Harvest fused QKV projection from n1
Browse files- AGILLM-4.md +14 -0
- N1_HARVEST.md +33 -5
- README.md +1 -0
- local_sweep_after_qkv_sdpa.json +22 -0
- local_sweep_after_qkv_sublinear.json +22 -0
- local_verify_m_fold_after_qkv_agillm4.json +118 -0
- local_verify_m_fold_after_qkv_fix_agillm4.json +118 -0
- local_verify_qkv_agillm4.json +119 -0
- local_verify_qkv_all_backends_agillm4.json +177 -0
- local_verify_qkv_sublinear_agillm4.json +31 -0
- nB300_agillm4.py +186 -14
- verify_m_fold_agillm4.py +4 -3
- verify_qkv_agillm4.py +246 -0
AGILLM-4.md
CHANGED
|
@@ -167,6 +167,20 @@ python /workspace/agillm-4/verify_m_fold_agillm4.py \
|
|
| 167 |
--backends manual,sdpa
|
| 168 |
```
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
See `N1_HARVEST.md` for the staged port order.
|
| 171 |
|
| 172 |
## Intelligence per FLOP
|
|
|
|
| 167 |
--backends manual,sdpa
|
| 168 |
```
|
| 169 |
|
| 170 |
+
Second harvested feature: fused QKV projection. AGILLM-4 now uses one
|
| 171 |
+
`Linear(d, 3d)` and chunks the result instead of running three separate
|
| 172 |
+
`Linear(d, d)` projections. Legacy `q.weight/k.weight/v.weight` checkpoints
|
| 173 |
+
load into the fused `qkv.weight` layout, and matching legacy AdamW moments are
|
| 174 |
+
remapped on full resume.
|
| 175 |
+
|
| 176 |
+
Verification:
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
python /workspace/agillm-4/verify_qkv_agillm4.py \
|
| 180 |
+
--presets pico_1x,micro_3x \
|
| 181 |
+
--backends manual,sdpa,sublinear
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
See `N1_HARVEST.md` for the staged port order.
|
| 185 |
|
| 186 |
## Intelligence per FLOP
|
N1_HARVEST.md
CHANGED
|
@@ -34,15 +34,43 @@ python agillm-4/verify_m_fold_agillm4.py \
|
|
| 34 |
The verifier checks forward output, loss, input gradients, parameter gradients,
|
| 35 |
cached append equivalence, cache key width, and metric-cache invalidation.
|
| 36 |
|
| 37 |
-
## Next Candidates
|
| 38 |
-
|
| 39 |
### 2. Fused QKV Projection
|
| 40 |
|
|
|
|
|
|
|
| 41 |
n1 fuses separate `q/k/v` linear layers into one `qkv` linear while keeping
|
| 42 |
-
checkpoint compatibility by folding old state-dict keys on load.
|
| 43 |
-
the
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
### 3. Combined ALiBi + Mask Cache
|
| 48 |
|
|
|
|
| 34 |
The verifier checks forward output, loss, input gradients, parameter gradients,
|
| 35 |
cached append equivalence, cache key width, and metric-cache invalidation.
|
| 36 |
|
|
|
|
|
|
|
| 37 |
### 2. Fused QKV Projection
|
| 38 |
|
| 39 |
+
Status: done.
|
| 40 |
+
|
| 41 |
n1 fuses separate `q/k/v` linear layers into one `qkv` linear while keeping
|
| 42 |
+
checkpoint compatibility by folding old state-dict keys on load. AGILLM-4 now
|
| 43 |
+
does the same. The parameter count and function are unchanged:
|
| 44 |
+
|
| 45 |
+
```text
|
| 46 |
+
[x Wq.T, x Wk.T, x Wv.T] == split(x [Wq; Wk; Wv].T)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Checkpoint compatibility:
|
| 50 |
|
| 51 |
+
- legacy `*.q.weight`, `*.k.weight`, `*.v.weight` triples load into
|
| 52 |
+
`*.qkv.weight`
|
| 53 |
+
- warm-start shape filtering fuses legacy triples before filtering
|
| 54 |
+
- legacy AdamW q/k/v moment tensors are concatenated into qkv optimizer state
|
| 55 |
+
when a full resume can be proven to match the old parameter layout
|
| 56 |
+
- if optimizer remap cannot be proven, model weights still load and optimizer
|
| 57 |
+
state is reset with a warning
|
| 58 |
+
|
| 59 |
+
Verification:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
python agillm-4/verify_qkv_agillm4.py \
|
| 63 |
+
--presets pico_1x,micro_3x \
|
| 64 |
+
--backends manual,sdpa,sublinear \
|
| 65 |
+
--cached_len 8 \
|
| 66 |
+
--new_len 4
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
The verifier checks fused-vs-unfused forward output, loss, input gradients,
|
| 70 |
+
parameter gradients, strict legacy state-dict loading, `_safe_load_any`
|
| 71 |
+
warm-start loading, and optimizer-state remap.
|
| 72 |
+
|
| 73 |
+
## Next Candidates
|
| 74 |
|
| 75 |
### 3. Combined ALiBi + Mask Cache
|
| 76 |
|
README.md
CHANGED
|
@@ -20,6 +20,7 @@ and extended for:
|
|
| 20 |
- AR+SAT every step with sequential backward to reduce peak VRAM
|
| 21 |
- SDPA and experimental sublinear local+landmark attention backends
|
| 22 |
- exact M-fold expansion attention harvested from n1.py, with local verifier
|
|
|
|
| 23 |
- profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
|
| 24 |
- synthetic long-context curriculum generation for recall and multi-hop tests
|
| 25 |
|
|
|
|
| 20 |
- AR+SAT every step with sequential backward to reduce peak VRAM
|
| 21 |
- SDPA and experimental sublinear local+landmark attention backends
|
| 22 |
- exact M-fold expansion attention harvested from n1.py, with local verifier
|
| 23 |
+
- fused QKV projection harvested from n1.py, with legacy checkpoint loading
|
| 24 |
- profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
|
| 25 |
- synthetic long-context curriculum generation for recall and multi-hop tests
|
| 26 |
|
local_sweep_after_qkv_sdpa.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"alloc_gb": 0.027,
|
| 4 |
+
"amp": false,
|
| 5 |
+
"attn_backend": "sdpa",
|
| 6 |
+
"batch_size": 1,
|
| 7 |
+
"block": 64,
|
| 8 |
+
"elapsed_s": 0.538,
|
| 9 |
+
"error": null,
|
| 10 |
+
"grad_checkpoint": true,
|
| 11 |
+
"loss": 18.4698,
|
| 12 |
+
"ok": true,
|
| 13 |
+
"peak_alloc_gb": 0.031,
|
| 14 |
+
"peak_reserved_gb": 0.057,
|
| 15 |
+
"reserved_gb": 0.057,
|
| 16 |
+
"sublinear_chunk": 128,
|
| 17 |
+
"sublinear_max_anchors": 256,
|
| 18 |
+
"sublinear_stride": 64,
|
| 19 |
+
"sublinear_window": 256,
|
| 20 |
+
"tokens_per_s_synthetic": 119.1
|
| 21 |
+
}
|
| 22 |
+
]
|
local_sweep_after_qkv_sublinear.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"alloc_gb": 0.027,
|
| 4 |
+
"amp": false,
|
| 5 |
+
"attn_backend": "sublinear",
|
| 6 |
+
"batch_size": 1,
|
| 7 |
+
"block": 64,
|
| 8 |
+
"elapsed_s": 0.698,
|
| 9 |
+
"error": null,
|
| 10 |
+
"grad_checkpoint": true,
|
| 11 |
+
"loss": 18.251,
|
| 12 |
+
"ok": true,
|
| 13 |
+
"peak_alloc_gb": 0.031,
|
| 14 |
+
"peak_reserved_gb": 0.057,
|
| 15 |
+
"reserved_gb": 0.057,
|
| 16 |
+
"sublinear_chunk": 16,
|
| 17 |
+
"sublinear_max_anchors": 16,
|
| 18 |
+
"sublinear_stride": 8,
|
| 19 |
+
"sublinear_window": 16,
|
| 20 |
+
"tokens_per_s_synthetic": 91.7
|
| 21 |
+
}
|
| 22 |
+
]
|
local_verify_m_fold_after_qkv_agillm4.json
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"backend": "manual",
|
| 4 |
+
"d": 32,
|
| 5 |
+
"dk": 16,
|
| 6 |
+
"expected_k_width": 16,
|
| 7 |
+
"heads": 2,
|
| 8 |
+
"ok": true,
|
| 9 |
+
"preset": "pico_1x",
|
| 10 |
+
"rank": 16,
|
| 11 |
+
"rows": {
|
| 12 |
+
"cache_k_width": 16.0,
|
| 13 |
+
"cache_v_width": 16.0,
|
| 14 |
+
"cached_append_forward": 5.960464477539063e-08,
|
| 15 |
+
"causal_alibi_forward": 0.0,
|
| 16 |
+
"causal_alibi_loss": 0.0,
|
| 17 |
+
"causal_alibi_param_grad": 0.0,
|
| 18 |
+
"causal_alibi_x_grad": 0.0,
|
| 19 |
+
"none_forward": 0.0,
|
| 20 |
+
"none_loss": 0.0,
|
| 21 |
+
"none_param_grad": 0.0,
|
| 22 |
+
"none_x_grad": 0.0,
|
| 23 |
+
"sat_alibi_forward": 0.0,
|
| 24 |
+
"sat_alibi_loss": 0.0,
|
| 25 |
+
"sat_alibi_param_grad": 0.0,
|
| 26 |
+
"sat_alibi_x_grad": 0.0
|
| 27 |
+
},
|
| 28 |
+
"tol": 0.0002
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"backend": "sdpa",
|
| 32 |
+
"d": 32,
|
| 33 |
+
"dk": 16,
|
| 34 |
+
"expected_k_width": 16,
|
| 35 |
+
"heads": 2,
|
| 36 |
+
"ok": true,
|
| 37 |
+
"preset": "pico_1x",
|
| 38 |
+
"rank": 16,
|
| 39 |
+
"rows": {
|
| 40 |
+
"cache_k_width": 16.0,
|
| 41 |
+
"cache_v_width": 16.0,
|
| 42 |
+
"cached_append_forward": 5.960464477539063e-08,
|
| 43 |
+
"causal_alibi_forward": 7.450580596923828e-08,
|
| 44 |
+
"causal_alibi_loss": 0.0,
|
| 45 |
+
"causal_alibi_param_grad": 1.862645149230957e-09,
|
| 46 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 47 |
+
"none_forward": 8.940696716308594e-08,
|
| 48 |
+
"none_loss": 0.0,
|
| 49 |
+
"none_param_grad": 9.313225746154785e-10,
|
| 50 |
+
"none_x_grad": 8.731149137020111e-11,
|
| 51 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 52 |
+
"sat_alibi_loss": 1.862645149230957e-09,
|
| 53 |
+
"sat_alibi_param_grad": 9.313225746154785e-10,
|
| 54 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 55 |
+
},
|
| 56 |
+
"tol": 0.0002
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"backend": "manual",
|
| 60 |
+
"d": 128,
|
| 61 |
+
"dk": 16,
|
| 62 |
+
"expected_k_width": 16,
|
| 63 |
+
"heads": 8,
|
| 64 |
+
"ok": true,
|
| 65 |
+
"preset": "micro_3x",
|
| 66 |
+
"rank": 48,
|
| 67 |
+
"rows": {
|
| 68 |
+
"cache_k_width": 16.0,
|
| 69 |
+
"cache_v_width": 16.0,
|
| 70 |
+
"cached_append_forward": 6.51925802230835e-08,
|
| 71 |
+
"causal_alibi_forward": 5.960464477539063e-08,
|
| 72 |
+
"causal_alibi_loss": 0.0,
|
| 73 |
+
"causal_alibi_param_grad": 4.656612873077393e-10,
|
| 74 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 75 |
+
"metric_cache_cleared_on_train": 0.0,
|
| 76 |
+
"metric_cache_reused": 0.0,
|
| 77 |
+
"none_forward": 5.960464477539063e-08,
|
| 78 |
+
"none_loss": 0.0,
|
| 79 |
+
"none_param_grad": 2.3283064365386963e-10,
|
| 80 |
+
"none_x_grad": 1.4551915228366852e-11,
|
| 81 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 82 |
+
"sat_alibi_loss": 1.862645149230957e-09,
|
| 83 |
+
"sat_alibi_param_grad": 5.820766091346741e-10,
|
| 84 |
+
"sat_alibi_x_grad": 5.820766091346741e-11
|
| 85 |
+
},
|
| 86 |
+
"tol": 0.0002
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"backend": "sdpa",
|
| 90 |
+
"d": 128,
|
| 91 |
+
"dk": 16,
|
| 92 |
+
"expected_k_width": 16,
|
| 93 |
+
"heads": 8,
|
| 94 |
+
"ok": true,
|
| 95 |
+
"preset": "micro_3x",
|
| 96 |
+
"rank": 48,
|
| 97 |
+
"rows": {
|
| 98 |
+
"cache_k_width": 16.0,
|
| 99 |
+
"cache_v_width": 16.0,
|
| 100 |
+
"cached_append_forward": 7.450580596923828e-08,
|
| 101 |
+
"causal_alibi_forward": 1.043081283569336e-07,
|
| 102 |
+
"causal_alibi_loss": 0.0,
|
| 103 |
+
"causal_alibi_param_grad": 4.656612873077393e-10,
|
| 104 |
+
"causal_alibi_x_grad": 8.731149137020111e-11,
|
| 105 |
+
"metric_cache_cleared_on_train": 0.0,
|
| 106 |
+
"metric_cache_reused": 0.0,
|
| 107 |
+
"none_forward": 8.940696716308594e-08,
|
| 108 |
+
"none_loss": 0.0,
|
| 109 |
+
"none_param_grad": 2.3283064365386963e-10,
|
| 110 |
+
"none_x_grad": 1.8189894035458565e-11,
|
| 111 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 112 |
+
"sat_alibi_loss": 0.0,
|
| 113 |
+
"sat_alibi_param_grad": 6.984919309616089e-10,
|
| 114 |
+
"sat_alibi_x_grad": 7.275957614183426e-11
|
| 115 |
+
},
|
| 116 |
+
"tol": 0.0002
|
| 117 |
+
}
|
| 118 |
+
]
|
local_verify_m_fold_after_qkv_fix_agillm4.json
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"backend": "manual",
|
| 4 |
+
"d": 32,
|
| 5 |
+
"dk": 16,
|
| 6 |
+
"expected_k_width": 16,
|
| 7 |
+
"heads": 2,
|
| 8 |
+
"ok": true,
|
| 9 |
+
"preset": "pico_1x",
|
| 10 |
+
"rank": 16,
|
| 11 |
+
"rows": {
|
| 12 |
+
"cache_k_width": 16.0,
|
| 13 |
+
"cache_v_width": 16.0,
|
| 14 |
+
"cached_append_forward": 5.960464477539063e-08,
|
| 15 |
+
"causal_alibi_forward": 0.0,
|
| 16 |
+
"causal_alibi_loss": 0.0,
|
| 17 |
+
"causal_alibi_param_grad": 0.0,
|
| 18 |
+
"causal_alibi_x_grad": 0.0,
|
| 19 |
+
"none_forward": 0.0,
|
| 20 |
+
"none_loss": 0.0,
|
| 21 |
+
"none_param_grad": 0.0,
|
| 22 |
+
"none_x_grad": 0.0,
|
| 23 |
+
"sat_alibi_forward": 0.0,
|
| 24 |
+
"sat_alibi_loss": 0.0,
|
| 25 |
+
"sat_alibi_param_grad": 0.0,
|
| 26 |
+
"sat_alibi_x_grad": 0.0
|
| 27 |
+
},
|
| 28 |
+
"tol": 0.0002
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"backend": "sdpa",
|
| 32 |
+
"d": 32,
|
| 33 |
+
"dk": 16,
|
| 34 |
+
"expected_k_width": 16,
|
| 35 |
+
"heads": 2,
|
| 36 |
+
"ok": true,
|
| 37 |
+
"preset": "pico_1x",
|
| 38 |
+
"rank": 16,
|
| 39 |
+
"rows": {
|
| 40 |
+
"cache_k_width": 16.0,
|
| 41 |
+
"cache_v_width": 16.0,
|
| 42 |
+
"cached_append_forward": 5.960464477539063e-08,
|
| 43 |
+
"causal_alibi_forward": 7.450580596923828e-08,
|
| 44 |
+
"causal_alibi_loss": 0.0,
|
| 45 |
+
"causal_alibi_param_grad": 1.862645149230957e-09,
|
| 46 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 47 |
+
"none_forward": 8.940696716308594e-08,
|
| 48 |
+
"none_loss": 0.0,
|
| 49 |
+
"none_param_grad": 9.313225746154785e-10,
|
| 50 |
+
"none_x_grad": 8.731149137020111e-11,
|
| 51 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 52 |
+
"sat_alibi_loss": 1.862645149230957e-09,
|
| 53 |
+
"sat_alibi_param_grad": 9.313225746154785e-10,
|
| 54 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 55 |
+
},
|
| 56 |
+
"tol": 0.0002
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"backend": "manual",
|
| 60 |
+
"d": 128,
|
| 61 |
+
"dk": 16,
|
| 62 |
+
"expected_k_width": 16,
|
| 63 |
+
"heads": 8,
|
| 64 |
+
"ok": true,
|
| 65 |
+
"preset": "micro_3x",
|
| 66 |
+
"rank": 48,
|
| 67 |
+
"rows": {
|
| 68 |
+
"cache_k_width": 16.0,
|
| 69 |
+
"cache_v_width": 16.0,
|
| 70 |
+
"cached_append_forward": 6.51925802230835e-08,
|
| 71 |
+
"causal_alibi_forward": 5.960464477539063e-08,
|
| 72 |
+
"causal_alibi_loss": 0.0,
|
| 73 |
+
"causal_alibi_param_grad": 4.656612873077393e-10,
|
| 74 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 75 |
+
"metric_cache_cleared_on_train": 0.0,
|
| 76 |
+
"metric_cache_reused": 0.0,
|
| 77 |
+
"none_forward": 5.960464477539063e-08,
|
| 78 |
+
"none_loss": 0.0,
|
| 79 |
+
"none_param_grad": 2.3283064365386963e-10,
|
| 80 |
+
"none_x_grad": 1.4551915228366852e-11,
|
| 81 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 82 |
+
"sat_alibi_loss": 1.862645149230957e-09,
|
| 83 |
+
"sat_alibi_param_grad": 5.820766091346741e-10,
|
| 84 |
+
"sat_alibi_x_grad": 5.820766091346741e-11
|
| 85 |
+
},
|
| 86 |
+
"tol": 0.0002
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"backend": "sdpa",
|
| 90 |
+
"d": 128,
|
| 91 |
+
"dk": 16,
|
| 92 |
+
"expected_k_width": 16,
|
| 93 |
+
"heads": 8,
|
| 94 |
+
"ok": true,
|
| 95 |
+
"preset": "micro_3x",
|
| 96 |
+
"rank": 48,
|
| 97 |
+
"rows": {
|
| 98 |
+
"cache_k_width": 16.0,
|
| 99 |
+
"cache_v_width": 16.0,
|
| 100 |
+
"cached_append_forward": 7.450580596923828e-08,
|
| 101 |
+
"causal_alibi_forward": 1.043081283569336e-07,
|
| 102 |
+
"causal_alibi_loss": 0.0,
|
| 103 |
+
"causal_alibi_param_grad": 4.656612873077393e-10,
|
| 104 |
+
"causal_alibi_x_grad": 8.731149137020111e-11,
|
| 105 |
+
"metric_cache_cleared_on_train": 0.0,
|
| 106 |
+
"metric_cache_reused": 0.0,
|
| 107 |
+
"none_forward": 8.940696716308594e-08,
|
| 108 |
+
"none_loss": 0.0,
|
| 109 |
+
"none_param_grad": 2.3283064365386963e-10,
|
| 110 |
+
"none_x_grad": 1.8189894035458565e-11,
|
| 111 |
+
"sat_alibi_forward": 1.1920928955078125e-07,
|
| 112 |
+
"sat_alibi_loss": 0.0,
|
| 113 |
+
"sat_alibi_param_grad": 6.984919309616089e-10,
|
| 114 |
+
"sat_alibi_x_grad": 7.275957614183426e-11
|
| 115 |
+
},
|
| 116 |
+
"tol": 0.0002
|
| 117 |
+
}
|
| 118 |
+
]
|
local_verify_qkv_agillm4.json
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"backend": "manual",
|
| 4 |
+
"d": 32,
|
| 5 |
+
"dk": 16,
|
| 6 |
+
"heads": 2,
|
| 7 |
+
"ok": true,
|
| 8 |
+
"preset": "pico_1x",
|
| 9 |
+
"rank": 16,
|
| 10 |
+
"rows": {
|
| 11 |
+
"causal_alibi_forward": 0.0,
|
| 12 |
+
"causal_alibi_loss": 0.0,
|
| 13 |
+
"causal_alibi_param_grad": 0.0,
|
| 14 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 15 |
+
"legacy_load_forward": 0.0,
|
| 16 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 17 |
+
"legacy_load_qkv_weight": 0.0,
|
| 18 |
+
"none_forward": 0.0,
|
| 19 |
+
"none_loss": 0.0,
|
| 20 |
+
"none_param_grad": 0.0,
|
| 21 |
+
"none_x_grad": 5.820766091346741e-11,
|
| 22 |
+
"optimizer_remap": 0.0,
|
| 23 |
+
"safe_load_any_loaded": 0.0,
|
| 24 |
+
"safe_load_any_qkv": 0.0,
|
| 25 |
+
"sat_alibi_forward": 0.0,
|
| 26 |
+
"sat_alibi_loss": 0.0,
|
| 27 |
+
"sat_alibi_param_grad": 0.0,
|
| 28 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 29 |
+
},
|
| 30 |
+
"tol": 0.0002
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"backend": "sdpa",
|
| 34 |
+
"d": 32,
|
| 35 |
+
"dk": 16,
|
| 36 |
+
"heads": 2,
|
| 37 |
+
"ok": true,
|
| 38 |
+
"preset": "pico_1x",
|
| 39 |
+
"rank": 16,
|
| 40 |
+
"rows": {
|
| 41 |
+
"causal_alibi_forward": 0.0,
|
| 42 |
+
"causal_alibi_loss": 0.0,
|
| 43 |
+
"causal_alibi_param_grad": 0.0,
|
| 44 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 45 |
+
"legacy_load_forward": 0.0,
|
| 46 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 47 |
+
"legacy_load_qkv_weight": 0.0,
|
| 48 |
+
"none_forward": 0.0,
|
| 49 |
+
"none_loss": 0.0,
|
| 50 |
+
"none_param_grad": 0.0,
|
| 51 |
+
"none_x_grad": 5.820766091346741e-11,
|
| 52 |
+
"safe_load_any_loaded": 0.0,
|
| 53 |
+
"safe_load_any_qkv": 0.0,
|
| 54 |
+
"sat_alibi_forward": 0.0,
|
| 55 |
+
"sat_alibi_loss": 0.0,
|
| 56 |
+
"sat_alibi_param_grad": 0.0,
|
| 57 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 58 |
+
},
|
| 59 |
+
"tol": 0.0002
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"backend": "manual",
|
| 63 |
+
"d": 128,
|
| 64 |
+
"dk": 16,
|
| 65 |
+
"heads": 8,
|
| 66 |
+
"ok": true,
|
| 67 |
+
"preset": "micro_3x",
|
| 68 |
+
"rank": 48,
|
| 69 |
+
"rows": {
|
| 70 |
+
"causal_alibi_forward": 0.0,
|
| 71 |
+
"causal_alibi_loss": 0.0,
|
| 72 |
+
"causal_alibi_param_grad": 0.0,
|
| 73 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 74 |
+
"legacy_load_forward": 0.0,
|
| 75 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 76 |
+
"legacy_load_qkv_weight": 0.0,
|
| 77 |
+
"none_forward": 0.0,
|
| 78 |
+
"none_loss": 0.0,
|
| 79 |
+
"none_param_grad": 0.0,
|
| 80 |
+
"none_x_grad": 1.4551915228366852e-11,
|
| 81 |
+
"safe_load_any_loaded": 0.0,
|
| 82 |
+
"safe_load_any_qkv": 0.0,
|
| 83 |
+
"sat_alibi_forward": 0.0,
|
| 84 |
+
"sat_alibi_loss": 0.0,
|
| 85 |
+
"sat_alibi_param_grad": 0.0,
|
| 86 |
+
"sat_alibi_x_grad": 5.820766091346741e-11
|
| 87 |
+
},
|
| 88 |
+
"tol": 0.0002
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"backend": "sdpa",
|
| 92 |
+
"d": 128,
|
| 93 |
+
"dk": 16,
|
| 94 |
+
"heads": 8,
|
| 95 |
+
"ok": true,
|
| 96 |
+
"preset": "micro_3x",
|
| 97 |
+
"rank": 48,
|
| 98 |
+
"rows": {
|
| 99 |
+
"causal_alibi_forward": 0.0,
|
| 100 |
+
"causal_alibi_loss": 0.0,
|
| 101 |
+
"causal_alibi_param_grad": 0.0,
|
| 102 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 103 |
+
"legacy_load_forward": 0.0,
|
| 104 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 105 |
+
"legacy_load_qkv_weight": 0.0,
|
| 106 |
+
"none_forward": 0.0,
|
| 107 |
+
"none_loss": 0.0,
|
| 108 |
+
"none_param_grad": 0.0,
|
| 109 |
+
"none_x_grad": 1.4551915228366852e-11,
|
| 110 |
+
"safe_load_any_loaded": 0.0,
|
| 111 |
+
"safe_load_any_qkv": 0.0,
|
| 112 |
+
"sat_alibi_forward": 0.0,
|
| 113 |
+
"sat_alibi_loss": 0.0,
|
| 114 |
+
"sat_alibi_param_grad": 0.0,
|
| 115 |
+
"sat_alibi_x_grad": 5.820766091346741e-11
|
| 116 |
+
},
|
| 117 |
+
"tol": 0.0002
|
| 118 |
+
}
|
| 119 |
+
]
|
local_verify_qkv_all_backends_agillm4.json
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"backend": "manual",
|
| 4 |
+
"d": 32,
|
| 5 |
+
"dk": 16,
|
| 6 |
+
"heads": 2,
|
| 7 |
+
"ok": true,
|
| 8 |
+
"preset": "pico_1x",
|
| 9 |
+
"rank": 16,
|
| 10 |
+
"rows": {
|
| 11 |
+
"causal_alibi_forward": 0.0,
|
| 12 |
+
"causal_alibi_loss": 0.0,
|
| 13 |
+
"causal_alibi_param_grad": 0.0,
|
| 14 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 15 |
+
"legacy_load_forward": 0.0,
|
| 16 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 17 |
+
"legacy_load_qkv_weight": 0.0,
|
| 18 |
+
"none_forward": 0.0,
|
| 19 |
+
"none_loss": 0.0,
|
| 20 |
+
"none_param_grad": 0.0,
|
| 21 |
+
"none_x_grad": 5.820766091346741e-11,
|
| 22 |
+
"optimizer_remap": 0.0,
|
| 23 |
+
"safe_load_any_loaded": 0.0,
|
| 24 |
+
"safe_load_any_qkv": 0.0,
|
| 25 |
+
"sat_alibi_forward": 0.0,
|
| 26 |
+
"sat_alibi_loss": 0.0,
|
| 27 |
+
"sat_alibi_param_grad": 0.0,
|
| 28 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 29 |
+
},
|
| 30 |
+
"tol": 0.0002
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"backend": "sdpa",
|
| 34 |
+
"d": 32,
|
| 35 |
+
"dk": 16,
|
| 36 |
+
"heads": 2,
|
| 37 |
+
"ok": true,
|
| 38 |
+
"preset": "pico_1x",
|
| 39 |
+
"rank": 16,
|
| 40 |
+
"rows": {
|
| 41 |
+
"causal_alibi_forward": 0.0,
|
| 42 |
+
"causal_alibi_loss": 0.0,
|
| 43 |
+
"causal_alibi_param_grad": 0.0,
|
| 44 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 45 |
+
"legacy_load_forward": 0.0,
|
| 46 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 47 |
+
"legacy_load_qkv_weight": 0.0,
|
| 48 |
+
"none_forward": 0.0,
|
| 49 |
+
"none_loss": 0.0,
|
| 50 |
+
"none_param_grad": 0.0,
|
| 51 |
+
"none_x_grad": 5.820766091346741e-11,
|
| 52 |
+
"safe_load_any_loaded": 0.0,
|
| 53 |
+
"safe_load_any_qkv": 0.0,
|
| 54 |
+
"sat_alibi_forward": 0.0,
|
| 55 |
+
"sat_alibi_loss": 0.0,
|
| 56 |
+
"sat_alibi_param_grad": 0.0,
|
| 57 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 58 |
+
},
|
| 59 |
+
"tol": 0.0002
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"backend": "sublinear",
|
| 63 |
+
"d": 32,
|
| 64 |
+
"dk": 16,
|
| 65 |
+
"heads": 2,
|
| 66 |
+
"ok": true,
|
| 67 |
+
"preset": "pico_1x",
|
| 68 |
+
"rank": 16,
|
| 69 |
+
"rows": {
|
| 70 |
+
"causal_alibi_forward": 0.0,
|
| 71 |
+
"causal_alibi_loss": 0.0,
|
| 72 |
+
"causal_alibi_param_grad": 0.0,
|
| 73 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 74 |
+
"legacy_load_forward": 0.0,
|
| 75 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 76 |
+
"legacy_load_qkv_weight": 0.0,
|
| 77 |
+
"none_forward": 0.0,
|
| 78 |
+
"none_loss": 0.0,
|
| 79 |
+
"none_param_grad": 0.0,
|
| 80 |
+
"none_x_grad": 5.820766091346741e-11,
|
| 81 |
+
"safe_load_any_loaded": 0.0,
|
| 82 |
+
"safe_load_any_qkv": 0.0,
|
| 83 |
+
"sat_alibi_forward": 0.0,
|
| 84 |
+
"sat_alibi_loss": 0.0,
|
| 85 |
+
"sat_alibi_param_grad": 0.0,
|
| 86 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 87 |
+
},
|
| 88 |
+
"tol": 0.0002
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"backend": "manual",
|
| 92 |
+
"d": 128,
|
| 93 |
+
"dk": 16,
|
| 94 |
+
"heads": 8,
|
| 95 |
+
"ok": true,
|
| 96 |
+
"preset": "micro_3x",
|
| 97 |
+
"rank": 48,
|
| 98 |
+
"rows": {
|
| 99 |
+
"causal_alibi_forward": 0.0,
|
| 100 |
+
"causal_alibi_loss": 0.0,
|
| 101 |
+
"causal_alibi_param_grad": 0.0,
|
| 102 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 103 |
+
"legacy_load_forward": 0.0,
|
| 104 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 105 |
+
"legacy_load_qkv_weight": 0.0,
|
| 106 |
+
"none_forward": 0.0,
|
| 107 |
+
"none_loss": 0.0,
|
| 108 |
+
"none_param_grad": 0.0,
|
| 109 |
+
"none_x_grad": 1.4551915228366852e-11,
|
| 110 |
+
"safe_load_any_loaded": 0.0,
|
| 111 |
+
"safe_load_any_qkv": 0.0,
|
| 112 |
+
"sat_alibi_forward": 0.0,
|
| 113 |
+
"sat_alibi_loss": 0.0,
|
| 114 |
+
"sat_alibi_param_grad": 0.0,
|
| 115 |
+
"sat_alibi_x_grad": 5.820766091346741e-11
|
| 116 |
+
},
|
| 117 |
+
"tol": 0.0002
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"backend": "sdpa",
|
| 121 |
+
"d": 128,
|
| 122 |
+
"dk": 16,
|
| 123 |
+
"heads": 8,
|
| 124 |
+
"ok": true,
|
| 125 |
+
"preset": "micro_3x",
|
| 126 |
+
"rank": 48,
|
| 127 |
+
"rows": {
|
| 128 |
+
"causal_alibi_forward": 0.0,
|
| 129 |
+
"causal_alibi_loss": 0.0,
|
| 130 |
+
"causal_alibi_param_grad": 0.0,
|
| 131 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 132 |
+
"legacy_load_forward": 0.0,
|
| 133 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 134 |
+
"legacy_load_qkv_weight": 0.0,
|
| 135 |
+
"none_forward": 0.0,
|
| 136 |
+
"none_loss": 0.0,
|
| 137 |
+
"none_param_grad": 0.0,
|
| 138 |
+
"none_x_grad": 1.4551915228366852e-11,
|
| 139 |
+
"safe_load_any_loaded": 0.0,
|
| 140 |
+
"safe_load_any_qkv": 0.0,
|
| 141 |
+
"sat_alibi_forward": 0.0,
|
| 142 |
+
"sat_alibi_loss": 0.0,
|
| 143 |
+
"sat_alibi_param_grad": 0.0,
|
| 144 |
+
"sat_alibi_x_grad": 5.820766091346741e-11
|
| 145 |
+
},
|
| 146 |
+
"tol": 0.0002
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"backend": "sublinear",
|
| 150 |
+
"d": 128,
|
| 151 |
+
"dk": 16,
|
| 152 |
+
"heads": 8,
|
| 153 |
+
"ok": true,
|
| 154 |
+
"preset": "micro_3x",
|
| 155 |
+
"rank": 48,
|
| 156 |
+
"rows": {
|
| 157 |
+
"causal_alibi_forward": 0.0,
|
| 158 |
+
"causal_alibi_loss": 0.0,
|
| 159 |
+
"causal_alibi_param_grad": 0.0,
|
| 160 |
+
"causal_alibi_x_grad": 5.820766091346741e-11,
|
| 161 |
+
"legacy_load_forward": 0.0,
|
| 162 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 163 |
+
"legacy_load_qkv_weight": 0.0,
|
| 164 |
+
"none_forward": 0.0,
|
| 165 |
+
"none_loss": 0.0,
|
| 166 |
+
"none_param_grad": 0.0,
|
| 167 |
+
"none_x_grad": 1.8189894035458565e-11,
|
| 168 |
+
"safe_load_any_loaded": 0.0,
|
| 169 |
+
"safe_load_any_qkv": 0.0,
|
| 170 |
+
"sat_alibi_forward": 0.0,
|
| 171 |
+
"sat_alibi_loss": 0.0,
|
| 172 |
+
"sat_alibi_param_grad": 0.0,
|
| 173 |
+
"sat_alibi_x_grad": 5.820766091346741e-11
|
| 174 |
+
},
|
| 175 |
+
"tol": 0.0002
|
| 176 |
+
}
|
| 177 |
+
]
|
local_verify_qkv_sublinear_agillm4.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"backend": "sublinear",
|
| 4 |
+
"d": 32,
|
| 5 |
+
"dk": 16,
|
| 6 |
+
"heads": 2,
|
| 7 |
+
"ok": true,
|
| 8 |
+
"preset": "pico_1x",
|
| 9 |
+
"rank": 16,
|
| 10 |
+
"rows": {
|
| 11 |
+
"causal_alibi_forward": 0.0,
|
| 12 |
+
"causal_alibi_loss": 0.0,
|
| 13 |
+
"causal_alibi_param_grad": 0.0,
|
| 14 |
+
"causal_alibi_x_grad": 2.3283064365386963e-10,
|
| 15 |
+
"legacy_load_forward": 0.0,
|
| 16 |
+
"legacy_load_missing_unexpected": 0.0,
|
| 17 |
+
"legacy_load_qkv_weight": 0.0,
|
| 18 |
+
"none_forward": 0.0,
|
| 19 |
+
"none_loss": 0.0,
|
| 20 |
+
"none_param_grad": 0.0,
|
| 21 |
+
"none_x_grad": 5.820766091346741e-11,
|
| 22 |
+
"safe_load_any_loaded": 0.0,
|
| 23 |
+
"safe_load_any_qkv": 0.0,
|
| 24 |
+
"sat_alibi_forward": 0.0,
|
| 25 |
+
"sat_alibi_loss": 0.0,
|
| 26 |
+
"sat_alibi_param_grad": 0.0,
|
| 27 |
+
"sat_alibi_x_grad": 2.3283064365386963e-10
|
| 28 |
+
},
|
| 29 |
+
"tol": 0.0002
|
| 30 |
+
}
|
| 31 |
+
]
|
nB300_agillm4.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
# Enhanced inference: checkpoint name, tok/s, UK time
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
-
import argparse, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess
|
| 8 |
from pathlib import Path
|
| 9 |
from contextlib import nullcontext
|
| 10 |
from typing import Dict, Any, List, Optional, Tuple
|
|
@@ -1147,9 +1147,10 @@ class TuneableAttentionMHA(nn.Module):
|
|
| 1147 |
self.sublinear_stride = max(0, int(sublinear_stride))
|
| 1148 |
self.sublinear_max_anchors = max(0, int(sublinear_max_anchors))
|
| 1149 |
self.sublinear_chunk = max(1, int(sublinear_chunk))
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
|
|
|
| 1153 |
self.U = nn.Parameter(torch.randn(self.dk, r))
|
| 1154 |
nn.init.orthogonal_(self.U)
|
| 1155 |
self.proj = nn.Linear(h * self.dk, d, bias=False)
|
|
@@ -1164,6 +1165,25 @@ class TuneableAttentionMHA(nn.Module):
|
|
| 1164 |
self._metric_cache_data_ptr: int = -1
|
| 1165 |
self._metric_cache_shape: Tuple[int, int] = (-1, -1)
|
| 1166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1167 |
def _proj_qk(self, x):
|
| 1168 |
B, N, _ = x.shape
|
| 1169 |
return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
|
|
@@ -1219,9 +1239,10 @@ class TuneableAttentionMHA(nn.Module):
|
|
| 1219 |
outputs = []
|
| 1220 |
scale = 1.0 / math.sqrt(self.dk)
|
| 1221 |
|
| 1222 |
-
|
|
|
|
| 1223 |
anchors = torch.arange(
|
| 1224 |
-
|
| 1225 |
k_len,
|
| 1226 |
self.sublinear_stride,
|
| 1227 |
device=device,
|
|
@@ -1273,9 +1294,8 @@ class TuneableAttentionMHA(nn.Module):
|
|
| 1273 |
return torch.cat(outputs, dim=2)
|
| 1274 |
|
| 1275 |
def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
|
| 1276 |
-
q_lin = self.
|
| 1277 |
-
|
| 1278 |
-
v_new = self._reshape_v(self.v(x))
|
| 1279 |
if self.r > self.dk:
|
| 1280 |
q = self._reshape_heads(q_lin) @ self._get_metric()
|
| 1281 |
k_new = self._reshape_heads(k_lin)
|
|
@@ -1498,6 +1518,141 @@ def _strip_orig_mod_prefix(state: dict) -> dict:
|
|
| 1498 |
for k, v in state.items()
|
| 1499 |
}
|
| 1500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
def save_delta(core, ar_h, sat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str):
|
| 1502 |
"""Save weight-only delta in background thread. Non-blocking."""
|
| 1503 |
global _delta_thread
|
|
@@ -1552,7 +1707,7 @@ def load_delta(path: pathlib.Path, core, ar_h, sat_h):
|
|
| 1552 |
ck = torch.load(path, map_location="cpu", weights_only=False)
|
| 1553 |
if not ck.get("delta"):
|
| 1554 |
raise ValueError(f"{path.name} is not a delta checkpoint")
|
| 1555 |
-
core.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["core"]))
|
| 1556 |
ar_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["ar"]))
|
| 1557 |
sat_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["sat"]))
|
| 1558 |
return ck.get("step", 0), ck.get("seen_tok", 0)
|
|
@@ -1586,11 +1741,25 @@ def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
|
|
| 1586 |
p = _resolve_ckpt(path) or path
|
| 1587 |
ck = _try_load(p, map_location="cpu")
|
| 1588 |
if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}")
|
| 1589 |
-
core.load_state_dict(_strip_orig_mod_prefix(ck["core"]))
|
| 1590 |
ar_h.load_state_dict(_strip_orig_mod_prefix(ck["ar"]))
|
| 1591 |
sat_h.load_state_dict(_strip_orig_mod_prefix(ck["sat"]))
|
| 1592 |
-
|
| 1593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1594 |
# Restore tokenizer from checkpoint if available
|
| 1595 |
if "tokenizer_json" in ck:
|
| 1596 |
try:
|
|
@@ -1614,8 +1783,11 @@ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
|
|
| 1614 |
sd = ck.get(key, ck) if key else ck
|
| 1615 |
if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
|
| 1616 |
sd = _strip_orig_mod_prefix(sd)
|
|
|
|
|
|
|
|
|
|
| 1617 |
tgt_sd = tgt.state_dict()
|
| 1618 |
-
filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
|
| 1619 |
if filt: tgt.load_state_dict(filt, strict=False)
|
| 1620 |
return len(filt)
|
| 1621 |
|
|
|
|
| 4 |
# Enhanced inference: checkpoint name, tok/s, UK time
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
+
import argparse, copy, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess
|
| 8 |
from pathlib import Path
|
| 9 |
from contextlib import nullcontext
|
| 10 |
from typing import Dict, Any, List, Optional, Tuple
|
|
|
|
| 1147 |
self.sublinear_stride = max(0, int(sublinear_stride))
|
| 1148 |
self.sublinear_max_anchors = max(0, int(sublinear_max_anchors))
|
| 1149 |
self.sublinear_chunk = max(1, int(sublinear_chunk))
|
| 1150 |
+
# Exact n1 harvest: one fused QKV projection is mathematically the same
|
| 1151 |
+
# as three independent bias-free Linear(d, d) projections with their
|
| 1152 |
+
# weights stacked along out_features.
|
| 1153 |
+
self.qkv = nn.Linear(d, 3 * d, bias=False)
|
| 1154 |
self.U = nn.Parameter(torch.randn(self.dk, r))
|
| 1155 |
nn.init.orthogonal_(self.U)
|
| 1156 |
self.proj = nn.Linear(h * self.dk, d, bias=False)
|
|
|
|
| 1165 |
self._metric_cache_data_ptr: int = -1
|
| 1166 |
self._metric_cache_shape: Tuple[int, int] = (-1, -1)
|
| 1167 |
|
| 1168 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 1169 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 1170 |
+
qkv_key = prefix + "qkv.weight"
|
| 1171 |
+
if qkv_key not in state_dict:
|
| 1172 |
+
qk = prefix + "q.weight"
|
| 1173 |
+
kk = prefix + "k.weight"
|
| 1174 |
+
vk = prefix + "v.weight"
|
| 1175 |
+
if qk in state_dict and kk in state_dict and vk in state_dict:
|
| 1176 |
+
fused = _cat_legacy_weight_blocks([state_dict[qk], state_dict[kk], state_dict[vk]])
|
| 1177 |
+
if fused is not None:
|
| 1178 |
+
state_dict[qkv_key] = fused
|
| 1179 |
+
state_dict.pop(qk)
|
| 1180 |
+
state_dict.pop(kk)
|
| 1181 |
+
state_dict.pop(vk)
|
| 1182 |
+
return super()._load_from_state_dict(
|
| 1183 |
+
state_dict, prefix, local_metadata, strict,
|
| 1184 |
+
missing_keys, unexpected_keys, error_msgs,
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
def _proj_qk(self, x):
|
| 1188 |
B, N, _ = x.shape
|
| 1189 |
return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
|
|
|
|
| 1239 |
outputs = []
|
| 1240 |
scale = 1.0 / math.sqrt(self.dk)
|
| 1241 |
|
| 1242 |
+
anchor_start = self.sublinear_stride - 1
|
| 1243 |
+
if self.sublinear_stride > 0 and self.sublinear_max_anchors > 0 and anchor_start < k_len:
|
| 1244 |
anchors = torch.arange(
|
| 1245 |
+
anchor_start,
|
| 1246 |
k_len,
|
| 1247 |
self.sublinear_stride,
|
| 1248 |
device=device,
|
|
|
|
| 1294 |
return torch.cat(outputs, dim=2)
|
| 1295 |
|
| 1296 |
def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
|
| 1297 |
+
q_lin, k_lin, v_lin = self.qkv(x).chunk(3, dim=-1)
|
| 1298 |
+
v_new = self._reshape_v(v_lin)
|
|
|
|
| 1299 |
if self.r > self.dk:
|
| 1300 |
q = self._reshape_heads(q_lin) @ self._get_metric()
|
| 1301 |
k_new = self._reshape_heads(k_lin)
|
|
|
|
| 1518 |
for k, v in state.items()
|
| 1519 |
}
|
| 1520 |
|
| 1521 |
+
def _cat_legacy_weight_blocks(blocks: list) -> Optional[torch.Tensor]:
|
| 1522 |
+
if not blocks or not all(torch.is_tensor(t) for t in blocks):
|
| 1523 |
+
return None
|
| 1524 |
+
first = blocks[0]
|
| 1525 |
+
tail_shape = tuple(first.shape[1:])
|
| 1526 |
+
if any(t.dtype != first.dtype or t.device != first.device for t in blocks):
|
| 1527 |
+
return None
|
| 1528 |
+
if any(t.ndim != first.ndim or tuple(t.shape[1:]) != tail_shape for t in blocks):
|
| 1529 |
+
return None
|
| 1530 |
+
return torch.cat(blocks, dim=0).contiguous()
|
| 1531 |
+
|
| 1532 |
+
def _fuse_qkv_in_state_dict(sd: dict) -> dict:
|
| 1533 |
+
"""Fold legacy q/k/v.weight triples into qkv.weight before loading/filtering."""
|
| 1534 |
+
if not isinstance(sd, dict):
|
| 1535 |
+
return sd
|
| 1536 |
+
prefixes = set()
|
| 1537 |
+
for key in list(sd.keys()):
|
| 1538 |
+
for suffix in (".q.weight", ".k.weight", ".v.weight"):
|
| 1539 |
+
if isinstance(key, str) and key.endswith(suffix):
|
| 1540 |
+
prefixes.add(key[: -len(suffix)])
|
| 1541 |
+
for prefix in prefixes:
|
| 1542 |
+
qk, kk, vk = prefix + ".q.weight", prefix + ".k.weight", prefix + ".v.weight"
|
| 1543 |
+
fk = prefix + ".qkv.weight"
|
| 1544 |
+
if qk in sd and kk in sd and vk in sd and fk not in sd:
|
| 1545 |
+
fused = _cat_legacy_weight_blocks([sd[qk], sd[kk], sd[vk]])
|
| 1546 |
+
if fused is not None:
|
| 1547 |
+
sd[fk] = fused
|
| 1548 |
+
sd.pop(qk)
|
| 1549 |
+
sd.pop(kk)
|
| 1550 |
+
sd.pop(vk)
|
| 1551 |
+
return sd
|
| 1552 |
+
|
| 1553 |
+
def _split_qkv_in_state_dict_for_test(sd: dict) -> dict:
|
| 1554 |
+
out = dict(sd)
|
| 1555 |
+
for key in list(out.keys()):
|
| 1556 |
+
if not isinstance(key, str) or not key.endswith(".qkv.weight"):
|
| 1557 |
+
continue
|
| 1558 |
+
base = key[: -len(".qkv.weight")]
|
| 1559 |
+
q, k, v = out.pop(key).chunk(3, dim=0)
|
| 1560 |
+
out[base + ".q.weight"] = q.clone()
|
| 1561 |
+
out[base + ".k.weight"] = k.clone()
|
| 1562 |
+
out[base + ".v.weight"] = v.clone()
|
| 1563 |
+
return out
|
| 1564 |
+
|
| 1565 |
+
def _clone_opt_value(value):
|
| 1566 |
+
if torch.is_tensor(value):
|
| 1567 |
+
return value.detach().clone()
|
| 1568 |
+
return copy.deepcopy(value)
|
| 1569 |
+
|
| 1570 |
+
def _optimizer_param_name_lookup(core, ar_h, sat_h) -> dict[int, str]:
|
| 1571 |
+
out = {}
|
| 1572 |
+
for prefix, module in (("core", core), ("ar", ar_h), ("sat", sat_h)):
|
| 1573 |
+
for name, param in module.named_parameters():
|
| 1574 |
+
out.setdefault(id(param), f"{prefix}.{name}")
|
| 1575 |
+
return out
|
| 1576 |
+
|
| 1577 |
+
def _optimizer_group_param_names(opt, core, ar_h, sat_h) -> List[List[str]]:
|
| 1578 |
+
lookup = _optimizer_param_name_lookup(core, ar_h, sat_h)
|
| 1579 |
+
return [
|
| 1580 |
+
[lookup.get(id(param), f"<unknown:{id(param)}>") for param in group["params"]]
|
| 1581 |
+
for group in opt.param_groups
|
| 1582 |
+
]
|
| 1583 |
+
|
| 1584 |
+
def _legacy_names_for_current_param(name: str) -> List[str]:
|
| 1585 |
+
if name.endswith(".qkv.weight"):
|
| 1586 |
+
base = name[: -len(".qkv.weight")]
|
| 1587 |
+
return [base + ".q.weight", base + ".k.weight", base + ".v.weight"]
|
| 1588 |
+
return [name]
|
| 1589 |
+
|
| 1590 |
+
def _fuse_legacy_optimizer_param_state(states: List[dict]) -> Optional[dict]:
|
| 1591 |
+
if len(states) < 2 or any(not isinstance(state, dict) for state in states):
|
| 1592 |
+
return None
|
| 1593 |
+
common = set(states[0])
|
| 1594 |
+
for state in states[1:]:
|
| 1595 |
+
common &= set(state)
|
| 1596 |
+
out = {}
|
| 1597 |
+
for key in common:
|
| 1598 |
+
vals = [state[key] for state in states]
|
| 1599 |
+
if all(torch.is_tensor(v) for v in vals):
|
| 1600 |
+
shape = vals[0].shape
|
| 1601 |
+
if vals[0].ndim > 0 and all(v.shape == shape for v in vals[1:]):
|
| 1602 |
+
out[key] = torch.cat([v.detach().clone() for v in vals], dim=0).contiguous()
|
| 1603 |
+
else:
|
| 1604 |
+
out[key] = vals[0].detach().clone()
|
| 1605 |
+
else:
|
| 1606 |
+
out[key] = copy.deepcopy(vals[0])
|
| 1607 |
+
return out
|
| 1608 |
+
|
| 1609 |
+
def _fuse_legacy_qkv_optimizer_state(opt_state: dict, opt, core, ar_h, sat_h) -> Optional[dict]:
|
| 1610 |
+
"""Remap pre-QKV-fusion AdamW state to the current fused parameter layout."""
|
| 1611 |
+
if not isinstance(opt_state, dict) or "state" not in opt_state or "param_groups" not in opt_state:
|
| 1612 |
+
return None
|
| 1613 |
+
current_sd = opt.state_dict()
|
| 1614 |
+
current_names = _optimizer_group_param_names(opt, core, ar_h, sat_h)
|
| 1615 |
+
legacy_names = [
|
| 1616 |
+
[legacy for name in group_names for legacy in _legacy_names_for_current_param(name)]
|
| 1617 |
+
for group_names in current_names
|
| 1618 |
+
]
|
| 1619 |
+
if len(legacy_names) != len(opt_state.get("param_groups", [])):
|
| 1620 |
+
return None
|
| 1621 |
+
|
| 1622 |
+
legacy_name_to_pid = {}
|
| 1623 |
+
for group_idx, names in enumerate(legacy_names):
|
| 1624 |
+
old_params = list(opt_state["param_groups"][group_idx].get("params", []))
|
| 1625 |
+
if len(names) != len(old_params):
|
| 1626 |
+
return None
|
| 1627 |
+
for name, pid in zip(names, old_params):
|
| 1628 |
+
legacy_name_to_pid[name] = pid
|
| 1629 |
+
|
| 1630 |
+
new_groups = []
|
| 1631 |
+
for group_idx, current_group in enumerate(current_sd["param_groups"]):
|
| 1632 |
+
new_group = copy.deepcopy(opt_state["param_groups"][group_idx])
|
| 1633 |
+
new_group["params"] = list(current_group["params"])
|
| 1634 |
+
if "param_names" in new_group:
|
| 1635 |
+
new_group["param_names"] = list(current_names[group_idx])
|
| 1636 |
+
new_groups.append(new_group)
|
| 1637 |
+
|
| 1638 |
+
old_states = opt_state.get("state", {})
|
| 1639 |
+
new_states = {}
|
| 1640 |
+
for group_names, current_group in zip(current_names, current_sd["param_groups"]):
|
| 1641 |
+
for name, new_pid in zip(group_names, current_group["params"]):
|
| 1642 |
+
legacy_set = _legacy_names_for_current_param(name)
|
| 1643 |
+
if len(legacy_set) > 1:
|
| 1644 |
+
old_pids = [legacy_name_to_pid.get(legacy) for legacy in legacy_set]
|
| 1645 |
+
if all(pid in old_states for pid in old_pids):
|
| 1646 |
+
fused = _fuse_legacy_optimizer_param_state([old_states[pid] for pid in old_pids])
|
| 1647 |
+
if fused is not None:
|
| 1648 |
+
new_states[new_pid] = fused
|
| 1649 |
+
continue
|
| 1650 |
+
old_pid = legacy_name_to_pid.get(name)
|
| 1651 |
+
if old_pid in old_states:
|
| 1652 |
+
new_states[new_pid] = {key: _clone_opt_value(value) for key, value in old_states[old_pid].items()}
|
| 1653 |
+
|
| 1654 |
+
return {"state": new_states, "param_groups": new_groups}
|
| 1655 |
+
|
| 1656 |
def save_delta(core, ar_h, sat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str):
|
| 1657 |
"""Save weight-only delta in background thread. Non-blocking."""
|
| 1658 |
global _delta_thread
|
|
|
|
| 1707 |
ck = torch.load(path, map_location="cpu", weights_only=False)
|
| 1708 |
if not ck.get("delta"):
|
| 1709 |
raise ValueError(f"{path.name} is not a delta checkpoint")
|
| 1710 |
+
core.load_state_dict(_fuse_qkv_in_state_dict(_strip_orig_mod_prefix(ck["weights"]["core"])))
|
| 1711 |
ar_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["ar"]))
|
| 1712 |
sat_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["sat"]))
|
| 1713 |
return ck.get("step", 0), ck.get("seen_tok", 0)
|
|
|
|
| 1741 |
p = _resolve_ckpt(path) or path
|
| 1742 |
ck = _try_load(p, map_location="cpu")
|
| 1743 |
if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}")
|
| 1744 |
+
core.load_state_dict(_fuse_qkv_in_state_dict(_strip_orig_mod_prefix(ck["core"])))
|
| 1745 |
ar_h.load_state_dict(_strip_orig_mod_prefix(ck["ar"]))
|
| 1746 |
sat_h.load_state_dict(_strip_orig_mod_prefix(ck["sat"]))
|
| 1747 |
+
try:
|
| 1748 |
+
opt.load_state_dict(ck["opt"])
|
| 1749 |
+
except Exception as exc:
|
| 1750 |
+
fused_opt = _fuse_legacy_qkv_optimizer_state(ck.get("opt"), opt, core, ar_h, sat_h)
|
| 1751 |
+
if fused_opt is not None:
|
| 1752 |
+
try:
|
| 1753 |
+
opt.load_state_dict(fused_opt)
|
| 1754 |
+
print("[ckpt] Converted legacy q/k/v optimizer state to fused qkv layout")
|
| 1755 |
+
except Exception as exc2:
|
| 1756 |
+
print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc}; qkv remap failed: {type(exc2).__name__}: {exc2})")
|
| 1757 |
+
else:
|
| 1758 |
+
print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc})")
|
| 1759 |
+
try:
|
| 1760 |
+
scaler.load_state_dict(ck["scaler"])
|
| 1761 |
+
except Exception as exc:
|
| 1762 |
+
print(f"[ckpt] WARNING: scaler state incompatible; resetting scaler ({type(exc).__name__}: {exc})")
|
| 1763 |
# Restore tokenizer from checkpoint if available
|
| 1764 |
if "tokenizer_json" in ck:
|
| 1765 |
try:
|
|
|
|
| 1783 |
sd = ck.get(key, ck) if key else ck
|
| 1784 |
if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
|
| 1785 |
sd = _strip_orig_mod_prefix(sd)
|
| 1786 |
+
sd = _fuse_qkv_in_state_dict(dict(sd)) if isinstance(sd, dict) else sd
|
| 1787 |
+
if not isinstance(sd, dict):
|
| 1788 |
+
return 0
|
| 1789 |
tgt_sd = tgt.state_dict()
|
| 1790 |
+
filt = {k: v for k, v in sd.items() if k in tgt_sd and hasattr(v, "shape") and v.shape == tgt_sd[k].shape}
|
| 1791 |
if filt: tgt.load_state_dict(filt, strict=False)
|
| 1792 |
return len(filt)
|
| 1793 |
|
verify_m_fold_agillm4.py
CHANGED
|
@@ -24,9 +24,10 @@ def causal_mask_cached(new_len: int, cached_len: int):
|
|
| 24 |
|
| 25 |
def old_expanded_forward(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
|
| 26 |
bsz, seq, _ = x.shape
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
|
| 31 |
if mha.use_relpos and rel_bias_tokens is not None:
|
| 32 |
att = att + nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
|
|
|
|
| 24 |
|
| 25 |
def old_expanded_forward(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
|
| 26 |
bsz, seq, _ = x.shape
|
| 27 |
+
q_lin, k_lin, v_lin = mha.qkv(x).chunk(3, dim=-1)
|
| 28 |
+
q = mha._reshape_heads(q_lin) @ mha.U
|
| 29 |
+
k = mha._reshape_heads(k_lin) @ mha.U
|
| 30 |
+
v = mha._reshape_v(v_lin)
|
| 31 |
att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
|
| 32 |
if mha.use_relpos and rel_bias_tokens is not None:
|
| 33 |
att = att + nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
|
verify_qkv_agillm4.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import copy
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from types import SimpleNamespace
|
| 12 |
+
|
| 13 |
+
os.environ.setdefault("AGILLM_SYNTHETIC_TOKENIZER", "1")
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
import nB300_agillm4 as nb
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def unfused_reference(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
|
| 22 |
+
bsz, seq, _ = x.shape
|
| 23 |
+
wq, wk, wv = mha.qkv.weight.chunk(3, dim=0)
|
| 24 |
+
q_lin = x @ wq.T
|
| 25 |
+
k_lin = x @ wk.T
|
| 26 |
+
v_lin = x @ wv.T
|
| 27 |
+
v = mha._reshape_v(v_lin)
|
| 28 |
+
if mha.r > mha.dk:
|
| 29 |
+
q = mha._reshape_heads(q_lin) @ mha._get_metric()
|
| 30 |
+
k = mha._reshape_heads(k_lin)
|
| 31 |
+
else:
|
| 32 |
+
q = mha._reshape_heads(q_lin) @ mha.U
|
| 33 |
+
k = mha._reshape_heads(k_lin) @ mha.U
|
| 34 |
+
|
| 35 |
+
attn_bias = None
|
| 36 |
+
if mha.use_relpos and rel_bias_tokens is not None:
|
| 37 |
+
attn_bias = nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
|
| 38 |
+
if mask is not None:
|
| 39 |
+
attn_bias = mask if attn_bias is None else attn_bias + mask
|
| 40 |
+
|
| 41 |
+
if mha.attn_backend == "sdpa":
|
| 42 |
+
try:
|
| 43 |
+
z = F.scaled_dot_product_attention(
|
| 44 |
+
q, k, v,
|
| 45 |
+
attn_mask=attn_bias,
|
| 46 |
+
dropout_p=0.0,
|
| 47 |
+
scale=1.0 / math.sqrt(mha.dk),
|
| 48 |
+
)
|
| 49 |
+
except TypeError:
|
| 50 |
+
q_scaled = q * math.sqrt(q.size(-1) / mha.dk)
|
| 51 |
+
z = F.scaled_dot_product_attention(q_scaled, k, v, attn_mask=attn_bias, dropout_p=0.0)
|
| 52 |
+
elif mha.attn_backend == "sublinear":
|
| 53 |
+
z = mha._sublinear_attention(q, k, v, attn_mask=attn_bias)
|
| 54 |
+
else:
|
| 55 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
|
| 56 |
+
if attn_bias is not None:
|
| 57 |
+
att = att + attn_bias
|
| 58 |
+
z = att.softmax(-1) @ v
|
| 59 |
+
z = z.transpose(1, 2).reshape(bsz, seq, -1)
|
| 60 |
+
return mha.drop(mha.proj(z))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def max_param_grad_diff(module, ref_grads: dict[str, torch.Tensor]) -> float:
|
| 64 |
+
out = 0.0
|
| 65 |
+
for name, param in module.named_parameters():
|
| 66 |
+
if param.grad is None:
|
| 67 |
+
continue
|
| 68 |
+
out = max(out, (param.grad.detach() - ref_grads[name]).abs().max().item())
|
| 69 |
+
return out
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def optimizer_for(core, ar_h, sat_h):
|
| 73 |
+
args = SimpleNamespace(optimizer="adamw")
|
| 74 |
+
return nb.make_optimizer(args, core, ar_h, sat_h, nb.LR_CORE, nb.LR_HEAD)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def split_qkv_optimizer_state_for_test(opt_state: dict, group_names: list[list[str]]) -> dict:
|
| 78 |
+
out = {"state": {}, "param_groups": []}
|
| 79 |
+
next_pid = 0
|
| 80 |
+
for src_group, names in zip(opt_state["param_groups"], group_names):
|
| 81 |
+
dst_group = copy.deepcopy(src_group)
|
| 82 |
+
dst_params = []
|
| 83 |
+
for src_pid, name in zip(src_group["params"], names):
|
| 84 |
+
src_state = opt_state.get("state", {}).get(src_pid, {})
|
| 85 |
+
if name.endswith(".qkv.weight"):
|
| 86 |
+
base = name[: -len(".qkv.weight")]
|
| 87 |
+
legacy_names = [base + ".q.weight", base + ".k.weight", base + ".v.weight"]
|
| 88 |
+
split_states = [{}, {}, {}]
|
| 89 |
+
for key, value in src_state.items():
|
| 90 |
+
if torch.is_tensor(value) and value.ndim > 0 and value.shape[0] % 3 == 0:
|
| 91 |
+
chunks = value.detach().chunk(3, dim=0)
|
| 92 |
+
for idx in range(3):
|
| 93 |
+
split_states[idx][key] = chunks[idx].clone().contiguous()
|
| 94 |
+
else:
|
| 95 |
+
for idx in range(3):
|
| 96 |
+
split_states[idx][key] = nb._clone_opt_value(value)
|
| 97 |
+
for legacy_name, split_state in zip(legacy_names, split_states):
|
| 98 |
+
dst_params.append(next_pid)
|
| 99 |
+
if split_state:
|
| 100 |
+
out["state"][next_pid] = split_state
|
| 101 |
+
next_pid += 1
|
| 102 |
+
else:
|
| 103 |
+
dst_params.append(next_pid)
|
| 104 |
+
if src_state:
|
| 105 |
+
out["state"][next_pid] = {key: nb._clone_opt_value(value) for key, value in src_state.items()}
|
| 106 |
+
next_pid += 1
|
| 107 |
+
dst_group["params"] = dst_params
|
| 108 |
+
out["param_groups"].append(dst_group)
|
| 109 |
+
return out
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def run_optimizer_remap_check(cfg: dict) -> dict:
|
| 113 |
+
core = nb.Encoder(cfg, attn_backend="manual").to(nb.DEV).train()
|
| 114 |
+
ar_h = nb.ARHead(cfg["d"]).to(nb.DEV).train()
|
| 115 |
+
sat_h = nb.SATHead(cfg["d"]).to(nb.DEV).train()
|
| 116 |
+
opt = optimizer_for(core, ar_h, sat_h)
|
| 117 |
+
ids = torch.randint(0, nb.VOCAB, (1, 8), device=nb.DEV)
|
| 118 |
+
loss = ar_h(core(ids, nb.causal_mask(ids.size(1)))).float().square().mean()
|
| 119 |
+
loss.backward()
|
| 120 |
+
opt.step()
|
| 121 |
+
opt.zero_grad(set_to_none=True)
|
| 122 |
+
current_names = nb._optimizer_group_param_names(opt, core, ar_h, sat_h)
|
| 123 |
+
legacy_opt = split_qkv_optimizer_state_for_test(opt.state_dict(), current_names)
|
| 124 |
+
|
| 125 |
+
fresh = nb.Encoder(cfg, attn_backend="manual").to(nb.DEV)
|
| 126 |
+
fresh_ar = nb.ARHead(cfg["d"]).to(nb.DEV)
|
| 127 |
+
fresh_sat = nb.SATHead(cfg["d"]).to(nb.DEV)
|
| 128 |
+
fresh_opt = optimizer_for(fresh, fresh_ar, fresh_sat)
|
| 129 |
+
fused = nb._fuse_legacy_qkv_optimizer_state(legacy_opt, fresh_opt, fresh, fresh_ar, fresh_sat)
|
| 130 |
+
ok = fused is not None
|
| 131 |
+
if ok:
|
| 132 |
+
fresh_opt.load_state_dict(fused)
|
| 133 |
+
loaded = fresh_opt.state_dict()
|
| 134 |
+
qkv_shape_ok = False
|
| 135 |
+
names = nb._optimizer_group_param_names(fresh_opt, fresh, fresh_ar, fresh_sat)
|
| 136 |
+
for group, group_names in zip(loaded["param_groups"], names):
|
| 137 |
+
for pid, name in zip(group["params"], group_names):
|
| 138 |
+
if name.endswith(".qkv.weight") and pid in loaded["state"]:
|
| 139 |
+
exp_avg = loaded["state"][pid].get("exp_avg")
|
| 140 |
+
qkv_shape_ok = torch.is_tensor(exp_avg) and tuple(exp_avg.shape) == tuple(dict(fresh.named_parameters())[name[len("core."):]].shape)
|
| 141 |
+
ok = ok and qkv_shape_ok
|
| 142 |
+
return {"optimizer_remap": 0.0 if ok else 1.0}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def verify_case(args, preset: str, backend: str) -> dict:
|
| 146 |
+
torch.manual_seed(args.seed)
|
| 147 |
+
cfg = nb.PRESETS[preset].copy()
|
| 148 |
+
d, h, r = cfg["d"], cfg["heads"], cfg["rank"]
|
| 149 |
+
seq = args.cached_len + args.new_len
|
| 150 |
+
mha = nb.TuneableAttentionMHA(d, h, r, attn_backend=backend).to(nb.DEV).eval()
|
| 151 |
+
rows = {}
|
| 152 |
+
for case_name, mask, rel_tokens in [
|
| 153 |
+
("none", None, None),
|
| 154 |
+
("causal_alibi", nb.causal_mask(seq), seq),
|
| 155 |
+
("sat_alibi", nb.sat_mask(seq), seq),
|
| 156 |
+
]:
|
| 157 |
+
x_fused = torch.randn(2, seq, d, device=nb.DEV, requires_grad=True)
|
| 158 |
+
x_ref = x_fused.detach().clone().requires_grad_(True)
|
| 159 |
+
y_fused = mha(x_fused, mask, rel_bias_tokens=rel_tokens)
|
| 160 |
+
y_ref = unfused_reference(mha, x_ref, mask=mask, rel_bias_tokens=rel_tokens)
|
| 161 |
+
loss_fused = y_fused.square().mean()
|
| 162 |
+
loss_ref = y_ref.square().mean()
|
| 163 |
+
loss_fused.backward()
|
| 164 |
+
fused_x_grad = x_fused.grad.detach().clone()
|
| 165 |
+
fused_param_grads = {
|
| 166 |
+
name: param.grad.detach().clone()
|
| 167 |
+
for name, param in mha.named_parameters()
|
| 168 |
+
if param.grad is not None
|
| 169 |
+
}
|
| 170 |
+
mha.zero_grad(set_to_none=True)
|
| 171 |
+
loss_ref.backward()
|
| 172 |
+
ref_x_grad = x_ref.grad.detach().clone()
|
| 173 |
+
rows[f"{case_name}_forward"] = (y_fused - y_ref).abs().max().item()
|
| 174 |
+
rows[f"{case_name}_loss"] = abs(loss_fused.item() - loss_ref.item())
|
| 175 |
+
rows[f"{case_name}_x_grad"] = (fused_x_grad - ref_x_grad).abs().max().item()
|
| 176 |
+
rows[f"{case_name}_param_grad"] = max_param_grad_diff(mha, fused_param_grads)
|
| 177 |
+
mha.zero_grad(set_to_none=True)
|
| 178 |
+
|
| 179 |
+
legacy_sd = nb._split_qkv_in_state_dict_for_test(mha.state_dict())
|
| 180 |
+
fresh = nb.TuneableAttentionMHA(d, h, r, attn_backend=backend).to(nb.DEV).eval()
|
| 181 |
+
missing, unexpected = fresh.load_state_dict(dict(legacy_sd), strict=True)
|
| 182 |
+
rows["legacy_load_missing_unexpected"] = float(len(missing) + len(unexpected))
|
| 183 |
+
rows["legacy_load_qkv_weight"] = (fresh.qkv.weight.detach() - mha.qkv.weight.detach()).abs().max().item()
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
x = torch.randn(2, seq, d, device=nb.DEV)
|
| 186 |
+
rows["legacy_load_forward"] = (fresh(x) - mha(x)).abs().max().item()
|
| 187 |
+
|
| 188 |
+
core = nb.Encoder(cfg, attn_backend=backend).to(nb.DEV).eval()
|
| 189 |
+
legacy_core_sd = nb._split_qkv_in_state_dict_for_test(core.state_dict())
|
| 190 |
+
dst = nb.Encoder(cfg, attn_backend=backend).to(nb.DEV).eval()
|
| 191 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 192 |
+
ckpt = Path(tmpdir) / "legacy_core.pt"
|
| 193 |
+
torch.save({"core": legacy_core_sd}, ckpt)
|
| 194 |
+
loaded = nb._safe_load_any(ckpt, dst, key="core")
|
| 195 |
+
rows["safe_load_any_loaded"] = 0.0 if loaded > 0 else 1.0
|
| 196 |
+
rows["safe_load_any_qkv"] = max(
|
| 197 |
+
(a - b).abs().max().item()
|
| 198 |
+
for (name, a), b in [
|
| 199 |
+
((name, param.detach()), dict(dst.named_parameters())[name].detach())
|
| 200 |
+
for name, param in core.named_parameters()
|
| 201 |
+
if name.endswith(".qkv.weight")
|
| 202 |
+
]
|
| 203 |
+
)
|
| 204 |
+
if preset == "pico_1x" and backend == "manual":
|
| 205 |
+
rows.update(run_optimizer_remap_check(cfg))
|
| 206 |
+
|
| 207 |
+
ok = all(value <= args.tol for value in rows.values())
|
| 208 |
+
return {
|
| 209 |
+
"preset": preset,
|
| 210 |
+
"backend": backend,
|
| 211 |
+
"d": d,
|
| 212 |
+
"heads": h,
|
| 213 |
+
"rank": r,
|
| 214 |
+
"dk": d // h,
|
| 215 |
+
"ok": ok,
|
| 216 |
+
"tol": args.tol,
|
| 217 |
+
"rows": rows,
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def main() -> int:
|
| 222 |
+
parser = argparse.ArgumentParser(description="Verify AGILLM-4 fused QKV harvest from n1.py")
|
| 223 |
+
parser.add_argument("--presets", default="pico_1x,micro_3x")
|
| 224 |
+
parser.add_argument("--backends", default="manual,sdpa")
|
| 225 |
+
parser.add_argument("--cached_len", type=int, default=8)
|
| 226 |
+
parser.add_argument("--new_len", type=int, default=4)
|
| 227 |
+
parser.add_argument("--seed", type=int, default=5678)
|
| 228 |
+
parser.add_argument("--tol", type=float, default=2e-4)
|
| 229 |
+
parser.add_argument("--json_out", default="")
|
| 230 |
+
args = parser.parse_args()
|
| 231 |
+
|
| 232 |
+
results = []
|
| 233 |
+
all_ok = True
|
| 234 |
+
for preset in [item.strip() for item in args.presets.split(",") if item.strip()]:
|
| 235 |
+
for backend in [item.strip() for item in args.backends.split(",") if item.strip()]:
|
| 236 |
+
result = verify_case(args, preset, backend)
|
| 237 |
+
results.append(result)
|
| 238 |
+
all_ok = all_ok and result["ok"]
|
| 239 |
+
print(json.dumps(result, sort_keys=True), flush=True)
|
| 240 |
+
if args.json_out:
|
| 241 |
+
Path(args.json_out).write_text(json.dumps(results, indent=2, sort_keys=True), encoding="utf-8")
|
| 242 |
+
return 0 if all_ok else 1
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
raise SystemExit(main())
|