CodonTranslator / src /models.py
alegendaryfish's picture
Align public training codebase with local training setup
d3d7249 verified
"""
Core model architectures for CodonTranslator.
- CodonTranslatorModel: decoder-only backbone with species + protein prefix
Includes a frozen ESM-C encoder for protein conditioning.
"""
import math
import os
from typing import Optional, Dict, Any, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torch.nn.utils.rnn as rnn_utils
from .layers import RMSNorm, TransformerBlock
from .tokenizer import SpecialIds
class FrozenESMCEncoder(nn.Module):
"""
Frozen ESM-C encoder that computes protein embeddings on the fly.
Kept on single GPU per rank (not distributed via FSDP).
"""
def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "fp16"):
super().__init__()
self.model_name = model_name
self._device = torch.device(device if torch.cuda.is_available() else "cpu")
if dtype == "fp16":
self._autocast_dtype = torch.float16
elif dtype == "bf16":
self._autocast_dtype = torch.bfloat16
else:
self._autocast_dtype = None
self._load_model()
self.eval()
for p in self.parameters():
p.requires_grad_(False)
def _load_model(self):
from esm.models.esmc import ESMC
from esm.utils.constants.models import ESMC_300M, ESMC_600M
if self.model_name == "esmc_300m":
model_const = ESMC_300M
self.D_esm = 960
elif self.model_name == "esmc_600m":
model_const = ESMC_600M
self.D_esm = 1152
else:
raise ValueError(f"Unknown model: {self.model_name}")
self.model = ESMC.from_pretrained(model_name=model_const, device=self._device)
self.tokenizer = self.model.tokenizer
@torch.no_grad()
def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"):
from esm.utils import encoding
from esm.utils.misc import stack_variable_length_tensors
pad = self.tokenizer.pad_token_id
tokenized_seqs = []
for seq in sequences:
tokens = encoding.tokenize_sequence(seq, self.tokenizer, add_special_tokens=add_special_tokens)
if max_length is not None and len(tokens) > max_length:
tokens = tokens[:max_length]
tokenized_seqs.append(tokens)
input_ids = stack_variable_length_tensors(tokenized_seqs, constant_value=pad)
attention_mask = (input_ids != pad)
return input_ids, attention_mask
@torch.no_grad()
def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, return_contacts: bool = False):
device = self.model.device
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if self._autocast_dtype is not None and device.type == "cuda":
with torch.amp.autocast('cuda', dtype=self._autocast_dtype):
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
else:
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
embeddings = outputs.embeddings
if return_dict:
return {"embeddings": embeddings, "attention_mask": attention_mask}
else:
return embeddings
def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None):
if attention_mask is not None:
lengths = attention_mask.sum(dim=1) - 2
lengths = lengths.clamp(min=1)
else:
B, L, D = embeddings.shape
lengths = torch.full((B,), L - 2, device=embeddings.device)
stripped = embeddings[:, 1:-1, :]
return stripped, lengths
class CodonTranslatorModel(nn.Module):
def __init__(
self,
vocab_size: int = 79,
hidden_size: int = 960,
num_layers: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
max_position_embeddings: int = 4096,
dropout: float = 0.1,
layer_norm_eps: float = 1e-6,
num_special_tokens: int = 13,
special_ids: Optional[SpecialIds] = None,
esm_model_name: str = "esmc_300m",
esm_device: str = "cuda",
esm_dtype: str = "fp16",
max_protein_prefix: int = 0,
max_species_prefix: int = 0,
prepend_species: bool = True,
prepend_protein: bool = True,
species_embedding_dim: int = 1024,
attn_impl: str = "gqa", # "gqa" or "mha"
num_kv_groups: int = 0, # for GQA; 0 means default (no grouping)
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.max_position_embeddings = max_position_embeddings
self.special_ids = special_ids or SpecialIds()
self.num_special_tokens = num_special_tokens
# Single embedding table for all tokens (special + codon)
self.token_embed = nn.Embedding(vocab_size, hidden_size)
if prepend_protein and esm_model_name:
self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype)
# Project ESM token embeddings (D_esm) to model hidden size, then normalize
self.esm_ln = nn.Sequential(
nn.Linear(self.esm.D_esm, hidden_size, bias=False),
nn.ReLU(),
nn.LayerNorm(hidden_size),
)
else:
self.esm = None
self.esm_ln = None
self.species_embedding_dim = species_embedding_dim if prepend_species else 0
if prepend_species:
# Project species embeddings (fixed or token sequence) from Ds -> H
self.species_ln = nn.Sequential(
nn.Linear(self.species_embedding_dim, hidden_size, bias=False),
nn.ReLU(),
nn.LayerNorm(hidden_size),
)
else:
self.species_ln = None
# Optional per-prefix caps; 0 means unlimited (subject to global max length)
self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0
self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0
self.prepend_species = bool(prepend_species)
self.prepend_protein = bool(prepend_protein)
# Learned start embedding (BOS-less decoding)
self.start_embed = nn.Parameter(torch.zeros(1, 1, hidden_size))
nn.init.normal_(self.start_embed, mean=0.0, std=0.02)
# Attention configuration
self.attn_impl = str(attn_impl)
self.num_kv_groups = int(num_kv_groups)
kv_groups = self.num_kv_groups
self.blocks = nn.ModuleList([
TransformerBlock(
dim=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
num_kv_groups=(kv_groups if (kv_groups > 0 and attn_impl == "gqa") else None),
qk_norm=False,
attn_type=("mha" if self.attn_impl == "mha" else "gqa"),
) for _ in range(num_layers)
])
self.ln_f = RMSNorm(hidden_size, eps=layer_norm_eps)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self.gradient_checkpointing = False
def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
device = self.token_embed.weight.device
return self.token_embed(token_ids.to(device))
def build_prefix(
self,
batch_size: int,
device: torch.device,
species_tok_emb: Optional[torch.Tensor] = None,
species_emb: Optional[torch.Tensor] = None,
protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
species_tok_emb_src: Optional[torch.Tensor] = None,
species_tok_emb_tgt: Optional[torch.Tensor] = None,
species_emb_src: Optional[torch.Tensor] = None,
species_emb_tgt: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Build LLaVA-style prefix token embeddings by concatenating
[species_src]+[species_tgt]+[protein_tokens]. Returns:
- prefix: [B, Lp, H]
- prefix_lengths: [B] valid token counts per sample
"""
parts: list[torch.Tensor] = []
# Species: src then tgt (if provided)
if self.prepend_species and self.species_ln is not None:
tok_src = species_tok_emb_src if species_tok_emb_src is not None else species_tok_emb
tok_tgt = species_tok_emb_tgt if species_tok_emb_tgt is not None else species_tok_emb
emb_src = species_emb_src if species_emb_src is not None else species_emb
emb_tgt = species_emb_tgt if species_emb_tgt is not None else species_emb
def _as_tokens(S_tok, S_fix):
if S_fix is not None:
# [B, Ds] -> [B, 1, H]
S = self.species_ln(S_fix.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1))
return S
elif S_tok is not None:
# [B, Ls, Ds] -> optional cap, then project to H
S = S_tok
if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix:
S = S[:, : self.max_species_prefix, :]
S = S.to(device=device, dtype=next(self.parameters()).dtype)
S = self.species_ln(S)
return S
else:
return None
Ssrc = _as_tokens(tok_src, emb_src)
if Ssrc is not None:
parts.append(Ssrc)
Sdst = _as_tokens(tok_tgt, emb_tgt)
if Sdst is not None:
parts.append(Sdst)
# Protein tokens from ESM-C
if self.prepend_protein and self.esm is not None and protein_input is not None:
prot_ids, prot_mask = protein_input
esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True)
P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask)
# Optional per-protein capping before projection
if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix:
P = P[:, : self.max_protein_prefix, :]
if lengths is not None:
lengths = lengths.clamp(max=self.max_protein_prefix)
if P.size(1) > 0:
P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype))
# Zero padded rows (per-sample) based on lengths
if lengths is not None:
Lp = P.size(1)
ar = torch.arange(Lp, device=device).unsqueeze(0)
lengths = lengths.to(device=device)
valid = ar < lengths.unsqueeze(1) # [B,Lp]
P = P * valid.unsqueeze(-1)
parts.append(P)
if len(parts) == 0:
empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype)
return empty, torch.zeros(batch_size, dtype=torch.long, device=device)
prefix = torch.cat(parts, dim=1) if parts else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) # [B,Lp,H]
# Compute per-sample valid lengths: treat zero rows as padding
with torch.no_grad():
if prefix.size(1) > 0:
valid = (prefix.abs().sum(dim=-1) > 0)
lengths = valid.sum(dim=1).to(torch.long)
else:
lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
# ---- Enforce hard global budget on the prefix itself ----
prefix_budget = max(0, int(self.max_position_embeddings) - 1)
if prefix_budget == 0:
trimmed = prefix.new_zeros(prefix.size(0), 0, prefix.size(2))
return trimmed, torch.zeros(prefix.size(0), dtype=torch.long, device=prefix.device)
allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype))
Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0
if prefix.size(1) > Lp_max:
trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2))
for b in range(prefix.size(0)):
lb = int(allow[b].item())
if lb > 0:
trimmed[b, :lb, :] = prefix[b, :lb, :]
prefix = trimmed
lengths = allow
else:
lengths = allow
return prefix, lengths
def forward(
self,
codon_ids: torch.Tensor,
cond: Dict[str, Any] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
species_tok_emb: Optional[torch.Tensor] = None,
protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
protein_seqs: Optional[List[str]] = None,
# KV cache options
use_cache: bool = False,
past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
position_offset: int = 0,
) -> Dict[str, torch.Tensor]:
batch_size, codon_len = codon_ids.shape
device = codon_ids.device
# Unpack conditioning
if cond is not None:
control_mode = cond.get("control_mode", "fixed")
species_tok_emb_src = cond.get("species_tok_emb_src")
species_tok_emb_tgt = cond.get("species_tok_emb_tgt")
species_emb_src = cond.get("species_emb_src")
species_emb_tgt = cond.get("species_emb_tgt")
species_tok_emb = cond.get("species_tok_emb")
species_emb = cond.get("species_emb")
protein_input = cond.get("protein_input")
protein_seqs = cond.get("protein_seqs")
else:
species_emb = None
species_tok_emb_src = None
species_tok_emb_tgt = None
species_emb_src = None
species_emb_tgt = None
if protein_seqs is not None and protein_input is None:
if self.esm is not None:
with torch.no_grad():
# Respect per-protein ceiling during tokenization (+2 for BOS/EOS)
max_len_tokens = (self.max_protein_prefix + 2) if (getattr(self, "max_protein_prefix", 0) > 0) else None
protein_input = self.esm.tokenize(protein_seqs, max_length=max_len_tokens)
else:
protein_input = None
# Fast path: incremental decode using KV cache
if past_kv is not None:
# Expect only newly generated codon tokens here
if codon_ids.numel() == 0:
# Nothing to do; return a dummy next_logits
dummy = torch.zeros(batch_size, self.vocab_size, device=device, dtype=self.lm_head.weight.dtype)
return {"logits": dummy[:, 0:0], "next_logits": dummy}
x = self.embed_tokens(codon_ids) # [B, T_new, H]
present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = []
for i, block in enumerate(self.blocks):
kv_i = past_kv[i] if i < len(past_kv) else None
if self.training and getattr(self, 'gradient_checkpointing', False):
def _fn(inp):
return block(inp, past_kv=kv_i, use_cache=True, position_offset=position_offset)
out_blk = checkpoint.checkpoint(_fn, x, use_reentrant=False)
else:
out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset)
x, kv_out = out_blk # type: ignore[assignment]
present_kv.append(kv_out)
x = self.ln_f(x)
logits_step = self.lm_head(x) # [B, T_new, V]
next_logits = logits_step[:, -1, :]
out: Dict[str, torch.Tensor] = {"logits": logits_step[:, 0:0, :], "next_logits": next_logits}
out["present_kv"] = present_kv # type: ignore[assignment]
return out if return_dict else logits_step[:, 0:0, :]
# Standard path: build prefix and full window (training or prefill)
prefix, prefix_lengths = self.build_prefix(
batch_size=batch_size,
device=device,
species_tok_emb=species_tok_emb,
species_emb=species_emb if cond is not None else None,
protein_input=protein_input,
species_tok_emb_src=species_tok_emb_src,
species_tok_emb_tgt=species_tok_emb_tgt,
species_emb_src=species_emb_src,
species_emb_tgt=species_emb_tgt,
)
start = self.start_embed.expand(batch_size, 1, self.hidden_size) # [B,1,H]
# Per-sample true codon input lengths (exclude PADs)
pad_id = int(self.special_ids.pad) if hasattr(self, "special_ids") and self.special_ids is not None else 0
codon_mask = (codon_ids != pad_id) # [B, N]
codon_lens = codon_mask.sum(dim=1) # [B]
# Budget remaining after prefix + start
capacity = max(0, int(self.max_position_embeddings))
budget_after_prefix = torch.clamp(
torch.as_tensor(capacity, device=device) - (prefix_lengths + 1),
min=0,
) # [B]
# Per-sample cap is limited by both budget and available codons
per_cap = torch.minimum(budget_after_prefix, codon_lens) # [B]
# Total valid lengths per sample (prefix + start + capped codon)
valid_lengths = prefix_lengths + 1 + per_cap
T = int(valid_lengths.max().item()) if valid_lengths.numel() > 0 else (1 + int(codon_lens.max().item()) if codon_lens.numel() > 0 else 1)
# Embed only the needed codon window for this batch
max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0
if max_cap > 0:
codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) # [B, max_cap, H]
else:
codon_emb = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype)
# Build sequence per-sample using concat to preserve gradients, then pad
seqs = []
for b in range(batch_size):
lp = int(prefix_lengths[b].item())
cap = int(per_cap[b].item())
parts = []
if lp > 0:
parts.append(prefix[b, :lp, :])
parts.append(start[b, 0:1, :])
if cap > 0:
parts.append(codon_emb[b, :cap, :])
seqs.append(torch.cat(parts, dim=0)) # [Lb, H]
x = rnn_utils.pad_sequence(seqs, batch_first=True) # [B, T, H]
present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = []
for block in self.blocks:
if self.training and getattr(self, 'gradient_checkpointing', False):
def _fn(inp):
return block(inp, use_cache=use_cache, position_offset=0)
blk_out = checkpoint.checkpoint(_fn, x, use_reentrant=False)
else:
blk_out = block(x, use_cache=use_cache, position_offset=0)
if use_cache:
x, kv = blk_out # type: ignore[misc]
present_kv_list.append(kv)
else:
x = blk_out # type: ignore[assignment]
x = self.ln_f(x)
logits_full = self.lm_head(x) # [B, T, V]
# Gather codon-aligned logits per sample: positions (lp+1) .. (lp+cap) (skip start)
next_logits_list = []
if max_cap == 0:
# Keep graph by slicing from logits_full
codon_logits = logits_full[:, 0:0, :]
for b in range(batch_size):
lp = int(prefix_lengths[b].item())
# Last consumed position is the start token at index lp
pos_next = lp
if pos_next < logits_full.size(1):
next_logits_list.append(logits_full[b, pos_next, :])
else:
next_logits_list.append(logits_full[b, -1, :])
next_logits = torch.stack(next_logits_list, dim=0)
else:
slices = []
for b in range(batch_size):
lp = int(prefix_lengths[b].item())
cap = int(per_cap[b].item())
# Skip the start position so logits align with labels = codon_ids[:, 1:]
sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size)
slices.append(sl)
# Next-token logits after processing 'cap' codons: last consumed is at lp + cap
pos_next = lp + cap
next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size))
codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) # [B,max_cap,V]
next_logits = torch.stack(next_logits_list, dim=0)
out = {"logits": codon_logits, "next_logits": next_logits}
if labels is not None:
# Align labels to per-sample caps: mask out positions >= cap
if labels.size(1) > 0 and max_cap > 0:
# Build masked labels with -100 beyond cap per sample
adj = labels.new_full((batch_size, max_cap), -100)
for b in range(batch_size):
cap = int(per_cap[b].item())
if cap > 0:
Lb = min(cap, labels.size(1))
adj[b, :Lb] = labels[b, :Lb]
loss = F.cross_entropy(codon_logits.reshape(-1, self.vocab_size), adj.reshape(-1), ignore_index=-100)
else:
loss = codon_logits.sum() * 0.0
out["loss"] = loss
# Provide optional debug stats for trainer logging
out["prefix_len"] = prefix_lengths.detach()
out["per_cap"] = per_cap.detach()
if use_cache:
out["present_kv"] = present_kv_list # type: ignore[assignment]
return out if return_dict else codon_logits