VictorLJZ commited on
Commit
587ddab
·
1 Parent(s): 4216843
benchmarking/cli.py CHANGED
@@ -139,8 +139,8 @@ def main():
139
  help="Model temperature for response generation (default: 0.7)")
140
  run_parser.add_argument("--top-p", type=float, default=0.95,
141
  help="Top-p nucleus sampling parameter (default: 0.95)")
142
- run_parser.add_argument("--max-tokens", type=int, default=1000,
143
- help="Maximum tokens per model response (default: 1000)")
144
 
145
  run_parser.set_defaults(func=run_benchmark_command)
146
 
 
139
  help="Model temperature for response generation (default: 0.7)")
140
  run_parser.add_argument("--top-p", type=float, default=0.95,
141
  help="Top-p nucleus sampling parameter (default: 0.95)")
142
+ run_parser.add_argument("--max-tokens", type=int, default=5000,
143
+ help="Maximum tokens per model response (default: 5000)")
144
 
145
  run_parser.set_defaults(func=run_benchmark_command)
146
 
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -33,19 +33,19 @@ class MedRAXProvider(LLMProvider):
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
- # "ImageVisualizerTool", # For displaying images in the UI
37
- # "DicomProcessorTool", # For processing DICOM medical image files
38
- # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
39
  # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
40
- # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
41
- # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
42
- # "XRayVQATool", # For visual question answering on X-rays
43
- # "LlavaMedTool", # For multimodal medical image understanding
44
- # "XRayPhraseGroundingTool", # For locating described features in X-rays
45
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
46
  "WebBrowserTool", # For web browsing and search capabilities
47
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
48
- # "PythonSandboxTool", # Add the Python sandbox tool
49
  ]
50
 
51
  rag_config = RAGConfig(
@@ -68,7 +68,7 @@ class MedRAXProvider(LLMProvider):
68
  agent, tools_dict = initialize_agent(
69
  prompt_file="medrax/docs/system_prompts.txt",
70
  tools_to_use=selected_tools,
71
- model_dir="/model-weights",
72
  temp_dir="temp", # Change this to the path of the temporary directory
73
  device="cpu",
74
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
 
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
+ "ImageVisualizerTool", # For displaying images in the UI
37
+ "DicomProcessorTool", # For processing DICOM medical image files
38
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
39
  # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
40
+ "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
41
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
42
+ "XRayVQATool", # For visual question answering on X-rays
43
+ "LlavaMedTool", # For multimodal medical image understanding
44
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
45
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
46
  "WebBrowserTool", # For web browsing and search capabilities
47
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
48
+ "PythonSandboxTool", # Add the Python sandbox tool
49
  ]
50
 
51
  rag_config = RAGConfig(
 
68
  agent, tools_dict = initialize_agent(
69
  prompt_file="medrax/docs/system_prompts.txt",
70
  tools_to_use=selected_tools,
71
+ model_dir="model-weights",
72
  temp_dir="temp", # Change this to the path of the temporary directory
73
  device="cpu",
74
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
medrax/agent/agent.py CHANGED
@@ -1,12 +1,8 @@
1
- import json
2
- import operator
3
  from pathlib import Path
4
  from dotenv import load_dotenv
5
- from datetime import datetime
6
- from typing import List, Dict, Any, TypedDict, Annotated, Optional
7
 
8
  from langgraph.prebuilt import create_react_agent
9
- from langchain_core.messages import AnyMessage
10
  from langgraph.prebuilt.chat_agent_executor import AgentState
11
  from langchain_core.language_models import BaseLanguageModel
12
  from langchain_core.tools import BaseTool
 
 
 
1
  from pathlib import Path
2
  from dotenv import load_dotenv
3
+ from typing import List, Any, TypedDict, Optional
 
4
 
5
  from langgraph.prebuilt import create_react_agent
 
6
  from langgraph.prebuilt.chat_agent_executor import AgentState
7
  from langchain_core.language_models import BaseLanguageModel
8
  from langchain_core.tools import BaseTool
medrax/tools/rag.py CHANGED
@@ -48,14 +48,14 @@ class RAGTool(BaseTool):
48
  self.rag = CohereRAG(config)
49
  self.chain = self.rag.initialize_rag(with_memory=True)
50
 
51
- def _run(self, query: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
52
  """Execute the RAG tool with the given query.
53
 
54
  Args:
55
  query (str): Medical question to answer
56
 
57
  Returns:
58
- Tuple[Dict[str, Any], Dict[str, Any]]: Output dictionary and metadata dictionary
59
  """
60
  try:
61
  result = self.chain.invoke({"query": query})
@@ -87,14 +87,14 @@ class RAGTool(BaseTool):
87
  }
88
  return output, metadata
89
 
90
- async def _arun(self, query: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
91
  """Async version of _run.
92
 
93
  Args:
94
  query (str): Medical question to answer
95
 
96
  Returns:
97
- Tuple[Dict[str, Any], Dict[str, Any]]: Output dictionary and metadata dictionary
98
 
99
  Raises:
100
  NotImplementedError: Async not implemented yet
 
48
  self.rag = CohereRAG(config)
49
  self.chain = self.rag.initialize_rag(with_memory=True)
50
 
51
+ def _run(self, query: str) -> Tuple[Dict[str, Any], Dict]:
52
  """Execute the RAG tool with the given query.
53
 
54
  Args:
55
  query (str): Medical question to answer
56
 
57
  Returns:
58
+ Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
59
  """
60
  try:
61
  result = self.chain.invoke({"query": query})
 
87
  }
88
  return output, metadata
89
 
90
+ async def _arun(self, query: str) -> Tuple[Dict[str, Any], Dict]:
91
  """Async version of _run.
92
 
93
  Args:
94
  query (str): Medical question to answer
95
 
96
  Returns:
97
+ Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
98
 
99
  Raises:
100
  NotImplementedError: Async not implemented yet