Spaces:
Sleeping
Sleeping
final updates
Browse files
benchmarking/benchmarks/base.py
CHANGED
|
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|
| 4 |
from typing import Dict, List, Optional, Any, Iterator, Tuple
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
@dataclass
|
|
@@ -31,17 +32,31 @@ class Benchmark(ABC):
|
|
| 31 |
Args:
|
| 32 |
data_dir (str): Directory containing benchmark data
|
| 33 |
**kwargs: Additional configuration parameters
|
|
|
|
| 34 |
"""
|
| 35 |
self.data_dir = Path(data_dir)
|
| 36 |
self.config = kwargs
|
| 37 |
self.data_points = []
|
| 38 |
self._load_data()
|
|
|
|
| 39 |
|
| 40 |
@abstractmethod
|
| 41 |
def _load_data(self) -> None:
|
| 42 |
"""Load benchmark data from the data directory."""
|
| 43 |
pass
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def get_data_point(self, index: int) -> BenchmarkDataPoint:
|
| 46 |
"""Get a specific data point by index.
|
| 47 |
|
|
|
|
| 4 |
from typing import Dict, List, Optional, Any, Iterator, Tuple
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
|
|
| 32 |
Args:
|
| 33 |
data_dir (str): Directory containing benchmark data
|
| 34 |
**kwargs: Additional configuration parameters
|
| 35 |
+
random_seed (int): Random seed for shuffling data (default: None, no shuffling)
|
| 36 |
"""
|
| 37 |
self.data_dir = Path(data_dir)
|
| 38 |
self.config = kwargs
|
| 39 |
self.data_points = []
|
| 40 |
self._load_data()
|
| 41 |
+
self._shuffle_data()
|
| 42 |
|
| 43 |
@abstractmethod
|
| 44 |
def _load_data(self) -> None:
|
| 45 |
"""Load benchmark data from the data directory."""
|
| 46 |
pass
|
| 47 |
|
| 48 |
+
def _shuffle_data(self) -> None:
|
| 49 |
+
"""Shuffle the data points if a random seed is provided.
|
| 50 |
+
|
| 51 |
+
This method is called automatically after data loading to ensure
|
| 52 |
+
reproducible benchmark runs when a random seed is specified.
|
| 53 |
+
"""
|
| 54 |
+
random_seed = self.config.get("random_seed", None)
|
| 55 |
+
if random_seed is not None:
|
| 56 |
+
random.seed(random_seed)
|
| 57 |
+
random.shuffle(self.data_points)
|
| 58 |
+
print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
|
| 59 |
+
|
| 60 |
def get_data_point(self, index: int) -> BenchmarkDataPoint:
|
| 61 |
"""Get a specific data point by index.
|
| 62 |
|
benchmarking/benchmarks/chestagentbench_benchmark.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import json
|
| 2 |
-
import random
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Dict, Optional, Any
|
| 5 |
from .base import Benchmark, BenchmarkDataPoint
|
|
@@ -31,10 +30,6 @@ class ChestAgentBenchBenchmark(Benchmark):
|
|
| 31 |
except Exception as e:
|
| 32 |
print(f"Error loading item {i}: {e}")
|
| 33 |
continue
|
| 34 |
-
|
| 35 |
-
# Shuffle the final data
|
| 36 |
-
random.seed(42)
|
| 37 |
-
random.shuffle(self.data_points)
|
| 38 |
|
| 39 |
def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
|
| 40 |
# Use full_question_id or question_id if available, else fallback
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Dict, Optional, Any
|
| 4 |
from .base import Benchmark, BenchmarkDataPoint
|
|
|
|
| 30 |
except Exception as e:
|
| 31 |
print(f"Error loading item {i}: {e}")
|
| 32 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
|
| 35 |
# Use full_question_id or question_id if available, else fallback
|
benchmarking/cli.py
CHANGED
|
@@ -73,6 +73,8 @@ def run_benchmark_command(args) -> None:
|
|
| 73 |
|
| 74 |
# Create benchmark
|
| 75 |
benchmark_kwargs = {}
|
|
|
|
|
|
|
| 76 |
|
| 77 |
benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
|
| 78 |
|
|
@@ -141,6 +143,8 @@ def main():
|
|
| 141 |
help="Top-p nucleus sampling parameter (default: 0.95)")
|
| 142 |
run_parser.add_argument("--max-tokens", type=int, default=5000,
|
| 143 |
help="Maximum tokens per model response (default: 5000)")
|
|
|
|
|
|
|
| 144 |
|
| 145 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 146 |
|
|
|
|
| 73 |
|
| 74 |
# Create benchmark
|
| 75 |
benchmark_kwargs = {}
|
| 76 |
+
if args.random_seed is not None:
|
| 77 |
+
benchmark_kwargs["random_seed"] = args.random_seed
|
| 78 |
|
| 79 |
benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
|
| 80 |
|
|
|
|
| 143 |
help="Top-p nucleus sampling parameter (default: 0.95)")
|
| 144 |
run_parser.add_argument("--max-tokens", type=int, default=5000,
|
| 145 |
help="Maximum tokens per model response (default: 5000)")
|
| 146 |
+
run_parser.add_argument("--random-seed", type=int, default=42,
|
| 147 |
+
help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
|
| 148 |
|
| 149 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 150 |
|
interface.py
CHANGED
|
@@ -192,7 +192,11 @@ class ChatInterface:
|
|
| 192 |
tool_args = pending_call["args"]
|
| 193 |
|
| 194 |
try:
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
tool_output_str = json.dumps(tool_output_json, indent=2)
|
| 197 |
except (json.JSONDecodeError, TypeError):
|
| 198 |
tool_output_str = str(msg.content)
|
|
@@ -217,10 +221,11 @@ class ChatInterface:
|
|
| 217 |
|
| 218 |
if tool_name == "image_visualizer":
|
| 219 |
try:
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
|
|
|
| 224 |
if isinstance(result, dict) and "image_path" in result:
|
| 225 |
self.display_file_path = result["image_path"]
|
| 226 |
chat_history.append(
|
|
|
|
| 192 |
tool_args = pending_call["args"]
|
| 193 |
|
| 194 |
try:
|
| 195 |
+
# Handle case where tool returns tuple (output, metadata)
|
| 196 |
+
content = msg.content
|
| 197 |
+
content_tuple = ast.literal_eval(content)
|
| 198 |
+
content = json.dumps(content_tuple[0])
|
| 199 |
+
tool_output_json = json.loads(content)
|
| 200 |
tool_output_str = json.dumps(tool_output_json, indent=2)
|
| 201 |
except (json.JSONDecodeError, TypeError):
|
| 202 |
tool_output_str = str(msg.content)
|
|
|
|
| 221 |
|
| 222 |
if tool_name == "image_visualizer":
|
| 223 |
try:
|
| 224 |
+
# Handle case where tool returns tuple (output, metadata)
|
| 225 |
+
content = msg.content
|
| 226 |
+
content_tuple = ast.literal_eval(content)
|
| 227 |
+
result = content_tuple[0]
|
| 228 |
+
|
| 229 |
if isinstance(result, dict) and "image_path" in result:
|
| 230 |
self.display_file_path = result["image_path"]
|
| 231 |
chat_history.append(
|
medrax/tools/__init__.py
CHANGED
|
@@ -5,7 +5,7 @@ from .report_generation import *
|
|
| 5 |
from .segmentation import *
|
| 6 |
from .vqa import *
|
| 7 |
from .grounding import *
|
| 8 |
-
from .
|
| 9 |
from .dicom import *
|
| 10 |
from .utils import *
|
| 11 |
from .rag import *
|
|
|
|
| 5 |
from .segmentation import *
|
| 6 |
from .vqa import *
|
| 7 |
from .grounding import *
|
| 8 |
+
from .xray_generation import *
|
| 9 |
from .dicom import *
|
| 10 |
from .utils import *
|
| 11 |
from .rag import *
|
medrax/tools/classification/arcplus.py
CHANGED
|
@@ -345,7 +345,8 @@ class ArcPlusClassifierTool(BaseTool):
|
|
| 345 |
predictions = predictions[: len(self.disease_list)]
|
| 346 |
|
| 347 |
# Create output dictionary mapping disease names to probabilities
|
| 348 |
-
|
|
|
|
| 349 |
|
| 350 |
metadata = {
|
| 351 |
"image_path": image_path,
|
|
|
|
| 345 |
predictions = predictions[: len(self.disease_list)]
|
| 346 |
|
| 347 |
# Create output dictionary mapping disease names to probabilities
|
| 348 |
+
# Convert numpy floats to native Python floats for proper serialization
|
| 349 |
+
output = dict(zip(self.disease_list, [float(pred) for pred in predictions]))
|
| 350 |
|
| 351 |
metadata = {
|
| 352 |
"image_path": image_path,
|
medrax/tools/segmentation/medsam2.py
CHANGED
|
@@ -15,7 +15,7 @@ from langchain_core.callbacks import (
|
|
| 15 |
from langchain_core.tools import BaseTool
|
| 16 |
|
| 17 |
# Add MedSAM2 to Python path for proper module resolution
|
| 18 |
-
medsam2_path = str(Path(__file__).parent.parent.parent / "MedSAM2")
|
| 19 |
if medsam2_path not in sys.path:
|
| 20 |
sys.path.append(medsam2_path)
|
| 21 |
|
|
@@ -93,7 +93,7 @@ class MedSAM2Tool(BaseTool):
|
|
| 93 |
if GlobalHydra.instance().is_initialized():
|
| 94 |
GlobalHydra.instance().clear()
|
| 95 |
|
| 96 |
-
config_dir = Path(__file__).parent.parent.parent / "MedSAM2" / "sam2" / "configs"
|
| 97 |
initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
|
| 98 |
|
| 99 |
hf_hub_download(
|
|
|
|
| 15 |
from langchain_core.tools import BaseTool
|
| 16 |
|
| 17 |
# Add MedSAM2 to Python path for proper module resolution
|
| 18 |
+
medsam2_path = str(Path(__file__).parent.parent.parent.parent / "MedSAM2")
|
| 19 |
if medsam2_path not in sys.path:
|
| 20 |
sys.path.append(medsam2_path)
|
| 21 |
|
|
|
|
| 93 |
if GlobalHydra.instance().is_initialized():
|
| 94 |
GlobalHydra.instance().clear()
|
| 95 |
|
| 96 |
+
config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
|
| 97 |
initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
|
| 98 |
|
| 99 |
hf_hub_download(
|