Mutation-XAI / explainability_engine.py
nileshhanotia's picture
Rename explainability_engine (2).py to explainability_engine.py
31e0577 verified
"""
explainability_engine.py
========================
Extract ALL internal explainability signals from each of the three models.
No signal is simplified or omitted.
Splice model signals:
- probability
- conv3 activation norm vector (99,)
- mutation-centered activation peak
- splice aura distance (donor / acceptor)
- counterfactual delta (all alternative bases)
- feature ablation response (splice / region / mutation groups)
- risk tier classification
V4 model signals:
- probability
- importance head vector (via conv3 hook β€” identical architecture)
- mutation-centered importance density
Classic model signals:
- probability
- importance head output (scalar)
- region importance (exon / intron)
- conv3 activation norm vector (99,)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from model_loader import (
MutationPredictorCNN_v2,
MutationPredictorCNN_v4,
MutationPredictorClassic,
ModelRegistry,
encode_for_v2,
encode_for_v4,
find_mutation_pos,
ALL_BASES,
MUT_TYPES,
)
logger = logging.getLogger("mutation_xai.xai")
# ═══════════════════════════════════════════════════════════════════════════════
# Shared helpers
# ═══════════════════════════════════════════════════════════════════════════════
def _conv3_activation_norm(model: torch.nn.Module, x: torch.Tensor,
forward_fn) -> np.ndarray:
"""
Register a forward hook on model.conv3, run forward_fn(x), return
L2-normalised per-position activation norm vector of shape (99,).
"""
activations: dict = {}
def _hook(module, inp, out):
activations["conv3"] = out.detach()
hook = model.conv3.register_forward_hook(_hook)
try:
with torch.no_grad():
forward_fn(x)
finally:
hook.remove()
act = activations.get("conv3")
if act is None:
return np.zeros(99)
# act shape: (1, 256, 99)
norm = act.squeeze(0).norm(dim=0).numpy() # (99,)
if norm.max() > 0:
norm = norm / norm.max()
return norm
def _gradient_attribution(model: torch.nn.Module, enc: torch.Tensor,
forward_fn_grad) -> np.ndarray:
"""
Compute input-gradient attribution for the sequence portion.
Returns normalised per-position attribution of shape (99,).
"""
model.eval()
enc_leaf = enc.clone().detach().requires_grad_(True)
logit = forward_fn_grad(enc_leaf)
model.zero_grad()
logit.backward()
grad = enc_leaf.grad
if grad is None:
return np.zeros(99)
seq_grad = grad[:1089].view(99, 11)
attr = seq_grad.abs().norm(dim=1).detach().numpy()
if attr.max() > 0:
attr = attr / attr.max()
return attr
def _mutation_peak_ratio(profile: np.ndarray, mutation_pos: int) -> float:
"""
peak_signal / mean_signal, where peak_signal is the profile value at
mutation_pos. Returns 0.0 if mutation_pos < 0 or mean == 0.
"""
if mutation_pos < 0 or mutation_pos >= len(profile):
return 0.0
mean_sig = float(profile.mean())
if mean_sig == 0:
return 0.0
return float(profile[mutation_pos]) / mean_sig
def _signal_concentration_index(profile: np.ndarray, mutation_pos: int,
window: int = 10) -> float:
"""
Fraction of total activation energy within Β±window of mutation_pos.
Ranges 0–1; 1.0 = perfectly concentrated.
"""
if mutation_pos < 0:
return 0.0
total = float(profile.sum())
if total == 0:
return 0.0
lo = max(0, mutation_pos - window)
hi = min(len(profile), mutation_pos + window + 1)
local = float(profile[lo:hi].sum())
return local / total
def _splice_distances(ref_seq: str, mutation_pos: int):
"""
Scan ref_seq for GT (donor) and AG (acceptor) dinucleotides.
Returns (dist_donor, dist_acceptor, nearest_donor_pos, nearest_acceptor_pos).
Any value may be None if no site found.
"""
seq = (ref_seq.upper() + "N" * 99)[:99]
donors, acceptors = [], []
for i in range(len(seq) - 1):
if seq[i:i+2] == "GT": donors.append(i)
if seq[i:i+2] == "AG": acceptors.append(i)
if mutation_pos < 0:
return None, None, None, None
dist_d = nearest_d = None
dist_a = nearest_a = None
if donors:
pairs = sorted([(abs(mutation_pos - p), p) for p in donors])
dist_d, nearest_d = pairs[0]
if acceptors:
pairs = sorted([(abs(mutation_pos - p), p) for p in acceptors])
dist_a, nearest_a = pairs[0]
return dist_d, dist_a, nearest_d, nearest_a
def _classify_splice_risk(distance: Optional[int]) -> str:
if distance is None: return "UNKNOWN"
if distance <= 2: return "CRITICAL SPLICE SITE"
if distance <= 8: return "SPLICE REGION"
return "NON-SPLICE"
def _classify_risk_tier(prob: float) -> tuple[str, str]:
if prob >= 0.90: return "PATHOGENIC", "Very high confidence"
if prob >= 0.70: return "LIKELY PATHOGENIC", "High confidence"
if prob >= 0.50: return "POSSIBLY PATHOGENIC", "Moderate confidence"
if prob >= 0.20: return "LIKELY BENIGN", "Low pathogenic signal"
return "BENIGN", "Very low pathogenic signal"
# ═══════════════════════════════════════════════════════════════════════════════
# Signal dataclasses
# ═══════════════════════════════════════════════════════════════════════════════
@dataclass
class SpliceSignals:
probability: float
risk_tier: str
tier_desc: str
conv3_norm: np.ndarray # (99,)
gradient_attribution: np.ndarray # (99,)
mutation_pos: int
mutation_peak_ratio: float
signal_concentration: float
imp_score: float # importance_head output
region_imp: np.ndarray # (2,) [exon, intron]
splice_imp: np.ndarray # (3,) [donor, acc, region]
dist_donor: Optional[int]
dist_acceptor: Optional[int]
nearest_donor: Optional[int]
nearest_acceptor: Optional[int]
splice_risk_donor: str
splice_risk_acceptor: str
counterfactual: dict # all-base CF results
ablation: dict # feature ablation deltas
splice_aura_score: float # proximity-weighted splice signal
@dataclass
class V4Signals:
probability: float
conv3_norm: np.ndarray # (99,)
gradient_attribution: np.ndarray # (99,)
mutation_pos: int
mutation_peak_ratio: float
signal_concentration: float
@dataclass
class ClassicSignals:
probability: float
conv3_norm: np.ndarray # (99,)
importance_head: float # scalar importance_head output
region_imp: np.ndarray # (2,) [exon, intron]
mutation_pos: int
mutation_peak_ratio: float
signal_concentration: float
# ═══════════════════════════════════════════════════════════════════════════════
# β‘  Extract Splice Signals
# ═══════════════════════════════════════════════════════════════════════════════
def extract_splice_signals(model: MutationPredictorCNN_v2,
ref_seq: str, mut_seq: str,
exon_flag: int, intron_flag: int) -> SpliceSignals:
enc = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag)
# ── base forward pass ────────────────────────────────────────────────────
with torch.no_grad():
x = enc.unsqueeze(0)
logit, imp_t, r_imp_t, s_imp_t = model(x)
prob = float(torch.sigmoid(logit).item())
imp_score = float(imp_t.item())
region_imp= r_imp_t[0].numpy()
splice_imp= s_imp_t[0].numpy()
tier, tier_desc = _classify_risk_tier(prob)
mutation_pos = find_mutation_pos(ref_seq, mut_seq)
# ── conv3 activation norm ────────────────────────────────────────────────
def _fwd(x_in):
return model(x_in.unsqueeze(0))
conv3_norm = _conv3_activation_norm(
model, enc,
lambda x: model(x.unsqueeze(0))
)
# ── gradient attribution ─────────────────────────────────────────────────
def _fwd_grad(leaf: torch.Tensor):
logit_g, _, _, _ = model(leaf.unsqueeze(0))
return logit_g
grad_attr = _gradient_attribution(model, enc, _fwd_grad)
# ── mutation-peak derived metrics ─────────────────────────────────────────
mpr = _mutation_peak_ratio(conv3_norm, mutation_pos)
sci = _signal_concentration_index(conv3_norm, mutation_pos)
# ── splice distances ─────────────────────────────────────────────────────
dist_d, dist_a, nearest_d, nearest_a = _splice_distances(ref_seq, mutation_pos)
risk_d = _classify_splice_risk(dist_d)
risk_a = _classify_splice_risk(dist_a)
# ── splice aura score β€” proximity-weighted composite ────────────────────
def _proximity_weight(dist):
if dist is None: return 0.0
if dist <= 2: return 1.0
if dist <= 8: return 0.5
return 0.1
aura = (
_proximity_weight(dist_d) * float(splice_imp[0]) +
_proximity_weight(dist_a) * float(splice_imp[1]) +
float(splice_imp[2]) * 0.3
) / 1.6 # normalise to ~[0,1]
aura = float(np.clip(aura, 0.0, 1.0))
# ── counterfactual analysis ───────────────────────────────────────────────
cf = _counterfactual_splice(model, ref_seq, mut_seq, mutation_pos,
exon_flag, intron_flag, prob)
# ── feature ablation ─────────────────────────────────────────────────────
abl = _ablation_splice(model, enc, prob)
return SpliceSignals(
probability=prob, risk_tier=tier, tier_desc=tier_desc,
conv3_norm=conv3_norm, gradient_attribution=grad_attr,
mutation_pos=mutation_pos,
mutation_peak_ratio=mpr, signal_concentration=sci,
imp_score=imp_score, region_imp=region_imp, splice_imp=splice_imp,
dist_donor=dist_d, dist_acceptor=dist_a,
nearest_donor=nearest_d, nearest_acceptor=nearest_a,
splice_risk_donor=risk_d, splice_risk_acceptor=risk_a,
counterfactual=cf, ablation=abl,
splice_aura_score=aura,
)
def _counterfactual_splice(model: MutationPredictorCNN_v2,
ref_seq: str, mut_seq: str,
mutation_pos: int, exon_flag: int,
intron_flag: int, orig_prob: float) -> dict:
if mutation_pos < 0 or mutation_pos >= len(ref_seq):
return {"error": "mutation position not detected",
"original_probability": orig_prob}
ref_base = ref_seq[mutation_pos].upper()
results = []
for alt in ALL_BASES:
if alt == ref_base:
continue
alt_mut = ref_seq[:mutation_pos] + alt + ref_seq[mutation_pos+1:]
enc_cf = encode_for_v2(ref_seq, alt_mut, exon_flag, intron_flag)
with torch.no_grad():
logit_cf, _, _, _ = model(enc_cf.unsqueeze(0))
p = float(torch.sigmoid(logit_cf).item())
results.append({"mutation": f"{ref_base}>{alt}", "alt_base": alt,
"probability": round(p, 4)})
all_probs = [r["probability"] for r in results] + [orig_prob]
return {
"original_probability": round(orig_prob, 4),
"ref_base": ref_base,
"table": sorted(results, key=lambda x: x["probability"], reverse=True),
"max_probability": round(max(all_probs), 4),
"min_probability": round(min(all_probs), 4),
"probability_range": round(max(all_probs) - min(all_probs), 4),
"counterfactual_delta": round(abs(max(all_probs) - min(all_probs)), 4),
}
def _ablation_splice(model: MutationPredictorCNN_v2,
enc: torch.Tensor, prob_base: float) -> dict:
def _prob(e):
with torch.no_grad():
logit, _, _, _ = model(e.unsqueeze(0))
return float(torch.sigmoid(logit).item())
enc_no_splice = enc.clone(); enc_no_splice[1103:1106] = 0.0
enc_no_region = enc.clone(); enc_no_region[1101:1103] = 0.0
enc_no_mut = enc.clone(); enc_no_mut[1089:1101] = 0.0
enc_no_seq = enc.clone(); enc_no_seq[:1089] = 0.0
d_splice = round(abs(prob_base - _prob(enc_no_splice)), 4)
d_region = round(abs(prob_base - _prob(enc_no_region)), 4)
d_mut = round(abs(prob_base - _prob(enc_no_mut)), 4)
d_seq = round(abs(prob_base - _prob(enc_no_seq)), 4)
total = d_splice + d_region + d_mut + d_seq
def _pct(v): return round(v / total * 100, 1) if total > 0 else 0.0
return {
"baseline_probability": round(prob_base, 4),
"splice_delta": d_splice, "splice_pct": _pct(d_splice),
"region_delta": d_region, "region_pct": _pct(d_region),
"mutation_delta": d_mut, "mutation_pct": _pct(d_mut),
"sequence_delta": d_seq, "sequence_pct": _pct(d_seq),
"dominant_feature": max(
[("Splice features", d_splice), ("Region flags", d_region),
("Mutation type", d_mut), ("Sequence context", d_seq)],
key=lambda x: x[1]
)[0],
}
# ═══════════════════════════════════════════════════════════════════════════════
# β‘‘ Extract V4 Signals
# ═══════════════════════════════════════════════════════════════════════════════
def extract_v4_signals(model: MutationPredictorCNN_v4,
ref_seq: str, mut_seq: str,
exon_flag: int, intron_flag: int) -> V4Signals:
seq_t, mut_oh, region_t, splice_t = encode_for_v4(ref_seq, mut_seq,
exon_flag, intron_flag)
# ── base forward ─────────────────────────────────────────────────────────
with torch.no_grad():
logit = model(seq_t, mut_oh, region_t, splice_t)
prob = float(torch.sigmoid(logit).item())
mutation_pos = find_mutation_pos(ref_seq, mut_seq)
# ── conv3 activation norm ────────────────────────────────────────────────
def _fwd_v4(seq_in):
return model(seq_in, mut_oh, region_t, splice_t)
conv3_norm = _conv3_activation_norm(
model, seq_t.squeeze(0),
lambda x: model(x.unsqueeze(0), mut_oh, region_t, splice_t)
)
# ── gradient attribution β€” through sequence tensor only ──────────────────
model.eval()
seq_leaf = seq_t.clone().detach().requires_grad_(True)
logit_g = model(seq_leaf, mut_oh, region_t, splice_t)
model.zero_grad()
logit_g.backward()
grad = seq_leaf.grad # (1, 11, 99)
if grad is not None:
# L2 norm per position across 11 channels
grad_attr = grad.squeeze(0).abs().norm(dim=0).numpy() # (99,)
if grad_attr.max() > 0:
grad_attr = grad_attr / grad_attr.max()
else:
grad_attr = np.zeros(99)
mpr = _mutation_peak_ratio(conv3_norm, mutation_pos)
sci = _signal_concentration_index(conv3_norm, mutation_pos)
return V4Signals(
probability=prob,
conv3_norm=conv3_norm, gradient_attribution=grad_attr,
mutation_pos=mutation_pos,
mutation_peak_ratio=mpr, signal_concentration=sci,
)
# ═══════════════════════════════════════════════════════════════════════════════
# β‘’ Extract Classic Signals
# ═══════════════════════════════════════════════════════════════════════════════
def extract_classic_signals(model: MutationPredictorClassic,
ref_seq: str, mut_seq: str,
exon_flag: int, intron_flag: int) -> ClassicSignals:
enc = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag)
# ── base forward ─────────────────────────────────────────────────────────
with torch.no_grad():
x = enc.unsqueeze(0)
logit, imp_t, r_imp_t = model(x)
prob = float(torch.sigmoid(logit).item())
imp_score = float(imp_t.item())
region_imp= r_imp_t[0].numpy()
mutation_pos = find_mutation_pos(ref_seq, mut_seq)
# ── conv3 activation norm ────────────────────────────────────────────────
conv3_norm = _conv3_activation_norm(
model, enc,
lambda x: model(x.unsqueeze(0))
)
mpr = _mutation_peak_ratio(conv3_norm, mutation_pos)
sci = _signal_concentration_index(conv3_norm, mutation_pos)
return ClassicSignals(
probability=prob,
conv3_norm=conv3_norm,
importance_head=imp_score,
region_imp=region_imp,
mutation_pos=mutation_pos,
mutation_peak_ratio=mpr,
signal_concentration=sci,
)
# ═══════════════════════════════════════════════════════════════════════════════
# Cross-model analysis
# ═══════════════════════════════════════════════════════════════════════════════
def compute_cross_model_analysis(splice: SpliceSignals,
v4: V4Signals,
classic: ClassicSignals) -> dict:
"""
Compute all five XAI Engine metrics and cross-model locality score.
"""
# 1. Mutation Peak Ratio β€” average across models
mpr_avg = float(np.mean([
splice.mutation_peak_ratio,
v4.mutation_peak_ratio,
classic.mutation_peak_ratio,
]))
# 2. Counterfactual magnitude β€” from splice model (has full CF data)
cf_mag = float(splice.counterfactual.get("counterfactual_delta", 0.0))
# 3. Cross-model locality score
# Are activation peaks aligned across models?
# Compute correlation of all three conv3_norm profiles.
profiles = [splice.conv3_norm, v4.conv3_norm, classic.conv3_norm]
cors = []
for i in range(len(profiles)):
for j in range(i+1, len(profiles)):
a, b = profiles[i], profiles[j]
if a.std() > 0 and b.std() > 0:
cors.append(float(np.corrcoef(a, b)[0, 1]))
else:
cors.append(0.0)
cross_locality = float(np.clip(np.mean(cors), -1.0, 1.0))
# 4. Signal concentration index β€” average across models
sci_avg = float(np.mean([
splice.signal_concentration,
v4.signal_concentration,
classic.signal_concentration,
]))
# 5. Explainability Strength Score (0–1)
mpr_norm = float(np.clip(mpr_avg / 3.0, 0.0, 1.0)) # >3Γ— peak = full score
cf_norm = float(np.clip(cf_mag, 0.0, 1.0))
loc_norm = float(np.clip((cross_locality + 1.0) / 2.0, 0.0, 1.0))
ess = (0.35 * mpr_norm + 0.35 * cf_norm + 0.30 * loc_norm)
ess = float(np.clip(ess, 0.0, 1.0))
# Activation pattern type
peak = float(np.max(splice.conv3_norm))
if peak > 0:
above_half = int(np.sum(splice.conv3_norm > 0.5 * peak))
above_tenth = int(np.sum(splice.conv3_norm > 0.1 * peak))
else:
above_half = above_tenth = 0
if above_half <= 5:
pattern = "Sharp"
elif above_half <= 25:
pattern = "Broad"
else:
pattern = "Flat"
# Per-model probability agreement
probs = [splice.probability, v4.probability, classic.probability]
prob_std = float(np.std(probs))
return {
"mutation_peak_ratio": round(mpr_avg, 4),
"counterfactual_magnitude": round(cf_mag, 4),
"cross_model_locality_score": round(cross_locality, 4),
"signal_concentration_index": round(sci_avg, 4),
"explainability_strength_score": round(ess, 4),
"activation_pattern_type": pattern,
"prob_std": round(prob_std, 4),
"model_agreement": _agreement_level(prob_std),
# raw profiles for plotting
"_splice_norm": splice.conv3_norm,
"_v4_norm": v4.conv3_norm,
"_classic_norm": classic.conv3_norm,
"_splice_grad": splice.gradient_attribution,
"_v4_grad": v4.gradient_attribution,
}
def _agreement_level(std: float) -> str:
if std < 0.05: return "Strong"
if std < 0.12: return "Moderate"
return "Weak"