Amodit commited on
Commit
2e21ef0
·
verified ·
1 Parent(s): f91c0f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -78
app.py CHANGED
@@ -1,36 +1,52 @@
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import nibabel as nib
4
  import numpy as np
 
 
 
5
  from monai.networks.nets import SwinUNETR
6
  from monai.inferers import sliding_window_inference
7
  from monai.transforms import (
8
- Compose, LoadImaged, EnsureChannelFirstd, Orientationd,
9
- Spacingd, ScaleIntensityRanged, CropForegroundd, Resized, EnsureTyped
 
 
 
 
 
 
 
10
  )
11
- import matplotlib.pyplot as plt
12
- from io import BytesIO
13
- from PIL import Image
14
- import os
15
- import tempfile
16
 
17
  print("Loading model...")
18
- device = torch.device("cpu") # Keep CPU for stability
 
19
  print(f"Device: {device}")
20
 
21
- # Model config
22
  model = SwinUNETR(
23
- img_size=(128, 128, 128),
24
  in_channels=1,
25
  out_channels=2,
 
 
 
 
26
  feature_size=48,
 
 
27
  spatial_dims=3,
28
  ).to(device)
29
 
30
  model_path = "best_metric_model.pth"
31
  if os.path.exists(model_path):
32
  try:
33
- model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
 
34
  print("Model loaded!")
35
  except Exception as e:
36
  print(f"Error loading model: {e}")
@@ -39,110 +55,148 @@ else:
39
 
40
  model.eval()
41
 
42
- # --- CORRECTED TRANSFORMS ---
43
- # Using LoadImaged (dictionary version) to match the rest of the pipeline
44
- test_transforms = Compose([
45
- LoadImaged(keys=["image"]), # FIXED: Dictionary transform
46
- EnsureChannelFirstd(keys=["image"]),
47
- Orientationd(keys=["image"], axcodes="RAS"),
48
- Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
49
- ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
50
- CropForegroundd(keys=["image"], source_key="image", allow_smaller=False),
51
- Resized(keys=["image"], spatial_size=(128, 128, 128)),
52
- EnsureTyped(keys=["image"], dtype=torch.float32),
53
- ])
 
 
 
 
 
 
 
 
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def segment_liver(file_obj, slice_num=64):
56
  try:
57
  if file_obj is None:
58
  return None, None
59
-
60
- # Handle Gradio file object
61
- file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
62
  print(f"Processing: {file_path}")
63
-
64
- # Load and preprocess
 
 
 
 
 
 
 
65
  data_dict = {"image": file_path}
66
  data_dict = test_transforms(data_dict)
67
- volume = data_dict["image"].unsqueeze(0).to(device)
68
-
69
  print(f"Input shape: {volume.shape}")
70
-
71
  # Inference
72
  with torch.no_grad():
73
  outputs = sliding_window_inference(
74
- volume, (96, 96, 96), 1, model, overlap=0.25
 
 
 
 
75
  )
76
- pred = torch.argmax(outputs, dim=1).float()
77
-
78
- # Visualization
79
  vol_np = volume[0, 0].cpu().numpy()
80
- pred_np = pred[0, 0].cpu().numpy()
81
-
82
- # Normalize volume
83
  vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
84
-
85
  # Slice selection
86
  z_dim = vol_np.shape[2]
87
- slice_idx = int(slice_num)
88
- if slice_idx >= z_dim: slice_idx = z_dim // 2
89
-
 
 
 
90
  # Plot
91
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
92
-
93
- # CT
94
- axes[0].imshow(vol_display[:, :, slice_idx], cmap='gray')
95
- axes[0].set_title('CT Scan')
96
- axes[0].axis('off')
97
-
98
- # Prediction
99
- axes[1].imshow(pred_np[:, :, slice_idx], cmap='Reds', vmin=0, vmax=1)
100
- axes[1].set_title('Liver Prediction')
101
- axes[1].axis('off')
102
-
103
- # Overlay
104
- axes[2].imshow(vol_display[:, :, slice_idx], cmap='gray')
105
- axes[2].imshow(pred_np[:, :, slice_idx], cmap='Greens', alpha=0.5, vmin=0, vmax=1)
106
- axes[2].set_title('Overlay')
107
- axes[2].axis('off')
108
-
109
  plt.tight_layout()
110
-
111
- # Save image
112
  buf = BytesIO()
113
- plt.savefig(buf, format='png', bbox_inches='tight')
114
  buf.seek(0)
115
- img = Image.open(buf)
116
- plt.close()
117
-
118
- # Save prediction as NIfTI
119
  pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4))
120
- out_path = tempfile.mktemp(suffix='.nii.gz')
121
  nib.save(pred_nii, out_path)
122
-
123
  print("Success!")
124
  return img, out_path
125
-
126
  except Exception as e:
127
- print(f"Error: {e}")
128
  import traceback
129
  traceback.print_exc()
130
  return None, None
131
 
132
- # Interface
 
133
  iface = gr.Interface(
134
  fn=segment_liver,
135
  inputs=[
136
- gr.File(label="Upload NIfTI"),
137
- gr.Slider(0, 127, value=64, label="Slice")
138
  ],
139
  outputs=[
140
- gr.Image(label="Result"),
141
- gr.File(label="Download Mask")
142
  ],
143
- title="Liver Segmentation (95.5% Dice)",
144
- description="Upload file from TestVolumes folder"
145
  )
146
 
147
  if __name__ == "__main__":
148
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os
2
+ import tempfile
3
+ from io import BytesIO
4
+
5
  import gradio as gr
6
  import torch
7
  import nibabel as nib
8
  import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+
12
  from monai.networks.nets import SwinUNETR
13
  from monai.inferers import sliding_window_inference
14
  from monai.transforms import (
15
+ Compose,
16
+ LoadImaged,
17
+ EnsureChannelFirstd,
18
+ Orientationd,
19
+ Spacingd,
20
+ ScaleIntensityRanged,
21
+ CropForegroundd,
22
+ Resized,
23
+ EnsureTyped,
24
  )
 
 
 
 
 
25
 
26
  print("Loading model...")
27
+ # HF Spaces: assume CPU, GPU not guaranteed
28
+ device = torch.device("cpu")
29
  print(f"Device: {device}")
30
 
31
+ # --------- MODEL DEFINITION (must match training) ----------
32
  model = SwinUNETR(
 
33
  in_channels=1,
34
  out_channels=2,
35
+ patch_size=2,
36
+ depths=(2, 2, 2, 2),
37
+ num_heads=(3, 6, 12, 24),
38
+ window_size=7,
39
  feature_size=48,
40
+ norm_name="instance",
41
+ use_checkpoint=False,
42
  spatial_dims=3,
43
  ).to(device)
44
 
45
  model_path = "best_metric_model.pth"
46
  if os.path.exists(model_path):
47
  try:
48
+ state = torch.load(model_path, map_location=device)
49
+ model.load_state_dict(state)
50
  print("Model loaded!")
51
  except Exception as e:
52
  print(f"Error loading model: {e}")
 
55
 
56
  model.eval()
57
 
58
+ # --------- PREPROCESSING (mirror training pipeline) ----------
59
+ test_transforms = Compose(
60
+ [
61
+ LoadImaged(keys=["image"]),
62
+ EnsureChannelFirstd(keys=["image"]),
63
+ Orientationd(keys=["image"], axcodes="RAS"),
64
+ Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.0), mode="bilinear"),
65
+ ScaleIntensityRanged(
66
+ keys=["image"],
67
+ a_min=-200,
68
+ a_max=200,
69
+ b_min=0.0,
70
+ b_max=1.0,
71
+ clip=True,
72
+ ),
73
+ CropForegroundd(keys=["image"], source_key="image", allow_smaller=False),
74
+ Resized(keys=["image"], spatial_size=(128, 128, 64)),
75
+ EnsureTyped(keys=["image"], dtype=torch.float32),
76
+ ]
77
+ )
78
+
79
 
80
+ def _get_path_from_gradio_file(file_obj):
81
+ """
82
+ Gradio / HF can pass:
83
+ - dict with "name"
84
+ - tempfile-like object with .name
85
+ - plain string path (local)
86
+ """
87
+ if file_obj is None:
88
+ return None
89
+
90
+ if isinstance(file_obj, dict):
91
+ return file_obj.get("name")
92
+ if hasattr(file_obj, "name"):
93
+ return file_obj.name
94
+ if isinstance(file_obj, str):
95
+ return file_obj
96
+ raise ValueError(f"Unsupported file object type: {type(file_obj)}")
97
+
98
+
99
+ # --------- INFERENCE FUNCTION ----------
100
  def segment_liver(file_obj, slice_num=64):
101
  try:
102
  if file_obj is None:
103
  return None, None
104
+
105
+ file_path = _get_path_from_gradio_file(file_obj)
 
106
  print(f"Processing: {file_path}")
107
+
108
+ if file_path is None or not os.path.exists(file_path):
109
+ raise FileNotFoundError("Uploaded file path not found")
110
+
111
+ # Manual extension validation (since we removed file_types)
112
+ if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")):
113
+ raise ValueError("Invalid file type. Please upload a .nii or .nii.gz NIfTI file.")
114
+
115
+ # Preprocess
116
  data_dict = {"image": file_path}
117
  data_dict = test_transforms(data_dict)
118
+ volume = data_dict["image"].unsqueeze(0).to(device) # [1,1,H,W,D]
 
119
  print(f"Input shape: {volume.shape}")
120
+
121
  # Inference
122
  with torch.no_grad():
123
  outputs = sliding_window_inference(
124
+ volume,
125
+ roi_size=(96, 96, 96),
126
+ sw_batch_size=1,
127
+ predictor=model,
128
+ overlap=0.25,
129
  )
130
+ pred = torch.argmax(outputs, dim=1).float() # [1,H,W,D]
131
+
 
132
  vol_np = volume[0, 0].cpu().numpy()
133
+ pred_np = pred[0].cpu().numpy()
134
+
135
+ # Normalize CT for display
136
  vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
137
+
138
  # Slice selection
139
  z_dim = vol_np.shape[2]
140
+ idx = int(slice_num)
141
+ if idx < 0:
142
+ idx = 0
143
+ if idx >= z_dim:
144
+ idx = z_dim // 2
145
+
146
  # Plot
147
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
148
+
149
+ axes[0].imshow(vol_display[:, :, idx], cmap="gray")
150
+ axes[0].set_title("CT Scan")
151
+ axes[0].axis("off")
152
+
153
+ axes[1].imshow(pred_np[:, :, idx], cmap="Reds", vmin=0, vmax=1)
154
+ axes[1].set_title("Liver Prediction")
155
+ axes[1].axis("off")
156
+
157
+ axes[2].imshow(vol_display[:, :, idx], cmap="gray")
158
+ axes[2].imshow(pred_np[:, :, idx], cmap="Greens", alpha=0.5, vmin=0, vmax=1)
159
+ axes[2].set_title("Overlay")
160
+ axes[2].axis("off")
161
+
 
 
 
162
  plt.tight_layout()
163
+
164
+ # Convert figure to numpy image for Gradio
165
  buf = BytesIO()
166
+ fig.savefig(buf, format="png", bbox_inches="tight")
167
  buf.seek(0)
168
+ img = np.array(Image.open(buf))
169
+ plt.close(fig)
170
+
171
+ # Save prediction as NIfTI for download
172
  pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4))
173
+ out_path = tempfile.mktemp(suffix=".nii.gz")
174
  nib.save(pred_nii, out_path)
175
+
176
  print("Success!")
177
  return img, out_path
178
+
179
  except Exception as e:
180
+ print(f"Error in segment_liver: {e}")
181
  import traceback
182
  traceback.print_exc()
183
  return None, None
184
 
185
+
186
+ # --------- GRADIO INTERFACE ----------
187
  iface = gr.Interface(
188
  fn=segment_liver,
189
  inputs=[
190
+ gr.File(label="Upload NIfTI volume (.nii or .nii.gz)"),
191
+ gr.Slider(0, 127, value=64, label="Slice index"),
192
  ],
193
  outputs=[
194
+ gr.Image(label="Result", type="numpy"),
195
+ gr.File(label="Download Mask (.nii.gz)"),
196
  ],
197
+ title="Liver Segmentation (SwinUNETR, MONAI)",
198
+ description="Upload a 3D liver CT volume (.nii or .nii.gz). The app runs a SwinUNETR model trained on MSD Task03 Liver.",
199
  )
200
 
201
  if __name__ == "__main__":
202
+ iface.launch()