""" explainability_engine.py ======================== Extract ALL internal explainability signals from each of the three models. No signal is simplified or omitted. Splice model signals: - probability - conv3 activation norm vector (99,) - mutation-centered activation peak - splice aura distance (donor / acceptor) - counterfactual delta (all alternative bases) - feature ablation response (splice / region / mutation groups) - risk tier classification V4 model signals: - probability - importance head vector (via conv3 hook — identical architecture) - mutation-centered importance density Classic model signals: - probability - importance head output (scalar) - region importance (exon / intron) - conv3 activation norm vector (99,) """ from __future__ import annotations import logging from dataclasses import dataclass, field from typing import Optional import numpy as np import torch from model_loader import ( MutationPredictorCNN_v2, MutationPredictorCNN_v4, MutationPredictorClassic, ModelRegistry, encode_for_v2, encode_for_v4, find_mutation_pos, ALL_BASES, MUT_TYPES, ) logger = logging.getLogger("mutation_xai.xai") # ═══════════════════════════════════════════════════════════════════════════════ # Shared helpers # ═══════════════════════════════════════════════════════════════════════════════ def _conv3_activation_norm(model: torch.nn.Module, x: torch.Tensor, forward_fn) -> np.ndarray: """ Register a forward hook on model.conv3, run forward_fn(x), return L2-normalised per-position activation norm vector of shape (99,). """ activations: dict = {} def _hook(module, inp, out): activations["conv3"] = out.detach() hook = model.conv3.register_forward_hook(_hook) try: with torch.no_grad(): forward_fn(x) finally: hook.remove() act = activations.get("conv3") if act is None: return np.zeros(99) # act shape: (1, 256, 99) norm = act.squeeze(0).norm(dim=0).numpy() # (99,) if norm.max() > 0: norm = norm / norm.max() return norm def _gradient_attribution(model: torch.nn.Module, enc: torch.Tensor, forward_fn_grad) -> np.ndarray: """ Compute input-gradient attribution for the sequence portion. Returns normalised per-position attribution of shape (99,). """ model.eval() enc_leaf = enc.clone().detach().requires_grad_(True) logit = forward_fn_grad(enc_leaf) model.zero_grad() logit.backward() grad = enc_leaf.grad if grad is None: return np.zeros(99) seq_grad = grad[:1089].view(99, 11) attr = seq_grad.abs().norm(dim=1).detach().numpy() if attr.max() > 0: attr = attr / attr.max() return attr def _mutation_peak_ratio(profile: np.ndarray, mutation_pos: int) -> float: """ peak_signal / mean_signal, where peak_signal is the profile value at mutation_pos. Returns 0.0 if mutation_pos < 0 or mean == 0. """ if mutation_pos < 0 or mutation_pos >= len(profile): return 0.0 mean_sig = float(profile.mean()) if mean_sig == 0: return 0.0 return float(profile[mutation_pos]) / mean_sig def _signal_concentration_index(profile: np.ndarray, mutation_pos: int, window: int = 10) -> float: """ Fraction of total activation energy within ±window of mutation_pos. Ranges 0–1; 1.0 = perfectly concentrated. """ if mutation_pos < 0: return 0.0 total = float(profile.sum()) if total == 0: return 0.0 lo = max(0, mutation_pos - window) hi = min(len(profile), mutation_pos + window + 1) local = float(profile[lo:hi].sum()) return local / total def _splice_distances(ref_seq: str, mutation_pos: int): """ Scan ref_seq for GT (donor) and AG (acceptor) dinucleotides. Returns (dist_donor, dist_acceptor, nearest_donor_pos, nearest_acceptor_pos). Any value may be None if no site found. """ seq = (ref_seq.upper() + "N" * 99)[:99] donors, acceptors = [], [] for i in range(len(seq) - 1): if seq[i:i+2] == "GT": donors.append(i) if seq[i:i+2] == "AG": acceptors.append(i) if mutation_pos < 0: return None, None, None, None dist_d = nearest_d = None dist_a = nearest_a = None if donors: pairs = sorted([(abs(mutation_pos - p), p) for p in donors]) dist_d, nearest_d = pairs[0] if acceptors: pairs = sorted([(abs(mutation_pos - p), p) for p in acceptors]) dist_a, nearest_a = pairs[0] return dist_d, dist_a, nearest_d, nearest_a def _classify_splice_risk(distance: Optional[int]) -> str: if distance is None: return "UNKNOWN" if distance <= 2: return "CRITICAL SPLICE SITE" if distance <= 8: return "SPLICE REGION" return "NON-SPLICE" def _classify_risk_tier(prob: float) -> tuple[str, str]: if prob >= 0.90: return "PATHOGENIC", "Very high confidence" if prob >= 0.70: return "LIKELY PATHOGENIC", "High confidence" if prob >= 0.50: return "POSSIBLY PATHOGENIC", "Moderate confidence" if prob >= 0.20: return "LIKELY BENIGN", "Low pathogenic signal" return "BENIGN", "Very low pathogenic signal" # ═══════════════════════════════════════════════════════════════════════════════ # Signal dataclasses # ═══════════════════════════════════════════════════════════════════════════════ @dataclass class SpliceSignals: probability: float risk_tier: str tier_desc: str conv3_norm: np.ndarray # (99,) gradient_attribution: np.ndarray # (99,) mutation_pos: int mutation_peak_ratio: float signal_concentration: float imp_score: float # importance_head output region_imp: np.ndarray # (2,) [exon, intron] splice_imp: np.ndarray # (3,) [donor, acc, region] dist_donor: Optional[int] dist_acceptor: Optional[int] nearest_donor: Optional[int] nearest_acceptor: Optional[int] splice_risk_donor: str splice_risk_acceptor: str counterfactual: dict # all-base CF results ablation: dict # feature ablation deltas splice_aura_score: float # proximity-weighted splice signal @dataclass class V4Signals: probability: float conv3_norm: np.ndarray # (99,) gradient_attribution: np.ndarray # (99,) mutation_pos: int mutation_peak_ratio: float signal_concentration: float @dataclass class ClassicSignals: probability: float conv3_norm: np.ndarray # (99,) importance_head: float # scalar importance_head output region_imp: np.ndarray # (2,) [exon, intron] mutation_pos: int mutation_peak_ratio: float signal_concentration: float # ═══════════════════════════════════════════════════════════════════════════════ # ① Extract Splice Signals # ═══════════════════════════════════════════════════════════════════════════════ def extract_splice_signals(model: MutationPredictorCNN_v2, ref_seq: str, mut_seq: str, exon_flag: int, intron_flag: int) -> SpliceSignals: enc = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag) # ── base forward pass ──────────────────────────────────────────────────── with torch.no_grad(): x = enc.unsqueeze(0) logit, imp_t, r_imp_t, s_imp_t = model(x) prob = float(torch.sigmoid(logit).item()) imp_score = float(imp_t.item()) region_imp= r_imp_t[0].numpy() splice_imp= s_imp_t[0].numpy() tier, tier_desc = _classify_risk_tier(prob) mutation_pos = find_mutation_pos(ref_seq, mut_seq) # ── conv3 activation norm ──────────────────────────────────────────────── def _fwd(x_in): return model(x_in.unsqueeze(0)) conv3_norm = _conv3_activation_norm( model, enc, lambda x: model(x.unsqueeze(0)) ) # ── gradient attribution ───────────────────────────────────────────────── def _fwd_grad(leaf: torch.Tensor): logit_g, _, _, _ = model(leaf.unsqueeze(0)) return logit_g grad_attr = _gradient_attribution(model, enc, _fwd_grad) # ── mutation-peak derived metrics ───────────────────────────────────────── mpr = _mutation_peak_ratio(conv3_norm, mutation_pos) sci = _signal_concentration_index(conv3_norm, mutation_pos) # ── splice distances ───────────────────────────────────────────────────── dist_d, dist_a, nearest_d, nearest_a = _splice_distances(ref_seq, mutation_pos) risk_d = _classify_splice_risk(dist_d) risk_a = _classify_splice_risk(dist_a) # ── splice aura score — proximity-weighted composite ──────────────────── def _proximity_weight(dist): if dist is None: return 0.0 if dist <= 2: return 1.0 if dist <= 8: return 0.5 return 0.1 aura = ( _proximity_weight(dist_d) * float(splice_imp[0]) + _proximity_weight(dist_a) * float(splice_imp[1]) + float(splice_imp[2]) * 0.3 ) / 1.6 # normalise to ~[0,1] aura = float(np.clip(aura, 0.0, 1.0)) # ── counterfactual analysis ─────────────────────────────────────────────── cf = _counterfactual_splice(model, ref_seq, mut_seq, mutation_pos, exon_flag, intron_flag, prob) # ── feature ablation ───────────────────────────────────────────────────── abl = _ablation_splice(model, enc, prob) return SpliceSignals( probability=prob, risk_tier=tier, tier_desc=tier_desc, conv3_norm=conv3_norm, gradient_attribution=grad_attr, mutation_pos=mutation_pos, mutation_peak_ratio=mpr, signal_concentration=sci, imp_score=imp_score, region_imp=region_imp, splice_imp=splice_imp, dist_donor=dist_d, dist_acceptor=dist_a, nearest_donor=nearest_d, nearest_acceptor=nearest_a, splice_risk_donor=risk_d, splice_risk_acceptor=risk_a, counterfactual=cf, ablation=abl, splice_aura_score=aura, ) def _counterfactual_splice(model: MutationPredictorCNN_v2, ref_seq: str, mut_seq: str, mutation_pos: int, exon_flag: int, intron_flag: int, orig_prob: float) -> dict: if mutation_pos < 0 or mutation_pos >= len(ref_seq): return {"error": "mutation position not detected", "original_probability": orig_prob} ref_base = ref_seq[mutation_pos].upper() results = [] for alt in ALL_BASES: if alt == ref_base: continue alt_mut = ref_seq[:mutation_pos] + alt + ref_seq[mutation_pos+1:] enc_cf = encode_for_v2(ref_seq, alt_mut, exon_flag, intron_flag) with torch.no_grad(): logit_cf, _, _, _ = model(enc_cf.unsqueeze(0)) p = float(torch.sigmoid(logit_cf).item()) results.append({"mutation": f"{ref_base}>{alt}", "alt_base": alt, "probability": round(p, 4)}) all_probs = [r["probability"] for r in results] + [orig_prob] return { "original_probability": round(orig_prob, 4), "ref_base": ref_base, "table": sorted(results, key=lambda x: x["probability"], reverse=True), "max_probability": round(max(all_probs), 4), "min_probability": round(min(all_probs), 4), "probability_range": round(max(all_probs) - min(all_probs), 4), "counterfactual_delta": round(abs(max(all_probs) - min(all_probs)), 4), } def _ablation_splice(model: MutationPredictorCNN_v2, enc: torch.Tensor, prob_base: float) -> dict: def _prob(e): with torch.no_grad(): logit, _, _, _ = model(e.unsqueeze(0)) return float(torch.sigmoid(logit).item()) enc_no_splice = enc.clone(); enc_no_splice[1103:1106] = 0.0 enc_no_region = enc.clone(); enc_no_region[1101:1103] = 0.0 enc_no_mut = enc.clone(); enc_no_mut[1089:1101] = 0.0 enc_no_seq = enc.clone(); enc_no_seq[:1089] = 0.0 d_splice = round(abs(prob_base - _prob(enc_no_splice)), 4) d_region = round(abs(prob_base - _prob(enc_no_region)), 4) d_mut = round(abs(prob_base - _prob(enc_no_mut)), 4) d_seq = round(abs(prob_base - _prob(enc_no_seq)), 4) total = d_splice + d_region + d_mut + d_seq def _pct(v): return round(v / total * 100, 1) if total > 0 else 0.0 return { "baseline_probability": round(prob_base, 4), "splice_delta": d_splice, "splice_pct": _pct(d_splice), "region_delta": d_region, "region_pct": _pct(d_region), "mutation_delta": d_mut, "mutation_pct": _pct(d_mut), "sequence_delta": d_seq, "sequence_pct": _pct(d_seq), "dominant_feature": max( [("Splice features", d_splice), ("Region flags", d_region), ("Mutation type", d_mut), ("Sequence context", d_seq)], key=lambda x: x[1] )[0], } # ═══════════════════════════════════════════════════════════════════════════════ # ② Extract V4 Signals # ═══════════════════════════════════════════════════════════════════════════════ def extract_v4_signals(model: MutationPredictorCNN_v4, ref_seq: str, mut_seq: str, exon_flag: int, intron_flag: int) -> V4Signals: seq_t, mut_oh, region_t, splice_t = encode_for_v4(ref_seq, mut_seq, exon_flag, intron_flag) # ── base forward ───────────────────────────────────────────────────────── with torch.no_grad(): logit = model(seq_t, mut_oh, region_t, splice_t) prob = float(torch.sigmoid(logit).item()) mutation_pos = find_mutation_pos(ref_seq, mut_seq) # ── conv3 activation norm ──────────────────────────────────────────────── def _fwd_v4(seq_in): return model(seq_in, mut_oh, region_t, splice_t) conv3_norm = _conv3_activation_norm( model, seq_t.squeeze(0), lambda x: model(x.unsqueeze(0), mut_oh, region_t, splice_t) ) # ── gradient attribution — through sequence tensor only ────────────────── model.eval() seq_leaf = seq_t.clone().detach().requires_grad_(True) logit_g = model(seq_leaf, mut_oh, region_t, splice_t) model.zero_grad() logit_g.backward() grad = seq_leaf.grad # (1, 11, 99) if grad is not None: # L2 norm per position across 11 channels grad_attr = grad.squeeze(0).abs().norm(dim=0).numpy() # (99,) if grad_attr.max() > 0: grad_attr = grad_attr / grad_attr.max() else: grad_attr = np.zeros(99) mpr = _mutation_peak_ratio(conv3_norm, mutation_pos) sci = _signal_concentration_index(conv3_norm, mutation_pos) return V4Signals( probability=prob, conv3_norm=conv3_norm, gradient_attribution=grad_attr, mutation_pos=mutation_pos, mutation_peak_ratio=mpr, signal_concentration=sci, ) # ═══════════════════════════════════════════════════════════════════════════════ # ③ Extract Classic Signals # ═══════════════════════════════════════════════════════════════════════════════ def extract_classic_signals(model: MutationPredictorClassic, ref_seq: str, mut_seq: str, exon_flag: int, intron_flag: int) -> ClassicSignals: enc = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag) # ── base forward ───────────────────────────────────────────────────────── with torch.no_grad(): x = enc.unsqueeze(0) logit, imp_t, r_imp_t = model(x) prob = float(torch.sigmoid(logit).item()) imp_score = float(imp_t.item()) region_imp= r_imp_t[0].numpy() mutation_pos = find_mutation_pos(ref_seq, mut_seq) # ── conv3 activation norm ──────────────────────────────────────────────── conv3_norm = _conv3_activation_norm( model, enc, lambda x: model(x.unsqueeze(0)) ) mpr = _mutation_peak_ratio(conv3_norm, mutation_pos) sci = _signal_concentration_index(conv3_norm, mutation_pos) return ClassicSignals( probability=prob, conv3_norm=conv3_norm, importance_head=imp_score, region_imp=region_imp, mutation_pos=mutation_pos, mutation_peak_ratio=mpr, signal_concentration=sci, ) # ═══════════════════════════════════════════════════════════════════════════════ # Cross-model analysis # ═══════════════════════════════════════════════════════════════════════════════ def compute_cross_model_analysis(splice: SpliceSignals, v4: V4Signals, classic: ClassicSignals) -> dict: """ Compute all five XAI Engine metrics and cross-model locality score. """ # 1. Mutation Peak Ratio — average across models mpr_avg = float(np.mean([ splice.mutation_peak_ratio, v4.mutation_peak_ratio, classic.mutation_peak_ratio, ])) # 2. Counterfactual magnitude — from splice model (has full CF data) cf_mag = float(splice.counterfactual.get("counterfactual_delta", 0.0)) # 3. Cross-model locality score # Are activation peaks aligned across models? # Compute correlation of all three conv3_norm profiles. profiles = [splice.conv3_norm, v4.conv3_norm, classic.conv3_norm] cors = [] for i in range(len(profiles)): for j in range(i+1, len(profiles)): a, b = profiles[i], profiles[j] if a.std() > 0 and b.std() > 0: cors.append(float(np.corrcoef(a, b)[0, 1])) else: cors.append(0.0) cross_locality = float(np.clip(np.mean(cors), -1.0, 1.0)) # 4. Signal concentration index — average across models sci_avg = float(np.mean([ splice.signal_concentration, v4.signal_concentration, classic.signal_concentration, ])) # 5. Explainability Strength Score (0–1) mpr_norm = float(np.clip(mpr_avg / 3.0, 0.0, 1.0)) # >3× peak = full score cf_norm = float(np.clip(cf_mag, 0.0, 1.0)) loc_norm = float(np.clip((cross_locality + 1.0) / 2.0, 0.0, 1.0)) ess = (0.35 * mpr_norm + 0.35 * cf_norm + 0.30 * loc_norm) ess = float(np.clip(ess, 0.0, 1.0)) # Activation pattern type peak = float(np.max(splice.conv3_norm)) if peak > 0: above_half = int(np.sum(splice.conv3_norm > 0.5 * peak)) above_tenth = int(np.sum(splice.conv3_norm > 0.1 * peak)) else: above_half = above_tenth = 0 if above_half <= 5: pattern = "Sharp" elif above_half <= 25: pattern = "Broad" else: pattern = "Flat" # Per-model probability agreement probs = [splice.probability, v4.probability, classic.probability] prob_std = float(np.std(probs)) return { "mutation_peak_ratio": round(mpr_avg, 4), "counterfactual_magnitude": round(cf_mag, 4), "cross_model_locality_score": round(cross_locality, 4), "signal_concentration_index": round(sci_avg, 4), "explainability_strength_score": round(ess, 4), "activation_pattern_type": pattern, "prob_std": round(prob_std, 4), "model_agreement": _agreement_level(prob_std), # raw profiles for plotting "_splice_norm": splice.conv3_norm, "_v4_norm": v4.conv3_norm, "_classic_norm": classic.conv3_norm, "_splice_grad": splice.gradient_attribution, "_v4_grad": v4.gradient_attribution, } def _agreement_level(std: float) -> str: if std < 0.05: return "Strong" if std < 0.12: return "Moderate" return "Weak"