OpenTransformer commited on
Commit
421314d
·
verified ·
1 Parent(s): 9c63689

Harvest exact M-fold attention from n1

Browse files
AGILLM-4.md CHANGED
@@ -142,6 +142,33 @@ python /workspace/agillm-4/block_sweep_agillm4.py \
142
  If it is stable, then compare it against SDPA at the same block size with
143
  `profile_agillm4.py`.
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  ## Intelligence per FLOP
146
 
147
  Compute reduction is not enough by itself. AGILLM-4 should spend saved FLOPs on
 
142
  If it is stable, then compare it against SDPA at the same block size with
143
  `profile_agillm4.py`.
144
 
145
+ ## n1.py Harvest
146
+
147
+ AGILLM-4 is now starting to import exact, proof-backed improvements from
148
+ `C:\Users\Scott\Downloads\n1.py` while keeping the AGILLM-4 model-scale and
149
+ long-context branch intact.
150
+
151
+ First harvested feature: exact M-fold expansion attention. When `rank > d_k`,
152
+ the trainer now uses:
153
+
154
+ ```text
155
+ (q @ U) @ (k @ U).T == q @ (U @ U.T) @ k.T
156
+ ```
157
+
158
+ This preserves the function while keeping score/cache key width at `d_k`
159
+ instead of `rank`. The inference path caches `U @ U.T`; the training path
160
+ recomputes it so gradients through `U` remain exact.
161
+
162
+ Verification:
163
+
164
+ ```bash
165
+ python /workspace/agillm-4/verify_m_fold_agillm4.py \
166
+ --presets pico_1x,micro_3x \
167
+ --backends manual,sdpa
168
+ ```
169
+
170
+ See `N1_HARVEST.md` for the staged port order.
171
+
172
  ## Intelligence per FLOP
173
 
174
  Compute reduction is not enough by itself. AGILLM-4 should spend saved FLOPs on
N1_HARVEST.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # n1.py Harvest Plan
2
+
3
+ Goal: move proven, no-quality-loss trainer improvements from
4
+ `C:\Users\Scott\Downloads\n1.py` into AGILLM-4 without replacing the AGILLM-4
5
+ long-context/model-scale branch.
6
+
7
+ ## Ported
8
+
9
+ ### 1. Exact M-Fold Expansion Attention
10
+
11
+ Status: done.
12
+
13
+ For ranks where `rank > d_k`, AGILLM-4 now computes:
14
+
15
+ ```text
16
+ (q @ U) @ (k @ U).T == q @ (U @ U.T) @ k.T
17
+ ```
18
+
19
+ This keeps attention scores and KV-cache keys in `d_k` width instead of
20
+ `rank` width while preserving the exact expanded-attention function. The
21
+ training path recomputes `U @ U.T` with gradients, and inference/no-grad caches
22
+ the metric until `U` changes.
23
+
24
+ Verification:
25
+
26
+ ```bash
27
+ python agillm-4/verify_m_fold_agillm4.py \
28
+ --presets pico_1x,micro_3x \
29
+ --backends manual,sdpa \
30
+ --cached_len 8 \
31
+ --new_len 4
32
+ ```
33
+
34
+ The verifier checks forward output, loss, input gradients, parameter gradients,
35
+ cached append equivalence, cache key width, and metric-cache invalidation.
36
+
37
+ ## Next Candidates
38
+
39
+ ### 2. Fused QKV Projection
40
+
41
+ n1 fuses separate `q/k/v` linear layers into one `qkv` linear while keeping
42
+ checkpoint compatibility by folding old state-dict keys on load. This should be
43
+ the next port after M-fold because it reduces three projection GEMMs to one.
44
+
45
+ Risk: checkpoint key migration. Keep this separate from the M-fold port.
46
+
47
+ ### 3. Combined ALiBi + Mask Cache
48
+
49
+ n1 pre-folds ALiBi into the mask once per encoder forward instead of rebuilding
50
+ the same layer-independent bias in every block.
51
+
52
+ Risk: cache semantics differ for KV decode where the ALiBi slice changes.
53
+
54
+ ### 4. SAT Speculative Inference
55
+
56
+ n1 has proof-covered SAT-draft / AR-verify speculative decoding. This belongs
57
+ in AGILLM-4 after the SFT result tells us whether chat turns are sane.
58
+
59
+ Risk: inference control flow and cache rollback complexity.
60
+
61
+ ### 5. Compact Checkpoint
62
+
63
+ n1 can compact `U` spectra post-training and save compatible checkpoints.
64
+
65
+ Risk: optimizer state must be dropped or remapped carefully; do only as a
66
+ separate command, never during a live training run.
67
+
68
+ ### 6. KV Cache Buffer
69
+
70
+ n1 replaces repeated decode-time `torch.cat` cache growth with preallocated KV
71
+ buffers.
72
+
73
+ Risk: cache object type touches AR, SAT, and future spec decoding paths.
74
+
75
+ ## Rule
76
+
77
+ Every harvested feature needs its own AGILLM-4 verifier or profile artifact.
78
+ Do not rely on n1's proof suite alone after adapting the implementation.
README.md CHANGED
@@ -19,9 +19,12 @@ and extended for:
19
  - longer block-size work on 24GB, B200, and B300 class GPUs
20
  - AR+SAT every step with sequential backward to reduce peak VRAM
21
  - SDPA and experimental sublinear local+landmark attention backends
 
22
  - profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
23
  - synthetic long-context curriculum generation for recall and multi-hop tests
24
 
25
  Start with [AGILLM-4.md](AGILLM-4.md) for the training plan and command
26
  recipes. The current sublinear backend is intentionally experimental: profile it
27
  against SDPA before using it for a real run.
 
 
 
19
  - longer block-size work on 24GB, B200, and B300 class GPUs
20
  - AR+SAT every step with sequential backward to reduce peak VRAM
21
  - SDPA and experimental sublinear local+landmark attention backends
22
+ - exact M-fold expansion attention harvested from n1.py, with local verifier
23
  - profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
24
  - synthetic long-context curriculum generation for recall and multi-hop tests
25
 
26
  Start with [AGILLM-4.md](AGILLM-4.md) for the training plan and command
27
  recipes. The current sublinear backend is intentionally experimental: profile it
28
  against SDPA before using it for a real run.
29
+
30
+ Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
local_sweep_after_mfold_sdpa.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "alloc_gb": 0.027,
4
+ "amp": false,
5
+ "attn_backend": "sdpa",
6
+ "batch_size": 1,
7
+ "block": 64,
8
+ "elapsed_s": 0.534,
9
+ "error": null,
10
+ "grad_checkpoint": true,
11
+ "loss": 18.8005,
12
+ "ok": true,
13
+ "peak_alloc_gb": 0.031,
14
+ "peak_reserved_gb": 0.057,
15
+ "reserved_gb": 0.057,
16
+ "sublinear_chunk": 128,
17
+ "sublinear_max_anchors": 256,
18
+ "sublinear_stride": 64,
19
+ "sublinear_window": 256,
20
+ "tokens_per_s_synthetic": 119.8
21
+ }
22
+ ]
local_sweep_after_mfold_sublinear.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "alloc_gb": 0.027,
4
+ "amp": false,
5
+ "attn_backend": "sublinear",
6
+ "batch_size": 1,
7
+ "block": 64,
8
+ "elapsed_s": 0.635,
9
+ "error": null,
10
+ "grad_checkpoint": true,
11
+ "loss": 18.7676,
12
+ "ok": true,
13
+ "peak_alloc_gb": 0.031,
14
+ "peak_reserved_gb": 0.057,
15
+ "reserved_gb": 0.057,
16
+ "sublinear_chunk": 16,
17
+ "sublinear_max_anchors": 16,
18
+ "sublinear_stride": 8,
19
+ "sublinear_window": 16,
20
+ "tokens_per_s_synthetic": 100.8
21
+ }
22
+ ]
local_verify_m_fold_agillm4.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "backend": "manual",
4
+ "d": 32,
5
+ "dk": 16,
6
+ "expected_k_width": 16,
7
+ "heads": 2,
8
+ "ok": true,
9
+ "preset": "pico_1x",
10
+ "rank": 16,
11
+ "rows": {
12
+ "cache_k_width": 16.0,
13
+ "cache_v_width": 16.0,
14
+ "cached_append_forward": 5.960464477539063e-08,
15
+ "causal_alibi_forward": 0.0,
16
+ "causal_alibi_loss": 0.0,
17
+ "causal_alibi_param_grad": 0.0,
18
+ "causal_alibi_x_grad": 0.0,
19
+ "none_forward": 0.0,
20
+ "none_loss": 0.0,
21
+ "none_param_grad": 0.0,
22
+ "none_x_grad": 0.0,
23
+ "sat_alibi_forward": 0.0,
24
+ "sat_alibi_loss": 0.0,
25
+ "sat_alibi_param_grad": 0.0,
26
+ "sat_alibi_x_grad": 0.0
27
+ },
28
+ "tol": 0.0002
29
+ },
30
+ {
31
+ "backend": "sdpa",
32
+ "d": 32,
33
+ "dk": 16,
34
+ "expected_k_width": 16,
35
+ "heads": 2,
36
+ "ok": true,
37
+ "preset": "pico_1x",
38
+ "rank": 16,
39
+ "rows": {
40
+ "cache_k_width": 16.0,
41
+ "cache_v_width": 16.0,
42
+ "cached_append_forward": 5.960464477539063e-08,
43
+ "causal_alibi_forward": 7.450580596923828e-08,
44
+ "causal_alibi_loss": 0.0,
45
+ "causal_alibi_param_grad": 1.862645149230957e-09,
46
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
47
+ "none_forward": 8.940696716308594e-08,
48
+ "none_loss": 0.0,
49
+ "none_param_grad": 9.313225746154785e-10,
50
+ "none_x_grad": 6.548361852765083e-11,
51
+ "sat_alibi_forward": 1.1920928955078125e-07,
52
+ "sat_alibi_loss": 1.862645149230957e-09,
53
+ "sat_alibi_param_grad": 9.313225746154785e-10,
54
+ "sat_alibi_x_grad": 3.4924596548080444e-10
55
+ },
56
+ "tol": 0.0002
57
+ },
58
+ {
59
+ "backend": "manual",
60
+ "d": 128,
61
+ "dk": 16,
62
+ "expected_k_width": 16,
63
+ "heads": 8,
64
+ "ok": true,
65
+ "preset": "micro_3x",
66
+ "rank": 48,
67
+ "rows": {
68
+ "cache_k_width": 16.0,
69
+ "cache_v_width": 16.0,
70
+ "cached_append_forward": 6.51925802230835e-08,
71
+ "causal_alibi_forward": 5.960464477539063e-08,
72
+ "causal_alibi_loss": 0.0,
73
+ "causal_alibi_param_grad": 4.656612873077393e-10,
74
+ "causal_alibi_x_grad": 5.820766091346741e-11,
75
+ "metric_cache_cleared_on_train": 0.0,
76
+ "metric_cache_reused": 0.0,
77
+ "none_forward": 5.960464477539063e-08,
78
+ "none_loss": 0.0,
79
+ "none_param_grad": 2.3283064365386963e-10,
80
+ "none_x_grad": 2.1827872842550278e-11,
81
+ "sat_alibi_forward": 1.1920928955078125e-07,
82
+ "sat_alibi_loss": 1.862645149230957e-09,
83
+ "sat_alibi_param_grad": 5.820766091346741e-10,
84
+ "sat_alibi_x_grad": 7.275957614183426e-11
85
+ },
86
+ "tol": 0.0002
87
+ },
88
+ {
89
+ "backend": "sdpa",
90
+ "d": 128,
91
+ "dk": 16,
92
+ "expected_k_width": 16,
93
+ "heads": 8,
94
+ "ok": true,
95
+ "preset": "micro_3x",
96
+ "rank": 48,
97
+ "rows": {
98
+ "cache_k_width": 16.0,
99
+ "cache_v_width": 16.0,
100
+ "cached_append_forward": 7.450580596923828e-08,
101
+ "causal_alibi_forward": 1.043081283569336e-07,
102
+ "causal_alibi_loss": 0.0,
103
+ "causal_alibi_param_grad": 4.656612873077393e-10,
104
+ "causal_alibi_x_grad": 9.458744898438454e-11,
105
+ "metric_cache_cleared_on_train": 0.0,
106
+ "metric_cache_reused": 0.0,
107
+ "none_forward": 8.940696716308594e-08,
108
+ "none_loss": 0.0,
109
+ "none_param_grad": 2.3283064365386963e-10,
110
+ "none_x_grad": 2.9103830456733704e-11,
111
+ "sat_alibi_forward": 1.1920928955078125e-07,
112
+ "sat_alibi_loss": 0.0,
113
+ "sat_alibi_param_grad": 6.984919309616089e-10,
114
+ "sat_alibi_x_grad": 1.1641532182693481e-10
115
+ },
116
+ "tol": 0.0002
117
+ }
118
+ ]
nB300_agillm4.py CHANGED
@@ -1154,6 +1154,15 @@ class TuneableAttentionMHA(nn.Module):
1154
  nn.init.orthogonal_(self.U)
1155
  self.proj = nn.Linear(h * self.dk, d, bias=False)
1156
  self.drop = nn.Dropout(0.1)
 
 
 
 
 
 
 
 
 
1157
 
1158
  def _proj_qk(self, x):
1159
  B, N, _ = x.shape
@@ -1163,6 +1172,44 @@ class TuneableAttentionMHA(nn.Module):
1163
  B, N, _ = x.shape
1164
  return x.view(B, N, self.h, self.dk).transpose(1, 2)
1165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1166
  def _sublinear_attention(self, q, k, v, attn_mask=None):
1167
  """Local-window + landmark attention: O(N * (window + N/stride))."""
1168
  bsz, heads, q_len, _ = q.shape
@@ -1226,9 +1273,15 @@ class TuneableAttentionMHA(nn.Module):
1226
  return torch.cat(outputs, dim=2)
1227
 
1228
  def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
1229
- q = self._proj_qk(self.q(x))
1230
- k_new = self._proj_qk(self.k(x))
1231
  v_new = self._reshape_v(self.v(x))
 
 
 
 
 
 
1232
  if kv_cache is None:
1233
  k, v = k_new, v_new
1234
  else:
 
1154
  nn.init.orthogonal_(self.U)
1155
  self.proj = nn.Linear(h * self.dk, d, bias=False)
1156
  self.drop = nn.Dropout(0.1)
1157
+ # Exact n1 harvest: for expansion ranks, (q @ U) @ (k @ U).T is
1158
+ # q @ (U @ U.T) @ k.T. This keeps score/cache width at d_k with no
1159
+ # quality change. Inference caches the metric and training recomputes
1160
+ # it so gradients through U are unchanged.
1161
+ self._metric_cache: Optional[torch.Tensor] = None
1162
+ self._metric_cache_ver: int = -1
1163
+ self._metric_cache_param_id: int = -1
1164
+ self._metric_cache_data_ptr: int = -1
1165
+ self._metric_cache_shape: Tuple[int, int] = (-1, -1)
1166
 
1167
  def _proj_qk(self, x):
1168
  B, N, _ = x.shape
 
1172
  B, N, _ = x.shape
1173
  return x.view(B, N, self.h, self.dk).transpose(1, 2)
1174
 
1175
+ def _reshape_heads(self, x):
1176
+ B, N, _ = x.shape
1177
+ return x.view(B, N, self.h, self.dk).transpose(1, 2)
1178
+
1179
+ def _get_metric(self) -> torch.Tensor:
1180
+ if torch.is_grad_enabled():
1181
+ return self.U @ self.U.T
1182
+ cur_ver = self.U._version
1183
+ cur_param_id = id(self.U)
1184
+ cur_data_ptr = int(self.U.data_ptr())
1185
+ cur_shape = tuple(self.U.shape)
1186
+ cache = self._metric_cache
1187
+ if (
1188
+ cache is None
1189
+ or cache.dtype != self.U.dtype
1190
+ or cache.device != self.U.device
1191
+ or self._metric_cache_ver != cur_ver
1192
+ or self._metric_cache_param_id != cur_param_id
1193
+ or self._metric_cache_data_ptr != cur_data_ptr
1194
+ or self._metric_cache_shape != cur_shape
1195
+ ):
1196
+ cache = (self.U @ self.U.T).detach()
1197
+ self._metric_cache = cache
1198
+ self._metric_cache_ver = cur_ver
1199
+ self._metric_cache_param_id = cur_param_id
1200
+ self._metric_cache_data_ptr = cur_data_ptr
1201
+ self._metric_cache_shape = cur_shape
1202
+ return cache
1203
+
1204
+ def train(self, mode: bool = True):
1205
+ if mode:
1206
+ self._metric_cache = None
1207
+ self._metric_cache_ver = -1
1208
+ self._metric_cache_param_id = -1
1209
+ self._metric_cache_data_ptr = -1
1210
+ self._metric_cache_shape = (-1, -1)
1211
+ return super().train(mode)
1212
+
1213
  def _sublinear_attention(self, q, k, v, attn_mask=None):
1214
  """Local-window + landmark attention: O(N * (window + N/stride))."""
1215
  bsz, heads, q_len, _ = q.shape
 
1273
  return torch.cat(outputs, dim=2)
1274
 
1275
  def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
1276
+ q_lin = self.q(x)
1277
+ k_lin = self.k(x)
1278
  v_new = self._reshape_v(self.v(x))
1279
+ if self.r > self.dk:
1280
+ q = self._reshape_heads(q_lin) @ self._get_metric()
1281
+ k_new = self._reshape_heads(k_lin)
1282
+ else:
1283
+ q = self._proj_qk(q_lin)
1284
+ k_new = self._proj_qk(k_lin)
1285
  if kv_cache is None:
1286
  k, v = k_new, v_new
1287
  else:
verify_m_fold_agillm4.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import math
7
+ import os
8
+ from pathlib import Path
9
+
10
+ os.environ.setdefault("AGILLM_SYNTHETIC_TOKENIZER", "1")
11
+
12
+ import torch
13
+
14
+ import nB300_agillm4 as nb
15
+
16
+
17
+ def causal_mask_cached(new_len: int, cached_len: int):
18
+ total = cached_len + new_len
19
+ q_pos = torch.arange(cached_len, total, device=nb.DEV).view(new_len, 1)
20
+ k_pos = torch.arange(total, device=nb.DEV).view(1, total)
21
+ mask = torch.where(k_pos > q_pos, float("-inf"), 0.0)
22
+ return mask.view(1, 1, new_len, total)
23
+
24
+
25
+ def old_expanded_forward(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
26
+ bsz, seq, _ = x.shape
27
+ q = mha._reshape_heads(mha.q(x)) @ mha.U
28
+ k = mha._reshape_heads(mha.k(x)) @ mha.U
29
+ v = mha._reshape_v(mha.v(x))
30
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
31
+ if mha.use_relpos and rel_bias_tokens is not None:
32
+ att = att + nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
33
+ if mask is not None:
34
+ att = att + mask
35
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(bsz, seq, -1)
36
+ return mha.drop(mha.proj(z))
37
+
38
+
39
+ def max_param_grad_diff(mha: nb.TuneableAttentionMHA, ref_grads: dict[str, torch.Tensor]) -> float:
40
+ out = 0.0
41
+ for name, param in mha.named_parameters():
42
+ if param.grad is None:
43
+ continue
44
+ out = max(out, (param.grad.detach() - ref_grads[name]).abs().max().item())
45
+ return out
46
+
47
+
48
+ def verify_case(args, preset: str, backend: str) -> dict:
49
+ torch.manual_seed(args.seed)
50
+ cfg = nb.PRESETS[preset].copy()
51
+ d, h, r = cfg["d"], cfg["heads"], cfg["rank"]
52
+ seq = args.cached_len + args.new_len
53
+ mha = nb.TuneableAttentionMHA(d, h, r, attn_backend=backend).to(nb.DEV).eval()
54
+ rows = {}
55
+
56
+ for case_name, mask, rel_tokens in [
57
+ ("none", None, None),
58
+ ("causal_alibi", nb.causal_mask(seq), seq),
59
+ ("sat_alibi", nb.sat_mask(seq), seq),
60
+ ]:
61
+ x_new = torch.randn(2, seq, d, device=nb.DEV, requires_grad=True)
62
+ x_old = x_new.detach().clone().requires_grad_(True)
63
+ y_new = mha(x_new, mask, rel_bias_tokens=rel_tokens)
64
+ y_old = old_expanded_forward(mha, x_old, mask=mask, rel_bias_tokens=rel_tokens)
65
+ loss_new = y_new.square().mean()
66
+ loss_old = y_old.square().mean()
67
+ loss_new.backward()
68
+ new_x_grad = x_new.grad.detach().clone()
69
+ new_param_grads = {
70
+ name: param.grad.detach().clone()
71
+ for name, param in mha.named_parameters()
72
+ if param.grad is not None
73
+ }
74
+ mha.zero_grad(set_to_none=True)
75
+ loss_old.backward()
76
+ old_x_grad = x_old.grad.detach().clone()
77
+ rows[f"{case_name}_forward"] = (y_new - y_old).abs().max().item()
78
+ rows[f"{case_name}_loss"] = abs(loss_new.item() - loss_old.item())
79
+ rows[f"{case_name}_x_grad"] = (new_x_grad - old_x_grad).abs().max().item()
80
+ rows[f"{case_name}_param_grad"] = max_param_grad_diff(mha, new_param_grads)
81
+ mha.zero_grad(set_to_none=True)
82
+
83
+ with torch.no_grad():
84
+ prefix = torch.randn(1, args.cached_len, d, device=nb.DEV)
85
+ append = torch.randn(1, args.new_len, d, device=nb.DEV)
86
+ full = torch.cat([prefix, append], dim=1)
87
+ y_full = mha(full, nb.causal_mask(seq), rel_bias_tokens=seq)[:, args.cached_len :]
88
+ _, kvs = mha(prefix, nb.causal_mask(args.cached_len), rel_bias_tokens=args.cached_len, use_cache=True)
89
+ y_cached, kvs2 = mha(
90
+ append,
91
+ causal_mask_cached(args.new_len, args.cached_len),
92
+ rel_bias_tokens=seq,
93
+ kv_cache=kvs,
94
+ use_cache=True,
95
+ )
96
+ k_cached, v_cached = kvs2
97
+ rows["cached_append_forward"] = (y_full - y_cached).abs().max().item()
98
+ rows["cache_k_width"] = float(k_cached.size(-1))
99
+ rows["cache_v_width"] = float(v_cached.size(-1))
100
+
101
+ if r > d // h:
102
+ _ = mha(prefix, None, use_cache=False)
103
+ first_cache = mha._metric_cache
104
+ _ = mha(prefix, None, use_cache=False)
105
+ second_cache = mha._metric_cache
106
+ rows["metric_cache_reused"] = 0.0 if (first_cache is second_cache and first_cache is not None) else 1.0
107
+ mha.train(True)
108
+ rows["metric_cache_cleared_on_train"] = 0.0 if mha._metric_cache is None else 1.0
109
+
110
+ tol = args.tol
111
+ ok = True
112
+ numeric_rows = {}
113
+ for key, value in rows.items():
114
+ if key in {"cache_k_width", "cache_v_width"}:
115
+ numeric_rows[key] = value
116
+ continue
117
+ numeric_rows[key] = value
118
+ ok = ok and value <= tol
119
+
120
+ expected_k_width = d // h if r > (d // h) else r
121
+ ok = ok and int(rows["cache_k_width"]) == expected_k_width
122
+ ok = ok and int(rows["cache_v_width"]) == d // h
123
+ return {
124
+ "preset": preset,
125
+ "backend": backend,
126
+ "d": d,
127
+ "heads": h,
128
+ "rank": r,
129
+ "dk": d // h,
130
+ "expected_k_width": expected_k_width,
131
+ "ok": ok,
132
+ "tol": tol,
133
+ "rows": numeric_rows,
134
+ }
135
+
136
+
137
+ def main() -> int:
138
+ parser = argparse.ArgumentParser(description="Verify AGILLM-4 exact M-fold attention harvest from n1.py")
139
+ parser.add_argument("--presets", default="pico_1x,micro_3x")
140
+ parser.add_argument("--backends", default="manual,sdpa")
141
+ parser.add_argument("--cached_len", type=int, default=8)
142
+ parser.add_argument("--new_len", type=int, default=4)
143
+ parser.add_argument("--seed", type=int, default=1234)
144
+ parser.add_argument("--tol", type=float, default=2e-4)
145
+ parser.add_argument("--json_out", default="")
146
+ args = parser.parse_args()
147
+
148
+ results = []
149
+ all_ok = True
150
+ for preset in [item.strip() for item in args.presets.split(",") if item.strip()]:
151
+ for backend in [item.strip() for item in args.backends.split(",") if item.strip()]:
152
+ result = verify_case(args, preset, backend)
153
+ results.append(result)
154
+ all_ok = all_ok and result["ok"]
155
+ print(json.dumps(result, sort_keys=True), flush=True)
156
+ if args.json_out:
157
+ Path(args.json_out).write_text(json.dumps(results, indent=2, sort_keys=True), encoding="utf-8")
158
+ return 0 if all_ok else 1
159
+
160
+
161
+ if __name__ == "__main__":
162
+ raise SystemExit(main())