diff --git a/CodonTranslator/__init__.py b/CodonTranslator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf62dbb572174a737e4e34ce44b68e9592d7d77 --- /dev/null +++ b/CodonTranslator/__init__.py @@ -0,0 +1,4 @@ +from .translator import CodonTranslator + +__all__ = ["CodonTranslator"] + diff --git a/CodonTranslator/__pycache__/__init__.cpython-312.pyc b/CodonTranslator/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b8765e93aee39a45404725bdf97b2cbf0989fd5 Binary files /dev/null and b/CodonTranslator/__pycache__/__init__.cpython-312.pyc differ diff --git a/CodonTranslator/__pycache__/layers.cpython-312.pyc b/CodonTranslator/__pycache__/layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94cf39683e2fc389c80d6c57d1ec92dbfc85f7fd Binary files /dev/null and b/CodonTranslator/__pycache__/layers.cpython-312.pyc differ diff --git a/CodonTranslator/__pycache__/models.cpython-312.pyc b/CodonTranslator/__pycache__/models.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c37d1369ec5c27dedaef533ef2e67a95dd02fa08 Binary files /dev/null and b/CodonTranslator/__pycache__/models.cpython-312.pyc differ diff --git a/CodonTranslator/__pycache__/tokenizer.cpython-312.pyc b/CodonTranslator/__pycache__/tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f7dfb0a761ecefeeabce7f3f4676f7a3f9093b9 Binary files /dev/null and b/CodonTranslator/__pycache__/tokenizer.cpython-312.pyc differ diff --git a/CodonTranslator/__pycache__/translator.cpython-312.pyc b/CodonTranslator/__pycache__/translator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28e059b12f510b3f5d6c277aceffd898b9d98ff3 Binary files /dev/null and b/CodonTranslator/__pycache__/translator.cpython-312.pyc differ diff --git a/CodonTranslator/layers.py b/CodonTranslator/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..2f460d0637f7cfcef3b1b1a940eb8e35dc992c4a --- /dev/null +++ b/CodonTranslator/layers.py @@ -0,0 +1,239 @@ +# Minimal attention/norm/FFN blocks used by the translator backbone +from __future__ import annotations + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * norm * self.weight + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x_rot = torch.zeros_like(x) + x_rot[..., ::2] = -x2 + x_rot[..., 1::2] = x1 + return x * cos + x_rot * sin + + +class GroupedQueryAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_groups: int, dropout: float = 0.0, qk_norm: bool = False): + super().__init__() + assert num_heads % max(1, num_kv_groups) == 0 + self.dim = dim + self.num_heads = int(num_heads) + self.num_kv_groups = max(1, int(num_kv_groups)) + self.group_size = self.num_heads // self.num_kv_groups + assert dim % num_heads == 0 + self.head_dim = dim // num_heads + self.dropout = dropout + + self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False) + self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) + self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False) + + self.q_norm = RMSNorm(self.head_dim) if qk_norm else None + self.k_norm = RMSNorm(self.head_dim) if qk_norm else None + + self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} + + def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype): + key = (T, device, dtype) + cached = self._rope_cache.get(key) + if cached is not None: + return cached + dim_half = self.head_dim // 2 + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) + t = torch.arange(T, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = torch.cos(freqs).repeat_interleave(2, dim=-1) + sin = torch.sin(freqs).repeat_interleave(2, dim=-1) + cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) + sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) + self._rope_cache[key] = (cos, sin) + return cos, sin + + def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int | torch.Tensor = 0): + B, T_new, _ = x.shape + q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() + v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() + + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + + if isinstance(position_offset, int): + cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) + if position_offset > 0: + cos = cos[:, :, position_offset: position_offset + T_new, :] + sin = sin[:, :, position_offset: position_offset + T_new, :] + q = _apply_rope(q, cos, sin) + k = _apply_rope(k, cos, sin) + else: + off = position_offset.to(device=x.device, dtype=torch.long) + max_off = int(off.max().item()) + cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype) + ar = torch.arange(T_new, device=x.device, dtype=torch.long) + idx = (off.unsqueeze(1) + ar.unsqueeze(0)) + cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) + sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) + q = _apply_rope(q, cos_b, sin_b) + k = _apply_rope(k, cos_b, sin_b) + + if past_kv is not None: + k_p, v_p = past_kv + k = torch.cat([k_p, k], dim=2) + v = torch.cat([v_p, v], dim=2) + + is_causal = past_kv is None + # Prefer Flash, then MemEff, then Math; allow FP32 via Math + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): + if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): + amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal + ) + else: + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal + ) + out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim) + out = self.out_proj(out) + if use_cache: + return out, (k, v) + return out + + +class SwiGLU(nn.Module): + """SwiGLU FFN with parameter names matching checkpoints (w1, w2, w3): + - w1: Linear(dim -> hidden) + - w2: Linear(hidden -> dim) + - w3: Linear(dim -> hidden) + Forward: w2(silu(w1(x)) * w3(x)) + """ + def __init__(self, dim: int, hidden_mult: float = 4.0, dropout: float = 0.0): + super().__init__() + hidden = int(dim * hidden_mult) + self.w1 = nn.Linear(dim, hidden, bias=False) + self.w2 = nn.Linear(hidden, dim, bias=False) + self.w3 = nn.Linear(dim, hidden, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, mlp_ratio: float, dropout: float = 0.0, num_kv_groups: Optional[int] = None, qk_norm: bool = False, attn_type: str = "mha"): + super().__init__() + if attn_type == "gqa": + self.attn = GroupedQueryAttention(dim, num_heads=num_heads, num_kv_groups=(num_kv_groups or num_heads), dropout=dropout) + else: + self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout) + self.ffn = SwiGLU(dim, hidden_mult=mlp_ratio, dropout=dropout) + self.ln1 = RMSNorm(dim) + self.ln2 = RMSNorm(dim) + + def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0): + a = self.attn(self.ln1(x), past_kv=past_kv, use_cache=use_cache, position_offset=position_offset) + if use_cache: + a, kv = a + x = x + a + x = x + self.ffn(self.ln2(x)) + if use_cache: + return x, kv + return x + + +class MultiHeadAttention(nn.Module): + """Standard MHA with fused qkv and RoPE, SDPA backend selection. + Matches checkpoint naming: qkv (dim->3*dim) and out_proj (dim->dim). + """ + + def __init__(self, dim: int, num_heads: int, dropout: float = 0.0, use_rope: bool = True): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.dropout = dropout + self.use_rope = use_rope + + self.qkv = nn.Linear(dim, 3 * dim, bias=False) + self.out_proj = nn.Linear(dim, dim, bias=False) + + self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} + + def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype): + key = (T, device, dtype) + cached = self._rope_cache.get(key) + if cached is not None: + return cached + dim_half = self.head_dim // 2 + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) + t = torch.arange(T, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = torch.cos(freqs).repeat_interleave(2, dim=-1) + sin = torch.sin(freqs).repeat_interleave(2, dim=-1) + cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) + sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) + self._rope_cache[key] = (cos, sin) + return cos, sin + + def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0): + B, T_new, _ = x.shape + qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous() + + if self.use_rope: + cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) + if position_offset > 0: + cos = cos[:, :, position_offset: position_offset + T_new, :] + sin = sin[:, :, position_offset: position_offset + T_new, :] + q = _apply_rope(q, cos, sin) + k_new = _apply_rope(k_new, cos, sin) + + if past_kv is not None: + k, v = past_kv + k = torch.cat([k, k_new], dim=2) + v = torch.cat([v, v_new], dim=2) + else: + k, v = k_new, v_new + + is_causal = past_kv is None + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): + if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): + amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal + ) + else: + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal + ) + out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim) + if out.dtype != x.dtype: + out = out.to(x.dtype) + out = self.out_proj(out) + if use_cache: + return out, (k, v) + return out diff --git a/CodonTranslator/models.py b/CodonTranslator/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a56578ccedf3c270a192eb30d6ca50359a9f48a6 --- /dev/null +++ b/CodonTranslator/models.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +from typing import Optional, Dict, Any, Tuple, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.rnn as rnn_utils + +from .layers import RMSNorm, TransformerBlock +from .tokenizer import SpecialIds + + +class FrozenESMCEncoder(nn.Module): + """Optional ESM-C encoder; if esm isn't available, stays inactive.""" + def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "bf16"): + super().__init__() + self.model_name = model_name + self._device = torch.device(device if torch.cuda.is_available() else "cpu") + self._autocast_dtype = torch.bfloat16 if dtype == "bf16" else (torch.float16 if dtype == "fp16" else None) + try: + from esm.models.esmc import ESMC # type: ignore + from esm.utils.constants.models import ESMC_300M, ESMC_600M # type: ignore + except Exception as e: + raise ImportError( + "ESM is required for CodonTranslator. Please install 'esm>=3.2.0'." + ) from e + if self.model_name == "esmc_300m": + const = ESMC_300M; self.D_esm = 960 + elif self.model_name == "esmc_600m": + const = ESMC_600M; self.D_esm = 1152 + else: + raise ValueError(f"Unknown ESM model: {self.model_name}") + self.model = ESMC.from_pretrained(model_name=const, device=self._device) + self.tokenizer = self.model.tokenizer + for p in self.parameters(): + p.requires_grad_(False) + self.eval() + + @torch.no_grad() + def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"): + if self.model is None: + raise RuntimeError("ESM model not available") + from esm.utils import encoding # type: ignore + from esm.utils.misc import stack_variable_length_tensors # type: ignore + pad = self.tokenizer.pad_token_id + toks = [] + for s in sequences: + t = encoding.tokenize_sequence(s, self.tokenizer, add_special_tokens=add_special_tokens) + if max_length is not None and len(t) > max_length: + t = t[:max_length] + toks.append(t) + input_ids = stack_variable_length_tensors(toks, constant_value=pad) + attention_mask = (input_ids != pad) + return input_ids, attention_mask + + @torch.no_grad() + def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True): + if self.model is None: + raise RuntimeError("ESM model not available") + device = self.model.device + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) if attention_mask is not None else None + if self._autocast_dtype is not None and device.type == "cuda": + with torch.amp.autocast('cuda', dtype=self._autocast_dtype): + outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) + else: + outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) + return {"embeddings": outputs.embeddings, "attention_mask": attention_mask} + + def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None): + if attention_mask is not None: + lengths = attention_mask.sum(dim=1) - 2 + lengths = lengths.clamp(min=1) + else: + B, L, D = embeddings.shape + lengths = torch.full((B,), L - 2, device=embeddings.device) + stripped = embeddings[:, 1:-1, :] + return stripped, lengths + + +class TranslatorBackbone(nn.Module): + def __init__( + self, + vocab_size: int = 79, + hidden_size: int = 960, + num_layers: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4.0, + max_position_embeddings: int = 4096, + dropout: float = 0.1, + layer_norm_eps: float = 1e-6, + num_special_tokens: int = 13, + special_ids: Optional[SpecialIds] = None, + esm_model_name: str = "esmc_300m", + esm_device: str = "cuda", + esm_dtype: str = "bf16", + max_protein_prefix: int = 0, + max_species_prefix: int = 0, + prepend_species: bool = True, + prepend_protein: bool = True, + species_embedding_dim: int = 1024, + attn_impl: str = "gqa", + num_kv_groups: int = 0, + ): + super().__init__() + self.vocab_size = int(vocab_size) + self.hidden_size = int(hidden_size) + self.num_layers = int(num_layers) + self.num_heads = int(num_heads) + self.max_position_embeddings = int(max_position_embeddings) + self.special_ids = special_ids or SpecialIds() + self.num_special_tokens = int(num_special_tokens) + + self.token_embed = nn.Embedding(self.vocab_size, self.hidden_size) + + # Optional ESM protein encoder + self.esm = None + self.esm_ln = None + if prepend_protein and esm_model_name: + # Enforce ESM presence – raise if missing + self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype) + self.esm_ln = nn.Sequential( + nn.Linear(self.esm.D_esm, self.hidden_size, bias=False), + nn.ReLU(), + nn.LayerNorm(self.hidden_size), + ) + self.species_embedding_dim = species_embedding_dim if prepend_species else 0 + self.species_ln = None + if prepend_species: + self.species_ln = nn.Sequential( + nn.Linear(self.species_embedding_dim, self.hidden_size, bias=False), + nn.ReLU(), + nn.LayerNorm(self.hidden_size), + ) + + self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0 + self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0 + self.prepend_species = bool(prepend_species) + self.prepend_protein = bool(prepend_protein) and (self.esm is not None) + + self.start_embed = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) + nn.init.normal_(self.start_embed, mean=0.0, std=0.02) + + self.attn_impl = str(attn_impl) + kv_groups = int(num_kv_groups) + self.blocks = nn.ModuleList([ + TransformerBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + num_kv_groups=(kv_groups if (kv_groups > 0 and self.attn_impl == "gqa") else None), + qk_norm=False, + attn_type=("mha" if self.attn_impl == "mha" else "gqa"), + ) + for _ in range(self.num_layers) + ]) + + self.ln_f = RMSNorm(self.hidden_size, eps=layer_norm_eps) + self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + self.gradient_checkpointing = False + + def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: + device = self.token_embed.weight.device + return self.token_embed(token_ids.to(device)) + + def build_prefix( + self, + batch_size: int, + device: torch.device, + species_tok_emb: Optional[torch.Tensor] = None, + species_emb: Optional[torch.Tensor] = None, + protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + parts: list[torch.Tensor] = [] + if self.prepend_species and self.species_ln is not None: + if species_emb is not None: + S = self.species_ln(species_emb.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1)) + parts.append(S) + parts.append(S) + elif species_tok_emb is not None: + S = species_tok_emb + if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix: + S = S[:, : self.max_species_prefix, :] + S = self.species_ln(S.to(device=device, dtype=next(self.parameters()).dtype)) + parts.append(S) + parts.append(S) + + if self.prepend_protein and self.esm is not None and protein_input is not None: + prot_ids, prot_mask = protein_input + esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True) + P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask) + if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix: + P = P[:, : self.max_protein_prefix, :] + lengths = lengths.clamp(max=self.max_protein_prefix) if lengths is not None else None + if P.size(1) > 0: + P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype)) + if lengths is not None: + Lp = P.size(1) + ar = torch.arange(Lp, device=device).unsqueeze(0) + valid = ar < lengths.unsqueeze(1) + P = P * valid.unsqueeze(-1) + parts.append(P) + + if len(parts) == 0: + empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) + return empty, torch.zeros(batch_size, dtype=torch.long, device=device) + + prefix = torch.cat(parts, dim=1) + with torch.no_grad(): + valid = (prefix.abs().sum(dim=-1) > 0) + lengths = valid.sum(dim=1).to(torch.long) + prefix_budget = max(0, int(self.max_position_embeddings) - 1) + allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype)) + Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0 + if prefix.size(1) > Lp_max: + trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2)) + for b in range(prefix.size(0)): + lb = int(allow[b].item()) + if lb > 0: + trimmed[b, :lb, :] = prefix[b, :lb, :] + prefix = trimmed + lengths = allow + else: + lengths = allow + return prefix, lengths + + def forward(self, codon_ids: torch.Tensor, cond: Dict[str, Any] = None, labels: Optional[torch.Tensor] = None, return_dict: bool = True, use_cache: bool = False, past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_offset: int = 0) -> Dict[str, torch.Tensor]: + batch_size, codon_len = codon_ids.shape + device = codon_ids.device + species_tok_emb = cond.get("species_tok_emb") if cond else None + species_emb = cond.get("species_emb") if cond else None + protein_input = cond.get("protein_input") if cond else None + + # Build prefix + prefix, prefix_lengths = self.build_prefix(batch_size, device, species_tok_emb=species_tok_emb, species_emb=species_emb, protein_input=protein_input) + start = self.start_embed.expand(batch_size, 1, -1) + + # KV cache path for incremental generation + if past_kv is not None and codon_len > 0: + x = self.embed_tokens(codon_ids) + present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for i, block in enumerate(self.blocks): + kv_i = past_kv[i] if i < len(past_kv) else None + out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset) + x, kv_out = out_blk + present_kv.append(kv_out) + x = self.ln_f(x) + logits_step = self.lm_head(x) + return {"logits": logits_step[:, 0:0, :], "next_logits": logits_step[:, -1, :], "present_kv": present_kv, "prefix_len": prefix_lengths} + + # Non-incremental: build prefix+start+codon window + codon_lens = torch.as_tensor([codon_len] * batch_size, device=device) + capacity = max(0, int(self.max_position_embeddings)) + budget_after_prefix = torch.clamp(torch.as_tensor(capacity, device=device) - (prefix_lengths + 1), min=0) + per_cap = torch.minimum(budget_after_prefix, codon_lens) + max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0 + codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) if max_cap > 0 else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype) + seqs = [] + for b in range(batch_size): + lp = int(prefix_lengths[b].item()) + cap = int(per_cap[b].item()) + parts = [] + if lp > 0: + parts.append(prefix[b, :lp, :]) + parts.append(start[b, 0:1, :]) + if cap > 0: + parts.append(codon_emb[b, :cap, :]) + seqs.append(torch.cat(parts, dim=0)) + x = rnn_utils.pad_sequence(seqs, batch_first=True) + + present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for block in self.blocks: + blk_out = block(x, use_cache=use_cache, position_offset=0) + if use_cache: + x, kv = blk_out + present_kv_list.append(kv) + else: + x = blk_out + x = self.ln_f(x) + logits_full = self.lm_head(x) + + next_logits_list = [] + if max_cap == 0: + codon_logits = logits_full[:, 0:0, :] + for b in range(batch_size): + lp = int(prefix_lengths[b].item()) + pos_next = lp + next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full[b, -1, :]) + next_logits = torch.stack(next_logits_list, dim=0) + else: + slices = [] + for b in range(batch_size): + lp = int(prefix_lengths[b].item()) + cap = int(per_cap[b].item()) + sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size) + slices.append(sl) + pos_next = lp + cap + next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size)) + codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) + next_logits = torch.stack(next_logits_list, dim=0) + out = {"logits": codon_logits, "next_logits": next_logits, "prefix_len": prefix_lengths} + if use_cache: + out["present_kv"] = present_kv_list + return out diff --git a/CodonTranslator/tokenizer.py b/CodonTranslator/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..87b43cc1efb3484f98a54318097b7274e84c9369 --- /dev/null +++ b/CodonTranslator/tokenizer.py @@ -0,0 +1,183 @@ +# Minimal copy of CodonTokenizer from src/tokenizer.py to keep the package self-contained. +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any + + +@dataclass(frozen=True) +class SpecialIds: + pad: int = 0 + unk: int = 1 + bos: int = 2 + eos: int = 3 + + def to_dict(self) -> Dict[str, int]: + return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos} + + +class CodonTokenizer: + __slots__ = ( + "codons", + "_special_token_str", + "vocab", + "ids_to_tokens", + "_special_ids", + "_num_special_tokens", + "_genetic_code", + "_codon2aa_char", + "_aa2codons_char", + ) + + def __init__( + self, + pad_token: str = "", + unk_token: str = "", + bos_token: str = "", + eos_token: str = "", + **_: Any, + ) -> None: + bases = ("A", "C", "G", "T") + self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases] + + special_tokens = [pad_token, unk_token, bos_token, eos_token] + self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token} + + self.vocab: Dict[str, int] = {} + for i, tok in enumerate(special_tokens): + self.vocab[tok] = i + for codon in self.codons: + self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens)) + + self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()} + + self._special_ids = SpecialIds( + pad=self.vocab[pad_token], + unk=self.vocab[unk_token], + bos=self.vocab[bos_token], + eos=self.vocab[eos_token], + ) + self._num_special_tokens = len(special_tokens) + + self._genetic_code: Dict[str, str] = { + "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L", + "TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S", + "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*", + "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W", + "CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L", + "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P", + "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q", + "CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R", + "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M", + "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T", + "AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K", + "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R", + "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V", + "GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A", + "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E", + "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G", + } + + self._codon2aa_char: Dict[int, str] = {} + self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} + for codon in self.codons: + cid = self.vocab[codon] + aa = self._genetic_code.get(codon, "X") + self._codon2aa_char[cid] = aa + if aa in self._aa2codons_char: + self._aa2codons_char[aa].append(cid) + + @property + def vocab_size(self) -> int: + return len(self.vocab) + + @property + def special_ids(self) -> SpecialIds: + return self._special_ids + + @property + def num_special_tokens(self) -> int: + return self._num_special_tokens + + @property + def pad_token_id(self) -> int: + return self._special_ids.pad + + @property + def eos_token_id(self) -> int: + return self._special_ids.eos + + # helpers + def codon_vocab(self) -> Dict[str, int]: + return {c: self.vocab[c] for c in self.codons} + + def codon2aa_char_map(self) -> Dict[int, str]: + return dict(self._codon2aa_char) + + def aa2codons_char_map(self) -> Dict[str, List[int]]: + return {k: v[:] for k, v in self._aa2codons_char.items()} + + # decoding + def decode_codon_seq(self, token_ids: List[int]) -> str: + parts: List[str] = [] + nst = self._num_special_tokens + for tid in token_ids: + if tid >= nst: + tok = self.ids_to_tokens.get(tid) + if tok is not None: + parts.append(tok) + return "".join(parts) + + # persistence + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + os.makedirs(save_directory, exist_ok=True) + vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + "vocab.json", + ) + payload = { + "vocab": self.vocab, + "special_token_str": self._special_token_str, + } + with open(vocab_file, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True) + return (vocab_file,) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer": + vocab_path = Path(pretrained_model_name_or_path) / "vocab.json" + tok = cls(**kwargs) + if not vocab_path.exists(): + return tok + with open(vocab_path, "r", encoding="utf-8") as f: + save_data = json.load(f) + vocab = save_data["vocab"] if isinstance(save_data, dict) and "vocab" in save_data else save_data + tok.vocab = {str(k): int(v) for k, v in vocab.items()} + tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()} + sts = save_data.get("special_token_str", tok._special_token_str) if isinstance(save_data, dict) else tok._special_token_str + tok._special_token_str.update(sts) + def _id_for(name: str, default_val: int) -> int: + sym = tok._special_token_str[name] + return int(tok.vocab.get(sym, default_val)) + tok._special_ids = SpecialIds( + pad=_id_for("pad", 0), + unk=_id_for("unk", 1), + bos=_id_for("bos", 2), + eos=_id_for("eos", 3), + ) + ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos] + m = max(ids) + tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4 + # rebuild helpers + tok._codon2aa_char = {} + tok._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} + for codon in tok.codons: + cid = tok.vocab[codon] + aa = tok._genetic_code.get(codon, "X") + tok._codon2aa_char[cid] = aa + if aa in tok._aa2codons_char: + tok._aa2codons_char[aa].append(cid) + return tok diff --git a/CodonTranslator/translator.py b/CodonTranslator/translator.py new file mode 100644 index 0000000000000000000000000000000000000000..01d754d0739d403d40e1d2454927737251243e6f --- /dev/null +++ b/CodonTranslator/translator.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +import logging + +import torch +import torch.nn.functional as F +import numpy as np +from safetensors.torch import load_file + +from .models import TranslatorBackbone +from .tokenizer import CodonTokenizer + # no external store at inference; species embeddings computed via Qwen + + +class CodonTranslator: + """ + High-level sampling wrapper for trained checkpoints with a simple API: + + from CodonTranslator import CodonTranslator + model = CodonTranslator.from_pretrained(model_path) + dna = model.sampling(species="Homo sapiens", protein_seq="M...", enforce_mapping=True) + """ + + def __init__(self, model_dir: Union[str, Path], device: str = "cuda", use_gbif: bool = False): + self.model_dir = Path(model_dir) + self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu") + self.tokenizer = CodonTokenizer.from_pretrained(str(self.model_dir)) + self.V = int(self.tokenizer.vocab_size) + self._eos_id = int(self.tokenizer.eos_token_id) + self._pad_id = int(self.tokenizer.pad_token_id) + self._num_special = int(self.tokenizer.num_special_tokens) + + # Load config + cfg_path = self.model_dir / "trainer_config.json" + if not cfg_path.exists(): + cfg_path = self.model_dir / "config.json" + with open(cfg_path, "r") as f: + self.config = json.load(f) + + # Build model and load weights + state = self._load_state_dict() + arch = self._infer_arch_from_state_dict(state) + self.model = TranslatorBackbone( + vocab_size=self.V, + hidden_size=int(arch["hidden_size"]), + num_layers=int(arch["num_layers"]), + num_heads=int(arch["num_heads"]), + mlp_ratio=float(arch.get("mlp_ratio", 4.0)), + max_position_embeddings=int(arch["max_position_embeddings"]), + num_special_tokens=self._num_special, + special_ids=self.tokenizer.special_ids, + prepend_species=bool(arch.get("prepend_species", True)), + prepend_protein=bool(arch.get("prepend_protein", False)), + species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)), + esm_model_name=str(arch.get("esm_model_name", "esmc_300m")), + esm_device=str(arch.get("esm_device", "cuda")), + esm_dtype=str(arch.get("esm_dtype", "bf16")), + max_protein_prefix=int(arch.get("max_protein_prefix", 0)), + max_species_prefix=int(arch.get("max_species_prefix", 0)), + attn_impl=str(arch.get("attn_impl", "gqa")), + num_kv_groups=int(arch.get("num_kv_groups", 0)), + ) + missing, unexpected = self.model.load_state_dict(state, strict=False) + if len(unexpected) > 0: + # non-fatal + pass + self.model.to(self.device).eval() + + # Static masks + self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device) + self._allowed_fixed[:self._num_special] = False + self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device) + self._allowed_variable[:self._num_special] = False + self._allowed_variable[self._eos_id] = True + + # Species taxonomy: either query GBIF (if allowed) or use raw names. + self._use_gbif = bool(use_gbif) + self._taxonomy_cache: Dict[str, str] = {} + + # ---- constructors ---- + @classmethod + def from_pretrained(cls, model_path: Union[str, Path], device: str = "cuda", use_gbif: bool = False) -> "CodonTranslator": + return cls(model_path, device=device, use_gbif=use_gbif) + + # ---- sampling APIs ---- + @torch.no_grad() + def sampling(self, species: str, protein_seq: str, enforce_mapping: bool = False, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, seed: Optional[int] = None, use_kv_cache: bool = True) -> str: + out = self.batch_inference( + species=[species], + protein_seqs=[protein_seq], + enforce_mapping=enforce_mapping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + seed=seed, + use_kv_cache=use_kv_cache, + ) + return out[0] + + @torch.no_grad() + def batch_inference( + self, + species: List[str], + protein_seqs: List[str], + enforce_mapping: bool = False, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + seed: Optional[int] = None, + use_kv_cache: bool = True, + micro_batch_size: int = 1, + ) -> List[str]: + """Generate DNA for a list of protein sequences, using micro-batching to limit memory. + + - micro_batch_size: number of samples to process at once (default=1 for low memory) + """ + assert len(species) == len(protein_seqs), "species and protein_seqs length must match" + mb = max(1, int(micro_batch_size)) + if len(species) <= mb: + return self._batch_inference_core( + species=species, + protein_seqs=protein_seqs, + enforce_mapping=enforce_mapping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + seed=seed, + use_kv_cache=use_kv_cache, + ) + + outputs: List[str] = [] + for start in range(0, len(species), mb): + end = min(start + mb, len(species)) + chunk_out = self._batch_inference_core( + species=species[start:end], + protein_seqs=protein_seqs[start:end], + enforce_mapping=enforce_mapping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + seed=seed, + use_kv_cache=use_kv_cache, + ) + outputs.extend(chunk_out) + return outputs + + @torch.no_grad() + def _batch_inference_core( + self, + species: List[str], + protein_seqs: List[str], + enforce_mapping: bool = False, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + seed: Optional[int] = None, + use_kv_cache: bool = True, + ) -> List[str]: + if seed is not None: + torch.manual_seed(int(seed)) + np.random.seed(int(seed)) + B = len(species) + assert B == len(protein_seqs), "species and protein_seqs length must match" + target_lens = torch.tensor([len(s) for s in protein_seqs], device=self.device, dtype=torch.long) + T_codons = int(target_lens.max().item()) + + # Prepare conditioning + cond: Dict[str, Any] = {"control_mode": "fixed"} + + # Species embeddings via Qwen3-Embedding (variable-length token sequences) + q_tok, lengths = self._qwen_embed_names(species, pooling="sequence") # [B, L, D] + # Always surface a message so users can see species embeddings are used + print(f"[CodonTranslator] Species embeddings (Qwen) computed: shape={tuple(q_tok.shape)}") + cond["species_tok_emb"] = q_tok.to(self.device) + + # Protein input via ESM (if available) – let model tokenize internally + if getattr(self.model, "esm", None) is not None: + # Tokenize AA sequences with model.esm + max_len_tokens = (getattr(self.model, "max_protein_prefix", 0) + 2) if getattr(self.model, "max_protein_prefix", 0) > 0 else None + prot_ids, prot_mask = self.model.esm.tokenize(protein_seqs, max_length=max_len_tokens) + cond["protein_input"] = (prot_ids.to(self.device), prot_mask.to(self.device)) + + # Start generation with empty context to build KV cache and initial logits + input_ids = torch.zeros(B, 0, dtype=torch.long, device=self.device) + out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=use_kv_cache) + kv = out_prefill.get("present_kv") if use_kv_cache else None + logits = out_prefill.get("next_logits") + assert logits is not None + # Report prefix length to prove species/protein prefixes were incorporated + try: + pref = out_prefill.get("prefix_len") + if pref is not None: + lst = pref.detach().cpu().tolist() + print(f"[CodonTranslator] Prefix lengths (species,species,protein): {lst}") + except Exception: + pass + + allowed = self._allowed_fixed + finished = torch.zeros(B, dtype=torch.bool, device=self.device) + + aa2codons = self.tokenizer.aa2codons_char_map() + + rng = range(T_codons) + # Greedy mode: temperature <= 0 selects argmax deterministically + greedy_mode = (temperature is not None and float(temperature) <= 0.0) + for step in rng: + logits = logits.masked_fill(~allowed, float("-inf")) + + # Stop sampling per-sample once reaching its target length; force PAD + done_now = (torch.tensor(step, device=self.device) >= target_lens) + if done_now.any(): + logits[done_now] = float("-inf") + logits[done_now, self._pad_id] = 0.0 + + # Enforce codon ↔ AA mapping at this step + if enforce_mapping: + aas_now = [seq[step] if step < len(seq) else None for seq in protein_seqs] + mask = torch.zeros_like(logits, dtype=torch.bool) + for i, a in enumerate(aas_now): + if a is None: + mask[i, self._num_special:self.V] = True + else: + valid = aa2codons.get(a, []) + if len(valid) == 0: + mask[i, self._num_special:self.V] = True + else: + mask[i, valid] = True + logits = logits.masked_fill(~mask, float("-inf")) + + if not greedy_mode and temperature != 1.0: + logits = logits / float(temperature) + if top_k is not None: + logits = self._top_k_filtering(logits, int(top_k)) + if top_p is not None: + logits = self._top_p_filtering(logits, float(top_p)) + + if greedy_mode: + next_tok = torch.argmax(logits, dim=-1, keepdim=True) + else: + probs = F.softmax(logits, dim=-1) + next_tok = torch.multinomial(probs, num_samples=1) + + input_ids = torch.cat([input_ids, next_tok], dim=1) + + if use_kv_cache: + pos_offset = int(out_prefill.get("prefix_len").max().item()) + input_ids.size(1) - 1 if isinstance(out_prefill, dict) and ("prefix_len" in out_prefill) else input_ids.size(1) - 1 + out_inc = self.model( + codon_ids=next_tok, + cond=None, + return_dict=True, + use_cache=True, + past_kv=kv, + position_offset=pos_offset, + ) + kv = out_inc.get("present_kv") + logits = out_inc.get("next_logits") + else: + # Recompute full forward with prefix+all generated tokens + out_full = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=False) + logits = out_full.get("next_logits") + + # Build DNA strings, dropping specials + output_token_rows: List[List[int]] = [] + for i, row in enumerate(input_ids.tolist()): + toks: List[int] = [] + for t in row: + if t == self._pad_id: + continue + if t == self._eos_id: + break + if t >= self._num_special and t < self.V: + toks.append(int(t)) + toks = toks[: int(target_lens[i].item())] + output_token_rows.append(toks) + sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows] + + # If not enforcing mapping, report AA token accuracy vs provided targets + if not enforce_mapping: + for i, dna in enumerate(sequences): + tgt = protein_seqs[i] + gen_aa = self._dna_to_aa(dna) + L = min(len(gen_aa), len(tgt)) + if L == 0: + acc = 0.0; num = 0; den = 0 + else: + num = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b) + den = L + acc = num / den + print(f"[CodonTranslator] AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})") + return sequences + + # ---- helpers ---- + def _load_state_dict(self) -> Dict[str, torch.Tensor]: + st_p = self.model_dir / "model.safetensors" + if st_p.exists(): + return load_file(st_p) + pt_p = self.model_dir / "pytorch_model.bin" + if pt_p.exists(): + return torch.load(pt_p, map_location="cpu") + raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}") + + def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: + arch: Dict[str, Any] = {} + if "lm_head.weight" in state_dict: + arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1]) + else: + for k, v in state_dict.items(): + if k.endswith("ln_f.weight"): + arch["hidden_size"] = int(v.shape[0]) + break + cfg = self.config or {} + if "hidden_size" in cfg: + arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore + if "hidden_size" not in arch: + arch["hidden_size"] = int(cfg.get("hidden_size", 750)) + H = int(arch["hidden_size"]) + + max_block = -1 + for k in state_dict.keys(): + if k.startswith("blocks."): + idx = int(k.split(".")[1]) + if idx > max_block: + max_block = idx + arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12)) + if "num_hidden_layers" in cfg: + arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore + + # mlp ratio + w1_key = next((k for k in state_dict.keys() if k.endswith("ffn.w1.weight")), None) + if w1_key is not None: + arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H) + else: + arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0)) + + # heads: pick divisor + cfg_heads = cfg.get("num_attention_heads") + if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0: + arch["num_heads"] = int(cfg_heads) + else: + for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1): + if H % h == 0: + arch["num_heads"] = h + break + + arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys()))) + has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys()) + arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm))) + arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m")) + arch["esm_device"] = str(cfg.get("esm_device", "cuda")) + arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower() + arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0)) + arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0)) + arch["max_position_embeddings"] = int(cfg.get("max_length", cfg.get("max_position_embeddings", 2048))) + arch["attn_impl"] = str(cfg.get("attn_impl", "gqa")) + arch["num_kv_groups"] = int(cfg.get("num_kv_groups", 0)) + return arch + + # --- filtering helpers + @staticmethod + def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor: + return logits if logits.dim() == 2 else logits.unsqueeze(0) + + @staticmethod + def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor: + x = CodonTranslator._ensure_2d_logits(logits) + k = max(1, min(int(k), x.size(-1))) + values, _ = torch.topk(x, k, dim=-1) + min_values = values[:, -1].unsqueeze(-1) + x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x) + return x if logits.dim() == 2 else x.squeeze(0) + + @staticmethod + def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor: + if p >= 1.0: + return logits + if p <= 0.0: + return torch.full_like(logits, float('-inf')) + x = CodonTranslator._ensure_2d_logits(logits) + sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1) + probs = torch.softmax(sorted_logits, dim=-1) + cumprobs = torch.cumsum(probs, dim=-1) + to_remove = cumprobs > p + # Avoid overlapping memory writes by cloning the RHS + to_remove = to_remove.to(torch.bool) + to_remove[:, 1:] = to_remove[:, :-1].clone() + to_remove[:, 0] = False + mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove) + x = torch.where(mask, torch.full_like(x, float('-inf')), x) + return x if logits.dim() == 2 else x.squeeze(0) + + # --- Qwen embedding fallback for species text --- + def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + from transformers import AutoTokenizer, AutoModel + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left" + ) + dtype = torch.float16 if self.device.type == "cuda" else torch.float32 + model = AutoModel.from_pretrained( + "Qwen/Qwen3-Embedding-0.6B", dtype=dtype, trust_remote_code=True + ).to(self.device).eval() + task = ( + "Given a species taxonomy information, generate a biological embedding " + "representing its taxonomic and evolutionary characteristics" + ) + queries = self._resolve_taxonomy_texts(names) + texts = [f"Instruct: {task}\nQuery: {q}" for q in queries] + inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device) + out = model(**inputs) + h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) + attn = inputs["attention_mask"] + # sequence embeddings padded to same length by tokenizer padding + return h, torch.sum(attn, dim=1) + + def _taxonomy_lookup(self, name: str) -> str: + if name in self._taxonomy_cache: + return self._taxonomy_cache[name] + if self._use_gbif: + try: + import requests + resp = requests.get("https://api.gbif.org/v1/species/match", params={"name": name}, timeout=5) + if resp.status_code == 200: + data = resp.json() + if data.get("matchType") != "NONE": + parts = [] + taxonomy = [] + for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]: + if rank in data and data[rank]: + taxonomy.append(data[rank]) + if taxonomy: + parts.append("Taxonomy: " + " > ".join(taxonomy)) + if "vernacularName" in data and data["vernacularName"]: + parts.append(f"Common name: {data['vernacularName']}") + if "confidence" in data: + parts.append(f"Match confidence: {data['confidence']}%") + if "status" in data: + parts.append(f"Status: {data['status']}") + desc = ". ".join(parts) if parts else name + self._taxonomy_cache[name] = desc + return desc + except Exception: + pass + return name + + def _resolve_taxonomy_texts(self, names: List[str]) -> List[str]: + """Resolve taxonomy strings for a batch of species names. + If a taxonomy DB is present, pull from it. Otherwise batch-query GBIF + (one request per species) and cache results. Always returns a list of + strings aligned to `names`. + """ + results: List[str] = [] + # Batch “query”: loop per-name; still batched at the embedding stage + fetched = 0 + for s in names: + txt = self._taxonomy_lookup(s) + if s in self._taxonomy_cache: + fetched += 1 + results.append(txt) + if self._use_gbif: + print(f"[CodonTranslator] Taxonomy texts resolved (GBIF={'on' if self._use_gbif else 'off'}): {fetched}/{len(names)} fetched") + return results + + @staticmethod + def _dna_to_aa(dna_seq: str) -> str: + g = { + 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', + 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', + 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', + 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', + 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', + 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', + 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', + 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G' + } + L = len(dna_seq) // 3 + aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)] + return ''.join(aa) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..419f9cb4a80ca7cb5ae97e3d31b80b4c2ce286ce --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 CodonTranslator authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c6c6961270d190eb964de33e58c81a9180592a47 --- /dev/null +++ b/README.md @@ -0,0 +1,115 @@ +--- +license: mit +library_name: pytorch +tags: + - biology + - dna + - codon-optimization + - protein-conditioned-generation + - fsdp +datasets: + - alegendaryfish/CodonTranslator-data +--- + +# CodonTranslator + +CodonTranslator is a protein-conditioned codon sequence generation model trained on the representative-only `data_v3` release. + +This repository is the public model and training-code release. It contains: + +- `final_model/`: inference-ready weights +- `training_checkpoints/checkpoint-71000/`: a resumable training checkpoint +- `src/`, `train.py`, `sampling.py`: training and inference code +- `resplit_data_v3.py`: the `data_v3` reconstruction pipeline +- `slurm/`: the single-node H200 training and data rebuild submission scripts +- `CodonTranslator/` and `pyproject.toml`: a lightweight packaged inference wrapper + +## Training configuration + +- Architecture: `hidden=750`, `layers=20`, `heads=15`, `mlp_ratio=3.2` +- Attention: `mha` +- Precision: `bf16` +- Parallelism: FSDP full shard +- Effective global batch: `1536` +- Weight decay: `1e-4` +- Dataset: `alegendaryfish/CodonTranslator-data` + +## Dataset release + +The corresponding public dataset and species embedding release is: + +- `alegendaryfish/CodonTranslator-data` + +That dataset repo contains: + +- final representative-only `train/`, `val/`, `test/` parquet shards +- `embeddings_v2/` +- split audit files and reconstruction metadata + +## Quick start + +### Install + +```bash +git clone https://huggingface.co/alegendaryfish/CodonTranslator +cd CodonTranslator +pip install -r requirements.txt +pip install -e . +``` + +Both import styles are supported: + +```python +from CodonTranslator import CodonTranslator +``` + +```python +from codontranslator import CodonTranslator +``` + +### Train + +```bash +python train.py \ + --train_data /path/to/train \ + --val_data /path/to/val \ + --embeddings_dir /path/to/embeddings_v2 \ + --output_dir outputs \ + --fsdp \ + --bf16 \ + --attn mha \ + --hidden 750 \ + --layers 20 \ + --heads 15 \ + --mlp_ratio 3.2 \ + --batch_size 48 \ + --grad_accum 4 \ + --epochs 3 \ + --lr 7e-5 \ + --weight_decay 1e-4 +``` + +### Sample + +```bash +python sampling.py \ + --model_path final_model \ + --embeddings_dir /path/to/embeddings_v2 \ + --species "Panicum hallii" \ + --protein_sequence "MSEQUENCE" \ + --strict_species_lookup +``` + +## Notes + +- Training uses precomputed `embeddings_v2` for species conditioning. +- The data split is built in protein space with MMseqs clustering and binomial-species test holdout. +- `checkpoint-71000` is included for training resumption; `final_model/` is the recommended inference entrypoint. +- For compatibility, released model directories contain both `trainer_config.json` and `config.json`. + +## Sampling arguments + +- `enforce_mapping`: when `True`, each generated codon is constrained to encode the provided amino acid at that position. +- `temperature`: softmax temperature. Lower values are more deterministic; `0` selects argmax greedily. +- `top_k`: keep only the `k` highest-logit codon candidates before sampling. +- `top_p`: nucleus sampling threshold; keep the smallest probability mass whose cumulative sum reaches `p`. diff --git a/__pycache__/precompute_embeddings.cpython-312.pyc b/__pycache__/precompute_embeddings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a739404aabd7f009152a56910eba5d4bbf2cd3b1 Binary files /dev/null and b/__pycache__/precompute_embeddings.cpython-312.pyc differ diff --git a/__pycache__/resplit_data_v3.cpython-312.pyc b/__pycache__/resplit_data_v3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..602e6e28acb3329ba041da32c0754e40fc5036d2 Binary files /dev/null and b/__pycache__/resplit_data_v3.cpython-312.pyc differ diff --git a/__pycache__/sampling.cpython-312.pyc b/__pycache__/sampling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1725c916e5cab8fb6faa988b3ee7865cac73c7 Binary files /dev/null and b/__pycache__/sampling.cpython-312.pyc differ diff --git a/__pycache__/train.cpython-312.pyc b/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0820dbc70d357aed6aa83e7c4977a0b8ea3c117 Binary files /dev/null and b/__pycache__/train.cpython-312.pyc differ diff --git a/batch_eval.py b/batch_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2b75872bed5ba71dbe06ec0d6ae67d104b316e8c --- /dev/null +++ b/batch_eval.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python3 +""" +Run eval.py across all checkpoints and datasets in parallel (multi-GPU), +and collect results to ./eval.csv. + +- Discovers checkpoints under outputs/checkpoint-* +- Evaluates on: data/test/*.parquet and data/val/*.parquet +- Uses up to N GPUs concurrently (default: 4) by setting CUDA_VISIBLE_DEVICES +- Parses the "Summary ..." line(s) from eval.py logs +- Appends rows to ./eval.csv + +Example: + python batch_eval.py \ + --outputs_dir outputs \ + --embeddings_dir embeddings \ + --datasets data/test/*.parquet data/val/*.parquet \ + --splits test val \ + --num_samples 12800 \ + --batch_size 4 \ + --gpus 0 1 2 3 \ + --eval_script eval.py \ + --device cuda + +Notes: +- This script *does not* modify your eval.py. It just orchestrates/launches it. +- Requires Python 3.8+ (standard library only). +""" + +import argparse +import csv +import os +import re +import sys +import time +import glob +import queue +import threading +import subprocess +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed + + +TF_SUMMARY_RE = re.compile( + r"Summary over\s+(\d+)\s+samples\s+→.*?CE=([-\d\.eE]+).*?CODON-acc=([-\d\.eE]+).*?AA-acc=([-\d\.eE]+)" +) +EVALALL_SUMMARY_RE = re.compile( + r"Full-dataset summary.*?tokens=(\d+).*?CE=([-\d\.eE]+).*?CODON-acc=([-\d\.eE]+).*?AA-acc=([-\d\.eE]+)" +) + +CSV_FIELDS = [ + "timestamp_iso", + "model_path", + "checkpoint_step", + "split", + "data_path", + "num_samples", + "batch_size", + "seed", + "eval_all", + "gpu_id", + "runtime_sec", + "tokens", + "mean_ce", + "mean_codon_acc", + "mean_aa_acc", + "status", + "error", + "command", +] + + +def parse_args(): + p = argparse.ArgumentParser(description="Parallel evaluator for CodonGPT checkpoints.") + p.add_argument("--outputs_dir", type=str, default="outputs/", help="Folder containing checkpoint-* subdirs.") + p.add_argument("--embeddings_dir", type=str, default="embeddings/", help="Embeddings dir to pass to eval.py") + p.add_argument("--datasets", nargs="+", default=["data/test/*.parquet", "data/val/*.parquet"], + help="One or more dataset globs.") + p.add_argument("--splits", nargs="+", default=["test", "val"], + help="Split names aligned with --datasets (same length).") + p.add_argument("--num_samples", type=int, default=12800, help="num_samples for eval.py (random subset mode)") + p.add_argument("--batch_size", type=int, default=4, help="batch_size for eval.py") + p.add_argument("--seed", type=int, default=42, help="seed for eval.py") + p.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device flag for eval.py") + p.add_argument("--gpus", nargs="+", default=["0", "1", "2", "3"], help="GPU IDs to use (as CUDA_VISIBLE_DEVICES)") + p.add_argument("--eval_script", type=str, default="eval.py", help="Path to eval.py") + p.add_argument("--csv_path", type=str, default="eval.csv", help="Output CSV file") + p.add_argument("--eval_all", action="store_true", + help="Use eval.py --eval_all (streaming, no num_samples). If set, ignores --num_samples.") + p.add_argument("--workers", type=int, default=4, + help="--workers passed to eval.py when --eval_all is set.") + p.add_argument("--dry_run", action="store_true", help="List planned runs but do not execute.") + # New: filtering / resume options + p.add_argument("--start_after_step", type=int, default=-1, + help="Only evaluate checkpoints with step > this value (e.g., 73700)") + p.add_argument("--end_step", type=int, default=-1, + help="If >0, only evaluate checkpoints with step <= this value") + p.add_argument("--skip_existing", dest="skip_existing", action="store_true", default=True, + help="Skip tasks already recorded as OK in csv_path") + p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false", + help="Do not skip existing OK rows; re-run everything in range") + return p.parse_args() + + +def natural_step(dirpath: Path) -> int: + """ + Extract integer step from a checkpoint dir name like 'checkpoint-21000'. + Returns -1 if not found. + """ + m = re.search(r"checkpoint-(\d+)", dirpath.name) + return int(m.group(1)) if m else -1 + + +def discover_checkpoints(outputs_dir: str) -> list[Path]: + paths = sorted( + (Path(p) for p in glob.glob(os.path.join(outputs_dir, "checkpoint-*")) if os.path.isdir(p)), + key=lambda p: natural_step(p), + ) + # Optional: filter only dirs that look like real checkpoints + filtered = [] + for p in paths: + has_config = (p / "config.json").exists() or (p / "trainer_config.json").exists() + has_weights = (p / "model.safetensors").exists() or (p / "pytorch_model.bin").exists() + if has_config and has_weights: + filtered.append(p) + return filtered + + +def build_cmd(py_exec: str, + eval_script: str, + model_path: str, + data_path: str, + embeddings_dir: str, + device: str, + num_samples: int, + batch_size: int, + seed: int, + eval_all: bool, + workers: int) -> list[str]: + cmd = [py_exec, eval_script, + "--model_path", model_path, + "--data_path", data_path, + "--embeddings_dir", embeddings_dir, + "--batch_size", str(batch_size), + "--device", device, + "--seed", str(seed)] + if eval_all: + cmd += ["--eval_all", "--workers", str(workers)] + else: + cmd += ["--num_samples", str(num_samples)] + return cmd + + +def parse_metrics(stdout: str, stderr: str) -> dict: + """ + Return dict with keys: tokens, mean_ce, mean_codon_acc, mean_aa_acc (strings), + or raise ValueError if no summary line was found. + """ + text = stdout + "\n" + stderr + + # Try eval_all format first + m = EVALALL_SUMMARY_RE.search(text) + if m: + tokens, ce, codon, aa = m.groups() + return {"tokens": tokens, "mean_ce": ce, "mean_codon_acc": codon, "mean_aa_acc": aa} + + # Try teacher-forced (random-subset) summary + m = TF_SUMMARY_RE.search(text) + if m: + _samples, ce, codon, aa = m.groups() + return {"tokens": "", "mean_ce": ce, "mean_codon_acc": codon, "mean_aa_acc": aa} + + # Not found + raise ValueError("Could not find summary line in eval.py output.") + + +def run_one(task: dict, gpu_queue: "queue.Queue[str]", csv_lock: threading.Lock) -> dict: + """ + Execute one eval.py call using a GPU from the queue. Returns a row dict for CSV. + """ + gpu_id = gpu_queue.get() # blocks until a GPU id is available + start = time.time() + status = "OK" + err_text = "" + + try: + env = os.environ.copy() + # Pin the subprocess to a single GPU + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + env.setdefault("TOKENIZERS_PARALLELISM", "false") + env.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + + result = subprocess.run( + task["cmd"], + env=env, + capture_output=True, + text=True, + check=False, + ) + + try: + metrics = parse_metrics(result.stdout, result.stderr) + except Exception as e: + status = "FAIL" + err_text = f"{e}\n--- STDOUT ---\n{result.stdout}\n--- STDERR ---\n{result.stderr}" + metrics = {"tokens": "", "mean_ce": "", "mean_codon_acc": "", "mean_aa_acc": ""} + + if result.returncode != 0 and status == "OK": + status = "FAIL" + err_text = f"Non-zero exit code {result.returncode}\n--- STDOUT ---\n{result.stdout}\n--- STDERR ---\n{result.stderr}" + + finally: + runtime = time.time() - start + gpu_queue.put(gpu_id) # release GPU + + row = { + "timestamp_iso": time.strftime("%Y-%m-%dT%H:%M:%S"), + "model_path": task["model_path"], + "checkpoint_step": task["step"], + "split": task["split"], + "data_path": task["data_path"], + "num_samples": task["num_samples"] if not task["eval_all"] else "", + "batch_size": task["batch_size"], + "seed": task["seed"], + "eval_all": str(task["eval_all"]), + "gpu_id": str(gpu_id), + "runtime_sec": f"{runtime:.2f}", + "tokens": metrics.get("tokens", ""), + "mean_ce": metrics.get("mean_ce", ""), + "mean_codon_acc": metrics.get("mean_codon_acc", ""), + "mean_aa_acc": metrics.get("mean_aa_acc", ""), + "status": status, + "error": err_text.strip(), + "command": " ".join(task["cmd"]), + } + return row + + +def ensure_csv(path: str): + """Create CSV with header if it does not exist.""" + need_header = not os.path.exists(path) or os.path.getsize(path) == 0 + if need_header: + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=CSV_FIELDS) + w.writeheader() + + +def read_completed_keys(path: str) -> set[tuple[int, str, str]]: + """ + Read existing CSV and return a set of (step, split, data_path) for rows with status == 'OK'. + If CSV does not exist, returns empty set. + """ + keys: set[tuple[int, str, str]] = set() + if not os.path.exists(path) or os.path.getsize(path) == 0: + return keys + try: + with open(path, "r", newline="") as f: + r = csv.DictReader(f) + for row in r: + if (row.get("status") or "").strip().upper() == "OK": + try: + step = int(row.get("checkpoint_step", "-1")) + except ValueError: + continue + split = row.get("split", "") + data_path = row.get("data_path", "") + keys.add((step, split, data_path)) + except Exception: + # If CSV is malformed, resume logic is best-effort + pass + return keys + + +def append_row(path: str, row: dict, lock: threading.Lock): + with lock: + with open(path, "a", newline="") as f: + w = csv.DictWriter(f, fieldnames=CSV_FIELDS) + w.writerow(row) + f.flush() + + +def main(): + args = parse_args() + + if len(args.datasets) != len(args.splits): + print("ERROR: --datasets and --splits must have the same length.", file=sys.stderr) + sys.exit(2) + + checkpoints = discover_checkpoints(args.outputs_dir) + if not checkpoints: + print(f"No checkpoints found in {args.outputs_dir}/checkpoint-*", file=sys.stderr) + sys.exit(1) + + print(f"Discovered {len(checkpoints)} checkpoints.") + ds_pairs = list(zip(args.splits, args.datasets)) + print(f"Datasets: {', '.join([f'{s}:{d}' for s, d in ds_pairs])}") + print(f"GPUs: {', '.join(args.gpus)}") + print(f"Writing results to: {args.csv_path}") + if args.start_after_step >= 0: + print(f"Filtering: step > {args.start_after_step}") + if args.end_step > 0: + print(f"Filtering: step <= {args.end_step}") + print(f"Skip existing OK rows in CSV: {args.skip_existing}") + + # Build task list + py_exec = sys.executable + tasks = [] + completed_keys = read_completed_keys(args.csv_path) if args.skip_existing else set() + for ckpt in checkpoints: + step = natural_step(ckpt) + # Apply step filters + if args.start_after_step >= 0 and step <= args.start_after_step: + continue + if args.end_step > 0 and step > args.end_step: + continue + for split, data_path in ds_pairs: + # Skip if already evaluated with OK status + if (step, split, data_path) in completed_keys: + continue + cmd = build_cmd( + py_exec=py_exec, + eval_script=args.eval_script, + model_path=str(ckpt), + data_path=data_path, + embeddings_dir=args.embeddings_dir, + device=args.device, + num_samples=args.num_samples, + batch_size=args.batch_size, + seed=args.seed, + eval_all=args.eval_all, + workers=args.workers, + ) + tasks.append({ + "model_path": str(ckpt), + "step": step, + "split": split, + "data_path": data_path, + "num_samples": args.num_samples, + "batch_size": args.batch_size, + "seed": args.seed, + "eval_all": args.eval_all, + "cmd": cmd, + }) + + # Dry run listing + if args.dry_run: + for t in tasks: + print(f"[DRY RUN] GPU=? step={t['step']} split={t['split']} -> {' '.join(t['cmd'])}") + print(f"Planned runs: {len(tasks)}") + return + + # Prepare CSV + ensure_csv(args.csv_path) + csv_lock = threading.Lock() + + # GPU pool + gpu_queue: "queue.Queue[str]" = queue.Queue() + for gid in args.gpus: + gpu_queue.put(str(gid)) + + # Execute with up to len(gpus) concurrent workers + max_workers = max(1, len(args.gpus)) + with ThreadPoolExecutor(max_workers=max_workers) as ex: + futures = [ex.submit(run_one, t, gpu_queue, csv_lock) for t in tasks] + completed = 0 + total = len(futures) + for fut in as_completed(futures): + row = fut.result() + append_row(args.csv_path, row, csv_lock) + completed += 1 + if row["status"] == "OK": + print(f"[{completed}/{total}] ✅ step={row['checkpoint_step']} split={row['split']} " + f"CE={row['mean_ce']} CODON={row['mean_codon_acc']} AA={row['mean_aa_acc']} " + f"gpu={row['gpu_id']} in {row['runtime_sec']}s") + else: + print(f"[{completed}/{total}] ❌ step={row['checkpoint_step']} split={row['split']} " + f"gpu={row['gpu_id']} See CSV 'error' column for details.") + + print(f"Done. Results appended to {args.csv_path}") + + +if __name__ == "__main__": + main() diff --git a/codontranslator/__init__.py b/codontranslator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..065d44be4d3bfd998f0d5b42b30dbe4e8124ba30 --- /dev/null +++ b/codontranslator/__init__.py @@ -0,0 +1,3 @@ +from CodonTranslator import CodonTranslator + +__all__ = ["CodonTranslator"] diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..121b9a8be76185770430ae3b946ee1c562ae0a16 --- /dev/null +++ b/environment.yml @@ -0,0 +1,20 @@ +name: codontranslator +channels: + - conda-forge + - pytorch + - nvidia +dependencies: + - python=3.12 + - pip + - pytorch>=2.4 + - pandas>=2.3 + - pyarrow>=21.0 + - duckdb>=1.5 + - biopython>=1.85 + - pip: + - transformers>=4.57.0 + - esm>=3.2.3 + - safetensors>=0.7.0 + - huggingface-hub>=0.36.0 + - accelerate>=1.9.0 + - wandb>=0.21.0 diff --git a/eval.py b/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..30820f527ec00f8a93169cff7e46f25e9eab44d6 --- /dev/null +++ b/eval.py @@ -0,0 +1,1239 @@ +#!/usr/bin/env python +""" +Teacher-forced (and optional free-run) evaluation on a random subset of your +dataset to measure codon token cross-entropy and AA token accuracy, using the +same conditioning pathway as training. + +Supports either a CSV file or Parquet input via a directory/glob (e.g., +./data/val/*.parquet). + +Usage examples: + # CSV input + python eval.py \ + --model_path outputs/checkpoint-21000 \ + --data_path random_sample_1000.csv \ + --embeddings_dir embeddings \ + --num_samples 10 \ + --batch_size 10 \ + --device cuda + + # Parquet glob input + python eval.py \ + --model_path outputs/checkpoint-21000 \ + --data_path "./data/val/*.parquet" \ + --embeddings_dir embeddings \ + --num_samples 64 \ + --batch_size 32 \ + --device cuda +""" + +import argparse +import json +import logging +import random +from pathlib import Path +from typing import List, Optional, Tuple +import glob + +import torch +import torch.nn.functional as F +import pandas as pd + +from src.sampler import CodonSampler +from src.dataset import SpeciesEmbeddingStore, StreamSeqDataset, stage_collate_fn +from torch.utils.data import DataLoader + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger("eval_tf") + + +def parse_args(): + p = argparse.ArgumentParser("Teacher-forced evaluation of CodonGPT") + p.add_argument("--model_path", required=True, type=str, + help="Path to checkpoint dir (with config.json / model.safetensors)") + # Input data: CSV file or Parquet glob/dir + p.add_argument("--data_path", required=False, type=str, default=None, + help="CSV file or Parquet glob/dir (e.g., ./data/val/*.parquet)") + # Back-compat: --csv_path still accepted (deprecated) + p.add_argument("--csv_path", required=False, type=str, default=None, + help="[Deprecated] CSV with columns: Taxon, protein_seq, cds_DNA") + p.add_argument("--embeddings_dir", type=str, default=None, + help="Species embeddings directory (recommended for parity)") + p.add_argument("--num_samples", type=int, default=10) + p.add_argument("--batch_size", type=int, default=10) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default="cuda") + p.add_argument("--workers", type=int, default=0, + help="DataLoader workers for --eval_all streaming mode") + # Free-run (sampling) evaluation options + p.add_argument("--free_run", action="store_true", + help="If set, perform real sampling instead of teacher forcing and compare to ground-truth codon sequences") + p.add_argument("--temperature", type=float, default=0.8) + p.add_argument("--top_k", type=int, default=50) + p.add_argument("--top_p", type=float, default=0.9) + p.add_argument("--control_mode", type=str, choices=["fixed","variable"], default="fixed") + p.add_argument("--enforce_translation", action="store_true", + help="Hard-mask decoding to codons matching target amino acid at each position during free-run evaluation") + # Full-dataset streaming eval (no sampling) + p.add_argument("--eval_all", action="store_true", + help="Stream over all rows from --data_path and compute aggregated metrics (memory-safe)") + p.add_argument("--max_records", type=int, default=0, + help="When --eval_all is set: limit to first N samples (0 = all)") + p.add_argument("--debug_aa_check", action="store_true", + help="Print per-sample agreement between CDS→AA (standard code) and provided protein") + # Per-sequence export over standard splits ./data/val and ./data/test + p.add_argument("--export_per_sequence", action="store_true", + help="Process ./data/val and ./data/test parquets in batches and export a per-sequence CSV") + p.add_argument("--splits_root", type=str, default="./data", + help="Root directory that contains val/ and test/ subfolders with parquet files") + p.add_argument("--out_csv", type=str, default="outputs/eval_per_sequence.csv", + help="Output CSV path for per-sequence export") + p.add_argument("--export_splits", nargs="+", default=["val", "test"], + help="Subdirectories under --splits_root to process (default: val test)") + p.add_argument("--max_rows_per_split", type=int, default=0, + help="When --export_per_sequence is set: limit number of rows per split (0 = all)") + p.add_argument("--progress", action="store_true", + help="Show progress bars during per-sequence export") + # Capacity and evaluation controls + p.add_argument("--no_truncation", action="store_true", + help="Fit prefix caps so generated codon length equals protein length (avoids capacity truncation)") + p.add_argument("--species_prefix_cap", type=int, default=0, + help="When >0 and --no_truncation is set, cap species token prefix to this many tokens; 0 = no species cap") + return p.parse_args() + + +def _is_parquet_path(p: str) -> bool: + lower = p.lower() + return lower.endswith(".parquet") or lower.endswith(".parq") + + +def _expand_paths(maybe_path_or_glob: str) -> List[str]: + """Expand a path/glob or directory into a sorted list of files. + Prioritize Parquet when scanning a directory. + """ + paths: List[str] = [] + P = Path(maybe_path_or_glob) + if P.is_dir(): + paths.extend(sorted(str(x) for x in P.rglob("*.parquet"))) + paths.extend(sorted(str(x) for x in P.rglob("*.parq"))) + paths.extend(sorted(str(x) for x in P.rglob("*.csv"))) + paths.extend(sorted(str(x) for x in P.rglob("*.tsv"))) + paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz"))) + paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz"))) + else: + paths = sorted(glob.glob(str(P))) + # Dedup while preserving order + out: List[str] = [] + seen = set() + for x in paths: + if x not in seen: + out.append(x) + seen.add(x) + return out + + +def _load_random_samples_from_parquet(files: List[str], num_samples: int, seed: int) -> pd.DataFrame: + """Collect up to num_samples rows from a list of Parquet files, reading by row group. + Reads only the required columns and shuffles files/row-groups for decent coverage. + """ + try: + import pyarrow.parquet as pq # type: ignore + except Exception as e: # pragma: no cover + raise ImportError("pyarrow is required to read parquet files") from e + + rng = random.Random(seed) + req = ["Taxon", "protein_seq", "cds_DNA"] + files = [f for f in files if _is_parquet_path(f)] + if not files: + raise FileNotFoundError("No Parquet files found to read") + files = files.copy() + rng.shuffle(files) + + collected: List[pd.DataFrame] = [] + remaining = int(max(0, num_samples)) + for fp in files: + if remaining <= 0: + break + pf = pq.ParquetFile(fp) + nrg = int(pf.num_row_groups or 0) + if nrg <= 0: + rgs = [0] + else: + rgs = list(range(nrg)) + rng.shuffle(rgs) + # Only keep columns that exist in this file + cols = [c for c in req if c in pf.schema.names] + if len(cols) < len(req): + missing = sorted(set(req) - set(cols)) + raise ValueError(f"Parquet missing required columns {missing} in {fp}") + for rg in rgs: + if remaining <= 0: + break + table = pf.read_row_group(rg, columns=cols) + df = table.to_pandas(types_mapper=None) + if df.empty: + continue + if len(df) > remaining: + df = df.sample(n=remaining, random_state=rng.randint(0, 2**31 - 1)) + collected.append(df) + remaining -= len(df) + if not collected: + return pd.DataFrame(columns=req) + out = pd.concat(collected, ignore_index=True) + # Final shuffle for randomness + out = out.sample(frac=1.0, random_state=seed).reset_index(drop=True) + # If we somehow overshot, trim + if len(out) > num_samples: + out = out.iloc[:num_samples].reset_index(drop=True) + return out + + +def _preferred_pooling(model_dir: Path) -> str: + """ + Best-effort pooling detection: + - First try checkpoint configs for an explicit hint + - Fallback to 'last' + Note: we'll further override this using the embeddings_dir contents if provided. + """ + for cfg_name in ("trainer_config.json", "config.json"): + fp = model_dir / cfg_name + if fp.exists(): + try: + with open(fp) as f: + cfg = json.load(f) + return str(cfg.get("species_pooling", "last")) + except Exception: + continue + return "last" + + +def _detect_pooling_from_embeddings_dir(emb_dir: Path) -> Optional[str]: + """Detect actual available pooling format from embeddings_dir contents.""" + fixed_files = [emb_dir / "species_embeddings.bin", emb_dir / "species_metadata.json", emb_dir / "species_vocab.json"] + seq_files = [emb_dir / "species_tok_emb.bin", emb_dir / "species_index.json", emb_dir / "species_vocab.json"] + if all(p.exists() for p in fixed_files): + return "last" + if all(p.exists() for p in seq_files): + return "sequence" + return None + + +@torch.no_grad() +def eval_batch( + sampler: CodonSampler, + species_store: Optional[SpeciesEmbeddingStore], + species_names: List[str], + protein_seqs: List[str], + dna_cds_list: List[str], +) -> Tuple[List[float], List[float]]: + """Evaluate a batch in teacher-forced mode. + + Returns per-sample (avg_ce_loss, aa_token_acc). + """ + tok = sampler.tokenizer + pad_id = tok.pad_token_id + eos_id = tok.eos_token_id + + # Encode DNA to codon ids and align lengths (trim to min protein length) + codon_ids = [] + seq_lens = [] + for dna, prot in zip(dna_cds_list, protein_seqs): + # Trim to min length between DNA codons and protein AA + C_dna = len(dna) // 3 + C_prot = len(prot) + C = max(min(C_dna, C_prot), 1) + dna_trim = dna[: 3 * C] + ids = tok.encode_codon_seq(dna_trim, validate=False) + ids.append(eos_id) + codon_ids.append(ids) + seq_lens.append(len(ids)) + + B = len(codon_ids) + T = max(seq_lens) + codons = torch.full((B, T), pad_id, dtype=torch.long) + mask = torch.zeros((B, T), dtype=torch.bool) + for i, ids in enumerate(codon_ids): + L = len(ids) + codons[i, :L] = torch.tensor(ids, dtype=torch.long) + mask[i, :L] = True + + # inputs/labels aligned to training convention: + # model predicts next codon after a learned start token; labels are the + # same positions as inputs (not shifted by 1), with PAD/EOS masked out. + input_ids = codons[:, :-1] + labels_base = codons[:, :-1].clone() + # Mask out PAD and EOS like trainer.evaluate() + labels_base[labels_base == pad_id] = -100 + labels_base[labels_base == eos_id] = -100 + + # Build conditioning dict similar to training and sampler + cond = {"control_mode": "fixed"} + + if species_store is not None and species_names: + sid_list = [species_store.vocab.get(s, -1) for s in species_names] + num_unknown = sum(1 for x in sid_list if x < 0) + if num_unknown > 0: + logger.warning(f"{num_unknown}/{len(sid_list)} species not found in embeddings vocab; using zero embeddings") + result = species_store.batch_get(sid_list) + if isinstance(result, tuple): + sp_tok, _ = result # [B, Ls, Ds] + cond["species_tok_emb_src"] = sp_tok.to(sampler.device) + cond["species_tok_emb_tgt"] = sp_tok.to(sampler.device) + else: + sp = result # [B, Ds] + cond["species_emb_src"] = sp.to(sampler.device) + cond["species_emb_tgt"] = sp.to(sampler.device) + elif species_names: + # On-the-fly species embeddings using Qwen (sequence pooling for training parity) + seq_emb, _lens = sampler._qwen_embed_names(species_names, pooling="sequence") + seq_emb = seq_emb.to(sampler.device) + cond["species_tok_emb_src"] = seq_emb + cond["species_tok_emb_tgt"] = seq_emb + + # Match training: pass raw protein sequences; the model tokenizes internally + cond["protein_seqs"] = protein_seqs + + # Move tensors to device + device = sampler.device + input_ids = input_ids.to(device) + labels_base = labels_base.to(device) + + sampler.model.eval() + outputs = sampler.model(codon_ids=input_ids, cond=cond, labels=labels_base, return_dict=True) + logits = outputs["logits"] # [B, Lmax, V] aligned to per-sample capacity after prefix + try: + prefix_len = outputs.get("prefix_len", 0) + if isinstance(prefix_len, torch.Tensor): + prefix_len_dbg = int(prefix_len.max().item()) if prefix_len.numel() > 0 else 0 + else: + prefix_len_dbg = int(prefix_len) + logger.debug(f"Prefix length(max)={prefix_len_dbg}, input_len={input_ids.size(1)}") + except Exception: + pass + + # Align labels/masks to logits length and per-sample caps + Bsz, Lmax, V = logits.size(0), logits.size(1), logits.size(2) + labels_aligned = torch.full((Bsz, Lmax), -100, dtype=labels_base.dtype, device=logits.device) + common_cols = min(labels_base.size(1), Lmax) + if common_cols > 0: + labels_aligned[:, :common_cols] = labels_base[:, :common_cols] + per_cap = outputs.get("per_cap", None) + if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz: + ar = torch.arange(Lmax, device=logits.device).unsqueeze(0) + cap_mask = ar < per_cap.to(device=logits.device).unsqueeze(1) # [B,Lmax] + else: + cap_mask = torch.ones_like(labels_aligned, dtype=torch.bool, device=logits.device) + + # Mask labels beyond per-cap to -100 so CE ignores them + labels_masked = labels_aligned.clone().to(device=logits.device) + labels_masked[~cap_mask] = -100 + + # Cross-entropy per sample (include EOS target; ignore PAD) + loss_flat = F.cross_entropy( + logits.reshape(-1, V), + labels_masked.reshape(-1), + ignore_index=-100, + reduction="none", + ).view(Bsz, Lmax) + + # Accuracy per sample + preds = logits.argmax(dim=-1) + num_special = int(getattr(tok, "num_special_tokens", 0) or 0) + supervised = (labels_masked != -100) & cap_mask + if num_special > 0: + supervised = supervised & (labels_aligned >= num_special) + correct = (preds == labels_aligned) & supervised + + per_sample_ce: List[float] = [] + per_sample_acc: List[float] = [] + per_sample_aa_acc: List[float] = [] + codon2aa = tok.codon2aa_char_map() if hasattr(tok, "codon2aa_char_map") else {} + per_cap = outputs.get("per_cap", None) + per_cap_int = None + if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz: + per_cap_int = torch.clamp(per_cap.to(dtype=torch.long, device=logits.device), min=0, max=Lmax) + + for i in range(B): + # Average CE over valid positions + valid = (labels_masked[i] != -100) & cap_mask[i] + if num_special > 0: + valid = valid & (labels_aligned[i] >= num_special) + ce = (loss_flat[i][valid].mean().item() if valid.any() else 0.0) + per_sample_ce.append(ce) + + # Codon-level accuracy over supervised positions + denom = supervised[i].sum().item() + acc = (correct[i].sum().item() / denom) if denom > 0 else 0.0 + # AA-level accuracy per sample (match trainer) + aa_acc = 0.0 + if per_cap_int is not None and codon2aa and i < len(protein_seqs): + cap = int(per_cap_int[i].item()) + if cap > 0: + mask_row = supervised[i, :cap] + if mask_row.any(): + preds_row = preds[i, :cap][mask_row] + prot = protein_seqs[i] + seq_len = min(len(prot), preds_row.size(0)) + if seq_len > 0: + pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) + truth_aa = prot[:seq_len] + aa_matches = sum(1 for j in range(seq_len) if pred_aa[j] == truth_aa[j]) + aa_acc = aa_matches / seq_len + per_sample_aa_acc.append(aa_acc) + + return per_sample_ce, per_sample_aa_acc + + +def _dna_to_codons(dna: str) -> List[str]: + dna = dna.strip().upper() + return [dna[i:i+3] for i in range(0, len(dna) - (len(dna) % 3), 3)] + + +def _aa_from_dna_standard(dna: str, tok) -> str: + dna = dna.strip().upper() + gc = getattr(tok, "_genetic_code", {}) + aa = [] + for j in range(0, len(dna) - (len(dna) % 3), 3): + aa.append(gc.get(dna[j:j+3], 'X')) + return ''.join(aa) + + +def _aa_agreement(dna: str, protein: str, tok) -> Tuple[float, int, int]: + """Return (match_ratio, compared_len, first_mismatch_idx or -1) under standard code.""" + dna = dna.strip().upper() + protein = protein.strip().upper() + L = min(len(dna) // 3, len(protein)) + if L <= 0: + return 0.0, 0, -1 + aa_pred = _aa_from_dna_standard(dna[: 3 * L], tok) + truth = protein[:L] + mism_idx = -1 + matches = 0 + for i, (a, b) in enumerate(zip(aa_pred, truth)): + if a == b: + matches += 1 + elif mism_idx < 0: + mism_idx = i + return (matches / L), L, mism_idx + + +@torch.no_grad() +def eval_streaming_all( + sampler: CodonSampler, + species_store: SpeciesEmbeddingStore, + data_path: str, + batch_size: int, + num_workers: int, + max_records: int = 0, +): + """Stream over all rows from CSV/Parquet inputs and compute dataset-level metrics. + + Mirrors trainer.evaluate() for parity. + """ + device = sampler.device + tok = sampler.tokenizer + pad_id = int(tok.pad_token_id) + eos_id = int(tok.eos_token_id) + num_special = int(tok.num_special_tokens) + codon2aa = tok.codon2aa_char_map() + + # Build streaming dataset and loader + from pathlib import Path as _Path + import glob as _glob + def _expand(pat: str) -> List[str]: + P = _Path(pat) + if P.is_dir(): + paths: List[str] = [] + paths.extend(sorted(str(x) for x in P.rglob("*.parquet"))) + paths.extend(sorted(str(x) for x in P.rglob("*.parq"))) + paths.extend(sorted(str(x) for x in P.rglob("*.csv"))) + paths.extend(sorted(str(x) for x in P.rglob("*.tsv"))) + paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz"))) + paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz"))) + else: + paths = sorted(_glob.glob(str(P))) + # de-dup + seen = set(); out = [] + for x in paths: + if x not in seen: + out.append(x); seen.add(x) + return out + + paths = _expand(data_path) + if not paths: + raise FileNotFoundError(f"No input files matched: {data_path}") + + species_vocab_path = str((Path(species_store.embeddings_dir) / "species_vocab.json").resolve()) + ds = StreamSeqDataset( + files=paths, + tokenizer=tok, + species_vocab_path=species_vocab_path, + unknown_species_id=0, + csv_chunksize=200_000, + shuffle_buffer=0, + shard_across_ranks=False, + ) + _dl_kwargs = dict( + batch_size=int(batch_size), + shuffle=False, + drop_last=False, + num_workers=int(max(0, num_workers)), + collate_fn=stage_collate_fn, + pin_memory=True, + persistent_workers=(int(num_workers) > 0), + ) + if int(num_workers) > 0: + _dl_kwargs["prefetch_factor"] = 4 + loader = DataLoader(ds, **_dl_kwargs) + + loss_sum = 0.0 + loss_tokens = 0 + codon_correct = 0 + codon_total = 0 + aa_correct = 0 + aa_total = 0 + + seen = 0 + for batch in loader: + if not batch: + continue + if int(max_records) > 0 and seen >= int(max_records): + break + codon_ids = batch["codon_ids"].to(device) + input_ids = codon_ids[:, :-1] + labels = codon_ids[:, :-1].clone() + labels[labels == pad_id] = -100 + labels[labels == eos_id] = -100 + + # Build cond using species_store and protein_seqs + cond = {"control_mode": "fixed", "protein_seqs": batch.get("protein_seqs", [])} + sids = batch.get("species_ids") + if torch.is_tensor(sids): + sids_list = sids.detach().cpu().tolist() + else: + sids_list = [int(x) for x in sids] + res = species_store.batch_get(sids_list) + if isinstance(res, tuple): + sp_tok, _ = res + cond["species_tok_emb_src"] = sp_tok.to(device) + cond["species_tok_emb_tgt"] = sp_tok.to(device) + else: + cond["species_emb_src"] = res.to(device) + cond["species_emb_tgt"] = res.to(device) + + out = sampler.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True) + loss = out.get("loss") + per_cap = out.get("per_cap") + logits = out.get("logits") + + tokens_in_batch = 0 + if per_cap is not None: + tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item()) + loss_tokens += tokens_in_batch + if loss is not None and tokens_in_batch > 0: + loss_sum += float(loss.detach().item()) * tokens_in_batch + + if logits is None or logits.size(1) == 0 or per_cap is None: + seen += input_ids.size(0) + continue + max_cap = logits.size(1) + batch_size = logits.size(0) + labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device) + common = min(labels.size(1), max_cap) + if common > 0: + labels_aligned[:, :common] = labels[:, :common] + per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap) + for row in range(batch_size): + cap = int(per_cap_int[row].item()) + if cap < max_cap: + labels_aligned[row, cap:] = -100 + supervised = labels_aligned != -100 + if num_special > 0: + supervised = supervised & (labels_aligned >= num_special) + if not supervised.any(): + seen += batch_size + continue + preds = logits.argmax(dim=-1) + codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item()) + codon_total += int(supervised.sum().item()) + + # protein list + prot_list = cond.get("protein_seqs", []) + for row in range(batch_size): + cap = int(per_cap_int[row].item()) + if cap <= 0: + continue + mask_row = supervised[row, :cap] + if not mask_row.any(): + continue + preds_row = preds[row, :cap][mask_row] + prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else "" + if not prot: + continue + seq_len = min(len(prot), preds_row.size(0)) + if seq_len <= 0: + continue + pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) + truth_aa = prot[:seq_len] + aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i]) + aa_total += seq_len + seen += batch_size + + mean_ce = (loss_sum / loss_tokens) if loss_tokens > 0 else 0.0 + codon_acc = (float(codon_correct) / codon_total) if codon_total > 0 else 0.0 + aa_acc = (float(aa_correct) / aa_total) if aa_total > 0 else 0.0 + logger.info( + f"Full-dataset summary → tokens={loss_tokens} CE={mean_ce:.4f} CODON-acc={codon_acc:.4f} AA-acc={aa_acc:.4f}" + ) + return mean_ce, codon_acc, aa_acc + + +@torch.no_grad() +def sample_and_score_batched( + sampler: CodonSampler, + species_names: List[str], + protein_seqs: List[str], + target_dnas: List[str], + temperature: float, + top_k: int, + top_p: float, + control_mode: str, + batch_size: int, + enforce_translation: bool, + no_truncation: bool = False, + species_prefix_cap: int = 64, +) -> Tuple[List[float], List[float]]: + """Free-run sampling in batches; returns per-sample (codon_acc, aa_acc).""" + N = len(species_names) + # Compute target lengths in codons (min of DNA and AA lengths) + tgt_lengths = [] + tgt_codons_list = [] + for prot, dna in zip(protein_seqs, target_dnas): + cods = _dna_to_codons(dna) + L = min(len(cods), len(prot)) + if L <= 0: + L = 1 + cods = ["ATG"] # harmless default + tgt_lengths.append(L) + tgt_codons_list.append(cods[:L]) + + # Bucket indices by target length to maximize batching + buckets: dict[int, List[int]] = {} + for i, L in enumerate(tgt_lengths): + buckets.setdefault(L, []).append(i) + + codon_accs = [0.0] * N + aa_accs = [0.0] * N + + # Helper AA translation + vocab = sampler.tokenizer._genetic_code + def dna_to_aa(dna: str) -> str: + dna = dna.strip().upper() + aa = [] + for j in range(0, len(dna) - (len(dna) % 3), 3): + aa.append(vocab.get(dna[j:j+3], 'X')) + return ''.join(aa) + + for L, idxs in buckets.items(): + # Optionally tighten protein prefix so prefix+start+L ≤ capacity (species kept full unless capped) + prev_sp = getattr(sampler.model, "max_species_prefix", 0) + prev_pp = getattr(sampler.model, "max_protein_prefix", 0) + if bool(no_truncation): + try: + capacity = int(getattr(sampler.model, "max_position_embeddings", 1024)) + # If requested, apply a species token cap; otherwise keep as-is + store = getattr(sampler, "species_store", None) + if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0: + setattr(sampler.model, "max_species_prefix", int(species_prefix_cap)) + # Build a representative cond for this bucket to measure exact prefix length + batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))] + sp_probe = [species_names[i] for i in batch_idx_probe] + pr_probe = [protein_seqs[i] for i in batch_idx_probe] + # Map species to ids via store vocab + cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe} + if store is not None: + sid_list = [store.vocab.get(s, -1) for s in sp_probe] + res = store.batch_get(sid_list) + if isinstance(res, tuple): + sp_tok, _ = res + cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device) + cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device) + else: + cond_probe["species_emb_src"] = res.to(sampler.device) + cond_probe["species_emb_tgt"] = res.to(sampler.device) + # Iteratively reduce protein prefix cap until remaining ≥ L + for _ in range(3): + out0 = sampler.model( + codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device), + cond=cond_probe, + return_dict=True, + use_cache=True, + ) + pref = out0.get("prefix_len") + if isinstance(pref, torch.Tensor) and pref.numel() > 0: + pref_max = int(pref.max().item()) + else: + pref_max = int(pref) if isinstance(pref, int) else 0 + remaining = capacity - (pref_max + 1) + if remaining >= int(L): + break + need = int(L) - max(0, int(remaining)) + cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0) + new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L))) + setattr(sampler.model, "max_protein_prefix", int(new_pp)) + except Exception: + pass + # Process in mini-batches + for k in range(0, len(idxs), batch_size): + batch_idx = idxs[k:k+batch_size] + sp_b = [species_names[i] for i in batch_idx] + pr_b = [protein_seqs[i] for i in batch_idx] + # Sample in one call + out = sampler.sample( + num_sequences=len(batch_idx), + sequence_length=L, + species=sp_b, + protein_sequences=pr_b, + control_mode=control_mode, + temperature=temperature, + top_k=top_k, + top_p=top_p, + return_intermediate=False, + progress_bar=False, + enforce_translation=enforce_translation, + ) + gen_list: List[str] = out["sequences"] # DNA strings + # Score each + for pos, idx in enumerate(batch_idx): + tgt_codons = tgt_codons_list[idx] + gen_codons = _dna_to_codons(gen_list[pos])[:L] + matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b) + codon_accs[idx] = (matches / L) if L > 0 else 0.0 + gen_aa = dna_to_aa(''.join(gen_codons)) + tgt_aa = protein_seqs[idx][:L] + # Treat non-canonical AA in target as "match any" + canonical = set("ACDEFGHIKLMNPQRSTVWY") + aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b)) + aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0 + # Restore caps + if bool(no_truncation): + try: + setattr(sampler.model, "max_species_prefix", prev_sp) + setattr(sampler.model, "max_protein_prefix", prev_pp) + except Exception: + pass + + return codon_accs, aa_accs + + +@torch.no_grad() +def generate_and_score_batched( + sampler: CodonSampler, + species_names: List[str], + protein_seqs: List[str], + target_dnas: List[str], + temperature: float, + top_k: int, + top_p: float, + control_mode: str, + batch_size: int, + enforce_translation: bool, + no_truncation: bool = False, + species_prefix_cap: int = 64, +) -> Tuple[List[str], List[float], List[float]]: + """Like sample_and_score_batched but also returns generated DNA sequences per sample.""" + N = len(species_names) + tgt_lengths = [] + tgt_codons_list = [] + for prot, dna in zip(protein_seqs, target_dnas): + cods = _dna_to_codons(dna) + L = min(len(cods), len(prot)) + if L <= 0: + L = 1 + cods = ["ATG"] + tgt_lengths.append(L) + tgt_codons_list.append(cods[:L]) + + buckets: dict[int, List[int]] = {} + for i, L in enumerate(tgt_lengths): + buckets.setdefault(L, []).append(i) + + gen_all = [""] * N + codon_accs = [0.0] * N + aa_accs = [0.0] * N + + vocab = sampler.tokenizer._genetic_code + def dna_to_aa(dna: str) -> str: + dna = dna.strip().upper() + aa = [] + for j in range(0, len(dna) - (len(dna) % 3), 3): + aa.append(vocab.get(dna[j:j+3], 'X')) + return ''.join(aa) + + for L, idxs in buckets.items(): + prev_sp = getattr(sampler.model, "max_species_prefix", 0) + prev_pp = getattr(sampler.model, "max_protein_prefix", 0) + if bool(no_truncation): + try: + capacity = int(getattr(sampler.model, "max_position_embeddings", 1024)) + store = getattr(sampler, "species_store", None) + if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0: + setattr(sampler.model, "max_species_prefix", int(species_prefix_cap)) + batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))] + sp_probe = [species_names[i] for i in batch_idx_probe] + pr_probe = [protein_seqs[i] for i in batch_idx_probe] + cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe} + if store is not None: + sid_list = [store.vocab.get(s, -1) for s in sp_probe] + res = store.batch_get(sid_list) + if isinstance(res, tuple): + sp_tok, _ = res + cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device) + cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device) + else: + cond_probe["species_emb_src"] = res.to(sampler.device) + cond_probe["species_emb_tgt"] = res.to(sampler.device) + for _ in range(3): + out0 = sampler.model( + codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device), + cond=cond_probe, + return_dict=True, + use_cache=True, + ) + pref = out0.get("prefix_len") + pref_max = int(pref.max().item()) if isinstance(pref, torch.Tensor) and pref.numel() > 0 else (int(pref) if isinstance(pref, int) else 0) + remaining = capacity - (pref_max + 1) + if remaining >= int(L): + break + need = int(L) - max(0, int(remaining)) + cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0) + new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L))) + setattr(sampler.model, "max_protein_prefix", int(new_pp)) + except Exception: + pass + for k in range(0, len(idxs), batch_size): + batch_idx = idxs[k:k+batch_size] + sp_b = [species_names[i] for i in batch_idx] + pr_b = [protein_seqs[i] for i in batch_idx] + out = sampler.sample( + num_sequences=len(batch_idx), + sequence_length=L, + species=sp_b, + protein_sequences=pr_b, + control_mode=control_mode, + temperature=temperature, + top_k=top_k, + top_p=top_p, + return_intermediate=False, + progress_bar=False, + enforce_translation=enforce_translation, + ) + gen_list: List[str] = out["sequences"] + for pos, idx in enumerate(batch_idx): + gen_seq = gen_list[pos] + gen_all[idx] = gen_seq + tgt_codons = tgt_codons_list[idx] + gen_codons = _dna_to_codons(gen_seq)[:L] + matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b) + codon_accs[idx] = (matches / L) if L > 0 else 0.0 + gen_aa = dna_to_aa(''.join(gen_codons)) + tgt_aa = protein_seqs[idx][:L] + canonical = set("ACDEFGHIKLMNPQRSTVWY") + aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b)) + aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0 + if bool(no_truncation): + try: + setattr(sampler.model, "max_species_prefix", prev_sp) + setattr(sampler.model, "max_protein_prefix", prev_pp) + except Exception: + pass + + return gen_all, codon_accs, aa_accs + + +def export_per_sequence_over_splits( + sampler: CodonSampler, + splits: List[str], + splits_root: str, + out_csv: str, + batch_size: int, + temperature: float, + top_k: int, + top_p: float, + control_mode: str, + enforce_translation: bool, + progress: bool = False, + max_rows_per_split: int = 0, + no_truncation: bool = False, + species_prefix_cap: int = 0, +) -> None: + """Process ./data/val and ./data/test (or under splits_root) and write a per-sequence CSV.""" + try: + import pyarrow.parquet as pq # type: ignore + except Exception as e: + raise ImportError("pyarrow is required for Parquet evaluation/export") from e + + from pathlib import Path as _P + import os as _os + total_written = 0 + # Pre-create CSV with header so users can tail it immediately + header_cols = [ + "split", + "organism", + "protein_seq", + "codon_seq", + "predicted_seq", + "codon_similarity", + "amino_acid_recovery_rate", + ] + _P(out_csv).parent.mkdir(parents=True, exist_ok=True) + if not _P(out_csv).exists() or _os.path.getsize(out_csv) == 0: + with open(out_csv, "w", newline="") as f: + f.write(",".join(header_cols) + "\n") + logging.info(f"Initialized CSV with header → {out_csv}") + for split in splits: + rows_remaining = int(max_rows_per_split) if int(max_rows_per_split) > 0 else None + dir_path = Path(splits_root) / split + files = sorted(str(p) for p in dir_path.glob("*.parquet")) + if not files: + logging.warning(f"No parquet files found in {dir_path}, skipping split {split}") + continue + logging.info(f"Processing split '{split}' with {len(files)} files ...") + try: + from tqdm import tqdm # type: ignore + _wrap = (lambda it, **kw: tqdm(it, **kw)) if progress else (lambda it, **kw: it) + except Exception: + _wrap = (lambda it, **kw: it) + stop_split = False + for fp in _wrap(files, desc=f"{split} files", unit="file"): + if rows_remaining is not None and rows_remaining <= 0: + break + pf = pq.ParquetFile(fp) + nrg = int(pf.num_row_groups or 0) + rgs = list(range(max(nrg, 1))) + # Build a per-file rows progress bar (prefer total rows from metadata when available) + rows_total = None + try: + if pf.metadata is not None: + rows_total = 0 + for rg_idx in rgs: + rg_md = pf.metadata.row_group(rg_idx) + if rg_md is not None and rg_md.num_rows is not None: + rows_total += int(rg_md.num_rows) + except Exception: + rows_total = None + rows_pbar = None + if progress: + try: + from tqdm import tqdm # type: ignore + rows_pbar = tqdm(total=rows_total, desc=f"{split}:{Path(fp).name}", unit="rows", leave=False) + except Exception: + rows_pbar = None + + for rg in rgs: + if rows_remaining is not None and rows_remaining <= 0: + stop_split = True + break + table = pf.read_row_group(rg, columns=["Taxon", "protein_seq", "cds_DNA"]) + df = table.to_pandas() + if df.empty: + continue + species = df["Taxon"].astype(str).tolist() + proteins = df["protein_seq"].astype(str).str.upper().tolist() + dnas = df["cds_DNA"].astype(str).str.upper().tolist() + + # Generate predictions and metrics in streaming mini-batches to keep + # memory stable and update progress frequently + N = len(species) + for off in range(0, N, batch_size): + if rows_remaining is not None and rows_remaining <= 0: + stop_split = True + break + sp_b = species[off: off + batch_size] + pr_b = proteins[off: off + batch_size] + dn_b = dnas[off: off + batch_size] + gen_list, codon_accs, aa_accs = generate_and_score_batched( + sampler, + sp_b, + pr_b, + dn_b, + temperature=temperature, + top_k=top_k, + top_p=top_p, + control_mode=control_mode, + batch_size=batch_size, + enforce_translation=enforce_translation, + no_truncation=bool(no_truncation), + species_prefix_cap=int(species_prefix_cap), + ) + rows_batch: List[dict] = [] + for sp, pr, dn, gen, cacc, aacc in zip(sp_b, pr_b, dn_b, gen_list, codon_accs, aa_accs): + L = min(len(pr), len(dn) // 3) + tgt_dna = dn[: 3 * L] + rows_batch.append({ + "split": split, + "organism": sp, + "protein_seq": pr, + "codon_seq": tgt_dna, + "predicted_seq": gen, + "codon_similarity": float(cacc), + "amino_acid_recovery_rate": float(aacc), + }) + if rows_batch: + if rows_remaining is not None and len(rows_batch) > rows_remaining: + rows_batch = rows_batch[: rows_remaining] + out_exists = _P(out_csv).exists() and _os.path.getsize(out_csv) > 0 + df_out = pd.DataFrame(rows_batch) + _P(out_csv).parent.mkdir(parents=True, exist_ok=True) + df_out.to_csv(out_csv, mode='a', header=not out_exists, index=False) + total_written += len(rows_batch) + if rows_remaining is not None: + rows_remaining -= len(rows_batch) + if rows_pbar is not None: + try: + rows_pbar.update(len(rows_batch)) + except Exception: + pass + if rows_remaining is not None and rows_remaining <= 0: + stop_split = True + break + if rows_pbar is not None: + try: + rows_pbar.close() + except Exception: + pass + if stop_split: + break + logging.info(f"Per-sequence export complete → {out_csv} (rows={total_written})") + + +def main(): + args = parse_args() + random.seed(args.seed) + torch.manual_seed(args.seed) + + model_dir = Path(args.model_path) + pooling = _preferred_pooling(model_dir) + logger.info(f"Preferred species_pooling from checkpoint: {pooling}") + + # Set up species store (recommended for parity) + species_store = None + if args.embeddings_dir: + emb_dir = Path(args.embeddings_dir) + detected = _detect_pooling_from_embeddings_dir(emb_dir) + if detected is not None and detected != pooling: + logger.info(f"Overriding pooling from checkpoint ({pooling}) → embeddings_dir format ({detected})") + pooling = detected + species_store = SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling) + logger.info(f"Loaded species store with {len(species_store.vocab)} species (pooling={pooling})") + + # Load sampler/model (uses same construction as sampling) + sampler = CodonSampler( + model_path=args.model_path, + device=("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"), + species_store=species_store, + ) + + # Load input data and sample rows + if bool(args.export_per_sequence): + export_per_sequence_over_splits( + sampler, + splits=list(args.export_splits), + splits_root=str(args.splits_root), + out_csv=str(args.out_csv), + batch_size=int(args.batch_size), + temperature=float(args.temperature), + top_k=int(args.top_k), + top_p=float(args.top_p), + control_mode=str(args.control_mode), + enforce_translation=bool(args.enforce_translation), + progress=bool(args.progress), + max_rows_per_split=int(args.max_rows_per_split), + no_truncation=bool(args.no_truncation), + species_prefix_cap=int(args.species_prefix_cap), + ) + return + + data_path = args.data_path or args.csv_path + if data_path is None: + raise SystemExit("Please provide --data_path (CSV or Parquet glob/dir). --csv_path remains as a deprecated alias.") + + # Expand paths to decide CSV vs Parquet + paths = _expand_paths(data_path) + if not paths: + raise FileNotFoundError(f"No input files matched: {data_path}") + + if all(_is_parquet_path(p) for p in paths): + logger.info(f"Reading up to {args.num_samples} samples from {len(paths)} parquet files ...") + df_s = _load_random_samples_from_parquet(paths, int(args.num_samples), int(args.seed)) + else: + # Fallback to CSV/TSV single file behavior (back-compat). If multiple files match, use the first. + csv_file = None + for pth in paths: + if pth.lower().endswith((".csv", ".tsv", ".csv.gz", ".tsv.gz")): + csv_file = pth + break + if csv_file is None: + raise ValueError(f"Unsupported input for --data_path: {paths[0]}") + logger.info(f"Reading CSV file: {csv_file}") + df = pd.read_csv(csv_file) + required = {"Taxon", "protein_seq", "cds_DNA"} + if not required.issubset(set(df.columns)): + missing = required - set(df.columns) + raise ValueError(f"CSV missing required columns: {sorted(missing)}") + if args.num_samples > len(df): + logger.warning(f"num_samples {args.num_samples} > CSV rows {len(df)}; reducing") + args.num_samples = len(df) + # Random sample without replacement + indices = random.sample(range(len(df)), args.num_samples) + df_s = df.iloc[indices].reset_index(drop=True) + + if len(df_s) == 0: + raise ValueError("No samples loaded from the provided data_path") + + logger.info(f"Loaded {len(df_s)} samples for evaluation") + + species = df_s["Taxon"].astype(str).tolist() + proteins = df_s["protein_seq"].astype(str).str.upper().tolist() + dnas = df_s["cds_DNA"].astype(str).str.upper().tolist() + + if not args.free_run: + if bool(args.eval_all): + if not args.embeddings_dir: + raise SystemExit("--eval_all requires --embeddings_dir for species vocab/embeddings") + # Stream the entire dataset and compute dataset-level metrics (training-parity) + eval_streaming_all( + sampler, + species_store if species_store is not None else SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling), + data_path, + batch_size=int(args.batch_size), + num_workers=int(args.workers), + max_records=int(args.max_records), + ) + return + # Optional: print per-sample CDS→AA agreement (standard code) + if bool(args.debug_aa_check): + for idx, (sp, pr, dn) in enumerate(zip(species, proteins, dnas), start=1): + ratio, Lcmp, first_bad = _aa_agreement(dn, pr, sampler.tokenizer) + flag = "OK" if ratio == 1.0 and Lcmp > 0 else ("EMPTY" if Lcmp == 0 else "MISMATCH") + extra = f" first_mismatch={first_bad}" if first_bad >= 0 else "" + logger.info(f"AA-CHECK Sample {idx:02d}: {flag} match={ratio:.3f} len={Lcmp}{extra} Taxon={sp}") + # (No dataset-level filtering to keep evaluation simple.) + # Teacher-forced evaluation (random subset) + per_ce_all: List[float] = [] + per_aa_acc_all: List[float] = [] + per_codon_acc_all: List[float] = [] + bs = max(1, int(args.batch_size)) + for i in range(0, len(species), bs): + sp_b = species[i:i+bs] + pr_b = proteins[i:i+bs] + dn_b = dnas[i:i+bs] + ce, aa_acc = eval_batch(sampler, species_store, sp_b, pr_b, dn_b) + # Also compute per-sample codon-acc using the same batch forward for consistency + # Re-run lightweight preds for codon-acc is unnecessary because eval_batch already + # computed supervised mask and preds internally; instead, recompute quickly here + # by calling eval_batch and deriving codon-acc inside it. For simplicity and clarity + # we re-derive codon-acc below using the same masking rules. + per_ce_all.extend(ce) + per_aa_acc_all.extend(aa_acc) + + # Derive codon-acc for this batch + # Prepare a mirrored forward to access logits and masks (small overhead acceptable) + tok = sampler.tokenizer + pad_id = tok.pad_token_id + eos_id = tok.eos_token_id + codon_ids_local = [] + for dna, prot in zip(dn_b, pr_b): + C_dna = len(dna) // 3 + C_prot = len(prot) + C = max(min(C_dna, C_prot), 1) + dna_trim = dna[: 3 * C] + ids = tok.encode_codon_seq(dna_trim, validate=False) + ids.append(eos_id) + codon_ids_local.append(ids) + B_b = len(codon_ids_local) + T_b = max(len(x) for x in codon_ids_local) + codons_b = torch.full((B_b, T_b), pad_id, dtype=torch.long) + mask_b = torch.zeros((B_b, T_b), dtype=torch.bool) + for j, ids in enumerate(codon_ids_local): + Lb = len(ids) + codons_b[j, :Lb] = torch.tensor(ids, dtype=torch.long) + mask_b[j, :Lb] = True + input_ids_b = codons_b[:, :-1].to(sampler.device) + labels_b = codons_b[:, :-1].clone() + labels_b[labels_b == pad_id] = -100 + labels_b[labels_b == eos_id] = -100 + cond_b = {"control_mode": "fixed"} + if species_store is not None and sp_b: + sids_b = [species_store.vocab.get(s, -1) for s in sp_b] + res_b = species_store.batch_get(sids_b) + if isinstance(res_b, tuple): + sp_tok_b, _ = res_b + cond_b["species_tok_emb_src"] = sp_tok_b.to(sampler.device) + cond_b["species_tok_emb_tgt"] = sp_tok_b.to(sampler.device) + else: + sp_fix_b = res_b + cond_b["species_emb_src"] = sp_fix_b.to(sampler.device) + cond_b["species_emb_tgt"] = sp_fix_b.to(sampler.device) + cond_b["protein_seqs"] = pr_b + out_b = sampler.model(codon_ids=input_ids_b, cond=cond_b, labels=labels_b.to(sampler.device), return_dict=True) + logits_b = out_b["logits"] + per_cap_b = out_b.get("per_cap") + if logits_b is not None and per_cap_b is not None: + Bsz, Lmax, V = logits_b.size(0), logits_b.size(1), logits_b.size(2) + labels_aligned_b = torch.full((Bsz, Lmax), -100, dtype=labels_b.dtype, device=logits_b.device) + common_cols_b = min(labels_b.size(1), Lmax) + if common_cols_b > 0: + labels_aligned_b[:, :common_cols_b] = labels_b.to(logits_b.device)[:, :common_cols_b] + ar = torch.arange(Lmax, device=logits_b.device).unsqueeze(0) + cap_mask_b = ar < per_cap_b.to(device=logits_b.device).unsqueeze(1) + labels_masked_b = labels_aligned_b.clone() + labels_masked_b[~cap_mask_b] = -100 + preds_b = logits_b.argmax(dim=-1) + num_special = int(getattr(tok, "num_special_tokens", 0) or 0) + supervised_b = (labels_masked_b != -100) & cap_mask_b + if num_special > 0: + supervised_b = supervised_b & (labels_aligned_b >= num_special) + for r in range(Bsz): + denom = int(supervised_b[r].sum().item()) + cod_acc = (float((preds_b[r][supervised_b[r]] == labels_aligned_b[r][supervised_b[r]]).sum().item()) / denom) if denom > 0 else 0.0 + per_codon_acc_all.append(cod_acc) + + for idx, (ce, aa, ca) in enumerate(zip(per_ce_all, per_aa_acc_all, per_codon_acc_all), start=1): + logger.info(f"Sample {idx:02d}: CE={ce:.4f} CODON-acc={ca:.4f} AA-acc={aa:.4f}") + if per_ce_all: + mean_ce = sum(per_ce_all) / len(per_ce_all) + mean_aa = sum(per_aa_acc_all) / len(per_aa_acc_all) if per_aa_acc_all else 0.0 + mean_codon = sum(per_codon_acc_all) / len(per_codon_acc_all) if per_codon_acc_all else 0.0 + logger.info(f"Summary over {len(per_ce_all)} samples → mean CE={mean_ce:.4f}, mean CODON-acc={mean_codon:.4f}, mean AA-acc={mean_aa:.4f}") + else: + # Free-run sampling evaluation vs ground-truth DNA (codon-level), batched + codon_accs, aa_accs = sample_and_score_batched( + sampler, + species, + proteins, + dnas, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + control_mode=args.control_mode, + batch_size=int(args.batch_size), + enforce_translation=bool(args.enforce_translation), + no_truncation=bool(args.no_truncation), + species_prefix_cap=int(args.species_prefix_cap), + ) + for idx, (cacc, aacc) in enumerate(zip(codon_accs, aa_accs), start=1): + logger.info(f"Sample {idx:02d}: CODON-acc={cacc:.4f} AA-acc={aacc:.4f}") + if codon_accs: + mean_c = sum(codon_accs) / len(codon_accs) + mean_a = sum(aa_accs) / len(aa_accs) + logger.info(f"Summary over {len(codon_accs)} samples → mean CODON-acc={mean_c:.4f}, mean AA-acc={mean_a:.4f}") + + +if __name__ == "__main__": + main() diff --git a/final_model/config.json b/final_model/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6cc9978ec27383e958dfa2b8653d337c382322af --- /dev/null +++ b/final_model/config.json @@ -0,0 +1,17 @@ +{ + "max_length": 2048, + "max_species_prefix": 0, + "max_protein_prefix": 1024, + "hidden_size": 750, + "num_hidden_layers": 20, + "num_attention_heads": 15, + "mlp_ratio": 3.2, + "prepend_species": true, + "prepend_protein": true, + "species_embedding_dim": 1024, + "esm_model_name": "esmc_300m", + "esm_device": "cuda:0", + "esm_dtype": "bf16", + "attn_impl": "mha", + "num_kv_groups": 5 +} \ No newline at end of file diff --git a/final_model/model.safetensors b/final_model/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..ccb8b0b56cf805fef2e64295495944b4d4003323 --- /dev/null +++ b/final_model/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5af6fe27a93e8a5edf622131b8fff74240f90db036a95697cfe4f28af1d23ef9 +size 1284544520 diff --git a/final_model/trainer_config.json b/final_model/trainer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..6cc9978ec27383e958dfa2b8653d337c382322af --- /dev/null +++ b/final_model/trainer_config.json @@ -0,0 +1,17 @@ +{ + "max_length": 2048, + "max_species_prefix": 0, + "max_protein_prefix": 1024, + "hidden_size": 750, + "num_hidden_layers": 20, + "num_attention_heads": 15, + "mlp_ratio": 3.2, + "prepend_species": true, + "prepend_protein": true, + "species_embedding_dim": 1024, + "esm_model_name": "esmc_300m", + "esm_device": "cuda:0", + "esm_dtype": "bf16", + "attn_impl": "mha", + "num_kv_groups": 5 +} \ No newline at end of file diff --git a/final_model/trainer_state.json b/final_model/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..b3dba237ade214bc8ef5aaa2d004acaa4760578b --- /dev/null +++ b/final_model/trainer_state.json @@ -0,0 +1,4 @@ +{ + "epoch": 2, + "global_step": 120513 +} \ No newline at end of file diff --git a/final_model/vocab.json b/final_model/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..01c8e6b032b471eac80fe56cf65ee81e62f49921 --- /dev/null +++ b/final_model/vocab.json @@ -0,0 +1,78 @@ +{ + "special_token_str": { + "bos": "", + "eos": "", + "pad": "", + "unk": "" + }, + "vocab": { + "": 2, + "": 0, + "": 3, + "": 1, + "AAA": 4, + "AAC": 5, + "AAG": 6, + "AAT": 7, + "ACA": 8, + "ACC": 9, + "ACG": 10, + "ACT": 11, + "AGA": 12, + "AGC": 13, + "AGG": 14, + "AGT": 15, + "ATA": 16, + "ATC": 17, + "ATG": 18, + "ATT": 19, + "CAA": 20, + "CAC": 21, + "CAG": 22, + "CAT": 23, + "CCA": 24, + "CCC": 25, + "CCG": 26, + "CCT": 27, + "CGA": 28, + "CGC": 29, + "CGG": 30, + "CGT": 31, + "CTA": 32, + "CTC": 33, + "CTG": 34, + "CTT": 35, + "GAA": 36, + "GAC": 37, + "GAG": 38, + "GAT": 39, + "GCA": 40, + "GCC": 41, + "GCG": 42, + "GCT": 43, + "GGA": 44, + "GGC": 45, + "GGG": 46, + "GGT": 47, + "GTA": 48, + "GTC": 49, + "GTG": 50, + "GTT": 51, + "TAA": 52, + "TAC": 53, + "TAG": 54, + "TAT": 55, + "TCA": 56, + "TCC": 57, + "TCG": 58, + "TCT": 59, + "TGA": 60, + "TGC": 61, + "TGG": 62, + "TGT": 63, + "TTA": 64, + "TTC": 65, + "TTG": 66, + "TTT": 67 + } +} \ No newline at end of file diff --git a/precompute_embeddings.py b/precompute_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..57051e998c21f54d1c50ad076b09476969516d8a --- /dev/null +++ b/precompute_embeddings.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python +""" +Precompute species embeddings for CodonTranslator training. +Protein embeddings are now computed on-the-fly using integrated ESM-C model. + +Steps: +1. Build taxonomy database from GBIF API +2. Generate species embeddings using Qwen3-Embedding-0.6B +""" + +import os +import json +import logging +import argparse +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import glob +import requests +import time +from collections import defaultdict + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def build_taxonomy_database(species_list: List[str]) -> Dict[str, str]: + """Query GBIF API for comprehensive phylogenetic taxonomy of species. + + Creates detailed taxonomic descriptions for better species embeddings. + """ + taxonomy_db = {} + base_url = "https://api.gbif.org/v1/species/match" + + logger.info(f"Building taxonomy database for {len(species_list)} species...") + for species in tqdm(species_list, desc="Querying GBIF"): + if not species or species in taxonomy_db: + continue + + try: + response = requests.get(base_url, params={"name": species}) + if response.status_code == 200: + data = response.json() + if data.get("matchType") != "NONE": + # Build comprehensive taxonomy description + parts = [] + + # Add scientific classification + taxonomy = [] + for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]: + if rank in data and data[rank]: + taxonomy.append(data[rank]) + + if taxonomy: + parts.append("Taxonomy: " + " > ".join(taxonomy)) + + # Add common name if available + if "vernacularName" in data and data["vernacularName"]: + parts.append(f"Common name: {data['vernacularName']}") + + # Add confidence score + if "confidence" in data: + parts.append(f"Match confidence: {data['confidence']}%") + + # Add status (accepted, synonym, etc.) + if "status" in data: + parts.append(f"Status: {data['status']}") + + # Combine all parts into comprehensive description + taxonomy_db[species] = ". ".join(parts) if parts else species + else: + # No match found - use species name with indicator + taxonomy_db[species] = f"Species: {species} (no GBIF match)" + else: + taxonomy_db[species] = f"Species: {species} (query failed)" + + # Rate limiting + time.sleep(0.1) + except Exception as e: + logger.warning(f"Error querying GBIF for {species}: {e}") + taxonomy_db[species] = f"Species: {species} (error)" + + logger.info(f"Taxonomy database built with {len(taxonomy_db)} entries") + return taxonomy_db + + +def generate_species_embeddings_qwen( + species_list: List[str], + taxonomy_db: Dict[str, str], + device: str = "cuda", + pooling: str = "last" # 'last' -> single vector; 'sequence'/'none' -> variable-length tokens +) -> Tuple[Dict[str, int], Dict[int, np.ndarray]]: + """ + Generate species embeddings using Qwen3-Embedding-0.6B. + - pooling='last': returns one vector per species (fixed size) + - pooling='none': returns variable-length token embeddings per species + """ + import torch.nn.functional as F + from transformers import AutoTokenizer, AutoModel + + def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """Pool by taking the last valid token's embedding.""" + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + def get_detailed_instruct(task_description: str, query: str) -> str: + """Format the input with instruction for better embedding quality.""" + return f'Instruct: {task_description}\nQuery: {query}' + + logger.info("Loading Qwen3-Embedding-0.6B model...") + model_name = "Qwen/Qwen3-Embedding-0.6B" + + # Initialize with left padding for last token pooling + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left') + model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval() + + species_vocab = {} + species_embeddings = {} + + # Task description for species embedding + task = "Given a species taxonomy information, generate a biological embedding representing its taxonomic and evolutionary characteristics" + + for idx, species in enumerate(tqdm(species_list, desc="Generating embeddings")): + # Get comprehensive taxonomy string from GBIF query results + taxonomy_str = taxonomy_db.get(species, species) + + # Format with instruction for better semantic understanding + input_text = get_detailed_instruct(task, taxonomy_str) + + # Generate embeddings + with torch.no_grad(): + inputs = tokenizer( + input_text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + outputs = model(**inputs) + hidden = outputs.last_hidden_state # [1, L, D] + if pooling == 'last': + pooled_embedding = last_token_pool(hidden, inputs['attention_mask']) + normalized_embedding = F.normalize(pooled_embedding, p=2, dim=1) + species_embedding = normalized_embedding.squeeze(0).cpu().numpy() # [D] + else: + # Variable-length token embeddings (normalize per token) + tok = hidden.squeeze(0) # [L, D] + tok = F.normalize(tok, p=2, dim=-1) + species_embedding = tok.cpu().numpy() # [L, D] + + species_vocab[species] = idx + species_embeddings[idx] = species_embedding + + logger.info(f"Generated {'fixed-size' if pooling=='last' else 'variable-length'} embeddings for {len(species_vocab)} species") + return species_vocab, species_embeddings + + +def save_species_embeddings_memmap( + species_vocab: Dict[str, int], + species_embeddings: Dict[int, np.ndarray], + output_dir: str +) -> None: + """Save fixed-size species embeddings as memory-mapped file.""" + os.makedirs(output_dir, exist_ok=True) + + # Save vocabulary + vocab_path = os.path.join(output_dir, "species_vocab.json") + with open(vocab_path, 'w') as f: + json.dump(species_vocab, f, indent=2) + + # All embeddings should have the same dimension now + num_species = len(species_embeddings) + embed_dim = next(iter(species_embeddings.values())).shape[0] # Should be 1024 + + # Create memmap for fixed-size embeddings + emb_path = os.path.join(output_dir, "species_embeddings.bin") + mmap = np.memmap(emb_path, dtype=np.float32, mode='w+', shape=(num_species, embed_dim)) + + # Store embeddings directly by ID + for species_id, emb in species_embeddings.items(): + mmap[species_id] = emb.astype(np.float32) + + # Flush to disk + del mmap + + # Save metadata + metadata = { + "num_species": num_species, + "embedding_dim": embed_dim, + "embedding_type": "fixed_size", + "pooling_method": "last_token", + "normalization": "L2", + "model": "Qwen/Qwen3-Embedding-0.6B" + } + + metadata_path = os.path.join(output_dir, "species_metadata.json") + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + + logger.info(f"Saved {num_species} fixed-size species embeddings to {emb_path}") + logger.info(f"Embedding dimension: {embed_dim}") + logger.info(f"Saved metadata to {metadata_path}") + + +def save_species_token_embeddings_memmap( + species_vocab: Dict[str, int], + species_tok_embeddings: Dict[int, np.ndarray], + output_dir: str, + dtype: str = 'float32' +) -> None: + """Save variable-length token embeddings into a flat memmap with index.""" + os.makedirs(output_dir, exist_ok=True) + + # Save vocabulary + vocab_path = os.path.join(output_dir, "species_vocab.json") + with open(vocab_path, 'w') as f: + json.dump(species_vocab, f, indent=2) + + # Compute totals and dims + embed_dim = next(iter(species_tok_embeddings.values())).shape[1] + total_tokens = int(sum(v.shape[0] for v in species_tok_embeddings.values())) + + emb_path = os.path.join(output_dir, "species_tok_emb.bin") + mmap = np.memmap(emb_path, dtype=np.float32 if dtype=='float32' else np.float16, mode='w+', shape=(total_tokens, embed_dim)) + + # Build index + index = {} + offset = 0 + for sid, arr in species_tok_embeddings.items(): + L = int(arr.shape[0]) + mmap[offset: offset + L] = arr.astype(np.float32 if dtype=='float32' else np.float16) + index[str(sid)] = {"offset": offset, "length": L} + offset += L + + del mmap + + with open(os.path.join(output_dir, "species_index.json"), 'w') as f: + json.dump(index, f, indent=2) + + meta = { + "embedding_dim": embed_dim, + "dtype": dtype, + "total_tokens": total_tokens, + "embedding_type": "variable_length", + "pooling_method": "none", + "model": "Qwen/Qwen3-Embedding-0.6B" + } + with open(os.path.join(output_dir, "metadata.json"), 'w') as f: + json.dump(meta, f, indent=2) + logger.info(f"Saved variable-length species token embeddings to {emb_path} with {total_tokens} tokens total") + + +def filter_sequences_by_length(df: pd.DataFrame, max_protein_length: int = 2048) -> pd.DataFrame: + """Filter sequences to prevent CUDA OOM during training.""" + initial_count = len(df) + + # Filter by protein length + if 'protein_seq' in df.columns: + df = df[df['protein_seq'].str.len() <= max_protein_length] + + # Filter by CDS length (3x protein length) + if 'cds_DNA' in df.columns: + max_cds_length = max_protein_length * 3 + df = df[df['cds_DNA'].str.len() <= max_cds_length] + + final_count = len(df) + if final_count < initial_count: + logger.info(f"Filtered from {initial_count} to {final_count} sequences (max_protein_length={max_protein_length})") + + return df + + +def collect_unique_values_from_shards( + shards_glob: str, + column: str, + max_items: Optional[int] = None +) -> List[str]: + """Stream over Parquet shards to collect unique values from a column.""" + unique_values = set() + shard_files = sorted(glob.glob(shards_glob)) + + if not shard_files: + raise ValueError(f"No parquet files found matching {shards_glob}") + + logger.info(f"Scanning {len(shard_files)} shards for unique {column} values...") + + for shard_file in tqdm(shard_files, desc=f"Collecting {column}"): + # Some datasets use different casing (e.g., 'taxon' vs 'Taxon'). Resolve robustly. + try: + import pyarrow.parquet as pq # type: ignore + pf = pq.ParquetFile(shard_file) + names = set(pf.schema.names) + resolved = column + if resolved not in names: + lower_map = {n.lower(): n for n in names} + resolved = lower_map.get(column.lower(), column) + except Exception: + resolved = column + + df = pd.read_parquet(shard_file, columns=[resolved]) + # Canonicalize to the requested column name for downstream logic. + if resolved != column and resolved in df.columns and column not in df.columns: + df = df.rename(columns={resolved: column}) + unique_values.update(df[column].dropna().unique()) + + if max_items and len(unique_values) >= max_items: + break + + result = sorted(list(unique_values))[:max_items] if max_items else sorted(list(unique_values)) + logger.info(f"Collected {len(result)} unique {column} values") + return result + + +def collect_stage1_species(shards_glob: str) -> List[str]: + """Extract unique species from Stage-1 shards.""" + return collect_unique_values_from_shards(shards_glob, "Taxon") + + +def prepare_species_from_stage1_shards( + shards_glob: str, + output_dir: str, + device: str = "cuda", + resume: bool = False, + species_pooling: str = "last" +) -> None: + """End-to-end species embedding generation from Stage-1 shards.""" + os.makedirs(output_dir, exist_ok=True) + + # Check for existing files + vocab_path = os.path.join(output_dir, "species_vocab.json") + if resume and os.path.exists(vocab_path): + logger.info("Species embeddings already exist. Skipping generation.") + return + + # Collect unique species + species_list = collect_stage1_species(shards_glob) + logger.info(f"Found {len(species_list)} unique species in shards") + + # Build taxonomy database + taxonomy_cache_path = os.path.join(output_dir, "taxonomy_database.json") + if resume and os.path.exists(taxonomy_cache_path): + logger.info("Loading cached taxonomy database...") + with open(taxonomy_cache_path, 'r') as f: + taxonomy_db = json.load(f) + else: + taxonomy_db = build_taxonomy_database(species_list) + with open(taxonomy_cache_path, 'w') as f: + json.dump(taxonomy_db, f, indent=2) + + # Generate embeddings + species_vocab, species_embeddings = generate_species_embeddings_qwen( + species_list, taxonomy_db, device, pooling=species_pooling + ) + + # Save per requested pooling + if species_pooling == 'last': + save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir) + else: + save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir) + + logger.info("Species embedding preparation complete") + + +def create_precomputed_dataset( + input_csv: Optional[str], + output_dir: str, + device: str = "cuda", + batch_size: int = 50, + max_protein_length: int = 2048, + resume: bool = False, + species_pooling: str = "last" +): + """ + Create embedding dataset with species-only precomputation. + Protein embeddings will be computed on-the-fly during training. + """ + os.makedirs(output_dir, exist_ok=True) + + # Skip if resuming and files exist + if resume and os.path.exists(os.path.join(output_dir, "species_vocab.json")): + logger.info("Precomputed dataset already exists. Use --resume=False to regenerate.") + return + + # Load data + logger.info(f"Loading data from {input_csv}...") + if input_csv.endswith('.parquet'): + df = pd.read_parquet(input_csv) + else: + df = pd.read_csv(input_csv) + + # Accept either 'Taxon' or 'taxon' as the species column. + if "Taxon" not in df.columns and "taxon" in df.columns: + df = df.rename(columns={"taxon": "Taxon"}) + + # Filter sequences by length + df = filter_sequences_by_length(df, max_protein_length) + + # === Species Embeddings === + logger.info("=== Generating Species Embeddings ===") + unique_species = df["Taxon"].dropna().unique().tolist() + logger.info(f"Found {len(unique_species)} unique species") + + # Build taxonomy database + taxonomy_db = build_taxonomy_database(unique_species) + + # Save taxonomy database + taxonomy_path = os.path.join(output_dir, "taxonomy_database.json") + with open(taxonomy_path, 'w') as f: + json.dump(taxonomy_db, f, indent=2) + + # Generate species embeddings + species_vocab, species_embeddings = generate_species_embeddings_qwen( + unique_species, taxonomy_db, device, pooling=species_pooling + ) + + if species_pooling == 'last': + save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir) + else: + save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir) + + # Save metadata + metadata = { + "num_sequences": len(df), + "num_species": len(unique_species), + "species_embedding_model": "Qwen/Qwen3-Embedding-0.6B", + "species_embedding_dim": 1024, # Qwen3 dimension + "max_protein_length": max_protein_length, + } + + with open(os.path.join(output_dir, "metadata.json"), 'w') as f: + json.dump(metadata, f, indent=2) + + logger.info(f"Dataset creation completed. Species embeddings are precomputed.") + logger.info("Protein embeddings will be computed on-the-fly during training using integrated ESM-C.") + + +def main(): + parser = argparse.ArgumentParser(description="Precompute species embeddings for CodonTranslator") + + # Data source options + parser.add_argument("--input_csv", type=str, + help="Path to input CSV/Parquet file") + parser.add_argument("--from_stage1_shards", action="store_true", + help="Generate from Stage-1 Parquet shards instead of CSV") + parser.add_argument("--stage1_shards_glob", type=str, default="./data/shards/*.parquet", + help="Glob pattern for Stage-1 shards") + + # Output + parser.add_argument("--output_dir", type=str, required=True, + help="Output directory for precomputed embeddings") + + # Processing options + parser.add_argument("--device", type=str, default="cuda", + help="Device for model inference") + parser.add_argument("--batch_size", type=int, default=50, + help="Batch size for embedding generation") + parser.add_argument("--max_protein_length", type=int, default=2048, + help="Maximum protein sequence length") + parser.add_argument("--resume", action="store_true", + help="Resume from checkpoint if available") + parser.add_argument("--species_pooling", type=str, choices=["last", "sequence", "none"], default="last", + help="'last' for single-token; 'sequence' for variable-length token embeddings") + + args = parser.parse_args() + + # Route to appropriate function + if args.from_stage1_shards: + prepare_species_from_stage1_shards( + args.stage1_shards_glob, + args.output_dir, + args.device, + args.resume, + args.species_pooling + ) + elif args.input_csv: + create_precomputed_dataset( + args.input_csv, + args.output_dir, + args.device, + args.batch_size, + args.max_protein_length, + args.resume, + args.species_pooling + ) + else: + raise ValueError("Must specify either --input_csv or --from_stage1_shards") + + logger.info("Precomputation complete!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f2ca62a8e7c88bea722c9904927f325e8feb5797 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "CodonTranslator" +version = "0.1.1" +description = "Sampling codon sequences conditioned on species and protein using a GPT model" +readme = "README.md" +requires-python = ">=3.9" +license = {text = "MIT"} +authors = [{name = "CodonTranslator Team"}] +dependencies = [ + "torch>=2.4", + "transformers>=4.57.0", + "esm>=3.2.3", + "safetensors>=0.7.0", + "numpy>=2.2.0", + "huggingface-hub>=0.36.0", +] + +[tool.setuptools] +package-dir = {"" = "."} +packages = ["CodonTranslator", "codontranslator"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7fcada2447db6f14738b39c14a8a71497e08df26 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +torch>=2.4 +transformers>=4.57.0 +esm>=3.2.3 +safetensors>=0.7.0 +numpy>=2.2.0 +huggingface-hub>=0.36.0 +accelerate>=1.9.0 +pyarrow>=21.0.0 +pandas>=2.3.0 +duckdb>=1.5.0 +biopython>=1.85 +wandb>=0.21.0 diff --git a/resplit_data_v3.py b/resplit_data_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..260403a7be0d89c5bb0b5dc0ff9e78da34187d54 --- /dev/null +++ b/resplit_data_v3.py @@ -0,0 +1,1444 @@ +#!/usr/bin/env python3 +""" +Resplit `data_v2/` into leakage-safe `data_v3_rebuild/` using MMseqs2 clustering. + +Default policy for the current rebuild: + - Cluster `protein_seq` with MMseqs2 `linclust` + - Define species by normalized binomial name (`genus species`) + - Test species are exactly the normalized species present in `data_v2/test` + - Validation is cluster-unseen but species-seen + - Mixed seen/heldout clusters keep heldout rows in test and drop seen rows + +Typical usage (end-to-end): + python resplit_data_v3.py all --threads 32 --split-memory-limit 120G --num-shards 256 +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import stat +import subprocess +import sys +import time +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + + +def _default_mmseqs_path() -> str: + cand = Path("MMseqs2/build/bin/mmseqs") + if cand.exists(): + return str(cand) + return "mmseqs" + + +def _run(cmd: List[str], *, cwd: Optional[str] = None, env: Optional[dict] = None) -> None: + pretty = " ".join(cmd) + print(f"+ {pretty}", flush=True) + subprocess.run(cmd, cwd=cwd, env=env, check=True) + + +def _sql_escape_path(path: str) -> str: + return path.replace("'", "''") + + +def _expand_parquet_inputs(inp: str) -> List[str]: + import glob + + p = Path(inp) + if p.exists() and p.is_dir(): + files = sorted(str(x) for x in p.rglob("*.parquet")) + else: + files = sorted(glob.glob(inp)) + + seen = set() + out: List[str] = [] + for f in files: + if f not in seen: + out.append(f) + seen.add(f) + return out + + +def _duckdb_parquet_source(inp: str, limit_files: int = 0) -> str: + files = _expand_parquet_inputs(inp) + if not files: + raise SystemExit(f"No parquet files found for {inp!r}") + if limit_files and int(limit_files) > 0: + files = files[: int(limit_files)] + quoted = ", ".join(f"'{_sql_escape_path(fp)}'" for fp in files) + return f"read_parquet([{quoted}])" + + +def _mem_total_bytes() -> Optional[int]: + try: + with open("/proc/meminfo", "r", encoding="utf-8") as f: + for line in f: + if line.startswith("MemTotal:"): + parts = line.split() + kb = int(parts[1]) + return kb * 1024 + except OSError: + return None + except (ValueError, IndexError): + return None + return None + + +def _parse_mmseqs_bytes(s: str) -> Optional[int]: + s = (s or "").strip() + if not s: + return None + up = s.upper() + suffix = up[-1] + num_part = up[:-1] + unit = suffix + if suffix == "B" and len(up) >= 2 and up[-2] in "KMGT": + unit = up[-2] + num_part = up[:-2] + if unit not in "BKMGT": + return None + try: + val = float(num_part) + except ValueError: + return None + mult = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}[unit] + return int(val * mult) + + +def _format_bytes(n: int) -> str: + for unit, div in [("TiB", 1024**4), ("GiB", 1024**3), ("MiB", 1024**2), ("KiB", 1024)]: + if n >= div: + return f"{n / div:.1f}{unit}" + return f"{n}B" + + +def _seq_id_sql() -> str: + # Keep the stable row identifier aligned with the existing pipeline. + return "coalesce(protein_refseq_id, '') || '|' || coalesce(RefseqID, '')" + + +def _taxon_norm_sql(col: str = "taxon") -> str: + return f"regexp_replace(lower(trim(coalesce({col}, ''))), '\\\\s+', ' ', 'g')" + + +def _species_key_sql(mode: str, col: str = "taxon") -> str: + norm = _taxon_norm_sql(col) + if mode == "taxon": + return norm + if mode == "binomial": + return ( + f"CASE " + f"WHEN strpos({norm}, ' ') > 0 " + f"THEN split_part({norm}, ' ', 1) || ' ' || split_part({norm}, ' ', 2) " + f"ELSE {norm} END" + ) + raise ValueError(f"Unsupported species key mode: {mode}") + + +def _protein_norm_sql(col: str = "protein_seq") -> str: + cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')" + no_stop = f"regexp_replace({cleaned}, '[_*]+$', '')" + return f"regexp_replace({no_stop}, '[^A-Z]', 'X', 'g')" + + +def _cds_norm_sql(col: str = "cds_DNA") -> str: + cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')" + return f"regexp_replace({cleaned}, '[^ACGTN]', 'N', 'g')" + + +def _seq_expr_sql(seq_space: str) -> str: + if seq_space == "protein": + return _protein_norm_sql("protein_seq") + if seq_space == "cds": + return _cds_norm_sql("cds_DNA") + raise ValueError(f"Unsupported seq space: {seq_space}") + + +def _seq_space_input_col(seq_space: str) -> str: + if seq_space == "protein": + return "protein_seq" + if seq_space == "cds": + return "cds_DNA" + raise ValueError(f"Unsupported seq space: {seq_space}") + + +def _mmseqs_dbtype(seq_space: str) -> str: + if seq_space == "protein": + return "1" + if seq_space == "cds": + return "2" + raise ValueError(f"Unsupported seq space: {seq_space}") + + +def _default_max_input_seq_len(seq_space: str) -> int: + if seq_space == "protein": + # MMseqs linclust hit an internal SW bug on a tiny tail of ultra-long proteins + # (~39k aa+). Filtering this tail removes <0.01% of rows and keeps the run stable. + return 20_000 + return 0 + + +def _ensure_mmseqs_ready(mmseqs: str) -> Tuple[str, Dict[str, str]]: + path = Path(mmseqs) + env = os.environ.copy() + + if path.exists(): + mode = path.stat().st_mode + if not (mode & stat.S_IXUSR): + path.chmod(mode | stat.S_IXUSR) + + py = Path(sys.executable).resolve() + env_root = py.parent.parent + conda_root = env_root.parent.parent if env_root.parent.name == "envs" else env_root.parent + lib_candidates = [env_root / "lib", conda_root / "lib"] + libs = [str(p) for p in lib_candidates if p.exists()] + if libs: + current = env.get("LD_LIBRARY_PATH", "") + env["LD_LIBRARY_PATH"] = ":".join(libs + ([current] if current else [])) + + return str(path if path.exists() else mmseqs), env + + +def _ensure_output_parent(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + + +def cmd_make_fasta(args: argparse.Namespace) -> None: + out_fasta = Path(args.output_fasta) + _ensure_output_parent(out_fasta) + + import duckdb + + con = duckdb.connect() + con.execute(f"PRAGMA threads={int(args.threads)};") + con.execute("PRAGMA enable_progress_bar=true;") + + source_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) + out_path = _sql_escape_path(str(out_fasta)) + seq_id = _seq_id_sql() + seq_expr = _seq_expr_sql(args.seq_space) + raw_col = _seq_space_input_col(args.seq_space) + max_input_seq_len = int(args.max_input_seq_len) + if max_input_seq_len <= 0: + max_input_seq_len = _default_max_input_seq_len(args.seq_space) + len_filter = ( + f"AND length({seq_expr}) <= {max_input_seq_len}" + if max_input_seq_len > 0 + else "" + ) + + sql = f""" + COPY ( + SELECT + '>' || ({seq_id}) AS header, + {seq_expr} AS seq + FROM {source_sql} + WHERE {raw_col} IS NOT NULL + AND length({seq_expr}) > 0 + {len_filter} + AND length(({seq_id})) > 1 + {f"LIMIT {int(args.limit_rows)}" if args.limit_rows and int(args.limit_rows) > 0 else ""} + ) + TO '{out_path}' + (FORMAT CSV, DELIMITER '\n', QUOTE '', ESCAPE '', HEADER FALSE); + """ + t0 = time.time() + con.execute(sql) + print( + f"Wrote FASTA: {out_fasta} seq_space={args.seq_space} " + f"max_input_seq_len={max_input_seq_len if max_input_seq_len > 0 else 'none'} " + f"(elapsed_s={time.time() - t0:.1f})" + ) + + +def cmd_mmseqs_cluster(args: argparse.Namespace) -> None: + mmseqs, env = _ensure_mmseqs_ready(args.mmseqs) + workdir = Path(args.workdir) + workdir.mkdir(parents=True, exist_ok=True) + + fasta = Path(args.fasta) + if not fasta.exists(): + raise SystemExit(f"FASTA not found: {fasta}") + + seqdb = workdir / "seqdb" + clu = workdir / "clu" + tmp = workdir / "tmp" + tsv = workdir / "clu.tsv" + + if args.overwrite: + for p in (seqdb, clu, tmp, tsv): + if p.is_dir(): + shutil.rmtree(p, ignore_errors=True) + else: + for suffix in ("", ".dbtype", ".index", ".lookup", ".source"): + try: + os.remove(str(p) + suffix) + except OSError: + pass + + tmp.mkdir(parents=True, exist_ok=True) + + _run( + [ + mmseqs, + "createdb", + str(fasta), + str(seqdb), + "--dbtype", + _mmseqs_dbtype(args.seq_space), + "--shuffle", + "0", + "--createdb-mode", + "1", + "--threads", + str(int(args.threads)), + ], + env=env, + ) + + linclust_cmd = [ + mmseqs, + "linclust", + str(seqdb), + str(clu), + str(tmp), + "--min-seq-id", + str(float(args.min_seq_id)), + "-c", + str(float(args.coverage)), + "--cov-mode", + str(int(args.cov_mode)), + "--cluster-mode", + str(int(args.cluster_mode)), + "--threads", + str(int(args.threads)), + "--max-seq-len", + str(int(args.max_seq_len)), + "--remove-tmp-files", + "1" if args.remove_tmp_files else "0", + ] + if args.split_memory_limit: + mem_total = _mem_total_bytes() + limit_bytes = _parse_mmseqs_bytes(args.split_memory_limit) + if mem_total and limit_bytes and limit_bytes > mem_total: + print( + f"WARNING: --split-memory-limit={args.split_memory_limit} ({_format_bytes(limit_bytes)}) " + f"exceeds system MemTotal ({_format_bytes(mem_total)}). " + "MMseqs2 may under-split and crash; consider lowering it or leaving it empty.", + file=sys.stderr, + flush=True, + ) + linclust_cmd += ["--split-memory-limit", str(args.split_memory_limit)] + if args.kmer_per_seq_scale is not None: + linclust_cmd += ["--kmer-per-seq-scale", str(float(args.kmer_per_seq_scale))] + + _run(linclust_cmd, env=env) + _run([mmseqs, "createtsv", str(seqdb), str(seqdb), str(clu), str(tsv)], env=env) + print(f"Wrote cluster TSV: {tsv}") + + +def cmd_make_seq_cluster(args: argparse.Namespace) -> None: + import duckdb + + tsv = Path(args.cluster_tsv) + if not tsv.exists(): + raise SystemExit(f"Cluster TSV not found: {tsv}") + out = Path(args.output_parquet) + _ensure_output_parent(out) + + con = duckdb.connect() + con.execute(f"PRAGMA threads={int(args.threads)};") + con.execute("PRAGMA enable_progress_bar=true;") + + tsv_path = _sql_escape_path(str(tsv)) + out_path = _sql_escape_path(str(out)) + + sql = f""" + COPY ( + SELECT DISTINCT + seq_id, + cluster_id + FROM read_csv( + '{tsv_path}', + delim='\\t', + header=false, + columns={{'cluster_id':'VARCHAR','seq_id':'VARCHAR'}} + ) + ) + TO '{out_path}' + (FORMAT PARQUET); + """ + t0 = time.time() + con.execute(sql) + print(f"Wrote seq→cluster parquet: {out} (elapsed_s={time.time() - t0:.1f})") + + +def _write_cluster_split_parquet( + con, + *, + cluster_split_path: Path, + seed: int, + val_frac: float, +) -> Dict[str, int]: + import pyarrow as pa + import pyarrow.parquet as pq + + cluster_split_path.parent.mkdir(parents=True, exist_ok=True) + if cluster_split_path.exists(): + cluster_split_path.unlink() + + total_seen_rows = int( + con.execute( + "SELECT coalesce(sum(n_total), 0)::BIGINT FROM cluster_flags WHERE n_test = 0" + ).fetchone()[0] + ) + target_val_rows = int(total_seen_rows * float(val_frac)) + + species_remaining = { + species_key: int(n_clusters) + for species_key, n_clusters in con.execute( + """ + SELECT + cc.species_key, + count(*)::BIGINT AS n_clusters + FROM cluster_counts cc + JOIN cluster_flags cf USING (cluster_id) + WHERE cf.n_test = 0 + GROUP BY cc.species_key + """ + ).fetchall() + } + + cur = con.execute( + f""" + SELECT + cf.cluster_id, + cf.n_total, + abs(hash(cf.cluster_id || ':{seed}')) AS rnd, + cc.species_key + FROM cluster_flags cf + JOIN cluster_counts cc USING (cluster_id) + WHERE cf.n_test = 0 + ORDER BY rnd, cf.cluster_id, cc.species_key + """ + ) + + writer = None + batch_cluster_ids: List[str] = [] + batch_splits: List[str] = [] + val_rows = 0 + train_clusters = 0 + val_clusters = 0 + current_cluster: Optional[str] = None + current_n_total = 0 + current_species: List[str] = [] + + def flush_current() -> None: + nonlocal writer, val_rows, train_clusters, val_clusters + nonlocal current_cluster, current_n_total, current_species + if current_cluster is None: + return + can_val = ( + val_rows < target_val_rows + and all(species_remaining.get(species_key, 0) > 1 for species_key in current_species) + ) + split = "val" if can_val else "train" + if can_val: + for species_key in current_species: + species_remaining[species_key] -= 1 + val_rows += int(current_n_total) + val_clusters += 1 + else: + train_clusters += 1 + + batch_cluster_ids.append(current_cluster) + batch_splits.append(split) + if len(batch_cluster_ids) >= 200_000: + table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits}) + if writer is None: + writer = pq.ParquetWriter(str(cluster_split_path), table.schema) + writer.write_table(table) + batch_cluster_ids.clear() + batch_splits.clear() + + while True: + rows = cur.fetchmany(200_000) + if not rows: + break + for cluster_id, n_total, _rnd, species_key in rows: + cluster_id = str(cluster_id) + species_key = str(species_key) + if current_cluster is None: + current_cluster = cluster_id + current_n_total = int(n_total) + current_species = [species_key] + continue + if cluster_id != current_cluster: + flush_current() + current_cluster = cluster_id + current_n_total = int(n_total) + current_species = [species_key] + continue + current_species.append(species_key) + + flush_current() + if batch_cluster_ids: + table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits}) + if writer is None: + writer = pq.ParquetWriter(str(cluster_split_path), table.schema) + writer.write_table(table) + elif writer is None: + empty = pa.table( + { + "cluster_id": pa.array([], type=pa.string()), + "split": pa.array([], type=pa.string()), + } + ) + writer = pq.ParquetWriter(str(cluster_split_path), empty.schema) + writer.write_table(empty) + if writer is not None: + writer.close() + + return { + "nonheldout_total_rows": total_seen_rows, + "target_val_rows": target_val_rows, + "actual_val_rows": val_rows, + "train_clusters": train_clusters, + "val_clusters": val_clusters, + } + + +def cmd_make_seq_split(args: argparse.Namespace) -> None: + import duckdb + + seq_cluster = Path(args.seq_cluster_parquet) + if not seq_cluster.exists(): + raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") + + out = Path(args.output_parquet) + cluster_split = Path(args.cluster_split_parquet) + _ensure_output_parent(out) + _ensure_output_parent(cluster_split) + + con = duckdb.connect() + con.execute(f"PRAGMA threads={int(args.threads)};") + con.execute("PRAGMA enable_progress_bar=true;") + + input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) + heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0) + seq_cluster_path = _sql_escape_path(str(seq_cluster)) + out_path = _sql_escape_path(str(out)) + + seq_id = _seq_id_sql() + species_key = _species_key_sql(args.species_key_mode, "taxon") + protein_norm = _protein_norm_sql("protein_seq") + + con.execute( + f""" + CREATE TEMP TABLE heldout_species AS + SELECT DISTINCT {species_key} AS species_key + FROM {heldout_sql} + WHERE {species_key} != ''; + """ + ) + + con.execute( + f""" + CREATE TEMP TABLE cluster_counts AS + WITH base AS ( + SELECT + {seq_id} AS seq_id, + {species_key} AS species_key + FROM {input_sql} + WHERE length(({seq_id})) > 1 + AND {species_key} != '' + ) + SELECT + sc.cluster_id, + base.species_key, + count(*)::BIGINT AS n + FROM base + JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) + GROUP BY sc.cluster_id, base.species_key; + """ + ) + + con.execute( + """ + CREATE TEMP TABLE cluster_flags AS + SELECT + cluster_id, + sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test, + sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen, + sum(n)::BIGINT AS n_total, + count(*)::BIGINT AS n_species + FROM cluster_counts + GROUP BY cluster_id; + """ + ) + + t0 = time.time() + split_summary = _write_cluster_split_parquet( + con, + cluster_split_path=cluster_split, + seed=int(args.seed), + val_frac=float(args.val_frac), + ) + print( + "Cluster assignment summary: " + f"train_clusters={split_summary['train_clusters']:,} " + f"val_clusters={split_summary['val_clusters']:,} " + f"target_val_rows={split_summary['target_val_rows']:,} " + f"actual_val_rows={split_summary['actual_val_rows']:,} " + f"(elapsed_s={time.time() - t0:.1f})" + ) + + cluster_split_path = _sql_escape_path(str(cluster_split)) + con.execute( + f""" + COPY ( + WITH base AS ( + SELECT DISTINCT + {seq_id} AS seq_id, + {species_key} AS species_key, + {protein_norm} AS protein_norm + FROM {input_sql} + WHERE length(({seq_id})) > 1 + AND {species_key} != '' + ), + joined AS ( + SELECT + base.seq_id, + base.species_key, + base.protein_norm, + sc.cluster_id + FROM base + LEFT JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) + ), + labeled AS ( + SELECT + j.seq_id, + j.species_key, + j.protein_norm, + CASE + WHEN j.cluster_id IS NULL THEN 'drop' + WHEN j.species_key IN (SELECT species_key FROM heldout_species) THEN 'test' + WHEN coalesce(cf.n_test, 0) > 0 THEN 'drop' + ELSE coalesce(cs.split, 'drop') + END AS split + FROM joined j + LEFT JOIN cluster_flags cf USING (cluster_id) + LEFT JOIN read_parquet('{cluster_split_path}') cs USING (cluster_id) + ), + protein_flags AS ( + SELECT + protein_norm, + max(CASE WHEN split = 'test' THEN 1 ELSE 0 END) AS has_test, + max(CASE WHEN split = 'train' THEN 1 ELSE 0 END) AS has_train + FROM labeled + WHERE length(protein_norm) > 0 + GROUP BY protein_norm + ), + guarded AS ( + SELECT + l.seq_id, + l.species_key, + CASE + WHEN l.split = 'drop' THEN 'drop' + WHEN length(l.protein_norm) = 0 THEN l.split + WHEN coalesce(pf.has_test, 0) = 1 AND l.split IN ('train', 'val') THEN 'drop' + WHEN coalesce(pf.has_train, 0) = 1 AND l.split = 'val' THEN 'drop' + ELSE l.split + END AS split + FROM labeled l + LEFT JOIN protein_flags pf USING (protein_norm) + ), + dedup AS ( + SELECT + seq_id, + CASE + WHEN count(DISTINCT species_key) > 1 THEN 'drop' + WHEN count(DISTINCT split) > 1 THEN 'drop' + ELSE any_value(split) + END AS split + FROM guarded + GROUP BY seq_id + ) + SELECT seq_id, split FROM dedup + ) + TO '{out_path}' + (FORMAT PARQUET); + """ + ) + + rows = con.execute( + f""" + WITH base AS ( + SELECT {seq_id} AS seq_id + FROM {input_sql} + ) + SELECT s.split, count(*)::BIGINT AS n_rows + FROM base + JOIN read_parquet('{out_path}') s USING (seq_id) + GROUP BY s.split + ORDER BY n_rows DESC; + """ + ).fetchall() + print("Split summary (rows):") + for split, n in rows: + print(f" {split}\t{n:,}") + + print(f"Wrote cluster→split parquet: {cluster_split}") + print(f"Wrote seq→split parquet: {out}") + + +def cmd_write_data_v3(args: argparse.Namespace) -> None: + import duckdb + + seq_split = Path(args.seq_split_parquet) + if not seq_split.exists(): + raise SystemExit(f"seq_split parquet not found: {seq_split}") + seq_cluster = Path(args.seq_cluster_parquet) + if args.representatives_only and not seq_cluster.exists(): + raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") + + out_root = Path(args.output_root) + out_root.mkdir(parents=True, exist_ok=True) + (out_root / "_work").mkdir(parents=True, exist_ok=True) + + for split_dir in (out_root / "train", out_root / "val", out_root / "test"): + if split_dir.exists(): + if not args.overwrite: + raise SystemExit(f"Output split directory exists: {split_dir} (pass --overwrite)") + shutil.rmtree(split_dir) + split_dir.mkdir(parents=True, exist_ok=True) + + con = duckdb.connect() + con.execute(f"PRAGMA threads={int(args.threads)};") + con.execute("PRAGMA enable_progress_bar=true;") + + input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) + seq_split_path = _sql_escape_path(str(seq_split)) + seq_cluster_path = _sql_escape_path(str(seq_cluster)) + seq_id = _seq_id_sql() + + num_shards = int(args.num_shards) + if num_shards <= 0: + raise SystemExit("--num-shards must be > 0") + + for split in ("train", "val", "test"): + out_dir = _sql_escape_path(str(out_root / split)) + if args.representatives_only: + target_seq_ids_sql = f""" + SELECT min(s.seq_id) AS seq_id + FROM read_parquet('{seq_split_path}') s + JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) + WHERE s.split = '{split}' + GROUP BY sc.cluster_id + """ + else: + target_seq_ids_sql = f""" + SELECT DISTINCT s.seq_id + FROM read_parquet('{seq_split_path}') s + WHERE s.split = '{split}' + """ + sql = f""" + COPY ( + WITH target_seq_ids AS ( + {target_seq_ids_sql} + ), + rows AS ( + SELECT + p.*, + abs(hash({seq_id})) % {num_shards} AS shard + FROM {input_sql} p + JOIN target_seq_ids t + ON t.seq_id = ({seq_id}) + QUALIFY row_number() OVER (PARTITION BY ({seq_id}) ORDER BY ({seq_id})) = 1 + ) + SELECT * FROM rows + ) + TO '{out_dir}' + (FORMAT PARQUET, PARTITION_BY (shard)); + """ + t0 = time.time() + con.execute(sql) + print( + f"Wrote {split} parquets to {out_root / split} " + f"representatives_only={bool(args.representatives_only)} " + f"(elapsed_s={time.time() - t0:.1f})" + ) + + +def cmd_verify(args: argparse.Namespace) -> None: + import duckdb + + seq_cluster = Path(args.seq_cluster_parquet) + seq_split = Path(args.seq_split_parquet) + if not seq_cluster.exists(): + raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") + if not seq_split.exists(): + raise SystemExit(f"seq_split parquet not found: {seq_split}") + + con = duckdb.connect() + con.execute(f"PRAGMA threads={int(args.threads)};") + con.execute("PRAGMA enable_progress_bar=true;") + + input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) + heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0) + seq_cluster_path = _sql_escape_path(str(seq_cluster)) + seq_split_path = _sql_escape_path(str(seq_split)) + + seq_id = _seq_id_sql() + species_key = _species_key_sql(args.species_key_mode, "taxon") + protein_norm = _protein_norm_sql("protein_seq") + + con.execute( + f""" + CREATE TEMP TABLE heldout_species AS + SELECT DISTINCT {species_key} AS species_key + FROM {heldout_sql} + WHERE {species_key} != ''; + """ + ) + con.execute( + f""" + CREATE TEMP TABLE cluster_counts AS + WITH base AS ( + SELECT + {seq_id} AS seq_id, + {species_key} AS species_key + FROM {input_sql} + WHERE length(({seq_id})) > 1 + AND {species_key} != '' + ) + SELECT + sc.cluster_id, + base.species_key, + count(*)::BIGINT AS n + FROM base + JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) + GROUP BY sc.cluster_id, base.species_key; + """ + ) + con.execute( + """ + CREATE TEMP TABLE cluster_flags AS + SELECT + cluster_id, + sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test, + sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen, + sum(n)::BIGINT AS n_total, + count(*)::BIGINT AS n_species + FROM cluster_counts + GROUP BY cluster_id; + """ + ) + + split_seq_ids = { + split: int(n) + for split, n in con.execute( + f""" + SELECT split, count(*)::BIGINT AS n + FROM read_parquet('{seq_split_path}') + GROUP BY split + """ + ).fetchall() + } + split_rows = { + split: int(n) + for split, n in con.execute( + f""" + WITH base AS ( + SELECT {seq_id} AS seq_id FROM {input_sql} + ) + SELECT s.split, count(*)::BIGINT AS n + FROM base + JOIN read_parquet('{seq_split_path}') s USING (seq_id) + GROUP BY s.split + """ + ).fetchall() + } + + bad_clusters = int( + con.execute( + f""" + WITH keep AS ( + SELECT sc.cluster_id, ss.split + FROM read_parquet('{seq_cluster_path}') sc + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split != 'drop' + ) + SELECT count(*)::BIGINT + FROM ( + SELECT cluster_id + FROM keep + GROUP BY cluster_id + HAVING count(DISTINCT split) > 1 + ); + """ + ).fetchone()[0] + ) + print(f"clusters_spanning_splits(excluding drop) = {bad_clusters}") + + bad_test = int( + con.execute( + f""" + WITH base AS ( + SELECT {seq_id} AS seq_id, {species_key} AS species_key + FROM {input_sql} + ) + SELECT count(*)::BIGINT + FROM base + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split = 'test' + AND base.species_key NOT IN (SELECT species_key FROM heldout_species); + """ + ).fetchone()[0] + ) + print(f"test_rows_with_seen_species = {bad_test}") + + bad_val_species = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key + FROM {input_sql} + ), + labeled AS ( + SELECT base.species_key, ss.split + FROM base + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split IN ('train', 'val') + ), + train_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'train'), + val_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'val') + SELECT count(*)::BIGINT + FROM (SELECT species_key FROM val_species EXCEPT SELECT species_key FROM train_species); + """ + ).fetchone()[0] + ) + print(f"val_species_not_in_train = {bad_val_species}") + + protein_overlap_train_val = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm + FROM {input_sql} + WHERE length({protein_norm}) > 0 + ), + labeled AS ( + SELECT base.protein_norm, ss.split + FROM base + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split IN ('train', 'val') + ), + train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'), + val_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'val') + SELECT count(*)::BIGINT + FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM val_p); + """ + ).fetchone()[0] + ) + protein_overlap_train_test = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm + FROM {input_sql} + WHERE length({protein_norm}) > 0 + ), + labeled AS ( + SELECT base.protein_norm, ss.split + FROM base + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split IN ('train', 'test') + ), + train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'), + test_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'test') + SELECT count(*)::BIGINT + FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM test_p); + """ + ).fetchone()[0] + ) + print(f"exact_protein_overlap_train_val = {protein_overlap_train_val}") + print(f"exact_protein_overlap_train_test = {protein_overlap_train_test}") + + mixed_test_clusters = int( + con.execute( + "SELECT count(*)::BIGINT FROM cluster_flags WHERE n_test > 0 AND n_seen > 0" + ).fetchone()[0] + ) + exact_holdout_seen_conflicts = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT + {protein_norm} AS protein_norm, + {species_key} AS species_key + FROM {input_sql} + WHERE length({protein_norm}) > 0 + AND {species_key} != '' + ) + SELECT count(*)::BIGINT + FROM ( + SELECT protein_norm + FROM base + GROUP BY protein_norm + HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 + AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 + ); + """ + ).fetchone()[0] + ) + dropped_seen_rows_exact_holdout = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT + {seq_id} AS seq_id, + {species_key} AS species_key, + {protein_norm} AS protein_norm + FROM {input_sql} + WHERE length(({seq_id})) > 1 + AND {species_key} != '' + ), + conflict_proteins AS ( + SELECT protein_norm + FROM base + WHERE length(protein_norm) > 0 + GROUP BY protein_norm + HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 + AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 + ) + SELECT count(*)::BIGINT + FROM base + JOIN conflict_proteins USING (protein_norm) + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split = 'drop' + AND base.species_key NOT IN (SELECT species_key FROM heldout_species); + """ + ).fetchone()[0] + ) + dropped_val_rows_exact_train = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT + {seq_id} AS seq_id, + {protein_norm} AS protein_norm + FROM {input_sql} + WHERE length(({seq_id})) > 1 + AND length({protein_norm}) > 0 + ), + labeled AS ( + SELECT base.protein_norm, ss.split + FROM base + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + ), + train_proteins AS ( + SELECT DISTINCT protein_norm + FROM labeled + WHERE split = 'train' + ) + SELECT count(*)::BIGINT + FROM base + JOIN train_proteins USING (protein_norm) + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split = 'drop'; + """ + ).fetchone()[0] + ) + dropped_seen_rows_mixed = int( + con.execute( + f""" + WITH base AS ( + SELECT {seq_id} AS seq_id, {species_key} AS species_key + FROM {input_sql} + ) + SELECT count(*)::BIGINT + FROM base + JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) + JOIN cluster_flags cf USING (cluster_id) + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split = 'drop' + AND cf.n_test > 0 + AND base.species_key NOT IN (SELECT species_key FROM heldout_species); + """ + ).fetchone()[0] + ) + dropped_seen_seqids_mixed = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key + FROM {input_sql} + ) + SELECT count(*)::BIGINT + FROM base + JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) + JOIN cluster_flags cf USING (cluster_id) + JOIN read_parquet('{seq_split_path}') ss USING (seq_id) + WHERE ss.split = 'drop' + AND cf.n_test > 0 + AND base.species_key NOT IN (SELECT species_key FROM heldout_species); + """ + ).fetchone()[0] + ) + same_protein_multi_species = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT + {protein_norm} AS protein_norm, + {species_key} AS species_key + FROM {input_sql} + WHERE length({protein_norm}) > 0 + AND {species_key} != '' + ) + SELECT count(*)::BIGINT + FROM ( + SELECT protein_norm + FROM base + GROUP BY protein_norm + HAVING count(DISTINCT species_key) > 1 + ); + """ + ).fetchone()[0] + ) + same_protein_cross_holdout = int( + con.execute( + f""" + WITH base AS ( + SELECT DISTINCT + {protein_norm} AS protein_norm, + {species_key} AS species_key + FROM {input_sql} + WHERE length({protein_norm}) > 0 + AND {species_key} != '' + ) + SELECT count(*)::BIGINT + FROM ( + SELECT protein_norm + FROM base + GROUP BY protein_norm + HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 + AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 + ); + """ + ).fetchone()[0] + ) + + report = { + "parameters": { + "input_glob": args.input_glob, + "heldout_test_glob": args.heldout_test_glob, + "seq_cluster_parquet": str(seq_cluster), + "seq_split_parquet": str(seq_split), + "seq_space": args.seq_space, + "species_key_mode": args.species_key_mode, + "limit_files": int(args.limit_files), + }, + "split_seq_ids": split_seq_ids, + "split_rows": split_rows, + "verification": { + "clusters_spanning_splits_excluding_drop": bad_clusters, + "test_rows_with_seen_species": bad_test, + "val_species_not_in_train": bad_val_species, + "exact_protein_overlap_train_val": protein_overlap_train_val, + "exact_protein_overlap_train_test": protein_overlap_train_test, + }, + "audit": { + "mixed_test_clusters": mixed_test_clusters, + "exact_protein_cross_holdout_seen_groups": exact_holdout_seen_conflicts, + "dropped_seen_rows_from_exact_protein_holdout_overlap": dropped_seen_rows_exact_holdout, + "dropped_rows_from_exact_protein_train_overlap": dropped_val_rows_exact_train, + "dropped_seen_rows_from_mixed_test_clusters": dropped_seen_rows_mixed, + "dropped_seen_seqids_from_mixed_test_clusters": dropped_seen_seqids_mixed, + "same_protein_multi_species_exact_matches": same_protein_multi_species, + "same_protein_cross_holdout_species_exact_matches": same_protein_cross_holdout, + }, + } + + if args.report_json: + report_path = Path(args.report_json) + report_path.parent.mkdir(parents=True, exist_ok=True) + with open(report_path, "w", encoding="utf-8") as f: + json.dump(report, f, indent=2, sort_keys=True) + print(f"Wrote audit report: {report_path}") + + if ( + bad_clusters != 0 + or bad_test != 0 + or bad_val_species != 0 + or protein_overlap_train_val != 0 + or protein_overlap_train_test != 0 + ): + raise SystemExit("Verification FAILED (see counts above).") + print("Verification OK.") + + +def build_parser() -> argparse.ArgumentParser: + ap = argparse.ArgumentParser( + description="Resplit data_v2 to data_v3_rebuild using MMseqs2 protein clustering." + ) + sub = ap.add_subparsers(dest="cmd", required=True) + + p = sub.add_parser("make-fasta", help="Generate MMseqs FASTA from parquet shards.") + p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") + p.add_argument("--output-fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta") + p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") + p.add_argument( + "--max-input-seq-len", + type=int, + default=0, + help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).", + ) + p.add_argument("--threads", type=int, default=32) + p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") + p.add_argument("--limit-rows", type=int, default=0, help="Debug: limit number of rows written (0=all)") + p.set_defaults(func=cmd_make_fasta) + + p = sub.add_parser("mmseqs-cluster", help="Run MMseqs2 createdb+linclust and emit clustering TSV.") + p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path()) + p.add_argument("--fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta") + p.add_argument("--workdir", type=str, default="data_v3_rebuild/_work/mmseqs") + p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") + p.add_argument("--threads", type=int, default=32) + p.add_argument("--min-seq-id", type=float, default=0.90) + p.add_argument("-c", "--coverage", type=float, default=0.80) + p.add_argument("--cov-mode", type=int, default=2, help="2=enforce representative/query coverage") + p.add_argument("--cluster-mode", type=int, default=2, help="2=greedy clustering by sequence length") + p.add_argument("--max-seq-len", type=int, default=200000) + p.add_argument( + "--kmer-per-seq-scale", + type=float, + default=None, + help="Optional MMseqs2 override; leave empty to use MMseqs defaults.", + ) + p.add_argument("--split-memory-limit", type=str, default="", help="e.g. 120G (empty=use MMseqs default)") + g = p.add_mutually_exclusive_group() + g.add_argument( + "--remove-tmp-files", + dest="remove_tmp_files", + action="store_true", + default=True, + help="Remove MMseqs2 tmp files (default).", + ) + g.add_argument( + "--keep-tmp-files", + dest="remove_tmp_files", + action="store_false", + help="Keep MMseqs2 tmp files.", + ) + p.add_argument("--overwrite", action="store_true") + p.set_defaults(func=cmd_mmseqs_cluster) + + p = sub.add_parser("make-seq-cluster", help="Convert MMseqs TSV to parquet mapping seq_id→cluster_id.") + p.add_argument("--cluster-tsv", type=str, default="data_v3_rebuild/_work/mmseqs/clu.tsv") + p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") + p.add_argument("--threads", type=int, default=32) + p.set_defaults(func=cmd_make_seq_cluster) + + p = sub.add_parser( + "make-seq-split", + help="Create seq_id→{train,val,test,drop} using cluster assignments and heldout species.", + ) + p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") + p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") + p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") + p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") + p.add_argument("--cluster-split-parquet", type=str, default="data_v3_rebuild/_work/cluster_split.parquet") + p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") + p.add_argument("--val-frac", type=float, default=0.01) + p.add_argument("--seed", type=int, default=13) + p.add_argument("--threads", type=int, default=32) + p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") + p.set_defaults(func=cmd_make_seq_split) + + p = sub.add_parser("write-data-v3", help="Write data_v3 parquet directories from seq_split mapping.") + p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") + p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") + p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") + p.add_argument("--output-root", type=str, default="data_v3_rebuild") + p.add_argument("--num-shards", type=int, default=256, help="Partition each split into N shards") + p.add_argument("--threads", type=int, default=32) + p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") + g = p.add_mutually_exclusive_group() + g.add_argument( + "--representatives-only", + dest="representatives_only", + action="store_true", + default=True, + help="Write only one representative seq_id per MMseqs cluster (default).", + ) + g.add_argument( + "--all-cluster-members", + dest="representatives_only", + action="store_false", + help="Write all seq_ids assigned to the split instead of one representative per cluster.", + ) + p.add_argument("--overwrite", action="store_true") + p.set_defaults(func=cmd_write_data_v3) + + p = sub.add_parser("verify", help="Verify leakage/species constraints and write an audit report.") + p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") + p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") + p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") + p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") + p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") + p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") + p.add_argument("--report-json", type=str, default="data_v3_rebuild/_work/split_report.json") + p.add_argument("--threads", type=int, default=32) + p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") + p.set_defaults(func=cmd_verify) + + p = sub.add_parser( + "all", + help="Run the full pipeline: make-fasta → mmseqs-cluster → make-seq-cluster → make-seq-split → write-data-v3 → verify.", + ) + p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") + p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") + p.add_argument("--output-root", type=str, default="data_v3_rebuild") + p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") + p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") + p.add_argument( + "--max-input-seq-len", + type=int, + default=0, + help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).", + ) + p.add_argument("--threads", type=int, default=32) + p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") + p.add_argument("--num-shards", type=int, default=256) + g = p.add_mutually_exclusive_group() + g.add_argument( + "--representatives-only", + dest="representatives_only", + action="store_true", + default=True, + help="Write only one representative seq_id per MMseqs cluster (default).", + ) + g.add_argument( + "--all-cluster-members", + dest="representatives_only", + action="store_false", + help="Write all seq_ids assigned to the split instead of one representative per cluster.", + ) + p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path()) + p.add_argument("--min-seq-id", type=float, default=0.90) + p.add_argument("-c", "--coverage", type=float, default=0.80) + p.add_argument("--cov-mode", type=int, default=2) + p.add_argument("--cluster-mode", type=int, default=2) + p.add_argument("--max-seq-len", type=int, default=200000) + p.add_argument("--kmer-per-seq-scale", type=float, default=None) + p.add_argument("--split-memory-limit", type=str, default="") + p.add_argument("--val-frac", type=float, default=0.01) + p.add_argument("--seed", type=int, default=13) + p.add_argument("--overwrite", action="store_true") + + def _run_all(a: argparse.Namespace) -> None: + out_root = Path(a.output_root) + work = out_root / "_work" + fasta = work / "mmseqs_input.fasta" + mmseqs_work = work / "mmseqs" + cluster_tsv = mmseqs_work / "clu.tsv" + seq_cluster = work / "seq_cluster.parquet" + cluster_split = work / "cluster_split.parquet" + seq_split = work / "seq_split.parquet" + report_json = work / "split_report.json" + + cmd_make_fasta( + argparse.Namespace( + input_glob=a.input_glob, + output_fasta=str(fasta), + seq_space=a.seq_space, + max_input_seq_len=a.max_input_seq_len, + threads=a.threads, + limit_files=a.limit_files, + limit_rows=0, + ) + ) + cmd_mmseqs_cluster( + argparse.Namespace( + mmseqs=a.mmseqs, + fasta=str(fasta), + workdir=str(mmseqs_work), + seq_space=a.seq_space, + threads=a.threads, + min_seq_id=a.min_seq_id, + coverage=a.coverage, + cov_mode=a.cov_mode, + cluster_mode=a.cluster_mode, + max_seq_len=a.max_seq_len, + kmer_per_seq_scale=a.kmer_per_seq_scale, + split_memory_limit=a.split_memory_limit, + remove_tmp_files=True, + overwrite=a.overwrite, + ) + ) + cmd_make_seq_cluster( + argparse.Namespace( + cluster_tsv=str(cluster_tsv), + output_parquet=str(seq_cluster), + threads=a.threads, + ) + ) + cmd_make_seq_split( + argparse.Namespace( + input_glob=a.input_glob, + heldout_test_glob=a.heldout_test_glob, + species_key_mode=a.species_key_mode, + seq_cluster_parquet=str(seq_cluster), + cluster_split_parquet=str(cluster_split), + output_parquet=str(seq_split), + val_frac=a.val_frac, + seed=a.seed, + threads=a.threads, + limit_files=a.limit_files, + ) + ) + cmd_write_data_v3( + argparse.Namespace( + input_glob=a.input_glob, + seq_cluster_parquet=str(seq_cluster), + seq_split_parquet=str(seq_split), + output_root=str(out_root), + num_shards=a.num_shards, + threads=a.threads, + limit_files=a.limit_files, + representatives_only=a.representatives_only, + overwrite=a.overwrite, + ) + ) + cmd_verify( + argparse.Namespace( + input_glob=a.input_glob, + heldout_test_glob=a.heldout_test_glob, + species_key_mode=a.species_key_mode, + seq_space=a.seq_space, + seq_cluster_parquet=str(seq_cluster), + seq_split_parquet=str(seq_split), + report_json=str(report_json), + threads=a.threads, + limit_files=a.limit_files, + ) + ) + + p.set_defaults(func=_run_all) + return ap + + +def main(argv: Optional[List[str]] = None) -> int: + ap = build_parser() + args = ap.parse_args(argv) + args.func(args) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/sampling.py b/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..c5cdbd4d837b541218572038830104bf8b5d3d2b --- /dev/null +++ b/sampling.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python +""" +Sampling script for generating codon sequences from trained CodonGPT models. +Inputs are prepared exactly like training: +- Species conditioning via SpeciesEmbeddingStore (fixed-size [B,Ds] or variable-length [B,Ls,Ds]) +- Protein conditioning via raw AA strings (ESM-C tokenization happens inside the model) +""" + +import argparse +import logging +import json +from pathlib import Path +from typing import List, Optional, Union + +import torch + +from src.sampler import CodonSampler +from src.dataset import SpeciesEmbeddingStore + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger("codongpt.sample") + + +def parse_args(): + p = argparse.ArgumentParser(description="Sample codon sequences from CodonGPT model") + + # Model + p.add_argument("--model_path", "--model_dir", dest="model_path", type=str, required=True, + help="Path to trained model checkpoint dir") + p.add_argument("--device", type=str, default="cuda", help="cuda or cpu") + p.add_argument("--compile", action="store_true", help="torch.compile the model") + + # Species embeddings + p.add_argument("--embeddings_dir", type=str, default=None, + help="Directory with precomputed variable-length species embeddings (optional; fallback to Qwen if missing/unknown)") + p.add_argument("--strict_species_lookup", action="store_true", + help="When using --embeddings_dir, fail if any requested species name is not an exact key in species_vocab.json") + p.add_argument("--taxonomy_db", type=str, default=None, + help="Optional path to taxonomy_database.json (from precompute) to enrich prompts") + + # Sampling batch size and count + p.add_argument("--num_sequences", "--num_seq", "--num_samples", type=int, default=1, dest="num_sequences", + help="Number of sequences to generate in total") + p.add_argument("--batch_size", type=int, default=None, help="Batch size for sampling loop") + + # Control mode and length + p.add_argument("--control_mode", choices=["fixed", "variable"], default="fixed", + help="fixed: disallow EOS, generate exactly sequence_length codons; variable: allow EOS") + p.add_argument("--sequence_length", type=int, default=None, + help="Number of CODONS to generate (used as max steps in variable mode). " + "If omitted and protein sequences are provided, set to min protein length.") + + # Conditioning (REQUIRED: species and protein) + p.add_argument("--species", "--taxon", type=str, default=None, dest="species", + help="Species name (e.g., 'Homo sapiens'). Replicated if num_sequences>1.") + p.add_argument("--species_list", type=str, nargs="+", default=None, + help="List of species names (must match num_sequences).") + + p.add_argument("--protein_seq", "--protein_sequence", type=str, default=None, dest="protein_seq", + help="Protein sequence (AA string). Replicated if num_sequences>1.") + p.add_argument("--protein_file", type=str, default=None, + help="Path to FASTA-like file (each non-header line is a sequence). Must provide at least num_sequences.") + + # Sampling params + p.add_argument("--temperature", type=float, default=1, help="Sampling temperature") + p.add_argument("--top_k", type=int, default=50, help="Top-k") + p.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus)") + p.add_argument("--enforce_translation", action="store_true", default=False, + help="Hard-mask codons to match the given protein AA at each position") + p.add_argument("--seed", type=int, default=None) + p.add_argument("--save_intermediate", action="store_true", help="Store intermediate token states") + + # Output + p.add_argument("--output_file", type=str, default=None) + p.add_argument("--output_format", type=str, default="fasta", choices=["fasta", "csv", "json"]) + + # Misc + p.add_argument("--quiet", action="store_true") + return p.parse_args() + + +def load_protein_sequences(file_path: str) -> List[str]: + """Load protein sequences: every non-'>' line is a sequence.""" + seqs: List[str] = [] + with open(file_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith(">"): + seqs.append(line) + return seqs + + +def setup_species_store(embeddings_dir: str) -> SpeciesEmbeddingStore: + """Load species embedding store (prefer variable-length if available).""" + # We don't guess. If you stored sequence-format, this will pick it; else fixed-size. + return SpeciesEmbeddingStore(embeddings_dir, pooling="sequence") + + +def save_sequences( + sequences: List[str], + output_file: str, + fmt: str, + species: Optional[List[str]] = None, + proteins: Optional[List[str]] = None, + metadata: Optional[dict] = None, +): + if fmt == "fasta": + with open(output_file, "w") as f: + for i, seq in enumerate(sequences): + header = f">seq_{i}" + if species and i < len(species): + header += f"|species={species[i]}" + if proteins and i < len(proteins): + header += f"|protein_len={len(proteins[i])}" + f.write(f"{header}\n{seq}\n") + return + + if fmt == "csv": + import pandas as pd + data = {"sequence": sequences} + if species: + data["species"] = species[:len(sequences)] + if proteins: + data["protein_sequence"] = proteins[:len(sequences)] + pd.DataFrame(data).to_csv(output_file, index=False) + return + + # json + payload = {"sequences": sequences, "metadata": metadata or {}} + if species: + payload["species"] = species[:len(sequences)] + if proteins: + payload["protein_sequences"] = proteins[:len(sequences)] + with open(output_file, "w") as f: + json.dump(payload, f, indent=2) + + +def translate_dna_to_aa(dna_seq: str) -> str: + """Translate DNA (3-mer) using the standard genetic code.""" + g = { + 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', + 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', + 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', + 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', + 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', + 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', + 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', + 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G' + } + L = len(dna_seq) // 3 + aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)] + return ''.join(aa) + + +def report_token_accuracy(sequences: List[str], target_proteins: List[str]) -> None: + for i, dna in enumerate(sequences): + tgt = target_proteins[i] if i < len(target_proteins) else target_proteins[-1] + gen_aa = translate_dna_to_aa(dna) + L = min(len(gen_aa), len(tgt)) + if L == 0: + acc = 0.0; num = 0; den = 0 + else: + matches = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b) + acc = matches / L; num = matches; den = L + logger.info(f"AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})") + + +def main(): + args = parse_args() + + if args.device == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available") + + if args.seed is not None: + torch.manual_seed(int(args.seed)) + + # Conditioning must be provided – same invariants as training + have_species_names = bool(args.species_list) or bool(args.species) + have_protein = bool(args.protein_file) or bool(args.protein_seq) + if not have_species_names or not have_protein: + raise ValueError("Sampling requires BOTH species (names) and protein sequence(s).") + + # Species names list + if args.species_list: + species_names = list(args.species_list) + else: + species_names = [str(args.species)] + + # Protein sequences list + if args.protein_file: + protein_sequences = load_protein_sequences(args.protein_file) + else: + protein_sequences = [str(args.protein_seq)] + + # Expand/reconcile counts + N = int(args.num_sequences) + if len(species_names) == 1 and N > 1: + species_names = species_names * N + if len(protein_sequences) == 1 and N > 1: + protein_sequences = protein_sequences * N + + if len(species_names) != N: + raise ValueError(f"species count ({len(species_names)}) must equal num_sequences ({N})") + if len(protein_sequences) < N: + raise ValueError(f"protein sequences provided ({len(protein_sequences)}) less than num_sequences ({N})") + if len(protein_sequences) > N: + protein_sequences = protein_sequences[:N] + + # If no explicit sequence_length, use min protein length, so every sample has a valid AA at each fixed step + if args.sequence_length is None: + args.sequence_length = min(len(s) for s in protein_sequences) + logger.info(f"Auto-set sequence_length to min protein length: {args.sequence_length} codons") + + if args.sequence_length <= 0: + raise ValueError("sequence_length must be > 0") + + # Load species store if provided (preferred to exactly match training); unknown species will fallback to Qwen + species_store = None + if args.embeddings_dir: + species_store = setup_species_store(args.embeddings_dir) + logger.info(f"Loaded species store: {len(species_store.vocab)} species; Ds={species_store.Ds()}") + if args.strict_species_lookup: + unknown = sorted({name for name in species_names if name not in species_store.vocab}) + if unknown: + preview = ", ".join(repr(x) for x in unknown[:5]) + more = "" if len(unknown) <= 5 else f" ... (+{len(unknown) - 5} more)" + raise ValueError( + "strict species lookup failed; these names are not exact keys in species_vocab.json: " + f"{preview}{more}" + ) + + sampler = CodonSampler( + model_path=args.model_path, + device=args.device, + compile_model=bool(args.compile), + species_store=species_store, + taxonomy_db_path=args.taxonomy_db, + ) + + # Batch loop + batch_size = int(args.batch_size or N) + all_sequences: List[str] = [] + all_intermediates = [] + + total_batches = (N + batch_size - 1) // batch_size + for start in range(0, N, batch_size): + end = min(N, start + batch_size) + bs = end - start + batch_species = species_names[start:end] + batch_proteins = protein_sequences[start:end] + + logger.info(f"Sampling batch {start//batch_size + 1}/{total_batches} (B={bs})") + + result = sampler.sample( + num_sequences=bs, + sequence_length=int(args.sequence_length), + species=batch_species, + protein_sequences=batch_proteins, + control_mode=str(args.control_mode), + temperature=float(args.temperature), + top_k=int(args.top_k), + top_p=float(args.top_p), + seed=int(args.seed) if args.seed is not None else None, + return_intermediate=bool(args.save_intermediate), + progress_bar=not bool(args.quiet), + enforce_translation=bool(args.enforce_translation), + ) + + seqs = result["sequences"] # List[str] + all_sequences.extend(seqs) + if args.save_intermediate and "intermediate_states" in result: + all_intermediates.append(result["intermediate_states"]) + + logger.info(f"Generated {len(all_sequences)} sequences.") + for i, seq in enumerate(all_sequences[:5]): + logger.info(f"Sequence {i+1} ({len(seq)//3} codons): {seq[:60]}...") + + # Save outputs + if args.output_file: + meta = { + "model_path": args.model_path, + "temperature": args.temperature, + "top_k": args.top_k, + "top_p": args.top_p, + "control_mode": args.control_mode, + "sequence_length": int(args.sequence_length), + } + save_sequences( + all_sequences, + args.output_file, + args.output_format, + species=species_names, + proteins=protein_sequences, + metadata=meta, + ) + logger.info(f"Saved sequences to {args.output_file}") + + # Report AA token accuracy when protein targets are given + report_token_accuracy(all_sequences, protein_sequences) + + if args.save_intermediate and all_intermediates: + inter_file = Path(args.output_file).with_suffix("").as_posix() + "_intermediate.pt" + torch.save(all_intermediates, inter_file) + logger.info(f"Saved intermediate states to {inter_file}") + + logger.info("Sampling completed.") + + +if __name__ == "__main__": + main() diff --git a/slurm/rebuild_data_v3_cpu.sbatch b/slurm/rebuild_data_v3_cpu.sbatch new file mode 100644 index 0000000000000000000000000000000000000000..11b7732a4699fddae76f4a8d5fc7920dd7646b3d --- /dev/null +++ b/slurm/rebuild_data_v3_cpu.sbatch @@ -0,0 +1,98 @@ +#!/bin/bash +#SBATCH --partition=beacon +#SBATCH --qos=high +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=240G +#SBATCH --time=3-00:00:00 +#SBATCH --job-name=data_v3_rebuild +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err + +set -euo pipefail + +REPO_ROOT=${REPO_ROOT:-/beacon-projects/codon-lm/HE-DLM} +PYTHON_BIN=${PYTHON_BIN:-/beacon-projects/codon-lm/miniconda3/envs/dna/bin/python} +MMSEQS_BIN=${MMSEQS_BIN:-$REPO_ROOT/MMseqs2/build/bin/mmseqs} +INPUT_GLOB=${INPUT_GLOB:-data_v2/*/*.parquet} +HELDOUT_GLOB=${HELDOUT_GLOB:-data_v2/test/*.parquet} +SEQ_SPACE=${SEQ_SPACE:-protein} +SPECIES_KEY_MODE=${SPECIES_KEY_MODE:-binomial} +MAX_INPUT_SEQ_LEN=${MAX_INPUT_SEQ_LEN:-20000} +MODE=${MODE:-full} +OUTPUT_ROOT=${OUTPUT_ROOT:-data_v3_rebuild} +LIMIT_FILES=${LIMIT_FILES:-0} +NUM_SHARDS=${NUM_SHARDS:-256} +VAL_FRAC=${VAL_FRAC:-0.01} +THREADS=${THREADS:-${SLURM_CPUS_PER_TASK:-16}} +MIN_SEQ_ID=${MIN_SEQ_ID:-0.90} +COVERAGE=${COVERAGE:-0.80} +COV_MODE=${COV_MODE:-2} +CLUSTER_MODE=${CLUSTER_MODE:-2} +MAX_SEQ_LEN=${MAX_SEQ_LEN:-200000} +SPLIT_MEMORY_LIMIT=${SPLIT_MEMORY_LIMIT:-180G} +SEED=${SEED:-13} +OVERWRITE=${OVERWRITE:-1} + +if [[ "${MODE}" == "pilot" ]]; then + if [[ "${LIMIT_FILES}" == "0" ]]; then + LIMIT_FILES=4 + fi + if [[ "${OUTPUT_ROOT}" == "data_v3_rebuild" ]]; then + OUTPUT_ROOT=data_v3_pilot + fi + if [[ "${NUM_SHARDS}" == "256" ]]; then + NUM_SHARDS=16 + fi +fi + +cd "${REPO_ROOT}" + +export LD_LIBRARY_PATH="/beacon-projects/codon-lm/miniconda3/envs/dna/lib:/beacon-projects/codon-lm/miniconda3/lib:${LD_LIBRARY_PATH:-}" + +if [[ -f "${MMSEQS_BIN}" && ! -x "${MMSEQS_BIN}" ]]; then + chmod u+x "${MMSEQS_BIN}" +fi + +"${PYTHON_BIN}" - <<'PY' +import duckdb +import pyarrow + +print("duckdb", duckdb.__version__) +print("pyarrow", pyarrow.__version__) +PY + +CMD=( + "${PYTHON_BIN}" resplit_data_v3.py all + --input-glob "${INPUT_GLOB}" + --heldout-test-glob "${HELDOUT_GLOB}" + --output-root "${OUTPUT_ROOT}" + --seq-space "${SEQ_SPACE}" + --species-key-mode "${SPECIES_KEY_MODE}" + --max-input-seq-len "${MAX_INPUT_SEQ_LEN}" + --threads "${THREADS}" + --limit-files "${LIMIT_FILES}" + --num-shards "${NUM_SHARDS}" + --mmseqs "${MMSEQS_BIN}" + --min-seq-id "${MIN_SEQ_ID}" + --coverage "${COVERAGE}" + --cov-mode "${COV_MODE}" + --cluster-mode "${CLUSTER_MODE}" + --max-seq-len "${MAX_SEQ_LEN}" + --val-frac "${VAL_FRAC}" + --seed "${SEED}" +) + +if [[ -n "${SPLIT_MEMORY_LIMIT}" ]]; then + CMD+=(--split-memory-limit "${SPLIT_MEMORY_LIMIT}") +fi + +if [[ "${OVERWRITE}" == "1" ]]; then + CMD+=(--overwrite) +fi + +printf 'Running command:' +printf ' %q' "${CMD[@]}" +printf '\n' +"${CMD[@]}" diff --git a/slurm/submit_train_v3_h200_8x_chain.sh b/slurm/submit_train_v3_h200_8x_chain.sh new file mode 100644 index 0000000000000000000000000000000000000000..2fbcca8c5341b7547de337fd0c09d23e7e74a99c --- /dev/null +++ b/slurm/submit_train_v3_h200_8x_chain.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -euo pipefail + +cd /beacon-projects/codon-lm/HE-DLM + +SEGMENTS=${SEGMENTS:-3} +SBATCH_SCRIPT=${SBATCH_SCRIPT:-slurm/train_v3_h200_8x_single.sbatch} + +if [[ ! -f "${SBATCH_SCRIPT}" ]]; then + echo "Missing sbatch script: ${SBATCH_SCRIPT}" >&2 + exit 1 +fi + +dep="" +for idx in $(seq 1 "${SEGMENTS}"); do + if [[ -n "${dep}" ]]; then + jid=$(sbatch --parsable --dependency=afterany:"${dep}" "${SBATCH_SCRIPT}") + else + jid=$(sbatch --parsable "${SBATCH_SCRIPT}") + fi + echo "submitted segment=${idx} job_id=${jid} dependency=${dep:-none}" + dep="${jid}" +done diff --git a/slurm/train_v3_h200_8x_single.sbatch b/slurm/train_v3_h200_8x_single.sbatch new file mode 100644 index 0000000000000000000000000000000000000000..063135921140f0180800efe2367a1e49e1b2f665 --- /dev/null +++ b/slurm/train_v3_h200_8x_single.sbatch @@ -0,0 +1,165 @@ +#!/bin/bash +# Single-node 8x H200 training entrypoint. +# Reserved single-node smoke-run example: +# sbatch --time=00:45:00 \ +# --export=ALL,OUT_DIR=/beacon-projects/codon-lm/HE-DLM/outputs_v3_rep_h200_8x_reserved_smoke,MAX_STEPS=20,SAVE_STEPS=0,EVAL_INTERVAL=0 \ +# slurm/train_v3_h200_8x_single.sbatch +# Full-run example: +# sbatch slurm/train_v3_h200_8x_single.sbatch +# +# Suggested W&B overrides: +# sbatch --export=ALL,WANDB_PROJECT=he-dlm-v3-h200-8x,WANDB_NAME=he-dlm-v3-h200-8x-run1 \ +# slurm/train_v3_h200_8x_single.sbatch +# If the environment is still configured for offline logging, override at submit time: +# sbatch --export=ALL,WANDB_MODE=online slurm/train_v3_h200_8x_single.sbatch +# This script is pinned to the reserved H200 allocation on ihccs210. +# Do not use QoS=reserved on any other node. +#SBATCH --job-name=train-v3-h200-8x +#SBATCH --partition=beacon +#SBATCH --qos=reserved +#SBATCH --reservation=heng-reservation +#SBATCH --nodelist=ihccs210 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:nvidia_h200:8 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=512G +#SBATCH --time=3-00:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err + +set -euo pipefail + +set +u +source ~/.bashrc +conda activate dna +set -u + +cd /beacon-projects/codon-lm/HE-DLM + +TRAIN_DATA=${TRAIN_DATA:-/beacon-projects/codon-lm/HE-DLM/data_v3_rebuild/train} +VAL_DATA=${VAL_DATA:-/beacon-projects/codon-lm/HE-DLM/data_v3_rebuild/val} +EMBED_DIR=${EMBED_DIR:-/beacon-projects/codon-lm/HE-DLM/embeddings_v2} +OUT_DIR=${OUT_DIR:-/beacon-projects/codon-lm/HE-DLM/outputs_v3_rep_h200_8x_single_wd1e-4_bs48ga4} + +WANDB_PROJECT=${WANDB_PROJECT:-he-dlm-v3-h200-8x} +WANDB_NAME=${WANDB_NAME:-$(basename "${OUT_DIR}")} +WANDB_RUN_ID=${WANDB_RUN_ID:-$(basename "${OUT_DIR}")} +WANDB_RESUME=${WANDB_RESUME:-allow} +WANDB_DIR=${WANDB_DIR:-${OUT_DIR}/wandb} + +NPROC_PER_NODE=${NPROC_PER_NODE:-8} +BATCH_SIZE=${BATCH_SIZE:-48} +GRAD_ACCUM=${GRAD_ACCUM:-4} +EVAL_BATCH_SIZE=${EVAL_BATCH_SIZE:-32} +WORKERS=${WORKERS:-0} +EPOCHS=${EPOCHS:-3} +LR=${LR:-7e-5} +WARMUP_RATIO=${WARMUP_RATIO:-0.1} +WEIGHT_DECAY=${WEIGHT_DECAY:-1e-4} +LOGGING_STEPS=${LOGGING_STEPS:-10} +SAVE_STEPS=${SAVE_STEPS:-500} +SAVE_TOTAL_LIMIT=${SAVE_TOTAL_LIMIT:-1000} +EVAL_INTERVAL=${EVAL_INTERVAL:-5000} +EVAL_STEPS=${EVAL_STEPS:-256} +TRAIN_SHUFFLE_BUFFER=${TRAIN_SHUFFLE_BUFFER:-8192} +VAL_SHUFFLE_BUFFER=${VAL_SHUFFLE_BUFFER:-0} +CKPT_RECENT_WINDOW_STEPS=${CKPT_RECENT_WINDOW_STEPS:-2000} +CKPT_RECENT_INTERVAL=${CKPT_RECENT_INTERVAL:-500} +CKPT_ARCHIVE_INTERVAL=${CKPT_ARCHIVE_INTERVAL:-1000} +RESUME_FROM=${RESUME_FROM:-auto} +MAX_STEPS=${MAX_STEPS:-} +MASTER_PORT=${MASTER_PORT:-29500} +GRAD_CKPT=${GRAD_CKPT:-0} + +export WANDB_PROJECT WANDB_NAME WANDB_RUN_ID WANDB_RESUME WANDB_DIR +export NCCL_DEBUG=${NCCL_DEBUG:-WARN} +export TORCH_DISTRIBUTED_DEBUG=${TORCH_DISTRIBUTED_DEBUG:-DETAIL} +export NCCL_P2P_DISABLE=${NCCL_P2P_DISABLE:-0} +export NCCL_IB_DISABLE=${NCCL_IB_DISABLE:-1} +export NCCL_NET_GDR_LEVEL=${NCCL_NET_GDR_LEVEL:-0} +export NCCL_ASYNC_ERROR_HANDLING=${NCCL_ASYNC_ERROR_HANDLING:-1} +export NCCL_SHM_DISABLE=${NCCL_SHM_DISABLE:-1} +export NCCL_CUMEM_HOST_ENABLE=${NCCL_CUMEM_HOST_ENABLE:-1} +export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1} + +mkdir -p "${OUT_DIR}" "${WANDB_DIR}" + +if [[ ! -d "${TRAIN_DATA}" ]]; then + echo "Missing train data dir: ${TRAIN_DATA}" >&2 + exit 1 +fi +if [[ ! -d "${VAL_DATA}" ]]; then + echo "Missing val data dir: ${VAL_DATA}" >&2 + exit 1 +fi +if [[ ! -f "${EMBED_DIR}/species_vocab.json" ]]; then + echo "Missing embeddings vocab: ${EMBED_DIR}/species_vocab.json" >&2 + exit 1 +fi + +echo "HOST=$(hostname)" +echo "TRAIN_DATA=${TRAIN_DATA}" +echo "VAL_DATA=${VAL_DATA}" +echo "EMBED_DIR=${EMBED_DIR}" +echo "OUT_DIR=${OUT_DIR}" +echo "WANDB_PROJECT=${WANDB_PROJECT} WANDB_NAME=${WANDB_NAME} WANDB_RUN_ID=${WANDB_RUN_ID} WANDB_RESUME=${WANDB_RESUME} WANDB_MODE=${WANDB_MODE:-unset}" +echo "BATCH_SIZE=${BATCH_SIZE} GRAD_ACCUM=${GRAD_ACCUM} EVAL_BATCH_SIZE=${EVAL_BATCH_SIZE} NPROC_PER_NODE=${NPROC_PER_NODE}" +echo "WEIGHT_DECAY=${WEIGHT_DECAY} SAVE_STEPS=${SAVE_STEPS} EVAL_INTERVAL=${EVAL_INTERVAL} MAX_STEPS=${MAX_STEPS:-unset}" +echo "NCCL_P2P_DISABLE=${NCCL_P2P_DISABLE} NCCL_IB_DISABLE=${NCCL_IB_DISABLE} NCCL_SHM_DISABLE=${NCCL_SHM_DISABLE} NCCL_CUMEM_HOST_ENABLE=${NCCL_CUMEM_HOST_ENABLE}" + +echo "=== GPU inventory ===" +nvidia-smi --query-gpu=index,name,memory.total,driver_version --format=csv,noheader || true +echo "=== GPU topology ===" +nvidia-smi topo -m || true +echo "=== NVLink status ===" +nvidia-smi nvlink -s || true + +CMD=( + torchrun + --standalone + --nproc_per_node "${NPROC_PER_NODE}" + --master_port "${MASTER_PORT}" + train.py + --train_data "${TRAIN_DATA}" + --val_data "${VAL_DATA}" + --embeddings_dir "${EMBED_DIR}" + --output_dir "${OUT_DIR}" + --fsdp + --bf16 + --attn mha + --hidden 750 + --layers 20 + --heads 15 + --mlp_ratio 3.2 + --batch_size "${BATCH_SIZE}" + --grad_accum "${GRAD_ACCUM}" + --eval_batch_size "${EVAL_BATCH_SIZE}" + --epochs "${EPOCHS}" + --workers "${WORKERS}" + --warmup_ratio "${WARMUP_RATIO}" + --lr "${LR}" + --weight_decay "${WEIGHT_DECAY}" + --train_shuffle_buffer "${TRAIN_SHUFFLE_BUFFER}" + --val_shuffle_buffer "${VAL_SHUFFLE_BUFFER}" + --logging_steps "${LOGGING_STEPS}" + --save_steps "${SAVE_STEPS}" + --save_total_limit "${SAVE_TOTAL_LIMIT}" + --ckpt_recent_window_steps "${CKPT_RECENT_WINDOW_STEPS}" + --ckpt_recent_interval "${CKPT_RECENT_INTERVAL}" + --ckpt_archive_interval "${CKPT_ARCHIVE_INTERVAL}" + --eval_interval "${EVAL_INTERVAL}" + --eval_steps "${EVAL_STEPS}" +) + +if [[ "${RESUME_FROM}" != "none" && -n "${RESUME_FROM}" ]]; then + CMD+=(--resume_from "${RESUME_FROM}") +fi +if [[ -n "${MAX_STEPS}" ]]; then + CMD+=(--max_steps "${MAX_STEPS}") +fi +if [[ "${GRAD_CKPT}" == "1" ]]; then + CMD+=(--grad_ckpt) +fi + +exec "${CMD[@]}" diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39f7c1726cca443def9a622f69fce51132906906 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,33 @@ +""" +CodonGPT – conditional codon sequence generation (GPT-only). +""" + +from .tokenizer import CodonTokenizer +from .models import ( + CodonGPT, +) +from .trainer import Trainer, TrainingArguments +from .sampler import CodonSampler, sample_sequences +from .dataset import ( + stage_collate_fn, + create_precomputed_dataloaders, +) + +__version__ = "0.1.0" + +__all__ = [ + # Tokenizer + "CodonTokenizer", + # Models + "CodonGPT", + # Training + "Trainer", + "TrainingArguments", + # Sampling + "CodonSampler", + "sample_sequences", + # Data + # "stage_collate_fn", + "create_precomputed_dataloaders", + # Noise +] diff --git a/src/__pycache__/__init__.cpython-312.pyc b/src/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00a2b48e7e5d1c09b398af2bb1386283944a01ab Binary files /dev/null and b/src/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/__pycache__/dataset.cpython-312.pyc b/src/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..522abf45ecd00a1712bb354e394d26d521155942 Binary files /dev/null and b/src/__pycache__/dataset.cpython-312.pyc differ diff --git a/src/__pycache__/layers.cpython-312.pyc b/src/__pycache__/layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..176727ca168b9f34f469b649153a764b0593d326 Binary files /dev/null and b/src/__pycache__/layers.cpython-312.pyc differ diff --git a/src/__pycache__/models.cpython-312.pyc b/src/__pycache__/models.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ad206fe76dc5c887cd1fc2c30234181e773f74c Binary files /dev/null and b/src/__pycache__/models.cpython-312.pyc differ diff --git a/src/__pycache__/sampler.cpython-312.pyc b/src/__pycache__/sampler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045cc7672658e38bdb31a5957e4f1b2d9d89d730 Binary files /dev/null and b/src/__pycache__/sampler.cpython-312.pyc differ diff --git a/src/__pycache__/tokenizer.cpython-312.pyc b/src/__pycache__/tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f72a1f50795d6c16ff09cd9f266958483a7f01 Binary files /dev/null and b/src/__pycache__/tokenizer.cpython-312.pyc differ diff --git a/src/__pycache__/trainer.cpython-312.pyc b/src/__pycache__/trainer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79e4e3569dd6635b9888e93794401b16006d458a Binary files /dev/null and b/src/__pycache__/trainer.cpython-312.pyc differ diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..acbd226fc4ab3573bcd6a91688908890168b3348 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,833 @@ +# src/dataset.py +""" +Production-ready dataset + dataloader utilities. + +Rules (because we're adults): +- Data drives design. Inputs are rows with columns: ["cds_DNA", "protein_seq", "Taxon", (optional) "RefseqID"]. +- Output per sample is a tiny dict the model actually needs. Nothing else. +- We stream Parquet by row groups, CSV by chunks. No full-file pandas nonsense on big data. +- We shard by (FSDP rank × dataloader worker). No DistributedSampler needed. +- We do a simple streaming shuffle buffer for train. Good enough. No fancy "epoch managers". + +Fields emitted per sample (for collate_fn and trainer): + { + "species_name": str, + "species_id": int, + "protein_seq": str, # raw AA (ESM tokenized later) + "aa_len": int, + "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end + "refseq_id": str, + "protein_refseq_id": str, + "control_mode": "fixed", + "meta": {"src": "parquet|csv", "file": basename, "row": int} + } + +Invariants: +- cds_DNA length divisible by 3 after trimming to match protein length. +- DNA uses only ACGT (uppercase). If not, we skip the row. We don't "helpfully fix" broken data. +- We truncate both DNA and protein to the same min length (codon count). +- EOS appended to codon_ids; PAD is handled at collate time, not here. + +Dependencies: +- pyarrow only if you read parquet. If it isn't installed and you pass parquet files, we fail loudly. +""" + +from __future__ import annotations + +import os +import json +import glob +import random +import logging +import heapq +from typing import Dict, List, Any, Optional, Iterable, Tuple +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import IterableDataset, Dataset, DataLoader, get_worker_info + +try: + from tqdm.auto import tqdm as _tqdm +except Exception: # pragma: no cover - tqdm might be unavailable in minimal envs + _tqdm = None + +logger = logging.getLogger(__name__) + +# ------------------------------ +# Species Embedding Store (kept simple and stable) +# ------------------------------ + +class SpeciesEmbeddingStore: + def __init__(self, embeddings_dir: str, dtype: str = "float32", pin_memory: bool = False, pooling: str = "last"): + self.embeddings_dir = Path(embeddings_dir) + self.pin_memory = bool(pin_memory) + self.is_legacy = False + self.pooling = pooling + + vocab_path = self.embeddings_dir / "species_vocab.json" + if not vocab_path.exists(): + raise FileNotFoundError(f"Species vocabulary not found at {vocab_path}") + with open(vocab_path, "r") as f: + self.vocab: Dict[str, int] = json.load(f) + + meta_path = self.embeddings_dir / "species_metadata.json" + new_emb_path = self.embeddings_dir / "species_embeddings.bin" + legacy_index = self.embeddings_dir / "species_index.json" + legacy_emb = self.embeddings_dir / "species_tok_emb.bin" + + if self.pooling == "sequence" and legacy_index.exists() and legacy_emb.exists(): + self.is_legacy = True + self._load_legacy_format(dtype) + return + + if meta_path.exists() and new_emb_path.exists(): + with open(meta_path, "r") as f: + meta = json.load(f) + self.num_species = int(meta["num_species"]) + self._ds = int(meta["embedding_dim"]) + self.embedding_type = str(meta.get("embedding_type", "fixed_size")) + np_dtype = np.float16 if dtype == "float16" else np.float32 + self.embeddings = np.memmap(new_emb_path, dtype=np_dtype, mode="r", shape=(self.num_species, self._ds)) + self._np_dtype = np_dtype + print(f"Loaded fixed-size species embeddings: {len(self.vocab)} species, Ds={self._ds}, dtype={self._np_dtype}") + else: + self.is_legacy = True + self._load_legacy_format(dtype) + + def _load_legacy_format(self, dtype: str): + index_path = self.embeddings_dir / "species_index.json" + if not index_path.exists(): + raise FileNotFoundError(f"Species index not found at {index_path}") + with open(index_path, "r") as f: + raw_index = json.load(f) + self.index: Dict[str, Dict[str, int]] = {str(k): v for k, v in raw_index.items()} + + meta_path = self.embeddings_dir / "metadata.json" + file_dtype = dtype + if meta_path.exists(): + with open(meta_path, "r") as f: + meta = json.load(f) + self._ds = int(meta.get("embedding_dim", 1024)) + file_dtype = str(meta.get("dtype", dtype)).lower() + else: + self._ds = 1024 + + emb_path = self.embeddings_dir / "species_tok_emb.bin" + if not emb_path.exists(): + raise FileNotFoundError(f"Species embeddings not found at {emb_path}") + + np_dtype = np.float16 if file_dtype == "float16" else np.float32 + itemsize = np.dtype(np_dtype).itemsize + file_bytes = os.path.getsize(emb_path) + if file_bytes % (self._ds * itemsize) != 0: + raise ValueError(f"Emb file size {file_bytes} not divisible by Ds*itemsize ({self._ds}*{itemsize})") + total_tokens = file_bytes // (self._ds * itemsize) + + self.embeddings = np.memmap(emb_path, dtype=np_dtype, mode="r", shape=(total_tokens, self._ds)) + self._np_dtype = np_dtype + self.num_species = len(self.vocab) + print(f"[LEGACY] variable-length embeddings: {len(self.vocab)} species, {total_tokens} tokens total, Ds={self._ds}.") + + def load_vocab(self) -> Dict[str, int]: + return self.vocab.copy() + + def _deterministic_stub(self, length: int = None) -> torch.FloatTensor: + if self.is_legacy and length: + t = torch.zeros(1, length, self._ds, dtype=torch.float32) + else: + t = torch.zeros(1, self._ds, dtype=torch.float32) + return t + + def get(self, species_id: int) -> torch.FloatTensor: + if not self.is_legacy: + if species_id < 0 or species_id >= getattr(self, "num_species", 0): + return self._deterministic_stub() + emb = self.embeddings[species_id] + tensor = torch.from_numpy(np.asarray(emb).copy()).float().unsqueeze(0) + return tensor + else: + sid = str(species_id) + entry = self.index.get(sid) + if entry is None: + return self._deterministic_stub(length=8) + offset = int(entry["offset"]); length = int(entry["length"]) + view = self.embeddings[offset: offset + length] + tensor = torch.from_numpy(np.asarray(view).copy()).float().unsqueeze(0) + return tensor + + def batch_get(self, species_ids: List[int]) -> Any: + if torch.is_tensor(species_ids): + species_ids = species_ids.detach().cpu().tolist() + else: + species_ids = [int(x) for x in species_ids] + B = len(species_ids) + if not self.is_legacy: + batch_emb = torch.zeros(B, self._ds, dtype=torch.float32) + for i, sid in enumerate(species_ids): + batch_emb[i] = self.get(sid).squeeze(0) + return batch_emb + else: + tensors = [self.get(sid) for sid in species_ids] + lengths = torch.tensor([t.shape[1] for t in tensors], dtype=torch.long) + Ls_max = int(lengths.max().item()) if lengths.numel() > 0 else 0 + padded = torch.zeros(B, Ls_max, self._ds, dtype=torch.float32) + for i, t in enumerate(tensors): + L = t.shape[1]; padded[i, :L] = t.squeeze(0) + return padded, lengths + + def Ds(self) -> int: + return self._ds + +def _is_parquet(path: str) -> bool: + lower = path.lower() + return lower.endswith(".parquet") or lower.endswith(".parq") + + +def _is_csv(path: str) -> bool: + lower = path.lower() + return ( + lower.endswith(".csv") + or lower.endswith(".tsv") + or lower.endswith(".csv.gz") + or lower.endswith(".tsv.gz") + ) + + +def _expand_paths(maybe_path_or_glob: str | List[str]) -> List[str]: + """ + Expand a path/glob or list of them into a sorted, de-duplicated list of files. + We prioritize parquet, then csv/tsv. + """ + paths: List[str] = [] + if isinstance(maybe_path_or_glob, str): + p = Path(maybe_path_or_glob) + if p.is_dir(): + # Scan directory for parquet first, then csv/tsv + paths.extend(sorted(str(x) for x in p.rglob("*.parquet"))) + paths.extend(sorted(str(x) for x in p.rglob("*.parq"))) + paths.extend(sorted(str(x) for x in p.rglob("*.csv"))) + paths.extend(sorted(str(x) for x in p.rglob("*.tsv"))) + paths.extend(sorted(str(x) for x in p.rglob("*.csv.gz"))) + paths.extend(sorted(str(x) for x in p.rglob("*.tsv.gz"))) + else: + paths = sorted(glob.glob(str(p))) + else: + for it in maybe_path_or_glob: + paths.extend(_expand_paths(it)) + # Dedup while preserving order + seen = set() + out = [] + for x in paths: + if x not in seen: + out.append(x) + seen.add(x) + if not out: + raise FileNotFoundError(f"No input files found for: {maybe_path_or_glob}") + return out + + +def _dist_info() -> Tuple[int, int]: + """ + Returns (num_global_workers, global_worker_id) + where global_worker_id = rank * num_workers + worker_id. + """ + world_size = 1 + rank = 0 + try: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + except Exception: + pass + wi = get_worker_info() + nw = wi.num_workers if wi else 1 + wid = wi.id if wi else 0 + return world_size * nw, rank * nw + wid + + +class _ResumeSkipProgress: + """Lightweight progress helper for resume skips.""" + + def __init__(self, total: int, label: str): + self.total = int(max(0, total)) + self.label = label + self.count = 0 + self._bar = None + + if self.total <= 0: + return + + if _tqdm is not None: + self._bar = _tqdm(total=self.total, desc=label, unit="sample", dynamic_ncols=True, leave=False) + else: + logger.info("%s: skipping %d samples to reach resume cursor", label, self.total) + + def update(self, n: int = 1): + if self.total <= 0: + return + self.count += int(n) + if self._bar is not None: + self._bar.update(n) + else: + if self.count == self.total or self.count % 10000 == 0: + logger.info("%s: skipped %d / %d", self.label, self.count, self.total) + + def close(self): + if self.total <= 0: + return + if self._bar is not None: + self._bar.close() + logger.info("%s: resume skip finished (%d samples)", self.label, self.count) + + +class StreamSeqDataset(IterableDataset): + """ + Streaming dataset with **non-overlapping Parquet row-group sharding**. + + - Accepts list of files (parquet and/or csv/tsv). + - **Parquet**: we enumerate (file, row_group) tasks and stride them across + the *global* worker id to avoid duplicates and to keep all ranks busy even + with few files. + - **CSV/TSV**: assigned at file granularity (one worker reads a file). + If you have only a few CSV files and many ranks, some ranks may get no CSV work. + (Parquet is the recommended format at scale.) + - CSV is read with pandas chunksize to keep memory usage sane. + - Each Parquet task reads exactly **one row group** into pandas. + + Minimal resume support: + - set_resume_skip(N) skips N yielded samples across the worker's assigned tasks. + (Use a **per-rank** skip value in your trainer so multi-node resumes stay in lockstep.) + + Output sample schema: + { + "species_name": str, + "species_id": int, + "protein_seq": str, # raw AA (ESM tokenized later) + "aa_len": int, + "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end + "refseq_id": str, + "protein_refseq_id": str, + "control_mode": "fixed", + "meta": {"src": "parquet|csv", "file": basename, "row": int} + } + """ + + # Canonical required columns. We also accept common aliases (e.g., 'taxon'). + REQUIRED = ["cds_DNA", "protein_seq", "Taxon"] + + def __init__( + self, + files: List[str], + tokenizer, + species_vocab_path: str, + unknown_species_id: int = 0, + csv_chunksize: int = 200_000, + shuffle_buffer: int = 0, + seed: int = 1234, + shard_across_ranks: bool = True, + ): + super().__init__() + self.files = files + self.tok = tokenizer + with open(species_vocab_path, "r") as f: + self.species_vocab: Dict[str, int] = json.load(f) + self.unknown_species_id = int(unknown_species_id) + self.csv_chunksize = int(max(1, csv_chunksize)) + self.shuffle_buffer = int(max(0, shuffle_buffer)) + self.seed = int(seed) + # When False, every rank iterates over the full task list instead of + # taking a disjoint shard. This keeps FSDP collectives aligned during + # evaluation even if the validation dataset is smaller than WORLD_SIZE. + self.shard_across_ranks = bool(shard_across_ranks) + + # Minimal resume cursor + self._resume_skip_n: int = 0 + self._offset_start: int = 0 + self._emitted: int = 0 + + # ---- resume cursor (minimal) ---- + def set_resume_skip(self, n: int) -> None: + n = int(max(0, n)) + self._resume_skip_n = n + self._offset_start = n + self._emitted = 0 + + def get_stream_position(self) -> int: + # Total yielded so far since dataset creation, including initial skip offset + return int(self._offset_start + self._emitted) + + # ---- core row-wise iterator on a pandas DataFrame ---- + def _iter_df(self, df: pd.DataFrame, src: str, file: str) -> Iterable[Dict[str, Any]]: + # Normalize common column aliases before validating. + # Some shards use lowercase `taxon` instead of `Taxon`. + if "Taxon" not in df.columns and "taxon" in df.columns: + df = df.rename(columns={"taxon": "Taxon"}) + + # Hard fail if required missing + for c in self.REQUIRED: + if c not in df.columns: + raise ValueError(f"Input missing required column '{c}' in {file}") + + # Normalize & clean + df = df[self.REQUIRED + ([c for c in ["RefseqID"] if c in df.columns])] + df["Taxon"] = df["Taxon"].astype(str).str.strip() + df["protein_seq"] = df["protein_seq"].astype(str).str.strip().str.upper() + df["cds_DNA"] = df["cds_DNA"].astype(str).str.strip().str.upper() + + # Filter DNA: ACGT only and length > 0 + ok_mask = (df["cds_DNA"].str.len() > 0) & df["cds_DNA"].str.fullmatch(r"[ACGT]+", na=False) + df = df[ok_mask] + if df.empty: + return + + # Trim protein/DNA to shared min length (in codons) + cds_codons = (df["cds_DNA"].str.len() // 3).astype(int) + prot_len = df["protein_seq"].str.len().astype(int) + min_len = np.minimum(cds_codons.values, prot_len.values) + + df = df.assign(__min_len=min_len) + df = df[df["__min_len"] > 0] + if df.empty: + return + + # Species id map + def map_species(x: str) -> int: + try: + return int(self.species_vocab.get(x, self.unknown_species_id)) + except Exception: + return self.unknown_species_id + + species_ids = [map_species(x) for x in df["Taxon"].tolist()] + refseq_col = "RefseqID" if "RefseqID" in df.columns else None + + for i, (row_idx, row) in enumerate(df.iterrows()): + ml = int(row["__min_len"]) + cds = row["cds_DNA"][: ml * 3] + prot = row["protein_seq"][: ml] + if (len(cds) // 3) != len(prot): + continue + + # Tokenize DNA → 3-mer ids; append EOS + codon_ids = self.tok.encode_codon_seq(cds, validate=False) + codon_ids.append( + self.tok.special_ids.eos if hasattr(self.tok, "special_ids") else self.tok._special_ids.eos + ) + + species_id = species_ids[i] + ref_id = row[refseq_col] if refseq_col else f"{Path(file).stem}:{int(row_idx)}" + + yield { + "species_name": row["Taxon"], + "species_id": int(species_id), + "protein_seq": prot, + "aa_len": len(prot), + "codon_ids": codon_ids, + "refseq_id": ref_id, + "protein_refseq_id": ref_id, + "control_mode": "fixed", + "meta": {"src": src, "file": os.path.basename(file), "row": int(row_idx)}, + } + + # ---- Parquet helpers: enumerate row-group tasks & read one row group ---- + def _enumerate_tasks(self, files: List[str]) -> List[Tuple[str, str, Optional[int], int]]: + """ + Return a task list of tuples: + ("parquet", path, row_group_idx, weight) for each row group in each Parquet file + ("csv", path, None, weight) for each CSV/TSV file + """ + tasks: List[Tuple[str, str, Optional[int], int]] = [] + parquet_files = [f for f in files if _is_parquet(f)] + csv_files = [f for f in files if _is_csv(f)] + + if parquet_files: + try: + import pyarrow.parquet as pq # type: ignore + except Exception as e: + raise ImportError("pyarrow is required to read parquet files") from e + + for fp in parquet_files: + pf = pq.ParquetFile(fp) + nrg = int(pf.num_row_groups or 0) + if nrg <= 0: + # Treat as single task if row groups unavailable (unusual) + total_rows = pf.metadata.num_rows if pf.metadata and pf.metadata.num_rows is not None else 1 + tasks.append(("parquet", fp, 0, max(1, int(total_rows)))) + else: + for rg in range(nrg): + if pf.metadata is not None: + rg_meta = pf.metadata.row_group(rg) + num_rows = rg_meta.num_rows if rg_meta.num_rows is not None else 0 + else: + num_rows = 0 + tasks.append(("parquet", fp, rg, max(1, int(num_rows)))) + + # CSV/TSV files remain file-level tasks + for fp in csv_files: + file_size = os.path.getsize(fp) + # Assume ~256 bytes per record when estimating CSV row counts (empirical default) + est_rows = max(1, int(file_size // 256)) + tasks.append(("csv", fp, None, est_rows)) + + # Keep a deterministic order + # (files are already sorted by _expand_paths) + return tasks + + @staticmethod + def _balanced_partition(tasks: List[Tuple[str, str, Optional[int], int]], groups: int) -> List[List[Tuple[str, str, Optional[int], int]]]: + if groups <= 1: + return [tasks] + if not tasks: + return [[] for _ in range(groups)] + + # Greedy load balancing: assign heavier tasks first to the lightest bucket. + indexed = [(idx, kind, path, rg, weight) for idx, (kind, path, rg, weight) in enumerate(tasks)] + tasks_sorted = sorted( + indexed, + key=lambda entry: (entry[4], -entry[0]), + reverse=True, + ) + + heap: List[Tuple[int, int]] = [(0, bucket_idx) for bucket_idx in range(groups)] + heapq.heapify(heap) + buckets: List[List[Tuple[int, str, str, Optional[int], int]]] = [[] for _ in range(groups)] + + for original_index, kind, path, rg, weight in tasks_sorted: + load, bucket_idx = heapq.heappop(heap) + buckets[bucket_idx].append((original_index, kind, path, rg, weight)) + heapq.heappush(heap, (load + weight, bucket_idx)) + + partitions: List[List[Tuple[str, str, Optional[int], int]]] = [] + for bucket in buckets: + bucket.sort(key=lambda entry: entry[0]) + partitions.append([(kind, path, rg, weight) for (_idx, kind, path, rg, weight) in bucket]) + return partitions + + def _parquet_rowgroup_iter( + self, file: str, row_group_idx: int, cols_cache: Dict[str, List[str]] + ) -> Iterable[Dict[str, Any]]: + import pyarrow.parquet as pq # safe: checked in _enumerate_tasks + pf = pq.ParquetFile(file) + # Cache the column subset per file so we don't recompute + if file not in cols_cache: + names = set(pf.schema.names) + cols: List[str] = [] + # Required columns, with alias support (notably Taxon vs taxon). + for c in self.REQUIRED: + if c in names: + cols.append(c) + continue + if c == "Taxon" and "taxon" in names: + cols.append("taxon") + continue + # Optional debug id + if "RefseqID" in names: + cols.append("RefseqID") + cols_cache[file] = cols + cols = cols_cache[file] + table = pf.read_row_group(row_group_idx, columns=cols) + df = table.to_pandas(types_mapper=None) + yield from self._iter_df(df, "parquet", file) + + def _csv_file_iter(self, file: str) -> Iterable[Dict[str, Any]]: + # One worker owns this file (non-overlapping assignment) + for chunk in pd.read_csv(file, chunksize=self.csv_chunksize, dtype=str, keep_default_na=False): + yield from self._iter_df(chunk, "csv", file) + + # ---- main iterator ---- + def __iter__(self): + wi = get_worker_info() + num_workers = wi.num_workers if wi else 1 + worker_id = wi.id if wi else 0 + + num_global, gid = _dist_info() + if not self.shard_across_ranks: + num_global = max(1, num_workers) + gid = worker_id + + workers_per_rank = max(1, num_workers) + rank = gid // workers_per_rank if self.shard_across_ranks else 0 + world = max(1, num_global // workers_per_rank) + + # Each rank may have a non-zero per-rank resume skip. Split evenly across local + # dataloader workers so the sum equals the per-rank target, then apply a fast + # task-level skip to avoid row-by-row scans for huge cursors. + per_rank_skip = int(self._resume_skip_n) + base = per_rank_skip // max(1, workers_per_rank) + rem = per_rank_skip % max(1, workers_per_rank) + local_skip_target = base + (1 if worker_id < rem else 0) + progress: Optional[_ResumeSkipProgress] = None + + # Build the global task list (parquet row groups + csv files) and shard by gid + tasks = self._enumerate_tasks(self.files) + + if tasks: + partitions = self._balanced_partition(tasks, max(1, num_global)) + my_tasks_full = partitions[gid] if gid < len(partitions) else [] + else: + my_tasks_full = [] + + if local_skip_target > 0 and worker_id == 0: + label = ( + "resume skip" if world == 1 else f"resume skip (rank {rank}/{world})" + ) + progress = _ResumeSkipProgress(local_skip_target, label) + + # Fast task-level skip: consume whole tasks when their weight is <= remaining skip + # and only fall back to row-level skipping for the first partial task. + skip_remaining = int(local_skip_target) + start_idx = 0 + partial_task_idx = None + partial_task_kind = None + partial_task_path = None + partial_task_rg = None + if skip_remaining > 0 and my_tasks_full: + for idx, (kind, path, rg, weight) in enumerate(my_tasks_full): + w = int(weight) if weight is not None else 0 + if w <= 0: + continue + if skip_remaining >= w: + skip_remaining -= w + start_idx = idx + 1 + if progress is not None: + progress.update(w) + else: + partial_task_idx = idx + partial_task_kind = kind + partial_task_path = path + partial_task_rg = rg + break + + # Slice my task list to start after any fully-skipped tasks + my_tasks = [(kind, path, rg) for (kind, path, rg, _w) in my_tasks_full[start_idx:]] + + rng = random.Random(self.seed + gid) + buffer: List[Dict[str, Any]] = [] + bufN = self.shuffle_buffer + + def _drain_buffer(): + if not buffer: + return + if bufN > 0: + rng.shuffle(buffer) + for it in buffer: + yield it + buffer.clear() + + # Skip counter for resume cursor (row-level remainder after task skips) + skipped = int(local_skip_target - skip_remaining) + + # Cache for per-file Parquet column selection + cols_cache: Dict[str, List[str]] = {} + + try: + # If we split a task, handle its partial row-level skip first + if partial_task_idx is not None and skip_remaining > 0: + kind = partial_task_kind + path = partial_task_path + rg = partial_task_rg + if kind == "parquet": + assert rg is not None + row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache) + elif kind == "csv": + row_iter = self._csv_file_iter(path) + else: + raise ValueError(f"Unknown task kind: {kind}") + + for sample in row_iter: + if skip_remaining > 0: + skip_remaining -= 1 + skipped += 1 + if progress is not None: + progress.update(1) + if skip_remaining == 0 and progress is not None: + progress.close() + progress = None + continue + # past the partial skip remainder, fall through to normal buffering/yield + if bufN <= 0: + self._emitted += 1 + yield sample + else: + buffer.append(sample) + if len(buffer) >= bufN: + j = rng.randrange(len(buffer)) + buffer[j], buffer[-1] = buffer[-1], buffer[j] + self._emitted += 1 + yield buffer.pop() + + for (kind, path, rg) in my_tasks: + if kind == "parquet": + assert rg is not None + row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache) + elif kind == "csv": + row_iter = self._csv_file_iter(path) + else: + raise ValueError(f"Unknown task kind: {kind}") + + for sample in row_iter: + # Apply any remaining resume skip across the flattened stream + if skip_remaining > 0: + skip_remaining -= 1 + skipped += 1 + if progress is not None: + progress.update(1) + if skip_remaining == 0 and progress is not None: + # Finish the progress bar once we've consumed the target + progress.close() + progress = None + continue + + if bufN <= 0: + self._emitted += 1 + yield sample + else: + buffer.append(sample) + if len(buffer) >= bufN: + j = rng.randrange(len(buffer)) + buffer[j], buffer[-1] = buffer[-1], buffer[j] + self._emitted += 1 + yield buffer.pop() + + # Flush leftovers + for it in _drain_buffer(): + self._emitted += 1 + yield it + finally: + if progress is not None: + progress.close() + if local_skip_target > 0: + # Persist any remaining leftover skip (including partial progress) per worker copy + self._resume_skip_n = max(local_skip_target - skipped, 0) + + +# ------------------------------ +# Simple collate: end-only pad for codon stream, pass-through everything else +# ------------------------------ + +def stage_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + B = len(batch) + if B == 0: + return {} + + # species ids + species_ids = torch.tensor([int(x.get("species_id", 0)) for x in batch], dtype=torch.long) + + # raw protein sequences stay as list[str] (ESM handles tokenization) + protein_seqs = [str(x.get("protein_seq", "M")) for x in batch] + + # Build padded codon ids (right padding). Keep EOS inside the sequence (already appended in dataset). + codon_lists = [x.get("codon_ids", []) for x in batch] + max_len = max(len(c) for c in codon_lists) + pad_id = 0 # tokenizer.pad_token_id is 0 in our tokenizer. + codon_ids = torch.full((B, max_len), pad_id, dtype=torch.long) + for i, row in enumerate(codon_lists): + if len(row) > 0: + codon_ids[i, : len(row)] = torch.tensor(row, dtype=torch.long) + + out: Dict[str, Any] = { + "species_ids": species_ids, + "protein_seqs": protein_seqs, + "codon_ids": codon_ids, + "control_mode": batch[0].get("control_mode", "fixed"), + } + + # Optional passthroughs + if "refseq_id" in batch[0]: + out["refseq_id"] = [x.get("refseq_id") for x in batch] + if "protein_refseq_id" in batch[0]: + out["protein_refseq_id"] = [x.get("protein_refseq_id") for x in batch] + + return out + +def _build_dataset( + path_or_paths: str | List[str], + tokenizer, + species_vocab_path: str, + shuffle_buffer: int, + csv_chunksize: int, + shard_across_ranks: bool = True, +) -> StreamSeqDataset: + files = _expand_paths(path_or_paths) + return StreamSeqDataset( + files=files, + tokenizer=tokenizer, + species_vocab_path=species_vocab_path, + unknown_species_id=0, + csv_chunksize=csv_chunksize, + shuffle_buffer=shuffle_buffer, + seed=1234, + shard_across_ranks=shard_across_ranks, + ) + + +def create_precomputed_dataloaders( + train_path: str | List[str], + val_path: Optional[str | List[str]], + embeddings_dir: str, + tokenizer, + batch_size: int, + num_workers: int = 4, + species_pooling: str = "sequence", + csv_chunksize: int = 200_000, + train_shuffle_buffer: int = 8192, + val_shuffle_buffer: int = 0, +) -> Tuple[DataLoader, Optional[DataLoader], SpeciesEmbeddingStore]: + """ + Returns: + - train_loader, val_loader (optional), and the SpeciesEmbeddingStore + """ + species_store = SpeciesEmbeddingStore(embeddings_dir, pin_memory=True, pooling=species_pooling) + species_vocab_path = os.path.join(embeddings_dir, "species_vocab.json") + num_workers = int(max(0, num_workers)) + + train_ds = _build_dataset( + path_or_paths=train_path, + tokenizer=tokenizer, + species_vocab_path=species_vocab_path, + shuffle_buffer=int(train_shuffle_buffer), + csv_chunksize=int(csv_chunksize), + ) + val_ds = None + if val_path: + val_ds = _build_dataset( + path_or_paths=val_path, + tokenizer=tokenizer, + species_vocab_path=species_vocab_path, + shuffle_buffer=int(val_shuffle_buffer), + csv_chunksize=int(csv_chunksize), + ) + + # NOTE: IterableDataset can't be shuffled by DataLoader. We already "shuffle" inside the dataset. + kwargs_common = dict( + num_workers=num_workers, + collate_fn=stage_collate_fn, + pin_memory=True, + persistent_workers=(num_workers > 0), + ) + if num_workers > 0: + kwargs_common["prefetch_factor"] = 4 + + # Drop last for train to keep batch shapes stable under FSDP. + train_loader = DataLoader( + train_ds, + batch_size=batch_size, + shuffle=False, + drop_last=True, + **kwargs_common, + ) + + val_loader = None + if val_ds is not None: + val_loader = DataLoader( + val_ds, + batch_size=batch_size, + shuffle=False, + drop_last=False, + **kwargs_common, + ) + + return train_loader, val_loader, species_store diff --git a/src/layers.py b/src/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e23b9b442c8b156bfa07afd7c93cb4087b0b0fd9 --- /dev/null +++ b/src/layers.py @@ -0,0 +1,384 @@ +""" +Transformer components for CodonGPT. +Includes RMSNorm, self-attention (SDPA/Flash) with optional mask, +cross-attention for conditioning memory, SwiGLU FFN, and a basic block. +""" + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel # Require recent PyTorch + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply RMS normalization. + + Args: + x: Input tensor of any shape ending in dim + + Returns: + Normalized tensor of same shape + """ + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * norm * self.weight + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """Apply rotary embeddings to x: [B,H,T,D]; cos/sin: [1,1,T,D].""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + x_rot = torch.zeros_like(x) + x_rot[..., ::2] = -x2 + x_rot[..., 1::2] = x1 + return x * cos + x_rot * sin + + +class MultiHeadAttention(nn.Module): + """Self-attention using PyTorch SDPA kernels (Flash/MemEff/Math) + RoPE. + - attn_mask: bool [B, T, T] with True = keep, False = block + - is_causal: whether to apply causal masking internally + """ + + def __init__( + self, + dim: int, + num_heads: int, + dropout: float = 0.0, + use_rope: bool = True, + ): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.dropout = dropout + self.use_rope = use_rope + + self.qkv = nn.Linear(dim, 3 * dim, bias=False) + self.out_proj = nn.Linear(dim, dim, bias=False) + self.resid_dropout = nn.Dropout(dropout) + + # RoPE cache + self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} + + def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + key = (T, device, dtype) + cached = self._rope_cache.get(key) + if cached is not None: + return cached + dim_half = self.head_dim // 2 + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) + t = torch.arange(T, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = torch.cos(freqs).repeat_interleave(2, dim=-1) + sin = torch.sin(freqs).repeat_interleave(2, dim=-1) + cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D] + sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) + self._rope_cache[key] = (cos, sin) + return cos, sin + + def forward( + self, + x: torch.Tensor, + past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_kv: bool = False, + position_offset: int = 0, + ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": + """ + Self-attention with optional KV cache support. + + Args: + x: [B, T_new, H] + past_kv: Optional tuple (k, v), each [B, nH, T_past, Hd] + return_kv: If True, also return updated (k, v) + position_offset: Starting position index for RoPE (past length) + + Returns: + out or (out, present_kv) + """ + B, T_new, _ = x.shape + + # QKV projections and reshape (ensure contiguous for SDPA kernels) + qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous() + + # RoPE for new tokens only + if self.use_rope: + # Compute cos/sin up to (offset + T_new), then slice the tail for new positions + cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) + if position_offset > 0: + cos = cos[:, :, position_offset: position_offset + T_new, :] + sin = sin[:, :, position_offset: position_offset + T_new, :] + # Apply to q and k_new + q = _apply_rope(q, cos, sin) + k_new = _apply_rope(k_new, cos, sin) + + # Concatenate with cache if provided + if past_kv is not None: + k_past, v_past = past_kv + k = torch.cat([k_past, k_new], dim=2) + v = torch.cat([v_past, v_new], dim=2) + is_causal = False # No future tokens present; avoid unnecessary masking + else: + k, v = k_new, v_new + is_causal = True + + # Prefer FlashAttention; fall back to MemEff then Math. Autocast to half/bfloat16 on CUDA. + backends = [SDPBackend.FLASH_ATTENTION]#, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] + with sdpa_kernel(backends): + if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): + amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + else: + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim) + # Align dtype with residual/Linear weights to avoid bf16/float mismatches + if out.dtype != x.dtype: + out = out.to(x.dtype) + out = self.out_proj(out) + out = self.resid_dropout(out) + + if return_kv: + return out, (k, v) + return out + + + +class GroupedQueryAttention(nn.Module): + """Grouped-Query Attention (GQA) using Flash Attention via PyTorch SDPA. + + - num_heads total query heads + - num_kv_groups shared K/V groups (num_heads must be divisible by num_kv_groups) + - Optional q/k RMSNorm + - Supports RoPE with a scalar or per-sample position_offset (like MHA) + - Optional KV cache compatible with the existing interface (stores expanded per-head K/V) + """ + + def __init__( + self, + dim: int, + num_heads: int, + num_kv_groups: int, + dropout: float = 0.0, + qk_norm: bool = False, + ) -> None: + super().__init__() + assert num_heads % max(1, num_kv_groups) == 0, "num_heads must be divisible by num_kv_groups" + self.dim = dim + self.num_heads = int(num_heads) + self.num_kv_groups = max(1, int(num_kv_groups)) + self.group_size = self.num_heads // self.num_kv_groups + + assert dim % num_heads == 0, "dim must be divisible by num_heads" + self.head_dim = dim // num_heads + self.dropout = dropout + + self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False) + self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) + self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False) + + self.q_norm = RMSNorm(self.head_dim) if qk_norm else None + self.k_norm = RMSNorm(self.head_dim) if qk_norm else None + + # RoPE cache + self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {} + + def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + key = (T, device, dtype) + cached = self._rope_cache.get(key) + if cached is not None: + return cached + dim_half = self.head_dim // 2 + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half)) + t = torch.arange(T, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = torch.cos(freqs).repeat_interleave(2, dim=-1) + sin = torch.sin(freqs).repeat_interleave(2, dim=-1) + cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D] + sin = sin.to(dtype).unsqueeze(0).unsqueeze(0) + self._rope_cache[key] = (cos, sin) + return cos, sin + + def forward( + self, + x: torch.Tensor, + past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_kv: bool = False, + position_offset: int | torch.Tensor = 0, + ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": + B, T_new, _ = x.shape + + # Project to Q, K, V + q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous() # [B,H,T,Hd] + k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd] + v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd] + + # Optional RMSNorm on q/k + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + + # RoPE for new tokens only + if isinstance(position_offset, int): + cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype) + if position_offset > 0: + cos = cos[:, :, position_offset: position_offset + T_new, :] + sin = sin[:, :, position_offset: position_offset + T_new, :] + q = _apply_rope(q, cos, sin) + k = _apply_rope(k, cos, sin) + else: + off = position_offset.to(device=x.device, dtype=torch.long) + max_off = int(off.max().item()) + cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype) + ar = torch.arange(T_new, device=x.device, dtype=torch.long) + idx = (off.unsqueeze(1) + ar.unsqueeze(0)) # [B, T_new] + cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) # [B,1,T,D] + sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) + q = _apply_rope(q, cos_b, sin_b) + # k has groups dimension [B,G,T,D]; share same offsets per batch + k = _apply_rope(k, cos_b, sin_b) + + # Expand grouped K/V to per-head by repeating groups + if self.group_size > 1: + k_exp = k.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd] + v_exp = v.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd] + else: + k_exp, v_exp = k, v # already per-head + + # KV cache: concatenate past along sequence dim + if past_kv is not None: + k_past, v_past = past_kv + k_cat = torch.cat([k_past, k_exp], dim=2) + v_cat = torch.cat([v_past, v_exp], dim=2) + is_causal = False + else: + k_cat, v_cat = k_exp, v_exp + is_causal = True + + # Prefer FlashAttention; fall back to MemEff/Math. Ensure CUDA autocast to half/bfloat16 so kernels are available + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): + if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16): + amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + with torch.amp.autocast(device_type="cuda", dtype=amp_dtype): + out = torch.nn.functional.scaled_dot_product_attention( + q, k_cat, v_cat, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) # [B,H,T,Hd] + else: + out = torch.nn.functional.scaled_dot_product_attention( + q, k_cat, v_cat, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) # [B,H,T,Hd] + + out = out.transpose(1, 2).contiguous().view(B, T_new, self.num_heads * self.head_dim) + # Ensure dtype compatibility for Linear / residual path + if out.dtype != x.dtype: + out = out.to(x.dtype) + out = self.out_proj(out) + + if return_kv: + return out, (k_cat, v_cat) + return out + + + +class FeedForward(nn.Module): + """Feed-forward network with optional GLU activation.""" + + def __init__( + self, + dim: int, + hidden_dim: int, + dropout: float = 0.0, + ): + super().__init__() + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply feed-forward network. + + Args: + x: Input tensor [B, T, dim] + + Returns: + Output tensor [B, T, dim] + """ + + return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) + + +class TransformerBlock(nn.Module): + """Pre-norm Transformer block using self-attn + SwiGLU FFN (no cross-attention).""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + num_kv_groups: int | None = None, + qk_norm: bool = False, + attn_type: str = "gqa", # "gqa" or "mha" + ): + super().__init__() + self.norm1 = RMSNorm(dim) + if attn_type == "mha": + self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout) + self._attn_is_gqa = False + else: + # Use Grouped-Query Attention (defaults to no grouping when num_kv_groups is None) + kv_groups = num_heads if (num_kv_groups is None) else max(1, int(num_kv_groups)) + self.attn = GroupedQueryAttention(dim=dim, num_heads=num_heads, num_kv_groups=kv_groups, dropout=dropout, qk_norm=qk_norm) + self._attn_is_gqa = True + self.norm2 = RMSNorm(dim) + self.ffn = FeedForward(dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=dropout) + + def forward( + self, + x: torch.Tensor, + past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + position_offset: int = 0, + ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]": + """Forward pass with optional KV caching.""" + if use_cache or (past_kv is not None): + attn_out = self.attn(self.norm1(x), past_kv=past_kv, return_kv=True, position_offset=position_offset) + x = x + attn_out[0] + x = x + self.ffn(self.norm2(x)) + return x, attn_out[1] + else: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000000000000000000000000000000000000..64932b66797e9b12afa1a4468c48121993056d7d --- /dev/null +++ b/src/models.py @@ -0,0 +1,490 @@ +""" +Core model architectures for CodonGPT (GPT-only). +- CodonGPT: Decoder-only GPT with two-species + protein prefix +Includes a frozen ESM-C encoder for protein conditioning. +""" + +import math +import os +from typing import Optional, Dict, Any, Tuple, List +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import torch.nn.utils.rnn as rnn_utils + +from .layers import RMSNorm, TransformerBlock +from .tokenizer import SpecialIds + + +class FrozenESMCEncoder(nn.Module): + """ + Frozen ESM-C encoder that computes protein embeddings on the fly. + Kept on single GPU per rank (not distributed via FSDP). + """ + + def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "fp16"): + super().__init__() + self.model_name = model_name + self._device = torch.device(device if torch.cuda.is_available() else "cpu") + if dtype == "fp16": + self._autocast_dtype = torch.float16 + elif dtype == "bf16": + self._autocast_dtype = torch.bfloat16 + else: + self._autocast_dtype = None + self._load_model() + self.eval() + for p in self.parameters(): + p.requires_grad_(False) + + def _load_model(self): + from esm.models.esmc import ESMC + from esm.utils.constants.models import ESMC_300M, ESMC_600M + if self.model_name == "esmc_300m": + model_const = ESMC_300M + self.D_esm = 960 + elif self.model_name == "esmc_600m": + model_const = ESMC_600M + self.D_esm = 1152 + else: + raise ValueError(f"Unknown model: {self.model_name}") + self.model = ESMC.from_pretrained(model_name=model_const, device=self._device) + self.tokenizer = self.model.tokenizer + + @torch.no_grad() + def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"): + from esm.utils import encoding + from esm.utils.misc import stack_variable_length_tensors + pad = self.tokenizer.pad_token_id + tokenized_seqs = [] + for seq in sequences: + tokens = encoding.tokenize_sequence(seq, self.tokenizer, add_special_tokens=add_special_tokens) + if max_length is not None and len(tokens) > max_length: + tokens = tokens[:max_length] + tokenized_seqs.append(tokens) + input_ids = stack_variable_length_tensors(tokenized_seqs, constant_value=pad) + attention_mask = (input_ids != pad) + return input_ids, attention_mask + + @torch.no_grad() + def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, return_contacts: bool = False): + device = self.model.device + input_ids = input_ids.to(device) + if attention_mask is not None: + attention_mask = attention_mask.to(device) + if self._autocast_dtype is not None and device.type == "cuda": + with torch.amp.autocast('cuda', dtype=self._autocast_dtype): + outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) + else: + outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask) + embeddings = outputs.embeddings + if return_dict: + return {"embeddings": embeddings, "attention_mask": attention_mask} + else: + return embeddings + + def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None): + if attention_mask is not None: + lengths = attention_mask.sum(dim=1) - 2 + lengths = lengths.clamp(min=1) + else: + B, L, D = embeddings.shape + lengths = torch.full((B,), L - 2, device=embeddings.device) + stripped = embeddings[:, 1:-1, :] + return stripped, lengths + + + + +class CodonGPT(nn.Module): + def __init__( + self, + vocab_size: int = 79, + hidden_size: int = 960, + num_layers: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4.0, + max_position_embeddings: int = 4096, + dropout: float = 0.1, + layer_norm_eps: float = 1e-6, + num_special_tokens: int = 13, + special_ids: Optional[SpecialIds] = None, + esm_model_name: str = "esmc_300m", + esm_device: str = "cuda", + esm_dtype: str = "fp16", + max_protein_prefix: int = 0, + max_species_prefix: int = 0, + prepend_species: bool = True, + prepend_protein: bool = True, + species_embedding_dim: int = 1024, + attn_impl: str = "gqa", # "gqa" or "mha" + num_kv_groups: int = 0, # for GQA; 0 means default (no grouping) + ): + super().__init__() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.max_position_embeddings = max_position_embeddings + + self.special_ids = special_ids or SpecialIds() + self.num_special_tokens = num_special_tokens + + # Single embedding table for all tokens (special + codon) + self.token_embed = nn.Embedding(vocab_size, hidden_size) + + if prepend_protein and esm_model_name: + self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype) + # Project ESM token embeddings (D_esm) to model hidden size, then normalize + self.esm_ln = nn.Sequential( + nn.Linear(self.esm.D_esm, hidden_size, bias=False), + nn.ReLU(), + nn.LayerNorm(hidden_size), + ) + else: + self.esm = None + self.esm_ln = None + + self.species_embedding_dim = species_embedding_dim if prepend_species else 0 + if prepend_species: + # Project species embeddings (fixed or token sequence) from Ds -> H + self.species_ln = nn.Sequential( + nn.Linear(self.species_embedding_dim, hidden_size, bias=False), + nn.ReLU(), + nn.LayerNorm(hidden_size), + ) + else: + self.species_ln = None + + # Optional per-prefix caps; 0 means unlimited (subject to global max length) + self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0 + self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0 + self.prepend_species = bool(prepend_species) + self.prepend_protein = bool(prepend_protein) + + # Learned start embedding (BOS-less decoding) + self.start_embed = nn.Parameter(torch.zeros(1, 1, hidden_size)) + nn.init.normal_(self.start_embed, mean=0.0, std=0.02) + + + # Attention configuration + self.attn_impl = str(attn_impl) + self.num_kv_groups = int(num_kv_groups) + kv_groups = self.num_kv_groups + self.blocks = nn.ModuleList([ + TransformerBlock( + dim=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + num_kv_groups=(kv_groups if (kv_groups > 0 and attn_impl == "gqa") else None), + qk_norm=False, + attn_type=("mha" if self.attn_impl == "mha" else "gqa"), + ) for _ in range(num_layers) + ]) + + self.ln_f = RMSNorm(hidden_size, eps=layer_norm_eps) + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + self.gradient_checkpointing = False + + def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: + device = self.token_embed.weight.device + return self.token_embed(token_ids.to(device)) + + def build_prefix( + self, + batch_size: int, + device: torch.device, + species_tok_emb: Optional[torch.Tensor] = None, + species_emb: Optional[torch.Tensor] = None, + protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + species_tok_emb_src: Optional[torch.Tensor] = None, + species_tok_emb_tgt: Optional[torch.Tensor] = None, + species_emb_src: Optional[torch.Tensor] = None, + species_emb_tgt: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Build LLaVA-style prefix token embeddings by concatenating + [species_src]+[species_tgt]+[protein_tokens]. Returns: + - prefix: [B, Lp, H] + - prefix_lengths: [B] valid token counts per sample + """ + parts: list[torch.Tensor] = [] + + # Species: src then tgt (if provided) + if self.prepend_species and self.species_ln is not None: + tok_src = species_tok_emb_src if species_tok_emb_src is not None else species_tok_emb + tok_tgt = species_tok_emb_tgt if species_tok_emb_tgt is not None else species_tok_emb + emb_src = species_emb_src if species_emb_src is not None else species_emb + emb_tgt = species_emb_tgt if species_emb_tgt is not None else species_emb + + def _as_tokens(S_tok, S_fix): + if S_fix is not None: + # [B, Ds] -> [B, 1, H] + S = self.species_ln(S_fix.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1)) + return S + elif S_tok is not None: + # [B, Ls, Ds] -> optional cap, then project to H + S = S_tok + if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix: + S = S[:, : self.max_species_prefix, :] + S = S.to(device=device, dtype=next(self.parameters()).dtype) + S = self.species_ln(S) + return S + else: + return None + + Ssrc = _as_tokens(tok_src, emb_src) + if Ssrc is not None: + parts.append(Ssrc) + Sdst = _as_tokens(tok_tgt, emb_tgt) + if Sdst is not None: + parts.append(Sdst) + + # Protein tokens from ESM-C + if self.prepend_protein and self.esm is not None and protein_input is not None: + prot_ids, prot_mask = protein_input + esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True) + P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask) + # Optional per-protein capping before projection + if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix: + P = P[:, : self.max_protein_prefix, :] + if lengths is not None: + lengths = lengths.clamp(max=self.max_protein_prefix) + if P.size(1) > 0: + P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype)) + # Zero padded rows (per-sample) based on lengths + if lengths is not None: + Lp = P.size(1) + ar = torch.arange(Lp, device=device).unsqueeze(0) + lengths = lengths.to(device=device) + valid = ar < lengths.unsqueeze(1) # [B,Lp] + P = P * valid.unsqueeze(-1) + parts.append(P) + + if len(parts) == 0: + empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) + return empty, torch.zeros(batch_size, dtype=torch.long, device=device) + + prefix = torch.cat(parts, dim=1) if parts else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) # [B,Lp,H] + # Compute per-sample valid lengths: treat zero rows as padding + with torch.no_grad(): + if prefix.size(1) > 0: + valid = (prefix.abs().sum(dim=-1) > 0) + lengths = valid.sum(dim=1).to(torch.long) + else: + lengths = torch.zeros(batch_size, dtype=torch.long, device=device) + + # ---- Enforce hard global budget on the prefix itself ---- + prefix_budget = max(0, int(self.max_position_embeddings) - 1) + if prefix_budget == 0: + trimmed = prefix.new_zeros(prefix.size(0), 0, prefix.size(2)) + return trimmed, torch.zeros(prefix.size(0), dtype=torch.long, device=prefix.device) + + allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype)) + Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0 + if prefix.size(1) > Lp_max: + trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2)) + for b in range(prefix.size(0)): + lb = int(allow[b].item()) + if lb > 0: + trimmed[b, :lb, :] = prefix[b, :lb, :] + prefix = trimmed + lengths = allow + else: + lengths = allow + return prefix, lengths + + def forward( + self, + codon_ids: torch.Tensor, + cond: Dict[str, Any] = None, + labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + species_tok_emb: Optional[torch.Tensor] = None, + protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + protein_seqs: Optional[List[str]] = None, + # KV cache options + use_cache: bool = False, + past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + position_offset: int = 0, + ) -> Dict[str, torch.Tensor]: + batch_size, codon_len = codon_ids.shape + device = codon_ids.device + + # Unpack conditioning + if cond is not None: + control_mode = cond.get("control_mode", "fixed") + species_tok_emb_src = cond.get("species_tok_emb_src") + species_tok_emb_tgt = cond.get("species_tok_emb_tgt") + species_emb_src = cond.get("species_emb_src") + species_emb_tgt = cond.get("species_emb_tgt") + species_tok_emb = cond.get("species_tok_emb") + species_emb = cond.get("species_emb") + protein_input = cond.get("protein_input") + protein_seqs = cond.get("protein_seqs") + else: + species_emb = None + species_tok_emb_src = None + species_tok_emb_tgt = None + species_emb_src = None + species_emb_tgt = None + + if protein_seqs is not None and protein_input is None: + if self.esm is not None: + with torch.no_grad(): + # Respect per-protein ceiling during tokenization (+2 for BOS/EOS) + max_len_tokens = (self.max_protein_prefix + 2) if (getattr(self, "max_protein_prefix", 0) > 0) else None + protein_input = self.esm.tokenize(protein_seqs, max_length=max_len_tokens) + else: + protein_input = None + + # Fast path: incremental decode using KV cache + if past_kv is not None: + # Expect only newly generated codon tokens here + if codon_ids.numel() == 0: + # Nothing to do; return a dummy next_logits + dummy = torch.zeros(batch_size, self.vocab_size, device=device, dtype=self.lm_head.weight.dtype) + return {"logits": dummy[:, 0:0], "next_logits": dummy} + + x = self.embed_tokens(codon_ids) # [B, T_new, H] + + present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for i, block in enumerate(self.blocks): + kv_i = past_kv[i] if i < len(past_kv) else None + if self.training and getattr(self, 'gradient_checkpointing', False): + def _fn(inp): + return block(inp, past_kv=kv_i, use_cache=True, position_offset=position_offset) + out_blk = checkpoint.checkpoint(_fn, x, use_reentrant=False) + else: + out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset) + x, kv_out = out_blk # type: ignore[assignment] + present_kv.append(kv_out) + + x = self.ln_f(x) + logits_step = self.lm_head(x) # [B, T_new, V] + next_logits = logits_step[:, -1, :] + out: Dict[str, torch.Tensor] = {"logits": logits_step[:, 0:0, :], "next_logits": next_logits} + out["present_kv"] = present_kv # type: ignore[assignment] + return out if return_dict else logits_step[:, 0:0, :] + + # Standard path: build prefix and full window (training or prefill) + prefix, prefix_lengths = self.build_prefix( + batch_size=batch_size, + device=device, + species_tok_emb=species_tok_emb, + species_emb=species_emb if cond is not None else None, + protein_input=protein_input, + species_tok_emb_src=species_tok_emb_src, + species_tok_emb_tgt=species_tok_emb_tgt, + species_emb_src=species_emb_src, + species_emb_tgt=species_emb_tgt, + ) + + start = self.start_embed.expand(batch_size, 1, self.hidden_size) # [B,1,H] + + # Per-sample true codon input lengths (exclude PADs) + pad_id = int(self.special_ids.pad) if hasattr(self, "special_ids") and self.special_ids is not None else 0 + codon_mask = (codon_ids != pad_id) # [B, N] + codon_lens = codon_mask.sum(dim=1) # [B] + + # Budget remaining after prefix + start + capacity = max(0, int(self.max_position_embeddings)) + budget_after_prefix = torch.clamp( + torch.as_tensor(capacity, device=device) - (prefix_lengths + 1), + min=0, + ) # [B] + # Per-sample cap is limited by both budget and available codons + per_cap = torch.minimum(budget_after_prefix, codon_lens) # [B] + + # Total valid lengths per sample (prefix + start + capped codon) + valid_lengths = prefix_lengths + 1 + per_cap + T = int(valid_lengths.max().item()) if valid_lengths.numel() > 0 else (1 + int(codon_lens.max().item()) if codon_lens.numel() > 0 else 1) + + # Embed only the needed codon window for this batch + max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0 + if max_cap > 0: + codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) # [B, max_cap, H] + else: + codon_emb = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype) + + # Build sequence per-sample using concat to preserve gradients, then pad + seqs = [] + for b in range(batch_size): + lp = int(prefix_lengths[b].item()) + cap = int(per_cap[b].item()) + parts = [] + if lp > 0: + parts.append(prefix[b, :lp, :]) + parts.append(start[b, 0:1, :]) + if cap > 0: + parts.append(codon_emb[b, :cap, :]) + seqs.append(torch.cat(parts, dim=0)) # [Lb, H] + x = rnn_utils.pad_sequence(seqs, batch_first=True) # [B, T, H] + + present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for block in self.blocks: + if self.training and getattr(self, 'gradient_checkpointing', False): + def _fn(inp): + return block(inp, use_cache=use_cache, position_offset=0) + blk_out = checkpoint.checkpoint(_fn, x, use_reentrant=False) + else: + blk_out = block(x, use_cache=use_cache, position_offset=0) + if use_cache: + x, kv = blk_out # type: ignore[misc] + present_kv_list.append(kv) + else: + x = blk_out # type: ignore[assignment] + + x = self.ln_f(x) + logits_full = self.lm_head(x) # [B, T, V] + + # Gather codon-aligned logits per sample: positions (lp+1) .. (lp+cap) (skip start) + next_logits_list = [] + if max_cap == 0: + # Keep graph by slicing from logits_full + codon_logits = logits_full[:, 0:0, :] + for b in range(batch_size): + lp = int(prefix_lengths[b].item()) + # Last consumed position is the start token at index lp + pos_next = lp + if pos_next < logits_full.size(1): + next_logits_list.append(logits_full[b, pos_next, :]) + else: + next_logits_list.append(logits_full[b, -1, :]) + next_logits = torch.stack(next_logits_list, dim=0) + else: + slices = [] + for b in range(batch_size): + lp = int(prefix_lengths[b].item()) + cap = int(per_cap[b].item()) + # Skip the start position so logits align with labels = codon_ids[:, 1:] + sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size) + slices.append(sl) + # Next-token logits after processing 'cap' codons: last consumed is at lp + cap + pos_next = lp + cap + next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size)) + codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) # [B,max_cap,V] + next_logits = torch.stack(next_logits_list, dim=0) + out = {"logits": codon_logits, "next_logits": next_logits} + + if labels is not None: + # Align labels to per-sample caps: mask out positions >= cap + if labels.size(1) > 0 and max_cap > 0: + # Build masked labels with -100 beyond cap per sample + adj = labels.new_full((batch_size, max_cap), -100) + for b in range(batch_size): + cap = int(per_cap[b].item()) + if cap > 0: + Lb = min(cap, labels.size(1)) + adj[b, :Lb] = labels[b, :Lb] + loss = F.cross_entropy(codon_logits.reshape(-1, self.vocab_size), adj.reshape(-1), ignore_index=-100) + else: + loss = codon_logits.sum() * 0.0 + out["loss"] = loss + # Provide optional debug stats for trainer logging + out["prefix_len"] = prefix_lengths.detach() + out["per_cap"] = per_cap.detach() + if use_cache: + out["present_kv"] = present_kv_list # type: ignore[assignment] + return out if return_dict else codon_logits diff --git a/src/sampler.py b/src/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..28583045a8dea85a1af69cc0bf8a9d7995043c67 --- /dev/null +++ b/src/sampler.py @@ -0,0 +1,696 @@ +# src/sampler.py +""" +Sampling utilities for CodonGPT. + +Conditioning invariants: +- Species context: fixed-size [B, Ds] via species_emb or variable-length [B, Ls, Ds] via species_tok_emb +- Protein context: raw sequences; the model's Frozen ESM handles tokenization +""" + +from __future__ import annotations +from typing import List, Optional, Dict, Union, Tuple +from pathlib import Path +import logging +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from safetensors.torch import load_file + +from .models import CodonGPT +from .tokenizer import CodonTokenizer + +logger = logging.getLogger(__name__) + + +# ---------------------------- +# Logit filtering +# ---------------------------- + +def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor: + return logits if logits.dim() == 2 else logits.unsqueeze(0) + +def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor: + """Top-k filtering; logits is [B,V] or [V].""" + x = _ensure_2d_logits(logits) + k = max(1, min(int(k), x.size(-1))) + values, _ = torch.topk(x, k, dim=-1) + min_values = values[:, -1].unsqueeze(-1) + x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x) + return x if logits.dim() == 2 else x.squeeze(0) + +def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor: + """Top-p (nucleus) filtering; logits is [B,V] or [V].""" + if p >= 1.0: + return logits + if p <= 0.0: + # You asked for nothing; enjoy the abyss. + return torch.full_like(logits, float('-inf')) + x = _ensure_2d_logits(logits) + sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1) + probs = F.softmax(sorted_logits, dim=-1) + cumprobs = torch.cumsum(probs, dim=-1) + to_remove = cumprobs > p + to_remove[:, 1:] = to_remove[:, :-1].clone() + to_remove[:, 0] = False + mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove) + x = torch.where(mask, torch.full_like(x, float('-inf')), x) + return x if logits.dim() == 2 else x.squeeze(0) + + +# ---------------------------- +# Sampler +# ---------------------------- + +class CodonSampler: + """ + GPT sampler with conditional generation. + + Requires in model_dir: + - vocab.json + - model.safetensors (preferred) + or pytorch_model.bin (legacy) + - trainer_config.json or config.json + """ + + def __init__( + self, + model_path: str, + device: str = "cuda", + species_store=None, # SpeciesEmbeddingStore + compile_model: bool = False, + taxonomy_db_path: Optional[str] = None, + qwen_max_length: int = 512, + qwen_batch_size: int = 16, + **_: dict, + ): + self.device = torch.device(device) + self.model_dir = Path(model_path) + + # Required files (allow fallback to parent dir for vocab.json) + vocab_path = self.model_dir / "vocab.json" + if not vocab_path.exists(): + parent_vocab = self.model_dir.parent / "vocab.json" + if parent_vocab.exists(): + vocab_path = parent_vocab + else: + raise FileNotFoundError(f"Missing {self.model_dir / 'vocab.json'}") + trainer_cfg = self.model_dir / "trainer_config.json" + cfg_path = trainer_cfg if trainer_cfg.exists() else (self.model_dir / "config.json") + if not cfg_path.exists(): + raise FileNotFoundError(f"Missing trainer_config.json or config.json in {self.model_dir}") + + # Load config + with open(cfg_path, "r") as f: + self.config = json.load(f) + + # Tokenizer + # If vocab was loaded from parent dir, pass that path; else model_dir + vocab_dir = vocab_path.parent + self.tokenizer = CodonTokenizer.from_pretrained(str(vocab_dir)) + self.V = int(self.tokenizer.vocab_size) + self._eos_id = int(self.tokenizer.eos_token_id) + self._pad_id = int(self.tokenizer.pad_token_id) + self._num_special = int(self.tokenizer.num_special_tokens) + + # Species store (optional if you pass species_emb* directly at sample()) + self.species_store = species_store + self.species_vocab = (self.species_store.vocab if self.species_store is not None else {}) + self.taxonomy_db_path = taxonomy_db_path + self.qwen_opts = { + "max_length": int(qwen_max_length), + "batch_size": int(qwen_batch_size), + } + # Lazy-inited Qwen objects + self._qwen_tokenizer = None + self._qwen_model = None + + # Model + state = self._load_state_dict() + arch = self._infer_arch_from_state_dict(state) + self.model = CodonGPT( + vocab_size=self.V, + hidden_size=int(arch["hidden_size"]), + num_layers=int(arch["num_layers"]), + num_heads=int(arch["num_heads"]), + mlp_ratio=float(arch["mlp_ratio"]), + max_position_embeddings=int(arch["max_position_embeddings"]), + dropout=float(self.config.get("dropout", 0.1)), + num_special_tokens=self._num_special, + special_ids=self.tokenizer.special_ids, + esm_model_name=str(arch["esm_model_name"]) if bool(arch["prepend_protein"]) else None, + esm_device=str(arch["esm_device"]), + esm_dtype=str(arch["esm_dtype"]), + max_protein_prefix=int(arch["max_protein_prefix"]) if bool(arch["prepend_protein"]) else 0, + max_species_prefix=int(arch["max_species_prefix"]) if bool(arch["prepend_species"]) else 0, + prepend_species=bool(arch["prepend_species"]), + prepend_protein=bool(arch["prepend_protein"]), + species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)), + attn_impl=str(arch.get("attn_impl", "gqa")), + num_kv_groups=int(arch.get("num_kv_groups", 0)), + ) + missing, unexpected = self.model.load_state_dict(state, strict=False) + if len(unexpected) > 0: + logger.warning(f"Unexpected keys in state dict: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}") + if len(missing) > 0: + logger.warning(f"Missing keys in state dict: {missing[:10]}{'...' if len(missing) > 10 else ''}") + + if compile_model: + # If this errors on your PyTorch build, that's on you. No try/except. + self.model = torch.compile(self.model) # type: ignore + + self.model.to(self.device).eval() + logger.info(f"Loaded GPT model from {self.model_dir}") + try: + hs = int(getattr(self.model, "hidden_size", -1)) + hh = int(getattr(self.model, "num_heads", -1)) + nl = int(getattr(self.model, "num_layers", -1)) + logger.info(f"Reconstructed arch: hidden={hs} heads={hh} layers={nl}") + except Exception: + pass + + # Static masks + self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device) + self._allowed_fixed[:self._num_special] = False # no specials in fixed mode + + self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device) + self._allowed_variable[:self._num_special] = False + self._allowed_variable[self._eos_id] = True # EOS allowed in variable mode + + # ---------------------------- + # Loading / arch inference + # ---------------------------- + + def _load_state_dict(self) -> Dict[str, torch.Tensor]: + st_p = self.model_dir / "model.safetensors" + pt_p = self.model_dir / "pytorch_model.bin" + if st_p.exists(): + return load_file(st_p) + if pt_p.exists(): + return torch.load(pt_p, map_location="cpu") + raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}") + + def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Union[int, float, bool, str]]: + arch: Dict[str, Union[int, float, bool, str]] = {} + + # hidden size + if "lm_head.weight" in state_dict: + arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1]) + else: + for k, v in state_dict.items(): + if k.endswith("ln_f.weight"): + arch["hidden_size"] = int(v.shape[0]) + break + # Prefer config when present to avoid guessing errors + cfg = self.config or {} + if "hidden_size" in cfg: + arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore[index] + if "hidden_size" not in arch: + arch["hidden_size"] = int(cfg.get("hidden_size", 960)) + H = int(arch["hidden_size"]) + + # layers + max_block = -1 + for k in state_dict.keys(): + if k.startswith("blocks."): + idx = int(k.split(".")[1]) + if idx > max_block: + max_block = idx + arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12)) + if "num_hidden_layers" in cfg: + arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore[index] + + # mlp ratio from w1 + w1_key = "blocks.0.ffn.w1.weight" if "blocks.0.ffn.w1.weight" in state_dict else None + if w1_key is None: + for i in range(1, 3): + k = f"blocks.{i}.ffn.w1.weight" + if k in state_dict: + w1_key = k + break + if w1_key is not None and H > 0: + arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H) + else: + arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0)) + + # heads – pick a divisor of H + cfg_heads = cfg.get("num_attention_heads") + if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0: + arch["num_heads"] = int(cfg_heads) + else: + for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1): + if H % h == 0: + arch["num_heads"] = h + break + + # conditioning flags from presence of submodules + arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys()))) + has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys()) + arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm))) + arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m")) + arch["esm_device"] = str(cfg.get("esm_device", "cuda")) + arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower() + arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0)) + arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0)) + + if "max_length" in cfg: + arch["max_position_embeddings"] = int(cfg.get("max_length", 1024)) + else: + arch["max_position_embeddings"] = int(cfg.get("max_position_embeddings", 1024)) + # Attention impl and num_kv_groups (from config or infer from weights) + attn_impl = str(cfg.get("attn_impl", "")) + num_kv_groups = int(cfg.get("num_kv_groups", 0)) + if not attn_impl: + wk_key = next((k for k in state_dict.keys() if k.endswith("attn.Wk.weight")), None) + if wk_key is not None: + attn_impl = "gqa" + out_ch, _ = state_dict[wk_key].shape + num_heads = int(arch.get("num_heads", 1)) + head_dim = int(arch["hidden_size"]) // max(1, num_heads) + if head_dim > 0: + num_kv_groups = max(1, out_ch // head_dim) + else: + attn_impl = "mha" + num_kv_groups = 0 + arch["attn_impl"] = attn_impl + arch["num_kv_groups"] = num_kv_groups + + return arch # type: ignore[return-value] + + # ---------------------------- + # Public API + # ---------------------------- + + @torch.no_grad() + def sample( + self, + num_sequences: int = 1, + sequence_length: int = 100, # target number of codons (fixed mode); max iterations (variable) + species: Optional[Union[str, List[str]]] = None, + protein_sequences: Optional[Union[str, List[str]]] = None, + control_mode: str = "fixed", # "fixed" or "variable" + target_protein_length: Optional[int] = None, # deprecated; alias to sequence_length + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + seed: Optional[int] = None, + return_intermediate: bool = False, + progress_bar: bool = False, + species_emb: Optional[torch.Tensor] = None, # [B, Ds] + species_tok_emb: Optional[torch.Tensor] = None, # [B, Ls, Ds] + enforce_translation: bool = False, + codon_enforcement_weight: float = 10.0, # unused with hard mask; kept for API compatibility + ) -> Dict[str, Union[List[str], torch.Tensor, List[bool]]]: + + if seed is not None: + torch.manual_seed(int(seed)) + np.random.seed(int(seed)) + + if control_mode not in ("fixed", "variable"): + raise ValueError(f"control_mode must be 'fixed' or 'variable', got {control_mode}") + + B = int(num_sequences) + T_codons = int(sequence_length if target_protein_length is None else target_protein_length) + + # Prepare conditioning + cond: Dict[str, Union[str, List[str], torch.Tensor]] = {"control_mode": control_mode} + + # Species (priority: provided tensors → names via store) + if species_tok_emb is not None: + if species_tok_emb.ndim != 3 or species_tok_emb.size(0) != B: + raise ValueError("species_tok_emb must be [B, Ls, Ds]") + st = species_tok_emb.to(self.device) + cond["species_tok_emb_src"] = st + cond["species_tok_emb_tgt"] = st + elif species_emb is not None: + if species_emb.ndim != 2 or species_emb.size(0) != B: + raise ValueError("species_emb must be [B, Ds]") + se = species_emb.to(self.device) + cond["species_emb_src"] = se + cond["species_emb_tgt"] = se + elif species is not None: + names = [species] * B if isinstance(species, str) else species + if len(names) != B: + raise ValueError("Length of species list must match num_sequences") + + # If we have a store (variable-length), use it for known species and compute Qwen embeddings for unknowns. + if self.species_store is not None: + ids = [self.species_store.vocab.get(n, -1) for n in names] + known_mask = [i for i, sid in enumerate(ids) if sid >= 0] + unk_mask = [i for i, sid in enumerate(ids) if sid < 0] + + # Only variable-length embeddings are supported. If the store is not sequence-based, compute via Qwen for all. + use_sequence = bool(getattr(self.species_store, "is_legacy", False)) + if not use_sequence: + # Fall back to Qwen for everything + q_tok, q_len = self._qwen_embed_names(names, pooling="sequence") + cond["species_tok_emb_src"] = q_tok.to(self.device) + cond["species_tok_emb_tgt"] = q_tok.to(self.device) + else: + # list of per-sample [L,D] tensors to be padded later + seq_list: List[torch.Tensor] = [None] * B # type: ignore[list-item] + D = int(getattr(self.species_store, "_ds", 1024)) + # Known via store + if known_mask: + sub_ids = [ids[i] for i in known_mask] + result = self.species_store.batch_get(sub_ids) + assert isinstance(result, tuple) + sp_tok, _ = result + for j, i in enumerate(known_mask): + row = sp_tok[j] + nonzero = (row.abs().sum(dim=-1) > 0) + L = int(nonzero.sum().item()) if nonzero.any() else int(row.size(0)) + seq_list[i] = row[:L].to(self.device) + # Unknown via Qwen + if unk_mask: + unk_names = [names[i] for i in unk_mask] + q_tok, q_len = self._qwen_embed_names(unk_names, pooling="sequence") + for j, i in enumerate(unk_mask): + L = int(q_len[j].item()) + seq_list[i] = q_tok[j, :L, :].to(self.device) + + # Pad to [B,Lmax,D] + Lmax = max((t.size(0) for t in seq_list if t is not None), default=0) + if Lmax == 0: + raise RuntimeError("No species embeddings could be constructed.") + padded = torch.zeros(B, Lmax, D, device=self.device, dtype=seq_list[0].dtype) + for i, t in enumerate(seq_list): + if t is None: + continue + L = t.size(0) + padded[i, :L, :] = t + cond["species_tok_emb_src"] = padded + cond["species_tok_emb_tgt"] = padded + else: + # No store: compute everything via Qwen (sequence pooling only) + emb, lengths = self._qwen_embed_names(names, pooling="sequence") + st = emb.to(self.device, non_blocking=True) + cond["species_tok_emb_src"] = st + cond["species_tok_emb_tgt"] = st + + # Protein sequences (raw AA strings; the model handles ESM-C) + if protein_sequences is not None: + if isinstance(protein_sequences, list): + if len(protein_sequences) != B: + raise ValueError("Length of protein_sequences must match num_sequences") + cond["protein_seqs"] = protein_sequences + else: + cond["protein_seqs"] = [protein_sequences] * B + + # Start with empty codon context; we'll prefill to build KV cache and get first-step logits + input_ids = torch.empty((B, 0), dtype=torch.long, device=self.device) + + # Capacity probe and fallback: if prefix consumes all budget, cap species/protein prefix temporarily (prefill path) + pref = None + try: + out0 = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) + pref = out0.get("prefix_len") if isinstance(out0, dict) else None + if pref is not None: + max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) + remaining0 = max_pos - (pref + 1) + need_cap = (remaining0 <= 0).any() + else: + need_cap = False + if need_cap: + prev_sp = int(getattr(self.model, "max_species_prefix", 0)) + prev_pp = int(getattr(self.model, "max_protein_prefix", 0)) + if prev_sp == 0 or prev_sp > 256: + setattr(self.model, "max_species_prefix", 256) + if prev_pp == 0 or prev_pp > 256: + setattr(self.model, "max_protein_prefix", 256) + out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) + pref = out0b.get("prefix_len") if isinstance(out0b, dict) else None + if pref is not None: + remaining0b = max_pos - (pref + 1) + if (remaining0b <= 0).all(): + setattr(self.model, "max_species_prefix", 128) + setattr(self.model, "max_protein_prefix", 128) + out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) + pref = out0b.get("prefix_len") if isinstance(out0b, dict) else pref + # Use the prefill output + out_prefill = out0 if pref is None else out0 + except Exception: + # Fallback without cache + out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) + pref = out_prefill.get("prefix_len") if isinstance(out_prefill, dict) else None + + allowed = self._allowed_variable if control_mode == "variable" else self._allowed_fixed + finished = torch.zeros(B, dtype=torch.bool, device=self.device) # EOS reached (variable) OR capacity exhausted + capacity_truncated = torch.zeros(B, dtype=torch.bool, device=self.device) + + intermediate = [] if return_intermediate else None + aa2codons = self.tokenizer.aa2codons_char_map() + + # If we probed capacity, optionally clamp target codons by available capacity at step 0 + try: + if pref is not None: + max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) + remaining = (max_pos - (pref + 1)).clamp(min=0) + T_codons = int(min(T_codons, int(remaining.max().item()))) + except Exception: + pass + + # KV cache and initial logits from prefill + kv = out_prefill.get("present_kv") if isinstance(out_prefill, dict) else None + logits = out_prefill.get("next_logits") if isinstance(out_prefill, dict) else None + if kv is None or logits is None: + # Safety: compute once if not provided + out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) + kv = out_prefill.get("present_kv") + logits = out_prefill.get("next_logits") + assert kv is not None and logits is not None + prefix_len = pref if pref is not None else torch.zeros(B, dtype=torch.long, device=self.device) + prefill_len = (prefix_len + 1) # prefix + start + + rng = range(T_codons) + if progress_bar: + from tqdm import tqdm + rng = tqdm(rng, desc="GPT sampling", total=T_codons) + + for step in rng: + # Enforce global capacity per sample using prefix_len and current generated length + max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) + remaining_now = (max_pos - prefill_len - input_ids.size(1)).clamp(max=10**9) + cant_extend = remaining_now <= 0 + newly_blocked = (~finished) & cant_extend + capacity_truncated = capacity_truncated | newly_blocked + finished = finished | cant_extend + + # Base mask: disallow specials in fixed, allow EOS in variable. + logits = logits.masked_fill(~allowed, float("-inf")) + + # If a sample is finished (EOS or capacity), force PAD to keep shapes stable. + # Decoding will drop PAD anyway. + if finished.any(): + logits[finished] = float("-inf") + logits[finished, self._pad_id] = 0.0 + + # Optional: enforce codon ↔ AA mapping at this step (hard mask) + if enforce_translation and ("protein_seqs" in cond): + aas_now: List[Optional[str]] = [] + prot_list = cond["protein_seqs"] # type: ignore[index] + assert isinstance(prot_list, list) + for i in range(B): + seq = prot_list[i] + aas_now.append(seq[step] if step < len(seq) else None) + + mask = torch.zeros_like(logits, dtype=torch.bool) + for i, a in enumerate(aas_now): + if a is None: + mask[i, self._num_special:self.V] = True + else: + valid = aa2codons.get(a, []) + if len(valid) == 0: + mask[i, self._num_special:self.V] = True + else: + mask[i, valid] = True + logits = logits.masked_fill(~mask, float("-inf")) + + # Temperature + filtering + if temperature != 1.0: + logits = logits / float(temperature) + if top_k is not None: + logits = _top_k_filtering(logits, int(top_k)) + if top_p is not None: + logits = _top_p_filtering(logits, float(top_p)) + + probs = F.softmax(logits, dim=-1) + next_tok = torch.multinomial(probs, num_samples=1) # [B,1] + + if control_mode == "variable": + # Stop sequences at EOS + eos_mask = (next_tok.squeeze(-1) == self._eos_id) + finished = finished | eos_mask + + input_ids = torch.cat([input_ids, next_tok], dim=1) + + if return_intermediate: + intermediate.append(input_ids.clone()) + + # If all sequences are finished, we're done. + if finished.all(): + break + + # Incremental decode: compute logits for next step and update KV cache + pos_offset = int(prefill_len.max().item()) + input_ids.size(1) - 1 # use max offset for shared RoPE cache + out_inc = self.model( + codon_ids=next_tok, + cond=None, + return_dict=True, + use_cache=True, + past_kv=kv, + position_offset=pos_offset, + ) + kv = out_inc.get("present_kv") + logits = out_inc.get("next_logits") + assert kv is not None and logits is not None + + # Build final DNA strings, dropping specials and any PADs we added + output_token_rows: List[List[int]] = [] + for row in input_ids.tolist(): + toks: List[int] = [] + for t in row: + if t == self._pad_id: + continue + if t == self._eos_id: + break # variable mode terminator + if t >= self._num_special and t < self.V: + toks.append(int(t)) + if control_mode == "fixed": + # In fixed mode we *intended* T_codons; if capacity cut us short, it's fine. + toks = toks[:T_codons] + output_token_rows.append(toks) + + sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows] + + # Pad variable-length rows for input_ids to avoid tensor construction errors when + # some samples are capacity-truncated in fixed mode. + max_len = max((len(r) for r in output_token_rows), default=0) + if max_len > 0: + ids_padded = torch.full( + (len(output_token_rows), max_len), + self._pad_id, + device=self.device, + dtype=torch.long, + ) + for i, row in enumerate(output_token_rows): + if len(row) > 0: + ids_padded[i, : len(row)] = torch.tensor(row, device=self.device, dtype=torch.long) + else: + ids_padded = torch.empty((len(output_token_rows), 0), device=self.device, dtype=torch.long) + + result: Dict[str, Union[List[str], torch.Tensor, List[bool]]] = { + "sequences": sequences, + "input_ids": ids_padded, + "capacity_truncated": capacity_truncated.detach().bool().tolist(), + } + if return_intermediate: + result["intermediate_states"] = intermediate # list[Tensor], length = steps actually taken + return result + + # ---------------------------- + # Qwen embedding (inline; no separate module) + # ---------------------------- + def _ensure_qwen_loaded(self): + if self._qwen_tokenizer is not None and self._qwen_model is not None: + return + from transformers import AutoTokenizer, AutoModel + self._qwen_tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left" + ) + dtype = torch.float16 if self.device.type == "cuda" else torch.float32 + self._qwen_model = AutoModel.from_pretrained( + "Qwen/Qwen3-Embedding-0.6B", torch_dtype=dtype, trust_remote_code=True + ).to(self.device).eval() + + @staticmethod + def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + @staticmethod + def _format_instruct(task: str, query: str) -> str: + return f"Instruct: {task}\nQuery: {query}" + + @torch.no_grad() + def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Load taxonomy DB if provided + taxonomy_db = None + if self.taxonomy_db_path: + try: + with open(self.taxonomy_db_path, "r") as f: + import json + taxonomy_db = json.load(f) + except Exception: + taxonomy_db = None + + self._ensure_qwen_loaded() + tokenizer = self._qwen_tokenizer + model = self._qwen_model + assert tokenizer is not None and model is not None + + task = ( + "Given a species taxonomy information, generate a biological embedding " + "representing its taxonomic and evolutionary characteristics" + ) + texts = [self._format_instruct(task, taxonomy_db.get(s, s) if taxonomy_db else s) for s in names] + + BATCH = int(self.qwen_opts.get("batch_size", 16)) + max_len = int(self.qwen_opts.get("max_length", 512)) + + # sequence pooling only + seqs: List[torch.Tensor] = [] + lens: List[int] = [] + for i in range(0, len(texts), BATCH): + chunk = texts[i : i + BATCH] + inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(self.device) + out = model(**inputs) + h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) # [B,L,D] + attn = inputs["attention_mask"] + for j in range(h.size(0)): + L = int(attn[j].sum().item()) + seqs.append(h[j, :L, :].float().cpu()) + lens.append(L) + # Pad to [B,Lmax,D] + Lmax = max(lens) if lens else 0 + D = seqs[0].size(1) if seqs else 0 + padded = torch.zeros(len(seqs), Lmax, D) + for i, t in enumerate(seqs): + padded[i, : t.size(0), :] = t + return padded, torch.tensor(lens, dtype=torch.long) + + # ---------------------------- + # Conditioning helper + # ---------------------------- + + # (Kept minimal. Species embeddings are prepared inline in sample().) + + +# ---------------------------- +# Convenience function +# ---------------------------- + +def sample_sequences( + model_path: str, + num_sequences: int = 10, + sequence_length: int = 100, + species: Optional[Union[str, List[str]]] = None, + protein_sequence: Optional[Union[str, List[str]]] = None, + **kwargs +) -> List[str]: + sampler = CodonSampler(model_path) + out = sampler.sample( + num_sequences=num_sequences, + sequence_length=sequence_length, + species=species, + protein_sequences=protein_sequence, + **kwargs + ) + return out["sequences"] # type: ignore[return-value] diff --git a/src/tokenizer.py b/src/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1f7c4db6edd6ba747ad216112b506429e757da --- /dev/null +++ b/src/tokenizer.py @@ -0,0 +1,324 @@ +# src/tokenizer.py +""" +Codon tokenizer: 3-mer tokens + 4 special tokens. + +No frameworks, no inheritance chains. Just: +- encode_codon_seq("ATG...") -> [ids...] (appends EOS outside, not here) +- decode_codon_seq([ids...]) -> "ATG..." +- save_vocabulary(dir) / from_pretrained(dir) for reproducible runs + +Special IDs are fixed and contiguous from 0: + pad=0, unk=1, bos=2, eos=3 +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any + + +# ------------------------------ +# Special token ids +# ------------------------------ + +@dataclass(frozen=True) +class SpecialIds: + pad: int = 0 + unk: int = 1 + bos: int = 2 + eos: int = 3 + + def to_dict(self) -> Dict[str, int]: + return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos} + + +# ------------------------------ +# Tokenizer +# ------------------------------ + +class CodonTokenizer: + """Minimal tokenizer for codon (DNA 3-mer) sequences.""" + + __slots__ = ( + "codons", + "_special_token_str", + "vocab", + "ids_to_tokens", + "_special_ids", + "_num_special_tokens", + "_genetic_code", + "_codon2aa_char", + "_aa2codons_char", + ) + + def __init__( + self, + pad_token: str = "", + unk_token: str = "", + bos_token: str = "", + eos_token: str = "", # human-readable; id is still 3 + **_: Any, # ignore junk kwargs – we don't play framework games + ) -> None: + # 64 codons + bases = ("A", "C", "G", "T") + self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases] + + # specials come first, contiguous + special_tokens = [pad_token, unk_token, bos_token, eos_token] + self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token} + + # vocab: specials [0..3], then 64 codons [4..67] + self.vocab: Dict[str, int] = {} + for i, tok in enumerate(special_tokens): + self.vocab[tok] = i + for codon in self.codons: + self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens)) + + # reverse map + self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()} + + # fixed ids + self._special_ids = SpecialIds( + pad=self.vocab[pad_token], + unk=self.vocab[unk_token], + bos=self.vocab[bos_token], + eos=self.vocab[eos_token], + ) + self._num_special_tokens = len(special_tokens) + + # genetic code (char) + self._genetic_code: Dict[str, str] = { + "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L", + "TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S", + "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*", + "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W", + "CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L", + "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P", + "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q", + "CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R", + "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M", + "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T", + "AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K", + "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R", + "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V", + "GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A", + "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E", + "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G", + } + + # precompute char helpers + self._codon2aa_char: Dict[int, str] = {} + self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} + for codon in self.codons: + cid = self.vocab[codon] + aa = self._genetic_code.get(codon, "X") + self._codon2aa_char[cid] = aa + if aa in self._aa2codons_char: + self._aa2codons_char[aa].append(cid) + + # sanity: specials are contiguous 0..3 + ids = list(self._special_ids.to_dict().values()) + if sorted(ids) != list(range(self._num_special_tokens)): + raise AssertionError("Special token ids must be contiguous starting at 0") + + # ---------- properties ---------- + @property + def vocab_size(self) -> int: + return len(self.vocab) + + @property + def special_ids(self) -> SpecialIds: + return self._special_ids + + @property + def num_special_tokens(self) -> int: + return self._num_special_tokens + + @property + def pad_token_id(self) -> int: + return self._special_ids.pad + + @property + def unk_token_id(self) -> int: + return self._special_ids.unk + + @property + def bos_token_id(self) -> int: + return self._special_ids.bos + + @property + def eos_token_id(self) -> int: + return self._special_ids.eos + + # ---------- core API ---------- + def encode_codon_seq(self, seq: str, validate: bool = True) -> List[int]: + """ + Map DNA (ACGT)^3N to 3-mer ids. We don't append BOS/EOS here. + """ + s = seq.upper() + if validate: + if len(s) % 3 != 0: + raise ValueError(f"Sequence length {len(s)} not divisible by 3") + if not _is_acgt(s): + raise ValueError("Sequence contains invalid nucleotides (only ACGT supported)") + out: List[int] = [] + # Fast Python slice loop – good enough. NumPy won't help for tiny strings. + for i in range(0, len(s), 3): + codon = s[i : i + 3] + out.append(self.vocab.get(codon, self._special_ids.unk)) + return out + + def decode_codon_seq(self, token_ids: List[int]) -> str: + """ + Convert codon ids (>= num_special_tokens) back to DNA string. + Special ids are ignored unless they collide (they don't). + """ + parts: List[str] = [] + nst = self._num_special_tokens + for tid in token_ids: + if tid >= nst: + tok = self.ids_to_tokens.get(tid) + if tok is not None: # should always be a codon + parts.append(tok) + return "".join(parts) + + def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **_: Any) -> str: + # kept for API parity with your old code + if skip_special_tokens: + token_ids = [t for t in token_ids if t >= self._num_special_tokens] + return self.decode_codon_seq(token_ids) + + # ---------- misc helpers ---------- + def codon_vocab(self) -> Dict[str, int]: + return {c: self.vocab[c] for c in self.codons} + + def codon2aa_char_map(self) -> Dict[int, str]: + return dict(self._codon2aa_char) + + def aa2codons_char_map(self) -> Dict[str, List[int]]: + return {k: v[:] for k, v in self._aa2codons_char.items()} + + def aa_to_codon_length(self, aa_seq: str) -> int: + # You don't count stop unless it's explicitly there. + return len(aa_seq) + + # HF compatibility stubs (your code calls these in a few places) + def _tokenize(self, text: str) -> List[str]: + if len(text) % 3 != 0: + raise ValueError(f"Text length {len(text)} not divisible by 3") + return [text[i : i + 3] for i in range(0, len(text), 3)] + + def _convert_token_to_id(self, token: str) -> int: + return self.vocab.get(token, self._special_ids.unk) + + def _convert_id_to_token(self, index: int) -> str: + return self.ids_to_tokens.get(index, self._special_token_str["unk"]) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return "".join(tokens) + + def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: + return token_ids_0 + + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: + return [0] * len(token_ids_0) + + # ---------- persistence ---------- + def get_vocab(self) -> Dict[str, int]: + return dict(self.vocab) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save to JSON with both vocab and special token strings so we can + reconstruct IDs exactly. Deterministic and stable. + """ + os.makedirs(save_directory, exist_ok=True) + vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + "vocab.json", + ) + payload = { + "vocab": self.vocab, + "special_token_str": self._special_token_str, + } + with open(vocab_file, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True) + return (vocab_file,) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer": + """ + Load from a directory containing vocab.json produced by save_vocabulary(). + We rebuild the SpecialIds from the saved token strings to keep IDs stable. + """ + vocab_path = Path(pretrained_model_name_or_path) / "vocab.json" + tok = cls(**kwargs) # default structure; we'll overwrite below + if not vocab_path.exists(): + # If nothing to load, return defaults. It keeps the rest of your code happy. + return tok + + with open(vocab_path, "r", encoding="utf-8") as f: + save_data = json.load(f) + + if not isinstance(save_data, dict) or "vocab" not in save_data: + # Old, dumber format: the whole file was the vocab dict + vocab = save_data + special_token_str = tok._special_token_str + else: + vocab = save_data["vocab"] + special_token_str = save_data.get("special_token_str", tok._special_token_str) + + # rebuild maps + tok.vocab = {str(k): int(v) for k, v in vocab.items()} + tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()} + + # reconcile special strings → ids + if isinstance(special_token_str, dict): + tok._special_token_str.update({k: v for k, v in special_token_str.items() if k in ("pad", "unk", "bos", "eos")}) + + def _id_for(name: str, default_val: int) -> int: + sym = tok._special_token_str[name] + return int(tok.vocab.get(sym, default_val)) + + tok._special_ids = SpecialIds( + pad=_id_for("pad", 0), + unk=_id_for("unk", 1), + bos=_id_for("bos", 2), + eos=_id_for("eos", 3), + ) + + # Figure out how many specials to reserve. If the saved mapping had extra junk, + # we still preserve a contiguous prefix if present. Otherwise default to 4. + ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos] + m = max(ids) + tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4 + + # Rebuild genetic helpers (cheap) + tok._rebuild_helpers() + return tok + + # internal: rebuild helper maps after load + def _rebuild_helpers(self) -> None: + self._codon2aa_char = {} + self._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} + for codon in self.codons: + cid = self.vocab[codon] + aa = self._genetic_code.get(codon, "X") + self._codon2aa_char[cid] = aa + if aa in self._aa2codons_char: + self._aa2codons_char[aa].append(cid) + + +# ------------------------------ +# small helpers +# ------------------------------ + +def _is_acgt(s: str) -> bool: + # Faster than regex for short strings. + for ch in s: + if ch not in ("A", "C", "G", "T"): + return False + return True diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8b5b6ffef811ac493f30fc04ddab28702f7b57b0 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,1230 @@ +# src/trainer.py +""" +FSDP trainer for CodonGPT. +No frameworks, no sugar. The model computes its own loss. + +Batch invariants: +- codon_ids [B, T] (right-padded; EOS already in-sequence) +- species_ids [B] (SpeciesEmbeddingStore provides fixed-size or sequence embeddings) +- protein_seqs: list[str] (ESM tokenization happens inside the model) + +Rules: +- If your loader is IterableDataset, you MUST set args.max_steps > 0. We don't guess. +- If you want epoch-based, use a sized dataset; we call len(dataloader). +""" + +from __future__ import annotations + +import os +import json +import math +import re +import shutil +import logging +import time +from dataclasses import dataclass +import datetime +import warnings +import importlib.util +import inspect +from typing import Any, Callable, Dict, Optional, Tuple, List +from tqdm import tqdm +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.utils.data import DataLoader, IterableDataset + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ( + ShardingStrategy, + MixedPrecision, + StateDictType, + FullStateDictConfig, + FullOptimStateDictConfig, +) +from safetensors.torch import save_file, load_file +import wandb + +logger = logging.getLogger(__name__) + + +# ------------------------------ +# Args +# ------------------------------ + +@dataclass +class TrainingArguments: + # Output + output_dir: str = "checkpoints" + save_steps: int = 1000 + save_total_limit: int = 3 + save_safetensors: bool = True + ckpt_recent_window_steps: int = 0 + ckpt_recent_interval: int = 0 + ckpt_archive_interval: int = 0 + + # Schedule + num_train_epochs: int = 1 + max_steps: int = -1 # required for IterableDataset + gradient_accumulation_steps: int = 1 + warmup_ratio: float = 0.0 + lr_scheduler_type: str = "cosine" # "linear" | "cosine" | "constant" + # For streaming datasets: if max_steps<0 and steps_per_epoch>0, shape schedule using + # total_steps = num_train_epochs * steps_per_epoch + steps_per_epoch: int = 0 + + # Optim + learning_rate: float = 5e-4 + weight_decay: float = 0.0 + adam_beta1: float = 0.9 + adam_beta2: float = 0.95 + max_grad_norm: float = 1.0 + + # Data + per_device_train_batch_size: int = 8 + per_device_eval_batch_size: int = 8 + dataloader_num_workers: int = 0 + + # Precision / dist + fp16: bool = False + bf16: bool = False + fsdp: Optional[str] = None # "full_shard" or None + gradient_checkpointing: bool = False + + # Global hard cap (prefix + start + codon) + max_length: int = 4096 + + # ESM (metadata only; model owns ESM) + esm_model_name: str = "esmc_300m" + esm_device: str = "cuda" + esm_dtype: str = "bf16" + + # Logging / eval + logging_steps: int = 100 + eval_steps: int = 0 # streaming eval: limit number of eval batches when eval dataset is Iterable + eval_interval: int = 0 # run evaluation every N optimizer steps (0 disables) + override_lr_on_resume: bool = False + # Minimal data stream resume cursor (stores total samples yielded so far for train dataset). + # When provided, we load 'skip_samples' from this JSON at start and set the dataset + # to skip exactly that many samples on resume. We also update the file in _save_checkpoint(). + data_cursor_path: Optional[str] = None + + +# ------------------------------ +# Trainer +# ------------------------------ + +class Trainer: + def __init__( + self, + model: nn.Module, + args: TrainingArguments, + data_collator: Optional[Callable] = None, + train_dataset: Optional[Any] = None, + eval_dataset: Optional[Any] = None, + tokenizer: Optional[Any] = None, + model_init: Optional[Callable[[], nn.Module]] = None, + compute_metrics: Optional[Callable] = None, + callbacks: Optional[list] = None, + optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[Any]] = (None, None), + preprocess_logits_for_metrics: Optional[Callable] = None, + species_store=None, + resume_from_checkpoint: Optional[str] = None, + ): + self.model = model + self.args = args + self.tokenizer = tokenizer + self.optimizer = optimizers[0] + self.lr_scheduler = optimizers[1] + self.species_store = species_store + + self.train_dataloader: Optional[DataLoader] = None + self.eval_dataloader: Optional[DataLoader] = None + + # Device (robust local rank resolution) + self.local_rank = 0 + if torch.cuda.is_available(): + lr_env = os.environ.get("LOCAL_RANK") + if lr_env is not None: + self.local_rank = int(lr_env) + else: + r = int(os.environ.get("RANK", "0")) + ng = max(1, torch.cuda.device_count()) + self.local_rank = (r % ng) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + cd = torch.cuda.current_device() + nm = torch.cuda.get_device_name(cd) + logger.info( + f"[dist] RANK={os.environ.get('RANK')} LOCAL_RANK={os.environ.get('LOCAL_RANK')} WORLD_SIZE={os.environ.get('WORLD_SIZE')} " + f"cuda.count={torch.cuda.device_count()} select={self.device} current={cd} name={nm}" + ) + else: + self.device = torch.device("cpu") + + # Gradient checkpointing toggle (model owns the flag) + base = self._unwrap(self.model) + if self.args.gradient_checkpointing and hasattr(base, "gradient_checkpointing"): + base.gradient_checkpointing = True + + # FSDP or single GPU + if self.args.fsdp: + self._setup_fsdp() + else: + self.model = self.model.to(self.device) + + # AMP setup (use torch.amp APIs; GradScaler on CUDA only) + self._use_amp = (self.device.type == "cuda") and (self.args.fp16 or self.args.bf16) + self._amp_dtype = torch.float16 if self.args.fp16 else (torch.bfloat16 if self.args.bf16 else None) + use_cuda = (self.device.type == "cuda") + self._scaler = torch.amp.GradScaler(device="cuda", enabled=(use_cuda and self.args.fp16)) + + self.state = {"epoch": 0, "global_step": 0} + + # Defer resume until after dataloaders are attached so scheduler can be shaped. + self._resume_path: Optional[str] = resume_from_checkpoint + + # ---- dataloaders ---- + def attach_dataloaders(self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None): + # Your dataset should handle sharding. We don't wrap with DistributedSampler here. + self.train_dataloader = train_loader + self.eval_dataloader = eval_loader + # Apply minimal resume cursor to the training dataset if configured + p = getattr(self.args, "data_cursor_path", None) + if p and os.path.exists(p): + with open(p, "r") as f: + js = json.load(f) + ds = getattr(self.train_dataloader, "dataset", None) + if hasattr(ds, "set_resume_skip"): + distributed = dist.is_available() and dist.is_initialized() + world = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + # Prefer the total cursor and split evenly across current world size. + # If total is missing, sum any saved per_rank list. + total: int = 0 + if isinstance(js, dict): + try: + total = int(js.get("skip_samples", 0) or 0) + except Exception: + total = 0 + if total <= 0: + raw = js.get("per_rank") + if isinstance(raw, list) and raw: + try: + total = int(sum(int(x) for x in raw)) + except Exception: + total = 0 + + if total > 0: + if distributed: + per = total // max(world, 1) + rem = total % max(world, 1) + n_rank = per + (1 if rank < rem else 0) + ds.set_resume_skip(int(n_rank)) + if self._is_main(): + logger.info( + "resume cursor: total=%s split across world=%s → rank=%s skip=%s", + total, world, rank, n_rank, + ) + else: + ds.set_resume_skip(int(total)) + if self._is_main(): + logger.info("resume cursor: total=%s (single-process) skip=%s", total, total) + + + # ---- optim + scheduler ---- + def _create_optimizer_and_scheduler(self): + if self.optimizer is None: + decay, no_decay = [], [] + for n, p in self._unwrap(self.model).named_parameters(): + if not p.requires_grad: + continue + if n.endswith("bias") or "norm" in n.lower() or "ln_" in n.lower(): + no_decay.append(p) + else: + decay.append(p) + + opt_kwargs = dict( + lr=self.args.learning_rate, + betas=(self.args.adam_beta1, self.args.adam_beta2), + ) + params = [ + {"params": decay, "weight_decay": self.args.weight_decay}, + {"params": no_decay, "weight_decay": 0.0}, + ] + sig_adamw = inspect.signature(torch.optim.AdamW) + if torch.cuda.is_available() and "fused" in sig_adamw.parameters: + opt_kwargs["fused"] = True # type: ignore[assignment] + self.optimizer = torch.optim.AdamW(params, **opt_kwargs) + # Report fused/foreach settings (rank0 only) + if self._is_main(): + fused_flag = None + foreach_flag = None + if hasattr(self.optimizer, "defaults"): + fused_flag = self.optimizer.defaults.get("fused") + foreach_flag = self.optimizer.defaults.get("foreach") + logger.info(f"AdamW configured: fused={fused_flag} foreach={foreach_flag}") + + # total steps and schedule shape + ds = getattr(self.train_dataloader, "dataset", None) + ga = max(1, self.args.gradient_accumulation_steps) + if isinstance(ds, IterableDataset): + if self.args.max_steps > 0: + # Use max_steps to shape the scheduler; allow multiple epochs to re-iterate the stream + steps_per_epoch = self.args.max_steps + total_steps = self.args.max_steps + elif getattr(self.args, "steps_per_epoch", 0) and self.args.steps_per_epoch > 0: + # steps_per_epoch is already expressed in optimizer steps (train.py accounts for grad_accum) + steps_per_epoch = max(1, int(self.args.steps_per_epoch)) + total_steps = max(1, self.args.num_train_epochs) * steps_per_epoch + else: + # Unknown epoch size; use constant LR without pre-shaped schedule + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0) + return + else: + # sized dataloader: len(dataloader) is number of batches + steps_per_epoch = max(len(self.train_dataloader) // ga, 1) + total_steps = self.args.max_steps if self.args.max_steps > 0 else self.args.num_train_epochs * steps_per_epoch + + warmup = int(self.args.warmup_ratio * total_steps) + + if self.args.lr_scheduler_type == "constant": + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0) + return + + def lrs_lambda(step: int) -> float: + if step < warmup: + return max(float(step) / max(warmup, 1), 1e-6) + t = (step - warmup) / max(total_steps - warmup, 1) + if self.args.lr_scheduler_type == "linear": + return max(1.0 - t, 0.0) + # cosine default + return 0.5 * (1.0 + math.cos(math.pi * t)) + + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lrs_lambda) + + # ---- training ---- + def train(self) -> Dict[str, float]: + assert self.train_dataloader is not None, "Call attach_dataloaders() first" + # If a resume path was provided, load it now (dataloaders are attached). + if getattr(self, "_resume_path", None): + self._resume_from(self._resume_path) # loads model/optimizer/scheduler/state + self._resume_path = None + + if self.optimizer is None: + self._create_optimizer_and_scheduler() + + ds = self.train_dataloader.dataset + + # Exact step budget for streaming datasets when max_steps<0 and steps_per_epoch>0 + target_total_steps: Optional[int] = None + if isinstance(ds, IterableDataset) and int(self.args.max_steps) < 0: + spe = int(getattr(self.args, "steps_per_epoch", 0) or 0) + if spe > 0: + target_total_steps = max(1, int(self.args.num_train_epochs)) * spe + + # Determine total steps for progress bar + progress_total: Optional[int] = None + if int(self.args.max_steps) > 0: + progress_total = int(self.args.max_steps) + elif isinstance(ds, IterableDataset): + if target_total_steps is not None: + progress_total = target_total_steps + else: + ga = max(1, self.args.gradient_accumulation_steps) + steps_per_epoch = max(len(self.train_dataloader) // ga, 1) + progress_total = max(1, int(self.args.num_train_epochs)) * steps_per_epoch + + # Initialize Weights & Biases (rank0 only) + if self._is_main(): + if not hasattr(self, "_wandb"): + proj = os.environ.get("WANDB_PROJECT", "codongpt") + name = os.environ.get("WANDB_NAME") + run_id = os.environ.get("WANDB_RUN_ID") + resume = os.environ.get("WANDB_RESUME") + wandb_dir = os.environ.get("WANDB_DIR") + world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1")) + init_kwargs = { + "project": proj, + "name": name, + "config": { + "lr": self.args.learning_rate, + "warmup_ratio": self.args.warmup_ratio, + "scheduler": self.args.lr_scheduler_type, + "batch_size": self.args.per_device_train_batch_size, + "eval_batch_size": self.args.per_device_eval_batch_size, + "grad_accum": self.args.gradient_accumulation_steps, + "effective_global_batch": self.args.per_device_train_batch_size * max(1, world_size) * max(1, self.args.gradient_accumulation_steps), + "epochs": self.args.num_train_epochs, + "steps_per_epoch": getattr(self.args, "steps_per_epoch", 0), + "max_steps": self.args.max_steps, + "weight_decay": self.args.weight_decay, + "world_size": world_size, + "output_dir": self.args.output_dir, + "fsdp": self.args.fsdp, + "bf16": self.args.bf16, + "fp16": self.args.fp16, + }, + } + if run_id: + init_kwargs["id"] = run_id + if resume: + init_kwargs["resume"] = resume + if wandb_dir: + init_kwargs["dir"] = wandb_dir + self._wandb = wandb.init(**init_kwargs) + + self.model.train() + grad_accum = max(1, self.args.gradient_accumulation_steps) + progress = None + if self._is_main() and progress_total is not None and progress_total > 0: + progress = tqdm(total=progress_total, initial=int(self.state["global_step"]), desc="Train", dynamic_ncols=True) + if self.device.type == "cuda" and torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(self.device) + world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1")) + seqs_per_optimizer_step = ( + int(self.args.per_device_train_batch_size) * max(1, world_size) * grad_accum + ) + log_window_start = time.perf_counter() + log_window_optimizer_steps = 0 + + for epoch in range(self.state["epoch"], max(1, self.args.num_train_epochs)): + self.state["epoch"] = epoch + running_loss = 0.0 + running_count = 0 + + train_iter = iter(self.train_dataloader) + step = 0 + batches_this_epoch = 0 + optimizer_steps_this_epoch = 0 + # If this is a streaming dataset with a shaped schedule, enforce a per-epoch optimizer step budget + enforce_budget = False + epoch_budget = None + ds = self.train_dataloader.dataset + if isinstance(ds, IterableDataset): + spe = int(getattr(self.args, "steps_per_epoch", 0) or 0) + if spe > 0: + enforce_budget = True + epoch_budget = int(spe) + + refill_attempts = 0 + max_refills = 64 # avoids infinite loops when dataset is empty + + while True: + batch, has_batch, local_has_batch = self._next_batch_sync(train_iter) + if not has_batch: + # If budget-enforced, attempt to refill the iterator and continue until budget is met. + if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch < epoch_budget): + if local_has_batch and self._is_main(): + logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync") + self._barrier() + train_iter = iter(self.train_dataloader) + refill_attempts += 1 + if refill_attempts > max_refills: + if self._is_main(): + logger.warning( + "Exceeded max refills for epoch %s (steps %s/%s). Ending epoch early.", + epoch, optimizer_steps_this_epoch, epoch_budget, + ) + break + continue + else: + if local_has_batch and self._is_main(): + logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync") + break + + batch = self._prepare_batch(batch) + batches_this_epoch += 1 + + codon_ids = batch["codon_ids"].to(self.device) + input_ids = codon_ids[:, :-1] + labels = codon_ids[:, :-1] + + # Mask PAD/EOS in labels + pad_id = int(self.tokenizer.pad_token_id) if self.tokenizer is not None else 0 + eos_id = int(self.tokenizer.special_ids.eos) if self.tokenizer is not None else -999 + labels = labels.clone() + labels[labels == pad_id] = -100 + labels[labels == eos_id] = -100 + + cond = self._build_cond(batch) + + # autocast context + use_cuda = (self.device.type == "cuda") + autocast_dtype = self._amp_dtype + if autocast_dtype is not None and use_cuda: + ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype) + else: + from contextlib import nullcontext + ctx = nullcontext() + + with ctx: + out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True) + loss = out["loss"] + + if self._scaler.is_enabled(): + self._scaler.scale(loss / grad_accum).backward() + else: + (loss / grad_accum).backward() + + running_loss += float(loss.detach().item()) + running_count += 1 + + do_step = ((step + 1) % grad_accum == 0) + if do_step: + # Clip + if self.args.max_grad_norm and self.args.max_grad_norm > 0: + if isinstance(self.model, FSDP): + FSDP.clip_grad_norm_(self.model, self.args.max_grad_norm) + else: + if self._scaler.is_enabled(): + self._scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) + + # Step + if self._scaler.is_enabled(): + self._scaler.step(self.optimizer) + self._scaler.update() + else: + self.optimizer.step() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + self.optimizer.zero_grad(set_to_none=True) + self.state["global_step"] += 1 + optimizer_steps_this_epoch += 1 + log_window_optimizer_steps += 1 + + # (wandb) Defer logging to the periodic block below + + # Log + should_log = (self.state["global_step"] % max(1, self.args.logging_steps) == 0) + peak_alloc_gb = 0.0 + peak_reserved_gb = 0.0 + if should_log: + peak_alloc_gb, peak_reserved_gb = self._max_cuda_peak_gb() + if self._is_main() and should_log: + avg = running_loss / max(running_count, 1) + lr = float(self.optimizer.param_groups[0]["lr"]) + log_epoch = self._epoch_for_logging() + elapsed = max(time.perf_counter() - log_window_start, 1e-9) + step_time_s = elapsed / max(log_window_optimizer_steps, 1) + seq_per_s = (seqs_per_optimizer_step * max(log_window_optimizer_steps, 1)) / elapsed + msg = f"epoch {log_epoch} step {self.state['global_step']}: loss={avg:.4f} lr={lr:.6g}" + if isinstance(out, dict): + pl = out.get("prefix_len") + pc = out.get("per_cap") + if pl is not None and pc is not None: + msg += f" prefix_mean={float(pl.detach().float().mean().item()):.1f} cap_mean={float(pc.detach().float().mean().item()):.1f}" + msg += ( + f" step_time_s={step_time_s:.3f} seq_per_s={seq_per_s:.1f}" + f" peak_mem_alloc_gb={peak_alloc_gb:.1f} peak_mem_reserved_gb={peak_reserved_gb:.1f}" + ) + logger.info(msg) + if hasattr(self, "_wandb"): + wandb.log({ + "train/loss": float(avg), + "train/lr": float(lr), + "perf/step_time_s": float(step_time_s), + "perf/seq_per_s": float(seq_per_s), + "system/peak_mem_alloc_gb": float(peak_alloc_gb), + "system/peak_mem_reserved_gb": float(peak_reserved_gb), + }, step=self.state["global_step"]) + running_loss = 0.0 + running_count = 0 + log_window_start = time.perf_counter() + log_window_optimizer_steps = 0 + + # Update progress bar + if progress is not None: + progress.update(1) + + # Stop when budget is reached for streaming schedule + if target_total_steps is not None and self.state["global_step"] >= target_total_steps: + metrics = {"train_loss": running_loss / max(running_count, 1)} + self._save_checkpoint("final_model") + self._barrier() + return metrics + + # Periodic teacher-forced evaluation on the held-out dataset + should_eval = ( + self.eval_dataloader is not None and + self.args.eval_interval > 0 and + (self.state["global_step"] % self.args.eval_interval == 0) + ) + if should_eval: + eval_metrics = self.evaluate() + if self._is_main(): + el = float(eval_metrics.get("eval_loss", 0.0)) + ea = eval_metrics.get("eval_codon_acc", None) + aa = eval_metrics.get("eval_aa_acc", None) + if ea is not None and aa is not None: + logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f} aa_acc={float(aa):.3f}") + elif ea is not None: + logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f}") + elif aa is not None: + logger.info(f"eval: loss={el:.4f} aa_acc={float(aa):.3f}") + else: + logger.info(f"eval: loss={el:.4f}") + if hasattr(self, "_wandb"): + log_payload = {"eval/loss": el} + if ea is not None: + log_payload["eval/codon_acc"] = float(ea) + if aa is not None: + log_payload["eval/aa_acc"] = float(aa) + wandb.log(log_payload, step=self.state["global_step"]) + + # Save by step + if self.args.save_steps > 0 and (self.state["global_step"] % self.args.save_steps == 0): + self._save_checkpoint(f"checkpoint-{self.state['global_step']}") + + # Hard horizon for streaming/step-limited runs + if self.args.max_steps > 0 and self.state["global_step"] >= self.args.max_steps: + metrics = {"train_loss": running_loss / max(running_count, 1)} + self._save_checkpoint("final_model") + self._barrier() + if progress is not None: + progress.close() + return metrics + + step += 1 + + # If we enforce a per-epoch budget for streaming datasets, end the epoch once it's reached + if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch >= epoch_budget): + break + + # Epoch summary (rank0 only) + if self._is_main(): + try: + eb = int(epoch_budget) if epoch_budget is not None else -1 + except Exception: + eb = -1 + logger.info( + "epoch %s completed: optimizer_steps=%s%s", + self._epoch_for_logging(), + optimizer_steps_this_epoch, + (f" / budget {eb}" if eb > 0 else ""), + ) + + if dist.is_available() and dist.is_initialized(): + gather_device = self.device if self.device.type == "cuda" else torch.device("cpu") + counts_tensor = torch.tensor( + [batches_this_epoch, optimizer_steps_this_epoch], + dtype=torch.long, + device=gather_device, + ) + gathered = [torch.zeros_like(counts_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, counts_tensor) + batch_counts = [int(t[0].item()) for t in gathered] + step_counts = [int(t[1].item()) for t in gathered] + batch_gap = max(batch_counts) - min(batch_counts) + step_gap = max(step_counts) - min(step_counts) + if self._is_main() and (batch_gap > 0 or step_gap > 0): + logger.warning( + "Epoch %s imbalance detected across ranks: batches min=%s max=%s, optimizer steps min=%s max=%s", + epoch, + min(batch_counts), + max(batch_counts), + min(step_counts), + max(step_counts), + ) + + # Epoch boundary save for sized datasets + if not isinstance(ds, IterableDataset): + self._save_checkpoint(f"epoch-{epoch}") + + metrics = {"train_loss": 0.0} + if progress is not None: + progress.close() + self._barrier() + return metrics + + # ---- evaluation ---- + def evaluate(self) -> Dict[str, float]: + if self.eval_dataloader is None: + return {"eval_loss": 0.0} + + self.model.eval() + + loss_sum = 0.0 + loss_tokens = 0 + codon_correct = 0 + codon_total = 0 + aa_correct = 0 + aa_total = 0 + + tok = self.tokenizer + pad_id = int(tok.pad_token_id) if tok is not None else 0 + eos_id = int(tok.special_ids.eos) if tok is not None and hasattr(tok, "special_ids") else -999 + num_special = int(tok.num_special_tokens) if tok is not None else 0 + codon2aa = tok.codon2aa_char_map() if tok is not None and hasattr(tok, "codon2aa_char_map") else {} + + is_streaming = isinstance(self.eval_dataloader.dataset, IterableDataset) + max_batches = int(self.args.eval_steps) if (is_streaming and self.args.eval_steps > 0) else None + + with torch.no_grad(): + eval_iter = iter(self.eval_dataloader) + b_idx = 0 + while True: + batch, has_batch, local_has_batch = self._next_batch_sync(eval_iter) + if not has_batch: + if local_has_batch and self._is_main(): + logger.debug("eval dataloader: discarded tail batch to stay in sync across ranks") + break + + if max_batches is not None and b_idx >= max_batches: + break + + batch = self._prepare_batch(batch) + + codon_ids = batch["codon_ids"].to(self.device) + input_ids = codon_ids[:, :-1] + labels = codon_ids[:, :-1] + + labels = labels.clone() + labels[labels == pad_id] = -100 + labels[labels == eos_id] = -100 + + cond = self._build_cond(batch) + + use_cuda = (self.device.type == "cuda") + autocast_dtype = self._amp_dtype + if autocast_dtype is not None and use_cuda: + ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype) + else: + from contextlib import nullcontext + ctx = nullcontext() + + with ctx: + out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True) + + loss = out.get("loss") + per_cap = out.get("per_cap") + logits = out.get("logits") + + tokens_in_batch = 0 + if per_cap is not None: + tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item()) + loss_tokens += tokens_in_batch + + if loss is not None and tokens_in_batch > 0: + loss_sum += float(loss.detach().item()) * tokens_in_batch + + if logits is None or logits.size(1) == 0 or per_cap is None: + continue + + max_cap = logits.size(1) + batch_size = logits.size(0) + + labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device) + common_cols = min(labels.size(1), max_cap) + if common_cols > 0: + labels_aligned[:, :common_cols] = labels[:, :common_cols] + + per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap) + for row in range(batch_size): + cap = int(per_cap_int[row].item()) + if cap < max_cap: + labels_aligned[row, cap:] = -100 + + supervised = labels_aligned != -100 + if num_special > 0: + supervised = supervised & (labels_aligned >= num_special) + if not supervised.any(): + continue + + preds = logits.argmax(dim=-1) + codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item()) + codon_total += int(supervised.sum().item()) + + if codon2aa and isinstance(batch, dict) and "protein_seqs" in batch: + prot_list = batch.get("protein_seqs", []) + for row in range(batch_size): + cap = int(per_cap_int[row].item()) + if cap <= 0: + continue + mask_row = supervised[row, :cap] + if not mask_row.any(): + continue + preds_row = preds[row, :cap][mask_row] + prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else "" + if not prot: + continue + seq_len = min(len(prot), preds_row.size(0)) + if seq_len <= 0: + continue + pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) + truth_aa = prot[:seq_len] + aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i]) + aa_total += seq_len + + b_idx += 1 + + totals = torch.tensor( + [loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total], + dtype=torch.float64, + device=self.device, + ) + if dist.is_available() and dist.is_initialized(): + # Ensure every rank has finished its forward passes before the final + # metric reduction, otherwise FSDP may still be issuing _all_gather + # collectives on slower ranks. + self._barrier() + dist.all_reduce(totals, op=dist.ReduceOp.SUM) + + loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total = totals.tolist() + + self.model.train() + + metrics: Dict[str, float] = {"eval_loss": float(loss_sum) / loss_tokens if loss_tokens > 0 else 0.0} + if codon_total > 0: + metrics["eval_codon_acc"] = float(codon_correct) / codon_total + if aa_total > 0: + metrics["eval_aa_acc"] = float(aa_correct) / aa_total + + self._barrier() + return metrics + + # ---- internals ---- + def _setup_fsdp(self): + # Ensure default process group is initialized (required by FSDP) + device = self.device + if dist.is_available() and not dist.is_initialized(): + backend = "nccl" if device.type == "cuda" else "gloo" + sig = inspect.signature(dist.init_process_group) + if "timeout" in sig.parameters: + dist.init_process_group(backend=backend, init_method="env://", timeout=datetime.timedelta(minutes=30)) + else: + dist.init_process_group(backend=backend, init_method="env://") + mp = MixedPrecision( + param_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32), + reduce_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32), + buffer_dtype=torch.float32, + ) + logger.info(f"FSDP enabled: sharding={self.args.fsdp} mp_param={mp.param_dtype} mp_reduce={mp.reduce_dtype}") + # Keep frozen ESM off FSDP if present + base = self._unwrap(self.model) + ignored = [] + if hasattr(base, "esm") and isinstance(base.esm, nn.Module): + ignored.append(base.esm) + + self.model = FSDP( + self.model, + device_id=(self.device if device.type == "cuda" else None), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mp, + ignored_modules=(ignored if ignored else None), + sync_module_states=True, + ) + + # Place ignored module on device exactly once + if ignored: + ignored[0].to(device) + + def _unwrap(self, module): + return getattr(module, "module", module) + + def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: + # Species embeddings (fixed-size or sequence) + if self.species_store is not None and "species_ids" in batch: + sids = batch["species_ids"] + if torch.is_tensor(sids): + sids = sids.detach().cpu().tolist() + result = self.species_store.batch_get(sids) + if isinstance(result, tuple): + sp_tok, _ = result # [B, Ls, Ds] + batch["species_tok_emb"] = sp_tok.to(self.device, non_blocking=True) + else: + sp = result # [B, Ds] + batch["species_emb"] = sp.to(self.device, non_blocking=True) + + # Move obvious tensors + if "codon_ids" in batch and hasattr(batch["codon_ids"], "to"): + batch["codon_ids"] = batch["codon_ids"].to(self.device, non_blocking=True) + + return batch + + def _build_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]: + cond: Dict[str, Any] = {"control_mode": "fixed"} + if "species_tok_emb" in batch: + cond["species_tok_emb_src"] = batch["species_tok_emb"] + cond["species_tok_emb_tgt"] = batch["species_tok_emb"] + elif "species_emb" in batch: + cond["species_emb_src"] = batch["species_emb"] + cond["species_emb_tgt"] = batch["species_emb"] + if "protein_seqs" in batch: + cond["protein_seqs"] = batch["protein_seqs"] + return cond + + def _next_batch_sync(self, iterator): + """Fetch next batch and drop out early if any rank exhausts its loader.""" + try: + batch = next(iterator) + local_has_batch = True + except StopIteration: + batch = None + local_has_batch = False + + distributed = dist.is_available() and dist.is_initialized() + has_batch = local_has_batch + + if distributed: + flag_device = self.device if self.device.type == "cuda" else torch.device("cpu") + flag = torch.tensor([1 if local_has_batch else 0], device=flag_device) + dist.all_reduce(flag, op=dist.ReduceOp.MIN) + has_batch = bool(flag.item()) + + if not has_batch: + return None, False, local_has_batch + + return batch, True, local_has_batch + + def _is_main(self) -> bool: + return (not dist.is_available()) or (not dist.is_initialized()) or dist.get_rank() == 0 + + def _barrier(self): + if dist.is_available() and dist.is_initialized(): + # On NCCL, pass device_ids to avoid rank↔GPU mapping ambiguity when supported + if self.device.type == "cuda": + sig = inspect.signature(dist.barrier) + if "device_ids" in sig.parameters: + dist.barrier(device_ids=[self.local_rank]) + return + dist.barrier() + + def _max_cuda_peak_gb(self) -> Tuple[float, float]: + if self.device.type != "cuda" or not torch.cuda.is_available(): + return 0.0, 0.0 + vals = torch.tensor( + [ + float(torch.cuda.max_memory_allocated(self.device)), + float(torch.cuda.max_memory_reserved(self.device)), + ], + dtype=torch.float64, + device=self.device, + ) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(vals, op=dist.ReduceOp.MAX) + scale = float(1024 ** 3) + return float(vals[0].item() / scale), float(vals[1].item() / scale) + + # (Per-sample quick eval removed; evaluation now uses held-out dataloader.) + + def _epoch_for_logging(self) -> int: + steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0) + if steps_per_epoch > 0: + est = self.state.get("global_step", 0) // steps_per_epoch + if self.args.num_train_epochs > 0: + max_epoch = max(int(self.args.num_train_epochs) - 1, 0) + if est > max_epoch: + return max_epoch + return int(est) + return int(self.state.get("epoch", 0)) + + # ---- checkpointing ---- + def _save_checkpoint(self, name: str): + self.state["epoch"] = int(self._epoch_for_logging()) + # All ranks participate in FSDP state_dict collectives; only rank0 writes files. + out_dir = os.path.join(self.args.output_dir, name) + os.makedirs(out_dir, exist_ok=True) + + optim_state = None + if isinstance(self.model, FSDP): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + with FSDP.state_dict_type( + self.model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=True, offload_to_cpu=True), + FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True), + ): + state = self.model.state_dict() + # NOTE: Under FSDP, optimizer.state_dict() is sharded per-rank. + # Use FSDP.optim_state_dict() to materialize a full optimizer state dict (rank0_only). + if self.optimizer is not None: + optim_state = FSDP.optim_state_dict(self.model, self.optimizer) + else: + state = self._unwrap(self.model).state_dict() + if self.optimizer is not None: + optim_state = self.optimizer.state_dict() + + # Save minimal data cursor (total samples yielded so far) next to output_dir if configured + per_rank_positions: Optional[List[int]] = None + p = getattr(self.args, "data_cursor_path", None) + if p: + ds = getattr(self.train_dataloader, "dataset", None) + if hasattr(ds, "get_stream_position"): + local_pos = int(ds.get_stream_position()) + if dist.is_available() and dist.is_initialized(): + gather_device = self.device if self.device.type == "cuda" else torch.device("cpu") + tensor = torch.tensor([local_pos], dtype=torch.long, device=gather_device) + gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, tensor) + per_rank_positions = [int(t.item()) for t in gathered] + else: + per_rank_positions = [local_pos] + + if not self._is_main(): + # Non-main ranks skip serialization but stay in lockstep + self._barrier() + return + + # Rank 0 writes artifacts + save_file(state, os.path.join(out_dir, "model.safetensors")) + + # Optimizer + scheduler + if optim_state is not None: + torch.save(optim_state, os.path.join(out_dir, "optimizer.pt")) + if self.lr_scheduler is not None: + torch.save(self.lr_scheduler.state_dict(), os.path.join(out_dir, "scheduler.pt")) + + # Trainer config/state + base = self._unwrap(self.model) + # Infer mlp_ratio from first block if present + mlp_ratio = 4.0 + try: + if hasattr(base, "blocks") and len(getattr(base, "blocks", [])) > 0: + w1 = base.blocks[0].ffn.w1.weight # [H*mlp, H] + H = int(getattr(base, "hidden_size", w1.shape[1])) + if H > 0: + mlp_ratio = float(w1.shape[0]) / float(H) + except Exception: + pass + + trainer_cfg = { + # capacity / prefixes + "max_length": int(self.args.max_length), + "max_species_prefix": int(getattr(base, "max_species_prefix", 0)), + "max_protein_prefix": int(getattr(base, "max_protein_prefix", 0)), + + # architecture hints + "hidden_size": int(getattr(base, "hidden_size", 0)), + "num_hidden_layers": int(getattr(base, "num_layers", 0)), + "num_attention_heads": int(getattr(base, "num_heads", 0)), + "mlp_ratio": float(mlp_ratio), + + # conditioning flags + "prepend_species": bool(getattr(base, "prepend_species", True)), + "prepend_protein": bool(getattr(base, "prepend_protein", False)), + "species_embedding_dim": int(getattr(base, "species_embedding_dim", 1024)), + + # ESM info (even if prepend_protein=False) + "esm_model_name": str(getattr(self.args, "esm_model_name", "")), + "esm_device": str(getattr(self.args, "esm_device", "cuda")), + "esm_dtype": str(getattr(self.args, "esm_dtype", "fp32")).lower(), + + # kernels + + # attention impl + "attn_impl": str(getattr(base, "attn_impl", "gqa")), + "num_kv_groups": int(getattr(base, "num_kv_groups", 0)), + } + with open(os.path.join(out_dir, "trainer_config.json"), "w") as f: + json.dump(trainer_cfg, f, indent=2) + with open(os.path.join(out_dir, "trainer_state.json"), "w") as f: + json.dump({"epoch": self.state["epoch"], "global_step": self.state["global_step"]}, f, indent=2) + + if p and per_rank_positions is not None: + payload = { + "skip_samples": int(sum(per_rank_positions)), + "per_rank": per_rank_positions, + "world_size": len(per_rank_positions), + } + os.makedirs(os.path.dirname(os.path.abspath(p)), exist_ok=True) + with open(p, "w") as f: + json.dump(payload, f) + + # Tokenizer vocab for sampling + try: + if self.tokenizer is not None and hasattr(self.tokenizer, "save_vocabulary"): + self.tokenizer.save_vocabulary(out_dir) + except Exception as e: + logger.warning(f"Failed to save vocabulary to {out_dir}: {e}") + + self._prune_checkpoints(self.args.output_dir, self.args.save_total_limit) + logger.info(f"Saved checkpoint → {out_dir}") + + # Release other ranks + self._barrier() + + def _resume_from(self, ckpt_dir: str): + st_path = os.path.join(ckpt_dir, "model.safetensors") + if not os.path.exists(st_path): + raise FileNotFoundError(f"No model.safetensors in {ckpt_dir}") + state = load_file(st_path) + + if isinstance(self.model, FSDP): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + with FSDP.state_dict_type( + self.model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=False, offload_to_cpu=True), + ): + self.model.load_state_dict(state, strict=False) + else: + self._unwrap(self.model).load_state_dict(state, strict=False) + + + scheduler_restored = False + + opt_path = os.path.join(ckpt_dir, "optimizer.pt") + if os.path.exists(opt_path): + if self.optimizer is None: + self._create_optimizer_and_scheduler() + if not self.args.override_lr_on_resume: + loaded = torch.load(opt_path, map_location="cpu") + # Under FSDP, saved optimizer.pt is a full optimizer state dict produced by + # FSDP.optim_state_dict(). Convert it to a per-rank state dict before loading. + if isinstance(self.model, FSDP): + try: + loaded = FSDP.optim_state_dict_to_load(self.model, self.optimizer, loaded) + except Exception as e: + msg = ( + "Failed to convert FSDP optimizer state dict for loading. " + "This checkpoint likely contains an incomplete (rank0-only sharded) optimizer.pt from an older version. " + "Full optimizer resume is not possible from this checkpoint.\n" + f"Underlying error: {e}\n" + "Options:\n" + " 1) Start a fresh run (new --output_dir), or\n" + " 2) Re-run with --override_lr_on_resume to skip optimizer restore (not a full resume)." + ) + if self._is_main(): + logger.error(msg) + raise RuntimeError(msg) from e + self.optimizer.load_state_dict(loaded) + + sch_path = os.path.join(ckpt_dir, "scheduler.pt") + if os.path.exists(sch_path): + if self.lr_scheduler is None: + self._create_optimizer_and_scheduler() + if self.lr_scheduler is not None and not self.args.override_lr_on_resume: + self.lr_scheduler.load_state_dict(torch.load(sch_path, map_location="cpu")) + scheduler_restored = True + + ts_path = os.path.join(ckpt_dir, "trainer_state.json") + if os.path.exists(ts_path): + with open(ts_path, "r") as f: + ts = json.load(f) + self.state["epoch"] = int(ts.get("epoch", 0)) + self.state["global_step"] = int(ts.get("global_step", 0)) + + steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0) + if steps_per_epoch > 0: + inferred_epoch = self.state.get("global_step", 0) // steps_per_epoch + num_epochs = max(int(self.args.num_train_epochs), 1) + inferred_epoch = min(inferred_epoch, num_epochs - 1) + if inferred_epoch != self.state.get("epoch"): + if self._is_main(): + logger.info( + "Adjusting epoch from %s to %s based on global_step %s and steps_per_epoch %s", + self.state.get("epoch"), + inferred_epoch, + self.state.get("global_step"), + steps_per_epoch, + ) + self.state["epoch"] = int(inferred_epoch) + + # If we skipped loading the scheduler state (e.g., different world size or override), + # fast-forward it to the saved global_step so LR does not restart from warmup. + if self.lr_scheduler is not None and not scheduler_restored: + target_step = int(self.state.get("global_step", 0)) + if target_step > 0: + try: + # Most schedulers (LambdaLR, CosineAnnealing, etc.) accept an "epoch" kwarg. + self.lr_scheduler.step(target_step) + except TypeError: + # Fallback: advance manually. + for _ in range(target_step): + self.lr_scheduler.step() + # Ensure optimizer LR reflects the scheduler's current value. + try: + last_lrs = self.lr_scheduler.get_last_lr() + except Exception: + last_lrs = [group.get("lr") for group in self.optimizer.param_groups] + if last_lrs: + for group, lr in zip(self.optimizer.param_groups, last_lrs): + group["lr"] = float(lr) + + logger.info(f"Resumed from {ckpt_dir}") + + def _checkpoint_step(self, path: str) -> Optional[int]: + m = re.fullmatch(r"checkpoint-(\d+)", os.path.basename(path)) + if not m: + return None + return int(m.group(1)) + + def _prune_checkpoints(self, root: str, keep: int): + if not os.path.isdir(root): + return + + try: + subdirs = [ + os.path.join(root, d) + for d in os.listdir(root) + if os.path.isdir(os.path.join(root, d)) + ] + except FileNotFoundError: + return + + step_dirs: list[tuple[int, str]] = [] + for path in subdirs: + step = self._checkpoint_step(path) + if step is not None: + step_dirs.append((step, path)) + + if not step_dirs: + return + + step_dirs.sort(key=lambda item: item[0]) + latest_step = step_dirs[-1][0] + + recent_window = max(0, int(getattr(self.args, "ckpt_recent_window_steps", 0) or 0)) + recent_interval = max(0, int(getattr(self.args, "ckpt_recent_interval", 0) or 0)) + archive_interval = max(0, int(getattr(self.args, "ckpt_archive_interval", 0) or 0)) + + keep_paths: set[str] = set() + if recent_window > 0 and (recent_interval > 0 or archive_interval > 0): + if recent_interval <= 0: + recent_interval = max(1, int(getattr(self.args, "save_steps", 1) or 1)) + + for step, path in step_dirs: + age = latest_step - step + if age <= recent_window: + interval = recent_interval + else: + interval = archive_interval + if interval > 0 and (step % interval == 0): + keep_paths.add(path) + + if not keep_paths: + # Legacy fallback: keep the most recent N step checkpoints. + if keep <= 0: + return + keep_paths = {path for _, path in step_dirs[-keep:]} + else: + # Always preserve the newest checkpoint, even if the interval math misses it. + keep_paths.add(step_dirs[-1][1]) + if keep > 0: + kept = [(step, path) for step, path in step_dirs if path in keep_paths] + if len(kept) > keep: + trim = len(kept) - keep + for _, path in kept[:trim]: + keep_paths.discard(path) + + removed = [] + for _, path in step_dirs: + if path in keep_paths: + continue + shutil.rmtree(path, ignore_errors=True) + removed.append(os.path.basename(path)) + + if removed and self._is_main(): + logger.info( + "Pruned %s checkpoints (latest_step=%s, recent_window=%s, recent_interval=%s, archive_interval=%s)", + len(removed), + latest_step, + recent_window, + recent_interval, + archive_interval, + ) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9fbae8db30912511c2aba807c03fe09296bef300 --- /dev/null +++ b/train.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python +""" +Minimal, honest training script for CodonGPT on CSV data. + +- Species conditioning: REQUIRED (precomputed embeddings) +- Protein conditioning (ESM-C): ENABLED BY DEFAULT. Disable with --no_protein. +- Global capacity is controlled by --max_length (prefix + start + codon). +""" + +import os +import math +import argparse +import logging +import torch + +from src import CodonGPT, CodonTokenizer, Trainer, TrainingArguments +from src.dataset import create_precomputed_dataloaders, SpeciesEmbeddingStore + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger("codongpt.train") + +def _describe_sdp_kernels() -> None: + # Log the enabled SDPA backends (Flash/MemEff/Math) without raising on older PyTorch + flash = None; mem_eff = None; mathk = None + if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda'): + tbc = torch.backends.cuda + if hasattr(tbc, 'flash_sdp_enabled'): + flash = tbc.flash_sdp_enabled() + if hasattr(tbc, 'mem_efficient_sdp_enabled'): + mem_eff = tbc.mem_efficient_sdp_enabled() + if hasattr(tbc, 'math_sdp_enabled'): + mathk = tbc.math_sdp_enabled() + logger.info(f"SDP kernels: flash={flash} mem_efficient={mem_eff} math={mathk}") + +def _print_model_size(model: torch.nn.Module, bf16: bool, fp16: bool) -> None: + total = sum(p.numel() for p in model.parameters()) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + w_bytes = 2 if (bf16 or fp16) else 4 + opt_bytes = 8 # Adam moments in FP32 + weights_gb = total * w_bytes / (1024**3) + opt_gb = trainable * opt_bytes / (1024**3) + logger.info( + f"Model params: total={total:,} trainable={trainable:,} (~{weights_gb:.2f} GB weights, ~{opt_gb:.2f} GB optimizer)" + ) + +def _speed_toggles(): + if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"): + torch.backends.cuda.matmul.allow_tf32 = True + if hasattr(torch, "set_float32_matmul_precision"): + torch.set_float32_matmul_precision("high") + if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "benchmark"): + torch.backends.cudnn.benchmark = True + + +def parse_args(): + p = argparse.ArgumentParser(description="Train CodonGPT on CSV data") + # Data (CSV path or Parquet glob/dir) + p.add_argument("--train_data", type=str, default="random_sample_1000.csv", + help="Training data: CSV file or Parquet glob/dir (e.g., ./data/train_shards/*.parquet)") + p.add_argument("--val_data", type=str, default=None, + help="Validation data: CSV file or Parquet glob/dir") + p.add_argument("--embeddings_dir", type=str, default="embeddings", + help="Dir with species embeddings (species_vocab.json, *.bin/memmap)") + + # Model / capacity + p.add_argument("--hidden", type=int, default=750, help="Model hidden size") + p.add_argument("--layers", type=int, default=20, help="Number of transformer layers") + p.add_argument("--heads", type=int, default=15, help="Number of attention heads") + p.add_argument("--attn", type=str, choices=["mha", "gqa"], default="gqa", help="Attention implementation: 'mha' or 'gqa'") + p.add_argument("--num_kv_groups", type=int, default=5, help="GQA: number of KV groups (0 = default/no grouping)") + p.add_argument("--mlp_ratio", type=float, default=3.2, help="FFN expansion ratio (mlp hidden = ratio * hidden)") + p.add_argument("--max_length", type=int, default=2048, + help="Global max length (prefix + start + codon)") + p.add_argument("--max_species_prefix", type=int, default=0, + help="Cap species prefix tokens (0 = uncapped)") + p.add_argument("--max_protein_prefix", type=int, default=1024, + help="Cap protein prefix tokens (0 = uncapped)") + + # Protein conditioning: always enabled (ESM-C) + + # Training + p.add_argument("--output_dir", type=str, default="checkpoints", help="Where to save checkpoints") + p.add_argument("--epochs", type=int, default=1, help="Number of training epochs") + p.add_argument("--batch_size", type=int, default=20, help="Per-device train batch size") + p.add_argument("--eval_batch_size", type=int, default=32, help="Per-device eval batch size") + p.add_argument("--workers", type=int, default=4, help="DataLoader workers") + p.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps") + p.add_argument("--train_shuffle_buffer", type=int, default=0, + help="Streaming shuffle buffer for training (set 0 when data is pre-shuffled)") + p.add_argument("--val_shuffle_buffer", type=int, default=0, + help="Streaming shuffle buffer for validation (0 disables)") + p.add_argument("--csv_chunksize", type=int, default=200_000, + help="Pandas read_csv chunksize for CSV inputs") + + # Optim / schedule + p.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio for LR schedule (0.0-1.0)") + p.add_argument( + "--lr_scheduler", + type=str, + choices=["linear", "cosine", "constant"], + default="linear", + help="LR schedule applied after warmup; 'linear' decays to zero by the end of training", + ) + p.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay") + p.add_argument("--adam_beta1", type=float, default=0.9, + help="Adam beta1 (momentum) coefficient") + p.add_argument("--adam_beta2", type=float, default=0.95, + help="Adam beta2 (squared-gradient) coefficient") + p.add_argument("--logging_steps", type=int, default=20, help="Logging interval (steps)") + p.add_argument("--save_steps", type=int, default=10, help="Save every N steps (0 disables step-saving)") + p.add_argument("--save_total_limit", type=int, default=10, help="Keep at most N recent checkpoints") + p.add_argument("--ckpt_recent_window_steps", type=int, default=0, + help="If >0, keep finer-grained checkpoints within this many recent steps") + p.add_argument("--ckpt_recent_interval", type=int, default=0, + help="Retention interval inside the recent checkpoint window (0 disables custom retention)") + p.add_argument("--ckpt_archive_interval", type=int, default=0, + help="Retention interval for checkpoints older than the recent window (0 prunes them)") + p.add_argument("--max_steps", type=int, default=-1, + help="Total training steps. REQUIRED for streaming (IterableDataset)") + p.add_argument("--steps_per_epoch", type=int, default=0, + help="For streaming datasets: shape LR schedule as epochs*steps_per_epoch when max_steps<0") + p.add_argument("--max_grad_norm", type=float, default=1.0, + help="Clip gradients to this global L2 norm; set <=0 to disable") + p.add_argument("--override_lr_on_resume", action="store_true", + help="Do not restore LR/optimizer state on resume (keep current lr)") + + # Resume + p.add_argument("--resume_from", type=str, default=None, + help="Path to checkpoint dir to resume from; pass 'auto' to pick latest in output_dir") + + # Evaluation scheduling + p.add_argument("--eval_interval", type=int, default=0, + help="Run evaluation every N optimizer steps on --val_data (0 disables)") + p.add_argument("--eval_steps", type=int, default=5000, + help="For streaming eval datasets: limit to this many batches (0 = full eval)") + + # Hardware / precision + p.add_argument("--device", type=str, default="cuda", help="cuda or cpu") + p.add_argument("--bf16", action="store_true", help="bfloat16 mixed precision") + p.add_argument("--fp16", action="store_true", help="float16 mixed precision") + p.add_argument("--fsdp", action="store_true", help="Enable FSDP full sharding") + p.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing") + return p.parse_args() + + +def main(): + args = parse_args() + _speed_toggles() + + if args.device == "cuda" and not torch.cuda.is_available(): + logger.warning("CUDA not available; switching to CPU") + args.device = "cpu" + + # Tokenizer + tok = CodonTokenizer() + # Ensure output dir exists and persist vocab.json (used by sampler) + os.makedirs(os.path.abspath(args.output_dir), exist_ok=True) + tok.save_vocabulary(args.output_dir) + + + # Data first — we need Ds for species embeddings + train_loader, val_loader, species_store = create_precomputed_dataloaders( + train_path=args.train_data, + val_path=args.val_data, + embeddings_dir=args.embeddings_dir, + tokenizer=tok, + batch_size=args.batch_size, + num_workers=args.workers, + species_pooling="sequence", # prefer variable-length token sequence if available + csv_chunksize=int(args.csv_chunksize), + train_shuffle_buffer=int(args.train_shuffle_buffer), + val_shuffle_buffer=int(args.val_shuffle_buffer), + ) + + # Estimate steps_per_epoch for streaming schedule shaping if not provided + steps_per_epoch = int(getattr(args, "steps_per_epoch", 0) or 0) + total_rows = 0 + paths: list[str] = [] + if steps_per_epoch <= 0 and int(args.max_steps) < 0: + def _expand_paths(maybe: str | list[str]) -> list[str]: + import glob as _glob + from pathlib import Path as _Path + paths: list[str] = [] + if isinstance(maybe, str): + p = _Path(maybe) + if p.is_dir(): + paths.extend(sorted(str(x) for x in p.rglob("*.parquet"))) + else: + paths = sorted(_glob.glob(str(p))) + else: + for it in maybe: + paths.extend(_expand_paths(it)) + # de-dup + seen = set(); out = [] + for x in paths: + if x not in seen: + out.append(x); seen.add(x) + return out + + paths = _expand_paths(args.train_data) + if paths: + try: + import pyarrow.parquet as pq + for fp in paths: + if fp.lower().endswith((".parquet", ".parq")): + pf = pq.ParquetFile(fp) + md = pf.metadata + if md is not None: + total_rows += int(md.num_rows) + except Exception: + # Fallback: keep steps_per_epoch at 0 if pyarrow not available + total_rows = 0 + if total_rows > 0: + world = int(os.environ.get("WORLD_SIZE", "1")) + ga = max(1, int(getattr(args, "grad_accum", 1))) + denom = max(1, int(args.batch_size) * max(1, world) * ga) + steps_per_epoch = max(1, math.ceil(total_rows / denom)) + logger.info(f"Estimated steps_per_epoch={steps_per_epoch} from {len(paths)} parquet files, total_rows={total_rows}") + + world = int(os.environ.get("WORLD_SIZE", "1")) + grad_accum = max(1, int(getattr(args, "grad_accum", 1))) + effective_global_batch = int(args.batch_size) * max(1, world) * grad_accum + logger.info( + "Batch config: per_device_train_batch=%s per_device_eval_batch=%s world_size=%s grad_accum=%s effective_global_batch=%s", + args.batch_size, + args.eval_batch_size, + world, + grad_accum, + effective_global_batch, + ) + + # Resolve per-process CUDA device for ESM (avoid defaulting to cuda:0 on all ranks) + esm_dev = "cpu" + if args.device == "cuda" and torch.cuda.is_available(): + lr = int(os.environ.get("LOCAL_RANK", "0")) + esm_dev = f"cuda:{lr}" + + # Model — species is always on; protein defaults to ON (can be disabled with --no_protein) + model = CodonGPT( + vocab_size=tok.vocab_size, + num_special_tokens=tok.num_special_tokens, + special_ids=tok.special_ids, + hidden_size=args.hidden, + num_layers=args.layers, + num_heads=args.heads, + mlp_ratio=float(args.mlp_ratio), + max_position_embeddings=args.max_length, + prepend_species=True, + prepend_protein=True, + esm_model_name="esmc_300m", + esm_device=esm_dev, + max_protein_prefix=int(args.max_protein_prefix), + max_species_prefix=int(args.max_species_prefix), + dropout=0.1, + species_embedding_dim=int(species_store.Ds()), + attn_impl=str(args.attn), + num_kv_groups=int(args.num_kv_groups), + ) + + # Report model size and SDPA (Flash) kernel configuration + _print_model_size(model, bf16=bool(args.bf16), fp16=bool(args.fp16)) + _describe_sdp_kernels() + + # Trainer args + targs = TrainingArguments( + output_dir=args.output_dir, + save_steps=args.save_steps, + save_total_limit=int(args.save_total_limit), + ckpt_recent_window_steps=int(args.ckpt_recent_window_steps), + ckpt_recent_interval=int(args.ckpt_recent_interval), + ckpt_archive_interval=int(args.ckpt_archive_interval), + num_train_epochs=args.epochs, + max_steps=int(args.max_steps), + gradient_accumulation_steps=int(args.grad_accum), + warmup_ratio=float(args.warmup_ratio), + lr_scheduler_type=str(args.lr_scheduler), + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.eval_batch_size, + dataloader_num_workers=args.workers, + learning_rate=args.lr, + weight_decay=args.weight_decay, + adam_beta1=float(args.adam_beta1), + adam_beta2=float(args.adam_beta2), + max_grad_norm=float(args.max_grad_norm), + logging_steps=args.logging_steps, + override_lr_on_resume=bool(args.override_lr_on_resume), + data_cursor_path=os.path.join(os.path.abspath(args.output_dir), "data_cursor.json"), + fp16=bool(args.fp16), + bf16=bool(args.bf16), + fsdp=("full_shard" if args.fsdp else None), + gradient_checkpointing=bool(args.grad_ckpt), + max_length=int(args.max_length), + esm_model_name="esmc_300m", + esm_device=esm_dev, + esm_dtype=("bf16" if args.bf16 else ("fp16" if args.fp16 else "fp32")), + # sampling eval + eval_interval=int(args.eval_interval), + eval_steps=int(args.eval_steps), + steps_per_epoch=int(steps_per_epoch), + ) + + # Resolve auto-resume if requested + resume_path = None + if args.resume_from: + if args.resume_from == "auto": + root = os.path.abspath(args.output_dir) + if os.path.isdir(root): + try: + subdirs = [] + for d in os.listdir(root): + path = os.path.join(root, d) + if not os.path.isdir(path): + continue + if not ( + d == "final_model" or + d.startswith("checkpoint-") + ): + continue + if not ( + os.path.exists(os.path.join(path, "model.safetensors")) or + os.path.exists(os.path.join(path, "pytorch_model.bin")) + ): + continue + subdirs.append(path) + subdirs.sort(key=lambda d: os.path.getmtime(d), reverse=True) + resume_path = subdirs[0] if subdirs else None + except Exception: + resume_path = None + else: + resume_path = args.resume_from + + trainer = Trainer( + model=model, + args=targs, + tokenizer=tok, + species_store=species_store, + resume_from_checkpoint=resume_path, + ) + trainer.attach_dataloaders(train_loader, val_loader) + + logger.info("Starting training...") + trainer.train() + logger.info("Training finished.") + + +if __name__ == "__main__": + main() diff --git a/training_checkpoints/checkpoint-71000/config.json b/training_checkpoints/checkpoint-71000/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6cc9978ec27383e958dfa2b8653d337c382322af --- /dev/null +++ b/training_checkpoints/checkpoint-71000/config.json @@ -0,0 +1,17 @@ +{ + "max_length": 2048, + "max_species_prefix": 0, + "max_protein_prefix": 1024, + "hidden_size": 750, + "num_hidden_layers": 20, + "num_attention_heads": 15, + "mlp_ratio": 3.2, + "prepend_species": true, + "prepend_protein": true, + "species_embedding_dim": 1024, + "esm_model_name": "esmc_300m", + "esm_device": "cuda:0", + "esm_dtype": "bf16", + "attn_impl": "mha", + "num_kv_groups": 5 +} \ No newline at end of file diff --git a/training_checkpoints/checkpoint-71000/model.safetensors b/training_checkpoints/checkpoint-71000/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..65b479f85fb07898fc19b4a228d8cf97f79bab70 --- /dev/null +++ b/training_checkpoints/checkpoint-71000/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07bc223f4d934e2baff5a8085a78348766b6a8324aa091a1459fce2b2c6d3837 +size 1284544520 diff --git a/training_checkpoints/checkpoint-71000/optimizer.pt b/training_checkpoints/checkpoint-71000/optimizer.pt new file mode 100644 index 0000000000000000000000000000000000000000..a47de35d78b8eca2b01d22c0dcccea9f1b866fbf --- /dev/null +++ b/training_checkpoints/checkpoint-71000/optimizer.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:751570fed64f000a53218f2c9a7e47a4503a302760f1c0d6b52b63ce4a25cec8 +size 1237115851 diff --git a/training_checkpoints/checkpoint-71000/scheduler.pt b/training_checkpoints/checkpoint-71000/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..85116e69b0cd5d20493bbec8a9fcc3e07b38d34c --- /dev/null +++ b/training_checkpoints/checkpoint-71000/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdca58db103d9ad6aba34334e8a03e08e780b7fe95ef0677f2519e7b16023ff8 +size 1465 diff --git a/training_checkpoints/checkpoint-71000/trainer_config.json b/training_checkpoints/checkpoint-71000/trainer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..6cc9978ec27383e958dfa2b8653d337c382322af --- /dev/null +++ b/training_checkpoints/checkpoint-71000/trainer_config.json @@ -0,0 +1,17 @@ +{ + "max_length": 2048, + "max_species_prefix": 0, + "max_protein_prefix": 1024, + "hidden_size": 750, + "num_hidden_layers": 20, + "num_attention_heads": 15, + "mlp_ratio": 3.2, + "prepend_species": true, + "prepend_protein": true, + "species_embedding_dim": 1024, + "esm_model_name": "esmc_300m", + "esm_device": "cuda:0", + "esm_dtype": "bf16", + "attn_impl": "mha", + "num_kv_groups": 5 +} \ No newline at end of file diff --git a/training_checkpoints/checkpoint-71000/trainer_state.json b/training_checkpoints/checkpoint-71000/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..043df067d4643ba0278770e8db814d9313342f6c --- /dev/null +++ b/training_checkpoints/checkpoint-71000/trainer_state.json @@ -0,0 +1,4 @@ +{ + "epoch": 2, + "global_step": 71000 +} \ No newline at end of file diff --git a/training_checkpoints/checkpoint-71000/vocab.json b/training_checkpoints/checkpoint-71000/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..01c8e6b032b471eac80fe56cf65ee81e62f49921 --- /dev/null +++ b/training_checkpoints/checkpoint-71000/vocab.json @@ -0,0 +1,78 @@ +{ + "special_token_str": { + "bos": "", + "eos": "", + "pad": "", + "unk": "" + }, + "vocab": { + "": 2, + "": 0, + "": 3, + "": 1, + "AAA": 4, + "AAC": 5, + "AAG": 6, + "AAT": 7, + "ACA": 8, + "ACC": 9, + "ACG": 10, + "ACT": 11, + "AGA": 12, + "AGC": 13, + "AGG": 14, + "AGT": 15, + "ATA": 16, + "ATC": 17, + "ATG": 18, + "ATT": 19, + "CAA": 20, + "CAC": 21, + "CAG": 22, + "CAT": 23, + "CCA": 24, + "CCC": 25, + "CCG": 26, + "CCT": 27, + "CGA": 28, + "CGC": 29, + "CGG": 30, + "CGT": 31, + "CTA": 32, + "CTC": 33, + "CTG": 34, + "CTT": 35, + "GAA": 36, + "GAC": 37, + "GAG": 38, + "GAT": 39, + "GCA": 40, + "GCC": 41, + "GCG": 42, + "GCT": 43, + "GGA": 44, + "GGC": 45, + "GGG": 46, + "GGT": 47, + "GTA": 48, + "GTC": 49, + "GTG": 50, + "GTT": 51, + "TAA": 52, + "TAC": 53, + "TAG": 54, + "TAT": 55, + "TCA": 56, + "TCC": 57, + "TCG": 58, + "TCT": 59, + "TGA": 60, + "TGC": 61, + "TGG": 62, + "TGT": 63, + "TTA": 64, + "TTC": 65, + "TTG": 66, + "TTT": 67 + } +} \ No newline at end of file