Public CodonTranslator model and training code release
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- CodonTranslator/__init__.py +4 -0
- CodonTranslator/__pycache__/__init__.cpython-312.pyc +0 -0
- CodonTranslator/__pycache__/layers.cpython-312.pyc +0 -0
- CodonTranslator/__pycache__/models.cpython-312.pyc +0 -0
- CodonTranslator/__pycache__/tokenizer.cpython-312.pyc +0 -0
- CodonTranslator/__pycache__/translator.cpython-312.pyc +0 -0
- CodonTranslator/layers.py +239 -0
- CodonTranslator/models.py +306 -0
- CodonTranslator/tokenizer.py +183 -0
- CodonTranslator/translator.py +479 -0
- LICENSE +21 -0
- README.md +115 -0
- __pycache__/precompute_embeddings.cpython-312.pyc +0 -0
- __pycache__/resplit_data_v3.cpython-312.pyc +0 -0
- __pycache__/sampling.cpython-312.pyc +0 -0
- __pycache__/train.cpython-312.pyc +0 -0
- batch_eval.py +382 -0
- codontranslator/__init__.py +3 -0
- environment.yml +20 -0
- eval.py +1239 -0
- final_model/config.json +17 -0
- final_model/model.safetensors +3 -0
- final_model/trainer_config.json +17 -0
- final_model/trainer_state.json +4 -0
- final_model/vocab.json +78 -0
- precompute_embeddings.py +503 -0
- pyproject.toml +24 -0
- requirements.txt +12 -0
- resplit_data_v3.py +1444 -0
- sampling.py +314 -0
- slurm/rebuild_data_v3_cpu.sbatch +98 -0
- slurm/submit_train_v3_h200_8x_chain.sh +24 -0
- slurm/train_v3_h200_8x_single.sbatch +165 -0
- src/__init__.py +33 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/dataset.cpython-312.pyc +0 -0
- src/__pycache__/layers.cpython-312.pyc +0 -0
- src/__pycache__/models.cpython-312.pyc +0 -0
- src/__pycache__/sampler.cpython-312.pyc +0 -0
- src/__pycache__/tokenizer.cpython-312.pyc +0 -0
- src/__pycache__/trainer.cpython-312.pyc +0 -0
- src/dataset.py +833 -0
- src/layers.py +384 -0
- src/models.py +490 -0
- src/sampler.py +696 -0
- src/tokenizer.py +324 -0
- src/trainer.py +1230 -0
- train.py +352 -0
- training_checkpoints/checkpoint-71000/config.json +17 -0
- training_checkpoints/checkpoint-71000/model.safetensors +3 -0
CodonTranslator/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .translator import CodonTranslator
|
| 2 |
+
|
| 3 |
+
__all__ = ["CodonTranslator"]
|
| 4 |
+
|
CodonTranslator/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
CodonTranslator/__pycache__/layers.cpython-312.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
CodonTranslator/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
CodonTranslator/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
CodonTranslator/__pycache__/translator.cpython-312.pyc
ADDED
|
Binary file (29.9 kB). View file
|
|
|
CodonTranslator/layers.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal attention/norm/FFN blocks used by the translator backbone
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RMSNorm(nn.Module):
|
| 14 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.eps = eps
|
| 17 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 21 |
+
return x * norm * self.weight
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
x1 = x[..., ::2]
|
| 26 |
+
x2 = x[..., 1::2]
|
| 27 |
+
x_rot = torch.zeros_like(x)
|
| 28 |
+
x_rot[..., ::2] = -x2
|
| 29 |
+
x_rot[..., 1::2] = x1
|
| 30 |
+
return x * cos + x_rot * sin
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class GroupedQueryAttention(nn.Module):
|
| 34 |
+
def __init__(self, dim: int, num_heads: int, num_kv_groups: int, dropout: float = 0.0, qk_norm: bool = False):
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert num_heads % max(1, num_kv_groups) == 0
|
| 37 |
+
self.dim = dim
|
| 38 |
+
self.num_heads = int(num_heads)
|
| 39 |
+
self.num_kv_groups = max(1, int(num_kv_groups))
|
| 40 |
+
self.group_size = self.num_heads // self.num_kv_groups
|
| 41 |
+
assert dim % num_heads == 0
|
| 42 |
+
self.head_dim = dim // num_heads
|
| 43 |
+
self.dropout = dropout
|
| 44 |
+
|
| 45 |
+
self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
|
| 46 |
+
self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
|
| 47 |
+
self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
|
| 48 |
+
self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False)
|
| 49 |
+
|
| 50 |
+
self.q_norm = RMSNorm(self.head_dim) if qk_norm else None
|
| 51 |
+
self.k_norm = RMSNorm(self.head_dim) if qk_norm else None
|
| 52 |
+
|
| 53 |
+
self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
|
| 54 |
+
|
| 55 |
+
def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype):
|
| 56 |
+
key = (T, device, dtype)
|
| 57 |
+
cached = self._rope_cache.get(key)
|
| 58 |
+
if cached is not None:
|
| 59 |
+
return cached
|
| 60 |
+
dim_half = self.head_dim // 2
|
| 61 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
|
| 62 |
+
t = torch.arange(T, device=device, dtype=torch.float32)
|
| 63 |
+
freqs = torch.outer(t, inv_freq)
|
| 64 |
+
cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
|
| 65 |
+
sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
|
| 66 |
+
cos = cos.to(dtype).unsqueeze(0).unsqueeze(0)
|
| 67 |
+
sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
|
| 68 |
+
self._rope_cache[key] = (cos, sin)
|
| 69 |
+
return cos, sin
|
| 70 |
+
|
| 71 |
+
def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int | torch.Tensor = 0):
|
| 72 |
+
B, T_new, _ = x.shape
|
| 73 |
+
q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 74 |
+
k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous()
|
| 75 |
+
v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous()
|
| 76 |
+
|
| 77 |
+
if self.q_norm is not None:
|
| 78 |
+
q = self.q_norm(q)
|
| 79 |
+
if self.k_norm is not None:
|
| 80 |
+
k = self.k_norm(k)
|
| 81 |
+
|
| 82 |
+
if isinstance(position_offset, int):
|
| 83 |
+
cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
|
| 84 |
+
if position_offset > 0:
|
| 85 |
+
cos = cos[:, :, position_offset: position_offset + T_new, :]
|
| 86 |
+
sin = sin[:, :, position_offset: position_offset + T_new, :]
|
| 87 |
+
q = _apply_rope(q, cos, sin)
|
| 88 |
+
k = _apply_rope(k, cos, sin)
|
| 89 |
+
else:
|
| 90 |
+
off = position_offset.to(device=x.device, dtype=torch.long)
|
| 91 |
+
max_off = int(off.max().item())
|
| 92 |
+
cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype)
|
| 93 |
+
ar = torch.arange(T_new, device=x.device, dtype=torch.long)
|
| 94 |
+
idx = (off.unsqueeze(1) + ar.unsqueeze(0))
|
| 95 |
+
cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1)
|
| 96 |
+
sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1)
|
| 97 |
+
q = _apply_rope(q, cos_b, sin_b)
|
| 98 |
+
k = _apply_rope(k, cos_b, sin_b)
|
| 99 |
+
|
| 100 |
+
if past_kv is not None:
|
| 101 |
+
k_p, v_p = past_kv
|
| 102 |
+
k = torch.cat([k_p, k], dim=2)
|
| 103 |
+
v = torch.cat([v_p, v], dim=2)
|
| 104 |
+
|
| 105 |
+
is_causal = past_kv is None
|
| 106 |
+
# Prefer Flash, then MemEff, then Math; allow FP32 via Math
|
| 107 |
+
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
|
| 108 |
+
if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
|
| 109 |
+
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 110 |
+
with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
|
| 111 |
+
out = F.scaled_dot_product_attention(
|
| 112 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
out = F.scaled_dot_product_attention(
|
| 116 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
|
| 117 |
+
)
|
| 118 |
+
out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim)
|
| 119 |
+
out = self.out_proj(out)
|
| 120 |
+
if use_cache:
|
| 121 |
+
return out, (k, v)
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SwiGLU(nn.Module):
|
| 126 |
+
"""SwiGLU FFN with parameter names matching checkpoints (w1, w2, w3):
|
| 127 |
+
- w1: Linear(dim -> hidden)
|
| 128 |
+
- w2: Linear(hidden -> dim)
|
| 129 |
+
- w3: Linear(dim -> hidden)
|
| 130 |
+
Forward: w2(silu(w1(x)) * w3(x))
|
| 131 |
+
"""
|
| 132 |
+
def __init__(self, dim: int, hidden_mult: float = 4.0, dropout: float = 0.0):
|
| 133 |
+
super().__init__()
|
| 134 |
+
hidden = int(dim * hidden_mult)
|
| 135 |
+
self.w1 = nn.Linear(dim, hidden, bias=False)
|
| 136 |
+
self.w2 = nn.Linear(hidden, dim, bias=False)
|
| 137 |
+
self.w3 = nn.Linear(dim, hidden, bias=False)
|
| 138 |
+
self.dropout = nn.Dropout(dropout)
|
| 139 |
+
|
| 140 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TransformerBlock(nn.Module):
|
| 145 |
+
def __init__(self, dim: int, num_heads: int, mlp_ratio: float, dropout: float = 0.0, num_kv_groups: Optional[int] = None, qk_norm: bool = False, attn_type: str = "mha"):
|
| 146 |
+
super().__init__()
|
| 147 |
+
if attn_type == "gqa":
|
| 148 |
+
self.attn = GroupedQueryAttention(dim, num_heads=num_heads, num_kv_groups=(num_kv_groups or num_heads), dropout=dropout)
|
| 149 |
+
else:
|
| 150 |
+
self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout)
|
| 151 |
+
self.ffn = SwiGLU(dim, hidden_mult=mlp_ratio, dropout=dropout)
|
| 152 |
+
self.ln1 = RMSNorm(dim)
|
| 153 |
+
self.ln2 = RMSNorm(dim)
|
| 154 |
+
|
| 155 |
+
def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0):
|
| 156 |
+
a = self.attn(self.ln1(x), past_kv=past_kv, use_cache=use_cache, position_offset=position_offset)
|
| 157 |
+
if use_cache:
|
| 158 |
+
a, kv = a
|
| 159 |
+
x = x + a
|
| 160 |
+
x = x + self.ffn(self.ln2(x))
|
| 161 |
+
if use_cache:
|
| 162 |
+
return x, kv
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class MultiHeadAttention(nn.Module):
|
| 167 |
+
"""Standard MHA with fused qkv and RoPE, SDPA backend selection.
|
| 168 |
+
Matches checkpoint naming: qkv (dim->3*dim) and out_proj (dim->dim).
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(self, dim: int, num_heads: int, dropout: float = 0.0, use_rope: bool = True):
|
| 172 |
+
super().__init__()
|
| 173 |
+
assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
|
| 174 |
+
self.dim = dim
|
| 175 |
+
self.num_heads = num_heads
|
| 176 |
+
self.head_dim = dim // num_heads
|
| 177 |
+
self.dropout = dropout
|
| 178 |
+
self.use_rope = use_rope
|
| 179 |
+
|
| 180 |
+
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
|
| 181 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
| 182 |
+
|
| 183 |
+
self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
|
| 184 |
+
|
| 185 |
+
def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype):
|
| 186 |
+
key = (T, device, dtype)
|
| 187 |
+
cached = self._rope_cache.get(key)
|
| 188 |
+
if cached is not None:
|
| 189 |
+
return cached
|
| 190 |
+
dim_half = self.head_dim // 2
|
| 191 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
|
| 192 |
+
t = torch.arange(T, device=device, dtype=torch.float32)
|
| 193 |
+
freqs = torch.outer(t, inv_freq)
|
| 194 |
+
cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
|
| 195 |
+
sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
|
| 196 |
+
cos = cos.to(dtype).unsqueeze(0).unsqueeze(0)
|
| 197 |
+
sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
|
| 198 |
+
self._rope_cache[key] = (cos, sin)
|
| 199 |
+
return cos, sin
|
| 200 |
+
|
| 201 |
+
def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0):
|
| 202 |
+
B, T_new, _ = x.shape
|
| 203 |
+
qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 204 |
+
q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous()
|
| 205 |
+
|
| 206 |
+
if self.use_rope:
|
| 207 |
+
cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
|
| 208 |
+
if position_offset > 0:
|
| 209 |
+
cos = cos[:, :, position_offset: position_offset + T_new, :]
|
| 210 |
+
sin = sin[:, :, position_offset: position_offset + T_new, :]
|
| 211 |
+
q = _apply_rope(q, cos, sin)
|
| 212 |
+
k_new = _apply_rope(k_new, cos, sin)
|
| 213 |
+
|
| 214 |
+
if past_kv is not None:
|
| 215 |
+
k, v = past_kv
|
| 216 |
+
k = torch.cat([k, k_new], dim=2)
|
| 217 |
+
v = torch.cat([v, v_new], dim=2)
|
| 218 |
+
else:
|
| 219 |
+
k, v = k_new, v_new
|
| 220 |
+
|
| 221 |
+
is_causal = past_kv is None
|
| 222 |
+
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
|
| 223 |
+
if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
|
| 224 |
+
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 225 |
+
with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
|
| 226 |
+
out = F.scaled_dot_product_attention(
|
| 227 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
out = F.scaled_dot_product_attention(
|
| 231 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
|
| 232 |
+
)
|
| 233 |
+
out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim)
|
| 234 |
+
if out.dtype != x.dtype:
|
| 235 |
+
out = out.to(x.dtype)
|
| 236 |
+
out = self.out_proj(out)
|
| 237 |
+
if use_cache:
|
| 238 |
+
return out, (k, v)
|
| 239 |
+
return out
|
CodonTranslator/models.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Dict, Any, Tuple, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.nn.utils.rnn as rnn_utils
|
| 9 |
+
|
| 10 |
+
from .layers import RMSNorm, TransformerBlock
|
| 11 |
+
from .tokenizer import SpecialIds
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FrozenESMCEncoder(nn.Module):
|
| 15 |
+
"""Optional ESM-C encoder; if esm isn't available, stays inactive."""
|
| 16 |
+
def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "bf16"):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.model_name = model_name
|
| 19 |
+
self._device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 20 |
+
self._autocast_dtype = torch.bfloat16 if dtype == "bf16" else (torch.float16 if dtype == "fp16" else None)
|
| 21 |
+
try:
|
| 22 |
+
from esm.models.esmc import ESMC # type: ignore
|
| 23 |
+
from esm.utils.constants.models import ESMC_300M, ESMC_600M # type: ignore
|
| 24 |
+
except Exception as e:
|
| 25 |
+
raise ImportError(
|
| 26 |
+
"ESM is required for CodonTranslator. Please install 'esm>=3.2.0'."
|
| 27 |
+
) from e
|
| 28 |
+
if self.model_name == "esmc_300m":
|
| 29 |
+
const = ESMC_300M; self.D_esm = 960
|
| 30 |
+
elif self.model_name == "esmc_600m":
|
| 31 |
+
const = ESMC_600M; self.D_esm = 1152
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unknown ESM model: {self.model_name}")
|
| 34 |
+
self.model = ESMC.from_pretrained(model_name=const, device=self._device)
|
| 35 |
+
self.tokenizer = self.model.tokenizer
|
| 36 |
+
for p in self.parameters():
|
| 37 |
+
p.requires_grad_(False)
|
| 38 |
+
self.eval()
|
| 39 |
+
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"):
|
| 42 |
+
if self.model is None:
|
| 43 |
+
raise RuntimeError("ESM model not available")
|
| 44 |
+
from esm.utils import encoding # type: ignore
|
| 45 |
+
from esm.utils.misc import stack_variable_length_tensors # type: ignore
|
| 46 |
+
pad = self.tokenizer.pad_token_id
|
| 47 |
+
toks = []
|
| 48 |
+
for s in sequences:
|
| 49 |
+
t = encoding.tokenize_sequence(s, self.tokenizer, add_special_tokens=add_special_tokens)
|
| 50 |
+
if max_length is not None and len(t) > max_length:
|
| 51 |
+
t = t[:max_length]
|
| 52 |
+
toks.append(t)
|
| 53 |
+
input_ids = stack_variable_length_tensors(toks, constant_value=pad)
|
| 54 |
+
attention_mask = (input_ids != pad)
|
| 55 |
+
return input_ids, attention_mask
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True):
|
| 59 |
+
if self.model is None:
|
| 60 |
+
raise RuntimeError("ESM model not available")
|
| 61 |
+
device = self.model.device
|
| 62 |
+
input_ids = input_ids.to(device)
|
| 63 |
+
attention_mask = attention_mask.to(device) if attention_mask is not None else None
|
| 64 |
+
if self._autocast_dtype is not None and device.type == "cuda":
|
| 65 |
+
with torch.amp.autocast('cuda', dtype=self._autocast_dtype):
|
| 66 |
+
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
|
| 67 |
+
else:
|
| 68 |
+
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
|
| 69 |
+
return {"embeddings": outputs.embeddings, "attention_mask": attention_mask}
|
| 70 |
+
|
| 71 |
+
def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None):
|
| 72 |
+
if attention_mask is not None:
|
| 73 |
+
lengths = attention_mask.sum(dim=1) - 2
|
| 74 |
+
lengths = lengths.clamp(min=1)
|
| 75 |
+
else:
|
| 76 |
+
B, L, D = embeddings.shape
|
| 77 |
+
lengths = torch.full((B,), L - 2, device=embeddings.device)
|
| 78 |
+
stripped = embeddings[:, 1:-1, :]
|
| 79 |
+
return stripped, lengths
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TranslatorBackbone(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
vocab_size: int = 79,
|
| 86 |
+
hidden_size: int = 960,
|
| 87 |
+
num_layers: int = 24,
|
| 88 |
+
num_heads: int = 16,
|
| 89 |
+
mlp_ratio: float = 4.0,
|
| 90 |
+
max_position_embeddings: int = 4096,
|
| 91 |
+
dropout: float = 0.1,
|
| 92 |
+
layer_norm_eps: float = 1e-6,
|
| 93 |
+
num_special_tokens: int = 13,
|
| 94 |
+
special_ids: Optional[SpecialIds] = None,
|
| 95 |
+
esm_model_name: str = "esmc_300m",
|
| 96 |
+
esm_device: str = "cuda",
|
| 97 |
+
esm_dtype: str = "bf16",
|
| 98 |
+
max_protein_prefix: int = 0,
|
| 99 |
+
max_species_prefix: int = 0,
|
| 100 |
+
prepend_species: bool = True,
|
| 101 |
+
prepend_protein: bool = True,
|
| 102 |
+
species_embedding_dim: int = 1024,
|
| 103 |
+
attn_impl: str = "gqa",
|
| 104 |
+
num_kv_groups: int = 0,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.vocab_size = int(vocab_size)
|
| 108 |
+
self.hidden_size = int(hidden_size)
|
| 109 |
+
self.num_layers = int(num_layers)
|
| 110 |
+
self.num_heads = int(num_heads)
|
| 111 |
+
self.max_position_embeddings = int(max_position_embeddings)
|
| 112 |
+
self.special_ids = special_ids or SpecialIds()
|
| 113 |
+
self.num_special_tokens = int(num_special_tokens)
|
| 114 |
+
|
| 115 |
+
self.token_embed = nn.Embedding(self.vocab_size, self.hidden_size)
|
| 116 |
+
|
| 117 |
+
# Optional ESM protein encoder
|
| 118 |
+
self.esm = None
|
| 119 |
+
self.esm_ln = None
|
| 120 |
+
if prepend_protein and esm_model_name:
|
| 121 |
+
# Enforce ESM presence – raise if missing
|
| 122 |
+
self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype)
|
| 123 |
+
self.esm_ln = nn.Sequential(
|
| 124 |
+
nn.Linear(self.esm.D_esm, self.hidden_size, bias=False),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
+
nn.LayerNorm(self.hidden_size),
|
| 127 |
+
)
|
| 128 |
+
self.species_embedding_dim = species_embedding_dim if prepend_species else 0
|
| 129 |
+
self.species_ln = None
|
| 130 |
+
if prepend_species:
|
| 131 |
+
self.species_ln = nn.Sequential(
|
| 132 |
+
nn.Linear(self.species_embedding_dim, self.hidden_size, bias=False),
|
| 133 |
+
nn.ReLU(),
|
| 134 |
+
nn.LayerNorm(self.hidden_size),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0
|
| 138 |
+
self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0
|
| 139 |
+
self.prepend_species = bool(prepend_species)
|
| 140 |
+
self.prepend_protein = bool(prepend_protein) and (self.esm is not None)
|
| 141 |
+
|
| 142 |
+
self.start_embed = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
| 143 |
+
nn.init.normal_(self.start_embed, mean=0.0, std=0.02)
|
| 144 |
+
|
| 145 |
+
self.attn_impl = str(attn_impl)
|
| 146 |
+
kv_groups = int(num_kv_groups)
|
| 147 |
+
self.blocks = nn.ModuleList([
|
| 148 |
+
TransformerBlock(
|
| 149 |
+
dim=self.hidden_size,
|
| 150 |
+
num_heads=self.num_heads,
|
| 151 |
+
mlp_ratio=mlp_ratio,
|
| 152 |
+
dropout=dropout,
|
| 153 |
+
num_kv_groups=(kv_groups if (kv_groups > 0 and self.attn_impl == "gqa") else None),
|
| 154 |
+
qk_norm=False,
|
| 155 |
+
attn_type=("mha" if self.attn_impl == "mha" else "gqa"),
|
| 156 |
+
)
|
| 157 |
+
for _ in range(self.num_layers)
|
| 158 |
+
])
|
| 159 |
+
|
| 160 |
+
self.ln_f = RMSNorm(self.hidden_size, eps=layer_norm_eps)
|
| 161 |
+
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
|
| 162 |
+
self.gradient_checkpointing = False
|
| 163 |
+
|
| 164 |
+
def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
device = self.token_embed.weight.device
|
| 166 |
+
return self.token_embed(token_ids.to(device))
|
| 167 |
+
|
| 168 |
+
def build_prefix(
|
| 169 |
+
self,
|
| 170 |
+
batch_size: int,
|
| 171 |
+
device: torch.device,
|
| 172 |
+
species_tok_emb: Optional[torch.Tensor] = None,
|
| 173 |
+
species_emb: Optional[torch.Tensor] = None,
|
| 174 |
+
protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 175 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 176 |
+
parts: list[torch.Tensor] = []
|
| 177 |
+
if self.prepend_species and self.species_ln is not None:
|
| 178 |
+
if species_emb is not None:
|
| 179 |
+
S = self.species_ln(species_emb.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1))
|
| 180 |
+
parts.append(S)
|
| 181 |
+
parts.append(S)
|
| 182 |
+
elif species_tok_emb is not None:
|
| 183 |
+
S = species_tok_emb
|
| 184 |
+
if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix:
|
| 185 |
+
S = S[:, : self.max_species_prefix, :]
|
| 186 |
+
S = self.species_ln(S.to(device=device, dtype=next(self.parameters()).dtype))
|
| 187 |
+
parts.append(S)
|
| 188 |
+
parts.append(S)
|
| 189 |
+
|
| 190 |
+
if self.prepend_protein and self.esm is not None and protein_input is not None:
|
| 191 |
+
prot_ids, prot_mask = protein_input
|
| 192 |
+
esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True)
|
| 193 |
+
P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask)
|
| 194 |
+
if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix:
|
| 195 |
+
P = P[:, : self.max_protein_prefix, :]
|
| 196 |
+
lengths = lengths.clamp(max=self.max_protein_prefix) if lengths is not None else None
|
| 197 |
+
if P.size(1) > 0:
|
| 198 |
+
P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype))
|
| 199 |
+
if lengths is not None:
|
| 200 |
+
Lp = P.size(1)
|
| 201 |
+
ar = torch.arange(Lp, device=device).unsqueeze(0)
|
| 202 |
+
valid = ar < lengths.unsqueeze(1)
|
| 203 |
+
P = P * valid.unsqueeze(-1)
|
| 204 |
+
parts.append(P)
|
| 205 |
+
|
| 206 |
+
if len(parts) == 0:
|
| 207 |
+
empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype)
|
| 208 |
+
return empty, torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 209 |
+
|
| 210 |
+
prefix = torch.cat(parts, dim=1)
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
valid = (prefix.abs().sum(dim=-1) > 0)
|
| 213 |
+
lengths = valid.sum(dim=1).to(torch.long)
|
| 214 |
+
prefix_budget = max(0, int(self.max_position_embeddings) - 1)
|
| 215 |
+
allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype))
|
| 216 |
+
Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0
|
| 217 |
+
if prefix.size(1) > Lp_max:
|
| 218 |
+
trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2))
|
| 219 |
+
for b in range(prefix.size(0)):
|
| 220 |
+
lb = int(allow[b].item())
|
| 221 |
+
if lb > 0:
|
| 222 |
+
trimmed[b, :lb, :] = prefix[b, :lb, :]
|
| 223 |
+
prefix = trimmed
|
| 224 |
+
lengths = allow
|
| 225 |
+
else:
|
| 226 |
+
lengths = allow
|
| 227 |
+
return prefix, lengths
|
| 228 |
+
|
| 229 |
+
def forward(self, codon_ids: torch.Tensor, cond: Dict[str, Any] = None, labels: Optional[torch.Tensor] = None, return_dict: bool = True, use_cache: bool = False, past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_offset: int = 0) -> Dict[str, torch.Tensor]:
|
| 230 |
+
batch_size, codon_len = codon_ids.shape
|
| 231 |
+
device = codon_ids.device
|
| 232 |
+
species_tok_emb = cond.get("species_tok_emb") if cond else None
|
| 233 |
+
species_emb = cond.get("species_emb") if cond else None
|
| 234 |
+
protein_input = cond.get("protein_input") if cond else None
|
| 235 |
+
|
| 236 |
+
# Build prefix
|
| 237 |
+
prefix, prefix_lengths = self.build_prefix(batch_size, device, species_tok_emb=species_tok_emb, species_emb=species_emb, protein_input=protein_input)
|
| 238 |
+
start = self.start_embed.expand(batch_size, 1, -1)
|
| 239 |
+
|
| 240 |
+
# KV cache path for incremental generation
|
| 241 |
+
if past_kv is not None and codon_len > 0:
|
| 242 |
+
x = self.embed_tokens(codon_ids)
|
| 243 |
+
present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
| 244 |
+
for i, block in enumerate(self.blocks):
|
| 245 |
+
kv_i = past_kv[i] if i < len(past_kv) else None
|
| 246 |
+
out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset)
|
| 247 |
+
x, kv_out = out_blk
|
| 248 |
+
present_kv.append(kv_out)
|
| 249 |
+
x = self.ln_f(x)
|
| 250 |
+
logits_step = self.lm_head(x)
|
| 251 |
+
return {"logits": logits_step[:, 0:0, :], "next_logits": logits_step[:, -1, :], "present_kv": present_kv, "prefix_len": prefix_lengths}
|
| 252 |
+
|
| 253 |
+
# Non-incremental: build prefix+start+codon window
|
| 254 |
+
codon_lens = torch.as_tensor([codon_len] * batch_size, device=device)
|
| 255 |
+
capacity = max(0, int(self.max_position_embeddings))
|
| 256 |
+
budget_after_prefix = torch.clamp(torch.as_tensor(capacity, device=device) - (prefix_lengths + 1), min=0)
|
| 257 |
+
per_cap = torch.minimum(budget_after_prefix, codon_lens)
|
| 258 |
+
max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0
|
| 259 |
+
codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) if max_cap > 0 else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype)
|
| 260 |
+
seqs = []
|
| 261 |
+
for b in range(batch_size):
|
| 262 |
+
lp = int(prefix_lengths[b].item())
|
| 263 |
+
cap = int(per_cap[b].item())
|
| 264 |
+
parts = []
|
| 265 |
+
if lp > 0:
|
| 266 |
+
parts.append(prefix[b, :lp, :])
|
| 267 |
+
parts.append(start[b, 0:1, :])
|
| 268 |
+
if cap > 0:
|
| 269 |
+
parts.append(codon_emb[b, :cap, :])
|
| 270 |
+
seqs.append(torch.cat(parts, dim=0))
|
| 271 |
+
x = rnn_utils.pad_sequence(seqs, batch_first=True)
|
| 272 |
+
|
| 273 |
+
present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
| 274 |
+
for block in self.blocks:
|
| 275 |
+
blk_out = block(x, use_cache=use_cache, position_offset=0)
|
| 276 |
+
if use_cache:
|
| 277 |
+
x, kv = blk_out
|
| 278 |
+
present_kv_list.append(kv)
|
| 279 |
+
else:
|
| 280 |
+
x = blk_out
|
| 281 |
+
x = self.ln_f(x)
|
| 282 |
+
logits_full = self.lm_head(x)
|
| 283 |
+
|
| 284 |
+
next_logits_list = []
|
| 285 |
+
if max_cap == 0:
|
| 286 |
+
codon_logits = logits_full[:, 0:0, :]
|
| 287 |
+
for b in range(batch_size):
|
| 288 |
+
lp = int(prefix_lengths[b].item())
|
| 289 |
+
pos_next = lp
|
| 290 |
+
next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full[b, -1, :])
|
| 291 |
+
next_logits = torch.stack(next_logits_list, dim=0)
|
| 292 |
+
else:
|
| 293 |
+
slices = []
|
| 294 |
+
for b in range(batch_size):
|
| 295 |
+
lp = int(prefix_lengths[b].item())
|
| 296 |
+
cap = int(per_cap[b].item())
|
| 297 |
+
sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size)
|
| 298 |
+
slices.append(sl)
|
| 299 |
+
pos_next = lp + cap
|
| 300 |
+
next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size))
|
| 301 |
+
codon_logits = rnn_utils.pad_sequence(slices, batch_first=True)
|
| 302 |
+
next_logits = torch.stack(next_logits_list, dim=0)
|
| 303 |
+
out = {"logits": codon_logits, "next_logits": next_logits, "prefix_len": prefix_lengths}
|
| 304 |
+
if use_cache:
|
| 305 |
+
out["present_kv"] = present_kv_list
|
| 306 |
+
return out
|
CodonTranslator/tokenizer.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal copy of CodonTokenizer from src/tokenizer.py to keep the package self-contained.
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class SpecialIds:
|
| 13 |
+
pad: int = 0
|
| 14 |
+
unk: int = 1
|
| 15 |
+
bos: int = 2
|
| 16 |
+
eos: int = 3
|
| 17 |
+
|
| 18 |
+
def to_dict(self) -> Dict[str, int]:
|
| 19 |
+
return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CodonTokenizer:
|
| 23 |
+
__slots__ = (
|
| 24 |
+
"codons",
|
| 25 |
+
"_special_token_str",
|
| 26 |
+
"vocab",
|
| 27 |
+
"ids_to_tokens",
|
| 28 |
+
"_special_ids",
|
| 29 |
+
"_num_special_tokens",
|
| 30 |
+
"_genetic_code",
|
| 31 |
+
"_codon2aa_char",
|
| 32 |
+
"_aa2codons_char",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
pad_token: str = "<pad>",
|
| 38 |
+
unk_token: str = "<unk>",
|
| 39 |
+
bos_token: str = "<bos>",
|
| 40 |
+
eos_token: str = "<stop>",
|
| 41 |
+
**_: Any,
|
| 42 |
+
) -> None:
|
| 43 |
+
bases = ("A", "C", "G", "T")
|
| 44 |
+
self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases]
|
| 45 |
+
|
| 46 |
+
special_tokens = [pad_token, unk_token, bos_token, eos_token]
|
| 47 |
+
self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token}
|
| 48 |
+
|
| 49 |
+
self.vocab: Dict[str, int] = {}
|
| 50 |
+
for i, tok in enumerate(special_tokens):
|
| 51 |
+
self.vocab[tok] = i
|
| 52 |
+
for codon in self.codons:
|
| 53 |
+
self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens))
|
| 54 |
+
|
| 55 |
+
self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()}
|
| 56 |
+
|
| 57 |
+
self._special_ids = SpecialIds(
|
| 58 |
+
pad=self.vocab[pad_token],
|
| 59 |
+
unk=self.vocab[unk_token],
|
| 60 |
+
bos=self.vocab[bos_token],
|
| 61 |
+
eos=self.vocab[eos_token],
|
| 62 |
+
)
|
| 63 |
+
self._num_special_tokens = len(special_tokens)
|
| 64 |
+
|
| 65 |
+
self._genetic_code: Dict[str, str] = {
|
| 66 |
+
"TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L",
|
| 67 |
+
"TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S",
|
| 68 |
+
"TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*",
|
| 69 |
+
"TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W",
|
| 70 |
+
"CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L",
|
| 71 |
+
"CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
|
| 72 |
+
"CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q",
|
| 73 |
+
"CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R",
|
| 74 |
+
"ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
|
| 75 |
+
"ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T",
|
| 76 |
+
"AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K",
|
| 77 |
+
"AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
|
| 78 |
+
"GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V",
|
| 79 |
+
"GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A",
|
| 80 |
+
"GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
|
| 81 |
+
"GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G",
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
self._codon2aa_char: Dict[int, str] = {}
|
| 85 |
+
self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
|
| 86 |
+
for codon in self.codons:
|
| 87 |
+
cid = self.vocab[codon]
|
| 88 |
+
aa = self._genetic_code.get(codon, "X")
|
| 89 |
+
self._codon2aa_char[cid] = aa
|
| 90 |
+
if aa in self._aa2codons_char:
|
| 91 |
+
self._aa2codons_char[aa].append(cid)
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def vocab_size(self) -> int:
|
| 95 |
+
return len(self.vocab)
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def special_ids(self) -> SpecialIds:
|
| 99 |
+
return self._special_ids
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def num_special_tokens(self) -> int:
|
| 103 |
+
return self._num_special_tokens
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def pad_token_id(self) -> int:
|
| 107 |
+
return self._special_ids.pad
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def eos_token_id(self) -> int:
|
| 111 |
+
return self._special_ids.eos
|
| 112 |
+
|
| 113 |
+
# helpers
|
| 114 |
+
def codon_vocab(self) -> Dict[str, int]:
|
| 115 |
+
return {c: self.vocab[c] for c in self.codons}
|
| 116 |
+
|
| 117 |
+
def codon2aa_char_map(self) -> Dict[int, str]:
|
| 118 |
+
return dict(self._codon2aa_char)
|
| 119 |
+
|
| 120 |
+
def aa2codons_char_map(self) -> Dict[str, List[int]]:
|
| 121 |
+
return {k: v[:] for k, v in self._aa2codons_char.items()}
|
| 122 |
+
|
| 123 |
+
# decoding
|
| 124 |
+
def decode_codon_seq(self, token_ids: List[int]) -> str:
|
| 125 |
+
parts: List[str] = []
|
| 126 |
+
nst = self._num_special_tokens
|
| 127 |
+
for tid in token_ids:
|
| 128 |
+
if tid >= nst:
|
| 129 |
+
tok = self.ids_to_tokens.get(tid)
|
| 130 |
+
if tok is not None:
|
| 131 |
+
parts.append(tok)
|
| 132 |
+
return "".join(parts)
|
| 133 |
+
|
| 134 |
+
# persistence
|
| 135 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 136 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 137 |
+
vocab_file = os.path.join(
|
| 138 |
+
save_directory,
|
| 139 |
+
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
|
| 140 |
+
)
|
| 141 |
+
payload = {
|
| 142 |
+
"vocab": self.vocab,
|
| 143 |
+
"special_token_str": self._special_token_str,
|
| 144 |
+
}
|
| 145 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 146 |
+
json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)
|
| 147 |
+
return (vocab_file,)
|
| 148 |
+
|
| 149 |
+
@classmethod
|
| 150 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer":
|
| 151 |
+
vocab_path = Path(pretrained_model_name_or_path) / "vocab.json"
|
| 152 |
+
tok = cls(**kwargs)
|
| 153 |
+
if not vocab_path.exists():
|
| 154 |
+
return tok
|
| 155 |
+
with open(vocab_path, "r", encoding="utf-8") as f:
|
| 156 |
+
save_data = json.load(f)
|
| 157 |
+
vocab = save_data["vocab"] if isinstance(save_data, dict) and "vocab" in save_data else save_data
|
| 158 |
+
tok.vocab = {str(k): int(v) for k, v in vocab.items()}
|
| 159 |
+
tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()}
|
| 160 |
+
sts = save_data.get("special_token_str", tok._special_token_str) if isinstance(save_data, dict) else tok._special_token_str
|
| 161 |
+
tok._special_token_str.update(sts)
|
| 162 |
+
def _id_for(name: str, default_val: int) -> int:
|
| 163 |
+
sym = tok._special_token_str[name]
|
| 164 |
+
return int(tok.vocab.get(sym, default_val))
|
| 165 |
+
tok._special_ids = SpecialIds(
|
| 166 |
+
pad=_id_for("pad", 0),
|
| 167 |
+
unk=_id_for("unk", 1),
|
| 168 |
+
bos=_id_for("bos", 2),
|
| 169 |
+
eos=_id_for("eos", 3),
|
| 170 |
+
)
|
| 171 |
+
ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos]
|
| 172 |
+
m = max(ids)
|
| 173 |
+
tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4
|
| 174 |
+
# rebuild helpers
|
| 175 |
+
tok._codon2aa_char = {}
|
| 176 |
+
tok._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
|
| 177 |
+
for codon in tok.codons:
|
| 178 |
+
cid = tok.vocab[codon]
|
| 179 |
+
aa = tok._genetic_code.get(codon, "X")
|
| 180 |
+
tok._codon2aa_char[cid] = aa
|
| 181 |
+
if aa in tok._aa2codons_char:
|
| 182 |
+
tok._aa2codons_char[aa].append(cid)
|
| 183 |
+
return tok
|
CodonTranslator/translator.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import numpy as np
|
| 11 |
+
from safetensors.torch import load_file
|
| 12 |
+
|
| 13 |
+
from .models import TranslatorBackbone
|
| 14 |
+
from .tokenizer import CodonTokenizer
|
| 15 |
+
# no external store at inference; species embeddings computed via Qwen
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CodonTranslator:
|
| 19 |
+
"""
|
| 20 |
+
High-level sampling wrapper for trained checkpoints with a simple API:
|
| 21 |
+
|
| 22 |
+
from CodonTranslator import CodonTranslator
|
| 23 |
+
model = CodonTranslator.from_pretrained(model_path)
|
| 24 |
+
dna = model.sampling(species="Homo sapiens", protein_seq="M...", enforce_mapping=True)
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, model_dir: Union[str, Path], device: str = "cuda", use_gbif: bool = False):
|
| 28 |
+
self.model_dir = Path(model_dir)
|
| 29 |
+
self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
|
| 30 |
+
self.tokenizer = CodonTokenizer.from_pretrained(str(self.model_dir))
|
| 31 |
+
self.V = int(self.tokenizer.vocab_size)
|
| 32 |
+
self._eos_id = int(self.tokenizer.eos_token_id)
|
| 33 |
+
self._pad_id = int(self.tokenizer.pad_token_id)
|
| 34 |
+
self._num_special = int(self.tokenizer.num_special_tokens)
|
| 35 |
+
|
| 36 |
+
# Load config
|
| 37 |
+
cfg_path = self.model_dir / "trainer_config.json"
|
| 38 |
+
if not cfg_path.exists():
|
| 39 |
+
cfg_path = self.model_dir / "config.json"
|
| 40 |
+
with open(cfg_path, "r") as f:
|
| 41 |
+
self.config = json.load(f)
|
| 42 |
+
|
| 43 |
+
# Build model and load weights
|
| 44 |
+
state = self._load_state_dict()
|
| 45 |
+
arch = self._infer_arch_from_state_dict(state)
|
| 46 |
+
self.model = TranslatorBackbone(
|
| 47 |
+
vocab_size=self.V,
|
| 48 |
+
hidden_size=int(arch["hidden_size"]),
|
| 49 |
+
num_layers=int(arch["num_layers"]),
|
| 50 |
+
num_heads=int(arch["num_heads"]),
|
| 51 |
+
mlp_ratio=float(arch.get("mlp_ratio", 4.0)),
|
| 52 |
+
max_position_embeddings=int(arch["max_position_embeddings"]),
|
| 53 |
+
num_special_tokens=self._num_special,
|
| 54 |
+
special_ids=self.tokenizer.special_ids,
|
| 55 |
+
prepend_species=bool(arch.get("prepend_species", True)),
|
| 56 |
+
prepend_protein=bool(arch.get("prepend_protein", False)),
|
| 57 |
+
species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)),
|
| 58 |
+
esm_model_name=str(arch.get("esm_model_name", "esmc_300m")),
|
| 59 |
+
esm_device=str(arch.get("esm_device", "cuda")),
|
| 60 |
+
esm_dtype=str(arch.get("esm_dtype", "bf16")),
|
| 61 |
+
max_protein_prefix=int(arch.get("max_protein_prefix", 0)),
|
| 62 |
+
max_species_prefix=int(arch.get("max_species_prefix", 0)),
|
| 63 |
+
attn_impl=str(arch.get("attn_impl", "gqa")),
|
| 64 |
+
num_kv_groups=int(arch.get("num_kv_groups", 0)),
|
| 65 |
+
)
|
| 66 |
+
missing, unexpected = self.model.load_state_dict(state, strict=False)
|
| 67 |
+
if len(unexpected) > 0:
|
| 68 |
+
# non-fatal
|
| 69 |
+
pass
|
| 70 |
+
self.model.to(self.device).eval()
|
| 71 |
+
|
| 72 |
+
# Static masks
|
| 73 |
+
self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device)
|
| 74 |
+
self._allowed_fixed[:self._num_special] = False
|
| 75 |
+
self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device)
|
| 76 |
+
self._allowed_variable[:self._num_special] = False
|
| 77 |
+
self._allowed_variable[self._eos_id] = True
|
| 78 |
+
|
| 79 |
+
# Species taxonomy: either query GBIF (if allowed) or use raw names.
|
| 80 |
+
self._use_gbif = bool(use_gbif)
|
| 81 |
+
self._taxonomy_cache: Dict[str, str] = {}
|
| 82 |
+
|
| 83 |
+
# ---- constructors ----
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_pretrained(cls, model_path: Union[str, Path], device: str = "cuda", use_gbif: bool = False) -> "CodonTranslator":
|
| 86 |
+
return cls(model_path, device=device, use_gbif=use_gbif)
|
| 87 |
+
|
| 88 |
+
# ---- sampling APIs ----
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def sampling(self, species: str, protein_seq: str, enforce_mapping: bool = False, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, seed: Optional[int] = None, use_kv_cache: bool = True) -> str:
|
| 91 |
+
out = self.batch_inference(
|
| 92 |
+
species=[species],
|
| 93 |
+
protein_seqs=[protein_seq],
|
| 94 |
+
enforce_mapping=enforce_mapping,
|
| 95 |
+
temperature=temperature,
|
| 96 |
+
top_k=top_k,
|
| 97 |
+
top_p=top_p,
|
| 98 |
+
seed=seed,
|
| 99 |
+
use_kv_cache=use_kv_cache,
|
| 100 |
+
)
|
| 101 |
+
return out[0]
|
| 102 |
+
|
| 103 |
+
@torch.no_grad()
|
| 104 |
+
def batch_inference(
|
| 105 |
+
self,
|
| 106 |
+
species: List[str],
|
| 107 |
+
protein_seqs: List[str],
|
| 108 |
+
enforce_mapping: bool = False,
|
| 109 |
+
temperature: float = 1.0,
|
| 110 |
+
top_k: Optional[int] = None,
|
| 111 |
+
top_p: Optional[float] = None,
|
| 112 |
+
seed: Optional[int] = None,
|
| 113 |
+
use_kv_cache: bool = True,
|
| 114 |
+
micro_batch_size: int = 1,
|
| 115 |
+
) -> List[str]:
|
| 116 |
+
"""Generate DNA for a list of protein sequences, using micro-batching to limit memory.
|
| 117 |
+
|
| 118 |
+
- micro_batch_size: number of samples to process at once (default=1 for low memory)
|
| 119 |
+
"""
|
| 120 |
+
assert len(species) == len(protein_seqs), "species and protein_seqs length must match"
|
| 121 |
+
mb = max(1, int(micro_batch_size))
|
| 122 |
+
if len(species) <= mb:
|
| 123 |
+
return self._batch_inference_core(
|
| 124 |
+
species=species,
|
| 125 |
+
protein_seqs=protein_seqs,
|
| 126 |
+
enforce_mapping=enforce_mapping,
|
| 127 |
+
temperature=temperature,
|
| 128 |
+
top_k=top_k,
|
| 129 |
+
top_p=top_p,
|
| 130 |
+
seed=seed,
|
| 131 |
+
use_kv_cache=use_kv_cache,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
outputs: List[str] = []
|
| 135 |
+
for start in range(0, len(species), mb):
|
| 136 |
+
end = min(start + mb, len(species))
|
| 137 |
+
chunk_out = self._batch_inference_core(
|
| 138 |
+
species=species[start:end],
|
| 139 |
+
protein_seqs=protein_seqs[start:end],
|
| 140 |
+
enforce_mapping=enforce_mapping,
|
| 141 |
+
temperature=temperature,
|
| 142 |
+
top_k=top_k,
|
| 143 |
+
top_p=top_p,
|
| 144 |
+
seed=seed,
|
| 145 |
+
use_kv_cache=use_kv_cache,
|
| 146 |
+
)
|
| 147 |
+
outputs.extend(chunk_out)
|
| 148 |
+
return outputs
|
| 149 |
+
|
| 150 |
+
@torch.no_grad()
|
| 151 |
+
def _batch_inference_core(
|
| 152 |
+
self,
|
| 153 |
+
species: List[str],
|
| 154 |
+
protein_seqs: List[str],
|
| 155 |
+
enforce_mapping: bool = False,
|
| 156 |
+
temperature: float = 1.0,
|
| 157 |
+
top_k: Optional[int] = None,
|
| 158 |
+
top_p: Optional[float] = None,
|
| 159 |
+
seed: Optional[int] = None,
|
| 160 |
+
use_kv_cache: bool = True,
|
| 161 |
+
) -> List[str]:
|
| 162 |
+
if seed is not None:
|
| 163 |
+
torch.manual_seed(int(seed))
|
| 164 |
+
np.random.seed(int(seed))
|
| 165 |
+
B = len(species)
|
| 166 |
+
assert B == len(protein_seqs), "species and protein_seqs length must match"
|
| 167 |
+
target_lens = torch.tensor([len(s) for s in protein_seqs], device=self.device, dtype=torch.long)
|
| 168 |
+
T_codons = int(target_lens.max().item())
|
| 169 |
+
|
| 170 |
+
# Prepare conditioning
|
| 171 |
+
cond: Dict[str, Any] = {"control_mode": "fixed"}
|
| 172 |
+
|
| 173 |
+
# Species embeddings via Qwen3-Embedding (variable-length token sequences)
|
| 174 |
+
q_tok, lengths = self._qwen_embed_names(species, pooling="sequence") # [B, L, D]
|
| 175 |
+
# Always surface a message so users can see species embeddings are used
|
| 176 |
+
print(f"[CodonTranslator] Species embeddings (Qwen) computed: shape={tuple(q_tok.shape)}")
|
| 177 |
+
cond["species_tok_emb"] = q_tok.to(self.device)
|
| 178 |
+
|
| 179 |
+
# Protein input via ESM (if available) – let model tokenize internally
|
| 180 |
+
if getattr(self.model, "esm", None) is not None:
|
| 181 |
+
# Tokenize AA sequences with model.esm
|
| 182 |
+
max_len_tokens = (getattr(self.model, "max_protein_prefix", 0) + 2) if getattr(self.model, "max_protein_prefix", 0) > 0 else None
|
| 183 |
+
prot_ids, prot_mask = self.model.esm.tokenize(protein_seqs, max_length=max_len_tokens)
|
| 184 |
+
cond["protein_input"] = (prot_ids.to(self.device), prot_mask.to(self.device))
|
| 185 |
+
|
| 186 |
+
# Start generation with empty context to build KV cache and initial logits
|
| 187 |
+
input_ids = torch.zeros(B, 0, dtype=torch.long, device=self.device)
|
| 188 |
+
out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=use_kv_cache)
|
| 189 |
+
kv = out_prefill.get("present_kv") if use_kv_cache else None
|
| 190 |
+
logits = out_prefill.get("next_logits")
|
| 191 |
+
assert logits is not None
|
| 192 |
+
# Report prefix length to prove species/protein prefixes were incorporated
|
| 193 |
+
try:
|
| 194 |
+
pref = out_prefill.get("prefix_len")
|
| 195 |
+
if pref is not None:
|
| 196 |
+
lst = pref.detach().cpu().tolist()
|
| 197 |
+
print(f"[CodonTranslator] Prefix lengths (species,species,protein): {lst}")
|
| 198 |
+
except Exception:
|
| 199 |
+
pass
|
| 200 |
+
|
| 201 |
+
allowed = self._allowed_fixed
|
| 202 |
+
finished = torch.zeros(B, dtype=torch.bool, device=self.device)
|
| 203 |
+
|
| 204 |
+
aa2codons = self.tokenizer.aa2codons_char_map()
|
| 205 |
+
|
| 206 |
+
rng = range(T_codons)
|
| 207 |
+
# Greedy mode: temperature <= 0 selects argmax deterministically
|
| 208 |
+
greedy_mode = (temperature is not None and float(temperature) <= 0.0)
|
| 209 |
+
for step in rng:
|
| 210 |
+
logits = logits.masked_fill(~allowed, float("-inf"))
|
| 211 |
+
|
| 212 |
+
# Stop sampling per-sample once reaching its target length; force PAD
|
| 213 |
+
done_now = (torch.tensor(step, device=self.device) >= target_lens)
|
| 214 |
+
if done_now.any():
|
| 215 |
+
logits[done_now] = float("-inf")
|
| 216 |
+
logits[done_now, self._pad_id] = 0.0
|
| 217 |
+
|
| 218 |
+
# Enforce codon ↔ AA mapping at this step
|
| 219 |
+
if enforce_mapping:
|
| 220 |
+
aas_now = [seq[step] if step < len(seq) else None for seq in protein_seqs]
|
| 221 |
+
mask = torch.zeros_like(logits, dtype=torch.bool)
|
| 222 |
+
for i, a in enumerate(aas_now):
|
| 223 |
+
if a is None:
|
| 224 |
+
mask[i, self._num_special:self.V] = True
|
| 225 |
+
else:
|
| 226 |
+
valid = aa2codons.get(a, [])
|
| 227 |
+
if len(valid) == 0:
|
| 228 |
+
mask[i, self._num_special:self.V] = True
|
| 229 |
+
else:
|
| 230 |
+
mask[i, valid] = True
|
| 231 |
+
logits = logits.masked_fill(~mask, float("-inf"))
|
| 232 |
+
|
| 233 |
+
if not greedy_mode and temperature != 1.0:
|
| 234 |
+
logits = logits / float(temperature)
|
| 235 |
+
if top_k is not None:
|
| 236 |
+
logits = self._top_k_filtering(logits, int(top_k))
|
| 237 |
+
if top_p is not None:
|
| 238 |
+
logits = self._top_p_filtering(logits, float(top_p))
|
| 239 |
+
|
| 240 |
+
if greedy_mode:
|
| 241 |
+
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
|
| 242 |
+
else:
|
| 243 |
+
probs = F.softmax(logits, dim=-1)
|
| 244 |
+
next_tok = torch.multinomial(probs, num_samples=1)
|
| 245 |
+
|
| 246 |
+
input_ids = torch.cat([input_ids, next_tok], dim=1)
|
| 247 |
+
|
| 248 |
+
if use_kv_cache:
|
| 249 |
+
pos_offset = int(out_prefill.get("prefix_len").max().item()) + input_ids.size(1) - 1 if isinstance(out_prefill, dict) and ("prefix_len" in out_prefill) else input_ids.size(1) - 1
|
| 250 |
+
out_inc = self.model(
|
| 251 |
+
codon_ids=next_tok,
|
| 252 |
+
cond=None,
|
| 253 |
+
return_dict=True,
|
| 254 |
+
use_cache=True,
|
| 255 |
+
past_kv=kv,
|
| 256 |
+
position_offset=pos_offset,
|
| 257 |
+
)
|
| 258 |
+
kv = out_inc.get("present_kv")
|
| 259 |
+
logits = out_inc.get("next_logits")
|
| 260 |
+
else:
|
| 261 |
+
# Recompute full forward with prefix+all generated tokens
|
| 262 |
+
out_full = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=False)
|
| 263 |
+
logits = out_full.get("next_logits")
|
| 264 |
+
|
| 265 |
+
# Build DNA strings, dropping specials
|
| 266 |
+
output_token_rows: List[List[int]] = []
|
| 267 |
+
for i, row in enumerate(input_ids.tolist()):
|
| 268 |
+
toks: List[int] = []
|
| 269 |
+
for t in row:
|
| 270 |
+
if t == self._pad_id:
|
| 271 |
+
continue
|
| 272 |
+
if t == self._eos_id:
|
| 273 |
+
break
|
| 274 |
+
if t >= self._num_special and t < self.V:
|
| 275 |
+
toks.append(int(t))
|
| 276 |
+
toks = toks[: int(target_lens[i].item())]
|
| 277 |
+
output_token_rows.append(toks)
|
| 278 |
+
sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows]
|
| 279 |
+
|
| 280 |
+
# If not enforcing mapping, report AA token accuracy vs provided targets
|
| 281 |
+
if not enforce_mapping:
|
| 282 |
+
for i, dna in enumerate(sequences):
|
| 283 |
+
tgt = protein_seqs[i]
|
| 284 |
+
gen_aa = self._dna_to_aa(dna)
|
| 285 |
+
L = min(len(gen_aa), len(tgt))
|
| 286 |
+
if L == 0:
|
| 287 |
+
acc = 0.0; num = 0; den = 0
|
| 288 |
+
else:
|
| 289 |
+
num = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b)
|
| 290 |
+
den = L
|
| 291 |
+
acc = num / den
|
| 292 |
+
print(f"[CodonTranslator] AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})")
|
| 293 |
+
return sequences
|
| 294 |
+
|
| 295 |
+
# ---- helpers ----
|
| 296 |
+
def _load_state_dict(self) -> Dict[str, torch.Tensor]:
|
| 297 |
+
st_p = self.model_dir / "model.safetensors"
|
| 298 |
+
if st_p.exists():
|
| 299 |
+
return load_file(st_p)
|
| 300 |
+
pt_p = self.model_dir / "pytorch_model.bin"
|
| 301 |
+
if pt_p.exists():
|
| 302 |
+
return torch.load(pt_p, map_location="cpu")
|
| 303 |
+
raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}")
|
| 304 |
+
|
| 305 |
+
def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
| 306 |
+
arch: Dict[str, Any] = {}
|
| 307 |
+
if "lm_head.weight" in state_dict:
|
| 308 |
+
arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1])
|
| 309 |
+
else:
|
| 310 |
+
for k, v in state_dict.items():
|
| 311 |
+
if k.endswith("ln_f.weight"):
|
| 312 |
+
arch["hidden_size"] = int(v.shape[0])
|
| 313 |
+
break
|
| 314 |
+
cfg = self.config or {}
|
| 315 |
+
if "hidden_size" in cfg:
|
| 316 |
+
arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore
|
| 317 |
+
if "hidden_size" not in arch:
|
| 318 |
+
arch["hidden_size"] = int(cfg.get("hidden_size", 750))
|
| 319 |
+
H = int(arch["hidden_size"])
|
| 320 |
+
|
| 321 |
+
max_block = -1
|
| 322 |
+
for k in state_dict.keys():
|
| 323 |
+
if k.startswith("blocks."):
|
| 324 |
+
idx = int(k.split(".")[1])
|
| 325 |
+
if idx > max_block:
|
| 326 |
+
max_block = idx
|
| 327 |
+
arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12))
|
| 328 |
+
if "num_hidden_layers" in cfg:
|
| 329 |
+
arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore
|
| 330 |
+
|
| 331 |
+
# mlp ratio
|
| 332 |
+
w1_key = next((k for k in state_dict.keys() if k.endswith("ffn.w1.weight")), None)
|
| 333 |
+
if w1_key is not None:
|
| 334 |
+
arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H)
|
| 335 |
+
else:
|
| 336 |
+
arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0))
|
| 337 |
+
|
| 338 |
+
# heads: pick divisor
|
| 339 |
+
cfg_heads = cfg.get("num_attention_heads")
|
| 340 |
+
if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0:
|
| 341 |
+
arch["num_heads"] = int(cfg_heads)
|
| 342 |
+
else:
|
| 343 |
+
for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1):
|
| 344 |
+
if H % h == 0:
|
| 345 |
+
arch["num_heads"] = h
|
| 346 |
+
break
|
| 347 |
+
|
| 348 |
+
arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys())))
|
| 349 |
+
has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys())
|
| 350 |
+
arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm)))
|
| 351 |
+
arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m"))
|
| 352 |
+
arch["esm_device"] = str(cfg.get("esm_device", "cuda"))
|
| 353 |
+
arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower()
|
| 354 |
+
arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0))
|
| 355 |
+
arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0))
|
| 356 |
+
arch["max_position_embeddings"] = int(cfg.get("max_length", cfg.get("max_position_embeddings", 2048)))
|
| 357 |
+
arch["attn_impl"] = str(cfg.get("attn_impl", "gqa"))
|
| 358 |
+
arch["num_kv_groups"] = int(cfg.get("num_kv_groups", 0))
|
| 359 |
+
return arch
|
| 360 |
+
|
| 361 |
+
# --- filtering helpers
|
| 362 |
+
@staticmethod
|
| 363 |
+
def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor:
|
| 364 |
+
return logits if logits.dim() == 2 else logits.unsqueeze(0)
|
| 365 |
+
|
| 366 |
+
@staticmethod
|
| 367 |
+
def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
|
| 368 |
+
x = CodonTranslator._ensure_2d_logits(logits)
|
| 369 |
+
k = max(1, min(int(k), x.size(-1)))
|
| 370 |
+
values, _ = torch.topk(x, k, dim=-1)
|
| 371 |
+
min_values = values[:, -1].unsqueeze(-1)
|
| 372 |
+
x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x)
|
| 373 |
+
return x if logits.dim() == 2 else x.squeeze(0)
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
|
| 377 |
+
if p >= 1.0:
|
| 378 |
+
return logits
|
| 379 |
+
if p <= 0.0:
|
| 380 |
+
return torch.full_like(logits, float('-inf'))
|
| 381 |
+
x = CodonTranslator._ensure_2d_logits(logits)
|
| 382 |
+
sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1)
|
| 383 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 384 |
+
cumprobs = torch.cumsum(probs, dim=-1)
|
| 385 |
+
to_remove = cumprobs > p
|
| 386 |
+
# Avoid overlapping memory writes by cloning the RHS
|
| 387 |
+
to_remove = to_remove.to(torch.bool)
|
| 388 |
+
to_remove[:, 1:] = to_remove[:, :-1].clone()
|
| 389 |
+
to_remove[:, 0] = False
|
| 390 |
+
mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove)
|
| 391 |
+
x = torch.where(mask, torch.full_like(x, float('-inf')), x)
|
| 392 |
+
return x if logits.dim() == 2 else x.squeeze(0)
|
| 393 |
+
|
| 394 |
+
# --- Qwen embedding fallback for species text ---
|
| 395 |
+
def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 396 |
+
from transformers import AutoTokenizer, AutoModel
|
| 397 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 398 |
+
"Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left"
|
| 399 |
+
)
|
| 400 |
+
dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
| 401 |
+
model = AutoModel.from_pretrained(
|
| 402 |
+
"Qwen/Qwen3-Embedding-0.6B", dtype=dtype, trust_remote_code=True
|
| 403 |
+
).to(self.device).eval()
|
| 404 |
+
task = (
|
| 405 |
+
"Given a species taxonomy information, generate a biological embedding "
|
| 406 |
+
"representing its taxonomic and evolutionary characteristics"
|
| 407 |
+
)
|
| 408 |
+
queries = self._resolve_taxonomy_texts(names)
|
| 409 |
+
texts = [f"Instruct: {task}\nQuery: {q}" for q in queries]
|
| 410 |
+
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
|
| 411 |
+
out = model(**inputs)
|
| 412 |
+
h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1)
|
| 413 |
+
attn = inputs["attention_mask"]
|
| 414 |
+
# sequence embeddings padded to same length by tokenizer padding
|
| 415 |
+
return h, torch.sum(attn, dim=1)
|
| 416 |
+
|
| 417 |
+
def _taxonomy_lookup(self, name: str) -> str:
|
| 418 |
+
if name in self._taxonomy_cache:
|
| 419 |
+
return self._taxonomy_cache[name]
|
| 420 |
+
if self._use_gbif:
|
| 421 |
+
try:
|
| 422 |
+
import requests
|
| 423 |
+
resp = requests.get("https://api.gbif.org/v1/species/match", params={"name": name}, timeout=5)
|
| 424 |
+
if resp.status_code == 200:
|
| 425 |
+
data = resp.json()
|
| 426 |
+
if data.get("matchType") != "NONE":
|
| 427 |
+
parts = []
|
| 428 |
+
taxonomy = []
|
| 429 |
+
for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]:
|
| 430 |
+
if rank in data and data[rank]:
|
| 431 |
+
taxonomy.append(data[rank])
|
| 432 |
+
if taxonomy:
|
| 433 |
+
parts.append("Taxonomy: " + " > ".join(taxonomy))
|
| 434 |
+
if "vernacularName" in data and data["vernacularName"]:
|
| 435 |
+
parts.append(f"Common name: {data['vernacularName']}")
|
| 436 |
+
if "confidence" in data:
|
| 437 |
+
parts.append(f"Match confidence: {data['confidence']}%")
|
| 438 |
+
if "status" in data:
|
| 439 |
+
parts.append(f"Status: {data['status']}")
|
| 440 |
+
desc = ". ".join(parts) if parts else name
|
| 441 |
+
self._taxonomy_cache[name] = desc
|
| 442 |
+
return desc
|
| 443 |
+
except Exception:
|
| 444 |
+
pass
|
| 445 |
+
return name
|
| 446 |
+
|
| 447 |
+
def _resolve_taxonomy_texts(self, names: List[str]) -> List[str]:
|
| 448 |
+
"""Resolve taxonomy strings for a batch of species names.
|
| 449 |
+
If a taxonomy DB is present, pull from it. Otherwise batch-query GBIF
|
| 450 |
+
(one request per species) and cache results. Always returns a list of
|
| 451 |
+
strings aligned to `names`.
|
| 452 |
+
"""
|
| 453 |
+
results: List[str] = []
|
| 454 |
+
# Batch “query”: loop per-name; still batched at the embedding stage
|
| 455 |
+
fetched = 0
|
| 456 |
+
for s in names:
|
| 457 |
+
txt = self._taxonomy_lookup(s)
|
| 458 |
+
if s in self._taxonomy_cache:
|
| 459 |
+
fetched += 1
|
| 460 |
+
results.append(txt)
|
| 461 |
+
if self._use_gbif:
|
| 462 |
+
print(f"[CodonTranslator] Taxonomy texts resolved (GBIF={'on' if self._use_gbif else 'off'}): {fetched}/{len(names)} fetched")
|
| 463 |
+
return results
|
| 464 |
+
|
| 465 |
+
@staticmethod
|
| 466 |
+
def _dna_to_aa(dna_seq: str) -> str:
|
| 467 |
+
g = {
|
| 468 |
+
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
|
| 469 |
+
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
|
| 470 |
+
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
|
| 471 |
+
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
|
| 472 |
+
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
|
| 473 |
+
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
|
| 474 |
+
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
|
| 475 |
+
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
|
| 476 |
+
}
|
| 477 |
+
L = len(dna_seq) // 3
|
| 478 |
+
aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)]
|
| 479 |
+
return ''.join(aa)
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 CodonTranslator authors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- biology
|
| 6 |
+
- dna
|
| 7 |
+
- codon-optimization
|
| 8 |
+
- protein-conditioned-generation
|
| 9 |
+
- fsdp
|
| 10 |
+
datasets:
|
| 11 |
+
- alegendaryfish/CodonTranslator-data
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# CodonTranslator
|
| 15 |
+
|
| 16 |
+
CodonTranslator is a protein-conditioned codon sequence generation model trained on the representative-only `data_v3` release.
|
| 17 |
+
|
| 18 |
+
This repository is the public model and training-code release. It contains:
|
| 19 |
+
|
| 20 |
+
- `final_model/`: inference-ready weights
|
| 21 |
+
- `training_checkpoints/checkpoint-71000/`: a resumable training checkpoint
|
| 22 |
+
- `src/`, `train.py`, `sampling.py`: training and inference code
|
| 23 |
+
- `resplit_data_v3.py`: the `data_v3` reconstruction pipeline
|
| 24 |
+
- `slurm/`: the single-node H200 training and data rebuild submission scripts
|
| 25 |
+
- `CodonTranslator/` and `pyproject.toml`: a lightweight packaged inference wrapper
|
| 26 |
+
|
| 27 |
+
## Training configuration
|
| 28 |
+
|
| 29 |
+
- Architecture: `hidden=750`, `layers=20`, `heads=15`, `mlp_ratio=3.2`
|
| 30 |
+
- Attention: `mha`
|
| 31 |
+
- Precision: `bf16`
|
| 32 |
+
- Parallelism: FSDP full shard
|
| 33 |
+
- Effective global batch: `1536`
|
| 34 |
+
- Weight decay: `1e-4`
|
| 35 |
+
- Dataset: `alegendaryfish/CodonTranslator-data`
|
| 36 |
+
|
| 37 |
+
## Dataset release
|
| 38 |
+
|
| 39 |
+
The corresponding public dataset and species embedding release is:
|
| 40 |
+
|
| 41 |
+
- `alegendaryfish/CodonTranslator-data`
|
| 42 |
+
|
| 43 |
+
That dataset repo contains:
|
| 44 |
+
|
| 45 |
+
- final representative-only `train/`, `val/`, `test/` parquet shards
|
| 46 |
+
- `embeddings_v2/`
|
| 47 |
+
- split audit files and reconstruction metadata
|
| 48 |
+
|
| 49 |
+
## Quick start
|
| 50 |
+
|
| 51 |
+
### Install
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
git clone https://huggingface.co/alegendaryfish/CodonTranslator
|
| 55 |
+
cd CodonTranslator
|
| 56 |
+
pip install -r requirements.txt
|
| 57 |
+
pip install -e .
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Both import styles are supported:
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from CodonTranslator import CodonTranslator
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
from codontranslator import CodonTranslator
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### Train
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
python train.py \
|
| 74 |
+
--train_data /path/to/train \
|
| 75 |
+
--val_data /path/to/val \
|
| 76 |
+
--embeddings_dir /path/to/embeddings_v2 \
|
| 77 |
+
--output_dir outputs \
|
| 78 |
+
--fsdp \
|
| 79 |
+
--bf16 \
|
| 80 |
+
--attn mha \
|
| 81 |
+
--hidden 750 \
|
| 82 |
+
--layers 20 \
|
| 83 |
+
--heads 15 \
|
| 84 |
+
--mlp_ratio 3.2 \
|
| 85 |
+
--batch_size 48 \
|
| 86 |
+
--grad_accum 4 \
|
| 87 |
+
--epochs 3 \
|
| 88 |
+
--lr 7e-5 \
|
| 89 |
+
--weight_decay 1e-4
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Sample
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
python sampling.py \
|
| 96 |
+
--model_path final_model \
|
| 97 |
+
--embeddings_dir /path/to/embeddings_v2 \
|
| 98 |
+
--species "Panicum hallii" \
|
| 99 |
+
--protein_sequence "MSEQUENCE" \
|
| 100 |
+
--strict_species_lookup
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## Notes
|
| 104 |
+
|
| 105 |
+
- Training uses precomputed `embeddings_v2` for species conditioning.
|
| 106 |
+
- The data split is built in protein space with MMseqs clustering and binomial-species test holdout.
|
| 107 |
+
- `checkpoint-71000` is included for training resumption; `final_model/` is the recommended inference entrypoint.
|
| 108 |
+
- For compatibility, released model directories contain both `trainer_config.json` and `config.json`.
|
| 109 |
+
|
| 110 |
+
## Sampling arguments
|
| 111 |
+
|
| 112 |
+
- `enforce_mapping`: when `True`, each generated codon is constrained to encode the provided amino acid at that position.
|
| 113 |
+
- `temperature`: softmax temperature. Lower values are more deterministic; `0` selects argmax greedily.
|
| 114 |
+
- `top_k`: keep only the `k` highest-logit codon candidates before sampling.
|
| 115 |
+
- `top_p`: nucleus sampling threshold; keep the smallest probability mass whose cumulative sum reaches `p`.
|
__pycache__/precompute_embeddings.cpython-312.pyc
ADDED
|
Binary file (24.1 kB). View file
|
|
|
__pycache__/resplit_data_v3.cpython-312.pyc
ADDED
|
Binary file (57.6 kB). View file
|
|
|
__pycache__/sampling.cpython-312.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
batch_eval.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Run eval.py across all checkpoints and datasets in parallel (multi-GPU),
|
| 4 |
+
and collect results to ./eval.csv.
|
| 5 |
+
|
| 6 |
+
- Discovers checkpoints under outputs/checkpoint-*
|
| 7 |
+
- Evaluates on: data/test/*.parquet and data/val/*.parquet
|
| 8 |
+
- Uses up to N GPUs concurrently (default: 4) by setting CUDA_VISIBLE_DEVICES
|
| 9 |
+
- Parses the "Summary ..." line(s) from eval.py logs
|
| 10 |
+
- Appends rows to ./eval.csv
|
| 11 |
+
|
| 12 |
+
Example:
|
| 13 |
+
python batch_eval.py \
|
| 14 |
+
--outputs_dir outputs \
|
| 15 |
+
--embeddings_dir embeddings \
|
| 16 |
+
--datasets data/test/*.parquet data/val/*.parquet \
|
| 17 |
+
--splits test val \
|
| 18 |
+
--num_samples 12800 \
|
| 19 |
+
--batch_size 4 \
|
| 20 |
+
--gpus 0 1 2 3 \
|
| 21 |
+
--eval_script eval.py \
|
| 22 |
+
--device cuda
|
| 23 |
+
|
| 24 |
+
Notes:
|
| 25 |
+
- This script *does not* modify your eval.py. It just orchestrates/launches it.
|
| 26 |
+
- Requires Python 3.8+ (standard library only).
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import csv
|
| 31 |
+
import os
|
| 32 |
+
import re
|
| 33 |
+
import sys
|
| 34 |
+
import time
|
| 35 |
+
import glob
|
| 36 |
+
import queue
|
| 37 |
+
import threading
|
| 38 |
+
import subprocess
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
TF_SUMMARY_RE = re.compile(
|
| 44 |
+
r"Summary over\s+(\d+)\s+samples\s+→.*?CE=([-\d\.eE]+).*?CODON-acc=([-\d\.eE]+).*?AA-acc=([-\d\.eE]+)"
|
| 45 |
+
)
|
| 46 |
+
EVALALL_SUMMARY_RE = re.compile(
|
| 47 |
+
r"Full-dataset summary.*?tokens=(\d+).*?CE=([-\d\.eE]+).*?CODON-acc=([-\d\.eE]+).*?AA-acc=([-\d\.eE]+)"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
CSV_FIELDS = [
|
| 51 |
+
"timestamp_iso",
|
| 52 |
+
"model_path",
|
| 53 |
+
"checkpoint_step",
|
| 54 |
+
"split",
|
| 55 |
+
"data_path",
|
| 56 |
+
"num_samples",
|
| 57 |
+
"batch_size",
|
| 58 |
+
"seed",
|
| 59 |
+
"eval_all",
|
| 60 |
+
"gpu_id",
|
| 61 |
+
"runtime_sec",
|
| 62 |
+
"tokens",
|
| 63 |
+
"mean_ce",
|
| 64 |
+
"mean_codon_acc",
|
| 65 |
+
"mean_aa_acc",
|
| 66 |
+
"status",
|
| 67 |
+
"error",
|
| 68 |
+
"command",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def parse_args():
|
| 73 |
+
p = argparse.ArgumentParser(description="Parallel evaluator for CodonGPT checkpoints.")
|
| 74 |
+
p.add_argument("--outputs_dir", type=str, default="outputs/", help="Folder containing checkpoint-* subdirs.")
|
| 75 |
+
p.add_argument("--embeddings_dir", type=str, default="embeddings/", help="Embeddings dir to pass to eval.py")
|
| 76 |
+
p.add_argument("--datasets", nargs="+", default=["data/test/*.parquet", "data/val/*.parquet"],
|
| 77 |
+
help="One or more dataset globs.")
|
| 78 |
+
p.add_argument("--splits", nargs="+", default=["test", "val"],
|
| 79 |
+
help="Split names aligned with --datasets (same length).")
|
| 80 |
+
p.add_argument("--num_samples", type=int, default=12800, help="num_samples for eval.py (random subset mode)")
|
| 81 |
+
p.add_argument("--batch_size", type=int, default=4, help="batch_size for eval.py")
|
| 82 |
+
p.add_argument("--seed", type=int, default=42, help="seed for eval.py")
|
| 83 |
+
p.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device flag for eval.py")
|
| 84 |
+
p.add_argument("--gpus", nargs="+", default=["0", "1", "2", "3"], help="GPU IDs to use (as CUDA_VISIBLE_DEVICES)")
|
| 85 |
+
p.add_argument("--eval_script", type=str, default="eval.py", help="Path to eval.py")
|
| 86 |
+
p.add_argument("--csv_path", type=str, default="eval.csv", help="Output CSV file")
|
| 87 |
+
p.add_argument("--eval_all", action="store_true",
|
| 88 |
+
help="Use eval.py --eval_all (streaming, no num_samples). If set, ignores --num_samples.")
|
| 89 |
+
p.add_argument("--workers", type=int, default=4,
|
| 90 |
+
help="--workers passed to eval.py when --eval_all is set.")
|
| 91 |
+
p.add_argument("--dry_run", action="store_true", help="List planned runs but do not execute.")
|
| 92 |
+
# New: filtering / resume options
|
| 93 |
+
p.add_argument("--start_after_step", type=int, default=-1,
|
| 94 |
+
help="Only evaluate checkpoints with step > this value (e.g., 73700)")
|
| 95 |
+
p.add_argument("--end_step", type=int, default=-1,
|
| 96 |
+
help="If >0, only evaluate checkpoints with step <= this value")
|
| 97 |
+
p.add_argument("--skip_existing", dest="skip_existing", action="store_true", default=True,
|
| 98 |
+
help="Skip tasks already recorded as OK in csv_path")
|
| 99 |
+
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
|
| 100 |
+
help="Do not skip existing OK rows; re-run everything in range")
|
| 101 |
+
return p.parse_args()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def natural_step(dirpath: Path) -> int:
|
| 105 |
+
"""
|
| 106 |
+
Extract integer step from a checkpoint dir name like 'checkpoint-21000'.
|
| 107 |
+
Returns -1 if not found.
|
| 108 |
+
"""
|
| 109 |
+
m = re.search(r"checkpoint-(\d+)", dirpath.name)
|
| 110 |
+
return int(m.group(1)) if m else -1
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def discover_checkpoints(outputs_dir: str) -> list[Path]:
|
| 114 |
+
paths = sorted(
|
| 115 |
+
(Path(p) for p in glob.glob(os.path.join(outputs_dir, "checkpoint-*")) if os.path.isdir(p)),
|
| 116 |
+
key=lambda p: natural_step(p),
|
| 117 |
+
)
|
| 118 |
+
# Optional: filter only dirs that look like real checkpoints
|
| 119 |
+
filtered = []
|
| 120 |
+
for p in paths:
|
| 121 |
+
has_config = (p / "config.json").exists() or (p / "trainer_config.json").exists()
|
| 122 |
+
has_weights = (p / "model.safetensors").exists() or (p / "pytorch_model.bin").exists()
|
| 123 |
+
if has_config and has_weights:
|
| 124 |
+
filtered.append(p)
|
| 125 |
+
return filtered
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_cmd(py_exec: str,
|
| 129 |
+
eval_script: str,
|
| 130 |
+
model_path: str,
|
| 131 |
+
data_path: str,
|
| 132 |
+
embeddings_dir: str,
|
| 133 |
+
device: str,
|
| 134 |
+
num_samples: int,
|
| 135 |
+
batch_size: int,
|
| 136 |
+
seed: int,
|
| 137 |
+
eval_all: bool,
|
| 138 |
+
workers: int) -> list[str]:
|
| 139 |
+
cmd = [py_exec, eval_script,
|
| 140 |
+
"--model_path", model_path,
|
| 141 |
+
"--data_path", data_path,
|
| 142 |
+
"--embeddings_dir", embeddings_dir,
|
| 143 |
+
"--batch_size", str(batch_size),
|
| 144 |
+
"--device", device,
|
| 145 |
+
"--seed", str(seed)]
|
| 146 |
+
if eval_all:
|
| 147 |
+
cmd += ["--eval_all", "--workers", str(workers)]
|
| 148 |
+
else:
|
| 149 |
+
cmd += ["--num_samples", str(num_samples)]
|
| 150 |
+
return cmd
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def parse_metrics(stdout: str, stderr: str) -> dict:
|
| 154 |
+
"""
|
| 155 |
+
Return dict with keys: tokens, mean_ce, mean_codon_acc, mean_aa_acc (strings),
|
| 156 |
+
or raise ValueError if no summary line was found.
|
| 157 |
+
"""
|
| 158 |
+
text = stdout + "\n" + stderr
|
| 159 |
+
|
| 160 |
+
# Try eval_all format first
|
| 161 |
+
m = EVALALL_SUMMARY_RE.search(text)
|
| 162 |
+
if m:
|
| 163 |
+
tokens, ce, codon, aa = m.groups()
|
| 164 |
+
return {"tokens": tokens, "mean_ce": ce, "mean_codon_acc": codon, "mean_aa_acc": aa}
|
| 165 |
+
|
| 166 |
+
# Try teacher-forced (random-subset) summary
|
| 167 |
+
m = TF_SUMMARY_RE.search(text)
|
| 168 |
+
if m:
|
| 169 |
+
_samples, ce, codon, aa = m.groups()
|
| 170 |
+
return {"tokens": "", "mean_ce": ce, "mean_codon_acc": codon, "mean_aa_acc": aa}
|
| 171 |
+
|
| 172 |
+
# Not found
|
| 173 |
+
raise ValueError("Could not find summary line in eval.py output.")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def run_one(task: dict, gpu_queue: "queue.Queue[str]", csv_lock: threading.Lock) -> dict:
|
| 177 |
+
"""
|
| 178 |
+
Execute one eval.py call using a GPU from the queue. Returns a row dict for CSV.
|
| 179 |
+
"""
|
| 180 |
+
gpu_id = gpu_queue.get() # blocks until a GPU id is available
|
| 181 |
+
start = time.time()
|
| 182 |
+
status = "OK"
|
| 183 |
+
err_text = ""
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
env = os.environ.copy()
|
| 187 |
+
# Pin the subprocess to a single GPU
|
| 188 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
| 189 |
+
env.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 190 |
+
env.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
|
| 191 |
+
|
| 192 |
+
result = subprocess.run(
|
| 193 |
+
task["cmd"],
|
| 194 |
+
env=env,
|
| 195 |
+
capture_output=True,
|
| 196 |
+
text=True,
|
| 197 |
+
check=False,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
metrics = parse_metrics(result.stdout, result.stderr)
|
| 202 |
+
except Exception as e:
|
| 203 |
+
status = "FAIL"
|
| 204 |
+
err_text = f"{e}\n--- STDOUT ---\n{result.stdout}\n--- STDERR ---\n{result.stderr}"
|
| 205 |
+
metrics = {"tokens": "", "mean_ce": "", "mean_codon_acc": "", "mean_aa_acc": ""}
|
| 206 |
+
|
| 207 |
+
if result.returncode != 0 and status == "OK":
|
| 208 |
+
status = "FAIL"
|
| 209 |
+
err_text = f"Non-zero exit code {result.returncode}\n--- STDOUT ---\n{result.stdout}\n--- STDERR ---\n{result.stderr}"
|
| 210 |
+
|
| 211 |
+
finally:
|
| 212 |
+
runtime = time.time() - start
|
| 213 |
+
gpu_queue.put(gpu_id) # release GPU
|
| 214 |
+
|
| 215 |
+
row = {
|
| 216 |
+
"timestamp_iso": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
| 217 |
+
"model_path": task["model_path"],
|
| 218 |
+
"checkpoint_step": task["step"],
|
| 219 |
+
"split": task["split"],
|
| 220 |
+
"data_path": task["data_path"],
|
| 221 |
+
"num_samples": task["num_samples"] if not task["eval_all"] else "",
|
| 222 |
+
"batch_size": task["batch_size"],
|
| 223 |
+
"seed": task["seed"],
|
| 224 |
+
"eval_all": str(task["eval_all"]),
|
| 225 |
+
"gpu_id": str(gpu_id),
|
| 226 |
+
"runtime_sec": f"{runtime:.2f}",
|
| 227 |
+
"tokens": metrics.get("tokens", ""),
|
| 228 |
+
"mean_ce": metrics.get("mean_ce", ""),
|
| 229 |
+
"mean_codon_acc": metrics.get("mean_codon_acc", ""),
|
| 230 |
+
"mean_aa_acc": metrics.get("mean_aa_acc", ""),
|
| 231 |
+
"status": status,
|
| 232 |
+
"error": err_text.strip(),
|
| 233 |
+
"command": " ".join(task["cmd"]),
|
| 234 |
+
}
|
| 235 |
+
return row
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def ensure_csv(path: str):
|
| 239 |
+
"""Create CSV with header if it does not exist."""
|
| 240 |
+
need_header = not os.path.exists(path) or os.path.getsize(path) == 0
|
| 241 |
+
if need_header:
|
| 242 |
+
with open(path, "w", newline="") as f:
|
| 243 |
+
w = csv.DictWriter(f, fieldnames=CSV_FIELDS)
|
| 244 |
+
w.writeheader()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def read_completed_keys(path: str) -> set[tuple[int, str, str]]:
|
| 248 |
+
"""
|
| 249 |
+
Read existing CSV and return a set of (step, split, data_path) for rows with status == 'OK'.
|
| 250 |
+
If CSV does not exist, returns empty set.
|
| 251 |
+
"""
|
| 252 |
+
keys: set[tuple[int, str, str]] = set()
|
| 253 |
+
if not os.path.exists(path) or os.path.getsize(path) == 0:
|
| 254 |
+
return keys
|
| 255 |
+
try:
|
| 256 |
+
with open(path, "r", newline="") as f:
|
| 257 |
+
r = csv.DictReader(f)
|
| 258 |
+
for row in r:
|
| 259 |
+
if (row.get("status") or "").strip().upper() == "OK":
|
| 260 |
+
try:
|
| 261 |
+
step = int(row.get("checkpoint_step", "-1"))
|
| 262 |
+
except ValueError:
|
| 263 |
+
continue
|
| 264 |
+
split = row.get("split", "")
|
| 265 |
+
data_path = row.get("data_path", "")
|
| 266 |
+
keys.add((step, split, data_path))
|
| 267 |
+
except Exception:
|
| 268 |
+
# If CSV is malformed, resume logic is best-effort
|
| 269 |
+
pass
|
| 270 |
+
return keys
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def append_row(path: str, row: dict, lock: threading.Lock):
|
| 274 |
+
with lock:
|
| 275 |
+
with open(path, "a", newline="") as f:
|
| 276 |
+
w = csv.DictWriter(f, fieldnames=CSV_FIELDS)
|
| 277 |
+
w.writerow(row)
|
| 278 |
+
f.flush()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main():
|
| 282 |
+
args = parse_args()
|
| 283 |
+
|
| 284 |
+
if len(args.datasets) != len(args.splits):
|
| 285 |
+
print("ERROR: --datasets and --splits must have the same length.", file=sys.stderr)
|
| 286 |
+
sys.exit(2)
|
| 287 |
+
|
| 288 |
+
checkpoints = discover_checkpoints(args.outputs_dir)
|
| 289 |
+
if not checkpoints:
|
| 290 |
+
print(f"No checkpoints found in {args.outputs_dir}/checkpoint-*", file=sys.stderr)
|
| 291 |
+
sys.exit(1)
|
| 292 |
+
|
| 293 |
+
print(f"Discovered {len(checkpoints)} checkpoints.")
|
| 294 |
+
ds_pairs = list(zip(args.splits, args.datasets))
|
| 295 |
+
print(f"Datasets: {', '.join([f'{s}:{d}' for s, d in ds_pairs])}")
|
| 296 |
+
print(f"GPUs: {', '.join(args.gpus)}")
|
| 297 |
+
print(f"Writing results to: {args.csv_path}")
|
| 298 |
+
if args.start_after_step >= 0:
|
| 299 |
+
print(f"Filtering: step > {args.start_after_step}")
|
| 300 |
+
if args.end_step > 0:
|
| 301 |
+
print(f"Filtering: step <= {args.end_step}")
|
| 302 |
+
print(f"Skip existing OK rows in CSV: {args.skip_existing}")
|
| 303 |
+
|
| 304 |
+
# Build task list
|
| 305 |
+
py_exec = sys.executable
|
| 306 |
+
tasks = []
|
| 307 |
+
completed_keys = read_completed_keys(args.csv_path) if args.skip_existing else set()
|
| 308 |
+
for ckpt in checkpoints:
|
| 309 |
+
step = natural_step(ckpt)
|
| 310 |
+
# Apply step filters
|
| 311 |
+
if args.start_after_step >= 0 and step <= args.start_after_step:
|
| 312 |
+
continue
|
| 313 |
+
if args.end_step > 0 and step > args.end_step:
|
| 314 |
+
continue
|
| 315 |
+
for split, data_path in ds_pairs:
|
| 316 |
+
# Skip if already evaluated with OK status
|
| 317 |
+
if (step, split, data_path) in completed_keys:
|
| 318 |
+
continue
|
| 319 |
+
cmd = build_cmd(
|
| 320 |
+
py_exec=py_exec,
|
| 321 |
+
eval_script=args.eval_script,
|
| 322 |
+
model_path=str(ckpt),
|
| 323 |
+
data_path=data_path,
|
| 324 |
+
embeddings_dir=args.embeddings_dir,
|
| 325 |
+
device=args.device,
|
| 326 |
+
num_samples=args.num_samples,
|
| 327 |
+
batch_size=args.batch_size,
|
| 328 |
+
seed=args.seed,
|
| 329 |
+
eval_all=args.eval_all,
|
| 330 |
+
workers=args.workers,
|
| 331 |
+
)
|
| 332 |
+
tasks.append({
|
| 333 |
+
"model_path": str(ckpt),
|
| 334 |
+
"step": step,
|
| 335 |
+
"split": split,
|
| 336 |
+
"data_path": data_path,
|
| 337 |
+
"num_samples": args.num_samples,
|
| 338 |
+
"batch_size": args.batch_size,
|
| 339 |
+
"seed": args.seed,
|
| 340 |
+
"eval_all": args.eval_all,
|
| 341 |
+
"cmd": cmd,
|
| 342 |
+
})
|
| 343 |
+
|
| 344 |
+
# Dry run listing
|
| 345 |
+
if args.dry_run:
|
| 346 |
+
for t in tasks:
|
| 347 |
+
print(f"[DRY RUN] GPU=? step={t['step']} split={t['split']} -> {' '.join(t['cmd'])}")
|
| 348 |
+
print(f"Planned runs: {len(tasks)}")
|
| 349 |
+
return
|
| 350 |
+
|
| 351 |
+
# Prepare CSV
|
| 352 |
+
ensure_csv(args.csv_path)
|
| 353 |
+
csv_lock = threading.Lock()
|
| 354 |
+
|
| 355 |
+
# GPU pool
|
| 356 |
+
gpu_queue: "queue.Queue[str]" = queue.Queue()
|
| 357 |
+
for gid in args.gpus:
|
| 358 |
+
gpu_queue.put(str(gid))
|
| 359 |
+
|
| 360 |
+
# Execute with up to len(gpus) concurrent workers
|
| 361 |
+
max_workers = max(1, len(args.gpus))
|
| 362 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 363 |
+
futures = [ex.submit(run_one, t, gpu_queue, csv_lock) for t in tasks]
|
| 364 |
+
completed = 0
|
| 365 |
+
total = len(futures)
|
| 366 |
+
for fut in as_completed(futures):
|
| 367 |
+
row = fut.result()
|
| 368 |
+
append_row(args.csv_path, row, csv_lock)
|
| 369 |
+
completed += 1
|
| 370 |
+
if row["status"] == "OK":
|
| 371 |
+
print(f"[{completed}/{total}] ✅ step={row['checkpoint_step']} split={row['split']} "
|
| 372 |
+
f"CE={row['mean_ce']} CODON={row['mean_codon_acc']} AA={row['mean_aa_acc']} "
|
| 373 |
+
f"gpu={row['gpu_id']} in {row['runtime_sec']}s")
|
| 374 |
+
else:
|
| 375 |
+
print(f"[{completed}/{total}] ❌ step={row['checkpoint_step']} split={row['split']} "
|
| 376 |
+
f"gpu={row['gpu_id']} See CSV 'error' column for details.")
|
| 377 |
+
|
| 378 |
+
print(f"Done. Results appended to {args.csv_path}")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
main()
|
codontranslator/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from CodonTranslator import CodonTranslator
|
| 2 |
+
|
| 3 |
+
__all__ = ["CodonTranslator"]
|
environment.yml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: codontranslator
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- pytorch
|
| 5 |
+
- nvidia
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.12
|
| 8 |
+
- pip
|
| 9 |
+
- pytorch>=2.4
|
| 10 |
+
- pandas>=2.3
|
| 11 |
+
- pyarrow>=21.0
|
| 12 |
+
- duckdb>=1.5
|
| 13 |
+
- biopython>=1.85
|
| 14 |
+
- pip:
|
| 15 |
+
- transformers>=4.57.0
|
| 16 |
+
- esm>=3.2.3
|
| 17 |
+
- safetensors>=0.7.0
|
| 18 |
+
- huggingface-hub>=0.36.0
|
| 19 |
+
- accelerate>=1.9.0
|
| 20 |
+
- wandb>=0.21.0
|
eval.py
ADDED
|
@@ -0,0 +1,1239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Teacher-forced (and optional free-run) evaluation on a random subset of your
|
| 4 |
+
dataset to measure codon token cross-entropy and AA token accuracy, using the
|
| 5 |
+
same conditioning pathway as training.
|
| 6 |
+
|
| 7 |
+
Supports either a CSV file or Parquet input via a directory/glob (e.g.,
|
| 8 |
+
./data/val/*.parquet).
|
| 9 |
+
|
| 10 |
+
Usage examples:
|
| 11 |
+
# CSV input
|
| 12 |
+
python eval.py \
|
| 13 |
+
--model_path outputs/checkpoint-21000 \
|
| 14 |
+
--data_path random_sample_1000.csv \
|
| 15 |
+
--embeddings_dir embeddings \
|
| 16 |
+
--num_samples 10 \
|
| 17 |
+
--batch_size 10 \
|
| 18 |
+
--device cuda
|
| 19 |
+
|
| 20 |
+
# Parquet glob input
|
| 21 |
+
python eval.py \
|
| 22 |
+
--model_path outputs/checkpoint-21000 \
|
| 23 |
+
--data_path "./data/val/*.parquet" \
|
| 24 |
+
--embeddings_dir embeddings \
|
| 25 |
+
--num_samples 64 \
|
| 26 |
+
--batch_size 32 \
|
| 27 |
+
--device cuda
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import json
|
| 32 |
+
import logging
|
| 33 |
+
import random
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
from typing import List, Optional, Tuple
|
| 36 |
+
import glob
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
import pandas as pd
|
| 41 |
+
|
| 42 |
+
from src.sampler import CodonSampler
|
| 43 |
+
from src.dataset import SpeciesEmbeddingStore, StreamSeqDataset, stage_collate_fn
|
| 44 |
+
from torch.utils.data import DataLoader
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logging.basicConfig(
|
| 48 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 49 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 50 |
+
level=logging.INFO,
|
| 51 |
+
)
|
| 52 |
+
logger = logging.getLogger("eval_tf")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_args():
|
| 56 |
+
p = argparse.ArgumentParser("Teacher-forced evaluation of CodonGPT")
|
| 57 |
+
p.add_argument("--model_path", required=True, type=str,
|
| 58 |
+
help="Path to checkpoint dir (with config.json / model.safetensors)")
|
| 59 |
+
# Input data: CSV file or Parquet glob/dir
|
| 60 |
+
p.add_argument("--data_path", required=False, type=str, default=None,
|
| 61 |
+
help="CSV file or Parquet glob/dir (e.g., ./data/val/*.parquet)")
|
| 62 |
+
# Back-compat: --csv_path still accepted (deprecated)
|
| 63 |
+
p.add_argument("--csv_path", required=False, type=str, default=None,
|
| 64 |
+
help="[Deprecated] CSV with columns: Taxon, protein_seq, cds_DNA")
|
| 65 |
+
p.add_argument("--embeddings_dir", type=str, default=None,
|
| 66 |
+
help="Species embeddings directory (recommended for parity)")
|
| 67 |
+
p.add_argument("--num_samples", type=int, default=10)
|
| 68 |
+
p.add_argument("--batch_size", type=int, default=10)
|
| 69 |
+
p.add_argument("--seed", type=int, default=42)
|
| 70 |
+
p.add_argument("--device", type=str, default="cuda")
|
| 71 |
+
p.add_argument("--workers", type=int, default=0,
|
| 72 |
+
help="DataLoader workers for --eval_all streaming mode")
|
| 73 |
+
# Free-run (sampling) evaluation options
|
| 74 |
+
p.add_argument("--free_run", action="store_true",
|
| 75 |
+
help="If set, perform real sampling instead of teacher forcing and compare to ground-truth codon sequences")
|
| 76 |
+
p.add_argument("--temperature", type=float, default=0.8)
|
| 77 |
+
p.add_argument("--top_k", type=int, default=50)
|
| 78 |
+
p.add_argument("--top_p", type=float, default=0.9)
|
| 79 |
+
p.add_argument("--control_mode", type=str, choices=["fixed","variable"], default="fixed")
|
| 80 |
+
p.add_argument("--enforce_translation", action="store_true",
|
| 81 |
+
help="Hard-mask decoding to codons matching target amino acid at each position during free-run evaluation")
|
| 82 |
+
# Full-dataset streaming eval (no sampling)
|
| 83 |
+
p.add_argument("--eval_all", action="store_true",
|
| 84 |
+
help="Stream over all rows from --data_path and compute aggregated metrics (memory-safe)")
|
| 85 |
+
p.add_argument("--max_records", type=int, default=0,
|
| 86 |
+
help="When --eval_all is set: limit to first N samples (0 = all)")
|
| 87 |
+
p.add_argument("--debug_aa_check", action="store_true",
|
| 88 |
+
help="Print per-sample agreement between CDS→AA (standard code) and provided protein")
|
| 89 |
+
# Per-sequence export over standard splits ./data/val and ./data/test
|
| 90 |
+
p.add_argument("--export_per_sequence", action="store_true",
|
| 91 |
+
help="Process ./data/val and ./data/test parquets in batches and export a per-sequence CSV")
|
| 92 |
+
p.add_argument("--splits_root", type=str, default="./data",
|
| 93 |
+
help="Root directory that contains val/ and test/ subfolders with parquet files")
|
| 94 |
+
p.add_argument("--out_csv", type=str, default="outputs/eval_per_sequence.csv",
|
| 95 |
+
help="Output CSV path for per-sequence export")
|
| 96 |
+
p.add_argument("--export_splits", nargs="+", default=["val", "test"],
|
| 97 |
+
help="Subdirectories under --splits_root to process (default: val test)")
|
| 98 |
+
p.add_argument("--max_rows_per_split", type=int, default=0,
|
| 99 |
+
help="When --export_per_sequence is set: limit number of rows per split (0 = all)")
|
| 100 |
+
p.add_argument("--progress", action="store_true",
|
| 101 |
+
help="Show progress bars during per-sequence export")
|
| 102 |
+
# Capacity and evaluation controls
|
| 103 |
+
p.add_argument("--no_truncation", action="store_true",
|
| 104 |
+
help="Fit prefix caps so generated codon length equals protein length (avoids capacity truncation)")
|
| 105 |
+
p.add_argument("--species_prefix_cap", type=int, default=0,
|
| 106 |
+
help="When >0 and --no_truncation is set, cap species token prefix to this many tokens; 0 = no species cap")
|
| 107 |
+
return p.parse_args()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _is_parquet_path(p: str) -> bool:
|
| 111 |
+
lower = p.lower()
|
| 112 |
+
return lower.endswith(".parquet") or lower.endswith(".parq")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _expand_paths(maybe_path_or_glob: str) -> List[str]:
|
| 116 |
+
"""Expand a path/glob or directory into a sorted list of files.
|
| 117 |
+
Prioritize Parquet when scanning a directory.
|
| 118 |
+
"""
|
| 119 |
+
paths: List[str] = []
|
| 120 |
+
P = Path(maybe_path_or_glob)
|
| 121 |
+
if P.is_dir():
|
| 122 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.parquet")))
|
| 123 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.parq")))
|
| 124 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.csv")))
|
| 125 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.tsv")))
|
| 126 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz")))
|
| 127 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz")))
|
| 128 |
+
else:
|
| 129 |
+
paths = sorted(glob.glob(str(P)))
|
| 130 |
+
# Dedup while preserving order
|
| 131 |
+
out: List[str] = []
|
| 132 |
+
seen = set()
|
| 133 |
+
for x in paths:
|
| 134 |
+
if x not in seen:
|
| 135 |
+
out.append(x)
|
| 136 |
+
seen.add(x)
|
| 137 |
+
return out
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _load_random_samples_from_parquet(files: List[str], num_samples: int, seed: int) -> pd.DataFrame:
|
| 141 |
+
"""Collect up to num_samples rows from a list of Parquet files, reading by row group.
|
| 142 |
+
Reads only the required columns and shuffles files/row-groups for decent coverage.
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
import pyarrow.parquet as pq # type: ignore
|
| 146 |
+
except Exception as e: # pragma: no cover
|
| 147 |
+
raise ImportError("pyarrow is required to read parquet files") from e
|
| 148 |
+
|
| 149 |
+
rng = random.Random(seed)
|
| 150 |
+
req = ["Taxon", "protein_seq", "cds_DNA"]
|
| 151 |
+
files = [f for f in files if _is_parquet_path(f)]
|
| 152 |
+
if not files:
|
| 153 |
+
raise FileNotFoundError("No Parquet files found to read")
|
| 154 |
+
files = files.copy()
|
| 155 |
+
rng.shuffle(files)
|
| 156 |
+
|
| 157 |
+
collected: List[pd.DataFrame] = []
|
| 158 |
+
remaining = int(max(0, num_samples))
|
| 159 |
+
for fp in files:
|
| 160 |
+
if remaining <= 0:
|
| 161 |
+
break
|
| 162 |
+
pf = pq.ParquetFile(fp)
|
| 163 |
+
nrg = int(pf.num_row_groups or 0)
|
| 164 |
+
if nrg <= 0:
|
| 165 |
+
rgs = [0]
|
| 166 |
+
else:
|
| 167 |
+
rgs = list(range(nrg))
|
| 168 |
+
rng.shuffle(rgs)
|
| 169 |
+
# Only keep columns that exist in this file
|
| 170 |
+
cols = [c for c in req if c in pf.schema.names]
|
| 171 |
+
if len(cols) < len(req):
|
| 172 |
+
missing = sorted(set(req) - set(cols))
|
| 173 |
+
raise ValueError(f"Parquet missing required columns {missing} in {fp}")
|
| 174 |
+
for rg in rgs:
|
| 175 |
+
if remaining <= 0:
|
| 176 |
+
break
|
| 177 |
+
table = pf.read_row_group(rg, columns=cols)
|
| 178 |
+
df = table.to_pandas(types_mapper=None)
|
| 179 |
+
if df.empty:
|
| 180 |
+
continue
|
| 181 |
+
if len(df) > remaining:
|
| 182 |
+
df = df.sample(n=remaining, random_state=rng.randint(0, 2**31 - 1))
|
| 183 |
+
collected.append(df)
|
| 184 |
+
remaining -= len(df)
|
| 185 |
+
if not collected:
|
| 186 |
+
return pd.DataFrame(columns=req)
|
| 187 |
+
out = pd.concat(collected, ignore_index=True)
|
| 188 |
+
# Final shuffle for randomness
|
| 189 |
+
out = out.sample(frac=1.0, random_state=seed).reset_index(drop=True)
|
| 190 |
+
# If we somehow overshot, trim
|
| 191 |
+
if len(out) > num_samples:
|
| 192 |
+
out = out.iloc[:num_samples].reset_index(drop=True)
|
| 193 |
+
return out
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _preferred_pooling(model_dir: Path) -> str:
|
| 197 |
+
"""
|
| 198 |
+
Best-effort pooling detection:
|
| 199 |
+
- First try checkpoint configs for an explicit hint
|
| 200 |
+
- Fallback to 'last'
|
| 201 |
+
Note: we'll further override this using the embeddings_dir contents if provided.
|
| 202 |
+
"""
|
| 203 |
+
for cfg_name in ("trainer_config.json", "config.json"):
|
| 204 |
+
fp = model_dir / cfg_name
|
| 205 |
+
if fp.exists():
|
| 206 |
+
try:
|
| 207 |
+
with open(fp) as f:
|
| 208 |
+
cfg = json.load(f)
|
| 209 |
+
return str(cfg.get("species_pooling", "last"))
|
| 210 |
+
except Exception:
|
| 211 |
+
continue
|
| 212 |
+
return "last"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _detect_pooling_from_embeddings_dir(emb_dir: Path) -> Optional[str]:
|
| 216 |
+
"""Detect actual available pooling format from embeddings_dir contents."""
|
| 217 |
+
fixed_files = [emb_dir / "species_embeddings.bin", emb_dir / "species_metadata.json", emb_dir / "species_vocab.json"]
|
| 218 |
+
seq_files = [emb_dir / "species_tok_emb.bin", emb_dir / "species_index.json", emb_dir / "species_vocab.json"]
|
| 219 |
+
if all(p.exists() for p in fixed_files):
|
| 220 |
+
return "last"
|
| 221 |
+
if all(p.exists() for p in seq_files):
|
| 222 |
+
return "sequence"
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
@torch.no_grad()
|
| 227 |
+
def eval_batch(
|
| 228 |
+
sampler: CodonSampler,
|
| 229 |
+
species_store: Optional[SpeciesEmbeddingStore],
|
| 230 |
+
species_names: List[str],
|
| 231 |
+
protein_seqs: List[str],
|
| 232 |
+
dna_cds_list: List[str],
|
| 233 |
+
) -> Tuple[List[float], List[float]]:
|
| 234 |
+
"""Evaluate a batch in teacher-forced mode.
|
| 235 |
+
|
| 236 |
+
Returns per-sample (avg_ce_loss, aa_token_acc).
|
| 237 |
+
"""
|
| 238 |
+
tok = sampler.tokenizer
|
| 239 |
+
pad_id = tok.pad_token_id
|
| 240 |
+
eos_id = tok.eos_token_id
|
| 241 |
+
|
| 242 |
+
# Encode DNA to codon ids and align lengths (trim to min protein length)
|
| 243 |
+
codon_ids = []
|
| 244 |
+
seq_lens = []
|
| 245 |
+
for dna, prot in zip(dna_cds_list, protein_seqs):
|
| 246 |
+
# Trim to min length between DNA codons and protein AA
|
| 247 |
+
C_dna = len(dna) // 3
|
| 248 |
+
C_prot = len(prot)
|
| 249 |
+
C = max(min(C_dna, C_prot), 1)
|
| 250 |
+
dna_trim = dna[: 3 * C]
|
| 251 |
+
ids = tok.encode_codon_seq(dna_trim, validate=False)
|
| 252 |
+
ids.append(eos_id)
|
| 253 |
+
codon_ids.append(ids)
|
| 254 |
+
seq_lens.append(len(ids))
|
| 255 |
+
|
| 256 |
+
B = len(codon_ids)
|
| 257 |
+
T = max(seq_lens)
|
| 258 |
+
codons = torch.full((B, T), pad_id, dtype=torch.long)
|
| 259 |
+
mask = torch.zeros((B, T), dtype=torch.bool)
|
| 260 |
+
for i, ids in enumerate(codon_ids):
|
| 261 |
+
L = len(ids)
|
| 262 |
+
codons[i, :L] = torch.tensor(ids, dtype=torch.long)
|
| 263 |
+
mask[i, :L] = True
|
| 264 |
+
|
| 265 |
+
# inputs/labels aligned to training convention:
|
| 266 |
+
# model predicts next codon after a learned start token; labels are the
|
| 267 |
+
# same positions as inputs (not shifted by 1), with PAD/EOS masked out.
|
| 268 |
+
input_ids = codons[:, :-1]
|
| 269 |
+
labels_base = codons[:, :-1].clone()
|
| 270 |
+
# Mask out PAD and EOS like trainer.evaluate()
|
| 271 |
+
labels_base[labels_base == pad_id] = -100
|
| 272 |
+
labels_base[labels_base == eos_id] = -100
|
| 273 |
+
|
| 274 |
+
# Build conditioning dict similar to training and sampler
|
| 275 |
+
cond = {"control_mode": "fixed"}
|
| 276 |
+
|
| 277 |
+
if species_store is not None and species_names:
|
| 278 |
+
sid_list = [species_store.vocab.get(s, -1) for s in species_names]
|
| 279 |
+
num_unknown = sum(1 for x in sid_list if x < 0)
|
| 280 |
+
if num_unknown > 0:
|
| 281 |
+
logger.warning(f"{num_unknown}/{len(sid_list)} species not found in embeddings vocab; using zero embeddings")
|
| 282 |
+
result = species_store.batch_get(sid_list)
|
| 283 |
+
if isinstance(result, tuple):
|
| 284 |
+
sp_tok, _ = result # [B, Ls, Ds]
|
| 285 |
+
cond["species_tok_emb_src"] = sp_tok.to(sampler.device)
|
| 286 |
+
cond["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
|
| 287 |
+
else:
|
| 288 |
+
sp = result # [B, Ds]
|
| 289 |
+
cond["species_emb_src"] = sp.to(sampler.device)
|
| 290 |
+
cond["species_emb_tgt"] = sp.to(sampler.device)
|
| 291 |
+
elif species_names:
|
| 292 |
+
# On-the-fly species embeddings using Qwen (sequence pooling for training parity)
|
| 293 |
+
seq_emb, _lens = sampler._qwen_embed_names(species_names, pooling="sequence")
|
| 294 |
+
seq_emb = seq_emb.to(sampler.device)
|
| 295 |
+
cond["species_tok_emb_src"] = seq_emb
|
| 296 |
+
cond["species_tok_emb_tgt"] = seq_emb
|
| 297 |
+
|
| 298 |
+
# Match training: pass raw protein sequences; the model tokenizes internally
|
| 299 |
+
cond["protein_seqs"] = protein_seqs
|
| 300 |
+
|
| 301 |
+
# Move tensors to device
|
| 302 |
+
device = sampler.device
|
| 303 |
+
input_ids = input_ids.to(device)
|
| 304 |
+
labels_base = labels_base.to(device)
|
| 305 |
+
|
| 306 |
+
sampler.model.eval()
|
| 307 |
+
outputs = sampler.model(codon_ids=input_ids, cond=cond, labels=labels_base, return_dict=True)
|
| 308 |
+
logits = outputs["logits"] # [B, Lmax, V] aligned to per-sample capacity after prefix
|
| 309 |
+
try:
|
| 310 |
+
prefix_len = outputs.get("prefix_len", 0)
|
| 311 |
+
if isinstance(prefix_len, torch.Tensor):
|
| 312 |
+
prefix_len_dbg = int(prefix_len.max().item()) if prefix_len.numel() > 0 else 0
|
| 313 |
+
else:
|
| 314 |
+
prefix_len_dbg = int(prefix_len)
|
| 315 |
+
logger.debug(f"Prefix length(max)={prefix_len_dbg}, input_len={input_ids.size(1)}")
|
| 316 |
+
except Exception:
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
# Align labels/masks to logits length and per-sample caps
|
| 320 |
+
Bsz, Lmax, V = logits.size(0), logits.size(1), logits.size(2)
|
| 321 |
+
labels_aligned = torch.full((Bsz, Lmax), -100, dtype=labels_base.dtype, device=logits.device)
|
| 322 |
+
common_cols = min(labels_base.size(1), Lmax)
|
| 323 |
+
if common_cols > 0:
|
| 324 |
+
labels_aligned[:, :common_cols] = labels_base[:, :common_cols]
|
| 325 |
+
per_cap = outputs.get("per_cap", None)
|
| 326 |
+
if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz:
|
| 327 |
+
ar = torch.arange(Lmax, device=logits.device).unsqueeze(0)
|
| 328 |
+
cap_mask = ar < per_cap.to(device=logits.device).unsqueeze(1) # [B,Lmax]
|
| 329 |
+
else:
|
| 330 |
+
cap_mask = torch.ones_like(labels_aligned, dtype=torch.bool, device=logits.device)
|
| 331 |
+
|
| 332 |
+
# Mask labels beyond per-cap to -100 so CE ignores them
|
| 333 |
+
labels_masked = labels_aligned.clone().to(device=logits.device)
|
| 334 |
+
labels_masked[~cap_mask] = -100
|
| 335 |
+
|
| 336 |
+
# Cross-entropy per sample (include EOS target; ignore PAD)
|
| 337 |
+
loss_flat = F.cross_entropy(
|
| 338 |
+
logits.reshape(-1, V),
|
| 339 |
+
labels_masked.reshape(-1),
|
| 340 |
+
ignore_index=-100,
|
| 341 |
+
reduction="none",
|
| 342 |
+
).view(Bsz, Lmax)
|
| 343 |
+
|
| 344 |
+
# Accuracy per sample
|
| 345 |
+
preds = logits.argmax(dim=-1)
|
| 346 |
+
num_special = int(getattr(tok, "num_special_tokens", 0) or 0)
|
| 347 |
+
supervised = (labels_masked != -100) & cap_mask
|
| 348 |
+
if num_special > 0:
|
| 349 |
+
supervised = supervised & (labels_aligned >= num_special)
|
| 350 |
+
correct = (preds == labels_aligned) & supervised
|
| 351 |
+
|
| 352 |
+
per_sample_ce: List[float] = []
|
| 353 |
+
per_sample_acc: List[float] = []
|
| 354 |
+
per_sample_aa_acc: List[float] = []
|
| 355 |
+
codon2aa = tok.codon2aa_char_map() if hasattr(tok, "codon2aa_char_map") else {}
|
| 356 |
+
per_cap = outputs.get("per_cap", None)
|
| 357 |
+
per_cap_int = None
|
| 358 |
+
if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz:
|
| 359 |
+
per_cap_int = torch.clamp(per_cap.to(dtype=torch.long, device=logits.device), min=0, max=Lmax)
|
| 360 |
+
|
| 361 |
+
for i in range(B):
|
| 362 |
+
# Average CE over valid positions
|
| 363 |
+
valid = (labels_masked[i] != -100) & cap_mask[i]
|
| 364 |
+
if num_special > 0:
|
| 365 |
+
valid = valid & (labels_aligned[i] >= num_special)
|
| 366 |
+
ce = (loss_flat[i][valid].mean().item() if valid.any() else 0.0)
|
| 367 |
+
per_sample_ce.append(ce)
|
| 368 |
+
|
| 369 |
+
# Codon-level accuracy over supervised positions
|
| 370 |
+
denom = supervised[i].sum().item()
|
| 371 |
+
acc = (correct[i].sum().item() / denom) if denom > 0 else 0.0
|
| 372 |
+
# AA-level accuracy per sample (match trainer)
|
| 373 |
+
aa_acc = 0.0
|
| 374 |
+
if per_cap_int is not None and codon2aa and i < len(protein_seqs):
|
| 375 |
+
cap = int(per_cap_int[i].item())
|
| 376 |
+
if cap > 0:
|
| 377 |
+
mask_row = supervised[i, :cap]
|
| 378 |
+
if mask_row.any():
|
| 379 |
+
preds_row = preds[i, :cap][mask_row]
|
| 380 |
+
prot = protein_seqs[i]
|
| 381 |
+
seq_len = min(len(prot), preds_row.size(0))
|
| 382 |
+
if seq_len > 0:
|
| 383 |
+
pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
|
| 384 |
+
truth_aa = prot[:seq_len]
|
| 385 |
+
aa_matches = sum(1 for j in range(seq_len) if pred_aa[j] == truth_aa[j])
|
| 386 |
+
aa_acc = aa_matches / seq_len
|
| 387 |
+
per_sample_aa_acc.append(aa_acc)
|
| 388 |
+
|
| 389 |
+
return per_sample_ce, per_sample_aa_acc
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _dna_to_codons(dna: str) -> List[str]:
|
| 393 |
+
dna = dna.strip().upper()
|
| 394 |
+
return [dna[i:i+3] for i in range(0, len(dna) - (len(dna) % 3), 3)]
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _aa_from_dna_standard(dna: str, tok) -> str:
|
| 398 |
+
dna = dna.strip().upper()
|
| 399 |
+
gc = getattr(tok, "_genetic_code", {})
|
| 400 |
+
aa = []
|
| 401 |
+
for j in range(0, len(dna) - (len(dna) % 3), 3):
|
| 402 |
+
aa.append(gc.get(dna[j:j+3], 'X'))
|
| 403 |
+
return ''.join(aa)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _aa_agreement(dna: str, protein: str, tok) -> Tuple[float, int, int]:
|
| 407 |
+
"""Return (match_ratio, compared_len, first_mismatch_idx or -1) under standard code."""
|
| 408 |
+
dna = dna.strip().upper()
|
| 409 |
+
protein = protein.strip().upper()
|
| 410 |
+
L = min(len(dna) // 3, len(protein))
|
| 411 |
+
if L <= 0:
|
| 412 |
+
return 0.0, 0, -1
|
| 413 |
+
aa_pred = _aa_from_dna_standard(dna[: 3 * L], tok)
|
| 414 |
+
truth = protein[:L]
|
| 415 |
+
mism_idx = -1
|
| 416 |
+
matches = 0
|
| 417 |
+
for i, (a, b) in enumerate(zip(aa_pred, truth)):
|
| 418 |
+
if a == b:
|
| 419 |
+
matches += 1
|
| 420 |
+
elif mism_idx < 0:
|
| 421 |
+
mism_idx = i
|
| 422 |
+
return (matches / L), L, mism_idx
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@torch.no_grad()
|
| 426 |
+
def eval_streaming_all(
|
| 427 |
+
sampler: CodonSampler,
|
| 428 |
+
species_store: SpeciesEmbeddingStore,
|
| 429 |
+
data_path: str,
|
| 430 |
+
batch_size: int,
|
| 431 |
+
num_workers: int,
|
| 432 |
+
max_records: int = 0,
|
| 433 |
+
):
|
| 434 |
+
"""Stream over all rows from CSV/Parquet inputs and compute dataset-level metrics.
|
| 435 |
+
|
| 436 |
+
Mirrors trainer.evaluate() for parity.
|
| 437 |
+
"""
|
| 438 |
+
device = sampler.device
|
| 439 |
+
tok = sampler.tokenizer
|
| 440 |
+
pad_id = int(tok.pad_token_id)
|
| 441 |
+
eos_id = int(tok.eos_token_id)
|
| 442 |
+
num_special = int(tok.num_special_tokens)
|
| 443 |
+
codon2aa = tok.codon2aa_char_map()
|
| 444 |
+
|
| 445 |
+
# Build streaming dataset and loader
|
| 446 |
+
from pathlib import Path as _Path
|
| 447 |
+
import glob as _glob
|
| 448 |
+
def _expand(pat: str) -> List[str]:
|
| 449 |
+
P = _Path(pat)
|
| 450 |
+
if P.is_dir():
|
| 451 |
+
paths: List[str] = []
|
| 452 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.parquet")))
|
| 453 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.parq")))
|
| 454 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.csv")))
|
| 455 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.tsv")))
|
| 456 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz")))
|
| 457 |
+
paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz")))
|
| 458 |
+
else:
|
| 459 |
+
paths = sorted(_glob.glob(str(P)))
|
| 460 |
+
# de-dup
|
| 461 |
+
seen = set(); out = []
|
| 462 |
+
for x in paths:
|
| 463 |
+
if x not in seen:
|
| 464 |
+
out.append(x); seen.add(x)
|
| 465 |
+
return out
|
| 466 |
+
|
| 467 |
+
paths = _expand(data_path)
|
| 468 |
+
if not paths:
|
| 469 |
+
raise FileNotFoundError(f"No input files matched: {data_path}")
|
| 470 |
+
|
| 471 |
+
species_vocab_path = str((Path(species_store.embeddings_dir) / "species_vocab.json").resolve())
|
| 472 |
+
ds = StreamSeqDataset(
|
| 473 |
+
files=paths,
|
| 474 |
+
tokenizer=tok,
|
| 475 |
+
species_vocab_path=species_vocab_path,
|
| 476 |
+
unknown_species_id=0,
|
| 477 |
+
csv_chunksize=200_000,
|
| 478 |
+
shuffle_buffer=0,
|
| 479 |
+
shard_across_ranks=False,
|
| 480 |
+
)
|
| 481 |
+
_dl_kwargs = dict(
|
| 482 |
+
batch_size=int(batch_size),
|
| 483 |
+
shuffle=False,
|
| 484 |
+
drop_last=False,
|
| 485 |
+
num_workers=int(max(0, num_workers)),
|
| 486 |
+
collate_fn=stage_collate_fn,
|
| 487 |
+
pin_memory=True,
|
| 488 |
+
persistent_workers=(int(num_workers) > 0),
|
| 489 |
+
)
|
| 490 |
+
if int(num_workers) > 0:
|
| 491 |
+
_dl_kwargs["prefetch_factor"] = 4
|
| 492 |
+
loader = DataLoader(ds, **_dl_kwargs)
|
| 493 |
+
|
| 494 |
+
loss_sum = 0.0
|
| 495 |
+
loss_tokens = 0
|
| 496 |
+
codon_correct = 0
|
| 497 |
+
codon_total = 0
|
| 498 |
+
aa_correct = 0
|
| 499 |
+
aa_total = 0
|
| 500 |
+
|
| 501 |
+
seen = 0
|
| 502 |
+
for batch in loader:
|
| 503 |
+
if not batch:
|
| 504 |
+
continue
|
| 505 |
+
if int(max_records) > 0 and seen >= int(max_records):
|
| 506 |
+
break
|
| 507 |
+
codon_ids = batch["codon_ids"].to(device)
|
| 508 |
+
input_ids = codon_ids[:, :-1]
|
| 509 |
+
labels = codon_ids[:, :-1].clone()
|
| 510 |
+
labels[labels == pad_id] = -100
|
| 511 |
+
labels[labels == eos_id] = -100
|
| 512 |
+
|
| 513 |
+
# Build cond using species_store and protein_seqs
|
| 514 |
+
cond = {"control_mode": "fixed", "protein_seqs": batch.get("protein_seqs", [])}
|
| 515 |
+
sids = batch.get("species_ids")
|
| 516 |
+
if torch.is_tensor(sids):
|
| 517 |
+
sids_list = sids.detach().cpu().tolist()
|
| 518 |
+
else:
|
| 519 |
+
sids_list = [int(x) for x in sids]
|
| 520 |
+
res = species_store.batch_get(sids_list)
|
| 521 |
+
if isinstance(res, tuple):
|
| 522 |
+
sp_tok, _ = res
|
| 523 |
+
cond["species_tok_emb_src"] = sp_tok.to(device)
|
| 524 |
+
cond["species_tok_emb_tgt"] = sp_tok.to(device)
|
| 525 |
+
else:
|
| 526 |
+
cond["species_emb_src"] = res.to(device)
|
| 527 |
+
cond["species_emb_tgt"] = res.to(device)
|
| 528 |
+
|
| 529 |
+
out = sampler.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True)
|
| 530 |
+
loss = out.get("loss")
|
| 531 |
+
per_cap = out.get("per_cap")
|
| 532 |
+
logits = out.get("logits")
|
| 533 |
+
|
| 534 |
+
tokens_in_batch = 0
|
| 535 |
+
if per_cap is not None:
|
| 536 |
+
tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item())
|
| 537 |
+
loss_tokens += tokens_in_batch
|
| 538 |
+
if loss is not None and tokens_in_batch > 0:
|
| 539 |
+
loss_sum += float(loss.detach().item()) * tokens_in_batch
|
| 540 |
+
|
| 541 |
+
if logits is None or logits.size(1) == 0 or per_cap is None:
|
| 542 |
+
seen += input_ids.size(0)
|
| 543 |
+
continue
|
| 544 |
+
max_cap = logits.size(1)
|
| 545 |
+
batch_size = logits.size(0)
|
| 546 |
+
labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device)
|
| 547 |
+
common = min(labels.size(1), max_cap)
|
| 548 |
+
if common > 0:
|
| 549 |
+
labels_aligned[:, :common] = labels[:, :common]
|
| 550 |
+
per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap)
|
| 551 |
+
for row in range(batch_size):
|
| 552 |
+
cap = int(per_cap_int[row].item())
|
| 553 |
+
if cap < max_cap:
|
| 554 |
+
labels_aligned[row, cap:] = -100
|
| 555 |
+
supervised = labels_aligned != -100
|
| 556 |
+
if num_special > 0:
|
| 557 |
+
supervised = supervised & (labels_aligned >= num_special)
|
| 558 |
+
if not supervised.any():
|
| 559 |
+
seen += batch_size
|
| 560 |
+
continue
|
| 561 |
+
preds = logits.argmax(dim=-1)
|
| 562 |
+
codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item())
|
| 563 |
+
codon_total += int(supervised.sum().item())
|
| 564 |
+
|
| 565 |
+
# protein list
|
| 566 |
+
prot_list = cond.get("protein_seqs", [])
|
| 567 |
+
for row in range(batch_size):
|
| 568 |
+
cap = int(per_cap_int[row].item())
|
| 569 |
+
if cap <= 0:
|
| 570 |
+
continue
|
| 571 |
+
mask_row = supervised[row, :cap]
|
| 572 |
+
if not mask_row.any():
|
| 573 |
+
continue
|
| 574 |
+
preds_row = preds[row, :cap][mask_row]
|
| 575 |
+
prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else ""
|
| 576 |
+
if not prot:
|
| 577 |
+
continue
|
| 578 |
+
seq_len = min(len(prot), preds_row.size(0))
|
| 579 |
+
if seq_len <= 0:
|
| 580 |
+
continue
|
| 581 |
+
pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
|
| 582 |
+
truth_aa = prot[:seq_len]
|
| 583 |
+
aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i])
|
| 584 |
+
aa_total += seq_len
|
| 585 |
+
seen += batch_size
|
| 586 |
+
|
| 587 |
+
mean_ce = (loss_sum / loss_tokens) if loss_tokens > 0 else 0.0
|
| 588 |
+
codon_acc = (float(codon_correct) / codon_total) if codon_total > 0 else 0.0
|
| 589 |
+
aa_acc = (float(aa_correct) / aa_total) if aa_total > 0 else 0.0
|
| 590 |
+
logger.info(
|
| 591 |
+
f"Full-dataset summary → tokens={loss_tokens} CE={mean_ce:.4f} CODON-acc={codon_acc:.4f} AA-acc={aa_acc:.4f}"
|
| 592 |
+
)
|
| 593 |
+
return mean_ce, codon_acc, aa_acc
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
@torch.no_grad()
|
| 597 |
+
def sample_and_score_batched(
|
| 598 |
+
sampler: CodonSampler,
|
| 599 |
+
species_names: List[str],
|
| 600 |
+
protein_seqs: List[str],
|
| 601 |
+
target_dnas: List[str],
|
| 602 |
+
temperature: float,
|
| 603 |
+
top_k: int,
|
| 604 |
+
top_p: float,
|
| 605 |
+
control_mode: str,
|
| 606 |
+
batch_size: int,
|
| 607 |
+
enforce_translation: bool,
|
| 608 |
+
no_truncation: bool = False,
|
| 609 |
+
species_prefix_cap: int = 64,
|
| 610 |
+
) -> Tuple[List[float], List[float]]:
|
| 611 |
+
"""Free-run sampling in batches; returns per-sample (codon_acc, aa_acc)."""
|
| 612 |
+
N = len(species_names)
|
| 613 |
+
# Compute target lengths in codons (min of DNA and AA lengths)
|
| 614 |
+
tgt_lengths = []
|
| 615 |
+
tgt_codons_list = []
|
| 616 |
+
for prot, dna in zip(protein_seqs, target_dnas):
|
| 617 |
+
cods = _dna_to_codons(dna)
|
| 618 |
+
L = min(len(cods), len(prot))
|
| 619 |
+
if L <= 0:
|
| 620 |
+
L = 1
|
| 621 |
+
cods = ["ATG"] # harmless default
|
| 622 |
+
tgt_lengths.append(L)
|
| 623 |
+
tgt_codons_list.append(cods[:L])
|
| 624 |
+
|
| 625 |
+
# Bucket indices by target length to maximize batching
|
| 626 |
+
buckets: dict[int, List[int]] = {}
|
| 627 |
+
for i, L in enumerate(tgt_lengths):
|
| 628 |
+
buckets.setdefault(L, []).append(i)
|
| 629 |
+
|
| 630 |
+
codon_accs = [0.0] * N
|
| 631 |
+
aa_accs = [0.0] * N
|
| 632 |
+
|
| 633 |
+
# Helper AA translation
|
| 634 |
+
vocab = sampler.tokenizer._genetic_code
|
| 635 |
+
def dna_to_aa(dna: str) -> str:
|
| 636 |
+
dna = dna.strip().upper()
|
| 637 |
+
aa = []
|
| 638 |
+
for j in range(0, len(dna) - (len(dna) % 3), 3):
|
| 639 |
+
aa.append(vocab.get(dna[j:j+3], 'X'))
|
| 640 |
+
return ''.join(aa)
|
| 641 |
+
|
| 642 |
+
for L, idxs in buckets.items():
|
| 643 |
+
# Optionally tighten protein prefix so prefix+start+L ≤ capacity (species kept full unless capped)
|
| 644 |
+
prev_sp = getattr(sampler.model, "max_species_prefix", 0)
|
| 645 |
+
prev_pp = getattr(sampler.model, "max_protein_prefix", 0)
|
| 646 |
+
if bool(no_truncation):
|
| 647 |
+
try:
|
| 648 |
+
capacity = int(getattr(sampler.model, "max_position_embeddings", 1024))
|
| 649 |
+
# If requested, apply a species token cap; otherwise keep as-is
|
| 650 |
+
store = getattr(sampler, "species_store", None)
|
| 651 |
+
if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0:
|
| 652 |
+
setattr(sampler.model, "max_species_prefix", int(species_prefix_cap))
|
| 653 |
+
# Build a representative cond for this bucket to measure exact prefix length
|
| 654 |
+
batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))]
|
| 655 |
+
sp_probe = [species_names[i] for i in batch_idx_probe]
|
| 656 |
+
pr_probe = [protein_seqs[i] for i in batch_idx_probe]
|
| 657 |
+
# Map species to ids via store vocab
|
| 658 |
+
cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe}
|
| 659 |
+
if store is not None:
|
| 660 |
+
sid_list = [store.vocab.get(s, -1) for s in sp_probe]
|
| 661 |
+
res = store.batch_get(sid_list)
|
| 662 |
+
if isinstance(res, tuple):
|
| 663 |
+
sp_tok, _ = res
|
| 664 |
+
cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device)
|
| 665 |
+
cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
|
| 666 |
+
else:
|
| 667 |
+
cond_probe["species_emb_src"] = res.to(sampler.device)
|
| 668 |
+
cond_probe["species_emb_tgt"] = res.to(sampler.device)
|
| 669 |
+
# Iteratively reduce protein prefix cap until remaining ≥ L
|
| 670 |
+
for _ in range(3):
|
| 671 |
+
out0 = sampler.model(
|
| 672 |
+
codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device),
|
| 673 |
+
cond=cond_probe,
|
| 674 |
+
return_dict=True,
|
| 675 |
+
use_cache=True,
|
| 676 |
+
)
|
| 677 |
+
pref = out0.get("prefix_len")
|
| 678 |
+
if isinstance(pref, torch.Tensor) and pref.numel() > 0:
|
| 679 |
+
pref_max = int(pref.max().item())
|
| 680 |
+
else:
|
| 681 |
+
pref_max = int(pref) if isinstance(pref, int) else 0
|
| 682 |
+
remaining = capacity - (pref_max + 1)
|
| 683 |
+
if remaining >= int(L):
|
| 684 |
+
break
|
| 685 |
+
need = int(L) - max(0, int(remaining))
|
| 686 |
+
cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0)
|
| 687 |
+
new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L)))
|
| 688 |
+
setattr(sampler.model, "max_protein_prefix", int(new_pp))
|
| 689 |
+
except Exception:
|
| 690 |
+
pass
|
| 691 |
+
# Process in mini-batches
|
| 692 |
+
for k in range(0, len(idxs), batch_size):
|
| 693 |
+
batch_idx = idxs[k:k+batch_size]
|
| 694 |
+
sp_b = [species_names[i] for i in batch_idx]
|
| 695 |
+
pr_b = [protein_seqs[i] for i in batch_idx]
|
| 696 |
+
# Sample in one call
|
| 697 |
+
out = sampler.sample(
|
| 698 |
+
num_sequences=len(batch_idx),
|
| 699 |
+
sequence_length=L,
|
| 700 |
+
species=sp_b,
|
| 701 |
+
protein_sequences=pr_b,
|
| 702 |
+
control_mode=control_mode,
|
| 703 |
+
temperature=temperature,
|
| 704 |
+
top_k=top_k,
|
| 705 |
+
top_p=top_p,
|
| 706 |
+
return_intermediate=False,
|
| 707 |
+
progress_bar=False,
|
| 708 |
+
enforce_translation=enforce_translation,
|
| 709 |
+
)
|
| 710 |
+
gen_list: List[str] = out["sequences"] # DNA strings
|
| 711 |
+
# Score each
|
| 712 |
+
for pos, idx in enumerate(batch_idx):
|
| 713 |
+
tgt_codons = tgt_codons_list[idx]
|
| 714 |
+
gen_codons = _dna_to_codons(gen_list[pos])[:L]
|
| 715 |
+
matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b)
|
| 716 |
+
codon_accs[idx] = (matches / L) if L > 0 else 0.0
|
| 717 |
+
gen_aa = dna_to_aa(''.join(gen_codons))
|
| 718 |
+
tgt_aa = protein_seqs[idx][:L]
|
| 719 |
+
# Treat non-canonical AA in target as "match any"
|
| 720 |
+
canonical = set("ACDEFGHIKLMNPQRSTVWY")
|
| 721 |
+
aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b))
|
| 722 |
+
aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0
|
| 723 |
+
# Restore caps
|
| 724 |
+
if bool(no_truncation):
|
| 725 |
+
try:
|
| 726 |
+
setattr(sampler.model, "max_species_prefix", prev_sp)
|
| 727 |
+
setattr(sampler.model, "max_protein_prefix", prev_pp)
|
| 728 |
+
except Exception:
|
| 729 |
+
pass
|
| 730 |
+
|
| 731 |
+
return codon_accs, aa_accs
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
@torch.no_grad()
|
| 735 |
+
def generate_and_score_batched(
|
| 736 |
+
sampler: CodonSampler,
|
| 737 |
+
species_names: List[str],
|
| 738 |
+
protein_seqs: List[str],
|
| 739 |
+
target_dnas: List[str],
|
| 740 |
+
temperature: float,
|
| 741 |
+
top_k: int,
|
| 742 |
+
top_p: float,
|
| 743 |
+
control_mode: str,
|
| 744 |
+
batch_size: int,
|
| 745 |
+
enforce_translation: bool,
|
| 746 |
+
no_truncation: bool = False,
|
| 747 |
+
species_prefix_cap: int = 64,
|
| 748 |
+
) -> Tuple[List[str], List[float], List[float]]:
|
| 749 |
+
"""Like sample_and_score_batched but also returns generated DNA sequences per sample."""
|
| 750 |
+
N = len(species_names)
|
| 751 |
+
tgt_lengths = []
|
| 752 |
+
tgt_codons_list = []
|
| 753 |
+
for prot, dna in zip(protein_seqs, target_dnas):
|
| 754 |
+
cods = _dna_to_codons(dna)
|
| 755 |
+
L = min(len(cods), len(prot))
|
| 756 |
+
if L <= 0:
|
| 757 |
+
L = 1
|
| 758 |
+
cods = ["ATG"]
|
| 759 |
+
tgt_lengths.append(L)
|
| 760 |
+
tgt_codons_list.append(cods[:L])
|
| 761 |
+
|
| 762 |
+
buckets: dict[int, List[int]] = {}
|
| 763 |
+
for i, L in enumerate(tgt_lengths):
|
| 764 |
+
buckets.setdefault(L, []).append(i)
|
| 765 |
+
|
| 766 |
+
gen_all = [""] * N
|
| 767 |
+
codon_accs = [0.0] * N
|
| 768 |
+
aa_accs = [0.0] * N
|
| 769 |
+
|
| 770 |
+
vocab = sampler.tokenizer._genetic_code
|
| 771 |
+
def dna_to_aa(dna: str) -> str:
|
| 772 |
+
dna = dna.strip().upper()
|
| 773 |
+
aa = []
|
| 774 |
+
for j in range(0, len(dna) - (len(dna) % 3), 3):
|
| 775 |
+
aa.append(vocab.get(dna[j:j+3], 'X'))
|
| 776 |
+
return ''.join(aa)
|
| 777 |
+
|
| 778 |
+
for L, idxs in buckets.items():
|
| 779 |
+
prev_sp = getattr(sampler.model, "max_species_prefix", 0)
|
| 780 |
+
prev_pp = getattr(sampler.model, "max_protein_prefix", 0)
|
| 781 |
+
if bool(no_truncation):
|
| 782 |
+
try:
|
| 783 |
+
capacity = int(getattr(sampler.model, "max_position_embeddings", 1024))
|
| 784 |
+
store = getattr(sampler, "species_store", None)
|
| 785 |
+
if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0:
|
| 786 |
+
setattr(sampler.model, "max_species_prefix", int(species_prefix_cap))
|
| 787 |
+
batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))]
|
| 788 |
+
sp_probe = [species_names[i] for i in batch_idx_probe]
|
| 789 |
+
pr_probe = [protein_seqs[i] for i in batch_idx_probe]
|
| 790 |
+
cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe}
|
| 791 |
+
if store is not None:
|
| 792 |
+
sid_list = [store.vocab.get(s, -1) for s in sp_probe]
|
| 793 |
+
res = store.batch_get(sid_list)
|
| 794 |
+
if isinstance(res, tuple):
|
| 795 |
+
sp_tok, _ = res
|
| 796 |
+
cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device)
|
| 797 |
+
cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
|
| 798 |
+
else:
|
| 799 |
+
cond_probe["species_emb_src"] = res.to(sampler.device)
|
| 800 |
+
cond_probe["species_emb_tgt"] = res.to(sampler.device)
|
| 801 |
+
for _ in range(3):
|
| 802 |
+
out0 = sampler.model(
|
| 803 |
+
codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device),
|
| 804 |
+
cond=cond_probe,
|
| 805 |
+
return_dict=True,
|
| 806 |
+
use_cache=True,
|
| 807 |
+
)
|
| 808 |
+
pref = out0.get("prefix_len")
|
| 809 |
+
pref_max = int(pref.max().item()) if isinstance(pref, torch.Tensor) and pref.numel() > 0 else (int(pref) if isinstance(pref, int) else 0)
|
| 810 |
+
remaining = capacity - (pref_max + 1)
|
| 811 |
+
if remaining >= int(L):
|
| 812 |
+
break
|
| 813 |
+
need = int(L) - max(0, int(remaining))
|
| 814 |
+
cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0)
|
| 815 |
+
new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L)))
|
| 816 |
+
setattr(sampler.model, "max_protein_prefix", int(new_pp))
|
| 817 |
+
except Exception:
|
| 818 |
+
pass
|
| 819 |
+
for k in range(0, len(idxs), batch_size):
|
| 820 |
+
batch_idx = idxs[k:k+batch_size]
|
| 821 |
+
sp_b = [species_names[i] for i in batch_idx]
|
| 822 |
+
pr_b = [protein_seqs[i] for i in batch_idx]
|
| 823 |
+
out = sampler.sample(
|
| 824 |
+
num_sequences=len(batch_idx),
|
| 825 |
+
sequence_length=L,
|
| 826 |
+
species=sp_b,
|
| 827 |
+
protein_sequences=pr_b,
|
| 828 |
+
control_mode=control_mode,
|
| 829 |
+
temperature=temperature,
|
| 830 |
+
top_k=top_k,
|
| 831 |
+
top_p=top_p,
|
| 832 |
+
return_intermediate=False,
|
| 833 |
+
progress_bar=False,
|
| 834 |
+
enforce_translation=enforce_translation,
|
| 835 |
+
)
|
| 836 |
+
gen_list: List[str] = out["sequences"]
|
| 837 |
+
for pos, idx in enumerate(batch_idx):
|
| 838 |
+
gen_seq = gen_list[pos]
|
| 839 |
+
gen_all[idx] = gen_seq
|
| 840 |
+
tgt_codons = tgt_codons_list[idx]
|
| 841 |
+
gen_codons = _dna_to_codons(gen_seq)[:L]
|
| 842 |
+
matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b)
|
| 843 |
+
codon_accs[idx] = (matches / L) if L > 0 else 0.0
|
| 844 |
+
gen_aa = dna_to_aa(''.join(gen_codons))
|
| 845 |
+
tgt_aa = protein_seqs[idx][:L]
|
| 846 |
+
canonical = set("ACDEFGHIKLMNPQRSTVWY")
|
| 847 |
+
aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b))
|
| 848 |
+
aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0
|
| 849 |
+
if bool(no_truncation):
|
| 850 |
+
try:
|
| 851 |
+
setattr(sampler.model, "max_species_prefix", prev_sp)
|
| 852 |
+
setattr(sampler.model, "max_protein_prefix", prev_pp)
|
| 853 |
+
except Exception:
|
| 854 |
+
pass
|
| 855 |
+
|
| 856 |
+
return gen_all, codon_accs, aa_accs
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def export_per_sequence_over_splits(
|
| 860 |
+
sampler: CodonSampler,
|
| 861 |
+
splits: List[str],
|
| 862 |
+
splits_root: str,
|
| 863 |
+
out_csv: str,
|
| 864 |
+
batch_size: int,
|
| 865 |
+
temperature: float,
|
| 866 |
+
top_k: int,
|
| 867 |
+
top_p: float,
|
| 868 |
+
control_mode: str,
|
| 869 |
+
enforce_translation: bool,
|
| 870 |
+
progress: bool = False,
|
| 871 |
+
max_rows_per_split: int = 0,
|
| 872 |
+
no_truncation: bool = False,
|
| 873 |
+
species_prefix_cap: int = 0,
|
| 874 |
+
) -> None:
|
| 875 |
+
"""Process ./data/val and ./data/test (or under splits_root) and write a per-sequence CSV."""
|
| 876 |
+
try:
|
| 877 |
+
import pyarrow.parquet as pq # type: ignore
|
| 878 |
+
except Exception as e:
|
| 879 |
+
raise ImportError("pyarrow is required for Parquet evaluation/export") from e
|
| 880 |
+
|
| 881 |
+
from pathlib import Path as _P
|
| 882 |
+
import os as _os
|
| 883 |
+
total_written = 0
|
| 884 |
+
# Pre-create CSV with header so users can tail it immediately
|
| 885 |
+
header_cols = [
|
| 886 |
+
"split",
|
| 887 |
+
"organism",
|
| 888 |
+
"protein_seq",
|
| 889 |
+
"codon_seq",
|
| 890 |
+
"predicted_seq",
|
| 891 |
+
"codon_similarity",
|
| 892 |
+
"amino_acid_recovery_rate",
|
| 893 |
+
]
|
| 894 |
+
_P(out_csv).parent.mkdir(parents=True, exist_ok=True)
|
| 895 |
+
if not _P(out_csv).exists() or _os.path.getsize(out_csv) == 0:
|
| 896 |
+
with open(out_csv, "w", newline="") as f:
|
| 897 |
+
f.write(",".join(header_cols) + "\n")
|
| 898 |
+
logging.info(f"Initialized CSV with header → {out_csv}")
|
| 899 |
+
for split in splits:
|
| 900 |
+
rows_remaining = int(max_rows_per_split) if int(max_rows_per_split) > 0 else None
|
| 901 |
+
dir_path = Path(splits_root) / split
|
| 902 |
+
files = sorted(str(p) for p in dir_path.glob("*.parquet"))
|
| 903 |
+
if not files:
|
| 904 |
+
logging.warning(f"No parquet files found in {dir_path}, skipping split {split}")
|
| 905 |
+
continue
|
| 906 |
+
logging.info(f"Processing split '{split}' with {len(files)} files ...")
|
| 907 |
+
try:
|
| 908 |
+
from tqdm import tqdm # type: ignore
|
| 909 |
+
_wrap = (lambda it, **kw: tqdm(it, **kw)) if progress else (lambda it, **kw: it)
|
| 910 |
+
except Exception:
|
| 911 |
+
_wrap = (lambda it, **kw: it)
|
| 912 |
+
stop_split = False
|
| 913 |
+
for fp in _wrap(files, desc=f"{split} files", unit="file"):
|
| 914 |
+
if rows_remaining is not None and rows_remaining <= 0:
|
| 915 |
+
break
|
| 916 |
+
pf = pq.ParquetFile(fp)
|
| 917 |
+
nrg = int(pf.num_row_groups or 0)
|
| 918 |
+
rgs = list(range(max(nrg, 1)))
|
| 919 |
+
# Build a per-file rows progress bar (prefer total rows from metadata when available)
|
| 920 |
+
rows_total = None
|
| 921 |
+
try:
|
| 922 |
+
if pf.metadata is not None:
|
| 923 |
+
rows_total = 0
|
| 924 |
+
for rg_idx in rgs:
|
| 925 |
+
rg_md = pf.metadata.row_group(rg_idx)
|
| 926 |
+
if rg_md is not None and rg_md.num_rows is not None:
|
| 927 |
+
rows_total += int(rg_md.num_rows)
|
| 928 |
+
except Exception:
|
| 929 |
+
rows_total = None
|
| 930 |
+
rows_pbar = None
|
| 931 |
+
if progress:
|
| 932 |
+
try:
|
| 933 |
+
from tqdm import tqdm # type: ignore
|
| 934 |
+
rows_pbar = tqdm(total=rows_total, desc=f"{split}:{Path(fp).name}", unit="rows", leave=False)
|
| 935 |
+
except Exception:
|
| 936 |
+
rows_pbar = None
|
| 937 |
+
|
| 938 |
+
for rg in rgs:
|
| 939 |
+
if rows_remaining is not None and rows_remaining <= 0:
|
| 940 |
+
stop_split = True
|
| 941 |
+
break
|
| 942 |
+
table = pf.read_row_group(rg, columns=["Taxon", "protein_seq", "cds_DNA"])
|
| 943 |
+
df = table.to_pandas()
|
| 944 |
+
if df.empty:
|
| 945 |
+
continue
|
| 946 |
+
species = df["Taxon"].astype(str).tolist()
|
| 947 |
+
proteins = df["protein_seq"].astype(str).str.upper().tolist()
|
| 948 |
+
dnas = df["cds_DNA"].astype(str).str.upper().tolist()
|
| 949 |
+
|
| 950 |
+
# Generate predictions and metrics in streaming mini-batches to keep
|
| 951 |
+
# memory stable and update progress frequently
|
| 952 |
+
N = len(species)
|
| 953 |
+
for off in range(0, N, batch_size):
|
| 954 |
+
if rows_remaining is not None and rows_remaining <= 0:
|
| 955 |
+
stop_split = True
|
| 956 |
+
break
|
| 957 |
+
sp_b = species[off: off + batch_size]
|
| 958 |
+
pr_b = proteins[off: off + batch_size]
|
| 959 |
+
dn_b = dnas[off: off + batch_size]
|
| 960 |
+
gen_list, codon_accs, aa_accs = generate_and_score_batched(
|
| 961 |
+
sampler,
|
| 962 |
+
sp_b,
|
| 963 |
+
pr_b,
|
| 964 |
+
dn_b,
|
| 965 |
+
temperature=temperature,
|
| 966 |
+
top_k=top_k,
|
| 967 |
+
top_p=top_p,
|
| 968 |
+
control_mode=control_mode,
|
| 969 |
+
batch_size=batch_size,
|
| 970 |
+
enforce_translation=enforce_translation,
|
| 971 |
+
no_truncation=bool(no_truncation),
|
| 972 |
+
species_prefix_cap=int(species_prefix_cap),
|
| 973 |
+
)
|
| 974 |
+
rows_batch: List[dict] = []
|
| 975 |
+
for sp, pr, dn, gen, cacc, aacc in zip(sp_b, pr_b, dn_b, gen_list, codon_accs, aa_accs):
|
| 976 |
+
L = min(len(pr), len(dn) // 3)
|
| 977 |
+
tgt_dna = dn[: 3 * L]
|
| 978 |
+
rows_batch.append({
|
| 979 |
+
"split": split,
|
| 980 |
+
"organism": sp,
|
| 981 |
+
"protein_seq": pr,
|
| 982 |
+
"codon_seq": tgt_dna,
|
| 983 |
+
"predicted_seq": gen,
|
| 984 |
+
"codon_similarity": float(cacc),
|
| 985 |
+
"amino_acid_recovery_rate": float(aacc),
|
| 986 |
+
})
|
| 987 |
+
if rows_batch:
|
| 988 |
+
if rows_remaining is not None and len(rows_batch) > rows_remaining:
|
| 989 |
+
rows_batch = rows_batch[: rows_remaining]
|
| 990 |
+
out_exists = _P(out_csv).exists() and _os.path.getsize(out_csv) > 0
|
| 991 |
+
df_out = pd.DataFrame(rows_batch)
|
| 992 |
+
_P(out_csv).parent.mkdir(parents=True, exist_ok=True)
|
| 993 |
+
df_out.to_csv(out_csv, mode='a', header=not out_exists, index=False)
|
| 994 |
+
total_written += len(rows_batch)
|
| 995 |
+
if rows_remaining is not None:
|
| 996 |
+
rows_remaining -= len(rows_batch)
|
| 997 |
+
if rows_pbar is not None:
|
| 998 |
+
try:
|
| 999 |
+
rows_pbar.update(len(rows_batch))
|
| 1000 |
+
except Exception:
|
| 1001 |
+
pass
|
| 1002 |
+
if rows_remaining is not None and rows_remaining <= 0:
|
| 1003 |
+
stop_split = True
|
| 1004 |
+
break
|
| 1005 |
+
if rows_pbar is not None:
|
| 1006 |
+
try:
|
| 1007 |
+
rows_pbar.close()
|
| 1008 |
+
except Exception:
|
| 1009 |
+
pass
|
| 1010 |
+
if stop_split:
|
| 1011 |
+
break
|
| 1012 |
+
logging.info(f"Per-sequence export complete → {out_csv} (rows={total_written})")
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
def main():
|
| 1016 |
+
args = parse_args()
|
| 1017 |
+
random.seed(args.seed)
|
| 1018 |
+
torch.manual_seed(args.seed)
|
| 1019 |
+
|
| 1020 |
+
model_dir = Path(args.model_path)
|
| 1021 |
+
pooling = _preferred_pooling(model_dir)
|
| 1022 |
+
logger.info(f"Preferred species_pooling from checkpoint: {pooling}")
|
| 1023 |
+
|
| 1024 |
+
# Set up species store (recommended for parity)
|
| 1025 |
+
species_store = None
|
| 1026 |
+
if args.embeddings_dir:
|
| 1027 |
+
emb_dir = Path(args.embeddings_dir)
|
| 1028 |
+
detected = _detect_pooling_from_embeddings_dir(emb_dir)
|
| 1029 |
+
if detected is not None and detected != pooling:
|
| 1030 |
+
logger.info(f"Overriding pooling from checkpoint ({pooling}) → embeddings_dir format ({detected})")
|
| 1031 |
+
pooling = detected
|
| 1032 |
+
species_store = SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling)
|
| 1033 |
+
logger.info(f"Loaded species store with {len(species_store.vocab)} species (pooling={pooling})")
|
| 1034 |
+
|
| 1035 |
+
# Load sampler/model (uses same construction as sampling)
|
| 1036 |
+
sampler = CodonSampler(
|
| 1037 |
+
model_path=args.model_path,
|
| 1038 |
+
device=("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"),
|
| 1039 |
+
species_store=species_store,
|
| 1040 |
+
)
|
| 1041 |
+
|
| 1042 |
+
# Load input data and sample rows
|
| 1043 |
+
if bool(args.export_per_sequence):
|
| 1044 |
+
export_per_sequence_over_splits(
|
| 1045 |
+
sampler,
|
| 1046 |
+
splits=list(args.export_splits),
|
| 1047 |
+
splits_root=str(args.splits_root),
|
| 1048 |
+
out_csv=str(args.out_csv),
|
| 1049 |
+
batch_size=int(args.batch_size),
|
| 1050 |
+
temperature=float(args.temperature),
|
| 1051 |
+
top_k=int(args.top_k),
|
| 1052 |
+
top_p=float(args.top_p),
|
| 1053 |
+
control_mode=str(args.control_mode),
|
| 1054 |
+
enforce_translation=bool(args.enforce_translation),
|
| 1055 |
+
progress=bool(args.progress),
|
| 1056 |
+
max_rows_per_split=int(args.max_rows_per_split),
|
| 1057 |
+
no_truncation=bool(args.no_truncation),
|
| 1058 |
+
species_prefix_cap=int(args.species_prefix_cap),
|
| 1059 |
+
)
|
| 1060 |
+
return
|
| 1061 |
+
|
| 1062 |
+
data_path = args.data_path or args.csv_path
|
| 1063 |
+
if data_path is None:
|
| 1064 |
+
raise SystemExit("Please provide --data_path (CSV or Parquet glob/dir). --csv_path remains as a deprecated alias.")
|
| 1065 |
+
|
| 1066 |
+
# Expand paths to decide CSV vs Parquet
|
| 1067 |
+
paths = _expand_paths(data_path)
|
| 1068 |
+
if not paths:
|
| 1069 |
+
raise FileNotFoundError(f"No input files matched: {data_path}")
|
| 1070 |
+
|
| 1071 |
+
if all(_is_parquet_path(p) for p in paths):
|
| 1072 |
+
logger.info(f"Reading up to {args.num_samples} samples from {len(paths)} parquet files ...")
|
| 1073 |
+
df_s = _load_random_samples_from_parquet(paths, int(args.num_samples), int(args.seed))
|
| 1074 |
+
else:
|
| 1075 |
+
# Fallback to CSV/TSV single file behavior (back-compat). If multiple files match, use the first.
|
| 1076 |
+
csv_file = None
|
| 1077 |
+
for pth in paths:
|
| 1078 |
+
if pth.lower().endswith((".csv", ".tsv", ".csv.gz", ".tsv.gz")):
|
| 1079 |
+
csv_file = pth
|
| 1080 |
+
break
|
| 1081 |
+
if csv_file is None:
|
| 1082 |
+
raise ValueError(f"Unsupported input for --data_path: {paths[0]}")
|
| 1083 |
+
logger.info(f"Reading CSV file: {csv_file}")
|
| 1084 |
+
df = pd.read_csv(csv_file)
|
| 1085 |
+
required = {"Taxon", "protein_seq", "cds_DNA"}
|
| 1086 |
+
if not required.issubset(set(df.columns)):
|
| 1087 |
+
missing = required - set(df.columns)
|
| 1088 |
+
raise ValueError(f"CSV missing required columns: {sorted(missing)}")
|
| 1089 |
+
if args.num_samples > len(df):
|
| 1090 |
+
logger.warning(f"num_samples {args.num_samples} > CSV rows {len(df)}; reducing")
|
| 1091 |
+
args.num_samples = len(df)
|
| 1092 |
+
# Random sample without replacement
|
| 1093 |
+
indices = random.sample(range(len(df)), args.num_samples)
|
| 1094 |
+
df_s = df.iloc[indices].reset_index(drop=True)
|
| 1095 |
+
|
| 1096 |
+
if len(df_s) == 0:
|
| 1097 |
+
raise ValueError("No samples loaded from the provided data_path")
|
| 1098 |
+
|
| 1099 |
+
logger.info(f"Loaded {len(df_s)} samples for evaluation")
|
| 1100 |
+
|
| 1101 |
+
species = df_s["Taxon"].astype(str).tolist()
|
| 1102 |
+
proteins = df_s["protein_seq"].astype(str).str.upper().tolist()
|
| 1103 |
+
dnas = df_s["cds_DNA"].astype(str).str.upper().tolist()
|
| 1104 |
+
|
| 1105 |
+
if not args.free_run:
|
| 1106 |
+
if bool(args.eval_all):
|
| 1107 |
+
if not args.embeddings_dir:
|
| 1108 |
+
raise SystemExit("--eval_all requires --embeddings_dir for species vocab/embeddings")
|
| 1109 |
+
# Stream the entire dataset and compute dataset-level metrics (training-parity)
|
| 1110 |
+
eval_streaming_all(
|
| 1111 |
+
sampler,
|
| 1112 |
+
species_store if species_store is not None else SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling),
|
| 1113 |
+
data_path,
|
| 1114 |
+
batch_size=int(args.batch_size),
|
| 1115 |
+
num_workers=int(args.workers),
|
| 1116 |
+
max_records=int(args.max_records),
|
| 1117 |
+
)
|
| 1118 |
+
return
|
| 1119 |
+
# Optional: print per-sample CDS→AA agreement (standard code)
|
| 1120 |
+
if bool(args.debug_aa_check):
|
| 1121 |
+
for idx, (sp, pr, dn) in enumerate(zip(species, proteins, dnas), start=1):
|
| 1122 |
+
ratio, Lcmp, first_bad = _aa_agreement(dn, pr, sampler.tokenizer)
|
| 1123 |
+
flag = "OK" if ratio == 1.0 and Lcmp > 0 else ("EMPTY" if Lcmp == 0 else "MISMATCH")
|
| 1124 |
+
extra = f" first_mismatch={first_bad}" if first_bad >= 0 else ""
|
| 1125 |
+
logger.info(f"AA-CHECK Sample {idx:02d}: {flag} match={ratio:.3f} len={Lcmp}{extra} Taxon={sp}")
|
| 1126 |
+
# (No dataset-level filtering to keep evaluation simple.)
|
| 1127 |
+
# Teacher-forced evaluation (random subset)
|
| 1128 |
+
per_ce_all: List[float] = []
|
| 1129 |
+
per_aa_acc_all: List[float] = []
|
| 1130 |
+
per_codon_acc_all: List[float] = []
|
| 1131 |
+
bs = max(1, int(args.batch_size))
|
| 1132 |
+
for i in range(0, len(species), bs):
|
| 1133 |
+
sp_b = species[i:i+bs]
|
| 1134 |
+
pr_b = proteins[i:i+bs]
|
| 1135 |
+
dn_b = dnas[i:i+bs]
|
| 1136 |
+
ce, aa_acc = eval_batch(sampler, species_store, sp_b, pr_b, dn_b)
|
| 1137 |
+
# Also compute per-sample codon-acc using the same batch forward for consistency
|
| 1138 |
+
# Re-run lightweight preds for codon-acc is unnecessary because eval_batch already
|
| 1139 |
+
# computed supervised mask and preds internally; instead, recompute quickly here
|
| 1140 |
+
# by calling eval_batch and deriving codon-acc inside it. For simplicity and clarity
|
| 1141 |
+
# we re-derive codon-acc below using the same masking rules.
|
| 1142 |
+
per_ce_all.extend(ce)
|
| 1143 |
+
per_aa_acc_all.extend(aa_acc)
|
| 1144 |
+
|
| 1145 |
+
# Derive codon-acc for this batch
|
| 1146 |
+
# Prepare a mirrored forward to access logits and masks (small overhead acceptable)
|
| 1147 |
+
tok = sampler.tokenizer
|
| 1148 |
+
pad_id = tok.pad_token_id
|
| 1149 |
+
eos_id = tok.eos_token_id
|
| 1150 |
+
codon_ids_local = []
|
| 1151 |
+
for dna, prot in zip(dn_b, pr_b):
|
| 1152 |
+
C_dna = len(dna) // 3
|
| 1153 |
+
C_prot = len(prot)
|
| 1154 |
+
C = max(min(C_dna, C_prot), 1)
|
| 1155 |
+
dna_trim = dna[: 3 * C]
|
| 1156 |
+
ids = tok.encode_codon_seq(dna_trim, validate=False)
|
| 1157 |
+
ids.append(eos_id)
|
| 1158 |
+
codon_ids_local.append(ids)
|
| 1159 |
+
B_b = len(codon_ids_local)
|
| 1160 |
+
T_b = max(len(x) for x in codon_ids_local)
|
| 1161 |
+
codons_b = torch.full((B_b, T_b), pad_id, dtype=torch.long)
|
| 1162 |
+
mask_b = torch.zeros((B_b, T_b), dtype=torch.bool)
|
| 1163 |
+
for j, ids in enumerate(codon_ids_local):
|
| 1164 |
+
Lb = len(ids)
|
| 1165 |
+
codons_b[j, :Lb] = torch.tensor(ids, dtype=torch.long)
|
| 1166 |
+
mask_b[j, :Lb] = True
|
| 1167 |
+
input_ids_b = codons_b[:, :-1].to(sampler.device)
|
| 1168 |
+
labels_b = codons_b[:, :-1].clone()
|
| 1169 |
+
labels_b[labels_b == pad_id] = -100
|
| 1170 |
+
labels_b[labels_b == eos_id] = -100
|
| 1171 |
+
cond_b = {"control_mode": "fixed"}
|
| 1172 |
+
if species_store is not None and sp_b:
|
| 1173 |
+
sids_b = [species_store.vocab.get(s, -1) for s in sp_b]
|
| 1174 |
+
res_b = species_store.batch_get(sids_b)
|
| 1175 |
+
if isinstance(res_b, tuple):
|
| 1176 |
+
sp_tok_b, _ = res_b
|
| 1177 |
+
cond_b["species_tok_emb_src"] = sp_tok_b.to(sampler.device)
|
| 1178 |
+
cond_b["species_tok_emb_tgt"] = sp_tok_b.to(sampler.device)
|
| 1179 |
+
else:
|
| 1180 |
+
sp_fix_b = res_b
|
| 1181 |
+
cond_b["species_emb_src"] = sp_fix_b.to(sampler.device)
|
| 1182 |
+
cond_b["species_emb_tgt"] = sp_fix_b.to(sampler.device)
|
| 1183 |
+
cond_b["protein_seqs"] = pr_b
|
| 1184 |
+
out_b = sampler.model(codon_ids=input_ids_b, cond=cond_b, labels=labels_b.to(sampler.device), return_dict=True)
|
| 1185 |
+
logits_b = out_b["logits"]
|
| 1186 |
+
per_cap_b = out_b.get("per_cap")
|
| 1187 |
+
if logits_b is not None and per_cap_b is not None:
|
| 1188 |
+
Bsz, Lmax, V = logits_b.size(0), logits_b.size(1), logits_b.size(2)
|
| 1189 |
+
labels_aligned_b = torch.full((Bsz, Lmax), -100, dtype=labels_b.dtype, device=logits_b.device)
|
| 1190 |
+
common_cols_b = min(labels_b.size(1), Lmax)
|
| 1191 |
+
if common_cols_b > 0:
|
| 1192 |
+
labels_aligned_b[:, :common_cols_b] = labels_b.to(logits_b.device)[:, :common_cols_b]
|
| 1193 |
+
ar = torch.arange(Lmax, device=logits_b.device).unsqueeze(0)
|
| 1194 |
+
cap_mask_b = ar < per_cap_b.to(device=logits_b.device).unsqueeze(1)
|
| 1195 |
+
labels_masked_b = labels_aligned_b.clone()
|
| 1196 |
+
labels_masked_b[~cap_mask_b] = -100
|
| 1197 |
+
preds_b = logits_b.argmax(dim=-1)
|
| 1198 |
+
num_special = int(getattr(tok, "num_special_tokens", 0) or 0)
|
| 1199 |
+
supervised_b = (labels_masked_b != -100) & cap_mask_b
|
| 1200 |
+
if num_special > 0:
|
| 1201 |
+
supervised_b = supervised_b & (labels_aligned_b >= num_special)
|
| 1202 |
+
for r in range(Bsz):
|
| 1203 |
+
denom = int(supervised_b[r].sum().item())
|
| 1204 |
+
cod_acc = (float((preds_b[r][supervised_b[r]] == labels_aligned_b[r][supervised_b[r]]).sum().item()) / denom) if denom > 0 else 0.0
|
| 1205 |
+
per_codon_acc_all.append(cod_acc)
|
| 1206 |
+
|
| 1207 |
+
for idx, (ce, aa, ca) in enumerate(zip(per_ce_all, per_aa_acc_all, per_codon_acc_all), start=1):
|
| 1208 |
+
logger.info(f"Sample {idx:02d}: CE={ce:.4f} CODON-acc={ca:.4f} AA-acc={aa:.4f}")
|
| 1209 |
+
if per_ce_all:
|
| 1210 |
+
mean_ce = sum(per_ce_all) / len(per_ce_all)
|
| 1211 |
+
mean_aa = sum(per_aa_acc_all) / len(per_aa_acc_all) if per_aa_acc_all else 0.0
|
| 1212 |
+
mean_codon = sum(per_codon_acc_all) / len(per_codon_acc_all) if per_codon_acc_all else 0.0
|
| 1213 |
+
logger.info(f"Summary over {len(per_ce_all)} samples → mean CE={mean_ce:.4f}, mean CODON-acc={mean_codon:.4f}, mean AA-acc={mean_aa:.4f}")
|
| 1214 |
+
else:
|
| 1215 |
+
# Free-run sampling evaluation vs ground-truth DNA (codon-level), batched
|
| 1216 |
+
codon_accs, aa_accs = sample_and_score_batched(
|
| 1217 |
+
sampler,
|
| 1218 |
+
species,
|
| 1219 |
+
proteins,
|
| 1220 |
+
dnas,
|
| 1221 |
+
temperature=args.temperature,
|
| 1222 |
+
top_k=args.top_k,
|
| 1223 |
+
top_p=args.top_p,
|
| 1224 |
+
control_mode=args.control_mode,
|
| 1225 |
+
batch_size=int(args.batch_size),
|
| 1226 |
+
enforce_translation=bool(args.enforce_translation),
|
| 1227 |
+
no_truncation=bool(args.no_truncation),
|
| 1228 |
+
species_prefix_cap=int(args.species_prefix_cap),
|
| 1229 |
+
)
|
| 1230 |
+
for idx, (cacc, aacc) in enumerate(zip(codon_accs, aa_accs), start=1):
|
| 1231 |
+
logger.info(f"Sample {idx:02d}: CODON-acc={cacc:.4f} AA-acc={aacc:.4f}")
|
| 1232 |
+
if codon_accs:
|
| 1233 |
+
mean_c = sum(codon_accs) / len(codon_accs)
|
| 1234 |
+
mean_a = sum(aa_accs) / len(aa_accs)
|
| 1235 |
+
logger.info(f"Summary over {len(codon_accs)} samples → mean CODON-acc={mean_c:.4f}, mean AA-acc={mean_a:.4f}")
|
| 1236 |
+
|
| 1237 |
+
|
| 1238 |
+
if __name__ == "__main__":
|
| 1239 |
+
main()
|
final_model/config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_length": 2048,
|
| 3 |
+
"max_species_prefix": 0,
|
| 4 |
+
"max_protein_prefix": 1024,
|
| 5 |
+
"hidden_size": 750,
|
| 6 |
+
"num_hidden_layers": 20,
|
| 7 |
+
"num_attention_heads": 15,
|
| 8 |
+
"mlp_ratio": 3.2,
|
| 9 |
+
"prepend_species": true,
|
| 10 |
+
"prepend_protein": true,
|
| 11 |
+
"species_embedding_dim": 1024,
|
| 12 |
+
"esm_model_name": "esmc_300m",
|
| 13 |
+
"esm_device": "cuda:0",
|
| 14 |
+
"esm_dtype": "bf16",
|
| 15 |
+
"attn_impl": "mha",
|
| 16 |
+
"num_kv_groups": 5
|
| 17 |
+
}
|
final_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5af6fe27a93e8a5edf622131b8fff74240f90db036a95697cfe4f28af1d23ef9
|
| 3 |
+
size 1284544520
|
final_model/trainer_config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_length": 2048,
|
| 3 |
+
"max_species_prefix": 0,
|
| 4 |
+
"max_protein_prefix": 1024,
|
| 5 |
+
"hidden_size": 750,
|
| 6 |
+
"num_hidden_layers": 20,
|
| 7 |
+
"num_attention_heads": 15,
|
| 8 |
+
"mlp_ratio": 3.2,
|
| 9 |
+
"prepend_species": true,
|
| 10 |
+
"prepend_protein": true,
|
| 11 |
+
"species_embedding_dim": 1024,
|
| 12 |
+
"esm_model_name": "esmc_300m",
|
| 13 |
+
"esm_device": "cuda:0",
|
| 14 |
+
"esm_dtype": "bf16",
|
| 15 |
+
"attn_impl": "mha",
|
| 16 |
+
"num_kv_groups": 5
|
| 17 |
+
}
|
final_model/trainer_state.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 2,
|
| 3 |
+
"global_step": 120513
|
| 4 |
+
}
|
final_model/vocab.json
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"special_token_str": {
|
| 3 |
+
"bos": "<bos>",
|
| 4 |
+
"eos": "<stop>",
|
| 5 |
+
"pad": "<pad>",
|
| 6 |
+
"unk": "<unk>"
|
| 7 |
+
},
|
| 8 |
+
"vocab": {
|
| 9 |
+
"<bos>": 2,
|
| 10 |
+
"<pad>": 0,
|
| 11 |
+
"<stop>": 3,
|
| 12 |
+
"<unk>": 1,
|
| 13 |
+
"AAA": 4,
|
| 14 |
+
"AAC": 5,
|
| 15 |
+
"AAG": 6,
|
| 16 |
+
"AAT": 7,
|
| 17 |
+
"ACA": 8,
|
| 18 |
+
"ACC": 9,
|
| 19 |
+
"ACG": 10,
|
| 20 |
+
"ACT": 11,
|
| 21 |
+
"AGA": 12,
|
| 22 |
+
"AGC": 13,
|
| 23 |
+
"AGG": 14,
|
| 24 |
+
"AGT": 15,
|
| 25 |
+
"ATA": 16,
|
| 26 |
+
"ATC": 17,
|
| 27 |
+
"ATG": 18,
|
| 28 |
+
"ATT": 19,
|
| 29 |
+
"CAA": 20,
|
| 30 |
+
"CAC": 21,
|
| 31 |
+
"CAG": 22,
|
| 32 |
+
"CAT": 23,
|
| 33 |
+
"CCA": 24,
|
| 34 |
+
"CCC": 25,
|
| 35 |
+
"CCG": 26,
|
| 36 |
+
"CCT": 27,
|
| 37 |
+
"CGA": 28,
|
| 38 |
+
"CGC": 29,
|
| 39 |
+
"CGG": 30,
|
| 40 |
+
"CGT": 31,
|
| 41 |
+
"CTA": 32,
|
| 42 |
+
"CTC": 33,
|
| 43 |
+
"CTG": 34,
|
| 44 |
+
"CTT": 35,
|
| 45 |
+
"GAA": 36,
|
| 46 |
+
"GAC": 37,
|
| 47 |
+
"GAG": 38,
|
| 48 |
+
"GAT": 39,
|
| 49 |
+
"GCA": 40,
|
| 50 |
+
"GCC": 41,
|
| 51 |
+
"GCG": 42,
|
| 52 |
+
"GCT": 43,
|
| 53 |
+
"GGA": 44,
|
| 54 |
+
"GGC": 45,
|
| 55 |
+
"GGG": 46,
|
| 56 |
+
"GGT": 47,
|
| 57 |
+
"GTA": 48,
|
| 58 |
+
"GTC": 49,
|
| 59 |
+
"GTG": 50,
|
| 60 |
+
"GTT": 51,
|
| 61 |
+
"TAA": 52,
|
| 62 |
+
"TAC": 53,
|
| 63 |
+
"TAG": 54,
|
| 64 |
+
"TAT": 55,
|
| 65 |
+
"TCA": 56,
|
| 66 |
+
"TCC": 57,
|
| 67 |
+
"TCG": 58,
|
| 68 |
+
"TCT": 59,
|
| 69 |
+
"TGA": 60,
|
| 70 |
+
"TGC": 61,
|
| 71 |
+
"TGG": 62,
|
| 72 |
+
"TGT": 63,
|
| 73 |
+
"TTA": 64,
|
| 74 |
+
"TTC": 65,
|
| 75 |
+
"TTG": 66,
|
| 76 |
+
"TTT": 67
|
| 77 |
+
}
|
| 78 |
+
}
|
precompute_embeddings.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Precompute species embeddings for CodonTranslator training.
|
| 4 |
+
Protein embeddings are now computed on-the-fly using integrated ESM-C model.
|
| 5 |
+
|
| 6 |
+
Steps:
|
| 7 |
+
1. Build taxonomy database from GBIF API
|
| 8 |
+
2. Generate species embeddings using Qwen3-Embedding-0.6B
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import logging
|
| 14 |
+
import argparse
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Optional, Tuple
|
| 17 |
+
import glob
|
| 18 |
+
import requests
|
| 19 |
+
import time
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import torch
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_taxonomy_database(species_list: List[str]) -> Dict[str, str]:
|
| 32 |
+
"""Query GBIF API for comprehensive phylogenetic taxonomy of species.
|
| 33 |
+
|
| 34 |
+
Creates detailed taxonomic descriptions for better species embeddings.
|
| 35 |
+
"""
|
| 36 |
+
taxonomy_db = {}
|
| 37 |
+
base_url = "https://api.gbif.org/v1/species/match"
|
| 38 |
+
|
| 39 |
+
logger.info(f"Building taxonomy database for {len(species_list)} species...")
|
| 40 |
+
for species in tqdm(species_list, desc="Querying GBIF"):
|
| 41 |
+
if not species or species in taxonomy_db:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
response = requests.get(base_url, params={"name": species})
|
| 46 |
+
if response.status_code == 200:
|
| 47 |
+
data = response.json()
|
| 48 |
+
if data.get("matchType") != "NONE":
|
| 49 |
+
# Build comprehensive taxonomy description
|
| 50 |
+
parts = []
|
| 51 |
+
|
| 52 |
+
# Add scientific classification
|
| 53 |
+
taxonomy = []
|
| 54 |
+
for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]:
|
| 55 |
+
if rank in data and data[rank]:
|
| 56 |
+
taxonomy.append(data[rank])
|
| 57 |
+
|
| 58 |
+
if taxonomy:
|
| 59 |
+
parts.append("Taxonomy: " + " > ".join(taxonomy))
|
| 60 |
+
|
| 61 |
+
# Add common name if available
|
| 62 |
+
if "vernacularName" in data and data["vernacularName"]:
|
| 63 |
+
parts.append(f"Common name: {data['vernacularName']}")
|
| 64 |
+
|
| 65 |
+
# Add confidence score
|
| 66 |
+
if "confidence" in data:
|
| 67 |
+
parts.append(f"Match confidence: {data['confidence']}%")
|
| 68 |
+
|
| 69 |
+
# Add status (accepted, synonym, etc.)
|
| 70 |
+
if "status" in data:
|
| 71 |
+
parts.append(f"Status: {data['status']}")
|
| 72 |
+
|
| 73 |
+
# Combine all parts into comprehensive description
|
| 74 |
+
taxonomy_db[species] = ". ".join(parts) if parts else species
|
| 75 |
+
else:
|
| 76 |
+
# No match found - use species name with indicator
|
| 77 |
+
taxonomy_db[species] = f"Species: {species} (no GBIF match)"
|
| 78 |
+
else:
|
| 79 |
+
taxonomy_db[species] = f"Species: {species} (query failed)"
|
| 80 |
+
|
| 81 |
+
# Rate limiting
|
| 82 |
+
time.sleep(0.1)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.warning(f"Error querying GBIF for {species}: {e}")
|
| 85 |
+
taxonomy_db[species] = f"Species: {species} (error)"
|
| 86 |
+
|
| 87 |
+
logger.info(f"Taxonomy database built with {len(taxonomy_db)} entries")
|
| 88 |
+
return taxonomy_db
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def generate_species_embeddings_qwen(
|
| 92 |
+
species_list: List[str],
|
| 93 |
+
taxonomy_db: Dict[str, str],
|
| 94 |
+
device: str = "cuda",
|
| 95 |
+
pooling: str = "last" # 'last' -> single vector; 'sequence'/'none' -> variable-length tokens
|
| 96 |
+
) -> Tuple[Dict[str, int], Dict[int, np.ndarray]]:
|
| 97 |
+
"""
|
| 98 |
+
Generate species embeddings using Qwen3-Embedding-0.6B.
|
| 99 |
+
- pooling='last': returns one vector per species (fixed size)
|
| 100 |
+
- pooling='none': returns variable-length token embeddings per species
|
| 101 |
+
"""
|
| 102 |
+
import torch.nn.functional as F
|
| 103 |
+
from transformers import AutoTokenizer, AutoModel
|
| 104 |
+
|
| 105 |
+
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""Pool by taking the last valid token's embedding."""
|
| 107 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 108 |
+
if left_padding:
|
| 109 |
+
return last_hidden_states[:, -1]
|
| 110 |
+
else:
|
| 111 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 112 |
+
batch_size = last_hidden_states.shape[0]
|
| 113 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
| 114 |
+
|
| 115 |
+
def get_detailed_instruct(task_description: str, query: str) -> str:
|
| 116 |
+
"""Format the input with instruction for better embedding quality."""
|
| 117 |
+
return f'Instruct: {task_description}\nQuery: {query}'
|
| 118 |
+
|
| 119 |
+
logger.info("Loading Qwen3-Embedding-0.6B model...")
|
| 120 |
+
model_name = "Qwen/Qwen3-Embedding-0.6B"
|
| 121 |
+
|
| 122 |
+
# Initialize with left padding for last token pooling
|
| 123 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
|
| 124 |
+
model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval()
|
| 125 |
+
|
| 126 |
+
species_vocab = {}
|
| 127 |
+
species_embeddings = {}
|
| 128 |
+
|
| 129 |
+
# Task description for species embedding
|
| 130 |
+
task = "Given a species taxonomy information, generate a biological embedding representing its taxonomic and evolutionary characteristics"
|
| 131 |
+
|
| 132 |
+
for idx, species in enumerate(tqdm(species_list, desc="Generating embeddings")):
|
| 133 |
+
# Get comprehensive taxonomy string from GBIF query results
|
| 134 |
+
taxonomy_str = taxonomy_db.get(species, species)
|
| 135 |
+
|
| 136 |
+
# Format with instruction for better semantic understanding
|
| 137 |
+
input_text = get_detailed_instruct(task, taxonomy_str)
|
| 138 |
+
|
| 139 |
+
# Generate embeddings
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
inputs = tokenizer(
|
| 142 |
+
input_text,
|
| 143 |
+
return_tensors="pt",
|
| 144 |
+
padding=True,
|
| 145 |
+
truncation=True,
|
| 146 |
+
max_length=512
|
| 147 |
+
)
|
| 148 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 149 |
+
|
| 150 |
+
outputs = model(**inputs)
|
| 151 |
+
hidden = outputs.last_hidden_state # [1, L, D]
|
| 152 |
+
if pooling == 'last':
|
| 153 |
+
pooled_embedding = last_token_pool(hidden, inputs['attention_mask'])
|
| 154 |
+
normalized_embedding = F.normalize(pooled_embedding, p=2, dim=1)
|
| 155 |
+
species_embedding = normalized_embedding.squeeze(0).cpu().numpy() # [D]
|
| 156 |
+
else:
|
| 157 |
+
# Variable-length token embeddings (normalize per token)
|
| 158 |
+
tok = hidden.squeeze(0) # [L, D]
|
| 159 |
+
tok = F.normalize(tok, p=2, dim=-1)
|
| 160 |
+
species_embedding = tok.cpu().numpy() # [L, D]
|
| 161 |
+
|
| 162 |
+
species_vocab[species] = idx
|
| 163 |
+
species_embeddings[idx] = species_embedding
|
| 164 |
+
|
| 165 |
+
logger.info(f"Generated {'fixed-size' if pooling=='last' else 'variable-length'} embeddings for {len(species_vocab)} species")
|
| 166 |
+
return species_vocab, species_embeddings
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def save_species_embeddings_memmap(
|
| 170 |
+
species_vocab: Dict[str, int],
|
| 171 |
+
species_embeddings: Dict[int, np.ndarray],
|
| 172 |
+
output_dir: str
|
| 173 |
+
) -> None:
|
| 174 |
+
"""Save fixed-size species embeddings as memory-mapped file."""
|
| 175 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 176 |
+
|
| 177 |
+
# Save vocabulary
|
| 178 |
+
vocab_path = os.path.join(output_dir, "species_vocab.json")
|
| 179 |
+
with open(vocab_path, 'w') as f:
|
| 180 |
+
json.dump(species_vocab, f, indent=2)
|
| 181 |
+
|
| 182 |
+
# All embeddings should have the same dimension now
|
| 183 |
+
num_species = len(species_embeddings)
|
| 184 |
+
embed_dim = next(iter(species_embeddings.values())).shape[0] # Should be 1024
|
| 185 |
+
|
| 186 |
+
# Create memmap for fixed-size embeddings
|
| 187 |
+
emb_path = os.path.join(output_dir, "species_embeddings.bin")
|
| 188 |
+
mmap = np.memmap(emb_path, dtype=np.float32, mode='w+', shape=(num_species, embed_dim))
|
| 189 |
+
|
| 190 |
+
# Store embeddings directly by ID
|
| 191 |
+
for species_id, emb in species_embeddings.items():
|
| 192 |
+
mmap[species_id] = emb.astype(np.float32)
|
| 193 |
+
|
| 194 |
+
# Flush to disk
|
| 195 |
+
del mmap
|
| 196 |
+
|
| 197 |
+
# Save metadata
|
| 198 |
+
metadata = {
|
| 199 |
+
"num_species": num_species,
|
| 200 |
+
"embedding_dim": embed_dim,
|
| 201 |
+
"embedding_type": "fixed_size",
|
| 202 |
+
"pooling_method": "last_token",
|
| 203 |
+
"normalization": "L2",
|
| 204 |
+
"model": "Qwen/Qwen3-Embedding-0.6B"
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
metadata_path = os.path.join(output_dir, "species_metadata.json")
|
| 208 |
+
with open(metadata_path, 'w') as f:
|
| 209 |
+
json.dump(metadata, f, indent=2)
|
| 210 |
+
|
| 211 |
+
logger.info(f"Saved {num_species} fixed-size species embeddings to {emb_path}")
|
| 212 |
+
logger.info(f"Embedding dimension: {embed_dim}")
|
| 213 |
+
logger.info(f"Saved metadata to {metadata_path}")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def save_species_token_embeddings_memmap(
|
| 217 |
+
species_vocab: Dict[str, int],
|
| 218 |
+
species_tok_embeddings: Dict[int, np.ndarray],
|
| 219 |
+
output_dir: str,
|
| 220 |
+
dtype: str = 'float32'
|
| 221 |
+
) -> None:
|
| 222 |
+
"""Save variable-length token embeddings into a flat memmap with index."""
|
| 223 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 224 |
+
|
| 225 |
+
# Save vocabulary
|
| 226 |
+
vocab_path = os.path.join(output_dir, "species_vocab.json")
|
| 227 |
+
with open(vocab_path, 'w') as f:
|
| 228 |
+
json.dump(species_vocab, f, indent=2)
|
| 229 |
+
|
| 230 |
+
# Compute totals and dims
|
| 231 |
+
embed_dim = next(iter(species_tok_embeddings.values())).shape[1]
|
| 232 |
+
total_tokens = int(sum(v.shape[0] for v in species_tok_embeddings.values()))
|
| 233 |
+
|
| 234 |
+
emb_path = os.path.join(output_dir, "species_tok_emb.bin")
|
| 235 |
+
mmap = np.memmap(emb_path, dtype=np.float32 if dtype=='float32' else np.float16, mode='w+', shape=(total_tokens, embed_dim))
|
| 236 |
+
|
| 237 |
+
# Build index
|
| 238 |
+
index = {}
|
| 239 |
+
offset = 0
|
| 240 |
+
for sid, arr in species_tok_embeddings.items():
|
| 241 |
+
L = int(arr.shape[0])
|
| 242 |
+
mmap[offset: offset + L] = arr.astype(np.float32 if dtype=='float32' else np.float16)
|
| 243 |
+
index[str(sid)] = {"offset": offset, "length": L}
|
| 244 |
+
offset += L
|
| 245 |
+
|
| 246 |
+
del mmap
|
| 247 |
+
|
| 248 |
+
with open(os.path.join(output_dir, "species_index.json"), 'w') as f:
|
| 249 |
+
json.dump(index, f, indent=2)
|
| 250 |
+
|
| 251 |
+
meta = {
|
| 252 |
+
"embedding_dim": embed_dim,
|
| 253 |
+
"dtype": dtype,
|
| 254 |
+
"total_tokens": total_tokens,
|
| 255 |
+
"embedding_type": "variable_length",
|
| 256 |
+
"pooling_method": "none",
|
| 257 |
+
"model": "Qwen/Qwen3-Embedding-0.6B"
|
| 258 |
+
}
|
| 259 |
+
with open(os.path.join(output_dir, "metadata.json"), 'w') as f:
|
| 260 |
+
json.dump(meta, f, indent=2)
|
| 261 |
+
logger.info(f"Saved variable-length species token embeddings to {emb_path} with {total_tokens} tokens total")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def filter_sequences_by_length(df: pd.DataFrame, max_protein_length: int = 2048) -> pd.DataFrame:
|
| 265 |
+
"""Filter sequences to prevent CUDA OOM during training."""
|
| 266 |
+
initial_count = len(df)
|
| 267 |
+
|
| 268 |
+
# Filter by protein length
|
| 269 |
+
if 'protein_seq' in df.columns:
|
| 270 |
+
df = df[df['protein_seq'].str.len() <= max_protein_length]
|
| 271 |
+
|
| 272 |
+
# Filter by CDS length (3x protein length)
|
| 273 |
+
if 'cds_DNA' in df.columns:
|
| 274 |
+
max_cds_length = max_protein_length * 3
|
| 275 |
+
df = df[df['cds_DNA'].str.len() <= max_cds_length]
|
| 276 |
+
|
| 277 |
+
final_count = len(df)
|
| 278 |
+
if final_count < initial_count:
|
| 279 |
+
logger.info(f"Filtered from {initial_count} to {final_count} sequences (max_protein_length={max_protein_length})")
|
| 280 |
+
|
| 281 |
+
return df
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def collect_unique_values_from_shards(
|
| 285 |
+
shards_glob: str,
|
| 286 |
+
column: str,
|
| 287 |
+
max_items: Optional[int] = None
|
| 288 |
+
) -> List[str]:
|
| 289 |
+
"""Stream over Parquet shards to collect unique values from a column."""
|
| 290 |
+
unique_values = set()
|
| 291 |
+
shard_files = sorted(glob.glob(shards_glob))
|
| 292 |
+
|
| 293 |
+
if not shard_files:
|
| 294 |
+
raise ValueError(f"No parquet files found matching {shards_glob}")
|
| 295 |
+
|
| 296 |
+
logger.info(f"Scanning {len(shard_files)} shards for unique {column} values...")
|
| 297 |
+
|
| 298 |
+
for shard_file in tqdm(shard_files, desc=f"Collecting {column}"):
|
| 299 |
+
# Some datasets use different casing (e.g., 'taxon' vs 'Taxon'). Resolve robustly.
|
| 300 |
+
try:
|
| 301 |
+
import pyarrow.parquet as pq # type: ignore
|
| 302 |
+
pf = pq.ParquetFile(shard_file)
|
| 303 |
+
names = set(pf.schema.names)
|
| 304 |
+
resolved = column
|
| 305 |
+
if resolved not in names:
|
| 306 |
+
lower_map = {n.lower(): n for n in names}
|
| 307 |
+
resolved = lower_map.get(column.lower(), column)
|
| 308 |
+
except Exception:
|
| 309 |
+
resolved = column
|
| 310 |
+
|
| 311 |
+
df = pd.read_parquet(shard_file, columns=[resolved])
|
| 312 |
+
# Canonicalize to the requested column name for downstream logic.
|
| 313 |
+
if resolved != column and resolved in df.columns and column not in df.columns:
|
| 314 |
+
df = df.rename(columns={resolved: column})
|
| 315 |
+
unique_values.update(df[column].dropna().unique())
|
| 316 |
+
|
| 317 |
+
if max_items and len(unique_values) >= max_items:
|
| 318 |
+
break
|
| 319 |
+
|
| 320 |
+
result = sorted(list(unique_values))[:max_items] if max_items else sorted(list(unique_values))
|
| 321 |
+
logger.info(f"Collected {len(result)} unique {column} values")
|
| 322 |
+
return result
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def collect_stage1_species(shards_glob: str) -> List[str]:
|
| 326 |
+
"""Extract unique species from Stage-1 shards."""
|
| 327 |
+
return collect_unique_values_from_shards(shards_glob, "Taxon")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def prepare_species_from_stage1_shards(
|
| 331 |
+
shards_glob: str,
|
| 332 |
+
output_dir: str,
|
| 333 |
+
device: str = "cuda",
|
| 334 |
+
resume: bool = False,
|
| 335 |
+
species_pooling: str = "last"
|
| 336 |
+
) -> None:
|
| 337 |
+
"""End-to-end species embedding generation from Stage-1 shards."""
|
| 338 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 339 |
+
|
| 340 |
+
# Check for existing files
|
| 341 |
+
vocab_path = os.path.join(output_dir, "species_vocab.json")
|
| 342 |
+
if resume and os.path.exists(vocab_path):
|
| 343 |
+
logger.info("Species embeddings already exist. Skipping generation.")
|
| 344 |
+
return
|
| 345 |
+
|
| 346 |
+
# Collect unique species
|
| 347 |
+
species_list = collect_stage1_species(shards_glob)
|
| 348 |
+
logger.info(f"Found {len(species_list)} unique species in shards")
|
| 349 |
+
|
| 350 |
+
# Build taxonomy database
|
| 351 |
+
taxonomy_cache_path = os.path.join(output_dir, "taxonomy_database.json")
|
| 352 |
+
if resume and os.path.exists(taxonomy_cache_path):
|
| 353 |
+
logger.info("Loading cached taxonomy database...")
|
| 354 |
+
with open(taxonomy_cache_path, 'r') as f:
|
| 355 |
+
taxonomy_db = json.load(f)
|
| 356 |
+
else:
|
| 357 |
+
taxonomy_db = build_taxonomy_database(species_list)
|
| 358 |
+
with open(taxonomy_cache_path, 'w') as f:
|
| 359 |
+
json.dump(taxonomy_db, f, indent=2)
|
| 360 |
+
|
| 361 |
+
# Generate embeddings
|
| 362 |
+
species_vocab, species_embeddings = generate_species_embeddings_qwen(
|
| 363 |
+
species_list, taxonomy_db, device, pooling=species_pooling
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Save per requested pooling
|
| 367 |
+
if species_pooling == 'last':
|
| 368 |
+
save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir)
|
| 369 |
+
else:
|
| 370 |
+
save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir)
|
| 371 |
+
|
| 372 |
+
logger.info("Species embedding preparation complete")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def create_precomputed_dataset(
|
| 376 |
+
input_csv: Optional[str],
|
| 377 |
+
output_dir: str,
|
| 378 |
+
device: str = "cuda",
|
| 379 |
+
batch_size: int = 50,
|
| 380 |
+
max_protein_length: int = 2048,
|
| 381 |
+
resume: bool = False,
|
| 382 |
+
species_pooling: str = "last"
|
| 383 |
+
):
|
| 384 |
+
"""
|
| 385 |
+
Create embedding dataset with species-only precomputation.
|
| 386 |
+
Protein embeddings will be computed on-the-fly during training.
|
| 387 |
+
"""
|
| 388 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 389 |
+
|
| 390 |
+
# Skip if resuming and files exist
|
| 391 |
+
if resume and os.path.exists(os.path.join(output_dir, "species_vocab.json")):
|
| 392 |
+
logger.info("Precomputed dataset already exists. Use --resume=False to regenerate.")
|
| 393 |
+
return
|
| 394 |
+
|
| 395 |
+
# Load data
|
| 396 |
+
logger.info(f"Loading data from {input_csv}...")
|
| 397 |
+
if input_csv.endswith('.parquet'):
|
| 398 |
+
df = pd.read_parquet(input_csv)
|
| 399 |
+
else:
|
| 400 |
+
df = pd.read_csv(input_csv)
|
| 401 |
+
|
| 402 |
+
# Accept either 'Taxon' or 'taxon' as the species column.
|
| 403 |
+
if "Taxon" not in df.columns and "taxon" in df.columns:
|
| 404 |
+
df = df.rename(columns={"taxon": "Taxon"})
|
| 405 |
+
|
| 406 |
+
# Filter sequences by length
|
| 407 |
+
df = filter_sequences_by_length(df, max_protein_length)
|
| 408 |
+
|
| 409 |
+
# === Species Embeddings ===
|
| 410 |
+
logger.info("=== Generating Species Embeddings ===")
|
| 411 |
+
unique_species = df["Taxon"].dropna().unique().tolist()
|
| 412 |
+
logger.info(f"Found {len(unique_species)} unique species")
|
| 413 |
+
|
| 414 |
+
# Build taxonomy database
|
| 415 |
+
taxonomy_db = build_taxonomy_database(unique_species)
|
| 416 |
+
|
| 417 |
+
# Save taxonomy database
|
| 418 |
+
taxonomy_path = os.path.join(output_dir, "taxonomy_database.json")
|
| 419 |
+
with open(taxonomy_path, 'w') as f:
|
| 420 |
+
json.dump(taxonomy_db, f, indent=2)
|
| 421 |
+
|
| 422 |
+
# Generate species embeddings
|
| 423 |
+
species_vocab, species_embeddings = generate_species_embeddings_qwen(
|
| 424 |
+
unique_species, taxonomy_db, device, pooling=species_pooling
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if species_pooling == 'last':
|
| 428 |
+
save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir)
|
| 429 |
+
else:
|
| 430 |
+
save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir)
|
| 431 |
+
|
| 432 |
+
# Save metadata
|
| 433 |
+
metadata = {
|
| 434 |
+
"num_sequences": len(df),
|
| 435 |
+
"num_species": len(unique_species),
|
| 436 |
+
"species_embedding_model": "Qwen/Qwen3-Embedding-0.6B",
|
| 437 |
+
"species_embedding_dim": 1024, # Qwen3 dimension
|
| 438 |
+
"max_protein_length": max_protein_length,
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
with open(os.path.join(output_dir, "metadata.json"), 'w') as f:
|
| 442 |
+
json.dump(metadata, f, indent=2)
|
| 443 |
+
|
| 444 |
+
logger.info(f"Dataset creation completed. Species embeddings are precomputed.")
|
| 445 |
+
logger.info("Protein embeddings will be computed on-the-fly during training using integrated ESM-C.")
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def main():
|
| 449 |
+
parser = argparse.ArgumentParser(description="Precompute species embeddings for CodonTranslator")
|
| 450 |
+
|
| 451 |
+
# Data source options
|
| 452 |
+
parser.add_argument("--input_csv", type=str,
|
| 453 |
+
help="Path to input CSV/Parquet file")
|
| 454 |
+
parser.add_argument("--from_stage1_shards", action="store_true",
|
| 455 |
+
help="Generate from Stage-1 Parquet shards instead of CSV")
|
| 456 |
+
parser.add_argument("--stage1_shards_glob", type=str, default="./data/shards/*.parquet",
|
| 457 |
+
help="Glob pattern for Stage-1 shards")
|
| 458 |
+
|
| 459 |
+
# Output
|
| 460 |
+
parser.add_argument("--output_dir", type=str, required=True,
|
| 461 |
+
help="Output directory for precomputed embeddings")
|
| 462 |
+
|
| 463 |
+
# Processing options
|
| 464 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 465 |
+
help="Device for model inference")
|
| 466 |
+
parser.add_argument("--batch_size", type=int, default=50,
|
| 467 |
+
help="Batch size for embedding generation")
|
| 468 |
+
parser.add_argument("--max_protein_length", type=int, default=2048,
|
| 469 |
+
help="Maximum protein sequence length")
|
| 470 |
+
parser.add_argument("--resume", action="store_true",
|
| 471 |
+
help="Resume from checkpoint if available")
|
| 472 |
+
parser.add_argument("--species_pooling", type=str, choices=["last", "sequence", "none"], default="last",
|
| 473 |
+
help="'last' for single-token; 'sequence' for variable-length token embeddings")
|
| 474 |
+
|
| 475 |
+
args = parser.parse_args()
|
| 476 |
+
|
| 477 |
+
# Route to appropriate function
|
| 478 |
+
if args.from_stage1_shards:
|
| 479 |
+
prepare_species_from_stage1_shards(
|
| 480 |
+
args.stage1_shards_glob,
|
| 481 |
+
args.output_dir,
|
| 482 |
+
args.device,
|
| 483 |
+
args.resume,
|
| 484 |
+
args.species_pooling
|
| 485 |
+
)
|
| 486 |
+
elif args.input_csv:
|
| 487 |
+
create_precomputed_dataset(
|
| 488 |
+
args.input_csv,
|
| 489 |
+
args.output_dir,
|
| 490 |
+
args.device,
|
| 491 |
+
args.batch_size,
|
| 492 |
+
args.max_protein_length,
|
| 493 |
+
args.resume,
|
| 494 |
+
args.species_pooling
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
raise ValueError("Must specify either --input_csv or --from_stage1_shards")
|
| 498 |
+
|
| 499 |
+
logger.info("Precomputation complete!")
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
if __name__ == "__main__":
|
| 503 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "CodonTranslator"
|
| 7 |
+
version = "0.1.1"
|
| 8 |
+
description = "Sampling codon sequences conditioned on species and protein using a GPT model"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
authors = [{name = "CodonTranslator Team"}]
|
| 13 |
+
dependencies = [
|
| 14 |
+
"torch>=2.4",
|
| 15 |
+
"transformers>=4.57.0",
|
| 16 |
+
"esm>=3.2.3",
|
| 17 |
+
"safetensors>=0.7.0",
|
| 18 |
+
"numpy>=2.2.0",
|
| 19 |
+
"huggingface-hub>=0.36.0",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[tool.setuptools]
|
| 23 |
+
package-dir = {"" = "."}
|
| 24 |
+
packages = ["CodonTranslator", "codontranslator"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.4
|
| 2 |
+
transformers>=4.57.0
|
| 3 |
+
esm>=3.2.3
|
| 4 |
+
safetensors>=0.7.0
|
| 5 |
+
numpy>=2.2.0
|
| 6 |
+
huggingface-hub>=0.36.0
|
| 7 |
+
accelerate>=1.9.0
|
| 8 |
+
pyarrow>=21.0.0
|
| 9 |
+
pandas>=2.3.0
|
| 10 |
+
duckdb>=1.5.0
|
| 11 |
+
biopython>=1.85
|
| 12 |
+
wandb>=0.21.0
|
resplit_data_v3.py
ADDED
|
@@ -0,0 +1,1444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Resplit `data_v2/` into leakage-safe `data_v3_rebuild/` using MMseqs2 clustering.
|
| 4 |
+
|
| 5 |
+
Default policy for the current rebuild:
|
| 6 |
+
- Cluster `protein_seq` with MMseqs2 `linclust`
|
| 7 |
+
- Define species by normalized binomial name (`genus species`)
|
| 8 |
+
- Test species are exactly the normalized species present in `data_v2/test`
|
| 9 |
+
- Validation is cluster-unseen but species-seen
|
| 10 |
+
- Mixed seen/heldout clusters keep heldout rows in test and drop seen rows
|
| 11 |
+
|
| 12 |
+
Typical usage (end-to-end):
|
| 13 |
+
python resplit_data_v3.py all --threads 32 --split-memory-limit 120G --num-shards 256
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import shutil
|
| 22 |
+
import stat
|
| 23 |
+
import subprocess
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _default_mmseqs_path() -> str:
|
| 31 |
+
cand = Path("MMseqs2/build/bin/mmseqs")
|
| 32 |
+
if cand.exists():
|
| 33 |
+
return str(cand)
|
| 34 |
+
return "mmseqs"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _run(cmd: List[str], *, cwd: Optional[str] = None, env: Optional[dict] = None) -> None:
|
| 38 |
+
pretty = " ".join(cmd)
|
| 39 |
+
print(f"+ {pretty}", flush=True)
|
| 40 |
+
subprocess.run(cmd, cwd=cwd, env=env, check=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _sql_escape_path(path: str) -> str:
|
| 44 |
+
return path.replace("'", "''")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _expand_parquet_inputs(inp: str) -> List[str]:
|
| 48 |
+
import glob
|
| 49 |
+
|
| 50 |
+
p = Path(inp)
|
| 51 |
+
if p.exists() and p.is_dir():
|
| 52 |
+
files = sorted(str(x) for x in p.rglob("*.parquet"))
|
| 53 |
+
else:
|
| 54 |
+
files = sorted(glob.glob(inp))
|
| 55 |
+
|
| 56 |
+
seen = set()
|
| 57 |
+
out: List[str] = []
|
| 58 |
+
for f in files:
|
| 59 |
+
if f not in seen:
|
| 60 |
+
out.append(f)
|
| 61 |
+
seen.add(f)
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _duckdb_parquet_source(inp: str, limit_files: int = 0) -> str:
|
| 66 |
+
files = _expand_parquet_inputs(inp)
|
| 67 |
+
if not files:
|
| 68 |
+
raise SystemExit(f"No parquet files found for {inp!r}")
|
| 69 |
+
if limit_files and int(limit_files) > 0:
|
| 70 |
+
files = files[: int(limit_files)]
|
| 71 |
+
quoted = ", ".join(f"'{_sql_escape_path(fp)}'" for fp in files)
|
| 72 |
+
return f"read_parquet([{quoted}])"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _mem_total_bytes() -> Optional[int]:
|
| 76 |
+
try:
|
| 77 |
+
with open("/proc/meminfo", "r", encoding="utf-8") as f:
|
| 78 |
+
for line in f:
|
| 79 |
+
if line.startswith("MemTotal:"):
|
| 80 |
+
parts = line.split()
|
| 81 |
+
kb = int(parts[1])
|
| 82 |
+
return kb * 1024
|
| 83 |
+
except OSError:
|
| 84 |
+
return None
|
| 85 |
+
except (ValueError, IndexError):
|
| 86 |
+
return None
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _parse_mmseqs_bytes(s: str) -> Optional[int]:
|
| 91 |
+
s = (s or "").strip()
|
| 92 |
+
if not s:
|
| 93 |
+
return None
|
| 94 |
+
up = s.upper()
|
| 95 |
+
suffix = up[-1]
|
| 96 |
+
num_part = up[:-1]
|
| 97 |
+
unit = suffix
|
| 98 |
+
if suffix == "B" and len(up) >= 2 and up[-2] in "KMGT":
|
| 99 |
+
unit = up[-2]
|
| 100 |
+
num_part = up[:-2]
|
| 101 |
+
if unit not in "BKMGT":
|
| 102 |
+
return None
|
| 103 |
+
try:
|
| 104 |
+
val = float(num_part)
|
| 105 |
+
except ValueError:
|
| 106 |
+
return None
|
| 107 |
+
mult = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}[unit]
|
| 108 |
+
return int(val * mult)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _format_bytes(n: int) -> str:
|
| 112 |
+
for unit, div in [("TiB", 1024**4), ("GiB", 1024**3), ("MiB", 1024**2), ("KiB", 1024)]:
|
| 113 |
+
if n >= div:
|
| 114 |
+
return f"{n / div:.1f}{unit}"
|
| 115 |
+
return f"{n}B"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _seq_id_sql() -> str:
|
| 119 |
+
# Keep the stable row identifier aligned with the existing pipeline.
|
| 120 |
+
return "coalesce(protein_refseq_id, '') || '|' || coalesce(RefseqID, '')"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _taxon_norm_sql(col: str = "taxon") -> str:
|
| 124 |
+
return f"regexp_replace(lower(trim(coalesce({col}, ''))), '\\\\s+', ' ', 'g')"
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _species_key_sql(mode: str, col: str = "taxon") -> str:
|
| 128 |
+
norm = _taxon_norm_sql(col)
|
| 129 |
+
if mode == "taxon":
|
| 130 |
+
return norm
|
| 131 |
+
if mode == "binomial":
|
| 132 |
+
return (
|
| 133 |
+
f"CASE "
|
| 134 |
+
f"WHEN strpos({norm}, ' ') > 0 "
|
| 135 |
+
f"THEN split_part({norm}, ' ', 1) || ' ' || split_part({norm}, ' ', 2) "
|
| 136 |
+
f"ELSE {norm} END"
|
| 137 |
+
)
|
| 138 |
+
raise ValueError(f"Unsupported species key mode: {mode}")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _protein_norm_sql(col: str = "protein_seq") -> str:
|
| 142 |
+
cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')"
|
| 143 |
+
no_stop = f"regexp_replace({cleaned}, '[_*]+$', '')"
|
| 144 |
+
return f"regexp_replace({no_stop}, '[^A-Z]', 'X', 'g')"
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _cds_norm_sql(col: str = "cds_DNA") -> str:
|
| 148 |
+
cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')"
|
| 149 |
+
return f"regexp_replace({cleaned}, '[^ACGTN]', 'N', 'g')"
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _seq_expr_sql(seq_space: str) -> str:
|
| 153 |
+
if seq_space == "protein":
|
| 154 |
+
return _protein_norm_sql("protein_seq")
|
| 155 |
+
if seq_space == "cds":
|
| 156 |
+
return _cds_norm_sql("cds_DNA")
|
| 157 |
+
raise ValueError(f"Unsupported seq space: {seq_space}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _seq_space_input_col(seq_space: str) -> str:
|
| 161 |
+
if seq_space == "protein":
|
| 162 |
+
return "protein_seq"
|
| 163 |
+
if seq_space == "cds":
|
| 164 |
+
return "cds_DNA"
|
| 165 |
+
raise ValueError(f"Unsupported seq space: {seq_space}")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _mmseqs_dbtype(seq_space: str) -> str:
|
| 169 |
+
if seq_space == "protein":
|
| 170 |
+
return "1"
|
| 171 |
+
if seq_space == "cds":
|
| 172 |
+
return "2"
|
| 173 |
+
raise ValueError(f"Unsupported seq space: {seq_space}")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _default_max_input_seq_len(seq_space: str) -> int:
|
| 177 |
+
if seq_space == "protein":
|
| 178 |
+
# MMseqs linclust hit an internal SW bug on a tiny tail of ultra-long proteins
|
| 179 |
+
# (~39k aa+). Filtering this tail removes <0.01% of rows and keeps the run stable.
|
| 180 |
+
return 20_000
|
| 181 |
+
return 0
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _ensure_mmseqs_ready(mmseqs: str) -> Tuple[str, Dict[str, str]]:
|
| 185 |
+
path = Path(mmseqs)
|
| 186 |
+
env = os.environ.copy()
|
| 187 |
+
|
| 188 |
+
if path.exists():
|
| 189 |
+
mode = path.stat().st_mode
|
| 190 |
+
if not (mode & stat.S_IXUSR):
|
| 191 |
+
path.chmod(mode | stat.S_IXUSR)
|
| 192 |
+
|
| 193 |
+
py = Path(sys.executable).resolve()
|
| 194 |
+
env_root = py.parent.parent
|
| 195 |
+
conda_root = env_root.parent.parent if env_root.parent.name == "envs" else env_root.parent
|
| 196 |
+
lib_candidates = [env_root / "lib", conda_root / "lib"]
|
| 197 |
+
libs = [str(p) for p in lib_candidates if p.exists()]
|
| 198 |
+
if libs:
|
| 199 |
+
current = env.get("LD_LIBRARY_PATH", "")
|
| 200 |
+
env["LD_LIBRARY_PATH"] = ":".join(libs + ([current] if current else []))
|
| 201 |
+
|
| 202 |
+
return str(path if path.exists() else mmseqs), env
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _ensure_output_parent(path: Path) -> None:
|
| 206 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def cmd_make_fasta(args: argparse.Namespace) -> None:
|
| 210 |
+
out_fasta = Path(args.output_fasta)
|
| 211 |
+
_ensure_output_parent(out_fasta)
|
| 212 |
+
|
| 213 |
+
import duckdb
|
| 214 |
+
|
| 215 |
+
con = duckdb.connect()
|
| 216 |
+
con.execute(f"PRAGMA threads={int(args.threads)};")
|
| 217 |
+
con.execute("PRAGMA enable_progress_bar=true;")
|
| 218 |
+
|
| 219 |
+
source_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
|
| 220 |
+
out_path = _sql_escape_path(str(out_fasta))
|
| 221 |
+
seq_id = _seq_id_sql()
|
| 222 |
+
seq_expr = _seq_expr_sql(args.seq_space)
|
| 223 |
+
raw_col = _seq_space_input_col(args.seq_space)
|
| 224 |
+
max_input_seq_len = int(args.max_input_seq_len)
|
| 225 |
+
if max_input_seq_len <= 0:
|
| 226 |
+
max_input_seq_len = _default_max_input_seq_len(args.seq_space)
|
| 227 |
+
len_filter = (
|
| 228 |
+
f"AND length({seq_expr}) <= {max_input_seq_len}"
|
| 229 |
+
if max_input_seq_len > 0
|
| 230 |
+
else ""
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
sql = f"""
|
| 234 |
+
COPY (
|
| 235 |
+
SELECT
|
| 236 |
+
'>' || ({seq_id}) AS header,
|
| 237 |
+
{seq_expr} AS seq
|
| 238 |
+
FROM {source_sql}
|
| 239 |
+
WHERE {raw_col} IS NOT NULL
|
| 240 |
+
AND length({seq_expr}) > 0
|
| 241 |
+
{len_filter}
|
| 242 |
+
AND length(({seq_id})) > 1
|
| 243 |
+
{f"LIMIT {int(args.limit_rows)}" if args.limit_rows and int(args.limit_rows) > 0 else ""}
|
| 244 |
+
)
|
| 245 |
+
TO '{out_path}'
|
| 246 |
+
(FORMAT CSV, DELIMITER '\n', QUOTE '', ESCAPE '', HEADER FALSE);
|
| 247 |
+
"""
|
| 248 |
+
t0 = time.time()
|
| 249 |
+
con.execute(sql)
|
| 250 |
+
print(
|
| 251 |
+
f"Wrote FASTA: {out_fasta} seq_space={args.seq_space} "
|
| 252 |
+
f"max_input_seq_len={max_input_seq_len if max_input_seq_len > 0 else 'none'} "
|
| 253 |
+
f"(elapsed_s={time.time() - t0:.1f})"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def cmd_mmseqs_cluster(args: argparse.Namespace) -> None:
|
| 258 |
+
mmseqs, env = _ensure_mmseqs_ready(args.mmseqs)
|
| 259 |
+
workdir = Path(args.workdir)
|
| 260 |
+
workdir.mkdir(parents=True, exist_ok=True)
|
| 261 |
+
|
| 262 |
+
fasta = Path(args.fasta)
|
| 263 |
+
if not fasta.exists():
|
| 264 |
+
raise SystemExit(f"FASTA not found: {fasta}")
|
| 265 |
+
|
| 266 |
+
seqdb = workdir / "seqdb"
|
| 267 |
+
clu = workdir / "clu"
|
| 268 |
+
tmp = workdir / "tmp"
|
| 269 |
+
tsv = workdir / "clu.tsv"
|
| 270 |
+
|
| 271 |
+
if args.overwrite:
|
| 272 |
+
for p in (seqdb, clu, tmp, tsv):
|
| 273 |
+
if p.is_dir():
|
| 274 |
+
shutil.rmtree(p, ignore_errors=True)
|
| 275 |
+
else:
|
| 276 |
+
for suffix in ("", ".dbtype", ".index", ".lookup", ".source"):
|
| 277 |
+
try:
|
| 278 |
+
os.remove(str(p) + suffix)
|
| 279 |
+
except OSError:
|
| 280 |
+
pass
|
| 281 |
+
|
| 282 |
+
tmp.mkdir(parents=True, exist_ok=True)
|
| 283 |
+
|
| 284 |
+
_run(
|
| 285 |
+
[
|
| 286 |
+
mmseqs,
|
| 287 |
+
"createdb",
|
| 288 |
+
str(fasta),
|
| 289 |
+
str(seqdb),
|
| 290 |
+
"--dbtype",
|
| 291 |
+
_mmseqs_dbtype(args.seq_space),
|
| 292 |
+
"--shuffle",
|
| 293 |
+
"0",
|
| 294 |
+
"--createdb-mode",
|
| 295 |
+
"1",
|
| 296 |
+
"--threads",
|
| 297 |
+
str(int(args.threads)),
|
| 298 |
+
],
|
| 299 |
+
env=env,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
linclust_cmd = [
|
| 303 |
+
mmseqs,
|
| 304 |
+
"linclust",
|
| 305 |
+
str(seqdb),
|
| 306 |
+
str(clu),
|
| 307 |
+
str(tmp),
|
| 308 |
+
"--min-seq-id",
|
| 309 |
+
str(float(args.min_seq_id)),
|
| 310 |
+
"-c",
|
| 311 |
+
str(float(args.coverage)),
|
| 312 |
+
"--cov-mode",
|
| 313 |
+
str(int(args.cov_mode)),
|
| 314 |
+
"--cluster-mode",
|
| 315 |
+
str(int(args.cluster_mode)),
|
| 316 |
+
"--threads",
|
| 317 |
+
str(int(args.threads)),
|
| 318 |
+
"--max-seq-len",
|
| 319 |
+
str(int(args.max_seq_len)),
|
| 320 |
+
"--remove-tmp-files",
|
| 321 |
+
"1" if args.remove_tmp_files else "0",
|
| 322 |
+
]
|
| 323 |
+
if args.split_memory_limit:
|
| 324 |
+
mem_total = _mem_total_bytes()
|
| 325 |
+
limit_bytes = _parse_mmseqs_bytes(args.split_memory_limit)
|
| 326 |
+
if mem_total and limit_bytes and limit_bytes > mem_total:
|
| 327 |
+
print(
|
| 328 |
+
f"WARNING: --split-memory-limit={args.split_memory_limit} ({_format_bytes(limit_bytes)}) "
|
| 329 |
+
f"exceeds system MemTotal ({_format_bytes(mem_total)}). "
|
| 330 |
+
"MMseqs2 may under-split and crash; consider lowering it or leaving it empty.",
|
| 331 |
+
file=sys.stderr,
|
| 332 |
+
flush=True,
|
| 333 |
+
)
|
| 334 |
+
linclust_cmd += ["--split-memory-limit", str(args.split_memory_limit)]
|
| 335 |
+
if args.kmer_per_seq_scale is not None:
|
| 336 |
+
linclust_cmd += ["--kmer-per-seq-scale", str(float(args.kmer_per_seq_scale))]
|
| 337 |
+
|
| 338 |
+
_run(linclust_cmd, env=env)
|
| 339 |
+
_run([mmseqs, "createtsv", str(seqdb), str(seqdb), str(clu), str(tsv)], env=env)
|
| 340 |
+
print(f"Wrote cluster TSV: {tsv}")
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def cmd_make_seq_cluster(args: argparse.Namespace) -> None:
|
| 344 |
+
import duckdb
|
| 345 |
+
|
| 346 |
+
tsv = Path(args.cluster_tsv)
|
| 347 |
+
if not tsv.exists():
|
| 348 |
+
raise SystemExit(f"Cluster TSV not found: {tsv}")
|
| 349 |
+
out = Path(args.output_parquet)
|
| 350 |
+
_ensure_output_parent(out)
|
| 351 |
+
|
| 352 |
+
con = duckdb.connect()
|
| 353 |
+
con.execute(f"PRAGMA threads={int(args.threads)};")
|
| 354 |
+
con.execute("PRAGMA enable_progress_bar=true;")
|
| 355 |
+
|
| 356 |
+
tsv_path = _sql_escape_path(str(tsv))
|
| 357 |
+
out_path = _sql_escape_path(str(out))
|
| 358 |
+
|
| 359 |
+
sql = f"""
|
| 360 |
+
COPY (
|
| 361 |
+
SELECT DISTINCT
|
| 362 |
+
seq_id,
|
| 363 |
+
cluster_id
|
| 364 |
+
FROM read_csv(
|
| 365 |
+
'{tsv_path}',
|
| 366 |
+
delim='\\t',
|
| 367 |
+
header=false,
|
| 368 |
+
columns={{'cluster_id':'VARCHAR','seq_id':'VARCHAR'}}
|
| 369 |
+
)
|
| 370 |
+
)
|
| 371 |
+
TO '{out_path}'
|
| 372 |
+
(FORMAT PARQUET);
|
| 373 |
+
"""
|
| 374 |
+
t0 = time.time()
|
| 375 |
+
con.execute(sql)
|
| 376 |
+
print(f"Wrote seq→cluster parquet: {out} (elapsed_s={time.time() - t0:.1f})")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def _write_cluster_split_parquet(
|
| 380 |
+
con,
|
| 381 |
+
*,
|
| 382 |
+
cluster_split_path: Path,
|
| 383 |
+
seed: int,
|
| 384 |
+
val_frac: float,
|
| 385 |
+
) -> Dict[str, int]:
|
| 386 |
+
import pyarrow as pa
|
| 387 |
+
import pyarrow.parquet as pq
|
| 388 |
+
|
| 389 |
+
cluster_split_path.parent.mkdir(parents=True, exist_ok=True)
|
| 390 |
+
if cluster_split_path.exists():
|
| 391 |
+
cluster_split_path.unlink()
|
| 392 |
+
|
| 393 |
+
total_seen_rows = int(
|
| 394 |
+
con.execute(
|
| 395 |
+
"SELECT coalesce(sum(n_total), 0)::BIGINT FROM cluster_flags WHERE n_test = 0"
|
| 396 |
+
).fetchone()[0]
|
| 397 |
+
)
|
| 398 |
+
target_val_rows = int(total_seen_rows * float(val_frac))
|
| 399 |
+
|
| 400 |
+
species_remaining = {
|
| 401 |
+
species_key: int(n_clusters)
|
| 402 |
+
for species_key, n_clusters in con.execute(
|
| 403 |
+
"""
|
| 404 |
+
SELECT
|
| 405 |
+
cc.species_key,
|
| 406 |
+
count(*)::BIGINT AS n_clusters
|
| 407 |
+
FROM cluster_counts cc
|
| 408 |
+
JOIN cluster_flags cf USING (cluster_id)
|
| 409 |
+
WHERE cf.n_test = 0
|
| 410 |
+
GROUP BY cc.species_key
|
| 411 |
+
"""
|
| 412 |
+
).fetchall()
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
cur = con.execute(
|
| 416 |
+
f"""
|
| 417 |
+
SELECT
|
| 418 |
+
cf.cluster_id,
|
| 419 |
+
cf.n_total,
|
| 420 |
+
abs(hash(cf.cluster_id || ':{seed}')) AS rnd,
|
| 421 |
+
cc.species_key
|
| 422 |
+
FROM cluster_flags cf
|
| 423 |
+
JOIN cluster_counts cc USING (cluster_id)
|
| 424 |
+
WHERE cf.n_test = 0
|
| 425 |
+
ORDER BY rnd, cf.cluster_id, cc.species_key
|
| 426 |
+
"""
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
writer = None
|
| 430 |
+
batch_cluster_ids: List[str] = []
|
| 431 |
+
batch_splits: List[str] = []
|
| 432 |
+
val_rows = 0
|
| 433 |
+
train_clusters = 0
|
| 434 |
+
val_clusters = 0
|
| 435 |
+
current_cluster: Optional[str] = None
|
| 436 |
+
current_n_total = 0
|
| 437 |
+
current_species: List[str] = []
|
| 438 |
+
|
| 439 |
+
def flush_current() -> None:
|
| 440 |
+
nonlocal writer, val_rows, train_clusters, val_clusters
|
| 441 |
+
nonlocal current_cluster, current_n_total, current_species
|
| 442 |
+
if current_cluster is None:
|
| 443 |
+
return
|
| 444 |
+
can_val = (
|
| 445 |
+
val_rows < target_val_rows
|
| 446 |
+
and all(species_remaining.get(species_key, 0) > 1 for species_key in current_species)
|
| 447 |
+
)
|
| 448 |
+
split = "val" if can_val else "train"
|
| 449 |
+
if can_val:
|
| 450 |
+
for species_key in current_species:
|
| 451 |
+
species_remaining[species_key] -= 1
|
| 452 |
+
val_rows += int(current_n_total)
|
| 453 |
+
val_clusters += 1
|
| 454 |
+
else:
|
| 455 |
+
train_clusters += 1
|
| 456 |
+
|
| 457 |
+
batch_cluster_ids.append(current_cluster)
|
| 458 |
+
batch_splits.append(split)
|
| 459 |
+
if len(batch_cluster_ids) >= 200_000:
|
| 460 |
+
table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits})
|
| 461 |
+
if writer is None:
|
| 462 |
+
writer = pq.ParquetWriter(str(cluster_split_path), table.schema)
|
| 463 |
+
writer.write_table(table)
|
| 464 |
+
batch_cluster_ids.clear()
|
| 465 |
+
batch_splits.clear()
|
| 466 |
+
|
| 467 |
+
while True:
|
| 468 |
+
rows = cur.fetchmany(200_000)
|
| 469 |
+
if not rows:
|
| 470 |
+
break
|
| 471 |
+
for cluster_id, n_total, _rnd, species_key in rows:
|
| 472 |
+
cluster_id = str(cluster_id)
|
| 473 |
+
species_key = str(species_key)
|
| 474 |
+
if current_cluster is None:
|
| 475 |
+
current_cluster = cluster_id
|
| 476 |
+
current_n_total = int(n_total)
|
| 477 |
+
current_species = [species_key]
|
| 478 |
+
continue
|
| 479 |
+
if cluster_id != current_cluster:
|
| 480 |
+
flush_current()
|
| 481 |
+
current_cluster = cluster_id
|
| 482 |
+
current_n_total = int(n_total)
|
| 483 |
+
current_species = [species_key]
|
| 484 |
+
continue
|
| 485 |
+
current_species.append(species_key)
|
| 486 |
+
|
| 487 |
+
flush_current()
|
| 488 |
+
if batch_cluster_ids:
|
| 489 |
+
table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits})
|
| 490 |
+
if writer is None:
|
| 491 |
+
writer = pq.ParquetWriter(str(cluster_split_path), table.schema)
|
| 492 |
+
writer.write_table(table)
|
| 493 |
+
elif writer is None:
|
| 494 |
+
empty = pa.table(
|
| 495 |
+
{
|
| 496 |
+
"cluster_id": pa.array([], type=pa.string()),
|
| 497 |
+
"split": pa.array([], type=pa.string()),
|
| 498 |
+
}
|
| 499 |
+
)
|
| 500 |
+
writer = pq.ParquetWriter(str(cluster_split_path), empty.schema)
|
| 501 |
+
writer.write_table(empty)
|
| 502 |
+
if writer is not None:
|
| 503 |
+
writer.close()
|
| 504 |
+
|
| 505 |
+
return {
|
| 506 |
+
"nonheldout_total_rows": total_seen_rows,
|
| 507 |
+
"target_val_rows": target_val_rows,
|
| 508 |
+
"actual_val_rows": val_rows,
|
| 509 |
+
"train_clusters": train_clusters,
|
| 510 |
+
"val_clusters": val_clusters,
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def cmd_make_seq_split(args: argparse.Namespace) -> None:
|
| 515 |
+
import duckdb
|
| 516 |
+
|
| 517 |
+
seq_cluster = Path(args.seq_cluster_parquet)
|
| 518 |
+
if not seq_cluster.exists():
|
| 519 |
+
raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
|
| 520 |
+
|
| 521 |
+
out = Path(args.output_parquet)
|
| 522 |
+
cluster_split = Path(args.cluster_split_parquet)
|
| 523 |
+
_ensure_output_parent(out)
|
| 524 |
+
_ensure_output_parent(cluster_split)
|
| 525 |
+
|
| 526 |
+
con = duckdb.connect()
|
| 527 |
+
con.execute(f"PRAGMA threads={int(args.threads)};")
|
| 528 |
+
con.execute("PRAGMA enable_progress_bar=true;")
|
| 529 |
+
|
| 530 |
+
input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
|
| 531 |
+
heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0)
|
| 532 |
+
seq_cluster_path = _sql_escape_path(str(seq_cluster))
|
| 533 |
+
out_path = _sql_escape_path(str(out))
|
| 534 |
+
|
| 535 |
+
seq_id = _seq_id_sql()
|
| 536 |
+
species_key = _species_key_sql(args.species_key_mode, "taxon")
|
| 537 |
+
protein_norm = _protein_norm_sql("protein_seq")
|
| 538 |
+
|
| 539 |
+
con.execute(
|
| 540 |
+
f"""
|
| 541 |
+
CREATE TEMP TABLE heldout_species AS
|
| 542 |
+
SELECT DISTINCT {species_key} AS species_key
|
| 543 |
+
FROM {heldout_sql}
|
| 544 |
+
WHERE {species_key} != '';
|
| 545 |
+
"""
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
con.execute(
|
| 549 |
+
f"""
|
| 550 |
+
CREATE TEMP TABLE cluster_counts AS
|
| 551 |
+
WITH base AS (
|
| 552 |
+
SELECT
|
| 553 |
+
{seq_id} AS seq_id,
|
| 554 |
+
{species_key} AS species_key
|
| 555 |
+
FROM {input_sql}
|
| 556 |
+
WHERE length(({seq_id})) > 1
|
| 557 |
+
AND {species_key} != ''
|
| 558 |
+
)
|
| 559 |
+
SELECT
|
| 560 |
+
sc.cluster_id,
|
| 561 |
+
base.species_key,
|
| 562 |
+
count(*)::BIGINT AS n
|
| 563 |
+
FROM base
|
| 564 |
+
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
|
| 565 |
+
GROUP BY sc.cluster_id, base.species_key;
|
| 566 |
+
"""
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
con.execute(
|
| 570 |
+
"""
|
| 571 |
+
CREATE TEMP TABLE cluster_flags AS
|
| 572 |
+
SELECT
|
| 573 |
+
cluster_id,
|
| 574 |
+
sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test,
|
| 575 |
+
sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen,
|
| 576 |
+
sum(n)::BIGINT AS n_total,
|
| 577 |
+
count(*)::BIGINT AS n_species
|
| 578 |
+
FROM cluster_counts
|
| 579 |
+
GROUP BY cluster_id;
|
| 580 |
+
"""
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
t0 = time.time()
|
| 584 |
+
split_summary = _write_cluster_split_parquet(
|
| 585 |
+
con,
|
| 586 |
+
cluster_split_path=cluster_split,
|
| 587 |
+
seed=int(args.seed),
|
| 588 |
+
val_frac=float(args.val_frac),
|
| 589 |
+
)
|
| 590 |
+
print(
|
| 591 |
+
"Cluster assignment summary: "
|
| 592 |
+
f"train_clusters={split_summary['train_clusters']:,} "
|
| 593 |
+
f"val_clusters={split_summary['val_clusters']:,} "
|
| 594 |
+
f"target_val_rows={split_summary['target_val_rows']:,} "
|
| 595 |
+
f"actual_val_rows={split_summary['actual_val_rows']:,} "
|
| 596 |
+
f"(elapsed_s={time.time() - t0:.1f})"
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
cluster_split_path = _sql_escape_path(str(cluster_split))
|
| 600 |
+
con.execute(
|
| 601 |
+
f"""
|
| 602 |
+
COPY (
|
| 603 |
+
WITH base AS (
|
| 604 |
+
SELECT DISTINCT
|
| 605 |
+
{seq_id} AS seq_id,
|
| 606 |
+
{species_key} AS species_key,
|
| 607 |
+
{protein_norm} AS protein_norm
|
| 608 |
+
FROM {input_sql}
|
| 609 |
+
WHERE length(({seq_id})) > 1
|
| 610 |
+
AND {species_key} != ''
|
| 611 |
+
),
|
| 612 |
+
joined AS (
|
| 613 |
+
SELECT
|
| 614 |
+
base.seq_id,
|
| 615 |
+
base.species_key,
|
| 616 |
+
base.protein_norm,
|
| 617 |
+
sc.cluster_id
|
| 618 |
+
FROM base
|
| 619 |
+
LEFT JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
|
| 620 |
+
),
|
| 621 |
+
labeled AS (
|
| 622 |
+
SELECT
|
| 623 |
+
j.seq_id,
|
| 624 |
+
j.species_key,
|
| 625 |
+
j.protein_norm,
|
| 626 |
+
CASE
|
| 627 |
+
WHEN j.cluster_id IS NULL THEN 'drop'
|
| 628 |
+
WHEN j.species_key IN (SELECT species_key FROM heldout_species) THEN 'test'
|
| 629 |
+
WHEN coalesce(cf.n_test, 0) > 0 THEN 'drop'
|
| 630 |
+
ELSE coalesce(cs.split, 'drop')
|
| 631 |
+
END AS split
|
| 632 |
+
FROM joined j
|
| 633 |
+
LEFT JOIN cluster_flags cf USING (cluster_id)
|
| 634 |
+
LEFT JOIN read_parquet('{cluster_split_path}') cs USING (cluster_id)
|
| 635 |
+
),
|
| 636 |
+
protein_flags AS (
|
| 637 |
+
SELECT
|
| 638 |
+
protein_norm,
|
| 639 |
+
max(CASE WHEN split = 'test' THEN 1 ELSE 0 END) AS has_test,
|
| 640 |
+
max(CASE WHEN split = 'train' THEN 1 ELSE 0 END) AS has_train
|
| 641 |
+
FROM labeled
|
| 642 |
+
WHERE length(protein_norm) > 0
|
| 643 |
+
GROUP BY protein_norm
|
| 644 |
+
),
|
| 645 |
+
guarded AS (
|
| 646 |
+
SELECT
|
| 647 |
+
l.seq_id,
|
| 648 |
+
l.species_key,
|
| 649 |
+
CASE
|
| 650 |
+
WHEN l.split = 'drop' THEN 'drop'
|
| 651 |
+
WHEN length(l.protein_norm) = 0 THEN l.split
|
| 652 |
+
WHEN coalesce(pf.has_test, 0) = 1 AND l.split IN ('train', 'val') THEN 'drop'
|
| 653 |
+
WHEN coalesce(pf.has_train, 0) = 1 AND l.split = 'val' THEN 'drop'
|
| 654 |
+
ELSE l.split
|
| 655 |
+
END AS split
|
| 656 |
+
FROM labeled l
|
| 657 |
+
LEFT JOIN protein_flags pf USING (protein_norm)
|
| 658 |
+
),
|
| 659 |
+
dedup AS (
|
| 660 |
+
SELECT
|
| 661 |
+
seq_id,
|
| 662 |
+
CASE
|
| 663 |
+
WHEN count(DISTINCT species_key) > 1 THEN 'drop'
|
| 664 |
+
WHEN count(DISTINCT split) > 1 THEN 'drop'
|
| 665 |
+
ELSE any_value(split)
|
| 666 |
+
END AS split
|
| 667 |
+
FROM guarded
|
| 668 |
+
GROUP BY seq_id
|
| 669 |
+
)
|
| 670 |
+
SELECT seq_id, split FROM dedup
|
| 671 |
+
)
|
| 672 |
+
TO '{out_path}'
|
| 673 |
+
(FORMAT PARQUET);
|
| 674 |
+
"""
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
rows = con.execute(
|
| 678 |
+
f"""
|
| 679 |
+
WITH base AS (
|
| 680 |
+
SELECT {seq_id} AS seq_id
|
| 681 |
+
FROM {input_sql}
|
| 682 |
+
)
|
| 683 |
+
SELECT s.split, count(*)::BIGINT AS n_rows
|
| 684 |
+
FROM base
|
| 685 |
+
JOIN read_parquet('{out_path}') s USING (seq_id)
|
| 686 |
+
GROUP BY s.split
|
| 687 |
+
ORDER BY n_rows DESC;
|
| 688 |
+
"""
|
| 689 |
+
).fetchall()
|
| 690 |
+
print("Split summary (rows):")
|
| 691 |
+
for split, n in rows:
|
| 692 |
+
print(f" {split}\t{n:,}")
|
| 693 |
+
|
| 694 |
+
print(f"Wrote cluster→split parquet: {cluster_split}")
|
| 695 |
+
print(f"Wrote seq→split parquet: {out}")
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def cmd_write_data_v3(args: argparse.Namespace) -> None:
|
| 699 |
+
import duckdb
|
| 700 |
+
|
| 701 |
+
seq_split = Path(args.seq_split_parquet)
|
| 702 |
+
if not seq_split.exists():
|
| 703 |
+
raise SystemExit(f"seq_split parquet not found: {seq_split}")
|
| 704 |
+
seq_cluster = Path(args.seq_cluster_parquet)
|
| 705 |
+
if args.representatives_only and not seq_cluster.exists():
|
| 706 |
+
raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
|
| 707 |
+
|
| 708 |
+
out_root = Path(args.output_root)
|
| 709 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
| 710 |
+
(out_root / "_work").mkdir(parents=True, exist_ok=True)
|
| 711 |
+
|
| 712 |
+
for split_dir in (out_root / "train", out_root / "val", out_root / "test"):
|
| 713 |
+
if split_dir.exists():
|
| 714 |
+
if not args.overwrite:
|
| 715 |
+
raise SystemExit(f"Output split directory exists: {split_dir} (pass --overwrite)")
|
| 716 |
+
shutil.rmtree(split_dir)
|
| 717 |
+
split_dir.mkdir(parents=True, exist_ok=True)
|
| 718 |
+
|
| 719 |
+
con = duckdb.connect()
|
| 720 |
+
con.execute(f"PRAGMA threads={int(args.threads)};")
|
| 721 |
+
con.execute("PRAGMA enable_progress_bar=true;")
|
| 722 |
+
|
| 723 |
+
input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
|
| 724 |
+
seq_split_path = _sql_escape_path(str(seq_split))
|
| 725 |
+
seq_cluster_path = _sql_escape_path(str(seq_cluster))
|
| 726 |
+
seq_id = _seq_id_sql()
|
| 727 |
+
|
| 728 |
+
num_shards = int(args.num_shards)
|
| 729 |
+
if num_shards <= 0:
|
| 730 |
+
raise SystemExit("--num-shards must be > 0")
|
| 731 |
+
|
| 732 |
+
for split in ("train", "val", "test"):
|
| 733 |
+
out_dir = _sql_escape_path(str(out_root / split))
|
| 734 |
+
if args.representatives_only:
|
| 735 |
+
target_seq_ids_sql = f"""
|
| 736 |
+
SELECT min(s.seq_id) AS seq_id
|
| 737 |
+
FROM read_parquet('{seq_split_path}') s
|
| 738 |
+
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
|
| 739 |
+
WHERE s.split = '{split}'
|
| 740 |
+
GROUP BY sc.cluster_id
|
| 741 |
+
"""
|
| 742 |
+
else:
|
| 743 |
+
target_seq_ids_sql = f"""
|
| 744 |
+
SELECT DISTINCT s.seq_id
|
| 745 |
+
FROM read_parquet('{seq_split_path}') s
|
| 746 |
+
WHERE s.split = '{split}'
|
| 747 |
+
"""
|
| 748 |
+
sql = f"""
|
| 749 |
+
COPY (
|
| 750 |
+
WITH target_seq_ids AS (
|
| 751 |
+
{target_seq_ids_sql}
|
| 752 |
+
),
|
| 753 |
+
rows AS (
|
| 754 |
+
SELECT
|
| 755 |
+
p.*,
|
| 756 |
+
abs(hash({seq_id})) % {num_shards} AS shard
|
| 757 |
+
FROM {input_sql} p
|
| 758 |
+
JOIN target_seq_ids t
|
| 759 |
+
ON t.seq_id = ({seq_id})
|
| 760 |
+
QUALIFY row_number() OVER (PARTITION BY ({seq_id}) ORDER BY ({seq_id})) = 1
|
| 761 |
+
)
|
| 762 |
+
SELECT * FROM rows
|
| 763 |
+
)
|
| 764 |
+
TO '{out_dir}'
|
| 765 |
+
(FORMAT PARQUET, PARTITION_BY (shard));
|
| 766 |
+
"""
|
| 767 |
+
t0 = time.time()
|
| 768 |
+
con.execute(sql)
|
| 769 |
+
print(
|
| 770 |
+
f"Wrote {split} parquets to {out_root / split} "
|
| 771 |
+
f"representatives_only={bool(args.representatives_only)} "
|
| 772 |
+
f"(elapsed_s={time.time() - t0:.1f})"
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def cmd_verify(args: argparse.Namespace) -> None:
|
| 777 |
+
import duckdb
|
| 778 |
+
|
| 779 |
+
seq_cluster = Path(args.seq_cluster_parquet)
|
| 780 |
+
seq_split = Path(args.seq_split_parquet)
|
| 781 |
+
if not seq_cluster.exists():
|
| 782 |
+
raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
|
| 783 |
+
if not seq_split.exists():
|
| 784 |
+
raise SystemExit(f"seq_split parquet not found: {seq_split}")
|
| 785 |
+
|
| 786 |
+
con = duckdb.connect()
|
| 787 |
+
con.execute(f"PRAGMA threads={int(args.threads)};")
|
| 788 |
+
con.execute("PRAGMA enable_progress_bar=true;")
|
| 789 |
+
|
| 790 |
+
input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
|
| 791 |
+
heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0)
|
| 792 |
+
seq_cluster_path = _sql_escape_path(str(seq_cluster))
|
| 793 |
+
seq_split_path = _sql_escape_path(str(seq_split))
|
| 794 |
+
|
| 795 |
+
seq_id = _seq_id_sql()
|
| 796 |
+
species_key = _species_key_sql(args.species_key_mode, "taxon")
|
| 797 |
+
protein_norm = _protein_norm_sql("protein_seq")
|
| 798 |
+
|
| 799 |
+
con.execute(
|
| 800 |
+
f"""
|
| 801 |
+
CREATE TEMP TABLE heldout_species AS
|
| 802 |
+
SELECT DISTINCT {species_key} AS species_key
|
| 803 |
+
FROM {heldout_sql}
|
| 804 |
+
WHERE {species_key} != '';
|
| 805 |
+
"""
|
| 806 |
+
)
|
| 807 |
+
con.execute(
|
| 808 |
+
f"""
|
| 809 |
+
CREATE TEMP TABLE cluster_counts AS
|
| 810 |
+
WITH base AS (
|
| 811 |
+
SELECT
|
| 812 |
+
{seq_id} AS seq_id,
|
| 813 |
+
{species_key} AS species_key
|
| 814 |
+
FROM {input_sql}
|
| 815 |
+
WHERE length(({seq_id})) > 1
|
| 816 |
+
AND {species_key} != ''
|
| 817 |
+
)
|
| 818 |
+
SELECT
|
| 819 |
+
sc.cluster_id,
|
| 820 |
+
base.species_key,
|
| 821 |
+
count(*)::BIGINT AS n
|
| 822 |
+
FROM base
|
| 823 |
+
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
|
| 824 |
+
GROUP BY sc.cluster_id, base.species_key;
|
| 825 |
+
"""
|
| 826 |
+
)
|
| 827 |
+
con.execute(
|
| 828 |
+
"""
|
| 829 |
+
CREATE TEMP TABLE cluster_flags AS
|
| 830 |
+
SELECT
|
| 831 |
+
cluster_id,
|
| 832 |
+
sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test,
|
| 833 |
+
sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen,
|
| 834 |
+
sum(n)::BIGINT AS n_total,
|
| 835 |
+
count(*)::BIGINT AS n_species
|
| 836 |
+
FROM cluster_counts
|
| 837 |
+
GROUP BY cluster_id;
|
| 838 |
+
"""
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
split_seq_ids = {
|
| 842 |
+
split: int(n)
|
| 843 |
+
for split, n in con.execute(
|
| 844 |
+
f"""
|
| 845 |
+
SELECT split, count(*)::BIGINT AS n
|
| 846 |
+
FROM read_parquet('{seq_split_path}')
|
| 847 |
+
GROUP BY split
|
| 848 |
+
"""
|
| 849 |
+
).fetchall()
|
| 850 |
+
}
|
| 851 |
+
split_rows = {
|
| 852 |
+
split: int(n)
|
| 853 |
+
for split, n in con.execute(
|
| 854 |
+
f"""
|
| 855 |
+
WITH base AS (
|
| 856 |
+
SELECT {seq_id} AS seq_id FROM {input_sql}
|
| 857 |
+
)
|
| 858 |
+
SELECT s.split, count(*)::BIGINT AS n
|
| 859 |
+
FROM base
|
| 860 |
+
JOIN read_parquet('{seq_split_path}') s USING (seq_id)
|
| 861 |
+
GROUP BY s.split
|
| 862 |
+
"""
|
| 863 |
+
).fetchall()
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
bad_clusters = int(
|
| 867 |
+
con.execute(
|
| 868 |
+
f"""
|
| 869 |
+
WITH keep AS (
|
| 870 |
+
SELECT sc.cluster_id, ss.split
|
| 871 |
+
FROM read_parquet('{seq_cluster_path}') sc
|
| 872 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 873 |
+
WHERE ss.split != 'drop'
|
| 874 |
+
)
|
| 875 |
+
SELECT count(*)::BIGINT
|
| 876 |
+
FROM (
|
| 877 |
+
SELECT cluster_id
|
| 878 |
+
FROM keep
|
| 879 |
+
GROUP BY cluster_id
|
| 880 |
+
HAVING count(DISTINCT split) > 1
|
| 881 |
+
);
|
| 882 |
+
"""
|
| 883 |
+
).fetchone()[0]
|
| 884 |
+
)
|
| 885 |
+
print(f"clusters_spanning_splits(excluding drop) = {bad_clusters}")
|
| 886 |
+
|
| 887 |
+
bad_test = int(
|
| 888 |
+
con.execute(
|
| 889 |
+
f"""
|
| 890 |
+
WITH base AS (
|
| 891 |
+
SELECT {seq_id} AS seq_id, {species_key} AS species_key
|
| 892 |
+
FROM {input_sql}
|
| 893 |
+
)
|
| 894 |
+
SELECT count(*)::BIGINT
|
| 895 |
+
FROM base
|
| 896 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 897 |
+
WHERE ss.split = 'test'
|
| 898 |
+
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
|
| 899 |
+
"""
|
| 900 |
+
).fetchone()[0]
|
| 901 |
+
)
|
| 902 |
+
print(f"test_rows_with_seen_species = {bad_test}")
|
| 903 |
+
|
| 904 |
+
bad_val_species = int(
|
| 905 |
+
con.execute(
|
| 906 |
+
f"""
|
| 907 |
+
WITH base AS (
|
| 908 |
+
SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key
|
| 909 |
+
FROM {input_sql}
|
| 910 |
+
),
|
| 911 |
+
labeled AS (
|
| 912 |
+
SELECT base.species_key, ss.split
|
| 913 |
+
FROM base
|
| 914 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 915 |
+
WHERE ss.split IN ('train', 'val')
|
| 916 |
+
),
|
| 917 |
+
train_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'train'),
|
| 918 |
+
val_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'val')
|
| 919 |
+
SELECT count(*)::BIGINT
|
| 920 |
+
FROM (SELECT species_key FROM val_species EXCEPT SELECT species_key FROM train_species);
|
| 921 |
+
"""
|
| 922 |
+
).fetchone()[0]
|
| 923 |
+
)
|
| 924 |
+
print(f"val_species_not_in_train = {bad_val_species}")
|
| 925 |
+
|
| 926 |
+
protein_overlap_train_val = int(
|
| 927 |
+
con.execute(
|
| 928 |
+
f"""
|
| 929 |
+
WITH base AS (
|
| 930 |
+
SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm
|
| 931 |
+
FROM {input_sql}
|
| 932 |
+
WHERE length({protein_norm}) > 0
|
| 933 |
+
),
|
| 934 |
+
labeled AS (
|
| 935 |
+
SELECT base.protein_norm, ss.split
|
| 936 |
+
FROM base
|
| 937 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 938 |
+
WHERE ss.split IN ('train', 'val')
|
| 939 |
+
),
|
| 940 |
+
train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'),
|
| 941 |
+
val_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'val')
|
| 942 |
+
SELECT count(*)::BIGINT
|
| 943 |
+
FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM val_p);
|
| 944 |
+
"""
|
| 945 |
+
).fetchone()[0]
|
| 946 |
+
)
|
| 947 |
+
protein_overlap_train_test = int(
|
| 948 |
+
con.execute(
|
| 949 |
+
f"""
|
| 950 |
+
WITH base AS (
|
| 951 |
+
SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm
|
| 952 |
+
FROM {input_sql}
|
| 953 |
+
WHERE length({protein_norm}) > 0
|
| 954 |
+
),
|
| 955 |
+
labeled AS (
|
| 956 |
+
SELECT base.protein_norm, ss.split
|
| 957 |
+
FROM base
|
| 958 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 959 |
+
WHERE ss.split IN ('train', 'test')
|
| 960 |
+
),
|
| 961 |
+
train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'),
|
| 962 |
+
test_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'test')
|
| 963 |
+
SELECT count(*)::BIGINT
|
| 964 |
+
FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM test_p);
|
| 965 |
+
"""
|
| 966 |
+
).fetchone()[0]
|
| 967 |
+
)
|
| 968 |
+
print(f"exact_protein_overlap_train_val = {protein_overlap_train_val}")
|
| 969 |
+
print(f"exact_protein_overlap_train_test = {protein_overlap_train_test}")
|
| 970 |
+
|
| 971 |
+
mixed_test_clusters = int(
|
| 972 |
+
con.execute(
|
| 973 |
+
"SELECT count(*)::BIGINT FROM cluster_flags WHERE n_test > 0 AND n_seen > 0"
|
| 974 |
+
).fetchone()[0]
|
| 975 |
+
)
|
| 976 |
+
exact_holdout_seen_conflicts = int(
|
| 977 |
+
con.execute(
|
| 978 |
+
f"""
|
| 979 |
+
WITH base AS (
|
| 980 |
+
SELECT DISTINCT
|
| 981 |
+
{protein_norm} AS protein_norm,
|
| 982 |
+
{species_key} AS species_key
|
| 983 |
+
FROM {input_sql}
|
| 984 |
+
WHERE length({protein_norm}) > 0
|
| 985 |
+
AND {species_key} != ''
|
| 986 |
+
)
|
| 987 |
+
SELECT count(*)::BIGINT
|
| 988 |
+
FROM (
|
| 989 |
+
SELECT protein_norm
|
| 990 |
+
FROM base
|
| 991 |
+
GROUP BY protein_norm
|
| 992 |
+
HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
|
| 993 |
+
AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
|
| 994 |
+
);
|
| 995 |
+
"""
|
| 996 |
+
).fetchone()[0]
|
| 997 |
+
)
|
| 998 |
+
dropped_seen_rows_exact_holdout = int(
|
| 999 |
+
con.execute(
|
| 1000 |
+
f"""
|
| 1001 |
+
WITH base AS (
|
| 1002 |
+
SELECT DISTINCT
|
| 1003 |
+
{seq_id} AS seq_id,
|
| 1004 |
+
{species_key} AS species_key,
|
| 1005 |
+
{protein_norm} AS protein_norm
|
| 1006 |
+
FROM {input_sql}
|
| 1007 |
+
WHERE length(({seq_id})) > 1
|
| 1008 |
+
AND {species_key} != ''
|
| 1009 |
+
),
|
| 1010 |
+
conflict_proteins AS (
|
| 1011 |
+
SELECT protein_norm
|
| 1012 |
+
FROM base
|
| 1013 |
+
WHERE length(protein_norm) > 0
|
| 1014 |
+
GROUP BY protein_norm
|
| 1015 |
+
HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
|
| 1016 |
+
AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
|
| 1017 |
+
)
|
| 1018 |
+
SELECT count(*)::BIGINT
|
| 1019 |
+
FROM base
|
| 1020 |
+
JOIN conflict_proteins USING (protein_norm)
|
| 1021 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 1022 |
+
WHERE ss.split = 'drop'
|
| 1023 |
+
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
|
| 1024 |
+
"""
|
| 1025 |
+
).fetchone()[0]
|
| 1026 |
+
)
|
| 1027 |
+
dropped_val_rows_exact_train = int(
|
| 1028 |
+
con.execute(
|
| 1029 |
+
f"""
|
| 1030 |
+
WITH base AS (
|
| 1031 |
+
SELECT DISTINCT
|
| 1032 |
+
{seq_id} AS seq_id,
|
| 1033 |
+
{protein_norm} AS protein_norm
|
| 1034 |
+
FROM {input_sql}
|
| 1035 |
+
WHERE length(({seq_id})) > 1
|
| 1036 |
+
AND length({protein_norm}) > 0
|
| 1037 |
+
),
|
| 1038 |
+
labeled AS (
|
| 1039 |
+
SELECT base.protein_norm, ss.split
|
| 1040 |
+
FROM base
|
| 1041 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 1042 |
+
),
|
| 1043 |
+
train_proteins AS (
|
| 1044 |
+
SELECT DISTINCT protein_norm
|
| 1045 |
+
FROM labeled
|
| 1046 |
+
WHERE split = 'train'
|
| 1047 |
+
)
|
| 1048 |
+
SELECT count(*)::BIGINT
|
| 1049 |
+
FROM base
|
| 1050 |
+
JOIN train_proteins USING (protein_norm)
|
| 1051 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 1052 |
+
WHERE ss.split = 'drop';
|
| 1053 |
+
"""
|
| 1054 |
+
).fetchone()[0]
|
| 1055 |
+
)
|
| 1056 |
+
dropped_seen_rows_mixed = int(
|
| 1057 |
+
con.execute(
|
| 1058 |
+
f"""
|
| 1059 |
+
WITH base AS (
|
| 1060 |
+
SELECT {seq_id} AS seq_id, {species_key} AS species_key
|
| 1061 |
+
FROM {input_sql}
|
| 1062 |
+
)
|
| 1063 |
+
SELECT count(*)::BIGINT
|
| 1064 |
+
FROM base
|
| 1065 |
+
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
|
| 1066 |
+
JOIN cluster_flags cf USING (cluster_id)
|
| 1067 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 1068 |
+
WHERE ss.split = 'drop'
|
| 1069 |
+
AND cf.n_test > 0
|
| 1070 |
+
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
|
| 1071 |
+
"""
|
| 1072 |
+
).fetchone()[0]
|
| 1073 |
+
)
|
| 1074 |
+
dropped_seen_seqids_mixed = int(
|
| 1075 |
+
con.execute(
|
| 1076 |
+
f"""
|
| 1077 |
+
WITH base AS (
|
| 1078 |
+
SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key
|
| 1079 |
+
FROM {input_sql}
|
| 1080 |
+
)
|
| 1081 |
+
SELECT count(*)::BIGINT
|
| 1082 |
+
FROM base
|
| 1083 |
+
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
|
| 1084 |
+
JOIN cluster_flags cf USING (cluster_id)
|
| 1085 |
+
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
|
| 1086 |
+
WHERE ss.split = 'drop'
|
| 1087 |
+
AND cf.n_test > 0
|
| 1088 |
+
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
|
| 1089 |
+
"""
|
| 1090 |
+
).fetchone()[0]
|
| 1091 |
+
)
|
| 1092 |
+
same_protein_multi_species = int(
|
| 1093 |
+
con.execute(
|
| 1094 |
+
f"""
|
| 1095 |
+
WITH base AS (
|
| 1096 |
+
SELECT DISTINCT
|
| 1097 |
+
{protein_norm} AS protein_norm,
|
| 1098 |
+
{species_key} AS species_key
|
| 1099 |
+
FROM {input_sql}
|
| 1100 |
+
WHERE length({protein_norm}) > 0
|
| 1101 |
+
AND {species_key} != ''
|
| 1102 |
+
)
|
| 1103 |
+
SELECT count(*)::BIGINT
|
| 1104 |
+
FROM (
|
| 1105 |
+
SELECT protein_norm
|
| 1106 |
+
FROM base
|
| 1107 |
+
GROUP BY protein_norm
|
| 1108 |
+
HAVING count(DISTINCT species_key) > 1
|
| 1109 |
+
);
|
| 1110 |
+
"""
|
| 1111 |
+
).fetchone()[0]
|
| 1112 |
+
)
|
| 1113 |
+
same_protein_cross_holdout = int(
|
| 1114 |
+
con.execute(
|
| 1115 |
+
f"""
|
| 1116 |
+
WITH base AS (
|
| 1117 |
+
SELECT DISTINCT
|
| 1118 |
+
{protein_norm} AS protein_norm,
|
| 1119 |
+
{species_key} AS species_key
|
| 1120 |
+
FROM {input_sql}
|
| 1121 |
+
WHERE length({protein_norm}) > 0
|
| 1122 |
+
AND {species_key} != ''
|
| 1123 |
+
)
|
| 1124 |
+
SELECT count(*)::BIGINT
|
| 1125 |
+
FROM (
|
| 1126 |
+
SELECT protein_norm
|
| 1127 |
+
FROM base
|
| 1128 |
+
GROUP BY protein_norm
|
| 1129 |
+
HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
|
| 1130 |
+
AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
|
| 1131 |
+
);
|
| 1132 |
+
"""
|
| 1133 |
+
).fetchone()[0]
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
report = {
|
| 1137 |
+
"parameters": {
|
| 1138 |
+
"input_glob": args.input_glob,
|
| 1139 |
+
"heldout_test_glob": args.heldout_test_glob,
|
| 1140 |
+
"seq_cluster_parquet": str(seq_cluster),
|
| 1141 |
+
"seq_split_parquet": str(seq_split),
|
| 1142 |
+
"seq_space": args.seq_space,
|
| 1143 |
+
"species_key_mode": args.species_key_mode,
|
| 1144 |
+
"limit_files": int(args.limit_files),
|
| 1145 |
+
},
|
| 1146 |
+
"split_seq_ids": split_seq_ids,
|
| 1147 |
+
"split_rows": split_rows,
|
| 1148 |
+
"verification": {
|
| 1149 |
+
"clusters_spanning_splits_excluding_drop": bad_clusters,
|
| 1150 |
+
"test_rows_with_seen_species": bad_test,
|
| 1151 |
+
"val_species_not_in_train": bad_val_species,
|
| 1152 |
+
"exact_protein_overlap_train_val": protein_overlap_train_val,
|
| 1153 |
+
"exact_protein_overlap_train_test": protein_overlap_train_test,
|
| 1154 |
+
},
|
| 1155 |
+
"audit": {
|
| 1156 |
+
"mixed_test_clusters": mixed_test_clusters,
|
| 1157 |
+
"exact_protein_cross_holdout_seen_groups": exact_holdout_seen_conflicts,
|
| 1158 |
+
"dropped_seen_rows_from_exact_protein_holdout_overlap": dropped_seen_rows_exact_holdout,
|
| 1159 |
+
"dropped_rows_from_exact_protein_train_overlap": dropped_val_rows_exact_train,
|
| 1160 |
+
"dropped_seen_rows_from_mixed_test_clusters": dropped_seen_rows_mixed,
|
| 1161 |
+
"dropped_seen_seqids_from_mixed_test_clusters": dropped_seen_seqids_mixed,
|
| 1162 |
+
"same_protein_multi_species_exact_matches": same_protein_multi_species,
|
| 1163 |
+
"same_protein_cross_holdout_species_exact_matches": same_protein_cross_holdout,
|
| 1164 |
+
},
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
if args.report_json:
|
| 1168 |
+
report_path = Path(args.report_json)
|
| 1169 |
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1170 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
| 1171 |
+
json.dump(report, f, indent=2, sort_keys=True)
|
| 1172 |
+
print(f"Wrote audit report: {report_path}")
|
| 1173 |
+
|
| 1174 |
+
if (
|
| 1175 |
+
bad_clusters != 0
|
| 1176 |
+
or bad_test != 0
|
| 1177 |
+
or bad_val_species != 0
|
| 1178 |
+
or protein_overlap_train_val != 0
|
| 1179 |
+
or protein_overlap_train_test != 0
|
| 1180 |
+
):
|
| 1181 |
+
raise SystemExit("Verification FAILED (see counts above).")
|
| 1182 |
+
print("Verification OK.")
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 1186 |
+
ap = argparse.ArgumentParser(
|
| 1187 |
+
description="Resplit data_v2 to data_v3_rebuild using MMseqs2 protein clustering."
|
| 1188 |
+
)
|
| 1189 |
+
sub = ap.add_subparsers(dest="cmd", required=True)
|
| 1190 |
+
|
| 1191 |
+
p = sub.add_parser("make-fasta", help="Generate MMseqs FASTA from parquet shards.")
|
| 1192 |
+
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
|
| 1193 |
+
p.add_argument("--output-fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta")
|
| 1194 |
+
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
|
| 1195 |
+
p.add_argument(
|
| 1196 |
+
"--max-input-seq-len",
|
| 1197 |
+
type=int,
|
| 1198 |
+
default=0,
|
| 1199 |
+
help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).",
|
| 1200 |
+
)
|
| 1201 |
+
p.add_argument("--threads", type=int, default=32)
|
| 1202 |
+
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
|
| 1203 |
+
p.add_argument("--limit-rows", type=int, default=0, help="Debug: limit number of rows written (0=all)")
|
| 1204 |
+
p.set_defaults(func=cmd_make_fasta)
|
| 1205 |
+
|
| 1206 |
+
p = sub.add_parser("mmseqs-cluster", help="Run MMseqs2 createdb+linclust and emit clustering TSV.")
|
| 1207 |
+
p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path())
|
| 1208 |
+
p.add_argument("--fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta")
|
| 1209 |
+
p.add_argument("--workdir", type=str, default="data_v3_rebuild/_work/mmseqs")
|
| 1210 |
+
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
|
| 1211 |
+
p.add_argument("--threads", type=int, default=32)
|
| 1212 |
+
p.add_argument("--min-seq-id", type=float, default=0.90)
|
| 1213 |
+
p.add_argument("-c", "--coverage", type=float, default=0.80)
|
| 1214 |
+
p.add_argument("--cov-mode", type=int, default=2, help="2=enforce representative/query coverage")
|
| 1215 |
+
p.add_argument("--cluster-mode", type=int, default=2, help="2=greedy clustering by sequence length")
|
| 1216 |
+
p.add_argument("--max-seq-len", type=int, default=200000)
|
| 1217 |
+
p.add_argument(
|
| 1218 |
+
"--kmer-per-seq-scale",
|
| 1219 |
+
type=float,
|
| 1220 |
+
default=None,
|
| 1221 |
+
help="Optional MMseqs2 override; leave empty to use MMseqs defaults.",
|
| 1222 |
+
)
|
| 1223 |
+
p.add_argument("--split-memory-limit", type=str, default="", help="e.g. 120G (empty=use MMseqs default)")
|
| 1224 |
+
g = p.add_mutually_exclusive_group()
|
| 1225 |
+
g.add_argument(
|
| 1226 |
+
"--remove-tmp-files",
|
| 1227 |
+
dest="remove_tmp_files",
|
| 1228 |
+
action="store_true",
|
| 1229 |
+
default=True,
|
| 1230 |
+
help="Remove MMseqs2 tmp files (default).",
|
| 1231 |
+
)
|
| 1232 |
+
g.add_argument(
|
| 1233 |
+
"--keep-tmp-files",
|
| 1234 |
+
dest="remove_tmp_files",
|
| 1235 |
+
action="store_false",
|
| 1236 |
+
help="Keep MMseqs2 tmp files.",
|
| 1237 |
+
)
|
| 1238 |
+
p.add_argument("--overwrite", action="store_true")
|
| 1239 |
+
p.set_defaults(func=cmd_mmseqs_cluster)
|
| 1240 |
+
|
| 1241 |
+
p = sub.add_parser("make-seq-cluster", help="Convert MMseqs TSV to parquet mapping seq_id→cluster_id.")
|
| 1242 |
+
p.add_argument("--cluster-tsv", type=str, default="data_v3_rebuild/_work/mmseqs/clu.tsv")
|
| 1243 |
+
p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
|
| 1244 |
+
p.add_argument("--threads", type=int, default=32)
|
| 1245 |
+
p.set_defaults(func=cmd_make_seq_cluster)
|
| 1246 |
+
|
| 1247 |
+
p = sub.add_parser(
|
| 1248 |
+
"make-seq-split",
|
| 1249 |
+
help="Create seq_id→{train,val,test,drop} using cluster assignments and heldout species.",
|
| 1250 |
+
)
|
| 1251 |
+
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
|
| 1252 |
+
p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
|
| 1253 |
+
p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
|
| 1254 |
+
p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
|
| 1255 |
+
p.add_argument("--cluster-split-parquet", type=str, default="data_v3_rebuild/_work/cluster_split.parquet")
|
| 1256 |
+
p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
|
| 1257 |
+
p.add_argument("--val-frac", type=float, default=0.01)
|
| 1258 |
+
p.add_argument("--seed", type=int, default=13)
|
| 1259 |
+
p.add_argument("--threads", type=int, default=32)
|
| 1260 |
+
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
|
| 1261 |
+
p.set_defaults(func=cmd_make_seq_split)
|
| 1262 |
+
|
| 1263 |
+
p = sub.add_parser("write-data-v3", help="Write data_v3 parquet directories from seq_split mapping.")
|
| 1264 |
+
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
|
| 1265 |
+
p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
|
| 1266 |
+
p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
|
| 1267 |
+
p.add_argument("--output-root", type=str, default="data_v3_rebuild")
|
| 1268 |
+
p.add_argument("--num-shards", type=int, default=256, help="Partition each split into N shards")
|
| 1269 |
+
p.add_argument("--threads", type=int, default=32)
|
| 1270 |
+
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
|
| 1271 |
+
g = p.add_mutually_exclusive_group()
|
| 1272 |
+
g.add_argument(
|
| 1273 |
+
"--representatives-only",
|
| 1274 |
+
dest="representatives_only",
|
| 1275 |
+
action="store_true",
|
| 1276 |
+
default=True,
|
| 1277 |
+
help="Write only one representative seq_id per MMseqs cluster (default).",
|
| 1278 |
+
)
|
| 1279 |
+
g.add_argument(
|
| 1280 |
+
"--all-cluster-members",
|
| 1281 |
+
dest="representatives_only",
|
| 1282 |
+
action="store_false",
|
| 1283 |
+
help="Write all seq_ids assigned to the split instead of one representative per cluster.",
|
| 1284 |
+
)
|
| 1285 |
+
p.add_argument("--overwrite", action="store_true")
|
| 1286 |
+
p.set_defaults(func=cmd_write_data_v3)
|
| 1287 |
+
|
| 1288 |
+
p = sub.add_parser("verify", help="Verify leakage/species constraints and write an audit report.")
|
| 1289 |
+
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
|
| 1290 |
+
p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
|
| 1291 |
+
p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
|
| 1292 |
+
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
|
| 1293 |
+
p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
|
| 1294 |
+
p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
|
| 1295 |
+
p.add_argument("--report-json", type=str, default="data_v3_rebuild/_work/split_report.json")
|
| 1296 |
+
p.add_argument("--threads", type=int, default=32)
|
| 1297 |
+
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
|
| 1298 |
+
p.set_defaults(func=cmd_verify)
|
| 1299 |
+
|
| 1300 |
+
p = sub.add_parser(
|
| 1301 |
+
"all",
|
| 1302 |
+
help="Run the full pipeline: make-fasta → mmseqs-cluster → make-seq-cluster → make-seq-split → write-data-v3 → verify.",
|
| 1303 |
+
)
|
| 1304 |
+
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
|
| 1305 |
+
p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
|
| 1306 |
+
p.add_argument("--output-root", type=str, default="data_v3_rebuild")
|
| 1307 |
+
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
|
| 1308 |
+
p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
|
| 1309 |
+
p.add_argument(
|
| 1310 |
+
"--max-input-seq-len",
|
| 1311 |
+
type=int,
|
| 1312 |
+
default=0,
|
| 1313 |
+
help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).",
|
| 1314 |
+
)
|
| 1315 |
+
p.add_argument("--threads", type=int, default=32)
|
| 1316 |
+
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
|
| 1317 |
+
p.add_argument("--num-shards", type=int, default=256)
|
| 1318 |
+
g = p.add_mutually_exclusive_group()
|
| 1319 |
+
g.add_argument(
|
| 1320 |
+
"--representatives-only",
|
| 1321 |
+
dest="representatives_only",
|
| 1322 |
+
action="store_true",
|
| 1323 |
+
default=True,
|
| 1324 |
+
help="Write only one representative seq_id per MMseqs cluster (default).",
|
| 1325 |
+
)
|
| 1326 |
+
g.add_argument(
|
| 1327 |
+
"--all-cluster-members",
|
| 1328 |
+
dest="representatives_only",
|
| 1329 |
+
action="store_false",
|
| 1330 |
+
help="Write all seq_ids assigned to the split instead of one representative per cluster.",
|
| 1331 |
+
)
|
| 1332 |
+
p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path())
|
| 1333 |
+
p.add_argument("--min-seq-id", type=float, default=0.90)
|
| 1334 |
+
p.add_argument("-c", "--coverage", type=float, default=0.80)
|
| 1335 |
+
p.add_argument("--cov-mode", type=int, default=2)
|
| 1336 |
+
p.add_argument("--cluster-mode", type=int, default=2)
|
| 1337 |
+
p.add_argument("--max-seq-len", type=int, default=200000)
|
| 1338 |
+
p.add_argument("--kmer-per-seq-scale", type=float, default=None)
|
| 1339 |
+
p.add_argument("--split-memory-limit", type=str, default="")
|
| 1340 |
+
p.add_argument("--val-frac", type=float, default=0.01)
|
| 1341 |
+
p.add_argument("--seed", type=int, default=13)
|
| 1342 |
+
p.add_argument("--overwrite", action="store_true")
|
| 1343 |
+
|
| 1344 |
+
def _run_all(a: argparse.Namespace) -> None:
|
| 1345 |
+
out_root = Path(a.output_root)
|
| 1346 |
+
work = out_root / "_work"
|
| 1347 |
+
fasta = work / "mmseqs_input.fasta"
|
| 1348 |
+
mmseqs_work = work / "mmseqs"
|
| 1349 |
+
cluster_tsv = mmseqs_work / "clu.tsv"
|
| 1350 |
+
seq_cluster = work / "seq_cluster.parquet"
|
| 1351 |
+
cluster_split = work / "cluster_split.parquet"
|
| 1352 |
+
seq_split = work / "seq_split.parquet"
|
| 1353 |
+
report_json = work / "split_report.json"
|
| 1354 |
+
|
| 1355 |
+
cmd_make_fasta(
|
| 1356 |
+
argparse.Namespace(
|
| 1357 |
+
input_glob=a.input_glob,
|
| 1358 |
+
output_fasta=str(fasta),
|
| 1359 |
+
seq_space=a.seq_space,
|
| 1360 |
+
max_input_seq_len=a.max_input_seq_len,
|
| 1361 |
+
threads=a.threads,
|
| 1362 |
+
limit_files=a.limit_files,
|
| 1363 |
+
limit_rows=0,
|
| 1364 |
+
)
|
| 1365 |
+
)
|
| 1366 |
+
cmd_mmseqs_cluster(
|
| 1367 |
+
argparse.Namespace(
|
| 1368 |
+
mmseqs=a.mmseqs,
|
| 1369 |
+
fasta=str(fasta),
|
| 1370 |
+
workdir=str(mmseqs_work),
|
| 1371 |
+
seq_space=a.seq_space,
|
| 1372 |
+
threads=a.threads,
|
| 1373 |
+
min_seq_id=a.min_seq_id,
|
| 1374 |
+
coverage=a.coverage,
|
| 1375 |
+
cov_mode=a.cov_mode,
|
| 1376 |
+
cluster_mode=a.cluster_mode,
|
| 1377 |
+
max_seq_len=a.max_seq_len,
|
| 1378 |
+
kmer_per_seq_scale=a.kmer_per_seq_scale,
|
| 1379 |
+
split_memory_limit=a.split_memory_limit,
|
| 1380 |
+
remove_tmp_files=True,
|
| 1381 |
+
overwrite=a.overwrite,
|
| 1382 |
+
)
|
| 1383 |
+
)
|
| 1384 |
+
cmd_make_seq_cluster(
|
| 1385 |
+
argparse.Namespace(
|
| 1386 |
+
cluster_tsv=str(cluster_tsv),
|
| 1387 |
+
output_parquet=str(seq_cluster),
|
| 1388 |
+
threads=a.threads,
|
| 1389 |
+
)
|
| 1390 |
+
)
|
| 1391 |
+
cmd_make_seq_split(
|
| 1392 |
+
argparse.Namespace(
|
| 1393 |
+
input_glob=a.input_glob,
|
| 1394 |
+
heldout_test_glob=a.heldout_test_glob,
|
| 1395 |
+
species_key_mode=a.species_key_mode,
|
| 1396 |
+
seq_cluster_parquet=str(seq_cluster),
|
| 1397 |
+
cluster_split_parquet=str(cluster_split),
|
| 1398 |
+
output_parquet=str(seq_split),
|
| 1399 |
+
val_frac=a.val_frac,
|
| 1400 |
+
seed=a.seed,
|
| 1401 |
+
threads=a.threads,
|
| 1402 |
+
limit_files=a.limit_files,
|
| 1403 |
+
)
|
| 1404 |
+
)
|
| 1405 |
+
cmd_write_data_v3(
|
| 1406 |
+
argparse.Namespace(
|
| 1407 |
+
input_glob=a.input_glob,
|
| 1408 |
+
seq_cluster_parquet=str(seq_cluster),
|
| 1409 |
+
seq_split_parquet=str(seq_split),
|
| 1410 |
+
output_root=str(out_root),
|
| 1411 |
+
num_shards=a.num_shards,
|
| 1412 |
+
threads=a.threads,
|
| 1413 |
+
limit_files=a.limit_files,
|
| 1414 |
+
representatives_only=a.representatives_only,
|
| 1415 |
+
overwrite=a.overwrite,
|
| 1416 |
+
)
|
| 1417 |
+
)
|
| 1418 |
+
cmd_verify(
|
| 1419 |
+
argparse.Namespace(
|
| 1420 |
+
input_glob=a.input_glob,
|
| 1421 |
+
heldout_test_glob=a.heldout_test_glob,
|
| 1422 |
+
species_key_mode=a.species_key_mode,
|
| 1423 |
+
seq_space=a.seq_space,
|
| 1424 |
+
seq_cluster_parquet=str(seq_cluster),
|
| 1425 |
+
seq_split_parquet=str(seq_split),
|
| 1426 |
+
report_json=str(report_json),
|
| 1427 |
+
threads=a.threads,
|
| 1428 |
+
limit_files=a.limit_files,
|
| 1429 |
+
)
|
| 1430 |
+
)
|
| 1431 |
+
|
| 1432 |
+
p.set_defaults(func=_run_all)
|
| 1433 |
+
return ap
|
| 1434 |
+
|
| 1435 |
+
|
| 1436 |
+
def main(argv: Optional[List[str]] = None) -> int:
|
| 1437 |
+
ap = build_parser()
|
| 1438 |
+
args = ap.parse_args(argv)
|
| 1439 |
+
args.func(args)
|
| 1440 |
+
return 0
|
| 1441 |
+
|
| 1442 |
+
|
| 1443 |
+
if __name__ == "__main__":
|
| 1444 |
+
raise SystemExit(main())
|
sampling.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Sampling script for generating codon sequences from trained CodonGPT models.
|
| 4 |
+
Inputs are prepared exactly like training:
|
| 5 |
+
- Species conditioning via SpeciesEmbeddingStore (fixed-size [B,Ds] or variable-length [B,Ls,Ds])
|
| 6 |
+
- Protein conditioning via raw AA strings (ESM-C tokenization happens inside the model)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import logging
|
| 11 |
+
import json
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List, Optional, Union
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from src.sampler import CodonSampler
|
| 18 |
+
from src.dataset import SpeciesEmbeddingStore
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(
|
| 21 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 22 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 23 |
+
level=logging.INFO,
|
| 24 |
+
)
|
| 25 |
+
logger = logging.getLogger("codongpt.sample")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_args():
|
| 29 |
+
p = argparse.ArgumentParser(description="Sample codon sequences from CodonGPT model")
|
| 30 |
+
|
| 31 |
+
# Model
|
| 32 |
+
p.add_argument("--model_path", "--model_dir", dest="model_path", type=str, required=True,
|
| 33 |
+
help="Path to trained model checkpoint dir")
|
| 34 |
+
p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
|
| 35 |
+
p.add_argument("--compile", action="store_true", help="torch.compile the model")
|
| 36 |
+
|
| 37 |
+
# Species embeddings
|
| 38 |
+
p.add_argument("--embeddings_dir", type=str, default=None,
|
| 39 |
+
help="Directory with precomputed variable-length species embeddings (optional; fallback to Qwen if missing/unknown)")
|
| 40 |
+
p.add_argument("--strict_species_lookup", action="store_true",
|
| 41 |
+
help="When using --embeddings_dir, fail if any requested species name is not an exact key in species_vocab.json")
|
| 42 |
+
p.add_argument("--taxonomy_db", type=str, default=None,
|
| 43 |
+
help="Optional path to taxonomy_database.json (from precompute) to enrich prompts")
|
| 44 |
+
|
| 45 |
+
# Sampling batch size and count
|
| 46 |
+
p.add_argument("--num_sequences", "--num_seq", "--num_samples", type=int, default=1, dest="num_sequences",
|
| 47 |
+
help="Number of sequences to generate in total")
|
| 48 |
+
p.add_argument("--batch_size", type=int, default=None, help="Batch size for sampling loop")
|
| 49 |
+
|
| 50 |
+
# Control mode and length
|
| 51 |
+
p.add_argument("--control_mode", choices=["fixed", "variable"], default="fixed",
|
| 52 |
+
help="fixed: disallow EOS, generate exactly sequence_length codons; variable: allow EOS")
|
| 53 |
+
p.add_argument("--sequence_length", type=int, default=None,
|
| 54 |
+
help="Number of CODONS to generate (used as max steps in variable mode). "
|
| 55 |
+
"If omitted and protein sequences are provided, set to min protein length.")
|
| 56 |
+
|
| 57 |
+
# Conditioning (REQUIRED: species and protein)
|
| 58 |
+
p.add_argument("--species", "--taxon", type=str, default=None, dest="species",
|
| 59 |
+
help="Species name (e.g., 'Homo sapiens'). Replicated if num_sequences>1.")
|
| 60 |
+
p.add_argument("--species_list", type=str, nargs="+", default=None,
|
| 61 |
+
help="List of species names (must match num_sequences).")
|
| 62 |
+
|
| 63 |
+
p.add_argument("--protein_seq", "--protein_sequence", type=str, default=None, dest="protein_seq",
|
| 64 |
+
help="Protein sequence (AA string). Replicated if num_sequences>1.")
|
| 65 |
+
p.add_argument("--protein_file", type=str, default=None,
|
| 66 |
+
help="Path to FASTA-like file (each non-header line is a sequence). Must provide at least num_sequences.")
|
| 67 |
+
|
| 68 |
+
# Sampling params
|
| 69 |
+
p.add_argument("--temperature", type=float, default=1, help="Sampling temperature")
|
| 70 |
+
p.add_argument("--top_k", type=int, default=50, help="Top-k")
|
| 71 |
+
p.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus)")
|
| 72 |
+
p.add_argument("--enforce_translation", action="store_true", default=False,
|
| 73 |
+
help="Hard-mask codons to match the given protein AA at each position")
|
| 74 |
+
p.add_argument("--seed", type=int, default=None)
|
| 75 |
+
p.add_argument("--save_intermediate", action="store_true", help="Store intermediate token states")
|
| 76 |
+
|
| 77 |
+
# Output
|
| 78 |
+
p.add_argument("--output_file", type=str, default=None)
|
| 79 |
+
p.add_argument("--output_format", type=str, default="fasta", choices=["fasta", "csv", "json"])
|
| 80 |
+
|
| 81 |
+
# Misc
|
| 82 |
+
p.add_argument("--quiet", action="store_true")
|
| 83 |
+
return p.parse_args()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def load_protein_sequences(file_path: str) -> List[str]:
|
| 87 |
+
"""Load protein sequences: every non-'>' line is a sequence."""
|
| 88 |
+
seqs: List[str] = []
|
| 89 |
+
with open(file_path, "r") as f:
|
| 90 |
+
for line in f:
|
| 91 |
+
line = line.strip()
|
| 92 |
+
if line and not line.startswith(">"):
|
| 93 |
+
seqs.append(line)
|
| 94 |
+
return seqs
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def setup_species_store(embeddings_dir: str) -> SpeciesEmbeddingStore:
|
| 98 |
+
"""Load species embedding store (prefer variable-length if available)."""
|
| 99 |
+
# We don't guess. If you stored sequence-format, this will pick it; else fixed-size.
|
| 100 |
+
return SpeciesEmbeddingStore(embeddings_dir, pooling="sequence")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def save_sequences(
|
| 104 |
+
sequences: List[str],
|
| 105 |
+
output_file: str,
|
| 106 |
+
fmt: str,
|
| 107 |
+
species: Optional[List[str]] = None,
|
| 108 |
+
proteins: Optional[List[str]] = None,
|
| 109 |
+
metadata: Optional[dict] = None,
|
| 110 |
+
):
|
| 111 |
+
if fmt == "fasta":
|
| 112 |
+
with open(output_file, "w") as f:
|
| 113 |
+
for i, seq in enumerate(sequences):
|
| 114 |
+
header = f">seq_{i}"
|
| 115 |
+
if species and i < len(species):
|
| 116 |
+
header += f"|species={species[i]}"
|
| 117 |
+
if proteins and i < len(proteins):
|
| 118 |
+
header += f"|protein_len={len(proteins[i])}"
|
| 119 |
+
f.write(f"{header}\n{seq}\n")
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
if fmt == "csv":
|
| 123 |
+
import pandas as pd
|
| 124 |
+
data = {"sequence": sequences}
|
| 125 |
+
if species:
|
| 126 |
+
data["species"] = species[:len(sequences)]
|
| 127 |
+
if proteins:
|
| 128 |
+
data["protein_sequence"] = proteins[:len(sequences)]
|
| 129 |
+
pd.DataFrame(data).to_csv(output_file, index=False)
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
# json
|
| 133 |
+
payload = {"sequences": sequences, "metadata": metadata or {}}
|
| 134 |
+
if species:
|
| 135 |
+
payload["species"] = species[:len(sequences)]
|
| 136 |
+
if proteins:
|
| 137 |
+
payload["protein_sequences"] = proteins[:len(sequences)]
|
| 138 |
+
with open(output_file, "w") as f:
|
| 139 |
+
json.dump(payload, f, indent=2)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def translate_dna_to_aa(dna_seq: str) -> str:
|
| 143 |
+
"""Translate DNA (3-mer) using the standard genetic code."""
|
| 144 |
+
g = {
|
| 145 |
+
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
|
| 146 |
+
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
|
| 147 |
+
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
|
| 148 |
+
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
|
| 149 |
+
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
|
| 150 |
+
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
|
| 151 |
+
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
|
| 152 |
+
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
|
| 153 |
+
}
|
| 154 |
+
L = len(dna_seq) // 3
|
| 155 |
+
aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)]
|
| 156 |
+
return ''.join(aa)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def report_token_accuracy(sequences: List[str], target_proteins: List[str]) -> None:
|
| 160 |
+
for i, dna in enumerate(sequences):
|
| 161 |
+
tgt = target_proteins[i] if i < len(target_proteins) else target_proteins[-1]
|
| 162 |
+
gen_aa = translate_dna_to_aa(dna)
|
| 163 |
+
L = min(len(gen_aa), len(tgt))
|
| 164 |
+
if L == 0:
|
| 165 |
+
acc = 0.0; num = 0; den = 0
|
| 166 |
+
else:
|
| 167 |
+
matches = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b)
|
| 168 |
+
acc = matches / L; num = matches; den = L
|
| 169 |
+
logger.info(f"AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def main():
|
| 173 |
+
args = parse_args()
|
| 174 |
+
|
| 175 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 176 |
+
raise RuntimeError("CUDA requested but not available")
|
| 177 |
+
|
| 178 |
+
if args.seed is not None:
|
| 179 |
+
torch.manual_seed(int(args.seed))
|
| 180 |
+
|
| 181 |
+
# Conditioning must be provided – same invariants as training
|
| 182 |
+
have_species_names = bool(args.species_list) or bool(args.species)
|
| 183 |
+
have_protein = bool(args.protein_file) or bool(args.protein_seq)
|
| 184 |
+
if not have_species_names or not have_protein:
|
| 185 |
+
raise ValueError("Sampling requires BOTH species (names) and protein sequence(s).")
|
| 186 |
+
|
| 187 |
+
# Species names list
|
| 188 |
+
if args.species_list:
|
| 189 |
+
species_names = list(args.species_list)
|
| 190 |
+
else:
|
| 191 |
+
species_names = [str(args.species)]
|
| 192 |
+
|
| 193 |
+
# Protein sequences list
|
| 194 |
+
if args.protein_file:
|
| 195 |
+
protein_sequences = load_protein_sequences(args.protein_file)
|
| 196 |
+
else:
|
| 197 |
+
protein_sequences = [str(args.protein_seq)]
|
| 198 |
+
|
| 199 |
+
# Expand/reconcile counts
|
| 200 |
+
N = int(args.num_sequences)
|
| 201 |
+
if len(species_names) == 1 and N > 1:
|
| 202 |
+
species_names = species_names * N
|
| 203 |
+
if len(protein_sequences) == 1 and N > 1:
|
| 204 |
+
protein_sequences = protein_sequences * N
|
| 205 |
+
|
| 206 |
+
if len(species_names) != N:
|
| 207 |
+
raise ValueError(f"species count ({len(species_names)}) must equal num_sequences ({N})")
|
| 208 |
+
if len(protein_sequences) < N:
|
| 209 |
+
raise ValueError(f"protein sequences provided ({len(protein_sequences)}) less than num_sequences ({N})")
|
| 210 |
+
if len(protein_sequences) > N:
|
| 211 |
+
protein_sequences = protein_sequences[:N]
|
| 212 |
+
|
| 213 |
+
# If no explicit sequence_length, use min protein length, so every sample has a valid AA at each fixed step
|
| 214 |
+
if args.sequence_length is None:
|
| 215 |
+
args.sequence_length = min(len(s) for s in protein_sequences)
|
| 216 |
+
logger.info(f"Auto-set sequence_length to min protein length: {args.sequence_length} codons")
|
| 217 |
+
|
| 218 |
+
if args.sequence_length <= 0:
|
| 219 |
+
raise ValueError("sequence_length must be > 0")
|
| 220 |
+
|
| 221 |
+
# Load species store if provided (preferred to exactly match training); unknown species will fallback to Qwen
|
| 222 |
+
species_store = None
|
| 223 |
+
if args.embeddings_dir:
|
| 224 |
+
species_store = setup_species_store(args.embeddings_dir)
|
| 225 |
+
logger.info(f"Loaded species store: {len(species_store.vocab)} species; Ds={species_store.Ds()}")
|
| 226 |
+
if args.strict_species_lookup:
|
| 227 |
+
unknown = sorted({name for name in species_names if name not in species_store.vocab})
|
| 228 |
+
if unknown:
|
| 229 |
+
preview = ", ".join(repr(x) for x in unknown[:5])
|
| 230 |
+
more = "" if len(unknown) <= 5 else f" ... (+{len(unknown) - 5} more)"
|
| 231 |
+
raise ValueError(
|
| 232 |
+
"strict species lookup failed; these names are not exact keys in species_vocab.json: "
|
| 233 |
+
f"{preview}{more}"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
sampler = CodonSampler(
|
| 237 |
+
model_path=args.model_path,
|
| 238 |
+
device=args.device,
|
| 239 |
+
compile_model=bool(args.compile),
|
| 240 |
+
species_store=species_store,
|
| 241 |
+
taxonomy_db_path=args.taxonomy_db,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Batch loop
|
| 245 |
+
batch_size = int(args.batch_size or N)
|
| 246 |
+
all_sequences: List[str] = []
|
| 247 |
+
all_intermediates = []
|
| 248 |
+
|
| 249 |
+
total_batches = (N + batch_size - 1) // batch_size
|
| 250 |
+
for start in range(0, N, batch_size):
|
| 251 |
+
end = min(N, start + batch_size)
|
| 252 |
+
bs = end - start
|
| 253 |
+
batch_species = species_names[start:end]
|
| 254 |
+
batch_proteins = protein_sequences[start:end]
|
| 255 |
+
|
| 256 |
+
logger.info(f"Sampling batch {start//batch_size + 1}/{total_batches} (B={bs})")
|
| 257 |
+
|
| 258 |
+
result = sampler.sample(
|
| 259 |
+
num_sequences=bs,
|
| 260 |
+
sequence_length=int(args.sequence_length),
|
| 261 |
+
species=batch_species,
|
| 262 |
+
protein_sequences=batch_proteins,
|
| 263 |
+
control_mode=str(args.control_mode),
|
| 264 |
+
temperature=float(args.temperature),
|
| 265 |
+
top_k=int(args.top_k),
|
| 266 |
+
top_p=float(args.top_p),
|
| 267 |
+
seed=int(args.seed) if args.seed is not None else None,
|
| 268 |
+
return_intermediate=bool(args.save_intermediate),
|
| 269 |
+
progress_bar=not bool(args.quiet),
|
| 270 |
+
enforce_translation=bool(args.enforce_translation),
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
seqs = result["sequences"] # List[str]
|
| 274 |
+
all_sequences.extend(seqs)
|
| 275 |
+
if args.save_intermediate and "intermediate_states" in result:
|
| 276 |
+
all_intermediates.append(result["intermediate_states"])
|
| 277 |
+
|
| 278 |
+
logger.info(f"Generated {len(all_sequences)} sequences.")
|
| 279 |
+
for i, seq in enumerate(all_sequences[:5]):
|
| 280 |
+
logger.info(f"Sequence {i+1} ({len(seq)//3} codons): {seq[:60]}...")
|
| 281 |
+
|
| 282 |
+
# Save outputs
|
| 283 |
+
if args.output_file:
|
| 284 |
+
meta = {
|
| 285 |
+
"model_path": args.model_path,
|
| 286 |
+
"temperature": args.temperature,
|
| 287 |
+
"top_k": args.top_k,
|
| 288 |
+
"top_p": args.top_p,
|
| 289 |
+
"control_mode": args.control_mode,
|
| 290 |
+
"sequence_length": int(args.sequence_length),
|
| 291 |
+
}
|
| 292 |
+
save_sequences(
|
| 293 |
+
all_sequences,
|
| 294 |
+
args.output_file,
|
| 295 |
+
args.output_format,
|
| 296 |
+
species=species_names,
|
| 297 |
+
proteins=protein_sequences,
|
| 298 |
+
metadata=meta,
|
| 299 |
+
)
|
| 300 |
+
logger.info(f"Saved sequences to {args.output_file}")
|
| 301 |
+
|
| 302 |
+
# Report AA token accuracy when protein targets are given
|
| 303 |
+
report_token_accuracy(all_sequences, protein_sequences)
|
| 304 |
+
|
| 305 |
+
if args.save_intermediate and all_intermediates:
|
| 306 |
+
inter_file = Path(args.output_file).with_suffix("").as_posix() + "_intermediate.pt"
|
| 307 |
+
torch.save(all_intermediates, inter_file)
|
| 308 |
+
logger.info(f"Saved intermediate states to {inter_file}")
|
| 309 |
+
|
| 310 |
+
logger.info("Sampling completed.")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
main()
|
slurm/rebuild_data_v3_cpu.sbatch
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=beacon
|
| 3 |
+
#SBATCH --qos=high
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=1
|
| 6 |
+
#SBATCH --cpus-per-task=16
|
| 7 |
+
#SBATCH --mem=240G
|
| 8 |
+
#SBATCH --time=3-00:00:00
|
| 9 |
+
#SBATCH --job-name=data_v3_rebuild
|
| 10 |
+
#SBATCH --output=%x_%j.out
|
| 11 |
+
#SBATCH --error=%x_%j.err
|
| 12 |
+
|
| 13 |
+
set -euo pipefail
|
| 14 |
+
|
| 15 |
+
REPO_ROOT=${REPO_ROOT:-/beacon-projects/codon-lm/HE-DLM}
|
| 16 |
+
PYTHON_BIN=${PYTHON_BIN:-/beacon-projects/codon-lm/miniconda3/envs/dna/bin/python}
|
| 17 |
+
MMSEQS_BIN=${MMSEQS_BIN:-$REPO_ROOT/MMseqs2/build/bin/mmseqs}
|
| 18 |
+
INPUT_GLOB=${INPUT_GLOB:-data_v2/*/*.parquet}
|
| 19 |
+
HELDOUT_GLOB=${HELDOUT_GLOB:-data_v2/test/*.parquet}
|
| 20 |
+
SEQ_SPACE=${SEQ_SPACE:-protein}
|
| 21 |
+
SPECIES_KEY_MODE=${SPECIES_KEY_MODE:-binomial}
|
| 22 |
+
MAX_INPUT_SEQ_LEN=${MAX_INPUT_SEQ_LEN:-20000}
|
| 23 |
+
MODE=${MODE:-full}
|
| 24 |
+
OUTPUT_ROOT=${OUTPUT_ROOT:-data_v3_rebuild}
|
| 25 |
+
LIMIT_FILES=${LIMIT_FILES:-0}
|
| 26 |
+
NUM_SHARDS=${NUM_SHARDS:-256}
|
| 27 |
+
VAL_FRAC=${VAL_FRAC:-0.01}
|
| 28 |
+
THREADS=${THREADS:-${SLURM_CPUS_PER_TASK:-16}}
|
| 29 |
+
MIN_SEQ_ID=${MIN_SEQ_ID:-0.90}
|
| 30 |
+
COVERAGE=${COVERAGE:-0.80}
|
| 31 |
+
COV_MODE=${COV_MODE:-2}
|
| 32 |
+
CLUSTER_MODE=${CLUSTER_MODE:-2}
|
| 33 |
+
MAX_SEQ_LEN=${MAX_SEQ_LEN:-200000}
|
| 34 |
+
SPLIT_MEMORY_LIMIT=${SPLIT_MEMORY_LIMIT:-180G}
|
| 35 |
+
SEED=${SEED:-13}
|
| 36 |
+
OVERWRITE=${OVERWRITE:-1}
|
| 37 |
+
|
| 38 |
+
if [[ "${MODE}" == "pilot" ]]; then
|
| 39 |
+
if [[ "${LIMIT_FILES}" == "0" ]]; then
|
| 40 |
+
LIMIT_FILES=4
|
| 41 |
+
fi
|
| 42 |
+
if [[ "${OUTPUT_ROOT}" == "data_v3_rebuild" ]]; then
|
| 43 |
+
OUTPUT_ROOT=data_v3_pilot
|
| 44 |
+
fi
|
| 45 |
+
if [[ "${NUM_SHARDS}" == "256" ]]; then
|
| 46 |
+
NUM_SHARDS=16
|
| 47 |
+
fi
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
cd "${REPO_ROOT}"
|
| 51 |
+
|
| 52 |
+
export LD_LIBRARY_PATH="/beacon-projects/codon-lm/miniconda3/envs/dna/lib:/beacon-projects/codon-lm/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
| 53 |
+
|
| 54 |
+
if [[ -f "${MMSEQS_BIN}" && ! -x "${MMSEQS_BIN}" ]]; then
|
| 55 |
+
chmod u+x "${MMSEQS_BIN}"
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
"${PYTHON_BIN}" - <<'PY'
|
| 59 |
+
import duckdb
|
| 60 |
+
import pyarrow
|
| 61 |
+
|
| 62 |
+
print("duckdb", duckdb.__version__)
|
| 63 |
+
print("pyarrow", pyarrow.__version__)
|
| 64 |
+
PY
|
| 65 |
+
|
| 66 |
+
CMD=(
|
| 67 |
+
"${PYTHON_BIN}" resplit_data_v3.py all
|
| 68 |
+
--input-glob "${INPUT_GLOB}"
|
| 69 |
+
--heldout-test-glob "${HELDOUT_GLOB}"
|
| 70 |
+
--output-root "${OUTPUT_ROOT}"
|
| 71 |
+
--seq-space "${SEQ_SPACE}"
|
| 72 |
+
--species-key-mode "${SPECIES_KEY_MODE}"
|
| 73 |
+
--max-input-seq-len "${MAX_INPUT_SEQ_LEN}"
|
| 74 |
+
--threads "${THREADS}"
|
| 75 |
+
--limit-files "${LIMIT_FILES}"
|
| 76 |
+
--num-shards "${NUM_SHARDS}"
|
| 77 |
+
--mmseqs "${MMSEQS_BIN}"
|
| 78 |
+
--min-seq-id "${MIN_SEQ_ID}"
|
| 79 |
+
--coverage "${COVERAGE}"
|
| 80 |
+
--cov-mode "${COV_MODE}"
|
| 81 |
+
--cluster-mode "${CLUSTER_MODE}"
|
| 82 |
+
--max-seq-len "${MAX_SEQ_LEN}"
|
| 83 |
+
--val-frac "${VAL_FRAC}"
|
| 84 |
+
--seed "${SEED}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if [[ -n "${SPLIT_MEMORY_LIMIT}" ]]; then
|
| 88 |
+
CMD+=(--split-memory-limit "${SPLIT_MEMORY_LIMIT}")
|
| 89 |
+
fi
|
| 90 |
+
|
| 91 |
+
if [[ "${OVERWRITE}" == "1" ]]; then
|
| 92 |
+
CMD+=(--overwrite)
|
| 93 |
+
fi
|
| 94 |
+
|
| 95 |
+
printf 'Running command:'
|
| 96 |
+
printf ' %q' "${CMD[@]}"
|
| 97 |
+
printf '\n'
|
| 98 |
+
"${CMD[@]}"
|
slurm/submit_train_v3_h200_8x_chain.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
cd /beacon-projects/codon-lm/HE-DLM
|
| 6 |
+
|
| 7 |
+
SEGMENTS=${SEGMENTS:-3}
|
| 8 |
+
SBATCH_SCRIPT=${SBATCH_SCRIPT:-slurm/train_v3_h200_8x_single.sbatch}
|
| 9 |
+
|
| 10 |
+
if [[ ! -f "${SBATCH_SCRIPT}" ]]; then
|
| 11 |
+
echo "Missing sbatch script: ${SBATCH_SCRIPT}" >&2
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
dep=""
|
| 16 |
+
for idx in $(seq 1 "${SEGMENTS}"); do
|
| 17 |
+
if [[ -n "${dep}" ]]; then
|
| 18 |
+
jid=$(sbatch --parsable --dependency=afterany:"${dep}" "${SBATCH_SCRIPT}")
|
| 19 |
+
else
|
| 20 |
+
jid=$(sbatch --parsable "${SBATCH_SCRIPT}")
|
| 21 |
+
fi
|
| 22 |
+
echo "submitted segment=${idx} job_id=${jid} dependency=${dep:-none}"
|
| 23 |
+
dep="${jid}"
|
| 24 |
+
done
|
slurm/train_v3_h200_8x_single.sbatch
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Single-node 8x H200 training entrypoint.
|
| 3 |
+
# Reserved single-node smoke-run example:
|
| 4 |
+
# sbatch --time=00:45:00 \
|
| 5 |
+
# --export=ALL,OUT_DIR=/beacon-projects/codon-lm/HE-DLM/outputs_v3_rep_h200_8x_reserved_smoke,MAX_STEPS=20,SAVE_STEPS=0,EVAL_INTERVAL=0 \
|
| 6 |
+
# slurm/train_v3_h200_8x_single.sbatch
|
| 7 |
+
# Full-run example:
|
| 8 |
+
# sbatch slurm/train_v3_h200_8x_single.sbatch
|
| 9 |
+
#
|
| 10 |
+
# Suggested W&B overrides:
|
| 11 |
+
# sbatch --export=ALL,WANDB_PROJECT=he-dlm-v3-h200-8x,WANDB_NAME=he-dlm-v3-h200-8x-run1 \
|
| 12 |
+
# slurm/train_v3_h200_8x_single.sbatch
|
| 13 |
+
# If the environment is still configured for offline logging, override at submit time:
|
| 14 |
+
# sbatch --export=ALL,WANDB_MODE=online slurm/train_v3_h200_8x_single.sbatch
|
| 15 |
+
# This script is pinned to the reserved H200 allocation on ihccs210.
|
| 16 |
+
# Do not use QoS=reserved on any other node.
|
| 17 |
+
#SBATCH --job-name=train-v3-h200-8x
|
| 18 |
+
#SBATCH --partition=beacon
|
| 19 |
+
#SBATCH --qos=reserved
|
| 20 |
+
#SBATCH --reservation=heng-reservation
|
| 21 |
+
#SBATCH --nodelist=ihccs210
|
| 22 |
+
#SBATCH --nodes=1
|
| 23 |
+
#SBATCH --ntasks=1
|
| 24 |
+
#SBATCH --gres=gpu:nvidia_h200:8
|
| 25 |
+
#SBATCH --cpus-per-task=16
|
| 26 |
+
#SBATCH --mem=512G
|
| 27 |
+
#SBATCH --time=3-00:00:00
|
| 28 |
+
#SBATCH --output=%x_%j.out
|
| 29 |
+
#SBATCH --error=%x_%j.err
|
| 30 |
+
|
| 31 |
+
set -euo pipefail
|
| 32 |
+
|
| 33 |
+
set +u
|
| 34 |
+
source ~/.bashrc
|
| 35 |
+
conda activate dna
|
| 36 |
+
set -u
|
| 37 |
+
|
| 38 |
+
cd /beacon-projects/codon-lm/HE-DLM
|
| 39 |
+
|
| 40 |
+
TRAIN_DATA=${TRAIN_DATA:-/beacon-projects/codon-lm/HE-DLM/data_v3_rebuild/train}
|
| 41 |
+
VAL_DATA=${VAL_DATA:-/beacon-projects/codon-lm/HE-DLM/data_v3_rebuild/val}
|
| 42 |
+
EMBED_DIR=${EMBED_DIR:-/beacon-projects/codon-lm/HE-DLM/embeddings_v2}
|
| 43 |
+
OUT_DIR=${OUT_DIR:-/beacon-projects/codon-lm/HE-DLM/outputs_v3_rep_h200_8x_single_wd1e-4_bs48ga4}
|
| 44 |
+
|
| 45 |
+
WANDB_PROJECT=${WANDB_PROJECT:-he-dlm-v3-h200-8x}
|
| 46 |
+
WANDB_NAME=${WANDB_NAME:-$(basename "${OUT_DIR}")}
|
| 47 |
+
WANDB_RUN_ID=${WANDB_RUN_ID:-$(basename "${OUT_DIR}")}
|
| 48 |
+
WANDB_RESUME=${WANDB_RESUME:-allow}
|
| 49 |
+
WANDB_DIR=${WANDB_DIR:-${OUT_DIR}/wandb}
|
| 50 |
+
|
| 51 |
+
NPROC_PER_NODE=${NPROC_PER_NODE:-8}
|
| 52 |
+
BATCH_SIZE=${BATCH_SIZE:-48}
|
| 53 |
+
GRAD_ACCUM=${GRAD_ACCUM:-4}
|
| 54 |
+
EVAL_BATCH_SIZE=${EVAL_BATCH_SIZE:-32}
|
| 55 |
+
WORKERS=${WORKERS:-0}
|
| 56 |
+
EPOCHS=${EPOCHS:-3}
|
| 57 |
+
LR=${LR:-7e-5}
|
| 58 |
+
WARMUP_RATIO=${WARMUP_RATIO:-0.1}
|
| 59 |
+
WEIGHT_DECAY=${WEIGHT_DECAY:-1e-4}
|
| 60 |
+
LOGGING_STEPS=${LOGGING_STEPS:-10}
|
| 61 |
+
SAVE_STEPS=${SAVE_STEPS:-500}
|
| 62 |
+
SAVE_TOTAL_LIMIT=${SAVE_TOTAL_LIMIT:-1000}
|
| 63 |
+
EVAL_INTERVAL=${EVAL_INTERVAL:-5000}
|
| 64 |
+
EVAL_STEPS=${EVAL_STEPS:-256}
|
| 65 |
+
TRAIN_SHUFFLE_BUFFER=${TRAIN_SHUFFLE_BUFFER:-8192}
|
| 66 |
+
VAL_SHUFFLE_BUFFER=${VAL_SHUFFLE_BUFFER:-0}
|
| 67 |
+
CKPT_RECENT_WINDOW_STEPS=${CKPT_RECENT_WINDOW_STEPS:-2000}
|
| 68 |
+
CKPT_RECENT_INTERVAL=${CKPT_RECENT_INTERVAL:-500}
|
| 69 |
+
CKPT_ARCHIVE_INTERVAL=${CKPT_ARCHIVE_INTERVAL:-1000}
|
| 70 |
+
RESUME_FROM=${RESUME_FROM:-auto}
|
| 71 |
+
MAX_STEPS=${MAX_STEPS:-}
|
| 72 |
+
MASTER_PORT=${MASTER_PORT:-29500}
|
| 73 |
+
GRAD_CKPT=${GRAD_CKPT:-0}
|
| 74 |
+
|
| 75 |
+
export WANDB_PROJECT WANDB_NAME WANDB_RUN_ID WANDB_RESUME WANDB_DIR
|
| 76 |
+
export NCCL_DEBUG=${NCCL_DEBUG:-WARN}
|
| 77 |
+
export TORCH_DISTRIBUTED_DEBUG=${TORCH_DISTRIBUTED_DEBUG:-DETAIL}
|
| 78 |
+
export NCCL_P2P_DISABLE=${NCCL_P2P_DISABLE:-0}
|
| 79 |
+
export NCCL_IB_DISABLE=${NCCL_IB_DISABLE:-1}
|
| 80 |
+
export NCCL_NET_GDR_LEVEL=${NCCL_NET_GDR_LEVEL:-0}
|
| 81 |
+
export NCCL_ASYNC_ERROR_HANDLING=${NCCL_ASYNC_ERROR_HANDLING:-1}
|
| 82 |
+
export NCCL_SHM_DISABLE=${NCCL_SHM_DISABLE:-1}
|
| 83 |
+
export NCCL_CUMEM_HOST_ENABLE=${NCCL_CUMEM_HOST_ENABLE:-1}
|
| 84 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1}
|
| 85 |
+
|
| 86 |
+
mkdir -p "${OUT_DIR}" "${WANDB_DIR}"
|
| 87 |
+
|
| 88 |
+
if [[ ! -d "${TRAIN_DATA}" ]]; then
|
| 89 |
+
echo "Missing train data dir: ${TRAIN_DATA}" >&2
|
| 90 |
+
exit 1
|
| 91 |
+
fi
|
| 92 |
+
if [[ ! -d "${VAL_DATA}" ]]; then
|
| 93 |
+
echo "Missing val data dir: ${VAL_DATA}" >&2
|
| 94 |
+
exit 1
|
| 95 |
+
fi
|
| 96 |
+
if [[ ! -f "${EMBED_DIR}/species_vocab.json" ]]; then
|
| 97 |
+
echo "Missing embeddings vocab: ${EMBED_DIR}/species_vocab.json" >&2
|
| 98 |
+
exit 1
|
| 99 |
+
fi
|
| 100 |
+
|
| 101 |
+
echo "HOST=$(hostname)"
|
| 102 |
+
echo "TRAIN_DATA=${TRAIN_DATA}"
|
| 103 |
+
echo "VAL_DATA=${VAL_DATA}"
|
| 104 |
+
echo "EMBED_DIR=${EMBED_DIR}"
|
| 105 |
+
echo "OUT_DIR=${OUT_DIR}"
|
| 106 |
+
echo "WANDB_PROJECT=${WANDB_PROJECT} WANDB_NAME=${WANDB_NAME} WANDB_RUN_ID=${WANDB_RUN_ID} WANDB_RESUME=${WANDB_RESUME} WANDB_MODE=${WANDB_MODE:-unset}"
|
| 107 |
+
echo "BATCH_SIZE=${BATCH_SIZE} GRAD_ACCUM=${GRAD_ACCUM} EVAL_BATCH_SIZE=${EVAL_BATCH_SIZE} NPROC_PER_NODE=${NPROC_PER_NODE}"
|
| 108 |
+
echo "WEIGHT_DECAY=${WEIGHT_DECAY} SAVE_STEPS=${SAVE_STEPS} EVAL_INTERVAL=${EVAL_INTERVAL} MAX_STEPS=${MAX_STEPS:-unset}"
|
| 109 |
+
echo "NCCL_P2P_DISABLE=${NCCL_P2P_DISABLE} NCCL_IB_DISABLE=${NCCL_IB_DISABLE} NCCL_SHM_DISABLE=${NCCL_SHM_DISABLE} NCCL_CUMEM_HOST_ENABLE=${NCCL_CUMEM_HOST_ENABLE}"
|
| 110 |
+
|
| 111 |
+
echo "=== GPU inventory ==="
|
| 112 |
+
nvidia-smi --query-gpu=index,name,memory.total,driver_version --format=csv,noheader || true
|
| 113 |
+
echo "=== GPU topology ==="
|
| 114 |
+
nvidia-smi topo -m || true
|
| 115 |
+
echo "=== NVLink status ==="
|
| 116 |
+
nvidia-smi nvlink -s || true
|
| 117 |
+
|
| 118 |
+
CMD=(
|
| 119 |
+
torchrun
|
| 120 |
+
--standalone
|
| 121 |
+
--nproc_per_node "${NPROC_PER_NODE}"
|
| 122 |
+
--master_port "${MASTER_PORT}"
|
| 123 |
+
train.py
|
| 124 |
+
--train_data "${TRAIN_DATA}"
|
| 125 |
+
--val_data "${VAL_DATA}"
|
| 126 |
+
--embeddings_dir "${EMBED_DIR}"
|
| 127 |
+
--output_dir "${OUT_DIR}"
|
| 128 |
+
--fsdp
|
| 129 |
+
--bf16
|
| 130 |
+
--attn mha
|
| 131 |
+
--hidden 750
|
| 132 |
+
--layers 20
|
| 133 |
+
--heads 15
|
| 134 |
+
--mlp_ratio 3.2
|
| 135 |
+
--batch_size "${BATCH_SIZE}"
|
| 136 |
+
--grad_accum "${GRAD_ACCUM}"
|
| 137 |
+
--eval_batch_size "${EVAL_BATCH_SIZE}"
|
| 138 |
+
--epochs "${EPOCHS}"
|
| 139 |
+
--workers "${WORKERS}"
|
| 140 |
+
--warmup_ratio "${WARMUP_RATIO}"
|
| 141 |
+
--lr "${LR}"
|
| 142 |
+
--weight_decay "${WEIGHT_DECAY}"
|
| 143 |
+
--train_shuffle_buffer "${TRAIN_SHUFFLE_BUFFER}"
|
| 144 |
+
--val_shuffle_buffer "${VAL_SHUFFLE_BUFFER}"
|
| 145 |
+
--logging_steps "${LOGGING_STEPS}"
|
| 146 |
+
--save_steps "${SAVE_STEPS}"
|
| 147 |
+
--save_total_limit "${SAVE_TOTAL_LIMIT}"
|
| 148 |
+
--ckpt_recent_window_steps "${CKPT_RECENT_WINDOW_STEPS}"
|
| 149 |
+
--ckpt_recent_interval "${CKPT_RECENT_INTERVAL}"
|
| 150 |
+
--ckpt_archive_interval "${CKPT_ARCHIVE_INTERVAL}"
|
| 151 |
+
--eval_interval "${EVAL_INTERVAL}"
|
| 152 |
+
--eval_steps "${EVAL_STEPS}"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if [[ "${RESUME_FROM}" != "none" && -n "${RESUME_FROM}" ]]; then
|
| 156 |
+
CMD+=(--resume_from "${RESUME_FROM}")
|
| 157 |
+
fi
|
| 158 |
+
if [[ -n "${MAX_STEPS}" ]]; then
|
| 159 |
+
CMD+=(--max_steps "${MAX_STEPS}")
|
| 160 |
+
fi
|
| 161 |
+
if [[ "${GRAD_CKPT}" == "1" ]]; then
|
| 162 |
+
CMD+=(--grad_ckpt)
|
| 163 |
+
fi
|
| 164 |
+
|
| 165 |
+
exec "${CMD[@]}"
|
src/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CodonGPT – conditional codon sequence generation (GPT-only).
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .tokenizer import CodonTokenizer
|
| 6 |
+
from .models import (
|
| 7 |
+
CodonGPT,
|
| 8 |
+
)
|
| 9 |
+
from .trainer import Trainer, TrainingArguments
|
| 10 |
+
from .sampler import CodonSampler, sample_sequences
|
| 11 |
+
from .dataset import (
|
| 12 |
+
stage_collate_fn,
|
| 13 |
+
create_precomputed_dataloaders,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__version__ = "0.1.0"
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
# Tokenizer
|
| 20 |
+
"CodonTokenizer",
|
| 21 |
+
# Models
|
| 22 |
+
"CodonGPT",
|
| 23 |
+
# Training
|
| 24 |
+
"Trainer",
|
| 25 |
+
"TrainingArguments",
|
| 26 |
+
# Sampling
|
| 27 |
+
"CodonSampler",
|
| 28 |
+
"sample_sequences",
|
| 29 |
+
# Data
|
| 30 |
+
# "stage_collate_fn",
|
| 31 |
+
"create_precomputed_dataloaders",
|
| 32 |
+
# Noise
|
| 33 |
+
]
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (637 Bytes). View file
|
|
|
src/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (41 kB). View file
|
|
|
src/__pycache__/layers.cpython-312.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
src/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (25.8 kB). View file
|
|
|
src/__pycache__/sampler.cpython-312.pyc
ADDED
|
Binary file (36.1 kB). View file
|
|
|
src/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
src/__pycache__/trainer.cpython-312.pyc
ADDED
|
Binary file (65.7 kB). View file
|
|
|
src/dataset.py
ADDED
|
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/dataset.py
|
| 2 |
+
"""
|
| 3 |
+
Production-ready dataset + dataloader utilities.
|
| 4 |
+
|
| 5 |
+
Rules (because we're adults):
|
| 6 |
+
- Data drives design. Inputs are rows with columns: ["cds_DNA", "protein_seq", "Taxon", (optional) "RefseqID"].
|
| 7 |
+
- Output per sample is a tiny dict the model actually needs. Nothing else.
|
| 8 |
+
- We stream Parquet by row groups, CSV by chunks. No full-file pandas nonsense on big data.
|
| 9 |
+
- We shard by (FSDP rank × dataloader worker). No DistributedSampler needed.
|
| 10 |
+
- We do a simple streaming shuffle buffer for train. Good enough. No fancy "epoch managers".
|
| 11 |
+
|
| 12 |
+
Fields emitted per sample (for collate_fn and trainer):
|
| 13 |
+
{
|
| 14 |
+
"species_name": str,
|
| 15 |
+
"species_id": int,
|
| 16 |
+
"protein_seq": str, # raw AA (ESM tokenized later)
|
| 17 |
+
"aa_len": int,
|
| 18 |
+
"codon_ids": List[int], # tokenized 3-mer ids + EOS at the end
|
| 19 |
+
"refseq_id": str,
|
| 20 |
+
"protein_refseq_id": str,
|
| 21 |
+
"control_mode": "fixed",
|
| 22 |
+
"meta": {"src": "parquet|csv", "file": basename, "row": int}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
Invariants:
|
| 26 |
+
- cds_DNA length divisible by 3 after trimming to match protein length.
|
| 27 |
+
- DNA uses only ACGT (uppercase). If not, we skip the row. We don't "helpfully fix" broken data.
|
| 28 |
+
- We truncate both DNA and protein to the same min length (codon count).
|
| 29 |
+
- EOS appended to codon_ids; PAD is handled at collate time, not here.
|
| 30 |
+
|
| 31 |
+
Dependencies:
|
| 32 |
+
- pyarrow only if you read parquet. If it isn't installed and you pass parquet files, we fail loudly.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
from __future__ import annotations
|
| 36 |
+
|
| 37 |
+
import os
|
| 38 |
+
import json
|
| 39 |
+
import glob
|
| 40 |
+
import random
|
| 41 |
+
import logging
|
| 42 |
+
import heapq
|
| 43 |
+
from typing import Dict, List, Any, Optional, Iterable, Tuple
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
|
| 46 |
+
import numpy as np
|
| 47 |
+
import pandas as pd
|
| 48 |
+
import torch
|
| 49 |
+
from torch.utils.data import IterableDataset, Dataset, DataLoader, get_worker_info
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from tqdm.auto import tqdm as _tqdm
|
| 53 |
+
except Exception: # pragma: no cover - tqdm might be unavailable in minimal envs
|
| 54 |
+
_tqdm = None
|
| 55 |
+
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
# ------------------------------
|
| 59 |
+
# Species Embedding Store (kept simple and stable)
|
| 60 |
+
# ------------------------------
|
| 61 |
+
|
| 62 |
+
class SpeciesEmbeddingStore:
|
| 63 |
+
def __init__(self, embeddings_dir: str, dtype: str = "float32", pin_memory: bool = False, pooling: str = "last"):
|
| 64 |
+
self.embeddings_dir = Path(embeddings_dir)
|
| 65 |
+
self.pin_memory = bool(pin_memory)
|
| 66 |
+
self.is_legacy = False
|
| 67 |
+
self.pooling = pooling
|
| 68 |
+
|
| 69 |
+
vocab_path = self.embeddings_dir / "species_vocab.json"
|
| 70 |
+
if not vocab_path.exists():
|
| 71 |
+
raise FileNotFoundError(f"Species vocabulary not found at {vocab_path}")
|
| 72 |
+
with open(vocab_path, "r") as f:
|
| 73 |
+
self.vocab: Dict[str, int] = json.load(f)
|
| 74 |
+
|
| 75 |
+
meta_path = self.embeddings_dir / "species_metadata.json"
|
| 76 |
+
new_emb_path = self.embeddings_dir / "species_embeddings.bin"
|
| 77 |
+
legacy_index = self.embeddings_dir / "species_index.json"
|
| 78 |
+
legacy_emb = self.embeddings_dir / "species_tok_emb.bin"
|
| 79 |
+
|
| 80 |
+
if self.pooling == "sequence" and legacy_index.exists() and legacy_emb.exists():
|
| 81 |
+
self.is_legacy = True
|
| 82 |
+
self._load_legacy_format(dtype)
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
if meta_path.exists() and new_emb_path.exists():
|
| 86 |
+
with open(meta_path, "r") as f:
|
| 87 |
+
meta = json.load(f)
|
| 88 |
+
self.num_species = int(meta["num_species"])
|
| 89 |
+
self._ds = int(meta["embedding_dim"])
|
| 90 |
+
self.embedding_type = str(meta.get("embedding_type", "fixed_size"))
|
| 91 |
+
np_dtype = np.float16 if dtype == "float16" else np.float32
|
| 92 |
+
self.embeddings = np.memmap(new_emb_path, dtype=np_dtype, mode="r", shape=(self.num_species, self._ds))
|
| 93 |
+
self._np_dtype = np_dtype
|
| 94 |
+
print(f"Loaded fixed-size species embeddings: {len(self.vocab)} species, Ds={self._ds}, dtype={self._np_dtype}")
|
| 95 |
+
else:
|
| 96 |
+
self.is_legacy = True
|
| 97 |
+
self._load_legacy_format(dtype)
|
| 98 |
+
|
| 99 |
+
def _load_legacy_format(self, dtype: str):
|
| 100 |
+
index_path = self.embeddings_dir / "species_index.json"
|
| 101 |
+
if not index_path.exists():
|
| 102 |
+
raise FileNotFoundError(f"Species index not found at {index_path}")
|
| 103 |
+
with open(index_path, "r") as f:
|
| 104 |
+
raw_index = json.load(f)
|
| 105 |
+
self.index: Dict[str, Dict[str, int]] = {str(k): v for k, v in raw_index.items()}
|
| 106 |
+
|
| 107 |
+
meta_path = self.embeddings_dir / "metadata.json"
|
| 108 |
+
file_dtype = dtype
|
| 109 |
+
if meta_path.exists():
|
| 110 |
+
with open(meta_path, "r") as f:
|
| 111 |
+
meta = json.load(f)
|
| 112 |
+
self._ds = int(meta.get("embedding_dim", 1024))
|
| 113 |
+
file_dtype = str(meta.get("dtype", dtype)).lower()
|
| 114 |
+
else:
|
| 115 |
+
self._ds = 1024
|
| 116 |
+
|
| 117 |
+
emb_path = self.embeddings_dir / "species_tok_emb.bin"
|
| 118 |
+
if not emb_path.exists():
|
| 119 |
+
raise FileNotFoundError(f"Species embeddings not found at {emb_path}")
|
| 120 |
+
|
| 121 |
+
np_dtype = np.float16 if file_dtype == "float16" else np.float32
|
| 122 |
+
itemsize = np.dtype(np_dtype).itemsize
|
| 123 |
+
file_bytes = os.path.getsize(emb_path)
|
| 124 |
+
if file_bytes % (self._ds * itemsize) != 0:
|
| 125 |
+
raise ValueError(f"Emb file size {file_bytes} not divisible by Ds*itemsize ({self._ds}*{itemsize})")
|
| 126 |
+
total_tokens = file_bytes // (self._ds * itemsize)
|
| 127 |
+
|
| 128 |
+
self.embeddings = np.memmap(emb_path, dtype=np_dtype, mode="r", shape=(total_tokens, self._ds))
|
| 129 |
+
self._np_dtype = np_dtype
|
| 130 |
+
self.num_species = len(self.vocab)
|
| 131 |
+
print(f"[LEGACY] variable-length embeddings: {len(self.vocab)} species, {total_tokens} tokens total, Ds={self._ds}.")
|
| 132 |
+
|
| 133 |
+
def load_vocab(self) -> Dict[str, int]:
|
| 134 |
+
return self.vocab.copy()
|
| 135 |
+
|
| 136 |
+
def _deterministic_stub(self, length: int = None) -> torch.FloatTensor:
|
| 137 |
+
if self.is_legacy and length:
|
| 138 |
+
t = torch.zeros(1, length, self._ds, dtype=torch.float32)
|
| 139 |
+
else:
|
| 140 |
+
t = torch.zeros(1, self._ds, dtype=torch.float32)
|
| 141 |
+
return t
|
| 142 |
+
|
| 143 |
+
def get(self, species_id: int) -> torch.FloatTensor:
|
| 144 |
+
if not self.is_legacy:
|
| 145 |
+
if species_id < 0 or species_id >= getattr(self, "num_species", 0):
|
| 146 |
+
return self._deterministic_stub()
|
| 147 |
+
emb = self.embeddings[species_id]
|
| 148 |
+
tensor = torch.from_numpy(np.asarray(emb).copy()).float().unsqueeze(0)
|
| 149 |
+
return tensor
|
| 150 |
+
else:
|
| 151 |
+
sid = str(species_id)
|
| 152 |
+
entry = self.index.get(sid)
|
| 153 |
+
if entry is None:
|
| 154 |
+
return self._deterministic_stub(length=8)
|
| 155 |
+
offset = int(entry["offset"]); length = int(entry["length"])
|
| 156 |
+
view = self.embeddings[offset: offset + length]
|
| 157 |
+
tensor = torch.from_numpy(np.asarray(view).copy()).float().unsqueeze(0)
|
| 158 |
+
return tensor
|
| 159 |
+
|
| 160 |
+
def batch_get(self, species_ids: List[int]) -> Any:
|
| 161 |
+
if torch.is_tensor(species_ids):
|
| 162 |
+
species_ids = species_ids.detach().cpu().tolist()
|
| 163 |
+
else:
|
| 164 |
+
species_ids = [int(x) for x in species_ids]
|
| 165 |
+
B = len(species_ids)
|
| 166 |
+
if not self.is_legacy:
|
| 167 |
+
batch_emb = torch.zeros(B, self._ds, dtype=torch.float32)
|
| 168 |
+
for i, sid in enumerate(species_ids):
|
| 169 |
+
batch_emb[i] = self.get(sid).squeeze(0)
|
| 170 |
+
return batch_emb
|
| 171 |
+
else:
|
| 172 |
+
tensors = [self.get(sid) for sid in species_ids]
|
| 173 |
+
lengths = torch.tensor([t.shape[1] for t in tensors], dtype=torch.long)
|
| 174 |
+
Ls_max = int(lengths.max().item()) if lengths.numel() > 0 else 0
|
| 175 |
+
padded = torch.zeros(B, Ls_max, self._ds, dtype=torch.float32)
|
| 176 |
+
for i, t in enumerate(tensors):
|
| 177 |
+
L = t.shape[1]; padded[i, :L] = t.squeeze(0)
|
| 178 |
+
return padded, lengths
|
| 179 |
+
|
| 180 |
+
def Ds(self) -> int:
|
| 181 |
+
return self._ds
|
| 182 |
+
|
| 183 |
+
def _is_parquet(path: str) -> bool:
|
| 184 |
+
lower = path.lower()
|
| 185 |
+
return lower.endswith(".parquet") or lower.endswith(".parq")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _is_csv(path: str) -> bool:
|
| 189 |
+
lower = path.lower()
|
| 190 |
+
return (
|
| 191 |
+
lower.endswith(".csv")
|
| 192 |
+
or lower.endswith(".tsv")
|
| 193 |
+
or lower.endswith(".csv.gz")
|
| 194 |
+
or lower.endswith(".tsv.gz")
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _expand_paths(maybe_path_or_glob: str | List[str]) -> List[str]:
|
| 199 |
+
"""
|
| 200 |
+
Expand a path/glob or list of them into a sorted, de-duplicated list of files.
|
| 201 |
+
We prioritize parquet, then csv/tsv.
|
| 202 |
+
"""
|
| 203 |
+
paths: List[str] = []
|
| 204 |
+
if isinstance(maybe_path_or_glob, str):
|
| 205 |
+
p = Path(maybe_path_or_glob)
|
| 206 |
+
if p.is_dir():
|
| 207 |
+
# Scan directory for parquet first, then csv/tsv
|
| 208 |
+
paths.extend(sorted(str(x) for x in p.rglob("*.parquet")))
|
| 209 |
+
paths.extend(sorted(str(x) for x in p.rglob("*.parq")))
|
| 210 |
+
paths.extend(sorted(str(x) for x in p.rglob("*.csv")))
|
| 211 |
+
paths.extend(sorted(str(x) for x in p.rglob("*.tsv")))
|
| 212 |
+
paths.extend(sorted(str(x) for x in p.rglob("*.csv.gz")))
|
| 213 |
+
paths.extend(sorted(str(x) for x in p.rglob("*.tsv.gz")))
|
| 214 |
+
else:
|
| 215 |
+
paths = sorted(glob.glob(str(p)))
|
| 216 |
+
else:
|
| 217 |
+
for it in maybe_path_or_glob:
|
| 218 |
+
paths.extend(_expand_paths(it))
|
| 219 |
+
# Dedup while preserving order
|
| 220 |
+
seen = set()
|
| 221 |
+
out = []
|
| 222 |
+
for x in paths:
|
| 223 |
+
if x not in seen:
|
| 224 |
+
out.append(x)
|
| 225 |
+
seen.add(x)
|
| 226 |
+
if not out:
|
| 227 |
+
raise FileNotFoundError(f"No input files found for: {maybe_path_or_glob}")
|
| 228 |
+
return out
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _dist_info() -> Tuple[int, int]:
|
| 232 |
+
"""
|
| 233 |
+
Returns (num_global_workers, global_worker_id)
|
| 234 |
+
where global_worker_id = rank * num_workers + worker_id.
|
| 235 |
+
"""
|
| 236 |
+
world_size = 1
|
| 237 |
+
rank = 0
|
| 238 |
+
try:
|
| 239 |
+
import torch.distributed as dist
|
| 240 |
+
|
| 241 |
+
if dist.is_available() and dist.is_initialized():
|
| 242 |
+
world_size = dist.get_world_size()
|
| 243 |
+
rank = dist.get_rank()
|
| 244 |
+
except Exception:
|
| 245 |
+
pass
|
| 246 |
+
wi = get_worker_info()
|
| 247 |
+
nw = wi.num_workers if wi else 1
|
| 248 |
+
wid = wi.id if wi else 0
|
| 249 |
+
return world_size * nw, rank * nw + wid
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class _ResumeSkipProgress:
|
| 253 |
+
"""Lightweight progress helper for resume skips."""
|
| 254 |
+
|
| 255 |
+
def __init__(self, total: int, label: str):
|
| 256 |
+
self.total = int(max(0, total))
|
| 257 |
+
self.label = label
|
| 258 |
+
self.count = 0
|
| 259 |
+
self._bar = None
|
| 260 |
+
|
| 261 |
+
if self.total <= 0:
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
if _tqdm is not None:
|
| 265 |
+
self._bar = _tqdm(total=self.total, desc=label, unit="sample", dynamic_ncols=True, leave=False)
|
| 266 |
+
else:
|
| 267 |
+
logger.info("%s: skipping %d samples to reach resume cursor", label, self.total)
|
| 268 |
+
|
| 269 |
+
def update(self, n: int = 1):
|
| 270 |
+
if self.total <= 0:
|
| 271 |
+
return
|
| 272 |
+
self.count += int(n)
|
| 273 |
+
if self._bar is not None:
|
| 274 |
+
self._bar.update(n)
|
| 275 |
+
else:
|
| 276 |
+
if self.count == self.total or self.count % 10000 == 0:
|
| 277 |
+
logger.info("%s: skipped %d / %d", self.label, self.count, self.total)
|
| 278 |
+
|
| 279 |
+
def close(self):
|
| 280 |
+
if self.total <= 0:
|
| 281 |
+
return
|
| 282 |
+
if self._bar is not None:
|
| 283 |
+
self._bar.close()
|
| 284 |
+
logger.info("%s: resume skip finished (%d samples)", self.label, self.count)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class StreamSeqDataset(IterableDataset):
|
| 288 |
+
"""
|
| 289 |
+
Streaming dataset with **non-overlapping Parquet row-group sharding**.
|
| 290 |
+
|
| 291 |
+
- Accepts list of files (parquet and/or csv/tsv).
|
| 292 |
+
- **Parquet**: we enumerate (file, row_group) tasks and stride them across
|
| 293 |
+
the *global* worker id to avoid duplicates and to keep all ranks busy even
|
| 294 |
+
with few files.
|
| 295 |
+
- **CSV/TSV**: assigned at file granularity (one worker reads a file).
|
| 296 |
+
If you have only a few CSV files and many ranks, some ranks may get no CSV work.
|
| 297 |
+
(Parquet is the recommended format at scale.)
|
| 298 |
+
- CSV is read with pandas chunksize to keep memory usage sane.
|
| 299 |
+
- Each Parquet task reads exactly **one row group** into pandas.
|
| 300 |
+
|
| 301 |
+
Minimal resume support:
|
| 302 |
+
- set_resume_skip(N) skips N yielded samples across the worker's assigned tasks.
|
| 303 |
+
(Use a **per-rank** skip value in your trainer so multi-node resumes stay in lockstep.)
|
| 304 |
+
|
| 305 |
+
Output sample schema:
|
| 306 |
+
{
|
| 307 |
+
"species_name": str,
|
| 308 |
+
"species_id": int,
|
| 309 |
+
"protein_seq": str, # raw AA (ESM tokenized later)
|
| 310 |
+
"aa_len": int,
|
| 311 |
+
"codon_ids": List[int], # tokenized 3-mer ids + EOS at the end
|
| 312 |
+
"refseq_id": str,
|
| 313 |
+
"protein_refseq_id": str,
|
| 314 |
+
"control_mode": "fixed",
|
| 315 |
+
"meta": {"src": "parquet|csv", "file": basename, "row": int}
|
| 316 |
+
}
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
# Canonical required columns. We also accept common aliases (e.g., 'taxon').
|
| 320 |
+
REQUIRED = ["cds_DNA", "protein_seq", "Taxon"]
|
| 321 |
+
|
| 322 |
+
def __init__(
|
| 323 |
+
self,
|
| 324 |
+
files: List[str],
|
| 325 |
+
tokenizer,
|
| 326 |
+
species_vocab_path: str,
|
| 327 |
+
unknown_species_id: int = 0,
|
| 328 |
+
csv_chunksize: int = 200_000,
|
| 329 |
+
shuffle_buffer: int = 0,
|
| 330 |
+
seed: int = 1234,
|
| 331 |
+
shard_across_ranks: bool = True,
|
| 332 |
+
):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.files = files
|
| 335 |
+
self.tok = tokenizer
|
| 336 |
+
with open(species_vocab_path, "r") as f:
|
| 337 |
+
self.species_vocab: Dict[str, int] = json.load(f)
|
| 338 |
+
self.unknown_species_id = int(unknown_species_id)
|
| 339 |
+
self.csv_chunksize = int(max(1, csv_chunksize))
|
| 340 |
+
self.shuffle_buffer = int(max(0, shuffle_buffer))
|
| 341 |
+
self.seed = int(seed)
|
| 342 |
+
# When False, every rank iterates over the full task list instead of
|
| 343 |
+
# taking a disjoint shard. This keeps FSDP collectives aligned during
|
| 344 |
+
# evaluation even if the validation dataset is smaller than WORLD_SIZE.
|
| 345 |
+
self.shard_across_ranks = bool(shard_across_ranks)
|
| 346 |
+
|
| 347 |
+
# Minimal resume cursor
|
| 348 |
+
self._resume_skip_n: int = 0
|
| 349 |
+
self._offset_start: int = 0
|
| 350 |
+
self._emitted: int = 0
|
| 351 |
+
|
| 352 |
+
# ---- resume cursor (minimal) ----
|
| 353 |
+
def set_resume_skip(self, n: int) -> None:
|
| 354 |
+
n = int(max(0, n))
|
| 355 |
+
self._resume_skip_n = n
|
| 356 |
+
self._offset_start = n
|
| 357 |
+
self._emitted = 0
|
| 358 |
+
|
| 359 |
+
def get_stream_position(self) -> int:
|
| 360 |
+
# Total yielded so far since dataset creation, including initial skip offset
|
| 361 |
+
return int(self._offset_start + self._emitted)
|
| 362 |
+
|
| 363 |
+
# ---- core row-wise iterator on a pandas DataFrame ----
|
| 364 |
+
def _iter_df(self, df: pd.DataFrame, src: str, file: str) -> Iterable[Dict[str, Any]]:
|
| 365 |
+
# Normalize common column aliases before validating.
|
| 366 |
+
# Some shards use lowercase `taxon` instead of `Taxon`.
|
| 367 |
+
if "Taxon" not in df.columns and "taxon" in df.columns:
|
| 368 |
+
df = df.rename(columns={"taxon": "Taxon"})
|
| 369 |
+
|
| 370 |
+
# Hard fail if required missing
|
| 371 |
+
for c in self.REQUIRED:
|
| 372 |
+
if c not in df.columns:
|
| 373 |
+
raise ValueError(f"Input missing required column '{c}' in {file}")
|
| 374 |
+
|
| 375 |
+
# Normalize & clean
|
| 376 |
+
df = df[self.REQUIRED + ([c for c in ["RefseqID"] if c in df.columns])]
|
| 377 |
+
df["Taxon"] = df["Taxon"].astype(str).str.strip()
|
| 378 |
+
df["protein_seq"] = df["protein_seq"].astype(str).str.strip().str.upper()
|
| 379 |
+
df["cds_DNA"] = df["cds_DNA"].astype(str).str.strip().str.upper()
|
| 380 |
+
|
| 381 |
+
# Filter DNA: ACGT only and length > 0
|
| 382 |
+
ok_mask = (df["cds_DNA"].str.len() > 0) & df["cds_DNA"].str.fullmatch(r"[ACGT]+", na=False)
|
| 383 |
+
df = df[ok_mask]
|
| 384 |
+
if df.empty:
|
| 385 |
+
return
|
| 386 |
+
|
| 387 |
+
# Trim protein/DNA to shared min length (in codons)
|
| 388 |
+
cds_codons = (df["cds_DNA"].str.len() // 3).astype(int)
|
| 389 |
+
prot_len = df["protein_seq"].str.len().astype(int)
|
| 390 |
+
min_len = np.minimum(cds_codons.values, prot_len.values)
|
| 391 |
+
|
| 392 |
+
df = df.assign(__min_len=min_len)
|
| 393 |
+
df = df[df["__min_len"] > 0]
|
| 394 |
+
if df.empty:
|
| 395 |
+
return
|
| 396 |
+
|
| 397 |
+
# Species id map
|
| 398 |
+
def map_species(x: str) -> int:
|
| 399 |
+
try:
|
| 400 |
+
return int(self.species_vocab.get(x, self.unknown_species_id))
|
| 401 |
+
except Exception:
|
| 402 |
+
return self.unknown_species_id
|
| 403 |
+
|
| 404 |
+
species_ids = [map_species(x) for x in df["Taxon"].tolist()]
|
| 405 |
+
refseq_col = "RefseqID" if "RefseqID" in df.columns else None
|
| 406 |
+
|
| 407 |
+
for i, (row_idx, row) in enumerate(df.iterrows()):
|
| 408 |
+
ml = int(row["__min_len"])
|
| 409 |
+
cds = row["cds_DNA"][: ml * 3]
|
| 410 |
+
prot = row["protein_seq"][: ml]
|
| 411 |
+
if (len(cds) // 3) != len(prot):
|
| 412 |
+
continue
|
| 413 |
+
|
| 414 |
+
# Tokenize DNA → 3-mer ids; append EOS
|
| 415 |
+
codon_ids = self.tok.encode_codon_seq(cds, validate=False)
|
| 416 |
+
codon_ids.append(
|
| 417 |
+
self.tok.special_ids.eos if hasattr(self.tok, "special_ids") else self.tok._special_ids.eos
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
species_id = species_ids[i]
|
| 421 |
+
ref_id = row[refseq_col] if refseq_col else f"{Path(file).stem}:{int(row_idx)}"
|
| 422 |
+
|
| 423 |
+
yield {
|
| 424 |
+
"species_name": row["Taxon"],
|
| 425 |
+
"species_id": int(species_id),
|
| 426 |
+
"protein_seq": prot,
|
| 427 |
+
"aa_len": len(prot),
|
| 428 |
+
"codon_ids": codon_ids,
|
| 429 |
+
"refseq_id": ref_id,
|
| 430 |
+
"protein_refseq_id": ref_id,
|
| 431 |
+
"control_mode": "fixed",
|
| 432 |
+
"meta": {"src": src, "file": os.path.basename(file), "row": int(row_idx)},
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
# ---- Parquet helpers: enumerate row-group tasks & read one row group ----
|
| 436 |
+
def _enumerate_tasks(self, files: List[str]) -> List[Tuple[str, str, Optional[int], int]]:
|
| 437 |
+
"""
|
| 438 |
+
Return a task list of tuples:
|
| 439 |
+
("parquet", path, row_group_idx, weight) for each row group in each Parquet file
|
| 440 |
+
("csv", path, None, weight) for each CSV/TSV file
|
| 441 |
+
"""
|
| 442 |
+
tasks: List[Tuple[str, str, Optional[int], int]] = []
|
| 443 |
+
parquet_files = [f for f in files if _is_parquet(f)]
|
| 444 |
+
csv_files = [f for f in files if _is_csv(f)]
|
| 445 |
+
|
| 446 |
+
if parquet_files:
|
| 447 |
+
try:
|
| 448 |
+
import pyarrow.parquet as pq # type: ignore
|
| 449 |
+
except Exception as e:
|
| 450 |
+
raise ImportError("pyarrow is required to read parquet files") from e
|
| 451 |
+
|
| 452 |
+
for fp in parquet_files:
|
| 453 |
+
pf = pq.ParquetFile(fp)
|
| 454 |
+
nrg = int(pf.num_row_groups or 0)
|
| 455 |
+
if nrg <= 0:
|
| 456 |
+
# Treat as single task if row groups unavailable (unusual)
|
| 457 |
+
total_rows = pf.metadata.num_rows if pf.metadata and pf.metadata.num_rows is not None else 1
|
| 458 |
+
tasks.append(("parquet", fp, 0, max(1, int(total_rows))))
|
| 459 |
+
else:
|
| 460 |
+
for rg in range(nrg):
|
| 461 |
+
if pf.metadata is not None:
|
| 462 |
+
rg_meta = pf.metadata.row_group(rg)
|
| 463 |
+
num_rows = rg_meta.num_rows if rg_meta.num_rows is not None else 0
|
| 464 |
+
else:
|
| 465 |
+
num_rows = 0
|
| 466 |
+
tasks.append(("parquet", fp, rg, max(1, int(num_rows))))
|
| 467 |
+
|
| 468 |
+
# CSV/TSV files remain file-level tasks
|
| 469 |
+
for fp in csv_files:
|
| 470 |
+
file_size = os.path.getsize(fp)
|
| 471 |
+
# Assume ~256 bytes per record when estimating CSV row counts (empirical default)
|
| 472 |
+
est_rows = max(1, int(file_size // 256))
|
| 473 |
+
tasks.append(("csv", fp, None, est_rows))
|
| 474 |
+
|
| 475 |
+
# Keep a deterministic order
|
| 476 |
+
# (files are already sorted by _expand_paths)
|
| 477 |
+
return tasks
|
| 478 |
+
|
| 479 |
+
@staticmethod
|
| 480 |
+
def _balanced_partition(tasks: List[Tuple[str, str, Optional[int], int]], groups: int) -> List[List[Tuple[str, str, Optional[int], int]]]:
|
| 481 |
+
if groups <= 1:
|
| 482 |
+
return [tasks]
|
| 483 |
+
if not tasks:
|
| 484 |
+
return [[] for _ in range(groups)]
|
| 485 |
+
|
| 486 |
+
# Greedy load balancing: assign heavier tasks first to the lightest bucket.
|
| 487 |
+
indexed = [(idx, kind, path, rg, weight) for idx, (kind, path, rg, weight) in enumerate(tasks)]
|
| 488 |
+
tasks_sorted = sorted(
|
| 489 |
+
indexed,
|
| 490 |
+
key=lambda entry: (entry[4], -entry[0]),
|
| 491 |
+
reverse=True,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
heap: List[Tuple[int, int]] = [(0, bucket_idx) for bucket_idx in range(groups)]
|
| 495 |
+
heapq.heapify(heap)
|
| 496 |
+
buckets: List[List[Tuple[int, str, str, Optional[int], int]]] = [[] for _ in range(groups)]
|
| 497 |
+
|
| 498 |
+
for original_index, kind, path, rg, weight in tasks_sorted:
|
| 499 |
+
load, bucket_idx = heapq.heappop(heap)
|
| 500 |
+
buckets[bucket_idx].append((original_index, kind, path, rg, weight))
|
| 501 |
+
heapq.heappush(heap, (load + weight, bucket_idx))
|
| 502 |
+
|
| 503 |
+
partitions: List[List[Tuple[str, str, Optional[int], int]]] = []
|
| 504 |
+
for bucket in buckets:
|
| 505 |
+
bucket.sort(key=lambda entry: entry[0])
|
| 506 |
+
partitions.append([(kind, path, rg, weight) for (_idx, kind, path, rg, weight) in bucket])
|
| 507 |
+
return partitions
|
| 508 |
+
|
| 509 |
+
def _parquet_rowgroup_iter(
|
| 510 |
+
self, file: str, row_group_idx: int, cols_cache: Dict[str, List[str]]
|
| 511 |
+
) -> Iterable[Dict[str, Any]]:
|
| 512 |
+
import pyarrow.parquet as pq # safe: checked in _enumerate_tasks
|
| 513 |
+
pf = pq.ParquetFile(file)
|
| 514 |
+
# Cache the column subset per file so we don't recompute
|
| 515 |
+
if file not in cols_cache:
|
| 516 |
+
names = set(pf.schema.names)
|
| 517 |
+
cols: List[str] = []
|
| 518 |
+
# Required columns, with alias support (notably Taxon vs taxon).
|
| 519 |
+
for c in self.REQUIRED:
|
| 520 |
+
if c in names:
|
| 521 |
+
cols.append(c)
|
| 522 |
+
continue
|
| 523 |
+
if c == "Taxon" and "taxon" in names:
|
| 524 |
+
cols.append("taxon")
|
| 525 |
+
continue
|
| 526 |
+
# Optional debug id
|
| 527 |
+
if "RefseqID" in names:
|
| 528 |
+
cols.append("RefseqID")
|
| 529 |
+
cols_cache[file] = cols
|
| 530 |
+
cols = cols_cache[file]
|
| 531 |
+
table = pf.read_row_group(row_group_idx, columns=cols)
|
| 532 |
+
df = table.to_pandas(types_mapper=None)
|
| 533 |
+
yield from self._iter_df(df, "parquet", file)
|
| 534 |
+
|
| 535 |
+
def _csv_file_iter(self, file: str) -> Iterable[Dict[str, Any]]:
|
| 536 |
+
# One worker owns this file (non-overlapping assignment)
|
| 537 |
+
for chunk in pd.read_csv(file, chunksize=self.csv_chunksize, dtype=str, keep_default_na=False):
|
| 538 |
+
yield from self._iter_df(chunk, "csv", file)
|
| 539 |
+
|
| 540 |
+
# ---- main iterator ----
|
| 541 |
+
def __iter__(self):
|
| 542 |
+
wi = get_worker_info()
|
| 543 |
+
num_workers = wi.num_workers if wi else 1
|
| 544 |
+
worker_id = wi.id if wi else 0
|
| 545 |
+
|
| 546 |
+
num_global, gid = _dist_info()
|
| 547 |
+
if not self.shard_across_ranks:
|
| 548 |
+
num_global = max(1, num_workers)
|
| 549 |
+
gid = worker_id
|
| 550 |
+
|
| 551 |
+
workers_per_rank = max(1, num_workers)
|
| 552 |
+
rank = gid // workers_per_rank if self.shard_across_ranks else 0
|
| 553 |
+
world = max(1, num_global // workers_per_rank)
|
| 554 |
+
|
| 555 |
+
# Each rank may have a non-zero per-rank resume skip. Split evenly across local
|
| 556 |
+
# dataloader workers so the sum equals the per-rank target, then apply a fast
|
| 557 |
+
# task-level skip to avoid row-by-row scans for huge cursors.
|
| 558 |
+
per_rank_skip = int(self._resume_skip_n)
|
| 559 |
+
base = per_rank_skip // max(1, workers_per_rank)
|
| 560 |
+
rem = per_rank_skip % max(1, workers_per_rank)
|
| 561 |
+
local_skip_target = base + (1 if worker_id < rem else 0)
|
| 562 |
+
progress: Optional[_ResumeSkipProgress] = None
|
| 563 |
+
|
| 564 |
+
# Build the global task list (parquet row groups + csv files) and shard by gid
|
| 565 |
+
tasks = self._enumerate_tasks(self.files)
|
| 566 |
+
|
| 567 |
+
if tasks:
|
| 568 |
+
partitions = self._balanced_partition(tasks, max(1, num_global))
|
| 569 |
+
my_tasks_full = partitions[gid] if gid < len(partitions) else []
|
| 570 |
+
else:
|
| 571 |
+
my_tasks_full = []
|
| 572 |
+
|
| 573 |
+
if local_skip_target > 0 and worker_id == 0:
|
| 574 |
+
label = (
|
| 575 |
+
"resume skip" if world == 1 else f"resume skip (rank {rank}/{world})"
|
| 576 |
+
)
|
| 577 |
+
progress = _ResumeSkipProgress(local_skip_target, label)
|
| 578 |
+
|
| 579 |
+
# Fast task-level skip: consume whole tasks when their weight is <= remaining skip
|
| 580 |
+
# and only fall back to row-level skipping for the first partial task.
|
| 581 |
+
skip_remaining = int(local_skip_target)
|
| 582 |
+
start_idx = 0
|
| 583 |
+
partial_task_idx = None
|
| 584 |
+
partial_task_kind = None
|
| 585 |
+
partial_task_path = None
|
| 586 |
+
partial_task_rg = None
|
| 587 |
+
if skip_remaining > 0 and my_tasks_full:
|
| 588 |
+
for idx, (kind, path, rg, weight) in enumerate(my_tasks_full):
|
| 589 |
+
w = int(weight) if weight is not None else 0
|
| 590 |
+
if w <= 0:
|
| 591 |
+
continue
|
| 592 |
+
if skip_remaining >= w:
|
| 593 |
+
skip_remaining -= w
|
| 594 |
+
start_idx = idx + 1
|
| 595 |
+
if progress is not None:
|
| 596 |
+
progress.update(w)
|
| 597 |
+
else:
|
| 598 |
+
partial_task_idx = idx
|
| 599 |
+
partial_task_kind = kind
|
| 600 |
+
partial_task_path = path
|
| 601 |
+
partial_task_rg = rg
|
| 602 |
+
break
|
| 603 |
+
|
| 604 |
+
# Slice my task list to start after any fully-skipped tasks
|
| 605 |
+
my_tasks = [(kind, path, rg) for (kind, path, rg, _w) in my_tasks_full[start_idx:]]
|
| 606 |
+
|
| 607 |
+
rng = random.Random(self.seed + gid)
|
| 608 |
+
buffer: List[Dict[str, Any]] = []
|
| 609 |
+
bufN = self.shuffle_buffer
|
| 610 |
+
|
| 611 |
+
def _drain_buffer():
|
| 612 |
+
if not buffer:
|
| 613 |
+
return
|
| 614 |
+
if bufN > 0:
|
| 615 |
+
rng.shuffle(buffer)
|
| 616 |
+
for it in buffer:
|
| 617 |
+
yield it
|
| 618 |
+
buffer.clear()
|
| 619 |
+
|
| 620 |
+
# Skip counter for resume cursor (row-level remainder after task skips)
|
| 621 |
+
skipped = int(local_skip_target - skip_remaining)
|
| 622 |
+
|
| 623 |
+
# Cache for per-file Parquet column selection
|
| 624 |
+
cols_cache: Dict[str, List[str]] = {}
|
| 625 |
+
|
| 626 |
+
try:
|
| 627 |
+
# If we split a task, handle its partial row-level skip first
|
| 628 |
+
if partial_task_idx is not None and skip_remaining > 0:
|
| 629 |
+
kind = partial_task_kind
|
| 630 |
+
path = partial_task_path
|
| 631 |
+
rg = partial_task_rg
|
| 632 |
+
if kind == "parquet":
|
| 633 |
+
assert rg is not None
|
| 634 |
+
row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache)
|
| 635 |
+
elif kind == "csv":
|
| 636 |
+
row_iter = self._csv_file_iter(path)
|
| 637 |
+
else:
|
| 638 |
+
raise ValueError(f"Unknown task kind: {kind}")
|
| 639 |
+
|
| 640 |
+
for sample in row_iter:
|
| 641 |
+
if skip_remaining > 0:
|
| 642 |
+
skip_remaining -= 1
|
| 643 |
+
skipped += 1
|
| 644 |
+
if progress is not None:
|
| 645 |
+
progress.update(1)
|
| 646 |
+
if skip_remaining == 0 and progress is not None:
|
| 647 |
+
progress.close()
|
| 648 |
+
progress = None
|
| 649 |
+
continue
|
| 650 |
+
# past the partial skip remainder, fall through to normal buffering/yield
|
| 651 |
+
if bufN <= 0:
|
| 652 |
+
self._emitted += 1
|
| 653 |
+
yield sample
|
| 654 |
+
else:
|
| 655 |
+
buffer.append(sample)
|
| 656 |
+
if len(buffer) >= bufN:
|
| 657 |
+
j = rng.randrange(len(buffer))
|
| 658 |
+
buffer[j], buffer[-1] = buffer[-1], buffer[j]
|
| 659 |
+
self._emitted += 1
|
| 660 |
+
yield buffer.pop()
|
| 661 |
+
|
| 662 |
+
for (kind, path, rg) in my_tasks:
|
| 663 |
+
if kind == "parquet":
|
| 664 |
+
assert rg is not None
|
| 665 |
+
row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache)
|
| 666 |
+
elif kind == "csv":
|
| 667 |
+
row_iter = self._csv_file_iter(path)
|
| 668 |
+
else:
|
| 669 |
+
raise ValueError(f"Unknown task kind: {kind}")
|
| 670 |
+
|
| 671 |
+
for sample in row_iter:
|
| 672 |
+
# Apply any remaining resume skip across the flattened stream
|
| 673 |
+
if skip_remaining > 0:
|
| 674 |
+
skip_remaining -= 1
|
| 675 |
+
skipped += 1
|
| 676 |
+
if progress is not None:
|
| 677 |
+
progress.update(1)
|
| 678 |
+
if skip_remaining == 0 and progress is not None:
|
| 679 |
+
# Finish the progress bar once we've consumed the target
|
| 680 |
+
progress.close()
|
| 681 |
+
progress = None
|
| 682 |
+
continue
|
| 683 |
+
|
| 684 |
+
if bufN <= 0:
|
| 685 |
+
self._emitted += 1
|
| 686 |
+
yield sample
|
| 687 |
+
else:
|
| 688 |
+
buffer.append(sample)
|
| 689 |
+
if len(buffer) >= bufN:
|
| 690 |
+
j = rng.randrange(len(buffer))
|
| 691 |
+
buffer[j], buffer[-1] = buffer[-1], buffer[j]
|
| 692 |
+
self._emitted += 1
|
| 693 |
+
yield buffer.pop()
|
| 694 |
+
|
| 695 |
+
# Flush leftovers
|
| 696 |
+
for it in _drain_buffer():
|
| 697 |
+
self._emitted += 1
|
| 698 |
+
yield it
|
| 699 |
+
finally:
|
| 700 |
+
if progress is not None:
|
| 701 |
+
progress.close()
|
| 702 |
+
if local_skip_target > 0:
|
| 703 |
+
# Persist any remaining leftover skip (including partial progress) per worker copy
|
| 704 |
+
self._resume_skip_n = max(local_skip_target - skipped, 0)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
# ------------------------------
|
| 708 |
+
# Simple collate: end-only pad for codon stream, pass-through everything else
|
| 709 |
+
# ------------------------------
|
| 710 |
+
|
| 711 |
+
def stage_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 712 |
+
B = len(batch)
|
| 713 |
+
if B == 0:
|
| 714 |
+
return {}
|
| 715 |
+
|
| 716 |
+
# species ids
|
| 717 |
+
species_ids = torch.tensor([int(x.get("species_id", 0)) for x in batch], dtype=torch.long)
|
| 718 |
+
|
| 719 |
+
# raw protein sequences stay as list[str] (ESM handles tokenization)
|
| 720 |
+
protein_seqs = [str(x.get("protein_seq", "M")) for x in batch]
|
| 721 |
+
|
| 722 |
+
# Build padded codon ids (right padding). Keep EOS inside the sequence (already appended in dataset).
|
| 723 |
+
codon_lists = [x.get("codon_ids", []) for x in batch]
|
| 724 |
+
max_len = max(len(c) for c in codon_lists)
|
| 725 |
+
pad_id = 0 # tokenizer.pad_token_id is 0 in our tokenizer.
|
| 726 |
+
codon_ids = torch.full((B, max_len), pad_id, dtype=torch.long)
|
| 727 |
+
for i, row in enumerate(codon_lists):
|
| 728 |
+
if len(row) > 0:
|
| 729 |
+
codon_ids[i, : len(row)] = torch.tensor(row, dtype=torch.long)
|
| 730 |
+
|
| 731 |
+
out: Dict[str, Any] = {
|
| 732 |
+
"species_ids": species_ids,
|
| 733 |
+
"protein_seqs": protein_seqs,
|
| 734 |
+
"codon_ids": codon_ids,
|
| 735 |
+
"control_mode": batch[0].get("control_mode", "fixed"),
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
# Optional passthroughs
|
| 739 |
+
if "refseq_id" in batch[0]:
|
| 740 |
+
out["refseq_id"] = [x.get("refseq_id") for x in batch]
|
| 741 |
+
if "protein_refseq_id" in batch[0]:
|
| 742 |
+
out["protein_refseq_id"] = [x.get("protein_refseq_id") for x in batch]
|
| 743 |
+
|
| 744 |
+
return out
|
| 745 |
+
|
| 746 |
+
def _build_dataset(
|
| 747 |
+
path_or_paths: str | List[str],
|
| 748 |
+
tokenizer,
|
| 749 |
+
species_vocab_path: str,
|
| 750 |
+
shuffle_buffer: int,
|
| 751 |
+
csv_chunksize: int,
|
| 752 |
+
shard_across_ranks: bool = True,
|
| 753 |
+
) -> StreamSeqDataset:
|
| 754 |
+
files = _expand_paths(path_or_paths)
|
| 755 |
+
return StreamSeqDataset(
|
| 756 |
+
files=files,
|
| 757 |
+
tokenizer=tokenizer,
|
| 758 |
+
species_vocab_path=species_vocab_path,
|
| 759 |
+
unknown_species_id=0,
|
| 760 |
+
csv_chunksize=csv_chunksize,
|
| 761 |
+
shuffle_buffer=shuffle_buffer,
|
| 762 |
+
seed=1234,
|
| 763 |
+
shard_across_ranks=shard_across_ranks,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def create_precomputed_dataloaders(
|
| 768 |
+
train_path: str | List[str],
|
| 769 |
+
val_path: Optional[str | List[str]],
|
| 770 |
+
embeddings_dir: str,
|
| 771 |
+
tokenizer,
|
| 772 |
+
batch_size: int,
|
| 773 |
+
num_workers: int = 4,
|
| 774 |
+
species_pooling: str = "sequence",
|
| 775 |
+
csv_chunksize: int = 200_000,
|
| 776 |
+
train_shuffle_buffer: int = 8192,
|
| 777 |
+
val_shuffle_buffer: int = 0,
|
| 778 |
+
) -> Tuple[DataLoader, Optional[DataLoader], SpeciesEmbeddingStore]:
|
| 779 |
+
"""
|
| 780 |
+
Returns:
|
| 781 |
+
- train_loader, val_loader (optional), and the SpeciesEmbeddingStore
|
| 782 |
+
"""
|
| 783 |
+
species_store = SpeciesEmbeddingStore(embeddings_dir, pin_memory=True, pooling=species_pooling)
|
| 784 |
+
species_vocab_path = os.path.join(embeddings_dir, "species_vocab.json")
|
| 785 |
+
num_workers = int(max(0, num_workers))
|
| 786 |
+
|
| 787 |
+
train_ds = _build_dataset(
|
| 788 |
+
path_or_paths=train_path,
|
| 789 |
+
tokenizer=tokenizer,
|
| 790 |
+
species_vocab_path=species_vocab_path,
|
| 791 |
+
shuffle_buffer=int(train_shuffle_buffer),
|
| 792 |
+
csv_chunksize=int(csv_chunksize),
|
| 793 |
+
)
|
| 794 |
+
val_ds = None
|
| 795 |
+
if val_path:
|
| 796 |
+
val_ds = _build_dataset(
|
| 797 |
+
path_or_paths=val_path,
|
| 798 |
+
tokenizer=tokenizer,
|
| 799 |
+
species_vocab_path=species_vocab_path,
|
| 800 |
+
shuffle_buffer=int(val_shuffle_buffer),
|
| 801 |
+
csv_chunksize=int(csv_chunksize),
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
# NOTE: IterableDataset can't be shuffled by DataLoader. We already "shuffle" inside the dataset.
|
| 805 |
+
kwargs_common = dict(
|
| 806 |
+
num_workers=num_workers,
|
| 807 |
+
collate_fn=stage_collate_fn,
|
| 808 |
+
pin_memory=True,
|
| 809 |
+
persistent_workers=(num_workers > 0),
|
| 810 |
+
)
|
| 811 |
+
if num_workers > 0:
|
| 812 |
+
kwargs_common["prefetch_factor"] = 4
|
| 813 |
+
|
| 814 |
+
# Drop last for train to keep batch shapes stable under FSDP.
|
| 815 |
+
train_loader = DataLoader(
|
| 816 |
+
train_ds,
|
| 817 |
+
batch_size=batch_size,
|
| 818 |
+
shuffle=False,
|
| 819 |
+
drop_last=True,
|
| 820 |
+
**kwargs_common,
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
val_loader = None
|
| 824 |
+
if val_ds is not None:
|
| 825 |
+
val_loader = DataLoader(
|
| 826 |
+
val_ds,
|
| 827 |
+
batch_size=batch_size,
|
| 828 |
+
shuffle=False,
|
| 829 |
+
drop_last=False,
|
| 830 |
+
**kwargs_common,
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
return train_loader, val_loader, species_store
|
src/layers.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transformer components for CodonGPT.
|
| 3 |
+
Includes RMSNorm, self-attention (SDPA/Flash) with optional mask,
|
| 4 |
+
cross-attention for conditioning memory, SwiGLU FFN, and a basic block.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel # Require recent PyTorch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RMSNorm(nn.Module):
|
| 17 |
+
"""Root Mean Square Layer Normalization."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.eps = eps
|
| 22 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
"""
|
| 26 |
+
Apply RMS normalization.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
x: Input tensor of any shape ending in dim
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Normalized tensor of same shape
|
| 33 |
+
"""
|
| 34 |
+
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 35 |
+
return x * norm * self.weight
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
"""Apply rotary embeddings to x: [B,H,T,D]; cos/sin: [1,1,T,D]."""
|
| 40 |
+
x1 = x[..., ::2]
|
| 41 |
+
x2 = x[..., 1::2]
|
| 42 |
+
x_rot = torch.zeros_like(x)
|
| 43 |
+
x_rot[..., ::2] = -x2
|
| 44 |
+
x_rot[..., 1::2] = x1
|
| 45 |
+
return x * cos + x_rot * sin
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MultiHeadAttention(nn.Module):
|
| 49 |
+
"""Self-attention using PyTorch SDPA kernels (Flash/MemEff/Math) + RoPE.
|
| 50 |
+
- attn_mask: bool [B, T, T] with True = keep, False = block
|
| 51 |
+
- is_causal: whether to apply causal masking internally
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
dim: int,
|
| 57 |
+
num_heads: int,
|
| 58 |
+
dropout: float = 0.0,
|
| 59 |
+
use_rope: bool = True,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
|
| 63 |
+
self.dim = dim
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
self.head_dim = dim // num_heads
|
| 66 |
+
self.dropout = dropout
|
| 67 |
+
self.use_rope = use_rope
|
| 68 |
+
|
| 69 |
+
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
|
| 70 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
| 71 |
+
self.resid_dropout = nn.Dropout(dropout)
|
| 72 |
+
|
| 73 |
+
# RoPE cache
|
| 74 |
+
self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
|
| 75 |
+
|
| 76 |
+
def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
| 77 |
+
key = (T, device, dtype)
|
| 78 |
+
cached = self._rope_cache.get(key)
|
| 79 |
+
if cached is not None:
|
| 80 |
+
return cached
|
| 81 |
+
dim_half = self.head_dim // 2
|
| 82 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
|
| 83 |
+
t = torch.arange(T, device=device, dtype=torch.float32)
|
| 84 |
+
freqs = torch.outer(t, inv_freq)
|
| 85 |
+
cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
|
| 86 |
+
sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
|
| 87 |
+
cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D]
|
| 88 |
+
sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
|
| 89 |
+
self._rope_cache[key] = (cos, sin)
|
| 90 |
+
return cos, sin
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
x: torch.Tensor,
|
| 95 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 96 |
+
return_kv: bool = False,
|
| 97 |
+
position_offset: int = 0,
|
| 98 |
+
) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
|
| 99 |
+
"""
|
| 100 |
+
Self-attention with optional KV cache support.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
x: [B, T_new, H]
|
| 104 |
+
past_kv: Optional tuple (k, v), each [B, nH, T_past, Hd]
|
| 105 |
+
return_kv: If True, also return updated (k, v)
|
| 106 |
+
position_offset: Starting position index for RoPE (past length)
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
out or (out, present_kv)
|
| 110 |
+
"""
|
| 111 |
+
B, T_new, _ = x.shape
|
| 112 |
+
|
| 113 |
+
# QKV projections and reshape (ensure contiguous for SDPA kernels)
|
| 114 |
+
qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 115 |
+
q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous()
|
| 116 |
+
|
| 117 |
+
# RoPE for new tokens only
|
| 118 |
+
if self.use_rope:
|
| 119 |
+
# Compute cos/sin up to (offset + T_new), then slice the tail for new positions
|
| 120 |
+
cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
|
| 121 |
+
if position_offset > 0:
|
| 122 |
+
cos = cos[:, :, position_offset: position_offset + T_new, :]
|
| 123 |
+
sin = sin[:, :, position_offset: position_offset + T_new, :]
|
| 124 |
+
# Apply to q and k_new
|
| 125 |
+
q = _apply_rope(q, cos, sin)
|
| 126 |
+
k_new = _apply_rope(k_new, cos, sin)
|
| 127 |
+
|
| 128 |
+
# Concatenate with cache if provided
|
| 129 |
+
if past_kv is not None:
|
| 130 |
+
k_past, v_past = past_kv
|
| 131 |
+
k = torch.cat([k_past, k_new], dim=2)
|
| 132 |
+
v = torch.cat([v_past, v_new], dim=2)
|
| 133 |
+
is_causal = False # No future tokens present; avoid unnecessary masking
|
| 134 |
+
else:
|
| 135 |
+
k, v = k_new, v_new
|
| 136 |
+
is_causal = True
|
| 137 |
+
|
| 138 |
+
# Prefer FlashAttention; fall back to MemEff then Math. Autocast to half/bfloat16 on CUDA.
|
| 139 |
+
backends = [SDPBackend.FLASH_ATTENTION]#, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
| 140 |
+
with sdpa_kernel(backends):
|
| 141 |
+
if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
|
| 142 |
+
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 143 |
+
with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
|
| 144 |
+
out = F.scaled_dot_product_attention(
|
| 145 |
+
q, k, v,
|
| 146 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 147 |
+
is_causal=is_causal,
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
out = F.scaled_dot_product_attention(
|
| 151 |
+
q, k, v,
|
| 152 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 153 |
+
is_causal=is_causal,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim)
|
| 157 |
+
# Align dtype with residual/Linear weights to avoid bf16/float mismatches
|
| 158 |
+
if out.dtype != x.dtype:
|
| 159 |
+
out = out.to(x.dtype)
|
| 160 |
+
out = self.out_proj(out)
|
| 161 |
+
out = self.resid_dropout(out)
|
| 162 |
+
|
| 163 |
+
if return_kv:
|
| 164 |
+
return out, (k, v)
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class GroupedQueryAttention(nn.Module):
|
| 170 |
+
"""Grouped-Query Attention (GQA) using Flash Attention via PyTorch SDPA.
|
| 171 |
+
|
| 172 |
+
- num_heads total query heads
|
| 173 |
+
- num_kv_groups shared K/V groups (num_heads must be divisible by num_kv_groups)
|
| 174 |
+
- Optional q/k RMSNorm
|
| 175 |
+
- Supports RoPE with a scalar or per-sample position_offset (like MHA)
|
| 176 |
+
- Optional KV cache compatible with the existing interface (stores expanded per-head K/V)
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
dim: int,
|
| 182 |
+
num_heads: int,
|
| 183 |
+
num_kv_groups: int,
|
| 184 |
+
dropout: float = 0.0,
|
| 185 |
+
qk_norm: bool = False,
|
| 186 |
+
) -> None:
|
| 187 |
+
super().__init__()
|
| 188 |
+
assert num_heads % max(1, num_kv_groups) == 0, "num_heads must be divisible by num_kv_groups"
|
| 189 |
+
self.dim = dim
|
| 190 |
+
self.num_heads = int(num_heads)
|
| 191 |
+
self.num_kv_groups = max(1, int(num_kv_groups))
|
| 192 |
+
self.group_size = self.num_heads // self.num_kv_groups
|
| 193 |
+
|
| 194 |
+
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
| 195 |
+
self.head_dim = dim // num_heads
|
| 196 |
+
self.dropout = dropout
|
| 197 |
+
|
| 198 |
+
self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
|
| 199 |
+
self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
|
| 200 |
+
self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
|
| 201 |
+
self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False)
|
| 202 |
+
|
| 203 |
+
self.q_norm = RMSNorm(self.head_dim) if qk_norm else None
|
| 204 |
+
self.k_norm = RMSNorm(self.head_dim) if qk_norm else None
|
| 205 |
+
|
| 206 |
+
# RoPE cache
|
| 207 |
+
self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
|
| 208 |
+
|
| 209 |
+
def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
| 210 |
+
key = (T, device, dtype)
|
| 211 |
+
cached = self._rope_cache.get(key)
|
| 212 |
+
if cached is not None:
|
| 213 |
+
return cached
|
| 214 |
+
dim_half = self.head_dim // 2
|
| 215 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
|
| 216 |
+
t = torch.arange(T, device=device, dtype=torch.float32)
|
| 217 |
+
freqs = torch.outer(t, inv_freq)
|
| 218 |
+
cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
|
| 219 |
+
sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
|
| 220 |
+
cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D]
|
| 221 |
+
sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
|
| 222 |
+
self._rope_cache[key] = (cos, sin)
|
| 223 |
+
return cos, sin
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
x: torch.Tensor,
|
| 228 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 229 |
+
return_kv: bool = False,
|
| 230 |
+
position_offset: int | torch.Tensor = 0,
|
| 231 |
+
) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
|
| 232 |
+
B, T_new, _ = x.shape
|
| 233 |
+
|
| 234 |
+
# Project to Q, K, V
|
| 235 |
+
q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous() # [B,H,T,Hd]
|
| 236 |
+
k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd]
|
| 237 |
+
v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd]
|
| 238 |
+
|
| 239 |
+
# Optional RMSNorm on q/k
|
| 240 |
+
if self.q_norm is not None:
|
| 241 |
+
q = self.q_norm(q)
|
| 242 |
+
if self.k_norm is not None:
|
| 243 |
+
k = self.k_norm(k)
|
| 244 |
+
|
| 245 |
+
# RoPE for new tokens only
|
| 246 |
+
if isinstance(position_offset, int):
|
| 247 |
+
cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
|
| 248 |
+
if position_offset > 0:
|
| 249 |
+
cos = cos[:, :, position_offset: position_offset + T_new, :]
|
| 250 |
+
sin = sin[:, :, position_offset: position_offset + T_new, :]
|
| 251 |
+
q = _apply_rope(q, cos, sin)
|
| 252 |
+
k = _apply_rope(k, cos, sin)
|
| 253 |
+
else:
|
| 254 |
+
off = position_offset.to(device=x.device, dtype=torch.long)
|
| 255 |
+
max_off = int(off.max().item())
|
| 256 |
+
cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype)
|
| 257 |
+
ar = torch.arange(T_new, device=x.device, dtype=torch.long)
|
| 258 |
+
idx = (off.unsqueeze(1) + ar.unsqueeze(0)) # [B, T_new]
|
| 259 |
+
cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) # [B,1,T,D]
|
| 260 |
+
sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1)
|
| 261 |
+
q = _apply_rope(q, cos_b, sin_b)
|
| 262 |
+
# k has groups dimension [B,G,T,D]; share same offsets per batch
|
| 263 |
+
k = _apply_rope(k, cos_b, sin_b)
|
| 264 |
+
|
| 265 |
+
# Expand grouped K/V to per-head by repeating groups
|
| 266 |
+
if self.group_size > 1:
|
| 267 |
+
k_exp = k.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd]
|
| 268 |
+
v_exp = v.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd]
|
| 269 |
+
else:
|
| 270 |
+
k_exp, v_exp = k, v # already per-head
|
| 271 |
+
|
| 272 |
+
# KV cache: concatenate past along sequence dim
|
| 273 |
+
if past_kv is not None:
|
| 274 |
+
k_past, v_past = past_kv
|
| 275 |
+
k_cat = torch.cat([k_past, k_exp], dim=2)
|
| 276 |
+
v_cat = torch.cat([v_past, v_exp], dim=2)
|
| 277 |
+
is_causal = False
|
| 278 |
+
else:
|
| 279 |
+
k_cat, v_cat = k_exp, v_exp
|
| 280 |
+
is_causal = True
|
| 281 |
+
|
| 282 |
+
# Prefer FlashAttention; fall back to MemEff/Math. Ensure CUDA autocast to half/bfloat16 so kernels are available
|
| 283 |
+
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
|
| 284 |
+
if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
|
| 285 |
+
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 286 |
+
with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
|
| 287 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 288 |
+
q, k_cat, v_cat,
|
| 289 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 290 |
+
is_causal=is_causal,
|
| 291 |
+
) # [B,H,T,Hd]
|
| 292 |
+
else:
|
| 293 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 294 |
+
q, k_cat, v_cat,
|
| 295 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 296 |
+
is_causal=is_causal,
|
| 297 |
+
) # [B,H,T,Hd]
|
| 298 |
+
|
| 299 |
+
out = out.transpose(1, 2).contiguous().view(B, T_new, self.num_heads * self.head_dim)
|
| 300 |
+
# Ensure dtype compatibility for Linear / residual path
|
| 301 |
+
if out.dtype != x.dtype:
|
| 302 |
+
out = out.to(x.dtype)
|
| 303 |
+
out = self.out_proj(out)
|
| 304 |
+
|
| 305 |
+
if return_kv:
|
| 306 |
+
return out, (k_cat, v_cat)
|
| 307 |
+
return out
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class FeedForward(nn.Module):
|
| 312 |
+
"""Feed-forward network with optional GLU activation."""
|
| 313 |
+
|
| 314 |
+
def __init__(
|
| 315 |
+
self,
|
| 316 |
+
dim: int,
|
| 317 |
+
hidden_dim: int,
|
| 318 |
+
dropout: float = 0.0,
|
| 319 |
+
):
|
| 320 |
+
super().__init__()
|
| 321 |
+
|
| 322 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 323 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 324 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 325 |
+
|
| 326 |
+
self.dropout = nn.Dropout(dropout)
|
| 327 |
+
|
| 328 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 329 |
+
"""
|
| 330 |
+
Apply feed-forward network.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
x: Input tensor [B, T, dim]
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Output tensor [B, T, dim]
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class TransformerBlock(nn.Module):
|
| 343 |
+
"""Pre-norm Transformer block using self-attn + SwiGLU FFN (no cross-attention)."""
|
| 344 |
+
|
| 345 |
+
def __init__(
|
| 346 |
+
self,
|
| 347 |
+
dim: int,
|
| 348 |
+
num_heads: int,
|
| 349 |
+
mlp_ratio: float = 4.0,
|
| 350 |
+
dropout: float = 0.0,
|
| 351 |
+
num_kv_groups: int | None = None,
|
| 352 |
+
qk_norm: bool = False,
|
| 353 |
+
attn_type: str = "gqa", # "gqa" or "mha"
|
| 354 |
+
):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.norm1 = RMSNorm(dim)
|
| 357 |
+
if attn_type == "mha":
|
| 358 |
+
self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout)
|
| 359 |
+
self._attn_is_gqa = False
|
| 360 |
+
else:
|
| 361 |
+
# Use Grouped-Query Attention (defaults to no grouping when num_kv_groups is None)
|
| 362 |
+
kv_groups = num_heads if (num_kv_groups is None) else max(1, int(num_kv_groups))
|
| 363 |
+
self.attn = GroupedQueryAttention(dim=dim, num_heads=num_heads, num_kv_groups=kv_groups, dropout=dropout, qk_norm=qk_norm)
|
| 364 |
+
self._attn_is_gqa = True
|
| 365 |
+
self.norm2 = RMSNorm(dim)
|
| 366 |
+
self.ffn = FeedForward(dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=dropout)
|
| 367 |
+
|
| 368 |
+
def forward(
|
| 369 |
+
self,
|
| 370 |
+
x: torch.Tensor,
|
| 371 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 372 |
+
use_cache: bool = False,
|
| 373 |
+
position_offset: int = 0,
|
| 374 |
+
) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
|
| 375 |
+
"""Forward pass with optional KV caching."""
|
| 376 |
+
if use_cache or (past_kv is not None):
|
| 377 |
+
attn_out = self.attn(self.norm1(x), past_kv=past_kv, return_kv=True, position_offset=position_offset)
|
| 378 |
+
x = x + attn_out[0]
|
| 379 |
+
x = x + self.ffn(self.norm2(x))
|
| 380 |
+
return x, attn_out[1]
|
| 381 |
+
else:
|
| 382 |
+
x = x + self.attn(self.norm1(x))
|
| 383 |
+
x = x + self.ffn(self.norm2(x))
|
| 384 |
+
return x
|
src/models.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core model architectures for CodonGPT (GPT-only).
|
| 3 |
+
- CodonGPT: Decoder-only GPT with two-species + protein prefix
|
| 4 |
+
Includes a frozen ESM-C encoder for protein conditioning.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
from typing import Optional, Dict, Any, Tuple, List
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint as checkpoint
|
| 14 |
+
import torch.nn.utils.rnn as rnn_utils
|
| 15 |
+
|
| 16 |
+
from .layers import RMSNorm, TransformerBlock
|
| 17 |
+
from .tokenizer import SpecialIds
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FrozenESMCEncoder(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Frozen ESM-C encoder that computes protein embeddings on the fly.
|
| 23 |
+
Kept on single GPU per rank (not distributed via FSDP).
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "fp16"):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.model_name = model_name
|
| 29 |
+
self._device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 30 |
+
if dtype == "fp16":
|
| 31 |
+
self._autocast_dtype = torch.float16
|
| 32 |
+
elif dtype == "bf16":
|
| 33 |
+
self._autocast_dtype = torch.bfloat16
|
| 34 |
+
else:
|
| 35 |
+
self._autocast_dtype = None
|
| 36 |
+
self._load_model()
|
| 37 |
+
self.eval()
|
| 38 |
+
for p in self.parameters():
|
| 39 |
+
p.requires_grad_(False)
|
| 40 |
+
|
| 41 |
+
def _load_model(self):
|
| 42 |
+
from esm.models.esmc import ESMC
|
| 43 |
+
from esm.utils.constants.models import ESMC_300M, ESMC_600M
|
| 44 |
+
if self.model_name == "esmc_300m":
|
| 45 |
+
model_const = ESMC_300M
|
| 46 |
+
self.D_esm = 960
|
| 47 |
+
elif self.model_name == "esmc_600m":
|
| 48 |
+
model_const = ESMC_600M
|
| 49 |
+
self.D_esm = 1152
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError(f"Unknown model: {self.model_name}")
|
| 52 |
+
self.model = ESMC.from_pretrained(model_name=model_const, device=self._device)
|
| 53 |
+
self.tokenizer = self.model.tokenizer
|
| 54 |
+
|
| 55 |
+
@torch.no_grad()
|
| 56 |
+
def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"):
|
| 57 |
+
from esm.utils import encoding
|
| 58 |
+
from esm.utils.misc import stack_variable_length_tensors
|
| 59 |
+
pad = self.tokenizer.pad_token_id
|
| 60 |
+
tokenized_seqs = []
|
| 61 |
+
for seq in sequences:
|
| 62 |
+
tokens = encoding.tokenize_sequence(seq, self.tokenizer, add_special_tokens=add_special_tokens)
|
| 63 |
+
if max_length is not None and len(tokens) > max_length:
|
| 64 |
+
tokens = tokens[:max_length]
|
| 65 |
+
tokenized_seqs.append(tokens)
|
| 66 |
+
input_ids = stack_variable_length_tensors(tokenized_seqs, constant_value=pad)
|
| 67 |
+
attention_mask = (input_ids != pad)
|
| 68 |
+
return input_ids, attention_mask
|
| 69 |
+
|
| 70 |
+
@torch.no_grad()
|
| 71 |
+
def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, return_contacts: bool = False):
|
| 72 |
+
device = self.model.device
|
| 73 |
+
input_ids = input_ids.to(device)
|
| 74 |
+
if attention_mask is not None:
|
| 75 |
+
attention_mask = attention_mask.to(device)
|
| 76 |
+
if self._autocast_dtype is not None and device.type == "cuda":
|
| 77 |
+
with torch.amp.autocast('cuda', dtype=self._autocast_dtype):
|
| 78 |
+
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
|
| 79 |
+
else:
|
| 80 |
+
outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
|
| 81 |
+
embeddings = outputs.embeddings
|
| 82 |
+
if return_dict:
|
| 83 |
+
return {"embeddings": embeddings, "attention_mask": attention_mask}
|
| 84 |
+
else:
|
| 85 |
+
return embeddings
|
| 86 |
+
|
| 87 |
+
def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None):
|
| 88 |
+
if attention_mask is not None:
|
| 89 |
+
lengths = attention_mask.sum(dim=1) - 2
|
| 90 |
+
lengths = lengths.clamp(min=1)
|
| 91 |
+
else:
|
| 92 |
+
B, L, D = embeddings.shape
|
| 93 |
+
lengths = torch.full((B,), L - 2, device=embeddings.device)
|
| 94 |
+
stripped = embeddings[:, 1:-1, :]
|
| 95 |
+
return stripped, lengths
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class CodonGPT(nn.Module):
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
vocab_size: int = 79,
|
| 104 |
+
hidden_size: int = 960,
|
| 105 |
+
num_layers: int = 24,
|
| 106 |
+
num_heads: int = 16,
|
| 107 |
+
mlp_ratio: float = 4.0,
|
| 108 |
+
max_position_embeddings: int = 4096,
|
| 109 |
+
dropout: float = 0.1,
|
| 110 |
+
layer_norm_eps: float = 1e-6,
|
| 111 |
+
num_special_tokens: int = 13,
|
| 112 |
+
special_ids: Optional[SpecialIds] = None,
|
| 113 |
+
esm_model_name: str = "esmc_300m",
|
| 114 |
+
esm_device: str = "cuda",
|
| 115 |
+
esm_dtype: str = "fp16",
|
| 116 |
+
max_protein_prefix: int = 0,
|
| 117 |
+
max_species_prefix: int = 0,
|
| 118 |
+
prepend_species: bool = True,
|
| 119 |
+
prepend_protein: bool = True,
|
| 120 |
+
species_embedding_dim: int = 1024,
|
| 121 |
+
attn_impl: str = "gqa", # "gqa" or "mha"
|
| 122 |
+
num_kv_groups: int = 0, # for GQA; 0 means default (no grouping)
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.vocab_size = vocab_size
|
| 126 |
+
self.hidden_size = hidden_size
|
| 127 |
+
self.num_layers = num_layers
|
| 128 |
+
self.num_heads = num_heads
|
| 129 |
+
self.max_position_embeddings = max_position_embeddings
|
| 130 |
+
|
| 131 |
+
self.special_ids = special_ids or SpecialIds()
|
| 132 |
+
self.num_special_tokens = num_special_tokens
|
| 133 |
+
|
| 134 |
+
# Single embedding table for all tokens (special + codon)
|
| 135 |
+
self.token_embed = nn.Embedding(vocab_size, hidden_size)
|
| 136 |
+
|
| 137 |
+
if prepend_protein and esm_model_name:
|
| 138 |
+
self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype)
|
| 139 |
+
# Project ESM token embeddings (D_esm) to model hidden size, then normalize
|
| 140 |
+
self.esm_ln = nn.Sequential(
|
| 141 |
+
nn.Linear(self.esm.D_esm, hidden_size, bias=False),
|
| 142 |
+
nn.ReLU(),
|
| 143 |
+
nn.LayerNorm(hidden_size),
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
self.esm = None
|
| 147 |
+
self.esm_ln = None
|
| 148 |
+
|
| 149 |
+
self.species_embedding_dim = species_embedding_dim if prepend_species else 0
|
| 150 |
+
if prepend_species:
|
| 151 |
+
# Project species embeddings (fixed or token sequence) from Ds -> H
|
| 152 |
+
self.species_ln = nn.Sequential(
|
| 153 |
+
nn.Linear(self.species_embedding_dim, hidden_size, bias=False),
|
| 154 |
+
nn.ReLU(),
|
| 155 |
+
nn.LayerNorm(hidden_size),
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
self.species_ln = None
|
| 159 |
+
|
| 160 |
+
# Optional per-prefix caps; 0 means unlimited (subject to global max length)
|
| 161 |
+
self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0
|
| 162 |
+
self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0
|
| 163 |
+
self.prepend_species = bool(prepend_species)
|
| 164 |
+
self.prepend_protein = bool(prepend_protein)
|
| 165 |
+
|
| 166 |
+
# Learned start embedding (BOS-less decoding)
|
| 167 |
+
self.start_embed = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
| 168 |
+
nn.init.normal_(self.start_embed, mean=0.0, std=0.02)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# Attention configuration
|
| 172 |
+
self.attn_impl = str(attn_impl)
|
| 173 |
+
self.num_kv_groups = int(num_kv_groups)
|
| 174 |
+
kv_groups = self.num_kv_groups
|
| 175 |
+
self.blocks = nn.ModuleList([
|
| 176 |
+
TransformerBlock(
|
| 177 |
+
dim=hidden_size,
|
| 178 |
+
num_heads=num_heads,
|
| 179 |
+
mlp_ratio=mlp_ratio,
|
| 180 |
+
dropout=dropout,
|
| 181 |
+
num_kv_groups=(kv_groups if (kv_groups > 0 and attn_impl == "gqa") else None),
|
| 182 |
+
qk_norm=False,
|
| 183 |
+
attn_type=("mha" if self.attn_impl == "mha" else "gqa"),
|
| 184 |
+
) for _ in range(num_layers)
|
| 185 |
+
])
|
| 186 |
+
|
| 187 |
+
self.ln_f = RMSNorm(hidden_size, eps=layer_norm_eps)
|
| 188 |
+
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
| 189 |
+
self.gradient_checkpointing = False
|
| 190 |
+
|
| 191 |
+
def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
device = self.token_embed.weight.device
|
| 193 |
+
return self.token_embed(token_ids.to(device))
|
| 194 |
+
|
| 195 |
+
def build_prefix(
|
| 196 |
+
self,
|
| 197 |
+
batch_size: int,
|
| 198 |
+
device: torch.device,
|
| 199 |
+
species_tok_emb: Optional[torch.Tensor] = None,
|
| 200 |
+
species_emb: Optional[torch.Tensor] = None,
|
| 201 |
+
protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 202 |
+
species_tok_emb_src: Optional[torch.Tensor] = None,
|
| 203 |
+
species_tok_emb_tgt: Optional[torch.Tensor] = None,
|
| 204 |
+
species_emb_src: Optional[torch.Tensor] = None,
|
| 205 |
+
species_emb_tgt: Optional[torch.Tensor] = None,
|
| 206 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 207 |
+
"""Build LLaVA-style prefix token embeddings by concatenating
|
| 208 |
+
[species_src]+[species_tgt]+[protein_tokens]. Returns:
|
| 209 |
+
- prefix: [B, Lp, H]
|
| 210 |
+
- prefix_lengths: [B] valid token counts per sample
|
| 211 |
+
"""
|
| 212 |
+
parts: list[torch.Tensor] = []
|
| 213 |
+
|
| 214 |
+
# Species: src then tgt (if provided)
|
| 215 |
+
if self.prepend_species and self.species_ln is not None:
|
| 216 |
+
tok_src = species_tok_emb_src if species_tok_emb_src is not None else species_tok_emb
|
| 217 |
+
tok_tgt = species_tok_emb_tgt if species_tok_emb_tgt is not None else species_tok_emb
|
| 218 |
+
emb_src = species_emb_src if species_emb_src is not None else species_emb
|
| 219 |
+
emb_tgt = species_emb_tgt if species_emb_tgt is not None else species_emb
|
| 220 |
+
|
| 221 |
+
def _as_tokens(S_tok, S_fix):
|
| 222 |
+
if S_fix is not None:
|
| 223 |
+
# [B, Ds] -> [B, 1, H]
|
| 224 |
+
S = self.species_ln(S_fix.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1))
|
| 225 |
+
return S
|
| 226 |
+
elif S_tok is not None:
|
| 227 |
+
# [B, Ls, Ds] -> optional cap, then project to H
|
| 228 |
+
S = S_tok
|
| 229 |
+
if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix:
|
| 230 |
+
S = S[:, : self.max_species_prefix, :]
|
| 231 |
+
S = S.to(device=device, dtype=next(self.parameters()).dtype)
|
| 232 |
+
S = self.species_ln(S)
|
| 233 |
+
return S
|
| 234 |
+
else:
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
Ssrc = _as_tokens(tok_src, emb_src)
|
| 238 |
+
if Ssrc is not None:
|
| 239 |
+
parts.append(Ssrc)
|
| 240 |
+
Sdst = _as_tokens(tok_tgt, emb_tgt)
|
| 241 |
+
if Sdst is not None:
|
| 242 |
+
parts.append(Sdst)
|
| 243 |
+
|
| 244 |
+
# Protein tokens from ESM-C
|
| 245 |
+
if self.prepend_protein and self.esm is not None and protein_input is not None:
|
| 246 |
+
prot_ids, prot_mask = protein_input
|
| 247 |
+
esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True)
|
| 248 |
+
P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask)
|
| 249 |
+
# Optional per-protein capping before projection
|
| 250 |
+
if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix:
|
| 251 |
+
P = P[:, : self.max_protein_prefix, :]
|
| 252 |
+
if lengths is not None:
|
| 253 |
+
lengths = lengths.clamp(max=self.max_protein_prefix)
|
| 254 |
+
if P.size(1) > 0:
|
| 255 |
+
P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype))
|
| 256 |
+
# Zero padded rows (per-sample) based on lengths
|
| 257 |
+
if lengths is not None:
|
| 258 |
+
Lp = P.size(1)
|
| 259 |
+
ar = torch.arange(Lp, device=device).unsqueeze(0)
|
| 260 |
+
lengths = lengths.to(device=device)
|
| 261 |
+
valid = ar < lengths.unsqueeze(1) # [B,Lp]
|
| 262 |
+
P = P * valid.unsqueeze(-1)
|
| 263 |
+
parts.append(P)
|
| 264 |
+
|
| 265 |
+
if len(parts) == 0:
|
| 266 |
+
empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype)
|
| 267 |
+
return empty, torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 268 |
+
|
| 269 |
+
prefix = torch.cat(parts, dim=1) if parts else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) # [B,Lp,H]
|
| 270 |
+
# Compute per-sample valid lengths: treat zero rows as padding
|
| 271 |
+
with torch.no_grad():
|
| 272 |
+
if prefix.size(1) > 0:
|
| 273 |
+
valid = (prefix.abs().sum(dim=-1) > 0)
|
| 274 |
+
lengths = valid.sum(dim=1).to(torch.long)
|
| 275 |
+
else:
|
| 276 |
+
lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 277 |
+
|
| 278 |
+
# ---- Enforce hard global budget on the prefix itself ----
|
| 279 |
+
prefix_budget = max(0, int(self.max_position_embeddings) - 1)
|
| 280 |
+
if prefix_budget == 0:
|
| 281 |
+
trimmed = prefix.new_zeros(prefix.size(0), 0, prefix.size(2))
|
| 282 |
+
return trimmed, torch.zeros(prefix.size(0), dtype=torch.long, device=prefix.device)
|
| 283 |
+
|
| 284 |
+
allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype))
|
| 285 |
+
Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0
|
| 286 |
+
if prefix.size(1) > Lp_max:
|
| 287 |
+
trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2))
|
| 288 |
+
for b in range(prefix.size(0)):
|
| 289 |
+
lb = int(allow[b].item())
|
| 290 |
+
if lb > 0:
|
| 291 |
+
trimmed[b, :lb, :] = prefix[b, :lb, :]
|
| 292 |
+
prefix = trimmed
|
| 293 |
+
lengths = allow
|
| 294 |
+
else:
|
| 295 |
+
lengths = allow
|
| 296 |
+
return prefix, lengths
|
| 297 |
+
|
| 298 |
+
def forward(
|
| 299 |
+
self,
|
| 300 |
+
codon_ids: torch.Tensor,
|
| 301 |
+
cond: Dict[str, Any] = None,
|
| 302 |
+
labels: Optional[torch.Tensor] = None,
|
| 303 |
+
return_dict: bool = True,
|
| 304 |
+
species_tok_emb: Optional[torch.Tensor] = None,
|
| 305 |
+
protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 306 |
+
protein_seqs: Optional[List[str]] = None,
|
| 307 |
+
# KV cache options
|
| 308 |
+
use_cache: bool = False,
|
| 309 |
+
past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 310 |
+
position_offset: int = 0,
|
| 311 |
+
) -> Dict[str, torch.Tensor]:
|
| 312 |
+
batch_size, codon_len = codon_ids.shape
|
| 313 |
+
device = codon_ids.device
|
| 314 |
+
|
| 315 |
+
# Unpack conditioning
|
| 316 |
+
if cond is not None:
|
| 317 |
+
control_mode = cond.get("control_mode", "fixed")
|
| 318 |
+
species_tok_emb_src = cond.get("species_tok_emb_src")
|
| 319 |
+
species_tok_emb_tgt = cond.get("species_tok_emb_tgt")
|
| 320 |
+
species_emb_src = cond.get("species_emb_src")
|
| 321 |
+
species_emb_tgt = cond.get("species_emb_tgt")
|
| 322 |
+
species_tok_emb = cond.get("species_tok_emb")
|
| 323 |
+
species_emb = cond.get("species_emb")
|
| 324 |
+
protein_input = cond.get("protein_input")
|
| 325 |
+
protein_seqs = cond.get("protein_seqs")
|
| 326 |
+
else:
|
| 327 |
+
species_emb = None
|
| 328 |
+
species_tok_emb_src = None
|
| 329 |
+
species_tok_emb_tgt = None
|
| 330 |
+
species_emb_src = None
|
| 331 |
+
species_emb_tgt = None
|
| 332 |
+
|
| 333 |
+
if protein_seqs is not None and protein_input is None:
|
| 334 |
+
if self.esm is not None:
|
| 335 |
+
with torch.no_grad():
|
| 336 |
+
# Respect per-protein ceiling during tokenization (+2 for BOS/EOS)
|
| 337 |
+
max_len_tokens = (self.max_protein_prefix + 2) if (getattr(self, "max_protein_prefix", 0) > 0) else None
|
| 338 |
+
protein_input = self.esm.tokenize(protein_seqs, max_length=max_len_tokens)
|
| 339 |
+
else:
|
| 340 |
+
protein_input = None
|
| 341 |
+
|
| 342 |
+
# Fast path: incremental decode using KV cache
|
| 343 |
+
if past_kv is not None:
|
| 344 |
+
# Expect only newly generated codon tokens here
|
| 345 |
+
if codon_ids.numel() == 0:
|
| 346 |
+
# Nothing to do; return a dummy next_logits
|
| 347 |
+
dummy = torch.zeros(batch_size, self.vocab_size, device=device, dtype=self.lm_head.weight.dtype)
|
| 348 |
+
return {"logits": dummy[:, 0:0], "next_logits": dummy}
|
| 349 |
+
|
| 350 |
+
x = self.embed_tokens(codon_ids) # [B, T_new, H]
|
| 351 |
+
|
| 352 |
+
present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
| 353 |
+
for i, block in enumerate(self.blocks):
|
| 354 |
+
kv_i = past_kv[i] if i < len(past_kv) else None
|
| 355 |
+
if self.training and getattr(self, 'gradient_checkpointing', False):
|
| 356 |
+
def _fn(inp):
|
| 357 |
+
return block(inp, past_kv=kv_i, use_cache=True, position_offset=position_offset)
|
| 358 |
+
out_blk = checkpoint.checkpoint(_fn, x, use_reentrant=False)
|
| 359 |
+
else:
|
| 360 |
+
out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset)
|
| 361 |
+
x, kv_out = out_blk # type: ignore[assignment]
|
| 362 |
+
present_kv.append(kv_out)
|
| 363 |
+
|
| 364 |
+
x = self.ln_f(x)
|
| 365 |
+
logits_step = self.lm_head(x) # [B, T_new, V]
|
| 366 |
+
next_logits = logits_step[:, -1, :]
|
| 367 |
+
out: Dict[str, torch.Tensor] = {"logits": logits_step[:, 0:0, :], "next_logits": next_logits}
|
| 368 |
+
out["present_kv"] = present_kv # type: ignore[assignment]
|
| 369 |
+
return out if return_dict else logits_step[:, 0:0, :]
|
| 370 |
+
|
| 371 |
+
# Standard path: build prefix and full window (training or prefill)
|
| 372 |
+
prefix, prefix_lengths = self.build_prefix(
|
| 373 |
+
batch_size=batch_size,
|
| 374 |
+
device=device,
|
| 375 |
+
species_tok_emb=species_tok_emb,
|
| 376 |
+
species_emb=species_emb if cond is not None else None,
|
| 377 |
+
protein_input=protein_input,
|
| 378 |
+
species_tok_emb_src=species_tok_emb_src,
|
| 379 |
+
species_tok_emb_tgt=species_tok_emb_tgt,
|
| 380 |
+
species_emb_src=species_emb_src,
|
| 381 |
+
species_emb_tgt=species_emb_tgt,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
start = self.start_embed.expand(batch_size, 1, self.hidden_size) # [B,1,H]
|
| 385 |
+
|
| 386 |
+
# Per-sample true codon input lengths (exclude PADs)
|
| 387 |
+
pad_id = int(self.special_ids.pad) if hasattr(self, "special_ids") and self.special_ids is not None else 0
|
| 388 |
+
codon_mask = (codon_ids != pad_id) # [B, N]
|
| 389 |
+
codon_lens = codon_mask.sum(dim=1) # [B]
|
| 390 |
+
|
| 391 |
+
# Budget remaining after prefix + start
|
| 392 |
+
capacity = max(0, int(self.max_position_embeddings))
|
| 393 |
+
budget_after_prefix = torch.clamp(
|
| 394 |
+
torch.as_tensor(capacity, device=device) - (prefix_lengths + 1),
|
| 395 |
+
min=0,
|
| 396 |
+
) # [B]
|
| 397 |
+
# Per-sample cap is limited by both budget and available codons
|
| 398 |
+
per_cap = torch.minimum(budget_after_prefix, codon_lens) # [B]
|
| 399 |
+
|
| 400 |
+
# Total valid lengths per sample (prefix + start + capped codon)
|
| 401 |
+
valid_lengths = prefix_lengths + 1 + per_cap
|
| 402 |
+
T = int(valid_lengths.max().item()) if valid_lengths.numel() > 0 else (1 + int(codon_lens.max().item()) if codon_lens.numel() > 0 else 1)
|
| 403 |
+
|
| 404 |
+
# Embed only the needed codon window for this batch
|
| 405 |
+
max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0
|
| 406 |
+
if max_cap > 0:
|
| 407 |
+
codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) # [B, max_cap, H]
|
| 408 |
+
else:
|
| 409 |
+
codon_emb = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype)
|
| 410 |
+
|
| 411 |
+
# Build sequence per-sample using concat to preserve gradients, then pad
|
| 412 |
+
seqs = []
|
| 413 |
+
for b in range(batch_size):
|
| 414 |
+
lp = int(prefix_lengths[b].item())
|
| 415 |
+
cap = int(per_cap[b].item())
|
| 416 |
+
parts = []
|
| 417 |
+
if lp > 0:
|
| 418 |
+
parts.append(prefix[b, :lp, :])
|
| 419 |
+
parts.append(start[b, 0:1, :])
|
| 420 |
+
if cap > 0:
|
| 421 |
+
parts.append(codon_emb[b, :cap, :])
|
| 422 |
+
seqs.append(torch.cat(parts, dim=0)) # [Lb, H]
|
| 423 |
+
x = rnn_utils.pad_sequence(seqs, batch_first=True) # [B, T, H]
|
| 424 |
+
|
| 425 |
+
present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
| 426 |
+
for block in self.blocks:
|
| 427 |
+
if self.training and getattr(self, 'gradient_checkpointing', False):
|
| 428 |
+
def _fn(inp):
|
| 429 |
+
return block(inp, use_cache=use_cache, position_offset=0)
|
| 430 |
+
blk_out = checkpoint.checkpoint(_fn, x, use_reentrant=False)
|
| 431 |
+
else:
|
| 432 |
+
blk_out = block(x, use_cache=use_cache, position_offset=0)
|
| 433 |
+
if use_cache:
|
| 434 |
+
x, kv = blk_out # type: ignore[misc]
|
| 435 |
+
present_kv_list.append(kv)
|
| 436 |
+
else:
|
| 437 |
+
x = blk_out # type: ignore[assignment]
|
| 438 |
+
|
| 439 |
+
x = self.ln_f(x)
|
| 440 |
+
logits_full = self.lm_head(x) # [B, T, V]
|
| 441 |
+
|
| 442 |
+
# Gather codon-aligned logits per sample: positions (lp+1) .. (lp+cap) (skip start)
|
| 443 |
+
next_logits_list = []
|
| 444 |
+
if max_cap == 0:
|
| 445 |
+
# Keep graph by slicing from logits_full
|
| 446 |
+
codon_logits = logits_full[:, 0:0, :]
|
| 447 |
+
for b in range(batch_size):
|
| 448 |
+
lp = int(prefix_lengths[b].item())
|
| 449 |
+
# Last consumed position is the start token at index lp
|
| 450 |
+
pos_next = lp
|
| 451 |
+
if pos_next < logits_full.size(1):
|
| 452 |
+
next_logits_list.append(logits_full[b, pos_next, :])
|
| 453 |
+
else:
|
| 454 |
+
next_logits_list.append(logits_full[b, -1, :])
|
| 455 |
+
next_logits = torch.stack(next_logits_list, dim=0)
|
| 456 |
+
else:
|
| 457 |
+
slices = []
|
| 458 |
+
for b in range(batch_size):
|
| 459 |
+
lp = int(prefix_lengths[b].item())
|
| 460 |
+
cap = int(per_cap[b].item())
|
| 461 |
+
# Skip the start position so logits align with labels = codon_ids[:, 1:]
|
| 462 |
+
sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size)
|
| 463 |
+
slices.append(sl)
|
| 464 |
+
# Next-token logits after processing 'cap' codons: last consumed is at lp + cap
|
| 465 |
+
pos_next = lp + cap
|
| 466 |
+
next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size))
|
| 467 |
+
codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) # [B,max_cap,V]
|
| 468 |
+
next_logits = torch.stack(next_logits_list, dim=0)
|
| 469 |
+
out = {"logits": codon_logits, "next_logits": next_logits}
|
| 470 |
+
|
| 471 |
+
if labels is not None:
|
| 472 |
+
# Align labels to per-sample caps: mask out positions >= cap
|
| 473 |
+
if labels.size(1) > 0 and max_cap > 0:
|
| 474 |
+
# Build masked labels with -100 beyond cap per sample
|
| 475 |
+
adj = labels.new_full((batch_size, max_cap), -100)
|
| 476 |
+
for b in range(batch_size):
|
| 477 |
+
cap = int(per_cap[b].item())
|
| 478 |
+
if cap > 0:
|
| 479 |
+
Lb = min(cap, labels.size(1))
|
| 480 |
+
adj[b, :Lb] = labels[b, :Lb]
|
| 481 |
+
loss = F.cross_entropy(codon_logits.reshape(-1, self.vocab_size), adj.reshape(-1), ignore_index=-100)
|
| 482 |
+
else:
|
| 483 |
+
loss = codon_logits.sum() * 0.0
|
| 484 |
+
out["loss"] = loss
|
| 485 |
+
# Provide optional debug stats for trainer logging
|
| 486 |
+
out["prefix_len"] = prefix_lengths.detach()
|
| 487 |
+
out["per_cap"] = per_cap.detach()
|
| 488 |
+
if use_cache:
|
| 489 |
+
out["present_kv"] = present_kv_list # type: ignore[assignment]
|
| 490 |
+
return out if return_dict else codon_logits
|
src/sampler.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/sampler.py
|
| 2 |
+
"""
|
| 3 |
+
Sampling utilities for CodonGPT.
|
| 4 |
+
|
| 5 |
+
Conditioning invariants:
|
| 6 |
+
- Species context: fixed-size [B, Ds] via species_emb or variable-length [B, Ls, Ds] via species_tok_emb
|
| 7 |
+
- Protein context: raw sequences; the model's Frozen ESM handles tokenization
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
from typing import List, Optional, Dict, Union, Tuple
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import logging
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import numpy as np
|
| 20 |
+
from safetensors.torch import load_file
|
| 21 |
+
|
| 22 |
+
from .models import CodonGPT
|
| 23 |
+
from .tokenizer import CodonTokenizer
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ----------------------------
|
| 29 |
+
# Logit filtering
|
| 30 |
+
# ----------------------------
|
| 31 |
+
|
| 32 |
+
def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
return logits if logits.dim() == 2 else logits.unsqueeze(0)
|
| 34 |
+
|
| 35 |
+
def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
|
| 36 |
+
"""Top-k filtering; logits is [B,V] or [V]."""
|
| 37 |
+
x = _ensure_2d_logits(logits)
|
| 38 |
+
k = max(1, min(int(k), x.size(-1)))
|
| 39 |
+
values, _ = torch.topk(x, k, dim=-1)
|
| 40 |
+
min_values = values[:, -1].unsqueeze(-1)
|
| 41 |
+
x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x)
|
| 42 |
+
return x if logits.dim() == 2 else x.squeeze(0)
|
| 43 |
+
|
| 44 |
+
def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
|
| 45 |
+
"""Top-p (nucleus) filtering; logits is [B,V] or [V]."""
|
| 46 |
+
if p >= 1.0:
|
| 47 |
+
return logits
|
| 48 |
+
if p <= 0.0:
|
| 49 |
+
# You asked for nothing; enjoy the abyss.
|
| 50 |
+
return torch.full_like(logits, float('-inf'))
|
| 51 |
+
x = _ensure_2d_logits(logits)
|
| 52 |
+
sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1)
|
| 53 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 54 |
+
cumprobs = torch.cumsum(probs, dim=-1)
|
| 55 |
+
to_remove = cumprobs > p
|
| 56 |
+
to_remove[:, 1:] = to_remove[:, :-1].clone()
|
| 57 |
+
to_remove[:, 0] = False
|
| 58 |
+
mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove)
|
| 59 |
+
x = torch.where(mask, torch.full_like(x, float('-inf')), x)
|
| 60 |
+
return x if logits.dim() == 2 else x.squeeze(0)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ----------------------------
|
| 64 |
+
# Sampler
|
| 65 |
+
# ----------------------------
|
| 66 |
+
|
| 67 |
+
class CodonSampler:
|
| 68 |
+
"""
|
| 69 |
+
GPT sampler with conditional generation.
|
| 70 |
+
|
| 71 |
+
Requires in model_dir:
|
| 72 |
+
- vocab.json
|
| 73 |
+
- model.safetensors (preferred)
|
| 74 |
+
or pytorch_model.bin (legacy)
|
| 75 |
+
- trainer_config.json or config.json
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
model_path: str,
|
| 81 |
+
device: str = "cuda",
|
| 82 |
+
species_store=None, # SpeciesEmbeddingStore
|
| 83 |
+
compile_model: bool = False,
|
| 84 |
+
taxonomy_db_path: Optional[str] = None,
|
| 85 |
+
qwen_max_length: int = 512,
|
| 86 |
+
qwen_batch_size: int = 16,
|
| 87 |
+
**_: dict,
|
| 88 |
+
):
|
| 89 |
+
self.device = torch.device(device)
|
| 90 |
+
self.model_dir = Path(model_path)
|
| 91 |
+
|
| 92 |
+
# Required files (allow fallback to parent dir for vocab.json)
|
| 93 |
+
vocab_path = self.model_dir / "vocab.json"
|
| 94 |
+
if not vocab_path.exists():
|
| 95 |
+
parent_vocab = self.model_dir.parent / "vocab.json"
|
| 96 |
+
if parent_vocab.exists():
|
| 97 |
+
vocab_path = parent_vocab
|
| 98 |
+
else:
|
| 99 |
+
raise FileNotFoundError(f"Missing {self.model_dir / 'vocab.json'}")
|
| 100 |
+
trainer_cfg = self.model_dir / "trainer_config.json"
|
| 101 |
+
cfg_path = trainer_cfg if trainer_cfg.exists() else (self.model_dir / "config.json")
|
| 102 |
+
if not cfg_path.exists():
|
| 103 |
+
raise FileNotFoundError(f"Missing trainer_config.json or config.json in {self.model_dir}")
|
| 104 |
+
|
| 105 |
+
# Load config
|
| 106 |
+
with open(cfg_path, "r") as f:
|
| 107 |
+
self.config = json.load(f)
|
| 108 |
+
|
| 109 |
+
# Tokenizer
|
| 110 |
+
# If vocab was loaded from parent dir, pass that path; else model_dir
|
| 111 |
+
vocab_dir = vocab_path.parent
|
| 112 |
+
self.tokenizer = CodonTokenizer.from_pretrained(str(vocab_dir))
|
| 113 |
+
self.V = int(self.tokenizer.vocab_size)
|
| 114 |
+
self._eos_id = int(self.tokenizer.eos_token_id)
|
| 115 |
+
self._pad_id = int(self.tokenizer.pad_token_id)
|
| 116 |
+
self._num_special = int(self.tokenizer.num_special_tokens)
|
| 117 |
+
|
| 118 |
+
# Species store (optional if you pass species_emb* directly at sample())
|
| 119 |
+
self.species_store = species_store
|
| 120 |
+
self.species_vocab = (self.species_store.vocab if self.species_store is not None else {})
|
| 121 |
+
self.taxonomy_db_path = taxonomy_db_path
|
| 122 |
+
self.qwen_opts = {
|
| 123 |
+
"max_length": int(qwen_max_length),
|
| 124 |
+
"batch_size": int(qwen_batch_size),
|
| 125 |
+
}
|
| 126 |
+
# Lazy-inited Qwen objects
|
| 127 |
+
self._qwen_tokenizer = None
|
| 128 |
+
self._qwen_model = None
|
| 129 |
+
|
| 130 |
+
# Model
|
| 131 |
+
state = self._load_state_dict()
|
| 132 |
+
arch = self._infer_arch_from_state_dict(state)
|
| 133 |
+
self.model = CodonGPT(
|
| 134 |
+
vocab_size=self.V,
|
| 135 |
+
hidden_size=int(arch["hidden_size"]),
|
| 136 |
+
num_layers=int(arch["num_layers"]),
|
| 137 |
+
num_heads=int(arch["num_heads"]),
|
| 138 |
+
mlp_ratio=float(arch["mlp_ratio"]),
|
| 139 |
+
max_position_embeddings=int(arch["max_position_embeddings"]),
|
| 140 |
+
dropout=float(self.config.get("dropout", 0.1)),
|
| 141 |
+
num_special_tokens=self._num_special,
|
| 142 |
+
special_ids=self.tokenizer.special_ids,
|
| 143 |
+
esm_model_name=str(arch["esm_model_name"]) if bool(arch["prepend_protein"]) else None,
|
| 144 |
+
esm_device=str(arch["esm_device"]),
|
| 145 |
+
esm_dtype=str(arch["esm_dtype"]),
|
| 146 |
+
max_protein_prefix=int(arch["max_protein_prefix"]) if bool(arch["prepend_protein"]) else 0,
|
| 147 |
+
max_species_prefix=int(arch["max_species_prefix"]) if bool(arch["prepend_species"]) else 0,
|
| 148 |
+
prepend_species=bool(arch["prepend_species"]),
|
| 149 |
+
prepend_protein=bool(arch["prepend_protein"]),
|
| 150 |
+
species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)),
|
| 151 |
+
attn_impl=str(arch.get("attn_impl", "gqa")),
|
| 152 |
+
num_kv_groups=int(arch.get("num_kv_groups", 0)),
|
| 153 |
+
)
|
| 154 |
+
missing, unexpected = self.model.load_state_dict(state, strict=False)
|
| 155 |
+
if len(unexpected) > 0:
|
| 156 |
+
logger.warning(f"Unexpected keys in state dict: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}")
|
| 157 |
+
if len(missing) > 0:
|
| 158 |
+
logger.warning(f"Missing keys in state dict: {missing[:10]}{'...' if len(missing) > 10 else ''}")
|
| 159 |
+
|
| 160 |
+
if compile_model:
|
| 161 |
+
# If this errors on your PyTorch build, that's on you. No try/except.
|
| 162 |
+
self.model = torch.compile(self.model) # type: ignore
|
| 163 |
+
|
| 164 |
+
self.model.to(self.device).eval()
|
| 165 |
+
logger.info(f"Loaded GPT model from {self.model_dir}")
|
| 166 |
+
try:
|
| 167 |
+
hs = int(getattr(self.model, "hidden_size", -1))
|
| 168 |
+
hh = int(getattr(self.model, "num_heads", -1))
|
| 169 |
+
nl = int(getattr(self.model, "num_layers", -1))
|
| 170 |
+
logger.info(f"Reconstructed arch: hidden={hs} heads={hh} layers={nl}")
|
| 171 |
+
except Exception:
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
# Static masks
|
| 175 |
+
self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device)
|
| 176 |
+
self._allowed_fixed[:self._num_special] = False # no specials in fixed mode
|
| 177 |
+
|
| 178 |
+
self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device)
|
| 179 |
+
self._allowed_variable[:self._num_special] = False
|
| 180 |
+
self._allowed_variable[self._eos_id] = True # EOS allowed in variable mode
|
| 181 |
+
|
| 182 |
+
# ----------------------------
|
| 183 |
+
# Loading / arch inference
|
| 184 |
+
# ----------------------------
|
| 185 |
+
|
| 186 |
+
def _load_state_dict(self) -> Dict[str, torch.Tensor]:
|
| 187 |
+
st_p = self.model_dir / "model.safetensors"
|
| 188 |
+
pt_p = self.model_dir / "pytorch_model.bin"
|
| 189 |
+
if st_p.exists():
|
| 190 |
+
return load_file(st_p)
|
| 191 |
+
if pt_p.exists():
|
| 192 |
+
return torch.load(pt_p, map_location="cpu")
|
| 193 |
+
raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}")
|
| 194 |
+
|
| 195 |
+
def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Union[int, float, bool, str]]:
|
| 196 |
+
arch: Dict[str, Union[int, float, bool, str]] = {}
|
| 197 |
+
|
| 198 |
+
# hidden size
|
| 199 |
+
if "lm_head.weight" in state_dict:
|
| 200 |
+
arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1])
|
| 201 |
+
else:
|
| 202 |
+
for k, v in state_dict.items():
|
| 203 |
+
if k.endswith("ln_f.weight"):
|
| 204 |
+
arch["hidden_size"] = int(v.shape[0])
|
| 205 |
+
break
|
| 206 |
+
# Prefer config when present to avoid guessing errors
|
| 207 |
+
cfg = self.config or {}
|
| 208 |
+
if "hidden_size" in cfg:
|
| 209 |
+
arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore[index]
|
| 210 |
+
if "hidden_size" not in arch:
|
| 211 |
+
arch["hidden_size"] = int(cfg.get("hidden_size", 960))
|
| 212 |
+
H = int(arch["hidden_size"])
|
| 213 |
+
|
| 214 |
+
# layers
|
| 215 |
+
max_block = -1
|
| 216 |
+
for k in state_dict.keys():
|
| 217 |
+
if k.startswith("blocks."):
|
| 218 |
+
idx = int(k.split(".")[1])
|
| 219 |
+
if idx > max_block:
|
| 220 |
+
max_block = idx
|
| 221 |
+
arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12))
|
| 222 |
+
if "num_hidden_layers" in cfg:
|
| 223 |
+
arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore[index]
|
| 224 |
+
|
| 225 |
+
# mlp ratio from w1
|
| 226 |
+
w1_key = "blocks.0.ffn.w1.weight" if "blocks.0.ffn.w1.weight" in state_dict else None
|
| 227 |
+
if w1_key is None:
|
| 228 |
+
for i in range(1, 3):
|
| 229 |
+
k = f"blocks.{i}.ffn.w1.weight"
|
| 230 |
+
if k in state_dict:
|
| 231 |
+
w1_key = k
|
| 232 |
+
break
|
| 233 |
+
if w1_key is not None and H > 0:
|
| 234 |
+
arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H)
|
| 235 |
+
else:
|
| 236 |
+
arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0))
|
| 237 |
+
|
| 238 |
+
# heads – pick a divisor of H
|
| 239 |
+
cfg_heads = cfg.get("num_attention_heads")
|
| 240 |
+
if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0:
|
| 241 |
+
arch["num_heads"] = int(cfg_heads)
|
| 242 |
+
else:
|
| 243 |
+
for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1):
|
| 244 |
+
if H % h == 0:
|
| 245 |
+
arch["num_heads"] = h
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
# conditioning flags from presence of submodules
|
| 249 |
+
arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys())))
|
| 250 |
+
has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys())
|
| 251 |
+
arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm)))
|
| 252 |
+
arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m"))
|
| 253 |
+
arch["esm_device"] = str(cfg.get("esm_device", "cuda"))
|
| 254 |
+
arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower()
|
| 255 |
+
arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0))
|
| 256 |
+
arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0))
|
| 257 |
+
|
| 258 |
+
if "max_length" in cfg:
|
| 259 |
+
arch["max_position_embeddings"] = int(cfg.get("max_length", 1024))
|
| 260 |
+
else:
|
| 261 |
+
arch["max_position_embeddings"] = int(cfg.get("max_position_embeddings", 1024))
|
| 262 |
+
# Attention impl and num_kv_groups (from config or infer from weights)
|
| 263 |
+
attn_impl = str(cfg.get("attn_impl", ""))
|
| 264 |
+
num_kv_groups = int(cfg.get("num_kv_groups", 0))
|
| 265 |
+
if not attn_impl:
|
| 266 |
+
wk_key = next((k for k in state_dict.keys() if k.endswith("attn.Wk.weight")), None)
|
| 267 |
+
if wk_key is not None:
|
| 268 |
+
attn_impl = "gqa"
|
| 269 |
+
out_ch, _ = state_dict[wk_key].shape
|
| 270 |
+
num_heads = int(arch.get("num_heads", 1))
|
| 271 |
+
head_dim = int(arch["hidden_size"]) // max(1, num_heads)
|
| 272 |
+
if head_dim > 0:
|
| 273 |
+
num_kv_groups = max(1, out_ch // head_dim)
|
| 274 |
+
else:
|
| 275 |
+
attn_impl = "mha"
|
| 276 |
+
num_kv_groups = 0
|
| 277 |
+
arch["attn_impl"] = attn_impl
|
| 278 |
+
arch["num_kv_groups"] = num_kv_groups
|
| 279 |
+
|
| 280 |
+
return arch # type: ignore[return-value]
|
| 281 |
+
|
| 282 |
+
# ----------------------------
|
| 283 |
+
# Public API
|
| 284 |
+
# ----------------------------
|
| 285 |
+
|
| 286 |
+
@torch.no_grad()
|
| 287 |
+
def sample(
|
| 288 |
+
self,
|
| 289 |
+
num_sequences: int = 1,
|
| 290 |
+
sequence_length: int = 100, # target number of codons (fixed mode); max iterations (variable)
|
| 291 |
+
species: Optional[Union[str, List[str]]] = None,
|
| 292 |
+
protein_sequences: Optional[Union[str, List[str]]] = None,
|
| 293 |
+
control_mode: str = "fixed", # "fixed" or "variable"
|
| 294 |
+
target_protein_length: Optional[int] = None, # deprecated; alias to sequence_length
|
| 295 |
+
temperature: float = 1.0,
|
| 296 |
+
top_k: Optional[int] = None,
|
| 297 |
+
top_p: Optional[float] = None,
|
| 298 |
+
seed: Optional[int] = None,
|
| 299 |
+
return_intermediate: bool = False,
|
| 300 |
+
progress_bar: bool = False,
|
| 301 |
+
species_emb: Optional[torch.Tensor] = None, # [B, Ds]
|
| 302 |
+
species_tok_emb: Optional[torch.Tensor] = None, # [B, Ls, Ds]
|
| 303 |
+
enforce_translation: bool = False,
|
| 304 |
+
codon_enforcement_weight: float = 10.0, # unused with hard mask; kept for API compatibility
|
| 305 |
+
) -> Dict[str, Union[List[str], torch.Tensor, List[bool]]]:
|
| 306 |
+
|
| 307 |
+
if seed is not None:
|
| 308 |
+
torch.manual_seed(int(seed))
|
| 309 |
+
np.random.seed(int(seed))
|
| 310 |
+
|
| 311 |
+
if control_mode not in ("fixed", "variable"):
|
| 312 |
+
raise ValueError(f"control_mode must be 'fixed' or 'variable', got {control_mode}")
|
| 313 |
+
|
| 314 |
+
B = int(num_sequences)
|
| 315 |
+
T_codons = int(sequence_length if target_protein_length is None else target_protein_length)
|
| 316 |
+
|
| 317 |
+
# Prepare conditioning
|
| 318 |
+
cond: Dict[str, Union[str, List[str], torch.Tensor]] = {"control_mode": control_mode}
|
| 319 |
+
|
| 320 |
+
# Species (priority: provided tensors → names via store)
|
| 321 |
+
if species_tok_emb is not None:
|
| 322 |
+
if species_tok_emb.ndim != 3 or species_tok_emb.size(0) != B:
|
| 323 |
+
raise ValueError("species_tok_emb must be [B, Ls, Ds]")
|
| 324 |
+
st = species_tok_emb.to(self.device)
|
| 325 |
+
cond["species_tok_emb_src"] = st
|
| 326 |
+
cond["species_tok_emb_tgt"] = st
|
| 327 |
+
elif species_emb is not None:
|
| 328 |
+
if species_emb.ndim != 2 or species_emb.size(0) != B:
|
| 329 |
+
raise ValueError("species_emb must be [B, Ds]")
|
| 330 |
+
se = species_emb.to(self.device)
|
| 331 |
+
cond["species_emb_src"] = se
|
| 332 |
+
cond["species_emb_tgt"] = se
|
| 333 |
+
elif species is not None:
|
| 334 |
+
names = [species] * B if isinstance(species, str) else species
|
| 335 |
+
if len(names) != B:
|
| 336 |
+
raise ValueError("Length of species list must match num_sequences")
|
| 337 |
+
|
| 338 |
+
# If we have a store (variable-length), use it for known species and compute Qwen embeddings for unknowns.
|
| 339 |
+
if self.species_store is not None:
|
| 340 |
+
ids = [self.species_store.vocab.get(n, -1) for n in names]
|
| 341 |
+
known_mask = [i for i, sid in enumerate(ids) if sid >= 0]
|
| 342 |
+
unk_mask = [i for i, sid in enumerate(ids) if sid < 0]
|
| 343 |
+
|
| 344 |
+
# Only variable-length embeddings are supported. If the store is not sequence-based, compute via Qwen for all.
|
| 345 |
+
use_sequence = bool(getattr(self.species_store, "is_legacy", False))
|
| 346 |
+
if not use_sequence:
|
| 347 |
+
# Fall back to Qwen for everything
|
| 348 |
+
q_tok, q_len = self._qwen_embed_names(names, pooling="sequence")
|
| 349 |
+
cond["species_tok_emb_src"] = q_tok.to(self.device)
|
| 350 |
+
cond["species_tok_emb_tgt"] = q_tok.to(self.device)
|
| 351 |
+
else:
|
| 352 |
+
# list of per-sample [L,D] tensors to be padded later
|
| 353 |
+
seq_list: List[torch.Tensor] = [None] * B # type: ignore[list-item]
|
| 354 |
+
D = int(getattr(self.species_store, "_ds", 1024))
|
| 355 |
+
# Known via store
|
| 356 |
+
if known_mask:
|
| 357 |
+
sub_ids = [ids[i] for i in known_mask]
|
| 358 |
+
result = self.species_store.batch_get(sub_ids)
|
| 359 |
+
assert isinstance(result, tuple)
|
| 360 |
+
sp_tok, _ = result
|
| 361 |
+
for j, i in enumerate(known_mask):
|
| 362 |
+
row = sp_tok[j]
|
| 363 |
+
nonzero = (row.abs().sum(dim=-1) > 0)
|
| 364 |
+
L = int(nonzero.sum().item()) if nonzero.any() else int(row.size(0))
|
| 365 |
+
seq_list[i] = row[:L].to(self.device)
|
| 366 |
+
# Unknown via Qwen
|
| 367 |
+
if unk_mask:
|
| 368 |
+
unk_names = [names[i] for i in unk_mask]
|
| 369 |
+
q_tok, q_len = self._qwen_embed_names(unk_names, pooling="sequence")
|
| 370 |
+
for j, i in enumerate(unk_mask):
|
| 371 |
+
L = int(q_len[j].item())
|
| 372 |
+
seq_list[i] = q_tok[j, :L, :].to(self.device)
|
| 373 |
+
|
| 374 |
+
# Pad to [B,Lmax,D]
|
| 375 |
+
Lmax = max((t.size(0) for t in seq_list if t is not None), default=0)
|
| 376 |
+
if Lmax == 0:
|
| 377 |
+
raise RuntimeError("No species embeddings could be constructed.")
|
| 378 |
+
padded = torch.zeros(B, Lmax, D, device=self.device, dtype=seq_list[0].dtype)
|
| 379 |
+
for i, t in enumerate(seq_list):
|
| 380 |
+
if t is None:
|
| 381 |
+
continue
|
| 382 |
+
L = t.size(0)
|
| 383 |
+
padded[i, :L, :] = t
|
| 384 |
+
cond["species_tok_emb_src"] = padded
|
| 385 |
+
cond["species_tok_emb_tgt"] = padded
|
| 386 |
+
else:
|
| 387 |
+
# No store: compute everything via Qwen (sequence pooling only)
|
| 388 |
+
emb, lengths = self._qwen_embed_names(names, pooling="sequence")
|
| 389 |
+
st = emb.to(self.device, non_blocking=True)
|
| 390 |
+
cond["species_tok_emb_src"] = st
|
| 391 |
+
cond["species_tok_emb_tgt"] = st
|
| 392 |
+
|
| 393 |
+
# Protein sequences (raw AA strings; the model handles ESM-C)
|
| 394 |
+
if protein_sequences is not None:
|
| 395 |
+
if isinstance(protein_sequences, list):
|
| 396 |
+
if len(protein_sequences) != B:
|
| 397 |
+
raise ValueError("Length of protein_sequences must match num_sequences")
|
| 398 |
+
cond["protein_seqs"] = protein_sequences
|
| 399 |
+
else:
|
| 400 |
+
cond["protein_seqs"] = [protein_sequences] * B
|
| 401 |
+
|
| 402 |
+
# Start with empty codon context; we'll prefill to build KV cache and get first-step logits
|
| 403 |
+
input_ids = torch.empty((B, 0), dtype=torch.long, device=self.device)
|
| 404 |
+
|
| 405 |
+
# Capacity probe and fallback: if prefix consumes all budget, cap species/protein prefix temporarily (prefill path)
|
| 406 |
+
pref = None
|
| 407 |
+
try:
|
| 408 |
+
out0 = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
|
| 409 |
+
pref = out0.get("prefix_len") if isinstance(out0, dict) else None
|
| 410 |
+
if pref is not None:
|
| 411 |
+
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
|
| 412 |
+
remaining0 = max_pos - (pref + 1)
|
| 413 |
+
need_cap = (remaining0 <= 0).any()
|
| 414 |
+
else:
|
| 415 |
+
need_cap = False
|
| 416 |
+
if need_cap:
|
| 417 |
+
prev_sp = int(getattr(self.model, "max_species_prefix", 0))
|
| 418 |
+
prev_pp = int(getattr(self.model, "max_protein_prefix", 0))
|
| 419 |
+
if prev_sp == 0 or prev_sp > 256:
|
| 420 |
+
setattr(self.model, "max_species_prefix", 256)
|
| 421 |
+
if prev_pp == 0 or prev_pp > 256:
|
| 422 |
+
setattr(self.model, "max_protein_prefix", 256)
|
| 423 |
+
out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
|
| 424 |
+
pref = out0b.get("prefix_len") if isinstance(out0b, dict) else None
|
| 425 |
+
if pref is not None:
|
| 426 |
+
remaining0b = max_pos - (pref + 1)
|
| 427 |
+
if (remaining0b <= 0).all():
|
| 428 |
+
setattr(self.model, "max_species_prefix", 128)
|
| 429 |
+
setattr(self.model, "max_protein_prefix", 128)
|
| 430 |
+
out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
|
| 431 |
+
pref = out0b.get("prefix_len") if isinstance(out0b, dict) else pref
|
| 432 |
+
# Use the prefill output
|
| 433 |
+
out_prefill = out0 if pref is None else out0
|
| 434 |
+
except Exception:
|
| 435 |
+
# Fallback without cache
|
| 436 |
+
out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
|
| 437 |
+
pref = out_prefill.get("prefix_len") if isinstance(out_prefill, dict) else None
|
| 438 |
+
|
| 439 |
+
allowed = self._allowed_variable if control_mode == "variable" else self._allowed_fixed
|
| 440 |
+
finished = torch.zeros(B, dtype=torch.bool, device=self.device) # EOS reached (variable) OR capacity exhausted
|
| 441 |
+
capacity_truncated = torch.zeros(B, dtype=torch.bool, device=self.device)
|
| 442 |
+
|
| 443 |
+
intermediate = [] if return_intermediate else None
|
| 444 |
+
aa2codons = self.tokenizer.aa2codons_char_map()
|
| 445 |
+
|
| 446 |
+
# If we probed capacity, optionally clamp target codons by available capacity at step 0
|
| 447 |
+
try:
|
| 448 |
+
if pref is not None:
|
| 449 |
+
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
|
| 450 |
+
remaining = (max_pos - (pref + 1)).clamp(min=0)
|
| 451 |
+
T_codons = int(min(T_codons, int(remaining.max().item())))
|
| 452 |
+
except Exception:
|
| 453 |
+
pass
|
| 454 |
+
|
| 455 |
+
# KV cache and initial logits from prefill
|
| 456 |
+
kv = out_prefill.get("present_kv") if isinstance(out_prefill, dict) else None
|
| 457 |
+
logits = out_prefill.get("next_logits") if isinstance(out_prefill, dict) else None
|
| 458 |
+
if kv is None or logits is None:
|
| 459 |
+
# Safety: compute once if not provided
|
| 460 |
+
out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
|
| 461 |
+
kv = out_prefill.get("present_kv")
|
| 462 |
+
logits = out_prefill.get("next_logits")
|
| 463 |
+
assert kv is not None and logits is not None
|
| 464 |
+
prefix_len = pref if pref is not None else torch.zeros(B, dtype=torch.long, device=self.device)
|
| 465 |
+
prefill_len = (prefix_len + 1) # prefix + start
|
| 466 |
+
|
| 467 |
+
rng = range(T_codons)
|
| 468 |
+
if progress_bar:
|
| 469 |
+
from tqdm import tqdm
|
| 470 |
+
rng = tqdm(rng, desc="GPT sampling", total=T_codons)
|
| 471 |
+
|
| 472 |
+
for step in rng:
|
| 473 |
+
# Enforce global capacity per sample using prefix_len and current generated length
|
| 474 |
+
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
|
| 475 |
+
remaining_now = (max_pos - prefill_len - input_ids.size(1)).clamp(max=10**9)
|
| 476 |
+
cant_extend = remaining_now <= 0
|
| 477 |
+
newly_blocked = (~finished) & cant_extend
|
| 478 |
+
capacity_truncated = capacity_truncated | newly_blocked
|
| 479 |
+
finished = finished | cant_extend
|
| 480 |
+
|
| 481 |
+
# Base mask: disallow specials in fixed, allow EOS in variable.
|
| 482 |
+
logits = logits.masked_fill(~allowed, float("-inf"))
|
| 483 |
+
|
| 484 |
+
# If a sample is finished (EOS or capacity), force PAD to keep shapes stable.
|
| 485 |
+
# Decoding will drop PAD anyway.
|
| 486 |
+
if finished.any():
|
| 487 |
+
logits[finished] = float("-inf")
|
| 488 |
+
logits[finished, self._pad_id] = 0.0
|
| 489 |
+
|
| 490 |
+
# Optional: enforce codon ↔ AA mapping at this step (hard mask)
|
| 491 |
+
if enforce_translation and ("protein_seqs" in cond):
|
| 492 |
+
aas_now: List[Optional[str]] = []
|
| 493 |
+
prot_list = cond["protein_seqs"] # type: ignore[index]
|
| 494 |
+
assert isinstance(prot_list, list)
|
| 495 |
+
for i in range(B):
|
| 496 |
+
seq = prot_list[i]
|
| 497 |
+
aas_now.append(seq[step] if step < len(seq) else None)
|
| 498 |
+
|
| 499 |
+
mask = torch.zeros_like(logits, dtype=torch.bool)
|
| 500 |
+
for i, a in enumerate(aas_now):
|
| 501 |
+
if a is None:
|
| 502 |
+
mask[i, self._num_special:self.V] = True
|
| 503 |
+
else:
|
| 504 |
+
valid = aa2codons.get(a, [])
|
| 505 |
+
if len(valid) == 0:
|
| 506 |
+
mask[i, self._num_special:self.V] = True
|
| 507 |
+
else:
|
| 508 |
+
mask[i, valid] = True
|
| 509 |
+
logits = logits.masked_fill(~mask, float("-inf"))
|
| 510 |
+
|
| 511 |
+
# Temperature + filtering
|
| 512 |
+
if temperature != 1.0:
|
| 513 |
+
logits = logits / float(temperature)
|
| 514 |
+
if top_k is not None:
|
| 515 |
+
logits = _top_k_filtering(logits, int(top_k))
|
| 516 |
+
if top_p is not None:
|
| 517 |
+
logits = _top_p_filtering(logits, float(top_p))
|
| 518 |
+
|
| 519 |
+
probs = F.softmax(logits, dim=-1)
|
| 520 |
+
next_tok = torch.multinomial(probs, num_samples=1) # [B,1]
|
| 521 |
+
|
| 522 |
+
if control_mode == "variable":
|
| 523 |
+
# Stop sequences at EOS
|
| 524 |
+
eos_mask = (next_tok.squeeze(-1) == self._eos_id)
|
| 525 |
+
finished = finished | eos_mask
|
| 526 |
+
|
| 527 |
+
input_ids = torch.cat([input_ids, next_tok], dim=1)
|
| 528 |
+
|
| 529 |
+
if return_intermediate:
|
| 530 |
+
intermediate.append(input_ids.clone())
|
| 531 |
+
|
| 532 |
+
# If all sequences are finished, we're done.
|
| 533 |
+
if finished.all():
|
| 534 |
+
break
|
| 535 |
+
|
| 536 |
+
# Incremental decode: compute logits for next step and update KV cache
|
| 537 |
+
pos_offset = int(prefill_len.max().item()) + input_ids.size(1) - 1 # use max offset for shared RoPE cache
|
| 538 |
+
out_inc = self.model(
|
| 539 |
+
codon_ids=next_tok,
|
| 540 |
+
cond=None,
|
| 541 |
+
return_dict=True,
|
| 542 |
+
use_cache=True,
|
| 543 |
+
past_kv=kv,
|
| 544 |
+
position_offset=pos_offset,
|
| 545 |
+
)
|
| 546 |
+
kv = out_inc.get("present_kv")
|
| 547 |
+
logits = out_inc.get("next_logits")
|
| 548 |
+
assert kv is not None and logits is not None
|
| 549 |
+
|
| 550 |
+
# Build final DNA strings, dropping specials and any PADs we added
|
| 551 |
+
output_token_rows: List[List[int]] = []
|
| 552 |
+
for row in input_ids.tolist():
|
| 553 |
+
toks: List[int] = []
|
| 554 |
+
for t in row:
|
| 555 |
+
if t == self._pad_id:
|
| 556 |
+
continue
|
| 557 |
+
if t == self._eos_id:
|
| 558 |
+
break # variable mode terminator
|
| 559 |
+
if t >= self._num_special and t < self.V:
|
| 560 |
+
toks.append(int(t))
|
| 561 |
+
if control_mode == "fixed":
|
| 562 |
+
# In fixed mode we *intended* T_codons; if capacity cut us short, it's fine.
|
| 563 |
+
toks = toks[:T_codons]
|
| 564 |
+
output_token_rows.append(toks)
|
| 565 |
+
|
| 566 |
+
sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows]
|
| 567 |
+
|
| 568 |
+
# Pad variable-length rows for input_ids to avoid tensor construction errors when
|
| 569 |
+
# some samples are capacity-truncated in fixed mode.
|
| 570 |
+
max_len = max((len(r) for r in output_token_rows), default=0)
|
| 571 |
+
if max_len > 0:
|
| 572 |
+
ids_padded = torch.full(
|
| 573 |
+
(len(output_token_rows), max_len),
|
| 574 |
+
self._pad_id,
|
| 575 |
+
device=self.device,
|
| 576 |
+
dtype=torch.long,
|
| 577 |
+
)
|
| 578 |
+
for i, row in enumerate(output_token_rows):
|
| 579 |
+
if len(row) > 0:
|
| 580 |
+
ids_padded[i, : len(row)] = torch.tensor(row, device=self.device, dtype=torch.long)
|
| 581 |
+
else:
|
| 582 |
+
ids_padded = torch.empty((len(output_token_rows), 0), device=self.device, dtype=torch.long)
|
| 583 |
+
|
| 584 |
+
result: Dict[str, Union[List[str], torch.Tensor, List[bool]]] = {
|
| 585 |
+
"sequences": sequences,
|
| 586 |
+
"input_ids": ids_padded,
|
| 587 |
+
"capacity_truncated": capacity_truncated.detach().bool().tolist(),
|
| 588 |
+
}
|
| 589 |
+
if return_intermediate:
|
| 590 |
+
result["intermediate_states"] = intermediate # list[Tensor], length = steps actually taken
|
| 591 |
+
return result
|
| 592 |
+
|
| 593 |
+
# ----------------------------
|
| 594 |
+
# Qwen embedding (inline; no separate module)
|
| 595 |
+
# ----------------------------
|
| 596 |
+
def _ensure_qwen_loaded(self):
|
| 597 |
+
if self._qwen_tokenizer is not None and self._qwen_model is not None:
|
| 598 |
+
return
|
| 599 |
+
from transformers import AutoTokenizer, AutoModel
|
| 600 |
+
self._qwen_tokenizer = AutoTokenizer.from_pretrained(
|
| 601 |
+
"Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left"
|
| 602 |
+
)
|
| 603 |
+
dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
| 604 |
+
self._qwen_model = AutoModel.from_pretrained(
|
| 605 |
+
"Qwen/Qwen3-Embedding-0.6B", torch_dtype=dtype, trust_remote_code=True
|
| 606 |
+
).to(self.device).eval()
|
| 607 |
+
|
| 608 |
+
@staticmethod
|
| 609 |
+
def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 610 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 611 |
+
if left_padding:
|
| 612 |
+
return last_hidden_states[:, -1]
|
| 613 |
+
else:
|
| 614 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 615 |
+
batch_size = last_hidden_states.shape[0]
|
| 616 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
| 617 |
+
|
| 618 |
+
@staticmethod
|
| 619 |
+
def _format_instruct(task: str, query: str) -> str:
|
| 620 |
+
return f"Instruct: {task}\nQuery: {query}"
|
| 621 |
+
|
| 622 |
+
@torch.no_grad()
|
| 623 |
+
def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 624 |
+
# Load taxonomy DB if provided
|
| 625 |
+
taxonomy_db = None
|
| 626 |
+
if self.taxonomy_db_path:
|
| 627 |
+
try:
|
| 628 |
+
with open(self.taxonomy_db_path, "r") as f:
|
| 629 |
+
import json
|
| 630 |
+
taxonomy_db = json.load(f)
|
| 631 |
+
except Exception:
|
| 632 |
+
taxonomy_db = None
|
| 633 |
+
|
| 634 |
+
self._ensure_qwen_loaded()
|
| 635 |
+
tokenizer = self._qwen_tokenizer
|
| 636 |
+
model = self._qwen_model
|
| 637 |
+
assert tokenizer is not None and model is not None
|
| 638 |
+
|
| 639 |
+
task = (
|
| 640 |
+
"Given a species taxonomy information, generate a biological embedding "
|
| 641 |
+
"representing its taxonomic and evolutionary characteristics"
|
| 642 |
+
)
|
| 643 |
+
texts = [self._format_instruct(task, taxonomy_db.get(s, s) if taxonomy_db else s) for s in names]
|
| 644 |
+
|
| 645 |
+
BATCH = int(self.qwen_opts.get("batch_size", 16))
|
| 646 |
+
max_len = int(self.qwen_opts.get("max_length", 512))
|
| 647 |
+
|
| 648 |
+
# sequence pooling only
|
| 649 |
+
seqs: List[torch.Tensor] = []
|
| 650 |
+
lens: List[int] = []
|
| 651 |
+
for i in range(0, len(texts), BATCH):
|
| 652 |
+
chunk = texts[i : i + BATCH]
|
| 653 |
+
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(self.device)
|
| 654 |
+
out = model(**inputs)
|
| 655 |
+
h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) # [B,L,D]
|
| 656 |
+
attn = inputs["attention_mask"]
|
| 657 |
+
for j in range(h.size(0)):
|
| 658 |
+
L = int(attn[j].sum().item())
|
| 659 |
+
seqs.append(h[j, :L, :].float().cpu())
|
| 660 |
+
lens.append(L)
|
| 661 |
+
# Pad to [B,Lmax,D]
|
| 662 |
+
Lmax = max(lens) if lens else 0
|
| 663 |
+
D = seqs[0].size(1) if seqs else 0
|
| 664 |
+
padded = torch.zeros(len(seqs), Lmax, D)
|
| 665 |
+
for i, t in enumerate(seqs):
|
| 666 |
+
padded[i, : t.size(0), :] = t
|
| 667 |
+
return padded, torch.tensor(lens, dtype=torch.long)
|
| 668 |
+
|
| 669 |
+
# ----------------------------
|
| 670 |
+
# Conditioning helper
|
| 671 |
+
# ----------------------------
|
| 672 |
+
|
| 673 |
+
# (Kept minimal. Species embeddings are prepared inline in sample().)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
# ----------------------------
|
| 677 |
+
# Convenience function
|
| 678 |
+
# ----------------------------
|
| 679 |
+
|
| 680 |
+
def sample_sequences(
|
| 681 |
+
model_path: str,
|
| 682 |
+
num_sequences: int = 10,
|
| 683 |
+
sequence_length: int = 100,
|
| 684 |
+
species: Optional[Union[str, List[str]]] = None,
|
| 685 |
+
protein_sequence: Optional[Union[str, List[str]]] = None,
|
| 686 |
+
**kwargs
|
| 687 |
+
) -> List[str]:
|
| 688 |
+
sampler = CodonSampler(model_path)
|
| 689 |
+
out = sampler.sample(
|
| 690 |
+
num_sequences=num_sequences,
|
| 691 |
+
sequence_length=sequence_length,
|
| 692 |
+
species=species,
|
| 693 |
+
protein_sequences=protein_sequence,
|
| 694 |
+
**kwargs
|
| 695 |
+
)
|
| 696 |
+
return out["sequences"] # type: ignore[return-value]
|
src/tokenizer.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/tokenizer.py
|
| 2 |
+
"""
|
| 3 |
+
Codon tokenizer: 3-mer tokens + 4 special tokens.
|
| 4 |
+
|
| 5 |
+
No frameworks, no inheritance chains. Just:
|
| 6 |
+
- encode_codon_seq("ATG...") -> [ids...] (appends EOS outside, not here)
|
| 7 |
+
- decode_codon_seq([ids...]) -> "ATG..."
|
| 8 |
+
- save_vocabulary(dir) / from_pretrained(dir) for reproducible runs
|
| 9 |
+
|
| 10 |
+
Special IDs are fixed and contiguous from 0:
|
| 11 |
+
pad=0, unk=1, bos=2, eos=3
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ------------------------------
|
| 24 |
+
# Special token ids
|
| 25 |
+
# ------------------------------
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class SpecialIds:
|
| 29 |
+
pad: int = 0
|
| 30 |
+
unk: int = 1
|
| 31 |
+
bos: int = 2
|
| 32 |
+
eos: int = 3
|
| 33 |
+
|
| 34 |
+
def to_dict(self) -> Dict[str, int]:
|
| 35 |
+
return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ------------------------------
|
| 39 |
+
# Tokenizer
|
| 40 |
+
# ------------------------------
|
| 41 |
+
|
| 42 |
+
class CodonTokenizer:
|
| 43 |
+
"""Minimal tokenizer for codon (DNA 3-mer) sequences."""
|
| 44 |
+
|
| 45 |
+
__slots__ = (
|
| 46 |
+
"codons",
|
| 47 |
+
"_special_token_str",
|
| 48 |
+
"vocab",
|
| 49 |
+
"ids_to_tokens",
|
| 50 |
+
"_special_ids",
|
| 51 |
+
"_num_special_tokens",
|
| 52 |
+
"_genetic_code",
|
| 53 |
+
"_codon2aa_char",
|
| 54 |
+
"_aa2codons_char",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
pad_token: str = "<pad>",
|
| 60 |
+
unk_token: str = "<unk>",
|
| 61 |
+
bos_token: str = "<bos>",
|
| 62 |
+
eos_token: str = "<stop>", # human-readable; id is still 3
|
| 63 |
+
**_: Any, # ignore junk kwargs – we don't play framework games
|
| 64 |
+
) -> None:
|
| 65 |
+
# 64 codons
|
| 66 |
+
bases = ("A", "C", "G", "T")
|
| 67 |
+
self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases]
|
| 68 |
+
|
| 69 |
+
# specials come first, contiguous
|
| 70 |
+
special_tokens = [pad_token, unk_token, bos_token, eos_token]
|
| 71 |
+
self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token}
|
| 72 |
+
|
| 73 |
+
# vocab: specials [0..3], then 64 codons [4..67]
|
| 74 |
+
self.vocab: Dict[str, int] = {}
|
| 75 |
+
for i, tok in enumerate(special_tokens):
|
| 76 |
+
self.vocab[tok] = i
|
| 77 |
+
for codon in self.codons:
|
| 78 |
+
self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens))
|
| 79 |
+
|
| 80 |
+
# reverse map
|
| 81 |
+
self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()}
|
| 82 |
+
|
| 83 |
+
# fixed ids
|
| 84 |
+
self._special_ids = SpecialIds(
|
| 85 |
+
pad=self.vocab[pad_token],
|
| 86 |
+
unk=self.vocab[unk_token],
|
| 87 |
+
bos=self.vocab[bos_token],
|
| 88 |
+
eos=self.vocab[eos_token],
|
| 89 |
+
)
|
| 90 |
+
self._num_special_tokens = len(special_tokens)
|
| 91 |
+
|
| 92 |
+
# genetic code (char)
|
| 93 |
+
self._genetic_code: Dict[str, str] = {
|
| 94 |
+
"TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L",
|
| 95 |
+
"TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S",
|
| 96 |
+
"TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*",
|
| 97 |
+
"TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W",
|
| 98 |
+
"CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L",
|
| 99 |
+
"CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
|
| 100 |
+
"CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q",
|
| 101 |
+
"CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R",
|
| 102 |
+
"ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
|
| 103 |
+
"ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T",
|
| 104 |
+
"AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K",
|
| 105 |
+
"AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
|
| 106 |
+
"GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V",
|
| 107 |
+
"GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A",
|
| 108 |
+
"GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
|
| 109 |
+
"GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G",
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# precompute char helpers
|
| 113 |
+
self._codon2aa_char: Dict[int, str] = {}
|
| 114 |
+
self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
|
| 115 |
+
for codon in self.codons:
|
| 116 |
+
cid = self.vocab[codon]
|
| 117 |
+
aa = self._genetic_code.get(codon, "X")
|
| 118 |
+
self._codon2aa_char[cid] = aa
|
| 119 |
+
if aa in self._aa2codons_char:
|
| 120 |
+
self._aa2codons_char[aa].append(cid)
|
| 121 |
+
|
| 122 |
+
# sanity: specials are contiguous 0..3
|
| 123 |
+
ids = list(self._special_ids.to_dict().values())
|
| 124 |
+
if sorted(ids) != list(range(self._num_special_tokens)):
|
| 125 |
+
raise AssertionError("Special token ids must be contiguous starting at 0")
|
| 126 |
+
|
| 127 |
+
# ---------- properties ----------
|
| 128 |
+
@property
|
| 129 |
+
def vocab_size(self) -> int:
|
| 130 |
+
return len(self.vocab)
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def special_ids(self) -> SpecialIds:
|
| 134 |
+
return self._special_ids
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def num_special_tokens(self) -> int:
|
| 138 |
+
return self._num_special_tokens
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def pad_token_id(self) -> int:
|
| 142 |
+
return self._special_ids.pad
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def unk_token_id(self) -> int:
|
| 146 |
+
return self._special_ids.unk
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def bos_token_id(self) -> int:
|
| 150 |
+
return self._special_ids.bos
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def eos_token_id(self) -> int:
|
| 154 |
+
return self._special_ids.eos
|
| 155 |
+
|
| 156 |
+
# ---------- core API ----------
|
| 157 |
+
def encode_codon_seq(self, seq: str, validate: bool = True) -> List[int]:
|
| 158 |
+
"""
|
| 159 |
+
Map DNA (ACGT)^3N to 3-mer ids. We don't append BOS/EOS here.
|
| 160 |
+
"""
|
| 161 |
+
s = seq.upper()
|
| 162 |
+
if validate:
|
| 163 |
+
if len(s) % 3 != 0:
|
| 164 |
+
raise ValueError(f"Sequence length {len(s)} not divisible by 3")
|
| 165 |
+
if not _is_acgt(s):
|
| 166 |
+
raise ValueError("Sequence contains invalid nucleotides (only ACGT supported)")
|
| 167 |
+
out: List[int] = []
|
| 168 |
+
# Fast Python slice loop – good enough. NumPy won't help for tiny strings.
|
| 169 |
+
for i in range(0, len(s), 3):
|
| 170 |
+
codon = s[i : i + 3]
|
| 171 |
+
out.append(self.vocab.get(codon, self._special_ids.unk))
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
def decode_codon_seq(self, token_ids: List[int]) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Convert codon ids (>= num_special_tokens) back to DNA string.
|
| 177 |
+
Special ids are ignored unless they collide (they don't).
|
| 178 |
+
"""
|
| 179 |
+
parts: List[str] = []
|
| 180 |
+
nst = self._num_special_tokens
|
| 181 |
+
for tid in token_ids:
|
| 182 |
+
if tid >= nst:
|
| 183 |
+
tok = self.ids_to_tokens.get(tid)
|
| 184 |
+
if tok is not None: # should always be a codon
|
| 185 |
+
parts.append(tok)
|
| 186 |
+
return "".join(parts)
|
| 187 |
+
|
| 188 |
+
def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **_: Any) -> str:
|
| 189 |
+
# kept for API parity with your old code
|
| 190 |
+
if skip_special_tokens:
|
| 191 |
+
token_ids = [t for t in token_ids if t >= self._num_special_tokens]
|
| 192 |
+
return self.decode_codon_seq(token_ids)
|
| 193 |
+
|
| 194 |
+
# ---------- misc helpers ----------
|
| 195 |
+
def codon_vocab(self) -> Dict[str, int]:
|
| 196 |
+
return {c: self.vocab[c] for c in self.codons}
|
| 197 |
+
|
| 198 |
+
def codon2aa_char_map(self) -> Dict[int, str]:
|
| 199 |
+
return dict(self._codon2aa_char)
|
| 200 |
+
|
| 201 |
+
def aa2codons_char_map(self) -> Dict[str, List[int]]:
|
| 202 |
+
return {k: v[:] for k, v in self._aa2codons_char.items()}
|
| 203 |
+
|
| 204 |
+
def aa_to_codon_length(self, aa_seq: str) -> int:
|
| 205 |
+
# You don't count stop unless it's explicitly there.
|
| 206 |
+
return len(aa_seq)
|
| 207 |
+
|
| 208 |
+
# HF compatibility stubs (your code calls these in a few places)
|
| 209 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 210 |
+
if len(text) % 3 != 0:
|
| 211 |
+
raise ValueError(f"Text length {len(text)} not divisible by 3")
|
| 212 |
+
return [text[i : i + 3] for i in range(0, len(text), 3)]
|
| 213 |
+
|
| 214 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 215 |
+
return self.vocab.get(token, self._special_ids.unk)
|
| 216 |
+
|
| 217 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 218 |
+
return self.ids_to_tokens.get(index, self._special_token_str["unk"])
|
| 219 |
+
|
| 220 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 221 |
+
return "".join(tokens)
|
| 222 |
+
|
| 223 |
+
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
|
| 224 |
+
return token_ids_0
|
| 225 |
+
|
| 226 |
+
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
|
| 227 |
+
return [0] * len(token_ids_0)
|
| 228 |
+
|
| 229 |
+
# ---------- persistence ----------
|
| 230 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 231 |
+
return dict(self.vocab)
|
| 232 |
+
|
| 233 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 234 |
+
"""
|
| 235 |
+
Save to JSON with both vocab and special token strings so we can
|
| 236 |
+
reconstruct IDs exactly. Deterministic and stable.
|
| 237 |
+
"""
|
| 238 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 239 |
+
vocab_file = os.path.join(
|
| 240 |
+
save_directory,
|
| 241 |
+
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
|
| 242 |
+
)
|
| 243 |
+
payload = {
|
| 244 |
+
"vocab": self.vocab,
|
| 245 |
+
"special_token_str": self._special_token_str,
|
| 246 |
+
}
|
| 247 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 248 |
+
json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)
|
| 249 |
+
return (vocab_file,)
|
| 250 |
+
|
| 251 |
+
@classmethod
|
| 252 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer":
|
| 253 |
+
"""
|
| 254 |
+
Load from a directory containing vocab.json produced by save_vocabulary().
|
| 255 |
+
We rebuild the SpecialIds from the saved token strings to keep IDs stable.
|
| 256 |
+
"""
|
| 257 |
+
vocab_path = Path(pretrained_model_name_or_path) / "vocab.json"
|
| 258 |
+
tok = cls(**kwargs) # default structure; we'll overwrite below
|
| 259 |
+
if not vocab_path.exists():
|
| 260 |
+
# If nothing to load, return defaults. It keeps the rest of your code happy.
|
| 261 |
+
return tok
|
| 262 |
+
|
| 263 |
+
with open(vocab_path, "r", encoding="utf-8") as f:
|
| 264 |
+
save_data = json.load(f)
|
| 265 |
+
|
| 266 |
+
if not isinstance(save_data, dict) or "vocab" not in save_data:
|
| 267 |
+
# Old, dumber format: the whole file was the vocab dict
|
| 268 |
+
vocab = save_data
|
| 269 |
+
special_token_str = tok._special_token_str
|
| 270 |
+
else:
|
| 271 |
+
vocab = save_data["vocab"]
|
| 272 |
+
special_token_str = save_data.get("special_token_str", tok._special_token_str)
|
| 273 |
+
|
| 274 |
+
# rebuild maps
|
| 275 |
+
tok.vocab = {str(k): int(v) for k, v in vocab.items()}
|
| 276 |
+
tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()}
|
| 277 |
+
|
| 278 |
+
# reconcile special strings → ids
|
| 279 |
+
if isinstance(special_token_str, dict):
|
| 280 |
+
tok._special_token_str.update({k: v for k, v in special_token_str.items() if k in ("pad", "unk", "bos", "eos")})
|
| 281 |
+
|
| 282 |
+
def _id_for(name: str, default_val: int) -> int:
|
| 283 |
+
sym = tok._special_token_str[name]
|
| 284 |
+
return int(tok.vocab.get(sym, default_val))
|
| 285 |
+
|
| 286 |
+
tok._special_ids = SpecialIds(
|
| 287 |
+
pad=_id_for("pad", 0),
|
| 288 |
+
unk=_id_for("unk", 1),
|
| 289 |
+
bos=_id_for("bos", 2),
|
| 290 |
+
eos=_id_for("eos", 3),
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Figure out how many specials to reserve. If the saved mapping had extra junk,
|
| 294 |
+
# we still preserve a contiguous prefix if present. Otherwise default to 4.
|
| 295 |
+
ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos]
|
| 296 |
+
m = max(ids)
|
| 297 |
+
tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4
|
| 298 |
+
|
| 299 |
+
# Rebuild genetic helpers (cheap)
|
| 300 |
+
tok._rebuild_helpers()
|
| 301 |
+
return tok
|
| 302 |
+
|
| 303 |
+
# internal: rebuild helper maps after load
|
| 304 |
+
def _rebuild_helpers(self) -> None:
|
| 305 |
+
self._codon2aa_char = {}
|
| 306 |
+
self._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
|
| 307 |
+
for codon in self.codons:
|
| 308 |
+
cid = self.vocab[codon]
|
| 309 |
+
aa = self._genetic_code.get(codon, "X")
|
| 310 |
+
self._codon2aa_char[cid] = aa
|
| 311 |
+
if aa in self._aa2codons_char:
|
| 312 |
+
self._aa2codons_char[aa].append(cid)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# ------------------------------
|
| 316 |
+
# small helpers
|
| 317 |
+
# ------------------------------
|
| 318 |
+
|
| 319 |
+
def _is_acgt(s: str) -> bool:
|
| 320 |
+
# Faster than regex for short strings.
|
| 321 |
+
for ch in s:
|
| 322 |
+
if ch not in ("A", "C", "G", "T"):
|
| 323 |
+
return False
|
| 324 |
+
return True
|
src/trainer.py
ADDED
|
@@ -0,0 +1,1230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/trainer.py
|
| 2 |
+
"""
|
| 3 |
+
FSDP trainer for CodonGPT.
|
| 4 |
+
No frameworks, no sugar. The model computes its own loss.
|
| 5 |
+
|
| 6 |
+
Batch invariants:
|
| 7 |
+
- codon_ids [B, T] (right-padded; EOS already in-sequence)
|
| 8 |
+
- species_ids [B] (SpeciesEmbeddingStore provides fixed-size or sequence embeddings)
|
| 9 |
+
- protein_seqs: list[str] (ESM tokenization happens inside the model)
|
| 10 |
+
|
| 11 |
+
Rules:
|
| 12 |
+
- If your loader is IterableDataset, you MUST set args.max_steps > 0. We don't guess.
|
| 13 |
+
- If you want epoch-based, use a sized dataset; we call len(dataloader).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import json
|
| 20 |
+
import math
|
| 21 |
+
import re
|
| 22 |
+
import shutil
|
| 23 |
+
import logging
|
| 24 |
+
import time
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
import datetime
|
| 27 |
+
import warnings
|
| 28 |
+
import importlib.util
|
| 29 |
+
import inspect
|
| 30 |
+
from typing import Any, Callable, Dict, Optional, Tuple, List
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.distributed as dist
|
| 35 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 36 |
+
|
| 37 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 38 |
+
from torch.distributed.fsdp import (
|
| 39 |
+
ShardingStrategy,
|
| 40 |
+
MixedPrecision,
|
| 41 |
+
StateDictType,
|
| 42 |
+
FullStateDictConfig,
|
| 43 |
+
FullOptimStateDictConfig,
|
| 44 |
+
)
|
| 45 |
+
from safetensors.torch import save_file, load_file
|
| 46 |
+
import wandb
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ------------------------------
|
| 52 |
+
# Args
|
| 53 |
+
# ------------------------------
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class TrainingArguments:
|
| 57 |
+
# Output
|
| 58 |
+
output_dir: str = "checkpoints"
|
| 59 |
+
save_steps: int = 1000
|
| 60 |
+
save_total_limit: int = 3
|
| 61 |
+
save_safetensors: bool = True
|
| 62 |
+
ckpt_recent_window_steps: int = 0
|
| 63 |
+
ckpt_recent_interval: int = 0
|
| 64 |
+
ckpt_archive_interval: int = 0
|
| 65 |
+
|
| 66 |
+
# Schedule
|
| 67 |
+
num_train_epochs: int = 1
|
| 68 |
+
max_steps: int = -1 # required for IterableDataset
|
| 69 |
+
gradient_accumulation_steps: int = 1
|
| 70 |
+
warmup_ratio: float = 0.0
|
| 71 |
+
lr_scheduler_type: str = "cosine" # "linear" | "cosine" | "constant"
|
| 72 |
+
# For streaming datasets: if max_steps<0 and steps_per_epoch>0, shape schedule using
|
| 73 |
+
# total_steps = num_train_epochs * steps_per_epoch
|
| 74 |
+
steps_per_epoch: int = 0
|
| 75 |
+
|
| 76 |
+
# Optim
|
| 77 |
+
learning_rate: float = 5e-4
|
| 78 |
+
weight_decay: float = 0.0
|
| 79 |
+
adam_beta1: float = 0.9
|
| 80 |
+
adam_beta2: float = 0.95
|
| 81 |
+
max_grad_norm: float = 1.0
|
| 82 |
+
|
| 83 |
+
# Data
|
| 84 |
+
per_device_train_batch_size: int = 8
|
| 85 |
+
per_device_eval_batch_size: int = 8
|
| 86 |
+
dataloader_num_workers: int = 0
|
| 87 |
+
|
| 88 |
+
# Precision / dist
|
| 89 |
+
fp16: bool = False
|
| 90 |
+
bf16: bool = False
|
| 91 |
+
fsdp: Optional[str] = None # "full_shard" or None
|
| 92 |
+
gradient_checkpointing: bool = False
|
| 93 |
+
|
| 94 |
+
# Global hard cap (prefix + start + codon)
|
| 95 |
+
max_length: int = 4096
|
| 96 |
+
|
| 97 |
+
# ESM (metadata only; model owns ESM)
|
| 98 |
+
esm_model_name: str = "esmc_300m"
|
| 99 |
+
esm_device: str = "cuda"
|
| 100 |
+
esm_dtype: str = "bf16"
|
| 101 |
+
|
| 102 |
+
# Logging / eval
|
| 103 |
+
logging_steps: int = 100
|
| 104 |
+
eval_steps: int = 0 # streaming eval: limit number of eval batches when eval dataset is Iterable
|
| 105 |
+
eval_interval: int = 0 # run evaluation every N optimizer steps (0 disables)
|
| 106 |
+
override_lr_on_resume: bool = False
|
| 107 |
+
# Minimal data stream resume cursor (stores total samples yielded so far for train dataset).
|
| 108 |
+
# When provided, we load 'skip_samples' from this JSON at start and set the dataset
|
| 109 |
+
# to skip exactly that many samples on resume. We also update the file in _save_checkpoint().
|
| 110 |
+
data_cursor_path: Optional[str] = None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ------------------------------
|
| 114 |
+
# Trainer
|
| 115 |
+
# ------------------------------
|
| 116 |
+
|
| 117 |
+
class Trainer:
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
model: nn.Module,
|
| 121 |
+
args: TrainingArguments,
|
| 122 |
+
data_collator: Optional[Callable] = None,
|
| 123 |
+
train_dataset: Optional[Any] = None,
|
| 124 |
+
eval_dataset: Optional[Any] = None,
|
| 125 |
+
tokenizer: Optional[Any] = None,
|
| 126 |
+
model_init: Optional[Callable[[], nn.Module]] = None,
|
| 127 |
+
compute_metrics: Optional[Callable] = None,
|
| 128 |
+
callbacks: Optional[list] = None,
|
| 129 |
+
optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[Any]] = (None, None),
|
| 130 |
+
preprocess_logits_for_metrics: Optional[Callable] = None,
|
| 131 |
+
species_store=None,
|
| 132 |
+
resume_from_checkpoint: Optional[str] = None,
|
| 133 |
+
):
|
| 134 |
+
self.model = model
|
| 135 |
+
self.args = args
|
| 136 |
+
self.tokenizer = tokenizer
|
| 137 |
+
self.optimizer = optimizers[0]
|
| 138 |
+
self.lr_scheduler = optimizers[1]
|
| 139 |
+
self.species_store = species_store
|
| 140 |
+
|
| 141 |
+
self.train_dataloader: Optional[DataLoader] = None
|
| 142 |
+
self.eval_dataloader: Optional[DataLoader] = None
|
| 143 |
+
|
| 144 |
+
# Device (robust local rank resolution)
|
| 145 |
+
self.local_rank = 0
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
lr_env = os.environ.get("LOCAL_RANK")
|
| 148 |
+
if lr_env is not None:
|
| 149 |
+
self.local_rank = int(lr_env)
|
| 150 |
+
else:
|
| 151 |
+
r = int(os.environ.get("RANK", "0"))
|
| 152 |
+
ng = max(1, torch.cuda.device_count())
|
| 153 |
+
self.local_rank = (r % ng)
|
| 154 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 155 |
+
torch.cuda.set_device(self.device)
|
| 156 |
+
cd = torch.cuda.current_device()
|
| 157 |
+
nm = torch.cuda.get_device_name(cd)
|
| 158 |
+
logger.info(
|
| 159 |
+
f"[dist] RANK={os.environ.get('RANK')} LOCAL_RANK={os.environ.get('LOCAL_RANK')} WORLD_SIZE={os.environ.get('WORLD_SIZE')} "
|
| 160 |
+
f"cuda.count={torch.cuda.device_count()} select={self.device} current={cd} name={nm}"
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
self.device = torch.device("cpu")
|
| 164 |
+
|
| 165 |
+
# Gradient checkpointing toggle (model owns the flag)
|
| 166 |
+
base = self._unwrap(self.model)
|
| 167 |
+
if self.args.gradient_checkpointing and hasattr(base, "gradient_checkpointing"):
|
| 168 |
+
base.gradient_checkpointing = True
|
| 169 |
+
|
| 170 |
+
# FSDP or single GPU
|
| 171 |
+
if self.args.fsdp:
|
| 172 |
+
self._setup_fsdp()
|
| 173 |
+
else:
|
| 174 |
+
self.model = self.model.to(self.device)
|
| 175 |
+
|
| 176 |
+
# AMP setup (use torch.amp APIs; GradScaler on CUDA only)
|
| 177 |
+
self._use_amp = (self.device.type == "cuda") and (self.args.fp16 or self.args.bf16)
|
| 178 |
+
self._amp_dtype = torch.float16 if self.args.fp16 else (torch.bfloat16 if self.args.bf16 else None)
|
| 179 |
+
use_cuda = (self.device.type == "cuda")
|
| 180 |
+
self._scaler = torch.amp.GradScaler(device="cuda", enabled=(use_cuda and self.args.fp16))
|
| 181 |
+
|
| 182 |
+
self.state = {"epoch": 0, "global_step": 0}
|
| 183 |
+
|
| 184 |
+
# Defer resume until after dataloaders are attached so scheduler can be shaped.
|
| 185 |
+
self._resume_path: Optional[str] = resume_from_checkpoint
|
| 186 |
+
|
| 187 |
+
# ---- dataloaders ----
|
| 188 |
+
def attach_dataloaders(self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None):
|
| 189 |
+
# Your dataset should handle sharding. We don't wrap with DistributedSampler here.
|
| 190 |
+
self.train_dataloader = train_loader
|
| 191 |
+
self.eval_dataloader = eval_loader
|
| 192 |
+
# Apply minimal resume cursor to the training dataset if configured
|
| 193 |
+
p = getattr(self.args, "data_cursor_path", None)
|
| 194 |
+
if p and os.path.exists(p):
|
| 195 |
+
with open(p, "r") as f:
|
| 196 |
+
js = json.load(f)
|
| 197 |
+
ds = getattr(self.train_dataloader, "dataset", None)
|
| 198 |
+
if hasattr(ds, "set_resume_skip"):
|
| 199 |
+
distributed = dist.is_available() and dist.is_initialized()
|
| 200 |
+
world = dist.get_world_size() if distributed else 1
|
| 201 |
+
rank = dist.get_rank() if distributed else 0
|
| 202 |
+
|
| 203 |
+
# Prefer the total cursor and split evenly across current world size.
|
| 204 |
+
# If total is missing, sum any saved per_rank list.
|
| 205 |
+
total: int = 0
|
| 206 |
+
if isinstance(js, dict):
|
| 207 |
+
try:
|
| 208 |
+
total = int(js.get("skip_samples", 0) or 0)
|
| 209 |
+
except Exception:
|
| 210 |
+
total = 0
|
| 211 |
+
if total <= 0:
|
| 212 |
+
raw = js.get("per_rank")
|
| 213 |
+
if isinstance(raw, list) and raw:
|
| 214 |
+
try:
|
| 215 |
+
total = int(sum(int(x) for x in raw))
|
| 216 |
+
except Exception:
|
| 217 |
+
total = 0
|
| 218 |
+
|
| 219 |
+
if total > 0:
|
| 220 |
+
if distributed:
|
| 221 |
+
per = total // max(world, 1)
|
| 222 |
+
rem = total % max(world, 1)
|
| 223 |
+
n_rank = per + (1 if rank < rem else 0)
|
| 224 |
+
ds.set_resume_skip(int(n_rank))
|
| 225 |
+
if self._is_main():
|
| 226 |
+
logger.info(
|
| 227 |
+
"resume cursor: total=%s split across world=%s → rank=%s skip=%s",
|
| 228 |
+
total, world, rank, n_rank,
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
ds.set_resume_skip(int(total))
|
| 232 |
+
if self._is_main():
|
| 233 |
+
logger.info("resume cursor: total=%s (single-process) skip=%s", total, total)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ---- optim + scheduler ----
|
| 237 |
+
def _create_optimizer_and_scheduler(self):
|
| 238 |
+
if self.optimizer is None:
|
| 239 |
+
decay, no_decay = [], []
|
| 240 |
+
for n, p in self._unwrap(self.model).named_parameters():
|
| 241 |
+
if not p.requires_grad:
|
| 242 |
+
continue
|
| 243 |
+
if n.endswith("bias") or "norm" in n.lower() or "ln_" in n.lower():
|
| 244 |
+
no_decay.append(p)
|
| 245 |
+
else:
|
| 246 |
+
decay.append(p)
|
| 247 |
+
|
| 248 |
+
opt_kwargs = dict(
|
| 249 |
+
lr=self.args.learning_rate,
|
| 250 |
+
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
| 251 |
+
)
|
| 252 |
+
params = [
|
| 253 |
+
{"params": decay, "weight_decay": self.args.weight_decay},
|
| 254 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 255 |
+
]
|
| 256 |
+
sig_adamw = inspect.signature(torch.optim.AdamW)
|
| 257 |
+
if torch.cuda.is_available() and "fused" in sig_adamw.parameters:
|
| 258 |
+
opt_kwargs["fused"] = True # type: ignore[assignment]
|
| 259 |
+
self.optimizer = torch.optim.AdamW(params, **opt_kwargs)
|
| 260 |
+
# Report fused/foreach settings (rank0 only)
|
| 261 |
+
if self._is_main():
|
| 262 |
+
fused_flag = None
|
| 263 |
+
foreach_flag = None
|
| 264 |
+
if hasattr(self.optimizer, "defaults"):
|
| 265 |
+
fused_flag = self.optimizer.defaults.get("fused")
|
| 266 |
+
foreach_flag = self.optimizer.defaults.get("foreach")
|
| 267 |
+
logger.info(f"AdamW configured: fused={fused_flag} foreach={foreach_flag}")
|
| 268 |
+
|
| 269 |
+
# total steps and schedule shape
|
| 270 |
+
ds = getattr(self.train_dataloader, "dataset", None)
|
| 271 |
+
ga = max(1, self.args.gradient_accumulation_steps)
|
| 272 |
+
if isinstance(ds, IterableDataset):
|
| 273 |
+
if self.args.max_steps > 0:
|
| 274 |
+
# Use max_steps to shape the scheduler; allow multiple epochs to re-iterate the stream
|
| 275 |
+
steps_per_epoch = self.args.max_steps
|
| 276 |
+
total_steps = self.args.max_steps
|
| 277 |
+
elif getattr(self.args, "steps_per_epoch", 0) and self.args.steps_per_epoch > 0:
|
| 278 |
+
# steps_per_epoch is already expressed in optimizer steps (train.py accounts for grad_accum)
|
| 279 |
+
steps_per_epoch = max(1, int(self.args.steps_per_epoch))
|
| 280 |
+
total_steps = max(1, self.args.num_train_epochs) * steps_per_epoch
|
| 281 |
+
else:
|
| 282 |
+
# Unknown epoch size; use constant LR without pre-shaped schedule
|
| 283 |
+
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0)
|
| 284 |
+
return
|
| 285 |
+
else:
|
| 286 |
+
# sized dataloader: len(dataloader) is number of batches
|
| 287 |
+
steps_per_epoch = max(len(self.train_dataloader) // ga, 1)
|
| 288 |
+
total_steps = self.args.max_steps if self.args.max_steps > 0 else self.args.num_train_epochs * steps_per_epoch
|
| 289 |
+
|
| 290 |
+
warmup = int(self.args.warmup_ratio * total_steps)
|
| 291 |
+
|
| 292 |
+
if self.args.lr_scheduler_type == "constant":
|
| 293 |
+
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0)
|
| 294 |
+
return
|
| 295 |
+
|
| 296 |
+
def lrs_lambda(step: int) -> float:
|
| 297 |
+
if step < warmup:
|
| 298 |
+
return max(float(step) / max(warmup, 1), 1e-6)
|
| 299 |
+
t = (step - warmup) / max(total_steps - warmup, 1)
|
| 300 |
+
if self.args.lr_scheduler_type == "linear":
|
| 301 |
+
return max(1.0 - t, 0.0)
|
| 302 |
+
# cosine default
|
| 303 |
+
return 0.5 * (1.0 + math.cos(math.pi * t))
|
| 304 |
+
|
| 305 |
+
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lrs_lambda)
|
| 306 |
+
|
| 307 |
+
# ---- training ----
|
| 308 |
+
def train(self) -> Dict[str, float]:
|
| 309 |
+
assert self.train_dataloader is not None, "Call attach_dataloaders() first"
|
| 310 |
+
# If a resume path was provided, load it now (dataloaders are attached).
|
| 311 |
+
if getattr(self, "_resume_path", None):
|
| 312 |
+
self._resume_from(self._resume_path) # loads model/optimizer/scheduler/state
|
| 313 |
+
self._resume_path = None
|
| 314 |
+
|
| 315 |
+
if self.optimizer is None:
|
| 316 |
+
self._create_optimizer_and_scheduler()
|
| 317 |
+
|
| 318 |
+
ds = self.train_dataloader.dataset
|
| 319 |
+
|
| 320 |
+
# Exact step budget for streaming datasets when max_steps<0 and steps_per_epoch>0
|
| 321 |
+
target_total_steps: Optional[int] = None
|
| 322 |
+
if isinstance(ds, IterableDataset) and int(self.args.max_steps) < 0:
|
| 323 |
+
spe = int(getattr(self.args, "steps_per_epoch", 0) or 0)
|
| 324 |
+
if spe > 0:
|
| 325 |
+
target_total_steps = max(1, int(self.args.num_train_epochs)) * spe
|
| 326 |
+
|
| 327 |
+
# Determine total steps for progress bar
|
| 328 |
+
progress_total: Optional[int] = None
|
| 329 |
+
if int(self.args.max_steps) > 0:
|
| 330 |
+
progress_total = int(self.args.max_steps)
|
| 331 |
+
elif isinstance(ds, IterableDataset):
|
| 332 |
+
if target_total_steps is not None:
|
| 333 |
+
progress_total = target_total_steps
|
| 334 |
+
else:
|
| 335 |
+
ga = max(1, self.args.gradient_accumulation_steps)
|
| 336 |
+
steps_per_epoch = max(len(self.train_dataloader) // ga, 1)
|
| 337 |
+
progress_total = max(1, int(self.args.num_train_epochs)) * steps_per_epoch
|
| 338 |
+
|
| 339 |
+
# Initialize Weights & Biases (rank0 only)
|
| 340 |
+
if self._is_main():
|
| 341 |
+
if not hasattr(self, "_wandb"):
|
| 342 |
+
proj = os.environ.get("WANDB_PROJECT", "codongpt")
|
| 343 |
+
name = os.environ.get("WANDB_NAME")
|
| 344 |
+
run_id = os.environ.get("WANDB_RUN_ID")
|
| 345 |
+
resume = os.environ.get("WANDB_RESUME")
|
| 346 |
+
wandb_dir = os.environ.get("WANDB_DIR")
|
| 347 |
+
world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1"))
|
| 348 |
+
init_kwargs = {
|
| 349 |
+
"project": proj,
|
| 350 |
+
"name": name,
|
| 351 |
+
"config": {
|
| 352 |
+
"lr": self.args.learning_rate,
|
| 353 |
+
"warmup_ratio": self.args.warmup_ratio,
|
| 354 |
+
"scheduler": self.args.lr_scheduler_type,
|
| 355 |
+
"batch_size": self.args.per_device_train_batch_size,
|
| 356 |
+
"eval_batch_size": self.args.per_device_eval_batch_size,
|
| 357 |
+
"grad_accum": self.args.gradient_accumulation_steps,
|
| 358 |
+
"effective_global_batch": self.args.per_device_train_batch_size * max(1, world_size) * max(1, self.args.gradient_accumulation_steps),
|
| 359 |
+
"epochs": self.args.num_train_epochs,
|
| 360 |
+
"steps_per_epoch": getattr(self.args, "steps_per_epoch", 0),
|
| 361 |
+
"max_steps": self.args.max_steps,
|
| 362 |
+
"weight_decay": self.args.weight_decay,
|
| 363 |
+
"world_size": world_size,
|
| 364 |
+
"output_dir": self.args.output_dir,
|
| 365 |
+
"fsdp": self.args.fsdp,
|
| 366 |
+
"bf16": self.args.bf16,
|
| 367 |
+
"fp16": self.args.fp16,
|
| 368 |
+
},
|
| 369 |
+
}
|
| 370 |
+
if run_id:
|
| 371 |
+
init_kwargs["id"] = run_id
|
| 372 |
+
if resume:
|
| 373 |
+
init_kwargs["resume"] = resume
|
| 374 |
+
if wandb_dir:
|
| 375 |
+
init_kwargs["dir"] = wandb_dir
|
| 376 |
+
self._wandb = wandb.init(**init_kwargs)
|
| 377 |
+
|
| 378 |
+
self.model.train()
|
| 379 |
+
grad_accum = max(1, self.args.gradient_accumulation_steps)
|
| 380 |
+
progress = None
|
| 381 |
+
if self._is_main() and progress_total is not None and progress_total > 0:
|
| 382 |
+
progress = tqdm(total=progress_total, initial=int(self.state["global_step"]), desc="Train", dynamic_ncols=True)
|
| 383 |
+
if self.device.type == "cuda" and torch.cuda.is_available():
|
| 384 |
+
torch.cuda.reset_peak_memory_stats(self.device)
|
| 385 |
+
world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1"))
|
| 386 |
+
seqs_per_optimizer_step = (
|
| 387 |
+
int(self.args.per_device_train_batch_size) * max(1, world_size) * grad_accum
|
| 388 |
+
)
|
| 389 |
+
log_window_start = time.perf_counter()
|
| 390 |
+
log_window_optimizer_steps = 0
|
| 391 |
+
|
| 392 |
+
for epoch in range(self.state["epoch"], max(1, self.args.num_train_epochs)):
|
| 393 |
+
self.state["epoch"] = epoch
|
| 394 |
+
running_loss = 0.0
|
| 395 |
+
running_count = 0
|
| 396 |
+
|
| 397 |
+
train_iter = iter(self.train_dataloader)
|
| 398 |
+
step = 0
|
| 399 |
+
batches_this_epoch = 0
|
| 400 |
+
optimizer_steps_this_epoch = 0
|
| 401 |
+
# If this is a streaming dataset with a shaped schedule, enforce a per-epoch optimizer step budget
|
| 402 |
+
enforce_budget = False
|
| 403 |
+
epoch_budget = None
|
| 404 |
+
ds = self.train_dataloader.dataset
|
| 405 |
+
if isinstance(ds, IterableDataset):
|
| 406 |
+
spe = int(getattr(self.args, "steps_per_epoch", 0) or 0)
|
| 407 |
+
if spe > 0:
|
| 408 |
+
enforce_budget = True
|
| 409 |
+
epoch_budget = int(spe)
|
| 410 |
+
|
| 411 |
+
refill_attempts = 0
|
| 412 |
+
max_refills = 64 # avoids infinite loops when dataset is empty
|
| 413 |
+
|
| 414 |
+
while True:
|
| 415 |
+
batch, has_batch, local_has_batch = self._next_batch_sync(train_iter)
|
| 416 |
+
if not has_batch:
|
| 417 |
+
# If budget-enforced, attempt to refill the iterator and continue until budget is met.
|
| 418 |
+
if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch < epoch_budget):
|
| 419 |
+
if local_has_batch and self._is_main():
|
| 420 |
+
logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync")
|
| 421 |
+
self._barrier()
|
| 422 |
+
train_iter = iter(self.train_dataloader)
|
| 423 |
+
refill_attempts += 1
|
| 424 |
+
if refill_attempts > max_refills:
|
| 425 |
+
if self._is_main():
|
| 426 |
+
logger.warning(
|
| 427 |
+
"Exceeded max refills for epoch %s (steps %s/%s). Ending epoch early.",
|
| 428 |
+
epoch, optimizer_steps_this_epoch, epoch_budget,
|
| 429 |
+
)
|
| 430 |
+
break
|
| 431 |
+
continue
|
| 432 |
+
else:
|
| 433 |
+
if local_has_batch and self._is_main():
|
| 434 |
+
logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync")
|
| 435 |
+
break
|
| 436 |
+
|
| 437 |
+
batch = self._prepare_batch(batch)
|
| 438 |
+
batches_this_epoch += 1
|
| 439 |
+
|
| 440 |
+
codon_ids = batch["codon_ids"].to(self.device)
|
| 441 |
+
input_ids = codon_ids[:, :-1]
|
| 442 |
+
labels = codon_ids[:, :-1]
|
| 443 |
+
|
| 444 |
+
# Mask PAD/EOS in labels
|
| 445 |
+
pad_id = int(self.tokenizer.pad_token_id) if self.tokenizer is not None else 0
|
| 446 |
+
eos_id = int(self.tokenizer.special_ids.eos) if self.tokenizer is not None else -999
|
| 447 |
+
labels = labels.clone()
|
| 448 |
+
labels[labels == pad_id] = -100
|
| 449 |
+
labels[labels == eos_id] = -100
|
| 450 |
+
|
| 451 |
+
cond = self._build_cond(batch)
|
| 452 |
+
|
| 453 |
+
# autocast context
|
| 454 |
+
use_cuda = (self.device.type == "cuda")
|
| 455 |
+
autocast_dtype = self._amp_dtype
|
| 456 |
+
if autocast_dtype is not None and use_cuda:
|
| 457 |
+
ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype)
|
| 458 |
+
else:
|
| 459 |
+
from contextlib import nullcontext
|
| 460 |
+
ctx = nullcontext()
|
| 461 |
+
|
| 462 |
+
with ctx:
|
| 463 |
+
out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True)
|
| 464 |
+
loss = out["loss"]
|
| 465 |
+
|
| 466 |
+
if self._scaler.is_enabled():
|
| 467 |
+
self._scaler.scale(loss / grad_accum).backward()
|
| 468 |
+
else:
|
| 469 |
+
(loss / grad_accum).backward()
|
| 470 |
+
|
| 471 |
+
running_loss += float(loss.detach().item())
|
| 472 |
+
running_count += 1
|
| 473 |
+
|
| 474 |
+
do_step = ((step + 1) % grad_accum == 0)
|
| 475 |
+
if do_step:
|
| 476 |
+
# Clip
|
| 477 |
+
if self.args.max_grad_norm and self.args.max_grad_norm > 0:
|
| 478 |
+
if isinstance(self.model, FSDP):
|
| 479 |
+
FSDP.clip_grad_norm_(self.model, self.args.max_grad_norm)
|
| 480 |
+
else:
|
| 481 |
+
if self._scaler.is_enabled():
|
| 482 |
+
self._scaler.unscale_(self.optimizer)
|
| 483 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
|
| 484 |
+
|
| 485 |
+
# Step
|
| 486 |
+
if self._scaler.is_enabled():
|
| 487 |
+
self._scaler.step(self.optimizer)
|
| 488 |
+
self._scaler.update()
|
| 489 |
+
else:
|
| 490 |
+
self.optimizer.step()
|
| 491 |
+
if self.lr_scheduler is not None:
|
| 492 |
+
self.lr_scheduler.step()
|
| 493 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 494 |
+
self.state["global_step"] += 1
|
| 495 |
+
optimizer_steps_this_epoch += 1
|
| 496 |
+
log_window_optimizer_steps += 1
|
| 497 |
+
|
| 498 |
+
# (wandb) Defer logging to the periodic block below
|
| 499 |
+
|
| 500 |
+
# Log
|
| 501 |
+
should_log = (self.state["global_step"] % max(1, self.args.logging_steps) == 0)
|
| 502 |
+
peak_alloc_gb = 0.0
|
| 503 |
+
peak_reserved_gb = 0.0
|
| 504 |
+
if should_log:
|
| 505 |
+
peak_alloc_gb, peak_reserved_gb = self._max_cuda_peak_gb()
|
| 506 |
+
if self._is_main() and should_log:
|
| 507 |
+
avg = running_loss / max(running_count, 1)
|
| 508 |
+
lr = float(self.optimizer.param_groups[0]["lr"])
|
| 509 |
+
log_epoch = self._epoch_for_logging()
|
| 510 |
+
elapsed = max(time.perf_counter() - log_window_start, 1e-9)
|
| 511 |
+
step_time_s = elapsed / max(log_window_optimizer_steps, 1)
|
| 512 |
+
seq_per_s = (seqs_per_optimizer_step * max(log_window_optimizer_steps, 1)) / elapsed
|
| 513 |
+
msg = f"epoch {log_epoch} step {self.state['global_step']}: loss={avg:.4f} lr={lr:.6g}"
|
| 514 |
+
if isinstance(out, dict):
|
| 515 |
+
pl = out.get("prefix_len")
|
| 516 |
+
pc = out.get("per_cap")
|
| 517 |
+
if pl is not None and pc is not None:
|
| 518 |
+
msg += f" prefix_mean={float(pl.detach().float().mean().item()):.1f} cap_mean={float(pc.detach().float().mean().item()):.1f}"
|
| 519 |
+
msg += (
|
| 520 |
+
f" step_time_s={step_time_s:.3f} seq_per_s={seq_per_s:.1f}"
|
| 521 |
+
f" peak_mem_alloc_gb={peak_alloc_gb:.1f} peak_mem_reserved_gb={peak_reserved_gb:.1f}"
|
| 522 |
+
)
|
| 523 |
+
logger.info(msg)
|
| 524 |
+
if hasattr(self, "_wandb"):
|
| 525 |
+
wandb.log({
|
| 526 |
+
"train/loss": float(avg),
|
| 527 |
+
"train/lr": float(lr),
|
| 528 |
+
"perf/step_time_s": float(step_time_s),
|
| 529 |
+
"perf/seq_per_s": float(seq_per_s),
|
| 530 |
+
"system/peak_mem_alloc_gb": float(peak_alloc_gb),
|
| 531 |
+
"system/peak_mem_reserved_gb": float(peak_reserved_gb),
|
| 532 |
+
}, step=self.state["global_step"])
|
| 533 |
+
running_loss = 0.0
|
| 534 |
+
running_count = 0
|
| 535 |
+
log_window_start = time.perf_counter()
|
| 536 |
+
log_window_optimizer_steps = 0
|
| 537 |
+
|
| 538 |
+
# Update progress bar
|
| 539 |
+
if progress is not None:
|
| 540 |
+
progress.update(1)
|
| 541 |
+
|
| 542 |
+
# Stop when budget is reached for streaming schedule
|
| 543 |
+
if target_total_steps is not None and self.state["global_step"] >= target_total_steps:
|
| 544 |
+
metrics = {"train_loss": running_loss / max(running_count, 1)}
|
| 545 |
+
self._save_checkpoint("final_model")
|
| 546 |
+
self._barrier()
|
| 547 |
+
return metrics
|
| 548 |
+
|
| 549 |
+
# Periodic teacher-forced evaluation on the held-out dataset
|
| 550 |
+
should_eval = (
|
| 551 |
+
self.eval_dataloader is not None and
|
| 552 |
+
self.args.eval_interval > 0 and
|
| 553 |
+
(self.state["global_step"] % self.args.eval_interval == 0)
|
| 554 |
+
)
|
| 555 |
+
if should_eval:
|
| 556 |
+
eval_metrics = self.evaluate()
|
| 557 |
+
if self._is_main():
|
| 558 |
+
el = float(eval_metrics.get("eval_loss", 0.0))
|
| 559 |
+
ea = eval_metrics.get("eval_codon_acc", None)
|
| 560 |
+
aa = eval_metrics.get("eval_aa_acc", None)
|
| 561 |
+
if ea is not None and aa is not None:
|
| 562 |
+
logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f} aa_acc={float(aa):.3f}")
|
| 563 |
+
elif ea is not None:
|
| 564 |
+
logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f}")
|
| 565 |
+
elif aa is not None:
|
| 566 |
+
logger.info(f"eval: loss={el:.4f} aa_acc={float(aa):.3f}")
|
| 567 |
+
else:
|
| 568 |
+
logger.info(f"eval: loss={el:.4f}")
|
| 569 |
+
if hasattr(self, "_wandb"):
|
| 570 |
+
log_payload = {"eval/loss": el}
|
| 571 |
+
if ea is not None:
|
| 572 |
+
log_payload["eval/codon_acc"] = float(ea)
|
| 573 |
+
if aa is not None:
|
| 574 |
+
log_payload["eval/aa_acc"] = float(aa)
|
| 575 |
+
wandb.log(log_payload, step=self.state["global_step"])
|
| 576 |
+
|
| 577 |
+
# Save by step
|
| 578 |
+
if self.args.save_steps > 0 and (self.state["global_step"] % self.args.save_steps == 0):
|
| 579 |
+
self._save_checkpoint(f"checkpoint-{self.state['global_step']}")
|
| 580 |
+
|
| 581 |
+
# Hard horizon for streaming/step-limited runs
|
| 582 |
+
if self.args.max_steps > 0 and self.state["global_step"] >= self.args.max_steps:
|
| 583 |
+
metrics = {"train_loss": running_loss / max(running_count, 1)}
|
| 584 |
+
self._save_checkpoint("final_model")
|
| 585 |
+
self._barrier()
|
| 586 |
+
if progress is not None:
|
| 587 |
+
progress.close()
|
| 588 |
+
return metrics
|
| 589 |
+
|
| 590 |
+
step += 1
|
| 591 |
+
|
| 592 |
+
# If we enforce a per-epoch budget for streaming datasets, end the epoch once it's reached
|
| 593 |
+
if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch >= epoch_budget):
|
| 594 |
+
break
|
| 595 |
+
|
| 596 |
+
# Epoch summary (rank0 only)
|
| 597 |
+
if self._is_main():
|
| 598 |
+
try:
|
| 599 |
+
eb = int(epoch_budget) if epoch_budget is not None else -1
|
| 600 |
+
except Exception:
|
| 601 |
+
eb = -1
|
| 602 |
+
logger.info(
|
| 603 |
+
"epoch %s completed: optimizer_steps=%s%s",
|
| 604 |
+
self._epoch_for_logging(),
|
| 605 |
+
optimizer_steps_this_epoch,
|
| 606 |
+
(f" / budget {eb}" if eb > 0 else ""),
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
if dist.is_available() and dist.is_initialized():
|
| 610 |
+
gather_device = self.device if self.device.type == "cuda" else torch.device("cpu")
|
| 611 |
+
counts_tensor = torch.tensor(
|
| 612 |
+
[batches_this_epoch, optimizer_steps_this_epoch],
|
| 613 |
+
dtype=torch.long,
|
| 614 |
+
device=gather_device,
|
| 615 |
+
)
|
| 616 |
+
gathered = [torch.zeros_like(counts_tensor) for _ in range(dist.get_world_size())]
|
| 617 |
+
dist.all_gather(gathered, counts_tensor)
|
| 618 |
+
batch_counts = [int(t[0].item()) for t in gathered]
|
| 619 |
+
step_counts = [int(t[1].item()) for t in gathered]
|
| 620 |
+
batch_gap = max(batch_counts) - min(batch_counts)
|
| 621 |
+
step_gap = max(step_counts) - min(step_counts)
|
| 622 |
+
if self._is_main() and (batch_gap > 0 or step_gap > 0):
|
| 623 |
+
logger.warning(
|
| 624 |
+
"Epoch %s imbalance detected across ranks: batches min=%s max=%s, optimizer steps min=%s max=%s",
|
| 625 |
+
epoch,
|
| 626 |
+
min(batch_counts),
|
| 627 |
+
max(batch_counts),
|
| 628 |
+
min(step_counts),
|
| 629 |
+
max(step_counts),
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# Epoch boundary save for sized datasets
|
| 633 |
+
if not isinstance(ds, IterableDataset):
|
| 634 |
+
self._save_checkpoint(f"epoch-{epoch}")
|
| 635 |
+
|
| 636 |
+
metrics = {"train_loss": 0.0}
|
| 637 |
+
if progress is not None:
|
| 638 |
+
progress.close()
|
| 639 |
+
self._barrier()
|
| 640 |
+
return metrics
|
| 641 |
+
|
| 642 |
+
# ---- evaluation ----
|
| 643 |
+
def evaluate(self) -> Dict[str, float]:
|
| 644 |
+
if self.eval_dataloader is None:
|
| 645 |
+
return {"eval_loss": 0.0}
|
| 646 |
+
|
| 647 |
+
self.model.eval()
|
| 648 |
+
|
| 649 |
+
loss_sum = 0.0
|
| 650 |
+
loss_tokens = 0
|
| 651 |
+
codon_correct = 0
|
| 652 |
+
codon_total = 0
|
| 653 |
+
aa_correct = 0
|
| 654 |
+
aa_total = 0
|
| 655 |
+
|
| 656 |
+
tok = self.tokenizer
|
| 657 |
+
pad_id = int(tok.pad_token_id) if tok is not None else 0
|
| 658 |
+
eos_id = int(tok.special_ids.eos) if tok is not None and hasattr(tok, "special_ids") else -999
|
| 659 |
+
num_special = int(tok.num_special_tokens) if tok is not None else 0
|
| 660 |
+
codon2aa = tok.codon2aa_char_map() if tok is not None and hasattr(tok, "codon2aa_char_map") else {}
|
| 661 |
+
|
| 662 |
+
is_streaming = isinstance(self.eval_dataloader.dataset, IterableDataset)
|
| 663 |
+
max_batches = int(self.args.eval_steps) if (is_streaming and self.args.eval_steps > 0) else None
|
| 664 |
+
|
| 665 |
+
with torch.no_grad():
|
| 666 |
+
eval_iter = iter(self.eval_dataloader)
|
| 667 |
+
b_idx = 0
|
| 668 |
+
while True:
|
| 669 |
+
batch, has_batch, local_has_batch = self._next_batch_sync(eval_iter)
|
| 670 |
+
if not has_batch:
|
| 671 |
+
if local_has_batch and self._is_main():
|
| 672 |
+
logger.debug("eval dataloader: discarded tail batch to stay in sync across ranks")
|
| 673 |
+
break
|
| 674 |
+
|
| 675 |
+
if max_batches is not None and b_idx >= max_batches:
|
| 676 |
+
break
|
| 677 |
+
|
| 678 |
+
batch = self._prepare_batch(batch)
|
| 679 |
+
|
| 680 |
+
codon_ids = batch["codon_ids"].to(self.device)
|
| 681 |
+
input_ids = codon_ids[:, :-1]
|
| 682 |
+
labels = codon_ids[:, :-1]
|
| 683 |
+
|
| 684 |
+
labels = labels.clone()
|
| 685 |
+
labels[labels == pad_id] = -100
|
| 686 |
+
labels[labels == eos_id] = -100
|
| 687 |
+
|
| 688 |
+
cond = self._build_cond(batch)
|
| 689 |
+
|
| 690 |
+
use_cuda = (self.device.type == "cuda")
|
| 691 |
+
autocast_dtype = self._amp_dtype
|
| 692 |
+
if autocast_dtype is not None and use_cuda:
|
| 693 |
+
ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype)
|
| 694 |
+
else:
|
| 695 |
+
from contextlib import nullcontext
|
| 696 |
+
ctx = nullcontext()
|
| 697 |
+
|
| 698 |
+
with ctx:
|
| 699 |
+
out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True)
|
| 700 |
+
|
| 701 |
+
loss = out.get("loss")
|
| 702 |
+
per_cap = out.get("per_cap")
|
| 703 |
+
logits = out.get("logits")
|
| 704 |
+
|
| 705 |
+
tokens_in_batch = 0
|
| 706 |
+
if per_cap is not None:
|
| 707 |
+
tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item())
|
| 708 |
+
loss_tokens += tokens_in_batch
|
| 709 |
+
|
| 710 |
+
if loss is not None and tokens_in_batch > 0:
|
| 711 |
+
loss_sum += float(loss.detach().item()) * tokens_in_batch
|
| 712 |
+
|
| 713 |
+
if logits is None or logits.size(1) == 0 or per_cap is None:
|
| 714 |
+
continue
|
| 715 |
+
|
| 716 |
+
max_cap = logits.size(1)
|
| 717 |
+
batch_size = logits.size(0)
|
| 718 |
+
|
| 719 |
+
labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device)
|
| 720 |
+
common_cols = min(labels.size(1), max_cap)
|
| 721 |
+
if common_cols > 0:
|
| 722 |
+
labels_aligned[:, :common_cols] = labels[:, :common_cols]
|
| 723 |
+
|
| 724 |
+
per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap)
|
| 725 |
+
for row in range(batch_size):
|
| 726 |
+
cap = int(per_cap_int[row].item())
|
| 727 |
+
if cap < max_cap:
|
| 728 |
+
labels_aligned[row, cap:] = -100
|
| 729 |
+
|
| 730 |
+
supervised = labels_aligned != -100
|
| 731 |
+
if num_special > 0:
|
| 732 |
+
supervised = supervised & (labels_aligned >= num_special)
|
| 733 |
+
if not supervised.any():
|
| 734 |
+
continue
|
| 735 |
+
|
| 736 |
+
preds = logits.argmax(dim=-1)
|
| 737 |
+
codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item())
|
| 738 |
+
codon_total += int(supervised.sum().item())
|
| 739 |
+
|
| 740 |
+
if codon2aa and isinstance(batch, dict) and "protein_seqs" in batch:
|
| 741 |
+
prot_list = batch.get("protein_seqs", [])
|
| 742 |
+
for row in range(batch_size):
|
| 743 |
+
cap = int(per_cap_int[row].item())
|
| 744 |
+
if cap <= 0:
|
| 745 |
+
continue
|
| 746 |
+
mask_row = supervised[row, :cap]
|
| 747 |
+
if not mask_row.any():
|
| 748 |
+
continue
|
| 749 |
+
preds_row = preds[row, :cap][mask_row]
|
| 750 |
+
prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else ""
|
| 751 |
+
if not prot:
|
| 752 |
+
continue
|
| 753 |
+
seq_len = min(len(prot), preds_row.size(0))
|
| 754 |
+
if seq_len <= 0:
|
| 755 |
+
continue
|
| 756 |
+
pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
|
| 757 |
+
truth_aa = prot[:seq_len]
|
| 758 |
+
aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i])
|
| 759 |
+
aa_total += seq_len
|
| 760 |
+
|
| 761 |
+
b_idx += 1
|
| 762 |
+
|
| 763 |
+
totals = torch.tensor(
|
| 764 |
+
[loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total],
|
| 765 |
+
dtype=torch.float64,
|
| 766 |
+
device=self.device,
|
| 767 |
+
)
|
| 768 |
+
if dist.is_available() and dist.is_initialized():
|
| 769 |
+
# Ensure every rank has finished its forward passes before the final
|
| 770 |
+
# metric reduction, otherwise FSDP may still be issuing _all_gather
|
| 771 |
+
# collectives on slower ranks.
|
| 772 |
+
self._barrier()
|
| 773 |
+
dist.all_reduce(totals, op=dist.ReduceOp.SUM)
|
| 774 |
+
|
| 775 |
+
loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total = totals.tolist()
|
| 776 |
+
|
| 777 |
+
self.model.train()
|
| 778 |
+
|
| 779 |
+
metrics: Dict[str, float] = {"eval_loss": float(loss_sum) / loss_tokens if loss_tokens > 0 else 0.0}
|
| 780 |
+
if codon_total > 0:
|
| 781 |
+
metrics["eval_codon_acc"] = float(codon_correct) / codon_total
|
| 782 |
+
if aa_total > 0:
|
| 783 |
+
metrics["eval_aa_acc"] = float(aa_correct) / aa_total
|
| 784 |
+
|
| 785 |
+
self._barrier()
|
| 786 |
+
return metrics
|
| 787 |
+
|
| 788 |
+
# ---- internals ----
|
| 789 |
+
def _setup_fsdp(self):
|
| 790 |
+
# Ensure default process group is initialized (required by FSDP)
|
| 791 |
+
device = self.device
|
| 792 |
+
if dist.is_available() and not dist.is_initialized():
|
| 793 |
+
backend = "nccl" if device.type == "cuda" else "gloo"
|
| 794 |
+
sig = inspect.signature(dist.init_process_group)
|
| 795 |
+
if "timeout" in sig.parameters:
|
| 796 |
+
dist.init_process_group(backend=backend, init_method="env://", timeout=datetime.timedelta(minutes=30))
|
| 797 |
+
else:
|
| 798 |
+
dist.init_process_group(backend=backend, init_method="env://")
|
| 799 |
+
mp = MixedPrecision(
|
| 800 |
+
param_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32),
|
| 801 |
+
reduce_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32),
|
| 802 |
+
buffer_dtype=torch.float32,
|
| 803 |
+
)
|
| 804 |
+
logger.info(f"FSDP enabled: sharding={self.args.fsdp} mp_param={mp.param_dtype} mp_reduce={mp.reduce_dtype}")
|
| 805 |
+
# Keep frozen ESM off FSDP if present
|
| 806 |
+
base = self._unwrap(self.model)
|
| 807 |
+
ignored = []
|
| 808 |
+
if hasattr(base, "esm") and isinstance(base.esm, nn.Module):
|
| 809 |
+
ignored.append(base.esm)
|
| 810 |
+
|
| 811 |
+
self.model = FSDP(
|
| 812 |
+
self.model,
|
| 813 |
+
device_id=(self.device if device.type == "cuda" else None),
|
| 814 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 815 |
+
mixed_precision=mp,
|
| 816 |
+
ignored_modules=(ignored if ignored else None),
|
| 817 |
+
sync_module_states=True,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
# Place ignored module on device exactly once
|
| 821 |
+
if ignored:
|
| 822 |
+
ignored[0].to(device)
|
| 823 |
+
|
| 824 |
+
def _unwrap(self, module):
|
| 825 |
+
return getattr(module, "module", module)
|
| 826 |
+
|
| 827 |
+
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 828 |
+
# Species embeddings (fixed-size or sequence)
|
| 829 |
+
if self.species_store is not None and "species_ids" in batch:
|
| 830 |
+
sids = batch["species_ids"]
|
| 831 |
+
if torch.is_tensor(sids):
|
| 832 |
+
sids = sids.detach().cpu().tolist()
|
| 833 |
+
result = self.species_store.batch_get(sids)
|
| 834 |
+
if isinstance(result, tuple):
|
| 835 |
+
sp_tok, _ = result # [B, Ls, Ds]
|
| 836 |
+
batch["species_tok_emb"] = sp_tok.to(self.device, non_blocking=True)
|
| 837 |
+
else:
|
| 838 |
+
sp = result # [B, Ds]
|
| 839 |
+
batch["species_emb"] = sp.to(self.device, non_blocking=True)
|
| 840 |
+
|
| 841 |
+
# Move obvious tensors
|
| 842 |
+
if "codon_ids" in batch and hasattr(batch["codon_ids"], "to"):
|
| 843 |
+
batch["codon_ids"] = batch["codon_ids"].to(self.device, non_blocking=True)
|
| 844 |
+
|
| 845 |
+
return batch
|
| 846 |
+
|
| 847 |
+
def _build_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 848 |
+
cond: Dict[str, Any] = {"control_mode": "fixed"}
|
| 849 |
+
if "species_tok_emb" in batch:
|
| 850 |
+
cond["species_tok_emb_src"] = batch["species_tok_emb"]
|
| 851 |
+
cond["species_tok_emb_tgt"] = batch["species_tok_emb"]
|
| 852 |
+
elif "species_emb" in batch:
|
| 853 |
+
cond["species_emb_src"] = batch["species_emb"]
|
| 854 |
+
cond["species_emb_tgt"] = batch["species_emb"]
|
| 855 |
+
if "protein_seqs" in batch:
|
| 856 |
+
cond["protein_seqs"] = batch["protein_seqs"]
|
| 857 |
+
return cond
|
| 858 |
+
|
| 859 |
+
def _next_batch_sync(self, iterator):
|
| 860 |
+
"""Fetch next batch and drop out early if any rank exhausts its loader."""
|
| 861 |
+
try:
|
| 862 |
+
batch = next(iterator)
|
| 863 |
+
local_has_batch = True
|
| 864 |
+
except StopIteration:
|
| 865 |
+
batch = None
|
| 866 |
+
local_has_batch = False
|
| 867 |
+
|
| 868 |
+
distributed = dist.is_available() and dist.is_initialized()
|
| 869 |
+
has_batch = local_has_batch
|
| 870 |
+
|
| 871 |
+
if distributed:
|
| 872 |
+
flag_device = self.device if self.device.type == "cuda" else torch.device("cpu")
|
| 873 |
+
flag = torch.tensor([1 if local_has_batch else 0], device=flag_device)
|
| 874 |
+
dist.all_reduce(flag, op=dist.ReduceOp.MIN)
|
| 875 |
+
has_batch = bool(flag.item())
|
| 876 |
+
|
| 877 |
+
if not has_batch:
|
| 878 |
+
return None, False, local_has_batch
|
| 879 |
+
|
| 880 |
+
return batch, True, local_has_batch
|
| 881 |
+
|
| 882 |
+
def _is_main(self) -> bool:
|
| 883 |
+
return (not dist.is_available()) or (not dist.is_initialized()) or dist.get_rank() == 0
|
| 884 |
+
|
| 885 |
+
def _barrier(self):
|
| 886 |
+
if dist.is_available() and dist.is_initialized():
|
| 887 |
+
# On NCCL, pass device_ids to avoid rank↔GPU mapping ambiguity when supported
|
| 888 |
+
if self.device.type == "cuda":
|
| 889 |
+
sig = inspect.signature(dist.barrier)
|
| 890 |
+
if "device_ids" in sig.parameters:
|
| 891 |
+
dist.barrier(device_ids=[self.local_rank])
|
| 892 |
+
return
|
| 893 |
+
dist.barrier()
|
| 894 |
+
|
| 895 |
+
def _max_cuda_peak_gb(self) -> Tuple[float, float]:
|
| 896 |
+
if self.device.type != "cuda" or not torch.cuda.is_available():
|
| 897 |
+
return 0.0, 0.0
|
| 898 |
+
vals = torch.tensor(
|
| 899 |
+
[
|
| 900 |
+
float(torch.cuda.max_memory_allocated(self.device)),
|
| 901 |
+
float(torch.cuda.max_memory_reserved(self.device)),
|
| 902 |
+
],
|
| 903 |
+
dtype=torch.float64,
|
| 904 |
+
device=self.device,
|
| 905 |
+
)
|
| 906 |
+
if dist.is_available() and dist.is_initialized():
|
| 907 |
+
dist.all_reduce(vals, op=dist.ReduceOp.MAX)
|
| 908 |
+
scale = float(1024 ** 3)
|
| 909 |
+
return float(vals[0].item() / scale), float(vals[1].item() / scale)
|
| 910 |
+
|
| 911 |
+
# (Per-sample quick eval removed; evaluation now uses held-out dataloader.)
|
| 912 |
+
|
| 913 |
+
def _epoch_for_logging(self) -> int:
|
| 914 |
+
steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0)
|
| 915 |
+
if steps_per_epoch > 0:
|
| 916 |
+
est = self.state.get("global_step", 0) // steps_per_epoch
|
| 917 |
+
if self.args.num_train_epochs > 0:
|
| 918 |
+
max_epoch = max(int(self.args.num_train_epochs) - 1, 0)
|
| 919 |
+
if est > max_epoch:
|
| 920 |
+
return max_epoch
|
| 921 |
+
return int(est)
|
| 922 |
+
return int(self.state.get("epoch", 0))
|
| 923 |
+
|
| 924 |
+
# ---- checkpointing ----
|
| 925 |
+
def _save_checkpoint(self, name: str):
|
| 926 |
+
self.state["epoch"] = int(self._epoch_for_logging())
|
| 927 |
+
# All ranks participate in FSDP state_dict collectives; only rank0 writes files.
|
| 928 |
+
out_dir = os.path.join(self.args.output_dir, name)
|
| 929 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 930 |
+
|
| 931 |
+
optim_state = None
|
| 932 |
+
if isinstance(self.model, FSDP):
|
| 933 |
+
with warnings.catch_warnings():
|
| 934 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 935 |
+
with FSDP.state_dict_type(
|
| 936 |
+
self.model,
|
| 937 |
+
StateDictType.FULL_STATE_DICT,
|
| 938 |
+
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
|
| 939 |
+
FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
|
| 940 |
+
):
|
| 941 |
+
state = self.model.state_dict()
|
| 942 |
+
# NOTE: Under FSDP, optimizer.state_dict() is sharded per-rank.
|
| 943 |
+
# Use FSDP.optim_state_dict() to materialize a full optimizer state dict (rank0_only).
|
| 944 |
+
if self.optimizer is not None:
|
| 945 |
+
optim_state = FSDP.optim_state_dict(self.model, self.optimizer)
|
| 946 |
+
else:
|
| 947 |
+
state = self._unwrap(self.model).state_dict()
|
| 948 |
+
if self.optimizer is not None:
|
| 949 |
+
optim_state = self.optimizer.state_dict()
|
| 950 |
+
|
| 951 |
+
# Save minimal data cursor (total samples yielded so far) next to output_dir if configured
|
| 952 |
+
per_rank_positions: Optional[List[int]] = None
|
| 953 |
+
p = getattr(self.args, "data_cursor_path", None)
|
| 954 |
+
if p:
|
| 955 |
+
ds = getattr(self.train_dataloader, "dataset", None)
|
| 956 |
+
if hasattr(ds, "get_stream_position"):
|
| 957 |
+
local_pos = int(ds.get_stream_position())
|
| 958 |
+
if dist.is_available() and dist.is_initialized():
|
| 959 |
+
gather_device = self.device if self.device.type == "cuda" else torch.device("cpu")
|
| 960 |
+
tensor = torch.tensor([local_pos], dtype=torch.long, device=gather_device)
|
| 961 |
+
gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
|
| 962 |
+
dist.all_gather(gathered, tensor)
|
| 963 |
+
per_rank_positions = [int(t.item()) for t in gathered]
|
| 964 |
+
else:
|
| 965 |
+
per_rank_positions = [local_pos]
|
| 966 |
+
|
| 967 |
+
if not self._is_main():
|
| 968 |
+
# Non-main ranks skip serialization but stay in lockstep
|
| 969 |
+
self._barrier()
|
| 970 |
+
return
|
| 971 |
+
|
| 972 |
+
# Rank 0 writes artifacts
|
| 973 |
+
save_file(state, os.path.join(out_dir, "model.safetensors"))
|
| 974 |
+
|
| 975 |
+
# Optimizer + scheduler
|
| 976 |
+
if optim_state is not None:
|
| 977 |
+
torch.save(optim_state, os.path.join(out_dir, "optimizer.pt"))
|
| 978 |
+
if self.lr_scheduler is not None:
|
| 979 |
+
torch.save(self.lr_scheduler.state_dict(), os.path.join(out_dir, "scheduler.pt"))
|
| 980 |
+
|
| 981 |
+
# Trainer config/state
|
| 982 |
+
base = self._unwrap(self.model)
|
| 983 |
+
# Infer mlp_ratio from first block if present
|
| 984 |
+
mlp_ratio = 4.0
|
| 985 |
+
try:
|
| 986 |
+
if hasattr(base, "blocks") and len(getattr(base, "blocks", [])) > 0:
|
| 987 |
+
w1 = base.blocks[0].ffn.w1.weight # [H*mlp, H]
|
| 988 |
+
H = int(getattr(base, "hidden_size", w1.shape[1]))
|
| 989 |
+
if H > 0:
|
| 990 |
+
mlp_ratio = float(w1.shape[0]) / float(H)
|
| 991 |
+
except Exception:
|
| 992 |
+
pass
|
| 993 |
+
|
| 994 |
+
trainer_cfg = {
|
| 995 |
+
# capacity / prefixes
|
| 996 |
+
"max_length": int(self.args.max_length),
|
| 997 |
+
"max_species_prefix": int(getattr(base, "max_species_prefix", 0)),
|
| 998 |
+
"max_protein_prefix": int(getattr(base, "max_protein_prefix", 0)),
|
| 999 |
+
|
| 1000 |
+
# architecture hints
|
| 1001 |
+
"hidden_size": int(getattr(base, "hidden_size", 0)),
|
| 1002 |
+
"num_hidden_layers": int(getattr(base, "num_layers", 0)),
|
| 1003 |
+
"num_attention_heads": int(getattr(base, "num_heads", 0)),
|
| 1004 |
+
"mlp_ratio": float(mlp_ratio),
|
| 1005 |
+
|
| 1006 |
+
# conditioning flags
|
| 1007 |
+
"prepend_species": bool(getattr(base, "prepend_species", True)),
|
| 1008 |
+
"prepend_protein": bool(getattr(base, "prepend_protein", False)),
|
| 1009 |
+
"species_embedding_dim": int(getattr(base, "species_embedding_dim", 1024)),
|
| 1010 |
+
|
| 1011 |
+
# ESM info (even if prepend_protein=False)
|
| 1012 |
+
"esm_model_name": str(getattr(self.args, "esm_model_name", "")),
|
| 1013 |
+
"esm_device": str(getattr(self.args, "esm_device", "cuda")),
|
| 1014 |
+
"esm_dtype": str(getattr(self.args, "esm_dtype", "fp32")).lower(),
|
| 1015 |
+
|
| 1016 |
+
# kernels
|
| 1017 |
+
|
| 1018 |
+
# attention impl
|
| 1019 |
+
"attn_impl": str(getattr(base, "attn_impl", "gqa")),
|
| 1020 |
+
"num_kv_groups": int(getattr(base, "num_kv_groups", 0)),
|
| 1021 |
+
}
|
| 1022 |
+
with open(os.path.join(out_dir, "trainer_config.json"), "w") as f:
|
| 1023 |
+
json.dump(trainer_cfg, f, indent=2)
|
| 1024 |
+
with open(os.path.join(out_dir, "trainer_state.json"), "w") as f:
|
| 1025 |
+
json.dump({"epoch": self.state["epoch"], "global_step": self.state["global_step"]}, f, indent=2)
|
| 1026 |
+
|
| 1027 |
+
if p and per_rank_positions is not None:
|
| 1028 |
+
payload = {
|
| 1029 |
+
"skip_samples": int(sum(per_rank_positions)),
|
| 1030 |
+
"per_rank": per_rank_positions,
|
| 1031 |
+
"world_size": len(per_rank_positions),
|
| 1032 |
+
}
|
| 1033 |
+
os.makedirs(os.path.dirname(os.path.abspath(p)), exist_ok=True)
|
| 1034 |
+
with open(p, "w") as f:
|
| 1035 |
+
json.dump(payload, f)
|
| 1036 |
+
|
| 1037 |
+
# Tokenizer vocab for sampling
|
| 1038 |
+
try:
|
| 1039 |
+
if self.tokenizer is not None and hasattr(self.tokenizer, "save_vocabulary"):
|
| 1040 |
+
self.tokenizer.save_vocabulary(out_dir)
|
| 1041 |
+
except Exception as e:
|
| 1042 |
+
logger.warning(f"Failed to save vocabulary to {out_dir}: {e}")
|
| 1043 |
+
|
| 1044 |
+
self._prune_checkpoints(self.args.output_dir, self.args.save_total_limit)
|
| 1045 |
+
logger.info(f"Saved checkpoint → {out_dir}")
|
| 1046 |
+
|
| 1047 |
+
# Release other ranks
|
| 1048 |
+
self._barrier()
|
| 1049 |
+
|
| 1050 |
+
def _resume_from(self, ckpt_dir: str):
|
| 1051 |
+
st_path = os.path.join(ckpt_dir, "model.safetensors")
|
| 1052 |
+
if not os.path.exists(st_path):
|
| 1053 |
+
raise FileNotFoundError(f"No model.safetensors in {ckpt_dir}")
|
| 1054 |
+
state = load_file(st_path)
|
| 1055 |
+
|
| 1056 |
+
if isinstance(self.model, FSDP):
|
| 1057 |
+
with warnings.catch_warnings():
|
| 1058 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 1059 |
+
with FSDP.state_dict_type(
|
| 1060 |
+
self.model,
|
| 1061 |
+
StateDictType.FULL_STATE_DICT,
|
| 1062 |
+
FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
|
| 1063 |
+
):
|
| 1064 |
+
self.model.load_state_dict(state, strict=False)
|
| 1065 |
+
else:
|
| 1066 |
+
self._unwrap(self.model).load_state_dict(state, strict=False)
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
scheduler_restored = False
|
| 1070 |
+
|
| 1071 |
+
opt_path = os.path.join(ckpt_dir, "optimizer.pt")
|
| 1072 |
+
if os.path.exists(opt_path):
|
| 1073 |
+
if self.optimizer is None:
|
| 1074 |
+
self._create_optimizer_and_scheduler()
|
| 1075 |
+
if not self.args.override_lr_on_resume:
|
| 1076 |
+
loaded = torch.load(opt_path, map_location="cpu")
|
| 1077 |
+
# Under FSDP, saved optimizer.pt is a full optimizer state dict produced by
|
| 1078 |
+
# FSDP.optim_state_dict(). Convert it to a per-rank state dict before loading.
|
| 1079 |
+
if isinstance(self.model, FSDP):
|
| 1080 |
+
try:
|
| 1081 |
+
loaded = FSDP.optim_state_dict_to_load(self.model, self.optimizer, loaded)
|
| 1082 |
+
except Exception as e:
|
| 1083 |
+
msg = (
|
| 1084 |
+
"Failed to convert FSDP optimizer state dict for loading. "
|
| 1085 |
+
"This checkpoint likely contains an incomplete (rank0-only sharded) optimizer.pt from an older version. "
|
| 1086 |
+
"Full optimizer resume is not possible from this checkpoint.\n"
|
| 1087 |
+
f"Underlying error: {e}\n"
|
| 1088 |
+
"Options:\n"
|
| 1089 |
+
" 1) Start a fresh run (new --output_dir), or\n"
|
| 1090 |
+
" 2) Re-run with --override_lr_on_resume to skip optimizer restore (not a full resume)."
|
| 1091 |
+
)
|
| 1092 |
+
if self._is_main():
|
| 1093 |
+
logger.error(msg)
|
| 1094 |
+
raise RuntimeError(msg) from e
|
| 1095 |
+
self.optimizer.load_state_dict(loaded)
|
| 1096 |
+
|
| 1097 |
+
sch_path = os.path.join(ckpt_dir, "scheduler.pt")
|
| 1098 |
+
if os.path.exists(sch_path):
|
| 1099 |
+
if self.lr_scheduler is None:
|
| 1100 |
+
self._create_optimizer_and_scheduler()
|
| 1101 |
+
if self.lr_scheduler is not None and not self.args.override_lr_on_resume:
|
| 1102 |
+
self.lr_scheduler.load_state_dict(torch.load(sch_path, map_location="cpu"))
|
| 1103 |
+
scheduler_restored = True
|
| 1104 |
+
|
| 1105 |
+
ts_path = os.path.join(ckpt_dir, "trainer_state.json")
|
| 1106 |
+
if os.path.exists(ts_path):
|
| 1107 |
+
with open(ts_path, "r") as f:
|
| 1108 |
+
ts = json.load(f)
|
| 1109 |
+
self.state["epoch"] = int(ts.get("epoch", 0))
|
| 1110 |
+
self.state["global_step"] = int(ts.get("global_step", 0))
|
| 1111 |
+
|
| 1112 |
+
steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0)
|
| 1113 |
+
if steps_per_epoch > 0:
|
| 1114 |
+
inferred_epoch = self.state.get("global_step", 0) // steps_per_epoch
|
| 1115 |
+
num_epochs = max(int(self.args.num_train_epochs), 1)
|
| 1116 |
+
inferred_epoch = min(inferred_epoch, num_epochs - 1)
|
| 1117 |
+
if inferred_epoch != self.state.get("epoch"):
|
| 1118 |
+
if self._is_main():
|
| 1119 |
+
logger.info(
|
| 1120 |
+
"Adjusting epoch from %s to %s based on global_step %s and steps_per_epoch %s",
|
| 1121 |
+
self.state.get("epoch"),
|
| 1122 |
+
inferred_epoch,
|
| 1123 |
+
self.state.get("global_step"),
|
| 1124 |
+
steps_per_epoch,
|
| 1125 |
+
)
|
| 1126 |
+
self.state["epoch"] = int(inferred_epoch)
|
| 1127 |
+
|
| 1128 |
+
# If we skipped loading the scheduler state (e.g., different world size or override),
|
| 1129 |
+
# fast-forward it to the saved global_step so LR does not restart from warmup.
|
| 1130 |
+
if self.lr_scheduler is not None and not scheduler_restored:
|
| 1131 |
+
target_step = int(self.state.get("global_step", 0))
|
| 1132 |
+
if target_step > 0:
|
| 1133 |
+
try:
|
| 1134 |
+
# Most schedulers (LambdaLR, CosineAnnealing, etc.) accept an "epoch" kwarg.
|
| 1135 |
+
self.lr_scheduler.step(target_step)
|
| 1136 |
+
except TypeError:
|
| 1137 |
+
# Fallback: advance manually.
|
| 1138 |
+
for _ in range(target_step):
|
| 1139 |
+
self.lr_scheduler.step()
|
| 1140 |
+
# Ensure optimizer LR reflects the scheduler's current value.
|
| 1141 |
+
try:
|
| 1142 |
+
last_lrs = self.lr_scheduler.get_last_lr()
|
| 1143 |
+
except Exception:
|
| 1144 |
+
last_lrs = [group.get("lr") for group in self.optimizer.param_groups]
|
| 1145 |
+
if last_lrs:
|
| 1146 |
+
for group, lr in zip(self.optimizer.param_groups, last_lrs):
|
| 1147 |
+
group["lr"] = float(lr)
|
| 1148 |
+
|
| 1149 |
+
logger.info(f"Resumed from {ckpt_dir}")
|
| 1150 |
+
|
| 1151 |
+
def _checkpoint_step(self, path: str) -> Optional[int]:
|
| 1152 |
+
m = re.fullmatch(r"checkpoint-(\d+)", os.path.basename(path))
|
| 1153 |
+
if not m:
|
| 1154 |
+
return None
|
| 1155 |
+
return int(m.group(1))
|
| 1156 |
+
|
| 1157 |
+
def _prune_checkpoints(self, root: str, keep: int):
|
| 1158 |
+
if not os.path.isdir(root):
|
| 1159 |
+
return
|
| 1160 |
+
|
| 1161 |
+
try:
|
| 1162 |
+
subdirs = [
|
| 1163 |
+
os.path.join(root, d)
|
| 1164 |
+
for d in os.listdir(root)
|
| 1165 |
+
if os.path.isdir(os.path.join(root, d))
|
| 1166 |
+
]
|
| 1167 |
+
except FileNotFoundError:
|
| 1168 |
+
return
|
| 1169 |
+
|
| 1170 |
+
step_dirs: list[tuple[int, str]] = []
|
| 1171 |
+
for path in subdirs:
|
| 1172 |
+
step = self._checkpoint_step(path)
|
| 1173 |
+
if step is not None:
|
| 1174 |
+
step_dirs.append((step, path))
|
| 1175 |
+
|
| 1176 |
+
if not step_dirs:
|
| 1177 |
+
return
|
| 1178 |
+
|
| 1179 |
+
step_dirs.sort(key=lambda item: item[0])
|
| 1180 |
+
latest_step = step_dirs[-1][0]
|
| 1181 |
+
|
| 1182 |
+
recent_window = max(0, int(getattr(self.args, "ckpt_recent_window_steps", 0) or 0))
|
| 1183 |
+
recent_interval = max(0, int(getattr(self.args, "ckpt_recent_interval", 0) or 0))
|
| 1184 |
+
archive_interval = max(0, int(getattr(self.args, "ckpt_archive_interval", 0) or 0))
|
| 1185 |
+
|
| 1186 |
+
keep_paths: set[str] = set()
|
| 1187 |
+
if recent_window > 0 and (recent_interval > 0 or archive_interval > 0):
|
| 1188 |
+
if recent_interval <= 0:
|
| 1189 |
+
recent_interval = max(1, int(getattr(self.args, "save_steps", 1) or 1))
|
| 1190 |
+
|
| 1191 |
+
for step, path in step_dirs:
|
| 1192 |
+
age = latest_step - step
|
| 1193 |
+
if age <= recent_window:
|
| 1194 |
+
interval = recent_interval
|
| 1195 |
+
else:
|
| 1196 |
+
interval = archive_interval
|
| 1197 |
+
if interval > 0 and (step % interval == 0):
|
| 1198 |
+
keep_paths.add(path)
|
| 1199 |
+
|
| 1200 |
+
if not keep_paths:
|
| 1201 |
+
# Legacy fallback: keep the most recent N step checkpoints.
|
| 1202 |
+
if keep <= 0:
|
| 1203 |
+
return
|
| 1204 |
+
keep_paths = {path for _, path in step_dirs[-keep:]}
|
| 1205 |
+
else:
|
| 1206 |
+
# Always preserve the newest checkpoint, even if the interval math misses it.
|
| 1207 |
+
keep_paths.add(step_dirs[-1][1])
|
| 1208 |
+
if keep > 0:
|
| 1209 |
+
kept = [(step, path) for step, path in step_dirs if path in keep_paths]
|
| 1210 |
+
if len(kept) > keep:
|
| 1211 |
+
trim = len(kept) - keep
|
| 1212 |
+
for _, path in kept[:trim]:
|
| 1213 |
+
keep_paths.discard(path)
|
| 1214 |
+
|
| 1215 |
+
removed = []
|
| 1216 |
+
for _, path in step_dirs:
|
| 1217 |
+
if path in keep_paths:
|
| 1218 |
+
continue
|
| 1219 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 1220 |
+
removed.append(os.path.basename(path))
|
| 1221 |
+
|
| 1222 |
+
if removed and self._is_main():
|
| 1223 |
+
logger.info(
|
| 1224 |
+
"Pruned %s checkpoints (latest_step=%s, recent_window=%s, recent_interval=%s, archive_interval=%s)",
|
| 1225 |
+
len(removed),
|
| 1226 |
+
latest_step,
|
| 1227 |
+
recent_window,
|
| 1228 |
+
recent_interval,
|
| 1229 |
+
archive_interval,
|
| 1230 |
+
)
|
train.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Minimal, honest training script for CodonGPT on CSV data.
|
| 4 |
+
|
| 5 |
+
- Species conditioning: REQUIRED (precomputed embeddings)
|
| 6 |
+
- Protein conditioning (ESM-C): ENABLED BY DEFAULT. Disable with --no_protein.
|
| 7 |
+
- Global capacity is controlled by --max_length (prefix + start + codon).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import math
|
| 12 |
+
import argparse
|
| 13 |
+
import logging
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from src import CodonGPT, CodonTokenizer, Trainer, TrainingArguments
|
| 17 |
+
from src.dataset import create_precomputed_dataloaders, SpeciesEmbeddingStore
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 21 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 22 |
+
level=logging.INFO,
|
| 23 |
+
)
|
| 24 |
+
logger = logging.getLogger("codongpt.train")
|
| 25 |
+
|
| 26 |
+
def _describe_sdp_kernels() -> None:
|
| 27 |
+
# Log the enabled SDPA backends (Flash/MemEff/Math) without raising on older PyTorch
|
| 28 |
+
flash = None; mem_eff = None; mathk = None
|
| 29 |
+
if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda'):
|
| 30 |
+
tbc = torch.backends.cuda
|
| 31 |
+
if hasattr(tbc, 'flash_sdp_enabled'):
|
| 32 |
+
flash = tbc.flash_sdp_enabled()
|
| 33 |
+
if hasattr(tbc, 'mem_efficient_sdp_enabled'):
|
| 34 |
+
mem_eff = tbc.mem_efficient_sdp_enabled()
|
| 35 |
+
if hasattr(tbc, 'math_sdp_enabled'):
|
| 36 |
+
mathk = tbc.math_sdp_enabled()
|
| 37 |
+
logger.info(f"SDP kernels: flash={flash} mem_efficient={mem_eff} math={mathk}")
|
| 38 |
+
|
| 39 |
+
def _print_model_size(model: torch.nn.Module, bf16: bool, fp16: bool) -> None:
|
| 40 |
+
total = sum(p.numel() for p in model.parameters())
|
| 41 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 42 |
+
w_bytes = 2 if (bf16 or fp16) else 4
|
| 43 |
+
opt_bytes = 8 # Adam moments in FP32
|
| 44 |
+
weights_gb = total * w_bytes / (1024**3)
|
| 45 |
+
opt_gb = trainable * opt_bytes / (1024**3)
|
| 46 |
+
logger.info(
|
| 47 |
+
f"Model params: total={total:,} trainable={trainable:,} (~{weights_gb:.2f} GB weights, ~{opt_gb:.2f} GB optimizer)"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def _speed_toggles():
|
| 51 |
+
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
|
| 52 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 53 |
+
if hasattr(torch, "set_float32_matmul_precision"):
|
| 54 |
+
torch.set_float32_matmul_precision("high")
|
| 55 |
+
if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "benchmark"):
|
| 56 |
+
torch.backends.cudnn.benchmark = True
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def parse_args():
|
| 60 |
+
p = argparse.ArgumentParser(description="Train CodonGPT on CSV data")
|
| 61 |
+
# Data (CSV path or Parquet glob/dir)
|
| 62 |
+
p.add_argument("--train_data", type=str, default="random_sample_1000.csv",
|
| 63 |
+
help="Training data: CSV file or Parquet glob/dir (e.g., ./data/train_shards/*.parquet)")
|
| 64 |
+
p.add_argument("--val_data", type=str, default=None,
|
| 65 |
+
help="Validation data: CSV file or Parquet glob/dir")
|
| 66 |
+
p.add_argument("--embeddings_dir", type=str, default="embeddings",
|
| 67 |
+
help="Dir with species embeddings (species_vocab.json, *.bin/memmap)")
|
| 68 |
+
|
| 69 |
+
# Model / capacity
|
| 70 |
+
p.add_argument("--hidden", type=int, default=750, help="Model hidden size")
|
| 71 |
+
p.add_argument("--layers", type=int, default=20, help="Number of transformer layers")
|
| 72 |
+
p.add_argument("--heads", type=int, default=15, help="Number of attention heads")
|
| 73 |
+
p.add_argument("--attn", type=str, choices=["mha", "gqa"], default="gqa", help="Attention implementation: 'mha' or 'gqa'")
|
| 74 |
+
p.add_argument("--num_kv_groups", type=int, default=5, help="GQA: number of KV groups (0 = default/no grouping)")
|
| 75 |
+
p.add_argument("--mlp_ratio", type=float, default=3.2, help="FFN expansion ratio (mlp hidden = ratio * hidden)")
|
| 76 |
+
p.add_argument("--max_length", type=int, default=2048,
|
| 77 |
+
help="Global max length (prefix + start + codon)")
|
| 78 |
+
p.add_argument("--max_species_prefix", type=int, default=0,
|
| 79 |
+
help="Cap species prefix tokens (0 = uncapped)")
|
| 80 |
+
p.add_argument("--max_protein_prefix", type=int, default=1024,
|
| 81 |
+
help="Cap protein prefix tokens (0 = uncapped)")
|
| 82 |
+
|
| 83 |
+
# Protein conditioning: always enabled (ESM-C)
|
| 84 |
+
|
| 85 |
+
# Training
|
| 86 |
+
p.add_argument("--output_dir", type=str, default="checkpoints", help="Where to save checkpoints")
|
| 87 |
+
p.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
|
| 88 |
+
p.add_argument("--batch_size", type=int, default=20, help="Per-device train batch size")
|
| 89 |
+
p.add_argument("--eval_batch_size", type=int, default=32, help="Per-device eval batch size")
|
| 90 |
+
p.add_argument("--workers", type=int, default=4, help="DataLoader workers")
|
| 91 |
+
p.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
|
| 92 |
+
p.add_argument("--train_shuffle_buffer", type=int, default=0,
|
| 93 |
+
help="Streaming shuffle buffer for training (set 0 when data is pre-shuffled)")
|
| 94 |
+
p.add_argument("--val_shuffle_buffer", type=int, default=0,
|
| 95 |
+
help="Streaming shuffle buffer for validation (0 disables)")
|
| 96 |
+
p.add_argument("--csv_chunksize", type=int, default=200_000,
|
| 97 |
+
help="Pandas read_csv chunksize for CSV inputs")
|
| 98 |
+
|
| 99 |
+
# Optim / schedule
|
| 100 |
+
p.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
| 101 |
+
p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio for LR schedule (0.0-1.0)")
|
| 102 |
+
p.add_argument(
|
| 103 |
+
"--lr_scheduler",
|
| 104 |
+
type=str,
|
| 105 |
+
choices=["linear", "cosine", "constant"],
|
| 106 |
+
default="linear",
|
| 107 |
+
help="LR schedule applied after warmup; 'linear' decays to zero by the end of training",
|
| 108 |
+
)
|
| 109 |
+
p.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay")
|
| 110 |
+
p.add_argument("--adam_beta1", type=float, default=0.9,
|
| 111 |
+
help="Adam beta1 (momentum) coefficient")
|
| 112 |
+
p.add_argument("--adam_beta2", type=float, default=0.95,
|
| 113 |
+
help="Adam beta2 (squared-gradient) coefficient")
|
| 114 |
+
p.add_argument("--logging_steps", type=int, default=20, help="Logging interval (steps)")
|
| 115 |
+
p.add_argument("--save_steps", type=int, default=10, help="Save every N steps (0 disables step-saving)")
|
| 116 |
+
p.add_argument("--save_total_limit", type=int, default=10, help="Keep at most N recent checkpoints")
|
| 117 |
+
p.add_argument("--ckpt_recent_window_steps", type=int, default=0,
|
| 118 |
+
help="If >0, keep finer-grained checkpoints within this many recent steps")
|
| 119 |
+
p.add_argument("--ckpt_recent_interval", type=int, default=0,
|
| 120 |
+
help="Retention interval inside the recent checkpoint window (0 disables custom retention)")
|
| 121 |
+
p.add_argument("--ckpt_archive_interval", type=int, default=0,
|
| 122 |
+
help="Retention interval for checkpoints older than the recent window (0 prunes them)")
|
| 123 |
+
p.add_argument("--max_steps", type=int, default=-1,
|
| 124 |
+
help="Total training steps. REQUIRED for streaming (IterableDataset)")
|
| 125 |
+
p.add_argument("--steps_per_epoch", type=int, default=0,
|
| 126 |
+
help="For streaming datasets: shape LR schedule as epochs*steps_per_epoch when max_steps<0")
|
| 127 |
+
p.add_argument("--max_grad_norm", type=float, default=1.0,
|
| 128 |
+
help="Clip gradients to this global L2 norm; set <=0 to disable")
|
| 129 |
+
p.add_argument("--override_lr_on_resume", action="store_true",
|
| 130 |
+
help="Do not restore LR/optimizer state on resume (keep current lr)")
|
| 131 |
+
|
| 132 |
+
# Resume
|
| 133 |
+
p.add_argument("--resume_from", type=str, default=None,
|
| 134 |
+
help="Path to checkpoint dir to resume from; pass 'auto' to pick latest in output_dir")
|
| 135 |
+
|
| 136 |
+
# Evaluation scheduling
|
| 137 |
+
p.add_argument("--eval_interval", type=int, default=0,
|
| 138 |
+
help="Run evaluation every N optimizer steps on --val_data (0 disables)")
|
| 139 |
+
p.add_argument("--eval_steps", type=int, default=5000,
|
| 140 |
+
help="For streaming eval datasets: limit to this many batches (0 = full eval)")
|
| 141 |
+
|
| 142 |
+
# Hardware / precision
|
| 143 |
+
p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
|
| 144 |
+
p.add_argument("--bf16", action="store_true", help="bfloat16 mixed precision")
|
| 145 |
+
p.add_argument("--fp16", action="store_true", help="float16 mixed precision")
|
| 146 |
+
p.add_argument("--fsdp", action="store_true", help="Enable FSDP full sharding")
|
| 147 |
+
p.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing")
|
| 148 |
+
return p.parse_args()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def main():
|
| 152 |
+
args = parse_args()
|
| 153 |
+
_speed_toggles()
|
| 154 |
+
|
| 155 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 156 |
+
logger.warning("CUDA not available; switching to CPU")
|
| 157 |
+
args.device = "cpu"
|
| 158 |
+
|
| 159 |
+
# Tokenizer
|
| 160 |
+
tok = CodonTokenizer()
|
| 161 |
+
# Ensure output dir exists and persist vocab.json (used by sampler)
|
| 162 |
+
os.makedirs(os.path.abspath(args.output_dir), exist_ok=True)
|
| 163 |
+
tok.save_vocabulary(args.output_dir)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# Data first — we need Ds for species embeddings
|
| 167 |
+
train_loader, val_loader, species_store = create_precomputed_dataloaders(
|
| 168 |
+
train_path=args.train_data,
|
| 169 |
+
val_path=args.val_data,
|
| 170 |
+
embeddings_dir=args.embeddings_dir,
|
| 171 |
+
tokenizer=tok,
|
| 172 |
+
batch_size=args.batch_size,
|
| 173 |
+
num_workers=args.workers,
|
| 174 |
+
species_pooling="sequence", # prefer variable-length token sequence if available
|
| 175 |
+
csv_chunksize=int(args.csv_chunksize),
|
| 176 |
+
train_shuffle_buffer=int(args.train_shuffle_buffer),
|
| 177 |
+
val_shuffle_buffer=int(args.val_shuffle_buffer),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Estimate steps_per_epoch for streaming schedule shaping if not provided
|
| 181 |
+
steps_per_epoch = int(getattr(args, "steps_per_epoch", 0) or 0)
|
| 182 |
+
total_rows = 0
|
| 183 |
+
paths: list[str] = []
|
| 184 |
+
if steps_per_epoch <= 0 and int(args.max_steps) < 0:
|
| 185 |
+
def _expand_paths(maybe: str | list[str]) -> list[str]:
|
| 186 |
+
import glob as _glob
|
| 187 |
+
from pathlib import Path as _Path
|
| 188 |
+
paths: list[str] = []
|
| 189 |
+
if isinstance(maybe, str):
|
| 190 |
+
p = _Path(maybe)
|
| 191 |
+
if p.is_dir():
|
| 192 |
+
paths.extend(sorted(str(x) for x in p.rglob("*.parquet")))
|
| 193 |
+
else:
|
| 194 |
+
paths = sorted(_glob.glob(str(p)))
|
| 195 |
+
else:
|
| 196 |
+
for it in maybe:
|
| 197 |
+
paths.extend(_expand_paths(it))
|
| 198 |
+
# de-dup
|
| 199 |
+
seen = set(); out = []
|
| 200 |
+
for x in paths:
|
| 201 |
+
if x not in seen:
|
| 202 |
+
out.append(x); seen.add(x)
|
| 203 |
+
return out
|
| 204 |
+
|
| 205 |
+
paths = _expand_paths(args.train_data)
|
| 206 |
+
if paths:
|
| 207 |
+
try:
|
| 208 |
+
import pyarrow.parquet as pq
|
| 209 |
+
for fp in paths:
|
| 210 |
+
if fp.lower().endswith((".parquet", ".parq")):
|
| 211 |
+
pf = pq.ParquetFile(fp)
|
| 212 |
+
md = pf.metadata
|
| 213 |
+
if md is not None:
|
| 214 |
+
total_rows += int(md.num_rows)
|
| 215 |
+
except Exception:
|
| 216 |
+
# Fallback: keep steps_per_epoch at 0 if pyarrow not available
|
| 217 |
+
total_rows = 0
|
| 218 |
+
if total_rows > 0:
|
| 219 |
+
world = int(os.environ.get("WORLD_SIZE", "1"))
|
| 220 |
+
ga = max(1, int(getattr(args, "grad_accum", 1)))
|
| 221 |
+
denom = max(1, int(args.batch_size) * max(1, world) * ga)
|
| 222 |
+
steps_per_epoch = max(1, math.ceil(total_rows / denom))
|
| 223 |
+
logger.info(f"Estimated steps_per_epoch={steps_per_epoch} from {len(paths)} parquet files, total_rows={total_rows}")
|
| 224 |
+
|
| 225 |
+
world = int(os.environ.get("WORLD_SIZE", "1"))
|
| 226 |
+
grad_accum = max(1, int(getattr(args, "grad_accum", 1)))
|
| 227 |
+
effective_global_batch = int(args.batch_size) * max(1, world) * grad_accum
|
| 228 |
+
logger.info(
|
| 229 |
+
"Batch config: per_device_train_batch=%s per_device_eval_batch=%s world_size=%s grad_accum=%s effective_global_batch=%s",
|
| 230 |
+
args.batch_size,
|
| 231 |
+
args.eval_batch_size,
|
| 232 |
+
world,
|
| 233 |
+
grad_accum,
|
| 234 |
+
effective_global_batch,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Resolve per-process CUDA device for ESM (avoid defaulting to cuda:0 on all ranks)
|
| 238 |
+
esm_dev = "cpu"
|
| 239 |
+
if args.device == "cuda" and torch.cuda.is_available():
|
| 240 |
+
lr = int(os.environ.get("LOCAL_RANK", "0"))
|
| 241 |
+
esm_dev = f"cuda:{lr}"
|
| 242 |
+
|
| 243 |
+
# Model — species is always on; protein defaults to ON (can be disabled with --no_protein)
|
| 244 |
+
model = CodonGPT(
|
| 245 |
+
vocab_size=tok.vocab_size,
|
| 246 |
+
num_special_tokens=tok.num_special_tokens,
|
| 247 |
+
special_ids=tok.special_ids,
|
| 248 |
+
hidden_size=args.hidden,
|
| 249 |
+
num_layers=args.layers,
|
| 250 |
+
num_heads=args.heads,
|
| 251 |
+
mlp_ratio=float(args.mlp_ratio),
|
| 252 |
+
max_position_embeddings=args.max_length,
|
| 253 |
+
prepend_species=True,
|
| 254 |
+
prepend_protein=True,
|
| 255 |
+
esm_model_name="esmc_300m",
|
| 256 |
+
esm_device=esm_dev,
|
| 257 |
+
max_protein_prefix=int(args.max_protein_prefix),
|
| 258 |
+
max_species_prefix=int(args.max_species_prefix),
|
| 259 |
+
dropout=0.1,
|
| 260 |
+
species_embedding_dim=int(species_store.Ds()),
|
| 261 |
+
attn_impl=str(args.attn),
|
| 262 |
+
num_kv_groups=int(args.num_kv_groups),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Report model size and SDPA (Flash) kernel configuration
|
| 266 |
+
_print_model_size(model, bf16=bool(args.bf16), fp16=bool(args.fp16))
|
| 267 |
+
_describe_sdp_kernels()
|
| 268 |
+
|
| 269 |
+
# Trainer args
|
| 270 |
+
targs = TrainingArguments(
|
| 271 |
+
output_dir=args.output_dir,
|
| 272 |
+
save_steps=args.save_steps,
|
| 273 |
+
save_total_limit=int(args.save_total_limit),
|
| 274 |
+
ckpt_recent_window_steps=int(args.ckpt_recent_window_steps),
|
| 275 |
+
ckpt_recent_interval=int(args.ckpt_recent_interval),
|
| 276 |
+
ckpt_archive_interval=int(args.ckpt_archive_interval),
|
| 277 |
+
num_train_epochs=args.epochs,
|
| 278 |
+
max_steps=int(args.max_steps),
|
| 279 |
+
gradient_accumulation_steps=int(args.grad_accum),
|
| 280 |
+
warmup_ratio=float(args.warmup_ratio),
|
| 281 |
+
lr_scheduler_type=str(args.lr_scheduler),
|
| 282 |
+
per_device_train_batch_size=args.batch_size,
|
| 283 |
+
per_device_eval_batch_size=args.eval_batch_size,
|
| 284 |
+
dataloader_num_workers=args.workers,
|
| 285 |
+
learning_rate=args.lr,
|
| 286 |
+
weight_decay=args.weight_decay,
|
| 287 |
+
adam_beta1=float(args.adam_beta1),
|
| 288 |
+
adam_beta2=float(args.adam_beta2),
|
| 289 |
+
max_grad_norm=float(args.max_grad_norm),
|
| 290 |
+
logging_steps=args.logging_steps,
|
| 291 |
+
override_lr_on_resume=bool(args.override_lr_on_resume),
|
| 292 |
+
data_cursor_path=os.path.join(os.path.abspath(args.output_dir), "data_cursor.json"),
|
| 293 |
+
fp16=bool(args.fp16),
|
| 294 |
+
bf16=bool(args.bf16),
|
| 295 |
+
fsdp=("full_shard" if args.fsdp else None),
|
| 296 |
+
gradient_checkpointing=bool(args.grad_ckpt),
|
| 297 |
+
max_length=int(args.max_length),
|
| 298 |
+
esm_model_name="esmc_300m",
|
| 299 |
+
esm_device=esm_dev,
|
| 300 |
+
esm_dtype=("bf16" if args.bf16 else ("fp16" if args.fp16 else "fp32")),
|
| 301 |
+
# sampling eval
|
| 302 |
+
eval_interval=int(args.eval_interval),
|
| 303 |
+
eval_steps=int(args.eval_steps),
|
| 304 |
+
steps_per_epoch=int(steps_per_epoch),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Resolve auto-resume if requested
|
| 308 |
+
resume_path = None
|
| 309 |
+
if args.resume_from:
|
| 310 |
+
if args.resume_from == "auto":
|
| 311 |
+
root = os.path.abspath(args.output_dir)
|
| 312 |
+
if os.path.isdir(root):
|
| 313 |
+
try:
|
| 314 |
+
subdirs = []
|
| 315 |
+
for d in os.listdir(root):
|
| 316 |
+
path = os.path.join(root, d)
|
| 317 |
+
if not os.path.isdir(path):
|
| 318 |
+
continue
|
| 319 |
+
if not (
|
| 320 |
+
d == "final_model" or
|
| 321 |
+
d.startswith("checkpoint-")
|
| 322 |
+
):
|
| 323 |
+
continue
|
| 324 |
+
if not (
|
| 325 |
+
os.path.exists(os.path.join(path, "model.safetensors")) or
|
| 326 |
+
os.path.exists(os.path.join(path, "pytorch_model.bin"))
|
| 327 |
+
):
|
| 328 |
+
continue
|
| 329 |
+
subdirs.append(path)
|
| 330 |
+
subdirs.sort(key=lambda d: os.path.getmtime(d), reverse=True)
|
| 331 |
+
resume_path = subdirs[0] if subdirs else None
|
| 332 |
+
except Exception:
|
| 333 |
+
resume_path = None
|
| 334 |
+
else:
|
| 335 |
+
resume_path = args.resume_from
|
| 336 |
+
|
| 337 |
+
trainer = Trainer(
|
| 338 |
+
model=model,
|
| 339 |
+
args=targs,
|
| 340 |
+
tokenizer=tok,
|
| 341 |
+
species_store=species_store,
|
| 342 |
+
resume_from_checkpoint=resume_path,
|
| 343 |
+
)
|
| 344 |
+
trainer.attach_dataloaders(train_loader, val_loader)
|
| 345 |
+
|
| 346 |
+
logger.info("Starting training...")
|
| 347 |
+
trainer.train()
|
| 348 |
+
logger.info("Training finished.")
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if __name__ == "__main__":
|
| 352 |
+
main()
|
training_checkpoints/checkpoint-71000/config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_length": 2048,
|
| 3 |
+
"max_species_prefix": 0,
|
| 4 |
+
"max_protein_prefix": 1024,
|
| 5 |
+
"hidden_size": 750,
|
| 6 |
+
"num_hidden_layers": 20,
|
| 7 |
+
"num_attention_heads": 15,
|
| 8 |
+
"mlp_ratio": 3.2,
|
| 9 |
+
"prepend_species": true,
|
| 10 |
+
"prepend_protein": true,
|
| 11 |
+
"species_embedding_dim": 1024,
|
| 12 |
+
"esm_model_name": "esmc_300m",
|
| 13 |
+
"esm_device": "cuda:0",
|
| 14 |
+
"esm_dtype": "bf16",
|
| 15 |
+
"attn_impl": "mha",
|
| 16 |
+
"num_kv_groups": 5
|
| 17 |
+
}
|
training_checkpoints/checkpoint-71000/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07bc223f4d934e2baff5a8085a78348766b6a8324aa091a1459fce2b2c6d3837
|
| 3 |
+
size 1284544520
|