iszt commited on
Commit
a99dcb7
·
verified ·
1 Parent(s): bbeaaf5

added mask support, bug fixes

Browse files
Files changed (1) hide show
  1. image_processing_eye_gpu.py +440 -151
image_processing_eye_gpu.py CHANGED
@@ -39,7 +39,11 @@ except ImportError:
39
  # =============================================================================
40
 
41
  def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
42
- """Convert PIL Image to tensor (C, H, W) in [0, 1]."""
 
 
 
 
43
  if not PIL_AVAILABLE:
44
  raise ImportError("PIL is required to process PIL Images")
45
 
@@ -63,7 +67,11 @@ def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
63
 
64
 
65
  def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor:
66
- """Convert numpy array to tensor (C, H, W) in [0, 1]."""
 
 
 
 
67
  if not NUMPY_AVAILABLE:
68
  raise ImportError("NumPy is required to process numpy arrays")
69
 
@@ -89,8 +97,14 @@ def standardize_input(
89
  images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]],
90
  device: Optional[torch.device] = None,
91
  ) -> torch.Tensor:
92
- """
93
- Convert input images to standardized tensor format.
 
 
 
 
 
 
94
 
95
  Args:
96
  images: Input as:
@@ -151,18 +165,89 @@ def standardize_input(
151
 
152
  return images
153
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
156
  """
157
- Convert RGB images to grayscale using luminance formula.
158
 
159
- Y = 0.299 * R + 0.587 * G + 0.114 * B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  Args:
162
- images: Tensor of shape (B, 3, H, W)
163
 
164
  Returns:
165
- Tensor of shape (B, 1, H, W)
166
  """
167
  # Luminance weights
168
  weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype)
@@ -177,11 +262,15 @@ def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor:
177
  # =============================================================================
178
 
179
  def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
180
- """
181
- Create Sobel kernels for gradient computation.
 
 
 
182
 
183
  Returns:
184
- Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3)
 
185
  """
186
  sobel_x = torch.tensor([
187
  [-1, 0, 1],
@@ -199,14 +288,16 @@ def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
199
 
200
 
201
  def compute_gradients(grayscale: torch.Tensor) -> tuple:
202
- """
203
- Compute image gradients using Sobel filters.
 
204
 
205
  Args:
206
- grayscale: Tensor of shape (B, 1, H, W)
207
 
208
  Returns:
209
- Tuple of (grad_x, grad_y, grad_magnitude)
 
210
  """
211
  sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype)
212
 
@@ -226,20 +317,28 @@ def compute_radial_symmetry_response(
226
  grad_y: torch.Tensor,
227
  grad_magnitude: torch.Tensor,
228
  ) -> torch.Tensor:
229
- """
230
- Compute radial symmetry response for circle detection.
 
 
 
 
 
 
 
 
231
 
232
- This weights regions that are:
233
- 1. Dark (low intensity - typical of pupil/iris)
234
- 2. Have strong radial gradients pointing inward
235
 
236
  Args:
237
- grayscale: Grayscale image (B, 1, H, W)
238
- grad_x, grad_y: Gradient components
239
- grad_magnitude: Gradient magnitude
 
240
 
241
  Returns:
242
- Radial symmetry response map (B, 1, H, W)
243
  """
244
  B, _, H, W = grayscale.shape
245
  device = grayscale.device
@@ -311,15 +410,21 @@ def compute_radial_symmetry_response(
311
 
312
 
313
  def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple:
314
- """
315
- Compute soft argmax to find the center coordinates.
 
 
 
 
 
 
316
 
317
  Args:
318
- response: Response map (B, 1, H, W)
319
- temperature: Softmax temperature (lower = sharper)
320
 
321
  Returns:
322
- Tuple of (cx, cy) each of shape (B,)
323
  """
324
  B, _, H, W = response.shape
325
  device = response.device
@@ -347,17 +452,19 @@ def estimate_eye_center(
347
  images: torch.Tensor,
348
  softmax_temperature: float = 0.1,
349
  ) -> tuple:
350
- """
351
- Estimate the center of the eye region in each image.
 
352
 
353
  Args:
354
- images: RGB images of shape (B, 3, H, W)
355
- softmax_temperature: Temperature for soft argmax (lower = sharper peak detection,
356
- higher = more averaging). Typical range: 0.01-1.0. Default 0.1 works well
357
- for most fundus images. Use higher values (0.3-0.5) for noisy images.
 
358
 
359
  Returns:
360
- Tuple of (cx, cy) each of shape (B,) in pixel coordinates
361
  """
362
  grayscale = rgb_to_grayscale(images)
363
  grad_x, grad_y, grad_magnitude = compute_gradients(grayscale)
@@ -380,19 +487,28 @@ def estimate_radius(
380
  min_radius_frac: float = 0.1,
381
  max_radius_frac: float = 0.5,
382
  ) -> torch.Tensor:
383
- """
384
- Estimate the radius of the eye region by analyzing radial intensity profiles.
 
 
 
 
 
 
 
 
 
385
 
386
  Args:
387
- images: RGB images (B, 3, H, W)
388
- cx, cy: Center coordinates (B,)
389
- num_radii: Number of radius samples
390
- num_angles: Number of angular samples
391
- min_radius_frac: Minimum radius as fraction of image size
392
- max_radius_frac: Maximum radius as fraction of image size
393
 
394
  Returns:
395
- Estimated radius for each image (B,)
396
  """
397
  B, _, H, W = images.shape
398
  device = images.device
@@ -472,18 +588,26 @@ def compute_crop_box(
472
  scale_factor: float = 1.1,
473
  allow_overflow: bool = False,
474
  ) -> tuple:
475
- """
476
- Compute square bounding box for cropping.
 
 
 
 
 
 
 
 
477
 
478
  Args:
479
- cx, cy: Center coordinates (B,)
480
- radius: Estimated radius (B,)
481
- H, W: Image dimensions
482
- scale_factor: Multiply radius by this factor for padding
483
- allow_overflow: If True, don't clamp box to image bounds (for pre-cropped images)
484
 
485
  Returns:
486
- Tuple of (x1, y1, x2, y2) each of shape (B,)
487
  """
488
  # Compute half side length
489
  half_side = radius * scale_factor
@@ -536,19 +660,23 @@ def batch_crop_and_resize(
536
  output_size: int,
537
  padding_mode: str = 'border',
538
  ) -> torch.Tensor:
539
- """
540
- Crop and resize images using grid_sample for GPU efficiency.
 
 
 
 
 
 
541
 
542
  Args:
543
- images: Input images (B, C, H, W)
544
- x1, y1, x2, y2: Crop coordinates (B,) - can extend beyond image bounds
545
- output_size: Output square size
546
- padding_mode: How to handle out-of-bounds sampling:
547
- - 'border': repeat edge pixels (default)
548
- - 'zeros': fill with black (useful for pre-cropped images)
549
 
550
  Returns:
551
- Cropped and resized images (B, C, output_size, output_size)
552
  """
553
  B, C, H, W = images.shape
554
  device = images.device
@@ -584,13 +712,101 @@ def batch_crop_and_resize(
584
 
585
  return cropped
586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
  # =============================================================================
589
  # PHASE 4: CLAHE (Torch-Native)
590
  # =============================================================================
591
 
592
  def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
593
- """Convert sRGB to linear RGB."""
 
 
 
594
  threshold = 0.04045
595
  linear = torch.where(
596
  rgb <= threshold,
@@ -601,7 +817,11 @@ def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
601
 
602
 
603
  def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
604
- """Convert linear RGB to sRGB."""
 
 
 
 
605
  threshold = 0.0031308
606
  srgb = torch.where(
607
  linear <= threshold,
@@ -612,21 +832,26 @@ def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
612
 
613
 
614
  def rgb_to_lab(images: torch.Tensor) -> tuple:
615
- """
616
- Convert sRGB images to CIE LAB color space.
 
 
617
 
618
- This is a proper LAB conversion that:
619
- 1. Converts sRGB to linear RGB
620
- 2. Converts linear RGB to XYZ
621
- 3. Converts XYZ to LAB
 
 
622
 
623
  Args:
624
- images: RGB images (B, C, H, W) in [0, 1] sRGB
625
 
626
  Returns:
627
- Tuple of (L, a, b) where:
628
- - L: Luminance in [0, 1] (normalized from [0, 100])
629
- - a, b: Chrominance (normalized to roughly [-0.5, 0.5])
 
630
  """
631
  device = images.device
632
  dtype = images.dtype
@@ -679,15 +904,18 @@ def rgb_to_lab(images: torch.Tensor) -> tuple:
679
 
680
 
681
  def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor:
682
- """
683
- Convert CIE LAB to sRGB.
 
 
684
 
685
  Args:
686
- L: Luminance in [0, 1] (normalized from [0, 100])
687
- a, b_ch: Chrominance (normalized, roughly [0, 1])
 
688
 
689
  Returns:
690
- RGB images (B, 3, H, W) in [0, 1] sRGB
691
  """
692
  # Denormalize
693
  L_lab = L * 100.0
@@ -735,15 +963,22 @@ def compute_histogram(
735
  tensor: torch.Tensor,
736
  num_bins: int = 256,
737
  ) -> torch.Tensor:
738
- """
739
- Compute histogram for a batch of single-channel images.
 
 
 
 
 
 
 
740
 
741
  Args:
742
- tensor: Input tensor (B, 1, H, W) with values in [0, 1]
743
- num_bins: Number of histogram bins
744
 
745
  Returns:
746
- Histograms (B, num_bins)
747
  """
748
  B = tensor.shape[0]
749
  device = tensor.device
@@ -770,16 +1005,21 @@ def clahe_single_tile(
770
  clip_limit: float,
771
  num_bins: int = 256,
772
  ) -> torch.Tensor:
773
- """
774
- Apply CLAHE to a single tile.
 
 
 
 
 
775
 
776
  Args:
777
- tile: Input tile (B, 1, tile_h, tile_w)
778
- clip_limit: Histogram clip limit
779
- num_bins: Number of histogram bins
780
 
781
  Returns:
782
- CDF lookup table (B, num_bins)
783
  """
784
  B, _, tile_h, tile_w = tile.shape
785
  device = tile.device
@@ -815,17 +1055,30 @@ def apply_clahe_vectorized(
815
  clip_limit: float = 2.0,
816
  num_bins: int = 256,
817
  ) -> torch.Tensor:
818
- """
819
- Vectorized CLAHE implementation (more efficient for GPU).
 
 
 
 
 
 
 
 
 
 
 
 
 
820
 
821
  Args:
822
- images: Input images (B, C, H, W)
823
- grid_size: Number of tiles in each dimension
824
- clip_limit: Histogram clip limit
825
- num_bins: Number of histogram bins
826
 
827
  Returns:
828
- CLAHE-enhanced images (B, C, H, W)
829
  """
830
  B, C, H, W = images.shape
831
  device = images.device
@@ -955,17 +1208,17 @@ def resize_images(
955
  mode: str = 'bilinear',
956
  antialias: bool = True,
957
  ) -> torch.Tensor:
958
- """
959
- Resize images to target size.
960
 
961
  Args:
962
- images: Input images (B, C, H, W)
963
- size: Target size (square)
964
- mode: Interpolation mode
965
- antialias: Whether to use antialiasing
 
966
 
967
  Returns:
968
- Resized images (B, C, size, size)
969
  """
970
  return F.interpolate(
971
  images,
@@ -982,17 +1235,17 @@ def normalize_images(
982
  std: Optional[List[float]] = None,
983
  mode: str = 'imagenet',
984
  ) -> torch.Tensor:
985
- """
986
- Normalize images.
987
 
988
  Args:
989
- images: Input images (B, C, H, W) in [0, 1]
990
- mean: Custom mean (per channel)
991
- std: Custom std (per channel)
992
- mode: 'imagenet', 'none', or 'custom'
 
993
 
994
  Returns:
995
- Normalized images
996
  """
997
  if mode == 'none':
998
  return images
@@ -1020,16 +1273,30 @@ def normalize_images(
1020
  # =============================================================================
1021
 
1022
  class EyeCLAHEImageProcessor(BaseImageProcessor):
1023
- """
1024
- GPU-native image processor for Color Fundus Photography (CFP) images.
1025
-
1026
- This processor:
1027
- 1. Localizes the eye region using gradient-based radial symmetry
1028
- 2. Crops to a border-minimized square centered on the eye
1029
- 3. Applies CLAHE for contrast enhancement
1030
- 4. Resizes and normalizes for vision model input
1031
-
1032
- All operations are implemented in pure PyTorch and are CUDA-compatible.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1033
  """
1034
 
1035
  model_input_names = ["pixel_values"]
@@ -1092,30 +1359,40 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
1092
  def preprocess(
1093
  self,
1094
  images,
 
1095
  return_tensors: str = "pt",
1096
  device: Optional[Union[str, torch.device]] = None,
1097
  **kwargs,
1098
  ) -> BatchFeature:
1099
- """
1100
- Preprocess images for model input.
 
 
 
1101
 
1102
  Args:
1103
- images: Input images in any of these formats:
1104
- - torch.Tensor: (C,H,W), (B,C,H,W), or list of tensors
1105
- - PIL.Image.Image: single image or list of images
1106
- - numpy.ndarray: (H,W,C), (B,H,W,C), or list of arrays
1107
- return_tensors: Return type (only "pt" supported)
1108
- device: Target device for processing (e.g., "cuda", "cpu")
 
 
 
1109
 
1110
  Returns:
1111
- BatchFeature with keys:
1112
- - 'pixel_values': Processed images (B, C, size, size)
1113
- - 'scale_x', 'scale_y': Scale factors for coordinate mapping (B,)
1114
- - 'offset_x', 'offset_y': Offsets for coordinate mapping (B,)
1115
-
1116
- To map coordinates from processed image back to original:
1117
- orig_x = offset_x + cropped_x * scale_x
1118
- orig_y = offset_y + cropped_y * scale_y
 
 
 
1119
  """
1120
  if return_tensors != "pt":
1121
  raise ValueError("Only 'pt' (PyTorch) tensors are supported")
@@ -1133,6 +1410,9 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
1133
 
1134
  # Standardize input
1135
  images = standardize_input(images, device)
 
 
 
1136
  B, C, H_orig, W_orig = images.shape
1137
 
1138
  if self.do_crop:
@@ -1164,6 +1444,13 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
1164
  # Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black
1165
  padding_mode = 'zeros' if self.allow_overflow else 'border'
1166
  images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode)
 
 
 
 
 
 
 
1167
  else:
1168
  # Just resize - no crop
1169
  # Compute coordinate mapping for direct resize
@@ -1173,6 +1460,10 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
1173
  offset_y = torch.zeros(B, device=device, dtype=images.dtype)
1174
  images = resize_images(images, self.size)
1175
 
 
 
 
 
1176
  # Apply CLAHE
1177
  if self.do_clahe:
1178
  images = apply_clahe_vectorized(
@@ -1190,25 +1481,23 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
1190
  )
1191
 
1192
  # Return with coordinate mapping information (flattened structure)
1193
- return BatchFeature(
1194
- data={
1195
- "pixel_values": images,
1196
- "scale_x": scale_x,
1197
- "scale_y": scale_y,
1198
- "offset_x": offset_x,
1199
- "offset_y": offset_y,
1200
- },
1201
- tensor_type="pt"
1202
- )
1203
 
1204
  def __call__(
1205
  self,
1206
  images: Union[torch.Tensor, List[torch.Tensor]],
1207
  **kwargs,
1208
  ) -> BatchFeature:
1209
- """
1210
- Process images (alias for preprocess).
1211
- """
1212
  return self.preprocess(images, **kwargs)
1213
 
1214
 
 
39
  # =============================================================================
40
 
41
  def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
42
+ """Convert a single PIL Image to a float32 tensor of shape (C, H, W) in [0, 1].
43
+
44
+ Converts to RGB if not already. Uses numpy as intermediate when available,
45
+ otherwise falls back to manual pixel extraction.
46
+ """
47
  if not PIL_AVAILABLE:
48
  raise ImportError("PIL is required to process PIL Images")
49
 
 
67
 
68
 
69
  def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor:
70
+ """Convert a single numpy array to a float32 tensor of shape (C, H, W) in [0, 1].
71
+
72
+ Handles grayscale (H, W), HWC (H, W, C) with C in {1, 3, 4}, and uint8/float inputs.
73
+ Makes a copy to avoid sharing memory with the source array.
74
+ """
75
  if not NUMPY_AVAILABLE:
76
  raise ImportError("NumPy is required to process numpy arrays")
77
 
 
97
  images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]],
98
  device: Optional[torch.device] = None,
99
  ) -> torch.Tensor:
100
+ """Convert heterogeneous image inputs to a standardized (B, C, H, W) float32 tensor in [0, 1].
101
+
102
+ Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. Integer-typed
103
+ inputs (uint8) are scaled to [0, 1]. The output is clamped to [0, 1].
104
+
105
+ Note: All images in a list must have the same spatial dimensions (required by torch.stack).
106
+ A single numpy array with ndim==3 is treated as a single HWC image if the last dimension
107
+ is in {1, 3, 4}; otherwise it falls through to the tensor path (assumed CHW).
108
 
109
  Args:
110
  images: Input as:
 
165
 
166
  return images
167
 
168
+ def standardize_mask_input(
169
+ masks: Union[
170
+ torch.Tensor,
171
+ List[torch.Tensor],
172
+ "Image.Image",
173
+ List["Image.Image"],
174
+ "np.ndarray",
175
+ List["np.ndarray"],
176
+ ],
177
+ device: Optional[torch.device] = None,
178
+ ) -> torch.Tensor:
179
+ """Convert heterogeneous mask inputs to a standardized (B, 1, H, W) tensor.
180
 
181
+ Unlike ``standardize_input``, this preserves the original dtype (typically integer
182
+ label values) and does **not** normalize to [0, 1].
183
+
184
+ Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof.
185
+ A single 2-D input is treated as (H, W) and expanded to (1, 1, H, W).
186
+
187
+ Args:
188
+ masks: Input masks in any supported format.
189
+ device: Target device.
190
+
191
+ Returns:
192
+ Tensor of shape (B, 1, H, W) with original dtype preserved.
193
  """
 
194
 
195
+ # Handle single inputs
196
+ if PIL_AVAILABLE and isinstance(masks, Image.Image):
197
+ masks = [masks]
198
+
199
+ if NUMPY_AVAILABLE and isinstance(masks, np.ndarray) and masks.ndim == 2:
200
+ masks = [masks]
201
+
202
+ # Convert list inputs
203
+ if isinstance(masks, list):
204
+ converted = []
205
+ for m in masks:
206
+ if PIL_AVAILABLE and isinstance(m, Image.Image):
207
+ # PIL mask → numpy → tensor
208
+ m = np.array(m)
209
+ converted.append(torch.from_numpy(m))
210
+ elif NUMPY_AVAILABLE and isinstance(m, np.ndarray):
211
+ converted.append(torch.from_numpy(m))
212
+ elif isinstance(m, torch.Tensor):
213
+ converted.append(m)
214
+ else:
215
+ raise TypeError(f"Unsupported mask type: {type(m)}")
216
+
217
+ masks = torch.stack(converted)
218
+
219
+ elif NUMPY_AVAILABLE and isinstance(masks, np.ndarray):
220
+ masks = torch.from_numpy(masks)
221
+
222
+ # At this point masks is a torch.Tensor
223
+
224
+ if masks.dim() == 2:
225
+ # (H, W) → (1, 1, H, W)
226
+ masks = masks.unsqueeze(0).unsqueeze(0)
227
+ elif masks.dim() == 3:
228
+ # (B, H, W) → (B, 1, H, W)
229
+ masks = masks.unsqueeze(1)
230
+ elif masks.dim() == 4:
231
+ # Assume already (B, C, H, W)
232
+ pass
233
+ else:
234
+ raise ValueError(f"Invalid mask shape: {masks.shape}")
235
+
236
+ # Move to device
237
+ if device is not None:
238
+ masks = masks.to(device)
239
+
240
+ return masks
241
+
242
+
243
+ def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor:
244
+ """Convert RGB images to grayscale via ITU-R BT.601 luminance: Y = 0.299R + 0.587G + 0.114B.
245
 
246
  Args:
247
+ images: Tensor of shape (B, 3, H, W) in any value range.
248
 
249
  Returns:
250
+ Tensor of shape (B, 1, H, W) in the same value range as input.
251
  """
252
  # Luminance weights
253
  weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype)
 
262
  # =============================================================================
263
 
264
  def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
265
+ """Create 3x3 Sobel edge-detection kernels for horizontal and vertical gradients.
266
+
267
+ Args:
268
+ device: Target device for the kernels.
269
+ dtype: Target dtype for the kernels.
270
 
271
  Returns:
272
+ Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3),
273
+ suitable for use with ``F.conv2d`` on single-channel input.
274
  """
275
  sobel_x = torch.tensor([
276
  [-1, 0, 1],
 
288
 
289
 
290
  def compute_gradients(grayscale: torch.Tensor) -> tuple:
291
+ """Compute horizontal and vertical image gradients using 3x3 Sobel filters.
292
+
293
+ Uses reflect-free padding=1 (zero-padded convolution) to maintain spatial size.
294
 
295
  Args:
296
+ grayscale: Single-channel images of shape (B, 1, H, W).
297
 
298
  Returns:
299
+ Tuple of (grad_x, grad_y, grad_magnitude), each (B, 1, H, W).
300
+ ``grad_magnitude`` = sqrt(grad_x^2 + grad_y^2 + 1e-8).
301
  """
302
  sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype)
303
 
 
317
  grad_y: torch.Tensor,
318
  grad_magnitude: torch.Tensor,
319
  ) -> torch.Tensor:
320
+ """Compute a radial-symmetry response map for circular-region detection.
321
+
322
+ The algorithm:
323
+ 1. Estimates an initial center as the intensity-weighted center of mass of
324
+ dark regions (squared inverse intensity).
325
+ 2. For each pixel, computes the dot product between the normalized gradient
326
+ vector and the unit vector pointing toward the estimated center.
327
+ 3. Weights this alignment score by gradient magnitude and darkness.
328
+ 4. Smooths the response with a separable Gaussian whose sigma is
329
+ proportional to the image size (kernel_size = max(H,W)//8, sigma = kernel_size/6).
330
 
331
+ High response indicates pixels whose gradients point radially inward toward
332
+ a dark center characteristic of the fundus disc boundary.
 
333
 
334
  Args:
335
+ grayscale: Grayscale images (B, 1, H, W) in [0, 1].
336
+ grad_x: Horizontal gradient (B, 1, H, W).
337
+ grad_y: Vertical gradient (B, 1, H, W).
338
+ grad_magnitude: Gradient magnitude (B, 1, H, W).
339
 
340
  Returns:
341
+ Smoothed radial symmetry response map (B, 1, H, W).
342
  """
343
  B, _, H, W = grayscale.shape
344
  device = grayscale.device
 
410
 
411
 
412
  def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple:
413
+ """Find the sub-pixel peak location in a response map via softmax-weighted coordinates.
414
+
415
+ Divides the flattened response by ``temperature`` before applying softmax, then
416
+ computes the weighted mean of the (x, y) coordinate grids. Lower temperature yields
417
+ a sharper, more argmax-like result; higher temperature yields a broader average.
418
+
419
+ Caution: Very low temperatures (< 0.01) combined with large response magnitudes
420
+ can cause numerical overflow in the softmax exponential.
421
 
422
  Args:
423
+ response: Response map (B, 1, H, W).
424
+ temperature: Softmax temperature. Default 0.1.
425
 
426
  Returns:
427
+ Tuple of (cx, cy), each of shape (B,), in pixel coordinates.
428
  """
429
  B, _, H, W = response.shape
430
  device = response.device
 
452
  images: torch.Tensor,
453
  softmax_temperature: float = 0.1,
454
  ) -> tuple:
455
+ """Estimate the center of the fundus/eye disc in each image.
456
+
457
+ Pipeline: RGB → grayscale → Sobel gradients → radial symmetry response → soft argmax.
458
 
459
  Args:
460
+ images: RGB images of shape (B, 3, H, W) in [0, 1].
461
+ softmax_temperature: Temperature for the soft-argmax peak finder.
462
+ Lower values (0.01-0.1) give sharper localization; higher values
463
+ (0.3-0.5) give broader averaging, useful for noisy or low-contrast images.
464
+ Default 0.1.
465
 
466
  Returns:
467
+ Tuple of (cx, cy), each of shape (B,), in pixel coordinates.
468
  """
469
  grayscale = rgb_to_grayscale(images)
470
  grad_x, grad_y, grad_magnitude = compute_gradients(grayscale)
 
487
  min_radius_frac: float = 0.1,
488
  max_radius_frac: float = 0.5,
489
  ) -> torch.Tensor:
490
+ """Estimate the radius of the fundus disc by analyzing radial intensity profiles.
491
+
492
+ Samples grayscale intensity along ``num_angles`` rays emanating from ``(cx, cy)``
493
+ at ``num_radii`` radial distances. The per-radius mean intensity across all angles
494
+ gives a 1-D radial profile. The discrete derivative of this profile is linearly
495
+ weighted by radius (range 0.5–1.5) to bias toward the outer fundus boundary
496
+ rather than the smaller pupil boundary. The radius at the strongest weighted
497
+ negative gradient is selected as the disc edge.
498
+
499
+ Uses ``F.grid_sample`` with bilinear interpolation and border padding for
500
+ sub-pixel sampling.
501
 
502
  Args:
503
+ images: RGB images (B, 3, H, W) in [0, 1].
504
+ cx, cy: Center coordinates (B,) in pixel units.
505
+ num_radii: Number of radial sample points. Default 100.
506
+ num_angles: Number of angular sample rays. Default 36.
507
+ min_radius_frac: Minimum search radius as fraction of min(H, W). Default 0.1.
508
+ max_radius_frac: Maximum search radius as fraction of min(H, W). Default 0.5.
509
 
510
  Returns:
511
+ Estimated radius for each image (B,), clamped to [min_radius, max_radius].
512
  """
513
  B, _, H, W = images.shape
514
  device = images.device
 
588
  scale_factor: float = 1.1,
589
  allow_overflow: bool = False,
590
  ) -> tuple:
591
+ """Compute a square bounding box centered on the detected eye.
592
+
593
+ The half-side length is ``radius * scale_factor``. When ``allow_overflow`` is
594
+ False, the box is clamped to the image bounds and then made square by shrinking
595
+ to the shorter side and re-centering. The resulting box is guaranteed to be
596
+ square and fully within [0, W-1] x [0, H-1].
597
+
598
+ When ``allow_overflow`` is True the raw (possibly out-of-bounds) box is
599
+ returned, which is useful for images where the fundus disc is partially
600
+ clipped; out-of-bounds regions will be zero-filled during grid_sample.
601
 
602
  Args:
603
+ cx, cy: Detected eye center coordinates (B,).
604
+ radius: Estimated disc radius (B,).
605
+ H, W: Spatial dimensions of the source images.
606
+ scale_factor: Padding multiplier applied to ``radius``. Default 1.1.
607
+ allow_overflow: Skip clamping / squareness enforcement. Default False.
608
 
609
  Returns:
610
+ Tuple of (x1, y1, x2, y2), each of shape (B,), in pixel coordinates.
611
  """
612
  # Compute half side length
613
  half_side = radius * scale_factor
 
660
  output_size: int,
661
  padding_mode: str = 'border',
662
  ) -> torch.Tensor:
663
+ """Crop and resize images to a square using ``F.grid_sample`` (GPU-friendly).
664
+
665
+ Builds a regular output grid in [0, 1]^2, maps it to the source rectangle
666
+ [x1, x2] x [y1, y2] via affine scaling, normalizes to [-1, 1] for
667
+ ``grid_sample``, and samples with bilinear interpolation (``align_corners=True``).
668
+
669
+ Crop coordinates may extend beyond image bounds; the ``padding_mode``
670
+ controls how out-of-bounds pixels are filled.
671
 
672
  Args:
673
+ images: Input images (B, C, H, W).
674
+ x1, y1, x2, y2: Crop box corners (B,). May exceed [0, W-1] / [0, H-1].
675
+ output_size: Side length of the square output.
676
+ padding_mode: ``'border'`` (repeat edge, default) or ``'zeros'`` (black fill).
 
 
677
 
678
  Returns:
679
+ Cropped and resized images (B, C, output_size, output_size).
680
  """
681
  B, C, H, W = images.shape
682
  device = images.device
 
712
 
713
  return cropped
714
 
715
+ #def batch_crop_and_resize_mask(
716
+ # masks: torch.Tensor,
717
+ # x1: torch.Tensor,
718
+ # y1: torch.Tensor,
719
+ # x2: torch.Tensor,
720
+ # y2: torch.Tensor,
721
+ # output_size: int,
722
+ # padding_mode: str = "zeros",
723
+ #) -> torch.Tensor:
724
+ # """
725
+ # Crop and resize masks using nearest-neighbor sampling.
726
+ # """
727
+ # return batch_crop_and_resize(
728
+ # masks,
729
+ # x1, y1, x2, y2,
730
+ # output_size,
731
+ # padding_mode=padding_mode,
732
+ # )
733
+
734
+ def batch_crop_and_resize_mask(
735
+ masks: torch.Tensor, # (B, 1, H, W)
736
+ x1: torch.Tensor,
737
+ y1: torch.Tensor,
738
+ x2: torch.Tensor,
739
+ y2: torch.Tensor,
740
+ output_size: int,
741
+ padding_mode: str = "zeros",
742
+ ) -> torch.Tensor:
743
+ """Crop and resize segmentation masks using nearest-neighbor sampling.
744
+
745
+ Same spatial transform as ``batch_crop_and_resize`` but uses ``mode='nearest'``
746
+ to preserve discrete label values. The output is rounded and cast to ``torch.long``
747
+ to guard against floating-point drift in ``grid_sample``.
748
+
749
+ Args:
750
+ masks: Integer label masks (B, 1, H, W) — any dtype (converted to float internally).
751
+ x1, y1, x2, y2: Crop box corners (B,). May exceed image bounds.
752
+ output_size: Side length of the square output.
753
+ padding_mode: ``'zeros'`` (background = 0, default) or ``'border'`` (repeat edge).
754
+
755
+ Returns:
756
+ Cropped and resized masks (B, 1, output_size, output_size) as ``torch.long``.
757
+ """
758
+
759
+ B, C, H, W = masks.shape
760
+ device = masks.device
761
+
762
+ # grid_sample requires floating point input
763
+ masks_f = masks.float()
764
+
765
+ # Create output grid in [0, 1]
766
+ coords = torch.linspace(0, 1, output_size, device=device)
767
+ out_y, out_x = torch.meshgrid(coords, coords, indexing="ij")
768
+ out_grid = torch.stack([out_x, out_y], dim=-1) # (S, S, 2)
769
+ out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1)
770
+
771
+ # Reshape crop boxes
772
+ x1 = x1.view(B, 1, 1, 1)
773
+ y1 = y1.view(B, 1, 1, 1)
774
+ x2 = x2.view(B, 1, 1, 1)
775
+ y2 = y2.view(B, 1, 1, 1)
776
+
777
+ # Map [0, 1] → pixel coordinates
778
+ sample_x = x1 + out_grid[..., 0:1] * (x2 - x1)
779
+ sample_y = y1 + out_grid[..., 1:2] * (y2 - y1)
780
+
781
+ # Normalize to [-1, 1]
782
+ sample_x = 2.0 * sample_x / (W - 1) - 1.0
783
+ sample_y = 2.0 * sample_y / (H - 1) - 1.0
784
+
785
+ grid = torch.cat([sample_x, sample_y], dim=-1)
786
+
787
+ # Nearest-neighbor sampling with caller-specified padding
788
+ cropped = F.grid_sample(
789
+ masks_f,
790
+ grid,
791
+ mode="nearest",
792
+ padding_mode=padding_mode,
793
+ align_corners=True,
794
+ )
795
+
796
+ # Round before converting to handle floating point errors from grid_sample.
797
+ # Even with mode="nearest", grid_sample can produce values like 0.9999999
798
+ # which would truncate to 0 instead of rounding to 1.
799
+ return cropped.round().long()
800
 
801
  # =============================================================================
802
  # PHASE 4: CLAHE (Torch-Native)
803
  # =============================================================================
804
 
805
  def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
806
+ """Apply the sRGB electro-optical transfer function (EOTF) to convert sRGB to linear RGB.
807
+
808
+ Uses the IEC 61966-2-1 piecewise formula with threshold 0.04045.
809
+ """
810
  threshold = 0.04045
811
  linear = torch.where(
812
  rgb <= threshold,
 
817
 
818
 
819
  def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
820
+ """Apply the inverse sRGB EOTF to convert linear RGB to sRGB.
821
+
822
+ Uses the IEC 61966-2-1 piecewise formula with threshold 0.0031308.
823
+ Input must be non-negative; negative values will produce NaN from the power function.
824
+ """
825
  threshold = 0.0031308
826
  srgb = torch.where(
827
  linear <= threshold,
 
832
 
833
 
834
  def rgb_to_lab(images: torch.Tensor) -> tuple:
835
+ """Convert sRGB images to CIE LAB colour space (D65 illuminant).
836
+
837
+ Conversion chain: sRGB → linear RGB → CIE XYZ → CIE LAB.
838
+ The raw LAB values are rescaled for internal convenience:
839
 
840
+ - L [0, 100] → L / 100 → [0, 1]
841
+ - a ~[-128, 127] → a / 256 + 0.5 → ~[0, 1]
842
+ - b ~[-128, 127] → b / 256 + 0.5 → ~[0, 1]
843
+
844
+ These normalised values are **not** standard LAB; use ``lab_to_rgb`` to
845
+ invert them back to sRGB.
846
 
847
  Args:
848
+ images: RGB images (B, 3, H, W) in [0, 1] sRGB.
849
 
850
  Returns:
851
+ Tuple of (L, a, b_ch), each (B, 1, H, W):
852
+ - L: Normalised luminance in [0, 1].
853
+ - a: Normalised green–red chrominance, roughly [0, 1].
854
+ - b_ch: Normalised blue–yellow chrominance, roughly [0, 1].
855
  """
856
  device = images.device
857
  dtype = images.dtype
 
904
 
905
 
906
  def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor:
907
+ """Convert normalised CIE LAB back to sRGB (inverse of ``rgb_to_lab``).
908
+
909
+ Denormalisation: L*100, (a-0.5)*256, (b_ch-0.5)*256, then LAB → XYZ → linear RGB → sRGB.
910
+ Output is clamped to [0, 1].
911
 
912
  Args:
913
+ L: Normalised luminance (B, 1, H, W) in [0, 1].
914
+ a: Normalised green–red chrominance (B, 1, H, W), roughly [0, 1].
915
+ b_ch: Normalised blue–yellow chrominance (B, 1, H, W), roughly [0, 1].
916
 
917
  Returns:
918
+ sRGB images (B, 3, H, W) clamped to [0, 1].
919
  """
920
  # Denormalize
921
  L_lab = L * 100.0
 
963
  tensor: torch.Tensor,
964
  num_bins: int = 256,
965
  ) -> torch.Tensor:
966
+ """Compute per-image histograms for a batch of single-channel images.
967
+
968
+ Bins are uniformly spaced over [0, 1]. Each pixel is assigned to a bin via
969
+ ``floor(value * (num_bins - 1))``, accumulated with ``scatter_add`` in a
970
+ per-sample loop.
971
+
972
+ Note: This function is used only by ``clahe_single_tile``.
973
+ The vectorized CLAHE path (``apply_clahe_vectorized``) computes histograms
974
+ inline for better GPU efficiency.
975
 
976
  Args:
977
+ tensor: Input (B, 1, H, W) with values in [0, 1].
978
+ num_bins: Number of histogram bins. Default 256.
979
 
980
  Returns:
981
+ Histograms of shape (B, num_bins), dtype matching input.
982
  """
983
  B = tensor.shape[0]
984
  device = tensor.device
 
1005
  clip_limit: float,
1006
  num_bins: int = 256,
1007
  ) -> torch.Tensor:
1008
+ """Compute the clipped-and-redistributed CDF for a single CLAHE tile.
1009
+
1010
+ Clips the histogram so no bin exceeds ``clip_limit * num_pixels / num_bins``,
1011
+ redistributes the excess uniformly, then computes and min-max normalises the CDF.
1012
+
1013
+ Note: This function is not used by the main pipeline — see
1014
+ ``apply_clahe_vectorized`` which processes all tiles in a single pass.
1015
 
1016
  Args:
1017
+ tile: Single-channel tile images (B, 1, tile_h, tile_w) in [0, 1].
1018
+ clip_limit: Relative clip limit (higher = less contrast limiting).
1019
+ num_bins: Number of histogram bins. Default 256.
1020
 
1021
  Returns:
1022
+ Normalised CDF lookup table (B, num_bins) in [0, 1].
1023
  """
1024
  B, _, tile_h, tile_w = tile.shape
1025
  device = tile.device
 
1055
  clip_limit: float = 2.0,
1056
  num_bins: int = 256,
1057
  ) -> torch.Tensor:
1058
+ """Fully-vectorized CLAHE (Contrast Limited Adaptive Histogram Equalisation).
1059
+
1060
+ For RGB input, converts to CIE LAB, applies CLAHE to the L channel only,
1061
+ then converts back to sRGB. For single-channel input, operates directly.
1062
+
1063
+ Algorithm:
1064
+ 1. Pads the luminance channel to be divisible by ``grid_size`` (reflect padding).
1065
+ 2. Reshapes into ``grid_size x grid_size`` non-overlapping tiles.
1066
+ 3. Computes a histogram per tile via ``scatter_add_`` (fully batched, no loops).
1067
+ 4. Clips each histogram at ``clip_limit * num_pixels / num_bins`` and
1068
+ redistributes excess counts uniformly across all bins.
1069
+ 5. Computes the cumulative distribution function (CDF) per tile and
1070
+ min-max normalises it to [0, 1].
1071
+ 6. Maps each output pixel to the four surrounding tile centres and
1072
+ bilinearly interpolates their CDF values for a smooth result.
1073
 
1074
  Args:
1075
+ images: Input images (B, C, H, W) in [0, 1]. C must be 1 or 3.
1076
+ grid_size: Tile grid resolution (tiles per axis). Default 8.
1077
+ clip_limit: Relative clip limit for histogram clipping. Default 2.0.
1078
+ num_bins: Number of histogram bins. Default 256.
1079
 
1080
  Returns:
1081
+ CLAHE-enhanced images (B, C, H, W) in [0, 1].
1082
  """
1083
  B, C, H, W = images.shape
1084
  device = images.device
 
1208
  mode: str = 'bilinear',
1209
  antialias: bool = True,
1210
  ) -> torch.Tensor:
1211
+ """Resize images to a square target size using ``F.interpolate``.
 
1212
 
1213
  Args:
1214
+ images: Input images (B, C, H, W). Must be float for bilinear/bicubic modes.
1215
+ size: Target side length (output is always square).
1216
+ mode: Interpolation mode (``'bilinear'``, ``'bicubic'``, ``'nearest'``, etc.).
1217
+ Default ``'bilinear'``.
1218
+ antialias: Enable antialiasing for bilinear/bicubic downscaling. Default True.
1219
 
1220
  Returns:
1221
+ Resized images (B, C, size, size).
1222
  """
1223
  return F.interpolate(
1224
  images,
 
1235
  std: Optional[List[float]] = None,
1236
  mode: str = 'imagenet',
1237
  ) -> torch.Tensor:
1238
+ """Channel-wise normalisation: ``(image - mean) / std``.
 
1239
 
1240
  Args:
1241
+ images: Input images (B, C, H, W) in [0, 1].
1242
+ mean: Per-channel means (length C). Required when ``mode='custom'``.
1243
+ std: Per-channel stds (length C). Required when ``mode='custom'``.
1244
+ mode: ``'imagenet'`` (uses ImageNet stats), ``'none'`` (identity), or
1245
+ ``'custom'`` (uses caller-supplied mean/std). Default ``'imagenet'``.
1246
 
1247
  Returns:
1248
+ Normalised images (B, C, H, W). Range depends on mean/std.
1249
  """
1250
  if mode == 'none':
1251
  return images
 
1273
  # =============================================================================
1274
 
1275
  class EyeCLAHEImageProcessor(BaseImageProcessor):
1276
+ """GPU-native Hugging Face image processor for Colour Fundus Photography (CFP).
1277
+
1278
+ Processing pipeline (all steps optional via constructor flags):
1279
+
1280
+ 1. **Eye localisation** (``do_crop=True``): detects the fundus disc centre via
1281
+ gradient-based radial symmetry (dark-region centre-of-mass Sobel gradients
1282
+ radial alignment score Gaussian smoothing → soft argmax) and estimates the
1283
+ disc radius from the strongest negative radial intensity gradient.
1284
+ 2. **Square crop & resize**: crops a square region around the detected disc
1285
+ (``radius * crop_scale_factor``), optionally allowing overflow beyond image
1286
+ bounds (``allow_overflow``), then resamples to ``size x size`` via bilinear
1287
+ ``grid_sample``. When ``do_crop=False``, the whole image is resized directly.
1288
+ 3. **CLAHE** (``do_clahe=True``): applies Contrast Limited Adaptive Histogram
1289
+ Equalisation to the CIE LAB luminance channel, using a fully-vectorized
1290
+ tile-based implementation with bilinear CDF interpolation.
1291
+ 4. **Normalisation**: channel-wise ``(image - mean) / std`` with configurable
1292
+ mode (ImageNet, custom, or none).
1293
+
1294
+ The processor also returns per-image coordinate-mapping scalars (``scale_x/y``,
1295
+ ``offset_x/y``) so that predictions in processed-image space can be mapped back
1296
+ to original pixel coordinates.
1297
+
1298
+ All operations are pure PyTorch — no OpenCV, PIL, or NumPy at runtime — and are
1299
+ CUDA-compatible and batch-friendly.
1300
  """
1301
 
1302
  model_input_names = ["pixel_values"]
 
1359
  def preprocess(
1360
  self,
1361
  images,
1362
+ masks=None,
1363
  return_tensors: str = "pt",
1364
  device: Optional[Union[str, torch.device]] = None,
1365
  **kwargs,
1366
  ) -> BatchFeature:
1367
+ """Run the full preprocessing pipeline on a batch of images.
1368
+
1369
+ Accepts any combination of torch.Tensor, PIL.Image, or numpy.ndarray inputs
1370
+ (see ``standardize_input`` for format details). Optionally processes
1371
+ accompanying segmentation masks with matching spatial transforms.
1372
 
1373
  Args:
1374
+ images: Input images in any supported format.
1375
+ masks: Optional segmentation masks in any format accepted by
1376
+ ``standardize_mask_input``. Undergo the same crop/resize as images
1377
+ (nearest-neighbour interpolation, label-preserving). Returned as
1378
+ ``torch.long`` under the ``"mask"`` key (or ``None`` if not provided).
1379
+ return_tensors: Only ``"pt"`` is supported.
1380
+ device: Device for all tensor operations (e.g. ``"cuda:0"``).
1381
+ Defaults to the device of the input tensor, or CPU for PIL/numpy.
1382
+ **kwargs: Passed through to ``BaseImageProcessor``.
1383
 
1384
  Returns:
1385
+ ``BatchFeature`` with keys:
1386
+
1387
+ - ``pixel_values`` (B, 3, size, size): Processed float32 images.
1388
+ - ``mask`` (B, 1, size, size) or ``None``: Processed long masks.
1389
+ - ``scale_x``, ``scale_y`` (B,): Per-image scale factors.
1390
+ - ``offset_x``, ``offset_y`` (B,): Per-image offsets.
1391
+
1392
+ Coordinate mapping from processed original pixel space::
1393
+
1394
+ orig_x = offset_x + proc_x * scale_x
1395
+ orig_y = offset_y + proc_y * scale_y
1396
  """
1397
  if return_tensors != "pt":
1398
  raise ValueError("Only 'pt' (PyTorch) tensors are supported")
 
1410
 
1411
  # Standardize input
1412
  images = standardize_input(images, device)
1413
+ if masks is not None:
1414
+ masks = standardize_mask_input(masks, device)
1415
+
1416
  B, C, H_orig, W_orig = images.shape
1417
 
1418
  if self.do_crop:
 
1444
  # Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black
1445
  padding_mode = 'zeros' if self.allow_overflow else 'border'
1446
  images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode)
1447
+
1448
+ if masks is not None:
1449
+ masks = batch_crop_and_resize_mask(
1450
+ masks, x1, y1, x2, y2,
1451
+ self.size,
1452
+ padding_mode=padding_mode,
1453
+ )
1454
  else:
1455
  # Just resize - no crop
1456
  # Compute coordinate mapping for direct resize
 
1460
  offset_y = torch.zeros(B, device=device, dtype=images.dtype)
1461
  images = resize_images(images, self.size)
1462
 
1463
+ if masks is not None:
1464
+ # F.interpolate requires float input; cast, resize, then restore long
1465
+ masks = resize_images(masks.float(), self.size, mode="nearest", antialias=False).round().long()
1466
+
1467
  # Apply CLAHE
1468
  if self.do_clahe:
1469
  images = apply_clahe_vectorized(
 
1481
  )
1482
 
1483
  # Return with coordinate mapping information (flattened structure)
1484
+ data = {
1485
+ "pixel_values": images,
1486
+ "scale_x": scale_x,
1487
+ "scale_y": scale_y,
1488
+ "offset_x": offset_x,
1489
+ "offset_y": offset_y,
1490
+ }
1491
+ if masks is not None:
1492
+ data["mask"] = masks
1493
+ return BatchFeature(data=data, tensor_type="pt")
1494
 
1495
  def __call__(
1496
  self,
1497
  images: Union[torch.Tensor, List[torch.Tensor]],
1498
  **kwargs,
1499
  ) -> BatchFeature:
1500
+ """Alias for ``preprocess`` — enables ``processor(images, ...)`` call syntax."""
 
 
1501
  return self.preprocess(images, **kwargs)
1502
 
1503