added mask support, bug fixes
Browse files- image_processing_eye_gpu.py +440 -151
image_processing_eye_gpu.py
CHANGED
|
@@ -39,7 +39,11 @@ except ImportError:
|
|
| 39 |
# =============================================================================
|
| 40 |
|
| 41 |
def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
|
| 42 |
-
"""Convert PIL Image to tensor (C, H, W) in [0, 1].
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if not PIL_AVAILABLE:
|
| 44 |
raise ImportError("PIL is required to process PIL Images")
|
| 45 |
|
|
@@ -63,7 +67,11 @@ def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
|
|
| 63 |
|
| 64 |
|
| 65 |
def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor:
|
| 66 |
-
"""Convert numpy array to tensor (C, H, W) in [0, 1].
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
if not NUMPY_AVAILABLE:
|
| 68 |
raise ImportError("NumPy is required to process numpy arrays")
|
| 69 |
|
|
@@ -89,8 +97,14 @@ def standardize_input(
|
|
| 89 |
images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]],
|
| 90 |
device: Optional[torch.device] = None,
|
| 91 |
) -> torch.Tensor:
|
| 92 |
-
"""
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
Args:
|
| 96 |
images: Input as:
|
|
@@ -151,18 +165,89 @@ def standardize_input(
|
|
| 151 |
|
| 152 |
return images
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
"""
|
| 157 |
-
Convert RGB images to grayscale using luminance formula.
|
| 158 |
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
Args:
|
| 162 |
-
images: Tensor of shape (B, 3, H, W)
|
| 163 |
|
| 164 |
Returns:
|
| 165 |
-
Tensor of shape (B, 1, H, W)
|
| 166 |
"""
|
| 167 |
# Luminance weights
|
| 168 |
weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype)
|
|
@@ -177,11 +262,15 @@ def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor:
|
|
| 177 |
# =============================================================================
|
| 178 |
|
| 179 |
def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
|
| 180 |
-
"""
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
Returns:
|
| 184 |
-
Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3)
|
|
|
|
| 185 |
"""
|
| 186 |
sobel_x = torch.tensor([
|
| 187 |
[-1, 0, 1],
|
|
@@ -199,14 +288,16 @@ def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
|
|
| 199 |
|
| 200 |
|
| 201 |
def compute_gradients(grayscale: torch.Tensor) -> tuple:
|
| 202 |
-
"""
|
| 203 |
-
|
|
|
|
| 204 |
|
| 205 |
Args:
|
| 206 |
-
grayscale:
|
| 207 |
|
| 208 |
Returns:
|
| 209 |
-
Tuple of (grad_x, grad_y, grad_magnitude)
|
|
|
|
| 210 |
"""
|
| 211 |
sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype)
|
| 212 |
|
|
@@ -226,20 +317,28 @@ def compute_radial_symmetry_response(
|
|
| 226 |
grad_y: torch.Tensor,
|
| 227 |
grad_magnitude: torch.Tensor,
|
| 228 |
) -> torch.Tensor:
|
| 229 |
-
"""
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
2. Have strong radial gradients pointing inward
|
| 235 |
|
| 236 |
Args:
|
| 237 |
-
grayscale: Grayscale
|
| 238 |
-
grad_x,
|
| 239 |
-
|
|
|
|
| 240 |
|
| 241 |
Returns:
|
| 242 |
-
|
| 243 |
"""
|
| 244 |
B, _, H, W = grayscale.shape
|
| 245 |
device = grayscale.device
|
|
@@ -311,15 +410,21 @@ def compute_radial_symmetry_response(
|
|
| 311 |
|
| 312 |
|
| 313 |
def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple:
|
| 314 |
-
"""
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
Args:
|
| 318 |
-
response: Response map (B, 1, H, W)
|
| 319 |
-
temperature: Softmax temperature
|
| 320 |
|
| 321 |
Returns:
|
| 322 |
-
Tuple of (cx, cy) each of shape (B,)
|
| 323 |
"""
|
| 324 |
B, _, H, W = response.shape
|
| 325 |
device = response.device
|
|
@@ -347,17 +452,19 @@ def estimate_eye_center(
|
|
| 347 |
images: torch.Tensor,
|
| 348 |
softmax_temperature: float = 0.1,
|
| 349 |
) -> tuple:
|
| 350 |
-
"""
|
| 351 |
-
|
|
|
|
| 352 |
|
| 353 |
Args:
|
| 354 |
-
images: RGB images of shape (B, 3, H, W)
|
| 355 |
-
softmax_temperature: Temperature for
|
| 356 |
-
|
| 357 |
-
|
|
|
|
| 358 |
|
| 359 |
Returns:
|
| 360 |
-
Tuple of (cx, cy) each of shape (B,) in pixel coordinates
|
| 361 |
"""
|
| 362 |
grayscale = rgb_to_grayscale(images)
|
| 363 |
grad_x, grad_y, grad_magnitude = compute_gradients(grayscale)
|
|
@@ -380,19 +487,28 @@ def estimate_radius(
|
|
| 380 |
min_radius_frac: float = 0.1,
|
| 381 |
max_radius_frac: float = 0.5,
|
| 382 |
) -> torch.Tensor:
|
| 383 |
-
"""
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
Args:
|
| 387 |
-
images: RGB images (B, 3, H, W)
|
| 388 |
-
cx, cy: Center coordinates (B,)
|
| 389 |
-
num_radii: Number of
|
| 390 |
-
num_angles: Number of angular
|
| 391 |
-
min_radius_frac: Minimum radius as fraction of
|
| 392 |
-
max_radius_frac: Maximum radius as fraction of
|
| 393 |
|
| 394 |
Returns:
|
| 395 |
-
Estimated radius for each image (B,)
|
| 396 |
"""
|
| 397 |
B, _, H, W = images.shape
|
| 398 |
device = images.device
|
|
@@ -472,18 +588,26 @@ def compute_crop_box(
|
|
| 472 |
scale_factor: float = 1.1,
|
| 473 |
allow_overflow: bool = False,
|
| 474 |
) -> tuple:
|
| 475 |
-
"""
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
Args:
|
| 479 |
-
cx, cy:
|
| 480 |
-
radius: Estimated radius (B,)
|
| 481 |
-
H, W:
|
| 482 |
-
scale_factor:
|
| 483 |
-
allow_overflow:
|
| 484 |
|
| 485 |
Returns:
|
| 486 |
-
Tuple of (x1, y1, x2, y2) each of shape (B,)
|
| 487 |
"""
|
| 488 |
# Compute half side length
|
| 489 |
half_side = radius * scale_factor
|
|
@@ -536,19 +660,23 @@ def batch_crop_and_resize(
|
|
| 536 |
output_size: int,
|
| 537 |
padding_mode: str = 'border',
|
| 538 |
) -> torch.Tensor:
|
| 539 |
-
"""
|
| 540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
Args:
|
| 543 |
-
images: Input images (B, C, H, W)
|
| 544 |
-
x1, y1, x2, y2: Crop
|
| 545 |
-
output_size:
|
| 546 |
-
padding_mode:
|
| 547 |
-
- 'border': repeat edge pixels (default)
|
| 548 |
-
- 'zeros': fill with black (useful for pre-cropped images)
|
| 549 |
|
| 550 |
Returns:
|
| 551 |
-
Cropped and resized images (B, C, output_size, output_size)
|
| 552 |
"""
|
| 553 |
B, C, H, W = images.shape
|
| 554 |
device = images.device
|
|
@@ -584,13 +712,101 @@ def batch_crop_and_resize(
|
|
| 584 |
|
| 585 |
return cropped
|
| 586 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
# =============================================================================
|
| 589 |
# PHASE 4: CLAHE (Torch-Native)
|
| 590 |
# =============================================================================
|
| 591 |
|
| 592 |
def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
|
| 593 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 594 |
threshold = 0.04045
|
| 595 |
linear = torch.where(
|
| 596 |
rgb <= threshold,
|
|
@@ -601,7 +817,11 @@ def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
|
|
| 601 |
|
| 602 |
|
| 603 |
def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
|
| 604 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
threshold = 0.0031308
|
| 606 |
srgb = torch.where(
|
| 607 |
linear <= threshold,
|
|
@@ -612,21 +832,26 @@ def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
|
|
| 612 |
|
| 613 |
|
| 614 |
def rgb_to_lab(images: torch.Tensor) -> tuple:
|
| 615 |
-
"""
|
| 616 |
-
|
|
|
|
|
|
|
| 617 |
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
|
|
|
|
|
|
| 622 |
|
| 623 |
Args:
|
| 624 |
-
images: RGB images (B,
|
| 625 |
|
| 626 |
Returns:
|
| 627 |
-
Tuple of (L, a,
|
| 628 |
-
- L:
|
| 629 |
-
- a
|
|
|
|
| 630 |
"""
|
| 631 |
device = images.device
|
| 632 |
dtype = images.dtype
|
|
@@ -679,15 +904,18 @@ def rgb_to_lab(images: torch.Tensor) -> tuple:
|
|
| 679 |
|
| 680 |
|
| 681 |
def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor:
|
| 682 |
-
"""
|
| 683 |
-
|
|
|
|
|
|
|
| 684 |
|
| 685 |
Args:
|
| 686 |
-
L:
|
| 687 |
-
a
|
|
|
|
| 688 |
|
| 689 |
Returns:
|
| 690 |
-
|
| 691 |
"""
|
| 692 |
# Denormalize
|
| 693 |
L_lab = L * 100.0
|
|
@@ -735,15 +963,22 @@ def compute_histogram(
|
|
| 735 |
tensor: torch.Tensor,
|
| 736 |
num_bins: int = 256,
|
| 737 |
) -> torch.Tensor:
|
| 738 |
-
"""
|
| 739 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
Args:
|
| 742 |
-
tensor: Input
|
| 743 |
-
num_bins: Number of histogram bins
|
| 744 |
|
| 745 |
Returns:
|
| 746 |
-
Histograms (B, num_bins)
|
| 747 |
"""
|
| 748 |
B = tensor.shape[0]
|
| 749 |
device = tensor.device
|
|
@@ -770,16 +1005,21 @@ def clahe_single_tile(
|
|
| 770 |
clip_limit: float,
|
| 771 |
num_bins: int = 256,
|
| 772 |
) -> torch.Tensor:
|
| 773 |
-
"""
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
Args:
|
| 777 |
-
tile:
|
| 778 |
-
clip_limit:
|
| 779 |
-
num_bins: Number of histogram bins
|
| 780 |
|
| 781 |
Returns:
|
| 782 |
-
CDF lookup table (B, num_bins)
|
| 783 |
"""
|
| 784 |
B, _, tile_h, tile_w = tile.shape
|
| 785 |
device = tile.device
|
|
@@ -815,17 +1055,30 @@ def apply_clahe_vectorized(
|
|
| 815 |
clip_limit: float = 2.0,
|
| 816 |
num_bins: int = 256,
|
| 817 |
) -> torch.Tensor:
|
| 818 |
-
"""
|
| 819 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
|
| 821 |
Args:
|
| 822 |
-
images: Input images (B, C, H, W)
|
| 823 |
-
grid_size:
|
| 824 |
-
clip_limit:
|
| 825 |
-
num_bins: Number of histogram bins
|
| 826 |
|
| 827 |
Returns:
|
| 828 |
-
CLAHE-enhanced images (B, C, H, W)
|
| 829 |
"""
|
| 830 |
B, C, H, W = images.shape
|
| 831 |
device = images.device
|
|
@@ -955,17 +1208,17 @@ def resize_images(
|
|
| 955 |
mode: str = 'bilinear',
|
| 956 |
antialias: bool = True,
|
| 957 |
) -> torch.Tensor:
|
| 958 |
-
"""
|
| 959 |
-
Resize images to target size.
|
| 960 |
|
| 961 |
Args:
|
| 962 |
-
images: Input images (B, C, H, W)
|
| 963 |
-
size: Target
|
| 964 |
-
mode: Interpolation mode
|
| 965 |
-
|
|
|
|
| 966 |
|
| 967 |
Returns:
|
| 968 |
-
Resized images (B, C, size, size)
|
| 969 |
"""
|
| 970 |
return F.interpolate(
|
| 971 |
images,
|
|
@@ -982,17 +1235,17 @@ def normalize_images(
|
|
| 982 |
std: Optional[List[float]] = None,
|
| 983 |
mode: str = 'imagenet',
|
| 984 |
) -> torch.Tensor:
|
| 985 |
-
"""
|
| 986 |
-
Normalize images.
|
| 987 |
|
| 988 |
Args:
|
| 989 |
-
images: Input images (B, C, H, W) in [0, 1]
|
| 990 |
-
mean:
|
| 991 |
-
std:
|
| 992 |
-
mode: 'imagenet', 'none', or
|
|
|
|
| 993 |
|
| 994 |
Returns:
|
| 995 |
-
|
| 996 |
"""
|
| 997 |
if mode == 'none':
|
| 998 |
return images
|
|
@@ -1020,16 +1273,30 @@ def normalize_images(
|
|
| 1020 |
# =============================================================================
|
| 1021 |
|
| 1022 |
class EyeCLAHEImageProcessor(BaseImageProcessor):
|
| 1023 |
-
"""
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
1.
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1033 |
"""
|
| 1034 |
|
| 1035 |
model_input_names = ["pixel_values"]
|
|
@@ -1092,30 +1359,40 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
|
|
| 1092 |
def preprocess(
|
| 1093 |
self,
|
| 1094 |
images,
|
|
|
|
| 1095 |
return_tensors: str = "pt",
|
| 1096 |
device: Optional[Union[str, torch.device]] = None,
|
| 1097 |
**kwargs,
|
| 1098 |
) -> BatchFeature:
|
| 1099 |
-
"""
|
| 1100 |
-
|
|
|
|
|
|
|
|
|
|
| 1101 |
|
| 1102 |
Args:
|
| 1103 |
-
images: Input images in any
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
-
|
| 1107 |
-
|
| 1108 |
-
|
|
|
|
|
|
|
|
|
|
| 1109 |
|
| 1110 |
Returns:
|
| 1111 |
-
BatchFeature with keys:
|
| 1112 |
-
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
|
|
|
|
|
|
|
|
|
| 1119 |
"""
|
| 1120 |
if return_tensors != "pt":
|
| 1121 |
raise ValueError("Only 'pt' (PyTorch) tensors are supported")
|
|
@@ -1133,6 +1410,9 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
|
|
| 1133 |
|
| 1134 |
# Standardize input
|
| 1135 |
images = standardize_input(images, device)
|
|
|
|
|
|
|
|
|
|
| 1136 |
B, C, H_orig, W_orig = images.shape
|
| 1137 |
|
| 1138 |
if self.do_crop:
|
|
@@ -1164,6 +1444,13 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
|
|
| 1164 |
# Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black
|
| 1165 |
padding_mode = 'zeros' if self.allow_overflow else 'border'
|
| 1166 |
images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1167 |
else:
|
| 1168 |
# Just resize - no crop
|
| 1169 |
# Compute coordinate mapping for direct resize
|
|
@@ -1173,6 +1460,10 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
|
|
| 1173 |
offset_y = torch.zeros(B, device=device, dtype=images.dtype)
|
| 1174 |
images = resize_images(images, self.size)
|
| 1175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1176 |
# Apply CLAHE
|
| 1177 |
if self.do_clahe:
|
| 1178 |
images = apply_clahe_vectorized(
|
|
@@ -1190,25 +1481,23 @@ class EyeCLAHEImageProcessor(BaseImageProcessor):
|
|
| 1190 |
)
|
| 1191 |
|
| 1192 |
# Return with coordinate mapping information (flattened structure)
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
| 1196 |
-
|
| 1197 |
-
|
| 1198 |
-
|
| 1199 |
-
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
)
|
| 1203 |
|
| 1204 |
def __call__(
|
| 1205 |
self,
|
| 1206 |
images: Union[torch.Tensor, List[torch.Tensor]],
|
| 1207 |
**kwargs,
|
| 1208 |
) -> BatchFeature:
|
| 1209 |
-
"""
|
| 1210 |
-
Process images (alias for preprocess).
|
| 1211 |
-
"""
|
| 1212 |
return self.preprocess(images, **kwargs)
|
| 1213 |
|
| 1214 |
|
|
|
|
| 39 |
# =============================================================================
|
| 40 |
|
| 41 |
def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
|
| 42 |
+
"""Convert a single PIL Image to a float32 tensor of shape (C, H, W) in [0, 1].
|
| 43 |
+
|
| 44 |
+
Converts to RGB if not already. Uses numpy as intermediate when available,
|
| 45 |
+
otherwise falls back to manual pixel extraction.
|
| 46 |
+
"""
|
| 47 |
if not PIL_AVAILABLE:
|
| 48 |
raise ImportError("PIL is required to process PIL Images")
|
| 49 |
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor:
|
| 70 |
+
"""Convert a single numpy array to a float32 tensor of shape (C, H, W) in [0, 1].
|
| 71 |
+
|
| 72 |
+
Handles grayscale (H, W), HWC (H, W, C) with C in {1, 3, 4}, and uint8/float inputs.
|
| 73 |
+
Makes a copy to avoid sharing memory with the source array.
|
| 74 |
+
"""
|
| 75 |
if not NUMPY_AVAILABLE:
|
| 76 |
raise ImportError("NumPy is required to process numpy arrays")
|
| 77 |
|
|
|
|
| 97 |
images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]],
|
| 98 |
device: Optional[torch.device] = None,
|
| 99 |
) -> torch.Tensor:
|
| 100 |
+
"""Convert heterogeneous image inputs to a standardized (B, C, H, W) float32 tensor in [0, 1].
|
| 101 |
+
|
| 102 |
+
Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. Integer-typed
|
| 103 |
+
inputs (uint8) are scaled to [0, 1]. The output is clamped to [0, 1].
|
| 104 |
+
|
| 105 |
+
Note: All images in a list must have the same spatial dimensions (required by torch.stack).
|
| 106 |
+
A single numpy array with ndim==3 is treated as a single HWC image if the last dimension
|
| 107 |
+
is in {1, 3, 4}; otherwise it falls through to the tensor path (assumed CHW).
|
| 108 |
|
| 109 |
Args:
|
| 110 |
images: Input as:
|
|
|
|
| 165 |
|
| 166 |
return images
|
| 167 |
|
| 168 |
+
def standardize_mask_input(
|
| 169 |
+
masks: Union[
|
| 170 |
+
torch.Tensor,
|
| 171 |
+
List[torch.Tensor],
|
| 172 |
+
"Image.Image",
|
| 173 |
+
List["Image.Image"],
|
| 174 |
+
"np.ndarray",
|
| 175 |
+
List["np.ndarray"],
|
| 176 |
+
],
|
| 177 |
+
device: Optional[torch.device] = None,
|
| 178 |
+
) -> torch.Tensor:
|
| 179 |
+
"""Convert heterogeneous mask inputs to a standardized (B, 1, H, W) tensor.
|
| 180 |
|
| 181 |
+
Unlike ``standardize_input``, this preserves the original dtype (typically integer
|
| 182 |
+
label values) and does **not** normalize to [0, 1].
|
| 183 |
+
|
| 184 |
+
Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof.
|
| 185 |
+
A single 2-D input is treated as (H, W) and expanded to (1, 1, H, W).
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
masks: Input masks in any supported format.
|
| 189 |
+
device: Target device.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Tensor of shape (B, 1, H, W) with original dtype preserved.
|
| 193 |
"""
|
|
|
|
| 194 |
|
| 195 |
+
# Handle single inputs
|
| 196 |
+
if PIL_AVAILABLE and isinstance(masks, Image.Image):
|
| 197 |
+
masks = [masks]
|
| 198 |
+
|
| 199 |
+
if NUMPY_AVAILABLE and isinstance(masks, np.ndarray) and masks.ndim == 2:
|
| 200 |
+
masks = [masks]
|
| 201 |
+
|
| 202 |
+
# Convert list inputs
|
| 203 |
+
if isinstance(masks, list):
|
| 204 |
+
converted = []
|
| 205 |
+
for m in masks:
|
| 206 |
+
if PIL_AVAILABLE and isinstance(m, Image.Image):
|
| 207 |
+
# PIL mask → numpy → tensor
|
| 208 |
+
m = np.array(m)
|
| 209 |
+
converted.append(torch.from_numpy(m))
|
| 210 |
+
elif NUMPY_AVAILABLE and isinstance(m, np.ndarray):
|
| 211 |
+
converted.append(torch.from_numpy(m))
|
| 212 |
+
elif isinstance(m, torch.Tensor):
|
| 213 |
+
converted.append(m)
|
| 214 |
+
else:
|
| 215 |
+
raise TypeError(f"Unsupported mask type: {type(m)}")
|
| 216 |
+
|
| 217 |
+
masks = torch.stack(converted)
|
| 218 |
+
|
| 219 |
+
elif NUMPY_AVAILABLE and isinstance(masks, np.ndarray):
|
| 220 |
+
masks = torch.from_numpy(masks)
|
| 221 |
+
|
| 222 |
+
# At this point masks is a torch.Tensor
|
| 223 |
+
|
| 224 |
+
if masks.dim() == 2:
|
| 225 |
+
# (H, W) → (1, 1, H, W)
|
| 226 |
+
masks = masks.unsqueeze(0).unsqueeze(0)
|
| 227 |
+
elif masks.dim() == 3:
|
| 228 |
+
# (B, H, W) → (B, 1, H, W)
|
| 229 |
+
masks = masks.unsqueeze(1)
|
| 230 |
+
elif masks.dim() == 4:
|
| 231 |
+
# Assume already (B, C, H, W)
|
| 232 |
+
pass
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"Invalid mask shape: {masks.shape}")
|
| 235 |
+
|
| 236 |
+
# Move to device
|
| 237 |
+
if device is not None:
|
| 238 |
+
masks = masks.to(device)
|
| 239 |
+
|
| 240 |
+
return masks
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
"""Convert RGB images to grayscale via ITU-R BT.601 luminance: Y = 0.299R + 0.587G + 0.114B.
|
| 245 |
|
| 246 |
Args:
|
| 247 |
+
images: Tensor of shape (B, 3, H, W) in any value range.
|
| 248 |
|
| 249 |
Returns:
|
| 250 |
+
Tensor of shape (B, 1, H, W) in the same value range as input.
|
| 251 |
"""
|
| 252 |
# Luminance weights
|
| 253 |
weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype)
|
|
|
|
| 262 |
# =============================================================================
|
| 263 |
|
| 264 |
def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
|
| 265 |
+
"""Create 3x3 Sobel edge-detection kernels for horizontal and vertical gradients.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
device: Target device for the kernels.
|
| 269 |
+
dtype: Target dtype for the kernels.
|
| 270 |
|
| 271 |
Returns:
|
| 272 |
+
Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3),
|
| 273 |
+
suitable for use with ``F.conv2d`` on single-channel input.
|
| 274 |
"""
|
| 275 |
sobel_x = torch.tensor([
|
| 276 |
[-1, 0, 1],
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
def compute_gradients(grayscale: torch.Tensor) -> tuple:
|
| 291 |
+
"""Compute horizontal and vertical image gradients using 3x3 Sobel filters.
|
| 292 |
+
|
| 293 |
+
Uses reflect-free padding=1 (zero-padded convolution) to maintain spatial size.
|
| 294 |
|
| 295 |
Args:
|
| 296 |
+
grayscale: Single-channel images of shape (B, 1, H, W).
|
| 297 |
|
| 298 |
Returns:
|
| 299 |
+
Tuple of (grad_x, grad_y, grad_magnitude), each (B, 1, H, W).
|
| 300 |
+
``grad_magnitude`` = sqrt(grad_x^2 + grad_y^2 + 1e-8).
|
| 301 |
"""
|
| 302 |
sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype)
|
| 303 |
|
|
|
|
| 317 |
grad_y: torch.Tensor,
|
| 318 |
grad_magnitude: torch.Tensor,
|
| 319 |
) -> torch.Tensor:
|
| 320 |
+
"""Compute a radial-symmetry response map for circular-region detection.
|
| 321 |
+
|
| 322 |
+
The algorithm:
|
| 323 |
+
1. Estimates an initial center as the intensity-weighted center of mass of
|
| 324 |
+
dark regions (squared inverse intensity).
|
| 325 |
+
2. For each pixel, computes the dot product between the normalized gradient
|
| 326 |
+
vector and the unit vector pointing toward the estimated center.
|
| 327 |
+
3. Weights this alignment score by gradient magnitude and darkness.
|
| 328 |
+
4. Smooths the response with a separable Gaussian whose sigma is
|
| 329 |
+
proportional to the image size (kernel_size = max(H,W)//8, sigma = kernel_size/6).
|
| 330 |
|
| 331 |
+
High response indicates pixels whose gradients point radially inward toward
|
| 332 |
+
a dark center — characteristic of the fundus disc boundary.
|
|
|
|
| 333 |
|
| 334 |
Args:
|
| 335 |
+
grayscale: Grayscale images (B, 1, H, W) in [0, 1].
|
| 336 |
+
grad_x: Horizontal gradient (B, 1, H, W).
|
| 337 |
+
grad_y: Vertical gradient (B, 1, H, W).
|
| 338 |
+
grad_magnitude: Gradient magnitude (B, 1, H, W).
|
| 339 |
|
| 340 |
Returns:
|
| 341 |
+
Smoothed radial symmetry response map (B, 1, H, W).
|
| 342 |
"""
|
| 343 |
B, _, H, W = grayscale.shape
|
| 344 |
device = grayscale.device
|
|
|
|
| 410 |
|
| 411 |
|
| 412 |
def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple:
|
| 413 |
+
"""Find the sub-pixel peak location in a response map via softmax-weighted coordinates.
|
| 414 |
+
|
| 415 |
+
Divides the flattened response by ``temperature`` before applying softmax, then
|
| 416 |
+
computes the weighted mean of the (x, y) coordinate grids. Lower temperature yields
|
| 417 |
+
a sharper, more argmax-like result; higher temperature yields a broader average.
|
| 418 |
+
|
| 419 |
+
Caution: Very low temperatures (< 0.01) combined with large response magnitudes
|
| 420 |
+
can cause numerical overflow in the softmax exponential.
|
| 421 |
|
| 422 |
Args:
|
| 423 |
+
response: Response map (B, 1, H, W).
|
| 424 |
+
temperature: Softmax temperature. Default 0.1.
|
| 425 |
|
| 426 |
Returns:
|
| 427 |
+
Tuple of (cx, cy), each of shape (B,), in pixel coordinates.
|
| 428 |
"""
|
| 429 |
B, _, H, W = response.shape
|
| 430 |
device = response.device
|
|
|
|
| 452 |
images: torch.Tensor,
|
| 453 |
softmax_temperature: float = 0.1,
|
| 454 |
) -> tuple:
|
| 455 |
+
"""Estimate the center of the fundus/eye disc in each image.
|
| 456 |
+
|
| 457 |
+
Pipeline: RGB → grayscale → Sobel gradients → radial symmetry response → soft argmax.
|
| 458 |
|
| 459 |
Args:
|
| 460 |
+
images: RGB images of shape (B, 3, H, W) in [0, 1].
|
| 461 |
+
softmax_temperature: Temperature for the soft-argmax peak finder.
|
| 462 |
+
Lower values (0.01-0.1) give sharper localization; higher values
|
| 463 |
+
(0.3-0.5) give broader averaging, useful for noisy or low-contrast images.
|
| 464 |
+
Default 0.1.
|
| 465 |
|
| 466 |
Returns:
|
| 467 |
+
Tuple of (cx, cy), each of shape (B,), in pixel coordinates.
|
| 468 |
"""
|
| 469 |
grayscale = rgb_to_grayscale(images)
|
| 470 |
grad_x, grad_y, grad_magnitude = compute_gradients(grayscale)
|
|
|
|
| 487 |
min_radius_frac: float = 0.1,
|
| 488 |
max_radius_frac: float = 0.5,
|
| 489 |
) -> torch.Tensor:
|
| 490 |
+
"""Estimate the radius of the fundus disc by analyzing radial intensity profiles.
|
| 491 |
+
|
| 492 |
+
Samples grayscale intensity along ``num_angles`` rays emanating from ``(cx, cy)``
|
| 493 |
+
at ``num_radii`` radial distances. The per-radius mean intensity across all angles
|
| 494 |
+
gives a 1-D radial profile. The discrete derivative of this profile is linearly
|
| 495 |
+
weighted by radius (range 0.5–1.5) to bias toward the outer fundus boundary
|
| 496 |
+
rather than the smaller pupil boundary. The radius at the strongest weighted
|
| 497 |
+
negative gradient is selected as the disc edge.
|
| 498 |
+
|
| 499 |
+
Uses ``F.grid_sample`` with bilinear interpolation and border padding for
|
| 500 |
+
sub-pixel sampling.
|
| 501 |
|
| 502 |
Args:
|
| 503 |
+
images: RGB images (B, 3, H, W) in [0, 1].
|
| 504 |
+
cx, cy: Center coordinates (B,) in pixel units.
|
| 505 |
+
num_radii: Number of radial sample points. Default 100.
|
| 506 |
+
num_angles: Number of angular sample rays. Default 36.
|
| 507 |
+
min_radius_frac: Minimum search radius as fraction of min(H, W). Default 0.1.
|
| 508 |
+
max_radius_frac: Maximum search radius as fraction of min(H, W). Default 0.5.
|
| 509 |
|
| 510 |
Returns:
|
| 511 |
+
Estimated radius for each image (B,), clamped to [min_radius, max_radius].
|
| 512 |
"""
|
| 513 |
B, _, H, W = images.shape
|
| 514 |
device = images.device
|
|
|
|
| 588 |
scale_factor: float = 1.1,
|
| 589 |
allow_overflow: bool = False,
|
| 590 |
) -> tuple:
|
| 591 |
+
"""Compute a square bounding box centered on the detected eye.
|
| 592 |
+
|
| 593 |
+
The half-side length is ``radius * scale_factor``. When ``allow_overflow`` is
|
| 594 |
+
False, the box is clamped to the image bounds and then made square by shrinking
|
| 595 |
+
to the shorter side and re-centering. The resulting box is guaranteed to be
|
| 596 |
+
square and fully within [0, W-1] x [0, H-1].
|
| 597 |
+
|
| 598 |
+
When ``allow_overflow`` is True the raw (possibly out-of-bounds) box is
|
| 599 |
+
returned, which is useful for images where the fundus disc is partially
|
| 600 |
+
clipped; out-of-bounds regions will be zero-filled during grid_sample.
|
| 601 |
|
| 602 |
Args:
|
| 603 |
+
cx, cy: Detected eye center coordinates (B,).
|
| 604 |
+
radius: Estimated disc radius (B,).
|
| 605 |
+
H, W: Spatial dimensions of the source images.
|
| 606 |
+
scale_factor: Padding multiplier applied to ``radius``. Default 1.1.
|
| 607 |
+
allow_overflow: Skip clamping / squareness enforcement. Default False.
|
| 608 |
|
| 609 |
Returns:
|
| 610 |
+
Tuple of (x1, y1, x2, y2), each of shape (B,), in pixel coordinates.
|
| 611 |
"""
|
| 612 |
# Compute half side length
|
| 613 |
half_side = radius * scale_factor
|
|
|
|
| 660 |
output_size: int,
|
| 661 |
padding_mode: str = 'border',
|
| 662 |
) -> torch.Tensor:
|
| 663 |
+
"""Crop and resize images to a square using ``F.grid_sample`` (GPU-friendly).
|
| 664 |
+
|
| 665 |
+
Builds a regular output grid in [0, 1]^2, maps it to the source rectangle
|
| 666 |
+
[x1, x2] x [y1, y2] via affine scaling, normalizes to [-1, 1] for
|
| 667 |
+
``grid_sample``, and samples with bilinear interpolation (``align_corners=True``).
|
| 668 |
+
|
| 669 |
+
Crop coordinates may extend beyond image bounds; the ``padding_mode``
|
| 670 |
+
controls how out-of-bounds pixels are filled.
|
| 671 |
|
| 672 |
Args:
|
| 673 |
+
images: Input images (B, C, H, W).
|
| 674 |
+
x1, y1, x2, y2: Crop box corners (B,). May exceed [0, W-1] / [0, H-1].
|
| 675 |
+
output_size: Side length of the square output.
|
| 676 |
+
padding_mode: ``'border'`` (repeat edge, default) or ``'zeros'`` (black fill).
|
|
|
|
|
|
|
| 677 |
|
| 678 |
Returns:
|
| 679 |
+
Cropped and resized images (B, C, output_size, output_size).
|
| 680 |
"""
|
| 681 |
B, C, H, W = images.shape
|
| 682 |
device = images.device
|
|
|
|
| 712 |
|
| 713 |
return cropped
|
| 714 |
|
| 715 |
+
#def batch_crop_and_resize_mask(
|
| 716 |
+
# masks: torch.Tensor,
|
| 717 |
+
# x1: torch.Tensor,
|
| 718 |
+
# y1: torch.Tensor,
|
| 719 |
+
# x2: torch.Tensor,
|
| 720 |
+
# y2: torch.Tensor,
|
| 721 |
+
# output_size: int,
|
| 722 |
+
# padding_mode: str = "zeros",
|
| 723 |
+
#) -> torch.Tensor:
|
| 724 |
+
# """
|
| 725 |
+
# Crop and resize masks using nearest-neighbor sampling.
|
| 726 |
+
# """
|
| 727 |
+
# return batch_crop_and_resize(
|
| 728 |
+
# masks,
|
| 729 |
+
# x1, y1, x2, y2,
|
| 730 |
+
# output_size,
|
| 731 |
+
# padding_mode=padding_mode,
|
| 732 |
+
# )
|
| 733 |
+
|
| 734 |
+
def batch_crop_and_resize_mask(
|
| 735 |
+
masks: torch.Tensor, # (B, 1, H, W)
|
| 736 |
+
x1: torch.Tensor,
|
| 737 |
+
y1: torch.Tensor,
|
| 738 |
+
x2: torch.Tensor,
|
| 739 |
+
y2: torch.Tensor,
|
| 740 |
+
output_size: int,
|
| 741 |
+
padding_mode: str = "zeros",
|
| 742 |
+
) -> torch.Tensor:
|
| 743 |
+
"""Crop and resize segmentation masks using nearest-neighbor sampling.
|
| 744 |
+
|
| 745 |
+
Same spatial transform as ``batch_crop_and_resize`` but uses ``mode='nearest'``
|
| 746 |
+
to preserve discrete label values. The output is rounded and cast to ``torch.long``
|
| 747 |
+
to guard against floating-point drift in ``grid_sample``.
|
| 748 |
+
|
| 749 |
+
Args:
|
| 750 |
+
masks: Integer label masks (B, 1, H, W) — any dtype (converted to float internally).
|
| 751 |
+
x1, y1, x2, y2: Crop box corners (B,). May exceed image bounds.
|
| 752 |
+
output_size: Side length of the square output.
|
| 753 |
+
padding_mode: ``'zeros'`` (background = 0, default) or ``'border'`` (repeat edge).
|
| 754 |
+
|
| 755 |
+
Returns:
|
| 756 |
+
Cropped and resized masks (B, 1, output_size, output_size) as ``torch.long``.
|
| 757 |
+
"""
|
| 758 |
+
|
| 759 |
+
B, C, H, W = masks.shape
|
| 760 |
+
device = masks.device
|
| 761 |
+
|
| 762 |
+
# grid_sample requires floating point input
|
| 763 |
+
masks_f = masks.float()
|
| 764 |
+
|
| 765 |
+
# Create output grid in [0, 1]
|
| 766 |
+
coords = torch.linspace(0, 1, output_size, device=device)
|
| 767 |
+
out_y, out_x = torch.meshgrid(coords, coords, indexing="ij")
|
| 768 |
+
out_grid = torch.stack([out_x, out_y], dim=-1) # (S, S, 2)
|
| 769 |
+
out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1)
|
| 770 |
+
|
| 771 |
+
# Reshape crop boxes
|
| 772 |
+
x1 = x1.view(B, 1, 1, 1)
|
| 773 |
+
y1 = y1.view(B, 1, 1, 1)
|
| 774 |
+
x2 = x2.view(B, 1, 1, 1)
|
| 775 |
+
y2 = y2.view(B, 1, 1, 1)
|
| 776 |
+
|
| 777 |
+
# Map [0, 1] → pixel coordinates
|
| 778 |
+
sample_x = x1 + out_grid[..., 0:1] * (x2 - x1)
|
| 779 |
+
sample_y = y1 + out_grid[..., 1:2] * (y2 - y1)
|
| 780 |
+
|
| 781 |
+
# Normalize to [-1, 1]
|
| 782 |
+
sample_x = 2.0 * sample_x / (W - 1) - 1.0
|
| 783 |
+
sample_y = 2.0 * sample_y / (H - 1) - 1.0
|
| 784 |
+
|
| 785 |
+
grid = torch.cat([sample_x, sample_y], dim=-1)
|
| 786 |
+
|
| 787 |
+
# Nearest-neighbor sampling with caller-specified padding
|
| 788 |
+
cropped = F.grid_sample(
|
| 789 |
+
masks_f,
|
| 790 |
+
grid,
|
| 791 |
+
mode="nearest",
|
| 792 |
+
padding_mode=padding_mode,
|
| 793 |
+
align_corners=True,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
# Round before converting to handle floating point errors from grid_sample.
|
| 797 |
+
# Even with mode="nearest", grid_sample can produce values like 0.9999999
|
| 798 |
+
# which would truncate to 0 instead of rounding to 1.
|
| 799 |
+
return cropped.round().long()
|
| 800 |
|
| 801 |
# =============================================================================
|
| 802 |
# PHASE 4: CLAHE (Torch-Native)
|
| 803 |
# =============================================================================
|
| 804 |
|
| 805 |
def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
|
| 806 |
+
"""Apply the sRGB electro-optical transfer function (EOTF) to convert sRGB to linear RGB.
|
| 807 |
+
|
| 808 |
+
Uses the IEC 61966-2-1 piecewise formula with threshold 0.04045.
|
| 809 |
+
"""
|
| 810 |
threshold = 0.04045
|
| 811 |
linear = torch.where(
|
| 812 |
rgb <= threshold,
|
|
|
|
| 817 |
|
| 818 |
|
| 819 |
def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
|
| 820 |
+
"""Apply the inverse sRGB EOTF to convert linear RGB to sRGB.
|
| 821 |
+
|
| 822 |
+
Uses the IEC 61966-2-1 piecewise formula with threshold 0.0031308.
|
| 823 |
+
Input must be non-negative; negative values will produce NaN from the power function.
|
| 824 |
+
"""
|
| 825 |
threshold = 0.0031308
|
| 826 |
srgb = torch.where(
|
| 827 |
linear <= threshold,
|
|
|
|
| 832 |
|
| 833 |
|
| 834 |
def rgb_to_lab(images: torch.Tensor) -> tuple:
|
| 835 |
+
"""Convert sRGB images to CIE LAB colour space (D65 illuminant).
|
| 836 |
+
|
| 837 |
+
Conversion chain: sRGB → linear RGB → CIE XYZ → CIE LAB.
|
| 838 |
+
The raw LAB values are rescaled for internal convenience:
|
| 839 |
|
| 840 |
+
- L ∈ [0, 100] → L / 100 → [0, 1]
|
| 841 |
+
- a ∈ ~[-128, 127] → a / 256 + 0.5 → ~[0, 1]
|
| 842 |
+
- b ∈ ~[-128, 127] → b / 256 + 0.5 → ~[0, 1]
|
| 843 |
+
|
| 844 |
+
These normalised values are **not** standard LAB; use ``lab_to_rgb`` to
|
| 845 |
+
invert them back to sRGB.
|
| 846 |
|
| 847 |
Args:
|
| 848 |
+
images: RGB images (B, 3, H, W) in [0, 1] sRGB.
|
| 849 |
|
| 850 |
Returns:
|
| 851 |
+
Tuple of (L, a, b_ch), each (B, 1, H, W):
|
| 852 |
+
- L: Normalised luminance in [0, 1].
|
| 853 |
+
- a: Normalised green–red chrominance, roughly [0, 1].
|
| 854 |
+
- b_ch: Normalised blue–yellow chrominance, roughly [0, 1].
|
| 855 |
"""
|
| 856 |
device = images.device
|
| 857 |
dtype = images.dtype
|
|
|
|
| 904 |
|
| 905 |
|
| 906 |
def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor:
|
| 907 |
+
"""Convert normalised CIE LAB back to sRGB (inverse of ``rgb_to_lab``).
|
| 908 |
+
|
| 909 |
+
Denormalisation: L*100, (a-0.5)*256, (b_ch-0.5)*256, then LAB → XYZ → linear RGB → sRGB.
|
| 910 |
+
Output is clamped to [0, 1].
|
| 911 |
|
| 912 |
Args:
|
| 913 |
+
L: Normalised luminance (B, 1, H, W) in [0, 1].
|
| 914 |
+
a: Normalised green–red chrominance (B, 1, H, W), roughly [0, 1].
|
| 915 |
+
b_ch: Normalised blue–yellow chrominance (B, 1, H, W), roughly [0, 1].
|
| 916 |
|
| 917 |
Returns:
|
| 918 |
+
sRGB images (B, 3, H, W) clamped to [0, 1].
|
| 919 |
"""
|
| 920 |
# Denormalize
|
| 921 |
L_lab = L * 100.0
|
|
|
|
| 963 |
tensor: torch.Tensor,
|
| 964 |
num_bins: int = 256,
|
| 965 |
) -> torch.Tensor:
|
| 966 |
+
"""Compute per-image histograms for a batch of single-channel images.
|
| 967 |
+
|
| 968 |
+
Bins are uniformly spaced over [0, 1]. Each pixel is assigned to a bin via
|
| 969 |
+
``floor(value * (num_bins - 1))``, accumulated with ``scatter_add`` in a
|
| 970 |
+
per-sample loop.
|
| 971 |
+
|
| 972 |
+
Note: This function is used only by ``clahe_single_tile``.
|
| 973 |
+
The vectorized CLAHE path (``apply_clahe_vectorized``) computes histograms
|
| 974 |
+
inline for better GPU efficiency.
|
| 975 |
|
| 976 |
Args:
|
| 977 |
+
tensor: Input (B, 1, H, W) with values in [0, 1].
|
| 978 |
+
num_bins: Number of histogram bins. Default 256.
|
| 979 |
|
| 980 |
Returns:
|
| 981 |
+
Histograms of shape (B, num_bins), dtype matching input.
|
| 982 |
"""
|
| 983 |
B = tensor.shape[0]
|
| 984 |
device = tensor.device
|
|
|
|
| 1005 |
clip_limit: float,
|
| 1006 |
num_bins: int = 256,
|
| 1007 |
) -> torch.Tensor:
|
| 1008 |
+
"""Compute the clipped-and-redistributed CDF for a single CLAHE tile.
|
| 1009 |
+
|
| 1010 |
+
Clips the histogram so no bin exceeds ``clip_limit * num_pixels / num_bins``,
|
| 1011 |
+
redistributes the excess uniformly, then computes and min-max normalises the CDF.
|
| 1012 |
+
|
| 1013 |
+
Note: This function is not used by the main pipeline — see
|
| 1014 |
+
``apply_clahe_vectorized`` which processes all tiles in a single pass.
|
| 1015 |
|
| 1016 |
Args:
|
| 1017 |
+
tile: Single-channel tile images (B, 1, tile_h, tile_w) in [0, 1].
|
| 1018 |
+
clip_limit: Relative clip limit (higher = less contrast limiting).
|
| 1019 |
+
num_bins: Number of histogram bins. Default 256.
|
| 1020 |
|
| 1021 |
Returns:
|
| 1022 |
+
Normalised CDF lookup table (B, num_bins) in [0, 1].
|
| 1023 |
"""
|
| 1024 |
B, _, tile_h, tile_w = tile.shape
|
| 1025 |
device = tile.device
|
|
|
|
| 1055 |
clip_limit: float = 2.0,
|
| 1056 |
num_bins: int = 256,
|
| 1057 |
) -> torch.Tensor:
|
| 1058 |
+
"""Fully-vectorized CLAHE (Contrast Limited Adaptive Histogram Equalisation).
|
| 1059 |
+
|
| 1060 |
+
For RGB input, converts to CIE LAB, applies CLAHE to the L channel only,
|
| 1061 |
+
then converts back to sRGB. For single-channel input, operates directly.
|
| 1062 |
+
|
| 1063 |
+
Algorithm:
|
| 1064 |
+
1. Pads the luminance channel to be divisible by ``grid_size`` (reflect padding).
|
| 1065 |
+
2. Reshapes into ``grid_size x grid_size`` non-overlapping tiles.
|
| 1066 |
+
3. Computes a histogram per tile via ``scatter_add_`` (fully batched, no loops).
|
| 1067 |
+
4. Clips each histogram at ``clip_limit * num_pixels / num_bins`` and
|
| 1068 |
+
redistributes excess counts uniformly across all bins.
|
| 1069 |
+
5. Computes the cumulative distribution function (CDF) per tile and
|
| 1070 |
+
min-max normalises it to [0, 1].
|
| 1071 |
+
6. Maps each output pixel to the four surrounding tile centres and
|
| 1072 |
+
bilinearly interpolates their CDF values for a smooth result.
|
| 1073 |
|
| 1074 |
Args:
|
| 1075 |
+
images: Input images (B, C, H, W) in [0, 1]. C must be 1 or 3.
|
| 1076 |
+
grid_size: Tile grid resolution (tiles per axis). Default 8.
|
| 1077 |
+
clip_limit: Relative clip limit for histogram clipping. Default 2.0.
|
| 1078 |
+
num_bins: Number of histogram bins. Default 256.
|
| 1079 |
|
| 1080 |
Returns:
|
| 1081 |
+
CLAHE-enhanced images (B, C, H, W) in [0, 1].
|
| 1082 |
"""
|
| 1083 |
B, C, H, W = images.shape
|
| 1084 |
device = images.device
|
|
|
|
| 1208 |
mode: str = 'bilinear',
|
| 1209 |
antialias: bool = True,
|
| 1210 |
) -> torch.Tensor:
|
| 1211 |
+
"""Resize images to a square target size using ``F.interpolate``.
|
|
|
|
| 1212 |
|
| 1213 |
Args:
|
| 1214 |
+
images: Input images (B, C, H, W). Must be float for bilinear/bicubic modes.
|
| 1215 |
+
size: Target side length (output is always square).
|
| 1216 |
+
mode: Interpolation mode (``'bilinear'``, ``'bicubic'``, ``'nearest'``, etc.).
|
| 1217 |
+
Default ``'bilinear'``.
|
| 1218 |
+
antialias: Enable antialiasing for bilinear/bicubic downscaling. Default True.
|
| 1219 |
|
| 1220 |
Returns:
|
| 1221 |
+
Resized images (B, C, size, size).
|
| 1222 |
"""
|
| 1223 |
return F.interpolate(
|
| 1224 |
images,
|
|
|
|
| 1235 |
std: Optional[List[float]] = None,
|
| 1236 |
mode: str = 'imagenet',
|
| 1237 |
) -> torch.Tensor:
|
| 1238 |
+
"""Channel-wise normalisation: ``(image - mean) / std``.
|
|
|
|
| 1239 |
|
| 1240 |
Args:
|
| 1241 |
+
images: Input images (B, C, H, W) in [0, 1].
|
| 1242 |
+
mean: Per-channel means (length C). Required when ``mode='custom'``.
|
| 1243 |
+
std: Per-channel stds (length C). Required when ``mode='custom'``.
|
| 1244 |
+
mode: ``'imagenet'`` (uses ImageNet stats), ``'none'`` (identity), or
|
| 1245 |
+
``'custom'`` (uses caller-supplied mean/std). Default ``'imagenet'``.
|
| 1246 |
|
| 1247 |
Returns:
|
| 1248 |
+
Normalised images (B, C, H, W). Range depends on mean/std.
|
| 1249 |
"""
|
| 1250 |
if mode == 'none':
|
| 1251 |
return images
|
|
|
|
| 1273 |
# =============================================================================
|
| 1274 |
|
| 1275 |
class EyeCLAHEImageProcessor(BaseImageProcessor):
|
| 1276 |
+
"""GPU-native Hugging Face image processor for Colour Fundus Photography (CFP).
|
| 1277 |
+
|
| 1278 |
+
Processing pipeline (all steps optional via constructor flags):
|
| 1279 |
+
|
| 1280 |
+
1. **Eye localisation** (``do_crop=True``): detects the fundus disc centre via
|
| 1281 |
+
gradient-based radial symmetry (dark-region centre-of-mass → Sobel gradients →
|
| 1282 |
+
radial alignment score → Gaussian smoothing → soft argmax) and estimates the
|
| 1283 |
+
disc radius from the strongest negative radial intensity gradient.
|
| 1284 |
+
2. **Square crop & resize**: crops a square region around the detected disc
|
| 1285 |
+
(``radius * crop_scale_factor``), optionally allowing overflow beyond image
|
| 1286 |
+
bounds (``allow_overflow``), then resamples to ``size x size`` via bilinear
|
| 1287 |
+
``grid_sample``. When ``do_crop=False``, the whole image is resized directly.
|
| 1288 |
+
3. **CLAHE** (``do_clahe=True``): applies Contrast Limited Adaptive Histogram
|
| 1289 |
+
Equalisation to the CIE LAB luminance channel, using a fully-vectorized
|
| 1290 |
+
tile-based implementation with bilinear CDF interpolation.
|
| 1291 |
+
4. **Normalisation**: channel-wise ``(image - mean) / std`` with configurable
|
| 1292 |
+
mode (ImageNet, custom, or none).
|
| 1293 |
+
|
| 1294 |
+
The processor also returns per-image coordinate-mapping scalars (``scale_x/y``,
|
| 1295 |
+
``offset_x/y``) so that predictions in processed-image space can be mapped back
|
| 1296 |
+
to original pixel coordinates.
|
| 1297 |
+
|
| 1298 |
+
All operations are pure PyTorch — no OpenCV, PIL, or NumPy at runtime — and are
|
| 1299 |
+
CUDA-compatible and batch-friendly.
|
| 1300 |
"""
|
| 1301 |
|
| 1302 |
model_input_names = ["pixel_values"]
|
|
|
|
| 1359 |
def preprocess(
|
| 1360 |
self,
|
| 1361 |
images,
|
| 1362 |
+
masks=None,
|
| 1363 |
return_tensors: str = "pt",
|
| 1364 |
device: Optional[Union[str, torch.device]] = None,
|
| 1365 |
**kwargs,
|
| 1366 |
) -> BatchFeature:
|
| 1367 |
+
"""Run the full preprocessing pipeline on a batch of images.
|
| 1368 |
+
|
| 1369 |
+
Accepts any combination of torch.Tensor, PIL.Image, or numpy.ndarray inputs
|
| 1370 |
+
(see ``standardize_input`` for format details). Optionally processes
|
| 1371 |
+
accompanying segmentation masks with matching spatial transforms.
|
| 1372 |
|
| 1373 |
Args:
|
| 1374 |
+
images: Input images in any supported format.
|
| 1375 |
+
masks: Optional segmentation masks in any format accepted by
|
| 1376 |
+
``standardize_mask_input``. Undergo the same crop/resize as images
|
| 1377 |
+
(nearest-neighbour interpolation, label-preserving). Returned as
|
| 1378 |
+
``torch.long`` under the ``"mask"`` key (or ``None`` if not provided).
|
| 1379 |
+
return_tensors: Only ``"pt"`` is supported.
|
| 1380 |
+
device: Device for all tensor operations (e.g. ``"cuda:0"``).
|
| 1381 |
+
Defaults to the device of the input tensor, or CPU for PIL/numpy.
|
| 1382 |
+
**kwargs: Passed through to ``BaseImageProcessor``.
|
| 1383 |
|
| 1384 |
Returns:
|
| 1385 |
+
``BatchFeature`` with keys:
|
| 1386 |
+
|
| 1387 |
+
- ``pixel_values`` (B, 3, size, size): Processed float32 images.
|
| 1388 |
+
- ``mask`` (B, 1, size, size) or ``None``: Processed long masks.
|
| 1389 |
+
- ``scale_x``, ``scale_y`` (B,): Per-image scale factors.
|
| 1390 |
+
- ``offset_x``, ``offset_y`` (B,): Per-image offsets.
|
| 1391 |
+
|
| 1392 |
+
Coordinate mapping from processed → original pixel space::
|
| 1393 |
+
|
| 1394 |
+
orig_x = offset_x + proc_x * scale_x
|
| 1395 |
+
orig_y = offset_y + proc_y * scale_y
|
| 1396 |
"""
|
| 1397 |
if return_tensors != "pt":
|
| 1398 |
raise ValueError("Only 'pt' (PyTorch) tensors are supported")
|
|
|
|
| 1410 |
|
| 1411 |
# Standardize input
|
| 1412 |
images = standardize_input(images, device)
|
| 1413 |
+
if masks is not None:
|
| 1414 |
+
masks = standardize_mask_input(masks, device)
|
| 1415 |
+
|
| 1416 |
B, C, H_orig, W_orig = images.shape
|
| 1417 |
|
| 1418 |
if self.do_crop:
|
|
|
|
| 1444 |
# Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black
|
| 1445 |
padding_mode = 'zeros' if self.allow_overflow else 'border'
|
| 1446 |
images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode)
|
| 1447 |
+
|
| 1448 |
+
if masks is not None:
|
| 1449 |
+
masks = batch_crop_and_resize_mask(
|
| 1450 |
+
masks, x1, y1, x2, y2,
|
| 1451 |
+
self.size,
|
| 1452 |
+
padding_mode=padding_mode,
|
| 1453 |
+
)
|
| 1454 |
else:
|
| 1455 |
# Just resize - no crop
|
| 1456 |
# Compute coordinate mapping for direct resize
|
|
|
|
| 1460 |
offset_y = torch.zeros(B, device=device, dtype=images.dtype)
|
| 1461 |
images = resize_images(images, self.size)
|
| 1462 |
|
| 1463 |
+
if masks is not None:
|
| 1464 |
+
# F.interpolate requires float input; cast, resize, then restore long
|
| 1465 |
+
masks = resize_images(masks.float(), self.size, mode="nearest", antialias=False).round().long()
|
| 1466 |
+
|
| 1467 |
# Apply CLAHE
|
| 1468 |
if self.do_clahe:
|
| 1469 |
images = apply_clahe_vectorized(
|
|
|
|
| 1481 |
)
|
| 1482 |
|
| 1483 |
# Return with coordinate mapping information (flattened structure)
|
| 1484 |
+
data = {
|
| 1485 |
+
"pixel_values": images,
|
| 1486 |
+
"scale_x": scale_x,
|
| 1487 |
+
"scale_y": scale_y,
|
| 1488 |
+
"offset_x": offset_x,
|
| 1489 |
+
"offset_y": offset_y,
|
| 1490 |
+
}
|
| 1491 |
+
if masks is not None:
|
| 1492 |
+
data["mask"] = masks
|
| 1493 |
+
return BatchFeature(data=data, tensor_type="pt")
|
| 1494 |
|
| 1495 |
def __call__(
|
| 1496 |
self,
|
| 1497 |
images: Union[torch.Tensor, List[torch.Tensor]],
|
| 1498 |
**kwargs,
|
| 1499 |
) -> BatchFeature:
|
| 1500 |
+
"""Alias for ``preprocess`` — enables ``processor(images, ...)`` call syntax."""
|
|
|
|
|
|
|
| 1501 |
return self.preprocess(images, **kwargs)
|
| 1502 |
|
| 1503 |
|