fix quantize logic (and lora adapter weights loading)
#522
by
yuz299
- opened
geneformer/perturber_utils.py
CHANGED
|
@@ -127,7 +127,10 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
| 127 |
output_hidden_states = (mode == "eval")
|
| 128 |
|
| 129 |
# Quantization logic
|
| 130 |
-
if quantize:
|
|
|
|
|
|
|
|
|
|
| 131 |
if inference_only:
|
| 132 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 133 |
peft_config = None
|
|
@@ -138,13 +141,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
| 138 |
bnb_4bit_quant_type="nf4",
|
| 139 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 140 |
)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
else:
|
| 149 |
quantize_config = None
|
| 150 |
peft_config = None
|
|
@@ -181,14 +193,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
| 181 |
model.eval()
|
| 182 |
|
| 183 |
# Handle device placement and PEFT
|
|
|
|
|
|
|
| 184 |
if not quantize:
|
| 185 |
# Only move non-quantized models
|
| 186 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 187 |
model = model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
elif peft_config:
|
| 189 |
# Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
|
| 190 |
model.enable_input_require_grads()
|
| 191 |
model = get_peft_model(model, peft_config)
|
|
|
|
| 192 |
|
| 193 |
return model
|
| 194 |
|
|
|
|
| 127 |
output_hidden_states = (mode == "eval")
|
| 128 |
|
| 129 |
# Quantization logic
|
| 130 |
+
if isinstance(quantize, dict):
|
| 131 |
+
quantize_config = quantize.get("bnb_config", None)
|
| 132 |
+
peft_config = quantize.get("peft_config", None)
|
| 133 |
+
elif quantize:
|
| 134 |
if inference_only:
|
| 135 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 136 |
peft_config = None
|
|
|
|
| 141 |
bnb_4bit_quant_type="nf4",
|
| 142 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 143 |
)
|
| 144 |
+
try:
|
| 145 |
+
peft_config = LoraConfig(
|
| 146 |
+
lora_alpha=128,
|
| 147 |
+
lora_dropout=0.1,
|
| 148 |
+
r=64,
|
| 149 |
+
bias="none",
|
| 150 |
+
task_type="TokenClassification",
|
| 151 |
+
)
|
| 152 |
+
except ValueError as e:
|
| 153 |
+
peft_config = LoraConfig(
|
| 154 |
+
lora_alpha=128,
|
| 155 |
+
lora_dropout=0.1,
|
| 156 |
+
r=64,
|
| 157 |
+
bias="none",
|
| 158 |
+
task_type="TOKEN_CLS",
|
| 159 |
+
)
|
| 160 |
else:
|
| 161 |
quantize_config = None
|
| 162 |
peft_config = None
|
|
|
|
| 193 |
model.eval()
|
| 194 |
|
| 195 |
# Handle device placement and PEFT
|
| 196 |
+
adapter_config_path = os.path.join(model_directory, "adapter_config.json")
|
| 197 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 198 |
if not quantize:
|
| 199 |
# Only move non-quantized models
|
| 200 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 201 |
model = model.to(device)
|
| 202 |
+
elif os.path.exists(adapter_config_path):
|
| 203 |
+
# If adapter files exist, load them into the model using PEFT's from_pretrained
|
| 204 |
+
model = PeftModel.from_pretrained(model, model_directory)
|
| 205 |
+
model = model.to(device)
|
| 206 |
+
print("loading lora weights")
|
| 207 |
elif peft_config:
|
| 208 |
# Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
|
| 209 |
model.enable_input_require_grads()
|
| 210 |
model = get_peft_model(model, peft_config)
|
| 211 |
+
model = model.to(device)
|
| 212 |
|
| 213 |
return model
|
| 214 |
|