""" Transformer components for CodonTranslator. Includes RMSNorm, self-attention (SDPA/Flash) with optional mask, cross-attention for conditioning memory, SwiGLU FFN, and a basic block. """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.attention import SDPBackend, sdpa_kernel # Require recent PyTorch class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply RMS normalization. Args: x: Input tensor of any shape ending in dim Returns: Normalized tensor of same shape """ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * norm * self.weight def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """Apply rotary embeddings to x: [B,H,T,D]; cos/sin: [1,1,T,D].""" x1 = x[..., ::2] x2 = x[..., 1::2] x_rot = torch.zeros_like(x) x_rot[..., ::2] = -x2 x_rot[..., 1::2] = x1 return x * cos + x_rot * sin class MultiHeadAttention(nn.Module): """Self-attention using PyTorch SDPA kernels (Flash/MemEff/Math) + RoPE. - attn_mask: bool [B, T, T] with True = keep, False = block - is_causal: whether to apply causal masking internally """ def __init__( self, dim: int, num_heads: int, dropout: float = 0.0, use_rope: bool = True, ): super().__init__() assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}" self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.dropout = dropout self.use_rope = use_rope self.qkv = nn.Linear(dim, 3 * dim, bias=False) self.out_proj = nn.Linear(dim, dim, bias=False) self.resid_dropout = nn.Dropout(dropout) # RoPE cache self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: key = (T, device, dtype) cached = self._rope_cache.get(key) if cached is not None: return cached dim_half = self.head_dim // 2 inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) t = torch.arange(T, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos = torch.cos(freqs).repeat_interleave(2, dim=-1) sin = torch.sin(freqs).repeat_interleave(2, dim=-1) cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D] sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) self._rope_cache[key] = (cos, sin) return cos, sin def forward( self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_kv: bool = False, position_offset: int = 0, ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": """ Self-attention with optional KV cache support. Args: x: [B, T_new, H] past_kv: Optional tuple (k, v), each [B, nH, T_past, Hd] return_kv: If True, also return updated (k, v) position_offset: Starting position index for RoPE (past length) Returns: out or (out, present_kv) """ B, T_new, _ = x.shape # QKV projections and reshape (ensure contiguous for SDPA kernels) qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous() # RoPE for new tokens only if self.use_rope: # Compute cos/sin up to (offset + T_new), then slice the tail for new positions cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) if position_offset > 0: cos = cos[:, :, position_offset: position_offset + T_new, :] sin = sin[:, :, position_offset: position_offset + T_new, :] # Apply to q and k_new q = _apply_rope(q, cos, sin) k_new = _apply_rope(k_new, cos, sin) # Concatenate with cache if provided if past_kv is not None: k_past, v_past = past_kv k = torch.cat([k_past, k_new], dim=2) v = torch.cat([v_past, v_new], dim=2) is_causal = False # No future tokens present; avoid unnecessary masking else: k, v = k_new, v_new is_causal = True # Prefer FlashAttention; fall back to MemEff then Math. Autocast to half/bfloat16 on CUDA. backends = [SDPBackend.FLASH_ATTENTION]#, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] with sdpa_kernel(backends): if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): out = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) else: out = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim) # Align dtype with residual/Linear weights to avoid bf16/float mismatches if out.dtype != x.dtype: out = out.to(x.dtype) out = self.out_proj(out) out = self.resid_dropout(out) if return_kv: return out, (k, v) return out class GroupedQueryAttention(nn.Module): """Grouped-Query Attention (GQA) using Flash Attention via PyTorch SDPA. - num_heads total query heads - num_kv_groups shared K/V groups (num_heads must be divisible by num_kv_groups) - Optional q/k RMSNorm - Supports RoPE with a scalar or per-sample position_offset (like MHA) - Optional KV cache compatible with the existing interface (stores expanded per-head K/V) """ def __init__( self, dim: int, num_heads: int, num_kv_groups: int, dropout: float = 0.0, qk_norm: bool = False, ) -> None: super().__init__() assert num_heads % max(1, num_kv_groups) == 0, "num_heads must be divisible by num_kv_groups" self.dim = dim self.num_heads = int(num_heads) self.num_kv_groups = max(1, int(num_kv_groups)) self.group_size = self.num_heads // self.num_kv_groups assert dim % num_heads == 0, "dim must be divisible by num_heads" self.head_dim = dim // num_heads self.dropout = dropout self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False) self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False) self.q_norm = RMSNorm(self.head_dim) if qk_norm else None self.k_norm = RMSNorm(self.head_dim) if qk_norm else None # RoPE cache self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: key = (T, device, dtype) cached = self._rope_cache.get(key) if cached is not None: return cached dim_half = self.head_dim // 2 inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) t = torch.arange(T, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos = torch.cos(freqs).repeat_interleave(2, dim=-1) sin = torch.sin(freqs).repeat_interleave(2, dim=-1) cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D] sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) self._rope_cache[key] = (cos, sin) return cos, sin def forward( self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_kv: bool = False, position_offset: int | torch.Tensor = 0, ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": B, T_new, _ = x.shape # Project to Q, K, V q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous() # [B,H,T,Hd] k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd] v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd] # Optional RMSNorm on q/k if self.q_norm is not None: q = self.q_norm(q) if self.k_norm is not None: k = self.k_norm(k) # RoPE for new tokens only if isinstance(position_offset, int): cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) if position_offset > 0: cos = cos[:, :, position_offset: position_offset + T_new, :] sin = sin[:, :, position_offset: position_offset + T_new, :] q = _apply_rope(q, cos, sin) k = _apply_rope(k, cos, sin) else: off = position_offset.to(device=x.device, dtype=torch.long) max_off = int(off.max().item()) cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype) ar = torch.arange(T_new, device=x.device, dtype=torch.long) idx = (off.unsqueeze(1) + ar.unsqueeze(0)) # [B, T_new] cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) # [B,1,T,D] sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) q = _apply_rope(q, cos_b, sin_b) # k has groups dimension [B,G,T,D]; share same offsets per batch k = _apply_rope(k, cos_b, sin_b) # Expand grouped K/V to per-head by repeating groups if self.group_size > 1: k_exp = k.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd] v_exp = v.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd] else: k_exp, v_exp = k, v # already per-head # KV cache: concatenate past along sequence dim if past_kv is not None: k_past, v_past = past_kv k_cat = torch.cat([k_past, k_exp], dim=2) v_cat = torch.cat([v_past, v_exp], dim=2) is_causal = False else: k_cat, v_cat = k_exp, v_exp is_causal = True # Prefer FlashAttention; fall back to MemEff/Math. Ensure CUDA autocast to half/bfloat16 so kernels are available with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): out = torch.nn.functional.scaled_dot_product_attention( q, k_cat, v_cat, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) # [B,H,T,Hd] else: out = torch.nn.functional.scaled_dot_product_attention( q, k_cat, v_cat, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) # [B,H,T,Hd] out = out.transpose(1, 2).contiguous().view(B, T_new, self.num_heads * self.head_dim) # Ensure dtype compatibility for Linear / residual path if out.dtype != x.dtype: out = out.to(x.dtype) out = self.out_proj(out) if return_kv: return out, (k_cat, v_cat) return out class FeedForward(nn.Module): """Feed-forward network with optional GLU activation.""" def __init__( self, dim: int, hidden_dim: int, dropout: float = 0.0, ): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply feed-forward network. Args: x: Input tensor [B, T, dim] Returns: Output tensor [B, T, dim] """ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class TransformerBlock(nn.Module): """Pre-norm Transformer block using self-attn + SwiGLU FFN (no cross-attention).""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0, num_kv_groups: int | None = None, qk_norm: bool = False, attn_type: str = "gqa", # "gqa" or "mha" ): super().__init__() self.norm1 = RMSNorm(dim) if attn_type == "mha": self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout) self._attn_is_gqa = False else: # Use Grouped-Query Attention (defaults to no grouping when num_kv_groups is None) kv_groups = num_heads if (num_kv_groups is None) else max(1, int(num_kv_groups)) self.attn = GroupedQueryAttention(dim=dim, num_heads=num_heads, num_kv_groups=kv_groups, dropout=dropout, qk_norm=qk_norm) self._attn_is_gqa = True self.norm2 = RMSNorm(dim) self.ffn = FeedForward(dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=dropout) def forward( self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0, ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": """Forward pass with optional KV caching.""" if use_cache or (past_kv is not None): attn_out = self.attn(self.norm1(x), past_kv=past_kv, return_kv=True, position_offset=position_offset) x = x + attn_out[0] x = x + self.ffn(self.norm2(x)) return x, attn_out[1] else: x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x