drankush-ai commited on
Commit
27298f5
Β·
verified Β·
1 Parent(s): a818d45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -200
app.py CHANGED
@@ -10,6 +10,7 @@ import os
10
  from fastai.learner import load_learner
11
  from fastai.basics import load_pickle
12
  import pickle
 
13
 
14
  # Function to extract slices from mask
15
  def extract_slices_from_mask(img, mask_data, view):
@@ -83,61 +84,21 @@ def inference(learn, reorder, resample, org_img, input_img, org_size):
83
  """Perform segmentation using the loaded model."""
84
  # Ensure input_img is a torch.Tensor
85
  if not isinstance(input_img, torch.Tensor):
86
- # If input_img is not a tensor, try to extract the tensor data
87
- if hasattr(input_img, 'data'):
88
- input_tensor = input_img.data
89
- else:
90
- raise ValueError(f"Expected input_img to be a torch.Tensor or have a 'data' attribute, but got {type(input_img)}")
91
- else:
92
- input_tensor = input_img
93
 
94
  # Perform the segmentation
95
  with torch.no_grad():
96
- pred = learn.predict(input_tensor)
97
 
98
  # Process the prediction if necessary
99
- mask_data = pred[0] # Assuming the first element of the prediction is the mask
100
 
101
  return mask_data
102
 
103
- # Function to extract tensor from different data types
104
- def extract_tensor_from_data(data, fallback_path=None):
105
- """Extract tensor from various data types returned by med_img_reader."""
106
- print(f"DEBUG: Extracting tensor from data type: {type(data)}")
107
-
108
- if isinstance(data, torch.Tensor):
109
- return data
110
- elif isinstance(data, dict):
111
- print(f"DEBUG: Dictionary keys: {list(data.keys())}")
112
- # Try common key names for tensor data
113
- possible_keys = ['data', 'tensor', 'image', 'img', 'input', 'volume']
114
- for key in possible_keys:
115
- if key in data:
116
- tensor_data = data[key]
117
- if isinstance(tensor_data, torch.Tensor):
118
- print(f"DEBUG: Found tensor in key '{key}'")
119
- return tensor_data
120
- elif hasattr(tensor_data, 'data') and isinstance(tensor_data.data, torch.Tensor):
121
- print(f"DEBUG: Found tensor in key '{key}'.data")
122
- return tensor_data.data
123
-
124
- # If no tensor found in dict, try to load directly
125
- if fallback_path:
126
- print("DEBUG: No tensor found in dict, trying direct load")
127
- try:
128
- direct_tensor = med_img_reader(fallback_path, only_tensor=True, dtype=torch.Tensor)
129
- return direct_tensor
130
- except Exception as e:
131
- print(f"DEBUG: Direct load failed: {e}")
132
- elif hasattr(data, 'data') and isinstance(data.data, torch.Tensor):
133
- print("DEBUG: Found tensor in data.data attribute")
134
- return data.data
135
- elif hasattr(data, 'tensor') and isinstance(data.tensor, torch.Tensor):
136
- print("DEBUG: Found tensor in tensor attribute")
137
- return data.tensor
138
-
139
- raise ValueError(f"Cannot extract tensor from data of type {type(data)}")
140
-
141
  # Function for Gradio image segmentation
142
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
143
  """Predict function using the learner and other resources."""
@@ -146,159 +107,68 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
146
  view = 'Sagittal'
147
 
148
  img_path = Path(fileobj.name)
149
-
150
- # Convert PosixPath to string
151
- img_path_str = str(img_path)
152
-
153
  save_fn = 'pred_' + img_path.stem
154
  save_path = save_dir / save_fn
155
 
156
- # Try different approaches to load the image
157
- org_img = None
158
- input_tensor = None
159
- org_size = None
160
-
161
  try:
162
- # First try with only_tensor=False to get all values
163
- result = med_img_reader(img_path_str,
164
- reorder=reorder,
165
- resample=resample,
166
- only_tensor=False,
167
- dtype=torch.Tensor)
168
 
169
- print(f"DEBUG: med_img_reader returned type: {type(result)}")
170
-
171
- # Handle different return types
172
  if isinstance(result, tuple):
173
- print(f"DEBUG: Tuple length: {len(result)}")
174
- if len(result) >= 2:
175
- org_img = result[0]
176
- input_img = result[1]
177
- org_size = result[2] if len(result) > 2 else None
178
-
179
- # Extract tensor from input_img
180
- input_tensor = extract_tensor_from_data(input_img, img_path_str)
181
-
182
- # Handle org_size if not provided
183
- if org_size is None and hasattr(input_tensor, 'shape'):
184
- org_size = input_tensor.shape[1:]
185
  else:
186
- raise ValueError(f"Unexpected tuple length from med_img_reader: {len(result)}")
187
-
188
- elif isinstance(result, dict):
189
- print(f"DEBUG: Dictionary keys: {list(result.keys())}")
190
-
191
- # Try to extract org_img
192
- org_img_keys = ['org_img', 'original', 'original_image', 'source']
193
- for key in org_img_keys:
194
- if key in result:
195
- org_img = result[key]
196
- break
197
-
198
- # Extract input tensor
199
- input_tensor = extract_tensor_from_data(result, img_path_str)
200
-
201
- # Extract org_size
202
- if 'org_size' in result:
203
- org_size = result['org_size']
204
- elif hasattr(input_tensor, 'shape'):
205
- org_size = input_tensor.shape[1:]
206
-
207
  else:
208
- # If it's neither tuple nor dict, try to extract tensor directly
209
- input_tensor = extract_tensor_from_data(result, img_path_str)
210
-
211
- # If org_img is still None, try to create it
212
- if org_img is None:
213
- try:
214
- from torchio import ScalarImage
215
- org_img = ScalarImage(img_path_str)
216
- except Exception as e:
217
- print(f"DEBUG: Failed to create ScalarImage: {e}")
218
- # Create a minimal org_img object
219
- class MinimalImage:
220
- def __init__(self, tensor, path):
221
- self.data = tensor
222
- self.path = path
223
- self.orientation = "RAS" # Default orientation
224
-
225
- def set_data(self, data):
226
- self.data = data
227
-
228
- def save(self, path):
229
- # Minimal save implementation
230
- pass
231
-
232
- org_img = MinimalImage(input_tensor, img_path_str)
233
-
234
- # Ensure org_size is set
235
- if org_size is None and hasattr(input_tensor, 'shape'):
236
- org_size = input_tensor.shape[1:]
237
 
238
  except Exception as e:
239
- print(f"DEBUG: Error in med_img_reader handling: {str(e)}")
240
- # Final fallback: try to load tensor directly
241
- try:
242
- input_tensor = med_img_reader(img_path_str,
243
- reorder=reorder,
244
- resample=resample,
245
- only_tensor=True,
246
- dtype=torch.Tensor)
247
-
248
- # Create minimal org_img
249
- class MinimalImage:
250
- def __init__(self, tensor, path):
251
- self.data = tensor
252
- self.path = path
253
- self.orientation = "RAS"
254
-
255
- def set_data(self, data):
256
- self.data = data
257
-
258
- def save(self, path):
259
- pass
260
-
261
- org_img = MinimalImage(input_tensor, img_path_str)
262
- org_size = input_tensor.shape[1:] if hasattr(input_tensor, 'shape') else None
263
-
264
- except Exception as e2:
265
- raise ValueError(f"Failed to load image: {str(e2)}")
266
 
267
- # Ensure we have a valid input tensor
268
- if input_tensor is None:
269
- raise ValueError("Could not extract input tensor from loaded image data")
270
-
271
- print(f"DEBUG: Final input_tensor type: {type(input_tensor)}, shape: {getattr(input_tensor, 'shape', 'No shape')}")
272
 
273
  # Perform inference
274
  mask_data = inference(learn, reorder=reorder, resample=resample,
275
- org_img=org_img, input_img=input_tensor,
276
  org_size=org_size)
277
 
278
- # Handle orientation if org_img has orientation attribute
279
- if hasattr(org_img, 'orientation') and "".join(getattr(org_img, 'orientation', '')) == "LSA":
280
  mask_data = mask_data.permute(0,1,3,2)
281
  mask_data = torch.flip(mask_data[0], dims=[1])
282
  mask_data = torch.Tensor(mask_data)[None]
283
 
284
- # Extract data from org_img properly
285
- if hasattr(org_img, 'data'):
286
- img = org_img.data
287
- else:
288
- img = input_tensor # Use input_tensor as fallback
289
-
290
- # Set data and save
291
- if hasattr(org_img, 'set_data'):
292
- org_img.set_data(mask_data)
293
- if hasattr(org_img, 'save'):
294
- org_img.save(save_path)
295
 
296
- slices = extract_slices_from_mask(img[0], mask_data[0], view)
 
297
  fused_images = [(get_fused_image(
298
- normalize_image(slice_img), # Normalize safely
299
  slice_mask, view))
300
  for slice_img, slice_mask in slices]
301
 
 
302
  volume = compute_binary_tumor_volume(org_img)
303
 
304
  return fused_images, round(volume, 2)
@@ -308,6 +178,7 @@ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_f
308
  """Load the model and other required resources."""
309
  try:
310
  learn = load_learner(models_path / learner_fn)
 
311
  except Exception as e:
312
  raise ValueError(f"Error loading the model: {str(e)}")
313
 
@@ -335,38 +206,62 @@ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_f
335
  return learn, reorder, resample
336
 
337
  # Initialize the system
338
- clone_dir = Path.cwd() / 'clone_dir'
339
- URI = os.getenv('PAT_Token_URI')
340
-
341
- if not URI:
342
- raise ValueError("PAT_Token_URI environment variable is not set")
343
 
344
- if os.path.exists(clone_dir):
345
- pass
346
- else:
347
- Repo.clone_from(URI, clone_dir)
348
 
349
- models_path = clone_dir
350
- save_dir = Path.cwd() / 'hs_pred'
351
- save_dir.mkdir(parents=True, exist_ok=True)
 
 
 
352
 
353
- # Load the model and other required resources
354
- learn, reorder, resample = load_system_resources(models_path=models_path)
 
 
355
 
356
- # Gradio interface setup
357
- output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
 
 
 
358
 
359
- view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
 
 
360
 
361
- # Ensure the example file path is correct
362
- example_path = str(clone_dir / "sample.nii.gz")
 
 
 
 
 
 
363
 
364
- demo = gr.Interface(
365
- fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
366
- inputs=["file", view_selector],
367
- outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
368
- examples=[[example_path, 'Sagittal']], # Added view parameter to example
369
- allow_flagging='never')
 
 
 
 
 
 
 
 
 
370
 
371
  # Launch the Gradio interface
 
372
  demo.launch()
 
10
  from fastai.learner import load_learner
11
  from fastai.basics import load_pickle
12
  import pickle
13
+ import traceback
14
 
15
  # Function to extract slices from mask
16
  def extract_slices_from_mask(img, mask_data, view):
 
84
  """Perform segmentation using the loaded model."""
85
  # Ensure input_img is a torch.Tensor
86
  if not isinstance(input_img, torch.Tensor):
87
+ raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
88
+
89
+ # Add batch dimension if needed
90
+ if input_img.dim() == 4:
91
+ input_img = input_img.unsqueeze(0)
 
 
92
 
93
  # Perform the segmentation
94
  with torch.no_grad():
95
+ pred = learn.predict(input_img)
96
 
97
  # Process the prediction if necessary
98
+ mask_data = pred[0] if isinstance(pred, (list, tuple)) else pred
99
 
100
  return mask_data
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  # Function for Gradio image segmentation
103
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
104
  """Predict function using the learner and other resources."""
 
107
  view = 'Sagittal'
108
 
109
  img_path = Path(fileobj.name)
 
 
 
 
110
  save_fn = 'pred_' + img_path.stem
111
  save_path = save_dir / save_fn
112
 
113
+ # Read the medical image - handle variable return values
 
 
 
 
114
  try:
115
+ result = med_img_reader(str(img_path),
116
+ reorder=reorder,
117
+ resample=resample,
118
+ only_tensor=False,
119
+ dtype=torch.Tensor)
 
120
 
121
+ # Handle different return formats
 
 
122
  if isinstance(result, tuple):
123
+ if len(result) == 3:
124
+ org_img, input_img, org_size = result
125
+ elif len(result) == 2:
126
+ org_img, metadata_dict = result
127
+ # Extract the tensor from the ScalarImage object
128
+ input_img = org_img.data # This is the actual tensor!
129
+ org_size = org_img.shape[1:] # Infer from org_img
 
 
 
 
 
130
  else:
131
+ raise ValueError(f"Unexpected number of return values: {len(result)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  else:
133
+ # Single value returned (shouldn't happen with only_tensor=False)
134
+ raise ValueError(f"Expected tuple but got {type(result)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  except Exception as e:
137
+ print(f"DEBUG: Error in med_img_reader: {str(e)}")
138
+ if 'result' in locals():
139
+ print(f"DEBUG: med_img_reader returned: {type(result)}")
140
+ if hasattr(result, '__len__'):
141
+ print(f"DEBUG: result length: {len(result)}")
142
+ raise ValueError(f"Error reading medical image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # Ensure input_img is a torch.Tensor
145
+ if not isinstance(input_img, torch.Tensor):
146
+ raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
 
 
147
 
148
  # Perform inference
149
  mask_data = inference(learn, reorder=reorder, resample=resample,
150
+ org_img=org_img, input_img=input_img,
151
  org_size=org_size)
152
 
153
+ # Handle orientation if needed
154
+ if hasattr(org_img, 'orientation') and "".join(org_img.orientation) == "LSA":
155
  mask_data = mask_data.permute(0,1,3,2)
156
  mask_data = torch.flip(mask_data[0], dims=[1])
157
  mask_data = torch.Tensor(mask_data)[None]
158
 
159
+ # Save the prediction
160
+ img = org_img.data
161
+ org_img.set_data(mask_data)
162
+ org_img.save(save_path)
 
 
 
 
 
 
 
163
 
164
+ # Extract slices and create fused images
165
+ slices = extract_slices_from_mask(img[0].numpy(), mask_data[0].numpy(), view)
166
  fused_images = [(get_fused_image(
167
+ normalize_image(slice_img),
168
  slice_mask, view))
169
  for slice_img, slice_mask in slices]
170
 
171
+ # Compute volume
172
  volume = compute_binary_tumor_volume(org_img)
173
 
174
  return fused_images, round(volume, 2)
 
178
  """Load the model and other required resources."""
179
  try:
180
  learn = load_learner(models_path / learner_fn)
181
+ print(f"βœ… Model loaded from {models_path / learner_fn}")
182
  except Exception as e:
183
  raise ValueError(f"Error loading the model: {str(e)}")
184
 
 
206
  return learn, reorder, resample
207
 
208
  # Initialize the system
209
+ try:
210
+ print("πŸš€ Initializing application...")
211
+
212
+ clone_dir = Path.cwd() / 'clone_dir'
213
+ URI = os.getenv('PAT_Token_URI')
214
 
215
+ if not URI:
216
+ raise ValueError("PAT_Token_URI environment variable is not set")
 
 
217
 
218
+ if os.path.exists(clone_dir):
219
+ print(f"βœ… Repository already exists at {clone_dir}")
220
+ else:
221
+ print(f"πŸ“₯ Cloning repository from {URI}")
222
+ Repo.clone_from(URI, clone_dir)
223
+ print(f"βœ… Repository cloned to {clone_dir}")
224
 
225
+ models_path = clone_dir
226
+ save_dir = Path.cwd() / 'hs_pred'
227
+ save_dir.mkdir(parents=True, exist_ok=True)
228
+ print(f"βœ… Save directory created at {save_dir}")
229
 
230
+ # Load the model and other required resources
231
+ print("πŸ” Loading model and resources...")
232
+ learn, reorder, resample = load_system_resources(models_path=models_path)
233
+ print(f"βœ… Model loaded successfully")
234
+ print(f"βœ… Reorder: {reorder}, Resample: {resample}")
235
 
236
+ # Gradio interface setup
237
+ output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
238
+ view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
239
 
240
+ # Ensure the example file path is correct
241
+ example_path = str(clone_dir / "sample.nii.gz")
242
+ if not os.path.exists(example_path):
243
+ print(f"⚠️ Example file not found: {example_path}")
244
+ examples = None
245
+ else:
246
+ print(f"βœ… Example file found: {example_path}")
247
+ examples = [[example_path]]
248
 
249
+ demo = gr.Interface(
250
+ fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
251
+ inputs=["file", view_selector],
252
+ outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
253
+ examples=examples,
254
+ allow_flagging='never')
255
+
256
+ print("βœ… Gradio interface initialized successfully")
257
+
258
+ except Exception as e:
259
+ print(f"❌ Error during initialization: {str(e)}")
260
+ print(f"Error type: {type(e).__name__}")
261
+ traceback.print_exc()
262
+ # Exit with error code
263
+ exit(1)
264
 
265
  # Launch the Gradio interface
266
+ print("🌐 Launching Gradio interface...")
267
  demo.launch()