project2 / app.py
VedikaP's picture
Upload 3 files
9fd74fd verified
import gradio as gr
import numpy as np
import cv2
import json
import os
from tensorflow.keras.models import load_model
# ─── Load model ───────────────────────────────────────────────────────────────
model = load_model("custom_cnn.h5")
IMG_SIZE = 224
NUM_OUTPUTS = model.output_shape[-1] # auto-detects 3-class or 16-class
# ─── Class / cluster labels ───────────────────────────────────────────────────
# Priority 1: class_labels.json saved alongside the model (from the 16-class notebook)
# Priority 2: fallback cluster names for the 3-class K-Means model
if os.path.exists("class_labels.json"):
with open("class_labels.json") as f:
CLASS_NAMES = json.load(f)["classes"]
else:
# 3-class K-Means cluster model fallback
CLASS_NAMES = [f"Cluster {i}" for i in range(NUM_OUTPUTS)]
# ─── Which actual pathology classes are dominant in each cluster ──────────────
# These come from analysing your K-Means cluster assignments vs ground-truth labels.
# REPLACE these lists with the real counts from your own cluster analysis notebook.
CLUSTER_DOMINANT = {
"Cluster 0": [
("Normal", 0.38),
("Mild Ventriculomegaly", 0.22),
("Arnold–Chiari Malformation",0.15),
("Moderate Ventriculomegaly", 0.14),
("Hydranencephaly", 0.11),
],
"Cluster 1": [
("Severe Ventriculomegaly", 0.35),
("Dandy–Walker Malformation", 0.25),
("Holoprosencephaly", 0.18),
("Agenesis of Corpus Callosum",0.13),
("Intracranial Tumors", 0.09),
],
"Cluster 2": [
("Intracranial Tumors", 0.30),
("Intracranial Hemorrhages", 0.28),
("Holoprosencephaly", 0.20),
("Dandy–Walker Malformation", 0.12),
("Agenesis of Corpus Callosum",0.10),
],
}
# For the 16-class model, dominant "classes in cluster" = top-5 softmax outputs
USE_SOFTMAX_DOMINANT = (NUM_OUTPUTS > 3)
# ─── All 16 ground-truth class names for the dropdown ────────────────────────
ALL_GT_CLASSES = [
"Normal",
"Mild Ventriculomegaly",
"Moderate Ventriculomegaly",
"Severe Ventriculomegaly",
"Arnold–Chiari Malformation",
"Hydranencephaly",
"Agenesis of Corpus Callosum",
"Dandy–Walker Malformation",
"Intracranial Tumors",
"Intracranial Hemorrhages",
"Holoprosencephaly",
"Cerebellar Hypoplasia",
"Microcephaly",
"Macrocephaly",
"Lissencephaly",
"Unknown / Not provided",
]
# ─── Preprocessing β€” mirrors the paper Β§3B pipeline ──────────────────────────
def preprocess(image: np.ndarray) -> np.ndarray:
"""Gaussian blur β†’ median filter β†’ CLAHE β†’ normalize [0,1]."""
if image is None:
return None
img = image.astype(np.uint8)
# To grayscale
if img.ndim == 3 and img.shape[2] == 3:
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
else:
gray = img if img.ndim == 2 else img[:, :, 0]
# Β§3B-2: Gaussian + median
blurred = cv2.GaussianBlur(gray, (5, 5), sigmaX=1.0)
median = cv2.medianBlur(blurred, 5)
# Β§3B-3: CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(median)
# Back to RGB float32 [0,1]
rgb = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB).astype(np.float32) / 255.0
return rgb
# ─── EMOJI badges for ranks ───────────────────────────────────────────────────
RANK_EMOJI = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰", "4️⃣", "5️⃣"]
# ─── Progress-bar helper ──────────────────────────────────────────────────────
def pct_bar(value: float, width: int = 28) -> str:
filled = round(value * width)
return "β–ˆ" * filled + "β–‘" * (width - filled)
# ─── Main prediction function ─────────────────────────────────────────────────
def predict(image, actual_class):
if image is None:
empty = "Upload an ultrasound image to begin."
return empty, empty, empty
# ── Preprocess & predict ──────────────────────────────────────────────────
proc = preprocess(image)
resized = cv2.resize(proc, (IMG_SIZE, IMG_SIZE))
inp = np.expand_dims(resized, axis=0)
probs = model.predict(inp, verbose=0)[0] # shape: (num_classes,)
top5_idx = np.argsort(probs)[::-1][:5]
pred_idx = top5_idx[0]
pred_label = CLASS_NAMES[pred_idx]
confidence = probs[pred_idx] * 100.0
# ── Panel 1: Prediction cluster ───────────────────────────────────────────
cluster_lines = [
"β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”",
f"β”‚ PREDICTED CLUSTER / CLASS β”‚",
"β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€",
f"β”‚ {pred_label:<39} β”‚",
f"β”‚ Confidence : {confidence:>6.2f}% β”‚",
"β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜",
"",
"All cluster probabilities:",
"─" * 43,
]
for i, (cname, p) in enumerate(zip(CLASS_NAMES, probs)):
marker = " β—€ PREDICTED" if i == pred_idx else ""
cluster_lines.append(
f" {cname:<35} {p*100:5.1f}%{marker}"
)
cluster_text = "\n".join(cluster_lines)
# ── Panel 2: Top-5 dominant classes ──────────────────────────────────────
if USE_SOFTMAX_DOMINANT:
# 16-class model β€” dominant = top-5 softmax outputs
dominant = [(CLASS_NAMES[i], float(probs[i])) for i in top5_idx]
source_note = f"(direct softmax outputs from {NUM_OUTPUTS}-class model)"
else:
# 3-class cluster model β€” look up pre-computed dominant pathologies
dominant = CLUSTER_DOMINANT.get(
pred_label,
[(f"Class {j}", 0.2) for j in range(5)]
)
source_note = f"(pathologies most common in {pred_label})"
top5_lines = [
f"TOP 5 DOMINANT PATHOLOGY CLASSES {source_note}",
"─" * 63,
"",
]
for rank, (cname, score) in enumerate(dominant):
bar = pct_bar(score)
emoji = RANK_EMOJI[rank]
top5_lines.append(
f" {emoji} {cname:<40} {bar} {score*100:5.1f}%"
)
top5_text = "\n".join(top5_lines)
# ── Panel 3: Actual class comparison ─────────────────────────────────────
if not actual_class or actual_class == "Unknown / Not provided":
actual_lines = [
"ℹ️ No ground-truth label provided.",
"",
"Select the actual class from the dropdown",
"on the left to see a correctness check.",
]
else:
# For cluster model: check if actual class appears in the top-5 dominant list
dominant_names = [d[0] for d in dominant]
in_top5 = actual_class in dominant_names
# For 16-class model: direct label match
if USE_SOFTMAX_DOMINANT:
correct = (actual_class == pred_label)
match_str = "βœ… CORRECT PREDICTION" if correct else f"❌ INCORRECT (model predicted '{pred_label}')"
else:
# Cluster model: soft match β€” is the actual class in the cluster's top-5?
if in_top5:
rank_pos = dominant_names.index(actual_class) + 1
match_str = f"βœ… CORRECT CLUSTER ('{actual_class}' is #{rank_pos} in {pred_label})"
else:
match_str = (
f"⚠️ PARTIAL MISS ('{actual_class}' not in top-5 of {pred_label})\n"
f" This may indicate a cluster assignment issue or borderline case."
)
actual_lines = [
"GROUND TRUTH vs PREDICTION",
"─" * 43,
"",
f" Actual class : {actual_class}",
f" Predicted : {pred_label} ({confidence:.1f}%)",
"",
f" {match_str}",
"",
"─" * 43,
"Top-5 dominant classes in predicted cluster:",
]
for rank, (cname, score) in enumerate(dominant):
tick = " βœ“" if cname == actual_class else " "
actual_lines.append(f" {tick} {rank+1}. {cname:<38} {score*100:.1f}%")
actual_text = "\n".join(actual_lines)
return cluster_text, top5_text, actual_text
# ─── Gradio UI ────────────────────────────────────────────────────────────────
CSS = """
body, .gradio-container { background: #0d1117 !important; }
.gr-box, .gr-panel { background: #161b22 !important; border: 1px solid #30363d !important; }
.gr-button { background: #238636 !important; color: #fff !important; border: none !important; }
.gr-button:hover { background: #2ea043 !important; }
.output-text textarea { font-family: 'Courier New', monospace !important; font-size: 13px !important;
background: #0d1117 !important; color: #e6edf3 !important;
border: 1px solid #30363d !important; }
label span { color: #8b949e !important; }
h1, h2, h3 { color: #e6edf3 !important; }
"""
with gr.Blocks(css=CSS, title="Fetal Brain MRI Classifier 🧠") as demo:
gr.Markdown("""
# 🧠 Fetal Brain MRI Classifier
#### Ultrasound anomaly detection β€” Standard CNN / Xception transfer learning
Upload a fetal ultrasound image, optionally select the known ground-truth class, then click **Submit**.
""")
with gr.Row():
# ── Left column: inputs ──────────────────────────────────────────────
with gr.Column(scale=1):
image_input = gr.Image(
type="numpy",
label="Ultrasound Image",
image_mode="RGB",
)
actual_input = gr.Dropdown(
choices=ALL_GT_CLASSES,
value="Unknown / Not provided",
label="Actual Ground-Truth Class (optional)",
)
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
# ── Right column: outputs ────────────────────────────────────────────
with gr.Column(scale=2):
cluster_out = gr.Textbox(
label="πŸ† Predicted Cluster / Class",
lines=14,
interactive=False,
)
top5_out = gr.Textbox(
label="πŸ“Š Top 5 Dominant Pathology Classes",
lines=10,
interactive=False,
)
actual_out = gr.Textbox(
label="βœ… Actual Class Comparison",
lines=12,
interactive=False,
)
# ── Wire up events ───────────────────────────────────────────────────────
submit_btn.click(
fn=predict,
inputs=[image_input, actual_input],
outputs=[cluster_out, top5_out, actual_out],
)
clear_btn.click(
fn=lambda: (None, "Unknown / Not provided", "", "", ""),
inputs=[],
outputs=[image_input, actual_input, cluster_out, top5_out, actual_out],
)
demo.launch()