"""PG-v2: SP8192+ParallelRes+DepthRec+TTT+Int6GPTQ+EMA""" from __future__ import annotations import copy,glob,io,math,os,random,subprocess,sys,time,uuid,zlib from pathlib import Path import numpy as np import sentencepiece as spm import torch,torch.distributed as dist,torch.nn.functional as F from torch import Tensor,nn from torch.nn.parallel import DistributedDataParallel as DDP class H: dp=os.environ.get("DATA_PATH","./data/datasets/fineweb10B_sp8192") tf=os.path.join(dp,"fineweb_train_*.bin") vf=os.path.join(dp,"fineweb_val_*.bin") tp=os.environ.get("TOKENIZER_PATH","./data/tokenizers/fineweb_8192_bpe.model") rid=os.environ.get("RUN_ID",str(uuid.uuid4())) seed=int(os.environ.get("SEED","1337")) vbs=int(os.environ.get("VBS","524288"));vle=int(os.environ.get("VLE","1000")) tle=int(os.environ.get("TLE","200")) iters=int(os.environ.get("ITERS","20000")) wdi=int(os.environ.get("WDI","3500"));wui=int(os.environ.get("WUI","20")) tbt=int(os.environ.get("TBT","524288"));tsl=int(os.environ.get("TSL","1024")) mws=float(os.environ.get("MWS","600.0")) V=int(os.environ.get("V","8192"));D=int(os.environ.get("D","768")) nh=int(os.environ.get("NH","12"));nkv=int(os.environ.get("NKV","4")) mm=int(os.environ.get("MM","4")) nul=int(os.environ.get("NUL","3"));nr=int(os.environ.get("NR","8")) ner=int(os.environ.get("NER","0")) rb=float(os.environ.get("RB","10000.0")) lsc=float(os.environ.get("LSC","30.0")) qkg=float(os.environ.get("QKG","5.25")) sws=int(os.environ.get("SWS","64"));swl=int(os.environ.get("SWL","1024")) tte=int(os.environ.get("TTE","1")) ttlr=float(os.environ.get("TTLR","0.01")) ttcs=int(os.environ.get("TTCS","64")) ttly=os.environ.get("TTLY","all") elr=float(os.environ.get("ELR","0.05")) mlr=float(os.environ.get("MLR","0.04")) slr=float(os.environ.get("SLR","0.04")) mmo=float(os.environ.get("MMO","0.95")) mbs=int(os.environ.get("MBS","5")) mwd=float(os.environ.get("MWD","0.09")) b1=float(os.environ.get("B1","0.9")) b2=float(os.environ.get("B2","0.95")) ae=float(os.environ.get("AE","1e-8")) gb=int(os.environ.get("GB","6")) sdn=float(os.environ.get("SDN","2.5")) esf=float(os.environ.get("ESF","0.4")) CP=tuple(p for p in "attn_scale,mlp_scale,resid_mix,q_gain".split(",") if p) def zp5(G,s=10,e=1e-7): a,b,c=3.4445,-4.7750,2.0315 X=G.bfloat16();X/=X.norm()+e tr=G.size(0)>G.size(1) if tr:X=X.T for _ in range(s): A=X@X.T;B=b*A+c*A@A;X=a*X+B@X return X.T if tr else X class Muon(torch.optim.Optimizer): def __init__(s,p,lr,mom,bs,wd=0.,nest=True): super().__init__(p,dict(lr=lr,mom=mom,bs=bs,wd=wd,nest=nest)) @torch.no_grad() def step(s,cl=None): lo=None if cl: with torch.enable_grad():lo=cl() dd=dist.is_available() and dist.is_initialized() ws=dist.get_world_size() if dd else 1 rk=dist.get_rank() if dd else 0 for g in s.param_groups: ps=g["params"];lr=g["lr"];mo=g["mom"];bs=g["bs"];wd=g["wd"];ne=g["nest"] tot=sum(int(p.numel()) for p in ps) fl=torch.zeros(tot,device=ps[0].device,dtype=torch.bfloat16) cur=0 for i,p in enumerate(ps): if i%ws==rk and p.grad is not None: gr=p.grad if wd:gr=gr+wd*p.data.to(gr.dtype) st=s.state[p] if "mb" not in st:st["mb"]=torch.zeros_like(gr) buf=st["mb"];buf.mul_(mo).add_(gr) if ne:gr=gr.add(buf,alpha=mo) gr=zp5(gr,steps=bs) gr*=max(1,gr.size(0)/gr.size(1))**0.5 fl[cur:cur+p.numel()]=gr.reshape(-1) cur+=p.numel() if dd:dist.all_reduce(fl,op=dist.ReduceOp.SUM) cur=0 for p in ps: gr=fl[cur:cur+p.numel()].view_as(p).to(dtype=p.dtype) p.add_(gr,alpha=-lr);cur+=p.numel() return lo def build_sp_luts(sp,vs,dev): sv=int(sp.vocab_size());sz=max(sv,vs) bb=np.zeros(sz,dtype=np.int16);hs=np.zeros(sz,dtype=bool);ib=np.ones(sz,dtype=bool) for t in range(sv): if sp.is_control(t) or sp.is_unknown(t) or sp.is_unused(t):continue ib[t]=False if sp.is_byte(t):bb[t]=1;continue pc=sp.id_to_piece(t) if pc.startswith("\u2581"):hs[t]=True;pc=pc[1:] bb[t]=len(pc.encode("utf-8")) return(torch.tensor(bb,dtype=torch.int16,device=dev), torch.tensor(hs,dtype=torch.bool,device=dev), torch.tensor(ib,dtype=torch.bool,device=dev)) def eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False): sl=a.swl;st=a.sws;T=vt.numel() starts=list(range(0,T-sl-1,st)) my=starts[rk::ws] ls=torch.zeros((),device=dev,dtype=torch.float64) tc=torch.zeros((),device=dev,dtype=torch.float64) bc=torch.zeros((),device=dev,dtype=torch.float64) rm=mdl while hasattr(rm,'module'):rm=rm.module if hasattr(rm,'_orig_mod'):rm=rm._orig_mod rm.eval() ctx=torch.no_grad if ttt else torch.inference_mode with ctx(): for s in my: e=s+sl x=vt[s:e].unsqueeze(0).to(dev,dtype=torch.int64) y=vt[s+1:e+1].unsqueeze(0).to(dev,dtype=torch.int64) with torch.autocast("cuda",dtype=torch.bfloat16): if ttt and a.tte:ptl=rm.ptl_ttt(x,y,a) else:ptl=rm.ptl(x,y) lo=sl-st;ps=ptl[0,lo:];ys=y[0,lo:];xs=x[0,lo:] ls+=ps.to(torch.float64).sum();tc+=ps.numel() tb=bbl[ys].to(torch.float64) tb+=(hsl[ys]&~ibl[xs]).to(torch.float64) bc+=tb.sum() if dist.is_available() and dist.is_initialized(): for t in(ls,tc,bc):dist.all_reduce(t,op=dist.ReduceOp.SUM) vl=float((ls/tc).item());bpb=float((ls/math.log(2)/bc).item()) rm.train();return vl,bpb def sdclip(t,n=2.5): m=t.float().mean();s=t.float().std() return t.clamp((m-n*s).item(),(m+n*s).item()) def qi6(t,ns=2.5): t32=t.float();mx=31 if t32.ndim==2: m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9) lo=m-ns*s;hi=m+ns*s tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32)) cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/mx q=torch.clamp(torch.round(tc/sc[:,None]),-mx,mx).to(torch.int8) return q.contiguous(),sc.to(torch.float16).contiguous() tc=sdclip(t32,ns);cv=float(tc.abs().max().item()) sc=torch.tensor(max(cv/mx,1./mx),dtype=torch.float32) q=torch.clamp(torch.round(tc/sc),-mx,mx).to(torch.int8) return q.contiguous(),sc def qi8(t,ns=2.5): t32=t.float() if t32.ndim==2: m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9) lo=m-ns*s;hi=m+ns*s tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32)) cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/127. q=torch.clamp(torch.round(tc/sc[:,None]),-127,127).to(torch.int8) return q.contiguous(),sc.to(torch.float16).contiguous() cv=float(sdclip(t32,ns).abs().max().item()) sc=torch.tensor(max(cv/127.,1./127.),dtype=torch.float32) q=torch.clamp(torch.round(t32.clamp(-cv,cv)/sc),-127,127).to(torch.int8) return q.contiguous(),sc def qsd(sd,gb=6,ns=2.5): qf=qi6 if gb==6 else qi8 qu,sc,dt,pt,po,qm={},{},{},{},{},{} st={k:0 for k in("pc","nt","bb","qb")} for n,t in sd.items(): t=t.detach().cpu().contiguous() st["pc"]+=t.numel();st["nt"]+=1;st["bb"]+=t.numel()*t.element_size() if not t.is_floating_point():pt[n]=t;st["qb"]+=t.numel()*t.element_size();continue ic=any(p in n for p in CP);ism=t.numel()<=65536 if "tok_emb" in n: po[n]=str(t.dtype).removeprefix("torch.") q,s=qi8(t,ns);qu[n]=q;sc[n]=s;dt[n]=po[n] if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":8} st["qb"]+=q.numel()+s.numel()*s.element_size();continue if ic or ism: if t.dtype in(torch.float32,torch.bfloat16):po[n]=str(t.dtype).removeprefix("torch.") pt[n]=t.float() if ic else t.to(torch.float16) pt[n]=pt[n].contiguous();st["qb"]+=pt[n].numel()*pt[n].element_size();continue q,s=qf(t,ns) if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":gb} qu[n]=q;sc[n]=s;dt[n]=str(t.dtype).removeprefix("torch.") st["qb"]+=q.numel()+s.numel()*s.element_size() obj={"__qf__":f"i{gb}sd","q":qu,"s":sc,"d":dt,"p":pt} if qm:obj["m"]=qm if po:obj["o"]=po return obj,st def dqsd(obj): out={};qm=obj.get("m",{});po=obj.get("o",{}) for n,q in obj["q"].items(): dt=getattr(torch,obj["d"][n]);s=obj["s"][n] if qm.get(n,{}).get("scheme")=="per_row" or s.ndim>0: s=s.to(torch.float32) out[n]=(q.float()*s.view(q.shape[0],*([1]*(q.ndim-1)))).to(dt).contiguous() else:out[n]=(q.float()*float(s.item())).to(dt).contiguous() for n,t in obj["p"].items(): ot=t.detach().cpu().contiguous();od=po.get(n) if isinstance(od,str):ot=ot.to(dtype=getattr(torch,od)).contiguous() out[n]=ot return out def lds(f): h=np.fromfile(f,dtype="0: av=s.t.numel()-s.p if av<=0:s.i=(s.i+1)%len(s.fs);s.t=lds(s.fs[s.i]);s.p=0;av=s.t.numel() k=min(r,av);ch.append(s.t[s.p:s.p+k]);s.p+=k;r-=k return ch[0] if len(ch)==1 else torch.cat(ch) class DTL: def __init__(s,pat,rk,ws,dev):s.rk=rk;s.ws=ws;s.dev=dev;s.st=TS(pat) def nb(s,gt,sl,ga): lt=gt//(s.ws*ga);ps=lt+1 ch=s.st.take(ps*s.ws);st=s.rk*ps lo=ch[st:st+ps].to(torch.int64) x=lo[:-1].reshape(-1,sl);y=lo[1:].reshape(-1,sl) return x.to(s.dev,non_blocking=True),y.to(s.dev,non_blocking=True) class RN(nn.Module): def __init__(s,eps=None):super().__init__();s.eps=eps def forward(s,x):return F.rms_norm(x,(x.size(-1),),eps=s.eps) class Rot(nn.Module): def __init__(s,d,b=10000.): super().__init__() s.register_buffer("if_",1./(b**(torch.arange(0,d,2,dtype=torch.float32)/d)),persistent=False) s._cl=0;s._c=None;s._s=None def forward(s,sl,dev,dt): if s._c is None or s._cl!=sl or s._c.device!=dev: t=torch.arange(sl,device=dev,dtype=s.if_.dtype) fr=torch.outer(t,s.if_.to(dev)) s._c=fr.cos()[None,None,:,:];s._s=fr.sin()[None,None,:,:];s._cl=sl return s._c.to(dtype=dt),s._s.to(dtype=dt) def arot(x,c,si): h=x.size(-1)//2;x1,x2=x[...,:h],x[...,h:] return torch.cat((x1*c+x2*si,x1*(-si)+x2*c),dim=-1) class CSA(nn.Module): def __init__(s,d,nh,nk,rb,qkg): super().__init__() assert d%nh==0 and nh%nk==0 s.nh=nh;s.nk=nk;s.hd=d//nh;kd=nk*s.hd s.cq=nn.Linear(d,d,bias=False);s.ck=nn.Linear(d,kd,bias=False) s.cv=nn.Linear(d,kd,bias=False);s.pr=nn.Linear(d,d,bias=False) s.qg=nn.Parameter(torch.full((nh,),qkg,dtype=torch.float32)) s.rot=Rot(s.hd,base=rb) def forward(s,x): B,T,_=x.shape q=s.cq(x).reshape(B,T,s.nh,s.hd).transpose(1,2) k=s.ck(x).reshape(B,T,s.nk,s.hd).transpose(1,2) v=s.cv(x).reshape(B,T,s.nk,s.hd).transpose(1,2) q=F.rms_norm(q,(q.size(-1),));k=F.rms_norm(k,(k.size(-1),)) c,si=s.rot(T,x.device,q.dtype) q=arot(q,c,si);k=arot(k,c,si) q=q*s.qg.to(dtype=q.dtype)[None,:,None,None] y=F.scaled_dot_product_attention(q,k,v,attn_mask=None,is_causal=True, enable_gqa=(s.nk!=s.nh)) return s.pr(y.transpose(1,2).contiguous().reshape(B,T,-1)) class MLP(nn.Module): def __init__(s,d,m): super().__init__() h=d*m;s.fc=nn.Linear(d,h,bias=False);s.pr=nn.Linear(h,d,bias=False) def forward(s,x):return s.pr(torch.relu(s.fc(x)).square()) class PB(nn.Module): """Parallel residual block.""" def __init__(s,d,nh,nk,mm,rb,qkg): super().__init__() s.n=RN();s.a=CSA(d,nh,nk,rb,qkg);s.m=MLP(d,mm) s.as_=nn.Parameter(torch.ones(d,dtype=torch.float32)) s.ms=nn.Parameter(torch.ones(d,dtype=torch.float32)) s.rm=nn.Parameter(torch.stack([torch.ones(d),torch.zeros(d)]).float()) def forward(s,x,x0): mx=s.rm.to(x.dtype) x=mx[0][None,None,:]*x+mx[1][None,None,:]*x0 h=s.n(x) x=x+s.as_.to(x.dtype)[None,None,:]*s.a(h)+s.ms.to(x.dtype)[None,None,:]*s.m(h) return x class RGPT(nn.Module): def __init__(s,a): super().__init__() s.lsc=a.lsc;s._tr=a.nr;s._er=a.ner or a.nr*2;s._V=a.V s.te=nn.Embedding(a.V,a.D) s.bl=nn.ModuleList([PB(a.D,a.nh,a.nkv,a.mm,a.rb,a.qkg) for _ in range(a.nul)]) s.fn=RN();nn.init.normal_(s.te.weight,std=0.005) def _fh(s,ids): x=F.rms_norm(s.te(ids),(s.te.embedding_dim,));x0=x n=s._tr if s.training else s._er for _ in range(n): for b in s.bl:x=b(x,x0) return s.fn(x) def forward(s,ids,tgt): h=s._fh(ids);lo=F.linear(h.reshape(-1,h.size(-1)),s.te.weight) lo=s.lsc*torch.tanh(lo/s.lsc) return F.cross_entropy(lo.float(),tgt.reshape(-1),reduction="mean") def ptl(s,ids,tgt): h=s._fh(ids);B,T,D=h.shape lo=F.linear(h.reshape(B*T,D),s.te.weight) lo=s.lsc*torch.tanh(lo/s.lsc) return F.cross_entropy(lo.float(),tgt.reshape(B*T),reduction="none").reshape(B,T) @torch.no_grad() def ptl_ttt(s,ids,tgt,a): """Score-first TTT: score chunk, then update MLP W_down for next chunk.""" cs=a.ttcs;lr=a.ttlr;B,T=ids.shape if a.ttly=="all":li=list(range(len(s.bl))) else:li=[int(x) for x in a.ttly.split(",")] ow={i:s.bl[i].m.pr.weight.data.clone() for i in li} ap=[];nc=(T+cs-1)//cs for ci in range(nc): lo=ci*cs;hi=min((ci+1)*cs,T) h=s._fh(ids);hc=h[:,lo:hi,:];yc=tgt[:,lo:hi] lg=F.linear(hc.reshape(-1,hc.size(-1)),s.te.weight) lg=s.lsc*torch.tanh(lg/s.lsc) pt=F.cross_entropy(lg.float(),yc.reshape(-1),reduction="none").reshape(B,hi-lo) ap.append(pt) if ci0 else None def lrm(st,el): if a.wdi<=0:return 1. if mms is None: w=max(a.iters-a.wdi,0) return max((a.iters-st)/max(a.wdi,1),0.) if w<=st0: im={n:t.detach().cpu().clone() for n,t in bm.state_dict().items()} io_=[copy.deepcopy(o.state_dict()) for o in opts] mdl.train();tw=DTL(a.tf,rk,ws,dev) for _ in range(a.wui): za() for mi in range(ga): if dd:mdl.require_backward_grad_sync=(mi==ga-1) x,y=tw.nb(a.tbt,a.tsl,ga) with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward() for o in opts:o.step() za() bm.load_state_dict(im,strict=True) for o,s in zip(opts,io_):o.load_state_dict(s) za() if dd:mdl.require_backward_grad_sync=True tl=DTL(a.tf,rk,ws,dev);tms=0.;ss=None torch.cuda.synchronize();t0=time.perf_counter();step=0 while True: ls=step==a.iters or(ss is not None and step>=ss) dv=ls or(a.vle>0 and step%a.vle==0) if dv: torch.cuda.synchronize();tms+=1000.*(time.perf_counter()-t0) vl,vb=eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False) l0(f"step:{step}/{a.iters} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_ms:{tms:.0f} step_avg:{tms/max(step,1):.2f}ms") torch.cuda.synchronize();t0=time.perf_counter() if ls: if ma: ema.ap();l0("EMA+TTT eval...") vle,vbe=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True) l0(f"ema_ttt val_loss:{vle:.4f} val_bpb:{vbe:.4f}") sd=bm.state_dict();obj,st=qsd(sd,a.gb,a.sdn) buf=io.BytesIO();torch.save(obj,buf) cmp=zlib.compress(buf.getvalue(),level=9) cb=len(code.encode());mb=len(cmp);tb=cb+mb l0(f"artifact code:{cb} model:{mb} total:{tb} ({tb/1e6:.3f}MB) params:{st['pc']}") sd2=dqsd(obj);bm.load_state_dict(sd2,strict=True) vl2,vb2=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True) l0(f"quant+ttt val_loss:{vl2:.4f} val_bpb:{vb2:.4f}") ema.re() break if ss is None and mms is not None: torch.cuda.synchronize() el=1000.*(time.perf_counter()-t0)+tms if el>=mms:ss=step+1 za() for mi in range(ga): if dd:mdl.require_backward_grad_sync=(mi==ga-1) x,y=tl.nb(a.tbt,a.tsl,ga) with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward() torch.cuda.synchronize();el=1000.*(time.perf_counter()-t0)+tms m=lrm(step,el) for o in opts: for g in o.param_groups:g["lr"]=g["base_lr"]*m for o in opts:o.step() if step>=ess:ema.up() if step%a.tle==0 and ma:l0(f"step:{step} lr_mul:{m:.4f}") step+=1 if __name__=="__main__":main()