Spaces:
Paused
Paused
victorli
commited on
Commit
·
fbcbf94
1
Parent(s):
e97f266
fixed up some things and corrected agent init logic
Browse files- .gitignore +2 -1
- main.py +25 -20
- medrax/tools/__init__.py +0 -1
.gitignore
CHANGED
|
@@ -179,4 +179,5 @@ model-weights/
|
|
| 179 |
|
| 180 |
.DS_Store
|
| 181 |
|
| 182 |
-
benchmarking/data/
|
|
|
|
|
|
| 179 |
|
| 180 |
.DS_Store
|
| 181 |
|
| 182 |
+
benchmarking/data/
|
| 183 |
+
model_cache/
|
main.py
CHANGED
|
@@ -10,6 +10,7 @@ with different model weights, tools, and parameters.
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import warnings
|
|
|
|
| 13 |
from typing import Dict, List, Optional, Any
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
from transformers import logging
|
|
@@ -33,11 +34,11 @@ _ = load_dotenv()
|
|
| 33 |
def initialize_agent(
|
| 34 |
prompt_file: str,
|
| 35 |
tools_to_use: Optional[List[str]] = None,
|
| 36 |
-
model_dir: str = "
|
| 37 |
temp_dir: str = "temp",
|
| 38 |
device: str = "cpu",
|
| 39 |
-
model: str = "
|
| 40 |
-
temperature: float = 0
|
| 41 |
top_p: float = 0.95,
|
| 42 |
rag_config: Optional[RAGConfig] = None,
|
| 43 |
model_kwargs: Dict[str, Any] = {},
|
|
@@ -93,21 +94,25 @@ def initialize_agent(
|
|
| 93 |
device=device, cache_dir=model_dir, temp_dir=temp_dir
|
| 94 |
),
|
| 95 |
"MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
try:
|
| 99 |
-
tools_dict["PythonSandboxTool"] = create_python_sandbox()
|
| 100 |
-
except Exception as e:
|
| 101 |
-
print(f"Error creating PythonSandboxTool: {e}")
|
| 102 |
-
print("Skipping PythonSandboxTool")
|
| 103 |
|
| 104 |
# Initialize only selected tools or all if none specified
|
| 105 |
tools_dict: Dict[str, BaseTool] = {}
|
| 106 |
-
|
|
|
|
| 107 |
for tool_name in tools_to_use:
|
|
|
|
|
|
|
| 108 |
if tool_name in all_tools:
|
| 109 |
tools_dict[tool_name] = all_tools[tool_name]()
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
# Set up checkpointing for conversation state
|
| 112 |
checkpointer = MemorySaver()
|
| 113 |
|
|
@@ -146,20 +151,20 @@ if __name__ == "__main__":
|
|
| 146 |
selected_tools = [
|
| 147 |
"ImageVisualizerTool", # For displaying images in the UI
|
| 148 |
# "DicomProcessorTool", # For processing DICOM medical image files
|
|
|
|
|
|
|
| 149 |
"TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 150 |
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 151 |
-
"
|
| 152 |
-
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 153 |
"XRayVQATool", # For visual question answering on X-rays
|
| 154 |
# "LlavaMedTool", # For multimodal medical image understanding
|
| 155 |
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 156 |
-
|
| 157 |
# "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
|
| 158 |
# "WebBrowserTool", # For web browsing and search capabilities
|
|
|
|
| 159 |
# "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 160 |
# "PythonSandboxTool", # Add the Python sandbox tool
|
| 161 |
-
"MedGemmaVQATool" # Google MedGemma VQA tool
|
| 162 |
-
"DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 163 |
]
|
| 164 |
|
| 165 |
# Setup the MedGemma environment if the MedGemmaVQATool is selected
|
|
@@ -188,11 +193,11 @@ if __name__ == "__main__":
|
|
| 188 |
agent, tools_dict = initialize_agent(
|
| 189 |
prompt_file="medrax/docs/system_prompts.txt",
|
| 190 |
tools_to_use=selected_tools,
|
| 191 |
-
model_dir="
|
| 192 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 193 |
-
device="
|
| 194 |
-
model="
|
| 195 |
-
temperature=0
|
| 196 |
top_p=0.95,
|
| 197 |
model_kwargs=model_kwargs,
|
| 198 |
rag_config=rag_config,
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import warnings
|
| 13 |
+
import os
|
| 14 |
from typing import Dict, List, Optional, Any
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from transformers import logging
|
|
|
|
| 34 |
def initialize_agent(
|
| 35 |
prompt_file: str,
|
| 36 |
tools_to_use: Optional[List[str]] = None,
|
| 37 |
+
model_dir: str = "model-weights",
|
| 38 |
temp_dir: str = "temp",
|
| 39 |
device: str = "cpu",
|
| 40 |
+
model: str = "gemini-2.5-pro",
|
| 41 |
+
temperature: float = 1.0,
|
| 42 |
top_p: float = 0.95,
|
| 43 |
rag_config: Optional[RAGConfig] = None,
|
| 44 |
model_kwargs: Dict[str, Any] = {},
|
|
|
|
| 94 |
device=device, cache_dir=model_dir, temp_dir=temp_dir
|
| 95 |
),
|
| 96 |
"MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
|
| 97 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
# Initialize only selected tools or all if none specified
|
| 100 |
tools_dict: Dict[str, BaseTool] = {}
|
| 101 |
+
if tools_to_use is None:
|
| 102 |
+
tools_to_use = []
|
| 103 |
for tool_name in tools_to_use:
|
| 104 |
+
if tool_name == "PythonSandboxTool":
|
| 105 |
+
continue
|
| 106 |
if tool_name in all_tools:
|
| 107 |
tools_dict[tool_name] = all_tools[tool_name]()
|
| 108 |
|
| 109 |
+
# Try to create the PythonSandboxTool
|
| 110 |
+
try:
|
| 111 |
+
tools_dict["PythonSandboxTool"] = create_python_sandbox()
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"Error creating PythonSandboxTool: {e}")
|
| 114 |
+
print("Skipping PythonSandboxTool")
|
| 115 |
+
|
| 116 |
# Set up checkpointing for conversation state
|
| 117 |
checkpointer = MemorySaver()
|
| 118 |
|
|
|
|
| 151 |
selected_tools = [
|
| 152 |
"ImageVisualizerTool", # For displaying images in the UI
|
| 153 |
# "DicomProcessorTool", # For processing DICOM medical image files
|
| 154 |
+
# "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
|
| 155 |
+
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 156 |
"TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 157 |
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 158 |
+
"MedGemmaVQATool" # Google MedGemma VQA tool
|
|
|
|
| 159 |
"XRayVQATool", # For visual question answering on X-rays
|
| 160 |
# "LlavaMedTool", # For multimodal medical image understanding
|
| 161 |
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 162 |
+
"ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
|
| 163 |
# "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
|
| 164 |
# "WebBrowserTool", # For web browsing and search capabilities
|
| 165 |
+
"DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 166 |
# "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 167 |
# "PythonSandboxTool", # Add the Python sandbox tool
|
|
|
|
|
|
|
| 168 |
]
|
| 169 |
|
| 170 |
# Setup the MedGemma environment if the MedGemmaVQATool is selected
|
|
|
|
| 193 |
agent, tools_dict = initialize_agent(
|
| 194 |
prompt_file="medrax/docs/system_prompts.txt",
|
| 195 |
tools_to_use=selected_tools,
|
| 196 |
+
model_dir="model-weights",
|
| 197 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 198 |
+
device="cpu",
|
| 199 |
+
model="gemini-2.5-pro", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 200 |
+
temperature=1.0,
|
| 201 |
top_p=0.95,
|
| 202 |
model_kwargs=model_kwargs,
|
| 203 |
rag_config=rag_config,
|
medrax/tools/__init__.py
CHANGED
|
@@ -11,4 +11,3 @@ from .utils import *
|
|
| 11 |
from .rag import *
|
| 12 |
from .browsing import *
|
| 13 |
from .python_tool import *
|
| 14 |
-
from .medsam2 import *
|
|
|
|
| 11 |
from .rag import *
|
| 12 |
from .browsing import *
|
| 13 |
from .python_tool import *
|
|
|