AGILLM4_add_v47_floor_tmux_launcher
Browse files- AGILLM-4.md +2 -1
- README.md +3 -0
- launch_agillm4_4090_floor_from_v47.sh +19 -0
- nB300_agillm4.py +22 -3
AGILLM-4.md
CHANGED
|
@@ -39,7 +39,8 @@ AGILLM4_4090_WARMSTART_FROM=/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.p
|
|
| 39 |
AGILLM4_4090_PRESET=agillm4_floor \
|
| 40 |
AGILLM4_4090_BLOCK=512 \
|
| 41 |
AGILLM4_4090_TOKEN_PARAM_RATIO=100 \
|
| 42 |
-
|
|
|
|
| 43 |
```
|
| 44 |
|
| 45 |
Important: `--sat_every 1 --nat_every 4` keeps SAT trained every step and NAT active on a cadence that fits 24GB cards. On B200/B300 use `--nat_every 1` for full AR+SAT+NAT every step. The AGILLM-4 code now backprops AR, SAT, and NAT sequentially, so the objective remains joint while peak VRAM is lower than holding all activation graphs at once.
|
|
|
|
| 39 |
AGILLM4_4090_PRESET=agillm4_floor \
|
| 40 |
AGILLM4_4090_BLOCK=512 \
|
| 41 |
AGILLM4_4090_TOKEN_PARAM_RATIO=100 \
|
| 42 |
+
tmux new-session -d -s agillm4_floor \
|
| 43 |
+
/workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh
|
| 44 |
```
|
| 45 |
|
| 46 |
Important: `--sat_every 1 --nat_every 4` keeps SAT trained every step and NAT active on a cadence that fits 24GB cards. On B200/B300 use `--nat_every 1` for full AR+SAT+NAT every step. The AGILLM-4 code now backprops AR, SAT, and NAT sequentially, so the objective remains joint while peak VRAM is lower than holding all activation graphs at once.
|
README.md
CHANGED
|
@@ -31,5 +31,8 @@ against SDPA before using it for a real run.
|
|
| 31 |
On RTX 4090-class 24GB cards, `run_agillm4_4090_longblock.sh` now defaults to
|
| 32 |
`agillm4_floor` instead of the AGILLM-3-sized `large` preset. Override
|
| 33 |
`AGILLM4_4090_BLOCK` upward only after the first floor run is stable.
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
|
|
|
|
| 31 |
On RTX 4090-class 24GB cards, `run_agillm4_4090_longblock.sh` now defaults to
|
| 32 |
`agillm4_floor` instead of the AGILLM-3-sized `large` preset. Override
|
| 33 |
`AGILLM4_4090_BLOCK` upward only after the first floor run is stable.
|
| 34 |
+
For the current v47 seed, launch tmux with
|
| 35 |
+
`/workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh`; it writes
|
| 36 |
+
`/workspace/agillm4_floor_train.log`.
|
| 37 |
|
| 38 |
Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
|
launch_agillm4_4090_floor_from_v47.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -Eeuo pipefail
|
| 3 |
+
|
| 4 |
+
LOG="${AGILLM4_FLOOR_LOG:-/workspace/agillm4_floor_train.log}"
|
| 5 |
+
mkdir -p "$(dirname "$LOG")"
|
| 6 |
+
exec >> "$LOG" 2>&1
|
| 7 |
+
|
| 8 |
+
echo "LAUNCH_AGILLM4_4090_FLOOR_FROM_V47 $(date -u +%Y-%m-%dT%H:%M:%SZ) host=$(hostname)"
|
| 9 |
+
|
| 10 |
+
export AGILLM4_4090_WARMSTART_FROM="${AGILLM4_4090_WARMSTART_FROM:-/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt}"
|
| 11 |
+
export AGILLM4_4090_PRESET="${AGILLM4_4090_PRESET:-agillm4_floor}"
|
| 12 |
+
export AGILLM4_4090_BLOCK="${AGILLM4_4090_BLOCK:-512}"
|
| 13 |
+
export AGILLM4_4090_TOKEN_PARAM_RATIO="${AGILLM4_4090_TOKEN_PARAM_RATIO:-100}"
|
| 14 |
+
export AGILLM4_4090_NAT_EVERY="${AGILLM4_4090_NAT_EVERY:-4}"
|
| 15 |
+
export AGILLM4_4090_NAT_MAX_TOKENS="${AGILLM4_4090_NAT_MAX_TOKENS:-512}"
|
| 16 |
+
export AGILLM4_4090_SAVE_EVERY_SEC="${AGILLM4_4090_SAVE_EVERY_SEC:-21600}"
|
| 17 |
+
export AGILLM4_4090_SAVE_DIR="${AGILLM4_4090_SAVE_DIR:-/workspace/agillm4_4090_ckpts}"
|
| 18 |
+
|
| 19 |
+
exec /workspace/agillm-4/run_agillm4_4090_longblock.sh
|
nB300_agillm4.py
CHANGED
|
@@ -1561,8 +1561,12 @@ def sat_mask(n, block=SAT_BLOCK):
|
|
| 1561 |
|
| 1562 |
def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
|
| 1563 |
total_len = cached_len + new_len
|
| 1564 |
-
|
| 1565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
|
| 1567 |
|
| 1568 |
# βββββββββββββββββββββββββ Checkpoint helpers βββββββββββββββββββββββββ
|
|
@@ -2480,9 +2484,23 @@ def infer(args):
|
|
| 2480 |
else:
|
| 2481 |
cached_len = ids.size(1)
|
| 2482 |
h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
|
|
|
|
| 2483 |
added = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2484 |
while added < args.max_new:
|
| 2485 |
-
logits_all, gate = sat_h(
|
| 2486 |
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
|
| 2487 |
stride = min(int(stride), logits_all.size(1))
|
| 2488 |
new_tokens = []
|
|
@@ -2499,6 +2517,7 @@ def infer(args):
|
|
| 2499 |
mask = sat_mask_cached(new_ids.size(1), cached_len)
|
| 2500 |
h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
| 2501 |
cached_len = ids.size(1)
|
|
|
|
| 2502 |
elapsed = time.time() - start
|
| 2503 |
gen_tokens = len(ids[0]) - prompt_len
|
| 2504 |
tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0
|
|
|
|
| 1561 |
|
| 1562 |
def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
|
| 1563 |
total_len = cached_len + new_len
|
| 1564 |
+
q_idx = torch.arange(cached_len, total_len, device=DEV).unsqueeze(1)
|
| 1565 |
+
k_idx = torch.arange(total_len, device=DEV).unsqueeze(0)
|
| 1566 |
+
q_grp = q_idx // block
|
| 1567 |
+
k_grp = k_idx // block
|
| 1568 |
+
allow = q_grp >= k_grp
|
| 1569 |
+
return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
|
| 1570 |
|
| 1571 |
|
| 1572 |
# βββββββββββββββββββββββββ Checkpoint helpers βββββββββββββββββββββββββ
|
|
|
|
| 2484 |
else:
|
| 2485 |
cached_len = ids.size(1)
|
| 2486 |
h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
|
| 2487 |
+
h_buffer = h[:, -SAT_BLOCK:]
|
| 2488 |
added = 0
|
| 2489 |
+
stop = False
|
| 2490 |
+
|
| 2491 |
+
# Align to block boundary if prompt is off-boundary
|
| 2492 |
+
if ids.size(1) % SAT_BLOCK != 0:
|
| 2493 |
+
logits = ar_h(h)[:, -1]
|
| 2494 |
+
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
|
| 2495 |
+
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 2496 |
+
ids = torch.cat([ids, nxt], 1)
|
| 2497 |
+
added += 1
|
| 2498 |
+
h, kvs = core(nxt, None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
| 2499 |
+
cached_len = ids.size(1)
|
| 2500 |
+
h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
|
| 2501 |
+
|
| 2502 |
while added < args.max_new:
|
| 2503 |
+
logits_all, gate = sat_h(h_buffer)
|
| 2504 |
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
|
| 2505 |
stride = min(int(stride), logits_all.size(1))
|
| 2506 |
new_tokens = []
|
|
|
|
| 2517 |
mask = sat_mask_cached(new_ids.size(1), cached_len)
|
| 2518 |
h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
| 2519 |
cached_len = ids.size(1)
|
| 2520 |
+
h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
|
| 2521 |
elapsed = time.time() - start
|
| 2522 |
gen_tokens = len(ids[0]) - prompt_len
|
| 2523 |
tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0
|