victorli commited on
Commit
16278b5
·
1 Parent(s): 5f69e37

fixed rexvqa benchmark and added handling for image norm for tools

Browse files
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -46,10 +46,10 @@ class ReXVQABenchmark(Benchmark):
46
  self.image_dataset = None
47
  self.image_mapping = {} # Maps study_id to image data
48
 
49
- super().__init__(data_dir, **kwargs)
 
50
 
51
- # Set images_dir after parent initialization
52
- self.images_dir = f"{self.data_dir}/images/deid_png"
53
 
54
  @staticmethod
55
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
 
46
  self.image_dataset = None
47
  self.image_mapping = {} # Maps study_id to image data
48
 
49
+ # Set images_dir BEFORE parent initialization to avoid AttributeError
50
+ self.images_dir = f"{data_dir}/images/deid_png"
51
 
52
+ super().__init__(data_dir, **kwargs)
 
53
 
54
  @staticmethod
55
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
medrax/tools/classification/torchxrayvision.py CHANGED
@@ -12,6 +12,8 @@ from langchain_core.callbacks import (
12
  )
13
  from langchain_core.tools import BaseTool
14
 
 
 
15
 
16
  class TorchXRayVisionInput(BaseModel):
17
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
@@ -76,7 +78,9 @@ class TorchXRayVisionClassifierTool(BaseTool):
76
  ValueError: If the image cannot be properly loaded or processed.
77
  """
78
  img = skimage.io.imread(image_path)
79
- img = xrv.datasets.normalize(img, 255)
 
 
80
 
81
  if len(img.shape) > 2:
82
  img = img[:, :, 0]
 
12
  )
13
  from langchain_core.tools import BaseTool
14
 
15
+ from medrax.utils.utils import preprocess_medical_image
16
+
17
 
18
  class TorchXRayVisionInput(BaseModel):
19
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
 
78
  ValueError: If the image cannot be properly loaded or processed.
79
  """
80
  img = skimage.io.imread(image_path)
81
+
82
+ # Use robust normalization that handles both 8-bit and 16-bit images
83
+ img = preprocess_medical_image(img, target_range=(-1024.0, 1024.0))
84
 
85
  if len(img.shape) > 2:
86
  img = img[:, :, 0]
medrax/tools/segmentation/segmentation.py CHANGED
@@ -20,6 +20,8 @@ from langchain_core.callbacks import (
20
  )
21
  from langchain_core.tools import BaseTool
22
 
 
 
23
 
24
  class ChestXRaySegmentationInput(BaseModel):
25
  """Input schema for the Chest X-ray Segmentation Tool."""
@@ -246,7 +248,8 @@ class ChestXRaySegmentationTool(BaseTool):
246
  if len(original_img.shape) > 2:
247
  original_img = original_img[:, :, 0]
248
 
249
- img = xrv.datasets.normalize(original_img, 255)
 
250
  img = img[None, ...]
251
  img = self.transform(img)
252
  img = torch.from_numpy(img)
 
20
  )
21
  from langchain_core.tools import BaseTool
22
 
23
+ from medrax.utils.utils import preprocess_medical_image
24
+
25
 
26
  class ChestXRaySegmentationInput(BaseModel):
27
  """Input schema for the Chest X-ray Segmentation Tool."""
 
248
  if len(original_img.shape) > 2:
249
  original_img = original_img[:, :, 0]
250
 
251
+ # Use robust normalization that handles both 8-bit and 16-bit images
252
+ img = preprocess_medical_image(original_img)
253
  img = img[None, ...]
254
  img = self.transform(img)
255
  img = torch.from_numpy(img)
medrax/utils/utils.py CHANGED
@@ -1,6 +1,90 @@
1
  import os
2
  import json
3
- from typing import Dict, List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def load_prompts_from_file(file_path: str) -> Dict[str, str]:
 
1
  import os
2
  import json
3
+ import numpy as np
4
+ from typing import Dict, List, Union, Tuple
5
+
6
+
7
+ def preprocess_medical_image(
8
+ image: np.ndarray,
9
+ target_range: Tuple[float, float] = (0.0, 1.0),
10
+ clip_values: bool = True
11
+ ) -> np.ndarray:
12
+ """
13
+ Preprocess medical images by auto-detecting bit depth and normalizing appropriately.
14
+
15
+ This function handles both 8-bit (0-255) and 16-bit (0-65535) images automatically,
16
+ normalizing them to the target range. It's designed for medical imaging tools that
17
+ expect consistent input ranges regardless of the original image bit depth.
18
+
19
+ Args:
20
+ image (np.ndarray): Input image array (2D or 3D)
21
+ target_range (Tuple[float, float]): Target range for normalization (default: (0.0, 1.0))
22
+ clip_values (bool): Whether to clip values to target range (default: True)
23
+
24
+ Returns:
25
+ np.ndarray: Normalized image in the target range
26
+
27
+ Raises:
28
+ ValueError: If image is empty or has invalid values
29
+ ValueError: If target_range is invalid
30
+ """
31
+ if image.size == 0:
32
+ raise ValueError("Input image is empty")
33
+
34
+ if len(target_range) != 2 or target_range[0] >= target_range[1]:
35
+ raise ValueError("target_range must be a tuple of (min, max) where min < max")
36
+
37
+ # Convert to float for processing
38
+ image = image.astype(np.float32)
39
+
40
+ # Auto-detect bit depth based on maximum value
41
+ max_val = np.max(image)
42
+ min_val = np.min(image)
43
+
44
+ # Determine the expected maximum value based on bit depth
45
+ if max_val <= 255:
46
+ # 8-bit image
47
+ expected_max = 255.0
48
+ elif max_val <= 65535:
49
+ # 16-bit image
50
+ expected_max = 65535.0
51
+ else:
52
+ # Higher bit depth or already normalized, use actual max
53
+ expected_max = max_val
54
+
55
+ # Normalize to 0-1 range first
56
+ if expected_max > 0:
57
+ image = (image - min_val) / (expected_max - min_val)
58
+ else:
59
+ # Handle edge case where image has no contrast
60
+ image = np.zeros_like(image)
61
+
62
+ # Scale to target range
63
+ target_min, target_max = target_range
64
+ image = image * (target_max - target_min) + target_min
65
+
66
+ # Clip values if requested
67
+ if clip_values:
68
+ image = np.clip(image, target_min, target_max)
69
+
70
+ return image
71
+
72
+
73
+ def normalize_medical_image_for_torchxrayvision(image: np.ndarray) -> np.ndarray:
74
+ """
75
+ Normalize medical images specifically for TorchXRayVision models.
76
+
77
+ This function is a convenience wrapper around preprocess_medical_image
78
+ that normalizes images to the -1024 to 1024 range expected by TorchXRayVision models.
79
+ This range corresponds to the Hounsfield Unit scale adapted for X-ray images.
80
+
81
+ Args:
82
+ image (np.ndarray): Input image array (2D or 3D)
83
+
84
+ Returns:
85
+ np.ndarray: Normalized image in -1024 to 1024 range
86
+ """
87
+ return preprocess_medical_image(image, target_range=(-1024.0, 1024.0))
88
 
89
 
90
  def load_prompts_from_file(file_path: str) -> Dict[str, str]: