VictorLJZ commited on
Commit
99f2cbc
·
1 Parent(s): eba86cf

changes so far

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
benchmarking/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Benchmarking pipeline for MedRAX and other medical AI models."""
2
+
3
+ __version__ = "1.0.0"
benchmarking/benchmarks/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark abstractions for medical AI evaluation."""
2
+
3
+ from .base import Benchmark, BenchmarkDataPoint
4
+ from .chest_agent_bench import ChestAgentBench
5
+ from .rexvqa_benchmark import ReXVQABenchmark
6
+
7
+ __all__ = [
8
+ "Benchmark",
9
+ "BenchmarkDataPoint",
10
+ "ChestAgentBench",
11
+ "ReXVQABenchmark",
12
+ ]
benchmarking/benchmarks/base.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for benchmarks."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Optional, Any, Iterator, Tuple
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ import json
8
+
9
+
10
+ @dataclass
11
+ class BenchmarkDataPoint:
12
+ """A single data point from a benchmark."""
13
+ id: str
14
+ text: str # The question/prompt
15
+ images: Optional[List[str]] = None # List of image paths
16
+ correct_answer: Optional[str] = None # Ground truth answer
17
+ case_id: Optional[str] = None # For grouping related questions
18
+ category: Optional[str] = None # Type of question/task
19
+ metadata: Optional[Dict[str, Any]] = None # Additional metadata
20
+
21
+
22
+ class Benchmark(ABC):
23
+ """Abstract base class for benchmarks.
24
+
25
+ This class defines the interface for all benchmarks, standardizing
26
+ how data is loaded and accessed across different benchmark datasets.
27
+ """
28
+
29
+ def __init__(self, data_dir: str, **kwargs):
30
+ """Initialize the benchmark.
31
+
32
+ Args:
33
+ data_dir (str): Directory containing benchmark data
34
+ **kwargs: Additional configuration parameters
35
+ """
36
+ self.data_dir = Path(data_dir)
37
+ self.config = kwargs
38
+ self.data_points = []
39
+ self._load_data()
40
+
41
+ @abstractmethod
42
+ def _load_data(self) -> None:
43
+ """Load benchmark data from the data directory."""
44
+ pass
45
+
46
+ def get_data_point(self, index: int) -> BenchmarkDataPoint:
47
+ """Get a specific data point by index.
48
+
49
+ Args:
50
+ index (int): Index of the data point to retrieve
51
+
52
+ Returns:
53
+ BenchmarkDataPoint: The data point at the given index
54
+ """
55
+ if index < 0 or index >= len(self.data_points):
56
+ raise IndexError(f"Index {index} out of range for {len(self.data_points)} data points")
57
+
58
+ return self.data_points[index]
59
+
60
+ def get_subset(self, indices: List[int]) -> List[BenchmarkDataPoint]:
61
+ """Get a subset of data points by indices.
62
+
63
+ Args:
64
+ indices (List[int]): List of indices to retrieve
65
+
66
+ Returns:
67
+ List[BenchmarkDataPoint]: List of data points at the given indices
68
+ """
69
+ return [self.get_data_point(i) for i in indices]
70
+
71
+ def save_subset(self, indices: List[int], output_path: str) -> None:
72
+ """Save a subset of the benchmark to a file.
73
+
74
+ Args:
75
+ indices (List[int]): Indices of data points to save
76
+ output_path (str): Path to save the subset
77
+ """
78
+ subset = self.get_subset(indices)
79
+
80
+ # Convert to serializable format
81
+ subset_data = []
82
+ for dp in subset:
83
+ subset_data.append({
84
+ "id": dp.id,
85
+ "text": dp.text,
86
+ "images": dp.images,
87
+ "correct_answer": dp.correct_answer,
88
+ "metadata": dp.metadata,
89
+ "case_id": dp.case_id,
90
+ "category": dp.category,
91
+ })
92
+
93
+ with open(output_path, 'w') as f:
94
+ json.dump(subset_data, f, indent=2)
95
+
96
+ def get_by_category(self, category: str) -> List[BenchmarkDataPoint]:
97
+ """Get all data points of a specific category.
98
+
99
+ Args:
100
+ category (str): Category to filter by
101
+
102
+ Returns:
103
+ List[BenchmarkDataPoint]: List of data points in the category
104
+ """
105
+ return [dp for dp in self if dp.category == category]
106
+
107
+ def get_by_case_id(self, case_id: str) -> List[BenchmarkDataPoint]:
108
+ """Get all data points for a specific case.
109
+
110
+ Args:
111
+ case_id (str): Case ID to filter by
112
+
113
+ Returns:
114
+ List[BenchmarkDataPoint]: List of data points for the case
115
+ """
116
+ return [dp for dp in self if dp.case_id == case_id]
117
+
118
+ def __str__(self) -> str:
119
+ """String representation of the benchmark."""
120
+ return f"{self.__class__.__name__}(data_dir={self.data_dir}, size={len(self)})"
121
+
122
+ def __len__(self) -> int:
123
+ """Return the number of data points in the benchmark."""
124
+ return len(self.data_points)
125
+
126
+ def __iter__(self) -> Iterator[BenchmarkDataPoint]:
127
+ """Iterate over all data points in the benchmark."""
128
+ for i in range(len(self)):
129
+ yield self.get_data_point(i)
130
+
131
+ def get_categories(self) -> List[str]:
132
+ """Get all unique categories in the benchmark.
133
+
134
+ Returns:
135
+ List[str]: List of unique categories
136
+ """
137
+ categories = set()
138
+ for dp in self:
139
+ if dp.category:
140
+ categories.add(dp.category)
141
+ return sorted(list(categories))
142
+
143
+ def get_case_ids(self) -> List[str]:
144
+ """Get all unique case IDs in the benchmark.
145
+
146
+ Returns:
147
+ List[str]: List of unique case IDs
148
+ """
149
+ case_ids = set()
150
+ for dp in self:
151
+ if dp.case_id:
152
+ case_ids.add(dp.case_id)
153
+ return sorted(list(case_ids))
154
+
155
+ def get_stats(self) -> Dict[str, Any]:
156
+ """Get statistics about the benchmark.
157
+
158
+ Returns:
159
+ Dict[str, Any]: Dictionary containing benchmark statistics
160
+ """
161
+ stats = {
162
+ "total_questions": len(self),
163
+ "total_cases": len(self.get_case_ids()),
164
+ "categories": self.get_categories(),
165
+ "category_counts": {},
166
+ "images_per_question": [],
167
+ "has_images": 0,
168
+ "no_images": 0,
169
+ }
170
+
171
+ for dp in self:
172
+ # Category counts
173
+ if dp.category:
174
+ stats["category_counts"][dp.category] = stats["category_counts"].get(dp.category, 0) + 1
175
+
176
+ # Image statistics
177
+ if dp.images:
178
+ stats["images_per_question"].append(len(dp.images))
179
+ stats["has_images"] += 1
180
+ else:
181
+ stats["images_per_question"].append(0)
182
+ stats["no_images"] += 1
183
+
184
+ if stats["images_per_question"]:
185
+ stats["avg_images_per_question"] = sum(stats["images_per_question"]) / len(stats["images_per_question"])
186
+ stats["max_images_per_question"] = max(stats["images_per_question"])
187
+ else:
188
+ stats["avg_images_per_question"] = 0
189
+ stats["max_images_per_question"] = 0
190
+
191
+ return stats
192
+
193
+ def validate_images(self) -> Tuple[List[str], List[str]]:
194
+ """Validate that all image paths exist.
195
+
196
+ Returns:
197
+ Tuple[List[str], List[str]]: Tuple of (valid_image_paths, invalid_image_paths)
198
+ """
199
+ valid_images = []
200
+ invalid_images = []
201
+
202
+ for dp in self:
203
+ if dp.images:
204
+ for image_path in dp.images:
205
+ if Path(image_path).exists():
206
+ valid_images.append(image_path)
207
+ else:
208
+ invalid_images.append(image_path)
209
+
210
+ return valid_images, invalid_images
benchmarking/benchmarks/chest_agent_bench.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChestAgentBench benchmark implementation."""
2
+
3
+ import json
4
+ import glob
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Any
7
+
8
+ from .base import Benchmark, BenchmarkDataPoint
9
+
10
+
11
+ class ChestAgentBench(Benchmark):
12
+ """ChestAgentBench benchmark for complex medical reasoning tasks."""
13
+
14
+ def __init__(self, data_dir: str, **kwargs):
15
+ """Initialize ChestAgentBench.
16
+
17
+ Args:
18
+ data_dir (str): Directory containing benchmark data
19
+ **kwargs: Additional configuration parameters
20
+ """
21
+ # Expected structure:
22
+ # data_dir/
23
+ # eurorad_metadata.json # Case metadata
24
+ # questions/
25
+ # case_id1/
26
+ # case_id1_question1.json
27
+ # case_id1_question2.json
28
+ # case_id2/
29
+ # ...
30
+ # figures/
31
+ # case_id1/
32
+ # figure1.jpg
33
+ # figure2.jpg
34
+ # case_id2/
35
+ # ...
36
+
37
+ self.metadata_file = kwargs.get("metadata_file", "eurorad_metadata.json")
38
+ self.questions_dir = kwargs.get("questions_dir", "questions")
39
+ self.figures_dir = kwargs.get("figures_dir", "figures")
40
+
41
+ super().__init__(data_dir, **kwargs)
42
+
43
+ def _load_data(self) -> None:
44
+ """Load ChestAgentBench data."""
45
+ # Load case metadata
46
+ metadata_path = self.data_dir / self.metadata_file
47
+ if not metadata_path.exists():
48
+ raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
49
+
50
+ with open(metadata_path, 'r') as f:
51
+ case_metadata = json.load(f)
52
+
53
+ # Load questions for each case
54
+ questions_path = self.data_dir / self.questions_dir
55
+ if not questions_path.exists():
56
+ raise FileNotFoundError(f"Questions directory not found: {questions_path}")
57
+
58
+ figures_path = self.data_dir / self.figures_dir
59
+
60
+ self.data_points = []
61
+
62
+ for case_id, case_details in case_metadata.items():
63
+ # Find all question files for this case
64
+ case_questions_dir = questions_path / case_id
65
+ if not case_questions_dir.exists():
66
+ continue
67
+
68
+ question_files = glob.glob(str(case_questions_dir / f"{case_id}_*.json"))
69
+
70
+ for question_file in question_files:
71
+ try:
72
+ with open(question_file, 'r') as f:
73
+ question_data = json.load(f)
74
+
75
+ question_id = Path(question_file).stem
76
+
77
+ # Parse figure information
78
+ images = []
79
+ if question_data.get("figures"):
80
+ required_figures = self._parse_figures(question_data["figures"])
81
+
82
+ # Find actual image files
83
+ case_figures_dir = figures_path / case_id
84
+ if case_figures_dir.exists():
85
+ for figure_id in required_figures:
86
+ # Look for the figure file
87
+ figure_files = glob.glob(str(case_figures_dir / f"{figure_id}.*"))
88
+ if figure_files:
89
+ images.append(figure_files[0]) # Take the first match
90
+
91
+ # Extract categories from metadata
92
+ categories = []
93
+ if question_data.get("metadata", {}).get("categories"):
94
+ categories = question_data["metadata"]["categories"]
95
+
96
+ category = categories[0] if categories else None
97
+
98
+ # Create data point
99
+ data_point = BenchmarkDataPoint(
100
+ id=question_id,
101
+ text=question_data["question"],
102
+ images=images if images else None,
103
+ correct_answer=question_data.get("answer", [None])[0],
104
+ metadata={
105
+ "case_details": case_details,
106
+ "question_metadata": question_data.get("metadata", {}),
107
+ "explanation": question_data.get("explanation", ""),
108
+ "categories": categories,
109
+ "figures": question_data.get("figures", []),
110
+ },
111
+ case_id=case_id,
112
+ category=category,
113
+ )
114
+
115
+ self.data_points.append(data_point)
116
+
117
+ except Exception as e:
118
+ print(f"Error loading question {question_file}: {e}")
119
+ continue
120
+
121
+ def _parse_figures(self, figures_data: Any) -> List[str]:
122
+ """Parse figure information from question data.
123
+
124
+ Args:
125
+ figures_data: Figure information from question JSON
126
+
127
+ Returns:
128
+ List[str]: List of figure IDs
129
+ """
130
+ if isinstance(figures_data, str):
131
+ try:
132
+ # Try to parse as JSON
133
+ figures_list = json.loads(figures_data)
134
+ return figures_list if isinstance(figures_list, list) else [figures_data]
135
+ except json.JSONDecodeError:
136
+ return [figures_data]
137
+ elif isinstance(figures_data, list):
138
+ return figures_data
139
+ else:
140
+ return [str(figures_data)]
141
+
142
+ def get_data_point(self, index: int) -> BenchmarkDataPoint:
143
+ """Get a specific data point by index.
144
+
145
+ Args:
146
+ index (int): Index of the data point to retrieve
147
+
148
+ Returns:
149
+ BenchmarkDataPoint: The data point at the given index
150
+ """
151
+ if index < 0 or index >= len(self.data_points):
152
+ raise IndexError(f"Index {index} out of range for {len(self.data_points)} data points")
153
+
154
+ return self.data_points[index]
155
+
156
+ def get_multiple_choice_options(self, data_point: BenchmarkDataPoint) -> List[str]:
157
+ """Get multiple choice options for a data point.
158
+
159
+ Args:
160
+ data_point (BenchmarkDataPoint): The data point
161
+
162
+ Returns:
163
+ List[str]: List of multiple choice options (A, B, C, D, E, F)
164
+ """
165
+ # ChestAgentBench uses A-F multiple choice
166
+ return ["A", "B", "C", "D", "E", "F"]
167
+
168
+ def format_question_with_choices(self, data_point: BenchmarkDataPoint) -> str:
169
+ """Format question text with multiple choice options.
170
+
171
+ Args:
172
+ data_point (BenchmarkDataPoint): The data point
173
+
174
+ Returns:
175
+ str: Formatted question with choices
176
+ """
177
+ question = data_point.text
178
+
179
+ # Add instruction for multiple choice format
180
+ question += "\n\nPlease provide your answer as a single letter (A, B, C, D, E, or F)."
181
+
182
+ return question
183
+
184
+ def get_category_mapping(self) -> Dict[str, str]:
185
+ """Get mapping of category names to descriptions.
186
+
187
+ Returns:
188
+ Dict[str, str]: Mapping of category names to descriptions
189
+ """
190
+ return {
191
+ "detection": "Identify and locate specific findings in the chest X-ray",
192
+ "classification": "Determine whether specific findings are present or absent",
193
+ "enumeration": "Count the number of target findings in the chest X-ray",
194
+ "localization": "Locate a given finding in the chest X-ray",
195
+ "comparison": "Compare the size or position of a specific finding",
196
+ "relationship": "Determine the relationship between two or more findings",
197
+ "diagnosis": "Make a diagnosis or determine a treatment plan",
198
+ "characterization": "Describe specific attributes of findings",
199
+ "reasoning": "Explain the medical rationale behind findings and conclusions",
200
+ }
benchmarking/benchmarks/rexvqa_benchmark.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ReXVQA benchmark implementation."""
2
+
3
+ from typing import Dict, List, Optional, Any
4
+ from datasets import load_dataset
5
+ from .base import Benchmark, BenchmarkDataPoint
6
+
7
+
8
+ class ReXVQABenchmark(Benchmark):
9
+ """ReXVQA benchmark for chest radiology visual question answering.
10
+
11
+ ReXVQA is a large-scale VQA dataset for chest radiology comprising approximately
12
+ 696,000 questions paired with 160,000 chest X-rays. It tests 5 core radiological
13
+ reasoning skills: presence assessment, location analysis, negation detection,
14
+ differential diagnosis, and geometric reasoning.
15
+
16
+ Paper: https://arxiv.org/abs/2506.04353
17
+ Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
18
+ """
19
+
20
+ def __init__(self, data_dir: str, **kwargs):
21
+ """Initialize ReXVQA benchmark.
22
+
23
+ Args:
24
+ data_dir (str): Directory to store/cache downloaded data
25
+ **kwargs: Additional configuration parameters
26
+ split (str): Dataset split to use ('validation' or 'test', default: 'validation')
27
+ cache_dir (str): Directory for caching HuggingFace datasets
28
+ trust_remote_code (bool): Whether to trust remote code (default: False)
29
+ """
30
+ self.split = kwargs.get("split", "validation")
31
+ self.cache_dir = kwargs.get("cache_dir", None)
32
+ self.trust_remote_code = kwargs.get("trust_remote_code", False)
33
+
34
+ super().__init__(data_dir, **kwargs)
35
+
36
+ def _load_data(self) -> None:
37
+ """Load ReXVQA data from HuggingFace."""
38
+ try:
39
+ # Load dataset from HuggingFace
40
+ print(f"Loading ReXVQA {self.split} split from HuggingFace...")
41
+
42
+ dataset = load_dataset(
43
+ "rajpurkarlab/ReXVQA",
44
+ split=self.split,
45
+ cache_dir=self.cache_dir,
46
+ trust_remote_code=self.trust_remote_code
47
+ )
48
+
49
+ print(f"Loaded {len(dataset)} examples from ReXVQA {self.split} split")
50
+
51
+ self.data_points = []
52
+
53
+ for i, item in enumerate(dataset):
54
+ try:
55
+ data_point = self._parse_rexvqa_item(item, i)
56
+ if data_point:
57
+ self.data_points.append(data_point)
58
+
59
+ except Exception as e:
60
+ print(f"Error loading item {i}: {e}")
61
+ continue
62
+
63
+ except Exception as e:
64
+ raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
65
+
66
+ def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
67
+ """Parse a ReXVQA dataset item.
68
+
69
+ Args:
70
+ item (Dict[str, Any]): Dataset item from HuggingFace
71
+ index (int): Item index
72
+
73
+ Returns:
74
+ Optional[BenchmarkDataPoint]: Parsed data point
75
+ """
76
+ # Extract basic information
77
+ question_id = item.get("id", f"rexvqa_{self.split}_{index}")
78
+ question = item.get("question", "")
79
+ answer = item.get("answer", "")
80
+
81
+ if not question:
82
+ return None
83
+
84
+ # Handle image
85
+ images = None
86
+ if "image" in item and item["image"] is not None:
87
+ # Save image to local cache directory
88
+ image_filename = f"{question_id}.png"
89
+ image_path = self.data_dir / "images" / image_filename
90
+
91
+ # Create images directory if it doesn't exist
92
+ image_path.parent.mkdir(parents=True, exist_ok=True)
93
+
94
+ # Save image if it doesn't exist
95
+ if not image_path.exists():
96
+ try:
97
+ item["image"].save(str(image_path))
98
+ except Exception as e:
99
+ print(f"Error saving image for {question_id}: {e}")
100
+ return None
101
+
102
+ images = [str(image_path)]
103
+
104
+ # Extract metadata
105
+ metadata = {
106
+ "dataset": "rexvqa",
107
+ "split": self.split,
108
+ "study_id": item.get("study_id", ""),
109
+ "image_id": item.get("image_id", ""),
110
+ "reasoning_type": item.get("reasoning_type", ""),
111
+ "anatomical_location": item.get("anatomical_location", ""),
112
+ "pathology": item.get("pathology", ""),
113
+ }
114
+
115
+ # Determine category from reasoning type
116
+ category = item.get("reasoning_type", "")
117
+
118
+ # Use study_id as case_id for grouping related questions
119
+ case_id = item.get("study_id", "")
120
+
121
+ return BenchmarkDataPoint(
122
+ id=question_id,
123
+ text=question,
124
+ images=images,
125
+ correct_answer=answer,
126
+ metadata=metadata,
127
+ case_id=case_id,
128
+ category=category,
129
+ )
130
+
131
+ def get_pathologies(self) -> List[str]:
132
+ """Get all unique pathologies in the dataset.
133
+
134
+ Returns:
135
+ List[str]: List of unique pathologies
136
+ """
137
+ pathologies = set()
138
+ for dp in self:
139
+ pathology = dp.metadata.get("pathology", "")
140
+ if pathology:
141
+ pathologies.add(pathology)
142
+ return sorted(list(pathologies))
143
+
144
+ def get_by_pathology(self, pathology: str) -> List[BenchmarkDataPoint]:
145
+ """Get all data points about a specific pathology.
146
+
147
+ Args:
148
+ pathology (str): Pathology to filter by
149
+
150
+ Returns:
151
+ List[BenchmarkDataPoint]: List of data points about the pathology
152
+ """
153
+ return [dp for dp in self if dp.metadata.get("pathology", "") == pathology]
154
+
155
+ def get_dataset_info(self) -> Dict[str, Any]:
156
+ """Get information about the ReXVQA dataset.
157
+
158
+ Returns:
159
+ Dict[str, Any]: Dataset information
160
+ """
161
+ return {
162
+ "name": "ReXVQA",
163
+ "description": "Large-scale Visual Question Answering Benchmark for Chest Radiology",
164
+ "split": self.split,
165
+ "size": len(self.data_points),
166
+ "reasoning_types": self.get_reasoning_types(),
167
+ "pathologies": self.get_pathologies(),
168
+ "categories": self.get_categories(),
169
+ "paper": "https://arxiv.org/abs/2506.04353",
170
+ "dataset_url": "https://huggingface.co/datasets/rajpurkarlab/ReXVQA",
171
+ }
benchmarking/cli.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface for the benchmarking pipeline."""
2
+
3
+ import argparse
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional
7
+
8
+ from .llm_providers import *
9
+ from .benchmarks import *
10
+ from .runner import BenchmarkRunner, BenchmarkRunConfig
11
+ from .evaluation import BenchmarkEvaluator
12
+
13
+
14
+ def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMProvider:
15
+ """Create an LLM provider based on the model name and type.
16
+
17
+ Args:
18
+ model_name (str): Name of the model
19
+ provider_type (str): Type of provider (openai, google, openrouter, medrax)
20
+ **kwargs: Additional configuration parameters
21
+
22
+ Returns:
23
+ LLMProvider: The configured LLM provider
24
+ """
25
+ provider_map = {
26
+ "openai": OpenAIProvider,
27
+ "google": GoogleProvider,
28
+ "openrouter": OpenRouterProvider,
29
+ "medrax": MedRAXProvider,
30
+ }
31
+
32
+ if provider_type not in provider_map:
33
+ raise ValueError(f"Unknown provider type: {provider_type}. Available: {list(provider_map.keys())}")
34
+
35
+ provider_class = provider_map[provider_type]
36
+ return provider_class(model_name, **kwargs)
37
+
38
+
39
+ def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
40
+ """Create a benchmark based on the benchmark name.
41
+
42
+ Args:
43
+ benchmark_name (str): Name of the benchmark
44
+ data_dir (str): Directory containing benchmark data
45
+ **kwargs: Additional configuration parameters
46
+
47
+ Returns:
48
+ Benchmark: The configured benchmark
49
+ """
50
+ benchmark_map = {
51
+ "chest_agent_bench": ChestAgentBench,
52
+ "rexvqa": ReXVQABenchmark,
53
+ }
54
+
55
+ if benchmark_name not in benchmark_map:
56
+ raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
57
+
58
+ benchmark_class = benchmark_map[benchmark_name]
59
+ return benchmark_class(data_dir, **kwargs)
60
+
61
+
62
+ def run_benchmark_command(args) -> None:
63
+ """Run a benchmark."""
64
+ print(f"Running benchmark: {args.benchmark} with model: {args.model}")
65
+
66
+ # Create LLM provider
67
+ provider_kwargs = {}
68
+ if args.provider == "medrax":
69
+ provider_kwargs = {
70
+ "tools_to_use": args.medrax_tools.split(",") if args.medrax_tools else None,
71
+ "model_dir": args.model_dir,
72
+ "temp_dir": args.temp_dir,
73
+ "device": args.device,
74
+ "rag_config": None, # You might want to add RAG config options
75
+ }
76
+
77
+ llm_provider = create_llm_provider(args.model, args.provider, **provider_kwargs)
78
+
79
+ # Create benchmark
80
+ benchmark_kwargs = {}
81
+
82
+ benchmark = create_benchmark(args.benchmark, args.data_dir, **benchmark_kwargs)
83
+
84
+ # Create runner config
85
+ config = BenchmarkRunConfig(
86
+ model_name=args.model,
87
+ benchmark_name=args.benchmark,
88
+ output_dir=args.output_dir,
89
+ max_questions=args.max_questions,
90
+ start_index=args.start_index,
91
+ temperature=args.temperature,
92
+ max_tokens=args.max_tokens,
93
+ system_prompt=args.system_prompt,
94
+ save_frequency=args.save_frequency,
95
+ log_level=args.log_level,
96
+ )
97
+
98
+ # Run benchmark
99
+ runner = BenchmarkRunner(config)
100
+ summary = runner.run_benchmark(llm_provider, benchmark)
101
+
102
+ print("\n" + "="*50)
103
+ print("BENCHMARK COMPLETED")
104
+ print("="*50)
105
+ print(f"Overall Accuracy: {summary['results']['accuracy']:.2f}%")
106
+ print(f"Total Questions: {summary['results']['total_questions']}")
107
+ print(f"Correct Answers: {summary['results']['correct_answers']}")
108
+ print(f"Total Duration: {summary['results']['total_duration']:.2f}s")
109
+ print(f"Results saved to: {summary['results_file']}")
110
+
111
+
112
+ def evaluate_results_command(args) -> None:
113
+ """Evaluate benchmark results."""
114
+ print(f"Evaluating results: {args.results_files}")
115
+
116
+ evaluator = BenchmarkEvaluator(args.output_dir)
117
+
118
+ if len(args.results_files) == 1:
119
+ # Single model evaluation
120
+ evaluation = evaluator.evaluate_single_run(args.results_files[0])
121
+ print("\n" + "="*50)
122
+ print("SINGLE MODEL EVALUATION")
123
+ print("="*50)
124
+ print(f"Model: {evaluation.model_name}")
125
+ print(f"Benchmark: {evaluation.benchmark_name}")
126
+ print(f"Overall Accuracy: {evaluation.overall_accuracy:.2f}%")
127
+ print(f"Total Questions: {evaluation.total_questions}")
128
+ print(f"Error Rate: {evaluation.error_rate:.2f}%")
129
+ print(f"Total Duration: {evaluation.total_duration:.2f}s")
130
+
131
+ if evaluation.category_accuracies:
132
+ print("\nCategory Accuracies:")
133
+ for category, accuracy in evaluation.category_accuracies.items():
134
+ print(f" {category}: {accuracy:.2f}%")
135
+
136
+ else:
137
+ # Multiple model comparison
138
+ comparison = evaluator.compare_models(args.results_files)
139
+
140
+ if "error" in comparison:
141
+ print(f"Error: {comparison['error']}")
142
+ return
143
+
144
+ print("\n" + "="*50)
145
+ print("MODEL COMPARISON")
146
+ print("="*50)
147
+
148
+ summary = comparison["summary"]
149
+ print(f"Models Compared: {summary['models_compared']}")
150
+ print(f"Best Overall Accuracy: {summary['best_overall_accuracy']:.2f}%")
151
+ print(f"Accuracy Range: {summary['accuracy_range'][0]:.2f}% - {summary['accuracy_range'][1]:.2f}%")
152
+
153
+ best_model = comparison["best_model"]
154
+ print(f"\nBest Model: {best_model['Model']} ({best_model['Accuracy (%)']:.2f}%)")
155
+
156
+ # Generate comprehensive report
157
+ report_path = evaluator.generate_report(args.results_files, args.report_name)
158
+ print(f"\nDetailed report saved to: {report_path}")
159
+
160
+ # Statistical significance test
161
+ if args.statistical_test:
162
+ print("\nRunning statistical significance tests...")
163
+ sig_results = evaluator.statistical_significance_test(args.results_files)
164
+ print(f"Found {len(sig_results['comparisons'])} pairwise comparisons")
165
+
166
+ for comp in sig_results["comparisons"]:
167
+ significance = "significant" if comp["significant"] else "not significant"
168
+ print(f"{comp['model1']} vs {comp['model2']}: {significance} (p={comp['p_value']:.4f})")
169
+
170
+
171
+ def list_providers_command(args) -> None:
172
+ """List available LLM providers."""
173
+ print("Available LLM Providers:")
174
+ print("- openai: OpenAI GPT models")
175
+ print("- google: Google Gemini models")
176
+ print("- openrouter: OpenRouter API (multiple models)")
177
+ print("- medrax: MedRAX agent system")
178
+
179
+
180
+ def list_benchmarks_command(args) -> None:
181
+ """List available benchmarks."""
182
+ print("Available Benchmarks:")
183
+ print("- rexvqa: ReXVQA (large-scale chest radiology VQA)")
184
+
185
+
186
+ def main():
187
+ """Main CLI entry point."""
188
+ parser = argparse.ArgumentParser(description="MedRAX Benchmarking Pipeline")
189
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
190
+
191
+ # Run benchmark command
192
+ run_parser = subparsers.add_parser("run", help="Run a benchmark")
193
+ run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
194
+ run_parser.add_argument("--provider", required=True, choices=["openai", "google", "openrouter", "medrax"], help="LLM provider")
195
+ run_parser.add_argument("--benchmark", required=True, choices=["chest_agent_bench", "rexvqa"], help="Benchmark to run")
196
+ run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
197
+ run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
198
+ run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
199
+ run_parser.add_argument("--start-index", type=int, default=0, help="Starting index for questions")
200
+ run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
201
+ run_parser.add_argument("--max-tokens", type=int, default=1500, help="Maximum tokens per response")
202
+ run_parser.add_argument("--system-prompt", help="System prompt for the model")
203
+ run_parser.add_argument("--save-frequency", type=int, default=10, help="Save results every N questions")
204
+ run_parser.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
205
+
206
+ # MedRAX-specific arguments
207
+ run_parser.add_argument("--medrax-tools", help="Comma-separated list of tools for MedRAX (e.g., WebBrowserTool,MedicalRAGTool)")
208
+ run_parser.add_argument("--model-dir", default="/model-weights", help="Directory containing model weights for MedRAX")
209
+ run_parser.add_argument("--temp-dir", default="temp", help="Temporary directory for MedRAX")
210
+ run_parser.add_argument("--device", default="cuda", help="Device for MedRAX models")
211
+
212
+
213
+
214
+ run_parser.set_defaults(func=run_benchmark_command)
215
+
216
+ # Evaluate results command
217
+ eval_parser = subparsers.add_parser("evaluate", help="Evaluate benchmark results")
218
+ eval_parser.add_argument("results_files", nargs="+", help="Path(s) to results files")
219
+ eval_parser.add_argument("--output-dir", default="evaluation_results", help="Output directory for evaluation")
220
+ eval_parser.add_argument("--report-name", default="evaluation_report", help="Name for the evaluation report")
221
+ eval_parser.add_argument("--statistical-test", action="store_true", help="Run statistical significance tests")
222
+ eval_parser.set_defaults(func=evaluate_results_command)
223
+
224
+ # List providers command
225
+ list_providers_parser = subparsers.add_parser("list-providers", help="List available LLM providers")
226
+ list_providers_parser.set_defaults(func=list_providers_command)
227
+
228
+ # List benchmarks command
229
+ list_benchmarks_parser = subparsers.add_parser("list-benchmarks", help="List available benchmarks")
230
+ list_benchmarks_parser.set_defaults(func=list_benchmarks_command)
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.command is None:
235
+ parser.print_help()
236
+ return
237
+
238
+ try:
239
+ args.func(args)
240
+ except Exception as e:
241
+ print(f"Error: {e}")
242
+ sys.exit(1)
243
+
244
+
245
+ if __name__ == "__main__":
246
+ main()
benchmarking/evaluation.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation code for analyzing benchmark results."""
2
+
3
+ import json
4
+ import pandas as pd
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Any, Tuple
8
+ from dataclasses import dataclass
9
+ from collections import defaultdict
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+
13
+
14
+ @dataclass
15
+ class EvaluationResult:
16
+ """Results of evaluating a benchmark run."""
17
+ model_name: str
18
+ benchmark_name: str
19
+ overall_accuracy: float
20
+ total_questions: int
21
+ correct_answers: int
22
+ total_duration: float
23
+ category_accuracies: Dict[str, float]
24
+ category_counts: Dict[str, int]
25
+ error_rate: float
26
+ avg_duration_per_question: float
27
+
28
+
29
+ class BenchmarkEvaluator:
30
+ """Class for evaluating and comparing benchmark results."""
31
+
32
+ def __init__(self, output_dir: str = "evaluation_results"):
33
+ """Initialize the evaluator.
34
+
35
+ Args:
36
+ output_dir (str): Directory to save evaluation results
37
+ """
38
+ self.output_dir = Path(output_dir)
39
+ self.output_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ def load_results(self, results_file: str) -> Dict[str, Any]:
42
+ """Load benchmark results from file.
43
+
44
+ Args:
45
+ results_file (str): Path to the results file
46
+
47
+ Returns:
48
+ Dict[str, Any]: Loaded results data
49
+ """
50
+ with open(results_file, 'r') as f:
51
+ return json.load(f)
52
+
53
+ def evaluate_single_run(self, results_file: str) -> EvaluationResult:
54
+ """Evaluate a single benchmark run.
55
+
56
+ Args:
57
+ results_file (str): Path to the results file
58
+
59
+ Returns:
60
+ EvaluationResult: Evaluation results
61
+ """
62
+ results = self.load_results(results_file)
63
+
64
+ # Calculate basic metrics
65
+ total_questions = len(results)
66
+ correct_answers = sum(1 for r in results if r.get("is_correct", False))
67
+ total_duration = sum(r.get("duration", 0) for r in results)
68
+ errors = sum(1 for r in results if r.get("error") is not None)
69
+
70
+ overall_accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0
71
+ error_rate = (errors / total_questions) * 100 if total_questions > 0 else 0
72
+
73
+ # Calculate per-category metrics
74
+ category_stats = defaultdict(lambda: {"correct": 0, "total": 0})
75
+
76
+ for result in results:
77
+ metadata = result.get("metadata", {})
78
+ category = metadata.get("category")
79
+
80
+ if category:
81
+ category_stats[category]["total"] += 1
82
+ if result.get("is_correct", False):
83
+ category_stats[category]["correct"] += 1
84
+
85
+ # Calculate category accuracies
86
+ category_accuracies = {}
87
+ category_counts = {}
88
+ for category, stats in category_stats.items():
89
+ category_accuracies[category] = (stats["correct"] / stats["total"]) * 100
90
+ category_counts[category] = stats["total"]
91
+
92
+ # Extract model and benchmark names (assuming they're in the filename or metadata)
93
+ results_path = Path(results_file)
94
+ filename_parts = results_path.stem.split("_")
95
+
96
+ model_name = "unknown"
97
+ benchmark_name = "unknown"
98
+
99
+ if len(filename_parts) >= 2:
100
+ benchmark_name = filename_parts[0]
101
+ model_name = filename_parts[1]
102
+
103
+ return EvaluationResult(
104
+ model_name=model_name,
105
+ benchmark_name=benchmark_name,
106
+ overall_accuracy=overall_accuracy,
107
+ total_questions=total_questions,
108
+ correct_answers=correct_answers,
109
+ total_duration=total_duration,
110
+ category_accuracies=category_accuracies,
111
+ category_counts=category_counts,
112
+ error_rate=error_rate,
113
+ avg_duration_per_question=total_duration / total_questions if total_questions > 0 else 0,
114
+ )
benchmarking/llm_providers/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  """LLM provider abstractions for benchmarking."""
2
 
3
- from .base import LLMProvider
4
  from .openai_provider import OpenAIProvider
5
  from .google_provider import GoogleProvider
6
  from .openrouter_provider import OpenRouterProvider
@@ -8,6 +8,8 @@ from .medrax_provider import MedRAXProvider
8
 
9
  __all__ = [
10
  "LLMProvider",
 
 
11
  "OpenAIProvider",
12
  "GoogleProvider",
13
  "OpenRouterProvider",
 
1
  """LLM provider abstractions for benchmarking."""
2
 
3
+ from .base import LLMProvider, LLMRequest, LLMResponse
4
  from .openai_provider import OpenAIProvider
5
  from .google_provider import GoogleProvider
6
  from .openrouter_provider import OpenRouterProvider
 
8
 
9
  __all__ = [
10
  "LLMProvider",
11
+ "LLMRequest",
12
+ "LLMResponse",
13
  "OpenAIProvider",
14
  "GoogleProvider",
15
  "OpenRouterProvider",
benchmarking/llm_providers/base.py CHANGED
@@ -63,6 +63,25 @@ class LLMProvider(ABC):
63
  """
64
  pass
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def _encode_image(self, image_path: str) -> str:
67
  """Encode image to base64 string.
68
 
 
63
  """
64
  pass
65
 
66
+ def test_connection(self) -> bool:
67
+ """Test the connection to the LLM provider.
68
+
69
+ Returns:
70
+ bool: True if connection is successful, False otherwise
71
+ """
72
+ try:
73
+ # Simple test request
74
+ test_request = LLMRequest(
75
+ text="Hello",
76
+ temperature=0.0,
77
+ max_tokens=10
78
+ )
79
+ response = self.generate_response(test_request)
80
+ return response.content is not None and len(response.content.strip()) > 0
81
+ except Exception as e:
82
+ print(f"Connection test failed: {e}")
83
+ return False
84
+
85
  def _encode_image(self, image_path: str) -> str:
86
  """Encode image to base64 string.
87
 
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -13,7 +13,8 @@ from .base import LLMProvider, LLMRequest, LLMResponse
13
  # Import MedRAX components
14
  from medrax.agent import Agent
15
  from medrax.tools import *
16
- from medrax.utils import load_prompts_from_file, RAGConfig
 
17
  from medrax.models import ModelFactory
18
  from langgraph.checkpoint.memory import MemorySaver
19
  from langchain_core.messages import HumanMessage
 
13
  # Import MedRAX components
14
  from medrax.agent import Agent
15
  from medrax.tools import *
16
+ from medrax.utils import load_prompts_from_file
17
+ from medrax.rag.rag import RAGConfig
18
  from medrax.models import ModelFactory
19
  from langgraph.checkpoint.memory import MemorySaver
20
  from langchain_core.messages import HumanMessage
benchmarking/runner.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main test runner for benchmarking pipeline."""
2
+
3
+ import json
4
+ import time
5
+ import logging
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Dict, Optional, Any
9
+ from dataclasses import dataclass
10
+ from tqdm import tqdm
11
+ import re
12
+ from .llm_providers import LLMProvider, LLMRequest, LLMResponse
13
+ from .benchmarks import Benchmark, BenchmarkDataPoint
14
+
15
+
16
+ @dataclass
17
+ class BenchmarkResult:
18
+ """Result of running a benchmark on a single data point."""
19
+ data_point_id: str
20
+ question: str
21
+ model_answer: str
22
+ correct_answer: str
23
+ is_correct: bool
24
+ duration: float
25
+ usage: Optional[Dict[str, Any]] = None
26
+ error: Optional[str] = None
27
+ metadata: Optional[Dict[str, Any]] = None
28
+
29
+
30
+ @dataclass
31
+ class BenchmarkRunConfig:
32
+ """Configuration for a benchmark run."""
33
+ model_name: str
34
+ benchmark_name: str
35
+ output_dir: str
36
+ max_questions: Optional[int] = None
37
+ start_index: int = 0
38
+ temperature: float = 0.7
39
+ max_tokens: int = 1500
40
+ system_prompt: Optional[str] = None
41
+ save_frequency: int = 10 # Save results every N questions
42
+ log_level: str = "INFO"
43
+ additional_params: Optional[Dict[str, Any]] = None
44
+
45
+
46
+ class BenchmarkRunner:
47
+ """Main class for running benchmarks against LLM providers."""
48
+
49
+ def __init__(self, config: BenchmarkRunConfig):
50
+ """Initialize the benchmark runner.
51
+
52
+ Args:
53
+ config (BenchmarkRunConfig): Configuration for the benchmark run
54
+ """
55
+ self.config = config
56
+ self.results = []
57
+ self.output_dir = Path(config.output_dir)
58
+ self.output_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ # Set up logging
61
+ self._setup_logging()
62
+
63
+ # Generate unique run ID
64
+ self.run_id = f"{config.benchmark_name}_{config.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
65
+
66
+ self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
67
+
68
+ def _setup_logging(self) -> None:
69
+ """Set up logging configuration."""
70
+ log_file = self.output_dir / f"benchmark_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
71
+
72
+ # Create logger
73
+ self.logger = logging.getLogger(f"benchmark_runner_{self.run_id}")
74
+ self.logger.setLevel(getattr(logging, self.config.log_level))
75
+
76
+ # Create handlers
77
+ file_handler = logging.FileHandler(log_file)
78
+ console_handler = logging.StreamHandler()
79
+
80
+ # Create formatter
81
+ formatter = logging.Formatter(
82
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
83
+ )
84
+ file_handler.setFormatter(formatter)
85
+ console_handler.setFormatter(formatter)
86
+
87
+ # Add handlers to logger
88
+ self.logger.addHandler(file_handler)
89
+ self.logger.addHandler(console_handler)
90
+
91
+ def run_benchmark(
92
+ self,
93
+ llm_provider: LLMProvider,
94
+ benchmark: Benchmark,
95
+ ) -> Dict[str, Any]:
96
+ """Run a benchmark against an LLM provider.
97
+
98
+ Args:
99
+ llm_provider (LLMProvider): The LLM provider to test
100
+ benchmark (Benchmark): The benchmark to run
101
+
102
+ Returns:
103
+ Dict[str, Any]: Summary of benchmark results
104
+ """
105
+ self.logger.info(f"Starting benchmark run: {self.run_id}")
106
+ self.logger.info(f"Model: {llm_provider.model_name}")
107
+ self.logger.info(f"Benchmark: {benchmark}")
108
+
109
+ # Test provider connection
110
+ if not llm_provider.test_connection():
111
+ self.logger.error("LLM provider connection test failed")
112
+ return {"error": "LLM provider connection test failed"}
113
+
114
+ # Get data points to process
115
+ total_questions = len(benchmark)
116
+ max_questions = self.config.max_questions or total_questions
117
+ end_index = min(self.config.start_index + max_questions, total_questions)
118
+
119
+ self.logger.info(f"Processing questions {self.config.start_index} to {end_index-1} of {total_questions}")
120
+
121
+ # Initialize counters
122
+ processed = 0
123
+ correct = 0
124
+ total_duration = 0.0
125
+
126
+ # Process each data point
127
+ for i in tqdm(range(self.config.start_index, end_index), desc="Processing questions"):
128
+ try:
129
+ data_point = benchmark.get_data_point(i)
130
+
131
+ # Run the model on this data point
132
+ result = self._process_data_point(llm_provider, data_point)
133
+
134
+ # Update counters
135
+ processed += 1
136
+ if result.is_correct:
137
+ correct += 1
138
+ total_duration += result.duration
139
+
140
+ # Add to results
141
+ self.results.append(result)
142
+
143
+ # Log progress
144
+ if processed % self.config.save_frequency == 0:
145
+ self._save_intermediate_results()
146
+ accuracy = (correct / processed) * 100
147
+ avg_duration = total_duration / processed
148
+
149
+ self.logger.info(
150
+ f"Progress: {processed}/{end_index - self.config.start_index} | "
151
+ f"Accuracy: {accuracy:.2f}% | "
152
+ f"Avg Duration: {avg_duration:.2f}s"
153
+ )
154
+
155
+ except Exception as e:
156
+ self.logger.error(f"Error processing data point {i}: {e}")
157
+ # Add error result
158
+ error_result = BenchmarkResult(
159
+ data_point_id=f"error_{i}",
160
+ question="",
161
+ model_answer="",
162
+ correct_answer="",
163
+ is_correct=False,
164
+ duration=0.0,
165
+ error=str(e)
166
+ )
167
+ self.results.append(error_result)
168
+ continue
169
+
170
+ # Save final results
171
+ summary = self._save_final_results(benchmark)
172
+
173
+ self.logger.info(f"Benchmark run completed: {self.run_id}")
174
+ self.logger.info(f"Final accuracy: {summary['results']['accuracy']:.2f}%")
175
+ self.logger.info(f"Total duration: {summary['results']['total_duration']:.2f}s")
176
+
177
+ return summary
178
+
179
+ def _process_data_point(
180
+ self,
181
+ llm_provider: LLMProvider,
182
+ data_point: BenchmarkDataPoint,
183
+ ) -> BenchmarkResult:
184
+ """Process a single data point.
185
+
186
+ Args:
187
+ llm_provider (LLMProvider): The LLM provider to use
188
+ data_point (BenchmarkDataPoint): The data point to process
189
+
190
+ Returns:
191
+ BenchmarkResult: Result of processing the data point
192
+ """
193
+ start_time = time.time()
194
+
195
+ try:
196
+ # Create request
197
+ request = LLMRequest(
198
+ text=data_point.text,
199
+ images=data_point.images,
200
+ system_prompt=self.config.system_prompt,
201
+ temperature=self.config.temperature,
202
+ max_tokens=self.config.max_tokens,
203
+ additional_params=self.config.additional_params
204
+ )
205
+
206
+ # Get response from LLM
207
+ response: LLMResponse = llm_provider.generate_response(request)
208
+
209
+ # Extract answer (this may need customization based on benchmark)
210
+ model_answer = self._extract_answer(response.content)
211
+
212
+ # Check if correct
213
+ is_correct = self._is_correct_answer(model_answer, data_point.correct_answer)
214
+
215
+ duration = time.time() - start_time
216
+
217
+ return BenchmarkResult(
218
+ data_point_id=data_point.id,
219
+ question=data_point.text,
220
+ model_answer=model_answer,
221
+ correct_answer=data_point.correct_answer,
222
+ is_correct=is_correct,
223
+ duration=duration,
224
+ usage=response.usage,
225
+ metadata={
226
+ "data_point_metadata": data_point.metadata,
227
+ "case_id": data_point.case_id,
228
+ "category": data_point.category,
229
+ "raw_response": response.content,
230
+ }
231
+ )
232
+
233
+ except Exception as e:
234
+ duration = time.time() - start_time
235
+ return BenchmarkResult(
236
+ data_point_id=data_point.id,
237
+ question=data_point.text,
238
+ model_answer="",
239
+ correct_answer=data_point.correct_answer,
240
+ is_correct=False,
241
+ duration=duration,
242
+ error=str(e),
243
+ metadata={
244
+ "data_point_metadata": data_point.metadata,
245
+ "case_id": data_point.case_id,
246
+ "category": data_point.category,
247
+ }
248
+ )
249
+
250
+ def _extract_answer(self, response_text: str) -> str:
251
+ """Extract the answer from the model response.
252
+
253
+ Args:
254
+ response_text (str): The full response text from the model
255
+
256
+ Returns:
257
+ str: The extracted answer
258
+ """
259
+ # This is a simple implementation - may need customization per benchmark
260
+ # For multiple choice, look for single letters A, B, C, D, E, F
261
+
262
+ # Look for patterns like "A", "B)", "(C)", "Answer: D", etc.
263
+ patterns = [
264
+ r'\b([A-F])\b', # Single letter
265
+ r'\b([A-F])\)', # Letter with closing parenthesis
266
+ r'\(([A-F])\)', # Letter in parentheses
267
+ r'[Aa]nswer\s*:?\s*([A-F])', # "Answer: X" format
268
+ r'[Cc]hoice\s*:?\s*([A-F])', # "Choice: X" format
269
+ ]
270
+
271
+ for pattern in patterns:
272
+ match = re.search(pattern, response_text)
273
+ if match:
274
+ return match.group(1).upper()
275
+
276
+ # If no pattern matches, return the first letter found
277
+ letters = re.findall(r'\b[A-F]\b', response_text)
278
+ if letters:
279
+ return letters[0].upper()
280
+
281
+ # If no letters found, return the full response (truncated)
282
+ return response_text.strip()[:100]
283
+
284
+ def _is_correct_answer(self, model_answer: str, correct_answer: str) -> bool:
285
+ """Check if the model answer is correct.
286
+
287
+ Args:
288
+ model_answer (str): The model's answer
289
+ correct_answer (str): The correct answer
290
+
291
+ Returns:
292
+ bool: True if the answer is correct
293
+ """
294
+ if not model_answer or not correct_answer:
295
+ return False
296
+
297
+ # For multiple choice, compare just the letter
298
+ model_clean = model_answer.strip().upper()
299
+ correct_clean = correct_answer.strip().upper()
300
+
301
+ # Extract just the first letter for comparison
302
+ model_letter = model_clean[0] if model_clean else ""
303
+ correct_letter = correct_clean[0] if correct_clean else ""
304
+
305
+ return model_letter == correct_letter
306
+
307
+ def _save_intermediate_results(self) -> None:
308
+ """Save intermediate results to disk."""
309
+ results_file = self.output_dir / f"{self.run_id}_intermediate.json"
310
+
311
+ # Convert results to serializable format
312
+ results_data = []
313
+ for result in self.results:
314
+ results_data.append({
315
+ "data_point_id": result.data_point_id,
316
+ "question": result.question,
317
+ "model_answer": result.model_answer,
318
+ "correct_answer": result.correct_answer,
319
+ "is_correct": result.is_correct,
320
+ "duration": result.duration,
321
+ "usage": result.usage,
322
+ "error": result.error,
323
+ "metadata": result.metadata,
324
+ })
325
+
326
+ with open(results_file, 'w') as f:
327
+ json.dump(results_data, f, indent=2)
328
+
329
+ def _save_final_results(self, benchmark: Benchmark) -> Dict[str, Any]:
330
+ """Save final results and return summary.
331
+
332
+ Args:
333
+ benchmark (Benchmark): The benchmark that was run
334
+
335
+ Returns:
336
+ Dict[str, Any]: Summary of results
337
+ """
338
+ # Save detailed results
339
+ results_file = self.output_dir / f"{self.run_id}_results.json"
340
+ self._save_intermediate_results()
341
+
342
+ # Calculate summary statistics
343
+ total_questions = len(self.results)
344
+ correct_answers = sum(1 for r in self.results if r.is_correct)
345
+ total_duration = sum(r.duration for r in self.results)
346
+
347
+ accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0
348
+
349
+ # Calculate per-category accuracy
350
+ category_stats = {}
351
+ for result in self.results:
352
+ if result.metadata and result.metadata.get("category"):
353
+ category = result.metadata["category"]
354
+ if category not in category_stats:
355
+ category_stats[category] = {"correct": 0, "total": 0}
356
+ category_stats[category]["total"] += 1
357
+ if result.is_correct:
358
+ category_stats[category]["correct"] += 1
359
+
360
+ # Calculate accuracy for each category
361
+ category_accuracies = {}
362
+ for category, stats in category_stats.items():
363
+ category_accuracies[category] = (stats["correct"] / stats["total"]) * 100
364
+
365
+ # Create summary
366
+ summary = {
367
+ "run_id": self.run_id,
368
+ "timestamp": datetime.now().isoformat(),
369
+ "config": {
370
+ "model_name": self.config.model_name,
371
+ "benchmark_name": self.config.benchmark_name,
372
+ "temperature": self.config.temperature,
373
+ "max_tokens": self.config.max_tokens,
374
+ "system_prompt": self.config.system_prompt,
375
+ },
376
+ "benchmark_info": {
377
+ "total_size": len(benchmark),
378
+ "processed_questions": total_questions,
379
+ "start_index": self.config.start_index,
380
+ },
381
+ "results": {
382
+ "accuracy": accuracy,
383
+ "correct_answers": correct_answers,
384
+ "total_questions": total_questions,
385
+ "total_duration": total_duration,
386
+ "avg_duration_per_question": total_duration / total_questions if total_questions > 0 else 0,
387
+ "category_accuracies": category_accuracies,
388
+ },
389
+ "results_file": str(results_file),
390
+ }
391
+
392
+ # Save summary
393
+ summary_file = self.output_dir / f"{self.run_id}_summary.json"
394
+ with open(summary_file, 'w') as f:
395
+ json.dump(summary, f, indent=2)
396
+
397
+ return summary
pyproject.toml CHANGED
@@ -72,6 +72,8 @@ dependencies = [
72
  "langchain-google-genai>=0.1.0",
73
  "ray>=2.9.0",
74
  "langchain-sandbox>=0.0.6",
 
 
75
  ]
76
 
77
  [project.optional-dependencies]
 
72
  "langchain-google-genai>=0.1.0",
73
  "ray>=2.9.0",
74
  "langchain-sandbox>=0.0.6",
75
+ "seaborn>=0.12.0",
76
+ "huggingface_hub>=0.17.0",
77
  ]
78
 
79
  [project.optional-dependencies]