File size: 13,399 Bytes
2d8da02
 
5343ca4
2d8da02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5343ca4
2d8da02
 
 
5343ca4
2d8da02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
#!/usr/bin/env python
"""
Sampling script for generating codon sequences from trained CodonTranslator models.
Inputs are prepared exactly like training:
- Species conditioning via SpeciesEmbeddingStore (fixed-size [B,Ds] or variable-length [B,Ls,Ds])
- Protein conditioning via raw AA strings (ESM-C tokenization happens inside the model)
"""

import argparse
import logging
import json
from pathlib import Path
from typing import List, Optional, Union

import torch

from src.sampler import CodonSampler
from src.dataset import SpeciesEmbeddingStore

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger("codontranslator.sample")


def parse_args():
    p = argparse.ArgumentParser(description="Sample codon sequences from a CodonTranslator model")

    # Model
    p.add_argument("--model_path", "--model_dir", dest="model_path", type=str, required=True,
                   help="Path to trained model checkpoint dir")
    p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
    p.add_argument("--compile", action="store_true", help="torch.compile the model")

    # Species embeddings
    p.add_argument("--embeddings_dir", type=str, default=None,
                   help="Directory with precomputed variable-length species embeddings (optional; fallback to Qwen if missing/unknown)")
    p.add_argument("--strict_species_lookup", action="store_true",
                   help="When using --embeddings_dir, fail if any requested species name is not an exact key in species_vocab.json")
    p.add_argument("--taxonomy_db", type=str, default=None,
                   help="Optional path to taxonomy_database.json (from precompute) to enrich prompts")

    # Sampling batch size and count
    p.add_argument("--num_sequences", "--num_seq", "--num_samples", type=int, default=1, dest="num_sequences",
                   help="Number of sequences to generate in total")
    p.add_argument("--batch_size", type=int, default=None, help="Batch size for sampling loop")

    # Control mode and length
    p.add_argument("--control_mode", choices=["fixed", "variable"], default="fixed",
                   help="fixed: disallow EOS, generate exactly sequence_length codons; variable: allow EOS")
    p.add_argument("--sequence_length", type=int, default=None,
                   help="Number of CODONS to generate (used as max steps in variable mode). "
                        "If omitted and protein sequences are provided, set to min protein length.")

    # Conditioning (REQUIRED: species and protein)
    p.add_argument("--species", "--taxon", type=str, default=None, dest="species",
                   help="Species name (e.g., 'Homo sapiens'). Replicated if num_sequences>1.")
    p.add_argument("--species_list", type=str, nargs="+", default=None,
                   help="List of species names (must match num_sequences).")

    p.add_argument("--protein_seq", "--protein_sequence", type=str, default=None, dest="protein_seq",
                   help="Protein sequence (AA string). Replicated if num_sequences>1.")
    p.add_argument("--protein_file", type=str, default=None,
                   help="Path to FASTA-like file (each non-header line is a sequence). Must provide at least num_sequences.")

    # Sampling params
    p.add_argument("--temperature", type=float, default=1, help="Sampling temperature")
    p.add_argument("--top_k", type=int, default=50, help="Top-k")
    p.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus)")
    p.add_argument("--enforce_translation", action="store_true", default=False,
                   help="Hard-mask codons to match the given protein AA at each position")
    p.add_argument("--seed", type=int, default=None)
    p.add_argument("--save_intermediate", action="store_true", help="Store intermediate token states")

    # Output
    p.add_argument("--output_file", type=str, default=None)
    p.add_argument("--output_format", type=str, default="fasta", choices=["fasta", "csv", "json"])

    # Misc
    p.add_argument("--quiet", action="store_true")
    return p.parse_args()


def load_protein_sequences(file_path: str) -> List[str]:
    """Load protein sequences: every non-'>' line is a sequence."""
    seqs: List[str] = []
    with open(file_path, "r") as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith(">"):
                seqs.append(line)
    return seqs


def setup_species_store(embeddings_dir: str) -> SpeciesEmbeddingStore:
    """Load species embedding store (prefer variable-length if available)."""
    # We don't guess. If you stored sequence-format, this will pick it; else fixed-size.
    return SpeciesEmbeddingStore(embeddings_dir, pooling="sequence")


def save_sequences(
    sequences: List[str],
    output_file: str,
    fmt: str,
    species: Optional[List[str]] = None,
    proteins: Optional[List[str]] = None,
    metadata: Optional[dict] = None,
):
    if fmt == "fasta":
        with open(output_file, "w") as f:
            for i, seq in enumerate(sequences):
                header = f">seq_{i}"
                if species and i < len(species):
                    header += f"|species={species[i]}"
                if proteins and i < len(proteins):
                    header += f"|protein_len={len(proteins[i])}"
                f.write(f"{header}\n{seq}\n")
        return

    if fmt == "csv":
        import pandas as pd
        data = {"sequence": sequences}
        if species:
            data["species"] = species[:len(sequences)]
        if proteins:
            data["protein_sequence"] = proteins[:len(sequences)]
        pd.DataFrame(data).to_csv(output_file, index=False)
        return

    # json
    payload = {"sequences": sequences, "metadata": metadata or {}}
    if species:
        payload["species"] = species[:len(sequences)]
    if proteins:
        payload["protein_sequences"] = proteins[:len(sequences)]
    with open(output_file, "w") as f:
        json.dump(payload, f, indent=2)


def translate_dna_to_aa(dna_seq: str) -> str:
    """Translate DNA (3-mer) using the standard genetic code."""
    g = {
        'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
        'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
        'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
        'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
        'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
        'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
        'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
        'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
    }
    L = len(dna_seq) // 3
    aa = [g.get(dna_seq[3*i:3*i+3], 'X') for i in range(L)]
    return ''.join(aa)


def report_token_accuracy(sequences: List[str], target_proteins: List[str]) -> None:
    for i, dna in enumerate(sequences):
        tgt = target_proteins[i] if i < len(target_proteins) else target_proteins[-1]
        gen_aa = translate_dna_to_aa(dna)
        L = min(len(gen_aa), len(tgt))
        if L == 0:
            acc = 0.0; num = 0; den = 0
        else:
            matches = sum(1 for a, b in zip(gen_aa[:L], tgt[:L]) if a == b)
            acc = matches / L; num = matches; den = L
        logger.info(f"AA token accuracy seq_{i+1}: {acc:.4f} ({num}/{den})")


def main():
    args = parse_args()

    if args.device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError("CUDA requested but not available")

    if args.seed is not None:
        torch.manual_seed(int(args.seed))

    # Conditioning must be provided – same invariants as training
    have_species_names = bool(args.species_list) or bool(args.species)
    have_protein = bool(args.protein_file) or bool(args.protein_seq)
    if not have_species_names or not have_protein:
        raise ValueError("Sampling requires BOTH species (names) and protein sequence(s).")

    # Species names list
    if args.species_list:
        species_names = list(args.species_list)
    else:
        species_names = [str(args.species)]

    # Protein sequences list
    if args.protein_file:
        protein_sequences = load_protein_sequences(args.protein_file)
    else:
        protein_sequences = [str(args.protein_seq)]

    # Expand/reconcile counts
    N = int(args.num_sequences)
    if len(species_names) == 1 and N > 1:
        species_names = species_names * N
    if len(protein_sequences) == 1 and N > 1:
        protein_sequences = protein_sequences * N

    if len(species_names) != N:
        raise ValueError(f"species count ({len(species_names)}) must equal num_sequences ({N})")
    if len(protein_sequences) < N:
        raise ValueError(f"protein sequences provided ({len(protein_sequences)}) less than num_sequences ({N})")
    if len(protein_sequences) > N:
        protein_sequences = protein_sequences[:N]

    # If no explicit sequence_length, use min protein length, so every sample has a valid AA at each fixed step
    if args.sequence_length is None:
        args.sequence_length = min(len(s) for s in protein_sequences)
        logger.info(f"Auto-set sequence_length to min protein length: {args.sequence_length} codons")

    if args.sequence_length <= 0:
        raise ValueError("sequence_length must be > 0")

    # Load species store if provided (preferred to exactly match training); unknown species will fallback to Qwen
    species_store = None
    if args.embeddings_dir:
        species_store = setup_species_store(args.embeddings_dir)
        logger.info(f"Loaded species store: {len(species_store.vocab)} species; Ds={species_store.Ds()}")
        if args.strict_species_lookup:
            unknown = sorted({name for name in species_names if name not in species_store.vocab})
            if unknown:
                preview = ", ".join(repr(x) for x in unknown[:5])
                more = "" if len(unknown) <= 5 else f" ... (+{len(unknown) - 5} more)"
                raise ValueError(
                    "strict species lookup failed; these names are not exact keys in species_vocab.json: "
                    f"{preview}{more}"
                )

    sampler = CodonSampler(
        model_path=args.model_path,
        device=args.device,
        compile_model=bool(args.compile),
        species_store=species_store,
        taxonomy_db_path=args.taxonomy_db,
    )

    # Batch loop
    batch_size = int(args.batch_size or N)
    all_sequences: List[str] = []
    all_intermediates = []

    total_batches = (N + batch_size - 1) // batch_size
    for start in range(0, N, batch_size):
        end = min(N, start + batch_size)
        bs = end - start
        batch_species = species_names[start:end]
        batch_proteins = protein_sequences[start:end]

        logger.info(f"Sampling batch {start//batch_size + 1}/{total_batches} (B={bs})")

        result = sampler.sample(
            num_sequences=bs,
            sequence_length=int(args.sequence_length),
            species=batch_species,
            protein_sequences=batch_proteins,
            control_mode=str(args.control_mode),
            temperature=float(args.temperature),
            top_k=int(args.top_k),
            top_p=float(args.top_p),
            seed=int(args.seed) if args.seed is not None else None,
            return_intermediate=bool(args.save_intermediate),
            progress_bar=not bool(args.quiet),
            enforce_translation=bool(args.enforce_translation),
        )

        seqs = result["sequences"]  # List[str]
        all_sequences.extend(seqs)
        if args.save_intermediate and "intermediate_states" in result:
            all_intermediates.append(result["intermediate_states"])

    logger.info(f"Generated {len(all_sequences)} sequences.")
    for i, seq in enumerate(all_sequences[:5]):
        logger.info(f"Sequence {i+1} ({len(seq)//3} codons): {seq[:60]}...")

    # Save outputs
    if args.output_file:
        meta = {
            "model_path": args.model_path,
            "temperature": args.temperature,
            "top_k": args.top_k,
            "top_p": args.top_p,
            "control_mode": args.control_mode,
            "sequence_length": int(args.sequence_length),
        }
        save_sequences(
            all_sequences,
            args.output_file,
            args.output_format,
            species=species_names,
            proteins=protein_sequences,
            metadata=meta,
        )
        logger.info(f"Saved sequences to {args.output_file}")

        # Report AA token accuracy when protein targets are given
        report_token_accuracy(all_sequences, protein_sequences)

        if args.save_intermediate and all_intermediates:
            inter_file = Path(args.output_file).with_suffix("").as_posix() + "_intermediate.pt"
            torch.save(all_intermediates, inter_file)
            logger.info(f"Saved intermediate states to {inter_file}")

    logger.info("Sampling completed.")


if __name__ == "__main__":
    main()