VictorLJZ commited on
Commit
044eaf7
·
1 Parent(s): f06bcdb

final updates

Browse files
benchmarking/benchmarks/base.py CHANGED
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
4
  from typing import Dict, List, Optional, Any, Iterator, Tuple
5
  from dataclasses import dataclass
6
  from pathlib import Path
 
7
 
8
 
9
  @dataclass
@@ -31,17 +32,31 @@ class Benchmark(ABC):
31
  Args:
32
  data_dir (str): Directory containing benchmark data
33
  **kwargs: Additional configuration parameters
 
34
  """
35
  self.data_dir = Path(data_dir)
36
  self.config = kwargs
37
  self.data_points = []
38
  self._load_data()
 
39
 
40
  @abstractmethod
41
  def _load_data(self) -> None:
42
  """Load benchmark data from the data directory."""
43
  pass
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
46
  """Get a specific data point by index.
47
 
 
4
  from typing import Dict, List, Optional, Any, Iterator, Tuple
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
+ import random
8
 
9
 
10
  @dataclass
 
32
  Args:
33
  data_dir (str): Directory containing benchmark data
34
  **kwargs: Additional configuration parameters
35
+ random_seed (int): Random seed for shuffling data (default: None, no shuffling)
36
  """
37
  self.data_dir = Path(data_dir)
38
  self.config = kwargs
39
  self.data_points = []
40
  self._load_data()
41
+ self._shuffle_data()
42
 
43
  @abstractmethod
44
  def _load_data(self) -> None:
45
  """Load benchmark data from the data directory."""
46
  pass
47
 
48
+ def _shuffle_data(self) -> None:
49
+ """Shuffle the data points if a random seed is provided.
50
+
51
+ This method is called automatically after data loading to ensure
52
+ reproducible benchmark runs when a random seed is specified.
53
+ """
54
+ random_seed = self.config.get("random_seed", None)
55
+ if random_seed is not None:
56
+ random.seed(random_seed)
57
+ random.shuffle(self.data_points)
58
+ print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
59
+
60
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
61
  """Get a specific data point by index.
62
 
benchmarking/benchmarks/chestagentbench_benchmark.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import random
3
  from pathlib import Path
4
  from typing import Dict, Optional, Any
5
  from .base import Benchmark, BenchmarkDataPoint
@@ -31,10 +30,6 @@ class ChestAgentBenchBenchmark(Benchmark):
31
  except Exception as e:
32
  print(f"Error loading item {i}: {e}")
33
  continue
34
-
35
- # Shuffle the final data
36
- random.seed(42)
37
- random.shuffle(self.data_points)
38
 
39
  def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
40
  # Use full_question_id or question_id if available, else fallback
 
1
  import json
 
2
  from pathlib import Path
3
  from typing import Dict, Optional, Any
4
  from .base import Benchmark, BenchmarkDataPoint
 
30
  except Exception as e:
31
  print(f"Error loading item {i}: {e}")
32
  continue
 
 
 
 
33
 
34
  def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
35
  # Use full_question_id or question_id if available, else fallback
benchmarking/cli.py CHANGED
@@ -73,6 +73,8 @@ def run_benchmark_command(args) -> None:
73
 
74
  # Create benchmark
75
  benchmark_kwargs = {}
 
 
76
 
77
  benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
78
 
@@ -141,6 +143,8 @@ def main():
141
  help="Top-p nucleus sampling parameter (default: 0.95)")
142
  run_parser.add_argument("--max-tokens", type=int, default=5000,
143
  help="Maximum tokens per model response (default: 5000)")
 
 
144
 
145
  run_parser.set_defaults(func=run_benchmark_command)
146
 
 
73
 
74
  # Create benchmark
75
  benchmark_kwargs = {}
76
+ if args.random_seed is not None:
77
+ benchmark_kwargs["random_seed"] = args.random_seed
78
 
79
  benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
80
 
 
143
  help="Top-p nucleus sampling parameter (default: 0.95)")
144
  run_parser.add_argument("--max-tokens", type=int, default=5000,
145
  help="Maximum tokens per model response (default: 5000)")
146
+ run_parser.add_argument("--random-seed", type=int, default=42,
147
+ help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
148
 
149
  run_parser.set_defaults(func=run_benchmark_command)
150
 
interface.py CHANGED
@@ -192,7 +192,11 @@ class ChatInterface:
192
  tool_args = pending_call["args"]
193
 
194
  try:
195
- tool_output_json = json.loads(msg.content)
 
 
 
 
196
  tool_output_str = json.dumps(tool_output_json, indent=2)
197
  except (json.JSONDecodeError, TypeError):
198
  tool_output_str = str(msg.content)
@@ -217,10 +221,11 @@ class ChatInterface:
217
 
218
  if tool_name == "image_visualizer":
219
  try:
220
- result = json.loads(msg.content)
221
- # Handle case where tool returns array [output, metadata]
222
- if isinstance(result, list) and len(result) > 0:
223
- result = result[0] # Take the first element (output)
 
224
  if isinstance(result, dict) and "image_path" in result:
225
  self.display_file_path = result["image_path"]
226
  chat_history.append(
 
192
  tool_args = pending_call["args"]
193
 
194
  try:
195
+ # Handle case where tool returns tuple (output, metadata)
196
+ content = msg.content
197
+ content_tuple = ast.literal_eval(content)
198
+ content = json.dumps(content_tuple[0])
199
+ tool_output_json = json.loads(content)
200
  tool_output_str = json.dumps(tool_output_json, indent=2)
201
  except (json.JSONDecodeError, TypeError):
202
  tool_output_str = str(msg.content)
 
221
 
222
  if tool_name == "image_visualizer":
223
  try:
224
+ # Handle case where tool returns tuple (output, metadata)
225
+ content = msg.content
226
+ content_tuple = ast.literal_eval(content)
227
+ result = content_tuple[0]
228
+
229
  if isinstance(result, dict) and "image_path" in result:
230
  self.display_file_path = result["image_path"]
231
  chat_history.append(
medrax/tools/__init__.py CHANGED
@@ -5,7 +5,7 @@ from .report_generation import *
5
  from .segmentation import *
6
  from .vqa import *
7
  from .grounding import *
8
- from .generation import *
9
  from .dicom import *
10
  from .utils import *
11
  from .rag import *
 
5
  from .segmentation import *
6
  from .vqa import *
7
  from .grounding import *
8
+ from .xray_generation import *
9
  from .dicom import *
10
  from .utils import *
11
  from .rag import *
medrax/tools/classification/arcplus.py CHANGED
@@ -345,7 +345,8 @@ class ArcPlusClassifierTool(BaseTool):
345
  predictions = predictions[: len(self.disease_list)]
346
 
347
  # Create output dictionary mapping disease names to probabilities
348
- output = dict(zip(self.disease_list, predictions.astype(float)))
 
349
 
350
  metadata = {
351
  "image_path": image_path,
 
345
  predictions = predictions[: len(self.disease_list)]
346
 
347
  # Create output dictionary mapping disease names to probabilities
348
+ # Convert numpy floats to native Python floats for proper serialization
349
+ output = dict(zip(self.disease_list, [float(pred) for pred in predictions]))
350
 
351
  metadata = {
352
  "image_path": image_path,
medrax/tools/segmentation/medsam2.py CHANGED
@@ -15,7 +15,7 @@ from langchain_core.callbacks import (
15
  from langchain_core.tools import BaseTool
16
 
17
  # Add MedSAM2 to Python path for proper module resolution
18
- medsam2_path = str(Path(__file__).parent.parent.parent / "MedSAM2")
19
  if medsam2_path not in sys.path:
20
  sys.path.append(medsam2_path)
21
 
@@ -93,7 +93,7 @@ class MedSAM2Tool(BaseTool):
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
 
96
- config_dir = Path(__file__).parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
 
99
  hf_hub_download(
 
15
  from langchain_core.tools import BaseTool
16
 
17
  # Add MedSAM2 to Python path for proper module resolution
18
+ medsam2_path = str(Path(__file__).parent.parent.parent.parent / "MedSAM2")
19
  if medsam2_path not in sys.path:
20
  sys.path.append(medsam2_path)
21
 
 
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
 
96
+ config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
 
99
  hf_hub_download(