Spaces:
Paused
Paused
Merge pull request #26 from bowang-lab/victor/benchmarking
Browse files- .env.example +0 -10
- benchmarking/benchmarks/rexvqa_benchmark.py +113 -31
- benchmarking/cli.py +4 -1
- benchmarking/llm_providers/base.py +1 -1
- benchmarking/llm_providers/medrax_provider.py +5 -5
- benchmarking/runner.py +54 -35
- benchmarking_script.sh +14 -0
- main.py +1 -1
- medgemma_script.sh +14 -0
- medrax/docs/system_prompts.txt +13 -6
- medrax/models/model_factory.py +10 -0
- medrax/tools/classification/torchxrayvision.py +5 -1
- medrax/tools/segmentation/segmentation.py +4 -1
- medrax/utils/utils.py +68 -1
.env.example
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
OPENAI_API_KEY=
|
| 2 |
-
OPENAI_BASE_URL=
|
| 3 |
-
GOOGLE_API_KEY=
|
| 4 |
-
GOOGLE_SEARCH_API_KEY=
|
| 5 |
-
GOOGLE_SEARCH_ENGINE_ID=
|
| 6 |
-
OPENROUTER_API_KEY=
|
| 7 |
-
OPENROUTER_BASE_URL=
|
| 8 |
-
COHERE_API_KEY=
|
| 9 |
-
PINECONE_API_KEY=
|
| 10 |
-
MEDGEMMA_API_URL=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarking/benchmarks/rexvqa_benchmark.py
CHANGED
|
@@ -7,6 +7,8 @@ from datasets import load_dataset
|
|
| 7 |
from .base import Benchmark, BenchmarkDataPoint
|
| 8 |
from pathlib import Path
|
| 9 |
import subprocess
|
|
|
|
|
|
|
| 10 |
from huggingface_hub import hf_hub_download, list_repo_files
|
| 11 |
|
| 12 |
|
|
@@ -44,14 +46,20 @@ class ReXVQABenchmark(Benchmark):
|
|
| 44 |
self.image_dataset = None
|
| 45 |
self.image_mapping = {} # Maps study_id to image data
|
| 46 |
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
self.images_dir = f"{self.data_dir}/images/deid_png"
|
| 51 |
|
| 52 |
@staticmethod
|
| 53 |
-
def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
|
| 54 |
-
"""Download and extract ReXGradient-160K images if not already present.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
output_dir = Path(output_dir)
|
| 56 |
tar_path = output_dir / "deid_png.tar"
|
| 57 |
images_dir = output_dir / "images"
|
|
@@ -60,6 +68,33 @@ class ReXVQABenchmark(Benchmark):
|
|
| 60 |
if images_dir.exists() and any(images_dir.rglob("*.png")):
|
| 61 |
print(f"Images already exist in {images_dir}, skipping download.")
|
| 62 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 64 |
print(f"Output directory: {output_dir}")
|
| 65 |
try:
|
|
@@ -96,6 +131,17 @@ class ReXVQABenchmark(Benchmark):
|
|
| 96 |
tar_file.write(f.read())
|
| 97 |
else:
|
| 98 |
print(f"Warning: {part_file} not found, skipping...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
else:
|
| 100 |
print(f"Tar file already exists: {tar_path}")
|
| 101 |
# Extract tar file
|
|
@@ -106,36 +152,72 @@ class ReXVQABenchmark(Benchmark):
|
|
| 106 |
print("Images already extracted.")
|
| 107 |
else:
|
| 108 |
try:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
print(f"Error extracting tar file: {e}")
|
| 116 |
return
|
| 117 |
-
except FileNotFoundError:
|
| 118 |
-
print("Error: 'tar' command not found. Please install tar or extract manually.")
|
| 119 |
-
return
|
| 120 |
png_files = list(images_dir.rglob("*.png"))
|
| 121 |
print(f"Extracted {len(png_files)} PNG images to {images_dir}")
|
| 122 |
|
| 123 |
-
# Clean up part and tar files after successful extraction
|
| 124 |
-
print("Cleaning up part and tar files...")
|
| 125 |
-
# Remove deid_png.part* files
|
| 126 |
-
for part_file in output_dir.glob("deid_png.part*"):
|
| 127 |
-
try:
|
| 128 |
-
part_file.unlink()
|
| 129 |
-
print(f"Deleted {part_file}")
|
| 130 |
-
except Exception as e:
|
| 131 |
-
print(f"Could not delete {part_file}: {e}")
|
| 132 |
-
# Remove deid_png.tar
|
| 133 |
-
if tar_path.exists():
|
| 134 |
-
try:
|
| 135 |
-
tar_path.unlink()
|
| 136 |
-
print(f"Deleted {tar_path}")
|
| 137 |
-
except Exception as e:
|
| 138 |
-
print(f"Could not delete {tar_path}: {e}")
|
| 139 |
except Exception as e:
|
| 140 |
print(f"Error: {e}")
|
| 141 |
|
|
@@ -167,7 +249,7 @@ class ReXVQABenchmark(Benchmark):
|
|
| 167 |
try:
|
| 168 |
# Check for images and test_vqa_data.json, download if missing
|
| 169 |
self.download_test_vqa_data_json(self.data_dir)
|
| 170 |
-
self.download_rexgradient_images(self.data_dir)
|
| 171 |
|
| 172 |
# Construct path to the JSON file
|
| 173 |
json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
|
|
|
|
| 7 |
from .base import Benchmark, BenchmarkDataPoint
|
| 8 |
from pathlib import Path
|
| 9 |
import subprocess
|
| 10 |
+
import tarfile
|
| 11 |
+
import zstandard as zstd
|
| 12 |
from huggingface_hub import hf_hub_download, list_repo_files
|
| 13 |
|
| 14 |
|
|
|
|
| 46 |
self.image_dataset = None
|
| 47 |
self.image_mapping = {} # Maps study_id to image data
|
| 48 |
|
| 49 |
+
# Set images_dir BEFORE parent initialization to avoid AttributeError
|
| 50 |
+
self.images_dir = f"{data_dir}/images/deid_png"
|
| 51 |
|
| 52 |
+
super().__init__(data_dir, **kwargs)
|
|
|
|
| 53 |
|
| 54 |
@staticmethod
|
| 55 |
+
def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
|
| 56 |
+
"""Download and extract ReXGradient-160K images if not already present.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
output_dir: Directory to store downloaded and extracted images
|
| 60 |
+
repo_id: HuggingFace repository ID for the dataset
|
| 61 |
+
test_only: If True, only extract images from the test split (default: True)
|
| 62 |
+
"""
|
| 63 |
output_dir = Path(output_dir)
|
| 64 |
tar_path = output_dir / "deid_png.tar"
|
| 65 |
images_dir = output_dir / "images"
|
|
|
|
| 68 |
if images_dir.exists() and any(images_dir.rglob("*.png")):
|
| 69 |
print(f"Images already exist in {images_dir}, skipping download.")
|
| 70 |
return
|
| 71 |
+
|
| 72 |
+
# Load test split metadata if test_only is True
|
| 73 |
+
test_image_paths = set()
|
| 74 |
+
if test_only:
|
| 75 |
+
print("Loading test split metadata to identify test images...")
|
| 76 |
+
try:
|
| 77 |
+
# Load the test metadata to get image paths
|
| 78 |
+
test_metadata_path = output_dir / "metadata" / "test_vqa_data.json"
|
| 79 |
+
if test_metadata_path.exists():
|
| 80 |
+
with open(test_metadata_path, 'r', encoding='utf-8') as f:
|
| 81 |
+
test_data = json.load(f)
|
| 82 |
+
|
| 83 |
+
# Extract all image paths from test data
|
| 84 |
+
for item in test_data.values():
|
| 85 |
+
if "ImagePath" in item and item["ImagePath"]:
|
| 86 |
+
for rel_path in item["ImagePath"]:
|
| 87 |
+
# Normalize path to match tar file structure
|
| 88 |
+
norm_path = rel_path.lstrip("./")
|
| 89 |
+
test_image_paths.add(norm_path)
|
| 90 |
+
|
| 91 |
+
print(f"Found {len(test_image_paths)} test images to extract")
|
| 92 |
+
else:
|
| 93 |
+
print("Warning: test_vqa_data.json not found, will extract all images")
|
| 94 |
+
test_only = False
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"Warning: Could not load test metadata: {e}, will extract all images")
|
| 97 |
+
test_only = False
|
| 98 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 99 |
print(f"Output directory: {output_dir}")
|
| 100 |
try:
|
|
|
|
| 131 |
tar_file.write(f.read())
|
| 132 |
else:
|
| 133 |
print(f"Warning: {part_file} not found, skipping...")
|
| 134 |
+
|
| 135 |
+
# Clean up part files after successful concatenation
|
| 136 |
+
print("Cleaning up part files...")
|
| 137 |
+
for part_file in part_files:
|
| 138 |
+
part_path = output_dir / part_file
|
| 139 |
+
if part_path.exists():
|
| 140 |
+
try:
|
| 141 |
+
part_path.unlink()
|
| 142 |
+
print(f"Deleted {part_file}")
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Could not delete {part_file}: {e}")
|
| 145 |
else:
|
| 146 |
print(f"Tar file already exists: {tar_path}")
|
| 147 |
# Extract tar file
|
|
|
|
| 152 |
print("Images already extracted.")
|
| 153 |
else:
|
| 154 |
try:
|
| 155 |
+
# Stream extract with filtering for test-only images (no seeking)
|
| 156 |
+
print("Stream extracting zstd-compressed tar file with filtering (streaming mode)...")
|
| 157 |
+
|
| 158 |
+
# Create a decompressor
|
| 159 |
+
dctx = zstd.ZstdDecompressor()
|
| 160 |
+
|
| 161 |
+
# Stream extract with filtering
|
| 162 |
+
extracted_count = 0
|
| 163 |
+
total_png_members = 0
|
| 164 |
+
|
| 165 |
+
with open(tar_path, 'rb') as compressed_file:
|
| 166 |
+
with dctx.stream_reader(compressed_file) as decompressed_stream:
|
| 167 |
+
# Use streaming tar mode to avoid seeks
|
| 168 |
+
with tarfile.open(fileobj=decompressed_stream, mode='r|') as tar:
|
| 169 |
+
for member in tar:
|
| 170 |
+
# Only consider PNG files
|
| 171 |
+
if not member.isfile() or not member.name.endswith('.png'):
|
| 172 |
+
continue
|
| 173 |
+
total_png_members += 1
|
| 174 |
+
|
| 175 |
+
# Normalize name to match entries gathered from JSON
|
| 176 |
+
normalized_name = member.name.lstrip('./')
|
| 177 |
+
|
| 178 |
+
# Decide whether to extract this file
|
| 179 |
+
should_extract = True
|
| 180 |
+
if test_only:
|
| 181 |
+
should_extract = normalized_name in test_image_paths
|
| 182 |
+
|
| 183 |
+
if not should_extract:
|
| 184 |
+
# Must still advance the stream for this member
|
| 185 |
+
tar.members = [] # no-op in stream mode; ensure we don't hold refs
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
# Ensure parent directories exist and write file by streaming
|
| 189 |
+
target_path = Path(images_dir) / normalized_name
|
| 190 |
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
| 191 |
+
|
| 192 |
+
extracted_file_obj = tar.extractfile(member)
|
| 193 |
+
if extracted_file_obj is None:
|
| 194 |
+
continue
|
| 195 |
+
with open(target_path, 'wb') as out_f:
|
| 196 |
+
while True:
|
| 197 |
+
chunk = extracted_file_obj.read(1024 * 1024)
|
| 198 |
+
if not chunk:
|
| 199 |
+
break
|
| 200 |
+
out_f.write(chunk)
|
| 201 |
+
|
| 202 |
+
extracted_count += 1
|
| 203 |
+
if extracted_count % 100 == 0:
|
| 204 |
+
print(f"Extracted {extracted_count} test images...")
|
| 205 |
+
|
| 206 |
+
print(f"Extraction completed! Extracted {extracted_count} matching PNGs out of {total_png_members} PNG members in the archive")
|
| 207 |
+
|
| 208 |
+
# Clean up compressed tar file after successful extraction
|
| 209 |
+
print("Cleaning up compressed tar file...")
|
| 210 |
+
try:
|
| 211 |
+
tar_path.unlink()
|
| 212 |
+
print(f"Deleted {tar_path}")
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"Could not delete {tar_path}: {e}")
|
| 215 |
+
except Exception as e:
|
| 216 |
print(f"Error extracting tar file: {e}")
|
| 217 |
return
|
|
|
|
|
|
|
|
|
|
| 218 |
png_files = list(images_dir.rglob("*.png"))
|
| 219 |
print(f"Extracted {len(png_files)} PNG images to {images_dir}")
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
except Exception as e:
|
| 222 |
print(f"Error: {e}")
|
| 223 |
|
|
|
|
| 249 |
try:
|
| 250 |
# Check for images and test_vqa_data.json, download if missing
|
| 251 |
self.download_test_vqa_data_json(self.data_dir)
|
| 252 |
+
self.download_rexgradient_images(self.data_dir, test_only=True)
|
| 253 |
|
| 254 |
# Construct path to the JSON file
|
| 255 |
json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
|
benchmarking/cli.py
CHANGED
|
@@ -87,7 +87,8 @@ def run_benchmark_command(args) -> None:
|
|
| 87 |
max_questions=args.max_questions,
|
| 88 |
temperature=args.temperature,
|
| 89 |
top_p=args.top_p,
|
| 90 |
-
max_tokens=args.max_tokens
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
# Run benchmark
|
|
@@ -145,6 +146,8 @@ def main():
|
|
| 145 |
help="Maximum tokens per model response (default: 5000)")
|
| 146 |
run_parser.add_argument("--random-seed", type=int, default=42,
|
| 147 |
help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
|
|
|
|
|
|
|
| 148 |
|
| 149 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 150 |
|
|
|
|
| 87 |
max_questions=args.max_questions,
|
| 88 |
temperature=args.temperature,
|
| 89 |
top_p=args.top_p,
|
| 90 |
+
max_tokens=args.max_tokens,
|
| 91 |
+
concurrency=args.concurrency
|
| 92 |
)
|
| 93 |
|
| 94 |
# Run benchmark
|
|
|
|
| 146 |
help="Maximum tokens per model response (default: 5000)")
|
| 147 |
run_parser.add_argument("--random-seed", type=int, default=42,
|
| 148 |
help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
|
| 149 |
+
run_parser.add_argument("--concurrency", type=int, default=1,
|
| 150 |
+
help="Number of datapoints to process in parallel (default: 1)")
|
| 151 |
|
| 152 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 153 |
|
benchmarking/llm_providers/base.py
CHANGED
|
@@ -85,7 +85,7 @@ class LLMProvider(ABC):
|
|
| 85 |
try:
|
| 86 |
# Simple test request
|
| 87 |
test_request = LLMRequest(
|
| 88 |
-
text="Hello",
|
| 89 |
temperature=0.5,
|
| 90 |
max_tokens=1000
|
| 91 |
)
|
|
|
|
| 85 |
try:
|
| 86 |
# Simple test request
|
| 87 |
test_request = LLMRequest(
|
| 88 |
+
text="Hello! What model are you? Tell me your full specification.",
|
| 89 |
temperature=0.5,
|
| 90 |
max_tokens=1000
|
| 91 |
)
|
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -37,11 +37,11 @@ class MedRAXProvider(LLMProvider):
|
|
| 37 |
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 38 |
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 39 |
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 40 |
-
"MedGemmaVQATool",
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
]
|
| 46 |
|
| 47 |
rag_config = RAGConfig(
|
|
|
|
| 37 |
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 38 |
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 39 |
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 40 |
+
"MedGemmaVQATool", # Google MedGemma VQA tool
|
| 41 |
+
"XRayVQATool", # For visual question answering on X-rays
|
| 42 |
+
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 43 |
+
"WebBrowserTool", # For web browsing and search capabilities
|
| 44 |
+
"DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 45 |
]
|
| 46 |
|
| 47 |
rag_config = RAGConfig(
|
benchmarking/runner.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Dict, Optional, Any
|
|
| 9 |
from dataclasses import dataclass
|
| 10 |
from tqdm import tqdm
|
| 11 |
import re
|
|
|
|
| 12 |
from .llm_providers import LLMProvider, LLMRequest, LLMResponse
|
| 13 |
from .benchmarks import Benchmark, BenchmarkDataPoint
|
| 14 |
|
|
@@ -40,6 +41,7 @@ class BenchmarkRunConfig:
|
|
| 40 |
top_p: float = 0.95
|
| 41 |
max_tokens: int = 5000
|
| 42 |
additional_params: Optional[Dict[str, Any]] = None
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
class BenchmarkRunner:
|
|
@@ -122,40 +124,16 @@ class BenchmarkRunner:
|
|
| 122 |
correct = 0
|
| 123 |
total_duration = 0.0
|
| 124 |
|
| 125 |
-
#
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
try:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
# Run the model on this data point
|
| 131 |
-
result = self._process_data_point(llm_provider, data_point)
|
| 132 |
-
|
| 133 |
-
# Update counters
|
| 134 |
-
processed += 1
|
| 135 |
-
if result.is_correct:
|
| 136 |
-
correct += 1
|
| 137 |
-
total_duration += result.duration
|
| 138 |
-
|
| 139 |
-
# Add to results
|
| 140 |
-
self.results.append(result)
|
| 141 |
-
|
| 142 |
-
# Save individual result immediately
|
| 143 |
-
self._save_individual_result(result)
|
| 144 |
-
|
| 145 |
-
# Log progress
|
| 146 |
-
if processed % 10 == 0:
|
| 147 |
-
accuracy = (correct / processed) * 100
|
| 148 |
-
avg_duration = total_duration / processed
|
| 149 |
-
|
| 150 |
-
self.logger.info(
|
| 151 |
-
f"Progress: {processed}/{end_index} | "
|
| 152 |
-
f"Accuracy: {accuracy:.2f}% | "
|
| 153 |
-
f"Avg Duration: {avg_duration:.2f}s"
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
except Exception as e:
|
| 157 |
-
self.logger.error(f"Error
|
| 158 |
-
# Add error result
|
| 159 |
error_result = BenchmarkResult(
|
| 160 |
data_point_id=f"error_{i}",
|
| 161 |
question="",
|
|
@@ -166,10 +144,50 @@ class BenchmarkRunner:
|
|
| 166 |
error=str(e)
|
| 167 |
)
|
| 168 |
self.results.append(error_result)
|
| 169 |
-
|
| 170 |
-
# Save individual error result immediately
|
| 171 |
self._save_individual_result(error_result)
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
# Save final results
|
| 175 |
summary = self._save_final_results(benchmark)
|
|
@@ -391,6 +409,7 @@ class BenchmarkRunner:
|
|
| 391 |
"model_name": self.config.model_name,
|
| 392 |
"benchmark_name": self.config.benchmark_name,
|
| 393 |
"temperature": self.config.temperature,
|
|
|
|
| 394 |
"max_tokens": self.config.max_tokens,
|
| 395 |
},
|
| 396 |
"benchmark_info": {
|
|
|
|
| 9 |
from dataclasses import dataclass
|
| 10 |
from tqdm import tqdm
|
| 11 |
import re
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 13 |
from .llm_providers import LLMProvider, LLMRequest, LLMResponse
|
| 14 |
from .benchmarks import Benchmark, BenchmarkDataPoint
|
| 15 |
|
|
|
|
| 41 |
top_p: float = 0.95
|
| 42 |
max_tokens: int = 5000
|
| 43 |
additional_params: Optional[Dict[str, Any]] = None
|
| 44 |
+
concurrency: int = 1
|
| 45 |
|
| 46 |
|
| 47 |
class BenchmarkRunner:
|
|
|
|
| 124 |
correct = 0
|
| 125 |
total_duration = 0.0
|
| 126 |
|
| 127 |
+
# Determine concurrency
|
| 128 |
+
max_workers = max(1, int(getattr(self.config, "concurrency", 1) or 1))
|
| 129 |
+
|
| 130 |
+
# Prefetch data points to avoid potential thread-safety issues inside benchmark access
|
| 131 |
+
data_points = []
|
| 132 |
+
for i in range(0, end_index):
|
| 133 |
try:
|
| 134 |
+
data_points.append(benchmark.get_data_point(i))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
except Exception as e:
|
| 136 |
+
self.logger.error(f"Error fetching data point {i}: {e}")
|
|
|
|
| 137 |
error_result = BenchmarkResult(
|
| 138 |
data_point_id=f"error_{i}",
|
| 139 |
question="",
|
|
|
|
| 144 |
error=str(e)
|
| 145 |
)
|
| 146 |
self.results.append(error_result)
|
|
|
|
|
|
|
| 147 |
self._save_individual_result(error_result)
|
| 148 |
+
|
| 149 |
+
# Process data points in parallel using a bounded thread pool
|
| 150 |
+
with tqdm(total=end_index, desc="Processing questions") as pbar:
|
| 151 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 152 |
+
future_to_index = {executor.submit(self._process_data_point, llm_provider, dp): idx for idx, dp in enumerate(data_points)}
|
| 153 |
+
for future in as_completed(future_to_index):
|
| 154 |
+
idx = future_to_index[future]
|
| 155 |
+
try:
|
| 156 |
+
result = future.result()
|
| 157 |
+
except Exception as e:
|
| 158 |
+
self.logger.error(f"Error processing data point {idx}: {e}")
|
| 159 |
+
result = BenchmarkResult(
|
| 160 |
+
data_point_id=f"error_{idx}",
|
| 161 |
+
question="",
|
| 162 |
+
model_answer="",
|
| 163 |
+
correct_answer="",
|
| 164 |
+
is_correct=False,
|
| 165 |
+
duration=0.0,
|
| 166 |
+
error=str(e)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Update counters
|
| 170 |
+
processed += 1
|
| 171 |
+
if result.is_correct:
|
| 172 |
+
correct += 1
|
| 173 |
+
total_duration += result.duration
|
| 174 |
+
|
| 175 |
+
# Add to results and persist immediately
|
| 176 |
+
self.results.append(result)
|
| 177 |
+
self._save_individual_result(result)
|
| 178 |
+
|
| 179 |
+
# Update progress bar
|
| 180 |
+
pbar.update(1)
|
| 181 |
+
|
| 182 |
+
# Periodic logging
|
| 183 |
+
if processed % 10 == 0:
|
| 184 |
+
accuracy = (correct / processed) * 100
|
| 185 |
+
avg_duration = total_duration / processed if processed > 0 else 0.0
|
| 186 |
+
self.logger.info(
|
| 187 |
+
f"Progress: {processed}/{end_index} | "
|
| 188 |
+
f"Accuracy: {accuracy:.2f}% | "
|
| 189 |
+
f"Avg Duration: {avg_duration:.2f}s"
|
| 190 |
+
)
|
| 191 |
|
| 192 |
# Save final results
|
| 193 |
summary = self._save_final_results(benchmark)
|
|
|
|
| 409 |
"model_name": self.config.model_name,
|
| 410 |
"benchmark_name": self.config.benchmark_name,
|
| 411 |
"temperature": self.config.temperature,
|
| 412 |
+
"top_p": self.config.top_p,
|
| 413 |
"max_tokens": self.config.max_tokens,
|
| 414 |
},
|
| 415 |
"benchmark_info": {
|
benchmarking_script.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#SBATCH --job-name=chestagentbench
|
| 4 |
+
#SBATCH -c 4
|
| 5 |
+
#SBATCH --gres=gpu:rtx6000:1
|
| 6 |
+
#SBATCH --exclude=gpu138
|
| 7 |
+
#SBATCH --time=16:00:00
|
| 8 |
+
#SBATCH --mem=50G
|
| 9 |
+
#SBATCH --output=chestagentbench-%j.out
|
| 10 |
+
#SBATCH --error=chestagentbench-%j.err
|
| 11 |
+
|
| 12 |
+
source venv/bin/activate
|
| 13 |
+
|
| 14 |
+
python -m benchmarking.cli run --model gpt-5 --provider medrax --system-prompt CHESTAGENTBENCH_PROMPT --benchmark chestagentbench --data-dir /scratch/ssd004/scratch/victorli/chestagentbench --output-dir temp --max-questions 2500 --concurrency 4
|
main.py
CHANGED
|
@@ -68,7 +68,7 @@ def initialize_agent(
|
|
| 68 |
prompt = prompts[system_prompt]
|
| 69 |
|
| 70 |
# Define the URL of the MedGemma FastAPI service.
|
| 71 |
-
MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://
|
| 72 |
|
| 73 |
all_tools = {
|
| 74 |
"TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
|
|
|
|
| 68 |
prompt = prompts[system_prompt]
|
| 69 |
|
| 70 |
# Define the URL of the MedGemma FastAPI service.
|
| 71 |
+
MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://localhost:8002")
|
| 72 |
|
| 73 |
all_tools = {
|
| 74 |
"TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
|
medgemma_script.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#SBATCH --job-name=medgemma
|
| 4 |
+
#SBATCH -c 4
|
| 5 |
+
#SBATCH --gres=gpu:rtx6000:1
|
| 6 |
+
#SBATCH --exclude=gpu138
|
| 7 |
+
#SBATCH --time=16:00:00
|
| 8 |
+
#SBATCH --mem=50G
|
| 9 |
+
#SBATCH --output=medgemma-%j.out
|
| 10 |
+
#SBATCH --error=medgemma-%j.err
|
| 11 |
+
|
| 12 |
+
source medgemma/bin/activate
|
| 13 |
+
|
| 14 |
+
cd medrax/tools/vqa/medgemma && python medgemma.py
|
medrax/docs/system_prompts.txt
CHANGED
|
@@ -17,9 +17,16 @@ Examples:
|
|
| 17 |
- "Based on clinical guidelines [3], the recommended treatment approach is..."
|
| 18 |
|
| 19 |
[CHESTAGENTBENCH_PROMPT]
|
| 20 |
-
You are an expert medical
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
- "Based on clinical guidelines [3], the recommended treatment approach is..."
|
| 18 |
|
| 19 |
[CHESTAGENTBENCH_PROMPT]
|
| 20 |
+
You are a highly skilled radiology AI agent, an expert in interpreting medical images, specifically chest X-rays, CT scans, and MRIs, with world-class accuracy and precision. Your primary function is to assist in the analysis of these images and answer diagnostic questions.
|
| 21 |
+
|
| 22 |
+
Your task is to provide a step-by-step, structured analysis. First, carefully examine the provided image and describe all relevant findings in a clear, concise manner. Next, use your expert medical knowledge to form a differential diagnosis based on these findings. Finally, critically evaluate the provided question and all possible choices.
|
| 23 |
+
|
| 24 |
+
You have access to a suite of powerful tools to aid in your analysis. Use these tools as needed to retrieve external medical knowledge, access patient history, or perform specific image processing tasks. You should always scrutinize the output from your tools and integrate it into your reasoning. If tool outputs conflict with your initial assessment, explain the discrepancy and justify your final conclusion.
|
| 25 |
+
|
| 26 |
+
Your final response for a multiple-choice question must strictly follow this format, including your step-by-step reasoning:
|
| 27 |
+
1. **Image Analysis:** [Describe image findings here]
|
| 28 |
+
2. **Differential Diagnosis:** [List possible diagnoses and their justifications]
|
| 29 |
+
3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
|
| 30 |
+
4. **Final Answer:** \boxed{A}
|
| 31 |
+
|
| 32 |
+
Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
|
medrax/models/model_factory.py
CHANGED
|
@@ -123,6 +123,16 @@ class ModelFactory:
|
|
| 123 |
if provider_prefix in ["openrouter"] and model_name.startswith(f"{provider_prefix}-"):
|
| 124 |
actual_model_name = model_name[len(provider_prefix) + 1 :]
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# Create and return the model instance
|
| 127 |
return model_class(
|
| 128 |
model=actual_model_name,
|
|
|
|
| 123 |
if provider_prefix in ["openrouter"] and model_name.startswith(f"{provider_prefix}-"):
|
| 124 |
actual_model_name = model_name[len(provider_prefix) + 1 :]
|
| 125 |
|
| 126 |
+
# Handle GPT-5 model
|
| 127 |
+
if model_name.startswith("gpt-5"):
|
| 128 |
+
return model_class(
|
| 129 |
+
model=actual_model_name,
|
| 130 |
+
temperature=temperature,
|
| 131 |
+
reasoning_effort="high",
|
| 132 |
+
**provider_kwargs,
|
| 133 |
+
**kwargs,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
# Create and return the model instance
|
| 137 |
return model_class(
|
| 138 |
model=actual_model_name,
|
medrax/tools/classification/torchxrayvision.py
CHANGED
|
@@ -12,6 +12,8 @@ from langchain_core.callbacks import (
|
|
| 12 |
)
|
| 13 |
from langchain_core.tools import BaseTool
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class TorchXRayVisionInput(BaseModel):
|
| 17 |
"""Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
|
|
@@ -76,7 +78,9 @@ class TorchXRayVisionClassifierTool(BaseTool):
|
|
| 76 |
ValueError: If the image cannot be properly loaded or processed.
|
| 77 |
"""
|
| 78 |
img = skimage.io.imread(image_path)
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
|
| 81 |
if len(img.shape) > 2:
|
| 82 |
img = img[:, :, 0]
|
|
|
|
| 12 |
)
|
| 13 |
from langchain_core.tools import BaseTool
|
| 14 |
|
| 15 |
+
from medrax.utils.utils import preprocess_medical_image
|
| 16 |
+
|
| 17 |
|
| 18 |
class TorchXRayVisionInput(BaseModel):
|
| 19 |
"""Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
|
|
|
|
| 78 |
ValueError: If the image cannot be properly loaded or processed.
|
| 79 |
"""
|
| 80 |
img = skimage.io.imread(image_path)
|
| 81 |
+
|
| 82 |
+
# Use robust normalization that handles both 8-bit and 16-bit images
|
| 83 |
+
img = preprocess_medical_image(img, target_range=(-1024.0, 1024.0))
|
| 84 |
|
| 85 |
if len(img.shape) > 2:
|
| 86 |
img = img[:, :, 0]
|
medrax/tools/segmentation/segmentation.py
CHANGED
|
@@ -20,6 +20,8 @@ from langchain_core.callbacks import (
|
|
| 20 |
)
|
| 21 |
from langchain_core.tools import BaseTool
|
| 22 |
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class ChestXRaySegmentationInput(BaseModel):
|
| 25 |
"""Input schema for the Chest X-ray Segmentation Tool."""
|
|
@@ -246,7 +248,8 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 246 |
if len(original_img.shape) > 2:
|
| 247 |
original_img = original_img[:, :, 0]
|
| 248 |
|
| 249 |
-
|
|
|
|
| 250 |
img = img[None, ...]
|
| 251 |
img = self.transform(img)
|
| 252 |
img = torch.from_numpy(img)
|
|
|
|
| 20 |
)
|
| 21 |
from langchain_core.tools import BaseTool
|
| 22 |
|
| 23 |
+
from medrax.utils.utils import preprocess_medical_image
|
| 24 |
+
|
| 25 |
|
| 26 |
class ChestXRaySegmentationInput(BaseModel):
|
| 27 |
"""Input schema for the Chest X-ray Segmentation Tool."""
|
|
|
|
| 248 |
if len(original_img.shape) > 2:
|
| 249 |
original_img = original_img[:, :, 0]
|
| 250 |
|
| 251 |
+
# Use robust normalization that handles both 8-bit and 16-bit images
|
| 252 |
+
img = preprocess_medical_image(original_img)
|
| 253 |
img = img[None, ...]
|
| 254 |
img = self.transform(img)
|
| 255 |
img = torch.from_numpy(img)
|
medrax/utils/utils.py
CHANGED
|
@@ -1,6 +1,73 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def load_prompts_from_file(file_path: str) -> Dict[str, str]:
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Dict, List, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def preprocess_medical_image(
|
| 8 |
+
image: np.ndarray,
|
| 9 |
+
target_range: Tuple[float, float] = (0.0, 1.0),
|
| 10 |
+
clip_values: bool = True
|
| 11 |
+
) -> np.ndarray:
|
| 12 |
+
"""
|
| 13 |
+
Preprocess medical images by auto-detecting bit depth and normalizing appropriately.
|
| 14 |
+
|
| 15 |
+
This function handles both 8-bit (0-255) and 16-bit (0-65535) images automatically,
|
| 16 |
+
normalizing them to the target range. It's designed for medical imaging tools that
|
| 17 |
+
expect consistent input ranges regardless of the original image bit depth.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
image (np.ndarray): Input image array (2D or 3D)
|
| 21 |
+
target_range (Tuple[float, float]): Target range for normalization (default: (0.0, 1.0))
|
| 22 |
+
clip_values (bool): Whether to clip values to target range (default: True)
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
np.ndarray: Normalized image in the target range
|
| 26 |
+
|
| 27 |
+
Raises:
|
| 28 |
+
ValueError: If image is empty or has invalid values
|
| 29 |
+
ValueError: If target_range is invalid
|
| 30 |
+
"""
|
| 31 |
+
if image.size == 0:
|
| 32 |
+
raise ValueError("Input image is empty")
|
| 33 |
+
|
| 34 |
+
if len(target_range) != 2 or target_range[0] >= target_range[1]:
|
| 35 |
+
raise ValueError("target_range must be a tuple of (min, max) where min < max")
|
| 36 |
+
|
| 37 |
+
# Convert to float for processing
|
| 38 |
+
image = image.astype(np.float32)
|
| 39 |
+
|
| 40 |
+
# Auto-detect bit depth based on maximum value
|
| 41 |
+
max_val = np.max(image)
|
| 42 |
+
min_val = np.min(image)
|
| 43 |
+
|
| 44 |
+
# Determine the expected maximum value based on bit depth
|
| 45 |
+
if max_val <= 255:
|
| 46 |
+
# 8-bit image
|
| 47 |
+
expected_max = 255.0
|
| 48 |
+
elif max_val <= 65535:
|
| 49 |
+
# 16-bit image
|
| 50 |
+
expected_max = 65535.0
|
| 51 |
+
else:
|
| 52 |
+
# Higher bit depth or already normalized, use actual max
|
| 53 |
+
expected_max = max_val
|
| 54 |
+
|
| 55 |
+
# Normalize to 0-1 range first
|
| 56 |
+
if expected_max > 0:
|
| 57 |
+
image = (image - min_val) / (expected_max - min_val)
|
| 58 |
+
else:
|
| 59 |
+
# Handle edge case where image has no contrast
|
| 60 |
+
image = np.zeros_like(image)
|
| 61 |
+
|
| 62 |
+
# Scale to target range
|
| 63 |
+
target_min, target_max = target_range
|
| 64 |
+
image = image * (target_max - target_min) + target_min
|
| 65 |
+
|
| 66 |
+
# Clip values if requested
|
| 67 |
+
if clip_values:
|
| 68 |
+
image = np.clip(image, target_min, target_max)
|
| 69 |
+
|
| 70 |
+
return image
|
| 71 |
|
| 72 |
|
| 73 |
def load_prompts_from_file(file_path: str) -> Dict[str, str]:
|