Loras / lora_manager.py
BoobyBoobs's picture
Create lora_manager.py
b13e05b verified
"""
lora_manager.py - manages LoRA adapter state lifecycle with caching
"""
import logging
from helpers import calculate_duration
logger = logging.getLogger(__name__)
class LoRAManager:
def __init__(self, loras_list):
self.loras = loras_list
self.active_adapters = []
def set_active_loras(self, pipe, selected_loras, scales):
if not selected_loras:
raise ValueError("No LoRAs selected")
with calculate_duration("Unloading LoRA weights"):
pipe.unload_lora_weights()
lora_names = []
lora_weights = []
with calculate_duration("Loading LoRA weights"):
for idx, lora in enumerate(selected_loras):
adapter_name = f"lora_{idx}"
logger.info(f"Loading {lora['title']} as {adapter_name}")
pipe.load_lora_weights(
lora['repo'],
weight_name=lora.get("weights"),
low_cpu_mem_usage=True,
adapter_name=adapter_name
)
lora_names.append(adapter_name)
lora_weights.append(scales[idx] if idx < len(scales) else 1.0)
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
self.active_adapters = lora_names
logger.info(f"Active adapters set: {self.active_adapters}")