DannyJun commited on
Commit
2a747c0
·
verified ·
1 Parent(s): d13c52b

Delete image_processing_molmoact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. image_processing_molmoact.py +0 -951
image_processing_molmoact.py DELETED
@@ -1,951 +0,0 @@
1
- """Image processor class for MolmoAct"""
2
- from typing import TYPE_CHECKING, Tuple, List, Optional, Union, Dict, Any
3
- import numpy as np
4
- import einops
5
- import torch
6
- import torchvision.transforms
7
- from torchvision.transforms import InterpolationMode
8
- from torchvision.transforms.functional import convert_image_dtype
9
-
10
- from transformers.image_utils import (
11
- OPENAI_CLIP_MEAN,
12
- OPENAI_CLIP_STD,
13
- ChannelDimension,
14
- ImageInput,
15
- is_valid_image,
16
- valid_images,
17
- to_numpy_array,
18
- )
19
- from transformers.image_transforms import convert_to_rgb, to_channel_dimension_format
20
- from transformers.processing_utils import ImagesKwargs
21
- from transformers.image_processing_utils import BaseImageProcessor
22
- from transformers.utils import logging
23
- from transformers.feature_extraction_utils import BatchFeature
24
- from transformers.utils import TensorType, logging
25
-
26
-
27
- if TYPE_CHECKING:
28
- from transformers.utils import TensorType, logging
29
-
30
-
31
- logger = logging.get_logger(__name__)
32
-
33
-
34
- def is_multi_image(image: Union[ImageInput, List[ImageInput]]) -> bool:
35
- return isinstance(image, (list, tuple))
36
-
37
-
38
- def make_batched_images(images) -> List[ImageInput]:
39
- """
40
- Accepts images in list or nested list format.
41
-
42
- Args:
43
- images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
44
- The input image.
45
-
46
- Returns:
47
- list: A list of images or a list of lists of images.
48
- """
49
- if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
50
- return images
51
-
52
- elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
53
- return images
54
-
55
- elif is_valid_image(images):
56
- return [images]
57
-
58
- raise ValueError(f"Could not make batched images from {images}")
59
-
60
-
61
- def normalize_image(image: np.ndarray, normalize_mode: str) -> np.ndarray:
62
- if normalize_mode == "openai":
63
- image -= np.array(OPENAI_CLIP_MEAN, dtype=np.float32)[None, None, :]
64
- image /= np.array(OPENAI_CLIP_STD, dtype=np.float32)[None, None, :]
65
- elif normalize_mode == "siglip":
66
- image = np.asarray(-1.0, dtype=np.float32) + image * np.asarray(2.0, dtype=np.float32)
67
- elif normalize_mode == "dino":
68
- image -= np.array([0.485, 0.456, 0.406], dtype=np.float32)[None, None, :]
69
- image /= np.array([0.229, 0.224, 0.225], dtype=np.float32)[None, None, :]
70
- else:
71
- raise NotImplementedError(normalize_mode)
72
- return image
73
-
74
-
75
- # Helper to ensure output_size is a 2-tuple of built-in Python ints
76
- def _ensure_pyint_size2(size):
77
- """
78
- Ensure `size` is a 2-tuple of built-in Python ints.
79
- Accepts int, list/tuple, or numpy array of length 1 or 2.
80
- """
81
- import numpy as np
82
- # If it's an array-like, normalize to length-2 tuple
83
- if isinstance(size, (list, tuple, np.ndarray)):
84
- if len(size) == 2:
85
- return (int(size[0]), int(size[1]))
86
- elif len(size) == 1:
87
- s = int(size[0])
88
- return (s, s)
89
- else:
90
- # Fallback: try to interpret as square size using first element
91
- s = int(size[0])
92
- return (s, s)
93
- # Scalar → square size
94
- s = int(size)
95
- return (s, s)
96
-
97
-
98
- def resize_and_pad(
99
- image,
100
- desired_output_size,
101
- resize_method="torch-bilinear",
102
- pad_value=0,
103
- ):
104
- """Resize an image while padding to preserve uts aspect ratio."""
105
- desired_output_size = _ensure_pyint_size2(desired_output_size)
106
- desired_height, desired_width = desired_output_size
107
- height, width = image.shape[:2]
108
-
109
- # Cast into float32 since the training code did this in float32 and it (very rarely) effects
110
- # the results after rounding.
111
- image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
112
- image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
113
- image_scale = min(image_scale_x, image_scale_y)
114
- scaled_height = int(np.array(height, np.float32) * image_scale)
115
- scaled_width = int(np.array(width, np.float32) * image_scale)
116
-
117
- if resize_method in ["torch-bilinear"]:
118
- image = torch.permute(torch.from_numpy(image), [2, 0, 1])
119
- image = convert_image_dtype(image) # resize in float32 to match the training code
120
- mode = InterpolationMode.BILINEAR
121
- image = torchvision.transforms.Resize([scaled_height, scaled_width], mode, antialias=True)(image)
122
- image = torch.clip(image, 0.0, 1.0)
123
- image = torch.permute(image, [1, 2, 0]).numpy()
124
- else:
125
- raise NotImplementedError(resize_method)
126
-
127
- top_pad = (desired_height - scaled_height) // 2
128
- left_pad = (desired_width - scaled_width) // 2
129
- padding = [
130
- [top_pad, desired_height - scaled_height - top_pad],
131
- [left_pad, desired_width - scaled_width - left_pad],
132
- [0, 0]
133
- ]
134
- image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
135
- image = np.pad(image, padding, constant_values=pad_value)
136
- return image, image_mask
137
-
138
-
139
- def metaclip_resize(image, desired_output_size):
140
- desired_output_size = _ensure_pyint_size2(desired_output_size)
141
- image = torch.permute(torch.from_numpy(image), [2, 0, 1])
142
- if torch.is_floating_point(image):
143
- image = torchvision.transforms.Resize(
144
- desired_output_size, InterpolationMode.BICUBIC, antialias=True)(image)
145
- image = torch.clip(image, 0.0, 1.0)
146
- else:
147
- assert image.dtype == torch.uint8, "Expected float images or uint8 images, but got {}".format(image.dtype)
148
- image = torchvision.transforms.Resize(
149
- desired_output_size, InterpolationMode.BICUBIC, antialias=True)(image)
150
- image = image.to(torch.float32)
151
- image = torch.clip(image, 0, 255)
152
- image = image / 255.0
153
- resized = torch.permute(image, [1, 2, 0]).numpy()
154
- image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_)
155
- return resized, image_mask
156
-
157
-
158
- def siglip_resize_and_pad(
159
- image: np.ndarray,
160
- desired_output_size: Tuple[int, int],
161
- ) -> Tuple[np.ndarray, np.ndarray]:
162
- desired_output_size = _ensure_pyint_size2(desired_output_size)
163
- # by default, image is a single image
164
- image = torch.permute(torch.from_numpy(image), [2, 0, 1])
165
- dtype = image.dtype
166
- if torch.is_floating_point(image):
167
- in_min = 0.0
168
- in_max = 1.0
169
- resized = torchvision.transforms.Resize(
170
- desired_output_size,
171
- InterpolationMode.BILINEAR,
172
- antialias=False,
173
- )(image)
174
- resized = torch.clip(resized, 0.0, 1.0).to(dtype)
175
- else:
176
- assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
177
- in_min = 0.0
178
- in_max = 255.0
179
- resized = torchvision.transforms.Resize(
180
- desired_output_size,
181
- InterpolationMode.BILINEAR,
182
- antialias=False,
183
- )(image)
184
- resized = torch.clip(resized, 0, 255).to(dtype)
185
-
186
- resized = resized.to(torch.float32)
187
- resized = (resized - in_min) / (in_max - in_min)
188
-
189
- resized = torch.permute(resized, [1, 2, 0]).numpy()
190
- image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_)
191
-
192
- return resized, image_mask
193
-
194
-
195
- def dino_resize_and_pad(
196
- image: np.ndarray,
197
- desired_output_size: Tuple[int, int],
198
- ) -> Tuple[np.ndarray, np.ndarray]:
199
- desired_output_size = _ensure_pyint_size2(desired_output_size)
200
- image = torch.permute(torch.from_numpy(image), [2, 0, 1])
201
- dtype = image.dtype
202
- if torch.is_floating_point(image):
203
- resized = torchvision.transforms.Resize(
204
- desired_output_size,
205
- InterpolationMode.BICUBIC,
206
- antialias=True,
207
- )(image)
208
- resized = torch.clip(resized, 0.0, 1.0).to(torch.float32)
209
- else:
210
- assert image.dtype == torch.uint8, "DINOv2 expects float images or uint8 images, but got {}".format(image.dtype)
211
- resized = torchvision.transforms.Resize(
212
- desired_output_size,
213
- InterpolationMode.BICUBIC,
214
- antialias=True,
215
- )(image)
216
- resized = torch.clip(resized, 0, 255).to(torch.float32)
217
- resized = resized / 255.0
218
-
219
- resized = torch.permute(resized, [1, 2, 0]).numpy()
220
- image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_)
221
-
222
- return resized, image_mask
223
-
224
-
225
- def resize_image(
226
- image: np.ndarray,
227
- resize_mode: str,
228
- output_size: Tuple[int, int],
229
- pad_value: float,
230
- ) -> Tuple[np.ndarray, np.ndarray]:
231
- if resize_mode == "siglip":
232
- return siglip_resize_and_pad(image, output_size)
233
- elif resize_mode == "dino":
234
- return dino_resize_and_pad(image, output_size)
235
- elif resize_mode == "metaclip":
236
- return metaclip_resize(image, output_size)
237
- else:
238
- resize = "torch-bilinear" if resize_mode == "default" else resize_mode
239
- return resize_and_pad(
240
- image, output_size, resize_method=resize, pad_value=pad_value,
241
- )
242
-
243
-
244
- def select_tiling(h, w, patch_size, max_num_crops):
245
- """Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
246
- original_size = np.stack([h, w]) # [1, 2]
247
- original_res = h * w
248
- tilings = []
249
- for i in range(1, max_num_crops + 1):
250
- for j in range(1, max_num_crops + 1):
251
- if i*j <= max_num_crops:
252
- tilings.append((i, j))
253
- # sort so argmin and argmax favour smaller tilings in the event of a tie
254
- tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
255
- candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
256
- candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
257
-
258
- # How much we would need to scale the image to fit exactly in each tiling
259
- original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
260
-
261
- # The original size can be zero in rare cases if the image is smaller than the margin
262
- # In those cases letting the scale become infinite means the tiling is based on the
263
- # other side, or falls back to the smallest tiling
264
- with np.errstate(divide='ignore'):
265
- required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
266
- required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
267
- if np.all(required_scale < 1):
268
- # We are forced to downscale, so try to minimize the amount of downscaling
269
- ix = np.argmax(required_scale)
270
- else:
271
- # Pick the resolution that required the least upscaling so that it most closely fits the image
272
- required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
273
- ix = np.argmin(required_scale)
274
- return candidate_tilings[ix]
275
-
276
-
277
- def build_resized_image(
278
- image: np.ndarray,
279
- resize_mode: str,
280
- normalized_mode: str,
281
- base_image_input_size: List[int],
282
- pad_value: float,
283
- image_patch_size: int,
284
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
285
- resized, resized_mask = resize_image(
286
- image, resize_mode, base_image_input_size, pad_value,
287
- )
288
- resized = normalize_image(resized, normalized_mode)
289
- if len(resized.shape) == 3:
290
- resized = np.expand_dims(resized, 0)
291
- resized_mask = np.expand_dims(resized_mask, 0)
292
- crop_patch_w = base_image_input_size[1] // image_patch_size
293
- crop_patch_h = base_image_input_size[0] // image_patch_size
294
- resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
295
- return resized, resized_mask, resize_idx
296
-
297
-
298
- def build_overlapping_crops(
299
- image: np.ndarray,
300
- resize_mode: str,
301
- normalize_mode: str,
302
- max_crops: int,
303
- overlap_margins: List[int],
304
- base_image_input_size: List[int],
305
- pad_value: float,
306
- image_patch_size: int,
307
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
308
- """Decompose an image into a set of overlapping crops
309
-
310
- :return crop_arr: [n_crops, h, w, 3] The crops
311
- :return mask_arr: [n_crops, h, w] The padding masks
312
- :return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
313
- the crops were extracted from, what patch in `crop_arr` it corresponds to
314
- """
315
- original_image_h, original_image_w = image.shape[:2]
316
- crop_size = base_image_input_size[0]
317
- assert base_image_input_size[0] == base_image_input_size[1]
318
-
319
- left_margin, right_margin = overlap_margins
320
- total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
321
- crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
322
- crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
323
- crop_window_size = crop_window_patches * image_patch_size
324
- crop_patch_w = base_image_input_size[1] // image_patch_size
325
- crop_patch_h = base_image_input_size[0] // image_patch_size
326
- original_image_h, original_image_w = image.shape[:2]
327
- crop_size = base_image_input_size[0]
328
-
329
- # Decide how to tile the image, to account for the overlap margins we compute the tiling
330
- # as if we had an image without the margins and were using a crop size without the margins
331
- tiling = select_tiling(
332
- original_image_h - total_margin_pixels,
333
- original_image_w - total_margin_pixels,
334
- crop_window_size,
335
- max_crops,
336
- )
337
-
338
- src, img_mask = resize_image(
339
- image,
340
- resize_mode,
341
- [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
342
- pad_value,
343
- )
344
- src = normalize_image(src, normalize_mode)
345
-
346
- # Now we have to split the image into crops, and track what patches came from
347
- # where in `patch_idx_arr`
348
- n_crops = tiling[0] * tiling[1]
349
- crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
350
- mask_arr = np.zeros([n_crops, crop_size, crop_size], dtype=img_mask.dtype)
351
- patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
352
- on = 0
353
- on_crop = 0
354
- for i in range(tiling[0]):
355
- # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
356
- # which results in overlapping crop windows
357
- y0 = i*crop_window_size
358
- for j in range(tiling[1]):
359
- x0 = j*crop_window_size
360
- crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
361
- mask_arr[on_crop] = img_mask[y0:y0+crop_size, x0:x0+crop_size]
362
- patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
363
- patch_idx += on_crop * crop_patch_h * crop_patch_w
364
-
365
- # Mask out idx that are in the overlap region
366
- if i != 0:
367
- patch_idx[:left_margin, :] = -1
368
- if j != 0:
369
- patch_idx[:, :left_margin] = -1
370
- if i != tiling[0]-1:
371
- patch_idx[-right_margin:, :] = -1
372
- if j != tiling[1]-1:
373
- patch_idx[:, -right_margin:] = -1
374
- patch_idx_arr[on_crop] = patch_idx
375
- on_crop += 1
376
-
377
- # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
378
- # so it is ordered left-to-right order
379
- patch_idx_arr = np.reshape(
380
- patch_idx_arr,
381
- [tiling[0], tiling[1], crop_patch_h, crop_patch_w]
382
- )
383
- patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
384
- patch_idx_arr = np.reshape(patch_idx_arr, [-1])
385
-
386
- # Now get the parts not in the overlap region, so it should map each patch in `src`
387
- # to the correct patch it should come from in `crop_arr`
388
- patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
389
- src.shape[0]//image_patch_size,
390
- src.shape[1]//image_patch_size,
391
- )
392
- return crop_arr, mask_arr, patch_idx_arr
393
-
394
-
395
- def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
396
- """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
397
- if len(array.shape) == 3:
398
- n_crops, h, w = array.shape
399
- h_patches = h//patch_size
400
- w_patches = w//patch_size
401
- array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
402
- array = np.transpose(array, [0, 1, 3, 2, 4])
403
- array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
404
- return array
405
- else:
406
- n_crops, h, w, c = array.shape
407
- h_patches = h//patch_size
408
- w_patches = w//patch_size
409
- array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
410
- array = np.transpose(array, [0, 1, 3, 2, 4, 5])
411
- array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
412
- return array
413
-
414
-
415
- def arange_for_pooling(
416
- idx_arr: np.ndarray,
417
- pool_h: int,
418
- pool_w: int,
419
- ) -> np.ndarray:
420
- h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
421
- w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
422
- idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
423
- mode='constant',constant_values=-1)
424
- return einops.rearrange(
425
- idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
426
-
427
-
428
- def image_to_patches_and_grids(
429
- image: ImageInput,
430
- crop_mode: str,
431
- resize_mode: str,
432
- normalize_mode: str,
433
- max_crops: int,
434
- overlap_margins: List[int],
435
- base_image_input_size: List[int],
436
- pad_value: float,
437
- image_patch_size: int,
438
- image_pooling_w: int,
439
- image_pooling_h: int,
440
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
441
- """
442
- :return image_grids, the shape of each (low-res, high-res) image after pooling
443
- :return crops, the image crops to processes with the ViT
444
- :return mask, the padding mask for each crop
445
- :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
446
- patches in `crops` to pool for that token, masked with -1
447
- """
448
- if isinstance(base_image_input_size, int):
449
- base_image_input_size = (base_image_input_size, base_image_input_size)
450
-
451
- base_image_input_d = image_patch_size
452
- pooling_w = image_pooling_w
453
- pooling_h = image_pooling_h
454
- crop_patch_w = base_image_input_size[1] // base_image_input_d
455
- crop_patch_h = base_image_input_size[0] // base_image_input_d
456
-
457
- if crop_mode == "resize":
458
- resized, resized_mask, resize_idx = build_resized_image(
459
- image,
460
- resize_mode,
461
- normalize_mode,
462
- base_image_input_size,
463
- pad_value,
464
- image_patch_size
465
- )
466
- pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
467
- h, w = pooling_idx.shape[:2]
468
- pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
469
- image_grid = [np.array([h, w])]
470
- return (
471
- np.stack(image_grid, 0),
472
- batch_pixels_to_patches(resized, image_patch_size),
473
- batch_pixels_to_patches(resized_mask, image_patch_size).mean(-1),
474
- pooling_idx,
475
- )
476
-
477
- if crop_mode in ["overlap-and-resize-c2", "overlap-and-resize"]:
478
- crop_arr, mask_arr, patch_idx_arr = build_overlapping_crops(
479
- image,
480
- resize_mode,
481
- normalize_mode,
482
- max_crops,
483
- overlap_margins,
484
- base_image_input_size,
485
- pad_value,
486
- image_patch_size,
487
- )
488
- pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
489
- h, w = pooling_idx.shape[:2]
490
- pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
491
- image_grid = [np.array([h, w])]
492
-
493
- if crop_mode == "overlap-and-resize":
494
- crop_arr = batch_pixels_to_patches(crop_arr, image_patch_size)
495
- mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
496
- return np.stack(image_grid, 0), crop_arr, mask_arr, pooling_idx
497
-
498
- # Finally do the same for the global image
499
- resized, resized_mask, resize_idx = build_resized_image(
500
- image,
501
- resize_mode,
502
- normalize_mode,
503
- base_image_input_size,
504
- pad_value,
505
- image_patch_size
506
- )
507
- crop_arr = np.concatenate([resized, crop_arr], 0)
508
-
509
- mask_arr = np.concatenate([resized_mask, mask_arr], 0)
510
-
511
- resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
512
- h, w = resize_idx.shape[:2]
513
- resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
514
-
515
- # Global image goes first, so the order of patches in previous crops gets increased
516
- pooling_idx = np.where(
517
- pooling_idx >= 0,
518
- pooling_idx + crop_patch_h*crop_patch_w,
519
- -1
520
- )
521
- pooling_idx = np.concatenate([resize_idx, pooling_idx])
522
- image_grid = [
523
- np.array([h, w]),
524
- ] + image_grid
525
-
526
- mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
527
- return (
528
- np.stack(image_grid, 0),
529
- batch_pixels_to_patches(crop_arr, image_patch_size),
530
- mask_arr,
531
- pooling_idx
532
- )
533
- else:
534
- raise NotImplementedError(crop_mode)
535
-
536
-
537
- def image_to_patches_and_tokens(
538
- image: ImageInput,
539
- crop_mode: str,
540
- use_col_tokens: bool,
541
- resize_mode: str,
542
- normalize_mode: str,
543
- max_crops: int,
544
- overlap_margins: List[int],
545
- base_image_input_size: List[int],
546
- pad_value: float,
547
- image_patch_size: int,
548
- image_pooling_w: int,
549
- image_pooling_h: int,
550
- image_patch_token_id: int,
551
- image_col_token_id: int,
552
- image_start_token_id: int,
553
- image_end_token_id: int,
554
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
555
- """
556
- :return image_tokens, the token IDS for this image, including special tokens
557
- :return crops, the image crops to processes with the ViT
558
- :return mask, the padding mask for each crop
559
- :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
560
- patches in `crops` to pool for that token, masked with -1
561
- """
562
-
563
- if isinstance(base_image_input_size, int):
564
- base_image_input_size = (base_image_input_size, base_image_input_size)
565
-
566
- base_image_input_d = image_patch_size
567
- pooling_w = image_pooling_w
568
- pooling_h = image_pooling_h
569
- patch_id = image_patch_token_id
570
- col_id = image_col_token_id
571
- start_id = image_start_token_id
572
- end_id = image_end_token_id
573
- crop_patch_w = base_image_input_size[1] // base_image_input_d
574
- crop_patch_h = base_image_input_size[0] // base_image_input_d
575
-
576
- if crop_mode == "resize":
577
- resized, resized_mask, resize_idx = build_resized_image(
578
- image,
579
- resize_mode,
580
- normalize_mode,
581
- base_image_input_size,
582
- pad_value,
583
- image_patch_size
584
- )
585
- pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
586
- h, w = pooling_idx.shape[:2]
587
- pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
588
- per_row = np.full(
589
- (w,),
590
- patch_id,
591
- dtype=np.int32
592
- )
593
- if use_col_tokens:
594
- per_row = np.concatenate([per_row, [col_id]], 0)
595
- extra_tokens = np.tile(per_row, [h])
596
- joint = [
597
- [start_id],
598
- extra_tokens,
599
- [end_id],
600
- ]
601
- return (
602
- np.concatenate(joint, 0),
603
- batch_pixels_to_patches(resized, image_patch_size),
604
- batch_pixels_to_patches(resized_mask, image_patch_size).mean(-1),
605
- pooling_idx,
606
- )
607
-
608
- if crop_mode in ["overlap-and-resize-c2", "overlap-and-resize"]:
609
- crop_arr, mask_arr, patch_idx_arr = build_overlapping_crops(
610
- image,
611
- resize_mode,
612
- normalize_mode,
613
- max_crops,
614
- overlap_margins,
615
- base_image_input_size,
616
- pad_value,
617
- image_patch_size,
618
- )
619
- pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
620
- h, w = pooling_idx.shape[:2]
621
- pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
622
-
623
- # Now build the output tokens
624
- per_row = np.full(w, patch_id, dtype=np.int32)
625
- if use_col_tokens:
626
- per_row = np.concatenate([per_row, [col_id]], 0)
627
- joint = np.tile(per_row, [h])
628
- joint = [
629
- [start_id],
630
- joint,
631
- [end_id]
632
- ]
633
-
634
- if crop_mode == "overlap-and-resize":
635
- crop_arr = batch_pixels_to_patches(crop_arr, image_patch_size)
636
- mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
637
- return np.concatenate(joint, 0), crop_arr, mask_arr, pooling_idx
638
-
639
- # Finally do the same for the global image
640
- resized, resized_mask, resize_idx = build_resized_image(
641
- image,
642
- resize_mode,
643
- normalize_mode,
644
- base_image_input_size,
645
- pad_value,
646
- image_patch_size
647
- )
648
- crop_arr = np.concatenate([resized, crop_arr], 0)
649
-
650
- mask_arr = np.concatenate([resized_mask, mask_arr], 0)
651
-
652
- resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
653
- h, w = resize_idx.shape[:2]
654
- resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
655
-
656
- # Global image goes first, so the order of patches in previous crops gets increased
657
- pooling_idx = np.where(
658
- pooling_idx >= 0,
659
- pooling_idx + crop_patch_h*crop_patch_w,
660
- -1
661
- )
662
- pooling_idx = np.concatenate([resize_idx, pooling_idx])
663
-
664
- per_row = np.full(
665
- (w,),
666
- patch_id,
667
- dtype=np.int32
668
- )
669
- if use_col_tokens:
670
- per_row = np.concatenate([per_row, [col_id]], 0)
671
- extra_tokens = np.tile(per_row, [h])
672
- joint = [
673
- [start_id],
674
- extra_tokens,
675
- [end_id],
676
- ] + joint
677
- mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1)
678
- return (
679
- np.concatenate(joint, 0),
680
- batch_pixels_to_patches(crop_arr, image_patch_size),
681
- mask_arr,
682
- pooling_idx
683
- )
684
- else:
685
- raise NotImplementedError(crop_mode)
686
-
687
-
688
- class MolmoActImagesKwargs(ImagesKwargs, total=False):
689
- crop_mode: Optional[str]
690
- resize_mode: Optional[str]
691
- normalize_mode: Optional[str]
692
- max_crops: Optional[int]
693
- max_multi_image_crops: Optional[int]
694
- overlap_margins: Optional[List[int]]
695
- base_image_input_size: Optional[List[int]]
696
- pad_value: Optional[float]
697
- image_patch_size: Optional[int]
698
- image_pooling_w: Optional[int]
699
- image_pooling_h: Optional[int]
700
-
701
-
702
- class MolmoActImageProcessor(BaseImageProcessor):
703
-
704
- model_input_names = ["images", "pooled_patches_idx", "image_masks"]
705
-
706
- def __init__(
707
- self,
708
- crop_mode: str = "overlap-and-resize-c2",
709
- resize_mode: str = "siglip",
710
- normalize_mode: str = "siglip",
711
- max_crops: int = 8,
712
- max_multi_image_crops: int = 4,
713
- overlap_margins: List[int] = [4, 4],
714
- base_image_input_size: List[int] = (378, 378),
715
- pad_value: float = 0.0,
716
- image_patch_size: int = 14,
717
- image_pooling_w: int = 2,
718
- image_pooling_h: int = 2,
719
- do_convert_rgb: bool = True,
720
- do_pad: Optional[bool] = True,
721
- **kwargs,
722
- ) -> None:
723
- super().__init__(**kwargs)
724
- self.crop_mode = crop_mode
725
- self.resize_mode = resize_mode
726
- self.normalize_mode = normalize_mode
727
- self.overlap_margins = overlap_margins
728
- self.max_crops = max_crops
729
- self.max_multi_image_crops = max_multi_image_crops
730
- self.overlap_margins = overlap_margins
731
- self.base_image_input_size = base_image_input_size
732
- self.pad_value = pad_value
733
- self.image_patch_size = image_patch_size
734
- self.image_pooling_w = image_pooling_w
735
- self.image_pooling_h = image_pooling_h
736
- self.do_convert_rgb = do_convert_rgb
737
- self.do_pad = do_pad
738
-
739
- def to_channel_dimension_last(
740
- self,
741
- images: List[ImageInput],
742
- ) -> List[ImageInput]:
743
- """
744
- Convert images to channel dimension last.
745
- """
746
- new_images = []
747
- for image in images:
748
- if is_multi_image(image):
749
- new_images.append([to_channel_dimension_format(img, ChannelDimension.LAST) for img in image])
750
- else:
751
- new_images.append(to_channel_dimension_format(image, ChannelDimension.LAST))
752
- return new_images
753
-
754
- def to_numpy_array(
755
- self,
756
- images: List[ImageInput],
757
- ) -> List[np.ndarray]:
758
- """
759
- Convert images to numpy array.
760
- """
761
- new_images = []
762
- for image in images:
763
- if is_multi_image(image):
764
- new_images.append([to_numpy_array(img) for img in image])
765
- else:
766
- new_images.append(to_numpy_array(image))
767
- return new_images
768
-
769
- def to_rgb(
770
- self,
771
- images: List[ImageInput],
772
- ) -> List[ImageInput]:
773
- """
774
- Convert images to RGB.
775
- """
776
- new_images = []
777
- for image in images:
778
- if is_multi_image(image):
779
- new_images.append([convert_to_rgb(img) for img in image])
780
- else:
781
- new_images.append(convert_to_rgb(image))
782
- return new_images
783
-
784
- def pad_arrays(self, arrays: List[np.ndarray], pad_value: float = -1) -> np.ndarray:
785
- max_len = max(arr.shape[0] for arr in arrays)
786
- padded_arr = np.full(
787
- [len(arrays), max_len] + list(arrays[0].shape[1:]), pad_value, dtype=arrays[0].dtype
788
- )
789
- for ix, arr in enumerate(arrays):
790
- padded_arr[ix, :len(arr)] = arr[:max_len]
791
- return padded_arr
792
-
793
- def pad_for_batching(self, data: Dict[str, Any]) -> Dict[str, Any]:
794
- """
795
- Pad the data for batching.
796
- """
797
- images = self.pad_arrays(data["images"])
798
- pooled_patches_idx = self.pad_arrays(data["pooled_patches_idx"])
799
- image_masks = self.pad_arrays(data["image_masks"])
800
- image_grids = self.pad_arrays(data["image_grids"])
801
- new_data = dict(
802
- images=images,
803
- pooled_patches_idx=pooled_patches_idx,
804
- image_masks=image_masks,
805
- image_grids=image_grids,
806
- )
807
- return new_data
808
-
809
- def preprocess(
810
- self,
811
- images: Union[ImageInput, List[ImageInput]],
812
- crop_mode: Optional[str] = None,
813
- resize_mode: Optional[str] = None,
814
- normalize_mode: Optional[str] = None,
815
- max_crops: Optional[int] = None,
816
- max_multi_image_crops: Optional[int] = None,
817
- overlap_margins: Optional[List[int]] = None,
818
- base_image_input_size: Optional[List[int]] = None,
819
- pad_value: Optional[float] = None,
820
- image_patch_size: Optional[int] = None,
821
- image_pooling_w: Optional[int] = None,
822
- image_pooling_h: Optional[int] = None,
823
- do_convert_rgb: Optional[bool] = None,
824
- do_pad: Optional[bool] = None,
825
- return_tensors: Optional[Union[str, TensorType]] = None,
826
- **kwargs,
827
- ) -> BatchFeature:
828
- """
829
- Preprocess an image for the model.
830
- Args:
831
- image: The image to preprocess.
832
- crop_mode: The crop mode to use. If None, use the default crop mode.
833
- resize_mode: The resize mode to use. If None, use the default resize mode.
834
- normalize_mode: The normalization mode to use. If None, use the default normalization mode.
835
- max_crops: The maximum number of crops to use. If None, use the default value.
836
- max_multi_image_crops: The maximum number of crops to use for multi-image inputs.
837
- overlap_margins: The overlap margins to use. If None, use the default values.
838
- base_image_input_size: The base image input size to use. If None, use the default size.
839
- pad_value: The padding value to use. If None, use the default value.
840
- image_patch_size: The size of the image patches. If None, use the default size.
841
- image_pooling_h: The height of the image pooling. If None, use the default height.
842
- image_pooling_w: The width of the image pooling. If None, use the default width.
843
- do_convert_rgb: Whether to convert the image to RGB. If None, use the default value.
844
- do_pad: Whether to pad image features. If None, use the default value.
845
-
846
- Returns:
847
- A tuple containing:
848
- - The image grids
849
- - The preprocessed images
850
- - The padding masks
851
- - The pooling indices
852
- """
853
- images = make_batched_images(images)
854
-
855
- if not valid_images(images):
856
- raise ValueError("Invalid image input")
857
-
858
- crop_mode = crop_mode or self.crop_mode
859
- normalize_mode = normalize_mode or self.normalize_mode
860
- resize_mode = resize_mode or self.resize_mode
861
- max_crops = max_crops or self.max_crops
862
- max_multi_image_crops = max_multi_image_crops or self.max_multi_image_crops
863
- overlap_margins = overlap_margins or self.overlap_margins
864
- base_image_input_size = base_image_input_size or self.base_image_input_size
865
- pad_value = pad_value or self.pad_value
866
- image_patch_size = image_patch_size or self.image_patch_size
867
- image_pooling_w = image_pooling_w or self.image_pooling_w
868
- image_pooling_h = image_pooling_h or self.image_pooling_h
869
- do_convert_rgb = do_convert_rgb or self.do_convert_rgb
870
- do_pad = do_pad or self.do_pad
871
-
872
- if do_convert_rgb:
873
- images = self.to_rgb(images)
874
-
875
- # All transformations expect numpy arrays.
876
- images = self.to_numpy_array(images)
877
-
878
- # All transformations expect channel dimension last.
879
- images = self.to_channel_dimension_last(images)
880
-
881
- batch_image_grids = []
882
- batch_crops = []
883
- batch_crop_masks = []
884
- batch_pooled_patches_idx = []
885
-
886
- for image in images:
887
- if is_multi_image(image):
888
- all_image_grids = []
889
- all_crops = []
890
- all_crop_masks = []
891
- pooled_patches_idx = []
892
- for img in image:
893
- image_grid, crops, img_mask, pooled_idx = image_to_patches_and_grids(
894
- img,
895
- crop_mode,
896
- resize_mode,
897
- normalize_mode,
898
- max_multi_image_crops,
899
- overlap_margins,
900
- base_image_input_size,
901
- pad_value,
902
- image_patch_size,
903
- image_pooling_w,
904
- image_pooling_h,
905
- )
906
- pooled_patches_idx.append(pooled_idx + sum(np.prod(x.shape[:2]) for x in all_crops))
907
- all_crops.append(crops)
908
- all_crop_masks.append(img_mask)
909
- all_image_grids.append(image_grid)
910
- all_image_grids = np.concatenate(all_image_grids, 0)
911
- all_crops = np.concatenate(all_crops, 0)
912
- all_crop_masks = np.concatenate(all_crop_masks, 0)
913
- pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
914
-
915
- batch_image_grids.append(all_image_grids)
916
- batch_crops.append(all_crops)
917
- batch_crop_masks.append(all_crop_masks)
918
- batch_pooled_patches_idx.append(pooled_patches_idx)
919
- else:
920
- image_grid, crops, img_mask, pooled_idx = image_to_patches_and_grids(
921
- image,
922
- crop_mode,
923
- resize_mode,
924
- normalize_mode,
925
- max_crops,
926
- overlap_margins,
927
- base_image_input_size,
928
- pad_value,
929
- image_patch_size,
930
- image_pooling_w,
931
- image_pooling_h,
932
- )
933
- batch_image_grids.append(image_grid)
934
- batch_crops.append(crops)
935
- batch_crop_masks.append(img_mask)
936
- batch_pooled_patches_idx.append(pooled_idx)
937
-
938
- data =dict(
939
- images=batch_crops,
940
- pooled_patches_idx=batch_pooled_patches_idx,
941
- image_masks=batch_crop_masks,
942
- image_grids=batch_image_grids,
943
- )
944
-
945
- if do_pad:
946
- data = self.pad_for_batching(data)
947
-
948
- return BatchFeature(data, tensor_type=return_tensors)
949
-
950
-
951
- MolmoActImageProcessor.register_for_auto_class()