protloc-ai / scripts /test_variant_effect.py
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba
"""
Smoke tests for VariantEffectPredictor.
Test cases:
- scan: original single-mutation workflow + scan_single_mutations demo
- nls: NLS-targeted disruption in a high-nucleus protein (single + multi mutation)
- membrane: TM-targeted disruption in a high-membrane protein
- combined: apply top 3 loss-direction mutations from scan simultaneously
- all: run everything
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import pandas as pd
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.analysis.variant_effect import VariantEffectPredictor # noqa: E402
HYDROPHOBIC_TM = set("LVIAF")
def _predict_map(predictor: VariantEffectPredictor, seq: str) -> Dict[str, float]:
emb = predictor.embed_sequence(seq)
probs = predictor._predict_proba_from_embeddings(emb[None, :])[0]
return predictor._dict_from_probs(probs)
def _pick_protein_with_high_label(
df: pd.DataFrame,
predictor: VariantEffectPredictor,
*,
label: str,
min_prob: float,
max_scan: int,
min_len: int = 120,
max_len: int = 700,
) -> Tuple[str, str, Dict[str, float]]:
required = {"ACC", "Sequence"}
missing = required.difference(df.columns)
if missing:
raise ValueError(f"Dataset missing columns: {sorted(missing)}")
pool = df.copy()
if label in pool.columns:
pool = pool[pool[label] == 1].copy()
pool["seq_len"] = pool["Sequence"].astype(str).str.len()
pool = pool[(pool["seq_len"] >= min_len) & (pool["seq_len"] <= max_len)].sort_values("seq_len", ascending=True)
tried = 0
for _, row in pool.iterrows():
if tried >= max_scan:
break
tried += 1
seq = str(row["Sequence"]).upper().strip()
pmap = _predict_map(predictor, seq)
if float(pmap.get(label, 0.0)) > float(min_prob):
return str(row["ACC"]), seq, pmap
raise RuntimeError(f"No protein found with P({label}) > {min_prob} in first {tried} candidates.")
def _top_residues(residue_scores: Sequence[Tuple[int, str, float]], n: int = 20) -> List[Tuple[int, str, float]]:
return sorted(residue_scores, key=lambda x: abs(float(x[2])), reverse=True)[:n]
def _print_effect(name: str, predictor: VariantEffectPredictor, effect: Dict[str, Any]) -> None:
print(f"\n=== {name} ===")
print(predictor.format_report(effect))
print("\n--- Raw JSON ---")
print(json.dumps(effect, indent=2, default=str))
def _run_scan_case(
predictor: VariantEffectPredictor,
df: pd.DataFrame,
max_scan: int,
) -> Dict[str, Any]:
acc, seq, pmap = _pick_protein_with_high_label(
df,
predictor,
label="Nucleus",
min_prob=0.6,
max_scan=max_scan,
min_len=120,
max_len=500,
)
print(f"\n[scan] Selected ACC={acc}, length={len(seq)}, P(Nucleus)={pmap.get('Nucleus', 0):.4f}")
ig = predictor.interpreter.get_integrated_gradients(seq, target_location="Nucleus")
ranked = _top_residues(ig["residue_scores"], n=40)
single_mut: Optional[Tuple[int, str, str]] = None
for pos, aa, _ in ranked:
aa_u = str(aa).upper()
if aa_u == "K":
single_mut = (int(pos), "K", "A")
break
if aa_u == "R":
single_mut = (int(pos), "R", "W")
break
if single_mut is None:
raise RuntimeError("No K/R among top important Nucleus residues for scan case.")
print(f"[scan] Chosen mutation from important residue: {single_mut[0]}{single_mut[1]}>{single_mut[2]}")
effect = predictor.predict_variant_effect(seq, [single_mut])
_print_effect("TEST CASE 1 - Existing scan workflow", predictor, effect)
print("\n[scan] Running scan_single_mutations on positions 1-50 (step=5) ...")
scan = predictor.scan_single_mutations(
seq,
region_start=1,
region_end=min(50, len(seq)),
step=5,
top_k=20,
)
print(f"[scan] Scored variants: {scan['total_variants_scored']} in {scan['time_seconds']:.2f}s")
print("[scan] Top 10 impactful mutations:")
for row in (scan["top_mutations"] or [])[:10]:
print(
f" {row['position']}{row['original_aa']}>{row['mutant_aa']} | "
f"max_delta={row['max_delta']:+.4f} | "
f"{row['most_affected_location']} ({row['direction']})"
)
return {"acc": acc, "sequence": seq, "scan": scan}
def _run_nls_case(
predictor: VariantEffectPredictor,
df: pd.DataFrame,
max_scan: int,
) -> Dict[str, Any]:
acc, seq, pmap = _pick_protein_with_high_label(
df,
predictor,
label="Nucleus",
min_prob=0.7,
max_scan=max_scan,
min_len=120,
max_len=700,
)
print(f"\n[nls] Selected ACC={acc}, length={len(seq)}, P(Nucleus)={pmap.get('Nucleus', 0):.4f}")
ig = predictor.interpreter.get_integrated_gradients(seq, target_location="Nucleus")
top20 = _top_residues(ig["residue_scores"], n=20)
kr_top = [(int(p), str(a).upper(), float(s)) for p, a, s in top20 if str(a).upper() in {"K", "R"}]
if not kr_top:
raise RuntimeError("No K/R in top 20 important Nucleus residues.")
anchor_pos, anchor_aa, _ = kr_top[0]
single_mut = (anchor_pos, anchor_aa, "A")
nearby: List[Tuple[int, str, str]] = [single_mut]
for p, a, _s in kr_top[1:]:
if abs(p - anchor_pos) <= 8:
nearby.append((p, a, "A"))
if len(nearby) >= 3:
break
if len(nearby) < 2:
for i, aa in enumerate(seq, start=1):
if aa in {"K", "R"} and abs(i - anchor_pos) <= 12 and i != anchor_pos:
nearby.append((i, aa, "A"))
if len(nearby) >= 3:
break
print(f"[nls] Single mutation: {single_mut[0]}{single_mut[1]}>{single_mut[2]}")
print(f"[nls] Multi mutation set: {', '.join(f'{p}{o}>{m}' for p, o, m in nearby)}")
eff_single = predictor.predict_variant_effect(seq, [single_mut])
eff_multi = predictor.predict_variant_effect(seq, nearby[:3])
_print_effect("TEST CASE 2A - NLS disruption (single)", predictor, eff_single)
_print_effect("TEST CASE 2B - NLS disruption (multi nearby)", predictor, eff_multi)
print("\n[nls] Comparison summary:")
print(
f" Single delta Nucleus: {eff_single['deltas'].get('Nucleus', 0.0):+.4f} | "
f"Multi delta Nucleus: {eff_multi['deltas'].get('Nucleus', 0.0):+.4f}"
)
print(
f" Single risk: {eff_single['mislocalization_risk']} | "
f"Multi risk: {eff_multi['mislocalization_risk']}"
)
return {"acc": acc, "sequence": seq, "single": eff_single, "multi": eff_multi}
def _run_membrane_case(
predictor: VariantEffectPredictor,
df: pd.DataFrame,
max_scan: int,
) -> Dict[str, Any]:
acc, seq, pmap = _pick_protein_with_high_label(
df,
predictor,
label="Membrane",
min_prob=0.7,
max_scan=max_scan,
min_len=120,
max_len=900,
)
print(f"\n[membrane] Selected ACC={acc}, length={len(seq)}, P(Membrane)={pmap.get('Membrane', 0):.4f}")
ig = predictor.interpreter.get_integrated_gradients(seq, target_location="Membrane")
top30 = _top_residues(ig["residue_scores"], n=30)
chosen: Optional[Tuple[int, str, str]] = None
for p, aa, _ in top30:
a = str(aa).upper()
if a not in HYDROPHOBIC_TM:
continue
if a == "V":
chosen = (int(p), "V", "K") # requested example
else:
chosen = (int(p), a, "D") # requested style (e.g., L->D)
break
if chosen is None:
raise RuntimeError("No hydrophobic residue (L/V/I/A/F) in top membrane-important positions.")
print(f"[membrane] Mutation: {chosen[0]}{chosen[1]}>{chosen[2]}")
effect = predictor.predict_variant_effect(seq, [chosen])
_print_effect("TEST CASE 3 - Transmembrane disruption", predictor, effect)
return {"acc": acc, "sequence": seq, "effect": effect}
def _run_combined_case(
predictor: VariantEffectPredictor,
accession: str,
sequence: str,
scan: Mapping[str, Any],
) -> Dict[str, Any]:
top = list(scan.get("top_mutations") or [])
losses = [x for x in top if str(x.get("direction", "")) == "loss"]
if len(losses) < 3:
raise RuntimeError("Need at least 3 loss-direction mutations in scan results for combined test.")
sel = losses[:3]
muts = [(int(x["position"]), str(x["original_aa"]), str(x["mutant_aa"])) for x in sel]
print(f"\n[combined] Using scan-case protein ACC={accession}, length={len(sequence)}")
print("[combined] Using top 3 loss-direction mutations:")
for m in muts:
print(f" - {m[0]}{m[1]}>{m[2]}")
combined = predictor.predict_variant_effect(sequence, muts)
_print_effect("TEST CASE 4 - Combined top-loss mutations", predictor, combined)
print("\n[combined] Individual vs combined most-affected deltas:")
for row in sel:
print(
f" {row['position']}{row['original_aa']}>{row['mutant_aa']}: "
f"{row['most_affected_location']} {float(row['max_delta']):+.4f}"
)
print(
f" Combined: {combined['most_affected_location']} "
f"{float(combined['max_delta']):+.4f}"
)
return {"acc": accession, "sequence": sequence, "combined": combined, "mutations": muts}
def main() -> None:
p = argparse.ArgumentParser(description="Smoke tests for VariantEffectPredictor.")
p.add_argument("--classifier-path", type=Path, default=ROOT / "models" / "best_model.pt")
p.add_argument("--csv-path", type=Path, default=ROOT / "data" / "processed" / "deeploc_multilabel.csv")
p.add_argument("--device", default="cuda", help="Device (default: cuda).")
p.add_argument("--max-scan", type=int, default=400)
p.add_argument(
"--test-case",
choices=["all", "scan", "nls", "membrane", "combined"],
default="all",
help="Which test workflow to run (default: all).",
)
args = p.parse_args()
classifier_path = args.classifier_path if args.classifier_path.is_absolute() else (ROOT / args.classifier_path).resolve()
csv_path = args.csv_path if args.csv_path.is_absolute() else (ROOT / args.csv_path).resolve()
if not classifier_path.is_file():
raise FileNotFoundError(f"Missing classifier: {classifier_path}")
if not csv_path.is_file():
raise FileNotFoundError(f"Missing dataset CSV: {csv_path}")
predictor = VariantEffectPredictor(classifier_path=classifier_path, device=args.device)
df = pd.read_csv(csv_path)
want_all = args.test_case == "all"
scan_bundle: Optional[Dict[str, Any]] = None
if want_all or args.test_case == "scan":
scan_bundle = _run_scan_case(predictor, df, max_scan=args.max_scan)
if want_all or args.test_case == "nls":
_run_nls_case(predictor, df, max_scan=args.max_scan)
if want_all or args.test_case == "membrane":
_run_membrane_case(predictor, df, max_scan=args.max_scan)
if want_all or args.test_case == "combined":
if scan_bundle is None:
scan_bundle = _run_scan_case(predictor, df, max_scan=args.max_scan)
_run_combined_case(
predictor,
accession=str(scan_bundle["acc"]),
sequence=str(scan_bundle["sequence"]),
scan=scan_bundle["scan"],
)
if __name__ == "__main__":
main()