""" Residue-level interpretation utilities for protein localization predictions. This module uses the full ESM encoder (not precomputed embeddings) so gradients can flow from classifier outputs back to token embeddings. """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Type import contextlib import importlib import io import logging import matplotlib.pyplot as plt import numpy as np import torch from matplotlib import patches from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase from src.utils.device import resolve_torch_device from .classifier import ProteinLocalizationClassifier def _load_integrated_gradients() -> Optional[Type[Any]]: """Load Captum lazily so static analysis does not require ``captum`` to be installed.""" try: mod = importlib.import_module("captum.attr") return getattr(mod, "IntegratedGradients") except ImportError: # pragma: no cover return None IntegratedGradients = _load_integrated_gradients() HYDROPHOBIC_AA = set("AVILMFWP") MTP_ENRICHED_AA = set("RSLA") POSITIVE_AA = set("KR") @dataclass class _SignalDetection: detected: bool region: Tuple[int, int] | None overlap_with_attribution: bool class ProteinInterpreter: def __init__( self, classifier_path: str | Path, esm_model_name: str = "facebook/esm2_t33_650M_UR50D", device: str | torch.device | None = None, ) -> None: self.device = resolve_torch_device(device) self.classifier_path = Path(classifier_path).expanduser().resolve() if not self.classifier_path.is_file(): raise FileNotFoundError(f"Missing classifier checkpoint: {self.classifier_path}") checkpoint = torch.load(self.classifier_path, map_location="cpu") if not isinstance(checkpoint, dict): raise ValueError("Unsupported classifier checkpoint format.") state_dict = checkpoint.get("state_dict", checkpoint) embedding_dim = int(checkpoint.get("embedding_dim", 1280)) num_labels = int(checkpoint.get("num_labels", 11)) label_names = checkpoint.get("label_names") dropout_rates = tuple(checkpoint.get("dropout_rates", [0.3, 0.3, 0.2])) hidden_dims = tuple(checkpoint.get("hidden_dims", [512, 256, 128])) self.classifier = ProteinLocalizationClassifier( embedding_dim=embedding_dim, num_labels=num_labels, label_names=label_names, dropout_rates=dropout_rates, hidden_dims=hidden_dims, ) self.classifier.load_state_dict(state_dict) self.classifier.to(self.device).eval() self.label_names = list(self.classifier.label_names) self.label_to_idx = {n: i for i, n in enumerate(self.label_names)} self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(esm_model_name) hf_logger = logging.getLogger("transformers.modeling_utils") prev_level = hf_logger.level hf_logger.setLevel(logging.ERROR) try: with contextlib.redirect_stderr(io.StringIO()): self.esm_model = AutoModel.from_pretrained( esm_model_name, attn_implementation="eager", ignore_mismatched_sizes=True, ) finally: hf_logger.setLevel(prev_level) self.esm_model.to(self.device).eval() self.esm_model_name = esm_model_name @property def esm_encoder(self) -> PreTrainedModel: """ESM backbone (``AutoModel``) used for IG and shared with :class:`ProteinRelocalizer` embeddings.""" return self.esm_model @property def esm_tokenizer(self) -> PreTrainedTokenizerBase: """Tokenizer paired with ``esm_encoder``.""" return self.tokenizer def mean_pool_embedding(self, sequence: str) -> np.ndarray: """ Mean-pooled ESM representation for one sequence (same pooling as training embeddings). Runs under ``torch.no_grad()``. """ seq = sequence.upper().strip() if not seq: raise ValueError("Empty sequence") toks = self._tokenize(seq) with torch.no_grad(): out = self.esm_model(**toks, return_dict=True) pooled = out.last_hidden_state.mean(dim=1) return pooled.detach().cpu().numpy().astype(np.float32).squeeze(0) def _tokenize(self, sequence: str) -> Dict[str, torch.Tensor]: if not sequence or not isinstance(sequence, str): raise ValueError("sequence must be a non-empty string") toks = self.tokenizer( sequence, return_tensors="pt", add_special_tokens=True, truncation=True, ) return {k: v.to(self.device) for k, v in toks.items()} def _trim_special_tokens(self, token_scores: np.ndarray, sequence: str) -> List[Tuple[int, str, float]]: # ESM tokenization with add_special_tokens=True yields [CLS] + residues + [EOS]. if token_scores.ndim != 1: raise ValueError("Expected 1D token score array.") if token_scores.shape[0] >= len(sequence) + 2: core = token_scores[1 : 1 + len(sequence)] else: # Best-effort fallback if tokenizer behavior differs. core = token_scores[: len(sequence)] return [(i + 1, aa, float(core[i])) for i, aa in enumerate(sequence[: len(core)])] @staticmethod def _clear_cuda_cache() -> None: if torch.cuda.is_available(): torch.cuda.empty_cache() def get_attention_scores(self, sequence: str) -> Dict[str, Any]: try: toks = self._tokenize(sequence) with torch.no_grad(): outputs = self.esm_model( input_ids=toks["input_ids"], attention_mask=toks.get("attention_mask"), output_attentions=True, return_dict=True, ) if outputs.attentions is None or len(outputs.attentions) == 0: raise RuntimeError("ESM model did not return attentions.") # last layer: [batch, heads, tokens, tokens] last_attn = outputs.attentions[-1][0] mean_attn = last_attn.mean(dim=0) # [tokens, tokens], averaged over heads token_importance = mean_attn.mean(dim=0).detach().cpu().numpy() residue_scores = self._trim_special_tokens(token_importance, sequence) return { "residue_scores": residue_scores, "raw_attention": mean_attn.detach().cpu().numpy(), } finally: self._clear_cuda_cache() def get_integrated_gradients(self, sequence: str, target_location: str) -> Dict[str, Any]: if IntegratedGradients is None: raise ImportError("captum is required for integrated gradients. Install with `pip install captum`.") if target_location not in self.label_to_idx: raise ValueError(f"Unknown target_location={target_location!r}. Known labels: {self.label_names}") target_idx = int(self.label_to_idx[target_location]) try: toks = self._tokenize(sequence) input_ids = toks["input_ids"] attention_mask = toks.get("attention_mask") input_embed = self.esm_model.get_input_embeddings()(input_ids) baseline_embed = torch.zeros_like(input_embed) def forward_func(inputs_embeds: torch.Tensor, attn_mask: torch.Tensor, cls_idx: int) -> torch.Tensor: enc = self.esm_model( inputs_embeds=inputs_embeds, attention_mask=attn_mask, return_dict=True, ) pooled = enc.last_hidden_state.mean(dim=1) logits = self.classifier(pooled) return logits[:, cls_idx] ig = IntegratedGradients(forward_func) attributions = ig.attribute( input_embed, baselines=baseline_embed, additional_forward_args=(attention_mask, target_idx), n_steps=32, internal_batch_size=1, ) token_scores = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy() residue_scores = self._trim_special_tokens(token_scores, sequence) return { "residue_scores": residue_scores, "raw_attributions": attributions.detach().cpu().numpy(), } finally: self._clear_cuda_cache() def identify_hot_regions( self, residue_scores: Sequence[Tuple[int, str, float]], window_size: int = 10, top_percentile: float = 90.0, ) -> List[Dict[str, Any]]: if len(residue_scores) == 0: return [] window_size = max(1, int(window_size)) vals = np.asarray([float(x[2]) for x in residue_scores], dtype=np.float64) seq = "".join([str(x[1]) for x in residue_scores]) if vals.size < window_size: avg = float(vals.mean()) return [{"start": 1, "end": int(vals.size), "avg_score": avg, "subsequence": seq}] win_avgs = np.asarray( [float(vals[i : i + window_size].mean()) for i in range(0, vals.size - window_size + 1)], dtype=np.float64, ) thr = float(np.percentile(win_avgs, top_percentile)) hot_starts = np.where(win_avgs >= thr)[0].tolist() if not hot_starts: return [] regions: List[Dict[str, Any]] = [] run_start = hot_starts[0] prev = hot_starts[0] for s in hot_starts[1:]: if s <= prev + 1: prev = s continue start = run_start + 1 end = min(prev + window_size, len(seq)) avg_score = float(vals[start - 1 : end].mean()) regions.append({"start": start, "end": end, "avg_score": avg_score, "subsequence": seq[start - 1 : end]}) run_start = s prev = s start = run_start + 1 end = min(prev + window_size, len(seq)) avg_score = float(vals[start - 1 : end].mean()) regions.append({"start": start, "end": end, "avg_score": avg_score, "subsequence": seq[start - 1 : end]}) return regions def validate_against_known_signals(self, sequence: str, hot_regions: Sequence[Dict[str, Any]]) -> Dict[str, Any]: seq = sequence.upper() n = len(seq) def _overlap(region: Tuple[int, int] | None) -> bool: if region is None: return False a1, a2 = region for r in hot_regions: b1, b2 = int(r["start"]), int(r["end"]) if max(a1, b1) <= min(a2, b2): return True return False # Signal peptide heuristic: hydrophobic region in residues 1-30. sp_region: Tuple[int, int] | None = None for w in range(8, 18): for i in range(0, max(0, min(30, n) - w + 1)): frag = seq[i : i + w] hyd_frac = sum(aa in HYDROPHOBIC_AA for aa in frag) / max(w, 1) if hyd_frac >= 0.6: sp_region = (i + 1, i + w) break if sp_region is not None: break # Mitochondrial transit peptide: first 10-70 enriched in R/S/L/A and positive residues. mtp_region: Tuple[int, int] | None = None mtp_end = min(70, n) if mtp_end >= 10: frag = seq[:mtp_end] rsla = sum(aa in MTP_ENRICHED_AA for aa in frag) / len(frag) pos = sum(aa in POSITIVE_AA for aa in frag) / len(frag) acidic = sum(aa in {"D", "E"} for aa in frag) / len(frag) if rsla >= 0.35 and pos >= 0.12 and acidic <= 0.12: mtp_region = (1, mtp_end) # NLS: >=4 K/R in any 7-mer. nls_region: Tuple[int, int] | None = None for i in range(0, max(0, n - 7 + 1)): frag = seq[i : i + 7] if sum(aa in POSITIVE_AA for aa in frag) >= 4: nls_region = (i + 1, i + 7) break # Transmembrane domain: 18-25 residues with strong hydrophobic content. tm_region: Tuple[int, int] | None = None for w in range(25, 17, -1): for i in range(0, max(0, n - w + 1)): frag = seq[i : i + w] hyd_frac = sum(aa in HYDROPHOBIC_AA for aa in frag) / max(w, 1) if hyd_frac >= 0.75: tm_region = (i + 1, i + w) break if tm_region is not None: break # ER retention signal: C-terminal KDEL/HDEL. er_region: Tuple[int, int] | None = None if n >= 4 and (seq.endswith("KDEL") or seq.endswith("HDEL")): er_region = (n - 3, n) out: Dict[str, _SignalDetection] = { "signal_peptide": _SignalDetection(sp_region is not None, sp_region, _overlap(sp_region)), "mitochondrial_transit_peptide": _SignalDetection(mtp_region is not None, mtp_region, _overlap(mtp_region)), "nuclear_localization_signal": _SignalDetection(nls_region is not None, nls_region, _overlap(nls_region)), "transmembrane_domain": _SignalDetection(tm_region is not None, tm_region, _overlap(tm_region)), "er_retention_signal": _SignalDetection(er_region is not None, er_region, _overlap(er_region)), } return { k: { "detected": v.detected, "region": v.region, "overlap_with_attribution": v.overlap_with_attribution, } for k, v in out.items() } def visualize_attribution( self, sequence: str, residue_scores: Sequence[Tuple[int, str, float]], hot_regions: Sequence[Dict[str, Any]], output_path: str | Path, ) -> None: if len(residue_scores) == 0: raise ValueError("residue_scores is empty") out_path = Path(output_path).expanduser().resolve() out_path.parent.mkdir(parents=True, exist_ok=True) values = np.asarray([float(x[2]) for x in residue_scores], dtype=np.float64) vmax = float(np.max(np.abs(values))) if values.size > 0 else 1.0 vmax = max(vmax, 1e-6) per_row = 50 n = len(residue_scores) rows = int(np.ceil(n / per_row)) signals = self.validate_against_known_signals(sequence, hot_regions) signal_colors = { "signal_peptide": "#2ca02c", "mitochondrial_transit_peptide": "#9467bd", "nuclear_localization_signal": "#ff7f0e", "transmembrane_domain": "#8c564b", "er_retention_signal": "#17becf", } fig, axes = plt.subplots(rows, 1, figsize=(16, max(2.3, rows * 2.4))) axes_arr = np.atleast_1d(axes) cmap = plt.get_cmap("coolwarm") for r in range(rows): ax = axes_arr[r] s = r * per_row e = min((r + 1) * per_row, n) seg_vals = values[s:e][None, :] ax.imshow(seg_vals, cmap=cmap, aspect="auto", vmin=-vmax, vmax=vmax, extent=(s + 1, e, 0.0, 1.0)) ax.set_yticks([]) ax.set_xlim(s + 1, e) ax.set_ylim(-0.55, 1.25) ax.set_xticks(np.arange(s + 1, e + 1, 5)) ax.set_xlabel("Residue position") ax.set_title(f"Residues {s + 1}-{e}", fontsize=10, loc="left") # amino-acid labels for i in range(s, e): ax.text(i + 1, 0.5, sequence[i], ha="center", va="center", fontsize=6, color="black") # hot-region brackets for region in hot_regions: hs, he = int(region["start"]), int(region["end"]) if he < s + 1 or hs > e: continue bs = max(hs, s + 1) be = min(he, e) y = -0.12 ax.plot([bs, be], [y, y], color="black", linewidth=1.5) ax.plot([bs, bs], [y, y + 0.08], color="black", linewidth=1.5) ax.plot([be, be], [y, y + 0.08], color="black", linewidth=1.5) # known signal track y0, h = -0.38, 0.14 for sig_name, payload in signals.items(): reg = payload.get("region") if not payload.get("detected") or reg is None: continue ss, se = int(reg[0]), int(reg[1]) if se < s + 1 or ss > e: continue ds = max(ss, s + 1) de = min(se, e) rect = patches.Rectangle((ds, y0), de - ds + 1, h, color=signal_colors[sig_name], alpha=0.85) ax.add_patch(rect) ax.text(ds, y0 - 0.03, sig_name.replace("_", " "), fontsize=6, va="top", ha="left") mappable = plt.cm.ScalarMappable(cmap=cmap) mappable.set_clim(-vmax, vmax) cbar = fig.colorbar(mappable, ax=axes_arr.tolist(), fraction=0.02, pad=0.02) cbar.set_label("Attribution score") fig.suptitle("Residue-level attribution map", fontsize=13, fontweight="bold") fig.tight_layout() fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig)