Scott/Codex commited on
Commit ·
c3d5043
1
Parent(s): 4396074
Add stochastic sparse DBlock speed profile
Browse files- README.md +8 -0
- dblocks_train.py +89 -28
- nB300_agillm4_vram_dblock.py +11 -0
- relaunch_agillm4_dblock.sh +26 -5
- relaunch_agillm4_dblock_tied.sh +26 -5
README.md
CHANGED
|
@@ -75,4 +75,12 @@ allocation for long context, and also gathers ALiBi bias directly for selected
|
|
| 75 |
local/anchor keys instead of materializing dense `[heads x T x T]` bias tensors.
|
| 76 |
A trainer heartbeat, post-checkpoint CUDA cache clear, and optional `--empty_cache_every_steps` hook were added for easier long-running Vast monitoring and VRAM-first allocator behavior.
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
License: Apache-2.0 (matching the upstream method).
|
|
|
|
| 75 |
local/anchor keys instead of materializing dense `[heads x T x T]` bias tensors.
|
| 76 |
A trainer heartbeat, post-checkpoint CUDA cache clear, and optional `--empty_cache_every_steps` hook were added for easier long-running Vast monitoring and VRAM-first allocator behavior.
|
| 77 |
|
| 78 |
+
Speed update 2026-05-29: the live Vast line now uses algorithmic speedups rather
|
| 79 |
+
than only hardware-style knobs: stochastic DBlock objective sampling (one sampled
|
| 80 |
+
AR/SAT/NAT objective per step), sampled token-level CE for the large vocab head,
|
| 81 |
+
and a tighter structured-sublinear attention profile (`window=128`, `stride=128`,
|
| 82 |
+
`max_anchors=128`). The first stable live window reached about 2.49k tok/s with
|
| 83 |
+
an ETA around 326 days, under the 1y+90d target, while keeping ctx=1280, B=2,
|
| 84 |
+
DiffusionBlocks, gradient-checkpointed blocks, tied heads, and structured masks.
|
| 85 |
+
|
| 86 |
License: Apache-2.0 (matching the upstream method).
|
dblocks_train.py
CHANGED
|
@@ -94,7 +94,7 @@ def _sample_sigma(ids, lo, hi, args, state):
|
|
| 94 |
return torch.from_numpy(sig_np).to(ids.device)
|
| 95 |
|
| 96 |
|
| 97 |
-
def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved):
|
| 98 |
log_every = int(getattr(args, "dblock_log_every", 50))
|
| 99 |
step = int(state.get("step", 0))
|
| 100 |
if log_every <= 0 or step % log_every != 0:
|
|
@@ -105,7 +105,7 @@ def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, pea
|
|
| 105 |
if peak_alloc is not None:
|
| 106 |
mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
|
| 107 |
print(
|
| 108 |
-
f"[dblock] step={step} block={bi} layers={layers} "
|
| 109 |
f"loss={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f} "
|
| 110 |
f"counts=[{counts}] ema=[{emas}]{mem}",
|
| 111 |
flush=True,
|
|
@@ -123,6 +123,53 @@ def _update_stats(state, bi, loss_value):
|
|
| 123 |
state["step"] = int(state.get("step", 0)) + 1
|
| 124 |
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
| 127 |
import nB300_agillm4 as M
|
| 128 |
|
|
@@ -133,6 +180,7 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 133 |
asg = state["assign"]
|
| 134 |
bs = state["bsig"]
|
| 135 |
T = ids.size(1)
|
|
|
|
| 136 |
bi = _choose_block(state, args)
|
| 137 |
lo, hi = sorted([bs[bi], bs[bi + 1]])
|
| 138 |
layers = asg[bi]
|
|
@@ -143,39 +191,57 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 143 |
ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
|
| 144 |
sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
|
| 145 |
nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
ar_val = 0.0
|
| 148 |
sat_val = 0.0
|
| 149 |
nat_val = 0.0
|
| 150 |
|
| 151 |
-
if
|
| 152 |
causal = M.causal_mask(T, structured=M.use_structured_masks(args))
|
| 153 |
with M.amp(args.amp):
|
| 154 |
emb = core.emb(ids)
|
| 155 |
zt = emb + sig[:, None, None] * torch.randn_like(emb)
|
| 156 |
h = ci * zt
|
| 157 |
for li in layers:
|
| 158 |
-
h =
|
| 159 |
Dn = core.ln(cs * zt + co * h)
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
| 161 |
ar_val = float(ar.detach())
|
| 162 |
scaler.scale(ar).backward()
|
| 163 |
-
del causal, emb, zt, h, Dn, ar
|
| 164 |
|
| 165 |
-
|
| 166 |
-
int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
|
| 167 |
-
)
|
| 168 |
-
if sat_weight > 0.0 and do_sat:
|
| 169 |
smask = M.sat_mask(T, structured=M.use_structured_masks(args))
|
| 170 |
with M.amp(args.amp):
|
| 171 |
emb2 = core.emb(ids)
|
| 172 |
zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
|
| 173 |
h2 = ci * zt2
|
| 174 |
for li in layers:
|
| 175 |
-
h2 =
|
| 176 |
Ds = core.ln(cs * zt2 + co * h2)
|
| 177 |
last = Ds[:, -SATB:]
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
| 179 |
satv = (
|
| 180 |
M.EMIT_LAMBDA
|
| 181 |
* F.cross_entropy(
|
|
@@ -188,19 +254,9 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 188 |
sat = sat_weight * w * (satf + satv)
|
| 189 |
sat_val = float(sat.detach())
|
| 190 |
scaler.scale(sat).backward()
|
| 191 |
-
del smask, emb2, zt2, h2, Ds, last, satf, satv, sat
|
| 192 |
|
| 193 |
-
|
| 194 |
-
nat_h is not None
|
| 195 |
-
and nat_weight > 0.0
|
| 196 |
-
and (not getattr(args, "ar_only", False))
|
| 197 |
-
and int(getattr(args, "nat_every", 1)) > 0
|
| 198 |
-
and (
|
| 199 |
-
int(getattr(args, "nat_every", 1)) <= 1
|
| 200 |
-
or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
|
| 201 |
-
)
|
| 202 |
-
)
|
| 203 |
-
if do_nat:
|
| 204 |
ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
|
| 205 |
nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
|
| 206 |
with M.amp(args.amp):
|
|
@@ -211,12 +267,17 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 211 |
nat_in[m] = M.BLANK
|
| 212 |
hn = core.emb(nat_in)
|
| 213 |
for li in layers:
|
| 214 |
-
hn =
|
| 215 |
Dnat = core.ln(hn)
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
nat_val = float(nat.detach())
|
| 218 |
scaler.scale(nat).backward()
|
| 219 |
-
del nat_ids, nat_in, m, hn, Dnat, nat
|
| 220 |
|
| 221 |
total_val = ar_val + sat_val + nat_val
|
| 222 |
if not math.isfinite(total_val):
|
|
@@ -239,5 +300,5 @@ def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
|
| 239 |
peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
|
| 240 |
peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
|
| 241 |
_update_stats(state, bi, total_val)
|
| 242 |
-
_maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved)
|
| 243 |
return total_val
|
|
|
|
| 94 |
return torch.from_numpy(sig_np).to(ids.device)
|
| 95 |
|
| 96 |
|
| 97 |
+
def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=None):
|
| 98 |
log_every = int(getattr(args, "dblock_log_every", 50))
|
| 99 |
step = int(state.get("step", 0))
|
| 100 |
if log_every <= 0 or step % log_every != 0:
|
|
|
|
| 105 |
if peak_alloc is not None:
|
| 106 |
mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
|
| 107 |
print(
|
| 108 |
+
f"[dblock] step={step} block={bi} obj={objective or 'mixed'} layers={layers} "
|
| 109 |
f"loss={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f} "
|
| 110 |
f"counts=[{counts}] ema=[{emas}]{mem}",
|
| 111 |
flush=True,
|
|
|
|
| 123 |
state["step"] = int(state.get("step", 0)) + 1
|
| 124 |
|
| 125 |
|
| 126 |
+
def _run_block(block, x, mask, use_checkpoint):
|
| 127 |
+
if use_checkpoint:
|
| 128 |
+
return _ck.checkpoint(lambda y, block=block: block(y, mask), x, use_reentrant=False)
|
| 129 |
+
return block(x, mask)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _sample_token_loss_inputs(hidden, targets, max_tokens):
|
| 133 |
+
max_tokens = int(max_tokens or 0)
|
| 134 |
+
if max_tokens <= 0:
|
| 135 |
+
return hidden.contiguous(), targets.contiguous(), int(targets.numel()), int(targets.numel())
|
| 136 |
+
flat_targets = targets.reshape(-1)
|
| 137 |
+
total = int(flat_targets.numel())
|
| 138 |
+
if total <= max_tokens:
|
| 139 |
+
return hidden.contiguous(), targets.contiguous(), total, total
|
| 140 |
+
# With-replacement sampling avoids building a full randperm each step; the sampled
|
| 141 |
+
# mean remains an unbiased estimator of the dense token CE mean.
|
| 142 |
+
idx = torch.randint(total, (max_tokens,), device=targets.device)
|
| 143 |
+
flat_hidden = hidden.reshape(total, hidden.size(-1))
|
| 144 |
+
return flat_hidden.index_select(0, idx).contiguous(), flat_targets.index_select(0, idx).contiguous(), int(max_tokens), total
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _choose_objectives(state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic):
|
| 148 |
+
mode = str(getattr(args, "dblock_objective_mode", "periodic") or "periodic").lower()
|
| 149 |
+
if mode != "stochastic":
|
| 150 |
+
return ar_weight > 0.0, sat_weight > 0.0 and do_sat_periodic, nat_weight > 0.0 and do_nat_periodic, "periodic"
|
| 151 |
+
choices = []
|
| 152 |
+
probs = []
|
| 153 |
+
if ar_weight > 0.0:
|
| 154 |
+
choices.append("ar")
|
| 155 |
+
probs.append(max(0.0, float(getattr(args, "dblock_ar_prob", 0.80))))
|
| 156 |
+
if sat_weight > 0.0 and not getattr(args, "ar_only", False):
|
| 157 |
+
choices.append("sat")
|
| 158 |
+
probs.append(max(0.0, float(getattr(args, "dblock_sat_prob", 0.10))))
|
| 159 |
+
if nat_weight > 0.0 and not getattr(args, "ar_only", False):
|
| 160 |
+
choices.append("nat")
|
| 161 |
+
probs.append(max(0.0, float(getattr(args, "dblock_nat_prob", 0.10))))
|
| 162 |
+
if not choices:
|
| 163 |
+
return False, False, False, "none"
|
| 164 |
+
total = sum(probs)
|
| 165 |
+
if total <= 0.0:
|
| 166 |
+
probs = [1.0 / len(choices) for _ in choices]
|
| 167 |
+
else:
|
| 168 |
+
probs = [p / total for p in probs]
|
| 169 |
+
picked = random.choices(choices, weights=probs, k=1)[0]
|
| 170 |
+
return picked == "ar", picked == "sat", picked == "nat", picked
|
| 171 |
+
|
| 172 |
+
|
| 173 |
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
|
| 174 |
import nB300_agillm4 as M
|
| 175 |
|
|
|
|
| 180 |
asg = state["assign"]
|
| 181 |
bs = state["bsig"]
|
| 182 |
T = ids.size(1)
|
| 183 |
+
use_layer_checkpoint = bool(getattr(args, "grad_checkpoint", False))
|
| 184 |
bi = _choose_block(state, args)
|
| 185 |
lo, hi = sorted([bs[bi], bs[bi + 1]])
|
| 186 |
layers = asg[bi]
|
|
|
|
| 191 |
ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
|
| 192 |
sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
|
| 193 |
nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
|
| 194 |
+
do_sat_periodic = (not getattr(args, "ar_only", False)) and (
|
| 195 |
+
int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
|
| 196 |
+
)
|
| 197 |
+
do_nat_periodic = (
|
| 198 |
+
nat_h is not None
|
| 199 |
+
and (not getattr(args, "ar_only", False))
|
| 200 |
+
and int(getattr(args, "nat_every", 1)) > 0
|
| 201 |
+
and (
|
| 202 |
+
int(getattr(args, "nat_every", 1)) <= 1
|
| 203 |
+
or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
|
| 204 |
+
)
|
| 205 |
+
)
|
| 206 |
+
run_ar, run_sat, run_nat, objective = _choose_objectives(
|
| 207 |
+
state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
|
| 208 |
+
)
|
| 209 |
|
| 210 |
ar_val = 0.0
|
| 211 |
sat_val = 0.0
|
| 212 |
nat_val = 0.0
|
| 213 |
|
| 214 |
+
if run_ar:
|
| 215 |
causal = M.causal_mask(T, structured=M.use_structured_masks(args))
|
| 216 |
with M.amp(args.amp):
|
| 217 |
emb = core.emb(ids)
|
| 218 |
zt = emb + sig[:, None, None] * torch.randn_like(emb)
|
| 219 |
h = ci * zt
|
| 220 |
for li in layers:
|
| 221 |
+
h = _run_block(core.blocks[li], h, causal, use_layer_checkpoint)
|
| 222 |
Dn = core.ln(cs * zt + co * h)
|
| 223 |
+
ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
|
| 224 |
+
Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
|
| 225 |
+
)
|
| 226 |
+
ar = ar_weight * w * fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
|
| 227 |
ar_val = float(ar.detach())
|
| 228 |
scaler.scale(ar).backward()
|
| 229 |
+
del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar, ar_used, ar_total
|
| 230 |
|
| 231 |
+
if run_sat:
|
|
|
|
|
|
|
|
|
|
| 232 |
smask = M.sat_mask(T, structured=M.use_structured_masks(args))
|
| 233 |
with M.amp(args.amp):
|
| 234 |
emb2 = core.emb(ids)
|
| 235 |
zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
|
| 236 |
h2 = ci * zt2
|
| 237 |
for li in layers:
|
| 238 |
+
h2 = _run_block(core.blocks[li], h2, smask, use_layer_checkpoint)
|
| 239 |
Ds = core.ln(cs * zt2 + co * h2)
|
| 240 |
last = Ds[:, -SATB:]
|
| 241 |
+
sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs(
|
| 242 |
+
last, ids[:, 1 : SATB + 1], int(getattr(args, "dblock_sat_loss_tokens", 0))
|
| 243 |
+
)
|
| 244 |
+
satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
|
| 245 |
satv = (
|
| 246 |
M.EMIT_LAMBDA
|
| 247 |
* F.cross_entropy(
|
|
|
|
| 254 |
sat = sat_weight * w * (satf + satv)
|
| 255 |
sat_val = float(sat.detach())
|
| 256 |
scaler.scale(sat).backward()
|
| 257 |
+
del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat
|
| 258 |
|
| 259 |
+
if run_nat:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
|
| 261 |
nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
|
| 262 |
with M.amp(args.amp):
|
|
|
|
| 267 |
nat_in[m] = M.BLANK
|
| 268 |
hn = core.emb(nat_in)
|
| 269 |
for li in layers:
|
| 270 |
+
hn = _run_block(core.blocks[li], hn, None, use_layer_checkpoint)
|
| 271 |
Dnat = core.ln(hn)
|
| 272 |
+
nat_hidden = Dnat[m]
|
| 273 |
+
nat_targets = nat_ids[m]
|
| 274 |
+
nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
|
| 275 |
+
nat_hidden.unsqueeze(0), nat_targets.unsqueeze(0), int(getattr(args, "dblock_nat_loss_tokens", 0))
|
| 276 |
+
)
|
| 277 |
+
nat = nat_weight * fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
|
| 278 |
nat_val = float(nat.detach())
|
| 279 |
scaler.scale(nat).backward()
|
| 280 |
+
del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat, nat_used, nat_total
|
| 281 |
|
| 282 |
total_val = ar_val + sat_val + nat_val
|
| 283 |
if not math.isfinite(total_val):
|
|
|
|
| 300 |
peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
|
| 301 |
peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
|
| 302 |
_update_stats(state, bi, total_val)
|
| 303 |
+
_maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=objective)
|
| 304 |
return total_val
|
nB300_agillm4_vram_dblock.py
CHANGED
|
@@ -2941,6 +2941,17 @@ def main():
|
|
| 2941 |
tr.add_argument("--dblock_ar_weight", type=float, default=1.0)
|
| 2942 |
tr.add_argument("--dblock_sat_weight", type=float, default=1.0)
|
| 2943 |
tr.add_argument("--dblock_nat_weight", type=float, default=1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2944 |
tr.add_argument("--reinit_nat", action="store_true",
|
| 2945 |
help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
|
| 2946 |
tr.add_argument("--seed_nat_from_ar", action="store_true",
|
|
|
|
| 2941 |
tr.add_argument("--dblock_ar_weight", type=float, default=1.0)
|
| 2942 |
tr.add_argument("--dblock_sat_weight", type=float, default=1.0)
|
| 2943 |
tr.add_argument("--dblock_nat_weight", type=float, default=1.0)
|
| 2944 |
+
tr.add_argument("--dblock_objective_mode", choices=["periodic", "stochastic"], default="periodic",
|
| 2945 |
+
help="DBlock objective scheduler. stochastic samples one objective per step to reduce redundant AR/SAT/NAT forwards.")
|
| 2946 |
+
tr.add_argument("--dblock_ar_prob", type=float, default=0.80, help="Stochastic DBlock probability for AR objective.")
|
| 2947 |
+
tr.add_argument("--dblock_sat_prob", type=float, default=0.10, help="Stochastic DBlock probability for SAT objective.")
|
| 2948 |
+
tr.add_argument("--dblock_nat_prob", type=float, default=0.10, help="Stochastic DBlock probability for NAT objective.")
|
| 2949 |
+
tr.add_argument("--dblock_ar_loss_tokens", type=int, default=0,
|
| 2950 |
+
help="If >0, uniformly sample this many AR target positions per DBlock step for stochastic token-level CE.")
|
| 2951 |
+
tr.add_argument("--dblock_sat_loss_tokens", type=int, default=0,
|
| 2952 |
+
help="If >0, uniformly sample this many SAT target positions per DBlock step.")
|
| 2953 |
+
tr.add_argument("--dblock_nat_loss_tokens", type=int, default=0,
|
| 2954 |
+
help="If >0, uniformly sample this many NAT target positions per DBlock step.")
|
| 2955 |
tr.add_argument("--reinit_nat", action="store_true",
|
| 2956 |
help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
|
| 2957 |
tr.add_argument("--seed_nat_from_ar", action="store_true",
|
relaunch_agillm4_dblock.sh
CHANGED
|
@@ -9,16 +9,37 @@ export AGILLM_ATTN_BACKEND="${AGILLM_ATTN_BACKEND:-sublinear}"
|
|
| 9 |
if [ -f /root/.cache/huggingface/token ]; then export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; fi
|
| 10 |
SAVE_DIR=/workspace/agillm4_4090_ckpts
|
| 11 |
CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
[ -n "$CKPT" ] || { echo "no ckpt" >&2; exit 1; }
|
| 13 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 14 |
-
echo "
|
| 15 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 16 |
--dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
|
| 17 |
--dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
|
| 18 |
-
--dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --
|
| 19 |
-
--
|
| 20 |
-
--
|
|
|
|
|
|
|
|
|
|
| 21 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" \
|
| 22 |
--save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
|
| 23 |
-
--empty_cache_every_steps "$
|
| 24 |
--delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
|
|
|
|
| 9 |
if [ -f /root/.cache/huggingface/token ]; then export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; fi
|
| 10 |
SAVE_DIR=/workspace/agillm4_4090_ckpts
|
| 11 |
CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
|
| 12 |
+
BATCH_SIZE="${AGILLM4_BATCH_SIZE:-2}"
|
| 13 |
+
SAT_EVERY="${AGILLM4_SAT_EVERY:-4}"
|
| 14 |
+
NAT_EVERY="${AGILLM4_NAT_EVERY:-4}"
|
| 15 |
+
EMPTY_CACHE_EVERY="${AGILLM4_EMPTY_CACHE_EVERY_STEPS:-0}"
|
| 16 |
+
GRAD_CHECKPOINT="${AGILLM4_GRAD_CHECKPOINT:-1}"
|
| 17 |
+
DBLOCK_OBJECTIVE_MODE="${AGILLM4_DBLOCK_OBJECTIVE_MODE:-stochastic}"
|
| 18 |
+
DBLOCK_AR_PROB="${AGILLM4_DBLOCK_AR_PROB:-0.85}"
|
| 19 |
+
DBLOCK_SAT_PROB="${AGILLM4_DBLOCK_SAT_PROB:-0.075}"
|
| 20 |
+
DBLOCK_NAT_PROB="${AGILLM4_DBLOCK_NAT_PROB:-0.075}"
|
| 21 |
+
DBLOCK_AR_LOSS_TOKENS="${AGILLM4_DBLOCK_AR_LOSS_TOKENS:-512}"
|
| 22 |
+
DBLOCK_SAT_LOSS_TOKENS="${AGILLM4_DBLOCK_SAT_LOSS_TOKENS:-0}"
|
| 23 |
+
DBLOCK_NAT_LOSS_TOKENS="${AGILLM4_DBLOCK_NAT_LOSS_TOKENS:-512}"
|
| 24 |
+
SUBLINEAR_WINDOW="${AGILLM4_SUBLINEAR_WINDOW:-128}"
|
| 25 |
+
SUBLINEAR_STRIDE="${AGILLM4_SUBLINEAR_STRIDE:-128}"
|
| 26 |
+
SUBLINEAR_MAX_ANCHORS="${AGILLM4_SUBLINEAR_MAX_ANCHORS:-128}"
|
| 27 |
+
SUBLINEAR_CHUNK="${AGILLM4_SUBLINEAR_CHUNK:-128}"
|
| 28 |
+
GC_FLAG=()
|
| 29 |
+
if [ "$GRAD_CHECKPOINT" = "1" ] || [ "$GRAD_CHECKPOINT" = "true" ] || [ "$GRAD_CHECKPOINT" = "yes" ]; then GC_FLAG=(--grad_checkpoint); fi
|
| 30 |
[ -n "$CKPT" ] || { echo "no ckpt" >&2; exit 1; }
|
| 31 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 32 |
+
echo "RELAUNCH_AGILLM4_DBLOCK_SPEED $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock blocks=${AGILLM4_DBLOCKS:-4} tie_weights=1 attn=${AGILLM_ATTN_BACKEND} batch=$BATCH_SIZE sat_every=$SAT_EVERY nat_every=$NAT_EVERY empty_cache_every=$EMPTY_CACHE_EVERY grad_checkpoint=$GRAD_CHECKPOINT objective=$DBLOCK_OBJECTIVE_MODE ar_prob=$DBLOCK_AR_PROB sat_prob=$DBLOCK_SAT_PROB nat_prob=$DBLOCK_NAT_PROB ar_loss_tokens=$DBLOCK_AR_LOSS_TOKENS nat_loss_tokens=$DBLOCK_NAT_LOSS_TOKENS sublinear_window=$SUBLINEAR_WINDOW sublinear_stride=$SUBLINEAR_STRIDE sublinear_max_anchors=$SUBLINEAR_MAX_ANCHORS"
|
| 33 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 34 |
--dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
|
| 35 |
--dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
|
| 36 |
+
--dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --dblock_objective_mode "$DBLOCK_OBJECTIVE_MODE" \
|
| 37 |
+
--dblock_ar_prob "$DBLOCK_AR_PROB" --dblock_sat_prob "$DBLOCK_SAT_PROB" --dblock_nat_prob "$DBLOCK_NAT_PROB" \
|
| 38 |
+
--dblock_ar_loss_tokens "$DBLOCK_AR_LOSS_TOKENS" --dblock_sat_loss_tokens "$DBLOCK_SAT_LOSS_TOKENS" --dblock_nat_loss_tokens "$DBLOCK_NAT_LOSS_TOKENS" \
|
| 39 |
+
--tie_weights \
|
| 40 |
+
--batch_size "$BATCH_SIZE" --block "${AGILLM4_BLOCK:-1280}" --amp --attn_backend "${AGILLM_ATTN_BACKEND}" --sublinear_window "$SUBLINEAR_WINDOW" --sublinear_stride "$SUBLINEAR_STRIDE" --sublinear_max_anchors "$SUBLINEAR_MAX_ANCHORS" --sublinear_chunk "$SUBLINEAR_CHUNK" "${GC_FLAG[@]}" \
|
| 41 |
+
--optimizer paged_adamw8bit --sat_every "$SAT_EVERY" --nat_every "$NAT_EVERY" --nat_max_tokens 768 --nat_mask_ratio 0.5 \
|
| 42 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" \
|
| 43 |
--save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
|
| 44 |
+
--empty_cache_every_steps "$EMPTY_CACHE_EVERY" \
|
| 45 |
--delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
|
relaunch_agillm4_dblock_tied.sh
CHANGED
|
@@ -8,15 +8,36 @@ export AGILLM_ATTN_BACKEND=sublinear
|
|
| 8 |
[ -f /root/.cache/huggingface/token ] && { export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; }
|
| 9 |
SAVE_DIR=/workspace/agillm4_4090_ckpts
|
| 10 |
CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 12 |
-
echo "
|
| 13 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 14 |
--dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
|
| 15 |
--dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
|
| 16 |
-
--dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --
|
| 17 |
-
--
|
| 18 |
-
--
|
|
|
|
|
|
|
|
|
|
| 19 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" \
|
| 20 |
--save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
|
| 21 |
-
--empty_cache_every_steps "$
|
| 22 |
--delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
|
|
|
|
| 8 |
[ -f /root/.cache/huggingface/token ] && { export HF_TOKEN="$(tr -d '\r\n' </root/.cache/huggingface/token)"; export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"; }
|
| 9 |
SAVE_DIR=/workspace/agillm4_4090_ckpts
|
| 10 |
CKPT="$(ls -1t "$SAVE_DIR"/pretrain_step*.pt 2>/dev/null | head -1)"
|
| 11 |
+
BATCH_SIZE="${AGILLM4_BATCH_SIZE:-2}"
|
| 12 |
+
SAT_EVERY="${AGILLM4_SAT_EVERY:-4}"
|
| 13 |
+
NAT_EVERY="${AGILLM4_NAT_EVERY:-4}"
|
| 14 |
+
EMPTY_CACHE_EVERY="${AGILLM4_EMPTY_CACHE_EVERY_STEPS:-0}"
|
| 15 |
+
GRAD_CHECKPOINT="${AGILLM4_GRAD_CHECKPOINT:-1}"
|
| 16 |
+
DBLOCK_OBJECTIVE_MODE="${AGILLM4_DBLOCK_OBJECTIVE_MODE:-stochastic}"
|
| 17 |
+
DBLOCK_AR_PROB="${AGILLM4_DBLOCK_AR_PROB:-0.85}"
|
| 18 |
+
DBLOCK_SAT_PROB="${AGILLM4_DBLOCK_SAT_PROB:-0.075}"
|
| 19 |
+
DBLOCK_NAT_PROB="${AGILLM4_DBLOCK_NAT_PROB:-0.075}"
|
| 20 |
+
DBLOCK_AR_LOSS_TOKENS="${AGILLM4_DBLOCK_AR_LOSS_TOKENS:-512}"
|
| 21 |
+
DBLOCK_SAT_LOSS_TOKENS="${AGILLM4_DBLOCK_SAT_LOSS_TOKENS:-0}"
|
| 22 |
+
DBLOCK_NAT_LOSS_TOKENS="${AGILLM4_DBLOCK_NAT_LOSS_TOKENS:-512}"
|
| 23 |
+
SUBLINEAR_WINDOW="${AGILLM4_SUBLINEAR_WINDOW:-128}"
|
| 24 |
+
SUBLINEAR_STRIDE="${AGILLM4_SUBLINEAR_STRIDE:-128}"
|
| 25 |
+
SUBLINEAR_MAX_ANCHORS="${AGILLM4_SUBLINEAR_MAX_ANCHORS:-128}"
|
| 26 |
+
SUBLINEAR_CHUNK="${AGILLM4_SUBLINEAR_CHUNK:-128}"
|
| 27 |
+
GC_FLAG=()
|
| 28 |
+
if [ "$GRAD_CHECKPOINT" = "1" ] || [ "$GRAD_CHECKPOINT" = "true" ] || [ "$GRAD_CHECKPOINT" = "yes" ]; then GC_FLAG=(--grad_checkpoint); fi
|
| 29 |
exec >> /workspace/agillm4_floor_train.log 2>&1
|
| 30 |
+
echo "RELAUNCH_AGILLM4_DBLOCK_TIED_SPEED $(date -u +%Y-%m-%dT%H:%M:%SZ) resume=$CKPT --dblock --tie_weights --attn_backend sublinear batch=$BATCH_SIZE sat_every=$SAT_EVERY nat_every=$NAT_EVERY empty_cache_every=$EMPTY_CACHE_EVERY grad_checkpoint=$GRAD_CHECKPOINT objective=$DBLOCK_OBJECTIVE_MODE ar_prob=$DBLOCK_AR_PROB sat_prob=$DBLOCK_SAT_PROB nat_prob=$DBLOCK_NAT_PROB ar_loss_tokens=$DBLOCK_AR_LOSS_TOKENS nat_loss_tokens=$DBLOCK_NAT_LOSS_TOKENS sublinear_window=$SUBLINEAR_WINDOW sublinear_stride=$SUBLINEAR_STRIDE sublinear_max_anchors=$SUBLINEAR_MAX_ANCHORS"
|
| 31 |
exec python -u nB300_agillm4.py train --preset agillm4_floor --resume "$CKPT" \
|
| 32 |
--dblock --dblock_blocks "${AGILLM4_DBLOCKS:-4}" --dblock_schedule "${AGILLM4_DBLOCK_SCHEDULE:-loss_balanced}" \
|
| 33 |
--dblock_warmup_steps "${AGILLM4_DBLOCK_WARMUP:-16}" --dblock_sigma_curriculum_steps "${AGILLM4_DBLOCK_SIGMA_CURRICULUM:-2000}" \
|
| 34 |
+
--dblock_log_every "${AGILLM4_DBLOCK_LOG_EVERY:-25}" --dblock_objective_mode "$DBLOCK_OBJECTIVE_MODE" \
|
| 35 |
+
--dblock_ar_prob "$DBLOCK_AR_PROB" --dblock_sat_prob "$DBLOCK_SAT_PROB" --dblock_nat_prob "$DBLOCK_NAT_PROB" \
|
| 36 |
+
--dblock_ar_loss_tokens "$DBLOCK_AR_LOSS_TOKENS" --dblock_sat_loss_tokens "$DBLOCK_SAT_LOSS_TOKENS" --dblock_nat_loss_tokens "$DBLOCK_NAT_LOSS_TOKENS" \
|
| 37 |
+
--tie_weights \
|
| 38 |
+
--batch_size "$BATCH_SIZE" --block 1280 --amp --attn_backend sublinear --sublinear_window "$SUBLINEAR_WINDOW" --sublinear_stride "$SUBLINEAR_STRIDE" --sublinear_max_anchors "$SUBLINEAR_MAX_ANCHORS" --sublinear_chunk "$SUBLINEAR_CHUNK" "${GC_FLAG[@]}" \
|
| 39 |
+
--optimizer paged_adamw8bit --sat_every "$SAT_EVERY" --nat_every "$NAT_EVERY" --nat_max_tokens 768 --nat_mask_ratio 0.5 \
|
| 40 |
--token_param_ratio 100 --save_dir "$SAVE_DIR" \
|
| 41 |
--save_every_sec 86400 --heartbeat_every_sec "${AGILLM4_HEARTBEAT_EVERY_SEC:-300}" \
|
| 42 |
+
--empty_cache_every_steps "$EMPTY_CACHE_EVERY" \
|
| 43 |
--delta_every_steps 25000 --delta_max_keep 1 --max_ckpts 1
|