""" app.py ====== Mutation Explainability Intelligence System Gradio Space — explanation-first clinical variant analysis Three models: nileshhanotia/mutation-predictor-splice nileshhanotia/mutation-predictor-v4 nileshhanotia/mutation-pathogenicity-predictor Explanation ALWAYS precedes prediction panel. """ from __future__ import annotations import io import json import logging import os import sys import tempfile import traceback import gradio as gr import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from matplotlib.colors import LinearSegmentedColormap def _fig_to_pil(fig): """Render matplotlib figure to PIL Image — required for gr.Image in Gradio 4.44.""" buf = io.BytesIO() fig.savefig(buf, format="png", dpi=110, bbox_inches="tight", facecolor=fig.get_facecolor()) buf.seek(0) from PIL import Image as _PILImage img = _PILImage.open(buf).copy() plt.close(fig) return img import requests import time from functools import lru_cache # ── project imports ─────────────────────────────────────────────────────────── from model_loader import ( ModelRegistry, encode_for_v2, find_mutation_pos, ) from explainability_engine import ( extract_splice_signals, extract_v4_signals, extract_classic_signals, compute_cross_model_analysis, ) from decision_engine import build_decision, DecisionResult logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", ) logger = logging.getLogger("mutation_xai") # ── Global registry (lazy) ──────────────────────────────────────────────────── REGISTRY = ModelRegistry(hf_token=os.environ.get("HF_TOKEN")) # ── Ensembl fetch ───────────────────────────────────────────────────────────── ENSEMBL_URL = "https://rest.ensembl.org/sequence/region/human" WINDOW_HALF = 49 # 49 + 1 + 49 = 99 bp (matches all three models) @lru_cache(maxsize=256) def _fetch_ensembl(chrom: str, start: int, end: int) -> str: chrom = chrom.lstrip("chrCHR").strip() region = f"{chrom}:{start}..{end}:1" url = f"{ENSEMBL_URL}/{region}" for attempt in range(3): try: r = requests.get(url, params={"content-type": "application/json"}, timeout=15) if r.status_code == 429: time.sleep(int(r.headers.get("Retry-After", 5))) continue r.raise_for_status() data = r.json() if isinstance(data, list): data = data[0] return data.get("seq", "").upper() except Exception as e: if attempt == 2: raise RuntimeError(f"Ensembl API failed: {e}") time.sleep(1.5 * (2 ** attempt)) return "" def fetch_window(chrom: str, pos: int) -> tuple[str, str, int]: """ Returns (ref_seq_99bp, mut_seq_placeholder, mutation_pos_in_window). Caller must insert the alt base into mut_seq at mutation_pos. """ chrom_clean = chrom.lstrip("chrCHR").strip() start = max(1, pos - WINDOW_HALF) end = pos + WINDOW_HALF seq = _fetch_ensembl(chrom_clean, start, end) if len(seq) < 1: raise ValueError(f"Empty sequence returned for chr{chrom}:{start}-{end}") # Pad/trim to 99 seq = (seq + "N" * 99)[:99] mut_pos = pos - start # 0-indexed position within window mut_pos = max(0, min(98, mut_pos)) return seq, mut_pos # ═══════════════════════════════════════════════════════════════════════════════ # Visualisation helpers # ═══════════════════════════════════════════════════════════════════════════════ _BG = "#0D1117" _TEXT = "#E6EDF3" _MUTED = "#7D8590" _BLUE = "#58A6FF" _GREEN = "#3FB950" _RED = "#F85149" _ORG = "#D29922" _CMAP_ACTIVATION = LinearSegmentedColormap.from_list( "act", [(0.04,0.22,0.47),(0.96,0.96,0.96),(0.72,0.05,0.12)], N=256) _CMAP_SPLICE = LinearSegmentedColormap.from_list( "splice", [(0.0,"#f7f7f7"),(0.3,"#fee08b"),(0.6,"#fc8d59"),(1.0,"#d73027")]) def _fig_base(w=15, h=2.8): fig, ax = plt.subplots(figsize=(w, h), facecolor=_BG) ax.set_facecolor(_BG) return fig, ax def _style_ax(ax, title): ax.set_title(title, color=_TEXT, fontsize=9, loc="left", pad=4, fontweight="bold") for sp in ["top","right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") ax.tick_params(colors=_TEXT, labelsize=7) def plot_activation_heatmap(profile: np.ndarray, mutation_pos: int, label: str, prob: float): imp = profile.copy() if imp.max() > 0: imp /= imp.max() fig, ax = _fig_base(15, 2.5) im = ax.imshow(imp[np.newaxis,:], aspect="auto", cmap=_CMAP_ACTIVATION, vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1]) if mutation_pos >= 0: ax.axvline(x=mutation_pos, color=_GREEN, linewidth=2.0, linestyle="--", label=f"Mutation pos {mutation_pos}") ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6, loc="upper right") cb = fig.colorbar(im, ax=ax, pad=0.01) cb.set_label("Activation intensity", color=_TEXT, fontsize=8) cb.ax.tick_params(colors=_TEXT, labelsize=7) ax.set_xlabel("Nucleotide position (99 bp window)", color=_TEXT, fontsize=9) ax.set_xticks(range(0, 99, 10)) ax.set_yticks([]) _style_ax(ax, f"CNN conv3 Activation — {label} (prob={prob:.4f})") fig.tight_layout() return _fig_to_pil(fig) def plot_splice_heatmap(ref_seq: str, mutation_pos: int): seq = (ref_seq.upper() + "N" * 99)[:99] scores = np.zeros(99) donors, acceptors = [], [] for i in range(len(seq)-1): if seq[i:i+2] == "GT": donors.append(i) if seq[i:i+2] == "AG": acceptors.append(i) for p in donors: for d in range(-8,9): if 0 <= p+d < 99: scores[p+d] = max(scores[p+d], 0.5) for p in acceptors: for d in range(-8,9): if 0 <= p+d < 99: scores[p+d] = max(scores[p+d], 0.5) for p in donors: if 0 <= p < 99: scores[p] = 1.0 for p in acceptors: if 0 <= p < 99: scores[p] = max(scores[p], 0.8) fig, ax = _fig_base(15, 2.5) im = ax.imshow(scores[np.newaxis,:], aspect="auto", cmap=_CMAP_SPLICE, vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1]) if mutation_pos >= 0: ax.axvline(x=mutation_pos, color=_BLUE, linewidth=2.0, linestyle="--", label=f"Mutation pos {mutation_pos}") ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6, loc="upper right") cb = fig.colorbar(im, ax=ax, pad=0.01) cb.set_label("Splice risk", color=_TEXT, fontsize=8) cb.ax.tick_params(colors=_TEXT, labelsize=7) ax.set_xlabel("Nucleotide position (99 bp window)", color=_TEXT, fontsize=9) ax.set_xticks(range(0, 99, 10)) ax.set_yticks([]) _style_ax(ax, "Splice Distance Risk — GT donor / AG acceptor signals") fig.tight_layout() return _fig_to_pil(fig) def plot_gradient_heatmap(attr: np.ndarray, mutation_pos: int, label: str): fig, ax = _fig_base(15, 2.5) im = ax.imshow(attr[np.newaxis,:], aspect="auto", cmap="PuOr", vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1]) if mutation_pos >= 0: ax.axvline(x=mutation_pos, color=_GREEN, linewidth=2.0, linestyle="--", label=f"Mutation pos {mutation_pos}") ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6, loc="upper right") cb = fig.colorbar(im, ax=ax, pad=0.01) cb.set_label("Gradient attribution", color=_TEXT, fontsize=8) cb.ax.tick_params(colors=_TEXT, labelsize=7) ax.set_xlabel("Nucleotide position", color=_TEXT, fontsize=9) ax.set_xticks(range(0, 99, 10)) ax.set_yticks([]) _style_ax(ax, f"Gradient Attribution Map — {label}") fig.tight_layout() return _fig_to_pil(fig) def plot_counterfactual(cf_table: list[dict], orig_prob: float, cf_delta: float): if not cf_table: fig, ax = plt.subplots(figsize=(8, 3), facecolor=_BG) ax.text(0.5, 0.5, "No counterfactual data", ha="center", va="center", color=_TEXT, fontsize=12) ax.axis("off") return _fig_to_pil(fig) labels = [r["mutation"] for r in cf_table] probs = [r["probability"] for r in cf_table] max_p, min_p = max(probs), min(probs) colors = [_RED if p == max_p else (_BLUE if p == min_p else "#74add1") for p in probs] fig, ax = plt.subplots(figsize=(10, 3.5), facecolor=_BG) ax.set_facecolor(_BG) bars = ax.bar(labels, probs, color=colors, edgecolor="#444", linewidth=0.7) ax.axhline(0.5, color=_MUTED, linestyle="--", linewidth=1.0, label="Decision boundary (0.5)") ax.axhline(orig_prob, color=_ORG, linestyle="-.", linewidth=1.5, label=f"Original mutation ({orig_prob:.3f})") ax.set_ylim(0, 1.05) ax.set_xlabel("Alternative mutation", color=_TEXT, fontsize=10) ax.set_ylabel("Pathogenicity probability", color=_TEXT, fontsize=10) ax.tick_params(colors=_TEXT) ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.5) for b, p in zip(bars, probs): ax.text(b.get_x() + b.get_width()/2, b.get_height()+0.01, f"{p:.3f}", ha="center", va="bottom", fontsize=8, color=_TEXT) for sp in ["top","right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") ax.set_title( f"Counterfactual Analysis | Δ={cf_delta:.4f} | " f"range {min_p:.3f}–{max_p:.3f}", color=_TEXT, fontsize=10, loc="left") fig.tight_layout() return _fig_to_pil(fig) def plot_ablation(ablation: dict): labels = [ "Splice features\n(donor/acceptor/region)", "Region features\n(exon/intron flags)", "Mutation type\n(one-hot)", ] deltas = [ablation["splice_causal_effect"], ablation["region_causal_effect"], ablation["mutation_causal_effect"]] pcts = [ablation["splice_pct"], ablation["region_pct"], ablation["mutation_pct"]] colors = [_RED, _ORG, _BLUE] fig, ax = plt.subplots(figsize=(9, 3.0), facecolor=_BG) ax.set_facecolor(_BG) bars = ax.barh(labels, deltas, color=colors, edgecolor="#444", linewidth=0.6) ax.set_xlabel("Probability delta when ablated (causal effect)", color=_TEXT, fontsize=9) ax.tick_params(colors=_TEXT, labelsize=8) ax.set_title( f"Feature Ablation | baseline prob={ablation['baseline_probability']:.4f}", color=_TEXT, fontsize=10, loc="left") for b, d, p in zip(bars, deltas, pcts): ax.text(b.get_width()+0.002, b.get_y()+b.get_height()/2, f" Δ{d:.4f} ({p}%)", va="center", fontsize=9, color=_TEXT) ax.set_xlim(0, max(deltas+[0.01]) * 1.6) for sp in ["top","right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") fig.tight_layout() return _fig_to_pil(fig) def plot_xai_metrics(xai): """Radar-style bar chart of explainability metrics.""" labels = ["Model\nAgreement", "XAI\nStrength", "CF\nMagnitude", "Locality\nScore", "Concentration\nIndex"] values = [ xai.model_agreement, xai.explainability_strength, min(xai.counterfactual_magnitude / 0.4, 1.0), xai.cross_model_locality_score, xai.signal_concentration_index, ] colors = [_GREEN if v >= 0.65 else (_ORG if v >= 0.40 else _RED) for v in values] fig, ax = plt.subplots(figsize=(10, 3.0), facecolor=_BG) ax.set_facecolor(_BG) bars = ax.bar(labels, values, color=colors, edgecolor="#444", linewidth=0.6, width=0.5) ax.axhline(0.65, color=_GREEN, linestyle="--", linewidth=0.8, alpha=0.6, label="High (≥0.65)") ax.axhline(0.40, color=_ORG, linestyle="--", linewidth=0.8, alpha=0.6, label="Moderate (≥0.40)") ax.set_ylim(0, 1.1) ax.set_ylabel("Score (0–1)", color=_TEXT, fontsize=9) ax.tick_params(colors=_TEXT, labelsize=8) ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.4, loc="upper right") for b, v in zip(bars, values): ax.text(b.get_x()+b.get_width()/2, b.get_height()+0.02, f"{v:.3f}", ha="center", fontsize=9, color=_TEXT) for sp in ["top","right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") ax.set_title("Explainability Metrics Panel", color=_TEXT, fontsize=10, loc="left") fig.tight_layout() return _fig_to_pil(fig) # ═══════════════════════════════════════════════════════════════════════════════ # Core pipeline # ═══════════════════════════════════════════════════════════════════════════════ def run_pipeline( chrom: str, position: str, ref_base: str, alt_base: str, exon_flag: int, intron_flag: int, ): """Main Gradio callback. Returns all outputs.""" chrom = chrom.strip() ref_base = ref_base.strip().upper() alt_base = alt_base.strip().upper() try: pos = int(position.strip().replace(",","")) except ValueError: return _error(f"Invalid position: '{position}'") for b, name in [(ref_base,"Reference"),(alt_base,"Alternate")]: if b not in "ACGT" or len(b) != 1: return _error(f"{name} base must be A, C, G, or T. Got: '{b}'") if ref_base == alt_base: return _error("Reference and alternate bases are identical.") try: ref_seq, mutation_pos = fetch_window(chrom, pos) # Validate reference base actual_ref = ref_seq[mutation_pos].upper() if actual_ref != ref_base: return _error( f"Reference mismatch at chr{chrom}:{pos}: " f"genome has '{actual_ref}', you entered '{ref_base}'." ) # Build mutated sequence mut_seq = ref_seq[:mutation_pos] + alt_base + ref_seq[mutation_pos+1:] splice_sig = extract_splice_signals( REGISTRY.splice, ref_seq, mut_seq, exon_flag, intron_flag) v4_sig = extract_v4_signals( REGISTRY.v4, ref_seq, mut_seq, exon_flag, intron_flag) classic_sig = extract_classic_signals( REGISTRY.classic, ref_seq, mut_seq) xai = compute_cross_model_analysis(splice_sig, v4_sig, classic_sig, mutation_pos) result = build_decision( chrom=chrom, pos=pos, ref=ref_base, alt=alt_base, ref_seq=ref_seq, mut_seq=mut_seq, mutation_pos=mutation_pos, splice=splice_sig, v4=v4_sig, classic=classic_sig, xai=xai, ) plots = _build_all_plots(result) json_str = result.to_json() json_file = _write_json_file(json_str) demo_banner = ( "\n> ⚠️ **DEMO MODE** — models are running with random weights. " "Place real checkpoints or ensure HF_TOKEN is set.\n" if REGISTRY.demo_mode else "" ) summary_md = _build_summary_md(result, demo_banner) return ( summary_md, # 0: explanation summary (FIRST) result.final_explanation, # 1: final explanation text plots["xai_metrics"], # 2: XAI metrics panel plots["splice_activation"], # 3: splice conv3 heatmap plots["splice_heatmap"], # 4: splice distance heatmap plots["v4_activation"], # 5: v4 conv3 heatmap plots["classic_activation"], # 6: classic conv3 heatmap plots["v4_gradient"], # 7: v4 gradient attribution plots["splice_gradient"], # 8: splice gradient attribution plots["counterfactual"], # 9: counterfactual chart plots["ablation"], # 10: ablation chart json_str, # 11: JSON report json_file, # 12: download file ) except Exception as exc: logger.error("Pipeline error: %s\n%s", exc, traceback.format_exc()) return _error(f"Error: {exc}\n\n```\n{traceback.format_exc()}\n```") def _build_all_plots(r: DecisionResult) -> dict: mp = r.mutation_pos return { "xai_metrics": plot_xai_metrics(r.xai), "splice_activation": plot_activation_heatmap( r.splice.conv3_profile, mp, "Splice Model", r.splice.probability), "splice_heatmap": plot_splice_heatmap(r.ref_seq, mp), "v4_activation": plot_activation_heatmap( r.v4.conv3_profile, mp, "V4 Model", r.v4.probability), "classic_activation": plot_activation_heatmap( r.classic.conv3_profile, mp, "Classic Model", r.classic.probability), "v4_gradient": plot_gradient_heatmap( r.v4.gradient_attribution, mp, "V4 Model"), "splice_gradient": plot_gradient_heatmap( r.splice.gradient_attribution, mp, "Splice Model"), "counterfactual": plot_counterfactual( r.splice.counterfactual_table, r.splice.probability, r.splice.counterfactual_delta), "ablation": plot_ablation(r.splice.ablation), } def _write_json_file(json_str: str) -> str: tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w", encoding="utf-8") tmp.write(json_str) tmp.close() return tmp.name def _build_summary_md(r: DecisionResult, demo_banner: str) -> str: mech_icon = { "Splice-driven": "🔀", "Protein-driven": "🧬", "Consensus": "✅", "Ambiguous": "⚠️", }.get(r.dominant_mechanism, "❓") tier_icon = { "PATHOGENIC": "🔴", "LIKELY PATHOGENIC": "🟠", "POSSIBLY PATHOGENIC": "🟡", "LIKELY BENIGN": "🟢", "BENIGN": "🟢", }.get(r.risk_tier, "⚪") conf_icon = {"High": "🔵", "Moderate": "🟡", "Low": "🔴"}.get(r.confidence, "⚪") return f"""{demo_banner} ## {tier_icon} Risk Tier: **{r.risk_tier}** | Field | Value | |---|---| | **Variant** | `chr{r.chrom}:g.{r.pos}{r.ref}>{r.alt}` | | **Unified Probability** | `{r.unified_probability:.4f}` | | **Dominant Mechanism** | {mech_icon} {r.dominant_mechanism} | | **Confidence** | {conf_icon} {r.confidence} | | **Splice Model** | `{r.splice.probability:.4f}` — {r.splice.risk_tier} | | **V4 Model** | `{r.v4.probability:.4f}` | | **Classic Model** | `{r.classic.probability:.4f}` | --- ### Explainability Metrics | Metric | Value | |---|---| | **Mutation Peak Ratio** | `{r.xai.mutation_peak_ratio:.4f}` | | **Counterfactual Magnitude** | `{r.xai.counterfactual_magnitude:.4f}` | | **Cross-Model Locality** | `{r.xai.cross_model_locality_score:.4f}` | | **Signal Concentration** | `{r.xai.signal_concentration_index:.4f}` | | **XAI Strength Score** | `{r.xai.explainability_strength:.4f}` | | **Activation Pattern** | `{r.xai.activation_pattern_type}` | | **Model Agreement** | `{r.xai.model_agreement:.4f}` | --- ### Interpretation Briefs **Splice:** {r.splice_analysis[:300]}{'…' if len(r.splice_analysis)>300 else ''} **Protein:** {r.protein_analysis[:250]}{'…' if len(r.protein_analysis)>250 else ''} **Agreement:** {r.agreement_analysis[:250]}{'…' if len(r.agreement_analysis)>250 else ''} """ def _error(msg: str): empties = [None] * 9 return ( f"❌ **Error**\n\n{msg}", "", empty, empty, empty, empty, empty, empty, empty, empty, empty, "{}", None, ) # ═══════════════════════════════════════════════════════════════════════════════ # Gradio UI # ═══════════════════════════════════════════════════════════════════════════════ CSS = """ @import url('https://fonts.googleapis.com/css2?family