Harvest exact M-fold attention from n1
Browse files- AGILLM-4.md +27 -0
- N1_HARVEST.md +78 -0
- README.md +3 -0
- local_sweep_after_mfold_sdpa.json +22 -0
- local_sweep_after_mfold_sublinear.json +22 -0
- local_verify_m_fold_agillm4.json +118 -0
- nB300_agillm4.py +55 -2
- verify_m_fold_agillm4.py +162 -0
AGILLM-4.md
CHANGED
|
@@ -142,6 +142,33 @@ python /workspace/agillm-4/block_sweep_agillm4.py \
|
|
| 142 |
If it is stable, then compare it against SDPA at the same block size with
|
| 143 |
`profile_agillm4.py`.
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
## Intelligence per FLOP
|
| 146 |
|
| 147 |
Compute reduction is not enough by itself. AGILLM-4 should spend saved FLOPs on
|
|
|
|
| 142 |
If it is stable, then compare it against SDPA at the same block size with
|
| 143 |
`profile_agillm4.py`.
|
| 144 |
|
| 145 |
+
## n1.py Harvest
|
| 146 |
+
|
| 147 |
+
AGILLM-4 is now starting to import exact, proof-backed improvements from
|
| 148 |
+
`C:\Users\Scott\Downloads\n1.py` while keeping the AGILLM-4 model-scale and
|
| 149 |
+
long-context branch intact.
|
| 150 |
+
|
| 151 |
+
First harvested feature: exact M-fold expansion attention. When `rank > d_k`,
|
| 152 |
+
the trainer now uses:
|
| 153 |
+
|
| 154 |
+
```text
|
| 155 |
+
(q @ U) @ (k @ U).T == q @ (U @ U.T) @ k.T
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
This preserves the function while keeping score/cache key width at `d_k`
|
| 159 |
+
instead of `rank`. The inference path caches `U @ U.T`; the training path
|
| 160 |
+
recomputes it so gradients through `U` remain exact.
|
| 161 |
+
|
| 162 |
+
Verification:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
python /workspace/agillm-4/verify_m_fold_agillm4.py \
|
| 166 |
+
--presets pico_1x,micro_3x \
|
| 167 |
+
--backends manual,sdpa
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
See `N1_HARVEST.md` for the staged port order.
|
| 171 |
+
|
| 172 |
## Intelligence per FLOP
|
| 173 |
|
| 174 |
Compute reduction is not enough by itself. AGILLM-4 should spend saved FLOPs on
|
N1_HARVEST.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# n1.py Harvest Plan
|
| 2 |
+
|
| 3 |
+
Goal: move proven, no-quality-loss trainer improvements from
|
| 4 |
+
`C:\Users\Scott\Downloads\n1.py` into AGILLM-4 without replacing the AGILLM-4
|
| 5 |
+
long-context/model-scale branch.
|
| 6 |
+
|
| 7 |
+
## Ported
|
| 8 |
+
|
| 9 |
+
### 1. Exact M-Fold Expansion Attention
|
| 10 |
+
|
| 11 |
+
Status: done.
|
| 12 |
+
|
| 13 |
+
For ranks where `rank > d_k`, AGILLM-4 now computes:
|
| 14 |
+
|
| 15 |
+
```text
|
| 16 |
+
(q @ U) @ (k @ U).T == q @ (U @ U.T) @ k.T
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
This keeps attention scores and KV-cache keys in `d_k` width instead of
|
| 20 |
+
`rank` width while preserving the exact expanded-attention function. The
|
| 21 |
+
training path recomputes `U @ U.T` with gradients, and inference/no-grad caches
|
| 22 |
+
the metric until `U` changes.
|
| 23 |
+
|
| 24 |
+
Verification:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
python agillm-4/verify_m_fold_agillm4.py \
|
| 28 |
+
--presets pico_1x,micro_3x \
|
| 29 |
+
--backends manual,sdpa \
|
| 30 |
+
--cached_len 8 \
|
| 31 |
+
--new_len 4
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
The verifier checks forward output, loss, input gradients, parameter gradients,
|
| 35 |
+
cached append equivalence, cache key width, and metric-cache invalidation.
|
| 36 |
+
|
| 37 |
+
## Next Candidates
|
| 38 |
+
|
| 39 |
+
### 2. Fused QKV Projection
|
| 40 |
+
|
| 41 |
+
n1 fuses separate `q/k/v` linear layers into one `qkv` linear while keeping
|
| 42 |
+
checkpoint compatibility by folding old state-dict keys on load. This should be
|
| 43 |
+
the next port after M-fold because it reduces three projection GEMMs to one.
|
| 44 |
+
|
| 45 |
+
Risk: checkpoint key migration. Keep this separate from the M-fold port.
|
| 46 |
+
|
| 47 |
+
### 3. Combined ALiBi + Mask Cache
|
| 48 |
+
|
| 49 |
+
n1 pre-folds ALiBi into the mask once per encoder forward instead of rebuilding
|
| 50 |
+
the same layer-independent bias in every block.
|
| 51 |
+
|
| 52 |
+
Risk: cache semantics differ for KV decode where the ALiBi slice changes.
|
| 53 |
+
|
| 54 |
+
### 4. SAT Speculative Inference
|
| 55 |
+
|
| 56 |
+
n1 has proof-covered SAT-draft / AR-verify speculative decoding. This belongs
|
| 57 |
+
in AGILLM-4 after the SFT result tells us whether chat turns are sane.
|
| 58 |
+
|
| 59 |
+
Risk: inference control flow and cache rollback complexity.
|
| 60 |
+
|
| 61 |
+
### 5. Compact Checkpoint
|
| 62 |
+
|
| 63 |
+
n1 can compact `U` spectra post-training and save compatible checkpoints.
|
| 64 |
+
|
| 65 |
+
Risk: optimizer state must be dropped or remapped carefully; do only as a
|
| 66 |
+
separate command, never during a live training run.
|
| 67 |
+
|
| 68 |
+
### 6. KV Cache Buffer
|
| 69 |
+
|
| 70 |
+
n1 replaces repeated decode-time `torch.cat` cache growth with preallocated KV
|
| 71 |
+
buffers.
|
| 72 |
+
|
| 73 |
+
Risk: cache object type touches AR, SAT, and future spec decoding paths.
|
| 74 |
+
|
| 75 |
+
## Rule
|
| 76 |
+
|
| 77 |
+
Every harvested feature needs its own AGILLM-4 verifier or profile artifact.
|
| 78 |
+
Do not rely on n1's proof suite alone after adapting the implementation.
|
README.md
CHANGED
|
@@ -19,9 +19,12 @@ and extended for:
|
|
| 19 |
- longer block-size work on 24GB, B200, and B300 class GPUs
|
| 20 |
- AR+SAT every step with sequential backward to reduce peak VRAM
|
| 21 |
- SDPA and experimental sublinear local+landmark attention backends
|
|
|
|
| 22 |
- profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
|
| 23 |
- synthetic long-context curriculum generation for recall and multi-hop tests
|
| 24 |
|
| 25 |
Start with [AGILLM-4.md](AGILLM-4.md) for the training plan and command
|
| 26 |
recipes. The current sublinear backend is intentionally experimental: profile it
|
| 27 |
against SDPA before using it for a real run.
|
|
|
|
|
|
|
|
|
| 19 |
- longer block-size work on 24GB, B200, and B300 class GPUs
|
| 20 |
- AR+SAT every step with sequential backward to reduce peak VRAM
|
| 21 |
- SDPA and experimental sublinear local+landmark attention backends
|
| 22 |
+
- exact M-fold expansion attention harvested from n1.py, with local verifier
|
| 23 |
- profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
|
| 24 |
- synthetic long-context curriculum generation for recall and multi-hop tests
|
| 25 |
|
| 26 |
Start with [AGILLM-4.md](AGILLM-4.md) for the training plan and command
|
| 27 |
recipes. The current sublinear backend is intentionally experimental: profile it
|
| 28 |
against SDPA before using it for a real run.
|
| 29 |
+
|
| 30 |
+
Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
|
local_sweep_after_mfold_sdpa.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"alloc_gb": 0.027,
|
| 4 |
+
"amp": false,
|
| 5 |
+
"attn_backend": "sdpa",
|
| 6 |
+
"batch_size": 1,
|
| 7 |
+
"block": 64,
|
| 8 |
+
"elapsed_s": 0.534,
|
| 9 |
+
"error": null,
|
| 10 |
+
"grad_checkpoint": true,
|
| 11 |
+
"loss": 18.8005,
|
| 12 |
+
"ok": true,
|
| 13 |
+
"peak_alloc_gb": 0.031,
|
| 14 |
+
"peak_reserved_gb": 0.057,
|
| 15 |
+
"reserved_gb": 0.057,
|
| 16 |
+
"sublinear_chunk": 128,
|
| 17 |
+
"sublinear_max_anchors": 256,
|
| 18 |
+
"sublinear_stride": 64,
|
| 19 |
+
"sublinear_window": 256,
|
| 20 |
+
"tokens_per_s_synthetic": 119.8
|
| 21 |
+
}
|
| 22 |
+
]
|
local_sweep_after_mfold_sublinear.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"alloc_gb": 0.027,
|
| 4 |
+
"amp": false,
|
| 5 |
+
"attn_backend": "sublinear",
|
| 6 |
+
"batch_size": 1,
|
| 7 |
+
"block": 64,
|
| 8 |
+
"elapsed_s": 0.635,
|
| 9 |
+
"error": null,
|
| 10 |
+
"grad_checkpoint": true,
|
| 11 |
+
"loss": 18.7676,
|
| 12 |
+
"ok": true,
|
| 13 |
+
"peak_alloc_gb": 0.031,
|
| 14 |
+
"peak_reserved_gb": 0.057,
|
| 15 |
+
"reserved_gb": 0.057,
|
| 16 |
+
"sublinear_chunk": 16,
|
| 17 |
+
"sublinear_max_anchors": 16,
|
| 18 |
+
"sublinear_stride": 8,
|
| 19 |
+
"sublinear_window": 16,
|
| 20 |
+
"tokens_per_s_synthetic": 100.8
|
| 21 |
+
}
|
| 22 |
+
]
|
local_verify_m_fold_agillm4.json
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"backend": "manual",
|
| 4 |
+
"d": 32,
|
| 5 |
+
"dk": 16,
|
| 6 |
+
"expected_k_width": 16,
|
| 7 |
+
"heads": 2,
|
| 8 |
+
"ok": true,
|
| 9 |
+
"preset": "pico_1x",
|
| 10 |
+
"rank": 16,
|
| 11 |
+
"rows": {
|
| 12 |
+
"cache_k_width": 16.0,
|
| 13 |
+
"cache_v_width": 16.0,
|
| 14 |
+
"cached_append_forward": 5.960464477539063e-08,
|
| 15 |
+
"causal_alibi_forward": 0.0,
|
| 16 |
+
"causal_alibi_loss": 0.0,
|
| 17 |
+
"causal_alibi_param_grad": 0.0,
|
| 18 |
+
"causal_alibi_x_grad": 0.0,
|
| 19 |
+
"none_forward": 0.0,
|
| 20 |
+
"none_loss": 0.0,
|
| 21 |
+
"none_param_grad": 0.0,
|
| 22 |
+
"none_x_grad": 0.0,
|
| 23 |
+
"sat_alibi_forward": 0.0,
|
| 24 |
+
"sat_alibi_loss": 0.0,
|
| 25 |
+
"sat_alibi_param_grad": 0.0,
|
| 26 |
+
"sat_alibi_x_grad": 0.0
|
| 27 |
+
},
|
| 28 |
+
"tol": 0.0002
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"backend": "sdpa",
|
| 32 |
+
"d": 32,
|
| 33 |
+
"dk": 16,
|
| 34 |
+
"expected_k_width": 16,
|
| 35 |
+
"heads": 2,
|
| 36 |
+
"ok": true,
|
| 37 |
+
"preset": "pico_1x",
|
| 38 |
+
"rank": 16,
|
| 39 |
+
"rows": {
|
| 40 |
+
"cache_k_width": 16.0,
|
| 41 |
+
"cache_v_width": 16.0,
|
| 42 |
+
"cached_append_forward": 5.960464477539063e-08,
|
| 43 |
+
"causal_alibi_forward": 7.450580596923828e-08,
|
| 44 |
+
"causal_alibi_loss": 0.0,
|
| 45 |
+
"causal_alibi_param_grad": 1.862645149230957e-09,
|
| 46 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 47 |
+
"none_forward": 8.940696716308594e-08,
|
| 48 |
+
"none_loss": 0.0,
|
| 49 |
+
"none_param_grad": 9.313225746154785e-10,
|
| 50 |
+
"none_x_grad": 6.548361852765083e-11,
|
| 51 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 52 |
+
"sat_alibi_loss": 1.862645149230957e-09,
|
| 53 |
+
"sat_alibi_param_grad": 9.313225746154785e-10,
|
| 54 |
+
"sat_alibi_x_grad": 3.4924596548080444e-10
|
| 55 |
+
},
|
| 56 |
+
"tol": 0.0002
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"backend": "manual",
|
| 60 |
+
"d": 128,
|
| 61 |
+
"dk": 16,
|
| 62 |
+
"expected_k_width": 16,
|
| 63 |
+
"heads": 8,
|
| 64 |
+
"ok": true,
|
| 65 |
+
"preset": "micro_3x",
|
| 66 |
+
"rank": 48,
|
| 67 |
+
"rows": {
|
| 68 |
+
"cache_k_width": 16.0,
|
| 69 |
+
"cache_v_width": 16.0,
|
| 70 |
+
"cached_append_forward": 6.51925802230835e-08,
|
| 71 |
+
"causal_alibi_forward": 5.960464477539063e-08,
|
| 72 |
+
"causal_alibi_loss": 0.0,
|
| 73 |
+
"causal_alibi_param_grad": 4.656612873077393e-10,
|
| 74 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 75 |
+
"metric_cache_cleared_on_train": 0.0,
|
| 76 |
+
"metric_cache_reused": 0.0,
|
| 77 |
+
"none_forward": 5.960464477539063e-08,
|
| 78 |
+
"none_loss": 0.0,
|
| 79 |
+
"none_param_grad": 2.3283064365386963e-10,
|
| 80 |
+
"none_x_grad": 2.1827872842550278e-11,
|
| 81 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 82 |
+
"sat_alibi_loss": 1.862645149230957e-09,
|
| 83 |
+
"sat_alibi_param_grad": 5.820766091346741e-10,
|
| 84 |
+
"sat_alibi_x_grad": 7.275957614183426e-11
|
| 85 |
+
},
|
| 86 |
+
"tol": 0.0002
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"backend": "sdpa",
|
| 90 |
+
"d": 128,
|
| 91 |
+
"dk": 16,
|
| 92 |
+
"expected_k_width": 16,
|
| 93 |
+
"heads": 8,
|
| 94 |
+
"ok": true,
|
| 95 |
+
"preset": "micro_3x",
|
| 96 |
+
"rank": 48,
|
| 97 |
+
"rows": {
|
| 98 |
+
"cache_k_width": 16.0,
|
| 99 |
+
"cache_v_width": 16.0,
|
| 100 |
+
"cached_append_forward": 7.450580596923828e-08,
|
| 101 |
+
"causal_alibi_forward": 1.043081283569336e-07,
|
| 102 |
+
"causal_alibi_loss": 0.0,
|
| 103 |
+
"causal_alibi_param_grad": 4.656612873077393e-10,
|
| 104 |
+
"causal_alibi_x_grad": 9.458744898438454e-11,
|
| 105 |
+
"metric_cache_cleared_on_train": 0.0,
|
| 106 |
+
"metric_cache_reused": 0.0,
|
| 107 |
+
"none_forward": 8.940696716308594e-08,
|
| 108 |
+
"none_loss": 0.0,
|
| 109 |
+
"none_param_grad": 2.3283064365386963e-10,
|
| 110 |
+
"none_x_grad": 2.9103830456733704e-11,
|
| 111 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 112 |
+
"sat_alibi_loss": 0.0,
|
| 113 |
+
"sat_alibi_param_grad": 6.984919309616089e-10,
|
| 114 |
+
"sat_alibi_x_grad": 1.1641532182693481e-10
|
| 115 |
+
},
|
| 116 |
+
"tol": 0.0002
|
| 117 |
+
}
|
| 118 |
+
]
|
nB300_agillm4.py
CHANGED
|
@@ -1154,6 +1154,15 @@ class TuneableAttentionMHA(nn.Module):
|
|
| 1154 |
nn.init.orthogonal_(self.U)
|
| 1155 |
self.proj = nn.Linear(h * self.dk, d, bias=False)
|
| 1156 |
self.drop = nn.Dropout(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1157 |
|
| 1158 |
def _proj_qk(self, x):
|
| 1159 |
B, N, _ = x.shape
|
|
@@ -1163,6 +1172,44 @@ class TuneableAttentionMHA(nn.Module):
|
|
| 1163 |
B, N, _ = x.shape
|
| 1164 |
return x.view(B, N, self.h, self.dk).transpose(1, 2)
|
| 1165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1166 |
def _sublinear_attention(self, q, k, v, attn_mask=None):
|
| 1167 |
"""Local-window + landmark attention: O(N * (window + N/stride))."""
|
| 1168 |
bsz, heads, q_len, _ = q.shape
|
|
@@ -1226,9 +1273,15 @@ class TuneableAttentionMHA(nn.Module):
|
|
| 1226 |
return torch.cat(outputs, dim=2)
|
| 1227 |
|
| 1228 |
def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
v_new = self._reshape_v(self.v(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
if kv_cache is None:
|
| 1233 |
k, v = k_new, v_new
|
| 1234 |
else:
|
|
|
|
| 1154 |
nn.init.orthogonal_(self.U)
|
| 1155 |
self.proj = nn.Linear(h * self.dk, d, bias=False)
|
| 1156 |
self.drop = nn.Dropout(0.1)
|
| 1157 |
+
# Exact n1 harvest: for expansion ranks, (q @ U) @ (k @ U).T is
|
| 1158 |
+
# q @ (U @ U.T) @ k.T. This keeps score/cache width at d_k with no
|
| 1159 |
+
# quality change. Inference caches the metric and training recomputes
|
| 1160 |
+
# it so gradients through U are unchanged.
|
| 1161 |
+
self._metric_cache: Optional[torch.Tensor] = None
|
| 1162 |
+
self._metric_cache_ver: int = -1
|
| 1163 |
+
self._metric_cache_param_id: int = -1
|
| 1164 |
+
self._metric_cache_data_ptr: int = -1
|
| 1165 |
+
self._metric_cache_shape: Tuple[int, int] = (-1, -1)
|
| 1166 |
|
| 1167 |
def _proj_qk(self, x):
|
| 1168 |
B, N, _ = x.shape
|
|
|
|
| 1172 |
B, N, _ = x.shape
|
| 1173 |
return x.view(B, N, self.h, self.dk).transpose(1, 2)
|
| 1174 |
|
| 1175 |
+
def _reshape_heads(self, x):
|
| 1176 |
+
B, N, _ = x.shape
|
| 1177 |
+
return x.view(B, N, self.h, self.dk).transpose(1, 2)
|
| 1178 |
+
|
| 1179 |
+
def _get_metric(self) -> torch.Tensor:
|
| 1180 |
+
if torch.is_grad_enabled():
|
| 1181 |
+
return self.U @ self.U.T
|
| 1182 |
+
cur_ver = self.U._version
|
| 1183 |
+
cur_param_id = id(self.U)
|
| 1184 |
+
cur_data_ptr = int(self.U.data_ptr())
|
| 1185 |
+
cur_shape = tuple(self.U.shape)
|
| 1186 |
+
cache = self._metric_cache
|
| 1187 |
+
if (
|
| 1188 |
+
cache is None
|
| 1189 |
+
or cache.dtype != self.U.dtype
|
| 1190 |
+
or cache.device != self.U.device
|
| 1191 |
+
or self._metric_cache_ver != cur_ver
|
| 1192 |
+
or self._metric_cache_param_id != cur_param_id
|
| 1193 |
+
or self._metric_cache_data_ptr != cur_data_ptr
|
| 1194 |
+
or self._metric_cache_shape != cur_shape
|
| 1195 |
+
):
|
| 1196 |
+
cache = (self.U @ self.U.T).detach()
|
| 1197 |
+
self._metric_cache = cache
|
| 1198 |
+
self._metric_cache_ver = cur_ver
|
| 1199 |
+
self._metric_cache_param_id = cur_param_id
|
| 1200 |
+
self._metric_cache_data_ptr = cur_data_ptr
|
| 1201 |
+
self._metric_cache_shape = cur_shape
|
| 1202 |
+
return cache
|
| 1203 |
+
|
| 1204 |
+
def train(self, mode: bool = True):
|
| 1205 |
+
if mode:
|
| 1206 |
+
self._metric_cache = None
|
| 1207 |
+
self._metric_cache_ver = -1
|
| 1208 |
+
self._metric_cache_param_id = -1
|
| 1209 |
+
self._metric_cache_data_ptr = -1
|
| 1210 |
+
self._metric_cache_shape = (-1, -1)
|
| 1211 |
+
return super().train(mode)
|
| 1212 |
+
|
| 1213 |
def _sublinear_attention(self, q, k, v, attn_mask=None):
|
| 1214 |
"""Local-window + landmark attention: O(N * (window + N/stride))."""
|
| 1215 |
bsz, heads, q_len, _ = q.shape
|
|
|
|
| 1273 |
return torch.cat(outputs, dim=2)
|
| 1274 |
|
| 1275 |
def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
|
| 1276 |
+
q_lin = self.q(x)
|
| 1277 |
+
k_lin = self.k(x)
|
| 1278 |
v_new = self._reshape_v(self.v(x))
|
| 1279 |
+
if self.r > self.dk:
|
| 1280 |
+
q = self._reshape_heads(q_lin) @ self._get_metric()
|
| 1281 |
+
k_new = self._reshape_heads(k_lin)
|
| 1282 |
+
else:
|
| 1283 |
+
q = self._proj_qk(q_lin)
|
| 1284 |
+
k_new = self._proj_qk(k_lin)
|
| 1285 |
if kv_cache is None:
|
| 1286 |
k, v = k_new, v_new
|
| 1287 |
else:
|
verify_m_fold_agillm4.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
os.environ.setdefault("AGILLM_SYNTHETIC_TOKENIZER", "1")
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
import nB300_agillm4 as nb
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def causal_mask_cached(new_len: int, cached_len: int):
|
| 18 |
+
total = cached_len + new_len
|
| 19 |
+
q_pos = torch.arange(cached_len, total, device=nb.DEV).view(new_len, 1)
|
| 20 |
+
k_pos = torch.arange(total, device=nb.DEV).view(1, total)
|
| 21 |
+
mask = torch.where(k_pos > q_pos, float("-inf"), 0.0)
|
| 22 |
+
return mask.view(1, 1, new_len, total)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def old_expanded_forward(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
|
| 26 |
+
bsz, seq, _ = x.shape
|
| 27 |
+
q = mha._reshape_heads(mha.q(x)) @ mha.U
|
| 28 |
+
k = mha._reshape_heads(mha.k(x)) @ mha.U
|
| 29 |
+
v = mha._reshape_v(mha.v(x))
|
| 30 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
|
| 31 |
+
if mha.use_relpos and rel_bias_tokens is not None:
|
| 32 |
+
att = att + nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
|
| 33 |
+
if mask is not None:
|
| 34 |
+
att = att + mask
|
| 35 |
+
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(bsz, seq, -1)
|
| 36 |
+
return mha.drop(mha.proj(z))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def max_param_grad_diff(mha: nb.TuneableAttentionMHA, ref_grads: dict[str, torch.Tensor]) -> float:
|
| 40 |
+
out = 0.0
|
| 41 |
+
for name, param in mha.named_parameters():
|
| 42 |
+
if param.grad is None:
|
| 43 |
+
continue
|
| 44 |
+
out = max(out, (param.grad.detach() - ref_grads[name]).abs().max().item())
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def verify_case(args, preset: str, backend: str) -> dict:
|
| 49 |
+
torch.manual_seed(args.seed)
|
| 50 |
+
cfg = nb.PRESETS[preset].copy()
|
| 51 |
+
d, h, r = cfg["d"], cfg["heads"], cfg["rank"]
|
| 52 |
+
seq = args.cached_len + args.new_len
|
| 53 |
+
mha = nb.TuneableAttentionMHA(d, h, r, attn_backend=backend).to(nb.DEV).eval()
|
| 54 |
+
rows = {}
|
| 55 |
+
|
| 56 |
+
for case_name, mask, rel_tokens in [
|
| 57 |
+
("none", None, None),
|
| 58 |
+
("causal_alibi", nb.causal_mask(seq), seq),
|
| 59 |
+
("sat_alibi", nb.sat_mask(seq), seq),
|
| 60 |
+
]:
|
| 61 |
+
x_new = torch.randn(2, seq, d, device=nb.DEV, requires_grad=True)
|
| 62 |
+
x_old = x_new.detach().clone().requires_grad_(True)
|
| 63 |
+
y_new = mha(x_new, mask, rel_bias_tokens=rel_tokens)
|
| 64 |
+
y_old = old_expanded_forward(mha, x_old, mask=mask, rel_bias_tokens=rel_tokens)
|
| 65 |
+
loss_new = y_new.square().mean()
|
| 66 |
+
loss_old = y_old.square().mean()
|
| 67 |
+
loss_new.backward()
|
| 68 |
+
new_x_grad = x_new.grad.detach().clone()
|
| 69 |
+
new_param_grads = {
|
| 70 |
+
name: param.grad.detach().clone()
|
| 71 |
+
for name, param in mha.named_parameters()
|
| 72 |
+
if param.grad is not None
|
| 73 |
+
}
|
| 74 |
+
mha.zero_grad(set_to_none=True)
|
| 75 |
+
loss_old.backward()
|
| 76 |
+
old_x_grad = x_old.grad.detach().clone()
|
| 77 |
+
rows[f"{case_name}_forward"] = (y_new - y_old).abs().max().item()
|
| 78 |
+
rows[f"{case_name}_loss"] = abs(loss_new.item() - loss_old.item())
|
| 79 |
+
rows[f"{case_name}_x_grad"] = (new_x_grad - old_x_grad).abs().max().item()
|
| 80 |
+
rows[f"{case_name}_param_grad"] = max_param_grad_diff(mha, new_param_grads)
|
| 81 |
+
mha.zero_grad(set_to_none=True)
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
prefix = torch.randn(1, args.cached_len, d, device=nb.DEV)
|
| 85 |
+
append = torch.randn(1, args.new_len, d, device=nb.DEV)
|
| 86 |
+
full = torch.cat([prefix, append], dim=1)
|
| 87 |
+
y_full = mha(full, nb.causal_mask(seq), rel_bias_tokens=seq)[:, args.cached_len :]
|
| 88 |
+
_, kvs = mha(prefix, nb.causal_mask(args.cached_len), rel_bias_tokens=args.cached_len, use_cache=True)
|
| 89 |
+
y_cached, kvs2 = mha(
|
| 90 |
+
append,
|
| 91 |
+
causal_mask_cached(args.new_len, args.cached_len),
|
| 92 |
+
rel_bias_tokens=seq,
|
| 93 |
+
kv_cache=kvs,
|
| 94 |
+
use_cache=True,
|
| 95 |
+
)
|
| 96 |
+
k_cached, v_cached = kvs2
|
| 97 |
+
rows["cached_append_forward"] = (y_full - y_cached).abs().max().item()
|
| 98 |
+
rows["cache_k_width"] = float(k_cached.size(-1))
|
| 99 |
+
rows["cache_v_width"] = float(v_cached.size(-1))
|
| 100 |
+
|
| 101 |
+
if r > d // h:
|
| 102 |
+
_ = mha(prefix, None, use_cache=False)
|
| 103 |
+
first_cache = mha._metric_cache
|
| 104 |
+
_ = mha(prefix, None, use_cache=False)
|
| 105 |
+
second_cache = mha._metric_cache
|
| 106 |
+
rows["metric_cache_reused"] = 0.0 if (first_cache is second_cache and first_cache is not None) else 1.0
|
| 107 |
+
mha.train(True)
|
| 108 |
+
rows["metric_cache_cleared_on_train"] = 0.0 if mha._metric_cache is None else 1.0
|
| 109 |
+
|
| 110 |
+
tol = args.tol
|
| 111 |
+
ok = True
|
| 112 |
+
numeric_rows = {}
|
| 113 |
+
for key, value in rows.items():
|
| 114 |
+
if key in {"cache_k_width", "cache_v_width"}:
|
| 115 |
+
numeric_rows[key] = value
|
| 116 |
+
continue
|
| 117 |
+
numeric_rows[key] = value
|
| 118 |
+
ok = ok and value <= tol
|
| 119 |
+
|
| 120 |
+
expected_k_width = d // h if r > (d // h) else r
|
| 121 |
+
ok = ok and int(rows["cache_k_width"]) == expected_k_width
|
| 122 |
+
ok = ok and int(rows["cache_v_width"]) == d // h
|
| 123 |
+
return {
|
| 124 |
+
"preset": preset,
|
| 125 |
+
"backend": backend,
|
| 126 |
+
"d": d,
|
| 127 |
+
"heads": h,
|
| 128 |
+
"rank": r,
|
| 129 |
+
"dk": d // h,
|
| 130 |
+
"expected_k_width": expected_k_width,
|
| 131 |
+
"ok": ok,
|
| 132 |
+
"tol": tol,
|
| 133 |
+
"rows": numeric_rows,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def main() -> int:
|
| 138 |
+
parser = argparse.ArgumentParser(description="Verify AGILLM-4 exact M-fold attention harvest from n1.py")
|
| 139 |
+
parser.add_argument("--presets", default="pico_1x,micro_3x")
|
| 140 |
+
parser.add_argument("--backends", default="manual,sdpa")
|
| 141 |
+
parser.add_argument("--cached_len", type=int, default=8)
|
| 142 |
+
parser.add_argument("--new_len", type=int, default=4)
|
| 143 |
+
parser.add_argument("--seed", type=int, default=1234)
|
| 144 |
+
parser.add_argument("--tol", type=float, default=2e-4)
|
| 145 |
+
parser.add_argument("--json_out", default="")
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
|
| 148 |
+
results = []
|
| 149 |
+
all_ok = True
|
| 150 |
+
for preset in [item.strip() for item in args.presets.split(",") if item.strip()]:
|
| 151 |
+
for backend in [item.strip() for item in args.backends.split(",") if item.strip()]:
|
| 152 |
+
result = verify_case(args, preset, backend)
|
| 153 |
+
results.append(result)
|
| 154 |
+
all_ok = all_ok and result["ok"]
|
| 155 |
+
print(json.dumps(result, sort_keys=True), flush=True)
|
| 156 |
+
if args.json_out:
|
| 157 |
+
Path(args.json_out).write_text(json.dumps(results, indent=2, sort_keys=True), encoding="utf-8")
|
| 158 |
+
return 0 if all_ok else 1
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
raise SystemExit(main())
|