import io import os import math import warnings import traceback import cv2 import numpy as np import torch import torch.nn.functional as F import matplotlib matplotlib.use("Agg") # headless backend, must be set before pyplot import import matplotlib.pyplot as plt import matplotlib.patches as mpatches import gradio as gr from PIL import Image from torchvision.transforms import v2 from huggingface_hub import hf_hub_download from model_arch import ( FetalMTLModel, DualFrequencyCascadeGraph, InferenceConfig, load_checkpoint, compute_ctr, ) warnings.filterwarnings("ignore") # ── Runtime device ──────────────────────────────────────────────────────────── DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── Checkpoint source — override via Space env vars ─────────────────────────── MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "AliMusaRizvi/planeclassifier") CHECKPOINT_FILE = os.environ.get("CHECKPOINT_FILE", "best_ema_4heads_20260331_0355.pth") HF_TOKEN = os.environ.get("HF_TOKEN", None) # needed for private repos # ── ImageNet normalisation tensors ──────────────────────────────────────────── _MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1) _STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1) # ── Segmentation colour palette (RGB) ──────────────────────────────────────── SEG_PALETTE = {0: (0, 0, 0), 1: (220, 50, 50), 2: (50, 100, 220)} # ── Global model state ──────────────────────────────────────────────────────── _model = None _gpu_diff = None _cls_names = None _brain_names = None _load_error = None # ── Pre-processing ──────────────────────────────────────────────────────────── def _clahe(pil_img: Image.Image) -> Image.Image: arr = np.array(pil_img.convert("RGB"), dtype=np.uint8) lab = cv2.cvtColor(arr, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) result = cv2.cvtColor(cv2.merge((clahe.apply(l), a, b)), cv2.COLOR_LAB2RGB) return Image.fromarray(result) def _preprocess(pil_img: Image.Image, size: int) -> torch.Tensor: """CLAHE → resize → DualFrequencyGraph → ImageNet-norm → clamp.""" img = _clahe(pil_img.convert("RGB")) tfm = v2.Compose([ v2.ToImage(), v2.Resize((size, size), antialias=True), v2.ToDtype(torch.float32, scale=True), ]) img_t = tfm(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): img_t = _gpu_diff(img_t) img_t = (img_t - _MEAN) / _STD img_t = img_t.clamp(-6.0, 6.0) return img_t # ── Lazy model loader ───────────────────────────────────────────────────────── def _load_model(): global _model, _gpu_diff, _cls_names, _brain_names, _load_error, _MEAN, _STD if _model is not None or _load_error is not None: return try: print(f"[startup] Downloading checkpoint from {MODEL_REPO_ID}/{CHECKPOINT_FILE} …") ckpt_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILE, token=HF_TOKEN, ) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) cls_names_raw = ckpt.get("class_names", []) brain_names_raw = ckpt.get("brain_class_names", []) num_cls = len(cls_names_raw) or 11 num_brain = len(brain_names_raw) or 7 print(f"[startup] Building FetalMTLModel (cls={num_cls}, brain={num_brain}) …") _model, _cls_names, _brain_names = load_checkpoint( ckpt_path, num_cls=num_cls, num_brain=num_brain, device=DEVICE) _gpu_diff = DualFrequencyCascadeGraph( low_freq_ratio=0.3, patch_size=16, n_graph_iter=3, learnable=True ).to(DEVICE).eval() # Move norm tensors to DEVICE after potential CUDA init _MEAN = _MEAN.to(DEVICE) _STD = _STD.to(DEVICE) print(f"[startup] Model ready on {DEVICE}") print(f"[startup] Plane classes : {_cls_names}") print(f"[startup] Brain classes : {_brain_names}") except Exception as exc: _load_error = traceback.format_exc() print(f"[startup] FAILED to load model:\n{_load_error}") # ── Visualisation helpers ───────────────────────────────────────────────────── def _bar_chart(class_names, probs, title="") -> Image.Image: n = len(class_names) height = max(3.0, n * 0.45) fig, ax = plt.subplots(figsize=(7, height)) peak = max(probs) colors = ["#4CAF50" if p == peak else "#2196F3" for p in probs] bars = ax.barh(class_names, probs, color=colors, edgecolor="k", linewidth=0.4) for bar, p in zip(bars, probs): ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2, f"{p:.3f}", va="center", fontsize=8) ax.set_xlim(0, 1.12) ax.set_xlabel("Confidence") if title: ax.set_title(title, fontsize=10) ax.grid(axis="x", alpha=0.3) plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=100, bbox_inches="tight") plt.close(fig) buf.seek(0) return Image.open(buf).copy() def _seg_overlay(orig_np: np.ndarray, mask_np: np.ndarray, alpha=0.45) -> np.ndarray: overlay = orig_np.copy() for cls_id, rgb in SEG_PALETTE.items(): if cls_id == 0: continue region = (mask_np == cls_id) overlay[region] = ( alpha * np.array(rgb, dtype=np.float32) + (1 - alpha) * orig_np[region].astype(np.float32) ).astype(np.uint8) return overlay def _pil_figure(fig) -> Image.Image: buf = io.BytesIO() fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") plt.close(fig) buf.seek(0) return Image.open(buf).copy() def _guard(): """Returns an error string if model isn't ready, else None.""" if _load_error: return ( f"**Model failed to load.**\n\n" f"Set the `MODEL_REPO_ID` and `CHECKPOINT_FILE` Space secrets.\n\n" f"```\n{_load_error[:600]}\n```" ) if _model is None: return "**Model is still loading — please wait a moment.**" return None # ── Head 1: Plane Classification ───────────────────────────────────────────── def predict_plane(pil_img): if (err := _guard()): return err, None if pil_img is None: return "Upload an image first.", None try: img_t = _preprocess(pil_img, InferenceConfig.BACKBONE_IMG_SIZE) with torch.no_grad(): out = _model(img_t, task="cls") probs = F.softmax(out["logits"], dim=1).squeeze(0).cpu().tolist() pred = int(max(range(len(probs)), key=probs.__getitem__)) label = _cls_names[pred] conf = probs[pred] text = f"**Predicted plane:** {label.replace('-', ' ').title()}\n**Confidence:** {conf:.1%}" chart = _bar_chart( [n.replace("-", " ").title() for n in _cls_names], probs, "Plane Classification Probabilities", ) return text, chart except Exception: return f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```", None # ── Head 2: Heart Segmentation ──────────────────────────────────────────────── def predict_segmentation(pil_img): if (err := _guard()): return None, err if pil_img is None: return None, "Upload an image first." try: sz = 384 img_t = _preprocess(pil_img, sz) with torch.no_grad(): out = _model(img_t, task="seg") mask = out["seg"].argmax(1).squeeze(0).cpu().numpy() orig = np.array(pil_img.convert("RGB").resize((sz, sz))) legend = [ mpatches.Patch(color=tuple(c / 255 for c in SEG_PALETTE[1]), label="Cardiac"), mpatches.Patch(color=tuple(c / 255 for c in SEG_PALETTE[2]), label="Thorax"), ] fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(11, 5)) ax0.imshow(orig); ax0.set_title("Input"); ax0.axis("off") ax1.imshow(_seg_overlay(orig, mask)) ax1.legend(handles=legend, loc="lower right", fontsize=8) ax1.set_title("Segmentation"); ax1.axis("off") seg_img = _pil_figure(fig) ctr = compute_ctr(mask) if ctr["CTR_area"] is not None: ctr_text = ( f"**CTR (area-based):** {ctr['CTR_area']:.4f}\n" f"**CTR (diameter):** {ctr['CTR_diam']:.4f}\n" f"**Assessment:** {ctr['flag']}" ) else: ctr_text = f"**Assessment:** {ctr['flag']}" return seg_img, ctr_text except Exception: return None, f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```" # ── Head 3: Down Syndrome Markers ───────────────────────────────────────────── def predict_ds(pil_img): if (err := _guard()): return None, err if pil_img is None: return None, "Upload an image first." try: sz = InferenceConfig.DS_IMG_SIZE # 384 img_t = _preprocess(pil_img, sz) # Zero bbox → model falls back to global average features for NB ROI nb_bbox = torch.zeros(1, 4, device=DEVICE) with torch.no_grad(): out = _model(img_t, task="ds", nb_bboxes=nb_bbox) pred = out["ds"] nt_mm = float(pred["nt_thickness_mm"].squeeze().item()) nb_prob = ( float(torch.sigmoid(pred["nb_logit"]).squeeze().item()) if pred["nb_logit"] is not None else None ) Hm = InferenceConfig.NT_HEATMAP_SZ hm_np = pred["nt_heatmaps"].squeeze(0).cpu().numpy() hm_vis = hm_np[0] + hm_np[1] hm_vis = (hm_vis - hm_vis.min()) / (hm_vis.max() - hm_vis.min() + 1e-8) orig_sm = np.array(pil_img.convert("RGB").resize((Hm, Hm))) fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4)) ax0.imshow(orig_sm); ax0.set_title("Input (96×96)"); ax0.axis("off") ax1.imshow(orig_sm) ax1.imshow(hm_vis, cmap="hot", alpha=0.55) ax1.set_title("NT Keypoint Heatmap (top+bottom KP)"); ax1.axis("off") vis_img = _pil_figure(fig) risk_str = ("⚠ High-risk (NT ≥ 3.5 mm)" if nt_mm >= 3.5 else "Normal range (NT < 3.5 mm)") nb_str = (f"**NB present probability:** {nb_prob:.1%}" if nb_prob is not None else "**NB:** N/A (provide NT bounding box for accurate NB prediction)") text = ( f"**NT Thickness:** {nt_mm:.2f} mm — {risk_str}\n\n" f"{nb_str}\n\n" f"*Note: NB accuracy is limited without a NT bounding-box annotation.*" ) return vis_img, text except Exception: return None, f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```" # ── Head 4: Brain Anomaly Classification ────────────────────────────────────── def predict_brain(pil_img): if (err := _guard()): return err, None if pil_img is None: return "Upload an image first.", None try: img_t = _preprocess(pil_img, InferenceConfig.BACKBONE_IMG_SIZE) with torch.no_grad(): out = _model(img_t, task="brain") probs = F.softmax(out["brain_logits"], dim=1).squeeze(0).cpu().tolist() pred = int(max(range(len(probs)), key=probs.__getitem__)) label = _brain_names[pred] conf = probs[pred] text = ( f"**Predicted condition:** {label.replace('-', ' ').title()}\n" f"**Confidence:** {conf:.1%}" ) chart = _bar_chart( [n.replace("-", " ").title() for n in _brain_names], probs, "Brain Anomaly Probabilities", ) return text, chart except Exception: return f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```", None # ── Gradio UI ───────────────────────────────────────────────────────────────── _DISCLAIMER = ( "> ⚠ **Research demo only** — not validated for clinical use. " "Do not use to guide medical decisions. Consult a qualified clinician." ) with gr.Blocks(title="Fetal Ultrasound MTL", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🫁 Fetal Ultrasound Multi-Task Analysis") gr.Markdown( "**Backbone:** ViL2-small + DynamicGraphTransformer (GNN) | " "**4 diagnostic heads** trained jointly" ) gr.Markdown(_DISCLAIMER) with gr.Tabs(): # ── Head 1 ──────────────────────────────────────────────────────── with gr.TabItem(" Plane Classification"): gr.Markdown( "Classify the fetal US image into one of the 11 standard planes " "(brain-cb, brain-tv, abdominal, femur, etc.)." ) with gr.Row(): h1_img = gr.Image(type="pil", label="Upload ultrasound image") with gr.Column(): h1_text = gr.Markdown() h1_chart = gr.Image(type="pil", label="Class probabilities") h1_btn = gr.Button("Classify plane", variant="primary") h1_btn.click(predict_plane, inputs=h1_img, outputs=[h1_text, h1_chart]) # ── Head 2 ──────────────────────────────────────────────────────── with gr.TabItem(" Heart Segmentation"): gr.Markdown( "Segment cardiac and thoracic structures from the four-chamber view " "and compute the Cardiothoracic Ratio (CTR)." ) with gr.Row(): h2_img = gr.Image(type="pil", label="Upload cardiac view image") with gr.Column(): h2_seg = gr.Image(type="pil", label="Segmentation overlay") h2_ctr = gr.Markdown() h2_btn = gr.Button("Segment", variant="primary") h2_btn.click(predict_segmentation, inputs=h2_img, outputs=[h2_seg, h2_ctr]) # ── Head 3 ──────────────────────────────────────────────────────── with gr.TabItem(" Down Syndrome Markers"): gr.Markdown( "Estimate Nuchal Translucency (NT) thickness and Nasal Bone (NB) presence " "from the sagittal facial/neck view. " "NT ≥ 3.5 mm is a high-risk threshold." ) with gr.Row(): h3_img = gr.Image(type="pil", label="Upload sagittal view") with gr.Column(): h3_vis = gr.Image(type="pil", label="NT heatmap overlay") h3_text = gr.Markdown() h3_btn = gr.Button("Analyse", variant="primary") h3_btn.click(predict_ds, inputs=h3_img, outputs=[h3_vis, h3_text]) # ── Head 4 ──────────────────────────────────────────────────────── with gr.TabItem(" Brain Anomaly"): gr.Markdown( "Classify fetal brain ultrasound into normal or one of several anomaly " "categories (ventriculomegaly, holoprosencephaly, arachnoid cyst, etc.)." ) with gr.Row(): h4_img = gr.Image(type="pil", label="Upload brain ultrasound") with gr.Column(): h4_text = gr.Markdown() h4_chart = gr.Image(type="pil", label="Class probabilities") h4_btn = gr.Button("Classify", variant="primary") h4_btn.click(predict_brain, inputs=h4_img, outputs=[h4_text, h4_chart]) gr.Markdown( "---\n" "**Model:** ViL2-small + DynamicGraphTransformer | " "**Training platform:** Google Colab A100 | " "**Framework:** PyTorch" ) if __name__ == "__main__": _load_model() demo.launch(server_name="0.0.0.0", server_port=7860)