Scott/Codex commited on
Commit
c3d5043
·
1 Parent(s): 4396074

Add stochastic sparse DBlock speed profile

Browse files
README.md CHANGED
@@ -75,4 +75,12 @@ allocation for long context, and also gathers ALiBi bias directly for selected
75
  local/anchor keys instead of materializing dense `[heads x T x T]` bias tensors.
76
  A trainer heartbeat, post-checkpoint CUDA cache clear, and optional `--empty_cache_every_steps` hook were added for easier long-running Vast monitoring and VRAM-first allocator behavior.
77
 
 
 
 
 
 
 
 
 
78
  License: Apache-2.0 (matching the upstream method).
 
75
  local/anchor keys instead of materializing dense `[heads x T x T]` bias tensors.
76
  A trainer heartbeat, post-checkpoint CUDA cache clear, and optional `--empty_cache_every_steps` hook were added for easier long-running Vast monitoring and VRAM-first allocator behavior.
77
 
78
+ Speed update 2026-05-29: the live Vast line now uses algorithmic speedups rather
79
+ than only hardware-style knobs: stochastic DBlock objective sampling (one sampled
80
+ AR/SAT/NAT objective per step), sampled token-level CE for the large vocab head,
81
+ and a tighter structured-sublinear attention profile (`window=128`, `stride=128`,
82
+ `max_anchors=128`). The first stable live window reached about 2.49k tok/s with
83
+ an ETA around 326 days, under the 1y+90d target, while keeping ctx=1280, B=2,
84
+ DiffusionBlocks, gradient-checkpointed blocks, tied heads, and structured masks.
85
+
86
  License: Apache-2.0 (matching the upstream method).
dblocks_train.py CHANGED
@@ -94,7 +94,7 @@ def _sample_sigma(ids, lo, hi, args, state):
94
  return torch.from_numpy(sig_np).to(ids.device)
95
 
96
 
97
- def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved):
98
  log_every = int(getattr(args, "dblock_log_every", 50))
99
  step = int(state.get("step", 0))
100
  if log_every <= 0 or step % log_every != 0:
@@ -105,7 +105,7 @@ def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, pea
105
  if peak_alloc is not None:
106
  mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
107
  print(
108
- f"[dblock] step={step} block={bi} layers={layers} "
109
  f"loss={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f} "
110
  f"counts=[{counts}] ema=[{emas}]{mem}",
111
  flush=True,
@@ -123,6 +123,53 @@ def _update_stats(state, bi, loss_value):
123
  state["step"] = int(state.get("step", 0)) + 1
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
127
  import nB300_agillm4 as M
128
 
@@ -133,6 +180,7 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
133
  asg = state["assign"]
134
  bs = state["bsig"]
135
  T = ids.size(1)
 
136
  bi = _choose_block(state, args)
137
  lo, hi = sorted([bs[bi], bs[bi + 1]])
138
  layers = asg[bi]
@@ -143,39 +191,57 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
143
  ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
144
  sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
145
  nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  ar_val = 0.0
148
  sat_val = 0.0
149
  nat_val = 0.0
150
 
151
- if ar_weight > 0.0:
152
  causal = M.causal_mask(T, structured=M.use_structured_masks(args))
153
  with M.amp(args.amp):
154
  emb = core.emb(ids)
155
  zt = emb + sig[:, None, None] * torch.randn_like(emb)
156
  h = ci * zt
157
  for li in layers:
158
- h = _ck.checkpoint(lambda y, block=core.blocks[li]: block(y, causal), h, use_reentrant=False)
159
  Dn = core.ln(cs * zt + co * h)
160
- ar = ar_weight * w * fused_ce(Dn[:, :-1].contiguous(), ar_h.proj.weight, ids[:, 1:].contiguous())
 
 
 
161
  ar_val = float(ar.detach())
162
  scaler.scale(ar).backward()
163
- del causal, emb, zt, h, Dn, ar
164
 
165
- do_sat = (not getattr(args, "ar_only", False)) and (
166
- int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
167
- )
168
- if sat_weight > 0.0 and do_sat:
169
  smask = M.sat_mask(T, structured=M.use_structured_masks(args))
170
  with M.amp(args.amp):
171
  emb2 = core.emb(ids)
172
  zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
173
  h2 = ci * zt2
174
  for li in layers:
175
- h2 = _ck.checkpoint(lambda y, block=core.blocks[li]: block(y, smask), h2, use_reentrant=False)
176
  Ds = core.ln(cs * zt2 + co * h2)
177
  last = Ds[:, -SATB:]
178
- satf = fused_ce(last.contiguous(), sat_h.proj.weight, ids[:, 1 : SATB + 1].contiguous())
 
 
 
179
  satv = (
180
  M.EMIT_LAMBDA
181
  * F.cross_entropy(
@@ -188,19 +254,9 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
188
  sat = sat_weight * w * (satf + satv)
189
  sat_val = float(sat.detach())
190
  scaler.scale(sat).backward()
191
- del smask, emb2, zt2, h2, Ds, last, satf, satv, sat
192
 
193
- do_nat = (
194
- nat_h is not None
195
- and nat_weight > 0.0
196
- and (not getattr(args, "ar_only", False))
197
- and int(getattr(args, "nat_every", 1)) > 0
198
- and (
199
- int(getattr(args, "nat_every", 1)) <= 1
200
- or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
201
- )
202
- )
203
- if do_nat:
204
  ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
205
  nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
206
  with M.amp(args.amp):
@@ -211,12 +267,17 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
211
  nat_in[m] = M.BLANK
212
  hn = core.emb(nat_in)
213
  for li in layers:
214
- hn = _ck.checkpoint(lambda y, block=core.blocks[li]: block(y, None), hn, use_reentrant=False)
215
  Dnat = core.ln(hn)
216
- nat = nat_weight * fused_ce(Dnat[m], nat_h.proj.weight, nat_ids[m])
 
 
 
 
 
217
  nat_val = float(nat.detach())
218
  scaler.scale(nat).backward()
219
- del nat_ids, nat_in, m, hn, Dnat, nat
220
 
221
  total_val = ar_val + sat_val + nat_val
222
  if not math.isfinite(total_val):
@@ -239,5 +300,5 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
239
  peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
240
  peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
241
  _update_stats(state, bi, total_val)
242
- _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved)
243
  return total_val
 
94
  return torch.from_numpy(sig_np).to(ids.device)
95
 
96
 
97
+ def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=None):
98
  log_every = int(getattr(args, "dblock_log_every", 50))
99
  step = int(state.get("step", 0))
100
  if log_every <= 0 or step % log_every != 0:
 
105
  if peak_alloc is not None:
106
  mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
107
  print(
108
+ f"[dblock] step={step} block={bi} obj={objective or 'mixed'} layers={layers} "
109
  f"loss={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f} "
110
  f"counts=[{counts}] ema=[{emas}]{mem}",
111
  flush=True,
 
123
  state["step"] = int(state.get("step", 0)) + 1
124
 
125
 
126
+ def _run_block(block, x, mask, use_checkpoint):
127
+ if use_checkpoint:
128
+ return _ck.checkpoint(lambda y, block=block: block(y, mask), x, use_reentrant=False)
129
+ return block(x, mask)
130
+
131
+
132
+ def _sample_token_loss_inputs(hidden, targets, max_tokens):
133
+ max_tokens = int(max_tokens or 0)
134
+ if max_tokens <= 0:
135
+ return hidden.contiguous(), targets.contiguous(), int(targets.numel()), int(targets.numel())
136
+ flat_targets = targets.reshape(-1)
137
+ total = int(flat_targets.numel())
138
+ if total <= max_tokens:
139
+ return hidden.contiguous(), targets.contiguous(), total, total
140
+ # With-replacement sampling avoids building a full randperm each step; the sampled
141
+ # mean remains an unbiased estimator of the dense token CE mean.
142
+ idx = torch.randint(total, (max_tokens,), device=targets.device)
143
+ flat_hidden = hidden.reshape(total, hidden.size(-1))
144
+ return flat_hidden.index_select(0, idx).contiguous(), flat_targets.index_select(0, idx).contiguous(), int(max_tokens), total
145
+
146
+
147
+ def _choose_objectives(state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic):
148
+ mode = str(getattr(args, "dblock_objective_mode", "periodic") or "periodic").lower()
149
+ if mode != "stochastic":
150
+ return ar_weight > 0.0, sat_weight > 0.0 and do_sat_periodic, nat_weight > 0.0 and do_nat_periodic, "periodic"
151
+ choices = []
152
+ probs = []
153
+ if ar_weight > 0.0:
154
+ choices.append("ar")
155
+ probs.append(max(0.0, float(getattr(args, "dblock_ar_prob", 0.80))))
156
+ if sat_weight > 0.0 and not getattr(args, "ar_only", False):
157
+ choices.append("sat")
158
+ probs.append(max(0.0, float(getattr(args, "dblock_sat_prob", 0.10))))
159
+ if nat_weight > 0.0 and not getattr(args, "ar_only", False):
160
+ choices.append("nat")
161
+ probs.append(max(0.0, float(getattr(args, "dblock_nat_prob", 0.10))))
162
+ if not choices:
163
+ return False, False, False, "none"
164
+ total = sum(probs)
165
+ if total <= 0.0:
166
+ probs = [1.0 / len(choices) for _ in choices]
167
+ else:
168
+ probs = [p / total for p in probs]
169
+ picked = random.choices(choices, weights=probs, k=1)[0]
170
+ return picked == "ar", picked == "sat", picked == "nat", picked
171
+
172
+
173
  def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
174
  import nB300_agillm4 as M
175
 
 
180
  asg = state["assign"]
181
  bs = state["bsig"]
182
  T = ids.size(1)
183
+ use_layer_checkpoint = bool(getattr(args, "grad_checkpoint", False))
184
  bi = _choose_block(state, args)
185
  lo, hi = sorted([bs[bi], bs[bi + 1]])
186
  layers = asg[bi]
 
191
  ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
192
  sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
193
  nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
194
+ do_sat_periodic = (not getattr(args, "ar_only", False)) and (
195
+ int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
196
+ )
197
+ do_nat_periodic = (
198
+ nat_h is not None
199
+ and (not getattr(args, "ar_only", False))
200
+ and int(getattr(args, "nat_every", 1)) > 0
201
+ and (
202
+ int(getattr(args, "nat_every", 1)) <= 1
203
+ or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
204
+ )
205
+ )
206
+ run_ar, run_sat, run_nat, objective = _choose_objectives(
207
+ state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
208
+ )
209
 
210
  ar_val = 0.0
211
  sat_val = 0.0
212
  nat_val = 0.0
213
 
214
+ if run_ar:
215
  causal = M.causal_mask(T, structured=M.use_structured_masks(args))
216
  with M.amp(args.amp):
217
  emb = core.emb(ids)
218
  zt = emb + sig[:, None, None] * torch.randn_like(emb)
219
  h = ci * zt
220
  for li in layers:
221
+ h = _run_block(core.blocks[li], h, causal, use_layer_checkpoint)
222
  Dn = core.ln(cs * zt + co * h)
223
+ ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
224
+ Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
225
+ )
226
+ ar = ar_weight * w * fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
227
  ar_val = float(ar.detach())
228
  scaler.scale(ar).backward()
229
+ del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar, ar_used, ar_total
230
 
231
+ if run_sat:
 
 
 
232
  smask = M.sat_mask(T, structured=M.use_structured_masks(args))
233
  with M.amp(args.amp):
234
  emb2 = core.emb(ids)
235
  zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
236
  h2 = ci * zt2
237
  for li in layers:
238
+ h2 = _run_block(core.blocks[li], h2, smask, use_layer_checkpoint)
239
  Ds = core.ln(cs * zt2 + co * h2)
240
  last = Ds[:, -SATB:]
241
+ sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs(
242
+ last, ids[:, 1 : SATB + 1], int(getattr(args, "dblock_sat_loss_tokens", 0))
243
+ )
244
+ satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
245
  satv = (
246
  M.EMIT_LAMBDA
247
  * F.cross_entropy(
 
254
  sat = sat_weight * w * (satf + satv)
255
  sat_val = float(sat.detach())
256
  scaler.scale(sat).backward()
257
+ del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat
258
 
259
+ if run_nat:
 
 
 
 
 
 
 
 
 
 
260
  ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
261
  nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
262
  with M.amp(args.amp):
 
267
  nat_in[m] = M.BLANK
268
  hn = core.emb(nat_in)
269
  for li in layers:
270
+ hn = _run_block(core.blocks[li], hn, None, use_layer_checkpoint)
271
  Dnat = core.ln(hn)
272
+ nat_hidden = Dnat[m]
273
+ nat_targets = nat_ids[m]
274
+ nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
275
+ nat_hidden.unsqueeze(0), nat_targets.unsqueeze(0), int(getattr(args, "dblock_nat_loss_tokens", 0))
276
+ )
277
+ nat = nat_weight * fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
278
  nat_val = float(nat.detach())
279
  scaler.scale(nat).backward()
280
+ del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat, nat_used, nat_total
281
 
282
  total_val = ar_val + sat_val + nat_val
283
  if not math.isfinite(total_val):
 
300
  peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
301
  peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
302
  _update_stats(state, bi, total_val)
303
+ _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=objective)
304
  return total_val
nB300_agillm4_vram_dblock.py CHANGED
@@ -2941,6 +2941,17 @@ def main():
2941
  tr.add_argument("--dblock_ar_weight", type=float, default=1.0)
2942
  tr.add_argument("--dblock_sat_weight", type=float, default=1.0)
2943
  tr.add_argument("--dblock_nat_weight", type=float, default=1.0)
 
 
 
 
 
 
 
 
 
 
 
2944
  tr.add_argument("--reinit_nat", action="store_true",
2945
  help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
2946
  tr.add_argument("--seed_nat_from_ar", action="store_true",
 
2941
  tr.add_argument("--dblock_ar_weight", type=float, default=1.0)
2942
  tr.add_argument("--dblock_sat_weight", type=float, default=1.0)
2943
  tr.add_argument("--dblock_nat_weight", type=float, default=1.0)
2944
+ tr.add_argument("--dblock_objective_mode", choices=["periodic", "stochastic"], default="periodic",
2945
+ help="DBlock objective scheduler. stochastic samples one objective per step to reduce redundant AR/SAT/NAT forwards.")
2946
+ tr.add_argument("--dblock_ar_prob", type=float, default=0.80, help="Stochastic DBlock probability for AR objective.")
2947
+ tr.add_argument("--dblock_sat_prob", type=float, default=0.10, help="Stochastic DBlock probability for SAT objective.")
2948
+ tr.add_argument("--dblock_nat_prob", type=float, default=0.10, help="Stochastic DBlock probability for NAT objective.")
2949
+ tr.add_argument("--dblock_ar_loss_tokens", type=int, default=0,
2950
+ help="If >0, uniformly sample this many AR target positions per DBlock step for stochastic token-level CE.")
2951
+ tr.add_argument("--dblock_sat_loss_tokens", type=int, default=0,
2952
+ help="If >0, uniformly sample this many SAT target positions per DBlock step.")
2953
+ tr.add_argument("--dblock_nat_loss_tokens", type=int, default=0,
2954
+ help="If >0, uniformly sample this many NAT target positions per DBlock step.")
2955
  tr.add_argument("--reinit_nat", action="store_true",
2956
  help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
2957
  tr.add_argument("--seed_nat_from_ar", action="store_true",
relaunch_agillm4_dblock.sh CHANGED
@@ -9,16 +9,37 @@ export AGILLM_ATTN_BACKEND="${AGILLM_ATTN_BACKEND:-sublinear}"
9
  if [ -f /root/.cache/huggingface/token ]; then export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; fi
10
  SAVE_DIR=/workspace/agillm4_4090_ckpts
11
  CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  [ -n "$CKPT" ] || { echo "no ckpt" >&2; exit 1; }
13
  exec >> /workspace/agillm4_floor_train.log 2>&1
14
- echo "RELAUNCH_AGILLM4_DBLOCK $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock blocks=${AGILLM4_DBLOCKS:-4} tie_weights=1 attn=${AGILLM_ATTN_BACKEND}"
15
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
16
  --dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
17
  --dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
18
- --dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --tie_weights \
19
- --batch_size 1 --block "${AGILLM4_BLOCK:-1280}" --amp --attn_backend "${AGILLM_ATTN_BACKEND}" --grad_checkpoint \
20
- --optimizer paged_adamw8bit --sat_every 1 --nat_every 1 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
 
 
 
21
  --token_param_ratio 100 --save_dir "$SAVE_DIR" \
22
  --save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
23
- --empty_cache_every_steps "${AGILLM4_EMPTY_CACHE_EVERY_STEPS:-1}" \
24
  --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
 
9
  if [ -f /root/.cache/huggingface/token ]; then export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; fi
10
  SAVE_DIR=/workspace/agillm4_4090_ckpts
11
  CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
12
+ BATCH_SIZE="${AGILLM4_BATCH_SIZE:-2}"
13
+ SAT_EVERY="${AGILLM4_SAT_EVERY:-4}"
14
+ NAT_EVERY="${AGILLM4_NAT_EVERY:-4}"
15
+ EMPTY_CACHE_EVERY="${AGILLM4_EMPTY_CACHE_EVERY_STEPS:-0}"
16
+ GRAD_CHECKPOINT="${AGILLM4_GRAD_CHECKPOINT:-1}"
17
+ DBLOCK_OBJECTIVE_MODE="${AGILLM4_DBLOCK_OBJECTIVE_MODE:-stochastic}"
18
+ DBLOCK_AR_PROB="${AGILLM4_DBLOCK_AR_PROB:-0.85}"
19
+ DBLOCK_SAT_PROB="${AGILLM4_DBLOCK_SAT_PROB:-0.075}"
20
+ DBLOCK_NAT_PROB="${AGILLM4_DBLOCK_NAT_PROB:-0.075}"
21
+ DBLOCK_AR_LOSS_TOKENS="${AGILLM4_DBLOCK_AR_LOSS_TOKENS:-512}"
22
+ DBLOCK_SAT_LOSS_TOKENS="${AGILLM4_DBLOCK_SAT_LOSS_TOKENS:-0}"
23
+ DBLOCK_NAT_LOSS_TOKENS="${AGILLM4_DBLOCK_NAT_LOSS_TOKENS:-512}"
24
+ SUBLINEAR_WINDOW="${AGILLM4_SUBLINEAR_WINDOW:-128}"
25
+ SUBLINEAR_STRIDE="${AGILLM4_SUBLINEAR_STRIDE:-128}"
26
+ SUBLINEAR_MAX_ANCHORS="${AGILLM4_SUBLINEAR_MAX_ANCHORS:-128}"
27
+ SUBLINEAR_CHUNK="${AGILLM4_SUBLINEAR_CHUNK:-128}"
28
+ GC_FLAG=()
29
+ if [ "$GRAD_CHECKPOINT" = "1" ] || [ "$GRAD_CHECKPOINT" = "true" ] || [ "$GRAD_CHECKPOINT" = "yes" ]; then GC_FLAG=(--grad_checkpoint); fi
30
  [ -n "$CKPT" ] || { echo "no ckpt" >&2; exit 1; }
31
  exec >> /workspace/agillm4_floor_train.log 2>&1
32
+ echo "RELAUNCH_AGILLM4_DBLOCK_SPEED $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock blocks=${AGILLM4_DBLOCKS:-4} tie_weights=1 attn=${AGILLM_ATTN_BACKEND} batch=$BATCH_SIZE sat_every=$SAT_EVERY nat_every=$NAT_EVERY empty_cache_every=$EMPTY_CACHE_EVERY grad_checkpoint=$GRAD_CHECKPOINT objective=$DBLOCK_OBJECTIVE_MODE ar_prob=$DBLOCK_AR_PROB sat_prob=$DBLOCK_SAT_PROB nat_prob=$DBLOCK_NAT_PROB ar_loss_tokens=$DBLOCK_AR_LOSS_TOKENS nat_loss_tokens=$DBLOCK_NAT_LOSS_TOKENS sublinear_window=$SUBLINEAR_WINDOW sublinear_stride=$SUBLINEAR_STRIDE sublinear_max_anchors=$SUBLINEAR_MAX_ANCHORS"
33
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
34
  --dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
35
  --dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
36
+ --dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --dblock_objective_mode "$DBLOCK_OBJECTIVE_MODE" \
37
+ --dblock_ar_prob "$DBLOCK_AR_PROB" --dblock_sat_prob "$DBLOCK_SAT_PROB" --dblock_nat_prob "$DBLOCK_NAT_PROB" \
38
+ --dblock_ar_loss_tokens "$DBLOCK_AR_LOSS_TOKENS" --dblock_sat_loss_tokens "$DBLOCK_SAT_LOSS_TOKENS" --dblock_nat_loss_tokens "$DBLOCK_NAT_LOSS_TOKENS" \
39
+ --tie_weights \
40
+ --batch_size "$BATCH_SIZE" --block "${AGILLM4_BLOCK:-1280}" --amp --attn_backend "${AGILLM_ATTN_BACKEND}" --sublinear_window "$SUBLINEAR_WINDOW" --sublinear_stride "$SUBLINEAR_STRIDE" --sublinear_max_anchors "$SUBLINEAR_MAX_ANCHORS" --sublinear_chunk "$SUBLINEAR_CHUNK" "${GC_FLAG[@]}" \
41
+ --optimizer paged_adamw8bit --sat_every "$SAT_EVERY" --nat_every "$NAT_EVERY" --nat_max_tokens 768 --nat_mask_ratio 0.5 \
42
  --token_param_ratio 100 --save_dir "$SAVE_DIR" \
43
  --save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
44
+ --empty_cache_every_steps "$EMPTY_CACHE_EVERY" \
45
  --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
relaunch_agillm4_dblock_tied.sh CHANGED
@@ -8,15 +8,36 @@ export AGILLM_ATTN_BACKEND=sublinear
8
  [ -f /root/.cache/huggingface/token ] && { export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; }
9
  SAVE_DIR=/workspace/agillm4_4090_ckpts
10
  CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  exec >> /workspace/agillm4_floor_train.log 2>&1
12
- echo "RELAUNCH_AGILLM4_DBLOCK_TIED $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock --tie_weights --attn_backend sublinear (fused_ce fixed)"
13
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
14
  --dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
15
  --dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
16
- --dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --tie_weights \
17
- --batch_size 1 --block 1280 --amp --attn_backend sublinear --grad_checkpoint \
18
- --optimizer paged_adamw8bit --sat_every 1 --nat_every 1 --nat_max_tokens 768 --nat_mask_ratio 0.5 \
 
 
 
19
  --token_param_ratio 100 --save_dir "$SAVE_DIR" \
20
  --save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
21
- --empty_cache_every_steps "${AGILLM4_EMPTY_CACHE_EVERY_STEPS:-1}" \
22
  --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
 
8
  [ -f /root/.cache/huggingface/token ] && { export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; }
9
  SAVE_DIR=/workspace/agillm4_4090_ckpts
10
  CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
11
+ BATCH_SIZE="${AGILLM4_BATCH_SIZE:-2}"
12
+ SAT_EVERY="${AGILLM4_SAT_EVERY:-4}"
13
+ NAT_EVERY="${AGILLM4_NAT_EVERY:-4}"
14
+ EMPTY_CACHE_EVERY="${AGILLM4_EMPTY_CACHE_EVERY_STEPS:-0}"
15
+ GRAD_CHECKPOINT="${AGILLM4_GRAD_CHECKPOINT:-1}"
16
+ DBLOCK_OBJECTIVE_MODE="${AGILLM4_DBLOCK_OBJECTIVE_MODE:-stochastic}"
17
+ DBLOCK_AR_PROB="${AGILLM4_DBLOCK_AR_PROB:-0.85}"
18
+ DBLOCK_SAT_PROB="${AGILLM4_DBLOCK_SAT_PROB:-0.075}"
19
+ DBLOCK_NAT_PROB="${AGILLM4_DBLOCK_NAT_PROB:-0.075}"
20
+ DBLOCK_AR_LOSS_TOKENS="${AGILLM4_DBLOCK_AR_LOSS_TOKENS:-512}"
21
+ DBLOCK_SAT_LOSS_TOKENS="${AGILLM4_DBLOCK_SAT_LOSS_TOKENS:-0}"
22
+ DBLOCK_NAT_LOSS_TOKENS="${AGILLM4_DBLOCK_NAT_LOSS_TOKENS:-512}"
23
+ SUBLINEAR_WINDOW="${AGILLM4_SUBLINEAR_WINDOW:-128}"
24
+ SUBLINEAR_STRIDE="${AGILLM4_SUBLINEAR_STRIDE:-128}"
25
+ SUBLINEAR_MAX_ANCHORS="${AGILLM4_SUBLINEAR_MAX_ANCHORS:-128}"
26
+ SUBLINEAR_CHUNK="${AGILLM4_SUBLINEAR_CHUNK:-128}"
27
+ GC_FLAG=()
28
+ if [ "$GRAD_CHECKPOINT" = "1" ] || [ "$GRAD_CHECKPOINT" = "true" ] || [ "$GRAD_CHECKPOINT" = "yes" ]; then GC_FLAG=(--grad_checkpoint); fi
29
  exec >> /workspace/agillm4_floor_train.log 2>&1
30
+ echo "RELAUNCH_AGILLM4_DBLOCK_TIED_SPEED $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock --tie_weights --attn_backend sublinear batch=$BATCH_SIZE sat_every=$SAT_EVERY nat_every=$NAT_EVERY empty_cache_every=$EMPTY_CACHE_EVERY grad_checkpoint=$GRAD_CHECKPOINT objective=$DBLOCK_OBJECTIVE_MODE ar_prob=$DBLOCK_AR_PROB sat_prob=$DBLOCK_SAT_PROB nat_prob=$DBLOCK_NAT_PROB ar_loss_tokens=$DBLOCK_AR_LOSS_TOKENS nat_loss_tokens=$DBLOCK_NAT_LOSS_TOKENS sublinear_window=$SUBLINEAR_WINDOW sublinear_stride=$SUBLINEAR_STRIDE sublinear_max_anchors=$SUBLINEAR_MAX_ANCHORS"
31
  exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
32
  --dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
33
  --dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
34
+ --dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --dblock_objective_mode "$DBLOCK_OBJECTIVE_MODE" \
35
+ --dblock_ar_prob "$DBLOCK_AR_PROB" --dblock_sat_prob "$DBLOCK_SAT_PROB" --dblock_nat_prob "$DBLOCK_NAT_PROB" \
36
+ --dblock_ar_loss_tokens "$DBLOCK_AR_LOSS_TOKENS" --dblock_sat_loss_tokens "$DBLOCK_SAT_LOSS_TOKENS" --dblock_nat_loss_tokens "$DBLOCK_NAT_LOSS_TOKENS" \
37
+ --tie_weights \
38
+ --batch_size "$BATCH_SIZE" --block 1280 --amp --attn_backend sublinear --sublinear_window "$SUBLINEAR_WINDOW" --sublinear_stride "$SUBLINEAR_STRIDE" --sublinear_max_anchors "$SUBLINEAR_MAX_ANCHORS" --sublinear_chunk "$SUBLINEAR_CHUNK" "${GC_FLAG[@]}" \
39
+ --optimizer paged_adamw8bit --sat_every "$SAT_EVERY" --nat_every "$NAT_EVERY" --nat_max_tokens 768 --nat_mask_ratio 0.5 \
40
  --token_param_ratio 100 --save_dir "$SAVE_DIR" \
41
  --save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
42
+ --empty_cache_every_steps "$EMPTY_CACHE_EVERY" \
43
  --delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1