File size: 1,370 Bytes
b13e05b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
lora_manager.py - manages LoRA adapter state lifecycle with caching
"""

import logging
from helpers import calculate_duration

logger = logging.getLogger(__name__)

class LoRAManager:
    def __init__(self, loras_list):
        self.loras = loras_list
        self.active_adapters = []

    def set_active_loras(self, pipe, selected_loras, scales):
        if not selected_loras:
            raise ValueError("No LoRAs selected")

        with calculate_duration("Unloading LoRA weights"):
            pipe.unload_lora_weights()

        lora_names = []
        lora_weights = []
        with calculate_duration("Loading LoRA weights"):
            for idx, lora in enumerate(selected_loras):
                adapter_name = f"lora_{idx}"
                logger.info(f"Loading {lora['title']} as {adapter_name}")
                pipe.load_lora_weights(
                    lora['repo'],
                    weight_name=lora.get("weights"),
                    low_cpu_mem_usage=True,
                    adapter_name=adapter_name
                )
                lora_names.append(adapter_name)
                lora_weights.append(scales[idx] if idx < len(scales) else 1.0)

            pipe.set_adapters(lora_names, adapter_weights=lora_weights)
            self.active_adapters = lora_names
            logger.info(f"Active adapters set: {self.active_adapters}")