tekkmaven's picture
Upload run_graduated.py with huggingface_hub
36ee7e2 verified
#!/usr/bin/env python3
"""
Graduated Dissimilarity Experiment
====================================
Train on modular addition, then fork into 5 branches with tasks of
increasing dissimilarity. Track representation metrics to find the
tipping point where forgetting begins.
Branches:
A→A: Continue addition (Level 0: identical)
A→B: Switch to subtraction (Level 1: same Fourier circuit)
A→C: Switch to multiplication (Level 2: different Fourier freqs)
A→D: Switch to max(a,b) (Level 3: linear/ordinal circuit)
A→E: Switch to XOR (Level 4: bit-level circuit)
"""
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
import numpy as np, json, os, math, copy, time
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Tuple
# ============= INLINE DEPS (self-contained for job) =============
# --- Representation metrics ---
def centering(K):
n=K.shape[0]; u=torch.ones(n,n,device=K.device,dtype=K.dtype)/n
return K-u@K-K@u+u@K@u
def linear_HSIC(X,Y):
K=X@X.T; L=Y@Y.T; return (centering(K)*centering(L)).sum()/((X.shape[0]-1)**2)
def linear_CKA(X,Y):
xy=linear_HSIC(X,Y);xx=linear_HSIC(X,X);yy=linear_HSIC(Y,Y)
return (xy/(xx.sqrt()*yy.sqrt()).clamp(min=1e-10)).item()
def cka_heatmap(sa,sb):
hm=np.zeros((len(sa),len(sb)))
for i in range(len(sa)):
for j in range(len(sb)): hm[i,j]=linear_CKA(sa[i],sb[j])
return hm
def subspace_angles(X,Y,k=10):
def tb(Z,k):
_,_,Vh=torch.linalg.svd(Z-Z.mean(0),full_matrices=False); return Vh[:min(k,Vh.shape[0])].T
Qx=tb(X,k);Qy=tb(Y,k); mk=min(Qx.shape[1],Qy.shape[1])
return torch.arccos(torch.linalg.svdvals(Qx[:,:mk].T@Qy[:,:mk]).clamp(-1,1))
def mean_sa_deg(X,Y,k=10): return (subspace_angles(X,Y,k).mean()*180/torch.pi).item()
def grad_align(model,ba,bb,lfn):
model.zero_grad();lfn(model,ba).backward()
ga=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone()
model.zero_grad();lfn(model,bb).backward()
gb=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone()
model.zero_grad(); return F.cosine_similarity(ga.unsqueeze(0),gb.unsqueeze(0)).item()
def attn_entropy(aw):
H=-(aw*(aw+1e-9).log2()).sum(-1)
return {'mean':H.mean().item(),'per_head':H.mean(dim=(0,2)).cpu().tolist()}
def wc_mag(s0,s1):
return {n:(s1[n].float()-s0[n].float()).norm().item() for n in s0 if n in s1}
# Fourier power spectrum of embedding matrix
def fourier_power_spectrum(W_E, p):
"""Compute Fourier power at each frequency for token embeddings.
W_E: [vocab_size, d_model]. Returns power at each freq 0..p//2."""
# Extract number embeddings only (skip special tokens)
from tasks import NUM_SPECIAL
num_emb = W_E[NUM_SPECIAL:NUM_SPECIAL+p, :] # [p, d_model]
fft = torch.fft.rfft(num_emb, dim=0) # [p//2+1, d_model]
power = (fft.abs() ** 2).sum(dim=1) # [p//2+1]
return power.cpu().numpy()
# --- Tasks ---
NS = 7 # NUM_SPECIAL (updated)
DP = 97
OP_TOKENS = {'add':2,'subtract':3,'multiply':4,'max':5,'xor':6}
ALL_OPS = ['add','subtract','multiply','max','xor']
class MAD(Dataset):
def __init__(self,op='add',p=DP,split='train',tf=0.5,seed=42):
self.p=p;self.op=op;self.ot=OP_TOKENS[op]
ap=[(a,b) for a in range(p) for b in range(p)]
rng=np.random.RandomState(seed);rng.shuffle(ap)
nt=int(len(ap)*tf); self.pairs=ap[:nt] if split=='train' else ap[nt:]
def _c(self,a,b):
if self.op=='add': return(a+b)%self.p
elif self.op=='subtract': return(a-b)%self.p
elif self.op=='multiply': return(a*b)%self.p
elif self.op=='max': return max(a,b)
elif self.op=='xor': return(a^b)%self.p
def __len__(self): return len(self.pairs)
def __getitem__(self,i):
a,b=self.pairs[i];c=self._c(a,b)
return {'input_ids':torch.tensor([a+NS,self.ot,b+NS,1,c+NS],dtype=torch.long),
'labels':torch.tensor([-100,-100,-100,-100,c+NS],dtype=torch.long)}
def get_probe(ds,n=500):
n=min(n,len(ds));its=[ds[i] for i in range(n)]
return torch.stack([it['input_ids'] for it in its]),np.array([it['labels'][-1].item()-NS for it in its])
def get_loaders(p=DP,bs=512,tf=0.5,seed=42):
ld={}
for op in ALL_OPS:
for sp in ['train','test']:
ld[f'{op}_{sp}']=DataLoader(MAD(op,p,sp,tf,seed),batch_size=bs,shuffle=(sp=='train'),drop_last=False)
return ld
# --- Model ---
@dataclass
class TC:
vs:int=104;nl:int=2;dm:int=128;nh:int=4;dmlp:int=512;msl:int=5;do:float=0.0;ln:bool=True
class MHA(nn.Module):
def __init__(self,c):
super().__init__();self.nh=c.nh;self.dh=c.dm//c.nh
self.WQ=nn.Linear(c.dm,c.dm,bias=False);self.WK=nn.Linear(c.dm,c.dm,bias=False)
self.WV=nn.Linear(c.dm,c.dm,bias=False);self.WO=nn.Linear(c.dm,c.dm,bias=False)
def forward(self,x,ra=False):
B,T,D=x.shape
Q=self.WQ(x).view(B,T,self.nh,self.dh).transpose(1,2)
K=self.WK(x).view(B,T,self.nh,self.dh).transpose(1,2)
V=self.WV(x).view(B,T,self.nh,self.dh).transpose(1,2)
s=(Q@K.transpose(-2,-1))/math.sqrt(self.dh)
s.masked_fill_(torch.triu(torch.ones(T,T,device=x.device),1).bool().unsqueeze(0).unsqueeze(0),float('-inf'))
a=F.softmax(s,dim=-1);return self.WO((a@V).transpose(1,2).reshape(B,T,D)),(a if ra else None)
class MLP2(nn.Module):
def __init__(self,c):
super().__init__();self.wi=nn.Linear(c.dm,c.dmlp);self.wo=nn.Linear(c.dmlp,c.dm)
def forward(self,x):h=F.gelu(self.wi(x));return self.wo(h),h
class TB(nn.Module):
def __init__(self,c):
super().__init__();self.attn=MHA(c);self.mlp=MLP2(c)
self.ln1=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
self.ln2=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
def forward(self,x,ri=False):
ao,aw=self.attn(self.ln1(x),ra=ri);x=x+ao;mo,mh=self.mlp(self.ln2(x));x=x+mo
r={'hs':x}
if ri:r['aw']=aw;r['mh']=mh
return r
class ST(nn.Module):
def __init__(self,c):
super().__init__();self.c=c
self.te=nn.Embedding(c.vs,c.dm);self.pe=nn.Embedding(c.msl,c.dm)
self.blocks=nn.ModuleList([TB(c) for _ in range(c.nl)])
self.lnf=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
self.head=nn.Linear(c.dm,c.vs,bias=False);self.head.weight=self.te.weight
self.apply(self._iw)
def _iw(self,m):
if isinstance(m,nn.Linear):nn.init.normal_(m.weight,std=0.02)
elif isinstance(m,nn.Embedding):nn.init.normal_(m.weight,std=0.02)
def forward(self,ids,labels=None,ri=False):
B,T=ids.shape;x=self.te(ids)+self.pe(torch.arange(T,device=ids.device))
ahs=[x.detach()];aaw=[];amh=[]
for b in self.blocks:
r=b(x,ri=ri);x=r['hs'];ahs.append(x.detach())
if ri:aaw.append(r['aw'].detach());amh.append(r['mh'].detach())
lo=self.head(self.lnf(x));res={'logits':lo}
if labels is not None:res['loss']=F.cross_entropy(lo.view(-1,lo.size(-1)),labels.view(-1),ignore_index=-100)
if ri:res['hs']=ahs;res['aw']=aaw;res['mh']=amh
return res
def nparams(self):return sum(p.numel() for p in self.parameters())
# --- Experiment helpers ---
def ev(model,dl,dev):
model.eval();tl=tc=tt=0
with torch.no_grad():
for b in dl:
ids=b['input_ids'].to(dev);labs=b['labels'].to(dev)
o=model(ids,labels=labs);tl+=o['loss'].item()*ids.shape[0]
tc+=(o['logits'][:,-1,:].argmax(-1)==labs[:,-1]).sum().item();tt+=ids.shape[0]
return {'loss':tl/tt,'acc':tc/tt}
def creps(model,pids,dev,pos=-1):
model.eval()
with torch.no_grad():o=model(pids.to(dev),ri=True)
return {'hs':[h[:,pos,:].cpu() for h in o['hs']],'aw':[a.cpu() for a in o['aw']]}
def cmetrics(model,s0,s1,rc,ri,rp,dev,cfg):
m={};nl=cfg.nl+1
for li in range(nl):
p=f'layer_{li}';c=rc['hs'][li];ii=ri['hs'][li];pp=rp['hs'][li]
m[f'{p}/cka_vs_init']=linear_CKA(c,ii)
m[f'{p}/cka_vs_phase1']=linear_CKA(c,pp)
k=min(10,c.shape[0]//2,c.shape[1])
m[f'{p}/subspace_angle_vs_phase1']=mean_sa_deg(c,pp,k=k) if k>0 else 0.
for li,aw in enumerate(rc['aw']):
e=attn_entropy(aw);m[f'layer_{li+1}/attn_entropy_mean']=e['mean']
for h,he in enumerate(e['per_head']):m[f'layer_{li+1}/head_{h}_entropy']=he
cs={k:v.cpu() for k,v in model.state_dict().items()}
wp=wc_mag(s1,cs)
for bi in range(cfg.nl):
m[f'block_{bi}/weight_change_from_phase1']=sum(v for k,v in wp.items() if f'blocks.{bi}' in k)
return m
def train_phase(model,opt,dl,ne,dev,pn,s0,s1,ri,rp,pids,eval_loaders,cfg,ce=20):
hist=[];gs=0
for ep in range(ne):
model.train();el=nb=0
for b in dl:
ids=b['input_ids'].to(dev);labs=b['labels'].to(dev)
o=model(ids,labels=labs);o['loss'].backward();opt.step();opt.zero_grad()
el+=o['loss'].item();nb+=1;gs+=1
if gs%ce==0:
model.eval();rc=creps(model,pids,dev)
sm=cmetrics(model,s0,s1,rc,ri,rp,dev,cfg)
# Evaluate on ALL tasks
for n,ld in eval_loaders.items():
if '_test' in n:
e=ev(model,ld,dev);sm[f'eval/{n}_loss']=e['loss'];sm[f'eval/{n}_acc']=e['acc']
# Gradient alignment: addition vs each other task
add_batch=next(iter(eval_loaders['add_test']))
def lfn(m,b):return m(b['input_ids'].to(dev),labels=b['labels'].to(dev))['loss']
for task_name in ['subtract','multiply','max','xor']:
try:
tb=next(iter(eval_loaders[f'{task_name}_test']))
sm[f'grad_align_add_vs_{task_name}']=grad_align(model,add_batch,tb,lfn)
except:sm[f'grad_align_add_vs_{task_name}']=0.
# Fourier power spectrum of embeddings
emb_w = model.te.weight.detach().cpu()
fps = fourier_power_spectrum(emb_w, cfg.vs - NS)
# Store top-5 peak frequencies
top5 = np.argsort(fps)[::-1][:5]
sm['fourier_top5_freqs'] = top5.tolist()
sm['fourier_top5_power'] = fps[top5].tolist()
sm['fourier_total_power'] = float(fps.sum())
sm['fourier_concentration'] = float(fps[top5].sum() / fps.sum()) # how concentrated
sm['phase']=pn;sm['epoch']=ep;sm['step']=gs;sm['train_loss']=el/nb
hist.append(sm)
add_acc=sm.get('eval/add_test_acc',0)
task_acc=sm.get(f'eval/{pn.split("_")[-1]}_test_acc', add_acc)
print(f"[{pn}] S{gs} L:{el/nb:.4f} AddAcc:{add_acc:.3f} TaskAcc:{task_acc:.3f} "
f"CKA(L1vP1):{sm.get('layer_1/cka_vs_phase1',0):.3f} "
f"FourierConc:{sm.get('fourier_concentration',0):.3f}")
model.train()
print(f"[{pn}] Ep{ep+1}/{ne} L:{el/nb:.4f}")
return hist
# ============= MAIN =============
print("="*70)
print("GRADUATED DISSIMILARITY EXPERIMENT")
print("="*70)
t0=time.time()
dev=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {dev}")
p=97;seed=42;torch.manual_seed(seed);np.random.seed(seed)
cfg=TC(vs=p+NS,nl=2,dm=128,nh=4,dmlp=512,msl=5)
model=ST(cfg).to(dev)
print(f"Parameters: {model.nparams():,}")
s0={k:v.cpu().clone() for k,v in model.state_dict().items()}
lds=get_loaders(p=p,bs=512,tf=0.5,seed=seed)
# Probe data from addition test set (same across all branches)
dsa=MAD('add',p=p,split='test',tf=0.5,seed=seed)
pids_a,pla=get_probe(dsa,500)
ri=creps(model,pids_a,dev)
# Initial Fourier spectrum
init_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p)
# ========== PHASE 1: Train on Addition ==========
print(f"\n{'='*60}\nPHASE 1: Train on Addition ({150} epochs)\n{'='*60}")
opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1.0)
h1=train_phase(model,opt,lds['add_train'],150,dev,'p1_add',s0,s0,ri,ri,pids_a,lds,cfg,ce=20)
s1={k:v.cpu().clone() for k,v in model.state_dict().items()}
rp=creps(model,pids_a,dev)
p1_eval={op:ev(model,lds[f'{op}_test'],dev) for op in ALL_OPS}
print(f"\nPhase 1 final accuracies:")
for op,e in p1_eval.items():print(f" {op}: {e['acc']:.3f}")
os.makedirs('results',exist_ok=True)
torch.save(model.state_dict(),'results/p1.pt')
p1_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p)
# ========== PHASE 2: Fork into 5 branches ==========
branch_results = {}
phase2_epochs = 150
for task_name in ALL_OPS:
branch_name = f'p2_{task_name}'
level = ALL_OPS.index(task_name)
print(f"\n{'='*60}")
print(f"PHASE 2 — Branch A→{task_name.upper()} (Level {level}, {phase2_epochs} epochs)")
print(f"{'='*60}")
m = ST(cfg).to(dev)
m.load_state_dict(torch.load('results/p1.pt', weights_only=True))
o = optim.AdamW(m.parameters(), lr=1e-3, weight_decay=1.0)
h = train_phase(m, o, lds[f'{task_name}_train'], phase2_epochs, dev,
branch_name, s0, s1, ri, rp, pids_a, lds, cfg, ce=20)
# Final evaluation on ALL tasks
final_eval = {op: ev(m, lds[f'{op}_test'], dev) for op in ALL_OPS}
fps_final = fourier_power_spectrum(m.te.weight.detach().cpu(), p)
branch_results[task_name] = {
'history': h,
'final_eval': final_eval,
'fourier_spectrum_final': fps_final.tolist(),
}
print(f"\n Final accuracies after A→{task_name}:")
for op, e in final_eval.items():
marker = " ←TRAINED" if op == task_name else (" ←BASE" if op == 'add' else "")
print(f" {op}: {e['acc']:.3f}{marker}")
# Forgetting = Phase1 add accuracy - current add accuracy
forgetting = p1_eval['add']['acc'] - final_eval['add']['acc']
print(f" Addition FORGETTING: {forgetting*100:.1f}%")
# ========== PHASE 3: Cross-comparison ==========
print(f"\n{'='*60}\nPHASE 3: Cross-model Representation Comparison\n{'='*60}")
# Reload all branch final models for comparison
branch_reps = {}
for task_name in ALL_OPS:
# Re-run the final model to get representations
# (models were not saved to disk, so re-extract from branch_results last step)
# Actually let's just use the last checkpoint reps from training
pass
# Build summary table
summary = {
'config': {'p':p,'n_layers':2,'d_model':128,'n_heads':4,'d_mlp':512,
'phase1_epochs':150,'phase2_epochs':phase2_epochs,
'lr':1e-3,'weight_decay':1.0,'batch_size':512,
'train_frac':0.5,'seed':seed,'n_parameters':model.nparams()},
'dissimilarity_ladder': {
'add': {'level':0, 'description':'Identical (Fourier circuit)'},
'subtract': {'level':1, 'description':'Same Fourier circuit (sign flip)'},
'multiply': {'level':2, 'description':'Discrete-log Fourier circuit'},
'max': {'level':3, 'description':'Linear/ordinal circuit'},
'xor': {'level':4, 'description':'Bit-level circuit'},
},
'phase1_history': h1,
'phase1_final_eval': {op: p1_eval[op] for op in ALL_OPS},
'fourier_spectrum_init': init_fps.tolist(),
'fourier_spectrum_phase1': p1_fps.tolist(),
}
# Per-branch results
for task_name, br in branch_results.items():
level = ALL_OPS.index(task_name)
summary[f'branch_{task_name}'] = {
'level': level,
'history': br['history'],
'final_eval': br['final_eval'],
'fourier_spectrum': br['fourier_spectrum_final'],
'addition_forgetting': p1_eval['add']['acc'] - br['final_eval']['add']['acc'],
}
# Forgetting summary table
print("\n" + "="*60)
print("FORGETTING SUMMARY")
print("="*60)
print(f"{'Task':<15} {'Level':<8} {'Add Acc':<10} {'Task Acc':<10} {'Forgetting':<12} {'Grad Align':<12}")
print("-"*67)
for task_name in ALL_OPS:
br = branch_results[task_name]
level = ALL_OPS.index(task_name)
add_acc = br['final_eval']['add']['acc']
task_acc = br['final_eval'][task_name]['acc']
forgetting = p1_eval['add']['acc'] - add_acc
# Get last gradient alignment from history
last_ga = 0
for h in reversed(br['history']):
if f'grad_align_add_vs_{task_name}' in h:
last_ga = h[f'grad_align_add_vs_{task_name}']; break
elif task_name == 'add' and 'grad_align_add_vs_subtract' in h:
last_ga = h['grad_align_add_vs_subtract']; break
print(f"{task_name:<15} {level:<8} {add_acc:<10.3f} {task_acc:<10.3f} {forgetting*100:<12.1f}% {last_ga:<12.3f}")
summary['forgetting_table'] = {}
for task_name in ALL_OPS:
br = branch_results[task_name]
summary['forgetting_table'][task_name] = {
'level': ALL_OPS.index(task_name),
'addition_accuracy': br['final_eval']['add']['acc'],
'task_accuracy': br['final_eval'][task_name]['acc'],
'forgetting_pct': (p1_eval['add']['acc'] - br['final_eval']['add']['acc']) * 100,
}
with open('results/graduated_experiment_results.json','w') as f:
json.dump(summary,f,indent=2,default=str)
elapsed = time.time() - t0
print(f"\nExperiment completed in {elapsed/60:.1f} minutes")
print("Results saved to results/graduated_experiment_results.json")
# ============= GENERATE PLOTS =============
print("\nGenerating plots...")
import matplotlib;matplotlib.use('Agg');import matplotlib.pyplot as plt
colors = {'add':'#2196F3','subtract':'#4CAF50','multiply':'#FF9800','max':'#F44336','xor':'#9C27B0'}
labels = {'add':'Addition (L0)','subtract':'Subtraction (L1)','multiply':'Multiplication (L2)',
'max':'Max (L3)','xor':'XOR (L4)'}
# Plot 1: Forgetting bar chart
fig,ax=plt.subplots(figsize=(10,6))
tasks=[t for t in ALL_OPS]
forg=[summary['forgetting_table'][t]['forgetting_pct'] for t in tasks]
task_acc=[summary['forgetting_table'][t]['task_accuracy'] for t in tasks]
x=np.arange(len(tasks))
bars=ax.bar(x,forg,color=[colors[t] for t in tasks],edgecolor='black',linewidth=0.5)
ax.set_xticks(x);ax.set_xticklabels([labels[t] for t in tasks],rotation=15,ha='right')
ax.set_ylabel('Addition Accuracy Forgetting (%)');ax.set_title('Forgetting vs Task Dissimilarity Level')
ax.axhline(0,color='gray',ls=':',alpha=0.5)
for bar,ta in zip(bars,task_acc):
ax.text(bar.get_x()+bar.get_width()/2,bar.get_height()+0.5,f'Task:{ta:.1%}',
ha='center',va='bottom',fontsize=8)
plt.tight_layout();plt.savefig('results/forgetting_ladder.png',dpi=150);plt.close()
# Plot 2: Addition accuracy over training for each branch
fig,ax=plt.subplots(figsize=(12,6))
s1s=[h['step'] for h in h1]
ax.plot(s1s,[h.get('eval/add_test_acc',0) for h in h1],'k-',lw=2,label='Phase 1 (training)')
for task_name in ALL_OPS:
br=branch_results[task_name]
steps=[h['step']+s1s[-1] for h in br['history']]
acc=[h.get('eval/add_test_acc',0) for h in br['history']]
ax.plot(steps,acc,'-',color=colors[task_name],lw=2,label=labels[task_name])
ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5,label='Phase transition')
ax.set_xlabel('Training Step');ax.set_ylabel('Addition Test Accuracy')
ax.set_title('Addition Accuracy: Learning vs Forgetting');ax.legend();ax.set_ylim(-0.05,1.05)
plt.tight_layout();plt.savefig('results/addition_accuracy_all_branches.png',dpi=150);plt.close()
# Plot 3: CKA vs Phase 1 for each branch (layer 2 = final layer)
fig,axes=plt.subplots(1,3,figsize=(18,5))
for li,ax in enumerate(axes):
m_name=f'layer_{li}/cka_vs_phase1'
for task_name in ALL_OPS:
br=branch_results[task_name]
steps=[h['step'] for h in br['history'] if m_name in h]
vals=[h[m_name] for h in br['history'] if m_name in h]
ax.plot(steps,vals,'-',color=colors[task_name],lw=1.5,label=labels[task_name])
ax.set_title(f'{"Embedding" if li==0 else f"Layer {li}"}: CKA vs Phase 1 End')
ax.set_xlabel('Step');ax.set_ylabel('CKA');ax.legend(fontsize=7);ax.set_ylim(0,1.05)
plt.tight_layout();plt.savefig('results/cka_all_branches.png',dpi=150);plt.close()
# Plot 4: Gradient alignment with addition over training
fig,ax=plt.subplots(figsize=(12,6))
for task_name in ['subtract','multiply','max','xor']:
m_name=f'grad_align_add_vs_{task_name}'
# Phase 1 data
steps_p1=[h['step'] for h in h1 if m_name in h]
vals_p1=[h[m_name] for h in h1 if m_name in h]
ax.plot(steps_p1,vals_p1,':',color=colors[task_name],lw=1,alpha=0.5)
# Branch data (training on this task)
br=branch_results[task_name]
steps=[h['step']+s1s[-1] for h in br['history'] if m_name in h]
vals=[h[m_name] for h in br['history'] if m_name in h]
ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name])
ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5)
ax.axhline(0,color='gray',ls=':',alpha=0.3)
ax.set_xlabel('Step');ax.set_ylabel('Gradient Cosine Similarity')
ax.set_title('Gradient Alignment with Addition Task');ax.legend()
plt.tight_layout();plt.savefig('results/gradient_alignment_all.png',dpi=150);plt.close()
# Plot 5: Fourier spectrum comparison
fig,axes=plt.subplots(2,3,figsize=(18,10))
freqs=np.arange(len(init_fps))
axes[0,0].bar(freqs,init_fps,color='gray',alpha=0.5);axes[0,0].set_title('Init');axes[0,0].set_ylabel('Power')
axes[0,1].bar(freqs,p1_fps,color='black',alpha=0.7);axes[0,1].set_title('After Phase 1 (Addition)')
for i,task_name in enumerate(['subtract','multiply','max','xor']):
ax=axes[(i+2)//3,(i+2)%3]
fps=np.array(branch_results[task_name]['fourier_spectrum_final'])
ax.bar(freqs[:len(fps)],fps,color=colors[task_name],alpha=0.7)
ax.set_title(f'After A→{task_name.title()}')
if (i+2)//3==1:ax.set_xlabel('Frequency')
if (i+2)%3==0:ax.set_ylabel('Power')
plt.suptitle('Embedding Fourier Power Spectrum Evolution',fontsize=14)
plt.tight_layout();plt.savefig('results/fourier_spectra.png',dpi=150);plt.close()
# Plot 6: Subspace angles
fig,ax=plt.subplots(figsize=(12,6))
for task_name in ALL_OPS:
m_name='layer_2/subspace_angle_vs_phase1'
br=branch_results[task_name]
steps=[h['step'] for h in br['history'] if m_name in h]
vals=[h[m_name] for h in br['history'] if m_name in h]
ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name])
ax.set_xlabel('Step');ax.set_ylabel('Subspace Angle (degrees)')
ax.set_title('Final Layer Subspace Angle Drift from Phase 1');ax.legend()
plt.tight_layout();plt.savefig('results/subspace_angles_all.png',dpi=150);plt.close()
print("All plots generated!")
# Upload to HF
from huggingface_hub import HfApi
api=HfApi(token=os.environ.get('HF_TOKEN'))
repo='tekkmaven/representation-learning-dynamics'
files_to_upload = [
'graduated_experiment_results.json',
'forgetting_ladder.png','addition_accuracy_all_branches.png',
'cka_all_branches.png','gradient_alignment_all.png',
'fourier_spectra.png','subspace_angles_all.png',
]
for f in files_to_upload:
try:
api.upload_file(path_or_fileobj=f'results/{f}',path_in_repo=f'results/{f}',repo_id=repo,repo_type='model')
print(f"Uploaded {f}")
except Exception as e:
print(f"Failed to upload {f}: {e}")
print(f"\n=== EXPERIMENT COMPLETE ===")
print(f"Total time: {(time.time()-t0)/60:.1f} minutes")
print(f"Repository: https://huggingface.co/{repo}")