Adibvafa commited on
Commit
7f4d4c2
·
1 Parent(s): 0ad3984

Prep for actual benchmarking

Browse files
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -35,19 +35,19 @@ class MedRAXProvider(LLMProvider):
35
  print("Starting server...")
36
 
37
  selected_tools = [
38
- "ImageVisualizerTool", # For displaying images in the UI
39
- "DicomProcessorTool", # For processing DICOM medical image files
40
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
41
- # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
42
- "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
43
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
44
  "XRayVQATool", # For visual question answering on X-rays
45
- "LlavaMedTool", # For multimodal medical image understanding
46
  "XRayPhraseGroundingTool", # For locating described features in X-rays
47
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
48
  "WebBrowserTool", # For web browsing and search capabilities
49
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
50
- "PythonSandboxTool", # Add the Python sandbox tool
51
  ]
52
 
53
  rag_config = RAGConfig(
@@ -58,7 +58,7 @@ class MedRAXProvider(LLMProvider):
58
  pinecone_index_name="medrax2", # Name for the Pinecone index
59
  chunk_size=1500,
60
  chunk_overlap=300,
61
- retriever_k=7,
62
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
63
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
64
  dataset_split="train", # Which split of the datasets to use
@@ -72,9 +72,9 @@ class MedRAXProvider(LLMProvider):
72
  tools_to_use=selected_tools,
73
  model_dir="model-weights",
74
  temp_dir="temp", # Change this to the path of the temporary directory
75
- device="cpu",
76
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
77
- temperature=0.7,
78
  top_p=0.95,
79
  model_kwargs=model_kwargs,
80
  rag_config=rag_config,
 
35
  print("Starting server...")
36
 
37
  selected_tools = [
38
+ # "ImageVisualizerTool", # For displaying images in the UI
39
+ # "DicomProcessorTool", # For processing DICOM medical image files
40
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
41
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
42
+ # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
43
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
44
  "XRayVQATool", # For visual question answering on X-rays
45
+ # "LlavaMedTool", # For multimodal medical image understanding
46
  "XRayPhraseGroundingTool", # For locating described features in X-rays
47
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
48
  "WebBrowserTool", # For web browsing and search capabilities
49
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
50
+ # "PythonSandboxTool", # Add the Python sandbox tool
51
  ]
52
 
53
  rag_config = RAGConfig(
 
58
  pinecone_index_name="medrax2", # Name for the Pinecone index
59
  chunk_size=1500,
60
  chunk_overlap=300,
61
+ retriever_k=3,
62
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
63
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
64
  dataset_split="train", # Which split of the datasets to use
 
72
  tools_to_use=selected_tools,
73
  model_dir="model-weights",
74
  temp_dir="temp", # Change this to the path of the temporary directory
75
+ device="cuda:0",
76
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
77
+ temperature=0.3,
78
  top_p=0.95,
79
  model_kwargs=model_kwargs,
80
  rag_config=rag_config,
main.py CHANGED
@@ -143,15 +143,15 @@ if __name__ == "__main__":
143
  selected_tools = [
144
  "ImageVisualizerTool", # For displaying images in the UI
145
  # "DicomProcessorTool", # For processing DICOM medical image files
146
- # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
147
- # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
148
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
149
- # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
150
- # "XRayVQATool", # For visual question answering on X-rays
151
  # "LlavaMedTool", # For multimodal medical image understanding
152
- # "XRayPhraseGroundingTool", # For locating described features in X-rays
153
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
154
- "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
155
  "WebBrowserTool", # For web browsing and search capabilities
156
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
157
  # "PythonSandboxTool", # Add the Python sandbox tool
@@ -167,7 +167,7 @@ if __name__ == "__main__":
167
  pinecone_index_name="medrax2", # Name for the Pinecone index
168
  chunk_size=1500,
169
  chunk_overlap=300,
170
- retriever_k=7,
171
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
172
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
173
  dataset_split="train", # Which split of the datasets to use
@@ -179,10 +179,10 @@ if __name__ == "__main__":
179
  agent, tools_dict = initialize_agent(
180
  prompt_file="medrax/docs/system_prompts.txt",
181
  tools_to_use=selected_tools,
182
- model_dir="model-weights",
183
  temp_dir="temp", # Change this to the path of the temporary directory
184
- device="cuda",
185
- model="grok-4", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
186
  temperature=0.7,
187
  top_p=0.95,
188
  model_kwargs=model_kwargs,
 
143
  selected_tools = [
144
  "ImageVisualizerTool", # For displaying images in the UI
145
  # "DicomProcessorTool", # For processing DICOM medical image files
146
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
147
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
148
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
149
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
150
+ "XRayVQATool", # For visual question answering on X-rays
151
  # "LlavaMedTool", # For multimodal medical image understanding
152
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
153
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
154
+ # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
155
  "WebBrowserTool", # For web browsing and search capabilities
156
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
157
  # "PythonSandboxTool", # Add the Python sandbox tool
 
167
  pinecone_index_name="medrax2", # Name for the Pinecone index
168
  chunk_size=1500,
169
  chunk_overlap=300,
170
+ retriever_k=3,
171
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
172
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
173
  dataset_split="train", # Which split of the datasets to use
 
179
  agent, tools_dict = initialize_agent(
180
  prompt_file="medrax/docs/system_prompts.txt",
181
  tools_to_use=selected_tools,
182
+ model_dir="/model-weights",
183
  temp_dir="temp", # Change this to the path of the temporary directory
184
+ device="cuda:0",
185
+ model="gemini-2.5-pro", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
186
  temperature=0.7,
187
  top_p=0.95,
188
  model_kwargs=model_kwargs,