drankush-ai commited on
Commit
36d522b
·
verified ·
1 Parent(s): c34fb59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -83
app.py CHANGED
@@ -10,7 +10,6 @@ import os
10
  from fastai.learner import load_learner
11
  from fastai.basics import load_pickle
12
  import pickle
13
- import traceback
14
 
15
  # Function to extract slices from mask
16
  def extract_slices_from_mask(img, mask_data, view):
@@ -84,18 +83,20 @@ def inference(learn, reorder, resample, org_img, input_img, org_size):
84
  """Perform segmentation using the loaded model."""
85
  # Ensure input_img is a torch.Tensor
86
  if not isinstance(input_img, torch.Tensor):
87
- raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
88
-
89
- # Add batch dimension if needed
90
- if input_img.dim() == 4:
91
- input_img = input_img.unsqueeze(0)
 
 
92
 
93
  # Perform the segmentation
94
  with torch.no_grad():
95
- pred = learn.predict(input_img)
96
 
97
  # Process the prediction if necessary
98
- mask_data = pred[0] if isinstance(pred, (list, tuple)) else pred
99
 
100
  return mask_data
101
 
@@ -107,47 +108,122 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
107
  view = 'Sagittal'
108
 
109
  img_path = Path(fileobj.name)
 
 
 
 
110
  save_fn = 'pred_' + img_path.stem
111
  save_path = save_dir / save_fn
112
 
113
- # Read the medical image
114
  try:
115
- org_img, input_img, org_size = med_img_reader(str(img_path),
116
- reorder=reorder,
117
- resample=resample,
118
- only_tensor=False,
119
- dtype=torch.Tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
- raise ValueError(f"Error reading medical image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # Ensure input_img is a torch.Tensor
124
- if not isinstance(input_img, torch.Tensor):
125
- raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
 
 
 
 
126
 
127
- # Perform inference
128
  mask_data = inference(learn, reorder=reorder, resample=resample,
129
- org_img=org_img, input_img=input_img,
130
  org_size=org_size)
131
 
132
- # Handle orientation if needed
133
- if hasattr(org_img, 'orientation') and "".join(org_img.orientation) == "LSA":
134
  mask_data = mask_data.permute(0,1,3,2)
135
  mask_data = torch.flip(mask_data[0], dims=[1])
136
  mask_data = torch.Tensor(mask_data)[None]
137
 
138
- # Save the prediction
139
- img = org_img.data
 
 
 
 
140
  org_img.set_data(mask_data)
141
  org_img.save(save_path)
142
 
143
- # Extract slices and create fused images
144
- slices = extract_slices_from_mask(img[0].numpy(), mask_data[0].numpy(), view)
145
  fused_images = [(get_fused_image(
146
- normalize_image(slice_img),
147
  slice_mask, view))
148
  for slice_img, slice_mask in slices]
149
 
150
- # Compute volume
151
  volume = compute_binary_tumor_volume(org_img)
152
 
153
  return fused_images, round(volume, 2)
@@ -157,7 +233,6 @@ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_f
157
  """Load the model and other required resources."""
158
  try:
159
  learn = load_learner(models_path / learner_fn)
160
- print(f"✅ Model loaded from {models_path / learner_fn}")
161
  except Exception as e:
162
  raise ValueError(f"Error loading the model: {str(e)}")
163
 
@@ -185,62 +260,38 @@ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_f
185
  return learn, reorder, resample
186
 
187
  # Initialize the system
188
- try:
189
- print("🚀 Initializing application...")
190
-
191
- clone_dir = Path.cwd() / 'clone_dir'
192
- URI = os.getenv('PAT_Token_URI')
193
 
194
- if not URI:
195
- raise ValueError("PAT_Token_URI environment variable is not set")
196
 
197
- if os.path.exists(clone_dir):
198
- print(f"✅ Repository already exists at {clone_dir}")
199
- else:
200
- print(f"📥 Cloning repository from {URI}")
201
- Repo.clone_from(URI, clone_dir)
202
- print(f"✅ Repository cloned to {clone_dir}")
203
-
204
- models_path = clone_dir
205
- save_dir = Path.cwd() / 'hs_pred'
206
- save_dir.mkdir(parents=True, exist_ok=True)
207
- print(f"✅ Save directory created at {save_dir}")
208
-
209
- # Load the model and other required resources
210
- print("🔍 Loading model and resources...")
211
- learn, reorder, resample = load_system_resources(models_path=models_path)
212
- print(f" Model loaded successfully")
213
- print(f"✅ Reorder: {reorder}, Resample: {resample}")
214
-
215
- # Gradio interface setup
216
- output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
217
- view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
218
-
219
- # Ensure the example file path is correct
220
- example_path = str(clone_dir / "sample.nii.gz")
221
- if not os.path.exists(example_path):
222
- print(f"⚠️ Example file not found: {example_path}")
223
- examples = None
224
- else:
225
- print(f"✅ Example file found: {example_path}")
226
- examples = [[example_path]]
227
-
228
- demo = gr.Interface(
229
- fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
230
- inputs=["file", view_selector],
231
- outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
232
- examples=examples,
233
- allow_flagging='never')
234
-
235
- print("✅ Gradio interface initialized successfully")
236
-
237
- except Exception as e:
238
- print(f"❌ Error during initialization: {str(e)}")
239
- print(f"Error type: {type(e).__name__}")
240
- traceback.print_exc()
241
- # Exit with error code
242
- exit(1)
243
 
244
  # Launch the Gradio interface
245
- print("🌐 Launching Gradio interface...")
246
  demo.launch()
 
10
  from fastai.learner import load_learner
11
  from fastai.basics import load_pickle
12
  import pickle
 
13
 
14
  # Function to extract slices from mask
15
  def extract_slices_from_mask(img, mask_data, view):
 
83
  """Perform segmentation using the loaded model."""
84
  # Ensure input_img is a torch.Tensor
85
  if not isinstance(input_img, torch.Tensor):
86
+ # If input_img is not a tensor, try to extract the tensor data
87
+ if hasattr(input_img, 'data'):
88
+ input_tensor = input_img.data
89
+ else:
90
+ raise ValueError(f"Expected input_img to be a torch.Tensor or have a 'data' attribute, but got {type(input_img)}")
91
+ else:
92
+ input_tensor = input_img
93
 
94
  # Perform the segmentation
95
  with torch.no_grad():
96
+ pred = learn.predict(input_tensor)
97
 
98
  # Process the prediction if necessary
99
+ mask_data = pred[0] # Assuming the first element of the prediction is the mask
100
 
101
  return mask_data
102
 
 
108
  view = 'Sagittal'
109
 
110
  img_path = Path(fileobj.name)
111
+
112
+ # Convert PosixPath to string
113
+ img_path_str = str(img_path)
114
+
115
  save_fn = 'pred_' + img_path.stem
116
  save_path = save_dir / save_fn
117
 
118
+ # Debug: Let's check what med_img_reader actually returns
119
  try:
120
+ # First try with only_tensor=False to get all values
121
+ result = med_img_reader(img_path_str,
122
+ reorder=reorder,
123
+ resample=resample,
124
+ only_tensor=False,
125
+ dtype=torch.Tensor)
126
+
127
+ # Debug print to understand the structure
128
+ print(f"DEBUG: med_img_reader returned type: {type(result)}")
129
+
130
+ # Handle different return types
131
+ if isinstance(result, tuple):
132
+ if len(result) == 3:
133
+ org_img, input_img, org_size = result
134
+ elif len(result) == 2:
135
+ org_img, input_img = result
136
+ # Infer org_size from org_img
137
+ if hasattr(org_img, 'shape'):
138
+ org_size = org_img.shape[1:]
139
+ else:
140
+ org_size = None
141
+ else:
142
+ raise ValueError(f"Unexpected tuple length from med_img_reader: {len(result)}")
143
+ elif isinstance(result, dict):
144
+ # If it's a dict, try to extract the needed values
145
+ # This might happen if the function returns something wrapped
146
+ print(f"DEBUG: Dictionary keys: {result.keys() if isinstance(result, dict) else 'N/A'}")
147
+ # Try to extract values - adjust based on actual keys
148
+ org_img = result.get('org_img', result.get('original', None))
149
+ input_img = result.get('input_img', result.get('input', result.get('data', None)))
150
+ org_size = result.get('org_size', result.get('size', None))
151
+
152
+ if org_img is None or input_img is None:
153
+ # If we can't extract from dict, try getting tensor directly
154
+ input_tensor = med_img_reader(img_path_str,
155
+ reorder=reorder,
156
+ resample=resample,
157
+ only_tensor=True,
158
+ dtype=torch.Tensor)
159
+ # Create dummy org_img
160
+ from torchio import ScalarImage
161
+ org_img = ScalarImage(img_path_str)
162
+ input_img = input_tensor
163
+ org_size = input_tensor.shape[1:] if hasattr(input_tensor, 'shape') else None
164
+ else:
165
+ # If it's neither tuple nor dict, it might be the tensor directly
166
+ # Try to get the full data with only_tensor=True
167
+ input_tensor = med_img_reader(img_path_str,
168
+ reorder=reorder,
169
+ resample=resample,
170
+ only_tensor=True,
171
+ dtype=torch.Tensor)
172
+ # Create dummy org_img
173
+ from torchio import ScalarImage
174
+ org_img = ScalarImage(img_path_str)
175
+ input_img = input_tensor
176
+ org_size = input_tensor.shape[1:] if hasattr(input_tensor, 'shape') else None
177
+
178
  except Exception as e:
179
+ print(f"DEBUG: Error in med_img_reader: {str(e)}")
180
+ # Fallback: try with only_tensor=True
181
+ try:
182
+ input_tensor = med_img_reader(img_path_str,
183
+ reorder=reorder,
184
+ resample=resample,
185
+ only_tensor=True,
186
+ dtype=torch.Tensor)
187
+ # Create dummy org_img
188
+ from torchio import ScalarImage
189
+ org_img = ScalarImage(img_path_str)
190
+ input_img = input_tensor
191
+ org_size = input_tensor.shape[1:] if hasattr(input_tensor, 'shape') else None
192
+ except Exception as e2:
193
+ raise ValueError(f"Failed to load image: {str(e2)}")
194
 
195
+ # Ensure input_img is proper format for inference
196
+ if hasattr(input_img, 'data') and isinstance(input_img.data, torch.Tensor):
197
+ input_tensor = input_img.data
198
+ elif isinstance(input_img, torch.Tensor):
199
+ input_tensor = input_img
200
+ else:
201
+ raise ValueError(f"Cannot extract tensor from input_img of type {type(input_img)}")
202
 
 
203
  mask_data = inference(learn, reorder=reorder, resample=resample,
204
+ org_img=org_img, input_img=input_tensor,
205
  org_size=org_size)
206
 
207
+ if "".join(org_img.orientation) == "LSA":
 
208
  mask_data = mask_data.permute(0,1,3,2)
209
  mask_data = torch.flip(mask_data[0], dims=[1])
210
  mask_data = torch.Tensor(mask_data)[None]
211
 
212
+ # Extract data from org_img properly
213
+ if hasattr(org_img, 'data'):
214
+ img = org_img.data
215
+ else:
216
+ img = org_img
217
+
218
  org_img.set_data(mask_data)
219
  org_img.save(save_path)
220
 
221
+ slices = extract_slices_from_mask(img[0], mask_data[0], view)
 
222
  fused_images = [(get_fused_image(
223
+ normalize_image(slice_img), # Normalize safely
224
  slice_mask, view))
225
  for slice_img, slice_mask in slices]
226
 
 
227
  volume = compute_binary_tumor_volume(org_img)
228
 
229
  return fused_images, round(volume, 2)
 
233
  """Load the model and other required resources."""
234
  try:
235
  learn = load_learner(models_path / learner_fn)
 
236
  except Exception as e:
237
  raise ValueError(f"Error loading the model: {str(e)}")
238
 
 
260
  return learn, reorder, resample
261
 
262
  # Initialize the system
263
+ clone_dir = Path.cwd() / 'clone_dir'
264
+ URI = os.getenv('PAT_Token_URI')
 
 
 
265
 
266
+ if not URI:
267
+ raise ValueError("PAT_Token_URI environment variable is not set")
268
 
269
+ if os.path.exists(clone_dir):
270
+ pass
271
+ else:
272
+ Repo.clone_from(URI, clone_dir)
273
+
274
+ models_path = clone_dir
275
+ save_dir = Path.cwd() / 'hs_pred'
276
+ save_dir.mkdir(parents=True, exist_ok=True)
277
+
278
+ # Load the model and other required resources
279
+ learn, reorder, resample = load_system_resources(models_path=models_path)
280
+
281
+ # Gradio interface setup
282
+ output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
283
+
284
+ view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
285
+
286
+ # Ensure the example file path is correct
287
+ example_path = str(clone_dir / "sample.nii.gz")
288
+
289
+ demo = gr.Interface(
290
+ fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
291
+ inputs=["file", view_selector],
292
+ outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
293
+ examples=[[example_path]],
294
+ allow_flagging='never')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  # Launch the Gradio interface
 
297
  demo.launch()