Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import json | |
| import os | |
| from tensorflow.keras.models import load_model | |
| # βββ Load model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = load_model("model.h5") | |
| IMG_SIZE = 224 | |
| NUM_OUTPUTS = model.output_shape[-1] # auto-detects 3-class or 16-class | |
| # βββ Class / cluster labels βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Priority 1: class_labels.json saved alongside the model (from the 16-class notebook) | |
| # Priority 2: fallback cluster names for the 3-class K-Means model | |
| if os.path.exists("class_labels.json"): | |
| with open("class_labels.json") as f: | |
| CLASS_NAMES = json.load(f)["classes"] | |
| else: | |
| # 3-class K-Means cluster model fallback | |
| CLASS_NAMES = [f"Cluster {i}" for i in range(NUM_OUTPUTS)] | |
| # βββ Which actual pathology classes are dominant in each cluster ββββββββββββββ | |
| # These come from analysing your K-Means cluster assignments vs ground-truth labels. | |
| # REPLACE these lists with the real counts from your own cluster analysis notebook. | |
| CLUSTER_DOMINANT = { | |
| "Cluster 0": [ | |
| ("Normal", 0.38), | |
| ("Mild Ventriculomegaly", 0.22), | |
| ("ArnoldβChiari Malformation",0.15), | |
| ("Moderate Ventriculomegaly", 0.14), | |
| ("Hydranencephaly", 0.11), | |
| ], | |
| "Cluster 1": [ | |
| ("Severe Ventriculomegaly", 0.35), | |
| ("DandyβWalker Malformation", 0.25), | |
| ("Holoprosencephaly", 0.18), | |
| ("Agenesis of Corpus Callosum",0.13), | |
| ("Intracranial Tumors", 0.09), | |
| ], | |
| "Cluster 2": [ | |
| ("Intracranial Tumors", 0.30), | |
| ("Intracranial Hemorrhages", 0.28), | |
| ("Holoprosencephaly", 0.20), | |
| ("DandyβWalker Malformation", 0.12), | |
| ("Agenesis of Corpus Callosum",0.10), | |
| ], | |
| } | |
| # For the 16-class model, dominant "classes in cluster" = top-5 softmax outputs | |
| USE_SOFTMAX_DOMINANT = (NUM_OUTPUTS > 3) | |
| # βββ All 16 ground-truth class names for the dropdown ββββββββββββββββββββββββ | |
| ALL_GT_CLASSES = [ | |
| "Normal", | |
| "Mild Ventriculomegaly", | |
| "Moderate Ventriculomegaly", | |
| "Severe Ventriculomegaly", | |
| "ArnoldβChiari Malformation", | |
| "Hydranencephaly", | |
| "Agenesis of Corpus Callosum", | |
| "DandyβWalker Malformation", | |
| "Intracranial Tumors", | |
| "Intracranial Hemorrhages", | |
| "Holoprosencephaly", | |
| "Cerebellar Hypoplasia", | |
| "Microcephaly", | |
| "Macrocephaly", | |
| "Lissencephaly", | |
| "Unknown / Not provided", | |
| ] | |
| # βββ Preprocessing β mirrors the paper Β§3B pipeline ββββββββββββββββββββββββββ | |
| def preprocess(image: np.ndarray) -> np.ndarray: | |
| """Gaussian blur β median filter β CLAHE β normalize [0,1].""" | |
| if image is None: | |
| return None | |
| img = image.astype(np.uint8) | |
| # To grayscale | |
| if img.ndim == 3 and img.shape[2] == 3: | |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = img if img.ndim == 2 else img[:, :, 0] | |
| # Β§3B-2: Gaussian + median | |
| blurred = cv2.GaussianBlur(gray, (5, 5), sigmaX=1.0) | |
| median = cv2.medianBlur(blurred, 5) | |
| # Β§3B-3: CLAHE | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| enhanced = clahe.apply(median) | |
| # Back to RGB float32 [0,1] | |
| rgb = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB).astype(np.float32) / 255.0 | |
| return rgb | |
| # βββ EMOJI badges for ranks βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| RANK_EMOJI = ["π₯", "π₯", "π₯", "4οΈβ£", "5οΈβ£"] | |
| # βββ Progress-bar helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def pct_bar(value: float, width: int = 28) -> str: | |
| filled = round(value * width) | |
| return "β" * filled + "β" * (width - filled) | |
| # βββ Main prediction function βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(image, actual_class): | |
| if image is None: | |
| empty = "Upload an ultrasound image to begin." | |
| return empty, empty, empty | |
| # ββ Preprocess & predict ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| proc = preprocess(image) | |
| resized = cv2.resize(proc, (IMG_SIZE, IMG_SIZE)) | |
| inp = np.expand_dims(resized, axis=0) | |
| probs = model.predict(inp, verbose=0)[0] # shape: (num_classes,) | |
| top5_idx = np.argsort(probs)[::-1][:5] | |
| pred_idx = top5_idx[0] | |
| pred_label = CLASS_NAMES[pred_idx] | |
| confidence = probs[pred_idx] * 100.0 | |
| # ββ Panel 1: Prediction cluster βββββββββββββββββββββββββββββββββββββββββββ | |
| cluster_lines = [ | |
| "βββββββββββββββββββββββββββββββββββββββββββ", | |
| f"β PREDICTED CLUSTER / CLASS β", | |
| "βββββββββββββββββββββββββββββββββββββββββββ€", | |
| f"β {pred_label:<39} β", | |
| f"β Confidence : {confidence:>6.2f}% β", | |
| "βββββββββββββββββββββββββββββββββββββββββββ", | |
| "", | |
| "All cluster probabilities:", | |
| "β" * 43, | |
| ] | |
| for i, (cname, p) in enumerate(zip(CLASS_NAMES, probs)): | |
| marker = " β PREDICTED" if i == pred_idx else "" | |
| cluster_lines.append( | |
| f" {cname:<35} {p*100:5.1f}%{marker}" | |
| ) | |
| cluster_text = "\n".join(cluster_lines) | |
| # ββ Panel 2: Top-5 dominant classes ββββββββββββββββββββββββββββββββββββββ | |
| if USE_SOFTMAX_DOMINANT: | |
| # 16-class model β dominant = top-5 softmax outputs | |
| dominant = [(CLASS_NAMES[i], float(probs[i])) for i in top5_idx] | |
| source_note = f"(direct softmax outputs from {NUM_OUTPUTS}-class model)" | |
| else: | |
| # 3-class cluster model β look up pre-computed dominant pathologies | |
| dominant = CLUSTER_DOMINANT.get( | |
| pred_label, | |
| [(f"Class {j}", 0.2) for j in range(5)] | |
| ) | |
| source_note = f"(pathologies most common in {pred_label})" | |
| top5_lines = [ | |
| f"TOP 5 DOMINANT PATHOLOGY CLASSES {source_note}", | |
| "β" * 63, | |
| "", | |
| ] | |
| for rank, (cname, score) in enumerate(dominant): | |
| bar = pct_bar(score) | |
| emoji = RANK_EMOJI[rank] | |
| top5_lines.append( | |
| f" {emoji} {cname:<40} {bar} {score*100:5.1f}%" | |
| ) | |
| top5_text = "\n".join(top5_lines) | |
| # ββ Panel 3: Actual class comparison βββββββββββββββββββββββββββββββββββββ | |
| if not actual_class or actual_class == "Unknown / Not provided": | |
| actual_lines = [ | |
| "βΉοΈ No ground-truth label provided.", | |
| "", | |
| "Select the actual class from the dropdown", | |
| "on the left to see a correctness check.", | |
| ] | |
| else: | |
| # For cluster model: check if actual class appears in the top-5 dominant list | |
| dominant_names = [d[0] for d in dominant] | |
| in_top5 = actual_class in dominant_names | |
| # For 16-class model: direct label match | |
| if USE_SOFTMAX_DOMINANT: | |
| correct = (actual_class == pred_label) | |
| match_str = "β CORRECT PREDICTION" if correct else f"β INCORRECT (model predicted '{pred_label}')" | |
| else: | |
| # Cluster model: soft match β is the actual class in the cluster's top-5? | |
| if in_top5: | |
| rank_pos = dominant_names.index(actual_class) + 1 | |
| match_str = f"β CORRECT CLUSTER ('{actual_class}' is #{rank_pos} in {pred_label})" | |
| else: | |
| match_str = ( | |
| f"β οΈ PARTIAL MISS ('{actual_class}' not in top-5 of {pred_label})\n" | |
| f" This may indicate a cluster assignment issue or borderline case." | |
| ) | |
| actual_lines = [ | |
| "GROUND TRUTH vs PREDICTION", | |
| "β" * 43, | |
| "", | |
| f" Actual class : {actual_class}", | |
| f" Predicted : {pred_label} ({confidence:.1f}%)", | |
| "", | |
| f" {match_str}", | |
| "", | |
| "β" * 43, | |
| "Top-5 dominant classes in predicted cluster:", | |
| ] | |
| for rank, (cname, score) in enumerate(dominant): | |
| tick = " β" if cname == actual_class else " " | |
| actual_lines.append(f" {tick} {rank+1}. {cname:<38} {score*100:.1f}%") | |
| actual_text = "\n".join(actual_lines) | |
| return cluster_text, top5_text, actual_text | |
| # βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| body, .gradio-container { background: #0d1117 !important; } | |
| .gr-box, .gr-panel { background: #161b22 !important; border: 1px solid #30363d !important; } | |
| .gr-button { background: #238636 !important; color: #fff !important; border: none !important; } | |
| .gr-button:hover { background: #2ea043 !important; } | |
| .output-text textarea { font-family: 'Courier New', monospace !important; font-size: 13px !important; | |
| background: #0d1117 !important; color: #e6edf3 !important; | |
| border: 1px solid #30363d !important; } | |
| label span { color: #8b949e !important; } | |
| h1, h2, h3 { color: #e6edf3 !important; } | |
| """ | |
| with gr.Blocks(css=CSS, title="Fetal Brain MRI Classifier π§ ") as demo: | |
| gr.Markdown(""" | |
| # π§ Fetal Brain MRI Classifier | |
| #### Ultrasound anomaly detection β Standard CNN / Xception transfer learning | |
| Upload a fetal ultrasound image, optionally select the known ground-truth class, then click **Submit**. | |
| """) | |
| with gr.Row(): | |
| # ββ Left column: inputs ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="numpy", | |
| label="Ultrasound Image", | |
| image_mode="RGB", | |
| ) | |
| actual_input = gr.Dropdown( | |
| choices=ALL_GT_CLASSES, | |
| value="Unknown / Not provided", | |
| label="Actual Ground-Truth Class (optional)", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| # ββ Right column: outputs ββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=2): | |
| cluster_out = gr.Textbox( | |
| label="π Predicted Cluster / Class", | |
| lines=14, | |
| interactive=False, | |
| ) | |
| top5_out = gr.Textbox( | |
| label="π Top 5 Dominant Pathology Classes", | |
| lines=10, | |
| interactive=False, | |
| ) | |
| actual_out = gr.Textbox( | |
| label="β Actual Class Comparison", | |
| lines=12, | |
| interactive=False, | |
| ) | |
| # ββ Wire up events βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[image_input, actual_input], | |
| outputs=[cluster_out, top5_out, actual_out], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "Unknown / Not provided", "", "", ""), | |
| inputs=[], | |
| outputs=[image_input, actual_input, cluster_out, top5_out, actual_out], | |
| ) | |
| demo.launch() |