AutoMICE / app.py
namijiang98's picture
Patch gradio_client schema bug (TypeError: argument of type bool)
e4366df
"""AutoMICE — Hugging Face Spaces demo.
Run a low-resolution Swin UNETR segmentation on a single mouse micro-CT
volume uploaded by the user. The full-resolution pipeline lives in the
official Docker image; this demo trades accuracy for speed on free CPU
hardware.
"""
from __future__ import annotations
import os
import tempfile
from io import BytesIO
from typing import Tuple
import gradio as gr
# ---------------------------------------------------------------------------
# Work around a known crash in `gradio_client.utils.json_schema_to_python_type`:
# TypeError: argument of type 'bool' is not iterable
# It happens when a JSON schema contains `additionalProperties: true/false`
# (a bool instead of a dict). Hugging Face Spaces touches /api/info on every
# page load, so this would 500 the whole demo. Patch defensively.
# ---------------------------------------------------------------------------
try:
from gradio_client import utils as _gc_utils
_orig_json_to_python = _gc_utils._json_schema_to_python_type
_orig_get_type = _gc_utils.get_type
def _safe_get_type(schema):
if not isinstance(schema, dict):
return "Any"
return _orig_get_type(schema)
def _safe_json_to_python(schema, defs=None):
if not isinstance(schema, dict):
return "Any"
return _orig_json_to_python(schema, defs)
_gc_utils.get_type = _safe_get_type
_gc_utils._json_schema_to_python_type = _safe_json_to_python
except Exception: # pragma: no cover - never block app startup
pass
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from monai import transforms
from monai.inferers import sliding_window_inference
from monai.networks.nets import SwinUNETR
from PIL import Image
from scipy import ndimage as ndi
HF_REPO_ID = os.environ.get("AUTOMICE_HF_REPO", "namijiang98/AutoMICE")
HF_MODEL_FILENAME = os.environ.get("AUTOMICE_HF_FILENAME", "model.pt")
LABEL_NAMES = {
0: "background",
1: "bladder",
2: "lung",
3: "heart",
4: "liver",
5: "intestine",
6: "kidney",
7: "spleen",
}
LABEL_COLORS = np.array(
[
[0, 0, 0],
[255, 215, 0],
[135, 206, 250],
[220, 20, 60],
[210, 105, 30],
[124, 252, 0],
[186, 85, 211],
[255, 105, 180],
],
dtype=np.uint8,
)
DEMO_SPACING = (0.4, 0.4, 0.4)
DEMO_MAX_DIM = 192
DEFAULT_OVERLAP = 0.5
ROI_SIZE = (96, 96, 96)
_MODEL: SwinUNETR | None = None
def _device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _load_model() -> SwinUNETR:
global _MODEL
if _MODEL is not None:
return _MODEL
weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME)
model = SwinUNETR(
img_size=96,
in_channels=1,
out_channels=8,
feature_size=36,
drop_rate=0.0,
attn_drop_rate=0.0,
dropout_path_rate=0.0,
use_checkpoint=False,
)
ckpt = torch.load(weights_path, map_location="cpu")
state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
model.load_state_dict(state_dict)
model.eval()
model.to(_device())
_MODEL = model
return model
def _build_transform():
return transforms.Compose(
[
transforms.LoadImaged(keys=["image"], reader="NibabelReader"),
transforms.AddChanneld(keys=["image"]),
transforms.Spacingd(keys="image", pixdim=DEMO_SPACING, mode="bilinear"),
transforms.ScaleIntensityRanged(
keys=["image"], a_min=-1000.0, a_max=5000.0, b_min=0.0, b_max=1.0, clip=True
),
transforms.ToTensord(keys=["image"]),
]
)
def _clamp_volume(tensor: torch.Tensor) -> torch.Tensor:
"""Hard limit on volume size so CPU inference finishes in a sensible time."""
_, h, w, d = tensor.shape
if max(h, w, d) <= DEMO_MAX_DIM:
return tensor
scale = DEMO_MAX_DIM / max(h, w, d)
new_shape = (
max(1, int(round(h * scale))),
max(1, int(round(w * scale))),
max(1, int(round(d * scale))),
)
return torch.nn.functional.interpolate(
tensor.unsqueeze(0),
size=new_shape,
mode="trilinear",
align_corners=False,
)[0]
def _overlay(image_slice: np.ndarray, seg_slice: np.ndarray) -> Image.Image:
"""Blend a grayscale CT slice with the coloured segmentation mask."""
img = image_slice.astype(np.float32)
img = (img - img.min()) / max(img.ptp(), 1e-6)
img_rgb = (np.stack([img, img, img], axis=-1) * 255).astype(np.uint8)
mask_rgb = LABEL_COLORS[seg_slice.astype(np.int32)]
alpha = (seg_slice > 0).astype(np.float32)[..., None] * 0.45
blended = (img_rgb * (1 - alpha) + mask_rgb * alpha).astype(np.uint8)
return Image.fromarray(blended)
def _make_montage(volume: np.ndarray, seg: np.ndarray) -> Image.Image:
"""Three-panel mid-slice montage (axial / coronal / sagittal)."""
h, w, d = volume.shape
ax_idx, co_idx, sa_idx = h // 2, w // 2, d // 2
panels = [
_overlay(np.rot90(volume[ax_idx, :, :]), np.rot90(seg[ax_idx, :, :])),
_overlay(np.rot90(volume[:, co_idx, :]), np.rot90(seg[:, co_idx, :])),
_overlay(np.rot90(volume[:, :, sa_idx]), np.rot90(seg[:, :, sa_idx])),
]
titles = ["axial mid-slice", "coronal mid-slice", "sagittal mid-slice"]
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, im, title in zip(axes, panels, titles):
ax.imshow(im)
ax.set_title(title)
ax.axis("off")
fig.tight_layout()
buf = BytesIO()
fig.savefig(buf, format="png", dpi=120, bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return Image.open(buf).copy()
def _legend_text(seg: np.ndarray) -> str:
vals, counts = np.unique(seg, return_counts=True)
rows = ["| index | organ | voxels |", "|-------|-------|--------|"]
for v, c in zip(vals.tolist(), counts.tolist()):
rows.append(f"| {v} | {LABEL_NAMES.get(int(v), 'unknown')} | {int(c)} |")
return "\n".join(rows)
def run_demo(nifti_file, overlap: float) -> Tuple[Image.Image, str, str]:
if nifti_file is None:
raise gr.Error("Please upload a .nii or .nii.gz volume first.")
path = nifti_file if isinstance(nifti_file, str) else nifti_file.name
if not (path.endswith(".nii") or path.endswith(".nii.gz")):
raise gr.Error("Only .nii / .nii.gz files are supported.")
transform = _build_transform()
sample = transform({"image": path})
img_tensor = sample["image"]
img_tensor = _clamp_volume(img_tensor)
device = _device()
model = _load_model()
with torch.no_grad():
inputs = img_tensor.unsqueeze(0).to(device)
logits = sliding_window_inference(
inputs=inputs,
roi_size=ROI_SIZE,
sw_batch_size=1,
predictor=model,
overlap=float(overlap),
mode="gaussian",
)
probs = torch.softmax(logits, dim=1).cpu().numpy()
seg = np.argmax(probs, axis=1).astype(np.uint8)[0]
ct = inputs.cpu().numpy()[0, 0]
montage = _make_montage(ct, seg)
original = nib.load(path)
target_shape = original.shape
if seg.shape != tuple(target_shape):
zoom = tuple(t / s for t, s in zip(target_shape, seg.shape))
seg_full = ndi.zoom(seg, zoom, order=0, prefilter=False).astype(np.uint8)
else:
seg_full = seg
out_path = tempfile.mktemp(suffix="_seg.nii.gz")
nib.save(nib.Nifti1Image(seg_full, original.affine), out_path)
return montage, _legend_text(seg_full), out_path
with gr.Blocks(
title="AutoMICE — Mouse Micro-CT Segmentation",
css="footer {visibility: hidden}",
) as demo:
gr.Markdown(
"""
# AutoMICE — Mouse Micro-CT Multi-Organ Segmentation
Automated **bladder · lung · heart · liver · intestine · kidney · spleen**
segmentation from a single 3D NIfTI volume using Swin UNETR.
> 📄 *Robust Automated Mouse Micro-CT Segmentation Using Swin UNEt TRansformers* — Jiang et al.
> 💻 [GitHub](https://github.com/namijiang/AutoMICE) · 🤗 [Model + Docker image](https://huggingface.co/namijiang98/AutoMICE)
⚠️ **This online demo runs in a downsampled / clamped mode for speed.**
For full-resolution segmentation please use the Docker image or the Python CLI.
"""
)
with gr.Row():
with gr.Column(scale=1):
in_file = gr.File(label="Mouse CT (.nii / .nii.gz)", file_types=[".nii", ".gz"])
overlap = gr.Slider(0.25, 0.85, value=DEFAULT_OVERLAP, step=0.05, label="Sliding-window overlap")
run_btn = gr.Button("Run AutoMICE", variant="primary")
with gr.Column(scale=2):
out_image = gr.Image(label="Mid-slice overlay", type="pil")
out_table = gr.Markdown(label="Voxel counts")
out_file = gr.File(label="Download segmentation (NIfTI)")
run_btn.click(
fn=run_demo,
inputs=[in_file, overlap],
outputs=[out_image, out_table, out_file],
)
gr.Markdown(
"""
---
### How to cite
```bibtex
@article{jiang2024automice,
title = {Robust Automated Mouse Micro-CT Segmentation Using Swin UNEt TRansformers},
author = {Jiang, Lu and Xu, Di and Xu, Qifan and Chatziioannou, Arion and
Iwamoto, Keisuke S. and Hui, Susanta and Sheng, Ke},
journal = {Bioengineering},
year = {2024},
doi = {10.3390/bioengineering11121255}
}
```
"""
)
if __name__ == "__main__":
# show_api=False alone is not enough on HF Spaces; the schema patch above
# is the real fix. We still hide the auto API page to reduce surface area.
demo.queue(max_size=8).launch(show_api=False)