Spaces:
Runtime error
Runtime error
| """ | |
| explainability_engine.py | |
| ======================== | |
| Extract ALL internal explainability signals from each of the three models. | |
| No signal is simplified or omitted. | |
| Splice model signals: | |
| - probability | |
| - conv3 activation norm vector (99,) | |
| - mutation-centered activation peak | |
| - splice aura distance (donor / acceptor) | |
| - counterfactual delta (all alternative bases) | |
| - feature ablation response (splice / region / mutation groups) | |
| - risk tier classification | |
| V4 model signals: | |
| - probability | |
| - importance head vector (via conv3 hook β identical architecture) | |
| - mutation-centered importance density | |
| Classic model signals: | |
| - probability | |
| - importance head output (scalar) | |
| - region importance (exon / intron) | |
| - conv3 activation norm vector (99,) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from model_loader import ( | |
| MutationPredictorCNN_v2, | |
| MutationPredictorCNN_v4, | |
| MutationPredictorClassic, | |
| ModelRegistry, | |
| encode_for_v2, | |
| encode_for_v4, | |
| find_mutation_pos, | |
| ALL_BASES, | |
| MUT_TYPES, | |
| ) | |
| logger = logging.getLogger("mutation_xai.xai") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Shared helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _conv3_activation_norm(model: torch.nn.Module, x: torch.Tensor, | |
| forward_fn) -> np.ndarray: | |
| """ | |
| Register a forward hook on model.conv3, run forward_fn(x), return | |
| L2-normalised per-position activation norm vector of shape (99,). | |
| """ | |
| activations: dict = {} | |
| def _hook(module, inp, out): | |
| activations["conv3"] = out.detach() | |
| hook = model.conv3.register_forward_hook(_hook) | |
| try: | |
| with torch.no_grad(): | |
| forward_fn(x) | |
| finally: | |
| hook.remove() | |
| act = activations.get("conv3") | |
| if act is None: | |
| return np.zeros(99) | |
| # act shape: (1, 256, 99) | |
| norm = act.squeeze(0).norm(dim=0).numpy() # (99,) | |
| if norm.max() > 0: | |
| norm = norm / norm.max() | |
| return norm | |
| def _gradient_attribution(model: torch.nn.Module, enc: torch.Tensor, | |
| forward_fn_grad) -> np.ndarray: | |
| """ | |
| Compute input-gradient attribution for the sequence portion. | |
| Returns normalised per-position attribution of shape (99,). | |
| """ | |
| model.eval() | |
| enc_leaf = enc.clone().detach().requires_grad_(True) | |
| logit = forward_fn_grad(enc_leaf) | |
| model.zero_grad() | |
| logit.backward() | |
| grad = enc_leaf.grad | |
| if grad is None: | |
| return np.zeros(99) | |
| seq_grad = grad[:1089].view(99, 11) | |
| attr = seq_grad.abs().norm(dim=1).detach().numpy() | |
| if attr.max() > 0: | |
| attr = attr / attr.max() | |
| return attr | |
| def _mutation_peak_ratio(profile: np.ndarray, mutation_pos: int) -> float: | |
| """ | |
| peak_signal / mean_signal, where peak_signal is the profile value at | |
| mutation_pos. Returns 0.0 if mutation_pos < 0 or mean == 0. | |
| """ | |
| if mutation_pos < 0 or mutation_pos >= len(profile): | |
| return 0.0 | |
| mean_sig = float(profile.mean()) | |
| if mean_sig == 0: | |
| return 0.0 | |
| return float(profile[mutation_pos]) / mean_sig | |
| def _signal_concentration_index(profile: np.ndarray, mutation_pos: int, | |
| window: int = 10) -> float: | |
| """ | |
| Fraction of total activation energy within Β±window of mutation_pos. | |
| Ranges 0β1; 1.0 = perfectly concentrated. | |
| """ | |
| if mutation_pos < 0: | |
| return 0.0 | |
| total = float(profile.sum()) | |
| if total == 0: | |
| return 0.0 | |
| lo = max(0, mutation_pos - window) | |
| hi = min(len(profile), mutation_pos + window + 1) | |
| local = float(profile[lo:hi].sum()) | |
| return local / total | |
| def _splice_distances(ref_seq: str, mutation_pos: int): | |
| """ | |
| Scan ref_seq for GT (donor) and AG (acceptor) dinucleotides. | |
| Returns (dist_donor, dist_acceptor, nearest_donor_pos, nearest_acceptor_pos). | |
| Any value may be None if no site found. | |
| """ | |
| seq = (ref_seq.upper() + "N" * 99)[:99] | |
| donors, acceptors = [], [] | |
| for i in range(len(seq) - 1): | |
| if seq[i:i+2] == "GT": donors.append(i) | |
| if seq[i:i+2] == "AG": acceptors.append(i) | |
| if mutation_pos < 0: | |
| return None, None, None, None | |
| dist_d = nearest_d = None | |
| dist_a = nearest_a = None | |
| if donors: | |
| pairs = sorted([(abs(mutation_pos - p), p) for p in donors]) | |
| dist_d, nearest_d = pairs[0] | |
| if acceptors: | |
| pairs = sorted([(abs(mutation_pos - p), p) for p in acceptors]) | |
| dist_a, nearest_a = pairs[0] | |
| return dist_d, dist_a, nearest_d, nearest_a | |
| def _classify_splice_risk(distance: Optional[int]) -> str: | |
| if distance is None: return "UNKNOWN" | |
| if distance <= 2: return "CRITICAL SPLICE SITE" | |
| if distance <= 8: return "SPLICE REGION" | |
| return "NON-SPLICE" | |
| def _classify_risk_tier(prob: float) -> tuple[str, str]: | |
| if prob >= 0.90: return "PATHOGENIC", "Very high confidence" | |
| if prob >= 0.70: return "LIKELY PATHOGENIC", "High confidence" | |
| if prob >= 0.50: return "POSSIBLY PATHOGENIC", "Moderate confidence" | |
| if prob >= 0.20: return "LIKELY BENIGN", "Low pathogenic signal" | |
| return "BENIGN", "Very low pathogenic signal" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Signal dataclasses | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SpliceSignals: | |
| probability: float | |
| risk_tier: str | |
| tier_desc: str | |
| conv3_norm: np.ndarray # (99,) | |
| gradient_attribution: np.ndarray # (99,) | |
| mutation_pos: int | |
| mutation_peak_ratio: float | |
| signal_concentration: float | |
| imp_score: float # importance_head output | |
| region_imp: np.ndarray # (2,) [exon, intron] | |
| splice_imp: np.ndarray # (3,) [donor, acc, region] | |
| dist_donor: Optional[int] | |
| dist_acceptor: Optional[int] | |
| nearest_donor: Optional[int] | |
| nearest_acceptor: Optional[int] | |
| splice_risk_donor: str | |
| splice_risk_acceptor: str | |
| counterfactual: dict # all-base CF results | |
| ablation: dict # feature ablation deltas | |
| splice_aura_score: float # proximity-weighted splice signal | |
| class V4Signals: | |
| probability: float | |
| conv3_norm: np.ndarray # (99,) | |
| gradient_attribution: np.ndarray # (99,) | |
| mutation_pos: int | |
| mutation_peak_ratio: float | |
| signal_concentration: float | |
| class ClassicSignals: | |
| probability: float | |
| conv3_norm: np.ndarray # (99,) | |
| importance_head: float # scalar importance_head output | |
| region_imp: np.ndarray # (2,) [exon, intron] | |
| mutation_pos: int | |
| mutation_peak_ratio: float | |
| signal_concentration: float | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # β Extract Splice Signals | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_splice_signals(model: MutationPredictorCNN_v2, | |
| ref_seq: str, mut_seq: str, | |
| exon_flag: int, intron_flag: int) -> SpliceSignals: | |
| enc = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag) | |
| # ββ base forward pass ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with torch.no_grad(): | |
| x = enc.unsqueeze(0) | |
| logit, imp_t, r_imp_t, s_imp_t = model(x) | |
| prob = float(torch.sigmoid(logit).item()) | |
| imp_score = float(imp_t.item()) | |
| region_imp= r_imp_t[0].numpy() | |
| splice_imp= s_imp_t[0].numpy() | |
| tier, tier_desc = _classify_risk_tier(prob) | |
| mutation_pos = find_mutation_pos(ref_seq, mut_seq) | |
| # ββ conv3 activation norm ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _fwd(x_in): | |
| return model(x_in.unsqueeze(0)) | |
| conv3_norm = _conv3_activation_norm( | |
| model, enc, | |
| lambda x: model(x.unsqueeze(0)) | |
| ) | |
| # ββ gradient attribution βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _fwd_grad(leaf: torch.Tensor): | |
| logit_g, _, _, _ = model(leaf.unsqueeze(0)) | |
| return logit_g | |
| grad_attr = _gradient_attribution(model, enc, _fwd_grad) | |
| # ββ mutation-peak derived metrics βββββββββββββββββββββββββββββββββββββββββ | |
| mpr = _mutation_peak_ratio(conv3_norm, mutation_pos) | |
| sci = _signal_concentration_index(conv3_norm, mutation_pos) | |
| # ββ splice distances βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| dist_d, dist_a, nearest_d, nearest_a = _splice_distances(ref_seq, mutation_pos) | |
| risk_d = _classify_splice_risk(dist_d) | |
| risk_a = _classify_splice_risk(dist_a) | |
| # ββ splice aura score β proximity-weighted composite ββββββββββββββββββββ | |
| def _proximity_weight(dist): | |
| if dist is None: return 0.0 | |
| if dist <= 2: return 1.0 | |
| if dist <= 8: return 0.5 | |
| return 0.1 | |
| aura = ( | |
| _proximity_weight(dist_d) * float(splice_imp[0]) + | |
| _proximity_weight(dist_a) * float(splice_imp[1]) + | |
| float(splice_imp[2]) * 0.3 | |
| ) / 1.6 # normalise to ~[0,1] | |
| aura = float(np.clip(aura, 0.0, 1.0)) | |
| # ββ counterfactual analysis βββββββββββββββββββββββββββββββββββββββββββββββ | |
| cf = _counterfactual_splice(model, ref_seq, mut_seq, mutation_pos, | |
| exon_flag, intron_flag, prob) | |
| # ββ feature ablation βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| abl = _ablation_splice(model, enc, prob) | |
| return SpliceSignals( | |
| probability=prob, risk_tier=tier, tier_desc=tier_desc, | |
| conv3_norm=conv3_norm, gradient_attribution=grad_attr, | |
| mutation_pos=mutation_pos, | |
| mutation_peak_ratio=mpr, signal_concentration=sci, | |
| imp_score=imp_score, region_imp=region_imp, splice_imp=splice_imp, | |
| dist_donor=dist_d, dist_acceptor=dist_a, | |
| nearest_donor=nearest_d, nearest_acceptor=nearest_a, | |
| splice_risk_donor=risk_d, splice_risk_acceptor=risk_a, | |
| counterfactual=cf, ablation=abl, | |
| splice_aura_score=aura, | |
| ) | |
| def _counterfactual_splice(model: MutationPredictorCNN_v2, | |
| ref_seq: str, mut_seq: str, | |
| mutation_pos: int, exon_flag: int, | |
| intron_flag: int, orig_prob: float) -> dict: | |
| if mutation_pos < 0 or mutation_pos >= len(ref_seq): | |
| return {"error": "mutation position not detected", | |
| "original_probability": orig_prob} | |
| ref_base = ref_seq[mutation_pos].upper() | |
| results = [] | |
| for alt in ALL_BASES: | |
| if alt == ref_base: | |
| continue | |
| alt_mut = ref_seq[:mutation_pos] + alt + ref_seq[mutation_pos+1:] | |
| enc_cf = encode_for_v2(ref_seq, alt_mut, exon_flag, intron_flag) | |
| with torch.no_grad(): | |
| logit_cf, _, _, _ = model(enc_cf.unsqueeze(0)) | |
| p = float(torch.sigmoid(logit_cf).item()) | |
| results.append({"mutation": f"{ref_base}>{alt}", "alt_base": alt, | |
| "probability": round(p, 4)}) | |
| all_probs = [r["probability"] for r in results] + [orig_prob] | |
| return { | |
| "original_probability": round(orig_prob, 4), | |
| "ref_base": ref_base, | |
| "table": sorted(results, key=lambda x: x["probability"], reverse=True), | |
| "max_probability": round(max(all_probs), 4), | |
| "min_probability": round(min(all_probs), 4), | |
| "probability_range": round(max(all_probs) - min(all_probs), 4), | |
| "counterfactual_delta": round(abs(max(all_probs) - min(all_probs)), 4), | |
| } | |
| def _ablation_splice(model: MutationPredictorCNN_v2, | |
| enc: torch.Tensor, prob_base: float) -> dict: | |
| def _prob(e): | |
| with torch.no_grad(): | |
| logit, _, _, _ = model(e.unsqueeze(0)) | |
| return float(torch.sigmoid(logit).item()) | |
| enc_no_splice = enc.clone(); enc_no_splice[1103:1106] = 0.0 | |
| enc_no_region = enc.clone(); enc_no_region[1101:1103] = 0.0 | |
| enc_no_mut = enc.clone(); enc_no_mut[1089:1101] = 0.0 | |
| enc_no_seq = enc.clone(); enc_no_seq[:1089] = 0.0 | |
| d_splice = round(abs(prob_base - _prob(enc_no_splice)), 4) | |
| d_region = round(abs(prob_base - _prob(enc_no_region)), 4) | |
| d_mut = round(abs(prob_base - _prob(enc_no_mut)), 4) | |
| d_seq = round(abs(prob_base - _prob(enc_no_seq)), 4) | |
| total = d_splice + d_region + d_mut + d_seq | |
| def _pct(v): return round(v / total * 100, 1) if total > 0 else 0.0 | |
| return { | |
| "baseline_probability": round(prob_base, 4), | |
| "splice_delta": d_splice, "splice_pct": _pct(d_splice), | |
| "region_delta": d_region, "region_pct": _pct(d_region), | |
| "mutation_delta": d_mut, "mutation_pct": _pct(d_mut), | |
| "sequence_delta": d_seq, "sequence_pct": _pct(d_seq), | |
| "dominant_feature": max( | |
| [("Splice features", d_splice), ("Region flags", d_region), | |
| ("Mutation type", d_mut), ("Sequence context", d_seq)], | |
| key=lambda x: x[1] | |
| )[0], | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # β‘ Extract V4 Signals | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_v4_signals(model: MutationPredictorCNN_v4, | |
| ref_seq: str, mut_seq: str, | |
| exon_flag: int, intron_flag: int) -> V4Signals: | |
| seq_t, mut_oh, region_t, splice_t = encode_for_v4(ref_seq, mut_seq, | |
| exon_flag, intron_flag) | |
| # ββ base forward βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with torch.no_grad(): | |
| logit = model(seq_t, mut_oh, region_t, splice_t) | |
| prob = float(torch.sigmoid(logit).item()) | |
| mutation_pos = find_mutation_pos(ref_seq, mut_seq) | |
| # ββ conv3 activation norm ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _fwd_v4(seq_in): | |
| return model(seq_in, mut_oh, region_t, splice_t) | |
| conv3_norm = _conv3_activation_norm( | |
| model, seq_t.squeeze(0), | |
| lambda x: model(x.unsqueeze(0), mut_oh, region_t, splice_t) | |
| ) | |
| # ββ gradient attribution β through sequence tensor only ββββββββββββββββββ | |
| model.eval() | |
| seq_leaf = seq_t.clone().detach().requires_grad_(True) | |
| logit_g = model(seq_leaf, mut_oh, region_t, splice_t) | |
| model.zero_grad() | |
| logit_g.backward() | |
| grad = seq_leaf.grad # (1, 11, 99) | |
| if grad is not None: | |
| # L2 norm per position across 11 channels | |
| grad_attr = grad.squeeze(0).abs().norm(dim=0).numpy() # (99,) | |
| if grad_attr.max() > 0: | |
| grad_attr = grad_attr / grad_attr.max() | |
| else: | |
| grad_attr = np.zeros(99) | |
| mpr = _mutation_peak_ratio(conv3_norm, mutation_pos) | |
| sci = _signal_concentration_index(conv3_norm, mutation_pos) | |
| return V4Signals( | |
| probability=prob, | |
| conv3_norm=conv3_norm, gradient_attribution=grad_attr, | |
| mutation_pos=mutation_pos, | |
| mutation_peak_ratio=mpr, signal_concentration=sci, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # β’ Extract Classic Signals | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_classic_signals(model: MutationPredictorClassic, | |
| ref_seq: str, mut_seq: str, | |
| exon_flag: int, intron_flag: int) -> ClassicSignals: | |
| enc = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag) | |
| # ββ base forward βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with torch.no_grad(): | |
| x = enc.unsqueeze(0) | |
| logit, imp_t, r_imp_t = model(x) | |
| prob = float(torch.sigmoid(logit).item()) | |
| imp_score = float(imp_t.item()) | |
| region_imp= r_imp_t[0].numpy() | |
| mutation_pos = find_mutation_pos(ref_seq, mut_seq) | |
| # ββ conv3 activation norm ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| conv3_norm = _conv3_activation_norm( | |
| model, enc, | |
| lambda x: model(x.unsqueeze(0)) | |
| ) | |
| mpr = _mutation_peak_ratio(conv3_norm, mutation_pos) | |
| sci = _signal_concentration_index(conv3_norm, mutation_pos) | |
| return ClassicSignals( | |
| probability=prob, | |
| conv3_norm=conv3_norm, | |
| importance_head=imp_score, | |
| region_imp=region_imp, | |
| mutation_pos=mutation_pos, | |
| mutation_peak_ratio=mpr, | |
| signal_concentration=sci, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Cross-model analysis | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_cross_model_analysis(splice: SpliceSignals, | |
| v4: V4Signals, | |
| classic: ClassicSignals) -> dict: | |
| """ | |
| Compute all five XAI Engine metrics and cross-model locality score. | |
| """ | |
| # 1. Mutation Peak Ratio β average across models | |
| mpr_avg = float(np.mean([ | |
| splice.mutation_peak_ratio, | |
| v4.mutation_peak_ratio, | |
| classic.mutation_peak_ratio, | |
| ])) | |
| # 2. Counterfactual magnitude β from splice model (has full CF data) | |
| cf_mag = float(splice.counterfactual.get("counterfactual_delta", 0.0)) | |
| # 3. Cross-model locality score | |
| # Are activation peaks aligned across models? | |
| # Compute correlation of all three conv3_norm profiles. | |
| profiles = [splice.conv3_norm, v4.conv3_norm, classic.conv3_norm] | |
| cors = [] | |
| for i in range(len(profiles)): | |
| for j in range(i+1, len(profiles)): | |
| a, b = profiles[i], profiles[j] | |
| if a.std() > 0 and b.std() > 0: | |
| cors.append(float(np.corrcoef(a, b)[0, 1])) | |
| else: | |
| cors.append(0.0) | |
| cross_locality = float(np.clip(np.mean(cors), -1.0, 1.0)) | |
| # 4. Signal concentration index β average across models | |
| sci_avg = float(np.mean([ | |
| splice.signal_concentration, | |
| v4.signal_concentration, | |
| classic.signal_concentration, | |
| ])) | |
| # 5. Explainability Strength Score (0β1) | |
| mpr_norm = float(np.clip(mpr_avg / 3.0, 0.0, 1.0)) # >3Γ peak = full score | |
| cf_norm = float(np.clip(cf_mag, 0.0, 1.0)) | |
| loc_norm = float(np.clip((cross_locality + 1.0) / 2.0, 0.0, 1.0)) | |
| ess = (0.35 * mpr_norm + 0.35 * cf_norm + 0.30 * loc_norm) | |
| ess = float(np.clip(ess, 0.0, 1.0)) | |
| # Activation pattern type | |
| peak = float(np.max(splice.conv3_norm)) | |
| if peak > 0: | |
| above_half = int(np.sum(splice.conv3_norm > 0.5 * peak)) | |
| above_tenth = int(np.sum(splice.conv3_norm > 0.1 * peak)) | |
| else: | |
| above_half = above_tenth = 0 | |
| if above_half <= 5: | |
| pattern = "Sharp" | |
| elif above_half <= 25: | |
| pattern = "Broad" | |
| else: | |
| pattern = "Flat" | |
| # Per-model probability agreement | |
| probs = [splice.probability, v4.probability, classic.probability] | |
| prob_std = float(np.std(probs)) | |
| return { | |
| "mutation_peak_ratio": round(mpr_avg, 4), | |
| "counterfactual_magnitude": round(cf_mag, 4), | |
| "cross_model_locality_score": round(cross_locality, 4), | |
| "signal_concentration_index": round(sci_avg, 4), | |
| "explainability_strength_score": round(ess, 4), | |
| "activation_pattern_type": pattern, | |
| "prob_std": round(prob_std, 4), | |
| "model_agreement": _agreement_level(prob_std), | |
| # raw profiles for plotting | |
| "_splice_norm": splice.conv3_norm, | |
| "_v4_norm": v4.conv3_norm, | |
| "_classic_norm": classic.conv3_norm, | |
| "_splice_grad": splice.gradient_attribution, | |
| "_v4_grad": v4.gradient_attribution, | |
| } | |
| def _agreement_level(std: float) -> str: | |
| if std < 0.05: return "Strong" | |
| if std < 0.12: return "Moderate" | |
| return "Weak" | |