| """ |
| Core model architectures for CodonTranslator. |
| - CodonTranslatorModel: decoder-only backbone with species + protein prefix |
| Includes a frozen ESM-C encoder for protein conditioning. |
| """ |
|
|
| import math |
| import os |
| from typing import Optional, Dict, Any, Tuple, List |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as checkpoint |
| import torch.nn.utils.rnn as rnn_utils |
|
|
| from .layers import RMSNorm, TransformerBlock |
| from .tokenizer import SpecialIds |
|
|
|
|
| class FrozenESMCEncoder(nn.Module): |
| """ |
| Frozen ESM-C encoder that computes protein embeddings on the fly. |
| Kept on single GPU per rank (not distributed via FSDP). |
| """ |
| |
| def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "fp16"): |
| super().__init__() |
| self.model_name = model_name |
| self._device = torch.device(device if torch.cuda.is_available() else "cpu") |
| if dtype == "fp16": |
| self._autocast_dtype = torch.float16 |
| elif dtype == "bf16": |
| self._autocast_dtype = torch.bfloat16 |
| else: |
| self._autocast_dtype = None |
| self._load_model() |
| self.eval() |
| for p in self.parameters(): |
| p.requires_grad_(False) |
|
|
| def _load_model(self): |
| from esm.models.esmc import ESMC |
| from esm.utils.constants.models import ESMC_300M, ESMC_600M |
| if self.model_name == "esmc_300m": |
| model_const = ESMC_300M |
| self.D_esm = 960 |
| elif self.model_name == "esmc_600m": |
| model_const = ESMC_600M |
| self.D_esm = 1152 |
| else: |
| raise ValueError(f"Unknown model: {self.model_name}") |
| self.model = ESMC.from_pretrained(model_name=model_const, device=self._device) |
| self.tokenizer = self.model.tokenizer |
|
|
| @torch.no_grad() |
| def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"): |
| from esm.utils import encoding |
| from esm.utils.misc import stack_variable_length_tensors |
| pad = self.tokenizer.pad_token_id |
| tokenized_seqs = [] |
| for seq in sequences: |
| tokens = encoding.tokenize_sequence(seq, self.tokenizer, add_special_tokens=add_special_tokens) |
| if max_length is not None and len(tokens) > max_length: |
| tokens = tokens[:max_length] |
| tokenized_seqs.append(tokens) |
| input_ids = stack_variable_length_tensors(tokenized_seqs, constant_value=pad) |
| attention_mask = (input_ids != pad) |
| return input_ids, attention_mask |
|
|
| @torch.no_grad() |
| def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, return_contacts: bool = False): |
| device = self.model.device |
| input_ids = input_ids.to(device) |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(device) |
| if self._autocast_dtype is not None and device.type == "cuda": |
| with torch.amp.autocast('cuda', dtype=self._autocast_dtype): |
| outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) |
| else: |
| outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) |
| embeddings = outputs.embeddings |
| if return_dict: |
| return {"embeddings": embeddings, "attention_mask": attention_mask} |
| else: |
| return embeddings |
|
|
| def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None): |
| if attention_mask is not None: |
| lengths = attention_mask.sum(dim=1) - 2 |
| lengths = lengths.clamp(min=1) |
| else: |
| B, L, D = embeddings.shape |
| lengths = torch.full((B,), L - 2, device=embeddings.device) |
| stripped = embeddings[:, 1:-1, :] |
| return stripped, lengths |
|
|
|
|
|
|
|
|
| class CodonTranslatorModel(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int = 79, |
| hidden_size: int = 960, |
| num_layers: int = 24, |
| num_heads: int = 16, |
| mlp_ratio: float = 4.0, |
| max_position_embeddings: int = 4096, |
| dropout: float = 0.1, |
| layer_norm_eps: float = 1e-6, |
| num_special_tokens: int = 13, |
| special_ids: Optional[SpecialIds] = None, |
| esm_model_name: str = "esmc_300m", |
| esm_device: str = "cuda", |
| esm_dtype: str = "fp16", |
| max_protein_prefix: int = 0, |
| max_species_prefix: int = 0, |
| prepend_species: bool = True, |
| prepend_protein: bool = True, |
| species_embedding_dim: int = 1024, |
| attn_impl: str = "gqa", |
| num_kv_groups: int = 0, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.num_heads = num_heads |
| self.max_position_embeddings = max_position_embeddings |
|
|
| self.special_ids = special_ids or SpecialIds() |
| self.num_special_tokens = num_special_tokens |
|
|
| |
| self.token_embed = nn.Embedding(vocab_size, hidden_size) |
|
|
| if prepend_protein and esm_model_name: |
| self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype) |
| |
| self.esm_ln = nn.Sequential( |
| nn.Linear(self.esm.D_esm, hidden_size, bias=False), |
| nn.ReLU(), |
| nn.LayerNorm(hidden_size), |
| ) |
| else: |
| self.esm = None |
| self.esm_ln = None |
|
|
| self.species_embedding_dim = species_embedding_dim if prepend_species else 0 |
| if prepend_species: |
| |
| self.species_ln = nn.Sequential( |
| nn.Linear(self.species_embedding_dim, hidden_size, bias=False), |
| nn.ReLU(), |
| nn.LayerNorm(hidden_size), |
| ) |
| else: |
| self.species_ln = None |
|
|
| |
| self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0 |
| self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0 |
| self.prepend_species = bool(prepend_species) |
| self.prepend_protein = bool(prepend_protein) |
|
|
| |
| self.start_embed = nn.Parameter(torch.zeros(1, 1, hidden_size)) |
| nn.init.normal_(self.start_embed, mean=0.0, std=0.02) |
|
|
| |
| |
| self.attn_impl = str(attn_impl) |
| self.num_kv_groups = int(num_kv_groups) |
| kv_groups = self.num_kv_groups |
| self.blocks = nn.ModuleList([ |
| TransformerBlock( |
| dim=hidden_size, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| dropout=dropout, |
| num_kv_groups=(kv_groups if (kv_groups > 0 and attn_impl == "gqa") else None), |
| qk_norm=False, |
| attn_type=("mha" if self.attn_impl == "mha" else "gqa"), |
| ) for _ in range(num_layers) |
| ]) |
|
|
| self.ln_f = RMSNorm(hidden_size, eps=layer_norm_eps) |
| self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
| self.gradient_checkpointing = False |
|
|
| def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
| device = self.token_embed.weight.device |
| return self.token_embed(token_ids.to(device)) |
|
|
| def build_prefix( |
| self, |
| batch_size: int, |
| device: torch.device, |
| species_tok_emb: Optional[torch.Tensor] = None, |
| species_emb: Optional[torch.Tensor] = None, |
| protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| species_tok_emb_src: Optional[torch.Tensor] = None, |
| species_tok_emb_tgt: Optional[torch.Tensor] = None, |
| species_emb_src: Optional[torch.Tensor] = None, |
| species_emb_tgt: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Build LLaVA-style prefix token embeddings by concatenating |
| [species_src]+[species_tgt]+[protein_tokens]. Returns: |
| - prefix: [B, Lp, H] |
| - prefix_lengths: [B] valid token counts per sample |
| """ |
| parts: list[torch.Tensor] = [] |
|
|
| |
| if self.prepend_species and self.species_ln is not None: |
| tok_src = species_tok_emb_src if species_tok_emb_src is not None else species_tok_emb |
| tok_tgt = species_tok_emb_tgt if species_tok_emb_tgt is not None else species_tok_emb |
| emb_src = species_emb_src if species_emb_src is not None else species_emb |
| emb_tgt = species_emb_tgt if species_emb_tgt is not None else species_emb |
|
|
| def _as_tokens(S_tok, S_fix): |
| if S_fix is not None: |
| |
| S = self.species_ln(S_fix.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1)) |
| return S |
| elif S_tok is not None: |
| |
| S = S_tok |
| if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix: |
| S = S[:, : self.max_species_prefix, :] |
| S = S.to(device=device, dtype=next(self.parameters()).dtype) |
| S = self.species_ln(S) |
| return S |
| else: |
| return None |
|
|
| Ssrc = _as_tokens(tok_src, emb_src) |
| if Ssrc is not None: |
| parts.append(Ssrc) |
| Sdst = _as_tokens(tok_tgt, emb_tgt) |
| if Sdst is not None: |
| parts.append(Sdst) |
|
|
| |
| if self.prepend_protein and self.esm is not None and protein_input is not None: |
| prot_ids, prot_mask = protein_input |
| esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True) |
| P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask) |
| |
| if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix: |
| P = P[:, : self.max_protein_prefix, :] |
| if lengths is not None: |
| lengths = lengths.clamp(max=self.max_protein_prefix) |
| if P.size(1) > 0: |
| P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype)) |
| |
| if lengths is not None: |
| Lp = P.size(1) |
| ar = torch.arange(Lp, device=device).unsqueeze(0) |
| lengths = lengths.to(device=device) |
| valid = ar < lengths.unsqueeze(1) |
| P = P * valid.unsqueeze(-1) |
| parts.append(P) |
|
|
| if len(parts) == 0: |
| empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) |
| return empty, torch.zeros(batch_size, dtype=torch.long, device=device) |
|
|
| prefix = torch.cat(parts, dim=1) if parts else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) |
| |
| with torch.no_grad(): |
| if prefix.size(1) > 0: |
| valid = (prefix.abs().sum(dim=-1) > 0) |
| lengths = valid.sum(dim=1).to(torch.long) |
| else: |
| lengths = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
|
| |
| prefix_budget = max(0, int(self.max_position_embeddings) - 1) |
| if prefix_budget == 0: |
| trimmed = prefix.new_zeros(prefix.size(0), 0, prefix.size(2)) |
| return trimmed, torch.zeros(prefix.size(0), dtype=torch.long, device=prefix.device) |
|
|
| allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype)) |
| Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0 |
| if prefix.size(1) > Lp_max: |
| trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2)) |
| for b in range(prefix.size(0)): |
| lb = int(allow[b].item()) |
| if lb > 0: |
| trimmed[b, :lb, :] = prefix[b, :lb, :] |
| prefix = trimmed |
| lengths = allow |
| else: |
| lengths = allow |
| return prefix, lengths |
|
|
| def forward( |
| self, |
| codon_ids: torch.Tensor, |
| cond: Dict[str, Any] = None, |
| labels: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| species_tok_emb: Optional[torch.Tensor] = None, |
| protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| protein_seqs: Optional[List[str]] = None, |
| |
| use_cache: bool = False, |
| past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| position_offset: int = 0, |
| ) -> Dict[str, torch.Tensor]: |
| batch_size, codon_len = codon_ids.shape |
| device = codon_ids.device |
|
|
| |
| if cond is not None: |
| control_mode = cond.get("control_mode", "fixed") |
| species_tok_emb_src = cond.get("species_tok_emb_src") |
| species_tok_emb_tgt = cond.get("species_tok_emb_tgt") |
| species_emb_src = cond.get("species_emb_src") |
| species_emb_tgt = cond.get("species_emb_tgt") |
| species_tok_emb = cond.get("species_tok_emb") |
| species_emb = cond.get("species_emb") |
| protein_input = cond.get("protein_input") |
| protein_seqs = cond.get("protein_seqs") |
| else: |
| species_emb = None |
| species_tok_emb_src = None |
| species_tok_emb_tgt = None |
| species_emb_src = None |
| species_emb_tgt = None |
|
|
| if protein_seqs is not None and protein_input is None: |
| if self.esm is not None: |
| with torch.no_grad(): |
| |
| max_len_tokens = (self.max_protein_prefix + 2) if (getattr(self, "max_protein_prefix", 0) > 0) else None |
| protein_input = self.esm.tokenize(protein_seqs, max_length=max_len_tokens) |
| else: |
| protein_input = None |
|
|
| |
| if past_kv is not None: |
| |
| if codon_ids.numel() == 0: |
| |
| dummy = torch.zeros(batch_size, self.vocab_size, device=device, dtype=self.lm_head.weight.dtype) |
| return {"logits": dummy[:, 0:0], "next_logits": dummy} |
|
|
| x = self.embed_tokens(codon_ids) |
|
|
| present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = [] |
| for i, block in enumerate(self.blocks): |
| kv_i = past_kv[i] if i < len(past_kv) else None |
| if self.training and getattr(self, 'gradient_checkpointing', False): |
| def _fn(inp): |
| return block(inp, past_kv=kv_i, use_cache=True, position_offset=position_offset) |
| out_blk = checkpoint.checkpoint(_fn, x, use_reentrant=False) |
| else: |
| out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset) |
| x, kv_out = out_blk |
| present_kv.append(kv_out) |
|
|
| x = self.ln_f(x) |
| logits_step = self.lm_head(x) |
| next_logits = logits_step[:, -1, :] |
| out: Dict[str, torch.Tensor] = {"logits": logits_step[:, 0:0, :], "next_logits": next_logits} |
| out["present_kv"] = present_kv |
| return out if return_dict else logits_step[:, 0:0, :] |
|
|
| |
| prefix, prefix_lengths = self.build_prefix( |
| batch_size=batch_size, |
| device=device, |
| species_tok_emb=species_tok_emb, |
| species_emb=species_emb if cond is not None else None, |
| protein_input=protein_input, |
| species_tok_emb_src=species_tok_emb_src, |
| species_tok_emb_tgt=species_tok_emb_tgt, |
| species_emb_src=species_emb_src, |
| species_emb_tgt=species_emb_tgt, |
| ) |
|
|
| start = self.start_embed.expand(batch_size, 1, self.hidden_size) |
|
|
| |
| pad_id = int(self.special_ids.pad) if hasattr(self, "special_ids") and self.special_ids is not None else 0 |
| codon_mask = (codon_ids != pad_id) |
| codon_lens = codon_mask.sum(dim=1) |
|
|
| |
| capacity = max(0, int(self.max_position_embeddings)) |
| budget_after_prefix = torch.clamp( |
| torch.as_tensor(capacity, device=device) - (prefix_lengths + 1), |
| min=0, |
| ) |
| |
| per_cap = torch.minimum(budget_after_prefix, codon_lens) |
|
|
| |
| valid_lengths = prefix_lengths + 1 + per_cap |
| T = int(valid_lengths.max().item()) if valid_lengths.numel() > 0 else (1 + int(codon_lens.max().item()) if codon_lens.numel() > 0 else 1) |
|
|
| |
| max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0 |
| if max_cap > 0: |
| codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) |
| else: |
| codon_emb = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype) |
|
|
| |
| seqs = [] |
| for b in range(batch_size): |
| lp = int(prefix_lengths[b].item()) |
| cap = int(per_cap[b].item()) |
| parts = [] |
| if lp > 0: |
| parts.append(prefix[b, :lp, :]) |
| parts.append(start[b, 0:1, :]) |
| if cap > 0: |
| parts.append(codon_emb[b, :cap, :]) |
| seqs.append(torch.cat(parts, dim=0)) |
| x = rnn_utils.pad_sequence(seqs, batch_first=True) |
|
|
| present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = [] |
| for block in self.blocks: |
| if self.training and getattr(self, 'gradient_checkpointing', False): |
| def _fn(inp): |
| return block(inp, use_cache=use_cache, position_offset=0) |
| blk_out = checkpoint.checkpoint(_fn, x, use_reentrant=False) |
| else: |
| blk_out = block(x, use_cache=use_cache, position_offset=0) |
| if use_cache: |
| x, kv = blk_out |
| present_kv_list.append(kv) |
| else: |
| x = blk_out |
|
|
| x = self.ln_f(x) |
| logits_full = self.lm_head(x) |
|
|
| |
| next_logits_list = [] |
| if max_cap == 0: |
| |
| codon_logits = logits_full[:, 0:0, :] |
| for b in range(batch_size): |
| lp = int(prefix_lengths[b].item()) |
| |
| pos_next = lp |
| if pos_next < logits_full.size(1): |
| next_logits_list.append(logits_full[b, pos_next, :]) |
| else: |
| next_logits_list.append(logits_full[b, -1, :]) |
| next_logits = torch.stack(next_logits_list, dim=0) |
| else: |
| slices = [] |
| for b in range(batch_size): |
| lp = int(prefix_lengths[b].item()) |
| cap = int(per_cap[b].item()) |
| |
| sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size) |
| slices.append(sl) |
| |
| pos_next = lp + cap |
| next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size)) |
| codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) |
| next_logits = torch.stack(next_logits_list, dim=0) |
| out = {"logits": codon_logits, "next_logits": next_logits} |
|
|
| if labels is not None: |
| |
| if labels.size(1) > 0 and max_cap > 0: |
| |
| adj = labels.new_full((batch_size, max_cap), -100) |
| for b in range(batch_size): |
| cap = int(per_cap[b].item()) |
| if cap > 0: |
| Lb = min(cap, labels.size(1)) |
| adj[b, :Lb] = labels[b, :Lb] |
| loss = F.cross_entropy(codon_logits.reshape(-1, self.vocab_size), adj.reshape(-1), ignore_index=-100) |
| else: |
| loss = codon_logits.sum() * 0.0 |
| out["loss"] = loss |
| |
| out["prefix_len"] = prefix_lengths.detach() |
| out["per_cap"] = per_cap.detach() |
| if use_cache: |
| out["present_kv"] = present_kv_list |
| return out if return_dict else codon_logits |
|
|