Spaces:
Paused
Paused
Adibvafa
commited on
Commit
·
a7d0aad
1
Parent(s):
5084d75
Fix prompt load
Browse files- benchmarking/cli.py +47 -27
- benchmarking/llm_providers/base.py +6 -4
- benchmarking/llm_providers/medrax_provider.py +5 -3
- main.py +4 -1
benchmarking/cli.py
CHANGED
|
@@ -3,34 +3,40 @@
|
|
| 3 |
import argparse
|
| 4 |
import sys
|
| 5 |
|
| 6 |
-
from .llm_providers import
|
| 7 |
from .benchmarks import *
|
| 8 |
from .runner import BenchmarkRunner, BenchmarkRunConfig
|
| 9 |
|
| 10 |
|
| 11 |
-
def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMProvider:
|
| 12 |
"""Create an LLM provider based on the model name and type.
|
| 13 |
|
| 14 |
Args:
|
| 15 |
model_name (str): Name of the model
|
| 16 |
-
provider_type (str): Type of provider (openai, google, openrouter,
|
|
|
|
| 17 |
**kwargs: Additional configuration parameters
|
| 18 |
|
| 19 |
Returns:
|
| 20 |
LLMProvider: The configured LLM provider
|
| 21 |
"""
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
|
|
@@ -63,12 +69,12 @@ def run_benchmark_command(args) -> None:
|
|
| 63 |
# Create LLM provider
|
| 64 |
provider_kwargs = {}
|
| 65 |
|
| 66 |
-
llm_provider = create_llm_provider(args.model, args.provider, **provider_kwargs)
|
| 67 |
|
| 68 |
# Create benchmark
|
| 69 |
benchmark_kwargs = {}
|
| 70 |
|
| 71 |
-
benchmark = create_benchmark(args.benchmark, args.data_dir, **benchmark_kwargs)
|
| 72 |
|
| 73 |
# Create runner config
|
| 74 |
config = BenchmarkRunConfig(
|
|
@@ -111,16 +117,30 @@ def main():
|
|
| 111 |
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 112 |
|
| 113 |
# Run benchmark command
|
| 114 |
-
run_parser = subparsers.add_parser("run", help="Run a benchmark")
|
| 115 |
-
run_parser.add_argument("--model", required=True,
|
| 116 |
-
|
| 117 |
-
run_parser.add_argument("--
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
run_parser.add_argument("--
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
run_parser.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 126 |
|
|
|
|
| 3 |
import argparse
|
| 4 |
import sys
|
| 5 |
|
| 6 |
+
from .llm_providers.base import LLMProvider
|
| 7 |
from .benchmarks import *
|
| 8 |
from .runner import BenchmarkRunner, BenchmarkRunConfig
|
| 9 |
|
| 10 |
|
| 11 |
+
def create_llm_provider(model_name: str, provider_type: str, system_prompt: str, **kwargs) -> LLMProvider:
|
| 12 |
"""Create an LLM provider based on the model name and type.
|
| 13 |
|
| 14 |
Args:
|
| 15 |
model_name (str): Name of the model
|
| 16 |
+
provider_type (str): Type of provider (openai, google, openrouter, medrax)
|
| 17 |
+
system_prompt (str): System prompt identifier to load from file
|
| 18 |
**kwargs: Additional configuration parameters
|
| 19 |
|
| 20 |
Returns:
|
| 21 |
LLMProvider: The configured LLM provider
|
| 22 |
"""
|
| 23 |
+
# Lazy imports to avoid slow startup
|
| 24 |
+
if provider_type == "openai":
|
| 25 |
+
from .llm_providers.openai_provider import OpenAIProvider
|
| 26 |
+
provider_class = OpenAIProvider
|
| 27 |
+
elif provider_type == "google":
|
| 28 |
+
from .llm_providers.google_provider import GoogleProvider
|
| 29 |
+
provider_class = GoogleProvider
|
| 30 |
+
elif provider_type == "openrouter":
|
| 31 |
+
from .llm_providers.openrouter_provider import OpenRouterProvider
|
| 32 |
+
provider_class = OpenRouterProvider
|
| 33 |
+
elif provider_type == "medrax":
|
| 34 |
+
from .llm_providers.medrax_provider import MedRAXProvider
|
| 35 |
+
provider_class = MedRAXProvider
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax")
|
| 38 |
+
|
| 39 |
+
return provider_class(model_name, system_prompt, **kwargs)
|
| 40 |
|
| 41 |
|
| 42 |
def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
|
|
|
|
| 69 |
# Create LLM provider
|
| 70 |
provider_kwargs = {}
|
| 71 |
|
| 72 |
+
llm_provider = create_llm_provider(model_name=args.model, provider_type=args.provider, system_prompt=args.system_prompt, **provider_kwargs)
|
| 73 |
|
| 74 |
# Create benchmark
|
| 75 |
benchmark_kwargs = {}
|
| 76 |
|
| 77 |
+
benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
|
| 78 |
|
| 79 |
# Create runner config
|
| 80 |
config = BenchmarkRunConfig(
|
|
|
|
| 117 |
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 118 |
|
| 119 |
# Run benchmark command
|
| 120 |
+
run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
|
| 121 |
+
run_parser.add_argument("--model", required=True,
|
| 122 |
+
help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
|
| 123 |
+
run_parser.add_argument("--provider", required=True,
|
| 124 |
+
choices=["openai", "google", "openrouter", "medrax"],
|
| 125 |
+
help="LLM provider to use")
|
| 126 |
+
run_parser.add_argument("--system-prompt", required=True,
|
| 127 |
+
choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
|
| 128 |
+
help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
|
| 129 |
+
run_parser.add_argument("--benchmark", required=True,
|
| 130 |
+
choices=["rexvqa", "chestagentbench"],
|
| 131 |
+
help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
|
| 132 |
+
run_parser.add_argument("--data-dir", required=True,
|
| 133 |
+
help="Directory containing benchmark data files")
|
| 134 |
+
run_parser.add_argument("--output-dir", default="benchmark_results",
|
| 135 |
+
help="Output directory for results (default: benchmark_results)")
|
| 136 |
+
run_parser.add_argument("--max-questions", type=int,
|
| 137 |
+
help="Maximum number of questions to process (default: all)")
|
| 138 |
+
run_parser.add_argument("--temperature", type=float, default=0.7,
|
| 139 |
+
help="Model temperature for response generation (default: 0.7)")
|
| 140 |
+
run_parser.add_argument("--top-p", type=float, default=0.95,
|
| 141 |
+
help="Top-p nucleus sampling parameter (default: 0.95)")
|
| 142 |
+
run_parser.add_argument("--max-tokens", type=int, default=1000,
|
| 143 |
+
help="Maximum tokens per model response (default: 1000)")
|
| 144 |
|
| 145 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 146 |
|
benchmarking/llm_providers/base.py
CHANGED
|
@@ -35,22 +35,24 @@ class LLMProvider(ABC):
|
|
| 35 |
text + image input -> text output across different models and APIs.
|
| 36 |
"""
|
| 37 |
|
| 38 |
-
def __init__(self, model_name: str, **kwargs):
|
| 39 |
"""Initialize the LLM provider.
|
| 40 |
|
| 41 |
Args:
|
| 42 |
model_name (str): Name of the model to use
|
|
|
|
| 43 |
**kwargs: Additional configuration parameters
|
| 44 |
"""
|
| 45 |
self.model_name = model_name
|
| 46 |
self.config = kwargs
|
|
|
|
| 47 |
|
| 48 |
-
#
|
| 49 |
try:
|
| 50 |
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
|
| 51 |
-
self.system_prompt = prompts.get(
|
| 52 |
if self.system_prompt is None:
|
| 53 |
-
print(f"Warning: System prompt not found in medrax/docs/system_prompts.txt.")
|
| 54 |
except Exception as e:
|
| 55 |
print(f"Error loading system prompt: {e}")
|
| 56 |
self.system_prompt = None
|
|
|
|
| 35 |
text + image input -> text output across different models and APIs.
|
| 36 |
"""
|
| 37 |
|
| 38 |
+
def __init__(self, model_name: str, system_prompt: str, **kwargs):
|
| 39 |
"""Initialize the LLM provider.
|
| 40 |
|
| 41 |
Args:
|
| 42 |
model_name (str): Name of the model to use
|
| 43 |
+
system_prompt (str): System prompt identifier to load from file
|
| 44 |
**kwargs: Additional configuration parameters
|
| 45 |
"""
|
| 46 |
self.model_name = model_name
|
| 47 |
self.config = kwargs
|
| 48 |
+
self.prompt_name = system_prompt # Store the original prompt identifier
|
| 49 |
|
| 50 |
+
# Load system prompt content from file
|
| 51 |
try:
|
| 52 |
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
|
| 53 |
+
self.system_prompt = prompts.get(system_prompt, None)
|
| 54 |
if self.system_prompt is None:
|
| 55 |
+
print(f"Warning: System prompt '{system_prompt}' not found in medrax/docs/system_prompts.txt.")
|
| 56 |
except Exception as e:
|
| 57 |
print(f"Error loading system prompt: {e}")
|
| 58 |
self.system_prompt = None
|
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -13,18 +13,19 @@ from main import initialize_agent
|
|
| 13 |
class MedRAXProvider(LLMProvider):
|
| 14 |
"""MedRAX LLM provider that uses the full MedRAX agent system."""
|
| 15 |
|
| 16 |
-
def __init__(self, model_name: str, **kwargs):
|
| 17 |
"""Initialize MedRAX provider.
|
| 18 |
|
| 19 |
Args:
|
| 20 |
model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
|
|
|
|
| 21 |
**kwargs: Additional configuration parameters
|
| 22 |
"""
|
| 23 |
self.model_name = model_name
|
| 24 |
self.agent = None
|
| 25 |
self.tools_dict = None
|
| 26 |
-
|
| 27 |
-
super().__init__(model_name, **kwargs)
|
| 28 |
|
| 29 |
def _setup(self) -> None:
|
| 30 |
"""Set up MedRAX agent system."""
|
|
@@ -75,6 +76,7 @@ class MedRAXProvider(LLMProvider):
|
|
| 75 |
top_p=0.95,
|
| 76 |
model_kwargs=model_kwargs,
|
| 77 |
rag_config=rag_config,
|
|
|
|
| 78 |
debug=True,
|
| 79 |
)
|
| 80 |
|
|
|
|
| 13 |
class MedRAXProvider(LLMProvider):
|
| 14 |
"""MedRAX LLM provider that uses the full MedRAX agent system."""
|
| 15 |
|
| 16 |
+
def __init__(self, model_name: str, system_prompt: str, **kwargs):
|
| 17 |
"""Initialize MedRAX provider.
|
| 18 |
|
| 19 |
Args:
|
| 20 |
model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
|
| 21 |
+
system_prompt (str): System prompt to use
|
| 22 |
**kwargs: Additional configuration parameters
|
| 23 |
"""
|
| 24 |
self.model_name = model_name
|
| 25 |
self.agent = None
|
| 26 |
self.tools_dict = None
|
| 27 |
+
|
| 28 |
+
super().__init__(model_name, system_prompt, **kwargs)
|
| 29 |
|
| 30 |
def _setup(self) -> None:
|
| 31 |
"""Set up MedRAX agent system."""
|
|
|
|
| 76 |
top_p=0.95,
|
| 77 |
model_kwargs=model_kwargs,
|
| 78 |
rag_config=rag_config,
|
| 79 |
+
system_prompt=self.prompt_name,
|
| 80 |
debug=True,
|
| 81 |
)
|
| 82 |
|
main.py
CHANGED
|
@@ -41,6 +41,7 @@ def initialize_agent(
|
|
| 41 |
top_p: float = 0.95,
|
| 42 |
rag_config: Optional[RAGConfig] = None,
|
| 43 |
model_kwargs: Dict[str, Any] = {},
|
|
|
|
| 44 |
debug: bool = False,
|
| 45 |
):
|
| 46 |
"""Initialize the MedRAX agent with specified tools and configuration.
|
|
@@ -56,6 +57,7 @@ def initialize_agent(
|
|
| 56 |
top_p (float, optional): Top P for the model. Defaults to 0.95.
|
| 57 |
rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
|
| 58 |
model_kwargs (dict, optional): Additional keyword arguments for model.
|
|
|
|
| 59 |
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
| 60 |
|
| 61 |
Returns:
|
|
@@ -63,7 +65,7 @@ def initialize_agent(
|
|
| 63 |
"""
|
| 64 |
# Load system prompts from file
|
| 65 |
prompts = load_prompts_from_file(prompt_file)
|
| 66 |
-
prompt = prompts[
|
| 67 |
|
| 68 |
all_tools = {
|
| 69 |
"TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
|
|
@@ -186,6 +188,7 @@ if __name__ == "__main__":
|
|
| 186 |
model_kwargs=model_kwargs,
|
| 187 |
rag_config=rag_config,
|
| 188 |
debug=True,
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
# Create and launch the web interface
|
|
|
|
| 41 |
top_p: float = 0.95,
|
| 42 |
rag_config: Optional[RAGConfig] = None,
|
| 43 |
model_kwargs: Dict[str, Any] = {},
|
| 44 |
+
system_prompt: str = "MEDICAL_ASSISTANT",
|
| 45 |
debug: bool = False,
|
| 46 |
):
|
| 47 |
"""Initialize the MedRAX agent with specified tools and configuration.
|
|
|
|
| 57 |
top_p (float, optional): Top P for the model. Defaults to 0.95.
|
| 58 |
rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
|
| 59 |
model_kwargs (dict, optional): Additional keyword arguments for model.
|
| 60 |
+
system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
|
| 61 |
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
| 62 |
|
| 63 |
Returns:
|
|
|
|
| 65 |
"""
|
| 66 |
# Load system prompts from file
|
| 67 |
prompts = load_prompts_from_file(prompt_file)
|
| 68 |
+
prompt = prompts[system_prompt]
|
| 69 |
|
| 70 |
all_tools = {
|
| 71 |
"TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
|
|
|
|
| 188 |
model_kwargs=model_kwargs,
|
| 189 |
rag_config=rag_config,
|
| 190 |
debug=True,
|
| 191 |
+
system_prompt="MEDICAL_ASSISTANT",
|
| 192 |
)
|
| 193 |
|
| 194 |
# Create and launch the web interface
|