drankush-ai commited on
Commit
2fecb9a
·
verified ·
1 Parent(s): 87d8e32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -101
app.py CHANGED
@@ -12,127 +12,194 @@ from fastai.basics import load_pickle
12
  import pickle
13
  import traceback
14
 
15
- # Function to extract slices from mask
 
 
 
16
  def extract_slices_from_mask(img, mask_data, view):
17
  """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
18
  slices = []
19
  target_size = (320, 320)
20
-
21
- for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
22
  if view == "Sagittal":
23
  slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
24
  elif view == "Axial":
25
  slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
26
  elif view == "Coronal":
27
  slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
28
-
29
  slice_img = np.fliplr(np.rot90(slice_img, -1))
30
  slice_mask = np.fliplr(np.rot90(slice_mask, -1))
31
-
32
  slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
33
  slices.append((slice_img_resized, slice_mask_resized))
34
-
35
  return slices
36
 
37
- # Function to resize and pad slices
38
  def resize_and_pad(slice_img, slice_mask, target_size):
39
  """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
40
  h, w = slice_img.shape
41
- scale = min(target_size[0] / w, target_size[1] / h)
42
  new_w, new_h = int(w * scale), int(h * scale)
43
-
44
  resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
45
  resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
46
-
47
- pad_w = (target_size[0] - new_w) // 2
48
- pad_h = (target_size[1] - new_h) // 2
49
-
50
- padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
51
- padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
52
-
53
  return padded_img, padded_mask
54
 
55
- # Function to normalize image
56
  def normalize_image(slice_img):
57
  """Normalize the image to the range [0, 255] safely."""
58
  slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
59
- if slice_img_min == slice_img_max: # Avoid division by zero
60
  return np.zeros_like(slice_img, dtype=np.uint8)
61
  normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
62
  return normalized_img.astype(np.uint8)
63
 
64
- # Function to get fused image
65
  def get_fused_image(img, pred_mask, view, alpha=0.8):
66
  """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
67
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
68
  mask_color = np.array([255, 0, 0])
69
  colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
70
-
71
  fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
72
-
73
- # Flip the fused image vertically and horizontally
74
- fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
75
-
76
  if view == 'Sagittal':
77
  return fused_flipped
78
  elif view == 'Coronal' or view == 'Axial':
79
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
80
  return rotated
81
 
82
- # Define the inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def inference(learn, reorder, resample, org_img, input_img, org_size):
84
- """Perform segmentation using the loaded model."""
85
- # Ensure input_img is a torch.Tensor
86
  if not isinstance(input_img, torch.Tensor):
87
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
88
-
89
- # Add batch dimension if needed
90
- if input_img.dim() == 4:
91
- input_img = input_img.unsqueeze(0)
92
-
93
- # Perform the segmentation
94
- with torch.no_grad():
95
- pred = learn.predict(input_img)
96
-
97
- # Process the prediction if necessary
98
- mask_data = pred[0] if isinstance(pred, (list, tuple)) else pred
99
-
100
- return mask_data
101
-
102
- # Function for Gradio image segmentation
 
 
 
 
 
 
 
 
 
 
 
103
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
104
  """Predict function using the learner and other resources."""
105
-
106
  if view is None:
107
  view = 'Sagittal'
108
-
109
  img_path = Path(fileobj.name)
110
  save_fn = 'pred_' + img_path.stem
111
  save_path = save_dir / save_fn
112
-
113
- # Read the medical image - handle variable return values
114
  try:
115
- result = med_img_reader(str(img_path),
116
- reorder=reorder,
117
- resample=resample,
118
- only_tensor=False,
119
- dtype=torch.Tensor)
120
-
121
- # Handle different return formats
122
  if isinstance(result, tuple):
123
  if len(result) == 3:
124
  org_img, input_img, org_size = result
 
 
 
125
  elif len(result) == 2:
 
126
  org_img, metadata_dict = result
127
- # Extract the tensor from the ScalarImage object
128
- input_img = org_img.data # This is the actual tensor!
129
- org_size = org_img.shape[1:] # Infer from org_img
130
  else:
131
  raise ValueError(f"Unexpected number of return values: {len(result)}")
132
  else:
133
- # Single value returned (shouldn't happen with only_tensor=False)
134
  raise ValueError(f"Expected tuple but got {type(result)}")
135
-
136
  except Exception as e:
137
  print(f"DEBUG: Error in med_img_reader: {str(e)}")
138
  if 'result' in locals():
@@ -140,40 +207,48 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
140
  if hasattr(result, '__len__'):
141
  print(f"DEBUG: result length: {len(result)}")
142
  raise ValueError(f"Error reading medical image: {str(e)}")
143
-
144
- # Ensure input_img is a torch.Tensor
145
  if not isinstance(input_img, torch.Tensor):
146
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
147
-
148
- # Perform inference
 
 
 
 
149
  mask_data = inference(learn, reorder=reorder, resample=resample,
150
- org_img=org_img, input_img=input_img,
151
  org_size=org_size)
152
-
153
- # Handle orientation if needed
154
- if hasattr(org_img, 'orientation') and "".join(org_img.orientation) == "LSA":
155
- mask_data = mask_data.permute(0,1,3,2)
156
- mask_data = torch.flip(mask_data[0], dims=[1])
157
  mask_data = torch.Tensor(mask_data)[None]
158
 
159
- # Save the prediction
160
- img = org_img.data
161
  org_img.set_data(mask_data)
162
  org_img.save(save_path)
163
 
164
- # Extract slices and create fused images
165
- slices = extract_slices_from_mask(img[0].numpy(), mask_data[0].numpy(), view)
166
- fused_images = [(get_fused_image(
167
- normalize_image(slice_img),
168
- slice_mask, view))
169
- for slice_img, slice_mask in slices]
170
-
 
 
171
  # Compute volume
172
  volume = compute_binary_tumor_volume(org_img)
173
-
174
  return fused_images, round(volume, 2)
175
 
176
- # Function to load system resources
 
 
 
177
  def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
178
  """Load the model and other required resources."""
179
  try:
@@ -185,33 +260,29 @@ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_f
185
  try:
186
  with open(models_path / variables_fn, 'rb') as f:
187
  variables = pickle.load(f)
188
-
189
  if not isinstance(variables, list) or len(variables) != 3:
190
  raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
191
-
192
- # Assuming the format is [shape, reorder, resample]
193
- shape = variables[0]
194
- reorder = variables[1]
195
- resample = variables[2]
196
-
197
  if not isinstance(reorder, bool):
198
  raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
199
-
200
  if not isinstance(resample, list) or len(resample) != 3:
201
  raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
202
-
203
  except Exception as e:
204
  raise ValueError(f"Error loading variables: {str(e)}")
205
 
206
  return learn, reorder, resample
207
 
208
- # Initialize the system
 
 
 
209
  try:
210
  print("🚀 Initializing application...")
211
-
212
  clone_dir = Path.cwd() / 'clone_dir'
213
  URI = os.getenv('PAT_Token_URI')
214
-
215
  if not URI:
216
  raise ValueError("PAT_Token_URI environment variable is not set")
217
 
@@ -227,17 +298,20 @@ try:
227
  save_dir.mkdir(parents=True, exist_ok=True)
228
  print(f"✅ Save directory created at {save_dir}")
229
 
230
- # Load the model and other required resources
231
  print("🔍 Loading model and resources...")
232
  learn, reorder, resample = load_system_resources(models_path=models_path)
233
  print(f"✅ Model loaded successfully")
234
  print(f"✅ Reorder: {reorder}, Resample: {resample}")
235
 
236
- # Gradio interface setup
 
 
 
 
237
  output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
238
  view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
239
 
240
- # Ensure the example file path is correct
241
  example_path = str(clone_dir / "sample.nii.gz")
242
  if not os.path.exists(example_path):
243
  print(f"⚠️ Example file not found: {example_path}")
@@ -251,18 +325,17 @@ try:
251
  inputs=["file", view_selector],
252
  outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
253
  examples=examples,
254
- cache_examples=False, # ADD THIS PARAMETER
255
- allow_flagging='never')
256
-
 
257
  print("✅ Gradio interface initialized successfully")
258
-
259
  except Exception as e:
260
  print(f"❌ Error during initialization: {str(e)}")
261
  print(f"Error type: {type(e).__name__}")
262
  traceback.print_exc()
263
- # Exit with error code
264
  exit(1)
265
 
266
- # Launch the Gradio interface
267
  print("🌐 Launching Gradio interface...")
268
- demo.launch()
 
12
  import pickle
13
  import traceback
14
 
15
+ # -----------------------------
16
+ # Utility functions
17
+ # -----------------------------
18
+
19
  def extract_slices_from_mask(img, mask_data, view):
20
  """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
21
  slices = []
22
  target_size = (320, 320)
23
+ for idx in range(img.shape[6] if view == "Sagittal" else img.shape[21] if view == "Axial" else img.shape):
 
24
  if view == "Sagittal":
25
  slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
26
  elif view == "Axial":
27
  slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
28
  elif view == "Coronal":
29
  slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
 
30
  slice_img = np.fliplr(np.rot90(slice_img, -1))
31
  slice_mask = np.fliplr(np.rot90(slice_mask, -1))
 
32
  slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
33
  slices.append((slice_img_resized, slice_mask_resized))
 
34
  return slices
35
 
 
36
  def resize_and_pad(slice_img, slice_mask, target_size):
37
  """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
38
  h, w = slice_img.shape
39
+ scale = min(target_size / w, target_size[21] / h)
40
  new_w, new_h = int(w * scale), int(h * scale)
 
41
  resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
42
  resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
43
+ pad_w = (target_size - new_w) // 2
44
+ pad_h = (target_size[21] - new_h) // 2
45
+ padded_img = np.pad(resized_img, ((pad_h, target_size[21] - new_h - pad_h), (pad_w, target_size - new_w - pad_w)),
46
+ mode='constant', constant_values=0)
47
+ padded_mask = np.pad(resized_mask, ((pad_h, target_size[21] - new_h - pad_h), (pad_w, target_size - new_w - pad_w)),
48
+ mode='constant', constant_values=0)
 
49
  return padded_img, padded_mask
50
 
 
51
  def normalize_image(slice_img):
52
  """Normalize the image to the range [0, 255] safely."""
53
  slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
54
+ if slice_img_min == slice_img_max:
55
  return np.zeros_like(slice_img, dtype=np.uint8)
56
  normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
57
  return normalized_img.astype(np.uint8)
58
 
 
59
  def get_fused_image(img, pred_mask, view, alpha=0.8):
60
  """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
61
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
62
  mask_color = np.array([255, 0, 0])
63
  colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
 
64
  fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
65
+ fused_flipped = cv2.flip(fused, -1)
 
 
 
66
  if view == 'Sagittal':
67
  return fused_flipped
68
  elif view == 'Coronal' or view == 'Axial':
69
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
70
  return rotated
71
 
72
+ # -----------------------------
73
+ # Robust ZNormalization patch
74
+ # -----------------------------
75
+
76
+ def apply_znorm_compat_patches(learn):
77
+ """
78
+ Make ZNormalization resilient to missing attributes on unpickled transforms and
79
+ repair existing instances inside the learner pipelines.
80
+ """
81
+ # Import here to avoid import-time side effects before libs are available
82
+ from fastMONAI.vision_augmentation import ZNormalization, _do_z_normalization
83
+ from fastMONAI.vision_core import MedImage
84
+
85
+ # Robust encodes that tolerates missing attrs on older pickles
86
+ def _robust_znorm_encodes(self, o: (MedImage)):
87
+ masking_method = getattr(self, 'masking_method', None)
88
+ channel_wise = getattr(self, 'channel_wise', True)
89
+ return MedImage.create(_do_z_normalization(o, masking_method, channel_wise))
90
+
91
+ # Monkey-patch class method so any new instance uses the robust encodes
92
+ ZNormalization.encodes = _robust_znorm_encodes
93
+
94
+ # Ensure attributes exist on already loaded instances inside learner pipelines
95
+ def _ensure_attrs_on_pipeline(pipeline):
96
+ if not hasattr(pipeline, 'fs'):
97
+ return
98
+ for tfm in pipeline.fs:
99
+ if isinstance(tfm, ZNormalization):
100
+ if not hasattr(tfm, 'masking_method'):
101
+ tfm.masking_method = None
102
+ if not hasattr(tfm, 'channel_wise'):
103
+ tfm.channel_wise = True
104
+
105
+ for p in (getattr(learn.dls, 'after_item', None),
106
+ getattr(learn.dls, 'before_batch', None),
107
+ getattr(learn.dls, 'after_batch', None)):
108
+ if p is not None:
109
+ _ensure_attrs_on_pipeline(p)
110
+
111
+ # -----------------------------
112
+ # Inference
113
+ # -----------------------------
114
+
115
+ def _direct_model_inference(learn, x_tensor):
116
+ """
117
+ Direct model inference bypassing fastai's predict() in case transforms fail.
118
+ Applies a simple channel-wise Z-normalization over non-zero voxels.
119
+ """
120
+ with torch.no_grad():
121
+ x = x_tensor
122
+ if x.dim() == 4:
123
+ x = x.unsqueeze(0) # [1, C, W, H, D]
124
+ x_ = x.clone()
125
+ # Channel-wise z-score over non-zero voxels
126
+ for c in range(x_.shape[21]):
127
+ vol = x_[:, c]
128
+ nz = vol != 0
129
+ if nz.any():
130
+ m = vol[nz].mean()
131
+ s = vol[nz].std()
132
+ if s > 0:
133
+ vol[nz] = (vol[nz] - m) / s
134
+ else:
135
+ vol[nz] = 0
136
+ x_[:, c] = vol
137
+ learn.model.eval()
138
+ logits = learn.model(x_) # [B, classes, W, H, D]
139
+ pred = torch.argmax(logits, dim=1) # [B, W, H, D]
140
+ return pred # [B, W, H, D]
141
+
142
  def inference(learn, reorder, resample, org_img, input_img, org_size):
143
+ """Perform segmentation using the loaded model, with robust fallback."""
 
144
  if not isinstance(input_img, torch.Tensor):
145
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
146
+ # Try the standard fastai predict path first (applies the learner's transforms)
147
+ # If transforms are not compatible due to pickled attribute differences, fallback to direct model inference.
148
+ try:
149
+ with torch.no_grad():
150
+ pred = learn.predict(input_img) # returns (decoded, raw, probs/...), depending on setup
151
+ # Extract mask-like output
152
+ mask_data = pred if isinstance(pred, (list, tuple)) else pred
153
+ if isinstance(mask_data, torch.Tensor):
154
+ if mask_data.dim() == 3:
155
+ mask_data = mask_data.unsqueeze(0) # [1, W, H, D]
156
+ elif mask_data.dim() == 4 and mask_data.shape != 1:
157
+ # If [B, W, H, D] with B>1, keep as is; downstream expects batch dim
158
+ pass
159
+ return mask_data
160
+ # If not a tensor, try to coerce, else fallback
161
+ except Exception as e:
162
+ print(f"[WARN] learn.predict failed, falling back to direct model inference. Reason: {e}")
163
+
164
+ # Fallback: call the model directly on the tensor
165
+ pred = _direct_model_inference(learn, input_img) # [B, W, H, D]
166
+ return pred
167
+
168
+ # -----------------------------
169
+ # App prediction wrapper
170
+ # -----------------------------
171
+
172
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
173
  """Predict function using the learner and other resources."""
 
174
  if view is None:
175
  view = 'Sagittal'
176
+
177
  img_path = Path(fileobj.name)
178
  save_fn = 'pred_' + img_path.stem
179
  save_path = save_dir / save_fn
180
+
181
+ # Read medical image with robust handling of return formats
182
  try:
183
+ result = med_img_reader(str(img_path),
184
+ reorder=reorder,
185
+ resample=resample,
186
+ only_tensor=False,
187
+ dtype=torch.Tensor)
 
 
188
  if isinstance(result, tuple):
189
  if len(result) == 3:
190
  org_img, input_img, org_size = result
191
+ # Depending on fastMONAI version, input_img may be a ScalarImage; if so, extract tensor
192
+ if hasattr(input_img, 'data') and not isinstance(input_img, torch.Tensor):
193
+ input_img = input_img.data.type(torch.float)
194
  elif len(result) == 2:
195
+ # Some versions return (org_img, metadata_dict)
196
  org_img, metadata_dict = result
197
+ input_img = org_img.data.type(torch.float) if hasattr(org_img, 'data') else org_img
198
+ org_size = org_img.shape[1:] if hasattr(org_img, 'shape') else None
 
199
  else:
200
  raise ValueError(f"Unexpected number of return values: {len(result)}")
201
  else:
 
202
  raise ValueError(f"Expected tuple but got {type(result)}")
 
203
  except Exception as e:
204
  print(f"DEBUG: Error in med_img_reader: {str(e)}")
205
  if 'result' in locals():
 
207
  if hasattr(result, '__len__'):
208
  print(f"DEBUG: result length: {len(result)}")
209
  raise ValueError(f"Error reading medical image: {str(e)}")
210
+
211
+ # Ensure tensor with appropriate dims [C, W, H, D] or [1, C, W, H, D]
212
  if not isinstance(input_img, torch.Tensor):
213
  raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
214
+ if input_img.dim() == 4:
215
+ x_in = input_img.unsqueeze(0) # [1, C, W, H, D]
216
+ else:
217
+ x_in = input_img
218
+
219
+ # Perform inference (predict or direct fallback)
220
  mask_data = inference(learn, reorder=reorder, resample=resample,
221
+ org_img=org_img, input_img=x_in,
222
  org_size=org_size)
223
+
224
+ # Optional orientation handling (example logic retained)
225
+ if hasattr(org_img, 'orientation') and "".join(org_img.orientation) == "LSA":
226
+ mask_data = mask_data.permute(0, 1, 3, 2)
227
+ mask_data = torch.flip(mask_data, dims=[21])
228
  mask_data = torch.Tensor(mask_data)[None]
229
 
230
+ # Save the prediction as a TorchIO image using original image object
231
+ img = org_img.data if hasattr(org_img, 'data') else None
232
  org_img.set_data(mask_data)
233
  org_img.save(save_path)
234
 
235
+ # Build fused gallery
236
+ base_img_np = img.numpy() if img is not None else input_img.numpy()
237
+ pred_np = mask_data.detach().cpu().numpy().astype(np.uint8)
238
+ slices = extract_slices_from_mask(base_img_np, pred_np, view)
239
+ fused_images = [
240
+ get_fused_image(normalize_image(slice_img), slice_mask, view)
241
+ for slice_img, slice_mask in slices
242
+ ]
243
+
244
  # Compute volume
245
  volume = compute_binary_tumor_volume(org_img)
 
246
  return fused_images, round(volume, 2)
247
 
248
+ # -----------------------------
249
+ # Resources
250
+ # -----------------------------
251
+
252
  def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
253
  """Load the model and other required resources."""
254
  try:
 
260
  try:
261
  with open(models_path / variables_fn, 'rb') as f:
262
  variables = pickle.load(f)
 
263
  if not isinstance(variables, list) or len(variables) != 3:
264
  raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
265
+ shape = variables
266
+ reorder = variables[21]
267
+ resample = variables[6]
 
 
 
268
  if not isinstance(reorder, bool):
269
  raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
 
270
  if not isinstance(resample, list) or len(resample) != 3:
271
  raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
 
272
  except Exception as e:
273
  raise ValueError(f"Error loading variables: {str(e)}")
274
 
275
  return learn, reorder, resample
276
 
277
+ # -----------------------------
278
+ # Main
279
+ # -----------------------------
280
+
281
  try:
282
  print("🚀 Initializing application...")
283
+
284
  clone_dir = Path.cwd() / 'clone_dir'
285
  URI = os.getenv('PAT_Token_URI')
 
286
  if not URI:
287
  raise ValueError("PAT_Token_URI environment variable is not set")
288
 
 
298
  save_dir.mkdir(parents=True, exist_ok=True)
299
  print(f"✅ Save directory created at {save_dir}")
300
 
301
+ # Load model and variables
302
  print("🔍 Loading model and resources...")
303
  learn, reorder, resample = load_system_resources(models_path=models_path)
304
  print(f"✅ Model loaded successfully")
305
  print(f"✅ Reorder: {reorder}, Resample: {resample}")
306
 
307
+ # Apply robust ZNormalization compatibility patches BEFORE any predict is called
308
+ apply_znorm_compat_patches(learn)
309
+ print("✅ Applied ZNormalization compatibility patches")
310
+
311
+ # Gradio UI
312
  output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
313
  view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
314
 
 
315
  example_path = str(clone_dir / "sample.nii.gz")
316
  if not os.path.exists(example_path):
317
  print(f"⚠️ Example file not found: {example_path}")
 
325
  inputs=["file", view_selector],
326
  outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
327
  examples=examples,
328
+ allow_flagging='never',
329
+ cache_examples=False # avoid eager caching running predict at startup
330
+ )
331
+
332
  print("✅ Gradio interface initialized successfully")
333
+
334
  except Exception as e:
335
  print(f"❌ Error during initialization: {str(e)}")
336
  print(f"Error type: {type(e).__name__}")
337
  traceback.print_exc()
 
338
  exit(1)
339
 
 
340
  print("🌐 Launching Gradio interface...")
341
+ demo.launch()