fix quantize logic (and lora adapter weights loading)

#522
by yuz299 - opened
Files changed (1) hide show
  1. geneformer/perturber_utils.py +28 -8
geneformer/perturber_utils.py CHANGED
@@ -127,7 +127,10 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
127
  output_hidden_states = (mode == "eval")
128
 
129
  # Quantization logic
130
- if quantize:
 
 
 
131
  if inference_only:
132
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
133
  peft_config = None
@@ -138,13 +141,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
- peft_config = LoraConfig(
142
- lora_alpha=128,
143
- lora_dropout=0.1,
144
- r=64,
145
- bias="none",
146
- task_type="TokenClassification",
147
- )
 
 
 
 
 
 
 
 
 
148
  else:
149
  quantize_config = None
150
  peft_config = None
@@ -181,14 +193,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
181
  model.eval()
182
 
183
  # Handle device placement and PEFT
 
 
184
  if not quantize:
185
  # Only move non-quantized models
186
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
  model = model.to(device)
 
 
 
 
 
188
  elif peft_config:
189
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
190
  model.enable_input_require_grads()
191
  model = get_peft_model(model, peft_config)
 
192
 
193
  return model
194
 
 
127
  output_hidden_states = (mode == "eval")
128
 
129
  # Quantization logic
130
+ if isinstance(quantize, dict):
131
+ quantize_config = quantize.get("bnb_config", None)
132
+ peft_config = quantize.get("peft_config", None)
133
+ elif quantize:
134
  if inference_only:
135
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
136
  peft_config = None
 
141
  bnb_4bit_quant_type="nf4",
142
  bnb_4bit_compute_dtype=torch.bfloat16,
143
  )
144
+ try:
145
+ peft_config = LoraConfig(
146
+ lora_alpha=128,
147
+ lora_dropout=0.1,
148
+ r=64,
149
+ bias="none",
150
+ task_type="TokenClassification",
151
+ )
152
+ except ValueError as e:
153
+ peft_config = LoraConfig(
154
+ lora_alpha=128,
155
+ lora_dropout=0.1,
156
+ r=64,
157
+ bias="none",
158
+ task_type="TOKEN_CLS",
159
+ )
160
  else:
161
  quantize_config = None
162
  peft_config = None
 
193
  model.eval()
194
 
195
  # Handle device placement and PEFT
196
+ adapter_config_path = os.path.join(model_directory, "adapter_config.json")
197
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
  if not quantize:
199
  # Only move non-quantized models
200
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
201
  model = model.to(device)
202
+ elif os.path.exists(adapter_config_path):
203
+ # If adapter files exist, load them into the model using PEFT's from_pretrained
204
+ model = PeftModel.from_pretrained(model, model_directory)
205
+ model = model.to(device)
206
+ print("loading lora weights")
207
  elif peft_config:
208
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
209
  model.enable_input_require_grads()
210
  model = get_peft_model(model, peft_config)
211
+ model = model.to(device)
212
 
213
  return model
214