File size: 4,170 Bytes
cb6f1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Smoke test for residue-level interpretability.

From project root:
    .\\venv\\Scripts\\python.exe scripts\\test_interpretability.py
"""

from __future__ import annotations

import sys
import time
from pathlib import Path
from typing import List, Sequence, Tuple

import pandas as pd

ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.models.interpretability import ProteinInterpreter  # noqa: E402


def _select_short_membrane_sequence(csv_path: Path) -> Tuple[str, str]:
    df = pd.read_csv(csv_path)
    required = {"ACC", "Sequence", "Membrane"}
    missing = required.difference(df.columns)
    if missing:
        raise ValueError(f"Dataset missing required columns: {sorted(missing)}")

    mem = df[df["Membrane"] == 1].copy()
    if mem.empty:
        raise ValueError("No membrane-positive sample found in dataset.")
    mem["seq_len"] = mem["Sequence"].astype(str).str.len()

    # Prefer short sequences to keep ESM+IG runtime manageable.
    short = mem[mem["seq_len"] <= 160]
    pool = short if not short.empty else mem
    row = pool.sort_values("seq_len", ascending=True).iloc[0]
    return str(row["ACC"]), str(row["Sequence"])


def _top_k(
    residue_scores: Sequence[Tuple[int, str, float]],
    k: int = 10,
    by_abs: bool = False,
) -> List[Tuple[int, str, float]]:
    if by_abs:
        ranked = sorted(residue_scores, key=lambda x: abs(float(x[2])), reverse=True)
    else:
        ranked = sorted(residue_scores, key=lambda x: float(x[2]), reverse=True)
    return ranked[:k]


def main() -> None:
    t0 = time.perf_counter()

    model_path = ROOT / "models" / "best_model.pt"
    csv_path = ROOT / "data" / "processed" / "deeploc_multilabel.csv"
    plot_path = ROOT / "plots" / "test_attribution.png"
    plot_path.parent.mkdir(parents=True, exist_ok=True)

    if not model_path.is_file():
        raise FileNotFoundError(f"Missing model checkpoint: {model_path}")
    if not csv_path.is_file():
        raise FileNotFoundError(f"Missing dataset CSV: {csv_path}")

    acc, seq = _select_short_membrane_sequence(csv_path)
    print(f"Selected sample ACC={acc}, length={len(seq)}")

    interpreter = ProteinInterpreter(
        classifier_path=model_path,
        esm_model_name="facebook/esm2_t33_650M_UR50D",
        device="cuda" if __import__("torch").cuda.is_available() else "cpu",
    )

    print("\n1) Attention scores")
    attention_out = interpreter.get_attention_scores(seq)
    top_attn = _top_k(attention_out["residue_scores"], k=10, by_abs=False)
    for pos, aa, score in top_attn:
        print(f"  pos={pos:4d} aa={aa} attention={score:.6f}")

    print("\n2) Integrated gradients for target='Membrane'")
    ig_out = interpreter.get_integrated_gradients(seq, target_location="Membrane")
    top_ig = _top_k(ig_out["residue_scores"], k=10, by_abs=True)
    for pos, aa, score in top_ig:
        print(f"  pos={pos:4d} aa={aa} attribution={score:.6f}")

    print("\n3) Hot regions")
    hot_regions = interpreter.identify_hot_regions(ig_out["residue_scores"], window_size=10, top_percentile=90)
    if not hot_regions:
        print("  No hot regions found.")
    for i, r in enumerate(hot_regions, start=1):
        print(
            f"  region {i}: start={r['start']}, end={r['end']}, "
            f"avg_score={r['avg_score']:.6f}, subseq={r['subsequence'][:30]}..."
        )

    print("\n4) Known signal checks")
    signal_checks = interpreter.validate_against_known_signals(seq, hot_regions)
    for name, payload in signal_checks.items():
        print(
            f"  {name}: detected={payload['detected']}, "
            f"region={payload['region']}, overlap_with_attribution={payload['overlap_with_attribution']}"
        )

    print("\n5) Attribution visualization")
    interpreter.visualize_attribution(
        sequence=seq,
        residue_scores=ig_out["residue_scores"],
        hot_regions=hot_regions,
        output_path=plot_path,
    )
    print(f"  Saved plot: {plot_path}")

    elapsed = time.perf_counter() - t0
    print(f"\nTotal time taken: {elapsed:.2f} seconds")


if __name__ == "__main__":
    main()