alegendaryfish commited on
Commit
2d8da02
·
verified ·
1 Parent(s): f672c5d

Public CodonTranslator model and training code release

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CodonTranslator/__init__.py +4 -0
  2. CodonTranslator/__pycache__/__init__.cpython-312.pyc +0 -0
  3. CodonTranslator/__pycache__/layers.cpython-312.pyc +0 -0
  4. CodonTranslator/__pycache__/models.cpython-312.pyc +0 -0
  5. CodonTranslator/__pycache__/tokenizer.cpython-312.pyc +0 -0
  6. CodonTranslator/__pycache__/translator.cpython-312.pyc +0 -0
  7. CodonTranslator/layers.py +239 -0
  8. CodonTranslator/models.py +306 -0
  9. CodonTranslator/tokenizer.py +183 -0
  10. CodonTranslator/translator.py +479 -0
  11. LICENSE +21 -0
  12. README.md +115 -0
  13. __pycache__/precompute_embeddings.cpython-312.pyc +0 -0
  14. __pycache__/resplit_data_v3.cpython-312.pyc +0 -0
  15. __pycache__/sampling.cpython-312.pyc +0 -0
  16. __pycache__/train.cpython-312.pyc +0 -0
  17. batch_eval.py +382 -0
  18. codontranslator/__init__.py +3 -0
  19. environment.yml +20 -0
  20. eval.py +1239 -0
  21. final_model/config.json +17 -0
  22. final_model/model.safetensors +3 -0
  23. final_model/trainer_config.json +17 -0
  24. final_model/trainer_state.json +4 -0
  25. final_model/vocab.json +78 -0
  26. precompute_embeddings.py +503 -0
  27. pyproject.toml +24 -0
  28. requirements.txt +12 -0
  29. resplit_data_v3.py +1444 -0
  30. sampling.py +314 -0
  31. slurm/rebuild_data_v3_cpu.sbatch +98 -0
  32. slurm/submit_train_v3_h200_8x_chain.sh +24 -0
  33. slurm/train_v3_h200_8x_single.sbatch +165 -0
  34. src/__init__.py +33 -0
  35. src/__pycache__/__init__.cpython-312.pyc +0 -0
  36. src/__pycache__/dataset.cpython-312.pyc +0 -0
  37. src/__pycache__/layers.cpython-312.pyc +0 -0
  38. src/__pycache__/models.cpython-312.pyc +0 -0
  39. src/__pycache__/sampler.cpython-312.pyc +0 -0
  40. src/__pycache__/tokenizer.cpython-312.pyc +0 -0
  41. src/__pycache__/trainer.cpython-312.pyc +0 -0
  42. src/dataset.py +833 -0
  43. src/layers.py +384 -0
  44. src/models.py +490 -0
  45. src/sampler.py +696 -0
  46. src/tokenizer.py +324 -0
  47. src/trainer.py +1230 -0
  48. train.py +352 -0
  49. training_checkpoints/checkpoint-71000/config.json +17 -0
  50. training_checkpoints/checkpoint-71000/model.safetensors +3 -0
CodonTranslator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .translator import CodonTranslator
2
+
3
+ __all__ = ["CodonTranslator"]
4
+
CodonTranslator/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (222 Bytes). View file
 
CodonTranslator/__pycache__/layers.cpython-312.pyc ADDED
Binary file (17.4 kB). View file
 
CodonTranslator/__pycache__/models.cpython-312.pyc ADDED
Binary file (20.1 kB). View file
 
CodonTranslator/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (11.2 kB). View file
 
CodonTranslator/__pycache__/translator.cpython-312.pyc ADDED
Binary file (29.9 kB). View file
 
CodonTranslator/layers.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal attention/norm/FFN blocks used by the translator backbone
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.attention import SDPBackend, sdpa_kernel
11
+
12
+
13
+ class RMSNorm(nn.Module):
14
+ def __init__(self, dim: int, eps: float = 1e-6):
15
+ super().__init__()
16
+ self.eps = eps
17
+ self.weight = nn.Parameter(torch.ones(dim))
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
21
+ return x * norm * self.weight
22
+
23
+
24
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
25
+ x1 = x[..., ::2]
26
+ x2 = x[..., 1::2]
27
+ x_rot = torch.zeros_like(x)
28
+ x_rot[..., ::2] = -x2
29
+ x_rot[..., 1::2] = x1
30
+ return x * cos + x_rot * sin
31
+
32
+
33
+ class GroupedQueryAttention(nn.Module):
34
+ def __init__(self, dim: int, num_heads: int, num_kv_groups: int, dropout: float = 0.0, qk_norm: bool = False):
35
+ super().__init__()
36
+ assert num_heads % max(1, num_kv_groups) == 0
37
+ self.dim = dim
38
+ self.num_heads = int(num_heads)
39
+ self.num_kv_groups = max(1, int(num_kv_groups))
40
+ self.group_size = self.num_heads // self.num_kv_groups
41
+ assert dim % num_heads == 0
42
+ self.head_dim = dim // num_heads
43
+ self.dropout = dropout
44
+
45
+ self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
46
+ self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
47
+ self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
48
+ self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False)
49
+
50
+ self.q_norm = RMSNorm(self.head_dim) if qk_norm else None
51
+ self.k_norm = RMSNorm(self.head_dim) if qk_norm else None
52
+
53
+ self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
54
+
55
+ def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype):
56
+ key = (T, device, dtype)
57
+ cached = self._rope_cache.get(key)
58
+ if cached is not None:
59
+ return cached
60
+ dim_half = self.head_dim // 2
61
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
62
+ t = torch.arange(T, device=device, dtype=torch.float32)
63
+ freqs = torch.outer(t, inv_freq)
64
+ cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
65
+ sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
66
+ cos = cos.to(dtype).unsqueeze(0).unsqueeze(0)
67
+ sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
68
+ self._rope_cache[key] = (cos, sin)
69
+ return cos, sin
70
+
71
+ def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int | torch.Tensor = 0):
72
+ B, T_new, _ = x.shape
73
+ q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
74
+ k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous()
75
+ v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous()
76
+
77
+ if self.q_norm is not None:
78
+ q = self.q_norm(q)
79
+ if self.k_norm is not None:
80
+ k = self.k_norm(k)
81
+
82
+ if isinstance(position_offset, int):
83
+ cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
84
+ if position_offset > 0:
85
+ cos = cos[:, :, position_offset: position_offset + T_new, :]
86
+ sin = sin[:, :, position_offset: position_offset + T_new, :]
87
+ q = _apply_rope(q, cos, sin)
88
+ k = _apply_rope(k, cos, sin)
89
+ else:
90
+ off = position_offset.to(device=x.device, dtype=torch.long)
91
+ max_off = int(off.max().item())
92
+ cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype)
93
+ ar = torch.arange(T_new, device=x.device, dtype=torch.long)
94
+ idx = (off.unsqueeze(1) + ar.unsqueeze(0))
95
+ cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1)
96
+ sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1)
97
+ q = _apply_rope(q, cos_b, sin_b)
98
+ k = _apply_rope(k, cos_b, sin_b)
99
+
100
+ if past_kv is not None:
101
+ k_p, v_p = past_kv
102
+ k = torch.cat([k_p, k], dim=2)
103
+ v = torch.cat([v_p, v], dim=2)
104
+
105
+ is_causal = past_kv is None
106
+ # Prefer Flash, then MemEff, then Math; allow FP32 via Math
107
+ with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
108
+ if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
109
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
110
+ with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
111
+ out = F.scaled_dot_product_attention(
112
+ q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
113
+ )
114
+ else:
115
+ out = F.scaled_dot_product_attention(
116
+ q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
117
+ )
118
+ out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim)
119
+ out = self.out_proj(out)
120
+ if use_cache:
121
+ return out, (k, v)
122
+ return out
123
+
124
+
125
+ class SwiGLU(nn.Module):
126
+ """SwiGLU FFN with parameter names matching checkpoints (w1, w2, w3):
127
+ - w1: Linear(dim -> hidden)
128
+ - w2: Linear(hidden -> dim)
129
+ - w3: Linear(dim -> hidden)
130
+ Forward: w2(silu(w1(x)) * w3(x))
131
+ """
132
+ def __init__(self, dim: int, hidden_mult: float = 4.0, dropout: float = 0.0):
133
+ super().__init__()
134
+ hidden = int(dim * hidden_mult)
135
+ self.w1 = nn.Linear(dim, hidden, bias=False)
136
+ self.w2 = nn.Linear(hidden, dim, bias=False)
137
+ self.w3 = nn.Linear(dim, hidden, bias=False)
138
+ self.dropout = nn.Dropout(dropout)
139
+
140
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
141
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
142
+
143
+
144
+ class TransformerBlock(nn.Module):
145
+ def __init__(self, dim: int, num_heads: int, mlp_ratio: float, dropout: float = 0.0, num_kv_groups: Optional[int] = None, qk_norm: bool = False, attn_type: str = "mha"):
146
+ super().__init__()
147
+ if attn_type == "gqa":
148
+ self.attn = GroupedQueryAttention(dim, num_heads=num_heads, num_kv_groups=(num_kv_groups or num_heads), dropout=dropout)
149
+ else:
150
+ self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout)
151
+ self.ffn = SwiGLU(dim, hidden_mult=mlp_ratio, dropout=dropout)
152
+ self.ln1 = RMSNorm(dim)
153
+ self.ln2 = RMSNorm(dim)
154
+
155
+ def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0):
156
+ a = self.attn(self.ln1(x), past_kv=past_kv, use_cache=use_cache, position_offset=position_offset)
157
+ if use_cache:
158
+ a, kv = a
159
+ x = x + a
160
+ x = x + self.ffn(self.ln2(x))
161
+ if use_cache:
162
+ return x, kv
163
+ return x
164
+
165
+
166
+ class MultiHeadAttention(nn.Module):
167
+ """Standard MHA with fused qkv and RoPE, SDPA backend selection.
168
+ Matches checkpoint naming: qkv (dim->3*dim) and out_proj (dim->dim).
169
+ """
170
+
171
+ def __init__(self, dim: int, num_heads: int, dropout: float = 0.0, use_rope: bool = True):
172
+ super().__init__()
173
+ assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
174
+ self.dim = dim
175
+ self.num_heads = num_heads
176
+ self.head_dim = dim // num_heads
177
+ self.dropout = dropout
178
+ self.use_rope = use_rope
179
+
180
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
181
+ self.out_proj = nn.Linear(dim, dim, bias=False)
182
+
183
+ self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
184
+
185
+ def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype):
186
+ key = (T, device, dtype)
187
+ cached = self._rope_cache.get(key)
188
+ if cached is not None:
189
+ return cached
190
+ dim_half = self.head_dim // 2
191
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
192
+ t = torch.arange(T, device=device, dtype=torch.float32)
193
+ freqs = torch.outer(t, inv_freq)
194
+ cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
195
+ sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
196
+ cos = cos.to(dtype).unsqueeze(0).unsqueeze(0)
197
+ sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
198
+ self._rope_cache[key] = (cos, sin)
199
+ return cos, sin
200
+
201
+ def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0):
202
+ B, T_new, _ = x.shape
203
+ qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
204
+ q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous()
205
+
206
+ if self.use_rope:
207
+ cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
208
+ if position_offset > 0:
209
+ cos = cos[:, :, position_offset: position_offset + T_new, :]
210
+ sin = sin[:, :, position_offset: position_offset + T_new, :]
211
+ q = _apply_rope(q, cos, sin)
212
+ k_new = _apply_rope(k_new, cos, sin)
213
+
214
+ if past_kv is not None:
215
+ k, v = past_kv
216
+ k = torch.cat([k, k_new], dim=2)
217
+ v = torch.cat([v, v_new], dim=2)
218
+ else:
219
+ k, v = k_new, v_new
220
+
221
+ is_causal = past_kv is None
222
+ with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
223
+ if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
224
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
225
+ with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
226
+ out = F.scaled_dot_product_attention(
227
+ q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
228
+ )
229
+ else:
230
+ out = F.scaled_dot_product_attention(
231
+ q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
232
+ )
233
+ out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim)
234
+ if out.dtype != x.dtype:
235
+ out = out.to(x.dtype)
236
+ out = self.out_proj(out)
237
+ if use_cache:
238
+ return out, (k, v)
239
+ return out
CodonTranslator/models.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Dict, Any, Tuple, List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.nn.utils.rnn as rnn_utils
9
+
10
+ from .layers import RMSNorm, TransformerBlock
11
+ from .tokenizer import SpecialIds
12
+
13
+
14
+ class FrozenESMCEncoder(nn.Module):
15
+ """Optional ESM-C encoder; if esm isn't available, stays inactive."""
16
+ def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "bf16"):
17
+ super().__init__()
18
+ self.model_name = model_name
19
+ self._device = torch.device(device if torch.cuda.is_available() else "cpu")
20
+ self._autocast_dtype = torch.bfloat16 if dtype == "bf16" else (torch.float16 if dtype == "fp16" else None)
21
+ try:
22
+ from esm.models.esmc import ESMC # type: ignore
23
+ from esm.utils.constants.models import ESMC_300M, ESMC_600M # type: ignore
24
+ except Exception as e:
25
+ raise ImportError(
26
+ "ESM is required for CodonTranslator. Please install 'esm>=3.2.0'."
27
+ ) from e
28
+ if self.model_name == "esmc_300m":
29
+ const = ESMC_300M; self.D_esm = 960
30
+ elif self.model_name == "esmc_600m":
31
+ const = ESMC_600M; self.D_esm = 1152
32
+ else:
33
+ raise ValueError(f"Unknown ESM model: {self.model_name}")
34
+ self.model = ESMC.from_pretrained(model_name=const, device=self._device)
35
+ self.tokenizer = self.model.tokenizer
36
+ for p in self.parameters():
37
+ p.requires_grad_(False)
38
+ self.eval()
39
+
40
+ @torch.no_grad()
41
+ def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"):
42
+ if self.model is None:
43
+ raise RuntimeError("ESM model not available")
44
+ from esm.utils import encoding # type: ignore
45
+ from esm.utils.misc import stack_variable_length_tensors # type: ignore
46
+ pad = self.tokenizer.pad_token_id
47
+ toks = []
48
+ for s in sequences:
49
+ t = encoding.tokenize_sequence(s, self.tokenizer, add_special_tokens=add_special_tokens)
50
+ if max_length is not None and len(t) > max_length:
51
+ t = t[:max_length]
52
+ toks.append(t)
53
+ input_ids = stack_variable_length_tensors(toks, constant_value=pad)
54
+ attention_mask = (input_ids != pad)
55
+ return input_ids, attention_mask
56
+
57
+ @torch.no_grad()
58
+ def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True):
59
+ if self.model is None:
60
+ raise RuntimeError("ESM model not available")
61
+ device = self.model.device
62
+ input_ids = input_ids.to(device)
63
+ attention_mask = attention_mask.to(device) if attention_mask is not None else None
64
+ if self._autocast_dtype is not None and device.type == "cuda":
65
+ with torch.amp.autocast('cuda', dtype=self._autocast_dtype):
66
+ outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
67
+ else:
68
+ outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
69
+ return {"embeddings": outputs.embeddings, "attention_mask": attention_mask}
70
+
71
+ def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None):
72
+ if attention_mask is not None:
73
+ lengths = attention_mask.sum(dim=1) - 2
74
+ lengths = lengths.clamp(min=1)
75
+ else:
76
+ B, L, D = embeddings.shape
77
+ lengths = torch.full((B,), L - 2, device=embeddings.device)
78
+ stripped = embeddings[:, 1:-1, :]
79
+ return stripped, lengths
80
+
81
+
82
+ class TranslatorBackbone(nn.Module):
83
+ def __init__(
84
+ self,
85
+ vocab_size: int = 79,
86
+ hidden_size: int = 960,
87
+ num_layers: int = 24,
88
+ num_heads: int = 16,
89
+ mlp_ratio: float = 4.0,
90
+ max_position_embeddings: int = 4096,
91
+ dropout: float = 0.1,
92
+ layer_norm_eps: float = 1e-6,
93
+ num_special_tokens: int = 13,
94
+ special_ids: Optional[SpecialIds] = None,
95
+ esm_model_name: str = "esmc_300m",
96
+ esm_device: str = "cuda",
97
+ esm_dtype: str = "bf16",
98
+ max_protein_prefix: int = 0,
99
+ max_species_prefix: int = 0,
100
+ prepend_species: bool = True,
101
+ prepend_protein: bool = True,
102
+ species_embedding_dim: int = 1024,
103
+ attn_impl: str = "gqa",
104
+ num_kv_groups: int = 0,
105
+ ):
106
+ super().__init__()
107
+ self.vocab_size = int(vocab_size)
108
+ self.hidden_size = int(hidden_size)
109
+ self.num_layers = int(num_layers)
110
+ self.num_heads = int(num_heads)
111
+ self.max_position_embeddings = int(max_position_embeddings)
112
+ self.special_ids = special_ids or SpecialIds()
113
+ self.num_special_tokens = int(num_special_tokens)
114
+
115
+ self.token_embed = nn.Embedding(self.vocab_size, self.hidden_size)
116
+
117
+ # Optional ESM protein encoder
118
+ self.esm = None
119
+ self.esm_ln = None
120
+ if prepend_protein and esm_model_name:
121
+ # Enforce ESM presence – raise if missing
122
+ self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype)
123
+ self.esm_ln = nn.Sequential(
124
+ nn.Linear(self.esm.D_esm, self.hidden_size, bias=False),
125
+ nn.ReLU(),
126
+ nn.LayerNorm(self.hidden_size),
127
+ )
128
+ self.species_embedding_dim = species_embedding_dim if prepend_species else 0
129
+ self.species_ln = None
130
+ if prepend_species:
131
+ self.species_ln = nn.Sequential(
132
+ nn.Linear(self.species_embedding_dim, self.hidden_size, bias=False),
133
+ nn.ReLU(),
134
+ nn.LayerNorm(self.hidden_size),
135
+ )
136
+
137
+ self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0
138
+ self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0
139
+ self.prepend_species = bool(prepend_species)
140
+ self.prepend_protein = bool(prepend_protein) and (self.esm is not None)
141
+
142
+ self.start_embed = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
143
+ nn.init.normal_(self.start_embed, mean=0.0, std=0.02)
144
+
145
+ self.attn_impl = str(attn_impl)
146
+ kv_groups = int(num_kv_groups)
147
+ self.blocks = nn.ModuleList([
148
+ TransformerBlock(
149
+ dim=self.hidden_size,
150
+ num_heads=self.num_heads,
151
+ mlp_ratio=mlp_ratio,
152
+ dropout=dropout,
153
+ num_kv_groups=(kv_groups if (kv_groups > 0 and self.attn_impl == "gqa") else None),
154
+ qk_norm=False,
155
+ attn_type=("mha" if self.attn_impl == "mha" else "gqa"),
156
+ )
157
+ for _ in range(self.num_layers)
158
+ ])
159
+
160
+ self.ln_f = RMSNorm(self.hidden_size, eps=layer_norm_eps)
161
+ self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
162
+ self.gradient_checkpointing = False
163
+
164
+ def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
165
+ device = self.token_embed.weight.device
166
+ return self.token_embed(token_ids.to(device))
167
+
168
+ def build_prefix(
169
+ self,
170
+ batch_size: int,
171
+ device: torch.device,
172
+ species_tok_emb: Optional[torch.Tensor] = None,
173
+ species_emb: Optional[torch.Tensor] = None,
174
+ protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
176
+ parts: list[torch.Tensor] = []
177
+ if self.prepend_species and self.species_ln is not None:
178
+ if species_emb is not None:
179
+ S = self.species_ln(species_emb.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1))
180
+ parts.append(S)
181
+ parts.append(S)
182
+ elif species_tok_emb is not None:
183
+ S = species_tok_emb
184
+ if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix:
185
+ S = S[:, : self.max_species_prefix, :]
186
+ S = self.species_ln(S.to(device=device, dtype=next(self.parameters()).dtype))
187
+ parts.append(S)
188
+ parts.append(S)
189
+
190
+ if self.prepend_protein and self.esm is not None and protein_input is not None:
191
+ prot_ids, prot_mask = protein_input
192
+ esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True)
193
+ P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask)
194
+ if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix:
195
+ P = P[:, : self.max_protein_prefix, :]
196
+ lengths = lengths.clamp(max=self.max_protein_prefix) if lengths is not None else None
197
+ if P.size(1) > 0:
198
+ P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype))
199
+ if lengths is not None:
200
+ Lp = P.size(1)
201
+ ar = torch.arange(Lp, device=device).unsqueeze(0)
202
+ valid = ar < lengths.unsqueeze(1)
203
+ P = P * valid.unsqueeze(-1)
204
+ parts.append(P)
205
+
206
+ if len(parts) == 0:
207
+ empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype)
208
+ return empty, torch.zeros(batch_size, dtype=torch.long, device=device)
209
+
210
+ prefix = torch.cat(parts, dim=1)
211
+ with torch.no_grad():
212
+ valid = (prefix.abs().sum(dim=-1) > 0)
213
+ lengths = valid.sum(dim=1).to(torch.long)
214
+ prefix_budget = max(0, int(self.max_position_embeddings) - 1)
215
+ allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype))
216
+ Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0
217
+ if prefix.size(1) > Lp_max:
218
+ trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2))
219
+ for b in range(prefix.size(0)):
220
+ lb = int(allow[b].item())
221
+ if lb > 0:
222
+ trimmed[b, :lb, :] = prefix[b, :lb, :]
223
+ prefix = trimmed
224
+ lengths = allow
225
+ else:
226
+ lengths = allow
227
+ return prefix, lengths
228
+
229
+ def forward(self, codon_ids: torch.Tensor, cond: Dict[str, Any] = None, labels: Optional[torch.Tensor] = None, return_dict: bool = True, use_cache: bool = False, past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_offset: int = 0) -> Dict[str, torch.Tensor]:
230
+ batch_size, codon_len = codon_ids.shape
231
+ device = codon_ids.device
232
+ species_tok_emb = cond.get("species_tok_emb") if cond else None
233
+ species_emb = cond.get("species_emb") if cond else None
234
+ protein_input = cond.get("protein_input") if cond else None
235
+
236
+ # Build prefix
237
+ prefix, prefix_lengths = self.build_prefix(batch_size, device, species_tok_emb=species_tok_emb, species_emb=species_emb, protein_input=protein_input)
238
+ start = self.start_embed.expand(batch_size, 1, -1)
239
+
240
+ # KV cache path for incremental generation
241
+ if past_kv is not None and codon_len > 0:
242
+ x = self.embed_tokens(codon_ids)
243
+ present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = []
244
+ for i, block in enumerate(self.blocks):
245
+ kv_i = past_kv[i] if i < len(past_kv) else None
246
+ out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset)
247
+ x, kv_out = out_blk
248
+ present_kv.append(kv_out)
249
+ x = self.ln_f(x)
250
+ logits_step = self.lm_head(x)
251
+ return {"logits": logits_step[:, 0:0, :], "next_logits": logits_step[:, -1, :], "present_kv": present_kv, "prefix_len": prefix_lengths}
252
+
253
+ # Non-incremental: build prefix+start+codon window
254
+ codon_lens = torch.as_tensor([codon_len] * batch_size, device=device)
255
+ capacity = max(0, int(self.max_position_embeddings))
256
+ budget_after_prefix = torch.clamp(torch.as_tensor(capacity, device=device) - (prefix_lengths + 1), min=0)
257
+ per_cap = torch.minimum(budget_after_prefix, codon_lens)
258
+ max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0
259
+ codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) if max_cap > 0 else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype)
260
+ seqs = []
261
+ for b in range(batch_size):
262
+ lp = int(prefix_lengths[b].item())
263
+ cap = int(per_cap[b].item())
264
+ parts = []
265
+ if lp > 0:
266
+ parts.append(prefix[b, :lp, :])
267
+ parts.append(start[b, 0:1, :])
268
+ if cap > 0:
269
+ parts.append(codon_emb[b, :cap, :])
270
+ seqs.append(torch.cat(parts, dim=0))
271
+ x = rnn_utils.pad_sequence(seqs, batch_first=True)
272
+
273
+ present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = []
274
+ for block in self.blocks:
275
+ blk_out = block(x, use_cache=use_cache, position_offset=0)
276
+ if use_cache:
277
+ x, kv = blk_out
278
+ present_kv_list.append(kv)
279
+ else:
280
+ x = blk_out
281
+ x = self.ln_f(x)
282
+ logits_full = self.lm_head(x)
283
+
284
+ next_logits_list = []
285
+ if max_cap == 0:
286
+ codon_logits = logits_full[:, 0:0, :]
287
+ for b in range(batch_size):
288
+ lp = int(prefix_lengths[b].item())
289
+ pos_next = lp
290
+ next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full[b, -1, :])
291
+ next_logits = torch.stack(next_logits_list, dim=0)
292
+ else:
293
+ slices = []
294
+ for b in range(batch_size):
295
+ lp = int(prefix_lengths[b].item())
296
+ cap = int(per_cap[b].item())
297
+ sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size)
298
+ slices.append(sl)
299
+ pos_next = lp + cap
300
+ next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size))
301
+ codon_logits = rnn_utils.pad_sequence(slices, batch_first=True)
302
+ next_logits = torch.stack(next_logits_list, dim=0)
303
+ out = {"logits": codon_logits, "next_logits": next_logits, "prefix_len": prefix_lengths}
304
+ if use_cache:
305
+ out["present_kv"] = present_kv_list
306
+ return out
CodonTranslator/tokenizer.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal copy of CodonTokenizer from src/tokenizer.py to keep the package self-contained.
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple, Any
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class SpecialIds:
13
+ pad: int = 0
14
+ unk: int = 1
15
+ bos: int = 2
16
+ eos: int = 3
17
+
18
+ def to_dict(self) -> Dict[str, int]:
19
+ return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos}
20
+
21
+
22
+ class CodonTokenizer:
23
+ __slots__ = (
24
+ "codons",
25
+ "_special_token_str",
26
+ "vocab",
27
+ "ids_to_tokens",
28
+ "_special_ids",
29
+ "_num_special_tokens",
30
+ "_genetic_code",
31
+ "_codon2aa_char",
32
+ "_aa2codons_char",
33
+ )
34
+
35
+ def __init__(
36
+ self,
37
+ pad_token: str = "<pad>",
38
+ unk_token: str = "<unk>",
39
+ bos_token: str = "<bos>",
40
+ eos_token: str = "<stop>",
41
+ **_: Any,
42
+ ) -> None:
43
+ bases = ("A", "C", "G", "T")
44
+ self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases]
45
+
46
+ special_tokens = [pad_token, unk_token, bos_token, eos_token]
47
+ self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token}
48
+
49
+ self.vocab: Dict[str, int] = {}
50
+ for i, tok in enumerate(special_tokens):
51
+ self.vocab[tok] = i
52
+ for codon in self.codons:
53
+ self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens))
54
+
55
+ self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()}
56
+
57
+ self._special_ids = SpecialIds(
58
+ pad=self.vocab[pad_token],
59
+ unk=self.vocab[unk_token],
60
+ bos=self.vocab[bos_token],
61
+ eos=self.vocab[eos_token],
62
+ )
63
+ self._num_special_tokens = len(special_tokens)
64
+
65
+ self._genetic_code: Dict[str, str] = {
66
+ "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L",
67
+ "TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S",
68
+ "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*",
69
+ "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W",
70
+ "CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L",
71
+ "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
72
+ "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q",
73
+ "CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R",
74
+ "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
75
+ "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T",
76
+ "AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K",
77
+ "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
78
+ "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V",
79
+ "GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A",
80
+ "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
81
+ "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G",
82
+ }
83
+
84
+ self._codon2aa_char: Dict[int, str] = {}
85
+ self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
86
+ for codon in self.codons:
87
+ cid = self.vocab[codon]
88
+ aa = self._genetic_code.get(codon, "X")
89
+ self._codon2aa_char[cid] = aa
90
+ if aa in self._aa2codons_char:
91
+ self._aa2codons_char[aa].append(cid)
92
+
93
+ @property
94
+ def vocab_size(self) -> int:
95
+ return len(self.vocab)
96
+
97
+ @property
98
+ def special_ids(self) -> SpecialIds:
99
+ return self._special_ids
100
+
101
+ @property
102
+ def num_special_tokens(self) -> int:
103
+ return self._num_special_tokens
104
+
105
+ @property
106
+ def pad_token_id(self) -> int:
107
+ return self._special_ids.pad
108
+
109
+ @property
110
+ def eos_token_id(self) -> int:
111
+ return self._special_ids.eos
112
+
113
+ # helpers
114
+ def codon_vocab(self) -> Dict[str, int]:
115
+ return {c: self.vocab[c] for c in self.codons}
116
+
117
+ def codon2aa_char_map(self) -> Dict[int, str]:
118
+ return dict(self._codon2aa_char)
119
+
120
+ def aa2codons_char_map(self) -> Dict[str, List[int]]:
121
+ return {k: v[:] for k, v in self._aa2codons_char.items()}
122
+
123
+ # decoding
124
+ def decode_codon_seq(self, token_ids: List[int]) -> str:
125
+ parts: List[str] = []
126
+ nst = self._num_special_tokens
127
+ for tid in token_ids:
128
+ if tid >= nst:
129
+ tok = self.ids_to_tokens.get(tid)
130
+ if tok is not None:
131
+ parts.append(tok)
132
+ return "".join(parts)
133
+
134
+ # persistence
135
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
136
+ os.makedirs(save_directory, exist_ok=True)
137
+ vocab_file = os.path.join(
138
+ save_directory,
139
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
140
+ )
141
+ payload = {
142
+ "vocab": self.vocab,
143
+ "special_token_str": self._special_token_str,
144
+ }
145
+ with open(vocab_file, "w", encoding="utf-8") as f:
146
+ json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)
147
+ return (vocab_file,)
148
+
149
+ @classmethod
150
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer":
151
+ vocab_path = Path(pretrained_model_name_or_path) / "vocab.json"
152
+ tok = cls(**kwargs)
153
+ if not vocab_path.exists():
154
+ return tok
155
+ with open(vocab_path, "r", encoding="utf-8") as f:
156
+ save_data = json.load(f)
157
+ vocab = save_data["vocab"] if isinstance(save_data, dict) and "vocab" in save_data else save_data
158
+ tok.vocab = {str(k): int(v) for k, v in vocab.items()}
159
+ tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()}
160
+ sts = save_data.get("special_token_str", tok._special_token_str) if isinstance(save_data, dict) else tok._special_token_str
161
+ tok._special_token_str.update(sts)
162
+ def _id_for(name: str, default_val: int) -> int:
163
+ sym = tok._special_token_str[name]
164
+ return int(tok.vocab.get(sym, default_val))
165
+ tok._special_ids = SpecialIds(
166
+ pad=_id_for("pad", 0),
167
+ unk=_id_for("unk", 1),
168
+ bos=_id_for("bos", 2),
169
+ eos=_id_for("eos", 3),
170
+ )
171
+ ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos]
172
+ m = max(ids)
173
+ tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4
174
+ # rebuild helpers
175
+ tok._codon2aa_char = {}
176
+ tok._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
177
+ for codon in tok.codons:
178
+ cid = tok.vocab[codon]
179
+ aa = tok._genetic_code.get(codon, "X")
180
+ tok._codon2aa_char[cid] = aa
181
+ if aa in tok._aa2codons_char:
182
+ tok._aa2codons_char[aa].append(cid)
183
+ return tok
CodonTranslator/translator.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from safetensors.torch import load_file
12
+
13
+ from .models import TranslatorBackbone
14
+ from .tokenizer import CodonTokenizer
15
+ # no external store at inference; species embeddings computed via Qwen
16
+
17
+
18
+ class CodonTranslator:
19
+ """
20
+ High-level sampling wrapper for trained checkpoints with a simple API:
21
+
22
+ from CodonTranslator import CodonTranslator
23
+ model = CodonTranslator.from_pretrained(model_path)
24
+ dna = model.sampling(species="Homo sapiens", protein_seq="M...", enforce_mapping=True)
25
+ """
26
+
27
+ def __init__(self, model_dir: Union[str, Path], device: str = "cuda", use_gbif: bool = False):
28
+ self.model_dir = Path(model_dir)
29
+ self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
30
+ self.tokenizer = CodonTokenizer.from_pretrained(str(self.model_dir))
31
+ self.V = int(self.tokenizer.vocab_size)
32
+ self._eos_id = int(self.tokenizer.eos_token_id)
33
+ self._pad_id = int(self.tokenizer.pad_token_id)
34
+ self._num_special = int(self.tokenizer.num_special_tokens)
35
+
36
+ # Load config
37
+ cfg_path = self.model_dir / "trainer_config.json"
38
+ if not cfg_path.exists():
39
+ cfg_path = self.model_dir / "config.json"
40
+ with open(cfg_path, "r") as f:
41
+ self.config = json.load(f)
42
+
43
+ # Build model and load weights
44
+ state = self._load_state_dict()
45
+ arch = self._infer_arch_from_state_dict(state)
46
+ self.model = TranslatorBackbone(
47
+ vocab_size=self.V,
48
+ hidden_size=int(arch["hidden_size"]),
49
+ num_layers=int(arch["num_layers"]),
50
+ num_heads=int(arch["num_heads"]),
51
+ mlp_ratio=float(arch.get("mlp_ratio", 4.0)),
52
+ max_position_embeddings=int(arch["max_position_embeddings"]),
53
+ num_special_tokens=self._num_special,
54
+ special_ids=self.tokenizer.special_ids,
55
+ prepend_species=bool(arch.get("prepend_species", True)),
56
+ prepend_protein=bool(arch.get("prepend_protein", False)),
57
+ species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)),
58
+ esm_model_name=str(arch.get("esm_model_name", "esmc_300m")),
59
+ esm_device=str(arch.get("esm_device", "cuda")),
60
+ esm_dtype=str(arch.get("esm_dtype", "bf16")),
61
+ max_protein_prefix=int(arch.get("max_protein_prefix", 0)),
62
+ max_species_prefix=int(arch.get("max_species_prefix", 0)),
63
+ attn_impl=str(arch.get("attn_impl", "gqa")),
64
+ num_kv_groups=int(arch.get("num_kv_groups", 0)),
65
+ )
66
+ missing, unexpected = self.model.load_state_dict(state, strict=False)
67
+ if len(unexpected) > 0:
68
+ # non-fatal
69
+ pass
70
+ self.model.to(self.device).eval()
71
+
72
+ # Static masks
73
+ self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device)
74
+ self._allowed_fixed[:self._num_special] = False
75
+ self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device)
76
+ self._allowed_variable[:self._num_special] = False
77
+ self._allowed_variable[self._eos_id] = True
78
+
79
+ # Species taxonomy: either query GBIF (if allowed) or use raw names.
80
+ self._use_gbif = bool(use_gbif)
81
+ self._taxonomy_cache: Dict[str, str] = {}
82
+
83
+ # ---- constructors ----
84
+ @classmethod
85
+ def from_pretrained(cls, model_path: Union[str, Path], device: str = "cuda", use_gbif: bool = False) -> "CodonTranslator":
86
+ return cls(model_path, device=device, use_gbif=use_gbif)
87
+
88
+ # ---- sampling APIs ----
89
+ @torch.no_grad()
90
+ def sampling(self, species: str, protein_seq: str, enforce_mapping: bool = False, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, seed: Optional[int] = None, use_kv_cache: bool = True) -> str:
91
+ out = self.batch_inference(
92
+ species=[species],
93
+ protein_seqs=[protein_seq],
94
+ enforce_mapping=enforce_mapping,
95
+ temperature=temperature,
96
+ top_k=top_k,
97
+ top_p=top_p,
98
+ seed=seed,
99
+ use_kv_cache=use_kv_cache,
100
+ )
101
+ return out[0]
102
+
103
+ @torch.no_grad()
104
+ def batch_inference(
105
+ self,
106
+ species: List[str],
107
+ protein_seqs: List[str],
108
+ enforce_mapping: bool = False,
109
+ temperature: float = 1.0,
110
+ top_k: Optional[int] = None,
111
+ top_p: Optional[float] = None,
112
+ seed: Optional[int] = None,
113
+ use_kv_cache: bool = True,
114
+ micro_batch_size: int = 1,
115
+ ) -> List[str]:
116
+ """Generate DNA for a list of protein sequences, using micro-batching to limit memory.
117
+
118
+ - micro_batch_size: number of samples to process at once (default=1 for low memory)
119
+ """
120
+ assert len(species) == len(protein_seqs), "species and protein_seqs length must match"
121
+ mb = max(1, int(micro_batch_size))
122
+ if len(species) <= mb:
123
+ return self._batch_inference_core(
124
+ species=species,
125
+ protein_seqs=protein_seqs,
126
+ enforce_mapping=enforce_mapping,
127
+ temperature=temperature,
128
+ top_k=top_k,
129
+ top_p=top_p,
130
+ seed=seed,
131
+ use_kv_cache=use_kv_cache,
132
+ )
133
+
134
+ outputs: List[str] = []
135
+ for start in range(0, len(species), mb):
136
+ end = min(start + mb, len(species))
137
+ chunk_out = self._batch_inference_core(
138
+ species=species[start:end],
139
+ protein_seqs=protein_seqs[start:end],
140
+ enforce_mapping=enforce_mapping,
141
+ temperature=temperature,
142
+ top_k=top_k,
143
+ top_p=top_p,
144
+ seed=seed,
145
+ use_kv_cache=use_kv_cache,
146
+ )
147
+ outputs.extend(chunk_out)
148
+ return outputs
149
+
150
+ @torch.no_grad()
151
+ def _batch_inference_core(
152
+ self,
153
+ species: List[str],
154
+ protein_seqs: List[str],
155
+ enforce_mapping: bool = False,
156
+ temperature: float = 1.0,
157
+ top_k: Optional[int] = None,
158
+ top_p: Optional[float] = None,
159
+ seed: Optional[int] = None,
160
+ use_kv_cache: bool = True,
161
+ ) -> List[str]:
162
+ if seed is not None:
163
+ torch.manual_seed(int(seed))
164
+ np.random.seed(int(seed))
165
+ B = len(species)
166
+ assert B == len(protein_seqs), "species and protein_seqs length must match"
167
+ target_lens = torch.tensor([len(s) for s in protein_seqs], device=self.device, dtype=torch.long)
168
+ T_codons = int(target_lens.max().item())
169
+
170
+ # Prepare conditioning
171
+ cond: Dict[str, Any] = {"control_mode": "fixed"}
172
+
173
+ # Species embeddings via Qwen3-Embedding (variable-length token sequences)
174
+ q_tok, lengths = self._qwen_embed_names(species, pooling="sequence") # [B, L, D]
175
+ # Always surface a message so users can see species embeddings are used
176
+ print(f"[CodonTranslator] Species embeddings (Qwen) computed: shape={tuple(q_tok.shape)}")
177
+ cond["species_tok_emb"] = q_tok.to(self.device)
178
+
179
+ # Protein input via ESM (if available) – let model tokenize internally
180
+ if getattr(self.model, "esm", None) is not None:
181
+ # Tokenize AA sequences with model.esm
182
+ max_len_tokens = (getattr(self.model, "max_protein_prefix", 0) + 2) if getattr(self.model, "max_protein_prefix", 0) > 0 else None
183
+ prot_ids, prot_mask = self.model.esm.tokenize(protein_seqs, max_length=max_len_tokens)
184
+ cond["protein_input"] = (prot_ids.to(self.device), prot_mask.to(self.device))
185
+
186
+ # Start generation with empty context to build KV cache and initial logits
187
+ input_ids = torch.zeros(B, 0, dtype=torch.long, device=self.device)
188
+ out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=use_kv_cache)
189
+ kv = out_prefill.get("present_kv") if use_kv_cache else None
190
+ logits = out_prefill.get("next_logits")
191
+ assert logits is not None
192
+ # Report prefix length to prove species/protein prefixes were incorporated
193
+ try:
194
+ pref = out_prefill.get("prefix_len")
195
+ if pref is not None:
196
+ lst = pref.detach().cpu().tolist()
197
+ print(f"[CodonTranslator] Prefix lengths (species,species,protein): {lst}")
198
+ except Exception:
199
+ pass
200
+
201
+ allowed = self._allowed_fixed
202
+ finished = torch.zeros(B, dtype=torch.bool, device=self.device)
203
+
204
+ aa2codons = self.tokenizer.aa2codons_char_map()
205
+
206
+ rng = range(T_codons)
207
+ # Greedy mode: temperature <= 0 selects argmax deterministically
208
+ greedy_mode = (temperature is not None and float(temperature) <= 0.0)
209
+ for step in rng:
210
+ logits = logits.masked_fill(~allowed, float("-inf"))
211
+
212
+ # Stop sampling per-sample once reaching its target length; force PAD
213
+ done_now = (torch.tensor(step, device=self.device) >= target_lens)
214
+ if done_now.any():
215
+ logits[done_now] = float("-inf")
216
+ logits[done_now, self._pad_id] = 0.0
217
+
218
+ # Enforce codon ↔ AA mapping at this step
219
+ if enforce_mapping:
220
+ aas_now = [seq[step] if step < len(seq) else None for seq in protein_seqs]
221
+ mask = torch.zeros_like(logits, dtype=torch.bool)
222
+ for i, a in enumerate(aas_now):
223
+ if a is None:
224
+ mask[i, self._num_special:self.V] = True
225
+ else:
226
+ valid = aa2codons.get(a, [])
227
+ if len(valid) == 0:
228
+ mask[i, self._num_special:self.V] = True
229
+ else:
230
+ mask[i, valid] = True
231
+ logits = logits.masked_fill(~mask, float("-inf"))
232
+
233
+ if not greedy_mode and temperature != 1.0:
234
+ logits = logits / float(temperature)
235
+ if top_k is not None:
236
+ logits = self._top_k_filtering(logits, int(top_k))
237
+ if top_p is not None:
238
+ logits = self._top_p_filtering(logits, float(top_p))
239
+
240
+ if greedy_mode:
241
+ next_tok = torch.argmax(logits, dim=-1, keepdim=True)
242
+ else:
243
+ probs = F.softmax(logits, dim=-1)
244
+ next_tok = torch.multinomial(probs, num_samples=1)
245
+
246
+ input_ids = torch.cat([input_ids, next_tok], dim=1)
247
+
248
+ if use_kv_cache:
249
+ pos_offset = int(out_prefill.get("prefix_len").max().item()) + input_ids.size(1) - 1 if isinstance(out_prefill, dict) and ("prefix_len" in out_prefill) else input_ids.size(1) - 1
250
+ out_inc = self.model(
251
+ codon_ids=next_tok,
252
+ cond=None,
253
+ return_dict=True,
254
+ use_cache=True,
255
+ past_kv=kv,
256
+ position_offset=pos_offset,
257
+ )
258
+ kv = out_inc.get("present_kv")
259
+ logits = out_inc.get("next_logits")
260
+ else:
261
+ # Recompute full forward with prefix+all generated tokens
262
+ out_full = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=False)
263
+ logits = out_full.get("next_logits")
264
+
265
+ # Build DNA strings, dropping specials
266
+ output_token_rows: List[List[int]] = []
267
+ for i, row in enumerate(input_ids.tolist()):
268
+ toks: List[int] = []
269
+ for t in row:
270
+ if t == self._pad_id:
271
+ continue
272
+ if t == self._eos_id:
273
+ break
274
+ if t >= self._num_special and t < self.V:
275
+ toks.append(int(t))
276
+ toks = toks[: int(target_lens[i].item())]
277
+ output_token_rows.append(toks)
278
+ sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows]
279
+
280
+ # If not enforcing mapping, report AA token accuracy vs provided targets
281
+ if not enforce_mapping:
282
+ for i, dna in enumerate(sequences):
283
+ tgt = protein_seqs[i]
284
+ gen_aa = self._dna_to_aa(dna)
285
+ L = min(len(gen_aa), len(tgt))
286
+ if L == 0:
287
+ acc = 0.0; num = 0; den = 0
288
+ else:
289
+ num = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b)
290
+ den = L
291
+ acc = num / den
292
+ print(f"[CodonTranslator] AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})")
293
+ return sequences
294
+
295
+ # ---- helpers ----
296
+ def _load_state_dict(self) -> Dict[str, torch.Tensor]:
297
+ st_p = self.model_dir / "model.safetensors"
298
+ if st_p.exists():
299
+ return load_file(st_p)
300
+ pt_p = self.model_dir / "pytorch_model.bin"
301
+ if pt_p.exists():
302
+ return torch.load(pt_p, map_location="cpu")
303
+ raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}")
304
+
305
+ def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
306
+ arch: Dict[str, Any] = {}
307
+ if "lm_head.weight" in state_dict:
308
+ arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1])
309
+ else:
310
+ for k, v in state_dict.items():
311
+ if k.endswith("ln_f.weight"):
312
+ arch["hidden_size"] = int(v.shape[0])
313
+ break
314
+ cfg = self.config or {}
315
+ if "hidden_size" in cfg:
316
+ arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore
317
+ if "hidden_size" not in arch:
318
+ arch["hidden_size"] = int(cfg.get("hidden_size", 750))
319
+ H = int(arch["hidden_size"])
320
+
321
+ max_block = -1
322
+ for k in state_dict.keys():
323
+ if k.startswith("blocks."):
324
+ idx = int(k.split(".")[1])
325
+ if idx > max_block:
326
+ max_block = idx
327
+ arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12))
328
+ if "num_hidden_layers" in cfg:
329
+ arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore
330
+
331
+ # mlp ratio
332
+ w1_key = next((k for k in state_dict.keys() if k.endswith("ffn.w1.weight")), None)
333
+ if w1_key is not None:
334
+ arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H)
335
+ else:
336
+ arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0))
337
+
338
+ # heads: pick divisor
339
+ cfg_heads = cfg.get("num_attention_heads")
340
+ if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0:
341
+ arch["num_heads"] = int(cfg_heads)
342
+ else:
343
+ for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1):
344
+ if H % h == 0:
345
+ arch["num_heads"] = h
346
+ break
347
+
348
+ arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys())))
349
+ has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys())
350
+ arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm)))
351
+ arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m"))
352
+ arch["esm_device"] = str(cfg.get("esm_device", "cuda"))
353
+ arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower()
354
+ arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0))
355
+ arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0))
356
+ arch["max_position_embeddings"] = int(cfg.get("max_length", cfg.get("max_position_embeddings", 2048)))
357
+ arch["attn_impl"] = str(cfg.get("attn_impl", "gqa"))
358
+ arch["num_kv_groups"] = int(cfg.get("num_kv_groups", 0))
359
+ return arch
360
+
361
+ # --- filtering helpers
362
+ @staticmethod
363
+ def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor:
364
+ return logits if logits.dim() == 2 else logits.unsqueeze(0)
365
+
366
+ @staticmethod
367
+ def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
368
+ x = CodonTranslator._ensure_2d_logits(logits)
369
+ k = max(1, min(int(k), x.size(-1)))
370
+ values, _ = torch.topk(x, k, dim=-1)
371
+ min_values = values[:, -1].unsqueeze(-1)
372
+ x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x)
373
+ return x if logits.dim() == 2 else x.squeeze(0)
374
+
375
+ @staticmethod
376
+ def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
377
+ if p >= 1.0:
378
+ return logits
379
+ if p <= 0.0:
380
+ return torch.full_like(logits, float('-inf'))
381
+ x = CodonTranslator._ensure_2d_logits(logits)
382
+ sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1)
383
+ probs = torch.softmax(sorted_logits, dim=-1)
384
+ cumprobs = torch.cumsum(probs, dim=-1)
385
+ to_remove = cumprobs > p
386
+ # Avoid overlapping memory writes by cloning the RHS
387
+ to_remove = to_remove.to(torch.bool)
388
+ to_remove[:, 1:] = to_remove[:, :-1].clone()
389
+ to_remove[:, 0] = False
390
+ mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove)
391
+ x = torch.where(mask, torch.full_like(x, float('-inf')), x)
392
+ return x if logits.dim() == 2 else x.squeeze(0)
393
+
394
+ # --- Qwen embedding fallback for species text ---
395
+ def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
396
+ from transformers import AutoTokenizer, AutoModel
397
+ tokenizer = AutoTokenizer.from_pretrained(
398
+ "Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left"
399
+ )
400
+ dtype = torch.float16 if self.device.type == "cuda" else torch.float32
401
+ model = AutoModel.from_pretrained(
402
+ "Qwen/Qwen3-Embedding-0.6B", dtype=dtype, trust_remote_code=True
403
+ ).to(self.device).eval()
404
+ task = (
405
+ "Given a species taxonomy information, generate a biological embedding "
406
+ "representing its taxonomic and evolutionary characteristics"
407
+ )
408
+ queries = self._resolve_taxonomy_texts(names)
409
+ texts = [f"Instruct: {task}\nQuery: {q}" for q in queries]
410
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
411
+ out = model(**inputs)
412
+ h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1)
413
+ attn = inputs["attention_mask"]
414
+ # sequence embeddings padded to same length by tokenizer padding
415
+ return h, torch.sum(attn, dim=1)
416
+
417
+ def _taxonomy_lookup(self, name: str) -> str:
418
+ if name in self._taxonomy_cache:
419
+ return self._taxonomy_cache[name]
420
+ if self._use_gbif:
421
+ try:
422
+ import requests
423
+ resp = requests.get("https://api.gbif.org/v1/species/match", params={"name": name}, timeout=5)
424
+ if resp.status_code == 200:
425
+ data = resp.json()
426
+ if data.get("matchType") != "NONE":
427
+ parts = []
428
+ taxonomy = []
429
+ for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]:
430
+ if rank in data and data[rank]:
431
+ taxonomy.append(data[rank])
432
+ if taxonomy:
433
+ parts.append("Taxonomy: " + " > ".join(taxonomy))
434
+ if "vernacularName" in data and data["vernacularName"]:
435
+ parts.append(f"Common name: {data['vernacularName']}")
436
+ if "confidence" in data:
437
+ parts.append(f"Match confidence: {data['confidence']}%")
438
+ if "status" in data:
439
+ parts.append(f"Status: {data['status']}")
440
+ desc = ". ".join(parts) if parts else name
441
+ self._taxonomy_cache[name] = desc
442
+ return desc
443
+ except Exception:
444
+ pass
445
+ return name
446
+
447
+ def _resolve_taxonomy_texts(self, names: List[str]) -> List[str]:
448
+ """Resolve taxonomy strings for a batch of species names.
449
+ If a taxonomy DB is present, pull from it. Otherwise batch-query GBIF
450
+ (one request per species) and cache results. Always returns a list of
451
+ strings aligned to `names`.
452
+ """
453
+ results: List[str] = []
454
+ # Batch “query”: loop per-name; still batched at the embedding stage
455
+ fetched = 0
456
+ for s in names:
457
+ txt = self._taxonomy_lookup(s)
458
+ if s in self._taxonomy_cache:
459
+ fetched += 1
460
+ results.append(txt)
461
+ if self._use_gbif:
462
+ print(f"[CodonTranslator] Taxonomy texts resolved (GBIF={'on' if self._use_gbif else 'off'}): {fetched}/{len(names)} fetched")
463
+ return results
464
+
465
+ @staticmethod
466
+ def _dna_to_aa(dna_seq: str) -> str:
467
+ g = {
468
+ 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
469
+ 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
470
+ 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
471
+ 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
472
+ 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
473
+ 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
474
+ 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
475
+ 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
476
+ }
477
+ L = len(dna_seq) // 3
478
+ aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)]
479
+ return ''.join(aa)
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 CodonTranslator authors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ tags:
5
+ - biology
6
+ - dna
7
+ - codon-optimization
8
+ - protein-conditioned-generation
9
+ - fsdp
10
+ datasets:
11
+ - alegendaryfish/CodonTranslator-data
12
+ ---
13
+
14
+ # CodonTranslator
15
+
16
+ CodonTranslator is a protein-conditioned codon sequence generation model trained on the representative-only `data_v3` release.
17
+
18
+ This repository is the public model and training-code release. It contains:
19
+
20
+ - `final_model/`: inference-ready weights
21
+ - `training_checkpoints/checkpoint-71000/`: a resumable training checkpoint
22
+ - `src/`, `train.py`, `sampling.py`: training and inference code
23
+ - `resplit_data_v3.py`: the `data_v3` reconstruction pipeline
24
+ - `slurm/`: the single-node H200 training and data rebuild submission scripts
25
+ - `CodonTranslator/` and `pyproject.toml`: a lightweight packaged inference wrapper
26
+
27
+ ## Training configuration
28
+
29
+ - Architecture: `hidden=750`, `layers=20`, `heads=15`, `mlp_ratio=3.2`
30
+ - Attention: `mha`
31
+ - Precision: `bf16`
32
+ - Parallelism: FSDP full shard
33
+ - Effective global batch: `1536`
34
+ - Weight decay: `1e-4`
35
+ - Dataset: `alegendaryfish/CodonTranslator-data`
36
+
37
+ ## Dataset release
38
+
39
+ The corresponding public dataset and species embedding release is:
40
+
41
+ - `alegendaryfish/CodonTranslator-data`
42
+
43
+ That dataset repo contains:
44
+
45
+ - final representative-only `train/`, `val/`, `test/` parquet shards
46
+ - `embeddings_v2/`
47
+ - split audit files and reconstruction metadata
48
+
49
+ ## Quick start
50
+
51
+ ### Install
52
+
53
+ ```bash
54
+ git clone https://huggingface.co/alegendaryfish/CodonTranslator
55
+ cd CodonTranslator
56
+ pip install -r requirements.txt
57
+ pip install -e .
58
+ ```
59
+
60
+ Both import styles are supported:
61
+
62
+ ```python
63
+ from CodonTranslator import CodonTranslator
64
+ ```
65
+
66
+ ```python
67
+ from codontranslator import CodonTranslator
68
+ ```
69
+
70
+ ### Train
71
+
72
+ ```bash
73
+ python train.py \
74
+ --train_data /path/to/train \
75
+ --val_data /path/to/val \
76
+ --embeddings_dir /path/to/embeddings_v2 \
77
+ --output_dir outputs \
78
+ --fsdp \
79
+ --bf16 \
80
+ --attn mha \
81
+ --hidden 750 \
82
+ --layers 20 \
83
+ --heads 15 \
84
+ --mlp_ratio 3.2 \
85
+ --batch_size 48 \
86
+ --grad_accum 4 \
87
+ --epochs 3 \
88
+ --lr 7e-5 \
89
+ --weight_decay 1e-4
90
+ ```
91
+
92
+ ### Sample
93
+
94
+ ```bash
95
+ python sampling.py \
96
+ --model_path final_model \
97
+ --embeddings_dir /path/to/embeddings_v2 \
98
+ --species "Panicum hallii" \
99
+ --protein_sequence "MSEQUENCE" \
100
+ --strict_species_lookup
101
+ ```
102
+
103
+ ## Notes
104
+
105
+ - Training uses precomputed `embeddings_v2` for species conditioning.
106
+ - The data split is built in protein space with MMseqs clustering and binomial-species test holdout.
107
+ - `checkpoint-71000` is included for training resumption; `final_model/` is the recommended inference entrypoint.
108
+ - For compatibility, released model directories contain both `trainer_config.json` and `config.json`.
109
+
110
+ ## Sampling arguments
111
+
112
+ - `enforce_mapping`: when `True`, each generated codon is constrained to encode the provided amino acid at that position.
113
+ - `temperature`: softmax temperature. Lower values are more deterministic; `0` selects argmax greedily.
114
+ - `top_k`: keep only the `k` highest-logit codon candidates before sampling.
115
+ - `top_p`: nucleus sampling threshold; keep the smallest probability mass whose cumulative sum reaches `p`.
__pycache__/precompute_embeddings.cpython-312.pyc ADDED
Binary file (24.1 kB). View file
 
__pycache__/resplit_data_v3.cpython-312.pyc ADDED
Binary file (57.6 kB). View file
 
__pycache__/sampling.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
__pycache__/train.cpython-312.pyc ADDED
Binary file (21.8 kB). View file
 
batch_eval.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Run eval.py across all checkpoints and datasets in parallel (multi-GPU),
4
+ and collect results to ./eval.csv.
5
+
6
+ - Discovers checkpoints under outputs/checkpoint-*
7
+ - Evaluates on: data/test/*.parquet and data/val/*.parquet
8
+ - Uses up to N GPUs concurrently (default: 4) by setting CUDA_VISIBLE_DEVICES
9
+ - Parses the "Summary ..." line(s) from eval.py logs
10
+ - Appends rows to ./eval.csv
11
+
12
+ Example:
13
+ python batch_eval.py \
14
+ --outputs_dir outputs \
15
+ --embeddings_dir embeddings \
16
+ --datasets data/test/*.parquet data/val/*.parquet \
17
+ --splits test val \
18
+ --num_samples 12800 \
19
+ --batch_size 4 \
20
+ --gpus 0 1 2 3 \
21
+ --eval_script eval.py \
22
+ --device cuda
23
+
24
+ Notes:
25
+ - This script *does not* modify your eval.py. It just orchestrates/launches it.
26
+ - Requires Python 3.8+ (standard library only).
27
+ """
28
+
29
+ import argparse
30
+ import csv
31
+ import os
32
+ import re
33
+ import sys
34
+ import time
35
+ import glob
36
+ import queue
37
+ import threading
38
+ import subprocess
39
+ from pathlib import Path
40
+ from concurrent.futures import ThreadPoolExecutor, as_completed
41
+
42
+
43
+ TF_SUMMARY_RE = re.compile(
44
+ r"Summary over\s+(\d+)\s+samples\s+→.*?CE=([-\d\.eE]+).*?CODON-acc=([-\d\.eE]+).*?AA-acc=([-\d\.eE]+)"
45
+ )
46
+ EVALALL_SUMMARY_RE = re.compile(
47
+ r"Full-dataset summary.*?tokens=(\d+).*?CE=([-\d\.eE]+).*?CODON-acc=([-\d\.eE]+).*?AA-acc=([-\d\.eE]+)"
48
+ )
49
+
50
+ CSV_FIELDS = [
51
+ "timestamp_iso",
52
+ "model_path",
53
+ "checkpoint_step",
54
+ "split",
55
+ "data_path",
56
+ "num_samples",
57
+ "batch_size",
58
+ "seed",
59
+ "eval_all",
60
+ "gpu_id",
61
+ "runtime_sec",
62
+ "tokens",
63
+ "mean_ce",
64
+ "mean_codon_acc",
65
+ "mean_aa_acc",
66
+ "status",
67
+ "error",
68
+ "command",
69
+ ]
70
+
71
+
72
+ def parse_args():
73
+ p = argparse.ArgumentParser(description="Parallel evaluator for CodonGPT checkpoints.")
74
+ p.add_argument("--outputs_dir", type=str, default="outputs/", help="Folder containing checkpoint-* subdirs.")
75
+ p.add_argument("--embeddings_dir", type=str, default="embeddings/", help="Embeddings dir to pass to eval.py")
76
+ p.add_argument("--datasets", nargs="+", default=["data/test/*.parquet", "data/val/*.parquet"],
77
+ help="One or more dataset globs.")
78
+ p.add_argument("--splits", nargs="+", default=["test", "val"],
79
+ help="Split names aligned with --datasets (same length).")
80
+ p.add_argument("--num_samples", type=int, default=12800, help="num_samples for eval.py (random subset mode)")
81
+ p.add_argument("--batch_size", type=int, default=4, help="batch_size for eval.py")
82
+ p.add_argument("--seed", type=int, default=42, help="seed for eval.py")
83
+ p.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device flag for eval.py")
84
+ p.add_argument("--gpus", nargs="+", default=["0", "1", "2", "3"], help="GPU IDs to use (as CUDA_VISIBLE_DEVICES)")
85
+ p.add_argument("--eval_script", type=str, default="eval.py", help="Path to eval.py")
86
+ p.add_argument("--csv_path", type=str, default="eval.csv", help="Output CSV file")
87
+ p.add_argument("--eval_all", action="store_true",
88
+ help="Use eval.py --eval_all (streaming, no num_samples). If set, ignores --num_samples.")
89
+ p.add_argument("--workers", type=int, default=4,
90
+ help="--workers passed to eval.py when --eval_all is set.")
91
+ p.add_argument("--dry_run", action="store_true", help="List planned runs but do not execute.")
92
+ # New: filtering / resume options
93
+ p.add_argument("--start_after_step", type=int, default=-1,
94
+ help="Only evaluate checkpoints with step > this value (e.g., 73700)")
95
+ p.add_argument("--end_step", type=int, default=-1,
96
+ help="If >0, only evaluate checkpoints with step <= this value")
97
+ p.add_argument("--skip_existing", dest="skip_existing", action="store_true", default=True,
98
+ help="Skip tasks already recorded as OK in csv_path")
99
+ p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
100
+ help="Do not skip existing OK rows; re-run everything in range")
101
+ return p.parse_args()
102
+
103
+
104
+ def natural_step(dirpath: Path) -> int:
105
+ """
106
+ Extract integer step from a checkpoint dir name like 'checkpoint-21000'.
107
+ Returns -1 if not found.
108
+ """
109
+ m = re.search(r"checkpoint-(\d+)", dirpath.name)
110
+ return int(m.group(1)) if m else -1
111
+
112
+
113
+ def discover_checkpoints(outputs_dir: str) -> list[Path]:
114
+ paths = sorted(
115
+ (Path(p) for p in glob.glob(os.path.join(outputs_dir, "checkpoint-*")) if os.path.isdir(p)),
116
+ key=lambda p: natural_step(p),
117
+ )
118
+ # Optional: filter only dirs that look like real checkpoints
119
+ filtered = []
120
+ for p in paths:
121
+ has_config = (p / "config.json").exists() or (p / "trainer_config.json").exists()
122
+ has_weights = (p / "model.safetensors").exists() or (p / "pytorch_model.bin").exists()
123
+ if has_config and has_weights:
124
+ filtered.append(p)
125
+ return filtered
126
+
127
+
128
+ def build_cmd(py_exec: str,
129
+ eval_script: str,
130
+ model_path: str,
131
+ data_path: str,
132
+ embeddings_dir: str,
133
+ device: str,
134
+ num_samples: int,
135
+ batch_size: int,
136
+ seed: int,
137
+ eval_all: bool,
138
+ workers: int) -> list[str]:
139
+ cmd = [py_exec, eval_script,
140
+ "--model_path", model_path,
141
+ "--data_path", data_path,
142
+ "--embeddings_dir", embeddings_dir,
143
+ "--batch_size", str(batch_size),
144
+ "--device", device,
145
+ "--seed", str(seed)]
146
+ if eval_all:
147
+ cmd += ["--eval_all", "--workers", str(workers)]
148
+ else:
149
+ cmd += ["--num_samples", str(num_samples)]
150
+ return cmd
151
+
152
+
153
+ def parse_metrics(stdout: str, stderr: str) -> dict:
154
+ """
155
+ Return dict with keys: tokens, mean_ce, mean_codon_acc, mean_aa_acc (strings),
156
+ or raise ValueError if no summary line was found.
157
+ """
158
+ text = stdout + "\n" + stderr
159
+
160
+ # Try eval_all format first
161
+ m = EVALALL_SUMMARY_RE.search(text)
162
+ if m:
163
+ tokens, ce, codon, aa = m.groups()
164
+ return {"tokens": tokens, "mean_ce": ce, "mean_codon_acc": codon, "mean_aa_acc": aa}
165
+
166
+ # Try teacher-forced (random-subset) summary
167
+ m = TF_SUMMARY_RE.search(text)
168
+ if m:
169
+ _samples, ce, codon, aa = m.groups()
170
+ return {"tokens": "", "mean_ce": ce, "mean_codon_acc": codon, "mean_aa_acc": aa}
171
+
172
+ # Not found
173
+ raise ValueError("Could not find summary line in eval.py output.")
174
+
175
+
176
+ def run_one(task: dict, gpu_queue: "queue.Queue[str]", csv_lock: threading.Lock) -> dict:
177
+ """
178
+ Execute one eval.py call using a GPU from the queue. Returns a row dict for CSV.
179
+ """
180
+ gpu_id = gpu_queue.get() # blocks until a GPU id is available
181
+ start = time.time()
182
+ status = "OK"
183
+ err_text = ""
184
+
185
+ try:
186
+ env = os.environ.copy()
187
+ # Pin the subprocess to a single GPU
188
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
189
+ env.setdefault("TOKENIZERS_PARALLELISM", "false")
190
+ env.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
191
+
192
+ result = subprocess.run(
193
+ task["cmd"],
194
+ env=env,
195
+ capture_output=True,
196
+ text=True,
197
+ check=False,
198
+ )
199
+
200
+ try:
201
+ metrics = parse_metrics(result.stdout, result.stderr)
202
+ except Exception as e:
203
+ status = "FAIL"
204
+ err_text = f"{e}\n--- STDOUT ---\n{result.stdout}\n--- STDERR ---\n{result.stderr}"
205
+ metrics = {"tokens": "", "mean_ce": "", "mean_codon_acc": "", "mean_aa_acc": ""}
206
+
207
+ if result.returncode != 0 and status == "OK":
208
+ status = "FAIL"
209
+ err_text = f"Non-zero exit code {result.returncode}\n--- STDOUT ---\n{result.stdout}\n--- STDERR ---\n{result.stderr}"
210
+
211
+ finally:
212
+ runtime = time.time() - start
213
+ gpu_queue.put(gpu_id) # release GPU
214
+
215
+ row = {
216
+ "timestamp_iso": time.strftime("%Y-%m-%dT%H:%M:%S"),
217
+ "model_path": task["model_path"],
218
+ "checkpoint_step": task["step"],
219
+ "split": task["split"],
220
+ "data_path": task["data_path"],
221
+ "num_samples": task["num_samples"] if not task["eval_all"] else "",
222
+ "batch_size": task["batch_size"],
223
+ "seed": task["seed"],
224
+ "eval_all": str(task["eval_all"]),
225
+ "gpu_id": str(gpu_id),
226
+ "runtime_sec": f"{runtime:.2f}",
227
+ "tokens": metrics.get("tokens", ""),
228
+ "mean_ce": metrics.get("mean_ce", ""),
229
+ "mean_codon_acc": metrics.get("mean_codon_acc", ""),
230
+ "mean_aa_acc": metrics.get("mean_aa_acc", ""),
231
+ "status": status,
232
+ "error": err_text.strip(),
233
+ "command": " ".join(task["cmd"]),
234
+ }
235
+ return row
236
+
237
+
238
+ def ensure_csv(path: str):
239
+ """Create CSV with header if it does not exist."""
240
+ need_header = not os.path.exists(path) or os.path.getsize(path) == 0
241
+ if need_header:
242
+ with open(path, "w", newline="") as f:
243
+ w = csv.DictWriter(f, fieldnames=CSV_FIELDS)
244
+ w.writeheader()
245
+
246
+
247
+ def read_completed_keys(path: str) -> set[tuple[int, str, str]]:
248
+ """
249
+ Read existing CSV and return a set of (step, split, data_path) for rows with status == 'OK'.
250
+ If CSV does not exist, returns empty set.
251
+ """
252
+ keys: set[tuple[int, str, str]] = set()
253
+ if not os.path.exists(path) or os.path.getsize(path) == 0:
254
+ return keys
255
+ try:
256
+ with open(path, "r", newline="") as f:
257
+ r = csv.DictReader(f)
258
+ for row in r:
259
+ if (row.get("status") or "").strip().upper() == "OK":
260
+ try:
261
+ step = int(row.get("checkpoint_step", "-1"))
262
+ except ValueError:
263
+ continue
264
+ split = row.get("split", "")
265
+ data_path = row.get("data_path", "")
266
+ keys.add((step, split, data_path))
267
+ except Exception:
268
+ # If CSV is malformed, resume logic is best-effort
269
+ pass
270
+ return keys
271
+
272
+
273
+ def append_row(path: str, row: dict, lock: threading.Lock):
274
+ with lock:
275
+ with open(path, "a", newline="") as f:
276
+ w = csv.DictWriter(f, fieldnames=CSV_FIELDS)
277
+ w.writerow(row)
278
+ f.flush()
279
+
280
+
281
+ def main():
282
+ args = parse_args()
283
+
284
+ if len(args.datasets) != len(args.splits):
285
+ print("ERROR: --datasets and --splits must have the same length.", file=sys.stderr)
286
+ sys.exit(2)
287
+
288
+ checkpoints = discover_checkpoints(args.outputs_dir)
289
+ if not checkpoints:
290
+ print(f"No checkpoints found in {args.outputs_dir}/checkpoint-*", file=sys.stderr)
291
+ sys.exit(1)
292
+
293
+ print(f"Discovered {len(checkpoints)} checkpoints.")
294
+ ds_pairs = list(zip(args.splits, args.datasets))
295
+ print(f"Datasets: {', '.join([f'{s}:{d}' for s, d in ds_pairs])}")
296
+ print(f"GPUs: {', '.join(args.gpus)}")
297
+ print(f"Writing results to: {args.csv_path}")
298
+ if args.start_after_step >= 0:
299
+ print(f"Filtering: step > {args.start_after_step}")
300
+ if args.end_step > 0:
301
+ print(f"Filtering: step <= {args.end_step}")
302
+ print(f"Skip existing OK rows in CSV: {args.skip_existing}")
303
+
304
+ # Build task list
305
+ py_exec = sys.executable
306
+ tasks = []
307
+ completed_keys = read_completed_keys(args.csv_path) if args.skip_existing else set()
308
+ for ckpt in checkpoints:
309
+ step = natural_step(ckpt)
310
+ # Apply step filters
311
+ if args.start_after_step >= 0 and step <= args.start_after_step:
312
+ continue
313
+ if args.end_step > 0 and step > args.end_step:
314
+ continue
315
+ for split, data_path in ds_pairs:
316
+ # Skip if already evaluated with OK status
317
+ if (step, split, data_path) in completed_keys:
318
+ continue
319
+ cmd = build_cmd(
320
+ py_exec=py_exec,
321
+ eval_script=args.eval_script,
322
+ model_path=str(ckpt),
323
+ data_path=data_path,
324
+ embeddings_dir=args.embeddings_dir,
325
+ device=args.device,
326
+ num_samples=args.num_samples,
327
+ batch_size=args.batch_size,
328
+ seed=args.seed,
329
+ eval_all=args.eval_all,
330
+ workers=args.workers,
331
+ )
332
+ tasks.append({
333
+ "model_path": str(ckpt),
334
+ "step": step,
335
+ "split": split,
336
+ "data_path": data_path,
337
+ "num_samples": args.num_samples,
338
+ "batch_size": args.batch_size,
339
+ "seed": args.seed,
340
+ "eval_all": args.eval_all,
341
+ "cmd": cmd,
342
+ })
343
+
344
+ # Dry run listing
345
+ if args.dry_run:
346
+ for t in tasks:
347
+ print(f"[DRY RUN] GPU=? step={t['step']} split={t['split']} -> {' '.join(t['cmd'])}")
348
+ print(f"Planned runs: {len(tasks)}")
349
+ return
350
+
351
+ # Prepare CSV
352
+ ensure_csv(args.csv_path)
353
+ csv_lock = threading.Lock()
354
+
355
+ # GPU pool
356
+ gpu_queue: "queue.Queue[str]" = queue.Queue()
357
+ for gid in args.gpus:
358
+ gpu_queue.put(str(gid))
359
+
360
+ # Execute with up to len(gpus) concurrent workers
361
+ max_workers = max(1, len(args.gpus))
362
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
363
+ futures = [ex.submit(run_one, t, gpu_queue, csv_lock) for t in tasks]
364
+ completed = 0
365
+ total = len(futures)
366
+ for fut in as_completed(futures):
367
+ row = fut.result()
368
+ append_row(args.csv_path, row, csv_lock)
369
+ completed += 1
370
+ if row["status"] == "OK":
371
+ print(f"[{completed}/{total}] ✅ step={row['checkpoint_step']} split={row['split']} "
372
+ f"CE={row['mean_ce']} CODON={row['mean_codon_acc']} AA={row['mean_aa_acc']} "
373
+ f"gpu={row['gpu_id']} in {row['runtime_sec']}s")
374
+ else:
375
+ print(f"[{completed}/{total}] ❌ step={row['checkpoint_step']} split={row['split']} "
376
+ f"gpu={row['gpu_id']} See CSV 'error' column for details.")
377
+
378
+ print(f"Done. Results appended to {args.csv_path}")
379
+
380
+
381
+ if __name__ == "__main__":
382
+ main()
codontranslator/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from CodonTranslator import CodonTranslator
2
+
3
+ __all__ = ["CodonTranslator"]
environment.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: codontranslator
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ - nvidia
6
+ dependencies:
7
+ - python=3.12
8
+ - pip
9
+ - pytorch>=2.4
10
+ - pandas>=2.3
11
+ - pyarrow>=21.0
12
+ - duckdb>=1.5
13
+ - biopython>=1.85
14
+ - pip:
15
+ - transformers>=4.57.0
16
+ - esm>=3.2.3
17
+ - safetensors>=0.7.0
18
+ - huggingface-hub>=0.36.0
19
+ - accelerate>=1.9.0
20
+ - wandb>=0.21.0
eval.py ADDED
@@ -0,0 +1,1239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Teacher-forced (and optional free-run) evaluation on a random subset of your
4
+ dataset to measure codon token cross-entropy and AA token accuracy, using the
5
+ same conditioning pathway as training.
6
+
7
+ Supports either a CSV file or Parquet input via a directory/glob (e.g.,
8
+ ./data/val/*.parquet).
9
+
10
+ Usage examples:
11
+ # CSV input
12
+ python eval.py \
13
+ --model_path outputs/checkpoint-21000 \
14
+ --data_path random_sample_1000.csv \
15
+ --embeddings_dir embeddings \
16
+ --num_samples 10 \
17
+ --batch_size 10 \
18
+ --device cuda
19
+
20
+ # Parquet glob input
21
+ python eval.py \
22
+ --model_path outputs/checkpoint-21000 \
23
+ --data_path "./data/val/*.parquet" \
24
+ --embeddings_dir embeddings \
25
+ --num_samples 64 \
26
+ --batch_size 32 \
27
+ --device cuda
28
+ """
29
+
30
+ import argparse
31
+ import json
32
+ import logging
33
+ import random
34
+ from pathlib import Path
35
+ from typing import List, Optional, Tuple
36
+ import glob
37
+
38
+ import torch
39
+ import torch.nn.functional as F
40
+ import pandas as pd
41
+
42
+ from src.sampler import CodonSampler
43
+ from src.dataset import SpeciesEmbeddingStore, StreamSeqDataset, stage_collate_fn
44
+ from torch.utils.data import DataLoader
45
+
46
+
47
+ logging.basicConfig(
48
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
49
+ datefmt="%m/%d/%Y %H:%M:%S",
50
+ level=logging.INFO,
51
+ )
52
+ logger = logging.getLogger("eval_tf")
53
+
54
+
55
+ def parse_args():
56
+ p = argparse.ArgumentParser("Teacher-forced evaluation of CodonGPT")
57
+ p.add_argument("--model_path", required=True, type=str,
58
+ help="Path to checkpoint dir (with config.json / model.safetensors)")
59
+ # Input data: CSV file or Parquet glob/dir
60
+ p.add_argument("--data_path", required=False, type=str, default=None,
61
+ help="CSV file or Parquet glob/dir (e.g., ./data/val/*.parquet)")
62
+ # Back-compat: --csv_path still accepted (deprecated)
63
+ p.add_argument("--csv_path", required=False, type=str, default=None,
64
+ help="[Deprecated] CSV with columns: Taxon, protein_seq, cds_DNA")
65
+ p.add_argument("--embeddings_dir", type=str, default=None,
66
+ help="Species embeddings directory (recommended for parity)")
67
+ p.add_argument("--num_samples", type=int, default=10)
68
+ p.add_argument("--batch_size", type=int, default=10)
69
+ p.add_argument("--seed", type=int, default=42)
70
+ p.add_argument("--device", type=str, default="cuda")
71
+ p.add_argument("--workers", type=int, default=0,
72
+ help="DataLoader workers for --eval_all streaming mode")
73
+ # Free-run (sampling) evaluation options
74
+ p.add_argument("--free_run", action="store_true",
75
+ help="If set, perform real sampling instead of teacher forcing and compare to ground-truth codon sequences")
76
+ p.add_argument("--temperature", type=float, default=0.8)
77
+ p.add_argument("--top_k", type=int, default=50)
78
+ p.add_argument("--top_p", type=float, default=0.9)
79
+ p.add_argument("--control_mode", type=str, choices=["fixed","variable"], default="fixed")
80
+ p.add_argument("--enforce_translation", action="store_true",
81
+ help="Hard-mask decoding to codons matching target amino acid at each position during free-run evaluation")
82
+ # Full-dataset streaming eval (no sampling)
83
+ p.add_argument("--eval_all", action="store_true",
84
+ help="Stream over all rows from --data_path and compute aggregated metrics (memory-safe)")
85
+ p.add_argument("--max_records", type=int, default=0,
86
+ help="When --eval_all is set: limit to first N samples (0 = all)")
87
+ p.add_argument("--debug_aa_check", action="store_true",
88
+ help="Print per-sample agreement between CDS→AA (standard code) and provided protein")
89
+ # Per-sequence export over standard splits ./data/val and ./data/test
90
+ p.add_argument("--export_per_sequence", action="store_true",
91
+ help="Process ./data/val and ./data/test parquets in batches and export a per-sequence CSV")
92
+ p.add_argument("--splits_root", type=str, default="./data",
93
+ help="Root directory that contains val/ and test/ subfolders with parquet files")
94
+ p.add_argument("--out_csv", type=str, default="outputs/eval_per_sequence.csv",
95
+ help="Output CSV path for per-sequence export")
96
+ p.add_argument("--export_splits", nargs="+", default=["val", "test"],
97
+ help="Subdirectories under --splits_root to process (default: val test)")
98
+ p.add_argument("--max_rows_per_split", type=int, default=0,
99
+ help="When --export_per_sequence is set: limit number of rows per split (0 = all)")
100
+ p.add_argument("--progress", action="store_true",
101
+ help="Show progress bars during per-sequence export")
102
+ # Capacity and evaluation controls
103
+ p.add_argument("--no_truncation", action="store_true",
104
+ help="Fit prefix caps so generated codon length equals protein length (avoids capacity truncation)")
105
+ p.add_argument("--species_prefix_cap", type=int, default=0,
106
+ help="When >0 and --no_truncation is set, cap species token prefix to this many tokens; 0 = no species cap")
107
+ return p.parse_args()
108
+
109
+
110
+ def _is_parquet_path(p: str) -> bool:
111
+ lower = p.lower()
112
+ return lower.endswith(".parquet") or lower.endswith(".parq")
113
+
114
+
115
+ def _expand_paths(maybe_path_or_glob: str) -> List[str]:
116
+ """Expand a path/glob or directory into a sorted list of files.
117
+ Prioritize Parquet when scanning a directory.
118
+ """
119
+ paths: List[str] = []
120
+ P = Path(maybe_path_or_glob)
121
+ if P.is_dir():
122
+ paths.extend(sorted(str(x) for x in P.rglob("*.parquet")))
123
+ paths.extend(sorted(str(x) for x in P.rglob("*.parq")))
124
+ paths.extend(sorted(str(x) for x in P.rglob("*.csv")))
125
+ paths.extend(sorted(str(x) for x in P.rglob("*.tsv")))
126
+ paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz")))
127
+ paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz")))
128
+ else:
129
+ paths = sorted(glob.glob(str(P)))
130
+ # Dedup while preserving order
131
+ out: List[str] = []
132
+ seen = set()
133
+ for x in paths:
134
+ if x not in seen:
135
+ out.append(x)
136
+ seen.add(x)
137
+ return out
138
+
139
+
140
+ def _load_random_samples_from_parquet(files: List[str], num_samples: int, seed: int) -> pd.DataFrame:
141
+ """Collect up to num_samples rows from a list of Parquet files, reading by row group.
142
+ Reads only the required columns and shuffles files/row-groups for decent coverage.
143
+ """
144
+ try:
145
+ import pyarrow.parquet as pq # type: ignore
146
+ except Exception as e: # pragma: no cover
147
+ raise ImportError("pyarrow is required to read parquet files") from e
148
+
149
+ rng = random.Random(seed)
150
+ req = ["Taxon", "protein_seq", "cds_DNA"]
151
+ files = [f for f in files if _is_parquet_path(f)]
152
+ if not files:
153
+ raise FileNotFoundError("No Parquet files found to read")
154
+ files = files.copy()
155
+ rng.shuffle(files)
156
+
157
+ collected: List[pd.DataFrame] = []
158
+ remaining = int(max(0, num_samples))
159
+ for fp in files:
160
+ if remaining <= 0:
161
+ break
162
+ pf = pq.ParquetFile(fp)
163
+ nrg = int(pf.num_row_groups or 0)
164
+ if nrg <= 0:
165
+ rgs = [0]
166
+ else:
167
+ rgs = list(range(nrg))
168
+ rng.shuffle(rgs)
169
+ # Only keep columns that exist in this file
170
+ cols = [c for c in req if c in pf.schema.names]
171
+ if len(cols) < len(req):
172
+ missing = sorted(set(req) - set(cols))
173
+ raise ValueError(f"Parquet missing required columns {missing} in {fp}")
174
+ for rg in rgs:
175
+ if remaining <= 0:
176
+ break
177
+ table = pf.read_row_group(rg, columns=cols)
178
+ df = table.to_pandas(types_mapper=None)
179
+ if df.empty:
180
+ continue
181
+ if len(df) > remaining:
182
+ df = df.sample(n=remaining, random_state=rng.randint(0, 2**31 - 1))
183
+ collected.append(df)
184
+ remaining -= len(df)
185
+ if not collected:
186
+ return pd.DataFrame(columns=req)
187
+ out = pd.concat(collected, ignore_index=True)
188
+ # Final shuffle for randomness
189
+ out = out.sample(frac=1.0, random_state=seed).reset_index(drop=True)
190
+ # If we somehow overshot, trim
191
+ if len(out) > num_samples:
192
+ out = out.iloc[:num_samples].reset_index(drop=True)
193
+ return out
194
+
195
+
196
+ def _preferred_pooling(model_dir: Path) -> str:
197
+ """
198
+ Best-effort pooling detection:
199
+ - First try checkpoint configs for an explicit hint
200
+ - Fallback to 'last'
201
+ Note: we'll further override this using the embeddings_dir contents if provided.
202
+ """
203
+ for cfg_name in ("trainer_config.json", "config.json"):
204
+ fp = model_dir / cfg_name
205
+ if fp.exists():
206
+ try:
207
+ with open(fp) as f:
208
+ cfg = json.load(f)
209
+ return str(cfg.get("species_pooling", "last"))
210
+ except Exception:
211
+ continue
212
+ return "last"
213
+
214
+
215
+ def _detect_pooling_from_embeddings_dir(emb_dir: Path) -> Optional[str]:
216
+ """Detect actual available pooling format from embeddings_dir contents."""
217
+ fixed_files = [emb_dir / "species_embeddings.bin", emb_dir / "species_metadata.json", emb_dir / "species_vocab.json"]
218
+ seq_files = [emb_dir / "species_tok_emb.bin", emb_dir / "species_index.json", emb_dir / "species_vocab.json"]
219
+ if all(p.exists() for p in fixed_files):
220
+ return "last"
221
+ if all(p.exists() for p in seq_files):
222
+ return "sequence"
223
+ return None
224
+
225
+
226
+ @torch.no_grad()
227
+ def eval_batch(
228
+ sampler: CodonSampler,
229
+ species_store: Optional[SpeciesEmbeddingStore],
230
+ species_names: List[str],
231
+ protein_seqs: List[str],
232
+ dna_cds_list: List[str],
233
+ ) -> Tuple[List[float], List[float]]:
234
+ """Evaluate a batch in teacher-forced mode.
235
+
236
+ Returns per-sample (avg_ce_loss, aa_token_acc).
237
+ """
238
+ tok = sampler.tokenizer
239
+ pad_id = tok.pad_token_id
240
+ eos_id = tok.eos_token_id
241
+
242
+ # Encode DNA to codon ids and align lengths (trim to min protein length)
243
+ codon_ids = []
244
+ seq_lens = []
245
+ for dna, prot in zip(dna_cds_list, protein_seqs):
246
+ # Trim to min length between DNA codons and protein AA
247
+ C_dna = len(dna) // 3
248
+ C_prot = len(prot)
249
+ C = max(min(C_dna, C_prot), 1)
250
+ dna_trim = dna[: 3 * C]
251
+ ids = tok.encode_codon_seq(dna_trim, validate=False)
252
+ ids.append(eos_id)
253
+ codon_ids.append(ids)
254
+ seq_lens.append(len(ids))
255
+
256
+ B = len(codon_ids)
257
+ T = max(seq_lens)
258
+ codons = torch.full((B, T), pad_id, dtype=torch.long)
259
+ mask = torch.zeros((B, T), dtype=torch.bool)
260
+ for i, ids in enumerate(codon_ids):
261
+ L = len(ids)
262
+ codons[i, :L] = torch.tensor(ids, dtype=torch.long)
263
+ mask[i, :L] = True
264
+
265
+ # inputs/labels aligned to training convention:
266
+ # model predicts next codon after a learned start token; labels are the
267
+ # same positions as inputs (not shifted by 1), with PAD/EOS masked out.
268
+ input_ids = codons[:, :-1]
269
+ labels_base = codons[:, :-1].clone()
270
+ # Mask out PAD and EOS like trainer.evaluate()
271
+ labels_base[labels_base == pad_id] = -100
272
+ labels_base[labels_base == eos_id] = -100
273
+
274
+ # Build conditioning dict similar to training and sampler
275
+ cond = {"control_mode": "fixed"}
276
+
277
+ if species_store is not None and species_names:
278
+ sid_list = [species_store.vocab.get(s, -1) for s in species_names]
279
+ num_unknown = sum(1 for x in sid_list if x < 0)
280
+ if num_unknown > 0:
281
+ logger.warning(f"{num_unknown}/{len(sid_list)} species not found in embeddings vocab; using zero embeddings")
282
+ result = species_store.batch_get(sid_list)
283
+ if isinstance(result, tuple):
284
+ sp_tok, _ = result # [B, Ls, Ds]
285
+ cond["species_tok_emb_src"] = sp_tok.to(sampler.device)
286
+ cond["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
287
+ else:
288
+ sp = result # [B, Ds]
289
+ cond["species_emb_src"] = sp.to(sampler.device)
290
+ cond["species_emb_tgt"] = sp.to(sampler.device)
291
+ elif species_names:
292
+ # On-the-fly species embeddings using Qwen (sequence pooling for training parity)
293
+ seq_emb, _lens = sampler._qwen_embed_names(species_names, pooling="sequence")
294
+ seq_emb = seq_emb.to(sampler.device)
295
+ cond["species_tok_emb_src"] = seq_emb
296
+ cond["species_tok_emb_tgt"] = seq_emb
297
+
298
+ # Match training: pass raw protein sequences; the model tokenizes internally
299
+ cond["protein_seqs"] = protein_seqs
300
+
301
+ # Move tensors to device
302
+ device = sampler.device
303
+ input_ids = input_ids.to(device)
304
+ labels_base = labels_base.to(device)
305
+
306
+ sampler.model.eval()
307
+ outputs = sampler.model(codon_ids=input_ids, cond=cond, labels=labels_base, return_dict=True)
308
+ logits = outputs["logits"] # [B, Lmax, V] aligned to per-sample capacity after prefix
309
+ try:
310
+ prefix_len = outputs.get("prefix_len", 0)
311
+ if isinstance(prefix_len, torch.Tensor):
312
+ prefix_len_dbg = int(prefix_len.max().item()) if prefix_len.numel() > 0 else 0
313
+ else:
314
+ prefix_len_dbg = int(prefix_len)
315
+ logger.debug(f"Prefix length(max)={prefix_len_dbg}, input_len={input_ids.size(1)}")
316
+ except Exception:
317
+ pass
318
+
319
+ # Align labels/masks to logits length and per-sample caps
320
+ Bsz, Lmax, V = logits.size(0), logits.size(1), logits.size(2)
321
+ labels_aligned = torch.full((Bsz, Lmax), -100, dtype=labels_base.dtype, device=logits.device)
322
+ common_cols = min(labels_base.size(1), Lmax)
323
+ if common_cols > 0:
324
+ labels_aligned[:, :common_cols] = labels_base[:, :common_cols]
325
+ per_cap = outputs.get("per_cap", None)
326
+ if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz:
327
+ ar = torch.arange(Lmax, device=logits.device).unsqueeze(0)
328
+ cap_mask = ar < per_cap.to(device=logits.device).unsqueeze(1) # [B,Lmax]
329
+ else:
330
+ cap_mask = torch.ones_like(labels_aligned, dtype=torch.bool, device=logits.device)
331
+
332
+ # Mask labels beyond per-cap to -100 so CE ignores them
333
+ labels_masked = labels_aligned.clone().to(device=logits.device)
334
+ labels_masked[~cap_mask] = -100
335
+
336
+ # Cross-entropy per sample (include EOS target; ignore PAD)
337
+ loss_flat = F.cross_entropy(
338
+ logits.reshape(-1, V),
339
+ labels_masked.reshape(-1),
340
+ ignore_index=-100,
341
+ reduction="none",
342
+ ).view(Bsz, Lmax)
343
+
344
+ # Accuracy per sample
345
+ preds = logits.argmax(dim=-1)
346
+ num_special = int(getattr(tok, "num_special_tokens", 0) or 0)
347
+ supervised = (labels_masked != -100) & cap_mask
348
+ if num_special > 0:
349
+ supervised = supervised & (labels_aligned >= num_special)
350
+ correct = (preds == labels_aligned) & supervised
351
+
352
+ per_sample_ce: List[float] = []
353
+ per_sample_acc: List[float] = []
354
+ per_sample_aa_acc: List[float] = []
355
+ codon2aa = tok.codon2aa_char_map() if hasattr(tok, "codon2aa_char_map") else {}
356
+ per_cap = outputs.get("per_cap", None)
357
+ per_cap_int = None
358
+ if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz:
359
+ per_cap_int = torch.clamp(per_cap.to(dtype=torch.long, device=logits.device), min=0, max=Lmax)
360
+
361
+ for i in range(B):
362
+ # Average CE over valid positions
363
+ valid = (labels_masked[i] != -100) & cap_mask[i]
364
+ if num_special > 0:
365
+ valid = valid & (labels_aligned[i] >= num_special)
366
+ ce = (loss_flat[i][valid].mean().item() if valid.any() else 0.0)
367
+ per_sample_ce.append(ce)
368
+
369
+ # Codon-level accuracy over supervised positions
370
+ denom = supervised[i].sum().item()
371
+ acc = (correct[i].sum().item() / denom) if denom > 0 else 0.0
372
+ # AA-level accuracy per sample (match trainer)
373
+ aa_acc = 0.0
374
+ if per_cap_int is not None and codon2aa and i < len(protein_seqs):
375
+ cap = int(per_cap_int[i].item())
376
+ if cap > 0:
377
+ mask_row = supervised[i, :cap]
378
+ if mask_row.any():
379
+ preds_row = preds[i, :cap][mask_row]
380
+ prot = protein_seqs[i]
381
+ seq_len = min(len(prot), preds_row.size(0))
382
+ if seq_len > 0:
383
+ pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
384
+ truth_aa = prot[:seq_len]
385
+ aa_matches = sum(1 for j in range(seq_len) if pred_aa[j] == truth_aa[j])
386
+ aa_acc = aa_matches / seq_len
387
+ per_sample_aa_acc.append(aa_acc)
388
+
389
+ return per_sample_ce, per_sample_aa_acc
390
+
391
+
392
+ def _dna_to_codons(dna: str) -> List[str]:
393
+ dna = dna.strip().upper()
394
+ return [dna[i:i+3] for i in range(0, len(dna) - (len(dna) % 3), 3)]
395
+
396
+
397
+ def _aa_from_dna_standard(dna: str, tok) -> str:
398
+ dna = dna.strip().upper()
399
+ gc = getattr(tok, "_genetic_code", {})
400
+ aa = []
401
+ for j in range(0, len(dna) - (len(dna) % 3), 3):
402
+ aa.append(gc.get(dna[j:j+3], 'X'))
403
+ return ''.join(aa)
404
+
405
+
406
+ def _aa_agreement(dna: str, protein: str, tok) -> Tuple[float, int, int]:
407
+ """Return (match_ratio, compared_len, first_mismatch_idx or -1) under standard code."""
408
+ dna = dna.strip().upper()
409
+ protein = protein.strip().upper()
410
+ L = min(len(dna) // 3, len(protein))
411
+ if L <= 0:
412
+ return 0.0, 0, -1
413
+ aa_pred = _aa_from_dna_standard(dna[: 3 * L], tok)
414
+ truth = protein[:L]
415
+ mism_idx = -1
416
+ matches = 0
417
+ for i, (a, b) in enumerate(zip(aa_pred, truth)):
418
+ if a == b:
419
+ matches += 1
420
+ elif mism_idx < 0:
421
+ mism_idx = i
422
+ return (matches / L), L, mism_idx
423
+
424
+
425
+ @torch.no_grad()
426
+ def eval_streaming_all(
427
+ sampler: CodonSampler,
428
+ species_store: SpeciesEmbeddingStore,
429
+ data_path: str,
430
+ batch_size: int,
431
+ num_workers: int,
432
+ max_records: int = 0,
433
+ ):
434
+ """Stream over all rows from CSV/Parquet inputs and compute dataset-level metrics.
435
+
436
+ Mirrors trainer.evaluate() for parity.
437
+ """
438
+ device = sampler.device
439
+ tok = sampler.tokenizer
440
+ pad_id = int(tok.pad_token_id)
441
+ eos_id = int(tok.eos_token_id)
442
+ num_special = int(tok.num_special_tokens)
443
+ codon2aa = tok.codon2aa_char_map()
444
+
445
+ # Build streaming dataset and loader
446
+ from pathlib import Path as _Path
447
+ import glob as _glob
448
+ def _expand(pat: str) -> List[str]:
449
+ P = _Path(pat)
450
+ if P.is_dir():
451
+ paths: List[str] = []
452
+ paths.extend(sorted(str(x) for x in P.rglob("*.parquet")))
453
+ paths.extend(sorted(str(x) for x in P.rglob("*.parq")))
454
+ paths.extend(sorted(str(x) for x in P.rglob("*.csv")))
455
+ paths.extend(sorted(str(x) for x in P.rglob("*.tsv")))
456
+ paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz")))
457
+ paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz")))
458
+ else:
459
+ paths = sorted(_glob.glob(str(P)))
460
+ # de-dup
461
+ seen = set(); out = []
462
+ for x in paths:
463
+ if x not in seen:
464
+ out.append(x); seen.add(x)
465
+ return out
466
+
467
+ paths = _expand(data_path)
468
+ if not paths:
469
+ raise FileNotFoundError(f"No input files matched: {data_path}")
470
+
471
+ species_vocab_path = str((Path(species_store.embeddings_dir) / "species_vocab.json").resolve())
472
+ ds = StreamSeqDataset(
473
+ files=paths,
474
+ tokenizer=tok,
475
+ species_vocab_path=species_vocab_path,
476
+ unknown_species_id=0,
477
+ csv_chunksize=200_000,
478
+ shuffle_buffer=0,
479
+ shard_across_ranks=False,
480
+ )
481
+ _dl_kwargs = dict(
482
+ batch_size=int(batch_size),
483
+ shuffle=False,
484
+ drop_last=False,
485
+ num_workers=int(max(0, num_workers)),
486
+ collate_fn=stage_collate_fn,
487
+ pin_memory=True,
488
+ persistent_workers=(int(num_workers) > 0),
489
+ )
490
+ if int(num_workers) > 0:
491
+ _dl_kwargs["prefetch_factor"] = 4
492
+ loader = DataLoader(ds, **_dl_kwargs)
493
+
494
+ loss_sum = 0.0
495
+ loss_tokens = 0
496
+ codon_correct = 0
497
+ codon_total = 0
498
+ aa_correct = 0
499
+ aa_total = 0
500
+
501
+ seen = 0
502
+ for batch in loader:
503
+ if not batch:
504
+ continue
505
+ if int(max_records) > 0 and seen >= int(max_records):
506
+ break
507
+ codon_ids = batch["codon_ids"].to(device)
508
+ input_ids = codon_ids[:, :-1]
509
+ labels = codon_ids[:, :-1].clone()
510
+ labels[labels == pad_id] = -100
511
+ labels[labels == eos_id] = -100
512
+
513
+ # Build cond using species_store and protein_seqs
514
+ cond = {"control_mode": "fixed", "protein_seqs": batch.get("protein_seqs", [])}
515
+ sids = batch.get("species_ids")
516
+ if torch.is_tensor(sids):
517
+ sids_list = sids.detach().cpu().tolist()
518
+ else:
519
+ sids_list = [int(x) for x in sids]
520
+ res = species_store.batch_get(sids_list)
521
+ if isinstance(res, tuple):
522
+ sp_tok, _ = res
523
+ cond["species_tok_emb_src"] = sp_tok.to(device)
524
+ cond["species_tok_emb_tgt"] = sp_tok.to(device)
525
+ else:
526
+ cond["species_emb_src"] = res.to(device)
527
+ cond["species_emb_tgt"] = res.to(device)
528
+
529
+ out = sampler.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True)
530
+ loss = out.get("loss")
531
+ per_cap = out.get("per_cap")
532
+ logits = out.get("logits")
533
+
534
+ tokens_in_batch = 0
535
+ if per_cap is not None:
536
+ tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item())
537
+ loss_tokens += tokens_in_batch
538
+ if loss is not None and tokens_in_batch > 0:
539
+ loss_sum += float(loss.detach().item()) * tokens_in_batch
540
+
541
+ if logits is None or logits.size(1) == 0 or per_cap is None:
542
+ seen += input_ids.size(0)
543
+ continue
544
+ max_cap = logits.size(1)
545
+ batch_size = logits.size(0)
546
+ labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device)
547
+ common = min(labels.size(1), max_cap)
548
+ if common > 0:
549
+ labels_aligned[:, :common] = labels[:, :common]
550
+ per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap)
551
+ for row in range(batch_size):
552
+ cap = int(per_cap_int[row].item())
553
+ if cap < max_cap:
554
+ labels_aligned[row, cap:] = -100
555
+ supervised = labels_aligned != -100
556
+ if num_special > 0:
557
+ supervised = supervised & (labels_aligned >= num_special)
558
+ if not supervised.any():
559
+ seen += batch_size
560
+ continue
561
+ preds = logits.argmax(dim=-1)
562
+ codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item())
563
+ codon_total += int(supervised.sum().item())
564
+
565
+ # protein list
566
+ prot_list = cond.get("protein_seqs", [])
567
+ for row in range(batch_size):
568
+ cap = int(per_cap_int[row].item())
569
+ if cap <= 0:
570
+ continue
571
+ mask_row = supervised[row, :cap]
572
+ if not mask_row.any():
573
+ continue
574
+ preds_row = preds[row, :cap][mask_row]
575
+ prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else ""
576
+ if not prot:
577
+ continue
578
+ seq_len = min(len(prot), preds_row.size(0))
579
+ if seq_len <= 0:
580
+ continue
581
+ pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
582
+ truth_aa = prot[:seq_len]
583
+ aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i])
584
+ aa_total += seq_len
585
+ seen += batch_size
586
+
587
+ mean_ce = (loss_sum / loss_tokens) if loss_tokens > 0 else 0.0
588
+ codon_acc = (float(codon_correct) / codon_total) if codon_total > 0 else 0.0
589
+ aa_acc = (float(aa_correct) / aa_total) if aa_total > 0 else 0.0
590
+ logger.info(
591
+ f"Full-dataset summary → tokens={loss_tokens} CE={mean_ce:.4f} CODON-acc={codon_acc:.4f} AA-acc={aa_acc:.4f}"
592
+ )
593
+ return mean_ce, codon_acc, aa_acc
594
+
595
+
596
+ @torch.no_grad()
597
+ def sample_and_score_batched(
598
+ sampler: CodonSampler,
599
+ species_names: List[str],
600
+ protein_seqs: List[str],
601
+ target_dnas: List[str],
602
+ temperature: float,
603
+ top_k: int,
604
+ top_p: float,
605
+ control_mode: str,
606
+ batch_size: int,
607
+ enforce_translation: bool,
608
+ no_truncation: bool = False,
609
+ species_prefix_cap: int = 64,
610
+ ) -> Tuple[List[float], List[float]]:
611
+ """Free-run sampling in batches; returns per-sample (codon_acc, aa_acc)."""
612
+ N = len(species_names)
613
+ # Compute target lengths in codons (min of DNA and AA lengths)
614
+ tgt_lengths = []
615
+ tgt_codons_list = []
616
+ for prot, dna in zip(protein_seqs, target_dnas):
617
+ cods = _dna_to_codons(dna)
618
+ L = min(len(cods), len(prot))
619
+ if L <= 0:
620
+ L = 1
621
+ cods = ["ATG"] # harmless default
622
+ tgt_lengths.append(L)
623
+ tgt_codons_list.append(cods[:L])
624
+
625
+ # Bucket indices by target length to maximize batching
626
+ buckets: dict[int, List[int]] = {}
627
+ for i, L in enumerate(tgt_lengths):
628
+ buckets.setdefault(L, []).append(i)
629
+
630
+ codon_accs = [0.0] * N
631
+ aa_accs = [0.0] * N
632
+
633
+ # Helper AA translation
634
+ vocab = sampler.tokenizer._genetic_code
635
+ def dna_to_aa(dna: str) -> str:
636
+ dna = dna.strip().upper()
637
+ aa = []
638
+ for j in range(0, len(dna) - (len(dna) % 3), 3):
639
+ aa.append(vocab.get(dna[j:j+3], 'X'))
640
+ return ''.join(aa)
641
+
642
+ for L, idxs in buckets.items():
643
+ # Optionally tighten protein prefix so prefix+start+L ≤ capacity (species kept full unless capped)
644
+ prev_sp = getattr(sampler.model, "max_species_prefix", 0)
645
+ prev_pp = getattr(sampler.model, "max_protein_prefix", 0)
646
+ if bool(no_truncation):
647
+ try:
648
+ capacity = int(getattr(sampler.model, "max_position_embeddings", 1024))
649
+ # If requested, apply a species token cap; otherwise keep as-is
650
+ store = getattr(sampler, "species_store", None)
651
+ if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0:
652
+ setattr(sampler.model, "max_species_prefix", int(species_prefix_cap))
653
+ # Build a representative cond for this bucket to measure exact prefix length
654
+ batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))]
655
+ sp_probe = [species_names[i] for i in batch_idx_probe]
656
+ pr_probe = [protein_seqs[i] for i in batch_idx_probe]
657
+ # Map species to ids via store vocab
658
+ cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe}
659
+ if store is not None:
660
+ sid_list = [store.vocab.get(s, -1) for s in sp_probe]
661
+ res = store.batch_get(sid_list)
662
+ if isinstance(res, tuple):
663
+ sp_tok, _ = res
664
+ cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device)
665
+ cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
666
+ else:
667
+ cond_probe["species_emb_src"] = res.to(sampler.device)
668
+ cond_probe["species_emb_tgt"] = res.to(sampler.device)
669
+ # Iteratively reduce protein prefix cap until remaining ≥ L
670
+ for _ in range(3):
671
+ out0 = sampler.model(
672
+ codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device),
673
+ cond=cond_probe,
674
+ return_dict=True,
675
+ use_cache=True,
676
+ )
677
+ pref = out0.get("prefix_len")
678
+ if isinstance(pref, torch.Tensor) and pref.numel() > 0:
679
+ pref_max = int(pref.max().item())
680
+ else:
681
+ pref_max = int(pref) if isinstance(pref, int) else 0
682
+ remaining = capacity - (pref_max + 1)
683
+ if remaining >= int(L):
684
+ break
685
+ need = int(L) - max(0, int(remaining))
686
+ cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0)
687
+ new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L)))
688
+ setattr(sampler.model, "max_protein_prefix", int(new_pp))
689
+ except Exception:
690
+ pass
691
+ # Process in mini-batches
692
+ for k in range(0, len(idxs), batch_size):
693
+ batch_idx = idxs[k:k+batch_size]
694
+ sp_b = [species_names[i] for i in batch_idx]
695
+ pr_b = [protein_seqs[i] for i in batch_idx]
696
+ # Sample in one call
697
+ out = sampler.sample(
698
+ num_sequences=len(batch_idx),
699
+ sequence_length=L,
700
+ species=sp_b,
701
+ protein_sequences=pr_b,
702
+ control_mode=control_mode,
703
+ temperature=temperature,
704
+ top_k=top_k,
705
+ top_p=top_p,
706
+ return_intermediate=False,
707
+ progress_bar=False,
708
+ enforce_translation=enforce_translation,
709
+ )
710
+ gen_list: List[str] = out["sequences"] # DNA strings
711
+ # Score each
712
+ for pos, idx in enumerate(batch_idx):
713
+ tgt_codons = tgt_codons_list[idx]
714
+ gen_codons = _dna_to_codons(gen_list[pos])[:L]
715
+ matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b)
716
+ codon_accs[idx] = (matches / L) if L > 0 else 0.0
717
+ gen_aa = dna_to_aa(''.join(gen_codons))
718
+ tgt_aa = protein_seqs[idx][:L]
719
+ # Treat non-canonical AA in target as "match any"
720
+ canonical = set("ACDEFGHIKLMNPQRSTVWY")
721
+ aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b))
722
+ aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0
723
+ # Restore caps
724
+ if bool(no_truncation):
725
+ try:
726
+ setattr(sampler.model, "max_species_prefix", prev_sp)
727
+ setattr(sampler.model, "max_protein_prefix", prev_pp)
728
+ except Exception:
729
+ pass
730
+
731
+ return codon_accs, aa_accs
732
+
733
+
734
+ @torch.no_grad()
735
+ def generate_and_score_batched(
736
+ sampler: CodonSampler,
737
+ species_names: List[str],
738
+ protein_seqs: List[str],
739
+ target_dnas: List[str],
740
+ temperature: float,
741
+ top_k: int,
742
+ top_p: float,
743
+ control_mode: str,
744
+ batch_size: int,
745
+ enforce_translation: bool,
746
+ no_truncation: bool = False,
747
+ species_prefix_cap: int = 64,
748
+ ) -> Tuple[List[str], List[float], List[float]]:
749
+ """Like sample_and_score_batched but also returns generated DNA sequences per sample."""
750
+ N = len(species_names)
751
+ tgt_lengths = []
752
+ tgt_codons_list = []
753
+ for prot, dna in zip(protein_seqs, target_dnas):
754
+ cods = _dna_to_codons(dna)
755
+ L = min(len(cods), len(prot))
756
+ if L <= 0:
757
+ L = 1
758
+ cods = ["ATG"]
759
+ tgt_lengths.append(L)
760
+ tgt_codons_list.append(cods[:L])
761
+
762
+ buckets: dict[int, List[int]] = {}
763
+ for i, L in enumerate(tgt_lengths):
764
+ buckets.setdefault(L, []).append(i)
765
+
766
+ gen_all = [""] * N
767
+ codon_accs = [0.0] * N
768
+ aa_accs = [0.0] * N
769
+
770
+ vocab = sampler.tokenizer._genetic_code
771
+ def dna_to_aa(dna: str) -> str:
772
+ dna = dna.strip().upper()
773
+ aa = []
774
+ for j in range(0, len(dna) - (len(dna) % 3), 3):
775
+ aa.append(vocab.get(dna[j:j+3], 'X'))
776
+ return ''.join(aa)
777
+
778
+ for L, idxs in buckets.items():
779
+ prev_sp = getattr(sampler.model, "max_species_prefix", 0)
780
+ prev_pp = getattr(sampler.model, "max_protein_prefix", 0)
781
+ if bool(no_truncation):
782
+ try:
783
+ capacity = int(getattr(sampler.model, "max_position_embeddings", 1024))
784
+ store = getattr(sampler, "species_store", None)
785
+ if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0:
786
+ setattr(sampler.model, "max_species_prefix", int(species_prefix_cap))
787
+ batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))]
788
+ sp_probe = [species_names[i] for i in batch_idx_probe]
789
+ pr_probe = [protein_seqs[i] for i in batch_idx_probe]
790
+ cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe}
791
+ if store is not None:
792
+ sid_list = [store.vocab.get(s, -1) for s in sp_probe]
793
+ res = store.batch_get(sid_list)
794
+ if isinstance(res, tuple):
795
+ sp_tok, _ = res
796
+ cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device)
797
+ cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
798
+ else:
799
+ cond_probe["species_emb_src"] = res.to(sampler.device)
800
+ cond_probe["species_emb_tgt"] = res.to(sampler.device)
801
+ for _ in range(3):
802
+ out0 = sampler.model(
803
+ codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device),
804
+ cond=cond_probe,
805
+ return_dict=True,
806
+ use_cache=True,
807
+ )
808
+ pref = out0.get("prefix_len")
809
+ pref_max = int(pref.max().item()) if isinstance(pref, torch.Tensor) and pref.numel() > 0 else (int(pref) if isinstance(pref, int) else 0)
810
+ remaining = capacity - (pref_max + 1)
811
+ if remaining >= int(L):
812
+ break
813
+ need = int(L) - max(0, int(remaining))
814
+ cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0)
815
+ new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L)))
816
+ setattr(sampler.model, "max_protein_prefix", int(new_pp))
817
+ except Exception:
818
+ pass
819
+ for k in range(0, len(idxs), batch_size):
820
+ batch_idx = idxs[k:k+batch_size]
821
+ sp_b = [species_names[i] for i in batch_idx]
822
+ pr_b = [protein_seqs[i] for i in batch_idx]
823
+ out = sampler.sample(
824
+ num_sequences=len(batch_idx),
825
+ sequence_length=L,
826
+ species=sp_b,
827
+ protein_sequences=pr_b,
828
+ control_mode=control_mode,
829
+ temperature=temperature,
830
+ top_k=top_k,
831
+ top_p=top_p,
832
+ return_intermediate=False,
833
+ progress_bar=False,
834
+ enforce_translation=enforce_translation,
835
+ )
836
+ gen_list: List[str] = out["sequences"]
837
+ for pos, idx in enumerate(batch_idx):
838
+ gen_seq = gen_list[pos]
839
+ gen_all[idx] = gen_seq
840
+ tgt_codons = tgt_codons_list[idx]
841
+ gen_codons = _dna_to_codons(gen_seq)[:L]
842
+ matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b)
843
+ codon_accs[idx] = (matches / L) if L > 0 else 0.0
844
+ gen_aa = dna_to_aa(''.join(gen_codons))
845
+ tgt_aa = protein_seqs[idx][:L]
846
+ canonical = set("ACDEFGHIKLMNPQRSTVWY")
847
+ aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b))
848
+ aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0
849
+ if bool(no_truncation):
850
+ try:
851
+ setattr(sampler.model, "max_species_prefix", prev_sp)
852
+ setattr(sampler.model, "max_protein_prefix", prev_pp)
853
+ except Exception:
854
+ pass
855
+
856
+ return gen_all, codon_accs, aa_accs
857
+
858
+
859
+ def export_per_sequence_over_splits(
860
+ sampler: CodonSampler,
861
+ splits: List[str],
862
+ splits_root: str,
863
+ out_csv: str,
864
+ batch_size: int,
865
+ temperature: float,
866
+ top_k: int,
867
+ top_p: float,
868
+ control_mode: str,
869
+ enforce_translation: bool,
870
+ progress: bool = False,
871
+ max_rows_per_split: int = 0,
872
+ no_truncation: bool = False,
873
+ species_prefix_cap: int = 0,
874
+ ) -> None:
875
+ """Process ./data/val and ./data/test (or under splits_root) and write a per-sequence CSV."""
876
+ try:
877
+ import pyarrow.parquet as pq # type: ignore
878
+ except Exception as e:
879
+ raise ImportError("pyarrow is required for Parquet evaluation/export") from e
880
+
881
+ from pathlib import Path as _P
882
+ import os as _os
883
+ total_written = 0
884
+ # Pre-create CSV with header so users can tail it immediately
885
+ header_cols = [
886
+ "split",
887
+ "organism",
888
+ "protein_seq",
889
+ "codon_seq",
890
+ "predicted_seq",
891
+ "codon_similarity",
892
+ "amino_acid_recovery_rate",
893
+ ]
894
+ _P(out_csv).parent.mkdir(parents=True, exist_ok=True)
895
+ if not _P(out_csv).exists() or _os.path.getsize(out_csv) == 0:
896
+ with open(out_csv, "w", newline="") as f:
897
+ f.write(",".join(header_cols) + "\n")
898
+ logging.info(f"Initialized CSV with header → {out_csv}")
899
+ for split in splits:
900
+ rows_remaining = int(max_rows_per_split) if int(max_rows_per_split) > 0 else None
901
+ dir_path = Path(splits_root) / split
902
+ files = sorted(str(p) for p in dir_path.glob("*.parquet"))
903
+ if not files:
904
+ logging.warning(f"No parquet files found in {dir_path}, skipping split {split}")
905
+ continue
906
+ logging.info(f"Processing split '{split}' with {len(files)} files ...")
907
+ try:
908
+ from tqdm import tqdm # type: ignore
909
+ _wrap = (lambda it, **kw: tqdm(it, **kw)) if progress else (lambda it, **kw: it)
910
+ except Exception:
911
+ _wrap = (lambda it, **kw: it)
912
+ stop_split = False
913
+ for fp in _wrap(files, desc=f"{split} files", unit="file"):
914
+ if rows_remaining is not None and rows_remaining <= 0:
915
+ break
916
+ pf = pq.ParquetFile(fp)
917
+ nrg = int(pf.num_row_groups or 0)
918
+ rgs = list(range(max(nrg, 1)))
919
+ # Build a per-file rows progress bar (prefer total rows from metadata when available)
920
+ rows_total = None
921
+ try:
922
+ if pf.metadata is not None:
923
+ rows_total = 0
924
+ for rg_idx in rgs:
925
+ rg_md = pf.metadata.row_group(rg_idx)
926
+ if rg_md is not None and rg_md.num_rows is not None:
927
+ rows_total += int(rg_md.num_rows)
928
+ except Exception:
929
+ rows_total = None
930
+ rows_pbar = None
931
+ if progress:
932
+ try:
933
+ from tqdm import tqdm # type: ignore
934
+ rows_pbar = tqdm(total=rows_total, desc=f"{split}:{Path(fp).name}", unit="rows", leave=False)
935
+ except Exception:
936
+ rows_pbar = None
937
+
938
+ for rg in rgs:
939
+ if rows_remaining is not None and rows_remaining <= 0:
940
+ stop_split = True
941
+ break
942
+ table = pf.read_row_group(rg, columns=["Taxon", "protein_seq", "cds_DNA"])
943
+ df = table.to_pandas()
944
+ if df.empty:
945
+ continue
946
+ species = df["Taxon"].astype(str).tolist()
947
+ proteins = df["protein_seq"].astype(str).str.upper().tolist()
948
+ dnas = df["cds_DNA"].astype(str).str.upper().tolist()
949
+
950
+ # Generate predictions and metrics in streaming mini-batches to keep
951
+ # memory stable and update progress frequently
952
+ N = len(species)
953
+ for off in range(0, N, batch_size):
954
+ if rows_remaining is not None and rows_remaining <= 0:
955
+ stop_split = True
956
+ break
957
+ sp_b = species[off: off + batch_size]
958
+ pr_b = proteins[off: off + batch_size]
959
+ dn_b = dnas[off: off + batch_size]
960
+ gen_list, codon_accs, aa_accs = generate_and_score_batched(
961
+ sampler,
962
+ sp_b,
963
+ pr_b,
964
+ dn_b,
965
+ temperature=temperature,
966
+ top_k=top_k,
967
+ top_p=top_p,
968
+ control_mode=control_mode,
969
+ batch_size=batch_size,
970
+ enforce_translation=enforce_translation,
971
+ no_truncation=bool(no_truncation),
972
+ species_prefix_cap=int(species_prefix_cap),
973
+ )
974
+ rows_batch: List[dict] = []
975
+ for sp, pr, dn, gen, cacc, aacc in zip(sp_b, pr_b, dn_b, gen_list, codon_accs, aa_accs):
976
+ L = min(len(pr), len(dn) // 3)
977
+ tgt_dna = dn[: 3 * L]
978
+ rows_batch.append({
979
+ "split": split,
980
+ "organism": sp,
981
+ "protein_seq": pr,
982
+ "codon_seq": tgt_dna,
983
+ "predicted_seq": gen,
984
+ "codon_similarity": float(cacc),
985
+ "amino_acid_recovery_rate": float(aacc),
986
+ })
987
+ if rows_batch:
988
+ if rows_remaining is not None and len(rows_batch) > rows_remaining:
989
+ rows_batch = rows_batch[: rows_remaining]
990
+ out_exists = _P(out_csv).exists() and _os.path.getsize(out_csv) > 0
991
+ df_out = pd.DataFrame(rows_batch)
992
+ _P(out_csv).parent.mkdir(parents=True, exist_ok=True)
993
+ df_out.to_csv(out_csv, mode='a', header=not out_exists, index=False)
994
+ total_written += len(rows_batch)
995
+ if rows_remaining is not None:
996
+ rows_remaining -= len(rows_batch)
997
+ if rows_pbar is not None:
998
+ try:
999
+ rows_pbar.update(len(rows_batch))
1000
+ except Exception:
1001
+ pass
1002
+ if rows_remaining is not None and rows_remaining <= 0:
1003
+ stop_split = True
1004
+ break
1005
+ if rows_pbar is not None:
1006
+ try:
1007
+ rows_pbar.close()
1008
+ except Exception:
1009
+ pass
1010
+ if stop_split:
1011
+ break
1012
+ logging.info(f"Per-sequence export complete → {out_csv} (rows={total_written})")
1013
+
1014
+
1015
+ def main():
1016
+ args = parse_args()
1017
+ random.seed(args.seed)
1018
+ torch.manual_seed(args.seed)
1019
+
1020
+ model_dir = Path(args.model_path)
1021
+ pooling = _preferred_pooling(model_dir)
1022
+ logger.info(f"Preferred species_pooling from checkpoint: {pooling}")
1023
+
1024
+ # Set up species store (recommended for parity)
1025
+ species_store = None
1026
+ if args.embeddings_dir:
1027
+ emb_dir = Path(args.embeddings_dir)
1028
+ detected = _detect_pooling_from_embeddings_dir(emb_dir)
1029
+ if detected is not None and detected != pooling:
1030
+ logger.info(f"Overriding pooling from checkpoint ({pooling}) → embeddings_dir format ({detected})")
1031
+ pooling = detected
1032
+ species_store = SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling)
1033
+ logger.info(f"Loaded species store with {len(species_store.vocab)} species (pooling={pooling})")
1034
+
1035
+ # Load sampler/model (uses same construction as sampling)
1036
+ sampler = CodonSampler(
1037
+ model_path=args.model_path,
1038
+ device=("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"),
1039
+ species_store=species_store,
1040
+ )
1041
+
1042
+ # Load input data and sample rows
1043
+ if bool(args.export_per_sequence):
1044
+ export_per_sequence_over_splits(
1045
+ sampler,
1046
+ splits=list(args.export_splits),
1047
+ splits_root=str(args.splits_root),
1048
+ out_csv=str(args.out_csv),
1049
+ batch_size=int(args.batch_size),
1050
+ temperature=float(args.temperature),
1051
+ top_k=int(args.top_k),
1052
+ top_p=float(args.top_p),
1053
+ control_mode=str(args.control_mode),
1054
+ enforce_translation=bool(args.enforce_translation),
1055
+ progress=bool(args.progress),
1056
+ max_rows_per_split=int(args.max_rows_per_split),
1057
+ no_truncation=bool(args.no_truncation),
1058
+ species_prefix_cap=int(args.species_prefix_cap),
1059
+ )
1060
+ return
1061
+
1062
+ data_path = args.data_path or args.csv_path
1063
+ if data_path is None:
1064
+ raise SystemExit("Please provide --data_path (CSV or Parquet glob/dir). --csv_path remains as a deprecated alias.")
1065
+
1066
+ # Expand paths to decide CSV vs Parquet
1067
+ paths = _expand_paths(data_path)
1068
+ if not paths:
1069
+ raise FileNotFoundError(f"No input files matched: {data_path}")
1070
+
1071
+ if all(_is_parquet_path(p) for p in paths):
1072
+ logger.info(f"Reading up to {args.num_samples} samples from {len(paths)} parquet files ...")
1073
+ df_s = _load_random_samples_from_parquet(paths, int(args.num_samples), int(args.seed))
1074
+ else:
1075
+ # Fallback to CSV/TSV single file behavior (back-compat). If multiple files match, use the first.
1076
+ csv_file = None
1077
+ for pth in paths:
1078
+ if pth.lower().endswith((".csv", ".tsv", ".csv.gz", ".tsv.gz")):
1079
+ csv_file = pth
1080
+ break
1081
+ if csv_file is None:
1082
+ raise ValueError(f"Unsupported input for --data_path: {paths[0]}")
1083
+ logger.info(f"Reading CSV file: {csv_file}")
1084
+ df = pd.read_csv(csv_file)
1085
+ required = {"Taxon", "protein_seq", "cds_DNA"}
1086
+ if not required.issubset(set(df.columns)):
1087
+ missing = required - set(df.columns)
1088
+ raise ValueError(f"CSV missing required columns: {sorted(missing)}")
1089
+ if args.num_samples > len(df):
1090
+ logger.warning(f"num_samples {args.num_samples} > CSV rows {len(df)}; reducing")
1091
+ args.num_samples = len(df)
1092
+ # Random sample without replacement
1093
+ indices = random.sample(range(len(df)), args.num_samples)
1094
+ df_s = df.iloc[indices].reset_index(drop=True)
1095
+
1096
+ if len(df_s) == 0:
1097
+ raise ValueError("No samples loaded from the provided data_path")
1098
+
1099
+ logger.info(f"Loaded {len(df_s)} samples for evaluation")
1100
+
1101
+ species = df_s["Taxon"].astype(str).tolist()
1102
+ proteins = df_s["protein_seq"].astype(str).str.upper().tolist()
1103
+ dnas = df_s["cds_DNA"].astype(str).str.upper().tolist()
1104
+
1105
+ if not args.free_run:
1106
+ if bool(args.eval_all):
1107
+ if not args.embeddings_dir:
1108
+ raise SystemExit("--eval_all requires --embeddings_dir for species vocab/embeddings")
1109
+ # Stream the entire dataset and compute dataset-level metrics (training-parity)
1110
+ eval_streaming_all(
1111
+ sampler,
1112
+ species_store if species_store is not None else SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling),
1113
+ data_path,
1114
+ batch_size=int(args.batch_size),
1115
+ num_workers=int(args.workers),
1116
+ max_records=int(args.max_records),
1117
+ )
1118
+ return
1119
+ # Optional: print per-sample CDS→AA agreement (standard code)
1120
+ if bool(args.debug_aa_check):
1121
+ for idx, (sp, pr, dn) in enumerate(zip(species, proteins, dnas), start=1):
1122
+ ratio, Lcmp, first_bad = _aa_agreement(dn, pr, sampler.tokenizer)
1123
+ flag = "OK" if ratio == 1.0 and Lcmp > 0 else ("EMPTY" if Lcmp == 0 else "MISMATCH")
1124
+ extra = f" first_mismatch={first_bad}" if first_bad >= 0 else ""
1125
+ logger.info(f"AA-CHECK Sample {idx:02d}: {flag} match={ratio:.3f} len={Lcmp}{extra} Taxon={sp}")
1126
+ # (No dataset-level filtering to keep evaluation simple.)
1127
+ # Teacher-forced evaluation (random subset)
1128
+ per_ce_all: List[float] = []
1129
+ per_aa_acc_all: List[float] = []
1130
+ per_codon_acc_all: List[float] = []
1131
+ bs = max(1, int(args.batch_size))
1132
+ for i in range(0, len(species), bs):
1133
+ sp_b = species[i:i+bs]
1134
+ pr_b = proteins[i:i+bs]
1135
+ dn_b = dnas[i:i+bs]
1136
+ ce, aa_acc = eval_batch(sampler, species_store, sp_b, pr_b, dn_b)
1137
+ # Also compute per-sample codon-acc using the same batch forward for consistency
1138
+ # Re-run lightweight preds for codon-acc is unnecessary because eval_batch already
1139
+ # computed supervised mask and preds internally; instead, recompute quickly here
1140
+ # by calling eval_batch and deriving codon-acc inside it. For simplicity and clarity
1141
+ # we re-derive codon-acc below using the same masking rules.
1142
+ per_ce_all.extend(ce)
1143
+ per_aa_acc_all.extend(aa_acc)
1144
+
1145
+ # Derive codon-acc for this batch
1146
+ # Prepare a mirrored forward to access logits and masks (small overhead acceptable)
1147
+ tok = sampler.tokenizer
1148
+ pad_id = tok.pad_token_id
1149
+ eos_id = tok.eos_token_id
1150
+ codon_ids_local = []
1151
+ for dna, prot in zip(dn_b, pr_b):
1152
+ C_dna = len(dna) // 3
1153
+ C_prot = len(prot)
1154
+ C = max(min(C_dna, C_prot), 1)
1155
+ dna_trim = dna[: 3 * C]
1156
+ ids = tok.encode_codon_seq(dna_trim, validate=False)
1157
+ ids.append(eos_id)
1158
+ codon_ids_local.append(ids)
1159
+ B_b = len(codon_ids_local)
1160
+ T_b = max(len(x) for x in codon_ids_local)
1161
+ codons_b = torch.full((B_b, T_b), pad_id, dtype=torch.long)
1162
+ mask_b = torch.zeros((B_b, T_b), dtype=torch.bool)
1163
+ for j, ids in enumerate(codon_ids_local):
1164
+ Lb = len(ids)
1165
+ codons_b[j, :Lb] = torch.tensor(ids, dtype=torch.long)
1166
+ mask_b[j, :Lb] = True
1167
+ input_ids_b = codons_b[:, :-1].to(sampler.device)
1168
+ labels_b = codons_b[:, :-1].clone()
1169
+ labels_b[labels_b == pad_id] = -100
1170
+ labels_b[labels_b == eos_id] = -100
1171
+ cond_b = {"control_mode": "fixed"}
1172
+ if species_store is not None and sp_b:
1173
+ sids_b = [species_store.vocab.get(s, -1) for s in sp_b]
1174
+ res_b = species_store.batch_get(sids_b)
1175
+ if isinstance(res_b, tuple):
1176
+ sp_tok_b, _ = res_b
1177
+ cond_b["species_tok_emb_src"] = sp_tok_b.to(sampler.device)
1178
+ cond_b["species_tok_emb_tgt"] = sp_tok_b.to(sampler.device)
1179
+ else:
1180
+ sp_fix_b = res_b
1181
+ cond_b["species_emb_src"] = sp_fix_b.to(sampler.device)
1182
+ cond_b["species_emb_tgt"] = sp_fix_b.to(sampler.device)
1183
+ cond_b["protein_seqs"] = pr_b
1184
+ out_b = sampler.model(codon_ids=input_ids_b, cond=cond_b, labels=labels_b.to(sampler.device), return_dict=True)
1185
+ logits_b = out_b["logits"]
1186
+ per_cap_b = out_b.get("per_cap")
1187
+ if logits_b is not None and per_cap_b is not None:
1188
+ Bsz, Lmax, V = logits_b.size(0), logits_b.size(1), logits_b.size(2)
1189
+ labels_aligned_b = torch.full((Bsz, Lmax), -100, dtype=labels_b.dtype, device=logits_b.device)
1190
+ common_cols_b = min(labels_b.size(1), Lmax)
1191
+ if common_cols_b > 0:
1192
+ labels_aligned_b[:, :common_cols_b] = labels_b.to(logits_b.device)[:, :common_cols_b]
1193
+ ar = torch.arange(Lmax, device=logits_b.device).unsqueeze(0)
1194
+ cap_mask_b = ar < per_cap_b.to(device=logits_b.device).unsqueeze(1)
1195
+ labels_masked_b = labels_aligned_b.clone()
1196
+ labels_masked_b[~cap_mask_b] = -100
1197
+ preds_b = logits_b.argmax(dim=-1)
1198
+ num_special = int(getattr(tok, "num_special_tokens", 0) or 0)
1199
+ supervised_b = (labels_masked_b != -100) & cap_mask_b
1200
+ if num_special > 0:
1201
+ supervised_b = supervised_b & (labels_aligned_b >= num_special)
1202
+ for r in range(Bsz):
1203
+ denom = int(supervised_b[r].sum().item())
1204
+ cod_acc = (float((preds_b[r][supervised_b[r]] == labels_aligned_b[r][supervised_b[r]]).sum().item()) / denom) if denom > 0 else 0.0
1205
+ per_codon_acc_all.append(cod_acc)
1206
+
1207
+ for idx, (ce, aa, ca) in enumerate(zip(per_ce_all, per_aa_acc_all, per_codon_acc_all), start=1):
1208
+ logger.info(f"Sample {idx:02d}: CE={ce:.4f} CODON-acc={ca:.4f} AA-acc={aa:.4f}")
1209
+ if per_ce_all:
1210
+ mean_ce = sum(per_ce_all) / len(per_ce_all)
1211
+ mean_aa = sum(per_aa_acc_all) / len(per_aa_acc_all) if per_aa_acc_all else 0.0
1212
+ mean_codon = sum(per_codon_acc_all) / len(per_codon_acc_all) if per_codon_acc_all else 0.0
1213
+ logger.info(f"Summary over {len(per_ce_all)} samples → mean CE={mean_ce:.4f}, mean CODON-acc={mean_codon:.4f}, mean AA-acc={mean_aa:.4f}")
1214
+ else:
1215
+ # Free-run sampling evaluation vs ground-truth DNA (codon-level), batched
1216
+ codon_accs, aa_accs = sample_and_score_batched(
1217
+ sampler,
1218
+ species,
1219
+ proteins,
1220
+ dnas,
1221
+ temperature=args.temperature,
1222
+ top_k=args.top_k,
1223
+ top_p=args.top_p,
1224
+ control_mode=args.control_mode,
1225
+ batch_size=int(args.batch_size),
1226
+ enforce_translation=bool(args.enforce_translation),
1227
+ no_truncation=bool(args.no_truncation),
1228
+ species_prefix_cap=int(args.species_prefix_cap),
1229
+ )
1230
+ for idx, (cacc, aacc) in enumerate(zip(codon_accs, aa_accs), start=1):
1231
+ logger.info(f"Sample {idx:02d}: CODON-acc={cacc:.4f} AA-acc={aacc:.4f}")
1232
+ if codon_accs:
1233
+ mean_c = sum(codon_accs) / len(codon_accs)
1234
+ mean_a = sum(aa_accs) / len(aa_accs)
1235
+ logger.info(f"Summary over {len(codon_accs)} samples → mean CODON-acc={mean_c:.4f}, mean AA-acc={mean_a:.4f}")
1236
+
1237
+
1238
+ if __name__ == "__main__":
1239
+ main()
final_model/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 2048,
3
+ "max_species_prefix": 0,
4
+ "max_protein_prefix": 1024,
5
+ "hidden_size": 750,
6
+ "num_hidden_layers": 20,
7
+ "num_attention_heads": 15,
8
+ "mlp_ratio": 3.2,
9
+ "prepend_species": true,
10
+ "prepend_protein": true,
11
+ "species_embedding_dim": 1024,
12
+ "esm_model_name": "esmc_300m",
13
+ "esm_device": "cuda:0",
14
+ "esm_dtype": "bf16",
15
+ "attn_impl": "mha",
16
+ "num_kv_groups": 5
17
+ }
final_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5af6fe27a93e8a5edf622131b8fff74240f90db036a95697cfe4f28af1d23ef9
3
+ size 1284544520
final_model/trainer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 2048,
3
+ "max_species_prefix": 0,
4
+ "max_protein_prefix": 1024,
5
+ "hidden_size": 750,
6
+ "num_hidden_layers": 20,
7
+ "num_attention_heads": 15,
8
+ "mlp_ratio": 3.2,
9
+ "prepend_species": true,
10
+ "prepend_protein": true,
11
+ "species_embedding_dim": 1024,
12
+ "esm_model_name": "esmc_300m",
13
+ "esm_device": "cuda:0",
14
+ "esm_dtype": "bf16",
15
+ "attn_impl": "mha",
16
+ "num_kv_groups": 5
17
+ }
final_model/trainer_state.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "epoch": 2,
3
+ "global_step": 120513
4
+ }
final_model/vocab.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "special_token_str": {
3
+ "bos": "<bos>",
4
+ "eos": "<stop>",
5
+ "pad": "<pad>",
6
+ "unk": "<unk>"
7
+ },
8
+ "vocab": {
9
+ "<bos>": 2,
10
+ "<pad>": 0,
11
+ "<stop>": 3,
12
+ "<unk>": 1,
13
+ "AAA": 4,
14
+ "AAC": 5,
15
+ "AAG": 6,
16
+ "AAT": 7,
17
+ "ACA": 8,
18
+ "ACC": 9,
19
+ "ACG": 10,
20
+ "ACT": 11,
21
+ "AGA": 12,
22
+ "AGC": 13,
23
+ "AGG": 14,
24
+ "AGT": 15,
25
+ "ATA": 16,
26
+ "ATC": 17,
27
+ "ATG": 18,
28
+ "ATT": 19,
29
+ "CAA": 20,
30
+ "CAC": 21,
31
+ "CAG": 22,
32
+ "CAT": 23,
33
+ "CCA": 24,
34
+ "CCC": 25,
35
+ "CCG": 26,
36
+ "CCT": 27,
37
+ "CGA": 28,
38
+ "CGC": 29,
39
+ "CGG": 30,
40
+ "CGT": 31,
41
+ "CTA": 32,
42
+ "CTC": 33,
43
+ "CTG": 34,
44
+ "CTT": 35,
45
+ "GAA": 36,
46
+ "GAC": 37,
47
+ "GAG": 38,
48
+ "GAT": 39,
49
+ "GCA": 40,
50
+ "GCC": 41,
51
+ "GCG": 42,
52
+ "GCT": 43,
53
+ "GGA": 44,
54
+ "GGC": 45,
55
+ "GGG": 46,
56
+ "GGT": 47,
57
+ "GTA": 48,
58
+ "GTC": 49,
59
+ "GTG": 50,
60
+ "GTT": 51,
61
+ "TAA": 52,
62
+ "TAC": 53,
63
+ "TAG": 54,
64
+ "TAT": 55,
65
+ "TCA": 56,
66
+ "TCC": 57,
67
+ "TCG": 58,
68
+ "TCT": 59,
69
+ "TGA": 60,
70
+ "TGC": 61,
71
+ "TGG": 62,
72
+ "TGT": 63,
73
+ "TTA": 64,
74
+ "TTC": 65,
75
+ "TTG": 66,
76
+ "TTT": 67
77
+ }
78
+ }
precompute_embeddings.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Precompute species embeddings for CodonTranslator training.
4
+ Protein embeddings are now computed on-the-fly using integrated ESM-C model.
5
+
6
+ Steps:
7
+ 1. Build taxonomy database from GBIF API
8
+ 2. Generate species embeddings using Qwen3-Embedding-0.6B
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import logging
14
+ import argparse
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple
17
+ import glob
18
+ import requests
19
+ import time
20
+ from collections import defaultdict
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ import torch
25
+ from tqdm import tqdm
26
+
27
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def build_taxonomy_database(species_list: List[str]) -> Dict[str, str]:
32
+ """Query GBIF API for comprehensive phylogenetic taxonomy of species.
33
+
34
+ Creates detailed taxonomic descriptions for better species embeddings.
35
+ """
36
+ taxonomy_db = {}
37
+ base_url = "https://api.gbif.org/v1/species/match"
38
+
39
+ logger.info(f"Building taxonomy database for {len(species_list)} species...")
40
+ for species in tqdm(species_list, desc="Querying GBIF"):
41
+ if not species or species in taxonomy_db:
42
+ continue
43
+
44
+ try:
45
+ response = requests.get(base_url, params={"name": species})
46
+ if response.status_code == 200:
47
+ data = response.json()
48
+ if data.get("matchType") != "NONE":
49
+ # Build comprehensive taxonomy description
50
+ parts = []
51
+
52
+ # Add scientific classification
53
+ taxonomy = []
54
+ for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]:
55
+ if rank in data and data[rank]:
56
+ taxonomy.append(data[rank])
57
+
58
+ if taxonomy:
59
+ parts.append("Taxonomy: " + " > ".join(taxonomy))
60
+
61
+ # Add common name if available
62
+ if "vernacularName" in data and data["vernacularName"]:
63
+ parts.append(f"Common name: {data['vernacularName']}")
64
+
65
+ # Add confidence score
66
+ if "confidence" in data:
67
+ parts.append(f"Match confidence: {data['confidence']}%")
68
+
69
+ # Add status (accepted, synonym, etc.)
70
+ if "status" in data:
71
+ parts.append(f"Status: {data['status']}")
72
+
73
+ # Combine all parts into comprehensive description
74
+ taxonomy_db[species] = ". ".join(parts) if parts else species
75
+ else:
76
+ # No match found - use species name with indicator
77
+ taxonomy_db[species] = f"Species: {species} (no GBIF match)"
78
+ else:
79
+ taxonomy_db[species] = f"Species: {species} (query failed)"
80
+
81
+ # Rate limiting
82
+ time.sleep(0.1)
83
+ except Exception as e:
84
+ logger.warning(f"Error querying GBIF for {species}: {e}")
85
+ taxonomy_db[species] = f"Species: {species} (error)"
86
+
87
+ logger.info(f"Taxonomy database built with {len(taxonomy_db)} entries")
88
+ return taxonomy_db
89
+
90
+
91
+ def generate_species_embeddings_qwen(
92
+ species_list: List[str],
93
+ taxonomy_db: Dict[str, str],
94
+ device: str = "cuda",
95
+ pooling: str = "last" # 'last' -> single vector; 'sequence'/'none' -> variable-length tokens
96
+ ) -> Tuple[Dict[str, int], Dict[int, np.ndarray]]:
97
+ """
98
+ Generate species embeddings using Qwen3-Embedding-0.6B.
99
+ - pooling='last': returns one vector per species (fixed size)
100
+ - pooling='none': returns variable-length token embeddings per species
101
+ """
102
+ import torch.nn.functional as F
103
+ from transformers import AutoTokenizer, AutoModel
104
+
105
+ def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
106
+ """Pool by taking the last valid token's embedding."""
107
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
108
+ if left_padding:
109
+ return last_hidden_states[:, -1]
110
+ else:
111
+ sequence_lengths = attention_mask.sum(dim=1) - 1
112
+ batch_size = last_hidden_states.shape[0]
113
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
114
+
115
+ def get_detailed_instruct(task_description: str, query: str) -> str:
116
+ """Format the input with instruction for better embedding quality."""
117
+ return f'Instruct: {task_description}\nQuery: {query}'
118
+
119
+ logger.info("Loading Qwen3-Embedding-0.6B model...")
120
+ model_name = "Qwen/Qwen3-Embedding-0.6B"
121
+
122
+ # Initialize with left padding for last token pooling
123
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
124
+ model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval()
125
+
126
+ species_vocab = {}
127
+ species_embeddings = {}
128
+
129
+ # Task description for species embedding
130
+ task = "Given a species taxonomy information, generate a biological embedding representing its taxonomic and evolutionary characteristics"
131
+
132
+ for idx, species in enumerate(tqdm(species_list, desc="Generating embeddings")):
133
+ # Get comprehensive taxonomy string from GBIF query results
134
+ taxonomy_str = taxonomy_db.get(species, species)
135
+
136
+ # Format with instruction for better semantic understanding
137
+ input_text = get_detailed_instruct(task, taxonomy_str)
138
+
139
+ # Generate embeddings
140
+ with torch.no_grad():
141
+ inputs = tokenizer(
142
+ input_text,
143
+ return_tensors="pt",
144
+ padding=True,
145
+ truncation=True,
146
+ max_length=512
147
+ )
148
+ inputs = {k: v.to(device) for k, v in inputs.items()}
149
+
150
+ outputs = model(**inputs)
151
+ hidden = outputs.last_hidden_state # [1, L, D]
152
+ if pooling == 'last':
153
+ pooled_embedding = last_token_pool(hidden, inputs['attention_mask'])
154
+ normalized_embedding = F.normalize(pooled_embedding, p=2, dim=1)
155
+ species_embedding = normalized_embedding.squeeze(0).cpu().numpy() # [D]
156
+ else:
157
+ # Variable-length token embeddings (normalize per token)
158
+ tok = hidden.squeeze(0) # [L, D]
159
+ tok = F.normalize(tok, p=2, dim=-1)
160
+ species_embedding = tok.cpu().numpy() # [L, D]
161
+
162
+ species_vocab[species] = idx
163
+ species_embeddings[idx] = species_embedding
164
+
165
+ logger.info(f"Generated {'fixed-size' if pooling=='last' else 'variable-length'} embeddings for {len(species_vocab)} species")
166
+ return species_vocab, species_embeddings
167
+
168
+
169
+ def save_species_embeddings_memmap(
170
+ species_vocab: Dict[str, int],
171
+ species_embeddings: Dict[int, np.ndarray],
172
+ output_dir: str
173
+ ) -> None:
174
+ """Save fixed-size species embeddings as memory-mapped file."""
175
+ os.makedirs(output_dir, exist_ok=True)
176
+
177
+ # Save vocabulary
178
+ vocab_path = os.path.join(output_dir, "species_vocab.json")
179
+ with open(vocab_path, 'w') as f:
180
+ json.dump(species_vocab, f, indent=2)
181
+
182
+ # All embeddings should have the same dimension now
183
+ num_species = len(species_embeddings)
184
+ embed_dim = next(iter(species_embeddings.values())).shape[0] # Should be 1024
185
+
186
+ # Create memmap for fixed-size embeddings
187
+ emb_path = os.path.join(output_dir, "species_embeddings.bin")
188
+ mmap = np.memmap(emb_path, dtype=np.float32, mode='w+', shape=(num_species, embed_dim))
189
+
190
+ # Store embeddings directly by ID
191
+ for species_id, emb in species_embeddings.items():
192
+ mmap[species_id] = emb.astype(np.float32)
193
+
194
+ # Flush to disk
195
+ del mmap
196
+
197
+ # Save metadata
198
+ metadata = {
199
+ "num_species": num_species,
200
+ "embedding_dim": embed_dim,
201
+ "embedding_type": "fixed_size",
202
+ "pooling_method": "last_token",
203
+ "normalization": "L2",
204
+ "model": "Qwen/Qwen3-Embedding-0.6B"
205
+ }
206
+
207
+ metadata_path = os.path.join(output_dir, "species_metadata.json")
208
+ with open(metadata_path, 'w') as f:
209
+ json.dump(metadata, f, indent=2)
210
+
211
+ logger.info(f"Saved {num_species} fixed-size species embeddings to {emb_path}")
212
+ logger.info(f"Embedding dimension: {embed_dim}")
213
+ logger.info(f"Saved metadata to {metadata_path}")
214
+
215
+
216
+ def save_species_token_embeddings_memmap(
217
+ species_vocab: Dict[str, int],
218
+ species_tok_embeddings: Dict[int, np.ndarray],
219
+ output_dir: str,
220
+ dtype: str = 'float32'
221
+ ) -> None:
222
+ """Save variable-length token embeddings into a flat memmap with index."""
223
+ os.makedirs(output_dir, exist_ok=True)
224
+
225
+ # Save vocabulary
226
+ vocab_path = os.path.join(output_dir, "species_vocab.json")
227
+ with open(vocab_path, 'w') as f:
228
+ json.dump(species_vocab, f, indent=2)
229
+
230
+ # Compute totals and dims
231
+ embed_dim = next(iter(species_tok_embeddings.values())).shape[1]
232
+ total_tokens = int(sum(v.shape[0] for v in species_tok_embeddings.values()))
233
+
234
+ emb_path = os.path.join(output_dir, "species_tok_emb.bin")
235
+ mmap = np.memmap(emb_path, dtype=np.float32 if dtype=='float32' else np.float16, mode='w+', shape=(total_tokens, embed_dim))
236
+
237
+ # Build index
238
+ index = {}
239
+ offset = 0
240
+ for sid, arr in species_tok_embeddings.items():
241
+ L = int(arr.shape[0])
242
+ mmap[offset: offset + L] = arr.astype(np.float32 if dtype=='float32' else np.float16)
243
+ index[str(sid)] = {"offset": offset, "length": L}
244
+ offset += L
245
+
246
+ del mmap
247
+
248
+ with open(os.path.join(output_dir, "species_index.json"), 'w') as f:
249
+ json.dump(index, f, indent=2)
250
+
251
+ meta = {
252
+ "embedding_dim": embed_dim,
253
+ "dtype": dtype,
254
+ "total_tokens": total_tokens,
255
+ "embedding_type": "variable_length",
256
+ "pooling_method": "none",
257
+ "model": "Qwen/Qwen3-Embedding-0.6B"
258
+ }
259
+ with open(os.path.join(output_dir, "metadata.json"), 'w') as f:
260
+ json.dump(meta, f, indent=2)
261
+ logger.info(f"Saved variable-length species token embeddings to {emb_path} with {total_tokens} tokens total")
262
+
263
+
264
+ def filter_sequences_by_length(df: pd.DataFrame, max_protein_length: int = 2048) -> pd.DataFrame:
265
+ """Filter sequences to prevent CUDA OOM during training."""
266
+ initial_count = len(df)
267
+
268
+ # Filter by protein length
269
+ if 'protein_seq' in df.columns:
270
+ df = df[df['protein_seq'].str.len() <= max_protein_length]
271
+
272
+ # Filter by CDS length (3x protein length)
273
+ if 'cds_DNA' in df.columns:
274
+ max_cds_length = max_protein_length * 3
275
+ df = df[df['cds_DNA'].str.len() <= max_cds_length]
276
+
277
+ final_count = len(df)
278
+ if final_count < initial_count:
279
+ logger.info(f"Filtered from {initial_count} to {final_count} sequences (max_protein_length={max_protein_length})")
280
+
281
+ return df
282
+
283
+
284
+ def collect_unique_values_from_shards(
285
+ shards_glob: str,
286
+ column: str,
287
+ max_items: Optional[int] = None
288
+ ) -> List[str]:
289
+ """Stream over Parquet shards to collect unique values from a column."""
290
+ unique_values = set()
291
+ shard_files = sorted(glob.glob(shards_glob))
292
+
293
+ if not shard_files:
294
+ raise ValueError(f"No parquet files found matching {shards_glob}")
295
+
296
+ logger.info(f"Scanning {len(shard_files)} shards for unique {column} values...")
297
+
298
+ for shard_file in tqdm(shard_files, desc=f"Collecting {column}"):
299
+ # Some datasets use different casing (e.g., 'taxon' vs 'Taxon'). Resolve robustly.
300
+ try:
301
+ import pyarrow.parquet as pq # type: ignore
302
+ pf = pq.ParquetFile(shard_file)
303
+ names = set(pf.schema.names)
304
+ resolved = column
305
+ if resolved not in names:
306
+ lower_map = {n.lower(): n for n in names}
307
+ resolved = lower_map.get(column.lower(), column)
308
+ except Exception:
309
+ resolved = column
310
+
311
+ df = pd.read_parquet(shard_file, columns=[resolved])
312
+ # Canonicalize to the requested column name for downstream logic.
313
+ if resolved != column and resolved in df.columns and column not in df.columns:
314
+ df = df.rename(columns={resolved: column})
315
+ unique_values.update(df[column].dropna().unique())
316
+
317
+ if max_items and len(unique_values) >= max_items:
318
+ break
319
+
320
+ result = sorted(list(unique_values))[:max_items] if max_items else sorted(list(unique_values))
321
+ logger.info(f"Collected {len(result)} unique {column} values")
322
+ return result
323
+
324
+
325
+ def collect_stage1_species(shards_glob: str) -> List[str]:
326
+ """Extract unique species from Stage-1 shards."""
327
+ return collect_unique_values_from_shards(shards_glob, "Taxon")
328
+
329
+
330
+ def prepare_species_from_stage1_shards(
331
+ shards_glob: str,
332
+ output_dir: str,
333
+ device: str = "cuda",
334
+ resume: bool = False,
335
+ species_pooling: str = "last"
336
+ ) -> None:
337
+ """End-to-end species embedding generation from Stage-1 shards."""
338
+ os.makedirs(output_dir, exist_ok=True)
339
+
340
+ # Check for existing files
341
+ vocab_path = os.path.join(output_dir, "species_vocab.json")
342
+ if resume and os.path.exists(vocab_path):
343
+ logger.info("Species embeddings already exist. Skipping generation.")
344
+ return
345
+
346
+ # Collect unique species
347
+ species_list = collect_stage1_species(shards_glob)
348
+ logger.info(f"Found {len(species_list)} unique species in shards")
349
+
350
+ # Build taxonomy database
351
+ taxonomy_cache_path = os.path.join(output_dir, "taxonomy_database.json")
352
+ if resume and os.path.exists(taxonomy_cache_path):
353
+ logger.info("Loading cached taxonomy database...")
354
+ with open(taxonomy_cache_path, 'r') as f:
355
+ taxonomy_db = json.load(f)
356
+ else:
357
+ taxonomy_db = build_taxonomy_database(species_list)
358
+ with open(taxonomy_cache_path, 'w') as f:
359
+ json.dump(taxonomy_db, f, indent=2)
360
+
361
+ # Generate embeddings
362
+ species_vocab, species_embeddings = generate_species_embeddings_qwen(
363
+ species_list, taxonomy_db, device, pooling=species_pooling
364
+ )
365
+
366
+ # Save per requested pooling
367
+ if species_pooling == 'last':
368
+ save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir)
369
+ else:
370
+ save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir)
371
+
372
+ logger.info("Species embedding preparation complete")
373
+
374
+
375
+ def create_precomputed_dataset(
376
+ input_csv: Optional[str],
377
+ output_dir: str,
378
+ device: str = "cuda",
379
+ batch_size: int = 50,
380
+ max_protein_length: int = 2048,
381
+ resume: bool = False,
382
+ species_pooling: str = "last"
383
+ ):
384
+ """
385
+ Create embedding dataset with species-only precomputation.
386
+ Protein embeddings will be computed on-the-fly during training.
387
+ """
388
+ os.makedirs(output_dir, exist_ok=True)
389
+
390
+ # Skip if resuming and files exist
391
+ if resume and os.path.exists(os.path.join(output_dir, "species_vocab.json")):
392
+ logger.info("Precomputed dataset already exists. Use --resume=False to regenerate.")
393
+ return
394
+
395
+ # Load data
396
+ logger.info(f"Loading data from {input_csv}...")
397
+ if input_csv.endswith('.parquet'):
398
+ df = pd.read_parquet(input_csv)
399
+ else:
400
+ df = pd.read_csv(input_csv)
401
+
402
+ # Accept either 'Taxon' or 'taxon' as the species column.
403
+ if "Taxon" not in df.columns and "taxon" in df.columns:
404
+ df = df.rename(columns={"taxon": "Taxon"})
405
+
406
+ # Filter sequences by length
407
+ df = filter_sequences_by_length(df, max_protein_length)
408
+
409
+ # === Species Embeddings ===
410
+ logger.info("=== Generating Species Embeddings ===")
411
+ unique_species = df["Taxon"].dropna().unique().tolist()
412
+ logger.info(f"Found {len(unique_species)} unique species")
413
+
414
+ # Build taxonomy database
415
+ taxonomy_db = build_taxonomy_database(unique_species)
416
+
417
+ # Save taxonomy database
418
+ taxonomy_path = os.path.join(output_dir, "taxonomy_database.json")
419
+ with open(taxonomy_path, 'w') as f:
420
+ json.dump(taxonomy_db, f, indent=2)
421
+
422
+ # Generate species embeddings
423
+ species_vocab, species_embeddings = generate_species_embeddings_qwen(
424
+ unique_species, taxonomy_db, device, pooling=species_pooling
425
+ )
426
+
427
+ if species_pooling == 'last':
428
+ save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir)
429
+ else:
430
+ save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir)
431
+
432
+ # Save metadata
433
+ metadata = {
434
+ "num_sequences": len(df),
435
+ "num_species": len(unique_species),
436
+ "species_embedding_model": "Qwen/Qwen3-Embedding-0.6B",
437
+ "species_embedding_dim": 1024, # Qwen3 dimension
438
+ "max_protein_length": max_protein_length,
439
+ }
440
+
441
+ with open(os.path.join(output_dir, "metadata.json"), 'w') as f:
442
+ json.dump(metadata, f, indent=2)
443
+
444
+ logger.info(f"Dataset creation completed. Species embeddings are precomputed.")
445
+ logger.info("Protein embeddings will be computed on-the-fly during training using integrated ESM-C.")
446
+
447
+
448
+ def main():
449
+ parser = argparse.ArgumentParser(description="Precompute species embeddings for CodonTranslator")
450
+
451
+ # Data source options
452
+ parser.add_argument("--input_csv", type=str,
453
+ help="Path to input CSV/Parquet file")
454
+ parser.add_argument("--from_stage1_shards", action="store_true",
455
+ help="Generate from Stage-1 Parquet shards instead of CSV")
456
+ parser.add_argument("--stage1_shards_glob", type=str, default="./data/shards/*.parquet",
457
+ help="Glob pattern for Stage-1 shards")
458
+
459
+ # Output
460
+ parser.add_argument("--output_dir", type=str, required=True,
461
+ help="Output directory for precomputed embeddings")
462
+
463
+ # Processing options
464
+ parser.add_argument("--device", type=str, default="cuda",
465
+ help="Device for model inference")
466
+ parser.add_argument("--batch_size", type=int, default=50,
467
+ help="Batch size for embedding generation")
468
+ parser.add_argument("--max_protein_length", type=int, default=2048,
469
+ help="Maximum protein sequence length")
470
+ parser.add_argument("--resume", action="store_true",
471
+ help="Resume from checkpoint if available")
472
+ parser.add_argument("--species_pooling", type=str, choices=["last", "sequence", "none"], default="last",
473
+ help="'last' for single-token; 'sequence' for variable-length token embeddings")
474
+
475
+ args = parser.parse_args()
476
+
477
+ # Route to appropriate function
478
+ if args.from_stage1_shards:
479
+ prepare_species_from_stage1_shards(
480
+ args.stage1_shards_glob,
481
+ args.output_dir,
482
+ args.device,
483
+ args.resume,
484
+ args.species_pooling
485
+ )
486
+ elif args.input_csv:
487
+ create_precomputed_dataset(
488
+ args.input_csv,
489
+ args.output_dir,
490
+ args.device,
491
+ args.batch_size,
492
+ args.max_protein_length,
493
+ args.resume,
494
+ args.species_pooling
495
+ )
496
+ else:
497
+ raise ValueError("Must specify either --input_csv or --from_stage1_shards")
498
+
499
+ logger.info("Precomputation complete!")
500
+
501
+
502
+ if __name__ == "__main__":
503
+ main()
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "CodonTranslator"
7
+ version = "0.1.1"
8
+ description = "Sampling codon sequences conditioned on species and protein using a GPT model"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = {text = "MIT"}
12
+ authors = [{name = "CodonTranslator Team"}]
13
+ dependencies = [
14
+ "torch>=2.4",
15
+ "transformers>=4.57.0",
16
+ "esm>=3.2.3",
17
+ "safetensors>=0.7.0",
18
+ "numpy>=2.2.0",
19
+ "huggingface-hub>=0.36.0",
20
+ ]
21
+
22
+ [tool.setuptools]
23
+ package-dir = {"" = "."}
24
+ packages = ["CodonTranslator", "codontranslator"]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.4
2
+ transformers>=4.57.0
3
+ esm>=3.2.3
4
+ safetensors>=0.7.0
5
+ numpy>=2.2.0
6
+ huggingface-hub>=0.36.0
7
+ accelerate>=1.9.0
8
+ pyarrow>=21.0.0
9
+ pandas>=2.3.0
10
+ duckdb>=1.5.0
11
+ biopython>=1.85
12
+ wandb>=0.21.0
resplit_data_v3.py ADDED
@@ -0,0 +1,1444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Resplit `data_v2/` into leakage-safe `data_v3_rebuild/` using MMseqs2 clustering.
4
+
5
+ Default policy for the current rebuild:
6
+ - Cluster `protein_seq` with MMseqs2 `linclust`
7
+ - Define species by normalized binomial name (`genus species`)
8
+ - Test species are exactly the normalized species present in `data_v2/test`
9
+ - Validation is cluster-unseen but species-seen
10
+ - Mixed seen/heldout clusters keep heldout rows in test and drop seen rows
11
+
12
+ Typical usage (end-to-end):
13
+ python resplit_data_v3.py all --threads 32 --split-memory-limit 120G --num-shards 256
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ import shutil
22
+ import stat
23
+ import subprocess
24
+ import sys
25
+ import time
26
+ from pathlib import Path
27
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
28
+
29
+
30
+ def _default_mmseqs_path() -> str:
31
+ cand = Path("MMseqs2/build/bin/mmseqs")
32
+ if cand.exists():
33
+ return str(cand)
34
+ return "mmseqs"
35
+
36
+
37
+ def _run(cmd: List[str], *, cwd: Optional[str] = None, env: Optional[dict] = None) -> None:
38
+ pretty = " ".join(cmd)
39
+ print(f"+ {pretty}", flush=True)
40
+ subprocess.run(cmd, cwd=cwd, env=env, check=True)
41
+
42
+
43
+ def _sql_escape_path(path: str) -> str:
44
+ return path.replace("'", "''")
45
+
46
+
47
+ def _expand_parquet_inputs(inp: str) -> List[str]:
48
+ import glob
49
+
50
+ p = Path(inp)
51
+ if p.exists() and p.is_dir():
52
+ files = sorted(str(x) for x in p.rglob("*.parquet"))
53
+ else:
54
+ files = sorted(glob.glob(inp))
55
+
56
+ seen = set()
57
+ out: List[str] = []
58
+ for f in files:
59
+ if f not in seen:
60
+ out.append(f)
61
+ seen.add(f)
62
+ return out
63
+
64
+
65
+ def _duckdb_parquet_source(inp: str, limit_files: int = 0) -> str:
66
+ files = _expand_parquet_inputs(inp)
67
+ if not files:
68
+ raise SystemExit(f"No parquet files found for {inp!r}")
69
+ if limit_files and int(limit_files) > 0:
70
+ files = files[: int(limit_files)]
71
+ quoted = ", ".join(f"'{_sql_escape_path(fp)}'" for fp in files)
72
+ return f"read_parquet([{quoted}])"
73
+
74
+
75
+ def _mem_total_bytes() -> Optional[int]:
76
+ try:
77
+ with open("/proc/meminfo", "r", encoding="utf-8") as f:
78
+ for line in f:
79
+ if line.startswith("MemTotal:"):
80
+ parts = line.split()
81
+ kb = int(parts[1])
82
+ return kb * 1024
83
+ except OSError:
84
+ return None
85
+ except (ValueError, IndexError):
86
+ return None
87
+ return None
88
+
89
+
90
+ def _parse_mmseqs_bytes(s: str) -> Optional[int]:
91
+ s = (s or "").strip()
92
+ if not s:
93
+ return None
94
+ up = s.upper()
95
+ suffix = up[-1]
96
+ num_part = up[:-1]
97
+ unit = suffix
98
+ if suffix == "B" and len(up) >= 2 and up[-2] in "KMGT":
99
+ unit = up[-2]
100
+ num_part = up[:-2]
101
+ if unit not in "BKMGT":
102
+ return None
103
+ try:
104
+ val = float(num_part)
105
+ except ValueError:
106
+ return None
107
+ mult = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}[unit]
108
+ return int(val * mult)
109
+
110
+
111
+ def _format_bytes(n: int) -> str:
112
+ for unit, div in [("TiB", 1024**4), ("GiB", 1024**3), ("MiB", 1024**2), ("KiB", 1024)]:
113
+ if n >= div:
114
+ return f"{n / div:.1f}{unit}"
115
+ return f"{n}B"
116
+
117
+
118
+ def _seq_id_sql() -> str:
119
+ # Keep the stable row identifier aligned with the existing pipeline.
120
+ return "coalesce(protein_refseq_id, '') || '|' || coalesce(RefseqID, '')"
121
+
122
+
123
+ def _taxon_norm_sql(col: str = "taxon") -> str:
124
+ return f"regexp_replace(lower(trim(coalesce({col}, ''))), '\\\\s+', ' ', 'g')"
125
+
126
+
127
+ def _species_key_sql(mode: str, col: str = "taxon") -> str:
128
+ norm = _taxon_norm_sql(col)
129
+ if mode == "taxon":
130
+ return norm
131
+ if mode == "binomial":
132
+ return (
133
+ f"CASE "
134
+ f"WHEN strpos({norm}, ' ') > 0 "
135
+ f"THEN split_part({norm}, ' ', 1) || ' ' || split_part({norm}, ' ', 2) "
136
+ f"ELSE {norm} END"
137
+ )
138
+ raise ValueError(f"Unsupported species key mode: {mode}")
139
+
140
+
141
+ def _protein_norm_sql(col: str = "protein_seq") -> str:
142
+ cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')"
143
+ no_stop = f"regexp_replace({cleaned}, '[_*]+$', '')"
144
+ return f"regexp_replace({no_stop}, '[^A-Z]', 'X', 'g')"
145
+
146
+
147
+ def _cds_norm_sql(col: str = "cds_DNA") -> str:
148
+ cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')"
149
+ return f"regexp_replace({cleaned}, '[^ACGTN]', 'N', 'g')"
150
+
151
+
152
+ def _seq_expr_sql(seq_space: str) -> str:
153
+ if seq_space == "protein":
154
+ return _protein_norm_sql("protein_seq")
155
+ if seq_space == "cds":
156
+ return _cds_norm_sql("cds_DNA")
157
+ raise ValueError(f"Unsupported seq space: {seq_space}")
158
+
159
+
160
+ def _seq_space_input_col(seq_space: str) -> str:
161
+ if seq_space == "protein":
162
+ return "protein_seq"
163
+ if seq_space == "cds":
164
+ return "cds_DNA"
165
+ raise ValueError(f"Unsupported seq space: {seq_space}")
166
+
167
+
168
+ def _mmseqs_dbtype(seq_space: str) -> str:
169
+ if seq_space == "protein":
170
+ return "1"
171
+ if seq_space == "cds":
172
+ return "2"
173
+ raise ValueError(f"Unsupported seq space: {seq_space}")
174
+
175
+
176
+ def _default_max_input_seq_len(seq_space: str) -> int:
177
+ if seq_space == "protein":
178
+ # MMseqs linclust hit an internal SW bug on a tiny tail of ultra-long proteins
179
+ # (~39k aa+). Filtering this tail removes <0.01% of rows and keeps the run stable.
180
+ return 20_000
181
+ return 0
182
+
183
+
184
+ def _ensure_mmseqs_ready(mmseqs: str) -> Tuple[str, Dict[str, str]]:
185
+ path = Path(mmseqs)
186
+ env = os.environ.copy()
187
+
188
+ if path.exists():
189
+ mode = path.stat().st_mode
190
+ if not (mode & stat.S_IXUSR):
191
+ path.chmod(mode | stat.S_IXUSR)
192
+
193
+ py = Path(sys.executable).resolve()
194
+ env_root = py.parent.parent
195
+ conda_root = env_root.parent.parent if env_root.parent.name == "envs" else env_root.parent
196
+ lib_candidates = [env_root / "lib", conda_root / "lib"]
197
+ libs = [str(p) for p in lib_candidates if p.exists()]
198
+ if libs:
199
+ current = env.get("LD_LIBRARY_PATH", "")
200
+ env["LD_LIBRARY_PATH"] = ":".join(libs + ([current] if current else []))
201
+
202
+ return str(path if path.exists() else mmseqs), env
203
+
204
+
205
+ def _ensure_output_parent(path: Path) -> None:
206
+ path.parent.mkdir(parents=True, exist_ok=True)
207
+
208
+
209
+ def cmd_make_fasta(args: argparse.Namespace) -> None:
210
+ out_fasta = Path(args.output_fasta)
211
+ _ensure_output_parent(out_fasta)
212
+
213
+ import duckdb
214
+
215
+ con = duckdb.connect()
216
+ con.execute(f"PRAGMA threads={int(args.threads)};")
217
+ con.execute("PRAGMA enable_progress_bar=true;")
218
+
219
+ source_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
220
+ out_path = _sql_escape_path(str(out_fasta))
221
+ seq_id = _seq_id_sql()
222
+ seq_expr = _seq_expr_sql(args.seq_space)
223
+ raw_col = _seq_space_input_col(args.seq_space)
224
+ max_input_seq_len = int(args.max_input_seq_len)
225
+ if max_input_seq_len <= 0:
226
+ max_input_seq_len = _default_max_input_seq_len(args.seq_space)
227
+ len_filter = (
228
+ f"AND length({seq_expr}) <= {max_input_seq_len}"
229
+ if max_input_seq_len > 0
230
+ else ""
231
+ )
232
+
233
+ sql = f"""
234
+ COPY (
235
+ SELECT
236
+ '>' || ({seq_id}) AS header,
237
+ {seq_expr} AS seq
238
+ FROM {source_sql}
239
+ WHERE {raw_col} IS NOT NULL
240
+ AND length({seq_expr}) > 0
241
+ {len_filter}
242
+ AND length(({seq_id})) > 1
243
+ {f"LIMIT {int(args.limit_rows)}" if args.limit_rows and int(args.limit_rows) > 0 else ""}
244
+ )
245
+ TO '{out_path}'
246
+ (FORMAT CSV, DELIMITER '\n', QUOTE '', ESCAPE '', HEADER FALSE);
247
+ """
248
+ t0 = time.time()
249
+ con.execute(sql)
250
+ print(
251
+ f"Wrote FASTA: {out_fasta} seq_space={args.seq_space} "
252
+ f"max_input_seq_len={max_input_seq_len if max_input_seq_len > 0 else 'none'} "
253
+ f"(elapsed_s={time.time() - t0:.1f})"
254
+ )
255
+
256
+
257
+ def cmd_mmseqs_cluster(args: argparse.Namespace) -> None:
258
+ mmseqs, env = _ensure_mmseqs_ready(args.mmseqs)
259
+ workdir = Path(args.workdir)
260
+ workdir.mkdir(parents=True, exist_ok=True)
261
+
262
+ fasta = Path(args.fasta)
263
+ if not fasta.exists():
264
+ raise SystemExit(f"FASTA not found: {fasta}")
265
+
266
+ seqdb = workdir / "seqdb"
267
+ clu = workdir / "clu"
268
+ tmp = workdir / "tmp"
269
+ tsv = workdir / "clu.tsv"
270
+
271
+ if args.overwrite:
272
+ for p in (seqdb, clu, tmp, tsv):
273
+ if p.is_dir():
274
+ shutil.rmtree(p, ignore_errors=True)
275
+ else:
276
+ for suffix in ("", ".dbtype", ".index", ".lookup", ".source"):
277
+ try:
278
+ os.remove(str(p) + suffix)
279
+ except OSError:
280
+ pass
281
+
282
+ tmp.mkdir(parents=True, exist_ok=True)
283
+
284
+ _run(
285
+ [
286
+ mmseqs,
287
+ "createdb",
288
+ str(fasta),
289
+ str(seqdb),
290
+ "--dbtype",
291
+ _mmseqs_dbtype(args.seq_space),
292
+ "--shuffle",
293
+ "0",
294
+ "--createdb-mode",
295
+ "1",
296
+ "--threads",
297
+ str(int(args.threads)),
298
+ ],
299
+ env=env,
300
+ )
301
+
302
+ linclust_cmd = [
303
+ mmseqs,
304
+ "linclust",
305
+ str(seqdb),
306
+ str(clu),
307
+ str(tmp),
308
+ "--min-seq-id",
309
+ str(float(args.min_seq_id)),
310
+ "-c",
311
+ str(float(args.coverage)),
312
+ "--cov-mode",
313
+ str(int(args.cov_mode)),
314
+ "--cluster-mode",
315
+ str(int(args.cluster_mode)),
316
+ "--threads",
317
+ str(int(args.threads)),
318
+ "--max-seq-len",
319
+ str(int(args.max_seq_len)),
320
+ "--remove-tmp-files",
321
+ "1" if args.remove_tmp_files else "0",
322
+ ]
323
+ if args.split_memory_limit:
324
+ mem_total = _mem_total_bytes()
325
+ limit_bytes = _parse_mmseqs_bytes(args.split_memory_limit)
326
+ if mem_total and limit_bytes and limit_bytes > mem_total:
327
+ print(
328
+ f"WARNING: --split-memory-limit={args.split_memory_limit} ({_format_bytes(limit_bytes)}) "
329
+ f"exceeds system MemTotal ({_format_bytes(mem_total)}). "
330
+ "MMseqs2 may under-split and crash; consider lowering it or leaving it empty.",
331
+ file=sys.stderr,
332
+ flush=True,
333
+ )
334
+ linclust_cmd += ["--split-memory-limit", str(args.split_memory_limit)]
335
+ if args.kmer_per_seq_scale is not None:
336
+ linclust_cmd += ["--kmer-per-seq-scale", str(float(args.kmer_per_seq_scale))]
337
+
338
+ _run(linclust_cmd, env=env)
339
+ _run([mmseqs, "createtsv", str(seqdb), str(seqdb), str(clu), str(tsv)], env=env)
340
+ print(f"Wrote cluster TSV: {tsv}")
341
+
342
+
343
+ def cmd_make_seq_cluster(args: argparse.Namespace) -> None:
344
+ import duckdb
345
+
346
+ tsv = Path(args.cluster_tsv)
347
+ if not tsv.exists():
348
+ raise SystemExit(f"Cluster TSV not found: {tsv}")
349
+ out = Path(args.output_parquet)
350
+ _ensure_output_parent(out)
351
+
352
+ con = duckdb.connect()
353
+ con.execute(f"PRAGMA threads={int(args.threads)};")
354
+ con.execute("PRAGMA enable_progress_bar=true;")
355
+
356
+ tsv_path = _sql_escape_path(str(tsv))
357
+ out_path = _sql_escape_path(str(out))
358
+
359
+ sql = f"""
360
+ COPY (
361
+ SELECT DISTINCT
362
+ seq_id,
363
+ cluster_id
364
+ FROM read_csv(
365
+ '{tsv_path}',
366
+ delim='\\t',
367
+ header=false,
368
+ columns={{'cluster_id':'VARCHAR','seq_id':'VARCHAR'}}
369
+ )
370
+ )
371
+ TO '{out_path}'
372
+ (FORMAT PARQUET);
373
+ """
374
+ t0 = time.time()
375
+ con.execute(sql)
376
+ print(f"Wrote seq→cluster parquet: {out} (elapsed_s={time.time() - t0:.1f})")
377
+
378
+
379
+ def _write_cluster_split_parquet(
380
+ con,
381
+ *,
382
+ cluster_split_path: Path,
383
+ seed: int,
384
+ val_frac: float,
385
+ ) -> Dict[str, int]:
386
+ import pyarrow as pa
387
+ import pyarrow.parquet as pq
388
+
389
+ cluster_split_path.parent.mkdir(parents=True, exist_ok=True)
390
+ if cluster_split_path.exists():
391
+ cluster_split_path.unlink()
392
+
393
+ total_seen_rows = int(
394
+ con.execute(
395
+ "SELECT coalesce(sum(n_total), 0)::BIGINT FROM cluster_flags WHERE n_test = 0"
396
+ ).fetchone()[0]
397
+ )
398
+ target_val_rows = int(total_seen_rows * float(val_frac))
399
+
400
+ species_remaining = {
401
+ species_key: int(n_clusters)
402
+ for species_key, n_clusters in con.execute(
403
+ """
404
+ SELECT
405
+ cc.species_key,
406
+ count(*)::BIGINT AS n_clusters
407
+ FROM cluster_counts cc
408
+ JOIN cluster_flags cf USING (cluster_id)
409
+ WHERE cf.n_test = 0
410
+ GROUP BY cc.species_key
411
+ """
412
+ ).fetchall()
413
+ }
414
+
415
+ cur = con.execute(
416
+ f"""
417
+ SELECT
418
+ cf.cluster_id,
419
+ cf.n_total,
420
+ abs(hash(cf.cluster_id || ':{seed}')) AS rnd,
421
+ cc.species_key
422
+ FROM cluster_flags cf
423
+ JOIN cluster_counts cc USING (cluster_id)
424
+ WHERE cf.n_test = 0
425
+ ORDER BY rnd, cf.cluster_id, cc.species_key
426
+ """
427
+ )
428
+
429
+ writer = None
430
+ batch_cluster_ids: List[str] = []
431
+ batch_splits: List[str] = []
432
+ val_rows = 0
433
+ train_clusters = 0
434
+ val_clusters = 0
435
+ current_cluster: Optional[str] = None
436
+ current_n_total = 0
437
+ current_species: List[str] = []
438
+
439
+ def flush_current() -> None:
440
+ nonlocal writer, val_rows, train_clusters, val_clusters
441
+ nonlocal current_cluster, current_n_total, current_species
442
+ if current_cluster is None:
443
+ return
444
+ can_val = (
445
+ val_rows < target_val_rows
446
+ and all(species_remaining.get(species_key, 0) > 1 for species_key in current_species)
447
+ )
448
+ split = "val" if can_val else "train"
449
+ if can_val:
450
+ for species_key in current_species:
451
+ species_remaining[species_key] -= 1
452
+ val_rows += int(current_n_total)
453
+ val_clusters += 1
454
+ else:
455
+ train_clusters += 1
456
+
457
+ batch_cluster_ids.append(current_cluster)
458
+ batch_splits.append(split)
459
+ if len(batch_cluster_ids) >= 200_000:
460
+ table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits})
461
+ if writer is None:
462
+ writer = pq.ParquetWriter(str(cluster_split_path), table.schema)
463
+ writer.write_table(table)
464
+ batch_cluster_ids.clear()
465
+ batch_splits.clear()
466
+
467
+ while True:
468
+ rows = cur.fetchmany(200_000)
469
+ if not rows:
470
+ break
471
+ for cluster_id, n_total, _rnd, species_key in rows:
472
+ cluster_id = str(cluster_id)
473
+ species_key = str(species_key)
474
+ if current_cluster is None:
475
+ current_cluster = cluster_id
476
+ current_n_total = int(n_total)
477
+ current_species = [species_key]
478
+ continue
479
+ if cluster_id != current_cluster:
480
+ flush_current()
481
+ current_cluster = cluster_id
482
+ current_n_total = int(n_total)
483
+ current_species = [species_key]
484
+ continue
485
+ current_species.append(species_key)
486
+
487
+ flush_current()
488
+ if batch_cluster_ids:
489
+ table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits})
490
+ if writer is None:
491
+ writer = pq.ParquetWriter(str(cluster_split_path), table.schema)
492
+ writer.write_table(table)
493
+ elif writer is None:
494
+ empty = pa.table(
495
+ {
496
+ "cluster_id": pa.array([], type=pa.string()),
497
+ "split": pa.array([], type=pa.string()),
498
+ }
499
+ )
500
+ writer = pq.ParquetWriter(str(cluster_split_path), empty.schema)
501
+ writer.write_table(empty)
502
+ if writer is not None:
503
+ writer.close()
504
+
505
+ return {
506
+ "nonheldout_total_rows": total_seen_rows,
507
+ "target_val_rows": target_val_rows,
508
+ "actual_val_rows": val_rows,
509
+ "train_clusters": train_clusters,
510
+ "val_clusters": val_clusters,
511
+ }
512
+
513
+
514
+ def cmd_make_seq_split(args: argparse.Namespace) -> None:
515
+ import duckdb
516
+
517
+ seq_cluster = Path(args.seq_cluster_parquet)
518
+ if not seq_cluster.exists():
519
+ raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
520
+
521
+ out = Path(args.output_parquet)
522
+ cluster_split = Path(args.cluster_split_parquet)
523
+ _ensure_output_parent(out)
524
+ _ensure_output_parent(cluster_split)
525
+
526
+ con = duckdb.connect()
527
+ con.execute(f"PRAGMA threads={int(args.threads)};")
528
+ con.execute("PRAGMA enable_progress_bar=true;")
529
+
530
+ input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
531
+ heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0)
532
+ seq_cluster_path = _sql_escape_path(str(seq_cluster))
533
+ out_path = _sql_escape_path(str(out))
534
+
535
+ seq_id = _seq_id_sql()
536
+ species_key = _species_key_sql(args.species_key_mode, "taxon")
537
+ protein_norm = _protein_norm_sql("protein_seq")
538
+
539
+ con.execute(
540
+ f"""
541
+ CREATE TEMP TABLE heldout_species AS
542
+ SELECT DISTINCT {species_key} AS species_key
543
+ FROM {heldout_sql}
544
+ WHERE {species_key} != '';
545
+ """
546
+ )
547
+
548
+ con.execute(
549
+ f"""
550
+ CREATE TEMP TABLE cluster_counts AS
551
+ WITH base AS (
552
+ SELECT
553
+ {seq_id} AS seq_id,
554
+ {species_key} AS species_key
555
+ FROM {input_sql}
556
+ WHERE length(({seq_id})) > 1
557
+ AND {species_key} != ''
558
+ )
559
+ SELECT
560
+ sc.cluster_id,
561
+ base.species_key,
562
+ count(*)::BIGINT AS n
563
+ FROM base
564
+ JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
565
+ GROUP BY sc.cluster_id, base.species_key;
566
+ """
567
+ )
568
+
569
+ con.execute(
570
+ """
571
+ CREATE TEMP TABLE cluster_flags AS
572
+ SELECT
573
+ cluster_id,
574
+ sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test,
575
+ sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen,
576
+ sum(n)::BIGINT AS n_total,
577
+ count(*)::BIGINT AS n_species
578
+ FROM cluster_counts
579
+ GROUP BY cluster_id;
580
+ """
581
+ )
582
+
583
+ t0 = time.time()
584
+ split_summary = _write_cluster_split_parquet(
585
+ con,
586
+ cluster_split_path=cluster_split,
587
+ seed=int(args.seed),
588
+ val_frac=float(args.val_frac),
589
+ )
590
+ print(
591
+ "Cluster assignment summary: "
592
+ f"train_clusters={split_summary['train_clusters']:,} "
593
+ f"val_clusters={split_summary['val_clusters']:,} "
594
+ f"target_val_rows={split_summary['target_val_rows']:,} "
595
+ f"actual_val_rows={split_summary['actual_val_rows']:,} "
596
+ f"(elapsed_s={time.time() - t0:.1f})"
597
+ )
598
+
599
+ cluster_split_path = _sql_escape_path(str(cluster_split))
600
+ con.execute(
601
+ f"""
602
+ COPY (
603
+ WITH base AS (
604
+ SELECT DISTINCT
605
+ {seq_id} AS seq_id,
606
+ {species_key} AS species_key,
607
+ {protein_norm} AS protein_norm
608
+ FROM {input_sql}
609
+ WHERE length(({seq_id})) > 1
610
+ AND {species_key} != ''
611
+ ),
612
+ joined AS (
613
+ SELECT
614
+ base.seq_id,
615
+ base.species_key,
616
+ base.protein_norm,
617
+ sc.cluster_id
618
+ FROM base
619
+ LEFT JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
620
+ ),
621
+ labeled AS (
622
+ SELECT
623
+ j.seq_id,
624
+ j.species_key,
625
+ j.protein_norm,
626
+ CASE
627
+ WHEN j.cluster_id IS NULL THEN 'drop'
628
+ WHEN j.species_key IN (SELECT species_key FROM heldout_species) THEN 'test'
629
+ WHEN coalesce(cf.n_test, 0) > 0 THEN 'drop'
630
+ ELSE coalesce(cs.split, 'drop')
631
+ END AS split
632
+ FROM joined j
633
+ LEFT JOIN cluster_flags cf USING (cluster_id)
634
+ LEFT JOIN read_parquet('{cluster_split_path}') cs USING (cluster_id)
635
+ ),
636
+ protein_flags AS (
637
+ SELECT
638
+ protein_norm,
639
+ max(CASE WHEN split = 'test' THEN 1 ELSE 0 END) AS has_test,
640
+ max(CASE WHEN split = 'train' THEN 1 ELSE 0 END) AS has_train
641
+ FROM labeled
642
+ WHERE length(protein_norm) > 0
643
+ GROUP BY protein_norm
644
+ ),
645
+ guarded AS (
646
+ SELECT
647
+ l.seq_id,
648
+ l.species_key,
649
+ CASE
650
+ WHEN l.split = 'drop' THEN 'drop'
651
+ WHEN length(l.protein_norm) = 0 THEN l.split
652
+ WHEN coalesce(pf.has_test, 0) = 1 AND l.split IN ('train', 'val') THEN 'drop'
653
+ WHEN coalesce(pf.has_train, 0) = 1 AND l.split = 'val' THEN 'drop'
654
+ ELSE l.split
655
+ END AS split
656
+ FROM labeled l
657
+ LEFT JOIN protein_flags pf USING (protein_norm)
658
+ ),
659
+ dedup AS (
660
+ SELECT
661
+ seq_id,
662
+ CASE
663
+ WHEN count(DISTINCT species_key) > 1 THEN 'drop'
664
+ WHEN count(DISTINCT split) > 1 THEN 'drop'
665
+ ELSE any_value(split)
666
+ END AS split
667
+ FROM guarded
668
+ GROUP BY seq_id
669
+ )
670
+ SELECT seq_id, split FROM dedup
671
+ )
672
+ TO '{out_path}'
673
+ (FORMAT PARQUET);
674
+ """
675
+ )
676
+
677
+ rows = con.execute(
678
+ f"""
679
+ WITH base AS (
680
+ SELECT {seq_id} AS seq_id
681
+ FROM {input_sql}
682
+ )
683
+ SELECT s.split, count(*)::BIGINT AS n_rows
684
+ FROM base
685
+ JOIN read_parquet('{out_path}') s USING (seq_id)
686
+ GROUP BY s.split
687
+ ORDER BY n_rows DESC;
688
+ """
689
+ ).fetchall()
690
+ print("Split summary (rows):")
691
+ for split, n in rows:
692
+ print(f" {split}\t{n:,}")
693
+
694
+ print(f"Wrote cluster→split parquet: {cluster_split}")
695
+ print(f"Wrote seq→split parquet: {out}")
696
+
697
+
698
+ def cmd_write_data_v3(args: argparse.Namespace) -> None:
699
+ import duckdb
700
+
701
+ seq_split = Path(args.seq_split_parquet)
702
+ if not seq_split.exists():
703
+ raise SystemExit(f"seq_split parquet not found: {seq_split}")
704
+ seq_cluster = Path(args.seq_cluster_parquet)
705
+ if args.representatives_only and not seq_cluster.exists():
706
+ raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
707
+
708
+ out_root = Path(args.output_root)
709
+ out_root.mkdir(parents=True, exist_ok=True)
710
+ (out_root / "_work").mkdir(parents=True, exist_ok=True)
711
+
712
+ for split_dir in (out_root / "train", out_root / "val", out_root / "test"):
713
+ if split_dir.exists():
714
+ if not args.overwrite:
715
+ raise SystemExit(f"Output split directory exists: {split_dir} (pass --overwrite)")
716
+ shutil.rmtree(split_dir)
717
+ split_dir.mkdir(parents=True, exist_ok=True)
718
+
719
+ con = duckdb.connect()
720
+ con.execute(f"PRAGMA threads={int(args.threads)};")
721
+ con.execute("PRAGMA enable_progress_bar=true;")
722
+
723
+ input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
724
+ seq_split_path = _sql_escape_path(str(seq_split))
725
+ seq_cluster_path = _sql_escape_path(str(seq_cluster))
726
+ seq_id = _seq_id_sql()
727
+
728
+ num_shards = int(args.num_shards)
729
+ if num_shards <= 0:
730
+ raise SystemExit("--num-shards must be > 0")
731
+
732
+ for split in ("train", "val", "test"):
733
+ out_dir = _sql_escape_path(str(out_root / split))
734
+ if args.representatives_only:
735
+ target_seq_ids_sql = f"""
736
+ SELECT min(s.seq_id) AS seq_id
737
+ FROM read_parquet('{seq_split_path}') s
738
+ JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
739
+ WHERE s.split = '{split}'
740
+ GROUP BY sc.cluster_id
741
+ """
742
+ else:
743
+ target_seq_ids_sql = f"""
744
+ SELECT DISTINCT s.seq_id
745
+ FROM read_parquet('{seq_split_path}') s
746
+ WHERE s.split = '{split}'
747
+ """
748
+ sql = f"""
749
+ COPY (
750
+ WITH target_seq_ids AS (
751
+ {target_seq_ids_sql}
752
+ ),
753
+ rows AS (
754
+ SELECT
755
+ p.*,
756
+ abs(hash({seq_id})) % {num_shards} AS shard
757
+ FROM {input_sql} p
758
+ JOIN target_seq_ids t
759
+ ON t.seq_id = ({seq_id})
760
+ QUALIFY row_number() OVER (PARTITION BY ({seq_id}) ORDER BY ({seq_id})) = 1
761
+ )
762
+ SELECT * FROM rows
763
+ )
764
+ TO '{out_dir}'
765
+ (FORMAT PARQUET, PARTITION_BY (shard));
766
+ """
767
+ t0 = time.time()
768
+ con.execute(sql)
769
+ print(
770
+ f"Wrote {split} parquets to {out_root / split} "
771
+ f"representatives_only={bool(args.representatives_only)} "
772
+ f"(elapsed_s={time.time() - t0:.1f})"
773
+ )
774
+
775
+
776
+ def cmd_verify(args: argparse.Namespace) -> None:
777
+ import duckdb
778
+
779
+ seq_cluster = Path(args.seq_cluster_parquet)
780
+ seq_split = Path(args.seq_split_parquet)
781
+ if not seq_cluster.exists():
782
+ raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
783
+ if not seq_split.exists():
784
+ raise SystemExit(f"seq_split parquet not found: {seq_split}")
785
+
786
+ con = duckdb.connect()
787
+ con.execute(f"PRAGMA threads={int(args.threads)};")
788
+ con.execute("PRAGMA enable_progress_bar=true;")
789
+
790
+ input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
791
+ heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0)
792
+ seq_cluster_path = _sql_escape_path(str(seq_cluster))
793
+ seq_split_path = _sql_escape_path(str(seq_split))
794
+
795
+ seq_id = _seq_id_sql()
796
+ species_key = _species_key_sql(args.species_key_mode, "taxon")
797
+ protein_norm = _protein_norm_sql("protein_seq")
798
+
799
+ con.execute(
800
+ f"""
801
+ CREATE TEMP TABLE heldout_species AS
802
+ SELECT DISTINCT {species_key} AS species_key
803
+ FROM {heldout_sql}
804
+ WHERE {species_key} != '';
805
+ """
806
+ )
807
+ con.execute(
808
+ f"""
809
+ CREATE TEMP TABLE cluster_counts AS
810
+ WITH base AS (
811
+ SELECT
812
+ {seq_id} AS seq_id,
813
+ {species_key} AS species_key
814
+ FROM {input_sql}
815
+ WHERE length(({seq_id})) > 1
816
+ AND {species_key} != ''
817
+ )
818
+ SELECT
819
+ sc.cluster_id,
820
+ base.species_key,
821
+ count(*)::BIGINT AS n
822
+ FROM base
823
+ JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
824
+ GROUP BY sc.cluster_id, base.species_key;
825
+ """
826
+ )
827
+ con.execute(
828
+ """
829
+ CREATE TEMP TABLE cluster_flags AS
830
+ SELECT
831
+ cluster_id,
832
+ sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test,
833
+ sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen,
834
+ sum(n)::BIGINT AS n_total,
835
+ count(*)::BIGINT AS n_species
836
+ FROM cluster_counts
837
+ GROUP BY cluster_id;
838
+ """
839
+ )
840
+
841
+ split_seq_ids = {
842
+ split: int(n)
843
+ for split, n in con.execute(
844
+ f"""
845
+ SELECT split, count(*)::BIGINT AS n
846
+ FROM read_parquet('{seq_split_path}')
847
+ GROUP BY split
848
+ """
849
+ ).fetchall()
850
+ }
851
+ split_rows = {
852
+ split: int(n)
853
+ for split, n in con.execute(
854
+ f"""
855
+ WITH base AS (
856
+ SELECT {seq_id} AS seq_id FROM {input_sql}
857
+ )
858
+ SELECT s.split, count(*)::BIGINT AS n
859
+ FROM base
860
+ JOIN read_parquet('{seq_split_path}') s USING (seq_id)
861
+ GROUP BY s.split
862
+ """
863
+ ).fetchall()
864
+ }
865
+
866
+ bad_clusters = int(
867
+ con.execute(
868
+ f"""
869
+ WITH keep AS (
870
+ SELECT sc.cluster_id, ss.split
871
+ FROM read_parquet('{seq_cluster_path}') sc
872
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
873
+ WHERE ss.split != 'drop'
874
+ )
875
+ SELECT count(*)::BIGINT
876
+ FROM (
877
+ SELECT cluster_id
878
+ FROM keep
879
+ GROUP BY cluster_id
880
+ HAVING count(DISTINCT split) > 1
881
+ );
882
+ """
883
+ ).fetchone()[0]
884
+ )
885
+ print(f"clusters_spanning_splits(excluding drop) = {bad_clusters}")
886
+
887
+ bad_test = int(
888
+ con.execute(
889
+ f"""
890
+ WITH base AS (
891
+ SELECT {seq_id} AS seq_id, {species_key} AS species_key
892
+ FROM {input_sql}
893
+ )
894
+ SELECT count(*)::BIGINT
895
+ FROM base
896
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
897
+ WHERE ss.split = 'test'
898
+ AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
899
+ """
900
+ ).fetchone()[0]
901
+ )
902
+ print(f"test_rows_with_seen_species = {bad_test}")
903
+
904
+ bad_val_species = int(
905
+ con.execute(
906
+ f"""
907
+ WITH base AS (
908
+ SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key
909
+ FROM {input_sql}
910
+ ),
911
+ labeled AS (
912
+ SELECT base.species_key, ss.split
913
+ FROM base
914
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
915
+ WHERE ss.split IN ('train', 'val')
916
+ ),
917
+ train_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'train'),
918
+ val_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'val')
919
+ SELECT count(*)::BIGINT
920
+ FROM (SELECT species_key FROM val_species EXCEPT SELECT species_key FROM train_species);
921
+ """
922
+ ).fetchone()[0]
923
+ )
924
+ print(f"val_species_not_in_train = {bad_val_species}")
925
+
926
+ protein_overlap_train_val = int(
927
+ con.execute(
928
+ f"""
929
+ WITH base AS (
930
+ SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm
931
+ FROM {input_sql}
932
+ WHERE length({protein_norm}) > 0
933
+ ),
934
+ labeled AS (
935
+ SELECT base.protein_norm, ss.split
936
+ FROM base
937
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
938
+ WHERE ss.split IN ('train', 'val')
939
+ ),
940
+ train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'),
941
+ val_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'val')
942
+ SELECT count(*)::BIGINT
943
+ FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM val_p);
944
+ """
945
+ ).fetchone()[0]
946
+ )
947
+ protein_overlap_train_test = int(
948
+ con.execute(
949
+ f"""
950
+ WITH base AS (
951
+ SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm
952
+ FROM {input_sql}
953
+ WHERE length({protein_norm}) > 0
954
+ ),
955
+ labeled AS (
956
+ SELECT base.protein_norm, ss.split
957
+ FROM base
958
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
959
+ WHERE ss.split IN ('train', 'test')
960
+ ),
961
+ train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'),
962
+ test_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'test')
963
+ SELECT count(*)::BIGINT
964
+ FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM test_p);
965
+ """
966
+ ).fetchone()[0]
967
+ )
968
+ print(f"exact_protein_overlap_train_val = {protein_overlap_train_val}")
969
+ print(f"exact_protein_overlap_train_test = {protein_overlap_train_test}")
970
+
971
+ mixed_test_clusters = int(
972
+ con.execute(
973
+ "SELECT count(*)::BIGINT FROM cluster_flags WHERE n_test > 0 AND n_seen > 0"
974
+ ).fetchone()[0]
975
+ )
976
+ exact_holdout_seen_conflicts = int(
977
+ con.execute(
978
+ f"""
979
+ WITH base AS (
980
+ SELECT DISTINCT
981
+ {protein_norm} AS protein_norm,
982
+ {species_key} AS species_key
983
+ FROM {input_sql}
984
+ WHERE length({protein_norm}) > 0
985
+ AND {species_key} != ''
986
+ )
987
+ SELECT count(*)::BIGINT
988
+ FROM (
989
+ SELECT protein_norm
990
+ FROM base
991
+ GROUP BY protein_norm
992
+ HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
993
+ AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
994
+ );
995
+ """
996
+ ).fetchone()[0]
997
+ )
998
+ dropped_seen_rows_exact_holdout = int(
999
+ con.execute(
1000
+ f"""
1001
+ WITH base AS (
1002
+ SELECT DISTINCT
1003
+ {seq_id} AS seq_id,
1004
+ {species_key} AS species_key,
1005
+ {protein_norm} AS protein_norm
1006
+ FROM {input_sql}
1007
+ WHERE length(({seq_id})) > 1
1008
+ AND {species_key} != ''
1009
+ ),
1010
+ conflict_proteins AS (
1011
+ SELECT protein_norm
1012
+ FROM base
1013
+ WHERE length(protein_norm) > 0
1014
+ GROUP BY protein_norm
1015
+ HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
1016
+ AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
1017
+ )
1018
+ SELECT count(*)::BIGINT
1019
+ FROM base
1020
+ JOIN conflict_proteins USING (protein_norm)
1021
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
1022
+ WHERE ss.split = 'drop'
1023
+ AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
1024
+ """
1025
+ ).fetchone()[0]
1026
+ )
1027
+ dropped_val_rows_exact_train = int(
1028
+ con.execute(
1029
+ f"""
1030
+ WITH base AS (
1031
+ SELECT DISTINCT
1032
+ {seq_id} AS seq_id,
1033
+ {protein_norm} AS protein_norm
1034
+ FROM {input_sql}
1035
+ WHERE length(({seq_id})) > 1
1036
+ AND length({protein_norm}) > 0
1037
+ ),
1038
+ labeled AS (
1039
+ SELECT base.protein_norm, ss.split
1040
+ FROM base
1041
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
1042
+ ),
1043
+ train_proteins AS (
1044
+ SELECT DISTINCT protein_norm
1045
+ FROM labeled
1046
+ WHERE split = 'train'
1047
+ )
1048
+ SELECT count(*)::BIGINT
1049
+ FROM base
1050
+ JOIN train_proteins USING (protein_norm)
1051
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
1052
+ WHERE ss.split = 'drop';
1053
+ """
1054
+ ).fetchone()[0]
1055
+ )
1056
+ dropped_seen_rows_mixed = int(
1057
+ con.execute(
1058
+ f"""
1059
+ WITH base AS (
1060
+ SELECT {seq_id} AS seq_id, {species_key} AS species_key
1061
+ FROM {input_sql}
1062
+ )
1063
+ SELECT count(*)::BIGINT
1064
+ FROM base
1065
+ JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
1066
+ JOIN cluster_flags cf USING (cluster_id)
1067
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
1068
+ WHERE ss.split = 'drop'
1069
+ AND cf.n_test > 0
1070
+ AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
1071
+ """
1072
+ ).fetchone()[0]
1073
+ )
1074
+ dropped_seen_seqids_mixed = int(
1075
+ con.execute(
1076
+ f"""
1077
+ WITH base AS (
1078
+ SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key
1079
+ FROM {input_sql}
1080
+ )
1081
+ SELECT count(*)::BIGINT
1082
+ FROM base
1083
+ JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
1084
+ JOIN cluster_flags cf USING (cluster_id)
1085
+ JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
1086
+ WHERE ss.split = 'drop'
1087
+ AND cf.n_test > 0
1088
+ AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
1089
+ """
1090
+ ).fetchone()[0]
1091
+ )
1092
+ same_protein_multi_species = int(
1093
+ con.execute(
1094
+ f"""
1095
+ WITH base AS (
1096
+ SELECT DISTINCT
1097
+ {protein_norm} AS protein_norm,
1098
+ {species_key} AS species_key
1099
+ FROM {input_sql}
1100
+ WHERE length({protein_norm}) > 0
1101
+ AND {species_key} != ''
1102
+ )
1103
+ SELECT count(*)::BIGINT
1104
+ FROM (
1105
+ SELECT protein_norm
1106
+ FROM base
1107
+ GROUP BY protein_norm
1108
+ HAVING count(DISTINCT species_key) > 1
1109
+ );
1110
+ """
1111
+ ).fetchone()[0]
1112
+ )
1113
+ same_protein_cross_holdout = int(
1114
+ con.execute(
1115
+ f"""
1116
+ WITH base AS (
1117
+ SELECT DISTINCT
1118
+ {protein_norm} AS protein_norm,
1119
+ {species_key} AS species_key
1120
+ FROM {input_sql}
1121
+ WHERE length({protein_norm}) > 0
1122
+ AND {species_key} != ''
1123
+ )
1124
+ SELECT count(*)::BIGINT
1125
+ FROM (
1126
+ SELECT protein_norm
1127
+ FROM base
1128
+ GROUP BY protein_norm
1129
+ HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
1130
+ AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
1131
+ );
1132
+ """
1133
+ ).fetchone()[0]
1134
+ )
1135
+
1136
+ report = {
1137
+ "parameters": {
1138
+ "input_glob": args.input_glob,
1139
+ "heldout_test_glob": args.heldout_test_glob,
1140
+ "seq_cluster_parquet": str(seq_cluster),
1141
+ "seq_split_parquet": str(seq_split),
1142
+ "seq_space": args.seq_space,
1143
+ "species_key_mode": args.species_key_mode,
1144
+ "limit_files": int(args.limit_files),
1145
+ },
1146
+ "split_seq_ids": split_seq_ids,
1147
+ "split_rows": split_rows,
1148
+ "verification": {
1149
+ "clusters_spanning_splits_excluding_drop": bad_clusters,
1150
+ "test_rows_with_seen_species": bad_test,
1151
+ "val_species_not_in_train": bad_val_species,
1152
+ "exact_protein_overlap_train_val": protein_overlap_train_val,
1153
+ "exact_protein_overlap_train_test": protein_overlap_train_test,
1154
+ },
1155
+ "audit": {
1156
+ "mixed_test_clusters": mixed_test_clusters,
1157
+ "exact_protein_cross_holdout_seen_groups": exact_holdout_seen_conflicts,
1158
+ "dropped_seen_rows_from_exact_protein_holdout_overlap": dropped_seen_rows_exact_holdout,
1159
+ "dropped_rows_from_exact_protein_train_overlap": dropped_val_rows_exact_train,
1160
+ "dropped_seen_rows_from_mixed_test_clusters": dropped_seen_rows_mixed,
1161
+ "dropped_seen_seqids_from_mixed_test_clusters": dropped_seen_seqids_mixed,
1162
+ "same_protein_multi_species_exact_matches": same_protein_multi_species,
1163
+ "same_protein_cross_holdout_species_exact_matches": same_protein_cross_holdout,
1164
+ },
1165
+ }
1166
+
1167
+ if args.report_json:
1168
+ report_path = Path(args.report_json)
1169
+ report_path.parent.mkdir(parents=True, exist_ok=True)
1170
+ with open(report_path, "w", encoding="utf-8") as f:
1171
+ json.dump(report, f, indent=2, sort_keys=True)
1172
+ print(f"Wrote audit report: {report_path}")
1173
+
1174
+ if (
1175
+ bad_clusters != 0
1176
+ or bad_test != 0
1177
+ or bad_val_species != 0
1178
+ or protein_overlap_train_val != 0
1179
+ or protein_overlap_train_test != 0
1180
+ ):
1181
+ raise SystemExit("Verification FAILED (see counts above).")
1182
+ print("Verification OK.")
1183
+
1184
+
1185
+ def build_parser() -> argparse.ArgumentParser:
1186
+ ap = argparse.ArgumentParser(
1187
+ description="Resplit data_v2 to data_v3_rebuild using MMseqs2 protein clustering."
1188
+ )
1189
+ sub = ap.add_subparsers(dest="cmd", required=True)
1190
+
1191
+ p = sub.add_parser("make-fasta", help="Generate MMseqs FASTA from parquet shards.")
1192
+ p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
1193
+ p.add_argument("--output-fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta")
1194
+ p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
1195
+ p.add_argument(
1196
+ "--max-input-seq-len",
1197
+ type=int,
1198
+ default=0,
1199
+ help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).",
1200
+ )
1201
+ p.add_argument("--threads", type=int, default=32)
1202
+ p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
1203
+ p.add_argument("--limit-rows", type=int, default=0, help="Debug: limit number of rows written (0=all)")
1204
+ p.set_defaults(func=cmd_make_fasta)
1205
+
1206
+ p = sub.add_parser("mmseqs-cluster", help="Run MMseqs2 createdb+linclust and emit clustering TSV.")
1207
+ p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path())
1208
+ p.add_argument("--fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta")
1209
+ p.add_argument("--workdir", type=str, default="data_v3_rebuild/_work/mmseqs")
1210
+ p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
1211
+ p.add_argument("--threads", type=int, default=32)
1212
+ p.add_argument("--min-seq-id", type=float, default=0.90)
1213
+ p.add_argument("-c", "--coverage", type=float, default=0.80)
1214
+ p.add_argument("--cov-mode", type=int, default=2, help="2=enforce representative/query coverage")
1215
+ p.add_argument("--cluster-mode", type=int, default=2, help="2=greedy clustering by sequence length")
1216
+ p.add_argument("--max-seq-len", type=int, default=200000)
1217
+ p.add_argument(
1218
+ "--kmer-per-seq-scale",
1219
+ type=float,
1220
+ default=None,
1221
+ help="Optional MMseqs2 override; leave empty to use MMseqs defaults.",
1222
+ )
1223
+ p.add_argument("--split-memory-limit", type=str, default="", help="e.g. 120G (empty=use MMseqs default)")
1224
+ g = p.add_mutually_exclusive_group()
1225
+ g.add_argument(
1226
+ "--remove-tmp-files",
1227
+ dest="remove_tmp_files",
1228
+ action="store_true",
1229
+ default=True,
1230
+ help="Remove MMseqs2 tmp files (default).",
1231
+ )
1232
+ g.add_argument(
1233
+ "--keep-tmp-files",
1234
+ dest="remove_tmp_files",
1235
+ action="store_false",
1236
+ help="Keep MMseqs2 tmp files.",
1237
+ )
1238
+ p.add_argument("--overwrite", action="store_true")
1239
+ p.set_defaults(func=cmd_mmseqs_cluster)
1240
+
1241
+ p = sub.add_parser("make-seq-cluster", help="Convert MMseqs TSV to parquet mapping seq_id→cluster_id.")
1242
+ p.add_argument("--cluster-tsv", type=str, default="data_v3_rebuild/_work/mmseqs/clu.tsv")
1243
+ p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
1244
+ p.add_argument("--threads", type=int, default=32)
1245
+ p.set_defaults(func=cmd_make_seq_cluster)
1246
+
1247
+ p = sub.add_parser(
1248
+ "make-seq-split",
1249
+ help="Create seq_id→{train,val,test,drop} using cluster assignments and heldout species.",
1250
+ )
1251
+ p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
1252
+ p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
1253
+ p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
1254
+ p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
1255
+ p.add_argument("--cluster-split-parquet", type=str, default="data_v3_rebuild/_work/cluster_split.parquet")
1256
+ p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
1257
+ p.add_argument("--val-frac", type=float, default=0.01)
1258
+ p.add_argument("--seed", type=int, default=13)
1259
+ p.add_argument("--threads", type=int, default=32)
1260
+ p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
1261
+ p.set_defaults(func=cmd_make_seq_split)
1262
+
1263
+ p = sub.add_parser("write-data-v3", help="Write data_v3 parquet directories from seq_split mapping.")
1264
+ p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
1265
+ p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
1266
+ p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
1267
+ p.add_argument("--output-root", type=str, default="data_v3_rebuild")
1268
+ p.add_argument("--num-shards", type=int, default=256, help="Partition each split into N shards")
1269
+ p.add_argument("--threads", type=int, default=32)
1270
+ p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
1271
+ g = p.add_mutually_exclusive_group()
1272
+ g.add_argument(
1273
+ "--representatives-only",
1274
+ dest="representatives_only",
1275
+ action="store_true",
1276
+ default=True,
1277
+ help="Write only one representative seq_id per MMseqs cluster (default).",
1278
+ )
1279
+ g.add_argument(
1280
+ "--all-cluster-members",
1281
+ dest="representatives_only",
1282
+ action="store_false",
1283
+ help="Write all seq_ids assigned to the split instead of one representative per cluster.",
1284
+ )
1285
+ p.add_argument("--overwrite", action="store_true")
1286
+ p.set_defaults(func=cmd_write_data_v3)
1287
+
1288
+ p = sub.add_parser("verify", help="Verify leakage/species constraints and write an audit report.")
1289
+ p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
1290
+ p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
1291
+ p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
1292
+ p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
1293
+ p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
1294
+ p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
1295
+ p.add_argument("--report-json", type=str, default="data_v3_rebuild/_work/split_report.json")
1296
+ p.add_argument("--threads", type=int, default=32)
1297
+ p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
1298
+ p.set_defaults(func=cmd_verify)
1299
+
1300
+ p = sub.add_parser(
1301
+ "all",
1302
+ help="Run the full pipeline: make-fasta → mmseqs-cluster → make-seq-cluster → make-seq-split → write-data-v3 → verify.",
1303
+ )
1304
+ p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
1305
+ p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
1306
+ p.add_argument("--output-root", type=str, default="data_v3_rebuild")
1307
+ p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
1308
+ p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
1309
+ p.add_argument(
1310
+ "--max-input-seq-len",
1311
+ type=int,
1312
+ default=0,
1313
+ help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).",
1314
+ )
1315
+ p.add_argument("--threads", type=int, default=32)
1316
+ p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
1317
+ p.add_argument("--num-shards", type=int, default=256)
1318
+ g = p.add_mutually_exclusive_group()
1319
+ g.add_argument(
1320
+ "--representatives-only",
1321
+ dest="representatives_only",
1322
+ action="store_true",
1323
+ default=True,
1324
+ help="Write only one representative seq_id per MMseqs cluster (default).",
1325
+ )
1326
+ g.add_argument(
1327
+ "--all-cluster-members",
1328
+ dest="representatives_only",
1329
+ action="store_false",
1330
+ help="Write all seq_ids assigned to the split instead of one representative per cluster.",
1331
+ )
1332
+ p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path())
1333
+ p.add_argument("--min-seq-id", type=float, default=0.90)
1334
+ p.add_argument("-c", "--coverage", type=float, default=0.80)
1335
+ p.add_argument("--cov-mode", type=int, default=2)
1336
+ p.add_argument("--cluster-mode", type=int, default=2)
1337
+ p.add_argument("--max-seq-len", type=int, default=200000)
1338
+ p.add_argument("--kmer-per-seq-scale", type=float, default=None)
1339
+ p.add_argument("--split-memory-limit", type=str, default="")
1340
+ p.add_argument("--val-frac", type=float, default=0.01)
1341
+ p.add_argument("--seed", type=int, default=13)
1342
+ p.add_argument("--overwrite", action="store_true")
1343
+
1344
+ def _run_all(a: argparse.Namespace) -> None:
1345
+ out_root = Path(a.output_root)
1346
+ work = out_root / "_work"
1347
+ fasta = work / "mmseqs_input.fasta"
1348
+ mmseqs_work = work / "mmseqs"
1349
+ cluster_tsv = mmseqs_work / "clu.tsv"
1350
+ seq_cluster = work / "seq_cluster.parquet"
1351
+ cluster_split = work / "cluster_split.parquet"
1352
+ seq_split = work / "seq_split.parquet"
1353
+ report_json = work / "split_report.json"
1354
+
1355
+ cmd_make_fasta(
1356
+ argparse.Namespace(
1357
+ input_glob=a.input_glob,
1358
+ output_fasta=str(fasta),
1359
+ seq_space=a.seq_space,
1360
+ max_input_seq_len=a.max_input_seq_len,
1361
+ threads=a.threads,
1362
+ limit_files=a.limit_files,
1363
+ limit_rows=0,
1364
+ )
1365
+ )
1366
+ cmd_mmseqs_cluster(
1367
+ argparse.Namespace(
1368
+ mmseqs=a.mmseqs,
1369
+ fasta=str(fasta),
1370
+ workdir=str(mmseqs_work),
1371
+ seq_space=a.seq_space,
1372
+ threads=a.threads,
1373
+ min_seq_id=a.min_seq_id,
1374
+ coverage=a.coverage,
1375
+ cov_mode=a.cov_mode,
1376
+ cluster_mode=a.cluster_mode,
1377
+ max_seq_len=a.max_seq_len,
1378
+ kmer_per_seq_scale=a.kmer_per_seq_scale,
1379
+ split_memory_limit=a.split_memory_limit,
1380
+ remove_tmp_files=True,
1381
+ overwrite=a.overwrite,
1382
+ )
1383
+ )
1384
+ cmd_make_seq_cluster(
1385
+ argparse.Namespace(
1386
+ cluster_tsv=str(cluster_tsv),
1387
+ output_parquet=str(seq_cluster),
1388
+ threads=a.threads,
1389
+ )
1390
+ )
1391
+ cmd_make_seq_split(
1392
+ argparse.Namespace(
1393
+ input_glob=a.input_glob,
1394
+ heldout_test_glob=a.heldout_test_glob,
1395
+ species_key_mode=a.species_key_mode,
1396
+ seq_cluster_parquet=str(seq_cluster),
1397
+ cluster_split_parquet=str(cluster_split),
1398
+ output_parquet=str(seq_split),
1399
+ val_frac=a.val_frac,
1400
+ seed=a.seed,
1401
+ threads=a.threads,
1402
+ limit_files=a.limit_files,
1403
+ )
1404
+ )
1405
+ cmd_write_data_v3(
1406
+ argparse.Namespace(
1407
+ input_glob=a.input_glob,
1408
+ seq_cluster_parquet=str(seq_cluster),
1409
+ seq_split_parquet=str(seq_split),
1410
+ output_root=str(out_root),
1411
+ num_shards=a.num_shards,
1412
+ threads=a.threads,
1413
+ limit_files=a.limit_files,
1414
+ representatives_only=a.representatives_only,
1415
+ overwrite=a.overwrite,
1416
+ )
1417
+ )
1418
+ cmd_verify(
1419
+ argparse.Namespace(
1420
+ input_glob=a.input_glob,
1421
+ heldout_test_glob=a.heldout_test_glob,
1422
+ species_key_mode=a.species_key_mode,
1423
+ seq_space=a.seq_space,
1424
+ seq_cluster_parquet=str(seq_cluster),
1425
+ seq_split_parquet=str(seq_split),
1426
+ report_json=str(report_json),
1427
+ threads=a.threads,
1428
+ limit_files=a.limit_files,
1429
+ )
1430
+ )
1431
+
1432
+ p.set_defaults(func=_run_all)
1433
+ return ap
1434
+
1435
+
1436
+ def main(argv: Optional[List[str]] = None) -> int:
1437
+ ap = build_parser()
1438
+ args = ap.parse_args(argv)
1439
+ args.func(args)
1440
+ return 0
1441
+
1442
+
1443
+ if __name__ == "__main__":
1444
+ raise SystemExit(main())
sampling.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Sampling script for generating codon sequences from trained CodonGPT models.
4
+ Inputs are prepared exactly like training:
5
+ - Species conditioning via SpeciesEmbeddingStore (fixed-size [B,Ds] or variable-length [B,Ls,Ds])
6
+ - Protein conditioning via raw AA strings (ESM-C tokenization happens inside the model)
7
+ """
8
+
9
+ import argparse
10
+ import logging
11
+ import json
12
+ from pathlib import Path
13
+ from typing import List, Optional, Union
14
+
15
+ import torch
16
+
17
+ from src.sampler import CodonSampler
18
+ from src.dataset import SpeciesEmbeddingStore
19
+
20
+ logging.basicConfig(
21
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
22
+ datefmt="%m/%d/%Y %H:%M:%S",
23
+ level=logging.INFO,
24
+ )
25
+ logger = logging.getLogger("codongpt.sample")
26
+
27
+
28
+ def parse_args():
29
+ p = argparse.ArgumentParser(description="Sample codon sequences from CodonGPT model")
30
+
31
+ # Model
32
+ p.add_argument("--model_path", "--model_dir", dest="model_path", type=str, required=True,
33
+ help="Path to trained model checkpoint dir")
34
+ p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
35
+ p.add_argument("--compile", action="store_true", help="torch.compile the model")
36
+
37
+ # Species embeddings
38
+ p.add_argument("--embeddings_dir", type=str, default=None,
39
+ help="Directory with precomputed variable-length species embeddings (optional; fallback to Qwen if missing/unknown)")
40
+ p.add_argument("--strict_species_lookup", action="store_true",
41
+ help="When using --embeddings_dir, fail if any requested species name is not an exact key in species_vocab.json")
42
+ p.add_argument("--taxonomy_db", type=str, default=None,
43
+ help="Optional path to taxonomy_database.json (from precompute) to enrich prompts")
44
+
45
+ # Sampling batch size and count
46
+ p.add_argument("--num_sequences", "--num_seq", "--num_samples", type=int, default=1, dest="num_sequences",
47
+ help="Number of sequences to generate in total")
48
+ p.add_argument("--batch_size", type=int, default=None, help="Batch size for sampling loop")
49
+
50
+ # Control mode and length
51
+ p.add_argument("--control_mode", choices=["fixed", "variable"], default="fixed",
52
+ help="fixed: disallow EOS, generate exactly sequence_length codons; variable: allow EOS")
53
+ p.add_argument("--sequence_length", type=int, default=None,
54
+ help="Number of CODONS to generate (used as max steps in variable mode). "
55
+ "If omitted and protein sequences are provided, set to min protein length.")
56
+
57
+ # Conditioning (REQUIRED: species and protein)
58
+ p.add_argument("--species", "--taxon", type=str, default=None, dest="species",
59
+ help="Species name (e.g., 'Homo sapiens'). Replicated if num_sequences>1.")
60
+ p.add_argument("--species_list", type=str, nargs="+", default=None,
61
+ help="List of species names (must match num_sequences).")
62
+
63
+ p.add_argument("--protein_seq", "--protein_sequence", type=str, default=None, dest="protein_seq",
64
+ help="Protein sequence (AA string). Replicated if num_sequences>1.")
65
+ p.add_argument("--protein_file", type=str, default=None,
66
+ help="Path to FASTA-like file (each non-header line is a sequence). Must provide at least num_sequences.")
67
+
68
+ # Sampling params
69
+ p.add_argument("--temperature", type=float, default=1, help="Sampling temperature")
70
+ p.add_argument("--top_k", type=int, default=50, help="Top-k")
71
+ p.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus)")
72
+ p.add_argument("--enforce_translation", action="store_true", default=False,
73
+ help="Hard-mask codons to match the given protein AA at each position")
74
+ p.add_argument("--seed", type=int, default=None)
75
+ p.add_argument("--save_intermediate", action="store_true", help="Store intermediate token states")
76
+
77
+ # Output
78
+ p.add_argument("--output_file", type=str, default=None)
79
+ p.add_argument("--output_format", type=str, default="fasta", choices=["fasta", "csv", "json"])
80
+
81
+ # Misc
82
+ p.add_argument("--quiet", action="store_true")
83
+ return p.parse_args()
84
+
85
+
86
+ def load_protein_sequences(file_path: str) -> List[str]:
87
+ """Load protein sequences: every non-'>' line is a sequence."""
88
+ seqs: List[str] = []
89
+ with open(file_path, "r") as f:
90
+ for line in f:
91
+ line = line.strip()
92
+ if line and not line.startswith(">"):
93
+ seqs.append(line)
94
+ return seqs
95
+
96
+
97
+ def setup_species_store(embeddings_dir: str) -> SpeciesEmbeddingStore:
98
+ """Load species embedding store (prefer variable-length if available)."""
99
+ # We don't guess. If you stored sequence-format, this will pick it; else fixed-size.
100
+ return SpeciesEmbeddingStore(embeddings_dir, pooling="sequence")
101
+
102
+
103
+ def save_sequences(
104
+ sequences: List[str],
105
+ output_file: str,
106
+ fmt: str,
107
+ species: Optional[List[str]] = None,
108
+ proteins: Optional[List[str]] = None,
109
+ metadata: Optional[dict] = None,
110
+ ):
111
+ if fmt == "fasta":
112
+ with open(output_file, "w") as f:
113
+ for i, seq in enumerate(sequences):
114
+ header = f">seq_{i}"
115
+ if species and i < len(species):
116
+ header += f"|species={species[i]}"
117
+ if proteins and i < len(proteins):
118
+ header += f"|protein_len={len(proteins[i])}"
119
+ f.write(f"{header}\n{seq}\n")
120
+ return
121
+
122
+ if fmt == "csv":
123
+ import pandas as pd
124
+ data = {"sequence": sequences}
125
+ if species:
126
+ data["species"] = species[:len(sequences)]
127
+ if proteins:
128
+ data["protein_sequence"] = proteins[:len(sequences)]
129
+ pd.DataFrame(data).to_csv(output_file, index=False)
130
+ return
131
+
132
+ # json
133
+ payload = {"sequences": sequences, "metadata": metadata or {}}
134
+ if species:
135
+ payload["species"] = species[:len(sequences)]
136
+ if proteins:
137
+ payload["protein_sequences"] = proteins[:len(sequences)]
138
+ with open(output_file, "w") as f:
139
+ json.dump(payload, f, indent=2)
140
+
141
+
142
+ def translate_dna_to_aa(dna_seq: str) -> str:
143
+ """Translate DNA (3-mer) using the standard genetic code."""
144
+ g = {
145
+ 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
146
+ 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
147
+ 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
148
+ 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
149
+ 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
150
+ 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
151
+ 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
152
+ 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
153
+ }
154
+ L = len(dna_seq) // 3
155
+ aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)]
156
+ return ''.join(aa)
157
+
158
+
159
+ def report_token_accuracy(sequences: List[str], target_proteins: List[str]) -> None:
160
+ for i, dna in enumerate(sequences):
161
+ tgt = target_proteins[i] if i < len(target_proteins) else target_proteins[-1]
162
+ gen_aa = translate_dna_to_aa(dna)
163
+ L = min(len(gen_aa), len(tgt))
164
+ if L == 0:
165
+ acc = 0.0; num = 0; den = 0
166
+ else:
167
+ matches = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b)
168
+ acc = matches / L; num = matches; den = L
169
+ logger.info(f"AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})")
170
+
171
+
172
+ def main():
173
+ args = parse_args()
174
+
175
+ if args.device == "cuda" and not torch.cuda.is_available():
176
+ raise RuntimeError("CUDA requested but not available")
177
+
178
+ if args.seed is not None:
179
+ torch.manual_seed(int(args.seed))
180
+
181
+ # Conditioning must be provided – same invariants as training
182
+ have_species_names = bool(args.species_list) or bool(args.species)
183
+ have_protein = bool(args.protein_file) or bool(args.protein_seq)
184
+ if not have_species_names or not have_protein:
185
+ raise ValueError("Sampling requires BOTH species (names) and protein sequence(s).")
186
+
187
+ # Species names list
188
+ if args.species_list:
189
+ species_names = list(args.species_list)
190
+ else:
191
+ species_names = [str(args.species)]
192
+
193
+ # Protein sequences list
194
+ if args.protein_file:
195
+ protein_sequences = load_protein_sequences(args.protein_file)
196
+ else:
197
+ protein_sequences = [str(args.protein_seq)]
198
+
199
+ # Expand/reconcile counts
200
+ N = int(args.num_sequences)
201
+ if len(species_names) == 1 and N > 1:
202
+ species_names = species_names * N
203
+ if len(protein_sequences) == 1 and N > 1:
204
+ protein_sequences = protein_sequences * N
205
+
206
+ if len(species_names) != N:
207
+ raise ValueError(f"species count ({len(species_names)}) must equal num_sequences ({N})")
208
+ if len(protein_sequences) < N:
209
+ raise ValueError(f"protein sequences provided ({len(protein_sequences)}) less than num_sequences ({N})")
210
+ if len(protein_sequences) > N:
211
+ protein_sequences = protein_sequences[:N]
212
+
213
+ # If no explicit sequence_length, use min protein length, so every sample has a valid AA at each fixed step
214
+ if args.sequence_length is None:
215
+ args.sequence_length = min(len(s) for s in protein_sequences)
216
+ logger.info(f"Auto-set sequence_length to min protein length: {args.sequence_length} codons")
217
+
218
+ if args.sequence_length <= 0:
219
+ raise ValueError("sequence_length must be > 0")
220
+
221
+ # Load species store if provided (preferred to exactly match training); unknown species will fallback to Qwen
222
+ species_store = None
223
+ if args.embeddings_dir:
224
+ species_store = setup_species_store(args.embeddings_dir)
225
+ logger.info(f"Loaded species store: {len(species_store.vocab)} species; Ds={species_store.Ds()}")
226
+ if args.strict_species_lookup:
227
+ unknown = sorted({name for name in species_names if name not in species_store.vocab})
228
+ if unknown:
229
+ preview = ", ".join(repr(x) for x in unknown[:5])
230
+ more = "" if len(unknown) <= 5 else f" ... (+{len(unknown) - 5} more)"
231
+ raise ValueError(
232
+ "strict species lookup failed; these names are not exact keys in species_vocab.json: "
233
+ f"{preview}{more}"
234
+ )
235
+
236
+ sampler = CodonSampler(
237
+ model_path=args.model_path,
238
+ device=args.device,
239
+ compile_model=bool(args.compile),
240
+ species_store=species_store,
241
+ taxonomy_db_path=args.taxonomy_db,
242
+ )
243
+
244
+ # Batch loop
245
+ batch_size = int(args.batch_size or N)
246
+ all_sequences: List[str] = []
247
+ all_intermediates = []
248
+
249
+ total_batches = (N + batch_size - 1) // batch_size
250
+ for start in range(0, N, batch_size):
251
+ end = min(N, start + batch_size)
252
+ bs = end - start
253
+ batch_species = species_names[start:end]
254
+ batch_proteins = protein_sequences[start:end]
255
+
256
+ logger.info(f"Sampling batch {start//batch_size + 1}/{total_batches} (B={bs})")
257
+
258
+ result = sampler.sample(
259
+ num_sequences=bs,
260
+ sequence_length=int(args.sequence_length),
261
+ species=batch_species,
262
+ protein_sequences=batch_proteins,
263
+ control_mode=str(args.control_mode),
264
+ temperature=float(args.temperature),
265
+ top_k=int(args.top_k),
266
+ top_p=float(args.top_p),
267
+ seed=int(args.seed) if args.seed is not None else None,
268
+ return_intermediate=bool(args.save_intermediate),
269
+ progress_bar=not bool(args.quiet),
270
+ enforce_translation=bool(args.enforce_translation),
271
+ )
272
+
273
+ seqs = result["sequences"] # List[str]
274
+ all_sequences.extend(seqs)
275
+ if args.save_intermediate and "intermediate_states" in result:
276
+ all_intermediates.append(result["intermediate_states"])
277
+
278
+ logger.info(f"Generated {len(all_sequences)} sequences.")
279
+ for i, seq in enumerate(all_sequences[:5]):
280
+ logger.info(f"Sequence {i+1} ({len(seq)//3} codons): {seq[:60]}...")
281
+
282
+ # Save outputs
283
+ if args.output_file:
284
+ meta = {
285
+ "model_path": args.model_path,
286
+ "temperature": args.temperature,
287
+ "top_k": args.top_k,
288
+ "top_p": args.top_p,
289
+ "control_mode": args.control_mode,
290
+ "sequence_length": int(args.sequence_length),
291
+ }
292
+ save_sequences(
293
+ all_sequences,
294
+ args.output_file,
295
+ args.output_format,
296
+ species=species_names,
297
+ proteins=protein_sequences,
298
+ metadata=meta,
299
+ )
300
+ logger.info(f"Saved sequences to {args.output_file}")
301
+
302
+ # Report AA token accuracy when protein targets are given
303
+ report_token_accuracy(all_sequences, protein_sequences)
304
+
305
+ if args.save_intermediate and all_intermediates:
306
+ inter_file = Path(args.output_file).with_suffix("").as_posix() + "_intermediate.pt"
307
+ torch.save(all_intermediates, inter_file)
308
+ logger.info(f"Saved intermediate states to {inter_file}")
309
+
310
+ logger.info("Sampling completed.")
311
+
312
+
313
+ if __name__ == "__main__":
314
+ main()
slurm/rebuild_data_v3_cpu.sbatch ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=beacon
3
+ #SBATCH --qos=high
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=1
6
+ #SBATCH --cpus-per-task=16
7
+ #SBATCH --mem=240G
8
+ #SBATCH --time=3-00:00:00
9
+ #SBATCH --job-name=data_v3_rebuild
10
+ #SBATCH --output=%x_%j.out
11
+ #SBATCH --error=%x_%j.err
12
+
13
+ set -euo pipefail
14
+
15
+ REPO_ROOT=${REPO_ROOT:-/beacon-projects/codon-lm/HE-DLM}
16
+ PYTHON_BIN=${PYTHON_BIN:-/beacon-projects/codon-lm/miniconda3/envs/dna/bin/python}
17
+ MMSEQS_BIN=${MMSEQS_BIN:-$REPO_ROOT/MMseqs2/build/bin/mmseqs}
18
+ INPUT_GLOB=${INPUT_GLOB:-data_v2/*/*.parquet}
19
+ HELDOUT_GLOB=${HELDOUT_GLOB:-data_v2/test/*.parquet}
20
+ SEQ_SPACE=${SEQ_SPACE:-protein}
21
+ SPECIES_KEY_MODE=${SPECIES_KEY_MODE:-binomial}
22
+ MAX_INPUT_SEQ_LEN=${MAX_INPUT_SEQ_LEN:-20000}
23
+ MODE=${MODE:-full}
24
+ OUTPUT_ROOT=${OUTPUT_ROOT:-data_v3_rebuild}
25
+ LIMIT_FILES=${LIMIT_FILES:-0}
26
+ NUM_SHARDS=${NUM_SHARDS:-256}
27
+ VAL_FRAC=${VAL_FRAC:-0.01}
28
+ THREADS=${THREADS:-${SLURM_CPUS_PER_TASK:-16}}
29
+ MIN_SEQ_ID=${MIN_SEQ_ID:-0.90}
30
+ COVERAGE=${COVERAGE:-0.80}
31
+ COV_MODE=${COV_MODE:-2}
32
+ CLUSTER_MODE=${CLUSTER_MODE:-2}
33
+ MAX_SEQ_LEN=${MAX_SEQ_LEN:-200000}
34
+ SPLIT_MEMORY_LIMIT=${SPLIT_MEMORY_LIMIT:-180G}
35
+ SEED=${SEED:-13}
36
+ OVERWRITE=${OVERWRITE:-1}
37
+
38
+ if [[ "${MODE}" == "pilot" ]]; then
39
+ if [[ "${LIMIT_FILES}" == "0" ]]; then
40
+ LIMIT_FILES=4
41
+ fi
42
+ if [[ "${OUTPUT_ROOT}" == "data_v3_rebuild" ]]; then
43
+ OUTPUT_ROOT=data_v3_pilot
44
+ fi
45
+ if [[ "${NUM_SHARDS}" == "256" ]]; then
46
+ NUM_SHARDS=16
47
+ fi
48
+ fi
49
+
50
+ cd "${REPO_ROOT}"
51
+
52
+ export LD_LIBRARY_PATH="/beacon-projects/codon-lm/miniconda3/envs/dna/lib:/beacon-projects/codon-lm/miniconda3/lib:${LD_LIBRARY_PATH:-}"
53
+
54
+ if [[ -f "${MMSEQS_BIN}" && ! -x "${MMSEQS_BIN}" ]]; then
55
+ chmod u+x "${MMSEQS_BIN}"
56
+ fi
57
+
58
+ "${PYTHON_BIN}" - <<'PY'
59
+ import duckdb
60
+ import pyarrow
61
+
62
+ print("duckdb", duckdb.__version__)
63
+ print("pyarrow", pyarrow.__version__)
64
+ PY
65
+
66
+ CMD=(
67
+ "${PYTHON_BIN}" resplit_data_v3.py all
68
+ --input-glob "${INPUT_GLOB}"
69
+ --heldout-test-glob "${HELDOUT_GLOB}"
70
+ --output-root "${OUTPUT_ROOT}"
71
+ --seq-space "${SEQ_SPACE}"
72
+ --species-key-mode "${SPECIES_KEY_MODE}"
73
+ --max-input-seq-len "${MAX_INPUT_SEQ_LEN}"
74
+ --threads "${THREADS}"
75
+ --limit-files "${LIMIT_FILES}"
76
+ --num-shards "${NUM_SHARDS}"
77
+ --mmseqs "${MMSEQS_BIN}"
78
+ --min-seq-id "${MIN_SEQ_ID}"
79
+ --coverage "${COVERAGE}"
80
+ --cov-mode "${COV_MODE}"
81
+ --cluster-mode "${CLUSTER_MODE}"
82
+ --max-seq-len "${MAX_SEQ_LEN}"
83
+ --val-frac "${VAL_FRAC}"
84
+ --seed "${SEED}"
85
+ )
86
+
87
+ if [[ -n "${SPLIT_MEMORY_LIMIT}" ]]; then
88
+ CMD+=(--split-memory-limit "${SPLIT_MEMORY_LIMIT}")
89
+ fi
90
+
91
+ if [[ "${OVERWRITE}" == "1" ]]; then
92
+ CMD+=(--overwrite)
93
+ fi
94
+
95
+ printf 'Running command:'
96
+ printf ' %q' "${CMD[@]}"
97
+ printf '\n'
98
+ "${CMD[@]}"
slurm/submit_train_v3_h200_8x_chain.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ cd /beacon-projects/codon-lm/HE-DLM
6
+
7
+ SEGMENTS=${SEGMENTS:-3}
8
+ SBATCH_SCRIPT=${SBATCH_SCRIPT:-slurm/train_v3_h200_8x_single.sbatch}
9
+
10
+ if [[ ! -f "${SBATCH_SCRIPT}" ]]; then
11
+ echo "Missing sbatch script: ${SBATCH_SCRIPT}" >&2
12
+ exit 1
13
+ fi
14
+
15
+ dep=""
16
+ for idx in $(seq 1 "${SEGMENTS}"); do
17
+ if [[ -n "${dep}" ]]; then
18
+ jid=$(sbatch --parsable --dependency=afterany:"${dep}" "${SBATCH_SCRIPT}")
19
+ else
20
+ jid=$(sbatch --parsable "${SBATCH_SCRIPT}")
21
+ fi
22
+ echo "submitted segment=${idx} job_id=${jid} dependency=${dep:-none}"
23
+ dep="${jid}"
24
+ done
slurm/train_v3_h200_8x_single.sbatch ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Single-node 8x H200 training entrypoint.
3
+ # Reserved single-node smoke-run example:
4
+ # sbatch --time=00:45:00 \
5
+ # --export=ALL,OUT_DIR=/beacon-projects/codon-lm/HE-DLM/outputs_v3_rep_h200_8x_reserved_smoke,MAX_STEPS=20,SAVE_STEPS=0,EVAL_INTERVAL=0 \
6
+ # slurm/train_v3_h200_8x_single.sbatch
7
+ # Full-run example:
8
+ # sbatch slurm/train_v3_h200_8x_single.sbatch
9
+ #
10
+ # Suggested W&B overrides:
11
+ # sbatch --export=ALL,WANDB_PROJECT=he-dlm-v3-h200-8x,WANDB_NAME=he-dlm-v3-h200-8x-run1 \
12
+ # slurm/train_v3_h200_8x_single.sbatch
13
+ # If the environment is still configured for offline logging, override at submit time:
14
+ # sbatch --export=ALL,WANDB_MODE=online slurm/train_v3_h200_8x_single.sbatch
15
+ # This script is pinned to the reserved H200 allocation on ihccs210.
16
+ # Do not use QoS=reserved on any other node.
17
+ #SBATCH --job-name=train-v3-h200-8x
18
+ #SBATCH --partition=beacon
19
+ #SBATCH --qos=reserved
20
+ #SBATCH --reservation=heng-reservation
21
+ #SBATCH --nodelist=ihccs210
22
+ #SBATCH --nodes=1
23
+ #SBATCH --ntasks=1
24
+ #SBATCH --gres=gpu:nvidia_h200:8
25
+ #SBATCH --cpus-per-task=16
26
+ #SBATCH --mem=512G
27
+ #SBATCH --time=3-00:00:00
28
+ #SBATCH --output=%x_%j.out
29
+ #SBATCH --error=%x_%j.err
30
+
31
+ set -euo pipefail
32
+
33
+ set +u
34
+ source ~/.bashrc
35
+ conda activate dna
36
+ set -u
37
+
38
+ cd /beacon-projects/codon-lm/HE-DLM
39
+
40
+ TRAIN_DATA=${TRAIN_DATA:-/beacon-projects/codon-lm/HE-DLM/data_v3_rebuild/train}
41
+ VAL_DATA=${VAL_DATA:-/beacon-projects/codon-lm/HE-DLM/data_v3_rebuild/val}
42
+ EMBED_DIR=${EMBED_DIR:-/beacon-projects/codon-lm/HE-DLM/embeddings_v2}
43
+ OUT_DIR=${OUT_DIR:-/beacon-projects/codon-lm/HE-DLM/outputs_v3_rep_h200_8x_single_wd1e-4_bs48ga4}
44
+
45
+ WANDB_PROJECT=${WANDB_PROJECT:-he-dlm-v3-h200-8x}
46
+ WANDB_NAME=${WANDB_NAME:-$(basename "${OUT_DIR}")}
47
+ WANDB_RUN_ID=${WANDB_RUN_ID:-$(basename "${OUT_DIR}")}
48
+ WANDB_RESUME=${WANDB_RESUME:-allow}
49
+ WANDB_DIR=${WANDB_DIR:-${OUT_DIR}/wandb}
50
+
51
+ NPROC_PER_NODE=${NPROC_PER_NODE:-8}
52
+ BATCH_SIZE=${BATCH_SIZE:-48}
53
+ GRAD_ACCUM=${GRAD_ACCUM:-4}
54
+ EVAL_BATCH_SIZE=${EVAL_BATCH_SIZE:-32}
55
+ WORKERS=${WORKERS:-0}
56
+ EPOCHS=${EPOCHS:-3}
57
+ LR=${LR:-7e-5}
58
+ WARMUP_RATIO=${WARMUP_RATIO:-0.1}
59
+ WEIGHT_DECAY=${WEIGHT_DECAY:-1e-4}
60
+ LOGGING_STEPS=${LOGGING_STEPS:-10}
61
+ SAVE_STEPS=${SAVE_STEPS:-500}
62
+ SAVE_TOTAL_LIMIT=${SAVE_TOTAL_LIMIT:-1000}
63
+ EVAL_INTERVAL=${EVAL_INTERVAL:-5000}
64
+ EVAL_STEPS=${EVAL_STEPS:-256}
65
+ TRAIN_SHUFFLE_BUFFER=${TRAIN_SHUFFLE_BUFFER:-8192}
66
+ VAL_SHUFFLE_BUFFER=${VAL_SHUFFLE_BUFFER:-0}
67
+ CKPT_RECENT_WINDOW_STEPS=${CKPT_RECENT_WINDOW_STEPS:-2000}
68
+ CKPT_RECENT_INTERVAL=${CKPT_RECENT_INTERVAL:-500}
69
+ CKPT_ARCHIVE_INTERVAL=${CKPT_ARCHIVE_INTERVAL:-1000}
70
+ RESUME_FROM=${RESUME_FROM:-auto}
71
+ MAX_STEPS=${MAX_STEPS:-}
72
+ MASTER_PORT=${MASTER_PORT:-29500}
73
+ GRAD_CKPT=${GRAD_CKPT:-0}
74
+
75
+ export WANDB_PROJECT WANDB_NAME WANDB_RUN_ID WANDB_RESUME WANDB_DIR
76
+ export NCCL_DEBUG=${NCCL_DEBUG:-WARN}
77
+ export TORCH_DISTRIBUTED_DEBUG=${TORCH_DISTRIBUTED_DEBUG:-DETAIL}
78
+ export NCCL_P2P_DISABLE=${NCCL_P2P_DISABLE:-0}
79
+ export NCCL_IB_DISABLE=${NCCL_IB_DISABLE:-1}
80
+ export NCCL_NET_GDR_LEVEL=${NCCL_NET_GDR_LEVEL:-0}
81
+ export NCCL_ASYNC_ERROR_HANDLING=${NCCL_ASYNC_ERROR_HANDLING:-1}
82
+ export NCCL_SHM_DISABLE=${NCCL_SHM_DISABLE:-1}
83
+ export NCCL_CUMEM_HOST_ENABLE=${NCCL_CUMEM_HOST_ENABLE:-1}
84
+ export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1}
85
+
86
+ mkdir -p "${OUT_DIR}" "${WANDB_DIR}"
87
+
88
+ if [[ ! -d "${TRAIN_DATA}" ]]; then
89
+ echo "Missing train data dir: ${TRAIN_DATA}" >&2
90
+ exit 1
91
+ fi
92
+ if [[ ! -d "${VAL_DATA}" ]]; then
93
+ echo "Missing val data dir: ${VAL_DATA}" >&2
94
+ exit 1
95
+ fi
96
+ if [[ ! -f "${EMBED_DIR}/species_vocab.json" ]]; then
97
+ echo "Missing embeddings vocab: ${EMBED_DIR}/species_vocab.json" >&2
98
+ exit 1
99
+ fi
100
+
101
+ echo "HOST=$(hostname)"
102
+ echo "TRAIN_DATA=${TRAIN_DATA}"
103
+ echo "VAL_DATA=${VAL_DATA}"
104
+ echo "EMBED_DIR=${EMBED_DIR}"
105
+ echo "OUT_DIR=${OUT_DIR}"
106
+ echo "WANDB_PROJECT=${WANDB_PROJECT} WANDB_NAME=${WANDB_NAME} WANDB_RUN_ID=${WANDB_RUN_ID} WANDB_RESUME=${WANDB_RESUME} WANDB_MODE=${WANDB_MODE:-unset}"
107
+ echo "BATCH_SIZE=${BATCH_SIZE} GRAD_ACCUM=${GRAD_ACCUM} EVAL_BATCH_SIZE=${EVAL_BATCH_SIZE} NPROC_PER_NODE=${NPROC_PER_NODE}"
108
+ echo "WEIGHT_DECAY=${WEIGHT_DECAY} SAVE_STEPS=${SAVE_STEPS} EVAL_INTERVAL=${EVAL_INTERVAL} MAX_STEPS=${MAX_STEPS:-unset}"
109
+ echo "NCCL_P2P_DISABLE=${NCCL_P2P_DISABLE} NCCL_IB_DISABLE=${NCCL_IB_DISABLE} NCCL_SHM_DISABLE=${NCCL_SHM_DISABLE} NCCL_CUMEM_HOST_ENABLE=${NCCL_CUMEM_HOST_ENABLE}"
110
+
111
+ echo "=== GPU inventory ==="
112
+ nvidia-smi --query-gpu=index,name,memory.total,driver_version --format=csv,noheader || true
113
+ echo "=== GPU topology ==="
114
+ nvidia-smi topo -m || true
115
+ echo "=== NVLink status ==="
116
+ nvidia-smi nvlink -s || true
117
+
118
+ CMD=(
119
+ torchrun
120
+ --standalone
121
+ --nproc_per_node "${NPROC_PER_NODE}"
122
+ --master_port "${MASTER_PORT}"
123
+ train.py
124
+ --train_data "${TRAIN_DATA}"
125
+ --val_data "${VAL_DATA}"
126
+ --embeddings_dir "${EMBED_DIR}"
127
+ --output_dir "${OUT_DIR}"
128
+ --fsdp
129
+ --bf16
130
+ --attn mha
131
+ --hidden 750
132
+ --layers 20
133
+ --heads 15
134
+ --mlp_ratio 3.2
135
+ --batch_size "${BATCH_SIZE}"
136
+ --grad_accum "${GRAD_ACCUM}"
137
+ --eval_batch_size "${EVAL_BATCH_SIZE}"
138
+ --epochs "${EPOCHS}"
139
+ --workers "${WORKERS}"
140
+ --warmup_ratio "${WARMUP_RATIO}"
141
+ --lr "${LR}"
142
+ --weight_decay "${WEIGHT_DECAY}"
143
+ --train_shuffle_buffer "${TRAIN_SHUFFLE_BUFFER}"
144
+ --val_shuffle_buffer "${VAL_SHUFFLE_BUFFER}"
145
+ --logging_steps "${LOGGING_STEPS}"
146
+ --save_steps "${SAVE_STEPS}"
147
+ --save_total_limit "${SAVE_TOTAL_LIMIT}"
148
+ --ckpt_recent_window_steps "${CKPT_RECENT_WINDOW_STEPS}"
149
+ --ckpt_recent_interval "${CKPT_RECENT_INTERVAL}"
150
+ --ckpt_archive_interval "${CKPT_ARCHIVE_INTERVAL}"
151
+ --eval_interval "${EVAL_INTERVAL}"
152
+ --eval_steps "${EVAL_STEPS}"
153
+ )
154
+
155
+ if [[ "${RESUME_FROM}" != "none" && -n "${RESUME_FROM}" ]]; then
156
+ CMD+=(--resume_from "${RESUME_FROM}")
157
+ fi
158
+ if [[ -n "${MAX_STEPS}" ]]; then
159
+ CMD+=(--max_steps "${MAX_STEPS}")
160
+ fi
161
+ if [[ "${GRAD_CKPT}" == "1" ]]; then
162
+ CMD+=(--grad_ckpt)
163
+ fi
164
+
165
+ exec "${CMD[@]}"
src/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CodonGPT – conditional codon sequence generation (GPT-only).
3
+ """
4
+
5
+ from .tokenizer import CodonTokenizer
6
+ from .models import (
7
+ CodonGPT,
8
+ )
9
+ from .trainer import Trainer, TrainingArguments
10
+ from .sampler import CodonSampler, sample_sequences
11
+ from .dataset import (
12
+ stage_collate_fn,
13
+ create_precomputed_dataloaders,
14
+ )
15
+
16
+ __version__ = "0.1.0"
17
+
18
+ __all__ = [
19
+ # Tokenizer
20
+ "CodonTokenizer",
21
+ # Models
22
+ "CodonGPT",
23
+ # Training
24
+ "Trainer",
25
+ "TrainingArguments",
26
+ # Sampling
27
+ "CodonSampler",
28
+ "sample_sequences",
29
+ # Data
30
+ # "stage_collate_fn",
31
+ "create_precomputed_dataloaders",
32
+ # Noise
33
+ ]
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (637 Bytes). View file
 
src/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (41 kB). View file
 
src/__pycache__/layers.cpython-312.pyc ADDED
Binary file (21.5 kB). View file
 
src/__pycache__/models.cpython-312.pyc ADDED
Binary file (25.8 kB). View file
 
src/__pycache__/sampler.cpython-312.pyc ADDED
Binary file (36.1 kB). View file
 
src/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
src/__pycache__/trainer.cpython-312.pyc ADDED
Binary file (65.7 kB). View file
 
src/dataset.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/dataset.py
2
+ """
3
+ Production-ready dataset + dataloader utilities.
4
+
5
+ Rules (because we're adults):
6
+ - Data drives design. Inputs are rows with columns: ["cds_DNA", "protein_seq", "Taxon", (optional) "RefseqID"].
7
+ - Output per sample is a tiny dict the model actually needs. Nothing else.
8
+ - We stream Parquet by row groups, CSV by chunks. No full-file pandas nonsense on big data.
9
+ - We shard by (FSDP rank × dataloader worker). No DistributedSampler needed.
10
+ - We do a simple streaming shuffle buffer for train. Good enough. No fancy "epoch managers".
11
+
12
+ Fields emitted per sample (for collate_fn and trainer):
13
+ {
14
+ "species_name": str,
15
+ "species_id": int,
16
+ "protein_seq": str, # raw AA (ESM tokenized later)
17
+ "aa_len": int,
18
+ "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end
19
+ "refseq_id": str,
20
+ "protein_refseq_id": str,
21
+ "control_mode": "fixed",
22
+ "meta": {"src": "parquet|csv", "file": basename, "row": int}
23
+ }
24
+
25
+ Invariants:
26
+ - cds_DNA length divisible by 3 after trimming to match protein length.
27
+ - DNA uses only ACGT (uppercase). If not, we skip the row. We don't "helpfully fix" broken data.
28
+ - We truncate both DNA and protein to the same min length (codon count).
29
+ - EOS appended to codon_ids; PAD is handled at collate time, not here.
30
+
31
+ Dependencies:
32
+ - pyarrow only if you read parquet. If it isn't installed and you pass parquet files, we fail loudly.
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ import os
38
+ import json
39
+ import glob
40
+ import random
41
+ import logging
42
+ import heapq
43
+ from typing import Dict, List, Any, Optional, Iterable, Tuple
44
+ from pathlib import Path
45
+
46
+ import numpy as np
47
+ import pandas as pd
48
+ import torch
49
+ from torch.utils.data import IterableDataset, Dataset, DataLoader, get_worker_info
50
+
51
+ try:
52
+ from tqdm.auto import tqdm as _tqdm
53
+ except Exception: # pragma: no cover - tqdm might be unavailable in minimal envs
54
+ _tqdm = None
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+ # ------------------------------
59
+ # Species Embedding Store (kept simple and stable)
60
+ # ------------------------------
61
+
62
+ class SpeciesEmbeddingStore:
63
+ def __init__(self, embeddings_dir: str, dtype: str = "float32", pin_memory: bool = False, pooling: str = "last"):
64
+ self.embeddings_dir = Path(embeddings_dir)
65
+ self.pin_memory = bool(pin_memory)
66
+ self.is_legacy = False
67
+ self.pooling = pooling
68
+
69
+ vocab_path = self.embeddings_dir / "species_vocab.json"
70
+ if not vocab_path.exists():
71
+ raise FileNotFoundError(f"Species vocabulary not found at {vocab_path}")
72
+ with open(vocab_path, "r") as f:
73
+ self.vocab: Dict[str, int] = json.load(f)
74
+
75
+ meta_path = self.embeddings_dir / "species_metadata.json"
76
+ new_emb_path = self.embeddings_dir / "species_embeddings.bin"
77
+ legacy_index = self.embeddings_dir / "species_index.json"
78
+ legacy_emb = self.embeddings_dir / "species_tok_emb.bin"
79
+
80
+ if self.pooling == "sequence" and legacy_index.exists() and legacy_emb.exists():
81
+ self.is_legacy = True
82
+ self._load_legacy_format(dtype)
83
+ return
84
+
85
+ if meta_path.exists() and new_emb_path.exists():
86
+ with open(meta_path, "r") as f:
87
+ meta = json.load(f)
88
+ self.num_species = int(meta["num_species"])
89
+ self._ds = int(meta["embedding_dim"])
90
+ self.embedding_type = str(meta.get("embedding_type", "fixed_size"))
91
+ np_dtype = np.float16 if dtype == "float16" else np.float32
92
+ self.embeddings = np.memmap(new_emb_path, dtype=np_dtype, mode="r", shape=(self.num_species, self._ds))
93
+ self._np_dtype = np_dtype
94
+ print(f"Loaded fixed-size species embeddings: {len(self.vocab)} species, Ds={self._ds}, dtype={self._np_dtype}")
95
+ else:
96
+ self.is_legacy = True
97
+ self._load_legacy_format(dtype)
98
+
99
+ def _load_legacy_format(self, dtype: str):
100
+ index_path = self.embeddings_dir / "species_index.json"
101
+ if not index_path.exists():
102
+ raise FileNotFoundError(f"Species index not found at {index_path}")
103
+ with open(index_path, "r") as f:
104
+ raw_index = json.load(f)
105
+ self.index: Dict[str, Dict[str, int]] = {str(k): v for k, v in raw_index.items()}
106
+
107
+ meta_path = self.embeddings_dir / "metadata.json"
108
+ file_dtype = dtype
109
+ if meta_path.exists():
110
+ with open(meta_path, "r") as f:
111
+ meta = json.load(f)
112
+ self._ds = int(meta.get("embedding_dim", 1024))
113
+ file_dtype = str(meta.get("dtype", dtype)).lower()
114
+ else:
115
+ self._ds = 1024
116
+
117
+ emb_path = self.embeddings_dir / "species_tok_emb.bin"
118
+ if not emb_path.exists():
119
+ raise FileNotFoundError(f"Species embeddings not found at {emb_path}")
120
+
121
+ np_dtype = np.float16 if file_dtype == "float16" else np.float32
122
+ itemsize = np.dtype(np_dtype).itemsize
123
+ file_bytes = os.path.getsize(emb_path)
124
+ if file_bytes % (self._ds * itemsize) != 0:
125
+ raise ValueError(f"Emb file size {file_bytes} not divisible by Ds*itemsize ({self._ds}*{itemsize})")
126
+ total_tokens = file_bytes // (self._ds * itemsize)
127
+
128
+ self.embeddings = np.memmap(emb_path, dtype=np_dtype, mode="r", shape=(total_tokens, self._ds))
129
+ self._np_dtype = np_dtype
130
+ self.num_species = len(self.vocab)
131
+ print(f"[LEGACY] variable-length embeddings: {len(self.vocab)} species, {total_tokens} tokens total, Ds={self._ds}.")
132
+
133
+ def load_vocab(self) -> Dict[str, int]:
134
+ return self.vocab.copy()
135
+
136
+ def _deterministic_stub(self, length: int = None) -> torch.FloatTensor:
137
+ if self.is_legacy and length:
138
+ t = torch.zeros(1, length, self._ds, dtype=torch.float32)
139
+ else:
140
+ t = torch.zeros(1, self._ds, dtype=torch.float32)
141
+ return t
142
+
143
+ def get(self, species_id: int) -> torch.FloatTensor:
144
+ if not self.is_legacy:
145
+ if species_id < 0 or species_id >= getattr(self, "num_species", 0):
146
+ return self._deterministic_stub()
147
+ emb = self.embeddings[species_id]
148
+ tensor = torch.from_numpy(np.asarray(emb).copy()).float().unsqueeze(0)
149
+ return tensor
150
+ else:
151
+ sid = str(species_id)
152
+ entry = self.index.get(sid)
153
+ if entry is None:
154
+ return self._deterministic_stub(length=8)
155
+ offset = int(entry["offset"]); length = int(entry["length"])
156
+ view = self.embeddings[offset: offset + length]
157
+ tensor = torch.from_numpy(np.asarray(view).copy()).float().unsqueeze(0)
158
+ return tensor
159
+
160
+ def batch_get(self, species_ids: List[int]) -> Any:
161
+ if torch.is_tensor(species_ids):
162
+ species_ids = species_ids.detach().cpu().tolist()
163
+ else:
164
+ species_ids = [int(x) for x in species_ids]
165
+ B = len(species_ids)
166
+ if not self.is_legacy:
167
+ batch_emb = torch.zeros(B, self._ds, dtype=torch.float32)
168
+ for i, sid in enumerate(species_ids):
169
+ batch_emb[i] = self.get(sid).squeeze(0)
170
+ return batch_emb
171
+ else:
172
+ tensors = [self.get(sid) for sid in species_ids]
173
+ lengths = torch.tensor([t.shape[1] for t in tensors], dtype=torch.long)
174
+ Ls_max = int(lengths.max().item()) if lengths.numel() > 0 else 0
175
+ padded = torch.zeros(B, Ls_max, self._ds, dtype=torch.float32)
176
+ for i, t in enumerate(tensors):
177
+ L = t.shape[1]; padded[i, :L] = t.squeeze(0)
178
+ return padded, lengths
179
+
180
+ def Ds(self) -> int:
181
+ return self._ds
182
+
183
+ def _is_parquet(path: str) -> bool:
184
+ lower = path.lower()
185
+ return lower.endswith(".parquet") or lower.endswith(".parq")
186
+
187
+
188
+ def _is_csv(path: str) -> bool:
189
+ lower = path.lower()
190
+ return (
191
+ lower.endswith(".csv")
192
+ or lower.endswith(".tsv")
193
+ or lower.endswith(".csv.gz")
194
+ or lower.endswith(".tsv.gz")
195
+ )
196
+
197
+
198
+ def _expand_paths(maybe_path_or_glob: str | List[str]) -> List[str]:
199
+ """
200
+ Expand a path/glob or list of them into a sorted, de-duplicated list of files.
201
+ We prioritize parquet, then csv/tsv.
202
+ """
203
+ paths: List[str] = []
204
+ if isinstance(maybe_path_or_glob, str):
205
+ p = Path(maybe_path_or_glob)
206
+ if p.is_dir():
207
+ # Scan directory for parquet first, then csv/tsv
208
+ paths.extend(sorted(str(x) for x in p.rglob("*.parquet")))
209
+ paths.extend(sorted(str(x) for x in p.rglob("*.parq")))
210
+ paths.extend(sorted(str(x) for x in p.rglob("*.csv")))
211
+ paths.extend(sorted(str(x) for x in p.rglob("*.tsv")))
212
+ paths.extend(sorted(str(x) for x in p.rglob("*.csv.gz")))
213
+ paths.extend(sorted(str(x) for x in p.rglob("*.tsv.gz")))
214
+ else:
215
+ paths = sorted(glob.glob(str(p)))
216
+ else:
217
+ for it in maybe_path_or_glob:
218
+ paths.extend(_expand_paths(it))
219
+ # Dedup while preserving order
220
+ seen = set()
221
+ out = []
222
+ for x in paths:
223
+ if x not in seen:
224
+ out.append(x)
225
+ seen.add(x)
226
+ if not out:
227
+ raise FileNotFoundError(f"No input files found for: {maybe_path_or_glob}")
228
+ return out
229
+
230
+
231
+ def _dist_info() -> Tuple[int, int]:
232
+ """
233
+ Returns (num_global_workers, global_worker_id)
234
+ where global_worker_id = rank * num_workers + worker_id.
235
+ """
236
+ world_size = 1
237
+ rank = 0
238
+ try:
239
+ import torch.distributed as dist
240
+
241
+ if dist.is_available() and dist.is_initialized():
242
+ world_size = dist.get_world_size()
243
+ rank = dist.get_rank()
244
+ except Exception:
245
+ pass
246
+ wi = get_worker_info()
247
+ nw = wi.num_workers if wi else 1
248
+ wid = wi.id if wi else 0
249
+ return world_size * nw, rank * nw + wid
250
+
251
+
252
+ class _ResumeSkipProgress:
253
+ """Lightweight progress helper for resume skips."""
254
+
255
+ def __init__(self, total: int, label: str):
256
+ self.total = int(max(0, total))
257
+ self.label = label
258
+ self.count = 0
259
+ self._bar = None
260
+
261
+ if self.total <= 0:
262
+ return
263
+
264
+ if _tqdm is not None:
265
+ self._bar = _tqdm(total=self.total, desc=label, unit="sample", dynamic_ncols=True, leave=False)
266
+ else:
267
+ logger.info("%s: skipping %d samples to reach resume cursor", label, self.total)
268
+
269
+ def update(self, n: int = 1):
270
+ if self.total <= 0:
271
+ return
272
+ self.count += int(n)
273
+ if self._bar is not None:
274
+ self._bar.update(n)
275
+ else:
276
+ if self.count == self.total or self.count % 10000 == 0:
277
+ logger.info("%s: skipped %d / %d", self.label, self.count, self.total)
278
+
279
+ def close(self):
280
+ if self.total <= 0:
281
+ return
282
+ if self._bar is not None:
283
+ self._bar.close()
284
+ logger.info("%s: resume skip finished (%d samples)", self.label, self.count)
285
+
286
+
287
+ class StreamSeqDataset(IterableDataset):
288
+ """
289
+ Streaming dataset with **non-overlapping Parquet row-group sharding**.
290
+
291
+ - Accepts list of files (parquet and/or csv/tsv).
292
+ - **Parquet**: we enumerate (file, row_group) tasks and stride them across
293
+ the *global* worker id to avoid duplicates and to keep all ranks busy even
294
+ with few files.
295
+ - **CSV/TSV**: assigned at file granularity (one worker reads a file).
296
+ If you have only a few CSV files and many ranks, some ranks may get no CSV work.
297
+ (Parquet is the recommended format at scale.)
298
+ - CSV is read with pandas chunksize to keep memory usage sane.
299
+ - Each Parquet task reads exactly **one row group** into pandas.
300
+
301
+ Minimal resume support:
302
+ - set_resume_skip(N) skips N yielded samples across the worker's assigned tasks.
303
+ (Use a **per-rank** skip value in your trainer so multi-node resumes stay in lockstep.)
304
+
305
+ Output sample schema:
306
+ {
307
+ "species_name": str,
308
+ "species_id": int,
309
+ "protein_seq": str, # raw AA (ESM tokenized later)
310
+ "aa_len": int,
311
+ "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end
312
+ "refseq_id": str,
313
+ "protein_refseq_id": str,
314
+ "control_mode": "fixed",
315
+ "meta": {"src": "parquet|csv", "file": basename, "row": int}
316
+ }
317
+ """
318
+
319
+ # Canonical required columns. We also accept common aliases (e.g., 'taxon').
320
+ REQUIRED = ["cds_DNA", "protein_seq", "Taxon"]
321
+
322
+ def __init__(
323
+ self,
324
+ files: List[str],
325
+ tokenizer,
326
+ species_vocab_path: str,
327
+ unknown_species_id: int = 0,
328
+ csv_chunksize: int = 200_000,
329
+ shuffle_buffer: int = 0,
330
+ seed: int = 1234,
331
+ shard_across_ranks: bool = True,
332
+ ):
333
+ super().__init__()
334
+ self.files = files
335
+ self.tok = tokenizer
336
+ with open(species_vocab_path, "r") as f:
337
+ self.species_vocab: Dict[str, int] = json.load(f)
338
+ self.unknown_species_id = int(unknown_species_id)
339
+ self.csv_chunksize = int(max(1, csv_chunksize))
340
+ self.shuffle_buffer = int(max(0, shuffle_buffer))
341
+ self.seed = int(seed)
342
+ # When False, every rank iterates over the full task list instead of
343
+ # taking a disjoint shard. This keeps FSDP collectives aligned during
344
+ # evaluation even if the validation dataset is smaller than WORLD_SIZE.
345
+ self.shard_across_ranks = bool(shard_across_ranks)
346
+
347
+ # Minimal resume cursor
348
+ self._resume_skip_n: int = 0
349
+ self._offset_start: int = 0
350
+ self._emitted: int = 0
351
+
352
+ # ---- resume cursor (minimal) ----
353
+ def set_resume_skip(self, n: int) -> None:
354
+ n = int(max(0, n))
355
+ self._resume_skip_n = n
356
+ self._offset_start = n
357
+ self._emitted = 0
358
+
359
+ def get_stream_position(self) -> int:
360
+ # Total yielded so far since dataset creation, including initial skip offset
361
+ return int(self._offset_start + self._emitted)
362
+
363
+ # ---- core row-wise iterator on a pandas DataFrame ----
364
+ def _iter_df(self, df: pd.DataFrame, src: str, file: str) -> Iterable[Dict[str, Any]]:
365
+ # Normalize common column aliases before validating.
366
+ # Some shards use lowercase `taxon` instead of `Taxon`.
367
+ if "Taxon" not in df.columns and "taxon" in df.columns:
368
+ df = df.rename(columns={"taxon": "Taxon"})
369
+
370
+ # Hard fail if required missing
371
+ for c in self.REQUIRED:
372
+ if c not in df.columns:
373
+ raise ValueError(f"Input missing required column '{c}' in {file}")
374
+
375
+ # Normalize & clean
376
+ df = df[self.REQUIRED + ([c for c in ["RefseqID"] if c in df.columns])]
377
+ df["Taxon"] = df["Taxon"].astype(str).str.strip()
378
+ df["protein_seq"] = df["protein_seq"].astype(str).str.strip().str.upper()
379
+ df["cds_DNA"] = df["cds_DNA"].astype(str).str.strip().str.upper()
380
+
381
+ # Filter DNA: ACGT only and length > 0
382
+ ok_mask = (df["cds_DNA"].str.len() > 0) & df["cds_DNA"].str.fullmatch(r"[ACGT]+", na=False)
383
+ df = df[ok_mask]
384
+ if df.empty:
385
+ return
386
+
387
+ # Trim protein/DNA to shared min length (in codons)
388
+ cds_codons = (df["cds_DNA"].str.len() // 3).astype(int)
389
+ prot_len = df["protein_seq"].str.len().astype(int)
390
+ min_len = np.minimum(cds_codons.values, prot_len.values)
391
+
392
+ df = df.assign(__min_len=min_len)
393
+ df = df[df["__min_len"] > 0]
394
+ if df.empty:
395
+ return
396
+
397
+ # Species id map
398
+ def map_species(x: str) -> int:
399
+ try:
400
+ return int(self.species_vocab.get(x, self.unknown_species_id))
401
+ except Exception:
402
+ return self.unknown_species_id
403
+
404
+ species_ids = [map_species(x) for x in df["Taxon"].tolist()]
405
+ refseq_col = "RefseqID" if "RefseqID" in df.columns else None
406
+
407
+ for i, (row_idx, row) in enumerate(df.iterrows()):
408
+ ml = int(row["__min_len"])
409
+ cds = row["cds_DNA"][: ml * 3]
410
+ prot = row["protein_seq"][: ml]
411
+ if (len(cds) // 3) != len(prot):
412
+ continue
413
+
414
+ # Tokenize DNA → 3-mer ids; append EOS
415
+ codon_ids = self.tok.encode_codon_seq(cds, validate=False)
416
+ codon_ids.append(
417
+ self.tok.special_ids.eos if hasattr(self.tok, "special_ids") else self.tok._special_ids.eos
418
+ )
419
+
420
+ species_id = species_ids[i]
421
+ ref_id = row[refseq_col] if refseq_col else f"{Path(file).stem}:{int(row_idx)}"
422
+
423
+ yield {
424
+ "species_name": row["Taxon"],
425
+ "species_id": int(species_id),
426
+ "protein_seq": prot,
427
+ "aa_len": len(prot),
428
+ "codon_ids": codon_ids,
429
+ "refseq_id": ref_id,
430
+ "protein_refseq_id": ref_id,
431
+ "control_mode": "fixed",
432
+ "meta": {"src": src, "file": os.path.basename(file), "row": int(row_idx)},
433
+ }
434
+
435
+ # ---- Parquet helpers: enumerate row-group tasks & read one row group ----
436
+ def _enumerate_tasks(self, files: List[str]) -> List[Tuple[str, str, Optional[int], int]]:
437
+ """
438
+ Return a task list of tuples:
439
+ ("parquet", path, row_group_idx, weight) for each row group in each Parquet file
440
+ ("csv", path, None, weight) for each CSV/TSV file
441
+ """
442
+ tasks: List[Tuple[str, str, Optional[int], int]] = []
443
+ parquet_files = [f for f in files if _is_parquet(f)]
444
+ csv_files = [f for f in files if _is_csv(f)]
445
+
446
+ if parquet_files:
447
+ try:
448
+ import pyarrow.parquet as pq # type: ignore
449
+ except Exception as e:
450
+ raise ImportError("pyarrow is required to read parquet files") from e
451
+
452
+ for fp in parquet_files:
453
+ pf = pq.ParquetFile(fp)
454
+ nrg = int(pf.num_row_groups or 0)
455
+ if nrg <= 0:
456
+ # Treat as single task if row groups unavailable (unusual)
457
+ total_rows = pf.metadata.num_rows if pf.metadata and pf.metadata.num_rows is not None else 1
458
+ tasks.append(("parquet", fp, 0, max(1, int(total_rows))))
459
+ else:
460
+ for rg in range(nrg):
461
+ if pf.metadata is not None:
462
+ rg_meta = pf.metadata.row_group(rg)
463
+ num_rows = rg_meta.num_rows if rg_meta.num_rows is not None else 0
464
+ else:
465
+ num_rows = 0
466
+ tasks.append(("parquet", fp, rg, max(1, int(num_rows))))
467
+
468
+ # CSV/TSV files remain file-level tasks
469
+ for fp in csv_files:
470
+ file_size = os.path.getsize(fp)
471
+ # Assume ~256 bytes per record when estimating CSV row counts (empirical default)
472
+ est_rows = max(1, int(file_size // 256))
473
+ tasks.append(("csv", fp, None, est_rows))
474
+
475
+ # Keep a deterministic order
476
+ # (files are already sorted by _expand_paths)
477
+ return tasks
478
+
479
+ @staticmethod
480
+ def _balanced_partition(tasks: List[Tuple[str, str, Optional[int], int]], groups: int) -> List[List[Tuple[str, str, Optional[int], int]]]:
481
+ if groups <= 1:
482
+ return [tasks]
483
+ if not tasks:
484
+ return [[] for _ in range(groups)]
485
+
486
+ # Greedy load balancing: assign heavier tasks first to the lightest bucket.
487
+ indexed = [(idx, kind, path, rg, weight) for idx, (kind, path, rg, weight) in enumerate(tasks)]
488
+ tasks_sorted = sorted(
489
+ indexed,
490
+ key=lambda entry: (entry[4], -entry[0]),
491
+ reverse=True,
492
+ )
493
+
494
+ heap: List[Tuple[int, int]] = [(0, bucket_idx) for bucket_idx in range(groups)]
495
+ heapq.heapify(heap)
496
+ buckets: List[List[Tuple[int, str, str, Optional[int], int]]] = [[] for _ in range(groups)]
497
+
498
+ for original_index, kind, path, rg, weight in tasks_sorted:
499
+ load, bucket_idx = heapq.heappop(heap)
500
+ buckets[bucket_idx].append((original_index, kind, path, rg, weight))
501
+ heapq.heappush(heap, (load + weight, bucket_idx))
502
+
503
+ partitions: List[List[Tuple[str, str, Optional[int], int]]] = []
504
+ for bucket in buckets:
505
+ bucket.sort(key=lambda entry: entry[0])
506
+ partitions.append([(kind, path, rg, weight) for (_idx, kind, path, rg, weight) in bucket])
507
+ return partitions
508
+
509
+ def _parquet_rowgroup_iter(
510
+ self, file: str, row_group_idx: int, cols_cache: Dict[str, List[str]]
511
+ ) -> Iterable[Dict[str, Any]]:
512
+ import pyarrow.parquet as pq # safe: checked in _enumerate_tasks
513
+ pf = pq.ParquetFile(file)
514
+ # Cache the column subset per file so we don't recompute
515
+ if file not in cols_cache:
516
+ names = set(pf.schema.names)
517
+ cols: List[str] = []
518
+ # Required columns, with alias support (notably Taxon vs taxon).
519
+ for c in self.REQUIRED:
520
+ if c in names:
521
+ cols.append(c)
522
+ continue
523
+ if c == "Taxon" and "taxon" in names:
524
+ cols.append("taxon")
525
+ continue
526
+ # Optional debug id
527
+ if "RefseqID" in names:
528
+ cols.append("RefseqID")
529
+ cols_cache[file] = cols
530
+ cols = cols_cache[file]
531
+ table = pf.read_row_group(row_group_idx, columns=cols)
532
+ df = table.to_pandas(types_mapper=None)
533
+ yield from self._iter_df(df, "parquet", file)
534
+
535
+ def _csv_file_iter(self, file: str) -> Iterable[Dict[str, Any]]:
536
+ # One worker owns this file (non-overlapping assignment)
537
+ for chunk in pd.read_csv(file, chunksize=self.csv_chunksize, dtype=str, keep_default_na=False):
538
+ yield from self._iter_df(chunk, "csv", file)
539
+
540
+ # ---- main iterator ----
541
+ def __iter__(self):
542
+ wi = get_worker_info()
543
+ num_workers = wi.num_workers if wi else 1
544
+ worker_id = wi.id if wi else 0
545
+
546
+ num_global, gid = _dist_info()
547
+ if not self.shard_across_ranks:
548
+ num_global = max(1, num_workers)
549
+ gid = worker_id
550
+
551
+ workers_per_rank = max(1, num_workers)
552
+ rank = gid // workers_per_rank if self.shard_across_ranks else 0
553
+ world = max(1, num_global // workers_per_rank)
554
+
555
+ # Each rank may have a non-zero per-rank resume skip. Split evenly across local
556
+ # dataloader workers so the sum equals the per-rank target, then apply a fast
557
+ # task-level skip to avoid row-by-row scans for huge cursors.
558
+ per_rank_skip = int(self._resume_skip_n)
559
+ base = per_rank_skip // max(1, workers_per_rank)
560
+ rem = per_rank_skip % max(1, workers_per_rank)
561
+ local_skip_target = base + (1 if worker_id < rem else 0)
562
+ progress: Optional[_ResumeSkipProgress] = None
563
+
564
+ # Build the global task list (parquet row groups + csv files) and shard by gid
565
+ tasks = self._enumerate_tasks(self.files)
566
+
567
+ if tasks:
568
+ partitions = self._balanced_partition(tasks, max(1, num_global))
569
+ my_tasks_full = partitions[gid] if gid < len(partitions) else []
570
+ else:
571
+ my_tasks_full = []
572
+
573
+ if local_skip_target > 0 and worker_id == 0:
574
+ label = (
575
+ "resume skip" if world == 1 else f"resume skip (rank {rank}/{world})"
576
+ )
577
+ progress = _ResumeSkipProgress(local_skip_target, label)
578
+
579
+ # Fast task-level skip: consume whole tasks when their weight is <= remaining skip
580
+ # and only fall back to row-level skipping for the first partial task.
581
+ skip_remaining = int(local_skip_target)
582
+ start_idx = 0
583
+ partial_task_idx = None
584
+ partial_task_kind = None
585
+ partial_task_path = None
586
+ partial_task_rg = None
587
+ if skip_remaining > 0 and my_tasks_full:
588
+ for idx, (kind, path, rg, weight) in enumerate(my_tasks_full):
589
+ w = int(weight) if weight is not None else 0
590
+ if w <= 0:
591
+ continue
592
+ if skip_remaining >= w:
593
+ skip_remaining -= w
594
+ start_idx = idx + 1
595
+ if progress is not None:
596
+ progress.update(w)
597
+ else:
598
+ partial_task_idx = idx
599
+ partial_task_kind = kind
600
+ partial_task_path = path
601
+ partial_task_rg = rg
602
+ break
603
+
604
+ # Slice my task list to start after any fully-skipped tasks
605
+ my_tasks = [(kind, path, rg) for (kind, path, rg, _w) in my_tasks_full[start_idx:]]
606
+
607
+ rng = random.Random(self.seed + gid)
608
+ buffer: List[Dict[str, Any]] = []
609
+ bufN = self.shuffle_buffer
610
+
611
+ def _drain_buffer():
612
+ if not buffer:
613
+ return
614
+ if bufN > 0:
615
+ rng.shuffle(buffer)
616
+ for it in buffer:
617
+ yield it
618
+ buffer.clear()
619
+
620
+ # Skip counter for resume cursor (row-level remainder after task skips)
621
+ skipped = int(local_skip_target - skip_remaining)
622
+
623
+ # Cache for per-file Parquet column selection
624
+ cols_cache: Dict[str, List[str]] = {}
625
+
626
+ try:
627
+ # If we split a task, handle its partial row-level skip first
628
+ if partial_task_idx is not None and skip_remaining > 0:
629
+ kind = partial_task_kind
630
+ path = partial_task_path
631
+ rg = partial_task_rg
632
+ if kind == "parquet":
633
+ assert rg is not None
634
+ row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache)
635
+ elif kind == "csv":
636
+ row_iter = self._csv_file_iter(path)
637
+ else:
638
+ raise ValueError(f"Unknown task kind: {kind}")
639
+
640
+ for sample in row_iter:
641
+ if skip_remaining > 0:
642
+ skip_remaining -= 1
643
+ skipped += 1
644
+ if progress is not None:
645
+ progress.update(1)
646
+ if skip_remaining == 0 and progress is not None:
647
+ progress.close()
648
+ progress = None
649
+ continue
650
+ # past the partial skip remainder, fall through to normal buffering/yield
651
+ if bufN <= 0:
652
+ self._emitted += 1
653
+ yield sample
654
+ else:
655
+ buffer.append(sample)
656
+ if len(buffer) >= bufN:
657
+ j = rng.randrange(len(buffer))
658
+ buffer[j], buffer[-1] = buffer[-1], buffer[j]
659
+ self._emitted += 1
660
+ yield buffer.pop()
661
+
662
+ for (kind, path, rg) in my_tasks:
663
+ if kind == "parquet":
664
+ assert rg is not None
665
+ row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache)
666
+ elif kind == "csv":
667
+ row_iter = self._csv_file_iter(path)
668
+ else:
669
+ raise ValueError(f"Unknown task kind: {kind}")
670
+
671
+ for sample in row_iter:
672
+ # Apply any remaining resume skip across the flattened stream
673
+ if skip_remaining > 0:
674
+ skip_remaining -= 1
675
+ skipped += 1
676
+ if progress is not None:
677
+ progress.update(1)
678
+ if skip_remaining == 0 and progress is not None:
679
+ # Finish the progress bar once we've consumed the target
680
+ progress.close()
681
+ progress = None
682
+ continue
683
+
684
+ if bufN <= 0:
685
+ self._emitted += 1
686
+ yield sample
687
+ else:
688
+ buffer.append(sample)
689
+ if len(buffer) >= bufN:
690
+ j = rng.randrange(len(buffer))
691
+ buffer[j], buffer[-1] = buffer[-1], buffer[j]
692
+ self._emitted += 1
693
+ yield buffer.pop()
694
+
695
+ # Flush leftovers
696
+ for it in _drain_buffer():
697
+ self._emitted += 1
698
+ yield it
699
+ finally:
700
+ if progress is not None:
701
+ progress.close()
702
+ if local_skip_target > 0:
703
+ # Persist any remaining leftover skip (including partial progress) per worker copy
704
+ self._resume_skip_n = max(local_skip_target - skipped, 0)
705
+
706
+
707
+ # ------------------------------
708
+ # Simple collate: end-only pad for codon stream, pass-through everything else
709
+ # ------------------------------
710
+
711
+ def stage_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
712
+ B = len(batch)
713
+ if B == 0:
714
+ return {}
715
+
716
+ # species ids
717
+ species_ids = torch.tensor([int(x.get("species_id", 0)) for x in batch], dtype=torch.long)
718
+
719
+ # raw protein sequences stay as list[str] (ESM handles tokenization)
720
+ protein_seqs = [str(x.get("protein_seq", "M")) for x in batch]
721
+
722
+ # Build padded codon ids (right padding). Keep EOS inside the sequence (already appended in dataset).
723
+ codon_lists = [x.get("codon_ids", []) for x in batch]
724
+ max_len = max(len(c) for c in codon_lists)
725
+ pad_id = 0 # tokenizer.pad_token_id is 0 in our tokenizer.
726
+ codon_ids = torch.full((B, max_len), pad_id, dtype=torch.long)
727
+ for i, row in enumerate(codon_lists):
728
+ if len(row) > 0:
729
+ codon_ids[i, : len(row)] = torch.tensor(row, dtype=torch.long)
730
+
731
+ out: Dict[str, Any] = {
732
+ "species_ids": species_ids,
733
+ "protein_seqs": protein_seqs,
734
+ "codon_ids": codon_ids,
735
+ "control_mode": batch[0].get("control_mode", "fixed"),
736
+ }
737
+
738
+ # Optional passthroughs
739
+ if "refseq_id" in batch[0]:
740
+ out["refseq_id"] = [x.get("refseq_id") for x in batch]
741
+ if "protein_refseq_id" in batch[0]:
742
+ out["protein_refseq_id"] = [x.get("protein_refseq_id") for x in batch]
743
+
744
+ return out
745
+
746
+ def _build_dataset(
747
+ path_or_paths: str | List[str],
748
+ tokenizer,
749
+ species_vocab_path: str,
750
+ shuffle_buffer: int,
751
+ csv_chunksize: int,
752
+ shard_across_ranks: bool = True,
753
+ ) -> StreamSeqDataset:
754
+ files = _expand_paths(path_or_paths)
755
+ return StreamSeqDataset(
756
+ files=files,
757
+ tokenizer=tokenizer,
758
+ species_vocab_path=species_vocab_path,
759
+ unknown_species_id=0,
760
+ csv_chunksize=csv_chunksize,
761
+ shuffle_buffer=shuffle_buffer,
762
+ seed=1234,
763
+ shard_across_ranks=shard_across_ranks,
764
+ )
765
+
766
+
767
+ def create_precomputed_dataloaders(
768
+ train_path: str | List[str],
769
+ val_path: Optional[str | List[str]],
770
+ embeddings_dir: str,
771
+ tokenizer,
772
+ batch_size: int,
773
+ num_workers: int = 4,
774
+ species_pooling: str = "sequence",
775
+ csv_chunksize: int = 200_000,
776
+ train_shuffle_buffer: int = 8192,
777
+ val_shuffle_buffer: int = 0,
778
+ ) -> Tuple[DataLoader, Optional[DataLoader], SpeciesEmbeddingStore]:
779
+ """
780
+ Returns:
781
+ - train_loader, val_loader (optional), and the SpeciesEmbeddingStore
782
+ """
783
+ species_store = SpeciesEmbeddingStore(embeddings_dir, pin_memory=True, pooling=species_pooling)
784
+ species_vocab_path = os.path.join(embeddings_dir, "species_vocab.json")
785
+ num_workers = int(max(0, num_workers))
786
+
787
+ train_ds = _build_dataset(
788
+ path_or_paths=train_path,
789
+ tokenizer=tokenizer,
790
+ species_vocab_path=species_vocab_path,
791
+ shuffle_buffer=int(train_shuffle_buffer),
792
+ csv_chunksize=int(csv_chunksize),
793
+ )
794
+ val_ds = None
795
+ if val_path:
796
+ val_ds = _build_dataset(
797
+ path_or_paths=val_path,
798
+ tokenizer=tokenizer,
799
+ species_vocab_path=species_vocab_path,
800
+ shuffle_buffer=int(val_shuffle_buffer),
801
+ csv_chunksize=int(csv_chunksize),
802
+ )
803
+
804
+ # NOTE: IterableDataset can't be shuffled by DataLoader. We already "shuffle" inside the dataset.
805
+ kwargs_common = dict(
806
+ num_workers=num_workers,
807
+ collate_fn=stage_collate_fn,
808
+ pin_memory=True,
809
+ persistent_workers=(num_workers > 0),
810
+ )
811
+ if num_workers > 0:
812
+ kwargs_common["prefetch_factor"] = 4
813
+
814
+ # Drop last for train to keep batch shapes stable under FSDP.
815
+ train_loader = DataLoader(
816
+ train_ds,
817
+ batch_size=batch_size,
818
+ shuffle=False,
819
+ drop_last=True,
820
+ **kwargs_common,
821
+ )
822
+
823
+ val_loader = None
824
+ if val_ds is not None:
825
+ val_loader = DataLoader(
826
+ val_ds,
827
+ batch_size=batch_size,
828
+ shuffle=False,
829
+ drop_last=False,
830
+ **kwargs_common,
831
+ )
832
+
833
+ return train_loader, val_loader, species_store
src/layers.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer components for CodonGPT.
3
+ Includes RMSNorm, self-attention (SDPA/Flash) with optional mask,
4
+ cross-attention for conditioning memory, SwiGLU FFN, and a basic block.
5
+ """
6
+
7
+ import math
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.attention import SDPBackend, sdpa_kernel # Require recent PyTorch
14
+
15
+
16
+ class RMSNorm(nn.Module):
17
+ """Root Mean Square Layer Normalization."""
18
+
19
+ def __init__(self, dim: int, eps: float = 1e-6):
20
+ super().__init__()
21
+ self.eps = eps
22
+ self.weight = nn.Parameter(torch.ones(dim))
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ Apply RMS normalization.
27
+
28
+ Args:
29
+ x: Input tensor of any shape ending in dim
30
+
31
+ Returns:
32
+ Normalized tensor of same shape
33
+ """
34
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
35
+ return x * norm * self.weight
36
+
37
+
38
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
39
+ """Apply rotary embeddings to x: [B,H,T,D]; cos/sin: [1,1,T,D]."""
40
+ x1 = x[..., ::2]
41
+ x2 = x[..., 1::2]
42
+ x_rot = torch.zeros_like(x)
43
+ x_rot[..., ::2] = -x2
44
+ x_rot[..., 1::2] = x1
45
+ return x * cos + x_rot * sin
46
+
47
+
48
+ class MultiHeadAttention(nn.Module):
49
+ """Self-attention using PyTorch SDPA kernels (Flash/MemEff/Math) + RoPE.
50
+ - attn_mask: bool [B, T, T] with True = keep, False = block
51
+ - is_causal: whether to apply causal masking internally
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ dim: int,
57
+ num_heads: int,
58
+ dropout: float = 0.0,
59
+ use_rope: bool = True,
60
+ ):
61
+ super().__init__()
62
+ assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
63
+ self.dim = dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = dim // num_heads
66
+ self.dropout = dropout
67
+ self.use_rope = use_rope
68
+
69
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
70
+ self.out_proj = nn.Linear(dim, dim, bias=False)
71
+ self.resid_dropout = nn.Dropout(dropout)
72
+
73
+ # RoPE cache
74
+ self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
75
+
76
+ def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
77
+ key = (T, device, dtype)
78
+ cached = self._rope_cache.get(key)
79
+ if cached is not None:
80
+ return cached
81
+ dim_half = self.head_dim // 2
82
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
83
+ t = torch.arange(T, device=device, dtype=torch.float32)
84
+ freqs = torch.outer(t, inv_freq)
85
+ cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
86
+ sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
87
+ cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D]
88
+ sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
89
+ self._rope_cache[key] = (cos, sin)
90
+ return cos, sin
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
96
+ return_kv: bool = False,
97
+ position_offset: int = 0,
98
+ ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
99
+ """
100
+ Self-attention with optional KV cache support.
101
+
102
+ Args:
103
+ x: [B, T_new, H]
104
+ past_kv: Optional tuple (k, v), each [B, nH, T_past, Hd]
105
+ return_kv: If True, also return updated (k, v)
106
+ position_offset: Starting position index for RoPE (past length)
107
+
108
+ Returns:
109
+ out or (out, present_kv)
110
+ """
111
+ B, T_new, _ = x.shape
112
+
113
+ # QKV projections and reshape (ensure contiguous for SDPA kernels)
114
+ qkv = self.qkv(x).view(B, T_new, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
115
+ q, k_new, v_new = qkv[0].contiguous(), qkv[1].contiguous(), qkv[2].contiguous()
116
+
117
+ # RoPE for new tokens only
118
+ if self.use_rope:
119
+ # Compute cos/sin up to (offset + T_new), then slice the tail for new positions
120
+ cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
121
+ if position_offset > 0:
122
+ cos = cos[:, :, position_offset: position_offset + T_new, :]
123
+ sin = sin[:, :, position_offset: position_offset + T_new, :]
124
+ # Apply to q and k_new
125
+ q = _apply_rope(q, cos, sin)
126
+ k_new = _apply_rope(k_new, cos, sin)
127
+
128
+ # Concatenate with cache if provided
129
+ if past_kv is not None:
130
+ k_past, v_past = past_kv
131
+ k = torch.cat([k_past, k_new], dim=2)
132
+ v = torch.cat([v_past, v_new], dim=2)
133
+ is_causal = False # No future tokens present; avoid unnecessary masking
134
+ else:
135
+ k, v = k_new, v_new
136
+ is_causal = True
137
+
138
+ # Prefer FlashAttention; fall back to MemEff then Math. Autocast to half/bfloat16 on CUDA.
139
+ backends = [SDPBackend.FLASH_ATTENTION]#, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
140
+ with sdpa_kernel(backends):
141
+ if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
142
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
143
+ with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
144
+ out = F.scaled_dot_product_attention(
145
+ q, k, v,
146
+ dropout_p=self.dropout if self.training else 0.0,
147
+ is_causal=is_causal,
148
+ )
149
+ else:
150
+ out = F.scaled_dot_product_attention(
151
+ q, k, v,
152
+ dropout_p=self.dropout if self.training else 0.0,
153
+ is_causal=is_causal,
154
+ )
155
+
156
+ out = out.transpose(1, 2).contiguous().view(B, T_new, self.dim)
157
+ # Align dtype with residual/Linear weights to avoid bf16/float mismatches
158
+ if out.dtype != x.dtype:
159
+ out = out.to(x.dtype)
160
+ out = self.out_proj(out)
161
+ out = self.resid_dropout(out)
162
+
163
+ if return_kv:
164
+ return out, (k, v)
165
+ return out
166
+
167
+
168
+
169
+ class GroupedQueryAttention(nn.Module):
170
+ """Grouped-Query Attention (GQA) using Flash Attention via PyTorch SDPA.
171
+
172
+ - num_heads total query heads
173
+ - num_kv_groups shared K/V groups (num_heads must be divisible by num_kv_groups)
174
+ - Optional q/k RMSNorm
175
+ - Supports RoPE with a scalar or per-sample position_offset (like MHA)
176
+ - Optional KV cache compatible with the existing interface (stores expanded per-head K/V)
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ dim: int,
182
+ num_heads: int,
183
+ num_kv_groups: int,
184
+ dropout: float = 0.0,
185
+ qk_norm: bool = False,
186
+ ) -> None:
187
+ super().__init__()
188
+ assert num_heads % max(1, num_kv_groups) == 0, "num_heads must be divisible by num_kv_groups"
189
+ self.dim = dim
190
+ self.num_heads = int(num_heads)
191
+ self.num_kv_groups = max(1, int(num_kv_groups))
192
+ self.group_size = self.num_heads // self.num_kv_groups
193
+
194
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
195
+ self.head_dim = dim // num_heads
196
+ self.dropout = dropout
197
+
198
+ self.Wq = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
199
+ self.Wk = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
200
+ self.Wv = nn.Linear(dim, self.num_kv_groups * self.head_dim, bias=False)
201
+ self.out_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False)
202
+
203
+ self.q_norm = RMSNorm(self.head_dim) if qk_norm else None
204
+ self.k_norm = RMSNorm(self.head_dim) if qk_norm else None
205
+
206
+ # RoPE cache
207
+ self._rope_cache: dict[tuple[int, torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor]] = {}
208
+
209
+ def _rope_cos_sin(self, T: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
210
+ key = (T, device, dtype)
211
+ cached = self._rope_cache.get(key)
212
+ if cached is not None:
213
+ return cached
214
+ dim_half = self.head_dim // 2
215
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_half, device=device, dtype=torch.float32) / dim_half))
216
+ t = torch.arange(T, device=device, dtype=torch.float32)
217
+ freqs = torch.outer(t, inv_freq)
218
+ cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
219
+ sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
220
+ cos = cos.to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,T,D]
221
+ sin = sin.to(dtype).unsqueeze(0).unsqueeze(0)
222
+ self._rope_cache[key] = (cos, sin)
223
+ return cos, sin
224
+
225
+ def forward(
226
+ self,
227
+ x: torch.Tensor,
228
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
229
+ return_kv: bool = False,
230
+ position_offset: int | torch.Tensor = 0,
231
+ ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
232
+ B, T_new, _ = x.shape
233
+
234
+ # Project to Q, K, V
235
+ q = self.Wq(x).view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2).contiguous() # [B,H,T,Hd]
236
+ k = self.Wk(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd]
237
+ v = self.Wv(x).view(B, T_new, self.num_kv_groups, self.head_dim).transpose(1, 2).contiguous() # [B,G,T,Hd]
238
+
239
+ # Optional RMSNorm on q/k
240
+ if self.q_norm is not None:
241
+ q = self.q_norm(q)
242
+ if self.k_norm is not None:
243
+ k = self.k_norm(k)
244
+
245
+ # RoPE for new tokens only
246
+ if isinstance(position_offset, int):
247
+ cos, sin = self._rope_cos_sin(position_offset + T_new, x.device, q.dtype)
248
+ if position_offset > 0:
249
+ cos = cos[:, :, position_offset: position_offset + T_new, :]
250
+ sin = sin[:, :, position_offset: position_offset + T_new, :]
251
+ q = _apply_rope(q, cos, sin)
252
+ k = _apply_rope(k, cos, sin)
253
+ else:
254
+ off = position_offset.to(device=x.device, dtype=torch.long)
255
+ max_off = int(off.max().item())
256
+ cos_all, sin_all = self._rope_cos_sin(max_off + T_new, x.device, q.dtype)
257
+ ar = torch.arange(T_new, device=x.device, dtype=torch.long)
258
+ idx = (off.unsqueeze(1) + ar.unsqueeze(0)) # [B, T_new]
259
+ cos_b = cos_all.squeeze(0).squeeze(0)[idx].unsqueeze(1) # [B,1,T,D]
260
+ sin_b = sin_all.squeeze(0).squeeze(0)[idx].unsqueeze(1)
261
+ q = _apply_rope(q, cos_b, sin_b)
262
+ # k has groups dimension [B,G,T,D]; share same offsets per batch
263
+ k = _apply_rope(k, cos_b, sin_b)
264
+
265
+ # Expand grouped K/V to per-head by repeating groups
266
+ if self.group_size > 1:
267
+ k_exp = k.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd]
268
+ v_exp = v.repeat_interleave(self.group_size, dim=1) # [B,H,T,Hd]
269
+ else:
270
+ k_exp, v_exp = k, v # already per-head
271
+
272
+ # KV cache: concatenate past along sequence dim
273
+ if past_kv is not None:
274
+ k_past, v_past = past_kv
275
+ k_cat = torch.cat([k_past, k_exp], dim=2)
276
+ v_cat = torch.cat([v_past, v_exp], dim=2)
277
+ is_causal = False
278
+ else:
279
+ k_cat, v_cat = k_exp, v_exp
280
+ is_causal = True
281
+
282
+ # Prefer FlashAttention; fall back to MemEff/Math. Ensure CUDA autocast to half/bfloat16 so kernels are available
283
+ with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
284
+ if x.device.type == "cuda" and q.dtype not in (torch.float16, torch.bfloat16):
285
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
286
+ with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
287
+ out = torch.nn.functional.scaled_dot_product_attention(
288
+ q, k_cat, v_cat,
289
+ dropout_p=self.dropout if self.training else 0.0,
290
+ is_causal=is_causal,
291
+ ) # [B,H,T,Hd]
292
+ else:
293
+ out = torch.nn.functional.scaled_dot_product_attention(
294
+ q, k_cat, v_cat,
295
+ dropout_p=self.dropout if self.training else 0.0,
296
+ is_causal=is_causal,
297
+ ) # [B,H,T,Hd]
298
+
299
+ out = out.transpose(1, 2).contiguous().view(B, T_new, self.num_heads * self.head_dim)
300
+ # Ensure dtype compatibility for Linear / residual path
301
+ if out.dtype != x.dtype:
302
+ out = out.to(x.dtype)
303
+ out = self.out_proj(out)
304
+
305
+ if return_kv:
306
+ return out, (k_cat, v_cat)
307
+ return out
308
+
309
+
310
+
311
+ class FeedForward(nn.Module):
312
+ """Feed-forward network with optional GLU activation."""
313
+
314
+ def __init__(
315
+ self,
316
+ dim: int,
317
+ hidden_dim: int,
318
+ dropout: float = 0.0,
319
+ ):
320
+ super().__init__()
321
+
322
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
323
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
324
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
325
+
326
+ self.dropout = nn.Dropout(dropout)
327
+
328
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
329
+ """
330
+ Apply feed-forward network.
331
+
332
+ Args:
333
+ x: Input tensor [B, T, dim]
334
+
335
+ Returns:
336
+ Output tensor [B, T, dim]
337
+ """
338
+
339
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
340
+
341
+
342
+ class TransformerBlock(nn.Module):
343
+ """Pre-norm Transformer block using self-attn + SwiGLU FFN (no cross-attention)."""
344
+
345
+ def __init__(
346
+ self,
347
+ dim: int,
348
+ num_heads: int,
349
+ mlp_ratio: float = 4.0,
350
+ dropout: float = 0.0,
351
+ num_kv_groups: int | None = None,
352
+ qk_norm: bool = False,
353
+ attn_type: str = "gqa", # "gqa" or "mha"
354
+ ):
355
+ super().__init__()
356
+ self.norm1 = RMSNorm(dim)
357
+ if attn_type == "mha":
358
+ self.attn = MultiHeadAttention(dim=dim, num_heads=num_heads, dropout=dropout)
359
+ self._attn_is_gqa = False
360
+ else:
361
+ # Use Grouped-Query Attention (defaults to no grouping when num_kv_groups is None)
362
+ kv_groups = num_heads if (num_kv_groups is None) else max(1, int(num_kv_groups))
363
+ self.attn = GroupedQueryAttention(dim=dim, num_heads=num_heads, num_kv_groups=kv_groups, dropout=dropout, qk_norm=qk_norm)
364
+ self._attn_is_gqa = True
365
+ self.norm2 = RMSNorm(dim)
366
+ self.ffn = FeedForward(dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=dropout)
367
+
368
+ def forward(
369
+ self,
370
+ x: torch.Tensor,
371
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
372
+ use_cache: bool = False,
373
+ position_offset: int = 0,
374
+ ) -> "torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]":
375
+ """Forward pass with optional KV caching."""
376
+ if use_cache or (past_kv is not None):
377
+ attn_out = self.attn(self.norm1(x), past_kv=past_kv, return_kv=True, position_offset=position_offset)
378
+ x = x + attn_out[0]
379
+ x = x + self.ffn(self.norm2(x))
380
+ return x, attn_out[1]
381
+ else:
382
+ x = x + self.attn(self.norm1(x))
383
+ x = x + self.ffn(self.norm2(x))
384
+ return x
src/models.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core model architectures for CodonGPT (GPT-only).
3
+ - CodonGPT: Decoder-only GPT with two-species + protein prefix
4
+ Includes a frozen ESM-C encoder for protein conditioning.
5
+ """
6
+
7
+ import math
8
+ import os
9
+ from typing import Optional, Dict, Any, Tuple, List
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+ import torch.nn.utils.rnn as rnn_utils
15
+
16
+ from .layers import RMSNorm, TransformerBlock
17
+ from .tokenizer import SpecialIds
18
+
19
+
20
+ class FrozenESMCEncoder(nn.Module):
21
+ """
22
+ Frozen ESM-C encoder that computes protein embeddings on the fly.
23
+ Kept on single GPU per rank (not distributed via FSDP).
24
+ """
25
+
26
+ def __init__(self, model_name: str = "esmc_300m", device: str = "cuda", dtype: str = "fp16"):
27
+ super().__init__()
28
+ self.model_name = model_name
29
+ self._device = torch.device(device if torch.cuda.is_available() else "cpu")
30
+ if dtype == "fp16":
31
+ self._autocast_dtype = torch.float16
32
+ elif dtype == "bf16":
33
+ self._autocast_dtype = torch.bfloat16
34
+ else:
35
+ self._autocast_dtype = None
36
+ self._load_model()
37
+ self.eval()
38
+ for p in self.parameters():
39
+ p.requires_grad_(False)
40
+
41
+ def _load_model(self):
42
+ from esm.models.esmc import ESMC
43
+ from esm.utils.constants.models import ESMC_300M, ESMC_600M
44
+ if self.model_name == "esmc_300m":
45
+ model_const = ESMC_300M
46
+ self.D_esm = 960
47
+ elif self.model_name == "esmc_600m":
48
+ model_const = ESMC_600M
49
+ self.D_esm = 1152
50
+ else:
51
+ raise ValueError(f"Unknown model: {self.model_name}")
52
+ self.model = ESMC.from_pretrained(model_name=model_const, device=self._device)
53
+ self.tokenizer = self.model.tokenizer
54
+
55
+ @torch.no_grad()
56
+ def tokenize(self, sequences: List[str], max_length: Optional[int] = None, add_special_tokens: bool = True, return_tensors: str = "pt"):
57
+ from esm.utils import encoding
58
+ from esm.utils.misc import stack_variable_length_tensors
59
+ pad = self.tokenizer.pad_token_id
60
+ tokenized_seqs = []
61
+ for seq in sequences:
62
+ tokens = encoding.tokenize_sequence(seq, self.tokenizer, add_special_tokens=add_special_tokens)
63
+ if max_length is not None and len(tokens) > max_length:
64
+ tokens = tokens[:max_length]
65
+ tokenized_seqs.append(tokens)
66
+ input_ids = stack_variable_length_tensors(tokenized_seqs, constant_value=pad)
67
+ attention_mask = (input_ids != pad)
68
+ return input_ids, attention_mask
69
+
70
+ @torch.no_grad()
71
+ def encode_from_ids(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, return_contacts: bool = False):
72
+ device = self.model.device
73
+ input_ids = input_ids.to(device)
74
+ if attention_mask is not None:
75
+ attention_mask = attention_mask.to(device)
76
+ if self._autocast_dtype is not None and device.type == "cuda":
77
+ with torch.amp.autocast('cuda', dtype=self._autocast_dtype):
78
+ outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
79
+ else:
80
+ outputs = self.model.forward(sequence_tokens=input_ids, sequence_id=attention_mask)
81
+ embeddings = outputs.embeddings
82
+ if return_dict:
83
+ return {"embeddings": embeddings, "attention_mask": attention_mask}
84
+ else:
85
+ return embeddings
86
+
87
+ def strip_special_tokens(self, embeddings: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None):
88
+ if attention_mask is not None:
89
+ lengths = attention_mask.sum(dim=1) - 2
90
+ lengths = lengths.clamp(min=1)
91
+ else:
92
+ B, L, D = embeddings.shape
93
+ lengths = torch.full((B,), L - 2, device=embeddings.device)
94
+ stripped = embeddings[:, 1:-1, :]
95
+ return stripped, lengths
96
+
97
+
98
+
99
+
100
+ class CodonGPT(nn.Module):
101
+ def __init__(
102
+ self,
103
+ vocab_size: int = 79,
104
+ hidden_size: int = 960,
105
+ num_layers: int = 24,
106
+ num_heads: int = 16,
107
+ mlp_ratio: float = 4.0,
108
+ max_position_embeddings: int = 4096,
109
+ dropout: float = 0.1,
110
+ layer_norm_eps: float = 1e-6,
111
+ num_special_tokens: int = 13,
112
+ special_ids: Optional[SpecialIds] = None,
113
+ esm_model_name: str = "esmc_300m",
114
+ esm_device: str = "cuda",
115
+ esm_dtype: str = "fp16",
116
+ max_protein_prefix: int = 0,
117
+ max_species_prefix: int = 0,
118
+ prepend_species: bool = True,
119
+ prepend_protein: bool = True,
120
+ species_embedding_dim: int = 1024,
121
+ attn_impl: str = "gqa", # "gqa" or "mha"
122
+ num_kv_groups: int = 0, # for GQA; 0 means default (no grouping)
123
+ ):
124
+ super().__init__()
125
+ self.vocab_size = vocab_size
126
+ self.hidden_size = hidden_size
127
+ self.num_layers = num_layers
128
+ self.num_heads = num_heads
129
+ self.max_position_embeddings = max_position_embeddings
130
+
131
+ self.special_ids = special_ids or SpecialIds()
132
+ self.num_special_tokens = num_special_tokens
133
+
134
+ # Single embedding table for all tokens (special + codon)
135
+ self.token_embed = nn.Embedding(vocab_size, hidden_size)
136
+
137
+ if prepend_protein and esm_model_name:
138
+ self.esm = FrozenESMCEncoder(esm_model_name, esm_device, esm_dtype)
139
+ # Project ESM token embeddings (D_esm) to model hidden size, then normalize
140
+ self.esm_ln = nn.Sequential(
141
+ nn.Linear(self.esm.D_esm, hidden_size, bias=False),
142
+ nn.ReLU(),
143
+ nn.LayerNorm(hidden_size),
144
+ )
145
+ else:
146
+ self.esm = None
147
+ self.esm_ln = None
148
+
149
+ self.species_embedding_dim = species_embedding_dim if prepend_species else 0
150
+ if prepend_species:
151
+ # Project species embeddings (fixed or token sequence) from Ds -> H
152
+ self.species_ln = nn.Sequential(
153
+ nn.Linear(self.species_embedding_dim, hidden_size, bias=False),
154
+ nn.ReLU(),
155
+ nn.LayerNorm(hidden_size),
156
+ )
157
+ else:
158
+ self.species_ln = None
159
+
160
+ # Optional per-prefix caps; 0 means unlimited (subject to global max length)
161
+ self.max_protein_prefix = int(max_protein_prefix) if max_protein_prefix is not None else 0
162
+ self.max_species_prefix = int(max_species_prefix) if max_species_prefix is not None else 0
163
+ self.prepend_species = bool(prepend_species)
164
+ self.prepend_protein = bool(prepend_protein)
165
+
166
+ # Learned start embedding (BOS-less decoding)
167
+ self.start_embed = nn.Parameter(torch.zeros(1, 1, hidden_size))
168
+ nn.init.normal_(self.start_embed, mean=0.0, std=0.02)
169
+
170
+
171
+ # Attention configuration
172
+ self.attn_impl = str(attn_impl)
173
+ self.num_kv_groups = int(num_kv_groups)
174
+ kv_groups = self.num_kv_groups
175
+ self.blocks = nn.ModuleList([
176
+ TransformerBlock(
177
+ dim=hidden_size,
178
+ num_heads=num_heads,
179
+ mlp_ratio=mlp_ratio,
180
+ dropout=dropout,
181
+ num_kv_groups=(kv_groups if (kv_groups > 0 and attn_impl == "gqa") else None),
182
+ qk_norm=False,
183
+ attn_type=("mha" if self.attn_impl == "mha" else "gqa"),
184
+ ) for _ in range(num_layers)
185
+ ])
186
+
187
+ self.ln_f = RMSNorm(hidden_size, eps=layer_norm_eps)
188
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
189
+ self.gradient_checkpointing = False
190
+
191
+ def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
192
+ device = self.token_embed.weight.device
193
+ return self.token_embed(token_ids.to(device))
194
+
195
+ def build_prefix(
196
+ self,
197
+ batch_size: int,
198
+ device: torch.device,
199
+ species_tok_emb: Optional[torch.Tensor] = None,
200
+ species_emb: Optional[torch.Tensor] = None,
201
+ protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
202
+ species_tok_emb_src: Optional[torch.Tensor] = None,
203
+ species_tok_emb_tgt: Optional[torch.Tensor] = None,
204
+ species_emb_src: Optional[torch.Tensor] = None,
205
+ species_emb_tgt: Optional[torch.Tensor] = None,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ """Build LLaVA-style prefix token embeddings by concatenating
208
+ [species_src]+[species_tgt]+[protein_tokens]. Returns:
209
+ - prefix: [B, Lp, H]
210
+ - prefix_lengths: [B] valid token counts per sample
211
+ """
212
+ parts: list[torch.Tensor] = []
213
+
214
+ # Species: src then tgt (if provided)
215
+ if self.prepend_species and self.species_ln is not None:
216
+ tok_src = species_tok_emb_src if species_tok_emb_src is not None else species_tok_emb
217
+ tok_tgt = species_tok_emb_tgt if species_tok_emb_tgt is not None else species_tok_emb
218
+ emb_src = species_emb_src if species_emb_src is not None else species_emb
219
+ emb_tgt = species_emb_tgt if species_emb_tgt is not None else species_emb
220
+
221
+ def _as_tokens(S_tok, S_fix):
222
+ if S_fix is not None:
223
+ # [B, Ds] -> [B, 1, H]
224
+ S = self.species_ln(S_fix.to(device=device, dtype=next(self.parameters()).dtype).unsqueeze(1))
225
+ return S
226
+ elif S_tok is not None:
227
+ # [B, Ls, Ds] -> optional cap, then project to H
228
+ S = S_tok
229
+ if getattr(self, "max_species_prefix", 0) > 0 and S.size(1) > self.max_species_prefix:
230
+ S = S[:, : self.max_species_prefix, :]
231
+ S = S.to(device=device, dtype=next(self.parameters()).dtype)
232
+ S = self.species_ln(S)
233
+ return S
234
+ else:
235
+ return None
236
+
237
+ Ssrc = _as_tokens(tok_src, emb_src)
238
+ if Ssrc is not None:
239
+ parts.append(Ssrc)
240
+ Sdst = _as_tokens(tok_tgt, emb_tgt)
241
+ if Sdst is not None:
242
+ parts.append(Sdst)
243
+
244
+ # Protein tokens from ESM-C
245
+ if self.prepend_protein and self.esm is not None and protein_input is not None:
246
+ prot_ids, prot_mask = protein_input
247
+ esm_out = self.esm.encode_from_ids(prot_ids, prot_mask, return_dict=True)
248
+ P, lengths = self.esm.strip_special_tokens(esm_out["embeddings"], prot_mask)
249
+ # Optional per-protein capping before projection
250
+ if getattr(self, "max_protein_prefix", 0) > 0 and P.size(1) > self.max_protein_prefix:
251
+ P = P[:, : self.max_protein_prefix, :]
252
+ if lengths is not None:
253
+ lengths = lengths.clamp(max=self.max_protein_prefix)
254
+ if P.size(1) > 0:
255
+ P = self.esm_ln(P.to(device=device, dtype=next(self.parameters()).dtype))
256
+ # Zero padded rows (per-sample) based on lengths
257
+ if lengths is not None:
258
+ Lp = P.size(1)
259
+ ar = torch.arange(Lp, device=device).unsqueeze(0)
260
+ lengths = lengths.to(device=device)
261
+ valid = ar < lengths.unsqueeze(1) # [B,Lp]
262
+ P = P * valid.unsqueeze(-1)
263
+ parts.append(P)
264
+
265
+ if len(parts) == 0:
266
+ empty = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype)
267
+ return empty, torch.zeros(batch_size, dtype=torch.long, device=device)
268
+
269
+ prefix = torch.cat(parts, dim=1) if parts else torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=next(self.parameters()).dtype) # [B,Lp,H]
270
+ # Compute per-sample valid lengths: treat zero rows as padding
271
+ with torch.no_grad():
272
+ if prefix.size(1) > 0:
273
+ valid = (prefix.abs().sum(dim=-1) > 0)
274
+ lengths = valid.sum(dim=1).to(torch.long)
275
+ else:
276
+ lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
277
+
278
+ # ---- Enforce hard global budget on the prefix itself ----
279
+ prefix_budget = max(0, int(self.max_position_embeddings) - 1)
280
+ if prefix_budget == 0:
281
+ trimmed = prefix.new_zeros(prefix.size(0), 0, prefix.size(2))
282
+ return trimmed, torch.zeros(prefix.size(0), dtype=torch.long, device=prefix.device)
283
+
284
+ allow = torch.minimum(lengths, torch.tensor(prefix_budget, device=lengths.device, dtype=lengths.dtype))
285
+ Lp_max = int(allow.max().item()) if allow.numel() > 0 else 0
286
+ if prefix.size(1) > Lp_max:
287
+ trimmed = prefix.new_zeros(prefix.size(0), Lp_max, prefix.size(2))
288
+ for b in range(prefix.size(0)):
289
+ lb = int(allow[b].item())
290
+ if lb > 0:
291
+ trimmed[b, :lb, :] = prefix[b, :lb, :]
292
+ prefix = trimmed
293
+ lengths = allow
294
+ else:
295
+ lengths = allow
296
+ return prefix, lengths
297
+
298
+ def forward(
299
+ self,
300
+ codon_ids: torch.Tensor,
301
+ cond: Dict[str, Any] = None,
302
+ labels: Optional[torch.Tensor] = None,
303
+ return_dict: bool = True,
304
+ species_tok_emb: Optional[torch.Tensor] = None,
305
+ protein_input: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
306
+ protein_seqs: Optional[List[str]] = None,
307
+ # KV cache options
308
+ use_cache: bool = False,
309
+ past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
310
+ position_offset: int = 0,
311
+ ) -> Dict[str, torch.Tensor]:
312
+ batch_size, codon_len = codon_ids.shape
313
+ device = codon_ids.device
314
+
315
+ # Unpack conditioning
316
+ if cond is not None:
317
+ control_mode = cond.get("control_mode", "fixed")
318
+ species_tok_emb_src = cond.get("species_tok_emb_src")
319
+ species_tok_emb_tgt = cond.get("species_tok_emb_tgt")
320
+ species_emb_src = cond.get("species_emb_src")
321
+ species_emb_tgt = cond.get("species_emb_tgt")
322
+ species_tok_emb = cond.get("species_tok_emb")
323
+ species_emb = cond.get("species_emb")
324
+ protein_input = cond.get("protein_input")
325
+ protein_seqs = cond.get("protein_seqs")
326
+ else:
327
+ species_emb = None
328
+ species_tok_emb_src = None
329
+ species_tok_emb_tgt = None
330
+ species_emb_src = None
331
+ species_emb_tgt = None
332
+
333
+ if protein_seqs is not None and protein_input is None:
334
+ if self.esm is not None:
335
+ with torch.no_grad():
336
+ # Respect per-protein ceiling during tokenization (+2 for BOS/EOS)
337
+ max_len_tokens = (self.max_protein_prefix + 2) if (getattr(self, "max_protein_prefix", 0) > 0) else None
338
+ protein_input = self.esm.tokenize(protein_seqs, max_length=max_len_tokens)
339
+ else:
340
+ protein_input = None
341
+
342
+ # Fast path: incremental decode using KV cache
343
+ if past_kv is not None:
344
+ # Expect only newly generated codon tokens here
345
+ if codon_ids.numel() == 0:
346
+ # Nothing to do; return a dummy next_logits
347
+ dummy = torch.zeros(batch_size, self.vocab_size, device=device, dtype=self.lm_head.weight.dtype)
348
+ return {"logits": dummy[:, 0:0], "next_logits": dummy}
349
+
350
+ x = self.embed_tokens(codon_ids) # [B, T_new, H]
351
+
352
+ present_kv: List[Tuple[torch.Tensor, torch.Tensor]] = []
353
+ for i, block in enumerate(self.blocks):
354
+ kv_i = past_kv[i] if i < len(past_kv) else None
355
+ if self.training and getattr(self, 'gradient_checkpointing', False):
356
+ def _fn(inp):
357
+ return block(inp, past_kv=kv_i, use_cache=True, position_offset=position_offset)
358
+ out_blk = checkpoint.checkpoint(_fn, x, use_reentrant=False)
359
+ else:
360
+ out_blk = block(x, past_kv=kv_i, use_cache=True, position_offset=position_offset)
361
+ x, kv_out = out_blk # type: ignore[assignment]
362
+ present_kv.append(kv_out)
363
+
364
+ x = self.ln_f(x)
365
+ logits_step = self.lm_head(x) # [B, T_new, V]
366
+ next_logits = logits_step[:, -1, :]
367
+ out: Dict[str, torch.Tensor] = {"logits": logits_step[:, 0:0, :], "next_logits": next_logits}
368
+ out["present_kv"] = present_kv # type: ignore[assignment]
369
+ return out if return_dict else logits_step[:, 0:0, :]
370
+
371
+ # Standard path: build prefix and full window (training or prefill)
372
+ prefix, prefix_lengths = self.build_prefix(
373
+ batch_size=batch_size,
374
+ device=device,
375
+ species_tok_emb=species_tok_emb,
376
+ species_emb=species_emb if cond is not None else None,
377
+ protein_input=protein_input,
378
+ species_tok_emb_src=species_tok_emb_src,
379
+ species_tok_emb_tgt=species_tok_emb_tgt,
380
+ species_emb_src=species_emb_src,
381
+ species_emb_tgt=species_emb_tgt,
382
+ )
383
+
384
+ start = self.start_embed.expand(batch_size, 1, self.hidden_size) # [B,1,H]
385
+
386
+ # Per-sample true codon input lengths (exclude PADs)
387
+ pad_id = int(self.special_ids.pad) if hasattr(self, "special_ids") and self.special_ids is not None else 0
388
+ codon_mask = (codon_ids != pad_id) # [B, N]
389
+ codon_lens = codon_mask.sum(dim=1) # [B]
390
+
391
+ # Budget remaining after prefix + start
392
+ capacity = max(0, int(self.max_position_embeddings))
393
+ budget_after_prefix = torch.clamp(
394
+ torch.as_tensor(capacity, device=device) - (prefix_lengths + 1),
395
+ min=0,
396
+ ) # [B]
397
+ # Per-sample cap is limited by both budget and available codons
398
+ per_cap = torch.minimum(budget_after_prefix, codon_lens) # [B]
399
+
400
+ # Total valid lengths per sample (prefix + start + capped codon)
401
+ valid_lengths = prefix_lengths + 1 + per_cap
402
+ T = int(valid_lengths.max().item()) if valid_lengths.numel() > 0 else (1 + int(codon_lens.max().item()) if codon_lens.numel() > 0 else 1)
403
+
404
+ # Embed only the needed codon window for this batch
405
+ max_cap = int(per_cap.max().item()) if per_cap.numel() > 0 else 0
406
+ if max_cap > 0:
407
+ codon_emb = self.embed_tokens(codon_ids[:, :max_cap]) # [B, max_cap, H]
408
+ else:
409
+ codon_emb = torch.zeros(batch_size, 0, self.hidden_size, device=device, dtype=start.dtype)
410
+
411
+ # Build sequence per-sample using concat to preserve gradients, then pad
412
+ seqs = []
413
+ for b in range(batch_size):
414
+ lp = int(prefix_lengths[b].item())
415
+ cap = int(per_cap[b].item())
416
+ parts = []
417
+ if lp > 0:
418
+ parts.append(prefix[b, :lp, :])
419
+ parts.append(start[b, 0:1, :])
420
+ if cap > 0:
421
+ parts.append(codon_emb[b, :cap, :])
422
+ seqs.append(torch.cat(parts, dim=0)) # [Lb, H]
423
+ x = rnn_utils.pad_sequence(seqs, batch_first=True) # [B, T, H]
424
+
425
+ present_kv_list: List[Tuple[torch.Tensor, torch.Tensor]] = []
426
+ for block in self.blocks:
427
+ if self.training and getattr(self, 'gradient_checkpointing', False):
428
+ def _fn(inp):
429
+ return block(inp, use_cache=use_cache, position_offset=0)
430
+ blk_out = checkpoint.checkpoint(_fn, x, use_reentrant=False)
431
+ else:
432
+ blk_out = block(x, use_cache=use_cache, position_offset=0)
433
+ if use_cache:
434
+ x, kv = blk_out # type: ignore[misc]
435
+ present_kv_list.append(kv)
436
+ else:
437
+ x = blk_out # type: ignore[assignment]
438
+
439
+ x = self.ln_f(x)
440
+ logits_full = self.lm_head(x) # [B, T, V]
441
+
442
+ # Gather codon-aligned logits per sample: positions (lp+1) .. (lp+cap) (skip start)
443
+ next_logits_list = []
444
+ if max_cap == 0:
445
+ # Keep graph by slicing from logits_full
446
+ codon_logits = logits_full[:, 0:0, :]
447
+ for b in range(batch_size):
448
+ lp = int(prefix_lengths[b].item())
449
+ # Last consumed position is the start token at index lp
450
+ pos_next = lp
451
+ if pos_next < logits_full.size(1):
452
+ next_logits_list.append(logits_full[b, pos_next, :])
453
+ else:
454
+ next_logits_list.append(logits_full[b, -1, :])
455
+ next_logits = torch.stack(next_logits_list, dim=0)
456
+ else:
457
+ slices = []
458
+ for b in range(batch_size):
459
+ lp = int(prefix_lengths[b].item())
460
+ cap = int(per_cap[b].item())
461
+ # Skip the start position so logits align with labels = codon_ids[:, 1:]
462
+ sl = logits_full[b, lp : lp + cap, :] if cap > 0 else logits_full.new_zeros(0, self.vocab_size)
463
+ slices.append(sl)
464
+ # Next-token logits after processing 'cap' codons: last consumed is at lp + cap
465
+ pos_next = lp + cap
466
+ next_logits_list.append(logits_full[b, pos_next, :] if pos_next < logits_full.size(1) else logits_full.new_zeros(self.vocab_size))
467
+ codon_logits = rnn_utils.pad_sequence(slices, batch_first=True) # [B,max_cap,V]
468
+ next_logits = torch.stack(next_logits_list, dim=0)
469
+ out = {"logits": codon_logits, "next_logits": next_logits}
470
+
471
+ if labels is not None:
472
+ # Align labels to per-sample caps: mask out positions >= cap
473
+ if labels.size(1) > 0 and max_cap > 0:
474
+ # Build masked labels with -100 beyond cap per sample
475
+ adj = labels.new_full((batch_size, max_cap), -100)
476
+ for b in range(batch_size):
477
+ cap = int(per_cap[b].item())
478
+ if cap > 0:
479
+ Lb = min(cap, labels.size(1))
480
+ adj[b, :Lb] = labels[b, :Lb]
481
+ loss = F.cross_entropy(codon_logits.reshape(-1, self.vocab_size), adj.reshape(-1), ignore_index=-100)
482
+ else:
483
+ loss = codon_logits.sum() * 0.0
484
+ out["loss"] = loss
485
+ # Provide optional debug stats for trainer logging
486
+ out["prefix_len"] = prefix_lengths.detach()
487
+ out["per_cap"] = per_cap.detach()
488
+ if use_cache:
489
+ out["present_kv"] = present_kv_list # type: ignore[assignment]
490
+ return out if return_dict else codon_logits
src/sampler.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/sampler.py
2
+ """
3
+ Sampling utilities for CodonGPT.
4
+
5
+ Conditioning invariants:
6
+ - Species context: fixed-size [B, Ds] via species_emb or variable-length [B, Ls, Ds] via species_tok_emb
7
+ - Protein context: raw sequences; the model's Frozen ESM handles tokenization
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import List, Optional, Dict, Union, Tuple
12
+ from pathlib import Path
13
+ import logging
14
+ import json
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+ from safetensors.torch import load_file
21
+
22
+ from .models import CodonGPT
23
+ from .tokenizer import CodonTokenizer
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ----------------------------
29
+ # Logit filtering
30
+ # ----------------------------
31
+
32
+ def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor:
33
+ return logits if logits.dim() == 2 else logits.unsqueeze(0)
34
+
35
+ def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
36
+ """Top-k filtering; logits is [B,V] or [V]."""
37
+ x = _ensure_2d_logits(logits)
38
+ k = max(1, min(int(k), x.size(-1)))
39
+ values, _ = torch.topk(x, k, dim=-1)
40
+ min_values = values[:, -1].unsqueeze(-1)
41
+ x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x)
42
+ return x if logits.dim() == 2 else x.squeeze(0)
43
+
44
+ def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
45
+ """Top-p (nucleus) filtering; logits is [B,V] or [V]."""
46
+ if p >= 1.0:
47
+ return logits
48
+ if p <= 0.0:
49
+ # You asked for nothing; enjoy the abyss.
50
+ return torch.full_like(logits, float('-inf'))
51
+ x = _ensure_2d_logits(logits)
52
+ sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1)
53
+ probs = F.softmax(sorted_logits, dim=-1)
54
+ cumprobs = torch.cumsum(probs, dim=-1)
55
+ to_remove = cumprobs > p
56
+ to_remove[:, 1:] = to_remove[:, :-1].clone()
57
+ to_remove[:, 0] = False
58
+ mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove)
59
+ x = torch.where(mask, torch.full_like(x, float('-inf')), x)
60
+ return x if logits.dim() == 2 else x.squeeze(0)
61
+
62
+
63
+ # ----------------------------
64
+ # Sampler
65
+ # ----------------------------
66
+
67
+ class CodonSampler:
68
+ """
69
+ GPT sampler with conditional generation.
70
+
71
+ Requires in model_dir:
72
+ - vocab.json
73
+ - model.safetensors (preferred)
74
+ or pytorch_model.bin (legacy)
75
+ - trainer_config.json or config.json
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ model_path: str,
81
+ device: str = "cuda",
82
+ species_store=None, # SpeciesEmbeddingStore
83
+ compile_model: bool = False,
84
+ taxonomy_db_path: Optional[str] = None,
85
+ qwen_max_length: int = 512,
86
+ qwen_batch_size: int = 16,
87
+ **_: dict,
88
+ ):
89
+ self.device = torch.device(device)
90
+ self.model_dir = Path(model_path)
91
+
92
+ # Required files (allow fallback to parent dir for vocab.json)
93
+ vocab_path = self.model_dir / "vocab.json"
94
+ if not vocab_path.exists():
95
+ parent_vocab = self.model_dir.parent / "vocab.json"
96
+ if parent_vocab.exists():
97
+ vocab_path = parent_vocab
98
+ else:
99
+ raise FileNotFoundError(f"Missing {self.model_dir / 'vocab.json'}")
100
+ trainer_cfg = self.model_dir / "trainer_config.json"
101
+ cfg_path = trainer_cfg if trainer_cfg.exists() else (self.model_dir / "config.json")
102
+ if not cfg_path.exists():
103
+ raise FileNotFoundError(f"Missing trainer_config.json or config.json in {self.model_dir}")
104
+
105
+ # Load config
106
+ with open(cfg_path, "r") as f:
107
+ self.config = json.load(f)
108
+
109
+ # Tokenizer
110
+ # If vocab was loaded from parent dir, pass that path; else model_dir
111
+ vocab_dir = vocab_path.parent
112
+ self.tokenizer = CodonTokenizer.from_pretrained(str(vocab_dir))
113
+ self.V = int(self.tokenizer.vocab_size)
114
+ self._eos_id = int(self.tokenizer.eos_token_id)
115
+ self._pad_id = int(self.tokenizer.pad_token_id)
116
+ self._num_special = int(self.tokenizer.num_special_tokens)
117
+
118
+ # Species store (optional if you pass species_emb* directly at sample())
119
+ self.species_store = species_store
120
+ self.species_vocab = (self.species_store.vocab if self.species_store is not None else {})
121
+ self.taxonomy_db_path = taxonomy_db_path
122
+ self.qwen_opts = {
123
+ "max_length": int(qwen_max_length),
124
+ "batch_size": int(qwen_batch_size),
125
+ }
126
+ # Lazy-inited Qwen objects
127
+ self._qwen_tokenizer = None
128
+ self._qwen_model = None
129
+
130
+ # Model
131
+ state = self._load_state_dict()
132
+ arch = self._infer_arch_from_state_dict(state)
133
+ self.model = CodonGPT(
134
+ vocab_size=self.V,
135
+ hidden_size=int(arch["hidden_size"]),
136
+ num_layers=int(arch["num_layers"]),
137
+ num_heads=int(arch["num_heads"]),
138
+ mlp_ratio=float(arch["mlp_ratio"]),
139
+ max_position_embeddings=int(arch["max_position_embeddings"]),
140
+ dropout=float(self.config.get("dropout", 0.1)),
141
+ num_special_tokens=self._num_special,
142
+ special_ids=self.tokenizer.special_ids,
143
+ esm_model_name=str(arch["esm_model_name"]) if bool(arch["prepend_protein"]) else None,
144
+ esm_device=str(arch["esm_device"]),
145
+ esm_dtype=str(arch["esm_dtype"]),
146
+ max_protein_prefix=int(arch["max_protein_prefix"]) if bool(arch["prepend_protein"]) else 0,
147
+ max_species_prefix=int(arch["max_species_prefix"]) if bool(arch["prepend_species"]) else 0,
148
+ prepend_species=bool(arch["prepend_species"]),
149
+ prepend_protein=bool(arch["prepend_protein"]),
150
+ species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)),
151
+ attn_impl=str(arch.get("attn_impl", "gqa")),
152
+ num_kv_groups=int(arch.get("num_kv_groups", 0)),
153
+ )
154
+ missing, unexpected = self.model.load_state_dict(state, strict=False)
155
+ if len(unexpected) > 0:
156
+ logger.warning(f"Unexpected keys in state dict: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}")
157
+ if len(missing) > 0:
158
+ logger.warning(f"Missing keys in state dict: {missing[:10]}{'...' if len(missing) > 10 else ''}")
159
+
160
+ if compile_model:
161
+ # If this errors on your PyTorch build, that's on you. No try/except.
162
+ self.model = torch.compile(self.model) # type: ignore
163
+
164
+ self.model.to(self.device).eval()
165
+ logger.info(f"Loaded GPT model from {self.model_dir}")
166
+ try:
167
+ hs = int(getattr(self.model, "hidden_size", -1))
168
+ hh = int(getattr(self.model, "num_heads", -1))
169
+ nl = int(getattr(self.model, "num_layers", -1))
170
+ logger.info(f"Reconstructed arch: hidden={hs} heads={hh} layers={nl}")
171
+ except Exception:
172
+ pass
173
+
174
+ # Static masks
175
+ self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device)
176
+ self._allowed_fixed[:self._num_special] = False # no specials in fixed mode
177
+
178
+ self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device)
179
+ self._allowed_variable[:self._num_special] = False
180
+ self._allowed_variable[self._eos_id] = True # EOS allowed in variable mode
181
+
182
+ # ----------------------------
183
+ # Loading / arch inference
184
+ # ----------------------------
185
+
186
+ def _load_state_dict(self) -> Dict[str, torch.Tensor]:
187
+ st_p = self.model_dir / "model.safetensors"
188
+ pt_p = self.model_dir / "pytorch_model.bin"
189
+ if st_p.exists():
190
+ return load_file(st_p)
191
+ if pt_p.exists():
192
+ return torch.load(pt_p, map_location="cpu")
193
+ raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}")
194
+
195
+ def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Union[int, float, bool, str]]:
196
+ arch: Dict[str, Union[int, float, bool, str]] = {}
197
+
198
+ # hidden size
199
+ if "lm_head.weight" in state_dict:
200
+ arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1])
201
+ else:
202
+ for k, v in state_dict.items():
203
+ if k.endswith("ln_f.weight"):
204
+ arch["hidden_size"] = int(v.shape[0])
205
+ break
206
+ # Prefer config when present to avoid guessing errors
207
+ cfg = self.config or {}
208
+ if "hidden_size" in cfg:
209
+ arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore[index]
210
+ if "hidden_size" not in arch:
211
+ arch["hidden_size"] = int(cfg.get("hidden_size", 960))
212
+ H = int(arch["hidden_size"])
213
+
214
+ # layers
215
+ max_block = -1
216
+ for k in state_dict.keys():
217
+ if k.startswith("blocks."):
218
+ idx = int(k.split(".")[1])
219
+ if idx > max_block:
220
+ max_block = idx
221
+ arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12))
222
+ if "num_hidden_layers" in cfg:
223
+ arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore[index]
224
+
225
+ # mlp ratio from w1
226
+ w1_key = "blocks.0.ffn.w1.weight" if "blocks.0.ffn.w1.weight" in state_dict else None
227
+ if w1_key is None:
228
+ for i in range(1, 3):
229
+ k = f"blocks.{i}.ffn.w1.weight"
230
+ if k in state_dict:
231
+ w1_key = k
232
+ break
233
+ if w1_key is not None and H > 0:
234
+ arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H)
235
+ else:
236
+ arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0))
237
+
238
+ # heads – pick a divisor of H
239
+ cfg_heads = cfg.get("num_attention_heads")
240
+ if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0:
241
+ arch["num_heads"] = int(cfg_heads)
242
+ else:
243
+ for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1):
244
+ if H % h == 0:
245
+ arch["num_heads"] = h
246
+ break
247
+
248
+ # conditioning flags from presence of submodules
249
+ arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys())))
250
+ has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys())
251
+ arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm)))
252
+ arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m"))
253
+ arch["esm_device"] = str(cfg.get("esm_device", "cuda"))
254
+ arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower()
255
+ arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0))
256
+ arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0))
257
+
258
+ if "max_length" in cfg:
259
+ arch["max_position_embeddings"] = int(cfg.get("max_length", 1024))
260
+ else:
261
+ arch["max_position_embeddings"] = int(cfg.get("max_position_embeddings", 1024))
262
+ # Attention impl and num_kv_groups (from config or infer from weights)
263
+ attn_impl = str(cfg.get("attn_impl", ""))
264
+ num_kv_groups = int(cfg.get("num_kv_groups", 0))
265
+ if not attn_impl:
266
+ wk_key = next((k for k in state_dict.keys() if k.endswith("attn.Wk.weight")), None)
267
+ if wk_key is not None:
268
+ attn_impl = "gqa"
269
+ out_ch, _ = state_dict[wk_key].shape
270
+ num_heads = int(arch.get("num_heads", 1))
271
+ head_dim = int(arch["hidden_size"]) // max(1, num_heads)
272
+ if head_dim > 0:
273
+ num_kv_groups = max(1, out_ch // head_dim)
274
+ else:
275
+ attn_impl = "mha"
276
+ num_kv_groups = 0
277
+ arch["attn_impl"] = attn_impl
278
+ arch["num_kv_groups"] = num_kv_groups
279
+
280
+ return arch # type: ignore[return-value]
281
+
282
+ # ----------------------------
283
+ # Public API
284
+ # ----------------------------
285
+
286
+ @torch.no_grad()
287
+ def sample(
288
+ self,
289
+ num_sequences: int = 1,
290
+ sequence_length: int = 100, # target number of codons (fixed mode); max iterations (variable)
291
+ species: Optional[Union[str, List[str]]] = None,
292
+ protein_sequences: Optional[Union[str, List[str]]] = None,
293
+ control_mode: str = "fixed", # "fixed" or "variable"
294
+ target_protein_length: Optional[int] = None, # deprecated; alias to sequence_length
295
+ temperature: float = 1.0,
296
+ top_k: Optional[int] = None,
297
+ top_p: Optional[float] = None,
298
+ seed: Optional[int] = None,
299
+ return_intermediate: bool = False,
300
+ progress_bar: bool = False,
301
+ species_emb: Optional[torch.Tensor] = None, # [B, Ds]
302
+ species_tok_emb: Optional[torch.Tensor] = None, # [B, Ls, Ds]
303
+ enforce_translation: bool = False,
304
+ codon_enforcement_weight: float = 10.0, # unused with hard mask; kept for API compatibility
305
+ ) -> Dict[str, Union[List[str], torch.Tensor, List[bool]]]:
306
+
307
+ if seed is not None:
308
+ torch.manual_seed(int(seed))
309
+ np.random.seed(int(seed))
310
+
311
+ if control_mode not in ("fixed", "variable"):
312
+ raise ValueError(f"control_mode must be 'fixed' or 'variable', got {control_mode}")
313
+
314
+ B = int(num_sequences)
315
+ T_codons = int(sequence_length if target_protein_length is None else target_protein_length)
316
+
317
+ # Prepare conditioning
318
+ cond: Dict[str, Union[str, List[str], torch.Tensor]] = {"control_mode": control_mode}
319
+
320
+ # Species (priority: provided tensors → names via store)
321
+ if species_tok_emb is not None:
322
+ if species_tok_emb.ndim != 3 or species_tok_emb.size(0) != B:
323
+ raise ValueError("species_tok_emb must be [B, Ls, Ds]")
324
+ st = species_tok_emb.to(self.device)
325
+ cond["species_tok_emb_src"] = st
326
+ cond["species_tok_emb_tgt"] = st
327
+ elif species_emb is not None:
328
+ if species_emb.ndim != 2 or species_emb.size(0) != B:
329
+ raise ValueError("species_emb must be [B, Ds]")
330
+ se = species_emb.to(self.device)
331
+ cond["species_emb_src"] = se
332
+ cond["species_emb_tgt"] = se
333
+ elif species is not None:
334
+ names = [species] * B if isinstance(species, str) else species
335
+ if len(names) != B:
336
+ raise ValueError("Length of species list must match num_sequences")
337
+
338
+ # If we have a store (variable-length), use it for known species and compute Qwen embeddings for unknowns.
339
+ if self.species_store is not None:
340
+ ids = [self.species_store.vocab.get(n, -1) for n in names]
341
+ known_mask = [i for i, sid in enumerate(ids) if sid >= 0]
342
+ unk_mask = [i for i, sid in enumerate(ids) if sid < 0]
343
+
344
+ # Only variable-length embeddings are supported. If the store is not sequence-based, compute via Qwen for all.
345
+ use_sequence = bool(getattr(self.species_store, "is_legacy", False))
346
+ if not use_sequence:
347
+ # Fall back to Qwen for everything
348
+ q_tok, q_len = self._qwen_embed_names(names, pooling="sequence")
349
+ cond["species_tok_emb_src"] = q_tok.to(self.device)
350
+ cond["species_tok_emb_tgt"] = q_tok.to(self.device)
351
+ else:
352
+ # list of per-sample [L,D] tensors to be padded later
353
+ seq_list: List[torch.Tensor] = [None] * B # type: ignore[list-item]
354
+ D = int(getattr(self.species_store, "_ds", 1024))
355
+ # Known via store
356
+ if known_mask:
357
+ sub_ids = [ids[i] for i in known_mask]
358
+ result = self.species_store.batch_get(sub_ids)
359
+ assert isinstance(result, tuple)
360
+ sp_tok, _ = result
361
+ for j, i in enumerate(known_mask):
362
+ row = sp_tok[j]
363
+ nonzero = (row.abs().sum(dim=-1) > 0)
364
+ L = int(nonzero.sum().item()) if nonzero.any() else int(row.size(0))
365
+ seq_list[i] = row[:L].to(self.device)
366
+ # Unknown via Qwen
367
+ if unk_mask:
368
+ unk_names = [names[i] for i in unk_mask]
369
+ q_tok, q_len = self._qwen_embed_names(unk_names, pooling="sequence")
370
+ for j, i in enumerate(unk_mask):
371
+ L = int(q_len[j].item())
372
+ seq_list[i] = q_tok[j, :L, :].to(self.device)
373
+
374
+ # Pad to [B,Lmax,D]
375
+ Lmax = max((t.size(0) for t in seq_list if t is not None), default=0)
376
+ if Lmax == 0:
377
+ raise RuntimeError("No species embeddings could be constructed.")
378
+ padded = torch.zeros(B, Lmax, D, device=self.device, dtype=seq_list[0].dtype)
379
+ for i, t in enumerate(seq_list):
380
+ if t is None:
381
+ continue
382
+ L = t.size(0)
383
+ padded[i, :L, :] = t
384
+ cond["species_tok_emb_src"] = padded
385
+ cond["species_tok_emb_tgt"] = padded
386
+ else:
387
+ # No store: compute everything via Qwen (sequence pooling only)
388
+ emb, lengths = self._qwen_embed_names(names, pooling="sequence")
389
+ st = emb.to(self.device, non_blocking=True)
390
+ cond["species_tok_emb_src"] = st
391
+ cond["species_tok_emb_tgt"] = st
392
+
393
+ # Protein sequences (raw AA strings; the model handles ESM-C)
394
+ if protein_sequences is not None:
395
+ if isinstance(protein_sequences, list):
396
+ if len(protein_sequences) != B:
397
+ raise ValueError("Length of protein_sequences must match num_sequences")
398
+ cond["protein_seqs"] = protein_sequences
399
+ else:
400
+ cond["protein_seqs"] = [protein_sequences] * B
401
+
402
+ # Start with empty codon context; we'll prefill to build KV cache and get first-step logits
403
+ input_ids = torch.empty((B, 0), dtype=torch.long, device=self.device)
404
+
405
+ # Capacity probe and fallback: if prefix consumes all budget, cap species/protein prefix temporarily (prefill path)
406
+ pref = None
407
+ try:
408
+ out0 = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
409
+ pref = out0.get("prefix_len") if isinstance(out0, dict) else None
410
+ if pref is not None:
411
+ max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
412
+ remaining0 = max_pos - (pref + 1)
413
+ need_cap = (remaining0 <= 0).any()
414
+ else:
415
+ need_cap = False
416
+ if need_cap:
417
+ prev_sp = int(getattr(self.model, "max_species_prefix", 0))
418
+ prev_pp = int(getattr(self.model, "max_protein_prefix", 0))
419
+ if prev_sp == 0 or prev_sp > 256:
420
+ setattr(self.model, "max_species_prefix", 256)
421
+ if prev_pp == 0 or prev_pp > 256:
422
+ setattr(self.model, "max_protein_prefix", 256)
423
+ out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
424
+ pref = out0b.get("prefix_len") if isinstance(out0b, dict) else None
425
+ if pref is not None:
426
+ remaining0b = max_pos - (pref + 1)
427
+ if (remaining0b <= 0).all():
428
+ setattr(self.model, "max_species_prefix", 128)
429
+ setattr(self.model, "max_protein_prefix", 128)
430
+ out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
431
+ pref = out0b.get("prefix_len") if isinstance(out0b, dict) else pref
432
+ # Use the prefill output
433
+ out_prefill = out0 if pref is None else out0
434
+ except Exception:
435
+ # Fallback without cache
436
+ out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
437
+ pref = out_prefill.get("prefix_len") if isinstance(out_prefill, dict) else None
438
+
439
+ allowed = self._allowed_variable if control_mode == "variable" else self._allowed_fixed
440
+ finished = torch.zeros(B, dtype=torch.bool, device=self.device) # EOS reached (variable) OR capacity exhausted
441
+ capacity_truncated = torch.zeros(B, dtype=torch.bool, device=self.device)
442
+
443
+ intermediate = [] if return_intermediate else None
444
+ aa2codons = self.tokenizer.aa2codons_char_map()
445
+
446
+ # If we probed capacity, optionally clamp target codons by available capacity at step 0
447
+ try:
448
+ if pref is not None:
449
+ max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
450
+ remaining = (max_pos - (pref + 1)).clamp(min=0)
451
+ T_codons = int(min(T_codons, int(remaining.max().item())))
452
+ except Exception:
453
+ pass
454
+
455
+ # KV cache and initial logits from prefill
456
+ kv = out_prefill.get("present_kv") if isinstance(out_prefill, dict) else None
457
+ logits = out_prefill.get("next_logits") if isinstance(out_prefill, dict) else None
458
+ if kv is None or logits is None:
459
+ # Safety: compute once if not provided
460
+ out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
461
+ kv = out_prefill.get("present_kv")
462
+ logits = out_prefill.get("next_logits")
463
+ assert kv is not None and logits is not None
464
+ prefix_len = pref if pref is not None else torch.zeros(B, dtype=torch.long, device=self.device)
465
+ prefill_len = (prefix_len + 1) # prefix + start
466
+
467
+ rng = range(T_codons)
468
+ if progress_bar:
469
+ from tqdm import tqdm
470
+ rng = tqdm(rng, desc="GPT sampling", total=T_codons)
471
+
472
+ for step in rng:
473
+ # Enforce global capacity per sample using prefix_len and current generated length
474
+ max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
475
+ remaining_now = (max_pos - prefill_len - input_ids.size(1)).clamp(max=10**9)
476
+ cant_extend = remaining_now <= 0
477
+ newly_blocked = (~finished) & cant_extend
478
+ capacity_truncated = capacity_truncated | newly_blocked
479
+ finished = finished | cant_extend
480
+
481
+ # Base mask: disallow specials in fixed, allow EOS in variable.
482
+ logits = logits.masked_fill(~allowed, float("-inf"))
483
+
484
+ # If a sample is finished (EOS or capacity), force PAD to keep shapes stable.
485
+ # Decoding will drop PAD anyway.
486
+ if finished.any():
487
+ logits[finished] = float("-inf")
488
+ logits[finished, self._pad_id] = 0.0
489
+
490
+ # Optional: enforce codon ↔ AA mapping at this step (hard mask)
491
+ if enforce_translation and ("protein_seqs" in cond):
492
+ aas_now: List[Optional[str]] = []
493
+ prot_list = cond["protein_seqs"] # type: ignore[index]
494
+ assert isinstance(prot_list, list)
495
+ for i in range(B):
496
+ seq = prot_list[i]
497
+ aas_now.append(seq[step] if step < len(seq) else None)
498
+
499
+ mask = torch.zeros_like(logits, dtype=torch.bool)
500
+ for i, a in enumerate(aas_now):
501
+ if a is None:
502
+ mask[i, self._num_special:self.V] = True
503
+ else:
504
+ valid = aa2codons.get(a, [])
505
+ if len(valid) == 0:
506
+ mask[i, self._num_special:self.V] = True
507
+ else:
508
+ mask[i, valid] = True
509
+ logits = logits.masked_fill(~mask, float("-inf"))
510
+
511
+ # Temperature + filtering
512
+ if temperature != 1.0:
513
+ logits = logits / float(temperature)
514
+ if top_k is not None:
515
+ logits = _top_k_filtering(logits, int(top_k))
516
+ if top_p is not None:
517
+ logits = _top_p_filtering(logits, float(top_p))
518
+
519
+ probs = F.softmax(logits, dim=-1)
520
+ next_tok = torch.multinomial(probs, num_samples=1) # [B,1]
521
+
522
+ if control_mode == "variable":
523
+ # Stop sequences at EOS
524
+ eos_mask = (next_tok.squeeze(-1) == self._eos_id)
525
+ finished = finished | eos_mask
526
+
527
+ input_ids = torch.cat([input_ids, next_tok], dim=1)
528
+
529
+ if return_intermediate:
530
+ intermediate.append(input_ids.clone())
531
+
532
+ # If all sequences are finished, we're done.
533
+ if finished.all():
534
+ break
535
+
536
+ # Incremental decode: compute logits for next step and update KV cache
537
+ pos_offset = int(prefill_len.max().item()) + input_ids.size(1) - 1 # use max offset for shared RoPE cache
538
+ out_inc = self.model(
539
+ codon_ids=next_tok,
540
+ cond=None,
541
+ return_dict=True,
542
+ use_cache=True,
543
+ past_kv=kv,
544
+ position_offset=pos_offset,
545
+ )
546
+ kv = out_inc.get("present_kv")
547
+ logits = out_inc.get("next_logits")
548
+ assert kv is not None and logits is not None
549
+
550
+ # Build final DNA strings, dropping specials and any PADs we added
551
+ output_token_rows: List[List[int]] = []
552
+ for row in input_ids.tolist():
553
+ toks: List[int] = []
554
+ for t in row:
555
+ if t == self._pad_id:
556
+ continue
557
+ if t == self._eos_id:
558
+ break # variable mode terminator
559
+ if t >= self._num_special and t < self.V:
560
+ toks.append(int(t))
561
+ if control_mode == "fixed":
562
+ # In fixed mode we *intended* T_codons; if capacity cut us short, it's fine.
563
+ toks = toks[:T_codons]
564
+ output_token_rows.append(toks)
565
+
566
+ sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows]
567
+
568
+ # Pad variable-length rows for input_ids to avoid tensor construction errors when
569
+ # some samples are capacity-truncated in fixed mode.
570
+ max_len = max((len(r) for r in output_token_rows), default=0)
571
+ if max_len > 0:
572
+ ids_padded = torch.full(
573
+ (len(output_token_rows), max_len),
574
+ self._pad_id,
575
+ device=self.device,
576
+ dtype=torch.long,
577
+ )
578
+ for i, row in enumerate(output_token_rows):
579
+ if len(row) > 0:
580
+ ids_padded[i, : len(row)] = torch.tensor(row, device=self.device, dtype=torch.long)
581
+ else:
582
+ ids_padded = torch.empty((len(output_token_rows), 0), device=self.device, dtype=torch.long)
583
+
584
+ result: Dict[str, Union[List[str], torch.Tensor, List[bool]]] = {
585
+ "sequences": sequences,
586
+ "input_ids": ids_padded,
587
+ "capacity_truncated": capacity_truncated.detach().bool().tolist(),
588
+ }
589
+ if return_intermediate:
590
+ result["intermediate_states"] = intermediate # list[Tensor], length = steps actually taken
591
+ return result
592
+
593
+ # ----------------------------
594
+ # Qwen embedding (inline; no separate module)
595
+ # ----------------------------
596
+ def _ensure_qwen_loaded(self):
597
+ if self._qwen_tokenizer is not None and self._qwen_model is not None:
598
+ return
599
+ from transformers import AutoTokenizer, AutoModel
600
+ self._qwen_tokenizer = AutoTokenizer.from_pretrained(
601
+ "Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left"
602
+ )
603
+ dtype = torch.float16 if self.device.type == "cuda" else torch.float32
604
+ self._qwen_model = AutoModel.from_pretrained(
605
+ "Qwen/Qwen3-Embedding-0.6B", torch_dtype=dtype, trust_remote_code=True
606
+ ).to(self.device).eval()
607
+
608
+ @staticmethod
609
+ def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
610
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
611
+ if left_padding:
612
+ return last_hidden_states[:, -1]
613
+ else:
614
+ sequence_lengths = attention_mask.sum(dim=1) - 1
615
+ batch_size = last_hidden_states.shape[0]
616
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
617
+
618
+ @staticmethod
619
+ def _format_instruct(task: str, query: str) -> str:
620
+ return f"Instruct: {task}\nQuery: {query}"
621
+
622
+ @torch.no_grad()
623
+ def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
624
+ # Load taxonomy DB if provided
625
+ taxonomy_db = None
626
+ if self.taxonomy_db_path:
627
+ try:
628
+ with open(self.taxonomy_db_path, "r") as f:
629
+ import json
630
+ taxonomy_db = json.load(f)
631
+ except Exception:
632
+ taxonomy_db = None
633
+
634
+ self._ensure_qwen_loaded()
635
+ tokenizer = self._qwen_tokenizer
636
+ model = self._qwen_model
637
+ assert tokenizer is not None and model is not None
638
+
639
+ task = (
640
+ "Given a species taxonomy information, generate a biological embedding "
641
+ "representing its taxonomic and evolutionary characteristics"
642
+ )
643
+ texts = [self._format_instruct(task, taxonomy_db.get(s, s) if taxonomy_db else s) for s in names]
644
+
645
+ BATCH = int(self.qwen_opts.get("batch_size", 16))
646
+ max_len = int(self.qwen_opts.get("max_length", 512))
647
+
648
+ # sequence pooling only
649
+ seqs: List[torch.Tensor] = []
650
+ lens: List[int] = []
651
+ for i in range(0, len(texts), BATCH):
652
+ chunk = texts[i : i + BATCH]
653
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(self.device)
654
+ out = model(**inputs)
655
+ h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) # [B,L,D]
656
+ attn = inputs["attention_mask"]
657
+ for j in range(h.size(0)):
658
+ L = int(attn[j].sum().item())
659
+ seqs.append(h[j, :L, :].float().cpu())
660
+ lens.append(L)
661
+ # Pad to [B,Lmax,D]
662
+ Lmax = max(lens) if lens else 0
663
+ D = seqs[0].size(1) if seqs else 0
664
+ padded = torch.zeros(len(seqs), Lmax, D)
665
+ for i, t in enumerate(seqs):
666
+ padded[i, : t.size(0), :] = t
667
+ return padded, torch.tensor(lens, dtype=torch.long)
668
+
669
+ # ----------------------------
670
+ # Conditioning helper
671
+ # ----------------------------
672
+
673
+ # (Kept minimal. Species embeddings are prepared inline in sample().)
674
+
675
+
676
+ # ----------------------------
677
+ # Convenience function
678
+ # ----------------------------
679
+
680
+ def sample_sequences(
681
+ model_path: str,
682
+ num_sequences: int = 10,
683
+ sequence_length: int = 100,
684
+ species: Optional[Union[str, List[str]]] = None,
685
+ protein_sequence: Optional[Union[str, List[str]]] = None,
686
+ **kwargs
687
+ ) -> List[str]:
688
+ sampler = CodonSampler(model_path)
689
+ out = sampler.sample(
690
+ num_sequences=num_sequences,
691
+ sequence_length=sequence_length,
692
+ species=species,
693
+ protein_sequences=protein_sequence,
694
+ **kwargs
695
+ )
696
+ return out["sequences"] # type: ignore[return-value]
src/tokenizer.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/tokenizer.py
2
+ """
3
+ Codon tokenizer: 3-mer tokens + 4 special tokens.
4
+
5
+ No frameworks, no inheritance chains. Just:
6
+ - encode_codon_seq("ATG...") -> [ids...] (appends EOS outside, not here)
7
+ - decode_codon_seq([ids...]) -> "ATG..."
8
+ - save_vocabulary(dir) / from_pretrained(dir) for reproducible runs
9
+
10
+ Special IDs are fixed and contiguous from 0:
11
+ pad=0, unk=1, bos=2, eos=3
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Dict, List, Optional, Tuple, Any
21
+
22
+
23
+ # ------------------------------
24
+ # Special token ids
25
+ # ------------------------------
26
+
27
+ @dataclass(frozen=True)
28
+ class SpecialIds:
29
+ pad: int = 0
30
+ unk: int = 1
31
+ bos: int = 2
32
+ eos: int = 3
33
+
34
+ def to_dict(self) -> Dict[str, int]:
35
+ return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos}
36
+
37
+
38
+ # ------------------------------
39
+ # Tokenizer
40
+ # ------------------------------
41
+
42
+ class CodonTokenizer:
43
+ """Minimal tokenizer for codon (DNA 3-mer) sequences."""
44
+
45
+ __slots__ = (
46
+ "codons",
47
+ "_special_token_str",
48
+ "vocab",
49
+ "ids_to_tokens",
50
+ "_special_ids",
51
+ "_num_special_tokens",
52
+ "_genetic_code",
53
+ "_codon2aa_char",
54
+ "_aa2codons_char",
55
+ )
56
+
57
+ def __init__(
58
+ self,
59
+ pad_token: str = "<pad>",
60
+ unk_token: str = "<unk>",
61
+ bos_token: str = "<bos>",
62
+ eos_token: str = "<stop>", # human-readable; id is still 3
63
+ **_: Any, # ignore junk kwargs – we don't play framework games
64
+ ) -> None:
65
+ # 64 codons
66
+ bases = ("A", "C", "G", "T")
67
+ self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases]
68
+
69
+ # specials come first, contiguous
70
+ special_tokens = [pad_token, unk_token, bos_token, eos_token]
71
+ self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token}
72
+
73
+ # vocab: specials [0..3], then 64 codons [4..67]
74
+ self.vocab: Dict[str, int] = {}
75
+ for i, tok in enumerate(special_tokens):
76
+ self.vocab[tok] = i
77
+ for codon in self.codons:
78
+ self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens))
79
+
80
+ # reverse map
81
+ self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()}
82
+
83
+ # fixed ids
84
+ self._special_ids = SpecialIds(
85
+ pad=self.vocab[pad_token],
86
+ unk=self.vocab[unk_token],
87
+ bos=self.vocab[bos_token],
88
+ eos=self.vocab[eos_token],
89
+ )
90
+ self._num_special_tokens = len(special_tokens)
91
+
92
+ # genetic code (char)
93
+ self._genetic_code: Dict[str, str] = {
94
+ "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L",
95
+ "TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S",
96
+ "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*",
97
+ "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W",
98
+ "CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L",
99
+ "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
100
+ "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q",
101
+ "CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R",
102
+ "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
103
+ "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T",
104
+ "AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K",
105
+ "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
106
+ "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V",
107
+ "GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A",
108
+ "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
109
+ "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G",
110
+ }
111
+
112
+ # precompute char helpers
113
+ self._codon2aa_char: Dict[int, str] = {}
114
+ self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
115
+ for codon in self.codons:
116
+ cid = self.vocab[codon]
117
+ aa = self._genetic_code.get(codon, "X")
118
+ self._codon2aa_char[cid] = aa
119
+ if aa in self._aa2codons_char:
120
+ self._aa2codons_char[aa].append(cid)
121
+
122
+ # sanity: specials are contiguous 0..3
123
+ ids = list(self._special_ids.to_dict().values())
124
+ if sorted(ids) != list(range(self._num_special_tokens)):
125
+ raise AssertionError("Special token ids must be contiguous starting at 0")
126
+
127
+ # ---------- properties ----------
128
+ @property
129
+ def vocab_size(self) -> int:
130
+ return len(self.vocab)
131
+
132
+ @property
133
+ def special_ids(self) -> SpecialIds:
134
+ return self._special_ids
135
+
136
+ @property
137
+ def num_special_tokens(self) -> int:
138
+ return self._num_special_tokens
139
+
140
+ @property
141
+ def pad_token_id(self) -> int:
142
+ return self._special_ids.pad
143
+
144
+ @property
145
+ def unk_token_id(self) -> int:
146
+ return self._special_ids.unk
147
+
148
+ @property
149
+ def bos_token_id(self) -> int:
150
+ return self._special_ids.bos
151
+
152
+ @property
153
+ def eos_token_id(self) -> int:
154
+ return self._special_ids.eos
155
+
156
+ # ---------- core API ----------
157
+ def encode_codon_seq(self, seq: str, validate: bool = True) -> List[int]:
158
+ """
159
+ Map DNA (ACGT)^3N to 3-mer ids. We don't append BOS/EOS here.
160
+ """
161
+ s = seq.upper()
162
+ if validate:
163
+ if len(s) % 3 != 0:
164
+ raise ValueError(f"Sequence length {len(s)} not divisible by 3")
165
+ if not _is_acgt(s):
166
+ raise ValueError("Sequence contains invalid nucleotides (only ACGT supported)")
167
+ out: List[int] = []
168
+ # Fast Python slice loop – good enough. NumPy won't help for tiny strings.
169
+ for i in range(0, len(s), 3):
170
+ codon = s[i : i + 3]
171
+ out.append(self.vocab.get(codon, self._special_ids.unk))
172
+ return out
173
+
174
+ def decode_codon_seq(self, token_ids: List[int]) -> str:
175
+ """
176
+ Convert codon ids (>= num_special_tokens) back to DNA string.
177
+ Special ids are ignored unless they collide (they don't).
178
+ """
179
+ parts: List[str] = []
180
+ nst = self._num_special_tokens
181
+ for tid in token_ids:
182
+ if tid >= nst:
183
+ tok = self.ids_to_tokens.get(tid)
184
+ if tok is not None: # should always be a codon
185
+ parts.append(tok)
186
+ return "".join(parts)
187
+
188
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **_: Any) -> str:
189
+ # kept for API parity with your old code
190
+ if skip_special_tokens:
191
+ token_ids = [t for t in token_ids if t >= self._num_special_tokens]
192
+ return self.decode_codon_seq(token_ids)
193
+
194
+ # ---------- misc helpers ----------
195
+ def codon_vocab(self) -> Dict[str, int]:
196
+ return {c: self.vocab[c] for c in self.codons}
197
+
198
+ def codon2aa_char_map(self) -> Dict[int, str]:
199
+ return dict(self._codon2aa_char)
200
+
201
+ def aa2codons_char_map(self) -> Dict[str, List[int]]:
202
+ return {k: v[:] for k, v in self._aa2codons_char.items()}
203
+
204
+ def aa_to_codon_length(self, aa_seq: str) -> int:
205
+ # You don't count stop unless it's explicitly there.
206
+ return len(aa_seq)
207
+
208
+ # HF compatibility stubs (your code calls these in a few places)
209
+ def _tokenize(self, text: str) -> List[str]:
210
+ if len(text) % 3 != 0:
211
+ raise ValueError(f"Text length {len(text)} not divisible by 3")
212
+ return [text[i : i + 3] for i in range(0, len(text), 3)]
213
+
214
+ def _convert_token_to_id(self, token: str) -> int:
215
+ return self.vocab.get(token, self._special_ids.unk)
216
+
217
+ def _convert_id_to_token(self, index: int) -> str:
218
+ return self.ids_to_tokens.get(index, self._special_token_str["unk"])
219
+
220
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
221
+ return "".join(tokens)
222
+
223
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
224
+ return token_ids_0
225
+
226
+ def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
227
+ return [0] * len(token_ids_0)
228
+
229
+ # ---------- persistence ----------
230
+ def get_vocab(self) -> Dict[str, int]:
231
+ return dict(self.vocab)
232
+
233
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
234
+ """
235
+ Save to JSON with both vocab and special token strings so we can
236
+ reconstruct IDs exactly. Deterministic and stable.
237
+ """
238
+ os.makedirs(save_directory, exist_ok=True)
239
+ vocab_file = os.path.join(
240
+ save_directory,
241
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
242
+ )
243
+ payload = {
244
+ "vocab": self.vocab,
245
+ "special_token_str": self._special_token_str,
246
+ }
247
+ with open(vocab_file, "w", encoding="utf-8") as f:
248
+ json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)
249
+ return (vocab_file,)
250
+
251
+ @classmethod
252
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer":
253
+ """
254
+ Load from a directory containing vocab.json produced by save_vocabulary().
255
+ We rebuild the SpecialIds from the saved token strings to keep IDs stable.
256
+ """
257
+ vocab_path = Path(pretrained_model_name_or_path) / "vocab.json"
258
+ tok = cls(**kwargs) # default structure; we'll overwrite below
259
+ if not vocab_path.exists():
260
+ # If nothing to load, return defaults. It keeps the rest of your code happy.
261
+ return tok
262
+
263
+ with open(vocab_path, "r", encoding="utf-8") as f:
264
+ save_data = json.load(f)
265
+
266
+ if not isinstance(save_data, dict) or "vocab" not in save_data:
267
+ # Old, dumber format: the whole file was the vocab dict
268
+ vocab = save_data
269
+ special_token_str = tok._special_token_str
270
+ else:
271
+ vocab = save_data["vocab"]
272
+ special_token_str = save_data.get("special_token_str", tok._special_token_str)
273
+
274
+ # rebuild maps
275
+ tok.vocab = {str(k): int(v) for k, v in vocab.items()}
276
+ tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()}
277
+
278
+ # reconcile special strings → ids
279
+ if isinstance(special_token_str, dict):
280
+ tok._special_token_str.update({k: v for k, v in special_token_str.items() if k in ("pad", "unk", "bos", "eos")})
281
+
282
+ def _id_for(name: str, default_val: int) -> int:
283
+ sym = tok._special_token_str[name]
284
+ return int(tok.vocab.get(sym, default_val))
285
+
286
+ tok._special_ids = SpecialIds(
287
+ pad=_id_for("pad", 0),
288
+ unk=_id_for("unk", 1),
289
+ bos=_id_for("bos", 2),
290
+ eos=_id_for("eos", 3),
291
+ )
292
+
293
+ # Figure out how many specials to reserve. If the saved mapping had extra junk,
294
+ # we still preserve a contiguous prefix if present. Otherwise default to 4.
295
+ ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos]
296
+ m = max(ids)
297
+ tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4
298
+
299
+ # Rebuild genetic helpers (cheap)
300
+ tok._rebuild_helpers()
301
+ return tok
302
+
303
+ # internal: rebuild helper maps after load
304
+ def _rebuild_helpers(self) -> None:
305
+ self._codon2aa_char = {}
306
+ self._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
307
+ for codon in self.codons:
308
+ cid = self.vocab[codon]
309
+ aa = self._genetic_code.get(codon, "X")
310
+ self._codon2aa_char[cid] = aa
311
+ if aa in self._aa2codons_char:
312
+ self._aa2codons_char[aa].append(cid)
313
+
314
+
315
+ # ------------------------------
316
+ # small helpers
317
+ # ------------------------------
318
+
319
+ def _is_acgt(s: str) -> bool:
320
+ # Faster than regex for short strings.
321
+ for ch in s:
322
+ if ch not in ("A", "C", "G", "T"):
323
+ return False
324
+ return True
src/trainer.py ADDED
@@ -0,0 +1,1230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/trainer.py
2
+ """
3
+ FSDP trainer for CodonGPT.
4
+ No frameworks, no sugar. The model computes its own loss.
5
+
6
+ Batch invariants:
7
+ - codon_ids [B, T] (right-padded; EOS already in-sequence)
8
+ - species_ids [B] (SpeciesEmbeddingStore provides fixed-size or sequence embeddings)
9
+ - protein_seqs: list[str] (ESM tokenization happens inside the model)
10
+
11
+ Rules:
12
+ - If your loader is IterableDataset, you MUST set args.max_steps > 0. We don't guess.
13
+ - If you want epoch-based, use a sized dataset; we call len(dataloader).
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import json
20
+ import math
21
+ import re
22
+ import shutil
23
+ import logging
24
+ import time
25
+ from dataclasses import dataclass
26
+ import datetime
27
+ import warnings
28
+ import importlib.util
29
+ import inspect
30
+ from typing import Any, Callable, Dict, Optional, Tuple, List
31
+ from tqdm import tqdm
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.distributed as dist
35
+ from torch.utils.data import DataLoader, IterableDataset
36
+
37
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
38
+ from torch.distributed.fsdp import (
39
+ ShardingStrategy,
40
+ MixedPrecision,
41
+ StateDictType,
42
+ FullStateDictConfig,
43
+ FullOptimStateDictConfig,
44
+ )
45
+ from safetensors.torch import save_file, load_file
46
+ import wandb
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ # ------------------------------
52
+ # Args
53
+ # ------------------------------
54
+
55
+ @dataclass
56
+ class TrainingArguments:
57
+ # Output
58
+ output_dir: str = "checkpoints"
59
+ save_steps: int = 1000
60
+ save_total_limit: int = 3
61
+ save_safetensors: bool = True
62
+ ckpt_recent_window_steps: int = 0
63
+ ckpt_recent_interval: int = 0
64
+ ckpt_archive_interval: int = 0
65
+
66
+ # Schedule
67
+ num_train_epochs: int = 1
68
+ max_steps: int = -1 # required for IterableDataset
69
+ gradient_accumulation_steps: int = 1
70
+ warmup_ratio: float = 0.0
71
+ lr_scheduler_type: str = "cosine" # "linear" | "cosine" | "constant"
72
+ # For streaming datasets: if max_steps<0 and steps_per_epoch>0, shape schedule using
73
+ # total_steps = num_train_epochs * steps_per_epoch
74
+ steps_per_epoch: int = 0
75
+
76
+ # Optim
77
+ learning_rate: float = 5e-4
78
+ weight_decay: float = 0.0
79
+ adam_beta1: float = 0.9
80
+ adam_beta2: float = 0.95
81
+ max_grad_norm: float = 1.0
82
+
83
+ # Data
84
+ per_device_train_batch_size: int = 8
85
+ per_device_eval_batch_size: int = 8
86
+ dataloader_num_workers: int = 0
87
+
88
+ # Precision / dist
89
+ fp16: bool = False
90
+ bf16: bool = False
91
+ fsdp: Optional[str] = None # "full_shard" or None
92
+ gradient_checkpointing: bool = False
93
+
94
+ # Global hard cap (prefix + start + codon)
95
+ max_length: int = 4096
96
+
97
+ # ESM (metadata only; model owns ESM)
98
+ esm_model_name: str = "esmc_300m"
99
+ esm_device: str = "cuda"
100
+ esm_dtype: str = "bf16"
101
+
102
+ # Logging / eval
103
+ logging_steps: int = 100
104
+ eval_steps: int = 0 # streaming eval: limit number of eval batches when eval dataset is Iterable
105
+ eval_interval: int = 0 # run evaluation every N optimizer steps (0 disables)
106
+ override_lr_on_resume: bool = False
107
+ # Minimal data stream resume cursor (stores total samples yielded so far for train dataset).
108
+ # When provided, we load 'skip_samples' from this JSON at start and set the dataset
109
+ # to skip exactly that many samples on resume. We also update the file in _save_checkpoint().
110
+ data_cursor_path: Optional[str] = None
111
+
112
+
113
+ # ------------------------------
114
+ # Trainer
115
+ # ------------------------------
116
+
117
+ class Trainer:
118
+ def __init__(
119
+ self,
120
+ model: nn.Module,
121
+ args: TrainingArguments,
122
+ data_collator: Optional[Callable] = None,
123
+ train_dataset: Optional[Any] = None,
124
+ eval_dataset: Optional[Any] = None,
125
+ tokenizer: Optional[Any] = None,
126
+ model_init: Optional[Callable[[], nn.Module]] = None,
127
+ compute_metrics: Optional[Callable] = None,
128
+ callbacks: Optional[list] = None,
129
+ optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[Any]] = (None, None),
130
+ preprocess_logits_for_metrics: Optional[Callable] = None,
131
+ species_store=None,
132
+ resume_from_checkpoint: Optional[str] = None,
133
+ ):
134
+ self.model = model
135
+ self.args = args
136
+ self.tokenizer = tokenizer
137
+ self.optimizer = optimizers[0]
138
+ self.lr_scheduler = optimizers[1]
139
+ self.species_store = species_store
140
+
141
+ self.train_dataloader: Optional[DataLoader] = None
142
+ self.eval_dataloader: Optional[DataLoader] = None
143
+
144
+ # Device (robust local rank resolution)
145
+ self.local_rank = 0
146
+ if torch.cuda.is_available():
147
+ lr_env = os.environ.get("LOCAL_RANK")
148
+ if lr_env is not None:
149
+ self.local_rank = int(lr_env)
150
+ else:
151
+ r = int(os.environ.get("RANK", "0"))
152
+ ng = max(1, torch.cuda.device_count())
153
+ self.local_rank = (r % ng)
154
+ self.device = torch.device(f"cuda:{self.local_rank}")
155
+ torch.cuda.set_device(self.device)
156
+ cd = torch.cuda.current_device()
157
+ nm = torch.cuda.get_device_name(cd)
158
+ logger.info(
159
+ f"[dist] RANK={os.environ.get('RANK')} LOCAL_RANK={os.environ.get('LOCAL_RANK')} WORLD_SIZE={os.environ.get('WORLD_SIZE')} "
160
+ f"cuda.count={torch.cuda.device_count()} select={self.device} current={cd} name={nm}"
161
+ )
162
+ else:
163
+ self.device = torch.device("cpu")
164
+
165
+ # Gradient checkpointing toggle (model owns the flag)
166
+ base = self._unwrap(self.model)
167
+ if self.args.gradient_checkpointing and hasattr(base, "gradient_checkpointing"):
168
+ base.gradient_checkpointing = True
169
+
170
+ # FSDP or single GPU
171
+ if self.args.fsdp:
172
+ self._setup_fsdp()
173
+ else:
174
+ self.model = self.model.to(self.device)
175
+
176
+ # AMP setup (use torch.amp APIs; GradScaler on CUDA only)
177
+ self._use_amp = (self.device.type == "cuda") and (self.args.fp16 or self.args.bf16)
178
+ self._amp_dtype = torch.float16 if self.args.fp16 else (torch.bfloat16 if self.args.bf16 else None)
179
+ use_cuda = (self.device.type == "cuda")
180
+ self._scaler = torch.amp.GradScaler(device="cuda", enabled=(use_cuda and self.args.fp16))
181
+
182
+ self.state = {"epoch": 0, "global_step": 0}
183
+
184
+ # Defer resume until after dataloaders are attached so scheduler can be shaped.
185
+ self._resume_path: Optional[str] = resume_from_checkpoint
186
+
187
+ # ---- dataloaders ----
188
+ def attach_dataloaders(self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None):
189
+ # Your dataset should handle sharding. We don't wrap with DistributedSampler here.
190
+ self.train_dataloader = train_loader
191
+ self.eval_dataloader = eval_loader
192
+ # Apply minimal resume cursor to the training dataset if configured
193
+ p = getattr(self.args, "data_cursor_path", None)
194
+ if p and os.path.exists(p):
195
+ with open(p, "r") as f:
196
+ js = json.load(f)
197
+ ds = getattr(self.train_dataloader, "dataset", None)
198
+ if hasattr(ds, "set_resume_skip"):
199
+ distributed = dist.is_available() and dist.is_initialized()
200
+ world = dist.get_world_size() if distributed else 1
201
+ rank = dist.get_rank() if distributed else 0
202
+
203
+ # Prefer the total cursor and split evenly across current world size.
204
+ # If total is missing, sum any saved per_rank list.
205
+ total: int = 0
206
+ if isinstance(js, dict):
207
+ try:
208
+ total = int(js.get("skip_samples", 0) or 0)
209
+ except Exception:
210
+ total = 0
211
+ if total <= 0:
212
+ raw = js.get("per_rank")
213
+ if isinstance(raw, list) and raw:
214
+ try:
215
+ total = int(sum(int(x) for x in raw))
216
+ except Exception:
217
+ total = 0
218
+
219
+ if total > 0:
220
+ if distributed:
221
+ per = total // max(world, 1)
222
+ rem = total % max(world, 1)
223
+ n_rank = per + (1 if rank < rem else 0)
224
+ ds.set_resume_skip(int(n_rank))
225
+ if self._is_main():
226
+ logger.info(
227
+ "resume cursor: total=%s split across world=%s → rank=%s skip=%s",
228
+ total, world, rank, n_rank,
229
+ )
230
+ else:
231
+ ds.set_resume_skip(int(total))
232
+ if self._is_main():
233
+ logger.info("resume cursor: total=%s (single-process) skip=%s", total, total)
234
+
235
+
236
+ # ---- optim + scheduler ----
237
+ def _create_optimizer_and_scheduler(self):
238
+ if self.optimizer is None:
239
+ decay, no_decay = [], []
240
+ for n, p in self._unwrap(self.model).named_parameters():
241
+ if not p.requires_grad:
242
+ continue
243
+ if n.endswith("bias") or "norm" in n.lower() or "ln_" in n.lower():
244
+ no_decay.append(p)
245
+ else:
246
+ decay.append(p)
247
+
248
+ opt_kwargs = dict(
249
+ lr=self.args.learning_rate,
250
+ betas=(self.args.adam_beta1, self.args.adam_beta2),
251
+ )
252
+ params = [
253
+ {"params": decay, "weight_decay": self.args.weight_decay},
254
+ {"params": no_decay, "weight_decay": 0.0},
255
+ ]
256
+ sig_adamw = inspect.signature(torch.optim.AdamW)
257
+ if torch.cuda.is_available() and "fused" in sig_adamw.parameters:
258
+ opt_kwargs["fused"] = True # type: ignore[assignment]
259
+ self.optimizer = torch.optim.AdamW(params, **opt_kwargs)
260
+ # Report fused/foreach settings (rank0 only)
261
+ if self._is_main():
262
+ fused_flag = None
263
+ foreach_flag = None
264
+ if hasattr(self.optimizer, "defaults"):
265
+ fused_flag = self.optimizer.defaults.get("fused")
266
+ foreach_flag = self.optimizer.defaults.get("foreach")
267
+ logger.info(f"AdamW configured: fused={fused_flag} foreach={foreach_flag}")
268
+
269
+ # total steps and schedule shape
270
+ ds = getattr(self.train_dataloader, "dataset", None)
271
+ ga = max(1, self.args.gradient_accumulation_steps)
272
+ if isinstance(ds, IterableDataset):
273
+ if self.args.max_steps > 0:
274
+ # Use max_steps to shape the scheduler; allow multiple epochs to re-iterate the stream
275
+ steps_per_epoch = self.args.max_steps
276
+ total_steps = self.args.max_steps
277
+ elif getattr(self.args, "steps_per_epoch", 0) and self.args.steps_per_epoch > 0:
278
+ # steps_per_epoch is already expressed in optimizer steps (train.py accounts for grad_accum)
279
+ steps_per_epoch = max(1, int(self.args.steps_per_epoch))
280
+ total_steps = max(1, self.args.num_train_epochs) * steps_per_epoch
281
+ else:
282
+ # Unknown epoch size; use constant LR without pre-shaped schedule
283
+ self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0)
284
+ return
285
+ else:
286
+ # sized dataloader: len(dataloader) is number of batches
287
+ steps_per_epoch = max(len(self.train_dataloader) // ga, 1)
288
+ total_steps = self.args.max_steps if self.args.max_steps > 0 else self.args.num_train_epochs * steps_per_epoch
289
+
290
+ warmup = int(self.args.warmup_ratio * total_steps)
291
+
292
+ if self.args.lr_scheduler_type == "constant":
293
+ self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0)
294
+ return
295
+
296
+ def lrs_lambda(step: int) -> float:
297
+ if step < warmup:
298
+ return max(float(step) / max(warmup, 1), 1e-6)
299
+ t = (step - warmup) / max(total_steps - warmup, 1)
300
+ if self.args.lr_scheduler_type == "linear":
301
+ return max(1.0 - t, 0.0)
302
+ # cosine default
303
+ return 0.5 * (1.0 + math.cos(math.pi * t))
304
+
305
+ self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lrs_lambda)
306
+
307
+ # ---- training ----
308
+ def train(self) -> Dict[str, float]:
309
+ assert self.train_dataloader is not None, "Call attach_dataloaders() first"
310
+ # If a resume path was provided, load it now (dataloaders are attached).
311
+ if getattr(self, "_resume_path", None):
312
+ self._resume_from(self._resume_path) # loads model/optimizer/scheduler/state
313
+ self._resume_path = None
314
+
315
+ if self.optimizer is None:
316
+ self._create_optimizer_and_scheduler()
317
+
318
+ ds = self.train_dataloader.dataset
319
+
320
+ # Exact step budget for streaming datasets when max_steps<0 and steps_per_epoch>0
321
+ target_total_steps: Optional[int] = None
322
+ if isinstance(ds, IterableDataset) and int(self.args.max_steps) < 0:
323
+ spe = int(getattr(self.args, "steps_per_epoch", 0) or 0)
324
+ if spe > 0:
325
+ target_total_steps = max(1, int(self.args.num_train_epochs)) * spe
326
+
327
+ # Determine total steps for progress bar
328
+ progress_total: Optional[int] = None
329
+ if int(self.args.max_steps) > 0:
330
+ progress_total = int(self.args.max_steps)
331
+ elif isinstance(ds, IterableDataset):
332
+ if target_total_steps is not None:
333
+ progress_total = target_total_steps
334
+ else:
335
+ ga = max(1, self.args.gradient_accumulation_steps)
336
+ steps_per_epoch = max(len(self.train_dataloader) // ga, 1)
337
+ progress_total = max(1, int(self.args.num_train_epochs)) * steps_per_epoch
338
+
339
+ # Initialize Weights & Biases (rank0 only)
340
+ if self._is_main():
341
+ if not hasattr(self, "_wandb"):
342
+ proj = os.environ.get("WANDB_PROJECT", "codongpt")
343
+ name = os.environ.get("WANDB_NAME")
344
+ run_id = os.environ.get("WANDB_RUN_ID")
345
+ resume = os.environ.get("WANDB_RESUME")
346
+ wandb_dir = os.environ.get("WANDB_DIR")
347
+ world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1"))
348
+ init_kwargs = {
349
+ "project": proj,
350
+ "name": name,
351
+ "config": {
352
+ "lr": self.args.learning_rate,
353
+ "warmup_ratio": self.args.warmup_ratio,
354
+ "scheduler": self.args.lr_scheduler_type,
355
+ "batch_size": self.args.per_device_train_batch_size,
356
+ "eval_batch_size": self.args.per_device_eval_batch_size,
357
+ "grad_accum": self.args.gradient_accumulation_steps,
358
+ "effective_global_batch": self.args.per_device_train_batch_size * max(1, world_size) * max(1, self.args.gradient_accumulation_steps),
359
+ "epochs": self.args.num_train_epochs,
360
+ "steps_per_epoch": getattr(self.args, "steps_per_epoch", 0),
361
+ "max_steps": self.args.max_steps,
362
+ "weight_decay": self.args.weight_decay,
363
+ "world_size": world_size,
364
+ "output_dir": self.args.output_dir,
365
+ "fsdp": self.args.fsdp,
366
+ "bf16": self.args.bf16,
367
+ "fp16": self.args.fp16,
368
+ },
369
+ }
370
+ if run_id:
371
+ init_kwargs["id"] = run_id
372
+ if resume:
373
+ init_kwargs["resume"] = resume
374
+ if wandb_dir:
375
+ init_kwargs["dir"] = wandb_dir
376
+ self._wandb = wandb.init(**init_kwargs)
377
+
378
+ self.model.train()
379
+ grad_accum = max(1, self.args.gradient_accumulation_steps)
380
+ progress = None
381
+ if self._is_main() and progress_total is not None and progress_total > 0:
382
+ progress = tqdm(total=progress_total, initial=int(self.state["global_step"]), desc="Train", dynamic_ncols=True)
383
+ if self.device.type == "cuda" and torch.cuda.is_available():
384
+ torch.cuda.reset_peak_memory_stats(self.device)
385
+ world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1"))
386
+ seqs_per_optimizer_step = (
387
+ int(self.args.per_device_train_batch_size) * max(1, world_size) * grad_accum
388
+ )
389
+ log_window_start = time.perf_counter()
390
+ log_window_optimizer_steps = 0
391
+
392
+ for epoch in range(self.state["epoch"], max(1, self.args.num_train_epochs)):
393
+ self.state["epoch"] = epoch
394
+ running_loss = 0.0
395
+ running_count = 0
396
+
397
+ train_iter = iter(self.train_dataloader)
398
+ step = 0
399
+ batches_this_epoch = 0
400
+ optimizer_steps_this_epoch = 0
401
+ # If this is a streaming dataset with a shaped schedule, enforce a per-epoch optimizer step budget
402
+ enforce_budget = False
403
+ epoch_budget = None
404
+ ds = self.train_dataloader.dataset
405
+ if isinstance(ds, IterableDataset):
406
+ spe = int(getattr(self.args, "steps_per_epoch", 0) or 0)
407
+ if spe > 0:
408
+ enforce_budget = True
409
+ epoch_budget = int(spe)
410
+
411
+ refill_attempts = 0
412
+ max_refills = 64 # avoids infinite loops when dataset is empty
413
+
414
+ while True:
415
+ batch, has_batch, local_has_batch = self._next_batch_sync(train_iter)
416
+ if not has_batch:
417
+ # If budget-enforced, attempt to refill the iterator and continue until budget is met.
418
+ if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch < epoch_budget):
419
+ if local_has_batch and self._is_main():
420
+ logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync")
421
+ self._barrier()
422
+ train_iter = iter(self.train_dataloader)
423
+ refill_attempts += 1
424
+ if refill_attempts > max_refills:
425
+ if self._is_main():
426
+ logger.warning(
427
+ "Exceeded max refills for epoch %s (steps %s/%s). Ending epoch early.",
428
+ epoch, optimizer_steps_this_epoch, epoch_budget,
429
+ )
430
+ break
431
+ continue
432
+ else:
433
+ if local_has_batch and self._is_main():
434
+ logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync")
435
+ break
436
+
437
+ batch = self._prepare_batch(batch)
438
+ batches_this_epoch += 1
439
+
440
+ codon_ids = batch["codon_ids"].to(self.device)
441
+ input_ids = codon_ids[:, :-1]
442
+ labels = codon_ids[:, :-1]
443
+
444
+ # Mask PAD/EOS in labels
445
+ pad_id = int(self.tokenizer.pad_token_id) if self.tokenizer is not None else 0
446
+ eos_id = int(self.tokenizer.special_ids.eos) if self.tokenizer is not None else -999
447
+ labels = labels.clone()
448
+ labels[labels == pad_id] = -100
449
+ labels[labels == eos_id] = -100
450
+
451
+ cond = self._build_cond(batch)
452
+
453
+ # autocast context
454
+ use_cuda = (self.device.type == "cuda")
455
+ autocast_dtype = self._amp_dtype
456
+ if autocast_dtype is not None and use_cuda:
457
+ ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype)
458
+ else:
459
+ from contextlib import nullcontext
460
+ ctx = nullcontext()
461
+
462
+ with ctx:
463
+ out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True)
464
+ loss = out["loss"]
465
+
466
+ if self._scaler.is_enabled():
467
+ self._scaler.scale(loss / grad_accum).backward()
468
+ else:
469
+ (loss / grad_accum).backward()
470
+
471
+ running_loss += float(loss.detach().item())
472
+ running_count += 1
473
+
474
+ do_step = ((step + 1) % grad_accum == 0)
475
+ if do_step:
476
+ # Clip
477
+ if self.args.max_grad_norm and self.args.max_grad_norm > 0:
478
+ if isinstance(self.model, FSDP):
479
+ FSDP.clip_grad_norm_(self.model, self.args.max_grad_norm)
480
+ else:
481
+ if self._scaler.is_enabled():
482
+ self._scaler.unscale_(self.optimizer)
483
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
484
+
485
+ # Step
486
+ if self._scaler.is_enabled():
487
+ self._scaler.step(self.optimizer)
488
+ self._scaler.update()
489
+ else:
490
+ self.optimizer.step()
491
+ if self.lr_scheduler is not None:
492
+ self.lr_scheduler.step()
493
+ self.optimizer.zero_grad(set_to_none=True)
494
+ self.state["global_step"] += 1
495
+ optimizer_steps_this_epoch += 1
496
+ log_window_optimizer_steps += 1
497
+
498
+ # (wandb) Defer logging to the periodic block below
499
+
500
+ # Log
501
+ should_log = (self.state["global_step"] % max(1, self.args.logging_steps) == 0)
502
+ peak_alloc_gb = 0.0
503
+ peak_reserved_gb = 0.0
504
+ if should_log:
505
+ peak_alloc_gb, peak_reserved_gb = self._max_cuda_peak_gb()
506
+ if self._is_main() and should_log:
507
+ avg = running_loss / max(running_count, 1)
508
+ lr = float(self.optimizer.param_groups[0]["lr"])
509
+ log_epoch = self._epoch_for_logging()
510
+ elapsed = max(time.perf_counter() - log_window_start, 1e-9)
511
+ step_time_s = elapsed / max(log_window_optimizer_steps, 1)
512
+ seq_per_s = (seqs_per_optimizer_step * max(log_window_optimizer_steps, 1)) / elapsed
513
+ msg = f"epoch {log_epoch} step {self.state['global_step']}: loss={avg:.4f} lr={lr:.6g}"
514
+ if isinstance(out, dict):
515
+ pl = out.get("prefix_len")
516
+ pc = out.get("per_cap")
517
+ if pl is not None and pc is not None:
518
+ msg += f" prefix_mean={float(pl.detach().float().mean().item()):.1f} cap_mean={float(pc.detach().float().mean().item()):.1f}"
519
+ msg += (
520
+ f" step_time_s={step_time_s:.3f} seq_per_s={seq_per_s:.1f}"
521
+ f" peak_mem_alloc_gb={peak_alloc_gb:.1f} peak_mem_reserved_gb={peak_reserved_gb:.1f}"
522
+ )
523
+ logger.info(msg)
524
+ if hasattr(self, "_wandb"):
525
+ wandb.log({
526
+ "train/loss": float(avg),
527
+ "train/lr": float(lr),
528
+ "perf/step_time_s": float(step_time_s),
529
+ "perf/seq_per_s": float(seq_per_s),
530
+ "system/peak_mem_alloc_gb": float(peak_alloc_gb),
531
+ "system/peak_mem_reserved_gb": float(peak_reserved_gb),
532
+ }, step=self.state["global_step"])
533
+ running_loss = 0.0
534
+ running_count = 0
535
+ log_window_start = time.perf_counter()
536
+ log_window_optimizer_steps = 0
537
+
538
+ # Update progress bar
539
+ if progress is not None:
540
+ progress.update(1)
541
+
542
+ # Stop when budget is reached for streaming schedule
543
+ if target_total_steps is not None and self.state["global_step"] >= target_total_steps:
544
+ metrics = {"train_loss": running_loss / max(running_count, 1)}
545
+ self._save_checkpoint("final_model")
546
+ self._barrier()
547
+ return metrics
548
+
549
+ # Periodic teacher-forced evaluation on the held-out dataset
550
+ should_eval = (
551
+ self.eval_dataloader is not None and
552
+ self.args.eval_interval > 0 and
553
+ (self.state["global_step"] % self.args.eval_interval == 0)
554
+ )
555
+ if should_eval:
556
+ eval_metrics = self.evaluate()
557
+ if self._is_main():
558
+ el = float(eval_metrics.get("eval_loss", 0.0))
559
+ ea = eval_metrics.get("eval_codon_acc", None)
560
+ aa = eval_metrics.get("eval_aa_acc", None)
561
+ if ea is not None and aa is not None:
562
+ logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f} aa_acc={float(aa):.3f}")
563
+ elif ea is not None:
564
+ logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f}")
565
+ elif aa is not None:
566
+ logger.info(f"eval: loss={el:.4f} aa_acc={float(aa):.3f}")
567
+ else:
568
+ logger.info(f"eval: loss={el:.4f}")
569
+ if hasattr(self, "_wandb"):
570
+ log_payload = {"eval/loss": el}
571
+ if ea is not None:
572
+ log_payload["eval/codon_acc"] = float(ea)
573
+ if aa is not None:
574
+ log_payload["eval/aa_acc"] = float(aa)
575
+ wandb.log(log_payload, step=self.state["global_step"])
576
+
577
+ # Save by step
578
+ if self.args.save_steps > 0 and (self.state["global_step"] % self.args.save_steps == 0):
579
+ self._save_checkpoint(f"checkpoint-{self.state['global_step']}")
580
+
581
+ # Hard horizon for streaming/step-limited runs
582
+ if self.args.max_steps > 0 and self.state["global_step"] >= self.args.max_steps:
583
+ metrics = {"train_loss": running_loss / max(running_count, 1)}
584
+ self._save_checkpoint("final_model")
585
+ self._barrier()
586
+ if progress is not None:
587
+ progress.close()
588
+ return metrics
589
+
590
+ step += 1
591
+
592
+ # If we enforce a per-epoch budget for streaming datasets, end the epoch once it's reached
593
+ if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch >= epoch_budget):
594
+ break
595
+
596
+ # Epoch summary (rank0 only)
597
+ if self._is_main():
598
+ try:
599
+ eb = int(epoch_budget) if epoch_budget is not None else -1
600
+ except Exception:
601
+ eb = -1
602
+ logger.info(
603
+ "epoch %s completed: optimizer_steps=%s%s",
604
+ self._epoch_for_logging(),
605
+ optimizer_steps_this_epoch,
606
+ (f" / budget {eb}" if eb > 0 else ""),
607
+ )
608
+
609
+ if dist.is_available() and dist.is_initialized():
610
+ gather_device = self.device if self.device.type == "cuda" else torch.device("cpu")
611
+ counts_tensor = torch.tensor(
612
+ [batches_this_epoch, optimizer_steps_this_epoch],
613
+ dtype=torch.long,
614
+ device=gather_device,
615
+ )
616
+ gathered = [torch.zeros_like(counts_tensor) for _ in range(dist.get_world_size())]
617
+ dist.all_gather(gathered, counts_tensor)
618
+ batch_counts = [int(t[0].item()) for t in gathered]
619
+ step_counts = [int(t[1].item()) for t in gathered]
620
+ batch_gap = max(batch_counts) - min(batch_counts)
621
+ step_gap = max(step_counts) - min(step_counts)
622
+ if self._is_main() and (batch_gap > 0 or step_gap > 0):
623
+ logger.warning(
624
+ "Epoch %s imbalance detected across ranks: batches min=%s max=%s, optimizer steps min=%s max=%s",
625
+ epoch,
626
+ min(batch_counts),
627
+ max(batch_counts),
628
+ min(step_counts),
629
+ max(step_counts),
630
+ )
631
+
632
+ # Epoch boundary save for sized datasets
633
+ if not isinstance(ds, IterableDataset):
634
+ self._save_checkpoint(f"epoch-{epoch}")
635
+
636
+ metrics = {"train_loss": 0.0}
637
+ if progress is not None:
638
+ progress.close()
639
+ self._barrier()
640
+ return metrics
641
+
642
+ # ---- evaluation ----
643
+ def evaluate(self) -> Dict[str, float]:
644
+ if self.eval_dataloader is None:
645
+ return {"eval_loss": 0.0}
646
+
647
+ self.model.eval()
648
+
649
+ loss_sum = 0.0
650
+ loss_tokens = 0
651
+ codon_correct = 0
652
+ codon_total = 0
653
+ aa_correct = 0
654
+ aa_total = 0
655
+
656
+ tok = self.tokenizer
657
+ pad_id = int(tok.pad_token_id) if tok is not None else 0
658
+ eos_id = int(tok.special_ids.eos) if tok is not None and hasattr(tok, "special_ids") else -999
659
+ num_special = int(tok.num_special_tokens) if tok is not None else 0
660
+ codon2aa = tok.codon2aa_char_map() if tok is not None and hasattr(tok, "codon2aa_char_map") else {}
661
+
662
+ is_streaming = isinstance(self.eval_dataloader.dataset, IterableDataset)
663
+ max_batches = int(self.args.eval_steps) if (is_streaming and self.args.eval_steps > 0) else None
664
+
665
+ with torch.no_grad():
666
+ eval_iter = iter(self.eval_dataloader)
667
+ b_idx = 0
668
+ while True:
669
+ batch, has_batch, local_has_batch = self._next_batch_sync(eval_iter)
670
+ if not has_batch:
671
+ if local_has_batch and self._is_main():
672
+ logger.debug("eval dataloader: discarded tail batch to stay in sync across ranks")
673
+ break
674
+
675
+ if max_batches is not None and b_idx >= max_batches:
676
+ break
677
+
678
+ batch = self._prepare_batch(batch)
679
+
680
+ codon_ids = batch["codon_ids"].to(self.device)
681
+ input_ids = codon_ids[:, :-1]
682
+ labels = codon_ids[:, :-1]
683
+
684
+ labels = labels.clone()
685
+ labels[labels == pad_id] = -100
686
+ labels[labels == eos_id] = -100
687
+
688
+ cond = self._build_cond(batch)
689
+
690
+ use_cuda = (self.device.type == "cuda")
691
+ autocast_dtype = self._amp_dtype
692
+ if autocast_dtype is not None and use_cuda:
693
+ ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype)
694
+ else:
695
+ from contextlib import nullcontext
696
+ ctx = nullcontext()
697
+
698
+ with ctx:
699
+ out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True)
700
+
701
+ loss = out.get("loss")
702
+ per_cap = out.get("per_cap")
703
+ logits = out.get("logits")
704
+
705
+ tokens_in_batch = 0
706
+ if per_cap is not None:
707
+ tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item())
708
+ loss_tokens += tokens_in_batch
709
+
710
+ if loss is not None and tokens_in_batch > 0:
711
+ loss_sum += float(loss.detach().item()) * tokens_in_batch
712
+
713
+ if logits is None or logits.size(1) == 0 or per_cap is None:
714
+ continue
715
+
716
+ max_cap = logits.size(1)
717
+ batch_size = logits.size(0)
718
+
719
+ labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device)
720
+ common_cols = min(labels.size(1), max_cap)
721
+ if common_cols > 0:
722
+ labels_aligned[:, :common_cols] = labels[:, :common_cols]
723
+
724
+ per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap)
725
+ for row in range(batch_size):
726
+ cap = int(per_cap_int[row].item())
727
+ if cap < max_cap:
728
+ labels_aligned[row, cap:] = -100
729
+
730
+ supervised = labels_aligned != -100
731
+ if num_special > 0:
732
+ supervised = supervised & (labels_aligned >= num_special)
733
+ if not supervised.any():
734
+ continue
735
+
736
+ preds = logits.argmax(dim=-1)
737
+ codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item())
738
+ codon_total += int(supervised.sum().item())
739
+
740
+ if codon2aa and isinstance(batch, dict) and "protein_seqs" in batch:
741
+ prot_list = batch.get("protein_seqs", [])
742
+ for row in range(batch_size):
743
+ cap = int(per_cap_int[row].item())
744
+ if cap <= 0:
745
+ continue
746
+ mask_row = supervised[row, :cap]
747
+ if not mask_row.any():
748
+ continue
749
+ preds_row = preds[row, :cap][mask_row]
750
+ prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else ""
751
+ if not prot:
752
+ continue
753
+ seq_len = min(len(prot), preds_row.size(0))
754
+ if seq_len <= 0:
755
+ continue
756
+ pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
757
+ truth_aa = prot[:seq_len]
758
+ aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i])
759
+ aa_total += seq_len
760
+
761
+ b_idx += 1
762
+
763
+ totals = torch.tensor(
764
+ [loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total],
765
+ dtype=torch.float64,
766
+ device=self.device,
767
+ )
768
+ if dist.is_available() and dist.is_initialized():
769
+ # Ensure every rank has finished its forward passes before the final
770
+ # metric reduction, otherwise FSDP may still be issuing _all_gather
771
+ # collectives on slower ranks.
772
+ self._barrier()
773
+ dist.all_reduce(totals, op=dist.ReduceOp.SUM)
774
+
775
+ loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total = totals.tolist()
776
+
777
+ self.model.train()
778
+
779
+ metrics: Dict[str, float] = {"eval_loss": float(loss_sum) / loss_tokens if loss_tokens > 0 else 0.0}
780
+ if codon_total > 0:
781
+ metrics["eval_codon_acc"] = float(codon_correct) / codon_total
782
+ if aa_total > 0:
783
+ metrics["eval_aa_acc"] = float(aa_correct) / aa_total
784
+
785
+ self._barrier()
786
+ return metrics
787
+
788
+ # ---- internals ----
789
+ def _setup_fsdp(self):
790
+ # Ensure default process group is initialized (required by FSDP)
791
+ device = self.device
792
+ if dist.is_available() and not dist.is_initialized():
793
+ backend = "nccl" if device.type == "cuda" else "gloo"
794
+ sig = inspect.signature(dist.init_process_group)
795
+ if "timeout" in sig.parameters:
796
+ dist.init_process_group(backend=backend, init_method="env://", timeout=datetime.timedelta(minutes=30))
797
+ else:
798
+ dist.init_process_group(backend=backend, init_method="env://")
799
+ mp = MixedPrecision(
800
+ param_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32),
801
+ reduce_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32),
802
+ buffer_dtype=torch.float32,
803
+ )
804
+ logger.info(f"FSDP enabled: sharding={self.args.fsdp} mp_param={mp.param_dtype} mp_reduce={mp.reduce_dtype}")
805
+ # Keep frozen ESM off FSDP if present
806
+ base = self._unwrap(self.model)
807
+ ignored = []
808
+ if hasattr(base, "esm") and isinstance(base.esm, nn.Module):
809
+ ignored.append(base.esm)
810
+
811
+ self.model = FSDP(
812
+ self.model,
813
+ device_id=(self.device if device.type == "cuda" else None),
814
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
815
+ mixed_precision=mp,
816
+ ignored_modules=(ignored if ignored else None),
817
+ sync_module_states=True,
818
+ )
819
+
820
+ # Place ignored module on device exactly once
821
+ if ignored:
822
+ ignored[0].to(device)
823
+
824
+ def _unwrap(self, module):
825
+ return getattr(module, "module", module)
826
+
827
+ def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
828
+ # Species embeddings (fixed-size or sequence)
829
+ if self.species_store is not None and "species_ids" in batch:
830
+ sids = batch["species_ids"]
831
+ if torch.is_tensor(sids):
832
+ sids = sids.detach().cpu().tolist()
833
+ result = self.species_store.batch_get(sids)
834
+ if isinstance(result, tuple):
835
+ sp_tok, _ = result # [B, Ls, Ds]
836
+ batch["species_tok_emb"] = sp_tok.to(self.device, non_blocking=True)
837
+ else:
838
+ sp = result # [B, Ds]
839
+ batch["species_emb"] = sp.to(self.device, non_blocking=True)
840
+
841
+ # Move obvious tensors
842
+ if "codon_ids" in batch and hasattr(batch["codon_ids"], "to"):
843
+ batch["codon_ids"] = batch["codon_ids"].to(self.device, non_blocking=True)
844
+
845
+ return batch
846
+
847
+ def _build_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]:
848
+ cond: Dict[str, Any] = {"control_mode": "fixed"}
849
+ if "species_tok_emb" in batch:
850
+ cond["species_tok_emb_src"] = batch["species_tok_emb"]
851
+ cond["species_tok_emb_tgt"] = batch["species_tok_emb"]
852
+ elif "species_emb" in batch:
853
+ cond["species_emb_src"] = batch["species_emb"]
854
+ cond["species_emb_tgt"] = batch["species_emb"]
855
+ if "protein_seqs" in batch:
856
+ cond["protein_seqs"] = batch["protein_seqs"]
857
+ return cond
858
+
859
+ def _next_batch_sync(self, iterator):
860
+ """Fetch next batch and drop out early if any rank exhausts its loader."""
861
+ try:
862
+ batch = next(iterator)
863
+ local_has_batch = True
864
+ except StopIteration:
865
+ batch = None
866
+ local_has_batch = False
867
+
868
+ distributed = dist.is_available() and dist.is_initialized()
869
+ has_batch = local_has_batch
870
+
871
+ if distributed:
872
+ flag_device = self.device if self.device.type == "cuda" else torch.device("cpu")
873
+ flag = torch.tensor([1 if local_has_batch else 0], device=flag_device)
874
+ dist.all_reduce(flag, op=dist.ReduceOp.MIN)
875
+ has_batch = bool(flag.item())
876
+
877
+ if not has_batch:
878
+ return None, False, local_has_batch
879
+
880
+ return batch, True, local_has_batch
881
+
882
+ def _is_main(self) -> bool:
883
+ return (not dist.is_available()) or (not dist.is_initialized()) or dist.get_rank() == 0
884
+
885
+ def _barrier(self):
886
+ if dist.is_available() and dist.is_initialized():
887
+ # On NCCL, pass device_ids to avoid rank↔GPU mapping ambiguity when supported
888
+ if self.device.type == "cuda":
889
+ sig = inspect.signature(dist.barrier)
890
+ if "device_ids" in sig.parameters:
891
+ dist.barrier(device_ids=[self.local_rank])
892
+ return
893
+ dist.barrier()
894
+
895
+ def _max_cuda_peak_gb(self) -> Tuple[float, float]:
896
+ if self.device.type != "cuda" or not torch.cuda.is_available():
897
+ return 0.0, 0.0
898
+ vals = torch.tensor(
899
+ [
900
+ float(torch.cuda.max_memory_allocated(self.device)),
901
+ float(torch.cuda.max_memory_reserved(self.device)),
902
+ ],
903
+ dtype=torch.float64,
904
+ device=self.device,
905
+ )
906
+ if dist.is_available() and dist.is_initialized():
907
+ dist.all_reduce(vals, op=dist.ReduceOp.MAX)
908
+ scale = float(1024 ** 3)
909
+ return float(vals[0].item() / scale), float(vals[1].item() / scale)
910
+
911
+ # (Per-sample quick eval removed; evaluation now uses held-out dataloader.)
912
+
913
+ def _epoch_for_logging(self) -> int:
914
+ steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0)
915
+ if steps_per_epoch > 0:
916
+ est = self.state.get("global_step", 0) // steps_per_epoch
917
+ if self.args.num_train_epochs > 0:
918
+ max_epoch = max(int(self.args.num_train_epochs) - 1, 0)
919
+ if est > max_epoch:
920
+ return max_epoch
921
+ return int(est)
922
+ return int(self.state.get("epoch", 0))
923
+
924
+ # ---- checkpointing ----
925
+ def _save_checkpoint(self, name: str):
926
+ self.state["epoch"] = int(self._epoch_for_logging())
927
+ # All ranks participate in FSDP state_dict collectives; only rank0 writes files.
928
+ out_dir = os.path.join(self.args.output_dir, name)
929
+ os.makedirs(out_dir, exist_ok=True)
930
+
931
+ optim_state = None
932
+ if isinstance(self.model, FSDP):
933
+ with warnings.catch_warnings():
934
+ warnings.filterwarnings("ignore", category=FutureWarning)
935
+ with FSDP.state_dict_type(
936
+ self.model,
937
+ StateDictType.FULL_STATE_DICT,
938
+ FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
939
+ FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
940
+ ):
941
+ state = self.model.state_dict()
942
+ # NOTE: Under FSDP, optimizer.state_dict() is sharded per-rank.
943
+ # Use FSDP.optim_state_dict() to materialize a full optimizer state dict (rank0_only).
944
+ if self.optimizer is not None:
945
+ optim_state = FSDP.optim_state_dict(self.model, self.optimizer)
946
+ else:
947
+ state = self._unwrap(self.model).state_dict()
948
+ if self.optimizer is not None:
949
+ optim_state = self.optimizer.state_dict()
950
+
951
+ # Save minimal data cursor (total samples yielded so far) next to output_dir if configured
952
+ per_rank_positions: Optional[List[int]] = None
953
+ p = getattr(self.args, "data_cursor_path", None)
954
+ if p:
955
+ ds = getattr(self.train_dataloader, "dataset", None)
956
+ if hasattr(ds, "get_stream_position"):
957
+ local_pos = int(ds.get_stream_position())
958
+ if dist.is_available() and dist.is_initialized():
959
+ gather_device = self.device if self.device.type == "cuda" else torch.device("cpu")
960
+ tensor = torch.tensor([local_pos], dtype=torch.long, device=gather_device)
961
+ gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
962
+ dist.all_gather(gathered, tensor)
963
+ per_rank_positions = [int(t.item()) for t in gathered]
964
+ else:
965
+ per_rank_positions = [local_pos]
966
+
967
+ if not self._is_main():
968
+ # Non-main ranks skip serialization but stay in lockstep
969
+ self._barrier()
970
+ return
971
+
972
+ # Rank 0 writes artifacts
973
+ save_file(state, os.path.join(out_dir, "model.safetensors"))
974
+
975
+ # Optimizer + scheduler
976
+ if optim_state is not None:
977
+ torch.save(optim_state, os.path.join(out_dir, "optimizer.pt"))
978
+ if self.lr_scheduler is not None:
979
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(out_dir, "scheduler.pt"))
980
+
981
+ # Trainer config/state
982
+ base = self._unwrap(self.model)
983
+ # Infer mlp_ratio from first block if present
984
+ mlp_ratio = 4.0
985
+ try:
986
+ if hasattr(base, "blocks") and len(getattr(base, "blocks", [])) > 0:
987
+ w1 = base.blocks[0].ffn.w1.weight # [H*mlp, H]
988
+ H = int(getattr(base, "hidden_size", w1.shape[1]))
989
+ if H > 0:
990
+ mlp_ratio = float(w1.shape[0]) / float(H)
991
+ except Exception:
992
+ pass
993
+
994
+ trainer_cfg = {
995
+ # capacity / prefixes
996
+ "max_length": int(self.args.max_length),
997
+ "max_species_prefix": int(getattr(base, "max_species_prefix", 0)),
998
+ "max_protein_prefix": int(getattr(base, "max_protein_prefix", 0)),
999
+
1000
+ # architecture hints
1001
+ "hidden_size": int(getattr(base, "hidden_size", 0)),
1002
+ "num_hidden_layers": int(getattr(base, "num_layers", 0)),
1003
+ "num_attention_heads": int(getattr(base, "num_heads", 0)),
1004
+ "mlp_ratio": float(mlp_ratio),
1005
+
1006
+ # conditioning flags
1007
+ "prepend_species": bool(getattr(base, "prepend_species", True)),
1008
+ "prepend_protein": bool(getattr(base, "prepend_protein", False)),
1009
+ "species_embedding_dim": int(getattr(base, "species_embedding_dim", 1024)),
1010
+
1011
+ # ESM info (even if prepend_protein=False)
1012
+ "esm_model_name": str(getattr(self.args, "esm_model_name", "")),
1013
+ "esm_device": str(getattr(self.args, "esm_device", "cuda")),
1014
+ "esm_dtype": str(getattr(self.args, "esm_dtype", "fp32")).lower(),
1015
+
1016
+ # kernels
1017
+
1018
+ # attention impl
1019
+ "attn_impl": str(getattr(base, "attn_impl", "gqa")),
1020
+ "num_kv_groups": int(getattr(base, "num_kv_groups", 0)),
1021
+ }
1022
+ with open(os.path.join(out_dir, "trainer_config.json"), "w") as f:
1023
+ json.dump(trainer_cfg, f, indent=2)
1024
+ with open(os.path.join(out_dir, "trainer_state.json"), "w") as f:
1025
+ json.dump({"epoch": self.state["epoch"], "global_step": self.state["global_step"]}, f, indent=2)
1026
+
1027
+ if p and per_rank_positions is not None:
1028
+ payload = {
1029
+ "skip_samples": int(sum(per_rank_positions)),
1030
+ "per_rank": per_rank_positions,
1031
+ "world_size": len(per_rank_positions),
1032
+ }
1033
+ os.makedirs(os.path.dirname(os.path.abspath(p)), exist_ok=True)
1034
+ with open(p, "w") as f:
1035
+ json.dump(payload, f)
1036
+
1037
+ # Tokenizer vocab for sampling
1038
+ try:
1039
+ if self.tokenizer is not None and hasattr(self.tokenizer, "save_vocabulary"):
1040
+ self.tokenizer.save_vocabulary(out_dir)
1041
+ except Exception as e:
1042
+ logger.warning(f"Failed to save vocabulary to {out_dir}: {e}")
1043
+
1044
+ self._prune_checkpoints(self.args.output_dir, self.args.save_total_limit)
1045
+ logger.info(f"Saved checkpoint → {out_dir}")
1046
+
1047
+ # Release other ranks
1048
+ self._barrier()
1049
+
1050
+ def _resume_from(self, ckpt_dir: str):
1051
+ st_path = os.path.join(ckpt_dir, "model.safetensors")
1052
+ if not os.path.exists(st_path):
1053
+ raise FileNotFoundError(f"No model.safetensors in {ckpt_dir}")
1054
+ state = load_file(st_path)
1055
+
1056
+ if isinstance(self.model, FSDP):
1057
+ with warnings.catch_warnings():
1058
+ warnings.filterwarnings("ignore", category=FutureWarning)
1059
+ with FSDP.state_dict_type(
1060
+ self.model,
1061
+ StateDictType.FULL_STATE_DICT,
1062
+ FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
1063
+ ):
1064
+ self.model.load_state_dict(state, strict=False)
1065
+ else:
1066
+ self._unwrap(self.model).load_state_dict(state, strict=False)
1067
+
1068
+
1069
+ scheduler_restored = False
1070
+
1071
+ opt_path = os.path.join(ckpt_dir, "optimizer.pt")
1072
+ if os.path.exists(opt_path):
1073
+ if self.optimizer is None:
1074
+ self._create_optimizer_and_scheduler()
1075
+ if not self.args.override_lr_on_resume:
1076
+ loaded = torch.load(opt_path, map_location="cpu")
1077
+ # Under FSDP, saved optimizer.pt is a full optimizer state dict produced by
1078
+ # FSDP.optim_state_dict(). Convert it to a per-rank state dict before loading.
1079
+ if isinstance(self.model, FSDP):
1080
+ try:
1081
+ loaded = FSDP.optim_state_dict_to_load(self.model, self.optimizer, loaded)
1082
+ except Exception as e:
1083
+ msg = (
1084
+ "Failed to convert FSDP optimizer state dict for loading. "
1085
+ "This checkpoint likely contains an incomplete (rank0-only sharded) optimizer.pt from an older version. "
1086
+ "Full optimizer resume is not possible from this checkpoint.\n"
1087
+ f"Underlying error: {e}\n"
1088
+ "Options:\n"
1089
+ " 1) Start a fresh run (new --output_dir), or\n"
1090
+ " 2) Re-run with --override_lr_on_resume to skip optimizer restore (not a full resume)."
1091
+ )
1092
+ if self._is_main():
1093
+ logger.error(msg)
1094
+ raise RuntimeError(msg) from e
1095
+ self.optimizer.load_state_dict(loaded)
1096
+
1097
+ sch_path = os.path.join(ckpt_dir, "scheduler.pt")
1098
+ if os.path.exists(sch_path):
1099
+ if self.lr_scheduler is None:
1100
+ self._create_optimizer_and_scheduler()
1101
+ if self.lr_scheduler is not None and not self.args.override_lr_on_resume:
1102
+ self.lr_scheduler.load_state_dict(torch.load(sch_path, map_location="cpu"))
1103
+ scheduler_restored = True
1104
+
1105
+ ts_path = os.path.join(ckpt_dir, "trainer_state.json")
1106
+ if os.path.exists(ts_path):
1107
+ with open(ts_path, "r") as f:
1108
+ ts = json.load(f)
1109
+ self.state["epoch"] = int(ts.get("epoch", 0))
1110
+ self.state["global_step"] = int(ts.get("global_step", 0))
1111
+
1112
+ steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0)
1113
+ if steps_per_epoch > 0:
1114
+ inferred_epoch = self.state.get("global_step", 0) // steps_per_epoch
1115
+ num_epochs = max(int(self.args.num_train_epochs), 1)
1116
+ inferred_epoch = min(inferred_epoch, num_epochs - 1)
1117
+ if inferred_epoch != self.state.get("epoch"):
1118
+ if self._is_main():
1119
+ logger.info(
1120
+ "Adjusting epoch from %s to %s based on global_step %s and steps_per_epoch %s",
1121
+ self.state.get("epoch"),
1122
+ inferred_epoch,
1123
+ self.state.get("global_step"),
1124
+ steps_per_epoch,
1125
+ )
1126
+ self.state["epoch"] = int(inferred_epoch)
1127
+
1128
+ # If we skipped loading the scheduler state (e.g., different world size or override),
1129
+ # fast-forward it to the saved global_step so LR does not restart from warmup.
1130
+ if self.lr_scheduler is not None and not scheduler_restored:
1131
+ target_step = int(self.state.get("global_step", 0))
1132
+ if target_step > 0:
1133
+ try:
1134
+ # Most schedulers (LambdaLR, CosineAnnealing, etc.) accept an "epoch" kwarg.
1135
+ self.lr_scheduler.step(target_step)
1136
+ except TypeError:
1137
+ # Fallback: advance manually.
1138
+ for _ in range(target_step):
1139
+ self.lr_scheduler.step()
1140
+ # Ensure optimizer LR reflects the scheduler's current value.
1141
+ try:
1142
+ last_lrs = self.lr_scheduler.get_last_lr()
1143
+ except Exception:
1144
+ last_lrs = [group.get("lr") for group in self.optimizer.param_groups]
1145
+ if last_lrs:
1146
+ for group, lr in zip(self.optimizer.param_groups, last_lrs):
1147
+ group["lr"] = float(lr)
1148
+
1149
+ logger.info(f"Resumed from {ckpt_dir}")
1150
+
1151
+ def _checkpoint_step(self, path: str) -> Optional[int]:
1152
+ m = re.fullmatch(r"checkpoint-(\d+)", os.path.basename(path))
1153
+ if not m:
1154
+ return None
1155
+ return int(m.group(1))
1156
+
1157
+ def _prune_checkpoints(self, root: str, keep: int):
1158
+ if not os.path.isdir(root):
1159
+ return
1160
+
1161
+ try:
1162
+ subdirs = [
1163
+ os.path.join(root, d)
1164
+ for d in os.listdir(root)
1165
+ if os.path.isdir(os.path.join(root, d))
1166
+ ]
1167
+ except FileNotFoundError:
1168
+ return
1169
+
1170
+ step_dirs: list[tuple[int, str]] = []
1171
+ for path in subdirs:
1172
+ step = self._checkpoint_step(path)
1173
+ if step is not None:
1174
+ step_dirs.append((step, path))
1175
+
1176
+ if not step_dirs:
1177
+ return
1178
+
1179
+ step_dirs.sort(key=lambda item: item[0])
1180
+ latest_step = step_dirs[-1][0]
1181
+
1182
+ recent_window = max(0, int(getattr(self.args, "ckpt_recent_window_steps", 0) or 0))
1183
+ recent_interval = max(0, int(getattr(self.args, "ckpt_recent_interval", 0) or 0))
1184
+ archive_interval = max(0, int(getattr(self.args, "ckpt_archive_interval", 0) or 0))
1185
+
1186
+ keep_paths: set[str] = set()
1187
+ if recent_window > 0 and (recent_interval > 0 or archive_interval > 0):
1188
+ if recent_interval <= 0:
1189
+ recent_interval = max(1, int(getattr(self.args, "save_steps", 1) or 1))
1190
+
1191
+ for step, path in step_dirs:
1192
+ age = latest_step - step
1193
+ if age <= recent_window:
1194
+ interval = recent_interval
1195
+ else:
1196
+ interval = archive_interval
1197
+ if interval > 0 and (step % interval == 0):
1198
+ keep_paths.add(path)
1199
+
1200
+ if not keep_paths:
1201
+ # Legacy fallback: keep the most recent N step checkpoints.
1202
+ if keep <= 0:
1203
+ return
1204
+ keep_paths = {path for _, path in step_dirs[-keep:]}
1205
+ else:
1206
+ # Always preserve the newest checkpoint, even if the interval math misses it.
1207
+ keep_paths.add(step_dirs[-1][1])
1208
+ if keep > 0:
1209
+ kept = [(step, path) for step, path in step_dirs if path in keep_paths]
1210
+ if len(kept) > keep:
1211
+ trim = len(kept) - keep
1212
+ for _, path in kept[:trim]:
1213
+ keep_paths.discard(path)
1214
+
1215
+ removed = []
1216
+ for _, path in step_dirs:
1217
+ if path in keep_paths:
1218
+ continue
1219
+ shutil.rmtree(path, ignore_errors=True)
1220
+ removed.append(os.path.basename(path))
1221
+
1222
+ if removed and self._is_main():
1223
+ logger.info(
1224
+ "Pruned %s checkpoints (latest_step=%s, recent_window=%s, recent_interval=%s, archive_interval=%s)",
1225
+ len(removed),
1226
+ latest_step,
1227
+ recent_window,
1228
+ recent_interval,
1229
+ archive_interval,
1230
+ )
train.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Minimal, honest training script for CodonGPT on CSV data.
4
+
5
+ - Species conditioning: REQUIRED (precomputed embeddings)
6
+ - Protein conditioning (ESM-C): ENABLED BY DEFAULT. Disable with --no_protein.
7
+ - Global capacity is controlled by --max_length (prefix + start + codon).
8
+ """
9
+
10
+ import os
11
+ import math
12
+ import argparse
13
+ import logging
14
+ import torch
15
+
16
+ from src import CodonGPT, CodonTokenizer, Trainer, TrainingArguments
17
+ from src.dataset import create_precomputed_dataloaders, SpeciesEmbeddingStore
18
+
19
+ logging.basicConfig(
20
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
21
+ datefmt="%m/%d/%Y %H:%M:%S",
22
+ level=logging.INFO,
23
+ )
24
+ logger = logging.getLogger("codongpt.train")
25
+
26
+ def _describe_sdp_kernels() -> None:
27
+ # Log the enabled SDPA backends (Flash/MemEff/Math) without raising on older PyTorch
28
+ flash = None; mem_eff = None; mathk = None
29
+ if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda'):
30
+ tbc = torch.backends.cuda
31
+ if hasattr(tbc, 'flash_sdp_enabled'):
32
+ flash = tbc.flash_sdp_enabled()
33
+ if hasattr(tbc, 'mem_efficient_sdp_enabled'):
34
+ mem_eff = tbc.mem_efficient_sdp_enabled()
35
+ if hasattr(tbc, 'math_sdp_enabled'):
36
+ mathk = tbc.math_sdp_enabled()
37
+ logger.info(f"SDP kernels: flash={flash} mem_efficient={mem_eff} math={mathk}")
38
+
39
+ def _print_model_size(model: torch.nn.Module, bf16: bool, fp16: bool) -> None:
40
+ total = sum(p.numel() for p in model.parameters())
41
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
42
+ w_bytes = 2 if (bf16 or fp16) else 4
43
+ opt_bytes = 8 # Adam moments in FP32
44
+ weights_gb = total * w_bytes / (1024**3)
45
+ opt_gb = trainable * opt_bytes / (1024**3)
46
+ logger.info(
47
+ f"Model params: total={total:,} trainable={trainable:,} (~{weights_gb:.2f} GB weights, ~{opt_gb:.2f} GB optimizer)"
48
+ )
49
+
50
+ def _speed_toggles():
51
+ if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
52
+ torch.backends.cuda.matmul.allow_tf32 = True
53
+ if hasattr(torch, "set_float32_matmul_precision"):
54
+ torch.set_float32_matmul_precision("high")
55
+ if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "benchmark"):
56
+ torch.backends.cudnn.benchmark = True
57
+
58
+
59
+ def parse_args():
60
+ p = argparse.ArgumentParser(description="Train CodonGPT on CSV data")
61
+ # Data (CSV path or Parquet glob/dir)
62
+ p.add_argument("--train_data", type=str, default="random_sample_1000.csv",
63
+ help="Training data: CSV file or Parquet glob/dir (e.g., ./data/train_shards/*.parquet)")
64
+ p.add_argument("--val_data", type=str, default=None,
65
+ help="Validation data: CSV file or Parquet glob/dir")
66
+ p.add_argument("--embeddings_dir", type=str, default="embeddings",
67
+ help="Dir with species embeddings (species_vocab.json, *.bin/memmap)")
68
+
69
+ # Model / capacity
70
+ p.add_argument("--hidden", type=int, default=750, help="Model hidden size")
71
+ p.add_argument("--layers", type=int, default=20, help="Number of transformer layers")
72
+ p.add_argument("--heads", type=int, default=15, help="Number of attention heads")
73
+ p.add_argument("--attn", type=str, choices=["mha", "gqa"], default="gqa", help="Attention implementation: 'mha' or 'gqa'")
74
+ p.add_argument("--num_kv_groups", type=int, default=5, help="GQA: number of KV groups (0 = default/no grouping)")
75
+ p.add_argument("--mlp_ratio", type=float, default=3.2, help="FFN expansion ratio (mlp hidden = ratio * hidden)")
76
+ p.add_argument("--max_length", type=int, default=2048,
77
+ help="Global max length (prefix + start + codon)")
78
+ p.add_argument("--max_species_prefix", type=int, default=0,
79
+ help="Cap species prefix tokens (0 = uncapped)")
80
+ p.add_argument("--max_protein_prefix", type=int, default=1024,
81
+ help="Cap protein prefix tokens (0 = uncapped)")
82
+
83
+ # Protein conditioning: always enabled (ESM-C)
84
+
85
+ # Training
86
+ p.add_argument("--output_dir", type=str, default="checkpoints", help="Where to save checkpoints")
87
+ p.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
88
+ p.add_argument("--batch_size", type=int, default=20, help="Per-device train batch size")
89
+ p.add_argument("--eval_batch_size", type=int, default=32, help="Per-device eval batch size")
90
+ p.add_argument("--workers", type=int, default=4, help="DataLoader workers")
91
+ p.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
92
+ p.add_argument("--train_shuffle_buffer", type=int, default=0,
93
+ help="Streaming shuffle buffer for training (set 0 when data is pre-shuffled)")
94
+ p.add_argument("--val_shuffle_buffer", type=int, default=0,
95
+ help="Streaming shuffle buffer for validation (0 disables)")
96
+ p.add_argument("--csv_chunksize", type=int, default=200_000,
97
+ help="Pandas read_csv chunksize for CSV inputs")
98
+
99
+ # Optim / schedule
100
+ p.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
101
+ p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio for LR schedule (0.0-1.0)")
102
+ p.add_argument(
103
+ "--lr_scheduler",
104
+ type=str,
105
+ choices=["linear", "cosine", "constant"],
106
+ default="linear",
107
+ help="LR schedule applied after warmup; 'linear' decays to zero by the end of training",
108
+ )
109
+ p.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay")
110
+ p.add_argument("--adam_beta1", type=float, default=0.9,
111
+ help="Adam beta1 (momentum) coefficient")
112
+ p.add_argument("--adam_beta2", type=float, default=0.95,
113
+ help="Adam beta2 (squared-gradient) coefficient")
114
+ p.add_argument("--logging_steps", type=int, default=20, help="Logging interval (steps)")
115
+ p.add_argument("--save_steps", type=int, default=10, help="Save every N steps (0 disables step-saving)")
116
+ p.add_argument("--save_total_limit", type=int, default=10, help="Keep at most N recent checkpoints")
117
+ p.add_argument("--ckpt_recent_window_steps", type=int, default=0,
118
+ help="If >0, keep finer-grained checkpoints within this many recent steps")
119
+ p.add_argument("--ckpt_recent_interval", type=int, default=0,
120
+ help="Retention interval inside the recent checkpoint window (0 disables custom retention)")
121
+ p.add_argument("--ckpt_archive_interval", type=int, default=0,
122
+ help="Retention interval for checkpoints older than the recent window (0 prunes them)")
123
+ p.add_argument("--max_steps", type=int, default=-1,
124
+ help="Total training steps. REQUIRED for streaming (IterableDataset)")
125
+ p.add_argument("--steps_per_epoch", type=int, default=0,
126
+ help="For streaming datasets: shape LR schedule as epochs*steps_per_epoch when max_steps<0")
127
+ p.add_argument("--max_grad_norm", type=float, default=1.0,
128
+ help="Clip gradients to this global L2 norm; set <=0 to disable")
129
+ p.add_argument("--override_lr_on_resume", action="store_true",
130
+ help="Do not restore LR/optimizer state on resume (keep current lr)")
131
+
132
+ # Resume
133
+ p.add_argument("--resume_from", type=str, default=None,
134
+ help="Path to checkpoint dir to resume from; pass 'auto' to pick latest in output_dir")
135
+
136
+ # Evaluation scheduling
137
+ p.add_argument("--eval_interval", type=int, default=0,
138
+ help="Run evaluation every N optimizer steps on --val_data (0 disables)")
139
+ p.add_argument("--eval_steps", type=int, default=5000,
140
+ help="For streaming eval datasets: limit to this many batches (0 = full eval)")
141
+
142
+ # Hardware / precision
143
+ p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
144
+ p.add_argument("--bf16", action="store_true", help="bfloat16 mixed precision")
145
+ p.add_argument("--fp16", action="store_true", help="float16 mixed precision")
146
+ p.add_argument("--fsdp", action="store_true", help="Enable FSDP full sharding")
147
+ p.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing")
148
+ return p.parse_args()
149
+
150
+
151
+ def main():
152
+ args = parse_args()
153
+ _speed_toggles()
154
+
155
+ if args.device == "cuda" and not torch.cuda.is_available():
156
+ logger.warning("CUDA not available; switching to CPU")
157
+ args.device = "cpu"
158
+
159
+ # Tokenizer
160
+ tok = CodonTokenizer()
161
+ # Ensure output dir exists and persist vocab.json (used by sampler)
162
+ os.makedirs(os.path.abspath(args.output_dir), exist_ok=True)
163
+ tok.save_vocabulary(args.output_dir)
164
+
165
+
166
+ # Data first — we need Ds for species embeddings
167
+ train_loader, val_loader, species_store = create_precomputed_dataloaders(
168
+ train_path=args.train_data,
169
+ val_path=args.val_data,
170
+ embeddings_dir=args.embeddings_dir,
171
+ tokenizer=tok,
172
+ batch_size=args.batch_size,
173
+ num_workers=args.workers,
174
+ species_pooling="sequence", # prefer variable-length token sequence if available
175
+ csv_chunksize=int(args.csv_chunksize),
176
+ train_shuffle_buffer=int(args.train_shuffle_buffer),
177
+ val_shuffle_buffer=int(args.val_shuffle_buffer),
178
+ )
179
+
180
+ # Estimate steps_per_epoch for streaming schedule shaping if not provided
181
+ steps_per_epoch = int(getattr(args, "steps_per_epoch", 0) or 0)
182
+ total_rows = 0
183
+ paths: list[str] = []
184
+ if steps_per_epoch <= 0 and int(args.max_steps) < 0:
185
+ def _expand_paths(maybe: str | list[str]) -> list[str]:
186
+ import glob as _glob
187
+ from pathlib import Path as _Path
188
+ paths: list[str] = []
189
+ if isinstance(maybe, str):
190
+ p = _Path(maybe)
191
+ if p.is_dir():
192
+ paths.extend(sorted(str(x) for x in p.rglob("*.parquet")))
193
+ else:
194
+ paths = sorted(_glob.glob(str(p)))
195
+ else:
196
+ for it in maybe:
197
+ paths.extend(_expand_paths(it))
198
+ # de-dup
199
+ seen = set(); out = []
200
+ for x in paths:
201
+ if x not in seen:
202
+ out.append(x); seen.add(x)
203
+ return out
204
+
205
+ paths = _expand_paths(args.train_data)
206
+ if paths:
207
+ try:
208
+ import pyarrow.parquet as pq
209
+ for fp in paths:
210
+ if fp.lower().endswith((".parquet", ".parq")):
211
+ pf = pq.ParquetFile(fp)
212
+ md = pf.metadata
213
+ if md is not None:
214
+ total_rows += int(md.num_rows)
215
+ except Exception:
216
+ # Fallback: keep steps_per_epoch at 0 if pyarrow not available
217
+ total_rows = 0
218
+ if total_rows > 0:
219
+ world = int(os.environ.get("WORLD_SIZE", "1"))
220
+ ga = max(1, int(getattr(args, "grad_accum", 1)))
221
+ denom = max(1, int(args.batch_size) * max(1, world) * ga)
222
+ steps_per_epoch = max(1, math.ceil(total_rows / denom))
223
+ logger.info(f"Estimated steps_per_epoch={steps_per_epoch} from {len(paths)} parquet files, total_rows={total_rows}")
224
+
225
+ world = int(os.environ.get("WORLD_SIZE", "1"))
226
+ grad_accum = max(1, int(getattr(args, "grad_accum", 1)))
227
+ effective_global_batch = int(args.batch_size) * max(1, world) * grad_accum
228
+ logger.info(
229
+ "Batch config: per_device_train_batch=%s per_device_eval_batch=%s world_size=%s grad_accum=%s effective_global_batch=%s",
230
+ args.batch_size,
231
+ args.eval_batch_size,
232
+ world,
233
+ grad_accum,
234
+ effective_global_batch,
235
+ )
236
+
237
+ # Resolve per-process CUDA device for ESM (avoid defaulting to cuda:0 on all ranks)
238
+ esm_dev = "cpu"
239
+ if args.device == "cuda" and torch.cuda.is_available():
240
+ lr = int(os.environ.get("LOCAL_RANK", "0"))
241
+ esm_dev = f"cuda:{lr}"
242
+
243
+ # Model — species is always on; protein defaults to ON (can be disabled with --no_protein)
244
+ model = CodonGPT(
245
+ vocab_size=tok.vocab_size,
246
+ num_special_tokens=tok.num_special_tokens,
247
+ special_ids=tok.special_ids,
248
+ hidden_size=args.hidden,
249
+ num_layers=args.layers,
250
+ num_heads=args.heads,
251
+ mlp_ratio=float(args.mlp_ratio),
252
+ max_position_embeddings=args.max_length,
253
+ prepend_species=True,
254
+ prepend_protein=True,
255
+ esm_model_name="esmc_300m",
256
+ esm_device=esm_dev,
257
+ max_protein_prefix=int(args.max_protein_prefix),
258
+ max_species_prefix=int(args.max_species_prefix),
259
+ dropout=0.1,
260
+ species_embedding_dim=int(species_store.Ds()),
261
+ attn_impl=str(args.attn),
262
+ num_kv_groups=int(args.num_kv_groups),
263
+ )
264
+
265
+ # Report model size and SDPA (Flash) kernel configuration
266
+ _print_model_size(model, bf16=bool(args.bf16), fp16=bool(args.fp16))
267
+ _describe_sdp_kernels()
268
+
269
+ # Trainer args
270
+ targs = TrainingArguments(
271
+ output_dir=args.output_dir,
272
+ save_steps=args.save_steps,
273
+ save_total_limit=int(args.save_total_limit),
274
+ ckpt_recent_window_steps=int(args.ckpt_recent_window_steps),
275
+ ckpt_recent_interval=int(args.ckpt_recent_interval),
276
+ ckpt_archive_interval=int(args.ckpt_archive_interval),
277
+ num_train_epochs=args.epochs,
278
+ max_steps=int(args.max_steps),
279
+ gradient_accumulation_steps=int(args.grad_accum),
280
+ warmup_ratio=float(args.warmup_ratio),
281
+ lr_scheduler_type=str(args.lr_scheduler),
282
+ per_device_train_batch_size=args.batch_size,
283
+ per_device_eval_batch_size=args.eval_batch_size,
284
+ dataloader_num_workers=args.workers,
285
+ learning_rate=args.lr,
286
+ weight_decay=args.weight_decay,
287
+ adam_beta1=float(args.adam_beta1),
288
+ adam_beta2=float(args.adam_beta2),
289
+ max_grad_norm=float(args.max_grad_norm),
290
+ logging_steps=args.logging_steps,
291
+ override_lr_on_resume=bool(args.override_lr_on_resume),
292
+ data_cursor_path=os.path.join(os.path.abspath(args.output_dir), "data_cursor.json"),
293
+ fp16=bool(args.fp16),
294
+ bf16=bool(args.bf16),
295
+ fsdp=("full_shard" if args.fsdp else None),
296
+ gradient_checkpointing=bool(args.grad_ckpt),
297
+ max_length=int(args.max_length),
298
+ esm_model_name="esmc_300m",
299
+ esm_device=esm_dev,
300
+ esm_dtype=("bf16" if args.bf16 else ("fp16" if args.fp16 else "fp32")),
301
+ # sampling eval
302
+ eval_interval=int(args.eval_interval),
303
+ eval_steps=int(args.eval_steps),
304
+ steps_per_epoch=int(steps_per_epoch),
305
+ )
306
+
307
+ # Resolve auto-resume if requested
308
+ resume_path = None
309
+ if args.resume_from:
310
+ if args.resume_from == "auto":
311
+ root = os.path.abspath(args.output_dir)
312
+ if os.path.isdir(root):
313
+ try:
314
+ subdirs = []
315
+ for d in os.listdir(root):
316
+ path = os.path.join(root, d)
317
+ if not os.path.isdir(path):
318
+ continue
319
+ if not (
320
+ d == "final_model" or
321
+ d.startswith("checkpoint-")
322
+ ):
323
+ continue
324
+ if not (
325
+ os.path.exists(os.path.join(path, "model.safetensors")) or
326
+ os.path.exists(os.path.join(path, "pytorch_model.bin"))
327
+ ):
328
+ continue
329
+ subdirs.append(path)
330
+ subdirs.sort(key=lambda d: os.path.getmtime(d), reverse=True)
331
+ resume_path = subdirs[0] if subdirs else None
332
+ except Exception:
333
+ resume_path = None
334
+ else:
335
+ resume_path = args.resume_from
336
+
337
+ trainer = Trainer(
338
+ model=model,
339
+ args=targs,
340
+ tokenizer=tok,
341
+ species_store=species_store,
342
+ resume_from_checkpoint=resume_path,
343
+ )
344
+ trainer.attach_dataloaders(train_loader, val_loader)
345
+
346
+ logger.info("Starting training...")
347
+ trainer.train()
348
+ logger.info("Training finished.")
349
+
350
+
351
+ if __name__ == "__main__":
352
+ main()
training_checkpoints/checkpoint-71000/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 2048,
3
+ "max_species_prefix": 0,
4
+ "max_protein_prefix": 1024,
5
+ "hidden_size": 750,
6
+ "num_hidden_layers": 20,
7
+ "num_attention_heads": 15,
8
+ "mlp_ratio": 3.2,
9
+ "prepend_species": true,
10
+ "prepend_protein": true,
11
+ "species_embedding_dim": 1024,
12
+ "esm_model_name": "esmc_300m",
13
+ "esm_device": "cuda:0",
14
+ "esm_dtype": "bf16",
15
+ "attn_impl": "mha",
16
+ "num_kv_groups": 5
17
+ }
training_checkpoints/checkpoint-71000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07bc223f4d934e2baff5a8085a78348766b6a8324aa091a1459fce2b2c6d3837
3
+ size 1284544520