Spaces:
Runtime error
Runtime error
Create lora_manager.py
Browse files- lora_manager.py +39 -0
lora_manager.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
lora_manager.py - manages LoRA adapter state lifecycle with caching
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from helpers import calculate_duration
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class LoRAManager:
|
| 11 |
+
def __init__(self, loras_list):
|
| 12 |
+
self.loras = loras_list
|
| 13 |
+
self.active_adapters = []
|
| 14 |
+
|
| 15 |
+
def set_active_loras(self, pipe, selected_loras, scales):
|
| 16 |
+
if not selected_loras:
|
| 17 |
+
raise ValueError("No LoRAs selected")
|
| 18 |
+
|
| 19 |
+
with calculate_duration("Unloading LoRA weights"):
|
| 20 |
+
pipe.unload_lora_weights()
|
| 21 |
+
|
| 22 |
+
lora_names = []
|
| 23 |
+
lora_weights = []
|
| 24 |
+
with calculate_duration("Loading LoRA weights"):
|
| 25 |
+
for idx, lora in enumerate(selected_loras):
|
| 26 |
+
adapter_name = f"lora_{idx}"
|
| 27 |
+
logger.info(f"Loading {lora['title']} as {adapter_name}")
|
| 28 |
+
pipe.load_lora_weights(
|
| 29 |
+
lora['repo'],
|
| 30 |
+
weight_name=lora.get("weights"),
|
| 31 |
+
low_cpu_mem_usage=True,
|
| 32 |
+
adapter_name=adapter_name
|
| 33 |
+
)
|
| 34 |
+
lora_names.append(adapter_name)
|
| 35 |
+
lora_weights.append(scales[idx] if idx < len(scales) else 1.0)
|
| 36 |
+
|
| 37 |
+
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
| 38 |
+
self.active_adapters = lora_names
|
| 39 |
+
logger.info(f"Active adapters set: {self.active_adapters}")
|