Scott/Codex commited on
Commit
690cf55
·
1 Parent(s): 468a571

Add DBlock profiler and speed-tuned batch config

Browse files
README.md CHANGED
@@ -85,4 +85,8 @@ DiffusionBlocks, gradient-checkpointed blocks, tied heads, and structured masks.
85
 
86
  Sublinear coverage update 2026-05-29: the saved AGILLM-4 trainer snapshot now matches the live v2 sparse global memory path. It fixes gathered ALiBi distance, suppresses duplicate local/anchor candidates before softmax, uses hybrid full-span + recent-tail anchors with explicit `--sublinear_sinks` and `--sublinear_recent_anchors`, and includes optional pooled K/V landmark summaries behind `--sublinear_pooled_landmarks`. At the live 128/128/128 profile it keeps deep-past coverage while preserving recent anchors and the same VRAM-first key budget. See `sublinear_improved_snippet.py` for the minimal blocks and `sublinear_improved.py` for the coverage demo/standalone selector.
87
 
 
 
 
 
88
  License: Apache-2.0 (matching the upstream method).
 
85
 
86
  Sublinear coverage update 2026-05-29: the saved AGILLM-4 trainer snapshot now matches the live v2 sparse global memory path. It fixes gathered ALiBi distance, suppresses duplicate local/anchor candidates before softmax, uses hybrid full-span + recent-tail anchors with explicit `--sublinear_sinks` and `--sublinear_recent_anchors`, and includes optional pooled K/V landmark summaries behind `--sublinear_pooled_landmarks`. At the live 128/128/128 profile it keeps deep-past coverage while preserving recent anchors and the same VRAM-first key budget. See `sublinear_improved_snippet.py` for the minimal blocks and `sublinear_improved.py` for the coverage demo/standalone selector.
87
 
88
+
89
+ Profiling/speed update 2026-05-29: added in-process DBlock profiling (`--profile_steps`, `--profile_log_every`) after external ptrace profiling was blocked on Vast. The profile showed the bottleneck is transformer recompute/backward, not fused CE or the optimizer: at B=2 full checkpointing, AR backward averaged ~605 ms/step, AR forward ~184 ms, CE ~4.5 ms, optimizer ~17 ms. Tested speed levers live: no checkpointing OOMed at B=2 and fell to B=1, selective checkpoint stride=2 fit but hugged VRAM and reached ~2.94k tok/s, B=5/6 hit a memory-pressure cliff, while B=4 with full DBlock checkpointing was the best stable official setting (~3.0k tok/s warm window, ~13.2 GB tensor peak / ~17.6 GB reserved, ETA ~269-275 days). The live relaunch now uses `--batch_size 4 --grad_checkpoint --dblock_checkpoint_stride 1` and leaves selective checkpointing available for future context/batch tradeoffs.
90
+
91
+
92
  License: Apache-2.0 (matching the upstream method).
dblocks_train.py CHANGED
@@ -7,6 +7,8 @@ Lazy-imports nB300 inside functions to avoid a circular import.
7
  """
8
  import math
9
  import random
 
 
10
  import numpy as np
11
  import torch
12
  import torch.nn as nn
@@ -17,6 +19,63 @@ from fused_ce import fused_ce
17
  SD = 0.5
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def _cdf(x):
21
  return 0.5 * (1 + math.erf(x / math.sqrt(2)))
22
 
@@ -129,6 +188,17 @@ def _run_block(block, x, mask, use_checkpoint):
129
  return block(x, mask)
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
132
  def _sample_token_loss_inputs(hidden, targets, max_tokens):
133
  max_tokens = int(max_tokens or 0)
134
  if max_tokens <= 0:
@@ -173,9 +243,12 @@ def _choose_objectives(state, args, ar_weight, sat_weight, nat_weight, do_sat_pe
173
  def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
174
  import nB300_agillm4 as M
175
 
 
 
176
  if torch.cuda.is_available():
177
  torch.cuda.reset_peak_memory_stats()
178
 
 
179
  B = state["B"]
180
  asg = state["assign"]
181
  bs = state["bsig"]
@@ -206,6 +279,7 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
206
  run_ar, run_sat, run_nat, objective = _choose_objectives(
207
  state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
208
  )
 
209
 
210
  ar_val = 0.0
211
  sat_val = 0.0
@@ -213,34 +287,44 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
213
 
214
  if run_ar:
215
  causal = M.causal_mask(T, structured=M.use_structured_masks(args))
 
216
  with M.amp(args.amp):
217
  emb = core.emb(ids)
218
  zt = emb + sig[:, None, None] * torch.randn_like(emb)
219
  h = ci * zt
220
- for li in layers:
221
- h = _run_block(core.blocks[li], h, causal, use_layer_checkpoint)
222
  Dn = core.ln(cs * zt + co * h)
 
 
223
  ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
224
  Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
225
  )
226
  ar = ar_weight * w * fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
227
  ar_val = float(ar.detach())
 
 
228
  scaler.scale(ar).backward()
 
229
  del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar, ar_used, ar_total
230
 
231
  if run_sat:
232
  smask = M.sat_mask(T, structured=M.use_structured_masks(args))
 
233
  with M.amp(args.amp):
234
  emb2 = core.emb(ids)
235
  zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
236
  h2 = ci * zt2
237
- for li in layers:
238
- h2 = _run_block(core.blocks[li], h2, smask, use_layer_checkpoint)
239
  Ds = core.ln(cs * zt2 + co * h2)
240
  last = Ds[:, -SATB:]
241
- sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs(
242
- last, ids[:, 1 : SATB + 1], int(getattr(args, "dblock_sat_loss_tokens", 0))
243
- )
 
 
 
244
  satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
245
  satv = (
246
  M.EMIT_LAMBDA
@@ -252,13 +336,17 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
252
  else 0.0
253
  )
254
  sat = sat_weight * w * (satf + satv)
 
255
  sat_val = float(sat.detach())
 
256
  scaler.scale(sat).backward()
 
257
  del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat
258
 
259
  if run_nat:
260
  ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
261
  nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
 
262
  with M.amp(args.amp):
263
  nat_in = nat_ids.clone()
264
  m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
@@ -266,9 +354,11 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
266
  m[..., -1] = True
267
  nat_in[m] = M.BLANK
268
  hn = core.emb(nat_in)
269
- for li in layers:
270
- hn = _run_block(core.blocks[li], hn, None, use_layer_checkpoint)
271
  Dnat = core.ln(hn)
 
 
272
  nat_hidden = Dnat[m]
273
  nat_targets = nat_ids[m]
274
  nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
@@ -276,7 +366,10 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
276
  )
277
  nat = nat_weight * fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
278
  nat_val = float(nat.detach())
 
 
279
  scaler.scale(nat).backward()
 
280
  del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat, nat_used, nat_total
281
 
282
  total_val = ar_val + sat_val + nat_val
@@ -285,20 +378,26 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
285
  if torch.cuda.is_available():
286
  torch.cuda.empty_cache()
287
  print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
 
 
288
  _update_stats(state, bi, total_val)
289
  return total_val
290
 
 
291
  scaler.unscale_(opt)
292
  nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
293
  scaler.step(opt)
294
  scaler.update()
295
  opt.zero_grad(set_to_none=True)
 
296
 
297
  peak_alloc = None
298
  peak_reserved = None
299
  if torch.cuda.is_available():
300
  peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
301
  peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
 
 
302
  _update_stats(state, bi, total_val)
303
  _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=objective)
304
  return total_val
 
7
  """
8
  import math
9
  import random
10
+ import time
11
+ from collections import defaultdict
12
  import numpy as np
13
  import torch
14
  import torch.nn as nn
 
19
  SD = 0.5
20
 
21
 
22
+
23
+
24
+ def _profile_active(state, args):
25
+ limit = int(getattr(args, "profile_steps", 0) or 0)
26
+ return limit > 0 and int(state.get("profile_n", 0)) < limit
27
+
28
+
29
+ def _profile_add(state, name, seconds):
30
+ if seconds is None:
31
+ return
32
+ prof = state.setdefault("profile_times", defaultdict(float))
33
+ prof[name] += float(seconds)
34
+
35
+
36
+ def _profile_tic(enabled):
37
+ if not enabled:
38
+ return None
39
+ if torch.cuda.is_available():
40
+ torch.cuda.synchronize()
41
+ return time.perf_counter()
42
+
43
+
44
+ def _profile_toc(state, name, start):
45
+ if start is None:
46
+ return
47
+ if torch.cuda.is_available():
48
+ torch.cuda.synchronize()
49
+ _profile_add(state, name, time.perf_counter() - start)
50
+
51
+
52
+ def _profile_step_done(state, args):
53
+ limit = int(getattr(args, "profile_steps", 0) or 0)
54
+ if limit <= 0:
55
+ return
56
+ n_prev = int(state.get("profile_n", 0))
57
+ if n_prev >= limit:
58
+ return
59
+ state["profile_n"] = n_prev + 1
60
+ n = int(state["profile_n"])
61
+ log_every = max(1, int(getattr(args, "profile_log_every", 25) or 25))
62
+ if n % log_every != 0 and n != limit:
63
+ return
64
+ times = state.get("profile_times", {})
65
+ keys = [
66
+ "data_stream", "tensor", "setup",
67
+ "ar_forward", "ar_ce", "ar_backward",
68
+ "sat_forward", "sat_ce", "sat_backward",
69
+ "nat_forward", "nat_ce", "nat_backward",
70
+ "opt_step", "step_total",
71
+ ]
72
+ parts = []
73
+ for key in keys:
74
+ val = float(times.get(key, 0.0)) * 1000.0 / max(1, n)
75
+ if val > 0.01:
76
+ parts.append(f"{key}={val:.2f}ms")
77
+ print(f"[profile] n={n}/{limit} avg " + " ".join(parts), flush=True)
78
+
79
  def _cdf(x):
80
  return 0.5 * (1 + math.erf(x / math.sqrt(2)))
81
 
 
188
  return block(x, mask)
189
 
190
 
191
+ def _dblock_checkpoint_this_layer(args, base_enabled, layer_pos):
192
+ if not base_enabled:
193
+ return False
194
+ stride = int(getattr(args, "dblock_checkpoint_stride", 1) or 1)
195
+ if stride <= 0:
196
+ return False
197
+ if stride == 1:
198
+ return True
199
+ return (int(layer_pos) % stride) == 0
200
+
201
+
202
  def _sample_token_loss_inputs(hidden, targets, max_tokens):
203
  max_tokens = int(max_tokens or 0)
204
  if max_tokens <= 0:
 
243
  def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
244
  import nB300_agillm4 as M
245
 
246
+ prof = _profile_active(state, args)
247
+ _step_t = _profile_tic(prof)
248
  if torch.cuda.is_available():
249
  torch.cuda.reset_peak_memory_stats()
250
 
251
+ _setup_t = _profile_tic(prof)
252
  B = state["B"]
253
  asg = state["assign"]
254
  bs = state["bsig"]
 
279
  run_ar, run_sat, run_nat, objective = _choose_objectives(
280
  state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
281
  )
282
+ _profile_toc(state, "setup", _setup_t)
283
 
284
  ar_val = 0.0
285
  sat_val = 0.0
 
287
 
288
  if run_ar:
289
  causal = M.causal_mask(T, structured=M.use_structured_masks(args))
290
+ _t = _profile_tic(prof)
291
  with M.amp(args.amp):
292
  emb = core.emb(ids)
293
  zt = emb + sig[:, None, None] * torch.randn_like(emb)
294
  h = ci * zt
295
+ for lpos, li in enumerate(layers):
296
+ h = _run_block(core.blocks[li], h, causal, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos))
297
  Dn = core.ln(cs * zt + co * h)
298
+ _profile_toc(state, "ar_forward", _t)
299
+ _t = _profile_tic(prof)
300
  ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
301
  Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
302
  )
303
  ar = ar_weight * w * fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
304
  ar_val = float(ar.detach())
305
+ _profile_toc(state, "ar_ce", _t)
306
+ _t = _profile_tic(prof)
307
  scaler.scale(ar).backward()
308
+ _profile_toc(state, "ar_backward", _t)
309
  del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar, ar_used, ar_total
310
 
311
  if run_sat:
312
  smask = M.sat_mask(T, structured=M.use_structured_masks(args))
313
+ _t = _profile_tic(prof)
314
  with M.amp(args.amp):
315
  emb2 = core.emb(ids)
316
  zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
317
  h2 = ci * zt2
318
+ for lpos, li in enumerate(layers):
319
+ h2 = _run_block(core.blocks[li], h2, smask, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos))
320
  Ds = core.ln(cs * zt2 + co * h2)
321
  last = Ds[:, -SATB:]
322
+ _profile_toc(state, "sat_forward", _t)
323
+ _t = _profile_tic(prof)
324
+ sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs(
325
+ last, ids[:, 1 : SATB + 1], int(getattr(args, "dblock_sat_loss_tokens", 0))
326
+ )
327
+ with M.amp(args.amp):
328
  satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
329
  satv = (
330
  M.EMIT_LAMBDA
 
336
  else 0.0
337
  )
338
  sat = sat_weight * w * (satf + satv)
339
+ _profile_toc(state, "sat_ce", _t)
340
  sat_val = float(sat.detach())
341
+ _t = _profile_tic(prof)
342
  scaler.scale(sat).backward()
343
+ _profile_toc(state, "sat_backward", _t)
344
  del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat
345
 
346
  if run_nat:
347
  ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
348
  nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
349
+ _t = _profile_tic(prof)
350
  with M.amp(args.amp):
351
  nat_in = nat_ids.clone()
352
  m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
 
354
  m[..., -1] = True
355
  nat_in[m] = M.BLANK
356
  hn = core.emb(nat_in)
357
+ for lpos, li in enumerate(layers):
358
+ hn = _run_block(core.blocks[li], hn, None, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos))
359
  Dnat = core.ln(hn)
360
+ _profile_toc(state, "nat_forward", _t)
361
+ _t = _profile_tic(prof)
362
  nat_hidden = Dnat[m]
363
  nat_targets = nat_ids[m]
364
  nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
 
366
  )
367
  nat = nat_weight * fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
368
  nat_val = float(nat.detach())
369
+ _profile_toc(state, "nat_ce", _t)
370
+ _t = _profile_tic(prof)
371
  scaler.scale(nat).backward()
372
+ _profile_toc(state, "nat_backward", _t)
373
  del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat, nat_used, nat_total
374
 
375
  total_val = ar_val + sat_val + nat_val
 
378
  if torch.cuda.is_available():
379
  torch.cuda.empty_cache()
380
  print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
381
+ _profile_toc(state, "step_total", _step_t)
382
+ _profile_step_done(state, args)
383
  _update_stats(state, bi, total_val)
384
  return total_val
385
 
386
+ _t = _profile_tic(prof)
387
  scaler.unscale_(opt)
388
  nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
389
  scaler.step(opt)
390
  scaler.update()
391
  opt.zero_grad(set_to_none=True)
392
+ _profile_toc(state, "opt_step", _t)
393
 
394
  peak_alloc = None
395
  peak_reserved = None
396
  if torch.cuda.is_available():
397
  peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
398
  peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
399
+ _profile_toc(state, "step_total", _step_t)
400
+ _profile_step_done(state, args)
401
  _update_stats(state, bi, total_val)
402
  _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=objective)
403
  return total_val
nB300_agillm4_vram_dblock.py CHANGED
@@ -2324,17 +2324,37 @@ def _train_phase(
2324
  except Exception:
2325
  pass
2326
  while seen_tok < total_tokens_needed:
 
 
2327
  try:
2328
  while len(buf) < BLOCK:
2329
  buf.append(next(stream))
2330
  except StopIteration:
2331
  break
 
 
 
 
 
 
2332
  seq = buf[:BLOCK]
2333
  buf = buf[BLOCK:]
2334
  batch_accum.append(seq)
2335
  if len(batch_accum) < BATCH:
2336
  continue
 
2337
  ids = torch.tensor(batch_accum, device=DEV)
 
 
 
 
 
 
 
 
 
 
 
2338
  batch_accum = []
2339
  tgt_ar = ids.clone()
2340
  try:
@@ -2980,6 +3000,10 @@ def main():
2980
  help="Print lightweight trainer heartbeat/status lines every N seconds; 0 disables.")
2981
  tr.add_argument("--empty_cache_every_steps", type=int, default=0,
2982
  help="Call torch.cuda.empty_cache() every N train steps; useful for VRAM-first runs where lower reserved VRAM matters more than speed.")
 
 
 
 
2983
  tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)")
2984
  tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep")
2985
  tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)")
@@ -3013,6 +3037,8 @@ def main():
3013
  help="Exploration rate for loss-balanced DBlock scheduling.")
3014
  tr.add_argument("--dblock_log_every", type=int, default=25,
3015
  help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.")
 
 
3016
  tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000,
3017
  help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.")
3018
  tr.add_argument("--dblock_edm_wmax", type=float, default=5.0,
 
2324
  except Exception:
2325
  pass
2326
  while seen_tok < total_tokens_needed:
2327
+ _profile_batch = _DBS is not None and int(getattr(args, "profile_steps", 0) or 0) > 0 and int(_DBS.get("profile_n", 0)) < int(getattr(args, "profile_steps", 0) or 0)
2328
+ _data_t = time.perf_counter() if _profile_batch else None
2329
  try:
2330
  while len(buf) < BLOCK:
2331
  buf.append(next(stream))
2332
  except StopIteration:
2333
  break
2334
+ if _profile_batch:
2335
+ try:
2336
+ import dblocks_train as _db_prof
2337
+ _db_prof._profile_add(_DBS, "data_stream", time.perf_counter() - _data_t)
2338
+ except Exception:
2339
+ pass
2340
  seq = buf[:BLOCK]
2341
  buf = buf[BLOCK:]
2342
  batch_accum.append(seq)
2343
  if len(batch_accum) < BATCH:
2344
  continue
2345
+ _tensor_t = time.perf_counter() if _profile_batch else None
2346
  ids = torch.tensor(batch_accum, device=DEV)
2347
+ if _profile_batch:
2348
+ if DEV.type == "cuda":
2349
+ try:
2350
+ torch.cuda.synchronize()
2351
+ except Exception:
2352
+ pass
2353
+ try:
2354
+ import dblocks_train as _db_prof
2355
+ _db_prof._profile_add(_DBS, "tensor", time.perf_counter() - _tensor_t)
2356
+ except Exception:
2357
+ pass
2358
  batch_accum = []
2359
  tgt_ar = ids.clone()
2360
  try:
 
3000
  help="Print lightweight trainer heartbeat/status lines every N seconds; 0 disables.")
3001
  tr.add_argument("--empty_cache_every_steps", type=int, default=0,
3002
  help="Call torch.cuda.empty_cache() every N train steps; useful for VRAM-first runs where lower reserved VRAM matters more than speed.")
3003
+ tr.add_argument("--profile_steps", type=int, default=0,
3004
+ help="Profile the first N DBlock training steps with in-process CUDA timers; 0 disables.")
3005
+ tr.add_argument("--profile_log_every", type=int, default=25,
3006
+ help="Print averaged profiler timings every N profiled steps.")
3007
  tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)")
3008
  tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep")
3009
  tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)")
 
3037
  help="Exploration rate for loss-balanced DBlock scheduling.")
3038
  tr.add_argument("--dblock_log_every", type=int, default=25,
3039
  help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.")
3040
+ tr.add_argument("--dblock_checkpoint_stride", type=int, default=1,
3041
+ help="With --grad_checkpoint in --dblock mode, checkpoint one layer every N selected block layers; 1=all layers, 2=alternate, 0=off.")
3042
  tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000,
3043
  help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.")
3044
  tr.add_argument("--dblock_edm_wmax", type=float, default=5.0,
relaunch_agillm4_dblock_sg2.sh CHANGED
@@ -10,15 +10,16 @@ export AGILLM_ATTN_BACKEND=sublinear
10
  SAVE_DIR=/workspace/agillm4_4090_ckpts
11
  CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
12
  exec >> /workspace/agillm4_floor_train.log 2>&1
13
- echo "RELAUNCH_AGILLM4_DBLOCK_SG2 $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT (improved sublinear v2: ALiBi distance + dedupe + hybrid anchors + sinks)"
14
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
15
  --dblock --dblock_blocks 4 --dblock_schedule loss_balanced --dblock_warmup_steps 16 \
16
  --dblock_sigma_curriculum_steps 2000 --dblock_log_every 25 --dblock_objective_mode stochastic \
17
  --dblock_ar_prob 0.85 --dblock_sat_prob 0.075 --dblock_nat_prob 0.075 \
18
  --dblock_ar_loss_tokens 512 --dblock_sat_loss_tokens 0 --dblock_nat_loss_tokens 512 \
19
- --tie_weights --batch_size 2 --block 1280 --amp --attn_backend sublinear \
20
  --sublinear_window 128 --sublinear_stride 128 --sublinear_max_anchors 128 --sublinear_chunk 128 \
21
  --sublinear_sinks 4 --sublinear_recent_anchors 64 --no-sublinear_pooled_landmarks \
22
- --grad_checkpoint --optimizer paged_adamw8bit --sat_every 4 --nat_every 4 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
23
  --token_param_ratio 100 --save_dir "$SAVE_DIR" --save_every_sec 86400 --heartbeat_every_sec 300 \
 
24
  --empty_cache_every_steps 0 --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
 
10
  SAVE_DIR=/workspace/agillm4_4090_ckpts
11
  CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
12
  exec >> /workspace/agillm4_floor_train.log 2>&1
13
+ echo "RELAUNCH_AGILLM4_DBLOCK_SG2 $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT (batch4 official speed-optimized + sublinear v2)"
14
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
15
  --dblock --dblock_blocks 4 --dblock_schedule loss_balanced --dblock_warmup_steps 16 \
16
  --dblock_sigma_curriculum_steps 2000 --dblock_log_every 25 --dblock_objective_mode stochastic \
17
  --dblock_ar_prob 0.85 --dblock_sat_prob 0.075 --dblock_nat_prob 0.075 \
18
  --dblock_ar_loss_tokens 512 --dblock_sat_loss_tokens 0 --dblock_nat_loss_tokens 512 \
19
+ --tie_weights --batch_size 4 --block 1280 --amp --attn_backend sublinear \
20
  --sublinear_window 128 --sublinear_stride 128 --sublinear_max_anchors 128 --sublinear_chunk 128 \
21
  --sublinear_sinks 4 --sublinear_recent_anchors 64 --no-sublinear_pooled_landmarks \
22
+ --grad_checkpoint --dblock_checkpoint_stride 1 --optimizer paged_adamw8bit --sat_every 4 --nat_every 4 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
23
  --token_param_ratio 100 --save_dir "$SAVE_DIR" --save_every_sec 86400 --heartbeat_every_sec 300 \
24
+ \
25
  --empty_cache_every_steps 0 --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1