""" GLADIUS v2.0 — Hybrid Attention (SLA2-inspired) The core attention mechanism: α-blended softmax + linear attention. Linear path: O(n) — cheap background awareness of all tokens. Softmax path: O(n·k) — precise attention for important token pairs. α: Per-token, learned. High when precision matters, low for routine. This is the SLA2 principle applied to GLADIUS: O = α ⊙ softmax_attention(Q, K_important, V) + (1-α) ⊙ linear_attention(Q, K_all, V) Reference: Ali's SLA2 attention pipeline diagram (ali-ref/img-09.jpg) """ import torch import torch.nn as nn import torch.nn.functional as F import math from .config import KernelConfig class RoPE(nn.Module): """Rotary Position Embeddings (Su et al., 2021). Applied to Q and K before attention computation. Does not interfere with our additive temporal encoding (they operate in different subspaces — RoPE is rotational, time encoding is additive). """ def __init__(self, head_dim: int, max_seq_len: int = 2048): super().__init__() # Precompute frequency bands inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) self.register_buffer('inv_freq', inv_freq) # Precompute rotation matrices for max_seq_len t = torch.arange(max_seq_len).float() freqs = torch.einsum('i,j->ij', t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer('cos_cached', emb.cos()) self.register_buffer('sin_cached', emb.sin()) def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor: """Apply rotary embeddings. x: (batch, heads, seq, head_dim)""" cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0) return (x * cos) + (self._rotate_half(x) * sin) @staticmethod def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) class HybridAttention(nn.Module): """ SLA2-inspired hybrid attention. Every token gets linear attention (cheap, global context). Important tokens ALSO get softmax attention (expensive, precise). The blend ratio α is learned per-token. argmax_attention: α = argmax_{blend} S(blend | hidden_state) """ def __init__(self, config: KernelConfig, layer_idx: int = 0): super().__init__() self.config = config self.layer_idx = layer_idx self.num_heads = config.num_heads self.head_dim = config.head_dim self.hidden_dim = config.hidden_dim # Projections self.q_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) self.k_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) self.v_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) self.o_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) # Learned blend ratio: per-head α router # Input: hidden state → Output: per-head scalar in [0, 1] self.alpha_router = nn.Sequential( nn.Linear(config.hidden_dim, config.num_heads), nn.Sigmoid() ) # RoPE self.rope = RoPE(config.head_dim, config.max_seq_len) # QK-Clip: softcap for attention logit stability (Gemma 2 / Kimi K2) # Smooth capping: logits = cap * tanh(logits / cap) # Prevents attention logit explosion at scale. None = disabled (backward-compatible). self.qk_softcap = getattr(config, 'qk_softcap', None) # Linear attention feature map: elu(x) + 1 (Katharopoulos et al., 2020) # Makes dot products non-negative for valid linear attention self._init_weights() def _init_weights(self): for proj in [self.q_proj, self.k_proj, self.v_proj, self.o_proj]: nn.init.normal_(proj.weight, std=0.02) # Initialize alpha toward 0.5 (balanced blend) nn.init.zeros_(self.alpha_router[0].bias) def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None, memory_keys: torch.Tensor | None = None, memory_values: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: x: (batch, seq_len, hidden_dim) mask: (batch, 1, seq_len, seq_len) causal mask memory_keys: Optional hot memory keys to attend over memory_values: Optional hot memory values Returns: (batch, seq_len, hidden_dim) """ B, S, D = x.shape # Project to multi-head Q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) # Apply RoPE to Q and K Q = self.rope(Q, S) K = self.rope(K, S) # === Linear Attention Path (O(n)) === # Feature map: elu(x) + 1 → non-negative Q_lin = F.elu(Q) + 1 # (B, H, S, D) K_lin = F.elu(K) + 1 # Causal linear attention via cumulative sum # KV = K_lin^T @ V accumulated causally # For simplicity in skeleton, use full (non-causal) linear attention # TODO: Replace with causal linear attention for autoregressive generation KV_lin = torch.matmul(K_lin.transpose(-2, -1), V) # (B, H, D, D) Z_lin = K_lin.transpose(-2, -1).sum(dim=-1, keepdim=True) # normalizer O_linear = torch.matmul(Q_lin, KV_lin) / (torch.matmul(Q_lin, Z_lin) + 1e-6) # === Softmax Attention Path (O(n²) but precise) === scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # QK-Clip: prevent attention logit explosion (Gemma 2 / Kimi K2 style) if self.qk_softcap is not None and self.qk_softcap > 0: scores = self.qk_softcap * torch.tanh(scores / self.qk_softcap) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1) O_softmax = torch.matmul(attn_weights, V) # === Blend === # α per-token, per-head: (B, S, num_heads) → (B, num_heads, S, 1) alpha = self.alpha_router(x) # (B, S, H) alpha = alpha.permute(0, 2, 1).unsqueeze(-1) # (B, H, S, 1) O = alpha * O_softmax + (1 - alpha) * O_linear # Reshape back O = O.transpose(1, 2).contiguous().view(B, S, D) return self.o_proj(O) class SwiGLU(nn.Module): """SwiGLU FFN block (Shazeer, 2020). Used in LLaMA, Mistral, etc.""" def __init__(self, config: KernelConfig): super().__init__() self.gate_proj = nn.Linear(config.hidden_dim, config.ffn_dim, bias=False) self.up_proj = nn.Linear(config.hidden_dim, config.ffn_dim, bias=False) self.down_proj = nn.Linear(config.ffn_dim, config.hidden_dim, bias=False) self._init_weights() def _init_weights(self): for proj in [self.gate_proj, self.up_proj, self.down_proj]: nn.init.normal_(proj.weight, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (Zhang & Sennrich, 2019).""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * norm * self.weight class TransformerLayer(nn.Module): """Single transformer layer: RMSNorm → HybridAttn → RMSNorm → SwiGLU.""" def __init__(self, config: KernelConfig, layer_idx: int = 0): super().__init__() self.attention = HybridAttention(config, layer_idx) self.ffn = SwiGLU(config) self.attn_norm = RMSNorm(config.hidden_dim) self.ffn_norm = RMSNorm(config.hidden_dim) def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None, ) -> torch.Tensor: # Pre-norm residual connections x = x + self.attention(self.attn_norm(x), mask=mask) x = x + self.ffn(self.ffn_norm(x)) return x