lowvoltagenation commited on
Commit
2043924
·
1 Parent(s): ed31be4

Add support for LoRA model loading in ModelInterface

Browse files

- Updated requirements.txt to include 'peft' library.
- Enhanced ModelInterface to load LoRA adapters with base models, including error handling and tokenizer setup.
- Integrated logging for model loading processes to improve feedback during operations.

requirements.txt CHANGED
@@ -12,6 +12,7 @@ langchain-community>=0.0.10
12
  # HuggingFace Integration
13
  huggingface_hub>=0.18.0
14
  datasets>=2.14.0
 
15
 
16
  # Model Providers (Optional)
17
  anthropic>=0.5.0
 
12
  # HuggingFace Integration
13
  huggingface_hub>=0.18.0
14
  datasets>=2.14.0
15
+ peft>=0.6.0
16
 
17
  # Model Providers (Optional)
18
  anthropic>=0.5.0
src/__pycache__/model_interface.cpython-313.pyc CHANGED
Binary files a/src/__pycache__/model_interface.cpython-313.pyc and b/src/__pycache__/model_interface.cpython-313.pyc differ
 
src/model_interface.py CHANGED
@@ -12,6 +12,7 @@ from transformers import (
12
  pipeline,
13
  BitsAndBytesConfig
14
  )
 
15
  import torch
16
  from huggingface_hub import HfApi
17
  import json
@@ -173,6 +174,60 @@ class ModelInterface:
173
  "type": "local"
174
  }
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  else:
177
  logger.error(f"Unknown model type: {model_type}")
178
  return False
 
12
  pipeline,
13
  BitsAndBytesConfig
14
  )
15
+ from peft import PeftModel
16
  import torch
17
  from huggingface_hub import HfApi
18
  import json
 
174
  "type": "local"
175
  }
176
 
177
+ elif model_type == "lora":
178
+ # Load LoRA adapter with base model
179
+ logger.info(f"Loading LoRA model {model_id}...")
180
+
181
+ base_model_id = model_config.get("base_model")
182
+ if not base_model_id:
183
+ logger.error(f"No base model specified for LoRA {model_id}")
184
+ return False
185
+
186
+ # Use auth token if available
187
+ auth_token = os.getenv("HUGGINGFACE_API_TOKEN") if use_auth_token else None
188
+
189
+ # Load base model first
190
+ logger.info(f"Loading base model {base_model_id}...")
191
+ base_model = AutoModelForCausalLM.from_pretrained(
192
+ base_model_id,
193
+ token=auth_token,
194
+ torch_dtype=torch.float16,
195
+ device_map="auto" if torch.cuda.is_available() else None,
196
+ low_cpu_mem_usage=True
197
+ )
198
+
199
+ # Load LoRA adapter
200
+ logger.info(f"Loading LoRA adapter {model_id}...")
201
+ model = PeftModel.from_pretrained(base_model, model_id, token=auth_token)
202
+
203
+ # Load tokenizer (from base model)
204
+ tokenizer = AutoTokenizer.from_pretrained(
205
+ base_model_id,
206
+ token=auth_token,
207
+ padding_side="left"
208
+ )
209
+
210
+ # Add pad token if missing
211
+ if tokenizer.pad_token is None:
212
+ tokenizer.pad_token = tokenizer.eos_token
213
+
214
+ # Create pipeline
215
+ pipe = pipeline(
216
+ "text-generation",
217
+ model=model,
218
+ tokenizer=tokenizer,
219
+ device=0 if torch.cuda.is_available() else -1,
220
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
221
+ )
222
+
223
+ self.models[model_id] = {
224
+ "pipeline": pipe,
225
+ "tokenizer": tokenizer,
226
+ "model": model,
227
+ "type": "lora",
228
+ "base_model": base_model_id
229
+ }
230
+
231
  else:
232
  logger.error(f"Unknown model type: {model_type}")
233
  return False