VictorLJZ commited on
Commit
fffa1c9
·
1 Parent(s): dacb34b

working version

Browse files
Files changed (3) hide show
  1. interface.py +4 -28
  2. main.py +3 -3
  3. medrax/tools/medsam2.py +60 -53
interface.py CHANGED
@@ -218,6 +218,9 @@ class ChatInterface:
218
  if tool_name == "image_visualizer":
219
  try:
220
  result = json.loads(msg.content)
 
 
 
221
  if isinstance(result, dict) and "image_path" in result:
222
  self.display_file_path = result["image_path"]
223
  chat_history.append(
@@ -229,33 +232,6 @@ class ChatInterface:
229
  yield chat_history, self.display_file_path, ""
230
  except (json.JSONDecodeError, TypeError):
231
  pass
232
-
233
- elif tool_name == "medsam2_segmentation":
234
- try:
235
- result = json.loads(msg.content)
236
- # Handle both single dict and potential list format
237
- if isinstance(result, list) and len(result) > 0:
238
- result = result[0]
239
-
240
- if isinstance(result, dict):
241
- # Look for visualization path in multiple possible keys
242
- viz_path = None
243
- for key in ["visualization_path", "image_path", "segmentation_image_path"]:
244
- if key in result:
245
- viz_path = result[key]
246
- break
247
-
248
- if viz_path:
249
- self.display_file_path = viz_path
250
- chat_history.append(
251
- ChatMessage(
252
- role="assistant",
253
- content={"path": self.display_file_path},
254
- )
255
- )
256
- yield chat_history, self.display_file_path, ""
257
- except (json.JSONDecodeError, TypeError):
258
- pass
259
 
260
  except Exception as e:
261
  chat_history.append(
@@ -358,4 +334,4 @@ def create_demo(agent, tools_dict):
358
  clear_btn.click(clear_chat, outputs=[chatbot, image_display])
359
  new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
360
 
361
- return demo
 
218
  if tool_name == "image_visualizer":
219
  try:
220
  result = json.loads(msg.content)
221
+ # Handle case where tool returns array [output, metadata]
222
+ if isinstance(result, list) and len(result) > 0:
223
+ result = result[0] # Take the first element (output)
224
  if isinstance(result, dict) and "image_path" in result:
225
  self.display_file_path = result["image_path"]
226
  chat_history.append(
 
232
  yield chat_history, self.display_file_path, ""
233
  except (json.JSONDecodeError, TypeError):
234
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  except Exception as e:
237
  chat_history.append(
 
334
  clear_btn.click(clear_chat, outputs=[chatbot, image_display])
335
  new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
336
 
337
+ return demo
main.py CHANGED
@@ -141,7 +141,7 @@ if __name__ == "__main__":
141
  # Example: initialize with only specific tools
142
  # Here three tools are commented out, you can uncomment them to use them
143
  selected_tools = [
144
- # "ImageVisualizerTool", # For displaying images in the UI
145
  # "DicomProcessorTool", # For processing DICOM medical image files
146
  # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
147
  # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
@@ -152,8 +152,8 @@ if __name__ == "__main__":
152
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
153
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
154
  "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
155
- "WebBrowserTool", # For web browsing and search capabilities
156
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
157
  # "PythonSandboxTool", # Add the Python sandbox tool
158
  ]
159
 
 
141
  # Example: initialize with only specific tools
142
  # Here three tools are commented out, you can uncomment them to use them
143
  selected_tools = [
144
+ "ImageVisualizerTool", # For displaying images in the UI
145
  # "DicomProcessorTool", # For processing DICOM medical image files
146
  # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
147
  # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
 
152
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
153
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
154
  "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
155
+ # "WebBrowserTool", # For web browsing and search capabilities
156
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
157
  # "PythonSandboxTool", # Add the Python sandbox tool
158
  ]
159
 
medrax/tools/medsam2.py CHANGED
@@ -170,47 +170,55 @@ class MedSAM2Tool(BaseTool):
170
 
171
  def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
172
  """Create visualization of segmentation results."""
173
- plt.figure(figsize=(12, 8))
174
 
175
- # Display original image
176
- plt.subplot(1, 2, 1)
177
- plt.imshow(image)
178
- plt.title("Original Image")
179
- plt.axis('off')
 
 
 
 
 
 
180
 
181
- # Display segmentation overlay
182
- plt.subplot(1, 2, 2)
183
- plt.imshow(image)
184
 
185
- # Overlay masks
186
- if len(masks) > 0:
187
- # Use the best mask (first one returned by SAM2)
188
- mask = masks[0]
189
- # Convert mask to boolean and ensure proper shape
190
- mask_bool = mask.astype(bool)
191
- colored_mask = np.zeros((*mask_bool.shape, 4))
192
- colored_mask[mask_bool] = [1, 0, 0, 0.5] # Red with transparency
193
- plt.imshow(colored_mask)
 
 
 
 
 
194
 
195
- # Add prompt visualization
196
  if prompt_info.get('box') is not None:
197
  box = prompt_info['box'][0]
198
  x1, y1, x2, y2 = box
199
- plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2)
200
  plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2, label='Box Prompt')
201
 
202
  if prompt_info.get('point') is not None:
203
  point = prompt_info['point'][0]
204
  plt.plot(point[0], point[1], 'go', markersize=10, label='Point Prompt')
205
 
206
- plt.title("Segmentation Result")
207
- plt.axis('off')
208
- if prompt_info.get('box') is not None or prompt_info.get('point') is not None:
209
- plt.legend()
210
 
211
- # Save visualization
212
  viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
213
- plt.savefig(viz_path, bbox_inches='tight', dpi=150)
214
  plt.close()
215
 
216
  return str(viz_path)
@@ -222,7 +230,7 @@ class MedSAM2Tool(BaseTool):
222
  prompt_coords: Optional[List[int]] = None,
223
  slice_index: Optional[int] = None,
224
  run_manager: Optional[CallbackManagerForToolRun] = None,
225
- ) -> Dict[str, Any]:
226
  """Run MedSAM2 segmentation on the input image."""
227
  try:
228
  # Load image
@@ -266,15 +274,15 @@ class MedSAM2Tool(BaseTool):
266
  prompt_info = {
267
  'box': input_box,
268
  'point': input_point,
269
- 'type': prompt_type
 
270
  }
271
  viz_path = self._create_visualization(image, masks, prompt_info)
272
 
273
- # Process results (exclude large mask arrays to avoid token limits)
274
- results = {
275
- "success": True,
276
  "confidence_scores": scores.tolist() if hasattr(scores, 'tolist') else list(scores),
277
- "visualization_path": viz_path,
278
  "num_masks": len(masks),
279
  "best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
280
  "mask_summary": {
@@ -282,31 +290,30 @@ class MedSAM2Tool(BaseTool):
282
  "mask_shapes": [list(mask.shape) for mask in masks],
283
  "segmented_area_pixels": [int(mask.sum()) for mask in masks]
284
  },
285
- # Include metadata in the main results
286
- "metadata": {
287
- "image_path": image_path,
288
- "image_shape": list(image.shape),
289
- "prompt_type": prompt_type,
290
- "prompt_coords": prompt_coords,
291
- "device": self.device,
292
- "num_masks_generated": len(masks),
293
- "analysis_status": "completed",
294
- }
295
  }
296
 
297
- return results
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  except Exception as e:
300
- error_result = {
301
- "error": str(e),
302
- "success": False,
303
- "metadata": {
304
- "image_path": image_path,
305
- "analysis_status": "failed",
306
- "error_details": str(e),
307
- }
308
  }
309
- return error_result
310
 
311
  async def _arun(
312
  self,
@@ -315,6 +322,6 @@ class MedSAM2Tool(BaseTool):
315
  prompt_coords: Optional[List[int]] = None,
316
  slice_index: Optional[int] = None,
317
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
318
- ) -> Dict[str, Any]:
319
  """Async version of _run."""
320
  return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
 
170
 
171
  def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
172
  """Create visualization of segmentation results."""
173
+ plt.figure(figsize=(10, 10))
174
 
175
+ # Convert RGB image to grayscale for background display
176
+ if len(image.shape) == 3:
177
+ # Convert RGB to grayscale using standard luminance formula
178
+ gray_image = 0.299 * image[:,:,0] + 0.587 * image[:,:,1] + 0.114 * image[:,:,2]
179
+ else:
180
+ gray_image = image
181
+
182
+ # Display grayscale background
183
+ plt.imshow(
184
+ gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0]
185
+ )
186
 
187
+ # Generate color palette for multiple masks
188
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
 
189
 
190
+ # Process and overlay each mask
191
+ for idx, (mask, color) in enumerate(zip(masks, colors)):
192
+ if mask.sum() > 0:
193
+ # Convert mask to boolean and ensure proper shape
194
+ mask_bool = mask.astype(bool)
195
+ colored_mask = np.zeros((*mask_bool.shape, 4))
196
+ colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
197
+ plt.imshow(
198
+ colored_mask, extent=[0, image.shape[1], image.shape[0], 0]
199
+ )
200
+
201
+ # Add legend entry for each mask
202
+ mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
203
+ plt.plot([], [], color=color, label=mask_label, linewidth=3)
204
 
205
+ # Add prompt visualization with consistent styling
206
  if prompt_info.get('box') is not None:
207
  box = prompt_info['box'][0]
208
  x1, y1, x2, y2 = box
 
209
  plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2, label='Box Prompt')
210
 
211
  if prompt_info.get('point') is not None:
212
  point = prompt_info['point'][0]
213
  plt.plot(point[0], point[1], 'go', markersize=10, label='Point Prompt')
214
 
215
+ plt.title("Segmentation Overlay")
216
+ plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
217
+ plt.axis("off")
 
218
 
219
+ # Save visualization with higher DPI like segmentation tool
220
  viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
221
+ plt.savefig(viz_path, bbox_inches='tight', dpi=300)
222
  plt.close()
223
 
224
  return str(viz_path)
 
230
  prompt_coords: Optional[List[int]] = None,
231
  slice_index: Optional[int] = None,
232
  run_manager: Optional[CallbackManagerForToolRun] = None,
233
+ ) -> Tuple[Dict[str, Any], Dict]:
234
  """Run MedSAM2 segmentation on the input image."""
235
  try:
236
  # Load image
 
274
  prompt_info = {
275
  'box': input_box,
276
  'point': input_point,
277
+ 'type': prompt_type,
278
+ 'scores': scores # Add scores for legend display
279
  }
280
  viz_path = self._create_visualization(image, masks, prompt_info)
281
 
282
+ # Create output dictionary (main results)
283
+ output = {
284
+ "segmentation_image_path": viz_path,
285
  "confidence_scores": scores.tolist() if hasattr(scores, 'tolist') else list(scores),
 
286
  "num_masks": len(masks),
287
  "best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
288
  "mask_summary": {
 
290
  "mask_shapes": [list(mask.shape) for mask in masks],
291
  "segmented_area_pixels": [int(mask.sum()) for mask in masks]
292
  },
 
 
 
 
 
 
 
 
 
 
293
  }
294
 
295
+ # Create metadata dictionary
296
+ metadata = {
297
+ "image_path": image_path,
298
+ "segmentation_image_path": viz_path,
299
+ "image_shape": list(image.shape),
300
+ "prompt_type": prompt_type,
301
+ "prompt_coords": prompt_coords,
302
+ "device": self.device,
303
+ "num_masks_generated": len(masks),
304
+ "analysis_status": "completed",
305
+ }
306
+
307
+ return output, metadata
308
 
309
  except Exception as e:
310
+ error_output = {"error": str(e)}
311
+ error_metadata = {
312
+ "image_path": image_path,
313
+ "analysis_status": "failed",
314
+ "error_details": str(e),
 
 
 
315
  }
316
+ return error_output, error_metadata
317
 
318
  async def _arun(
319
  self,
 
322
  prompt_coords: Optional[List[int]] = None,
323
  slice_index: Optional[int] = None,
324
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
325
+ ) -> Tuple[Dict[str, Any], Dict]:
326
  """Async version of _run."""
327
  return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)