""" Smoke tests for VariantEffectPredictor. Test cases: - scan: original single-mutation workflow + scan_single_mutations demo - nls: NLS-targeted disruption in a high-nucleus protein (single + multi mutation) - membrane: TM-targeted disruption in a high-membrane protein - combined: apply top 3 loss-direction mutations from scan simultaneously - all: run everything """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple import pandas as pd ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from src.analysis.variant_effect import VariantEffectPredictor # noqa: E402 HYDROPHOBIC_TM = set("LVIAF") def _predict_map(predictor: VariantEffectPredictor, seq: str) -> Dict[str, float]: emb = predictor.embed_sequence(seq) probs = predictor._predict_proba_from_embeddings(emb[None, :])[0] return predictor._dict_from_probs(probs) def _pick_protein_with_high_label( df: pd.DataFrame, predictor: VariantEffectPredictor, *, label: str, min_prob: float, max_scan: int, min_len: int = 120, max_len: int = 700, ) -> Tuple[str, str, Dict[str, float]]: required = {"ACC", "Sequence"} missing = required.difference(df.columns) if missing: raise ValueError(f"Dataset missing columns: {sorted(missing)}") pool = df.copy() if label in pool.columns: pool = pool[pool[label] == 1].copy() pool["seq_len"] = pool["Sequence"].astype(str).str.len() pool = pool[(pool["seq_len"] >= min_len) & (pool["seq_len"] <= max_len)].sort_values("seq_len", ascending=True) tried = 0 for _, row in pool.iterrows(): if tried >= max_scan: break tried += 1 seq = str(row["Sequence"]).upper().strip() pmap = _predict_map(predictor, seq) if float(pmap.get(label, 0.0)) > float(min_prob): return str(row["ACC"]), seq, pmap raise RuntimeError(f"No protein found with P({label}) > {min_prob} in first {tried} candidates.") def _top_residues(residue_scores: Sequence[Tuple[int, str, float]], n: int = 20) -> List[Tuple[int, str, float]]: return sorted(residue_scores, key=lambda x: abs(float(x[2])), reverse=True)[:n] def _print_effect(name: str, predictor: VariantEffectPredictor, effect: Dict[str, Any]) -> None: print(f"\n=== {name} ===") print(predictor.format_report(effect)) print("\n--- Raw JSON ---") print(json.dumps(effect, indent=2, default=str)) def _run_scan_case( predictor: VariantEffectPredictor, df: pd.DataFrame, max_scan: int, ) -> Dict[str, Any]: acc, seq, pmap = _pick_protein_with_high_label( df, predictor, label="Nucleus", min_prob=0.6, max_scan=max_scan, min_len=120, max_len=500, ) print(f"\n[scan] Selected ACC={acc}, length={len(seq)}, P(Nucleus)={pmap.get('Nucleus', 0):.4f}") ig = predictor.interpreter.get_integrated_gradients(seq, target_location="Nucleus") ranked = _top_residues(ig["residue_scores"], n=40) single_mut: Optional[Tuple[int, str, str]] = None for pos, aa, _ in ranked: aa_u = str(aa).upper() if aa_u == "K": single_mut = (int(pos), "K", "A") break if aa_u == "R": single_mut = (int(pos), "R", "W") break if single_mut is None: raise RuntimeError("No K/R among top important Nucleus residues for scan case.") print(f"[scan] Chosen mutation from important residue: {single_mut[0]}{single_mut[1]}>{single_mut[2]}") effect = predictor.predict_variant_effect(seq, [single_mut]) _print_effect("TEST CASE 1 - Existing scan workflow", predictor, effect) print("\n[scan] Running scan_single_mutations on positions 1-50 (step=5) ...") scan = predictor.scan_single_mutations( seq, region_start=1, region_end=min(50, len(seq)), step=5, top_k=20, ) print(f"[scan] Scored variants: {scan['total_variants_scored']} in {scan['time_seconds']:.2f}s") print("[scan] Top 10 impactful mutations:") for row in (scan["top_mutations"] or [])[:10]: print( f" {row['position']}{row['original_aa']}>{row['mutant_aa']} | " f"max_delta={row['max_delta']:+.4f} | " f"{row['most_affected_location']} ({row['direction']})" ) return {"acc": acc, "sequence": seq, "scan": scan} def _run_nls_case( predictor: VariantEffectPredictor, df: pd.DataFrame, max_scan: int, ) -> Dict[str, Any]: acc, seq, pmap = _pick_protein_with_high_label( df, predictor, label="Nucleus", min_prob=0.7, max_scan=max_scan, min_len=120, max_len=700, ) print(f"\n[nls] Selected ACC={acc}, length={len(seq)}, P(Nucleus)={pmap.get('Nucleus', 0):.4f}") ig = predictor.interpreter.get_integrated_gradients(seq, target_location="Nucleus") top20 = _top_residues(ig["residue_scores"], n=20) kr_top = [(int(p), str(a).upper(), float(s)) for p, a, s in top20 if str(a).upper() in {"K", "R"}] if not kr_top: raise RuntimeError("No K/R in top 20 important Nucleus residues.") anchor_pos, anchor_aa, _ = kr_top[0] single_mut = (anchor_pos, anchor_aa, "A") nearby: List[Tuple[int, str, str]] = [single_mut] for p, a, _s in kr_top[1:]: if abs(p - anchor_pos) <= 8: nearby.append((p, a, "A")) if len(nearby) >= 3: break if len(nearby) < 2: for i, aa in enumerate(seq, start=1): if aa in {"K", "R"} and abs(i - anchor_pos) <= 12 and i != anchor_pos: nearby.append((i, aa, "A")) if len(nearby) >= 3: break print(f"[nls] Single mutation: {single_mut[0]}{single_mut[1]}>{single_mut[2]}") print(f"[nls] Multi mutation set: {', '.join(f'{p}{o}>{m}' for p, o, m in nearby)}") eff_single = predictor.predict_variant_effect(seq, [single_mut]) eff_multi = predictor.predict_variant_effect(seq, nearby[:3]) _print_effect("TEST CASE 2A - NLS disruption (single)", predictor, eff_single) _print_effect("TEST CASE 2B - NLS disruption (multi nearby)", predictor, eff_multi) print("\n[nls] Comparison summary:") print( f" Single delta Nucleus: {eff_single['deltas'].get('Nucleus', 0.0):+.4f} | " f"Multi delta Nucleus: {eff_multi['deltas'].get('Nucleus', 0.0):+.4f}" ) print( f" Single risk: {eff_single['mislocalization_risk']} | " f"Multi risk: {eff_multi['mislocalization_risk']}" ) return {"acc": acc, "sequence": seq, "single": eff_single, "multi": eff_multi} def _run_membrane_case( predictor: VariantEffectPredictor, df: pd.DataFrame, max_scan: int, ) -> Dict[str, Any]: acc, seq, pmap = _pick_protein_with_high_label( df, predictor, label="Membrane", min_prob=0.7, max_scan=max_scan, min_len=120, max_len=900, ) print(f"\n[membrane] Selected ACC={acc}, length={len(seq)}, P(Membrane)={pmap.get('Membrane', 0):.4f}") ig = predictor.interpreter.get_integrated_gradients(seq, target_location="Membrane") top30 = _top_residues(ig["residue_scores"], n=30) chosen: Optional[Tuple[int, str, str]] = None for p, aa, _ in top30: a = str(aa).upper() if a not in HYDROPHOBIC_TM: continue if a == "V": chosen = (int(p), "V", "K") # requested example else: chosen = (int(p), a, "D") # requested style (e.g., L->D) break if chosen is None: raise RuntimeError("No hydrophobic residue (L/V/I/A/F) in top membrane-important positions.") print(f"[membrane] Mutation: {chosen[0]}{chosen[1]}>{chosen[2]}") effect = predictor.predict_variant_effect(seq, [chosen]) _print_effect("TEST CASE 3 - Transmembrane disruption", predictor, effect) return {"acc": acc, "sequence": seq, "effect": effect} def _run_combined_case( predictor: VariantEffectPredictor, accession: str, sequence: str, scan: Mapping[str, Any], ) -> Dict[str, Any]: top = list(scan.get("top_mutations") or []) losses = [x for x in top if str(x.get("direction", "")) == "loss"] if len(losses) < 3: raise RuntimeError("Need at least 3 loss-direction mutations in scan results for combined test.") sel = losses[:3] muts = [(int(x["position"]), str(x["original_aa"]), str(x["mutant_aa"])) for x in sel] print(f"\n[combined] Using scan-case protein ACC={accession}, length={len(sequence)}") print("[combined] Using top 3 loss-direction mutations:") for m in muts: print(f" - {m[0]}{m[1]}>{m[2]}") combined = predictor.predict_variant_effect(sequence, muts) _print_effect("TEST CASE 4 - Combined top-loss mutations", predictor, combined) print("\n[combined] Individual vs combined most-affected deltas:") for row in sel: print( f" {row['position']}{row['original_aa']}>{row['mutant_aa']}: " f"{row['most_affected_location']} {float(row['max_delta']):+.4f}" ) print( f" Combined: {combined['most_affected_location']} " f"{float(combined['max_delta']):+.4f}" ) return {"acc": accession, "sequence": sequence, "combined": combined, "mutations": muts} def main() -> None: p = argparse.ArgumentParser(description="Smoke tests for VariantEffectPredictor.") p.add_argument("--classifier-path", type=Path, default=ROOT / "models" / "best_model.pt") p.add_argument("--csv-path", type=Path, default=ROOT / "data" / "processed" / "deeploc_multilabel.csv") p.add_argument("--device", default="cuda", help="Device (default: cuda).") p.add_argument("--max-scan", type=int, default=400) p.add_argument( "--test-case", choices=["all", "scan", "nls", "membrane", "combined"], default="all", help="Which test workflow to run (default: all).", ) args = p.parse_args() classifier_path = args.classifier_path if args.classifier_path.is_absolute() else (ROOT / args.classifier_path).resolve() csv_path = args.csv_path if args.csv_path.is_absolute() else (ROOT / args.csv_path).resolve() if not classifier_path.is_file(): raise FileNotFoundError(f"Missing classifier: {classifier_path}") if not csv_path.is_file(): raise FileNotFoundError(f"Missing dataset CSV: {csv_path}") predictor = VariantEffectPredictor(classifier_path=classifier_path, device=args.device) df = pd.read_csv(csv_path) want_all = args.test_case == "all" scan_bundle: Optional[Dict[str, Any]] = None if want_all or args.test_case == "scan": scan_bundle = _run_scan_case(predictor, df, max_scan=args.max_scan) if want_all or args.test_case == "nls": _run_nls_case(predictor, df, max_scan=args.max_scan) if want_all or args.test_case == "membrane": _run_membrane_case(predictor, df, max_scan=args.max_scan) if want_all or args.test_case == "combined": if scan_bundle is None: scan_bundle = _run_scan_case(predictor, df, max_scan=args.max_scan) _run_combined_case( predictor, accession=str(scan_bundle["acc"]), sequence=str(scan_bundle["sequence"]), scan=scan_bundle["scan"], ) if __name__ == "__main__": main()