protloc-ai / scripts /test_interpretability.py
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba
"""
Smoke test for residue-level interpretability.
From project root:
.\\venv\\Scripts\\python.exe scripts\\test_interpretability.py
"""
from __future__ import annotations
import sys
import time
from pathlib import Path
from typing import List, Sequence, Tuple
import pandas as pd
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.models.interpretability import ProteinInterpreter # noqa: E402
def _select_short_membrane_sequence(csv_path: Path) -> Tuple[str, str]:
df = pd.read_csv(csv_path)
required = {"ACC", "Sequence", "Membrane"}
missing = required.difference(df.columns)
if missing:
raise ValueError(f"Dataset missing required columns: {sorted(missing)}")
mem = df[df["Membrane"] == 1].copy()
if mem.empty:
raise ValueError("No membrane-positive sample found in dataset.")
mem["seq_len"] = mem["Sequence"].astype(str).str.len()
# Prefer short sequences to keep ESM+IG runtime manageable.
short = mem[mem["seq_len"] <= 160]
pool = short if not short.empty else mem
row = pool.sort_values("seq_len", ascending=True).iloc[0]
return str(row["ACC"]), str(row["Sequence"])
def _top_k(
residue_scores: Sequence[Tuple[int, str, float]],
k: int = 10,
by_abs: bool = False,
) -> List[Tuple[int, str, float]]:
if by_abs:
ranked = sorted(residue_scores, key=lambda x: abs(float(x[2])), reverse=True)
else:
ranked = sorted(residue_scores, key=lambda x: float(x[2]), reverse=True)
return ranked[:k]
def main() -> None:
t0 = time.perf_counter()
model_path = ROOT / "models" / "best_model.pt"
csv_path = ROOT / "data" / "processed" / "deeploc_multilabel.csv"
plot_path = ROOT / "plots" / "test_attribution.png"
plot_path.parent.mkdir(parents=True, exist_ok=True)
if not model_path.is_file():
raise FileNotFoundError(f"Missing model checkpoint: {model_path}")
if not csv_path.is_file():
raise FileNotFoundError(f"Missing dataset CSV: {csv_path}")
acc, seq = _select_short_membrane_sequence(csv_path)
print(f"Selected sample ACC={acc}, length={len(seq)}")
interpreter = ProteinInterpreter(
classifier_path=model_path,
esm_model_name="facebook/esm2_t33_650M_UR50D",
device="cuda" if __import__("torch").cuda.is_available() else "cpu",
)
print("\n1) Attention scores")
attention_out = interpreter.get_attention_scores(seq)
top_attn = _top_k(attention_out["residue_scores"], k=10, by_abs=False)
for pos, aa, score in top_attn:
print(f" pos={pos:4d} aa={aa} attention={score:.6f}")
print("\n2) Integrated gradients for target='Membrane'")
ig_out = interpreter.get_integrated_gradients(seq, target_location="Membrane")
top_ig = _top_k(ig_out["residue_scores"], k=10, by_abs=True)
for pos, aa, score in top_ig:
print(f" pos={pos:4d} aa={aa} attribution={score:.6f}")
print("\n3) Hot regions")
hot_regions = interpreter.identify_hot_regions(ig_out["residue_scores"], window_size=10, top_percentile=90)
if not hot_regions:
print(" No hot regions found.")
for i, r in enumerate(hot_regions, start=1):
print(
f" region {i}: start={r['start']}, end={r['end']}, "
f"avg_score={r['avg_score']:.6f}, subseq={r['subsequence'][:30]}..."
)
print("\n4) Known signal checks")
signal_checks = interpreter.validate_against_known_signals(seq, hot_regions)
for name, payload in signal_checks.items():
print(
f" {name}: detected={payload['detected']}, "
f"region={payload['region']}, overlap_with_attribution={payload['overlap_with_attribution']}"
)
print("\n5) Attribution visualization")
interpreter.visualize_attribution(
sequence=seq,
residue_scores=ig_out["residue_scores"],
hot_regions=hot_regions,
output_path=plot_path,
)
print(f" Saved plot: {plot_path}")
elapsed = time.perf_counter() - t0
print(f"\nTotal time taken: {elapsed:.2f} seconds")
if __name__ == "__main__":
main()