| |
| """ |
| Teacher-forced (and optional free-run) evaluation on a random subset of your |
| dataset to measure codon token cross-entropy and AA token accuracy, using the |
| same conditioning pathway as training. |
| |
| Supports either a CSV file or Parquet input via a directory/glob (e.g., |
| ./data/val/*.parquet). |
| |
| Usage examples: |
| # CSV input |
| python eval.py \ |
| --model_path outputs/checkpoint-21000 \ |
| --data_path random_sample_1000.csv \ |
| --embeddings_dir embeddings \ |
| --num_samples 10 \ |
| --batch_size 10 \ |
| --device cuda |
| |
| # Parquet glob input |
| python eval.py \ |
| --model_path outputs/checkpoint-21000 \ |
| --data_path "./data/val/*.parquet" \ |
| --embeddings_dir embeddings \ |
| --num_samples 64 \ |
| --batch_size 32 \ |
| --device cuda |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import random |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
| import glob |
|
|
| import torch |
| import torch.nn.functional as F |
| import pandas as pd |
|
|
| from src.sampler import CodonSampler |
| from src.dataset import SpeciesEmbeddingStore, StreamSeqDataset, stage_collate_fn |
| from torch.utils.data import DataLoader |
|
|
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger("eval_tf") |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser("Teacher-forced evaluation of CodonTranslator") |
| p.add_argument("--model_path", required=True, type=str, |
| help="Path to checkpoint dir (with config.json / model.safetensors)") |
| |
| p.add_argument("--data_path", required=False, type=str, default=None, |
| help="CSV file or Parquet glob/dir (e.g., ./data/val/*.parquet)") |
| |
| p.add_argument("--csv_path", required=False, type=str, default=None, |
| help="[Deprecated] CSV with columns: Taxon, protein_seq, cds_DNA") |
| p.add_argument("--embeddings_dir", type=str, default=None, |
| help="Species embeddings directory (recommended for parity)") |
| p.add_argument("--num_samples", type=int, default=10) |
| p.add_argument("--batch_size", type=int, default=10) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--device", type=str, default="cuda") |
| p.add_argument("--workers", type=int, default=0, |
| help="DataLoader workers for --eval_all streaming mode") |
| |
| p.add_argument("--free_run", action="store_true", |
| help="If set, perform real sampling instead of teacher forcing and compare to ground-truth codon sequences") |
| p.add_argument("--temperature", type=float, default=0.8) |
| p.add_argument("--top_k", type=int, default=50) |
| p.add_argument("--top_p", type=float, default=0.9) |
| p.add_argument("--control_mode", type=str, choices=["fixed","variable"], default="fixed") |
| p.add_argument("--enforce_translation", action="store_true", |
| help="Hard-mask decoding to codons matching target amino acid at each position during free-run evaluation") |
| |
| p.add_argument("--eval_all", action="store_true", |
| help="Stream over all rows from --data_path and compute aggregated metrics (memory-safe)") |
| p.add_argument("--max_records", type=int, default=0, |
| help="When --eval_all is set: limit to first N samples (0 = all)") |
| p.add_argument("--debug_aa_check", action="store_true", |
| help="Print per-sample agreement between CDS→AA (standard code) and provided protein") |
| |
| p.add_argument("--export_per_sequence", action="store_true", |
| help="Process ./data/val and ./data/test parquets in batches and export a per-sequence CSV") |
| p.add_argument("--splits_root", type=str, default="./data", |
| help="Root directory that contains val/ and test/ subfolders with parquet files") |
| p.add_argument("--out_csv", type=str, default="outputs/eval_per_sequence.csv", |
| help="Output CSV path for per-sequence export") |
| p.add_argument("--export_splits", nargs="+", default=["val", "test"], |
| help="Subdirectories under --splits_root to process (default: val test)") |
| p.add_argument("--max_rows_per_split", type=int, default=0, |
| help="When --export_per_sequence is set: limit number of rows per split (0 = all)") |
| p.add_argument("--progress", action="store_true", |
| help="Show progress bars during per-sequence export") |
| |
| p.add_argument("--no_truncation", action="store_true", |
| help="Fit prefix caps so generated codon length equals protein length (avoids capacity truncation)") |
| p.add_argument("--species_prefix_cap", type=int, default=0, |
| help="When >0 and --no_truncation is set, cap species token prefix to this many tokens; 0 = no species cap") |
| return p.parse_args() |
|
|
|
|
| def _is_parquet_path(p: str) -> bool: |
| lower = p.lower() |
| return lower.endswith(".parquet") or lower.endswith(".parq") |
|
|
|
|
| def _expand_paths(maybe_path_or_glob: str) -> List[str]: |
| """Expand a path/glob or directory into a sorted list of files. |
| Prioritize Parquet when scanning a directory. |
| """ |
| paths: List[str] = [] |
| P = Path(maybe_path_or_glob) |
| if P.is_dir(): |
| paths.extend(sorted(str(x) for x in P.rglob("*.parquet"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.parq"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.csv"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.tsv"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz"))) |
| else: |
| paths = sorted(glob.glob(str(P))) |
| |
| out: List[str] = [] |
| seen = set() |
| for x in paths: |
| if x not in seen: |
| out.append(x) |
| seen.add(x) |
| return out |
|
|
|
|
| def _load_random_samples_from_parquet(files: List[str], num_samples: int, seed: int) -> pd.DataFrame: |
| """Collect up to num_samples rows from a list of Parquet files, reading by row group. |
| Reads only the required columns and shuffles files/row-groups for decent coverage. |
| """ |
| try: |
| import pyarrow.parquet as pq |
| except Exception as e: |
| raise ImportError("pyarrow is required to read parquet files") from e |
|
|
| rng = random.Random(seed) |
| req = ["Taxon", "protein_seq", "cds_DNA"] |
| files = [f for f in files if _is_parquet_path(f)] |
| if not files: |
| raise FileNotFoundError("No Parquet files found to read") |
| files = files.copy() |
| rng.shuffle(files) |
|
|
| collected: List[pd.DataFrame] = [] |
| remaining = int(max(0, num_samples)) |
| for fp in files: |
| if remaining <= 0: |
| break |
| pf = pq.ParquetFile(fp) |
| nrg = int(pf.num_row_groups or 0) |
| if nrg <= 0: |
| rgs = [0] |
| else: |
| rgs = list(range(nrg)) |
| rng.shuffle(rgs) |
| |
| cols = [c for c in req if c in pf.schema.names] |
| if len(cols) < len(req): |
| missing = sorted(set(req) - set(cols)) |
| raise ValueError(f"Parquet missing required columns {missing} in {fp}") |
| for rg in rgs: |
| if remaining <= 0: |
| break |
| table = pf.read_row_group(rg, columns=cols) |
| df = table.to_pandas(types_mapper=None) |
| if df.empty: |
| continue |
| if len(df) > remaining: |
| df = df.sample(n=remaining, random_state=rng.randint(0, 2**31 - 1)) |
| collected.append(df) |
| remaining -= len(df) |
| if not collected: |
| return pd.DataFrame(columns=req) |
| out = pd.concat(collected, ignore_index=True) |
| |
| out = out.sample(frac=1.0, random_state=seed).reset_index(drop=True) |
| |
| if len(out) > num_samples: |
| out = out.iloc[:num_samples].reset_index(drop=True) |
| return out |
|
|
|
|
| def _preferred_pooling(model_dir: Path) -> str: |
| """ |
| Best-effort pooling detection: |
| - First try checkpoint configs for an explicit hint |
| - Fallback to 'last' |
| Note: we'll further override this using the embeddings_dir contents if provided. |
| """ |
| for cfg_name in ("trainer_config.json", "config.json"): |
| fp = model_dir / cfg_name |
| if fp.exists(): |
| try: |
| with open(fp) as f: |
| cfg = json.load(f) |
| return str(cfg.get("species_pooling", "last")) |
| except Exception: |
| continue |
| return "last" |
|
|
|
|
| def _detect_pooling_from_embeddings_dir(emb_dir: Path) -> Optional[str]: |
| """Detect actual available pooling format from embeddings_dir contents.""" |
| fixed_files = [emb_dir / "species_embeddings.bin", emb_dir / "species_metadata.json", emb_dir / "species_vocab.json"] |
| seq_files = [emb_dir / "species_tok_emb.bin", emb_dir / "species_index.json", emb_dir / "species_vocab.json"] |
| if all(p.exists() for p in fixed_files): |
| return "last" |
| if all(p.exists() for p in seq_files): |
| return "sequence" |
| return None |
|
|
|
|
| @torch.no_grad() |
| def eval_batch( |
| sampler: CodonSampler, |
| species_store: Optional[SpeciesEmbeddingStore], |
| species_names: List[str], |
| protein_seqs: List[str], |
| dna_cds_list: List[str], |
| ) -> Tuple[List[float], List[float]]: |
| """Evaluate a batch in teacher-forced mode. |
| |
| Returns per-sample (avg_ce_loss, aa_token_acc). |
| """ |
| tok = sampler.tokenizer |
| pad_id = tok.pad_token_id |
| eos_id = tok.eos_token_id |
|
|
| |
| codon_ids = [] |
| seq_lens = [] |
| for dna, prot in zip(dna_cds_list, protein_seqs): |
| |
| C_dna = len(dna) // 3 |
| C_prot = len(prot) |
| C = max(min(C_dna, C_prot), 1) |
| dna_trim = dna[: 3 * C] |
| ids = tok.encode_codon_seq(dna_trim, validate=False) |
| ids.append(eos_id) |
| codon_ids.append(ids) |
| seq_lens.append(len(ids)) |
|
|
| B = len(codon_ids) |
| T = max(seq_lens) |
| codons = torch.full((B, T), pad_id, dtype=torch.long) |
| mask = torch.zeros((B, T), dtype=torch.bool) |
| for i, ids in enumerate(codon_ids): |
| L = len(ids) |
| codons[i, :L] = torch.tensor(ids, dtype=torch.long) |
| mask[i, :L] = True |
|
|
| |
| |
| |
| input_ids = codons[:, :-1] |
| labels_base = codons[:, :-1].clone() |
| |
| labels_base[labels_base == pad_id] = -100 |
| labels_base[labels_base == eos_id] = -100 |
|
|
| |
| cond = {"control_mode": "fixed"} |
|
|
| if species_store is not None and species_names: |
| sid_list = [species_store.vocab.get(s, -1) for s in species_names] |
| num_unknown = sum(1 for x in sid_list if x < 0) |
| if num_unknown > 0: |
| logger.warning(f"{num_unknown}/{len(sid_list)} species not found in embeddings vocab; using zero embeddings") |
| result = species_store.batch_get(sid_list) |
| if isinstance(result, tuple): |
| sp_tok, _ = result |
| cond["species_tok_emb_src"] = sp_tok.to(sampler.device) |
| cond["species_tok_emb_tgt"] = sp_tok.to(sampler.device) |
| else: |
| sp = result |
| cond["species_emb_src"] = sp.to(sampler.device) |
| cond["species_emb_tgt"] = sp.to(sampler.device) |
| elif species_names: |
| |
| seq_emb, _lens = sampler._qwen_embed_names(species_names, pooling="sequence") |
| seq_emb = seq_emb.to(sampler.device) |
| cond["species_tok_emb_src"] = seq_emb |
| cond["species_tok_emb_tgt"] = seq_emb |
|
|
| |
| cond["protein_seqs"] = protein_seqs |
|
|
| |
| device = sampler.device |
| input_ids = input_ids.to(device) |
| labels_base = labels_base.to(device) |
|
|
| sampler.model.eval() |
| outputs = sampler.model(codon_ids=input_ids, cond=cond, labels=labels_base, return_dict=True) |
| logits = outputs["logits"] |
| try: |
| prefix_len = outputs.get("prefix_len", 0) |
| if isinstance(prefix_len, torch.Tensor): |
| prefix_len_dbg = int(prefix_len.max().item()) if prefix_len.numel() > 0 else 0 |
| else: |
| prefix_len_dbg = int(prefix_len) |
| logger.debug(f"Prefix length(max)={prefix_len_dbg}, input_len={input_ids.size(1)}") |
| except Exception: |
| pass |
|
|
| |
| Bsz, Lmax, V = logits.size(0), logits.size(1), logits.size(2) |
| labels_aligned = torch.full((Bsz, Lmax), -100, dtype=labels_base.dtype, device=logits.device) |
| common_cols = min(labels_base.size(1), Lmax) |
| if common_cols > 0: |
| labels_aligned[:, :common_cols] = labels_base[:, :common_cols] |
| per_cap = outputs.get("per_cap", None) |
| if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz: |
| ar = torch.arange(Lmax, device=logits.device).unsqueeze(0) |
| cap_mask = ar < per_cap.to(device=logits.device).unsqueeze(1) |
| else: |
| cap_mask = torch.ones_like(labels_aligned, dtype=torch.bool, device=logits.device) |
|
|
| |
| labels_masked = labels_aligned.clone().to(device=logits.device) |
| labels_masked[~cap_mask] = -100 |
|
|
| |
| loss_flat = F.cross_entropy( |
| logits.reshape(-1, V), |
| labels_masked.reshape(-1), |
| ignore_index=-100, |
| reduction="none", |
| ).view(Bsz, Lmax) |
|
|
| |
| preds = logits.argmax(dim=-1) |
| num_special = int(getattr(tok, "num_special_tokens", 0) or 0) |
| supervised = (labels_masked != -100) & cap_mask |
| if num_special > 0: |
| supervised = supervised & (labels_aligned >= num_special) |
| correct = (preds == labels_aligned) & supervised |
|
|
| per_sample_ce: List[float] = [] |
| per_sample_acc: List[float] = [] |
| per_sample_aa_acc: List[float] = [] |
| codon2aa = tok.codon2aa_char_map() if hasattr(tok, "codon2aa_char_map") else {} |
| per_cap = outputs.get("per_cap", None) |
| per_cap_int = None |
| if isinstance(per_cap, torch.Tensor) and per_cap.numel() == Bsz: |
| per_cap_int = torch.clamp(per_cap.to(dtype=torch.long, device=logits.device), min=0, max=Lmax) |
|
|
| for i in range(B): |
| |
| valid = (labels_masked[i] != -100) & cap_mask[i] |
| if num_special > 0: |
| valid = valid & (labels_aligned[i] >= num_special) |
| ce = (loss_flat[i][valid].mean().item() if valid.any() else 0.0) |
| per_sample_ce.append(ce) |
|
|
| |
| denom = supervised[i].sum().item() |
| acc = (correct[i].sum().item() / denom) if denom > 0 else 0.0 |
| |
| aa_acc = 0.0 |
| if per_cap_int is not None and codon2aa and i < len(protein_seqs): |
| cap = int(per_cap_int[i].item()) |
| if cap > 0: |
| mask_row = supervised[i, :cap] |
| if mask_row.any(): |
| preds_row = preds[i, :cap][mask_row] |
| prot = protein_seqs[i] |
| seq_len = min(len(prot), preds_row.size(0)) |
| if seq_len > 0: |
| pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) |
| truth_aa = prot[:seq_len] |
| aa_matches = sum(1 for j in range(seq_len) if pred_aa[j] == truth_aa[j]) |
| aa_acc = aa_matches / seq_len |
| per_sample_aa_acc.append(aa_acc) |
|
|
| return per_sample_ce, per_sample_aa_acc |
|
|
|
|
| def _dna_to_codons(dna: str) -> List[str]: |
| dna = dna.strip().upper() |
| return [dna[i:i+3] for i in range(0, len(dna) - (len(dna) % 3), 3)] |
|
|
|
|
| def _aa_from_dna_standard(dna: str, tok) -> str: |
| dna = dna.strip().upper() |
| gc = getattr(tok, "_genetic_code", {}) |
| aa = [] |
| for j in range(0, len(dna) - (len(dna) % 3), 3): |
| aa.append(gc.get(dna[j:j+3], 'X')) |
| return ''.join(aa) |
|
|
|
|
| def _aa_agreement(dna: str, protein: str, tok) -> Tuple[float, int, int]: |
| """Return (match_ratio, compared_len, first_mismatch_idx or -1) under standard code.""" |
| dna = dna.strip().upper() |
| protein = protein.strip().upper() |
| L = min(len(dna) // 3, len(protein)) |
| if L <= 0: |
| return 0.0, 0, -1 |
| aa_pred = _aa_from_dna_standard(dna[: 3 * L], tok) |
| truth = protein[:L] |
| mism_idx = -1 |
| matches = 0 |
| for i, (a, b) in enumerate(zip(aa_pred, truth)): |
| if a == b: |
| matches += 1 |
| elif mism_idx < 0: |
| mism_idx = i |
| return (matches / L), L, mism_idx |
|
|
|
|
| @torch.no_grad() |
| def eval_streaming_all( |
| sampler: CodonSampler, |
| species_store: SpeciesEmbeddingStore, |
| data_path: str, |
| batch_size: int, |
| num_workers: int, |
| max_records: int = 0, |
| ): |
| """Stream over all rows from CSV/Parquet inputs and compute dataset-level metrics. |
| |
| Mirrors trainer.evaluate() for parity. |
| """ |
| device = sampler.device |
| tok = sampler.tokenizer |
| pad_id = int(tok.pad_token_id) |
| eos_id = int(tok.eos_token_id) |
| num_special = int(tok.num_special_tokens) |
| codon2aa = tok.codon2aa_char_map() |
|
|
| |
| from pathlib import Path as _Path |
| import glob as _glob |
| def _expand(pat: str) -> List[str]: |
| P = _Path(pat) |
| if P.is_dir(): |
| paths: List[str] = [] |
| paths.extend(sorted(str(x) for x in P.rglob("*.parquet"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.parq"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.csv"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.tsv"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.csv.gz"))) |
| paths.extend(sorted(str(x) for x in P.rglob("*.tsv.gz"))) |
| else: |
| paths = sorted(_glob.glob(str(P))) |
| |
| seen = set(); out = [] |
| for x in paths: |
| if x not in seen: |
| out.append(x); seen.add(x) |
| return out |
|
|
| paths = _expand(data_path) |
| if not paths: |
| raise FileNotFoundError(f"No input files matched: {data_path}") |
|
|
| species_vocab_path = str((Path(species_store.embeddings_dir) / "species_vocab.json").resolve()) |
| ds = StreamSeqDataset( |
| files=paths, |
| tokenizer=tok, |
| species_vocab_path=species_vocab_path, |
| unknown_species_id=0, |
| csv_chunksize=200_000, |
| shuffle_buffer=0, |
| shard_across_ranks=False, |
| ) |
| _dl_kwargs = dict( |
| batch_size=int(batch_size), |
| shuffle=False, |
| drop_last=False, |
| num_workers=int(max(0, num_workers)), |
| collate_fn=stage_collate_fn, |
| pin_memory=True, |
| persistent_workers=(int(num_workers) > 0), |
| ) |
| if int(num_workers) > 0: |
| _dl_kwargs["prefetch_factor"] = 4 |
| loader = DataLoader(ds, **_dl_kwargs) |
|
|
| loss_sum = 0.0 |
| loss_tokens = 0 |
| codon_correct = 0 |
| codon_total = 0 |
| aa_correct = 0 |
| aa_total = 0 |
|
|
| seen = 0 |
| for batch in loader: |
| if not batch: |
| continue |
| if int(max_records) > 0 and seen >= int(max_records): |
| break |
| codon_ids = batch["codon_ids"].to(device) |
| input_ids = codon_ids[:, :-1] |
| labels = codon_ids[:, :-1].clone() |
| labels[labels == pad_id] = -100 |
| labels[labels == eos_id] = -100 |
|
|
| |
| cond = {"control_mode": "fixed", "protein_seqs": batch.get("protein_seqs", [])} |
| sids = batch.get("species_ids") |
| if torch.is_tensor(sids): |
| sids_list = sids.detach().cpu().tolist() |
| else: |
| sids_list = [int(x) for x in sids] |
| res = species_store.batch_get(sids_list) |
| if isinstance(res, tuple): |
| sp_tok, _ = res |
| cond["species_tok_emb_src"] = sp_tok.to(device) |
| cond["species_tok_emb_tgt"] = sp_tok.to(device) |
| else: |
| cond["species_emb_src"] = res.to(device) |
| cond["species_emb_tgt"] = res.to(device) |
|
|
| out = sampler.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True) |
| loss = out.get("loss") |
| per_cap = out.get("per_cap") |
| logits = out.get("logits") |
|
|
| tokens_in_batch = 0 |
| if per_cap is not None: |
| tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item()) |
| loss_tokens += tokens_in_batch |
| if loss is not None and tokens_in_batch > 0: |
| loss_sum += float(loss.detach().item()) * tokens_in_batch |
|
|
| if logits is None or logits.size(1) == 0 or per_cap is None: |
| seen += input_ids.size(0) |
| continue |
| max_cap = logits.size(1) |
| batch_size = logits.size(0) |
| labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device) |
| common = min(labels.size(1), max_cap) |
| if common > 0: |
| labels_aligned[:, :common] = labels[:, :common] |
| per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap) |
| for row in range(batch_size): |
| cap = int(per_cap_int[row].item()) |
| if cap < max_cap: |
| labels_aligned[row, cap:] = -100 |
| supervised = labels_aligned != -100 |
| if num_special > 0: |
| supervised = supervised & (labels_aligned >= num_special) |
| if not supervised.any(): |
| seen += batch_size |
| continue |
| preds = logits.argmax(dim=-1) |
| codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item()) |
| codon_total += int(supervised.sum().item()) |
|
|
| |
| prot_list = cond.get("protein_seqs", []) |
| for row in range(batch_size): |
| cap = int(per_cap_int[row].item()) |
| if cap <= 0: |
| continue |
| mask_row = supervised[row, :cap] |
| if not mask_row.any(): |
| continue |
| preds_row = preds[row, :cap][mask_row] |
| prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else "" |
| if not prot: |
| continue |
| seq_len = min(len(prot), preds_row.size(0)) |
| if seq_len <= 0: |
| continue |
| pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) |
| truth_aa = prot[:seq_len] |
| aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i]) |
| aa_total += seq_len |
| seen += batch_size |
|
|
| mean_ce = (loss_sum / loss_tokens) if loss_tokens > 0 else 0.0 |
| codon_acc = (float(codon_correct) / codon_total) if codon_total > 0 else 0.0 |
| aa_acc = (float(aa_correct) / aa_total) if aa_total > 0 else 0.0 |
| logger.info( |
| f"Full-dataset summary → tokens={loss_tokens} CE={mean_ce:.4f} CODON-acc={codon_acc:.4f} AA-acc={aa_acc:.4f}" |
| ) |
| return mean_ce, codon_acc, aa_acc |
|
|
|
|
| @torch.no_grad() |
| def sample_and_score_batched( |
| sampler: CodonSampler, |
| species_names: List[str], |
| protein_seqs: List[str], |
| target_dnas: List[str], |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| control_mode: str, |
| batch_size: int, |
| enforce_translation: bool, |
| no_truncation: bool = False, |
| species_prefix_cap: int = 64, |
| ) -> Tuple[List[float], List[float]]: |
| """Free-run sampling in batches; returns per-sample (codon_acc, aa_acc).""" |
| N = len(species_names) |
| |
| tgt_lengths = [] |
| tgt_codons_list = [] |
| for prot, dna in zip(protein_seqs, target_dnas): |
| cods = _dna_to_codons(dna) |
| L = min(len(cods), len(prot)) |
| if L <= 0: |
| L = 1 |
| cods = ["ATG"] |
| tgt_lengths.append(L) |
| tgt_codons_list.append(cods[:L]) |
|
|
| |
| buckets: dict[int, List[int]] = {} |
| for i, L in enumerate(tgt_lengths): |
| buckets.setdefault(L, []).append(i) |
|
|
| codon_accs = [0.0] * N |
| aa_accs = [0.0] * N |
|
|
| |
| vocab = sampler.tokenizer._genetic_code |
| def dna_to_aa(dna: str) -> str: |
| dna = dna.strip().upper() |
| aa = [] |
| for j in range(0, len(dna) - (len(dna) % 3), 3): |
| aa.append(vocab.get(dna[j:j+3], 'X')) |
| return ''.join(aa) |
|
|
| for L, idxs in buckets.items(): |
| |
| prev_sp = getattr(sampler.model, "max_species_prefix", 0) |
| prev_pp = getattr(sampler.model, "max_protein_prefix", 0) |
| if bool(no_truncation): |
| try: |
| capacity = int(getattr(sampler.model, "max_position_embeddings", 1024)) |
| |
| store = getattr(sampler, "species_store", None) |
| if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0: |
| setattr(sampler.model, "max_species_prefix", int(species_prefix_cap)) |
| |
| batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))] |
| sp_probe = [species_names[i] for i in batch_idx_probe] |
| pr_probe = [protein_seqs[i] for i in batch_idx_probe] |
| |
| cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe} |
| if store is not None: |
| sid_list = [store.vocab.get(s, -1) for s in sp_probe] |
| res = store.batch_get(sid_list) |
| if isinstance(res, tuple): |
| sp_tok, _ = res |
| cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device) |
| cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device) |
| else: |
| cond_probe["species_emb_src"] = res.to(sampler.device) |
| cond_probe["species_emb_tgt"] = res.to(sampler.device) |
| |
| for _ in range(3): |
| out0 = sampler.model( |
| codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device), |
| cond=cond_probe, |
| return_dict=True, |
| use_cache=True, |
| ) |
| pref = out0.get("prefix_len") |
| if isinstance(pref, torch.Tensor) and pref.numel() > 0: |
| pref_max = int(pref.max().item()) |
| else: |
| pref_max = int(pref) if isinstance(pref, int) else 0 |
| remaining = capacity - (pref_max + 1) |
| if remaining >= int(L): |
| break |
| need = int(L) - max(0, int(remaining)) |
| cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0) |
| new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L))) |
| setattr(sampler.model, "max_protein_prefix", int(new_pp)) |
| except Exception: |
| pass |
| |
| for k in range(0, len(idxs), batch_size): |
| batch_idx = idxs[k:k+batch_size] |
| sp_b = [species_names[i] for i in batch_idx] |
| pr_b = [protein_seqs[i] for i in batch_idx] |
| |
| out = sampler.sample( |
| num_sequences=len(batch_idx), |
| sequence_length=L, |
| species=sp_b, |
| protein_sequences=pr_b, |
| control_mode=control_mode, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| return_intermediate=False, |
| progress_bar=False, |
| enforce_translation=enforce_translation, |
| ) |
| gen_list: List[str] = out["sequences"] |
| |
| for pos, idx in enumerate(batch_idx): |
| tgt_codons = tgt_codons_list[idx] |
| gen_codons = _dna_to_codons(gen_list[pos])[:L] |
| matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b) |
| codon_accs[idx] = (matches / L) if L > 0 else 0.0 |
| gen_aa = dna_to_aa(''.join(gen_codons)) |
| tgt_aa = protein_seqs[idx][:L] |
| |
| canonical = set("ACDEFGHIKLMNPQRSTVWY") |
| aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b)) |
| aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0 |
| |
| if bool(no_truncation): |
| try: |
| setattr(sampler.model, "max_species_prefix", prev_sp) |
| setattr(sampler.model, "max_protein_prefix", prev_pp) |
| except Exception: |
| pass |
|
|
| return codon_accs, aa_accs |
|
|
|
|
| @torch.no_grad() |
| def generate_and_score_batched( |
| sampler: CodonSampler, |
| species_names: List[str], |
| protein_seqs: List[str], |
| target_dnas: List[str], |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| control_mode: str, |
| batch_size: int, |
| enforce_translation: bool, |
| no_truncation: bool = False, |
| species_prefix_cap: int = 64, |
| ) -> Tuple[List[str], List[float], List[float]]: |
| """Like sample_and_score_batched but also returns generated DNA sequences per sample.""" |
| N = len(species_names) |
| tgt_lengths = [] |
| tgt_codons_list = [] |
| for prot, dna in zip(protein_seqs, target_dnas): |
| cods = _dna_to_codons(dna) |
| L = min(len(cods), len(prot)) |
| if L <= 0: |
| L = 1 |
| cods = ["ATG"] |
| tgt_lengths.append(L) |
| tgt_codons_list.append(cods[:L]) |
|
|
| buckets: dict[int, List[int]] = {} |
| for i, L in enumerate(tgt_lengths): |
| buckets.setdefault(L, []).append(i) |
|
|
| gen_all = [""] * N |
| codon_accs = [0.0] * N |
| aa_accs = [0.0] * N |
|
|
| vocab = sampler.tokenizer._genetic_code |
| def dna_to_aa(dna: str) -> str: |
| dna = dna.strip().upper() |
| aa = [] |
| for j in range(0, len(dna) - (len(dna) % 3), 3): |
| aa.append(vocab.get(dna[j:j+3], 'X')) |
| return ''.join(aa) |
|
|
| for L, idxs in buckets.items(): |
| prev_sp = getattr(sampler.model, "max_species_prefix", 0) |
| prev_pp = getattr(sampler.model, "max_protein_prefix", 0) |
| if bool(no_truncation): |
| try: |
| capacity = int(getattr(sampler.model, "max_position_embeddings", 1024)) |
| store = getattr(sampler, "species_store", None) |
| if store is not None and getattr(store, "is_legacy", False) and int(species_prefix_cap) > 0: |
| setattr(sampler.model, "max_species_prefix", int(species_prefix_cap)) |
| batch_idx_probe = idxs[: min(len(idxs), max(1, min(batch_size, 8)))] |
| sp_probe = [species_names[i] for i in batch_idx_probe] |
| pr_probe = [protein_seqs[i] for i in batch_idx_probe] |
| cond_probe = {"control_mode": "fixed", "protein_seqs": pr_probe} |
| if store is not None: |
| sid_list = [store.vocab.get(s, -1) for s in sp_probe] |
| res = store.batch_get(sid_list) |
| if isinstance(res, tuple): |
| sp_tok, _ = res |
| cond_probe["species_tok_emb_src"] = sp_tok.to(sampler.device) |
| cond_probe["species_tok_emb_tgt"] = sp_tok.to(sampler.device) |
| else: |
| cond_probe["species_emb_src"] = res.to(sampler.device) |
| cond_probe["species_emb_tgt"] = res.to(sampler.device) |
| for _ in range(3): |
| out0 = sampler.model( |
| codon_ids=torch.zeros(len(batch_idx_probe), 0, dtype=torch.long, device=sampler.device), |
| cond=cond_probe, |
| return_dict=True, |
| use_cache=True, |
| ) |
| pref = out0.get("prefix_len") |
| pref_max = int(pref.max().item()) if isinstance(pref, torch.Tensor) and pref.numel() > 0 else (int(pref) if isinstance(pref, int) else 0) |
| remaining = capacity - (pref_max + 1) |
| if remaining >= int(L): |
| break |
| need = int(L) - max(0, int(remaining)) |
| cur_pp = int(getattr(sampler.model, "max_protein_prefix", 0) or 0) |
| new_pp = max(0, cur_pp - need) if cur_pp > 0 else max(0, pref_max - (capacity - 1 - int(L))) |
| setattr(sampler.model, "max_protein_prefix", int(new_pp)) |
| except Exception: |
| pass |
| for k in range(0, len(idxs), batch_size): |
| batch_idx = idxs[k:k+batch_size] |
| sp_b = [species_names[i] for i in batch_idx] |
| pr_b = [protein_seqs[i] for i in batch_idx] |
| out = sampler.sample( |
| num_sequences=len(batch_idx), |
| sequence_length=L, |
| species=sp_b, |
| protein_sequences=pr_b, |
| control_mode=control_mode, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| return_intermediate=False, |
| progress_bar=False, |
| enforce_translation=enforce_translation, |
| ) |
| gen_list: List[str] = out["sequences"] |
| for pos, idx in enumerate(batch_idx): |
| gen_seq = gen_list[pos] |
| gen_all[idx] = gen_seq |
| tgt_codons = tgt_codons_list[idx] |
| gen_codons = _dna_to_codons(gen_seq)[:L] |
| matches = sum(1 for a,b in zip(gen_codons, tgt_codons) if a == b) |
| codon_accs[idx] = (matches / L) if L > 0 else 0.0 |
| gen_aa = dna_to_aa(''.join(gen_codons)) |
| tgt_aa = protein_seqs[idx][:L] |
| canonical = set("ACDEFGHIKLMNPQRSTVWY") |
| aa_matches = sum(1 for a,b in zip(gen_aa, tgt_aa) if (b not in canonical) or (a == b)) |
| aa_accs[idx] = (aa_matches / L) if L > 0 else 0.0 |
| if bool(no_truncation): |
| try: |
| setattr(sampler.model, "max_species_prefix", prev_sp) |
| setattr(sampler.model, "max_protein_prefix", prev_pp) |
| except Exception: |
| pass |
|
|
| return gen_all, codon_accs, aa_accs |
|
|
|
|
| def export_per_sequence_over_splits( |
| sampler: CodonSampler, |
| splits: List[str], |
| splits_root: str, |
| out_csv: str, |
| batch_size: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| control_mode: str, |
| enforce_translation: bool, |
| progress: bool = False, |
| max_rows_per_split: int = 0, |
| no_truncation: bool = False, |
| species_prefix_cap: int = 0, |
| ) -> None: |
| """Process ./data/val and ./data/test (or under splits_root) and write a per-sequence CSV.""" |
| try: |
| import pyarrow.parquet as pq |
| except Exception as e: |
| raise ImportError("pyarrow is required for Parquet evaluation/export") from e |
|
|
| from pathlib import Path as _P |
| import os as _os |
| total_written = 0 |
| |
| header_cols = [ |
| "split", |
| "organism", |
| "protein_seq", |
| "codon_seq", |
| "predicted_seq", |
| "codon_similarity", |
| "amino_acid_recovery_rate", |
| ] |
| _P(out_csv).parent.mkdir(parents=True, exist_ok=True) |
| if not _P(out_csv).exists() or _os.path.getsize(out_csv) == 0: |
| with open(out_csv, "w", newline="") as f: |
| f.write(",".join(header_cols) + "\n") |
| logging.info(f"Initialized CSV with header → {out_csv}") |
| for split in splits: |
| rows_remaining = int(max_rows_per_split) if int(max_rows_per_split) > 0 else None |
| dir_path = Path(splits_root) / split |
| files = sorted(str(p) for p in dir_path.glob("*.parquet")) |
| if not files: |
| logging.warning(f"No parquet files found in {dir_path}, skipping split {split}") |
| continue |
| logging.info(f"Processing split '{split}' with {len(files)} files ...") |
| try: |
| from tqdm import tqdm |
| _wrap = (lambda it, **kw: tqdm(it, **kw)) if progress else (lambda it, **kw: it) |
| except Exception: |
| _wrap = (lambda it, **kw: it) |
| stop_split = False |
| for fp in _wrap(files, desc=f"{split} files", unit="file"): |
| if rows_remaining is not None and rows_remaining <= 0: |
| break |
| pf = pq.ParquetFile(fp) |
| nrg = int(pf.num_row_groups or 0) |
| rgs = list(range(max(nrg, 1))) |
| |
| rows_total = None |
| try: |
| if pf.metadata is not None: |
| rows_total = 0 |
| for rg_idx in rgs: |
| rg_md = pf.metadata.row_group(rg_idx) |
| if rg_md is not None and rg_md.num_rows is not None: |
| rows_total += int(rg_md.num_rows) |
| except Exception: |
| rows_total = None |
| rows_pbar = None |
| if progress: |
| try: |
| from tqdm import tqdm |
| rows_pbar = tqdm(total=rows_total, desc=f"{split}:{Path(fp).name}", unit="rows", leave=False) |
| except Exception: |
| rows_pbar = None |
|
|
| for rg in rgs: |
| if rows_remaining is not None and rows_remaining <= 0: |
| stop_split = True |
| break |
| table = pf.read_row_group(rg, columns=["Taxon", "protein_seq", "cds_DNA"]) |
| df = table.to_pandas() |
| if df.empty: |
| continue |
| species = df["Taxon"].astype(str).tolist() |
| proteins = df["protein_seq"].astype(str).str.upper().tolist() |
| dnas = df["cds_DNA"].astype(str).str.upper().tolist() |
|
|
| |
| |
| N = len(species) |
| for off in range(0, N, batch_size): |
| if rows_remaining is not None and rows_remaining <= 0: |
| stop_split = True |
| break |
| sp_b = species[off: off + batch_size] |
| pr_b = proteins[off: off + batch_size] |
| dn_b = dnas[off: off + batch_size] |
| gen_list, codon_accs, aa_accs = generate_and_score_batched( |
| sampler, |
| sp_b, |
| pr_b, |
| dn_b, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| control_mode=control_mode, |
| batch_size=batch_size, |
| enforce_translation=enforce_translation, |
| no_truncation=bool(no_truncation), |
| species_prefix_cap=int(species_prefix_cap), |
| ) |
| rows_batch: List[dict] = [] |
| for sp, pr, dn, gen, cacc, aacc in zip(sp_b, pr_b, dn_b, gen_list, codon_accs, aa_accs): |
| L = min(len(pr), len(dn) // 3) |
| tgt_dna = dn[: 3 * L] |
| rows_batch.append({ |
| "split": split, |
| "organism": sp, |
| "protein_seq": pr, |
| "codon_seq": tgt_dna, |
| "predicted_seq": gen, |
| "codon_similarity": float(cacc), |
| "amino_acid_recovery_rate": float(aacc), |
| }) |
| if rows_batch: |
| if rows_remaining is not None and len(rows_batch) > rows_remaining: |
| rows_batch = rows_batch[: rows_remaining] |
| out_exists = _P(out_csv).exists() and _os.path.getsize(out_csv) > 0 |
| df_out = pd.DataFrame(rows_batch) |
| _P(out_csv).parent.mkdir(parents=True, exist_ok=True) |
| df_out.to_csv(out_csv, mode='a', header=not out_exists, index=False) |
| total_written += len(rows_batch) |
| if rows_remaining is not None: |
| rows_remaining -= len(rows_batch) |
| if rows_pbar is not None: |
| try: |
| rows_pbar.update(len(rows_batch)) |
| except Exception: |
| pass |
| if rows_remaining is not None and rows_remaining <= 0: |
| stop_split = True |
| break |
| if rows_pbar is not None: |
| try: |
| rows_pbar.close() |
| except Exception: |
| pass |
| if stop_split: |
| break |
| logging.info(f"Per-sequence export complete → {out_csv} (rows={total_written})") |
|
|
|
|
| def main(): |
| args = parse_args() |
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
|
|
| model_dir = Path(args.model_path) |
| pooling = _preferred_pooling(model_dir) |
| logger.info(f"Preferred species_pooling from checkpoint: {pooling}") |
|
|
| |
| species_store = None |
| if args.embeddings_dir: |
| emb_dir = Path(args.embeddings_dir) |
| detected = _detect_pooling_from_embeddings_dir(emb_dir) |
| if detected is not None and detected != pooling: |
| logger.info(f"Overriding pooling from checkpoint ({pooling}) → embeddings_dir format ({detected})") |
| pooling = detected |
| species_store = SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling) |
| logger.info(f"Loaded species store with {len(species_store.vocab)} species (pooling={pooling})") |
|
|
| |
| sampler = CodonSampler( |
| model_path=args.model_path, |
| device=("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"), |
| species_store=species_store, |
| ) |
|
|
| |
| if bool(args.export_per_sequence): |
| export_per_sequence_over_splits( |
| sampler, |
| splits=list(args.export_splits), |
| splits_root=str(args.splits_root), |
| out_csv=str(args.out_csv), |
| batch_size=int(args.batch_size), |
| temperature=float(args.temperature), |
| top_k=int(args.top_k), |
| top_p=float(args.top_p), |
| control_mode=str(args.control_mode), |
| enforce_translation=bool(args.enforce_translation), |
| progress=bool(args.progress), |
| max_rows_per_split=int(args.max_rows_per_split), |
| no_truncation=bool(args.no_truncation), |
| species_prefix_cap=int(args.species_prefix_cap), |
| ) |
| return |
|
|
| data_path = args.data_path or args.csv_path |
| if data_path is None: |
| raise SystemExit("Please provide --data_path (CSV or Parquet glob/dir). --csv_path remains as a deprecated alias.") |
|
|
| |
| paths = _expand_paths(data_path) |
| if not paths: |
| raise FileNotFoundError(f"No input files matched: {data_path}") |
|
|
| if all(_is_parquet_path(p) for p in paths): |
| logger.info(f"Reading up to {args.num_samples} samples from {len(paths)} parquet files ...") |
| df_s = _load_random_samples_from_parquet(paths, int(args.num_samples), int(args.seed)) |
| else: |
| |
| csv_file = None |
| for pth in paths: |
| if pth.lower().endswith((".csv", ".tsv", ".csv.gz", ".tsv.gz")): |
| csv_file = pth |
| break |
| if csv_file is None: |
| raise ValueError(f"Unsupported input for --data_path: {paths[0]}") |
| logger.info(f"Reading CSV file: {csv_file}") |
| df = pd.read_csv(csv_file) |
| required = {"Taxon", "protein_seq", "cds_DNA"} |
| if not required.issubset(set(df.columns)): |
| missing = required - set(df.columns) |
| raise ValueError(f"CSV missing required columns: {sorted(missing)}") |
| if args.num_samples > len(df): |
| logger.warning(f"num_samples {args.num_samples} > CSV rows {len(df)}; reducing") |
| args.num_samples = len(df) |
| |
| indices = random.sample(range(len(df)), args.num_samples) |
| df_s = df.iloc[indices].reset_index(drop=True) |
|
|
| if len(df_s) == 0: |
| raise ValueError("No samples loaded from the provided data_path") |
|
|
| logger.info(f"Loaded {len(df_s)} samples for evaluation") |
|
|
| species = df_s["Taxon"].astype(str).tolist() |
| proteins = df_s["protein_seq"].astype(str).str.upper().tolist() |
| dnas = df_s["cds_DNA"].astype(str).str.upper().tolist() |
|
|
| if not args.free_run: |
| if bool(args.eval_all): |
| if not args.embeddings_dir: |
| raise SystemExit("--eval_all requires --embeddings_dir for species vocab/embeddings") |
| |
| eval_streaming_all( |
| sampler, |
| species_store if species_store is not None else SpeciesEmbeddingStore(args.embeddings_dir, pooling=pooling), |
| data_path, |
| batch_size=int(args.batch_size), |
| num_workers=int(args.workers), |
| max_records=int(args.max_records), |
| ) |
| return |
| |
| if bool(args.debug_aa_check): |
| for idx, (sp, pr, dn) in enumerate(zip(species, proteins, dnas), start=1): |
| ratio, Lcmp, first_bad = _aa_agreement(dn, pr, sampler.tokenizer) |
| flag = "OK" if ratio == 1.0 and Lcmp > 0 else ("EMPTY" if Lcmp == 0 else "MISMATCH") |
| extra = f" first_mismatch={first_bad}" if first_bad >= 0 else "" |
| logger.info(f"AA-CHECK Sample {idx:02d}: {flag} match={ratio:.3f} len={Lcmp}{extra} Taxon={sp}") |
| |
| |
| per_ce_all: List[float] = [] |
| per_aa_acc_all: List[float] = [] |
| per_codon_acc_all: List[float] = [] |
| bs = max(1, int(args.batch_size)) |
| for i in range(0, len(species), bs): |
| sp_b = species[i:i+bs] |
| pr_b = proteins[i:i+bs] |
| dn_b = dnas[i:i+bs] |
| ce, aa_acc = eval_batch(sampler, species_store, sp_b, pr_b, dn_b) |
| |
| |
| |
| |
| |
| per_ce_all.extend(ce) |
| per_aa_acc_all.extend(aa_acc) |
|
|
| |
| |
| tok = sampler.tokenizer |
| pad_id = tok.pad_token_id |
| eos_id = tok.eos_token_id |
| codon_ids_local = [] |
| for dna, prot in zip(dn_b, pr_b): |
| C_dna = len(dna) // 3 |
| C_prot = len(prot) |
| C = max(min(C_dna, C_prot), 1) |
| dna_trim = dna[: 3 * C] |
| ids = tok.encode_codon_seq(dna_trim, validate=False) |
| ids.append(eos_id) |
| codon_ids_local.append(ids) |
| B_b = len(codon_ids_local) |
| T_b = max(len(x) for x in codon_ids_local) |
| codons_b = torch.full((B_b, T_b), pad_id, dtype=torch.long) |
| mask_b = torch.zeros((B_b, T_b), dtype=torch.bool) |
| for j, ids in enumerate(codon_ids_local): |
| Lb = len(ids) |
| codons_b[j, :Lb] = torch.tensor(ids, dtype=torch.long) |
| mask_b[j, :Lb] = True |
| input_ids_b = codons_b[:, :-1].to(sampler.device) |
| labels_b = codons_b[:, :-1].clone() |
| labels_b[labels_b == pad_id] = -100 |
| labels_b[labels_b == eos_id] = -100 |
| cond_b = {"control_mode": "fixed"} |
| if species_store is not None and sp_b: |
| sids_b = [species_store.vocab.get(s, -1) for s in sp_b] |
| res_b = species_store.batch_get(sids_b) |
| if isinstance(res_b, tuple): |
| sp_tok_b, _ = res_b |
| cond_b["species_tok_emb_src"] = sp_tok_b.to(sampler.device) |
| cond_b["species_tok_emb_tgt"] = sp_tok_b.to(sampler.device) |
| else: |
| sp_fix_b = res_b |
| cond_b["species_emb_src"] = sp_fix_b.to(sampler.device) |
| cond_b["species_emb_tgt"] = sp_fix_b.to(sampler.device) |
| cond_b["protein_seqs"] = pr_b |
| out_b = sampler.model(codon_ids=input_ids_b, cond=cond_b, labels=labels_b.to(sampler.device), return_dict=True) |
| logits_b = out_b["logits"] |
| per_cap_b = out_b.get("per_cap") |
| if logits_b is not None and per_cap_b is not None: |
| Bsz, Lmax, V = logits_b.size(0), logits_b.size(1), logits_b.size(2) |
| labels_aligned_b = torch.full((Bsz, Lmax), -100, dtype=labels_b.dtype, device=logits_b.device) |
| common_cols_b = min(labels_b.size(1), Lmax) |
| if common_cols_b > 0: |
| labels_aligned_b[:, :common_cols_b] = labels_b.to(logits_b.device)[:, :common_cols_b] |
| ar = torch.arange(Lmax, device=logits_b.device).unsqueeze(0) |
| cap_mask_b = ar < per_cap_b.to(device=logits_b.device).unsqueeze(1) |
| labels_masked_b = labels_aligned_b.clone() |
| labels_masked_b[~cap_mask_b] = -100 |
| preds_b = logits_b.argmax(dim=-1) |
| num_special = int(getattr(tok, "num_special_tokens", 0) or 0) |
| supervised_b = (labels_masked_b != -100) & cap_mask_b |
| if num_special > 0: |
| supervised_b = supervised_b & (labels_aligned_b >= num_special) |
| for r in range(Bsz): |
| denom = int(supervised_b[r].sum().item()) |
| cod_acc = (float((preds_b[r][supervised_b[r]] == labels_aligned_b[r][supervised_b[r]]).sum().item()) / denom) if denom > 0 else 0.0 |
| per_codon_acc_all.append(cod_acc) |
|
|
| for idx, (ce, aa, ca) in enumerate(zip(per_ce_all, per_aa_acc_all, per_codon_acc_all), start=1): |
| logger.info(f"Sample {idx:02d}: CE={ce:.4f} CODON-acc={ca:.4f} AA-acc={aa:.4f}") |
| if per_ce_all: |
| mean_ce = sum(per_ce_all) / len(per_ce_all) |
| mean_aa = sum(per_aa_acc_all) / len(per_aa_acc_all) if per_aa_acc_all else 0.0 |
| mean_codon = sum(per_codon_acc_all) / len(per_codon_acc_all) if per_codon_acc_all else 0.0 |
| logger.info(f"Summary over {len(per_ce_all)} samples → mean CE={mean_ce:.4f}, mean CODON-acc={mean_codon:.4f}, mean AA-acc={mean_aa:.4f}") |
| else: |
| |
| codon_accs, aa_accs = sample_and_score_batched( |
| sampler, |
| species, |
| proteins, |
| dnas, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| control_mode=args.control_mode, |
| batch_size=int(args.batch_size), |
| enforce_translation=bool(args.enforce_translation), |
| no_truncation=bool(args.no_truncation), |
| species_prefix_cap=int(args.species_prefix_cap), |
| ) |
| for idx, (cacc, aacc) in enumerate(zip(codon_accs, aa_accs), start=1): |
| logger.info(f"Sample {idx:02d}: CODON-acc={cacc:.4f} AA-acc={aacc:.4f}") |
| if codon_accs: |
| mean_c = sum(codon_accs) / len(codon_accs) |
| mean_a = sum(aa_accs) / len(aa_accs) |
| logger.info(f"Summary over {len(codon_accs)} samples → mean CODON-acc={mean_c:.4f}, mean AA-acc={mean_a:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|