drankush-ai commited on
Commit
5638d3f
·
verified ·
1 Parent(s): 9d8ce80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -195
app.py CHANGED
@@ -11,195 +11,129 @@ from fastai.learner import load_learner
11
  from fastai.basics import load_pickle
12
  import pickle
13
  import traceback
 
14
 
15
- # -----------------------------
16
- # Utility functions
17
- # -----------------------------
18
-
19
  def extract_slices_from_mask(img, mask_data, view):
20
  """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
21
  slices = []
22
  target_size = (320, 320)
23
- for idx in range(img.shape[6] if view == "Sagittal" else img.shape[21] if view == "Axial" else img.shape):
 
24
  if view == "Sagittal":
25
  slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
26
  elif view == "Axial":
27
  slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
28
  elif view == "Coronal":
29
  slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
 
30
  slice_img = np.fliplr(np.rot90(slice_img, -1))
31
  slice_mask = np.fliplr(np.rot90(slice_mask, -1))
 
32
  slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
33
  slices.append((slice_img_resized, slice_mask_resized))
 
34
  return slices
35
 
 
36
  def resize_and_pad(slice_img, slice_mask, target_size):
37
  """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
38
  h, w = slice_img.shape
39
- scale = min(target_size / w, target_size[21] / h)
40
  new_w, new_h = int(w * scale), int(h * scale)
 
41
  resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
42
  resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
43
- pad_w = (target_size - new_w) // 2
44
- pad_h = (target_size[21] - new_h) // 2
45
- padded_img = np.pad(resized_img, ((pad_h, target_size[21] - new_h - pad_h), (pad_w, target_size - new_w - pad_w)),
46
- mode='constant', constant_values=0)
47
- padded_mask = np.pad(resized_mask, ((pad_h, target_size[21] - new_h - pad_h), (pad_w, target_size - new_w - pad_w)),
48
- mode='constant', constant_values=0)
 
49
  return padded_img, padded_mask
50
 
 
51
  def normalize_image(slice_img):
52
  """Normalize the image to the range [0, 255] safely."""
53
  slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
54
- if slice_img_min == slice_img_max:
55
  return np.zeros_like(slice_img, dtype=np.uint8)
56
  normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
57
  return normalized_img.astype(np.uint8)
58
 
 
59
  def get_fused_image(img, pred_mask, view, alpha=0.8):
60
  """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
61
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
62
  mask_color = np.array([255, 0, 0])
63
  colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
 
64
  fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
65
- fused_flipped = cv2.flip(fused, -1)
 
 
 
66
  if view == 'Sagittal':
67
  return fused_flipped
68
  elif view == 'Coronal' or view == 'Axial':
69
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
70
  return rotated
71
 
72
- # -----------------------------
73
- # Robust ZNormalization patch
74
- # -----------------------------
75
-
76
- def apply_znorm_compat_patches(learn):
77
- """
78
- Make ZNormalization resilient to missing attributes on unpickled transforms and
79
- repair existing instances inside the learner pipelines.
80
- """
81
- # Import here to avoid import-time side effects before libs are available
82
- from fastMONAI.vision_augmentation import ZNormalization, _do_z_normalization
83
- from fastMONAI.vision_core import MedImage
84
-
85
- # Robust encodes that tolerates missing attrs on older pickles
86
- def _robust_znorm_encodes(self, o: (MedImage)):
87
- masking_method = getattr(self, 'masking_method', None)
88
- channel_wise = getattr(self, 'channel_wise', True)
89
- return MedImage.create(_do_z_normalization(o, masking_method, channel_wise))
90
-
91
- # Monkey-patch class method so any new instance uses the robust encodes
92
- ZNormalization.encodes = _robust_znorm_encodes
93
-
94
- # Ensure attributes exist on already loaded instances inside learner pipelines
95
- def _ensure_attrs_on_pipeline(pipeline):
96
- if not hasattr(pipeline, 'fs'):
97
- return
98
- for tfm in pipeline.fs:
99
- if isinstance(tfm, ZNormalization):
100
- if not hasattr(tfm, 'masking_method'):
101
- tfm.masking_method = None
102
- if not hasattr(tfm, 'channel_wise'):
103
- tfm.channel_wise = True
104
-
105
- for p in (getattr(learn.dls, 'after_item', None),
106
- getattr(learn.dls, 'before_batch', None),
107
- getattr(learn.dls, 'after_batch', None)):
108
- if p is not None:
109
- _ensure_attrs_on_pipeline(p)
110
-
111
- # -----------------------------
112
- # Inference
113
- # -----------------------------
114
-
115
- def _direct_model_inference(learn, x_tensor):
116
- """
117
- Direct model inference bypassing fastai's predict() in case transforms fail.
118
- Applies a simple channel-wise Z-normalization over non-zero voxels.
119
- """
120
- with torch.no_grad():
121
- x = x_tensor
122
- if x.dim() == 4:
123
- x = x.unsqueeze(0) # [1, C, W, H, D]
124
- x_ = x.clone()
125
- # Channel-wise z-score over non-zero voxels
126
- for c in range(x_.shape[21]):
127
- vol = x_[:, c]
128
- nz = vol != 0
129
- if nz.any():
130
- m = vol[nz].mean()
131
- s = vol[nz].std()
132
- if s > 0:
133
- vol[nz] = (vol[nz] - m) / s
134
- else:
135
- vol[nz] = 0
136
- x_[:, c] = vol
137
- learn.model.eval()
138
- logits = learn.model(x_) # [B, classes, W, H, D]
139
- pred = torch.argmax(logits, dim=1) # [B, W, H, D]
140
- return pred # [B, W, H, D]
141
-
142
  def inference(learn, reorder, resample, org_img, input_img, org_size):
143
- """Perform segmentation using the loaded model, with robust fallback."""
 
144
  if not isinstance(input_img, torch.Tensor):
145
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
146
- # Try the standard fastai predict path first (applies the learner's transforms)
147
- # If transforms are not compatible due to pickled attribute differences, fallback to direct model inference.
148
- try:
149
- with torch.no_grad():
150
- pred = learn.predict(input_img) # returns (decoded, raw, probs/...), depending on setup
151
- # Extract mask-like output
152
- mask_data = pred if isinstance(pred, (list, tuple)) else pred
153
- if isinstance(mask_data, torch.Tensor):
154
- if mask_data.dim() == 3:
155
- mask_data = mask_data.unsqueeze(0) # [1, W, H, D]
156
- elif mask_data.dim() == 4 and mask_data.shape != 1:
157
- # If [B, W, H, D] with B>1, keep as is; downstream expects batch dim
158
- pass
159
- return mask_data
160
- # If not a tensor, try to coerce, else fallback
161
- except Exception as e:
162
- print(f"[WARN] learn.predict failed, falling back to direct model inference. Reason: {e}")
163
-
164
- # Fallback: call the model directly on the tensor
165
- pred = _direct_model_inference(learn, input_img) # [B, W, H, D]
166
- return pred
167
-
168
- # -----------------------------
169
- # App prediction wrapper
170
- # -----------------------------
171
-
172
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
173
  """Predict function using the learner and other resources."""
 
174
  if view is None:
175
  view = 'Sagittal'
176
-
177
  img_path = Path(fileobj.name)
178
  save_fn = 'pred_' + img_path.stem
179
  save_path = save_dir / save_fn
180
-
181
- # Read medical image with robust handling of return formats
182
  try:
183
- result = med_img_reader(str(img_path),
184
- reorder=reorder,
185
- resample=resample,
186
- only_tensor=False,
187
- dtype=torch.Tensor)
 
 
188
  if isinstance(result, tuple):
189
  if len(result) == 3:
190
  org_img, input_img, org_size = result
191
- # Depending on fastMONAI version, input_img may be a ScalarImage; if so, extract tensor
192
- if hasattr(input_img, 'data') and not isinstance(input_img, torch.Tensor):
193
- input_img = input_img.data.type(torch.float)
194
  elif len(result) == 2:
195
- # Some versions return (org_img, metadata_dict)
196
  org_img, metadata_dict = result
197
- input_img = org_img.data.type(torch.float) if hasattr(org_img, 'data') else org_img
198
- org_size = org_img.shape[1:] if hasattr(org_img, 'shape') else None
 
199
  else:
200
  raise ValueError(f"Unexpected number of return values: {len(result)}")
201
  else:
 
202
  raise ValueError(f"Expected tuple but got {type(result)}")
 
203
  except Exception as e:
204
  print(f"DEBUG: Error in med_img_reader: {str(e)}")
205
  if 'result' in locals():
@@ -207,48 +141,40 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
207
  if hasattr(result, '__len__'):
208
  print(f"DEBUG: result length: {len(result)}")
209
  raise ValueError(f"Error reading medical image: {str(e)}")
210
-
211
- # Ensure tensor with appropriate dims [C, W, H, D] or [1, C, W, H, D]
212
  if not isinstance(input_img, torch.Tensor):
213
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
214
- if input_img.dim() == 4:
215
- x_in = input_img.unsqueeze(0) # [1, C, W, H, D]
216
- else:
217
- x_in = input_img
218
-
219
- # Perform inference (predict or direct fallback)
220
  mask_data = inference(learn, reorder=reorder, resample=resample,
221
- org_img=org_img, input_img=x_in,
222
  org_size=org_size)
223
-
224
- # Optional orientation handling (example logic retained)
225
- if hasattr(org_img, 'orientation') and "".join(org_img.orientation) == "LSA":
226
- mask_data = mask_data.permute(0, 1, 3, 2)
227
- mask_data = torch.flip(mask_data, dims=[21])
228
  mask_data = torch.Tensor(mask_data)[None]
229
 
230
- # Save the prediction as a TorchIO image using original image object
231
- img = org_img.data if hasattr(org_img, 'data') else None
232
  org_img.set_data(mask_data)
233
  org_img.save(save_path)
234
 
235
- # Build fused gallery
236
- base_img_np = img.numpy() if img is not None else input_img.numpy()
237
- pred_np = mask_data.detach().cpu().numpy().astype(np.uint8)
238
- slices = extract_slices_from_mask(base_img_np, pred_np, view)
239
- fused_images = [
240
- get_fused_image(normalize_image(slice_img), slice_mask, view)
241
- for slice_img, slice_mask in slices
242
- ]
243
-
244
  # Compute volume
245
  volume = compute_binary_tumor_volume(org_img)
246
- return fused_images, round(volume, 2)
247
 
248
- # -----------------------------
249
- # Resources
250
- # -----------------------------
251
 
 
252
  def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
253
  """Load the model and other required resources."""
254
  try:
@@ -260,44 +186,33 @@ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_f
260
  try:
261
  with open(models_path / variables_fn, 'rb') as f:
262
  variables = pickle.load(f)
263
-
264
- # Accept dict or list/tuple; be defensive against shape changes
265
- shape = None
266
- if isinstance(variables, dict):
267
- reorder = bool(variables.get('reorder', False))
268
- resample = variables.get('resample', [1.25, 1.25, 1.37])
269
- elif isinstance(variables, (list, tuple)):
270
- if len(variables) >= 3:
271
- shape, reorder, resample = variables, variables[19], variables[20]
272
- elif len(variables) == 2:
273
- reorder, resample = variables, variables[19]
274
- else:
275
- raise ValueError(f"vars.pkl has unexpected length: {len(variables)}")
276
- else:
277
- raise ValueError(f"vars.pkl has unexpected type: {type(variables)}")
278
-
279
- # Validate types strictly
280
  if not isinstance(reorder, bool):
281
- raise ValueError(f"vars.pkl 'reorder' must be bool, got {type(reorder)}")
282
- if not (isinstance(resample, (list, tuple)) and len(resample) == 3):
283
- raise ValueError(f"vars.pkl 'resample' must be a list/tuple of len 3, got {resample}")
284
- resample = list(resample)
285
-
286
  except Exception as e:
287
  raise ValueError(f"Error loading variables: {str(e)}")
288
 
289
  return learn, reorder, resample
290
 
291
-
292
- # -----------------------------
293
- # Main
294
- # -----------------------------
295
-
296
  try:
297
  print("🚀 Initializing application...")
298
-
299
  clone_dir = Path.cwd() / 'clone_dir'
300
  URI = os.getenv('PAT_Token_URI')
 
301
  if not URI:
302
  raise ValueError("PAT_Token_URI environment variable is not set")
303
 
@@ -307,26 +222,72 @@ try:
307
  print(f"📥 Cloning repository from {URI}")
308
  Repo.clone_from(URI, clone_dir)
309
  print(f"✅ Repository cloned to {clone_dir}")
310
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  models_path = clone_dir
312
  save_dir = Path.cwd() / 'hs_pred'
313
  save_dir.mkdir(parents=True, exist_ok=True)
314
  print(f"✅ Save directory created at {save_dir}")
315
 
316
- # Load model and variables
317
  print("🔍 Loading model and resources...")
318
  learn, reorder, resample = load_system_resources(models_path=models_path)
319
  print(f"✅ Model loaded successfully")
320
  print(f"✅ Reorder: {reorder}, Resample: {resample}")
321
 
322
- # Apply robust ZNormalization compatibility patches BEFORE any predict is called
323
- apply_znorm_compat_patches(learn)
324
- print("✅ Applied ZNormalization compatibility patches")
325
-
326
- # Gradio UI
327
  output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
328
  view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
329
 
 
330
  example_path = str(clone_dir / "sample.nii.gz")
331
  if not os.path.exists(example_path):
332
  print(f"⚠️ Example file not found: {example_path}")
@@ -340,17 +301,18 @@ try:
340
  inputs=["file", view_selector],
341
  outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
342
  examples=examples,
343
- allow_flagging='never',
344
- cache_examples=False # avoid eager caching running predict at startup
345
- )
346
-
347
  print("✅ Gradio interface initialized successfully")
348
-
349
  except Exception as e:
350
  print(f"❌ Error during initialization: {str(e)}")
351
  print(f"Error type: {type(e).__name__}")
352
  traceback.print_exc()
 
353
  exit(1)
354
 
 
355
  print("🌐 Launching Gradio interface...")
356
- demo.launch()
 
11
  from fastai.basics import load_pickle
12
  import pickle
13
  import traceback
14
+ import torchio as tio # Added for the patch
15
 
16
+ # Function to extract slices from mask
 
 
 
17
  def extract_slices_from_mask(img, mask_data, view):
18
  """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
19
  slices = []
20
  target_size = (320, 320)
21
+
22
+ for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
23
  if view == "Sagittal":
24
  slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
25
  elif view == "Axial":
26
  slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
27
  elif view == "Coronal":
28
  slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
29
+
30
  slice_img = np.fliplr(np.rot90(slice_img, -1))
31
  slice_mask = np.fliplr(np.rot90(slice_mask, -1))
32
+
33
  slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
34
  slices.append((slice_img_resized, slice_mask_resized))
35
+
36
  return slices
37
 
38
+ # Function to resize and pad slices
39
  def resize_and_pad(slice_img, slice_mask, target_size):
40
  """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
41
  h, w = slice_img.shape
42
+ scale = min(target_size[0] / w, target_size[1] / h)
43
  new_w, new_h = int(w * scale), int(h * scale)
44
+
45
  resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
46
  resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
47
+
48
+ pad_w = (target_size[0] - new_w) // 2
49
+ pad_h = (target_size[1] - new_h) // 2
50
+
51
+ padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
52
+ padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
53
+
54
  return padded_img, padded_mask
55
 
56
+ # Function to normalize image
57
  def normalize_image(slice_img):
58
  """Normalize the image to the range [0, 255] safely."""
59
  slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
60
+ if slice_img_min == slice_img_max: # Avoid division by zero
61
  return np.zeros_like(slice_img, dtype=np.uint8)
62
  normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
63
  return normalized_img.astype(np.uint8)
64
 
65
+ # Function to get fused image
66
  def get_fused_image(img, pred_mask, view, alpha=0.8):
67
  """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
68
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
69
  mask_color = np.array([255, 0, 0])
70
  colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
71
+
72
  fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
73
+
74
+ # Flip the fused image vertically and horizontally
75
+ fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
76
+
77
  if view == 'Sagittal':
78
  return fused_flipped
79
  elif view == 'Coronal' or view == 'Axial':
80
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
81
  return rotated
82
 
83
+ # Define the inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def inference(learn, reorder, resample, org_img, input_img, org_size):
85
+ """Perform segmentation using the loaded model."""
86
+ # Ensure input_img is a torch.Tensor
87
  if not isinstance(input_img, torch.Tensor):
88
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
89
+
90
+ # Add batch dimension if needed
91
+ if input_img.dim() == 4:
92
+ input_img = input_img.unsqueeze(0)
93
+
94
+ # Perform the segmentation
95
+ with torch.no_grad():
96
+ pred = learn.predict(input_img)
97
+
98
+ # Process the prediction if necessary
99
+ mask_data = pred[0] if isinstance(pred, (list, tuple)) else pred
100
+
101
+ return mask_data
102
+
103
+ # Function for Gradio image segmentation
 
 
 
 
 
 
 
 
 
 
 
104
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
105
  """Predict function using the learner and other resources."""
106
+
107
  if view is None:
108
  view = 'Sagittal'
109
+
110
  img_path = Path(fileobj.name)
111
  save_fn = 'pred_' + img_path.stem
112
  save_path = save_dir / save_fn
113
+
114
+ # Read the medical image - handle variable return values
115
  try:
116
+ result = med_img_reader(str(img_path),
117
+ reorder=reorder,
118
+ resample=resample,
119
+ only_tensor=False,
120
+ dtype=torch.Tensor)
121
+
122
+ # Handle different return formats
123
  if isinstance(result, tuple):
124
  if len(result) == 3:
125
  org_img, input_img, org_size = result
 
 
 
126
  elif len(result) == 2:
 
127
  org_img, metadata_dict = result
128
+ # Extract the tensor from the ScalarImage object
129
+ input_img = org_img.data # This is the actual tensor!
130
+ org_size = org_img.shape[1:] # Infer from org_img
131
  else:
132
  raise ValueError(f"Unexpected number of return values: {len(result)}")
133
  else:
134
+ # Single value returned (shouldn't happen with only_tensor=False)
135
  raise ValueError(f"Expected tuple but got {type(result)}")
136
+
137
  except Exception as e:
138
  print(f"DEBUG: Error in med_img_reader: {str(e)}")
139
  if 'result' in locals():
 
141
  if hasattr(result, '__len__'):
142
  print(f"DEBUG: result length: {len(result)}")
143
  raise ValueError(f"Error reading medical image: {str(e)}")
144
+
145
+ # Ensure input_img is a torch.Tensor
146
  if not isinstance(input_img, torch.Tensor):
147
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
148
+
149
+ # Perform inference
 
 
 
 
150
  mask_data = inference(learn, reorder=reorder, resample=resample,
151
+ org_img=org_img, input_img=input_img,
152
  org_size=org_size)
153
+
154
+ # Handle orientation if needed
155
+ if hasattr(org_img, 'orientation') and "".join(org_img.orientation) == "LSA":
156
+ mask_data = mask_data.permute(0,1,3,2)
157
+ mask_data = torch.flip(mask_data[0], dims=[1])
158
  mask_data = torch.Tensor(mask_data)[None]
159
 
160
+ # Save the prediction
161
+ img = org_img.data
162
  org_img.set_data(mask_data)
163
  org_img.save(save_path)
164
 
165
+ # Extract slices and create fused images
166
+ slices = extract_slices_from_mask(img[0].numpy(), mask_data[0].numpy(), view)
167
+ fused_images = [(get_fused_image(
168
+ normalize_image(slice_img),
169
+ slice_mask, view))
170
+ for slice_img, slice_mask in slices]
171
+
 
 
172
  # Compute volume
173
  volume = compute_binary_tumor_volume(org_img)
 
174
 
175
+ return fused_images, round(volume, 2)
 
 
176
 
177
+ # Function to load system resources
178
  def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
179
  """Load the model and other required resources."""
180
  try:
 
186
  try:
187
  with open(models_path / variables_fn, 'rb') as f:
188
  variables = pickle.load(f)
189
+
190
+ if not isinstance(variables, list) or len(variables) != 3:
191
+ raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
192
+
193
+ # Assuming the format is [shape, reorder, resample]
194
+ shape = variables[0]
195
+ reorder = variables[1]
196
+ resample = variables[2]
197
+
 
 
 
 
 
 
 
 
198
  if not isinstance(reorder, bool):
199
+ raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
200
+
201
+ if not isinstance(resample, list) or len(resample) != 3:
202
+ raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
203
+
204
  except Exception as e:
205
  raise ValueError(f"Error loading variables: {str(e)}")
206
 
207
  return learn, reorder, resample
208
 
209
+ # Initialize the system
 
 
 
 
210
  try:
211
  print("🚀 Initializing application...")
212
+
213
  clone_dir = Path.cwd() / 'clone_dir'
214
  URI = os.getenv('PAT_Token_URI')
215
+
216
  if not URI:
217
  raise ValueError("PAT_Token_URI environment variable is not set")
218
 
 
222
  print(f"📥 Cloning repository from {URI}")
223
  Repo.clone_from(URI, clone_dir)
224
  print(f"✅ Repository cloned to {clone_dir}")
225
+
226
+ # ===== MONKEY PATCH FOR ZNORMALIZATION ERROR =====
227
+ # This patch fixes the 'ZNormalization' object has no attribute 'masking_method' error
228
+ # Import necessary modules
229
+ from fastMONAI.vision_augmentation import ZNormalization
230
+ from fastMONAI.vision_core import MedImage
231
+
232
+ # Create a patched version of the encodes method
233
+ def patched_encodes(self, o):
234
+ """Patched version of ZNormalization.encodes to handle attribute mismatch."""
235
+ # Handle the attribute mismatch - ensure we have what we need
236
+ if not hasattr(self, 'z_normalization'):
237
+ # Get masking_method with fallback
238
+ masking_method = getattr(self, 'masking_method', None)
239
+ self.z_normalization = tio.ZNormalization(masking_method=masking_method)
240
+
241
+ # Ensure channel_wise exists
242
+ if not hasattr(self, 'channel_wise'):
243
+ self.channel_wise = True
244
+
245
+ # Use the current implementation logic
246
+ try:
247
+ if self.channel_wise:
248
+ o = torch.stack([self.z_normalization(c[None])[0] for c in o])
249
+ else:
250
+ o = self.z_normalization(o)
251
+ except RuntimeError as e:
252
+ if "Standard deviation is 0" in str(e):
253
+ mean = float(o.mean())
254
+ error_msg = (
255
+ f"Standard deviation is 0 for image (mean={mean:.3f}).\n"
256
+ f"This indicates uniform pixel values.\n\n"
257
+ f"Possible causes:\n"
258
+ f"• Corrupted or blank image\n"
259
+ f"• Oversaturated regions\n"
260
+ f"• Background-only regions\n"
261
+ f"• All-zero mask being processed as image\n\n"
262
+ f"Suggested solutions:\n"
263
+ f"• Check image quality and acquisition\n"
264
+ f"• Verify image vs mask data loading"
265
+ )
266
+ raise RuntimeError(error_msg) from e
267
+
268
+ return MedImage.create(o)
269
+
270
+ # Replace the problematic method
271
+ ZNormalization.encodes = patched_encodes
272
+ print("✅ Applied ZNormalization patch to fix attribute error")
273
+ # ===== END OF MONKEY PATCH =====
274
+
275
  models_path = clone_dir
276
  save_dir = Path.cwd() / 'hs_pred'
277
  save_dir.mkdir(parents=True, exist_ok=True)
278
  print(f"✅ Save directory created at {save_dir}")
279
 
280
+ # Load the model and other required resources
281
  print("🔍 Loading model and resources...")
282
  learn, reorder, resample = load_system_resources(models_path=models_path)
283
  print(f"✅ Model loaded successfully")
284
  print(f"✅ Reorder: {reorder}, Resample: {resample}")
285
 
286
+ # Gradio interface setup
 
 
 
 
287
  output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
288
  view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
289
 
290
+ # Ensure the example file path is correct
291
  example_path = str(clone_dir / "sample.nii.gz")
292
  if not os.path.exists(example_path):
293
  print(f"⚠️ Example file not found: {example_path}")
 
301
  inputs=["file", view_selector],
302
  outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
303
  examples=examples,
304
+ cache_examples=False, # Disable example caching to prevent errors during startup
305
+ allow_flagging='never')
306
+
 
307
  print("✅ Gradio interface initialized successfully")
308
+
309
  except Exception as e:
310
  print(f"❌ Error during initialization: {str(e)}")
311
  print(f"Error type: {type(e).__name__}")
312
  traceback.print_exc()
313
+ # Exit with error code
314
  exit(1)
315
 
316
+ # Launch the Gradio interface
317
  print("🌐 Launching Gradio interface...")
318
+ demo.launch()