ukung commited on
Commit
d2ceedd
·
verified ·
1 Parent(s): 82586a0

Upload modeling_tinyv4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tinyv4.py +633 -0
modeling_tinyv4.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Architecture: TinyV4 (ManifoldHC + CSA/HCA attention + DeepSeekMoE + PartialRoPE + MTP)
4
+ HF-compatible: supports trust_remote_code via PretrainedConfig + from_pretrained/save_pretrained.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+ from transformers import AutoTokenizer
12
+ from safetensors.torch import load_file as safe_load, save_file as safe_save
13
+ import time
14
+ import math
15
+ import json
16
+ import os
17
+
18
+ # ---- RMSNorm fallback for older PyTorch / CUDA ----
19
+ if hasattr(nn, 'RMSNorm'):
20
+ RMSNorm = nn.RMSNorm
21
+ else:
22
+ class RMSNorm(nn.Module):
23
+ """Manual RMSNorm — works on any device, any PyTorch version."""
24
+ def __init__(self, dim, eps=1e-6):
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+ def forward(self, x):
29
+ norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
30
+ return (x.float() * norm).type_as(x) * self.weight
31
+
32
+ # ============================================================
33
+ # TinyV4 Architecture
34
+ # ============================================================
35
+
36
+ class TinyV4Config(PretrainedConfig):
37
+ model_type = "tinyv4"
38
+
39
+ def __init__(
40
+ self,
41
+ vocab_size: int = 1000,
42
+ dim: int = 384,
43
+ depth: int = 8,
44
+ n_hc: int = 2,
45
+ n_routed: int = 8,
46
+ n_active: int = 2,
47
+ n_shared: int = 1,
48
+ expert_intermediate: int = 512,
49
+ csa_m: int = 4,
50
+ csa_topk: int = 32,
51
+ hca_m: int = 16,
52
+ n_win: int = 32,
53
+ n_q_head: int = 8,
54
+ head_dim: int = 64,
55
+ d_c: int = 192,
56
+ n_idx_head: int = 8,
57
+ idx_head_dim: int = 64,
58
+ n_out_group: int = 2,
59
+ d_g: int = 128,
60
+ rope_dim: int = 32,
61
+ mtp_depth: int = 1,
62
+ hash_layers: int = 3,
63
+ max_len: int = 1024,
64
+ sinkhorn_iters: int = 20,
65
+ aux_bias_update: float = 0.001,
66
+ bal_loss_weight: float = 0.0001,
67
+ **kwargs
68
+ ):
69
+ super().__init__(**kwargs)
70
+ self.vocab_size = vocab_size
71
+ self.dim = dim
72
+ self.depth = depth
73
+ self.n_hc = n_hc
74
+ self.n_routed = n_routed
75
+ self.n_active = n_active
76
+ self.n_shared = n_shared
77
+ self.expert_intermediate = expert_intermediate
78
+ self.csa_m = csa_m
79
+ self.csa_topk = csa_topk
80
+ self.hca_m = hca_m
81
+ self.n_win = n_win
82
+ self.n_q_head = n_q_head
83
+ self.head_dim = head_dim
84
+ self.d_c = d_c
85
+ self.n_idx_head = n_idx_head
86
+ self.idx_head_dim = idx_head_dim
87
+ self.n_out_group = n_out_group
88
+ self.d_g = d_g
89
+ self.rope_dim = rope_dim
90
+ self.mtp_depth = mtp_depth
91
+ self.hash_layers = hash_layers
92
+ self.max_len = max_len
93
+ self.sinkhorn_iters = sinkhorn_iters
94
+ self.aux_bias_update = aux_bias_update
95
+ self.bal_loss_weight = bal_loss_weight
96
+
97
+
98
+ def sinkhorn_knopp(B_raw, n_iters=20):
99
+ M = torch.exp(B_raw)
100
+ for _ in range(n_iters):
101
+ M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-12)
102
+ M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-12)
103
+ return M
104
+
105
+
106
+ class ManifoldHC(nn.Module):
107
+ def __init__(self, dim, n_hc, n_iters=20):
108
+ super().__init__()
109
+ self.dim = dim; self.n_hc = n_hc; self.n_iters = n_iters
110
+ flat_dim = n_hc * dim
111
+ self.W_pre = nn.Linear(flat_dim, n_hc, bias=False)
112
+ self.W_post = nn.Linear(flat_dim, n_hc, bias=False)
113
+ self.W_res = nn.Linear(flat_dim, n_hc * n_hc, bias=False)
114
+ self.S_pre = nn.Parameter(torch.zeros(1, n_hc))
115
+ self.S_post = nn.Parameter(torch.zeros(1, n_hc))
116
+ self.S_res = nn.Parameter(torch.zeros(1, n_hc * n_hc))
117
+ self.alpha_pre = nn.Parameter(torch.tensor(0.1))
118
+ self.alpha_res = nn.Parameter(torch.tensor(0.1))
119
+ self.alpha_post = nn.Parameter(torch.tensor(0.1))
120
+
121
+ def forward(self, X, sublayer):
122
+ B, T, n_hc, d = X.shape
123
+ flat_dim = n_hc * d
124
+ X_flat = X.reshape(B * T, flat_dim)
125
+ X_norm = F.rms_norm(X_flat, (flat_dim,))
126
+ A_raw = self.alpha_pre * self.W_pre(X_norm) + self.S_pre
127
+ C_raw = self.alpha_post * self.W_post(X_norm) + self.S_post
128
+ B_raw = self.alpha_res * self.W_res(X_norm) + self.S_res
129
+ A = torch.sigmoid(A_raw)
130
+ C = 2.0 * torch.sigmoid(C_raw)
131
+ B_mat = B_raw.reshape(B * T, n_hc, n_hc)
132
+ B_mat = sinkhorn_knopp(B_mat, self.n_iters)
133
+ sublayer_input = torch.einsum('bn,bnd->bd', A, X_flat.reshape(B * T, n_hc, d))
134
+ sublayer_input = sublayer_input.reshape(B, T, d)
135
+ sublayer_output = sublayer(sublayer_input)
136
+ sublayer_output = sublayer_output.reshape(B * T, d)
137
+ residual = torch.bmm(B_mat, X_flat.reshape(B * T, n_hc, d))
138
+ injection = C.unsqueeze(-1) * sublayer_output.unsqueeze(1)
139
+ X_new = residual + injection
140
+ return X_new.reshape(B, T, n_hc, d)
141
+
142
+
143
+ class PartialRoPE(nn.Module):
144
+ def __init__(self, dim, rope_dim, max_len=2048):
145
+ super().__init__()
146
+ self.dim = dim; self.rope_dim = rope_dim; self.max_len = max_len
147
+ theta = 10000.0 ** (-2.0 * torch.arange(0, rope_dim, 2) / rope_dim)
148
+ pos = torch.arange(max_len)
149
+ freqs = torch.outer(pos, theta)
150
+ self.register_buffer('cos', freqs.cos())
151
+ self.register_buffer('sin', freqs.sin())
152
+
153
+ def _rotate(self, x, positions):
154
+ B, H, D = x.shape; r = self.rope_dim
155
+ x_rope = x[..., -r:]; x_pass = x[..., :-r]
156
+ x_rope = x_rope.reshape(B, H, r // 2, 2)
157
+ x1, x2 = x_rope[..., 0], x_rope[..., 1]
158
+ cos = self.cos[positions][:, None, :]; sin = self.sin[positions][:, None, :]
159
+ y1 = x1 * cos - x2 * sin; y2 = x1 * sin + x2 * cos
160
+ y_rope = torch.stack([y1, y2], dim=-1).reshape(B, H, r)
161
+ return torch.cat([x_pass, y_rope], dim=-1)
162
+
163
+ def forward(self, q, k, q_pos=None, k_pos=None):
164
+ if q_pos is None: q_pos = torch.arange(q.shape[0], device=q.device)
165
+ if k_pos is None: k_pos = torch.arange(k.shape[0], device=k.device)
166
+ return self._rotate(q, q_pos), self._rotate(k, k_pos)
167
+
168
+ def inverse(self, x, positions=None):
169
+ if positions is None: positions = torch.arange(x.shape[0], device=x.device)
170
+ B, H, D = x.shape; r = self.rope_dim
171
+ x_rope = x[..., -r:]; x_pass = x[..., :-r]
172
+ x_rope = x_rope.reshape(B, H, r // 2, 2)
173
+ x1, x2 = x_rope[..., 0], x_rope[..., 1]
174
+ cos = self.cos[positions][:, None, :]; sin = self.sin[positions][:, None, :]
175
+ y1 = x1 * cos + x2 * sin; y2 = -x1 * sin + x2 * cos
176
+ y_rope = torch.stack([y1, y2], dim=-1).reshape(B, H, r)
177
+ return torch.cat([x_pass, y_rope], dim=-1)
178
+
179
+
180
+ def compress_kv(C, Z, B_pos, m):
181
+ B, T, c = C.shape
182
+ pad_len = (m - (T % m)) % m
183
+ if pad_len > 0:
184
+ C = F.pad(C, (0, 0, 0, pad_len)); Z = F.pad(Z, (0, 0, 0, pad_len))
185
+ T_pad = T + pad_len; T_comp = T_pad // m
186
+ C_blocks = C.reshape(B, T_comp, m, c); Z_blocks = Z.reshape(B, T_comp, m, c)
187
+ scores = Z_blocks + B_pos[None, None, :, :]
188
+ weights = torch.softmax(scores, dim=2)
189
+ return (weights * C_blocks).sum(dim=2)
190
+
191
+
192
+ def compress_kv_csa(C_a, C_b, Z_a, Z_b, B_a, B_b, m):
193
+ B, T, c = C_a.shape
194
+ pad_len = (m - (T % m)) % m
195
+ if pad_len > 0:
196
+ C_a = F.pad(C_a, (0, 0, 0, pad_len)); C_b = F.pad(C_b, (0, 0, 0, pad_len))
197
+ Z_a = F.pad(Z_a, (0, 0, 0, pad_len)); Z_b = F.pad(Z_b, (0, 0, 0, pad_len))
198
+ T_pad = T + pad_len; T_comp = T_pad // m
199
+ C_a_blocks = C_a.reshape(B, T_comp, m, c); C_b_blocks = C_b.reshape(B, T_comp, m, c)
200
+ Z_a_blocks = Z_a.reshape(B, T_comp, m, c); Z_b_blocks = Z_b.reshape(B, T_comp, m, c)
201
+ C_b_shifted = torch.cat([torch.zeros(B, 1, m, c, device=C_b.device), C_b_blocks[:, :-1]], dim=1)
202
+ Z_b_shifted = torch.cat([torch.full((B, 1, m, c), float('-inf'), device=Z_b.device), Z_b_blocks[:, :-1]], dim=1)
203
+ C_cat = torch.cat([C_a_blocks, C_b_shifted], dim=2)
204
+ Z_cat = torch.cat([Z_a_blocks, Z_b_shifted], dim=2)
205
+ B_cat = torch.cat([B_a, B_b], dim=0)
206
+ scores = Z_cat + B_cat[None, None, :, :]
207
+ weights = torch.softmax(scores, dim=2)
208
+ return (weights * C_cat).sum(dim=2)
209
+
210
+
211
+ class CSA(nn.Module):
212
+ def __init__(self, config):
213
+ super().__init__()
214
+ d = config.dim; c = config.head_dim; n_h = config.n_q_head
215
+ n_h_I = config.n_idx_head; c_I = config.idx_head_dim; d_c = config.d_c
216
+ m = config.csa_m; topk = config.csa_topk; n_win = config.n_win
217
+ g = config.n_out_group; d_g = config.d_g
218
+ self.d, self.c, self.n_h, self.n_h_I, self.c_I = d, c, n_h, n_h_I, c_I
219
+ self.d_c, self.m, self.topk, self.n_win, self.g, self.d_g = d_c, m, topk, n_win, g, d_g
220
+ self.W_aKV = nn.Linear(d, c, bias=False); self.W_bKV = nn.Linear(d, c, bias=False)
221
+ self.W_aZ = nn.Linear(d, c, bias=False); self.W_bZ = nn.Linear(d, c, bias=False)
222
+ self.B_a = nn.Parameter(torch.zeros(m, c)); self.B_b = nn.Parameter(torch.zeros(m, c))
223
+ self.W_idxKV = nn.Linear(d, c_I, bias=False); self.W_idxZ = nn.Linear(d, c_I, bias=False)
224
+ self.B_idx = nn.Parameter(torch.zeros(m, c_I))
225
+ self.W_DQ = nn.Linear(d, d_c, bias=False)
226
+ self.W_IUQ = nn.Linear(d_c, c_I * n_h_I, bias=False)
227
+ self.W_UQ = nn.Linear(d_c, c * n_h, bias=False)
228
+ self.W_w = nn.Linear(d, n_h_I, bias=False)
229
+ self.W_swKV = nn.Linear(d, c, bias=False)
230
+ assert n_h % g == 0
231
+ hpg = n_h // g; god = hpg * c
232
+ self.group_proj = nn.ModuleList([nn.Linear(god, d_g, bias=False) for _ in range(g)])
233
+ self.out_proj = nn.Linear(d_g * g, d, bias=False)
234
+ self.sink_logits = nn.Parameter(torch.zeros(n_h))
235
+ self.rope = PartialRoPE(c, config.rope_dim, config.max_len)
236
+ self.q_norm = RMSNorm(c); self.kv_norm = RMSNorm(c)
237
+
238
+ def forward(self, x):
239
+ B, T, d = x.shape; device = x.device
240
+ m, c, n_h, n_h_I, c_I, topk, n_win = self.m, self.c, self.n_h, self.n_h_I, self.c_I, self.topk, self.n_win
241
+ C_a = self.W_aKV(x); C_b = self.W_bKV(x); Z_a = self.W_aZ(x); Z_b = self.W_bZ(x)
242
+ KV_comp = compress_kv_csa(C_a, C_b, Z_a, Z_b, self.B_a, self.B_b, m)
243
+ T_comp = KV_comp.shape[1]
244
+ C_idx = self.W_idxKV(x); Z_idx = self.W_idxZ(x)
245
+ K_idx_comp = compress_kv(C_idx, Z_idx, self.B_idx, m)
246
+ c_Q = self.W_DQ(x)
247
+ q_I = self.W_IUQ(c_Q).reshape(B, T, n_h_I, c_I)
248
+ q = self.W_UQ(c_Q).reshape(B, T, n_h, c)
249
+ w_I = self.W_w(x)
250
+ idx_scores = torch.einsum('bthc,bsc->bths', q_I, K_idx_comp)
251
+ idx_scores = torch.einsum('bth,bths->bts', F.relu(w_I), F.relu(idx_scores))
252
+ query_block = torch.arange(T, device=device) // m
253
+ causal_mask = query_block[:, None] > torch.arange(T_comp, device=device)[None, :]
254
+ idx_scores = idx_scores.masked_fill(~causal_mask, float('-inf'))
255
+ SW_KV = self.W_swKV(x)
256
+ SW_KV_padded = F.pad(SW_KV, (0, 0, n_win, 0))
257
+ win_indices = torch.arange(n_win, device=device)[None, None, :]
258
+ query_pos = torch.arange(T, device=device)[None, :, None]
259
+ gather_idx = (query_pos + win_indices).clamp(0, T + n_win - 1).expand(B, -1, -1)
260
+ SW_gathered = SW_KV_padded[torch.arange(B, device=device)[:, None, None], gather_idx]
261
+ KV_all = torch.cat([KV_comp.unsqueeze(1).expand(-1, T, -1, -1), SW_gathered], dim=2)
262
+ n_kv = T_comp + n_win
263
+ q = self.q_norm(q.reshape(B * T * n_h, c)).reshape(B, T, n_h, c)
264
+ KV_all = self.kv_norm(KV_all.reshape(B * T * n_kv, c)).reshape(B, T, n_kv, c)
265
+ q_pos = torch.arange(T, device=device).repeat(B)
266
+ comp_positions = (torch.arange(T_comp, device=device) * m + m // 2)
267
+ sw_positions = torch.arange(T, device=device)[:, None] - torch.arange(n_win, device=device)[None, :]
268
+ sw_positions = sw_positions.clamp(min=0)
269
+ kv_positions = torch.cat([comp_positions.unsqueeze(0).expand(T, -1), sw_positions], dim=1)
270
+ kv_pos_flat = kv_positions.reshape(-1).repeat(B)
271
+ q_flat = q.reshape(B * T, n_h, c)
272
+ q_flat = self.rope._rotate(q_flat, q_pos)
273
+ q = q_flat.reshape(B, T, n_h, c)
274
+ kv_flat = KV_all.reshape(B * T * n_kv, 1, c)
275
+ kv_flat = self.rope._rotate(kv_flat, kv_pos_flat)
276
+ KV_all = kv_flat.reshape(B, T, n_kv, c)
277
+ KV_expanded = KV_all.unsqueeze(2).expand(-1, -1, n_h, -1, -1)
278
+ scale = c ** -0.5
279
+ attn_logits = torch.einsum('bthc,bthkc->bthk', q, KV_expanded) * scale
280
+ idx_bias = F.pad(idx_scores, (0, n_win), value=0.0)
281
+ attn_logits = attn_logits + idx_bias[:, :, None, :]
282
+ causal_mask_comp = query_block[:, None] > torch.arange(T_comp, device=device)[None, :]
283
+ causal_mask_all = torch.cat([causal_mask_comp, torch.ones(T, n_win, dtype=torch.bool, device=device)], dim=1)
284
+ attn_logits = attn_logits.masked_fill(~causal_mask_all[None, :, None, :], float('-inf'))
285
+ sink = self.sink_logits[None, None, :, None]
286
+ attn_logits_with_sink = torch.cat([attn_logits, sink.expand(B, T, -1, -1)], dim=-1)
287
+ attn_weights = torch.softmax(attn_logits_with_sink, dim=-1)[..., :n_kv]
288
+ o = torch.einsum('bthk,bthkc->bthc', attn_weights, KV_expanded)
289
+ o_flat = o.reshape(B * T, n_h, c)
290
+ o_pos = torch.arange(T, device=device).repeat(B)
291
+ o_flat = self.rope.inverse(o_flat, o_pos)
292
+ o = o_flat.reshape(B, T, n_h, c)
293
+ hpg = n_h // self.g
294
+ o_groups = o.chunk(self.g, dim=2)
295
+ intermediates = []
296
+ for proj, og in zip(self.group_proj, o_groups):
297
+ intermediates.append(proj(og.reshape(B, T, hpg * c)))
298
+ return self.out_proj(torch.cat(intermediates, dim=-1))
299
+
300
+
301
+ class HCA(nn.Module):
302
+ def __init__(self, config):
303
+ super().__init__()
304
+ d = config.dim; c = config.head_dim; n_h = config.n_q_head
305
+ d_c = config.d_c; m = config.hca_m; n_win = config.n_win
306
+ g = config.n_out_group; d_g = config.d_g
307
+ self.d, self.c, self.n_h, self.d_c, self.m, self.n_win, self.g, self.d_g = d, c, n_h, d_c, m, n_win, g, d_g
308
+ self.W_KV = nn.Linear(d, c, bias=False); self.W_Z = nn.Linear(d, c, bias=False)
309
+ self.B_pos = nn.Parameter(torch.zeros(m, c))
310
+ self.W_DQ = nn.Linear(d, d_c, bias=False)
311
+ self.W_UQ = nn.Linear(d_c, c * n_h, bias=False)
312
+ self.W_swKV = nn.Linear(d, c, bias=False)
313
+ assert n_h % g == 0
314
+ hpg = n_h // g; god = hpg * c
315
+ self.group_proj = nn.ModuleList([nn.Linear(god, d_g, bias=False) for _ in range(g)])
316
+ self.out_proj = nn.Linear(d_g * g, d, bias=False)
317
+ self.sink_logits = nn.Parameter(torch.zeros(n_h))
318
+ self.rope = PartialRoPE(c, config.rope_dim, config.max_len)
319
+ self.q_norm = RMSNorm(c); self.kv_norm = RMSNorm(c)
320
+
321
+ def forward(self, x):
322
+ B, T, d = x.shape; device = x.device
323
+ m, c, n_h, n_win = self.m, self.c, self.n_h, self.n_win
324
+ C = self.W_KV(x); Z = self.W_Z(x)
325
+ KV_comp = compress_kv(C, Z, self.B_pos, m)
326
+ T_comp = KV_comp.shape[1]
327
+ c_Q = self.W_DQ(x)
328
+ q = self.W_UQ(c_Q).reshape(B, T, n_h, c)
329
+ SW_KV = self.W_swKV(x)
330
+ SW_KV_padded = F.pad(SW_KV, (0, 0, n_win, 0))
331
+ win_indices = torch.arange(n_win, device=device)[None, None, :]
332
+ query_pos = torch.arange(T, device=device)[None, :, None]
333
+ gather_idx = (query_pos + win_indices).clamp(0, T + n_win - 1).expand(B, -1, -1)
334
+ SW_gathered = SW_KV_padded[torch.arange(B, device=device)[:, None, None], gather_idx]
335
+ KV_all = torch.cat([KV_comp.unsqueeze(1).expand(-1, T, -1, -1), SW_gathered], dim=2)
336
+ n_kv = T_comp + n_win
337
+ q = self.q_norm(q.reshape(B * T * n_h, c)).reshape(B, T, n_h, c)
338
+ KV_all = self.kv_norm(KV_all.reshape(B * T * n_kv, c)).reshape(B, T, n_kv, c)
339
+ q_pos = torch.arange(T, device=device).repeat(B)
340
+ comp_positions = (torch.arange(T_comp, device=device) * m + m // 2)
341
+ sw_positions = torch.arange(T, device=device)[:, None] - torch.arange(n_win, device=device)[None, :]
342
+ sw_positions = sw_positions.clamp(min=0)
343
+ kv_positions = torch.cat([comp_positions.unsqueeze(0).expand(T, -1), sw_positions], dim=1)
344
+ kv_pos_flat = kv_positions.reshape(-1).repeat(B)
345
+ q_flat = q.reshape(B * T, n_h, c)
346
+ q_flat = self.rope._rotate(q_flat, q_pos)
347
+ q = q_flat.reshape(B, T, n_h, c)
348
+ kv_flat = KV_all.reshape(B * T * n_kv, 1, c)
349
+ kv_flat = self.rope._rotate(kv_flat, kv_pos_flat)
350
+ KV_all = kv_flat.reshape(B, T, n_kv, c)
351
+ KV_expanded = KV_all.unsqueeze(2).expand(-1, -1, n_h, -1, -1)
352
+ scale = c ** -0.5
353
+ attn_logits = torch.einsum('bthc,bthkc->bthk', q, KV_expanded) * scale
354
+ query_block = torch.arange(T, device=device) // m
355
+ causal_mask = (query_block[:, None] > torch.arange(T_comp, device=device)[None, :])
356
+ causal_mask = torch.cat([causal_mask, torch.ones(T, n_win, dtype=torch.bool, device=device)], dim=1)
357
+ attn_logits = attn_logits.masked_fill(~causal_mask[None, :, None, :], float('-inf'))
358
+ sink = self.sink_logits[None, None, :, None]
359
+ attn_logits_with_sink = torch.cat([attn_logits, sink.expand(B, T, -1, -1)], dim=-1)
360
+ attn_weights = torch.softmax(attn_logits_with_sink, dim=-1)[..., :n_kv]
361
+ o = torch.einsum('bthk,bthkc->bthc', attn_weights, KV_expanded)
362
+ o_flat = o.reshape(B * T, n_h, c)
363
+ o_pos = torch.arange(T, device=device).repeat(B)
364
+ o_flat = self.rope.inverse(o_flat, o_pos)
365
+ o = o_flat.reshape(B, T, n_h, c)
366
+ hpg = n_h // self.g
367
+ o_groups = o.chunk(self.g, dim=2)
368
+ intermediates = []
369
+ for proj, og in zip(self.group_proj, o_groups):
370
+ intermediates.append(proj(og.reshape(B, T, hpg * c)))
371
+ return self.out_proj(torch.cat(intermediates, dim=-1))
372
+
373
+
374
+ class Expert(nn.Module):
375
+ def __init__(self, dim, intermediate):
376
+ super().__init__()
377
+ self.gate_proj = nn.Linear(dim, intermediate, bias=False)
378
+ self.up_proj = nn.Linear(dim, intermediate, bias=False)
379
+ self.down_proj = nn.Linear(intermediate, dim, bias=False)
380
+ def forward(self, x):
381
+ gate = torch.clamp(self.gate_proj(x), max=10.0)
382
+ up = torch.clamp(self.up_proj(x), min=-10.0, max=10.0)
383
+ return self.down_proj(F.silu(gate) * up)
384
+
385
+
386
+ class DeepSeekMoE(nn.Module):
387
+ def __init__(self, config, layer_idx):
388
+ super().__init__()
389
+ d = config.dim
390
+ self.use_hash = layer_idx < config.hash_layers
391
+ self.d, self.n_routed, self.n_active = d, config.n_routed, config.n_active
392
+ self.shared_experts = nn.ModuleList([Expert(d, config.expert_intermediate) for _ in range(config.n_shared)])
393
+ self.routed_experts = nn.ModuleList([Expert(d, config.expert_intermediate) for _ in range(config.n_routed)])
394
+ self.gate = nn.Linear(d, config.n_routed, bias=False)
395
+ self.register_buffer('e_bias', torch.zeros(config.n_routed))
396
+ self.register_buffer('expert_counts', torch.zeros(config.n_routed))
397
+
398
+ def forward(self, x):
399
+ B, T, d = x.shape; device = x.device
400
+ shared_out = sum(expert(x) for expert in self.shared_experts)
401
+ if self.use_hash:
402
+ pos = torch.arange(T, device=device)
403
+ expert_idx = pos % self.n_routed
404
+ routed_out = torch.zeros(B, T, d, device=device)
405
+ for e_idx in range(self.n_routed):
406
+ mask = (expert_idx == e_idx).float()
407
+ if mask.sum() > 0:
408
+ routed_out = routed_out + self.routed_experts[e_idx](x * mask[None, :, None]) * mask[None, :, None]
409
+ return shared_out + routed_out, torch.tensor(0.0, device=device)
410
+ gate_out = self.gate(x)
411
+ affinity = torch.sqrt(F.softplus(gate_out)) + self.e_bias
412
+ topk_weights, topk_indices = torch.topk(affinity, self.n_active, dim=-1)
413
+ topk_weights = F.softmax(topk_weights, dim=-1)
414
+ with torch.no_grad():
415
+ counts = torch.zeros(self.n_routed, device=device)
416
+ for k in range(self.n_active):
417
+ counts.scatter_add_(0, topk_indices[..., k].reshape(-1), torch.ones(B * T, device=device))
418
+ self.expert_counts = counts.detach()
419
+ routed_out = torch.zeros(B, T, d, device=device)
420
+ for e_idx in range(self.n_routed):
421
+ mask = (topk_indices == e_idx).any(dim=-1)
422
+ if mask.any():
423
+ weight_mask = (topk_indices == e_idx).float()
424
+ weights = (topk_weights * weight_mask).sum(dim=-1)
425
+ routed_out[mask] = routed_out[mask] + self.routed_experts[e_idx](x[mask]) * weights[mask, None]
426
+ frac = counts / (B * T * self.n_active)
427
+ bal_loss = torch.dot(frac, self.e_bias)
428
+ return shared_out + routed_out, bal_loss
429
+
430
+ def update_bias(self):
431
+ if not self.use_hash:
432
+ with torch.no_grad():
433
+ n_total = self.expert_counts.sum()
434
+ if n_total > 0:
435
+ target = n_total / self.n_routed
436
+ self.e_bias -= 0.001 * (self.expert_counts - target) / max(target, 1)
437
+
438
+
439
+ class TransformerBlock(nn.Module):
440
+ def __init__(self, config, layer_idx):
441
+ super().__init__()
442
+ d = config.dim; n_hc = config.n_hc
443
+ if layer_idx < 2: self.attn = HCA(config)
444
+ elif layer_idx % 2 == 0: self.attn = CSA(config)
445
+ else: self.attn = HCA(config)
446
+ self.mhc_attn = ManifoldHC(d, n_hc, config.sinkhorn_iters)
447
+ self.mhc_ffn = ManifoldHC(d, n_hc, config.sinkhorn_iters)
448
+ self.moe = DeepSeekMoE(config, layer_idx)
449
+
450
+ def forward(self, X):
451
+ X = self.mhc_attn(X, self.attn)
452
+ bl = [torch.tensor(0.0, device=X.device)]
453
+ def moe_fn(x):
454
+ out, b = self.moe(x); bl[0] = b; return out
455
+ X = self.mhc_ffn(X, moe_fn)
456
+ return X, bl[0]
457
+
458
+
459
+ class MTPModule(nn.Module):
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ d = config.dim; n_hc = config.n_hc
463
+ self.proj_in = nn.Linear(d, d, bias=False)
464
+ self.mhc = ManifoldHC(d, n_hc, config.sinkhorn_iters)
465
+ self.attn = HCA(config)
466
+ self.norm = nn.LayerNorm(d)
467
+ self.head = nn.Linear(d, config.vocab_size, bias=False)
468
+
469
+ def forward(self, h, X):
470
+ h_proj = self.proj_in(h)
471
+ X = self.mhc(X, lambda x: self.attn(x))
472
+ return self.head(self.norm(X[:, :, 0, :] + h_proj))
473
+
474
+
475
+ class TinyV4(PreTrainedModel):
476
+ config_class = TinyV4Config
477
+ base_model_prefix = "tinyv4"
478
+ supports_gradient_checkpointing = False
479
+
480
+ def __init__(self, config):
481
+ super().__init__(config)
482
+ d = config.dim; n_hc = config.n_hc
483
+ self.embed = nn.Embedding(config.vocab_size, d)
484
+ self.expand = nn.Linear(d, n_hc * d, bias=False)
485
+ self.blocks = nn.ModuleList([TransformerBlock(config, i) for i in range(config.depth)])
486
+ self.norm = nn.LayerNorm(d)
487
+ self.head = nn.Linear(d, config.vocab_size, bias=False)
488
+ self.mtp = MTPModule(config) if config.mtp_depth > 0 else None
489
+ self.post_init()
490
+
491
+ def forward(self, input_ids):
492
+ B, T = input_ids.shape; d = self.config.dim; n_hc = self.config.n_hc; device = input_ids.device
493
+ x = self.embed(input_ids)
494
+ X = self.expand(x).reshape(B, T, n_hc, d)
495
+ total_bal_loss = torch.tensor(0.0, device=device)
496
+ for block in self.blocks:
497
+ X, bl = block(X); total_bal_loss = total_bal_loss + bl
498
+ h = X[:, :, 0, :]
499
+ logits = self.head(self.norm(h))
500
+ mtp_logits = self.mtp(h, X) if self.mtp else None
501
+ return logits, mtp_logits, total_bal_loss
502
+
503
+ def param_count(self):
504
+ return sum(p.numel() for p in self.parameters())
505
+
506
+ @classmethod
507
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
508
+ """Load TinyV4 from a directory containing model.safetensors + config.json."""
509
+ model_path = pretrained_model_name_or_path
510
+
511
+ # Load config manually (PretrainedConfig.from_pretrained sometimes misses custom fields)
512
+ config_file = os.path.join(model_path, "config.json")
513
+ if not os.path.exists(config_file):
514
+ raise FileNotFoundError(f"config.json not found in {model_path}")
515
+ with open(config_file, "r") as f:
516
+ config_dict = json.load(f)
517
+ config = TinyV4Config(**config_dict)
518
+
519
+ # Create model with config
520
+ model = cls(config)
521
+
522
+ # Load weights
523
+ weights_file = os.path.join(model_path, "model.safetensors")
524
+ if not os.path.exists(weights_file):
525
+ raise FileNotFoundError(f"model.safetensors not found in {model_path}")
526
+
527
+ state_dict = safe_load(weights_file)
528
+ model.load_state_dict(state_dict, strict=False)
529
+
530
+ return model
531
+
532
+ def save_pretrained(self, save_directory, **kwargs):
533
+ """Save TinyV4 config + weights to a directory."""
534
+ os.makedirs(save_directory, exist_ok=True)
535
+
536
+ # Save config
537
+ self.config.save_pretrained(save_directory)
538
+
539
+ # Save weights
540
+ safe_save(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
541
+
542
+
543
+ # ============================================================
544
+ # Auto-search for ~10M config
545
+ # ============================================================
546
+ def search_best_config(target_params=10_000_000, vocab_size=32000):
547
+ """Search for config that gives closest to target_params."""
548
+ best_config = None
549
+ best_diff = float('inf')
550
+
551
+ configs = [
552
+ # With vocab=32000 + tie_embeddings, embedding is only ~2M
553
+ # So we can afford bigger transformer blocks
554
+ TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=4, n_active=2,
555
+ n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8,
556
+ n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4,
557
+ idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0,
558
+ hash_layers=2, max_len=512),
559
+ TinyV4Config(vocab_size=vocab_size, dim=128, depth=8, n_hc=2, n_routed=4, n_active=2,
560
+ n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8,
561
+ n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4,
562
+ idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0,
563
+ hash_layers=3, max_len=512),
564
+ TinyV4Config(vocab_size=vocab_size, dim=160, depth=4, n_hc=2, n_routed=4, n_active=2,
565
+ n_shared=1, expert_intermediate=256, csa_m=4, csa_topk=16, hca_m=8,
566
+ n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4,
567
+ idx_head_dim=48, n_out_group=2, d_g=80, rope_dim=24, mtp_depth=0,
568
+ hash_layers=2, max_len=512),
569
+ TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=6, n_active=2,
570
+ n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8,
571
+ n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4,
572
+ idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0,
573
+ hash_layers=2, max_len=512),
574
+ TinyV4Config(vocab_size=vocab_size, dim=96, depth=8, n_hc=2, n_routed=4, n_active=2,
575
+ n_shared=1, expert_intermediate=128, csa_m=4, csa_topk=16, hca_m=8,
576
+ n_win=16, n_q_head=4, head_dim=48, d_c=48, n_idx_head=4,
577
+ idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0,
578
+ hash_layers=3, max_len=512),
579
+ TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=4, n_active=2,
580
+ n_shared=1, expert_intermediate=256, csa_m=4, csa_topk=16, hca_m=8,
581
+ n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4,
582
+ idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0,
583
+ hash_layers=2, max_len=512),
584
+ # With MTP
585
+ TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=4, n_active=2,
586
+ n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8,
587
+ n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4,
588
+ idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=1,
589
+ hash_layers=2, max_len=512),
590
+ ]
591
+
592
+ print(f"\n{'='*70}")
593
+ print(f"Searching for config closest to {target_params/1e6:.1f}M params (vocab={vocab_size})")
594
+ print(f"Note: tie_embeddings=True — embed & head share weights")
595
+ print(f"{'='*70}")
596
+
597
+ for cfg in configs:
598
+ model = TinyV4(cfg)
599
+ # Tie embeddings: head.weight = embed.weight
600
+ model.head.weight = model.embed.weight
601
+ n = model.param_count()
602
+ diff = abs(n - target_params)
603
+ pct = (n - target_params) / target_params * 100
604
+ print(f" dim={cfg.dim:3d} depth={cfg.depth} n_routed={cfg.n_routed} expert_int={cfg.expert_intermediate:3d} "
605
+ f"mtp={cfg.mtp_depth} → {n/1e6:.2f}M params ({pct:+.1f}%)")
606
+ if diff < best_diff:
607
+ best_diff = diff
608
+ best_config = cfg
609
+ del model
610
+
611
+ print(f"\n✅ Best config: {best_config.dim}d {best_config.depth}L → "
612
+ f"{TinyV4(best_config).param_count()/1e6:.2f}M params (with tie_embeddings)")
613
+ return best_config
614
+
615
+
616
+ # ============================================================
617
+ # Generation (using HuggingFace tokenizer)
618
+ # ============================================================
619
+ def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_k=50, device='cpu'):
620
+ model.eval()
621
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
622
+ with torch.no_grad():
623
+ for _ in range(max_new_tokens):
624
+ idx = input_ids[:, -model.config.max_len:]
625
+ logits, _, _ = model(idx)
626
+ logits = logits[:, -1, :] / temperature
627
+ if top_k > 0:
628
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
629
+ logits[logits < v[:, [-1]]] = float('-inf')
630
+ probs = torch.softmax(logits, dim=-1)
631
+ next_token = torch.multinomial(probs, num_samples=1)
632
+ input_ids = torch.cat([input_ids, next_token], dim=1)
633
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)