Fix: backward() inside @torch .no_grad() — use torch.enable_grad() for dense gradient computation
Browse files- exp5_mechanism.py +19 -71
exp5_mechanism.py
CHANGED
|
@@ -29,9 +29,6 @@ import torch,torch.nn as nn,torch.nn.functional as F
|
|
| 29 |
import tiktoken
|
| 30 |
print("imports ok",flush=True)
|
| 31 |
|
| 32 |
-
# ═══════════════════════════════════════════════════════════════
|
| 33 |
-
# DATA
|
| 34 |
-
# ═══════════════════════════════════════════════════════════════
|
| 35 |
class Corpus:
|
| 36 |
_i=None
|
| 37 |
@classmethod
|
|
@@ -56,9 +53,6 @@ class Corpus:
|
|
| 56 |
def mg(s):
|
| 57 |
g=torch.Generator(device="cpu"); g.manual_seed(s); return g
|
| 58 |
|
| 59 |
-
# ═══════════════════════════════════════════════════════════════
|
| 60 |
-
# SPARSE LINEAR
|
| 61 |
-
# ═══════════════════════════════════════════════════════════════
|
| 62 |
class SparseBwd(torch.autograd.Function):
|
| 63 |
@staticmethod
|
| 64 |
def forward(ctx,x,w,b,ac,cs,sdx):
|
|
@@ -86,9 +80,6 @@ class SL(nn.Linear):
|
|
| 86 |
if not self.se or self.ac is None: return F.linear(x,self.weight,self.bias)
|
| 87 |
return SparseBwd.apply(x,self.weight,self.bias,self.ac,self.cs,self.sdx)
|
| 88 |
|
| 89 |
-
# ═══════════════════════════════════════════════════════════════
|
| 90 |
-
# MODEL
|
| 91 |
-
# ═══════════════════════════════════════════════════════════════
|
| 92 |
class Attn(nn.Module):
|
| 93 |
def __init__(self,d,nh,bs,do):
|
| 94 |
super().__init__(); self.nh=nh; self.hd=d//nh
|
|
@@ -127,9 +118,6 @@ class GPT(nn.Module):
|
|
| 127 |
def np(self): return sum(p.numel() for p in self.parameters())
|
| 128 |
def gsl(m): return [x for x in m.modules() if isinstance(x,SL)]
|
| 129 |
|
| 130 |
-
# ═══════════════════════════════════════════════════════════════
|
| 131 |
-
# SCHEDULER (builds similarity matrix during warmup)
|
| 132 |
-
# ═══════════════════════════════════════════════════════════════
|
| 133 |
class Sched:
|
| 134 |
def __init__(self,model,frac,cs,dev,beta=0.95,sim_hist=128,min_sim=8):
|
| 135 |
self.frac,self.cs,self.dev,self.beta=frac,cs,dev,beta
|
|
@@ -184,17 +172,12 @@ class Sched:
|
|
| 184 |
for _,ids in self.m2i.items(): ok[ids[:,None],ids[None,:]]=True
|
| 185 |
self.similarity=torch.where(ok,S,torch.zeros_like(S))
|
| 186 |
def mask_jaccard(self):
|
| 187 |
-
"""Jaccard between current and previous active set."""
|
| 188 |
if self.prev_act.sum()==0: return 0.0
|
| 189 |
i=(self.act&self.prev_act).sum().item()
|
| 190 |
u=(self.act|self.prev_act).sum().item()
|
| 191 |
return i/max(u,1)
|
| 192 |
|
| 193 |
-
# ═══════════════════════════════════════════════════════════════
|
| 194 |
-
# RELAXERS
|
| 195 |
-
# ═══════════════════════════════════════════════════════════════
|
| 196 |
class GraphRelaxer:
|
| 197 |
-
"""Graph Laplacian relaxation using the real similarity matrix."""
|
| 198 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 199 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 200 |
@torch.no_grad()
|
|
@@ -213,11 +196,10 @@ class GraphRelaxer:
|
|
| 213 |
Wf=W.reshape(nc,-1); Wa=(S_n@Wf).view(nc,cs,di)
|
| 214 |
W[li]=(1-self.alpha)*W[li]+self.alpha*Wa[li]
|
| 215 |
m.weight.data=W.view(m.out_features,di)
|
| 216 |
-
deltas[m]=W[li]-W_before
|
| 217 |
return deltas
|
| 218 |
|
| 219 |
class RollRelaxer:
|
| 220 |
-
"""Spatial neighbor relaxation via torch.roll."""
|
| 221 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 222 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 223 |
@torch.no_grad()
|
|
@@ -237,7 +219,6 @@ class RollRelaxer:
|
|
| 237 |
return deltas
|
| 238 |
|
| 239 |
class RandomRelaxer:
|
| 240 |
-
"""Control: random similarity matrix (same sparsity pattern, random values)."""
|
| 241 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 242 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 243 |
self._rand_sim=None
|
|
@@ -245,7 +226,6 @@ class RandomRelaxer:
|
|
| 245 |
if self._rand_sim is not None: return self._rand_sim
|
| 246 |
S=self.sched.similarity
|
| 247 |
if S is None: return None
|
| 248 |
-
# Random positive values with same mask structure
|
| 249 |
R=torch.rand_like(S)*S.abs().mean()
|
| 250 |
R.fill_diagonal_(0)
|
| 251 |
ok=torch.zeros_like(R,dtype=torch.bool)
|
|
@@ -272,7 +252,6 @@ class RandomRelaxer:
|
|
| 272 |
return deltas
|
| 273 |
|
| 274 |
class ShuffledGraphRelaxer:
|
| 275 |
-
"""Control: real similarity stats, shuffled structure within each layer."""
|
| 276 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 277 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 278 |
self._shuf_sim=None
|
|
@@ -281,11 +260,9 @@ class ShuffledGraphRelaxer:
|
|
| 281 |
S=self.sched.similarity
|
| 282 |
if S is None: return None
|
| 283 |
Ss=S.clone()
|
| 284 |
-
# Shuffle within each layer block
|
| 285 |
for _,ids in self.sched.m2i.items():
|
| 286 |
n=len(ids)
|
| 287 |
-
block=Ss[ids][:,ids].clone()
|
| 288 |
-
# Shuffle rows and columns with same permutation
|
| 289 |
perm=torch.randperm(n,device=S.device)
|
| 290 |
block=block[perm][:,perm]
|
| 291 |
block.fill_diagonal_(0)
|
|
@@ -312,12 +289,8 @@ class ShuffledGraphRelaxer:
|
|
| 312 |
return deltas
|
| 313 |
|
| 314 |
class NullRelaxer:
|
| 315 |
-
"""No-op relaxer."""
|
| 316 |
def relax(self): return {}
|
| 317 |
|
| 318 |
-
# ═══════════════════════════════════════════════════════════════
|
| 319 |
-
# OPTIMIZER
|
| 320 |
-
# ═══════════════════════════════════════════════════════════════
|
| 321 |
class CAdam:
|
| 322 |
def __init__(self,model,lr=3e-4,cs=64):
|
| 323 |
self.model,self.lr,self.cs=model,lr,cs
|
|
@@ -343,18 +316,11 @@ class CAdam:
|
|
| 343 |
s,e=c*self.cs,(c+1)*self.cs
|
| 344 |
p.data[s:e].sub_(m[s:e]/(torch.sqrt(v[s:e])+1e-8),alpha=self.lr)
|
| 345 |
|
| 346 |
-
# ═══════════════════════════════════════════════════════════════
|
| 347 |
-
# EVAL
|
| 348 |
-
# ═══════════════════════════════════════════════════════════════
|
| 349 |
@torch.no_grad()
|
| 350 |
def ev(model,corpus,bs,n=20,seed=9999):
|
| 351 |
model.eval(); ls=[model(*corpus.get_batch("val",bs,mg(seed+i)))[1].item() for i in range(n)]
|
| 352 |
model.train(); a=sum(ls)/len(ls); return a,math.exp(min(a,20))
|
| 353 |
|
| 354 |
-
# ═══════════════════════════════════════════════════════════════
|
| 355 |
-
# ORACLE GRADIENT DIAGNOSTIC
|
| 356 |
-
# ═══════════════════════════════════════════════════════════════
|
| 357 |
-
@torch.no_grad()
|
| 358 |
def compute_relaxer_diagnostics(model, sched, relaxer_deltas, x, y, corpus, bs, cs):
|
| 359 |
"""
|
| 360 |
Compare relaxer delta on inactive chunks to what dense gradient would have been.
|
|
@@ -362,38 +328,33 @@ def compute_relaxer_diagnostics(model, sched, relaxer_deltas, x, y, corpus, bs,
|
|
| 362 |
"""
|
| 363 |
if not relaxer_deltas: return None, None
|
| 364 |
|
| 365 |
-
#
|
| 366 |
for m in gsl(model): m.se=False
|
| 367 |
for p in model.parameters(): p.grad=None
|
| 368 |
-
|
|
|
|
|
|
|
| 369 |
|
| 370 |
cos_sims=[]; mag_ratios=[]
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
cos_sims.append(F.cosine_similarity(d_flat.unsqueeze(0),g_flat.unsqueeze(0)).item())
|
| 385 |
-
mag_ratios.append((dn/gn).item())
|
| 386 |
-
|
| 387 |
-
# Restore sparse mode
|
| 388 |
for m in gsl(model): m.se=True
|
| 389 |
for p in model.parameters(): p.grad=None
|
| 390 |
|
| 391 |
if not cos_sims: return None, None
|
| 392 |
return sum(cos_sims)/len(cos_sims), sum(mag_ratios)/len(mag_ratios)
|
| 393 |
|
| 394 |
-
# ═══════════════════════════════════════════════════════════════
|
| 395 |
-
# SINGLE RUN
|
| 396 |
-
# ══════════════════════════════════��════════════════════════════
|
| 397 |
def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
|
| 398 |
alpha=0.1, iters=3, diag_interval=100):
|
| 399 |
torch.manual_seed(seed); random.seed(seed)
|
|
@@ -410,12 +371,10 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
|
|
| 410 |
if is_sparse:
|
| 411 |
sched=Sched(model,af,cs,dev)
|
| 412 |
elif needs_relax:
|
| 413 |
-
# Dense + relax: need scheduler for similarity matrix but run dense forward/backward
|
| 414 |
sched=Sched(model,af,cs,dev)
|
| 415 |
|
| 416 |
opt=CAdam(model,lr,cs)
|
| 417 |
|
| 418 |
-
# Create relaxer
|
| 419 |
if not needs_relax:
|
| 420 |
relaxer=NullRelaxer()
|
| 421 |
elif "random" in mode:
|
|
@@ -443,7 +402,6 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
|
|
| 443 |
for m in gsl(model): m.se=True; m.sdx=False
|
| 444 |
else:
|
| 445 |
for m in gsl(model): m.se=False; m.ac=None
|
| 446 |
-
# For dense+relax: still run scheduler to build similarity & set active mask
|
| 447 |
if sched:
|
| 448 |
sched.choose(step,wu,an)
|
| 449 |
|
|
@@ -455,10 +413,8 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
|
|
| 455 |
|
| 456 |
opt.step()
|
| 457 |
|
| 458 |
-
# Relaxation (only after annealing completes)
|
| 459 |
relax_deltas={}
|
| 460 |
if needs_relax and step>=wu+an:
|
| 461 |
-
# For dense+relax: temporarily set active mask so relaxer knows what's "active"
|
| 462 |
if is_dense and sched:
|
| 463 |
for m,ids in sched.m2i.items():
|
| 464 |
m.ac=sched.m2l[m][sched.act[ids]]
|
|
@@ -466,7 +422,6 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
|
|
| 466 |
if is_dense and sched:
|
| 467 |
for m in gsl(model): m.ac=None
|
| 468 |
|
| 469 |
-
# Diagnostics
|
| 470 |
if step%50==0:
|
| 471 |
vl,_=ev(model,corpus,bs,n=10,seed=7777)
|
| 472 |
val_curve.append((step,vl))
|
|
@@ -500,9 +455,6 @@ def runs(cfg,seeds):
|
|
| 500 |
sl=(sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5
|
| 501 |
return {"ml":ml,"sl":sl,"rs":rs,"ms":sum(r["ms"] for r in rs)/len(rs)}
|
| 502 |
|
| 503 |
-
# ═══════════════════════════════════════════════════════════════
|
| 504 |
-
# MAIN
|
| 505 |
-
# ═══════════════════════════════════════════════════════════════
|
| 506 |
def main():
|
| 507 |
p=argparse.ArgumentParser()
|
| 508 |
p.add_argument("--device",default="cuda"); p.add_argument("--steps",type=int,default=1000)
|
|
@@ -522,7 +474,6 @@ def main():
|
|
| 522 |
base=dict(steps=a.steps,bs=a.bs,bsz=a.bsz,nl=a.nl,nh=a.nh,d=a.d,cs=a.cs,af=a.af,
|
| 523 |
wu=a.wu,an=a.an,lr=a.lr,dev=a.device,alpha=0.1,iters=3)
|
| 524 |
|
| 525 |
-
# ── Part 1: Main configs ──
|
| 526 |
configs=[
|
| 527 |
("dense", "dense"),
|
| 528 |
("dense+relax_graph", "dense+relax_graph"),
|
|
@@ -549,7 +500,6 @@ def main():
|
|
| 549 |
r=R[name]
|
| 550 |
print(f"{name:<25} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
|
| 551 |
|
| 552 |
-
# ── Part 2: Alpha sweep ──
|
| 553 |
print(f"\n--- Alpha sweep (ema+relax_graph, 5 seeds) ---",flush=True)
|
| 554 |
print(f"{'alpha':>6} | {'Val Loss':>20} | {'ms/step':>8}",flush=True)
|
| 555 |
print("-"*42,flush=True)
|
|
@@ -559,7 +509,6 @@ def main():
|
|
| 559 |
alpha_results[alpha]=r
|
| 560 |
print(f"{alpha:>6.2f} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
|
| 561 |
|
| 562 |
-
# ── Part 3: Diagnostics summary ──
|
| 563 |
print(f"\n--- Diagnostics (grad_cos, mag_ratio) ---",flush=True)
|
| 564 |
for name in ["ema+relax_graph","ema+relax_roll","ema+relax_random","ema+relax_shuffled"]:
|
| 565 |
if name not in R: continue
|
|
@@ -571,7 +520,6 @@ def main():
|
|
| 571 |
gc_m=sum(gc_all)/len(gc_all); mr_m=sum(mr_all)/len(mr_all)
|
| 572 |
print(f" {name:<25}: grad_cos={gc_m:.4f} mag_ratio={mr_m:.4f}",flush=True)
|
| 573 |
|
| 574 |
-
# Save
|
| 575 |
all_results={"configs":R,"alpha_sweep":alpha_results}
|
| 576 |
with open("exp5.json","w") as f:
|
| 577 |
json.dump(all_results,f,indent=2,default=str)
|
|
|
|
| 29 |
import tiktoken
|
| 30 |
print("imports ok",flush=True)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
class Corpus:
|
| 33 |
_i=None
|
| 34 |
@classmethod
|
|
|
|
| 53 |
def mg(s):
|
| 54 |
g=torch.Generator(device="cpu"); g.manual_seed(s); return g
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
class SparseBwd(torch.autograd.Function):
|
| 57 |
@staticmethod
|
| 58 |
def forward(ctx,x,w,b,ac,cs,sdx):
|
|
|
|
| 80 |
if not self.se or self.ac is None: return F.linear(x,self.weight,self.bias)
|
| 81 |
return SparseBwd.apply(x,self.weight,self.bias,self.ac,self.cs,self.sdx)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
class Attn(nn.Module):
|
| 84 |
def __init__(self,d,nh,bs,do):
|
| 85 |
super().__init__(); self.nh=nh; self.hd=d//nh
|
|
|
|
| 118 |
def np(self): return sum(p.numel() for p in self.parameters())
|
| 119 |
def gsl(m): return [x for x in m.modules() if isinstance(x,SL)]
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
class Sched:
|
| 122 |
def __init__(self,model,frac,cs,dev,beta=0.95,sim_hist=128,min_sim=8):
|
| 123 |
self.frac,self.cs,self.dev,self.beta=frac,cs,dev,beta
|
|
|
|
| 172 |
for _,ids in self.m2i.items(): ok[ids[:,None],ids[None,:]]=True
|
| 173 |
self.similarity=torch.where(ok,S,torch.zeros_like(S))
|
| 174 |
def mask_jaccard(self):
|
|
|
|
| 175 |
if self.prev_act.sum()==0: return 0.0
|
| 176 |
i=(self.act&self.prev_act).sum().item()
|
| 177 |
u=(self.act|self.prev_act).sum().item()
|
| 178 |
return i/max(u,1)
|
| 179 |
|
|
|
|
|
|
|
|
|
|
| 180 |
class GraphRelaxer:
|
|
|
|
| 181 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 182 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 183 |
@torch.no_grad()
|
|
|
|
| 196 |
Wf=W.reshape(nc,-1); Wa=(S_n@Wf).view(nc,cs,di)
|
| 197 |
W[li]=(1-self.alpha)*W[li]+self.alpha*Wa[li]
|
| 198 |
m.weight.data=W.view(m.out_features,di)
|
| 199 |
+
deltas[m]=W[li]-W_before
|
| 200 |
return deltas
|
| 201 |
|
| 202 |
class RollRelaxer:
|
|
|
|
| 203 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 204 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 205 |
@torch.no_grad()
|
|
|
|
| 219 |
return deltas
|
| 220 |
|
| 221 |
class RandomRelaxer:
|
|
|
|
| 222 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 223 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 224 |
self._rand_sim=None
|
|
|
|
| 226 |
if self._rand_sim is not None: return self._rand_sim
|
| 227 |
S=self.sched.similarity
|
| 228 |
if S is None: return None
|
|
|
|
| 229 |
R=torch.rand_like(S)*S.abs().mean()
|
| 230 |
R.fill_diagonal_(0)
|
| 231 |
ok=torch.zeros_like(R,dtype=torch.bool)
|
|
|
|
| 252 |
return deltas
|
| 253 |
|
| 254 |
class ShuffledGraphRelaxer:
|
|
|
|
| 255 |
def __init__(self, sched, alpha=0.1, iters=3):
|
| 256 |
self.sched,self.alpha,self.iters=sched,alpha,iters
|
| 257 |
self._shuf_sim=None
|
|
|
|
| 260 |
S=self.sched.similarity
|
| 261 |
if S is None: return None
|
| 262 |
Ss=S.clone()
|
|
|
|
| 263 |
for _,ids in self.sched.m2i.items():
|
| 264 |
n=len(ids)
|
| 265 |
+
block=Ss[ids][:,ids].clone()
|
|
|
|
| 266 |
perm=torch.randperm(n,device=S.device)
|
| 267 |
block=block[perm][:,perm]
|
| 268 |
block.fill_diagonal_(0)
|
|
|
|
| 289 |
return deltas
|
| 290 |
|
| 291 |
class NullRelaxer:
|
|
|
|
| 292 |
def relax(self): return {}
|
| 293 |
|
|
|
|
|
|
|
|
|
|
| 294 |
class CAdam:
|
| 295 |
def __init__(self,model,lr=3e-4,cs=64):
|
| 296 |
self.model,self.lr,self.cs=model,lr,cs
|
|
|
|
| 316 |
s,e=c*self.cs,(c+1)*self.cs
|
| 317 |
p.data[s:e].sub_(m[s:e]/(torch.sqrt(v[s:e])+1e-8),alpha=self.lr)
|
| 318 |
|
|
|
|
|
|
|
|
|
|
| 319 |
@torch.no_grad()
|
| 320 |
def ev(model,corpus,bs,n=20,seed=9999):
|
| 321 |
model.eval(); ls=[model(*corpus.get_batch("val",bs,mg(seed+i)))[1].item() for i in range(n)]
|
| 322 |
model.train(); a=sum(ls)/len(ls); return a,math.exp(min(a,20))
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
def compute_relaxer_diagnostics(model, sched, relaxer_deltas, x, y, corpus, bs, cs):
|
| 325 |
"""
|
| 326 |
Compare relaxer delta on inactive chunks to what dense gradient would have been.
|
|
|
|
| 328 |
"""
|
| 329 |
if not relaxer_deltas: return None, None
|
| 330 |
|
| 331 |
+
# Need gradients enabled for the dense forward/backward
|
| 332 |
for m in gsl(model): m.se=False
|
| 333 |
for p in model.parameters(): p.grad=None
|
| 334 |
+
with torch.enable_grad():
|
| 335 |
+
_,lo=model(x,y)
|
| 336 |
+
lo.backward()
|
| 337 |
|
| 338 |
cos_sims=[]; mag_ratios=[]
|
| 339 |
+
with torch.no_grad():
|
| 340 |
+
for m,delta in relaxer_deltas.items():
|
| 341 |
+
if m not in sched.m2i: continue
|
| 342 |
+
ids=sched.m2i[m]; nc=len(ids); di=m.weight.shape[1]
|
| 343 |
+
la=sched.act[ids]; li=~la
|
| 344 |
+
if li.sum()==0 or m.weight.grad is None: continue
|
| 345 |
+
dense_g=m.weight.grad.view(nc,cs,di)[li]
|
| 346 |
+
d_flat=delta.reshape(-1); g_flat=dense_g.reshape(-1)
|
| 347 |
+
dn=d_flat.norm(); gn=g_flat.norm()
|
| 348 |
+
if dn>1e-12 and gn>1e-12:
|
| 349 |
+
cos_sims.append(F.cosine_similarity(d_flat.unsqueeze(0),g_flat.unsqueeze(0)).item())
|
| 350 |
+
mag_ratios.append((dn/gn).item())
|
| 351 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
for m in gsl(model): m.se=True
|
| 353 |
for p in model.parameters(): p.grad=None
|
| 354 |
|
| 355 |
if not cos_sims: return None, None
|
| 356 |
return sum(cos_sims)/len(cos_sims), sum(mag_ratios)/len(mag_ratios)
|
| 357 |
|
|
|
|
|
|
|
|
|
|
| 358 |
def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
|
| 359 |
alpha=0.1, iters=3, diag_interval=100):
|
| 360 |
torch.manual_seed(seed); random.seed(seed)
|
|
|
|
| 371 |
if is_sparse:
|
| 372 |
sched=Sched(model,af,cs,dev)
|
| 373 |
elif needs_relax:
|
|
|
|
| 374 |
sched=Sched(model,af,cs,dev)
|
| 375 |
|
| 376 |
opt=CAdam(model,lr,cs)
|
| 377 |
|
|
|
|
| 378 |
if not needs_relax:
|
| 379 |
relaxer=NullRelaxer()
|
| 380 |
elif "random" in mode:
|
|
|
|
| 402 |
for m in gsl(model): m.se=True; m.sdx=False
|
| 403 |
else:
|
| 404 |
for m in gsl(model): m.se=False; m.ac=None
|
|
|
|
| 405 |
if sched:
|
| 406 |
sched.choose(step,wu,an)
|
| 407 |
|
|
|
|
| 413 |
|
| 414 |
opt.step()
|
| 415 |
|
|
|
|
| 416 |
relax_deltas={}
|
| 417 |
if needs_relax and step>=wu+an:
|
|
|
|
| 418 |
if is_dense and sched:
|
| 419 |
for m,ids in sched.m2i.items():
|
| 420 |
m.ac=sched.m2l[m][sched.act[ids]]
|
|
|
|
| 422 |
if is_dense and sched:
|
| 423 |
for m in gsl(model): m.ac=None
|
| 424 |
|
|
|
|
| 425 |
if step%50==0:
|
| 426 |
vl,_=ev(model,corpus,bs,n=10,seed=7777)
|
| 427 |
val_curve.append((step,vl))
|
|
|
|
| 455 |
sl=(sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5
|
| 456 |
return {"ml":ml,"sl":sl,"rs":rs,"ms":sum(r["ms"] for r in rs)/len(rs)}
|
| 457 |
|
|
|
|
|
|
|
|
|
|
| 458 |
def main():
|
| 459 |
p=argparse.ArgumentParser()
|
| 460 |
p.add_argument("--device",default="cuda"); p.add_argument("--steps",type=int,default=1000)
|
|
|
|
| 474 |
base=dict(steps=a.steps,bs=a.bs,bsz=a.bsz,nl=a.nl,nh=a.nh,d=a.d,cs=a.cs,af=a.af,
|
| 475 |
wu=a.wu,an=a.an,lr=a.lr,dev=a.device,alpha=0.1,iters=3)
|
| 476 |
|
|
|
|
| 477 |
configs=[
|
| 478 |
("dense", "dense"),
|
| 479 |
("dense+relax_graph", "dense+relax_graph"),
|
|
|
|
| 500 |
r=R[name]
|
| 501 |
print(f"{name:<25} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
|
| 502 |
|
|
|
|
| 503 |
print(f"\n--- Alpha sweep (ema+relax_graph, 5 seeds) ---",flush=True)
|
| 504 |
print(f"{'alpha':>6} | {'Val Loss':>20} | {'ms/step':>8}",flush=True)
|
| 505 |
print("-"*42,flush=True)
|
|
|
|
| 509 |
alpha_results[alpha]=r
|
| 510 |
print(f"{alpha:>6.2f} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
|
| 511 |
|
|
|
|
| 512 |
print(f"\n--- Diagnostics (grad_cos, mag_ratio) ---",flush=True)
|
| 513 |
for name in ["ema+relax_graph","ema+relax_roll","ema+relax_random","ema+relax_shuffled"]:
|
| 514 |
if name not in R: continue
|
|
|
|
| 520 |
gc_m=sum(gc_all)/len(gc_all); mr_m=sum(mr_all)/len(mr_all)
|
| 521 |
print(f" {name:<25}: grad_cos={gc_m:.4f} mag_ratio={mr_m:.4f}",flush=True)
|
| 522 |
|
|
|
|
| 523 |
all_results={"configs":R,"alpha_sweep":alpha_results}
|
| 524 |
with open("exp5.json","w") as f:
|
| 525 |
json.dump(all_results,f,indent=2,default=str)
|