File size: 8,684 Bytes
c938648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""explainability_renderer.py — PeVe v1.1"""
from __future__ import annotations
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
from config import (BAND_COLORS, WINDOW_BP, SPLICE_PROB_HIGH, SPLICE_PROB_MODERATE,
                    ACTIVATION_NORM_HIGH, ACTIVATION_NORM_MODERATE, BIOCHEMICAL_RISK_ACTIVE)
from decision_engine import SynthesisResult, SpliceLayerOutput, ContextLayerOutput, ProteinLayerOutput

MECH_COLORS = {
    "RNA_Splicing":"#d73027","Protein_Biochemical":"#4575b4","Sequence_Context":"#74add1",
    "Mechanism_Ambiguity":"#f46d43","Protein_Truncation":"#313695",
    "Insufficient_Evidence":"#aaaaaa","Conflict_Manual_Review":"#762a83","Out_Of_Scope":"#cccccc",
}

def render_summary_card(result, chrom, pos, ref, alt):
    fig, ax = plt.subplots(figsize=(8, 2.8))
    ax.axis("off"); fig.patch.set_facecolor("#f8f9fa"); ax.set_facecolor("#f8f9fa")
    mcolor = MECH_COLORS.get(result.dominant_mechanism, "#888")
    rcolor = "#d73027" if result.conflict_report.requires_manual_review else "#1a9850"
    ax.text(0.02, 0.88, f"chr{chrom}:{pos}  {ref}>{alt}", transform=ax.transAxes,
            fontsize=13, fontweight="bold", color="#222")
    ax.text(0.02, 0.62, f"Dominant:  {result.dominant_mechanism.replace('_',' ')}",
            transform=ax.transAxes, fontsize=11, color="white",
            bbox=dict(facecolor=mcolor, boxstyle="round,pad=0.3", edgecolor="none"))
    ax.text(0.02, 0.35, f"Classification:  {result.final_classification}",
            transform=ax.transAxes, fontsize=10, color="#333")
    label = "⛔ MANUAL REVIEW REQUIRED" if result.conflict_report.requires_manual_review else "✓ No major conflicts"
    ax.text(0.02, 0.12, label, transform=ax.transAxes, fontsize=9, color=rcolor, fontstyle="italic")
    plt.tight_layout(); return fig

def render_saliency_heatmap(splice, ref, alt):
    fig, ax = plt.subplots(figsize=(10, 1.8))
    if splice.saliency_map is not None and len(splice.saliency_map) > 0:
        sal = np.array(splice.saliency_map, dtype=float)
        mn, mx = sal.min(), sal.max()
        if mx > mn: sal = (sal - mn)/(mx - mn)
        if len(sal) != WINDOW_BP:
            xo = np.linspace(0,1,len(sal)); xn = np.linspace(0,1,WINDOW_BP)
            sal = np.interp(xn, xo, sal)
        cmap = LinearSegmentedColormap.from_list("sal", ["#f7fbff","#6baed6","#08519c","#d73027"])
        ax.imshow(sal.reshape(1,-1), aspect="auto", cmap=cmap, vmin=0, vmax=1, extent=[0,WINDOW_BP,0,1])
    else:
        ax.text(0.5, 0.5, "Saliency map unavailable", ha="center", va="center",
                transform=ax.transAxes, color="#aaa"); ax.set_facecolor("#f0f0f0")
    ax.axvline(x=WINDOW_BP//2, color="#d73027", linewidth=2.5, linestyle="--", label=f"{ref}>{alt}")
    ax.set_xlabel("Position in 401bp window", fontsize=9); ax.set_yticks([])
    ax.set_title(f"RNA Saliency  |  splice_prob={splice.splice_prob:.3f}", fontsize=10)
    ax.legend(loc="upper right", fontsize=8); plt.tight_layout(); return fig

def render_activation_peak(context, ref, alt):
    fig, ax = plt.subplots(figsize=(10, 2.2))
    x = np.arange(WINDOW_BP); peak = context.activation_peak_position; norm = context.activation_norm
    profile = norm * np.exp(-0.5*((x-peak)/30)**2)
    ax.fill_between(x, profile, alpha=0.35, color="#4575b4")
    ax.plot(x, profile, color="#4575b4", linewidth=1.5, label="Activation profile")
    ax.axvline(x=WINDOW_BP//2, color="#d73027", linewidth=2, linestyle="--", label=f"Mutation ({ref}>{alt})")
    ax.axvline(x=peak, color="#1a9850", linewidth=1.5, linestyle=":", label=f"Peak (pos={peak})")
    ax.axhline(y=ACTIVATION_NORM_MODERATE, color="#fc8d59", linewidth=1, linestyle="--", alpha=0.7, label=f"Active thresh ({ACTIVATION_NORM_MODERATE})")
    ax.axhline(y=ACTIVATION_NORM_HIGH, color="#d73027", linewidth=1, linestyle="--", alpha=0.7, label=f"High thresh ({ACTIVATION_NORM_HIGH})")
    ax.set_xlim(0,WINDOW_BP); ax.set_ylim(0,max(1.0,norm+0.1))
    ax.set_xlabel("Position in 401bp window",fontsize=9); ax.set_ylabel("Activation",fontsize=9)
    ax.set_title(f"Sequence Context Activation  |  norm={context.activation_norm:.3f}, peak={peak}", fontsize=10)
    ax.legend(loc="upper right", fontsize=7); plt.tight_layout(); return fig

def render_shap_bar(protein):
    shap = protein.shap_feature_contributions
    if not shap:
        fig, ax = plt.subplots(figsize=(6,2))
        ax.text(0.5,0.5,"SHAP values unavailable",ha="center",va="center",transform=ax.transAxes,color="#aaa")
        ax.axis("off"); plt.tight_layout(); return fig
    feats = list(shap.keys()); vals = [shap[f] for f in feats]
    colors = ["#d73027" if v>0 else "#4575b4" for v in vals]
    fig, ax = plt.subplots(figsize=(7, max(2.5, 0.5*len(feats)+1)))
    bars = ax.barh(feats, vals, color=colors, edgecolor="white", height=0.6)
    ax.axvline(x=0, color="#333", linewidth=1)
    ax.set_xlabel("SHAP contribution (positive=pathogenic)",fontsize=9)
    ax.set_title(f"Layer 3 Features  |  biochemical_risk={protein.biochemical_risk_score:.3f}",fontsize=10)
    for bar, v in zip(bars, vals):
        ax.text(v+(0.005 if v>=0 else -0.005), bar.get_y()+bar.get_height()/2,
                f"{v:+.3f}", va="center", ha="left" if v>=0 else "right", fontsize=8)
    ax.legend(handles=[mpatches.Patch(color="#d73027",label="Pathogenic"),
                       mpatches.Patch(color="#4575b4",label="Benign")], fontsize=8, loc="lower right")
    plt.tight_layout(); return fig

def render_band_gauges(result, splice, context, protein):
    fig, axes = plt.subplots(1,3,figsize=(10,2)); fig.patch.set_facecolor("#f8f9fa")
    datasets = [
        ("RNA Splice", splice.splice_prob, [(SPLICE_PROB_HIGH,"High"),(SPLICE_PROB_MODERATE,"Moderate")], result.activation_levels.splice_band),
        ("Seq Context", context.activation_norm, [(ACTIVATION_NORM_HIGH,"High"),(ACTIVATION_NORM_MODERATE,"Moderate")], result.activation_levels.context_band),
        ("Protein", protein.biochemical_risk_score, [(BIOCHEMICAL_RISK_ACTIVE,"Active")],
         "Active" if result.activation_levels.protein_active else "Inactive"),
    ]
    for ax, (title, value, bands, cband) in zip(axes, datasets):
        ax.set_facecolor("#f8f9fa"); ax.set_xlim(0,1); ax.set_ylim(0,1); ax.axis("off")
        ax.set_title(title, fontsize=9, pad=4)
        bar_color = BAND_COLORS.get(cband, "#888")
        ax.barh(0.3, 1.0, height=0.25, color="#e0e0e0", left=0, align="edge")
        ax.barh(0.3, value, height=0.25, color=bar_color, left=0, align="edge", alpha=0.85)
        for thresh, lbl in bands:
            ax.axvline(x=thresh, color="#333", linewidth=1.2, linestyle="--", alpha=0.7)
            ax.text(thresh, 0.58, f"{thresh}", ha="center", fontsize=6.5, color="#333")
        ax.text(max(0.01, value-0.01), 0.3+0.125, f"{value:.3f}", va="center",
                fontsize=8, fontweight="bold", color="white" if value>0.4 else "#333")
        ax.text(0.5, 0.08, cband, ha="center", fontsize=9, color=bar_color,
                fontweight="bold", transform=ax.transAxes)
    plt.suptitle("Mechanism Activation Bands", fontsize=10, y=1.02)
    plt.tight_layout(); return fig

def render_conflict_table(result):
    rows = []
    for c in result.conflict_report.major_conflicts:
        rows.append(f'<tr><td style="color:#d73027;font-weight:bold;padding:4px 8px">⛔ MAJOR</td>'
                    f'<td style="padding:4px 8px">{c.replace("MAJOR: ","")}</td></tr>')
    for c in result.conflict_report.minor_conflicts:
        rows.append(f'<tr><td style="color:#fc8d59;font-weight:bold;padding:4px 8px">⚠ MINOR</td>'
                    f'<td style="padding:4px 8px">{c.replace("MINOR: ","")}</td></tr>')
    if not rows:
        rows = ['<tr><td colspan="2" style="color:#1a9850;padding:4px 8px">✓ No conflicts detected</td></tr>']
    hc = "#d73027" if result.conflict_report.requires_manual_review else "#1a9850"
    rt = "MANUAL REVIEW REQUIRED" if result.conflict_report.requires_manual_review else "No Review Required"
    return f"""<div style="font-family:monospace;font-size:13px">
      <div style="color:{hc};font-weight:bold;margin-bottom:8px">
        {rt} ({result.conflict_report.conflict_score_major} major, {result.conflict_report.conflict_score_minor} minor)
      </div>
      <table style="border-collapse:collapse;width:100%;background:#fafafa">
        <thead><tr style="background:#eee"><th style="padding:4px 8px;text-align:left">Tier</th>
        <th style="padding:4px 8px;text-align:left">Description</th></tr></thead>
        <tbody>{"".join(rows)}</tbody>
      </table></div>"""