""" DeepSeek-style Multi-head Latent Attention (MLA) with RoPE. Key innovations: 1. KV compression to latent space (reduce KV memory) 2. Q stays in full dimension for expressive query space 3. RoPE positional embeddings on Q and K 4. Grouped Query Attention (GQA) for efficiency 5. Learnable head combination weights 6. Numerical stability via pre-norm and scaling """ import torch import torch.nn as nn import torch.nn.functional as F import math def _residual_rms_norm(x, enabled=False, target=1.0, eps=1e-6, cap=None): if not enabled and cap is None: return x rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt() if enabled: scale = target / rms else: cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device) scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms) return x * scale.to(dtype=x.dtype) class RotaryEmbedding(nn.Module): """Rotary position embeddings used in RoPE with optional YaRN extension. YaRN (Yet another RoPE eXtension) allows context length interpolation via frequency scaling. When yarn_alpha != 1.0 or seq_len > max_seq_length, frequencies are dynamically scaled to support longer sequences. Parameters: dim: Embedding dimension (must be even) rope_scale: Base RoPE scale factor (default: 40) max_seq_length: Original trained sequence length (default: 1024) yarn_alpha: YaRN interpolation factor (default: 1.0, no interpolation) - values < 1.0: aggressive interpolation (faster context expansion) - values > 1.0: conservative interpolation (safer) """ def __init__(self, dim, rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0): super().__init__() assert dim % 2 == 0, "Dimension must be even for rotary embeddings" self.dim = dim self.rope_scale = rope_scale self.max_seq_length = max_seq_length self.yarn_alpha = yarn_alpha inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def _apply_yarn_scaling(self, freqs, seq_len): """Apply YaRN frequency scaling for context extension. Args: freqs: [seq_len, dim] frequency tensor seq_len: Current sequence length Returns: Scaled freqs if yarn is enabled and seq_len > max_seq_length, else original freqs """ # Only apply scaling if sequence exceeds training length or yarn_alpha != 1.0 if self.yarn_alpha == 1.0 and seq_len <= self.max_seq_length: return freqs # YaRN scaling factor: interpolate frequency reduction # scale_factor = (seq_len / max_seq_length) ** (1 / yarn_alpha) # Scales down frequencies to fit longer context while maintaining position distinctions scale_factor = (seq_len / self.max_seq_length) ** (1.0 / self.yarn_alpha) freqs = freqs / scale_factor return freqs def forward(self, seq_len, device): """Generate rotary embeddings for sequence with optional YaRN scaling. Args: seq_len: Current sequence length device: Device to create embeddings on Returns: [seq_len, 2*dim] rotary embeddings (duplicated freqs) """ t = torch.arange(seq_len, device=device).type_as(self.inv_freq) / self.rope_scale freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [seq_len, dim//2] # Apply YaRN frequency scaling if enabled freqs = self._apply_yarn_scaling(freqs, seq_len) return torch.cat((freqs, freqs), dim=-1) # [seq_len, dim] def rotate_half(x): """Rotate half the hidden dims of the input.""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary(x, cos, sin): """Apply rotary embeddings to input tensor. Args: x: [B, n_heads, seq_len, head_dim] or similar cos: [seq_len, head_dim] or [1, 1, seq_len, head_dim] sin: [seq_len, head_dim] or [1, 1, seq_len, head_dim] """ # Ensure cos/sin have the right dimensions for broadcasting if cos.dim() == 2: cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) # Handle case where cos/sin may be shorter than x cos = cos[..., :x.shape[-1]] sin = sin[..., :x.shape[-1]] # Split x based on cos dimensions x_rot = x[..., :cos.shape[-1]] x_base = x[..., cos.shape[-1]:] # Apply rotation x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin) # Concatenate rotated and base parts return torch.cat([x_rot, x_base], dim=-1) if x_base.shape[-1] > 0 else x_rot class DeepSeekMLA(nn.Module): """ DeepSeek-style Multi-head Latent Attention (MLA). Architecture: 1. Project input to Query: [B, seq_len, d_model] -> [B, seq_len, d_model] 2. Compress to KV latent: [B, seq_len, d_model] -> [B, seq_len, d_latent_kv] 3. Split into heads for attention 4. Apply RoPE to Q and K 5. Compute attention scores: (Q @ K^T) / sqrt(d_head) 6. Apply softmax and combine with values 7. Concatenate heads and project back to d_model Parameters: d_model: Model dimension d_latent_kv: Latent dimension for KV compression n_heads: Number of attention heads d_rope: Dimension for RoPE (usually == d_head_dim) dropout: Dropout probability gqa_groups: Grouped Query Attention groups (1 = standard MLA, >1 = GQA) """ def __init__(self, d_model, d_latent_kv, n_heads, d_rope, dropout=0.1, gqa_groups=1, rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0): super().__init__() self.d_model = d_model self.d_latent_kv = d_latent_kv self.n_heads = n_heads self.d_rope = d_rope self.gqa_groups = gqa_groups assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" assert d_latent_kv % n_heads == 0, f"d_latent_kv ({d_latent_kv}) must be divisible by n_heads ({n_heads})" self.d_head_full = d_model // n_heads # Full head dimension for Q self.d_head_latent = d_latent_kv // n_heads # Latent head dimension for K/V # Scaling factor for attention scores self.scale = 1.0 / math.sqrt(self.d_head_latent) # Layer norm before attention for stability self.norm = nn.LayerNorm(d_model) # Q projection: d_model -> d_model (full dimension) self.q_proj = nn.Linear(d_model, d_model, bias=False) # K/V projections: d_model -> d_latent_kv (compressed) self.k_proj = nn.Linear(d_model, d_latent_kv, bias=False) self.v_proj = nn.Linear(d_model, d_latent_kv, bias=False) # RoPE for position encoding with YaRN support self.rotary = RotaryEmbedding( d_rope, rope_scale=rope_scale, max_seq_length=max_seq_length, yarn_alpha=yarn_alpha ) # Output projection: d_latent_kv -> d_model self.out_proj = nn.Linear(d_latent_kv, d_model, bias=False) # Head combination weights (learnable scaling per head) self.head_weights = nn.Parameter(torch.ones(n_heads)) # Dropout self.attn_dropout = nn.Dropout(dropout) self.proj_dropout = nn.Dropout(dropout) def forward(self, x, attention_mask=None): """ Args: x: [B, seq_len, d_model] attention_mask: [B, seq_len] (1 = keep, 0 = mask) or [B, 1, seq_len, seq_len] (causal mask) Returns: out: [B, seq_len, d_model] """ B, seq_len, _ = x.shape device = x.device # Pre-norm x_norm = self.norm(x) # Project to Q, K, V spaces q = self.q_proj(x_norm) # [B, seq_len, d_model] k = self.k_proj(x_norm) # [B, seq_len, d_latent_kv] v = self.v_proj(x_norm) # [B, seq_len, d_latent_kv] # ──────────────────────────────────────────────────────────────────────── # Reshape into multi-head format # ──────────────────────────────────────────────────────────────────────── # Q: [B, seq_len, d_model] -> [B, seq_len, n_heads, d_head_full] -> [B, n_heads, seq_len, d_head_full] q = q.view(B, seq_len, self.n_heads, self.d_head_full).transpose(1, 2) # K: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent] k = k.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2) # V: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent] v = v.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2) # ──────────────────────────────────────────────────────────────────────── # Apply RoPE to Q and K # ──────────────────────────────────────────────────────────────────────── if self.d_rope > 0: # Generate RoPE embeddings: [seq_len, d_rope] rotary_emb = self.rotary(seq_len, device) # [seq_len, d_rope] cos = torch.cos(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope] sin = torch.sin(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope] # Apply RoPE to Q (only on first d_rope dimensions) q_rope = apply_rotary(q[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope] q = torch.cat([q_rope, q[..., self.d_rope:]], dim=-1) # Combine with remaining dims # Apply RoPE to K (only on first d_rope dimensions) k_rope = apply_rotary(k[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope] k = torch.cat([k_rope, k[..., self.d_rope:]], dim=-1) # Combine with remaining dims # ──────────────────────────────────────────────────────────────────────── # Compute attention using PyTorch 2.0+ fused scaled_dot_product_attention # ──────────────────────────────────────────────────────────────────────── # Only use first d_head_latent dimensions of Q for attention # K and V are already d_head_latent dimension q_for_attn = q[..., :self.d_head_latent] # [B, n_heads, seq_len, d_head_latent] # Convert attention mask to boolean format for scaled_dot_product_attention # Input mask: 0 = mask (don't attend), 1 = keep (attend) # Boolean mask: False = mask, True = attend attn_mask_bool = None if attention_mask is not None: if attention_mask.dim() == 2: # [B, seq_len] with {0, 1} -> [B, 1, 1, seq_len] with {False, True} attn_mask_bool = attention_mask.bool().unsqueeze(1).unsqueeze(1) else: # Already 4D [B, 1, seq_len, seq_len], just convert to bool attn_mask_bool = attention_mask.bool() # Get dropout probability (0.0 when not training) dropout_p = self.attn_dropout.p if self.training else 0.0 if hasattr(F, "scaled_dot_product_attention"): # Apply fused attention operation when available. out_heads = F.scaled_dot_product_attention( q_for_attn, k, v, attn_mask=attn_mask_bool, dropout_p=dropout_p, scale=None ) # [B, n_heads, seq_len, d_head_latent] else: scores = torch.matmul(q_for_attn, k.transpose(-2, -1)) * self.scale if attn_mask_bool is not None: scores = scores.masked_fill(~attn_mask_bool, torch.finfo(scores.dtype).min) attn_weights = F.softmax(scores, dim=-1) if dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p, training=True) out_heads = torch.matmul(attn_weights, v) # ──────────────────────────────────────────────────────────────────────── # Concatenate heads # ──────────────────────────────────────────────────────────────────────── # [B, seq_len, n_heads, d_head_latent] -> [B, seq_len, d_latent_kv] out_concat = out_heads.transpose(1, 2).reshape(B, seq_len, self.d_latent_kv) # Project back to d_model out = self.out_proj(out_concat) # [B, seq_len, d_model] out = self.proj_dropout(out) return out class AttentionBlock(nn.Module): """ Attention block with pre-norm residual connection and feed-forward network. Structure: Input ├─> Norm ─┬─> MLA ──┬─> Residual Add │ └────────┘ ├────────────────────────────────────> Norm ─┬─> SwiGLU FFN ──┬─> Residual Add │ └───────┘ │ └────────────────────────────────────────────────────────────> Output """ def __init__(self, d_model, d_latent_kv, n_heads, d_rope, d_ff, dropout=0.1, gqa_groups=1, rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0, residual_rms_norm=False, residual_rms_target=1.0, residual_rms_cap=None, residual_rms_eps=1e-6): super().__init__() self.residual_rms_norm = residual_rms_norm self.residual_rms_target = residual_rms_target self.residual_rms_cap = residual_rms_cap self.residual_rms_eps = residual_rms_eps self.mla = DeepSeekMLA(d_model, d_latent_kv, n_heads, d_rope, dropout, gqa_groups, rope_scale=rope_scale, max_seq_length=max_seq_length, yarn_alpha=yarn_alpha) # SwiGLU feed-forward network self.ff_norm = nn.LayerNorm(d_model) self.ff_gate = nn.Linear(d_model, d_ff, bias=False) self.ff_value = nn.Linear(d_model, d_ff, bias=False) self.ff_out = nn.Linear(d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, attention_mask=None): """ Args: x: [B, seq_len, d_model] attention_mask: [B, seq_len] or [B, 1, seq_len, seq_len] Returns: out: [B, seq_len, d_model] """ # Attention with residual attn_out = self.mla(x, attention_mask) x = x + self.dropout(attn_out) x = _residual_rms_norm( x, self.residual_rms_norm, self.residual_rms_target, self.residual_rms_eps, self.residual_rms_cap, ) # FFN with residual ff_norm = self.ff_norm(x) ff_gate = self.ff_gate(ff_norm) ff_value = self.ff_value(ff_norm) ff_out = ff_value * F.silu(ff_gate) # SwiGLU activation ff_out = self.ff_out(ff_out) x = x + self.dropout(ff_out) x = _residual_rms_norm( x, self.residual_rms_norm, self.residual_rms_target, self.residual_rms_eps, self.residual_rms_cap, ) return x