""" Core model architectures for CodonTranslator. - CodonTranslatorModel: decoder-only backbone with species + protein prefixing Includes a frozen ESM-C encoder for protein conditioning. """ import math import os from typing import Optional, Dict, Any, Tuple, List import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import torch.nn.utils.rnn as rnn_utils from .layers import RMSNorm, TransformerBlock from .tokenizer import SpecialIds class FrozenESMCEncoder(nn.Module): """ Frozen ESM-C encoder that computes protein embeddings on the fly. Kept on single GPU per rank (not distributed via FSDP). """ def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "fp16"): super().__init__() self.model_name = model_name self._device = torch.device(device if torch.cuda.is_available() else "cpu") if dtype == "fp16": self._autocast_dtype = torch.float16 elif dtype == "bf16": self._autocast_dtype = torch.bfloat16 else: self._autocast_dtype = None self._load_model() self.eval() for p in self.parameters(): p.requires_grad_(False) def _load_model(self): from esm.models.esmc import ESMC from esm.utils.constants.models import ESMC_300M, ESMC_600M if self.model_name == "esmc_300m": model_const = ESMC_300M self.D_esm = 960 elif self.model_name == "esmc_600m": model_const = ESMC_600M self.D_esm = 1152 else: raise ValueError(f"Unknown model: {self.model_name}") self.model = ESMC.from_pretrained(model_name=model_const, device=self._device) self.tokenizer = self.model.tokenizer @torch.no_grad() def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"): from esm.utils import encoding from esm.utils.misc import stack_variable_length_tensors pad = self.tokenizer.pad_token_id tokenized_seqs = [] for seq in sequences: tokens = encoding.tokenize_sequence(seq, self.tokenizer, add_special_tokens=add_special_tokens) if max_length is not None and len(tokens) > max_length: tokens = tokens[:max_length] tokenized_seqs.append(tokens) input_ids = stack_variable_length_tensors(tokenized_seqs, constant_value=pad) attention_mask = (input_ids != pad) return input_ids, attention_mask @torch.no_grad() def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, return_contacts: bool = False): device = self.model.device input_ids = input_ids.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) if self._autocast_dtype is not None and device.type == "cuda": with torch.amp.autocast('cuda', dtype=self._autocast_dtype): outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) else: outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) embeddings = outputs.embeddings if return_dict: return {"embeddings": embeddings, "attention_mask": attention_mask} else: return embeddings def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None): if attention_mask is not None: lengths = attention_mask.sum(dim=1) - 2 lengths = lengths.clamp(min=1) else: B, L, D = embeddings.shape lengths = torch.full((B,), L - 2, device=embeddings.device) stripped = embeddings[:, 1:-1, :] return stripped, lengths class CodonTranslatorModel(nn.Module): def __init__( self, vocab_size: int = 79, hidden_size: int = 960, num_layers: int = 24, num_heads: int = 16, mlp_ratio: float = 4.0, max_position_embeddings: int = 4096, dropout: float = 0.1, layer_norm_eps: float = 1e-6, num_special_tokens: int = 13, special_ids: Optional[SpecialIds] = None, esm_model_name: str = "esmc_300m", esm_device: str = "cuda", esm_dtype: str = "fp16", max_protein_prefix: int = 0, max_species_prefix: int = 0, prepend_species: bool = True, prepend_protein: bool = True, species_embedding_dim: int = 1024, attn_impl: str = "gqa", # "gqa" or "mha" num_kv_groups: int = 0, # for GQA; 0 means default (no grouping) ): super().__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers self.num_heads = num_heads self.max_position_embeddings = max_position_embeddings self.special_ids = special_ids or SpecialIds() self.num_special_tokens = num_special_tokens # Single embedding table for all tokens (special + codon) self.token_embed = nn.Embedding(vocab_size, hidden_size) if prepend_protein and esm_model_name: self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype) # Project ESM token embeddings (D_esm) to model hidden size, then normalize self.esm_ln = nn.Sequential( nn.Linear(self.esm.D_esm, hidden_size, bias=False), nn.ReLU(), nn.LayerNorm(hidden_size), ) else: self.esm = None self.esm_ln = None self.species_embedding_dim = species_embedding_dim if prepend_species else 0 if prepend_species: # Project species embeddings (fixed or token sequence) from Ds -> H self.species_ln = nn.Sequential( nn.Linear(self.species_embedding_dim, hidden_size, bias=False), nn.ReLU(), nn.LayerNorm(hidden_size), ) else: self.species_ln = None # Optional per-prefix caps; 0 means unlimited (subject to global max length) self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0 self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0 self.prepend_species = bool(prepend_species) self.prepend_protein = bool(prepend_protein) # Learned start embedding (BOS-less decoding) self.start_embed = nn.Parameter(torch.zeros(1, 1, hidden_size)) nn.init.normal_(self.start_embed, mean=0.0, std=0.02) # Attention configuration self.attn_impl = str(attn_impl) self.num_kv_groups = int(num_kv_groups) kv_groups = self.num_kv_groups self.blocks = nn.ModuleList([ TransformerBlock( dim=hidden_size, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout, num_kv_groups=(kv_groups if (kv_groups > 0 and attn_impl == "gqa") else None), qk_norm=False, attn_type=("mha" if self.attn_impl == "mha" else "gqa"), ) for _ in range(num_layers) ]) self.ln_f = RMSNorm(hidden_size, eps=layer_norm_eps) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) self.gradient_checkpointing = False def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: device = self.token_embed.weight.device return self.token_embed(token_ids.to(device)) def build_prefix( self, batch_size: int, device: torch.device, species_tok_emb: Optional[torch.Tensor] = None, species_emb: Optional[torch.Tensor] = None, protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, species_tok_emb_src: Optional[torch.Tensor] = None, species_tok_emb_tgt: Optional[torch.Tensor] = None, species_emb_src: Optional[torch.Tensor] = None, species_emb_tgt: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Build LLaVA-style prefix token embeddings by concatenating [species_src]+[species_tgt]+[protein_tokens]. Returns: - prefix: [B, Lp, H] - prefix_lengths: [B] valid token counts per sample """ parts: list[torch.Tensor] = [] # Species: src then tgt (if provided) if self.prepend_species and self.species_ln is not None: tok_src = species_tok_emb_src if species_tok_emb_src is not None else species_tok_emb tok_tgt = species_tok_emb_tgt if species_tok_emb_tgt is not None else species_tok_emb emb_src = species_emb_src if species_emb_src is not None else species_emb emb_tgt = species_emb_tgt if species_emb_tgt is not None else species_emb def _as_tokens(S_tok, S_fix): if S_fix is not None: # [B, Ds] -> [B, 1, H] S = self.species_ln(S_fix.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1)) return S elif S_tok is not None: # [B, Ls, Ds] -> optional cap, then project to H S = S_tok if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix: S = S[:, : self.max_species_prefix, :] S = S.to(device=device, dtype=next(self.parameters()).dtype) S = self.species_ln(S) return S else: return None Ssrc = _as_tokens(tok_src, emb_src) if Ssrc is not None: parts.append(Ssrc) Sdst = _as_tokens(tok_tgt, emb_tgt) if Sdst is not None: parts.append(Sdst) # Protein tokens from ESM-C if self.prepend_protein and self.esm is not None and protein_input is not None: prot_ids, prot_mask = protein_input esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True) P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask) # Optional per-protein capping before projection if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix: P = P[:, : self.max_protein_prefix, :] if lengths is not None: lengths = lengths.clamp(max=self.max_protein_prefix) if P.size(1) > 0: P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype)) # Zero padded rows (per-sample) based on lengths if lengths is not None: Lp = P.size(1) ar = torch.arange(Lp, device=device).unsqueeze(0) lengths = lengths.to(device=device) valid = ar < lengths.unsqueeze(1) # [B,Lp] P = P * valid.unsqueeze(-1) parts.append(P) if len(parts) == 0: empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) return empty, torch.zeros(batch_size, dtype=torch.long, device=device) prefix = torch.cat(parts, dim=1) if parts else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) # [B,Lp,H] # Compute per-sample valid lengths: treat zero rows as padding with torch.no_grad(): if prefix.size(1) > 0: valid = (prefix.abs().sum(dim=-1) > 0) lengths = valid.sum(dim=1).to(torch.long) else: lengths = torch.zeros(batch_size, dtype=torch.long, device=device) # ---- Enforce hard global budget on the prefix itself ---- prefix_budget = max(0, int(self.max_position_embeddings) - 1) if prefix_budget == 0: trimmed = prefix.new_zeros(prefix.size(0), 0, prefix.size(2)) return trimmed, torch.zeros(prefix.size(0), dtype=torch.long, device=prefix.device) allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype)) Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0 if prefix.size(1) > Lp_max: trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2)) for b in range(prefix.size(0)): lb = int(allow[b].item()) if lb > 0: trimmed[b, :lb, :] = prefix[b, :lb, :] prefix = trimmed lengths = allow else: lengths = allow return prefix, lengths def forward( self, codon_ids: torch.Tensor, cond: Dict[str, Any] = None, labels: Optional[torch.Tensor] = None, return_dict: bool = True, species_tok_emb: Optional[torch.Tensor] = None, protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, protein_seqs: Optional[List[str]] = None, # KV cache options use_cache: bool = False, past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_offset: int = 0, ) -> Dict[str, torch.Tensor]: batch_size, codon_len = codon_ids.shape device = codon_ids.device # Unpack conditioning if cond is not None: control_mode = cond.get("control_mode", "fixed") species_tok_emb_src = cond.get("species_tok_emb_src") species_tok_emb_tgt = cond.get("species_tok_emb_tgt") species_emb_src = cond.get("species_emb_src") species_emb_tgt = cond.get("species_emb_tgt") species_tok_emb = cond.get("species_tok_emb") species_emb = cond.get("species_emb") protein_input = cond.get("protein_input") protein_seqs = cond.get("protein_seqs") else: species_emb = None species_tok_emb_src = None species_tok_emb_tgt = None species_emb_src = None species_emb_tgt = None if protein_seqs is not None and protein_input is None: if self.esm is not None: with torch.no_grad(): # Respect per-protein ceiling during tokenization (+2 for BOS/EOS) max_len_tokens = (self.max_protein_prefix + 2) if (getattr(self, "max_protein_prefix", 0) > 0) else None protein_input = self.esm.tokenize(protein_seqs, max_length=max_len_tokens) else: protein_input = None # Fast path: incremental decode using KV cache if past_kv is not None: # Expect only newly generated codon tokens here if codon_ids.numel() == 0: # Nothing to do; return a dummy next_logits dummy = torch.zeros(batch_size, self.vocab_size, device=device, dtype=self.lm_head.weight.dtype) return {"logits": dummy[:, 0:0], "next_logits": dummy} x = self.embed_tokens(codon_ids) # [B, T_new, H] present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = [] for i, block in enumerate(self.blocks): kv_i = past_kv[i] if i < len(past_kv) else None if self.training and getattr(self, 'gradient_checkpointing', False): def _fn(inp): return block(inp, past_kv=kv_i, use_cache=True, position_offset=position_offset) out_blk = checkpoint.checkpoint(_fn, x, use_reentrant=False) else: out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset) x, kv_out = out_blk # type: ignore[assignment] present_kv.append(kv_out) x = self.ln_f(x) logits_step = self.lm_head(x) # [B, T_new, V] next_logits = logits_step[:, -1, :] out: Dict[str, torch.Tensor] = {"logits": logits_step[:, 0:0, :], "next_logits": next_logits} out["present_kv"] = present_kv # type: ignore[assignment] return out if return_dict else logits_step[:, 0:0, :] # Standard path: build prefix and full window (training or prefill) prefix, prefix_lengths = self.build_prefix( batch_size=batch_size, device=device, species_tok_emb=species_tok_emb, species_emb=species_emb if cond is not None else None, protein_input=protein_input, species_tok_emb_src=species_tok_emb_src, species_tok_emb_tgt=species_tok_emb_tgt, species_emb_src=species_emb_src, species_emb_tgt=species_emb_tgt, ) start = self.start_embed.expand(batch_size, 1, self.hidden_size) # [B,1,H] # Per-sample true codon input lengths (exclude PADs) pad_id = int(self.special_ids.pad) if hasattr(self, "special_ids") and self.special_ids is not None else 0 codon_mask = (codon_ids != pad_id) # [B, N] codon_lens = codon_mask.sum(dim=1) # [B] # Budget remaining after prefix + start capacity = max(0, int(self.max_position_embeddings)) budget_after_prefix = torch.clamp( torch.as_tensor(capacity, device=device) - (prefix_lengths + 1), min=0, ) # [B] # Per-sample cap is limited by both budget and available codons per_cap = torch.minimum(budget_after_prefix, codon_lens) # [B] # Total valid lengths per sample (prefix + start + capped codon) valid_lengths = prefix_lengths + 1 + per_cap T = int(valid_lengths.max().item()) if valid_lengths.numel() > 0 else (1 + int(codon_lens.max().item()) if codon_lens.numel() > 0 else 1) # Embed only the needed codon window for this batch max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0 if max_cap > 0: codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) # [B, max_cap, H] else: codon_emb = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype) # Build sequence per-sample using concat to preserve gradients, then pad seqs = [] for b in range(batch_size): lp = int(prefix_lengths[b].item()) cap = int(per_cap[b].item()) parts = [] if lp > 0: parts.append(prefix[b, :lp, :]) parts.append(start[b, 0:1, :]) if cap > 0: parts.append(codon_emb[b, :cap, :]) seqs.append(torch.cat(parts, dim=0)) # [Lb, H] x = rnn_utils.pad_sequence(seqs, batch_first=True) # [B, T, H] present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = [] for block in self.blocks: if self.training and getattr(self, 'gradient_checkpointing', False): def _fn(inp): return block(inp, use_cache=use_cache, position_offset=0) blk_out = checkpoint.checkpoint(_fn, x, use_reentrant=False) else: blk_out = block(x, use_cache=use_cache, position_offset=0) if use_cache: x, kv = blk_out # type: ignore[misc] present_kv_list.append(kv) else: x = blk_out # type: ignore[assignment] x = self.ln_f(x) logits_full = self.lm_head(x) # [B, T, V] # Gather codon-aligned logits per sample: positions (lp+1) .. (lp+cap) (skip start) next_logits_list = [] if max_cap == 0: # Keep graph by slicing from logits_full codon_logits = logits_full[:, 0:0, :] for b in range(batch_size): lp = int(prefix_lengths[b].item()) # Last consumed position is the start token at index lp pos_next = lp if pos_next < logits_full.size(1): next_logits_list.append(logits_full[b, pos_next, :]) else: next_logits_list.append(logits_full[b, -1, :]) next_logits = torch.stack(next_logits_list, dim=0) else: slices = [] for b in range(batch_size): lp = int(prefix_lengths[b].item()) cap = int(per_cap[b].item()) # Skip the start position so logits align with labels = codon_ids[:, 1:] sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size) slices.append(sl) # Next-token logits after processing 'cap' codons: last consumed is at lp + cap pos_next = lp + cap next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size)) codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) # [B,max_cap,V] next_logits = torch.stack(next_logits_list, dim=0) out = {"logits": codon_logits, "next_logits": next_logits} if labels is not None: # Align labels to per-sample caps: mask out positions >= cap if labels.size(1) > 0 and max_cap > 0: # Build masked labels with -100 beyond cap per sample adj = labels.new_full((batch_size, max_cap), -100) for b in range(batch_size): cap = int(per_cap[b].item()) if cap > 0: Lb = min(cap, labels.size(1)) adj[b, :Lb] = labels[b, :Lb] loss = F.cross_entropy(codon_logits.reshape(-1, self.vocab_size), adj.reshape(-1), ignore_index=-100) else: loss = codon_logits.sum() * 0.0 out["loss"] = loss # Provide optional debug stats for trainer logging out["prefix_len"] = prefix_lengths.detach() out["per_cap"] = per_cap.detach() if use_cache: out["present_kv"] = present_kv_list # type: ignore[assignment] return out if return_dict else codon_logits