import os import tempfile from io import BytesIO import gradio as gr import torch import nibabel as nib import numpy as np import matplotlib.pyplot as plt from PIL import Image from monai.networks.nets import SwinUNETR from monai.inferers import sliding_window_inference from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd, ScaleIntensityRanged, CropForegroundd, Resized, EnsureTyped, ) print("Starting app...") # ----------------- DEVICE ----------------- device = torch.device("cpu") print(f"Using device: {device}") # ----------------- MODEL ----------------- # NOTE: SwinUNETR in current MONAI versions does NOT take `patch_size` or `window_size`. # Use img_size consistent with your pre-processing (Resized to 128x128x64). model = SwinUNETR( img_size=(128, 128, 64), in_channels=1, out_channels=2, depths=(2, 2, 2, 2), num_heads=(3, 6, 12, 24), feature_size=48, norm_name="instance", use_checkpoint=False, spatial_dims=3, ).to(device) ckpt_path = "best_metric_model.pth" if os.path.exists(ckpt_path): try: state = torch.load(ckpt_path, map_location=device) model.load_state_dict(state) print("Model loaded successfully.") except Exception as e: print(f"ERROR loading model weights: {e}") else: print(f"WARNING: checkpoint '{ckpt_path}' not found in Space.") model.eval() # ----------------- TRANSFORMS ----------------- test_transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Orientationd(keys=["image"], axcodes="RAS"), Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.0), mode="bilinear"), ScaleIntensityRanged( keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image"], source_key="image", allow_smaller=False), Resized(keys=["image"], spatial_size=(128, 128, 64)), EnsureTyped(keys=["image"], dtype=torch.float32), ] ) def _get_path_from_gradio_file(file_obj): """ Convert the Gradio file object into a real path on disk. Handles dicts, tempfiles, and plain string paths. """ if file_obj is None: return None # Case 1: dict (HF Spaces often passes this) if isinstance(file_obj, dict): if "path" in file_obj and file_obj["path"]: return file_obj["path"] if "name" in file_obj and file_obj["name"]: return file_obj["name"] # If we only have raw bytes, write them to a temp file if "data" in file_obj and file_obj["data"] is not None: suffix = ".nii.gz" tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) tmp.write(file_obj["data"]) tmp.flush() tmp.close() return tmp.name # Case 2: tempfile-like with .name if hasattr(file_obj, "name"): return file_obj.name # Case 3: already a string path (local testing) if isinstance(file_obj, str): return file_obj raise ValueError(f"Unsupported file object type: {type(file_obj)}") def _error_image(msg: str): """ Create a simple image with an error message so the UI never looks 'empty' when something goes wrong. """ fig, ax = plt.subplots(figsize=(8, 3)) ax.text(0.5, 0.5, msg, ha="center", va="center", color="red", fontsize=12) ax.axis("off") buf = BytesIO() fig.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) img = np.array(Image.open(buf)) plt.close(fig) return img # ----------------- INFERENCE ----------------- def segment_liver(file_obj, slice_num=64): try: if file_obj is None: return _error_image("No file uploaded."), None file_path = _get_path_from_gradio_file(file_obj) print(f"[segment_liver] file_path = {file_path}") if file_path is None or not os.path.exists(file_path): raise FileNotFoundError("Uploaded file path not found on server.") # Manual extension check if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")): raise ValueError("Invalid file type. Please upload a .nii or .nii.gz NIfTI file.") # Preprocess data_dict = {"image": file_path} data_dict = test_transforms(data_dict) volume = data_dict["image"].unsqueeze(0).to(device) # [1, 1, H, W, D] print(f"[segment_liver] preprocessed volume shape: {volume.shape}") # Inference with torch.no_grad(): outputs = sliding_window_inference( volume, roi_size=(96, 96, 96), sw_batch_size=1, predictor=model, overlap=0.25, ) pred = torch.argmax(outputs, dim=1).float() # [1, H, W, D] vol_np = volume[0, 0].cpu().numpy() pred_np = pred[0].cpu().numpy() # Normalize CT for display vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8) # Handle slice index safely z_dim = vol_np.shape[2] idx = int(slice_num) if idx < 0 or idx >= z_dim: idx = z_dim // 2 # Plot CT / mask / overlay fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(vol_display[:, :, idx], cmap="gray") axes[0].set_title("CT Slice") axes[0].axis("off") axes[1].imshow(pred_np[:, :, idx], cmap="Reds", vmin=0, vmax=1) axes[1].set_title("Predicted Liver Mask") axes[1].axis("off") axes[2].imshow(vol_display[:, :, idx], cmap="gray") axes[2].imshow(pred_np[:, :, idx], cmap="Greens", alpha=0.5, vmin=0, vmax=1) axes[2].set_title("Overlay") axes[2].axis("off") plt.tight_layout() # Convert figure to numpy image buf = BytesIO() fig.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) img = np.array(Image.open(buf)) plt.close(fig) # Save prediction mask as NIfTI for download pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4)) out_path = tempfile.mktemp(suffix=".nii.gz") nib.save(pred_nii, out_path) print("[segment_liver] success.") return img, out_path except Exception as e: import traceback print("[segment_liver] ERROR:", e) traceback.print_exc() return _error_image(f"Error: {e}"), None # ----------------- GRADIO UI ----------------- iface = gr.Interface( fn=segment_liver, inputs=[ gr.File(label="Upload NIfTI volume (.nii or .nii.gz)"), gr.Slider(0, 127, value=64, label="Slice index"), ], outputs=[ gr.Image(label="Result", type="numpy"), gr.File(label="Download Mask (.nii.gz)"), ], title="Liver Segmentation (SwinUNETR, MONAI)", description="Upload a 3D liver CT volume (.nii or .nii.gz). The app runs a SwinUNETR model trained on MSD Task03 Liver.", ) if __name__ == "__main__": # On HF Spaces: iface.launch(server_name=\"0.0.0.0\", server_port=7860) iface.launch()