Mutation_XAI / explainability_renderer.py
nileshhanotia's picture
explainability_renderer.py
c938648 verified
"""explainability_renderer.py — PeVe v1.1"""
from __future__ import annotations
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
from config import (BAND_COLORS, WINDOW_BP, SPLICE_PROB_HIGH, SPLICE_PROB_MODERATE,
ACTIVATION_NORM_HIGH, ACTIVATION_NORM_MODERATE, BIOCHEMICAL_RISK_ACTIVE)
from decision_engine import SynthesisResult, SpliceLayerOutput, ContextLayerOutput, ProteinLayerOutput
MECH_COLORS = {
"RNA_Splicing":"#d73027","Protein_Biochemical":"#4575b4","Sequence_Context":"#74add1",
"Mechanism_Ambiguity":"#f46d43","Protein_Truncation":"#313695",
"Insufficient_Evidence":"#aaaaaa","Conflict_Manual_Review":"#762a83","Out_Of_Scope":"#cccccc",
}
def render_summary_card(result, chrom, pos, ref, alt):
fig, ax = plt.subplots(figsize=(8, 2.8))
ax.axis("off"); fig.patch.set_facecolor("#f8f9fa"); ax.set_facecolor("#f8f9fa")
mcolor = MECH_COLORS.get(result.dominant_mechanism, "#888")
rcolor = "#d73027" if result.conflict_report.requires_manual_review else "#1a9850"
ax.text(0.02, 0.88, f"chr{chrom}:{pos} {ref}>{alt}", transform=ax.transAxes,
fontsize=13, fontweight="bold", color="#222")
ax.text(0.02, 0.62, f"Dominant: {result.dominant_mechanism.replace('_',' ')}",
transform=ax.transAxes, fontsize=11, color="white",
bbox=dict(facecolor=mcolor, boxstyle="round,pad=0.3", edgecolor="none"))
ax.text(0.02, 0.35, f"Classification: {result.final_classification}",
transform=ax.transAxes, fontsize=10, color="#333")
label = "⛔ MANUAL REVIEW REQUIRED" if result.conflict_report.requires_manual_review else "✓ No major conflicts"
ax.text(0.02, 0.12, label, transform=ax.transAxes, fontsize=9, color=rcolor, fontstyle="italic")
plt.tight_layout(); return fig
def render_saliency_heatmap(splice, ref, alt):
fig, ax = plt.subplots(figsize=(10, 1.8))
if splice.saliency_map is not None and len(splice.saliency_map) > 0:
sal = np.array(splice.saliency_map, dtype=float)
mn, mx = sal.min(), sal.max()
if mx > mn: sal = (sal - mn)/(mx - mn)
if len(sal) != WINDOW_BP:
xo = np.linspace(0,1,len(sal)); xn = np.linspace(0,1,WINDOW_BP)
sal = np.interp(xn, xo, sal)
cmap = LinearSegmentedColormap.from_list("sal", ["#f7fbff","#6baed6","#08519c","#d73027"])
ax.imshow(sal.reshape(1,-1), aspect="auto", cmap=cmap, vmin=0, vmax=1, extent=[0,WINDOW_BP,0,1])
else:
ax.text(0.5, 0.5, "Saliency map unavailable", ha="center", va="center",
transform=ax.transAxes, color="#aaa"); ax.set_facecolor("#f0f0f0")
ax.axvline(x=WINDOW_BP//2, color="#d73027", linewidth=2.5, linestyle="--", label=f"{ref}>{alt}")
ax.set_xlabel("Position in 401bp window", fontsize=9); ax.set_yticks([])
ax.set_title(f"RNA Saliency | splice_prob={splice.splice_prob:.3f}", fontsize=10)
ax.legend(loc="upper right", fontsize=8); plt.tight_layout(); return fig
def render_activation_peak(context, ref, alt):
fig, ax = plt.subplots(figsize=(10, 2.2))
x = np.arange(WINDOW_BP); peak = context.activation_peak_position; norm = context.activation_norm
profile = norm * np.exp(-0.5*((x-peak)/30)**2)
ax.fill_between(x, profile, alpha=0.35, color="#4575b4")
ax.plot(x, profile, color="#4575b4", linewidth=1.5, label="Activation profile")
ax.axvline(x=WINDOW_BP//2, color="#d73027", linewidth=2, linestyle="--", label=f"Mutation ({ref}>{alt})")
ax.axvline(x=peak, color="#1a9850", linewidth=1.5, linestyle=":", label=f"Peak (pos={peak})")
ax.axhline(y=ACTIVATION_NORM_MODERATE, color="#fc8d59", linewidth=1, linestyle="--", alpha=0.7, label=f"Active thresh ({ACTIVATION_NORM_MODERATE})")
ax.axhline(y=ACTIVATION_NORM_HIGH, color="#d73027", linewidth=1, linestyle="--", alpha=0.7, label=f"High thresh ({ACTIVATION_NORM_HIGH})")
ax.set_xlim(0,WINDOW_BP); ax.set_ylim(0,max(1.0,norm+0.1))
ax.set_xlabel("Position in 401bp window",fontsize=9); ax.set_ylabel("Activation",fontsize=9)
ax.set_title(f"Sequence Context Activation | norm={context.activation_norm:.3f}, peak={peak}", fontsize=10)
ax.legend(loc="upper right", fontsize=7); plt.tight_layout(); return fig
def render_shap_bar(protein):
shap = protein.shap_feature_contributions
if not shap:
fig, ax = plt.subplots(figsize=(6,2))
ax.text(0.5,0.5,"SHAP values unavailable",ha="center",va="center",transform=ax.transAxes,color="#aaa")
ax.axis("off"); plt.tight_layout(); return fig
feats = list(shap.keys()); vals = [shap[f] for f in feats]
colors = ["#d73027" if v>0 else "#4575b4" for v in vals]
fig, ax = plt.subplots(figsize=(7, max(2.5, 0.5*len(feats)+1)))
bars = ax.barh(feats, vals, color=colors, edgecolor="white", height=0.6)
ax.axvline(x=0, color="#333", linewidth=1)
ax.set_xlabel("SHAP contribution (positive=pathogenic)",fontsize=9)
ax.set_title(f"Layer 3 Features | biochemical_risk={protein.biochemical_risk_score:.3f}",fontsize=10)
for bar, v in zip(bars, vals):
ax.text(v+(0.005 if v>=0 else -0.005), bar.get_y()+bar.get_height()/2,
f"{v:+.3f}", va="center", ha="left" if v>=0 else "right", fontsize=8)
ax.legend(handles=[mpatches.Patch(color="#d73027",label="Pathogenic"),
mpatches.Patch(color="#4575b4",label="Benign")], fontsize=8, loc="lower right")
plt.tight_layout(); return fig
def render_band_gauges(result, splice, context, protein):
fig, axes = plt.subplots(1,3,figsize=(10,2)); fig.patch.set_facecolor("#f8f9fa")
datasets = [
("RNA Splice", splice.splice_prob, [(SPLICE_PROB_HIGH,"High"),(SPLICE_PROB_MODERATE,"Moderate")], result.activation_levels.splice_band),
("Seq Context", context.activation_norm, [(ACTIVATION_NORM_HIGH,"High"),(ACTIVATION_NORM_MODERATE,"Moderate")], result.activation_levels.context_band),
("Protein", protein.biochemical_risk_score, [(BIOCHEMICAL_RISK_ACTIVE,"Active")],
"Active" if result.activation_levels.protein_active else "Inactive"),
]
for ax, (title, value, bands, cband) in zip(axes, datasets):
ax.set_facecolor("#f8f9fa"); ax.set_xlim(0,1); ax.set_ylim(0,1); ax.axis("off")
ax.set_title(title, fontsize=9, pad=4)
bar_color = BAND_COLORS.get(cband, "#888")
ax.barh(0.3, 1.0, height=0.25, color="#e0e0e0", left=0, align="edge")
ax.barh(0.3, value, height=0.25, color=bar_color, left=0, align="edge", alpha=0.85)
for thresh, lbl in bands:
ax.axvline(x=thresh, color="#333", linewidth=1.2, linestyle="--", alpha=0.7)
ax.text(thresh, 0.58, f"{thresh}", ha="center", fontsize=6.5, color="#333")
ax.text(max(0.01, value-0.01), 0.3+0.125, f"{value:.3f}", va="center",
fontsize=8, fontweight="bold", color="white" if value>0.4 else "#333")
ax.text(0.5, 0.08, cband, ha="center", fontsize=9, color=bar_color,
fontweight="bold", transform=ax.transAxes)
plt.suptitle("Mechanism Activation Bands", fontsize=10, y=1.02)
plt.tight_layout(); return fig
def render_conflict_table(result):
rows = []
for c in result.conflict_report.major_conflicts:
rows.append(f'<tr><td style="color:#d73027;font-weight:bold;padding:4px 8px">⛔ MAJOR</td>'
f'<td style="padding:4px 8px">{c.replace("MAJOR: ","")}</td></tr>')
for c in result.conflict_report.minor_conflicts:
rows.append(f'<tr><td style="color:#fc8d59;font-weight:bold;padding:4px 8px">⚠ MINOR</td>'
f'<td style="padding:4px 8px">{c.replace("MINOR: ","")}</td></tr>')
if not rows:
rows = ['<tr><td colspan="2" style="color:#1a9850;padding:4px 8px">✓ No conflicts detected</td></tr>']
hc = "#d73027" if result.conflict_report.requires_manual_review else "#1a9850"
rt = "MANUAL REVIEW REQUIRED" if result.conflict_report.requires_manual_review else "No Review Required"
return f"""<div style="font-family:monospace;font-size:13px">
<div style="color:{hc};font-weight:bold;margin-bottom:8px">
{rt} ({result.conflict_report.conflict_score_major} major, {result.conflict_report.conflict_score_minor} minor)
</div>
<table style="border-collapse:collapse;width:100%;background:#fafafa">
<thead><tr style="background:#eee"><th style="padding:4px 8px;text-align:left">Tier</th>
<th style="padding:4px 8px;text-align:left">Description</th></tr></thead>
<tbody>{"".join(rows)}</tbody>
</table></div>"""