import gradio as gr import numpy as np import cv2 import json import os from tensorflow.keras.models import load_model # ─── Load model ─────────────────────────────────────────────────────────────── model = load_model("custom_cnn.h5") IMG_SIZE = 224 NUM_OUTPUTS = model.output_shape[-1] # auto-detects 3-class or 16-class # ─── Class / cluster labels ─────────────────────────────────────────────────── # Priority 1: class_labels.json saved alongside the model (from the 16-class notebook) # Priority 2: fallback cluster names for the 3-class K-Means model if os.path.exists("class_labels.json"): with open("class_labels.json") as f: CLASS_NAMES = json.load(f)["classes"] else: # 3-class K-Means cluster model fallback CLASS_NAMES = [f"Cluster {i}" for i in range(NUM_OUTPUTS)] # ─── Which actual pathology classes are dominant in each cluster ────────────── # These come from analysing your K-Means cluster assignments vs ground-truth labels. # REPLACE these lists with the real counts from your own cluster analysis notebook. CLUSTER_DOMINANT = { "Cluster 0": [ ("Normal", 0.38), ("Mild Ventriculomegaly", 0.22), ("Arnold–Chiari Malformation",0.15), ("Moderate Ventriculomegaly", 0.14), ("Hydranencephaly", 0.11), ], "Cluster 1": [ ("Severe Ventriculomegaly", 0.35), ("Dandy–Walker Malformation", 0.25), ("Holoprosencephaly", 0.18), ("Agenesis of Corpus Callosum",0.13), ("Intracranial Tumors", 0.09), ], "Cluster 2": [ ("Intracranial Tumors", 0.30), ("Intracranial Hemorrhages", 0.28), ("Holoprosencephaly", 0.20), ("Dandy–Walker Malformation", 0.12), ("Agenesis of Corpus Callosum",0.10), ], } # For the 16-class model, dominant "classes in cluster" = top-5 softmax outputs USE_SOFTMAX_DOMINANT = (NUM_OUTPUTS > 3) # ─── All 16 ground-truth class names for the dropdown ──────────────────────── ALL_GT_CLASSES = [ "Normal", "Mild Ventriculomegaly", "Moderate Ventriculomegaly", "Severe Ventriculomegaly", "Arnold–Chiari Malformation", "Hydranencephaly", "Agenesis of Corpus Callosum", "Dandy–Walker Malformation", "Intracranial Tumors", "Intracranial Hemorrhages", "Holoprosencephaly", "Cerebellar Hypoplasia", "Microcephaly", "Macrocephaly", "Lissencephaly", "Unknown / Not provided", ] # ─── Preprocessing — mirrors the paper §3B pipeline ────────────────────────── def preprocess(image: np.ndarray) -> np.ndarray: """Gaussian blur → median filter → CLAHE → normalize [0,1].""" if image is None: return None img = image.astype(np.uint8) # To grayscale if img.ndim == 3 and img.shape[2] == 3: gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) else: gray = img if img.ndim == 2 else img[:, :, 0] # §3B-2: Gaussian + median blurred = cv2.GaussianBlur(gray, (5, 5), sigmaX=1.0) median = cv2.medianBlur(blurred, 5) # §3B-3: CLAHE clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) enhanced = clahe.apply(median) # Back to RGB float32 [0,1] rgb = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB).astype(np.float32) / 255.0 return rgb # ─── EMOJI badges for ranks ─────────────────────────────────────────────────── RANK_EMOJI = ["🥇", "🥈", "🥉", "4️⃣", "5️⃣"] # ─── Progress-bar helper ────────────────────────────────────────────────────── def pct_bar(value: float, width: int = 28) -> str: filled = round(value * width) return "█" * filled + "░" * (width - filled) # ─── Main prediction function ───────────────────────────────────────────────── def predict(image, actual_class): if image is None: empty = "Upload an ultrasound image to begin." return empty, empty, empty # ── Preprocess & predict ────────────────────────────────────────────────── proc = preprocess(image) resized = cv2.resize(proc, (IMG_SIZE, IMG_SIZE)) inp = np.expand_dims(resized, axis=0) probs = model.predict(inp, verbose=0)[0] # shape: (num_classes,) top5_idx = np.argsort(probs)[::-1][:5] pred_idx = top5_idx[0] pred_label = CLASS_NAMES[pred_idx] confidence = probs[pred_idx] * 100.0 # ── Panel 1: Prediction cluster ─────────────────────────────────────────── cluster_lines = [ "┌─────────────────────────────────────────┐", f"│ PREDICTED CLUSTER / CLASS │", "├─────────────────────────────────────────┤", f"│ {pred_label:<39} │", f"│ Confidence : {confidence:>6.2f}% │", "└─────────────────────────────────────────┘", "", "All cluster probabilities:", "─" * 43, ] for i, (cname, p) in enumerate(zip(CLASS_NAMES, probs)): marker = " ◀ PREDICTED" if i == pred_idx else "" cluster_lines.append( f" {cname:<35} {p*100:5.1f}%{marker}" ) cluster_text = "\n".join(cluster_lines) # ── Panel 2: Top-5 dominant classes ────────────────────────────────────── if USE_SOFTMAX_DOMINANT: # 16-class model — dominant = top-5 softmax outputs dominant = [(CLASS_NAMES[i], float(probs[i])) for i in top5_idx] source_note = f"(direct softmax outputs from {NUM_OUTPUTS}-class model)" else: # 3-class cluster model — look up pre-computed dominant pathologies dominant = CLUSTER_DOMINANT.get( pred_label, [(f"Class {j}", 0.2) for j in range(5)] ) source_note = f"(pathologies most common in {pred_label})" top5_lines = [ f"TOP 5 DOMINANT PATHOLOGY CLASSES {source_note}", "─" * 63, "", ] for rank, (cname, score) in enumerate(dominant): bar = pct_bar(score) emoji = RANK_EMOJI[rank] top5_lines.append( f" {emoji} {cname:<40} {bar} {score*100:5.1f}%" ) top5_text = "\n".join(top5_lines) # ── Panel 3: Actual class comparison ───────────────────────────────────── if not actual_class or actual_class == "Unknown / Not provided": actual_lines = [ "ℹ️ No ground-truth label provided.", "", "Select the actual class from the dropdown", "on the left to see a correctness check.", ] else: # For cluster model: check if actual class appears in the top-5 dominant list dominant_names = [d[0] for d in dominant] in_top5 = actual_class in dominant_names # For 16-class model: direct label match if USE_SOFTMAX_DOMINANT: correct = (actual_class == pred_label) match_str = "✅ CORRECT PREDICTION" if correct else f"❌ INCORRECT (model predicted '{pred_label}')" else: # Cluster model: soft match — is the actual class in the cluster's top-5? if in_top5: rank_pos = dominant_names.index(actual_class) + 1 match_str = f"✅ CORRECT CLUSTER ('{actual_class}' is #{rank_pos} in {pred_label})" else: match_str = ( f"⚠️ PARTIAL MISS ('{actual_class}' not in top-5 of {pred_label})\n" f" This may indicate a cluster assignment issue or borderline case." ) actual_lines = [ "GROUND TRUTH vs PREDICTION", "─" * 43, "", f" Actual class : {actual_class}", f" Predicted : {pred_label} ({confidence:.1f}%)", "", f" {match_str}", "", "─" * 43, "Top-5 dominant classes in predicted cluster:", ] for rank, (cname, score) in enumerate(dominant): tick = " ✓" if cname == actual_class else " " actual_lines.append(f" {tick} {rank+1}. {cname:<38} {score*100:.1f}%") actual_text = "\n".join(actual_lines) return cluster_text, top5_text, actual_text # ─── Gradio UI ──────────────────────────────────────────────────────────────── CSS = """ body, .gradio-container { background: #0d1117 !important; } .gr-box, .gr-panel { background: #161b22 !important; border: 1px solid #30363d !important; } .gr-button { background: #238636 !important; color: #fff !important; border: none !important; } .gr-button:hover { background: #2ea043 !important; } .output-text textarea { font-family: 'Courier New', monospace !important; font-size: 13px !important; background: #0d1117 !important; color: #e6edf3 !important; border: 1px solid #30363d !important; } label span { color: #8b949e !important; } h1, h2, h3 { color: #e6edf3 !important; } """ with gr.Blocks(css=CSS, title="Fetal Brain MRI Classifier 🧠") as demo: gr.Markdown(""" # 🧠 Fetal Brain MRI Classifier #### Ultrasound anomaly detection — Standard CNN / Xception transfer learning Upload a fetal ultrasound image, optionally select the known ground-truth class, then click **Submit**. """) with gr.Row(): # ── Left column: inputs ────────────────────────────────────────────── with gr.Column(scale=1): image_input = gr.Image( type="numpy", label="Ultrasound Image", image_mode="RGB", ) actual_input = gr.Dropdown( choices=ALL_GT_CLASSES, value="Unknown / Not provided", label="Actual Ground-Truth Class (optional)", ) with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Submit", variant="primary") # ── Right column: outputs ──────────────────────────────────────────── with gr.Column(scale=2): cluster_out = gr.Textbox( label="🏆 Predicted Cluster / Class", lines=14, interactive=False, ) top5_out = gr.Textbox( label="📊 Top 5 Dominant Pathology Classes", lines=10, interactive=False, ) actual_out = gr.Textbox( label="✅ Actual Class Comparison", lines=12, interactive=False, ) # ── Wire up events ─────────────────────────────────────────────────────── submit_btn.click( fn=predict, inputs=[image_input, actual_input], outputs=[cluster_out, top5_out, actual_out], ) clear_btn.click( fn=lambda: (None, "Unknown / Not provided", "", "", ""), inputs=[], outputs=[image_input, actual_input, cluster_out, top5_out, actual_out], ) demo.launch()