dreilly commited on
Commit
61dcebe
·
verified ·
1 Parent(s): 40ae0be

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViSCoP_VisionEncoderModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_viscop_vision_encoder.ViSCoP_VisionEncoderConfig",
8
+ "AutoModel": "modeling_viscop_vision_encoder.ViSCoP_VisionEncoderModel"
9
+ },
10
+ "hidden_act": "gelu_pytorch_tanh",
11
+ "hidden_size": 1152,
12
+ "interaction_module": "cross_attention",
13
+ "interaction_module_layers": null,
14
+ "intermediate_size": 4304,
15
+ "layer_norm_eps": 1e-06,
16
+ "model_type": "viscop_vision_encoder",
17
+ "num_attention_heads": 16,
18
+ "num_channels": 3,
19
+ "num_hidden_layers": 27,
20
+ "num_visual_probes": 16,
21
+ "patch_size": 14,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.46.3"
24
+ }
configuration_viscop_vision_encoder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/DAMO-NLP-SG/VideoLLaMA3/blob/main/videollama3/model/videollama3_encoder/configuration_videollama3_encoder.py
2
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py.
3
+ # Below is the original copyright:
4
+ # coding=utf-8
5
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """ViSCoP vision encoder model configuration."""
19
+
20
+ from transformers import PretrainedConfig
21
+
22
+
23
+ class ViSCoP_VisionEncoderConfig(PretrainedConfig):
24
+
25
+ model_type = "viscop_vision_encoder"
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size=768,
30
+ intermediate_size=3072,
31
+ num_hidden_layers=12,
32
+ num_attention_heads=12,
33
+ num_channels=3,
34
+ patch_size=16,
35
+ hidden_act="gelu_pytorch_tanh",
36
+ layer_norm_eps=1e-6,
37
+ attention_dropout=0.0,
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.hidden_size = hidden_size
42
+ self.intermediate_size = intermediate_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.num_channels = num_channels
46
+ self.patch_size = patch_size
47
+ self.attention_dropout = attention_dropout
48
+ self.layer_norm_eps = layer_norm_eps
49
+ self.hidden_act = hidden_act
image_processing_viscop.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/DAMO-NLP-SG/VideoLLaMA3/blob/main/videollama3/model/videollama3_encoder/image_processing_videollama3.py
2
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py.
3
+ # Below is the original copyright:
4
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ """Image processor class for ViSCoP."""
23
+
24
+ import math
25
+ from typing import Dict, List, Optional, Union
26
+
27
+ import numpy as np
28
+
29
+ import torch
30
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
31
+ from transformers.image_utils import ImageInput
32
+ from transformers.image_transforms import (
33
+ convert_to_rgb,
34
+ resize,
35
+ to_channel_dimension_format,
36
+ )
37
+ from transformers.image_utils import (
38
+ OPENAI_CLIP_MEAN,
39
+ OPENAI_CLIP_STD,
40
+ ChannelDimension,
41
+ ImageInput,
42
+ PILImageResampling,
43
+ VideoInput,
44
+ get_image_size,
45
+ infer_channel_dimension_format,
46
+ is_scaled_image,
47
+ is_valid_image,
48
+ make_list_of_images,
49
+ to_numpy_array,
50
+ )
51
+ from transformers.utils import TensorType, is_vision_available, logging
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ if is_vision_available():
58
+ from PIL import Image
59
+
60
+
61
+ def is_valid_video(video) -> bool:
62
+ if isinstance(video, (list, tuple)):
63
+ return all(is_valid_image(frame) for frame in video)
64
+ elif isinstance(video, np.ndarray):
65
+ return video.ndim == 4
66
+ elif isinstance(video, torch.Tensor):
67
+ return video.ndim == 4
68
+ return False
69
+
70
+
71
+ def make_batched_images(images) -> List[List[ImageInput]]:
72
+ """
73
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
74
+
75
+ Args:
76
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
77
+ The input image.
78
+
79
+ Returns:
80
+ list: A list of images.
81
+ """
82
+ if isinstance(images, (list, tuple)):
83
+ # list of images/videos
84
+ if not all(is_valid_video(image) or is_valid_image(image) for image in images):
85
+ raise ValueError(f"Could not make batched images from {images}")
86
+ return images
87
+ elif is_valid_video(images) or is_valid_image(images):
88
+ # single image/video
89
+ return [images]
90
+
91
+ raise ValueError(f"Could not make batched images from {images}")
92
+
93
+
94
+ def simple_batched_resize(
95
+ images, factor: int = 28, min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
96
+ ):
97
+ min_pixels = min_tokens * factor * factor
98
+ max_pixels = max_tokens * factor * factor
99
+
100
+ num_images = 0
101
+ for image in images:
102
+ if is_valid_video(image):
103
+ num_images += len(image)
104
+ else:
105
+ num_images += 1
106
+
107
+ image_sizes = []
108
+ for image in images:
109
+ if is_valid_video(image):
110
+ image = image[0]
111
+ if isinstance(image, Image.Image):
112
+ width, height = image.size
113
+ else:
114
+ height, width = get_image_size(image, channel_dim=input_data_format)
115
+ image_sizes.append([height, width])
116
+
117
+ tmp_image_sizes = []
118
+ for height, width in image_sizes:
119
+ h_bar = round(height / factor) * factor
120
+ w_bar = round(width / factor) * factor
121
+ if h_bar * w_bar > (max_pixels // num_images):
122
+ beta = math.sqrt((height * width) / (max_pixels // num_images))
123
+ h_bar = math.floor(height / beta / factor) * factor
124
+ w_bar = math.floor(width / beta / factor) * factor
125
+ # per image min_pixels
126
+ if h_bar * w_bar < min_pixels:
127
+ beta = math.sqrt(min_pixels / (height * width))
128
+ h_bar = math.ceil(height * beta / factor) * factor
129
+ w_bar = math.ceil(width * beta / factor) * factor
130
+ tmp_image_sizes.append((h_bar, w_bar))
131
+ image_sizes = tmp_image_sizes
132
+ return image_sizes
133
+
134
+
135
+ def batched_resize(
136
+ images, factors: List[int], min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
137
+ ):
138
+ image_sizes = []
139
+ for image in images:
140
+ if is_valid_video(image):
141
+ num_frame = len(image)
142
+ image = image[0]
143
+ else:
144
+ num_frame = 1
145
+ if isinstance(image, Image.Image):
146
+ width, height = image.size
147
+ else:
148
+ height, width = get_image_size(image, channel_dim=input_data_format)
149
+ image_sizes.append([num_frame, height, width])
150
+
151
+ # global max_pixels
152
+ smart_scale_factors = 1.0
153
+ total_tokens = 0
154
+ for (num_frame, height, width), factor in zip(image_sizes, factors):
155
+ total_tokens += num_frame * math.ceil(height / factor) * math.ceil(width / factor)
156
+
157
+ # TODO: add min_pixels
158
+ if total_tokens > max_tokens:
159
+ beta = math.sqrt(total_tokens / max_tokens)
160
+ tmp_image_sizes = []
161
+ for (_, height, width), factor in zip(image_sizes, factors):
162
+ h_bar = math.floor(height / beta / factor) * factor
163
+ w_bar = math.floor(width / beta / factor) * factor
164
+ tmp_image_sizes.append((h_bar, w_bar))
165
+ image_sizes = tmp_image_sizes
166
+ else:
167
+ tmp_image_sizes = []
168
+ for (_, height, width), factor in zip(image_sizes, factors):
169
+ height = round(height / factor) * factor
170
+ width = round(width / factor) * factor
171
+ tmp_image_sizes.append((height, width))
172
+ image_sizes = tmp_image_sizes
173
+
174
+ return image_sizes
175
+
176
+
177
+ class ViSCoP_ImageProcessor(BaseImageProcessor):
178
+ r"""
179
+ Constructs a DAMO-VL (https://huggingface.co/DAMO-NLP-SG/VL3-SigLIP-NaViT) image processor that dynamically resizes images based on the original images.
180
+
181
+ Args:
182
+ do_resize (`bool`, *optional*, defaults to `True`):
183
+ Whether to resize the image's (height, width) dimensions.
184
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
185
+ Resampling filter to use when resizing the image.
186
+ do_rescale (`bool`, *optional*, defaults to `True`):
187
+ Whether to rescale the image by the specified scale `rescale_factor`.
188
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
189
+ Scale factor to use if rescaling the image.
190
+ do_normalize (`bool`, *optional*, defaults to `True`):
191
+ Whether to normalize the image.
192
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
193
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
194
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
195
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
196
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
197
+ Whether to convert the image to RGB.
198
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
199
+ The min pixels of the image to resize the image.
200
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
201
+ The max pixels of the image to resize the image.
202
+ patch_size (`int`, *optional*, defaults to 14):
203
+ The spacial patch size of the vision encoder.
204
+ """
205
+
206
+ model_input_names = ["pixel_values", "grid_sizes", "merge_sizes"]
207
+
208
+ def __init__(
209
+ self,
210
+ do_resize: bool = True,
211
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
212
+ do_rescale: bool = True,
213
+ rescale_factor: Union[int, float] = 1 / 255,
214
+ do_normalize: bool = True,
215
+ image_mean: Optional[Union[float, List[float]]] = None,
216
+ image_std: Optional[Union[float, List[float]]] = None,
217
+ do_convert_rgb: bool = True,
218
+ min_tokens: int = 4 * 4,
219
+ max_tokens: int = 16384,
220
+ patch_size: int = 14,
221
+ **kwargs,
222
+ ) -> None:
223
+ super().__init__(**kwargs)
224
+ self.do_resize = do_resize
225
+ self.resample = resample
226
+ self.do_rescale = do_rescale
227
+ self.rescale_factor = rescale_factor
228
+ self.do_normalize = do_normalize
229
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
230
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
231
+ self.min_tokens = min_tokens
232
+ self.max_tokens = max_tokens
233
+ self.patch_size = patch_size
234
+ self.do_convert_rgb = do_convert_rgb
235
+
236
+ def _preprocess(
237
+ self,
238
+ images: Union[ImageInput, VideoInput],
239
+ target_size: List[int],
240
+ merge_size: int = 1,
241
+ do_resize: bool = None,
242
+ resample: PILImageResampling = None,
243
+ do_rescale: bool = None,
244
+ rescale_factor: float = None,
245
+ do_normalize: bool = None,
246
+ image_mean: Optional[Union[float, List[float]]] = None,
247
+ image_std: Optional[Union[float, List[float]]] = None,
248
+ do_convert_rgb: bool = None,
249
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
250
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
251
+ ):
252
+ """
253
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
254
+
255
+ Args:
256
+ images (`ImageInput`):
257
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
258
+ target_size (`List[int]`):
259
+ The target size to resize the image to. Should be a list of two integers: [target_height, target_width].
260
+ merge_size (`int`, *optional*, defaults to `1`):
261
+ The merge size after the vision encoder.
262
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
263
+ Whether to resize the image.
264
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
265
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
266
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
267
+ Whether to rescale the image.
268
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
269
+ Scale factor to use if rescaling the image.
270
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
271
+ Whether to normalize the image.
272
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
273
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
274
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
275
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
276
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
277
+ Whether to convert the image to RGB.
278
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
279
+ The channel dimension format for the output image. Can be one of:
280
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
281
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
282
+ - Unset: Use the channel dimension format of the input image.
283
+ input_data_format (`ChannelDimension` or `str`, *optional*):
284
+ The channel dimension format for the input image. Can be one of:
285
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
286
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
287
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
288
+ """
289
+ images = make_list_of_images(images)
290
+
291
+ if do_convert_rgb:
292
+ images = [convert_to_rgb(image) for image in images]
293
+
294
+ # All transformations expect numpy arrays.
295
+ images = [to_numpy_array(image) for image in images]
296
+
297
+ if is_scaled_image(images[0]) and do_rescale:
298
+ logger.warning_once(
299
+ "It looks like you are trying to rescale already rescaled images. If the input"
300
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
301
+ )
302
+ if input_data_format is None:
303
+ # We assume that all images have the same channel dimension format.
304
+ input_data_format = infer_channel_dimension_format(images[0])
305
+
306
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
307
+ resized_height, resized_width = height, width
308
+ processed_images = []
309
+ for image in images:
310
+ if do_resize:
311
+ resized_height, resized_width = target_size
312
+ image = resize(
313
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
314
+ )
315
+
316
+ if do_rescale:
317
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
318
+
319
+ if do_normalize:
320
+ image = self.normalize(
321
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
322
+ )
323
+
324
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
325
+ processed_images.append(image)
326
+
327
+ patches = np.array(processed_images)
328
+ if data_format == ChannelDimension.LAST:
329
+ patches = patches.transpose(0, 3, 1, 2)
330
+ t = patches.shape[0]
331
+ channel = patches.shape[1]
332
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
333
+ patches = patches.reshape(
334
+ t,
335
+ channel,
336
+ grid_h // merge_size,
337
+ merge_size,
338
+ self.patch_size,
339
+ grid_w // merge_size,
340
+ merge_size,
341
+ self.patch_size,
342
+ )
343
+ patches = patches.transpose(0, 2, 5, 3, 6, 1, 4, 7)
344
+ flatten_patches = patches.reshape(
345
+ t * grid_h * grid_w, channel * self.patch_size * self.patch_size
346
+ )
347
+
348
+ return flatten_patches, (t, grid_h, grid_w)
349
+
350
+ def preprocess(
351
+ self,
352
+ images: ImageInput,
353
+ do_resize: bool = None,
354
+ resample: PILImageResampling = None,
355
+ do_rescale: bool = None,
356
+ rescale_factor: float = None,
357
+ do_normalize: bool = None,
358
+ image_mean: Optional[Union[float, List[float]]] = None,
359
+ image_std: Optional[Union[float, List[float]]] = None,
360
+ do_convert_rgb: bool = None,
361
+ merge_size: Optional[Union[int, List[int]]] = None,
362
+ return_tensors: Optional[Union[str, TensorType]] = None,
363
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
364
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
365
+ ):
366
+ """
367
+ Args:
368
+ images (`ImageInput`):
369
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
370
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
371
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
372
+ Whether to resize the image.
373
+ resample (`int`, *optional*, defaults to `self.resample`):
374
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
375
+ has an effect if `do_resize` is set to `True`.
376
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
377
+ Whether to rescale the image.
378
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
379
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
380
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
381
+ Whether to normalize the image.
382
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
383
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
384
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
385
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
386
+ `True`.
387
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
388
+ Whether to convert the image to RGB.
389
+ return_tensors (`str` or `TensorType`, *optional*):
390
+ The type of tensors to return. Can be one of:
391
+ - Unset: Return a list of `np.ndarray`.
392
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
393
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
394
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
395
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
396
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
397
+ The channel dimension format for the output image. Can be one of:
398
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
399
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
400
+ - Unset: Use the channel dimension format of the input image.
401
+ input_data_format (`ChannelDimension` or `str`, *optional*):
402
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
403
+ from the input image. Can be one of:
404
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
405
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
406
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
407
+
408
+ """
409
+ do_resize = do_resize if do_resize is not None else self.do_resize
410
+ resample = resample if resample is not None else self.resample
411
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
412
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
413
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
414
+ image_mean = image_mean if image_mean is not None else self.image_mean
415
+ image_std = image_std if image_std is not None else self.image_std
416
+ merge_size = merge_size if merge_size is not None else self.merge_size
417
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
418
+
419
+ images = make_batched_images(images)
420
+
421
+ if isinstance(merge_size, (list, tuple)):
422
+ assert len(merge_size) == len(images), "Merge size must be the same length as images."
423
+ merge_sizes = merge_size
424
+ else:
425
+ merge_sizes = [merge_size for _ in images]
426
+
427
+ if all(merge_size == merge_sizes[0] for merge_size in merge_sizes):
428
+ target_sizes = simple_batched_resize(
429
+ images,
430
+ factor=self.patch_size * merge_sizes[0],
431
+ min_tokens=self.min_tokens,
432
+ max_tokens=self.max_tokens,
433
+ input_data_format=input_data_format,
434
+ )
435
+ # if target_sizes[0][0] != 448:
436
+ # for i in range(len(images[0]), 0, -1):
437
+ # target_sizes = simple_batched_resize(
438
+ # [images[0][:i]],
439
+ # factor=self.patch_size * merge_sizes[0],
440
+ # min_tokens=self.min_tokens,
441
+ # max_tokens=self.max_tokens,
442
+ # input_data_format=input_data_format,
443
+ # )
444
+
445
+ # if target_sizes[0][0] == 448:
446
+ # print(f'Num frames to maintain 448x448: {i}')
447
+ else:
448
+ target_sizes = batched_resize(
449
+ images,
450
+ factors=[self.patch_size * merge_size for merge_size in merge_sizes],
451
+ min_tokens=self.min_tokens,
452
+ max_tokens=self.max_tokens,
453
+ input_data_format=input_data_format,
454
+ )
455
+
456
+ pixel_values, grid_sizes = [], []
457
+ for image, merge_size, target_size in zip(images, merge_sizes, target_sizes):
458
+ patches, grid_size = self._preprocess(
459
+ image,
460
+ target_size=target_size,
461
+ merge_size=merge_size,
462
+ do_resize=do_resize,
463
+ resample=resample,
464
+ do_rescale=do_rescale,
465
+ rescale_factor=rescale_factor,
466
+ do_normalize=do_normalize,
467
+ image_mean=image_mean,
468
+ image_std=image_std,
469
+ data_format=data_format,
470
+ do_convert_rgb=do_convert_rgb,
471
+ input_data_format=input_data_format,
472
+ )
473
+ pixel_values.append(patches)
474
+ grid_sizes.append(grid_size)
475
+
476
+ pixel_value_shapes = [x.shape for x in pixel_values]
477
+ pixel_values = np.concatenate(pixel_values, axis=0)
478
+ grid_sizes = np.array(grid_sizes)
479
+ merge_sizes = np.array(merge_sizes)
480
+
481
+ data = {
482
+ "pixel_values": pixel_values,
483
+ "grid_sizes": grid_sizes,
484
+ "merge_sizes": merge_sizes,
485
+ "pixel_value_shapes": pixel_value_shapes
486
+ }
487
+
488
+ return BatchFeature(data=data, tensor_type=return_tensors)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45c1890c91ad9b2aea588035bd528c968702ce4f9ff1fcff3b70cd05e16ff7c9
3
+ size 824342816
modeling_viscop_vision_encoder.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/DAMO-NLP-SG/VideoLLaMA3/blob/main/videollama3/model/videollama3_encoder/modeling_videollama3_encoder.py
2
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py.
3
+ # Below is the original copyright:
4
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ """PyTorch ViSCoP vision encoder model. This file contains the implementation of visual probes and interaction modules"""
23
+
24
+ import importlib.util
25
+ import os.path as osp
26
+ import math
27
+ import warnings
28
+
29
+ from einops import rearrange
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ import torch.utils.checkpoint
35
+ from torch.nn.init import _calculate_fan_in_and_fan_out
36
+
37
+ from transformers.activations import ACT2FN
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import is_flash_attn_2_available
40
+
41
+ if is_flash_attn_2_available():
42
+ from flash_attn import flash_attn_varlen_func
43
+ else:
44
+ flash_attn_varlen_func = None
45
+
46
+ from .configuration_viscop_vision_encoder import ViSCoP_VisionEncoderConfig
47
+
48
+ # try:
49
+ # from .configuration_viscop_encoder import ViSCoP_VisionEncoderConfig
50
+ # except ImportError:
51
+ # spec = importlib.util.spec_from_file_location(
52
+ # "configuration_videollama3_encoder",
53
+ # osp.join(osp.dirname(__file__), "configuration_videollama3_encoder.py"),
54
+ # )
55
+ # configuration_videollama3_encoder = importlib.util.module_from_spec(spec)
56
+ # spec.loader.exec_module(configuration_videollama3_encoder)
57
+ # Videollama3VisionEncoderConfig = getattr(
58
+ # configuration_videollama3_encoder,
59
+ # "Videollama3VisionEncoderConfig",
60
+ # )
61
+
62
+
63
+ def _trunc_normal_(tensor, mean, std, a, b):
64
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
65
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
66
+ def norm_cdf(x):
67
+ # Computes standard normal cumulative distribution function
68
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
69
+
70
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
71
+ warnings.warn(
72
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
73
+ "The distribution of values may be incorrect.",
74
+ stacklevel=2,
75
+ )
76
+
77
+ # Values are generated by using a truncated uniform distribution and
78
+ # then using the inverse CDF for the normal distribution.
79
+ # Get upper and lower cdf values
80
+ l = norm_cdf((a - mean) / std)
81
+ u = norm_cdf((b - mean) / std)
82
+
83
+ # Uniformly fill tensor with values from [l, u], then translate to
84
+ # [2l-1, 2u-1].
85
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
86
+
87
+ # Use inverse cdf transform for normal distribution to get truncated
88
+ # standard normal
89
+ tensor.erfinv_()
90
+
91
+ # Transform to proper mean, std
92
+ tensor.mul_(std * math.sqrt(2.0))
93
+ tensor.add_(mean)
94
+
95
+ # Clamp to ensure it's in the proper range
96
+ tensor.clamp_(min=a, max=b)
97
+
98
+
99
+ def trunc_normal_tf_(
100
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
101
+ ) -> torch.Tensor:
102
+ """Fills the input Tensor with values drawn from a truncated
103
+ normal distribution. The values are effectively drawn from the
104
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
105
+ with values outside :math:`[a, b]` redrawn until they are within
106
+ the bounds. The method used for generating the random values works
107
+ best when :math:`a \\leq \text{mean} \\leq b`.
108
+
109
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
110
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
111
+ and the result is subsequently scaled and shifted by the mean and std args.
112
+
113
+ Args:
114
+ tensor: an n-dimensional `torch.Tensor`
115
+ mean: the mean of the normal distribution
116
+ std: the standard deviation of the normal distribution
117
+ a: the minimum cutoff value
118
+ b: the maximum cutoff value
119
+ """
120
+ with torch.no_grad():
121
+ _trunc_normal_(tensor, 0, 1.0, a, b)
122
+ tensor.mul_(std).add_(mean)
123
+
124
+
125
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
126
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
127
+ if mode == "fan_in":
128
+ denom = fan_in
129
+ elif mode == "fan_out":
130
+ denom = fan_out
131
+ elif mode == "fan_avg":
132
+ denom = (fan_in + fan_out) / 2
133
+
134
+ variance = scale / denom
135
+
136
+ if distribution == "truncated_normal":
137
+ # constant is stddev of standard normal truncated to (-2, 2)
138
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
139
+ elif distribution == "normal":
140
+ with torch.no_grad():
141
+ tensor.normal_(std=math.sqrt(variance))
142
+ elif distribution == "uniform":
143
+ bound = math.sqrt(3 * variance)
144
+ with torch.no_grad():
145
+ tensor.uniform_(-bound, bound)
146
+ else:
147
+ raise ValueError(f"invalid distribution {distribution}")
148
+
149
+
150
+ def lecun_normal_(tensor):
151
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
152
+
153
+
154
+ def default_flax_embed_init(tensor):
155
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
156
+
157
+
158
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
159
+ def rotate_half(x):
160
+ """Rotates half the hidden dims of the input."""
161
+ x1 = x[..., : x.shape[-1] // 2]
162
+ x2 = x[..., x.shape[-1] // 2 :]
163
+ return torch.cat((-x2, x1), dim=-1)
164
+
165
+
166
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
167
+ orig_dtype = tensor.dtype
168
+ tensor = tensor.float()
169
+ cos = freqs.cos()
170
+ sin = freqs.sin()
171
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
172
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
173
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
174
+ output = output.to(orig_dtype)
175
+ return output
176
+
177
+
178
+ class VisionRotaryEmbedding(nn.Module):
179
+
180
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
181
+ super().__init__()
182
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
183
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
184
+
185
+ def forward(self, seqlen: int) -> torch.Tensor:
186
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
187
+ freqs = torch.outer(seq, self.inv_freq)
188
+ return freqs
189
+
190
+
191
+ class VisionEmbeddings(nn.Module):
192
+
193
+ def __init__(self, config: ViSCoP_VisionEncoderConfig):
194
+ super().__init__()
195
+ self.config = config
196
+ self.embed_dim = config.hidden_size
197
+ self.patch_size = config.patch_size
198
+
199
+ self.patch_embedding = nn.Conv2d(
200
+ in_channels=config.num_channels,
201
+ out_channels=self.embed_dim,
202
+ kernel_size=self.patch_size,
203
+ stride=self.patch_size,
204
+ padding="valid",
205
+ )
206
+
207
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
208
+ hidden_states = hidden_states.view(
209
+ -1, self.config.num_channels, self.patch_size, self.patch_size
210
+ )
211
+ patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
212
+ # embeddings = patch_embeds.flatten(2).transpose(1, 2)
213
+ embeddings = patch_embeds.view(-1, self.embed_dim)
214
+
215
+ return embeddings
216
+
217
+
218
+ class VisionAttention(nn.Module):
219
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
220
+
221
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
222
+ def __init__(self, config):
223
+ super().__init__()
224
+ self.config = config
225
+ self.embed_dim = config.hidden_size
226
+ self.num_heads = config.num_attention_heads
227
+ self.head_dim = self.embed_dim // self.num_heads
228
+ if self.head_dim * self.num_heads != self.embed_dim:
229
+ raise ValueError(
230
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
231
+ f" {self.num_heads})."
232
+ )
233
+ self.scale = self.head_dim**-0.5
234
+ self.dropout = config.attention_dropout
235
+
236
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
237
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
238
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
239
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ cu_seqlens: torch.Tensor,
245
+ rotary_pos_emb: torch.Tensor = None,
246
+ ) -> torch.Tensor:
247
+ """Input shape: Time x Channel"""
248
+
249
+ q_len, _ = hidden_states.size()
250
+
251
+ query_states = self.q_proj(hidden_states)
252
+ key_states = self.k_proj(hidden_states)
253
+ value_states = self.v_proj(hidden_states)
254
+
255
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
256
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
257
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
258
+
259
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
260
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
261
+
262
+ attention_mask = torch.zeros([1, q_len, q_len], device=query_states.device, dtype=torch.bool)
263
+ for i in range(1, len(cu_seqlens)):
264
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
265
+
266
+ query_states = query_states.transpose(0, 1)
267
+ key_states = key_states.transpose(0, 1)
268
+ value_states = value_states.transpose(0, 1)
269
+
270
+ attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
271
+ attn_weights = attn_weights + attention_mask
272
+
273
+ # upcast attention to fp32
274
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
275
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
276
+ attn_output = torch.matmul(attn_weights, value_states)
277
+
278
+ attn_output = attn_output.transpose(0, 1)
279
+ attn_output = attn_output.reshape(q_len, -1)
280
+ attn_output = self.out_proj(attn_output)
281
+
282
+ return attn_output
283
+
284
+
285
+ class VisionFlashAttention2(VisionAttention):
286
+
287
+ def __init__(self, *args, **kwargs):
288
+ super().__init__(*args, **kwargs)
289
+
290
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ cu_seqlens: torch.Tensor,
295
+ rotary_pos_emb: torch.Tensor = None,
296
+ ) -> torch.Tensor:
297
+ q_len, _ = hidden_states.size()
298
+
299
+ query_states = self.q_proj(hidden_states)
300
+ key_states = self.k_proj(hidden_states)
301
+ value_states = self.v_proj(hidden_states)
302
+
303
+ # Flash attention requires the input to have the shape
304
+ # batch_size x seq_length x head_dim x hidden_dim
305
+ # therefore we just need to keep the original shape
306
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
307
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
308
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
309
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
310
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
311
+
312
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
313
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
314
+ q_len, -1
315
+ )
316
+ attn_output = self.out_proj(attn_output)
317
+
318
+ return attn_output
319
+
320
+ class InteractionModule_CrossAttention_FA2(VisionAttention):
321
+
322
+ def __init__(self, *args, **kwargs):
323
+ super().__init__(*args, **kwargs)
324
+
325
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
326
+ def forward(
327
+ self,
328
+ visual_probes: torch.Tensor,
329
+ hidden_states: torch.Tensor,
330
+ cu_seqlensq: torch.Tensor,
331
+ cu_seqlenskv: torch.Tensor,
332
+ rotary_pos_emb: torch.Tensor = None,
333
+ ) -> torch.Tensor:
334
+ q_len, _ = visual_probes.size()
335
+ kv_len, _ = hidden_states.size()
336
+
337
+ query_states = self.q_proj(visual_probes)
338
+ key_states = self.k_proj(hidden_states)
339
+ value_states = self.v_proj(hidden_states)
340
+
341
+ # Flash attention requires the input to have the shape
342
+ # batch_size x seq_length x head_dim x hidden_dim
343
+ # therefore we just need to keep the original shape
344
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
345
+ key_states = key_states.view(kv_len, self.num_heads, self.head_dim)
346
+ value_states = value_states.view(kv_len, self.num_heads, self.head_dim)
347
+ # query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
348
+ # key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
349
+
350
+ max_seqlen_q = (cu_seqlensq[1:] - cu_seqlensq[:-1]).max().item()
351
+ max_seqlen_kv = (cu_seqlenskv[1:] - cu_seqlenskv[:-1]).max().item()
352
+
353
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlensq, cu_seqlenskv, max_seqlen_q, max_seqlen_kv).reshape(
354
+ q_len, -1
355
+ )
356
+ attn_output = self.out_proj(attn_output)
357
+
358
+ return attn_output
359
+
360
+
361
+ class VisionSdpaAttention(VisionAttention):
362
+
363
+ def forward(
364
+ self,
365
+ hidden_states: torch.Tensor,
366
+ cu_seqlens: torch.Tensor,
367
+ rotary_pos_emb: torch.Tensor = None,
368
+ ) -> torch.Tensor:
369
+ seq_length = hidden_states.shape[0]
370
+ query_states = self.q_proj(hidden_states)
371
+ key_states = self.k_proj(hidden_states)
372
+ value_states = self.v_proj(hidden_states)
373
+
374
+ query_states = query_states.view(seq_length, self.num_heads, self.head_dim)
375
+ key_states = key_states.view(seq_length, self.num_heads, self.head_dim)
376
+ value_states = value_states.view(seq_length, self.num_heads, self.head_dim)
377
+
378
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
379
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
380
+
381
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
382
+ for i in range(1, len(cu_seqlens)):
383
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
384
+
385
+ query_states = query_states.transpose(0, 1)
386
+ key_states = key_states.transpose(0, 1)
387
+ value_states = value_states.transpose(0, 1)
388
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
389
+ attn_output = attn_output.transpose(0, 1)
390
+ attn_output = attn_output.reshape(seq_length, -1)
391
+ attn_output = self.out_proj(attn_output)
392
+ return attn_output
393
+
394
+
395
+ VISION_ATTENTION_CLASSES = {
396
+ "eager": VisionAttention,
397
+ "flash_attention_2": VisionFlashAttention2,
398
+ "sdpa": VisionSdpaAttention,
399
+ }
400
+
401
+ class SigLIPVisionMLP(nn.Module):
402
+
403
+ def __init__(self, config):
404
+ super().__init__()
405
+ self.config = config
406
+ self.activation_fn = ACT2FN[config.hidden_act]
407
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
408
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
409
+
410
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
411
+ hidden_states = self.fc1(hidden_states)
412
+ hidden_states = self.activation_fn(hidden_states)
413
+ hidden_states = self.fc2(hidden_states)
414
+ return hidden_states
415
+
416
+
417
+ class SigLIPVisionEncoderLayer(nn.Module):
418
+
419
+ def __init__(self, config: ViSCoP_VisionEncoderConfig):
420
+ super().__init__()
421
+ self.embed_dim = config.hidden_size
422
+ self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
423
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
424
+ self.mlp = SigLIPVisionMLP(config)
425
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
426
+
427
+ # Ignore copy
428
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
429
+ # hidden_states = hidden_states + self.self_attn(
430
+ # self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
431
+ # )
432
+ hidden_states = hidden_states + self.self_attn(
433
+ self.layer_norm1(hidden_states), cu_seqlens, rotary_pos_emb
434
+ )
435
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
436
+ return hidden_states
437
+
438
+ # Alternative where each interaction module is a copy of the entire vision transformer layer.
439
+ # Empirically we found that simple cross-attention works better.
440
+ class InteractionModule_CrossAttentionEncoderLayer(nn.Module):
441
+ def __init__(self, config):
442
+ super().__init__()
443
+ embed_dim = config.hidden_size
444
+ self.cross_attn = InteractionModule_CrossAttention_FA2(config)
445
+
446
+ # Separate norms for queries and visual features
447
+ self.q_ln1 = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
448
+ self.kv_ln1 = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
449
+
450
+ self.q_ln2 = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
451
+ self.mlp = SigLIPVisionMLP(config)
452
+
453
+ def forward(self, visual_probes, hidden_states, query_seqlens, ca_kv_seqlens, rotary_pos_emb):
454
+ visual_probes = visual_probes + self.cross_attn(
455
+ self.q_ln1(visual_probes), self.kv_ln1(hidden_states), query_seqlens, ca_kv_seqlens, rotary_pos_emb
456
+ )
457
+ visual_probes = visual_probes + self.mlp(self.q_ln2(visual_probes))
458
+
459
+ return visual_probes
460
+
461
+
462
+ INTERACTION_MODULES = {
463
+ 'cross_attention': InteractionModule_CrossAttention_FA2,
464
+ 'cross_attention_transformer': InteractionModule_CrossAttentionEncoderLayer
465
+ }
466
+
467
+
468
+ class ViSCoP_VisionTransformerEncoder(nn.Module):
469
+
470
+ def __init__(self, config: ViSCoP_VisionEncoderConfig):
471
+ super().__init__()
472
+ self.config = config
473
+ head_dim = config.hidden_size // config.num_attention_heads
474
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
475
+ self.layers = nn.ModuleList([SigLIPVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
476
+ self.gradient_checkpointing = False
477
+
478
+ assert config.interaction_module in INTERACTION_MODULES, f'interaction module must be in: {INTERACTION_MODULES.keys()}'
479
+ num_interaction_modules = config.num_hidden_layers if config.interaction_module_layers is None else len(config.interaction_module_layers)
480
+ self.interaction_modules = nn.ModuleList([INTERACTION_MODULES[config.interaction_module](config) for _ in range(num_interaction_modules)]) # > Create interaction modules
481
+ self.interaction_module_layers = [_ for _ in range(config.num_hidden_layers)] if config.interaction_module_layers is None else config.interaction_module_layers
482
+ self.interaction_module_mapping = {k: v for v, k in enumerate(self.interaction_module_layers)}
483
+
484
+ self.num_visual_probes = config.num_visual_probes
485
+ self.visual_probes = nn.Parameter(torch.zeros(1, self.num_visual_probes, self.config.hidden_size)) # > Create the visual probes
486
+
487
+ def rot_pos_emb(self, grid_sizes, merge_sizes):
488
+ pos_ids = []
489
+ for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
490
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
491
+ hpos_ids = hpos_ids.reshape(
492
+ h // merge_size,
493
+ merge_size,
494
+ w // merge_size,
495
+ merge_size,
496
+ )
497
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
498
+ hpos_ids = hpos_ids.flatten()
499
+
500
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
501
+ wpos_ids = wpos_ids.reshape(
502
+ h // merge_size,
503
+ merge_size,
504
+ w // merge_size,
505
+ merge_size,
506
+ )
507
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
508
+ wpos_ids = wpos_ids.flatten()
509
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
510
+
511
+ pos_ids = torch.cat(pos_ids, dim=0)
512
+ max_grid_size = grid_sizes[:, 1:].max()
513
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
514
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
515
+
516
+ return rotary_pos_emb
517
+
518
+ def forward(self, hidden_states, grid_sizes, merge_sizes) -> torch.Tensor:
519
+ rotary_pos_emb = self.rot_pos_emb(grid_sizes, merge_sizes)
520
+
521
+ # mask for self-attention (per-grid attention)
522
+ cu_seqlens = torch.repeat_interleave(grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]).cumsum(dim=0, dtype=torch.int32)
523
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
524
+
525
+ # > repeat probes
526
+ batch_size = grid_sizes.shape[0]
527
+ visual_probes = self.visual_probes.repeat(batch_size, 1, 1)
528
+ visual_probes = visual_probes.view(batch_size*self.num_visual_probes, self.config.hidden_size)
529
+
530
+ # > mask for interaction module (cross-attend probes to all grids)
531
+ visual_probe_seqlens = torch.tensor([self.num_visual_probes]*batch_size).cumsum(dim=0, dtype=torch.int32)
532
+ visual_probe_seqlens = F.pad(visual_probe_seqlens, (1, 0), value=0).to(cu_seqlens.device)
533
+
534
+ ca_kv_seqlens = (grid_sizes[:, 0] * (grid_sizes[:, 1] * grid_sizes[:, 2])).cumsum(dim=0, dtype=torch.int32)
535
+ ca_kv_seqlens = F.pad(ca_kv_seqlens, (1, 0), value=0)
536
+
537
+ for i, blk in enumerate(self.layers):
538
+ if self.gradient_checkpointing and self.training:
539
+ hidden_states = self._gradient_checkpointing_func(
540
+ blk.__call__,
541
+ hidden_states,
542
+ cu_seqlens,
543
+ rotary_pos_emb
544
+ )
545
+
546
+ if i in self.interaction_module_layers:
547
+ visual_probes = self._gradient_checkpointing_func(
548
+ self.interaction_modules[self.interaction_module_mapping[i]].__call__,
549
+ visual_probes,
550
+ hidden_states,
551
+ visual_probe_seqlens,
552
+ ca_kv_seqlens,
553
+ rotary_pos_emb
554
+ )
555
+ else:
556
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
557
+
558
+ if i in self.interaction_module_layers:
559
+ visual_probes = self.interaction_modules[self.interaction_module_mapping[i]](
560
+ visual_probes,
561
+ hidden_states,
562
+ visual_probe_seqlens,
563
+ ca_kv_seqlens,
564
+ rotary_pos_emb
565
+ )
566
+
567
+ return hidden_states, visual_probes
568
+
569
+
570
+ class ViSCoP_VisionEncoderModel(PreTrainedModel):
571
+
572
+ config_class = ViSCoP_VisionEncoderConfig
573
+ base_model_prefix = "viscop"
574
+ main_input_name = "pixel_values"
575
+ supports_gradient_checkpointing = True
576
+ _no_split_modules = [
577
+ "SigLIPVisionEncoderLayer",
578
+ "VisionEmbeddings",
579
+ ]
580
+ _supports_flash_attn_2 = True
581
+ _supports_sdpa = True
582
+
583
+ def __init__(self, config: ViSCoP_VisionEncoderConfig, **kwargs):
584
+ super().__init__(config=config)
585
+ embed_dim = config.hidden_size
586
+ config.num_visual_probes = kwargs['model_cfg'].num_visual_probes
587
+ config.interaction_module_layers = kwargs['model_cfg'].interaction_module_layers
588
+ config.interaction_module = kwargs['model_cfg'].interaction_module
589
+
590
+ self.embeddings = VisionEmbeddings(config)
591
+ self.encoder = ViSCoP_VisionTransformerEncoder(config)
592
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
593
+ # if config.interaction_module != 'cross_attention':
594
+ # self.post_layernorm_ca = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
595
+
596
+ self.post_init()
597
+
598
+ def forward(self, pixel_values, grid_sizes, merge_sizes=None) -> torch.Tensor:
599
+ hidden_states = self.embeddings(pixel_values)
600
+ hidden_states, visual_probes = self.encoder(hidden_states, grid_sizes, merge_sizes)
601
+ hidden_states = self.post_layernorm(hidden_states)
602
+ # if self.config.interaction_module != 'cross_attention':
603
+ # visual_probes = self.post_layernorm_ca(visual_probes)
604
+
605
+ hidden_states_chunks = hidden_states.split(grid_sizes.prod(dim=1).tolist(), dim=0)
606
+ outputs = []
607
+
608
+ for hidden_states, grid_size, merge_size in zip(hidden_states_chunks, grid_sizes, merge_sizes):
609
+ # NOTE: previous implementation, which supports downsampling with any factor
610
+ c = hidden_states.shape[-1]
611
+ hidden_states = hidden_states.view(
612
+ grid_size[0], grid_size[1] // merge_size, grid_size[2] // merge_size, merge_size, merge_size, c
613
+ ).permute(0, 1, 3, 2, 4, 5)
614
+ hidden_states = hidden_states.reshape(
615
+ grid_size[0], grid_size[1], grid_size[2], c
616
+ ).permute(0, 3, 1, 2)
617
+ hidden_states = torch.nn.functional.interpolate(
618
+ hidden_states,
619
+ size=(grid_size[1] // merge_size, grid_size[2] // merge_size),
620
+ mode='bilinear'
621
+ )
622
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(-1, c)
623
+
624
+ # NOTE: simplified implementation, which only supports downsampling with integer factor
625
+ # NOTE: this implementation is mathematically equivalent to the previous one when merge_size is 1 or 2 but may cause slightly different results
626
+ # hidden_states = hidden_states.view(-1, merge_size * merge_size, hidden_states.size(-1))
627
+ # hidden_states = hidden_states.mean(dim=1)
628
+
629
+ outputs.append(hidden_states)
630
+
631
+ return torch.cat(outputs, dim=0), visual_probes
632
+
633
+ def _init_weights(self, module):
634
+ """Initialize the weights"""
635
+ if isinstance(module, nn.Embedding):
636
+ default_flax_embed_init(module.weight)
637
+ elif isinstance(module, VisionAttention):
638
+ if isinstance(module, InteractionModule_CrossAttention_FA2): # for deepspeed-zero3
639
+ return
640
+ nn.init.xavier_uniform_(module.q_proj.weight)
641
+ nn.init.xavier_uniform_(module.k_proj.weight)
642
+ nn.init.xavier_uniform_(module.v_proj.weight)
643
+ nn.init.xavier_uniform_(module.out_proj.weight)
644
+ nn.init.zeros_(module.q_proj.bias)
645
+ nn.init.zeros_(module.k_proj.bias)
646
+ nn.init.zeros_(module.v_proj.bias)
647
+ nn.init.zeros_(module.out_proj.bias)
648
+ elif isinstance(module, SigLIPVisionMLP):
649
+ nn.init.xavier_uniform_(module.fc1.weight)
650
+ nn.init.xavier_uniform_(module.fc2.weight)
651
+ nn.init.normal_(module.fc1.bias, std=1e-6)
652
+ nn.init.normal_(module.fc2.bias, std=1e-6)
653
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
654
+ lecun_normal_(module.weight)
655
+ if module.bias is not None:
656
+ nn.init.zeros_(module.bias)
657
+ elif isinstance(module, nn.LayerNorm):
658
+ module.bias.data.zero_()
659
+ module.weight.data.fill_(1.0)
660
+ elif isinstance(module, ViSCoP_VisionTransformerEncoder):
661
+ nn.init.normal_(module.visual_probes, std=.02)
preprocessor_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_viscop.ViSCoP_ImageProcessor"
4
+ },
5
+ "do_convert_rgb": null,
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_processor_type": "ViSCoP_ImageProcessor",
15
+ "image_std": [
16
+ 0.5,
17
+ 0.5,
18
+ 0.5
19
+ ],
20
+ "max_tokens": 16384,
21
+ "min_tokens": 16,
22
+ "patch_size": 14,
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098
25
+ }