OpenTransformer commited on
Commit
5c06a05
Β·
verified Β·
1 Parent(s): 60f0b3c

AGILLM4_add_v47_floor_tmux_launcher

Browse files
AGILLM-4.md CHANGED
@@ -39,7 +39,8 @@ AGILLM4_4090_WARMSTART_FROM=/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.p
39
  AGILLM4_4090_PRESET=agillm4_floor \
40
  AGILLM4_4090_BLOCK=512 \
41
  AGILLM4_4090_TOKEN_PARAM_RATIO=100 \
42
- bash /workspace/agillm-4/run_agillm4_4090_longblock.sh
 
43
  ```
44
 
45
  Important: `--sat_every 1 --nat_every 4` keeps SAT trained every step and NAT active on a cadence that fits 24GB cards. On B200/B300 use `--nat_every 1` for full AR+SAT+NAT every step. The AGILLM-4 code now backprops AR, SAT, and NAT sequentially, so the objective remains joint while peak VRAM is lower than holding all activation graphs at once.
 
39
  AGILLM4_4090_PRESET=agillm4_floor \
40
  AGILLM4_4090_BLOCK=512 \
41
  AGILLM4_4090_TOKEN_PARAM_RATIO=100 \
42
+ tmux new-session -d -s agillm4_floor \
43
+ /workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh
44
  ```
45
 
46
  Important: `--sat_every 1 --nat_every 4` keeps SAT trained every step and NAT active on a cadence that fits 24GB cards. On B200/B300 use `--nat_every 1` for full AR+SAT+NAT every step. The AGILLM-4 code now backprops AR, SAT, and NAT sequentially, so the objective remains joint while peak VRAM is lower than holding all activation graphs at once.
README.md CHANGED
@@ -31,5 +31,8 @@ against SDPA before using it for a real run.
31
  On RTX 4090-class 24GB cards, `run_agillm4_4090_longblock.sh` now defaults to
32
  `agillm4_floor` instead of the AGILLM-3-sized `large` preset. Override
33
  `AGILLM4_4090_BLOCK` upward only after the first floor run is stable.
 
 
 
34
 
35
  Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
 
31
  On RTX 4090-class 24GB cards, `run_agillm4_4090_longblock.sh` now defaults to
32
  `agillm4_floor` instead of the AGILLM-3-sized `large` preset. Override
33
  `AGILLM4_4090_BLOCK` upward only after the first floor run is stable.
34
+ For the current v47 seed, launch tmux with
35
+ `/workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh`; it writes
36
+ `/workspace/agillm4_floor_train.log`.
37
 
38
  Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
launch_agillm4_4090_floor_from_v47.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -Eeuo pipefail
3
+
4
+ LOG="${AGILLM4_FLOOR_LOG:-/workspace/agillm4_floor_train.log}"
5
+ mkdir -p "$(dirname "$LOG")"
6
+ exec >> "$LOG" 2>&1
7
+
8
+ echo "LAUNCH_AGILLM4_4090_FLOOR_FROM_V47 $(date -u +%Y-%m-%dT%H:%M:%SZ) host=$(hostname)"
9
+
10
+ export AGILLM4_4090_WARMSTART_FROM="${AGILLM4_4090_WARMSTART_FROM:-/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt}"
11
+ export AGILLM4_4090_PRESET="${AGILLM4_4090_PRESET:-agillm4_floor}"
12
+ export AGILLM4_4090_BLOCK="${AGILLM4_4090_BLOCK:-512}"
13
+ export AGILLM4_4090_TOKEN_PARAM_RATIO="${AGILLM4_4090_TOKEN_PARAM_RATIO:-100}"
14
+ export AGILLM4_4090_NAT_EVERY="${AGILLM4_4090_NAT_EVERY:-4}"
15
+ export AGILLM4_4090_NAT_MAX_TOKENS="${AGILLM4_4090_NAT_MAX_TOKENS:-512}"
16
+ export AGILLM4_4090_SAVE_EVERY_SEC="${AGILLM4_4090_SAVE_EVERY_SEC:-21600}"
17
+ export AGILLM4_4090_SAVE_DIR="${AGILLM4_4090_SAVE_DIR:-/workspace/agillm4_4090_ckpts}"
18
+
19
+ exec /workspace/agillm-4/run_agillm4_4090_longblock.sh
nB300_agillm4.py CHANGED
@@ -1561,8 +1561,12 @@ def sat_mask(n, block=SAT_BLOCK):
1561
 
1562
  def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
1563
  total_len = cached_len + new_len
1564
- mask = torch.zeros((1, 1, new_len, total_len), device=DEV)
1565
- return mask
 
 
 
 
1566
 
1567
 
1568
  # ───────────────────────── Checkpoint helpers ─────────────────────────
@@ -2480,9 +2484,23 @@ def infer(args):
2480
  else:
2481
  cached_len = ids.size(1)
2482
  h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
 
2483
  added = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
2484
  while added < args.max_new:
2485
- logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
2486
  stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
2487
  stride = min(int(stride), logits_all.size(1))
2488
  new_tokens = []
@@ -2499,6 +2517,7 @@ def infer(args):
2499
  mask = sat_mask_cached(new_ids.size(1), cached_len)
2500
  h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
2501
  cached_len = ids.size(1)
 
2502
  elapsed = time.time() - start
2503
  gen_tokens = len(ids[0]) - prompt_len
2504
  tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0
 
1561
 
1562
  def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
1563
  total_len = cached_len + new_len
1564
+ q_idx = torch.arange(cached_len, total_len, device=DEV).unsqueeze(1)
1565
+ k_idx = torch.arange(total_len, device=DEV).unsqueeze(0)
1566
+ q_grp = q_idx // block
1567
+ k_grp = k_idx // block
1568
+ allow = q_grp >= k_grp
1569
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
1570
 
1571
 
1572
  # ───────────────────────── Checkpoint helpers ─────────────────────────
 
2484
  else:
2485
  cached_len = ids.size(1)
2486
  h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
2487
+ h_buffer = h[:, -SAT_BLOCK:]
2488
  added = 0
2489
+ stop = False
2490
+
2491
+ # Align to block boundary if prompt is off-boundary
2492
+ if ids.size(1) % SAT_BLOCK != 0:
2493
+ logits = ar_h(h)[:, -1]
2494
+ logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
2495
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
2496
+ ids = torch.cat([ids, nxt], 1)
2497
+ added += 1
2498
+ h, kvs = core(nxt, None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
2499
+ cached_len = ids.size(1)
2500
+ h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
2501
+
2502
  while added < args.max_new:
2503
+ logits_all, gate = sat_h(h_buffer)
2504
  stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
2505
  stride = min(int(stride), logits_all.size(1))
2506
  new_tokens = []
 
2517
  mask = sat_mask_cached(new_ids.size(1), cached_len)
2518
  h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
2519
  cached_len = ids.size(1)
2520
+ h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
2521
  elapsed = time.time() - start
2522
  gen_tokens = len(ids[0]) - prompt_len
2523
  tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0