Simo76 commited on
Commit
28c5d43
·
1 Parent(s): 7c2c644

Add Unified LoRA Controller implementation

Browse files

This file implements a Unified LoRA Controller for adaptive parameter-efficient fine-tuning, including methods for updating learning rates based on training loss and maintaining state history.

Files changed (1) hide show
  1. controller.py +211 -0
controller.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified LoRA Controller
3
+ ========================
4
+
5
+ Adaptive parameter-efficient fine-tuning controller with automatic
6
+ Single/Multi/Mirror mode switching based on synaptic stress signals.
7
+
8
+ Author: Simona Vargiu
9
+ License: Apache 2.0
10
+ """
11
+
12
+ import torch
13
+ from typing import Dict, Optional, Tuple
14
+
15
+
16
+ class UnifiedController:
17
+ """
18
+ Unified LoRA adaptive controller.
19
+
20
+ Monitors training stress via synaptic signal φ(t) and automatically
21
+ switches between three operational modes:
22
+ - Mode 0 (Single): Shared adapter for low conflict
23
+ - Mode 1 (Multi): Task-specific adapters for moderate stress
24
+ - Mode 2 (Mirror): Stability snapshots for catastrophic forgetting
25
+
26
+ Args:
27
+ alpha (float): Learning rate for φ(t) updates (default: 0.1)
28
+ beta (float): EMA smoothing factor for loss (default: 0.9)
29
+ theta0 (float): Single/Multi threshold (default: 0.3)
30
+ theta1 (float): Multi/Mirror threshold (default: 0.7)
31
+ lr_single (float): Learning rate for Single mode (default: 5e-5)
32
+ lr_multi (float): Learning rate for Multi mode (default: 3e-5)
33
+ lr_mirror (float): Learning rate for Mirror mode (default: 1e-5)
34
+
35
+ Example:
36
+ >>> controller = UnifiedController()
37
+ >>> for step, batch in enumerate(train_loader):
38
+ ... outputs = model(**batch)
39
+ ... new_lr = controller.update(outputs.loss.item())
40
+ ... # Apply new_lr to optimizer
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ alpha: float = 0.1,
46
+ beta: float = 0.9,
47
+ theta0: float = 0.3,
48
+ theta1: float = 0.7,
49
+ lr_single: float = 5e-5,
50
+ lr_multi: float = 3e-5,
51
+ lr_mirror: float = 1e-5,
52
+ ):
53
+ self.alpha = alpha
54
+ self.beta = beta
55
+ self.theta0 = theta0
56
+ self.theta1 = theta1
57
+
58
+ # Learning rates per mode
59
+ self.lr_map = {
60
+ 0: lr_single,
61
+ 1: lr_multi,
62
+ 2: lr_mirror,
63
+ }
64
+
65
+ # State variables
66
+ self.phi = 0.5 # Synaptic stress signal
67
+ self.E_smooth = 1.0 # Smoothed loss
68
+ self.mode = 1 # Current mode (start with Multi)
69
+ self.step = 0
70
+
71
+ # History tracking
72
+ self.history = {
73
+ "phi": [],
74
+ "E_smooth": [],
75
+ "mode": [],
76
+ "step": [],
77
+ }
78
+
79
+ def update(self, loss: float) -> float:
80
+ """
81
+ Update controller state and return new learning rate.
82
+
83
+ Args:
84
+ loss (float): Current training loss
85
+
86
+ Returns:
87
+ float: New learning rate based on current mode
88
+ """
89
+ self.step += 1
90
+
91
+ # Update smoothed loss (EMA)
92
+ E = float(loss)
93
+ self.E_smooth = self.beta * self.E_smooth + (1 - self.beta) * E
94
+
95
+ # Compute normalized stress signal
96
+ D = self.E_smooth / (1 + self.E_smooth) # Normalize to [0,1]
97
+
98
+ # Update synaptic signal φ(t) with EMA
99
+ self.phi = (1 - self.alpha) * self.phi + self.alpha * D
100
+
101
+ # FSM: Determine mode based on φ(t)
102
+ if self.phi < self.theta0:
103
+ self.mode = 0 # Single
104
+ elif self.phi < self.theta1:
105
+ self.mode = 1 # Multi
106
+ else:
107
+ self.mode = 2 # Mirror
108
+
109
+ # Log history
110
+ self.history["phi"].append(self.phi)
111
+ self.history["E_smooth"].append(self.E_smooth)
112
+ self.history["mode"].append(self.mode)
113
+ self.history["step"].append(self.step)
114
+
115
+ # Return learning rate for current mode
116
+ return self.lr_map[self.mode]
117
+
118
+ def get_state(self) -> Dict[str, float]:
119
+ """
120
+ Get current controller state.
121
+
122
+ Returns:
123
+ dict: Current values of phi, E_smooth, mode, step
124
+ """
125
+ return {
126
+ "phi": self.phi,
127
+ "E_smooth": self.E_smooth,
128
+ "mode": self.mode,
129
+ "step": self.step,
130
+ }
131
+
132
+ def get_history(self) -> Dict[str, list]:
133
+ """
134
+ Get complete training history.
135
+
136
+ Returns:
137
+ dict: History of phi, E_smooth, mode, step
138
+ """
139
+ return self.history
140
+
141
+ def reset(self):
142
+ """Reset controller to initial state."""
143
+ self.phi = 0.5
144
+ self.E_smooth = 1.0
145
+ self.mode = 1
146
+ self.step = 0
147
+ self.history = {
148
+ "phi": [],
149
+ "E_smooth": [],
150
+ "mode": [],
151
+ "step": [],
152
+ }
153
+
154
+ @staticmethod
155
+ def mode_name(mode: int) -> str:
156
+ """
157
+ Get human-readable mode name.
158
+
159
+ Args:
160
+ mode (int): Mode number (0, 1, or 2)
161
+
162
+ Returns:
163
+ str: Mode name
164
+ """
165
+ names = {0: "Single", 1: "Multi", 2: "Mirror"}
166
+ return names.get(mode, "Unknown")
167
+
168
+ def __repr__(self) -> str:
169
+ """String representation of controller state."""
170
+ return (
171
+ f"UnifiedController(step={self.step}, phi={self.phi:.3f}, "
172
+ f"mode={self.mode} ({self.mode_name(self.mode)}), "
173
+ f"E_smooth={self.E_smooth:.3f})"
174
+ )
175
+
176
+
177
+ # Example usage
178
+ if __name__ == "__main__":
179
+ import numpy as np
180
+
181
+ print("Unified LoRA Controller - Example")
182
+ print("=" * 50)
183
+
184
+ controller = UnifiedController()
185
+
186
+ # Simulate training with stress events
187
+ print("\nSimulating training with SHOCK at step 150...")
188
+ print()
189
+
190
+ for step in range(300):
191
+ # Simulate loss
192
+ if step < 150:
193
+ loss = np.random.uniform(0.4, 0.6) # Normal training
194
+ else:
195
+ loss = np.random.uniform(2.0, 4.0) # SHOCK
196
+
197
+ # Update controller
198
+ new_lr = controller.update(loss)
199
+
200
+ # Log every 50 steps
201
+ if step % 50 == 0:
202
+ state = controller.get_state()
203
+ print(
204
+ f"[{step:3d}] phi={state['phi']:.3f} | "
205
+ f"mode={state['mode']} ({controller.mode_name(state['mode'])}) | "
206
+ f"lr={new_lr:.1e}"
207
+ )
208
+
209
+ print("\n" + "=" * 50)
210
+ print("Simulation complete!")
211
+ print(f"\nFinal state: {controller}")