Fetal / app.py
AliMusaRizvi's picture
Update app.py
44a888e verified
import io
import os
import math
import warnings
import traceback
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg") # headless backend, must be set before pyplot import
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import gradio as gr
from PIL import Image
from torchvision.transforms import v2
from huggingface_hub import hf_hub_download
from model_arch import (
FetalMTLModel,
DualFrequencyCascadeGraph,
InferenceConfig,
load_checkpoint,
compute_ctr,
)
warnings.filterwarnings("ignore")
# ── Runtime device ────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Checkpoint source β€” override via Space env vars ───────────────────────────
MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "AliMusaRizvi/planeclassifier")
CHECKPOINT_FILE = os.environ.get("CHECKPOINT_FILE", "best_ema_4heads_20260331_0355.pth")
HF_TOKEN = os.environ.get("HF_TOKEN", None) # needed for private repos
# ── ImageNet normalisation tensors ────────────────────────────────────────────
_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)
# ── Segmentation colour palette (RGB) ────────────────────────────────────────
SEG_PALETTE = {0: (0, 0, 0), 1: (220, 50, 50), 2: (50, 100, 220)}
# ── Global model state ────────────────────────────────────────────────────────
_model = None
_gpu_diff = None
_cls_names = None
_brain_names = None
_load_error = None
# ── Pre-processing ────────────────────────────────────────────────────────────
def _clahe(pil_img: Image.Image) -> Image.Image:
arr = np.array(pil_img.convert("RGB"), dtype=np.uint8)
lab = cv2.cvtColor(arr, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
result = cv2.cvtColor(cv2.merge((clahe.apply(l), a, b)), cv2.COLOR_LAB2RGB)
return Image.fromarray(result)
def _preprocess(pil_img: Image.Image, size: int) -> torch.Tensor:
"""CLAHE β†’ resize β†’ DualFrequencyGraph β†’ ImageNet-norm β†’ clamp."""
img = _clahe(pil_img.convert("RGB"))
tfm = v2.Compose([
v2.ToImage(),
v2.Resize((size, size), antialias=True),
v2.ToDtype(torch.float32, scale=True),
])
img_t = tfm(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
img_t = _gpu_diff(img_t)
img_t = (img_t - _MEAN) / _STD
img_t = img_t.clamp(-6.0, 6.0)
return img_t
# ── Lazy model loader ─────────────────────────────────────────────────────────
def _load_model():
global _model, _gpu_diff, _cls_names, _brain_names, _load_error, _MEAN, _STD
if _model is not None or _load_error is not None:
return
try:
print(f"[startup] Downloading checkpoint from {MODEL_REPO_ID}/{CHECKPOINT_FILE} …")
ckpt_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=CHECKPOINT_FILE,
token=HF_TOKEN,
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cls_names_raw = ckpt.get("class_names", [])
brain_names_raw = ckpt.get("brain_class_names", [])
num_cls = len(cls_names_raw) or 11
num_brain = len(brain_names_raw) or 7
print(f"[startup] Building FetalMTLModel (cls={num_cls}, brain={num_brain}) …")
_model, _cls_names, _brain_names = load_checkpoint(
ckpt_path, num_cls=num_cls, num_brain=num_brain, device=DEVICE)
_gpu_diff = DualFrequencyCascadeGraph(
low_freq_ratio=0.3, patch_size=16, n_graph_iter=3, learnable=True
).to(DEVICE).eval()
# Move norm tensors to DEVICE after potential CUDA init
_MEAN = _MEAN.to(DEVICE)
_STD = _STD.to(DEVICE)
print(f"[startup] Model ready on {DEVICE}")
print(f"[startup] Plane classes : {_cls_names}")
print(f"[startup] Brain classes : {_brain_names}")
except Exception as exc:
_load_error = traceback.format_exc()
print(f"[startup] FAILED to load model:\n{_load_error}")
# ── Visualisation helpers ─────────────────────────────────────────────────────
def _bar_chart(class_names, probs, title="") -> Image.Image:
n = len(class_names)
height = max(3.0, n * 0.45)
fig, ax = plt.subplots(figsize=(7, height))
peak = max(probs)
colors = ["#4CAF50" if p == peak else "#2196F3" for p in probs]
bars = ax.barh(class_names, probs, color=colors, edgecolor="k", linewidth=0.4)
for bar, p in zip(bars, probs):
ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2,
f"{p:.3f}", va="center", fontsize=8)
ax.set_xlim(0, 1.12)
ax.set_xlabel("Confidence")
if title:
ax.set_title(title, fontsize=10)
ax.grid(axis="x", alpha=0.3)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=100, bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return Image.open(buf).copy()
def _seg_overlay(orig_np: np.ndarray, mask_np: np.ndarray, alpha=0.45) -> np.ndarray:
overlay = orig_np.copy()
for cls_id, rgb in SEG_PALETTE.items():
if cls_id == 0:
continue
region = (mask_np == cls_id)
overlay[region] = (
alpha * np.array(rgb, dtype=np.float32) +
(1 - alpha) * orig_np[region].astype(np.float32)
).astype(np.uint8)
return overlay
def _pil_figure(fig) -> Image.Image:
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120, bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return Image.open(buf).copy()
def _guard():
"""Returns an error string if model isn't ready, else None."""
if _load_error:
return (
f"**Model failed to load.**\n\n"
f"Set the `MODEL_REPO_ID` and `CHECKPOINT_FILE` Space secrets.\n\n"
f"```\n{_load_error[:600]}\n```"
)
if _model is None:
return "**Model is still loading β€” please wait a moment.**"
return None
# ── Head 1: Plane Classification ─────────────────────────────────────────────
def predict_plane(pil_img):
if (err := _guard()):
return err, None
if pil_img is None:
return "Upload an image first.", None
try:
img_t = _preprocess(pil_img, InferenceConfig.BACKBONE_IMG_SIZE)
with torch.no_grad():
out = _model(img_t, task="cls")
probs = F.softmax(out["logits"], dim=1).squeeze(0).cpu().tolist()
pred = int(max(range(len(probs)), key=probs.__getitem__))
label = _cls_names[pred]
conf = probs[pred]
text = f"**Predicted plane:** {label.replace('-', ' ').title()}\n**Confidence:** {conf:.1%}"
chart = _bar_chart(
[n.replace("-", " ").title() for n in _cls_names],
probs,
"Plane Classification Probabilities",
)
return text, chart
except Exception:
return f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```", None
# ── Head 2: Heart Segmentation ────────────────────────────────────────────────
def predict_segmentation(pil_img):
if (err := _guard()):
return None, err
if pil_img is None:
return None, "Upload an image first."
try:
sz = 384
img_t = _preprocess(pil_img, sz)
with torch.no_grad():
out = _model(img_t, task="seg")
mask = out["seg"].argmax(1).squeeze(0).cpu().numpy()
orig = np.array(pil_img.convert("RGB").resize((sz, sz)))
legend = [
mpatches.Patch(color=tuple(c / 255 for c in SEG_PALETTE[1]), label="Cardiac"),
mpatches.Patch(color=tuple(c / 255 for c in SEG_PALETTE[2]), label="Thorax"),
]
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(11, 5))
ax0.imshow(orig); ax0.set_title("Input"); ax0.axis("off")
ax1.imshow(_seg_overlay(orig, mask))
ax1.legend(handles=legend, loc="lower right", fontsize=8)
ax1.set_title("Segmentation"); ax1.axis("off")
seg_img = _pil_figure(fig)
ctr = compute_ctr(mask)
if ctr["CTR_area"] is not None:
ctr_text = (
f"**CTR (area-based):** {ctr['CTR_area']:.4f}\n"
f"**CTR (diameter):** {ctr['CTR_diam']:.4f}\n"
f"**Assessment:** {ctr['flag']}"
)
else:
ctr_text = f"**Assessment:** {ctr['flag']}"
return seg_img, ctr_text
except Exception:
return None, f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```"
# ── Head 3: Down Syndrome Markers ─────────────────────────────────────────────
def predict_ds(pil_img):
if (err := _guard()):
return None, err
if pil_img is None:
return None, "Upload an image first."
try:
sz = InferenceConfig.DS_IMG_SIZE # 384
img_t = _preprocess(pil_img, sz)
# Zero bbox β†’ model falls back to global average features for NB ROI
nb_bbox = torch.zeros(1, 4, device=DEVICE)
with torch.no_grad():
out = _model(img_t, task="ds", nb_bboxes=nb_bbox)
pred = out["ds"]
nt_mm = float(pred["nt_thickness_mm"].squeeze().item())
nb_prob = (
float(torch.sigmoid(pred["nb_logit"]).squeeze().item())
if pred["nb_logit"] is not None else None
)
Hm = InferenceConfig.NT_HEATMAP_SZ
hm_np = pred["nt_heatmaps"].squeeze(0).cpu().numpy()
hm_vis = hm_np[0] + hm_np[1]
hm_vis = (hm_vis - hm_vis.min()) / (hm_vis.max() - hm_vis.min() + 1e-8)
orig_sm = np.array(pil_img.convert("RGB").resize((Hm, Hm)))
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4))
ax0.imshow(orig_sm); ax0.set_title("Input (96Γ—96)"); ax0.axis("off")
ax1.imshow(orig_sm)
ax1.imshow(hm_vis, cmap="hot", alpha=0.55)
ax1.set_title("NT Keypoint Heatmap (top+bottom KP)"); ax1.axis("off")
vis_img = _pil_figure(fig)
risk_str = ("⚠ High-risk (NT β‰₯ 3.5 mm)" if nt_mm >= 3.5
else "Normal range (NT < 3.5 mm)")
nb_str = (f"**NB present probability:** {nb_prob:.1%}"
if nb_prob is not None
else "**NB:** N/A (provide NT bounding box for accurate NB prediction)")
text = (
f"**NT Thickness:** {nt_mm:.2f} mm β€” {risk_str}\n\n"
f"{nb_str}\n\n"
f"*Note: NB accuracy is limited without a NT bounding-box annotation.*"
)
return vis_img, text
except Exception:
return None, f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```"
# ── Head 4: Brain Anomaly Classification ──────────────────────────────────────
def predict_brain(pil_img):
if (err := _guard()):
return err, None
if pil_img is None:
return "Upload an image first.", None
try:
img_t = _preprocess(pil_img, InferenceConfig.BACKBONE_IMG_SIZE)
with torch.no_grad():
out = _model(img_t, task="brain")
probs = F.softmax(out["brain_logits"], dim=1).squeeze(0).cpu().tolist()
pred = int(max(range(len(probs)), key=probs.__getitem__))
label = _brain_names[pred]
conf = probs[pred]
text = (
f"**Predicted condition:** {label.replace('-', ' ').title()}\n"
f"**Confidence:** {conf:.1%}"
)
chart = _bar_chart(
[n.replace("-", " ").title() for n in _brain_names],
probs,
"Brain Anomaly Probabilities",
)
return text, chart
except Exception:
return f"**Inference error:**\n```\n{traceback.format_exc()[:400]}\n```", None
# ── Gradio UI ─────────────────────────────────────────────────────────────────
_DISCLAIMER = (
"> ⚠ **Research demo only** β€” not validated for clinical use. "
"Do not use to guide medical decisions. Consult a qualified clinician."
)
with gr.Blocks(title="Fetal Ultrasound MTL", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🫁 Fetal Ultrasound Multi-Task Analysis")
gr.Markdown(
"**Backbone:** ViL2-small + DynamicGraphTransformer (GNN) | "
"**4 diagnostic heads** trained jointly"
)
gr.Markdown(_DISCLAIMER)
with gr.Tabs():
# ── Head 1 ────────────────────────────────────────────────────────
with gr.TabItem(" Plane Classification"):
gr.Markdown(
"Classify the fetal US image into one of the 11 standard planes "
"(brain-cb, brain-tv, abdominal, femur, etc.)."
)
with gr.Row():
h1_img = gr.Image(type="pil", label="Upload ultrasound image")
with gr.Column():
h1_text = gr.Markdown()
h1_chart = gr.Image(type="pil", label="Class probabilities")
h1_btn = gr.Button("Classify plane", variant="primary")
h1_btn.click(predict_plane, inputs=h1_img, outputs=[h1_text, h1_chart])
# ── Head 2 ────────────────────────────────────────────────────────
with gr.TabItem(" Heart Segmentation"):
gr.Markdown(
"Segment cardiac and thoracic structures from the four-chamber view "
"and compute the Cardiothoracic Ratio (CTR)."
)
with gr.Row():
h2_img = gr.Image(type="pil", label="Upload cardiac view image")
with gr.Column():
h2_seg = gr.Image(type="pil", label="Segmentation overlay")
h2_ctr = gr.Markdown()
h2_btn = gr.Button("Segment", variant="primary")
h2_btn.click(predict_segmentation, inputs=h2_img, outputs=[h2_seg, h2_ctr])
# ── Head 3 ────────────────────────────────────────────────────────
with gr.TabItem(" Down Syndrome Markers"):
gr.Markdown(
"Estimate Nuchal Translucency (NT) thickness and Nasal Bone (NB) presence "
"from the sagittal facial/neck view. "
"NT β‰₯ 3.5 mm is a high-risk threshold."
)
with gr.Row():
h3_img = gr.Image(type="pil", label="Upload sagittal view")
with gr.Column():
h3_vis = gr.Image(type="pil", label="NT heatmap overlay")
h3_text = gr.Markdown()
h3_btn = gr.Button("Analyse", variant="primary")
h3_btn.click(predict_ds, inputs=h3_img, outputs=[h3_vis, h3_text])
# ── Head 4 ────────────────────────────────────────────────────────
with gr.TabItem(" Brain Anomaly"):
gr.Markdown(
"Classify fetal brain ultrasound into normal or one of several anomaly "
"categories (ventriculomegaly, holoprosencephaly, arachnoid cyst, etc.)."
)
with gr.Row():
h4_img = gr.Image(type="pil", label="Upload brain ultrasound")
with gr.Column():
h4_text = gr.Markdown()
h4_chart = gr.Image(type="pil", label="Class probabilities")
h4_btn = gr.Button("Classify", variant="primary")
h4_btn.click(predict_brain, inputs=h4_img, outputs=[h4_text, h4_chart])
gr.Markdown(
"---\n"
"**Model:** ViL2-small + DynamicGraphTransformer | "
"**Training platform:** Google Colab A100 | "
"**Framework:** PyTorch"
)
if __name__ == "__main__":
_load_model()
demo.launch(server_name="0.0.0.0", server_port=7860)