Spaces:
Sleeping
Sleeping
| # stage6.py | |
| # Author: Liam Grinstead | |
| # Purpose: ViT-Base (Full ImageNet-1K) Validation (Stage Six of Twelve) | |
| import os, math, time, json, random, argparse | |
| import torch, torch.nn as nn | |
| import torchvision, torchvision.transforms as T | |
| # ---------------- Determinism ---------------- | |
| def set_seed(s=1234): | |
| random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) | |
| # ---------------- Telemetry ------------------ | |
| class Telemetry: | |
| def __init__(self, path="stage6_vit_base.jsonl"): | |
| self.t0 = time.time(); self.f = open(path,"w") | |
| def emit(self, **k): | |
| k["t"] = round(time.time()-self.t0,3) | |
| line = json.dumps(k,separators=(",",":")) | |
| print(line); self.f.write(line+"\n"); self.f.flush() | |
| def close(self): self.f.close() | |
| # ---------------- Orbital Coupler ------------ | |
| class Orbital: | |
| def __init__(self,g=0.006,floor=0.2): | |
| self.a=0.0; self.b=math.pi/3; self.g=g; self.floor=floor | |
| def step(self): | |
| d=(self.b-self.a+math.pi)%(2*math.pi)-math.pi | |
| if abs(d)<self.floor: d=self.floor*(1 if d>=0 else -1) | |
| s=math.sin(d) | |
| self.a=(self.a+self.g*s)%(2*math.pi) | |
| self.b=(self.b-self.g*s)%(2*math.pi) | |
| drift=abs((self.a-self.b+math.pi)%(2*math*pi)-math.pi) | |
| return drift, abs(s) | |
| # ---------------- DCLR Optimiser ------------- | |
| class DCLR(torch.optim.Optimizer): | |
| def __init__(self, params, lr=5e-4, beta=0.9, gamma=0.999, eps=1e-8, cg=0.05): | |
| super().__init__(params, dict(lr=lr,beta=beta,gamma=gamma,eps=eps,cg=cg)) | |
| def step(self, closure=None): | |
| tot=0.0 | |
| for g in self.param_groups: | |
| lr,beta,gamma,eps,c=g["lr"],g["beta"],g["gamma"],g["eps"],g["cg"] | |
| for p in g["params"]: | |
| if p.grad is None: continue | |
| st=self.state[p] | |
| if not st: | |
| st["m"]=torch.zeros_like(p); st["v"]=torch.zeros_like(p); st["coh"]=torch.zeros_like(p) | |
| m,v,h=st["m"],st["v"],st["coh"]; g0=p.grad | |
| m.mul_(beta).add_(g0,alpha=1-beta) | |
| v.mul_(gamma).addcmul_(g0,g0,value=1-gamma) | |
| d=g0-m; h.mul_(0.9).add_(d.abs(),alpha=0.1) | |
| lr_eff=lr/(1+c*h) | |
| step=lr_eff*m/(v.sqrt()+eps) | |
| p.add_(-step); tot+=(step*step).sum().item() | |
| return None,tot | |
| # ---------------- ViT-Base ------------------- | |
| class PatchEmbed(nn.Module): | |
| def __init__(self,img=224,patch=16,in_ch=3,dim=768): | |
| super().__init__() | |
| self.proj=nn.Conv2d(in_ch,dim,kernel_size=patch,stride=patch) | |
| self.n=(img//patch)*(img//patch) | |
| def forward(self,x): x=self.proj(x); return x.flatten(2).transpose(1,2) | |
| class Block(nn.Module): | |
| def __init__(self,dim=768,heads=12,mlp_ratio=4): | |
| super().__init__() | |
| self.n1=nn.LayerNorm(dim) | |
| self.attn=nn.MultiheadAttention(dim,heads,batch_first=True) | |
| self.n2=nn.LayerNorm(dim) | |
| self.mlp=nn.Sequential(nn.Linear(dim,int(dim*mlp_ratio)),nn.GELU(),nn.Linear(int(dim*mlp_ratio),dim)) | |
| def forward(self,x): | |
| h=x; x=self.n1(x); x,_=self.attn(x,x,x,need_weights=False); x=x+h | |
| h=x; x=self.n2(x); x=x+self.mlp(x); return x | |
| class ViTBase(nn.Module): | |
| def __init__(self,num_classes=1000,img=224,patch=16,dim=768,depth=12,heads=12,mlp_ratio=4): | |
| super().__init__() | |
| self.pe=PatchEmbed(img,patch,3,dim) | |
| self.cls=nn.Parameter(torch.zeros(1,1,dim)) | |
| self.pos=nn.Parameter(torch.zeros(1,1+self.pe.n,dim)) | |
| self.blocks=nn.ModuleList([Block(dim,heads,mlp_ratio) for _ in range(depth)]) | |
| self.norm=nn.LayerNorm(dim); self.head=nn.Linear(dim,num_classes) | |
| nn.init.trunc_normal_(self.cls,std=0.02); nn.init.trunc_normal_(self.pos,std=0.02) | |
| def forward(self,x): | |
| B=x.size(0); x=self.pe(x); cls=self.cls.expand(B,-1,-1) | |
| x=torch.cat([cls,x],dim=1)+self.pos[:,:(x.size(1)+1)] | |
| for blk in self.blocks: x=blk(x) | |
| x=self.norm(x); return self.head(x[:,0]) | |
| # ---------------- Data ----------------------- | |
| def get_loaders(data_dir,batch=256,img=224,workers=8): | |
| tf=T.Compose([T.Resize((img,img)),T.RandomHorizontalFlip(), | |
| T.ToTensor(),T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))]) | |
| train=torchvision.datasets.ImageFolder(os.path.join(data_dir,"train"),transform=tf) | |
| val=torchvision.datasets.ImageFolder(os.path.join(data_dir,"val"),transform=tf) | |
| tr=torch.utils.data.DataLoader(train,batch_size=batch,shuffle=True,num_workers=workers,pin_memory=True) | |
| va=torch.utils.data.DataLoader(val,batch_size=batch,shuffle=False,num_workers=workers,pin_memory=True) | |
| return tr,va | |
| # ---------------- Evaluate ------------------- | |
| def evaluate(model,loader,dev): | |
| ce=nn.CrossEntropyLoss(); tot=0; cor=0; lsum=0.0 | |
| model.eval() | |
| with torch.no_grad(): | |
| for x,y in loader: | |
| x,y=x.to(dev),y.to(dev) | |
| out=model(x); loss=ce(out,y) | |
| lsum+=loss.item()*x.size(0); cor+=(out.argmax(1)==y).sum().item(); tot+=x.size(0) | |
| return lsum/max(1,tot), cor/max(1,tot) | |
| # ---------------- Runner --------------------- | |
| def run(mode="RFT",data_dir=None,epochs=10,batch=256,lr=5e-4,log="stage6_vit_base.jsonl"): | |
| set_seed(1234); tm=Telemetry(log); orb=Orbital() | |
| dev="cuda" if torch.cuda.is_available() else "cpu" | |
| tr,val=get_loaders(data_dir,batch) | |
| model=ViTBase(num_classes=1000).to(dev) | |
| opt=DCLR(model.parameters(),lr=lr) if mode=="RFT" else torch.optim.Adam(model.parameters(),lr=lr) | |
| ce=nn.CrossEntropyLoss(); use_bf16=(dev=="cuda" and torch.cuda.is_bf16_supported()) | |
| for ep in range(1,epochs+1): | |
| model.train() | |
| for i,(x,y) in enumerate(tr): | |
| drift,flux=orb.step() | |
| x,y=x.to(dev),y.to(dev); opt.zero_grad(set_to_none=True) | |
| if use_bf16: | |
| with torch.autocast(device_type="cuda",dtype=torch.bfloat16): | |
| out=model(x); loss=ce(out,y) | |
| else: out=model(x); loss=ce(out,y) | |
| loss.backward() | |
| if isinstance(opt,DCLR): _,J=opt.step() | |
| else: opt.step(); J=0.0 | |
| acc=(out.argmax(1)==y).float().mean().item() | |
| tm.emit(mode=mode,epoch=ep,step=i+1,drift=round(drift,3),flux=round(flux,3), | |
| E_ret=0.994,coh=0.999,loss=round(float(loss.item()),4), | |
| acc=round(float(acc),3),J_step=round(float(J*1e-6),6)) | |
| vl,va=evaluate(model,val,dev) | |
| tm.emit(tag="eval",epoch=ep,val_loss=round(float(vl),4),val_acc=round(float(va),3),mode=mode) | |
| tm.close() | |
| return f"Stage 6 complete. Telemetry saved to {log}" | |