Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import math | |
| import warnings | |
| import traceback | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import matplotlib | |
| matplotlib.use("Agg") # headless backend, must be set before pyplot import | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision.transforms import v2 | |
| from huggingface_hub import hf_hub_download | |
| from model_arch import ( | |
| FetalMTLModel, | |
| DualFrequencyCascadeGraph, | |
| InferenceConfig, | |
| load_checkpoint, | |
| compute_ctr, | |
| ) | |
| warnings.filterwarnings("ignore") | |
| # ββ Runtime device ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ββ Checkpoint source β override via Space env vars βββββββββββββββββββββββββββ | |
| MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "AliMusaRizvi/planeclassifier") | |
| CHECKPOINT_FILE = os.environ.get("CHECKPOINT_FILE", "best_ema_4heads_20260331_0355.pth") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) # needed for private repos | |
| # ββ ImageNet normalisation tensors ββββββββββββββββββββββββββββββββββββββββββββ | |
| _MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1) | |
| _STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1) | |
| # ββ Segmentation colour palette (RGB) ββββββββββββββββββββββββββββββββββββββββ | |
| SEG_PALETTE = {0: (0, 0, 0), 1: (220, 50, 50), 2: (50, 100, 220)} | |
| # ββ Global model state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| _gpu_diff = None | |
| _cls_names = None | |
| _brain_names = None | |
| _load_error = None | |
| # ββ Pre-processing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _clahe(pil_img: Image.Image) -> Image.Image: | |
| arr = np.array(pil_img.convert("RGB"), dtype=np.uint8) | |
| lab = cv2.cvtColor(arr, cv2.COLOR_RGB2LAB) | |
| l, a, b = cv2.split(lab) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| result = cv2.cvtColor(cv2.merge((clahe.apply(l), a, b)), cv2.COLOR_LAB2RGB) | |
| return Image.fromarray(result) | |
| def _preprocess(pil_img: Image.Image, size: int) -> torch.Tensor: | |
| """CLAHE β resize β DualFrequencyGraph β ImageNet-norm β clamp.""" | |
| img = _clahe(pil_img.convert("RGB")) | |
| tfm = v2.Compose([ | |
| v2.ToImage(), | |
| v2.Resize((size, size), antialias=True), | |
| v2.ToDtype(torch.float32, scale=True), | |
| ]) | |
| img_t = tfm(img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| img_t = _gpu_diff(img_t) | |
| img_t = (img_t - _MEAN) / _STD | |
| img_t = img_t.clamp(-6.0, 6.0) | |
| return img_t | |
| # ββ Lazy model loader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_model(): | |
| global _model, _gpu_diff, _cls_names, _brain_names, _load_error, _MEAN, _STD | |
| if _model is not None or _load_error is not None: | |
| return | |
| try: | |
| print(f"[startup] Downloading checkpoint from {MODEL_REPO_ID}/{CHECKPOINT_FILE} β¦") | |
| ckpt_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=CHECKPOINT_FILE, | |
| token=HF_TOKEN, | |
| ) | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| cls_names_raw = ckpt.get("class_names", []) | |
| brain_names_raw = ckpt.get("brain_class_names", []) | |
| num_cls = len(cls_names_raw) or 11 | |
| num_brain = len(brain_names_raw) or 7 | |
| print(f"[startup] Building FetalMTLModel (cls={num_cls}, brain={num_brain}) β¦") | |
| _model, _cls_names, _brain_names = load_checkpoint( | |
| ckpt_path, num_cls=num_cls, num_brain=num_brain, device=DEVICE) | |
| _gpu_diff = DualFrequencyCascadeGraph( | |
| low_freq_ratio=0.3, patch_size=16, n_graph_iter=3, learnable=True | |
| ).to(DEVICE).eval() | |
| # Move norm tensors to DEVICE after potential CUDA init | |
| _MEAN = _MEAN.to(DEVICE) | |
| _STD = _STD.to(DEVICE) | |
| print(f"[startup] Model ready on {DEVICE}") | |
| print(f"[startup] Plane classes : {_cls_names}") | |
| print(f"[startup] Brain classes : {_brain_names}") | |
| except Exception as exc: | |
| _load_error = traceback.format_exc() | |
| print(f"[startup] FAILED to load model:\n{_load_error}") | |
| # ββ Visualisation helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _bar_chart(class_names, probs, title="") -> Image.Image: | |
| n = len(class_names) | |
| height = max(3.0, n * 0.45) | |
| fig, ax = plt.subplots(figsize=(7, height)) | |
| peak = max(probs) | |
| colors = ["#4CAF50" if p == peak else "#2196F3" for p in probs] | |
| bars = ax.barh(class_names, probs, color=colors, edgecolor="k", linewidth=0.4) | |
| for bar, p in zip(bars, probs): | |
| ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2, | |
| f"{p:.3f}", va="center", fontsize=8) | |
| ax.set_xlim(0, 1.12) | |
| ax.set_xlabel("Confidence") | |
| if title: | |
| ax.set_title(title, fontsize=10) | |
| ax.grid(axis="x", alpha=0.3) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=100, bbox_inches="tight") | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf).copy() | |
| def _seg_overlay(orig_np: np.ndarray, mask_np: np.ndarray, alpha=0.45) -> np.ndarray: | |
| overlay = orig_np.copy() | |
| for cls_id, rgb in SEG_PALETTE.items(): | |
| if cls_id == 0: | |
| continue | |
| region = (mask_np == cls_id) | |
| overlay[region] = ( | |
| alpha * np.array(rgb, dtype=np.float32) + | |
| (1 - alpha) * orig_np[region].astype(np.float32) | |
| ).astype(np.uint8) | |
| return overlay | |
| def _pil_figure(fig) -> Image.Image: | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf).copy() | |
| def _guard(): | |
| """Returns an error string if model isn't ready, else None.""" | |
| if _load_error: | |
| return ( | |
| f"**Model failed to load.**\n\n" | |
| f"Set the `MODEL_REPO_ID` and `CHECKPOINT_FILE` Space secrets.\n\n" | |
| f"```\n{_load_error[:600]}\n```" | |
| ) | |
| if _model is None: | |
| return "**Model is still loading β please wait a moment.**" | |
| return None | |
| # ββ Head 1: Plane Classification βββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_plane(pil_img): | |
| if (err := _guard()): | |
| return err, None | |
| if pil_img is None: | |
| return "Upload an image first.", None | |
| try: | |
| img_t = _preprocess(pil_img, InferenceConfig.BACKBONE_IMG_SIZE) | |
| with torch.no_grad(): | |
| out = _model(img_t, task="cls") | |
| probs = F.softmax(out["logits"], dim=1).squeeze(0).cpu().tolist() | |
| pred = int(max(range(len(probs)), key=probs.__getitem__)) | |
| label = _cls_names[pred] | |
| conf = probs[pred] | |
| text = f"**Predicted plane:** {label.replace('-', ' ').title()}\n**Confidence:** {conf:.1%}" | |
| chart = _bar_chart( | |
| [n.replace("-", " ").title() for n in _cls_names], | |
| probs, | |
| "Plane Classification Probabilities", | |
| ) | |
| return text, chart | |
| except Exception: | |
| return f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```", None | |
| # ββ Head 2: Heart Segmentation ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_segmentation(pil_img): | |
| if (err := _guard()): | |
| return None, err | |
| if pil_img is None: | |
| return None, "Upload an image first." | |
| try: | |
| sz = 384 | |
| img_t = _preprocess(pil_img, sz) | |
| with torch.no_grad(): | |
| out = _model(img_t, task="seg") | |
| mask = out["seg"].argmax(1).squeeze(0).cpu().numpy() | |
| orig = np.array(pil_img.convert("RGB").resize((sz, sz))) | |
| legend = [ | |
| mpatches.Patch(color=tuple(c / 255 for c in SEG_PALETTE[1]), label="Cardiac"), | |
| mpatches.Patch(color=tuple(c / 255 for c in SEG_PALETTE[2]), label="Thorax"), | |
| ] | |
| fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(11, 5)) | |
| ax0.imshow(orig); ax0.set_title("Input"); ax0.axis("off") | |
| ax1.imshow(_seg_overlay(orig, mask)) | |
| ax1.legend(handles=legend, loc="lower right", fontsize=8) | |
| ax1.set_title("Segmentation"); ax1.axis("off") | |
| seg_img = _pil_figure(fig) | |
| ctr = compute_ctr(mask) | |
| if ctr["CTR_area"] is not None: | |
| ctr_text = ( | |
| f"**CTR (area-based):** {ctr['CTR_area']:.4f}\n" | |
| f"**CTR (diameter):** {ctr['CTR_diam']:.4f}\n" | |
| f"**Assessment:** {ctr['flag']}" | |
| ) | |
| else: | |
| ctr_text = f"**Assessment:** {ctr['flag']}" | |
| return seg_img, ctr_text | |
| except Exception: | |
| return None, f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```" | |
| # ββ Head 3: Down Syndrome Markers βββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_ds(pil_img): | |
| if (err := _guard()): | |
| return None, err | |
| if pil_img is None: | |
| return None, "Upload an image first." | |
| try: | |
| sz = InferenceConfig.DS_IMG_SIZE # 384 | |
| img_t = _preprocess(pil_img, sz) | |
| # Zero bbox β model falls back to global average features for NB ROI | |
| nb_bbox = torch.zeros(1, 4, device=DEVICE) | |
| with torch.no_grad(): | |
| out = _model(img_t, task="ds", nb_bboxes=nb_bbox) | |
| pred = out["ds"] | |
| nt_mm = float(pred["nt_thickness_mm"].squeeze().item()) | |
| nb_prob = ( | |
| float(torch.sigmoid(pred["nb_logit"]).squeeze().item()) | |
| if pred["nb_logit"] is not None else None | |
| ) | |
| Hm = InferenceConfig.NT_HEATMAP_SZ | |
| hm_np = pred["nt_heatmaps"].squeeze(0).cpu().numpy() | |
| hm_vis = hm_np[0] + hm_np[1] | |
| hm_vis = (hm_vis - hm_vis.min()) / (hm_vis.max() - hm_vis.min() + 1e-8) | |
| orig_sm = np.array(pil_img.convert("RGB").resize((Hm, Hm))) | |
| fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4)) | |
| ax0.imshow(orig_sm); ax0.set_title("Input (96Γ96)"); ax0.axis("off") | |
| ax1.imshow(orig_sm) | |
| ax1.imshow(hm_vis, cmap="hot", alpha=0.55) | |
| ax1.set_title("NT Keypoint Heatmap (top+bottom KP)"); ax1.axis("off") | |
| vis_img = _pil_figure(fig) | |
| risk_str = ("β High-risk (NT β₯ 3.5 mm)" if nt_mm >= 3.5 | |
| else "Normal range (NT < 3.5 mm)") | |
| nb_str = (f"**NB present probability:** {nb_prob:.1%}" | |
| if nb_prob is not None | |
| else "**NB:** N/A (provide NT bounding box for accurate NB prediction)") | |
| text = ( | |
| f"**NT Thickness:** {nt_mm:.2f} mm β {risk_str}\n\n" | |
| f"{nb_str}\n\n" | |
| f"*Note: NB accuracy is limited without a NT bounding-box annotation.*" | |
| ) | |
| return vis_img, text | |
| except Exception: | |
| return None, f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```" | |
| # ββ Head 4: Brain Anomaly Classification ββββββββββββββββββββββββββββββββββββββ | |
| def predict_brain(pil_img): | |
| if (err := _guard()): | |
| return err, None | |
| if pil_img is None: | |
| return "Upload an image first.", None | |
| try: | |
| img_t = _preprocess(pil_img, InferenceConfig.BACKBONE_IMG_SIZE) | |
| with torch.no_grad(): | |
| out = _model(img_t, task="brain") | |
| probs = F.softmax(out["brain_logits"], dim=1).squeeze(0).cpu().tolist() | |
| pred = int(max(range(len(probs)), key=probs.__getitem__)) | |
| label = _brain_names[pred] | |
| conf = probs[pred] | |
| text = ( | |
| f"**Predicted condition:** {label.replace('-', ' ').title()}\n" | |
| f"**Confidence:** {conf:.1%}" | |
| ) | |
| chart = _bar_chart( | |
| [n.replace("-", " ").title() for n in _brain_names], | |
| probs, | |
| "Brain Anomaly Probabilities", | |
| ) | |
| return text, chart | |
| except Exception: | |
| return f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```", None | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _DISCLAIMER = ( | |
| "> β **Research demo only** β not validated for clinical use. " | |
| "Do not use to guide medical decisions. Consult a qualified clinician." | |
| ) | |
| with gr.Blocks(title="Fetal Ultrasound MTL", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π« Fetal Ultrasound Multi-Task Analysis") | |
| gr.Markdown( | |
| "**Backbone:** ViL2-small + DynamicGraphTransformer (GNN) | " | |
| "**4 diagnostic heads** trained jointly" | |
| ) | |
| gr.Markdown(_DISCLAIMER) | |
| with gr.Tabs(): | |
| # ββ Head 1 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem(" Plane Classification"): | |
| gr.Markdown( | |
| "Classify the fetal US image into one of the 11 standard planes " | |
| "(brain-cb, brain-tv, abdominal, femur, etc.)." | |
| ) | |
| with gr.Row(): | |
| h1_img = gr.Image(type="pil", label="Upload ultrasound image") | |
| with gr.Column(): | |
| h1_text = gr.Markdown() | |
| h1_chart = gr.Image(type="pil", label="Class probabilities") | |
| h1_btn = gr.Button("Classify plane", variant="primary") | |
| h1_btn.click(predict_plane, inputs=h1_img, outputs=[h1_text, h1_chart]) | |
| # ββ Head 2 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem(" Heart Segmentation"): | |
| gr.Markdown( | |
| "Segment cardiac and thoracic structures from the four-chamber view " | |
| "and compute the Cardiothoracic Ratio (CTR)." | |
| ) | |
| with gr.Row(): | |
| h2_img = gr.Image(type="pil", label="Upload cardiac view image") | |
| with gr.Column(): | |
| h2_seg = gr.Image(type="pil", label="Segmentation overlay") | |
| h2_ctr = gr.Markdown() | |
| h2_btn = gr.Button("Segment", variant="primary") | |
| h2_btn.click(predict_segmentation, inputs=h2_img, outputs=[h2_seg, h2_ctr]) | |
| # ββ Head 3 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem(" Down Syndrome Markers"): | |
| gr.Markdown( | |
| "Estimate Nuchal Translucency (NT) thickness and Nasal Bone (NB) presence " | |
| "from the sagittal facial/neck view. " | |
| "NT β₯ 3.5 mm is a high-risk threshold." | |
| ) | |
| with gr.Row(): | |
| h3_img = gr.Image(type="pil", label="Upload sagittal view") | |
| with gr.Column(): | |
| h3_vis = gr.Image(type="pil", label="NT heatmap overlay") | |
| h3_text = gr.Markdown() | |
| h3_btn = gr.Button("Analyse", variant="primary") | |
| h3_btn.click(predict_ds, inputs=h3_img, outputs=[h3_vis, h3_text]) | |
| # ββ Head 4 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem(" Brain Anomaly"): | |
| gr.Markdown( | |
| "Classify fetal brain ultrasound into normal or one of several anomaly " | |
| "categories (ventriculomegaly, holoprosencephaly, arachnoid cyst, etc.)." | |
| ) | |
| with gr.Row(): | |
| h4_img = gr.Image(type="pil", label="Upload brain ultrasound") | |
| with gr.Column(): | |
| h4_text = gr.Markdown() | |
| h4_chart = gr.Image(type="pil", label="Class probabilities") | |
| h4_btn = gr.Button("Classify", variant="primary") | |
| h4_btn.click(predict_brain, inputs=h4_img, outputs=[h4_text, h4_chart]) | |
| gr.Markdown( | |
| "---\n" | |
| "**Model:** ViL2-small + DynamicGraphTransformer | " | |
| "**Training platform:** Google Colab A100 | " | |
| "**Framework:** PyTorch" | |
| ) | |
| if __name__ == "__main__": | |
| _load_model() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |