drankush-ai commited on
Commit
b0bd5fc
·
verified ·
1 Parent(s): e41d6ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -336
app.py CHANGED
@@ -7,13 +7,21 @@ from huggingface_hub import snapshot_download
7
  from fastMONAI.vision_all import *
8
  from git import Repo
9
  import os
10
- from fastai.learner import load_learner
11
- from fastai.basics import load_pickle
12
- import pickle
13
- import traceback
14
- import torchio as tio
15
 
16
- # Function to extract slices from mask
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def extract_slices_from_mask(img, mask_data, view):
18
  """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
19
  slices = []
@@ -35,7 +43,6 @@ def extract_slices_from_mask(img, mask_data, view):
35
 
36
  return slices
37
 
38
- # Function to resize and pad slices
39
  def resize_and_pad(slice_img, slice_mask, target_size):
40
  """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
41
  h, w = slice_img.shape
@@ -53,7 +60,6 @@ def resize_and_pad(slice_img, slice_mask, target_size):
53
 
54
  return padded_img, padded_mask
55
 
56
- # Function to normalize image
57
  def normalize_image(slice_img):
58
  """Normalize the image to the range [0, 255] safely."""
59
  slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
@@ -62,7 +68,6 @@ def normalize_image(slice_img):
62
  normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
63
  return normalized_img.astype(np.uint8)
64
 
65
- # Function to get fused image
66
  def get_fused_image(img, pred_mask, view, alpha=0.8):
67
  """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
68
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
@@ -74,369 +79,71 @@ def get_fused_image(img, pred_mask, view, alpha=0.8):
74
  # Flip the fused image vertically and horizontally
75
  fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
76
 
77
- if view == 'Sagittal':
78
  return fused_flipped
79
- elif view == 'Coronal' or view == 'Axial':
80
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
81
  return rotated
82
 
83
- # Define the inference function
84
- def inference(learn, reorder, resample, org_img, input_img, org_size):
85
- """Perform segmentation using the loaded model."""
86
- print("\n=== Inference Debug ===")
87
- print(f"Initial input shape: {input_img.shape if isinstance(input_img, torch.Tensor) else 'Not a tensor'}")
88
-
89
- # Ensure input_img is a torch.Tensor
90
- if not isinstance(input_img, torch.Tensor):
91
- raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
92
-
93
- # Add batch dimension if needed
94
- if input_img.dim() == 4:
95
- input_img = input_img.unsqueeze(0)
96
- print(f"Added batch dim - shape: {input_img.shape}")
97
-
98
- # Verify we now have the correct dimensions
99
- if input_img.dim() != 5:
100
- raise ValueError(f"Input tensor must be 5D after preprocessing, got {input_img.dim()}D tensor with shape {input_img.shape}")
101
-
102
- print(f"Input tensor shape before prediction: {input_img.shape}")
103
-
104
- # Perform the segmentation
105
- with torch.no_grad():
106
- try:
107
- print(f"\nModel input shape before processing: {input_img.shape}")
108
-
109
- # Convert to plain tensor if wrapped in MedImage
110
- if hasattr(input_img, 'data'):
111
- input_img = input_img.data
112
- print(f"Unwrapped MedImage to tensor with shape: {input_img.shape}")
113
-
114
- # Ensure we have exactly 5 dimensions [1,1,320,320,140]
115
- if input_img.dim() == 6:
116
- input_img = input_img.squeeze(0) # Remove extra batch dim
117
- elif input_img.dim() == 4:
118
- input_img = input_img.unsqueeze(0) # Add batch dim
119
-
120
- # Verify final shape
121
- if input_img.shape != torch.Size([1,1,320,320,140]):
122
- input_img = input_img.view(1, 1, 320, 320, 140)
123
- print(f"Reshaped to required dimensions: {input_img.shape}")
124
-
125
- print(f"Final input shape to model: {input_img.shape}")
126
- pred = learn.predict(input_img)
127
- print(f"Prediction output shape: {pred[0].shape if isinstance(pred, (list, tuple)) else pred.shape}")
128
- except Exception as e:
129
- print(f"Error during prediction: {str(e)}")
130
- print(f"Current input shape: {input_img.shape}")
131
- print(f"Input tensor contents: {input_img}")
132
- raise
133
-
134
- # Process the prediction if necessary
135
- mask_data = pred[0] if isinstance(pred, (list, tuple)) else pred
136
-
137
- return mask_data
138
-
139
- # Function for Gradio image segmentation
140
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
141
  """Predict function using the learner and other resources."""
142
 
143
- if view is None:
144
  view = 'Sagittal'
145
 
146
  img_path = Path(fileobj.name)
 
147
  save_fn = 'pred_' + img_path.stem
148
  save_path = save_dir / save_fn
 
 
 
 
149
 
150
- # Read the medical image - handle variable return values
151
- try:
152
- print(f"\nReading medical image: {img_path}")
153
- result = med_img_reader(str(img_path),
154
- reorder=reorder,
155
- resample=resample,
156
- only_tensor=False,
157
- dtype=torch.Tensor)
158
- print(f"med_img_reader returned type: {type(result)}")
159
- if isinstance(result, tuple):
160
- print(f"Tuple length: {len(result)}")
161
- for i, item in enumerate(result):
162
- print(f"Item {i} type: {type(item)}")
163
- if hasattr(item, 'shape'):
164
- print(f"Item {i} shape: {item.shape}")
165
-
166
- # Handle different return formats
167
- if isinstance(result, tuple):
168
- if len(result) == 3:
169
- org_img, input_img, org_size = result
170
- elif len(result) == 2:
171
- org_img, metadata_dict = result
172
- # Extract the tensor from the ScalarImage object
173
- input_img = org_img.data # This is the actual tensor!
174
- org_size = org_img.shape[1:] # Infer from org_img
175
- else:
176
- raise ValueError(f"Unexpected number of return values: {len(result)}")
177
- else:
178
- # Single value returned (shouldn't happen with only_tensor=False)
179
- raise ValueError(f"Expected tuple but got {type(result)}")
180
-
181
- except Exception as e:
182
- print(f"DEBUG: Error in med_img_reader: {str(e)}")
183
- if 'result' in locals():
184
- print(f"DEBUG: med_img_reader returned: {type(result)}")
185
- if hasattr(result, '__len__'):
186
- print(f"DEBUG: result length: {len(result)}")
187
- raise ValueError(f"Error reading medical image: {str(e)}")
188
-
189
- # Ensure input_img is a torch.Tensor
190
- if not isinstance(input_img, torch.Tensor):
191
- raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
192
-
193
- # Perform inference
194
- print(f"\nInput tensor shape before inference: {input_img.shape}")
195
-
196
- # Convert to plain tensor if wrapped in MedImage
197
- if hasattr(input_img, 'data'):
198
- input_img = input_img.data
199
- print(f"Unwrapped MedImage to tensor with shape: {input_img.shape}")
200
-
201
- # Directly reshape to required 5D format [1,1,320,320,140]
202
- try:
203
- input_img = input_img.reshape(1, 1, 320, 320, 140)
204
- print(f"Reshaped to required dimensions: {input_img.shape}")
205
- except Exception as e:
206
- raise ValueError(f"Could not reshape tensor from {input_img.shape} to [1,1,320,320,140]: {str(e)}")
207
  mask_data = inference(learn, reorder=reorder, resample=resample,
208
  org_img=org_img, input_img=input_img,
209
- org_size=org_size)
210
- print(f"Raw mask data shape: {mask_data.shape}")
211
 
212
- # Handle orientation if needed
213
- if hasattr(org_img, 'orientation') and "".join(org_img.orientation) == "LSA":
214
- print("Applying LSA orientation correction")
215
  mask_data = mask_data.permute(0,1,3,2)
216
  mask_data = torch.flip(mask_data[0], dims=[1])
217
  mask_data = torch.Tensor(mask_data)[None]
218
- print(f"Orientation corrected shape: {mask_data.shape}")
219
 
220
- # Save the prediction
221
  img = org_img.data
222
  org_img.set_data(mask_data)
223
  org_img.save(save_path)
224
 
225
- # Extract slices and create fused images
226
- slices = extract_slices_from_mask(img[0].numpy(), mask_data[0].numpy(), view)
227
  fused_images = [(get_fused_image(
228
- normalize_image(slice_img),
229
  slice_mask, view))
230
  for slice_img, slice_mask in slices]
231
 
232
- # Compute volume
233
  volume = compute_binary_tumor_volume(org_img)
234
 
235
  return fused_images, round(volume, 2)
236
 
237
- # Function to load system resources
238
- def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
239
- """Load the model and other required resources."""
240
- try:
241
- learn = load_learner(models_path / learner_fn)
242
- print(f"✅ Model loaded from {models_path / learner_fn}")
243
- except Exception as e:
244
- raise ValueError(f"Error loading the model: {str(e)}")
245
-
246
- try:
247
- with open(models_path / variables_fn, 'rb') as f:
248
- variables = pickle.load(f)
249
-
250
- if not isinstance(variables, list) or len(variables) != 3:
251
- raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
252
-
253
- # Assuming the format is [shape, reorder, resample]
254
- shape = variables[0]
255
- reorder = variables[1]
256
- resample = variables[2]
257
-
258
- if not isinstance(reorder, bool):
259
- raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
260
-
261
- if not isinstance(resample, list) or len(resample) != 3:
262
- raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
263
-
264
- except Exception as e:
265
- raise ValueError(f"Error loading variables: {str(e)}")
266
-
267
- return learn, reorder, resample
268
-
269
  # Initialize the system
270
- try:
271
- print("🚀 Initializing application...")
272
-
273
- clone_dir = Path.cwd() / 'clone_dir'
274
- URI = os.getenv('PAT_Token_URI')
275
 
276
- if not URI:
277
- raise ValueError("PAT_Token_URI environment variable is not set")
 
 
278
 
279
- if os.path.exists(clone_dir):
280
- print(f" Repository already exists at {clone_dir}")
281
- else:
282
- print(f"📥 Cloning repository from {URI}")
283
- Repo.clone_from(URI, clone_dir)
284
- print(f"✅ Repository cloned to {clone_dir}")
285
-
286
- # ===== MONKEY PATCH FOR ZNORMALIZATION ERROR =====
287
- # This is the ONLY patch needed - it directly fixes the root cause
288
- from fastMONAI.vision_augmentation import ZNormalization
289
- from fastMONAI.vision_core import MedImage
290
-
291
- # Save the original encodes method
292
- original_encodes = ZNormalization.encodes
293
-
294
- def patched_encodes(self, o):
295
- """Patched version of ZNormalization.encodes with custom normalization."""
296
- print("\n=== Custom ZNormalization ===")
297
- print(f"Input type: {type(o)}")
298
-
299
- try:
300
- # Get the input tensor data
301
- if hasattr(o, 'data') and isinstance(o.data, torch.Tensor):
302
- tensor_data = o.data
303
- elif isinstance(o, torch.Tensor):
304
- tensor_data = o
305
- else:
306
- print("Input is not a tensor - converting")
307
- tensor_data = torch.tensor(o)
308
-
309
- print(f"Input shape: {tensor_data.shape}")
310
- print(f"Input dtype: {tensor_data.dtype}")
311
-
312
- # Ensure we have the right dimensions
313
- if tensor_data.dim() == 6:
314
- tensor_data = tensor_data.squeeze(0)
315
- print(f"Removed extra dim - new shape: {tensor_data.shape}")
316
-
317
- # Calculate mean and std
318
- mean = tensor_data.mean()
319
- std = tensor_data.std()
320
- print(f"Calculated mean: {mean:.4f}, std: {std:.4f}")
321
-
322
- # Avoid division by zero
323
- if std == 0:
324
- std = 1.0
325
- print("Warning: std is 0, setting to 1.0")
326
-
327
- # Apply normalization
328
- normalized = (tensor_data - mean) / std
329
- print(f"Normalized tensor shape: {normalized.shape}")
330
-
331
- # Return as plain tensor to avoid MedImage wrapping issues
332
- return normalized
333
-
334
- except Exception as e:
335
- print(f"Error in custom normalization: {str(e)}")
336
- print("Fallback - returning input as-is")
337
- return MedImage.create(o)
338
-
339
- print("=== End Custom ZNormalization ===\n")
340
 
341
- # Add PadOrCrop debug prints
342
- from fastMONAI.vision_augmentation import PadOrCrop
343
- original_padorcrop_encodes = PadOrCrop.encodes
344
 
345
- def patched_padorcrop_encodes(self, o):
346
- """Patched version of PadOrCrop.encodes that ensures tensor output."""
347
- print("\n=== PadOrCrop Debug ===")
348
- print(f"Input type: {type(o)}")
349
- if hasattr(o, 'shape'): print(f"Input shape: {o.shape}")
350
-
351
- # Safely get attributes with fallbacks
352
- target_shape = getattr(self, 'size', o.shape[1:]) # Default to input shape
353
- padding_mode = getattr(self, 'padding_mode', 'constant')
354
- mask_name = getattr(self, 'mask_name', None)
355
-
356
- print(f"Target shape: {target_shape}")
357
- print(f"Padding mode: {padding_mode}")
358
- print(f"Mask name: {mask_name}")
359
-
360
- try:
361
- # Get the transformation result
362
- result = original_padorcrop_encodes(self, o)
363
- print(f"Initial output type: {type(result)}")
364
-
365
- # Extract tensor data from result
366
- if hasattr(result, 'data'):
367
- print("Extracting .data attribute")
368
- tensor_data = result.data
369
- elif isinstance(result, torch.Tensor):
370
- print("Using tensor directly")
371
- tensor_data = result
372
- else:
373
- print("Unknown output type - using input")
374
- tensor_data = o.data if hasattr(o, 'data') else o
375
-
376
- # Ensure we have a tensor
377
- if not isinstance(tensor_data, torch.Tensor):
378
- print("Converting to tensor")
379
- tensor_data = torch.tensor(tensor_data)
380
-
381
- print(f"Final output shape: {tensor_data.shape}")
382
- return MedImage.create(tensor_data)
383
-
384
- except Exception as e:
385
- print(f"Error in PadOrCrop: {str(e)}")
386
- print("Fallback - returning input as MedImage")
387
- return MedImage.create(o.data if hasattr(o, 'data') else o)
388
-
389
- print("=== End PadOrCrop Debug ===\n")
390
-
391
- PadOrCrop.encodes = patched_padorcrop_encodes
392
- print("✅ Applied PadOrCrop patch with debug prints")
393
-
394
- # Replace the encodes method
395
- ZNormalization.encodes = patched_encodes
396
- print("✅ Applied ZNormalization patch to handle missing attributes in encodes method")
397
- # ===== END OF MONKEY PATCH =====
398
-
399
- models_path = clone_dir
400
- save_dir = Path.cwd() / 'hs_pred'
401
- save_dir.mkdir(parents=True, exist_ok=True)
402
- print(f"✅ Save directory created at {save_dir}")
403
-
404
- # Load the model and other required resources
405
- print("🔍 Loading model and resources...")
406
- learn, reorder, resample = load_system_resources(models_path=models_path)
407
- print(f"✅ Model loaded successfully")
408
- print(f"✅ Reorder: {reorder}, Resample: {resample}")
409
-
410
- # Gradio interface setup
411
- output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
412
- view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
413
-
414
- # Ensure the example file path is correct
415
- example_path = str(clone_dir / "sample.nii.gz")
416
- if not os.path.exists(example_path):
417
- print(f"⚠️ Example file not found: {example_path}")
418
- examples = None
419
- else:
420
- print(f"✅ Example file found: {example_path}")
421
- examples = [[example_path]]
422
-
423
- demo = gr.Interface(
424
- fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
425
- inputs=["file", view_selector],
426
- outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
427
- examples=examples,
428
- cache_examples=False, # Disable example caching to prevent errors during startup
429
- allow_flagging='never')
430
-
431
- print("✅ Gradio interface initialized successfully")
432
-
433
- except Exception as e:
434
- print(f"❌ Error during initialization: {str(e)}")
435
- print(f"Error type: {type(e).__name__}")
436
- traceback.print_exc()
437
- # Exit with error code
438
- exit(1)
439
 
440
  # Launch the Gradio interface
441
- print("🌐 Launching Gradio interface...")
442
- demo.launch()
 
7
  from fastMONAI.vision_all import *
8
  from git import Repo
9
  import os
 
 
 
 
 
10
 
11
+ #Additional support for local execution:-
12
+ #import pathlib
13
+ #temp = pathlib.PosixPath
14
+ #pathlib.PosixPath = pathlib.WindowsPath
15
+ #pathlib.PosixPath = temp
16
+
17
+ clone_dir = Path.cwd() / 'clone_dir'
18
+ URI = os.getenv('PAT_Token_URI')
19
+
20
+ if os.path.exists(clone_dir):
21
+ pass
22
+ else:
23
+ Repo.clone_from(URI, clone_dir)
24
+
25
  def extract_slices_from_mask(img, mask_data, view):
26
  """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
27
  slices = []
 
43
 
44
  return slices
45
 
 
46
  def resize_and_pad(slice_img, slice_mask, target_size):
47
  """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
48
  h, w = slice_img.shape
 
60
 
61
  return padded_img, padded_mask
62
 
 
63
  def normalize_image(slice_img):
64
  """Normalize the image to the range [0, 255] safely."""
65
  slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
 
68
  normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
69
  return normalized_img.astype(np.uint8)
70
 
 
71
  def get_fused_image(img, pred_mask, view, alpha=0.8):
72
  """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
73
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 
79
  # Flip the fused image vertically and horizontally
80
  fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
81
 
82
+ if view=='Sagittal':
83
  return fused_flipped
84
+ elif view=='Coronal' or 'Axial':
85
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
86
  return rotated
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
89
  """Predict function using the learner and other resources."""
90
 
91
+ if view == None:
92
  view = 'Sagittal'
93
 
94
  img_path = Path(fileobj.name)
95
+
96
  save_fn = 'pred_' + img_path.stem
97
  save_path = save_dir / save_fn
98
+ org_img, input_img, org_size = med_img_reader(img_path,
99
+ reorder=reorder,
100
+ resample=resample,
101
+ only_tensor=False)
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  mask_data = inference(learn, reorder=reorder, resample=resample,
104
  org_img=org_img, input_img=input_img,
105
+ org_size=org_size).data
 
106
 
107
+ if "".join(org_img.orientation) == "LSA":
 
 
108
  mask_data = mask_data.permute(0,1,3,2)
109
  mask_data = torch.flip(mask_data[0], dims=[1])
110
  mask_data = torch.Tensor(mask_data)[None]
 
111
 
 
112
  img = org_img.data
113
  org_img.set_data(mask_data)
114
  org_img.save(save_path)
115
 
116
+ slices = extract_slices_from_mask(img[0], mask_data[0], view)
 
117
  fused_images = [(get_fused_image(
118
+ normalize_image(slice_img), # Normalize safely
119
  slice_mask, view))
120
  for slice_img, slice_mask in slices]
121
 
 
122
  volume = compute_binary_tumor_volume(org_img)
123
 
124
  return fused_images, round(volume, 2)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  # Initialize the system
127
+ models_path = Path.cwd() / 'clone_dir'
128
+ save_dir = Path.cwd() / 'hs_pred'
129
+ save_dir.mkdir(parents=True, exist_ok=True)
 
 
130
 
131
+ # Load the model and other required resources
132
+ learn, reorder, resample = load_system_resources(models_path=models_path,
133
+ learner_fn='heart_model.pkl',
134
+ variables_fn='vars.pkl')
135
 
136
+ # Gradio interface setup
137
+ output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
 
 
140
 
141
+ demo = gr.Interface(
142
+ fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
143
+ inputs=["file", view_selector],
144
+ outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
145
+ examples=[[str(Path.cwd() /"sample.nii.gz")]],
146
+ allow_flagging='never')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # Launch the Gradio interface
149
+ demo.launch()