Scott/Codex commited on
Commit
df559be
·
1 Parent(s): 5c319bd

Upgrade VRAM-first DiffusionBlocks trainer

Browse files
README.md CHANGED
@@ -44,6 +44,8 @@ whose released code is ViT/classification only.
44
  - `--tie_weights` now means AR, SAT, and NAT share the embedding projection tensor. This drops the live parameter count from 1,213,418,242 to 716,595,202.
45
  - Old untied checkpoint head matrices are intentionally skipped under tied mode; core weights still warm-start and the optimizer can rebuild.
46
  - SAT now uses fused vocab-streaming CE in the dblock path, and the dblock step releases AR/SAT activations before moving to the next objective.
 
 
47
 
48
  ## Honest findings
49
  - DiffusionBlocks and gradient-checkpointing are **substitutes** for activation
@@ -59,4 +61,10 @@ only: old untied AR/SAT/NAT head tensors are skipped when tied heads are active,
59
  optimizer state is allowed to reset. The priority is lower VRAM over preserving every
60
  old training assumption.
61
 
 
 
 
 
 
 
62
  License: Apache-2.0 (matching the upstream method).
 
44
  - `--tie_weights` now means AR, SAT, and NAT share the embedding projection tensor. This drops the live parameter count from 1,213,418,242 to 716,595,202.
45
  - Old untied checkpoint head matrices are intentionally skipped under tied mode; core weights still warm-start and the optimizer can rebuild.
46
  - SAT now uses fused vocab-streaming CE in the dblock path, and the dblock step releases AR/SAT activations before moving to the next objective.
47
+ - DBlock now uses loss-balanced block scheduling after warmup, per-block EMA diagnostics, sigma-range curriculum, objective weights, and peak VRAM logging.
48
+ - The folded-in DBlock path now builds the dense causal/SAT masks once per objective instead of once per layer, and NAT obeys `--nat_max_tokens` so long-context AR does not force full-context NAT memory.
49
 
50
  ## Honest findings
51
  - DiffusionBlocks and gradient-checkpointing are **substitutes** for activation
 
61
  optimizer state is allowed to reset. The priority is lower VRAM over preserving every
62
  old training assumption.
63
 
64
+ Upgrade update 2026-05-29: DBlock is no longer just a random-block prototype. The live
65
+ path now has loss-balanced scheduling, sigma curriculum, DBlock objective weights,
66
+ per-block loss/VRAM logging, single-build masks per objective, and NAT token capping.
67
+ These are meant to preserve the VRAM breakthrough while making block-wise training
68
+ less brittle over long runs.
69
+
70
  License: Apache-2.0 (matching the upstream method).
dblocks_train.py CHANGED
@@ -5,63 +5,239 @@ Block-wise EDM denoising on the real Encoder blocks, supervising AR + SAT(fixed+
5
  CE. Reuses the live data stream / optimizer / checkpointing of nB300_agillm4.
6
  Lazy-imports nB300 inside functions to avoid a circular import.
7
  """
8
- import math, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
 
 
 
 
 
9
  import torch.utils.checkpoint as _ck
10
  from fused_ce import fused_ce
11
- SD=0.5
12
- def _cdf(x): return 0.5*(1+math.erf(x/math.sqrt(2)))
13
- def _ppf(p): return float(torch.erfinv(torch.tensor(2*p-1.0))*math.sqrt(2))
14
- def _block_sigmas(B,smin=0.002,smax=80.0,pm=-1.2,ps=1.2):
15
- a,b=_cdf((math.log(smin)-pm)/ps),_cdf((math.log(smax)-pm)/ps)
16
- return [float(np.exp(pm+ps*_ppf(a+(b-a)*(i/B)))) for i in range(B+1)]
17
- def _edm_pre(s): s=s[:,None,None]; return SD**2/(s**2+SD**2), s*SD/(s**2+SD**2)**0.5, 1/(s**2+SD**2)**0.5
18
- def _edm_w(s,wmax=5.0): return float(((s**2+SD**2)/(s*SD)**2).clamp(max=wmax).mean())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def _dblock_init(core, args):
21
- B=int(getattr(args,"dblock_blocks",4)); L=len(core.blocks); sp=max(1,L//B)
22
- asg=[list(range(i*sp,(i+1)*sp)) for i in range(B)]; asg[-1]=list(range((B-1)*sp,L))
 
 
 
 
 
23
  print(f"[dblock] DiffusionBlocks mode: {L} layers -> {B} blocks {asg}")
24
- print(f"[dblock] equi-prob sigma boundaries: {[round(x,3) for x in _block_sigmas(B)]}")
25
- return {"B":B,"assign":asg,"bsig":_block_sigmas(B)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
28
  import nB300_agillm4 as M
29
- B=state["B"]; asg=state["assign"]; bs=state["bsig"]; T=ids.size(1)
30
- bi=random.randrange(B); lo,hi=sorted([bs[bi],bs[bi+1]]); layers=asg[bi]
31
- sig=torch.from_numpy(np.exp(np.random.uniform(math.log(max(lo,1e-4)),math.log(hi),ids.size(0))).astype("float32")).to(ids.device)
32
- cs,co,ci=_edm_pre(sig); w=_edm_w(sig); SATB=M.SAT_BLOCK
33
- # ---- AR: causal diffusion denoise ----
34
- with M.amp(args.amp):
35
- emb=core.emb(ids); zt=emb+sig[:,None,None]*torch.randn_like(emb); h=ci*zt
36
- for li in layers: h=_ck.checkpoint(core.blocks[li], h, M.causal_mask(T), use_reentrant=False)
37
- Dn=core.ln(cs*zt+co*h)
38
- ar=w*fused_ce(Dn[:,:-1].contiguous(), ar_h.proj.weight, ids[:,1:].contiguous())
39
- scaler.scale(ar).backward()
40
- ar_val=float(ar.detach())
41
- del emb, zt, h, Dn, ar
42
- # ---- SAT: block-causal diffusion; fixed proj + variable gate ----
43
- with M.amp(args.amp):
44
- emb2=core.emb(ids); zt2=emb2+sig[:,None,None]*torch.randn_like(emb2); h2=ci*zt2
45
- for li in layers: h2=_ck.checkpoint(core.blocks[li], h2, M.sat_mask(T), use_reentrant=False)
46
- Ds=core.ln(cs*zt2+co*h2); last=Ds[:,-SATB:]
47
- satf=fused_ce(last.contiguous(), sat_h.proj.weight, ids[:,1:SATB+1].contiguous())
48
- satv=(M.EMIT_LAMBDA*F.cross_entropy(sat_h.gate(Ds[:,0].float()), torch.ones(ids.size(0),dtype=torch.long,device=ids.device))) if sat_h.gate is not None else 0.0
49
- sat=w*(satf+satv)
50
- scaler.scale(sat).backward()
51
- sat_val=float(sat.detach())
52
- del emb2, zt2, h2, Ds, last, satf, satv, sat
53
- # ---- NAT: bidirectional mask-predict ----
54
- nat_val=0.0
55
- if nat_h is not None:
56
- ratio=min(max(float(getattr(args,"nat_mask_ratio",0.5)),0.05),0.95)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  with M.amp(args.amp):
58
- nat_ids=ids.clone(); m=torch.rand(ids.shape,device=ids.device)<ratio
59
- if not bool(m.any()): m[...,-1]=True
60
- nat_ids[m]=M.BLANK; hn=core.emb(nat_ids)
61
- for li in layers: hn=_ck.checkpoint(core.blocks[li], hn, None, use_reentrant=False)
62
- Dnat=core.ln(hn)
63
- nat=fused_ce(Dnat[m], nat_h.proj.weight, ids[m]); scaler.scale(nat).backward(); nat_val=float(nat.detach()); del nat_ids, m, hn, Dnat, nat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  scaler.unscale_(opt)
65
- nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]],1.0)
66
- scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
67
- return ar_val+sat_val+nat_val
 
 
 
 
 
 
 
 
 
 
 
5
  CE. Reuses the live data stream / optimizer / checkpointing of nB300_agillm4.
6
  Lazy-imports nB300 inside functions to avoid a circular import.
7
  """
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
  import torch.utils.checkpoint as _ck
15
  from fused_ce import fused_ce
16
+
17
+ SD = 0.5
18
+
19
+
20
+ def _cdf(x):
21
+ return 0.5 * (1 + math.erf(x / math.sqrt(2)))
22
+
23
+
24
+ def _ppf(p):
25
+ return float(torch.erfinv(torch.tensor(2 * p - 1.0)) * math.sqrt(2))
26
+
27
+
28
+ def _block_sigmas(B, smin=0.002, smax=80.0, pm=-1.2, ps=1.2):
29
+ a, b = _cdf((math.log(smin) - pm) / ps), _cdf((math.log(smax) - pm) / ps)
30
+ return [float(np.exp(pm + ps * _ppf(a + (b - a) * (i / B)))) for i in range(B + 1)]
31
+
32
+
33
+ def _edm_pre(s):
34
+ s = s[:, None, None]
35
+ return SD**2 / (s**2 + SD**2), s * SD / (s**2 + SD**2) ** 0.5, 1 / (s**2 + SD**2) ** 0.5
36
+
37
+
38
+ def _edm_w(s, wmax=5.0):
39
+ return float(((s**2 + SD**2) / (s * SD) ** 2).clamp(max=wmax).mean())
40
+
41
 
42
  def _dblock_init(core, args):
43
+ B = int(getattr(args, "dblock_blocks", 4))
44
+ L = len(core.blocks)
45
+ sp = max(1, L // B)
46
+ asg = [list(range(i * sp, (i + 1) * sp)) for i in range(B)]
47
+ asg[-1] = list(range((B - 1) * sp, L))
48
+ bsig = _block_sigmas(B)
49
+ schedule = getattr(args, "dblock_schedule", "loss_balanced")
50
  print(f"[dblock] DiffusionBlocks mode: {L} layers -> {B} blocks {asg}")
51
+ print(f"[dblock] schedule={schedule} sigma boundaries: {[round(x, 3) for x in bsig]}")
52
+ return {
53
+ "B": B,
54
+ "assign": asg,
55
+ "bsig": bsig,
56
+ "step": 0,
57
+ "counts": [0 for _ in range(B)],
58
+ "loss_ema": [None for _ in range(B)],
59
+ }
60
+
61
+
62
+ def _choose_block(state, args):
63
+ B = state["B"]
64
+ schedule = str(getattr(args, "dblock_schedule", "loss_balanced") or "loss_balanced").lower()
65
+ step = int(state.get("step", 0))
66
+ counts = state.setdefault("counts", [0 for _ in range(B)])
67
+ emas = state.setdefault("loss_ema", [None for _ in range(B)])
68
+ if schedule == "random":
69
+ return random.randrange(B)
70
+ if schedule == "roundrobin":
71
+ return step % B
72
+ explore = float(getattr(args, "dblock_explore", 0.05))
73
+ warmup = int(getattr(args, "dblock_warmup_steps", max(8, B * 2)))
74
+ if step < warmup or any(c == 0 for c in counts):
75
+ return min(range(B), key=lambda i: (counts[i], i))
76
+ if explore > 0.0 and random.random() < explore:
77
+ return min(range(B), key=lambda i: (counts[i], i))
78
+ return max(range(B), key=lambda i: (-1.0 if emas[i] is None else emas[i], -counts[i]))
79
+
80
+
81
+ def _sample_sigma(ids, lo, hi, args, state):
82
+ cur_step = int(state.get("step", 0))
83
+ curriculum = int(getattr(args, "dblock_sigma_curriculum_steps", 0))
84
+ if curriculum > 0:
85
+ frac = min(1.0, max(0.05, (cur_step + 1) / float(curriculum)))
86
+ hi = lo * ((hi / max(lo, 1e-8)) ** frac)
87
+ sig_np = np.exp(
88
+ np.random.uniform(
89
+ math.log(max(lo, 1e-4)),
90
+ math.log(max(hi, lo + 1e-4)),
91
+ ids.size(0),
92
+ ).astype("float32")
93
+ )
94
+ return torch.from_numpy(sig_np).to(ids.device)
95
+
96
+
97
+ def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved):
98
+ log_every = int(getattr(args, "dblock_log_every", 50))
99
+ step = int(state.get("step", 0))
100
+ if log_every <= 0 or step % log_every != 0:
101
+ return
102
+ counts = ",".join(str(x) for x in state.get("counts", []))
103
+ emas = ",".join("nan" if x is None else f"{x:.2f}" for x in state.get("loss_ema", []))
104
+ mem = ""
105
+ if peak_alloc is not None:
106
+ mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
107
+ print(
108
+ f"[dblock] step={step} block={bi} layers={layers} "
109
+ f"loss={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f} "
110
+ f"counts=[{counts}] ema=[{emas}]{mem}",
111
+ flush=True,
112
+ )
113
+
114
+
115
+ def _update_stats(state, bi, loss_value):
116
+ B = state["B"]
117
+ counts = state.setdefault("counts", [0 for _ in range(B)])
118
+ emas = state.setdefault("loss_ema", [None for _ in range(B)])
119
+ counts[bi] += 1
120
+ prev = emas[bi]
121
+ beta = 0.96
122
+ emas[bi] = float(loss_value) if prev is None else beta * float(prev) + (1.0 - beta) * float(loss_value)
123
+ state["step"] = int(state.get("step", 0)) + 1
124
+
125
 
126
  def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
127
  import nB300_agillm4 as M
128
+
129
+ if torch.cuda.is_available():
130
+ torch.cuda.reset_peak_memory_stats()
131
+
132
+ B = state["B"]
133
+ asg = state["assign"]
134
+ bs = state["bsig"]
135
+ T = ids.size(1)
136
+ bi = _choose_block(state, args)
137
+ lo, hi = sorted([bs[bi], bs[bi + 1]])
138
+ layers = asg[bi]
139
+ sig = _sample_sigma(ids, lo, hi, args, state)
140
+ cs, co, ci = _edm_pre(sig)
141
+ w = _edm_w(sig, float(getattr(args, "dblock_edm_wmax", 5.0)))
142
+ SATB = M.SAT_BLOCK
143
+ ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
144
+ sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
145
+ nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
146
+
147
+ ar_val = 0.0
148
+ sat_val = 0.0
149
+ nat_val = 0.0
150
+
151
+ if ar_weight > 0.0:
152
+ causal = M.causal_mask(T)
153
+ with M.amp(args.amp):
154
+ emb = core.emb(ids)
155
+ zt = emb + sig[:, None, None] * torch.randn_like(emb)
156
+ h = ci * zt
157
+ for li in layers:
158
+ h = _ck.checkpoint(core.blocks[li], h, causal, use_reentrant=False)
159
+ Dn = core.ln(cs * zt + co * h)
160
+ ar = ar_weight * w * fused_ce(Dn[:, :-1].contiguous(), ar_h.proj.weight, ids[:, 1:].contiguous())
161
+ ar_val = float(ar.detach())
162
+ scaler.scale(ar).backward()
163
+ del causal, emb, zt, h, Dn, ar
164
+
165
+ do_sat = (not getattr(args, "ar_only", False)) and (
166
+ int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
167
+ )
168
+ if sat_weight > 0.0 and do_sat:
169
+ smask = M.sat_mask(T)
170
+ with M.amp(args.amp):
171
+ emb2 = core.emb(ids)
172
+ zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
173
+ h2 = ci * zt2
174
+ for li in layers:
175
+ h2 = _ck.checkpoint(core.blocks[li], h2, smask, use_reentrant=False)
176
+ Ds = core.ln(cs * zt2 + co * h2)
177
+ last = Ds[:, -SATB:]
178
+ satf = fused_ce(last.contiguous(), sat_h.proj.weight, ids[:, 1 : SATB + 1].contiguous())
179
+ satv = (
180
+ M.EMIT_LAMBDA
181
+ * F.cross_entropy(
182
+ sat_h.gate(Ds[:, 0].float()),
183
+ torch.ones(ids.size(0), dtype=torch.long, device=ids.device),
184
+ )
185
+ if sat_h.gate is not None
186
+ else 0.0
187
+ )
188
+ sat = sat_weight * w * (satf + satv)
189
+ sat_val = float(sat.detach())
190
+ scaler.scale(sat).backward()
191
+ del smask, emb2, zt2, h2, Ds, last, satf, satv, sat
192
+
193
+ do_nat = (
194
+ nat_h is not None
195
+ and nat_weight > 0.0
196
+ and (not getattr(args, "ar_only", False))
197
+ and int(getattr(args, "nat_every", 1)) > 0
198
+ and (
199
+ int(getattr(args, "nat_every", 1)) <= 1
200
+ or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
201
+ )
202
+ )
203
+ if do_nat:
204
+ ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
205
+ nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
206
  with M.amp(args.amp):
207
+ nat_in = nat_ids.clone()
208
+ m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
209
+ if not bool(m.any()):
210
+ m[..., -1] = True
211
+ nat_in[m] = M.BLANK
212
+ hn = core.emb(nat_in)
213
+ for li in layers:
214
+ hn = _ck.checkpoint(core.blocks[li], hn, None, use_reentrant=False)
215
+ Dnat = core.ln(hn)
216
+ nat = nat_weight * fused_ce(Dnat[m], nat_h.proj.weight, nat_ids[m])
217
+ nat_val = float(nat.detach())
218
+ scaler.scale(nat).backward()
219
+ del nat_ids, nat_in, m, hn, Dnat, nat
220
+
221
+ total_val = ar_val + sat_val + nat_val
222
+ if not math.isfinite(total_val):
223
+ opt.zero_grad(set_to_none=True)
224
+ if torch.cuda.is_available():
225
+ torch.cuda.empty_cache()
226
+ print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
227
+ _update_stats(state, bi, total_val)
228
+ return total_val
229
+
230
  scaler.unscale_(opt)
231
+ nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
232
+ scaler.step(opt)
233
+ scaler.update()
234
+ opt.zero_grad(set_to_none=True)
235
+
236
+ peak_alloc = None
237
+ peak_reserved = None
238
+ if torch.cuda.is_available():
239
+ peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
240
+ peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
241
+ _update_stats(state, bi, total_val)
242
+ _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved)
243
+ return total_val
fused_ce.py CHANGED
@@ -4,28 +4,54 @@ recomputes softmax per vocab-chunk (grad = softmax - onehot). This is the
4
  DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to
5
  the output head instead of network depth."""
6
  import torch
 
7
  class FusedCE(torch.autograd.Function):
8
  @staticmethod
9
  def forward(ctx, h, W, tgt, vchunk=16384):
10
- N,d=h.shape; V=W.shape[0]; hf=h.float()
11
- m=torch.full((N,),-1e30,device=h.device); s=torch.zeros(N,device=h.device); zt=torch.zeros(N,device=h.device)
12
- for c in range(0,V,vchunk):
13
- lg=hf@W[c:c+vchunk].float().T # [N,vchunk] transient only
14
- cm=lg.max(1).values; nm=torch.maximum(m,cm)
15
- s=s*torch.exp(m-nm)+torch.exp(lg-nm[:,None]).sum(1); m=nm
16
- ic=(tgt>=c)&(tgt<c+vchunk)
17
- if ic.any(): zt[ic]=lg[ic,tgt[ic]-c]
18
- lse=m+torch.log(s); ctx.save_for_backward(h,W,tgt,lse); ctx.vchunk=vchunk
19
- return (lse-zt).mean()
 
 
 
 
 
 
 
 
 
 
 
 
20
  @staticmethod
21
  def backward(ctx, go):
22
- h,W,tgt,lse=ctx.saved_tensors; vc=ctx.vchunk; N,d=h.shape; V=W.shape[0]; hf=h.float()
23
- gh=torch.zeros_like(hf); gW=torch.zeros(W.shape,device=W.device,dtype=torch.float32); sc=float(go)/N
24
- for c in range(0,V,vc):
25
- Wc=W[c:c+vc].float(); p=torch.exp(hf@Wc.T-lse[:,None]) # softmax chunk [N,vchunk]
26
- ic=(tgt>=c)&(tgt<c+vc)
27
- if ic.any(): p[ic,tgt[ic]-c]-=1.0
28
- p*=sc; gh+=p@Wc; gW[c:c+vc]+=p.T@hf
29
- return gh.to(h.dtype), gW.to(W.dtype), None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def fused_ce(h, W, tgt, vchunk=16384):
31
- return FusedCE.apply(h.reshape(-1,h.size(-1)), W, tgt.reshape(-1), vchunk)
 
4
  DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to
5
  the output head instead of network depth."""
6
  import torch
7
+
8
  class FusedCE(torch.autograd.Function):
9
  @staticmethod
10
  def forward(ctx, h, W, tgt, vchunk=16384):
11
+ with torch.cuda.amp.autocast(enabled=False):
12
+ hf = h.float()
13
+ Wf = W.float()
14
+ N, d = h.shape
15
+ V = W.shape[0]
16
+ m = torch.full((N,), -1e30, device=h.device, dtype=torch.float32)
17
+ s = torch.zeros(N, device=h.device, dtype=torch.float32)
18
+ zt = torch.zeros(N, device=h.device, dtype=torch.float32)
19
+ for c in range(0, V, vchunk):
20
+ lg = hf @ Wf[c:c+vchunk].T # [N,vchunk] transient only
21
+ cm = lg.max(1).values
22
+ nm = torch.maximum(m, cm)
23
+ s = s * torch.exp(m - nm) + torch.exp(lg - nm[:, None]).sum(1)
24
+ m = nm
25
+ ic = (tgt >= c) & (tgt < c+vchunk)
26
+ if ic.any():
27
+ zt[ic] = lg[ic, tgt[ic] - c].float()
28
+ lse = m + torch.log(s)
29
+ ctx.save_for_backward(h, W, tgt, lse)
30
+ ctx.vchunk = vchunk
31
+ return (lse - zt).mean()
32
+
33
  @staticmethod
34
  def backward(ctx, go):
35
+ h, W, tgt, lse = ctx.saved_tensors
36
+ vc = ctx.vchunk
37
+ N, d = h.shape
38
+ V = W.shape[0]
39
+ with torch.cuda.amp.autocast(enabled=False):
40
+ hf = h.float()
41
+ Wc_all = W.float()
42
+ gh = torch.zeros_like(hf)
43
+ gW = torch.zeros(W.shape, device=W.device, dtype=torch.float32)
44
+ sc = float(go) / N
45
+ for c in range(0, V, vc):
46
+ Wc = Wc_all[c:c+vc]
47
+ p = torch.exp(hf @ Wc.T - lse[:, None]) # softmax chunk [N,vchunk]
48
+ ic = (tgt >= c) & (tgt < c+vc)
49
+ if ic.any():
50
+ p[ic, tgt[ic] - c] -= 1.0
51
+ p *= sc
52
+ gh += p @ Wc
53
+ gW[c:c+vc] += p.T @ hf
54
+ return gh.to(h.dtype), gW.to(W.dtype), None, None
55
+
56
  def fused_ce(h, W, tgt, vchunk=16384):
57
+ return FusedCE.apply(h.reshape(-1, h.size(-1)), W, tgt.reshape(-1), vchunk)
nB300_agillm4_vram_dblock.py CHANGED
@@ -2806,6 +2806,21 @@ def main():
2806
  help="Fraction of positions masked to BLANK for the NAT mask-predict (CMLM) objective.")
2807
  tr.add_argument("--dblock", action="store_true", help="DiffusionBlocks block-wise denoising training (low VRAM).")
2808
  tr.add_argument("--dblock_blocks", type=int, default=4, help="Partition layers into this many DiffusionBlocks blocks.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2809
  tr.add_argument("--reinit_nat", action="store_true",
2810
  help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
2811
  tr.add_argument("--seed_nat_from_ar", action="store_true",
 
2806
  help="Fraction of positions masked to BLANK for the NAT mask-predict (CMLM) objective.")
2807
  tr.add_argument("--dblock", action="store_true", help="DiffusionBlocks block-wise denoising training (low VRAM).")
2808
  tr.add_argument("--dblock_blocks", type=int, default=4, help="Partition layers into this many DiffusionBlocks blocks.")
2809
+ tr.add_argument("--dblock_schedule", choices=["random", "roundrobin", "loss_balanced"], default="loss_balanced",
2810
+ help="How --dblock chooses the next layer block. loss_balanced focuses blocks whose EMA loss is highest after warmup.")
2811
+ tr.add_argument("--dblock_warmup_steps", type=int, default=16,
2812
+ help="Initial DBlock steps spent covering every block before loss-balanced scheduling.")
2813
+ tr.add_argument("--dblock_explore", type=float, default=0.05,
2814
+ help="Exploration rate for loss-balanced DBlock scheduling.")
2815
+ tr.add_argument("--dblock_log_every", type=int, default=25,
2816
+ help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.")
2817
+ tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000,
2818
+ help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.")
2819
+ tr.add_argument("--dblock_edm_wmax", type=float, default=5.0,
2820
+ help="Cap for EDM loss weighting in DBlock mode.")
2821
+ tr.add_argument("--dblock_ar_weight", type=float, default=1.0)
2822
+ tr.add_argument("--dblock_sat_weight", type=float, default=1.0)
2823
+ tr.add_argument("--dblock_nat_weight", type=float, default=1.0)
2824
  tr.add_argument("--reinit_nat", action="store_true",
2825
  help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
2826
  tr.add_argument("--seed_nat_from_ar", action="store_true",
relaunch_agillm4_dblock.sh CHANGED
@@ -13,7 +13,9 @@ CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
13
  exec >> /workspace/agillm4_floor_train.log 2>&1
14
  echo "RELAUNCH_AGILLM4_DBLOCK $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock blocks=${AGILLM4_DBLOCKS:-4} tie_weights=1 attn=${AGILLM_ATTN_BACKEND}"
15
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
16
- --dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --tie_weights \
 
 
17
  --batch_size 1 --block "${AGILLM4_BLOCK:-1280}" --amp --attn_backend "${AGILLM_ATTN_BACKEND}" --grad_checkpoint \
18
  --optimizer paged_adamw8bit --sat_every 1 --nat_every 1 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
19
  --token_param_ratio 100 --save_dir "$SAVE_DIR" \
 
13
  exec >> /workspace/agillm4_floor_train.log 2>&1
14
  echo "RELAUNCH_AGILLM4_DBLOCK $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock blocks=${AGILLM4_DBLOCKS:-4} tie_weights=1 attn=${AGILLM_ATTN_BACKEND}"
15
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
16
+ --dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
17
+ --dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
18
+ --dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --tie_weights \
19
  --batch_size 1 --block "${AGILLM4_BLOCK:-1280}" --amp --attn_backend "${AGILLM_ATTN_BACKEND}" --grad_checkpoint \
20
  --optimizer paged_adamw8bit --sat_every 1 --nat_every 1 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
21
  --token_param_ratio 100 --save_dir "$SAVE_DIR" \