Spaces:
Runtime error
Runtime error
| """ | |
| lora_manager.py - manages LoRA adapter state lifecycle with caching | |
| """ | |
| import logging | |
| from helpers import calculate_duration | |
| logger = logging.getLogger(__name__) | |
| class LoRAManager: | |
| def __init__(self, loras_list): | |
| self.loras = loras_list | |
| self.active_adapters = [] | |
| def set_active_loras(self, pipe, selected_loras, scales): | |
| if not selected_loras: | |
| raise ValueError("No LoRAs selected") | |
| with calculate_duration("Unloading LoRA weights"): | |
| pipe.unload_lora_weights() | |
| lora_names = [] | |
| lora_weights = [] | |
| with calculate_duration("Loading LoRA weights"): | |
| for idx, lora in enumerate(selected_loras): | |
| adapter_name = f"lora_{idx}" | |
| logger.info(f"Loading {lora['title']} as {adapter_name}") | |
| pipe.load_lora_weights( | |
| lora['repo'], | |
| weight_name=lora.get("weights"), | |
| low_cpu_mem_usage=True, | |
| adapter_name=adapter_name | |
| ) | |
| lora_names.append(adapter_name) | |
| lora_weights.append(scales[idx] if idx < len(scales) else 1.0) | |
| pipe.set_adapters(lora_names, adapter_weights=lora_weights) | |
| self.active_adapters = lora_names | |
| logger.info(f"Active adapters set: {self.active_adapters}") | |