DannyJun commited on
Commit
3e59a9f
·
verified ·
1 Parent(s): a3c624a

Upload image_processing_sprvla.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. image_processing_sprvla.py +951 -0
image_processing_sprvla.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for SPRVLA"""
2
+ from typing import TYPE_CHECKING, Tuple, List, Optional, Union, Dict, Any
3
+ import numpy as np
4
+ import einops
5
+ import torch
6
+ import torchvision.transforms
7
+ from torchvision.transforms import InterpolationMode
8
+ from torchvision.transforms.functional import convert_image_dtype
9
+
10
+ from transformers.image_utils import (
11
+ OPENAI_CLIP_MEAN,
12
+ OPENAI_CLIP_STD,
13
+ ChannelDimension,
14
+ ImageInput,
15
+ is_valid_image,
16
+ valid_images,
17
+ to_numpy_array,
18
+ )
19
+ from transformers.image_transforms import convert_to_rgb, to_channel_dimension_format
20
+ from transformers.processing_utils import ImagesKwargs
21
+ from transformers.image_processing_utils import BaseImageProcessor
22
+ from transformers.utils import logging
23
+ from transformers.feature_extraction_utils import BatchFeature
24
+ from transformers.utils import TensorType, logging
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers.utils import TensorType, logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ def is_multi_image(image: Union[ImageInput, List[ImageInput]]) -> bool:
35
+ return isinstance(image, (list, tuple))
36
+
37
+
38
+ def make_batched_images(images) -> List[ImageInput]:
39
+ """
40
+ Accepts images in list or nested list format.
41
+
42
+ Args:
43
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
44
+ The input image.
45
+
46
+ Returns:
47
+ list: A list of images or a list of lists of images.
48
+ """
49
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
50
+ return images
51
+
52
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
53
+ return images
54
+
55
+ elif is_valid_image(images):
56
+ return [images]
57
+
58
+ raise ValueError(f"Could not make batched images from {images}")
59
+
60
+
61
+ def normalize_image(image: np.ndarray, normalize_mode: str) -> np.ndarray:
62
+ if normalize_mode == "openai":
63
+ image -= np.array(OPENAI_CLIP_MEAN, dtype=np.float32)[None, None, :]
64
+ image /= np.array(OPENAI_CLIP_STD, dtype=np.float32)[None, None, :]
65
+ elif normalize_mode == "siglip":
66
+ image = np.asarray(-1.0, dtype=np.float32) + image * np.asarray(2.0, dtype=np.float32)
67
+ elif normalize_mode == "dino":
68
+ image -= np.array([0.485, 0.456, 0.406], dtype=np.float32)[None, None, :]
69
+ image /= np.array([0.229, 0.224, 0.225], dtype=np.float32)[None, None, :]
70
+ else:
71
+ raise NotImplementedError(normalize_mode)
72
+ return image
73
+
74
+
75
+ # Helper to ensure output_size is a 2-tuple of built-in Python ints
76
+ def _ensure_pyint_size2(size):
77
+ """
78
+ Ensure `size` is a 2-tuple of built-in Python ints.
79
+ Accepts int, list/tuple, or numpy array of length 1 or 2.
80
+ """
81
+ import numpy as np
82
+ # If it's an array-like, normalize to length-2 tuple
83
+ if isinstance(size, (list, tuple, np.ndarray)):
84
+ if len(size) == 2:
85
+ return (int(size[0]), int(size[1]))
86
+ elif len(size) == 1:
87
+ s = int(size[0])
88
+ return (s, s)
89
+ else:
90
+ # Fallback: try to interpret as square size using first element
91
+ s = int(size[0])
92
+ return (s, s)
93
+ # Scalar → square size
94
+ s = int(size)
95
+ return (s, s)
96
+
97
+
98
+ def resize_and_pad(
99
+ image,
100
+ desired_output_size,
101
+ resize_method="torch-bilinear",
102
+ pad_value=0,
103
+ ):
104
+ """Resize an image while padding to preserve uts aspect ratio."""
105
+ desired_output_size = _ensure_pyint_size2(desired_output_size)
106
+ desired_height, desired_width = desired_output_size
107
+ height, width = image.shape[:2]
108
+
109
+ # Cast into float32 since the training code did this in float32 and it (very rarely) effects
110
+ # the results after rounding.
111
+ image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
112
+ image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
113
+ image_scale = min(image_scale_x, image_scale_y)
114
+ scaled_height = int(np.array(height, np.float32) * image_scale)
115
+ scaled_width = int(np.array(width, np.float32) * image_scale)
116
+
117
+ if resize_method in ["torch-bilinear"]:
118
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
119
+ image = convert_image_dtype(image) # resize in float32 to match the training code
120
+ mode = InterpolationMode.BILINEAR
121
+ image = torchvision.transforms.Resize([scaled_height, scaled_width], mode, antialias=True)(image)
122
+ image = torch.clip(image, 0.0, 1.0)
123
+ image = torch.permute(image, [1, 2, 0]).numpy()
124
+ else:
125
+ raise NotImplementedError(resize_method)
126
+
127
+ top_pad = (desired_height - scaled_height) // 2
128
+ left_pad = (desired_width - scaled_width) // 2
129
+ padding = [
130
+ [top_pad, desired_height - scaled_height - top_pad],
131
+ [left_pad, desired_width - scaled_width - left_pad],
132
+ [0, 0]
133
+ ]
134
+ image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
135
+ image = np.pad(image, padding, constant_values=pad_value)
136
+ return image, image_mask
137
+
138
+
139
+ def metaclip_resize(image, desired_output_size):
140
+ desired_output_size = _ensure_pyint_size2(desired_output_size)
141
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
142
+ if torch.is_floating_point(image):
143
+ image = torchvision.transforms.Resize(
144
+ desired_output_size, InterpolationMode.BICUBIC, antialias=True)(image)
145
+ image = torch.clip(image, 0.0, 1.0)
146
+ else:
147
+ assert image.dtype == torch.uint8, "Expected float images or uint8 images, but got {}".format(image.dtype)
148
+ image = torchvision.transforms.Resize(
149
+ desired_output_size, InterpolationMode.BICUBIC, antialias=True)(image)
150
+ image = image.to(torch.float32)
151
+ image = torch.clip(image, 0, 255)
152
+ image = image / 255.0
153
+ resized = torch.permute(image, [1, 2, 0]).numpy()
154
+ image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_)
155
+ return resized, image_mask
156
+
157
+
158
+ def siglip_resize_and_pad(
159
+ image: np.ndarray,
160
+ desired_output_size: Tuple[int, int],
161
+ ) -> Tuple[np.ndarray, np.ndarray]:
162
+ desired_output_size = _ensure_pyint_size2(desired_output_size)
163
+ # by default, image is a single image
164
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
165
+ dtype = image.dtype
166
+ if torch.is_floating_point(image):
167
+ in_min = 0.0
168
+ in_max = 1.0
169
+ resized = torchvision.transforms.Resize(
170
+ desired_output_size,
171
+ InterpolationMode.BILINEAR,
172
+ antialias=False,
173
+ )(image)
174
+ resized = torch.clip(resized, 0.0, 1.0).to(dtype)
175
+ else:
176
+ assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
177
+ in_min = 0.0
178
+ in_max = 255.0
179
+ resized = torchvision.transforms.Resize(
180
+ desired_output_size,
181
+ InterpolationMode.BILINEAR,
182
+ antialias=False,
183
+ )(image)
184
+ resized = torch.clip(resized, 0, 255).to(dtype)
185
+
186
+ resized = resized.to(torch.float32)
187
+ resized = (resized - in_min) / (in_max - in_min)
188
+
189
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
190
+ image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_)
191
+
192
+ return resized, image_mask
193
+
194
+
195
+ def dino_resize_and_pad(
196
+ image: np.ndarray,
197
+ desired_output_size: Tuple[int, int],
198
+ ) -> Tuple[np.ndarray, np.ndarray]:
199
+ desired_output_size = _ensure_pyint_size2(desired_output_size)
200
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
201
+ dtype = image.dtype
202
+ if torch.is_floating_point(image):
203
+ resized = torchvision.transforms.Resize(
204
+ desired_output_size,
205
+ InterpolationMode.BICUBIC,
206
+ antialias=True,
207
+ )(image)
208
+ resized = torch.clip(resized, 0.0, 1.0).to(torch.float32)
209
+ else:
210
+ assert image.dtype == torch.uint8, "DINOv2 expects float images or uint8 images, but got {}".format(image.dtype)
211
+ resized = torchvision.transforms.Resize(
212
+ desired_output_size,
213
+ InterpolationMode.BICUBIC,
214
+ antialias=True,
215
+ )(image)
216
+ resized = torch.clip(resized, 0, 255).to(torch.float32)
217
+ resized = resized / 255.0
218
+
219
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
220
+ image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_)
221
+
222
+ return resized, image_mask
223
+
224
+
225
+ def resize_image(
226
+ image: np.ndarray,
227
+ resize_mode: str,
228
+ output_size: Tuple[int, int],
229
+ pad_value: float,
230
+ ) -> Tuple[np.ndarray, np.ndarray]:
231
+ if resize_mode == "siglip":
232
+ return siglip_resize_and_pad(image, output_size)
233
+ elif resize_mode == "dino":
234
+ return dino_resize_and_pad(image, output_size)
235
+ elif resize_mode == "metaclip":
236
+ return metaclip_resize(image, output_size)
237
+ else:
238
+ resize = "torch-bilinear" if resize_mode == "default" else resize_mode
239
+ return resize_and_pad(
240
+ image, output_size, resize_method=resize, pad_value=pad_value,
241
+ )
242
+
243
+
244
+ def select_tiling(h, w, patch_size, max_num_crops):
245
+ """Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
246
+ original_size = np.stack([h, w]) # [1, 2]
247
+ original_res = h * w
248
+ tilings = []
249
+ for i in range(1, max_num_crops + 1):
250
+ for j in range(1, max_num_crops + 1):
251
+ if i*j <= max_num_crops:
252
+ tilings.append((i, j))
253
+ # sort so argmin and argmax favour smaller tilings in the event of a tie
254
+ tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
255
+ candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
256
+ candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
257
+
258
+ # How much we would need to scale the image to fit exactly in each tiling
259
+ original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
260
+
261
+ # The original size can be zero in rare cases if the image is smaller than the margin
262
+ # In those cases letting the scale become infinite means the tiling is based on the
263
+ # other side, or falls back to the smallest tiling
264
+ with np.errstate(divide='ignore'):
265
+ required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
266
+ required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
267
+ if np.all(required_scale < 1):
268
+ # We are forced to downscale, so try to minimize the amount of downscaling
269
+ ix = np.argmax(required_scale)
270
+ else:
271
+ # Pick the resolution that required the least upscaling so that it most closely fits the image
272
+ required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
273
+ ix = np.argmin(required_scale)
274
+ return candidate_tilings[ix]
275
+
276
+
277
+ def build_resized_image(
278
+ image: np.ndarray,
279
+ resize_mode: str,
280
+ normalized_mode: str,
281
+ base_image_input_size: List[int],
282
+ pad_value: float,
283
+ image_patch_size: int,
284
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
285
+ resized, resized_mask = resize_image(
286
+ image, resize_mode, base_image_input_size, pad_value,
287
+ )
288
+ resized = normalize_image(resized, normalized_mode)
289
+ if len(resized.shape) == 3:
290
+ resized = np.expand_dims(resized, 0)
291
+ resized_mask = np.expand_dims(resized_mask, 0)
292
+ crop_patch_w = base_image_input_size[1] // image_patch_size
293
+ crop_patch_h = base_image_input_size[0] // image_patch_size
294
+ resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
295
+ return resized, resized_mask, resize_idx
296
+
297
+
298
+ def build_overlapping_crops(
299
+ image: np.ndarray,
300
+ resize_mode: str,
301
+ normalize_mode: str,
302
+ max_crops: int,
303
+ overlap_margins: List[int],
304
+ base_image_input_size: List[int],
305
+ pad_value: float,
306
+ image_patch_size: int,
307
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
308
+ """Decompose an image into a set of overlapping crops
309
+
310
+ :return crop_arr: [n_crops, h, w, 3] The crops
311
+ :return mask_arr: [n_crops, h, w] The padding masks
312
+ :return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
313
+ the crops were extracted from, what patch in `crop_arr` it corresponds to
314
+ """
315
+ original_image_h, original_image_w = image.shape[:2]
316
+ crop_size = base_image_input_size[0]
317
+ assert base_image_input_size[0] == base_image_input_size[1]
318
+
319
+ left_margin, right_margin = overlap_margins
320
+ total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
321
+ crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
322
+ crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
323
+ crop_window_size = crop_window_patches * image_patch_size
324
+ crop_patch_w = base_image_input_size[1] // image_patch_size
325
+ crop_patch_h = base_image_input_size[0] // image_patch_size
326
+ original_image_h, original_image_w = image.shape[:2]
327
+ crop_size = base_image_input_size[0]
328
+
329
+ # Decide how to tile the image, to account for the overlap margins we compute the tiling
330
+ # as if we had an image without the margins and were using a crop size without the margins
331
+ tiling = select_tiling(
332
+ original_image_h - total_margin_pixels,
333
+ original_image_w - total_margin_pixels,
334
+ crop_window_size,
335
+ max_crops,
336
+ )
337
+
338
+ src, img_mask = resize_image(
339
+ image,
340
+ resize_mode,
341
+ [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
342
+ pad_value,
343
+ )
344
+ src = normalize_image(src, normalize_mode)
345
+
346
+ # Now we have to split the image into crops, and track what patches came from
347
+ # where in `patch_idx_arr`
348
+ n_crops = tiling[0] * tiling[1]
349
+ crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
350
+ mask_arr = np.zeros([n_crops, crop_size, crop_size], dtype=img_mask.dtype)
351
+ patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
352
+ on = 0
353
+ on_crop = 0
354
+ for i in range(tiling[0]):
355
+ # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
356
+ # which results in overlapping crop windows
357
+ y0 = i*crop_window_size
358
+ for j in range(tiling[1]):
359
+ x0 = j*crop_window_size
360
+ crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
361
+ mask_arr[on_crop] = img_mask[y0:y0+crop_size, x0:x0+crop_size]
362
+ patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
363
+ patch_idx += on_crop * crop_patch_h * crop_patch_w
364
+
365
+ # Mask out idx that are in the overlap region
366
+ if i != 0:
367
+ patch_idx[:left_margin, :] = -1
368
+ if j != 0:
369
+ patch_idx[:, :left_margin] = -1
370
+ if i != tiling[0]-1:
371
+ patch_idx[-right_margin:, :] = -1
372
+ if j != tiling[1]-1:
373
+ patch_idx[:, -right_margin:] = -1
374
+ patch_idx_arr[on_crop] = patch_idx
375
+ on_crop += 1
376
+
377
+ # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
378
+ # so it is ordered left-to-right order
379
+ patch_idx_arr = np.reshape(
380
+ patch_idx_arr,
381
+ [tiling[0], tiling[1], crop_patch_h, crop_patch_w]
382
+ )
383
+ patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
384
+ patch_idx_arr = np.reshape(patch_idx_arr, [-1])
385
+
386
+ # Now get the parts not in the overlap region, so it should map each patch in `src`
387
+ # to the correct patch it should come from in `crop_arr`
388
+ patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
389
+ src.shape[0]//image_patch_size,
390
+ src.shape[1]//image_patch_size,
391
+ )
392
+ return crop_arr, mask_arr, patch_idx_arr
393
+
394
+
395
+ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
396
+ """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
397
+ if len(array.shape) == 3:
398
+ n_crops, h, w = array.shape
399
+ h_patches = h//patch_size
400
+ w_patches = w//patch_size
401
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
402
+ array = np.transpose(array, [0, 1, 3, 2, 4])
403
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
404
+ return array
405
+ else:
406
+ n_crops, h, w, c = array.shape
407
+ h_patches = h//patch_size
408
+ w_patches = w//patch_size
409
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
410
+ array = np.transpose(array, [0, 1, 3, 2, 4, 5])
411
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
412
+ return array
413
+
414
+
415
+ def arange_for_pooling(
416
+ idx_arr: np.ndarray,
417
+ pool_h: int,
418
+ pool_w: int,
419
+ ) -> np.ndarray:
420
+ h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
421
+ w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
422
+ idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
423
+ mode='constant',constant_values=-1)
424
+ return einops.rearrange(
425
+ idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
426
+
427
+
428
+ def image_to_patches_and_grids(
429
+ image: ImageInput,
430
+ crop_mode: str,
431
+ resize_mode: str,
432
+ normalize_mode: str,
433
+ max_crops: int,
434
+ overlap_margins: List[int],
435
+ base_image_input_size: List[int],
436
+ pad_value: float,
437
+ image_patch_size: int,
438
+ image_pooling_w: int,
439
+ image_pooling_h: int,
440
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
441
+ """
442
+ :return image_grids, the shape of each (low-res, high-res) image after pooling
443
+ :return crops, the image crops to processes with the ViT
444
+ :return mask, the padding mask for each crop
445
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
446
+ patches in `crops` to pool for that token, masked with -1
447
+ """
448
+ if isinstance(base_image_input_size, int):
449
+ base_image_input_size = (base_image_input_size, base_image_input_size)
450
+
451
+ base_image_input_d = image_patch_size
452
+ pooling_w = image_pooling_w
453
+ pooling_h = image_pooling_h
454
+ crop_patch_w = base_image_input_size[1] // base_image_input_d
455
+ crop_patch_h = base_image_input_size[0] // base_image_input_d
456
+
457
+ if crop_mode == "resize":
458
+ resized, resized_mask, resize_idx = build_resized_image(
459
+ image,
460
+ resize_mode,
461
+ normalize_mode,
462
+ base_image_input_size,
463
+ pad_value,
464
+ image_patch_size
465
+ )
466
+ pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
467
+ h, w = pooling_idx.shape[:2]
468
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
469
+ image_grid = [np.array([h, w])]
470
+ return (
471
+ np.stack(image_grid, 0),
472
+ batch_pixels_to_patches(resized, image_patch_size),
473
+ batch_pixels_to_patches(resized_mask, image_patch_size).mean(-1),
474
+ pooling_idx,
475
+ )
476
+
477
+ if crop_mode in ["overlap-and-resize-c2", "overlap-and-resize"]:
478
+ crop_arr, mask_arr, patch_idx_arr = build_overlapping_crops(
479
+ image,
480
+ resize_mode,
481
+ normalize_mode,
482
+ max_crops,
483
+ overlap_margins,
484
+ base_image_input_size,
485
+ pad_value,
486
+ image_patch_size,
487
+ )
488
+ pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
489
+ h, w = pooling_idx.shape[:2]
490
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
491
+ image_grid = [np.array([h, w])]
492
+
493
+ if crop_mode == "overlap-and-resize":
494
+ crop_arr = batch_pixels_to_patches(crop_arr, image_patch_size)
495
+ mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
496
+ return np.stack(image_grid, 0), crop_arr, mask_arr, pooling_idx
497
+
498
+ # Finally do the same for the global image
499
+ resized, resized_mask, resize_idx = build_resized_image(
500
+ image,
501
+ resize_mode,
502
+ normalize_mode,
503
+ base_image_input_size,
504
+ pad_value,
505
+ image_patch_size
506
+ )
507
+ crop_arr = np.concatenate([resized, crop_arr], 0)
508
+
509
+ mask_arr = np.concatenate([resized_mask, mask_arr], 0)
510
+
511
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
512
+ h, w = resize_idx.shape[:2]
513
+ resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
514
+
515
+ # Global image goes first, so the order of patches in previous crops gets increased
516
+ pooling_idx = np.where(
517
+ pooling_idx >= 0,
518
+ pooling_idx + crop_patch_h*crop_patch_w,
519
+ -1
520
+ )
521
+ pooling_idx = np.concatenate([resize_idx, pooling_idx])
522
+ image_grid = [
523
+ np.array([h, w]),
524
+ ] + image_grid
525
+
526
+ mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
527
+ return (
528
+ np.stack(image_grid, 0),
529
+ batch_pixels_to_patches(crop_arr, image_patch_size),
530
+ mask_arr,
531
+ pooling_idx
532
+ )
533
+ else:
534
+ raise NotImplementedError(crop_mode)
535
+
536
+
537
+ def image_to_patches_and_tokens(
538
+ image: ImageInput,
539
+ crop_mode: str,
540
+ use_col_tokens: bool,
541
+ resize_mode: str,
542
+ normalize_mode: str,
543
+ max_crops: int,
544
+ overlap_margins: List[int],
545
+ base_image_input_size: List[int],
546
+ pad_value: float,
547
+ image_patch_size: int,
548
+ image_pooling_w: int,
549
+ image_pooling_h: int,
550
+ image_patch_token_id: int,
551
+ image_col_token_id: int,
552
+ image_start_token_id: int,
553
+ image_end_token_id: int,
554
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
555
+ """
556
+ :return image_tokens, the token IDS for this image, including special tokens
557
+ :return crops, the image crops to processes with the ViT
558
+ :return mask, the padding mask for each crop
559
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
560
+ patches in `crops` to pool for that token, masked with -1
561
+ """
562
+
563
+ if isinstance(base_image_input_size, int):
564
+ base_image_input_size = (base_image_input_size, base_image_input_size)
565
+
566
+ base_image_input_d = image_patch_size
567
+ pooling_w = image_pooling_w
568
+ pooling_h = image_pooling_h
569
+ patch_id = image_patch_token_id
570
+ col_id = image_col_token_id
571
+ start_id = image_start_token_id
572
+ end_id = image_end_token_id
573
+ crop_patch_w = base_image_input_size[1] // base_image_input_d
574
+ crop_patch_h = base_image_input_size[0] // base_image_input_d
575
+
576
+ if crop_mode == "resize":
577
+ resized, resized_mask, resize_idx = build_resized_image(
578
+ image,
579
+ resize_mode,
580
+ normalize_mode,
581
+ base_image_input_size,
582
+ pad_value,
583
+ image_patch_size
584
+ )
585
+ pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
586
+ h, w = pooling_idx.shape[:2]
587
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
588
+ per_row = np.full(
589
+ (w,),
590
+ patch_id,
591
+ dtype=np.int32
592
+ )
593
+ if use_col_tokens:
594
+ per_row = np.concatenate([per_row, [col_id]], 0)
595
+ extra_tokens = np.tile(per_row, [h])
596
+ joint = [
597
+ [start_id],
598
+ extra_tokens,
599
+ [end_id],
600
+ ]
601
+ return (
602
+ np.concatenate(joint, 0),
603
+ batch_pixels_to_patches(resized, image_patch_size),
604
+ batch_pixels_to_patches(resized_mask, image_patch_size).mean(-1),
605
+ pooling_idx,
606
+ )
607
+
608
+ if crop_mode in ["overlap-and-resize-c2", "overlap-and-resize"]:
609
+ crop_arr, mask_arr, patch_idx_arr = build_overlapping_crops(
610
+ image,
611
+ resize_mode,
612
+ normalize_mode,
613
+ max_crops,
614
+ overlap_margins,
615
+ base_image_input_size,
616
+ pad_value,
617
+ image_patch_size,
618
+ )
619
+ pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
620
+ h, w = pooling_idx.shape[:2]
621
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
622
+
623
+ # Now build the output tokens
624
+ per_row = np.full(w, patch_id, dtype=np.int32)
625
+ if use_col_tokens:
626
+ per_row = np.concatenate([per_row, [col_id]], 0)
627
+ joint = np.tile(per_row, [h])
628
+ joint = [
629
+ [start_id],
630
+ joint,
631
+ [end_id]
632
+ ]
633
+
634
+ if crop_mode == "overlap-and-resize":
635
+ crop_arr = batch_pixels_to_patches(crop_arr, image_patch_size)
636
+ mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
637
+ return np.concatenate(joint, 0), crop_arr, mask_arr, pooling_idx
638
+
639
+ # Finally do the same for the global image
640
+ resized, resized_mask, resize_idx = build_resized_image(
641
+ image,
642
+ resize_mode,
643
+ normalize_mode,
644
+ base_image_input_size,
645
+ pad_value,
646
+ image_patch_size
647
+ )
648
+ crop_arr = np.concatenate([resized, crop_arr], 0)
649
+
650
+ mask_arr = np.concatenate([resized_mask, mask_arr], 0)
651
+
652
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
653
+ h, w = resize_idx.shape[:2]
654
+ resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
655
+
656
+ # Global image goes first, so the order of patches in previous crops gets increased
657
+ pooling_idx = np.where(
658
+ pooling_idx >= 0,
659
+ pooling_idx + crop_patch_h*crop_patch_w,
660
+ -1
661
+ )
662
+ pooling_idx = np.concatenate([resize_idx, pooling_idx])
663
+
664
+ per_row = np.full(
665
+ (w,),
666
+ patch_id,
667
+ dtype=np.int32
668
+ )
669
+ if use_col_tokens:
670
+ per_row = np.concatenate([per_row, [col_id]], 0)
671
+ extra_tokens = np.tile(per_row, [h])
672
+ joint = [
673
+ [start_id],
674
+ extra_tokens,
675
+ [end_id],
676
+ ] + joint
677
+ mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
678
+ return (
679
+ np.concatenate(joint, 0),
680
+ batch_pixels_to_patches(crop_arr, image_patch_size),
681
+ mask_arr,
682
+ pooling_idx
683
+ )
684
+ else:
685
+ raise NotImplementedError(crop_mode)
686
+
687
+
688
+ class SPRVLAImagesKwargs(ImagesKwargs, total=False):
689
+ crop_mode: Optional[str]
690
+ resize_mode: Optional[str]
691
+ normalize_mode: Optional[str]
692
+ max_crops: Optional[int]
693
+ max_multi_image_crops: Optional[int]
694
+ overlap_margins: Optional[List[int]]
695
+ base_image_input_size: Optional[List[int]]
696
+ pad_value: Optional[float]
697
+ image_patch_size: Optional[int]
698
+ image_pooling_w: Optional[int]
699
+ image_pooling_h: Optional[int]
700
+
701
+
702
+ class SPRVLAImageProcessor(BaseImageProcessor):
703
+
704
+ model_input_names = ["images", "pooled_patches_idx", "image_masks"]
705
+
706
+ def __init__(
707
+ self,
708
+ crop_mode: str = "overlap-and-resize-c2",
709
+ resize_mode: str = "siglip",
710
+ normalize_mode: str = "siglip",
711
+ max_crops: int = 8,
712
+ max_multi_image_crops: int = 4,
713
+ overlap_margins: List[int] = [4, 4],
714
+ base_image_input_size: List[int] = (378, 378),
715
+ pad_value: float = 0.0,
716
+ image_patch_size: int = 14,
717
+ image_pooling_w: int = 2,
718
+ image_pooling_h: int = 2,
719
+ do_convert_rgb: bool = True,
720
+ do_pad: Optional[bool] = True,
721
+ **kwargs,
722
+ ) -> None:
723
+ super().__init__(**kwargs)
724
+ self.crop_mode = crop_mode
725
+ self.resize_mode = resize_mode
726
+ self.normalize_mode = normalize_mode
727
+ self.overlap_margins = overlap_margins
728
+ self.max_crops = max_crops
729
+ self.max_multi_image_crops = max_multi_image_crops
730
+ self.overlap_margins = overlap_margins
731
+ self.base_image_input_size = base_image_input_size
732
+ self.pad_value = pad_value
733
+ self.image_patch_size = image_patch_size
734
+ self.image_pooling_w = image_pooling_w
735
+ self.image_pooling_h = image_pooling_h
736
+ self.do_convert_rgb = do_convert_rgb
737
+ self.do_pad = do_pad
738
+
739
+ def to_channel_dimension_last(
740
+ self,
741
+ images: List[ImageInput],
742
+ ) -> List[ImageInput]:
743
+ """
744
+ Convert images to channel dimension last.
745
+ """
746
+ new_images = []
747
+ for image in images:
748
+ if is_multi_image(image):
749
+ new_images.append([to_channel_dimension_format(img, ChannelDimension.LAST) for img in image])
750
+ else:
751
+ new_images.append(to_channel_dimension_format(image, ChannelDimension.LAST))
752
+ return new_images
753
+
754
+ def to_numpy_array(
755
+ self,
756
+ images: List[ImageInput],
757
+ ) -> List[np.ndarray]:
758
+ """
759
+ Convert images to numpy array.
760
+ """
761
+ new_images = []
762
+ for image in images:
763
+ if is_multi_image(image):
764
+ new_images.append([to_numpy_array(img) for img in image])
765
+ else:
766
+ new_images.append(to_numpy_array(image))
767
+ return new_images
768
+
769
+ def to_rgb(
770
+ self,
771
+ images: List[ImageInput],
772
+ ) -> List[ImageInput]:
773
+ """
774
+ Convert images to RGB.
775
+ """
776
+ new_images = []
777
+ for image in images:
778
+ if is_multi_image(image):
779
+ new_images.append([convert_to_rgb(img) for img in image])
780
+ else:
781
+ new_images.append(convert_to_rgb(image))
782
+ return new_images
783
+
784
+ def pad_arrays(self, arrays: List[np.ndarray], pad_value: float = -1) -> np.ndarray:
785
+ max_len = max(arr.shape[0] for arr in arrays)
786
+ padded_arr = np.full(
787
+ [len(arrays), max_len] + list(arrays[0].shape[1:]), pad_value, dtype=arrays[0].dtype
788
+ )
789
+ for ix, arr in enumerate(arrays):
790
+ padded_arr[ix, :len(arr)] = arr[:max_len]
791
+ return padded_arr
792
+
793
+ def pad_for_batching(self, data: Dict[str, Any]) -> Dict[str, Any]:
794
+ """
795
+ Pad the data for batching.
796
+ """
797
+ images = self.pad_arrays(data["images"])
798
+ pooled_patches_idx = self.pad_arrays(data["pooled_patches_idx"])
799
+ image_masks = self.pad_arrays(data["image_masks"])
800
+ image_grids = self.pad_arrays(data["image_grids"])
801
+ new_data = dict(
802
+ images=images,
803
+ pooled_patches_idx=pooled_patches_idx,
804
+ image_masks=image_masks,
805
+ image_grids=image_grids,
806
+ )
807
+ return new_data
808
+
809
+ def preprocess(
810
+ self,
811
+ images: Union[ImageInput, List[ImageInput]],
812
+ crop_mode: Optional[str] = None,
813
+ resize_mode: Optional[str] = None,
814
+ normalize_mode: Optional[str] = None,
815
+ max_crops: Optional[int] = None,
816
+ max_multi_image_crops: Optional[int] = None,
817
+ overlap_margins: Optional[List[int]] = None,
818
+ base_image_input_size: Optional[List[int]] = None,
819
+ pad_value: Optional[float] = None,
820
+ image_patch_size: Optional[int] = None,
821
+ image_pooling_w: Optional[int] = None,
822
+ image_pooling_h: Optional[int] = None,
823
+ do_convert_rgb: Optional[bool] = None,
824
+ do_pad: Optional[bool] = None,
825
+ return_tensors: Optional[Union[str, TensorType]] = None,
826
+ **kwargs,
827
+ ) -> BatchFeature:
828
+ """
829
+ Preprocess an image for the model.
830
+ Args:
831
+ image: The image to preprocess.
832
+ crop_mode: The crop mode to use. If None, use the default crop mode.
833
+ resize_mode: The resize mode to use. If None, use the default resize mode.
834
+ normalize_mode: The normalization mode to use. If None, use the default normalization mode.
835
+ max_crops: The maximum number of crops to use. If None, use the default value.
836
+ max_multi_image_crops: The maximum number of crops to use for multi-image inputs.
837
+ overlap_margins: The overlap margins to use. If None, use the default values.
838
+ base_image_input_size: The base image input size to use. If None, use the default size.
839
+ pad_value: The padding value to use. If None, use the default value.
840
+ image_patch_size: The size of the image patches. If None, use the default size.
841
+ image_pooling_h: The height of the image pooling. If None, use the default height.
842
+ image_pooling_w: The width of the image pooling. If None, use the default width.
843
+ do_convert_rgb: Whether to convert the image to RGB. If None, use the default value.
844
+ do_pad: Whether to pad image features. If None, use the default value.
845
+
846
+ Returns:
847
+ A tuple containing:
848
+ - The image grids
849
+ - The preprocessed images
850
+ - The padding masks
851
+ - The pooling indices
852
+ """
853
+ images = make_batched_images(images)
854
+
855
+ if not valid_images(images):
856
+ raise ValueError("Invalid image input")
857
+
858
+ crop_mode = crop_mode or self.crop_mode
859
+ normalize_mode = normalize_mode or self.normalize_mode
860
+ resize_mode = resize_mode or self.resize_mode
861
+ max_crops = max_crops or self.max_crops
862
+ max_multi_image_crops = max_multi_image_crops or self.max_multi_image_crops
863
+ overlap_margins = overlap_margins or self.overlap_margins
864
+ base_image_input_size = base_image_input_size or self.base_image_input_size
865
+ pad_value = pad_value or self.pad_value
866
+ image_patch_size = image_patch_size or self.image_patch_size
867
+ image_pooling_w = image_pooling_w or self.image_pooling_w
868
+ image_pooling_h = image_pooling_h or self.image_pooling_h
869
+ do_convert_rgb = do_convert_rgb or self.do_convert_rgb
870
+ do_pad = do_pad or self.do_pad
871
+
872
+ if do_convert_rgb:
873
+ images = self.to_rgb(images)
874
+
875
+ # All transformations expect numpy arrays.
876
+ images = self.to_numpy_array(images)
877
+
878
+ # All transformations expect channel dimension last.
879
+ images = self.to_channel_dimension_last(images)
880
+
881
+ batch_image_grids = []
882
+ batch_crops = []
883
+ batch_crop_masks = []
884
+ batch_pooled_patches_idx = []
885
+
886
+ for image in images:
887
+ if is_multi_image(image):
888
+ all_image_grids = []
889
+ all_crops = []
890
+ all_crop_masks = []
891
+ pooled_patches_idx = []
892
+ for img in image:
893
+ image_grid, crops, img_mask, pooled_idx = image_to_patches_and_grids(
894
+ img,
895
+ crop_mode,
896
+ resize_mode,
897
+ normalize_mode,
898
+ max_multi_image_crops,
899
+ overlap_margins,
900
+ base_image_input_size,
901
+ pad_value,
902
+ image_patch_size,
903
+ image_pooling_w,
904
+ image_pooling_h,
905
+ )
906
+ pooled_patches_idx.append(pooled_idx + sum(np.prod(x.shape[:2]) for x in all_crops))
907
+ all_crops.append(crops)
908
+ all_crop_masks.append(img_mask)
909
+ all_image_grids.append(image_grid)
910
+ all_image_grids = np.concatenate(all_image_grids, 0)
911
+ all_crops = np.concatenate(all_crops, 0)
912
+ all_crop_masks = np.concatenate(all_crop_masks, 0)
913
+ pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
914
+
915
+ batch_image_grids.append(all_image_grids)
916
+ batch_crops.append(all_crops)
917
+ batch_crop_masks.append(all_crop_masks)
918
+ batch_pooled_patches_idx.append(pooled_patches_idx)
919
+ else:
920
+ image_grid, crops, img_mask, pooled_idx = image_to_patches_and_grids(
921
+ image,
922
+ crop_mode,
923
+ resize_mode,
924
+ normalize_mode,
925
+ max_crops,
926
+ overlap_margins,
927
+ base_image_input_size,
928
+ pad_value,
929
+ image_patch_size,
930
+ image_pooling_w,
931
+ image_pooling_h,
932
+ )
933
+ batch_image_grids.append(image_grid)
934
+ batch_crops.append(crops)
935
+ batch_crop_masks.append(img_mask)
936
+ batch_pooled_patches_idx.append(pooled_idx)
937
+
938
+ data =dict(
939
+ images=batch_crops,
940
+ pooled_patches_idx=batch_pooled_patches_idx,
941
+ image_masks=batch_crop_masks,
942
+ image_grids=batch_image_grids,
943
+ )
944
+
945
+ if do_pad:
946
+ data = self.pad_for_batching(data)
947
+
948
+ return BatchFeature(data, tensor_type=return_tensors)
949
+
950
+
951
+ SPRVLAImageProcessor.register_for_auto_class()