| """ |
| Transformer components for CodonTranslator. |
| Includes RMSNorm, self-attention (SDPA/Flash) with optional mask, |
| cross-attention for conditioning memory, SwiGLU FFN, and a basic block. |
| """ |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization.""" |
| |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply RMS normalization. |
| |
| Args: |
| x: Input tensor of any shape ending in dim |
| |
| Returns: |
| Normalized tensor of same shape |
| """ |
| norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return x * norm * self.weight |
|
|
|
|
| def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| """Apply rotary embeddings to x: [B,H,T,D]; cos/sin: [1,1,T,D].""" |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
| x_rot = torch.zeros_like(x) |
| x_rot[..., ::2] = -x2 |
| x_rot[..., 1::2] = x1 |
| return x * cos + x_rot * sin |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """Self-attention using PyTorch SDPA kernels (Flash/MemEff/Math) + RoPE. |
| - attn_mask: bool [B, T, T] with True = keep, False = block |
| - is_causal: whether to apply causal masking internally |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| dropout: float = 0.0, |
| use_rope: bool = True, |
| ): |
| super().__init__() |
| assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}" |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.dropout = dropout |
| self.use_rope = use_rope |
|
|
| self.qkv = nn.Linear(dim, 3 * dim, bias=False) |
| self.out_proj = nn.Linear(dim, dim, bias=False) |
| self.resid_dropout = nn.Dropout(dropout) |
|
|
| |
| self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} |
|
|
| def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: |
| key = (T, device, dtype) |
| cached = self._rope_cache.get(key) |
| if cached is not None: |
| return cached |
| dim_half = self.head_dim // 2 |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) |
| t = torch.arange(T, device=device, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq) |
| cos = torch.cos(freqs).repeat_interleave(2, dim=-1) |
| sin = torch.sin(freqs).repeat_interleave(2, dim=-1) |
| cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) |
| sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) |
| self._rope_cache[key] = (cos, sin) |
| return cos, sin |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| return_kv: bool = False, |
| position_offset: int = 0, |
| ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": |
| """ |
| Self-attention with optional KV cache support. |
| |
| Args: |
| x: [B, T_new, H] |
| past_kv: Optional tuple (k, v), each [B, nH, T_past, Hd] |
| return_kv: If True, also return updated (k, v) |
| position_offset: Starting position index for RoPE (past length) |
| |
| Returns: |
| out or (out, present_kv) |
| """ |
| B, T_new, _ = x.shape |
|
|
| |
| qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous() |
|
|
| |
| if self.use_rope: |
| |
| cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) |
| if position_offset > 0: |
| cos = cos[:, :, position_offset: position_offset + T_new, :] |
| sin = sin[:, :, position_offset: position_offset + T_new, :] |
| |
| q = _apply_rope(q, cos, sin) |
| k_new = _apply_rope(k_new, cos, sin) |
|
|
| |
| if past_kv is not None: |
| k_past, v_past = past_kv |
| k = torch.cat([k_past, k_new], dim=2) |
| v = torch.cat([v_past, v_new], dim=2) |
| is_causal = False |
| else: |
| k, v = k_new, v_new |
| is_causal = True |
|
|
| |
| backends = [SDPBackend.FLASH_ATTENTION] |
| with sdpa_kernel(backends): |
| if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): |
| amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
| with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=is_causal, |
| ) |
| else: |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=is_causal, |
| ) |
|
|
| out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim) |
| |
| if out.dtype != x.dtype: |
| out = out.to(x.dtype) |
| out = self.out_proj(out) |
| out = self.resid_dropout(out) |
|
|
| if return_kv: |
| return out, (k, v) |
| return out |
|
|
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| """Grouped-Query Attention (GQA) using Flash Attention via PyTorch SDPA. |
| |
| - num_heads total query heads |
| - num_kv_groups shared K/V groups (num_heads must be divisible by num_kv_groups) |
| - Optional q/k RMSNorm |
| - Supports RoPE with a scalar or per-sample position_offset (like MHA) |
| - Optional KV cache compatible with the existing interface (stores expanded per-head K/V) |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| num_kv_groups: int, |
| dropout: float = 0.0, |
| qk_norm: bool = False, |
| ) -> None: |
| super().__init__() |
| assert num_heads % max(1, num_kv_groups) == 0, "num_heads must be divisible by num_kv_groups" |
| self.dim = dim |
| self.num_heads = int(num_heads) |
| self.num_kv_groups = max(1, int(num_kv_groups)) |
| self.group_size = self.num_heads // self.num_kv_groups |
|
|
| assert dim % num_heads == 0, "dim must be divisible by num_heads" |
| self.head_dim = dim // num_heads |
| self.dropout = dropout |
|
|
| self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False) |
| self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) |
| self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) |
| self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False) |
|
|
| self.q_norm = RMSNorm(self.head_dim) if qk_norm else None |
| self.k_norm = RMSNorm(self.head_dim) if qk_norm else None |
|
|
| |
| self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} |
|
|
| def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: |
| key = (T, device, dtype) |
| cached = self._rope_cache.get(key) |
| if cached is not None: |
| return cached |
| dim_half = self.head_dim // 2 |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) |
| t = torch.arange(T, device=device, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq) |
| cos = torch.cos(freqs).repeat_interleave(2, dim=-1) |
| sin = torch.sin(freqs).repeat_interleave(2, dim=-1) |
| cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) |
| sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) |
| self._rope_cache[key] = (cos, sin) |
| return cos, sin |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| return_kv: bool = False, |
| position_offset: int | torch.Tensor = 0, |
| ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": |
| B, T_new, _ = x.shape |
|
|
| |
| q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() |
| v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() |
|
|
| |
| if self.q_norm is not None: |
| q = self.q_norm(q) |
| if self.k_norm is not None: |
| k = self.k_norm(k) |
|
|
| |
| if isinstance(position_offset, int): |
| cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) |
| if position_offset > 0: |
| cos = cos[:, :, position_offset: position_offset + T_new, :] |
| sin = sin[:, :, position_offset: position_offset + T_new, :] |
| q = _apply_rope(q, cos, sin) |
| k = _apply_rope(k, cos, sin) |
| else: |
| off = position_offset.to(device=x.device, dtype=torch.long) |
| max_off = int(off.max().item()) |
| cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype) |
| ar = torch.arange(T_new, device=x.device, dtype=torch.long) |
| idx = (off.unsqueeze(1) + ar.unsqueeze(0)) |
| cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) |
| sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) |
| q = _apply_rope(q, cos_b, sin_b) |
| |
| k = _apply_rope(k, cos_b, sin_b) |
|
|
| |
| if self.group_size > 1: |
| k_exp = k.repeat_interleave(self.group_size, dim=1) |
| v_exp = v.repeat_interleave(self.group_size, dim=1) |
| else: |
| k_exp, v_exp = k, v |
|
|
| |
| if past_kv is not None: |
| k_past, v_past = past_kv |
| k_cat = torch.cat([k_past, k_exp], dim=2) |
| v_cat = torch.cat([v_past, v_exp], dim=2) |
| is_causal = False |
| else: |
| k_cat, v_cat = k_exp, v_exp |
| is_causal = True |
|
|
| |
| with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): |
| if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): |
| amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
| with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): |
| out = torch.nn.functional.scaled_dot_product_attention( |
| q, k_cat, v_cat, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=is_causal, |
| ) |
| else: |
| out = torch.nn.functional.scaled_dot_product_attention( |
| q, k_cat, v_cat, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=is_causal, |
| ) |
|
|
| out = out.transpose(1, 2).contiguous().view(B, T_new, self.num_heads * self.head_dim) |
| |
| if out.dtype != x.dtype: |
| out = out.to(x.dtype) |
| out = self.out_proj(out) |
|
|
| if return_kv: |
| return out, (k_cat, v_cat) |
| return out |
|
|
|
|
|
|
| class FeedForward(nn.Module): |
| """Feed-forward network with optional GLU activation.""" |
| |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply feed-forward network. |
| |
| Args: |
| x: Input tensor [B, T, dim] |
| |
| Returns: |
| Output tensor [B, T, dim] |
| """ |
| |
| return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Pre-norm Transformer block using self-attn + SwiGLU FFN (no cross-attention).""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.0, |
| num_kv_groups: int | None = None, |
| qk_norm: bool = False, |
| attn_type: str = "gqa", |
| ): |
| super().__init__() |
| self.norm1 = RMSNorm(dim) |
| if attn_type == "mha": |
| self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout) |
| self._attn_is_gqa = False |
| else: |
| |
| kv_groups = num_heads if (num_kv_groups is None) else max(1, int(num_kv_groups)) |
| self.attn = GroupedQueryAttention(dim=dim, num_heads=num_heads, num_kv_groups=kv_groups, dropout=dropout, qk_norm=qk_norm) |
| self._attn_is_gqa = True |
| self.norm2 = RMSNorm(dim) |
| self.ffn = FeedForward(dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=dropout) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| position_offset: int = 0, |
| ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": |
| """Forward pass with optional KV caching.""" |
| if use_cache or (past_kv is not None): |
| attn_out = self.attn(self.norm1(x), past_kv=past_kv, return_kv=True, position_offset=position_offset) |
| x = x + attn_out[0] |
| x = x + self.ffn(self.norm2(x)) |
| return x, attn_out[1] |
| else: |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|