#!/usr/bin/env python """ Teacher-forced (and optional free-run) evaluation on a random subset of your dataset to measure codon token cross-entropy and AA token accuracy, using the same conditioning pathway as training. Supports either a CSV file or Parquet input via a directory/glob (e.g., ./data/val/*.parquet). Usage examples: # CSV input python eval.py \ --model_path outputs/checkpoint-21000 \ --data_path random_sample_1000.csv \ --embeddings_dir embeddings \ --num_samples 10 \ --batch_size 10 \ --device cuda # Parquet glob input python eval.py \ --model_path outputs/checkpoint-21000 \ --data_path "./data/val/*.parquet" \ --embeddings_dir embeddings \ --num_samples 64 \ --batch_size 32 \ --device cuda """ import argparse import json import logging import random from pathlib import Path from typing import List, Optional, Tuple import glob import torch import torch.nn.functional as F import pandas as pd from src.sampler import CodonSampler from src.dataset import SpeciesEmbeddingStore, StreamSeqDataset, stage_collate_fn from torch.utils.data import DataLoader logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger("eval_tf") def parse_args(): p = argparse.ArgumentParser("Teacher-forced evaluation of CodonTranslator") p.add_argument("--model_path", required=True, type=str, help="Path to checkpoint dir (with config.json / model.safetensors)") # Input data: CSV file or Parquet glob/dir p.add_argument("--data_path", required=False, type=str, default=None, help="CSV file or Parquet glob/dir (e.g., ./data/val/*.parquet)") # Back-compat: --csv_path still accepted (deprecated) p.add_argument("--csv_path", required=False, type=str, default=None, help="[Deprecated] CSV with columns: Taxon, protein_seq, cds_DNA") p.add_argument("--embeddings_dir", type=str, default=None, help="Species embeddings directory (recommended for parity)") p.add_argument("--num_samples", type=int, default=10) p.add_argument("--batch_size", type=int, default=10) p.add_argument("--seed", type=int, default=42) p.add_argument("--device", type=str, default="cuda") p.add_argument("--workers", type=int, default=0, help="DataLoader workers for --eval_all streaming mode") # Free-run (sampling) evaluation options p.add_argument("--free_run", action="store_true", help="If set, perform real sampling instead of teacher forcing and compare to ground-truth codon sequences") p.add_argument("--temperature", type=float, default=0.8) p.add_argument("--top_k", type=int, default=50) p.add_argument("--top_p", type=float, default=0.9) p.add_argument("--control_mode", type=str, choices=["fixed","variable"], default="fixed") p.add_argument("--enforce_translation", action="store_true", help="Hard-mask decoding to codons matching target amino acid at each position during free-run evaluation") # Full-dataset streaming eval (no sampling) p.add_argument("--eval_all", action="store_true", help="Stream over all rows from --data_path and compute aggregated metrics (memory-safe)") p.add_argument("--max_records", type=int, default=0, help="When --eval_all is set: limit to first N samples (0 = all)") p.add_argument("--debug_aa_check", action="store_true", help="Print per-sample agreement between CDS→AA (standard code) and provided protein") # Per-sequence export over standard splits ./data/val and ./data/test p.add_argument("--export_per_sequence", action="store_true", help="Process ./data/val and ./data/test parquets in batches and export a per-sequence CSV") p.add_argument("--splits_root", type=str, default="./data", help="Root directory that contains val/ and test/ subfolders with parquet files") p.add_argument("--out_csv", type=str, default="outputs/eval_per_sequence.csv", help="Output CSV path for per-sequence export") p.add_argument("--export_splits", nargs="+", default=["val", "test"], help="Subdirectories under --splits_root to process (default: val test)") p.add_argument("--max_rows_per_split", type=int, default=0, help="When --export_per_sequence is set: limit number of rows per split (0 = all)") p.add_argument("--progress", action="store_true", help="Show progress bars during per-sequence export") # Capacity and evaluation controls p.add_argument("--no_truncation", action="store_true", help="Fit prefix caps so generated codon length equals protein length (avoids capacity truncation)") p.add_argument("--species_prefix_cap", type=int, default=0, help="When >0 and --no_truncation is set, cap species token prefix to this many tokens; 0 = no species cap") return p.parse_args() def _is_parquet_path(p: str) -> bool: lower = p.lower() return lower.endswith(".parquet") or lower.endswith(".parq") def _expand_paths(maybe_path_or_glob: str) -> List[str]: """Expand a path/glob or directory into a sorted list of files. Prioritize Parquet when scanning a directory. """ paths: List[str] = [] P = Path(maybe_path_or_glob) if P.is_dir(): paths.extend(sorted(str(x) for x in P.rglob("*.parquet"))) paths.extend(sorted(str(x) for x in P.rglob("*.parq"))) paths.extend(sorted(str(x) for x in P.rglob("*.csv"))) paths.extend(sorted(str(x) for x in P.rglob("*.tsv"))) paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz"))) paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz"))) else: paths = sorted(glob.glob(str(P))) # Dedup while preserving order out: List[str] = [] seen = set() for x in paths: if x not in seen: out.append(x) seen.add(x) return out def _load_random_samples_from_parquet(files: List[str], num_samples: int, seed: int) -> pd.DataFrame: """Collect up to num_samples rows from a list of Parquet files, reading by row group. Reads only the required columns and shuffles files/row-groups for decent coverage. """ try: import pyarrow.parquet as pq # type: ignore except Exception as e: # pragma: no cover raise ImportError("pyarrow is required to read parquet files") from e rng = random.Random(seed) req = ["Taxon", "protein_seq", "cds_DNA"] files = [f for f in files if _is_parquet_path(f)] if not files: raise FileNotFoundError("No Parquet files found to read") files = files.copy() rng.shuffle(files) collected: List[pd.DataFrame] = [] remaining = int(max(0, num_samples)) for fp in files: if remaining <= 0: break pf = pq.ParquetFile(fp) nrg = int(pf.num_row_groups or 0) if nrg <= 0: rgs = [0] else: rgs = list(range(nrg)) rng.shuffle(rgs) # Only keep columns that exist in this file cols = [c for c in req if c in pf.schema.names] if len(cols) < len(req): missing = sorted(set(req) - set(cols)) raise ValueError(f"Parquet missing required columns {missing} in {fp}") for rg in rgs: if remaining <= 0: break table = pf.read_row_group(rg, columns=cols) df = table.to_pandas(types_mapper=None) if df.empty: continue if len(df) > remaining: df = df.sample(n=remaining, random_state=rng.randint(0, 2**31 - 1)) collected.append(df) remaining -= len(df) if not collected: return pd.DataFrame(columns=req) out = pd.concat(collected, ignore_index=True) # Final shuffle for randomness out = out.sample(frac=1.0, random_state=seed).reset_index(drop=True) # If we somehow overshot, trim if len(out) > num_samples: out = out.iloc[:num_samples].reset_index(drop=True) return out def _preferred_pooling(model_dir: Path) -> str: """ Best-effort pooling detection: - First try checkpoint configs for an explicit hint - Fallback to 'last' Note: we'll further override this using the embeddings_dir contents if provided. """ for cfg_name in ("trainer_config.json", "config.json"): fp = model_dir / cfg_name if fp.exists(): try: with open(fp) as f: cfg = json.load(f) return str(cfg.get("species_pooling", "last")) except Exception: continue return "last" def _detect_pooling_from_embeddings_dir(emb_dir: Path) -> Optional[str]: """Detect actual available pooling format from embeddings_dir contents.""" fixed_files = [emb_dir / "species_embeddings.bin", emb_dir / "species_metadata.json", emb_dir / "species_vocab.json"] seq_files = [emb_dir / "species_tok_emb.bin", emb_dir / "species_index.json", emb_dir / "species_vocab.json"] if all(p.exists() for p in fixed_files): return "last" if all(p.exists() for p in seq_files): return "sequence" return None @torch.no_grad() def eval_batch( sampler: CodonSampler, species_store: Optional[SpeciesEmbeddingStore], species_names: List[str], protein_seqs: List[str], dna_cds_list: List[str], ) -> Tuple[List[float], List[float]]: """Evaluate a batch in teacher-forced mode. Returns per-sample (avg_ce_loss, aa_token_acc). """ tok = sampler.tokenizer pad_id = tok.pad_token_id eos_id = tok.eos_token_id # Encode DNA to codon ids and align lengths (trim to min protein length) codon_ids = [] seq_lens = [] for dna, prot in zip(dna_cds_list, protein_seqs): # Trim to min length between DNA codons and protein AA C_dna = len(dna) // 3 C_prot = len(prot) C = max(min(C_dna, C_prot), 1) dna_trim = dna[: 3 * C] ids = tok.encode_codon_seq(dna_trim, validate=False) ids.append(eos_id) codon_ids.append(ids) seq_lens.append(len(ids)) B = len(codon_ids) T = max(seq_lens) codons = torch.full((B, T), pad_id, dtype=torch.long) mask = torch.zeros((B, T), dtype=torch.bool) for i, ids in enumerate(codon_ids): L = len(ids) codons[i, :L] = torch.tensor(ids, dtype=torch.long) mask[i, :L] = True # inputs/labels aligned to training convention: # model predicts next codon after a learned start token; labels are the # same positions as inputs (not shifted by 1), with PAD/EOS masked out. input_ids = codons[:, :-1] labels_base = codons[:, :-1].clone() # Mask out PAD and EOS like trainer.evaluate() labels_base[labels_base == pad_id] = -100 labels_base[labels_base == eos_id] = -100 # Build conditioning dict similar to training and sampler cond = {"control_mode": "fixed"} if species_store is not None and species_names: sid_list = [species_store.vocab.get(s, -1) for s in species_names] num_unknown = sum(1 for x in sid_list if x < 0) if num_unknown > 0: logger.warning(f"{num_unknown}/{len(sid_list)} species not found in embeddings vocab; using zero embeddings") result = species_store.batch_get(sid_list) if isinstance(result, tuple): sp_tok, _ = result # [B, Ls, Ds] cond["species_tok_emb_src"] = sp_tok.to(sampler.device) cond["species_tok_emb_tgt"] = sp_tok.to(sampler.device) else: sp = result # [B, Ds] cond["species_emb_src"] = sp.to(sampler.device) cond["species_emb_tgt"] = sp.to(sampler.device) elif species_names: # On-the-fly species embeddings using Qwen (sequence pooling for training parity) seq_emb, _lens = sampler._qwen_embed_names(species_names, pooling="sequence") seq_emb = seq_emb.to(sampler.device) cond["species_tok_emb_src"] = seq_emb cond["species_tok_emb_tgt"] = seq_emb # Match training: pass raw protein sequences; the model tokenizes internally cond["protein_seqs"] = protein_seqs # Move tensors to device device = sampler.device input_ids = input_ids.to(device) labels_base = labels_base.to(device) sampler.model.eval() outputs = sampler.model(codon_ids=input_ids, cond=cond, labels=labels_base, return_dict=True) logits = outputs["logits"] # [B, Lmax, V] aligned to per-sample capacity after prefix try: prefix_len = outputs.get("prefix_len", 0) if isinstance(prefix_len, torch.Tensor): prefix_len_dbg = int(prefix_len.max().item()) if prefix_len.numel() > 0 else 0 else: prefix_len_dbg = int(prefix_len) logger.debug(f"Prefix length(max)={prefix_len_dbg}, input_len={input_ids.size(1)}") except Exception: pass # Align labels/masks to logits length and per-sample caps Bsz, Lmax, V = logits.size(0), logits.size(1), logits.size(2) labels_aligned = torch.full((Bsz, Lmax), -100, dtype=labels_base.dtype, device=logits.device) common_cols = min(labels_base.size(1), Lmax) if common_cols > 0: labels_aligned[:, :common_cols] = labels_base[:, :common_cols] per_cap = outputs.get("per_cap", None) if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz: ar = torch.arange(Lmax, device=logits.device).unsqueeze(0) cap_mask = ar < per_cap.to(device=logits.device).unsqueeze(1) # [B,Lmax] else: cap_mask = torch.ones_like(labels_aligned, dtype=torch.bool, device=logits.device) # Mask labels beyond per-cap to -100 so CE ignores them labels_masked = labels_aligned.clone().to(device=logits.device) labels_masked[~cap_mask] = -100 # Cross-entropy per sample (include EOS target; ignore PAD) loss_flat = F.cross_entropy( logits.reshape(-1, V), labels_masked.reshape(-1), ignore_index=-100, reduction="none", ).view(Bsz, Lmax) # Accuracy per sample preds = logits.argmax(dim=-1) num_special = int(getattr(tok, "num_special_tokens", 0) or 0) supervised = (labels_masked != -100) & cap_mask if num_special > 0: supervised = supervised & (labels_aligned >= num_special) correct = (preds == labels_aligned) & supervised per_sample_ce: List[float] = [] per_sample_acc: List[float] = [] per_sample_aa_acc: List[float] = [] codon2aa = tok.codon2aa_char_map() if hasattr(tok, "codon2aa_char_map") else {} per_cap = outputs.get("per_cap", None) per_cap_int = None if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz: per_cap_int = torch.clamp(per_cap.to(dtype=torch.long, device=logits.device), min=0, max=Lmax) for i in range(B): # Average CE over valid positions valid = (labels_masked[i] != -100) & cap_mask[i] if num_special > 0: valid = valid & (labels_aligned[i] >= num_special) ce = (loss_flat[i][valid].mean().item() if valid.any() else 0.0) per_sample_ce.append(ce) # Codon-level accuracy over supervised positions denom = supervised[i].sum().item() acc = (correct[i].sum().item() / denom) if denom > 0 else 0.0 # AA-level accuracy per sample (match trainer) aa_acc = 0.0 if per_cap_int is not None and codon2aa and i < len(protein_seqs): cap = int(per_cap_int[i].item()) if cap > 0: mask_row = supervised[i, :cap] if mask_row.any(): preds_row = preds[i, :cap][mask_row] prot = protein_seqs[i] seq_len = min(len(prot), preds_row.size(0)) if seq_len > 0: pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) truth_aa = prot[:seq_len] aa_matches = sum(1 for j in range(seq_len) if pred_aa[j] == truth_aa[j]) aa_acc = aa_matches / seq_len per_sample_aa_acc.append(aa_acc) return per_sample_ce, per_sample_aa_acc def _dna_to_codons(dna: str) -> List[str]: dna = dna.strip().upper() return [dna[i:i+3] for i in range(0, len(dna) - (len(dna) % 3), 3)] def _aa_from_dna_standard(dna: str, tok) -> str: dna = dna.strip().upper() gc = getattr(tok, "_genetic_code", {}) aa = [] for j in range(0, len(dna) - (len(dna) % 3), 3): aa.append(gc.get(dna[j:j+3], 'X')) return ''.join(aa) def _aa_agreement(dna: str, protein: str, tok) -> Tuple[float, int, int]: """Return (match_ratio, compared_len, first_mismatch_idx or -1) under standard code.""" dna = dna.strip().upper() protein = protein.strip().upper() L = min(len(dna) // 3, len(protein)) if L <= 0: return 0.0, 0, -1 aa_pred = _aa_from_dna_standard(dna[: 3 * L], tok) truth = protein[:L] mism_idx = -1 matches = 0 for i, (a, b) in enumerate(zip(aa_pred, truth)): if a == b: matches += 1 elif mism_idx < 0: mism_idx = i return (matches / L), L, mism_idx @torch.no_grad() def eval_streaming_all( sampler: CodonSampler, species_store: SpeciesEmbeddingStore, data_path: str, batch_size: int, num_workers: int, max_records: int = 0, ): """Stream over all rows from CSV/Parquet inputs and compute dataset-level metrics. Mirrors trainer.evaluate() for parity. """ device = sampler.device tok = sampler.tokenizer pad_id = int(tok.pad_token_id) eos_id = int(tok.eos_token_id) num_special = int(tok.num_special_tokens) codon2aa = tok.codon2aa_char_map() # Build streaming dataset and loader from pathlib import Path as _Path import glob as _glob def _expand(pat: str) -> List[str]: P = _Path(pat) if P.is_dir(): paths: List[str] = [] paths.extend(sorted(str(x) for x in P.rglob("*.parquet"))) paths.extend(sorted(str(x) for x in P.rglob("*.parq"))) paths.extend(sorted(str(x) for x in P.rglob("*.csv"))) paths.extend(sorted(str(x) for x in P.rglob("*.tsv"))) paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz"))) paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz"))) else: paths = sorted(_glob.glob(str(P))) # de-dup seen = set(); out = [] for x in paths: if x not in seen: out.append(x); seen.add(x) return out paths = _expand(data_path) if not paths: raise FileNotFoundError(f"No input files matched: {data_path}") species_vocab_path = str((Path(species_store.embeddings_dir) / "species_vocab.json").resolve()) ds = StreamSeqDataset( files=paths, tokenizer=tok, species_vocab_path=species_vocab_path, unknown_species_id=0, csv_chunksize=200_000, shuffle_buffer=0, shard_across_ranks=False, ) _dl_kwargs = dict( batch_size=int(batch_size), shuffle=False, drop_last=False, num_workers=int(max(0, num_workers)), collate_fn=stage_collate_fn, pin_memory=True, persistent_workers=(int(num_workers) > 0), ) if int(num_workers) > 0: _dl_kwargs["prefetch_factor"] = 4 loader = DataLoader(ds, **_dl_kwargs) loss_sum = 0.0 loss_tokens = 0 codon_correct = 0 codon_total = 0 aa_correct = 0 aa_total = 0 seen = 0 for batch in loader: if not batch: continue if int(max_records) > 0 and seen >= int(max_records): break codon_ids = batch["codon_ids"].to(device) input_ids = codon_ids[:, :-1] labels = codon_ids[:, :-1].clone() labels[labels == pad_id] = -100 labels[labels == eos_id] = -100 # Build cond using species_store and protein_seqs cond = {"control_mode": "fixed", "protein_seqs": batch.get("protein_seqs", [])} sids = batch.get("species_ids") if torch.is_tensor(sids): sids_list = sids.detach().cpu().tolist() else: sids_list = [int(x) for x in sids] res = species_store.batch_get(sids_list) if isinstance(res, tuple): sp_tok, _ = res cond["species_tok_emb_src"] = sp_tok.to(device) cond["species_tok_emb_tgt"] = sp_tok.to(device) else: cond["species_emb_src"] = res.to(device) cond["species_emb_tgt"] = res.to(device) out = sampler.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True) loss = out.get("loss") per_cap = out.get("per_cap") logits = out.get("logits") tokens_in_batch = 0 if per_cap is not None: tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item()) loss_tokens += tokens_in_batch if loss is not None and tokens_in_batch > 0: loss_sum += float(loss.detach().item()) * tokens_in_batch if logits is None or logits.size(1) == 0 or per_cap is None: seen += input_ids.size(0) continue max_cap = logits.size(1) batch_size = logits.size(0) labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device) common = min(labels.size(1), max_cap) if common > 0: labels_aligned[:, :common] = labels[:, :common] per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap) for row in range(batch_size): cap = int(per_cap_int[row].item()) if cap < max_cap: labels_aligned[row, cap:] = -100 supervised = labels_aligned != -100 if num_special > 0: supervised = supervised & (labels_aligned >= num_special) if not supervised.any(): seen += batch_size continue preds = logits.argmax(dim=-1) codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item()) codon_total += int(supervised.sum().item()) # protein list prot_list = cond.get("protein_seqs", []) for row in range(batch_size): cap = int(per_cap_int[row].item()) if cap <= 0: continue mask_row = supervised[row, :cap] if not mask_row.any(): continue preds_row = preds[row, :cap][mask_row] prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else "" if not prot: continue seq_len = min(len(prot), preds_row.size(0)) if seq_len <= 0: continue pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) truth_aa = prot[:seq_len] aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i]) aa_total += seq_len seen += batch_size mean_ce = (loss_sum / loss_tokens) if loss_tokens > 0 else 0.0 codon_acc = (float(codon_correct) / codon_total) if codon_total > 0 else 0.0 aa_acc = (float(aa_correct) / aa_total) if aa_total > 0 else 0.0 logger.info( f"Full-dataset summary → tokens={loss_tokens} CE={mean_ce:.4f} CODON-acc={codon_acc:.4f} AA-acc={aa_acc:.4f}" ) return mean_ce, codon_acc, aa_acc @torch.no_grad() def sample_and_score_batched( sampler: CodonSampler, species_names: List[str], protein_seqs: List[str], target_dnas: List[str], temperature: float, top_k: int, top_p: float, control_mode: str, batch_size: int, enforce_translation: bool, no_truncation: bool = False, species_prefix_cap: int = 64, ) -> Tuple[List[float], List[float]]: """Free-run sampling in batches; returns per-sample (codon_acc, aa_acc).""" N = len(species_names) # Compute target lengths in codons (min of DNA and AA lengths) tgt_lengths = [] tgt_codons_list = [] for prot, dna in zip(protein_seqs, target_dnas): cods = _dna_to_codons(dna) L = min(len(cods), len(prot)) if L <= 0: L = 1 cods = ["ATG"] # harmless default tgt_lengths.append(L) tgt_codons_list.append(cods[:L]) # Bucket indices by target length to maximize batching buckets: dict[int, List[int]] = {} for i, L in enumerate(tgt_lengths): buckets.setdefault(L, []).append(i) codon_accs = [0.0] * N aa_accs = [0.0] * N # Helper AA translation vocab = sampler.tokenizer._genetic_code def dna_to_aa(dna: str) -> str: dna = dna.strip().upper() aa = [] for j in range(0, len(dna) - (len(dna) % 3), 3): aa.append(vocab.get(dna[j:j+3], 'X')) return ''.join(aa) for L, idxs in buckets.items(): # Optionally tighten protein prefix so prefix+start+L ≤ capacity (species kept full unless capped) prev_sp = getattr(sampler.model, "max_species_prefix", 0) prev_pp = getattr(sampler.model, "max_protein_prefix", 0) if bool(no_truncation): try: capacity = int(getattr(sampler.model, "max_position_embeddings", 1024)) # If requested, apply a species token cap; otherwise keep as-is store = getattr(sampler, "species_store", None) if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0: setattr(sampler.model, "max_species_prefix", int(species_prefix_cap)) # Build a representative cond for this bucket to measure exact prefix length batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))] sp_probe = [species_names[i] for i in batch_idx_probe] pr_probe = [protein_seqs[i] for i in batch_idx_probe] # Map species to ids via store vocab cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe} if store is not None: sid_list = [store.vocab.get(s, -1) for s in sp_probe] res = store.batch_get(sid_list) if isinstance(res, tuple): sp_tok, _ = res cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device) cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device) else: cond_probe["species_emb_src"] = res.to(sampler.device) cond_probe["species_emb_tgt"] = res.to(sampler.device) # Iteratively reduce protein prefix cap until remaining ≥ L for _ in range(3): out0 = sampler.model( codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device), cond=cond_probe, return_dict=True, use_cache=True, ) pref = out0.get("prefix_len") if isinstance(pref, torch.Tensor) and pref.numel() > 0: pref_max = int(pref.max().item()) else: pref_max = int(pref) if isinstance(pref, int) else 0 remaining = capacity - (pref_max + 1) if remaining >= int(L): break need = int(L) - max(0, int(remaining)) cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0) new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L))) setattr(sampler.model, "max_protein_prefix", int(new_pp)) except Exception: pass # Process in mini-batches for k in range(0, len(idxs), batch_size): batch_idx = idxs[k:k+batch_size] sp_b = [species_names[i] for i in batch_idx] pr_b = [protein_seqs[i] for i in batch_idx] # Sample in one call out = sampler.sample( num_sequences=len(batch_idx), sequence_length=L, species=sp_b, protein_sequences=pr_b, control_mode=control_mode, temperature=temperature, top_k=top_k, top_p=top_p, return_intermediate=False, progress_bar=False, enforce_translation=enforce_translation, ) gen_list: List[str] = out["sequences"] # DNA strings # Score each for pos, idx in enumerate(batch_idx): tgt_codons = tgt_codons_list[idx] gen_codons = _dna_to_codons(gen_list[pos])[:L] matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b) codon_accs[idx] = (matches / L) if L > 0 else 0.0 gen_aa = dna_to_aa(''.join(gen_codons)) tgt_aa = protein_seqs[idx][:L] # Treat non-canonical AA in target as "match any" canonical = set("ACDEFGHIKLMNPQRSTVWY") aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b)) aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0 # Restore caps if bool(no_truncation): try: setattr(sampler.model, "max_species_prefix", prev_sp) setattr(sampler.model, "max_protein_prefix", prev_pp) except Exception: pass return codon_accs, aa_accs @torch.no_grad() def generate_and_score_batched( sampler: CodonSampler, species_names: List[str], protein_seqs: List[str], target_dnas: List[str], temperature: float, top_k: int, top_p: float, control_mode: str, batch_size: int, enforce_translation: bool, no_truncation: bool = False, species_prefix_cap: int = 64, ) -> Tuple[List[str], List[float], List[float]]: """Like sample_and_score_batched but also returns generated DNA sequences per sample.""" N = len(species_names) tgt_lengths = [] tgt_codons_list = [] for prot, dna in zip(protein_seqs, target_dnas): cods = _dna_to_codons(dna) L = min(len(cods), len(prot)) if L <= 0: L = 1 cods = ["ATG"] tgt_lengths.append(L) tgt_codons_list.append(cods[:L]) buckets: dict[int, List[int]] = {} for i, L in enumerate(tgt_lengths): buckets.setdefault(L, []).append(i) gen_all = [""] * N codon_accs = [0.0] * N aa_accs = [0.0] * N vocab = sampler.tokenizer._genetic_code def dna_to_aa(dna: str) -> str: dna = dna.strip().upper() aa = [] for j in range(0, len(dna) - (len(dna) % 3), 3): aa.append(vocab.get(dna[j:j+3], 'X')) return ''.join(aa) for L, idxs in buckets.items(): prev_sp = getattr(sampler.model, "max_species_prefix", 0) prev_pp = getattr(sampler.model, "max_protein_prefix", 0) if bool(no_truncation): try: capacity = int(getattr(sampler.model, "max_position_embeddings", 1024)) store = getattr(sampler, "species_store", None) if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0: setattr(sampler.model, "max_species_prefix", int(species_prefix_cap)) batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))] sp_probe = [species_names[i] for i in batch_idx_probe] pr_probe = [protein_seqs[i] for i in batch_idx_probe] cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe} if store is not None: sid_list = [store.vocab.get(s, -1) for s in sp_probe] res = store.batch_get(sid_list) if isinstance(res, tuple): sp_tok, _ = res cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device) cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device) else: cond_probe["species_emb_src"] = res.to(sampler.device) cond_probe["species_emb_tgt"] = res.to(sampler.device) for _ in range(3): out0 = sampler.model( codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device), cond=cond_probe, return_dict=True, use_cache=True, ) pref = out0.get("prefix_len") pref_max = int(pref.max().item()) if isinstance(pref, torch.Tensor) and pref.numel() > 0 else (int(pref) if isinstance(pref, int) else 0) remaining = capacity - (pref_max + 1) if remaining >= int(L): break need = int(L) - max(0, int(remaining)) cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0) new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L))) setattr(sampler.model, "max_protein_prefix", int(new_pp)) except Exception: pass for k in range(0, len(idxs), batch_size): batch_idx = idxs[k:k+batch_size] sp_b = [species_names[i] for i in batch_idx] pr_b = [protein_seqs[i] for i in batch_idx] out = sampler.sample( num_sequences=len(batch_idx), sequence_length=L, species=sp_b, protein_sequences=pr_b, control_mode=control_mode, temperature=temperature, top_k=top_k, top_p=top_p, return_intermediate=False, progress_bar=False, enforce_translation=enforce_translation, ) gen_list: List[str] = out["sequences"] for pos, idx in enumerate(batch_idx): gen_seq = gen_list[pos] gen_all[idx] = gen_seq tgt_codons = tgt_codons_list[idx] gen_codons = _dna_to_codons(gen_seq)[:L] matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b) codon_accs[idx] = (matches / L) if L > 0 else 0.0 gen_aa = dna_to_aa(''.join(gen_codons)) tgt_aa = protein_seqs[idx][:L] canonical = set("ACDEFGHIKLMNPQRSTVWY") aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b)) aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0 if bool(no_truncation): try: setattr(sampler.model, "max_species_prefix", prev_sp) setattr(sampler.model, "max_protein_prefix", prev_pp) except Exception: pass return gen_all, codon_accs, aa_accs def export_per_sequence_over_splits( sampler: CodonSampler, splits: List[str], splits_root: str, out_csv: str, batch_size: int, temperature: float, top_k: int, top_p: float, control_mode: str, enforce_translation: bool, progress: bool = False, max_rows_per_split: int = 0, no_truncation: bool = False, species_prefix_cap: int = 0, ) -> None: """Process ./data/val and ./data/test (or under splits_root) and write a per-sequence CSV.""" try: import pyarrow.parquet as pq # type: ignore except Exception as e: raise ImportError("pyarrow is required for Parquet evaluation/export") from e from pathlib import Path as _P import os as _os total_written = 0 # Pre-create CSV with header so users can tail it immediately header_cols = [ "split", "organism", "protein_seq", "codon_seq", "predicted_seq", "codon_similarity", "amino_acid_recovery_rate", ] _P(out_csv).parent.mkdir(parents=True, exist_ok=True) if not _P(out_csv).exists() or _os.path.getsize(out_csv) == 0: with open(out_csv, "w", newline="") as f: f.write(",".join(header_cols) + "\n") logging.info(f"Initialized CSV with header → {out_csv}") for split in splits: rows_remaining = int(max_rows_per_split) if int(max_rows_per_split) > 0 else None dir_path = Path(splits_root) / split files = sorted(str(p) for p in dir_path.glob("*.parquet")) if not files: logging.warning(f"No parquet files found in {dir_path}, skipping split {split}") continue logging.info(f"Processing split '{split}' with {len(files)} files ...") try: from tqdm import tqdm # type: ignore _wrap = (lambda it, **kw: tqdm(it, **kw)) if progress else (lambda it, **kw: it) except Exception: _wrap = (lambda it, **kw: it) stop_split = False for fp in _wrap(files, desc=f"{split} files", unit="file"): if rows_remaining is not None and rows_remaining <= 0: break pf = pq.ParquetFile(fp) nrg = int(pf.num_row_groups or 0) rgs = list(range(max(nrg, 1))) # Build a per-file rows progress bar (prefer total rows from metadata when available) rows_total = None try: if pf.metadata is not None: rows_total = 0 for rg_idx in rgs: rg_md = pf.metadata.row_group(rg_idx) if rg_md is not None and rg_md.num_rows is not None: rows_total += int(rg_md.num_rows) except Exception: rows_total = None rows_pbar = None if progress: try: from tqdm import tqdm # type: ignore rows_pbar = tqdm(total=rows_total, desc=f"{split}:{Path(fp).name}", unit="rows", leave=False) except Exception: rows_pbar = None for rg in rgs: if rows_remaining is not None and rows_remaining <= 0: stop_split = True break table = pf.read_row_group(rg, columns=["Taxon", "protein_seq", "cds_DNA"]) df = table.to_pandas() if df.empty: continue species = df["Taxon"].astype(str).tolist() proteins = df["protein_seq"].astype(str).str.upper().tolist() dnas = df["cds_DNA"].astype(str).str.upper().tolist() # Generate predictions and metrics in streaming mini-batches to keep # memory stable and update progress frequently N = len(species) for off in range(0, N, batch_size): if rows_remaining is not None and rows_remaining <= 0: stop_split = True break sp_b = species[off: off + batch_size] pr_b = proteins[off: off + batch_size] dn_b = dnas[off: off + batch_size] gen_list, codon_accs, aa_accs = generate_and_score_batched( sampler, sp_b, pr_b, dn_b, temperature=temperature, top_k=top_k, top_p=top_p, control_mode=control_mode, batch_size=batch_size, enforce_translation=enforce_translation, no_truncation=bool(no_truncation), species_prefix_cap=int(species_prefix_cap), ) rows_batch: List[dict] = [] for sp, pr, dn, gen, cacc, aacc in zip(sp_b, pr_b, dn_b, gen_list, codon_accs, aa_accs): L = min(len(pr), len(dn) // 3) tgt_dna = dn[: 3 * L] rows_batch.append({ "split": split, "organism": sp, "protein_seq": pr, "codon_seq": tgt_dna, "predicted_seq": gen, "codon_similarity": float(cacc), "amino_acid_recovery_rate": float(aacc), }) if rows_batch: if rows_remaining is not None and len(rows_batch) > rows_remaining: rows_batch = rows_batch[: rows_remaining] out_exists = _P(out_csv).exists() and _os.path.getsize(out_csv) > 0 df_out = pd.DataFrame(rows_batch) _P(out_csv).parent.mkdir(parents=True, exist_ok=True) df_out.to_csv(out_csv, mode='a', header=not out_exists, index=False) total_written += len(rows_batch) if rows_remaining is not None: rows_remaining -= len(rows_batch) if rows_pbar is not None: try: rows_pbar.update(len(rows_batch)) except Exception: pass if rows_remaining is not None and rows_remaining <= 0: stop_split = True break if rows_pbar is not None: try: rows_pbar.close() except Exception: pass if stop_split: break logging.info(f"Per-sequence export complete → {out_csv} (rows={total_written})") def main(): args = parse_args() random.seed(args.seed) torch.manual_seed(args.seed) model_dir = Path(args.model_path) pooling = _preferred_pooling(model_dir) logger.info(f"Preferred species_pooling from checkpoint: {pooling}") # Set up species store (recommended for parity) species_store = None if args.embeddings_dir: emb_dir = Path(args.embeddings_dir) detected = _detect_pooling_from_embeddings_dir(emb_dir) if detected is not None and detected != pooling: logger.info(f"Overriding pooling from checkpoint ({pooling}) → embeddings_dir format ({detected})") pooling = detected species_store = SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling) logger.info(f"Loaded species store with {len(species_store.vocab)} species (pooling={pooling})") # Load sampler/model (uses same construction as sampling) sampler = CodonSampler( model_path=args.model_path, device=("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"), species_store=species_store, ) # Load input data and sample rows if bool(args.export_per_sequence): export_per_sequence_over_splits( sampler, splits=list(args.export_splits), splits_root=str(args.splits_root), out_csv=str(args.out_csv), batch_size=int(args.batch_size), temperature=float(args.temperature), top_k=int(args.top_k), top_p=float(args.top_p), control_mode=str(args.control_mode), enforce_translation=bool(args.enforce_translation), progress=bool(args.progress), max_rows_per_split=int(args.max_rows_per_split), no_truncation=bool(args.no_truncation), species_prefix_cap=int(args.species_prefix_cap), ) return data_path = args.data_path or args.csv_path if data_path is None: raise SystemExit("Please provide --data_path (CSV or Parquet glob/dir). --csv_path remains as a deprecated alias.") # Expand paths to decide CSV vs Parquet paths = _expand_paths(data_path) if not paths: raise FileNotFoundError(f"No input files matched: {data_path}") if all(_is_parquet_path(p) for p in paths): logger.info(f"Reading up to {args.num_samples} samples from {len(paths)} parquet files ...") df_s = _load_random_samples_from_parquet(paths, int(args.num_samples), int(args.seed)) else: # Fallback to CSV/TSV single file behavior (back-compat). If multiple files match, use the first. csv_file = None for pth in paths: if pth.lower().endswith((".csv", ".tsv", ".csv.gz", ".tsv.gz")): csv_file = pth break if csv_file is None: raise ValueError(f"Unsupported input for --data_path: {paths[0]}") logger.info(f"Reading CSV file: {csv_file}") df = pd.read_csv(csv_file) required = {"Taxon", "protein_seq", "cds_DNA"} if not required.issubset(set(df.columns)): missing = required - set(df.columns) raise ValueError(f"CSV missing required columns: {sorted(missing)}") if args.num_samples > len(df): logger.warning(f"num_samples {args.num_samples} > CSV rows {len(df)}; reducing") args.num_samples = len(df) # Random sample without replacement indices = random.sample(range(len(df)), args.num_samples) df_s = df.iloc[indices].reset_index(drop=True) if len(df_s) == 0: raise ValueError("No samples loaded from the provided data_path") logger.info(f"Loaded {len(df_s)} samples for evaluation") species = df_s["Taxon"].astype(str).tolist() proteins = df_s["protein_seq"].astype(str).str.upper().tolist() dnas = df_s["cds_DNA"].astype(str).str.upper().tolist() if not args.free_run: if bool(args.eval_all): if not args.embeddings_dir: raise SystemExit("--eval_all requires --embeddings_dir for species vocab/embeddings") # Stream the entire dataset and compute dataset-level metrics (training-parity) eval_streaming_all( sampler, species_store if species_store is not None else SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling), data_path, batch_size=int(args.batch_size), num_workers=int(args.workers), max_records=int(args.max_records), ) return # Optional: print per-sample CDS→AA agreement (standard code) if bool(args.debug_aa_check): for idx, (sp, pr, dn) in enumerate(zip(species, proteins, dnas), start=1): ratio, Lcmp, first_bad = _aa_agreement(dn, pr, sampler.tokenizer) flag = "OK" if ratio == 1.0 and Lcmp > 0 else ("EMPTY" if Lcmp == 0 else "MISMATCH") extra = f" first_mismatch={first_bad}" if first_bad >= 0 else "" logger.info(f"AA-CHECK Sample {idx:02d}: {flag} match={ratio:.3f} len={Lcmp}{extra} Taxon={sp}") # (No dataset-level filtering to keep evaluation simple.) # Teacher-forced evaluation (random subset) per_ce_all: List[float] = [] per_aa_acc_all: List[float] = [] per_codon_acc_all: List[float] = [] bs = max(1, int(args.batch_size)) for i in range(0, len(species), bs): sp_b = species[i:i+bs] pr_b = proteins[i:i+bs] dn_b = dnas[i:i+bs] ce, aa_acc = eval_batch(sampler, species_store, sp_b, pr_b, dn_b) # Also compute per-sample codon-acc using the same batch forward for consistency # Re-run lightweight preds for codon-acc is unnecessary because eval_batch already # computed supervised mask and preds internally; instead, recompute quickly here # by calling eval_batch and deriving codon-acc inside it. For simplicity and clarity # we re-derive codon-acc below using the same masking rules. per_ce_all.extend(ce) per_aa_acc_all.extend(aa_acc) # Derive codon-acc for this batch # Prepare a mirrored forward to access logits and masks (small overhead acceptable) tok = sampler.tokenizer pad_id = tok.pad_token_id eos_id = tok.eos_token_id codon_ids_local = [] for dna, prot in zip(dn_b, pr_b): C_dna = len(dna) // 3 C_prot = len(prot) C = max(min(C_dna, C_prot), 1) dna_trim = dna[: 3 * C] ids = tok.encode_codon_seq(dna_trim, validate=False) ids.append(eos_id) codon_ids_local.append(ids) B_b = len(codon_ids_local) T_b = max(len(x) for x in codon_ids_local) codons_b = torch.full((B_b, T_b), pad_id, dtype=torch.long) mask_b = torch.zeros((B_b, T_b), dtype=torch.bool) for j, ids in enumerate(codon_ids_local): Lb = len(ids) codons_b[j, :Lb] = torch.tensor(ids, dtype=torch.long) mask_b[j, :Lb] = True input_ids_b = codons_b[:, :-1].to(sampler.device) labels_b = codons_b[:, :-1].clone() labels_b[labels_b == pad_id] = -100 labels_b[labels_b == eos_id] = -100 cond_b = {"control_mode": "fixed"} if species_store is not None and sp_b: sids_b = [species_store.vocab.get(s, -1) for s in sp_b] res_b = species_store.batch_get(sids_b) if isinstance(res_b, tuple): sp_tok_b, _ = res_b cond_b["species_tok_emb_src"] = sp_tok_b.to(sampler.device) cond_b["species_tok_emb_tgt"] = sp_tok_b.to(sampler.device) else: sp_fix_b = res_b cond_b["species_emb_src"] = sp_fix_b.to(sampler.device) cond_b["species_emb_tgt"] = sp_fix_b.to(sampler.device) cond_b["protein_seqs"] = pr_b out_b = sampler.model(codon_ids=input_ids_b, cond=cond_b, labels=labels_b.to(sampler.device), return_dict=True) logits_b = out_b["logits"] per_cap_b = out_b.get("per_cap") if logits_b is not None and per_cap_b is not None: Bsz, Lmax, V = logits_b.size(0), logits_b.size(1), logits_b.size(2) labels_aligned_b = torch.full((Bsz, Lmax), -100, dtype=labels_b.dtype, device=logits_b.device) common_cols_b = min(labels_b.size(1), Lmax) if common_cols_b > 0: labels_aligned_b[:, :common_cols_b] = labels_b.to(logits_b.device)[:, :common_cols_b] ar = torch.arange(Lmax, device=logits_b.device).unsqueeze(0) cap_mask_b = ar < per_cap_b.to(device=logits_b.device).unsqueeze(1) labels_masked_b = labels_aligned_b.clone() labels_masked_b[~cap_mask_b] = -100 preds_b = logits_b.argmax(dim=-1) num_special = int(getattr(tok, "num_special_tokens", 0) or 0) supervised_b = (labels_masked_b != -100) & cap_mask_b if num_special > 0: supervised_b = supervised_b & (labels_aligned_b >= num_special) for r in range(Bsz): denom = int(supervised_b[r].sum().item()) cod_acc = (float((preds_b[r][supervised_b[r]] == labels_aligned_b[r][supervised_b[r]]).sum().item()) / denom) if denom > 0 else 0.0 per_codon_acc_all.append(cod_acc) for idx, (ce, aa, ca) in enumerate(zip(per_ce_all, per_aa_acc_all, per_codon_acc_all), start=1): logger.info(f"Sample {idx:02d}: CE={ce:.4f} CODON-acc={ca:.4f} AA-acc={aa:.4f}") if per_ce_all: mean_ce = sum(per_ce_all) / len(per_ce_all) mean_aa = sum(per_aa_acc_all) / len(per_aa_acc_all) if per_aa_acc_all else 0.0 mean_codon = sum(per_codon_acc_all) / len(per_codon_acc_all) if per_codon_acc_all else 0.0 logger.info(f"Summary over {len(per_ce_all)} samples → mean CE={mean_ce:.4f}, mean CODON-acc={mean_codon:.4f}, mean AA-acc={mean_aa:.4f}") else: # Free-run sampling evaluation vs ground-truth DNA (codon-level), batched codon_accs, aa_accs = sample_and_score_batched( sampler, species, proteins, dnas, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, control_mode=args.control_mode, batch_size=int(args.batch_size), enforce_translation=bool(args.enforce_translation), no_truncation=bool(args.no_truncation), species_prefix_cap=int(args.species_prefix_cap), ) for idx, (cacc, aacc) in enumerate(zip(codon_accs, aa_accs), start=1): logger.info(f"Sample {idx:02d}: CODON-acc={cacc:.4f} AA-acc={aacc:.4f}") if codon_accs: mean_c = sum(codon_accs) / len(codon_accs) mean_a = sum(aa_accs) / len(aa_accs) logger.info(f"Summary over {len(codon_accs)} samples → mean CODON-acc={mean_c:.4f}, mean AA-acc={mean_a:.4f}") if __name__ == "__main__": main()