Spaces:
Sleeping
Sleeping
lowvoltagenation
commited on
Commit
·
2043924
1
Parent(s):
ed31be4
Add support for LoRA model loading in ModelInterface
Browse files- Updated requirements.txt to include 'peft' library.
- Enhanced ModelInterface to load LoRA adapters with base models, including error handling and tokenizer setup.
- Integrated logging for model loading processes to improve feedback during operations.
requirements.txt
CHANGED
|
@@ -12,6 +12,7 @@ langchain-community>=0.0.10
|
|
| 12 |
# HuggingFace Integration
|
| 13 |
huggingface_hub>=0.18.0
|
| 14 |
datasets>=2.14.0
|
|
|
|
| 15 |
|
| 16 |
# Model Providers (Optional)
|
| 17 |
anthropic>=0.5.0
|
|
|
|
| 12 |
# HuggingFace Integration
|
| 13 |
huggingface_hub>=0.18.0
|
| 14 |
datasets>=2.14.0
|
| 15 |
+
peft>=0.6.0
|
| 16 |
|
| 17 |
# Model Providers (Optional)
|
| 18 |
anthropic>=0.5.0
|
src/__pycache__/model_interface.cpython-313.pyc
CHANGED
|
Binary files a/src/__pycache__/model_interface.cpython-313.pyc and b/src/__pycache__/model_interface.cpython-313.pyc differ
|
|
|
src/model_interface.py
CHANGED
|
@@ -12,6 +12,7 @@ from transformers import (
|
|
| 12 |
pipeline,
|
| 13 |
BitsAndBytesConfig
|
| 14 |
)
|
|
|
|
| 15 |
import torch
|
| 16 |
from huggingface_hub import HfApi
|
| 17 |
import json
|
|
@@ -173,6 +174,60 @@ class ModelInterface:
|
|
| 173 |
"type": "local"
|
| 174 |
}
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
else:
|
| 177 |
logger.error(f"Unknown model type: {model_type}")
|
| 178 |
return False
|
|
|
|
| 12 |
pipeline,
|
| 13 |
BitsAndBytesConfig
|
| 14 |
)
|
| 15 |
+
from peft import PeftModel
|
| 16 |
import torch
|
| 17 |
from huggingface_hub import HfApi
|
| 18 |
import json
|
|
|
|
| 174 |
"type": "local"
|
| 175 |
}
|
| 176 |
|
| 177 |
+
elif model_type == "lora":
|
| 178 |
+
# Load LoRA adapter with base model
|
| 179 |
+
logger.info(f"Loading LoRA model {model_id}...")
|
| 180 |
+
|
| 181 |
+
base_model_id = model_config.get("base_model")
|
| 182 |
+
if not base_model_id:
|
| 183 |
+
logger.error(f"No base model specified for LoRA {model_id}")
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
# Use auth token if available
|
| 187 |
+
auth_token = os.getenv("HUGGINGFACE_API_TOKEN") if use_auth_token else None
|
| 188 |
+
|
| 189 |
+
# Load base model first
|
| 190 |
+
logger.info(f"Loading base model {base_model_id}...")
|
| 191 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 192 |
+
base_model_id,
|
| 193 |
+
token=auth_token,
|
| 194 |
+
torch_dtype=torch.float16,
|
| 195 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 196 |
+
low_cpu_mem_usage=True
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Load LoRA adapter
|
| 200 |
+
logger.info(f"Loading LoRA adapter {model_id}...")
|
| 201 |
+
model = PeftModel.from_pretrained(base_model, model_id, token=auth_token)
|
| 202 |
+
|
| 203 |
+
# Load tokenizer (from base model)
|
| 204 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 205 |
+
base_model_id,
|
| 206 |
+
token=auth_token,
|
| 207 |
+
padding_side="left"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Add pad token if missing
|
| 211 |
+
if tokenizer.pad_token is None:
|
| 212 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 213 |
+
|
| 214 |
+
# Create pipeline
|
| 215 |
+
pipe = pipeline(
|
| 216 |
+
"text-generation",
|
| 217 |
+
model=model,
|
| 218 |
+
tokenizer=tokenizer,
|
| 219 |
+
device=0 if torch.cuda.is_available() else -1,
|
| 220 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.models[model_id] = {
|
| 224 |
+
"pipeline": pipe,
|
| 225 |
+
"tokenizer": tokenizer,
|
| 226 |
+
"model": model,
|
| 227 |
+
"type": "lora",
|
| 228 |
+
"base_model": base_model_id
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
else:
|
| 232 |
logger.error(f"Unknown model type: {model_type}")
|
| 233 |
return False
|