Spaces:
Paused
Paused
working version
Browse files- interface.py +4 -28
- main.py +3 -3
- medrax/tools/medsam2.py +60 -53
interface.py
CHANGED
|
@@ -218,6 +218,9 @@ class ChatInterface:
|
|
| 218 |
if tool_name == "image_visualizer":
|
| 219 |
try:
|
| 220 |
result = json.loads(msg.content)
|
|
|
|
|
|
|
|
|
|
| 221 |
if isinstance(result, dict) and "image_path" in result:
|
| 222 |
self.display_file_path = result["image_path"]
|
| 223 |
chat_history.append(
|
|
@@ -229,33 +232,6 @@ class ChatInterface:
|
|
| 229 |
yield chat_history, self.display_file_path, ""
|
| 230 |
except (json.JSONDecodeError, TypeError):
|
| 231 |
pass
|
| 232 |
-
|
| 233 |
-
elif tool_name == "medsam2_segmentation":
|
| 234 |
-
try:
|
| 235 |
-
result = json.loads(msg.content)
|
| 236 |
-
# Handle both single dict and potential list format
|
| 237 |
-
if isinstance(result, list) and len(result) > 0:
|
| 238 |
-
result = result[0]
|
| 239 |
-
|
| 240 |
-
if isinstance(result, dict):
|
| 241 |
-
# Look for visualization path in multiple possible keys
|
| 242 |
-
viz_path = None
|
| 243 |
-
for key in ["visualization_path", "image_path", "segmentation_image_path"]:
|
| 244 |
-
if key in result:
|
| 245 |
-
viz_path = result[key]
|
| 246 |
-
break
|
| 247 |
-
|
| 248 |
-
if viz_path:
|
| 249 |
-
self.display_file_path = viz_path
|
| 250 |
-
chat_history.append(
|
| 251 |
-
ChatMessage(
|
| 252 |
-
role="assistant",
|
| 253 |
-
content={"path": self.display_file_path},
|
| 254 |
-
)
|
| 255 |
-
)
|
| 256 |
-
yield chat_history, self.display_file_path, ""
|
| 257 |
-
except (json.JSONDecodeError, TypeError):
|
| 258 |
-
pass
|
| 259 |
|
| 260 |
except Exception as e:
|
| 261 |
chat_history.append(
|
|
@@ -358,4 +334,4 @@ def create_demo(agent, tools_dict):
|
|
| 358 |
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
| 359 |
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
| 360 |
|
| 361 |
-
return demo
|
|
|
|
| 218 |
if tool_name == "image_visualizer":
|
| 219 |
try:
|
| 220 |
result = json.loads(msg.content)
|
| 221 |
+
# Handle case where tool returns array [output, metadata]
|
| 222 |
+
if isinstance(result, list) and len(result) > 0:
|
| 223 |
+
result = result[0] # Take the first element (output)
|
| 224 |
if isinstance(result, dict) and "image_path" in result:
|
| 225 |
self.display_file_path = result["image_path"]
|
| 226 |
chat_history.append(
|
|
|
|
| 232 |
yield chat_history, self.display_file_path, ""
|
| 233 |
except (json.JSONDecodeError, TypeError):
|
| 234 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
except Exception as e:
|
| 237 |
chat_history.append(
|
|
|
|
| 334 |
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
| 335 |
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
| 336 |
|
| 337 |
+
return demo
|
main.py
CHANGED
|
@@ -141,7 +141,7 @@ if __name__ == "__main__":
|
|
| 141 |
# Example: initialize with only specific tools
|
| 142 |
# Here three tools are commented out, you can uncomment them to use them
|
| 143 |
selected_tools = [
|
| 144 |
-
|
| 145 |
# "DicomProcessorTool", # For processing DICOM medical image files
|
| 146 |
# "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 147 |
# "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
|
@@ -152,8 +152,8 @@ if __name__ == "__main__":
|
|
| 152 |
# "XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 153 |
# "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
|
| 154 |
"MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
|
| 155 |
-
"WebBrowserTool", # For web browsing and search capabilities
|
| 156 |
-
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 157 |
# "PythonSandboxTool", # Add the Python sandbox tool
|
| 158 |
]
|
| 159 |
|
|
|
|
| 141 |
# Example: initialize with only specific tools
|
| 142 |
# Here three tools are commented out, you can uncomment them to use them
|
| 143 |
selected_tools = [
|
| 144 |
+
"ImageVisualizerTool", # For displaying images in the UI
|
| 145 |
# "DicomProcessorTool", # For processing DICOM medical image files
|
| 146 |
# "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 147 |
# "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
|
|
|
| 152 |
# "XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 153 |
# "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
|
| 154 |
"MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
|
| 155 |
+
# "WebBrowserTool", # For web browsing and search capabilities
|
| 156 |
+
# "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 157 |
# "PythonSandboxTool", # Add the Python sandbox tool
|
| 158 |
]
|
| 159 |
|
medrax/tools/medsam2.py
CHANGED
|
@@ -170,47 +170,55 @@ class MedSAM2Tool(BaseTool):
|
|
| 170 |
|
| 171 |
def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
|
| 172 |
"""Create visualization of segmentation results."""
|
| 173 |
-
plt.figure(figsize=(
|
| 174 |
|
| 175 |
-
#
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
plt.
|
| 183 |
-
plt.imshow(image)
|
| 184 |
|
| 185 |
-
#
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
# Add prompt visualization
|
| 196 |
if prompt_info.get('box') is not None:
|
| 197 |
box = prompt_info['box'][0]
|
| 198 |
x1, y1, x2, y2 = box
|
| 199 |
-
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2)
|
| 200 |
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2, label='Box Prompt')
|
| 201 |
|
| 202 |
if prompt_info.get('point') is not None:
|
| 203 |
point = prompt_info['point'][0]
|
| 204 |
plt.plot(point[0], point[1], 'go', markersize=10, label='Point Prompt')
|
| 205 |
|
| 206 |
-
plt.title("Segmentation
|
| 207 |
-
plt.
|
| 208 |
-
|
| 209 |
-
plt.legend()
|
| 210 |
|
| 211 |
-
# Save visualization
|
| 212 |
viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
|
| 213 |
-
plt.savefig(viz_path, bbox_inches='tight', dpi=
|
| 214 |
plt.close()
|
| 215 |
|
| 216 |
return str(viz_path)
|
|
@@ -222,7 +230,7 @@ class MedSAM2Tool(BaseTool):
|
|
| 222 |
prompt_coords: Optional[List[int]] = None,
|
| 223 |
slice_index: Optional[int] = None,
|
| 224 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
| 225 |
-
) -> Dict[str, Any]:
|
| 226 |
"""Run MedSAM2 segmentation on the input image."""
|
| 227 |
try:
|
| 228 |
# Load image
|
|
@@ -266,15 +274,15 @@ class MedSAM2Tool(BaseTool):
|
|
| 266 |
prompt_info = {
|
| 267 |
'box': input_box,
|
| 268 |
'point': input_point,
|
| 269 |
-
'type': prompt_type
|
|
|
|
| 270 |
}
|
| 271 |
viz_path = self._create_visualization(image, masks, prompt_info)
|
| 272 |
|
| 273 |
-
#
|
| 274 |
-
|
| 275 |
-
"
|
| 276 |
"confidence_scores": scores.tolist() if hasattr(scores, 'tolist') else list(scores),
|
| 277 |
-
"visualization_path": viz_path,
|
| 278 |
"num_masks": len(masks),
|
| 279 |
"best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
|
| 280 |
"mask_summary": {
|
|
@@ -282,31 +290,30 @@ class MedSAM2Tool(BaseTool):
|
|
| 282 |
"mask_shapes": [list(mask.shape) for mask in masks],
|
| 283 |
"segmented_area_pixels": [int(mask.sum()) for mask in masks]
|
| 284 |
},
|
| 285 |
-
# Include metadata in the main results
|
| 286 |
-
"metadata": {
|
| 287 |
-
"image_path": image_path,
|
| 288 |
-
"image_shape": list(image.shape),
|
| 289 |
-
"prompt_type": prompt_type,
|
| 290 |
-
"prompt_coords": prompt_coords,
|
| 291 |
-
"device": self.device,
|
| 292 |
-
"num_masks_generated": len(masks),
|
| 293 |
-
"analysis_status": "completed",
|
| 294 |
-
}
|
| 295 |
}
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
except Exception as e:
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
"
|
| 303 |
-
"
|
| 304 |
-
|
| 305 |
-
"analysis_status": "failed",
|
| 306 |
-
"error_details": str(e),
|
| 307 |
-
}
|
| 308 |
}
|
| 309 |
-
return
|
| 310 |
|
| 311 |
async def _arun(
|
| 312 |
self,
|
|
@@ -315,6 +322,6 @@ class MedSAM2Tool(BaseTool):
|
|
| 315 |
prompt_coords: Optional[List[int]] = None,
|
| 316 |
slice_index: Optional[int] = None,
|
| 317 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 318 |
-
) -> Dict[str, Any]:
|
| 319 |
"""Async version of _run."""
|
| 320 |
return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
|
|
|
|
| 170 |
|
| 171 |
def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
|
| 172 |
"""Create visualization of segmentation results."""
|
| 173 |
+
plt.figure(figsize=(10, 10))
|
| 174 |
|
| 175 |
+
# Convert RGB image to grayscale for background display
|
| 176 |
+
if len(image.shape) == 3:
|
| 177 |
+
# Convert RGB to grayscale using standard luminance formula
|
| 178 |
+
gray_image = 0.299 * image[:,:,0] + 0.587 * image[:,:,1] + 0.114 * image[:,:,2]
|
| 179 |
+
else:
|
| 180 |
+
gray_image = image
|
| 181 |
+
|
| 182 |
+
# Display grayscale background
|
| 183 |
+
plt.imshow(
|
| 184 |
+
gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0]
|
| 185 |
+
)
|
| 186 |
|
| 187 |
+
# Generate color palette for multiple masks
|
| 188 |
+
colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
|
|
|
|
| 189 |
|
| 190 |
+
# Process and overlay each mask
|
| 191 |
+
for idx, (mask, color) in enumerate(zip(masks, colors)):
|
| 192 |
+
if mask.sum() > 0:
|
| 193 |
+
# Convert mask to boolean and ensure proper shape
|
| 194 |
+
mask_bool = mask.astype(bool)
|
| 195 |
+
colored_mask = np.zeros((*mask_bool.shape, 4))
|
| 196 |
+
colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
|
| 197 |
+
plt.imshow(
|
| 198 |
+
colored_mask, extent=[0, image.shape[1], image.shape[0], 0]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Add legend entry for each mask
|
| 202 |
+
mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
|
| 203 |
+
plt.plot([], [], color=color, label=mask_label, linewidth=3)
|
| 204 |
|
| 205 |
+
# Add prompt visualization with consistent styling
|
| 206 |
if prompt_info.get('box') is not None:
|
| 207 |
box = prompt_info['box'][0]
|
| 208 |
x1, y1, x2, y2 = box
|
|
|
|
| 209 |
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2, label='Box Prompt')
|
| 210 |
|
| 211 |
if prompt_info.get('point') is not None:
|
| 212 |
point = prompt_info['point'][0]
|
| 213 |
plt.plot(point[0], point[1], 'go', markersize=10, label='Point Prompt')
|
| 214 |
|
| 215 |
+
plt.title("Segmentation Overlay")
|
| 216 |
+
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
| 217 |
+
plt.axis("off")
|
|
|
|
| 218 |
|
| 219 |
+
# Save visualization with higher DPI like segmentation tool
|
| 220 |
viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
|
| 221 |
+
plt.savefig(viz_path, bbox_inches='tight', dpi=300)
|
| 222 |
plt.close()
|
| 223 |
|
| 224 |
return str(viz_path)
|
|
|
|
| 230 |
prompt_coords: Optional[List[int]] = None,
|
| 231 |
slice_index: Optional[int] = None,
|
| 232 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
| 233 |
+
) -> Tuple[Dict[str, Any], Dict]:
|
| 234 |
"""Run MedSAM2 segmentation on the input image."""
|
| 235 |
try:
|
| 236 |
# Load image
|
|
|
|
| 274 |
prompt_info = {
|
| 275 |
'box': input_box,
|
| 276 |
'point': input_point,
|
| 277 |
+
'type': prompt_type,
|
| 278 |
+
'scores': scores # Add scores for legend display
|
| 279 |
}
|
| 280 |
viz_path = self._create_visualization(image, masks, prompt_info)
|
| 281 |
|
| 282 |
+
# Create output dictionary (main results)
|
| 283 |
+
output = {
|
| 284 |
+
"segmentation_image_path": viz_path,
|
| 285 |
"confidence_scores": scores.tolist() if hasattr(scores, 'tolist') else list(scores),
|
|
|
|
| 286 |
"num_masks": len(masks),
|
| 287 |
"best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
|
| 288 |
"mask_summary": {
|
|
|
|
| 290 |
"mask_shapes": [list(mask.shape) for mask in masks],
|
| 291 |
"segmented_area_pixels": [int(mask.sum()) for mask in masks]
|
| 292 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
}
|
| 294 |
|
| 295 |
+
# Create metadata dictionary
|
| 296 |
+
metadata = {
|
| 297 |
+
"image_path": image_path,
|
| 298 |
+
"segmentation_image_path": viz_path,
|
| 299 |
+
"image_shape": list(image.shape),
|
| 300 |
+
"prompt_type": prompt_type,
|
| 301 |
+
"prompt_coords": prompt_coords,
|
| 302 |
+
"device": self.device,
|
| 303 |
+
"num_masks_generated": len(masks),
|
| 304 |
+
"analysis_status": "completed",
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
return output, metadata
|
| 308 |
|
| 309 |
except Exception as e:
|
| 310 |
+
error_output = {"error": str(e)}
|
| 311 |
+
error_metadata = {
|
| 312 |
+
"image_path": image_path,
|
| 313 |
+
"analysis_status": "failed",
|
| 314 |
+
"error_details": str(e),
|
|
|
|
|
|
|
|
|
|
| 315 |
}
|
| 316 |
+
return error_output, error_metadata
|
| 317 |
|
| 318 |
async def _arun(
|
| 319 |
self,
|
|
|
|
| 322 |
prompt_coords: Optional[List[int]] = None,
|
| 323 |
slice_index: Optional[int] = None,
|
| 324 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 325 |
+
) -> Tuple[Dict[str, Any], Dict]:
|
| 326 |
"""Async version of _run."""
|
| 327 |
return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
|