iszt commited on
Commit
2fe10b5
·
verified ·
1 Parent(s): 8a38ea0

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +158 -0
  2. image_processing_eye_gpu.py +1186 -0
  3. preprocessor_config.json +20 -0
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EyeCLAHEImageProcessor
2
+
3
+ A GPU-native Hugging Face ImageProcessor for **Color Fundus Photography (CFP)** images, designed for diabetic retinopathy detection and other retinal imaging tasks.
4
+
5
+ ## Features
6
+
7
+ - **Eye Region Localization**: Automatically detects and centers on the fundus using gradient-based radial symmetry
8
+ - **Smart Cropping**: Border-minimized square crop centered on the detected eye
9
+ - **CLAHE Enhancement**: Contrast Limited Adaptive Histogram Equalization for improved visibility
10
+ - **Pure PyTorch**: No OpenCV/PIL dependencies at runtime - fully GPU-accelerated
11
+ - **Batch Processing**: Efficient batched operations for training pipelines
12
+ - **Flexible Input**: Accepts PyTorch tensors, PIL Images, and NumPy arrays
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ pip install transformers torch
18
+ ```
19
+
20
+ ## Quick Start
21
+
22
+ ```python
23
+ from transformers import AutoImageProcessor
24
+ from PIL import Image
25
+
26
+ # Load the processor
27
+ processor = AutoImageProcessor.from_pretrained("iszt/eye-clahe-processor", trust_remote_code=True)
28
+
29
+ # Process a single image
30
+ image = Image.open("fundus_image.jpg")
31
+ outputs = processor(image, return_tensors="pt")
32
+ pixel_values = outputs["pixel_values"] # Shape: (1, 3, 512, 512)
33
+
34
+ # Process on GPU
35
+ outputs = processor(image, return_tensors="pt", device="cuda")
36
+ ```
37
+
38
+ ## Batch Processing
39
+
40
+ ```python
41
+ import torch
42
+ from PIL import Image
43
+
44
+ # Load multiple images
45
+ images = [Image.open(f"image_{i}.jpg") for i in range(8)]
46
+
47
+ # Process batch
48
+ outputs = processor(images, return_tensors="pt", device="cuda")
49
+ pixel_values = outputs["pixel_values"] # Shape: (8, 3, 512, 512)
50
+ ```
51
+
52
+ ## With PyTorch Tensors
53
+
54
+ ```python
55
+ import torch
56
+
57
+ # Tensor input: (B, C, H, W) or (C, H, W)
58
+ images = torch.rand(4, 3, 512, 512) # Batch of 4 images
59
+
60
+ outputs = processor(images, return_tensors="pt")
61
+ ```
62
+
63
+ ## Configuration Options
64
+
65
+ | Parameter | Default | Description |
66
+ |-----------|---------|-------------|
67
+ | `size` | 512 | Output image size (square) |
68
+ | `do_crop` | true | Enable eye-centered cropping |
69
+ | `do_clahe` | true | Enable CLAHE contrast enhancement |
70
+ | `crop_scale_factor` | 1.1 | Padding around detected eye region |
71
+ | `clahe_grid_size` | 8 | CLAHE tile grid size |
72
+ | `clahe_clip_limit` | 2.0 | CLAHE histogram clip limit |
73
+ | `normalization_mode` | "imagenet" | Normalization: "imagenet", "none", or "custom" |
74
+ | `min_radius_frac` | 0.1 | Minimum eye radius as fraction of image |
75
+ | `max_radius_frac` | 1.2 | Maximum eye radius as fraction of image |
76
+ | `allow_overflow` | true | Allow crop box beyond image bounds (fills with black) |
77
+ | `softmax_temperature` | 0.1 | Temperature for eye center detection (higher = smoother) |
78
+
79
+ ## Custom Configuration
80
+
81
+ ```python
82
+ from transformers import AutoImageProcessor
83
+
84
+ processor = AutoImageProcessor.from_pretrained(
85
+ "iszt/eye-clahe-processor",
86
+ trust_remote_code=True,
87
+ size=384,
88
+ normalization_mode="imagenet",
89
+ clahe_clip_limit=3.0,
90
+ softmax_temperature=0.3,
91
+ )
92
+ ```
93
+
94
+ ## Processing Pipeline
95
+
96
+ The processor applies the following steps:
97
+
98
+ 1. **Input Standardization**: Convert PIL/NumPy/Tensor to (B, C, H, W) float32 tensor in [0, 1]
99
+ 2. **Eye Localization**: Detect fundus center using radial symmetry analysis
100
+ 3. **Radius Estimation**: Determine fundus boundary from radial intensity profiles
101
+ 4. **Crop & Resize**: Extract square region centered on eye, resize to target size
102
+ 5. **CLAHE**: Apply contrast enhancement in LAB color space (L channel only)
103
+ 6. **Normalization**: Apply ImageNet normalization (optional)
104
+
105
+ ## Use with Vision Models
106
+
107
+ ```python
108
+ from transformers import AutoImageProcessor, AutoModel
109
+ from PIL import Image
110
+
111
+ # Load processor and model
112
+ processor = AutoImageProcessor.from_pretrained("iszt/eye-clahe-processor", trust_remote_code=True)
113
+ model = AutoModel.from_pretrained("google/vit-base-patch16-224")
114
+
115
+ # Process and run inference
116
+ image = Image.open("fundus.jpg")
117
+ inputs = processor(image, return_tensors="pt", device="cuda")
118
+
119
+ # Update normalization for pretrained models
120
+ inputs["pixel_values"] = (inputs["pixel_values"] - torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda()) / torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
121
+
122
+ with torch.no_grad():
123
+ outputs = model(**inputs)
124
+ ```
125
+
126
+ ## Technical Details
127
+
128
+ ### Eye Center Detection
129
+
130
+ Uses a gradient-based radial symmetry approach:
131
+ - Computes Sobel gradients
132
+ - Weights pixels by darkness (fundus is typically darker than background)
133
+ - Finds center where gradients point inward radially
134
+ - Uses soft argmax for sub-pixel accuracy
135
+
136
+ ### CLAHE Implementation
137
+
138
+ Pure PyTorch CLAHE with:
139
+ - Proper sRGB to CIE LAB conversion
140
+ - Vectorized histogram computation using scatter_add
141
+ - Bilinear interpolation between tile CDFs
142
+ - Only modifies L channel, preserving color information
143
+
144
+ ## License
145
+
146
+ Apache 2.0
147
+
148
+ ## Citation
149
+
150
+ If you use this processor in your research, please cite:
151
+
152
+ ```bibtex
153
+ @software{eye_clahe_processor,
154
+ title={EyeCLAHEImageProcessor: GPU-Native Fundus Image Preprocessing},
155
+ year={2026},
156
+ url={https://huggingface.co/iszt/eye-clahe-processor}
157
+ }
158
+ ```
image_processing_eye_gpu.py ADDED
@@ -0,0 +1,1186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU-Native Eye Image Processor for Color Fundus Photography (CFP) Images.
3
+
4
+ This module implements a fully PyTorch-based image processor that:
5
+ 1. Localizes the eye/fundus region using gradient-based radial symmetry
6
+ 2. Crops to a border-minimized square centered on the eye
7
+ 3. Applies CLAHE for contrast enhancement
8
+ 4. Outputs tensors compatible with Hugging Face vision models
9
+
10
+ Constraints:
11
+ - PyTorch only (no OpenCV, PIL, NumPy in runtime)
12
+ - CUDA-compatible, batch-friendly, deterministic
13
+ """
14
+
15
+ from typing import Dict, List, Optional, Union
16
+ import math
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
21
+
22
+ # Optional imports for broader input support
23
+ try:
24
+ from PIL import Image
25
+ PIL_AVAILABLE = True
26
+ except ImportError:
27
+ PIL_AVAILABLE = False
28
+
29
+ try:
30
+ import numpy as np
31
+ NUMPY_AVAILABLE = True
32
+ except ImportError:
33
+ NUMPY_AVAILABLE = False
34
+
35
+
36
+ # =============================================================================
37
+ # PHASE 1: Input & Tensor Standardization
38
+ # =============================================================================
39
+
40
+ def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
41
+ """Convert PIL Image to tensor (C, H, W) in [0, 1]."""
42
+ if not PIL_AVAILABLE:
43
+ raise ImportError("PIL is required to process PIL Images")
44
+
45
+ # Convert to RGB if necessary
46
+ if image.mode != "RGB":
47
+ image = image.convert("RGB")
48
+
49
+ # Use numpy as intermediate if available, otherwise manual conversion
50
+ if NUMPY_AVAILABLE:
51
+ arr = np.array(image, dtype=np.float32) / 255.0
52
+ # (H, W, C) -> (C, H, W)
53
+ tensor = torch.from_numpy(arr).permute(2, 0, 1)
54
+ else:
55
+ # Manual conversion without numpy
56
+ width, height = image.size
57
+ pixels = list(image.getdata())
58
+ tensor = torch.tensor(pixels, dtype=torch.float32).view(height, width, 3) / 255.0
59
+ tensor = tensor.permute(2, 0, 1)
60
+
61
+ return tensor
62
+
63
+
64
+ def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor:
65
+ """Convert numpy array to tensor (C, H, W) in [0, 1]."""
66
+ if not NUMPY_AVAILABLE:
67
+ raise ImportError("NumPy is required to process numpy arrays")
68
+
69
+ # Handle different array shapes
70
+ if arr.ndim == 2:
71
+ # Grayscale (H, W) -> (1, H, W)
72
+ arr = arr[..., None]
73
+
74
+ if arr.ndim == 3 and arr.shape[-1] in [1, 3, 4]:
75
+ # (H, W, C) -> (C, H, W)
76
+ arr = arr.transpose(2, 0, 1)
77
+
78
+ # Convert to float and normalize
79
+ if arr.dtype == np.uint8:
80
+ arr = arr.astype(np.float32) / 255.0
81
+ elif arr.dtype != np.float32:
82
+ arr = arr.astype(np.float32)
83
+
84
+ return torch.from_numpy(arr.copy())
85
+
86
+
87
+ def standardize_input(
88
+ images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]],
89
+ device: Optional[torch.device] = None,
90
+ ) -> torch.Tensor:
91
+ """
92
+ Convert input images to standardized tensor format.
93
+
94
+ Args:
95
+ images: Input as:
96
+ - torch.Tensor (C,H,W), (B,C,H,W), or list of tensors
97
+ - PIL.Image.Image or list of PIL Images
98
+ - numpy.ndarray (H,W,C), (B,H,W,C), or list of arrays
99
+ device: Target device (defaults to input device or CPU)
100
+
101
+ Returns:
102
+ Tensor of shape (B, C, H, W) in float32, range [0, 1]
103
+ """
104
+ # Handle single inputs by wrapping in list
105
+ if PIL_AVAILABLE and isinstance(images, Image.Image):
106
+ images = [images]
107
+ if NUMPY_AVAILABLE and isinstance(images, np.ndarray) and images.ndim == 3:
108
+ # Could be single (H,W,C) or batch (B,H,W) grayscale - assume single if last dim is 1-4
109
+ if images.shape[-1] in [1, 3, 4]:
110
+ images = [images]
111
+
112
+ # Convert list inputs to tensors
113
+ if isinstance(images, list):
114
+ converted = []
115
+ for img in images:
116
+ if PIL_AVAILABLE and isinstance(img, Image.Image):
117
+ converted.append(_pil_to_tensor(img))
118
+ elif NUMPY_AVAILABLE and isinstance(img, np.ndarray):
119
+ converted.append(_numpy_to_tensor(img))
120
+ elif isinstance(img, torch.Tensor):
121
+ t = img if img.dim() == 3 else img.squeeze(0)
122
+ converted.append(t)
123
+ else:
124
+ raise TypeError(f"Unsupported image type: {type(img)}")
125
+ images = torch.stack(converted)
126
+ elif NUMPY_AVAILABLE and isinstance(images, np.ndarray):
127
+ # Batch of numpy arrays (B, H, W, C)
128
+ if images.ndim == 4:
129
+ images = images.transpose(0, 3, 1, 2) # (B, C, H, W)
130
+ if images.dtype == np.uint8:
131
+ images = images.astype(np.float32) / 255.0
132
+ images = torch.from_numpy(images.copy())
133
+
134
+ if images.dim() == 3:
135
+ # Add batch dimension: (C, H, W) -> (B, C, H, W)
136
+ images = images.unsqueeze(0)
137
+
138
+ # Move to target device if specified
139
+ if device is not None:
140
+ images = images.to(device)
141
+
142
+ # Convert to float32 and normalize to [0, 1]
143
+ if images.dtype == torch.uint8:
144
+ images = images.float() / 255.0
145
+ elif images.dtype != torch.float32:
146
+ images = images.float()
147
+
148
+ # Clamp to valid range
149
+ images = images.clamp(0.0, 1.0)
150
+
151
+ return images
152
+
153
+
154
+ def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Convert RGB images to grayscale using luminance formula.
157
+
158
+ Y = 0.299 * R + 0.587 * G + 0.114 * B
159
+
160
+ Args:
161
+ images: Tensor of shape (B, 3, H, W)
162
+
163
+ Returns:
164
+ Tensor of shape (B, 1, H, W)
165
+ """
166
+ # Luminance weights
167
+ weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype)
168
+ weights = weights.view(1, 3, 1, 1)
169
+
170
+ grayscale = (images * weights).sum(dim=1, keepdim=True)
171
+ return grayscale
172
+
173
+
174
+ # =============================================================================
175
+ # PHASE 2: Eye Region Localization (GPU-Safe)
176
+ # =============================================================================
177
+
178
+ def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
179
+ """
180
+ Create Sobel kernels for gradient computation.
181
+
182
+ Returns:
183
+ Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3)
184
+ """
185
+ sobel_x = torch.tensor([
186
+ [-1, 0, 1],
187
+ [-2, 0, 2],
188
+ [-1, 0, 1]
189
+ ], device=device, dtype=dtype).view(1, 1, 3, 3)
190
+
191
+ sobel_y = torch.tensor([
192
+ [-1, -2, -1],
193
+ [ 0, 0, 0],
194
+ [ 1, 2, 1]
195
+ ], device=device, dtype=dtype).view(1, 1, 3, 3)
196
+
197
+ return sobel_x, sobel_y
198
+
199
+
200
+ def compute_gradients(grayscale: torch.Tensor) -> tuple:
201
+ """
202
+ Compute image gradients using Sobel filters.
203
+
204
+ Args:
205
+ grayscale: Tensor of shape (B, 1, H, W)
206
+
207
+ Returns:
208
+ Tuple of (grad_x, grad_y, grad_magnitude)
209
+ """
210
+ sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype)
211
+
212
+ # Apply Sobel filters with padding to maintain size
213
+ grad_x = F.conv2d(grayscale, sobel_x, padding=1)
214
+ grad_y = F.conv2d(grayscale, sobel_y, padding=1)
215
+
216
+ # Compute gradient magnitude
217
+ grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8)
218
+
219
+ return grad_x, grad_y, grad_magnitude
220
+
221
+
222
+ def compute_radial_symmetry_response(
223
+ grayscale: torch.Tensor,
224
+ grad_x: torch.Tensor,
225
+ grad_y: torch.Tensor,
226
+ grad_magnitude: torch.Tensor,
227
+ ) -> torch.Tensor:
228
+ """
229
+ Compute radial symmetry response for circle detection.
230
+
231
+ This weights regions that are:
232
+ 1. Dark (low intensity - typical of pupil/iris)
233
+ 2. Have strong radial gradients pointing inward
234
+
235
+ Args:
236
+ grayscale: Grayscale image (B, 1, H, W)
237
+ grad_x, grad_y: Gradient components
238
+ grad_magnitude: Gradient magnitude
239
+
240
+ Returns:
241
+ Radial symmetry response map (B, 1, H, W)
242
+ """
243
+ B, _, H, W = grayscale.shape
244
+ device = grayscale.device
245
+ dtype = grayscale.dtype
246
+
247
+ # Create coordinate grids
248
+ y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W)
249
+ x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W)
250
+
251
+ # Compute center of mass of dark regions as initial estimate
252
+ # Invert intensity so dark regions have high weight
253
+ dark_weight = 1.0 - grayscale
254
+ dark_weight = dark_weight ** 2 # Emphasize darker regions
255
+
256
+ # Normalize weights
257
+ weight_sum = dark_weight.sum(dim=(2, 3), keepdim=True) + 1e-8
258
+
259
+ # Weighted center of mass
260
+ cx_init = (dark_weight * x_coords).sum(dim=(2, 3), keepdim=True) / weight_sum
261
+ cy_init = (dark_weight * y_coords).sum(dim=(2, 3), keepdim=True) / weight_sum
262
+
263
+ # Compute vectors from each pixel to estimated center
264
+ dx_to_center = cx_init - x_coords
265
+ dy_to_center = cy_init - y_coords
266
+ dist_to_center = torch.sqrt(dx_to_center ** 2 + dy_to_center ** 2 + 1e-8)
267
+
268
+ # Normalize direction vectors
269
+ dx_norm = dx_to_center / dist_to_center
270
+ dy_norm = dy_to_center / dist_to_center
271
+
272
+ # Normalize gradient vectors
273
+ grad_norm = grad_magnitude + 1e-8
274
+ gx_norm = grad_x / grad_norm
275
+ gy_norm = grad_y / grad_norm
276
+
277
+ # Radial symmetry: gradient should point toward center
278
+ # Dot product between gradient and direction to center
279
+ radial_alignment = gx_norm * dx_norm + gy_norm * dy_norm
280
+
281
+ # Weight by gradient magnitude and darkness
282
+ response = radial_alignment * grad_magnitude * dark_weight
283
+
284
+ # Apply Gaussian smoothing to get robust response
285
+ kernel_size = max(H, W) // 8
286
+ if kernel_size % 2 == 0:
287
+ kernel_size += 1
288
+ kernel_size = max(kernel_size, 5)
289
+
290
+ sigma = kernel_size / 6.0
291
+
292
+ # Create 1D Gaussian kernel
293
+ x = torch.arange(kernel_size, device=device, dtype=dtype) - kernel_size // 2
294
+ gaussian_1d = torch.exp(-x ** 2 / (2 * sigma ** 2))
295
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
296
+
297
+ # Separable 2D convolution
298
+ gaussian_1d_h = gaussian_1d.view(1, 1, 1, kernel_size)
299
+ gaussian_1d_v = gaussian_1d.view(1, 1, kernel_size, 1)
300
+
301
+ pad_h = kernel_size // 2
302
+ pad_v = kernel_size // 2
303
+
304
+ response = F.pad(response, (pad_h, pad_h, 0, 0), mode='reflect')
305
+ response = F.conv2d(response, gaussian_1d_h)
306
+ response = F.pad(response, (0, 0, pad_v, pad_v), mode='reflect')
307
+ response = F.conv2d(response, gaussian_1d_v)
308
+
309
+ return response
310
+
311
+
312
+ def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple:
313
+ """
314
+ Compute soft argmax to find the center coordinates.
315
+
316
+ Args:
317
+ response: Response map (B, 1, H, W)
318
+ temperature: Softmax temperature (lower = sharper)
319
+
320
+ Returns:
321
+ Tuple of (cx, cy) each of shape (B,)
322
+ """
323
+ B, _, H, W = response.shape
324
+ device = response.device
325
+ dtype = response.dtype
326
+
327
+ # Flatten spatial dimensions
328
+ response_flat = response.view(B, -1)
329
+
330
+ # Apply softmax with temperature
331
+ weights = F.softmax(response_flat / temperature, dim=1)
332
+ weights = weights.view(B, 1, H, W)
333
+
334
+ # Create coordinate grids
335
+ y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W)
336
+ x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W)
337
+
338
+ # Weighted sum of coordinates
339
+ cx = (weights * x_coords).sum(dim=(2, 3)).squeeze(-1) # (B,)
340
+ cy = (weights * y_coords).sum(dim=(2, 3)).squeeze(-1) # (B,)
341
+
342
+ return cx, cy
343
+
344
+
345
+ def estimate_eye_center(
346
+ images: torch.Tensor,
347
+ softmax_temperature: float = 0.1,
348
+ ) -> tuple:
349
+ """
350
+ Estimate the center of the eye region in each image.
351
+
352
+ Args:
353
+ images: RGB images of shape (B, 3, H, W)
354
+ softmax_temperature: Temperature for soft argmax (lower = sharper peak detection,
355
+ higher = more averaging). Typical range: 0.01-1.0. Default 0.1 works well
356
+ for most fundus images. Use higher values (0.3-0.5) for noisy images.
357
+
358
+ Returns:
359
+ Tuple of (cx, cy) each of shape (B,) in pixel coordinates
360
+ """
361
+ grayscale = rgb_to_grayscale(images)
362
+ grad_x, grad_y, grad_magnitude = compute_gradients(grayscale)
363
+ response = compute_radial_symmetry_response(grayscale, grad_x, grad_y, grad_magnitude)
364
+ cx, cy = soft_argmax_2d(response, temperature=softmax_temperature)
365
+
366
+ return cx, cy
367
+
368
+
369
+ # =============================================================================
370
+ # PHASE 2.3: Radius Estimation
371
+ # =============================================================================
372
+
373
+ def estimate_radius(
374
+ images: torch.Tensor,
375
+ cx: torch.Tensor,
376
+ cy: torch.Tensor,
377
+ num_radii: int = 100,
378
+ num_angles: int = 36,
379
+ min_radius_frac: float = 0.1,
380
+ max_radius_frac: float = 0.5,
381
+ ) -> torch.Tensor:
382
+ """
383
+ Estimate the radius of the eye region by analyzing radial intensity profiles.
384
+
385
+ Args:
386
+ images: RGB images (B, 3, H, W)
387
+ cx, cy: Center coordinates (B,)
388
+ num_radii: Number of radius samples
389
+ num_angles: Number of angular samples
390
+ min_radius_frac: Minimum radius as fraction of image size
391
+ max_radius_frac: Maximum radius as fraction of image size
392
+
393
+ Returns:
394
+ Estimated radius for each image (B,)
395
+ """
396
+ B, _, H, W = images.shape
397
+ device = images.device
398
+ dtype = images.dtype
399
+
400
+ grayscale = rgb_to_grayscale(images) # (B, 1, H, W)
401
+
402
+ min_dim = min(H, W)
403
+ min_radius = int(min_radius_frac * min_dim)
404
+ max_radius = int(max_radius_frac * min_dim)
405
+
406
+ # Create radius and angle samples
407
+ radii = torch.linspace(min_radius, max_radius, num_radii, device=device, dtype=dtype)
408
+ angles = torch.linspace(0, 2 * math.pi, num_angles + 1, device=device, dtype=dtype)[:-1]
409
+
410
+ # Create sampling grid: (num_angles, num_radii)
411
+ cos_angles = torch.cos(angles).view(-1, 1) # (num_angles, 1)
412
+ sin_angles = torch.sin(angles).view(-1, 1) # (num_angles, 1)
413
+
414
+ # Offset coordinates from center
415
+ dx = cos_angles * radii # (num_angles, num_radii)
416
+ dy = sin_angles * radii # (num_angles, num_radii)
417
+
418
+ # Compute absolute coordinates for each batch item
419
+ # cx, cy: (B,) -> expand to (B, num_angles, num_radii)
420
+ cx_expanded = cx.view(B, 1, 1).expand(B, num_angles, num_radii)
421
+ cy_expanded = cy.view(B, 1, 1).expand(B, num_angles, num_radii)
422
+
423
+ sample_x = cx_expanded + dx.unsqueeze(0) # (B, num_angles, num_radii)
424
+ sample_y = cy_expanded + dy.unsqueeze(0) # (B, num_angles, num_radii)
425
+
426
+ # Normalize to [-1, 1] for grid_sample
427
+ sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0
428
+ sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0
429
+
430
+ # Create sampling grid: (B, num_angles, num_radii, 2)
431
+ grid = torch.stack([sample_x_norm, sample_y_norm], dim=-1)
432
+
433
+ # Sample intensities
434
+ sampled = F.grid_sample(
435
+ grayscale, grid, mode='bilinear', padding_mode='border', align_corners=True
436
+ ) # (B, 1, num_angles, num_radii)
437
+
438
+ # Average over angles to get radial profile
439
+ radial_profile = sampled.mean(dim=2).squeeze(1) # (B, num_radii)
440
+
441
+ # Compute gradient of radial profile (looking for strong negative gradient at iris edge)
442
+ radial_gradient = radial_profile[:, 1:] - radial_profile[:, :-1] # (B, num_radii-1)
443
+
444
+ # Find the radius with strongest negative gradient (edge of iris)
445
+ # Weight by radius to prefer larger circles (avoid pupil boundary)
446
+ radius_weights = torch.linspace(0.5, 1.5, num_radii - 1, device=device, dtype=dtype)
447
+ weighted_gradient = radial_gradient * radius_weights.unsqueeze(0)
448
+
449
+ # Find minimum (strongest negative gradient)
450
+ min_idx = weighted_gradient.argmin(dim=1) # (B,)
451
+
452
+ # Convert index to radius value
453
+ estimated_radius = radii[min_idx + 1] # +1 because gradient has one less element
454
+
455
+ # Clamp to valid range
456
+ estimated_radius = estimated_radius.clamp(min_radius, max_radius)
457
+
458
+ return estimated_radius
459
+
460
+
461
+ # =============================================================================
462
+ # PHASE 3: Border-Minimized Square Crop
463
+ # =============================================================================
464
+
465
+ def compute_crop_box(
466
+ cx: torch.Tensor,
467
+ cy: torch.Tensor,
468
+ radius: torch.Tensor,
469
+ H: int,
470
+ W: int,
471
+ scale_factor: float = 1.1,
472
+ allow_overflow: bool = False,
473
+ ) -> tuple:
474
+ """
475
+ Compute square bounding box for cropping.
476
+
477
+ Args:
478
+ cx, cy: Center coordinates (B,)
479
+ radius: Estimated radius (B,)
480
+ H, W: Image dimensions
481
+ scale_factor: Multiply radius by this factor for padding
482
+ allow_overflow: If True, don't clamp box to image bounds (for pre-cropped images)
483
+
484
+ Returns:
485
+ Tuple of (x1, y1, x2, y2) each of shape (B,)
486
+ """
487
+ # Compute half side length
488
+ half_side = radius * scale_factor
489
+
490
+ # Initial box centered on detected eye
491
+ x1 = cx - half_side
492
+ y1 = cy - half_side
493
+ x2 = cx + half_side
494
+ y2 = cy + half_side
495
+
496
+ if allow_overflow:
497
+ # Keep the box centered on the eye, don't clamp
498
+ # Out-of-bounds regions will be filled with black during cropping
499
+ return x1, y1, x2, y2
500
+
501
+ # Clamp to image bounds while maintaining square shape
502
+ # If box exceeds bounds, shift it
503
+ x1 = x1.clamp(min=0)
504
+ y1 = y1.clamp(min=0)
505
+ x2 = x2.clamp(max=W - 1)
506
+ y2 = y2.clamp(max=H - 1)
507
+
508
+ # Ensure square by taking minimum side
509
+ side_x = x2 - x1
510
+ side_y = y2 - y1
511
+ side = torch.minimum(side_x, side_y)
512
+
513
+ # Recenter the box
514
+ cx_new = (x1 + x2) / 2
515
+ cy_new = (y1 + y2) / 2
516
+
517
+ x1 = (cx_new - side / 2).clamp(min=0)
518
+ y1 = (cy_new - side / 2).clamp(min=0)
519
+ x2 = x1 + side
520
+ y2 = y1 + side
521
+
522
+ # Final clamp
523
+ x2 = x2.clamp(max=W - 1)
524
+ y2 = y2.clamp(max=H - 1)
525
+
526
+ return x1, y1, x2, y2
527
+
528
+
529
+ def batch_crop_and_resize(
530
+ images: torch.Tensor,
531
+ x1: torch.Tensor,
532
+ y1: torch.Tensor,
533
+ x2: torch.Tensor,
534
+ y2: torch.Tensor,
535
+ output_size: int,
536
+ padding_mode: str = 'border',
537
+ ) -> torch.Tensor:
538
+ """
539
+ Crop and resize images using grid_sample for GPU efficiency.
540
+
541
+ Args:
542
+ images: Input images (B, C, H, W)
543
+ x1, y1, x2, y2: Crop coordinates (B,) - can extend beyond image bounds
544
+ output_size: Output square size
545
+ padding_mode: How to handle out-of-bounds sampling:
546
+ - 'border': repeat edge pixels (default)
547
+ - 'zeros': fill with black (useful for pre-cropped images)
548
+
549
+ Returns:
550
+ Cropped and resized images (B, C, output_size, output_size)
551
+ """
552
+ B, C, H, W = images.shape
553
+ device = images.device
554
+ dtype = images.dtype
555
+
556
+ # Create output grid coordinates
557
+ out_coords = torch.linspace(0, 1, output_size, device=device, dtype=dtype)
558
+ out_y, out_x = torch.meshgrid(out_coords, out_coords, indexing='ij')
559
+ out_grid = torch.stack([out_x, out_y], dim=-1) # (output_size, output_size, 2)
560
+ out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) # (B, output_size, output_size, 2)
561
+
562
+ # Scale grid to crop coordinates
563
+ # out_grid is in [0, 1], need to map to [x1, x2] and [y1, y2]
564
+ x1 = x1.view(B, 1, 1, 1)
565
+ y1 = y1.view(B, 1, 1, 1)
566
+ x2 = x2.view(B, 1, 1, 1)
567
+ y2 = y2.view(B, 1, 1, 1)
568
+
569
+ # Map [0, 1] to pixel coordinates
570
+ sample_x = x1 + out_grid[..., 0:1] * (x2 - x1)
571
+ sample_y = y1 + out_grid[..., 1:2] * (y2 - y1)
572
+
573
+ # Normalize to [-1, 1] for grid_sample
574
+ sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0
575
+ sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0
576
+
577
+ grid = torch.cat([sample_x_norm, sample_y_norm], dim=-1) # (B, output_size, output_size, 2)
578
+
579
+ # Sample with specified padding mode
580
+ cropped = F.grid_sample(
581
+ images, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True
582
+ )
583
+
584
+ return cropped
585
+
586
+
587
+ # =============================================================================
588
+ # PHASE 4: CLAHE (Torch-Native)
589
+ # =============================================================================
590
+
591
+ def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
592
+ """Convert sRGB to linear RGB."""
593
+ threshold = 0.04045
594
+ linear = torch.where(
595
+ rgb <= threshold,
596
+ rgb / 12.92,
597
+ ((rgb + 0.055) / 1.055) ** 2.4
598
+ )
599
+ return linear
600
+
601
+
602
+ def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
603
+ """Convert linear RGB to sRGB."""
604
+ threshold = 0.0031308
605
+ srgb = torch.where(
606
+ linear <= threshold,
607
+ linear * 12.92,
608
+ 1.055 * (linear ** (1.0 / 2.4)) - 0.055
609
+ )
610
+ return srgb
611
+
612
+
613
+ def rgb_to_lab(images: torch.Tensor) -> tuple:
614
+ """
615
+ Convert sRGB images to CIE LAB color space.
616
+
617
+ This is a proper LAB conversion that:
618
+ 1. Converts sRGB to linear RGB
619
+ 2. Converts linear RGB to XYZ
620
+ 3. Converts XYZ to LAB
621
+
622
+ Args:
623
+ images: RGB images (B, C, H, W) in [0, 1] sRGB
624
+
625
+ Returns:
626
+ Tuple of (L, a, b) where:
627
+ - L: Luminance in [0, 1] (normalized from [0, 100])
628
+ - a, b: Chrominance (normalized to roughly [-0.5, 0.5])
629
+ """
630
+ device = images.device
631
+ dtype = images.dtype
632
+
633
+ # Step 1: sRGB to linear RGB
634
+ linear_rgb = _srgb_to_linear(images)
635
+
636
+ # Step 2: Linear RGB to XYZ (D65 illuminant)
637
+ # RGB to XYZ matrix
638
+ r = linear_rgb[:, 0:1, :, :]
639
+ g = linear_rgb[:, 1:2, :, :]
640
+ b = linear_rgb[:, 2:3, :, :]
641
+
642
+ x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b
643
+ y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b
644
+ z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b
645
+
646
+ # D65 reference white
647
+ xn, yn, zn = 0.95047, 1.0, 1.08883
648
+
649
+ x = x / xn
650
+ y = y / yn
651
+ z = z / zn
652
+
653
+ # Step 3: XYZ to LAB
654
+ delta = 6.0 / 29.0
655
+ delta_cube = delta ** 3
656
+
657
+ def f(t):
658
+ return torch.where(
659
+ t > delta_cube,
660
+ t ** (1.0 / 3.0),
661
+ t / (3.0 * delta ** 2) + 4.0 / 29.0
662
+ )
663
+
664
+ fx = f(x)
665
+ fy = f(y)
666
+ fz = f(z)
667
+
668
+ L = 116.0 * fy - 16.0 # Range [0, 100]
669
+ a = 500.0 * (fx - fy) # Range roughly [-128, 127]
670
+ b_ch = 200.0 * (fy - fz) # Range roughly [-128, 127]
671
+
672
+ # Normalize to convenient ranges for processing
673
+ L = L / 100.0 # [0, 1]
674
+ a = a / 256.0 + 0.5 # Roughly [0, 1]
675
+ b_ch = b_ch / 256.0 + 0.5 # Roughly [0, 1]
676
+
677
+ return L, a, b_ch
678
+
679
+
680
+ def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor:
681
+ """
682
+ Convert CIE LAB to sRGB.
683
+
684
+ Args:
685
+ L: Luminance in [0, 1] (normalized from [0, 100])
686
+ a, b_ch: Chrominance (normalized, roughly [0, 1])
687
+
688
+ Returns:
689
+ RGB images (B, 3, H, W) in [0, 1] sRGB
690
+ """
691
+ # Denormalize
692
+ L_lab = L * 100.0
693
+ a_lab = (a - 0.5) * 256.0
694
+ b_lab = (b_ch - 0.5) * 256.0
695
+
696
+ # LAB to XYZ
697
+ fy = (L_lab + 16.0) / 116.0
698
+ fx = a_lab / 500.0 + fy
699
+ fz = fy - b_lab / 200.0
700
+
701
+ delta = 6.0 / 29.0
702
+
703
+ def f_inv(t):
704
+ return torch.where(
705
+ t > delta,
706
+ t ** 3,
707
+ 3.0 * (delta ** 2) * (t - 4.0 / 29.0)
708
+ )
709
+
710
+ # D65 reference white
711
+ xn, yn, zn = 0.95047, 1.0, 1.08883
712
+
713
+ x = xn * f_inv(fx)
714
+ y = yn * f_inv(fy)
715
+ z = zn * f_inv(fz)
716
+
717
+ # XYZ to linear RGB
718
+ r = 3.2404542 * x - 1.5371385 * y - 0.4985314 * z
719
+ g = -0.9692660 * x + 1.8760108 * y + 0.0415560 * z
720
+ b = 0.0556434 * x - 0.2040259 * y + 1.0572252 * z
721
+
722
+ linear_rgb = torch.cat([r, g, b], dim=1)
723
+
724
+ # Clamp before gamma correction to avoid NaN from negative values
725
+ linear_rgb = linear_rgb.clamp(0.0, 1.0)
726
+
727
+ # Linear RGB to sRGB
728
+ srgb = _linear_to_srgb(linear_rgb)
729
+
730
+ return srgb.clamp(0.0, 1.0)
731
+
732
+
733
+ def compute_histogram(
734
+ tensor: torch.Tensor,
735
+ num_bins: int = 256,
736
+ ) -> torch.Tensor:
737
+ """
738
+ Compute histogram for a batch of single-channel images.
739
+
740
+ Args:
741
+ tensor: Input tensor (B, 1, H, W) with values in [0, 1]
742
+ num_bins: Number of histogram bins
743
+
744
+ Returns:
745
+ Histograms (B, num_bins)
746
+ """
747
+ B = tensor.shape[0]
748
+ device = tensor.device
749
+ dtype = tensor.dtype
750
+
751
+ # Flatten spatial dimensions
752
+ flat = tensor.view(B, -1) # (B, H*W)
753
+
754
+ # Bin indices
755
+ bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1)
756
+
757
+ # Compute histogram using scatter_add
758
+ histograms = torch.zeros(B, num_bins, device=device, dtype=dtype)
759
+ ones = torch.ones_like(flat, dtype=dtype)
760
+
761
+ for i in range(B):
762
+ histograms[i] = histograms[i].scatter_add(0, bin_indices[i], ones[i])
763
+
764
+ return histograms
765
+
766
+
767
+ def clahe_single_tile(
768
+ tile: torch.Tensor,
769
+ clip_limit: float,
770
+ num_bins: int = 256,
771
+ ) -> torch.Tensor:
772
+ """
773
+ Apply CLAHE to a single tile.
774
+
775
+ Args:
776
+ tile: Input tile (B, 1, tile_h, tile_w)
777
+ clip_limit: Histogram clip limit
778
+ num_bins: Number of histogram bins
779
+
780
+ Returns:
781
+ CDF lookup table (B, num_bins)
782
+ """
783
+ B, _, tile_h, tile_w = tile.shape
784
+ device = tile.device
785
+ dtype = tile.dtype
786
+ num_pixels = tile_h * tile_w
787
+
788
+ # Compute histogram
789
+ hist = compute_histogram(tile, num_bins) # (B, num_bins)
790
+
791
+ # Clip histogram
792
+ clip_value = clip_limit * num_pixels / num_bins
793
+ excess = (hist - clip_value).clamp(min=0).sum(dim=1, keepdim=True) # (B, 1)
794
+ hist = hist.clamp(max=clip_value)
795
+
796
+ # Redistribute excess uniformly
797
+ redistribution = excess / num_bins
798
+ hist = hist + redistribution
799
+
800
+ # Compute CDF
801
+ cdf = hist.cumsum(dim=1) # (B, num_bins)
802
+
803
+ # Normalize CDF to [0, 1]
804
+ cdf_min = cdf[:, 0:1]
805
+ cdf_max = cdf[:, -1:]
806
+ cdf = (cdf - cdf_min) / (cdf_max - cdf_min + 1e-8)
807
+
808
+ return cdf
809
+
810
+
811
+ def apply_clahe_vectorized(
812
+ images: torch.Tensor,
813
+ grid_size: int = 8,
814
+ clip_limit: float = 2.0,
815
+ num_bins: int = 256,
816
+ ) -> torch.Tensor:
817
+ """
818
+ Vectorized CLAHE implementation (more efficient for GPU).
819
+
820
+ Args:
821
+ images: Input images (B, C, H, W)
822
+ grid_size: Number of tiles in each dimension
823
+ clip_limit: Histogram clip limit
824
+ num_bins: Number of histogram bins
825
+
826
+ Returns:
827
+ CLAHE-enhanced images (B, C, H, W)
828
+ """
829
+ B, C, H, W = images.shape
830
+ device = images.device
831
+ dtype = images.dtype
832
+
833
+ # Work on luminance only
834
+ if C == 3:
835
+ L, a, b_ch = rgb_to_lab(images)
836
+ else:
837
+ L = images.clone()
838
+ a = b_ch = None
839
+
840
+ # Ensure divisibility
841
+ pad_h = (grid_size - H % grid_size) % grid_size
842
+ pad_w = (grid_size - W % grid_size) % grid_size
843
+
844
+ if pad_h > 0 or pad_w > 0:
845
+ L_padded = F.pad(L, (0, pad_w, 0, pad_h), mode='reflect')
846
+ else:
847
+ L_padded = L
848
+
849
+ _, _, H_pad, W_pad = L_padded.shape
850
+ tile_h = H_pad // grid_size
851
+ tile_w = W_pad // grid_size
852
+
853
+ # Reshape into tiles: (B, 1, grid_size, tile_h, grid_size, tile_w)
854
+ L_tiles = L_padded.view(B, 1, grid_size, tile_h, grid_size, tile_w)
855
+ L_tiles = L_tiles.permute(0, 2, 4, 1, 3, 5) # (B, grid_size, grid_size, 1, tile_h, tile_w)
856
+ L_tiles = L_tiles.reshape(B * grid_size * grid_size, 1, tile_h, tile_w)
857
+
858
+ # Compute histograms for all tiles at once
859
+ num_pixels = tile_h * tile_w
860
+ flat = L_tiles.view(B * grid_size * grid_size, -1)
861
+ bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1)
862
+
863
+ # Vectorized histogram computation
864
+ histograms = torch.zeros(B * grid_size * grid_size, num_bins, device=device, dtype=dtype)
865
+ histograms.scatter_add_(1, bin_indices, torch.ones_like(flat))
866
+
867
+ # Clip and redistribute
868
+ clip_value = clip_limit * num_pixels / num_bins
869
+ excess = (histograms - clip_value).clamp(min=0).sum(dim=1, keepdim=True)
870
+ histograms = histograms.clamp(max=clip_value)
871
+ histograms = histograms + excess / num_bins
872
+
873
+ # Compute CDFs
874
+ cdfs = histograms.cumsum(dim=1)
875
+ cdf_min = cdfs[:, 0:1]
876
+ cdf_max = cdfs[:, -1:]
877
+ cdfs = (cdfs - cdf_min) / (cdf_max - cdf_min + 1e-8)
878
+
879
+ # Reshape CDFs: (B, grid_size, grid_size, num_bins)
880
+ cdfs = cdfs.view(B, grid_size, grid_size, num_bins)
881
+
882
+ # Create coordinate grids for interpolation
883
+ y_coords = torch.arange(H_pad, device=device, dtype=dtype)
884
+ x_coords = torch.arange(W_pad, device=device, dtype=dtype)
885
+
886
+ # Map to tile coordinates (centered on tiles)
887
+ tile_y = (y_coords + 0.5) / tile_h - 0.5
888
+ tile_x = (x_coords + 0.5) / tile_w - 0.5
889
+
890
+ tile_y = tile_y.clamp(0, grid_size - 1.001)
891
+ tile_x = tile_x.clamp(0, grid_size - 1.001)
892
+
893
+ # Integer indices and weights
894
+ ty0 = tile_y.long().clamp(0, grid_size - 2)
895
+ tx0 = tile_x.long().clamp(0, grid_size - 2)
896
+ ty1 = (ty0 + 1).clamp(max=grid_size - 1)
897
+ tx1 = (tx0 + 1).clamp(max=grid_size - 1)
898
+
899
+ wy = (tile_y - ty0.float()).view(1, H_pad, 1, 1)
900
+ wx = (tile_x - tx0.float()).view(1, 1, W_pad, 1)
901
+
902
+ # Get bin indices for all pixels
903
+ bin_idx = (L_padded * (num_bins - 1)).long().clamp(0, num_bins - 1) # (B, 1, H_pad, W_pad)
904
+ bin_idx = bin_idx.squeeze(1) # (B, H_pad, W_pad)
905
+
906
+ # Gather CDF values for each corner
907
+ # We need cdfs[b, ty, tx, bin_idx[b, y, x]] for all combinations
908
+
909
+ # Expand indices for gathering
910
+ b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, H_pad, W_pad)
911
+ ty0_exp = ty0.view(1, H_pad, 1).expand(B, H_pad, W_pad)
912
+ ty1_exp = ty1.view(1, H_pad, 1).expand(B, H_pad, W_pad)
913
+ tx0_exp = tx0.view(1, 1, W_pad).expand(B, H_pad, W_pad)
914
+ tx1_exp = tx1.view(1, 1, W_pad).expand(B, H_pad, W_pad)
915
+
916
+ # Gather using advanced indexing
917
+ v00 = cdfs[b_idx, ty0_exp, tx0_exp, bin_idx] # (B, H_pad, W_pad)
918
+ v01 = cdfs[b_idx, ty0_exp, tx1_exp, bin_idx]
919
+ v10 = cdfs[b_idx, ty1_exp, tx0_exp, bin_idx]
920
+ v11 = cdfs[b_idx, ty1_exp, tx1_exp, bin_idx]
921
+
922
+ # Bilinear interpolation
923
+ wy = wy.squeeze(-1) # (1, H_pad, 1)
924
+ wx = wx.squeeze(-1) # (1, 1, W_pad)
925
+
926
+ L_out = (1 - wy) * (1 - wx) * v00 + (1 - wy) * wx * v01 + wy * (1 - wx) * v10 + wy * wx * v11
927
+ L_out = L_out.unsqueeze(1) # (B, 1, H_pad, W_pad)
928
+
929
+ # Remove padding
930
+ if pad_h > 0 or pad_w > 0:
931
+ L_out = L_out[:, :, :H, :W]
932
+
933
+ # Convert back to RGB
934
+ if C == 3:
935
+ output = lab_to_rgb(L_out, a, b_ch)
936
+ else:
937
+ output = L_out
938
+
939
+ return output
940
+
941
+
942
+ # =============================================================================
943
+ # PHASE 5: Resize & Normalization
944
+ # =============================================================================
945
+
946
+ # ImageNet normalization constants
947
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
948
+ IMAGENET_STD = [0.229, 0.224, 0.225]
949
+
950
+
951
+ def resize_images(
952
+ images: torch.Tensor,
953
+ size: int,
954
+ mode: str = 'bilinear',
955
+ antialias: bool = True,
956
+ ) -> torch.Tensor:
957
+ """
958
+ Resize images to target size.
959
+
960
+ Args:
961
+ images: Input images (B, C, H, W)
962
+ size: Target size (square)
963
+ mode: Interpolation mode
964
+ antialias: Whether to use antialiasing
965
+
966
+ Returns:
967
+ Resized images (B, C, size, size)
968
+ """
969
+ return F.interpolate(
970
+ images,
971
+ size=(size, size),
972
+ mode=mode,
973
+ align_corners=False if mode in ['bilinear', 'bicubic'] else None,
974
+ antialias=antialias if mode in ['bilinear', 'bicubic'] else False,
975
+ )
976
+
977
+
978
+ def normalize_images(
979
+ images: torch.Tensor,
980
+ mean: Optional[List[float]] = None,
981
+ std: Optional[List[float]] = None,
982
+ mode: str = 'imagenet',
983
+ ) -> torch.Tensor:
984
+ """
985
+ Normalize images.
986
+
987
+ Args:
988
+ images: Input images (B, C, H, W) in [0, 1]
989
+ mean: Custom mean (per channel)
990
+ std: Custom std (per channel)
991
+ mode: 'imagenet', 'none', or 'custom'
992
+
993
+ Returns:
994
+ Normalized images
995
+ """
996
+ if mode == 'none':
997
+ return images
998
+
999
+ if mode == 'imagenet':
1000
+ mean = IMAGENET_MEAN
1001
+ std = IMAGENET_STD
1002
+ elif mode == 'custom':
1003
+ if mean is None or std is None:
1004
+ raise ValueError("Custom mode requires mean and std")
1005
+ else:
1006
+ raise ValueError(f"Unknown normalization mode: {mode}")
1007
+
1008
+ device = images.device
1009
+ dtype = images.dtype
1010
+
1011
+ mean_tensor = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1)
1012
+ std_tensor = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1)
1013
+
1014
+ return (images - mean_tensor) / std_tensor
1015
+
1016
+
1017
+ # =============================================================================
1018
+ # PHASE 6: Hugging Face ImageProcessor Integration
1019
+ # =============================================================================
1020
+
1021
+ class EyeCLAHEImageProcessor(BaseImageProcessor):
1022
+ """
1023
+ GPU-native image processor for Color Fundus Photography (CFP) images.
1024
+
1025
+ This processor:
1026
+ 1. Localizes the eye region using gradient-based radial symmetry
1027
+ 2. Crops to a border-minimized square centered on the eye
1028
+ 3. Applies CLAHE for contrast enhancement
1029
+ 4. Resizes and normalizes for vision model input
1030
+
1031
+ All operations are implemented in pure PyTorch and are CUDA-compatible.
1032
+ """
1033
+
1034
+ model_input_names = ["pixel_values"]
1035
+
1036
+ def __init__(
1037
+ self,
1038
+ size: int = 224,
1039
+ crop_scale_factor: float = 1.1,
1040
+ clahe_grid_size: int = 8,
1041
+ clahe_clip_limit: float = 2.0,
1042
+ normalization_mode: str = "imagenet",
1043
+ custom_mean: Optional[List[float]] = None,
1044
+ custom_std: Optional[List[float]] = None,
1045
+ do_clahe: bool = True,
1046
+ do_crop: bool = True,
1047
+ min_radius_frac: float = 0.1,
1048
+ max_radius_frac: float = 0.5,
1049
+ allow_overflow: bool = False,
1050
+ softmax_temperature: float = 0.1,
1051
+ **kwargs,
1052
+ ):
1053
+ """
1054
+ Initialize the EyeCLAHEImageProcessor.
1055
+
1056
+ Args:
1057
+ size: Output image size (square)
1058
+ crop_scale_factor: Scale factor for crop box (relative to detected radius)
1059
+ clahe_grid_size: Number of tiles for CLAHE
1060
+ clahe_clip_limit: Histogram clip limit for CLAHE
1061
+ normalization_mode: 'imagenet', 'none', or 'custom'
1062
+ custom_mean: Custom normalization mean (if mode='custom')
1063
+ custom_std: Custom normalization std (if mode='custom')
1064
+ do_clahe: Whether to apply CLAHE
1065
+ do_crop: Whether to perform eye-centered cropping
1066
+ min_radius_frac: Minimum radius as fraction of image size
1067
+ max_radius_frac: Maximum radius as fraction of image size
1068
+ allow_overflow: If True, allow crop box to extend beyond image bounds
1069
+ and fill missing regions with black. Useful for pre-cropped
1070
+ images where the fundus circle is partially cut off.
1071
+ softmax_temperature: Temperature for soft argmax in eye center detection.
1072
+ Lower values (0.01-0.1) give sharper peak detection, higher values
1073
+ (0.3-0.5) provide more averaging for noisy images. Default: 0.1.
1074
+ """
1075
+ super().__init__(**kwargs)
1076
+
1077
+ self.size = size
1078
+ self.crop_scale_factor = crop_scale_factor
1079
+ self.clahe_grid_size = clahe_grid_size
1080
+ self.clahe_clip_limit = clahe_clip_limit
1081
+ self.normalization_mode = normalization_mode
1082
+ self.custom_mean = custom_mean
1083
+ self.custom_std = custom_std
1084
+ self.do_clahe = do_clahe
1085
+ self.do_crop = do_crop
1086
+ self.min_radius_frac = min_radius_frac
1087
+ self.max_radius_frac = max_radius_frac
1088
+ self.allow_overflow = allow_overflow
1089
+ self.softmax_temperature = softmax_temperature
1090
+
1091
+ def preprocess(
1092
+ self,
1093
+ images,
1094
+ return_tensors: str = "pt",
1095
+ device: Optional[Union[str, torch.device]] = None,
1096
+ **kwargs,
1097
+ ) -> BatchFeature:
1098
+ """
1099
+ Preprocess images for model input.
1100
+
1101
+ Args:
1102
+ images: Input images in any of these formats:
1103
+ - torch.Tensor: (C,H,W), (B,C,H,W), or list of tensors
1104
+ - PIL.Image.Image: single image or list of images
1105
+ - numpy.ndarray: (H,W,C), (B,H,W,C), or list of arrays
1106
+ return_tensors: Return type (only "pt" supported)
1107
+ device: Target device for processing (e.g., "cuda", "cpu")
1108
+
1109
+ Returns:
1110
+ BatchFeature with 'pixel_values' key containing (B, C, size, size) tensor
1111
+ """
1112
+ if return_tensors != "pt":
1113
+ raise ValueError("Only 'pt' (PyTorch) tensors are supported")
1114
+
1115
+ # Determine device
1116
+ if device is not None:
1117
+ device = torch.device(device)
1118
+ elif isinstance(images, torch.Tensor):
1119
+ device = images.device
1120
+ elif isinstance(images, list) and len(images) > 0 and isinstance(images[0], torch.Tensor):
1121
+ device = images[0].device
1122
+ else:
1123
+ # PIL images and numpy arrays default to CPU
1124
+ device = torch.device('cpu')
1125
+
1126
+ # Standardize input
1127
+ images = standardize_input(images, device)
1128
+ B, C, H, W = images.shape
1129
+
1130
+ if self.do_crop:
1131
+ # Estimate eye center
1132
+ cx, cy = estimate_eye_center(images, softmax_temperature=self.softmax_temperature)
1133
+
1134
+ # Estimate radius
1135
+ radius = estimate_radius(
1136
+ images, cx, cy,
1137
+ min_radius_frac=self.min_radius_frac,
1138
+ max_radius_frac=self.max_radius_frac,
1139
+ )
1140
+
1141
+ # Compute crop box
1142
+ x1, y1, x2, y2 = compute_crop_box(
1143
+ cx, cy, radius, H, W,
1144
+ scale_factor=self.crop_scale_factor,
1145
+ allow_overflow=self.allow_overflow,
1146
+ )
1147
+
1148
+ # Crop and resize
1149
+ # Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black
1150
+ padding_mode = 'zeros' if self.allow_overflow else 'border'
1151
+ images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode)
1152
+ else:
1153
+ # Just resize
1154
+ images = resize_images(images, self.size)
1155
+
1156
+ # Apply CLAHE
1157
+ if self.do_clahe:
1158
+ images = apply_clahe_vectorized(
1159
+ images,
1160
+ grid_size=self.clahe_grid_size,
1161
+ clip_limit=self.clahe_clip_limit,
1162
+ )
1163
+
1164
+ # Normalize
1165
+ images = normalize_images(
1166
+ images,
1167
+ mean=self.custom_mean,
1168
+ std=self.custom_std,
1169
+ mode=self.normalization_mode,
1170
+ )
1171
+
1172
+ return BatchFeature(data={"pixel_values": images}, tensor_type="pt")
1173
+
1174
+ def __call__(
1175
+ self,
1176
+ images: Union[torch.Tensor, List[torch.Tensor]],
1177
+ **kwargs,
1178
+ ) -> BatchFeature:
1179
+ """
1180
+ Process images (alias for preprocess).
1181
+ """
1182
+ return self.preprocess(images, **kwargs)
1183
+
1184
+
1185
+ # For AutoImageProcessor registration
1186
+ EyeGPUImageProcessor = EyeCLAHEImageProcessor
preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_processor_type": "EyeCLAHEImageProcessor",
3
+ "auto_map": {
4
+ "AutoImageProcessor": "image_processing_eye_gpu.EyeCLAHEImageProcessor"
5
+ },
6
+ "size": 512,
7
+ "crop_scale_factor": 1.1,
8
+ "clahe_grid_size": 8,
9
+ "clahe_clip_limit": 2.0,
10
+ "normalization_mode": "imagenet",
11
+ "custom_mean": null,
12
+ "custom_std": null,
13
+ "do_clahe": true,
14
+ "do_crop": true,
15
+ "min_radius_frac": 0.1,
16
+ "max_radius_frac": 1.2,
17
+ "allow_overflow": true,
18
+ "softmax_temperature": 0.1,
19
+ "processor_class": "EyeCLAHEImageProcessor"
20
+ }