VictorLJZ commited on
Commit
dacb34b
·
1 Parent(s): 775f52a

simplified medsam2 setup

Browse files
Files changed (2) hide show
  1. main.py +2 -2
  2. medrax/tools/medsam2.py +21 -30
main.py CHANGED
@@ -87,7 +87,7 @@ def initialize_agent(
87
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
88
  "WebBrowserTool": lambda: WebBrowserTool(),
89
  "MedSAM2Tool": lambda: MedSAM2Tool(
90
- model_dir=model_dir, device=device, temp_dir=temp_dir
91
  ),
92
  }
93
 
@@ -154,7 +154,7 @@ if __name__ == "__main__":
154
  "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
155
  "WebBrowserTool", # For web browsing and search capabilities
156
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
157
- "PythonSandboxTool", # Add the Python sandbox tool
158
  ]
159
 
160
  # Configure the Retrieval Augmented Generation (RAG) system
 
87
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
88
  "WebBrowserTool": lambda: WebBrowserTool(),
89
  "MedSAM2Tool": lambda: MedSAM2Tool(
90
+ device=device, cache_dir=model_dir, temp_dir=temp_dir
91
  ),
92
  }
93
 
 
154
  "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
155
  "WebBrowserTool", # For web browsing and search capabilities
156
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
157
+ # "PythonSandboxTool", # Add the Python sandbox tool
158
  ]
159
 
160
  # Configure the Retrieval Augmented Generation (RAG) system
medrax/tools/medsam2.py CHANGED
@@ -3,12 +3,8 @@ from pathlib import Path
3
  import uuid
4
  import tempfile
5
  import numpy as np
6
- import torch
7
  import matplotlib.pyplot as plt
8
  from PIL import Image
9
- import cv2
10
- import sys
11
- import os
12
 
13
  from pydantic import BaseModel, Field
14
  from langchain_core.callbacks import (
@@ -17,6 +13,11 @@ from langchain_core.callbacks import (
17
  )
18
  from langchain_core.tools import BaseTool
19
 
 
 
 
 
 
20
 
21
  class MedSAM2Input(BaseModel):
22
  """Input schema for the MedSAM2 Tool."""
@@ -44,7 +45,7 @@ class MedSAM2Tool(BaseTool):
44
  Supports interactive prompting with boxes, points, or automatic segmentation.
45
  """
46
 
47
- name: str = "medsam2_segmentation"
48
  description: str = (
49
  "Advanced medical image segmentation using MedSAM2 (Segment Anything Model 2 for Medical Images). "
50
  "Supports interactive prompting with box coordinates, point clicks, or automatic segmentation. "
@@ -57,47 +58,37 @@ class MedSAM2Tool(BaseTool):
57
  )
58
  args_schema: Type[BaseModel] = MedSAM2Input
59
 
 
 
 
60
  predictor: Any = None
61
- device: str = "cuda"
62
- temp_dir: Path = None
63
- model_dir: Path = None
64
 
65
  def __init__(
66
  self,
67
- model_dir: str,
68
  device: Optional[str] = "cuda",
 
69
  temp_dir: Optional[str] = None,
 
 
70
  model_cfg: str = "sam2.1_hiera_t512.yaml",
71
- checkpoint: str = "MedSAM2_latest.pt",
72
  ):
73
  """Initialize the MedSAM2 tool."""
74
  super().__init__()
75
  self.device = device
76
- self.model_dir = Path(model_dir)
77
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
78
- self.temp_dir.mkdir(exist_ok=True)
79
-
80
- # Add MedSAM2 to Python path
81
- medsam2_path = self.model_dir / "MedSAM2"
82
- if medsam2_path.exists():
83
- sys.path.insert(0, str(medsam2_path))
84
- else:
85
- raise FileNotFoundError(f"MedSAM2 not found at {medsam2_path}. Please run git clone in {model_dir}")
86
 
87
  try:
88
- # Import MedSAM2 modules
89
- from sam2.build_sam import build_sam2
90
- from sam2.sam2_image_predictor import SAM2ImagePredictor
91
-
92
- # Build model
93
- checkpoint_path = medsam2_path / "checkpoints" / checkpoint
94
-
95
- if not checkpoint_path.exists():
96
- raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please run download.sh")
97
 
98
- # Build model using config path relative to sam2 package (MedSAM2 sets up Hydra config paths automatically)
99
  config_path = f"configs/{model_cfg.replace('.yaml', '')}"
100
- sam2_model = build_sam2(config_path, str(checkpoint_path), device=device)
101
  self.predictor = SAM2ImagePredictor(sam2_model)
102
 
103
  print(f"MedSAM2 model loaded successfully on {device}")
 
3
  import uuid
4
  import tempfile
5
  import numpy as np
 
6
  import matplotlib.pyplot as plt
7
  from PIL import Image
 
 
 
8
 
9
  from pydantic import BaseModel, Field
10
  from langchain_core.callbacks import (
 
13
  )
14
  from langchain_core.tools import BaseTool
15
 
16
+ from MedSAM2.sam2.build_sam import build_sam2
17
+ from MedSAM2.sam2.sam2_image_predictor import SAM2ImagePredictor
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
 
22
  class MedSAM2Input(BaseModel):
23
  """Input schema for the MedSAM2 Tool."""
 
45
  Supports interactive prompting with boxes, points, or automatic segmentation.
46
  """
47
 
48
+ name: str = "medsam2"
49
  description: str = (
50
  "Advanced medical image segmentation using MedSAM2 (Segment Anything Model 2 for Medical Images). "
51
  "Supports interactive prompting with box coordinates, point clicks, or automatic segmentation. "
 
58
  )
59
  args_schema: Type[BaseModel] = MedSAM2Input
60
 
61
+ device: Optional[str] = "cuda"
62
+ cache_dir: Path = None
63
+ temp_dir: Path = Path("temp")
64
  predictor: Any = None
 
 
 
65
 
66
  def __init__(
67
  self,
 
68
  device: Optional[str] = "cuda",
69
+ cache_dir: str = "/model-weights",
70
  temp_dir: Optional[str] = None,
71
+ model_path: str = "wanglab/MedSAM2",
72
+ model_file: str = "MedSAM2_latest.pt",
73
  model_cfg: str = "sam2.1_hiera_t512.yaml",
74
+ **kwargs,
75
  ):
76
  """Initialize the MedSAM2 tool."""
77
  super().__init__()
78
  self.device = device
79
+ self.cache_dir = Path(cache_dir)
80
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
 
 
 
 
 
 
 
 
81
 
82
  try:
83
+ hf_hub_download(
84
+ repo_id=model_path,
85
+ filename=model_file,
86
+ local_dir=self.cache_dir,
87
+ local_dir_use_symlinks=False
88
+ )
 
 
 
89
 
 
90
  config_path = f"configs/{model_cfg.replace('.yaml', '')}"
91
+ sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
92
  self.predictor = SAM2ImagePredictor(sam2_model)
93
 
94
  print(f"MedSAM2 model loaded successfully on {device}")