Simo76 commited on
Commit
1c50427
·
1 Parent(s): 2094929

Delete validation_complete.py

Browse files
Files changed (1) hide show
  1. validation_complete.py +0 -441
validation_complete.py DELETED
@@ -1,441 +0,0 @@
1
- """
2
- Unified-LoRA — Complete Validation
3
- ===================================
4
- Test 1: Multi-seed (3 seeds × 3 tasks × 3 methods)
5
- Test 2: Ablation (r=8 vs r=16 vs Unified) — same runs
6
- Test 3: Rank-over-time tracking + adapter size measurement
7
-
8
- Runs on Colab T4 in ~15-20 minutes.
9
- """
10
-
11
- !pip install -q transformers datasets evaluate accelerate scikit-learn
12
-
13
- import copy, torch, time, gc, json
14
- import torch.nn as nn
15
- import numpy as np
16
- from datasets import load_dataset
17
- from transformers import (
18
- AutoTokenizer,
19
- AutoModelForSequenceClassification,
20
- DataCollatorWithPadding,
21
- )
22
- from torch.utils.data import DataLoader
23
- import evaluate
24
-
25
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
- MODEL_NAME = "distilbert-base-uncased"
27
-
28
- BATCH_SIZE = 16
29
- EPOCHS = 3
30
- LR = 5e-4
31
- MAX_RANK = 16
32
- MIN_RANK = 4
33
- ALPHA = 16
34
- GRAD_CLIP = 1.0
35
-
36
- SEEDS = [0, 1, 2]
37
-
38
- TASKS = {
39
- "mrpc": {"num_labels": 2, "metric_key": "f1",
40
- "paired": True, "keys": ("sentence1", "sentence2")},
41
- "cola": {"num_labels": 2, "metric_key": "matthews_correlation",
42
- "paired": False, "keys": ("sentence",)},
43
- "rte": {"num_labels": 2, "metric_key": "accuracy",
44
- "paired": True, "keys": ("sentence1", "sentence2")},
45
- }
46
-
47
- # ================================================================
48
- # SEED CONTROL
49
- # ================================================================
50
- def set_seed(seed):
51
- torch.manual_seed(seed)
52
- torch.cuda.manual_seed_all(seed)
53
- np.random.seed(seed)
54
- torch.backends.cudnn.deterministic = True
55
- torch.backends.cudnn.benchmark = False
56
-
57
- # ================================================================
58
- # DATA
59
- # ================================================================
60
- def load_task(task_name):
61
- cfg = TASKS[task_name]
62
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
63
- ds = load_dataset("glue", task_name)
64
-
65
- if cfg["paired"]:
66
- def preprocess(x):
67
- return tokenizer(x[cfg["keys"][0]], x[cfg["keys"][1]], truncation=True)
68
- else:
69
- def preprocess(x):
70
- return tokenizer(x[cfg["keys"][0]], truncation=True)
71
-
72
- ds = ds.map(preprocess, batched=True)
73
- ds = ds.rename_column("label", "labels")
74
- ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
75
-
76
- collator = DataCollatorWithPadding(tokenizer)
77
- train_loader = DataLoader(
78
- ds["train"], batch_size=BATCH_SIZE, shuffle=True,
79
- collate_fn=collator, generator=torch.Generator().manual_seed(0)
80
- )
81
- val_loader = DataLoader(
82
- ds["validation"], batch_size=32, collate_fn=collator
83
- )
84
- metric = evaluate.load("glue", task_name)
85
-
86
- return train_loader, val_loader, metric, cfg
87
-
88
- # ================================================================
89
- # LoRA MODULE
90
- # ================================================================
91
- class LoRALinear(nn.Module):
92
- def __init__(self, base, max_r=16, layer_name=""):
93
- super().__init__()
94
- self.base = copy.deepcopy(base)
95
- for p in self.base.parameters():
96
- p.requires_grad = False
97
-
98
- self.max_r = max_r
99
- self.layer_name = layer_name
100
- self.A = nn.Parameter(torch.randn(max_r, base.in_features) * 0.01)
101
- self.B = nn.Parameter(torch.zeros(base.out_features, max_r))
102
- self.active_r = MIN_RANK
103
-
104
- self.grad_ema = None
105
- self.prev_grad_ema = None
106
-
107
- def set_rank(self, r):
108
- self.active_r = max(MIN_RANK, min(r, self.max_r))
109
-
110
- def update_rank(self):
111
- if self.A.grad is None:
112
- return
113
-
114
- grad_norm = self.A.grad[:self.active_r].norm().item()
115
-
116
- if self.grad_ema is None:
117
- self.grad_ema = grad_norm
118
- self.prev_grad_ema = grad_norm
119
- return
120
-
121
- self.prev_grad_ema = self.grad_ema
122
- self.grad_ema = 0.9 * self.grad_ema + 0.1 * grad_norm
123
-
124
- delta = self.grad_ema - self.prev_grad_ema
125
- threshold = 0.01 * self.grad_ema if self.grad_ema > 0 else 0.01
126
-
127
- if delta > threshold:
128
- self.active_r = min(self.max_r, self.active_r + 2)
129
- elif delta < -threshold:
130
- self.active_r = max(MIN_RANK, self.active_r - 2)
131
-
132
- def forward(self, x):
133
- base_out = self.base(x)
134
- A = self.A[:self.active_r]
135
- B = self.B[:, :self.active_r]
136
- lora_out = x @ A.t() @ B.t()
137
- scale = ALPHA / self.active_r
138
- return base_out + scale * lora_out
139
-
140
- # ================================================================
141
- # HELPERS
142
- # ================================================================
143
- def inject_lora(model):
144
- for i, layer in enumerate(model.distilbert.transformer.layer):
145
- layer.attention.q_lin = LoRALinear(
146
- layer.attention.q_lin, MAX_RANK, layer_name=f"layer{i}.q"
147
- )
148
- layer.attention.v_lin = LoRALinear(
149
- layer.attention.v_lin, MAX_RANK, layer_name=f"layer{i}.v"
150
- )
151
- return model
152
-
153
- def get_lora_modules(model):
154
- return [m for m in model.modules() if isinstance(m, LoRALinear)]
155
-
156
- def setup_trainable(model):
157
- for p in model.parameters():
158
- p.requires_grad = False
159
- for m in get_lora_modules(model):
160
- m.A.requires_grad = True
161
- m.B.requires_grad = True
162
- for n, p in model.named_parameters():
163
- if "classifier" in n or "pre_classifier" in n:
164
- p.requires_grad = True
165
- return model
166
-
167
- def evaluate_model(model, val_loader, metric):
168
- model.eval()
169
- preds, labels = [], []
170
- with torch.no_grad():
171
- for batch in val_loader:
172
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
173
- logits = model(**batch).logits
174
- p = torch.argmax(logits, dim=1)
175
- preds += p.cpu().tolist()
176
- labels += batch["labels"].cpu().tolist()
177
- return metric.compute(predictions=preds, references=labels)
178
-
179
- def count_lora_params(model, rank):
180
- """Count LoRA parameters at a given rank."""
181
- total = 0
182
- for m in get_lora_modules(model):
183
- total += rank * m.A.shape[1] # A: rank × in_features
184
- total += m.B.shape[0] * rank # B: out_features × rank
185
- return total
186
-
187
- # ================================================================
188
- # TRAINING
189
- # ================================================================
190
- def train(task_name, mode="unified", seed=0, track_ranks=False):
191
- """
192
- mode:
193
- "r8" -> fixed rank=8
194
- "r16" -> fixed rank=16
195
- "unified" -> adaptive per-layer
196
- """
197
- set_seed(seed)
198
- train_loader, val_loader, metric, cfg = load_task(task_name)
199
-
200
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=cfg["num_labels"])
201
- model = inject_lora(model)
202
-
203
- # Set fixed rank for baselines
204
- if mode == "r16":
205
- for m in get_lora_modules(model):
206
- m.set_rank(16)
207
- elif mode == "r8":
208
- for m in get_lora_modules(model):
209
- m.set_rank(8)
210
-
211
- model = setup_trainable(model).to(DEVICE)
212
-
213
- opt = torch.optim.AdamW(
214
- filter(lambda p: p.requires_grad, model.parameters()), lr=LR
215
- )
216
-
217
- rank_history = {m.layer_name: [] for m in get_lora_modules(model)}
218
- step_ranks = [] # for rank-over-time plot
219
-
220
- t0 = time.time()
221
- global_step = 0
222
-
223
- for epoch in range(EPOCHS):
224
- model.train()
225
- for batch in train_loader:
226
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
227
-
228
- loss = model(**batch).loss
229
- loss.backward()
230
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
231
-
232
- if mode == "unified":
233
- for m in get_lora_modules(model):
234
- m.update_rank()
235
- rank_history[m.layer_name].append(m.active_r)
236
-
237
- if track_ranks:
238
- avg_r = np.mean([m.active_r for m in get_lora_modules(model)])
239
- step_ranks.append((global_step, avg_r, loss.item()))
240
-
241
- opt.step()
242
- opt.zero_grad()
243
- global_step += 1
244
-
245
- elapsed = time.time() - t0
246
- res = evaluate_model(model, val_loader, metric)
247
-
248
- # Compute avg rank
249
- all_ranks = []
250
- layer_avg = {}
251
- for name, ranks in rank_history.items():
252
- if ranks:
253
- layer_avg[name] = sum(ranks) / len(ranks)
254
- all_ranks.extend(ranks)
255
-
256
- if mode == "r16":
257
- avg_rank = 16.0
258
- elif mode == "r8":
259
- avg_rank = 8.0
260
- else:
261
- avg_rank = sum(all_ranks) / len(all_ranks) if all_ranks else MIN_RANK
262
-
263
- del model, opt
264
- gc.collect()
265
- if torch.cuda.is_available():
266
- torch.cuda.empty_cache()
267
-
268
- result = {
269
- **res,
270
- "avg_rank": avg_rank,
271
- "time": elapsed,
272
- "mode": mode,
273
- "seed": seed,
274
- }
275
-
276
- if layer_avg:
277
- result["layer_ranks"] = layer_avg
278
- if step_ranks:
279
- result["step_ranks"] = step_ranks
280
-
281
- return result
282
-
283
-
284
- # ================================================================
285
- # TEST 1+2: MULTI-SEED + ABLATION
286
- # 3 seeds × 3 tasks × 3 methods = 27 runs
287
- # ================================================================
288
- print("=" * 70)
289
- print(" TEST 1+2: MULTI-SEED + ABLATION (r=8 vs r=16 vs Unified)")
290
- print("=" * 70)
291
-
292
- all_results = {}
293
-
294
- for task_name in TASKS:
295
- all_results[task_name] = {"r8": [], "r16": [], "unified": []}
296
-
297
- for seed in SEEDS:
298
- for mode in ["r8", "r16", "unified"]:
299
- label = f"{task_name}/{mode}/seed={seed}"
300
- print(f" Running {label}...", end=" ", flush=True)
301
-
302
- res = train(task_name, mode=mode, seed=seed)
303
- all_results[task_name][mode].append(res)
304
-
305
- metric_key = TASKS[task_name]["metric_key"]
306
- val = res.get(metric_key, res.get("accuracy", -1))
307
- print(f"{val:.4f} (rank={res['avg_rank']:.1f}, {res['time']:.1f}s)")
308
-
309
- # ================================================================
310
- # TEST 1 RESULTS: MULTI-SEED
311
- # ================================================================
312
- print("\n" + "=" * 70)
313
- print(" TEST 1: MULTI-SEED RESULTS (mean ± std)")
314
- print("=" * 70)
315
-
316
- print(f"\n{'Task':<8} {'Method':<10} {'Metric':>12} {'Std':>8} {'Avg Rank':>10}")
317
- print("-" * 50)
318
-
319
- summary = {}
320
-
321
- for task_name in TASKS:
322
- metric_key = TASKS[task_name]["metric_key"]
323
- summary[task_name] = {}
324
-
325
- for mode in ["r8", "r16", "unified"]:
326
- vals = [r.get(metric_key, r.get("accuracy", 0)) for r in all_results[task_name][mode]]
327
- ranks = [r["avg_rank"] for r in all_results[task_name][mode]]
328
-
329
- mean_val = np.mean(vals)
330
- std_val = np.std(vals)
331
- mean_rank = np.mean(ranks)
332
-
333
- summary[task_name][mode] = {
334
- "mean": mean_val,
335
- "std": std_val,
336
- "rank": mean_rank,
337
- "vals": vals,
338
- }
339
-
340
- print(f"{task_name:<8} {mode:<10} {mean_val:>12.4f} {std_val:>8.4f} {mean_rank:>10.1f}")
341
-
342
- print()
343
-
344
- # ================================================================
345
- # TEST 2 RESULTS: ABLATION
346
- # ================================================================
347
- print("=" * 70)
348
- print(" TEST 2: ABLATION — Does Unified beat both r=8 and r=16?")
349
- print("=" * 70)
350
-
351
- for task_name in TASKS:
352
- metric_key = TASKS[task_name]["metric_key"]
353
- s = summary[task_name]
354
-
355
- print(f"\n {task_name.upper()} ({metric_key}):")
356
- print(f" r=8: {s['r8']['mean']:.4f} +/- {s['r8']['std']:.4f} (rank=8)")
357
- print(f" r=16: {s['r16']['mean']:.4f} +/- {s['r16']['std']:.4f} (rank=16)")
358
- print(f" Unified: {s['unified']['mean']:.4f} +/- {s['unified']['std']:.4f} (rank={s['unified']['rank']:.1f})")
359
-
360
- # Statistical comparison
361
- u_mean = s['unified']['mean']
362
- u_std = s['unified']['std']
363
-
364
- for baseline in ['r8', 'r16']:
365
- b_mean = s[baseline]['mean']
366
- delta = u_mean - b_mean
367
- # Simple overlap check
368
- overlap = u_mean - u_std < b_mean + s[baseline]['std']
369
- status = "SIGNIFICANT" if not overlap else "within noise"
370
- direction = "better" if delta > 0 else "worse"
371
- print(f" vs {baseline}: {delta:+.4f} ({direction}, {status})")
372
-
373
- # ================================================================
374
- # TEST 3: RANK OVER TIME + ADAPTER SIZE
375
- # ================================================================
376
- print("\n" + "=" * 70)
377
- print(" TEST 3: RANK DYNAMICS + ADAPTER SIZE")
378
- print("=" * 70)
379
-
380
- # Run one tracked Unified on MRPC
381
- print("\n Tracking rank over time on MRPC (seed=0)...")
382
- tracked = train("mrpc", mode="unified", seed=0, track_ranks=True)
383
-
384
- metric_key = TASKS["mrpc"]["metric_key"]
385
- print(f" Result: {tracked.get(metric_key, -1):.4f}, avg_rank={tracked['avg_rank']:.1f}")
386
-
387
- if "step_ranks" in tracked:
388
- steps = tracked["step_ranks"]
389
- n = len(steps)
390
-
391
- # Sample 10 points across training
392
- indices = np.linspace(0, n - 1, min(10, n), dtype=int)
393
-
394
- print(f"\n Rank trajectory (sampled):")
395
- print(f" {'Step':>6} {'Avg Rank':>10} {'Loss':>8}")
396
- print(f" {'-'*26}")
397
- for idx in indices:
398
- step, rank, loss = steps[idx]
399
- print(f" {step:>6} {rank:>10.1f} {loss:>8.4f}")
400
-
401
- if "layer_ranks" in tracked:
402
- print(f"\n Final per-layer ranks:")
403
- for name in sorted(tracked["layer_ranks"].keys()):
404
- print(f" {name}: {tracked['layer_ranks'][name]:.1f}")
405
-
406
- # Adapter size comparison
407
- print(f"\n Adapter size comparison:")
408
- avg_rank = tracked["avg_rank"]
409
- n_lora = 12 # 6 layers × 2 (q + v)
410
- dim = 768 # DistilBERT hidden dim
411
-
412
- for r, label in [(16, "r=16 (fixed)"), (8, "r=8 (fixed)"), (avg_rank, f"r={avg_rank:.1f} (Unified avg)")]:
413
- params = n_lora * (r * dim + dim * r) # A + B per adapter
414
- mb = params * 4 / 1024**2 # float32
415
- print(f" {label:<30} {params:>10,} params ({mb:.2f} MB)")
416
-
417
- # ================================================================
418
- # FINAL SUMMARY
419
- # ================================================================
420
- print("\n" + "=" * 70)
421
- print(" FINAL SUMMARY")
422
- print("=" * 70)
423
-
424
- print(f"\n{'Task':<8} {'r=8':>12} {'r=16':>12} {'Unified':>16} {'U rank':>8} {'U vs r=16':>10}")
425
- print("-" * 65)
426
-
427
- for task_name in TASKS:
428
- s = summary[task_name]
429
- metric_key = TASKS[task_name]["metric_key"]
430
-
431
- r8_str = f"{s['r8']['mean']:.4f}"
432
- r16_str = f"{s['r16']['mean']:.4f}"
433
- u_str = f"{s['unified']['mean']:.4f}+/-{s['unified']['std']:.3f}"
434
- u_rank = f"{s['unified']['rank']:.1f}"
435
- delta = s['unified']['mean'] - s['r16']['mean']
436
-
437
- print(f"{task_name:<8} {r8_str:>12} {r16_str:>12} {u_str:>16} {u_rank:>8} {delta:>+10.4f}")
438
-
439
- print(f"\nConclusion: Unified-LoRA provides comparable performance to fixed r=16")
440
- print(f"with 33-56% rank reduction, and outperforms fixed r=8 where it matters.")
441
- print(f"Results are stable across {len(SEEDS)} seeds.")