Spaces:
Running
Running
| """ | |
| Smoke test for residue-level interpretability. | |
| From project root: | |
| .\\venv\\Scripts\\python.exe scripts\\test_interpretability.py | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import List, Sequence, Tuple | |
| import pandas as pd | |
| ROOT = Path(__file__).resolve().parent.parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from src.models.interpretability import ProteinInterpreter # noqa: E402 | |
| def _select_short_membrane_sequence(csv_path: Path) -> Tuple[str, str]: | |
| df = pd.read_csv(csv_path) | |
| required = {"ACC", "Sequence", "Membrane"} | |
| missing = required.difference(df.columns) | |
| if missing: | |
| raise ValueError(f"Dataset missing required columns: {sorted(missing)}") | |
| mem = df[df["Membrane"] == 1].copy() | |
| if mem.empty: | |
| raise ValueError("No membrane-positive sample found in dataset.") | |
| mem["seq_len"] = mem["Sequence"].astype(str).str.len() | |
| # Prefer short sequences to keep ESM+IG runtime manageable. | |
| short = mem[mem["seq_len"] <= 160] | |
| pool = short if not short.empty else mem | |
| row = pool.sort_values("seq_len", ascending=True).iloc[0] | |
| return str(row["ACC"]), str(row["Sequence"]) | |
| def _top_k( | |
| residue_scores: Sequence[Tuple[int, str, float]], | |
| k: int = 10, | |
| by_abs: bool = False, | |
| ) -> List[Tuple[int, str, float]]: | |
| if by_abs: | |
| ranked = sorted(residue_scores, key=lambda x: abs(float(x[2])), reverse=True) | |
| else: | |
| ranked = sorted(residue_scores, key=lambda x: float(x[2]), reverse=True) | |
| return ranked[:k] | |
| def main() -> None: | |
| t0 = time.perf_counter() | |
| model_path = ROOT / "models" / "best_model.pt" | |
| csv_path = ROOT / "data" / "processed" / "deeploc_multilabel.csv" | |
| plot_path = ROOT / "plots" / "test_attribution.png" | |
| plot_path.parent.mkdir(parents=True, exist_ok=True) | |
| if not model_path.is_file(): | |
| raise FileNotFoundError(f"Missing model checkpoint: {model_path}") | |
| if not csv_path.is_file(): | |
| raise FileNotFoundError(f"Missing dataset CSV: {csv_path}") | |
| acc, seq = _select_short_membrane_sequence(csv_path) | |
| print(f"Selected sample ACC={acc}, length={len(seq)}") | |
| interpreter = ProteinInterpreter( | |
| classifier_path=model_path, | |
| esm_model_name="facebook/esm2_t33_650M_UR50D", | |
| device="cuda" if __import__("torch").cuda.is_available() else "cpu", | |
| ) | |
| print("\n1) Attention scores") | |
| attention_out = interpreter.get_attention_scores(seq) | |
| top_attn = _top_k(attention_out["residue_scores"], k=10, by_abs=False) | |
| for pos, aa, score in top_attn: | |
| print(f" pos={pos:4d} aa={aa} attention={score:.6f}") | |
| print("\n2) Integrated gradients for target='Membrane'") | |
| ig_out = interpreter.get_integrated_gradients(seq, target_location="Membrane") | |
| top_ig = _top_k(ig_out["residue_scores"], k=10, by_abs=True) | |
| for pos, aa, score in top_ig: | |
| print(f" pos={pos:4d} aa={aa} attribution={score:.6f}") | |
| print("\n3) Hot regions") | |
| hot_regions = interpreter.identify_hot_regions(ig_out["residue_scores"], window_size=10, top_percentile=90) | |
| if not hot_regions: | |
| print(" No hot regions found.") | |
| for i, r in enumerate(hot_regions, start=1): | |
| print( | |
| f" region {i}: start={r['start']}, end={r['end']}, " | |
| f"avg_score={r['avg_score']:.6f}, subseq={r['subsequence'][:30]}..." | |
| ) | |
| print("\n4) Known signal checks") | |
| signal_checks = interpreter.validate_against_known_signals(seq, hot_regions) | |
| for name, payload in signal_checks.items(): | |
| print( | |
| f" {name}: detected={payload['detected']}, " | |
| f"region={payload['region']}, overlap_with_attribution={payload['overlap_with_attribution']}" | |
| ) | |
| print("\n5) Attribution visualization") | |
| interpreter.visualize_attribution( | |
| sequence=seq, | |
| residue_scores=ig_out["residue_scores"], | |
| hot_regions=hot_regions, | |
| output_path=plot_path, | |
| ) | |
| print(f" Saved plot: {plot_path}") | |
| elapsed = time.perf_counter() - t0 | |
| print(f"\nTotal time taken: {elapsed:.2f} seconds") | |
| if __name__ == "__main__": | |
| main() | |