Amodit's picture
Update app.py
92db826 verified
import os
import tempfile
from io import BytesIO
import gradio as gr
import torch
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from monai.networks.nets import SwinUNETR
from monai.inferers import sliding_window_inference
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Orientationd,
Spacingd,
ScaleIntensityRanged,
CropForegroundd,
Resized,
EnsureTyped,
)
print("Starting app...")
# ----------------- DEVICE -----------------
device = torch.device("cpu")
print(f"Using device: {device}")
# ----------------- MODEL -----------------
# NOTE: SwinUNETR in current MONAI versions does NOT take `patch_size` or `window_size`.
# Use img_size consistent with your pre-processing (Resized to 128x128x64).
model = SwinUNETR(
img_size=(128, 128, 64),
in_channels=1,
out_channels=2,
depths=(2, 2, 2, 2),
num_heads=(3, 6, 12, 24),
feature_size=48,
norm_name="instance",
use_checkpoint=False,
spatial_dims=3,
).to(device)
ckpt_path = "best_metric_model.pth"
if os.path.exists(ckpt_path):
try:
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)
print("Model loaded successfully.")
except Exception as e:
print(f"ERROR loading model weights: {e}")
else:
print(f"WARNING: checkpoint '{ckpt_path}' not found in Space.")
model.eval()
# ----------------- TRANSFORMS -----------------
test_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]),
Orientationd(keys=["image"], axcodes="RAS"),
Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.0), mode="bilinear"),
ScaleIntensityRanged(
keys=["image"],
a_min=-200,
a_max=200,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image"], source_key="image", allow_smaller=False),
Resized(keys=["image"], spatial_size=(128, 128, 64)),
EnsureTyped(keys=["image"], dtype=torch.float32),
]
)
def _get_path_from_gradio_file(file_obj):
"""
Convert the Gradio file object into a real path on disk.
Handles dicts, tempfiles, and plain string paths.
"""
if file_obj is None:
return None
# Case 1: dict (HF Spaces often passes this)
if isinstance(file_obj, dict):
if "path" in file_obj and file_obj["path"]:
return file_obj["path"]
if "name" in file_obj and file_obj["name"]:
return file_obj["name"]
# If we only have raw bytes, write them to a temp file
if "data" in file_obj and file_obj["data"] is not None:
suffix = ".nii.gz"
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
tmp.write(file_obj["data"])
tmp.flush()
tmp.close()
return tmp.name
# Case 2: tempfile-like with .name
if hasattr(file_obj, "name"):
return file_obj.name
# Case 3: already a string path (local testing)
if isinstance(file_obj, str):
return file_obj
raise ValueError(f"Unsupported file object type: {type(file_obj)}")
def _error_image(msg: str):
"""
Create a simple image with an error message so the UI
never looks 'empty' when something goes wrong.
"""
fig, ax = plt.subplots(figsize=(8, 3))
ax.text(0.5, 0.5, msg, ha="center", va="center", color="red", fontsize=12)
ax.axis("off")
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img = np.array(Image.open(buf))
plt.close(fig)
return img
# ----------------- INFERENCE -----------------
def segment_liver(file_obj, slice_num=64):
try:
if file_obj is None:
return _error_image("No file uploaded."), None
file_path = _get_path_from_gradio_file(file_obj)
print(f"[segment_liver] file_path = {file_path}")
if file_path is None or not os.path.exists(file_path):
raise FileNotFoundError("Uploaded file path not found on server.")
# Manual extension check
if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")):
raise ValueError("Invalid file type. Please upload a .nii or .nii.gz NIfTI file.")
# Preprocess
data_dict = {"image": file_path}
data_dict = test_transforms(data_dict)
volume = data_dict["image"].unsqueeze(0).to(device) # [1, 1, H, W, D]
print(f"[segment_liver] preprocessed volume shape: {volume.shape}")
# Inference
with torch.no_grad():
outputs = sliding_window_inference(
volume,
roi_size=(96, 96, 96),
sw_batch_size=1,
predictor=model,
overlap=0.25,
)
pred = torch.argmax(outputs, dim=1).float() # [1, H, W, D]
vol_np = volume[0, 0].cpu().numpy()
pred_np = pred[0].cpu().numpy()
# Normalize CT for display
vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
# Handle slice index safely
z_dim = vol_np.shape[2]
idx = int(slice_num)
if idx < 0 or idx >= z_dim:
idx = z_dim // 2
# Plot CT / mask / overlay
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(vol_display[:, :, idx], cmap="gray")
axes[0].set_title("CT Slice")
axes[0].axis("off")
axes[1].imshow(pred_np[:, :, idx], cmap="Reds", vmin=0, vmax=1)
axes[1].set_title("Predicted Liver Mask")
axes[1].axis("off")
axes[2].imshow(vol_display[:, :, idx], cmap="gray")
axes[2].imshow(pred_np[:, :, idx], cmap="Greens", alpha=0.5, vmin=0, vmax=1)
axes[2].set_title("Overlay")
axes[2].axis("off")
plt.tight_layout()
# Convert figure to numpy image
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img = np.array(Image.open(buf))
plt.close(fig)
# Save prediction mask as NIfTI for download
pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4))
out_path = tempfile.mktemp(suffix=".nii.gz")
nib.save(pred_nii, out_path)
print("[segment_liver] success.")
return img, out_path
except Exception as e:
import traceback
print("[segment_liver] ERROR:", e)
traceback.print_exc()
return _error_image(f"Error: {e}"), None
# ----------------- GRADIO UI -----------------
iface = gr.Interface(
fn=segment_liver,
inputs=[
gr.File(label="Upload NIfTI volume (.nii or .nii.gz)"),
gr.Slider(0, 127, value=64, label="Slice index"),
],
outputs=[
gr.Image(label="Result", type="numpy"),
gr.File(label="Download Mask (.nii.gz)"),
],
title="Liver Segmentation (SwinUNETR, MONAI)",
description="Upload a 3D liver CT volume (.nii or .nii.gz). The app runs a SwinUNETR model trained on MSD Task03 Liver.",
)
if __name__ == "__main__":
# On HF Spaces: iface.launch(server_name=\"0.0.0.0\", server_port=7860)
iface.launch()