Scott/Codex commited on
Commit ·
df559be
1
Parent(s): 5c319bd
Upgrade VRAM-first DiffusionBlocks trainer
Browse files- README.md +8 -0
- dblocks_train.py +226 -50
- fused_ce.py +45 -19
- nB300_agillm4_vram_dblock.py +15 -0
- relaunch_agillm4_dblock.sh +3 -1
README.md
CHANGED
|
@@ -44,6 +44,8 @@ whose released code is ViT/classification only.
|
|
| 44 |
- `--tie_weights` now means AR, SAT, and NAT share the embedding projection tensor. This drops the live parameter count from 1,213,418,242 to 716,595,202.
|
| 45 |
- Old untied checkpoint head matrices are intentionally skipped under tied mode; core weights still warm-start and the optimizer can rebuild.
|
| 46 |
- SAT now uses fused vocab-streaming CE in the dblock path, and the dblock step releases AR/SAT activations before moving to the next objective.
|
|
|
|
|
|
|
| 47 |
|
| 48 |
## Honest findings
|
| 49 |
- DiffusionBlocks and gradient-checkpointing are **substitutes** for activation
|
|
@@ -59,4 +61,10 @@ only: old untied AR/SAT/NAT head tensors are skipped when tied heads are active,
|
|
| 59 |
optimizer state is allowed to reset. The priority is lower VRAM over preserving every
|
| 60 |
old training assumption.
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
License: Apache-2.0 (matching the upstream method).
|
|
|
|
| 44 |
- `--tie_weights` now means AR, SAT, and NAT share the embedding projection tensor. This drops the live parameter count from 1,213,418,242 to 716,595,202.
|
| 45 |
- Old untied checkpoint head matrices are intentionally skipped under tied mode; core weights still warm-start and the optimizer can rebuild.
|
| 46 |
- SAT now uses fused vocab-streaming CE in the dblock path, and the dblock step releases AR/SAT activations before moving to the next objective.
|
| 47 |
+
- DBlock now uses loss-balanced block scheduling after warmup, per-block EMA diagnostics, sigma-range curriculum, objective weights, and peak VRAM logging.
|
| 48 |
+
- The folded-in DBlock path now builds the dense causal/SAT masks once per objective instead of once per layer, and NAT obeys `--nat_max_tokens` so long-context AR does not force full-context NAT memory.
|
| 49 |
|
| 50 |
## Honest findings
|
| 51 |
- DiffusionBlocks and gradient-checkpointing are **substitutes** for activation
|
|
|
|
| 61 |
optimizer state is allowed to reset. The priority is lower VRAM over preserving every
|
| 62 |
old training assumption.
|
| 63 |
|
| 64 |
+
Upgrade update 2026-05-29: DBlock is no longer just a random-block prototype. The live
|
| 65 |
+
path now has loss-balanced scheduling, sigma curriculum, DBlock objective weights,
|
| 66 |
+
per-block loss/VRAM logging, single-build masks per objective, and NAT token capping.
|
| 67 |
+
These are meant to preserve the VRAM breakthrough while making block-wise training
|
| 68 |
+
less brittle over long runs.
|
| 69 |
+
|
| 70 |
License: Apache-2.0 (matching the upstream method).
|
dblocks_train.py
CHANGED
|
@@ -5,63 +5,239 @@ Block-wise EDM denoising on the real Encoder blocks, supervising AR + SAT(fixed+
|
|
| 5 |
CE. Reuses the live data stream / optimizer / checkpointing of nB300_agillm4.
|
| 6 |
Lazy-imports nB300 inside functions to avoid a circular import.
|
| 7 |
"""
|
| 8 |
-
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import torch.utils.checkpoint as _ck
|
| 10 |
from fused_ce import fused_ce
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
return
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def _dblock_init(core, args):
|
| 21 |
-
B=int(getattr(args,"dblock_blocks",4))
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
print(f"[dblock] DiffusionBlocks mode: {L} layers -> {B} blocks {asg}")
|
| 24 |
-
print(f"[dblock]
|
| 25 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
| 28 |
import nB300_agillm4 as M
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
with M.amp(args.amp):
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
scaler.unscale_(opt)
|
| 65 |
-
nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]],1.0)
|
| 66 |
-
scaler.step(opt)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
CE. Reuses the live data stream / optimizer / checkpointing of nB300_agillm4.
|
| 6 |
Lazy-imports nB300 inside functions to avoid a circular import.
|
| 7 |
"""
|
| 8 |
+
import math
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
import torch.utils.checkpoint as _ck
|
| 15 |
from fused_ce import fused_ce
|
| 16 |
+
|
| 17 |
+
SD = 0.5
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _cdf(x):
|
| 21 |
+
return 0.5 * (1 + math.erf(x / math.sqrt(2)))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _ppf(p):
|
| 25 |
+
return float(torch.erfinv(torch.tensor(2 * p - 1.0)) * math.sqrt(2))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _block_sigmas(B, smin=0.002, smax=80.0, pm=-1.2, ps=1.2):
|
| 29 |
+
a, b = _cdf((math.log(smin) - pm) / ps), _cdf((math.log(smax) - pm) / ps)
|
| 30 |
+
return [float(np.exp(pm + ps * _ppf(a + (b - a) * (i / B)))) for i in range(B + 1)]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _edm_pre(s):
|
| 34 |
+
s = s[:, None, None]
|
| 35 |
+
return SD**2 / (s**2 + SD**2), s * SD / (s**2 + SD**2) ** 0.5, 1 / (s**2 + SD**2) ** 0.5
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _edm_w(s, wmax=5.0):
|
| 39 |
+
return float(((s**2 + SD**2) / (s * SD) ** 2).clamp(max=wmax).mean())
|
| 40 |
+
|
| 41 |
|
| 42 |
def _dblock_init(core, args):
|
| 43 |
+
B = int(getattr(args, "dblock_blocks", 4))
|
| 44 |
+
L = len(core.blocks)
|
| 45 |
+
sp = max(1, L // B)
|
| 46 |
+
asg = [list(range(i * sp, (i + 1) * sp)) for i in range(B)]
|
| 47 |
+
asg[-1] = list(range((B - 1) * sp, L))
|
| 48 |
+
bsig = _block_sigmas(B)
|
| 49 |
+
schedule = getattr(args, "dblock_schedule", "loss_balanced")
|
| 50 |
print(f"[dblock] DiffusionBlocks mode: {L} layers -> {B} blocks {asg}")
|
| 51 |
+
print(f"[dblock] schedule={schedule} sigma boundaries: {[round(x, 3) for x in bsig]}")
|
| 52 |
+
return {
|
| 53 |
+
"B": B,
|
| 54 |
+
"assign": asg,
|
| 55 |
+
"bsig": bsig,
|
| 56 |
+
"step": 0,
|
| 57 |
+
"counts": [0 for _ in range(B)],
|
| 58 |
+
"loss_ema": [None for _ in range(B)],
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _choose_block(state, args):
|
| 63 |
+
B = state["B"]
|
| 64 |
+
schedule = str(getattr(args, "dblock_schedule", "loss_balanced") or "loss_balanced").lower()
|
| 65 |
+
step = int(state.get("step", 0))
|
| 66 |
+
counts = state.setdefault("counts", [0 for _ in range(B)])
|
| 67 |
+
emas = state.setdefault("loss_ema", [None for _ in range(B)])
|
| 68 |
+
if schedule == "random":
|
| 69 |
+
return random.randrange(B)
|
| 70 |
+
if schedule == "roundrobin":
|
| 71 |
+
return step % B
|
| 72 |
+
explore = float(getattr(args, "dblock_explore", 0.05))
|
| 73 |
+
warmup = int(getattr(args, "dblock_warmup_steps", max(8, B * 2)))
|
| 74 |
+
if step < warmup or any(c == 0 for c in counts):
|
| 75 |
+
return min(range(B), key=lambda i: (counts[i], i))
|
| 76 |
+
if explore > 0.0 and random.random() < explore:
|
| 77 |
+
return min(range(B), key=lambda i: (counts[i], i))
|
| 78 |
+
return max(range(B), key=lambda i: (-1.0 if emas[i] is None else emas[i], -counts[i]))
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _sample_sigma(ids, lo, hi, args, state):
|
| 82 |
+
cur_step = int(state.get("step", 0))
|
| 83 |
+
curriculum = int(getattr(args, "dblock_sigma_curriculum_steps", 0))
|
| 84 |
+
if curriculum > 0:
|
| 85 |
+
frac = min(1.0, max(0.05, (cur_step + 1) / float(curriculum)))
|
| 86 |
+
hi = lo * ((hi / max(lo, 1e-8)) ** frac)
|
| 87 |
+
sig_np = np.exp(
|
| 88 |
+
np.random.uniform(
|
| 89 |
+
math.log(max(lo, 1e-4)),
|
| 90 |
+
math.log(max(hi, lo + 1e-4)),
|
| 91 |
+
ids.size(0),
|
| 92 |
+
).astype("float32")
|
| 93 |
+
)
|
| 94 |
+
return torch.from_numpy(sig_np).to(ids.device)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved):
|
| 98 |
+
log_every = int(getattr(args, "dblock_log_every", 50))
|
| 99 |
+
step = int(state.get("step", 0))
|
| 100 |
+
if log_every <= 0 or step % log_every != 0:
|
| 101 |
+
return
|
| 102 |
+
counts = ",".join(str(x) for x in state.get("counts", []))
|
| 103 |
+
emas = ",".join("nan" if x is None else f"{x:.2f}" for x in state.get("loss_ema", []))
|
| 104 |
+
mem = ""
|
| 105 |
+
if peak_alloc is not None:
|
| 106 |
+
mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
|
| 107 |
+
print(
|
| 108 |
+
f"[dblock] step={step} block={bi} layers={layers} "
|
| 109 |
+
f"loss={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f} "
|
| 110 |
+
f"counts=[{counts}] ema=[{emas}]{mem}",
|
| 111 |
+
flush=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _update_stats(state, bi, loss_value):
|
| 116 |
+
B = state["B"]
|
| 117 |
+
counts = state.setdefault("counts", [0 for _ in range(B)])
|
| 118 |
+
emas = state.setdefault("loss_ema", [None for _ in range(B)])
|
| 119 |
+
counts[bi] += 1
|
| 120 |
+
prev = emas[bi]
|
| 121 |
+
beta = 0.96
|
| 122 |
+
emas[bi] = float(loss_value) if prev is None else beta * float(prev) + (1.0 - beta) * float(loss_value)
|
| 123 |
+
state["step"] = int(state.get("step", 0)) + 1
|
| 124 |
+
|
| 125 |
|
| 126 |
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
| 127 |
import nB300_agillm4 as M
|
| 128 |
+
|
| 129 |
+
if torch.cuda.is_available():
|
| 130 |
+
torch.cuda.reset_peak_memory_stats()
|
| 131 |
+
|
| 132 |
+
B = state["B"]
|
| 133 |
+
asg = state["assign"]
|
| 134 |
+
bs = state["bsig"]
|
| 135 |
+
T = ids.size(1)
|
| 136 |
+
bi = _choose_block(state, args)
|
| 137 |
+
lo, hi = sorted([bs[bi], bs[bi + 1]])
|
| 138 |
+
layers = asg[bi]
|
| 139 |
+
sig = _sample_sigma(ids, lo, hi, args, state)
|
| 140 |
+
cs, co, ci = _edm_pre(sig)
|
| 141 |
+
w = _edm_w(sig, float(getattr(args, "dblock_edm_wmax", 5.0)))
|
| 142 |
+
SATB = M.SAT_BLOCK
|
| 143 |
+
ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
|
| 144 |
+
sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
|
| 145 |
+
nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
|
| 146 |
+
|
| 147 |
+
ar_val = 0.0
|
| 148 |
+
sat_val = 0.0
|
| 149 |
+
nat_val = 0.0
|
| 150 |
+
|
| 151 |
+
if ar_weight > 0.0:
|
| 152 |
+
causal = M.causal_mask(T)
|
| 153 |
+
with M.amp(args.amp):
|
| 154 |
+
emb = core.emb(ids)
|
| 155 |
+
zt = emb + sig[:, None, None] * torch.randn_like(emb)
|
| 156 |
+
h = ci * zt
|
| 157 |
+
for li in layers:
|
| 158 |
+
h = _ck.checkpoint(core.blocks[li], h, causal, use_reentrant=False)
|
| 159 |
+
Dn = core.ln(cs * zt + co * h)
|
| 160 |
+
ar = ar_weight * w * fused_ce(Dn[:, :-1].contiguous(), ar_h.proj.weight, ids[:, 1:].contiguous())
|
| 161 |
+
ar_val = float(ar.detach())
|
| 162 |
+
scaler.scale(ar).backward()
|
| 163 |
+
del causal, emb, zt, h, Dn, ar
|
| 164 |
+
|
| 165 |
+
do_sat = (not getattr(args, "ar_only", False)) and (
|
| 166 |
+
int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
|
| 167 |
+
)
|
| 168 |
+
if sat_weight > 0.0 and do_sat:
|
| 169 |
+
smask = M.sat_mask(T)
|
| 170 |
+
with M.amp(args.amp):
|
| 171 |
+
emb2 = core.emb(ids)
|
| 172 |
+
zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
|
| 173 |
+
h2 = ci * zt2
|
| 174 |
+
for li in layers:
|
| 175 |
+
h2 = _ck.checkpoint(core.blocks[li], h2, smask, use_reentrant=False)
|
| 176 |
+
Ds = core.ln(cs * zt2 + co * h2)
|
| 177 |
+
last = Ds[:, -SATB:]
|
| 178 |
+
satf = fused_ce(last.contiguous(), sat_h.proj.weight, ids[:, 1 : SATB + 1].contiguous())
|
| 179 |
+
satv = (
|
| 180 |
+
M.EMIT_LAMBDA
|
| 181 |
+
* F.cross_entropy(
|
| 182 |
+
sat_h.gate(Ds[:, 0].float()),
|
| 183 |
+
torch.ones(ids.size(0), dtype=torch.long, device=ids.device),
|
| 184 |
+
)
|
| 185 |
+
if sat_h.gate is not None
|
| 186 |
+
else 0.0
|
| 187 |
+
)
|
| 188 |
+
sat = sat_weight * w * (satf + satv)
|
| 189 |
+
sat_val = float(sat.detach())
|
| 190 |
+
scaler.scale(sat).backward()
|
| 191 |
+
del smask, emb2, zt2, h2, Ds, last, satf, satv, sat
|
| 192 |
+
|
| 193 |
+
do_nat = (
|
| 194 |
+
nat_h is not None
|
| 195 |
+
and nat_weight > 0.0
|
| 196 |
+
and (not getattr(args, "ar_only", False))
|
| 197 |
+
and int(getattr(args, "nat_every", 1)) > 0
|
| 198 |
+
and (
|
| 199 |
+
int(getattr(args, "nat_every", 1)) <= 1
|
| 200 |
+
or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
if do_nat:
|
| 204 |
+
ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
|
| 205 |
+
nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
|
| 206 |
with M.amp(args.amp):
|
| 207 |
+
nat_in = nat_ids.clone()
|
| 208 |
+
m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
|
| 209 |
+
if not bool(m.any()):
|
| 210 |
+
m[..., -1] = True
|
| 211 |
+
nat_in[m] = M.BLANK
|
| 212 |
+
hn = core.emb(nat_in)
|
| 213 |
+
for li in layers:
|
| 214 |
+
hn = _ck.checkpoint(core.blocks[li], hn, None, use_reentrant=False)
|
| 215 |
+
Dnat = core.ln(hn)
|
| 216 |
+
nat = nat_weight * fused_ce(Dnat[m], nat_h.proj.weight, nat_ids[m])
|
| 217 |
+
nat_val = float(nat.detach())
|
| 218 |
+
scaler.scale(nat).backward()
|
| 219 |
+
del nat_ids, nat_in, m, hn, Dnat, nat
|
| 220 |
+
|
| 221 |
+
total_val = ar_val + sat_val + nat_val
|
| 222 |
+
if not math.isfinite(total_val):
|
| 223 |
+
opt.zero_grad(set_to_none=True)
|
| 224 |
+
if torch.cuda.is_available():
|
| 225 |
+
torch.cuda.empty_cache()
|
| 226 |
+
print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
|
| 227 |
+
_update_stats(state, bi, total_val)
|
| 228 |
+
return total_val
|
| 229 |
+
|
| 230 |
scaler.unscale_(opt)
|
| 231 |
+
nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
|
| 232 |
+
scaler.step(opt)
|
| 233 |
+
scaler.update()
|
| 234 |
+
opt.zero_grad(set_to_none=True)
|
| 235 |
+
|
| 236 |
+
peak_alloc = None
|
| 237 |
+
peak_reserved = None
|
| 238 |
+
if torch.cuda.is_available():
|
| 239 |
+
peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
|
| 240 |
+
peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
|
| 241 |
+
_update_stats(state, bi, total_val)
|
| 242 |
+
_maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved)
|
| 243 |
+
return total_val
|
fused_ce.py
CHANGED
|
@@ -4,28 +4,54 @@ recomputes softmax per vocab-chunk (grad = softmax - onehot). This is the
|
|
| 4 |
DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to
|
| 5 |
the output head instead of network depth."""
|
| 6 |
import torch
|
|
|
|
| 7 |
class FusedCE(torch.autograd.Function):
|
| 8 |
@staticmethod
|
| 9 |
def forward(ctx, h, W, tgt, vchunk=16384):
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
@staticmethod
|
| 21 |
def backward(ctx, go):
|
| 22 |
-
h,W,tgt,lse
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def fused_ce(h, W, tgt, vchunk=16384):
|
| 31 |
-
return FusedCE.apply(h.reshape(-1,h.size(-1)), W, tgt.reshape(-1), vchunk)
|
|
|
|
| 4 |
DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to
|
| 5 |
the output head instead of network depth."""
|
| 6 |
import torch
|
| 7 |
+
|
| 8 |
class FusedCE(torch.autograd.Function):
|
| 9 |
@staticmethod
|
| 10 |
def forward(ctx, h, W, tgt, vchunk=16384):
|
| 11 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 12 |
+
hf = h.float()
|
| 13 |
+
Wf = W.float()
|
| 14 |
+
N, d = h.shape
|
| 15 |
+
V = W.shape[0]
|
| 16 |
+
m = torch.full((N,), -1e30, device=h.device, dtype=torch.float32)
|
| 17 |
+
s = torch.zeros(N, device=h.device, dtype=torch.float32)
|
| 18 |
+
zt = torch.zeros(N, device=h.device, dtype=torch.float32)
|
| 19 |
+
for c in range(0, V, vchunk):
|
| 20 |
+
lg = hf @ Wf[c:c+vchunk].T # [N,vchunk] transient only
|
| 21 |
+
cm = lg.max(1).values
|
| 22 |
+
nm = torch.maximum(m, cm)
|
| 23 |
+
s = s * torch.exp(m - nm) + torch.exp(lg - nm[:, None]).sum(1)
|
| 24 |
+
m = nm
|
| 25 |
+
ic = (tgt >= c) & (tgt < c+vchunk)
|
| 26 |
+
if ic.any():
|
| 27 |
+
zt[ic] = lg[ic, tgt[ic] - c].float()
|
| 28 |
+
lse = m + torch.log(s)
|
| 29 |
+
ctx.save_for_backward(h, W, tgt, lse)
|
| 30 |
+
ctx.vchunk = vchunk
|
| 31 |
+
return (lse - zt).mean()
|
| 32 |
+
|
| 33 |
@staticmethod
|
| 34 |
def backward(ctx, go):
|
| 35 |
+
h, W, tgt, lse = ctx.saved_tensors
|
| 36 |
+
vc = ctx.vchunk
|
| 37 |
+
N, d = h.shape
|
| 38 |
+
V = W.shape[0]
|
| 39 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 40 |
+
hf = h.float()
|
| 41 |
+
Wc_all = W.float()
|
| 42 |
+
gh = torch.zeros_like(hf)
|
| 43 |
+
gW = torch.zeros(W.shape, device=W.device, dtype=torch.float32)
|
| 44 |
+
sc = float(go) / N
|
| 45 |
+
for c in range(0, V, vc):
|
| 46 |
+
Wc = Wc_all[c:c+vc]
|
| 47 |
+
p = torch.exp(hf @ Wc.T - lse[:, None]) # softmax chunk [N,vchunk]
|
| 48 |
+
ic = (tgt >= c) & (tgt < c+vc)
|
| 49 |
+
if ic.any():
|
| 50 |
+
p[ic, tgt[ic] - c] -= 1.0
|
| 51 |
+
p *= sc
|
| 52 |
+
gh += p @ Wc
|
| 53 |
+
gW[c:c+vc] += p.T @ hf
|
| 54 |
+
return gh.to(h.dtype), gW.to(W.dtype), None, None
|
| 55 |
+
|
| 56 |
def fused_ce(h, W, tgt, vchunk=16384):
|
| 57 |
+
return FusedCE.apply(h.reshape(-1, h.size(-1)), W, tgt.reshape(-1), vchunk)
|
nB300_agillm4_vram_dblock.py
CHANGED
|
@@ -2806,6 +2806,21 @@ def main():
|
|
| 2806 |
help="Fraction of positions masked to BLANK for the NAT mask-predict (CMLM) objective.")
|
| 2807 |
tr.add_argument("--dblock", action="store_true", help="DiffusionBlocks block-wise denoising training (low VRAM).")
|
| 2808 |
tr.add_argument("--dblock_blocks", type=int, default=4, help="Partition layers into this many DiffusionBlocks blocks.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2809 |
tr.add_argument("--reinit_nat", action="store_true",
|
| 2810 |
help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
|
| 2811 |
tr.add_argument("--seed_nat_from_ar", action="store_true",
|
|
|
|
| 2806 |
help="Fraction of positions masked to BLANK for the NAT mask-predict (CMLM) objective.")
|
| 2807 |
tr.add_argument("--dblock", action="store_true", help="DiffusionBlocks block-wise denoising training (low VRAM).")
|
| 2808 |
tr.add_argument("--dblock_blocks", type=int, default=4, help="Partition layers into this many DiffusionBlocks blocks.")
|
| 2809 |
+
tr.add_argument("--dblock_schedule", choices=["random", "roundrobin", "loss_balanced"], default="loss_balanced",
|
| 2810 |
+
help="How --dblock chooses the next layer block. loss_balanced focuses blocks whose EMA loss is highest after warmup.")
|
| 2811 |
+
tr.add_argument("--dblock_warmup_steps", type=int, default=16,
|
| 2812 |
+
help="Initial DBlock steps spent covering every block before loss-balanced scheduling.")
|
| 2813 |
+
tr.add_argument("--dblock_explore", type=float, default=0.05,
|
| 2814 |
+
help="Exploration rate for loss-balanced DBlock scheduling.")
|
| 2815 |
+
tr.add_argument("--dblock_log_every", type=int, default=25,
|
| 2816 |
+
help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.")
|
| 2817 |
+
tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000,
|
| 2818 |
+
help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.")
|
| 2819 |
+
tr.add_argument("--dblock_edm_wmax", type=float, default=5.0,
|
| 2820 |
+
help="Cap for EDM loss weighting in DBlock mode.")
|
| 2821 |
+
tr.add_argument("--dblock_ar_weight", type=float, default=1.0)
|
| 2822 |
+
tr.add_argument("--dblock_sat_weight", type=float, default=1.0)
|
| 2823 |
+
tr.add_argument("--dblock_nat_weight", type=float, default=1.0)
|
| 2824 |
tr.add_argument("--reinit_nat", action="store_true",
|
| 2825 |
help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
|
| 2826 |
tr.add_argument("--seed_nat_from_ar", action="store_true",
|
relaunch_agillm4_dblock.sh
CHANGED
|
@@ -13,7 +13,9 @@ CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
|
|
| 13 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 14 |
echo "RELAUNCH_AGILLM4_DBLOCK $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock blocks=${AGILLM4_DBLOCKS:-4} tie_weights=1 attn=${AGILLM_ATTN_BACKEND}"
|
| 15 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 16 |
-
--dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --
|
|
|
|
|
|
|
| 17 |
--batch_size 1 --block "${AGILLM4_BLOCK:-1280}" --amp --attn_backend "${AGILLM_ATTN_BACKEND}" --grad_checkpoint \
|
| 18 |
--optimizer paged_adamw8bit --sat_every 1 --nat_every 1 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
|
| 19 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" \
|
|
|
|
| 13 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 14 |
echo "RELAUNCH_AGILLM4_DBLOCK $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock blocks=${AGILLM4_DBLOCKS:-4} tie_weights=1 attn=${AGILLM_ATTN_BACKEND}"
|
| 15 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 16 |
+
--dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
|
| 17 |
+
--dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
|
| 18 |
+
--dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --tie_weights \
|
| 19 |
--batch_size 1 --block "${AGILLM4_BLOCK:-1280}" --amp --attn_backend "${AGILLM_ATTN_BACKEND}" --grad_checkpoint \
|
| 20 |
--optimizer paged_adamw8bit --sat_every 1 --nat_every 1 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
|
| 21 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" \
|