VictorLJZ commited on
Commit
35945d9
·
2 Parent(s): 0d14c76 ab428dd

Merge pull request #16 from bowang-lab/victor

Browse files
.gitignore CHANGED
@@ -175,4 +175,6 @@ temp/
175
 
176
  hf_files/
177
  medrax-pdfs/
178
- model-weights/
 
 
 
175
 
176
  hf_files/
177
  medrax-pdfs/
178
+ model-weights/
179
+
180
+ .DS_Store
benchmarking/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Benchmarking pipeline for MedRAX2 and other medical AI models."""
benchmarking/benchmarks/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark abstractions for medical AI evaluation."""
2
+
3
+ from .base import Benchmark, BenchmarkDataPoint
4
+ from .rexvqa_benchmark import ReXVQABenchmark
5
+ from .chestagentbench_benchmark import ChestAgentBenchBenchmark
6
+
7
+ __all__ = [
8
+ "Benchmark",
9
+ "BenchmarkDataPoint",
10
+ "ReXVQABenchmark",
11
+ "ChestAgentBenchBenchmark",
12
+ ]
benchmarking/benchmarks/base.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for benchmarks."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Optional, Any, Iterator, Tuple
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+
9
+ @dataclass
10
+ class BenchmarkDataPoint:
11
+ """A single data point from a benchmark."""
12
+ id: str
13
+ text: str # The question/prompt
14
+ images: Optional[List[str]] = None # List of image paths
15
+ correct_answer: Optional[str] = None # Ground truth answer
16
+ case_id: Optional[str] = None # For grouping related questions
17
+ category: Optional[str] = None # Type of question/task
18
+ metadata: Optional[Dict[str, Any]] = None # Additional metadata
19
+
20
+
21
+ class Benchmark(ABC):
22
+ """Abstract base class for benchmarks.
23
+
24
+ This class defines the interface for all benchmarks, standardizing
25
+ how data is loaded and accessed across different benchmark datasets.
26
+ """
27
+
28
+ def __init__(self, data_dir: str, **kwargs):
29
+ """Initialize the benchmark.
30
+
31
+ Args:
32
+ data_dir (str): Directory containing benchmark data
33
+ **kwargs: Additional configuration parameters
34
+ """
35
+ self.data_dir = Path(data_dir)
36
+ self.config = kwargs
37
+ self.data_points = []
38
+ self._load_data()
39
+
40
+ @abstractmethod
41
+ def _load_data(self) -> None:
42
+ """Load benchmark data from the data directory."""
43
+ pass
44
+
45
+ def get_data_point(self, index: int) -> BenchmarkDataPoint:
46
+ """Get a specific data point by index.
47
+
48
+ Args:
49
+ index (int): Index of the data point to retrieve
50
+
51
+ Returns:
52
+ BenchmarkDataPoint: The data point at the given index
53
+ """
54
+ if index < 0 or index >= len(self.data_points):
55
+ raise IndexError(f"Index {index} out of range for {len(self.data_points)} data points")
56
+
57
+ return self.data_points[index]
58
+
59
+ def get_subset(self, indices: List[int]) -> List[BenchmarkDataPoint]:
60
+ """Get a subset of data points by indices.
61
+
62
+ Args:
63
+ indices (List[int]): List of indices to retrieve
64
+
65
+ Returns:
66
+ List[BenchmarkDataPoint]: List of data points at the given indices
67
+ """
68
+ return [self.get_data_point(i) for i in indices]
69
+
70
+ def get_by_category(self, category: str) -> List[BenchmarkDataPoint]:
71
+ """Get all data points of a specific category.
72
+
73
+ Args:
74
+ category (str): Category to filter by
75
+
76
+ Returns:
77
+ List[BenchmarkDataPoint]: List of data points in the category
78
+ """
79
+ return [dp for dp in self if dp.category == category]
80
+
81
+ def get_by_case_id(self, case_id: str) -> List[BenchmarkDataPoint]:
82
+ """Get all data points for a specific case.
83
+
84
+ Args:
85
+ case_id (str): Case ID to filter by
86
+
87
+ Returns:
88
+ List[BenchmarkDataPoint]: List of data points for the case
89
+ """
90
+ return [dp for dp in self if dp.case_id == case_id]
91
+
92
+ def __str__(self) -> str:
93
+ """String representation of the benchmark."""
94
+ return f"{self.__class__.__name__}(data_dir={self.data_dir}, size={len(self)})"
95
+
96
+ def __len__(self) -> int:
97
+ """Return the number of data points in the benchmark."""
98
+ return len(self.data_points)
99
+
100
+ def __iter__(self) -> Iterator[BenchmarkDataPoint]:
101
+ """Iterate over all data points in the benchmark."""
102
+ for i in range(len(self)):
103
+ yield self.get_data_point(i)
104
+
105
+ def get_categories(self) -> List[str]:
106
+ """Get all unique categories in the benchmark.
107
+
108
+ Returns:
109
+ List[str]: List of unique categories
110
+ """
111
+ categories = set()
112
+ for dp in self:
113
+ if dp.category:
114
+ categories.add(dp.category)
115
+ return sorted(list(categories))
116
+
117
+ def get_case_ids(self) -> List[str]:
118
+ """Get all unique case IDs in the benchmark.
119
+
120
+ Returns:
121
+ List[str]: List of unique case IDs
122
+ """
123
+ case_ids = set()
124
+ for dp in self:
125
+ if dp.case_id:
126
+ case_ids.add(dp.case_id)
127
+ return sorted(list(case_ids))
128
+
129
+ def get_stats(self) -> Dict[str, Any]:
130
+ """Get statistics about the benchmark.
131
+
132
+ Returns:
133
+ Dict[str, Any]: Dictionary containing benchmark statistics
134
+ """
135
+ stats = {
136
+ "total_questions": len(self),
137
+ "total_cases": len(self.get_case_ids()),
138
+ "categories": self.get_categories(),
139
+ "category_counts": {},
140
+ "has_images": False,
141
+ "num_images": 0,
142
+ }
143
+
144
+ for dp in self:
145
+ # Category counts
146
+ if dp.category:
147
+ stats["category_counts"][dp.category] = stats["category_counts"].get(dp.category, 0) + 1
148
+
149
+ # Image statistics
150
+ if dp.images:
151
+ stats["has_images"] = True
152
+ stats["num_images"] += len(dp.images)
153
+ return stats
154
+
155
+ def validate_images(self) -> Tuple[List[str], List[str]]:
156
+ """Validate that all image paths exist.
157
+
158
+ Returns:
159
+ Tuple[List[str], List[str]]: Tuple of (valid_image_paths, invalid_image_paths)
160
+ """
161
+ valid_images = []
162
+ invalid_images = []
163
+
164
+ for dp in self:
165
+ if dp.images:
166
+ for image_path in dp.images:
167
+ if Path(image_path).exists():
168
+ valid_images.append(image_path)
169
+ else:
170
+ invalid_images.append(image_path)
171
+
172
+ return valid_images, invalid_images
benchmarking/benchmarks/chestagentbench_benchmark.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, Optional, Any
4
+ from .base import Benchmark, BenchmarkDataPoint
5
+
6
+ class ChestAgentBenchBenchmark(Benchmark):
7
+ """ChestAgentBench benchmark for complex CXR interpretation and reasoning.
8
+
9
+ Loads the dataset from a local metadata.jsonl file and parses each entry into a BenchmarkDataPoint.
10
+ """
11
+ def __init__(self, data_dir: str, **kwargs):
12
+ self.max_questions = kwargs.get("max_questions", None)
13
+ super().__init__(data_dir, **kwargs)
14
+
15
+ def _load_data(self) -> None:
16
+ metadata_path = Path(self.data_dir) / "metadata.jsonl"
17
+ if not metadata_path.exists():
18
+ raise FileNotFoundError(f"Could not find metadata.jsonl in {self.data_dir}")
19
+ print(f"Loading ChestAgentBench from local file: {metadata_path}")
20
+ self.data_points = []
21
+ with open(metadata_path, "r", encoding="utf-8") as f:
22
+ for i, line in enumerate(f):
23
+ if self.max_questions and i >= self.max_questions:
24
+ break
25
+ try:
26
+ item = json.loads(line)
27
+ data_point = self._parse_item(item, i)
28
+ if data_point:
29
+ self.data_points.append(data_point)
30
+ except Exception as e:
31
+ print(f"Error loading item {i}: {e}")
32
+ continue
33
+
34
+ def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
35
+ # Use full_question_id or question_id if available, else fallback
36
+ question_id = item.get("full_question_id") or item.get("question_id") or f"chestagentbench_{index}"
37
+ question = item.get("question", "")
38
+ correct_answer = item.get("answer", "")
39
+ explanation = item.get("explanation", "")
40
+ images = item.get("images", [])
41
+ case_id = item.get("case_id", "")
42
+ category = item.get("categories", "")
43
+ # Compose question text (options are embedded in the question string)
44
+ question_with_options = question
45
+ # Map image paths to local figures directory
46
+ local_images = None
47
+ if images:
48
+ figures_dir = Path(self.data_dir) / "figures"
49
+ local_images = []
50
+ for img in images:
51
+ # Handle relative paths like "figures/11583/figure_1.jpg"
52
+ if img.startswith("figures/"):
53
+ # Remove "figures/" prefix and construct full path
54
+ relative_path = img[8:] # Remove "figures/" prefix
55
+ full_path = figures_dir / relative_path
56
+ local_images.append(str(full_path))
57
+ else:
58
+ # Fallback to original logic
59
+ local_images.append(str(figures_dir / Path(img).name))
60
+ # Metadata
61
+ metadata = dict(item)
62
+ metadata["explanation"] = explanation
63
+ metadata["dataset"] = "chestagentbench"
64
+ return BenchmarkDataPoint(
65
+ id=question_id,
66
+ text=question_with_options,
67
+ images=local_images,
68
+ correct_answer=correct_answer,
69
+ metadata=metadata,
70
+ case_id=case_id,
71
+ category=category,
72
+ )
benchmarking/benchmarks/rexvqa_benchmark.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ReXVQA benchmark implementation."""
2
+
3
+ import json
4
+ import os
5
+ from typing import Dict, List, Optional, Any
6
+ from datasets import load_dataset
7
+ from .base import Benchmark, BenchmarkDataPoint
8
+ from pathlib import Path
9
+
10
+
11
+ class ReXVQABenchmark(Benchmark):
12
+ """ReXVQA benchmark for chest radiology visual question answering.
13
+
14
+ ReXVQA is a large-scale VQA dataset for chest radiology comprising approximately
15
+ 696,000 questions paired with 160,000 chest X-rays. It tests 5 core radiological
16
+ reasoning skills: presence assessment, location analysis, negation detection,
17
+ differential diagnosis, and geometric reasoning.
18
+
19
+ The dataset consists of two separate HuggingFace datasets:
20
+ - ReXVQA: Contains questions, answers, and metadata
21
+ - ReXGradient-160K: Contains metadata only (images are in separate part files)
22
+
23
+ Paper: https://arxiv.org/abs/2506.04353
24
+ Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
25
+ Images: https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K
26
+ """
27
+
28
+ def __init__(self, data_dir: str, **kwargs):
29
+ """Initialize ReXVQA benchmark.
30
+
31
+ Args:
32
+ data_dir (str): Directory to store/cache downloaded data
33
+ **kwargs: Additional configuration parameters
34
+ split (str): Dataset split to use (default: 'test')
35
+ cache_dir (str): Directory for caching HuggingFace datasets
36
+ trust_remote_code (bool): Whether to trust remote code (default: False)
37
+ max_questions (int): Maximum number of questions to load (default: None, load all)
38
+ images_dir (str): Directory containing extracted PNG images (default: None)
39
+ """
40
+ self.split = kwargs.get("split", "test")
41
+ self.cache_dir = kwargs.get("cache_dir", None)
42
+ self.trust_remote_code = kwargs.get("trust_remote_code", False)
43
+ self.max_questions = kwargs.get("max_questions", None)
44
+ self.images_dir = "benchmarking/data/rexvqa/images/deid_png"
45
+ self.image_dataset = None
46
+ self.image_mapping = {} # Maps study_id to image data
47
+
48
+ super().__init__(data_dir, **kwargs)
49
+
50
+ def _load_data(self) -> None:
51
+ """Load ReXVQA data from local JSON file."""
52
+ try:
53
+ # Construct path to the JSON file
54
+ json_file_path = os.path.join("benchmarking", "data", "rexvqa", "test_vqa_data.json")
55
+
56
+ # Check if file exists
57
+ if not os.path.exists(json_file_path):
58
+ raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}")
59
+
60
+ print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}")
61
+
62
+ # Load JSON file directly
63
+ with open(json_file_path, 'r', encoding='utf-8') as f:
64
+ questions_data = json.load(f)
65
+
66
+ # ReXVQA format: {question_id: {question_data}, ...}
67
+ questions_list = []
68
+ for question_id, question_data in questions_data.items():
69
+ # Add the question_id to the question_data for reference
70
+ question_data['id'] = question_id
71
+ questions_list.append(question_data)
72
+
73
+ print(f"Loaded {len(questions_list)} questions from local JSON file")
74
+
75
+ # Load images dataset from ReXGradient-160K (metadata only)
76
+ print("Loading ReXGradient-160K metadata dataset...")
77
+ try:
78
+ self.image_dataset = load_dataset(
79
+ "rajpurkarlab/ReXGradient-160K",
80
+ split="test",
81
+ cache_dir=self.cache_dir,
82
+ trust_remote_code=self.trust_remote_code
83
+ )
84
+ print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
85
+
86
+ # Create mapping from study_id to image metadata
87
+ self._create_image_mapping()
88
+
89
+ except Exception as e:
90
+ print(f"Warning: Could not load ReXGradient-160K dataset: {e}")
91
+ print("Proceeding without images...")
92
+ self.load_images = False
93
+
94
+ self.data_points = []
95
+
96
+ # Process questions (limit if max_questions is specified)
97
+ questions_to_process = questions_list
98
+ if self.max_questions:
99
+ questions_to_process = questions_list[:min(self.max_questions, len(questions_list))]
100
+
101
+ for i, item in enumerate(questions_to_process):
102
+ try:
103
+ data_point = self._parse_rexvqa_item(item, i)
104
+ if data_point:
105
+ self.data_points.append(data_point)
106
+
107
+ except Exception as e:
108
+ print(f"Error loading item {i}: {e}")
109
+ continue
110
+
111
+ except Exception as e:
112
+ raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
113
+
114
+ def _create_image_mapping(self) -> None:
115
+ """Create mapping from study_id to image metadata."""
116
+ if not self.image_dataset:
117
+ return
118
+
119
+ print("Creating image mapping...")
120
+
121
+ for item in self.image_dataset:
122
+ study_instance_uid = item.get("StudyInstanceUid", "")
123
+ if study_instance_uid:
124
+ # Store the image metadata for this study using StudyInstanceUid as key
125
+ if study_instance_uid not in self.image_mapping:
126
+ self.image_mapping[study_instance_uid] = []
127
+ self.image_mapping[study_instance_uid].append(item)
128
+
129
+ print(f"Created image mapping for {len(self.image_mapping)} studies")
130
+
131
+ def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
132
+ """Parse a ReXVQA dataset item.
133
+
134
+ Args:
135
+ item (Dict[str, Any]): Dataset item from JSON file
136
+ index (int): Item index
137
+
138
+ Returns:
139
+ Optional[BenchmarkDataPoint]: Parsed data point
140
+ """
141
+ # Extract basic information
142
+ question_id = item.get("id", f"rexvqa_{self.split}_{index}")
143
+ question = item.get("question", "")
144
+
145
+ # Handle multiple choice options
146
+ options = item.get("options", [])
147
+ if options:
148
+ # Add options to the question for multiple choice format
149
+ question_with_options = question + "\n\nOptions:\n" + "\n".join(options)
150
+ else:
151
+ question_with_options = question
152
+
153
+ # Get correct answer
154
+ correct_answer = item.get("correct_answer", "")
155
+
156
+ if not question:
157
+ return None
158
+
159
+ # Handle images using ImagePath field
160
+ images = None
161
+ if self.images_dir and "ImagePath" in item and item["ImagePath"]:
162
+ images = []
163
+ for rel_path in item["ImagePath"]:
164
+ # Remove leading ../ if present
165
+ norm_rel_path = rel_path.lstrip("./")
166
+ # Join with images_dir root
167
+ full_path = str(Path(self.images_dir).parent / norm_rel_path)
168
+ images.append(full_path)
169
+
170
+ # Extract metadata
171
+ metadata = {
172
+ "dataset": "rexvqa",
173
+ "split": self.split,
174
+ "study_id": item.get("study_id", ""),
175
+ "study_instance_uid": item.get("StudyInstanceUid", ""),
176
+ "reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
177
+ "category": item.get("category", ""),
178
+ "class": item.get("class", ""),
179
+ "subcategory": item.get("subcategory", ""),
180
+ "patient_id": item.get("PatientID", ""),
181
+ "patient_age": item.get("PatientAge", ""),
182
+ "patient_sex": item.get("PatientSex", ""),
183
+ "study_date": item.get("StudyDate", ""),
184
+ "indication": item.get("Indication", ""),
185
+ "findings": item.get("Findings", ""),
186
+ "impression": item.get("Impression", ""),
187
+ "image_modality": item.get("ImageModality", []),
188
+ "image_view_position": item.get("ImageViewPosition", []),
189
+ "correct_answer_explanation": item.get("correct_answer_explanation", ""),
190
+ }
191
+
192
+ case_id = item.get("study_id", "")
193
+ category = item.get("task_name", "")
194
+
195
+ return BenchmarkDataPoint(
196
+ id=question_id,
197
+ text=question_with_options,
198
+ images=images,
199
+ correct_answer=correct_answer,
200
+ metadata=metadata,
201
+ case_id=case_id,
202
+ category=category,
203
+ )
benchmarking/cli.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface for the benchmarking pipeline."""
2
+
3
+ import argparse
4
+ import sys
5
+
6
+ from .llm_providers import *
7
+ from .benchmarks import *
8
+ from .runner import BenchmarkRunner, BenchmarkRunConfig
9
+
10
+
11
+ def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMProvider:
12
+ """Create an LLM provider based on the model name and type.
13
+
14
+ Args:
15
+ model_name (str): Name of the model
16
+ provider_type (str): Type of provider (openai, google, openrouter, anthropic, medrax)
17
+ **kwargs: Additional configuration parameters
18
+
19
+ Returns:
20
+ LLMProvider: The configured LLM provider
21
+ """
22
+ provider_map = {
23
+ "openai": OpenAIProvider,
24
+ "google": GoogleProvider,
25
+ "openrouter": OpenRouterProvider,
26
+ "medrax": MedRAXProvider,
27
+ }
28
+
29
+ if provider_type not in provider_map:
30
+ raise ValueError(f"Unknown provider type: {provider_type}. Available: {list(provider_map.keys())}")
31
+
32
+ provider_class = provider_map[provider_type]
33
+ return provider_class(model_name, **kwargs)
34
+
35
+
36
+ def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
37
+ """Create a benchmark based on the benchmark name.
38
+
39
+ Args:
40
+ benchmark_name (str): Name of the benchmark
41
+ data_dir (str): Directory containing benchmark data
42
+ **kwargs: Additional configuration parameters
43
+
44
+ Returns:
45
+ Benchmark: The configured benchmark
46
+ """
47
+ benchmark_map = {
48
+ "rexvqa": ReXVQABenchmark,
49
+ "chestagentbench": ChestAgentBenchBenchmark,
50
+ }
51
+
52
+ if benchmark_name not in benchmark_map:
53
+ raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
54
+
55
+ benchmark_class = benchmark_map[benchmark_name]
56
+ return benchmark_class(data_dir, **kwargs)
57
+
58
+
59
+ def run_benchmark_command(args) -> None:
60
+ """Run a benchmark."""
61
+ print(f"Running benchmark: {args.benchmark} with model: {args.model}")
62
+
63
+ # Create LLM provider
64
+ provider_kwargs = {}
65
+
66
+ llm_provider = create_llm_provider(args.model, args.provider, **provider_kwargs)
67
+
68
+ # Create benchmark
69
+ benchmark_kwargs = {}
70
+
71
+ benchmark = create_benchmark(args.benchmark, args.data_dir, **benchmark_kwargs)
72
+
73
+ # Create runner config
74
+ config = BenchmarkRunConfig(
75
+ provider_name=args.provider,
76
+ model_name=args.model,
77
+ benchmark_name=args.benchmark,
78
+ output_dir=args.output_dir,
79
+ max_questions=args.max_questions,
80
+ temperature=args.temperature,
81
+ top_p=args.top_p,
82
+ max_tokens=args.max_tokens
83
+ )
84
+
85
+ # Run benchmark
86
+ runner = BenchmarkRunner(config)
87
+ summary = runner.run_benchmark(llm_provider, benchmark)
88
+
89
+ print("\n" + "="*50)
90
+ print("BENCHMARK COMPLETED")
91
+ print("="*50)
92
+
93
+ # Check if benchmark run was successful
94
+ if "error" in summary:
95
+ print(f"Error: {summary['error']}")
96
+ return
97
+
98
+ # Print results
99
+ print(f"Model: {args.model}")
100
+ print(f"Benchmark: {args.benchmark}")
101
+ print(f"Total Questions: {summary['results']['total_questions']}")
102
+ print(f"Correct Answers: {summary['results']['correct_answers']}")
103
+ print(f"Overall Accuracy: {summary['results']['accuracy']:.2f}%")
104
+ print(f"Total Duration: {summary['results']['total_duration']:.2f}s")
105
+ print(f"Results saved to: {summary['results_file']}")
106
+
107
+
108
+ def main():
109
+ """Main CLI entry point."""
110
+ parser = argparse.ArgumentParser(description="MedRAX Benchmarking Pipeline")
111
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
112
+
113
+ # Run benchmark command
114
+ run_parser = subparsers.add_parser("run", help="Run a benchmark")
115
+ run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
116
+ run_parser.add_argument("--provider", required=True, choices=["openai", "google", "openrouter", "medrax"], help="LLM provider")
117
+ run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
118
+ run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
119
+ run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
120
+ run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
121
+ run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
122
+ run_parser.add_argument("--top-p", type=float, default=0.95, help="Top-p value")
123
+ run_parser.add_argument("--max-tokens", type=int, default=1000, help="Maximum tokens per response")
124
+
125
+ run_parser.set_defaults(func=run_benchmark_command)
126
+
127
+ args = parser.parse_args()
128
+
129
+ if args.command is None:
130
+ parser.print_help()
131
+ return
132
+
133
+ try:
134
+ args.func(args)
135
+ except Exception as e:
136
+ print(f"Error: {e}")
137
+ sys.exit(1)
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
benchmarking/data/rexvqa/download_rexgradient_images.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility script to download and extract ReXGradient-160K images.
4
+
5
+ This script helps users download the actual PNG images from the ReXGradient-160K dataset,
6
+ which are stored as part files on HuggingFace and need to be concatenated and extracted.
7
+
8
+ Usage:
9
+ python download_rexgradient_images.py --output_dir /path/to/images
10
+ """
11
+
12
+ import argparse
13
+ import subprocess
14
+ from pathlib import Path
15
+ from huggingface_hub import hf_hub_download, list_repo_files
16
+ import requests
17
+ from tqdm import tqdm
18
+
19
+
20
+ def download_file(url, output_path, chunk_size=8192):
21
+ """Download a file with progress bar."""
22
+ response = requests.get(url, stream=True)
23
+ total_size = int(response.headers.get('content-length', 0))
24
+
25
+ with open(output_path, 'wb') as f:
26
+ with tqdm(total=total_size, unit='B', unit_scale=True, desc=output_path.name) as pbar:
27
+ for chunk in response.iter_content(chunk_size=chunk_size):
28
+ if chunk:
29
+ f.write(chunk)
30
+ pbar.update(len(chunk))
31
+
32
+
33
+ def main():
34
+ parser = argparse.ArgumentParser(description="Download ReXGradient-160K images")
35
+ parser.add_argument(
36
+ "--output_dir",
37
+ type=str,
38
+ required=True,
39
+ help="Directory to save extracted images"
40
+ )
41
+ parser.add_argument(
42
+ "--repo_id",
43
+ type=str,
44
+ default="rajpurkarlab/ReXGradient-160K",
45
+ help="HuggingFace repository ID"
46
+ )
47
+ parser.add_argument(
48
+ "--skip_download",
49
+ action="store_true",
50
+ help="Skip downloading and only extract if files exist"
51
+ )
52
+
53
+ args = parser.parse_args()
54
+
55
+ output_dir = Path(args.output_dir)
56
+ output_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ print(f"Output directory: {output_dir}")
59
+
60
+ # Check if we need to accept the license first
61
+ print("Note: You may need to accept the dataset license on HuggingFace first:")
62
+ print(f"Visit: https://huggingface.co/datasets/{args.repo_id}")
63
+ print("Click 'Access repository' and accept the license agreement.")
64
+ print()
65
+
66
+ try:
67
+ # List files in the repository
68
+ print("Listing files in repository...")
69
+ files = list_repo_files(args.repo_id, repo_type='dataset')
70
+ part_files = [f for f in files if f.startswith("deid_png.part")]
71
+
72
+ if not part_files:
73
+ print("No part files found. The images might be in a different format.")
74
+ print("Available files:")
75
+ for f in files:
76
+ print(f" - {f}")
77
+ return
78
+
79
+ print(f"Found {len(part_files)} part files:")
80
+ for f in part_files:
81
+ print(f" - {f}")
82
+
83
+ # Download part files
84
+ if not args.skip_download:
85
+ print("\nDownloading part files...")
86
+ for part_file in part_files:
87
+ output_path = output_dir / part_file
88
+ if output_path.exists():
89
+ print(f"Skipping {part_file} (already exists)")
90
+ continue
91
+
92
+ print(f"Downloading {part_file}...")
93
+ try:
94
+ hf_hub_download(
95
+ repo_id=args.repo_id,
96
+ filename=part_file,
97
+ local_dir=output_dir,
98
+ local_dir_use_symlinks=False,
99
+ repo_type='dataset'
100
+ )
101
+ except Exception as e:
102
+ print(f"Error downloading {part_file}: {e}")
103
+ print("You may need to accept the license agreement on HuggingFace.")
104
+ return
105
+
106
+ # Concatenate part files
107
+ tar_path = output_dir / "deid_png.tar"
108
+ if not tar_path.exists():
109
+ print("\nConcatenating part files...")
110
+ with open(tar_path, 'wb') as tar_file:
111
+ for part_file in sorted(part_files):
112
+ part_path = output_dir / part_file
113
+ if part_path.exists():
114
+ print(f"Adding {part_file}...")
115
+ with open(part_path, 'rb') as f:
116
+ tar_file.write(f.read())
117
+ else:
118
+ print(f"Warning: {part_file} not found, skipping...")
119
+ else:
120
+ print(f"Tar file already exists: {tar_path}")
121
+
122
+ # Extract tar file
123
+ if tar_path.exists():
124
+ print("\nExtracting images...")
125
+ images_dir = output_dir / "images"
126
+ images_dir.mkdir(exist_ok=True)
127
+
128
+ # Check if already extracted
129
+ if any(images_dir.glob("*.png")):
130
+ print("Images already extracted.")
131
+ else:
132
+ try:
133
+ subprocess.run([
134
+ "tar", "-xf", str(tar_path),
135
+ "-C", str(images_dir)
136
+ ], check=True)
137
+ print("Extraction completed!")
138
+ except subprocess.CalledProcessError as e:
139
+ print(f"Error extracting tar file: {e}")
140
+ return
141
+ except FileNotFoundError:
142
+ print("Error: 'tar' command not found. Please install tar or extract manually.")
143
+ return
144
+
145
+ # Count extracted images
146
+ png_files = list(images_dir.glob("*.png"))
147
+ print(f"Extracted {len(png_files)} PNG images to {images_dir}")
148
+
149
+ # Show some example filenames
150
+ if png_files:
151
+ print("\nExample image filenames:")
152
+ for f in png_files[:5]:
153
+ print(f" - {f.name}")
154
+ if len(png_files) > 5:
155
+ print(f" ... and {len(png_files) - 5} more")
156
+
157
+ print(f"\nSetup complete! Use this directory as images_dir in ReXVQABenchmark:")
158
+ print(f"images_dir='{images_dir}'")
159
+
160
+ except Exception as e:
161
+ print(f"Error: {e}")
162
+ print("\nManual setup instructions:")
163
+ print("1. Visit https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K")
164
+ print("2. Accept the license agreement")
165
+ print("3. Download the deid_png.part* files")
166
+ print("4. Concatenate: cat deid_png.part* > deid_png.tar")
167
+ print("5. Extract: tar -xf deid_png.tar")
168
+ print("6. Use the extracted directory as images_dir")
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
benchmarking/llm_providers/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM provider abstractions for benchmarking."""
2
+
3
+ from .base import LLMProvider, LLMRequest, LLMResponse
4
+ from .openai_provider import OpenAIProvider
5
+ from .google_provider import GoogleProvider
6
+ from .medrax_provider import MedRAXProvider
7
+ from .openrouter_provider import OpenRouterProvider
8
+
9
+ __all__ = [
10
+ "LLMProvider",
11
+ "LLMRequest",
12
+ "LLMResponse",
13
+ "OpenAIProvider",
14
+ "GoogleProvider",
15
+ "MedRAXProvider",
16
+ "OpenRouterProvider",
17
+ ]
benchmarking/llm_providers/base.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for LLM providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Optional, Any
5
+ from dataclasses import dataclass
6
+ import base64
7
+ from pathlib import Path
8
+ from medrax.utils.utils import load_prompts_from_file
9
+
10
+
11
+ @dataclass
12
+ class LLMRequest:
13
+ """Request to an LLM provider."""
14
+ text: str
15
+ images: Optional[List[str]] = None # List of image paths
16
+ temperature: float = 0.7
17
+ top_p: float = 0.95
18
+ max_tokens: int = 5000
19
+ additional_params: Optional[Dict[str, Any]] = None
20
+
21
+
22
+ @dataclass
23
+ class LLMResponse:
24
+ """Response from an LLM provider."""
25
+ content: str
26
+ usage: Optional[Dict[str, Any]] = None
27
+ duration: Optional[float] = None
28
+ raw_response: Optional[Any] = None
29
+
30
+
31
+ class LLMProvider(ABC):
32
+ """Abstract base class for LLM providers.
33
+
34
+ This class defines the interface for all LLM providers, standardizing
35
+ text + image input -> text output across different models and APIs.
36
+ """
37
+
38
+ def __init__(self, model_name: str, **kwargs):
39
+ """Initialize the LLM provider.
40
+
41
+ Args:
42
+ model_name (str): Name of the model to use
43
+ **kwargs: Additional configuration parameters
44
+ """
45
+ self.model_name = model_name
46
+ self.config = kwargs
47
+
48
+ # Always load system prompt from file
49
+ try:
50
+ prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
51
+ self.system_prompt = prompts.get("CHESTAGENTBENCH_PROMPT", None)
52
+ if self.system_prompt is None:
53
+ print(f"Warning: System prompt not found in medrax/docs/system_prompts.txt.")
54
+ except Exception as e:
55
+ print(f"Error loading system prompt: {e}")
56
+ self.system_prompt = None
57
+
58
+ self._setup()
59
+
60
+ @abstractmethod
61
+ def _setup(self) -> None:
62
+ """Set up the provider (API keys, client initialization, etc.)."""
63
+ pass
64
+
65
+ @abstractmethod
66
+ def generate_response(self, request: LLMRequest) -> LLMResponse:
67
+ """Generate a response from the LLM.
68
+
69
+ Args:
70
+ request (LLMRequest): The request containing text, images, and parameters
71
+
72
+ Returns:
73
+ LLMResponse: The response from the LLM
74
+ """
75
+ pass
76
+
77
+ def test_connection(self) -> bool:
78
+ """Test the connection to the LLM provider.
79
+
80
+ Returns:
81
+ bool: True if connection is successful, False otherwise
82
+ """
83
+ try:
84
+ # Simple test request
85
+ test_request = LLMRequest(
86
+ text="Hello",
87
+ temperature=0.5,
88
+ max_tokens=1000
89
+ )
90
+ response = self.generate_response(test_request)
91
+ return response.content is not None and len(response.content.strip()) > 0
92
+ except Exception as e:
93
+ print(f"Connection test failed: {e}")
94
+ return False
95
+
96
+ def _encode_image(self, image_path: str) -> str:
97
+ """Encode image to base64 string.
98
+
99
+ Args:
100
+ image_path (str): Path to the image file
101
+
102
+ Returns:
103
+ str: Base64 encoded image string
104
+ """
105
+ with open(image_path, "rb") as image_file:
106
+ return base64.b64encode(image_file.read()).decode('utf-8')
107
+
108
+ def _validate_image_paths(self, image_paths: List[str]) -> List[str]:
109
+ """Validate that image paths exist and are readable.
110
+
111
+ Args:
112
+ image_paths (List[str]): List of image paths to validate
113
+
114
+ Returns:
115
+ List[str]: List of valid image paths
116
+ """
117
+ valid_paths = []
118
+ for path in image_paths:
119
+ if Path(path).exists() and Path(path).is_file():
120
+ valid_paths.append(path)
121
+ else:
122
+ print(f"Warning: Image path does not exist: {path}")
123
+ return valid_paths
124
+
125
+ def __str__(self) -> str:
126
+ """String representation of the provider."""
127
+ return f"{self.__class__.__name__}(model={self.model_name})"
benchmarking/llm_providers/google_provider.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Google LLM provider implementation using langchain_google_genai."""
2
+
3
+ import os
4
+ import time
5
+ from tenacity import retry, wait_exponential, stop_after_attempt
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_core.messages import HumanMessage, SystemMessage
8
+
9
+ from .base import LLMProvider, LLMRequest, LLMResponse
10
+
11
+
12
+ class GoogleProvider(LLMProvider):
13
+ """Google LLM provider for Gemini models using langchain_google_genai."""
14
+
15
+ def _setup(self) -> None:
16
+ """Set up Google langchain client."""
17
+ api_key = os.getenv("GOOGLE_API_KEY")
18
+ if not api_key:
19
+ raise ValueError("GOOGLE_API_KEY environment variable is required")
20
+
21
+ # Create ChatGoogleGenerativeAI instance
22
+ self.client = ChatGoogleGenerativeAI(
23
+ model=self.model_name,
24
+ google_api_key=api_key
25
+ )
26
+
27
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
28
+ def generate_response(self, request: LLMRequest) -> LLMResponse:
29
+ """Generate response using langchain Google Gemini.
30
+
31
+ Args:
32
+ request (LLMRequest): The request containing text, images, and parameters
33
+
34
+ Returns:
35
+ LLMResponse: The response from Google Gemini
36
+ """
37
+ start_time = time.time()
38
+
39
+ # Build messages
40
+ messages = []
41
+
42
+ # Add system prompt if provided
43
+ if self.system_prompt:
44
+ messages.append(SystemMessage(content=self.system_prompt))
45
+
46
+ # Construct content for multimodal content
47
+ if request.images:
48
+ # For multimodal content, use a list format
49
+ content_parts = [request.text]
50
+
51
+ # Add images if provided
52
+ valid_images = self._validate_image_paths(request.images)
53
+ for image_path in valid_images:
54
+ try:
55
+ # For langchain Google, pass image data as base64
56
+ image_b64 = self._encode_image(image_path)
57
+ content_parts.append({
58
+ "type": "image_url",
59
+ "image_url": f"data:image/jpeg;base64,{image_b64}"
60
+ })
61
+ except Exception as e:
62
+ print(f"Error reading image {image_path}: {e}")
63
+
64
+ messages.append(HumanMessage(content=content_parts))
65
+ else:
66
+ # Text-only message
67
+ messages.append(HumanMessage(content=request.text))
68
+
69
+ # Make API call using langchain
70
+ try:
71
+ # Update client parameters for this request
72
+ self.client.temperature = request.temperature
73
+ self.client.max_output_tokens = request.max_tokens
74
+ self.client.top_p = request.top_p
75
+
76
+ response = self.client.invoke(messages)
77
+
78
+ duration = time.time() - start_time
79
+
80
+ # Extract response content
81
+ content = response.content if response.content else ""
82
+
83
+ # Get usage information if available
84
+ usage = {}
85
+ if hasattr(response, 'usage_metadata') and response.usage_metadata:
86
+ usage = {
87
+ "prompt_tokens": response.usage_metadata.get("input_tokens", 0),
88
+ "completion_tokens": response.usage_metadata.get("output_tokens", 0),
89
+ "total_tokens": response.usage_metadata.get("total_tokens", 0)
90
+ }
91
+
92
+ return LLMResponse(
93
+ content=content,
94
+ usage=usage,
95
+ duration=duration,
96
+ raw_response=response
97
+ )
98
+
99
+ except Exception as e:
100
+ return LLMResponse(
101
+ content=f"Error: {str(e)}",
102
+ duration=time.time() - start_time,
103
+ raw_response=None
104
+ )
benchmarking/llm_providers/medrax_provider.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MedRAX LLM provider implementation."""
2
+
3
+ import time
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ from .base import LLMProvider, LLMRequest, LLMResponse
8
+
9
+ from medrax.rag.rag import RAGConfig
10
+ from main import initialize_agent
11
+
12
+
13
+ class MedRAXProvider(LLMProvider):
14
+ """MedRAX LLM provider that uses the full MedRAX agent system."""
15
+
16
+ def __init__(self, model_name: str, **kwargs):
17
+ """Initialize MedRAX provider.
18
+
19
+ Args:
20
+ model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
21
+ **kwargs: Additional configuration parameters
22
+ """
23
+ self.model_name = model_name
24
+ self.agent = None
25
+ self.tools_dict = None
26
+
27
+ super().__init__(model_name, **kwargs)
28
+
29
+ def _setup(self) -> None:
30
+ """Set up MedRAX agent system."""
31
+ try:
32
+ print("Starting server...")
33
+
34
+ selected_tools = [
35
+ # "ImageVisualizerTool", # For displaying images in the UI
36
+ # "DicomProcessorTool", # For processing DICOM medical image files
37
+ # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
38
+ # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
39
+ # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
40
+ # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
41
+ # "XRayVQATool", # For visual question answering on X-rays
42
+ # "LlavaMedTool", # For multimodal medical image understanding
43
+ # "XRayPhraseGroundingTool", # For locating described features in X-rays
44
+ # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
45
+ "WebBrowserTool", # For web browsing and search capabilities
46
+ "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
47
+ # "PythonSandboxTool", # Add the Python sandbox tool
48
+ ]
49
+
50
+ rag_config = RAGConfig(
51
+ model="command-a-03-2025", # Chat model for generating responses
52
+ embedding_model="embed-v4.0", # Embedding model for the RAG system
53
+ rerank_model="rerank-v3.5", # Reranking model for the RAG system
54
+ temperature=0.3,
55
+ pinecone_index_name="medrax2", # Name for the Pinecone index
56
+ chunk_size=1500,
57
+ chunk_overlap=300,
58
+ retriever_k=7,
59
+ local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
60
+ huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
61
+ dataset_split="train", # Which split of the datasets to use
62
+ )
63
+
64
+ # Prepare any additional model-specific kwargs
65
+ model_kwargs = {}
66
+
67
+ agent, tools_dict = initialize_agent(
68
+ prompt_file="medrax/docs/system_prompts.txt",
69
+ tools_to_use=selected_tools,
70
+ model_dir="/model-weights",
71
+ temp_dir="temp", # Change this to the path of the temporary directory
72
+ device="cpu",
73
+ model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
74
+ temperature=0.7,
75
+ top_p=0.95,
76
+ model_kwargs=model_kwargs,
77
+ rag_config=rag_config,
78
+ debug=True,
79
+ )
80
+
81
+ self.agent = agent
82
+ self.tools_dict = tools_dict
83
+
84
+ print(f"MedRAX agent initialized with tools: {list(self.tools_dict.keys())}")
85
+
86
+ except Exception as e:
87
+ print(f"Error initializing MedRAX agent: {e}")
88
+ raise
89
+
90
+ def generate_response(self, request: LLMRequest) -> LLMResponse:
91
+ """Generate response using MedRAX agent.
92
+
93
+ Args:
94
+ request (LLMRequest): The request containing text, images, and parameters
95
+
96
+ Returns:
97
+ LLMResponse: The response from MedRAX agent
98
+ """
99
+ start_time = time.time()
100
+
101
+ if self.agent is None:
102
+ return LLMResponse(
103
+ content="Error: MedRAX agent not initialized",
104
+ duration=time.time() - start_time,
105
+ raw_response=None
106
+ )
107
+
108
+ try:
109
+ # Build messages for the agent
110
+ messages = []
111
+ thread_id = str(int(time.time() * 1000)) # Unique thread ID
112
+
113
+ # Copy images to session temp directory and provide paths
114
+ image_paths = []
115
+ if request.images:
116
+ valid_images = self._validate_image_paths(request.images)
117
+ print(f"Processing {len(valid_images)} images")
118
+ for i, image_path in enumerate(valid_images):
119
+ print(f"Original image path: {image_path}")
120
+ # Copy image to session temp directory
121
+ dest_path = Path("temp") / f"image_{i}_{Path(image_path).name}"
122
+ print(f"Destination path: {dest_path}")
123
+ shutil.copy2(image_path, dest_path)
124
+ image_paths.append(str(dest_path))
125
+
126
+ # Verify file exists after copy
127
+ if not dest_path.exists():
128
+ print(f"ERROR: File not found after copy: {dest_path}")
129
+ else:
130
+ print(f"File successfully copied: {dest_path}")
131
+
132
+ # Add image path message for tools
133
+ messages.append({
134
+ "role": "user",
135
+ "content": f"image_path: {dest_path}"
136
+ })
137
+
138
+ # Add image content for multimodal LLM
139
+ with open(image_path, "rb") as img_file:
140
+ img_base64 = self._encode_image(image_path)
141
+
142
+ messages.append({
143
+ "role": "user",
144
+ "content": [{
145
+ "type": "image_url",
146
+ "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
147
+ }]
148
+ })
149
+
150
+ # Add text message
151
+ messages.append({
152
+ "role": "user",
153
+ "content": [{
154
+ "type": "text",
155
+ "text": request.text
156
+ }]
157
+ })
158
+
159
+ # Run the agent
160
+ response_content = ""
161
+ for chunk in self.agent.workflow.stream(
162
+ {"messages": messages},
163
+ {"configurable": {"thread_id": thread_id}},
164
+ stream_mode="updates"
165
+ ):
166
+ if isinstance(chunk, dict):
167
+ for node_name, node_output in chunk.items():
168
+ if "messages" in node_output:
169
+ for msg in node_output["messages"]:
170
+ if hasattr(msg, 'content') and msg.content:
171
+ response_content += str(msg.content)
172
+
173
+ duration = time.time() - start_time
174
+
175
+ return LLMResponse(
176
+ content=response_content.strip(),
177
+ usage={"agent_tools": list(self.tools_dict.keys())},
178
+ duration=duration,
179
+ raw_response={"thread_id": thread_id, "image_paths": image_paths}
180
+ )
181
+
182
+ except Exception as e:
183
+ return LLMResponse(
184
+ content=f"Error: {str(e)}",
185
+ duration=time.time() - start_time,
186
+ raw_response=None
187
+ )
benchmarking/llm_providers/openai_provider.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI LLM provider implementation using langchain_openai."""
2
+
3
+ import os
4
+ import time
5
+ from tenacity import retry, wait_exponential, stop_after_attempt
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core.messages import HumanMessage, SystemMessage
8
+
9
+ from .base import LLMProvider, LLMRequest, LLMResponse
10
+
11
+
12
+ class OpenAIProvider(LLMProvider):
13
+ """OpenAI LLM provider for GPT models using langchain_openai."""
14
+
15
+ def _setup(self) -> None:
16
+ """Set up OpenAI langchain client."""
17
+ api_key = os.getenv("OPENAI_API_KEY")
18
+ if not api_key:
19
+ raise ValueError("OPENAI_API_KEY environment variable is required")
20
+
21
+ base_url = os.getenv("OPENAI_BASE_URL")
22
+
23
+ # Create ChatOpenAI instance
24
+ kwargs = {
25
+ "model": self.model_name,
26
+ "api_key": api_key,
27
+ }
28
+
29
+ if base_url:
30
+ kwargs["base_url"] = base_url
31
+
32
+ self.client = ChatOpenAI(**kwargs)
33
+
34
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
35
+ def generate_response(self, request: LLMRequest) -> LLMResponse:
36
+ """Generate response using langchain OpenAI.
37
+
38
+ Args:
39
+ request (LLMRequest): The request containing text, images, and parameters
40
+
41
+ Returns:
42
+ LLMResponse: The response from OpenAI
43
+ """
44
+ start_time = time.time()
45
+
46
+ # Build messages
47
+ messages = []
48
+
49
+ # Add system prompt if provided
50
+ if self.system_prompt:
51
+ messages.append(SystemMessage(content=self.system_prompt))
52
+
53
+ # Build user message content
54
+ user_content = []
55
+ user_content.append({
56
+ "type": "text",
57
+ "text": request.text
58
+ })
59
+
60
+ # Add images if provided
61
+ if request.images:
62
+ valid_images = self._validate_image_paths(request.images)
63
+ for image_path in valid_images:
64
+ try:
65
+ image_b64 = self._encode_image(image_path)
66
+ user_content.append({
67
+ "type": "image_url",
68
+ "image_url": {
69
+ "url": f"data:image/jpeg;base64,{image_b64}",
70
+ "detail": "high"
71
+ }
72
+ })
73
+ except Exception as e:
74
+ print(f"Error reading image {image_path}: {e}")
75
+
76
+ messages.append(HumanMessage(content=user_content))
77
+
78
+ # Make API call using langchain
79
+ try:
80
+ # Update client parameters for this request
81
+ self.client.temperature = request.temperature
82
+ self.client.max_tokens = request.max_tokens
83
+ self.client.top_p = request.top_p
84
+
85
+ response = self.client.invoke(messages)
86
+
87
+ duration = time.time() - start_time
88
+
89
+ # Extract response content
90
+ content = response.content if response.content else ""
91
+
92
+ # Get usage information if available
93
+ usage = {}
94
+ if hasattr(response, 'usage_metadata') and response.usage_metadata:
95
+ usage = {
96
+ "prompt_tokens": response.usage_metadata.get("input_tokens", 0),
97
+ "completion_tokens": response.usage_metadata.get("output_tokens", 0),
98
+ "total_tokens": response.usage_metadata.get("total_tokens", 0)
99
+ }
100
+
101
+ return LLMResponse(
102
+ content=content,
103
+ usage=usage,
104
+ duration=duration,
105
+ raw_response=response
106
+ )
107
+
108
+ except Exception as e:
109
+ return LLMResponse(
110
+ content=f"Error: {str(e)}",
111
+ duration=time.time() - start_time,
112
+ raw_response=None
113
+ )
benchmarking/llm_providers/openrouter_provider.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """xAI LLM provider implementation using OpenRouter API via OpenAI SDK."""
2
+
3
+ import os
4
+ import time
5
+ from tenacity import retry, wait_exponential, stop_after_attempt
6
+ from openai import OpenAI
7
+
8
+ from .base import LLMProvider, LLMRequest, LLMResponse
9
+
10
+
11
+ class OpenRouterProvider(LLMProvider):
12
+ """LLM provider using OpenRouter API via OpenAI SDK."""
13
+
14
+ def _setup(self) -> None:
15
+ """Set up OpenRouter client models."""
16
+ api_key = os.getenv("OPENROUTER_API_KEY")
17
+ if not api_key:
18
+ raise ValueError("OPENROUTER_API_KEY environment variable is required for xAI Grok via OpenRouter.")
19
+ base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
20
+ # Use OpenAI SDK with OpenRouter endpoint
21
+ self.client = OpenAI(api_key=api_key, base_url=base_url)
22
+
23
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
24
+ def generate_response(self, request: LLMRequest) -> LLMResponse:
25
+ """Generate response using OpenRouter model via OpenAI SDK.
26
+
27
+ Args:
28
+ request (LLMRequest): The request containing text, images, and parameters
29
+ Returns:
30
+ LLMResponse: The response from OpenRouter
31
+ """
32
+ start_time = time.time()
33
+
34
+ # Build messages
35
+ messages = []
36
+ if self.system_prompt:
37
+ messages.append({"role": "system", "content": self.system_prompt})
38
+
39
+ user_content = []
40
+ user_content.append({"type": "text", "text": request.text})
41
+
42
+ # Add images if provided
43
+ if request.images:
44
+ valid_images = self._validate_image_paths(request.images)
45
+ for image_path in valid_images:
46
+ try:
47
+ image_b64 = self._encode_image(image_path)
48
+ user_content.append({
49
+ "type": "image_url",
50
+ "image_url": {
51
+ "url": f"data:image/jpeg;base64,{image_b64}",
52
+ "detail": "high"
53
+ }
54
+ })
55
+ except Exception as e:
56
+ print(f"Error reading image {image_path}: {e}")
57
+
58
+ messages.append({"role": "user", "content": user_content})
59
+
60
+ try:
61
+ response = self.client.chat.completions.create(
62
+ model=self.model_name,
63
+ messages=messages,
64
+ temperature=request.temperature,
65
+ top_p=request.top_p,
66
+ max_tokens=request.max_tokens,
67
+ **(request.additional_params or {})
68
+ )
69
+ duration = time.time() - start_time
70
+ content = response.choices[0].message.content if response.choices else ""
71
+ usage = {}
72
+ if hasattr(response, 'usage') and response.usage:
73
+ usage = {
74
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
75
+ "completion_tokens": getattr(response.usage, "completion_tokens", 0),
76
+ "total_tokens": getattr(response.usage, "total_tokens", 0)
77
+ }
78
+ return LLMResponse(
79
+ content=content,
80
+ usage=usage,
81
+ duration=duration,
82
+ raw_response=response
83
+ )
84
+ except Exception as e:
85
+ return LLMResponse(
86
+ content=f"Error: {str(e)}",
87
+ duration=time.time() - start_time,
88
+ raw_response=None
89
+ )
benchmarking/runner.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main test runner for benchmarking pipeline."""
2
+
3
+ import json
4
+ import time
5
+ import logging
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Dict, Optional, Any
9
+ from dataclasses import dataclass
10
+ from tqdm import tqdm
11
+ import re
12
+ from .llm_providers import LLMProvider, LLMRequest, LLMResponse
13
+ from .benchmarks import Benchmark, BenchmarkDataPoint
14
+
15
+
16
+ @dataclass
17
+ class BenchmarkResult:
18
+ """Result of running a benchmark on a single data point."""
19
+ data_point_id: str
20
+ question: str
21
+ model_answer: str
22
+ correct_answer: str
23
+ is_correct: bool
24
+ duration: float
25
+ usage: Optional[Dict[str, Any]] = None
26
+ error: Optional[str] = None
27
+ metadata: Optional[Dict[str, Any]] = None
28
+
29
+
30
+ @dataclass
31
+ class BenchmarkRunConfig:
32
+ """Configuration for a benchmark run."""
33
+ provider_name: str
34
+ model_name: str
35
+ benchmark_name: str
36
+ output_dir: str
37
+ max_questions: Optional[int] = None
38
+ temperature: float = 0.7
39
+ top_p: float = 0.95
40
+ max_tokens: int = 5000
41
+ additional_params: Optional[Dict[str, Any]] = None
42
+
43
+
44
+ class BenchmarkRunner:
45
+ """Main class for running benchmarks against LLM providers."""
46
+
47
+ def __init__(self, config: BenchmarkRunConfig):
48
+ """Initialize the benchmark runner.
49
+
50
+ Args:
51
+ config (BenchmarkRunConfig): Configuration for the benchmark run
52
+ """
53
+ self.config = config
54
+ self.results = []
55
+ self.output_dir = Path(config.output_dir)
56
+ self.output_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ # Generate unique run ID
59
+ self.run_id = f"{config.benchmark_name}_{config.provider_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
60
+
61
+ # Set up logging
62
+ self._setup_logging()
63
+
64
+ self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
65
+
66
+ def _setup_logging(self) -> None:
67
+ """Set up logging configuration."""
68
+ log_file = self.output_dir / f"benchmark_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
69
+
70
+ # Create logger
71
+ self.logger = logging.getLogger(f"benchmark_runner_{self.run_id}")
72
+ self.logger.setLevel(logging.INFO)
73
+
74
+ # Create handlers
75
+ file_handler = logging.FileHandler(log_file)
76
+ console_handler = logging.StreamHandler()
77
+
78
+ # Create formatter
79
+ formatter = logging.Formatter(
80
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
81
+ )
82
+ file_handler.setFormatter(formatter)
83
+ console_handler.setFormatter(formatter)
84
+
85
+ # Add handlers to logger
86
+ self.logger.addHandler(file_handler)
87
+ self.logger.addHandler(console_handler)
88
+
89
+ def run_benchmark(
90
+ self,
91
+ llm_provider: LLMProvider,
92
+ benchmark: Benchmark,
93
+ ) -> Dict[str, Any]:
94
+ """Run a benchmark against an LLM provider.
95
+
96
+ Args:
97
+ llm_provider (LLMProvider): The LLM provider to test
98
+ benchmark (Benchmark): The benchmark to run
99
+
100
+ Returns:
101
+ Dict[str, Any]: Summary of benchmark results
102
+ """
103
+ self.logger.info(f"Starting benchmark run: {self.run_id}")
104
+ self.logger.info(f"Model: {llm_provider.model_name}")
105
+ self.logger.info(f"Benchmark: {benchmark}")
106
+
107
+ # Test provider connection
108
+ if not llm_provider.test_connection():
109
+ self.logger.error("LLM provider connection test failed")
110
+ return {"error": "LLM provider connection test failed"}
111
+
112
+ # Get data points to process
113
+ total_questions = len(benchmark)
114
+ max_questions = self.config.max_questions or total_questions
115
+ end_index = min(max_questions, total_questions)
116
+
117
+ self.logger.info(f"Processing questions {0} to {end_index-1} of {total_questions}")
118
+
119
+ # Initialize counters
120
+ processed = 0
121
+ correct = 0
122
+ total_duration = 0.0
123
+
124
+ # Process each data point
125
+ for i in tqdm(range(0, end_index), desc="Processing questions"):
126
+ try:
127
+ data_point = benchmark.get_data_point(i)
128
+
129
+ # Run the model on this data point
130
+ result = self._process_data_point(llm_provider, data_point)
131
+
132
+ # Update counters
133
+ processed += 1
134
+ if result.is_correct:
135
+ correct += 1
136
+ total_duration += result.duration
137
+
138
+ # Add to results
139
+ self.results.append(result)
140
+
141
+ # Log progress
142
+ if processed % 10 == 0:
143
+ self._save_intermediate_results()
144
+ accuracy = (correct / processed) * 100
145
+ avg_duration = total_duration / processed
146
+
147
+ self.logger.info(
148
+ f"Progress: {processed}/{end_index} | "
149
+ f"Accuracy: {accuracy:.2f}% | "
150
+ f"Avg Duration: {avg_duration:.2f}s"
151
+ )
152
+
153
+ except Exception as e:
154
+ self.logger.error(f"Error processing data point {i}: {e}")
155
+ # Add error result
156
+ error_result = BenchmarkResult(
157
+ data_point_id=f"error_{i}",
158
+ question="",
159
+ model_answer="",
160
+ correct_answer="",
161
+ is_correct=False,
162
+ duration=0.0,
163
+ error=str(e)
164
+ )
165
+ self.results.append(error_result)
166
+ continue
167
+
168
+ # Save final results
169
+ summary = self._save_final_results(benchmark)
170
+
171
+ self.logger.info(f"Benchmark run completed: {self.run_id}")
172
+ self.logger.info(f"Final accuracy: {summary['results']['accuracy']:.2f}%")
173
+ self.logger.info(f"Total duration: {summary['results']['total_duration']:.2f}s")
174
+
175
+ return summary
176
+
177
+ def _process_data_point(
178
+ self,
179
+ llm_provider: LLMProvider,
180
+ data_point: BenchmarkDataPoint,
181
+ ) -> BenchmarkResult:
182
+ """Process a single data point.
183
+
184
+ Args:
185
+ llm_provider (LLMProvider): The LLM provider to use
186
+ data_point (BenchmarkDataPoint): The data point to process
187
+
188
+ Returns:
189
+ BenchmarkResult: Result of processing the data point
190
+ """
191
+ start_time = time.time()
192
+
193
+ try:
194
+ # Create request
195
+ request = LLMRequest(
196
+ text=data_point.text,
197
+ images=data_point.images,
198
+ temperature=self.config.temperature,
199
+ top_p=self.config.top_p,
200
+ max_tokens=self.config.max_tokens,
201
+ additional_params=self.config.additional_params
202
+ )
203
+
204
+ # Get response from LLM
205
+ response: LLMResponse = llm_provider.generate_response(request)
206
+
207
+ # Extract answer (this may need customization based on benchmark)
208
+ model_answer = self._extract_answer(response.content)
209
+
210
+ # Check if correct
211
+ is_correct = self._is_correct_answer(model_answer, data_point.correct_answer)
212
+
213
+ duration = time.time() - start_time
214
+
215
+ return BenchmarkResult(
216
+ data_point_id=data_point.id,
217
+ question=data_point.text,
218
+ model_answer=model_answer,
219
+ correct_answer=data_point.correct_answer,
220
+ is_correct=is_correct,
221
+ duration=duration,
222
+ usage=response.usage,
223
+ metadata={
224
+ "data_point_metadata": data_point.metadata,
225
+ "case_id": data_point.case_id,
226
+ "category": data_point.category,
227
+ "raw_response": response.content,
228
+ }
229
+ )
230
+
231
+ except Exception as e:
232
+ duration = time.time() - start_time
233
+ return BenchmarkResult(
234
+ data_point_id=data_point.id,
235
+ question=data_point.text,
236
+ model_answer="",
237
+ correct_answer=data_point.correct_answer,
238
+ is_correct=False,
239
+ duration=duration,
240
+ error=str(e),
241
+ metadata={
242
+ "data_point_metadata": data_point.metadata,
243
+ "case_id": data_point.case_id,
244
+ "category": data_point.category,
245
+ }
246
+ )
247
+
248
+ def _extract_answer(self, response_text: str) -> str:
249
+ """Extract the answer from the model response.
250
+
251
+ Args:
252
+ response_text (str): The full response text from the model
253
+
254
+ Returns:
255
+ str: The extracted answer
256
+ """
257
+ # First, look for the '<|A|>' format
258
+ final_answer_pattern = r'\s*<\|([A-F])\|>'
259
+ match = re.search(final_answer_pattern, response_text)
260
+ if match:
261
+ return match.group(1).upper()
262
+
263
+ # If no pattern matches, return the full response
264
+ return response_text.strip()
265
+
266
+ def _is_correct_answer(self, model_answer: str, correct_answer: str) -> bool:
267
+ """Check if the model answer is correct.
268
+
269
+ Args:
270
+ model_answer (str): The model's answer
271
+ correct_answer (str): The correct answer
272
+
273
+ Returns:
274
+ bool: True if the answer is correct
275
+ """
276
+ if not model_answer or not correct_answer:
277
+ return False
278
+
279
+ # For multiple choice, compare just the letter
280
+ model_clean = model_answer.strip().upper()
281
+ correct_clean = correct_answer.strip().upper()
282
+
283
+ # Extract just the first letter for comparison
284
+ model_letter = model_clean[0] if model_clean else ""
285
+ correct_letter = correct_clean[0] if correct_clean else ""
286
+
287
+ return model_letter == correct_letter
288
+
289
+ def _save_intermediate_results(self) -> None:
290
+ """Save intermediate results to disk."""
291
+ results_file = self.output_dir / f"{self.run_id}_intermediate.json"
292
+
293
+ # Convert results to serializable format
294
+ results_data = []
295
+ for result in self.results:
296
+ results_data.append({
297
+ "data_point_id": result.data_point_id,
298
+ "question": result.question,
299
+ "model_answer": result.model_answer,
300
+ "correct_answer": result.correct_answer,
301
+ "is_correct": result.is_correct,
302
+ "duration": result.duration,
303
+ "usage": result.usage,
304
+ "error": result.error,
305
+ "metadata": result.metadata,
306
+ })
307
+
308
+ with open(results_file, 'w') as f:
309
+ json.dump(results_data, f, indent=2)
310
+
311
+ def _save_final_results(self, benchmark: Benchmark) -> Dict[str, Any]:
312
+ """Save final results and return summary.
313
+
314
+ Args:
315
+ benchmark (Benchmark): The benchmark that was run
316
+
317
+ Returns:
318
+ Dict[str, Any]: Summary of results
319
+ """
320
+ # Save detailed results
321
+ results_file = self.output_dir / f"{self.run_id}_results.json"
322
+ self._save_intermediate_results()
323
+
324
+ # Calculate summary statistics
325
+ total_questions = len(self.results)
326
+ correct_answers = sum(1 for r in self.results if r.is_correct)
327
+ total_duration = sum(r.duration for r in self.results)
328
+
329
+ accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0
330
+
331
+ # Calculate per-category accuracy
332
+ category_stats = {}
333
+ for result in self.results:
334
+ if result.metadata and result.metadata.get("category"):
335
+ category = result.metadata["category"]
336
+ if category not in category_stats:
337
+ category_stats[category] = {"correct": 0, "total": 0}
338
+ category_stats[category]["total"] += 1
339
+ if result.is_correct:
340
+ category_stats[category]["correct"] += 1
341
+
342
+ # Calculate accuracy for each category
343
+ category_accuracies = {}
344
+ for category, stats in category_stats.items():
345
+ category_accuracies[category] = (stats["correct"] / stats["total"]) * 100
346
+
347
+ # Create summary
348
+ summary = {
349
+ "run_id": self.run_id,
350
+ "timestamp": datetime.now().isoformat(),
351
+ "config": {
352
+ "model_name": self.config.model_name,
353
+ "benchmark_name": self.config.benchmark_name,
354
+ "temperature": self.config.temperature,
355
+ "max_tokens": self.config.max_tokens,
356
+ },
357
+ "benchmark_info": {
358
+ "total_size": len(benchmark),
359
+ "processed_questions": total_questions,
360
+ },
361
+ "results": {
362
+ "accuracy": accuracy,
363
+ "correct_answers": correct_answers,
364
+ "total_questions": total_questions,
365
+ "total_duration": total_duration,
366
+ "avg_duration_per_question": total_duration / total_questions if total_questions > 0 else 0,
367
+ "category_accuracies": category_accuracies,
368
+ },
369
+ "results_file": str(results_file),
370
+ }
371
+
372
+ # Save summary
373
+ summary_file = self.output_dir / f"{self.run_id}_summary.json"
374
+ with open(summary_file, 'w') as f:
375
+ json.dump(summary, f, indent=2)
376
+
377
+ return summary
main.py CHANGED
@@ -9,7 +9,6 @@ The system uses OpenAI's language models for reasoning and can be configured
9
  with different model weights, tools, and parameters.
10
  """
11
 
12
- import os
13
  import warnings
14
  from typing import Dict, List, Optional, Any
15
  from dotenv import load_dotenv
@@ -175,14 +174,6 @@ if __name__ == "__main__":
175
  # Prepare any additional model-specific kwargs
176
  model_kwargs = {}
177
 
178
- # Set up API keys for the web browser tool
179
- # You'll need to set these environment variables:
180
- # - GOOGLE_SEARCH_API_KEY: Your Google Custom Search API key
181
- # - GOOGLE_SEARCH_ENGINE_ID: Your Google Custom Search Engine ID
182
- # - COHERE_API_KEY: Your Cohere API key
183
- # - OPENAI_API_KEY: Your OpenAI API key
184
- # - PINECONE_API_KEY: Your Pinecone API key
185
-
186
  agent, tools_dict = initialize_agent(
187
  prompt_file="medrax/docs/system_prompts.txt",
188
  tools_to_use=selected_tools,
 
9
  with different model weights, tools, and parameters.
10
  """
11
 
 
12
  import warnings
13
  from typing import Dict, List, Optional, Any
14
  from dotenv import load_dotenv
 
174
  # Prepare any additional model-specific kwargs
175
  model_kwargs = {}
176
 
 
 
 
 
 
 
 
 
177
  agent, tools_dict = initialize_agent(
178
  prompt_file="medrax/docs/system_prompts.txt",
179
  tools_to_use=selected_tools,
medrax/docs/system_prompts.txt CHANGED
@@ -1,20 +1,26 @@
1
  [MEDICAL_ASSISTANT]
2
  You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
3
  Solve using your own vision and reasoning and use tools to complement your reasoning.
4
- Make multiple tool calls in parallel or sequence as needed for comprehensive answers.
5
- Critically think about and criticize the tool outputs.
6
  If you need to look up some information before asking a follow up question, you are allowed to do that.
7
 
8
  CITATION REQUIREMENTS:
9
- - When referencing information from the RAG and web search tools, ALWAYS include numbered citations [1], [2], [3], etc.
10
- - Use citations immediately after making claims or statements based on the above tool results
11
- - Be consistent with citation numbering throughout your response
12
- - Only cite sources that actually contain the information you're referencing
13
 
14
  Examples:
15
  - "According to recent research [1], chest X-rays can show signs of pneumonia..."
16
  - "The medical literature indicates [2] that this condition typically presents with..."
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
- [GENERAL_ASSISTANT]
20
- You are a helpful AI assistant. Your role is to assist users with a wide range of tasks and questions, providing accurate and useful information on various topics.
 
 
 
 
 
 
 
1
  [MEDICAL_ASSISTANT]
2
  You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
3
  Solve using your own vision and reasoning and use tools to complement your reasoning.
4
+ You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
5
+ Think critically about and criticize the tool outputs.
6
  If you need to look up some information before asking a follow up question, you are allowed to do that.
7
 
8
  CITATION REQUIREMENTS:
9
+ - When referencing information from RAG and/or web search tools, ALWAYS include numbered citations [1], [2], [3], etc.
10
+ - Use citations immediately after making claims or statements based on the above tool results.
11
+ - Be consistent with citation numbering throughout your response.
12
+ - Only cite sources that actually contain the information you're referencing.
13
 
14
  Examples:
15
  - "According to recent research [1], chest X-rays can show signs of pneumonia..."
16
  - "The medical literature indicates [2] that this condition typically presents with..."
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
+ [CHESTAGENTBENCH_PROMPT]
20
+ You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
21
+ Solve using your own vision and reasoning and use tools (if available) to complement your reasoning.
22
+ You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
+ Think critically about and criticize the tool outputs.
24
+ If you need to look up some information before asking a follow up question, you are allowed to do that.
25
+ When encountering a multiple-choice question, your final response should end with "Final answer: <|A|>" from list of possible choices A, B, C, D, E, F.
26
+ It is extremely important that you strictly answer in the format mentioned above.
medrax/models/model_factory.py CHANGED
@@ -28,7 +28,11 @@ class ModelFactory:
28
  "env_key": "OPENAI_API_KEY",
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
- "gemini": {"class": ChatGoogleGenerativeAI, "env_key": "GOOGLE_API_KEY"},
 
 
 
 
32
  "openrouter": {
33
  "class": ChatOpenAI, # OpenRouter uses OpenAI-compatible interface
34
  "env_key": "OPENROUTER_API_KEY",
@@ -36,8 +40,8 @@ class ModelFactory:
36
  "default_base_url": "https://openrouter.ai/api/v1",
37
  },
38
  "grok": {
39
- "class": ChatXAI,
40
- "env_key": "XAI_API_KEY",
41
  }
42
  # Add more providers with default configurations here
43
  }
 
28
  "env_key": "OPENAI_API_KEY",
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
+ "gemini": {
32
+ "class": ChatGoogleGenerativeAI,
33
+ "env_key": "GOOGLE_API_KEY",
34
+ "base_url_key": "GOOGLE_BASE_URL",
35
+ },
36
  "openrouter": {
37
  "class": ChatOpenAI, # OpenRouter uses OpenAI-compatible interface
38
  "env_key": "OPENROUTER_API_KEY",
 
40
  "default_base_url": "https://openrouter.ai/api/v1",
41
  },
42
  "grok": {
43
+ "class": ChatXAI,
44
+ "env_key": "XAI_API_KEY",
45
  }
46
  # Add more providers with default configurations here
47
  }
pyproject.toml CHANGED
@@ -72,6 +72,8 @@ dependencies = [
72
  "langchain-google-genai>=0.1.0",
73
  "ray>=2.9.0",
74
  "langchain-sandbox>=0.0.6",
 
 
75
  "iopath>=0.1.10",
76
  ]
77
 
 
72
  "langchain-google-genai>=0.1.0",
73
  "ray>=2.9.0",
74
  "langchain-sandbox>=0.0.6",
75
+ "seaborn>=0.12.0",
76
+ "huggingface_hub>=0.17.0",
77
  "iopath>=0.1.10",
78
  ]
79