#!/usr/bin/env python3 """ Graduated Dissimilarity Experiment ==================================== Train on modular addition, then fork into 5 branches with tasks of increasing dissimilarity. Track representation metrics to find the tipping point where forgetting begins. Branches: A→A: Continue addition (Level 0: identical) A→B: Switch to subtraction (Level 1: same Fourier circuit) A→C: Switch to multiplication (Level 2: different Fourier freqs) A→D: Switch to max(a,b) (Level 3: linear/ordinal circuit) A→E: Switch to XOR (Level 4: bit-level circuit) """ import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F import numpy as np, json, os, math, copy, time from dataclasses import dataclass from torch.utils.data import Dataset, DataLoader from typing import Dict, List, Tuple # ============= INLINE DEPS (self-contained for job) ============= # --- Representation metrics --- def centering(K): n=K.shape[0]; u=torch.ones(n,n,device=K.device,dtype=K.dtype)/n return K-u@K-K@u+u@K@u def linear_HSIC(X,Y): K=X@X.T; L=Y@Y.T; return (centering(K)*centering(L)).sum()/((X.shape[0]-1)**2) def linear_CKA(X,Y): xy=linear_HSIC(X,Y);xx=linear_HSIC(X,X);yy=linear_HSIC(Y,Y) return (xy/(xx.sqrt()*yy.sqrt()).clamp(min=1e-10)).item() def cka_heatmap(sa,sb): hm=np.zeros((len(sa),len(sb))) for i in range(len(sa)): for j in range(len(sb)): hm[i,j]=linear_CKA(sa[i],sb[j]) return hm def subspace_angles(X,Y,k=10): def tb(Z,k): _,_,Vh=torch.linalg.svd(Z-Z.mean(0),full_matrices=False); return Vh[:min(k,Vh.shape[0])].T Qx=tb(X,k);Qy=tb(Y,k); mk=min(Qx.shape[1],Qy.shape[1]) return torch.arccos(torch.linalg.svdvals(Qx[:,:mk].T@Qy[:,:mk]).clamp(-1,1)) def mean_sa_deg(X,Y,k=10): return (subspace_angles(X,Y,k).mean()*180/torch.pi).item() def grad_align(model,ba,bb,lfn): model.zero_grad();lfn(model,ba).backward() ga=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone() model.zero_grad();lfn(model,bb).backward() gb=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone() model.zero_grad(); return F.cosine_similarity(ga.unsqueeze(0),gb.unsqueeze(0)).item() def attn_entropy(aw): H=-(aw*(aw+1e-9).log2()).sum(-1) return {'mean':H.mean().item(),'per_head':H.mean(dim=(0,2)).cpu().tolist()} def wc_mag(s0,s1): return {n:(s1[n].float()-s0[n].float()).norm().item() for n in s0 if n in s1} # Fourier power spectrum of embedding matrix def fourier_power_spectrum(W_E, p): """Compute Fourier power at each frequency for token embeddings. W_E: [vocab_size, d_model]. Returns power at each freq 0..p//2.""" # Extract number embeddings only (skip special tokens) from tasks import NUM_SPECIAL num_emb = W_E[NUM_SPECIAL:NUM_SPECIAL+p, :] # [p, d_model] fft = torch.fft.rfft(num_emb, dim=0) # [p//2+1, d_model] power = (fft.abs() ** 2).sum(dim=1) # [p//2+1] return power.cpu().numpy() # --- Tasks --- NS = 7 # NUM_SPECIAL (updated) DP = 97 OP_TOKENS = {'add':2,'subtract':3,'multiply':4,'max':5,'xor':6} ALL_OPS = ['add','subtract','multiply','max','xor'] class MAD(Dataset): def __init__(self,op='add',p=DP,split='train',tf=0.5,seed=42): self.p=p;self.op=op;self.ot=OP_TOKENS[op] ap=[(a,b) for a in range(p) for b in range(p)] rng=np.random.RandomState(seed);rng.shuffle(ap) nt=int(len(ap)*tf); self.pairs=ap[:nt] if split=='train' else ap[nt:] def _c(self,a,b): if self.op=='add': return(a+b)%self.p elif self.op=='subtract': return(a-b)%self.p elif self.op=='multiply': return(a*b)%self.p elif self.op=='max': return max(a,b) elif self.op=='xor': return(a^b)%self.p def __len__(self): return len(self.pairs) def __getitem__(self,i): a,b=self.pairs[i];c=self._c(a,b) return {'input_ids':torch.tensor([a+NS,self.ot,b+NS,1,c+NS],dtype=torch.long), 'labels':torch.tensor([-100,-100,-100,-100,c+NS],dtype=torch.long)} def get_probe(ds,n=500): n=min(n,len(ds));its=[ds[i] for i in range(n)] return torch.stack([it['input_ids'] for it in its]),np.array([it['labels'][-1].item()-NS for it in its]) def get_loaders(p=DP,bs=512,tf=0.5,seed=42): ld={} for op in ALL_OPS: for sp in ['train','test']: ld[f'{op}_{sp}']=DataLoader(MAD(op,p,sp,tf,seed),batch_size=bs,shuffle=(sp=='train'),drop_last=False) return ld # --- Model --- @dataclass class TC: vs:int=104;nl:int=2;dm:int=128;nh:int=4;dmlp:int=512;msl:int=5;do:float=0.0;ln:bool=True class MHA(nn.Module): def __init__(self,c): super().__init__();self.nh=c.nh;self.dh=c.dm//c.nh self.WQ=nn.Linear(c.dm,c.dm,bias=False);self.WK=nn.Linear(c.dm,c.dm,bias=False) self.WV=nn.Linear(c.dm,c.dm,bias=False);self.WO=nn.Linear(c.dm,c.dm,bias=False) def forward(self,x,ra=False): B,T,D=x.shape Q=self.WQ(x).view(B,T,self.nh,self.dh).transpose(1,2) K=self.WK(x).view(B,T,self.nh,self.dh).transpose(1,2) V=self.WV(x).view(B,T,self.nh,self.dh).transpose(1,2) s=(Q@K.transpose(-2,-1))/math.sqrt(self.dh) s.masked_fill_(torch.triu(torch.ones(T,T,device=x.device),1).bool().unsqueeze(0).unsqueeze(0),float('-inf')) a=F.softmax(s,dim=-1);return self.WO((a@V).transpose(1,2).reshape(B,T,D)),(a if ra else None) class MLP2(nn.Module): def __init__(self,c): super().__init__();self.wi=nn.Linear(c.dm,c.dmlp);self.wo=nn.Linear(c.dmlp,c.dm) def forward(self,x):h=F.gelu(self.wi(x));return self.wo(h),h class TB(nn.Module): def __init__(self,c): super().__init__();self.attn=MHA(c);self.mlp=MLP2(c) self.ln1=nn.LayerNorm(c.dm) if c.ln else nn.Identity() self.ln2=nn.LayerNorm(c.dm) if c.ln else nn.Identity() def forward(self,x,ri=False): ao,aw=self.attn(self.ln1(x),ra=ri);x=x+ao;mo,mh=self.mlp(self.ln2(x));x=x+mo r={'hs':x} if ri:r['aw']=aw;r['mh']=mh return r class ST(nn.Module): def __init__(self,c): super().__init__();self.c=c self.te=nn.Embedding(c.vs,c.dm);self.pe=nn.Embedding(c.msl,c.dm) self.blocks=nn.ModuleList([TB(c) for _ in range(c.nl)]) self.lnf=nn.LayerNorm(c.dm) if c.ln else nn.Identity() self.head=nn.Linear(c.dm,c.vs,bias=False);self.head.weight=self.te.weight self.apply(self._iw) def _iw(self,m): if isinstance(m,nn.Linear):nn.init.normal_(m.weight,std=0.02) elif isinstance(m,nn.Embedding):nn.init.normal_(m.weight,std=0.02) def forward(self,ids,labels=None,ri=False): B,T=ids.shape;x=self.te(ids)+self.pe(torch.arange(T,device=ids.device)) ahs=[x.detach()];aaw=[];amh=[] for b in self.blocks: r=b(x,ri=ri);x=r['hs'];ahs.append(x.detach()) if ri:aaw.append(r['aw'].detach());amh.append(r['mh'].detach()) lo=self.head(self.lnf(x));res={'logits':lo} if labels is not None:res['loss']=F.cross_entropy(lo.view(-1,lo.size(-1)),labels.view(-1),ignore_index=-100) if ri:res['hs']=ahs;res['aw']=aaw;res['mh']=amh return res def nparams(self):return sum(p.numel() for p in self.parameters()) # --- Experiment helpers --- def ev(model,dl,dev): model.eval();tl=tc=tt=0 with torch.no_grad(): for b in dl: ids=b['input_ids'].to(dev);labs=b['labels'].to(dev) o=model(ids,labels=labs);tl+=o['loss'].item()*ids.shape[0] tc+=(o['logits'][:,-1,:].argmax(-1)==labs[:,-1]).sum().item();tt+=ids.shape[0] return {'loss':tl/tt,'acc':tc/tt} def creps(model,pids,dev,pos=-1): model.eval() with torch.no_grad():o=model(pids.to(dev),ri=True) return {'hs':[h[:,pos,:].cpu() for h in o['hs']],'aw':[a.cpu() for a in o['aw']]} def cmetrics(model,s0,s1,rc,ri,rp,dev,cfg): m={};nl=cfg.nl+1 for li in range(nl): p=f'layer_{li}';c=rc['hs'][li];ii=ri['hs'][li];pp=rp['hs'][li] m[f'{p}/cka_vs_init']=linear_CKA(c,ii) m[f'{p}/cka_vs_phase1']=linear_CKA(c,pp) k=min(10,c.shape[0]//2,c.shape[1]) m[f'{p}/subspace_angle_vs_phase1']=mean_sa_deg(c,pp,k=k) if k>0 else 0. for li,aw in enumerate(rc['aw']): e=attn_entropy(aw);m[f'layer_{li+1}/attn_entropy_mean']=e['mean'] for h,he in enumerate(e['per_head']):m[f'layer_{li+1}/head_{h}_entropy']=he cs={k:v.cpu() for k,v in model.state_dict().items()} wp=wc_mag(s1,cs) for bi in range(cfg.nl): m[f'block_{bi}/weight_change_from_phase1']=sum(v for k,v in wp.items() if f'blocks.{bi}' in k) return m def train_phase(model,opt,dl,ne,dev,pn,s0,s1,ri,rp,pids,eval_loaders,cfg,ce=20): hist=[];gs=0 for ep in range(ne): model.train();el=nb=0 for b in dl: ids=b['input_ids'].to(dev);labs=b['labels'].to(dev) o=model(ids,labels=labs);o['loss'].backward();opt.step();opt.zero_grad() el+=o['loss'].item();nb+=1;gs+=1 if gs%ce==0: model.eval();rc=creps(model,pids,dev) sm=cmetrics(model,s0,s1,rc,ri,rp,dev,cfg) # Evaluate on ALL tasks for n,ld in eval_loaders.items(): if '_test' in n: e=ev(model,ld,dev);sm[f'eval/{n}_loss']=e['loss'];sm[f'eval/{n}_acc']=e['acc'] # Gradient alignment: addition vs each other task add_batch=next(iter(eval_loaders['add_test'])) def lfn(m,b):return m(b['input_ids'].to(dev),labels=b['labels'].to(dev))['loss'] for task_name in ['subtract','multiply','max','xor']: try: tb=next(iter(eval_loaders[f'{task_name}_test'])) sm[f'grad_align_add_vs_{task_name}']=grad_align(model,add_batch,tb,lfn) except:sm[f'grad_align_add_vs_{task_name}']=0. # Fourier power spectrum of embeddings emb_w = model.te.weight.detach().cpu() fps = fourier_power_spectrum(emb_w, cfg.vs - NS) # Store top-5 peak frequencies top5 = np.argsort(fps)[::-1][:5] sm['fourier_top5_freqs'] = top5.tolist() sm['fourier_top5_power'] = fps[top5].tolist() sm['fourier_total_power'] = float(fps.sum()) sm['fourier_concentration'] = float(fps[top5].sum() / fps.sum()) # how concentrated sm['phase']=pn;sm['epoch']=ep;sm['step']=gs;sm['train_loss']=el/nb hist.append(sm) add_acc=sm.get('eval/add_test_acc',0) task_acc=sm.get(f'eval/{pn.split("_")[-1]}_test_acc', add_acc) print(f"[{pn}] S{gs} L:{el/nb:.4f} AddAcc:{add_acc:.3f} TaskAcc:{task_acc:.3f} " f"CKA(L1vP1):{sm.get('layer_1/cka_vs_phase1',0):.3f} " f"FourierConc:{sm.get('fourier_concentration',0):.3f}") model.train() print(f"[{pn}] Ep{ep+1}/{ne} L:{el/nb:.4f}") return hist # ============= MAIN ============= print("="*70) print("GRADUATED DISSIMILARITY EXPERIMENT") print("="*70) t0=time.time() dev=torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {dev}") p=97;seed=42;torch.manual_seed(seed);np.random.seed(seed) cfg=TC(vs=p+NS,nl=2,dm=128,nh=4,dmlp=512,msl=5) model=ST(cfg).to(dev) print(f"Parameters: {model.nparams():,}") s0={k:v.cpu().clone() for k,v in model.state_dict().items()} lds=get_loaders(p=p,bs=512,tf=0.5,seed=seed) # Probe data from addition test set (same across all branches) dsa=MAD('add',p=p,split='test',tf=0.5,seed=seed) pids_a,pla=get_probe(dsa,500) ri=creps(model,pids_a,dev) # Initial Fourier spectrum init_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p) # ========== PHASE 1: Train on Addition ========== print(f"\n{'='*60}\nPHASE 1: Train on Addition ({150} epochs)\n{'='*60}") opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1.0) h1=train_phase(model,opt,lds['add_train'],150,dev,'p1_add',s0,s0,ri,ri,pids_a,lds,cfg,ce=20) s1={k:v.cpu().clone() for k,v in model.state_dict().items()} rp=creps(model,pids_a,dev) p1_eval={op:ev(model,lds[f'{op}_test'],dev) for op in ALL_OPS} print(f"\nPhase 1 final accuracies:") for op,e in p1_eval.items():print(f" {op}: {e['acc']:.3f}") os.makedirs('results',exist_ok=True) torch.save(model.state_dict(),'results/p1.pt') p1_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p) # ========== PHASE 2: Fork into 5 branches ========== branch_results = {} phase2_epochs = 150 for task_name in ALL_OPS: branch_name = f'p2_{task_name}' level = ALL_OPS.index(task_name) print(f"\n{'='*60}") print(f"PHASE 2 — Branch A→{task_name.upper()} (Level {level}, {phase2_epochs} epochs)") print(f"{'='*60}") m = ST(cfg).to(dev) m.load_state_dict(torch.load('results/p1.pt', weights_only=True)) o = optim.AdamW(m.parameters(), lr=1e-3, weight_decay=1.0) h = train_phase(m, o, lds[f'{task_name}_train'], phase2_epochs, dev, branch_name, s0, s1, ri, rp, pids_a, lds, cfg, ce=20) # Final evaluation on ALL tasks final_eval = {op: ev(m, lds[f'{op}_test'], dev) for op in ALL_OPS} fps_final = fourier_power_spectrum(m.te.weight.detach().cpu(), p) branch_results[task_name] = { 'history': h, 'final_eval': final_eval, 'fourier_spectrum_final': fps_final.tolist(), } print(f"\n Final accuracies after A→{task_name}:") for op, e in final_eval.items(): marker = " ←TRAINED" if op == task_name else (" ←BASE" if op == 'add' else "") print(f" {op}: {e['acc']:.3f}{marker}") # Forgetting = Phase1 add accuracy - current add accuracy forgetting = p1_eval['add']['acc'] - final_eval['add']['acc'] print(f" Addition FORGETTING: {forgetting*100:.1f}%") # ========== PHASE 3: Cross-comparison ========== print(f"\n{'='*60}\nPHASE 3: Cross-model Representation Comparison\n{'='*60}") # Reload all branch final models for comparison branch_reps = {} for task_name in ALL_OPS: # Re-run the final model to get representations # (models were not saved to disk, so re-extract from branch_results last step) # Actually let's just use the last checkpoint reps from training pass # Build summary table summary = { 'config': {'p':p,'n_layers':2,'d_model':128,'n_heads':4,'d_mlp':512, 'phase1_epochs':150,'phase2_epochs':phase2_epochs, 'lr':1e-3,'weight_decay':1.0,'batch_size':512, 'train_frac':0.5,'seed':seed,'n_parameters':model.nparams()}, 'dissimilarity_ladder': { 'add': {'level':0, 'description':'Identical (Fourier circuit)'}, 'subtract': {'level':1, 'description':'Same Fourier circuit (sign flip)'}, 'multiply': {'level':2, 'description':'Discrete-log Fourier circuit'}, 'max': {'level':3, 'description':'Linear/ordinal circuit'}, 'xor': {'level':4, 'description':'Bit-level circuit'}, }, 'phase1_history': h1, 'phase1_final_eval': {op: p1_eval[op] for op in ALL_OPS}, 'fourier_spectrum_init': init_fps.tolist(), 'fourier_spectrum_phase1': p1_fps.tolist(), } # Per-branch results for task_name, br in branch_results.items(): level = ALL_OPS.index(task_name) summary[f'branch_{task_name}'] = { 'level': level, 'history': br['history'], 'final_eval': br['final_eval'], 'fourier_spectrum': br['fourier_spectrum_final'], 'addition_forgetting': p1_eval['add']['acc'] - br['final_eval']['add']['acc'], } # Forgetting summary table print("\n" + "="*60) print("FORGETTING SUMMARY") print("="*60) print(f"{'Task':<15} {'Level':<8} {'Add Acc':<10} {'Task Acc':<10} {'Forgetting':<12} {'Grad Align':<12}") print("-"*67) for task_name in ALL_OPS: br = branch_results[task_name] level = ALL_OPS.index(task_name) add_acc = br['final_eval']['add']['acc'] task_acc = br['final_eval'][task_name]['acc'] forgetting = p1_eval['add']['acc'] - add_acc # Get last gradient alignment from history last_ga = 0 for h in reversed(br['history']): if f'grad_align_add_vs_{task_name}' in h: last_ga = h[f'grad_align_add_vs_{task_name}']; break elif task_name == 'add' and 'grad_align_add_vs_subtract' in h: last_ga = h['grad_align_add_vs_subtract']; break print(f"{task_name:<15} {level:<8} {add_acc:<10.3f} {task_acc:<10.3f} {forgetting*100:<12.1f}% {last_ga:<12.3f}") summary['forgetting_table'] = {} for task_name in ALL_OPS: br = branch_results[task_name] summary['forgetting_table'][task_name] = { 'level': ALL_OPS.index(task_name), 'addition_accuracy': br['final_eval']['add']['acc'], 'task_accuracy': br['final_eval'][task_name]['acc'], 'forgetting_pct': (p1_eval['add']['acc'] - br['final_eval']['add']['acc']) * 100, } with open('results/graduated_experiment_results.json','w') as f: json.dump(summary,f,indent=2,default=str) elapsed = time.time() - t0 print(f"\nExperiment completed in {elapsed/60:.1f} minutes") print("Results saved to results/graduated_experiment_results.json") # ============= GENERATE PLOTS ============= print("\nGenerating plots...") import matplotlib;matplotlib.use('Agg');import matplotlib.pyplot as plt colors = {'add':'#2196F3','subtract':'#4CAF50','multiply':'#FF9800','max':'#F44336','xor':'#9C27B0'} labels = {'add':'Addition (L0)','subtract':'Subtraction (L1)','multiply':'Multiplication (L2)', 'max':'Max (L3)','xor':'XOR (L4)'} # Plot 1: Forgetting bar chart fig,ax=plt.subplots(figsize=(10,6)) tasks=[t for t in ALL_OPS] forg=[summary['forgetting_table'][t]['forgetting_pct'] for t in tasks] task_acc=[summary['forgetting_table'][t]['task_accuracy'] for t in tasks] x=np.arange(len(tasks)) bars=ax.bar(x,forg,color=[colors[t] for t in tasks],edgecolor='black',linewidth=0.5) ax.set_xticks(x);ax.set_xticklabels([labels[t] for t in tasks],rotation=15,ha='right') ax.set_ylabel('Addition Accuracy Forgetting (%)');ax.set_title('Forgetting vs Task Dissimilarity Level') ax.axhline(0,color='gray',ls=':',alpha=0.5) for bar,ta in zip(bars,task_acc): ax.text(bar.get_x()+bar.get_width()/2,bar.get_height()+0.5,f'Task:{ta:.1%}', ha='center',va='bottom',fontsize=8) plt.tight_layout();plt.savefig('results/forgetting_ladder.png',dpi=150);plt.close() # Plot 2: Addition accuracy over training for each branch fig,ax=plt.subplots(figsize=(12,6)) s1s=[h['step'] for h in h1] ax.plot(s1s,[h.get('eval/add_test_acc',0) for h in h1],'k-',lw=2,label='Phase 1 (training)') for task_name in ALL_OPS: br=branch_results[task_name] steps=[h['step']+s1s[-1] for h in br['history']] acc=[h.get('eval/add_test_acc',0) for h in br['history']] ax.plot(steps,acc,'-',color=colors[task_name],lw=2,label=labels[task_name]) ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5,label='Phase transition') ax.set_xlabel('Training Step');ax.set_ylabel('Addition Test Accuracy') ax.set_title('Addition Accuracy: Learning vs Forgetting');ax.legend();ax.set_ylim(-0.05,1.05) plt.tight_layout();plt.savefig('results/addition_accuracy_all_branches.png',dpi=150);plt.close() # Plot 3: CKA vs Phase 1 for each branch (layer 2 = final layer) fig,axes=plt.subplots(1,3,figsize=(18,5)) for li,ax in enumerate(axes): m_name=f'layer_{li}/cka_vs_phase1' for task_name in ALL_OPS: br=branch_results[task_name] steps=[h['step'] for h in br['history'] if m_name in h] vals=[h[m_name] for h in br['history'] if m_name in h] ax.plot(steps,vals,'-',color=colors[task_name],lw=1.5,label=labels[task_name]) ax.set_title(f'{"Embedding" if li==0 else f"Layer {li}"}: CKA vs Phase 1 End') ax.set_xlabel('Step');ax.set_ylabel('CKA');ax.legend(fontsize=7);ax.set_ylim(0,1.05) plt.tight_layout();plt.savefig('results/cka_all_branches.png',dpi=150);plt.close() # Plot 4: Gradient alignment with addition over training fig,ax=plt.subplots(figsize=(12,6)) for task_name in ['subtract','multiply','max','xor']: m_name=f'grad_align_add_vs_{task_name}' # Phase 1 data steps_p1=[h['step'] for h in h1 if m_name in h] vals_p1=[h[m_name] for h in h1 if m_name in h] ax.plot(steps_p1,vals_p1,':',color=colors[task_name],lw=1,alpha=0.5) # Branch data (training on this task) br=branch_results[task_name] steps=[h['step']+s1s[-1] for h in br['history'] if m_name in h] vals=[h[m_name] for h in br['history'] if m_name in h] ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name]) ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5) ax.axhline(0,color='gray',ls=':',alpha=0.3) ax.set_xlabel('Step');ax.set_ylabel('Gradient Cosine Similarity') ax.set_title('Gradient Alignment with Addition Task');ax.legend() plt.tight_layout();plt.savefig('results/gradient_alignment_all.png',dpi=150);plt.close() # Plot 5: Fourier spectrum comparison fig,axes=plt.subplots(2,3,figsize=(18,10)) freqs=np.arange(len(init_fps)) axes[0,0].bar(freqs,init_fps,color='gray',alpha=0.5);axes[0,0].set_title('Init');axes[0,0].set_ylabel('Power') axes[0,1].bar(freqs,p1_fps,color='black',alpha=0.7);axes[0,1].set_title('After Phase 1 (Addition)') for i,task_name in enumerate(['subtract','multiply','max','xor']): ax=axes[(i+2)//3,(i+2)%3] fps=np.array(branch_results[task_name]['fourier_spectrum_final']) ax.bar(freqs[:len(fps)],fps,color=colors[task_name],alpha=0.7) ax.set_title(f'After A→{task_name.title()}') if (i+2)//3==1:ax.set_xlabel('Frequency') if (i+2)%3==0:ax.set_ylabel('Power') plt.suptitle('Embedding Fourier Power Spectrum Evolution',fontsize=14) plt.tight_layout();plt.savefig('results/fourier_spectra.png',dpi=150);plt.close() # Plot 6: Subspace angles fig,ax=plt.subplots(figsize=(12,6)) for task_name in ALL_OPS: m_name='layer_2/subspace_angle_vs_phase1' br=branch_results[task_name] steps=[h['step'] for h in br['history'] if m_name in h] vals=[h[m_name] for h in br['history'] if m_name in h] ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name]) ax.set_xlabel('Step');ax.set_ylabel('Subspace Angle (degrees)') ax.set_title('Final Layer Subspace Angle Drift from Phase 1');ax.legend() plt.tight_layout();plt.savefig('results/subspace_angles_all.png',dpi=150);plt.close() print("All plots generated!") # Upload to HF from huggingface_hub import HfApi api=HfApi(token=os.environ.get('HF_TOKEN')) repo='tekkmaven/representation-learning-dynamics' files_to_upload = [ 'graduated_experiment_results.json', 'forgetting_ladder.png','addition_accuracy_all_branches.png', 'cka_all_branches.png','gradient_alignment_all.png', 'fourier_spectra.png','subspace_angles_all.png', ] for f in files_to_upload: try: api.upload_file(path_or_fileobj=f'results/{f}',path_in_repo=f'results/{f}',repo_id=repo,repo_type='model') print(f"Uploaded {f}") except Exception as e: print(f"Failed to upload {f}: {e}") print(f"\n=== EXPERIMENT COMPLETE ===") print(f"Total time: {(time.time()-t0)/60:.1f} minutes") print(f"Repository: https://huggingface.co/{repo}")