Spaces:
Running
Running
| import os | |
| import tempfile | |
| from io import BytesIO | |
| import gradio as gr | |
| import torch | |
| import nibabel as nib | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from monai.networks.nets import SwinUNETR | |
| from monai.inferers import sliding_window_inference | |
| from monai.transforms import ( | |
| Compose, | |
| LoadImaged, | |
| EnsureChannelFirstd, | |
| Orientationd, | |
| Spacingd, | |
| ScaleIntensityRanged, | |
| CropForegroundd, | |
| Resized, | |
| EnsureTyped, | |
| ) | |
| print("Starting app...") | |
| # ----------------- DEVICE ----------------- | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| # ----------------- MODEL ----------------- | |
| # NOTE: SwinUNETR in current MONAI versions does NOT take `patch_size` or `window_size`. | |
| # Use img_size consistent with your pre-processing (Resized to 128x128x64). | |
| model = SwinUNETR( | |
| img_size=(128, 128, 64), | |
| in_channels=1, | |
| out_channels=2, | |
| depths=(2, 2, 2, 2), | |
| num_heads=(3, 6, 12, 24), | |
| feature_size=48, | |
| norm_name="instance", | |
| use_checkpoint=False, | |
| spatial_dims=3, | |
| ).to(device) | |
| ckpt_path = "best_metric_model.pth" | |
| if os.path.exists(ckpt_path): | |
| try: | |
| state = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(state) | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"ERROR loading model weights: {e}") | |
| else: | |
| print(f"WARNING: checkpoint '{ckpt_path}' not found in Space.") | |
| model.eval() | |
| # ----------------- TRANSFORMS ----------------- | |
| test_transforms = Compose( | |
| [ | |
| LoadImaged(keys=["image"]), | |
| EnsureChannelFirstd(keys=["image"]), | |
| Orientationd(keys=["image"], axcodes="RAS"), | |
| Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.0), mode="bilinear"), | |
| ScaleIntensityRanged( | |
| keys=["image"], | |
| a_min=-200, | |
| a_max=200, | |
| b_min=0.0, | |
| b_max=1.0, | |
| clip=True, | |
| ), | |
| CropForegroundd(keys=["image"], source_key="image", allow_smaller=False), | |
| Resized(keys=["image"], spatial_size=(128, 128, 64)), | |
| EnsureTyped(keys=["image"], dtype=torch.float32), | |
| ] | |
| ) | |
| def _get_path_from_gradio_file(file_obj): | |
| """ | |
| Convert the Gradio file object into a real path on disk. | |
| Handles dicts, tempfiles, and plain string paths. | |
| """ | |
| if file_obj is None: | |
| return None | |
| # Case 1: dict (HF Spaces often passes this) | |
| if isinstance(file_obj, dict): | |
| if "path" in file_obj and file_obj["path"]: | |
| return file_obj["path"] | |
| if "name" in file_obj and file_obj["name"]: | |
| return file_obj["name"] | |
| # If we only have raw bytes, write them to a temp file | |
| if "data" in file_obj and file_obj["data"] is not None: | |
| suffix = ".nii.gz" | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| tmp.write(file_obj["data"]) | |
| tmp.flush() | |
| tmp.close() | |
| return tmp.name | |
| # Case 2: tempfile-like with .name | |
| if hasattr(file_obj, "name"): | |
| return file_obj.name | |
| # Case 3: already a string path (local testing) | |
| if isinstance(file_obj, str): | |
| return file_obj | |
| raise ValueError(f"Unsupported file object type: {type(file_obj)}") | |
| def _error_image(msg: str): | |
| """ | |
| Create a simple image with an error message so the UI | |
| never looks 'empty' when something goes wrong. | |
| """ | |
| fig, ax = plt.subplots(figsize=(8, 3)) | |
| ax.text(0.5, 0.5, msg, ha="center", va="center", color="red", fontsize=12) | |
| ax.axis("off") | |
| buf = BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight") | |
| buf.seek(0) | |
| img = np.array(Image.open(buf)) | |
| plt.close(fig) | |
| return img | |
| # ----------------- INFERENCE ----------------- | |
| def segment_liver(file_obj, slice_num=64): | |
| try: | |
| if file_obj is None: | |
| return _error_image("No file uploaded."), None | |
| file_path = _get_path_from_gradio_file(file_obj) | |
| print(f"[segment_liver] file_path = {file_path}") | |
| if file_path is None or not os.path.exists(file_path): | |
| raise FileNotFoundError("Uploaded file path not found on server.") | |
| # Manual extension check | |
| if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")): | |
| raise ValueError("Invalid file type. Please upload a .nii or .nii.gz NIfTI file.") | |
| # Preprocess | |
| data_dict = {"image": file_path} | |
| data_dict = test_transforms(data_dict) | |
| volume = data_dict["image"].unsqueeze(0).to(device) # [1, 1, H, W, D] | |
| print(f"[segment_liver] preprocessed volume shape: {volume.shape}") | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = sliding_window_inference( | |
| volume, | |
| roi_size=(96, 96, 96), | |
| sw_batch_size=1, | |
| predictor=model, | |
| overlap=0.25, | |
| ) | |
| pred = torch.argmax(outputs, dim=1).float() # [1, H, W, D] | |
| vol_np = volume[0, 0].cpu().numpy() | |
| pred_np = pred[0].cpu().numpy() | |
| # Normalize CT for display | |
| vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8) | |
| # Handle slice index safely | |
| z_dim = vol_np.shape[2] | |
| idx = int(slice_num) | |
| if idx < 0 or idx >= z_dim: | |
| idx = z_dim // 2 | |
| # Plot CT / mask / overlay | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| axes[0].imshow(vol_display[:, :, idx], cmap="gray") | |
| axes[0].set_title("CT Slice") | |
| axes[0].axis("off") | |
| axes[1].imshow(pred_np[:, :, idx], cmap="Reds", vmin=0, vmax=1) | |
| axes[1].set_title("Predicted Liver Mask") | |
| axes[1].axis("off") | |
| axes[2].imshow(vol_display[:, :, idx], cmap="gray") | |
| axes[2].imshow(pred_np[:, :, idx], cmap="Greens", alpha=0.5, vmin=0, vmax=1) | |
| axes[2].set_title("Overlay") | |
| axes[2].axis("off") | |
| plt.tight_layout() | |
| # Convert figure to numpy image | |
| buf = BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight") | |
| buf.seek(0) | |
| img = np.array(Image.open(buf)) | |
| plt.close(fig) | |
| # Save prediction mask as NIfTI for download | |
| pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4)) | |
| out_path = tempfile.mktemp(suffix=".nii.gz") | |
| nib.save(pred_nii, out_path) | |
| print("[segment_liver] success.") | |
| return img, out_path | |
| except Exception as e: | |
| import traceback | |
| print("[segment_liver] ERROR:", e) | |
| traceback.print_exc() | |
| return _error_image(f"Error: {e}"), None | |
| # ----------------- GRADIO UI ----------------- | |
| iface = gr.Interface( | |
| fn=segment_liver, | |
| inputs=[ | |
| gr.File(label="Upload NIfTI volume (.nii or .nii.gz)"), | |
| gr.Slider(0, 127, value=64, label="Slice index"), | |
| ], | |
| outputs=[ | |
| gr.Image(label="Result", type="numpy"), | |
| gr.File(label="Download Mask (.nii.gz)"), | |
| ], | |
| title="Liver Segmentation (SwinUNETR, MONAI)", | |
| description="Upload a 3D liver CT volume (.nii or .nii.gz). The app runs a SwinUNETR model trained on MSD Task03 Liver.", | |
| ) | |
| if __name__ == "__main__": | |
| # On HF Spaces: iface.launch(server_name=\"0.0.0.0\", server_port=7860) | |
| iface.launch() | |