File size: 22,603 Bytes
2d8da02 a5c2045 d3d7249 2d8da02 301a25c 2d8da02 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | """
Core model architectures for CodonTranslator.
- CodonTranslatorModel: decoder-only backbone with species + protein prefix
Includes a frozen ESM-C encoder for protein conditioning.
"""
import math
import os
from typing import Optional, Dict, Any, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torch.nn.utils.rnn as rnn_utils
from .layers import RMSNorm, TransformerBlock
from .tokenizer import SpecialIds
class FrozenESMCEncoder(nn.Module):
"""
Frozen ESM-C encoder that computes protein embeddings on the fly.
Kept on single GPU per rank (not distributed via FSDP).
"""
def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "fp16"):
super().__init__()
self.model_name = model_name
self._device = torch.device(device if torch.cuda.is_available() else "cpu")
if dtype == "fp16":
self._autocast_dtype = torch.float16
elif dtype == "bf16":
self._autocast_dtype = torch.bfloat16
else:
self._autocast_dtype = None
self._load_model()
self.eval()
for p in self.parameters():
p.requires_grad_(False)
def _load_model(self):
from esm.models.esmc import ESMC
from esm.utils.constants.models import ESMC_300M, ESMC_600M
if self.model_name == "esmc_300m":
model_const = ESMC_300M
self.D_esm = 960
elif self.model_name == "esmc_600m":
model_const = ESMC_600M
self.D_esm = 1152
else:
raise ValueError(f"Unknown model: {self.model_name}")
self.model = ESMC.from_pretrained(model_name=model_const, device=self._device)
self.tokenizer = self.model.tokenizer
@torch.no_grad()
def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"):
from esm.utils import encoding
from esm.utils.misc import stack_variable_length_tensors
pad = self.tokenizer.pad_token_id
tokenized_seqs = []
for seq in sequences:
tokens = encoding.tokenize_sequence(seq, self.tokenizer, add_special_tokens=add_special_tokens)
if max_length is not None and len(tokens) > max_length:
tokens = tokens[:max_length]
tokenized_seqs.append(tokens)
input_ids = stack_variable_length_tensors(tokenized_seqs, constant_value=pad)
attention_mask = (input_ids != pad)
return input_ids, attention_mask
@torch.no_grad()
def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, return_contacts: bool = False):
device = self.model.device
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if self._autocast_dtype is not None and device.type == "cuda":
with torch.amp.autocast('cuda', dtype=self._autocast_dtype):
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
else:
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
embeddings = outputs.embeddings
if return_dict:
return {"embeddings": embeddings, "attention_mask": attention_mask}
else:
return embeddings
def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None):
if attention_mask is not None:
lengths = attention_mask.sum(dim=1) - 2
lengths = lengths.clamp(min=1)
else:
B, L, D = embeddings.shape
lengths = torch.full((B,), L - 2, device=embeddings.device)
stripped = embeddings[:, 1:-1, :]
return stripped, lengths
class CodonTranslatorModel(nn.Module):
def __init__(
self,
vocab_size: int = 79,
hidden_size: int = 960,
num_layers: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
max_position_embeddings: int = 4096,
dropout: float = 0.1,
layer_norm_eps: float = 1e-6,
num_special_tokens: int = 13,
special_ids: Optional[SpecialIds] = None,
esm_model_name: str = "esmc_300m",
esm_device: str = "cuda",
esm_dtype: str = "fp16",
max_protein_prefix: int = 0,
max_species_prefix: int = 0,
prepend_species: bool = True,
prepend_protein: bool = True,
species_embedding_dim: int = 1024,
attn_impl: str = "gqa", # "gqa" or "mha"
num_kv_groups: int = 0, # for GQA; 0 means default (no grouping)
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.max_position_embeddings = max_position_embeddings
self.special_ids = special_ids or SpecialIds()
self.num_special_tokens = num_special_tokens
# Single embedding table for all tokens (special + codon)
self.token_embed = nn.Embedding(vocab_size, hidden_size)
if prepend_protein and esm_model_name:
self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype)
# Project ESM token embeddings (D_esm) to model hidden size, then normalize
self.esm_ln = nn.Sequential(
nn.Linear(self.esm.D_esm, hidden_size, bias=False),
nn.ReLU(),
nn.LayerNorm(hidden_size),
)
else:
self.esm = None
self.esm_ln = None
self.species_embedding_dim = species_embedding_dim if prepend_species else 0
if prepend_species:
# Project species embeddings (fixed or token sequence) from Ds -> H
self.species_ln = nn.Sequential(
nn.Linear(self.species_embedding_dim, hidden_size, bias=False),
nn.ReLU(),
nn.LayerNorm(hidden_size),
)
else:
self.species_ln = None
# Optional per-prefix caps; 0 means unlimited (subject to global max length)
self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0
self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0
self.prepend_species = bool(prepend_species)
self.prepend_protein = bool(prepend_protein)
# Learned start embedding (BOS-less decoding)
self.start_embed = nn.Parameter(torch.zeros(1, 1, hidden_size))
nn.init.normal_(self.start_embed, mean=0.0, std=0.02)
# Attention configuration
self.attn_impl = str(attn_impl)
self.num_kv_groups = int(num_kv_groups)
kv_groups = self.num_kv_groups
self.blocks = nn.ModuleList([
TransformerBlock(
dim=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
num_kv_groups=(kv_groups if (kv_groups > 0 and attn_impl == "gqa") else None),
qk_norm=False,
attn_type=("mha" if self.attn_impl == "mha" else "gqa"),
) for _ in range(num_layers)
])
self.ln_f = RMSNorm(hidden_size, eps=layer_norm_eps)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self.gradient_checkpointing = False
def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
device = self.token_embed.weight.device
return self.token_embed(token_ids.to(device))
def build_prefix(
self,
batch_size: int,
device: torch.device,
species_tok_emb: Optional[torch.Tensor] = None,
species_emb: Optional[torch.Tensor] = None,
protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
species_tok_emb_src: Optional[torch.Tensor] = None,
species_tok_emb_tgt: Optional[torch.Tensor] = None,
species_emb_src: Optional[torch.Tensor] = None,
species_emb_tgt: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Build LLaVA-style prefix token embeddings by concatenating
[species_src]+[species_tgt]+[protein_tokens]. Returns:
- prefix: [B, Lp, H]
- prefix_lengths: [B] valid token counts per sample
"""
parts: list[torch.Tensor] = []
# Species: src then tgt (if provided)
if self.prepend_species and self.species_ln is not None:
tok_src = species_tok_emb_src if species_tok_emb_src is not None else species_tok_emb
tok_tgt = species_tok_emb_tgt if species_tok_emb_tgt is not None else species_tok_emb
emb_src = species_emb_src if species_emb_src is not None else species_emb
emb_tgt = species_emb_tgt if species_emb_tgt is not None else species_emb
def _as_tokens(S_tok, S_fix):
if S_fix is not None:
# [B, Ds] -> [B, 1, H]
S = self.species_ln(S_fix.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1))
return S
elif S_tok is not None:
# [B, Ls, Ds] -> optional cap, then project to H
S = S_tok
if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix:
S = S[:, : self.max_species_prefix, :]
S = S.to(device=device, dtype=next(self.parameters()).dtype)
S = self.species_ln(S)
return S
else:
return None
Ssrc = _as_tokens(tok_src, emb_src)
if Ssrc is not None:
parts.append(Ssrc)
Sdst = _as_tokens(tok_tgt, emb_tgt)
if Sdst is not None:
parts.append(Sdst)
# Protein tokens from ESM-C
if self.prepend_protein and self.esm is not None and protein_input is not None:
prot_ids, prot_mask = protein_input
esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True)
P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask)
# Optional per-protein capping before projection
if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix:
P = P[:, : self.max_protein_prefix, :]
if lengths is not None:
lengths = lengths.clamp(max=self.max_protein_prefix)
if P.size(1) > 0:
P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype))
# Zero padded rows (per-sample) based on lengths
if lengths is not None:
Lp = P.size(1)
ar = torch.arange(Lp, device=device).unsqueeze(0)
lengths = lengths.to(device=device)
valid = ar < lengths.unsqueeze(1) # [B,Lp]
P = P * valid.unsqueeze(-1)
parts.append(P)
if len(parts) == 0:
empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype)
return empty, torch.zeros(batch_size, dtype=torch.long, device=device)
prefix = torch.cat(parts, dim=1) if parts else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) # [B,Lp,H]
# Compute per-sample valid lengths: treat zero rows as padding
with torch.no_grad():
if prefix.size(1) > 0:
valid = (prefix.abs().sum(dim=-1) > 0)
lengths = valid.sum(dim=1).to(torch.long)
else:
lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
# ---- Enforce hard global budget on the prefix itself ----
prefix_budget = max(0, int(self.max_position_embeddings) - 1)
if prefix_budget == 0:
trimmed = prefix.new_zeros(prefix.size(0), 0, prefix.size(2))
return trimmed, torch.zeros(prefix.size(0), dtype=torch.long, device=prefix.device)
allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype))
Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0
if prefix.size(1) > Lp_max:
trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2))
for b in range(prefix.size(0)):
lb = int(allow[b].item())
if lb > 0:
trimmed[b, :lb, :] = prefix[b, :lb, :]
prefix = trimmed
lengths = allow
else:
lengths = allow
return prefix, lengths
def forward(
self,
codon_ids: torch.Tensor,
cond: Dict[str, Any] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
species_tok_emb: Optional[torch.Tensor] = None,
protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
protein_seqs: Optional[List[str]] = None,
# KV cache options
use_cache: bool = False,
past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
position_offset: int = 0,
) -> Dict[str, torch.Tensor]:
batch_size, codon_len = codon_ids.shape
device = codon_ids.device
# Unpack conditioning
if cond is not None:
control_mode = cond.get("control_mode", "fixed")
species_tok_emb_src = cond.get("species_tok_emb_src")
species_tok_emb_tgt = cond.get("species_tok_emb_tgt")
species_emb_src = cond.get("species_emb_src")
species_emb_tgt = cond.get("species_emb_tgt")
species_tok_emb = cond.get("species_tok_emb")
species_emb = cond.get("species_emb")
protein_input = cond.get("protein_input")
protein_seqs = cond.get("protein_seqs")
else:
species_emb = None
species_tok_emb_src = None
species_tok_emb_tgt = None
species_emb_src = None
species_emb_tgt = None
if protein_seqs is not None and protein_input is None:
if self.esm is not None:
with torch.no_grad():
# Respect per-protein ceiling during tokenization (+2 for BOS/EOS)
max_len_tokens = (self.max_protein_prefix + 2) if (getattr(self, "max_protein_prefix", 0) > 0) else None
protein_input = self.esm.tokenize(protein_seqs, max_length=max_len_tokens)
else:
protein_input = None
# Fast path: incremental decode using KV cache
if past_kv is not None:
# Expect only newly generated codon tokens here
if codon_ids.numel() == 0:
# Nothing to do; return a dummy next_logits
dummy = torch.zeros(batch_size, self.vocab_size, device=device, dtype=self.lm_head.weight.dtype)
return {"logits": dummy[:, 0:0], "next_logits": dummy}
x = self.embed_tokens(codon_ids) # [B, T_new, H]
present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = []
for i, block in enumerate(self.blocks):
kv_i = past_kv[i] if i < len(past_kv) else None
if self.training and getattr(self, 'gradient_checkpointing', False):
def _fn(inp):
return block(inp, past_kv=kv_i, use_cache=True, position_offset=position_offset)
out_blk = checkpoint.checkpoint(_fn, x, use_reentrant=False)
else:
out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset)
x, kv_out = out_blk # type: ignore[assignment]
present_kv.append(kv_out)
x = self.ln_f(x)
logits_step = self.lm_head(x) # [B, T_new, V]
next_logits = logits_step[:, -1, :]
out: Dict[str, torch.Tensor] = {"logits": logits_step[:, 0:0, :], "next_logits": next_logits}
out["present_kv"] = present_kv # type: ignore[assignment]
return out if return_dict else logits_step[:, 0:0, :]
# Standard path: build prefix and full window (training or prefill)
prefix, prefix_lengths = self.build_prefix(
batch_size=batch_size,
device=device,
species_tok_emb=species_tok_emb,
species_emb=species_emb if cond is not None else None,
protein_input=protein_input,
species_tok_emb_src=species_tok_emb_src,
species_tok_emb_tgt=species_tok_emb_tgt,
species_emb_src=species_emb_src,
species_emb_tgt=species_emb_tgt,
)
start = self.start_embed.expand(batch_size, 1, self.hidden_size) # [B,1,H]
# Per-sample true codon input lengths (exclude PADs)
pad_id = int(self.special_ids.pad) if hasattr(self, "special_ids") and self.special_ids is not None else 0
codon_mask = (codon_ids != pad_id) # [B, N]
codon_lens = codon_mask.sum(dim=1) # [B]
# Budget remaining after prefix + start
capacity = max(0, int(self.max_position_embeddings))
budget_after_prefix = torch.clamp(
torch.as_tensor(capacity, device=device) - (prefix_lengths + 1),
min=0,
) # [B]
# Per-sample cap is limited by both budget and available codons
per_cap = torch.minimum(budget_after_prefix, codon_lens) # [B]
# Total valid lengths per sample (prefix + start + capped codon)
valid_lengths = prefix_lengths + 1 + per_cap
T = int(valid_lengths.max().item()) if valid_lengths.numel() > 0 else (1 + int(codon_lens.max().item()) if codon_lens.numel() > 0 else 1)
# Embed only the needed codon window for this batch
max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0
if max_cap > 0:
codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) # [B, max_cap, H]
else:
codon_emb = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype)
# Build sequence per-sample using concat to preserve gradients, then pad
seqs = []
for b in range(batch_size):
lp = int(prefix_lengths[b].item())
cap = int(per_cap[b].item())
parts = []
if lp > 0:
parts.append(prefix[b, :lp, :])
parts.append(start[b, 0:1, :])
if cap > 0:
parts.append(codon_emb[b, :cap, :])
seqs.append(torch.cat(parts, dim=0)) # [Lb, H]
x = rnn_utils.pad_sequence(seqs, batch_first=True) # [B, T, H]
present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = []
for block in self.blocks:
if self.training and getattr(self, 'gradient_checkpointing', False):
def _fn(inp):
return block(inp, use_cache=use_cache, position_offset=0)
blk_out = checkpoint.checkpoint(_fn, x, use_reentrant=False)
else:
blk_out = block(x, use_cache=use_cache, position_offset=0)
if use_cache:
x, kv = blk_out # type: ignore[misc]
present_kv_list.append(kv)
else:
x = blk_out # type: ignore[assignment]
x = self.ln_f(x)
logits_full = self.lm_head(x) # [B, T, V]
# Gather codon-aligned logits per sample: positions (lp+1) .. (lp+cap) (skip start)
next_logits_list = []
if max_cap == 0:
# Keep graph by slicing from logits_full
codon_logits = logits_full[:, 0:0, :]
for b in range(batch_size):
lp = int(prefix_lengths[b].item())
# Last consumed position is the start token at index lp
pos_next = lp
if pos_next < logits_full.size(1):
next_logits_list.append(logits_full[b, pos_next, :])
else:
next_logits_list.append(logits_full[b, -1, :])
next_logits = torch.stack(next_logits_list, dim=0)
else:
slices = []
for b in range(batch_size):
lp = int(prefix_lengths[b].item())
cap = int(per_cap[b].item())
# Skip the start position so logits align with labels = codon_ids[:, 1:]
sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size)
slices.append(sl)
# Next-token logits after processing 'cap' codons: last consumed is at lp + cap
pos_next = lp + cap
next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size))
codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) # [B,max_cap,V]
next_logits = torch.stack(next_logits_list, dim=0)
out = {"logits": codon_logits, "next_logits": next_logits}
if labels is not None:
# Align labels to per-sample caps: mask out positions >= cap
if labels.size(1) > 0 and max_cap > 0:
# Build masked labels with -100 beyond cap per sample
adj = labels.new_full((batch_size, max_cap), -100)
for b in range(batch_size):
cap = int(per_cap[b].item())
if cap > 0:
Lb = min(cap, labels.size(1))
adj[b, :Lb] = labels[b, :Lb]
loss = F.cross_entropy(codon_logits.reshape(-1, self.vocab_size), adj.reshape(-1), ignore_index=-100)
else:
loss = codon_logits.sum() * 0.0
out["loss"] = loss
# Provide optional debug stats for trainer logging
out["prefix_len"] = prefix_lengths.detach()
out["per_cap"] = per_cap.detach()
if use_cache:
out["present_kv"] = present_kv_list # type: ignore[assignment]
return out if return_dict else codon_logits
|