Spaces:
Sleeping
Sleeping
| """AutoMICE — Hugging Face Spaces demo. | |
| Run a low-resolution Swin UNETR segmentation on a single mouse micro-CT | |
| volume uploaded by the user. The full-resolution pipeline lives in the | |
| official Docker image; this demo trades accuracy for speed on free CPU | |
| hardware. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from io import BytesIO | |
| from typing import Tuple | |
| import gradio as gr | |
| # --------------------------------------------------------------------------- | |
| # Work around a known crash in `gradio_client.utils.json_schema_to_python_type`: | |
| # TypeError: argument of type 'bool' is not iterable | |
| # It happens when a JSON schema contains `additionalProperties: true/false` | |
| # (a bool instead of a dict). Hugging Face Spaces touches /api/info on every | |
| # page load, so this would 500 the whole demo. Patch defensively. | |
| # --------------------------------------------------------------------------- | |
| try: | |
| from gradio_client import utils as _gc_utils | |
| _orig_json_to_python = _gc_utils._json_schema_to_python_type | |
| _orig_get_type = _gc_utils.get_type | |
| def _safe_get_type(schema): | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| return _orig_get_type(schema) | |
| def _safe_json_to_python(schema, defs=None): | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| return _orig_json_to_python(schema, defs) | |
| _gc_utils.get_type = _safe_get_type | |
| _gc_utils._json_schema_to_python_type = _safe_json_to_python | |
| except Exception: # pragma: no cover - never block app startup | |
| pass | |
| import matplotlib.pyplot as plt | |
| import nibabel as nib | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from monai import transforms | |
| from monai.inferers import sliding_window_inference | |
| from monai.networks.nets import SwinUNETR | |
| from PIL import Image | |
| from scipy import ndimage as ndi | |
| HF_REPO_ID = os.environ.get("AUTOMICE_HF_REPO", "namijiang98/AutoMICE") | |
| HF_MODEL_FILENAME = os.environ.get("AUTOMICE_HF_FILENAME", "model.pt") | |
| LABEL_NAMES = { | |
| 0: "background", | |
| 1: "bladder", | |
| 2: "lung", | |
| 3: "heart", | |
| 4: "liver", | |
| 5: "intestine", | |
| 6: "kidney", | |
| 7: "spleen", | |
| } | |
| LABEL_COLORS = np.array( | |
| [ | |
| [0, 0, 0], | |
| [255, 215, 0], | |
| [135, 206, 250], | |
| [220, 20, 60], | |
| [210, 105, 30], | |
| [124, 252, 0], | |
| [186, 85, 211], | |
| [255, 105, 180], | |
| ], | |
| dtype=np.uint8, | |
| ) | |
| DEMO_SPACING = (0.4, 0.4, 0.4) | |
| DEMO_MAX_DIM = 192 | |
| DEFAULT_OVERLAP = 0.5 | |
| ROI_SIZE = (96, 96, 96) | |
| _MODEL: SwinUNETR | None = None | |
| def _device() -> torch.device: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _load_model() -> SwinUNETR: | |
| global _MODEL | |
| if _MODEL is not None: | |
| return _MODEL | |
| weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME) | |
| model = SwinUNETR( | |
| img_size=96, | |
| in_channels=1, | |
| out_channels=8, | |
| feature_size=36, | |
| drop_rate=0.0, | |
| attn_drop_rate=0.0, | |
| dropout_path_rate=0.0, | |
| use_checkpoint=False, | |
| ) | |
| ckpt = torch.load(weights_path, map_location="cpu") | |
| state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model.to(_device()) | |
| _MODEL = model | |
| return model | |
| def _build_transform(): | |
| return transforms.Compose( | |
| [ | |
| transforms.LoadImaged(keys=["image"], reader="NibabelReader"), | |
| transforms.AddChanneld(keys=["image"]), | |
| transforms.Spacingd(keys="image", pixdim=DEMO_SPACING, mode="bilinear"), | |
| transforms.ScaleIntensityRanged( | |
| keys=["image"], a_min=-1000.0, a_max=5000.0, b_min=0.0, b_max=1.0, clip=True | |
| ), | |
| transforms.ToTensord(keys=["image"]), | |
| ] | |
| ) | |
| def _clamp_volume(tensor: torch.Tensor) -> torch.Tensor: | |
| """Hard limit on volume size so CPU inference finishes in a sensible time.""" | |
| _, h, w, d = tensor.shape | |
| if max(h, w, d) <= DEMO_MAX_DIM: | |
| return tensor | |
| scale = DEMO_MAX_DIM / max(h, w, d) | |
| new_shape = ( | |
| max(1, int(round(h * scale))), | |
| max(1, int(round(w * scale))), | |
| max(1, int(round(d * scale))), | |
| ) | |
| return torch.nn.functional.interpolate( | |
| tensor.unsqueeze(0), | |
| size=new_shape, | |
| mode="trilinear", | |
| align_corners=False, | |
| )[0] | |
| def _overlay(image_slice: np.ndarray, seg_slice: np.ndarray) -> Image.Image: | |
| """Blend a grayscale CT slice with the coloured segmentation mask.""" | |
| img = image_slice.astype(np.float32) | |
| img = (img - img.min()) / max(img.ptp(), 1e-6) | |
| img_rgb = (np.stack([img, img, img], axis=-1) * 255).astype(np.uint8) | |
| mask_rgb = LABEL_COLORS[seg_slice.astype(np.int32)] | |
| alpha = (seg_slice > 0).astype(np.float32)[..., None] * 0.45 | |
| blended = (img_rgb * (1 - alpha) + mask_rgb * alpha).astype(np.uint8) | |
| return Image.fromarray(blended) | |
| def _make_montage(volume: np.ndarray, seg: np.ndarray) -> Image.Image: | |
| """Three-panel mid-slice montage (axial / coronal / sagittal).""" | |
| h, w, d = volume.shape | |
| ax_idx, co_idx, sa_idx = h // 2, w // 2, d // 2 | |
| panels = [ | |
| _overlay(np.rot90(volume[ax_idx, :, :]), np.rot90(seg[ax_idx, :, :])), | |
| _overlay(np.rot90(volume[:, co_idx, :]), np.rot90(seg[:, co_idx, :])), | |
| _overlay(np.rot90(volume[:, :, sa_idx]), np.rot90(seg[:, :, sa_idx])), | |
| ] | |
| titles = ["axial mid-slice", "coronal mid-slice", "sagittal mid-slice"] | |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
| for ax, im, title in zip(axes, panels, titles): | |
| ax.imshow(im) | |
| ax.set_title(title) | |
| ax.axis("off") | |
| fig.tight_layout() | |
| buf = BytesIO() | |
| fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf).copy() | |
| def _legend_text(seg: np.ndarray) -> str: | |
| vals, counts = np.unique(seg, return_counts=True) | |
| rows = ["| index | organ | voxels |", "|-------|-------|--------|"] | |
| for v, c in zip(vals.tolist(), counts.tolist()): | |
| rows.append(f"| {v} | {LABEL_NAMES.get(int(v), 'unknown')} | {int(c)} |") | |
| return "\n".join(rows) | |
| def run_demo(nifti_file, overlap: float) -> Tuple[Image.Image, str, str]: | |
| if nifti_file is None: | |
| raise gr.Error("Please upload a .nii or .nii.gz volume first.") | |
| path = nifti_file if isinstance(nifti_file, str) else nifti_file.name | |
| if not (path.endswith(".nii") or path.endswith(".nii.gz")): | |
| raise gr.Error("Only .nii / .nii.gz files are supported.") | |
| transform = _build_transform() | |
| sample = transform({"image": path}) | |
| img_tensor = sample["image"] | |
| img_tensor = _clamp_volume(img_tensor) | |
| device = _device() | |
| model = _load_model() | |
| with torch.no_grad(): | |
| inputs = img_tensor.unsqueeze(0).to(device) | |
| logits = sliding_window_inference( | |
| inputs=inputs, | |
| roi_size=ROI_SIZE, | |
| sw_batch_size=1, | |
| predictor=model, | |
| overlap=float(overlap), | |
| mode="gaussian", | |
| ) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy() | |
| seg = np.argmax(probs, axis=1).astype(np.uint8)[0] | |
| ct = inputs.cpu().numpy()[0, 0] | |
| montage = _make_montage(ct, seg) | |
| original = nib.load(path) | |
| target_shape = original.shape | |
| if seg.shape != tuple(target_shape): | |
| zoom = tuple(t / s for t, s in zip(target_shape, seg.shape)) | |
| seg_full = ndi.zoom(seg, zoom, order=0, prefilter=False).astype(np.uint8) | |
| else: | |
| seg_full = seg | |
| out_path = tempfile.mktemp(suffix="_seg.nii.gz") | |
| nib.save(nib.Nifti1Image(seg_full, original.affine), out_path) | |
| return montage, _legend_text(seg_full), out_path | |
| with gr.Blocks( | |
| title="AutoMICE — Mouse Micro-CT Segmentation", | |
| css="footer {visibility: hidden}", | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # AutoMICE — Mouse Micro-CT Multi-Organ Segmentation | |
| Automated **bladder · lung · heart · liver · intestine · kidney · spleen** | |
| segmentation from a single 3D NIfTI volume using Swin UNETR. | |
| > 📄 *Robust Automated Mouse Micro-CT Segmentation Using Swin UNEt TRansformers* — Jiang et al. | |
| > 💻 [GitHub](https://github.com/namijiang/AutoMICE) · 🤗 [Model + Docker image](https://huggingface.co/namijiang98/AutoMICE) | |
| ⚠️ **This online demo runs in a downsampled / clamped mode for speed.** | |
| For full-resolution segmentation please use the Docker image or the Python CLI. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_file = gr.File(label="Mouse CT (.nii / .nii.gz)", file_types=[".nii", ".gz"]) | |
| overlap = gr.Slider(0.25, 0.85, value=DEFAULT_OVERLAP, step=0.05, label="Sliding-window overlap") | |
| run_btn = gr.Button("Run AutoMICE", variant="primary") | |
| with gr.Column(scale=2): | |
| out_image = gr.Image(label="Mid-slice overlay", type="pil") | |
| out_table = gr.Markdown(label="Voxel counts") | |
| out_file = gr.File(label="Download segmentation (NIfTI)") | |
| run_btn.click( | |
| fn=run_demo, | |
| inputs=[in_file, overlap], | |
| outputs=[out_image, out_table, out_file], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### How to cite | |
| ```bibtex | |
| @article{jiang2024automice, | |
| title = {Robust Automated Mouse Micro-CT Segmentation Using Swin UNEt TRansformers}, | |
| author = {Jiang, Lu and Xu, Di and Xu, Qifan and Chatziioannou, Arion and | |
| Iwamoto, Keisuke S. and Hui, Susanta and Sheng, Ke}, | |
| journal = {Bioengineering}, | |
| year = {2024}, | |
| doi = {10.3390/bioengineering11121255} | |
| } | |
| ``` | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| # show_api=False alone is not enough on HF Spaces; the schema patch above | |
| # is the real fix. We still hide the auto API page to reduce surface area. | |
| demo.queue(max_size=8).launch(show_api=False) | |