theapemachine commited on
Commit
39301d6
·
verified ·
1 Parent(s): 1f4765b

Fix: backward() inside @torch .no_grad() — use torch.enable_grad() for dense gradient computation

Browse files
Files changed (1) hide show
  1. exp5_mechanism.py +19 -71
exp5_mechanism.py CHANGED
@@ -29,9 +29,6 @@ import torch,torch.nn as nn,torch.nn.functional as F
29
  import tiktoken
30
  print("imports ok",flush=True)
31
 
32
- # ═══════════════════════════════════════════════════════════════
33
- # DATA
34
- # ═══════════════════════════════════════════════════════════════
35
  class Corpus:
36
  _i=None
37
  @classmethod
@@ -56,9 +53,6 @@ class Corpus:
56
  def mg(s):
57
  g=torch.Generator(device="cpu"); g.manual_seed(s); return g
58
 
59
- # ═══════════════════════════════════════════════════════════════
60
- # SPARSE LINEAR
61
- # ═══════════════════════════════════════════════════════════════
62
  class SparseBwd(torch.autograd.Function):
63
  @staticmethod
64
  def forward(ctx,x,w,b,ac,cs,sdx):
@@ -86,9 +80,6 @@ class SL(nn.Linear):
86
  if not self.se or self.ac is None: return F.linear(x,self.weight,self.bias)
87
  return SparseBwd.apply(x,self.weight,self.bias,self.ac,self.cs,self.sdx)
88
 
89
- # ═══════════════════════════════════════════════════════════════
90
- # MODEL
91
- # ═══════════════════════════════════════════════════════════════
92
  class Attn(nn.Module):
93
  def __init__(self,d,nh,bs,do):
94
  super().__init__(); self.nh=nh; self.hd=d//nh
@@ -127,9 +118,6 @@ class GPT(nn.Module):
127
  def np(self): return sum(p.numel() for p in self.parameters())
128
  def gsl(m): return [x for x in m.modules() if isinstance(x,SL)]
129
 
130
- # ═══════════════════════════════════════════════════════════════
131
- # SCHEDULER (builds similarity matrix during warmup)
132
- # ═══════════════════════════════════════════════════════════════
133
  class Sched:
134
  def __init__(self,model,frac,cs,dev,beta=0.95,sim_hist=128,min_sim=8):
135
  self.frac,self.cs,self.dev,self.beta=frac,cs,dev,beta
@@ -184,17 +172,12 @@ class Sched:
184
  for _,ids in self.m2i.items(): ok[ids[:,None],ids[None,:]]=True
185
  self.similarity=torch.where(ok,S,torch.zeros_like(S))
186
  def mask_jaccard(self):
187
- """Jaccard between current and previous active set."""
188
  if self.prev_act.sum()==0: return 0.0
189
  i=(self.act&self.prev_act).sum().item()
190
  u=(self.act|self.prev_act).sum().item()
191
  return i/max(u,1)
192
 
193
- # ═══════════════════════════════════════════════════════════════
194
- # RELAXERS
195
- # ═══════════════════════════════════════════════════════════════
196
  class GraphRelaxer:
197
- """Graph Laplacian relaxation using the real similarity matrix."""
198
  def __init__(self, sched, alpha=0.1, iters=3):
199
  self.sched,self.alpha,self.iters=sched,alpha,iters
200
  @torch.no_grad()
@@ -213,11 +196,10 @@ class GraphRelaxer:
213
  Wf=W.reshape(nc,-1); Wa=(S_n@Wf).view(nc,cs,di)
214
  W[li]=(1-self.alpha)*W[li]+self.alpha*Wa[li]
215
  m.weight.data=W.view(m.out_features,di)
216
- deltas[m]=W[li]-W_before # (n_inactive, cs, di)
217
  return deltas
218
 
219
  class RollRelaxer:
220
- """Spatial neighbor relaxation via torch.roll."""
221
  def __init__(self, sched, alpha=0.1, iters=3):
222
  self.sched,self.alpha,self.iters=sched,alpha,iters
223
  @torch.no_grad()
@@ -237,7 +219,6 @@ class RollRelaxer:
237
  return deltas
238
 
239
  class RandomRelaxer:
240
- """Control: random similarity matrix (same sparsity pattern, random values)."""
241
  def __init__(self, sched, alpha=0.1, iters=3):
242
  self.sched,self.alpha,self.iters=sched,alpha,iters
243
  self._rand_sim=None
@@ -245,7 +226,6 @@ class RandomRelaxer:
245
  if self._rand_sim is not None: return self._rand_sim
246
  S=self.sched.similarity
247
  if S is None: return None
248
- # Random positive values with same mask structure
249
  R=torch.rand_like(S)*S.abs().mean()
250
  R.fill_diagonal_(0)
251
  ok=torch.zeros_like(R,dtype=torch.bool)
@@ -272,7 +252,6 @@ class RandomRelaxer:
272
  return deltas
273
 
274
  class ShuffledGraphRelaxer:
275
- """Control: real similarity stats, shuffled structure within each layer."""
276
  def __init__(self, sched, alpha=0.1, iters=3):
277
  self.sched,self.alpha,self.iters=sched,alpha,iters
278
  self._shuf_sim=None
@@ -281,11 +260,9 @@ class ShuffledGraphRelaxer:
281
  S=self.sched.similarity
282
  if S is None: return None
283
  Ss=S.clone()
284
- # Shuffle within each layer block
285
  for _,ids in self.sched.m2i.items():
286
  n=len(ids)
287
- block=Ss[ids][:,ids].clone() # (n,n)
288
- # Shuffle rows and columns with same permutation
289
  perm=torch.randperm(n,device=S.device)
290
  block=block[perm][:,perm]
291
  block.fill_diagonal_(0)
@@ -312,12 +289,8 @@ class ShuffledGraphRelaxer:
312
  return deltas
313
 
314
  class NullRelaxer:
315
- """No-op relaxer."""
316
  def relax(self): return {}
317
 
318
- # ═══════════════════════════════════════════════════════════════
319
- # OPTIMIZER
320
- # ═══════════════════════════════════════════════════════════════
321
  class CAdam:
322
  def __init__(self,model,lr=3e-4,cs=64):
323
  self.model,self.lr,self.cs=model,lr,cs
@@ -343,18 +316,11 @@ class CAdam:
343
  s,e=c*self.cs,(c+1)*self.cs
344
  p.data[s:e].sub_(m[s:e]/(torch.sqrt(v[s:e])+1e-8),alpha=self.lr)
345
 
346
- # ═══════════════════════════════════════════════════════════════
347
- # EVAL
348
- # ═══════════════════════════════════════════════════════════════
349
  @torch.no_grad()
350
  def ev(model,corpus,bs,n=20,seed=9999):
351
  model.eval(); ls=[model(*corpus.get_batch("val",bs,mg(seed+i)))[1].item() for i in range(n)]
352
  model.train(); a=sum(ls)/len(ls); return a,math.exp(min(a,20))
353
 
354
- # ═══════════════════════════════════════════════════════════════
355
- # ORACLE GRADIENT DIAGNOSTIC
356
- # ═══════════════════════════════════════════════════════════════
357
- @torch.no_grad()
358
  def compute_relaxer_diagnostics(model, sched, relaxer_deltas, x, y, corpus, bs, cs):
359
  """
360
  Compare relaxer delta on inactive chunks to what dense gradient would have been.
@@ -362,38 +328,33 @@ def compute_relaxer_diagnostics(model, sched, relaxer_deltas, x, y, corpus, bs,
362
  """
363
  if not relaxer_deltas: return None, None
364
 
365
- # Compute dense gradients
366
  for m in gsl(model): m.se=False
367
  for p in model.parameters(): p.grad=None
368
- _,lo=model(x,y); lo.backward()
 
 
369
 
370
  cos_sims=[]; mag_ratios=[]
371
- for m,delta in relaxer_deltas.items():
372
- if m not in sched.m2i: continue
373
- ids=sched.m2i[m]; nc=len(ids); di=m.weight.shape[1]
374
- la=sched.act[ids]; li=~la
375
- if li.sum()==0 or m.weight.grad is None: continue
376
-
377
- # Dense gradient for inactive chunks, reshaped
378
- dense_g=m.weight.grad.view(nc,cs,di)[li] # (n_inact, cs, di)
379
-
380
- # Flatten for cosine/magnitude
381
- d_flat=delta.reshape(-1); g_flat=dense_g.reshape(-1)
382
- dn=d_flat.norm(); gn=g_flat.norm()
383
- if dn>1e-12 and gn>1e-12:
384
- cos_sims.append(F.cosine_similarity(d_flat.unsqueeze(0),g_flat.unsqueeze(0)).item())
385
- mag_ratios.append((dn/gn).item())
386
-
387
- # Restore sparse mode
388
  for m in gsl(model): m.se=True
389
  for p in model.parameters(): p.grad=None
390
 
391
  if not cos_sims: return None, None
392
  return sum(cos_sims)/len(cos_sims), sum(mag_ratios)/len(mag_ratios)
393
 
394
- # ═══════════════════════════════════════════════════════════════
395
- # SINGLE RUN
396
- # ══════════════════════════════════��════════════════════════════
397
  def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
398
  alpha=0.1, iters=3, diag_interval=100):
399
  torch.manual_seed(seed); random.seed(seed)
@@ -410,12 +371,10 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
410
  if is_sparse:
411
  sched=Sched(model,af,cs,dev)
412
  elif needs_relax:
413
- # Dense + relax: need scheduler for similarity matrix but run dense forward/backward
414
  sched=Sched(model,af,cs,dev)
415
 
416
  opt=CAdam(model,lr,cs)
417
 
418
- # Create relaxer
419
  if not needs_relax:
420
  relaxer=NullRelaxer()
421
  elif "random" in mode:
@@ -443,7 +402,6 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
443
  for m in gsl(model): m.se=True; m.sdx=False
444
  else:
445
  for m in gsl(model): m.se=False; m.ac=None
446
- # For dense+relax: still run scheduler to build similarity & set active mask
447
  if sched:
448
  sched.choose(step,wu,an)
449
 
@@ -455,10 +413,8 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
455
 
456
  opt.step()
457
 
458
- # Relaxation (only after annealing completes)
459
  relax_deltas={}
460
  if needs_relax and step>=wu+an:
461
- # For dense+relax: temporarily set active mask so relaxer knows what's "active"
462
  if is_dense and sched:
463
  for m,ids in sched.m2i.items():
464
  m.ac=sched.m2l[m][sched.act[ids]]
@@ -466,7 +422,6 @@ def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
466
  if is_dense and sched:
467
  for m in gsl(model): m.ac=None
468
 
469
- # Diagnostics
470
  if step%50==0:
471
  vl,_=ev(model,corpus,bs,n=10,seed=7777)
472
  val_curve.append((step,vl))
@@ -500,9 +455,6 @@ def runs(cfg,seeds):
500
  sl=(sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5
501
  return {"ml":ml,"sl":sl,"rs":rs,"ms":sum(r["ms"] for r in rs)/len(rs)}
502
 
503
- # ═══════════════════════════════════════════════════════════════
504
- # MAIN
505
- # ═══════════════════════════════════════════════════════════════
506
  def main():
507
  p=argparse.ArgumentParser()
508
  p.add_argument("--device",default="cuda"); p.add_argument("--steps",type=int,default=1000)
@@ -522,7 +474,6 @@ def main():
522
  base=dict(steps=a.steps,bs=a.bs,bsz=a.bsz,nl=a.nl,nh=a.nh,d=a.d,cs=a.cs,af=a.af,
523
  wu=a.wu,an=a.an,lr=a.lr,dev=a.device,alpha=0.1,iters=3)
524
 
525
- # ── Part 1: Main configs ──
526
  configs=[
527
  ("dense", "dense"),
528
  ("dense+relax_graph", "dense+relax_graph"),
@@ -549,7 +500,6 @@ def main():
549
  r=R[name]
550
  print(f"{name:<25} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
551
 
552
- # ── Part 2: Alpha sweep ──
553
  print(f"\n--- Alpha sweep (ema+relax_graph, 5 seeds) ---",flush=True)
554
  print(f"{'alpha':>6} | {'Val Loss':>20} | {'ms/step':>8}",flush=True)
555
  print("-"*42,flush=True)
@@ -559,7 +509,6 @@ def main():
559
  alpha_results[alpha]=r
560
  print(f"{alpha:>6.2f} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
561
 
562
- # ── Part 3: Diagnostics summary ──
563
  print(f"\n--- Diagnostics (grad_cos, mag_ratio) ---",flush=True)
564
  for name in ["ema+relax_graph","ema+relax_roll","ema+relax_random","ema+relax_shuffled"]:
565
  if name not in R: continue
@@ -571,7 +520,6 @@ def main():
571
  gc_m=sum(gc_all)/len(gc_all); mr_m=sum(mr_all)/len(mr_all)
572
  print(f" {name:<25}: grad_cos={gc_m:.4f} mag_ratio={mr_m:.4f}",flush=True)
573
 
574
- # Save
575
  all_results={"configs":R,"alpha_sweep":alpha_results}
576
  with open("exp5.json","w") as f:
577
  json.dump(all_results,f,indent=2,default=str)
 
29
  import tiktoken
30
  print("imports ok",flush=True)
31
 
 
 
 
32
  class Corpus:
33
  _i=None
34
  @classmethod
 
53
  def mg(s):
54
  g=torch.Generator(device="cpu"); g.manual_seed(s); return g
55
 
 
 
 
56
  class SparseBwd(torch.autograd.Function):
57
  @staticmethod
58
  def forward(ctx,x,w,b,ac,cs,sdx):
 
80
  if not self.se or self.ac is None: return F.linear(x,self.weight,self.bias)
81
  return SparseBwd.apply(x,self.weight,self.bias,self.ac,self.cs,self.sdx)
82
 
 
 
 
83
  class Attn(nn.Module):
84
  def __init__(self,d,nh,bs,do):
85
  super().__init__(); self.nh=nh; self.hd=d//nh
 
118
  def np(self): return sum(p.numel() for p in self.parameters())
119
  def gsl(m): return [x for x in m.modules() if isinstance(x,SL)]
120
 
 
 
 
121
  class Sched:
122
  def __init__(self,model,frac,cs,dev,beta=0.95,sim_hist=128,min_sim=8):
123
  self.frac,self.cs,self.dev,self.beta=frac,cs,dev,beta
 
172
  for _,ids in self.m2i.items(): ok[ids[:,None],ids[None,:]]=True
173
  self.similarity=torch.where(ok,S,torch.zeros_like(S))
174
  def mask_jaccard(self):
 
175
  if self.prev_act.sum()==0: return 0.0
176
  i=(self.act&self.prev_act).sum().item()
177
  u=(self.act|self.prev_act).sum().item()
178
  return i/max(u,1)
179
 
 
 
 
180
  class GraphRelaxer:
 
181
  def __init__(self, sched, alpha=0.1, iters=3):
182
  self.sched,self.alpha,self.iters=sched,alpha,iters
183
  @torch.no_grad()
 
196
  Wf=W.reshape(nc,-1); Wa=(S_n@Wf).view(nc,cs,di)
197
  W[li]=(1-self.alpha)*W[li]+self.alpha*Wa[li]
198
  m.weight.data=W.view(m.out_features,di)
199
+ deltas[m]=W[li]-W_before
200
  return deltas
201
 
202
  class RollRelaxer:
 
203
  def __init__(self, sched, alpha=0.1, iters=3):
204
  self.sched,self.alpha,self.iters=sched,alpha,iters
205
  @torch.no_grad()
 
219
  return deltas
220
 
221
  class RandomRelaxer:
 
222
  def __init__(self, sched, alpha=0.1, iters=3):
223
  self.sched,self.alpha,self.iters=sched,alpha,iters
224
  self._rand_sim=None
 
226
  if self._rand_sim is not None: return self._rand_sim
227
  S=self.sched.similarity
228
  if S is None: return None
 
229
  R=torch.rand_like(S)*S.abs().mean()
230
  R.fill_diagonal_(0)
231
  ok=torch.zeros_like(R,dtype=torch.bool)
 
252
  return deltas
253
 
254
  class ShuffledGraphRelaxer:
 
255
  def __init__(self, sched, alpha=0.1, iters=3):
256
  self.sched,self.alpha,self.iters=sched,alpha,iters
257
  self._shuf_sim=None
 
260
  S=self.sched.similarity
261
  if S is None: return None
262
  Ss=S.clone()
 
263
  for _,ids in self.sched.m2i.items():
264
  n=len(ids)
265
+ block=Ss[ids][:,ids].clone()
 
266
  perm=torch.randperm(n,device=S.device)
267
  block=block[perm][:,perm]
268
  block.fill_diagonal_(0)
 
289
  return deltas
290
 
291
  class NullRelaxer:
 
292
  def relax(self): return {}
293
 
 
 
 
294
  class CAdam:
295
  def __init__(self,model,lr=3e-4,cs=64):
296
  self.model,self.lr,self.cs=model,lr,cs
 
316
  s,e=c*self.cs,(c+1)*self.cs
317
  p.data[s:e].sub_(m[s:e]/(torch.sqrt(v[s:e])+1e-8),alpha=self.lr)
318
 
 
 
 
319
  @torch.no_grad()
320
  def ev(model,corpus,bs,n=20,seed=9999):
321
  model.eval(); ls=[model(*corpus.get_batch("val",bs,mg(seed+i)))[1].item() for i in range(n)]
322
  model.train(); a=sum(ls)/len(ls); return a,math.exp(min(a,20))
323
 
 
 
 
 
324
  def compute_relaxer_diagnostics(model, sched, relaxer_deltas, x, y, corpus, bs, cs):
325
  """
326
  Compare relaxer delta on inactive chunks to what dense gradient would have been.
 
328
  """
329
  if not relaxer_deltas: return None, None
330
 
331
+ # Need gradients enabled for the dense forward/backward
332
  for m in gsl(model): m.se=False
333
  for p in model.parameters(): p.grad=None
334
+ with torch.enable_grad():
335
+ _,lo=model(x,y)
336
+ lo.backward()
337
 
338
  cos_sims=[]; mag_ratios=[]
339
+ with torch.no_grad():
340
+ for m,delta in relaxer_deltas.items():
341
+ if m not in sched.m2i: continue
342
+ ids=sched.m2i[m]; nc=len(ids); di=m.weight.shape[1]
343
+ la=sched.act[ids]; li=~la
344
+ if li.sum()==0 or m.weight.grad is None: continue
345
+ dense_g=m.weight.grad.view(nc,cs,di)[li]
346
+ d_flat=delta.reshape(-1); g_flat=dense_g.reshape(-1)
347
+ dn=d_flat.norm(); gn=g_flat.norm()
348
+ if dn>1e-12 and gn>1e-12:
349
+ cos_sims.append(F.cosine_similarity(d_flat.unsqueeze(0),g_flat.unsqueeze(0)).item())
350
+ mag_ratios.append((dn/gn).item())
351
+
 
 
 
 
352
  for m in gsl(model): m.se=True
353
  for p in model.parameters(): p.grad=None
354
 
355
  if not cos_sims: return None, None
356
  return sum(cos_sims)/len(cos_sims), sum(mag_ratios)/len(mag_ratios)
357
 
 
 
 
358
  def run1(mode, steps, bs, bsz, nl, nh, d, cs, af, wu, an, lr, dev, seed,
359
  alpha=0.1, iters=3, diag_interval=100):
360
  torch.manual_seed(seed); random.seed(seed)
 
371
  if is_sparse:
372
  sched=Sched(model,af,cs,dev)
373
  elif needs_relax:
 
374
  sched=Sched(model,af,cs,dev)
375
 
376
  opt=CAdam(model,lr,cs)
377
 
 
378
  if not needs_relax:
379
  relaxer=NullRelaxer()
380
  elif "random" in mode:
 
402
  for m in gsl(model): m.se=True; m.sdx=False
403
  else:
404
  for m in gsl(model): m.se=False; m.ac=None
 
405
  if sched:
406
  sched.choose(step,wu,an)
407
 
 
413
 
414
  opt.step()
415
 
 
416
  relax_deltas={}
417
  if needs_relax and step>=wu+an:
 
418
  if is_dense and sched:
419
  for m,ids in sched.m2i.items():
420
  m.ac=sched.m2l[m][sched.act[ids]]
 
422
  if is_dense and sched:
423
  for m in gsl(model): m.ac=None
424
 
 
425
  if step%50==0:
426
  vl,_=ev(model,corpus,bs,n=10,seed=7777)
427
  val_curve.append((step,vl))
 
455
  sl=(sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5
456
  return {"ml":ml,"sl":sl,"rs":rs,"ms":sum(r["ms"] for r in rs)/len(rs)}
457
 
 
 
 
458
  def main():
459
  p=argparse.ArgumentParser()
460
  p.add_argument("--device",default="cuda"); p.add_argument("--steps",type=int,default=1000)
 
474
  base=dict(steps=a.steps,bs=a.bs,bsz=a.bsz,nl=a.nl,nh=a.nh,d=a.d,cs=a.cs,af=a.af,
475
  wu=a.wu,an=a.an,lr=a.lr,dev=a.device,alpha=0.1,iters=3)
476
 
 
477
  configs=[
478
  ("dense", "dense"),
479
  ("dense+relax_graph", "dense+relax_graph"),
 
500
  r=R[name]
501
  print(f"{name:<25} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
502
 
 
503
  print(f"\n--- Alpha sweep (ema+relax_graph, 5 seeds) ---",flush=True)
504
  print(f"{'alpha':>6} | {'Val Loss':>20} | {'ms/step':>8}",flush=True)
505
  print("-"*42,flush=True)
 
509
  alpha_results[alpha]=r
510
  print(f"{alpha:>6.2f} | {r['ml']:.4f} ± {r['sl']:.4f} | {r['ms']:>7.1f}",flush=True)
511
 
 
512
  print(f"\n--- Diagnostics (grad_cos, mag_ratio) ---",flush=True)
513
  for name in ["ema+relax_graph","ema+relax_roll","ema+relax_random","ema+relax_shuffled"]:
514
  if name not in R: continue
 
520
  gc_m=sum(gc_all)/len(gc_all); mr_m=sum(mr_all)/len(mr_all)
521
  print(f" {name:<25}: grad_cos={gc_m:.4f} mag_ratio={mr_m:.4f}",flush=True)
522
 
 
523
  all_results={"configs":R,"alpha_sweep":alpha_results}
524
  with open("exp5.json","w") as f:
525
  json.dump(all_results,f,indent=2,default=str)