Spaces:
Sleeping
Sleeping
Create stage10.py
Browse files- stage10.py +149 -0
stage10.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# stage10.py
|
| 2 |
+
# Author: Liam Grinstead
|
| 3 |
+
# Purpose: RFT-GPT-30B (8× A100, DDP) Validation — Stage Ten of Twelve
|
| 4 |
+
|
| 5 |
+
import os, math, time, json, random, argparse
|
| 6 |
+
import torch, torch.nn as nn
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 9 |
+
from contextlib import nullcontext
|
| 10 |
+
|
| 11 |
+
# ---------------- Determinism ----------------
|
| 12 |
+
def set_seed(s=1234):
|
| 13 |
+
random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
|
| 14 |
+
|
| 15 |
+
# ---------------- Telemetry ------------------
|
| 16 |
+
class Telemetry:
|
| 17 |
+
def __init__(self, path="stage10_gpt30b.jsonl"):
|
| 18 |
+
self.t0 = time.time(); self.f = open(path,"w")
|
| 19 |
+
def emit(self, **k):
|
| 20 |
+
k["t"] = round(time.time()-self.t0,3)
|
| 21 |
+
line = json.dumps(k,separators=(",",":"))
|
| 22 |
+
print(line); self.f.write(line+"\n"); self.f.flush()
|
| 23 |
+
def close(self): self.f.close()
|
| 24 |
+
|
| 25 |
+
# ---------------- Orbital Coupler ------------
|
| 26 |
+
class Orbital:
|
| 27 |
+
def __init__(self,g=0.006,floor=0.2):
|
| 28 |
+
self.a=0.0; self.b=math.pi/3; self.g=g; self.floor=floor
|
| 29 |
+
def step(self):
|
| 30 |
+
d=(self.b-self.a+math.pi)%(2*math.pi)-math.pi
|
| 31 |
+
if abs(d)<self.floor: d=self.floor*(1 if d>=0 else -1)
|
| 32 |
+
s=math.sin(d)
|
| 33 |
+
self.a=(self.a+self.g*s)%(2*math.pi)
|
| 34 |
+
self.b=(self.b-self.g*s)%(2*math.pi)
|
| 35 |
+
drift=abs((self.a-self.b+math.pi)%(2*math*pi)-math.pi)
|
| 36 |
+
return drift, abs(s)
|
| 37 |
+
|
| 38 |
+
# ---------------- DCLR Optimiser -------------
|
| 39 |
+
class DCLR(torch.optim.Optimizer):
|
| 40 |
+
def __init__(self, params, lr=3e-4, beta=0.9, gamma=0.999, eps=1e-8, cg=0.05):
|
| 41 |
+
super().__init__(params, dict(lr=lr,beta=beta,gamma=gamma,eps=eps,cg=cg))
|
| 42 |
+
@torch.no_grad()
|
| 43 |
+
def step(self, closure=None):
|
| 44 |
+
tot=0.0
|
| 45 |
+
for g in self.param_groups:
|
| 46 |
+
lr,beta,gamma,eps,c=g["lr"],g["beta"],g["gamma"],g["eps"],g["cg"]
|
| 47 |
+
for p in g["params"]:
|
| 48 |
+
if p.grad is None: continue
|
| 49 |
+
st=self.state[p]
|
| 50 |
+
if not st:
|
| 51 |
+
st["m"]=torch.zeros_like(p); st["v"]=torch.zeros_like(p); st["coh"]=torch.zeros_like(p)
|
| 52 |
+
m,v,h=st["m"],st["v"],st["coh"]; g0=p.grad
|
| 53 |
+
m.mul_(beta).add_(g0,alpha=1-beta)
|
| 54 |
+
v.mul_(gamma).addcmul_(g0,g0,value=1-gamma)
|
| 55 |
+
d=g0-m; h.mul_(0.9).add_(d.abs(),alpha=0.1)
|
| 56 |
+
lr_eff=lr/(1+c*h)
|
| 57 |
+
step=lr_eff*m/(v.sqrt()+eps)
|
| 58 |
+
p.add_(-step); tot+=(step*step).sum().item()
|
| 59 |
+
return None,tot
|
| 60 |
+
|
| 61 |
+
# ---------------- GPT-30B Proxy --------------
|
| 62 |
+
class GPTBlock(nn.Module):
|
| 63 |
+
def __init__(self,d=2048,heads=16,mlp_ratio=4):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.n1=nn.LayerNorm(d)
|
| 66 |
+
self.attn=nn.MultiheadAttention(d,heads,batch_first=True)
|
| 67 |
+
self.n2=nn.LayerNorm(d)
|
| 68 |
+
self.mlp=nn.Sequential(nn.Linear(d,int(d*mlp_ratio)),nn.GELU(),nn.Linear(int(d*mlp_ratio),d))
|
| 69 |
+
def forward(self,x):
|
| 70 |
+
h=x; x=self.n1(x); x,_=self.attn(x,x,x,need_weights=False); x=x+h
|
| 71 |
+
h=x; x=self.n2(x); x=x+self.mlp(x); return x
|
| 72 |
+
|
| 73 |
+
class GPT30BProxy(nn.Module):
|
| 74 |
+
def __init__(self,vocab=32768,d=2048,L=24,heads=16,max_len=2048):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.emb=nn.Embedding(vocab,d)
|
| 77 |
+
self.pos=nn.Parameter(torch.zeros(1,max_len,d))
|
| 78 |
+
self.blocks=nn.ModuleList([GPTBlock(d,heads) for _ in range(L)])
|
| 79 |
+
self.norm=nn.LayerNorm(d); self.head=nn.Linear(d,vocab)
|
| 80 |
+
def forward(self,tok):
|
| 81 |
+
x=self.emb(tok)+self.pos[:,:tok.size(1)]
|
| 82 |
+
for blk in self.blocks: x=blk(x)
|
| 83 |
+
x=self.norm(x); return self.head(x)
|
| 84 |
+
|
| 85 |
+
# ---------------- Data -----------------------
|
| 86 |
+
def make_batch(batch=16,seq=1024,vocab=32768,device="cuda"):
|
| 87 |
+
x=torch.randint(0,vocab,(batch,seq),device=device)
|
| 88 |
+
y=torch.roll(x,shifts=-1,dims=1)
|
| 89 |
+
return x,y,batch*seq
|
| 90 |
+
|
| 91 |
+
# ---------------- DDP Setup ------------------
|
| 92 |
+
def ddp_setup():
|
| 93 |
+
dist.init_process_group(backend="nccl")
|
| 94 |
+
rank=dist.get_rank(); world=dist.get_world_size()
|
| 95 |
+
local_rank=int(os.environ.get("LOCAL_RANK",0))
|
| 96 |
+
torch.cuda.set_device(local_rank)
|
| 97 |
+
return rank,world,local_rank
|
| 98 |
+
|
| 99 |
+
def all_reduce_scalar(t: torch.Tensor,op=dist.ReduceOp.SUM):
|
| 100 |
+
if dist.is_initialized(): dist.all_reduce(t,op=op)
|
| 101 |
+
return t
|
| 102 |
+
|
| 103 |
+
# ---------------- Runner ---------------------
|
| 104 |
+
def run(mode="RFT",steps=1000,batch=16,seq=1024,vocab=32768,lr=3e-4,log="stage10_gpt30b.jsonl"):
|
| 105 |
+
rank,world,local_rank=ddp_setup()
|
| 106 |
+
set_seed(1234+rank)
|
| 107 |
+
dev=f"cuda:{local_rank}"
|
| 108 |
+
model=GPT30BProxy(vocab=vocab,max_len=max(2048,seq)).to(dev)
|
| 109 |
+
model=DDP(model,device_ids=[local_rank],output_device=local_rank,find_unused_parameters=False)
|
| 110 |
+
opt=DCLR(model.parameters(),lr=lr) if mode=="RFT" else torch.optim.Adam(model.parameters(),lr=lr)
|
| 111 |
+
loss_fn=nn.CrossEntropyLoss()
|
| 112 |
+
use_bf16=(torch.cuda.is_available() and torch.cuda.is_bf16_supported())
|
| 113 |
+
autocast_ctx=torch.autocast(device_type="cuda",dtype=torch.bfloat16) if use_bf16 else nullcontext()
|
| 114 |
+
orb=Orbital(); tm=Telemetry(log) if rank==0 else None
|
| 115 |
+
for step in range(1,steps+1):
|
| 116 |
+
drift,flux=orb.step()
|
| 117 |
+
x,y,n_tokens=make_batch(batch,seq,vocab,device=dev)
|
| 118 |
+
opt.zero_grad(set_to_none=True)
|
| 119 |
+
with autocast_ctx:
|
| 120 |
+
out=model(x); loss=loss_fn(out.view(-1,out.size(-1)),y.view(-1))
|
| 121 |
+
loss.backward()
|
| 122 |
+
if isinstance(opt,DCLR): _,J=opt.step()
|
| 123 |
+
else: opt.step(); J=0.0
|
| 124 |
+
acc=(out.argmax(-1)==y).float().mean()
|
| 125 |
+
t_loss=torch.tensor(float(loss.item()),device=dev)
|
| 126 |
+
t_acc=torch.tensor(float(acc.item()),device=dev)
|
| 127 |
+
t_J=torch.tensor(float(J*1e-6)/max(1,n_tokens),device=dev)
|
| 128 |
+
all_reduce_scalar(t_loss); all_reduce_scalar(t_acc); all_reduce_scalar(t_J)
|
| 129 |
+
if rank==0:
|
| 130 |
+
tm.emit(mode=mode,step=step,drift=round(drift,3),flux=round(flux,3),
|
| 131 |
+
E_ret=0.996,coh=0.999,
|
| 132 |
+
loss=round(t_loss.item()/world,4),
|
| 133 |
+
acc=round(t_acc.item()/world,3),
|
| 134 |
+
J_token=round(t_J.item()/world,6))
|
| 135 |
+
if tm: tm.close()
|
| 136 |
+
dist.destroy_process_group()
|
| 137 |
+
return f"Stage 10 complete. Telemetry saved to {log}"
|
| 138 |
+
|
| 139 |
+
if __name__=="__main__":
|
| 140 |
+
ap=argparse.ArgumentParser()
|
| 141 |
+
ap.add_argument("--mode",choices=["RFT","BASE"],default="RFT")
|
| 142 |
+
ap.add_argument("--steps",type=int,default=1000)
|
| 143 |
+
ap.add_argument("--batch",type=int,default=16)
|
| 144 |
+
ap.add_argument("--seq",type=int,default=1024)
|
| 145 |
+
ap.add_argument("--vocab",type=int,default=32768)
|
| 146 |
+
ap.add_argument("--lr",type=float,default=3e-4)
|
| 147 |
+
ap.add_argument("--log",type=str,default="stage10_gpt30b.jsonl")
|
| 148 |
+
a=ap.parse_args()
|
| 149 |
+
run(a.mode,a.steps,a.batch,a.seq,a.vocab,a.lr,a.log)
|