Adibvafa commited on
Commit
a7d0aad
·
1 Parent(s): 5084d75

Fix prompt load

Browse files
benchmarking/cli.py CHANGED
@@ -3,34 +3,40 @@
3
  import argparse
4
  import sys
5
 
6
- from .llm_providers import *
7
  from .benchmarks import *
8
  from .runner import BenchmarkRunner, BenchmarkRunConfig
9
 
10
 
11
- def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMProvider:
12
  """Create an LLM provider based on the model name and type.
13
 
14
  Args:
15
  model_name (str): Name of the model
16
- provider_type (str): Type of provider (openai, google, openrouter, anthropic, medrax)
 
17
  **kwargs: Additional configuration parameters
18
 
19
  Returns:
20
  LLMProvider: The configured LLM provider
21
  """
22
- provider_map = {
23
- "openai": OpenAIProvider,
24
- "google": GoogleProvider,
25
- "openrouter": OpenRouterProvider,
26
- "medrax": MedRAXProvider,
27
- }
28
-
29
- if provider_type not in provider_map:
30
- raise ValueError(f"Unknown provider type: {provider_type}. Available: {list(provider_map.keys())}")
31
-
32
- provider_class = provider_map[provider_type]
33
- return provider_class(model_name, **kwargs)
 
 
 
 
 
34
 
35
 
36
  def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
@@ -63,12 +69,12 @@ def run_benchmark_command(args) -> None:
63
  # Create LLM provider
64
  provider_kwargs = {}
65
 
66
- llm_provider = create_llm_provider(args.model, args.provider, **provider_kwargs)
67
 
68
  # Create benchmark
69
  benchmark_kwargs = {}
70
 
71
- benchmark = create_benchmark(args.benchmark, args.data_dir, **benchmark_kwargs)
72
 
73
  # Create runner config
74
  config = BenchmarkRunConfig(
@@ -111,16 +117,30 @@ def main():
111
  subparsers = parser.add_subparsers(dest="command", help="Available commands")
112
 
113
  # Run benchmark command
114
- run_parser = subparsers.add_parser("run", help="Run a benchmark")
115
- run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
116
- run_parser.add_argument("--provider", required=True, choices=["openai", "google", "openrouter", "medrax"], help="LLM provider")
117
- run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
118
- run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
119
- run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
120
- run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
121
- run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
122
- run_parser.add_argument("--top-p", type=float, default=0.95, help="Top-p value")
123
- run_parser.add_argument("--max-tokens", type=int, default=1000, help="Maximum tokens per response")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  run_parser.set_defaults(func=run_benchmark_command)
126
 
 
3
  import argparse
4
  import sys
5
 
6
+ from .llm_providers.base import LLMProvider
7
  from .benchmarks import *
8
  from .runner import BenchmarkRunner, BenchmarkRunConfig
9
 
10
 
11
+ def create_llm_provider(model_name: str, provider_type: str, system_prompt: str, **kwargs) -> LLMProvider:
12
  """Create an LLM provider based on the model name and type.
13
 
14
  Args:
15
  model_name (str): Name of the model
16
+ provider_type (str): Type of provider (openai, google, openrouter, medrax)
17
+ system_prompt (str): System prompt identifier to load from file
18
  **kwargs: Additional configuration parameters
19
 
20
  Returns:
21
  LLMProvider: The configured LLM provider
22
  """
23
+ # Lazy imports to avoid slow startup
24
+ if provider_type == "openai":
25
+ from .llm_providers.openai_provider import OpenAIProvider
26
+ provider_class = OpenAIProvider
27
+ elif provider_type == "google":
28
+ from .llm_providers.google_provider import GoogleProvider
29
+ provider_class = GoogleProvider
30
+ elif provider_type == "openrouter":
31
+ from .llm_providers.openrouter_provider import OpenRouterProvider
32
+ provider_class = OpenRouterProvider
33
+ elif provider_type == "medrax":
34
+ from .llm_providers.medrax_provider import MedRAXProvider
35
+ provider_class = MedRAXProvider
36
+ else:
37
+ raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax")
38
+
39
+ return provider_class(model_name, system_prompt, **kwargs)
40
 
41
 
42
  def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
 
69
  # Create LLM provider
70
  provider_kwargs = {}
71
 
72
+ llm_provider = create_llm_provider(model_name=args.model, provider_type=args.provider, system_prompt=args.system_prompt, **provider_kwargs)
73
 
74
  # Create benchmark
75
  benchmark_kwargs = {}
76
 
77
+ benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
78
 
79
  # Create runner config
80
  config = BenchmarkRunConfig(
 
117
  subparsers = parser.add_subparsers(dest="command", help="Available commands")
118
 
119
  # Run benchmark command
120
+ run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
121
+ run_parser.add_argument("--model", required=True,
122
+ help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
123
+ run_parser.add_argument("--provider", required=True,
124
+ choices=["openai", "google", "openrouter", "medrax"],
125
+ help="LLM provider to use")
126
+ run_parser.add_argument("--system-prompt", required=True,
127
+ choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
128
+ help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
129
+ run_parser.add_argument("--benchmark", required=True,
130
+ choices=["rexvqa", "chestagentbench"],
131
+ help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
132
+ run_parser.add_argument("--data-dir", required=True,
133
+ help="Directory containing benchmark data files")
134
+ run_parser.add_argument("--output-dir", default="benchmark_results",
135
+ help="Output directory for results (default: benchmark_results)")
136
+ run_parser.add_argument("--max-questions", type=int,
137
+ help="Maximum number of questions to process (default: all)")
138
+ run_parser.add_argument("--temperature", type=float, default=0.7,
139
+ help="Model temperature for response generation (default: 0.7)")
140
+ run_parser.add_argument("--top-p", type=float, default=0.95,
141
+ help="Top-p nucleus sampling parameter (default: 0.95)")
142
+ run_parser.add_argument("--max-tokens", type=int, default=1000,
143
+ help="Maximum tokens per model response (default: 1000)")
144
 
145
  run_parser.set_defaults(func=run_benchmark_command)
146
 
benchmarking/llm_providers/base.py CHANGED
@@ -35,22 +35,24 @@ class LLMProvider(ABC):
35
  text + image input -> text output across different models and APIs.
36
  """
37
 
38
- def __init__(self, model_name: str, **kwargs):
39
  """Initialize the LLM provider.
40
 
41
  Args:
42
  model_name (str): Name of the model to use
 
43
  **kwargs: Additional configuration parameters
44
  """
45
  self.model_name = model_name
46
  self.config = kwargs
 
47
 
48
- # Always load system prompt from file
49
  try:
50
  prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
51
- self.system_prompt = prompts.get("CHESTAGENTBENCH_PROMPT", None)
52
  if self.system_prompt is None:
53
- print(f"Warning: System prompt not found in medrax/docs/system_prompts.txt.")
54
  except Exception as e:
55
  print(f"Error loading system prompt: {e}")
56
  self.system_prompt = None
 
35
  text + image input -> text output across different models and APIs.
36
  """
37
 
38
+ def __init__(self, model_name: str, system_prompt: str, **kwargs):
39
  """Initialize the LLM provider.
40
 
41
  Args:
42
  model_name (str): Name of the model to use
43
+ system_prompt (str): System prompt identifier to load from file
44
  **kwargs: Additional configuration parameters
45
  """
46
  self.model_name = model_name
47
  self.config = kwargs
48
+ self.prompt_name = system_prompt # Store the original prompt identifier
49
 
50
+ # Load system prompt content from file
51
  try:
52
  prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
53
+ self.system_prompt = prompts.get(system_prompt, None)
54
  if self.system_prompt is None:
55
+ print(f"Warning: System prompt '{system_prompt}' not found in medrax/docs/system_prompts.txt.")
56
  except Exception as e:
57
  print(f"Error loading system prompt: {e}")
58
  self.system_prompt = None
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -13,18 +13,19 @@ from main import initialize_agent
13
  class MedRAXProvider(LLMProvider):
14
  """MedRAX LLM provider that uses the full MedRAX agent system."""
15
 
16
- def __init__(self, model_name: str, **kwargs):
17
  """Initialize MedRAX provider.
18
 
19
  Args:
20
  model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
 
21
  **kwargs: Additional configuration parameters
22
  """
23
  self.model_name = model_name
24
  self.agent = None
25
  self.tools_dict = None
26
-
27
- super().__init__(model_name, **kwargs)
28
 
29
  def _setup(self) -> None:
30
  """Set up MedRAX agent system."""
@@ -75,6 +76,7 @@ class MedRAXProvider(LLMProvider):
75
  top_p=0.95,
76
  model_kwargs=model_kwargs,
77
  rag_config=rag_config,
 
78
  debug=True,
79
  )
80
 
 
13
  class MedRAXProvider(LLMProvider):
14
  """MedRAX LLM provider that uses the full MedRAX agent system."""
15
 
16
+ def __init__(self, model_name: str, system_prompt: str, **kwargs):
17
  """Initialize MedRAX provider.
18
 
19
  Args:
20
  model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
21
+ system_prompt (str): System prompt to use
22
  **kwargs: Additional configuration parameters
23
  """
24
  self.model_name = model_name
25
  self.agent = None
26
  self.tools_dict = None
27
+
28
+ super().__init__(model_name, system_prompt, **kwargs)
29
 
30
  def _setup(self) -> None:
31
  """Set up MedRAX agent system."""
 
76
  top_p=0.95,
77
  model_kwargs=model_kwargs,
78
  rag_config=rag_config,
79
+ system_prompt=self.prompt_name,
80
  debug=True,
81
  )
82
 
main.py CHANGED
@@ -41,6 +41,7 @@ def initialize_agent(
41
  top_p: float = 0.95,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
 
44
  debug: bool = False,
45
  ):
46
  """Initialize the MedRAX agent with specified tools and configuration.
@@ -56,6 +57,7 @@ def initialize_agent(
56
  top_p (float, optional): Top P for the model. Defaults to 0.95.
57
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
58
  model_kwargs (dict, optional): Additional keyword arguments for model.
 
59
  debug (bool, optional): Whether to enable debug mode. Defaults to False.
60
 
61
  Returns:
@@ -63,7 +65,7 @@ def initialize_agent(
63
  """
64
  # Load system prompts from file
65
  prompts = load_prompts_from_file(prompt_file)
66
- prompt = prompts["MEDICAL_ASSISTANT"]
67
 
68
  all_tools = {
69
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
@@ -186,6 +188,7 @@ if __name__ == "__main__":
186
  model_kwargs=model_kwargs,
187
  rag_config=rag_config,
188
  debug=True,
 
189
  )
190
 
191
  # Create and launch the web interface
 
41
  top_p: float = 0.95,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
44
+ system_prompt: str = "MEDICAL_ASSISTANT",
45
  debug: bool = False,
46
  ):
47
  """Initialize the MedRAX agent with specified tools and configuration.
 
57
  top_p (float, optional): Top P for the model. Defaults to 0.95.
58
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
59
  model_kwargs (dict, optional): Additional keyword arguments for model.
60
+ system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
61
  debug (bool, optional): Whether to enable debug mode. Defaults to False.
62
 
63
  Returns:
 
65
  """
66
  # Load system prompts from file
67
  prompts = load_prompts_from_file(prompt_file)
68
+ prompt = prompts[system_prompt]
69
 
70
  all_tools = {
71
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
 
188
  model_kwargs=model_kwargs,
189
  rag_config=rag_config,
190
  debug=True,
191
+ system_prompt="MEDICAL_ASSISTANT",
192
  )
193
 
194
  # Create and launch the web interface