wtzhang-nlp commited on
Commit
10a59d1
·
verified ·
1 Parent(s): d5ef13a

Upload SigLIP2 NaViT model with Google checkpoint

Browse files
README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ tags:
5
+ - vision
6
+ - image-feature-extraction
7
+ - siglip
8
+ - navit
9
+ - google
10
+ pipeline_tag: image-feature-extraction
11
+ ---
12
+
13
+ # SigLIP2 NaViT Vision Encoder (with Google Pretrained Weights)
14
+
15
+ This is a SigLIP2 NaViT (Native Resolution Vision Transformer) vision encoder model, initialized with Google's pretrained SigLIP2 checkpoint. The vision encoder weights are from Google's checkpoint, while the merger layer is randomly initialized.
16
+
17
+ ## Model Details
18
+
19
+ - **Model Type:** Vision Encoder
20
+ - **Architecture:** SigLIP2 with Native Resolution ViT
21
+ - **Base Checkpoint:** Google SigLIP2
22
+ - **Precision:** FP16 (float16) for reduced storage
23
+ - **Hidden Size:** 768
24
+ - **Number of Layers:** 12
25
+ - **Number of Attention Heads:** 12
26
+ - **Patch Size:** 16
27
+ - **Spatial Merge Size:** 2
28
+ - **Output Hidden Size:** 896
29
+
30
+ ## Initialization
31
+
32
+ - **Vision Encoder:** Initialized from Google's SigLIP2 pretrained checkpoint
33
+ - **Vision Merger:** Randomly initialized (ready for fine-tuning)
34
+
35
+ ## Usage
36
+
37
+ ```python
38
+ from transformers import AutoModel, AutoImageProcessor
39
+ from PIL import Image
40
+ import torch
41
+
42
+ # Load model and processor
43
+ model = AutoModel.from_pretrained("wtzhang-nlp/siglip2-navit-google", trust_remote_code=True)
44
+ processor = AutoImageProcessor.from_pretrained("wtzhang-nlp/siglip2-navit-google", trust_remote_code=True)
45
+
46
+ # Load and process image
47
+ image = Image.open("path/to/image.jpg")
48
+ inputs = processor(images=image, return_tensors="pt")
49
+
50
+ # Forward pass
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+
54
+ print(f"Output shape: {outputs.last_hidden_state.shape}")
55
+ # Expected: [batch_size, num_patches, 896]
56
+ ```
57
+
58
+ ## License
59
+
60
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Siglip2NaViTVisionModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_siglip2_navit.Siglip2NaViTVisionConfig",
8
+ "AutoModel": "modeling_siglip2_navit.Siglip2NaViTVisionModel"
9
+ },
10
+ "dtype": "float16",
11
+ "hidden_act": "gelu_pytorch_tanh",
12
+ "hidden_size": 768,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-06,
15
+ "model_type": "siglip2_navit",
16
+ "num_attention_heads": 12,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 12,
19
+ "out_hidden_size": 896,
20
+ "patch_size": 16,
21
+ "spatial_merge_size": 2,
22
+ "transformers_version": "4.57.3"
23
+ }
configuration_siglip2_navit.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Siglip2NaViT vision encoder model configuration."""
18
+
19
+ from transformers import PretrainedConfig
20
+
21
+
22
+ class Siglip2NaViTVisionConfig(PretrainedConfig):
23
+
24
+ model_type = "siglip2_navit"
25
+
26
+ def __init__(
27
+ self,
28
+ hidden_size=768,
29
+ intermediate_size=3072,
30
+ num_hidden_layers=12,
31
+ num_attention_heads=12,
32
+ num_channels=3,
33
+ patch_size=16,
34
+ hidden_act="gelu_pytorch_tanh",
35
+ layer_norm_eps=1e-6,
36
+ attention_dropout=0.0,
37
+ out_hidden_size=896,
38
+ spatial_merge_size=2,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+
43
+ self.hidden_size = hidden_size
44
+ self.intermediate_size = intermediate_size
45
+ self.num_hidden_layers = num_hidden_layers
46
+ self.num_attention_heads = num_attention_heads
47
+ self.num_channels = num_channels
48
+ self.patch_size = patch_size
49
+ self.attention_dropout = attention_dropout
50
+ self.layer_norm_eps = layer_norm_eps
51
+ self.hidden_act = hidden_act
52
+ self.out_hidden_size = out_hidden_size
53
+ self.spatial_merge_size = spatial_merge_size
image_processor_siglip2_navit.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://huggingface.co/DAMO-NLP-SG/SigLIP-NaViT/raw/main/image_processing_videollama3.py
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """Image processor class for Siglip2NaViT."""
22
+
23
+ import math
24
+ from typing import List, Optional, Union
25
+
26
+ import numpy as np
27
+
28
+ import torch
29
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
30
+ from transformers.image_utils import ImageInput
31
+ from transformers.image_transforms import (
32
+ convert_to_rgb,
33
+ resize,
34
+ to_channel_dimension_format,
35
+ )
36
+ from transformers.image_utils import (
37
+ OPENAI_CLIP_MEAN,
38
+ OPENAI_CLIP_STD,
39
+ ChannelDimension,
40
+ ImageInput,
41
+ PILImageResampling,
42
+ get_image_size,
43
+ infer_channel_dimension_format,
44
+ is_scaled_image,
45
+ is_valid_image,
46
+ make_list_of_images,
47
+ to_numpy_array,
48
+ )
49
+
50
+ try:
51
+ from transformers.image_utils import VideoInput
52
+ except:
53
+ from transformers.video_utils import VideoInput
54
+ from transformers.utils import TensorType, is_vision_available, logging
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+
60
+ if is_vision_available():
61
+ from PIL import Image
62
+
63
+
64
+ def is_valid_video(video) -> bool:
65
+ if isinstance(video, (list, tuple)):
66
+ return all(is_valid_image(frame) for frame in video)
67
+ elif isinstance(video, np.ndarray):
68
+ return video.ndim == 4
69
+ elif isinstance(video, torch.Tensor):
70
+ return video.ndim == 4
71
+ return False
72
+
73
+
74
+ def make_batched_images(images) -> List[List[ImageInput]]:
75
+ """
76
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
77
+
78
+ Args:
79
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
80
+ The input image.
81
+
82
+ Returns:
83
+ list: A list of images.
84
+ """
85
+ if isinstance(images, (list, tuple)):
86
+ # list of images/videos
87
+ if not all(is_valid_video(image) or is_valid_image(image) for image in images):
88
+ raise ValueError(f"Could not make batched images from {images}")
89
+ return images
90
+ elif is_valid_video(images) or is_valid_image(images):
91
+ # single image/video
92
+ return [images]
93
+
94
+ raise ValueError(f"Could not make batched images from {images}")
95
+
96
+
97
+ def simple_batched_resize(
98
+ images,
99
+ factor: int,
100
+ min_tokens: int = 32,
101
+ max_tokens: int = 16384,
102
+ input_data_format: str = None,
103
+ ):
104
+ min_pixels = min_tokens * factor * factor
105
+ max_pixels = max_tokens * factor * factor
106
+
107
+ num_images = 0
108
+ for image in images:
109
+ if is_valid_video(image):
110
+ num_images += len(image)
111
+ else:
112
+ num_images += 1
113
+
114
+ image_sizes = []
115
+ for image in images:
116
+ if is_valid_video(image):
117
+ image = image[0]
118
+ if isinstance(image, Image.Image):
119
+ width, height = image.size
120
+ else:
121
+ height, width = get_image_size(image, channel_dim=input_data_format)
122
+ image_sizes.append([height, width])
123
+
124
+ tmp_image_sizes = []
125
+ for height, width in image_sizes:
126
+ h_bar = round(height / factor) * factor
127
+ w_bar = round(width / factor) * factor
128
+ if h_bar * w_bar > (max_pixels // num_images):
129
+ beta = math.sqrt((height * width) / (max_pixels // num_images))
130
+ h_bar = math.floor(height / beta / factor) * factor
131
+ w_bar = math.floor(width / beta / factor) * factor
132
+ # per image min_pixels
133
+ if h_bar * w_bar < min_pixels:
134
+ beta = math.sqrt(min_pixels / (height * width))
135
+ h_bar = math.ceil(height * beta / factor) * factor
136
+ w_bar = math.ceil(width * beta / factor) * factor
137
+ tmp_image_sizes.append((h_bar, w_bar))
138
+ image_sizes = tmp_image_sizes
139
+ return image_sizes
140
+
141
+
142
+ def batched_resize(
143
+ images,
144
+ factors: List[int],
145
+ min_tokens: int = 32,
146
+ max_tokens: int = 16384,
147
+ input_data_format: str = None,
148
+ ):
149
+ image_sizes = []
150
+ for image in images:
151
+ if is_valid_video(image):
152
+ num_frame = len(image)
153
+ image = image[0]
154
+ else:
155
+ num_frame = 1
156
+ if isinstance(image, Image.Image):
157
+ width, height = image.size
158
+ else:
159
+ height, width = get_image_size(image, channel_dim=input_data_format)
160
+ image_sizes.append([num_frame, height, width])
161
+
162
+ # global max_pixels
163
+ smart_scale_factors = 1.0
164
+ total_tokens = 0
165
+ for (num_frame, height, width), factor in zip(image_sizes, factors):
166
+ total_tokens += (
167
+ num_frame * math.ceil(height / factor) * math.ceil(width / factor)
168
+ )
169
+
170
+ # TODO: add min_pixels
171
+ if total_tokens > max_tokens:
172
+ beta = math.sqrt(total_tokens / max_tokens)
173
+ tmp_image_sizes = []
174
+ for (_, height, width), factor in zip(image_sizes, factors):
175
+ h_bar = math.floor(height / beta / factor) * factor
176
+ w_bar = math.floor(width / beta / factor) * factor
177
+ tmp_image_sizes.append((h_bar, w_bar))
178
+ image_sizes = tmp_image_sizes
179
+ else:
180
+ tmp_image_sizes = []
181
+ for (_, height, width), factor in zip(image_sizes, factors):
182
+ height = round(height / factor) * factor
183
+ width = round(width / factor) * factor
184
+ tmp_image_sizes.append((height, width))
185
+ image_sizes = tmp_image_sizes
186
+
187
+ return image_sizes
188
+
189
+
190
+ class SigLIP2NaViTImageProcessor(BaseImageProcessor):
191
+ r"""
192
+ Constructs a Siglip2NaViT image processor that dynamically resizes images based on the original images.
193
+
194
+ Args:
195
+ do_resize (`bool`, *optional*, defaults to `True`):
196
+ Whether to resize the image's (height, width) dimensions.
197
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
198
+ Resampling filter to use when resizing the image.
199
+ do_rescale (`bool`, *optional*, defaults to `True`):
200
+ Whether to rescale the image by the specified scale `rescale_factor`.
201
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
202
+ Scale factor to use if rescaling the image.
203
+ do_normalize (`bool`, *optional*, defaults to `True`):
204
+ Whether to normalize the image.
205
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
206
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
207
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
208
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
209
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
210
+ Whether to convert the image to RGB.
211
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
212
+ The min pixels of the image to resize the image.
213
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
214
+ The max pixels of the image to resize the image.
215
+ patch_size (`int`, *optional*, defaults to 14):
216
+ The spacial patch size of the vision encoder.
217
+ """
218
+
219
+ model_input_names = ["pixel_values", "grid_sizes", "merge_sizes"]
220
+
221
+ def __init__(
222
+ self,
223
+ do_resize: bool = True,
224
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
225
+ do_rescale: bool = True,
226
+ rescale_factor: Union[int, float] = 1 / 255,
227
+ do_normalize: bool = True,
228
+ image_mean: Optional[Union[float, List[float]]] = None,
229
+ image_std: Optional[Union[float, List[float]]] = None,
230
+ do_convert_rgb: bool = True,
231
+ min_tokens: int = 32,
232
+ max_tokens: int = 196,
233
+ patch_size: int = 16,
234
+ merge_size: int = 2,
235
+ **kwargs,
236
+ ) -> None:
237
+ super().__init__(**kwargs)
238
+ self.do_resize = do_resize
239
+ self.resample = resample
240
+ self.do_rescale = do_rescale
241
+ self.rescale_factor = rescale_factor
242
+ self.do_normalize = do_normalize
243
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
244
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
245
+ self.min_tokens = min_tokens
246
+ self.max_tokens = max_tokens
247
+ self.patch_size = patch_size
248
+ self.do_convert_rgb = do_convert_rgb
249
+ self.merge_size = merge_size
250
+
251
+ def _preprocess(
252
+ self,
253
+ images: Union[ImageInput, VideoInput],
254
+ target_size: List[int],
255
+ do_resize: bool = None,
256
+ resample: PILImageResampling = None,
257
+ do_rescale: bool = None,
258
+ rescale_factor: float = None,
259
+ do_normalize: bool = None,
260
+ image_mean: Optional[Union[float, List[float]]] = None,
261
+ image_std: Optional[Union[float, List[float]]] = None,
262
+ do_convert_rgb: bool = None,
263
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
264
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
265
+ ):
266
+ """
267
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
268
+
269
+ Args:
270
+ images (`ImageInput`):
271
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
272
+ target_size (`List[int]`):
273
+ The target size to resize the image to. Should be a list of two integers: [target_height, target_width].
274
+ merge_size (`int`, *optional*, defaults to `1`):
275
+ The merge size after the vision encoder.
276
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
277
+ Whether to resize the image.
278
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
279
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
280
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
281
+ Whether to rescale the image.
282
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
283
+ Scale factor to use if rescaling the image.
284
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
285
+ Whether to normalize the image.
286
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
287
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
288
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
289
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
290
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
291
+ Whether to convert the image to RGB.
292
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
293
+ The channel dimension format for the output image. Can be one of:
294
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
295
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
296
+ - Unset: Use the channel dimension format of the input image.
297
+ input_data_format (`ChannelDimension` or `str`, *optional*):
298
+ The channel dimension format for the input image. Can be one of:
299
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
300
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
301
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
302
+ """
303
+ images = make_list_of_images(images)
304
+
305
+ if do_convert_rgb:
306
+ images = [convert_to_rgb(image) for image in images]
307
+
308
+ # All transformations expect numpy arrays.
309
+ images = [to_numpy_array(image) for image in images]
310
+
311
+ if is_scaled_image(images[0]) and do_rescale:
312
+ logger.warning_once(
313
+ "It looks like you are trying to rescale already rescaled images. If the input"
314
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
315
+ )
316
+ if input_data_format is None:
317
+ # We assume that all images have the same channel dimension format.
318
+ input_data_format = infer_channel_dimension_format(images[0])
319
+
320
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
321
+ resized_height, resized_width = height, width
322
+ processed_images = []
323
+ for image in images:
324
+ if do_resize:
325
+ resized_height, resized_width = target_size
326
+ image = resize(
327
+ image,
328
+ size=(resized_height, resized_width),
329
+ resample=resample,
330
+ input_data_format=input_data_format,
331
+ )
332
+
333
+ if do_rescale:
334
+ image = self.rescale(
335
+ image, scale=rescale_factor, input_data_format=input_data_format
336
+ )
337
+
338
+ if do_normalize:
339
+ image = self.normalize(
340
+ image=image,
341
+ mean=image_mean,
342
+ std=image_std,
343
+ input_data_format=input_data_format,
344
+ )
345
+
346
+ image = to_channel_dimension_format(
347
+ image, data_format, input_channel_dim=input_data_format
348
+ )
349
+ processed_images.append(image)
350
+
351
+ patches = np.array(processed_images)
352
+ if data_format == ChannelDimension.LAST:
353
+ patches = patches.transpose(0, 3, 1, 2)
354
+ t = patches.shape[0]
355
+ channel = patches.shape[1]
356
+ grid_h, grid_w = (
357
+ resized_height // self.patch_size,
358
+ resized_width // self.patch_size,
359
+ )
360
+ patches = patches.reshape(
361
+ t,
362
+ channel,
363
+ grid_h // self.merge_size,
364
+ self.merge_size,
365
+ self.patch_size,
366
+ grid_w // self.merge_size,
367
+ self.merge_size,
368
+ self.patch_size,
369
+ )
370
+ patches = patches.transpose(0, 2, 5, 3, 6, 1, 4, 7)
371
+ flatten_patches = patches.reshape(
372
+ t * grid_h * grid_w, channel * self.patch_size * self.patch_size
373
+ )
374
+
375
+ return flatten_patches, (t, grid_h, grid_w)
376
+
377
+ def preprocess(
378
+ self,
379
+ images: ImageInput,
380
+ do_resize: bool = None,
381
+ resample: PILImageResampling = None,
382
+ do_rescale: bool = None,
383
+ rescale_factor: float = None,
384
+ do_normalize: bool = None,
385
+ image_mean: Optional[Union[float, List[float]]] = None,
386
+ image_std: Optional[Union[float, List[float]]] = None,
387
+ do_convert_rgb: bool = None,
388
+ merge_size: Optional[Union[int, List[int]]] = None,
389
+ return_tensors: Optional[Union[str, TensorType]] = None,
390
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
391
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
392
+ ):
393
+ """
394
+ Args:
395
+ images (`ImageInput`):
396
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
397
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
398
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
399
+ Whether to resize the image.
400
+ resample (`int`, *optional*, defaults to `self.resample`):
401
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
402
+ has an effect if `do_resize` is set to `True`.
403
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
404
+ Whether to rescale the image.
405
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
406
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
407
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
408
+ Whether to normalize the image.
409
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
410
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
411
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
412
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
413
+ `True`.
414
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
415
+ Whether to convert the image to RGB.
416
+ return_tensors (`str` or `TensorType`, *optional*):
417
+ The type of tensors to return. Can be one of:
418
+ - Unset: Return a list of `np.ndarray`.
419
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
420
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
421
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
422
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
423
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
424
+ The channel dimension format for the output image. Can be one of:
425
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
426
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
427
+ - Unset: Use the channel dimension format of the input image.
428
+ input_data_format (`ChannelDimension` or `str`, *optional*):
429
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
430
+ from the input image. Can be one of:
431
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
432
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
433
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
434
+
435
+ """
436
+ do_resize = do_resize if do_resize is not None else self.do_resize
437
+ resample = resample if resample is not None else self.resample
438
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
439
+ rescale_factor = (
440
+ rescale_factor if rescale_factor is not None else self.rescale_factor
441
+ )
442
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
443
+ image_mean = image_mean if image_mean is not None else self.image_mean
444
+ image_std = image_std if image_std is not None else self.image_std
445
+ merge_size = merge_size if merge_size is not None else self.merge_size
446
+ do_convert_rgb = (
447
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
448
+ )
449
+
450
+ images = make_batched_images(images)
451
+
452
+ if isinstance(merge_size, (list, tuple)):
453
+ assert len(merge_size) == len(
454
+ images
455
+ ), "Merge size must be the same length as images."
456
+ merge_sizes = merge_size
457
+ else:
458
+ merge_sizes = [merge_size for _ in images]
459
+
460
+ if all(merge_size == merge_sizes[0] for merge_size in merge_sizes):
461
+ target_sizes = simple_batched_resize(
462
+ images,
463
+ factor=self.patch_size * merge_sizes[0],
464
+ min_tokens=self.min_tokens,
465
+ max_tokens=self.max_tokens,
466
+ input_data_format=input_data_format,
467
+ )
468
+ else:
469
+ target_sizes = batched_resize(
470
+ images,
471
+ factors=[self.patch_size * merge_size for merge_size in merge_sizes],
472
+ min_tokens=self.min_tokens,
473
+ max_tokens=self.max_tokens,
474
+ input_data_format=input_data_format,
475
+ )
476
+
477
+ pixel_values, grid_sizes = [], []
478
+ for image, merge_size, target_size in zip(images, merge_sizes, target_sizes):
479
+ patches, grid_size = self._preprocess(
480
+ image,
481
+ target_size=target_size,
482
+ do_resize=do_resize,
483
+ resample=resample,
484
+ do_rescale=do_rescale,
485
+ rescale_factor=rescale_factor,
486
+ do_normalize=do_normalize,
487
+ image_mean=image_mean,
488
+ image_std=image_std,
489
+ data_format=data_format,
490
+ do_convert_rgb=do_convert_rgb,
491
+ input_data_format=input_data_format,
492
+ )
493
+ pixel_values.append(patches)
494
+ grid_sizes.append(grid_size)
495
+
496
+ pixel_values = np.concatenate(pixel_values, axis=0)
497
+ grid_sizes = np.array(grid_sizes)
498
+ merge_sizes = np.array(merge_sizes)
499
+
500
+ data = {
501
+ "pixel_values": pixel_values,
502
+ "grid_sizes": grid_sizes,
503
+ "merge_sizes": merge_sizes,
504
+ }
505
+
506
+ return BatchFeature(data=data, tensor_type=return_tensors)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be27e95b71ce82d51a8e294a628e0926cc2f625addf9fe7c9a8ec835573821ea
3
+ size 195704976
modeling_siglip2_navit.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """PyTorch Siglip2NaViT vision encoder model."""
22
+
23
+ import importlib.util
24
+ import os.path as osp
25
+ import math
26
+ import warnings
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ from torch.nn.init import _calculate_fan_in_and_fan_out
33
+
34
+ from transformers.activations import ACT2FN
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import is_flash_attn_2_available
37
+
38
+ if is_flash_attn_2_available():
39
+ from flash_attn import flash_attn_varlen_func
40
+ else:
41
+ flash_attn_varlen_func = None
42
+
43
+ from .configuration_siglip2_navit import (
44
+ Siglip2NaViTVisionConfig,
45
+ )
46
+
47
+
48
+ def _trunc_normal_(tensor, mean, std, a, b):
49
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
50
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
51
+ def norm_cdf(x):
52
+ # Computes standard normal cumulative distribution function
53
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
54
+
55
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
56
+ warnings.warn(
57
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
58
+ "The distribution of values may be incorrect.",
59
+ stacklevel=2,
60
+ )
61
+
62
+ # Values are generated by using a truncated uniform distribution and
63
+ # then using the inverse CDF for the normal distribution.
64
+ # Get upper and lower cdf values
65
+ l = norm_cdf((a - mean) / std)
66
+ u = norm_cdf((b - mean) / std)
67
+
68
+ # Uniformly fill tensor with values from [l, u], then translate to
69
+ # [2l-1, 2u-1].
70
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
71
+
72
+ # Use inverse cdf transform for normal distribution to get truncated
73
+ # standard normal
74
+ tensor.erfinv_()
75
+
76
+ # Transform to proper mean, std
77
+ tensor.mul_(std * math.sqrt(2.0))
78
+ tensor.add_(mean)
79
+
80
+ # Clamp to ensure it's in the proper range
81
+ tensor.clamp_(min=a, max=b)
82
+
83
+
84
+ def trunc_normal_tf_(
85
+ tensor: torch.Tensor,
86
+ mean: float = 0.0,
87
+ std: float = 1.0,
88
+ a: float = -2.0,
89
+ b: float = 2.0,
90
+ ) -> torch.Tensor:
91
+ """Fills the input Tensor with values drawn from a truncated
92
+ normal distribution. The values are effectively drawn from the
93
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
94
+ with values outside :math:`[a, b]` redrawn until they are within
95
+ the bounds. The method used for generating the random values works
96
+ best when :math:`a \\leq \text{mean} \\leq b`.
97
+
98
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
99
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
100
+ and the result is subsequently scaled and shifted by the mean and std args.
101
+
102
+ Args:
103
+ tensor: an n-dimensional `torch.Tensor`
104
+ mean: the mean of the normal distribution
105
+ std: the standard deviation of the normal distribution
106
+ a: the minimum cutoff value
107
+ b: the maximum cutoff value
108
+ """
109
+ with torch.no_grad():
110
+ _trunc_normal_(tensor, 0, 1.0, a, b)
111
+ tensor.mul_(std).add_(mean)
112
+
113
+
114
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
115
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
116
+ if mode == "fan_in":
117
+ denom = fan_in
118
+ elif mode == "fan_out":
119
+ denom = fan_out
120
+ elif mode == "fan_avg":
121
+ denom = (fan_in + fan_out) / 2
122
+
123
+ variance = scale / denom
124
+
125
+ if distribution == "truncated_normal":
126
+ # constant is stddev of standard normal truncated to (-2, 2)
127
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
128
+ elif distribution == "normal":
129
+ with torch.no_grad():
130
+ tensor.normal_(std=math.sqrt(variance))
131
+ elif distribution == "uniform":
132
+ bound = math.sqrt(3 * variance)
133
+ with torch.no_grad():
134
+ tensor.uniform_(-bound, bound)
135
+ else:
136
+ raise ValueError(f"invalid distribution {distribution}")
137
+
138
+
139
+ def lecun_normal_(tensor):
140
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
141
+
142
+
143
+ def default_flax_embed_init(tensor):
144
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
145
+
146
+
147
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
148
+ def rotate_half(x):
149
+ """Rotates half the hidden dims of the input."""
150
+ x1 = x[..., : x.shape[-1] // 2]
151
+ x2 = x[..., x.shape[-1] // 2 :]
152
+ return torch.cat((-x2, x1), dim=-1)
153
+
154
+
155
+ def apply_rotary_pos_emb_vision(
156
+ tensor: torch.Tensor, freqs: torch.Tensor
157
+ ) -> torch.Tensor:
158
+ orig_dtype = tensor.dtype
159
+ tensor = tensor.float()
160
+ cos = freqs.cos()
161
+ sin = freqs.sin()
162
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
163
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
164
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
165
+ output = output.to(orig_dtype)
166
+ return output
167
+
168
+
169
+ class VisionRotaryEmbedding(nn.Module):
170
+
171
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
172
+ super().__init__()
173
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
174
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
175
+
176
+ def forward(self, seqlen: int) -> torch.Tensor:
177
+ seq = torch.arange(
178
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
179
+ )
180
+ freqs = torch.outer(seq, self.inv_freq)
181
+ return freqs
182
+
183
+
184
+ class Siglip2NaViTVisionEmbeddings(nn.Module):
185
+
186
+ def __init__(self, config: Siglip2NaViTVisionConfig):
187
+ super().__init__()
188
+ self.config = config
189
+ self.embed_dim = config.hidden_size
190
+ self.patch_size = config.patch_size
191
+
192
+ self.patch_embedding = nn.Conv2d(
193
+ in_channels=config.num_channels,
194
+ out_channels=self.embed_dim,
195
+ kernel_size=self.patch_size,
196
+ stride=self.patch_size,
197
+ padding="valid",
198
+ )
199
+
200
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
201
+ hidden_states = hidden_states.view(
202
+ -1, self.config.num_channels, self.patch_size, self.patch_size
203
+ )
204
+ patch_embeds = self.patch_embedding(hidden_states)
205
+ embeddings = patch_embeds.view(-1, self.embed_dim)
206
+
207
+ return embeddings
208
+
209
+
210
+ class VisionAttention(nn.Module):
211
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
212
+
213
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
214
+ def __init__(self, config):
215
+ super().__init__()
216
+ self.config = config
217
+ self.embed_dim = config.hidden_size
218
+ self.num_heads = config.num_attention_heads
219
+ self.head_dim = self.embed_dim // self.num_heads
220
+ if self.head_dim * self.num_heads != self.embed_dim:
221
+ raise ValueError(
222
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
223
+ f" {self.num_heads})."
224
+ )
225
+ self.scale = self.head_dim**-0.5
226
+ self.dropout = config.attention_dropout
227
+
228
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
229
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
230
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
231
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ cu_seqlens: torch.Tensor,
237
+ rotary_pos_emb: torch.Tensor = None,
238
+ ) -> torch.Tensor:
239
+ """Input shape: Time x Channel"""
240
+
241
+ q_len, _ = hidden_states.size()
242
+
243
+ query_states = self.q_proj(hidden_states)
244
+ key_states = self.k_proj(hidden_states)
245
+ value_states = self.v_proj(hidden_states)
246
+
247
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
248
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
249
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
250
+
251
+ query_states = apply_rotary_pos_emb_vision(
252
+ query_states.unsqueeze(0), rotary_pos_emb
253
+ ).squeeze(0)
254
+ key_states = apply_rotary_pos_emb_vision(
255
+ key_states.unsqueeze(0), rotary_pos_emb
256
+ ).squeeze(0)
257
+
258
+ attention_mask = torch.zeros(
259
+ [1, q_len, q_len], device=query_states.device, dtype=torch.bool
260
+ )
261
+ for i in range(1, len(cu_seqlens)):
262
+ attention_mask[
263
+ ...,
264
+ cu_seqlens[i - 1] : cu_seqlens[i],
265
+ cu_seqlens[i - 1] : cu_seqlens[i],
266
+ ] = True
267
+
268
+ query_states = query_states.transpose(0, 1)
269
+ key_states = key_states.transpose(0, 1)
270
+ value_states = value_states.transpose(0, 1)
271
+
272
+ attn_weights = torch.matmul(
273
+ query_states, key_states.transpose(1, 2)
274
+ ) / math.sqrt(self.head_dim)
275
+ attn_weights = attn_weights + attention_mask
276
+
277
+ # upcast attention to fp32
278
+ attn_weights = nn.functional.softmax(
279
+ attn_weights, dim=-1, dtype=torch.float32
280
+ ).to(query_states.dtype)
281
+ attn_weights = nn.functional.dropout(
282
+ attn_weights, p=self.dropout, training=self.training
283
+ )
284
+ attn_output = torch.matmul(attn_weights, value_states)
285
+
286
+ attn_output = attn_output.transpose(0, 1)
287
+ attn_output = attn_output.reshape(q_len, -1)
288
+ attn_output = self.out_proj(attn_output)
289
+
290
+ return attn_output
291
+
292
+
293
+ class VisionFlashAttention2(VisionAttention):
294
+
295
+ def __init__(self, *args, **kwargs):
296
+ super().__init__(*args, **kwargs)
297
+
298
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
299
+ def forward(
300
+ self,
301
+ hidden_states: torch.Tensor,
302
+ cu_seqlens: torch.Tensor,
303
+ rotary_pos_emb: torch.Tensor = None,
304
+ ) -> torch.Tensor:
305
+ q_len, _ = hidden_states.size()
306
+
307
+ query_states = self.q_proj(hidden_states)
308
+ key_states = self.k_proj(hidden_states)
309
+ value_states = self.v_proj(hidden_states)
310
+
311
+ # Flash attention requires the input to have the shape
312
+ # batch_size x seq_length x head_dim x hidden_dim
313
+ # therefore we just need to keep the original shape
314
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
315
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
316
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
317
+ query_states = apply_rotary_pos_emb_vision(
318
+ query_states.unsqueeze(0), rotary_pos_emb
319
+ ).squeeze(0)
320
+ key_states = apply_rotary_pos_emb_vision(
321
+ key_states.unsqueeze(0), rotary_pos_emb
322
+ ).squeeze(0)
323
+
324
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
325
+ attn_output = flash_attn_varlen_func(
326
+ query_states,
327
+ key_states,
328
+ value_states,
329
+ cu_seqlens,
330
+ cu_seqlens,
331
+ max_seqlen,
332
+ max_seqlen,
333
+ ).reshape(q_len, -1)
334
+ attn_output = self.out_proj(attn_output)
335
+
336
+ return attn_output
337
+
338
+
339
+ class VisionSdpaAttention(VisionAttention):
340
+
341
+ def forward(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ cu_seqlens: torch.Tensor,
345
+ rotary_pos_emb: torch.Tensor = None,
346
+ ) -> torch.Tensor:
347
+ seq_length = hidden_states.shape[0]
348
+ query_states = self.q_proj(hidden_states)
349
+ key_states = self.k_proj(hidden_states)
350
+ value_states = self.v_proj(hidden_states)
351
+
352
+ query_states = query_states.view(seq_length, self.num_heads, self.head_dim)
353
+ key_states = key_states.view(seq_length, self.num_heads, self.head_dim)
354
+ value_states = value_states.view(seq_length, self.num_heads, self.head_dim)
355
+
356
+ query_states = apply_rotary_pos_emb_vision(
357
+ query_states.unsqueeze(0), rotary_pos_emb
358
+ ).squeeze(0)
359
+ key_states = apply_rotary_pos_emb_vision(
360
+ key_states.unsqueeze(0), rotary_pos_emb
361
+ ).squeeze(0)
362
+
363
+ attention_mask = torch.zeros(
364
+ [1, seq_length, seq_length], device=query_states.device, dtype=torch.bool
365
+ )
366
+ for i in range(1, len(cu_seqlens)):
367
+ attention_mask[
368
+ ...,
369
+ cu_seqlens[i - 1] : cu_seqlens[i],
370
+ cu_seqlens[i - 1] : cu_seqlens[i],
371
+ ] = True
372
+
373
+ query_states = query_states.transpose(0, 1)
374
+ key_states = key_states.transpose(0, 1)
375
+ value_states = value_states.transpose(0, 1)
376
+ attn_output = F.scaled_dot_product_attention(
377
+ query_states, key_states, value_states, attention_mask, dropout_p=0.0
378
+ )
379
+ attn_output = attn_output.transpose(0, 1)
380
+ attn_output = attn_output.reshape(seq_length, -1)
381
+ attn_output = self.out_proj(attn_output)
382
+ return attn_output
383
+
384
+
385
+ VISION_ATTENTION_CLASSES = {
386
+ "eager": VisionAttention,
387
+ "flash_attention_2": VisionFlashAttention2,
388
+ "sdpa": VisionSdpaAttention,
389
+ }
390
+
391
+
392
+ class Siglip2NaViTVisionMerger(nn.Module):
393
+ def __init__(
394
+ self, config: Siglip2NaViTVisionConfig, use_postshuffle_norm: bool = False
395
+ ):
396
+ super().__init__()
397
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
398
+ self.use_postshuffle_norm = use_postshuffle_norm
399
+ self.norm = nn.LayerNorm(
400
+ self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6
401
+ )
402
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
403
+ self.act_fn = nn.GELU()
404
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
405
+
406
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
407
+ x = self.norm(
408
+ x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x
409
+ ).view(-1, self.hidden_size)
410
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
411
+ return x
412
+
413
+
414
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip2NaViT
415
+ class Siglip2NaViTVisionMLP(nn.Module):
416
+
417
+ def __init__(self, config):
418
+ super().__init__()
419
+ self.config = config
420
+ self.activation_fn = ACT2FN[config.hidden_act]
421
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
422
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
423
+
424
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
425
+ hidden_states = self.fc1(hidden_states)
426
+ hidden_states = self.activation_fn(hidden_states)
427
+ hidden_states = self.fc2(hidden_states)
428
+ return hidden_states
429
+
430
+
431
+ class Siglip2NaViTVisionEncoderLayer(nn.Module):
432
+
433
+ def __init__(self, config: Siglip2NaViTVisionConfig):
434
+ super().__init__()
435
+ self.embed_dim = config.hidden_size
436
+ self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](
437
+ config=config
438
+ )
439
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
440
+ self.mlp = Siglip2NaViTVisionMLP(config)
441
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
442
+
443
+ # Ignore copy
444
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
445
+ hidden_states = hidden_states + self.self_attn(
446
+ self.layer_norm1(hidden_states),
447
+ cu_seqlens=cu_seqlens,
448
+ rotary_pos_emb=rotary_pos_emb,
449
+ )
450
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
451
+ return hidden_states
452
+
453
+
454
+ class Siglip2NaViTVisionTransformerEncoder(nn.Module):
455
+
456
+ def __init__(self, config: Siglip2NaViTVisionConfig):
457
+ super().__init__()
458
+ self.config = config
459
+ head_dim = config.hidden_size // config.num_attention_heads
460
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
461
+ self.layers = nn.ModuleList(
462
+ [
463
+ Siglip2NaViTVisionEncoderLayer(config)
464
+ for _ in range(config.num_hidden_layers)
465
+ ]
466
+ )
467
+
468
+ def rot_pos_emb(self, grid_sizes, merge_sizes):
469
+ pos_ids = []
470
+ for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
471
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
472
+ hpos_ids = hpos_ids.reshape(
473
+ h // merge_size,
474
+ merge_size,
475
+ w // merge_size,
476
+ merge_size,
477
+ )
478
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
479
+ hpos_ids = hpos_ids.flatten()
480
+
481
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
482
+ wpos_ids = wpos_ids.reshape(
483
+ h // merge_size,
484
+ merge_size,
485
+ w // merge_size,
486
+ merge_size,
487
+ )
488
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
489
+ wpos_ids = wpos_ids.flatten()
490
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
491
+
492
+ pos_ids = torch.cat(pos_ids, dim=0)
493
+ max_grid_size = grid_sizes[:, 1:].max()
494
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
495
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
496
+
497
+ return rotary_pos_emb
498
+
499
+ def forward(self, hidden_states, grid_sizes, merge_sizes) -> torch.Tensor:
500
+ rotary_pos_emb = self.rot_pos_emb(grid_sizes, merge_sizes)
501
+
502
+ cu_seqlens = torch.repeat_interleave(
503
+ grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]
504
+ ).cumsum(dim=0, dtype=torch.int32)
505
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
506
+
507
+ for blk in self.layers:
508
+ hidden_states = blk(
509
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
510
+ )
511
+
512
+ return hidden_states
513
+
514
+
515
+ class Siglip2NaViTVisionModel(PreTrainedModel):
516
+
517
+ config_class = Siglip2NaViTVisionConfig
518
+ base_model_prefix = "siglip2_navit"
519
+ main_input_name = "pixel_values"
520
+ _no_split_modules = [
521
+ "Siglip2NaViTVisionEncoderLayer",
522
+ "Siglip2NaViTVisionEmbeddings",
523
+ ]
524
+ _supports_flash_attn_2 = True
525
+ _supports_sdpa = True
526
+
527
+ def __init__(self, config: Siglip2NaViTVisionConfig):
528
+ super().__init__(config=config)
529
+ embed_dim = config.hidden_size
530
+
531
+ self.embeddings = Siglip2NaViTVisionEmbeddings(config)
532
+ self.encoder = Siglip2NaViTVisionTransformerEncoder(config)
533
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
534
+ self.merger = Siglip2NaViTVisionMerger(config)
535
+
536
+ self.post_init()
537
+
538
+ def forward(self, pixel_values, grid_sizes, merge_sizes) -> torch.Tensor:
539
+ hidden_states = self.embeddings(pixel_values)
540
+ hidden_states = self.encoder(hidden_states, grid_sizes, merge_sizes)
541
+ hidden_states = self.post_layernorm(hidden_states)
542
+
543
+ ret = self.merger(hidden_states) # (#num_patches, out_hidden_size)
544
+ return ret
545
+
546
+ def _init_weights(self, module):
547
+ """Initialize the weights"""
548
+ if isinstance(module, nn.Embedding):
549
+ default_flax_embed_init(module.weight)
550
+ elif isinstance(module, VisionAttention):
551
+ nn.init.xavier_uniform_(module.q_proj.weight)
552
+ nn.init.xavier_uniform_(module.k_proj.weight)
553
+ nn.init.xavier_uniform_(module.v_proj.weight)
554
+ nn.init.xavier_uniform_(module.out_proj.weight)
555
+ nn.init.zeros_(module.q_proj.bias)
556
+ nn.init.zeros_(module.k_proj.bias)
557
+ nn.init.zeros_(module.v_proj.bias)
558
+ nn.init.zeros_(module.out_proj.bias)
559
+ elif isinstance(module, Siglip2NaViTVisionMLP):
560
+ nn.init.xavier_uniform_(module.fc1.weight)
561
+ nn.init.xavier_uniform_(module.fc2.weight)
562
+ nn.init.normal_(module.fc1.bias, std=1e-6)
563
+ nn.init.normal_(module.fc2.bias, std=1e-6)
564
+ elif isinstance(module, Siglip2NaViTVisionMerger):
565
+ nn.init.xavier_uniform_(module.linear_fc1.weight)
566
+ nn.init.xavier_uniform_(module.linear_fc2.weight)
567
+ nn.init.normal_(module.linear_fc1.bias, std=1e-6)
568
+ nn.init.normal_(module.linear_fc2.bias, std=1e-6)
569
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
570
+ lecun_normal_(module.weight)
571
+ if module.bias is not None:
572
+ nn.init.zeros_(module.bias)
573
+ elif isinstance(module, nn.LayerNorm):
574
+ module.bias.data.zero_()
575
+ module.weight.data.fill_(1.0)
preprocessor_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processor_siglip2_navit.SigLIP2NaViTImageProcessor"
4
+ },
5
+ "do_convert_rgb": true,
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_processor_type": "SigLIP2NaViTImageProcessor",
15
+ "image_std": [
16
+ 0.5,
17
+ 0.5,
18
+ 0.5
19
+ ],
20
+ "max_tokens": 196,
21
+ "merge_size": 2,
22
+ "min_tokens": 32,
23
+ "patch_size": 16,
24
+ "resample": 3,
25
+ "rescale_factor": 0.00392156862745098
26
+ }