nileshhanotia's picture
Update app.py
1245dad verified
"""
app.py
======
Mutation Explainability Intelligence System
Gradio Space β€” explanation-first clinical variant analysis
Three models:
nileshhanotia/mutation-predictor-splice
nileshhanotia/mutation-predictor-v4
nileshhanotia/mutation-pathogenicity-predictor
Explanation ALWAYS precedes prediction panel.
"""
from __future__ import annotations
import io
import json
import logging
import os
import sys
import tempfile
import traceback
import gradio as gr
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap
def _fig_to_pil(fig):
"""Render matplotlib figure to PIL Image β€” required for gr.Image in Gradio 4.44."""
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=110, bbox_inches="tight",
facecolor=fig.get_facecolor())
buf.seek(0)
from PIL import Image as _PILImage
img = _PILImage.open(buf).copy()
plt.close(fig)
return img
import requests
import time
from functools import lru_cache
# ── project imports ───────────────────────────────────────────────────────────
from model_loader import (
ModelRegistry,
encode_for_v2,
find_mutation_pos,
)
from explainability_engine import (
extract_splice_signals,
extract_v4_signals,
extract_classic_signals,
compute_cross_model_analysis,
)
from decision_engine import build_decision, DecisionResult
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
)
logger = logging.getLogger("mutation_xai")
# ── Global registry (lazy) ────────────────────────────────────────────────────
REGISTRY = ModelRegistry(hf_token=os.environ.get("HF_TOKEN"))
# ── Ensembl fetch ─────────────────────────────────────────────────────────────
ENSEMBL_URL = "https://rest.ensembl.org/sequence/region/human"
WINDOW_HALF = 49 # 49 + 1 + 49 = 99 bp (matches all three models)
@lru_cache(maxsize=256)
def _fetch_ensembl(chrom: str, start: int, end: int) -> str:
chrom = chrom.lstrip("chrCHR").strip()
region = f"{chrom}:{start}..{end}:1"
url = f"{ENSEMBL_URL}/{region}"
for attempt in range(3):
try:
r = requests.get(url, params={"content-type": "application/json"}, timeout=15)
if r.status_code == 429:
time.sleep(int(r.headers.get("Retry-After", 5)))
continue
r.raise_for_status()
data = r.json()
if isinstance(data, list): data = data[0]
return data.get("seq", "").upper()
except Exception as e:
if attempt == 2:
raise RuntimeError(f"Ensembl API failed: {e}")
time.sleep(1.5 * (2 ** attempt))
return ""
def fetch_window(chrom: str, pos: int) -> tuple[str, str, int]:
"""
Returns (ref_seq_99bp, mut_seq_placeholder, mutation_pos_in_window).
Caller must insert the alt base into mut_seq at mutation_pos.
"""
chrom_clean = chrom.lstrip("chrCHR").strip()
start = max(1, pos - WINDOW_HALF)
end = pos + WINDOW_HALF
seq = _fetch_ensembl(chrom_clean, start, end)
if len(seq) < 1:
raise ValueError(f"Empty sequence returned for chr{chrom}:{start}-{end}")
# Pad/trim to 99
seq = (seq + "N" * 99)[:99]
mut_pos = pos - start # 0-indexed position within window
mut_pos = max(0, min(98, mut_pos))
return seq, mut_pos
# ═══════════════════════════════════════════════════════════════════════════════
# Visualisation helpers
# ═══════════════════════════════════════════════════════════════════════════════
_BG = "#0D1117"
_TEXT = "#E6EDF3"
_MUTED = "#7D8590"
_BLUE = "#58A6FF"
_GREEN = "#3FB950"
_RED = "#F85149"
_ORG = "#D29922"
_CMAP_ACTIVATION = LinearSegmentedColormap.from_list(
"act", [(0.04,0.22,0.47),(0.96,0.96,0.96),(0.72,0.05,0.12)], N=256)
_CMAP_SPLICE = LinearSegmentedColormap.from_list(
"splice", [(0.0,"#f7f7f7"),(0.3,"#fee08b"),(0.6,"#fc8d59"),(1.0,"#d73027")])
def _fig_base(w=15, h=2.8):
fig, ax = plt.subplots(figsize=(w, h), facecolor=_BG)
ax.set_facecolor(_BG)
return fig, ax
def _style_ax(ax, title):
ax.set_title(title, color=_TEXT, fontsize=9, loc="left", pad=4, fontweight="bold")
for sp in ["top","right"]:
ax.spines[sp].set_visible(False)
ax.spines["left"].set_color("#333")
ax.spines["bottom"].set_color("#333")
ax.tick_params(colors=_TEXT, labelsize=7)
def plot_activation_heatmap(profile: np.ndarray, mutation_pos: int,
label: str, prob: float):
imp = profile.copy()
if imp.max() > 0:
imp /= imp.max()
fig, ax = _fig_base(15, 2.5)
im = ax.imshow(imp[np.newaxis,:], aspect="auto", cmap=_CMAP_ACTIVATION,
vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1])
if mutation_pos >= 0:
ax.axvline(x=mutation_pos, color=_GREEN, linewidth=2.0, linestyle="--",
label=f"Mutation pos {mutation_pos}")
ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6,
loc="upper right")
cb = fig.colorbar(im, ax=ax, pad=0.01)
cb.set_label("Activation intensity", color=_TEXT, fontsize=8)
cb.ax.tick_params(colors=_TEXT, labelsize=7)
ax.set_xlabel("Nucleotide position (99 bp window)", color=_TEXT, fontsize=9)
ax.set_xticks(range(0, 99, 10))
ax.set_yticks([])
_style_ax(ax, f"CNN conv3 Activation β€” {label} (prob={prob:.4f})")
fig.tight_layout()
return _fig_to_pil(fig)
def plot_splice_heatmap(ref_seq: str, mutation_pos: int):
seq = (ref_seq.upper() + "N" * 99)[:99]
scores = np.zeros(99)
donors, acceptors = [], []
for i in range(len(seq)-1):
if seq[i:i+2] == "GT": donors.append(i)
if seq[i:i+2] == "AG": acceptors.append(i)
for p in donors:
for d in range(-8,9):
if 0 <= p+d < 99: scores[p+d] = max(scores[p+d], 0.5)
for p in acceptors:
for d in range(-8,9):
if 0 <= p+d < 99: scores[p+d] = max(scores[p+d], 0.5)
for p in donors:
if 0 <= p < 99: scores[p] = 1.0
for p in acceptors:
if 0 <= p < 99: scores[p] = max(scores[p], 0.8)
fig, ax = _fig_base(15, 2.5)
im = ax.imshow(scores[np.newaxis,:], aspect="auto", cmap=_CMAP_SPLICE,
vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1])
if mutation_pos >= 0:
ax.axvline(x=mutation_pos, color=_BLUE, linewidth=2.0, linestyle="--",
label=f"Mutation pos {mutation_pos}")
ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6,
loc="upper right")
cb = fig.colorbar(im, ax=ax, pad=0.01)
cb.set_label("Splice risk", color=_TEXT, fontsize=8)
cb.ax.tick_params(colors=_TEXT, labelsize=7)
ax.set_xlabel("Nucleotide position (99 bp window)", color=_TEXT, fontsize=9)
ax.set_xticks(range(0, 99, 10))
ax.set_yticks([])
_style_ax(ax, "Splice Distance Risk β€” GT donor / AG acceptor signals")
fig.tight_layout()
return _fig_to_pil(fig)
def plot_gradient_heatmap(attr: np.ndarray, mutation_pos: int, label: str):
fig, ax = _fig_base(15, 2.5)
im = ax.imshow(attr[np.newaxis,:], aspect="auto", cmap="PuOr",
vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1])
if mutation_pos >= 0:
ax.axvline(x=mutation_pos, color=_GREEN, linewidth=2.0, linestyle="--",
label=f"Mutation pos {mutation_pos}")
ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6,
loc="upper right")
cb = fig.colorbar(im, ax=ax, pad=0.01)
cb.set_label("Gradient attribution", color=_TEXT, fontsize=8)
cb.ax.tick_params(colors=_TEXT, labelsize=7)
ax.set_xlabel("Nucleotide position", color=_TEXT, fontsize=9)
ax.set_xticks(range(0, 99, 10))
ax.set_yticks([])
_style_ax(ax, f"Gradient Attribution Map β€” {label}")
fig.tight_layout()
return _fig_to_pil(fig)
def plot_counterfactual(cf_table: list[dict], orig_prob: float, cf_delta: float):
if not cf_table:
fig, ax = plt.subplots(figsize=(8, 3), facecolor=_BG)
ax.text(0.5, 0.5, "No counterfactual data", ha="center", va="center",
color=_TEXT, fontsize=12)
ax.axis("off")
return _fig_to_pil(fig)
labels = [r["mutation"] for r in cf_table]
probs = [r["probability"] for r in cf_table]
max_p, min_p = max(probs), min(probs)
colors = [_RED if p == max_p else (_BLUE if p == min_p else "#74add1") for p in probs]
fig, ax = plt.subplots(figsize=(10, 3.5), facecolor=_BG)
ax.set_facecolor(_BG)
bars = ax.bar(labels, probs, color=colors, edgecolor="#444", linewidth=0.7)
ax.axhline(0.5, color=_MUTED, linestyle="--", linewidth=1.0, label="Decision boundary (0.5)")
ax.axhline(orig_prob, color=_ORG, linestyle="-.", linewidth=1.5,
label=f"Original mutation ({orig_prob:.3f})")
ax.set_ylim(0, 1.05)
ax.set_xlabel("Alternative mutation", color=_TEXT, fontsize=10)
ax.set_ylabel("Pathogenicity probability", color=_TEXT, fontsize=10)
ax.tick_params(colors=_TEXT)
ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.5)
for b, p in zip(bars, probs):
ax.text(b.get_x() + b.get_width()/2, b.get_height()+0.01,
f"{p:.3f}", ha="center", va="bottom", fontsize=8, color=_TEXT)
for sp in ["top","right"]:
ax.spines[sp].set_visible(False)
ax.spines["left"].set_color("#333")
ax.spines["bottom"].set_color("#333")
ax.set_title(
f"Counterfactual Analysis | Ξ”={cf_delta:.4f} | "
f"range {min_p:.3f}–{max_p:.3f}",
color=_TEXT, fontsize=10, loc="left")
fig.tight_layout()
return _fig_to_pil(fig)
def plot_ablation(ablation: dict):
labels = [
"Splice features\n(donor/acceptor/region)",
"Region features\n(exon/intron flags)",
"Mutation type\n(one-hot)",
]
deltas = [ablation["splice_causal_effect"],
ablation["region_causal_effect"],
ablation["mutation_causal_effect"]]
pcts = [ablation["splice_pct"], ablation["region_pct"], ablation["mutation_pct"]]
colors = [_RED, _ORG, _BLUE]
fig, ax = plt.subplots(figsize=(9, 3.0), facecolor=_BG)
ax.set_facecolor(_BG)
bars = ax.barh(labels, deltas, color=colors, edgecolor="#444", linewidth=0.6)
ax.set_xlabel("Probability delta when ablated (causal effect)", color=_TEXT, fontsize=9)
ax.tick_params(colors=_TEXT, labelsize=8)
ax.set_title(
f"Feature Ablation | baseline prob={ablation['baseline_probability']:.4f}",
color=_TEXT, fontsize=10, loc="left")
for b, d, p in zip(bars, deltas, pcts):
ax.text(b.get_width()+0.002, b.get_y()+b.get_height()/2,
f" Ξ”{d:.4f} ({p}%)", va="center", fontsize=9, color=_TEXT)
ax.set_xlim(0, max(deltas+[0.01]) * 1.6)
for sp in ["top","right"]:
ax.spines[sp].set_visible(False)
ax.spines["left"].set_color("#333")
ax.spines["bottom"].set_color("#333")
fig.tight_layout()
return _fig_to_pil(fig)
def plot_xai_metrics(xai):
"""Radar-style bar chart of explainability metrics."""
labels = ["Model\nAgreement", "XAI\nStrength", "CF\nMagnitude",
"Locality\nScore", "Concentration\nIndex"]
values = [
xai.model_agreement,
xai.explainability_strength,
min(xai.counterfactual_magnitude / 0.4, 1.0),
xai.cross_model_locality_score,
xai.signal_concentration_index,
]
colors = [_GREEN if v >= 0.65 else (_ORG if v >= 0.40 else _RED) for v in values]
fig, ax = plt.subplots(figsize=(10, 3.0), facecolor=_BG)
ax.set_facecolor(_BG)
bars = ax.bar(labels, values, color=colors, edgecolor="#444", linewidth=0.6, width=0.5)
ax.axhline(0.65, color=_GREEN, linestyle="--", linewidth=0.8, alpha=0.6, label="High (β‰₯0.65)")
ax.axhline(0.40, color=_ORG, linestyle="--", linewidth=0.8, alpha=0.6, label="Moderate (β‰₯0.40)")
ax.set_ylim(0, 1.1)
ax.set_ylabel("Score (0–1)", color=_TEXT, fontsize=9)
ax.tick_params(colors=_TEXT, labelsize=8)
ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.4, loc="upper right")
for b, v in zip(bars, values):
ax.text(b.get_x()+b.get_width()/2, b.get_height()+0.02,
f"{v:.3f}", ha="center", fontsize=9, color=_TEXT)
for sp in ["top","right"]:
ax.spines[sp].set_visible(False)
ax.spines["left"].set_color("#333")
ax.spines["bottom"].set_color("#333")
ax.set_title("Explainability Metrics Panel", color=_TEXT, fontsize=10, loc="left")
fig.tight_layout()
return _fig_to_pil(fig)
# ═══════════════════════════════════════════════════════════════════════════════
# Core pipeline
# ═══════════════════════════════════════════════════════════════════════════════
def run_pipeline(
chrom: str,
position: str,
ref_base: str,
alt_base: str,
exon_flag: int,
intron_flag: int,
):
"""Main Gradio callback. Returns all outputs."""
chrom = chrom.strip()
ref_base = ref_base.strip().upper()
alt_base = alt_base.strip().upper()
try:
pos = int(position.strip().replace(",",""))
except ValueError:
return _error(f"Invalid position: '{position}'")
for b, name in [(ref_base,"Reference"),(alt_base,"Alternate")]:
if b not in "ACGT" or len(b) != 1:
return _error(f"{name} base must be A, C, G, or T. Got: '{b}'")
if ref_base == alt_base:
return _error("Reference and alternate bases are identical.")
try:
ref_seq, mutation_pos = fetch_window(chrom, pos)
# Validate reference base
actual_ref = ref_seq[mutation_pos].upper()
if actual_ref != ref_base:
return _error(
f"Reference mismatch at chr{chrom}:{pos}: "
f"genome has '{actual_ref}', you entered '{ref_base}'."
)
# Build mutated sequence
mut_seq = ref_seq[:mutation_pos] + alt_base + ref_seq[mutation_pos+1:]
splice_sig = extract_splice_signals(
REGISTRY.splice, ref_seq, mut_seq, exon_flag, intron_flag)
v4_sig = extract_v4_signals(
REGISTRY.v4, ref_seq, mut_seq, exon_flag, intron_flag)
classic_sig = extract_classic_signals(
REGISTRY.classic, ref_seq, mut_seq)
xai = compute_cross_model_analysis(splice_sig, v4_sig, classic_sig, mutation_pos)
result = build_decision(
chrom=chrom, pos=pos, ref=ref_base, alt=alt_base,
ref_seq=ref_seq, mut_seq=mut_seq, mutation_pos=mutation_pos,
splice=splice_sig, v4=v4_sig, classic=classic_sig, xai=xai,
)
plots = _build_all_plots(result)
json_str = result.to_json()
json_file = _write_json_file(json_str)
demo_banner = (
"\n> ⚠️ **DEMO MODE** β€” models are running with random weights. "
"Place real checkpoints or ensure HF_TOKEN is set.\n"
if REGISTRY.demo_mode else ""
)
summary_md = _build_summary_md(result, demo_banner)
return (
summary_md, # 0: explanation summary (FIRST)
result.final_explanation, # 1: final explanation text
plots["xai_metrics"], # 2: XAI metrics panel
plots["splice_activation"], # 3: splice conv3 heatmap
plots["splice_heatmap"], # 4: splice distance heatmap
plots["v4_activation"], # 5: v4 conv3 heatmap
plots["classic_activation"], # 6: classic conv3 heatmap
plots["v4_gradient"], # 7: v4 gradient attribution
plots["splice_gradient"], # 8: splice gradient attribution
plots["counterfactual"], # 9: counterfactual chart
plots["ablation"], # 10: ablation chart
json_str, # 11: JSON report
json_file, # 12: download file
)
except Exception as exc:
logger.error("Pipeline error: %s\n%s", exc, traceback.format_exc())
return _error(f"Error: {exc}\n\n```\n{traceback.format_exc()}\n```")
def _build_all_plots(r: DecisionResult) -> dict:
mp = r.mutation_pos
return {
"xai_metrics": plot_xai_metrics(r.xai),
"splice_activation": plot_activation_heatmap(
r.splice.conv3_profile, mp, "Splice Model", r.splice.probability),
"splice_heatmap": plot_splice_heatmap(r.ref_seq, mp),
"v4_activation": plot_activation_heatmap(
r.v4.conv3_profile, mp, "V4 Model", r.v4.probability),
"classic_activation": plot_activation_heatmap(
r.classic.conv3_profile, mp, "Classic Model", r.classic.probability),
"v4_gradient": plot_gradient_heatmap(
r.v4.gradient_attribution, mp, "V4 Model"),
"splice_gradient": plot_gradient_heatmap(
r.splice.gradient_attribution, mp, "Splice Model"),
"counterfactual": plot_counterfactual(
r.splice.counterfactual_table,
r.splice.probability,
r.splice.counterfactual_delta),
"ablation": plot_ablation(r.splice.ablation),
}
def _write_json_file(json_str: str) -> str:
tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False,
mode="w", encoding="utf-8")
tmp.write(json_str)
tmp.close()
return tmp.name
def _build_summary_md(r: DecisionResult, demo_banner: str) -> str:
mech_icon = {
"Splice-driven": "πŸ”€",
"Protein-driven": "🧬",
"Consensus": "βœ…",
"Ambiguous": "⚠️",
}.get(r.dominant_mechanism, "❓")
tier_icon = {
"PATHOGENIC": "πŸ”΄",
"LIKELY PATHOGENIC": "🟠",
"POSSIBLY PATHOGENIC": "🟑",
"LIKELY BENIGN": "🟒",
"BENIGN": "🟒",
}.get(r.risk_tier, "βšͺ")
conf_icon = {"High": "πŸ”΅", "Moderate": "🟑", "Low": "πŸ”΄"}.get(r.confidence, "βšͺ")
return f"""{demo_banner}
## {tier_icon} Risk Tier: **{r.risk_tier}**
| Field | Value |
|---|---|
| **Variant** | `chr{r.chrom}:g.{r.pos}{r.ref}>{r.alt}` |
| **Unified Probability** | `{r.unified_probability:.4f}` |
| **Dominant Mechanism** | {mech_icon} {r.dominant_mechanism} |
| **Confidence** | {conf_icon} {r.confidence} |
| **Splice Model** | `{r.splice.probability:.4f}` β€” {r.splice.risk_tier} |
| **V4 Model** | `{r.v4.probability:.4f}` |
| **Classic Model** | `{r.classic.probability:.4f}` |
---
### Explainability Metrics
| Metric | Value |
|---|---|
| **Mutation Peak Ratio** | `{r.xai.mutation_peak_ratio:.4f}` |
| **Counterfactual Magnitude** | `{r.xai.counterfactual_magnitude:.4f}` |
| **Cross-Model Locality** | `{r.xai.cross_model_locality_score:.4f}` |
| **Signal Concentration** | `{r.xai.signal_concentration_index:.4f}` |
| **XAI Strength Score** | `{r.xai.explainability_strength:.4f}` |
| **Activation Pattern** | `{r.xai.activation_pattern_type}` |
| **Model Agreement** | `{r.xai.model_agreement:.4f}` |
---
### Interpretation Briefs
**Splice:** {r.splice_analysis[:300]}{'…' if len(r.splice_analysis)>300 else ''}
**Protein:** {r.protein_analysis[:250]}{'…' if len(r.protein_analysis)>250 else ''}
**Agreement:** {r.agreement_analysis[:250]}{'…' if len(r.agreement_analysis)>250 else ''}
"""
def _error(msg: str):
empties = [None] * 9
return (
f"❌ **Error**\n\n{msg}",
"", empty, empty, empty, empty, empty,
empty, empty, empty, empty,
"{}", None,
)
# ═══════════════════════════════════════════════════════════════════════════════
# Gradio UI
# ═══════════════════════════════════════════════════════════════════════════════
CSS = """
@import url('https://fonts.googleapis.com/css2?family