add fallback task_type for LoraConfig to support different PEFT versions

#537
Files changed (1) hide show
  1. geneformer/perturber_utils.py +14 -7
geneformer/perturber_utils.py CHANGED
@@ -138,13 +138,20 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
- peft_config = LoraConfig(
142
- lora_alpha=128,
143
- lora_dropout=0.1,
144
- r=64,
145
- bias="none",
146
- task_type="TokenClassification",
147
- )
 
 
 
 
 
 
 
148
  else:
149
  quantize_config = None
150
  peft_config = None
 
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
+ # Define common LoraConfig parameters
142
+ lora_config_params = {
143
+ "lora_alpha": 128,
144
+ "lora_dropout": 0.1,
145
+ "r": 64,
146
+ "bias": "none"
147
+ }
148
+
149
+ # Try with TokenClassification first, fallback to TOKEN_CLS if needed
150
+ try:
151
+ peft_config = LoraConfig(**lora_config_params, task_type="TokenClassification")
152
+ except ValueError:
153
+ # Some versions use TOKEN_CLS instead of TokenClassification
154
+ peft_config = LoraConfig(**lora_config_params, task_type="TOKEN_CLS")
155
  else:
156
  quantize_config = None
157
  peft_config = None