| |
| """ |
| Architecture: TinyV4 (ManifoldHC + CSA/HCA attention + DeepSeekMoE + PartialRoPE + MTP) |
| HF-compatible: supports trust_remote_code via PretrainedConfig + from_pretrained/save_pretrained. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers import AutoTokenizer |
| from safetensors.torch import load_file as safe_load, save_file as safe_save |
| import time |
| import math |
| import json |
| import os |
|
|
| |
| if hasattr(nn, 'RMSNorm'): |
| RMSNorm = nn.RMSNorm |
| else: |
| class RMSNorm(nn.Module): |
| """Manual RMSNorm — works on any device, any PyTorch version.""" |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| def forward(self, x): |
| norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) |
| return (x.float() * norm).type_as(x) * self.weight |
|
|
| |
| |
| |
|
|
| class TinyV4Config(PretrainedConfig): |
| model_type = "tinyv4" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 1000, |
| dim: int = 384, |
| depth: int = 8, |
| n_hc: int = 2, |
| n_routed: int = 8, |
| n_active: int = 2, |
| n_shared: int = 1, |
| expert_intermediate: int = 512, |
| csa_m: int = 4, |
| csa_topk: int = 32, |
| hca_m: int = 16, |
| n_win: int = 32, |
| n_q_head: int = 8, |
| head_dim: int = 64, |
| d_c: int = 192, |
| n_idx_head: int = 8, |
| idx_head_dim: int = 64, |
| n_out_group: int = 2, |
| d_g: int = 128, |
| rope_dim: int = 32, |
| mtp_depth: int = 1, |
| hash_layers: int = 3, |
| max_len: int = 1024, |
| sinkhorn_iters: int = 20, |
| aux_bias_update: float = 0.001, |
| bal_loss_weight: float = 0.0001, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.dim = dim |
| self.depth = depth |
| self.n_hc = n_hc |
| self.n_routed = n_routed |
| self.n_active = n_active |
| self.n_shared = n_shared |
| self.expert_intermediate = expert_intermediate |
| self.csa_m = csa_m |
| self.csa_topk = csa_topk |
| self.hca_m = hca_m |
| self.n_win = n_win |
| self.n_q_head = n_q_head |
| self.head_dim = head_dim |
| self.d_c = d_c |
| self.n_idx_head = n_idx_head |
| self.idx_head_dim = idx_head_dim |
| self.n_out_group = n_out_group |
| self.d_g = d_g |
| self.rope_dim = rope_dim |
| self.mtp_depth = mtp_depth |
| self.hash_layers = hash_layers |
| self.max_len = max_len |
| self.sinkhorn_iters = sinkhorn_iters |
| self.aux_bias_update = aux_bias_update |
| self.bal_loss_weight = bal_loss_weight |
|
|
|
|
| def sinkhorn_knopp(B_raw, n_iters=20): |
| M = torch.exp(B_raw) |
| for _ in range(n_iters): |
| M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-12) |
| M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-12) |
| return M |
|
|
|
|
| class ManifoldHC(nn.Module): |
| def __init__(self, dim, n_hc, n_iters=20): |
| super().__init__() |
| self.dim = dim; self.n_hc = n_hc; self.n_iters = n_iters |
| flat_dim = n_hc * dim |
| self.W_pre = nn.Linear(flat_dim, n_hc, bias=False) |
| self.W_post = nn.Linear(flat_dim, n_hc, bias=False) |
| self.W_res = nn.Linear(flat_dim, n_hc * n_hc, bias=False) |
| self.S_pre = nn.Parameter(torch.zeros(1, n_hc)) |
| self.S_post = nn.Parameter(torch.zeros(1, n_hc)) |
| self.S_res = nn.Parameter(torch.zeros(1, n_hc * n_hc)) |
| self.alpha_pre = nn.Parameter(torch.tensor(0.1)) |
| self.alpha_res = nn.Parameter(torch.tensor(0.1)) |
| self.alpha_post = nn.Parameter(torch.tensor(0.1)) |
|
|
| def forward(self, X, sublayer): |
| B, T, n_hc, d = X.shape |
| flat_dim = n_hc * d |
| X_flat = X.reshape(B * T, flat_dim) |
| X_norm = F.rms_norm(X_flat, (flat_dim,)) |
| A_raw = self.alpha_pre * self.W_pre(X_norm) + self.S_pre |
| C_raw = self.alpha_post * self.W_post(X_norm) + self.S_post |
| B_raw = self.alpha_res * self.W_res(X_norm) + self.S_res |
| A = torch.sigmoid(A_raw) |
| C = 2.0 * torch.sigmoid(C_raw) |
| B_mat = B_raw.reshape(B * T, n_hc, n_hc) |
| B_mat = sinkhorn_knopp(B_mat, self.n_iters) |
| sublayer_input = torch.einsum('bn,bnd->bd', A, X_flat.reshape(B * T, n_hc, d)) |
| sublayer_input = sublayer_input.reshape(B, T, d) |
| sublayer_output = sublayer(sublayer_input) |
| sublayer_output = sublayer_output.reshape(B * T, d) |
| residual = torch.bmm(B_mat, X_flat.reshape(B * T, n_hc, d)) |
| injection = C.unsqueeze(-1) * sublayer_output.unsqueeze(1) |
| X_new = residual + injection |
| return X_new.reshape(B, T, n_hc, d) |
|
|
|
|
| class PartialRoPE(nn.Module): |
| def __init__(self, dim, rope_dim, max_len=2048): |
| super().__init__() |
| self.dim = dim; self.rope_dim = rope_dim; self.max_len = max_len |
| theta = 10000.0 ** (-2.0 * torch.arange(0, rope_dim, 2) / rope_dim) |
| pos = torch.arange(max_len) |
| freqs = torch.outer(pos, theta) |
| self.register_buffer('cos', freqs.cos()) |
| self.register_buffer('sin', freqs.sin()) |
|
|
| def _rotate(self, x, positions): |
| B, H, D = x.shape; r = self.rope_dim |
| x_rope = x[..., -r:]; x_pass = x[..., :-r] |
| x_rope = x_rope.reshape(B, H, r // 2, 2) |
| x1, x2 = x_rope[..., 0], x_rope[..., 1] |
| cos = self.cos[positions][:, None, :]; sin = self.sin[positions][:, None, :] |
| y1 = x1 * cos - x2 * sin; y2 = x1 * sin + x2 * cos |
| y_rope = torch.stack([y1, y2], dim=-1).reshape(B, H, r) |
| return torch.cat([x_pass, y_rope], dim=-1) |
|
|
| def forward(self, q, k, q_pos=None, k_pos=None): |
| if q_pos is None: q_pos = torch.arange(q.shape[0], device=q.device) |
| if k_pos is None: k_pos = torch.arange(k.shape[0], device=k.device) |
| return self._rotate(q, q_pos), self._rotate(k, k_pos) |
|
|
| def inverse(self, x, positions=None): |
| if positions is None: positions = torch.arange(x.shape[0], device=x.device) |
| B, H, D = x.shape; r = self.rope_dim |
| x_rope = x[..., -r:]; x_pass = x[..., :-r] |
| x_rope = x_rope.reshape(B, H, r // 2, 2) |
| x1, x2 = x_rope[..., 0], x_rope[..., 1] |
| cos = self.cos[positions][:, None, :]; sin = self.sin[positions][:, None, :] |
| y1 = x1 * cos + x2 * sin; y2 = -x1 * sin + x2 * cos |
| y_rope = torch.stack([y1, y2], dim=-1).reshape(B, H, r) |
| return torch.cat([x_pass, y_rope], dim=-1) |
|
|
|
|
| def compress_kv(C, Z, B_pos, m): |
| B, T, c = C.shape |
| pad_len = (m - (T % m)) % m |
| if pad_len > 0: |
| C = F.pad(C, (0, 0, 0, pad_len)); Z = F.pad(Z, (0, 0, 0, pad_len)) |
| T_pad = T + pad_len; T_comp = T_pad // m |
| C_blocks = C.reshape(B, T_comp, m, c); Z_blocks = Z.reshape(B, T_comp, m, c) |
| scores = Z_blocks + B_pos[None, None, :, :] |
| weights = torch.softmax(scores, dim=2) |
| return (weights * C_blocks).sum(dim=2) |
|
|
|
|
| def compress_kv_csa(C_a, C_b, Z_a, Z_b, B_a, B_b, m): |
| B, T, c = C_a.shape |
| pad_len = (m - (T % m)) % m |
| if pad_len > 0: |
| C_a = F.pad(C_a, (0, 0, 0, pad_len)); C_b = F.pad(C_b, (0, 0, 0, pad_len)) |
| Z_a = F.pad(Z_a, (0, 0, 0, pad_len)); Z_b = F.pad(Z_b, (0, 0, 0, pad_len)) |
| T_pad = T + pad_len; T_comp = T_pad // m |
| C_a_blocks = C_a.reshape(B, T_comp, m, c); C_b_blocks = C_b.reshape(B, T_comp, m, c) |
| Z_a_blocks = Z_a.reshape(B, T_comp, m, c); Z_b_blocks = Z_b.reshape(B, T_comp, m, c) |
| C_b_shifted = torch.cat([torch.zeros(B, 1, m, c, device=C_b.device), C_b_blocks[:, :-1]], dim=1) |
| Z_b_shifted = torch.cat([torch.full((B, 1, m, c), float('-inf'), device=Z_b.device), Z_b_blocks[:, :-1]], dim=1) |
| C_cat = torch.cat([C_a_blocks, C_b_shifted], dim=2) |
| Z_cat = torch.cat([Z_a_blocks, Z_b_shifted], dim=2) |
| B_cat = torch.cat([B_a, B_b], dim=0) |
| scores = Z_cat + B_cat[None, None, :, :] |
| weights = torch.softmax(scores, dim=2) |
| return (weights * C_cat).sum(dim=2) |
|
|
|
|
| class CSA(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| d = config.dim; c = config.head_dim; n_h = config.n_q_head |
| n_h_I = config.n_idx_head; c_I = config.idx_head_dim; d_c = config.d_c |
| m = config.csa_m; topk = config.csa_topk; n_win = config.n_win |
| g = config.n_out_group; d_g = config.d_g |
| self.d, self.c, self.n_h, self.n_h_I, self.c_I = d, c, n_h, n_h_I, c_I |
| self.d_c, self.m, self.topk, self.n_win, self.g, self.d_g = d_c, m, topk, n_win, g, d_g |
| self.W_aKV = nn.Linear(d, c, bias=False); self.W_bKV = nn.Linear(d, c, bias=False) |
| self.W_aZ = nn.Linear(d, c, bias=False); self.W_bZ = nn.Linear(d, c, bias=False) |
| self.B_a = nn.Parameter(torch.zeros(m, c)); self.B_b = nn.Parameter(torch.zeros(m, c)) |
| self.W_idxKV = nn.Linear(d, c_I, bias=False); self.W_idxZ = nn.Linear(d, c_I, bias=False) |
| self.B_idx = nn.Parameter(torch.zeros(m, c_I)) |
| self.W_DQ = nn.Linear(d, d_c, bias=False) |
| self.W_IUQ = nn.Linear(d_c, c_I * n_h_I, bias=False) |
| self.W_UQ = nn.Linear(d_c, c * n_h, bias=False) |
| self.W_w = nn.Linear(d, n_h_I, bias=False) |
| self.W_swKV = nn.Linear(d, c, bias=False) |
| assert n_h % g == 0 |
| hpg = n_h // g; god = hpg * c |
| self.group_proj = nn.ModuleList([nn.Linear(god, d_g, bias=False) for _ in range(g)]) |
| self.out_proj = nn.Linear(d_g * g, d, bias=False) |
| self.sink_logits = nn.Parameter(torch.zeros(n_h)) |
| self.rope = PartialRoPE(c, config.rope_dim, config.max_len) |
| self.q_norm = RMSNorm(c); self.kv_norm = RMSNorm(c) |
|
|
| def forward(self, x): |
| B, T, d = x.shape; device = x.device |
| m, c, n_h, n_h_I, c_I, topk, n_win = self.m, self.c, self.n_h, self.n_h_I, self.c_I, self.topk, self.n_win |
| C_a = self.W_aKV(x); C_b = self.W_bKV(x); Z_a = self.W_aZ(x); Z_b = self.W_bZ(x) |
| KV_comp = compress_kv_csa(C_a, C_b, Z_a, Z_b, self.B_a, self.B_b, m) |
| T_comp = KV_comp.shape[1] |
| C_idx = self.W_idxKV(x); Z_idx = self.W_idxZ(x) |
| K_idx_comp = compress_kv(C_idx, Z_idx, self.B_idx, m) |
| c_Q = self.W_DQ(x) |
| q_I = self.W_IUQ(c_Q).reshape(B, T, n_h_I, c_I) |
| q = self.W_UQ(c_Q).reshape(B, T, n_h, c) |
| w_I = self.W_w(x) |
| idx_scores = torch.einsum('bthc,bsc->bths', q_I, K_idx_comp) |
| idx_scores = torch.einsum('bth,bths->bts', F.relu(w_I), F.relu(idx_scores)) |
| query_block = torch.arange(T, device=device) // m |
| causal_mask = query_block[:, None] > torch.arange(T_comp, device=device)[None, :] |
| idx_scores = idx_scores.masked_fill(~causal_mask, float('-inf')) |
| SW_KV = self.W_swKV(x) |
| SW_KV_padded = F.pad(SW_KV, (0, 0, n_win, 0)) |
| win_indices = torch.arange(n_win, device=device)[None, None, :] |
| query_pos = torch.arange(T, device=device)[None, :, None] |
| gather_idx = (query_pos + win_indices).clamp(0, T + n_win - 1).expand(B, -1, -1) |
| SW_gathered = SW_KV_padded[torch.arange(B, device=device)[:, None, None], gather_idx] |
| KV_all = torch.cat([KV_comp.unsqueeze(1).expand(-1, T, -1, -1), SW_gathered], dim=2) |
| n_kv = T_comp + n_win |
| q = self.q_norm(q.reshape(B * T * n_h, c)).reshape(B, T, n_h, c) |
| KV_all = self.kv_norm(KV_all.reshape(B * T * n_kv, c)).reshape(B, T, n_kv, c) |
| q_pos = torch.arange(T, device=device).repeat(B) |
| comp_positions = (torch.arange(T_comp, device=device) * m + m // 2) |
| sw_positions = torch.arange(T, device=device)[:, None] - torch.arange(n_win, device=device)[None, :] |
| sw_positions = sw_positions.clamp(min=0) |
| kv_positions = torch.cat([comp_positions.unsqueeze(0).expand(T, -1), sw_positions], dim=1) |
| kv_pos_flat = kv_positions.reshape(-1).repeat(B) |
| q_flat = q.reshape(B * T, n_h, c) |
| q_flat = self.rope._rotate(q_flat, q_pos) |
| q = q_flat.reshape(B, T, n_h, c) |
| kv_flat = KV_all.reshape(B * T * n_kv, 1, c) |
| kv_flat = self.rope._rotate(kv_flat, kv_pos_flat) |
| KV_all = kv_flat.reshape(B, T, n_kv, c) |
| KV_expanded = KV_all.unsqueeze(2).expand(-1, -1, n_h, -1, -1) |
| scale = c ** -0.5 |
| attn_logits = torch.einsum('bthc,bthkc->bthk', q, KV_expanded) * scale |
| idx_bias = F.pad(idx_scores, (0, n_win), value=0.0) |
| attn_logits = attn_logits + idx_bias[:, :, None, :] |
| causal_mask_comp = query_block[:, None] > torch.arange(T_comp, device=device)[None, :] |
| causal_mask_all = torch.cat([causal_mask_comp, torch.ones(T, n_win, dtype=torch.bool, device=device)], dim=1) |
| attn_logits = attn_logits.masked_fill(~causal_mask_all[None, :, None, :], float('-inf')) |
| sink = self.sink_logits[None, None, :, None] |
| attn_logits_with_sink = torch.cat([attn_logits, sink.expand(B, T, -1, -1)], dim=-1) |
| attn_weights = torch.softmax(attn_logits_with_sink, dim=-1)[..., :n_kv] |
| o = torch.einsum('bthk,bthkc->bthc', attn_weights, KV_expanded) |
| o_flat = o.reshape(B * T, n_h, c) |
| o_pos = torch.arange(T, device=device).repeat(B) |
| o_flat = self.rope.inverse(o_flat, o_pos) |
| o = o_flat.reshape(B, T, n_h, c) |
| hpg = n_h // self.g |
| o_groups = o.chunk(self.g, dim=2) |
| intermediates = [] |
| for proj, og in zip(self.group_proj, o_groups): |
| intermediates.append(proj(og.reshape(B, T, hpg * c))) |
| return self.out_proj(torch.cat(intermediates, dim=-1)) |
|
|
|
|
| class HCA(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| d = config.dim; c = config.head_dim; n_h = config.n_q_head |
| d_c = config.d_c; m = config.hca_m; n_win = config.n_win |
| g = config.n_out_group; d_g = config.d_g |
| self.d, self.c, self.n_h, self.d_c, self.m, self.n_win, self.g, self.d_g = d, c, n_h, d_c, m, n_win, g, d_g |
| self.W_KV = nn.Linear(d, c, bias=False); self.W_Z = nn.Linear(d, c, bias=False) |
| self.B_pos = nn.Parameter(torch.zeros(m, c)) |
| self.W_DQ = nn.Linear(d, d_c, bias=False) |
| self.W_UQ = nn.Linear(d_c, c * n_h, bias=False) |
| self.W_swKV = nn.Linear(d, c, bias=False) |
| assert n_h % g == 0 |
| hpg = n_h // g; god = hpg * c |
| self.group_proj = nn.ModuleList([nn.Linear(god, d_g, bias=False) for _ in range(g)]) |
| self.out_proj = nn.Linear(d_g * g, d, bias=False) |
| self.sink_logits = nn.Parameter(torch.zeros(n_h)) |
| self.rope = PartialRoPE(c, config.rope_dim, config.max_len) |
| self.q_norm = RMSNorm(c); self.kv_norm = RMSNorm(c) |
|
|
| def forward(self, x): |
| B, T, d = x.shape; device = x.device |
| m, c, n_h, n_win = self.m, self.c, self.n_h, self.n_win |
| C = self.W_KV(x); Z = self.W_Z(x) |
| KV_comp = compress_kv(C, Z, self.B_pos, m) |
| T_comp = KV_comp.shape[1] |
| c_Q = self.W_DQ(x) |
| q = self.W_UQ(c_Q).reshape(B, T, n_h, c) |
| SW_KV = self.W_swKV(x) |
| SW_KV_padded = F.pad(SW_KV, (0, 0, n_win, 0)) |
| win_indices = torch.arange(n_win, device=device)[None, None, :] |
| query_pos = torch.arange(T, device=device)[None, :, None] |
| gather_idx = (query_pos + win_indices).clamp(0, T + n_win - 1).expand(B, -1, -1) |
| SW_gathered = SW_KV_padded[torch.arange(B, device=device)[:, None, None], gather_idx] |
| KV_all = torch.cat([KV_comp.unsqueeze(1).expand(-1, T, -1, -1), SW_gathered], dim=2) |
| n_kv = T_comp + n_win |
| q = self.q_norm(q.reshape(B * T * n_h, c)).reshape(B, T, n_h, c) |
| KV_all = self.kv_norm(KV_all.reshape(B * T * n_kv, c)).reshape(B, T, n_kv, c) |
| q_pos = torch.arange(T, device=device).repeat(B) |
| comp_positions = (torch.arange(T_comp, device=device) * m + m // 2) |
| sw_positions = torch.arange(T, device=device)[:, None] - torch.arange(n_win, device=device)[None, :] |
| sw_positions = sw_positions.clamp(min=0) |
| kv_positions = torch.cat([comp_positions.unsqueeze(0).expand(T, -1), sw_positions], dim=1) |
| kv_pos_flat = kv_positions.reshape(-1).repeat(B) |
| q_flat = q.reshape(B * T, n_h, c) |
| q_flat = self.rope._rotate(q_flat, q_pos) |
| q = q_flat.reshape(B, T, n_h, c) |
| kv_flat = KV_all.reshape(B * T * n_kv, 1, c) |
| kv_flat = self.rope._rotate(kv_flat, kv_pos_flat) |
| KV_all = kv_flat.reshape(B, T, n_kv, c) |
| KV_expanded = KV_all.unsqueeze(2).expand(-1, -1, n_h, -1, -1) |
| scale = c ** -0.5 |
| attn_logits = torch.einsum('bthc,bthkc->bthk', q, KV_expanded) * scale |
| query_block = torch.arange(T, device=device) // m |
| causal_mask = (query_block[:, None] > torch.arange(T_comp, device=device)[None, :]) |
| causal_mask = torch.cat([causal_mask, torch.ones(T, n_win, dtype=torch.bool, device=device)], dim=1) |
| attn_logits = attn_logits.masked_fill(~causal_mask[None, :, None, :], float('-inf')) |
| sink = self.sink_logits[None, None, :, None] |
| attn_logits_with_sink = torch.cat([attn_logits, sink.expand(B, T, -1, -1)], dim=-1) |
| attn_weights = torch.softmax(attn_logits_with_sink, dim=-1)[..., :n_kv] |
| o = torch.einsum('bthk,bthkc->bthc', attn_weights, KV_expanded) |
| o_flat = o.reshape(B * T, n_h, c) |
| o_pos = torch.arange(T, device=device).repeat(B) |
| o_flat = self.rope.inverse(o_flat, o_pos) |
| o = o_flat.reshape(B, T, n_h, c) |
| hpg = n_h // self.g |
| o_groups = o.chunk(self.g, dim=2) |
| intermediates = [] |
| for proj, og in zip(self.group_proj, o_groups): |
| intermediates.append(proj(og.reshape(B, T, hpg * c))) |
| return self.out_proj(torch.cat(intermediates, dim=-1)) |
|
|
|
|
| class Expert(nn.Module): |
| def __init__(self, dim, intermediate): |
| super().__init__() |
| self.gate_proj = nn.Linear(dim, intermediate, bias=False) |
| self.up_proj = nn.Linear(dim, intermediate, bias=False) |
| self.down_proj = nn.Linear(intermediate, dim, bias=False) |
| def forward(self, x): |
| gate = torch.clamp(self.gate_proj(x), max=10.0) |
| up = torch.clamp(self.up_proj(x), min=-10.0, max=10.0) |
| return self.down_proj(F.silu(gate) * up) |
|
|
|
|
| class DeepSeekMoE(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| d = config.dim |
| self.use_hash = layer_idx < config.hash_layers |
| self.d, self.n_routed, self.n_active = d, config.n_routed, config.n_active |
| self.shared_experts = nn.ModuleList([Expert(d, config.expert_intermediate) for _ in range(config.n_shared)]) |
| self.routed_experts = nn.ModuleList([Expert(d, config.expert_intermediate) for _ in range(config.n_routed)]) |
| self.gate = nn.Linear(d, config.n_routed, bias=False) |
| self.register_buffer('e_bias', torch.zeros(config.n_routed)) |
| self.register_buffer('expert_counts', torch.zeros(config.n_routed)) |
|
|
| def forward(self, x): |
| B, T, d = x.shape; device = x.device |
| shared_out = sum(expert(x) for expert in self.shared_experts) |
| if self.use_hash: |
| pos = torch.arange(T, device=device) |
| expert_idx = pos % self.n_routed |
| routed_out = torch.zeros(B, T, d, device=device) |
| for e_idx in range(self.n_routed): |
| mask = (expert_idx == e_idx).float() |
| if mask.sum() > 0: |
| routed_out = routed_out + self.routed_experts[e_idx](x * mask[None, :, None]) * mask[None, :, None] |
| return shared_out + routed_out, torch.tensor(0.0, device=device) |
| gate_out = self.gate(x) |
| affinity = torch.sqrt(F.softplus(gate_out)) + self.e_bias |
| topk_weights, topk_indices = torch.topk(affinity, self.n_active, dim=-1) |
| topk_weights = F.softmax(topk_weights, dim=-1) |
| with torch.no_grad(): |
| counts = torch.zeros(self.n_routed, device=device) |
| for k in range(self.n_active): |
| counts.scatter_add_(0, topk_indices[..., k].reshape(-1), torch.ones(B * T, device=device)) |
| self.expert_counts = counts.detach() |
| routed_out = torch.zeros(B, T, d, device=device) |
| for e_idx in range(self.n_routed): |
| mask = (topk_indices == e_idx).any(dim=-1) |
| if mask.any(): |
| weight_mask = (topk_indices == e_idx).float() |
| weights = (topk_weights * weight_mask).sum(dim=-1) |
| routed_out[mask] = routed_out[mask] + self.routed_experts[e_idx](x[mask]) * weights[mask, None] |
| frac = counts / (B * T * self.n_active) |
| bal_loss = torch.dot(frac, self.e_bias) |
| return shared_out + routed_out, bal_loss |
|
|
| def update_bias(self): |
| if not self.use_hash: |
| with torch.no_grad(): |
| n_total = self.expert_counts.sum() |
| if n_total > 0: |
| target = n_total / self.n_routed |
| self.e_bias -= 0.001 * (self.expert_counts - target) / max(target, 1) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| d = config.dim; n_hc = config.n_hc |
| if layer_idx < 2: self.attn = HCA(config) |
| elif layer_idx % 2 == 0: self.attn = CSA(config) |
| else: self.attn = HCA(config) |
| self.mhc_attn = ManifoldHC(d, n_hc, config.sinkhorn_iters) |
| self.mhc_ffn = ManifoldHC(d, n_hc, config.sinkhorn_iters) |
| self.moe = DeepSeekMoE(config, layer_idx) |
|
|
| def forward(self, X): |
| X = self.mhc_attn(X, self.attn) |
| bl = [torch.tensor(0.0, device=X.device)] |
| def moe_fn(x): |
| out, b = self.moe(x); bl[0] = b; return out |
| X = self.mhc_ffn(X, moe_fn) |
| return X, bl[0] |
|
|
|
|
| class MTPModule(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| d = config.dim; n_hc = config.n_hc |
| self.proj_in = nn.Linear(d, d, bias=False) |
| self.mhc = ManifoldHC(d, n_hc, config.sinkhorn_iters) |
| self.attn = HCA(config) |
| self.norm = nn.LayerNorm(d) |
| self.head = nn.Linear(d, config.vocab_size, bias=False) |
|
|
| def forward(self, h, X): |
| h_proj = self.proj_in(h) |
| X = self.mhc(X, lambda x: self.attn(x)) |
| return self.head(self.norm(X[:, :, 0, :] + h_proj)) |
|
|
|
|
| class TinyV4(PreTrainedModel): |
| config_class = TinyV4Config |
| base_model_prefix = "tinyv4" |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| d = config.dim; n_hc = config.n_hc |
| self.embed = nn.Embedding(config.vocab_size, d) |
| self.expand = nn.Linear(d, n_hc * d, bias=False) |
| self.blocks = nn.ModuleList([TransformerBlock(config, i) for i in range(config.depth)]) |
| self.norm = nn.LayerNorm(d) |
| self.head = nn.Linear(d, config.vocab_size, bias=False) |
| self.mtp = MTPModule(config) if config.mtp_depth > 0 else None |
| self.post_init() |
|
|
| def forward(self, input_ids): |
| B, T = input_ids.shape; d = self.config.dim; n_hc = self.config.n_hc; device = input_ids.device |
| x = self.embed(input_ids) |
| X = self.expand(x).reshape(B, T, n_hc, d) |
| total_bal_loss = torch.tensor(0.0, device=device) |
| for block in self.blocks: |
| X, bl = block(X); total_bal_loss = total_bal_loss + bl |
| h = X[:, :, 0, :] |
| logits = self.head(self.norm(h)) |
| mtp_logits = self.mtp(h, X) if self.mtp else None |
| return logits, mtp_logits, total_bal_loss |
|
|
| def param_count(self): |
| return sum(p.numel() for p in self.parameters()) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| """Load TinyV4 from a directory containing model.safetensors + config.json.""" |
| model_path = pretrained_model_name_or_path |
|
|
| |
| config_file = os.path.join(model_path, "config.json") |
| if not os.path.exists(config_file): |
| raise FileNotFoundError(f"config.json not found in {model_path}") |
| with open(config_file, "r") as f: |
| config_dict = json.load(f) |
| config = TinyV4Config(**config_dict) |
|
|
| |
| model = cls(config) |
|
|
| |
| weights_file = os.path.join(model_path, "model.safetensors") |
| if not os.path.exists(weights_file): |
| raise FileNotFoundError(f"model.safetensors not found in {model_path}") |
|
|
| state_dict = safe_load(weights_file) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| return model |
|
|
| def save_pretrained(self, save_directory, **kwargs): |
| """Save TinyV4 config + weights to a directory.""" |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| |
| self.config.save_pretrained(save_directory) |
|
|
| |
| safe_save(self.state_dict(), os.path.join(save_directory, "model.safetensors")) |
|
|
|
|
| |
| |
| |
| def search_best_config(target_params=10_000_000, vocab_size=32000): |
| """Search for config that gives closest to target_params.""" |
| best_config = None |
| best_diff = float('inf') |
|
|
| configs = [ |
| |
| |
| TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=4, n_active=2, |
| n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8, |
| n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4, |
| idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0, |
| hash_layers=2, max_len=512), |
| TinyV4Config(vocab_size=vocab_size, dim=128, depth=8, n_hc=2, n_routed=4, n_active=2, |
| n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8, |
| n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4, |
| idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0, |
| hash_layers=3, max_len=512), |
| TinyV4Config(vocab_size=vocab_size, dim=160, depth=4, n_hc=2, n_routed=4, n_active=2, |
| n_shared=1, expert_intermediate=256, csa_m=4, csa_topk=16, hca_m=8, |
| n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4, |
| idx_head_dim=48, n_out_group=2, d_g=80, rope_dim=24, mtp_depth=0, |
| hash_layers=2, max_len=512), |
| TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=6, n_active=2, |
| n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8, |
| n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4, |
| idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0, |
| hash_layers=2, max_len=512), |
| TinyV4Config(vocab_size=vocab_size, dim=96, depth=8, n_hc=2, n_routed=4, n_active=2, |
| n_shared=1, expert_intermediate=128, csa_m=4, csa_topk=16, hca_m=8, |
| n_win=16, n_q_head=4, head_dim=48, d_c=48, n_idx_head=4, |
| idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0, |
| hash_layers=3, max_len=512), |
| TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=4, n_active=2, |
| n_shared=1, expert_intermediate=256, csa_m=4, csa_topk=16, hca_m=8, |
| n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4, |
| idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=0, |
| hash_layers=2, max_len=512), |
| |
| TinyV4Config(vocab_size=vocab_size, dim=128, depth=6, n_hc=2, n_routed=4, n_active=2, |
| n_shared=1, expert_intermediate=192, csa_m=4, csa_topk=16, hca_m=8, |
| n_win=16, n_q_head=4, head_dim=48, d_c=64, n_idx_head=4, |
| idx_head_dim=48, n_out_group=2, d_g=64, rope_dim=24, mtp_depth=1, |
| hash_layers=2, max_len=512), |
| ] |
|
|
| print(f"\n{'='*70}") |
| print(f"Searching for config closest to {target_params/1e6:.1f}M params (vocab={vocab_size})") |
| print(f"Note: tie_embeddings=True — embed & head share weights") |
| print(f"{'='*70}") |
|
|
| for cfg in configs: |
| model = TinyV4(cfg) |
| |
| model.head.weight = model.embed.weight |
| n = model.param_count() |
| diff = abs(n - target_params) |
| pct = (n - target_params) / target_params * 100 |
| print(f" dim={cfg.dim:3d} depth={cfg.depth} n_routed={cfg.n_routed} expert_int={cfg.expert_intermediate:3d} " |
| f"mtp={cfg.mtp_depth} → {n/1e6:.2f}M params ({pct:+.1f}%)") |
| if diff < best_diff: |
| best_diff = diff |
| best_config = cfg |
| del model |
|
|
| print(f"\n✅ Best config: {best_config.dim}d {best_config.depth}L → " |
| f"{TinyV4(best_config).param_count()/1e6:.2f}M params (with tie_embeddings)") |
| return best_config |
|
|
|
|
| |
| |
| |
| def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_k=50, device='cpu'): |
| model.eval() |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| idx = input_ids[:, -model.config.max_len:] |
| logits, _, _ = model(idx) |
| logits = logits[:, -1, :] / temperature |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float('-inf') |
| probs = torch.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
| return tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|