Simo76 commited on
Commit
a073377
·
1 Parent(s): 9ce053d

Update unified_lora.py

Browse files
Files changed (1) hide show
  1. unified_lora.py +0 -173
unified_lora.py CHANGED
@@ -16,177 +16,4 @@ This file is kept for reference only.
16
  Status: deprecated / legacy
17
  """
18
 
19
- Unified-LoRA Controller
20
- ========================
21
- Adaptive per-layer rank controller for LoRA fine-tuning.
22
- Drop-in module — works with any model that uses LoRA adapters.
23
 
24
- Usage:
25
- from unified_lora import LoRALinear, get_lora_modules
26
-
27
- # Replace linear layers with adaptive LoRA
28
- layer.q_proj = LoRALinear(layer.q_proj, max_r=16)
29
-
30
- # In training loop, after loss.backward():
31
- for m in get_lora_modules(model):
32
- m.update_rank()
33
- """
34
-
35
- import copy
36
- import torch
37
- import torch.nn as nn
38
-
39
-
40
- class LoRALinear(nn.Module):
41
- """
42
- LoRA adapter with per-layer adaptive rank.
43
-
44
- The rank adjusts based on gradient stress:
45
- - Gradient stress increasing → rank goes up (more capacity)
46
- - Gradient stress decreasing → rank goes down (less capacity)
47
-
48
- Parameters
49
- ----------
50
- base : nn.Linear
51
- The original linear layer to wrap.
52
- max_r : int
53
- Maximum rank (default 16).
54
- min_r : int
55
- Minimum rank (default 4).
56
- alpha : float
57
- Scaling factor for LoRA output. Uses alpha/active_r scaling.
58
- layer_name : str
59
- Optional name for logging.
60
- """
61
-
62
- def __init__(self, base, max_r=16, min_r=4, alpha=16.0, layer_name=""):
63
- super().__init__()
64
- self.base = copy.deepcopy(base)
65
- for p in self.base.parameters():
66
- p.requires_grad = False
67
-
68
- self.max_r = max_r
69
- self.min_r = min_r
70
- self.alpha = alpha
71
- self.layer_name = layer_name
72
-
73
- self.A = nn.Parameter(torch.randn(max_r, base.in_features) * 0.01)
74
- self.B = nn.Parameter(torch.zeros(base.out_features, max_r))
75
- self.active_r = min_r
76
-
77
- # Stress tracking
78
- self.grad_ema = None
79
- self.prev_grad_ema = None
80
-
81
- def set_rank(self, r):
82
- self.active_r = max(self.min_r, min(r, self.max_r))
83
-
84
- def update_rank(self):
85
- """Call after loss.backward(), before optimizer.step()."""
86
- if self.A.grad is None:
87
- return
88
-
89
- grad_norm = self.A.grad[:self.active_r].norm().item()
90
-
91
- if self.grad_ema is None:
92
- self.grad_ema = grad_norm
93
- self.prev_grad_ema = grad_norm
94
- return
95
-
96
- self.prev_grad_ema = self.grad_ema
97
- self.grad_ema = 0.9 * self.grad_ema + 0.1 * grad_norm
98
-
99
- delta = self.grad_ema - self.prev_grad_ema
100
- threshold = 0.01 * self.grad_ema if self.grad_ema > 0 else 0.01
101
-
102
- if delta > threshold:
103
- self.active_r = min(self.max_r, self.active_r + 2)
104
- elif delta < -threshold:
105
- self.active_r = max(self.min_r, self.active_r - 2)
106
-
107
- def forward(self, x):
108
- base_out = self.base(x)
109
- A = self.A[:self.active_r]
110
- B = self.B[:, :self.active_r]
111
- lora_out = x @ A.t() @ B.t()
112
- scale = self.alpha / self.active_r
113
- return base_out + scale * lora_out
114
-
115
- def extra_repr(self):
116
- return (f"in={self.base.in_features}, out={self.base.out_features}, "
117
- f"max_r={self.max_r}, min_r={self.min_r}, alpha={self.alpha}, "
118
- f"active_r={self.active_r}, name={self.layer_name}")
119
-
120
-
121
- def get_lora_modules(model):
122
- """Return all LoRALinear modules in a model."""
123
- return [m for m in model.modules() if isinstance(m, LoRALinear)]
124
-
125
-
126
- def inject_lora(model, target_modules, max_r=16, min_r=4, alpha=16.0):
127
- """
128
- Replace target linear layers with LoRALinear adapters.
129
-
130
- Parameters
131
- ----------
132
- model : nn.Module
133
- The model to modify.
134
- target_modules : list of str
135
- Names of linear layers to replace (e.g. ["q_proj", "v_proj"]).
136
- max_r, min_r, alpha : passed to LoRALinear.
137
-
138
- Returns
139
- -------
140
- model : nn.Module
141
- Modified model with LoRA adapters.
142
-
143
- Example
144
- -------
145
- # DistilBERT
146
- inject_lora(model, ["q_lin", "v_lin"])
147
-
148
- # Llama / Mistral
149
- inject_lora(model, ["q_proj", "v_proj"])
150
-
151
- # All attention projections
152
- inject_lora(model, ["q_proj", "k_proj", "v_proj", "o_proj"])
153
- """
154
- replace_list = []
155
- for name, module in model.named_modules():
156
- if isinstance(module, nn.Linear):
157
- if any(name.endswith(t) for t in target_modules):
158
- replace_list.append(name)
159
-
160
- for name in replace_list:
161
- parts = name.split(".")
162
- parent = model
163
- for p in parts[:-1]:
164
- parent = getattr(parent, p)
165
- original = getattr(parent, parts[-1])
166
- setattr(parent, parts[-1], LoRALinear(
167
- original, max_r=max_r, min_r=min_r, alpha=alpha, layer_name=name
168
- ))
169
-
170
- print(f"Injected LoRA into {len(replace_list)} layers: {replace_list}")
171
- return model
172
-
173
-
174
- def setup_trainable(model):
175
- """Freeze base model, unfreeze LoRA params and classifier."""
176
- for p in model.parameters():
177
- p.requires_grad = False
178
-
179
- for m in get_lora_modules(model):
180
- m.A.requires_grad = True
181
- m.B.requires_grad = True
182
-
183
- # Unfreeze common classifier head names
184
- for n, p in model.named_parameters():
185
- if any(k in n for k in ["classifier", "pre_classifier", "score", "lm_head"]):
186
- p.requires_grad = True
187
-
188
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
189
- total = sum(p.numel() for p in model.parameters())
190
- print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
191
-
192
- return model
 
16
  Status: deprecated / legacy
17
  """
18
 
 
 
 
 
19