| """ |
| model.py -- SpikeWhaleLM: combined architecture from SpikeTransformer (My Project) + NanoWhale. |
| |
| Architecture flow: |
| Embedding |
| -> Engram delta (N-gram memory, My Project) |
| -> [expand to hc_mult copies if HC enabled] |
| -> N x TransformerBlock: |
| HC pre-op (NanoWhale) -> RMSNorm -> MLA+DERF+XSA Attention (combined) |
| -> HC post-op |
| HC pre-op -> RMSNorm -> MoE FFN w/ shared expert (NanoWhale) |
| -> HC post-op |
| -> [mean-pool hc_mult copies if HC enabled] |
| -> RMSNorm |
| -> LM head + MTP heads (NanoWhale) |
| |
| Component origins: |
| RMSNorm, RotaryEmbedding -- both (standard) |
| Engram / DERFContextGate -- My Project |
| MLADerfXSAAttention -- MLA from NanoWhale + DERF+XSA from My Project |
| SparseMoEFFN w/ shared expert -- NanoWhale MoE structure + My Project aux loss |
| HyperConnectionLayer -- NanoWhale |
| SpikeWhaleLM + MTP heads -- NanoWhale |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, List |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from torch.utils.checkpoint import checkpoint as gradient_checkpoint |
|
|
| from config import SpikeWhaleConfig |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """RoPE for the rope partition of Q and K (qk_rope_head_dim dims only).""" |
|
|
| def __init__(self, dim: int, max_positions: int = 4096, theta: float = 10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| t = torch.arange(max_positions).float() |
| freqs = torch.outer(t, inv_freq) |
| self.register_buffer("cos_cache", freqs.cos()) |
| self.register_buffer("sin_cache", freqs.sin()) |
|
|
| def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| x: [B, H, S, rope_dim] |
| position_ids: [B, S] |
| """ |
| cos = self.cos_cache[position_ids].unsqueeze(1) |
| sin = self.sin_cache[position_ids].unsqueeze(1) |
| d = cos.shape[-1] |
| x1, x2 = x[..., :d], x[..., d:] |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class TokenCompressor(nn.Module): |
| def __init__(self, embed_dim: int, compress_dim: int): |
| super().__init__() |
| self.proj = nn.Linear(embed_dim, compress_dim, bias=False) |
| nn.init.normal_(self.proj.weight, std=0.02) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.proj.weight.requires_grad_(False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.proj(x) |
|
|
|
|
| class MultiHeadHashLookup(nn.Module): |
| def __init__(self, num_heads: int, table_size: int, |
| compress_dim: int, out_dim: int, max_ngram: int = 3): |
| super().__init__() |
| self.num_heads = num_heads |
| self.table_size = table_size |
| self.max_ngram = max_ngram |
| self.out_dim = out_dim |
|
|
| self.tables = nn.ModuleList([ |
| nn.Embedding(table_size, out_dim) for _ in range(num_heads) |
| ]) |
| for t in self.tables: |
| nn.init.normal_(t.weight, std=0.01) |
|
|
| for n in range(1, max_ngram + 1): |
| for k in range(n): |
| proj = torch.randn(num_heads, compress_dim) |
| proj = proj / (proj.norm(dim=1, keepdim=True) + 1e-8) |
| self.register_buffer(f"hash_proj_n{n}_p{k}", proj) |
|
|
| def forward(self, compressed: torch.Tensor) -> torch.Tensor: |
| """ |
| compressed: [B, S, compress_dim] |
| returns: [B, S, out_dim] |
| |
| All positions are processed in parallel. The outer loop runs max_ngram |
| times (≤3), not S times (≤2048). Each iteration is a single matmul + |
| embedding lookup across the whole sequence, making this GPU-friendly |
| and compatible with torch.compile. |
| """ |
| B, S, _ = compressed.shape |
| device = compressed.device |
| out = torch.zeros(B, S, self.out_dim, device=device, dtype=compressed.dtype) |
| |
| |
| |
| norm = torch.zeros(S, device=device) |
|
|
| for n in range(1, self.max_ngram + 1): |
| if S < n: |
| continue |
| valid_len = S - n + 1 |
| start = n - 1 |
|
|
| |
| |
| |
| h = torch.zeros(B, valid_len, self.num_heads, device=device) |
| for k in range(n): |
| proj = getattr(self, f"hash_proj_n{n}_p{k}") |
| h = h + torch.matmul(compressed[:, k:k + valid_len, :].float(), proj.t()) |
|
|
| idx = h.abs().long() % self.table_size |
|
|
| for head_idx, table in enumerate(self.tables): |
| out[:, start:, :] = out[:, start:, :] + table(idx[:, :, head_idx]) |
|
|
| norm[start:] += self.num_heads |
|
|
| |
| |
| |
| return (out / norm.view(1, -1, 1).clamp(min=1)).to(compressed.dtype) |
|
|
|
|
| class DERFContextGate(nn.Module): |
| """ |
| DERF gate: gate = gamma * erf(alpha * proj([retrieved, x]) + bias) |
| Positive probability = (gate + 1) / 2 applied to retrieved embedding. |
| Large negative init_bias keeps gate closed at start of training. |
| """ |
| def __init__(self, obs_size: int, init_bias: float = -4.0): |
| super().__init__() |
| self.proj = nn.Linear(obs_size * 2, obs_size) |
| self.alpha = nn.Parameter(torch.ones(obs_size)) |
| self.bias = nn.Parameter(torch.full((obs_size,), init_bias)) |
| self.gamma = nn.Parameter(torch.ones(obs_size)) |
|
|
| def forward(self, retrieved: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
| logits = self.proj(torch.cat([retrieved, x], dim=-1)) |
| gate = self.gamma * ((torch.erf(self.alpha * logits + self.bias) + 1.0) / 2.0) |
| return retrieved * gate |
|
|
|
|
| class EngramModule(nn.Module): |
| """ |
| N-gram hash lookup with DERF gate (My Project), fully vectorized. |
| All S positions are processed in parallel — the sequential Python loop |
| over sequence positions has been eliminated. The lookup now accepts the |
| full [B, S, compress_dim] compressed tensor and returns [B, S, H] in one pass. |
| """ |
| def __init__(self, cfg: SpikeWhaleConfig): |
| super().__init__() |
| self.compressor = TokenCompressor(cfg.hidden_size, cfg.engram_compress_dim) |
| self.lookup = MultiHeadHashLookup( |
| cfg.engram_num_heads, cfg.engram_table_size, |
| cfg.engram_compress_dim, cfg.hidden_size, cfg.engram_max_ngram, |
| ) |
| self.gate = DERFContextGate(cfg.hidden_size, cfg.engram_gate_init_bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """x: [B, S, H] -> engram_delta: [B, S, H]""" |
| compressed = self.compressor(x.detach()) |
| retrieved = self.lookup(compressed) |
| return self.gate(retrieved, x) |
|
|
|
|
| |
| |
| |
|
|
| class HyperConnectionLayer(nn.Module): |
| """ |
| Simplified Hyper-Connections for one sublayer (attention or FFN). |
| |
| Maintains hc_mult parallel residual streams. |
| Pre-op: learned weighted average of hc_mult copies -> single hidden state for sublayer. |
| Post-op: sublayer output added to each copy with learned per-stream weights. |
| |
| Full HC uses Sinkhorn-normalized 2D routing matrices; this uses softmax-normalized |
| 1D weights for pre/post routing -- captures the same multi-stream routing spirit. |
| """ |
| def __init__(self, hidden_size: int, hc_mult: int, |
| sinkhorn_iters: int = 20, eps: float = 1e-6): |
| super().__init__() |
| self.hc_mult = hc_mult |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.pre_weight = nn.Parameter( |
| torch.linspace(0.5, -0.5, hc_mult) / max(hc_mult, 1) |
| ) |
| self.post_weight = nn.Parameter( |
| torch.linspace(-0.5, 0.5, hc_mult) / max(hc_mult, 1) |
| ) |
|
|
| def pre_op(self, copies: torch.Tensor) -> torch.Tensor: |
| """copies: [B, hc_mult, S, H] -> [B, S, H]""" |
| w = F.softmax(self.pre_weight, dim=0) |
| return (copies * w.view(1, -1, 1, 1)).sum(dim=1) |
|
|
| def post_op(self, copies: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: |
| """ |
| copies: [B, hc_mult, S, H] |
| delta: [B, S, H] |
| Returns updated copies: [B, hc_mult, S, H] |
| """ |
| w = F.softmax(self.post_weight, dim=0) |
| return copies + delta.unsqueeze(1) * w.view(1, -1, 1, 1) |
|
|
|
|
| |
| |
| |
|
|
| class MLADerfXSAAttention(nn.Module): |
| """ |
| Multi-Head Latent Attention (NanoWhale) with DERF scores + XSA correction (My Project). |
| |
| MLA (from NanoWhale): |
| Q: hidden -> q_lora_rank (RMSNorm) -> num_heads * head_dim (low-rank projection) |
| K, V: hidden -> num_kv_heads * head_dim (direct, MQA by default with num_kv_heads=1) |
| Output: num_heads * head_dim -> o_lora_rank -> hidden (low-rank output) |
| Partial RoPE: applied only to the last qk_rope_head_dim dims of Q and K |
| |
| DERF (from My Project): |
| Replaces softmax: erf(alpha * scores + bias) * gamma, shifted to [0,1] then normalized. |
| Per-head learnable alpha, bias, gamma. |
| |
| XSA (from My Project): |
| After computing the weighted value sum y, subtract the component of y that |
| projects onto each position's own value vector. Forces the output to carry |
| only cross-position information, not echo the current token back. |
| """ |
|
|
| def __init__(self, cfg: SpikeWhaleConfig): |
| super().__init__() |
| self.num_heads = cfg.num_attention_heads |
| self.num_kv_heads = cfg.num_key_value_heads |
| self.head_dim = cfg.head_dim |
| self.qk_rope_head_dim = cfg.qk_rope_head_dim |
| self.nope_head_dim = cfg.nope_head_dim |
| self.hidden_size = cfg.hidden_size |
| self.use_derf = cfg.use_derf |
| self.use_xsa = cfg.use_xsa |
| self.dropout_p = cfg.attention_dropout |
| self.kv_groups = self.num_heads // self.num_kv_heads |
|
|
| |
| self.q_a_proj = nn.Linear(cfg.hidden_size, cfg.q_lora_rank, bias=False) |
| self.q_a_norm = RMSNorm(cfg.q_lora_rank, cfg.rms_norm_eps) |
| self.q_b_proj = nn.Linear(cfg.q_lora_rank, self.num_heads * self.head_dim, bias=False) |
|
|
| |
| self.k_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
|
|
| |
| self.o_a_proj = nn.Linear(self.num_heads * self.head_dim, cfg.o_lora_rank, bias=False) |
| self.o_b_proj = nn.Linear(cfg.o_lora_rank, cfg.hidden_size, bias=False) |
|
|
| |
| self.rope = RotaryEmbedding( |
| self.qk_rope_head_dim, |
| max_positions=cfg.max_position_embeddings, |
| theta=cfg.rope_theta, |
| ) |
|
|
| |
| if self.use_derf: |
| self.derf_alpha = nn.Parameter(torch.ones(self.num_heads)) |
| self.derf_bias = nn.Parameter(torch.zeros(self.num_heads)) |
| self.derf_gamma = nn.Parameter(torch.ones(self.num_heads)) |
|
|
| nn.init.normal_(self.q_a_proj.weight, std=cfg.initializer_range) |
| nn.init.normal_(self.q_b_proj.weight, std=cfg.initializer_range) |
| nn.init.normal_(self.k_proj.weight, std=cfg.initializer_range) |
| nn.init.normal_(self.v_proj.weight, std=cfg.initializer_range) |
| nn.init.normal_(self.o_a_proj.weight, std=cfg.initializer_range) |
| nn.init.normal_(self.o_b_proj.weight, std=cfg.initializer_range) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| position_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| B, S, _ = x.shape |
|
|
| |
| q = self.q_a_norm(self.q_a_proj(x)) |
| q = self.q_b_proj(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
|
| |
| k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| q_nope = q[..., :self.nope_head_dim] |
| q_rope = q[..., self.nope_head_dim:] |
| k_nope = k[..., :self.nope_head_dim] |
| k_rope = k[..., self.nope_head_dim:] |
|
|
| q_rope = self.rope(q_rope, position_ids) |
| k_rope = self.rope(k_rope, position_ids) |
|
|
| q = torch.cat([q_nope, q_rope], dim=-1) |
| k = torch.cat([k_nope, k_rope], dim=-1) |
|
|
| |
| if past_key_value is not None: |
| k = torch.cat([past_key_value[0], k], dim=2) |
| v = torch.cat([past_key_value[1], v], dim=2) |
| present = (k, v) if use_cache else None |
| N = k.shape[2] |
|
|
| |
| if self.kv_groups > 1: |
| k = k.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( |
| B, self.num_heads, N, self.head_dim) |
| v = v.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( |
| B, self.num_heads, N, self.head_dim) |
|
|
| |
| if self.use_derf: |
| |
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
|
| |
| if attention_mask is None and past_key_value is None: |
| is_masked = torch.triu(torch.ones(S, N, dtype=torch.bool, device=scores.device), diagonal=N - S + 1).unsqueeze(0).unsqueeze(0) |
| else: |
| is_masked = (attention_mask < -1.0) if attention_mask is not None else torch.zeros_like(scores, dtype=torch.bool) |
|
|
| |
| |
| safe_scores = scores.masked_fill(is_masked, -10000.0) |
|
|
| a = self.derf_alpha.view(1, -1, 1, 1) |
| b = self.derf_bias.view(1, -1, 1, 1) |
| g = self.derf_gamma.view(1, -1, 1, 1) |
|
|
| attn_weights = g * torch.erf(a * safe_scores + b) |
| attn_weights = (attn_weights + g) / 2.0 |
| attn_weights = attn_weights.masked_fill(is_masked, 0.0) |
| attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-8) |
|
|
| if self.dropout_p > 0 and self.training: |
| attn_weights = F.dropout(attn_weights, p=self.dropout_p) |
|
|
| y = torch.matmul(attn_weights, v) |
| else: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| q = q.contiguous() |
| k = k.contiguous() |
| v = v.contiguous() |
| drop = self.dropout_p if self.training else 0.0 |
| if past_key_value is None and attention_mask is None: |
| |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=drop) |
| else: |
| |
| |
| if attention_mask is not None: |
| is_masked = (attention_mask < -1.0) |
| else: |
| is_masked = torch.triu( |
| torch.ones(S, N, dtype=torch.bool, device=q.device), |
| diagonal=N - S + 1, |
| ).unsqueeze(0).unsqueeze(0) |
| y = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=~is_masked, dropout_p=drop) |
|
|
| |
| |
| |
| if self.use_xsa: |
| past_len = N - S |
| v_self = v[:, :, past_len:past_len + S, :] |
| vn = v_self / (v_self.norm(dim=-1, keepdim=True) + 1e-8) |
| projection = (y * vn).sum(dim=-1, keepdim=True) * vn |
| y = y - projection |
|
|
| |
| y = y.transpose(1, 2).contiguous().view(B, S, self.num_heads * self.head_dim) |
| y = self.o_b_proj(self.o_a_proj(y)) |
| return y, present |
|
|
|
|
| |
| |
| |
|
|
| class ExpertFFN(nn.Module): |
| """Single SwiGLU expert.""" |
| def __init__(self, hidden_size: int, intermediate_size: int): |
| super().__init__() |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| def sqrtsoftplus(x: torch.Tensor) -> torch.Tensor: |
| """sqrt(softplus(x)) = sqrt(log(1+exp(x))). NanoWhale expert scoring.""" |
| |
| return torch.sqrt(F.softplus(x) + 1e-8) |
|
|
|
|
| class SparseMoEFFN(nn.Module): |
| """ |
| Combines NanoWhale MoE structure with My Project aux loss: |
| - n_shared_experts always-active experts (NanoWhale) |
| - n_routed_experts sparse routed experts, top-k activation |
| - sqrtsoftplus scoring (NanoWhale) vs softmax |
| - hash routing for early layers (NanoWhale) |
| - norm_topk_prob + routed_scaling_factor (NanoWhale) |
| - load-balancing aux loss (My Project) |
| """ |
| def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int = 0): |
| super().__init__() |
| self.n_routed_experts = cfg.n_routed_experts |
| self.n_shared_experts = cfg.n_shared_experts |
| self.num_experts_per_tok = cfg.num_experts_per_tok |
| self.norm_topk_prob = cfg.norm_topk_prob |
| self.scoring_func = cfg.scoring_func |
| self.routed_scaling_factor = cfg.routed_scaling_factor |
| self.use_hash_routing = layer_idx < cfg.num_hash_layers |
| self.aux_loss_coef = cfg.moe_aux_loss_coef |
|
|
| self.router = nn.Linear(cfg.hidden_size, cfg.n_routed_experts, bias=False) |
| self.experts = nn.ModuleList([ |
| ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size) |
| for _ in range(cfg.n_routed_experts) |
| ]) |
| self.shared_experts = nn.ModuleList([ |
| ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size) |
| for _ in range(cfg.n_shared_experts) |
| ]) if cfg.n_shared_experts > 0 else None |
|
|
| self._last_aux_loss: Optional[torch.Tensor] = None |
|
|
| def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: |
| B, S, H = x.shape |
| x_flat = x.view(B * S, H) |
| T = B * S |
|
|
| |
| shared_out = torch.zeros_like(x_flat) |
| if self.shared_experts: |
| for expert in self.shared_experts: |
| shared_out = shared_out + expert(x_flat) |
| if len(self.shared_experts) > 1: |
| shared_out = shared_out / len(self.shared_experts) |
|
|
| |
| if self.use_hash_routing: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if position_ids is not None: |
| base = (position_ids.reshape(T, 1) % self.n_routed_experts).long() |
| else: |
| base = (torch.arange(T, device=x.device) % self.n_routed_experts).unsqueeze(1) |
| offsets = torch.arange(self.num_experts_per_tok, device=x.device) |
| top_k_indices = (base + offsets.unsqueeze(0)) % self.n_routed_experts |
| top_k_weights = torch.ones(T, self.num_experts_per_tok, device=x.device) / self.num_experts_per_tok |
| self._last_aux_loss = None |
| else: |
| router_logits = self.router(x_flat) |
|
|
| if self.scoring_func == "sqrtsoftplus": |
| routing_scores = sqrtsoftplus(router_logits) |
| else: |
| routing_scores = F.softmax(router_logits, dim=-1) |
|
|
| top_k_scores, top_k_indices = torch.topk(routing_scores, self.num_experts_per_tok, dim=-1) |
|
|
| if self.norm_topk_prob: |
| top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-8) |
| else: |
| top_k_weights = top_k_scores |
| top_k_weights = top_k_weights * self.routed_scaling_factor |
|
|
| |
| softmax_probs = F.softmax(router_logits, dim=-1) |
| expert_mask = torch.zeros_like(softmax_probs) |
| expert_mask.scatter_(1, top_k_indices, 1.0) |
| f_e = expert_mask.mean(0) |
| p_e = softmax_probs.mean(0) |
| self._last_aux_loss = self.n_routed_experts * (f_e * p_e).sum() * self.aux_loss_coef |
|
|
| |
| out = torch.zeros_like(x_flat) |
| for expert_idx, expert in enumerate(self.experts): |
| token_mask = (top_k_indices == expert_idx).any(dim=-1) |
| if not token_mask.any(): |
| continue |
| expert_input = x_flat[token_mask] |
| expert_output = expert(expert_input) |
| k_pos = (top_k_indices[token_mask] == expert_idx).nonzero(as_tuple=False) |
| weights = top_k_weights[token_mask][k_pos[:, 0], k_pos[:, 1]].unsqueeze(-1) |
| out[token_mask] = out[token_mask] + expert_output * weights |
|
|
| out = out + shared_out |
| return out.view(B, S, H) |
|
|
| def get_aux_loss(self) -> Optional[torch.Tensor]: |
| |
| |
| |
| return self._last_aux_loss |
|
|
|
|
| class DenseFFN(nn.Module): |
| """Dense SwiGLU FFN for non-MoE layers.""" |
| def __init__(self, cfg: SpikeWhaleConfig): |
| super().__init__() |
| self.gate_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) |
| self.up_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) |
| self.down_proj = nn.Linear(cfg.moe_intermediate_size, cfg.hidden_size, bias=False) |
|
|
| def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
| def get_aux_loss(self) -> Optional[torch.Tensor]: |
| return None |
|
|
|
|
| |
| |
| |
|
|
| class TransformerBlock(nn.Module): |
| """ |
| Transformer block combining all features: |
| - Hyper-Connections: pre/post routing through hc_mult streams (NanoWhale) |
| - MLA + DERF + XSA attention (combined) |
| - MoE FFN with shared expert (NanoWhale) + aux loss (My Project) |
| """ |
| def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int): |
| super().__init__() |
| self.use_hc = cfg.use_hyper_connections |
| self.hidden_dropout = cfg.hidden_dropout |
|
|
| self.attn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) |
| self.attn = MLADerfXSAAttention(cfg) |
| self.ffn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) |
|
|
| if cfg.use_moe and layer_idx in cfg.moe_layers: |
| self.ffn = SparseMoEFFN(cfg, layer_idx) |
| self.is_moe = True |
| else: |
| self.ffn = DenseFFN(cfg) |
| self.is_moe = False |
|
|
| if self.use_hc: |
| self.hc_attn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, |
| cfg.hc_sinkhorn_iters, cfg.hc_eps) |
| self.hc_ffn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, |
| cfg.hc_sinkhorn_iters, cfg.hc_eps) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| position_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple] = None, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[Tuple], Optional[torch.Tensor]]: |
|
|
| |
| if self.use_hc: |
| h = self.hc_attn.pre_op(x) |
| else: |
| h = x |
|
|
| attn_out, present = self.attn( |
| self.attn_norm(h), position_ids, attention_mask, past_key_value, use_cache |
| ) |
| attn_out = F.dropout(attn_out, p=self.hidden_dropout, training=self.training) |
|
|
| if self.use_hc: |
| x = self.hc_attn.post_op(x, attn_out) |
| h = self.hc_ffn.pre_op(x) |
| else: |
| h = h + attn_out |
|
|
| |
| ffn_out = self.ffn(self.ffn_norm(h), position_ids) |
| ffn_out = F.dropout(ffn_out, p=self.hidden_dropout, training=self.training) |
|
|
| if self.use_hc: |
| x = self.hc_ffn.post_op(x, ffn_out) |
| else: |
| x = h + ffn_out |
|
|
| return x, present, self.ffn.get_aux_loss() |
|
|
|
|
| |
| |
| |
|
|
| class HRMRefinementBlock(nn.Module): |
| """ |
| HRM-INSPIRED iterative refinement (EXPERIMENTAL, off by default). NOT the full |
| Hierarchical Reasoning Model -- only the iterative-refinement mechanism that the |
| independent ARC-Prize ablation found carried most of HRM's benefit, adapted to a |
| causal LM's final hidden state. |
| |
| Runs N inner steps; each computes a small gated update conditioned on the current |
| state AND the original ('anchor') input. Per-step gate inits at 0 and up.weight is |
| zero-init -> the block is an EXACT identity at init, so enabling it cannot hurt a |
| fresh model; it only contributes if training opens the gate. Pointwise over |
| positions -> causal-safe (no future-token leakage). In/out [B,S,H]. |
| """ |
| def __init__(self, hidden_size: int, refine_dim: int, steps: int, eps: float = 1e-6): |
| super().__init__() |
| self.steps = steps |
| self.norm = RMSNorm(hidden_size, eps) |
| self.down = nn.Linear(hidden_size * 2, refine_dim, bias=False) |
| self.up = nn.Linear(refine_dim, hidden_size, bias=False) |
| self.gate = nn.Parameter(torch.zeros(steps)) |
| nn.init.normal_(self.down.weight, std=0.02) |
| nn.init.zeros_(self.up.weight) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| anchor = x |
| h = x |
| for t in range(self.steps): |
| inp = torch.cat([self.norm(h), anchor], dim=-1) |
| update = self.up(F.silu(self.down(inp))) |
| h = h + torch.tanh(self.gate[t]) * update |
| return h |
|
|
|
|
| class LatentProjection(nn.Module): |
| """ModularMind-on-V2: pool final hidden state -> d_latent output vector. |
| Mirrors ModularMind's contract: mean-pool over sequence, ReLU^2 activation |
| (sparse latent codes), Xavier init (NOT zero) so the latent carries signal |
| from step 1 — zero-init would make the chain unable to bootstrap.""" |
| def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6): |
| super().__init__() |
| self.proj1 = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.proj2 = nn.Linear(hidden_size, d_latent, bias=False) |
| self.norm = RMSNorm(d_latent, eps) |
| nn.init.xavier_uniform_(self.proj1.weight) |
| nn.init.xavier_uniform_(self.proj2.weight) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| pooled = x.mean(dim=1) |
| h = torch.relu(self.proj1(pooled)) ** 2 |
| return self.norm(self.proj2(h)) |
|
|
|
|
| class LatentInjection(nn.Module): |
| """ModularMind-on-V2: fold an incoming d_latent vector into embeddings. |
| Broadcast across positions, ReGLU-gated add. Gate starts SMALL (not exactly |
| zero): the injection is near-identity at init (stable) while still passing a |
| little gradient, so the upstream RecursiveLink + specialist can bootstrap from |
| step 1. (Exact-zero gate would block all gradient to the link -- the |
| bootstrapping problem ModularMind's LatentProjection docstring warns about.) |
| This is the INPUT side of RecursiveLink (the prev specialist's latent).""" |
| def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6, |
| gate_init: float = 1e-3): |
| super().__init__() |
| self.up = nn.Linear(d_latent, hidden_size, bias=False) |
| self.norm = RMSNorm(hidden_size, eps) |
| self.value_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.gate_init = gate_init |
| nn.init.xavier_uniform_(self.up.weight) |
| nn.init.xavier_uniform_(self.value_proj.weight) |
| nn.init.normal_(self.gate_proj.weight, std=gate_init) |
|
|
| def forward(self, x: torch.Tensor, latent: torch.Tensor) -> torch.Tensor: |
| |
| inj = self.norm(self.up(latent)).unsqueeze(1) |
| value = self.value_proj(inj) |
| gate = torch.relu(self.gate_proj(inj)) |
| return x + value * gate |
|
|
|
|
| class RecursiveLink(nn.Module): |
| """ModularMind cross-specialist bridge, V2 build. Converts one specialist's |
| output latent into the next specialist's input latent. ReGLU + residual, |
| single shared module reused for every hop. Fully differentiable.""" |
| def __init__(self, d_latent: int = 256, expansion: float = 2.0): |
| super().__init__() |
| d_hidden = int(d_latent * expansion) |
| self.norm = nn.LayerNorm(d_latent) |
| self.value_proj = nn.Linear(d_latent, d_hidden, bias=False) |
| self.gate_proj = nn.Linear(d_latent, d_hidden, bias=False) |
| self.down = nn.Linear(d_hidden, d_latent, bias=False) |
| self.residual_gate = nn.Parameter(torch.ones(1)) |
| nn.init.xavier_uniform_(self.value_proj.weight) |
| nn.init.xavier_uniform_(self.gate_proj.weight) |
| nn.init.xavier_uniform_(self.down.weight) |
|
|
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| n = self.norm(z) |
| h = self.value_proj(n) * torch.relu(self.gate_proj(n)) |
| return z + self.residual_gate * self.down(h) |
|
|
|
|
| class SpikeWhaleModel(nn.Module): |
| """Decoder stack without LM head.""" |
|
|
| def __init__(self, cfg: SpikeWhaleConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size) |
| nn.init.normal_(self.embed_tokens.weight, std=cfg.initializer_range) |
|
|
| self.engram = EngramModule(cfg) if cfg.use_engram else None |
| self.layers = nn.ModuleList([ |
| TransformerBlock(cfg, layer_idx=i) |
| for i in range(cfg.num_hidden_layers) |
| ]) |
| self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) |
| self.hrm_refine = ( |
| HRMRefinementBlock(cfg.hidden_size, cfg.hrm_refine_dim, |
| cfg.hrm_refine_steps, cfg.rms_norm_eps) |
| if getattr(cfg, "use_hrm_refine", False) else None |
| ) |
| |
| if getattr(cfg, "use_latent_io", False): |
| self.latent_inject = LatentInjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps) |
| self.latent_out = LatentProjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps) |
| else: |
| self.latent_inject = None |
| self.latent_out = None |
| self.gradient_checkpointing = False |
|
|
| def reset_latent_gate(self): |
| """Re-init the injection gate SMALL (not zero). Must be called AFTER any HF |
| post_init/_init_weights pass, which otherwise re-randomizes the gate to full |
| scale. Small-but-nonzero keeps injection near-identity at start while letting |
| gradient reach the upstream RecursiveLink (so the chain can bootstrap).""" |
| if self.latent_inject is not None: |
| nn.init.normal_(self.latent_inject.gate_proj.weight, |
| std=self.latent_inject.gate_init) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[Tuple]] = None, |
| use_cache: bool = False, |
| inject_latent: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[List[Tuple]], torch.Tensor]: |
| B, S = input_ids.shape |
| device = input_ids.device |
|
|
| if position_ids is None: |
| past_len = past_key_values[0][0].shape[2] if past_key_values else 0 |
| position_ids = torch.arange( |
| past_len, past_len + S, device=device |
| ).unsqueeze(0).expand(B, -1) |
|
|
| |
| x = self.embed_tokens(input_ids) |
|
|
| |
| if self.engram is not None: |
| x = x + self.engram(x) |
|
|
| |
| |
| |
| if self.latent_inject is not None and inject_latent is not None: |
| x = self.latent_inject(x, inject_latent) |
|
|
| |
| if self.cfg.use_hyper_connections: |
| x = x.unsqueeze(1).expand(-1, self.cfg.hc_mult, -1, -1).clone() |
| |
|
|
| present_key_values = [] if use_cache else None |
| total_aux_loss = torch.tensor(0.0, device=device) |
|
|
| for layer_idx, layer in enumerate(self.layers): |
| pkv = past_key_values[layer_idx] if past_key_values else None |
|
|
| if self.gradient_checkpointing and self.training: |
| |
| x, present, aux_loss = gradient_checkpoint( |
| layer, x, position_ids, attention_mask, None, False, |
| use_reentrant=False, |
| ) |
| else: |
| x, present, aux_loss = layer(x, position_ids, attention_mask, pkv, use_cache) |
|
|
| if use_cache: |
| present_key_values.append(present) |
| if aux_loss is not None: |
| total_aux_loss = total_aux_loss + aux_loss |
|
|
| |
| if self.cfg.use_hyper_connections: |
| x = x.mean(dim=1) |
|
|
| if self.hrm_refine is not None: |
| x = self.hrm_refine(x) |
|
|
| x = self.norm(x) |
|
|
| |
| out_latent = self.latent_out(x) if self.latent_out is not None else None |
| return x, present_key_values, total_aux_loss, out_latent |
|
|
|
|
| class SpikeWhaleLM(PreTrainedModel): |
| """ |
| Full causal LM combining all SpikeTransformer + NanoWhale features. |
| |
| Training (forward with labels): |
| out = model(input_ids=ids, labels=ids) |
| loss = out.loss # CE + MTP loss + MoE aux loss |
| |
| Generation: |
| out = model(input_ids=ids, use_cache=True) |
| past = out.past_key_values |
| out2 = model(input_ids=next_id, past_key_values=past, use_cache=True) |
| """ |
| config_class = SpikeWhaleConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["TransformerBlock"] |
|
|
| def __init__(self, cfg: SpikeWhaleConfig): |
| super().__init__(cfg) |
| self.model = SpikeWhaleModel(cfg) |
| self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) |
| nn.init.normal_(self.lm_head.weight, std=cfg.initializer_range) |
|
|
| if cfg.tie_word_embeddings: |
| self.lm_head.weight = self.model.embed_tokens.weight |
|
|
| |
| self.mtp_heads = nn.ModuleList([ |
| nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) |
| for _ in range(cfg.num_nextn_predict_layers) |
| ]) if cfg.num_nextn_predict_layers > 0 else None |
|
|
| self.post_init() |
| |
| |
| self.model.reset_latent_gate() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, SpikeWhaleModel): |
| module.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[Tuple]] = None, |
| labels: Optional[torch.Tensor] = None, |
| use_cache: bool = False, |
| inject_latent: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| hidden, present_kvs, aux_loss, out_latent = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| inject_latent=inject_latent, |
| ) |
|
|
| logits = self.lm_head(hidden) |
| loss = None |
|
|
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| |
| |
| if self.mtp_heads is not None: |
| mtp_total = torch.tensor(0.0, device=loss.device) |
| for k, head in enumerate(self.mtp_heads, start=1): |
| offset = k + 1 |
| if hidden.size(1) > offset: |
| mtp_logits = head(hidden[..., :-offset, :].contiguous()) |
| mtp_labels = labels[..., offset:].contiguous() |
| mtp_total = mtp_total + F.cross_entropy( |
| mtp_logits.view(-1, mtp_logits.size(-1)), |
| mtp_labels.view(-1), |
| ignore_index=-100, |
| ) |
| loss = loss + mtp_total / max(len(self.mtp_heads), 1) |
|
|
| |
| loss = loss + aux_loss |
|
|
| out = CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=present_kvs, |
| ) |
| out.latent = out_latent |
| return out |
|
|
| def count_parameters(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |