Spaces:
Sleeping
Sleeping
Junzhe Li
commited on
Commit
·
89321e2
1
Parent(s):
b93ad3f
revamped benchmarking suite
Browse files- benchmarking/benchmarks/base.py +3 -20
- benchmarking/benchmarks/rexvqa_benchmark.py +2 -2
- benchmarking/cli.py +52 -64
- benchmarking/llm_providers/__init__.py +14 -0
- benchmarking/llm_providers/base.py +44 -24
- benchmarking/llm_providers/google_provider.py +13 -10
- benchmarking/llm_providers/medgemma_provider.py +222 -0
- benchmarking/llm_providers/medrax_provider.py +97 -33
- benchmarking/llm_providers/openai_provider.py +16 -13
- benchmarking/llm_providers/openrouter_provider.py +14 -8
- benchmarking/runner.py +38 -102
- benchmarking/system_prompts.txt +36 -0
benchmarking/benchmarks/base.py
CHANGED
|
@@ -39,7 +39,7 @@ class Benchmark(ABC):
|
|
| 39 |
self._load_data()
|
| 40 |
self._shuffle_data()
|
| 41 |
|
| 42 |
-
self.max_questions =
|
| 43 |
if self.max_questions:
|
| 44 |
self.data_points = self.data_points[:self.max_questions]
|
| 45 |
print(f"Randomly sampled {self.max_questions} questions from {self.__class__.__name__}")
|
|
@@ -51,12 +51,13 @@ class Benchmark(ABC):
|
|
| 51 |
"""Load benchmark data from the data directory."""
|
| 52 |
pass
|
| 53 |
|
| 54 |
-
def _shuffle_data(self
|
| 55 |
"""Shuffle the data points if a random seed is provided. If no random seed is provided, use 42 as default.
|
| 56 |
|
| 57 |
This method is called automatically after data loading to ensure
|
| 58 |
reproducible benchmark runs when a random seed is specified.
|
| 59 |
"""
|
|
|
|
| 60 |
random.seed(random_seed)
|
| 61 |
random.shuffle(self.data_points)
|
| 62 |
print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
|
|
@@ -99,21 +100,3 @@ class Benchmark(ABC):
|
|
| 99 |
for i in range(len(self)):
|
| 100 |
yield self.get_data_point(i)
|
| 101 |
|
| 102 |
-
def validate_images(self) -> Tuple[List[str], List[str]]:
|
| 103 |
-
"""Validate that all image paths exist.
|
| 104 |
-
|
| 105 |
-
Returns:
|
| 106 |
-
Tuple[List[str], List[str]]: Tuple of (valid_image_paths, invalid_image_paths)
|
| 107 |
-
"""
|
| 108 |
-
valid_images = []
|
| 109 |
-
invalid_images = []
|
| 110 |
-
|
| 111 |
-
for dp in self:
|
| 112 |
-
if dp.images:
|
| 113 |
-
for image_path in dp.images:
|
| 114 |
-
if Path(image_path).exists():
|
| 115 |
-
valid_images.append(image_path)
|
| 116 |
-
else:
|
| 117 |
-
invalid_images.append(image_path)
|
| 118 |
-
|
| 119 |
-
return valid_images, invalid_images
|
|
|
|
| 39 |
self._load_data()
|
| 40 |
self._shuffle_data()
|
| 41 |
|
| 42 |
+
self.max_questions = self.config.get("max_questions", None)
|
| 43 |
if self.max_questions:
|
| 44 |
self.data_points = self.data_points[:self.max_questions]
|
| 45 |
print(f"Randomly sampled {self.max_questions} questions from {self.__class__.__name__}")
|
|
|
|
| 51 |
"""Load benchmark data from the data directory."""
|
| 52 |
pass
|
| 53 |
|
| 54 |
+
def _shuffle_data(self) -> None:
|
| 55 |
"""Shuffle the data points if a random seed is provided. If no random seed is provided, use 42 as default.
|
| 56 |
|
| 57 |
This method is called automatically after data loading to ensure
|
| 58 |
reproducible benchmark runs when a random seed is specified.
|
| 59 |
"""
|
| 60 |
+
random_seed = self.config.get("random_seed", 42)
|
| 61 |
random.seed(random_seed)
|
| 62 |
random.shuffle(self.data_points)
|
| 63 |
print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
|
|
|
|
| 100 |
for i in range(len(self)):
|
| 101 |
yield self.get_data_point(i)
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarking/benchmarks/rexvqa_benchmark.py
CHANGED
|
@@ -48,11 +48,11 @@ class ReXVQABenchmark(Benchmark):
|
|
| 48 |
max_questions (int): Maximum number of questions to load (default: None, load all)
|
| 49 |
images_dir (str): Directory containing extracted PNG images (default: None)
|
| 50 |
"""
|
| 51 |
-
super().__init__(data_dir, **kwargs)
|
| 52 |
-
|
| 53 |
self.split = kwargs.get("split", "test")
|
| 54 |
self.images_dir = f"{data_dir}/images/deid_png"
|
| 55 |
|
|
|
|
|
|
|
| 56 |
def _load_data(self) -> None:
|
| 57 |
"""Load ReXVQA data from HuggingFace."""
|
| 58 |
try:
|
|
|
|
| 48 |
max_questions (int): Maximum number of questions to load (default: None, load all)
|
| 49 |
images_dir (str): Directory containing extracted PNG images (default: None)
|
| 50 |
"""
|
|
|
|
|
|
|
| 51 |
self.split = kwargs.get("split", "test")
|
| 52 |
self.images_dir = f"{data_dir}/images/deid_png"
|
| 53 |
|
| 54 |
+
super().__init__(data_dir, **kwargs)
|
| 55 |
+
|
| 56 |
def _load_data(self) -> None:
|
| 57 |
"""Load ReXVQA data from HuggingFace."""
|
| 58 |
try:
|
benchmarking/cli.py
CHANGED
|
@@ -8,12 +8,35 @@ from .benchmarks import *
|
|
| 8 |
from .runner import BenchmarkRunner, BenchmarkRunConfig
|
| 9 |
|
| 10 |
|
| 11 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""Create an LLM provider based on the model name and type.
|
| 13 |
|
| 14 |
Args:
|
|
|
|
| 15 |
model_name (str): Name of the model
|
| 16 |
-
provider_type (str): Type of provider (openai, google, openrouter, medrax)
|
| 17 |
system_prompt (str): System prompt identifier to load from file
|
| 18 |
**kwargs: Additional configuration parameters
|
| 19 |
|
|
@@ -33,85 +56,50 @@ def create_llm_provider(model_name: str, provider_type: str, system_prompt: str,
|
|
| 33 |
elif provider_type == "medrax":
|
| 34 |
from .llm_providers.medrax_provider import MedRAXProvider
|
| 35 |
provider_class = MedRAXProvider
|
|
|
|
|
|
|
|
|
|
| 36 |
else:
|
| 37 |
-
raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax")
|
| 38 |
|
| 39 |
return provider_class(model_name, system_prompt, **kwargs)
|
| 40 |
|
| 41 |
|
| 42 |
-
def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
|
| 43 |
-
"""Create a benchmark based on the benchmark name.
|
| 44 |
-
|
| 45 |
-
Args:
|
| 46 |
-
benchmark_name (str): Name of the benchmark
|
| 47 |
-
data_dir (str): Directory containing benchmark data
|
| 48 |
-
**kwargs: Additional configuration parameters
|
| 49 |
-
|
| 50 |
-
Returns:
|
| 51 |
-
Benchmark: The configured benchmark
|
| 52 |
-
"""
|
| 53 |
-
benchmark_map = {
|
| 54 |
-
"rexvqa": ReXVQABenchmark,
|
| 55 |
-
"chestagentbench": ChestAgentBenchBenchmark,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
if benchmark_name not in benchmark_map:
|
| 59 |
-
raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
|
| 60 |
-
|
| 61 |
-
benchmark_class = benchmark_map[benchmark_name]
|
| 62 |
-
return benchmark_class(data_dir, **kwargs)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
def run_benchmark_command(args) -> None:
|
| 66 |
"""Run a benchmark."""
|
| 67 |
-
print(f"Running benchmark: {args.benchmark} with model: {args.model}")
|
| 68 |
-
|
| 69 |
-
# Create LLM provider
|
| 70 |
-
provider_kwargs = {}
|
| 71 |
-
|
| 72 |
-
llm_provider = create_llm_provider(model_name=args.model, provider_type=args.provider, system_prompt=args.system_prompt, **provider_kwargs)
|
| 73 |
|
| 74 |
# Create benchmark
|
| 75 |
benchmark_kwargs = {}
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Create runner config
|
| 82 |
config = BenchmarkRunConfig(
|
|
|
|
| 83 |
provider_name=args.provider,
|
| 84 |
model_name=args.model,
|
| 85 |
-
benchmark_name=args.benchmark,
|
| 86 |
output_dir=args.output_dir,
|
| 87 |
max_questions=args.max_questions,
|
| 88 |
temperature=args.temperature,
|
| 89 |
top_p=args.top_p,
|
| 90 |
max_tokens=args.max_tokens,
|
| 91 |
-
concurrency=args.concurrency
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
# Run benchmark
|
| 95 |
runner = BenchmarkRunner(config)
|
| 96 |
-
summary = runner.run_benchmark(
|
| 97 |
-
|
| 98 |
-
print("\n" + "="*50)
|
| 99 |
-
print("BENCHMARK COMPLETED")
|
| 100 |
-
print("="*50)
|
| 101 |
-
|
| 102 |
-
# Check if benchmark run was successful
|
| 103 |
-
if "error" in summary:
|
| 104 |
-
print(f"Error: {summary['error']}")
|
| 105 |
-
return
|
| 106 |
-
|
| 107 |
-
# Print results
|
| 108 |
-
print(f"Model: {args.model}")
|
| 109 |
-
print(f"Benchmark: {args.benchmark}")
|
| 110 |
-
print(f"Total Questions: {summary['results']['total_questions']}")
|
| 111 |
-
print(f"Correct Answers: {summary['results']['correct_answers']}")
|
| 112 |
-
print(f"Overall Accuracy: {summary['results']['accuracy']:.2f}%")
|
| 113 |
-
print(f"Total Duration: {summary['results']['total_duration']:.2f}s")
|
| 114 |
-
print(f"Results saved to: {summary['results_file']}")
|
| 115 |
|
| 116 |
|
| 117 |
def main():
|
|
@@ -121,17 +109,17 @@ def main():
|
|
| 121 |
|
| 122 |
# Run benchmark command
|
| 123 |
run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
|
| 124 |
-
run_parser.add_argument("--
|
| 125 |
-
|
|
|
|
| 126 |
run_parser.add_argument("--provider", required=True,
|
| 127 |
-
choices=["openai", "google", "openrouter", "medrax"],
|
| 128 |
help="LLM provider to use")
|
|
|
|
|
|
|
| 129 |
run_parser.add_argument("--system-prompt", required=True,
|
| 130 |
choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
|
| 131 |
help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
|
| 132 |
-
run_parser.add_argument("--benchmark", required=True,
|
| 133 |
-
choices=["rexvqa", "chestagentbench"],
|
| 134 |
-
help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
|
| 135 |
run_parser.add_argument("--data-dir", required=True,
|
| 136 |
help="Directory containing benchmark data files")
|
| 137 |
run_parser.add_argument("--output-dir", default="benchmark_results",
|
|
@@ -144,10 +132,10 @@ def main():
|
|
| 144 |
help="Top-p nucleus sampling parameter (default: 0.95)")
|
| 145 |
run_parser.add_argument("--max-tokens", type=int, default=5000,
|
| 146 |
help="Maximum tokens per model response (default: 5000)")
|
| 147 |
-
run_parser.add_argument("--random-seed", type=int, default=42,
|
| 148 |
-
help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
|
| 149 |
run_parser.add_argument("--concurrency", type=int, default=1,
|
| 150 |
help="Number of datapoints to process in parallel (default: 1)")
|
|
|
|
|
|
|
| 151 |
|
| 152 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 153 |
|
|
|
|
| 8 |
from .runner import BenchmarkRunner, BenchmarkRunConfig
|
| 9 |
|
| 10 |
|
| 11 |
+
def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
|
| 12 |
+
"""Create a benchmark based on the benchmark name.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
benchmark_name (str): Name of the benchmark
|
| 16 |
+
data_dir (str): Directory containing benchmark data
|
| 17 |
+
**kwargs: Additional configuration parameters
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Benchmark: The configured benchmark
|
| 21 |
+
"""
|
| 22 |
+
benchmark_map = {
|
| 23 |
+
"rexvqa": ReXVQABenchmark,
|
| 24 |
+
"chestagentbench": ChestAgentBenchBenchmark,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
if benchmark_name not in benchmark_map:
|
| 28 |
+
raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
|
| 29 |
+
|
| 30 |
+
benchmark_class = benchmark_map[benchmark_name]
|
| 31 |
+
return benchmark_class(data_dir, **kwargs)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_llm_provider(provider_type: str, model_name: str, system_prompt: str, **kwargs) -> LLMProvider:
|
| 35 |
"""Create an LLM provider based on the model name and type.
|
| 36 |
|
| 37 |
Args:
|
| 38 |
+
provider_type (str): Type of provider (openai, google, openrouter, medrax, medgemma)
|
| 39 |
model_name (str): Name of the model
|
|
|
|
| 40 |
system_prompt (str): System prompt identifier to load from file
|
| 41 |
**kwargs: Additional configuration parameters
|
| 42 |
|
|
|
|
| 56 |
elif provider_type == "medrax":
|
| 57 |
from .llm_providers.medrax_provider import MedRAXProvider
|
| 58 |
provider_class = MedRAXProvider
|
| 59 |
+
elif provider_type == "medgemma":
|
| 60 |
+
from .llm_providers.medgemma_provider import MedGemmaProvider
|
| 61 |
+
provider_class = MedGemmaProvider
|
| 62 |
else:
|
| 63 |
+
raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax, medgemma")
|
| 64 |
|
| 65 |
return provider_class(model_name, system_prompt, **kwargs)
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
def run_benchmark_command(args) -> None:
|
| 69 |
"""Run a benchmark."""
|
| 70 |
+
print(f"Running benchmark: {args.benchmark} with provider: {args.provider}, model: {args.model}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Create benchmark
|
| 73 |
benchmark_kwargs = {}
|
| 74 |
+
benchmark_kwargs["max_questions"] = args.max_questions
|
| 75 |
+
benchmark_kwargs["random_seed"] = args.random_seed
|
|
|
|
| 76 |
benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
|
| 77 |
+
|
| 78 |
+
# Create LLM provider
|
| 79 |
+
provider_kwargs = {}
|
| 80 |
+
provider_kwargs["temperature"] = args.temperature
|
| 81 |
+
provider_kwargs["top_p"] = args.top_p
|
| 82 |
+
provider_kwargs["max_tokens"] = args.max_tokens
|
| 83 |
+
llm_provider = create_llm_provider(provider_type=args.provider, model_name=args.model, system_prompt=args.system_prompt, **provider_kwargs)
|
| 84 |
|
| 85 |
# Create runner config
|
| 86 |
config = BenchmarkRunConfig(
|
| 87 |
+
benchmark_name=args.benchmark,
|
| 88 |
provider_name=args.provider,
|
| 89 |
model_name=args.model,
|
|
|
|
| 90 |
output_dir=args.output_dir,
|
| 91 |
max_questions=args.max_questions,
|
| 92 |
temperature=args.temperature,
|
| 93 |
top_p=args.top_p,
|
| 94 |
max_tokens=args.max_tokens,
|
| 95 |
+
concurrency=args.concurrency,
|
| 96 |
+
random_seed=args.random_seed
|
| 97 |
)
|
| 98 |
|
| 99 |
# Run benchmark
|
| 100 |
runner = BenchmarkRunner(config)
|
| 101 |
+
summary = runner.run_benchmark(benchmark, llm_provider)
|
| 102 |
+
print(summary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
def main():
|
|
|
|
| 109 |
|
| 110 |
# Run benchmark command
|
| 111 |
run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
|
| 112 |
+
run_parser.add_argument("--benchmark", required=True,
|
| 113 |
+
choices=["rexvqa", "chestagentbench"],
|
| 114 |
+
help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
|
| 115 |
run_parser.add_argument("--provider", required=True,
|
| 116 |
+
choices=["openai", "google", "openrouter", "medrax", "medgemma"],
|
| 117 |
help="LLM provider to use")
|
| 118 |
+
run_parser.add_argument("--model", required=True,
|
| 119 |
+
help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
|
| 120 |
run_parser.add_argument("--system-prompt", required=True,
|
| 121 |
choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
|
| 122 |
help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
|
|
|
|
|
|
|
|
|
|
| 123 |
run_parser.add_argument("--data-dir", required=True,
|
| 124 |
help="Directory containing benchmark data files")
|
| 125 |
run_parser.add_argument("--output-dir", default="benchmark_results",
|
|
|
|
| 132 |
help="Top-p nucleus sampling parameter (default: 0.95)")
|
| 133 |
run_parser.add_argument("--max-tokens", type=int, default=5000,
|
| 134 |
help="Maximum tokens per model response (default: 5000)")
|
|
|
|
|
|
|
| 135 |
run_parser.add_argument("--concurrency", type=int, default=1,
|
| 136 |
help="Number of datapoints to process in parallel (default: 1)")
|
| 137 |
+
run_parser.add_argument("--random-seed", type=int, default=42,
|
| 138 |
+
help="Random seed for shuffling benchmark data (enables reproducible runs, default: 42)")
|
| 139 |
|
| 140 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 141 |
|
benchmarking/llm_providers/__init__.py
CHANGED
|
@@ -5,6 +5,17 @@ from .openai_provider import OpenAIProvider
|
|
| 5 |
from .google_provider import GoogleProvider
|
| 6 |
from .medrax_provider import MedRAXProvider
|
| 7 |
from .openrouter_provider import OpenRouterProvider
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
__all__ = [
|
| 10 |
"LLMProvider",
|
|
@@ -14,4 +25,7 @@ __all__ = [
|
|
| 14 |
"GoogleProvider",
|
| 15 |
"MedRAXProvider",
|
| 16 |
"OpenRouterProvider",
|
|
|
|
|
|
|
|
|
|
| 17 |
]
|
|
|
|
| 5 |
from .google_provider import GoogleProvider
|
| 6 |
from .medrax_provider import MedRAXProvider
|
| 7 |
from .openrouter_provider import OpenRouterProvider
|
| 8 |
+
from .medgemma_provider import MedGemmaProvider
|
| 9 |
+
|
| 10 |
+
# QwenProvider is optional - only import if dependencies are compatible
|
| 11 |
+
try:
|
| 12 |
+
from .qwen_provider import QwenProvider
|
| 13 |
+
QWEN_AVAILABLE = True
|
| 14 |
+
except ImportError as e:
|
| 15 |
+
QWEN_AVAILABLE = False
|
| 16 |
+
QwenProvider = None
|
| 17 |
+
print(f"QwenProvider not available: {e}")
|
| 18 |
+
print("To use Qwen models, upgrade transformers: pip install --upgrade git+https://github.com/huggingface/transformers")
|
| 19 |
|
| 20 |
__all__ = [
|
| 21 |
"LLMProvider",
|
|
|
|
| 25 |
"GoogleProvider",
|
| 26 |
"MedRAXProvider",
|
| 27 |
"OpenRouterProvider",
|
| 28 |
+
"MedGemmaProvider",
|
| 29 |
+
"QwenProvider",
|
| 30 |
+
"QWEN_AVAILABLE",
|
| 31 |
]
|
benchmarking/llm_providers/base.py
CHANGED
|
@@ -13,10 +13,6 @@ class LLMRequest:
|
|
| 13 |
"""Request to an LLM provider."""
|
| 14 |
text: str
|
| 15 |
images: Optional[List[str]] = None # List of image paths
|
| 16 |
-
temperature: float = 0.7
|
| 17 |
-
top_p: float = 0.95
|
| 18 |
-
max_tokens: int = 5000
|
| 19 |
-
additional_params: Optional[Dict[str, Any]] = None
|
| 20 |
|
| 21 |
|
| 22 |
@dataclass
|
|
@@ -44,15 +40,17 @@ class LLMProvider(ABC):
|
|
| 44 |
**kwargs: Additional configuration parameters
|
| 45 |
"""
|
| 46 |
self.model_name = model_name
|
| 47 |
-
self.
|
| 48 |
-
self.
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# Load system prompt content from file
|
| 51 |
try:
|
| 52 |
-
prompts = load_prompts_from_file("
|
| 53 |
-
self.system_prompt = prompts.get(
|
| 54 |
if self.system_prompt is None:
|
| 55 |
-
print(f"Warning: System prompt '{system_prompt}' not found in
|
| 56 |
except Exception as e:
|
| 57 |
print(f"Error loading system prompt: {e}")
|
| 58 |
self.system_prompt = None
|
|
@@ -85,9 +83,7 @@ class LLMProvider(ABC):
|
|
| 85 |
try:
|
| 86 |
# Simple test request
|
| 87 |
test_request = LLMRequest(
|
| 88 |
-
text="Hello! What model are you? Tell me your full specification."
|
| 89 |
-
temperature=0.5,
|
| 90 |
-
max_tokens=1000
|
| 91 |
)
|
| 92 |
response = self.generate_response(test_request)
|
| 93 |
return response.content is not None and len(response.content.strip()) > 0
|
|
@@ -95,6 +91,23 @@ class LLMProvider(ABC):
|
|
| 95 |
print(f"Connection test failed: {e}")
|
| 96 |
return False
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def _encode_image(self, image_path: str) -> str:
|
| 99 |
"""Encode image to base64 string.
|
| 100 |
|
|
@@ -110,23 +123,30 @@ class LLMProvider(ABC):
|
|
| 110 |
except Exception as e:
|
| 111 |
print(f"ERROR: _encode_image failed for {image_path} (type: {type(image_path)}): {e}")
|
| 112 |
raise
|
| 113 |
-
|
| 114 |
-
def
|
| 115 |
-
"""
|
| 116 |
|
| 117 |
Args:
|
| 118 |
-
|
| 119 |
|
| 120 |
Returns:
|
| 121 |
-
|
| 122 |
"""
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
def __str__(self) -> str:
|
| 132 |
"""String representation of the provider."""
|
|
|
|
| 13 |
"""Request to an LLM provider."""
|
| 14 |
text: str
|
| 15 |
images: Optional[List[str]] = None # List of image paths
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
@dataclass
|
|
|
|
| 40 |
**kwargs: Additional configuration parameters
|
| 41 |
"""
|
| 42 |
self.model_name = model_name
|
| 43 |
+
self.temperature = kwargs.get("temperature", 0.7)
|
| 44 |
+
self.top_p = kwargs.get("top_p", 0.95)
|
| 45 |
+
self.max_tokens = kwargs.get("max_tokens", 5000)
|
| 46 |
+
self.prompt_name = system_prompt
|
| 47 |
|
| 48 |
# Load system prompt content from file
|
| 49 |
try:
|
| 50 |
+
prompts = load_prompts_from_file("benchmarking/system_prompts.txt")
|
| 51 |
+
self.system_prompt = prompts.get(self.prompt_name, None)
|
| 52 |
if self.system_prompt is None:
|
| 53 |
+
print(f"Warning: System prompt '{system_prompt}' not found in benchmarking/system_prompts.txt.")
|
| 54 |
except Exception as e:
|
| 55 |
print(f"Error loading system prompt: {e}")
|
| 56 |
self.system_prompt = None
|
|
|
|
| 83 |
try:
|
| 84 |
# Simple test request
|
| 85 |
test_request = LLMRequest(
|
| 86 |
+
text="Hello! What model are you? Tell me your full specification."
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
response = self.generate_response(test_request)
|
| 89 |
return response.content is not None and len(response.content.strip()) > 0
|
|
|
|
| 91 |
print(f"Connection test failed: {e}")
|
| 92 |
return False
|
| 93 |
|
| 94 |
+
def _validate_image_paths(self, image_paths: List[str]) -> List[str]:
|
| 95 |
+
"""Validate that image paths exist and are readable.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
image_paths (List[str]): List of image paths to validate
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
List[str]: List of valid image paths
|
| 102 |
+
"""
|
| 103 |
+
valid_paths = []
|
| 104 |
+
for path in image_paths:
|
| 105 |
+
if Path(path).exists() and Path(path).is_file():
|
| 106 |
+
valid_paths.append(path)
|
| 107 |
+
else:
|
| 108 |
+
print(f"Warning: Image path does not exist: {path}")
|
| 109 |
+
return valid_paths
|
| 110 |
+
|
| 111 |
def _encode_image(self, image_path: str) -> str:
|
| 112 |
"""Encode image to base64 string.
|
| 113 |
|
|
|
|
| 123 |
except Exception as e:
|
| 124 |
print(f"ERROR: _encode_image failed for {image_path} (type: {type(image_path)}): {e}")
|
| 125 |
raise
|
| 126 |
+
|
| 127 |
+
def _get_image_mime_type(self, image_path: str) -> str:
|
| 128 |
+
"""Detect the MIME type of an image file.
|
| 129 |
|
| 130 |
Args:
|
| 131 |
+
image_path (str): Path to the image file
|
| 132 |
|
| 133 |
Returns:
|
| 134 |
+
str: MIME type (e.g., 'image/png', 'image/jpeg')
|
| 135 |
"""
|
| 136 |
+
# Get file extension
|
| 137 |
+
ext = Path(image_path).suffix.lower()
|
| 138 |
+
|
| 139 |
+
# Map extensions to MIME types
|
| 140 |
+
mime_types = {
|
| 141 |
+
'.png': 'image/png',
|
| 142 |
+
'.jpg': 'image/jpeg',
|
| 143 |
+
'.jpeg': 'image/jpeg',
|
| 144 |
+
'.gif': 'image/gif',
|
| 145 |
+
'.webp': 'image/webp',
|
| 146 |
+
'.bmp': 'image/bmp',
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
return mime_types.get(ext, 'image/png') # Default to PNG for medical images
|
| 150 |
|
| 151 |
def __str__(self) -> str:
|
| 152 |
"""String representation of the provider."""
|
benchmarking/llm_providers/google_provider.py
CHANGED
|
@@ -14,6 +14,10 @@ class GoogleProvider(LLMProvider):
|
|
| 14 |
|
| 15 |
def _setup(self) -> None:
|
| 16 |
"""Set up Google langchain client."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
api_key = os.getenv("GOOGLE_API_KEY")
|
| 18 |
if not api_key:
|
| 19 |
raise ValueError("GOOGLE_API_KEY environment variable is required")
|
|
@@ -21,7 +25,10 @@ class GoogleProvider(LLMProvider):
|
|
| 21 |
# Create ChatGoogleGenerativeAI instance
|
| 22 |
self.client = ChatGoogleGenerativeAI(
|
| 23 |
model=self.model_name,
|
| 24 |
-
google_api_key=api_key
|
|
|
|
|
|
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
|
@@ -54,9 +61,10 @@ class GoogleProvider(LLMProvider):
|
|
| 54 |
try:
|
| 55 |
# For langchain Google, pass image data as base64
|
| 56 |
image_b64 = self._encode_image(image_path)
|
|
|
|
| 57 |
content_parts.append({
|
| 58 |
"type": "image_url",
|
| 59 |
-
"image_url": f"data:
|
| 60 |
})
|
| 61 |
except Exception as e:
|
| 62 |
print(f"Error reading image {image_path}: {e}")
|
|
@@ -68,18 +76,13 @@ class GoogleProvider(LLMProvider):
|
|
| 68 |
|
| 69 |
# Make API call using langchain
|
| 70 |
try:
|
| 71 |
-
#
|
| 72 |
-
self.client.temperature = request.temperature
|
| 73 |
-
self.client.max_output_tokens = request.max_tokens
|
| 74 |
-
self.client.top_p = request.top_p
|
| 75 |
-
|
| 76 |
response = self.client.invoke(messages)
|
|
|
|
| 77 |
|
|
|
|
| 78 |
duration = time.time() - start_time
|
| 79 |
|
| 80 |
-
# Extract response content
|
| 81 |
-
content = response.content if response.content else ""
|
| 82 |
-
|
| 83 |
# Get usage information if available
|
| 84 |
usage = {}
|
| 85 |
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
|
|
|
| 14 |
|
| 15 |
def _setup(self) -> None:
|
| 16 |
"""Set up Google langchain client."""
|
| 17 |
+
# Set provider name
|
| 18 |
+
self.provider_name = "google"
|
| 19 |
+
|
| 20 |
+
# Get API key from environment variable
|
| 21 |
api_key = os.getenv("GOOGLE_API_KEY")
|
| 22 |
if not api_key:
|
| 23 |
raise ValueError("GOOGLE_API_KEY environment variable is required")
|
|
|
|
| 25 |
# Create ChatGoogleGenerativeAI instance
|
| 26 |
self.client = ChatGoogleGenerativeAI(
|
| 27 |
model=self.model_name,
|
| 28 |
+
google_api_key=api_key,
|
| 29 |
+
temperature=self.temperature,
|
| 30 |
+
max_output_tokens=self.max_tokens,
|
| 31 |
+
top_p=self.top_p
|
| 32 |
)
|
| 33 |
|
| 34 |
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
|
|
|
| 61 |
try:
|
| 62 |
# For langchain Google, pass image data as base64
|
| 63 |
image_b64 = self._encode_image(image_path)
|
| 64 |
+
mime_type = self._get_image_mime_type(image_path)
|
| 65 |
content_parts.append({
|
| 66 |
"type": "image_url",
|
| 67 |
+
"image_url": f"data:{mime_type};base64,{image_b64}"
|
| 68 |
})
|
| 69 |
except Exception as e:
|
| 70 |
print(f"Error reading image {image_path}: {e}")
|
|
|
|
| 76 |
|
| 77 |
# Make API call using langchain
|
| 78 |
try:
|
| 79 |
+
# Make API call
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
response = self.client.invoke(messages)
|
| 81 |
+
content = response.content if response.content else ""
|
| 82 |
|
| 83 |
+
# Calculate duration
|
| 84 |
duration = time.time() - start_time
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
# Get usage information if available
|
| 87 |
usage = {}
|
| 88 |
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
benchmarking/llm_providers/medgemma_provider.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MedGemma LLM provider implementation using the MedGemma FastAPI service."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import httpx
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tenacity import retry, wait_exponential, stop_after_attempt
|
| 9 |
+
|
| 10 |
+
from .base import LLMProvider, LLMRequest, LLMResponse
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MedGemmaProvider(LLMProvider):
|
| 14 |
+
"""MedGemma LLM provider that communicates with the MedGemma FastAPI service.
|
| 15 |
+
|
| 16 |
+
This provider wraps Google's MedGemma 4B model as an LLMProvider for benchmarking.
|
| 17 |
+
It communicates with a running MedGemma FastAPI service on localhost:8002.
|
| 18 |
+
|
| 19 |
+
MedGemma is a specialized multimodal AI model trained on medical images and text.
|
| 20 |
+
It provides expert-level analysis for chest X-rays, dermatology images,
|
| 21 |
+
ophthalmology images, and histopathology slides.
|
| 22 |
+
|
| 23 |
+
Requirements:
|
| 24 |
+
- MedGemma FastAPI service must be running on the configured API URL
|
| 25 |
+
- Default URL: http://localhost:8002
|
| 26 |
+
- Can be overridden via MEDGEMMA_API_URL environment variable
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, model_name: str, system_prompt: str, **kwargs):
|
| 30 |
+
"""Initialize MedGemma provider.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_name (str): Model name (for consistency with other providers)
|
| 34 |
+
system_prompt (str): System prompt identifier to load from file
|
| 35 |
+
**kwargs: Additional configuration parameters
|
| 36 |
+
- api_url: URL of the MedGemma FastAPI service
|
| 37 |
+
- max_new_tokens: Maximum tokens to generate (default: 300)
|
| 38 |
+
"""
|
| 39 |
+
# Extract MedGemma-specific config before calling super().__init__
|
| 40 |
+
self.api_url = kwargs.pop('api_url', None) or os.getenv('MEDGEMMA_API_URL', 'http://localhost:8002')
|
| 41 |
+
self.max_new_tokens = kwargs.pop('max_new_tokens', 300)
|
| 42 |
+
self.client = None
|
| 43 |
+
|
| 44 |
+
# Call parent constructor
|
| 45 |
+
super().__init__(model_name, system_prompt, **kwargs)
|
| 46 |
+
|
| 47 |
+
def _setup(self) -> None:
|
| 48 |
+
"""Set up httpx client for communicating with MedGemma API."""
|
| 49 |
+
# Create httpx client with reasonable timeouts
|
| 50 |
+
timeout_config = httpx.Timeout(
|
| 51 |
+
timeout=300.0, # 5 minutes for inference
|
| 52 |
+
connect=10.0 # 10 seconds to establish connection
|
| 53 |
+
)
|
| 54 |
+
self.client = httpx.Client(timeout=timeout_config)
|
| 55 |
+
|
| 56 |
+
# Test connection to MedGemma service
|
| 57 |
+
try:
|
| 58 |
+
response = self.client.get(f"{self.api_url}/docs")
|
| 59 |
+
if response.status_code != 200:
|
| 60 |
+
print(f"Warning: MedGemma API at {self.api_url} may not be running (status: {response.status_code})")
|
| 61 |
+
except httpx.ConnectError:
|
| 62 |
+
print(f"Warning: Could not connect to MedGemma API at {self.api_url}")
|
| 63 |
+
print("Please ensure the MedGemma FastAPI service is running:")
|
| 64 |
+
print(f" python medrax/tools/vqa/medgemma/medgemma.py")
|
| 65 |
+
|
| 66 |
+
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
| 67 |
+
def generate_response(self, request: LLMRequest) -> LLMResponse:
|
| 68 |
+
"""Generate response using MedGemma API.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
request (LLMRequest): The request containing text, images, and parameters
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
LLMResponse: The response from MedGemma
|
| 75 |
+
"""
|
| 76 |
+
start_time = time.time()
|
| 77 |
+
|
| 78 |
+
if self.client is None:
|
| 79 |
+
return LLMResponse(
|
| 80 |
+
content="Error: MedGemma client not initialized",
|
| 81 |
+
duration=time.time() - start_time
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Validate and prepare images
|
| 86 |
+
if not request.images:
|
| 87 |
+
return LLMResponse(
|
| 88 |
+
content="Error: MedGemma requires at least one image",
|
| 89 |
+
duration=time.time() - start_time
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
valid_images = self._validate_image_paths(request.images)
|
| 93 |
+
if not valid_images:
|
| 94 |
+
return LLMResponse(
|
| 95 |
+
content="Error: No valid image paths provided",
|
| 96 |
+
duration=time.time() - start_time
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Prepare multipart form data
|
| 100 |
+
files_to_send = []
|
| 101 |
+
for image_path in valid_images:
|
| 102 |
+
try:
|
| 103 |
+
# Detect correct MIME type based on file extension
|
| 104 |
+
ext = Path(image_path).suffix.lower()
|
| 105 |
+
mime_type = "image/png" if ext == ".png" else "image/jpeg"
|
| 106 |
+
|
| 107 |
+
# Read image file
|
| 108 |
+
with open(image_path, "rb") as f:
|
| 109 |
+
image_data = f.read()
|
| 110 |
+
|
| 111 |
+
# Add to files list
|
| 112 |
+
files_to_send.append(
|
| 113 |
+
("images", (os.path.basename(image_path), image_data, mime_type))
|
| 114 |
+
)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"Error reading image {image_path}: {e}")
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
if not files_to_send:
|
| 120 |
+
return LLMResponse(
|
| 121 |
+
content="Error: Failed to read any image files",
|
| 122 |
+
duration=time.time() - start_time
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Prepare form data
|
| 126 |
+
# Use system_prompt if provided, otherwise use default
|
| 127 |
+
system_prompt_text = self.system_prompt if self.system_prompt else "You are an expert radiologist."
|
| 128 |
+
|
| 129 |
+
# Override max_new_tokens if provided in request
|
| 130 |
+
max_tokens = getattr(request, 'max_tokens', self.max_new_tokens)
|
| 131 |
+
|
| 132 |
+
data = {
|
| 133 |
+
"prompt": request.text,
|
| 134 |
+
"system_prompt": system_prompt_text,
|
| 135 |
+
"max_new_tokens": max_tokens,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# Make API request
|
| 139 |
+
response = self.client.post(
|
| 140 |
+
f"{self.api_url}/analyze-images/",
|
| 141 |
+
data=data,
|
| 142 |
+
files=files_to_send,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Check for errors
|
| 146 |
+
response.raise_for_status()
|
| 147 |
+
|
| 148 |
+
# Parse response
|
| 149 |
+
response_data = response.json()
|
| 150 |
+
content = response_data.get("response", "")
|
| 151 |
+
metadata = response_data.get("metadata", {})
|
| 152 |
+
|
| 153 |
+
duration = time.time() - start_time
|
| 154 |
+
|
| 155 |
+
# MedGemma doesn't provide token usage, but we can include request info
|
| 156 |
+
usage = {
|
| 157 |
+
"num_images": len(valid_images),
|
| 158 |
+
"max_new_tokens": max_tokens,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
return LLMResponse(
|
| 162 |
+
content=content,
|
| 163 |
+
usage=usage,
|
| 164 |
+
duration=duration
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
except httpx.TimeoutException as e:
|
| 168 |
+
duration = time.time() - start_time
|
| 169 |
+
error_msg = f"MedGemma API request timed out after {duration:.1f}s. The server might be overloaded or the model is taking too long to process."
|
| 170 |
+
print(f"Error: {error_msg}")
|
| 171 |
+
return LLMResponse(
|
| 172 |
+
content=f"Error: {error_msg}",
|
| 173 |
+
duration=duration
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
except httpx.ConnectError as e:
|
| 177 |
+
duration = time.time() - start_time
|
| 178 |
+
error_msg = f"Could not connect to MedGemma API at {self.api_url}. Please ensure the service is running."
|
| 179 |
+
print(f"Error: {error_msg}")
|
| 180 |
+
return LLMResponse(
|
| 181 |
+
content=f"Error: {error_msg}",
|
| 182 |
+
duration=duration
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
except httpx.HTTPStatusError as e:
|
| 186 |
+
duration = time.time() - start_time
|
| 187 |
+
error_msg = f"MedGemma API returned error {e.response.status_code}: {e.response.text}"
|
| 188 |
+
print(f"Error: {error_msg}")
|
| 189 |
+
return LLMResponse(
|
| 190 |
+
content=f"Error: {error_msg}",
|
| 191 |
+
duration=duration
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
duration = time.time() - start_time
|
| 196 |
+
error_msg = f"Unexpected error calling MedGemma API: {str(e)}"
|
| 197 |
+
print(f"Error: {error_msg}")
|
| 198 |
+
return LLMResponse(
|
| 199 |
+
content=f"Error: {error_msg}",
|
| 200 |
+
duration=duration
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def test_connection(self) -> bool:
|
| 204 |
+
"""Test the connection to the MedGemma API service.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
bool: True if connection is successful and service is responding
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
# Try to access the API docs endpoint
|
| 211 |
+
response = self.client.get(f"{self.api_url}/docs")
|
| 212 |
+
return response.status_code == 200
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"MedGemma connection test failed: {e}")
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
def __del__(self):
|
| 218 |
+
"""Clean up httpx client on deletion."""
|
| 219 |
+
if self.client is not None:
|
| 220 |
+
self.client.close()
|
| 221 |
+
|
| 222 |
+
|
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -21,7 +21,9 @@ class MedRAXProvider(LLMProvider):
|
|
| 21 |
system_prompt (str): System prompt to use
|
| 22 |
**kwargs: Additional configuration parameters
|
| 23 |
"""
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
self.agent = None
|
| 26 |
self.tools_dict = None
|
| 27 |
|
|
@@ -33,15 +35,15 @@ class MedRAXProvider(LLMProvider):
|
|
| 33 |
print("Starting server...")
|
| 34 |
|
| 35 |
selected_tools = [
|
| 36 |
-
"TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 37 |
-
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 38 |
-
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 39 |
-
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 40 |
-
"MedGemmaVQATool", # Google MedGemma VQA tool
|
| 41 |
-
"XRayVQATool", # For visual question answering on X-rays
|
| 42 |
-
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 43 |
-
"WebBrowserTool", # For web browsing and search capabilities
|
| 44 |
-
"DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 45 |
]
|
| 46 |
|
| 47 |
rag_config = RAGConfig(
|
|
@@ -62,14 +64,15 @@ class MedRAXProvider(LLMProvider):
|
|
| 62 |
model_kwargs = {}
|
| 63 |
|
| 64 |
agent, tools_dict = initialize_agent(
|
| 65 |
-
prompt_file="
|
| 66 |
tools_to_use=selected_tools,
|
| 67 |
-
model_dir="/
|
| 68 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 69 |
device="cuda:0",
|
| 70 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 71 |
-
temperature=
|
| 72 |
-
top_p=
|
|
|
|
| 73 |
model_kwargs=model_kwargs,
|
| 74 |
rag_config=rag_config,
|
| 75 |
system_prompt=self.prompt_name,
|
|
@@ -107,32 +110,34 @@ class MedRAXProvider(LLMProvider):
|
|
| 107 |
thread_id = str(int(time.time() * 1000)) # Unique thread ID
|
| 108 |
|
| 109 |
if request.images:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
valid_images = self._validate_image_paths(request.images)
|
| 111 |
print(f"Processing {len(valid_images)} images")
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
-
|
| 119 |
-
|
| 120 |
|
| 121 |
-
|
| 122 |
"type": "image_url",
|
| 123 |
-
"image_url": {"url": f"data:
|
| 124 |
-
}
|
| 125 |
except Exception as e:
|
| 126 |
print(f"ERROR: Image encoding failed for {image_path}: {e}")
|
| 127 |
raise
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
messages.append(HumanMessage(content=[{
|
| 133 |
-
"type": "text",
|
| 134 |
-
"text": request.text
|
| 135 |
-
}]))
|
| 136 |
else:
|
| 137 |
# If no images, add text as simple string
|
| 138 |
messages.append(HumanMessage(content=request.text))
|
|
@@ -216,8 +221,67 @@ class MedRAXProvider(LLMProvider):
|
|
| 216 |
"type": type(msg).__name__,
|
| 217 |
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 218 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
chunk_messages.append(msg_info)
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
serializable_chunk["messages"] = chunk_messages
|
| 222 |
|
| 223 |
return serializable_chunk
|
|
|
|
| 21 |
system_prompt (str): System prompt to use
|
| 22 |
**kwargs: Additional configuration parameters
|
| 23 |
"""
|
| 24 |
+
# Set provider name
|
| 25 |
+
self.provider_name = "medrax"
|
| 26 |
+
|
| 27 |
self.agent = None
|
| 28 |
self.tools_dict = None
|
| 29 |
|
|
|
|
| 35 |
print("Starting server...")
|
| 36 |
|
| 37 |
selected_tools = [
|
| 38 |
+
# "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 39 |
+
# "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 40 |
+
# "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 41 |
+
# "XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 42 |
+
# "MedGemmaVQATool", # Google MedGemma VQA tool
|
| 43 |
+
# "XRayVQATool", # For visual question answering on X-rays
|
| 44 |
+
# "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 45 |
+
# "WebBrowserTool", # For web browsing and search capabilities
|
| 46 |
+
# "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 47 |
]
|
| 48 |
|
| 49 |
rag_config = RAGConfig(
|
|
|
|
| 64 |
model_kwargs = {}
|
| 65 |
|
| 66 |
agent, tools_dict = initialize_agent(
|
| 67 |
+
prompt_file="benchmarking/system_prompts.txt",
|
| 68 |
tools_to_use=selected_tools,
|
| 69 |
+
model_dir="/home/lijunzh3/scratch/MedRAX2/model-weights",
|
| 70 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 71 |
device="cuda:0",
|
| 72 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 73 |
+
temperature=self.temperature,
|
| 74 |
+
top_p=self.top_p,
|
| 75 |
+
max_tokens=self.max_tokens,
|
| 76 |
model_kwargs=model_kwargs,
|
| 77 |
rag_config=rag_config,
|
| 78 |
system_prompt=self.prompt_name,
|
|
|
|
| 110 |
thread_id = str(int(time.time() * 1000)) # Unique thread ID
|
| 111 |
|
| 112 |
if request.images:
|
| 113 |
+
# Build multimodal content with text and images
|
| 114 |
+
content = [{"type": "text", "text": request.text}]
|
| 115 |
+
|
| 116 |
+
# Validate image paths
|
| 117 |
valid_images = self._validate_image_paths(request.images)
|
| 118 |
print(f"Processing {len(valid_images)} images")
|
| 119 |
+
|
| 120 |
+
# Add image paths for tools
|
| 121 |
+
for image_path in valid_images:
|
| 122 |
+
content.append({"type": "text", "text": f"image_path: {image_path}"})
|
| 123 |
+
|
| 124 |
+
# Add image content for multimodal LLM
|
| 125 |
+
for image_path in valid_images:
|
| 126 |
try:
|
| 127 |
+
img_base64 = self._encode_image(image_path)
|
| 128 |
+
mime_type = self._get_image_mime_type(image_path)
|
| 129 |
|
| 130 |
+
content.append({
|
| 131 |
"type": "image_url",
|
| 132 |
+
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"}
|
| 133 |
+
})
|
| 134 |
except Exception as e:
|
| 135 |
print(f"ERROR: Image encoding failed for {image_path}: {e}")
|
| 136 |
raise
|
| 137 |
+
|
| 138 |
+
# Create single multimodal message
|
| 139 |
+
messages.append(HumanMessage(content=content))
|
| 140 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
# If no images, add text as simple string
|
| 143 |
messages.append(HumanMessage(content=request.text))
|
|
|
|
| 221 |
"type": type(msg).__name__,
|
| 222 |
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 223 |
}
|
| 224 |
+
|
| 225 |
+
# Extract response metadata (reasoning/thinking traces)
|
| 226 |
+
if hasattr(msg, 'response_metadata') and msg.response_metadata:
|
| 227 |
+
try:
|
| 228 |
+
msg_info["response_metadata"] = dict(msg.response_metadata)
|
| 229 |
+
|
| 230 |
+
# Extract specific reasoning fields for easier access
|
| 231 |
+
# Gemini 2.0 Flash Thinking uses 'thoughts'
|
| 232 |
+
if "thoughts" in msg.response_metadata:
|
| 233 |
+
msg_info["thinking"] = msg.response_metadata["thoughts"]
|
| 234 |
+
|
| 235 |
+
# DeepSeek-R1 and similar models use 'reasoning_content'
|
| 236 |
+
if "reasoning_content" in msg.response_metadata:
|
| 237 |
+
msg_info["reasoning"] = msg.response_metadata["reasoning_content"]
|
| 238 |
+
|
| 239 |
+
# Some models expose thinking in other fields
|
| 240 |
+
if "extended_thinking" in msg.response_metadata:
|
| 241 |
+
msg_info["extended_thinking"] = msg.response_metadata["extended_thinking"]
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"Warning: Could not serialize response_metadata: {e}")
|
| 244 |
+
|
| 245 |
+
# Extract usage metadata (reasoning tokens for o1/o3 models)
|
| 246 |
+
if hasattr(msg, 'usage_metadata') and msg.usage_metadata:
|
| 247 |
+
try:
|
| 248 |
+
msg_info["usage_metadata"] = dict(msg.usage_metadata)
|
| 249 |
+
|
| 250 |
+
# Highlight reasoning tokens if present
|
| 251 |
+
if isinstance(msg.usage_metadata, dict) and "reasoning_tokens" in msg.usage_metadata:
|
| 252 |
+
msg_info["reasoning_tokens"] = msg.usage_metadata["reasoning_tokens"]
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Warning: Could not serialize usage_metadata: {e}")
|
| 255 |
+
|
| 256 |
+
# Extract additional kwargs (some models put reasoning here)
|
| 257 |
+
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs:
|
| 258 |
+
try:
|
| 259 |
+
# Filter for reasoning-related fields
|
| 260 |
+
reasoning_kwargs = {}
|
| 261 |
+
for key in ['thinking', 'reasoning', 'thoughts', 'chain_of_thought']:
|
| 262 |
+
if key in msg.additional_kwargs:
|
| 263 |
+
reasoning_kwargs[key] = msg.additional_kwargs[key]
|
| 264 |
+
|
| 265 |
+
if reasoning_kwargs:
|
| 266 |
+
msg_info["additional_reasoning"] = reasoning_kwargs
|
| 267 |
+
|
| 268 |
+
# Include full additional_kwargs for completeness (may contain other useful info)
|
| 269 |
+
msg_info["additional_kwargs"] = dict(msg.additional_kwargs)
|
| 270 |
+
except Exception as e:
|
| 271 |
+
print(f"Warning: Could not serialize additional_kwargs: {e}")
|
| 272 |
+
|
| 273 |
chunk_messages.append(msg_info)
|
| 274 |
+
|
| 275 |
+
# Enhanced logging for debugging
|
| 276 |
+
log_msg = f"Message in chunk: type={msg_info['type']}"
|
| 277 |
+
if "thinking" in msg_info:
|
| 278 |
+
log_msg += f", has_thinking=True (length={len(str(msg_info['thinking']))})"
|
| 279 |
+
if "reasoning" in msg_info:
|
| 280 |
+
log_msg += f", has_reasoning=True (length={len(str(msg_info['reasoning']))})"
|
| 281 |
+
if "reasoning_tokens" in msg_info:
|
| 282 |
+
log_msg += f", reasoning_tokens={msg_info['reasoning_tokens']}"
|
| 283 |
+
print(log_msg)
|
| 284 |
+
|
| 285 |
serializable_chunk["messages"] = chunk_messages
|
| 286 |
|
| 287 |
return serializable_chunk
|
benchmarking/llm_providers/openai_provider.py
CHANGED
|
@@ -14,21 +14,28 @@ class OpenAIProvider(LLMProvider):
|
|
| 14 |
|
| 15 |
def _setup(self) -> None:
|
| 16 |
"""Set up OpenAI langchain client."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
api_key = os.getenv("OPENAI_API_KEY")
|
| 18 |
-
if not api_key:
|
| 19 |
-
raise ValueError("OPENAI_API_KEY environment variable is required")
|
| 20 |
-
|
| 21 |
base_url = os.getenv("OPENAI_BASE_URL")
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
#
|
| 24 |
kwargs = {
|
| 25 |
"model": self.model_name,
|
| 26 |
"api_key": api_key,
|
|
|
|
|
|
|
| 27 |
}
|
| 28 |
-
|
| 29 |
if base_url:
|
| 30 |
kwargs["base_url"] = base_url
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
self.client = ChatOpenAI(**kwargs)
|
| 33 |
|
| 34 |
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
|
@@ -63,10 +70,11 @@ class OpenAIProvider(LLMProvider):
|
|
| 63 |
for image_path in valid_images:
|
| 64 |
try:
|
| 65 |
image_b64 = self._encode_image(image_path)
|
|
|
|
| 66 |
user_content.append({
|
| 67 |
"type": "image_url",
|
| 68 |
"image_url": {
|
| 69 |
-
"url": f"data:
|
| 70 |
"detail": "high"
|
| 71 |
}
|
| 72 |
})
|
|
@@ -75,13 +83,8 @@ class OpenAIProvider(LLMProvider):
|
|
| 75 |
|
| 76 |
messages.append(HumanMessage(content=user_content))
|
| 77 |
|
| 78 |
-
# Make API call
|
| 79 |
try:
|
| 80 |
-
# Update client parameters for this request
|
| 81 |
-
self.client.temperature = request.temperature
|
| 82 |
-
self.client.max_tokens = request.max_tokens
|
| 83 |
-
self.client.top_p = request.top_p
|
| 84 |
-
|
| 85 |
response = self.client.invoke(messages)
|
| 86 |
|
| 87 |
duration = time.time() - start_time
|
|
|
|
| 14 |
|
| 15 |
def _setup(self) -> None:
|
| 16 |
"""Set up OpenAI langchain client."""
|
| 17 |
+
# Set provider name
|
| 18 |
+
self.provider_name = "openai"
|
| 19 |
+
|
| 20 |
+
# Get API key and base URL from environment variables
|
| 21 |
api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
|
|
|
|
|
| 22 |
base_url = os.getenv("OPENAI_BASE_URL")
|
| 23 |
+
if not api_key or not base_url:
|
| 24 |
+
raise ValueError("OPENAI_API_KEY and OPENAI_BASE_URL environment variables are required")
|
| 25 |
|
| 26 |
+
# Construct kwargs for ChatOpenAI instance
|
| 27 |
kwargs = {
|
| 28 |
"model": self.model_name,
|
| 29 |
"api_key": api_key,
|
| 30 |
+
"temperature": self.temperature,
|
| 31 |
+
"max_tokens": self.max_tokens
|
| 32 |
}
|
|
|
|
| 33 |
if base_url:
|
| 34 |
kwargs["base_url"] = base_url
|
| 35 |
+
if self.model_name.startswith("gpt-5") or self.model_name.startswith("o1") or self.model_name.startswith("o3"):
|
| 36 |
+
kwargs["reasoning_effort"] = "high"
|
| 37 |
+
|
| 38 |
+
# Create ChatOpenAI instance
|
| 39 |
self.client = ChatOpenAI(**kwargs)
|
| 40 |
|
| 41 |
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
|
|
|
| 70 |
for image_path in valid_images:
|
| 71 |
try:
|
| 72 |
image_b64 = self._encode_image(image_path)
|
| 73 |
+
mime_type = self._get_image_mime_type(image_path)
|
| 74 |
user_content.append({
|
| 75 |
"type": "image_url",
|
| 76 |
"image_url": {
|
| 77 |
+
"url": f"data:{mime_type};base64,{image_b64}",
|
| 78 |
"detail": "high"
|
| 79 |
}
|
| 80 |
})
|
|
|
|
| 83 |
|
| 84 |
messages.append(HumanMessage(content=user_content))
|
| 85 |
|
| 86 |
+
# Make API call
|
| 87 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
response = self.client.invoke(messages)
|
| 89 |
|
| 90 |
duration = time.time() - start_time
|
benchmarking/llm_providers/openrouter_provider.py
CHANGED
|
@@ -13,11 +13,16 @@ class OpenRouterProvider(LLMProvider):
|
|
| 13 |
|
| 14 |
def _setup(self) -> None:
|
| 15 |
"""Set up OpenRouter client models."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
api_key = os.getenv("OPENROUTER_API_KEY")
|
| 17 |
-
if not api_key:
|
| 18 |
-
raise ValueError("OPENROUTER_API_KEY environment variable is required for xAI Grok via OpenRouter.")
|
| 19 |
base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
| 22 |
|
| 23 |
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
|
@@ -45,10 +50,11 @@ class OpenRouterProvider(LLMProvider):
|
|
| 45 |
for image_path in valid_images:
|
| 46 |
try:
|
| 47 |
image_b64 = self._encode_image(image_path)
|
|
|
|
| 48 |
user_content.append({
|
| 49 |
"type": "image_url",
|
| 50 |
"image_url": {
|
| 51 |
-
"url": f"data:
|
| 52 |
"detail": "high"
|
| 53 |
}
|
| 54 |
})
|
|
@@ -57,14 +63,14 @@ class OpenRouterProvider(LLMProvider):
|
|
| 57 |
|
| 58 |
messages.append({"role": "user", "content": user_content})
|
| 59 |
|
|
|
|
| 60 |
try:
|
| 61 |
response = self.client.chat.completions.create(
|
| 62 |
model=self.model_name,
|
| 63 |
messages=messages,
|
| 64 |
-
temperature=
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
**(request.additional_params or {})
|
| 68 |
)
|
| 69 |
duration = time.time() - start_time
|
| 70 |
content = response.choices[0].message.content if response.choices else ""
|
|
|
|
| 13 |
|
| 14 |
def _setup(self) -> None:
|
| 15 |
"""Set up OpenRouter client models."""
|
| 16 |
+
# Set provider name
|
| 17 |
+
self.provider_name = "openrouter"
|
| 18 |
+
|
| 19 |
+
# Get API key and base URL from environment variables
|
| 20 |
api_key = os.getenv("OPENROUTER_API_KEY")
|
|
|
|
|
|
|
| 21 |
base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
|
| 22 |
+
if not api_key or not base_url:
|
| 23 |
+
raise ValueError("OPENROUTER_API_KEY and OPENROUTER_BASE_URL environment variables are required")
|
| 24 |
+
|
| 25 |
+
# Create OpenAI client with OpenRouter endpoint
|
| 26 |
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
| 27 |
|
| 28 |
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
|
|
|
| 50 |
for image_path in valid_images:
|
| 51 |
try:
|
| 52 |
image_b64 = self._encode_image(image_path)
|
| 53 |
+
mime_type = self._get_image_mime_type(image_path)
|
| 54 |
user_content.append({
|
| 55 |
"type": "image_url",
|
| 56 |
"image_url": {
|
| 57 |
+
"url": f"data:{mime_type};base64,{image_b64}",
|
| 58 |
"detail": "high"
|
| 59 |
}
|
| 60 |
})
|
|
|
|
| 63 |
|
| 64 |
messages.append({"role": "user", "content": user_content})
|
| 65 |
|
| 66 |
+
# Make API call
|
| 67 |
try:
|
| 68 |
response = self.client.chat.completions.create(
|
| 69 |
model=self.model_name,
|
| 70 |
messages=messages,
|
| 71 |
+
temperature=self.temperature,
|
| 72 |
+
max_tokens=self.max_tokens,
|
| 73 |
+
top_p=self.top_p
|
|
|
|
| 74 |
)
|
| 75 |
duration = time.time() - start_time
|
| 76 |
content = response.choices[0].message.content if response.choices else ""
|
benchmarking/runner.py
CHANGED
|
@@ -32,16 +32,17 @@ class BenchmarkResult:
|
|
| 32 |
@dataclass
|
| 33 |
class BenchmarkRunConfig:
|
| 34 |
"""Configuration for a benchmark run."""
|
|
|
|
| 35 |
provider_name: str
|
| 36 |
model_name: str
|
| 37 |
-
benchmark_name: str
|
| 38 |
output_dir: str
|
| 39 |
max_questions: Optional[int] = None
|
| 40 |
temperature: float = 0.7
|
| 41 |
top_p: float = 0.95
|
| 42 |
max_tokens: int = 5000
|
| 43 |
-
additional_params: Optional[Dict[str, Any]] = None
|
| 44 |
concurrency: int = 1
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
class BenchmarkRunner:
|
|
@@ -59,11 +60,10 @@ class BenchmarkRunner:
|
|
| 59 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 60 |
|
| 61 |
# Generate unique run ID
|
| 62 |
-
self.run_id = f"{config.benchmark_name}_{config.provider_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 63 |
|
| 64 |
# Set up logging
|
| 65 |
self._setup_logging()
|
| 66 |
-
|
| 67 |
self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
|
| 68 |
|
| 69 |
def _setup_logging(self) -> None:
|
|
@@ -91,34 +91,28 @@ class BenchmarkRunner:
|
|
| 91 |
|
| 92 |
def run_benchmark(
|
| 93 |
self,
|
| 94 |
-
llm_provider: LLMProvider,
|
| 95 |
benchmark: Benchmark,
|
|
|
|
| 96 |
) -> Dict[str, Any]:
|
| 97 |
"""Run a benchmark against an LLM provider.
|
| 98 |
|
| 99 |
Args:
|
| 100 |
-
llm_provider (LLMProvider): The LLM provider to test
|
| 101 |
benchmark (Benchmark): The benchmark to run
|
|
|
|
| 102 |
|
| 103 |
Returns:
|
| 104 |
Dict[str, Any]: Summary of benchmark results
|
| 105 |
"""
|
| 106 |
self.logger.info(f"Starting benchmark run: {self.run_id}")
|
| 107 |
-
self.logger.info(f"Model: {llm_provider.model_name}")
|
| 108 |
self.logger.info(f"Benchmark: {benchmark}")
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# Test provider connection
|
| 111 |
if not llm_provider.test_connection():
|
| 112 |
self.logger.error("LLM provider connection test failed")
|
| 113 |
return {"error": "LLM provider connection test failed"}
|
| 114 |
|
| 115 |
-
# Get data points to process
|
| 116 |
-
total_questions = len(benchmark)
|
| 117 |
-
max_questions = self.config.max_questions or total_questions
|
| 118 |
-
end_index = min(max_questions, total_questions)
|
| 119 |
-
|
| 120 |
-
self.logger.info(f"Processing questions {0} to {end_index-1} of {total_questions}")
|
| 121 |
-
|
| 122 |
# Initialize counters
|
| 123 |
processed = 0
|
| 124 |
correct = 0
|
|
@@ -127,29 +121,10 @@ class BenchmarkRunner:
|
|
| 127 |
# Determine concurrency
|
| 128 |
max_workers = max(1, int(getattr(self.config, "concurrency", 1) or 1))
|
| 129 |
|
| 130 |
-
# Prefetch data points to avoid potential thread-safety issues inside benchmark access
|
| 131 |
-
data_points = []
|
| 132 |
-
for i in range(0, end_index):
|
| 133 |
-
try:
|
| 134 |
-
data_points.append(benchmark.get_data_point(i))
|
| 135 |
-
except Exception as e:
|
| 136 |
-
self.logger.error(f"Error fetching data point {i}: {e}")
|
| 137 |
-
error_result = BenchmarkResult(
|
| 138 |
-
data_point_id=f"error_{i}",
|
| 139 |
-
question="",
|
| 140 |
-
model_answer="",
|
| 141 |
-
correct_answer="",
|
| 142 |
-
is_correct=False,
|
| 143 |
-
duration=0.0,
|
| 144 |
-
error=str(e)
|
| 145 |
-
)
|
| 146 |
-
self.results.append(error_result)
|
| 147 |
-
self._save_individual_result(error_result)
|
| 148 |
-
|
| 149 |
# Process data points in parallel using a bounded thread pool
|
| 150 |
-
with tqdm(total=
|
| 151 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 152 |
-
future_to_index = {executor.submit(self._process_data_point,
|
| 153 |
for future in as_completed(future_to_index):
|
| 154 |
idx = future_to_index[future]
|
| 155 |
try:
|
|
@@ -184,30 +159,29 @@ class BenchmarkRunner:
|
|
| 184 |
accuracy = (correct / processed) * 100
|
| 185 |
avg_duration = total_duration / processed if processed > 0 else 0.0
|
| 186 |
self.logger.info(
|
| 187 |
-
f"Progress: {processed}/{
|
| 188 |
f"Accuracy: {accuracy:.2f}% | "
|
| 189 |
f"Avg Duration: {avg_duration:.2f}s"
|
| 190 |
)
|
| 191 |
|
| 192 |
# Save final results
|
| 193 |
summary = self._save_final_results(benchmark)
|
| 194 |
-
|
| 195 |
self.logger.info(f"Benchmark run completed: {self.run_id}")
|
| 196 |
-
self.logger.info(f"
|
| 197 |
-
|
| 198 |
-
|
| 199 |
return summary
|
| 200 |
|
| 201 |
def _process_data_point(
|
| 202 |
self,
|
| 203 |
-
llm_provider: LLMProvider,
|
| 204 |
data_point: BenchmarkDataPoint,
|
|
|
|
| 205 |
) -> BenchmarkResult:
|
| 206 |
"""Process a single data point.
|
| 207 |
|
| 208 |
Args:
|
| 209 |
-
llm_provider (LLMProvider): The LLM provider to use
|
| 210 |
data_point (BenchmarkDataPoint): The data point to process
|
|
|
|
| 211 |
|
| 212 |
Returns:
|
| 213 |
BenchmarkResult: Result of processing the data point
|
|
@@ -215,14 +189,10 @@ class BenchmarkRunner:
|
|
| 215 |
start_time = time.time()
|
| 216 |
|
| 217 |
try:
|
| 218 |
-
# Create request
|
| 219 |
request = LLMRequest(
|
| 220 |
text=data_point.text,
|
| 221 |
-
images=data_point.images
|
| 222 |
-
temperature=self.config.temperature,
|
| 223 |
-
top_p=self.config.top_p,
|
| 224 |
-
max_tokens=self.config.max_tokens,
|
| 225 |
-
additional_params=self.config.additional_params
|
| 226 |
)
|
| 227 |
|
| 228 |
# Get response from LLM
|
|
@@ -232,10 +202,12 @@ class BenchmarkRunner:
|
|
| 232 |
model_answer = self._extract_answer(response.content)
|
| 233 |
|
| 234 |
# Check if correct
|
| 235 |
-
is_correct =
|
| 236 |
|
|
|
|
| 237 |
duration = time.time() - start_time
|
| 238 |
|
|
|
|
| 239 |
return BenchmarkResult(
|
| 240 |
data_point_id=data_point.id,
|
| 241 |
question=data_point.text,
|
|
@@ -247,8 +219,6 @@ class BenchmarkRunner:
|
|
| 247 |
chunk_history=response.chunk_history,
|
| 248 |
metadata={
|
| 249 |
"data_point_metadata": data_point.metadata,
|
| 250 |
-
"case_id": data_point.case_id,
|
| 251 |
-
"category": data_point.category,
|
| 252 |
"raw_response": response.content,
|
| 253 |
}
|
| 254 |
)
|
|
@@ -265,9 +235,7 @@ class BenchmarkRunner:
|
|
| 265 |
error=str(e),
|
| 266 |
chunk_history=None,
|
| 267 |
metadata={
|
| 268 |
-
"data_point_metadata": data_point.metadata
|
| 269 |
-
"case_id": data_point.case_id,
|
| 270 |
-
"category": data_point.category,
|
| 271 |
}
|
| 272 |
)
|
| 273 |
|
|
@@ -289,29 +257,6 @@ class BenchmarkRunner:
|
|
| 289 |
# If no pattern matches, return the full response
|
| 290 |
return response_text.strip()
|
| 291 |
|
| 292 |
-
def _is_correct_answer(self, model_answer: str, correct_answer: str) -> bool:
|
| 293 |
-
"""Check if the model answer is correct.
|
| 294 |
-
|
| 295 |
-
Args:
|
| 296 |
-
model_answer (str): The model's answer
|
| 297 |
-
correct_answer (str): The correct answer
|
| 298 |
-
|
| 299 |
-
Returns:
|
| 300 |
-
bool: True if the answer is correct
|
| 301 |
-
"""
|
| 302 |
-
if not model_answer or not correct_answer:
|
| 303 |
-
return False
|
| 304 |
-
|
| 305 |
-
# For multiple choice, compare just the letter
|
| 306 |
-
model_clean = model_answer.strip().upper()
|
| 307 |
-
correct_clean = correct_answer.strip().upper()
|
| 308 |
-
|
| 309 |
-
# Extract just the first letter for comparison
|
| 310 |
-
model_letter = model_clean[0] if model_clean else ""
|
| 311 |
-
correct_letter = correct_clean[0] if correct_clean else ""
|
| 312 |
-
|
| 313 |
-
return model_letter == correct_letter
|
| 314 |
-
|
| 315 |
def _save_individual_result(self, result: BenchmarkResult) -> None:
|
| 316 |
"""Save a single result to its own JSON file.
|
| 317 |
|
|
@@ -321,12 +266,14 @@ class BenchmarkRunner:
|
|
| 321 |
# Sanitize data_point_id for filename (remove invalid characters)
|
| 322 |
safe_id = re.sub(r'[^\w\-_.]', '_', result.data_point_id)
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
# Create filename with benchmark name and data point ID
|
| 325 |
filename = f"{self.config.benchmark_name}_{safe_id}.json"
|
| 326 |
-
result_file =
|
| 327 |
-
|
| 328 |
-
# Create individual_results directory if it doesn't exist
|
| 329 |
-
result_file.parent.mkdir(exist_ok=True)
|
| 330 |
|
| 331 |
# Convert result to serializable format
|
| 332 |
result_data = {
|
|
@@ -341,7 +288,7 @@ class BenchmarkRunner:
|
|
| 341 |
"usage": result.usage,
|
| 342 |
"error": result.error,
|
| 343 |
"chunk_history": result.chunk_history,
|
| 344 |
-
"metadata": result.metadata
|
| 345 |
}
|
| 346 |
|
| 347 |
# Save to file
|
|
@@ -357,8 +304,13 @@ class BenchmarkRunner:
|
|
| 357 |
Returns:
|
| 358 |
Dict[str, Any]: Summary of results
|
| 359 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
# Save detailed results
|
| 361 |
-
results_file =
|
| 362 |
|
| 363 |
# Convert results to serializable format for final file
|
| 364 |
results_data = []
|
|
@@ -385,29 +337,14 @@ class BenchmarkRunner:
|
|
| 385 |
|
| 386 |
accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0
|
| 387 |
|
| 388 |
-
# Calculate per-category accuracy
|
| 389 |
-
category_stats = {}
|
| 390 |
-
for result in self.results:
|
| 391 |
-
if result.metadata and result.metadata.get("category"):
|
| 392 |
-
category = result.metadata["category"]
|
| 393 |
-
if category not in category_stats:
|
| 394 |
-
category_stats[category] = {"correct": 0, "total": 0}
|
| 395 |
-
category_stats[category]["total"] += 1
|
| 396 |
-
if result.is_correct:
|
| 397 |
-
category_stats[category]["correct"] += 1
|
| 398 |
-
|
| 399 |
-
# Calculate accuracy for each category
|
| 400 |
-
category_accuracies = {}
|
| 401 |
-
for category, stats in category_stats.items():
|
| 402 |
-
category_accuracies[category] = (stats["correct"] / stats["total"]) * 100
|
| 403 |
-
|
| 404 |
# Create summary
|
| 405 |
summary = {
|
| 406 |
"run_id": self.run_id,
|
| 407 |
"timestamp": datetime.now().isoformat(),
|
| 408 |
"config": {
|
| 409 |
-
"model_name": self.config.model_name,
|
| 410 |
"benchmark_name": self.config.benchmark_name,
|
|
|
|
|
|
|
| 411 |
"temperature": self.config.temperature,
|
| 412 |
"top_p": self.config.top_p,
|
| 413 |
"max_tokens": self.config.max_tokens,
|
|
@@ -422,13 +359,12 @@ class BenchmarkRunner:
|
|
| 422 |
"total_questions": total_questions,
|
| 423 |
"total_duration": total_duration,
|
| 424 |
"avg_duration_per_question": total_duration / total_questions if total_questions > 0 else 0,
|
| 425 |
-
"category_accuracies": category_accuracies,
|
| 426 |
},
|
| 427 |
"results_file": str(results_file),
|
| 428 |
}
|
| 429 |
|
| 430 |
# Save summary
|
| 431 |
-
summary_file =
|
| 432 |
with open(summary_file, 'w') as f:
|
| 433 |
json.dump(summary, f, indent=2)
|
| 434 |
|
|
|
|
| 32 |
@dataclass
|
| 33 |
class BenchmarkRunConfig:
|
| 34 |
"""Configuration for a benchmark run."""
|
| 35 |
+
benchmark_name: str
|
| 36 |
provider_name: str
|
| 37 |
model_name: str
|
|
|
|
| 38 |
output_dir: str
|
| 39 |
max_questions: Optional[int] = None
|
| 40 |
temperature: float = 0.7
|
| 41 |
top_p: float = 0.95
|
| 42 |
max_tokens: int = 5000
|
|
|
|
| 43 |
concurrency: int = 1
|
| 44 |
+
random_seed: Optional[int] = None
|
| 45 |
+
|
| 46 |
|
| 47 |
|
| 48 |
class BenchmarkRunner:
|
|
|
|
| 60 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
|
| 62 |
# Generate unique run ID
|
| 63 |
+
self.run_id = f"{config.benchmark_name}_{config.provider_name}_{config.model_name}_{config.max_questions}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 64 |
|
| 65 |
# Set up logging
|
| 66 |
self._setup_logging()
|
|
|
|
| 67 |
self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
|
| 68 |
|
| 69 |
def _setup_logging(self) -> None:
|
|
|
|
| 91 |
|
| 92 |
def run_benchmark(
|
| 93 |
self,
|
|
|
|
| 94 |
benchmark: Benchmark,
|
| 95 |
+
llm_provider: LLMProvider,
|
| 96 |
) -> Dict[str, Any]:
|
| 97 |
"""Run a benchmark against an LLM provider.
|
| 98 |
|
| 99 |
Args:
|
|
|
|
| 100 |
benchmark (Benchmark): The benchmark to run
|
| 101 |
+
llm_provider (LLMProvider): The LLM provider to test
|
| 102 |
|
| 103 |
Returns:
|
| 104 |
Dict[str, Any]: Summary of benchmark results
|
| 105 |
"""
|
| 106 |
self.logger.info(f"Starting benchmark run: {self.run_id}")
|
|
|
|
| 107 |
self.logger.info(f"Benchmark: {benchmark}")
|
| 108 |
+
self.logger.info(f"Provider: {llm_provider.provider_name}")
|
| 109 |
+
self.logger.info(f"Model: {llm_provider.model_name}")
|
| 110 |
|
| 111 |
# Test provider connection
|
| 112 |
if not llm_provider.test_connection():
|
| 113 |
self.logger.error("LLM provider connection test failed")
|
| 114 |
return {"error": "LLM provider connection test failed"}
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# Initialize counters
|
| 117 |
processed = 0
|
| 118 |
correct = 0
|
|
|
|
| 121 |
# Determine concurrency
|
| 122 |
max_workers = max(1, int(getattr(self.config, "concurrency", 1) or 1))
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# Process data points in parallel using a bounded thread pool
|
| 125 |
+
with tqdm(total=len(benchmark), desc="Processing questions") as pbar:
|
| 126 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 127 |
+
future_to_index = {executor.submit(self._process_data_point, dp, llm_provider): idx for idx, dp in enumerate(benchmark)}
|
| 128 |
for future in as_completed(future_to_index):
|
| 129 |
idx = future_to_index[future]
|
| 130 |
try:
|
|
|
|
| 159 |
accuracy = (correct / processed) * 100
|
| 160 |
avg_duration = total_duration / processed if processed > 0 else 0.0
|
| 161 |
self.logger.info(
|
| 162 |
+
f"Progress: {processed}/{len(benchmark)} | "
|
| 163 |
f"Accuracy: {accuracy:.2f}% | "
|
| 164 |
f"Avg Duration: {avg_duration:.2f}s"
|
| 165 |
)
|
| 166 |
|
| 167 |
# Save final results
|
| 168 |
summary = self._save_final_results(benchmark)
|
| 169 |
+
|
| 170 |
self.logger.info(f"Benchmark run completed: {self.run_id}")
|
| 171 |
+
self.logger.info(f"Summary: {summary}")
|
| 172 |
+
|
|
|
|
| 173 |
return summary
|
| 174 |
|
| 175 |
def _process_data_point(
|
| 176 |
self,
|
|
|
|
| 177 |
data_point: BenchmarkDataPoint,
|
| 178 |
+
llm_provider: LLMProvider
|
| 179 |
) -> BenchmarkResult:
|
| 180 |
"""Process a single data point.
|
| 181 |
|
| 182 |
Args:
|
|
|
|
| 183 |
data_point (BenchmarkDataPoint): The data point to process
|
| 184 |
+
llm_provider (LLMProvider): The LLM provider to use
|
| 185 |
|
| 186 |
Returns:
|
| 187 |
BenchmarkResult: Result of processing the data point
|
|
|
|
| 189 |
start_time = time.time()
|
| 190 |
|
| 191 |
try:
|
| 192 |
+
# Create request for LLM
|
| 193 |
request = LLMRequest(
|
| 194 |
text=data_point.text,
|
| 195 |
+
images=data_point.images
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
# Get response from LLM
|
|
|
|
| 202 |
model_answer = self._extract_answer(response.content)
|
| 203 |
|
| 204 |
# Check if correct
|
| 205 |
+
is_correct = model_answer == data_point.correct_answer
|
| 206 |
|
| 207 |
+
# Calculate duration
|
| 208 |
duration = time.time() - start_time
|
| 209 |
|
| 210 |
+
# Return result
|
| 211 |
return BenchmarkResult(
|
| 212 |
data_point_id=data_point.id,
|
| 213 |
question=data_point.text,
|
|
|
|
| 219 |
chunk_history=response.chunk_history,
|
| 220 |
metadata={
|
| 221 |
"data_point_metadata": data_point.metadata,
|
|
|
|
|
|
|
| 222 |
"raw_response": response.content,
|
| 223 |
}
|
| 224 |
)
|
|
|
|
| 235 |
error=str(e),
|
| 236 |
chunk_history=None,
|
| 237 |
metadata={
|
| 238 |
+
"data_point_metadata": data_point.metadata
|
|
|
|
|
|
|
| 239 |
}
|
| 240 |
)
|
| 241 |
|
|
|
|
| 257 |
# If no pattern matches, return the full response
|
| 258 |
return response_text.strip()
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
def _save_individual_result(self, result: BenchmarkResult) -> None:
|
| 261 |
"""Save a single result to its own JSON file.
|
| 262 |
|
|
|
|
| 266 |
# Sanitize data_point_id for filename (remove invalid characters)
|
| 267 |
safe_id = re.sub(r'[^\w\-_.]', '_', result.data_point_id)
|
| 268 |
|
| 269 |
+
# Create run_id directory and individual_results subdirectory
|
| 270 |
+
run_dir = self.output_dir / self.run_id
|
| 271 |
+
individual_results_dir = run_dir / "individual_results"
|
| 272 |
+
individual_results_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
|
| 274 |
# Create filename with benchmark name and data point ID
|
| 275 |
filename = f"{self.config.benchmark_name}_{safe_id}.json"
|
| 276 |
+
result_file = individual_results_dir / filename
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
# Convert result to serializable format
|
| 279 |
result_data = {
|
|
|
|
| 288 |
"usage": result.usage,
|
| 289 |
"error": result.error,
|
| 290 |
"chunk_history": result.chunk_history,
|
| 291 |
+
"metadata": result.metadata,
|
| 292 |
}
|
| 293 |
|
| 294 |
# Save to file
|
|
|
|
| 304 |
Returns:
|
| 305 |
Dict[str, Any]: Summary of results
|
| 306 |
"""
|
| 307 |
+
# Create run_id directory and final_results subdirectory
|
| 308 |
+
run_dir = self.output_dir / self.run_id
|
| 309 |
+
final_results_dir = run_dir / "final_results"
|
| 310 |
+
final_results_dir.mkdir(parents=True, exist_ok=True)
|
| 311 |
+
|
| 312 |
# Save detailed results
|
| 313 |
+
results_file = final_results_dir / f"{self.run_id}_results.json"
|
| 314 |
|
| 315 |
# Convert results to serializable format for final file
|
| 316 |
results_data = []
|
|
|
|
| 337 |
|
| 338 |
accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
# Create summary
|
| 341 |
summary = {
|
| 342 |
"run_id": self.run_id,
|
| 343 |
"timestamp": datetime.now().isoformat(),
|
| 344 |
"config": {
|
|
|
|
| 345 |
"benchmark_name": self.config.benchmark_name,
|
| 346 |
+
"provider_name": self.config.provider_name,
|
| 347 |
+
"model_name": self.config.model_name,
|
| 348 |
"temperature": self.config.temperature,
|
| 349 |
"top_p": self.config.top_p,
|
| 350 |
"max_tokens": self.config.max_tokens,
|
|
|
|
| 359 |
"total_questions": total_questions,
|
| 360 |
"total_duration": total_duration,
|
| 361 |
"avg_duration_per_question": total_duration / total_questions if total_questions > 0 else 0,
|
|
|
|
| 362 |
},
|
| 363 |
"results_file": str(results_file),
|
| 364 |
}
|
| 365 |
|
| 366 |
# Save summary
|
| 367 |
+
summary_file = final_results_dir / f"{self.run_id}_summary.json"
|
| 368 |
with open(summary_file, 'w') as f:
|
| 369 |
json.dump(summary, f, indent=2)
|
| 370 |
|
benchmarking/system_prompts.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[MEDICAL_ASSISTANT]
|
| 2 |
+
You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
|
| 3 |
+
Solve using your own vision and reasoning and use tools to complement your reasoning.
|
| 4 |
+
You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
|
| 5 |
+
Think critically about and criticize the tool outputs.
|
| 6 |
+
If you need to look up some information before asking a follow up question, you are allowed to do that.
|
| 7 |
+
|
| 8 |
+
CITATION REQUIREMENTS:
|
| 9 |
+
- When referencing information from RAG and/or web search tools, ALWAYS include numbered citations [1], [2], [3], etc.
|
| 10 |
+
- Use citations immediately after making claims or statements based on the above tool results.
|
| 11 |
+
- Be consistent with citation numbering throughout your response.
|
| 12 |
+
- Only cite sources that actually contain the information you're referencing.
|
| 13 |
+
|
| 14 |
+
Examples:
|
| 15 |
+
- "According to recent research [1], chest X-rays can show signs of pneumonia..."
|
| 16 |
+
- "The medical literature indicates [2] that this condition typically presents with..."
|
| 17 |
+
- "Based on clinical guidelines [3], the recommended treatment approach is..."
|
| 18 |
+
|
| 19 |
+
[CHESTAGENTBENCH_PROMPT]
|
| 20 |
+
You are a highly skilled radiology AI agent, an expert in interpreting medical images, specifically chest X-rays, CT scans, and MRIs, with world-class accuracy and precision.
|
| 21 |
+
Your primary function is to assist in the analysis of these images and answer diagnostic questions.
|
| 22 |
+
|
| 23 |
+
Your task is to provide a step-by-step, structured analysis. First, carefully examine the provided image and describe all relevant findings in a clear, concise manner.
|
| 24 |
+
Next, use your expert medical knowledge to form a differential diagnosis based on these findings. Finally, critically evaluate the provided question and all possible choices.
|
| 25 |
+
|
| 26 |
+
You have access to a suite of powerful tools to aid in your analysis. Use these tools as needed to retrieve external medical knowledge, access patient history, or perform specific image processing tasks.
|
| 27 |
+
You should always scrutinize the output from your tools and integrate it into your reasoning. If tool outputs conflict with your initial assessment, explain the discrepancy and justify your final conclusion.
|
| 28 |
+
You must take care to pass in the image paths exactly or else the tools will not work. Do not mangle up the image paths.
|
| 29 |
+
|
| 30 |
+
Your final response for a multiple-choice question must strictly follow this format, including your step-by-step reasoning:
|
| 31 |
+
1. **Image Analysis:** [Describe image findings here]
|
| 32 |
+
2. **Differential Diagnosis:** [List possible diagnoses and their justifications]
|
| 33 |
+
3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
|
| 34 |
+
4. **Final Answer:** \boxed{A}
|
| 35 |
+
|
| 36 |
+
Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
|