victorli commited on
Commit
fbcbf94
·
1 Parent(s): e97f266

fixed up some things and corrected agent init logic

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. main.py +25 -20
  3. medrax/tools/__init__.py +0 -1
.gitignore CHANGED
@@ -179,4 +179,5 @@ model-weights/
179
 
180
  .DS_Store
181
 
182
- benchmarking/data/
 
 
179
 
180
  .DS_Store
181
 
182
+ benchmarking/data/
183
+ model_cache/
main.py CHANGED
@@ -10,6 +10,7 @@ with different model weights, tools, and parameters.
10
  """
11
 
12
  import warnings
 
13
  from typing import Dict, List, Optional, Any
14
  from dotenv import load_dotenv
15
  from transformers import logging
@@ -33,11 +34,11 @@ _ = load_dotenv()
33
  def initialize_agent(
34
  prompt_file: str,
35
  tools_to_use: Optional[List[str]] = None,
36
- model_dir: str = "/scratch/ssd004/scratch/victorli/model-weights",
37
  temp_dir: str = "temp",
38
  device: str = "cpu",
39
- model: str = "gpt-4.1-2025-04-14",
40
- temperature: float = 0.7,
41
  top_p: float = 0.95,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
@@ -93,21 +94,25 @@ def initialize_agent(
93
  device=device, cache_dir=model_dir, temp_dir=temp_dir
94
  ),
95
  "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
96
- }
97
-
98
- try:
99
- tools_dict["PythonSandboxTool"] = create_python_sandbox()
100
- except Exception as e:
101
- print(f"Error creating PythonSandboxTool: {e}")
102
- print("Skipping PythonSandboxTool")
103
 
104
  # Initialize only selected tools or all if none specified
105
  tools_dict: Dict[str, BaseTool] = {}
106
- tools_to_use = tools_to_use or all_tools.keys()
 
107
  for tool_name in tools_to_use:
 
 
108
  if tool_name in all_tools:
109
  tools_dict[tool_name] = all_tools[tool_name]()
110
 
 
 
 
 
 
 
 
111
  # Set up checkpointing for conversation state
112
  checkpointer = MemorySaver()
113
 
@@ -146,20 +151,20 @@ if __name__ == "__main__":
146
  selected_tools = [
147
  "ImageVisualizerTool", # For displaying images in the UI
148
  # "DicomProcessorTool", # For processing DICOM medical image files
 
 
149
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
150
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
151
- "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
152
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
153
  "XRayVQATool", # For visual question answering on X-rays
154
  # "LlavaMedTool", # For multimodal medical image understanding
155
  "XRayPhraseGroundingTool", # For locating described features in X-rays
156
- # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
157
  # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
158
  # "WebBrowserTool", # For web browsing and search capabilities
 
159
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
160
  # "PythonSandboxTool", # Add the Python sandbox tool
161
- "MedGemmaVQATool" # Google MedGemma VQA tool
162
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
163
  ]
164
 
165
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
@@ -188,11 +193,11 @@ if __name__ == "__main__":
188
  agent, tools_dict = initialize_agent(
189
  prompt_file="medrax/docs/system_prompts.txt",
190
  tools_to_use=selected_tools,
191
- model_dir="/model-weights",
192
  temp_dir="temp", # Change this to the path of the temporary directory
193
- device="cuda:0",
194
- model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
195
- temperature=0.7,
196
  top_p=0.95,
197
  model_kwargs=model_kwargs,
198
  rag_config=rag_config,
 
10
  """
11
 
12
  import warnings
13
+ import os
14
  from typing import Dict, List, Optional, Any
15
  from dotenv import load_dotenv
16
  from transformers import logging
 
34
  def initialize_agent(
35
  prompt_file: str,
36
  tools_to_use: Optional[List[str]] = None,
37
+ model_dir: str = "model-weights",
38
  temp_dir: str = "temp",
39
  device: str = "cpu",
40
+ model: str = "gemini-2.5-pro",
41
+ temperature: float = 1.0,
42
  top_p: float = 0.95,
43
  rag_config: Optional[RAGConfig] = None,
44
  model_kwargs: Dict[str, Any] = {},
 
94
  device=device, cache_dir=model_dir, temp_dir=temp_dir
95
  ),
96
  "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
97
+ }
 
 
 
 
 
 
98
 
99
  # Initialize only selected tools or all if none specified
100
  tools_dict: Dict[str, BaseTool] = {}
101
+ if tools_to_use is None:
102
+ tools_to_use = []
103
  for tool_name in tools_to_use:
104
+ if tool_name == "PythonSandboxTool":
105
+ continue
106
  if tool_name in all_tools:
107
  tools_dict[tool_name] = all_tools[tool_name]()
108
 
109
+ # Try to create the PythonSandboxTool
110
+ try:
111
+ tools_dict["PythonSandboxTool"] = create_python_sandbox()
112
+ except Exception as e:
113
+ print(f"Error creating PythonSandboxTool: {e}")
114
+ print("Skipping PythonSandboxTool")
115
+
116
  # Set up checkpointing for conversation state
117
  checkpointer = MemorySaver()
118
 
 
151
  selected_tools = [
152
  "ImageVisualizerTool", # For displaying images in the UI
153
  # "DicomProcessorTool", # For processing DICOM medical image files
154
+ # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
155
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
156
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
157
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
158
+ "MedGemmaVQATool" # Google MedGemma VQA tool
 
159
  "XRayVQATool", # For visual question answering on X-rays
160
  # "LlavaMedTool", # For multimodal medical image understanding
161
  "XRayPhraseGroundingTool", # For locating described features in X-rays
162
+ "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
163
  # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
164
  # "WebBrowserTool", # For web browsing and search capabilities
165
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
166
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
167
  # "PythonSandboxTool", # Add the Python sandbox tool
 
 
168
  ]
169
 
170
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
 
193
  agent, tools_dict = initialize_agent(
194
  prompt_file="medrax/docs/system_prompts.txt",
195
  tools_to_use=selected_tools,
196
+ model_dir="model-weights",
197
  temp_dir="temp", # Change this to the path of the temporary directory
198
+ device="cpu",
199
+ model="gemini-2.5-pro", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
200
+ temperature=1.0,
201
  top_p=0.95,
202
  model_kwargs=model_kwargs,
203
  rag_config=rag_config,
medrax/tools/__init__.py CHANGED
@@ -11,4 +11,3 @@ from .utils import *
11
  from .rag import *
12
  from .browsing import *
13
  from .python_tool import *
14
- from .medsam2 import *
 
11
  from .rag import *
12
  from .browsing import *
13
  from .python_tool import *