Simo76 commited on
Commit
eeccc5f
Β·
1 Parent(s): aae2316

Add stress test for task switching with Nested Orbital Controller

Browse files

This script performs a stress test for task switching between MRPC and SST-2 datasets using a Nested Orbital Controller. It includes model training, evaluation, and results summarization.

experiments/stress_test_task_switch.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified-LoRA β€” Stress Test: Task Switch
3
+ =========================================
4
+
5
+ MRPC (60 steps) β†’ SST-2 (60 steps)
6
+ Baseline (r=16 fixed) vs Nested Orbital Controller
7
+
8
+ Self-contained, reproducible on Google Colab with T4 GPU.
9
+
10
+ Usage:
11
+ pip install transformers datasets evaluate
12
+ python stress_test_task_switch.py
13
+ """
14
+
15
+ import time, random, math, numpy as np, torch, torch.nn as nn
16
+ import torch.nn.functional as F, evaluate
17
+ from datasets import load_dataset
18
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
19
+ from torch.utils.data import DataLoader
20
+
21
+ # Import from controller.py (same repo)
22
+ import sys, os
23
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
+ from controller import NestedLoRALinear, OrbitalController, inject_nested_lora, set_rank
25
+
26
+ # ── CONFIG ──────────────────────────────────────────
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+ MODEL = "distilbert-base-uncased"
29
+ BATCH = 8
30
+ LR = 5e-5
31
+ SEEDS = [0, 1, 2]
32
+
33
+ MAX_RANK = 16
34
+ WARMUP = 10
35
+ STABLE_WINDOW = 6
36
+
37
+ STEPS_TASK1 = 60 # MRPC
38
+ STEPS_TASK2 = 60 # SST-2
39
+ TOTAL_STEPS = STEPS_TASK1 + STEPS_TASK2
40
+
41
+ # ── DATA ────────────────────────────────────────────
42
+ print("Loading data...")
43
+ tok = AutoTokenizer.from_pretrained(MODEL)
44
+
45
+ ds_mrpc = load_dataset("glue", "mrpc")
46
+ def tok_mrpc(x):
47
+ return tok(x["sentence1"], x["sentence2"],
48
+ truncation=True, padding="max_length", max_length=128)
49
+ ds_mrpc = ds_mrpc.map(tok_mrpc, batched=True)
50
+ ds_mrpc.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
51
+ train_mrpc = DataLoader(ds_mrpc["train"], batch_size=BATCH, shuffle=True)
52
+ val_mrpc = DataLoader(ds_mrpc["validation"], batch_size=BATCH)
53
+
54
+ ds_sst2 = load_dataset("glue", "sst2")
55
+ def tok_sst2(x):
56
+ return tok(x["sentence"], truncation=True, padding="max_length", max_length=128)
57
+ ds_sst2 = ds_sst2.map(tok_sst2, batched=True)
58
+ ds_sst2.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
59
+ train_sst2 = DataLoader(ds_sst2["train"], batch_size=BATCH, shuffle=True)
60
+ val_sst2 = DataLoader(ds_sst2["validation"], batch_size=BATCH)
61
+
62
+ metric_mrpc = evaluate.load("glue", "mrpc")
63
+ metric_sst2 = evaluate.load("glue", "sst2")
64
+
65
+ # ── HELPERS ─────────────────────────────────────────
66
+ def make_iter(loader):
67
+ while True:
68
+ for batch in loader:
69
+ yield batch
70
+
71
+ def get_batch(it, device):
72
+ batch = next(it)
73
+ return (batch["input_ids"].to(device),
74
+ batch["attention_mask"].to(device),
75
+ batch["label"].to(device))
76
+
77
+ def build_model():
78
+ base = AutoModelForSequenceClassification.from_pretrained(
79
+ MODEL, num_labels=2, ignore_mismatched_sizes=True
80
+ )
81
+ return inject_nested_lora(base, MAX_RANK).to(DEVICE)
82
+
83
+ def eval_f1(model, loader, metric_fn):
84
+ model.eval()
85
+ preds, labels = [], []
86
+ with torch.no_grad():
87
+ for batch in loader:
88
+ x = batch["input_ids"].to(DEVICE)
89
+ m = batch["attention_mask"].to(DEVICE)
90
+ y = batch["label"].to(DEVICE)
91
+ logits = model(input_ids=x, attention_mask=m).logits
92
+ preds.extend(logits.argmax(dim=-1).cpu().numpy())
93
+ labels.extend(y.cpu().numpy())
94
+ model.train()
95
+ result = metric_fn.compute(predictions=preds, references=labels)
96
+ return result.get("f1", result.get("accuracy", 0.0))
97
+
98
+ def eff_rank(usage):
99
+ tot = sum(usage.values())
100
+ return sum(k * v for k, v in usage.items()) / tot if tot > 0 else 0
101
+
102
+ # ── TRAIN BASELINE ──────────────────────────────────
103
+ def train_baseline(model):
104
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
105
+ set_rank(model, 16)
106
+ it_mrpc = make_iter(train_mrpc)
107
+ it_sst2 = make_iter(train_sst2)
108
+ loss_trace = []
109
+
110
+ for step in range(TOTAL_STEPS):
111
+ if step < STEPS_TASK1:
112
+ x, m, y = get_batch(it_mrpc, DEVICE)
113
+ else:
114
+ x, m, y = get_batch(it_sst2, DEVICE)
115
+
116
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
117
+ loss_trace.append(loss.item())
118
+ loss.backward()
119
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
120
+ opt.step()
121
+ opt.zero_grad()
122
+
123
+ return model, loss_trace
124
+
125
+ # ── TRAIN UNIFIED ───────────────────────────────────
126
+ def train_unified(model):
127
+ ctrl = OrbitalController(warmup=WARMUP, stable_window=STABLE_WINDOW)
128
+ ctrl.rank = 4
129
+ set_rank(model, 4)
130
+
131
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
132
+ usage = {4: 0, 8: 0, 16: 0}
133
+ rank_trace, loss_trace = [], []
134
+ it_mrpc = make_iter(train_mrpc)
135
+ it_sst2 = make_iter(train_sst2)
136
+
137
+ for step in range(TOTAL_STEPS):
138
+ if step < STEPS_TASK1:
139
+ x, m, y = get_batch(it_mrpc, DEVICE)
140
+ else:
141
+ x, m, y = get_batch(it_sst2, DEVICE)
142
+
143
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
144
+ new_rank = ctrl.step(loss.item())
145
+ set_rank(model, new_rank)
146
+
147
+ usage[new_rank] += 1
148
+ rank_trace.append(new_rank)
149
+ loss_trace.append(loss.item())
150
+
151
+ loss.backward()
152
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
153
+ opt.step()
154
+ opt.zero_grad()
155
+
156
+ return model, usage, rank_trace, loss_trace, ctrl
157
+
158
+ # ── RUN ─────────────────────────────────────────────
159
+ print(f"\nDevice: {DEVICE}")
160
+ print(f"Plan: MRPC Γ— {STEPS_TASK1} β†’ SST-2 Γ— {STEPS_TASK2}")
161
+ print(f"Shock at step {STEPS_TASK1}")
162
+ print("=" * 55)
163
+
164
+ results = []
165
+
166
+ for seed in SEEDS:
167
+ print(f"\n{'─' * 55}\n SEED {seed}\n{'─' * 55}")
168
+
169
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
170
+ base_model = build_model()
171
+ base_model, base_loss = train_baseline(base_model)
172
+ f1_mrpc_base = eval_f1(base_model, val_mrpc, metric_mrpc)
173
+ f1_sst2_base = eval_f1(base_model, val_sst2, metric_sst2)
174
+ del base_model; torch.cuda.empty_cache()
175
+
176
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
177
+ uni_model = build_model()
178
+ uni_model, usage, rank_trace, uni_loss, ctrl = train_unified(uni_model)
179
+ f1_mrpc_uni = eval_f1(uni_model, val_mrpc, metric_mrpc)
180
+ f1_sst2_uni = eval_f1(uni_model, val_sst2, metric_sst2)
181
+
182
+ er = eff_rank(usage)
183
+ saving = 1 - er / 16
184
+ transitions = sum(1 for i in range(1, len(rank_trace)) if rank_trace[i] != rank_trace[i-1])
185
+
186
+ print(f"\n {'':30s} {'BASELINE':>10s} {'UNIFIED':>10s}")
187
+ print(f" {'─' * 55}")
188
+ print(f" {'MRPC F1 (retention)':30s} {f1_mrpc_base:10.3f} {f1_mrpc_uni:10.3f}")
189
+ print(f" {'SST-2 Acc (new task)':30s} {f1_sst2_base:10.3f} {f1_sst2_uni:10.3f}")
190
+ print(f"\n Unified: eff_rank={er:.1f} saving={saving*100:.0f}% transitions={transitions}")
191
+ print(f" Usage: r4={usage[4]} r8={usage[8]} r16={usage[16]}")
192
+
193
+ # Rank trace
194
+ trace_str = ""
195
+ for i, r in enumerate(rank_trace):
196
+ if i % 10 == 0:
197
+ marker = " <<<SHOCK" if i == STEPS_TASK1 else ""
198
+ trace_str += f"\n [{i:3d}]{marker} "
199
+ trace_str += f"r{r:<3d}"
200
+ print(f" Rank trace:{trace_str}")
201
+
202
+ results.append({
203
+ 'seed': seed, 'f1_mrpc_base': f1_mrpc_base, 'f1_sst2_base': f1_sst2_base,
204
+ 'f1_mrpc_uni': f1_mrpc_uni, 'f1_sst2_uni': f1_sst2_uni,
205
+ 'eff_rank': er, 'saving': saving, 'transitions': transitions,
206
+ })
207
+ del uni_model; torch.cuda.empty_cache()
208
+
209
+ # ── SUMMARY ─────────────────────────────────────────
210
+ print(f"\n{'=' * 55}\n SUMMARY\n{'=' * 55}")
211
+ mrpc_b = np.mean([r['f1_mrpc_base'] for r in results])
212
+ mrpc_u = np.mean([r['f1_mrpc_uni'] for r in results])
213
+ sst2_b = np.mean([r['f1_sst2_base'] for r in results])
214
+ sst2_u = np.mean([r['f1_sst2_uni'] for r in results])
215
+ er_avg = np.mean([r['eff_rank'] for r in results])
216
+ sv_avg = np.mean([r['saving'] for r in results])
217
+
218
+ print(f"\n {'':30s} {'BASELINE':>10s} {'UNIFIED':>10s} {'DELTA':>8s}")
219
+ print(f" {'─' * 60}")
220
+ print(f" {'MRPC F1 (retention)':30s} {mrpc_b:10.3f} {mrpc_u:10.3f} {mrpc_u-mrpc_b:+8.3f}")
221
+ print(f" {'SST-2 Acc (new task)':30s} {sst2_b:10.3f} {sst2_u:10.3f} {sst2_u-sst2_b:+8.3f}")
222
+ print(f" {'Eff rank':30s} {'16.0':>10s} {er_avg:10.1f}")
223
+ print(f" {'Saving':30s} {'0%':>10s} {sv_avg*100:9.0f}%")