Upload run_graduated.py with huggingface_hub
Browse files- run_graduated.py +509 -0
run_graduated.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Graduated Dissimilarity Experiment
|
| 4 |
+
====================================
|
| 5 |
+
Train on modular addition, then fork into 5 branches with tasks of
|
| 6 |
+
increasing dissimilarity. Track representation metrics to find the
|
| 7 |
+
tipping point where forgetting begins.
|
| 8 |
+
|
| 9 |
+
Branches:
|
| 10 |
+
A→A: Continue addition (Level 0: identical)
|
| 11 |
+
A→B: Switch to subtraction (Level 1: same Fourier circuit)
|
| 12 |
+
A→C: Switch to multiplication (Level 2: different Fourier freqs)
|
| 13 |
+
A→D: Switch to max(a,b) (Level 3: linear/ordinal circuit)
|
| 14 |
+
A→E: Switch to XOR (Level 4: bit-level circuit)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
|
| 18 |
+
import numpy as np, json, os, math, copy, time
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from torch.utils.data import Dataset, DataLoader
|
| 21 |
+
from typing import Dict, List, Tuple
|
| 22 |
+
|
| 23 |
+
# ============= INLINE DEPS (self-contained for job) =============
|
| 24 |
+
|
| 25 |
+
# --- Representation metrics ---
|
| 26 |
+
def centering(K):
|
| 27 |
+
n=K.shape[0]; u=torch.ones(n,n,device=K.device,dtype=K.dtype)/n
|
| 28 |
+
return K-u@K-K@u+u@K@u
|
| 29 |
+
def linear_HSIC(X,Y):
|
| 30 |
+
K=X@X.T; L=Y@Y.T; return (centering(K)*centering(L)).sum()/((X.shape[0]-1)**2)
|
| 31 |
+
def linear_CKA(X,Y):
|
| 32 |
+
xy=linear_HSIC(X,Y);xx=linear_HSIC(X,X);yy=linear_HSIC(Y,Y)
|
| 33 |
+
return (xy/(xx.sqrt()*yy.sqrt()).clamp(min=1e-10)).item()
|
| 34 |
+
def cka_heatmap(sa,sb):
|
| 35 |
+
hm=np.zeros((len(sa),len(sb)))
|
| 36 |
+
for i in range(len(sa)):
|
| 37 |
+
for j in range(len(sb)): hm[i,j]=linear_CKA(sa[i],sb[j])
|
| 38 |
+
return hm
|
| 39 |
+
def subspace_angles(X,Y,k=10):
|
| 40 |
+
def tb(Z,k):
|
| 41 |
+
_,_,Vh=torch.linalg.svd(Z-Z.mean(0),full_matrices=False); return Vh[:min(k,Vh.shape[0])].T
|
| 42 |
+
Qx=tb(X,k);Qy=tb(Y,k); mk=min(Qx.shape[1],Qy.shape[1])
|
| 43 |
+
return torch.arccos(torch.linalg.svdvals(Qx[:,:mk].T@Qy[:,:mk]).clamp(-1,1))
|
| 44 |
+
def mean_sa_deg(X,Y,k=10): return (subspace_angles(X,Y,k).mean()*180/torch.pi).item()
|
| 45 |
+
def grad_align(model,ba,bb,lfn):
|
| 46 |
+
model.zero_grad();lfn(model,ba).backward()
|
| 47 |
+
ga=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone()
|
| 48 |
+
model.zero_grad();lfn(model,bb).backward()
|
| 49 |
+
gb=torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone()
|
| 50 |
+
model.zero_grad(); return F.cosine_similarity(ga.unsqueeze(0),gb.unsqueeze(0)).item()
|
| 51 |
+
def attn_entropy(aw):
|
| 52 |
+
H=-(aw*(aw+1e-9).log2()).sum(-1)
|
| 53 |
+
return {'mean':H.mean().item(),'per_head':H.mean(dim=(0,2)).cpu().tolist()}
|
| 54 |
+
def wc_mag(s0,s1):
|
| 55 |
+
return {n:(s1[n].float()-s0[n].float()).norm().item() for n in s0 if n in s1}
|
| 56 |
+
|
| 57 |
+
# Fourier power spectrum of embedding matrix
|
| 58 |
+
def fourier_power_spectrum(W_E, p):
|
| 59 |
+
"""Compute Fourier power at each frequency for token embeddings.
|
| 60 |
+
W_E: [vocab_size, d_model]. Returns power at each freq 0..p//2."""
|
| 61 |
+
# Extract number embeddings only (skip special tokens)
|
| 62 |
+
from tasks import NUM_SPECIAL
|
| 63 |
+
num_emb = W_E[NUM_SPECIAL:NUM_SPECIAL+p, :] # [p, d_model]
|
| 64 |
+
fft = torch.fft.rfft(num_emb, dim=0) # [p//2+1, d_model]
|
| 65 |
+
power = (fft.abs() ** 2).sum(dim=1) # [p//2+1]
|
| 66 |
+
return power.cpu().numpy()
|
| 67 |
+
|
| 68 |
+
# --- Tasks ---
|
| 69 |
+
NS = 7 # NUM_SPECIAL (updated)
|
| 70 |
+
DP = 97
|
| 71 |
+
|
| 72 |
+
OP_TOKENS = {'add':2,'subtract':3,'multiply':4,'max':5,'xor':6}
|
| 73 |
+
ALL_OPS = ['add','subtract','multiply','max','xor']
|
| 74 |
+
|
| 75 |
+
class MAD(Dataset):
|
| 76 |
+
def __init__(self,op='add',p=DP,split='train',tf=0.5,seed=42):
|
| 77 |
+
self.p=p;self.op=op;self.ot=OP_TOKENS[op]
|
| 78 |
+
ap=[(a,b) for a in range(p) for b in range(p)]
|
| 79 |
+
rng=np.random.RandomState(seed);rng.shuffle(ap)
|
| 80 |
+
nt=int(len(ap)*tf); self.pairs=ap[:nt] if split=='train' else ap[nt:]
|
| 81 |
+
def _c(self,a,b):
|
| 82 |
+
if self.op=='add': return(a+b)%self.p
|
| 83 |
+
elif self.op=='subtract': return(a-b)%self.p
|
| 84 |
+
elif self.op=='multiply': return(a*b)%self.p
|
| 85 |
+
elif self.op=='max': return max(a,b)
|
| 86 |
+
elif self.op=='xor': return(a^b)%self.p
|
| 87 |
+
def __len__(self): return len(self.pairs)
|
| 88 |
+
def __getitem__(self,i):
|
| 89 |
+
a,b=self.pairs[i];c=self._c(a,b)
|
| 90 |
+
return {'input_ids':torch.tensor([a+NS,self.ot,b+NS,1,c+NS],dtype=torch.long),
|
| 91 |
+
'labels':torch.tensor([-100,-100,-100,-100,c+NS],dtype=torch.long)}
|
| 92 |
+
|
| 93 |
+
def get_probe(ds,n=500):
|
| 94 |
+
n=min(n,len(ds));its=[ds[i] for i in range(n)]
|
| 95 |
+
return torch.stack([it['input_ids'] for it in its]),np.array([it['labels'][-1].item()-NS for it in its])
|
| 96 |
+
|
| 97 |
+
def get_loaders(p=DP,bs=512,tf=0.5,seed=42):
|
| 98 |
+
ld={}
|
| 99 |
+
for op in ALL_OPS:
|
| 100 |
+
for sp in ['train','test']:
|
| 101 |
+
ld[f'{op}_{sp}']=DataLoader(MAD(op,p,sp,tf,seed),batch_size=bs,shuffle=(sp=='train'),drop_last=False)
|
| 102 |
+
return ld
|
| 103 |
+
|
| 104 |
+
# --- Model ---
|
| 105 |
+
@dataclass
|
| 106 |
+
class TC:
|
| 107 |
+
vs:int=104;nl:int=2;dm:int=128;nh:int=4;dmlp:int=512;msl:int=5;do:float=0.0;ln:bool=True
|
| 108 |
+
|
| 109 |
+
class MHA(nn.Module):
|
| 110 |
+
def __init__(self,c):
|
| 111 |
+
super().__init__();self.nh=c.nh;self.dh=c.dm//c.nh
|
| 112 |
+
self.WQ=nn.Linear(c.dm,c.dm,bias=False);self.WK=nn.Linear(c.dm,c.dm,bias=False)
|
| 113 |
+
self.WV=nn.Linear(c.dm,c.dm,bias=False);self.WO=nn.Linear(c.dm,c.dm,bias=False)
|
| 114 |
+
def forward(self,x,ra=False):
|
| 115 |
+
B,T,D=x.shape
|
| 116 |
+
Q=self.WQ(x).view(B,T,self.nh,self.dh).transpose(1,2)
|
| 117 |
+
K=self.WK(x).view(B,T,self.nh,self.dh).transpose(1,2)
|
| 118 |
+
V=self.WV(x).view(B,T,self.nh,self.dh).transpose(1,2)
|
| 119 |
+
s=(Q@K.transpose(-2,-1))/math.sqrt(self.dh)
|
| 120 |
+
s.masked_fill_(torch.triu(torch.ones(T,T,device=x.device),1).bool().unsqueeze(0).unsqueeze(0),float('-inf'))
|
| 121 |
+
a=F.softmax(s,dim=-1);return self.WO((a@V).transpose(1,2).reshape(B,T,D)),(a if ra else None)
|
| 122 |
+
|
| 123 |
+
class MLP2(nn.Module):
|
| 124 |
+
def __init__(self,c):
|
| 125 |
+
super().__init__();self.wi=nn.Linear(c.dm,c.dmlp);self.wo=nn.Linear(c.dmlp,c.dm)
|
| 126 |
+
def forward(self,x):h=F.gelu(self.wi(x));return self.wo(h),h
|
| 127 |
+
|
| 128 |
+
class TB(nn.Module):
|
| 129 |
+
def __init__(self,c):
|
| 130 |
+
super().__init__();self.attn=MHA(c);self.mlp=MLP2(c)
|
| 131 |
+
self.ln1=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
|
| 132 |
+
self.ln2=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
|
| 133 |
+
def forward(self,x,ri=False):
|
| 134 |
+
ao,aw=self.attn(self.ln1(x),ra=ri);x=x+ao;mo,mh=self.mlp(self.ln2(x));x=x+mo
|
| 135 |
+
r={'hs':x}
|
| 136 |
+
if ri:r['aw']=aw;r['mh']=mh
|
| 137 |
+
return r
|
| 138 |
+
|
| 139 |
+
class ST(nn.Module):
|
| 140 |
+
def __init__(self,c):
|
| 141 |
+
super().__init__();self.c=c
|
| 142 |
+
self.te=nn.Embedding(c.vs,c.dm);self.pe=nn.Embedding(c.msl,c.dm)
|
| 143 |
+
self.blocks=nn.ModuleList([TB(c) for _ in range(c.nl)])
|
| 144 |
+
self.lnf=nn.LayerNorm(c.dm) if c.ln else nn.Identity()
|
| 145 |
+
self.head=nn.Linear(c.dm,c.vs,bias=False);self.head.weight=self.te.weight
|
| 146 |
+
self.apply(self._iw)
|
| 147 |
+
def _iw(self,m):
|
| 148 |
+
if isinstance(m,nn.Linear):nn.init.normal_(m.weight,std=0.02)
|
| 149 |
+
elif isinstance(m,nn.Embedding):nn.init.normal_(m.weight,std=0.02)
|
| 150 |
+
def forward(self,ids,labels=None,ri=False):
|
| 151 |
+
B,T=ids.shape;x=self.te(ids)+self.pe(torch.arange(T,device=ids.device))
|
| 152 |
+
ahs=[x.detach()];aaw=[];amh=[]
|
| 153 |
+
for b in self.blocks:
|
| 154 |
+
r=b(x,ri=ri);x=r['hs'];ahs.append(x.detach())
|
| 155 |
+
if ri:aaw.append(r['aw'].detach());amh.append(r['mh'].detach())
|
| 156 |
+
lo=self.head(self.lnf(x));res={'logits':lo}
|
| 157 |
+
if labels is not None:res['loss']=F.cross_entropy(lo.view(-1,lo.size(-1)),labels.view(-1),ignore_index=-100)
|
| 158 |
+
if ri:res['hs']=ahs;res['aw']=aaw;res['mh']=amh
|
| 159 |
+
return res
|
| 160 |
+
def nparams(self):return sum(p.numel() for p in self.parameters())
|
| 161 |
+
|
| 162 |
+
# --- Experiment helpers ---
|
| 163 |
+
def ev(model,dl,dev):
|
| 164 |
+
model.eval();tl=tc=tt=0
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
for b in dl:
|
| 167 |
+
ids=b['input_ids'].to(dev);labs=b['labels'].to(dev)
|
| 168 |
+
o=model(ids,labels=labs);tl+=o['loss'].item()*ids.shape[0]
|
| 169 |
+
tc+=(o['logits'][:,-1,:].argmax(-1)==labs[:,-1]).sum().item();tt+=ids.shape[0]
|
| 170 |
+
return {'loss':tl/tt,'acc':tc/tt}
|
| 171 |
+
|
| 172 |
+
def creps(model,pids,dev,pos=-1):
|
| 173 |
+
model.eval()
|
| 174 |
+
with torch.no_grad():o=model(pids.to(dev),ri=True)
|
| 175 |
+
return {'hs':[h[:,pos,:].cpu() for h in o['hs']],'aw':[a.cpu() for a in o['aw']]}
|
| 176 |
+
|
| 177 |
+
def cmetrics(model,s0,s1,rc,ri,rp,dev,cfg):
|
| 178 |
+
m={};nl=cfg.nl+1
|
| 179 |
+
for li in range(nl):
|
| 180 |
+
p=f'layer_{li}';c=rc['hs'][li];ii=ri['hs'][li];pp=rp['hs'][li]
|
| 181 |
+
m[f'{p}/cka_vs_init']=linear_CKA(c,ii)
|
| 182 |
+
m[f'{p}/cka_vs_phase1']=linear_CKA(c,pp)
|
| 183 |
+
k=min(10,c.shape[0]//2,c.shape[1])
|
| 184 |
+
m[f'{p}/subspace_angle_vs_phase1']=mean_sa_deg(c,pp,k=k) if k>0 else 0.
|
| 185 |
+
for li,aw in enumerate(rc['aw']):
|
| 186 |
+
e=attn_entropy(aw);m[f'layer_{li+1}/attn_entropy_mean']=e['mean']
|
| 187 |
+
for h,he in enumerate(e['per_head']):m[f'layer_{li+1}/head_{h}_entropy']=he
|
| 188 |
+
cs={k:v.cpu() for k,v in model.state_dict().items()}
|
| 189 |
+
wp=wc_mag(s1,cs)
|
| 190 |
+
for bi in range(cfg.nl):
|
| 191 |
+
m[f'block_{bi}/weight_change_from_phase1']=sum(v for k,v in wp.items() if f'blocks.{bi}' in k)
|
| 192 |
+
return m
|
| 193 |
+
|
| 194 |
+
def train_phase(model,opt,dl,ne,dev,pn,s0,s1,ri,rp,pids,eval_loaders,cfg,ce=20):
|
| 195 |
+
hist=[];gs=0
|
| 196 |
+
for ep in range(ne):
|
| 197 |
+
model.train();el=nb=0
|
| 198 |
+
for b in dl:
|
| 199 |
+
ids=b['input_ids'].to(dev);labs=b['labels'].to(dev)
|
| 200 |
+
o=model(ids,labels=labs);o['loss'].backward();opt.step();opt.zero_grad()
|
| 201 |
+
el+=o['loss'].item();nb+=1;gs+=1
|
| 202 |
+
if gs%ce==0:
|
| 203 |
+
model.eval();rc=creps(model,pids,dev)
|
| 204 |
+
sm=cmetrics(model,s0,s1,rc,ri,rp,dev,cfg)
|
| 205 |
+
# Evaluate on ALL tasks
|
| 206 |
+
for n,ld in eval_loaders.items():
|
| 207 |
+
if '_test' in n:
|
| 208 |
+
e=ev(model,ld,dev);sm[f'eval/{n}_loss']=e['loss'];sm[f'eval/{n}_acc']=e['acc']
|
| 209 |
+
# Gradient alignment: addition vs each other task
|
| 210 |
+
add_batch=next(iter(eval_loaders['add_test']))
|
| 211 |
+
def lfn(m,b):return m(b['input_ids'].to(dev),labels=b['labels'].to(dev))['loss']
|
| 212 |
+
for task_name in ['subtract','multiply','max','xor']:
|
| 213 |
+
try:
|
| 214 |
+
tb=next(iter(eval_loaders[f'{task_name}_test']))
|
| 215 |
+
sm[f'grad_align_add_vs_{task_name}']=grad_align(model,add_batch,tb,lfn)
|
| 216 |
+
except:sm[f'grad_align_add_vs_{task_name}']=0.
|
| 217 |
+
# Fourier power spectrum of embeddings
|
| 218 |
+
emb_w = model.te.weight.detach().cpu()
|
| 219 |
+
fps = fourier_power_spectrum(emb_w, cfg.vs - NS)
|
| 220 |
+
# Store top-5 peak frequencies
|
| 221 |
+
top5 = np.argsort(fps)[::-1][:5]
|
| 222 |
+
sm['fourier_top5_freqs'] = top5.tolist()
|
| 223 |
+
sm['fourier_top5_power'] = fps[top5].tolist()
|
| 224 |
+
sm['fourier_total_power'] = float(fps.sum())
|
| 225 |
+
sm['fourier_concentration'] = float(fps[top5].sum() / fps.sum()) # how concentrated
|
| 226 |
+
|
| 227 |
+
sm['phase']=pn;sm['epoch']=ep;sm['step']=gs;sm['train_loss']=el/nb
|
| 228 |
+
hist.append(sm)
|
| 229 |
+
add_acc=sm.get('eval/add_test_acc',0)
|
| 230 |
+
task_acc=sm.get(f'eval/{pn.split("_")[-1]}_test_acc', add_acc)
|
| 231 |
+
print(f"[{pn}] S{gs} L:{el/nb:.4f} AddAcc:{add_acc:.3f} TaskAcc:{task_acc:.3f} "
|
| 232 |
+
f"CKA(L1vP1):{sm.get('layer_1/cka_vs_phase1',0):.3f} "
|
| 233 |
+
f"FourierConc:{sm.get('fourier_concentration',0):.3f}")
|
| 234 |
+
model.train()
|
| 235 |
+
print(f"[{pn}] Ep{ep+1}/{ne} L:{el/nb:.4f}")
|
| 236 |
+
return hist
|
| 237 |
+
|
| 238 |
+
# ============= MAIN =============
|
| 239 |
+
print("="*70)
|
| 240 |
+
print("GRADUATED DISSIMILARITY EXPERIMENT")
|
| 241 |
+
print("="*70)
|
| 242 |
+
t0=time.time()
|
| 243 |
+
dev=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 244 |
+
print(f"Device: {dev}")
|
| 245 |
+
|
| 246 |
+
p=97;seed=42;torch.manual_seed(seed);np.random.seed(seed)
|
| 247 |
+
cfg=TC(vs=p+NS,nl=2,dm=128,nh=4,dmlp=512,msl=5)
|
| 248 |
+
model=ST(cfg).to(dev)
|
| 249 |
+
print(f"Parameters: {model.nparams():,}")
|
| 250 |
+
|
| 251 |
+
s0={k:v.cpu().clone() for k,v in model.state_dict().items()}
|
| 252 |
+
lds=get_loaders(p=p,bs=512,tf=0.5,seed=seed)
|
| 253 |
+
|
| 254 |
+
# Probe data from addition test set (same across all branches)
|
| 255 |
+
dsa=MAD('add',p=p,split='test',tf=0.5,seed=seed)
|
| 256 |
+
pids_a,pla=get_probe(dsa,500)
|
| 257 |
+
ri=creps(model,pids_a,dev)
|
| 258 |
+
|
| 259 |
+
# Initial Fourier spectrum
|
| 260 |
+
init_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p)
|
| 261 |
+
|
| 262 |
+
# ========== PHASE 1: Train on Addition ==========
|
| 263 |
+
print(f"\n{'='*60}\nPHASE 1: Train on Addition ({150} epochs)\n{'='*60}")
|
| 264 |
+
opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1.0)
|
| 265 |
+
h1=train_phase(model,opt,lds['add_train'],150,dev,'p1_add',s0,s0,ri,ri,pids_a,lds,cfg,ce=20)
|
| 266 |
+
|
| 267 |
+
s1={k:v.cpu().clone() for k,v in model.state_dict().items()}
|
| 268 |
+
rp=creps(model,pids_a,dev)
|
| 269 |
+
p1_eval={op:ev(model,lds[f'{op}_test'],dev) for op in ALL_OPS}
|
| 270 |
+
print(f"\nPhase 1 final accuracies:")
|
| 271 |
+
for op,e in p1_eval.items():print(f" {op}: {e['acc']:.3f}")
|
| 272 |
+
|
| 273 |
+
os.makedirs('results',exist_ok=True)
|
| 274 |
+
torch.save(model.state_dict(),'results/p1.pt')
|
| 275 |
+
p1_fps = fourier_power_spectrum(model.te.weight.detach().cpu(), p)
|
| 276 |
+
|
| 277 |
+
# ========== PHASE 2: Fork into 5 branches ==========
|
| 278 |
+
branch_results = {}
|
| 279 |
+
phase2_epochs = 150
|
| 280 |
+
|
| 281 |
+
for task_name in ALL_OPS:
|
| 282 |
+
branch_name = f'p2_{task_name}'
|
| 283 |
+
level = ALL_OPS.index(task_name)
|
| 284 |
+
print(f"\n{'='*60}")
|
| 285 |
+
print(f"PHASE 2 — Branch A→{task_name.upper()} (Level {level}, {phase2_epochs} epochs)")
|
| 286 |
+
print(f"{'='*60}")
|
| 287 |
+
|
| 288 |
+
m = ST(cfg).to(dev)
|
| 289 |
+
m.load_state_dict(torch.load('results/p1.pt', weights_only=True))
|
| 290 |
+
o = optim.AdamW(m.parameters(), lr=1e-3, weight_decay=1.0)
|
| 291 |
+
h = train_phase(m, o, lds[f'{task_name}_train'], phase2_epochs, dev,
|
| 292 |
+
branch_name, s0, s1, ri, rp, pids_a, lds, cfg, ce=20)
|
| 293 |
+
|
| 294 |
+
# Final evaluation on ALL tasks
|
| 295 |
+
final_eval = {op: ev(m, lds[f'{op}_test'], dev) for op in ALL_OPS}
|
| 296 |
+
fps_final = fourier_power_spectrum(m.te.weight.detach().cpu(), p)
|
| 297 |
+
|
| 298 |
+
branch_results[task_name] = {
|
| 299 |
+
'history': h,
|
| 300 |
+
'final_eval': final_eval,
|
| 301 |
+
'fourier_spectrum_final': fps_final.tolist(),
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
print(f"\n Final accuracies after A→{task_name}:")
|
| 305 |
+
for op, e in final_eval.items():
|
| 306 |
+
marker = " ←TRAINED" if op == task_name else (" ←BASE" if op == 'add' else "")
|
| 307 |
+
print(f" {op}: {e['acc']:.3f}{marker}")
|
| 308 |
+
|
| 309 |
+
# Forgetting = Phase1 add accuracy - current add accuracy
|
| 310 |
+
forgetting = p1_eval['add']['acc'] - final_eval['add']['acc']
|
| 311 |
+
print(f" Addition FORGETTING: {forgetting*100:.1f}%")
|
| 312 |
+
|
| 313 |
+
# ========== PHASE 3: Cross-comparison ==========
|
| 314 |
+
print(f"\n{'='*60}\nPHASE 3: Cross-model Representation Comparison\n{'='*60}")
|
| 315 |
+
|
| 316 |
+
# Reload all branch final models for comparison
|
| 317 |
+
branch_reps = {}
|
| 318 |
+
for task_name in ALL_OPS:
|
| 319 |
+
# Re-run the final model to get representations
|
| 320 |
+
# (models were not saved to disk, so re-extract from branch_results last step)
|
| 321 |
+
# Actually let's just use the last checkpoint reps from training
|
| 322 |
+
pass
|
| 323 |
+
|
| 324 |
+
# Build summary table
|
| 325 |
+
summary = {
|
| 326 |
+
'config': {'p':p,'n_layers':2,'d_model':128,'n_heads':4,'d_mlp':512,
|
| 327 |
+
'phase1_epochs':150,'phase2_epochs':phase2_epochs,
|
| 328 |
+
'lr':1e-3,'weight_decay':1.0,'batch_size':512,
|
| 329 |
+
'train_frac':0.5,'seed':seed,'n_parameters':model.nparams()},
|
| 330 |
+
'dissimilarity_ladder': {
|
| 331 |
+
'add': {'level':0, 'description':'Identical (Fourier circuit)'},
|
| 332 |
+
'subtract': {'level':1, 'description':'Same Fourier circuit (sign flip)'},
|
| 333 |
+
'multiply': {'level':2, 'description':'Discrete-log Fourier circuit'},
|
| 334 |
+
'max': {'level':3, 'description':'Linear/ordinal circuit'},
|
| 335 |
+
'xor': {'level':4, 'description':'Bit-level circuit'},
|
| 336 |
+
},
|
| 337 |
+
'phase1_history': h1,
|
| 338 |
+
'phase1_final_eval': {op: p1_eval[op] for op in ALL_OPS},
|
| 339 |
+
'fourier_spectrum_init': init_fps.tolist(),
|
| 340 |
+
'fourier_spectrum_phase1': p1_fps.tolist(),
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
# Per-branch results
|
| 344 |
+
for task_name, br in branch_results.items():
|
| 345 |
+
level = ALL_OPS.index(task_name)
|
| 346 |
+
summary[f'branch_{task_name}'] = {
|
| 347 |
+
'level': level,
|
| 348 |
+
'history': br['history'],
|
| 349 |
+
'final_eval': br['final_eval'],
|
| 350 |
+
'fourier_spectrum': br['fourier_spectrum_final'],
|
| 351 |
+
'addition_forgetting': p1_eval['add']['acc'] - br['final_eval']['add']['acc'],
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
# Forgetting summary table
|
| 355 |
+
print("\n" + "="*60)
|
| 356 |
+
print("FORGETTING SUMMARY")
|
| 357 |
+
print("="*60)
|
| 358 |
+
print(f"{'Task':<15} {'Level':<8} {'Add Acc':<10} {'Task Acc':<10} {'Forgetting':<12} {'Grad Align':<12}")
|
| 359 |
+
print("-"*67)
|
| 360 |
+
for task_name in ALL_OPS:
|
| 361 |
+
br = branch_results[task_name]
|
| 362 |
+
level = ALL_OPS.index(task_name)
|
| 363 |
+
add_acc = br['final_eval']['add']['acc']
|
| 364 |
+
task_acc = br['final_eval'][task_name]['acc']
|
| 365 |
+
forgetting = p1_eval['add']['acc'] - add_acc
|
| 366 |
+
# Get last gradient alignment from history
|
| 367 |
+
last_ga = 0
|
| 368 |
+
for h in reversed(br['history']):
|
| 369 |
+
if f'grad_align_add_vs_{task_name}' in h:
|
| 370 |
+
last_ga = h[f'grad_align_add_vs_{task_name}']; break
|
| 371 |
+
elif task_name == 'add' and 'grad_align_add_vs_subtract' in h:
|
| 372 |
+
last_ga = h['grad_align_add_vs_subtract']; break
|
| 373 |
+
print(f"{task_name:<15} {level:<8} {add_acc:<10.3f} {task_acc:<10.3f} {forgetting*100:<12.1f}% {last_ga:<12.3f}")
|
| 374 |
+
|
| 375 |
+
summary['forgetting_table'] = {}
|
| 376 |
+
for task_name in ALL_OPS:
|
| 377 |
+
br = branch_results[task_name]
|
| 378 |
+
summary['forgetting_table'][task_name] = {
|
| 379 |
+
'level': ALL_OPS.index(task_name),
|
| 380 |
+
'addition_accuracy': br['final_eval']['add']['acc'],
|
| 381 |
+
'task_accuracy': br['final_eval'][task_name]['acc'],
|
| 382 |
+
'forgetting_pct': (p1_eval['add']['acc'] - br['final_eval']['add']['acc']) * 100,
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
with open('results/graduated_experiment_results.json','w') as f:
|
| 386 |
+
json.dump(summary,f,indent=2,default=str)
|
| 387 |
+
|
| 388 |
+
elapsed = time.time() - t0
|
| 389 |
+
print(f"\nExperiment completed in {elapsed/60:.1f} minutes")
|
| 390 |
+
print("Results saved to results/graduated_experiment_results.json")
|
| 391 |
+
|
| 392 |
+
# ============= GENERATE PLOTS =============
|
| 393 |
+
print("\nGenerating plots...")
|
| 394 |
+
import matplotlib;matplotlib.use('Agg');import matplotlib.pyplot as plt
|
| 395 |
+
|
| 396 |
+
colors = {'add':'#2196F3','subtract':'#4CAF50','multiply':'#FF9800','max':'#F44336','xor':'#9C27B0'}
|
| 397 |
+
labels = {'add':'Addition (L0)','subtract':'Subtraction (L1)','multiply':'Multiplication (L2)',
|
| 398 |
+
'max':'Max (L3)','xor':'XOR (L4)'}
|
| 399 |
+
|
| 400 |
+
# Plot 1: Forgetting bar chart
|
| 401 |
+
fig,ax=plt.subplots(figsize=(10,6))
|
| 402 |
+
tasks=[t for t in ALL_OPS]
|
| 403 |
+
forg=[summary['forgetting_table'][t]['forgetting_pct'] for t in tasks]
|
| 404 |
+
task_acc=[summary['forgetting_table'][t]['task_accuracy'] for t in tasks]
|
| 405 |
+
x=np.arange(len(tasks))
|
| 406 |
+
bars=ax.bar(x,forg,color=[colors[t] for t in tasks],edgecolor='black',linewidth=0.5)
|
| 407 |
+
ax.set_xticks(x);ax.set_xticklabels([labels[t] for t in tasks],rotation=15,ha='right')
|
| 408 |
+
ax.set_ylabel('Addition Accuracy Forgetting (%)');ax.set_title('Forgetting vs Task Dissimilarity Level')
|
| 409 |
+
ax.axhline(0,color='gray',ls=':',alpha=0.5)
|
| 410 |
+
for bar,ta in zip(bars,task_acc):
|
| 411 |
+
ax.text(bar.get_x()+bar.get_width()/2,bar.get_height()+0.5,f'Task:{ta:.1%}',
|
| 412 |
+
ha='center',va='bottom',fontsize=8)
|
| 413 |
+
plt.tight_layout();plt.savefig('results/forgetting_ladder.png',dpi=150);plt.close()
|
| 414 |
+
|
| 415 |
+
# Plot 2: Addition accuracy over training for each branch
|
| 416 |
+
fig,ax=plt.subplots(figsize=(12,6))
|
| 417 |
+
s1s=[h['step'] for h in h1]
|
| 418 |
+
ax.plot(s1s,[h.get('eval/add_test_acc',0) for h in h1],'k-',lw=2,label='Phase 1 (training)')
|
| 419 |
+
for task_name in ALL_OPS:
|
| 420 |
+
br=branch_results[task_name]
|
| 421 |
+
steps=[h['step']+s1s[-1] for h in br['history']]
|
| 422 |
+
acc=[h.get('eval/add_test_acc',0) for h in br['history']]
|
| 423 |
+
ax.plot(steps,acc,'-',color=colors[task_name],lw=2,label=labels[task_name])
|
| 424 |
+
ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5,label='Phase transition')
|
| 425 |
+
ax.set_xlabel('Training Step');ax.set_ylabel('Addition Test Accuracy')
|
| 426 |
+
ax.set_title('Addition Accuracy: Learning vs Forgetting');ax.legend();ax.set_ylim(-0.05,1.05)
|
| 427 |
+
plt.tight_layout();plt.savefig('results/addition_accuracy_all_branches.png',dpi=150);plt.close()
|
| 428 |
+
|
| 429 |
+
# Plot 3: CKA vs Phase 1 for each branch (layer 2 = final layer)
|
| 430 |
+
fig,axes=plt.subplots(1,3,figsize=(18,5))
|
| 431 |
+
for li,ax in enumerate(axes):
|
| 432 |
+
m_name=f'layer_{li}/cka_vs_phase1'
|
| 433 |
+
for task_name in ALL_OPS:
|
| 434 |
+
br=branch_results[task_name]
|
| 435 |
+
steps=[h['step'] for h in br['history'] if m_name in h]
|
| 436 |
+
vals=[h[m_name] for h in br['history'] if m_name in h]
|
| 437 |
+
ax.plot(steps,vals,'-',color=colors[task_name],lw=1.5,label=labels[task_name])
|
| 438 |
+
ax.set_title(f'{"Embedding" if li==0 else f"Layer {li}"}: CKA vs Phase 1 End')
|
| 439 |
+
ax.set_xlabel('Step');ax.set_ylabel('CKA');ax.legend(fontsize=7);ax.set_ylim(0,1.05)
|
| 440 |
+
plt.tight_layout();plt.savefig('results/cka_all_branches.png',dpi=150);plt.close()
|
| 441 |
+
|
| 442 |
+
# Plot 4: Gradient alignment with addition over training
|
| 443 |
+
fig,ax=plt.subplots(figsize=(12,6))
|
| 444 |
+
for task_name in ['subtract','multiply','max','xor']:
|
| 445 |
+
m_name=f'grad_align_add_vs_{task_name}'
|
| 446 |
+
# Phase 1 data
|
| 447 |
+
steps_p1=[h['step'] for h in h1 if m_name in h]
|
| 448 |
+
vals_p1=[h[m_name] for h in h1 if m_name in h]
|
| 449 |
+
ax.plot(steps_p1,vals_p1,':',color=colors[task_name],lw=1,alpha=0.5)
|
| 450 |
+
# Branch data (training on this task)
|
| 451 |
+
br=branch_results[task_name]
|
| 452 |
+
steps=[h['step']+s1s[-1] for h in br['history'] if m_name in h]
|
| 453 |
+
vals=[h[m_name] for h in br['history'] if m_name in h]
|
| 454 |
+
ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name])
|
| 455 |
+
ax.axvline(s1s[-1],color='gray',ls='--',alpha=0.5)
|
| 456 |
+
ax.axhline(0,color='gray',ls=':',alpha=0.3)
|
| 457 |
+
ax.set_xlabel('Step');ax.set_ylabel('Gradient Cosine Similarity')
|
| 458 |
+
ax.set_title('Gradient Alignment with Addition Task');ax.legend()
|
| 459 |
+
plt.tight_layout();plt.savefig('results/gradient_alignment_all.png',dpi=150);plt.close()
|
| 460 |
+
|
| 461 |
+
# Plot 5: Fourier spectrum comparison
|
| 462 |
+
fig,axes=plt.subplots(2,3,figsize=(18,10))
|
| 463 |
+
freqs=np.arange(len(init_fps))
|
| 464 |
+
axes[0,0].bar(freqs,init_fps,color='gray',alpha=0.5);axes[0,0].set_title('Init');axes[0,0].set_ylabel('Power')
|
| 465 |
+
axes[0,1].bar(freqs,p1_fps,color='black',alpha=0.7);axes[0,1].set_title('After Phase 1 (Addition)')
|
| 466 |
+
for i,task_name in enumerate(['subtract','multiply','max','xor']):
|
| 467 |
+
ax=axes[(i+2)//3,(i+2)%3]
|
| 468 |
+
fps=np.array(branch_results[task_name]['fourier_spectrum_final'])
|
| 469 |
+
ax.bar(freqs[:len(fps)],fps,color=colors[task_name],alpha=0.7)
|
| 470 |
+
ax.set_title(f'After A→{task_name.title()}')
|
| 471 |
+
if (i+2)//3==1:ax.set_xlabel('Frequency')
|
| 472 |
+
if (i+2)%3==0:ax.set_ylabel('Power')
|
| 473 |
+
plt.suptitle('Embedding Fourier Power Spectrum Evolution',fontsize=14)
|
| 474 |
+
plt.tight_layout();plt.savefig('results/fourier_spectra.png',dpi=150);plt.close()
|
| 475 |
+
|
| 476 |
+
# Plot 6: Subspace angles
|
| 477 |
+
fig,ax=plt.subplots(figsize=(12,6))
|
| 478 |
+
for task_name in ALL_OPS:
|
| 479 |
+
m_name='layer_2/subspace_angle_vs_phase1'
|
| 480 |
+
br=branch_results[task_name]
|
| 481 |
+
steps=[h['step'] for h in br['history'] if m_name in h]
|
| 482 |
+
vals=[h[m_name] for h in br['history'] if m_name in h]
|
| 483 |
+
ax.plot(steps,vals,'-',color=colors[task_name],lw=2,label=labels[task_name])
|
| 484 |
+
ax.set_xlabel('Step');ax.set_ylabel('Subspace Angle (degrees)')
|
| 485 |
+
ax.set_title('Final Layer Subspace Angle Drift from Phase 1');ax.legend()
|
| 486 |
+
plt.tight_layout();plt.savefig('results/subspace_angles_all.png',dpi=150);plt.close()
|
| 487 |
+
|
| 488 |
+
print("All plots generated!")
|
| 489 |
+
|
| 490 |
+
# Upload to HF
|
| 491 |
+
from huggingface_hub import HfApi
|
| 492 |
+
api=HfApi(token=os.environ.get('HF_TOKEN'))
|
| 493 |
+
repo='tekkmaven/representation-learning-dynamics'
|
| 494 |
+
files_to_upload = [
|
| 495 |
+
'graduated_experiment_results.json',
|
| 496 |
+
'forgetting_ladder.png','addition_accuracy_all_branches.png',
|
| 497 |
+
'cka_all_branches.png','gradient_alignment_all.png',
|
| 498 |
+
'fourier_spectra.png','subspace_angles_all.png',
|
| 499 |
+
]
|
| 500 |
+
for f in files_to_upload:
|
| 501 |
+
try:
|
| 502 |
+
api.upload_file(path_or_fileobj=f'results/{f}',path_in_repo=f'results/{f}',repo_id=repo,repo_type='model')
|
| 503 |
+
print(f"Uploaded {f}")
|
| 504 |
+
except Exception as e:
|
| 505 |
+
print(f"Failed to upload {f}: {e}")
|
| 506 |
+
|
| 507 |
+
print(f"\n=== EXPERIMENT COMPLETE ===")
|
| 508 |
+
print(f"Total time: {(time.time()-t0)/60:.1f} minutes")
|
| 509 |
+
print(f"Repository: https://huggingface.co/{repo}")
|