VictorLJZ commited on
Commit
707b83e
·
1 Parent(s): 26b938b
Files changed (1) hide show
  1. medrax/tools/medsam2.py +326 -0
medrax/tools/medsam2.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Type, Any
2
+ from pathlib import Path
3
+ import uuid
4
+ import tempfile
5
+ import numpy as np
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import cv2
10
+ import sys
11
+ import os
12
+
13
+ from pydantic import BaseModel, Field
14
+ from langchain_core.callbacks import (
15
+ AsyncCallbackManagerForToolRun,
16
+ CallbackManagerForToolRun,
17
+ )
18
+ from langchain_core.tools import BaseTool
19
+
20
+
21
+ class MedSAM2Input(BaseModel):
22
+ """Input schema for the MedSAM2 Tool."""
23
+
24
+ image_path: str = Field(..., description="Path to the medical image file to be segmented")
25
+ prompt_type: str = Field(
26
+ "box",
27
+ description="Type of prompt: 'box' for bounding box, 'point' for point click, or 'auto' for automatic segmentation",
28
+ )
29
+ prompt_coords: Optional[List[int]] = Field(
30
+ None,
31
+ description="Prompt coordinates: [x1,y1,x2,y2] for box prompt or [x,y] for point prompt. Leave None for auto segmentation",
32
+ )
33
+ slice_index: Optional[int] = Field(
34
+ None,
35
+ description="Specific slice index for 3D volumes (0-based). If None, processes middle slice",
36
+ )
37
+
38
+
39
+ class MedSAM2Tool(BaseTool):
40
+ """Advanced medical image segmentation tool using MedSAM2.
41
+
42
+ This tool provides state-of-the-art medical image segmentation capabilities using
43
+ the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
44
+ Supports interactive prompting with boxes, points, or automatic segmentation.
45
+ """
46
+
47
+ name: str = "medsam2_segmentation"
48
+ description: str = (
49
+ "Advanced medical image segmentation using MedSAM2 (Segment Anything Model 2 for Medical Images). "
50
+ "Supports interactive prompting with box coordinates, point clicks, or automatic segmentation. "
51
+ "Can handle 2D medical images and 3D volumes. Returns segmentation masks and visualization overlays. "
52
+ "Prompt types: 'box' with [x1,y1,x2,y2] coordinates, 'point' with [x,y] coordinates, or 'auto' for automatic. "
53
+ "Example: {'image_path': '/path/to/image.png', 'prompt_type': 'box', 'prompt_coords': [100,100,200,200]}"
54
+ )
55
+ args_schema: Type[BaseModel] = MedSAM2Input
56
+
57
+ predictor: Any = None
58
+ device: str = "cuda"
59
+ temp_dir: Path = None
60
+ model_dir: Path = None
61
+
62
+ def __init__(
63
+ self,
64
+ model_dir: str,
65
+ device: Optional[str] = "cuda",
66
+ temp_dir: Optional[str] = None,
67
+ model_cfg: str = "sam2.1_hiera_t512.yaml",
68
+ checkpoint: str = "MedSAM2_latest.pt",
69
+ ):
70
+ """Initialize the MedSAM2 tool."""
71
+ super().__init__()
72
+ self.device = device
73
+ self.model_dir = Path(model_dir)
74
+ self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
75
+ self.temp_dir.mkdir(exist_ok=True)
76
+
77
+ # Add MedSAM2 to Python path
78
+ medsam2_path = self.model_dir / "MedSAM2"
79
+ if medsam2_path.exists():
80
+ sys.path.insert(0, str(medsam2_path))
81
+ else:
82
+ raise FileNotFoundError(f"MedSAM2 not found at {medsam2_path}. Please run git clone in {model_dir}")
83
+
84
+ try:
85
+ # Import MedSAM2 modules
86
+ from sam2.build_sam import build_sam2
87
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
88
+
89
+ # Build model
90
+ checkpoint_path = medsam2_path / "checkpoints" / checkpoint
91
+
92
+ if not checkpoint_path.exists():
93
+ raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please run download.sh")
94
+
95
+ # Build model using config path relative to sam2 package (MedSAM2 sets up Hydra config paths automatically)
96
+ config_path = f"configs/{model_cfg.replace('.yaml', '')}"
97
+ sam2_model = build_sam2(config_path, str(checkpoint_path), device=device)
98
+ self.predictor = SAM2ImagePredictor(sam2_model)
99
+
100
+ print(f"MedSAM2 model loaded successfully on {device}")
101
+
102
+ except Exception as e:
103
+ raise RuntimeError(f"Failed to initialize MedSAM2: {str(e)}")
104
+
105
+ def _load_image(self, image_path: str) -> np.ndarray:
106
+ """Load and preprocess image for medical analysis."""
107
+ try:
108
+ # Handle different image formats
109
+ if image_path.lower().endswith('.dcm'):
110
+ # DICOM files - would need DICOM processor
111
+ raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
112
+
113
+ # Load standard image formats
114
+ image = Image.open(image_path)
115
+
116
+ # For medical images, convert to grayscale first if needed, then to RGB
117
+ if image.mode == 'L': # Grayscale
118
+ # Convert grayscale to RGB for SAM2
119
+ image = image.convert('RGB')
120
+ elif image.mode != 'RGB':
121
+ if image.mode == 'RGBA':
122
+ # Create white background for RGBA
123
+ background = Image.new('RGB', image.size, (255, 255, 255))
124
+ background.paste(image, mask=image.split()[-1])
125
+ image = background
126
+ else:
127
+ image = image.convert('RGB')
128
+
129
+ # Convert to numpy array
130
+ image_np = np.array(image)
131
+
132
+ # Ensure image is in proper range [0, 255]
133
+ if image_np.max() <= 1.0:
134
+ image_np = (image_np * 255).astype(np.uint8)
135
+ else:
136
+ image_np = image_np.astype(np.uint8)
137
+
138
+ return image_np
139
+
140
+ except Exception as e:
141
+ raise ValueError(f"Failed to load image {image_path}: {str(e)}")
142
+
143
+ def _process_prompts(self, prompt_type: str, prompt_coords: Optional[List[int]], image_shape: Tuple[int, int]):
144
+ """Process and validate prompts."""
145
+ if prompt_type == "auto":
146
+ return None, None, None
147
+
148
+ if prompt_coords is None:
149
+ if prompt_type != "auto":
150
+ raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
151
+ return None, None, None
152
+
153
+ if prompt_type == "box":
154
+ if len(prompt_coords) != 4:
155
+ raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
156
+
157
+ x1, y1, x2, y2 = prompt_coords
158
+ # Validate coordinates
159
+ if x1 >= x2 or y1 >= y2:
160
+ raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
161
+
162
+ input_box = np.array([[x1, y1, x2, y2]])
163
+ return input_box, None, None
164
+
165
+ elif prompt_type == "point":
166
+ if len(prompt_coords) != 2:
167
+ raise ValueError("Point prompt requires 2 coordinates: [x,y]")
168
+
169
+ x, y = prompt_coords
170
+ input_point = np.array([[x, y]])
171
+ input_label = np.array([1]) # Positive point
172
+ return None, input_point, input_label
173
+
174
+ else:
175
+ raise ValueError(f"Unknown prompt type: {prompt_type}")
176
+
177
+ def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
178
+ """Create visualization of segmentation results."""
179
+ plt.figure(figsize=(12, 8))
180
+
181
+ # Display original image
182
+ plt.subplot(1, 2, 1)
183
+ plt.imshow(image)
184
+ plt.title("Original Image")
185
+ plt.axis('off')
186
+
187
+ # Display segmentation overlay
188
+ plt.subplot(1, 2, 2)
189
+ plt.imshow(image)
190
+
191
+ # Overlay masks
192
+ if len(masks) > 0:
193
+ # Use the best mask (first one returned by SAM2)
194
+ mask = masks[0]
195
+ # Convert mask to boolean and ensure proper shape
196
+ mask_bool = mask.astype(bool)
197
+ colored_mask = np.zeros((*mask_bool.shape, 4))
198
+ colored_mask[mask_bool] = [1, 0, 0, 0.5] # Red with transparency
199
+ plt.imshow(colored_mask)
200
+
201
+ # Add prompt visualization
202
+ if prompt_info.get('box') is not None:
203
+ box = prompt_info['box'][0]
204
+ x1, y1, x2, y2 = box
205
+ plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2)
206
+ plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2, label='Box Prompt')
207
+
208
+ if prompt_info.get('point') is not None:
209
+ point = prompt_info['point'][0]
210
+ plt.plot(point[0], point[1], 'go', markersize=10, label='Point Prompt')
211
+
212
+ plt.title("Segmentation Result")
213
+ plt.axis('off')
214
+ if prompt_info.get('box') is not None or prompt_info.get('point') is not None:
215
+ plt.legend()
216
+
217
+ # Save visualization
218
+ viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
219
+ plt.savefig(viz_path, bbox_inches='tight', dpi=150)
220
+ plt.close()
221
+
222
+ return str(viz_path)
223
+
224
+ def _run(
225
+ self,
226
+ image_path: str,
227
+ prompt_type: str = "box",
228
+ prompt_coords: Optional[List[int]] = None,
229
+ slice_index: Optional[int] = None,
230
+ run_manager: Optional[CallbackManagerForToolRun] = None,
231
+ ) -> Dict[str, Any]:
232
+ """Run MedSAM2 segmentation on the input image."""
233
+ try:
234
+ # Load image
235
+ image = self._load_image(image_path)
236
+
237
+ # Set image for predictor
238
+ self.predictor.set_image(image)
239
+
240
+ # Process prompts
241
+ input_box, input_point, input_label = self._process_prompts(
242
+ prompt_type, prompt_coords, image.shape[:2]
243
+ )
244
+
245
+ # Run inference
246
+ if prompt_type == "auto":
247
+ # For auto segmentation, try multiple approaches and select best result
248
+ h, w = image.shape[:2]
249
+
250
+ # Try multiple points in key areas for medical images
251
+ sample_points = np.array([
252
+ [w//3, h//3], # Upper left lung area
253
+ [2*w//3, h//3], # Upper right lung area
254
+ [w//2, 2*h//3], # Lower center area
255
+ ])
256
+ sample_labels = np.array([1, 1, 1]) # All positive points
257
+
258
+ masks, scores, logits = self.predictor.predict(
259
+ point_coords=sample_points,
260
+ point_labels=sample_labels,
261
+ multimask_output=True,
262
+ )
263
+ else:
264
+ masks, scores, logits = self.predictor.predict(
265
+ point_coords=input_point,
266
+ point_labels=input_label,
267
+ box=input_box,
268
+ multimask_output=True,
269
+ )
270
+
271
+ # Create visualization
272
+ prompt_info = {
273
+ 'box': input_box,
274
+ 'point': input_point,
275
+ 'type': prompt_type
276
+ }
277
+ viz_path = self._create_visualization(image, masks, prompt_info)
278
+
279
+ # Process results (exclude large mask arrays to avoid token limits)
280
+ results = {
281
+ "success": True,
282
+ "confidence_scores": scores.tolist() if hasattr(scores, 'tolist') else list(scores),
283
+ "visualization_path": viz_path,
284
+ "num_masks": len(masks),
285
+ "best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
286
+ "mask_summary": {
287
+ "total_masks": len(masks),
288
+ "mask_shapes": [list(mask.shape) for mask in masks],
289
+ "segmented_area_pixels": [int(mask.sum()) for mask in masks]
290
+ },
291
+ # Include metadata in the main results
292
+ "metadata": {
293
+ "image_path": image_path,
294
+ "image_shape": list(image.shape),
295
+ "prompt_type": prompt_type,
296
+ "prompt_coords": prompt_coords,
297
+ "device": self.device,
298
+ "num_masks_generated": len(masks),
299
+ "analysis_status": "completed",
300
+ }
301
+ }
302
+
303
+ return results
304
+
305
+ except Exception as e:
306
+ error_result = {
307
+ "error": str(e),
308
+ "success": False,
309
+ "metadata": {
310
+ "image_path": image_path,
311
+ "analysis_status": "failed",
312
+ "error_details": str(e),
313
+ }
314
+ }
315
+ return error_result
316
+
317
+ async def _arun(
318
+ self,
319
+ image_path: str,
320
+ prompt_type: str = "box",
321
+ prompt_coords: Optional[List[int]] = None,
322
+ slice_index: Optional[int] = None,
323
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
324
+ ) -> Dict[str, Any]:
325
+ """Async version of _run."""
326
+ return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)