"""AutoMICE — Hugging Face Spaces demo. Run a low-resolution Swin UNETR segmentation on a single mouse micro-CT volume uploaded by the user. The full-resolution pipeline lives in the official Docker image; this demo trades accuracy for speed on free CPU hardware. """ from __future__ import annotations import os import tempfile from io import BytesIO from typing import Tuple import gradio as gr # --------------------------------------------------------------------------- # Work around a known crash in `gradio_client.utils.json_schema_to_python_type`: # TypeError: argument of type 'bool' is not iterable # It happens when a JSON schema contains `additionalProperties: true/false` # (a bool instead of a dict). Hugging Face Spaces touches /api/info on every # page load, so this would 500 the whole demo. Patch defensively. # --------------------------------------------------------------------------- try: from gradio_client import utils as _gc_utils _orig_json_to_python = _gc_utils._json_schema_to_python_type _orig_get_type = _gc_utils.get_type def _safe_get_type(schema): if not isinstance(schema, dict): return "Any" return _orig_get_type(schema) def _safe_json_to_python(schema, defs=None): if not isinstance(schema, dict): return "Any" return _orig_json_to_python(schema, defs) _gc_utils.get_type = _safe_get_type _gc_utils._json_schema_to_python_type = _safe_json_to_python except Exception: # pragma: no cover - never block app startup pass import matplotlib.pyplot as plt import nibabel as nib import numpy as np import torch from huggingface_hub import hf_hub_download from monai import transforms from monai.inferers import sliding_window_inference from monai.networks.nets import SwinUNETR from PIL import Image from scipy import ndimage as ndi HF_REPO_ID = os.environ.get("AUTOMICE_HF_REPO", "namijiang98/AutoMICE") HF_MODEL_FILENAME = os.environ.get("AUTOMICE_HF_FILENAME", "model.pt") LABEL_NAMES = { 0: "background", 1: "bladder", 2: "lung", 3: "heart", 4: "liver", 5: "intestine", 6: "kidney", 7: "spleen", } LABEL_COLORS = np.array( [ [0, 0, 0], [255, 215, 0], [135, 206, 250], [220, 20, 60], [210, 105, 30], [124, 252, 0], [186, 85, 211], [255, 105, 180], ], dtype=np.uint8, ) DEMO_SPACING = (0.4, 0.4, 0.4) DEMO_MAX_DIM = 192 DEFAULT_OVERLAP = 0.5 ROI_SIZE = (96, 96, 96) _MODEL: SwinUNETR | None = None def _device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _load_model() -> SwinUNETR: global _MODEL if _MODEL is not None: return _MODEL weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME) model = SwinUNETR( img_size=96, in_channels=1, out_channels=8, feature_size=36, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_checkpoint=False, ) ckpt = torch.load(weights_path, map_location="cpu") state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt model.load_state_dict(state_dict) model.eval() model.to(_device()) _MODEL = model return model def _build_transform(): return transforms.Compose( [ transforms.LoadImaged(keys=["image"], reader="NibabelReader"), transforms.AddChanneld(keys=["image"]), transforms.Spacingd(keys="image", pixdim=DEMO_SPACING, mode="bilinear"), transforms.ScaleIntensityRanged( keys=["image"], a_min=-1000.0, a_max=5000.0, b_min=0.0, b_max=1.0, clip=True ), transforms.ToTensord(keys=["image"]), ] ) def _clamp_volume(tensor: torch.Tensor) -> torch.Tensor: """Hard limit on volume size so CPU inference finishes in a sensible time.""" _, h, w, d = tensor.shape if max(h, w, d) <= DEMO_MAX_DIM: return tensor scale = DEMO_MAX_DIM / max(h, w, d) new_shape = ( max(1, int(round(h * scale))), max(1, int(round(w * scale))), max(1, int(round(d * scale))), ) return torch.nn.functional.interpolate( tensor.unsqueeze(0), size=new_shape, mode="trilinear", align_corners=False, )[0] def _overlay(image_slice: np.ndarray, seg_slice: np.ndarray) -> Image.Image: """Blend a grayscale CT slice with the coloured segmentation mask.""" img = image_slice.astype(np.float32) img = (img - img.min()) / max(img.ptp(), 1e-6) img_rgb = (np.stack([img, img, img], axis=-1) * 255).astype(np.uint8) mask_rgb = LABEL_COLORS[seg_slice.astype(np.int32)] alpha = (seg_slice > 0).astype(np.float32)[..., None] * 0.45 blended = (img_rgb * (1 - alpha) + mask_rgb * alpha).astype(np.uint8) return Image.fromarray(blended) def _make_montage(volume: np.ndarray, seg: np.ndarray) -> Image.Image: """Three-panel mid-slice montage (axial / coronal / sagittal).""" h, w, d = volume.shape ax_idx, co_idx, sa_idx = h // 2, w // 2, d // 2 panels = [ _overlay(np.rot90(volume[ax_idx, :, :]), np.rot90(seg[ax_idx, :, :])), _overlay(np.rot90(volume[:, co_idx, :]), np.rot90(seg[:, co_idx, :])), _overlay(np.rot90(volume[:, :, sa_idx]), np.rot90(seg[:, :, sa_idx])), ] titles = ["axial mid-slice", "coronal mid-slice", "sagittal mid-slice"] fig, axes = plt.subplots(1, 3, figsize=(12, 4)) for ax, im, title in zip(axes, panels, titles): ax.imshow(im) ax.set_title(title) ax.axis("off") fig.tight_layout() buf = BytesIO() fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") plt.close(fig) buf.seek(0) return Image.open(buf).copy() def _legend_text(seg: np.ndarray) -> str: vals, counts = np.unique(seg, return_counts=True) rows = ["| index | organ | voxels |", "|-------|-------|--------|"] for v, c in zip(vals.tolist(), counts.tolist()): rows.append(f"| {v} | {LABEL_NAMES.get(int(v), 'unknown')} | {int(c)} |") return "\n".join(rows) def run_demo(nifti_file, overlap: float) -> Tuple[Image.Image, str, str]: if nifti_file is None: raise gr.Error("Please upload a .nii or .nii.gz volume first.") path = nifti_file if isinstance(nifti_file, str) else nifti_file.name if not (path.endswith(".nii") or path.endswith(".nii.gz")): raise gr.Error("Only .nii / .nii.gz files are supported.") transform = _build_transform() sample = transform({"image": path}) img_tensor = sample["image"] img_tensor = _clamp_volume(img_tensor) device = _device() model = _load_model() with torch.no_grad(): inputs = img_tensor.unsqueeze(0).to(device) logits = sliding_window_inference( inputs=inputs, roi_size=ROI_SIZE, sw_batch_size=1, predictor=model, overlap=float(overlap), mode="gaussian", ) probs = torch.softmax(logits, dim=1).cpu().numpy() seg = np.argmax(probs, axis=1).astype(np.uint8)[0] ct = inputs.cpu().numpy()[0, 0] montage = _make_montage(ct, seg) original = nib.load(path) target_shape = original.shape if seg.shape != tuple(target_shape): zoom = tuple(t / s for t, s in zip(target_shape, seg.shape)) seg_full = ndi.zoom(seg, zoom, order=0, prefilter=False).astype(np.uint8) else: seg_full = seg out_path = tempfile.mktemp(suffix="_seg.nii.gz") nib.save(nib.Nifti1Image(seg_full, original.affine), out_path) return montage, _legend_text(seg_full), out_path with gr.Blocks( title="AutoMICE — Mouse Micro-CT Segmentation", css="footer {visibility: hidden}", ) as demo: gr.Markdown( """ # AutoMICE — Mouse Micro-CT Multi-Organ Segmentation Automated **bladder · lung · heart · liver · intestine · kidney · spleen** segmentation from a single 3D NIfTI volume using Swin UNETR. > 📄 *Robust Automated Mouse Micro-CT Segmentation Using Swin UNEt TRansformers* — Jiang et al. > 💻 [GitHub](https://github.com/namijiang/AutoMICE) · 🤗 [Model + Docker image](https://huggingface.co/namijiang98/AutoMICE) ⚠️ **This online demo runs in a downsampled / clamped mode for speed.** For full-resolution segmentation please use the Docker image or the Python CLI. """ ) with gr.Row(): with gr.Column(scale=1): in_file = gr.File(label="Mouse CT (.nii / .nii.gz)", file_types=[".nii", ".gz"]) overlap = gr.Slider(0.25, 0.85, value=DEFAULT_OVERLAP, step=0.05, label="Sliding-window overlap") run_btn = gr.Button("Run AutoMICE", variant="primary") with gr.Column(scale=2): out_image = gr.Image(label="Mid-slice overlay", type="pil") out_table = gr.Markdown(label="Voxel counts") out_file = gr.File(label="Download segmentation (NIfTI)") run_btn.click( fn=run_demo, inputs=[in_file, overlap], outputs=[out_image, out_table, out_file], ) gr.Markdown( """ --- ### How to cite ```bibtex @article{jiang2024automice, title = {Robust Automated Mouse Micro-CT Segmentation Using Swin UNEt TRansformers}, author = {Jiang, Lu and Xu, Di and Xu, Qifan and Chatziioannou, Arion and Iwamoto, Keisuke S. and Hui, Susanta and Sheng, Ke}, journal = {Bioengineering}, year = {2024}, doi = {10.3390/bioengineering11121255} } ``` """ ) if __name__ == "__main__": # show_api=False alone is not enough on HF Spaces; the schema patch above # is the real fix. We still hide the auto API page to reduce surface area. demo.queue(max_size=8).launch(show_api=False)