File size: 7,277 Bytes
2e21ef0
 
 
 
d791fee
 
 
 
2e21ef0
 
 
d791fee
 
f999ee3
2e21ef0
 
 
 
 
 
 
 
 
f999ee3
d791fee
92db826
 
 
2e21ef0
92db826
d791fee
92db826
 
 
d791fee
92db826
d791fee
 
2e21ef0
 
d791fee
2e21ef0
 
d791fee
 
 
92db826
 
f91c0f5
92db826
2e21ef0
92db826
f91c0f5
92db826
f91c0f5
92db826
f999ee3
d791fee
 
92db826
2e21ef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d791fee
2e21ef0
 
92db826
 
2e21ef0
 
 
 
92db826
2e21ef0
92db826
 
 
 
 
 
 
 
 
 
 
 
 
 
2e21ef0
 
92db826
 
2e21ef0
 
92db826
2e21ef0
 
 
92db826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f91c0f5
d791fee
f91c0f5
92db826
2e21ef0
 
92db826
2e21ef0
 
92db826
2e21ef0
92db826
2e21ef0
 
 
 
39f51a5
 
92db826
 
2e21ef0
de1f585
39f51a5
 
2e21ef0
 
 
 
 
39f51a5
92db826
2e21ef0
39f51a5
2e21ef0
 
 
de1f585
2e21ef0
92db826
f91c0f5
2e21ef0
92db826
2e21ef0
 
92db826
39f51a5
2e21ef0
 
92db826
2e21ef0
 
 
92db826
2e21ef0
 
 
 
 
 
 
39f51a5
2e21ef0
92db826
d791fee
2e21ef0
d791fee
2e21ef0
 
 
92db826
39f51a5
2e21ef0
39f51a5
2e21ef0
92db826
39f51a5
2e21ef0
d791fee
de1f585
92db826
 
de1f585
92db826
d791fee
2e21ef0
92db826
39f51a5
 
 
2e21ef0
 
39f51a5
 
2e21ef0
 
39f51a5
2e21ef0
 
39f51a5
d791fee
 
92db826
2e21ef0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import tempfile
from io import BytesIO

import gradio as gr
import torch
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from monai.networks.nets import SwinUNETR
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    Resized,
    EnsureTyped,
)

print("Starting app...")

# ----------------- DEVICE -----------------
device = torch.device("cpu")
print(f"Using device: {device}")

# ----------------- MODEL -----------------
# NOTE: SwinUNETR in current MONAI versions does NOT take `patch_size` or `window_size`.
# Use img_size consistent with your pre-processing (Resized to 128x128x64).
model = SwinUNETR(
    img_size=(128, 128, 64),
    in_channels=1,
    out_channels=2,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    feature_size=48,
    norm_name="instance",
    use_checkpoint=False,
    spatial_dims=3,
).to(device)

ckpt_path = "best_metric_model.pth"
if os.path.exists(ckpt_path):
    try:
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state)
        print("Model loaded successfully.")
    except Exception as e:
        print(f"ERROR loading model weights: {e}")
else:
    print(f"WARNING: checkpoint '{ckpt_path}' not found in Space.")

model.eval()

# ----------------- TRANSFORMS -----------------
test_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.0), mode="bilinear"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-200,
            a_max=200,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image", allow_smaller=False),
        Resized(keys=["image"], spatial_size=(128, 128, 64)),
        EnsureTyped(keys=["image"], dtype=torch.float32),
    ]
)


def _get_path_from_gradio_file(file_obj):
    """
    Convert the Gradio file object into a real path on disk.
    Handles dicts, tempfiles, and plain string paths.
    """
    if file_obj is None:
        return None

    # Case 1: dict (HF Spaces often passes this)
    if isinstance(file_obj, dict):
        if "path" in file_obj and file_obj["path"]:
            return file_obj["path"]
        if "name" in file_obj and file_obj["name"]:
            return file_obj["name"]
        # If we only have raw bytes, write them to a temp file
        if "data" in file_obj and file_obj["data"] is not None:
            suffix = ".nii.gz"
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
            tmp.write(file_obj["data"])
            tmp.flush()
            tmp.close()
            return tmp.name

    # Case 2: tempfile-like with .name
    if hasattr(file_obj, "name"):
        return file_obj.name

    # Case 3: already a string path (local testing)
    if isinstance(file_obj, str):
        return file_obj

    raise ValueError(f"Unsupported file object type: {type(file_obj)}")


def _error_image(msg: str):
    """
    Create a simple image with an error message so the UI
    never looks 'empty' when something goes wrong.
    """
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.text(0.5, 0.5, msg, ha="center", va="center", color="red", fontsize=12)
    ax.axis("off")
    buf = BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    buf.seek(0)
    img = np.array(Image.open(buf))
    plt.close(fig)
    return img


# ----------------- INFERENCE -----------------
def segment_liver(file_obj, slice_num=64):
    try:
        if file_obj is None:
            return _error_image("No file uploaded."), None

        file_path = _get_path_from_gradio_file(file_obj)
        print(f"[segment_liver] file_path = {file_path}")

        if file_path is None or not os.path.exists(file_path):
            raise FileNotFoundError("Uploaded file path not found on server.")

        # Manual extension check
        if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")):
            raise ValueError("Invalid file type. Please upload a .nii or .nii.gz NIfTI file.")

        # Preprocess
        data_dict = {"image": file_path}
        data_dict = test_transforms(data_dict)
        volume = data_dict["image"].unsqueeze(0).to(device)  # [1, 1, H, W, D]
        print(f"[segment_liver] preprocessed volume shape: {volume.shape}")

        # Inference
        with torch.no_grad():
            outputs = sliding_window_inference(
                volume,
                roi_size=(96, 96, 96),
                sw_batch_size=1,
                predictor=model,
                overlap=0.25,
            )
            pred = torch.argmax(outputs, dim=1).float()  # [1, H, W, D]

        vol_np = volume[0, 0].cpu().numpy()
        pred_np = pred[0].cpu().numpy()

        # Normalize CT for display
        vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)

        # Handle slice index safely
        z_dim = vol_np.shape[2]
        idx = int(slice_num)
        if idx < 0 or idx >= z_dim:
            idx = z_dim // 2

        # Plot CT / mask / overlay
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        axes[0].imshow(vol_display[:, :, idx], cmap="gray")
        axes[0].set_title("CT Slice")
        axes[0].axis("off")

        axes[1].imshow(pred_np[:, :, idx], cmap="Reds", vmin=0, vmax=1)
        axes[1].set_title("Predicted Liver Mask")
        axes[1].axis("off")

        axes[2].imshow(vol_display[:, :, idx], cmap="gray")
        axes[2].imshow(pred_np[:, :, idx], cmap="Greens", alpha=0.5, vmin=0, vmax=1)
        axes[2].set_title("Overlay")
        axes[2].axis("off")

        plt.tight_layout()

        # Convert figure to numpy image
        buf = BytesIO()
        fig.savefig(buf, format="png", bbox_inches="tight")
        buf.seek(0)
        img = np.array(Image.open(buf))
        plt.close(fig)

        # Save prediction mask as NIfTI for download
        pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4))
        out_path = tempfile.mktemp(suffix=".nii.gz")
        nib.save(pred_nii, out_path)

        print("[segment_liver] success.")
        return img, out_path

    except Exception as e:
        import traceback

        print("[segment_liver] ERROR:", e)
        traceback.print_exc()
        return _error_image(f"Error: {e}"), None


# ----------------- GRADIO UI -----------------
iface = gr.Interface(
    fn=segment_liver,
    inputs=[
        gr.File(label="Upload NIfTI volume (.nii or .nii.gz)"),
        gr.Slider(0, 127, value=64, label="Slice index"),
    ],
    outputs=[
        gr.Image(label="Result", type="numpy"),
        gr.File(label="Download Mask (.nii.gz)"),
    ],
    title="Liver Segmentation (SwinUNETR, MONAI)",
    description="Upload a 3D liver CT volume (.nii or .nii.gz). The app runs a SwinUNETR model trained on MSD Task03 Liver.",
)

if __name__ == "__main__":
    # On HF Spaces: iface.launch(server_name=\"0.0.0.0\", server_port=7860)
    iface.launch()