Simo76 commited on
Commit
64f78fc
Β·
1 Parent(s): eeccc5f

Add stable task parity test for Unified-LoRA

Browse files

This script implements the Unified-LoRA Stable Task Parity Test for the MRPC dataset, validating that the controller causes no degradation during stable training. It includes functions for data loading, model training, and evaluation.

Files changed (1) hide show
  1. experiments/stable_task_test.py +172 -0
experiments/stable_task_test.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified-LoRA β€” Stable Task Parity Test
3
+ ========================================
4
+
5
+ MRPC only, 120 steps, 3 seeds.
6
+ Validates that the controller causes zero degradation on stable training.
7
+
8
+ Usage:
9
+ pip install transformers datasets evaluate
10
+ python stable_task_test.py
11
+ """
12
+
13
+ import time, random, math, numpy as np, torch, torch.nn as nn
14
+ import torch.nn.functional as F, evaluate
15
+ from datasets import load_dataset
16
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
17
+ from torch.utils.data import DataLoader
18
+
19
+ import sys, os
20
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+ from controller import NestedLoRALinear, OrbitalController, inject_nested_lora, set_rank
22
+
23
+ # ── CONFIG ──────────────────────────────────────────
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ MODEL = "distilbert-base-uncased"
26
+ BATCH = 8
27
+ STEPS = 120
28
+ LR = 5e-5
29
+ SEEDS = [0, 1, 2]
30
+
31
+ MAX_RANK = 16
32
+ WARMUP = 15
33
+ STABLE_WINDOW = 8
34
+
35
+ # ── DATA ────────────────────────────────────────────
36
+ print("Loading data...")
37
+ tok = AutoTokenizer.from_pretrained(MODEL)
38
+ ds = load_dataset("glue", "mrpc")
39
+
40
+ def tok_fn(x):
41
+ return tok(x["sentence1"], x["sentence2"],
42
+ truncation=True, padding="max_length", max_length=128)
43
+
44
+ ds = ds.map(tok_fn, batched=True)
45
+ ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
46
+ train_loader = DataLoader(ds["train"], batch_size=BATCH, shuffle=True)
47
+ val_loader = DataLoader(ds["validation"], batch_size=BATCH)
48
+ metric = evaluate.load("glue", "mrpc")
49
+
50
+ # ── HELPERS ─────────────────────────────────────────
51
+ def build_model():
52
+ base = AutoModelForSequenceClassification.from_pretrained(
53
+ MODEL, num_labels=2, ignore_mismatched_sizes=True
54
+ )
55
+ return inject_nested_lora(base, MAX_RANK).to(DEVICE)
56
+
57
+ def eval_model(model):
58
+ model.eval()
59
+ preds, labels = [], []
60
+ with torch.no_grad():
61
+ for batch in val_loader:
62
+ x = batch["input_ids"].to(DEVICE)
63
+ m = batch["attention_mask"].to(DEVICE)
64
+ y = batch["label"].to(DEVICE)
65
+ logits = model(input_ids=x, attention_mask=m).logits
66
+ preds.extend(logits.argmax(dim=-1).cpu().numpy())
67
+ labels.extend(y.cpu().numpy())
68
+ return metric.compute(predictions=preds, references=labels)["f1"]
69
+
70
+ def eff_rank(usage):
71
+ tot = sum(usage.values())
72
+ return sum(k * v for k, v in usage.items()) / tot if tot > 0 else 0
73
+
74
+ # ── TRAIN BASELINE ──────────────────────────────────
75
+ def train_baseline(model):
76
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
77
+ set_rank(model, 16)
78
+ it = iter(train_loader)
79
+
80
+ for step in range(STEPS):
81
+ try:
82
+ batch = next(it)
83
+ except StopIteration:
84
+ it = iter(train_loader); batch = next(it)
85
+
86
+ x = batch["input_ids"].to(DEVICE)
87
+ m = batch["attention_mask"].to(DEVICE)
88
+ y = batch["label"].to(DEVICE)
89
+
90
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
91
+ loss.backward()
92
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
93
+ opt.step()
94
+ opt.zero_grad()
95
+
96
+ return model
97
+
98
+ # ── TRAIN UNIFIED ───────────────────────────────────
99
+ def train_unified(model):
100
+ ctrl = OrbitalController(warmup=WARMUP, stable_window=STABLE_WINDOW)
101
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
102
+ usage = {4: 0, 8: 0, 16: 0}
103
+ rank_trace = []
104
+ it = iter(train_loader)
105
+
106
+ for step in range(STEPS):
107
+ try:
108
+ batch = next(it)
109
+ except StopIteration:
110
+ it = iter(train_loader); batch = next(it)
111
+
112
+ x = batch["input_ids"].to(DEVICE)
113
+ m = batch["attention_mask"].to(DEVICE)
114
+ y = batch["label"].to(DEVICE)
115
+
116
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
117
+ new_rank = ctrl.step(loss.item())
118
+ set_rank(model, new_rank)
119
+
120
+ usage[new_rank] += 1
121
+ rank_trace.append(new_rank)
122
+
123
+ loss.backward()
124
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
125
+ opt.step()
126
+ opt.zero_grad()
127
+
128
+ return model, usage, rank_trace, ctrl
129
+
130
+ # ── RUN ─────────────────────────────────────────────
131
+ print(f"\nDevice: {DEVICE}")
132
+ print(f"Task: MRPC, {STEPS} steps")
133
+ print("=" * 55)
134
+
135
+ results = []
136
+
137
+ for seed in SEEDS:
138
+ print(f"\n{'─' * 50}\n SEED {seed}\n{'─' * 50}")
139
+
140
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
141
+ base_model = build_model()
142
+ base_model = train_baseline(base_model)
143
+ f1_base = eval_model(base_model)
144
+ del base_model; torch.cuda.empty_cache()
145
+
146
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
147
+ uni_model = build_model()
148
+ uni_model, usage, trace, ctrl = train_unified(uni_model)
149
+ f1_uni = eval_model(uni_model)
150
+
151
+ er = eff_rank(usage)
152
+ saving = 1 - er / 16
153
+ transitions = sum(1 for i in range(1, len(trace)) if trace[i] != trace[i-1])
154
+
155
+ print(f"\n BASELINE F1 = {f1_base:.3f} (rank=16 fixed)")
156
+ print(f" UNIFIED F1 = {f1_uni:.3f} (eff_rank={er:.1f}, saving={saving*100:.0f}%)")
157
+ print(f" delta F1 = {f1_uni - f1_base:+.3f}")
158
+ print(f" Usage: r4={usage[4]} r8={usage[8]} r16={usage[16]} transitions={transitions}")
159
+
160
+ results.append({
161
+ 'seed': seed, 'f1_base': f1_base, 'f1_uni': f1_uni,
162
+ 'delta': f1_uni - f1_base, 'eff_rank': er,
163
+ })
164
+ del uni_model; torch.cuda.empty_cache()
165
+
166
+ # ── SUMMARY ─────────────────────────────────────────
167
+ print(f"\n{'=' * 55}\n SUMMARY\n{'=' * 55}")
168
+ f1b = [r['f1_base'] for r in results]
169
+ f1u = [r['f1_uni'] for r in results]
170
+ print(f"\n Baseline F1: {np.mean(f1b):.3f} +/- {np.std(f1b):.3f}")
171
+ print(f" Unified F1: {np.mean(f1u):.3f} +/- {np.std(f1u):.3f}")
172
+ print(f" delta F1: {np.mean([r['delta'] for r in results]):+.3f}")