tekkmaven commited on
Commit
36ee7e2
·
verified ·
1 Parent(s): 6da318d

Upload run_graduated.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_graduated.py +509 -0
run_graduated.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Graduated Dissimilarity Experiment
4
+ ====================================
5
+ Train on modular addition, then fork into 5 branches with tasks of
6
+ increasing dissimilarity. Track representation metrics to find the
7
+ tipping point where forgetting begins.
8
+
9
+ Branches:
10
+ A→A: Continue addition (Level 0: identical)
11
+ A→B: Switch to subtraction (Level 1: same Fourier circuit)
12
+ A→C: Switch to multiplication (Level 2: different Fourier freqs)
13
+ A→D: Switch to max(a,b) (Level 3: linear/ordinal circuit)
14
+ A→E: Switch to XOR (Level 4: bit-level circuit)
15
+ """
16
+
17
+ import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
18
+ import numpy as np, json, os, math, copy, time
19
+ from dataclasses import dataclass
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from typing import Dict, List, Tuple
22
+
23
+ # ============= INLINE DEPS (self-contained for job) =============
24
+
25
+ # --- Representation metrics ---
26
+ def centering(K):
27
+ n=K.shape[0]; u=torch.ones(n,n,device=K.device,dtype=K.dtype)/n
28
+ return K-u@K-K@u+u@K@u
29
+ def linear_HSIC(X,Y):
30
+ K=X@X.T; L=Y@Y.T; return (centering(K)*centering(L)).sum()/((X.shape[0]-1)**2)
31
+ def linear_CKA(X,Y):
32
+ xy=linear_HSIC(X,Y);xx=linear_HSIC(X,X);yy=linear_HSIC(Y,Y)
33
+ return (xy/(xx.sqrt()*yy.sqrt()).clamp(min=1e-10)).item()
34
+ def cka_heatmap(sa,sb):
35
+ hm=np.zeros((len(sa),len(sb)))
36
+ for i in range(len(sa)):
37
+ for j in range(len(sb)): hm[i,j]=linear_CKA(sa[i],sb[j])
38
+ return hm
39
+ def subspace_angles(X,Y,k=10):
40
+ def tb(Z,k):
41
+ _,_,Vh=torch.linalg.svd(Z-Z.mean(0),full_matrices=False); return Vh[:min(k,Vh.shape[0])].T
42
+ Qx=tb(X,k);Qy=tb(Y,k); mk=min(Qx.shape[1],Qy.shape[1])
43
+ return torch.arccos(torch.linalg.svdvals(Qx[:,:mk].T@Qy[:,:mk]).clamp(-1,1))
44
+ def mean_sa_deg(X,Y,k=10): return (subspace_angles(X,Y,k).mean()*180/torch.pi).item()
45
+ def grad_align(model,ba,bb,lfn):
46
+ model.zero_grad();lfn(model,ba).backward()
47
+ ga=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone()
48
+ model.zero_grad();lfn(model,bb).backward()
49
+ gb=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone()
50
+ model.zero_grad(); return F.cosine_similarity(ga.unsqueeze(0),gb.unsqueeze(0)).item()
51
+ def attn_entropy(aw):
52
+ H=-(aw*(aw+1e-9).log2()).sum(-1)
53
+ return {'mean':H.mean().item(),'per_head':H.mean(dim=(0,2)).cpu().tolist()}
54
+ def wc_mag(s0,s1):
55
+ return {n:(s1[n].float()-s0[n].float()).norm().item() for n in s0 if n in s1}
56
+
57
+ # Fourier power spectrum of embedding matrix
58
+ def fourier_power_spectrum(W_E, p):
59
+ """Compute Fourier power at each frequency for token embeddings.
60
+ W_E: [vocab_size, d_model]. Returns power at each freq 0..p//2."""
61
+ # Extract number embeddings only (skip special tokens)
62
+ from tasks import NUM_SPECIAL
63
+ num_emb = W_E[NUM_SPECIAL:NUM_SPECIAL+p, :] # [p, d_model]
64
+ fft = torch.fft.rfft(num_emb, dim=0) # [p//2+1, d_model]
65
+ power = (fft.abs() ** 2).sum(dim=1) # [p//2+1]
66
+ return power.cpu().numpy()
67
+
68
+ # --- Tasks ---
69
+ NS = 7 # NUM_SPECIAL (updated)
70
+ DP = 97
71
+
72
+ OP_TOKENS = {'add':2,'subtract':3,'multiply':4,'max':5,'xor':6}
73
+ ALL_OPS = ['add','subtract','multiply','max','xor']
74
+
75
+ class MAD(Dataset):
76
+ def __init__(self,op='add',p=DP,split='train',tf=0.5,seed=42):
77
+ self.p=p;self.op=op;self.ot=OP_TOKENS[op]
78
+ ap=[(a,b) for a in range(p) for b in range(p)]
79
+ rng=np.random.RandomState(seed);rng.shuffle(ap)
80
+ nt=int(len(ap)*tf); self.pairs=ap[:nt] if split=='train' else ap[nt:]
81
+ def _c(self,a,b):
82
+ if self.op=='add': return(a+b)%self.p
83
+ elif self.op=='subtract': return(a-b)%self.p
84
+ elif self.op=='multiply': return(a*b)%self.p
85
+ elif self.op=='max': return max(a,b)
86
+ elif self.op=='xor': return(a^b)%self.p
87
+ def __len__(self): return len(self.pairs)
88
+ def __getitem__(self,i):
89
+ a,b=self.pairs[i];c=self._c(a,b)
90
+ return {'input_ids':torch.tensor([a+NS,self.ot,b+NS,1,c+NS],dtype=torch.long),
91
+ 'labels':torch.tensor([-100,-100,-100,-100,c+NS],dtype=torch.long)}
92
+
93
+ def get_probe(ds,n=500):
94
+ n=min(n,len(ds));its=[ds[i] for i in range(n)]
95
+ return torch.stack([it['input_ids'] for it in its]),np.array([it['labels'][-1].item()-NS for it in its])
96
+
97
+ def get_loaders(p=DP,bs=512,tf=0.5,seed=42):
98
+ ld={}
99
+ for op in ALL_OPS:
100
+ for sp in ['train','test']:
101
+ ld[f'{op}_{sp}']=DataLoader(MAD(op,p,sp,tf,seed),batch_size=bs,shuffle=(sp=='train'),drop_last=False)
102
+ return ld
103
+
104
+ # --- Model ---
105
+ @dataclass
106
+ class TC:
107
+ vs:int=104;nl:int=2;dm:int=128;nh:int=4;dmlp:int=512;msl:int=5;do:float=0.0;ln:bool=True
108
+
109
+ class MHA(nn.Module):
110
+ def __init__(self,c):
111
+ super().__init__();self.nh=c.nh;self.dh=c.dm//c.nh
112
+ self.WQ=nn.Linear(c.dm,c.dm,bias=False);self.WK=nn.Linear(c.dm,c.dm,bias=False)
113
+ self.WV=nn.Linear(c.dm,c.dm,bias=False);self.WO=nn.Linear(c.dm,c.dm,bias=False)
114
+ def forward(self,x,ra=False):
115
+ B,T,D=x.shape
116
+ Q=self.WQ(x).view(B,T,self.nh,self.dh).transpose(1,2)
117
+ K=self.WK(x).view(B,T,self.nh,self.dh).transpose(1,2)
118
+ V=self.WV(x).view(B,T,self.nh,self.dh).transpose(1,2)
119
+ s=(Q@K.transpose(-2,-1))/math.sqrt(self.dh)
120
+ s.masked_fill_(torch.triu(torch.ones(T,T,device=x.device),1).bool().unsqueeze(0).unsqueeze(0),float('-inf'))
121
+ a=F.softmax(s,dim=-1);return self.WO((a@V).transpose(1,2).reshape(B,T,D)),(a if ra else None)
122
+
123
+ class MLP2(nn.Module):
124
+ def __init__(self,c):
125
+ super().__init__();self.wi=nn.Linear(c.dm,c.dmlp);self.wo=nn.Linear(c.dmlp,c.dm)
126
+ def forward(self,x):h=F.gelu(self.wi(x));return self.wo(h),h
127
+
128
+ class TB(nn.Module):
129
+ def __init__(self,c):
130
+ super().__init__();self.attn=MHA(c);self.mlp=MLP2(c)
131
+ self.ln1=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
132
+ self.ln2=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
133
+ def forward(self,x,ri=False):
134
+ ao,aw=self.attn(self.ln1(x),ra=ri);x=x+ao;mo,mh=self.mlp(self.ln2(x));x=x+mo
135
+ r={'hs':x}
136
+ if ri:r['aw']=aw;r['mh']=mh
137
+ return r
138
+
139
+ class ST(nn.Module):
140
+ def __init__(self,c):
141
+ super().__init__();self.c=c
142
+ self.te=nn.Embedding(c.vs,c.dm);self.pe=nn.Embedding(c.msl,c.dm)
143
+ self.blocks=nn.ModuleList([TB(c) for _ in range(c.nl)])
144
+ self.lnf=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
145
+ self.head=nn.Linear(c.dm,c.vs,bias=False);self.head.weight=self.te.weight
146
+ self.apply(self._iw)
147
+ def _iw(self,m):
148
+ if isinstance(m,nn.Linear):nn.init.normal_(m.weight,std=0.02)
149
+ elif isinstance(m,nn.Embedding):nn.init.normal_(m.weight,std=0.02)
150
+ def forward(self,ids,labels=None,ri=False):
151
+ B,T=ids.shape;x=self.te(ids)+self.pe(torch.arange(T,device=ids.device))
152
+ ahs=[x.detach()];aaw=[];amh=[]
153
+ for b in self.blocks:
154
+ r=b(x,ri=ri);x=r['hs'];ahs.append(x.detach())
155
+ if ri:aaw.append(r['aw'].detach());amh.append(r['mh'].detach())
156
+ lo=self.head(self.lnf(x));res={'logits':lo}
157
+ if labels is not None:res['loss']=F.cross_entropy(lo.view(-1,lo.size(-1)),labels.view(-1),ignore_index=-100)
158
+ if ri:res['hs']=ahs;res['aw']=aaw;res['mh']=amh
159
+ return res
160
+ def nparams(self):return sum(p.numel() for p in self.parameters())
161
+
162
+ # --- Experiment helpers ---
163
+ def ev(model,dl,dev):
164
+ model.eval();tl=tc=tt=0
165
+ with torch.no_grad():
166
+ for b in dl:
167
+ ids=b['input_ids'].to(dev);labs=b['labels'].to(dev)
168
+ o=model(ids,labels=labs);tl+=o['loss'].item()*ids.shape[0]
169
+ tc+=(o['logits'][:,-1,:].argmax(-1)==labs[:,-1]).sum().item();tt+=ids.shape[0]
170
+ return {'loss':tl/tt,'acc':tc/tt}
171
+
172
+ def creps(model,pids,dev,pos=-1):
173
+ model.eval()
174
+ with torch.no_grad():o=model(pids.to(dev),ri=True)
175
+ return {'hs':[h[:,pos,:].cpu() for h in o['hs']],'aw':[a.cpu() for a in o['aw']]}
176
+
177
+ def cmetrics(model,s0,s1,rc,ri,rp,dev,cfg):
178
+ m={};nl=cfg.nl+1
179
+ for li in range(nl):
180
+ p=f'layer_{li}';c=rc['hs'][li];ii=ri['hs'][li];pp=rp['hs'][li]
181
+ m[f'{p}/cka_vs_init']=linear_CKA(c,ii)
182
+ m[f'{p}/cka_vs_phase1']=linear_CKA(c,pp)
183
+ k=min(10,c.shape[0]//2,c.shape[1])
184
+ m[f'{p}/subspace_angle_vs_phase1']=mean_sa_deg(c,pp,k=k) if k>0 else 0.
185
+ for li,aw in enumerate(rc['aw']):
186
+ e=attn_entropy(aw);m[f'layer_{li+1}/attn_entropy_mean']=e['mean']
187
+ for h,he in enumerate(e['per_head']):m[f'layer_{li+1}/head_{h}_entropy']=he
188
+ cs={k:v.cpu() for k,v in model.state_dict().items()}
189
+ wp=wc_mag(s1,cs)
190
+ for bi in range(cfg.nl):
191
+ m[f'block_{bi}/weight_change_from_phase1']=sum(v for k,v in wp.items() if f'blocks.{bi}' in k)
192
+ return m
193
+
194
+ def train_phase(model,opt,dl,ne,dev,pn,s0,s1,ri,rp,pids,eval_loaders,cfg,ce=20):
195
+ hist=[];gs=0
196
+ for ep in range(ne):
197
+ model.train();el=nb=0
198
+ for b in dl:
199
+ ids=b['input_ids'].to(dev);labs=b['labels'].to(dev)
200
+ o=model(ids,labels=labs);o['loss'].backward();opt.step();opt.zero_grad()
201
+ el+=o['loss'].item();nb+=1;gs+=1
202
+ if gs%ce==0:
203
+ model.eval();rc=creps(model,pids,dev)
204
+ sm=cmetrics(model,s0,s1,rc,ri,rp,dev,cfg)
205
+ # Evaluate on ALL tasks
206
+ for n,ld in eval_loaders.items():
207
+ if '_test' in n:
208
+ e=ev(model,ld,dev);sm[f'eval/{n}_loss']=e['loss'];sm[f'eval/{n}_acc']=e['acc']
209
+ # Gradient alignment: addition vs each other task
210
+ add_batch=next(iter(eval_loaders['add_test']))
211
+ def lfn(m,b):return m(b['input_ids'].to(dev),labels=b['labels'].to(dev))['loss']
212
+ for task_name in ['subtract','multiply','max','xor']:
213
+ try:
214
+ tb=next(iter(eval_loaders[f'{task_name}_test']))
215
+ sm[f'grad_align_add_vs_{task_name}']=grad_align(model,add_batch,tb,lfn)
216
+ except:sm[f'grad_align_add_vs_{task_name}']=0.
217
+ # Fourier power spectrum of embeddings
218
+ emb_w = model.te.weight.detach().cpu()
219
+ fps = fourier_power_spectrum(emb_w, cfg.vs - NS)
220
+ # Store top-5 peak frequencies
221
+ top5 = np.argsort(fps)[::-1][:5]
222
+ sm['fourier_top5_freqs'] = top5.tolist()
223
+ sm['fourier_top5_power'] = fps[top5].tolist()
224
+ sm['fourier_total_power'] = float(fps.sum())
225
+ sm['fourier_concentration'] = float(fps[top5].sum() / fps.sum()) # how concentrated
226
+
227
+ sm['phase']=pn;sm['epoch']=ep;sm['step']=gs;sm['train_loss']=el/nb
228
+ hist.append(sm)
229
+ add_acc=sm.get('eval/add_test_acc',0)
230
+ task_acc=sm.get(f'eval/{pn.split("_")[-1]}_test_acc', add_acc)
231
+ print(f"[{pn}] S{gs} L:{el/nb:.4f} AddAcc:{add_acc:.3f} TaskAcc:{task_acc:.3f} "
232
+ f"CKA(L1vP1):{sm.get('layer_1/cka_vs_phase1',0):.3f} "
233
+ f"FourierConc:{sm.get('fourier_concentration',0):.3f}")
234
+ model.train()
235
+ print(f"[{pn}] Ep{ep+1}/{ne} L:{el/nb:.4f}")
236
+ return hist
237
+
238
+ # ============= MAIN =============
239
+ print("="*70)
240
+ print("GRADUATED DISSIMILARITY EXPERIMENT")
241
+ print("="*70)
242
+ t0=time.time()
243
+ dev=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
244
+ print(f"Device: {dev}")
245
+
246
+ p=97;seed=42;torch.manual_seed(seed);np.random.seed(seed)
247
+ cfg=TC(vs=p+NS,nl=2,dm=128,nh=4,dmlp=512,msl=5)
248
+ model=ST(cfg).to(dev)
249
+ print(f"Parameters: {model.nparams():,}")
250
+
251
+ s0={k:v.cpu().clone() for k,v in model.state_dict().items()}
252
+ lds=get_loaders(p=p,bs=512,tf=0.5,seed=seed)
253
+
254
+ # Probe data from addition test set (same across all branches)
255
+ dsa=MAD('add',p=p,split='test',tf=0.5,seed=seed)
256
+ pids_a,pla=get_probe(dsa,500)
257
+ ri=creps(model,pids_a,dev)
258
+
259
+ # Initial Fourier spectrum
260
+ init_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p)
261
+
262
+ # ========== PHASE 1: Train on Addition ==========
263
+ print(f"\n{'='*60}\nPHASE 1: Train on Addition ({150} epochs)\n{'='*60}")
264
+ opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1.0)
265
+ h1=train_phase(model,opt,lds['add_train'],150,dev,'p1_add',s0,s0,ri,ri,pids_a,lds,cfg,ce=20)
266
+
267
+ s1={k:v.cpu().clone() for k,v in model.state_dict().items()}
268
+ rp=creps(model,pids_a,dev)
269
+ p1_eval={op:ev(model,lds[f'{op}_test'],dev) for op in ALL_OPS}
270
+ print(f"\nPhase 1 final accuracies:")
271
+ for op,e in p1_eval.items():print(f" {op}: {e['acc']:.3f}")
272
+
273
+ os.makedirs('results',exist_ok=True)
274
+ torch.save(model.state_dict(),'results/p1.pt')
275
+ p1_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p)
276
+
277
+ # ========== PHASE 2: Fork into 5 branches ==========
278
+ branch_results = {}
279
+ phase2_epochs = 150
280
+
281
+ for task_name in ALL_OPS:
282
+ branch_name = f'p2_{task_name}'
283
+ level = ALL_OPS.index(task_name)
284
+ print(f"\n{'='*60}")
285
+ print(f"PHASE 2 — Branch A→{task_name.upper()} (Level {level}, {phase2_epochs} epochs)")
286
+ print(f"{'='*60}")
287
+
288
+ m = ST(cfg).to(dev)
289
+ m.load_state_dict(torch.load('results/p1.pt', weights_only=True))
290
+ o = optim.AdamW(m.parameters(), lr=1e-3, weight_decay=1.0)
291
+ h = train_phase(m, o, lds[f'{task_name}_train'], phase2_epochs, dev,
292
+ branch_name, s0, s1, ri, rp, pids_a, lds, cfg, ce=20)
293
+
294
+ # Final evaluation on ALL tasks
295
+ final_eval = {op: ev(m, lds[f'{op}_test'], dev) for op in ALL_OPS}
296
+ fps_final = fourier_power_spectrum(m.te.weight.detach().cpu(), p)
297
+
298
+ branch_results[task_name] = {
299
+ 'history': h,
300
+ 'final_eval': final_eval,
301
+ 'fourier_spectrum_final': fps_final.tolist(),
302
+ }
303
+
304
+ print(f"\n Final accuracies after A→{task_name}:")
305
+ for op, e in final_eval.items():
306
+ marker = " ←TRAINED" if op == task_name else (" ←BASE" if op == 'add' else "")
307
+ print(f" {op}: {e['acc']:.3f}{marker}")
308
+
309
+ # Forgetting = Phase1 add accuracy - current add accuracy
310
+ forgetting = p1_eval['add']['acc'] - final_eval['add']['acc']
311
+ print(f" Addition FORGETTING: {forgetting*100:.1f}%")
312
+
313
+ # ========== PHASE 3: Cross-comparison ==========
314
+ print(f"\n{'='*60}\nPHASE 3: Cross-model Representation Comparison\n{'='*60}")
315
+
316
+ # Reload all branch final models for comparison
317
+ branch_reps = {}
318
+ for task_name in ALL_OPS:
319
+ # Re-run the final model to get representations
320
+ # (models were not saved to disk, so re-extract from branch_results last step)
321
+ # Actually let's just use the last checkpoint reps from training
322
+ pass
323
+
324
+ # Build summary table
325
+ summary = {
326
+ 'config': {'p':p,'n_layers':2,'d_model':128,'n_heads':4,'d_mlp':512,
327
+ 'phase1_epochs':150,'phase2_epochs':phase2_epochs,
328
+ 'lr':1e-3,'weight_decay':1.0,'batch_size':512,
329
+ 'train_frac':0.5,'seed':seed,'n_parameters':model.nparams()},
330
+ 'dissimilarity_ladder': {
331
+ 'add': {'level':0, 'description':'Identical (Fourier circuit)'},
332
+ 'subtract': {'level':1, 'description':'Same Fourier circuit (sign flip)'},
333
+ 'multiply': {'level':2, 'description':'Discrete-log Fourier circuit'},
334
+ 'max': {'level':3, 'description':'Linear/ordinal circuit'},
335
+ 'xor': {'level':4, 'description':'Bit-level circuit'},
336
+ },
337
+ 'phase1_history': h1,
338
+ 'phase1_final_eval': {op: p1_eval[op] for op in ALL_OPS},
339
+ 'fourier_spectrum_init': init_fps.tolist(),
340
+ 'fourier_spectrum_phase1': p1_fps.tolist(),
341
+ }
342
+
343
+ # Per-branch results
344
+ for task_name, br in branch_results.items():
345
+ level = ALL_OPS.index(task_name)
346
+ summary[f'branch_{task_name}'] = {
347
+ 'level': level,
348
+ 'history': br['history'],
349
+ 'final_eval': br['final_eval'],
350
+ 'fourier_spectrum': br['fourier_spectrum_final'],
351
+ 'addition_forgetting': p1_eval['add']['acc'] - br['final_eval']['add']['acc'],
352
+ }
353
+
354
+ # Forgetting summary table
355
+ print("\n" + "="*60)
356
+ print("FORGETTING SUMMARY")
357
+ print("="*60)
358
+ print(f"{'Task':<15} {'Level':<8} {'Add Acc':<10} {'Task Acc':<10} {'Forgetting':<12} {'Grad Align':<12}")
359
+ print("-"*67)
360
+ for task_name in ALL_OPS:
361
+ br = branch_results[task_name]
362
+ level = ALL_OPS.index(task_name)
363
+ add_acc = br['final_eval']['add']['acc']
364
+ task_acc = br['final_eval'][task_name]['acc']
365
+ forgetting = p1_eval['add']['acc'] - add_acc
366
+ # Get last gradient alignment from history
367
+ last_ga = 0
368
+ for h in reversed(br['history']):
369
+ if f'grad_align_add_vs_{task_name}' in h:
370
+ last_ga = h[f'grad_align_add_vs_{task_name}']; break
371
+ elif task_name == 'add' and 'grad_align_add_vs_subtract' in h:
372
+ last_ga = h['grad_align_add_vs_subtract']; break
373
+ print(f"{task_name:<15} {level:<8} {add_acc:<10.3f} {task_acc:<10.3f} {forgetting*100:<12.1f}% {last_ga:<12.3f}")
374
+
375
+ summary['forgetting_table'] = {}
376
+ for task_name in ALL_OPS:
377
+ br = branch_results[task_name]
378
+ summary['forgetting_table'][task_name] = {
379
+ 'level': ALL_OPS.index(task_name),
380
+ 'addition_accuracy': br['final_eval']['add']['acc'],
381
+ 'task_accuracy': br['final_eval'][task_name]['acc'],
382
+ 'forgetting_pct': (p1_eval['add']['acc'] - br['final_eval']['add']['acc']) * 100,
383
+ }
384
+
385
+ with open('results/graduated_experiment_results.json','w') as f:
386
+ json.dump(summary,f,indent=2,default=str)
387
+
388
+ elapsed = time.time() - t0
389
+ print(f"\nExperiment completed in {elapsed/60:.1f} minutes")
390
+ print("Results saved to results/graduated_experiment_results.json")
391
+
392
+ # ============= GENERATE PLOTS =============
393
+ print("\nGenerating plots...")
394
+ import matplotlib;matplotlib.use('Agg');import matplotlib.pyplot as plt
395
+
396
+ colors = {'add':'#2196F3','subtract':'#4CAF50','multiply':'#FF9800','max':'#F44336','xor':'#9C27B0'}
397
+ labels = {'add':'Addition (L0)','subtract':'Subtraction (L1)','multiply':'Multiplication (L2)',
398
+ 'max':'Max (L3)','xor':'XOR (L4)'}
399
+
400
+ # Plot 1: Forgetting bar chart
401
+ fig,ax=plt.subplots(figsize=(10,6))
402
+ tasks=[t for t in ALL_OPS]
403
+ forg=[summary['forgetting_table'][t]['forgetting_pct'] for t in tasks]
404
+ task_acc=[summary['forgetting_table'][t]['task_accuracy'] for t in tasks]
405
+ x=np.arange(len(tasks))
406
+ bars=ax.bar(x,forg,color=[colors[t] for t in tasks],edgecolor='black',linewidth=0.5)
407
+ ax.set_xticks(x);ax.set_xticklabels([labels[t] for t in tasks],rotation=15,ha='right')
408
+ ax.set_ylabel('Addition Accuracy Forgetting (%)');ax.set_title('Forgetting vs Task Dissimilarity Level')
409
+ ax.axhline(0,color='gray',ls=':',alpha=0.5)
410
+ for bar,ta in zip(bars,task_acc):
411
+ ax.text(bar.get_x()+bar.get_width()/2,bar.get_height()+0.5,f'Task:{ta:.1%}',
412
+ ha='center',va='bottom',fontsize=8)
413
+ plt.tight_layout();plt.savefig('results/forgetting_ladder.png',dpi=150);plt.close()
414
+
415
+ # Plot 2: Addition accuracy over training for each branch
416
+ fig,ax=plt.subplots(figsize=(12,6))
417
+ s1s=[h['step'] for h in h1]
418
+ ax.plot(s1s,[h.get('eval/add_test_acc',0) for h in h1],'k-',lw=2,label='Phase 1 (training)')
419
+ for task_name in ALL_OPS:
420
+ br=branch_results[task_name]
421
+ steps=[h['step']+s1s[-1] for h in br['history']]
422
+ acc=[h.get('eval/add_test_acc',0) for h in br['history']]
423
+ ax.plot(steps,acc,'-',color=colors[task_name],lw=2,label=labels[task_name])
424
+ ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5,label='Phase transition')
425
+ ax.set_xlabel('Training Step');ax.set_ylabel('Addition Test Accuracy')
426
+ ax.set_title('Addition Accuracy: Learning vs Forgetting');ax.legend();ax.set_ylim(-0.05,1.05)
427
+ plt.tight_layout();plt.savefig('results/addition_accuracy_all_branches.png',dpi=150);plt.close()
428
+
429
+ # Plot 3: CKA vs Phase 1 for each branch (layer 2 = final layer)
430
+ fig,axes=plt.subplots(1,3,figsize=(18,5))
431
+ for li,ax in enumerate(axes):
432
+ m_name=f'layer_{li}/cka_vs_phase1'
433
+ for task_name in ALL_OPS:
434
+ br=branch_results[task_name]
435
+ steps=[h['step'] for h in br['history'] if m_name in h]
436
+ vals=[h[m_name] for h in br['history'] if m_name in h]
437
+ ax.plot(steps,vals,'-',color=colors[task_name],lw=1.5,label=labels[task_name])
438
+ ax.set_title(f'{"Embedding" if li==0 else f"Layer {li}"}: CKA vs Phase 1 End')
439
+ ax.set_xlabel('Step');ax.set_ylabel('CKA');ax.legend(fontsize=7);ax.set_ylim(0,1.05)
440
+ plt.tight_layout();plt.savefig('results/cka_all_branches.png',dpi=150);plt.close()
441
+
442
+ # Plot 4: Gradient alignment with addition over training
443
+ fig,ax=plt.subplots(figsize=(12,6))
444
+ for task_name in ['subtract','multiply','max','xor']:
445
+ m_name=f'grad_align_add_vs_{task_name}'
446
+ # Phase 1 data
447
+ steps_p1=[h['step'] for h in h1 if m_name in h]
448
+ vals_p1=[h[m_name] for h in h1 if m_name in h]
449
+ ax.plot(steps_p1,vals_p1,':',color=colors[task_name],lw=1,alpha=0.5)
450
+ # Branch data (training on this task)
451
+ br=branch_results[task_name]
452
+ steps=[h['step']+s1s[-1] for h in br['history'] if m_name in h]
453
+ vals=[h[m_name] for h in br['history'] if m_name in h]
454
+ ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name])
455
+ ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5)
456
+ ax.axhline(0,color='gray',ls=':',alpha=0.3)
457
+ ax.set_xlabel('Step');ax.set_ylabel('Gradient Cosine Similarity')
458
+ ax.set_title('Gradient Alignment with Addition Task');ax.legend()
459
+ plt.tight_layout();plt.savefig('results/gradient_alignment_all.png',dpi=150);plt.close()
460
+
461
+ # Plot 5: Fourier spectrum comparison
462
+ fig,axes=plt.subplots(2,3,figsize=(18,10))
463
+ freqs=np.arange(len(init_fps))
464
+ axes[0,0].bar(freqs,init_fps,color='gray',alpha=0.5);axes[0,0].set_title('Init');axes[0,0].set_ylabel('Power')
465
+ axes[0,1].bar(freqs,p1_fps,color='black',alpha=0.7);axes[0,1].set_title('After Phase 1 (Addition)')
466
+ for i,task_name in enumerate(['subtract','multiply','max','xor']):
467
+ ax=axes[(i+2)//3,(i+2)%3]
468
+ fps=np.array(branch_results[task_name]['fourier_spectrum_final'])
469
+ ax.bar(freqs[:len(fps)],fps,color=colors[task_name],alpha=0.7)
470
+ ax.set_title(f'After A→{task_name.title()}')
471
+ if (i+2)//3==1:ax.set_xlabel('Frequency')
472
+ if (i+2)%3==0:ax.set_ylabel('Power')
473
+ plt.suptitle('Embedding Fourier Power Spectrum Evolution',fontsize=14)
474
+ plt.tight_layout();plt.savefig('results/fourier_spectra.png',dpi=150);plt.close()
475
+
476
+ # Plot 6: Subspace angles
477
+ fig,ax=plt.subplots(figsize=(12,6))
478
+ for task_name in ALL_OPS:
479
+ m_name='layer_2/subspace_angle_vs_phase1'
480
+ br=branch_results[task_name]
481
+ steps=[h['step'] for h in br['history'] if m_name in h]
482
+ vals=[h[m_name] for h in br['history'] if m_name in h]
483
+ ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name])
484
+ ax.set_xlabel('Step');ax.set_ylabel('Subspace Angle (degrees)')
485
+ ax.set_title('Final Layer Subspace Angle Drift from Phase 1');ax.legend()
486
+ plt.tight_layout();plt.savefig('results/subspace_angles_all.png',dpi=150);plt.close()
487
+
488
+ print("All plots generated!")
489
+
490
+ # Upload to HF
491
+ from huggingface_hub import HfApi
492
+ api=HfApi(token=os.environ.get('HF_TOKEN'))
493
+ repo='tekkmaven/representation-learning-dynamics'
494
+ files_to_upload = [
495
+ 'graduated_experiment_results.json',
496
+ 'forgetting_ladder.png','addition_accuracy_all_branches.png',
497
+ 'cka_all_branches.png','gradient_alignment_all.png',
498
+ 'fourier_spectra.png','subspace_angles_all.png',
499
+ ]
500
+ for f in files_to_upload:
501
+ try:
502
+ api.upload_file(path_or_fileobj=f'results/{f}',path_in_repo=f'results/{f}',repo_id=repo,repo_type='model')
503
+ print(f"Uploaded {f}")
504
+ except Exception as e:
505
+ print(f"Failed to upload {f}: {e}")
506
+
507
+ print(f"\n=== EXPERIMENT COMPLETE ===")
508
+ print(f"Total time: {(time.time()-t0)/60:.1f} minutes")
509
+ print(f"Repository: https://huggingface.co/{repo}")