| |
| """ |
| Sampling script for generating codon sequences from trained CodonTranslator models. |
| Inputs are prepared exactly like training: |
| - Species conditioning via SpeciesEmbeddingStore (fixed-size [B,Ds] or variable-length [B,Ls,Ds]) |
| - Protein conditioning via raw AA strings (ESM-C tokenization happens inside the model) |
| """ |
|
|
| import argparse |
| import logging |
| import json |
| from pathlib import Path |
| from typing import List, Optional, Union |
|
|
| import torch |
|
|
| from src.sampler import CodonSampler |
| from src.dataset import SpeciesEmbeddingStore |
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger("codontranslator.sample") |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Sample codon sequences from a CodonTranslator model") |
|
|
| |
| p.add_argument("--model_path", "--model_dir", dest="model_path", type=str, required=True, |
| help="Path to trained model checkpoint dir") |
| p.add_argument("--device", type=str, default="cuda", help="cuda or cpu") |
| p.add_argument("--compile", action="store_true", help="torch.compile the model") |
|
|
| |
| p.add_argument("--embeddings_dir", type=str, default=None, |
| help="Directory with precomputed variable-length species embeddings (optional; fallback to Qwen if missing/unknown)") |
| p.add_argument("--strict_species_lookup", action="store_true", |
| help="When using --embeddings_dir, fail if any requested species name is not an exact key in species_vocab.json") |
| p.add_argument("--taxonomy_db", type=str, default=None, |
| help="Optional path to taxonomy_database.json (from precompute) to enrich prompts") |
|
|
| |
| p.add_argument("--num_sequences", "--num_seq", "--num_samples", type=int, default=1, dest="num_sequences", |
| help="Number of sequences to generate in total") |
| p.add_argument("--batch_size", type=int, default=None, help="Batch size for sampling loop") |
|
|
| |
| p.add_argument("--control_mode", choices=["fixed", "variable"], default="fixed", |
| help="fixed: disallow EOS, generate exactly sequence_length codons; variable: allow EOS") |
| p.add_argument("--sequence_length", type=int, default=None, |
| help="Number of CODONS to generate (used as max steps in variable mode). " |
| "If omitted and protein sequences are provided, set to min protein length.") |
|
|
| |
| p.add_argument("--species", "--taxon", type=str, default=None, dest="species", |
| help="Species name (e.g., 'Homo sapiens'). Replicated if num_sequences>1.") |
| p.add_argument("--species_list", type=str, nargs="+", default=None, |
| help="List of species names (must match num_sequences).") |
|
|
| p.add_argument("--protein_seq", "--protein_sequence", type=str, default=None, dest="protein_seq", |
| help="Protein sequence (AA string). Replicated if num_sequences>1.") |
| p.add_argument("--protein_file", type=str, default=None, |
| help="Path to FASTA-like file (each non-header line is a sequence). Must provide at least num_sequences.") |
|
|
| |
| p.add_argument("--temperature", type=float, default=1, help="Sampling temperature") |
| p.add_argument("--top_k", type=int, default=50, help="Top-k") |
| p.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus)") |
| p.add_argument("--enforce_translation", action="store_true", default=False, |
| help="Hard-mask codons to match the given protein AA at each position") |
| p.add_argument("--seed", type=int, default=None) |
| p.add_argument("--save_intermediate", action="store_true", help="Store intermediate token states") |
|
|
| |
| p.add_argument("--output_file", type=str, default=None) |
| p.add_argument("--output_format", type=str, default="fasta", choices=["fasta", "csv", "json"]) |
|
|
| |
| p.add_argument("--quiet", action="store_true") |
| return p.parse_args() |
|
|
|
|
| def load_protein_sequences(file_path: str) -> List[str]: |
| """Load protein sequences: every non-'>' line is a sequence.""" |
| seqs: List[str] = [] |
| with open(file_path, "r") as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith(">"): |
| seqs.append(line) |
| return seqs |
|
|
|
|
| def setup_species_store(embeddings_dir: str) -> SpeciesEmbeddingStore: |
| """Load species embedding store (prefer variable-length if available).""" |
| |
| return SpeciesEmbeddingStore(embeddings_dir, pooling="sequence") |
|
|
|
|
| def save_sequences( |
| sequences: List[str], |
| output_file: str, |
| fmt: str, |
| species: Optional[List[str]] = None, |
| proteins: Optional[List[str]] = None, |
| metadata: Optional[dict] = None, |
| ): |
| if fmt == "fasta": |
| with open(output_file, "w") as f: |
| for i, seq in enumerate(sequences): |
| header = f">seq_{i}" |
| if species and i < len(species): |
| header += f"|species={species[i]}" |
| if proteins and i < len(proteins): |
| header += f"|protein_len={len(proteins[i])}" |
| f.write(f"{header}\n{seq}\n") |
| return |
|
|
| if fmt == "csv": |
| import pandas as pd |
| data = {"sequence": sequences} |
| if species: |
| data["species"] = species[:len(sequences)] |
| if proteins: |
| data["protein_sequence"] = proteins[:len(sequences)] |
| pd.DataFrame(data).to_csv(output_file, index=False) |
| return |
|
|
| |
| payload = {"sequences": sequences, "metadata": metadata or {}} |
| if species: |
| payload["species"] = species[:len(sequences)] |
| if proteins: |
| payload["protein_sequences"] = proteins[:len(sequences)] |
| with open(output_file, "w") as f: |
| json.dump(payload, f, indent=2) |
|
|
|
|
| def translate_dna_to_aa(dna_seq: str) -> str: |
| """Translate DNA (3-mer) using the standard genetic code.""" |
| g = { |
| 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', |
| 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', |
| 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', |
| 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', |
| 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', |
| 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', |
| 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', |
| 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G' |
| } |
| L = len(dna_seq) // 3 |
| aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)] |
| return ''.join(aa) |
|
|
|
|
| def report_token_accuracy(sequences: List[str], target_proteins: List[str]) -> None: |
| for i, dna in enumerate(sequences): |
| tgt = target_proteins[i] if i < len(target_proteins) else target_proteins[-1] |
| gen_aa = translate_dna_to_aa(dna) |
| L = min(len(gen_aa), len(tgt)) |
| if L == 0: |
| acc = 0.0; num = 0; den = 0 |
| else: |
| matches = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b) |
| acc = matches / L; num = matches; den = L |
| logger.info(f"AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})") |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| if args.device == "cuda" and not torch.cuda.is_available(): |
| raise RuntimeError("CUDA requested but not available") |
|
|
| if args.seed is not None: |
| torch.manual_seed(int(args.seed)) |
|
|
| |
| have_species_names = bool(args.species_list) or bool(args.species) |
| have_protein = bool(args.protein_file) or bool(args.protein_seq) |
| if not have_species_names or not have_protein: |
| raise ValueError("Sampling requires BOTH species (names) and protein sequence(s).") |
|
|
| |
| if args.species_list: |
| species_names = list(args.species_list) |
| else: |
| species_names = [str(args.species)] |
|
|
| |
| if args.protein_file: |
| protein_sequences = load_protein_sequences(args.protein_file) |
| else: |
| protein_sequences = [str(args.protein_seq)] |
|
|
| |
| N = int(args.num_sequences) |
| if len(species_names) == 1 and N > 1: |
| species_names = species_names * N |
| if len(protein_sequences) == 1 and N > 1: |
| protein_sequences = protein_sequences * N |
|
|
| if len(species_names) != N: |
| raise ValueError(f"species count ({len(species_names)}) must equal num_sequences ({N})") |
| if len(protein_sequences) < N: |
| raise ValueError(f"protein sequences provided ({len(protein_sequences)}) less than num_sequences ({N})") |
| if len(protein_sequences) > N: |
| protein_sequences = protein_sequences[:N] |
|
|
| |
| if args.sequence_length is None: |
| args.sequence_length = min(len(s) for s in protein_sequences) |
| logger.info(f"Auto-set sequence_length to min protein length: {args.sequence_length} codons") |
|
|
| if args.sequence_length <= 0: |
| raise ValueError("sequence_length must be > 0") |
|
|
| |
| species_store = None |
| if args.embeddings_dir: |
| species_store = setup_species_store(args.embeddings_dir) |
| logger.info(f"Loaded species store: {len(species_store.vocab)} species; Ds={species_store.Ds()}") |
| if args.strict_species_lookup: |
| unknown = sorted({name for name in species_names if name not in species_store.vocab}) |
| if unknown: |
| preview = ", ".join(repr(x) for x in unknown[:5]) |
| more = "" if len(unknown) <= 5 else f" ... (+{len(unknown) - 5} more)" |
| raise ValueError( |
| "strict species lookup failed; these names are not exact keys in species_vocab.json: " |
| f"{preview}{more}" |
| ) |
|
|
| sampler = CodonSampler( |
| model_path=args.model_path, |
| device=args.device, |
| compile_model=bool(args.compile), |
| species_store=species_store, |
| taxonomy_db_path=args.taxonomy_db, |
| ) |
|
|
| |
| batch_size = int(args.batch_size or N) |
| all_sequences: List[str] = [] |
| all_intermediates = [] |
|
|
| total_batches = (N + batch_size - 1) // batch_size |
| for start in range(0, N, batch_size): |
| end = min(N, start + batch_size) |
| bs = end - start |
| batch_species = species_names[start:end] |
| batch_proteins = protein_sequences[start:end] |
|
|
| logger.info(f"Sampling batch {start//batch_size + 1}/{total_batches} (B={bs})") |
|
|
| result = sampler.sample( |
| num_sequences=bs, |
| sequence_length=int(args.sequence_length), |
| species=batch_species, |
| protein_sequences=batch_proteins, |
| control_mode=str(args.control_mode), |
| temperature=float(args.temperature), |
| top_k=int(args.top_k), |
| top_p=float(args.top_p), |
| seed=int(args.seed) if args.seed is not None else None, |
| return_intermediate=bool(args.save_intermediate), |
| progress_bar=not bool(args.quiet), |
| enforce_translation=bool(args.enforce_translation), |
| ) |
|
|
| seqs = result["sequences"] |
| all_sequences.extend(seqs) |
| if args.save_intermediate and "intermediate_states" in result: |
| all_intermediates.append(result["intermediate_states"]) |
|
|
| logger.info(f"Generated {len(all_sequences)} sequences.") |
| for i, seq in enumerate(all_sequences[:5]): |
| logger.info(f"Sequence {i+1} ({len(seq)//3} codons): {seq[:60]}...") |
|
|
| |
| if args.output_file: |
| meta = { |
| "model_path": args.model_path, |
| "temperature": args.temperature, |
| "top_k": args.top_k, |
| "top_p": args.top_p, |
| "control_mode": args.control_mode, |
| "sequence_length": int(args.sequence_length), |
| } |
| save_sequences( |
| all_sequences, |
| args.output_file, |
| args.output_format, |
| species=species_names, |
| proteins=protein_sequences, |
| metadata=meta, |
| ) |
| logger.info(f"Saved sequences to {args.output_file}") |
|
|
| |
| report_token_accuracy(all_sequences, protein_sequences) |
|
|
| if args.save_intermediate and all_intermediates: |
| inter_file = Path(args.output_file).with_suffix("").as_posix() + "_intermediate.pt" |
| torch.save(all_intermediates, inter_file) |
| logger.info(f"Saved intermediate states to {inter_file}") |
|
|
| logger.info("Sampling completed.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|