BoobyBoobs commited on
Commit
b13e05b
·
verified ·
1 Parent(s): 955066e

Create lora_manager.py

Browse files
Files changed (1) hide show
  1. lora_manager.py +39 -0
lora_manager.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ lora_manager.py - manages LoRA adapter state lifecycle with caching
3
+ """
4
+
5
+ import logging
6
+ from helpers import calculate_duration
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class LoRAManager:
11
+ def __init__(self, loras_list):
12
+ self.loras = loras_list
13
+ self.active_adapters = []
14
+
15
+ def set_active_loras(self, pipe, selected_loras, scales):
16
+ if not selected_loras:
17
+ raise ValueError("No LoRAs selected")
18
+
19
+ with calculate_duration("Unloading LoRA weights"):
20
+ pipe.unload_lora_weights()
21
+
22
+ lora_names = []
23
+ lora_weights = []
24
+ with calculate_duration("Loading LoRA weights"):
25
+ for idx, lora in enumerate(selected_loras):
26
+ adapter_name = f"lora_{idx}"
27
+ logger.info(f"Loading {lora['title']} as {adapter_name}")
28
+ pipe.load_lora_weights(
29
+ lora['repo'],
30
+ weight_name=lora.get("weights"),
31
+ low_cpu_mem_usage=True,
32
+ adapter_name=adapter_name
33
+ )
34
+ lora_names.append(adapter_name)
35
+ lora_weights.append(scales[idx] if idx < len(scales) else 1.0)
36
+
37
+ pipe.set_adapters(lora_names, adapter_weights=lora_weights)
38
+ self.active_adapters = lora_names
39
+ logger.info(f"Active adapters set: {self.active_adapters}")