CodonTranslator / src /layers.py
alegendaryfish's picture
Rename public-facing CodonGPT strings to CodonTranslator
1cd3d3f verified
"""
Transformer components for CodonTranslator.
Includes RMSNorm, self-attention (SDPA/Flash) with optional mask,
cross-attention for conditioning memory, SwiGLU FFN, and a basic block.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel # Require recent PyTorch
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization.
Args:
x: Input tensor of any shape ending in dim
Returns:
Normalized tensor of same shape
"""
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * norm * self.weight
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary embeddings to x: [B,H,T,D]; cos/sin: [1,1,T,D]."""
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_rot = torch.zeros_like(x)
x_rot[..., ::2] = -x2
x_rot[..., 1::2] = x1
return x * cos + x_rot * sin
class MultiHeadAttention(nn.Module):
"""Self-attention using PyTorch SDPA kernels (Flash/MemEff/Math) + RoPE.
- attn_mask: bool [B, T, T] with True = keep, False = block
- is_causal: whether to apply causal masking internally
"""
def __init__(
self,
dim: int,
num_heads: int,
dropout: float = 0.0,
use_rope: bool = True,
):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.dropout = dropout
self.use_rope = use_rope
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
self.resid_dropout = nn.Dropout(dropout)
# RoPE cache
self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
key = (T, device, dtype)
cached = self._rope_cache.get(key)
if cached is not None:
return cached
dim_half = self.head_dim // 2
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
t = torch.arange(T, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D]
sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
self._rope_cache[key] = (cos, sin)
return cos, sin
def forward(
self,
x: torch.Tensor,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_kv: bool = False,
position_offset: int = 0,
) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
"""
Self-attention with optional KV cache support.
Args:
x: [B, T_new, H]
past_kv: Optional tuple (k, v), each [B, nH, T_past, Hd]
return_kv: If True, also return updated (k, v)
position_offset: Starting position index for RoPE (past length)
Returns:
out or (out, present_kv)
"""
B, T_new, _ = x.shape
# QKV projections and reshape (ensure contiguous for SDPA kernels)
qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous()
# RoPE for new tokens only
if self.use_rope:
# Compute cos/sin up to (offset + T_new), then slice the tail for new positions
cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
if position_offset > 0:
cos = cos[:, :, position_offset: position_offset + T_new, :]
sin = sin[:, :, position_offset: position_offset + T_new, :]
# Apply to q and k_new
q = _apply_rope(q, cos, sin)
k_new = _apply_rope(k_new, cos, sin)
# Concatenate with cache if provided
if past_kv is not None:
k_past, v_past = past_kv
k = torch.cat([k_past, k_new], dim=2)
v = torch.cat([v_past, v_new], dim=2)
is_causal = False # No future tokens present; avoid unnecessary masking
else:
k, v = k_new, v_new
is_causal = True
# Prefer FlashAttention; fall back to MemEff then Math. Autocast to half/bfloat16 on CUDA.
backends = [SDPBackend.FLASH_ATTENTION]#, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
with sdpa_kernel(backends):
if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
else:
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim)
# Align dtype with residual/Linear weights to avoid bf16/float mismatches
if out.dtype != x.dtype:
out = out.to(x.dtype)
out = self.out_proj(out)
out = self.resid_dropout(out)
if return_kv:
return out, (k, v)
return out
class GroupedQueryAttention(nn.Module):
"""Grouped-Query Attention (GQA) using Flash Attention via PyTorch SDPA.
- num_heads total query heads
- num_kv_groups shared K/V groups (num_heads must be divisible by num_kv_groups)
- Optional q/k RMSNorm
- Supports RoPE with a scalar or per-sample position_offset (like MHA)
- Optional KV cache compatible with the existing interface (stores expanded per-head K/V)
"""
def __init__(
self,
dim: int,
num_heads: int,
num_kv_groups: int,
dropout: float = 0.0,
qk_norm: bool = False,
) -> None:
super().__init__()
assert num_heads % max(1, num_kv_groups) == 0, "num_heads must be divisible by num_kv_groups"
self.dim = dim
self.num_heads = int(num_heads)
self.num_kv_groups = max(1, int(num_kv_groups))
self.group_size = self.num_heads // self.num_kv_groups
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.head_dim = dim // num_heads
self.dropout = dropout
self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False)
self.q_norm = RMSNorm(self.head_dim) if qk_norm else None
self.k_norm = RMSNorm(self.head_dim) if qk_norm else None
# RoPE cache
self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
key = (T, device, dtype)
cached = self._rope_cache.get(key)
if cached is not None:
return cached
dim_half = self.head_dim // 2
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
t = torch.arange(T, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D]
sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
self._rope_cache[key] = (cos, sin)
return cos, sin
def forward(
self,
x: torch.Tensor,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_kv: bool = False,
position_offset: int | torch.Tensor = 0,
) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
B, T_new, _ = x.shape
# Project to Q, K, V
q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous() # [B,H,T,Hd]
k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd]
v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd]
# Optional RMSNorm on q/k
if self.q_norm is not None:
q = self.q_norm(q)
if self.k_norm is not None:
k = self.k_norm(k)
# RoPE for new tokens only
if isinstance(position_offset, int):
cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
if position_offset > 0:
cos = cos[:, :, position_offset: position_offset + T_new, :]
sin = sin[:, :, position_offset: position_offset + T_new, :]
q = _apply_rope(q, cos, sin)
k = _apply_rope(k, cos, sin)
else:
off = position_offset.to(device=x.device, dtype=torch.long)
max_off = int(off.max().item())
cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype)
ar = torch.arange(T_new, device=x.device, dtype=torch.long)
idx = (off.unsqueeze(1) + ar.unsqueeze(0)) # [B, T_new]
cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) # [B,1,T,D]
sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1)
q = _apply_rope(q, cos_b, sin_b)
# k has groups dimension [B,G,T,D]; share same offsets per batch
k = _apply_rope(k, cos_b, sin_b)
# Expand grouped K/V to per-head by repeating groups
if self.group_size > 1:
k_exp = k.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd]
v_exp = v.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd]
else:
k_exp, v_exp = k, v # already per-head
# KV cache: concatenate past along sequence dim
if past_kv is not None:
k_past, v_past = past_kv
k_cat = torch.cat([k_past, k_exp], dim=2)
v_cat = torch.cat([v_past, v_exp], dim=2)
is_causal = False
else:
k_cat, v_cat = k_exp, v_exp
is_causal = True
# Prefer FlashAttention; fall back to MemEff/Math. Ensure CUDA autocast to half/bfloat16 so kernels are available
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
out = torch.nn.functional.scaled_dot_product_attention(
q, k_cat, v_cat,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
) # [B,H,T,Hd]
else:
out = torch.nn.functional.scaled_dot_product_attention(
q, k_cat, v_cat,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
) # [B,H,T,Hd]
out = out.transpose(1, 2).contiguous().view(B, T_new, self.num_heads * self.head_dim)
# Ensure dtype compatibility for Linear / residual path
if out.dtype != x.dtype:
out = out.to(x.dtype)
out = self.out_proj(out)
if return_kv:
return out, (k_cat, v_cat)
return out
class FeedForward(nn.Module):
"""Feed-forward network with optional GLU activation."""
def __init__(
self,
dim: int,
hidden_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply feed-forward network.
Args:
x: Input tensor [B, T, dim]
Returns:
Output tensor [B, T, dim]
"""
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class TransformerBlock(nn.Module):
"""Pre-norm Transformer block using self-attn + SwiGLU FFN (no cross-attention)."""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
num_kv_groups: int | None = None,
qk_norm: bool = False,
attn_type: str = "gqa", # "gqa" or "mha"
):
super().__init__()
self.norm1 = RMSNorm(dim)
if attn_type == "mha":
self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout)
self._attn_is_gqa = False
else:
# Use Grouped-Query Attention (defaults to no grouping when num_kv_groups is None)
kv_groups = num_heads if (num_kv_groups is None) else max(1, int(num_kv_groups))
self.attn = GroupedQueryAttention(dim=dim, num_heads=num_heads, num_kv_groups=kv_groups, dropout=dropout, qk_norm=qk_norm)
self._attn_is_gqa = True
self.norm2 = RMSNorm(dim)
self.ffn = FeedForward(dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=dropout)
def forward(
self,
x: torch.Tensor,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
position_offset: int = 0,
) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
"""Forward pass with optional KV caching."""
if use_cache or (past_kv is not None):
attn_out = self.attn(self.norm1(x), past_kv=past_kv, return_kv=True, position_offset=position_offset)
x = x + attn_out[0]
x = x + self.ffn(self.norm2(x))
return x, attn_out[1]
else:
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x