protloc-ai / src /models /interpretability.py
Tanoj22
Force add src/models and src/data code files
fe5a903
"""
Residue-level interpretation utilities for protein localization predictions.
This module uses the full ESM encoder (not precomputed embeddings) so gradients
can flow from classifier outputs back to token embeddings.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
import contextlib
import importlib
import io
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import patches
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
from src.utils.device import resolve_torch_device
from .classifier import ProteinLocalizationClassifier
def _load_integrated_gradients() -> Optional[Type[Any]]:
"""Load Captum lazily so static analysis does not require ``captum`` to be installed."""
try:
mod = importlib.import_module("captum.attr")
return getattr(mod, "IntegratedGradients")
except ImportError: # pragma: no cover
return None
IntegratedGradients = _load_integrated_gradients()
HYDROPHOBIC_AA = set("AVILMFWP")
MTP_ENRICHED_AA = set("RSLA")
POSITIVE_AA = set("KR")
@dataclass
class _SignalDetection:
detected: bool
region: Tuple[int, int] | None
overlap_with_attribution: bool
class ProteinInterpreter:
def __init__(
self,
classifier_path: str | Path,
esm_model_name: str = "facebook/esm2_t33_650M_UR50D",
device: str | torch.device | None = None,
) -> None:
self.device = resolve_torch_device(device)
self.classifier_path = Path(classifier_path).expanduser().resolve()
if not self.classifier_path.is_file():
raise FileNotFoundError(f"Missing classifier checkpoint: {self.classifier_path}")
checkpoint = torch.load(self.classifier_path, map_location="cpu")
if not isinstance(checkpoint, dict):
raise ValueError("Unsupported classifier checkpoint format.")
state_dict = checkpoint.get("state_dict", checkpoint)
embedding_dim = int(checkpoint.get("embedding_dim", 1280))
num_labels = int(checkpoint.get("num_labels", 11))
label_names = checkpoint.get("label_names")
dropout_rates = tuple(checkpoint.get("dropout_rates", [0.3, 0.3, 0.2]))
hidden_dims = tuple(checkpoint.get("hidden_dims", [512, 256, 128]))
self.classifier = ProteinLocalizationClassifier(
embedding_dim=embedding_dim,
num_labels=num_labels,
label_names=label_names,
dropout_rates=dropout_rates,
hidden_dims=hidden_dims,
)
self.classifier.load_state_dict(state_dict)
self.classifier.to(self.device).eval()
self.label_names = list(self.classifier.label_names)
self.label_to_idx = {n: i for i, n in enumerate(self.label_names)}
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(esm_model_name)
hf_logger = logging.getLogger("transformers.modeling_utils")
prev_level = hf_logger.level
hf_logger.setLevel(logging.ERROR)
try:
with contextlib.redirect_stderr(io.StringIO()):
self.esm_model = AutoModel.from_pretrained(
esm_model_name,
attn_implementation="eager",
ignore_mismatched_sizes=True,
)
finally:
hf_logger.setLevel(prev_level)
self.esm_model.to(self.device).eval()
self.esm_model_name = esm_model_name
@property
def esm_encoder(self) -> PreTrainedModel:
"""ESM backbone (``AutoModel``) used for IG and shared with :class:`ProteinRelocalizer` embeddings."""
return self.esm_model
@property
def esm_tokenizer(self) -> PreTrainedTokenizerBase:
"""Tokenizer paired with ``esm_encoder``."""
return self.tokenizer
def mean_pool_embedding(self, sequence: str) -> np.ndarray:
"""
Mean-pooled ESM representation for one sequence (same pooling as training embeddings).
Runs under ``torch.no_grad()``.
"""
seq = sequence.upper().strip()
if not seq:
raise ValueError("Empty sequence")
toks = self._tokenize(seq)
with torch.no_grad():
out = self.esm_model(**toks, return_dict=True)
pooled = out.last_hidden_state.mean(dim=1)
return pooled.detach().cpu().numpy().astype(np.float32).squeeze(0)
def _tokenize(self, sequence: str) -> Dict[str, torch.Tensor]:
if not sequence or not isinstance(sequence, str):
raise ValueError("sequence must be a non-empty string")
toks = self.tokenizer(
sequence,
return_tensors="pt",
add_special_tokens=True,
truncation=True,
)
return {k: v.to(self.device) for k, v in toks.items()}
def _trim_special_tokens(self, token_scores: np.ndarray, sequence: str) -> List[Tuple[int, str, float]]:
# ESM tokenization with add_special_tokens=True yields [CLS] + residues + [EOS].
if token_scores.ndim != 1:
raise ValueError("Expected 1D token score array.")
if token_scores.shape[0] >= len(sequence) + 2:
core = token_scores[1 : 1 + len(sequence)]
else:
# Best-effort fallback if tokenizer behavior differs.
core = token_scores[: len(sequence)]
return [(i + 1, aa, float(core[i])) for i, aa in enumerate(sequence[: len(core)])]
@staticmethod
def _clear_cuda_cache() -> None:
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_attention_scores(self, sequence: str) -> Dict[str, Any]:
try:
toks = self._tokenize(sequence)
with torch.no_grad():
outputs = self.esm_model(
input_ids=toks["input_ids"],
attention_mask=toks.get("attention_mask"),
output_attentions=True,
return_dict=True,
)
if outputs.attentions is None or len(outputs.attentions) == 0:
raise RuntimeError("ESM model did not return attentions.")
# last layer: [batch, heads, tokens, tokens]
last_attn = outputs.attentions[-1][0]
mean_attn = last_attn.mean(dim=0) # [tokens, tokens], averaged over heads
token_importance = mean_attn.mean(dim=0).detach().cpu().numpy()
residue_scores = self._trim_special_tokens(token_importance, sequence)
return {
"residue_scores": residue_scores,
"raw_attention": mean_attn.detach().cpu().numpy(),
}
finally:
self._clear_cuda_cache()
def get_integrated_gradients(self, sequence: str, target_location: str) -> Dict[str, Any]:
if IntegratedGradients is None:
raise ImportError("captum is required for integrated gradients. Install with `pip install captum`.")
if target_location not in self.label_to_idx:
raise ValueError(f"Unknown target_location={target_location!r}. Known labels: {self.label_names}")
target_idx = int(self.label_to_idx[target_location])
try:
toks = self._tokenize(sequence)
input_ids = toks["input_ids"]
attention_mask = toks.get("attention_mask")
input_embed = self.esm_model.get_input_embeddings()(input_ids)
baseline_embed = torch.zeros_like(input_embed)
def forward_func(inputs_embeds: torch.Tensor, attn_mask: torch.Tensor, cls_idx: int) -> torch.Tensor:
enc = self.esm_model(
inputs_embeds=inputs_embeds,
attention_mask=attn_mask,
return_dict=True,
)
pooled = enc.last_hidden_state.mean(dim=1)
logits = self.classifier(pooled)
return logits[:, cls_idx]
ig = IntegratedGradients(forward_func)
attributions = ig.attribute(
input_embed,
baselines=baseline_embed,
additional_forward_args=(attention_mask, target_idx),
n_steps=32,
internal_batch_size=1,
)
token_scores = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
residue_scores = self._trim_special_tokens(token_scores, sequence)
return {
"residue_scores": residue_scores,
"raw_attributions": attributions.detach().cpu().numpy(),
}
finally:
self._clear_cuda_cache()
def identify_hot_regions(
self,
residue_scores: Sequence[Tuple[int, str, float]],
window_size: int = 10,
top_percentile: float = 90.0,
) -> List[Dict[str, Any]]:
if len(residue_scores) == 0:
return []
window_size = max(1, int(window_size))
vals = np.asarray([float(x[2]) for x in residue_scores], dtype=np.float64)
seq = "".join([str(x[1]) for x in residue_scores])
if vals.size < window_size:
avg = float(vals.mean())
return [{"start": 1, "end": int(vals.size), "avg_score": avg, "subsequence": seq}]
win_avgs = np.asarray(
[float(vals[i : i + window_size].mean()) for i in range(0, vals.size - window_size + 1)],
dtype=np.float64,
)
thr = float(np.percentile(win_avgs, top_percentile))
hot_starts = np.where(win_avgs >= thr)[0].tolist()
if not hot_starts:
return []
regions: List[Dict[str, Any]] = []
run_start = hot_starts[0]
prev = hot_starts[0]
for s in hot_starts[1:]:
if s <= prev + 1:
prev = s
continue
start = run_start + 1
end = min(prev + window_size, len(seq))
avg_score = float(vals[start - 1 : end].mean())
regions.append({"start": start, "end": end, "avg_score": avg_score, "subsequence": seq[start - 1 : end]})
run_start = s
prev = s
start = run_start + 1
end = min(prev + window_size, len(seq))
avg_score = float(vals[start - 1 : end].mean())
regions.append({"start": start, "end": end, "avg_score": avg_score, "subsequence": seq[start - 1 : end]})
return regions
def validate_against_known_signals(self, sequence: str, hot_regions: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
seq = sequence.upper()
n = len(seq)
def _overlap(region: Tuple[int, int] | None) -> bool:
if region is None:
return False
a1, a2 = region
for r in hot_regions:
b1, b2 = int(r["start"]), int(r["end"])
if max(a1, b1) <= min(a2, b2):
return True
return False
# Signal peptide heuristic: hydrophobic region in residues 1-30.
sp_region: Tuple[int, int] | None = None
for w in range(8, 18):
for i in range(0, max(0, min(30, n) - w + 1)):
frag = seq[i : i + w]
hyd_frac = sum(aa in HYDROPHOBIC_AA for aa in frag) / max(w, 1)
if hyd_frac >= 0.6:
sp_region = (i + 1, i + w)
break
if sp_region is not None:
break
# Mitochondrial transit peptide: first 10-70 enriched in R/S/L/A and positive residues.
mtp_region: Tuple[int, int] | None = None
mtp_end = min(70, n)
if mtp_end >= 10:
frag = seq[:mtp_end]
rsla = sum(aa in MTP_ENRICHED_AA for aa in frag) / len(frag)
pos = sum(aa in POSITIVE_AA for aa in frag) / len(frag)
acidic = sum(aa in {"D", "E"} for aa in frag) / len(frag)
if rsla >= 0.35 and pos >= 0.12 and acidic <= 0.12:
mtp_region = (1, mtp_end)
# NLS: >=4 K/R in any 7-mer.
nls_region: Tuple[int, int] | None = None
for i in range(0, max(0, n - 7 + 1)):
frag = seq[i : i + 7]
if sum(aa in POSITIVE_AA for aa in frag) >= 4:
nls_region = (i + 1, i + 7)
break
# Transmembrane domain: 18-25 residues with strong hydrophobic content.
tm_region: Tuple[int, int] | None = None
for w in range(25, 17, -1):
for i in range(0, max(0, n - w + 1)):
frag = seq[i : i + w]
hyd_frac = sum(aa in HYDROPHOBIC_AA for aa in frag) / max(w, 1)
if hyd_frac >= 0.75:
tm_region = (i + 1, i + w)
break
if tm_region is not None:
break
# ER retention signal: C-terminal KDEL/HDEL.
er_region: Tuple[int, int] | None = None
if n >= 4 and (seq.endswith("KDEL") or seq.endswith("HDEL")):
er_region = (n - 3, n)
out: Dict[str, _SignalDetection] = {
"signal_peptide": _SignalDetection(sp_region is not None, sp_region, _overlap(sp_region)),
"mitochondrial_transit_peptide": _SignalDetection(mtp_region is not None, mtp_region, _overlap(mtp_region)),
"nuclear_localization_signal": _SignalDetection(nls_region is not None, nls_region, _overlap(nls_region)),
"transmembrane_domain": _SignalDetection(tm_region is not None, tm_region, _overlap(tm_region)),
"er_retention_signal": _SignalDetection(er_region is not None, er_region, _overlap(er_region)),
}
return {
k: {
"detected": v.detected,
"region": v.region,
"overlap_with_attribution": v.overlap_with_attribution,
}
for k, v in out.items()
}
def visualize_attribution(
self,
sequence: str,
residue_scores: Sequence[Tuple[int, str, float]],
hot_regions: Sequence[Dict[str, Any]],
output_path: str | Path,
) -> None:
if len(residue_scores) == 0:
raise ValueError("residue_scores is empty")
out_path = Path(output_path).expanduser().resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
values = np.asarray([float(x[2]) for x in residue_scores], dtype=np.float64)
vmax = float(np.max(np.abs(values))) if values.size > 0 else 1.0
vmax = max(vmax, 1e-6)
per_row = 50
n = len(residue_scores)
rows = int(np.ceil(n / per_row))
signals = self.validate_against_known_signals(sequence, hot_regions)
signal_colors = {
"signal_peptide": "#2ca02c",
"mitochondrial_transit_peptide": "#9467bd",
"nuclear_localization_signal": "#ff7f0e",
"transmembrane_domain": "#8c564b",
"er_retention_signal": "#17becf",
}
fig, axes = plt.subplots(rows, 1, figsize=(16, max(2.3, rows * 2.4)))
axes_arr = np.atleast_1d(axes)
cmap = plt.get_cmap("coolwarm")
for r in range(rows):
ax = axes_arr[r]
s = r * per_row
e = min((r + 1) * per_row, n)
seg_vals = values[s:e][None, :]
ax.imshow(seg_vals, cmap=cmap, aspect="auto", vmin=-vmax, vmax=vmax, extent=(s + 1, e, 0.0, 1.0))
ax.set_yticks([])
ax.set_xlim(s + 1, e)
ax.set_ylim(-0.55, 1.25)
ax.set_xticks(np.arange(s + 1, e + 1, 5))
ax.set_xlabel("Residue position")
ax.set_title(f"Residues {s + 1}-{e}", fontsize=10, loc="left")
# amino-acid labels
for i in range(s, e):
ax.text(i + 1, 0.5, sequence[i], ha="center", va="center", fontsize=6, color="black")
# hot-region brackets
for region in hot_regions:
hs, he = int(region["start"]), int(region["end"])
if he < s + 1 or hs > e:
continue
bs = max(hs, s + 1)
be = min(he, e)
y = -0.12
ax.plot([bs, be], [y, y], color="black", linewidth=1.5)
ax.plot([bs, bs], [y, y + 0.08], color="black", linewidth=1.5)
ax.plot([be, be], [y, y + 0.08], color="black", linewidth=1.5)
# known signal track
y0, h = -0.38, 0.14
for sig_name, payload in signals.items():
reg = payload.get("region")
if not payload.get("detected") or reg is None:
continue
ss, se = int(reg[0]), int(reg[1])
if se < s + 1 or ss > e:
continue
ds = max(ss, s + 1)
de = min(se, e)
rect = patches.Rectangle((ds, y0), de - ds + 1, h, color=signal_colors[sig_name], alpha=0.85)
ax.add_patch(rect)
ax.text(ds, y0 - 0.03, sig_name.replace("_", " "), fontsize=6, va="top", ha="left")
mappable = plt.cm.ScalarMappable(cmap=cmap)
mappable.set_clim(-vmax, vmax)
cbar = fig.colorbar(mappable, ax=axes_arr.tolist(), fraction=0.02, pad=0.02)
cbar.set_label("Attribution score")
fig.suptitle("Residue-level attribution map", fontsize=13, fontweight="bold")
fig.tight_layout()
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close(fig)