LCVC-DeepFuse / app.py
vimdhayak's picture
Update app.py
bb86306 verified
Raw
History Blame Contribute Delete
23.3 kB
from __future__ import annotations
import html
import traceback
import gradio as gr
from PIL import Image
from src.config import CLASS_DISPLAY_NAMES, CLASS_NAMES
from src.modeling import predict, weighted_ensemble_cam
CUSTOM_CSS = """
/* ── Design tokens ─────────────────────────────────────────────────── */
:root {
--bg: #080c14;
--surface: #0f1520;
--surface-2: #151d2c;
--border: rgba(255,255,255,.08);
--border-2: rgba(255,255,255,.13);
--text: #f1f5f9;
--text-2: #8a96aa;
--text-3: #5a6478;
--accent: #7c6fff;
--accent-2: #40c4ff;
--radius: 14px;
--radius-sm: 8px;
--font: "DM Sans", ui-sans-serif, system-ui, sans-serif;
--font-mono: "DM Mono", ui-monospace, monospace;
}
/* ── Base ──────────────────────────────────────────────────────────── */
@import url('https://fonts.googleapis.com/css2?family=DM+Sans:wght@300;400;500;600;700&family=DM+Mono:wght@400;500&display=swap');
html, body, .gradio-container {
background: var(--bg) !important;
color: var(--text) !important;
font-family: var(--font) !important;
}
.gradio-container {
max-width: 1180px !important;
margin: 0 auto !important;
padding: 0 20px !important;
}
/* Kill Gradio's default chrome */
.gradio-container .block,
.gradio-container .form,
.gradio-container .panel,
.gradio-container .wrap,
.gradio-container .gr-box {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
.gradio-container label,
.gradio-container label span,
.gradio-container .info,
.gradio-container p,
.gradio-container span {
color: var(--text-2) !important;
font-family: var(--font) !important;
}
/* ── Hero ──────────────────────────────────────────────────────────── */
.hero {
padding: 36px 0 28px;
border-bottom: 1px solid var(--border);
margin-bottom: 28px;
}
.hero h1 {
margin: 0 !important;
font-size: 2rem !important;
font-weight: 700 !important;
letter-spacing: -0.03em !important;
color: var(--text) !important;
}
.hero p {
margin: 4px 0 0 !important;
font-size: .95rem !important;
color: var(--text-2) !important;
}
.chip-row {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 20px;
}
.chip {
display: inline-flex;
align-items: center;
height: 30px;
padding: 0 12px;
border-radius: 99px;
font-size: .8rem;
font-weight: 500;
background: var(--surface-2);
border: 1px solid var(--border-2);
color: var(--text-2);
}
.chip:first-child {
background: rgba(124,111,255,.15);
border-color: rgba(124,111,255,.4);
color: #b3acff;
}
/* ── Card ──────────────────────────────────────────────────────────── */
.card {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
padding: 20px !important;
}
.card + .card { margin-top: 12px; }
.card-label {
font-size: .72rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: .08em;
color: var(--text-3);
margin-bottom: 14px;
}
/* ── Upload / heatmap images ───────────────────────────────────────── */
#mri_upload,
#heatmap_preview {
background: transparent !important;
border: 0 !important;
box-shadow: none !important;
width: 100% !important;
}
#mri_upload > div,
#mri_upload .wrap,
#mri_upload .block,
#mri_upload .form,
#heatmap_preview > div,
#heatmap_preview .wrap,
#heatmap_preview .block,
#heatmap_preview .form {
background: transparent !important;
border: 0 !important;
box-shadow: none !important;
padding: 0 !important;
width: 100% !important;
overflow: visible !important;
}
#mri_upload .image-container,
#mri_upload .upload-container,
#mri_upload [data-testid="image"],
#mri_upload .empty,
#heatmap_preview .image-container,
#heatmap_preview .upload-container,
#heatmap_preview [data-testid="image"],
#heatmap_preview .empty {
width: 100% !important;
border-radius: var(--radius) !important;
background: var(--surface-2) !important;
border: 1px dashed var(--border-2) !important;
box-shadow: none !important;
}
#mri_upload .image-container,
#mri_upload .upload-container,
#mri_upload [data-testid="image"],
#mri_upload .empty {
min-height: 420px !important;
height: 420px !important;
}
#heatmap_preview .image-container,
#heatmap_preview .upload-container,
#heatmap_preview [data-testid="image"],
#heatmap_preview .empty {
min-height: 240px !important;
height: 240px !important;
}
#mri_upload img,
#heatmap_preview img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
border-radius: var(--radius) !important;
background: var(--surface-2) !important;
}
/* Extra bottom padding on upload card so source buttons are never clipped */
#upload_card {
padding-bottom: 32px !important;
}
/* ── Upload / paste source buttons ─────────────────────────────────── */
/* Tab strip container */
#mri_upload .tabs,
#mri_upload [role="tablist"],
#mri_upload .tab-nav {
display: flex !important;
flex-direction: row !important;
gap: 8px !important;
padding: 0 0 12px 0 !important;
background: transparent !important;
border: none !important;
opacity: 1 !important;
visibility: visible !important;
position: relative !important;
z-index: 30 !important;
}
/* Each tab / icon button */
#mri_upload [role="tab"],
#mri_upload .tab-nav button,
#mri_upload .tabs button,
#mri_upload button.tab-button {
display: inline-flex !important;
align-items: center !important;
justify-content: center !important;
min-width: 44px !important;
min-height: 44px !important;
padding: 10px 12px !important;
border-radius: var(--radius-sm) !important;
background: #1e2d45 !important;
border: 1px solid rgba(255,255,255,.18) !important;
color: #e2e8f0 !important;
font-size: .85rem !important;
font-weight: 500 !important;
opacity: 1 !important;
visibility: visible !important;
z-index: 30 !important;
cursor: pointer !important;
box-shadow: 0 2px 8px rgba(0,0,0,.35) !important;
transition: background .15s !important;
}
#mri_upload [role="tab"]:hover,
#mri_upload .tabs button:hover {
background: #283d57 !important;
}
/* Force SVG icons inside the buttons to be white/light */
#mri_upload [role="tab"] svg,
#mri_upload .tabs button svg,
#mri_upload button svg {
color: #e2e8f0 !important;
fill: #e2e8f0 !important;
stroke: #e2e8f0 !important;
width: 20px !important;
height: 20px !important;
opacity: 1 !important;
}
/* Active / selected tab */
#mri_upload [role="tab"][aria-selected="true"],
#mri_upload [role="tab"].selected {
background: var(--accent) !important;
border-color: var(--accent) !important;
color: #fff !important;
}
#mri_upload [role="tab"][aria-selected="true"] svg,
#mri_upload [role="tab"].selected svg {
color: #fff !important;
fill: #fff !important;
stroke: #fff !important;
}
/* ── Checkbox ──────────────────────────────────────────────────────── */
#heatmap_checkbox {
margin-top: 12px !important;
}
#heatmap_checkbox label {
display: flex !important;
align-items: center !important;
gap: 10px !important;
padding: 12px 14px !important;
border-radius: var(--radius-sm) !important;
background: var(--surface-2) !important;
border: 1px solid var(--border) !important;
min-height: auto !important;
}
#heatmap_checkbox input[type="checkbox"] {
appearance: none !important;
-webkit-appearance: none !important;
width: 20px !important;
height: 20px !important;
min-width: 20px !important;
border-radius: 6px !important;
background: var(--bg) !important;
border: 1.5px solid var(--border-2) !important;
cursor: pointer !important;
}
#heatmap_checkbox input[type="checkbox"]:checked {
background-color: var(--accent) !important;
border-color: var(--accent) !important;
background-image: url("data:image/svg+xml,%3Csvg width='14' height='14' viewBox='0 0 20 20' fill='none' xmlns='http://www.w3.org/2000/svg'%3E%3Cpath d='M4.2 10.4L8.1 14.2L15.9 5.8' stroke='white' stroke-width='2.5' stroke-linecap='round' stroke-linejoin='round'/%3E%3C/svg%3E") !important;
background-repeat: no-repeat !important;
background-position: center !important;
}
#heatmap_checkbox span {
color: var(--text) !important;
font-size: .88rem !important;
font-weight: 500 !important;
}
/* ── Run button ────────────────────────────────────────────────────── */
#run_button button {
width: 100% !important;
margin-top: 10px !important;
height: 44px !important;
border-radius: var(--radius-sm) !important;
background: var(--accent) !important;
color: #fff !important;
font-weight: 600 !important;
font-size: .92rem !important;
border: none !important;
box-shadow: 0 0 0 1px rgba(124,111,255,.3), 0 4px 16px rgba(124,111,255,.2) !important;
transition: opacity .15s !important;
letter-spacing: -.01em !important;
}
#run_button button:hover {
opacity: .88 !important;
}
/* ── Result card ───────────────────────────────────────────────────── */
.result-card {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
padding: 22px 24px !important;
margin-bottom: 12px;
}
.pred-eyebrow {
font-size: .72rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: .1em;
color: var(--text-3);
margin-bottom: 8px;
}
.pred-label {
font-size: 2.2rem;
font-weight: 700;
letter-spacing: -0.045em;
color: var(--text);
line-height: 1;
margin-bottom: 6px;
}
.pred-sub {
font-size: .85rem;
color: var(--text-2);
margin-bottom: 20px;
}
.pred-sub b {
color: #b3acff;
font-weight: 600;
}
/* Metric row — always 3 columns */
.metric-grid {
display: grid !important;
grid-template-columns: repeat(3, 1fr) !important;
gap: 10px !important;
}
.metric {
padding: 12px 14px;
border-radius: var(--radius-sm);
background: var(--surface-2);
border: 1px solid var(--border);
}
.metric .k {
font-size: .72rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: .07em;
color: var(--text-3);
margin-bottom: 5px;
}
.metric .v {
font-size: 1.05rem;
font-weight: 700;
color: var(--text);
font-family: var(--font-mono);
}
.metric .v.confidence { color: #b3acff; }
/* ── Probability card ──────────────────────────────────────────────── */
.prob-card {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
padding: 22px 24px !important;
margin-bottom: 12px;
}
.prob-title {
font-size: .72rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: .1em;
color: var(--text-3);
margin-bottom: 16px;
}
.prob-item {
margin-bottom: 14px;
}
.prob-head {
display: flex;
justify-content: space-between;
align-items: baseline;
margin-bottom: 6px;
}
.prob-label {
font-size: .88rem;
font-weight: 500;
color: var(--text-2);
}
.prob-item.top .prob-label { color: var(--text); }
.prob-percent {
font-size: .82rem;
font-weight: 600;
color: var(--text-3);
font-family: var(--font-mono);
}
.prob-item.top .prob-percent { color: #b3acff; }
.prob-track {
height: 4px;
border-radius: 99px;
background: var(--surface-2);
overflow: hidden;
}
.prob-fill {
height: 100%;
border-radius: 99px;
background: linear-gradient(90deg, var(--accent), var(--accent-2));
}
/* ── Details card ──────────────────────────────────────────────────── */
.details-card {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
padding: 22px 24px !important;
margin-bottom: 12px;
}
.details-title {
font-size: .72rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: .1em;
color: var(--text-3);
margin-bottom: 6px;
}
.details-subtitle {
font-size: .82rem;
color: var(--text-3);
margin-bottom: 18px;
line-height: 1.5;
}
/* Member row — 4-column grid */
.member-row {
display: grid !important;
grid-template-columns: 1.8fr 1fr 1fr 1fr !important;
gap: 10px !important;
align-items: start !important;
padding: 14px !important;
border-radius: var(--radius-sm) !important;
background: var(--surface-2) !important;
border: 1px solid var(--border) !important;
margin-bottom: 10px !important;
}
.member-name {
font-size: .9rem;
font-weight: 600;
color: var(--text);
line-height: 1.3;
}
.member-meta {
margin-top: 4px;
font-size: .78rem;
color: var(--text-3);
}
.detail-pill {
padding: 10px 12px;
border-radius: var(--radius-sm);
background: var(--bg);
border: 1px solid var(--border);
}
.detail-key {
font-size: .68rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: .08em;
color: var(--text-3);
margin-bottom: 4px;
}
.detail-value {
font-size: .9rem;
font-weight: 600;
color: var(--text);
font-family: var(--font-mono);
}
.detail-value.accent { color: #b3acff; }
.vote-track {
margin-top: 6px;
height: 3px;
border-radius: 99px;
background: rgba(255,255,255,.08);
overflow: hidden;
}
.vote-fill {
height: 100%;
border-radius: 99px;
background: linear-gradient(90deg, var(--accent), var(--accent-2));
}
/* ── Responsive ────────────────────────────────────────────────────── */
@media (max-width: 860px) {
.metric-grid { grid-template-columns: repeat(3, 1fr) !important; }
.member-row { grid-template-columns: 1fr 1fr !important; }
}
@media (max-width: 560px) {
.metric-grid { grid-template-columns: 1fr 1fr !important; }
.member-row { grid-template-columns: 1fr !important; }
.pred-label { font-size: 1.7rem; }
}
"""
CLASS_CHIPS_HTML = "".join(
f"<span class='chip'>{html.escape(CLASS_DISPLAY_NAMES[name])}</span>"
for name in CLASS_NAMES
)
HERO_HTML = f"""
<div class="hero">
<h1>LCVC DeepFuse</h1>
<p>Brain MRI Ensemble Classifier · EfficientNet-B0 + MobileNetV3-Small</p>
<div class="chip-row">{CLASS_CHIPS_HTML}</div>
</div>
"""
def _pct(value: float) -> str:
return f"{100.0 * float(value):.2f}%"
def _empty_prediction_card() -> str:
return """
<div class="result-card">
<div class="pred-eyebrow">Top Prediction</div>
<div class="pred-label">Waiting</div>
<div class="pred-sub">Upload an MRI image, then run prediction.</div>
<div class="metric-grid">
<div class="metric"><div class="k">Input size</div><div class="v">—</div></div>
<div class="metric"><div class="k">Model votes</div><div class="v">3</div></div>
<div class="metric"><div class="k">Confidence</div><div class="v confidence">—</div></div>
</div>
</div>
"""
def _prediction_card(label: str, confidence: float, image: Image.Image) -> str:
width, height = image.size if image is not None else (0, 0)
label = html.escape(label)
confidence_text = _pct(confidence)
return f"""
<div class="result-card">
<div class="pred-eyebrow">Top Prediction</div>
<div class="pred-label">{label}</div>
<div class="pred-sub">Confidence: <b>{confidence_text}</b></div>
<div class="metric-grid">
<div class="metric"><div class="k">Input size</div><div class="v">{width}×{height}</div></div>
<div class="metric"><div class="k">Model votes</div><div class="v">3</div></div>
<div class="metric"><div class="k">Confidence</div><div class="v confidence">{confidence_text}</div></div>
</div>
</div>
"""
def _probabilities_card(probabilities: dict[str, float] | None = None, top_class: str | None = None) -> str:
probabilities = probabilities or {name: 0.0 for name in CLASS_NAMES}
rows: list[tuple[str, str, float]] = []
for class_name in CLASS_NAMES:
display = CLASS_DISPLAY_NAMES[class_name]
probability = float(probabilities.get(class_name, 0.0))
rows.append((class_name, display, probability))
rows.sort(key=lambda item: item[2], reverse=True)
items = []
for class_name, display, probability in rows:
percent = max(0.0, min(100.0, probability * 100.0))
top_cls = " top" if class_name == top_class else ""
items.append(
f"""
<div class="prob-item{top_cls}">
<div class="prob-head">
<span class="prob-label">{html.escape(display)}</span>
<span class="prob-percent">{percent:.2f}%</span>
</div>
<div class="prob-track">
<div class="prob-fill" style="width:{percent:.2f}%"></div>
</div>
</div>
"""
)
return f"""
<div class="prob-card">
<div class="prob-title">Class Probabilities</div>
{''.join(items)}
</div>
"""
def _empty_details_card() -> str:
return """
<div class="details-card">
<div class="details-title">Model Contributions</div>
<div class="details-subtitle">Run prediction to see each checkpoint's confidence, weight, and vote strength.</div>
<div class="member-row">
<div>
<div class="member-name">Waiting for image</div>
<div class="member-meta">EfficientNet-B0 + MobileNetV3-Small</div>
</div>
<div class="detail-pill"><div class="detail-key">Predicts</div><div class="detail-value">—</div></div>
<div class="detail-pill"><div class="detail-key">Confidence</div><div class="detail-value">—</div></div>
<div class="detail-pill"><div class="detail-key">Weighted vote</div><div class="detail-value accent">—</div></div>
</div>
</div>
"""
def _details_card(member_df) -> str:
if member_df is None or len(member_df) == 0:
return _empty_details_card()
rows_html = []
for _, row in member_df.iterrows():
member = html.escape(str(row.get("member", "Model")))
weight = float(row.get("weight", 0.0))
pred = html.escape(str(row.get("member prediction", "—")))
conf = float(row.get("member confidence", 0.0))
weighted_vote = max(0.0, min(1.0, weight * conf))
vote_pct = weighted_vote * 100.0
rows_html.append(
f"""
<div class="member-row">
<div>
<div class="member-name">{member}</div>
<div class="member-meta">Weight: {weight * 100.0:.1f}%</div>
</div>
<div class="detail-pill">
<div class="detail-key">Predicts</div>
<div class="detail-value">{pred}</div>
</div>
<div class="detail-pill">
<div class="detail-key">Confidence</div>
<div class="detail-value">{conf * 100.0:.1f}%</div>
</div>
<div class="detail-pill">
<div class="detail-key">Weighted vote</div>
<div class="detail-value accent">{vote_pct:.1f}%</div>
<div class="vote-track"><div class="vote-fill" style="width:{vote_pct:.1f}%"></div></div>
</div>
</div>
"""
)
return f"""
<div class="details-card">
<div class="details-title">Model Contributions</div>
<div class="details-subtitle">Weighted vote = confidence × ensemble weight.</div>
{''.join(rows_html)}
</div>
"""
def run_prediction(image: Image.Image, make_heatmap: bool):
if image is None:
raise gr.Error("Upload an MRI image first.")
try:
result = predict(image)
heatmap = weighted_ensemble_cam(image, result.predicted_class) if make_heatmap else None
prediction_html = _prediction_card(result.predicted_display, result.confidence, image)
probabilities_html = _probabilities_card(result.probabilities, result.predicted_class)
details_html = _details_card(result.member_df)
return prediction_html, probabilities_html, details_html, heatmap
except FileNotFoundError as exc:
raise gr.Error(str(exc)) from exc
except Exception as exc:
detail = traceback.format_exc(limit=3)
raise gr.Error(f"Prediction failed: {exc}\n\n{detail}") from exc
with gr.Blocks(
css=CUSTOM_CSS,
theme=gr.themes.Base(primary_hue="violet", secondary_hue="cyan", neutral_hue="slate"),
title="LCVC DeepFuse",
) as demo:
gr.HTML(HERO_HTML)
with gr.Row(equal_height=False):
with gr.Column(scale=4, min_width=320):
with gr.Group(elem_classes=["card", "upload-panel"], elem_id="upload_card"):
gr.HTML('<div class="card-label">MRI Image</div>')
image_input = gr.Image(
label="Upload MRI Image",
show_label=False,
type="pil",
height=420,
sources=["upload", "clipboard"],
elem_id="mri_upload",
)
heatmap_toggle = gr.Checkbox(
value=False,
label="Generate Grad-CAM heatmap",
elem_id="heatmap_checkbox",
)
run_button = gr.Button("Run Prediction", variant="primary", elem_id="run_button")
with gr.Group(elem_classes=["card", "heatmap-panel"]):
gr.HTML('<div class="card-label">Grad-CAM Heatmap</div>')
heatmap_output = gr.Image(
label="Heatmap",
show_label=False,
type="pil",
height=240,
elem_id="heatmap_preview",
)
with gr.Column(scale=6, min_width=380):
prediction_html_out = gr.HTML(_empty_prediction_card())
probabilities_html_out = gr.HTML(_probabilities_card())
details_html_out = gr.HTML(_empty_details_card())
run_button.click(
fn=run_prediction,
inputs=[image_input, heatmap_toggle],
outputs=[prediction_html_out, probabilities_html_out, details_html_out, heatmap_output],
)
if __name__ == "__main__":
demo.queue(max_size=16).launch()