File size: 13,399 Bytes
2d8da02 5343ca4 2d8da02 5343ca4 2d8da02 5343ca4 2d8da02 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 | #!/usr/bin/env python
"""
Sampling script for generating codon sequences from trained CodonTranslator models.
Inputs are prepared exactly like training:
- Species conditioning via SpeciesEmbeddingStore (fixed-size [B,Ds] or variable-length [B,Ls,Ds])
- Protein conditioning via raw AA strings (ESM-C tokenization happens inside the model)
"""
import argparse
import logging
import json
from pathlib import Path
from typing import List, Optional, Union
import torch
from src.sampler import CodonSampler
from src.dataset import SpeciesEmbeddingStore
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger("codontranslator.sample")
def parse_args():
p = argparse.ArgumentParser(description="Sample codon sequences from a CodonTranslator model")
# Model
p.add_argument("--model_path", "--model_dir", dest="model_path", type=str, required=True,
help="Path to trained model checkpoint dir")
p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
p.add_argument("--compile", action="store_true", help="torch.compile the model")
# Species embeddings
p.add_argument("--embeddings_dir", type=str, default=None,
help="Directory with precomputed variable-length species embeddings (optional; fallback to Qwen if missing/unknown)")
p.add_argument("--strict_species_lookup", action="store_true",
help="When using --embeddings_dir, fail if any requested species name is not an exact key in species_vocab.json")
p.add_argument("--taxonomy_db", type=str, default=None,
help="Optional path to taxonomy_database.json (from precompute) to enrich prompts")
# Sampling batch size and count
p.add_argument("--num_sequences", "--num_seq", "--num_samples", type=int, default=1, dest="num_sequences",
help="Number of sequences to generate in total")
p.add_argument("--batch_size", type=int, default=None, help="Batch size for sampling loop")
# Control mode and length
p.add_argument("--control_mode", choices=["fixed", "variable"], default="fixed",
help="fixed: disallow EOS, generate exactly sequence_length codons; variable: allow EOS")
p.add_argument("--sequence_length", type=int, default=None,
help="Number of CODONS to generate (used as max steps in variable mode). "
"If omitted and protein sequences are provided, set to min protein length.")
# Conditioning (REQUIRED: species and protein)
p.add_argument("--species", "--taxon", type=str, default=None, dest="species",
help="Species name (e.g., 'Homo sapiens'). Replicated if num_sequences>1.")
p.add_argument("--species_list", type=str, nargs="+", default=None,
help="List of species names (must match num_sequences).")
p.add_argument("--protein_seq", "--protein_sequence", type=str, default=None, dest="protein_seq",
help="Protein sequence (AA string). Replicated if num_sequences>1.")
p.add_argument("--protein_file", type=str, default=None,
help="Path to FASTA-like file (each non-header line is a sequence). Must provide at least num_sequences.")
# Sampling params
p.add_argument("--temperature", type=float, default=1, help="Sampling temperature")
p.add_argument("--top_k", type=int, default=50, help="Top-k")
p.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus)")
p.add_argument("--enforce_translation", action="store_true", default=False,
help="Hard-mask codons to match the given protein AA at each position")
p.add_argument("--seed", type=int, default=None)
p.add_argument("--save_intermediate", action="store_true", help="Store intermediate token states")
# Output
p.add_argument("--output_file", type=str, default=None)
p.add_argument("--output_format", type=str, default="fasta", choices=["fasta", "csv", "json"])
# Misc
p.add_argument("--quiet", action="store_true")
return p.parse_args()
def load_protein_sequences(file_path: str) -> List[str]:
"""Load protein sequences: every non-'>' line is a sequence."""
seqs: List[str] = []
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith(">"):
seqs.append(line)
return seqs
def setup_species_store(embeddings_dir: str) -> SpeciesEmbeddingStore:
"""Load species embedding store (prefer variable-length if available)."""
# We don't guess. If you stored sequence-format, this will pick it; else fixed-size.
return SpeciesEmbeddingStore(embeddings_dir, pooling="sequence")
def save_sequences(
sequences: List[str],
output_file: str,
fmt: str,
species: Optional[List[str]] = None,
proteins: Optional[List[str]] = None,
metadata: Optional[dict] = None,
):
if fmt == "fasta":
with open(output_file, "w") as f:
for i, seq in enumerate(sequences):
header = f">seq_{i}"
if species and i < len(species):
header += f"|species={species[i]}"
if proteins and i < len(proteins):
header += f"|protein_len={len(proteins[i])}"
f.write(f"{header}\n{seq}\n")
return
if fmt == "csv":
import pandas as pd
data = {"sequence": sequences}
if species:
data["species"] = species[:len(sequences)]
if proteins:
data["protein_sequence"] = proteins[:len(sequences)]
pd.DataFrame(data).to_csv(output_file, index=False)
return
# json
payload = {"sequences": sequences, "metadata": metadata or {}}
if species:
payload["species"] = species[:len(sequences)]
if proteins:
payload["protein_sequences"] = proteins[:len(sequences)]
with open(output_file, "w") as f:
json.dump(payload, f, indent=2)
def translate_dna_to_aa(dna_seq: str) -> str:
"""Translate DNA (3-mer) using the standard genetic code."""
g = {
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
}
L = len(dna_seq) // 3
aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)]
return ''.join(aa)
def report_token_accuracy(sequences: List[str], target_proteins: List[str]) -> None:
for i, dna in enumerate(sequences):
tgt = target_proteins[i] if i < len(target_proteins) else target_proteins[-1]
gen_aa = translate_dna_to_aa(dna)
L = min(len(gen_aa), len(tgt))
if L == 0:
acc = 0.0; num = 0; den = 0
else:
matches = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b)
acc = matches / L; num = matches; den = L
logger.info(f"AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})")
def main():
args = parse_args()
if args.device == "cuda" and not torch.cuda.is_available():
raise RuntimeError("CUDA requested but not available")
if args.seed is not None:
torch.manual_seed(int(args.seed))
# Conditioning must be provided – same invariants as training
have_species_names = bool(args.species_list) or bool(args.species)
have_protein = bool(args.protein_file) or bool(args.protein_seq)
if not have_species_names or not have_protein:
raise ValueError("Sampling requires BOTH species (names) and protein sequence(s).")
# Species names list
if args.species_list:
species_names = list(args.species_list)
else:
species_names = [str(args.species)]
# Protein sequences list
if args.protein_file:
protein_sequences = load_protein_sequences(args.protein_file)
else:
protein_sequences = [str(args.protein_seq)]
# Expand/reconcile counts
N = int(args.num_sequences)
if len(species_names) == 1 and N > 1:
species_names = species_names * N
if len(protein_sequences) == 1 and N > 1:
protein_sequences = protein_sequences * N
if len(species_names) != N:
raise ValueError(f"species count ({len(species_names)}) must equal num_sequences ({N})")
if len(protein_sequences) < N:
raise ValueError(f"protein sequences provided ({len(protein_sequences)}) less than num_sequences ({N})")
if len(protein_sequences) > N:
protein_sequences = protein_sequences[:N]
# If no explicit sequence_length, use min protein length, so every sample has a valid AA at each fixed step
if args.sequence_length is None:
args.sequence_length = min(len(s) for s in protein_sequences)
logger.info(f"Auto-set sequence_length to min protein length: {args.sequence_length} codons")
if args.sequence_length <= 0:
raise ValueError("sequence_length must be > 0")
# Load species store if provided (preferred to exactly match training); unknown species will fallback to Qwen
species_store = None
if args.embeddings_dir:
species_store = setup_species_store(args.embeddings_dir)
logger.info(f"Loaded species store: {len(species_store.vocab)} species; Ds={species_store.Ds()}")
if args.strict_species_lookup:
unknown = sorted({name for name in species_names if name not in species_store.vocab})
if unknown:
preview = ", ".join(repr(x) for x in unknown[:5])
more = "" if len(unknown) <= 5 else f" ... (+{len(unknown) - 5} more)"
raise ValueError(
"strict species lookup failed; these names are not exact keys in species_vocab.json: "
f"{preview}{more}"
)
sampler = CodonSampler(
model_path=args.model_path,
device=args.device,
compile_model=bool(args.compile),
species_store=species_store,
taxonomy_db_path=args.taxonomy_db,
)
# Batch loop
batch_size = int(args.batch_size or N)
all_sequences: List[str] = []
all_intermediates = []
total_batches = (N + batch_size - 1) // batch_size
for start in range(0, N, batch_size):
end = min(N, start + batch_size)
bs = end - start
batch_species = species_names[start:end]
batch_proteins = protein_sequences[start:end]
logger.info(f"Sampling batch {start//batch_size + 1}/{total_batches} (B={bs})")
result = sampler.sample(
num_sequences=bs,
sequence_length=int(args.sequence_length),
species=batch_species,
protein_sequences=batch_proteins,
control_mode=str(args.control_mode),
temperature=float(args.temperature),
top_k=int(args.top_k),
top_p=float(args.top_p),
seed=int(args.seed) if args.seed is not None else None,
return_intermediate=bool(args.save_intermediate),
progress_bar=not bool(args.quiet),
enforce_translation=bool(args.enforce_translation),
)
seqs = result["sequences"] # List[str]
all_sequences.extend(seqs)
if args.save_intermediate and "intermediate_states" in result:
all_intermediates.append(result["intermediate_states"])
logger.info(f"Generated {len(all_sequences)} sequences.")
for i, seq in enumerate(all_sequences[:5]):
logger.info(f"Sequence {i+1} ({len(seq)//3} codons): {seq[:60]}...")
# Save outputs
if args.output_file:
meta = {
"model_path": args.model_path,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"control_mode": args.control_mode,
"sequence_length": int(args.sequence_length),
}
save_sequences(
all_sequences,
args.output_file,
args.output_format,
species=species_names,
proteins=protein_sequences,
metadata=meta,
)
logger.info(f"Saved sequences to {args.output_file}")
# Report AA token accuracy when protein targets are given
report_token_accuracy(all_sequences, protein_sequences)
if args.save_intermediate and all_intermediates:
inter_file = Path(args.output_file).with_suffix("").as_posix() + "_intermediate.pt"
torch.save(all_intermediates, inter_file)
logger.info(f"Saved intermediate states to {inter_file}")
logger.info("Sampling completed.")
if __name__ == "__main__":
main()
|