Spaces:
Paused
Paused
added chestagentbench
Browse files- benchmarking/benchmarks/__init__.py +2 -0
- benchmarking/benchmarks/base.py +2 -12
- benchmarking/benchmarks/chestagentbench_benchmark.py +73 -0
- benchmarking/cli.py +3 -1
- benchmarking/llm_providers/base.py +12 -1
- benchmarking/llm_providers/google_provider.py +2 -2
- benchmarking/llm_providers/medrax_provider.py +17 -7
- benchmarking/llm_providers/openai_provider.py +2 -2
- benchmarking/runner.py +16 -14
- medrax/docs/system_prompts.txt +5 -4
benchmarking/benchmarks/__init__.py
CHANGED
|
@@ -2,9 +2,11 @@
|
|
| 2 |
|
| 3 |
from .base import Benchmark, BenchmarkDataPoint
|
| 4 |
from .rexvqa_benchmark import ReXVQABenchmark
|
|
|
|
| 5 |
|
| 6 |
__all__ = [
|
| 7 |
"Benchmark",
|
| 8 |
"BenchmarkDataPoint",
|
| 9 |
"ReXVQABenchmark",
|
|
|
|
| 10 |
]
|
|
|
|
| 2 |
|
| 3 |
from .base import Benchmark, BenchmarkDataPoint
|
| 4 |
from .rexvqa_benchmark import ReXVQABenchmark
|
| 5 |
+
from .chestagentbench_benchmark import ChestAgentBenchBenchmark
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
"Benchmark",
|
| 9 |
"BenchmarkDataPoint",
|
| 10 |
"ReXVQABenchmark",
|
| 11 |
+
"ChestAgentBenchBenchmark",
|
| 12 |
]
|
benchmarking/benchmarks/base.py
CHANGED
|
@@ -138,7 +138,7 @@ class Benchmark(ABC):
|
|
| 138 |
"categories": self.get_categories(),
|
| 139 |
"category_counts": {},
|
| 140 |
"has_images": False,
|
| 141 |
-
"
|
| 142 |
}
|
| 143 |
|
| 144 |
for dp in self:
|
|
@@ -149,17 +149,7 @@ class Benchmark(ABC):
|
|
| 149 |
# Image statistics
|
| 150 |
if dp.images:
|
| 151 |
stats["has_images"] = True
|
| 152 |
-
stats["
|
| 153 |
-
else:
|
| 154 |
-
stats["images_per_question"].append(0)
|
| 155 |
-
|
| 156 |
-
if stats["images_per_question"]:
|
| 157 |
-
stats["avg_images_per_question"] = sum(stats["images_per_question"]) / len(stats["images_per_question"])
|
| 158 |
-
stats["max_images_per_question"] = max(stats["images_per_question"])
|
| 159 |
-
else:
|
| 160 |
-
stats["avg_images_per_question"] = 0
|
| 161 |
-
stats["max_images_per_question"] = 0
|
| 162 |
-
|
| 163 |
return stats
|
| 164 |
|
| 165 |
def validate_images(self) -> Tuple[List[str], List[str]]:
|
|
|
|
| 138 |
"categories": self.get_categories(),
|
| 139 |
"category_counts": {},
|
| 140 |
"has_images": False,
|
| 141 |
+
"num_images": 0,
|
| 142 |
}
|
| 143 |
|
| 144 |
for dp in self:
|
|
|
|
| 149 |
# Image statistics
|
| 150 |
if dp.images:
|
| 151 |
stats["has_images"] = True
|
| 152 |
+
stats["num_images"] += len(dp.images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return stats
|
| 154 |
|
| 155 |
def validate_images(self) -> Tuple[List[str], List[str]]:
|
benchmarking/benchmarks/chestagentbench_benchmark.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, List, Optional, Any
|
| 5 |
+
from .base import Benchmark, BenchmarkDataPoint
|
| 6 |
+
|
| 7 |
+
class ChestAgentBenchBenchmark(Benchmark):
|
| 8 |
+
"""ChestAgentBench benchmark for complex CXR interpretation and reasoning.
|
| 9 |
+
|
| 10 |
+
Loads the dataset from a local metadata.jsonl file and parses each entry into a BenchmarkDataPoint.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, data_dir: str, **kwargs):
|
| 13 |
+
self.max_questions = kwargs.get("max_questions", None)
|
| 14 |
+
super().__init__(data_dir, **kwargs)
|
| 15 |
+
|
| 16 |
+
def _load_data(self) -> None:
|
| 17 |
+
metadata_path = Path(self.data_dir) / "metadata.jsonl"
|
| 18 |
+
if not metadata_path.exists():
|
| 19 |
+
raise FileNotFoundError(f"Could not find metadata.jsonl in {self.data_dir}")
|
| 20 |
+
print(f"Loading ChestAgentBench from local file: {metadata_path}")
|
| 21 |
+
self.data_points = []
|
| 22 |
+
with open(metadata_path, "r", encoding="utf-8") as f:
|
| 23 |
+
for i, line in enumerate(f):
|
| 24 |
+
if self.max_questions and i >= self.max_questions:
|
| 25 |
+
break
|
| 26 |
+
try:
|
| 27 |
+
item = json.loads(line)
|
| 28 |
+
data_point = self._parse_item(item, i)
|
| 29 |
+
if data_point:
|
| 30 |
+
self.data_points.append(data_point)
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error loading item {i}: {e}")
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
|
| 36 |
+
# Use full_question_id or question_id if available, else fallback
|
| 37 |
+
question_id = item.get("full_question_id") or item.get("question_id") or f"chestagentbench_{index}"
|
| 38 |
+
question = item.get("question", "")
|
| 39 |
+
correct_answer = item.get("answer", "")
|
| 40 |
+
explanation = item.get("explanation", "")
|
| 41 |
+
images = item.get("images", [])
|
| 42 |
+
case_id = item.get("case_id", "")
|
| 43 |
+
category = item.get("categories", "")
|
| 44 |
+
# Compose question text (options are embedded in the question string)
|
| 45 |
+
question_with_options = question
|
| 46 |
+
# Map image paths to local figures directory
|
| 47 |
+
local_images = None
|
| 48 |
+
if images:
|
| 49 |
+
figures_dir = Path(self.data_dir) / "figures"
|
| 50 |
+
local_images = []
|
| 51 |
+
for img in images:
|
| 52 |
+
# Handle relative paths like "figures/11583/figure_1.jpg"
|
| 53 |
+
if img.startswith("figures/"):
|
| 54 |
+
# Remove "figures/" prefix and construct full path
|
| 55 |
+
relative_path = img[8:] # Remove "figures/" prefix
|
| 56 |
+
full_path = figures_dir / relative_path
|
| 57 |
+
local_images.append(str(full_path))
|
| 58 |
+
else:
|
| 59 |
+
# Fallback to original logic
|
| 60 |
+
local_images.append(str(figures_dir / Path(img).name))
|
| 61 |
+
# Metadata
|
| 62 |
+
metadata = dict(item)
|
| 63 |
+
metadata["explanation"] = explanation
|
| 64 |
+
metadata["dataset"] = "chestagentbench"
|
| 65 |
+
return BenchmarkDataPoint(
|
| 66 |
+
id=question_id,
|
| 67 |
+
text=question_with_options,
|
| 68 |
+
images=local_images,
|
| 69 |
+
correct_answer=correct_answer,
|
| 70 |
+
metadata=metadata,
|
| 71 |
+
case_id=case_id,
|
| 72 |
+
category=category,
|
| 73 |
+
)
|
benchmarking/cli.py
CHANGED
|
@@ -45,6 +45,7 @@ def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
|
|
| 45 |
"""
|
| 46 |
benchmark_map = {
|
| 47 |
"rexvqa": ReXVQABenchmark,
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
if benchmark_name not in benchmark_map:
|
|
@@ -70,6 +71,7 @@ def run_benchmark_command(args) -> None:
|
|
| 70 |
|
| 71 |
# Create runner config
|
| 72 |
config = BenchmarkRunConfig(
|
|
|
|
| 73 |
model_name=args.model,
|
| 74 |
benchmark_name=args.benchmark,
|
| 75 |
output_dir=args.output_dir,
|
|
@@ -110,7 +112,7 @@ def main():
|
|
| 110 |
run_parser = subparsers.add_parser("run", help="Run a benchmark")
|
| 111 |
run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
|
| 112 |
run_parser.add_argument("--provider", required=True, choices=["openai", "google", "medrax"], help="LLM provider")
|
| 113 |
-
run_parser.add_argument("--benchmark", required=True, choices=["rexvqa"], help="Benchmark to run")
|
| 114 |
run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
|
| 115 |
run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
|
| 116 |
run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
|
|
|
|
| 45 |
"""
|
| 46 |
benchmark_map = {
|
| 47 |
"rexvqa": ReXVQABenchmark,
|
| 48 |
+
"chestagentbench": ChestAgentBenchBenchmark,
|
| 49 |
}
|
| 50 |
|
| 51 |
if benchmark_name not in benchmark_map:
|
|
|
|
| 71 |
|
| 72 |
# Create runner config
|
| 73 |
config = BenchmarkRunConfig(
|
| 74 |
+
provider_name=args.provider,
|
| 75 |
model_name=args.model,
|
| 76 |
benchmark_name=args.benchmark,
|
| 77 |
output_dir=args.output_dir,
|
|
|
|
| 112 |
run_parser = subparsers.add_parser("run", help="Run a benchmark")
|
| 113 |
run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
|
| 114 |
run_parser.add_argument("--provider", required=True, choices=["openai", "google", "medrax"], help="LLM provider")
|
| 115 |
+
run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
|
| 116 |
run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
|
| 117 |
run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
|
| 118 |
run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
|
benchmarking/llm_providers/base.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Any
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
import base64
|
| 7 |
from pathlib import Path
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
@@ -12,7 +13,6 @@ class LLMRequest:
|
|
| 12 |
"""Request to an LLM provider."""
|
| 13 |
text: str
|
| 14 |
images: Optional[List[str]] = None # List of image paths
|
| 15 |
-
system_prompt: Optional[str] = None
|
| 16 |
temperature: float = 0.7
|
| 17 |
max_tokens: int = 1500
|
| 18 |
additional_params: Optional[Dict[str, Any]] = None
|
|
@@ -43,6 +43,17 @@ class LLMProvider(ABC):
|
|
| 43 |
"""
|
| 44 |
self.model_name = model_name
|
| 45 |
self.config = kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
self._setup()
|
| 47 |
|
| 48 |
@abstractmethod
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
import base64
|
| 7 |
from pathlib import Path
|
| 8 |
+
from medrax.utils.utils import load_prompts_from_file
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
|
|
|
| 13 |
"""Request to an LLM provider."""
|
| 14 |
text: str
|
| 15 |
images: Optional[List[str]] = None # List of image paths
|
|
|
|
| 16 |
temperature: float = 0.7
|
| 17 |
max_tokens: int = 1500
|
| 18 |
additional_params: Optional[Dict[str, Any]] = None
|
|
|
|
| 43 |
"""
|
| 44 |
self.model_name = model_name
|
| 45 |
self.config = kwargs
|
| 46 |
+
|
| 47 |
+
# Always load system prompt from file
|
| 48 |
+
try:
|
| 49 |
+
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
|
| 50 |
+
self.system_prompt = prompts.get("MEDICAL_ASSISTANT", None)
|
| 51 |
+
if self.system_prompt is None:
|
| 52 |
+
print(f"Warning: System prompt type 'MEDICAL_ASSISTANT' not found in medrax/docs/system_prompts.txt.")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Error loading system prompt: {e}")
|
| 55 |
+
self.system_prompt = None
|
| 56 |
+
|
| 57 |
self._setup()
|
| 58 |
|
| 59 |
@abstractmethod
|
benchmarking/llm_providers/google_provider.py
CHANGED
|
@@ -40,8 +40,8 @@ class GoogleProvider(LLMProvider):
|
|
| 40 |
messages = []
|
| 41 |
|
| 42 |
# Add system prompt if provided
|
| 43 |
-
if
|
| 44 |
-
messages.append(SystemMessage(content=
|
| 45 |
|
| 46 |
# Construct content for multimodal content
|
| 47 |
if request.images:
|
|
|
|
| 40 |
messages = []
|
| 41 |
|
| 42 |
# Add system prompt if provided
|
| 43 |
+
if self.system_prompt:
|
| 44 |
+
messages.append(SystemMessage(content=self.system_prompt))
|
| 45 |
|
| 46 |
# Construct content for multimodal content
|
| 47 |
if request.images:
|
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -33,7 +33,7 @@ class MedRAXProvider(LLMProvider):
|
|
| 33 |
print("Starting server...")
|
| 34 |
|
| 35 |
selected_tools = [
|
| 36 |
-
"ImageVisualizerTool", # For displaying images in the UI
|
| 37 |
# "DicomProcessorTool", # For processing DICOM medical image files
|
| 38 |
# "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 39 |
# "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
|
@@ -45,7 +45,7 @@ class MedRAXProvider(LLMProvider):
|
|
| 45 |
# "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
|
| 46 |
"WebBrowserTool", # For web browsing and search capabilities
|
| 47 |
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 48 |
-
"PythonSandboxTool", # Add the Python sandbox tool
|
| 49 |
]
|
| 50 |
|
| 51 |
rag_config = RAGConfig(
|
|
@@ -73,7 +73,7 @@ class MedRAXProvider(LLMProvider):
|
|
| 73 |
tools_to_use=selected_tools,
|
| 74 |
model_dir="/model-weights",
|
| 75 |
temp_dir=self.session_temp_dir, # Change this to the path of the temporary directory
|
| 76 |
-
device="
|
| 77 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 78 |
temperature=0.7,
|
| 79 |
top_p=0.95,
|
|
@@ -118,12 +118,21 @@ class MedRAXProvider(LLMProvider):
|
|
| 118 |
image_paths = []
|
| 119 |
if request.images:
|
| 120 |
valid_images = self._validate_image_paths(request.images)
|
|
|
|
| 121 |
for i, image_path in enumerate(valid_images):
|
|
|
|
| 122 |
# Copy image to session temp directory
|
| 123 |
dest_path = self.session_temp_dir / f"image_{i}_{Path(image_path).name}"
|
|
|
|
| 124 |
shutil.copy2(image_path, dest_path)
|
| 125 |
image_paths.append(str(dest_path))
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
# Add image path message for tools
|
| 128 |
messages.append({
|
| 129 |
"role": "user",
|
|
@@ -167,9 +176,6 @@ class MedRAXProvider(LLMProvider):
|
|
| 167 |
|
| 168 |
duration = time.time() - start_time
|
| 169 |
|
| 170 |
-
# Clean up temporary files
|
| 171 |
-
self._cleanup_temp_files()
|
| 172 |
-
|
| 173 |
return LLMResponse(
|
| 174 |
content=response_content.strip(),
|
| 175 |
usage={"agent_tools": list(self.tools_dict.keys())},
|
|
@@ -178,7 +184,6 @@ class MedRAXProvider(LLMProvider):
|
|
| 178 |
)
|
| 179 |
|
| 180 |
except Exception as e:
|
| 181 |
-
self._cleanup_temp_files()
|
| 182 |
return LLMResponse(
|
| 183 |
content=f"Error: {str(e)}",
|
| 184 |
duration=time.time() - start_time,
|
|
@@ -190,5 +195,10 @@ class MedRAXProvider(LLMProvider):
|
|
| 190 |
try:
|
| 191 |
if hasattr(self, 'session_temp_dir') and self.session_temp_dir.exists():
|
| 192 |
shutil.rmtree(self.session_temp_dir)
|
|
|
|
| 193 |
except Exception as e:
|
| 194 |
print(f"Warning: Failed to cleanup temp files: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
print("Starting server...")
|
| 34 |
|
| 35 |
selected_tools = [
|
| 36 |
+
# "ImageVisualizerTool", # For displaying images in the UI
|
| 37 |
# "DicomProcessorTool", # For processing DICOM medical image files
|
| 38 |
# "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 39 |
# "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
|
|
|
| 45 |
# "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
|
| 46 |
"WebBrowserTool", # For web browsing and search capabilities
|
| 47 |
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 48 |
+
# "PythonSandboxTool", # Add the Python sandbox tool
|
| 49 |
]
|
| 50 |
|
| 51 |
rag_config = RAGConfig(
|
|
|
|
| 73 |
tools_to_use=selected_tools,
|
| 74 |
model_dir="/model-weights",
|
| 75 |
temp_dir=self.session_temp_dir, # Change this to the path of the temporary directory
|
| 76 |
+
device="cpu",
|
| 77 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 78 |
temperature=0.7,
|
| 79 |
top_p=0.95,
|
|
|
|
| 118 |
image_paths = []
|
| 119 |
if request.images:
|
| 120 |
valid_images = self._validate_image_paths(request.images)
|
| 121 |
+
print(f"Processing {len(valid_images)} images")
|
| 122 |
for i, image_path in enumerate(valid_images):
|
| 123 |
+
print(f"Original image path: {image_path}")
|
| 124 |
# Copy image to session temp directory
|
| 125 |
dest_path = self.session_temp_dir / f"image_{i}_{Path(image_path).name}"
|
| 126 |
+
print(f"Destination path: {dest_path}")
|
| 127 |
shutil.copy2(image_path, dest_path)
|
| 128 |
image_paths.append(str(dest_path))
|
| 129 |
|
| 130 |
+
# Verify file exists after copy
|
| 131 |
+
if not dest_path.exists():
|
| 132 |
+
print(f"ERROR: File not found after copy: {dest_path}")
|
| 133 |
+
else:
|
| 134 |
+
print(f"File successfully copied: {dest_path}")
|
| 135 |
+
|
| 136 |
# Add image path message for tools
|
| 137 |
messages.append({
|
| 138 |
"role": "user",
|
|
|
|
| 176 |
|
| 177 |
duration = time.time() - start_time
|
| 178 |
|
|
|
|
|
|
|
|
|
|
| 179 |
return LLMResponse(
|
| 180 |
content=response_content.strip(),
|
| 181 |
usage={"agent_tools": list(self.tools_dict.keys())},
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
except Exception as e:
|
|
|
|
| 187 |
return LLMResponse(
|
| 188 |
content=f"Error: {str(e)}",
|
| 189 |
duration=time.time() - start_time,
|
|
|
|
| 195 |
try:
|
| 196 |
if hasattr(self, 'session_temp_dir') and self.session_temp_dir.exists():
|
| 197 |
shutil.rmtree(self.session_temp_dir)
|
| 198 |
+
print(f"Cleaned up temporary directory: {self.session_temp_dir}")
|
| 199 |
except Exception as e:
|
| 200 |
print(f"Warning: Failed to cleanup temp files: {e}")
|
| 201 |
+
|
| 202 |
+
def cleanup(self) -> None:
|
| 203 |
+
"""Clean up resources when done with the provider."""
|
| 204 |
+
self._cleanup_temp_files()
|
benchmarking/llm_providers/openai_provider.py
CHANGED
|
@@ -48,8 +48,8 @@ class OpenAIProvider(LLMProvider):
|
|
| 48 |
messages = []
|
| 49 |
|
| 50 |
# Add system prompt if provided
|
| 51 |
-
if
|
| 52 |
-
messages.append(SystemMessage(content=
|
| 53 |
|
| 54 |
# Build user message content
|
| 55 |
user_content = []
|
|
|
|
| 48 |
messages = []
|
| 49 |
|
| 50 |
# Add system prompt if provided
|
| 51 |
+
if self.system_prompt:
|
| 52 |
+
messages.append(SystemMessage(content=self.system_prompt))
|
| 53 |
|
| 54 |
# Build user message content
|
| 55 |
user_content = []
|
benchmarking/runner.py
CHANGED
|
@@ -30,16 +30,13 @@ class BenchmarkResult:
|
|
| 30 |
@dataclass
|
| 31 |
class BenchmarkRunConfig:
|
| 32 |
"""Configuration for a benchmark run."""
|
|
|
|
| 33 |
model_name: str
|
| 34 |
benchmark_name: str
|
| 35 |
output_dir: str
|
| 36 |
max_questions: Optional[int] = None
|
| 37 |
-
start_index: int = 0
|
| 38 |
temperature: float = 0.7
|
| 39 |
max_tokens: int = 1500
|
| 40 |
-
system_prompt: Optional[str] = None
|
| 41 |
-
save_frequency: int = 10 # Save results every N questions
|
| 42 |
-
log_level: str = "INFO"
|
| 43 |
additional_params: Optional[Dict[str, Any]] = None
|
| 44 |
|
| 45 |
|
|
@@ -58,7 +55,7 @@ class BenchmarkRunner:
|
|
| 58 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
|
| 60 |
# Generate unique run ID
|
| 61 |
-
self.run_id = f"{config.benchmark_name}_{config.
|
| 62 |
|
| 63 |
# Set up logging
|
| 64 |
self._setup_logging()
|
|
@@ -71,7 +68,7 @@ class BenchmarkRunner:
|
|
| 71 |
|
| 72 |
# Create logger
|
| 73 |
self.logger = logging.getLogger(f"benchmark_runner_{self.run_id}")
|
| 74 |
-
self.logger.setLevel(
|
| 75 |
|
| 76 |
# Create handlers
|
| 77 |
file_handler = logging.FileHandler(log_file)
|
|
@@ -114,9 +111,9 @@ class BenchmarkRunner:
|
|
| 114 |
# Get data points to process
|
| 115 |
total_questions = len(benchmark)
|
| 116 |
max_questions = self.config.max_questions or total_questions
|
| 117 |
-
end_index = min(
|
| 118 |
|
| 119 |
-
self.logger.info(f"Processing questions {
|
| 120 |
|
| 121 |
# Initialize counters
|
| 122 |
processed = 0
|
|
@@ -124,7 +121,7 @@ class BenchmarkRunner:
|
|
| 124 |
total_duration = 0.0
|
| 125 |
|
| 126 |
# Process each data point
|
| 127 |
-
for i in tqdm(range(
|
| 128 |
try:
|
| 129 |
data_point = benchmark.get_data_point(i)
|
| 130 |
|
|
@@ -141,13 +138,13 @@ class BenchmarkRunner:
|
|
| 141 |
self.results.append(result)
|
| 142 |
|
| 143 |
# Log progress
|
| 144 |
-
if processed %
|
| 145 |
self._save_intermediate_results()
|
| 146 |
accuracy = (correct / processed) * 100
|
| 147 |
avg_duration = total_duration / processed
|
| 148 |
|
| 149 |
self.logger.info(
|
| 150 |
-
f"Progress: {processed}/{end_index
|
| 151 |
f"Accuracy: {accuracy:.2f}% | "
|
| 152 |
f"Avg Duration: {avg_duration:.2f}s"
|
| 153 |
)
|
|
@@ -170,6 +167,14 @@ class BenchmarkRunner:
|
|
| 170 |
# Save final results
|
| 171 |
summary = self._save_final_results(benchmark)
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
self.logger.info(f"Benchmark run completed: {self.run_id}")
|
| 174 |
self.logger.info(f"Final accuracy: {summary['results']['accuracy']:.2f}%")
|
| 175 |
self.logger.info(f"Total duration: {summary['results']['total_duration']:.2f}s")
|
|
@@ -197,7 +202,6 @@ class BenchmarkRunner:
|
|
| 197 |
request = LLMRequest(
|
| 198 |
text=data_point.text,
|
| 199 |
images=data_point.images,
|
| 200 |
-
system_prompt=self.config.system_prompt,
|
| 201 |
temperature=self.config.temperature,
|
| 202 |
max_tokens=self.config.max_tokens,
|
| 203 |
additional_params=self.config.additional_params
|
|
@@ -371,12 +375,10 @@ class BenchmarkRunner:
|
|
| 371 |
"benchmark_name": self.config.benchmark_name,
|
| 372 |
"temperature": self.config.temperature,
|
| 373 |
"max_tokens": self.config.max_tokens,
|
| 374 |
-
"system_prompt": self.config.system_prompt,
|
| 375 |
},
|
| 376 |
"benchmark_info": {
|
| 377 |
"total_size": len(benchmark),
|
| 378 |
"processed_questions": total_questions,
|
| 379 |
-
"start_index": self.config.start_index,
|
| 380 |
},
|
| 381 |
"results": {
|
| 382 |
"accuracy": accuracy,
|
|
|
|
| 30 |
@dataclass
|
| 31 |
class BenchmarkRunConfig:
|
| 32 |
"""Configuration for a benchmark run."""
|
| 33 |
+
provider_name: str
|
| 34 |
model_name: str
|
| 35 |
benchmark_name: str
|
| 36 |
output_dir: str
|
| 37 |
max_questions: Optional[int] = None
|
|
|
|
| 38 |
temperature: float = 0.7
|
| 39 |
max_tokens: int = 1500
|
|
|
|
|
|
|
|
|
|
| 40 |
additional_params: Optional[Dict[str, Any]] = None
|
| 41 |
|
| 42 |
|
|
|
|
| 55 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 56 |
|
| 57 |
# Generate unique run ID
|
| 58 |
+
self.run_id = f"{config.benchmark_name}_{config.provider_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 59 |
|
| 60 |
# Set up logging
|
| 61 |
self._setup_logging()
|
|
|
|
| 68 |
|
| 69 |
# Create logger
|
| 70 |
self.logger = logging.getLogger(f"benchmark_runner_{self.run_id}")
|
| 71 |
+
self.logger.setLevel(logging.INFO)
|
| 72 |
|
| 73 |
# Create handlers
|
| 74 |
file_handler = logging.FileHandler(log_file)
|
|
|
|
| 111 |
# Get data points to process
|
| 112 |
total_questions = len(benchmark)
|
| 113 |
max_questions = self.config.max_questions or total_questions
|
| 114 |
+
end_index = min(max_questions, total_questions)
|
| 115 |
|
| 116 |
+
self.logger.info(f"Processing questions {0} to {end_index-1} of {total_questions}")
|
| 117 |
|
| 118 |
# Initialize counters
|
| 119 |
processed = 0
|
|
|
|
| 121 |
total_duration = 0.0
|
| 122 |
|
| 123 |
# Process each data point
|
| 124 |
+
for i in tqdm(range(0, end_index), desc="Processing questions"):
|
| 125 |
try:
|
| 126 |
data_point = benchmark.get_data_point(i)
|
| 127 |
|
|
|
|
| 138 |
self.results.append(result)
|
| 139 |
|
| 140 |
# Log progress
|
| 141 |
+
if processed % 10 == 0:
|
| 142 |
self._save_intermediate_results()
|
| 143 |
accuracy = (correct / processed) * 100
|
| 144 |
avg_duration = total_duration / processed
|
| 145 |
|
| 146 |
self.logger.info(
|
| 147 |
+
f"Progress: {processed}/{end_index} | "
|
| 148 |
f"Accuracy: {accuracy:.2f}% | "
|
| 149 |
f"Avg Duration: {avg_duration:.2f}s"
|
| 150 |
)
|
|
|
|
| 167 |
# Save final results
|
| 168 |
summary = self._save_final_results(benchmark)
|
| 169 |
|
| 170 |
+
# Clean up provider resources
|
| 171 |
+
if hasattr(llm_provider, 'cleanup'):
|
| 172 |
+
try:
|
| 173 |
+
llm_provider.cleanup()
|
| 174 |
+
self.logger.info("Provider cleanup completed")
|
| 175 |
+
except Exception as e:
|
| 176 |
+
self.logger.warning(f"Provider cleanup failed: {e}")
|
| 177 |
+
|
| 178 |
self.logger.info(f"Benchmark run completed: {self.run_id}")
|
| 179 |
self.logger.info(f"Final accuracy: {summary['results']['accuracy']:.2f}%")
|
| 180 |
self.logger.info(f"Total duration: {summary['results']['total_duration']:.2f}s")
|
|
|
|
| 202 |
request = LLMRequest(
|
| 203 |
text=data_point.text,
|
| 204 |
images=data_point.images,
|
|
|
|
| 205 |
temperature=self.config.temperature,
|
| 206 |
max_tokens=self.config.max_tokens,
|
| 207 |
additional_params=self.config.additional_params
|
|
|
|
| 375 |
"benchmark_name": self.config.benchmark_name,
|
| 376 |
"temperature": self.config.temperature,
|
| 377 |
"max_tokens": self.config.max_tokens,
|
|
|
|
| 378 |
},
|
| 379 |
"benchmark_info": {
|
| 380 |
"total_size": len(benchmark),
|
| 381 |
"processed_questions": total_questions,
|
|
|
|
| 382 |
},
|
| 383 |
"results": {
|
| 384 |
"accuracy": accuracy,
|
medrax/docs/system_prompts.txt
CHANGED
|
@@ -4,12 +4,13 @@ Solve using your own vision and reasoning and use tools to complement your reaso
|
|
| 4 |
Make multiple tool calls in parallel or sequence as needed for comprehensive answers.
|
| 5 |
Critically think about and criticize the tool outputs.
|
| 6 |
If you need to look up some information before asking a follow up question, you are allowed to do that.
|
|
|
|
| 7 |
|
| 8 |
CITATION REQUIREMENTS:
|
| 9 |
-
- When referencing information from
|
| 10 |
-
- Use citations immediately after making claims or statements based on the above tool results
|
| 11 |
-
- Be consistent with citation numbering throughout your response
|
| 12 |
-
- Only cite sources that actually contain the information you're referencing
|
| 13 |
|
| 14 |
Examples:
|
| 15 |
- "According to recent research [1], chest X-rays can show signs of pneumonia..."
|
|
|
|
| 4 |
Make multiple tool calls in parallel or sequence as needed for comprehensive answers.
|
| 5 |
Critically think about and criticize the tool outputs.
|
| 6 |
If you need to look up some information before asking a follow up question, you are allowed to do that.
|
| 7 |
+
When encountering a multiple-choice question, give the final answer in closed parentheses without further elaborations; give a definitive answer even if you're not sure.
|
| 8 |
|
| 9 |
CITATION REQUIREMENTS:
|
| 10 |
+
- When referencing information from RAG and/or web search tools, ALWAYS include numbered citations [1], [2], [3], etc.
|
| 11 |
+
- Use citations immediately after making claims or statements based on the above tool results.
|
| 12 |
+
- Be consistent with citation numbering throughout your response.
|
| 13 |
+
- Only cite sources that actually contain the information you're referencing.
|
| 14 |
|
| 15 |
Examples:
|
| 16 |
- "According to recent research [1], chest X-rays can show signs of pneumonia..."
|