Amodit commited on
Commit
92db826
·
verified ·
1 Parent(s): 2e21ef0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -39
app.py CHANGED
@@ -23,39 +23,41 @@ from monai.transforms import (
23
  EnsureTyped,
24
  )
25
 
26
- print("Loading model...")
27
- # HF Spaces: assume CPU, GPU not guaranteed
 
28
  device = torch.device("cpu")
29
- print(f"Device: {device}")
30
 
31
- # --------- MODEL DEFINITION (must match training) ----------
 
 
32
  model = SwinUNETR(
 
33
  in_channels=1,
34
  out_channels=2,
35
- patch_size=2,
36
  depths=(2, 2, 2, 2),
37
  num_heads=(3, 6, 12, 24),
38
- window_size=7,
39
  feature_size=48,
40
  norm_name="instance",
41
  use_checkpoint=False,
42
  spatial_dims=3,
43
  ).to(device)
44
 
45
- model_path = "best_metric_model.pth"
46
- if os.path.exists(model_path):
47
  try:
48
- state = torch.load(model_path, map_location=device)
49
  model.load_state_dict(state)
50
- print("Model loaded!")
51
  except Exception as e:
52
- print(f"Error loading model: {e}")
53
  else:
54
- print("WARNING: Model file not found")
55
 
56
  model.eval()
57
 
58
- # --------- PREPROCESSING (mirror training pipeline) ----------
59
  test_transforms = Compose(
60
  [
61
  LoadImaged(keys=["image"]),
@@ -79,44 +81,75 @@ test_transforms = Compose(
79
 
80
  def _get_path_from_gradio_file(file_obj):
81
  """
82
- Gradio / HF can pass:
83
- - dict with "name"
84
- - tempfile-like object with .name
85
- - plain string path (local)
86
  """
87
  if file_obj is None:
88
  return None
89
 
 
90
  if isinstance(file_obj, dict):
91
- return file_obj.get("name")
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if hasattr(file_obj, "name"):
93
  return file_obj.name
 
 
94
  if isinstance(file_obj, str):
95
  return file_obj
 
96
  raise ValueError(f"Unsupported file object type: {type(file_obj)}")
97
 
98
 
99
- # --------- INFERENCE FUNCTION ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def segment_liver(file_obj, slice_num=64):
101
  try:
102
  if file_obj is None:
103
- return None, None
104
 
105
  file_path = _get_path_from_gradio_file(file_obj)
106
- print(f"Processing: {file_path}")
107
 
108
  if file_path is None or not os.path.exists(file_path):
109
- raise FileNotFoundError("Uploaded file path not found")
110
 
111
- # Manual extension validation (since we removed file_types)
112
  if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")):
113
  raise ValueError("Invalid file type. Please upload a .nii or .nii.gz NIfTI file.")
114
 
115
  # Preprocess
116
  data_dict = {"image": file_path}
117
  data_dict = test_transforms(data_dict)
118
- volume = data_dict["image"].unsqueeze(0).to(device) # [1,1,H,W,D]
119
- print(f"Input shape: {volume.shape}")
120
 
121
  # Inference
122
  with torch.no_grad():
@@ -127,7 +160,7 @@ def segment_liver(file_obj, slice_num=64):
127
  predictor=model,
128
  overlap=0.25,
129
  )
130
- pred = torch.argmax(outputs, dim=1).float() # [1,H,W,D]
131
 
132
  vol_np = volume[0, 0].cpu().numpy()
133
  pred_np = pred[0].cpu().numpy()
@@ -135,23 +168,21 @@ def segment_liver(file_obj, slice_num=64):
135
  # Normalize CT for display
136
  vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
137
 
138
- # Slice selection
139
  z_dim = vol_np.shape[2]
140
  idx = int(slice_num)
141
- if idx < 0:
142
- idx = 0
143
- if idx >= z_dim:
144
  idx = z_dim // 2
145
 
146
- # Plot
147
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
148
 
149
  axes[0].imshow(vol_display[:, :, idx], cmap="gray")
150
- axes[0].set_title("CT Scan")
151
  axes[0].axis("off")
152
 
153
  axes[1].imshow(pred_np[:, :, idx], cmap="Reds", vmin=0, vmax=1)
154
- axes[1].set_title("Liver Prediction")
155
  axes[1].axis("off")
156
 
157
  axes[2].imshow(vol_display[:, :, idx], cmap="gray")
@@ -161,29 +192,30 @@ def segment_liver(file_obj, slice_num=64):
161
 
162
  plt.tight_layout()
163
 
164
- # Convert figure to numpy image for Gradio
165
  buf = BytesIO()
166
  fig.savefig(buf, format="png", bbox_inches="tight")
167
  buf.seek(0)
168
  img = np.array(Image.open(buf))
169
  plt.close(fig)
170
 
171
- # Save prediction as NIfTI for download
172
  pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4))
173
  out_path = tempfile.mktemp(suffix=".nii.gz")
174
  nib.save(pred_nii, out_path)
175
 
176
- print("Success!")
177
  return img, out_path
178
 
179
  except Exception as e:
180
- print(f"Error in segment_liver: {e}")
181
  import traceback
 
 
182
  traceback.print_exc()
183
- return None, None
184
 
185
 
186
- # --------- GRADIO INTERFACE ----------
187
  iface = gr.Interface(
188
  fn=segment_liver,
189
  inputs=[
@@ -199,4 +231,5 @@ iface = gr.Interface(
199
  )
200
 
201
  if __name__ == "__main__":
 
202
  iface.launch()
 
23
  EnsureTyped,
24
  )
25
 
26
+ print("Starting app...")
27
+
28
+ # ----------------- DEVICE -----------------
29
  device = torch.device("cpu")
30
+ print(f"Using device: {device}")
31
 
32
+ # ----------------- MODEL -----------------
33
+ # NOTE: SwinUNETR in current MONAI versions does NOT take `patch_size` or `window_size`.
34
+ # Use img_size consistent with your pre-processing (Resized to 128x128x64).
35
  model = SwinUNETR(
36
+ img_size=(128, 128, 64),
37
  in_channels=1,
38
  out_channels=2,
 
39
  depths=(2, 2, 2, 2),
40
  num_heads=(3, 6, 12, 24),
 
41
  feature_size=48,
42
  norm_name="instance",
43
  use_checkpoint=False,
44
  spatial_dims=3,
45
  ).to(device)
46
 
47
+ ckpt_path = "best_metric_model.pth"
48
+ if os.path.exists(ckpt_path):
49
  try:
50
+ state = torch.load(ckpt_path, map_location=device)
51
  model.load_state_dict(state)
52
+ print("Model loaded successfully.")
53
  except Exception as e:
54
+ print(f"ERROR loading model weights: {e}")
55
  else:
56
+ print(f"WARNING: checkpoint '{ckpt_path}' not found in Space.")
57
 
58
  model.eval()
59
 
60
+ # ----------------- TRANSFORMS -----------------
61
  test_transforms = Compose(
62
  [
63
  LoadImaged(keys=["image"]),
 
81
 
82
  def _get_path_from_gradio_file(file_obj):
83
  """
84
+ Convert the Gradio file object into a real path on disk.
85
+ Handles dicts, tempfiles, and plain string paths.
 
 
86
  """
87
  if file_obj is None:
88
  return None
89
 
90
+ # Case 1: dict (HF Spaces often passes this)
91
  if isinstance(file_obj, dict):
92
+ if "path" in file_obj and file_obj["path"]:
93
+ return file_obj["path"]
94
+ if "name" in file_obj and file_obj["name"]:
95
+ return file_obj["name"]
96
+ # If we only have raw bytes, write them to a temp file
97
+ if "data" in file_obj and file_obj["data"] is not None:
98
+ suffix = ".nii.gz"
99
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
100
+ tmp.write(file_obj["data"])
101
+ tmp.flush()
102
+ tmp.close()
103
+ return tmp.name
104
+
105
+ # Case 2: tempfile-like with .name
106
  if hasattr(file_obj, "name"):
107
  return file_obj.name
108
+
109
+ # Case 3: already a string path (local testing)
110
  if isinstance(file_obj, str):
111
  return file_obj
112
+
113
  raise ValueError(f"Unsupported file object type: {type(file_obj)}")
114
 
115
 
116
+ def _error_image(msg: str):
117
+ """
118
+ Create a simple image with an error message so the UI
119
+ never looks 'empty' when something goes wrong.
120
+ """
121
+ fig, ax = plt.subplots(figsize=(8, 3))
122
+ ax.text(0.5, 0.5, msg, ha="center", va="center", color="red", fontsize=12)
123
+ ax.axis("off")
124
+ buf = BytesIO()
125
+ fig.savefig(buf, format="png", bbox_inches="tight")
126
+ buf.seek(0)
127
+ img = np.array(Image.open(buf))
128
+ plt.close(fig)
129
+ return img
130
+
131
+
132
+ # ----------------- INFERENCE -----------------
133
  def segment_liver(file_obj, slice_num=64):
134
  try:
135
  if file_obj is None:
136
+ return _error_image("No file uploaded."), None
137
 
138
  file_path = _get_path_from_gradio_file(file_obj)
139
+ print(f"[segment_liver] file_path = {file_path}")
140
 
141
  if file_path is None or not os.path.exists(file_path):
142
+ raise FileNotFoundError("Uploaded file path not found on server.")
143
 
144
+ # Manual extension check
145
  if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")):
146
  raise ValueError("Invalid file type. Please upload a .nii or .nii.gz NIfTI file.")
147
 
148
  # Preprocess
149
  data_dict = {"image": file_path}
150
  data_dict = test_transforms(data_dict)
151
+ volume = data_dict["image"].unsqueeze(0).to(device) # [1, 1, H, W, D]
152
+ print(f"[segment_liver] preprocessed volume shape: {volume.shape}")
153
 
154
  # Inference
155
  with torch.no_grad():
 
160
  predictor=model,
161
  overlap=0.25,
162
  )
163
+ pred = torch.argmax(outputs, dim=1).float() # [1, H, W, D]
164
 
165
  vol_np = volume[0, 0].cpu().numpy()
166
  pred_np = pred[0].cpu().numpy()
 
168
  # Normalize CT for display
169
  vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
170
 
171
+ # Handle slice index safely
172
  z_dim = vol_np.shape[2]
173
  idx = int(slice_num)
174
+ if idx < 0 or idx >= z_dim:
 
 
175
  idx = z_dim // 2
176
 
177
+ # Plot CT / mask / overlay
178
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
179
 
180
  axes[0].imshow(vol_display[:, :, idx], cmap="gray")
181
+ axes[0].set_title("CT Slice")
182
  axes[0].axis("off")
183
 
184
  axes[1].imshow(pred_np[:, :, idx], cmap="Reds", vmin=0, vmax=1)
185
+ axes[1].set_title("Predicted Liver Mask")
186
  axes[1].axis("off")
187
 
188
  axes[2].imshow(vol_display[:, :, idx], cmap="gray")
 
192
 
193
  plt.tight_layout()
194
 
195
+ # Convert figure to numpy image
196
  buf = BytesIO()
197
  fig.savefig(buf, format="png", bbox_inches="tight")
198
  buf.seek(0)
199
  img = np.array(Image.open(buf))
200
  plt.close(fig)
201
 
202
+ # Save prediction mask as NIfTI for download
203
  pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), np.eye(4))
204
  out_path = tempfile.mktemp(suffix=".nii.gz")
205
  nib.save(pred_nii, out_path)
206
 
207
+ print("[segment_liver] success.")
208
  return img, out_path
209
 
210
  except Exception as e:
 
211
  import traceback
212
+
213
+ print("[segment_liver] ERROR:", e)
214
  traceback.print_exc()
215
+ return _error_image(f"Error: {e}"), None
216
 
217
 
218
+ # ----------------- GRADIO UI -----------------
219
  iface = gr.Interface(
220
  fn=segment_liver,
221
  inputs=[
 
231
  )
232
 
233
  if __name__ == "__main__":
234
+ # On HF Spaces: iface.launch(server_name=\"0.0.0.0\", server_port=7860)
235
  iface.launch()