CodonTranslator / eval.py
alegendaryfish's picture
Rename public-facing CodonGPT strings to CodonTranslator
4547beb verified
#!/usr/bin/env python
"""
Teacher-forced (and optional free-run) evaluation on a random subset of your
dataset to measure codon token cross-entropy and AA token accuracy, using the
same conditioning pathway as training.
Supports either a CSV file or Parquet input via a directory/glob (e.g.,
./data/val/*.parquet).
Usage examples:
# CSV input
python eval.py \
--model_path outputs/checkpoint-21000 \
--data_path random_sample_1000.csv \
--embeddings_dir embeddings \
--num_samples 10 \
--batch_size 10 \
--device cuda
# Parquet glob input
python eval.py \
--model_path outputs/checkpoint-21000 \
--data_path "./data/val/*.parquet" \
--embeddings_dir embeddings \
--num_samples 64 \
--batch_size 32 \
--device cuda
"""
import argparse
import json
import logging
import random
from pathlib import Path
from typing import List, Optional, Tuple
import glob
import torch
import torch.nn.functional as F
import pandas as pd
from src.sampler import CodonSampler
from src.dataset import SpeciesEmbeddingStore, StreamSeqDataset, stage_collate_fn
from torch.utils.data import DataLoader
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger("eval_tf")
def parse_args():
p = argparse.ArgumentParser("Teacher-forced evaluation of CodonTranslator")
p.add_argument("--model_path", required=True, type=str,
help="Path to checkpoint dir (with config.json / model.safetensors)")
# Input data: CSV file or Parquet glob/dir
p.add_argument("--data_path", required=False, type=str, default=None,
help="CSV file or Parquet glob/dir (e.g., ./data/val/*.parquet)")
# Back-compat: --csv_path still accepted (deprecated)
p.add_argument("--csv_path", required=False, type=str, default=None,
help="[Deprecated] CSV with columns: Taxon, protein_seq, cds_DNA")
p.add_argument("--embeddings_dir", type=str, default=None,
help="Species embeddings directory (recommended for parity)")
p.add_argument("--num_samples", type=int, default=10)
p.add_argument("--batch_size", type=int, default=10)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", type=str, default="cuda")
p.add_argument("--workers", type=int, default=0,
help="DataLoader workers for --eval_all streaming mode")
# Free-run (sampling) evaluation options
p.add_argument("--free_run", action="store_true",
help="If set, perform real sampling instead of teacher forcing and compare to ground-truth codon sequences")
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top_k", type=int, default=50)
p.add_argument("--top_p", type=float, default=0.9)
p.add_argument("--control_mode", type=str, choices=["fixed","variable"], default="fixed")
p.add_argument("--enforce_translation", action="store_true",
help="Hard-mask decoding to codons matching target amino acid at each position during free-run evaluation")
# Full-dataset streaming eval (no sampling)
p.add_argument("--eval_all", action="store_true",
help="Stream over all rows from --data_path and compute aggregated metrics (memory-safe)")
p.add_argument("--max_records", type=int, default=0,
help="When --eval_all is set: limit to first N samples (0 = all)")
p.add_argument("--debug_aa_check", action="store_true",
help="Print per-sample agreement between CDS→AA (standard code) and provided protein")
# Per-sequence export over standard splits ./data/val and ./data/test
p.add_argument("--export_per_sequence", action="store_true",
help="Process ./data/val and ./data/test parquets in batches and export a per-sequence CSV")
p.add_argument("--splits_root", type=str, default="./data",
help="Root directory that contains val/ and test/ subfolders with parquet files")
p.add_argument("--out_csv", type=str, default="outputs/eval_per_sequence.csv",
help="Output CSV path for per-sequence export")
p.add_argument("--export_splits", nargs="+", default=["val", "test"],
help="Subdirectories under --splits_root to process (default: val test)")
p.add_argument("--max_rows_per_split", type=int, default=0,
help="When --export_per_sequence is set: limit number of rows per split (0 = all)")
p.add_argument("--progress", action="store_true",
help="Show progress bars during per-sequence export")
# Capacity and evaluation controls
p.add_argument("--no_truncation", action="store_true",
help="Fit prefix caps so generated codon length equals protein length (avoids capacity truncation)")
p.add_argument("--species_prefix_cap", type=int, default=0,
help="When >0 and --no_truncation is set, cap species token prefix to this many tokens; 0 = no species cap")
return p.parse_args()
def _is_parquet_path(p: str) -> bool:
lower = p.lower()
return lower.endswith(".parquet") or lower.endswith(".parq")
def _expand_paths(maybe_path_or_glob: str) -> List[str]:
"""Expand a path/glob or directory into a sorted list of files.
Prioritize Parquet when scanning a directory.
"""
paths: List[str] = []
P = Path(maybe_path_or_glob)
if P.is_dir():
paths.extend(sorted(str(x) for x in P.rglob("*.parquet")))
paths.extend(sorted(str(x) for x in P.rglob("*.parq")))
paths.extend(sorted(str(x) for x in P.rglob("*.csv")))
paths.extend(sorted(str(x) for x in P.rglob("*.tsv")))
paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz")))
paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz")))
else:
paths = sorted(glob.glob(str(P)))
# Dedup while preserving order
out: List[str] = []
seen = set()
for x in paths:
if x not in seen:
out.append(x)
seen.add(x)
return out
def _load_random_samples_from_parquet(files: List[str], num_samples: int, seed: int) -> pd.DataFrame:
"""Collect up to num_samples rows from a list of Parquet files, reading by row group.
Reads only the required columns and shuffles files/row-groups for decent coverage.
"""
try:
import pyarrow.parquet as pq # type: ignore
except Exception as e: # pragma: no cover
raise ImportError("pyarrow is required to read parquet files") from e
rng = random.Random(seed)
req = ["Taxon", "protein_seq", "cds_DNA"]
files = [f for f in files if _is_parquet_path(f)]
if not files:
raise FileNotFoundError("No Parquet files found to read")
files = files.copy()
rng.shuffle(files)
collected: List[pd.DataFrame] = []
remaining = int(max(0, num_samples))
for fp in files:
if remaining <= 0:
break
pf = pq.ParquetFile(fp)
nrg = int(pf.num_row_groups or 0)
if nrg <= 0:
rgs = [0]
else:
rgs = list(range(nrg))
rng.shuffle(rgs)
# Only keep columns that exist in this file
cols = [c for c in req if c in pf.schema.names]
if len(cols) < len(req):
missing = sorted(set(req) - set(cols))
raise ValueError(f"Parquet missing required columns {missing} in {fp}")
for rg in rgs:
if remaining <= 0:
break
table = pf.read_row_group(rg, columns=cols)
df = table.to_pandas(types_mapper=None)
if df.empty:
continue
if len(df) > remaining:
df = df.sample(n=remaining, random_state=rng.randint(0, 2**31 - 1))
collected.append(df)
remaining -= len(df)
if not collected:
return pd.DataFrame(columns=req)
out = pd.concat(collected, ignore_index=True)
# Final shuffle for randomness
out = out.sample(frac=1.0, random_state=seed).reset_index(drop=True)
# If we somehow overshot, trim
if len(out) > num_samples:
out = out.iloc[:num_samples].reset_index(drop=True)
return out
def _preferred_pooling(model_dir: Path) -> str:
"""
Best-effort pooling detection:
- First try checkpoint configs for an explicit hint
- Fallback to 'last'
Note: we'll further override this using the embeddings_dir contents if provided.
"""
for cfg_name in ("trainer_config.json", "config.json"):
fp = model_dir / cfg_name
if fp.exists():
try:
with open(fp) as f:
cfg = json.load(f)
return str(cfg.get("species_pooling", "last"))
except Exception:
continue
return "last"
def _detect_pooling_from_embeddings_dir(emb_dir: Path) -> Optional[str]:
"""Detect actual available pooling format from embeddings_dir contents."""
fixed_files = [emb_dir / "species_embeddings.bin", emb_dir / "species_metadata.json", emb_dir / "species_vocab.json"]
seq_files = [emb_dir / "species_tok_emb.bin", emb_dir / "species_index.json", emb_dir / "species_vocab.json"]
if all(p.exists() for p in fixed_files):
return "last"
if all(p.exists() for p in seq_files):
return "sequence"
return None
@torch.no_grad()
def eval_batch(
sampler: CodonSampler,
species_store: Optional[SpeciesEmbeddingStore],
species_names: List[str],
protein_seqs: List[str],
dna_cds_list: List[str],
) -> Tuple[List[float], List[float]]:
"""Evaluate a batch in teacher-forced mode.
Returns per-sample (avg_ce_loss, aa_token_acc).
"""
tok = sampler.tokenizer
pad_id = tok.pad_token_id
eos_id = tok.eos_token_id
# Encode DNA to codon ids and align lengths (trim to min protein length)
codon_ids = []
seq_lens = []
for dna, prot in zip(dna_cds_list, protein_seqs):
# Trim to min length between DNA codons and protein AA
C_dna = len(dna) // 3
C_prot = len(prot)
C = max(min(C_dna, C_prot), 1)
dna_trim = dna[: 3 * C]
ids = tok.encode_codon_seq(dna_trim, validate=False)
ids.append(eos_id)
codon_ids.append(ids)
seq_lens.append(len(ids))
B = len(codon_ids)
T = max(seq_lens)
codons = torch.full((B, T), pad_id, dtype=torch.long)
mask = torch.zeros((B, T), dtype=torch.bool)
for i, ids in enumerate(codon_ids):
L = len(ids)
codons[i, :L] = torch.tensor(ids, dtype=torch.long)
mask[i, :L] = True
# inputs/labels aligned to training convention:
# model predicts next codon after a learned start token; labels are the
# same positions as inputs (not shifted by 1), with PAD/EOS masked out.
input_ids = codons[:, :-1]
labels_base = codons[:, :-1].clone()
# Mask out PAD and EOS like trainer.evaluate()
labels_base[labels_base == pad_id] = -100
labels_base[labels_base == eos_id] = -100
# Build conditioning dict similar to training and sampler
cond = {"control_mode": "fixed"}
if species_store is not None and species_names:
sid_list = [species_store.vocab.get(s, -1) for s in species_names]
num_unknown = sum(1 for x in sid_list if x < 0)
if num_unknown > 0:
logger.warning(f"{num_unknown}/{len(sid_list)} species not found in embeddings vocab; using zero embeddings")
result = species_store.batch_get(sid_list)
if isinstance(result, tuple):
sp_tok, _ = result # [B, Ls, Ds]
cond["species_tok_emb_src"] = sp_tok.to(sampler.device)
cond["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
else:
sp = result # [B, Ds]
cond["species_emb_src"] = sp.to(sampler.device)
cond["species_emb_tgt"] = sp.to(sampler.device)
elif species_names:
# On-the-fly species embeddings using Qwen (sequence pooling for training parity)
seq_emb, _lens = sampler._qwen_embed_names(species_names, pooling="sequence")
seq_emb = seq_emb.to(sampler.device)
cond["species_tok_emb_src"] = seq_emb
cond["species_tok_emb_tgt"] = seq_emb
# Match training: pass raw protein sequences; the model tokenizes internally
cond["protein_seqs"] = protein_seqs
# Move tensors to device
device = sampler.device
input_ids = input_ids.to(device)
labels_base = labels_base.to(device)
sampler.model.eval()
outputs = sampler.model(codon_ids=input_ids, cond=cond, labels=labels_base, return_dict=True)
logits = outputs["logits"] # [B, Lmax, V] aligned to per-sample capacity after prefix
try:
prefix_len = outputs.get("prefix_len", 0)
if isinstance(prefix_len, torch.Tensor):
prefix_len_dbg = int(prefix_len.max().item()) if prefix_len.numel() > 0 else 0
else:
prefix_len_dbg = int(prefix_len)
logger.debug(f"Prefix length(max)={prefix_len_dbg}, input_len={input_ids.size(1)}")
except Exception:
pass
# Align labels/masks to logits length and per-sample caps
Bsz, Lmax, V = logits.size(0), logits.size(1), logits.size(2)
labels_aligned = torch.full((Bsz, Lmax), -100, dtype=labels_base.dtype, device=logits.device)
common_cols = min(labels_base.size(1), Lmax)
if common_cols > 0:
labels_aligned[:, :common_cols] = labels_base[:, :common_cols]
per_cap = outputs.get("per_cap", None)
if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz:
ar = torch.arange(Lmax, device=logits.device).unsqueeze(0)
cap_mask = ar < per_cap.to(device=logits.device).unsqueeze(1) # [B,Lmax]
else:
cap_mask = torch.ones_like(labels_aligned, dtype=torch.bool, device=logits.device)
# Mask labels beyond per-cap to -100 so CE ignores them
labels_masked = labels_aligned.clone().to(device=logits.device)
labels_masked[~cap_mask] = -100
# Cross-entropy per sample (include EOS target; ignore PAD)
loss_flat = F.cross_entropy(
logits.reshape(-1, V),
labels_masked.reshape(-1),
ignore_index=-100,
reduction="none",
).view(Bsz, Lmax)
# Accuracy per sample
preds = logits.argmax(dim=-1)
num_special = int(getattr(tok, "num_special_tokens", 0) or 0)
supervised = (labels_masked != -100) & cap_mask
if num_special > 0:
supervised = supervised & (labels_aligned >= num_special)
correct = (preds == labels_aligned) & supervised
per_sample_ce: List[float] = []
per_sample_acc: List[float] = []
per_sample_aa_acc: List[float] = []
codon2aa = tok.codon2aa_char_map() if hasattr(tok, "codon2aa_char_map") else {}
per_cap = outputs.get("per_cap", None)
per_cap_int = None
if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz:
per_cap_int = torch.clamp(per_cap.to(dtype=torch.long, device=logits.device), min=0, max=Lmax)
for i in range(B):
# Average CE over valid positions
valid = (labels_masked[i] != -100) & cap_mask[i]
if num_special > 0:
valid = valid & (labels_aligned[i] >= num_special)
ce = (loss_flat[i][valid].mean().item() if valid.any() else 0.0)
per_sample_ce.append(ce)
# Codon-level accuracy over supervised positions
denom = supervised[i].sum().item()
acc = (correct[i].sum().item() / denom) if denom > 0 else 0.0
# AA-level accuracy per sample (match trainer)
aa_acc = 0.0
if per_cap_int is not None and codon2aa and i < len(protein_seqs):
cap = int(per_cap_int[i].item())
if cap > 0:
mask_row = supervised[i, :cap]
if mask_row.any():
preds_row = preds[i, :cap][mask_row]
prot = protein_seqs[i]
seq_len = min(len(prot), preds_row.size(0))
if seq_len > 0:
pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
truth_aa = prot[:seq_len]
aa_matches = sum(1 for j in range(seq_len) if pred_aa[j] == truth_aa[j])
aa_acc = aa_matches / seq_len
per_sample_aa_acc.append(aa_acc)
return per_sample_ce, per_sample_aa_acc
def _dna_to_codons(dna: str) -> List[str]:
dna = dna.strip().upper()
return [dna[i:i+3] for i in range(0, len(dna) - (len(dna) % 3), 3)]
def _aa_from_dna_standard(dna: str, tok) -> str:
dna = dna.strip().upper()
gc = getattr(tok, "_genetic_code", {})
aa = []
for j in range(0, len(dna) - (len(dna) % 3), 3):
aa.append(gc.get(dna[j:j+3], 'X'))
return ''.join(aa)
def _aa_agreement(dna: str, protein: str, tok) -> Tuple[float, int, int]:
"""Return (match_ratio, compared_len, first_mismatch_idx or -1) under standard code."""
dna = dna.strip().upper()
protein = protein.strip().upper()
L = min(len(dna) // 3, len(protein))
if L <= 0:
return 0.0, 0, -1
aa_pred = _aa_from_dna_standard(dna[: 3 * L], tok)
truth = protein[:L]
mism_idx = -1
matches = 0
for i, (a, b) in enumerate(zip(aa_pred, truth)):
if a == b:
matches += 1
elif mism_idx < 0:
mism_idx = i
return (matches / L), L, mism_idx
@torch.no_grad()
def eval_streaming_all(
sampler: CodonSampler,
species_store: SpeciesEmbeddingStore,
data_path: str,
batch_size: int,
num_workers: int,
max_records: int = 0,
):
"""Stream over all rows from CSV/Parquet inputs and compute dataset-level metrics.
Mirrors trainer.evaluate() for parity.
"""
device = sampler.device
tok = sampler.tokenizer
pad_id = int(tok.pad_token_id)
eos_id = int(tok.eos_token_id)
num_special = int(tok.num_special_tokens)
codon2aa = tok.codon2aa_char_map()
# Build streaming dataset and loader
from pathlib import Path as _Path
import glob as _glob
def _expand(pat: str) -> List[str]:
P = _Path(pat)
if P.is_dir():
paths: List[str] = []
paths.extend(sorted(str(x) for x in P.rglob("*.parquet")))
paths.extend(sorted(str(x) for x in P.rglob("*.parq")))
paths.extend(sorted(str(x) for x in P.rglob("*.csv")))
paths.extend(sorted(str(x) for x in P.rglob("*.tsv")))
paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz")))
paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz")))
else:
paths = sorted(_glob.glob(str(P)))
# de-dup
seen = set(); out = []
for x in paths:
if x not in seen:
out.append(x); seen.add(x)
return out
paths = _expand(data_path)
if not paths:
raise FileNotFoundError(f"No input files matched: {data_path}")
species_vocab_path = str((Path(species_store.embeddings_dir) / "species_vocab.json").resolve())
ds = StreamSeqDataset(
files=paths,
tokenizer=tok,
species_vocab_path=species_vocab_path,
unknown_species_id=0,
csv_chunksize=200_000,
shuffle_buffer=0,
shard_across_ranks=False,
)
_dl_kwargs = dict(
batch_size=int(batch_size),
shuffle=False,
drop_last=False,
num_workers=int(max(0, num_workers)),
collate_fn=stage_collate_fn,
pin_memory=True,
persistent_workers=(int(num_workers) > 0),
)
if int(num_workers) > 0:
_dl_kwargs["prefetch_factor"] = 4
loader = DataLoader(ds, **_dl_kwargs)
loss_sum = 0.0
loss_tokens = 0
codon_correct = 0
codon_total = 0
aa_correct = 0
aa_total = 0
seen = 0
for batch in loader:
if not batch:
continue
if int(max_records) > 0 and seen >= int(max_records):
break
codon_ids = batch["codon_ids"].to(device)
input_ids = codon_ids[:, :-1]
labels = codon_ids[:, :-1].clone()
labels[labels == pad_id] = -100
labels[labels == eos_id] = -100
# Build cond using species_store and protein_seqs
cond = {"control_mode": "fixed", "protein_seqs": batch.get("protein_seqs", [])}
sids = batch.get("species_ids")
if torch.is_tensor(sids):
sids_list = sids.detach().cpu().tolist()
else:
sids_list = [int(x) for x in sids]
res = species_store.batch_get(sids_list)
if isinstance(res, tuple):
sp_tok, _ = res
cond["species_tok_emb_src"] = sp_tok.to(device)
cond["species_tok_emb_tgt"] = sp_tok.to(device)
else:
cond["species_emb_src"] = res.to(device)
cond["species_emb_tgt"] = res.to(device)
out = sampler.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True)
loss = out.get("loss")
per_cap = out.get("per_cap")
logits = out.get("logits")
tokens_in_batch = 0
if per_cap is not None:
tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item())
loss_tokens += tokens_in_batch
if loss is not None and tokens_in_batch > 0:
loss_sum += float(loss.detach().item()) * tokens_in_batch
if logits is None or logits.size(1) == 0 or per_cap is None:
seen += input_ids.size(0)
continue
max_cap = logits.size(1)
batch_size = logits.size(0)
labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device)
common = min(labels.size(1), max_cap)
if common > 0:
labels_aligned[:, :common] = labels[:, :common]
per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap)
for row in range(batch_size):
cap = int(per_cap_int[row].item())
if cap < max_cap:
labels_aligned[row, cap:] = -100
supervised = labels_aligned != -100
if num_special > 0:
supervised = supervised & (labels_aligned >= num_special)
if not supervised.any():
seen += batch_size
continue
preds = logits.argmax(dim=-1)
codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item())
codon_total += int(supervised.sum().item())
# protein list
prot_list = cond.get("protein_seqs", [])
for row in range(batch_size):
cap = int(per_cap_int[row].item())
if cap <= 0:
continue
mask_row = supervised[row, :cap]
if not mask_row.any():
continue
preds_row = preds[row, :cap][mask_row]
prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else ""
if not prot:
continue
seq_len = min(len(prot), preds_row.size(0))
if seq_len <= 0:
continue
pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len])
truth_aa = prot[:seq_len]
aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i])
aa_total += seq_len
seen += batch_size
mean_ce = (loss_sum / loss_tokens) if loss_tokens > 0 else 0.0
codon_acc = (float(codon_correct) / codon_total) if codon_total > 0 else 0.0
aa_acc = (float(aa_correct) / aa_total) if aa_total > 0 else 0.0
logger.info(
f"Full-dataset summary → tokens={loss_tokens} CE={mean_ce:.4f} CODON-acc={codon_acc:.4f} AA-acc={aa_acc:.4f}"
)
return mean_ce, codon_acc, aa_acc
@torch.no_grad()
def sample_and_score_batched(
sampler: CodonSampler,
species_names: List[str],
protein_seqs: List[str],
target_dnas: List[str],
temperature: float,
top_k: int,
top_p: float,
control_mode: str,
batch_size: int,
enforce_translation: bool,
no_truncation: bool = False,
species_prefix_cap: int = 64,
) -> Tuple[List[float], List[float]]:
"""Free-run sampling in batches; returns per-sample (codon_acc, aa_acc)."""
N = len(species_names)
# Compute target lengths in codons (min of DNA and AA lengths)
tgt_lengths = []
tgt_codons_list = []
for prot, dna in zip(protein_seqs, target_dnas):
cods = _dna_to_codons(dna)
L = min(len(cods), len(prot))
if L <= 0:
L = 1
cods = ["ATG"] # harmless default
tgt_lengths.append(L)
tgt_codons_list.append(cods[:L])
# Bucket indices by target length to maximize batching
buckets: dict[int, List[int]] = {}
for i, L in enumerate(tgt_lengths):
buckets.setdefault(L, []).append(i)
codon_accs = [0.0] * N
aa_accs = [0.0] * N
# Helper AA translation
vocab = sampler.tokenizer._genetic_code
def dna_to_aa(dna: str) -> str:
dna = dna.strip().upper()
aa = []
for j in range(0, len(dna) - (len(dna) % 3), 3):
aa.append(vocab.get(dna[j:j+3], 'X'))
return ''.join(aa)
for L, idxs in buckets.items():
# Optionally tighten protein prefix so prefix+start+L ≤ capacity (species kept full unless capped)
prev_sp = getattr(sampler.model, "max_species_prefix", 0)
prev_pp = getattr(sampler.model, "max_protein_prefix", 0)
if bool(no_truncation):
try:
capacity = int(getattr(sampler.model, "max_position_embeddings", 1024))
# If requested, apply a species token cap; otherwise keep as-is
store = getattr(sampler, "species_store", None)
if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0:
setattr(sampler.model, "max_species_prefix", int(species_prefix_cap))
# Build a representative cond for this bucket to measure exact prefix length
batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))]
sp_probe = [species_names[i] for i in batch_idx_probe]
pr_probe = [protein_seqs[i] for i in batch_idx_probe]
# Map species to ids via store vocab
cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe}
if store is not None:
sid_list = [store.vocab.get(s, -1) for s in sp_probe]
res = store.batch_get(sid_list)
if isinstance(res, tuple):
sp_tok, _ = res
cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device)
cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
else:
cond_probe["species_emb_src"] = res.to(sampler.device)
cond_probe["species_emb_tgt"] = res.to(sampler.device)
# Iteratively reduce protein prefix cap until remaining ≥ L
for _ in range(3):
out0 = sampler.model(
codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device),
cond=cond_probe,
return_dict=True,
use_cache=True,
)
pref = out0.get("prefix_len")
if isinstance(pref, torch.Tensor) and pref.numel() > 0:
pref_max = int(pref.max().item())
else:
pref_max = int(pref) if isinstance(pref, int) else 0
remaining = capacity - (pref_max + 1)
if remaining >= int(L):
break
need = int(L) - max(0, int(remaining))
cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0)
new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L)))
setattr(sampler.model, "max_protein_prefix", int(new_pp))
except Exception:
pass
# Process in mini-batches
for k in range(0, len(idxs), batch_size):
batch_idx = idxs[k:k+batch_size]
sp_b = [species_names[i] for i in batch_idx]
pr_b = [protein_seqs[i] for i in batch_idx]
# Sample in one call
out = sampler.sample(
num_sequences=len(batch_idx),
sequence_length=L,
species=sp_b,
protein_sequences=pr_b,
control_mode=control_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
return_intermediate=False,
progress_bar=False,
enforce_translation=enforce_translation,
)
gen_list: List[str] = out["sequences"] # DNA strings
# Score each
for pos, idx in enumerate(batch_idx):
tgt_codons = tgt_codons_list[idx]
gen_codons = _dna_to_codons(gen_list[pos])[:L]
matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b)
codon_accs[idx] = (matches / L) if L > 0 else 0.0
gen_aa = dna_to_aa(''.join(gen_codons))
tgt_aa = protein_seqs[idx][:L]
# Treat non-canonical AA in target as "match any"
canonical = set("ACDEFGHIKLMNPQRSTVWY")
aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b))
aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0
# Restore caps
if bool(no_truncation):
try:
setattr(sampler.model, "max_species_prefix", prev_sp)
setattr(sampler.model, "max_protein_prefix", prev_pp)
except Exception:
pass
return codon_accs, aa_accs
@torch.no_grad()
def generate_and_score_batched(
sampler: CodonSampler,
species_names: List[str],
protein_seqs: List[str],
target_dnas: List[str],
temperature: float,
top_k: int,
top_p: float,
control_mode: str,
batch_size: int,
enforce_translation: bool,
no_truncation: bool = False,
species_prefix_cap: int = 64,
) -> Tuple[List[str], List[float], List[float]]:
"""Like sample_and_score_batched but also returns generated DNA sequences per sample."""
N = len(species_names)
tgt_lengths = []
tgt_codons_list = []
for prot, dna in zip(protein_seqs, target_dnas):
cods = _dna_to_codons(dna)
L = min(len(cods), len(prot))
if L <= 0:
L = 1
cods = ["ATG"]
tgt_lengths.append(L)
tgt_codons_list.append(cods[:L])
buckets: dict[int, List[int]] = {}
for i, L in enumerate(tgt_lengths):
buckets.setdefault(L, []).append(i)
gen_all = [""] * N
codon_accs = [0.0] * N
aa_accs = [0.0] * N
vocab = sampler.tokenizer._genetic_code
def dna_to_aa(dna: str) -> str:
dna = dna.strip().upper()
aa = []
for j in range(0, len(dna) - (len(dna) % 3), 3):
aa.append(vocab.get(dna[j:j+3], 'X'))
return ''.join(aa)
for L, idxs in buckets.items():
prev_sp = getattr(sampler.model, "max_species_prefix", 0)
prev_pp = getattr(sampler.model, "max_protein_prefix", 0)
if bool(no_truncation):
try:
capacity = int(getattr(sampler.model, "max_position_embeddings", 1024))
store = getattr(sampler, "species_store", None)
if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0:
setattr(sampler.model, "max_species_prefix", int(species_prefix_cap))
batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))]
sp_probe = [species_names[i] for i in batch_idx_probe]
pr_probe = [protein_seqs[i] for i in batch_idx_probe]
cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe}
if store is not None:
sid_list = [store.vocab.get(s, -1) for s in sp_probe]
res = store.batch_get(sid_list)
if isinstance(res, tuple):
sp_tok, _ = res
cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device)
cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device)
else:
cond_probe["species_emb_src"] = res.to(sampler.device)
cond_probe["species_emb_tgt"] = res.to(sampler.device)
for _ in range(3):
out0 = sampler.model(
codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device),
cond=cond_probe,
return_dict=True,
use_cache=True,
)
pref = out0.get("prefix_len")
pref_max = int(pref.max().item()) if isinstance(pref, torch.Tensor) and pref.numel() > 0 else (int(pref) if isinstance(pref, int) else 0)
remaining = capacity - (pref_max + 1)
if remaining >= int(L):
break
need = int(L) - max(0, int(remaining))
cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0)
new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L)))
setattr(sampler.model, "max_protein_prefix", int(new_pp))
except Exception:
pass
for k in range(0, len(idxs), batch_size):
batch_idx = idxs[k:k+batch_size]
sp_b = [species_names[i] for i in batch_idx]
pr_b = [protein_seqs[i] for i in batch_idx]
out = sampler.sample(
num_sequences=len(batch_idx),
sequence_length=L,
species=sp_b,
protein_sequences=pr_b,
control_mode=control_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
return_intermediate=False,
progress_bar=False,
enforce_translation=enforce_translation,
)
gen_list: List[str] = out["sequences"]
for pos, idx in enumerate(batch_idx):
gen_seq = gen_list[pos]
gen_all[idx] = gen_seq
tgt_codons = tgt_codons_list[idx]
gen_codons = _dna_to_codons(gen_seq)[:L]
matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b)
codon_accs[idx] = (matches / L) if L > 0 else 0.0
gen_aa = dna_to_aa(''.join(gen_codons))
tgt_aa = protein_seqs[idx][:L]
canonical = set("ACDEFGHIKLMNPQRSTVWY")
aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b))
aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0
if bool(no_truncation):
try:
setattr(sampler.model, "max_species_prefix", prev_sp)
setattr(sampler.model, "max_protein_prefix", prev_pp)
except Exception:
pass
return gen_all, codon_accs, aa_accs
def export_per_sequence_over_splits(
sampler: CodonSampler,
splits: List[str],
splits_root: str,
out_csv: str,
batch_size: int,
temperature: float,
top_k: int,
top_p: float,
control_mode: str,
enforce_translation: bool,
progress: bool = False,
max_rows_per_split: int = 0,
no_truncation: bool = False,
species_prefix_cap: int = 0,
) -> None:
"""Process ./data/val and ./data/test (or under splits_root) and write a per-sequence CSV."""
try:
import pyarrow.parquet as pq # type: ignore
except Exception as e:
raise ImportError("pyarrow is required for Parquet evaluation/export") from e
from pathlib import Path as _P
import os as _os
total_written = 0
# Pre-create CSV with header so users can tail it immediately
header_cols = [
"split",
"organism",
"protein_seq",
"codon_seq",
"predicted_seq",
"codon_similarity",
"amino_acid_recovery_rate",
]
_P(out_csv).parent.mkdir(parents=True, exist_ok=True)
if not _P(out_csv).exists() or _os.path.getsize(out_csv) == 0:
with open(out_csv, "w", newline="") as f:
f.write(",".join(header_cols) + "\n")
logging.info(f"Initialized CSV with header → {out_csv}")
for split in splits:
rows_remaining = int(max_rows_per_split) if int(max_rows_per_split) > 0 else None
dir_path = Path(splits_root) / split
files = sorted(str(p) for p in dir_path.glob("*.parquet"))
if not files:
logging.warning(f"No parquet files found in {dir_path}, skipping split {split}")
continue
logging.info(f"Processing split '{split}' with {len(files)} files ...")
try:
from tqdm import tqdm # type: ignore
_wrap = (lambda it, **kw: tqdm(it, **kw)) if progress else (lambda it, **kw: it)
except Exception:
_wrap = (lambda it, **kw: it)
stop_split = False
for fp in _wrap(files, desc=f"{split} files", unit="file"):
if rows_remaining is not None and rows_remaining <= 0:
break
pf = pq.ParquetFile(fp)
nrg = int(pf.num_row_groups or 0)
rgs = list(range(max(nrg, 1)))
# Build a per-file rows progress bar (prefer total rows from metadata when available)
rows_total = None
try:
if pf.metadata is not None:
rows_total = 0
for rg_idx in rgs:
rg_md = pf.metadata.row_group(rg_idx)
if rg_md is not None and rg_md.num_rows is not None:
rows_total += int(rg_md.num_rows)
except Exception:
rows_total = None
rows_pbar = None
if progress:
try:
from tqdm import tqdm # type: ignore
rows_pbar = tqdm(total=rows_total, desc=f"{split}:{Path(fp).name}", unit="rows", leave=False)
except Exception:
rows_pbar = None
for rg in rgs:
if rows_remaining is not None and rows_remaining <= 0:
stop_split = True
break
table = pf.read_row_group(rg, columns=["Taxon", "protein_seq", "cds_DNA"])
df = table.to_pandas()
if df.empty:
continue
species = df["Taxon"].astype(str).tolist()
proteins = df["protein_seq"].astype(str).str.upper().tolist()
dnas = df["cds_DNA"].astype(str).str.upper().tolist()
# Generate predictions and metrics in streaming mini-batches to keep
# memory stable and update progress frequently
N = len(species)
for off in range(0, N, batch_size):
if rows_remaining is not None and rows_remaining <= 0:
stop_split = True
break
sp_b = species[off: off + batch_size]
pr_b = proteins[off: off + batch_size]
dn_b = dnas[off: off + batch_size]
gen_list, codon_accs, aa_accs = generate_and_score_batched(
sampler,
sp_b,
pr_b,
dn_b,
temperature=temperature,
top_k=top_k,
top_p=top_p,
control_mode=control_mode,
batch_size=batch_size,
enforce_translation=enforce_translation,
no_truncation=bool(no_truncation),
species_prefix_cap=int(species_prefix_cap),
)
rows_batch: List[dict] = []
for sp, pr, dn, gen, cacc, aacc in zip(sp_b, pr_b, dn_b, gen_list, codon_accs, aa_accs):
L = min(len(pr), len(dn) // 3)
tgt_dna = dn[: 3 * L]
rows_batch.append({
"split": split,
"organism": sp,
"protein_seq": pr,
"codon_seq": tgt_dna,
"predicted_seq": gen,
"codon_similarity": float(cacc),
"amino_acid_recovery_rate": float(aacc),
})
if rows_batch:
if rows_remaining is not None and len(rows_batch) > rows_remaining:
rows_batch = rows_batch[: rows_remaining]
out_exists = _P(out_csv).exists() and _os.path.getsize(out_csv) > 0
df_out = pd.DataFrame(rows_batch)
_P(out_csv).parent.mkdir(parents=True, exist_ok=True)
df_out.to_csv(out_csv, mode='a', header=not out_exists, index=False)
total_written += len(rows_batch)
if rows_remaining is not None:
rows_remaining -= len(rows_batch)
if rows_pbar is not None:
try:
rows_pbar.update(len(rows_batch))
except Exception:
pass
if rows_remaining is not None and rows_remaining <= 0:
stop_split = True
break
if rows_pbar is not None:
try:
rows_pbar.close()
except Exception:
pass
if stop_split:
break
logging.info(f"Per-sequence export complete → {out_csv} (rows={total_written})")
def main():
args = parse_args()
random.seed(args.seed)
torch.manual_seed(args.seed)
model_dir = Path(args.model_path)
pooling = _preferred_pooling(model_dir)
logger.info(f"Preferred species_pooling from checkpoint: {pooling}")
# Set up species store (recommended for parity)
species_store = None
if args.embeddings_dir:
emb_dir = Path(args.embeddings_dir)
detected = _detect_pooling_from_embeddings_dir(emb_dir)
if detected is not None and detected != pooling:
logger.info(f"Overriding pooling from checkpoint ({pooling}) → embeddings_dir format ({detected})")
pooling = detected
species_store = SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling)
logger.info(f"Loaded species store with {len(species_store.vocab)} species (pooling={pooling})")
# Load sampler/model (uses same construction as sampling)
sampler = CodonSampler(
model_path=args.model_path,
device=("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"),
species_store=species_store,
)
# Load input data and sample rows
if bool(args.export_per_sequence):
export_per_sequence_over_splits(
sampler,
splits=list(args.export_splits),
splits_root=str(args.splits_root),
out_csv=str(args.out_csv),
batch_size=int(args.batch_size),
temperature=float(args.temperature),
top_k=int(args.top_k),
top_p=float(args.top_p),
control_mode=str(args.control_mode),
enforce_translation=bool(args.enforce_translation),
progress=bool(args.progress),
max_rows_per_split=int(args.max_rows_per_split),
no_truncation=bool(args.no_truncation),
species_prefix_cap=int(args.species_prefix_cap),
)
return
data_path = args.data_path or args.csv_path
if data_path is None:
raise SystemExit("Please provide --data_path (CSV or Parquet glob/dir). --csv_path remains as a deprecated alias.")
# Expand paths to decide CSV vs Parquet
paths = _expand_paths(data_path)
if not paths:
raise FileNotFoundError(f"No input files matched: {data_path}")
if all(_is_parquet_path(p) for p in paths):
logger.info(f"Reading up to {args.num_samples} samples from {len(paths)} parquet files ...")
df_s = _load_random_samples_from_parquet(paths, int(args.num_samples), int(args.seed))
else:
# Fallback to CSV/TSV single file behavior (back-compat). If multiple files match, use the first.
csv_file = None
for pth in paths:
if pth.lower().endswith((".csv", ".tsv", ".csv.gz", ".tsv.gz")):
csv_file = pth
break
if csv_file is None:
raise ValueError(f"Unsupported input for --data_path: {paths[0]}")
logger.info(f"Reading CSV file: {csv_file}")
df = pd.read_csv(csv_file)
required = {"Taxon", "protein_seq", "cds_DNA"}
if not required.issubset(set(df.columns)):
missing = required - set(df.columns)
raise ValueError(f"CSV missing required columns: {sorted(missing)}")
if args.num_samples > len(df):
logger.warning(f"num_samples {args.num_samples} > CSV rows {len(df)}; reducing")
args.num_samples = len(df)
# Random sample without replacement
indices = random.sample(range(len(df)), args.num_samples)
df_s = df.iloc[indices].reset_index(drop=True)
if len(df_s) == 0:
raise ValueError("No samples loaded from the provided data_path")
logger.info(f"Loaded {len(df_s)} samples for evaluation")
species = df_s["Taxon"].astype(str).tolist()
proteins = df_s["protein_seq"].astype(str).str.upper().tolist()
dnas = df_s["cds_DNA"].astype(str).str.upper().tolist()
if not args.free_run:
if bool(args.eval_all):
if not args.embeddings_dir:
raise SystemExit("--eval_all requires --embeddings_dir for species vocab/embeddings")
# Stream the entire dataset and compute dataset-level metrics (training-parity)
eval_streaming_all(
sampler,
species_store if species_store is not None else SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling),
data_path,
batch_size=int(args.batch_size),
num_workers=int(args.workers),
max_records=int(args.max_records),
)
return
# Optional: print per-sample CDS→AA agreement (standard code)
if bool(args.debug_aa_check):
for idx, (sp, pr, dn) in enumerate(zip(species, proteins, dnas), start=1):
ratio, Lcmp, first_bad = _aa_agreement(dn, pr, sampler.tokenizer)
flag = "OK" if ratio == 1.0 and Lcmp > 0 else ("EMPTY" if Lcmp == 0 else "MISMATCH")
extra = f" first_mismatch={first_bad}" if first_bad >= 0 else ""
logger.info(f"AA-CHECK Sample {idx:02d}: {flag} match={ratio:.3f} len={Lcmp}{extra} Taxon={sp}")
# (No dataset-level filtering to keep evaluation simple.)
# Teacher-forced evaluation (random subset)
per_ce_all: List[float] = []
per_aa_acc_all: List[float] = []
per_codon_acc_all: List[float] = []
bs = max(1, int(args.batch_size))
for i in range(0, len(species), bs):
sp_b = species[i:i+bs]
pr_b = proteins[i:i+bs]
dn_b = dnas[i:i+bs]
ce, aa_acc = eval_batch(sampler, species_store, sp_b, pr_b, dn_b)
# Also compute per-sample codon-acc using the same batch forward for consistency
# Re-run lightweight preds for codon-acc is unnecessary because eval_batch already
# computed supervised mask and preds internally; instead, recompute quickly here
# by calling eval_batch and deriving codon-acc inside it. For simplicity and clarity
# we re-derive codon-acc below using the same masking rules.
per_ce_all.extend(ce)
per_aa_acc_all.extend(aa_acc)
# Derive codon-acc for this batch
# Prepare a mirrored forward to access logits and masks (small overhead acceptable)
tok = sampler.tokenizer
pad_id = tok.pad_token_id
eos_id = tok.eos_token_id
codon_ids_local = []
for dna, prot in zip(dn_b, pr_b):
C_dna = len(dna) // 3
C_prot = len(prot)
C = max(min(C_dna, C_prot), 1)
dna_trim = dna[: 3 * C]
ids = tok.encode_codon_seq(dna_trim, validate=False)
ids.append(eos_id)
codon_ids_local.append(ids)
B_b = len(codon_ids_local)
T_b = max(len(x) for x in codon_ids_local)
codons_b = torch.full((B_b, T_b), pad_id, dtype=torch.long)
mask_b = torch.zeros((B_b, T_b), dtype=torch.bool)
for j, ids in enumerate(codon_ids_local):
Lb = len(ids)
codons_b[j, :Lb] = torch.tensor(ids, dtype=torch.long)
mask_b[j, :Lb] = True
input_ids_b = codons_b[:, :-1].to(sampler.device)
labels_b = codons_b[:, :-1].clone()
labels_b[labels_b == pad_id] = -100
labels_b[labels_b == eos_id] = -100
cond_b = {"control_mode": "fixed"}
if species_store is not None and sp_b:
sids_b = [species_store.vocab.get(s, -1) for s in sp_b]
res_b = species_store.batch_get(sids_b)
if isinstance(res_b, tuple):
sp_tok_b, _ = res_b
cond_b["species_tok_emb_src"] = sp_tok_b.to(sampler.device)
cond_b["species_tok_emb_tgt"] = sp_tok_b.to(sampler.device)
else:
sp_fix_b = res_b
cond_b["species_emb_src"] = sp_fix_b.to(sampler.device)
cond_b["species_emb_tgt"] = sp_fix_b.to(sampler.device)
cond_b["protein_seqs"] = pr_b
out_b = sampler.model(codon_ids=input_ids_b, cond=cond_b, labels=labels_b.to(sampler.device), return_dict=True)
logits_b = out_b["logits"]
per_cap_b = out_b.get("per_cap")
if logits_b is not None and per_cap_b is not None:
Bsz, Lmax, V = logits_b.size(0), logits_b.size(1), logits_b.size(2)
labels_aligned_b = torch.full((Bsz, Lmax), -100, dtype=labels_b.dtype, device=logits_b.device)
common_cols_b = min(labels_b.size(1), Lmax)
if common_cols_b > 0:
labels_aligned_b[:, :common_cols_b] = labels_b.to(logits_b.device)[:, :common_cols_b]
ar = torch.arange(Lmax, device=logits_b.device).unsqueeze(0)
cap_mask_b = ar < per_cap_b.to(device=logits_b.device).unsqueeze(1)
labels_masked_b = labels_aligned_b.clone()
labels_masked_b[~cap_mask_b] = -100
preds_b = logits_b.argmax(dim=-1)
num_special = int(getattr(tok, "num_special_tokens", 0) or 0)
supervised_b = (labels_masked_b != -100) & cap_mask_b
if num_special > 0:
supervised_b = supervised_b & (labels_aligned_b >= num_special)
for r in range(Bsz):
denom = int(supervised_b[r].sum().item())
cod_acc = (float((preds_b[r][supervised_b[r]] == labels_aligned_b[r][supervised_b[r]]).sum().item()) / denom) if denom > 0 else 0.0
per_codon_acc_all.append(cod_acc)
for idx, (ce, aa, ca) in enumerate(zip(per_ce_all, per_aa_acc_all, per_codon_acc_all), start=1):
logger.info(f"Sample {idx:02d}: CE={ce:.4f} CODON-acc={ca:.4f} AA-acc={aa:.4f}")
if per_ce_all:
mean_ce = sum(per_ce_all) / len(per_ce_all)
mean_aa = sum(per_aa_acc_all) / len(per_aa_acc_all) if per_aa_acc_all else 0.0
mean_codon = sum(per_codon_acc_all) / len(per_codon_acc_all) if per_codon_acc_all else 0.0
logger.info(f"Summary over {len(per_ce_all)} samples → mean CE={mean_ce:.4f}, mean CODON-acc={mean_codon:.4f}, mean AA-acc={mean_aa:.4f}")
else:
# Free-run sampling evaluation vs ground-truth DNA (codon-level), batched
codon_accs, aa_accs = sample_and_score_batched(
sampler,
species,
proteins,
dnas,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
control_mode=args.control_mode,
batch_size=int(args.batch_size),
enforce_translation=bool(args.enforce_translation),
no_truncation=bool(args.no_truncation),
species_prefix_cap=int(args.species_prefix_cap),
)
for idx, (cacc, aacc) in enumerate(zip(codon_accs, aa_accs), start=1):
logger.info(f"Sample {idx:02d}: CODON-acc={cacc:.4f} AA-acc={aacc:.4f}")
if codon_accs:
mean_c = sum(codon_accs) / len(codon_accs)
mean_a = sum(aa_accs) / len(aa_accs)
logger.info(f"Summary over {len(codon_accs)} samples → mean CODON-acc={mean_c:.4f}, mean AA-acc={mean_a:.4f}")
if __name__ == "__main__":
main()