OpenTransformer commited on
Commit
18b3e9e
·
verified ·
1 Parent(s): 421314d

Harvest fused QKV projection from n1

Browse files
AGILLM-4.md CHANGED
@@ -167,6 +167,20 @@ python /workspace/agillm-4/verify_m_fold_agillm4.py \
167
  --backends manual,sdpa
168
  ```
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  See `N1_HARVEST.md` for the staged port order.
171
 
172
  ## Intelligence per FLOP
 
167
  --backends manual,sdpa
168
  ```
169
 
170
+ Second harvested feature: fused QKV projection. AGILLM-4 now uses one
171
+ `Linear(d, 3d)` and chunks the result instead of running three separate
172
+ `Linear(d, d)` projections. Legacy `q.weight/k.weight/v.weight` checkpoints
173
+ load into the fused `qkv.weight` layout, and matching legacy AdamW moments are
174
+ remapped on full resume.
175
+
176
+ Verification:
177
+
178
+ ```bash
179
+ python /workspace/agillm-4/verify_qkv_agillm4.py \
180
+ --presets pico_1x,micro_3x \
181
+ --backends manual,sdpa,sublinear
182
+ ```
183
+
184
  See `N1_HARVEST.md` for the staged port order.
185
 
186
  ## Intelligence per FLOP
N1_HARVEST.md CHANGED
@@ -34,15 +34,43 @@ python agillm-4/verify_m_fold_agillm4.py \
34
  The verifier checks forward output, loss, input gradients, parameter gradients,
35
  cached append equivalence, cache key width, and metric-cache invalidation.
36
 
37
- ## Next Candidates
38
-
39
  ### 2. Fused QKV Projection
40
 
 
 
41
  n1 fuses separate `q/k/v` linear layers into one `qkv` linear while keeping
42
- checkpoint compatibility by folding old state-dict keys on load. This should be
43
- the next port after M-fold because it reduces three projection GEMMs to one.
 
 
 
 
 
 
44
 
45
- Risk: checkpoint key migration. Keep this separate from the M-fold port.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  ### 3. Combined ALiBi + Mask Cache
48
 
 
34
  The verifier checks forward output, loss, input gradients, parameter gradients,
35
  cached append equivalence, cache key width, and metric-cache invalidation.
36
 
 
 
37
  ### 2. Fused QKV Projection
38
 
39
+ Status: done.
40
+
41
  n1 fuses separate `q/k/v` linear layers into one `qkv` linear while keeping
42
+ checkpoint compatibility by folding old state-dict keys on load. AGILLM-4 now
43
+ does the same. The parameter count and function are unchanged:
44
+
45
+ ```text
46
+ [x Wq.T, x Wk.T, x Wv.T] == split(x [Wq; Wk; Wv].T)
47
+ ```
48
+
49
+ Checkpoint compatibility:
50
 
51
+ - legacy `*.q.weight`, `*.k.weight`, `*.v.weight` triples load into
52
+ `*.qkv.weight`
53
+ - warm-start shape filtering fuses legacy triples before filtering
54
+ - legacy AdamW q/k/v moment tensors are concatenated into qkv optimizer state
55
+ when a full resume can be proven to match the old parameter layout
56
+ - if optimizer remap cannot be proven, model weights still load and optimizer
57
+ state is reset with a warning
58
+
59
+ Verification:
60
+
61
+ ```bash
62
+ python agillm-4/verify_qkv_agillm4.py \
63
+ --presets pico_1x,micro_3x \
64
+ --backends manual,sdpa,sublinear \
65
+ --cached_len 8 \
66
+ --new_len 4
67
+ ```
68
+
69
+ The verifier checks fused-vs-unfused forward output, loss, input gradients,
70
+ parameter gradients, strict legacy state-dict loading, `_safe_load_any`
71
+ warm-start loading, and optimizer-state remap.
72
+
73
+ ## Next Candidates
74
 
75
  ### 3. Combined ALiBi + Mask Cache
76
 
README.md CHANGED
@@ -20,6 +20,7 @@ and extended for:
20
  - AR+SAT every step with sequential backward to reduce peak VRAM
21
  - SDPA and experimental sublinear local+landmark attention backends
22
  - exact M-fold expansion attention harvested from n1.py, with local verifier
 
23
  - profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
24
  - synthetic long-context curriculum generation for recall and multi-hop tests
25
 
 
20
  - AR+SAT every step with sequential backward to reduce peak VRAM
21
  - SDPA and experimental sublinear local+landmark attention backends
22
  - exact M-fold expansion attention harvested from n1.py, with local verifier
23
+ - fused QKV projection harvested from n1.py, with legacy checkpoint loading
24
  - profiling tools for memory, throughput, AR cost, SAT cost, and optimizer cost
25
  - synthetic long-context curriculum generation for recall and multi-hop tests
26
 
local_sweep_after_qkv_sdpa.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "alloc_gb": 0.027,
4
+ "amp": false,
5
+ "attn_backend": "sdpa",
6
+ "batch_size": 1,
7
+ "block": 64,
8
+ "elapsed_s": 0.538,
9
+ "error": null,
10
+ "grad_checkpoint": true,
11
+ "loss": 18.4698,
12
+ "ok": true,
13
+ "peak_alloc_gb": 0.031,
14
+ "peak_reserved_gb": 0.057,
15
+ "reserved_gb": 0.057,
16
+ "sublinear_chunk": 128,
17
+ "sublinear_max_anchors": 256,
18
+ "sublinear_stride": 64,
19
+ "sublinear_window": 256,
20
+ "tokens_per_s_synthetic": 119.1
21
+ }
22
+ ]
local_sweep_after_qkv_sublinear.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "alloc_gb": 0.027,
4
+ "amp": false,
5
+ "attn_backend": "sublinear",
6
+ "batch_size": 1,
7
+ "block": 64,
8
+ "elapsed_s": 0.698,
9
+ "error": null,
10
+ "grad_checkpoint": true,
11
+ "loss": 18.251,
12
+ "ok": true,
13
+ "peak_alloc_gb": 0.031,
14
+ "peak_reserved_gb": 0.057,
15
+ "reserved_gb": 0.057,
16
+ "sublinear_chunk": 16,
17
+ "sublinear_max_anchors": 16,
18
+ "sublinear_stride": 8,
19
+ "sublinear_window": 16,
20
+ "tokens_per_s_synthetic": 91.7
21
+ }
22
+ ]
local_verify_m_fold_after_qkv_agillm4.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "backend": "manual",
4
+ "d": 32,
5
+ "dk": 16,
6
+ "expected_k_width": 16,
7
+ "heads": 2,
8
+ "ok": true,
9
+ "preset": "pico_1x",
10
+ "rank": 16,
11
+ "rows": {
12
+ "cache_k_width": 16.0,
13
+ "cache_v_width": 16.0,
14
+ "cached_append_forward": 5.960464477539063e-08,
15
+ "causal_alibi_forward": 0.0,
16
+ "causal_alibi_loss": 0.0,
17
+ "causal_alibi_param_grad": 0.0,
18
+ "causal_alibi_x_grad": 0.0,
19
+ "none_forward": 0.0,
20
+ "none_loss": 0.0,
21
+ "none_param_grad": 0.0,
22
+ "none_x_grad": 0.0,
23
+ "sat_alibi_forward": 0.0,
24
+ "sat_alibi_loss": 0.0,
25
+ "sat_alibi_param_grad": 0.0,
26
+ "sat_alibi_x_grad": 0.0
27
+ },
28
+ "tol": 0.0002
29
+ },
30
+ {
31
+ "backend": "sdpa",
32
+ "d": 32,
33
+ "dk": 16,
34
+ "expected_k_width": 16,
35
+ "heads": 2,
36
+ "ok": true,
37
+ "preset": "pico_1x",
38
+ "rank": 16,
39
+ "rows": {
40
+ "cache_k_width": 16.0,
41
+ "cache_v_width": 16.0,
42
+ "cached_append_forward": 5.960464477539063e-08,
43
+ "causal_alibi_forward": 7.450580596923828e-08,
44
+ "causal_alibi_loss": 0.0,
45
+ "causal_alibi_param_grad": 1.862645149230957e-09,
46
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
47
+ "none_forward": 8.940696716308594e-08,
48
+ "none_loss": 0.0,
49
+ "none_param_grad": 9.313225746154785e-10,
50
+ "none_x_grad": 8.731149137020111e-11,
51
+ "sat_alibi_forward": 1.1920928955078125e-07,
52
+ "sat_alibi_loss": 1.862645149230957e-09,
53
+ "sat_alibi_param_grad": 9.313225746154785e-10,
54
+ "sat_alibi_x_grad": 2.3283064365386963e-10
55
+ },
56
+ "tol": 0.0002
57
+ },
58
+ {
59
+ "backend": "manual",
60
+ "d": 128,
61
+ "dk": 16,
62
+ "expected_k_width": 16,
63
+ "heads": 8,
64
+ "ok": true,
65
+ "preset": "micro_3x",
66
+ "rank": 48,
67
+ "rows": {
68
+ "cache_k_width": 16.0,
69
+ "cache_v_width": 16.0,
70
+ "cached_append_forward": 6.51925802230835e-08,
71
+ "causal_alibi_forward": 5.960464477539063e-08,
72
+ "causal_alibi_loss": 0.0,
73
+ "causal_alibi_param_grad": 4.656612873077393e-10,
74
+ "causal_alibi_x_grad": 5.820766091346741e-11,
75
+ "metric_cache_cleared_on_train": 0.0,
76
+ "metric_cache_reused": 0.0,
77
+ "none_forward": 5.960464477539063e-08,
78
+ "none_loss": 0.0,
79
+ "none_param_grad": 2.3283064365386963e-10,
80
+ "none_x_grad": 1.4551915228366852e-11,
81
+ "sat_alibi_forward": 1.1920928955078125e-07,
82
+ "sat_alibi_loss": 1.862645149230957e-09,
83
+ "sat_alibi_param_grad": 5.820766091346741e-10,
84
+ "sat_alibi_x_grad": 5.820766091346741e-11
85
+ },
86
+ "tol": 0.0002
87
+ },
88
+ {
89
+ "backend": "sdpa",
90
+ "d": 128,
91
+ "dk": 16,
92
+ "expected_k_width": 16,
93
+ "heads": 8,
94
+ "ok": true,
95
+ "preset": "micro_3x",
96
+ "rank": 48,
97
+ "rows": {
98
+ "cache_k_width": 16.0,
99
+ "cache_v_width": 16.0,
100
+ "cached_append_forward": 7.450580596923828e-08,
101
+ "causal_alibi_forward": 1.043081283569336e-07,
102
+ "causal_alibi_loss": 0.0,
103
+ "causal_alibi_param_grad": 4.656612873077393e-10,
104
+ "causal_alibi_x_grad": 8.731149137020111e-11,
105
+ "metric_cache_cleared_on_train": 0.0,
106
+ "metric_cache_reused": 0.0,
107
+ "none_forward": 8.940696716308594e-08,
108
+ "none_loss": 0.0,
109
+ "none_param_grad": 2.3283064365386963e-10,
110
+ "none_x_grad": 1.8189894035458565e-11,
111
+ "sat_alibi_forward": 1.1920928955078125e-07,
112
+ "sat_alibi_loss": 0.0,
113
+ "sat_alibi_param_grad": 6.984919309616089e-10,
114
+ "sat_alibi_x_grad": 7.275957614183426e-11
115
+ },
116
+ "tol": 0.0002
117
+ }
118
+ ]
local_verify_m_fold_after_qkv_fix_agillm4.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "backend": "manual",
4
+ "d": 32,
5
+ "dk": 16,
6
+ "expected_k_width": 16,
7
+ "heads": 2,
8
+ "ok": true,
9
+ "preset": "pico_1x",
10
+ "rank": 16,
11
+ "rows": {
12
+ "cache_k_width": 16.0,
13
+ "cache_v_width": 16.0,
14
+ "cached_append_forward": 5.960464477539063e-08,
15
+ "causal_alibi_forward": 0.0,
16
+ "causal_alibi_loss": 0.0,
17
+ "causal_alibi_param_grad": 0.0,
18
+ "causal_alibi_x_grad": 0.0,
19
+ "none_forward": 0.0,
20
+ "none_loss": 0.0,
21
+ "none_param_grad": 0.0,
22
+ "none_x_grad": 0.0,
23
+ "sat_alibi_forward": 0.0,
24
+ "sat_alibi_loss": 0.0,
25
+ "sat_alibi_param_grad": 0.0,
26
+ "sat_alibi_x_grad": 0.0
27
+ },
28
+ "tol": 0.0002
29
+ },
30
+ {
31
+ "backend": "sdpa",
32
+ "d": 32,
33
+ "dk": 16,
34
+ "expected_k_width": 16,
35
+ "heads": 2,
36
+ "ok": true,
37
+ "preset": "pico_1x",
38
+ "rank": 16,
39
+ "rows": {
40
+ "cache_k_width": 16.0,
41
+ "cache_v_width": 16.0,
42
+ "cached_append_forward": 5.960464477539063e-08,
43
+ "causal_alibi_forward": 7.450580596923828e-08,
44
+ "causal_alibi_loss": 0.0,
45
+ "causal_alibi_param_grad": 1.862645149230957e-09,
46
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
47
+ "none_forward": 8.940696716308594e-08,
48
+ "none_loss": 0.0,
49
+ "none_param_grad": 9.313225746154785e-10,
50
+ "none_x_grad": 8.731149137020111e-11,
51
+ "sat_alibi_forward": 1.1920928955078125e-07,
52
+ "sat_alibi_loss": 1.862645149230957e-09,
53
+ "sat_alibi_param_grad": 9.313225746154785e-10,
54
+ "sat_alibi_x_grad": 2.3283064365386963e-10
55
+ },
56
+ "tol": 0.0002
57
+ },
58
+ {
59
+ "backend": "manual",
60
+ "d": 128,
61
+ "dk": 16,
62
+ "expected_k_width": 16,
63
+ "heads": 8,
64
+ "ok": true,
65
+ "preset": "micro_3x",
66
+ "rank": 48,
67
+ "rows": {
68
+ "cache_k_width": 16.0,
69
+ "cache_v_width": 16.0,
70
+ "cached_append_forward": 6.51925802230835e-08,
71
+ "causal_alibi_forward": 5.960464477539063e-08,
72
+ "causal_alibi_loss": 0.0,
73
+ "causal_alibi_param_grad": 4.656612873077393e-10,
74
+ "causal_alibi_x_grad": 5.820766091346741e-11,
75
+ "metric_cache_cleared_on_train": 0.0,
76
+ "metric_cache_reused": 0.0,
77
+ "none_forward": 5.960464477539063e-08,
78
+ "none_loss": 0.0,
79
+ "none_param_grad": 2.3283064365386963e-10,
80
+ "none_x_grad": 1.4551915228366852e-11,
81
+ "sat_alibi_forward": 1.1920928955078125e-07,
82
+ "sat_alibi_loss": 1.862645149230957e-09,
83
+ "sat_alibi_param_grad": 5.820766091346741e-10,
84
+ "sat_alibi_x_grad": 5.820766091346741e-11
85
+ },
86
+ "tol": 0.0002
87
+ },
88
+ {
89
+ "backend": "sdpa",
90
+ "d": 128,
91
+ "dk": 16,
92
+ "expected_k_width": 16,
93
+ "heads": 8,
94
+ "ok": true,
95
+ "preset": "micro_3x",
96
+ "rank": 48,
97
+ "rows": {
98
+ "cache_k_width": 16.0,
99
+ "cache_v_width": 16.0,
100
+ "cached_append_forward": 7.450580596923828e-08,
101
+ "causal_alibi_forward": 1.043081283569336e-07,
102
+ "causal_alibi_loss": 0.0,
103
+ "causal_alibi_param_grad": 4.656612873077393e-10,
104
+ "causal_alibi_x_grad": 8.731149137020111e-11,
105
+ "metric_cache_cleared_on_train": 0.0,
106
+ "metric_cache_reused": 0.0,
107
+ "none_forward": 8.940696716308594e-08,
108
+ "none_loss": 0.0,
109
+ "none_param_grad": 2.3283064365386963e-10,
110
+ "none_x_grad": 1.8189894035458565e-11,
111
+ "sat_alibi_forward": 1.1920928955078125e-07,
112
+ "sat_alibi_loss": 0.0,
113
+ "sat_alibi_param_grad": 6.984919309616089e-10,
114
+ "sat_alibi_x_grad": 7.275957614183426e-11
115
+ },
116
+ "tol": 0.0002
117
+ }
118
+ ]
local_verify_qkv_agillm4.json ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "backend": "manual",
4
+ "d": 32,
5
+ "dk": 16,
6
+ "heads": 2,
7
+ "ok": true,
8
+ "preset": "pico_1x",
9
+ "rank": 16,
10
+ "rows": {
11
+ "causal_alibi_forward": 0.0,
12
+ "causal_alibi_loss": 0.0,
13
+ "causal_alibi_param_grad": 0.0,
14
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
15
+ "legacy_load_forward": 0.0,
16
+ "legacy_load_missing_unexpected": 0.0,
17
+ "legacy_load_qkv_weight": 0.0,
18
+ "none_forward": 0.0,
19
+ "none_loss": 0.0,
20
+ "none_param_grad": 0.0,
21
+ "none_x_grad": 5.820766091346741e-11,
22
+ "optimizer_remap": 0.0,
23
+ "safe_load_any_loaded": 0.0,
24
+ "safe_load_any_qkv": 0.0,
25
+ "sat_alibi_forward": 0.0,
26
+ "sat_alibi_loss": 0.0,
27
+ "sat_alibi_param_grad": 0.0,
28
+ "sat_alibi_x_grad": 2.3283064365386963e-10
29
+ },
30
+ "tol": 0.0002
31
+ },
32
+ {
33
+ "backend": "sdpa",
34
+ "d": 32,
35
+ "dk": 16,
36
+ "heads": 2,
37
+ "ok": true,
38
+ "preset": "pico_1x",
39
+ "rank": 16,
40
+ "rows": {
41
+ "causal_alibi_forward": 0.0,
42
+ "causal_alibi_loss": 0.0,
43
+ "causal_alibi_param_grad": 0.0,
44
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
45
+ "legacy_load_forward": 0.0,
46
+ "legacy_load_missing_unexpected": 0.0,
47
+ "legacy_load_qkv_weight": 0.0,
48
+ "none_forward": 0.0,
49
+ "none_loss": 0.0,
50
+ "none_param_grad": 0.0,
51
+ "none_x_grad": 5.820766091346741e-11,
52
+ "safe_load_any_loaded": 0.0,
53
+ "safe_load_any_qkv": 0.0,
54
+ "sat_alibi_forward": 0.0,
55
+ "sat_alibi_loss": 0.0,
56
+ "sat_alibi_param_grad": 0.0,
57
+ "sat_alibi_x_grad": 2.3283064365386963e-10
58
+ },
59
+ "tol": 0.0002
60
+ },
61
+ {
62
+ "backend": "manual",
63
+ "d": 128,
64
+ "dk": 16,
65
+ "heads": 8,
66
+ "ok": true,
67
+ "preset": "micro_3x",
68
+ "rank": 48,
69
+ "rows": {
70
+ "causal_alibi_forward": 0.0,
71
+ "causal_alibi_loss": 0.0,
72
+ "causal_alibi_param_grad": 0.0,
73
+ "causal_alibi_x_grad": 5.820766091346741e-11,
74
+ "legacy_load_forward": 0.0,
75
+ "legacy_load_missing_unexpected": 0.0,
76
+ "legacy_load_qkv_weight": 0.0,
77
+ "none_forward": 0.0,
78
+ "none_loss": 0.0,
79
+ "none_param_grad": 0.0,
80
+ "none_x_grad": 1.4551915228366852e-11,
81
+ "safe_load_any_loaded": 0.0,
82
+ "safe_load_any_qkv": 0.0,
83
+ "sat_alibi_forward": 0.0,
84
+ "sat_alibi_loss": 0.0,
85
+ "sat_alibi_param_grad": 0.0,
86
+ "sat_alibi_x_grad": 5.820766091346741e-11
87
+ },
88
+ "tol": 0.0002
89
+ },
90
+ {
91
+ "backend": "sdpa",
92
+ "d": 128,
93
+ "dk": 16,
94
+ "heads": 8,
95
+ "ok": true,
96
+ "preset": "micro_3x",
97
+ "rank": 48,
98
+ "rows": {
99
+ "causal_alibi_forward": 0.0,
100
+ "causal_alibi_loss": 0.0,
101
+ "causal_alibi_param_grad": 0.0,
102
+ "causal_alibi_x_grad": 5.820766091346741e-11,
103
+ "legacy_load_forward": 0.0,
104
+ "legacy_load_missing_unexpected": 0.0,
105
+ "legacy_load_qkv_weight": 0.0,
106
+ "none_forward": 0.0,
107
+ "none_loss": 0.0,
108
+ "none_param_grad": 0.0,
109
+ "none_x_grad": 1.4551915228366852e-11,
110
+ "safe_load_any_loaded": 0.0,
111
+ "safe_load_any_qkv": 0.0,
112
+ "sat_alibi_forward": 0.0,
113
+ "sat_alibi_loss": 0.0,
114
+ "sat_alibi_param_grad": 0.0,
115
+ "sat_alibi_x_grad": 5.820766091346741e-11
116
+ },
117
+ "tol": 0.0002
118
+ }
119
+ ]
local_verify_qkv_all_backends_agillm4.json ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "backend": "manual",
4
+ "d": 32,
5
+ "dk": 16,
6
+ "heads": 2,
7
+ "ok": true,
8
+ "preset": "pico_1x",
9
+ "rank": 16,
10
+ "rows": {
11
+ "causal_alibi_forward": 0.0,
12
+ "causal_alibi_loss": 0.0,
13
+ "causal_alibi_param_grad": 0.0,
14
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
15
+ "legacy_load_forward": 0.0,
16
+ "legacy_load_missing_unexpected": 0.0,
17
+ "legacy_load_qkv_weight": 0.0,
18
+ "none_forward": 0.0,
19
+ "none_loss": 0.0,
20
+ "none_param_grad": 0.0,
21
+ "none_x_grad": 5.820766091346741e-11,
22
+ "optimizer_remap": 0.0,
23
+ "safe_load_any_loaded": 0.0,
24
+ "safe_load_any_qkv": 0.0,
25
+ "sat_alibi_forward": 0.0,
26
+ "sat_alibi_loss": 0.0,
27
+ "sat_alibi_param_grad": 0.0,
28
+ "sat_alibi_x_grad": 2.3283064365386963e-10
29
+ },
30
+ "tol": 0.0002
31
+ },
32
+ {
33
+ "backend": "sdpa",
34
+ "d": 32,
35
+ "dk": 16,
36
+ "heads": 2,
37
+ "ok": true,
38
+ "preset": "pico_1x",
39
+ "rank": 16,
40
+ "rows": {
41
+ "causal_alibi_forward": 0.0,
42
+ "causal_alibi_loss": 0.0,
43
+ "causal_alibi_param_grad": 0.0,
44
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
45
+ "legacy_load_forward": 0.0,
46
+ "legacy_load_missing_unexpected": 0.0,
47
+ "legacy_load_qkv_weight": 0.0,
48
+ "none_forward": 0.0,
49
+ "none_loss": 0.0,
50
+ "none_param_grad": 0.0,
51
+ "none_x_grad": 5.820766091346741e-11,
52
+ "safe_load_any_loaded": 0.0,
53
+ "safe_load_any_qkv": 0.0,
54
+ "sat_alibi_forward": 0.0,
55
+ "sat_alibi_loss": 0.0,
56
+ "sat_alibi_param_grad": 0.0,
57
+ "sat_alibi_x_grad": 2.3283064365386963e-10
58
+ },
59
+ "tol": 0.0002
60
+ },
61
+ {
62
+ "backend": "sublinear",
63
+ "d": 32,
64
+ "dk": 16,
65
+ "heads": 2,
66
+ "ok": true,
67
+ "preset": "pico_1x",
68
+ "rank": 16,
69
+ "rows": {
70
+ "causal_alibi_forward": 0.0,
71
+ "causal_alibi_loss": 0.0,
72
+ "causal_alibi_param_grad": 0.0,
73
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
74
+ "legacy_load_forward": 0.0,
75
+ "legacy_load_missing_unexpected": 0.0,
76
+ "legacy_load_qkv_weight": 0.0,
77
+ "none_forward": 0.0,
78
+ "none_loss": 0.0,
79
+ "none_param_grad": 0.0,
80
+ "none_x_grad": 5.820766091346741e-11,
81
+ "safe_load_any_loaded": 0.0,
82
+ "safe_load_any_qkv": 0.0,
83
+ "sat_alibi_forward": 0.0,
84
+ "sat_alibi_loss": 0.0,
85
+ "sat_alibi_param_grad": 0.0,
86
+ "sat_alibi_x_grad": 2.3283064365386963e-10
87
+ },
88
+ "tol": 0.0002
89
+ },
90
+ {
91
+ "backend": "manual",
92
+ "d": 128,
93
+ "dk": 16,
94
+ "heads": 8,
95
+ "ok": true,
96
+ "preset": "micro_3x",
97
+ "rank": 48,
98
+ "rows": {
99
+ "causal_alibi_forward": 0.0,
100
+ "causal_alibi_loss": 0.0,
101
+ "causal_alibi_param_grad": 0.0,
102
+ "causal_alibi_x_grad": 5.820766091346741e-11,
103
+ "legacy_load_forward": 0.0,
104
+ "legacy_load_missing_unexpected": 0.0,
105
+ "legacy_load_qkv_weight": 0.0,
106
+ "none_forward": 0.0,
107
+ "none_loss": 0.0,
108
+ "none_param_grad": 0.0,
109
+ "none_x_grad": 1.4551915228366852e-11,
110
+ "safe_load_any_loaded": 0.0,
111
+ "safe_load_any_qkv": 0.0,
112
+ "sat_alibi_forward": 0.0,
113
+ "sat_alibi_loss": 0.0,
114
+ "sat_alibi_param_grad": 0.0,
115
+ "sat_alibi_x_grad": 5.820766091346741e-11
116
+ },
117
+ "tol": 0.0002
118
+ },
119
+ {
120
+ "backend": "sdpa",
121
+ "d": 128,
122
+ "dk": 16,
123
+ "heads": 8,
124
+ "ok": true,
125
+ "preset": "micro_3x",
126
+ "rank": 48,
127
+ "rows": {
128
+ "causal_alibi_forward": 0.0,
129
+ "causal_alibi_loss": 0.0,
130
+ "causal_alibi_param_grad": 0.0,
131
+ "causal_alibi_x_grad": 5.820766091346741e-11,
132
+ "legacy_load_forward": 0.0,
133
+ "legacy_load_missing_unexpected": 0.0,
134
+ "legacy_load_qkv_weight": 0.0,
135
+ "none_forward": 0.0,
136
+ "none_loss": 0.0,
137
+ "none_param_grad": 0.0,
138
+ "none_x_grad": 1.4551915228366852e-11,
139
+ "safe_load_any_loaded": 0.0,
140
+ "safe_load_any_qkv": 0.0,
141
+ "sat_alibi_forward": 0.0,
142
+ "sat_alibi_loss": 0.0,
143
+ "sat_alibi_param_grad": 0.0,
144
+ "sat_alibi_x_grad": 5.820766091346741e-11
145
+ },
146
+ "tol": 0.0002
147
+ },
148
+ {
149
+ "backend": "sublinear",
150
+ "d": 128,
151
+ "dk": 16,
152
+ "heads": 8,
153
+ "ok": true,
154
+ "preset": "micro_3x",
155
+ "rank": 48,
156
+ "rows": {
157
+ "causal_alibi_forward": 0.0,
158
+ "causal_alibi_loss": 0.0,
159
+ "causal_alibi_param_grad": 0.0,
160
+ "causal_alibi_x_grad": 5.820766091346741e-11,
161
+ "legacy_load_forward": 0.0,
162
+ "legacy_load_missing_unexpected": 0.0,
163
+ "legacy_load_qkv_weight": 0.0,
164
+ "none_forward": 0.0,
165
+ "none_loss": 0.0,
166
+ "none_param_grad": 0.0,
167
+ "none_x_grad": 1.8189894035458565e-11,
168
+ "safe_load_any_loaded": 0.0,
169
+ "safe_load_any_qkv": 0.0,
170
+ "sat_alibi_forward": 0.0,
171
+ "sat_alibi_loss": 0.0,
172
+ "sat_alibi_param_grad": 0.0,
173
+ "sat_alibi_x_grad": 5.820766091346741e-11
174
+ },
175
+ "tol": 0.0002
176
+ }
177
+ ]
local_verify_qkv_sublinear_agillm4.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "backend": "sublinear",
4
+ "d": 32,
5
+ "dk": 16,
6
+ "heads": 2,
7
+ "ok": true,
8
+ "preset": "pico_1x",
9
+ "rank": 16,
10
+ "rows": {
11
+ "causal_alibi_forward": 0.0,
12
+ "causal_alibi_loss": 0.0,
13
+ "causal_alibi_param_grad": 0.0,
14
+ "causal_alibi_x_grad": 2.3283064365386963e-10,
15
+ "legacy_load_forward": 0.0,
16
+ "legacy_load_missing_unexpected": 0.0,
17
+ "legacy_load_qkv_weight": 0.0,
18
+ "none_forward": 0.0,
19
+ "none_loss": 0.0,
20
+ "none_param_grad": 0.0,
21
+ "none_x_grad": 5.820766091346741e-11,
22
+ "safe_load_any_loaded": 0.0,
23
+ "safe_load_any_qkv": 0.0,
24
+ "sat_alibi_forward": 0.0,
25
+ "sat_alibi_loss": 0.0,
26
+ "sat_alibi_param_grad": 0.0,
27
+ "sat_alibi_x_grad": 2.3283064365386963e-10
28
+ },
29
+ "tol": 0.0002
30
+ }
31
+ ]
nB300_agillm4.py CHANGED
@@ -4,7 +4,7 @@
4
  # Enhanced inference: checkpoint name, tok/s, UK time
5
 
6
  from __future__ import annotations
7
- import argparse, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess
8
  from pathlib import Path
9
  from contextlib import nullcontext
10
  from typing import Dict, Any, List, Optional, Tuple
@@ -1147,9 +1147,10 @@ class TuneableAttentionMHA(nn.Module):
1147
  self.sublinear_stride = max(0, int(sublinear_stride))
1148
  self.sublinear_max_anchors = max(0, int(sublinear_max_anchors))
1149
  self.sublinear_chunk = max(1, int(sublinear_chunk))
1150
- self.q = nn.Linear(d, d, bias=False)
1151
- self.k = nn.Linear(d, d, bias=False)
1152
- self.v = nn.Linear(d, d, bias=False)
 
1153
  self.U = nn.Parameter(torch.randn(self.dk, r))
1154
  nn.init.orthogonal_(self.U)
1155
  self.proj = nn.Linear(h * self.dk, d, bias=False)
@@ -1164,6 +1165,25 @@ class TuneableAttentionMHA(nn.Module):
1164
  self._metric_cache_data_ptr: int = -1
1165
  self._metric_cache_shape: Tuple[int, int] = (-1, -1)
1166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1167
  def _proj_qk(self, x):
1168
  B, N, _ = x.shape
1169
  return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
@@ -1219,9 +1239,10 @@ class TuneableAttentionMHA(nn.Module):
1219
  outputs = []
1220
  scale = 1.0 / math.sqrt(self.dk)
1221
 
1222
- if self.sublinear_stride > 0 and self.sublinear_max_anchors > 0:
 
1223
  anchors = torch.arange(
1224
- self.sublinear_stride - 1,
1225
  k_len,
1226
  self.sublinear_stride,
1227
  device=device,
@@ -1273,9 +1294,8 @@ class TuneableAttentionMHA(nn.Module):
1273
  return torch.cat(outputs, dim=2)
1274
 
1275
  def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
1276
- q_lin = self.q(x)
1277
- k_lin = self.k(x)
1278
- v_new = self._reshape_v(self.v(x))
1279
  if self.r > self.dk:
1280
  q = self._reshape_heads(q_lin) @ self._get_metric()
1281
  k_new = self._reshape_heads(k_lin)
@@ -1498,6 +1518,141 @@ def _strip_orig_mod_prefix(state: dict) -> dict:
1498
  for k, v in state.items()
1499
  }
1500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1501
  def save_delta(core, ar_h, sat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str):
1502
  """Save weight-only delta in background thread. Non-blocking."""
1503
  global _delta_thread
@@ -1552,7 +1707,7 @@ def load_delta(path: pathlib.Path, core, ar_h, sat_h):
1552
  ck = torch.load(path, map_location="cpu", weights_only=False)
1553
  if not ck.get("delta"):
1554
  raise ValueError(f"{path.name} is not a delta checkpoint")
1555
- core.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["core"]))
1556
  ar_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["ar"]))
1557
  sat_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["sat"]))
1558
  return ck.get("step", 0), ck.get("seen_tok", 0)
@@ -1586,11 +1741,25 @@ def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
1586
  p = _resolve_ckpt(path) or path
1587
  ck = _try_load(p, map_location="cpu")
1588
  if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}")
1589
- core.load_state_dict(_strip_orig_mod_prefix(ck["core"]))
1590
  ar_h.load_state_dict(_strip_orig_mod_prefix(ck["ar"]))
1591
  sat_h.load_state_dict(_strip_orig_mod_prefix(ck["sat"]))
1592
- opt.load_state_dict(ck["opt"])
1593
- scaler.load_state_dict(ck["scaler"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1594
  # Restore tokenizer from checkpoint if available
1595
  if "tokenizer_json" in ck:
1596
  try:
@@ -1614,8 +1783,11 @@ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
1614
  sd = ck.get(key, ck) if key else ck
1615
  if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
1616
  sd = _strip_orig_mod_prefix(sd)
 
 
 
1617
  tgt_sd = tgt.state_dict()
1618
- filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
1619
  if filt: tgt.load_state_dict(filt, strict=False)
1620
  return len(filt)
1621
 
 
4
  # Enhanced inference: checkpoint name, tok/s, UK time
5
 
6
  from __future__ import annotations
7
+ import argparse, copy, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess
8
  from pathlib import Path
9
  from contextlib import nullcontext
10
  from typing import Dict, Any, List, Optional, Tuple
 
1147
  self.sublinear_stride = max(0, int(sublinear_stride))
1148
  self.sublinear_max_anchors = max(0, int(sublinear_max_anchors))
1149
  self.sublinear_chunk = max(1, int(sublinear_chunk))
1150
+ # Exact n1 harvest: one fused QKV projection is mathematically the same
1151
+ # as three independent bias-free Linear(d, d) projections with their
1152
+ # weights stacked along out_features.
1153
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
1154
  self.U = nn.Parameter(torch.randn(self.dk, r))
1155
  nn.init.orthogonal_(self.U)
1156
  self.proj = nn.Linear(h * self.dk, d, bias=False)
 
1165
  self._metric_cache_data_ptr: int = -1
1166
  self._metric_cache_shape: Tuple[int, int] = (-1, -1)
1167
 
1168
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
1169
+ missing_keys, unexpected_keys, error_msgs):
1170
+ qkv_key = prefix + "qkv.weight"
1171
+ if qkv_key not in state_dict:
1172
+ qk = prefix + "q.weight"
1173
+ kk = prefix + "k.weight"
1174
+ vk = prefix + "v.weight"
1175
+ if qk in state_dict and kk in state_dict and vk in state_dict:
1176
+ fused = _cat_legacy_weight_blocks([state_dict[qk], state_dict[kk], state_dict[vk]])
1177
+ if fused is not None:
1178
+ state_dict[qkv_key] = fused
1179
+ state_dict.pop(qk)
1180
+ state_dict.pop(kk)
1181
+ state_dict.pop(vk)
1182
+ return super()._load_from_state_dict(
1183
+ state_dict, prefix, local_metadata, strict,
1184
+ missing_keys, unexpected_keys, error_msgs,
1185
+ )
1186
+
1187
  def _proj_qk(self, x):
1188
  B, N, _ = x.shape
1189
  return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
 
1239
  outputs = []
1240
  scale = 1.0 / math.sqrt(self.dk)
1241
 
1242
+ anchor_start = self.sublinear_stride - 1
1243
+ if self.sublinear_stride > 0 and self.sublinear_max_anchors > 0 and anchor_start < k_len:
1244
  anchors = torch.arange(
1245
+ anchor_start,
1246
  k_len,
1247
  self.sublinear_stride,
1248
  device=device,
 
1294
  return torch.cat(outputs, dim=2)
1295
 
1296
  def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
1297
+ q_lin, k_lin, v_lin = self.qkv(x).chunk(3, dim=-1)
1298
+ v_new = self._reshape_v(v_lin)
 
1299
  if self.r > self.dk:
1300
  q = self._reshape_heads(q_lin) @ self._get_metric()
1301
  k_new = self._reshape_heads(k_lin)
 
1518
  for k, v in state.items()
1519
  }
1520
 
1521
+ def _cat_legacy_weight_blocks(blocks: list) -> Optional[torch.Tensor]:
1522
+ if not blocks or not all(torch.is_tensor(t) for t in blocks):
1523
+ return None
1524
+ first = blocks[0]
1525
+ tail_shape = tuple(first.shape[1:])
1526
+ if any(t.dtype != first.dtype or t.device != first.device for t in blocks):
1527
+ return None
1528
+ if any(t.ndim != first.ndim or tuple(t.shape[1:]) != tail_shape for t in blocks):
1529
+ return None
1530
+ return torch.cat(blocks, dim=0).contiguous()
1531
+
1532
+ def _fuse_qkv_in_state_dict(sd: dict) -> dict:
1533
+ """Fold legacy q/k/v.weight triples into qkv.weight before loading/filtering."""
1534
+ if not isinstance(sd, dict):
1535
+ return sd
1536
+ prefixes = set()
1537
+ for key in list(sd.keys()):
1538
+ for suffix in (".q.weight", ".k.weight", ".v.weight"):
1539
+ if isinstance(key, str) and key.endswith(suffix):
1540
+ prefixes.add(key[: -len(suffix)])
1541
+ for prefix in prefixes:
1542
+ qk, kk, vk = prefix + ".q.weight", prefix + ".k.weight", prefix + ".v.weight"
1543
+ fk = prefix + ".qkv.weight"
1544
+ if qk in sd and kk in sd and vk in sd and fk not in sd:
1545
+ fused = _cat_legacy_weight_blocks([sd[qk], sd[kk], sd[vk]])
1546
+ if fused is not None:
1547
+ sd[fk] = fused
1548
+ sd.pop(qk)
1549
+ sd.pop(kk)
1550
+ sd.pop(vk)
1551
+ return sd
1552
+
1553
+ def _split_qkv_in_state_dict_for_test(sd: dict) -> dict:
1554
+ out = dict(sd)
1555
+ for key in list(out.keys()):
1556
+ if not isinstance(key, str) or not key.endswith(".qkv.weight"):
1557
+ continue
1558
+ base = key[: -len(".qkv.weight")]
1559
+ q, k, v = out.pop(key).chunk(3, dim=0)
1560
+ out[base + ".q.weight"] = q.clone()
1561
+ out[base + ".k.weight"] = k.clone()
1562
+ out[base + ".v.weight"] = v.clone()
1563
+ return out
1564
+
1565
+ def _clone_opt_value(value):
1566
+ if torch.is_tensor(value):
1567
+ return value.detach().clone()
1568
+ return copy.deepcopy(value)
1569
+
1570
+ def _optimizer_param_name_lookup(core, ar_h, sat_h) -> dict[int, str]:
1571
+ out = {}
1572
+ for prefix, module in (("core", core), ("ar", ar_h), ("sat", sat_h)):
1573
+ for name, param in module.named_parameters():
1574
+ out.setdefault(id(param), f"{prefix}.{name}")
1575
+ return out
1576
+
1577
+ def _optimizer_group_param_names(opt, core, ar_h, sat_h) -> List[List[str]]:
1578
+ lookup = _optimizer_param_name_lookup(core, ar_h, sat_h)
1579
+ return [
1580
+ [lookup.get(id(param), f"<unknown:{id(param)}>") for param in group["params"]]
1581
+ for group in opt.param_groups
1582
+ ]
1583
+
1584
+ def _legacy_names_for_current_param(name: str) -> List[str]:
1585
+ if name.endswith(".qkv.weight"):
1586
+ base = name[: -len(".qkv.weight")]
1587
+ return [base + ".q.weight", base + ".k.weight", base + ".v.weight"]
1588
+ return [name]
1589
+
1590
+ def _fuse_legacy_optimizer_param_state(states: List[dict]) -> Optional[dict]:
1591
+ if len(states) < 2 or any(not isinstance(state, dict) for state in states):
1592
+ return None
1593
+ common = set(states[0])
1594
+ for state in states[1:]:
1595
+ common &= set(state)
1596
+ out = {}
1597
+ for key in common:
1598
+ vals = [state[key] for state in states]
1599
+ if all(torch.is_tensor(v) for v in vals):
1600
+ shape = vals[0].shape
1601
+ if vals[0].ndim > 0 and all(v.shape == shape for v in vals[1:]):
1602
+ out[key] = torch.cat([v.detach().clone() for v in vals], dim=0).contiguous()
1603
+ else:
1604
+ out[key] = vals[0].detach().clone()
1605
+ else:
1606
+ out[key] = copy.deepcopy(vals[0])
1607
+ return out
1608
+
1609
+ def _fuse_legacy_qkv_optimizer_state(opt_state: dict, opt, core, ar_h, sat_h) -> Optional[dict]:
1610
+ """Remap pre-QKV-fusion AdamW state to the current fused parameter layout."""
1611
+ if not isinstance(opt_state, dict) or "state" not in opt_state or "param_groups" not in opt_state:
1612
+ return None
1613
+ current_sd = opt.state_dict()
1614
+ current_names = _optimizer_group_param_names(opt, core, ar_h, sat_h)
1615
+ legacy_names = [
1616
+ [legacy for name in group_names for legacy in _legacy_names_for_current_param(name)]
1617
+ for group_names in current_names
1618
+ ]
1619
+ if len(legacy_names) != len(opt_state.get("param_groups", [])):
1620
+ return None
1621
+
1622
+ legacy_name_to_pid = {}
1623
+ for group_idx, names in enumerate(legacy_names):
1624
+ old_params = list(opt_state["param_groups"][group_idx].get("params", []))
1625
+ if len(names) != len(old_params):
1626
+ return None
1627
+ for name, pid in zip(names, old_params):
1628
+ legacy_name_to_pid[name] = pid
1629
+
1630
+ new_groups = []
1631
+ for group_idx, current_group in enumerate(current_sd["param_groups"]):
1632
+ new_group = copy.deepcopy(opt_state["param_groups"][group_idx])
1633
+ new_group["params"] = list(current_group["params"])
1634
+ if "param_names" in new_group:
1635
+ new_group["param_names"] = list(current_names[group_idx])
1636
+ new_groups.append(new_group)
1637
+
1638
+ old_states = opt_state.get("state", {})
1639
+ new_states = {}
1640
+ for group_names, current_group in zip(current_names, current_sd["param_groups"]):
1641
+ for name, new_pid in zip(group_names, current_group["params"]):
1642
+ legacy_set = _legacy_names_for_current_param(name)
1643
+ if len(legacy_set) > 1:
1644
+ old_pids = [legacy_name_to_pid.get(legacy) for legacy in legacy_set]
1645
+ if all(pid in old_states for pid in old_pids):
1646
+ fused = _fuse_legacy_optimizer_param_state([old_states[pid] for pid in old_pids])
1647
+ if fused is not None:
1648
+ new_states[new_pid] = fused
1649
+ continue
1650
+ old_pid = legacy_name_to_pid.get(name)
1651
+ if old_pid in old_states:
1652
+ new_states[new_pid] = {key: _clone_opt_value(value) for key, value in old_states[old_pid].items()}
1653
+
1654
+ return {"state": new_states, "param_groups": new_groups}
1655
+
1656
  def save_delta(core, ar_h, sat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str):
1657
  """Save weight-only delta in background thread. Non-blocking."""
1658
  global _delta_thread
 
1707
  ck = torch.load(path, map_location="cpu", weights_only=False)
1708
  if not ck.get("delta"):
1709
  raise ValueError(f"{path.name} is not a delta checkpoint")
1710
+ core.load_state_dict(_fuse_qkv_in_state_dict(_strip_orig_mod_prefix(ck["weights"]["core"])))
1711
  ar_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["ar"]))
1712
  sat_h.load_state_dict(_strip_orig_mod_prefix(ck["weights"]["sat"]))
1713
  return ck.get("step", 0), ck.get("seen_tok", 0)
 
1741
  p = _resolve_ckpt(path) or path
1742
  ck = _try_load(p, map_location="cpu")
1743
  if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}")
1744
+ core.load_state_dict(_fuse_qkv_in_state_dict(_strip_orig_mod_prefix(ck["core"])))
1745
  ar_h.load_state_dict(_strip_orig_mod_prefix(ck["ar"]))
1746
  sat_h.load_state_dict(_strip_orig_mod_prefix(ck["sat"]))
1747
+ try:
1748
+ opt.load_state_dict(ck["opt"])
1749
+ except Exception as exc:
1750
+ fused_opt = _fuse_legacy_qkv_optimizer_state(ck.get("opt"), opt, core, ar_h, sat_h)
1751
+ if fused_opt is not None:
1752
+ try:
1753
+ opt.load_state_dict(fused_opt)
1754
+ print("[ckpt] Converted legacy q/k/v optimizer state to fused qkv layout")
1755
+ except Exception as exc2:
1756
+ print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc}; qkv remap failed: {type(exc2).__name__}: {exc2})")
1757
+ else:
1758
+ print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc})")
1759
+ try:
1760
+ scaler.load_state_dict(ck["scaler"])
1761
+ except Exception as exc:
1762
+ print(f"[ckpt] WARNING: scaler state incompatible; resetting scaler ({type(exc).__name__}: {exc})")
1763
  # Restore tokenizer from checkpoint if available
1764
  if "tokenizer_json" in ck:
1765
  try:
 
1783
  sd = ck.get(key, ck) if key else ck
1784
  if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
1785
  sd = _strip_orig_mod_prefix(sd)
1786
+ sd = _fuse_qkv_in_state_dict(dict(sd)) if isinstance(sd, dict) else sd
1787
+ if not isinstance(sd, dict):
1788
+ return 0
1789
  tgt_sd = tgt.state_dict()
1790
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and hasattr(v, "shape") and v.shape == tgt_sd[k].shape}
1791
  if filt: tgt.load_state_dict(filt, strict=False)
1792
  return len(filt)
1793
 
verify_m_fold_agillm4.py CHANGED
@@ -24,9 +24,10 @@ def causal_mask_cached(new_len: int, cached_len: int):
24
 
25
  def old_expanded_forward(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
26
  bsz, seq, _ = x.shape
27
- q = mha._reshape_heads(mha.q(x)) @ mha.U
28
- k = mha._reshape_heads(mha.k(x)) @ mha.U
29
- v = mha._reshape_v(mha.v(x))
 
30
  att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
31
  if mha.use_relpos and rel_bias_tokens is not None:
32
  att = att + nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
 
24
 
25
  def old_expanded_forward(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
26
  bsz, seq, _ = x.shape
27
+ q_lin, k_lin, v_lin = mha.qkv(x).chunk(3, dim=-1)
28
+ q = mha._reshape_heads(q_lin) @ mha.U
29
+ k = mha._reshape_heads(k_lin) @ mha.U
30
+ v = mha._reshape_v(v_lin)
31
  att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
32
  if mha.use_relpos and rel_bias_tokens is not None:
33
  att = att + nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
verify_qkv_agillm4.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import copy
6
+ import json
7
+ import math
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ from types import SimpleNamespace
12
+
13
+ os.environ.setdefault("AGILLM_SYNTHETIC_TOKENIZER", "1")
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ import nB300_agillm4 as nb
19
+
20
+
21
+ def unfused_reference(mha: nb.TuneableAttentionMHA, x: torch.Tensor, mask=None, rel_bias_tokens=None):
22
+ bsz, seq, _ = x.shape
23
+ wq, wk, wv = mha.qkv.weight.chunk(3, dim=0)
24
+ q_lin = x @ wq.T
25
+ k_lin = x @ wk.T
26
+ v_lin = x @ wv.T
27
+ v = mha._reshape_v(v_lin)
28
+ if mha.r > mha.dk:
29
+ q = mha._reshape_heads(q_lin) @ mha._get_metric()
30
+ k = mha._reshape_heads(k_lin)
31
+ else:
32
+ q = mha._reshape_heads(q_lin) @ mha.U
33
+ k = mha._reshape_heads(k_lin) @ mha.U
34
+
35
+ attn_bias = None
36
+ if mha.use_relpos and rel_bias_tokens is not None:
37
+ attn_bias = nb.alibi_bias(mha.h, rel_bias_tokens)[:, :, -seq:, :]
38
+ if mask is not None:
39
+ attn_bias = mask if attn_bias is None else attn_bias + mask
40
+
41
+ if mha.attn_backend == "sdpa":
42
+ try:
43
+ z = F.scaled_dot_product_attention(
44
+ q, k, v,
45
+ attn_mask=attn_bias,
46
+ dropout_p=0.0,
47
+ scale=1.0 / math.sqrt(mha.dk),
48
+ )
49
+ except TypeError:
50
+ q_scaled = q * math.sqrt(q.size(-1) / mha.dk)
51
+ z = F.scaled_dot_product_attention(q_scaled, k, v, attn_mask=attn_bias, dropout_p=0.0)
52
+ elif mha.attn_backend == "sublinear":
53
+ z = mha._sublinear_attention(q, k, v, attn_mask=attn_bias)
54
+ else:
55
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
56
+ if attn_bias is not None:
57
+ att = att + attn_bias
58
+ z = att.softmax(-1) @ v
59
+ z = z.transpose(1, 2).reshape(bsz, seq, -1)
60
+ return mha.drop(mha.proj(z))
61
+
62
+
63
+ def max_param_grad_diff(module, ref_grads: dict[str, torch.Tensor]) -> float:
64
+ out = 0.0
65
+ for name, param in module.named_parameters():
66
+ if param.grad is None:
67
+ continue
68
+ out = max(out, (param.grad.detach() - ref_grads[name]).abs().max().item())
69
+ return out
70
+
71
+
72
+ def optimizer_for(core, ar_h, sat_h):
73
+ args = SimpleNamespace(optimizer="adamw")
74
+ return nb.make_optimizer(args, core, ar_h, sat_h, nb.LR_CORE, nb.LR_HEAD)
75
+
76
+
77
+ def split_qkv_optimizer_state_for_test(opt_state: dict, group_names: list[list[str]]) -> dict:
78
+ out = {"state": {}, "param_groups": []}
79
+ next_pid = 0
80
+ for src_group, names in zip(opt_state["param_groups"], group_names):
81
+ dst_group = copy.deepcopy(src_group)
82
+ dst_params = []
83
+ for src_pid, name in zip(src_group["params"], names):
84
+ src_state = opt_state.get("state", {}).get(src_pid, {})
85
+ if name.endswith(".qkv.weight"):
86
+ base = name[: -len(".qkv.weight")]
87
+ legacy_names = [base + ".q.weight", base + ".k.weight", base + ".v.weight"]
88
+ split_states = [{}, {}, {}]
89
+ for key, value in src_state.items():
90
+ if torch.is_tensor(value) and value.ndim > 0 and value.shape[0] % 3 == 0:
91
+ chunks = value.detach().chunk(3, dim=0)
92
+ for idx in range(3):
93
+ split_states[idx][key] = chunks[idx].clone().contiguous()
94
+ else:
95
+ for idx in range(3):
96
+ split_states[idx][key] = nb._clone_opt_value(value)
97
+ for legacy_name, split_state in zip(legacy_names, split_states):
98
+ dst_params.append(next_pid)
99
+ if split_state:
100
+ out["state"][next_pid] = split_state
101
+ next_pid += 1
102
+ else:
103
+ dst_params.append(next_pid)
104
+ if src_state:
105
+ out["state"][next_pid] = {key: nb._clone_opt_value(value) for key, value in src_state.items()}
106
+ next_pid += 1
107
+ dst_group["params"] = dst_params
108
+ out["param_groups"].append(dst_group)
109
+ return out
110
+
111
+
112
+ def run_optimizer_remap_check(cfg: dict) -> dict:
113
+ core = nb.Encoder(cfg, attn_backend="manual").to(nb.DEV).train()
114
+ ar_h = nb.ARHead(cfg["d"]).to(nb.DEV).train()
115
+ sat_h = nb.SATHead(cfg["d"]).to(nb.DEV).train()
116
+ opt = optimizer_for(core, ar_h, sat_h)
117
+ ids = torch.randint(0, nb.VOCAB, (1, 8), device=nb.DEV)
118
+ loss = ar_h(core(ids, nb.causal_mask(ids.size(1)))).float().square().mean()
119
+ loss.backward()
120
+ opt.step()
121
+ opt.zero_grad(set_to_none=True)
122
+ current_names = nb._optimizer_group_param_names(opt, core, ar_h, sat_h)
123
+ legacy_opt = split_qkv_optimizer_state_for_test(opt.state_dict(), current_names)
124
+
125
+ fresh = nb.Encoder(cfg, attn_backend="manual").to(nb.DEV)
126
+ fresh_ar = nb.ARHead(cfg["d"]).to(nb.DEV)
127
+ fresh_sat = nb.SATHead(cfg["d"]).to(nb.DEV)
128
+ fresh_opt = optimizer_for(fresh, fresh_ar, fresh_sat)
129
+ fused = nb._fuse_legacy_qkv_optimizer_state(legacy_opt, fresh_opt, fresh, fresh_ar, fresh_sat)
130
+ ok = fused is not None
131
+ if ok:
132
+ fresh_opt.load_state_dict(fused)
133
+ loaded = fresh_opt.state_dict()
134
+ qkv_shape_ok = False
135
+ names = nb._optimizer_group_param_names(fresh_opt, fresh, fresh_ar, fresh_sat)
136
+ for group, group_names in zip(loaded["param_groups"], names):
137
+ for pid, name in zip(group["params"], group_names):
138
+ if name.endswith(".qkv.weight") and pid in loaded["state"]:
139
+ exp_avg = loaded["state"][pid].get("exp_avg")
140
+ qkv_shape_ok = torch.is_tensor(exp_avg) and tuple(exp_avg.shape) == tuple(dict(fresh.named_parameters())[name[len("core."):]].shape)
141
+ ok = ok and qkv_shape_ok
142
+ return {"optimizer_remap": 0.0 if ok else 1.0}
143
+
144
+
145
+ def verify_case(args, preset: str, backend: str) -> dict:
146
+ torch.manual_seed(args.seed)
147
+ cfg = nb.PRESETS[preset].copy()
148
+ d, h, r = cfg["d"], cfg["heads"], cfg["rank"]
149
+ seq = args.cached_len + args.new_len
150
+ mha = nb.TuneableAttentionMHA(d, h, r, attn_backend=backend).to(nb.DEV).eval()
151
+ rows = {}
152
+ for case_name, mask, rel_tokens in [
153
+ ("none", None, None),
154
+ ("causal_alibi", nb.causal_mask(seq), seq),
155
+ ("sat_alibi", nb.sat_mask(seq), seq),
156
+ ]:
157
+ x_fused = torch.randn(2, seq, d, device=nb.DEV, requires_grad=True)
158
+ x_ref = x_fused.detach().clone().requires_grad_(True)
159
+ y_fused = mha(x_fused, mask, rel_bias_tokens=rel_tokens)
160
+ y_ref = unfused_reference(mha, x_ref, mask=mask, rel_bias_tokens=rel_tokens)
161
+ loss_fused = y_fused.square().mean()
162
+ loss_ref = y_ref.square().mean()
163
+ loss_fused.backward()
164
+ fused_x_grad = x_fused.grad.detach().clone()
165
+ fused_param_grads = {
166
+ name: param.grad.detach().clone()
167
+ for name, param in mha.named_parameters()
168
+ if param.grad is not None
169
+ }
170
+ mha.zero_grad(set_to_none=True)
171
+ loss_ref.backward()
172
+ ref_x_grad = x_ref.grad.detach().clone()
173
+ rows[f"{case_name}_forward"] = (y_fused - y_ref).abs().max().item()
174
+ rows[f"{case_name}_loss"] = abs(loss_fused.item() - loss_ref.item())
175
+ rows[f"{case_name}_x_grad"] = (fused_x_grad - ref_x_grad).abs().max().item()
176
+ rows[f"{case_name}_param_grad"] = max_param_grad_diff(mha, fused_param_grads)
177
+ mha.zero_grad(set_to_none=True)
178
+
179
+ legacy_sd = nb._split_qkv_in_state_dict_for_test(mha.state_dict())
180
+ fresh = nb.TuneableAttentionMHA(d, h, r, attn_backend=backend).to(nb.DEV).eval()
181
+ missing, unexpected = fresh.load_state_dict(dict(legacy_sd), strict=True)
182
+ rows["legacy_load_missing_unexpected"] = float(len(missing) + len(unexpected))
183
+ rows["legacy_load_qkv_weight"] = (fresh.qkv.weight.detach() - mha.qkv.weight.detach()).abs().max().item()
184
+ with torch.no_grad():
185
+ x = torch.randn(2, seq, d, device=nb.DEV)
186
+ rows["legacy_load_forward"] = (fresh(x) - mha(x)).abs().max().item()
187
+
188
+ core = nb.Encoder(cfg, attn_backend=backend).to(nb.DEV).eval()
189
+ legacy_core_sd = nb._split_qkv_in_state_dict_for_test(core.state_dict())
190
+ dst = nb.Encoder(cfg, attn_backend=backend).to(nb.DEV).eval()
191
+ with tempfile.TemporaryDirectory() as tmpdir:
192
+ ckpt = Path(tmpdir) / "legacy_core.pt"
193
+ torch.save({"core": legacy_core_sd}, ckpt)
194
+ loaded = nb._safe_load_any(ckpt, dst, key="core")
195
+ rows["safe_load_any_loaded"] = 0.0 if loaded > 0 else 1.0
196
+ rows["safe_load_any_qkv"] = max(
197
+ (a - b).abs().max().item()
198
+ for (name, a), b in [
199
+ ((name, param.detach()), dict(dst.named_parameters())[name].detach())
200
+ for name, param in core.named_parameters()
201
+ if name.endswith(".qkv.weight")
202
+ ]
203
+ )
204
+ if preset == "pico_1x" and backend == "manual":
205
+ rows.update(run_optimizer_remap_check(cfg))
206
+
207
+ ok = all(value <= args.tol for value in rows.values())
208
+ return {
209
+ "preset": preset,
210
+ "backend": backend,
211
+ "d": d,
212
+ "heads": h,
213
+ "rank": r,
214
+ "dk": d // h,
215
+ "ok": ok,
216
+ "tol": args.tol,
217
+ "rows": rows,
218
+ }
219
+
220
+
221
+ def main() -> int:
222
+ parser = argparse.ArgumentParser(description="Verify AGILLM-4 fused QKV harvest from n1.py")
223
+ parser.add_argument("--presets", default="pico_1x,micro_3x")
224
+ parser.add_argument("--backends", default="manual,sdpa")
225
+ parser.add_argument("--cached_len", type=int, default=8)
226
+ parser.add_argument("--new_len", type=int, default=4)
227
+ parser.add_argument("--seed", type=int, default=5678)
228
+ parser.add_argument("--tol", type=float, default=2e-4)
229
+ parser.add_argument("--json_out", default="")
230
+ args = parser.parse_args()
231
+
232
+ results = []
233
+ all_ok = True
234
+ for preset in [item.strip() for item in args.presets.split(",") if item.strip()]:
235
+ for backend in [item.strip() for item in args.backends.split(",") if item.strip()]:
236
+ result = verify_case(args, preset, backend)
237
+ results.append(result)
238
+ all_ok = all_ok and result["ok"]
239
+ print(json.dumps(result, sort_keys=True), flush=True)
240
+ if args.json_out:
241
+ Path(args.json_out).write_text(json.dumps(results, indent=2, sort_keys=True), encoding="utf-8")
242
+ return 0 if all_ok else 1
243
+
244
+
245
+ if __name__ == "__main__":
246
+ raise SystemExit(main())