seconds-0 commited on
Commit
622dbbd
·
verified ·
1 Parent(s): dcd5fea

NSA 117M initial export

Browse files
logs/logs_extra_keys.txt CHANGED
@@ -1,49 +1 @@
1
- blocks.0.attn.gate.fc1.bias
2
- blocks.0.attn.gate.fc1.weight
3
- blocks.0.attn.gate.fc2.bias
4
- blocks.0.attn.gate.fc2.weight
5
- blocks.1.attn.gate.fc1.bias
6
- blocks.1.attn.gate.fc1.weight
7
- blocks.1.attn.gate.fc2.bias
8
- blocks.1.attn.gate.fc2.weight
9
- blocks.10.attn.gate.fc1.bias
10
- blocks.10.attn.gate.fc1.weight
11
- blocks.10.attn.gate.fc2.bias
12
- blocks.10.attn.gate.fc2.weight
13
- blocks.11.attn.gate.fc1.bias
14
- blocks.11.attn.gate.fc1.weight
15
- blocks.11.attn.gate.fc2.bias
16
- blocks.11.attn.gate.fc2.weight
17
- blocks.2.attn.gate.fc1.bias
18
- blocks.2.attn.gate.fc1.weight
19
- blocks.2.attn.gate.fc2.bias
20
- blocks.2.attn.gate.fc2.weight
21
- blocks.3.attn.gate.fc1.bias
22
- blocks.3.attn.gate.fc1.weight
23
- blocks.3.attn.gate.fc2.bias
24
- blocks.3.attn.gate.fc2.weight
25
- blocks.4.attn.gate.fc1.bias
26
- blocks.4.attn.gate.fc1.weight
27
- blocks.4.attn.gate.fc2.bias
28
- blocks.4.attn.gate.fc2.weight
29
- blocks.5.attn.gate.fc1.bias
30
- blocks.5.attn.gate.fc1.weight
31
- blocks.5.attn.gate.fc2.bias
32
- blocks.5.attn.gate.fc2.weight
33
- blocks.6.attn.gate.fc1.bias
34
- blocks.6.attn.gate.fc1.weight
35
- blocks.6.attn.gate.fc2.bias
36
- blocks.6.attn.gate.fc2.weight
37
- blocks.7.attn.gate.fc1.bias
38
- blocks.7.attn.gate.fc1.weight
39
- blocks.7.attn.gate.fc2.bias
40
- blocks.7.attn.gate.fc2.weight
41
- blocks.8.attn.gate.fc1.bias
42
- blocks.8.attn.gate.fc1.weight
43
- blocks.8.attn.gate.fc2.bias
44
- blocks.8.attn.gate.fc2.weight
45
- blocks.9.attn.gate.fc1.bias
46
- blocks.9.attn.gate.fc1.weight
47
- blocks.9.attn.gate.fc2.bias
48
- blocks.9.attn.gate.fc2.weight
49
  norm_f.weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  norm_f.weight
logs/logs_mapping.json CHANGED
@@ -7,6 +7,10 @@
7
  "model.blocks.0.attn.W_V_cmp.weight",
8
  "model.blocks.0.attn.W_V_sel.weight",
9
  "model.blocks.0.attn.W_V_win.weight",
 
 
 
 
10
  "model.blocks.0.attn.out.weight",
11
  "model.blocks.0.mlp.fc1.weight",
12
  "model.blocks.0.mlp.fc2.weight",
@@ -19,6 +23,10 @@
19
  "model.blocks.1.attn.W_V_cmp.weight",
20
  "model.blocks.1.attn.W_V_sel.weight",
21
  "model.blocks.1.attn.W_V_win.weight",
 
 
 
 
22
  "model.blocks.1.attn.out.weight",
23
  "model.blocks.1.mlp.fc1.weight",
24
  "model.blocks.1.mlp.fc2.weight",
@@ -31,6 +39,10 @@
31
  "model.blocks.10.attn.W_V_cmp.weight",
32
  "model.blocks.10.attn.W_V_sel.weight",
33
  "model.blocks.10.attn.W_V_win.weight",
 
 
 
 
34
  "model.blocks.10.attn.out.weight",
35
  "model.blocks.10.mlp.fc1.weight",
36
  "model.blocks.10.mlp.fc2.weight",
@@ -43,6 +55,10 @@
43
  "model.blocks.11.attn.W_V_cmp.weight",
44
  "model.blocks.11.attn.W_V_sel.weight",
45
  "model.blocks.11.attn.W_V_win.weight",
 
 
 
 
46
  "model.blocks.11.attn.out.weight",
47
  "model.blocks.11.mlp.fc1.weight",
48
  "model.blocks.11.mlp.fc2.weight",
@@ -55,6 +71,10 @@
55
  "model.blocks.2.attn.W_V_cmp.weight",
56
  "model.blocks.2.attn.W_V_sel.weight",
57
  "model.blocks.2.attn.W_V_win.weight",
 
 
 
 
58
  "model.blocks.2.attn.out.weight",
59
  "model.blocks.2.mlp.fc1.weight",
60
  "model.blocks.2.mlp.fc2.weight",
@@ -67,6 +87,10 @@
67
  "model.blocks.3.attn.W_V_cmp.weight",
68
  "model.blocks.3.attn.W_V_sel.weight",
69
  "model.blocks.3.attn.W_V_win.weight",
 
 
 
 
70
  "model.blocks.3.attn.out.weight",
71
  "model.blocks.3.mlp.fc1.weight",
72
  "model.blocks.3.mlp.fc2.weight",
@@ -79,6 +103,10 @@
79
  "model.blocks.4.attn.W_V_cmp.weight",
80
  "model.blocks.4.attn.W_V_sel.weight",
81
  "model.blocks.4.attn.W_V_win.weight",
 
 
 
 
82
  "model.blocks.4.attn.out.weight",
83
  "model.blocks.4.mlp.fc1.weight",
84
  "model.blocks.4.mlp.fc2.weight",
@@ -91,6 +119,10 @@
91
  "model.blocks.5.attn.W_V_cmp.weight",
92
  "model.blocks.5.attn.W_V_sel.weight",
93
  "model.blocks.5.attn.W_V_win.weight",
 
 
 
 
94
  "model.blocks.5.attn.out.weight",
95
  "model.blocks.5.mlp.fc1.weight",
96
  "model.blocks.5.mlp.fc2.weight",
@@ -103,6 +135,10 @@
103
  "model.blocks.6.attn.W_V_cmp.weight",
104
  "model.blocks.6.attn.W_V_sel.weight",
105
  "model.blocks.6.attn.W_V_win.weight",
 
 
 
 
106
  "model.blocks.6.attn.out.weight",
107
  "model.blocks.6.mlp.fc1.weight",
108
  "model.blocks.6.mlp.fc2.weight",
@@ -115,6 +151,10 @@
115
  "model.blocks.7.attn.W_V_cmp.weight",
116
  "model.blocks.7.attn.W_V_sel.weight",
117
  "model.blocks.7.attn.W_V_win.weight",
 
 
 
 
118
  "model.blocks.7.attn.out.weight",
119
  "model.blocks.7.mlp.fc1.weight",
120
  "model.blocks.7.mlp.fc2.weight",
@@ -127,6 +167,10 @@
127
  "model.blocks.8.attn.W_V_cmp.weight",
128
  "model.blocks.8.attn.W_V_sel.weight",
129
  "model.blocks.8.attn.W_V_win.weight",
 
 
 
 
130
  "model.blocks.8.attn.out.weight",
131
  "model.blocks.8.mlp.fc1.weight",
132
  "model.blocks.8.mlp.fc2.weight",
@@ -139,6 +183,10 @@
139
  "model.blocks.9.attn.W_V_cmp.weight",
140
  "model.blocks.9.attn.W_V_sel.weight",
141
  "model.blocks.9.attn.W_V_win.weight",
 
 
 
 
142
  "model.blocks.9.attn.out.weight",
143
  "model.blocks.9.mlp.fc1.weight",
144
  "model.blocks.9.mlp.fc2.weight",
@@ -148,82 +196,10 @@
148
  "model.lm_head.weight"
149
  ],
150
  "missing": [
151
- "model.blocks.0.attn.g1.weight",
152
- "model.blocks.0.attn.g2.weight",
153
- "model.blocks.1.attn.g1.weight",
154
- "model.blocks.1.attn.g2.weight",
155
- "model.blocks.10.attn.g1.weight",
156
- "model.blocks.10.attn.g2.weight",
157
- "model.blocks.11.attn.g1.weight",
158
- "model.blocks.11.attn.g2.weight",
159
- "model.blocks.2.attn.g1.weight",
160
- "model.blocks.2.attn.g2.weight",
161
- "model.blocks.3.attn.g1.weight",
162
- "model.blocks.3.attn.g2.weight",
163
- "model.blocks.4.attn.g1.weight",
164
- "model.blocks.4.attn.g2.weight",
165
- "model.blocks.5.attn.g1.weight",
166
- "model.blocks.5.attn.g2.weight",
167
- "model.blocks.6.attn.g1.weight",
168
- "model.blocks.6.attn.g2.weight",
169
- "model.blocks.7.attn.g1.weight",
170
- "model.blocks.7.attn.g2.weight",
171
- "model.blocks.8.attn.g1.weight",
172
- "model.blocks.8.attn.g2.weight",
173
- "model.blocks.9.attn.g1.weight",
174
- "model.blocks.9.attn.g2.weight",
175
  "model.norm.bias",
176
  "model.norm.weight"
177
  ],
178
  "extra": [
179
- "blocks.0.attn.gate.fc1.bias",
180
- "blocks.0.attn.gate.fc1.weight",
181
- "blocks.0.attn.gate.fc2.bias",
182
- "blocks.0.attn.gate.fc2.weight",
183
- "blocks.1.attn.gate.fc1.bias",
184
- "blocks.1.attn.gate.fc1.weight",
185
- "blocks.1.attn.gate.fc2.bias",
186
- "blocks.1.attn.gate.fc2.weight",
187
- "blocks.10.attn.gate.fc1.bias",
188
- "blocks.10.attn.gate.fc1.weight",
189
- "blocks.10.attn.gate.fc2.bias",
190
- "blocks.10.attn.gate.fc2.weight",
191
- "blocks.11.attn.gate.fc1.bias",
192
- "blocks.11.attn.gate.fc1.weight",
193
- "blocks.11.attn.gate.fc2.bias",
194
- "blocks.11.attn.gate.fc2.weight",
195
- "blocks.2.attn.gate.fc1.bias",
196
- "blocks.2.attn.gate.fc1.weight",
197
- "blocks.2.attn.gate.fc2.bias",
198
- "blocks.2.attn.gate.fc2.weight",
199
- "blocks.3.attn.gate.fc1.bias",
200
- "blocks.3.attn.gate.fc1.weight",
201
- "blocks.3.attn.gate.fc2.bias",
202
- "blocks.3.attn.gate.fc2.weight",
203
- "blocks.4.attn.gate.fc1.bias",
204
- "blocks.4.attn.gate.fc1.weight",
205
- "blocks.4.attn.gate.fc2.bias",
206
- "blocks.4.attn.gate.fc2.weight",
207
- "blocks.5.attn.gate.fc1.bias",
208
- "blocks.5.attn.gate.fc1.weight",
209
- "blocks.5.attn.gate.fc2.bias",
210
- "blocks.5.attn.gate.fc2.weight",
211
- "blocks.6.attn.gate.fc1.bias",
212
- "blocks.6.attn.gate.fc1.weight",
213
- "blocks.6.attn.gate.fc2.bias",
214
- "blocks.6.attn.gate.fc2.weight",
215
- "blocks.7.attn.gate.fc1.bias",
216
- "blocks.7.attn.gate.fc1.weight",
217
- "blocks.7.attn.gate.fc2.bias",
218
- "blocks.7.attn.gate.fc2.weight",
219
- "blocks.8.attn.gate.fc1.bias",
220
- "blocks.8.attn.gate.fc1.weight",
221
- "blocks.8.attn.gate.fc2.bias",
222
- "blocks.8.attn.gate.fc2.weight",
223
- "blocks.9.attn.gate.fc1.bias",
224
- "blocks.9.attn.gate.fc1.weight",
225
- "blocks.9.attn.gate.fc2.bias",
226
- "blocks.9.attn.gate.fc2.weight",
227
  "norm_f.weight"
228
  ]
229
  }
 
7
  "model.blocks.0.attn.W_V_cmp.weight",
8
  "model.blocks.0.attn.W_V_sel.weight",
9
  "model.blocks.0.attn.W_V_win.weight",
10
+ "model.blocks.0.attn.gate.fc1.bias",
11
+ "model.blocks.0.attn.gate.fc1.weight",
12
+ "model.blocks.0.attn.gate.fc2.bias",
13
+ "model.blocks.0.attn.gate.fc2.weight",
14
  "model.blocks.0.attn.out.weight",
15
  "model.blocks.0.mlp.fc1.weight",
16
  "model.blocks.0.mlp.fc2.weight",
 
23
  "model.blocks.1.attn.W_V_cmp.weight",
24
  "model.blocks.1.attn.W_V_sel.weight",
25
  "model.blocks.1.attn.W_V_win.weight",
26
+ "model.blocks.1.attn.gate.fc1.bias",
27
+ "model.blocks.1.attn.gate.fc1.weight",
28
+ "model.blocks.1.attn.gate.fc2.bias",
29
+ "model.blocks.1.attn.gate.fc2.weight",
30
  "model.blocks.1.attn.out.weight",
31
  "model.blocks.1.mlp.fc1.weight",
32
  "model.blocks.1.mlp.fc2.weight",
 
39
  "model.blocks.10.attn.W_V_cmp.weight",
40
  "model.blocks.10.attn.W_V_sel.weight",
41
  "model.blocks.10.attn.W_V_win.weight",
42
+ "model.blocks.10.attn.gate.fc1.bias",
43
+ "model.blocks.10.attn.gate.fc1.weight",
44
+ "model.blocks.10.attn.gate.fc2.bias",
45
+ "model.blocks.10.attn.gate.fc2.weight",
46
  "model.blocks.10.attn.out.weight",
47
  "model.blocks.10.mlp.fc1.weight",
48
  "model.blocks.10.mlp.fc2.weight",
 
55
  "model.blocks.11.attn.W_V_cmp.weight",
56
  "model.blocks.11.attn.W_V_sel.weight",
57
  "model.blocks.11.attn.W_V_win.weight",
58
+ "model.blocks.11.attn.gate.fc1.bias",
59
+ "model.blocks.11.attn.gate.fc1.weight",
60
+ "model.blocks.11.attn.gate.fc2.bias",
61
+ "model.blocks.11.attn.gate.fc2.weight",
62
  "model.blocks.11.attn.out.weight",
63
  "model.blocks.11.mlp.fc1.weight",
64
  "model.blocks.11.mlp.fc2.weight",
 
71
  "model.blocks.2.attn.W_V_cmp.weight",
72
  "model.blocks.2.attn.W_V_sel.weight",
73
  "model.blocks.2.attn.W_V_win.weight",
74
+ "model.blocks.2.attn.gate.fc1.bias",
75
+ "model.blocks.2.attn.gate.fc1.weight",
76
+ "model.blocks.2.attn.gate.fc2.bias",
77
+ "model.blocks.2.attn.gate.fc2.weight",
78
  "model.blocks.2.attn.out.weight",
79
  "model.blocks.2.mlp.fc1.weight",
80
  "model.blocks.2.mlp.fc2.weight",
 
87
  "model.blocks.3.attn.W_V_cmp.weight",
88
  "model.blocks.3.attn.W_V_sel.weight",
89
  "model.blocks.3.attn.W_V_win.weight",
90
+ "model.blocks.3.attn.gate.fc1.bias",
91
+ "model.blocks.3.attn.gate.fc1.weight",
92
+ "model.blocks.3.attn.gate.fc2.bias",
93
+ "model.blocks.3.attn.gate.fc2.weight",
94
  "model.blocks.3.attn.out.weight",
95
  "model.blocks.3.mlp.fc1.weight",
96
  "model.blocks.3.mlp.fc2.weight",
 
103
  "model.blocks.4.attn.W_V_cmp.weight",
104
  "model.blocks.4.attn.W_V_sel.weight",
105
  "model.blocks.4.attn.W_V_win.weight",
106
+ "model.blocks.4.attn.gate.fc1.bias",
107
+ "model.blocks.4.attn.gate.fc1.weight",
108
+ "model.blocks.4.attn.gate.fc2.bias",
109
+ "model.blocks.4.attn.gate.fc2.weight",
110
  "model.blocks.4.attn.out.weight",
111
  "model.blocks.4.mlp.fc1.weight",
112
  "model.blocks.4.mlp.fc2.weight",
 
119
  "model.blocks.5.attn.W_V_cmp.weight",
120
  "model.blocks.5.attn.W_V_sel.weight",
121
  "model.blocks.5.attn.W_V_win.weight",
122
+ "model.blocks.5.attn.gate.fc1.bias",
123
+ "model.blocks.5.attn.gate.fc1.weight",
124
+ "model.blocks.5.attn.gate.fc2.bias",
125
+ "model.blocks.5.attn.gate.fc2.weight",
126
  "model.blocks.5.attn.out.weight",
127
  "model.blocks.5.mlp.fc1.weight",
128
  "model.blocks.5.mlp.fc2.weight",
 
135
  "model.blocks.6.attn.W_V_cmp.weight",
136
  "model.blocks.6.attn.W_V_sel.weight",
137
  "model.blocks.6.attn.W_V_win.weight",
138
+ "model.blocks.6.attn.gate.fc1.bias",
139
+ "model.blocks.6.attn.gate.fc1.weight",
140
+ "model.blocks.6.attn.gate.fc2.bias",
141
+ "model.blocks.6.attn.gate.fc2.weight",
142
  "model.blocks.6.attn.out.weight",
143
  "model.blocks.6.mlp.fc1.weight",
144
  "model.blocks.6.mlp.fc2.weight",
 
151
  "model.blocks.7.attn.W_V_cmp.weight",
152
  "model.blocks.7.attn.W_V_sel.weight",
153
  "model.blocks.7.attn.W_V_win.weight",
154
+ "model.blocks.7.attn.gate.fc1.bias",
155
+ "model.blocks.7.attn.gate.fc1.weight",
156
+ "model.blocks.7.attn.gate.fc2.bias",
157
+ "model.blocks.7.attn.gate.fc2.weight",
158
  "model.blocks.7.attn.out.weight",
159
  "model.blocks.7.mlp.fc1.weight",
160
  "model.blocks.7.mlp.fc2.weight",
 
167
  "model.blocks.8.attn.W_V_cmp.weight",
168
  "model.blocks.8.attn.W_V_sel.weight",
169
  "model.blocks.8.attn.W_V_win.weight",
170
+ "model.blocks.8.attn.gate.fc1.bias",
171
+ "model.blocks.8.attn.gate.fc1.weight",
172
+ "model.blocks.8.attn.gate.fc2.bias",
173
+ "model.blocks.8.attn.gate.fc2.weight",
174
  "model.blocks.8.attn.out.weight",
175
  "model.blocks.8.mlp.fc1.weight",
176
  "model.blocks.8.mlp.fc2.weight",
 
183
  "model.blocks.9.attn.W_V_cmp.weight",
184
  "model.blocks.9.attn.W_V_sel.weight",
185
  "model.blocks.9.attn.W_V_win.weight",
186
+ "model.blocks.9.attn.gate.fc1.bias",
187
+ "model.blocks.9.attn.gate.fc1.weight",
188
+ "model.blocks.9.attn.gate.fc2.bias",
189
+ "model.blocks.9.attn.gate.fc2.weight",
190
  "model.blocks.9.attn.out.weight",
191
  "model.blocks.9.mlp.fc1.weight",
192
  "model.blocks.9.mlp.fc2.weight",
 
196
  "model.lm_head.weight"
197
  ],
198
  "missing": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  "model.norm.bias",
200
  "model.norm.weight"
201
  ],
202
  "extra": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  "norm_f.weight"
204
  ]
205
  }
logs/logs_missing_keys.txt CHANGED
@@ -1,26 +1,2 @@
1
- model.blocks.0.attn.g1.weight
2
- model.blocks.0.attn.g2.weight
3
- model.blocks.1.attn.g1.weight
4
- model.blocks.1.attn.g2.weight
5
- model.blocks.10.attn.g1.weight
6
- model.blocks.10.attn.g2.weight
7
- model.blocks.11.attn.g1.weight
8
- model.blocks.11.attn.g2.weight
9
- model.blocks.2.attn.g1.weight
10
- model.blocks.2.attn.g2.weight
11
- model.blocks.3.attn.g1.weight
12
- model.blocks.3.attn.g2.weight
13
- model.blocks.4.attn.g1.weight
14
- model.blocks.4.attn.g2.weight
15
- model.blocks.5.attn.g1.weight
16
- model.blocks.5.attn.g2.weight
17
- model.blocks.6.attn.g1.weight
18
- model.blocks.6.attn.g2.weight
19
- model.blocks.7.attn.g1.weight
20
- model.blocks.7.attn.g2.weight
21
- model.blocks.8.attn.g1.weight
22
- model.blocks.8.attn.g2.weight
23
- model.blocks.9.attn.g1.weight
24
- model.blocks.9.attn.g2.weight
25
  model.norm.bias
26
  model.norm.weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  model.norm.bias
2
  model.norm.weight
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:171c4c893cc17e5a5f4ff0ab2f43cc0a7f48e50b47b48b4401b0e21b7b0eacb4
3
- size 320203152
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21d3ac54cadd49cc11ea0e88d37874aa3a9391e7b47a2704b6557a0e9229640c
3
+ size 313204736
modeling_nsa.py CHANGED
@@ -9,7 +9,12 @@ from transformers.generation.utils import GenerationMixin
9
  from transformers.modeling_outputs import CausalLMOutput
10
 
11
  from .configuration_nsa import NSAConfig
12
- _HAS_NSA = False # Embedded NSA is provided below; no external import required.
 
 
 
 
 
13
 
14
 
15
  class RMSNorm(nn.Module):
@@ -100,9 +105,12 @@ class EmbeddedNSAAttention(nn.Module):
100
  self.W_V_sel = nn.Linear(dim, n_kv_groups * d_v, bias=False)
101
  self.W_K_win = nn.Linear(dim, n_kv_groups * d_k, bias=False)
102
  self.W_V_win = nn.Linear(dim, n_kv_groups * d_v, bias=False)
103
- self.g1 = nn.Linear(dim, max(1, dim // 4), bias=False)
104
- self.g2 = nn.Linear(max(1, dim // 4), 3, bias=False)
105
- nn.init.zeros_(self.g2.weight)
 
 
 
106
  self.out = nn.Linear(n_heads * d_v, dim, bias=False)
107
 
108
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -159,11 +167,25 @@ class EmbeddedNSAAttention(nn.Module):
159
  P_w = torch.nn.functional.softmax(logits_w, dim=-1)
160
  O_win = torch.matmul(P_w, Vw)
161
 
162
- # Gate & mix
163
- gate = self.g2(torch.nn.functional.silu(self.g1(x))) # [B,S,3]
164
- gate = torch.nn.functional.softmax(gate, dim=-1)
165
- gc, gs, gw = gate[..., 0:1], gate[..., 1:2], gate[..., 2:3]
166
- O = gc.unsqueeze(1) * O_cmp + gs.unsqueeze(1) * O_sel + gw.unsqueeze(1) * O_win
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  O = O.transpose(1, 2).reshape(B, S, h * dv)
168
  return self.out(O)
169
 
@@ -240,7 +262,24 @@ class NSATinyLM(nn.Module):
240
  import os as _os
241
  # Allow forcing simple fallback via env for integration tests
242
  _force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
243
- if _force_simple == False:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  self.blocks = nn.ModuleList([
245
  NSABlockRemote(
246
  self.hidden_size,
 
9
  from transformers.modeling_outputs import CausalLMOutput
10
 
11
  from .configuration_nsa import NSAConfig
12
+ _HAS_NSA = False
13
+ try:
14
+ from .nsa.model.llama_block_nsa import LlamaBlockNSA as _VendorNSABlock
15
+ _HAS_NSA = True
16
+ except Exception:
17
+ _VendorNSABlock = None # type: ignore
18
 
19
 
20
  class RMSNorm(nn.Module):
 
105
  self.W_V_sel = nn.Linear(dim, n_kv_groups * d_v, bias=False)
106
  self.W_K_win = nn.Linear(dim, n_kv_groups * d_k, bias=False)
107
  self.W_V_win = nn.Linear(dim, n_kv_groups * d_v, bias=False)
108
+ # Gate MLP operates on per-group pooled Q with width d_k (matches training)
109
+ gate_hidden = max(1, d_k // 2)
110
+ self.gate_fc1 = nn.Linear(d_k, gate_hidden, bias=True)
111
+ self.gate_fc2 = nn.Linear(gate_hidden, 3, bias=True)
112
+ nn.init.xavier_uniform_(self.gate_fc2.weight, gain=0.1)
113
+ nn.init.zeros_(self.gate_fc2.bias)
114
  self.out = nn.Linear(n_heads * d_v, dim, bias=False)
115
 
116
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
167
  P_w = torch.nn.functional.softmax(logits_w, dim=-1)
168
  O_win = torch.matmul(P_w, Vw)
169
 
170
+ # Gate & mix: compute per-token, per-group gate from pooled Q
171
+ # Pool Q across heads within each kv-group
172
+ # Qr: [B,h,S,dk] -> reshape to [B,G,h_per_group,S,dk] then mean over h_per_group
173
+ G = max(1, self.n_kv_groups)
174
+ h_per_group = max(1, h // G)
175
+ Qg = Qr.view(B, G, h_per_group, S, dk).mean(dim=2) # [B,G,S,dk]
176
+ Qg = Qg.permute(0, 2, 1, 3) # [B,S,G,dk]
177
+ g1 = torch.nn.functional.silu(self.gate_fc1(Qg))
178
+ gate = torch.nn.functional.softmax(self.gate_fc2(g1), dim=-1) # [B,S,G,3]
179
+ gc = gate[..., 0:1].unsqueeze(-1) # [B,S,G,1,1]
180
+ gs = gate[..., 1:2].unsqueeze(-1)
181
+ gw = gate[..., 2:3].unsqueeze(-1)
182
+ # Broadcast group gates to heads within the group
183
+ # Reshape branch outputs to [B,S,G,h_per_group,dv]
184
+ Oc = O_cmp.permute(0,2,1,3).view(B, S, G, h_per_group, dv)
185
+ Os = O_sel.permute(0,2,1,3).view(B, S, G, h_per_group, dv)
186
+ Ow = O_win.permute(0,2,1,3).view(B, S, G, h_per_group, dv)
187
+ O = gc * Oc + gs * Os + gw * Ow
188
+ O = O.reshape(B, S, h, dv).permute(0, 2, 1, 3)
189
  O = O.transpose(1, 2).reshape(B, S, h * dv)
190
  return self.out(O)
191
 
 
262
  import os as _os
263
  # Allow forcing simple fallback via env for integration tests
264
  _force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
265
+ if not _force_simple and _HAS_NSA and _VendorNSABlock is not None:
266
+ # Prefer vendored NSA block to match training semantics and map gate weights
267
+ self.blocks = nn.ModuleList([
268
+ _VendorNSABlock(
269
+ dim=self.hidden_size,
270
+ n_heads=self.num_attention_heads,
271
+ n_kv_groups=self.n_kv_groups,
272
+ d_k=self.d_k,
273
+ d_v=self.d_v,
274
+ l=self.l,
275
+ d=self.d,
276
+ l_sel=self.l_sel,
277
+ n_sel=self.n_sel,
278
+ w=self.w,
279
+ ) for _ in range(self.num_hidden_layers)
280
+ ])
281
+ elif not _force_simple:
282
+ # Fallback to embedded minimal NSA if vendor import failed
283
  self.blocks = nn.ModuleList([
284
  NSABlockRemote(
285
  self.hidden_size,