Scott/Codex commited on
Commit ·
690cf55
1
Parent(s): 468a571
Add DBlock profiler and speed-tuned batch config
Browse files- README.md +4 -0
- dblocks_train.py +108 -9
- nB300_agillm4_vram_dblock.py +26 -0
- relaunch_agillm4_dblock_sg2.sh +4 -3
README.md
CHANGED
|
@@ -85,4 +85,8 @@ DiffusionBlocks, gradient-checkpointed blocks, tied heads, and structured masks.
|
|
| 85 |
|
| 86 |
Sublinear coverage update 2026-05-29: the saved AGILLM-4 trainer snapshot now matches the live v2 sparse global memory path. It fixes gathered ALiBi distance, suppresses duplicate local/anchor candidates before softmax, uses hybrid full-span + recent-tail anchors with explicit `--sublinear_sinks` and `--sublinear_recent_anchors`, and includes optional pooled K/V landmark summaries behind `--sublinear_pooled_landmarks`. At the live 128/128/128 profile it keeps deep-past coverage while preserving recent anchors and the same VRAM-first key budget. See `sublinear_improved_snippet.py` for the minimal blocks and `sublinear_improved.py` for the coverage demo/standalone selector.
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
License: Apache-2.0 (matching the upstream method).
|
|
|
|
| 85 |
|
| 86 |
Sublinear coverage update 2026-05-29: the saved AGILLM-4 trainer snapshot now matches the live v2 sparse global memory path. It fixes gathered ALiBi distance, suppresses duplicate local/anchor candidates before softmax, uses hybrid full-span + recent-tail anchors with explicit `--sublinear_sinks` and `--sublinear_recent_anchors`, and includes optional pooled K/V landmark summaries behind `--sublinear_pooled_landmarks`. At the live 128/128/128 profile it keeps deep-past coverage while preserving recent anchors and the same VRAM-first key budget. See `sublinear_improved_snippet.py` for the minimal blocks and `sublinear_improved.py` for the coverage demo/standalone selector.
|
| 87 |
|
| 88 |
+
|
| 89 |
+
Profiling/speed update 2026-05-29: added in-process DBlock profiling (`--profile_steps`, `--profile_log_every`) after external ptrace profiling was blocked on Vast. The profile showed the bottleneck is transformer recompute/backward, not fused CE or the optimizer: at B=2 full checkpointing, AR backward averaged ~605 ms/step, AR forward ~184 ms, CE ~4.5 ms, optimizer ~17 ms. Tested speed levers live: no checkpointing OOMed at B=2 and fell to B=1, selective checkpoint stride=2 fit but hugged VRAM and reached ~2.94k tok/s, B=5/6 hit a memory-pressure cliff, while B=4 with full DBlock checkpointing was the best stable official setting (~3.0k tok/s warm window, ~13.2 GB tensor peak / ~17.6 GB reserved, ETA ~269-275 days). The live relaunch now uses `--batch_size 4 --grad_checkpoint --dblock_checkpoint_stride 1` and leaves selective checkpointing available for future context/batch tradeoffs.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
License: Apache-2.0 (matching the upstream method).
|
dblocks_train.py
CHANGED
|
@@ -7,6 +7,8 @@ Lazy-imports nB300 inside functions to avoid a circular import.
|
|
| 7 |
"""
|
| 8 |
import math
|
| 9 |
import random
|
|
|
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
|
@@ -17,6 +19,63 @@ from fused_ce import fused_ce
|
|
| 17 |
SD = 0.5
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def _cdf(x):
|
| 21 |
return 0.5 * (1 + math.erf(x / math.sqrt(2)))
|
| 22 |
|
|
@@ -129,6 +188,17 @@ def _run_block(block, x, mask, use_checkpoint):
|
|
| 129 |
return block(x, mask)
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
def _sample_token_loss_inputs(hidden, targets, max_tokens):
|
| 133 |
max_tokens = int(max_tokens or 0)
|
| 134 |
if max_tokens <= 0:
|
|
@@ -173,9 +243,12 @@ def _choose_objectives(state, args, ar_weight, sat_weight, nat_weight, do_sat_pe
|
|
| 173 |
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
| 174 |
import nB300_agillm4 as M
|
| 175 |
|
|
|
|
|
|
|
| 176 |
if torch.cuda.is_available():
|
| 177 |
torch.cuda.reset_peak_memory_stats()
|
| 178 |
|
|
|
|
| 179 |
B = state["B"]
|
| 180 |
asg = state["assign"]
|
| 181 |
bs = state["bsig"]
|
|
@@ -206,6 +279,7 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 206 |
run_ar, run_sat, run_nat, objective = _choose_objectives(
|
| 207 |
state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
|
| 208 |
)
|
|
|
|
| 209 |
|
| 210 |
ar_val = 0.0
|
| 211 |
sat_val = 0.0
|
|
@@ -213,34 +287,44 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 213 |
|
| 214 |
if run_ar:
|
| 215 |
causal = M.causal_mask(T, structured=M.use_structured_masks(args))
|
|
|
|
| 216 |
with M.amp(args.amp):
|
| 217 |
emb = core.emb(ids)
|
| 218 |
zt = emb + sig[:, None, None] * torch.randn_like(emb)
|
| 219 |
h = ci * zt
|
| 220 |
-
for li in layers:
|
| 221 |
-
h = _run_block(core.blocks[li], h, causal, use_layer_checkpoint)
|
| 222 |
Dn = core.ln(cs * zt + co * h)
|
|
|
|
|
|
|
| 223 |
ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
|
| 224 |
Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
|
| 225 |
)
|
| 226 |
ar = ar_weight * w * fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
|
| 227 |
ar_val = float(ar.detach())
|
|
|
|
|
|
|
| 228 |
scaler.scale(ar).backward()
|
|
|
|
| 229 |
del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar, ar_used, ar_total
|
| 230 |
|
| 231 |
if run_sat:
|
| 232 |
smask = M.sat_mask(T, structured=M.use_structured_masks(args))
|
|
|
|
| 233 |
with M.amp(args.amp):
|
| 234 |
emb2 = core.emb(ids)
|
| 235 |
zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
|
| 236 |
h2 = ci * zt2
|
| 237 |
-
for li in layers:
|
| 238 |
-
h2 = _run_block(core.blocks[li], h2, smask, use_layer_checkpoint)
|
| 239 |
Ds = core.ln(cs * zt2 + co * h2)
|
| 240 |
last = Ds[:, -SATB:]
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
| 244 |
satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
|
| 245 |
satv = (
|
| 246 |
M.EMIT_LAMBDA
|
|
@@ -252,13 +336,17 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 252 |
else 0.0
|
| 253 |
)
|
| 254 |
sat = sat_weight * w * (satf + satv)
|
|
|
|
| 255 |
sat_val = float(sat.detach())
|
|
|
|
| 256 |
scaler.scale(sat).backward()
|
|
|
|
| 257 |
del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat
|
| 258 |
|
| 259 |
if run_nat:
|
| 260 |
ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
|
| 261 |
nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
|
|
|
|
| 262 |
with M.amp(args.amp):
|
| 263 |
nat_in = nat_ids.clone()
|
| 264 |
m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
|
|
@@ -266,9 +354,11 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 266 |
m[..., -1] = True
|
| 267 |
nat_in[m] = M.BLANK
|
| 268 |
hn = core.emb(nat_in)
|
| 269 |
-
for li in layers:
|
| 270 |
-
hn = _run_block(core.blocks[li], hn, None, use_layer_checkpoint)
|
| 271 |
Dnat = core.ln(hn)
|
|
|
|
|
|
|
| 272 |
nat_hidden = Dnat[m]
|
| 273 |
nat_targets = nat_ids[m]
|
| 274 |
nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
|
|
@@ -276,7 +366,10 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 276 |
)
|
| 277 |
nat = nat_weight * fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
|
| 278 |
nat_val = float(nat.detach())
|
|
|
|
|
|
|
| 279 |
scaler.scale(nat).backward()
|
|
|
|
| 280 |
del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat, nat_used, nat_total
|
| 281 |
|
| 282 |
total_val = ar_val + sat_val + nat_val
|
|
@@ -285,20 +378,26 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 285 |
if torch.cuda.is_available():
|
| 286 |
torch.cuda.empty_cache()
|
| 287 |
print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
|
|
|
|
|
|
|
| 288 |
_update_stats(state, bi, total_val)
|
| 289 |
return total_val
|
| 290 |
|
|
|
|
| 291 |
scaler.unscale_(opt)
|
| 292 |
nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
|
| 293 |
scaler.step(opt)
|
| 294 |
scaler.update()
|
| 295 |
opt.zero_grad(set_to_none=True)
|
|
|
|
| 296 |
|
| 297 |
peak_alloc = None
|
| 298 |
peak_reserved = None
|
| 299 |
if torch.cuda.is_available():
|
| 300 |
peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
|
| 301 |
peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
|
|
|
|
|
|
|
| 302 |
_update_stats(state, bi, total_val)
|
| 303 |
_maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=objective)
|
| 304 |
return total_val
|
|
|
|
| 7 |
"""
|
| 8 |
import math
|
| 9 |
import random
|
| 10 |
+
import time
|
| 11 |
+
from collections import defaultdict
|
| 12 |
import numpy as np
|
| 13 |
import torch
|
| 14 |
import torch.nn as nn
|
|
|
|
| 19 |
SD = 0.5
|
| 20 |
|
| 21 |
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _profile_active(state, args):
|
| 25 |
+
limit = int(getattr(args, "profile_steps", 0) or 0)
|
| 26 |
+
return limit > 0 and int(state.get("profile_n", 0)) < limit
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _profile_add(state, name, seconds):
|
| 30 |
+
if seconds is None:
|
| 31 |
+
return
|
| 32 |
+
prof = state.setdefault("profile_times", defaultdict(float))
|
| 33 |
+
prof[name] += float(seconds)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _profile_tic(enabled):
|
| 37 |
+
if not enabled:
|
| 38 |
+
return None
|
| 39 |
+
if torch.cuda.is_available():
|
| 40 |
+
torch.cuda.synchronize()
|
| 41 |
+
return time.perf_counter()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _profile_toc(state, name, start):
|
| 45 |
+
if start is None:
|
| 46 |
+
return
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
torch.cuda.synchronize()
|
| 49 |
+
_profile_add(state, name, time.perf_counter() - start)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _profile_step_done(state, args):
|
| 53 |
+
limit = int(getattr(args, "profile_steps", 0) or 0)
|
| 54 |
+
if limit <= 0:
|
| 55 |
+
return
|
| 56 |
+
n_prev = int(state.get("profile_n", 0))
|
| 57 |
+
if n_prev >= limit:
|
| 58 |
+
return
|
| 59 |
+
state["profile_n"] = n_prev + 1
|
| 60 |
+
n = int(state["profile_n"])
|
| 61 |
+
log_every = max(1, int(getattr(args, "profile_log_every", 25) or 25))
|
| 62 |
+
if n % log_every != 0 and n != limit:
|
| 63 |
+
return
|
| 64 |
+
times = state.get("profile_times", {})
|
| 65 |
+
keys = [
|
| 66 |
+
"data_stream", "tensor", "setup",
|
| 67 |
+
"ar_forward", "ar_ce", "ar_backward",
|
| 68 |
+
"sat_forward", "sat_ce", "sat_backward",
|
| 69 |
+
"nat_forward", "nat_ce", "nat_backward",
|
| 70 |
+
"opt_step", "step_total",
|
| 71 |
+
]
|
| 72 |
+
parts = []
|
| 73 |
+
for key in keys:
|
| 74 |
+
val = float(times.get(key, 0.0)) * 1000.0 / max(1, n)
|
| 75 |
+
if val > 0.01:
|
| 76 |
+
parts.append(f"{key}={val:.2f}ms")
|
| 77 |
+
print(f"[profile] n={n}/{limit} avg " + " ".join(parts), flush=True)
|
| 78 |
+
|
| 79 |
def _cdf(x):
|
| 80 |
return 0.5 * (1 + math.erf(x / math.sqrt(2)))
|
| 81 |
|
|
|
|
| 188 |
return block(x, mask)
|
| 189 |
|
| 190 |
|
| 191 |
+
def _dblock_checkpoint_this_layer(args, base_enabled, layer_pos):
|
| 192 |
+
if not base_enabled:
|
| 193 |
+
return False
|
| 194 |
+
stride = int(getattr(args, "dblock_checkpoint_stride", 1) or 1)
|
| 195 |
+
if stride <= 0:
|
| 196 |
+
return False
|
| 197 |
+
if stride == 1:
|
| 198 |
+
return True
|
| 199 |
+
return (int(layer_pos) % stride) == 0
|
| 200 |
+
|
| 201 |
+
|
| 202 |
def _sample_token_loss_inputs(hidden, targets, max_tokens):
|
| 203 |
max_tokens = int(max_tokens or 0)
|
| 204 |
if max_tokens <= 0:
|
|
|
|
| 243 |
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
| 244 |
import nB300_agillm4 as M
|
| 245 |
|
| 246 |
+
prof = _profile_active(state, args)
|
| 247 |
+
_step_t = _profile_tic(prof)
|
| 248 |
if torch.cuda.is_available():
|
| 249 |
torch.cuda.reset_peak_memory_stats()
|
| 250 |
|
| 251 |
+
_setup_t = _profile_tic(prof)
|
| 252 |
B = state["B"]
|
| 253 |
asg = state["assign"]
|
| 254 |
bs = state["bsig"]
|
|
|
|
| 279 |
run_ar, run_sat, run_nat, objective = _choose_objectives(
|
| 280 |
state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
|
| 281 |
)
|
| 282 |
+
_profile_toc(state, "setup", _setup_t)
|
| 283 |
|
| 284 |
ar_val = 0.0
|
| 285 |
sat_val = 0.0
|
|
|
|
| 287 |
|
| 288 |
if run_ar:
|
| 289 |
causal = M.causal_mask(T, structured=M.use_structured_masks(args))
|
| 290 |
+
_t = _profile_tic(prof)
|
| 291 |
with M.amp(args.amp):
|
| 292 |
emb = core.emb(ids)
|
| 293 |
zt = emb + sig[:, None, None] * torch.randn_like(emb)
|
| 294 |
h = ci * zt
|
| 295 |
+
for lpos, li in enumerate(layers):
|
| 296 |
+
h = _run_block(core.blocks[li], h, causal, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos))
|
| 297 |
Dn = core.ln(cs * zt + co * h)
|
| 298 |
+
_profile_toc(state, "ar_forward", _t)
|
| 299 |
+
_t = _profile_tic(prof)
|
| 300 |
ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
|
| 301 |
Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
|
| 302 |
)
|
| 303 |
ar = ar_weight * w * fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
|
| 304 |
ar_val = float(ar.detach())
|
| 305 |
+
_profile_toc(state, "ar_ce", _t)
|
| 306 |
+
_t = _profile_tic(prof)
|
| 307 |
scaler.scale(ar).backward()
|
| 308 |
+
_profile_toc(state, "ar_backward", _t)
|
| 309 |
del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar, ar_used, ar_total
|
| 310 |
|
| 311 |
if run_sat:
|
| 312 |
smask = M.sat_mask(T, structured=M.use_structured_masks(args))
|
| 313 |
+
_t = _profile_tic(prof)
|
| 314 |
with M.amp(args.amp):
|
| 315 |
emb2 = core.emb(ids)
|
| 316 |
zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
|
| 317 |
h2 = ci * zt2
|
| 318 |
+
for lpos, li in enumerate(layers):
|
| 319 |
+
h2 = _run_block(core.blocks[li], h2, smask, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos))
|
| 320 |
Ds = core.ln(cs * zt2 + co * h2)
|
| 321 |
last = Ds[:, -SATB:]
|
| 322 |
+
_profile_toc(state, "sat_forward", _t)
|
| 323 |
+
_t = _profile_tic(prof)
|
| 324 |
+
sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs(
|
| 325 |
+
last, ids[:, 1 : SATB + 1], int(getattr(args, "dblock_sat_loss_tokens", 0))
|
| 326 |
+
)
|
| 327 |
+
with M.amp(args.amp):
|
| 328 |
satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
|
| 329 |
satv = (
|
| 330 |
M.EMIT_LAMBDA
|
|
|
|
| 336 |
else 0.0
|
| 337 |
)
|
| 338 |
sat = sat_weight * w * (satf + satv)
|
| 339 |
+
_profile_toc(state, "sat_ce", _t)
|
| 340 |
sat_val = float(sat.detach())
|
| 341 |
+
_t = _profile_tic(prof)
|
| 342 |
scaler.scale(sat).backward()
|
| 343 |
+
_profile_toc(state, "sat_backward", _t)
|
| 344 |
del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat
|
| 345 |
|
| 346 |
if run_nat:
|
| 347 |
ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
|
| 348 |
nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
|
| 349 |
+
_t = _profile_tic(prof)
|
| 350 |
with M.amp(args.amp):
|
| 351 |
nat_in = nat_ids.clone()
|
| 352 |
m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
|
|
|
|
| 354 |
m[..., -1] = True
|
| 355 |
nat_in[m] = M.BLANK
|
| 356 |
hn = core.emb(nat_in)
|
| 357 |
+
for lpos, li in enumerate(layers):
|
| 358 |
+
hn = _run_block(core.blocks[li], hn, None, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos))
|
| 359 |
Dnat = core.ln(hn)
|
| 360 |
+
_profile_toc(state, "nat_forward", _t)
|
| 361 |
+
_t = _profile_tic(prof)
|
| 362 |
nat_hidden = Dnat[m]
|
| 363 |
nat_targets = nat_ids[m]
|
| 364 |
nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
|
|
|
|
| 366 |
)
|
| 367 |
nat = nat_weight * fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
|
| 368 |
nat_val = float(nat.detach())
|
| 369 |
+
_profile_toc(state, "nat_ce", _t)
|
| 370 |
+
_t = _profile_tic(prof)
|
| 371 |
scaler.scale(nat).backward()
|
| 372 |
+
_profile_toc(state, "nat_backward", _t)
|
| 373 |
del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat, nat_used, nat_total
|
| 374 |
|
| 375 |
total_val = ar_val + sat_val + nat_val
|
|
|
|
| 378 |
if torch.cuda.is_available():
|
| 379 |
torch.cuda.empty_cache()
|
| 380 |
print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
|
| 381 |
+
_profile_toc(state, "step_total", _step_t)
|
| 382 |
+
_profile_step_done(state, args)
|
| 383 |
_update_stats(state, bi, total_val)
|
| 384 |
return total_val
|
| 385 |
|
| 386 |
+
_t = _profile_tic(prof)
|
| 387 |
scaler.unscale_(opt)
|
| 388 |
nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
|
| 389 |
scaler.step(opt)
|
| 390 |
scaler.update()
|
| 391 |
opt.zero_grad(set_to_none=True)
|
| 392 |
+
_profile_toc(state, "opt_step", _t)
|
| 393 |
|
| 394 |
peak_alloc = None
|
| 395 |
peak_reserved = None
|
| 396 |
if torch.cuda.is_available():
|
| 397 |
peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
|
| 398 |
peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
|
| 399 |
+
_profile_toc(state, "step_total", _step_t)
|
| 400 |
+
_profile_step_done(state, args)
|
| 401 |
_update_stats(state, bi, total_val)
|
| 402 |
_maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=objective)
|
| 403 |
return total_val
|
nB300_agillm4_vram_dblock.py
CHANGED
|
@@ -2324,17 +2324,37 @@ def _train_phase(
|
|
| 2324 |
except Exception:
|
| 2325 |
pass
|
| 2326 |
while seen_tok < total_tokens_needed:
|
|
|
|
|
|
|
| 2327 |
try:
|
| 2328 |
while len(buf) < BLOCK:
|
| 2329 |
buf.append(next(stream))
|
| 2330 |
except StopIteration:
|
| 2331 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2332 |
seq = buf[:BLOCK]
|
| 2333 |
buf = buf[BLOCK:]
|
| 2334 |
batch_accum.append(seq)
|
| 2335 |
if len(batch_accum) < BATCH:
|
| 2336 |
continue
|
|
|
|
| 2337 |
ids = torch.tensor(batch_accum, device=DEV)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2338 |
batch_accum = []
|
| 2339 |
tgt_ar = ids.clone()
|
| 2340 |
try:
|
|
@@ -2980,6 +3000,10 @@ def main():
|
|
| 2980 |
help="Print lightweight trainer heartbeat/status lines every N seconds; 0 disables.")
|
| 2981 |
tr.add_argument("--empty_cache_every_steps", type=int, default=0,
|
| 2982 |
help="Call torch.cuda.empty_cache() every N train steps; useful for VRAM-first runs where lower reserved VRAM matters more than speed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2983 |
tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)")
|
| 2984 |
tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep")
|
| 2985 |
tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)")
|
|
@@ -3013,6 +3037,8 @@ def main():
|
|
| 3013 |
help="Exploration rate for loss-balanced DBlock scheduling.")
|
| 3014 |
tr.add_argument("--dblock_log_every", type=int, default=25,
|
| 3015 |
help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.")
|
|
|
|
|
|
|
| 3016 |
tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000,
|
| 3017 |
help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.")
|
| 3018 |
tr.add_argument("--dblock_edm_wmax", type=float, default=5.0,
|
|
|
|
| 2324 |
except Exception:
|
| 2325 |
pass
|
| 2326 |
while seen_tok < total_tokens_needed:
|
| 2327 |
+
_profile_batch = _DBS is not None and int(getattr(args, "profile_steps", 0) or 0) > 0 and int(_DBS.get("profile_n", 0)) < int(getattr(args, "profile_steps", 0) or 0)
|
| 2328 |
+
_data_t = time.perf_counter() if _profile_batch else None
|
| 2329 |
try:
|
| 2330 |
while len(buf) < BLOCK:
|
| 2331 |
buf.append(next(stream))
|
| 2332 |
except StopIteration:
|
| 2333 |
break
|
| 2334 |
+
if _profile_batch:
|
| 2335 |
+
try:
|
| 2336 |
+
import dblocks_train as _db_prof
|
| 2337 |
+
_db_prof._profile_add(_DBS, "data_stream", time.perf_counter() - _data_t)
|
| 2338 |
+
except Exception:
|
| 2339 |
+
pass
|
| 2340 |
seq = buf[:BLOCK]
|
| 2341 |
buf = buf[BLOCK:]
|
| 2342 |
batch_accum.append(seq)
|
| 2343 |
if len(batch_accum) < BATCH:
|
| 2344 |
continue
|
| 2345 |
+
_tensor_t = time.perf_counter() if _profile_batch else None
|
| 2346 |
ids = torch.tensor(batch_accum, device=DEV)
|
| 2347 |
+
if _profile_batch:
|
| 2348 |
+
if DEV.type == "cuda":
|
| 2349 |
+
try:
|
| 2350 |
+
torch.cuda.synchronize()
|
| 2351 |
+
except Exception:
|
| 2352 |
+
pass
|
| 2353 |
+
try:
|
| 2354 |
+
import dblocks_train as _db_prof
|
| 2355 |
+
_db_prof._profile_add(_DBS, "tensor", time.perf_counter() - _tensor_t)
|
| 2356 |
+
except Exception:
|
| 2357 |
+
pass
|
| 2358 |
batch_accum = []
|
| 2359 |
tgt_ar = ids.clone()
|
| 2360 |
try:
|
|
|
|
| 3000 |
help="Print lightweight trainer heartbeat/status lines every N seconds; 0 disables.")
|
| 3001 |
tr.add_argument("--empty_cache_every_steps", type=int, default=0,
|
| 3002 |
help="Call torch.cuda.empty_cache() every N train steps; useful for VRAM-first runs where lower reserved VRAM matters more than speed.")
|
| 3003 |
+
tr.add_argument("--profile_steps", type=int, default=0,
|
| 3004 |
+
help="Profile the first N DBlock training steps with in-process CUDA timers; 0 disables.")
|
| 3005 |
+
tr.add_argument("--profile_log_every", type=int, default=25,
|
| 3006 |
+
help="Print averaged profiler timings every N profiled steps.")
|
| 3007 |
tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)")
|
| 3008 |
tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep")
|
| 3009 |
tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)")
|
|
|
|
| 3037 |
help="Exploration rate for loss-balanced DBlock scheduling.")
|
| 3038 |
tr.add_argument("--dblock_log_every", type=int, default=25,
|
| 3039 |
help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.")
|
| 3040 |
+
tr.add_argument("--dblock_checkpoint_stride", type=int, default=1,
|
| 3041 |
+
help="With --grad_checkpoint in --dblock mode, checkpoint one layer every N selected block layers; 1=all layers, 2=alternate, 0=off.")
|
| 3042 |
tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000,
|
| 3043 |
help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.")
|
| 3044 |
tr.add_argument("--dblock_edm_wmax", type=float, default=5.0,
|
relaunch_agillm4_dblock_sg2.sh
CHANGED
|
@@ -10,15 +10,16 @@ export AGILLM_ATTN_BACKEND=sublinear
|
|
| 10 |
SAVE_DIR=/workspace/agillm4_4090_ckpts
|
| 11 |
CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
|
| 12 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 13 |
-
echo "RELAUNCH_AGILLM4_DBLOCK_SG2 $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT (
|
| 14 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 15 |
--dblock --dblock_blocks 4 --dblock_schedule loss_balanced --dblock_warmup_steps 16 \
|
| 16 |
--dblock_sigma_curriculum_steps 2000 --dblock_log_every 25 --dblock_objective_mode stochastic \
|
| 17 |
--dblock_ar_prob 0.85 --dblock_sat_prob 0.075 --dblock_nat_prob 0.075 \
|
| 18 |
--dblock_ar_loss_tokens 512 --dblock_sat_loss_tokens 0 --dblock_nat_loss_tokens 512 \
|
| 19 |
-
--tie_weights --batch_size
|
| 20 |
--sublinear_window 128 --sublinear_stride 128 --sublinear_max_anchors 128 --sublinear_chunk 128 \
|
| 21 |
--sublinear_sinks 4 --sublinear_recent_anchors 64 --no-sublinear_pooled_landmarks \
|
| 22 |
-
--grad_checkpoint --optimizer paged_adamw8bit --sat_every 4 --nat_every 4 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
|
| 23 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" --save_every_sec 86400 --heartbeat_every_sec 300 \
|
|
|
|
| 24 |
--empty_cache_every_steps 0 --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
|
|
|
|
| 10 |
SAVE_DIR=/workspace/agillm4_4090_ckpts
|
| 11 |
CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
|
| 12 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 13 |
+
echo "RELAUNCH_AGILLM4_DBLOCK_SG2 $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT (batch4 official speed-optimized + sublinear v2)"
|
| 14 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 15 |
--dblock --dblock_blocks 4 --dblock_schedule loss_balanced --dblock_warmup_steps 16 \
|
| 16 |
--dblock_sigma_curriculum_steps 2000 --dblock_log_every 25 --dblock_objective_mode stochastic \
|
| 17 |
--dblock_ar_prob 0.85 --dblock_sat_prob 0.075 --dblock_nat_prob 0.075 \
|
| 18 |
--dblock_ar_loss_tokens 512 --dblock_sat_loss_tokens 0 --dblock_nat_loss_tokens 512 \
|
| 19 |
+
--tie_weights --batch_size 4 --block 1280 --amp --attn_backend sublinear \
|
| 20 |
--sublinear_window 128 --sublinear_stride 128 --sublinear_max_anchors 128 --sublinear_chunk 128 \
|
| 21 |
--sublinear_sinks 4 --sublinear_recent_anchors 64 --no-sublinear_pooled_landmarks \
|
| 22 |
+
--grad_checkpoint --dblock_checkpoint_stride 1 --optimizer paged_adamw8bit --sat_every 4 --nat_every 4 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
|
| 23 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" --save_every_sec 86400 --heartbeat_every_sec 300 \
|
| 24 |
+
\
|
| 25 |
--empty_cache_every_steps 0 --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
|