Spaces:
Running
Running
| """ | |
| Residue-level interpretation utilities for protein localization predictions. | |
| This module uses the full ESM encoder (not precomputed embeddings) so gradients | |
| can flow from classifier outputs back to token embeddings. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Type | |
| import contextlib | |
| import importlib | |
| import io | |
| import logging | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from matplotlib import patches | |
| from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase | |
| from src.utils.device import resolve_torch_device | |
| from .classifier import ProteinLocalizationClassifier | |
| def _load_integrated_gradients() -> Optional[Type[Any]]: | |
| """Load Captum lazily so static analysis does not require ``captum`` to be installed.""" | |
| try: | |
| mod = importlib.import_module("captum.attr") | |
| return getattr(mod, "IntegratedGradients") | |
| except ImportError: # pragma: no cover | |
| return None | |
| IntegratedGradients = _load_integrated_gradients() | |
| HYDROPHOBIC_AA = set("AVILMFWP") | |
| MTP_ENRICHED_AA = set("RSLA") | |
| POSITIVE_AA = set("KR") | |
| class _SignalDetection: | |
| detected: bool | |
| region: Tuple[int, int] | None | |
| overlap_with_attribution: bool | |
| class ProteinInterpreter: | |
| def __init__( | |
| self, | |
| classifier_path: str | Path, | |
| esm_model_name: str = "facebook/esm2_t33_650M_UR50D", | |
| device: str | torch.device | None = None, | |
| ) -> None: | |
| self.device = resolve_torch_device(device) | |
| self.classifier_path = Path(classifier_path).expanduser().resolve() | |
| if not self.classifier_path.is_file(): | |
| raise FileNotFoundError(f"Missing classifier checkpoint: {self.classifier_path}") | |
| checkpoint = torch.load(self.classifier_path, map_location="cpu") | |
| if not isinstance(checkpoint, dict): | |
| raise ValueError("Unsupported classifier checkpoint format.") | |
| state_dict = checkpoint.get("state_dict", checkpoint) | |
| embedding_dim = int(checkpoint.get("embedding_dim", 1280)) | |
| num_labels = int(checkpoint.get("num_labels", 11)) | |
| label_names = checkpoint.get("label_names") | |
| dropout_rates = tuple(checkpoint.get("dropout_rates", [0.3, 0.3, 0.2])) | |
| hidden_dims = tuple(checkpoint.get("hidden_dims", [512, 256, 128])) | |
| self.classifier = ProteinLocalizationClassifier( | |
| embedding_dim=embedding_dim, | |
| num_labels=num_labels, | |
| label_names=label_names, | |
| dropout_rates=dropout_rates, | |
| hidden_dims=hidden_dims, | |
| ) | |
| self.classifier.load_state_dict(state_dict) | |
| self.classifier.to(self.device).eval() | |
| self.label_names = list(self.classifier.label_names) | |
| self.label_to_idx = {n: i for i, n in enumerate(self.label_names)} | |
| self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(esm_model_name) | |
| hf_logger = logging.getLogger("transformers.modeling_utils") | |
| prev_level = hf_logger.level | |
| hf_logger.setLevel(logging.ERROR) | |
| try: | |
| with contextlib.redirect_stderr(io.StringIO()): | |
| self.esm_model = AutoModel.from_pretrained( | |
| esm_model_name, | |
| attn_implementation="eager", | |
| ignore_mismatched_sizes=True, | |
| ) | |
| finally: | |
| hf_logger.setLevel(prev_level) | |
| self.esm_model.to(self.device).eval() | |
| self.esm_model_name = esm_model_name | |
| def esm_encoder(self) -> PreTrainedModel: | |
| """ESM backbone (``AutoModel``) used for IG and shared with :class:`ProteinRelocalizer` embeddings.""" | |
| return self.esm_model | |
| def esm_tokenizer(self) -> PreTrainedTokenizerBase: | |
| """Tokenizer paired with ``esm_encoder``.""" | |
| return self.tokenizer | |
| def mean_pool_embedding(self, sequence: str) -> np.ndarray: | |
| """ | |
| Mean-pooled ESM representation for one sequence (same pooling as training embeddings). | |
| Runs under ``torch.no_grad()``. | |
| """ | |
| seq = sequence.upper().strip() | |
| if not seq: | |
| raise ValueError("Empty sequence") | |
| toks = self._tokenize(seq) | |
| with torch.no_grad(): | |
| out = self.esm_model(**toks, return_dict=True) | |
| pooled = out.last_hidden_state.mean(dim=1) | |
| return pooled.detach().cpu().numpy().astype(np.float32).squeeze(0) | |
| def _tokenize(self, sequence: str) -> Dict[str, torch.Tensor]: | |
| if not sequence or not isinstance(sequence, str): | |
| raise ValueError("sequence must be a non-empty string") | |
| toks = self.tokenizer( | |
| sequence, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| truncation=True, | |
| ) | |
| return {k: v.to(self.device) for k, v in toks.items()} | |
| def _trim_special_tokens(self, token_scores: np.ndarray, sequence: str) -> List[Tuple[int, str, float]]: | |
| # ESM tokenization with add_special_tokens=True yields [CLS] + residues + [EOS]. | |
| if token_scores.ndim != 1: | |
| raise ValueError("Expected 1D token score array.") | |
| if token_scores.shape[0] >= len(sequence) + 2: | |
| core = token_scores[1 : 1 + len(sequence)] | |
| else: | |
| # Best-effort fallback if tokenizer behavior differs. | |
| core = token_scores[: len(sequence)] | |
| return [(i + 1, aa, float(core[i])) for i, aa in enumerate(sequence[: len(core)])] | |
| def _clear_cuda_cache() -> None: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_attention_scores(self, sequence: str) -> Dict[str, Any]: | |
| try: | |
| toks = self._tokenize(sequence) | |
| with torch.no_grad(): | |
| outputs = self.esm_model( | |
| input_ids=toks["input_ids"], | |
| attention_mask=toks.get("attention_mask"), | |
| output_attentions=True, | |
| return_dict=True, | |
| ) | |
| if outputs.attentions is None or len(outputs.attentions) == 0: | |
| raise RuntimeError("ESM model did not return attentions.") | |
| # last layer: [batch, heads, tokens, tokens] | |
| last_attn = outputs.attentions[-1][0] | |
| mean_attn = last_attn.mean(dim=0) # [tokens, tokens], averaged over heads | |
| token_importance = mean_attn.mean(dim=0).detach().cpu().numpy() | |
| residue_scores = self._trim_special_tokens(token_importance, sequence) | |
| return { | |
| "residue_scores": residue_scores, | |
| "raw_attention": mean_attn.detach().cpu().numpy(), | |
| } | |
| finally: | |
| self._clear_cuda_cache() | |
| def get_integrated_gradients(self, sequence: str, target_location: str) -> Dict[str, Any]: | |
| if IntegratedGradients is None: | |
| raise ImportError("captum is required for integrated gradients. Install with `pip install captum`.") | |
| if target_location not in self.label_to_idx: | |
| raise ValueError(f"Unknown target_location={target_location!r}. Known labels: {self.label_names}") | |
| target_idx = int(self.label_to_idx[target_location]) | |
| try: | |
| toks = self._tokenize(sequence) | |
| input_ids = toks["input_ids"] | |
| attention_mask = toks.get("attention_mask") | |
| input_embed = self.esm_model.get_input_embeddings()(input_ids) | |
| baseline_embed = torch.zeros_like(input_embed) | |
| def forward_func(inputs_embeds: torch.Tensor, attn_mask: torch.Tensor, cls_idx: int) -> torch.Tensor: | |
| enc = self.esm_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attn_mask, | |
| return_dict=True, | |
| ) | |
| pooled = enc.last_hidden_state.mean(dim=1) | |
| logits = self.classifier(pooled) | |
| return logits[:, cls_idx] | |
| ig = IntegratedGradients(forward_func) | |
| attributions = ig.attribute( | |
| input_embed, | |
| baselines=baseline_embed, | |
| additional_forward_args=(attention_mask, target_idx), | |
| n_steps=32, | |
| internal_batch_size=1, | |
| ) | |
| token_scores = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy() | |
| residue_scores = self._trim_special_tokens(token_scores, sequence) | |
| return { | |
| "residue_scores": residue_scores, | |
| "raw_attributions": attributions.detach().cpu().numpy(), | |
| } | |
| finally: | |
| self._clear_cuda_cache() | |
| def identify_hot_regions( | |
| self, | |
| residue_scores: Sequence[Tuple[int, str, float]], | |
| window_size: int = 10, | |
| top_percentile: float = 90.0, | |
| ) -> List[Dict[str, Any]]: | |
| if len(residue_scores) == 0: | |
| return [] | |
| window_size = max(1, int(window_size)) | |
| vals = np.asarray([float(x[2]) for x in residue_scores], dtype=np.float64) | |
| seq = "".join([str(x[1]) for x in residue_scores]) | |
| if vals.size < window_size: | |
| avg = float(vals.mean()) | |
| return [{"start": 1, "end": int(vals.size), "avg_score": avg, "subsequence": seq}] | |
| win_avgs = np.asarray( | |
| [float(vals[i : i + window_size].mean()) for i in range(0, vals.size - window_size + 1)], | |
| dtype=np.float64, | |
| ) | |
| thr = float(np.percentile(win_avgs, top_percentile)) | |
| hot_starts = np.where(win_avgs >= thr)[0].tolist() | |
| if not hot_starts: | |
| return [] | |
| regions: List[Dict[str, Any]] = [] | |
| run_start = hot_starts[0] | |
| prev = hot_starts[0] | |
| for s in hot_starts[1:]: | |
| if s <= prev + 1: | |
| prev = s | |
| continue | |
| start = run_start + 1 | |
| end = min(prev + window_size, len(seq)) | |
| avg_score = float(vals[start - 1 : end].mean()) | |
| regions.append({"start": start, "end": end, "avg_score": avg_score, "subsequence": seq[start - 1 : end]}) | |
| run_start = s | |
| prev = s | |
| start = run_start + 1 | |
| end = min(prev + window_size, len(seq)) | |
| avg_score = float(vals[start - 1 : end].mean()) | |
| regions.append({"start": start, "end": end, "avg_score": avg_score, "subsequence": seq[start - 1 : end]}) | |
| return regions | |
| def validate_against_known_signals(self, sequence: str, hot_regions: Sequence[Dict[str, Any]]) -> Dict[str, Any]: | |
| seq = sequence.upper() | |
| n = len(seq) | |
| def _overlap(region: Tuple[int, int] | None) -> bool: | |
| if region is None: | |
| return False | |
| a1, a2 = region | |
| for r in hot_regions: | |
| b1, b2 = int(r["start"]), int(r["end"]) | |
| if max(a1, b1) <= min(a2, b2): | |
| return True | |
| return False | |
| # Signal peptide heuristic: hydrophobic region in residues 1-30. | |
| sp_region: Tuple[int, int] | None = None | |
| for w in range(8, 18): | |
| for i in range(0, max(0, min(30, n) - w + 1)): | |
| frag = seq[i : i + w] | |
| hyd_frac = sum(aa in HYDROPHOBIC_AA for aa in frag) / max(w, 1) | |
| if hyd_frac >= 0.6: | |
| sp_region = (i + 1, i + w) | |
| break | |
| if sp_region is not None: | |
| break | |
| # Mitochondrial transit peptide: first 10-70 enriched in R/S/L/A and positive residues. | |
| mtp_region: Tuple[int, int] | None = None | |
| mtp_end = min(70, n) | |
| if mtp_end >= 10: | |
| frag = seq[:mtp_end] | |
| rsla = sum(aa in MTP_ENRICHED_AA for aa in frag) / len(frag) | |
| pos = sum(aa in POSITIVE_AA for aa in frag) / len(frag) | |
| acidic = sum(aa in {"D", "E"} for aa in frag) / len(frag) | |
| if rsla >= 0.35 and pos >= 0.12 and acidic <= 0.12: | |
| mtp_region = (1, mtp_end) | |
| # NLS: >=4 K/R in any 7-mer. | |
| nls_region: Tuple[int, int] | None = None | |
| for i in range(0, max(0, n - 7 + 1)): | |
| frag = seq[i : i + 7] | |
| if sum(aa in POSITIVE_AA for aa in frag) >= 4: | |
| nls_region = (i + 1, i + 7) | |
| break | |
| # Transmembrane domain: 18-25 residues with strong hydrophobic content. | |
| tm_region: Tuple[int, int] | None = None | |
| for w in range(25, 17, -1): | |
| for i in range(0, max(0, n - w + 1)): | |
| frag = seq[i : i + w] | |
| hyd_frac = sum(aa in HYDROPHOBIC_AA for aa in frag) / max(w, 1) | |
| if hyd_frac >= 0.75: | |
| tm_region = (i + 1, i + w) | |
| break | |
| if tm_region is not None: | |
| break | |
| # ER retention signal: C-terminal KDEL/HDEL. | |
| er_region: Tuple[int, int] | None = None | |
| if n >= 4 and (seq.endswith("KDEL") or seq.endswith("HDEL")): | |
| er_region = (n - 3, n) | |
| out: Dict[str, _SignalDetection] = { | |
| "signal_peptide": _SignalDetection(sp_region is not None, sp_region, _overlap(sp_region)), | |
| "mitochondrial_transit_peptide": _SignalDetection(mtp_region is not None, mtp_region, _overlap(mtp_region)), | |
| "nuclear_localization_signal": _SignalDetection(nls_region is not None, nls_region, _overlap(nls_region)), | |
| "transmembrane_domain": _SignalDetection(tm_region is not None, tm_region, _overlap(tm_region)), | |
| "er_retention_signal": _SignalDetection(er_region is not None, er_region, _overlap(er_region)), | |
| } | |
| return { | |
| k: { | |
| "detected": v.detected, | |
| "region": v.region, | |
| "overlap_with_attribution": v.overlap_with_attribution, | |
| } | |
| for k, v in out.items() | |
| } | |
| def visualize_attribution( | |
| self, | |
| sequence: str, | |
| residue_scores: Sequence[Tuple[int, str, float]], | |
| hot_regions: Sequence[Dict[str, Any]], | |
| output_path: str | Path, | |
| ) -> None: | |
| if len(residue_scores) == 0: | |
| raise ValueError("residue_scores is empty") | |
| out_path = Path(output_path).expanduser().resolve() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| values = np.asarray([float(x[2]) for x in residue_scores], dtype=np.float64) | |
| vmax = float(np.max(np.abs(values))) if values.size > 0 else 1.0 | |
| vmax = max(vmax, 1e-6) | |
| per_row = 50 | |
| n = len(residue_scores) | |
| rows = int(np.ceil(n / per_row)) | |
| signals = self.validate_against_known_signals(sequence, hot_regions) | |
| signal_colors = { | |
| "signal_peptide": "#2ca02c", | |
| "mitochondrial_transit_peptide": "#9467bd", | |
| "nuclear_localization_signal": "#ff7f0e", | |
| "transmembrane_domain": "#8c564b", | |
| "er_retention_signal": "#17becf", | |
| } | |
| fig, axes = plt.subplots(rows, 1, figsize=(16, max(2.3, rows * 2.4))) | |
| axes_arr = np.atleast_1d(axes) | |
| cmap = plt.get_cmap("coolwarm") | |
| for r in range(rows): | |
| ax = axes_arr[r] | |
| s = r * per_row | |
| e = min((r + 1) * per_row, n) | |
| seg_vals = values[s:e][None, :] | |
| ax.imshow(seg_vals, cmap=cmap, aspect="auto", vmin=-vmax, vmax=vmax, extent=(s + 1, e, 0.0, 1.0)) | |
| ax.set_yticks([]) | |
| ax.set_xlim(s + 1, e) | |
| ax.set_ylim(-0.55, 1.25) | |
| ax.set_xticks(np.arange(s + 1, e + 1, 5)) | |
| ax.set_xlabel("Residue position") | |
| ax.set_title(f"Residues {s + 1}-{e}", fontsize=10, loc="left") | |
| # amino-acid labels | |
| for i in range(s, e): | |
| ax.text(i + 1, 0.5, sequence[i], ha="center", va="center", fontsize=6, color="black") | |
| # hot-region brackets | |
| for region in hot_regions: | |
| hs, he = int(region["start"]), int(region["end"]) | |
| if he < s + 1 or hs > e: | |
| continue | |
| bs = max(hs, s + 1) | |
| be = min(he, e) | |
| y = -0.12 | |
| ax.plot([bs, be], [y, y], color="black", linewidth=1.5) | |
| ax.plot([bs, bs], [y, y + 0.08], color="black", linewidth=1.5) | |
| ax.plot([be, be], [y, y + 0.08], color="black", linewidth=1.5) | |
| # known signal track | |
| y0, h = -0.38, 0.14 | |
| for sig_name, payload in signals.items(): | |
| reg = payload.get("region") | |
| if not payload.get("detected") or reg is None: | |
| continue | |
| ss, se = int(reg[0]), int(reg[1]) | |
| if se < s + 1 or ss > e: | |
| continue | |
| ds = max(ss, s + 1) | |
| de = min(se, e) | |
| rect = patches.Rectangle((ds, y0), de - ds + 1, h, color=signal_colors[sig_name], alpha=0.85) | |
| ax.add_patch(rect) | |
| ax.text(ds, y0 - 0.03, sig_name.replace("_", " "), fontsize=6, va="top", ha="left") | |
| mappable = plt.cm.ScalarMappable(cmap=cmap) | |
| mappable.set_clim(-vmax, vmax) | |
| cbar = fig.colorbar(mappable, ax=axes_arr.tolist(), fraction=0.02, pad=0.02) | |
| cbar.set_label("Attribution score") | |
| fig.suptitle("Residue-level attribution map", fontsize=13, fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |