| |
| """ |
| Graduated Dissimilarity Experiment |
| ==================================== |
| Train on modular addition, then fork into 5 branches with tasks of |
| increasing dissimilarity. Track representation metrics to find the |
| tipping point where forgetting begins. |
| |
| Branches: |
| A→A: Continue addition (Level 0: identical) |
| A→B: Switch to subtraction (Level 1: same Fourier circuit) |
| A→C: Switch to multiplication (Level 2: different Fourier freqs) |
| A→D: Switch to max(a,b) (Level 3: linear/ordinal circuit) |
| A→E: Switch to XOR (Level 4: bit-level circuit) |
| """ |
|
|
| import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F |
| import numpy as np, json, os, math, copy, time |
| from dataclasses import dataclass |
| from torch.utils.data import Dataset, DataLoader |
| from typing import Dict, List, Tuple |
|
|
| |
|
|
| |
| def centering(K): |
| n=K.shape[0]; u=torch.ones(n,n,device=K.device,dtype=K.dtype)/n |
| return K-u@K-K@u+u@K@u |
| def linear_HSIC(X,Y): |
| K=X@X.T; L=Y@Y.T; return (centering(K)*centering(L)).sum()/((X.shape[0]-1)**2) |
| def linear_CKA(X,Y): |
| xy=linear_HSIC(X,Y);xx=linear_HSIC(X,X);yy=linear_HSIC(Y,Y) |
| return (xy/(xx.sqrt()*yy.sqrt()).clamp(min=1e-10)).item() |
| def cka_heatmap(sa,sb): |
| hm=np.zeros((len(sa),len(sb))) |
| for i in range(len(sa)): |
| for j in range(len(sb)): hm[i,j]=linear_CKA(sa[i],sb[j]) |
| return hm |
| def subspace_angles(X,Y,k=10): |
| def tb(Z,k): |
| _,_,Vh=torch.linalg.svd(Z-Z.mean(0),full_matrices=False); return Vh[:min(k,Vh.shape[0])].T |
| Qx=tb(X,k);Qy=tb(Y,k); mk=min(Qx.shape[1],Qy.shape[1]) |
| return torch.arccos(torch.linalg.svdvals(Qx[:,:mk].T@Qy[:,:mk]).clamp(-1,1)) |
| def mean_sa_deg(X,Y,k=10): return (subspace_angles(X,Y,k).mean()*180/torch.pi).item() |
| def grad_align(model,ba,bb,lfn): |
| model.zero_grad();lfn(model,ba).backward() |
| ga=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone() |
| model.zero_grad();lfn(model,bb).backward() |
| gb=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone() |
| model.zero_grad(); return F.cosine_similarity(ga.unsqueeze(0),gb.unsqueeze(0)).item() |
| def attn_entropy(aw): |
| H=-(aw*(aw+1e-9).log2()).sum(-1) |
| return {'mean':H.mean().item(),'per_head':H.mean(dim=(0,2)).cpu().tolist()} |
| def wc_mag(s0,s1): |
| return {n:(s1[n].float()-s0[n].float()).norm().item() for n in s0 if n in s1} |
|
|
| |
| def fourier_power_spectrum(W_E, p): |
| """Compute Fourier power at each frequency for token embeddings. |
| W_E: [vocab_size, d_model]. Returns power at each freq 0..p//2.""" |
| |
| from tasks import NUM_SPECIAL |
| num_emb = W_E[NUM_SPECIAL:NUM_SPECIAL+p, :] |
| fft = torch.fft.rfft(num_emb, dim=0) |
| power = (fft.abs() ** 2).sum(dim=1) |
| return power.cpu().numpy() |
|
|
| |
| NS = 7 |
| DP = 97 |
|
|
| OP_TOKENS = {'add':2,'subtract':3,'multiply':4,'max':5,'xor':6} |
| ALL_OPS = ['add','subtract','multiply','max','xor'] |
|
|
| class MAD(Dataset): |
| def __init__(self,op='add',p=DP,split='train',tf=0.5,seed=42): |
| self.p=p;self.op=op;self.ot=OP_TOKENS[op] |
| ap=[(a,b) for a in range(p) for b in range(p)] |
| rng=np.random.RandomState(seed);rng.shuffle(ap) |
| nt=int(len(ap)*tf); self.pairs=ap[:nt] if split=='train' else ap[nt:] |
| def _c(self,a,b): |
| if self.op=='add': return(a+b)%self.p |
| elif self.op=='subtract': return(a-b)%self.p |
| elif self.op=='multiply': return(a*b)%self.p |
| elif self.op=='max': return max(a,b) |
| elif self.op=='xor': return(a^b)%self.p |
| def __len__(self): return len(self.pairs) |
| def __getitem__(self,i): |
| a,b=self.pairs[i];c=self._c(a,b) |
| return {'input_ids':torch.tensor([a+NS,self.ot,b+NS,1,c+NS],dtype=torch.long), |
| 'labels':torch.tensor([-100,-100,-100,-100,c+NS],dtype=torch.long)} |
|
|
| def get_probe(ds,n=500): |
| n=min(n,len(ds));its=[ds[i] for i in range(n)] |
| return torch.stack([it['input_ids'] for it in its]),np.array([it['labels'][-1].item()-NS for it in its]) |
|
|
| def get_loaders(p=DP,bs=512,tf=0.5,seed=42): |
| ld={} |
| for op in ALL_OPS: |
| for sp in ['train','test']: |
| ld[f'{op}_{sp}']=DataLoader(MAD(op,p,sp,tf,seed),batch_size=bs,shuffle=(sp=='train'),drop_last=False) |
| return ld |
|
|
| |
| @dataclass |
| class TC: |
| vs:int=104;nl:int=2;dm:int=128;nh:int=4;dmlp:int=512;msl:int=5;do:float=0.0;ln:bool=True |
|
|
| class MHA(nn.Module): |
| def __init__(self,c): |
| super().__init__();self.nh=c.nh;self.dh=c.dm//c.nh |
| self.WQ=nn.Linear(c.dm,c.dm,bias=False);self.WK=nn.Linear(c.dm,c.dm,bias=False) |
| self.WV=nn.Linear(c.dm,c.dm,bias=False);self.WO=nn.Linear(c.dm,c.dm,bias=False) |
| def forward(self,x,ra=False): |
| B,T,D=x.shape |
| Q=self.WQ(x).view(B,T,self.nh,self.dh).transpose(1,2) |
| K=self.WK(x).view(B,T,self.nh,self.dh).transpose(1,2) |
| V=self.WV(x).view(B,T,self.nh,self.dh).transpose(1,2) |
| s=(Q@K.transpose(-2,-1))/math.sqrt(self.dh) |
| s.masked_fill_(torch.triu(torch.ones(T,T,device=x.device),1).bool().unsqueeze(0).unsqueeze(0),float('-inf')) |
| a=F.softmax(s,dim=-1);return self.WO((a@V).transpose(1,2).reshape(B,T,D)),(a if ra else None) |
|
|
| class MLP2(nn.Module): |
| def __init__(self,c): |
| super().__init__();self.wi=nn.Linear(c.dm,c.dmlp);self.wo=nn.Linear(c.dmlp,c.dm) |
| def forward(self,x):h=F.gelu(self.wi(x));return self.wo(h),h |
|
|
| class TB(nn.Module): |
| def __init__(self,c): |
| super().__init__();self.attn=MHA(c);self.mlp=MLP2(c) |
| self.ln1=nn.LayerNorm(c.dm) if c.ln else nn.Identity() |
| self.ln2=nn.LayerNorm(c.dm) if c.ln else nn.Identity() |
| def forward(self,x,ri=False): |
| ao,aw=self.attn(self.ln1(x),ra=ri);x=x+ao;mo,mh=self.mlp(self.ln2(x));x=x+mo |
| r={'hs':x} |
| if ri:r['aw']=aw;r['mh']=mh |
| return r |
|
|
| class ST(nn.Module): |
| def __init__(self,c): |
| super().__init__();self.c=c |
| self.te=nn.Embedding(c.vs,c.dm);self.pe=nn.Embedding(c.msl,c.dm) |
| self.blocks=nn.ModuleList([TB(c) for _ in range(c.nl)]) |
| self.lnf=nn.LayerNorm(c.dm) if c.ln else nn.Identity() |
| self.head=nn.Linear(c.dm,c.vs,bias=False);self.head.weight=self.te.weight |
| self.apply(self._iw) |
| def _iw(self,m): |
| if isinstance(m,nn.Linear):nn.init.normal_(m.weight,std=0.02) |
| elif isinstance(m,nn.Embedding):nn.init.normal_(m.weight,std=0.02) |
| def forward(self,ids,labels=None,ri=False): |
| B,T=ids.shape;x=self.te(ids)+self.pe(torch.arange(T,device=ids.device)) |
| ahs=[x.detach()];aaw=[];amh=[] |
| for b in self.blocks: |
| r=b(x,ri=ri);x=r['hs'];ahs.append(x.detach()) |
| if ri:aaw.append(r['aw'].detach());amh.append(r['mh'].detach()) |
| lo=self.head(self.lnf(x));res={'logits':lo} |
| if labels is not None:res['loss']=F.cross_entropy(lo.view(-1,lo.size(-1)),labels.view(-1),ignore_index=-100) |
| if ri:res['hs']=ahs;res['aw']=aaw;res['mh']=amh |
| return res |
| def nparams(self):return sum(p.numel() for p in self.parameters()) |
|
|
| |
| def ev(model,dl,dev): |
| model.eval();tl=tc=tt=0 |
| with torch.no_grad(): |
| for b in dl: |
| ids=b['input_ids'].to(dev);labs=b['labels'].to(dev) |
| o=model(ids,labels=labs);tl+=o['loss'].item()*ids.shape[0] |
| tc+=(o['logits'][:,-1,:].argmax(-1)==labs[:,-1]).sum().item();tt+=ids.shape[0] |
| return {'loss':tl/tt,'acc':tc/tt} |
|
|
| def creps(model,pids,dev,pos=-1): |
| model.eval() |
| with torch.no_grad():o=model(pids.to(dev),ri=True) |
| return {'hs':[h[:,pos,:].cpu() for h in o['hs']],'aw':[a.cpu() for a in o['aw']]} |
|
|
| def cmetrics(model,s0,s1,rc,ri,rp,dev,cfg): |
| m={};nl=cfg.nl+1 |
| for li in range(nl): |
| p=f'layer_{li}';c=rc['hs'][li];ii=ri['hs'][li];pp=rp['hs'][li] |
| m[f'{p}/cka_vs_init']=linear_CKA(c,ii) |
| m[f'{p}/cka_vs_phase1']=linear_CKA(c,pp) |
| k=min(10,c.shape[0]//2,c.shape[1]) |
| m[f'{p}/subspace_angle_vs_phase1']=mean_sa_deg(c,pp,k=k) if k>0 else 0. |
| for li,aw in enumerate(rc['aw']): |
| e=attn_entropy(aw);m[f'layer_{li+1}/attn_entropy_mean']=e['mean'] |
| for h,he in enumerate(e['per_head']):m[f'layer_{li+1}/head_{h}_entropy']=he |
| cs={k:v.cpu() for k,v in model.state_dict().items()} |
| wp=wc_mag(s1,cs) |
| for bi in range(cfg.nl): |
| m[f'block_{bi}/weight_change_from_phase1']=sum(v for k,v in wp.items() if f'blocks.{bi}' in k) |
| return m |
|
|
| def train_phase(model,opt,dl,ne,dev,pn,s0,s1,ri,rp,pids,eval_loaders,cfg,ce=20): |
| hist=[];gs=0 |
| for ep in range(ne): |
| model.train();el=nb=0 |
| for b in dl: |
| ids=b['input_ids'].to(dev);labs=b['labels'].to(dev) |
| o=model(ids,labels=labs);o['loss'].backward();opt.step();opt.zero_grad() |
| el+=o['loss'].item();nb+=1;gs+=1 |
| if gs%ce==0: |
| model.eval();rc=creps(model,pids,dev) |
| sm=cmetrics(model,s0,s1,rc,ri,rp,dev,cfg) |
| |
| for n,ld in eval_loaders.items(): |
| if '_test' in n: |
| e=ev(model,ld,dev);sm[f'eval/{n}_loss']=e['loss'];sm[f'eval/{n}_acc']=e['acc'] |
| |
| add_batch=next(iter(eval_loaders['add_test'])) |
| def lfn(m,b):return m(b['input_ids'].to(dev),labels=b['labels'].to(dev))['loss'] |
| for task_name in ['subtract','multiply','max','xor']: |
| try: |
| tb=next(iter(eval_loaders[f'{task_name}_test'])) |
| sm[f'grad_align_add_vs_{task_name}']=grad_align(model,add_batch,tb,lfn) |
| except:sm[f'grad_align_add_vs_{task_name}']=0. |
| |
| emb_w = model.te.weight.detach().cpu() |
| fps = fourier_power_spectrum(emb_w, cfg.vs - NS) |
| |
| top5 = np.argsort(fps)[::-1][:5] |
| sm['fourier_top5_freqs'] = top5.tolist() |
| sm['fourier_top5_power'] = fps[top5].tolist() |
| sm['fourier_total_power'] = float(fps.sum()) |
| sm['fourier_concentration'] = float(fps[top5].sum() / fps.sum()) |
|
|
| sm['phase']=pn;sm['epoch']=ep;sm['step']=gs;sm['train_loss']=el/nb |
| hist.append(sm) |
| add_acc=sm.get('eval/add_test_acc',0) |
| task_acc=sm.get(f'eval/{pn.split("_")[-1]}_test_acc', add_acc) |
| print(f"[{pn}] S{gs} L:{el/nb:.4f} AddAcc:{add_acc:.3f} TaskAcc:{task_acc:.3f} " |
| f"CKA(L1vP1):{sm.get('layer_1/cka_vs_phase1',0):.3f} " |
| f"FourierConc:{sm.get('fourier_concentration',0):.3f}") |
| model.train() |
| print(f"[{pn}] Ep{ep+1}/{ne} L:{el/nb:.4f}") |
| return hist |
|
|
| |
| print("="*70) |
| print("GRADUATED DISSIMILARITY EXPERIMENT") |
| print("="*70) |
| t0=time.time() |
| dev=torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {dev}") |
|
|
| p=97;seed=42;torch.manual_seed(seed);np.random.seed(seed) |
| cfg=TC(vs=p+NS,nl=2,dm=128,nh=4,dmlp=512,msl=5) |
| model=ST(cfg).to(dev) |
| print(f"Parameters: {model.nparams():,}") |
|
|
| s0={k:v.cpu().clone() for k,v in model.state_dict().items()} |
| lds=get_loaders(p=p,bs=512,tf=0.5,seed=seed) |
|
|
| |
| dsa=MAD('add',p=p,split='test',tf=0.5,seed=seed) |
| pids_a,pla=get_probe(dsa,500) |
| ri=creps(model,pids_a,dev) |
|
|
| |
| init_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p) |
|
|
| |
| print(f"\n{'='*60}\nPHASE 1: Train on Addition ({150} epochs)\n{'='*60}") |
| opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1.0) |
| h1=train_phase(model,opt,lds['add_train'],150,dev,'p1_add',s0,s0,ri,ri,pids_a,lds,cfg,ce=20) |
|
|
| s1={k:v.cpu().clone() for k,v in model.state_dict().items()} |
| rp=creps(model,pids_a,dev) |
| p1_eval={op:ev(model,lds[f'{op}_test'],dev) for op in ALL_OPS} |
| print(f"\nPhase 1 final accuracies:") |
| for op,e in p1_eval.items():print(f" {op}: {e['acc']:.3f}") |
|
|
| os.makedirs('results',exist_ok=True) |
| torch.save(model.state_dict(),'results/p1.pt') |
| p1_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p) |
|
|
| |
| branch_results = {} |
| phase2_epochs = 150 |
|
|
| for task_name in ALL_OPS: |
| branch_name = f'p2_{task_name}' |
| level = ALL_OPS.index(task_name) |
| print(f"\n{'='*60}") |
| print(f"PHASE 2 — Branch A→{task_name.upper()} (Level {level}, {phase2_epochs} epochs)") |
| print(f"{'='*60}") |
|
|
| m = ST(cfg).to(dev) |
| m.load_state_dict(torch.load('results/p1.pt', weights_only=True)) |
| o = optim.AdamW(m.parameters(), lr=1e-3, weight_decay=1.0) |
| h = train_phase(m, o, lds[f'{task_name}_train'], phase2_epochs, dev, |
| branch_name, s0, s1, ri, rp, pids_a, lds, cfg, ce=20) |
|
|
| |
| final_eval = {op: ev(m, lds[f'{op}_test'], dev) for op in ALL_OPS} |
| fps_final = fourier_power_spectrum(m.te.weight.detach().cpu(), p) |
|
|
| branch_results[task_name] = { |
| 'history': h, |
| 'final_eval': final_eval, |
| 'fourier_spectrum_final': fps_final.tolist(), |
| } |
|
|
| print(f"\n Final accuracies after A→{task_name}:") |
| for op, e in final_eval.items(): |
| marker = " ←TRAINED" if op == task_name else (" ←BASE" if op == 'add' else "") |
| print(f" {op}: {e['acc']:.3f}{marker}") |
|
|
| |
| forgetting = p1_eval['add']['acc'] - final_eval['add']['acc'] |
| print(f" Addition FORGETTING: {forgetting*100:.1f}%") |
|
|
| |
| print(f"\n{'='*60}\nPHASE 3: Cross-model Representation Comparison\n{'='*60}") |
|
|
| |
| branch_reps = {} |
| for task_name in ALL_OPS: |
| |
| |
| |
| pass |
|
|
| |
| summary = { |
| 'config': {'p':p,'n_layers':2,'d_model':128,'n_heads':4,'d_mlp':512, |
| 'phase1_epochs':150,'phase2_epochs':phase2_epochs, |
| 'lr':1e-3,'weight_decay':1.0,'batch_size':512, |
| 'train_frac':0.5,'seed':seed,'n_parameters':model.nparams()}, |
| 'dissimilarity_ladder': { |
| 'add': {'level':0, 'description':'Identical (Fourier circuit)'}, |
| 'subtract': {'level':1, 'description':'Same Fourier circuit (sign flip)'}, |
| 'multiply': {'level':2, 'description':'Discrete-log Fourier circuit'}, |
| 'max': {'level':3, 'description':'Linear/ordinal circuit'}, |
| 'xor': {'level':4, 'description':'Bit-level circuit'}, |
| }, |
| 'phase1_history': h1, |
| 'phase1_final_eval': {op: p1_eval[op] for op in ALL_OPS}, |
| 'fourier_spectrum_init': init_fps.tolist(), |
| 'fourier_spectrum_phase1': p1_fps.tolist(), |
| } |
|
|
| |
| for task_name, br in branch_results.items(): |
| level = ALL_OPS.index(task_name) |
| summary[f'branch_{task_name}'] = { |
| 'level': level, |
| 'history': br['history'], |
| 'final_eval': br['final_eval'], |
| 'fourier_spectrum': br['fourier_spectrum_final'], |
| 'addition_forgetting': p1_eval['add']['acc'] - br['final_eval']['add']['acc'], |
| } |
|
|
| |
| print("\n" + "="*60) |
| print("FORGETTING SUMMARY") |
| print("="*60) |
| print(f"{'Task':<15} {'Level':<8} {'Add Acc':<10} {'Task Acc':<10} {'Forgetting':<12} {'Grad Align':<12}") |
| print("-"*67) |
| for task_name in ALL_OPS: |
| br = branch_results[task_name] |
| level = ALL_OPS.index(task_name) |
| add_acc = br['final_eval']['add']['acc'] |
| task_acc = br['final_eval'][task_name]['acc'] |
| forgetting = p1_eval['add']['acc'] - add_acc |
| |
| last_ga = 0 |
| for h in reversed(br['history']): |
| if f'grad_align_add_vs_{task_name}' in h: |
| last_ga = h[f'grad_align_add_vs_{task_name}']; break |
| elif task_name == 'add' and 'grad_align_add_vs_subtract' in h: |
| last_ga = h['grad_align_add_vs_subtract']; break |
| print(f"{task_name:<15} {level:<8} {add_acc:<10.3f} {task_acc:<10.3f} {forgetting*100:<12.1f}% {last_ga:<12.3f}") |
|
|
| summary['forgetting_table'] = {} |
| for task_name in ALL_OPS: |
| br = branch_results[task_name] |
| summary['forgetting_table'][task_name] = { |
| 'level': ALL_OPS.index(task_name), |
| 'addition_accuracy': br['final_eval']['add']['acc'], |
| 'task_accuracy': br['final_eval'][task_name]['acc'], |
| 'forgetting_pct': (p1_eval['add']['acc'] - br['final_eval']['add']['acc']) * 100, |
| } |
|
|
| with open('results/graduated_experiment_results.json','w') as f: |
| json.dump(summary,f,indent=2,default=str) |
|
|
| elapsed = time.time() - t0 |
| print(f"\nExperiment completed in {elapsed/60:.1f} minutes") |
| print("Results saved to results/graduated_experiment_results.json") |
|
|
| |
| print("\nGenerating plots...") |
| import matplotlib;matplotlib.use('Agg');import matplotlib.pyplot as plt |
|
|
| colors = {'add':'#2196F3','subtract':'#4CAF50','multiply':'#FF9800','max':'#F44336','xor':'#9C27B0'} |
| labels = {'add':'Addition (L0)','subtract':'Subtraction (L1)','multiply':'Multiplication (L2)', |
| 'max':'Max (L3)','xor':'XOR (L4)'} |
|
|
| |
| fig,ax=plt.subplots(figsize=(10,6)) |
| tasks=[t for t in ALL_OPS] |
| forg=[summary['forgetting_table'][t]['forgetting_pct'] for t in tasks] |
| task_acc=[summary['forgetting_table'][t]['task_accuracy'] for t in tasks] |
| x=np.arange(len(tasks)) |
| bars=ax.bar(x,forg,color=[colors[t] for t in tasks],edgecolor='black',linewidth=0.5) |
| ax.set_xticks(x);ax.set_xticklabels([labels[t] for t in tasks],rotation=15,ha='right') |
| ax.set_ylabel('Addition Accuracy Forgetting (%)');ax.set_title('Forgetting vs Task Dissimilarity Level') |
| ax.axhline(0,color='gray',ls=':',alpha=0.5) |
| for bar,ta in zip(bars,task_acc): |
| ax.text(bar.get_x()+bar.get_width()/2,bar.get_height()+0.5,f'Task:{ta:.1%}', |
| ha='center',va='bottom',fontsize=8) |
| plt.tight_layout();plt.savefig('results/forgetting_ladder.png',dpi=150);plt.close() |
|
|
| |
| fig,ax=plt.subplots(figsize=(12,6)) |
| s1s=[h['step'] for h in h1] |
| ax.plot(s1s,[h.get('eval/add_test_acc',0) for h in h1],'k-',lw=2,label='Phase 1 (training)') |
| for task_name in ALL_OPS: |
| br=branch_results[task_name] |
| steps=[h['step']+s1s[-1] for h in br['history']] |
| acc=[h.get('eval/add_test_acc',0) for h in br['history']] |
| ax.plot(steps,acc,'-',color=colors[task_name],lw=2,label=labels[task_name]) |
| ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5,label='Phase transition') |
| ax.set_xlabel('Training Step');ax.set_ylabel('Addition Test Accuracy') |
| ax.set_title('Addition Accuracy: Learning vs Forgetting');ax.legend();ax.set_ylim(-0.05,1.05) |
| plt.tight_layout();plt.savefig('results/addition_accuracy_all_branches.png',dpi=150);plt.close() |
|
|
| |
| fig,axes=plt.subplots(1,3,figsize=(18,5)) |
| for li,ax in enumerate(axes): |
| m_name=f'layer_{li}/cka_vs_phase1' |
| for task_name in ALL_OPS: |
| br=branch_results[task_name] |
| steps=[h['step'] for h in br['history'] if m_name in h] |
| vals=[h[m_name] for h in br['history'] if m_name in h] |
| ax.plot(steps,vals,'-',color=colors[task_name],lw=1.5,label=labels[task_name]) |
| ax.set_title(f'{"Embedding" if li==0 else f"Layer {li}"}: CKA vs Phase 1 End') |
| ax.set_xlabel('Step');ax.set_ylabel('CKA');ax.legend(fontsize=7);ax.set_ylim(0,1.05) |
| plt.tight_layout();plt.savefig('results/cka_all_branches.png',dpi=150);plt.close() |
|
|
| |
| fig,ax=plt.subplots(figsize=(12,6)) |
| for task_name in ['subtract','multiply','max','xor']: |
| m_name=f'grad_align_add_vs_{task_name}' |
| |
| steps_p1=[h['step'] for h in h1 if m_name in h] |
| vals_p1=[h[m_name] for h in h1 if m_name in h] |
| ax.plot(steps_p1,vals_p1,':',color=colors[task_name],lw=1,alpha=0.5) |
| |
| br=branch_results[task_name] |
| steps=[h['step']+s1s[-1] for h in br['history'] if m_name in h] |
| vals=[h[m_name] for h in br['history'] if m_name in h] |
| ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name]) |
| ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5) |
| ax.axhline(0,color='gray',ls=':',alpha=0.3) |
| ax.set_xlabel('Step');ax.set_ylabel('Gradient Cosine Similarity') |
| ax.set_title('Gradient Alignment with Addition Task');ax.legend() |
| plt.tight_layout();plt.savefig('results/gradient_alignment_all.png',dpi=150);plt.close() |
|
|
| |
| fig,axes=plt.subplots(2,3,figsize=(18,10)) |
| freqs=np.arange(len(init_fps)) |
| axes[0,0].bar(freqs,init_fps,color='gray',alpha=0.5);axes[0,0].set_title('Init');axes[0,0].set_ylabel('Power') |
| axes[0,1].bar(freqs,p1_fps,color='black',alpha=0.7);axes[0,1].set_title('After Phase 1 (Addition)') |
| for i,task_name in enumerate(['subtract','multiply','max','xor']): |
| ax=axes[(i+2)//3,(i+2)%3] |
| fps=np.array(branch_results[task_name]['fourier_spectrum_final']) |
| ax.bar(freqs[:len(fps)],fps,color=colors[task_name],alpha=0.7) |
| ax.set_title(f'After A→{task_name.title()}') |
| if (i+2)//3==1:ax.set_xlabel('Frequency') |
| if (i+2)%3==0:ax.set_ylabel('Power') |
| plt.suptitle('Embedding Fourier Power Spectrum Evolution',fontsize=14) |
| plt.tight_layout();plt.savefig('results/fourier_spectra.png',dpi=150);plt.close() |
|
|
| |
| fig,ax=plt.subplots(figsize=(12,6)) |
| for task_name in ALL_OPS: |
| m_name='layer_2/subspace_angle_vs_phase1' |
| br=branch_results[task_name] |
| steps=[h['step'] for h in br['history'] if m_name in h] |
| vals=[h[m_name] for h in br['history'] if m_name in h] |
| ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name]) |
| ax.set_xlabel('Step');ax.set_ylabel('Subspace Angle (degrees)') |
| ax.set_title('Final Layer Subspace Angle Drift from Phase 1');ax.legend() |
| plt.tight_layout();plt.savefig('results/subspace_angles_all.png',dpi=150);plt.close() |
|
|
| print("All plots generated!") |
|
|
| |
| from huggingface_hub import HfApi |
| api=HfApi(token=os.environ.get('HF_TOKEN')) |
| repo='tekkmaven/representation-learning-dynamics' |
| files_to_upload = [ |
| 'graduated_experiment_results.json', |
| 'forgetting_ladder.png','addition_accuracy_all_branches.png', |
| 'cka_all_branches.png','gradient_alignment_all.png', |
| 'fourier_spectra.png','subspace_angles_all.png', |
| ] |
| for f in files_to_upload: |
| try: |
| api.upload_file(path_or_fileobj=f'results/{f}',path_in_repo=f'results/{f}',repo_id=repo,repo_type='model') |
| print(f"Uploaded {f}") |
| except Exception as e: |
| print(f"Failed to upload {f}: {e}") |
|
|
| print(f"\n=== EXPERIMENT COMPLETE ===") |
| print(f"Total time: {(time.time()-t0)/60:.1f} minutes") |
| print(f"Repository: https://huggingface.co/{repo}") |
|
|