#!/usr/bin/env python """ Sampling script for generating codon sequences from trained CodonTranslator models. Inputs are prepared exactly like training: - Species conditioning via SpeciesEmbeddingStore (fixed-size [B,Ds] or variable-length [B,Ls,Ds]) - Protein conditioning via raw AA strings (ESM-C tokenization happens inside the model) """ import argparse import logging import json from pathlib import Path from typing import List, Optional, Union import torch from src.sampler import CodonSampler from src.dataset import SpeciesEmbeddingStore logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger("codontranslator.sample") def parse_args(): p = argparse.ArgumentParser(description="Sample codon sequences from a CodonTranslator model") # Model p.add_argument("--model_path", "--model_dir", dest="model_path", type=str, required=True, help="Path to trained model checkpoint dir") p.add_argument("--device", type=str, default="cuda", help="cuda or cpu") p.add_argument("--compile", action="store_true", help="torch.compile the model") # Species embeddings p.add_argument("--embeddings_dir", type=str, default=None, help="Directory with precomputed variable-length species embeddings (optional; fallback to Qwen if missing/unknown)") p.add_argument("--strict_species_lookup", action="store_true", help="When using --embeddings_dir, fail if any requested species name is not an exact key in species_vocab.json") p.add_argument("--taxonomy_db", type=str, default=None, help="Optional path to taxonomy_database.json (from precompute) to enrich prompts") # Sampling batch size and count p.add_argument("--num_sequences", "--num_seq", "--num_samples", type=int, default=1, dest="num_sequences", help="Number of sequences to generate in total") p.add_argument("--batch_size", type=int, default=None, help="Batch size for sampling loop") # Control mode and length p.add_argument("--control_mode", choices=["fixed", "variable"], default="fixed", help="fixed: disallow EOS, generate exactly sequence_length codons; variable: allow EOS") p.add_argument("--sequence_length", type=int, default=None, help="Number of CODONS to generate (used as max steps in variable mode). " "If omitted and protein sequences are provided, set to min protein length.") # Conditioning (REQUIRED: species and protein) p.add_argument("--species", "--taxon", type=str, default=None, dest="species", help="Species name (e.g., 'Homo sapiens'). Replicated if num_sequences>1.") p.add_argument("--species_list", type=str, nargs="+", default=None, help="List of species names (must match num_sequences).") p.add_argument("--protein_seq", "--protein_sequence", type=str, default=None, dest="protein_seq", help="Protein sequence (AA string). Replicated if num_sequences>1.") p.add_argument("--protein_file", type=str, default=None, help="Path to FASTA-like file (each non-header line is a sequence). Must provide at least num_sequences.") # Sampling params p.add_argument("--temperature", type=float, default=1, help="Sampling temperature") p.add_argument("--top_k", type=int, default=50, help="Top-k") p.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus)") p.add_argument("--enforce_translation", action="store_true", default=False, help="Hard-mask codons to match the given protein AA at each position") p.add_argument("--seed", type=int, default=None) p.add_argument("--save_intermediate", action="store_true", help="Store intermediate token states") # Output p.add_argument("--output_file", type=str, default=None) p.add_argument("--output_format", type=str, default="fasta", choices=["fasta", "csv", "json"]) # Misc p.add_argument("--quiet", action="store_true") return p.parse_args() def load_protein_sequences(file_path: str) -> List[str]: """Load protein sequences: every non-'>' line is a sequence.""" seqs: List[str] = [] with open(file_path, "r") as f: for line in f: line = line.strip() if line and not line.startswith(">"): seqs.append(line) return seqs def setup_species_store(embeddings_dir: str) -> SpeciesEmbeddingStore: """Load species embedding store (prefer variable-length if available).""" # We don't guess. If you stored sequence-format, this will pick it; else fixed-size. return SpeciesEmbeddingStore(embeddings_dir, pooling="sequence") def save_sequences( sequences: List[str], output_file: str, fmt: str, species: Optional[List[str]] = None, proteins: Optional[List[str]] = None, metadata: Optional[dict] = None, ): if fmt == "fasta": with open(output_file, "w") as f: for i, seq in enumerate(sequences): header = f">seq_{i}" if species and i < len(species): header += f"|species={species[i]}" if proteins and i < len(proteins): header += f"|protein_len={len(proteins[i])}" f.write(f"{header}\n{seq}\n") return if fmt == "csv": import pandas as pd data = {"sequence": sequences} if species: data["species"] = species[:len(sequences)] if proteins: data["protein_sequence"] = proteins[:len(sequences)] pd.DataFrame(data).to_csv(output_file, index=False) return # json payload = {"sequences": sequences, "metadata": metadata or {}} if species: payload["species"] = species[:len(sequences)] if proteins: payload["protein_sequences"] = proteins[:len(sequences)] with open(output_file, "w") as f: json.dump(payload, f, indent=2) def translate_dna_to_aa(dna_seq: str) -> str: """Translate DNA (3-mer) using the standard genetic code.""" g = { 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G' } L = len(dna_seq) // 3 aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)] return ''.join(aa) def report_token_accuracy(sequences: List[str], target_proteins: List[str]) -> None: for i, dna in enumerate(sequences): tgt = target_proteins[i] if i < len(target_proteins) else target_proteins[-1] gen_aa = translate_dna_to_aa(dna) L = min(len(gen_aa), len(tgt)) if L == 0: acc = 0.0; num = 0; den = 0 else: matches = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b) acc = matches / L; num = matches; den = L logger.info(f"AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})") def main(): args = parse_args() if args.device == "cuda" and not torch.cuda.is_available(): raise RuntimeError("CUDA requested but not available") if args.seed is not None: torch.manual_seed(int(args.seed)) # Conditioning must be provided – same invariants as training have_species_names = bool(args.species_list) or bool(args.species) have_protein = bool(args.protein_file) or bool(args.protein_seq) if not have_species_names or not have_protein: raise ValueError("Sampling requires BOTH species (names) and protein sequence(s).") # Species names list if args.species_list: species_names = list(args.species_list) else: species_names = [str(args.species)] # Protein sequences list if args.protein_file: protein_sequences = load_protein_sequences(args.protein_file) else: protein_sequences = [str(args.protein_seq)] # Expand/reconcile counts N = int(args.num_sequences) if len(species_names) == 1 and N > 1: species_names = species_names * N if len(protein_sequences) == 1 and N > 1: protein_sequences = protein_sequences * N if len(species_names) != N: raise ValueError(f"species count ({len(species_names)}) must equal num_sequences ({N})") if len(protein_sequences) < N: raise ValueError(f"protein sequences provided ({len(protein_sequences)}) less than num_sequences ({N})") if len(protein_sequences) > N: protein_sequences = protein_sequences[:N] # If no explicit sequence_length, use min protein length, so every sample has a valid AA at each fixed step if args.sequence_length is None: args.sequence_length = min(len(s) for s in protein_sequences) logger.info(f"Auto-set sequence_length to min protein length: {args.sequence_length} codons") if args.sequence_length <= 0: raise ValueError("sequence_length must be > 0") # Load species store if provided (preferred to exactly match training); unknown species will fallback to Qwen species_store = None if args.embeddings_dir: species_store = setup_species_store(args.embeddings_dir) logger.info(f"Loaded species store: {len(species_store.vocab)} species; Ds={species_store.Ds()}") if args.strict_species_lookup: unknown = sorted({name for name in species_names if name not in species_store.vocab}) if unknown: preview = ", ".join(repr(x) for x in unknown[:5]) more = "" if len(unknown) <= 5 else f" ... (+{len(unknown) - 5} more)" raise ValueError( "strict species lookup failed; these names are not exact keys in species_vocab.json: " f"{preview}{more}" ) sampler = CodonSampler( model_path=args.model_path, device=args.device, compile_model=bool(args.compile), species_store=species_store, taxonomy_db_path=args.taxonomy_db, ) # Batch loop batch_size = int(args.batch_size or N) all_sequences: List[str] = [] all_intermediates = [] total_batches = (N + batch_size - 1) // batch_size for start in range(0, N, batch_size): end = min(N, start + batch_size) bs = end - start batch_species = species_names[start:end] batch_proteins = protein_sequences[start:end] logger.info(f"Sampling batch {start//batch_size + 1}/{total_batches} (B={bs})") result = sampler.sample( num_sequences=bs, sequence_length=int(args.sequence_length), species=batch_species, protein_sequences=batch_proteins, control_mode=str(args.control_mode), temperature=float(args.temperature), top_k=int(args.top_k), top_p=float(args.top_p), seed=int(args.seed) if args.seed is not None else None, return_intermediate=bool(args.save_intermediate), progress_bar=not bool(args.quiet), enforce_translation=bool(args.enforce_translation), ) seqs = result["sequences"] # List[str] all_sequences.extend(seqs) if args.save_intermediate and "intermediate_states" in result: all_intermediates.append(result["intermediate_states"]) logger.info(f"Generated {len(all_sequences)} sequences.") for i, seq in enumerate(all_sequences[:5]): logger.info(f"Sequence {i+1} ({len(seq)//3} codons): {seq[:60]}...") # Save outputs if args.output_file: meta = { "model_path": args.model_path, "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, "control_mode": args.control_mode, "sequence_length": int(args.sequence_length), } save_sequences( all_sequences, args.output_file, args.output_format, species=species_names, proteins=protein_sequences, metadata=meta, ) logger.info(f"Saved sequences to {args.output_file}") # Report AA token accuracy when protein targets are given report_token_accuracy(all_sequences, protein_sequences) if args.save_intermediate and all_intermediates: inter_file = Path(args.output_file).with_suffix("").as_posix() + "_intermediate.pt" torch.save(all_intermediates, inter_file) logger.info(f"Saved intermediate states to {inter_file}") logger.info("Sampling completed.") if __name__ == "__main__": main()