XQce commited on
Commit
07bc10c
·
verified ·
1 Parent(s): 2d2b5ba

Upload 5 files

Browse files
Files changed (5) hide show
  1. modeling_projector.py +162 -0
  2. modeling_valley.py +556 -0
  3. modeling_vision_tower.py +165 -0
  4. processing_valley.py +313 -0
  5. utils.py +251 -0
modeling_projector.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ def build_vision_projector(config, delay_load=False, **kwargs):
6
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
7
+
8
+ if projector_type == 'conv_adapter':
9
+ return ConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", None))
10
+ elif projector_type == 'mlp_pixel_shuffle':
11
+ return MlpPixelShuffle(config.mm_hidden_size, config.hidden_size,
12
+ config.pixelshuffle_downsample_ratio, getattr(config, "mlp_hidden_dim", None))
13
+ elif projector_type == 'ovis_conv_adapter':
14
+ return OvisConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", 32000),
15
+ getattr(config, "tokenize_function", "softmax"))
16
+ raise ValueError(f'Unknown projector type: {projector_type}')
17
+
18
+
19
+ class ConvAdapter(nn.Module):
20
+ def __init__(self, dim_in, dim_out, mlp_hidden_dim=None):
21
+ super().__init__()
22
+ self.mm_projector_type = 'conv_adapter'
23
+ if mlp_hidden_dim is None:
24
+ self.mlp = nn.Sequential(
25
+ nn.Linear(dim_in, dim_out),
26
+ nn.GELU(),
27
+ nn.Linear(dim_out, dim_out)
28
+ )
29
+ else:
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(dim_in, mlp_hidden_dim),
32
+ nn.GELU(),
33
+ nn.Linear(mlp_hidden_dim, dim_out)
34
+ )
35
+ self.conv = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
36
+
37
+ def forward(self, x):
38
+ """
39
+ Args:
40
+ x (torch.Tensor): image features
41
+ shape (F, v, D)
42
+ Returns:
43
+ shape (F, n, D) where n is token_num that has been reduced
44
+ """
45
+ x = self.mlp(x)
46
+
47
+ f, v, d = x.shape
48
+ s = int(math.sqrt(v - 1))
49
+ x = x[:, 1:, :] # remove cls_token
50
+ x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
51
+ x = self.conv(x)
52
+ x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
53
+ return x
54
+
55
+
56
+ class MlpPixelShuffle(nn.Module):
57
+ def __init__(self, dim_in, dim_out, pixelshuffle_downsample_ratio, mlp_hidden_dim=None):
58
+ super().__init__()
59
+ self.mm_projector_type = 'mlp_pixel_shuffle'
60
+ if mlp_hidden_dim is None:
61
+ self.mlp = nn.Sequential(
62
+ nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), dim_out),
63
+ nn.GELU(),
64
+ nn.Linear(dim_out, dim_out)
65
+ )
66
+ else:
67
+ self.mlp = nn.Sequential(
68
+ nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), mlp_hidden_dim),
69
+ nn.GELU(),
70
+ nn.Linear(mlp_hidden_dim, dim_out)
71
+ )
72
+ self.scale_factor = pixelshuffle_downsample_ratio
73
+
74
+ def pixel_shuffle(self, x, scale_factor=2):
75
+ # change scale_factor from float to int
76
+
77
+ n, w, h, c = x.size()
78
+ # N, W, H, C --> N, W, H / scale, C * scale
79
+ x = x.view(n, w, int(h / scale_factor), int(c * scale_factor))
80
+ # N, W, H / scale, C * scale --> N, H / scale, W, C * scale
81
+ x = x.permute(0, 2, 1, 3).contiguous()
82
+ # N, H / scale, W, C * scale --> N, H / scale, W / scale, C * (scale ** 2)
83
+ x = x.view(n, int(h / scale_factor), int(w / scale_factor),
84
+ int(c * (scale_factor * scale_factor)))
85
+
86
+ x = x.permute(0, 2, 1, 3).contiguous()
87
+
88
+ return x
89
+
90
+ def forward(self, x):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): image features
94
+ shape (F, v, D)
95
+ Returns:
96
+ shape (F, n, D) where n is token_num that has been reduced
97
+ """
98
+ x = x[:, 1:, :] # remove cls_token
99
+ h = w = int(x.shape[1] ** 0.5)
100
+ x = x.view(x.shape[0], h, w, -1)
101
+ x = self.pixel_shuffle(x, self.scale_factor)
102
+ x = self.mlp(x)
103
+ x = x.view(x.shape[0],-1,x.shape[-1])
104
+ return x
105
+
106
+
107
+ class OvisConvAdapter(nn.Module):
108
+ def __init__(self, dim_in, dim_out, vocab_size, tokenize_function="softmax"):
109
+ super().__init__()
110
+ self.mm_projector_type = 'ovis_conv_adapter'
111
+ self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1)
112
+ self.mlp = torch.nn.Sequential(
113
+ torch.nn.Linear(dim_in, vocab_size, bias=False),
114
+ torch.nn.LayerNorm(vocab_size)
115
+ )
116
+ self.embedding = torch.nn.Embedding(vocab_size, dim_out)
117
+ self.tokenize_function = tokenize_function
118
+
119
+ def tokenize(self, logits):
120
+ def st_argmax(y_soft, dim): # straight-through softmax
121
+ index = y_soft.max(dim, keepdim=True)[1]
122
+ y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
123
+ ret = y_hard - y_soft.detach() + y_soft
124
+ return ret
125
+
126
+ if self.tokenize_function == 'softmax':
127
+ tokens = torch.nn.functional.softmax(logits, dim=-1)
128
+ elif self.tokenize_function == 'gumbel_argmax':
129
+ tokens = torch.nn.functional.gumbel_softmax(logits, tau=self.config.tau, hard=True)
130
+ elif self.tokenize_function == 'st_argmax':
131
+ tokens = st_argmax(logits, dim=-1)
132
+ else:
133
+ raise ValueError(
134
+ 'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax,'
135
+ f' but got {self.config.tokenize_function}'
136
+ )
137
+ return tokens
138
+
139
+ def forward(self, x):
140
+ """
141
+ Args:
142
+ x (torch.Tensor): image features
143
+ shape (F, v, D)
144
+ Returns:
145
+ shape (F, n, D) where n is token_num that has been reduced
146
+ """
147
+ # conv
148
+ f, v, d = x.shape
149
+ s = int(math.sqrt(v - 1))
150
+ x = x[:, 1:, :] # remove cls_token
151
+ x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
152
+ x = self.conv(x)
153
+ x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
154
+
155
+ # tokenize
156
+ logits = self.mlp(x)
157
+ visual_tokens = self.tokenize(logits)
158
+
159
+ # get embeddings
160
+ out = torch.matmul(visual_tokens, self.embedding.weight)
161
+
162
+ return out
modeling_valley.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import numpy as np
17
+ from torch import nn
18
+ from torch.nn import CrossEntropyLoss
19
+ from abc import ABC, abstractmethod
20
+ from typing import List, Optional, Tuple, Union, Dict, Any
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2ForCausalLM, Qwen2Model
23
+
24
+ from .modeling_vision_tower import build_vision_tower
25
+ from .modeling_projector import build_vision_projector
26
+ from .utils import get_anyres_image_grid_shape, unpad_image, IGNORE_INDEX, IMAGE_TOKEN_INDEX
27
+
28
+
29
+ class ValleyConfig(Qwen2Config):
30
+ model_type = "valley"
31
+
32
+ class ValleyMetaModel:
33
+ def __init__(self, config):
34
+ super(ValleyMetaModel, self).__init__(config)
35
+ # Build vision tower
36
+ if hasattr(config, "mm_vision_tower"):
37
+ if getattr(config, "eagle_vision_tower", None) is not None:
38
+ self.vision_tower, self.qwen2vl_vision_tower = build_vision_tower(config, delay_load=False)
39
+ else:
40
+ self.vision_tower = build_vision_tower(config, delay_load=False)
41
+ # Build Projector
42
+ if hasattr(config, "mm_projector_type") and not getattr(config, "only_navit", False):
43
+ self.mm_projector = build_vision_projector(config)
44
+
45
+ def get_vision_tower(self):
46
+ vision_tower = getattr(self, "vision_tower", None)
47
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
48
+ qwen2vl_vision_tower = getattr(self, "qwen2vl_vision_tower", None)
49
+ return vision_tower, qwen2vl_vision_tower
50
+ else:
51
+ return vision_tower
52
+
53
+ class ValleyMetaForCausalLM(ABC):
54
+ @abstractmethod
55
+ def get_model(self):
56
+ pass
57
+
58
+ def get_vision_tower(self):
59
+ return self.get_model().get_vision_tower()
60
+
61
+ def split_by_instance(self, original_list, split_sizes):
62
+ start = 0
63
+ sub_lists = []
64
+ for size in split_sizes:
65
+ end = start + size
66
+ sub_list = original_list[start:end]
67
+ sub_lists.append([x.to(self.device) for x in sub_list])
68
+ start = end
69
+ return sub_lists
70
+
71
+ def encode_images_qwen2vl(self, pixel_values = None, grid_thw = None, split_sizes=None):
72
+ _, qwen2vl_vision_tower = self.get_model().get_vision_tower()
73
+ qwen2vl_image_features = qwen2vl_vision_tower(pixel_values, grid_thw)
74
+ qwen2vl_image_split_sizes = torch.prod(grid_thw[:, 1:3]//2, dim=1)
75
+ qwen2vl_image_features = torch.split(qwen2vl_image_features, qwen2vl_image_split_sizes.tolist(), dim=0)
76
+ qwen2vl_image_features = self.split_by_instance(qwen2vl_image_features, split_sizes)
77
+ return qwen2vl_image_features
78
+
79
+ def encode_images(self, images = None, split_sizes = None):
80
+ """
81
+ images: (if not anyres) images.shape = [n,3,336,336] , n = number of images + (number of video) * 8
82
+ images: (if anyres) images.shape = [n,3,336,336] , n = number of tiles * number of images
83
+ """
84
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
85
+ siglip_vision_tower, _ = self.get_model().get_vision_tower()
86
+ image_features = siglip_vision_tower(images)
87
+ image_features = self.get_model().mm_projector(image_features)
88
+ else:
89
+ image_features = self.get_model().get_vision_tower()(images)
90
+ image_features = self.get_model().mm_projector(image_features)
91
+
92
+ if getattr(self.config,'anyres', False) and getattr(self.config, 'max_vision_token', None) is not None:
93
+ assert split_sizes is not None
94
+ image_features = list(torch.split(image_features, split_sizes, dim=0))
95
+ for i, image_feature in enumerate(image_features):
96
+ hidden_dim = image_feature.shape[-1]
97
+ image_tokens = image_feature.shape[0]*image_feature.shape[1]
98
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
99
+ pass # the max_vision_token will be processed in the unpad image token part
100
+ else:
101
+ if image_tokens > self.config.max_vision_token:
102
+ intput_shape = int((image_feature.shape[1])**0.5)
103
+ output_shape = int((self.config.max_vision_token/image_feature.shape[0])**0.5)
104
+ image_feature = image_feature.view(image_feature.shape[0],intput_shape, intput_shape, -1).permute(0,3,1,2)
105
+ m = nn.AdaptiveAvgPool2d(output_shape) # different from roi pooling, but in square image, it seems the same
106
+ pooling_feature = m(image_feature).permute(0,2,3,1)
107
+ image_features[i] = pooling_feature.view(image_feature.shape[0], -1, hidden_dim)
108
+ split_sizes = None # have already split, set the flag
109
+
110
+ if getattr(self.config, 'mm_use_im_start_end', False):
111
+ raise ValueError('mm_use_im_start is not support')
112
+ if split_sizes is not None:
113
+ image_features = torch.split(image_features, split_sizes, dim=0)
114
+
115
+ return image_features
116
+
117
+ def get_padding_method(self):
118
+ right_padding = getattr(self, 'right_padding', None)
119
+ # if right_padding flag is setted, ignore training flag.
120
+ if right_padding is not None:
121
+ method = 'right' if right_padding else 'left'
122
+ # in the other way, use training flag to determine the padding method.
123
+ method = 'right' if self.training else 'left'
124
+
125
+ return method
126
+
127
+ def prepare_inputs_labels_for_multimodal(
128
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images,
129
+ image_sizes, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw):
130
+
131
+ vision_tower = self.get_vision_tower()
132
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
133
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
134
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
135
+ attention_mask = torch.cat((attention_mask, torch.ones(
136
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
137
+ dtype=attention_mask.dtype,
138
+ device=attention_mask.device
139
+ )), dim=1)
140
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
141
+
142
+ # Step1: Get image embedings
143
+ if type(images) is list or images.ndim == 5:
144
+ # Without slicing the image
145
+ if not getattr(self.config,'anyres', False):
146
+ concat_images = torch.cat([image for image in images], dim=0) # to do batch compute
147
+ split_sizes = [image.shape[0] for image in images]
148
+
149
+ # Get vision tower feature, check whether only use navit firstly
150
+ if getattr(self.config, 'eagle_vision_tower', None) is not None and getattr(self.config, 'only_navit', False):
151
+ image_features = None
152
+ else:
153
+ image_features = self.encode_images(concat_images, split_sizes)
154
+ image_features = [x.to(self.device) for x in image_features]
155
+
156
+ # Get Eagle features
157
+ if getattr(self.config, 'eagle_vision_tower', None) is not None:
158
+ if pixel_values is not None:
159
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values, image_grid_thw, split_sizes)
160
+ elif pixel_values_videos is not None:
161
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values_videos, video_grid_thw, split_sizes)
162
+ else:
163
+ qwen2vl_image_features = None
164
+
165
+ # Slicing the image, each image contains some sub_images:
166
+ # images = [
167
+ # [image1_tiles(n1,3,336,336), image2_tiles(n2,3,336,336), ...],
168
+ # [image1_tiles(n1,3,336,336), image2_tiles(n2,3,336,336), ...], ...
169
+ # ]
170
+ else:
171
+ split_sizes = [len(image) for image in images]
172
+ # Get Eagle features
173
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
174
+ if pixel_values is not None:
175
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values, image_grid_thw, split_sizes)
176
+ elif pixel_values_videos is not None:
177
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values_videos, video_grid_thw, split_sizes)
178
+ else:
179
+ qwen2vl_image_features = None
180
+
181
+ # Get vision tower feature, check whether only use navit firstly
182
+ if getattr(self.config, 'eagle_vision_tower', None) is not None and getattr(self.config, 'only_navit', False):
183
+ image_features = None
184
+ else:
185
+ image_features = []
186
+ all_concat_images = []
187
+ all_split_sizes = []
188
+ for batch_images in images:
189
+ concat_images = torch.cat([image for image in batch_images], dim=0) # to do batch compute
190
+ split_sizes = [image.shape[0] for image in batch_images]
191
+ all_concat_images.append(concat_images)
192
+ all_split_sizes.append(split_sizes)
193
+ all_image_features = self.encode_images(images=torch.cat(all_concat_images, dim=0), split_sizes=sum(all_split_sizes, []))
194
+
195
+ idx = 0
196
+ for split_sizes in all_split_sizes:
197
+ batch_image_features = all_image_features[idx:idx+len(split_sizes)]
198
+ idx += len(split_sizes)
199
+ if type(batch_image_features[0]) is list:
200
+ batch_image_features = [torch.cat(x).to(self.device) for x in batch_image_features]
201
+ else:
202
+ batch_image_features = [x.view(-1,x.shape[-1]).to(self.device) for x in batch_image_features] # tiles feature need to flatten in token dimention, [n_tiles, T, d] -> [n_tiles * T, d]
203
+ image_features.append(batch_image_features)
204
+
205
+ if getattr(self.config, "eagle_vision_tower", None) is not None and getattr(self.config, 'only_navit', False) == False:
206
+ # unpad image tokens
207
+ height = width = self.config.num_patches_per_side
208
+ new_image_features = []
209
+ for batch_image_features, batch_image_sizes in zip(image_features, image_sizes):
210
+ batch_image_features_list = []
211
+ for cur_image_feature, cur_image_size in zip(batch_image_features, batch_image_sizes):
212
+ base_image_feature = cur_image_feature[:width*height, :]
213
+ image_feature = cur_image_feature[width*height:, :]
214
+ if image_feature.shape[0] != 0:
215
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
216
+ cur_image_size,
217
+ self.config.grid_pinpoints,
218
+ self.config.vit_crop_size
219
+ )
220
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) # (num_patch_H, num_patch_W, H, W, C)
221
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() # (C, num_patch_H, H, num_patch_W, W)
222
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3) # (C, num_token_H, num_token_W)
223
+ image_feature = unpad_image(image_feature, cur_image_size) # (C, num_token_H_unpad, num_token_W_unpad)
224
+ input_shape = (image_feature.shape[-2], image_feature.shape[-1])
225
+ subimage_tokens = np.prod(input_shape)
226
+
227
+ # adaptive avg 2d pool for reducing token num
228
+ max_subimage_tokens = self.config.max_vision_token-width*height
229
+ if subimage_tokens > max_subimage_tokens:
230
+ aspect_ratio = input_shape[0] / input_shape[1]
231
+ output_shape = (
232
+ int((max_subimage_tokens/aspect_ratio)**0.5*aspect_ratio),
233
+ int((max_subimage_tokens/aspect_ratio)**0.5)
234
+ )
235
+ m = nn.AdaptiveAvgPool2d(output_shape)
236
+ image_feature = m(image_feature)
237
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
238
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
239
+ else:
240
+ image_feature = cur_image_feature
241
+ batch_image_features_list.append(image_feature)
242
+ new_image_features.append(batch_image_features_list)
243
+
244
+ image_features = new_image_features
245
+
246
+ else:
247
+ image_features = self.encode_images(images).to(self.device)
248
+
249
+
250
+ # Step2: Iterate through each sample in the batch, insert image embedings into input_embeds
251
+ # and filling labels, attention mask at the same time. Finally, get `new_input_embed`,
252
+ # `new_labels`, new_attention_mask`.
253
+ _labels = labels
254
+ _position_ids = position_ids
255
+ _attention_mask = attention_mask
256
+ if attention_mask is None:
257
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
258
+ if position_ids is None:
259
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
260
+ if labels is None:
261
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
262
+
263
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask.bool())]
264
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask.bool())]
265
+ attention_mask = [cur_attention_mask[cur_attention_mask.bool()] for cur_attention_mask in attention_mask]
266
+ new_input_embeds = []
267
+ new_labels = []
268
+ new_attention_mask = []
269
+
270
+ for batch_idx, cur_input_ids in enumerate(input_ids):
271
+ cur_batch_image_idx = 0
272
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
273
+
274
+ # Step2-1: If this piece of data is pure text, then concat a dummy image to ensure the whole compute graph is same on all device
275
+ if num_images == 0:
276
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
277
+ if getattr(self.config, 'only_navit', False):
278
+ cur_image_features = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
279
+ else:
280
+ siglip_feat = image_features[batch_idx][cur_batch_image_idx]
281
+ try:
282
+ qwen2vl_feat = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
283
+ cur_image_features = torch.cat((siglip_feat, qwen2vl_feat), dim=0)
284
+ except Exception as e:
285
+ print(e)
286
+ print("only siglip feature:", siglip_feat.shape)
287
+ cur_image_features = siglip_feat
288
+ else:
289
+ cur_image_features = image_features[batch_idx][cur_batch_image_idx]
290
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
291
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features.squeeze(0)[0:0]], dim=0)
292
+ new_input_embeds.append(cur_input_embeds)
293
+ new_labels.append(labels[batch_idx])
294
+ new_attention_mask.append(attention_mask[batch_idx])
295
+ cur_batch_image_idx += 1
296
+ continue
297
+
298
+ # Step2-2: Split input_ids, labels, attention_mask by IMAGE_TOKEN_INDEX
299
+ cur_input_ids_noim, cur_labels_noim, cur_attention_mask_noim = [], [], []
300
+ cur_labels = labels[batch_idx]
301
+ cur_attention_mask = attention_mask[batch_idx]
302
+ cur_img_attention_mask = [
303
+ attention_mask[batch_idx][i].item()
304
+ for i in torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
305
+ ]
306
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
307
+ for i in range(len(image_token_indices) - 1):
308
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
309
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
310
+ cur_attention_mask_noim.append(cur_attention_mask[image_token_indices[i]+1:image_token_indices[i+1]])
311
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
312
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
313
+ cur_input_embeds_no_im = list(torch.split(cur_input_embeds, split_sizes, dim=0))# get text features
314
+
315
+ # Step2-3: Insert image embedings
316
+ cur_new_input_embeds, cur_new_labels, cur_new_attention_mask = [], [], []
317
+ for i in range(num_images + 1): # to add multimodal feature internal the text feature
318
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
319
+ cur_new_labels.append(cur_labels_noim[i])
320
+ cur_new_attention_mask.append(cur_attention_mask_noim[i])
321
+ if i < num_images:
322
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
323
+ if getattr(self.config, 'only_navit', False):
324
+ cur_image_features = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
325
+ else:
326
+ siglip_feat = image_features[batch_idx][cur_batch_image_idx]
327
+ try:
328
+ qwen2vl_feat = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
329
+ cur_image_features = torch.cat((siglip_feat, qwen2vl_feat), dim=0)
330
+ except Exception as e:
331
+ print(e)
332
+ print("only siglip feature:", siglip_feat.shape)
333
+ cur_image_features = siglip_feat
334
+ else:
335
+ cur_image_features = image_features[batch_idx][cur_batch_image_idx]
336
+ cur_batch_image_idx += 1
337
+ cur_new_input_embeds.append(cur_image_features)
338
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
339
+ cur_new_attention_mask.append(torch.full((cur_image_features.shape[0],), True, device=cur_attention_mask.device, dtype=cur_attention_mask.dtype))
340
+
341
+ # Step2-4: Concat image embedings and text embedings
342
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
343
+ cur_new_labels = torch.cat(cur_new_labels)
344
+ cur_new_attention_mask = torch.cat(cur_new_attention_mask)
345
+ new_input_embeds.append(cur_new_input_embeds)
346
+ new_labels.append(cur_new_labels)
347
+ new_attention_mask.append(cur_new_attention_mask)
348
+
349
+ # Step3: Truncate sequences to max length as image embeddings can make the sequence longer
350
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
351
+ if tokenizer_model_max_length is not None:
352
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
353
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
354
+ new_attention_mask = [x[:tokenizer_model_max_length] for x in new_attention_mask]
355
+
356
+ # Step4: Pad and stack input_embeds, labels, attention_mask
357
+ max_len = max(x.shape[0] for x in new_input_embeds)
358
+ batch_size = len(new_input_embeds)
359
+ new_input_embeds_padded = []
360
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
361
+ new_attention_mask_padded = torch.zeros((batch_size, max_len), dtype=new_attention_mask[0].dtype, device=new_attention_mask[0].device)
362
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
363
+
364
+ for i, (cur_new_embed, cur_new_labels, cur_attention_mask) in enumerate(zip(new_input_embeds, new_labels, new_attention_mask)):
365
+ cur_len = cur_new_embed.shape[0]
366
+ if self.get_padding_method() == 'left':
367
+ new_input_embeds_padded.append(torch.cat((
368
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
369
+ cur_new_embed
370
+ ), dim=0))
371
+ if cur_len > 0:
372
+ new_labels_padded[i, -cur_len:] = cur_new_labels
373
+ new_attention_mask_padded[i, -cur_len:] = cur_attention_mask
374
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
375
+
376
+ else:
377
+ new_input_embeds_padded.append(torch.cat((
378
+ cur_new_embed,
379
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
380
+ ), dim=0))
381
+ if cur_len > 0:
382
+ new_labels_padded[i, :cur_len] = cur_new_labels
383
+ new_attention_mask_padded[i, :cur_len] = cur_attention_mask
384
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
385
+
386
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
387
+ new_labels = new_labels_padded if _labels is not None else None
388
+ new_attention_mask = new_attention_mask_padded if _attention_mask is not None else None
389
+ if _position_ids is None:
390
+ position_ids = None
391
+
392
+ return None, position_ids, new_attention_mask, past_key_values, new_input_embeds, new_labels
393
+
394
+
395
+ class ValleyQwen2Model(ValleyMetaModel, Qwen2Model):
396
+ config_class = ValleyConfig
397
+ def __init__(self, config: Qwen2Config):
398
+ super(ValleyQwen2Model, self).__init__(config)
399
+
400
+
401
+ class ValleyQwen2ForCausalLM(Qwen2ForCausalLM, ValleyMetaForCausalLM):
402
+ config_class = ValleyConfig
403
+
404
+ def __init__(self, config):
405
+ super(Qwen2ForCausalLM, self).__init__(config)
406
+ self.model = ValleyQwen2Model(config)
407
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
408
+ self.post_init()
409
+
410
+ def get_model(self):
411
+ return self.model
412
+
413
+ def _update_model_kwargs_for_generation(
414
+ self,
415
+ outputs: CausalLMOutputWithPast,
416
+ model_kwargs: Dict[str, Any],
417
+ is_encoder_decoder: bool = False,
418
+ num_new_tokens: int = 1,
419
+ ) -> Dict[str, Any]:
420
+ new_model_kwargs = super()._update_model_kwargs_for_generation(
421
+ outputs,
422
+ model_kwargs,
423
+ is_encoder_decoder,
424
+ num_new_tokens
425
+ )
426
+ """
427
+ Set model_kwargs["attention_mask"] to the expanded `attention_mask` in
428
+ the `prepare_inputs_labels_for_multimodal` function to ensure the
429
+ correctness of the generate behavior when `use_cache` is enabled.
430
+ """
431
+ if not is_encoder_decoder:
432
+ if "attention_mask" in new_model_kwargs:
433
+ attention_mask = outputs.attention_mask
434
+ new_model_kwargs["attention_mask"] = torch.cat(
435
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
436
+ )
437
+ return new_model_kwargs
438
+
439
+
440
+ def forward(
441
+ self,
442
+ input_ids: torch.LongTensor = None,
443
+ attention_mask: Optional[torch.Tensor] = None,
444
+ position_ids: Optional[torch.LongTensor] = None,
445
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
446
+ inputs_embeds: Optional[torch.FloatTensor] = None,
447
+ labels: Optional[torch.LongTensor] = None,
448
+ use_cache: Optional[bool] = None,
449
+ output_attentions: Optional[bool] = None,
450
+ output_hidden_states: Optional[bool] = None,
451
+ images: Optional[torch.FloatTensor] = None,
452
+ return_dict: Optional[bool] = None,
453
+ image_sizes: Optional[List[List[int]]] = None,
454
+ pixel_values: Optional[torch.Tensor] = None,
455
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
456
+ image_grid_thw: Optional[torch.LongTensor] = None,
457
+ video_grid_thw: Optional[torch.LongTensor] = None,
458
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
459
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
460
+ output_hidden_states = (
461
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
462
+ )
463
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
464
+
465
+ if inputs_embeds is None:
466
+ (
467
+ input_ids,
468
+ position_ids,
469
+ attention_mask,
470
+ past_key_values,
471
+ inputs_embeds,
472
+ labels
473
+ ) = self.prepare_inputs_labels_for_multimodal(
474
+ input_ids,
475
+ position_ids,
476
+ attention_mask,
477
+ past_key_values,
478
+ labels,
479
+ images,
480
+ image_sizes,
481
+ pixel_values,
482
+ pixel_values_videos,
483
+ image_grid_thw,
484
+ video_grid_thw,
485
+ )
486
+
487
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
488
+ outputs = self.model(
489
+ input_ids=input_ids,
490
+ attention_mask=attention_mask,
491
+ position_ids=position_ids,
492
+ past_key_values=past_key_values,
493
+ inputs_embeds=inputs_embeds,
494
+ use_cache=use_cache,
495
+ output_attentions=output_attentions,
496
+ output_hidden_states=output_hidden_states,
497
+ return_dict=return_dict,
498
+ )
499
+
500
+ hidden_states = outputs[0]
501
+ logits = self.lm_head(hidden_states)
502
+
503
+ loss = None
504
+ if labels is not None:
505
+ # Shift so that tokens < n predict n
506
+ shift_logits = logits[..., :-1, :].contiguous()
507
+ shift_labels = labels[..., 1:].contiguous()
508
+ loss_fct = CrossEntropyLoss(reduction='mean')
509
+ bs = shift_labels.shape[0]
510
+ shift_labels = shift_labels.to(shift_logits.device)
511
+ loss = torch.stack([loss_fct(shift_logits[i], shift_labels[i]) for i in range(bs)])
512
+
513
+ if not return_dict:
514
+ output = (logits,) + outputs[1:]
515
+ return (loss,) + output if loss is not None else output
516
+
517
+ res = CausalLMOutputWithPast(
518
+ loss=loss,
519
+ logits=logits,
520
+ past_key_values=outputs.past_key_values,
521
+ hidden_states=outputs.hidden_states,
522
+ attentions=outputs.attentions,
523
+ )
524
+
525
+ res.attention_mask = attention_mask
526
+ return res
527
+
528
+ def prepare_inputs_for_generation(
529
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
530
+ ):
531
+ if past_key_values:
532
+ input_ids = input_ids[:, -1:]
533
+
534
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
535
+ if inputs_embeds is not None and past_key_values is None:
536
+ model_inputs = {"inputs_embeds": inputs_embeds}
537
+ else:
538
+ model_inputs = {"input_ids": input_ids}
539
+
540
+ model_inputs.update(
541
+ {
542
+ "past_key_values": past_key_values,
543
+ "use_cache": kwargs.get("use_cache"),
544
+ "attention_mask": attention_mask,
545
+ "images": kwargs.get("images", None),
546
+ "image_sizes": kwargs.get("image_sizes", None),
547
+ "pixel_values": kwargs.get("pixel_values", None),
548
+ "pixel_values_videos": kwargs.get("pixel_values_videos", None),
549
+ "image_grid_thw": kwargs.get("image_grid_thw", None),
550
+ "video_grid_thw": kwargs.get("video_grid_thw", None),
551
+ }
552
+ )
553
+ return model_inputs
554
+
555
+ AutoConfig.register("valley", ValleyConfig)
556
+ AutoModelForCausalLM.register(ValleyConfig, ValleyQwen2ForCausalLM)
modeling_vision_tower.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
4
+ from transformers import PretrainedConfig
5
+
6
+ siglip_config = PretrainedConfig.from_dict(
7
+ {
8
+ "attention_dropout": 0.0,
9
+ "hidden_act": "gelu_pytorch_tanh",
10
+ "hidden_size": 1152,
11
+ "image_size": 384,
12
+ "intermediate_size": 4304,
13
+ "layer_norm_eps": 1e-06,
14
+ "model_type": "siglip_vision_model",
15
+ "num_attention_heads": 16,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 27,
18
+ "patch_size": 14,
19
+ }
20
+ )
21
+
22
+ qwen2vl_vit_config = PretrainedConfig.from_dict(
23
+ {
24
+ "depth": 32,
25
+ "embed_dim": 1280,
26
+ "hidden_act": "quick_gelu",
27
+ "hidden_size": 3584,
28
+ "in_channels": 3,
29
+ "in_chans": 3,
30
+ "mlp_ratio": 4,
31
+ "model_type": "qwen2_vl",
32
+ "num_heads": 16,
33
+ "patch_size": 14,
34
+ "spatial_merge_size": 2,
35
+ "spatial_patch_size": 14,
36
+ "temporal_patch_size": 2,
37
+ "_attn_implementation": "flash_attention_2",
38
+ "_attn_implementation_internal": "flash_attention_2"
39
+ }
40
+ )
41
+
42
+ def build_vision_tower(vision_tower_cfg, **kwargs):
43
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
44
+ if "siglip-so400m-patch14-384" in vision_tower:
45
+ # Eagle
46
+ if getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None:
47
+ if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None:
48
+ qwen2vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation
49
+ qwen2vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation
50
+
51
+ qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config)
52
+
53
+ if getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) is not None:
54
+ del qwen2vl_vision_tower.merger
55
+ qwen2vl_vision_tower.merger = CustomPatchMerger(
56
+ vision_tower_cfg.hidden_size,
57
+ context_dim=1280,
58
+ hidden_dim=getattr(vision_tower_cfg, "navit_merger_hidden_dim", None)
59
+ ) # random initialize
60
+ qwen2vl_vision_tower.requires_grad_(False)
61
+
62
+ # If only use navit, delete siglip_vision_tower
63
+ if getattr(vision_tower_cfg, "only_navit", False):
64
+ siglip_vision_tower = None
65
+ else:
66
+ siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
67
+
68
+ return siglip_vision_tower, qwen2vl_vision_tower
69
+ # Non-Eagle
70
+ else:
71
+ siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
72
+ return siglip_vision_tower
73
+ else:
74
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
75
+
76
+ class SigLipVisionTower(nn.Module):
77
+ def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"):
78
+ super().__init__()
79
+ self.is_loaded = False
80
+ self.image_tower_name = vision_tower
81
+ self.select_layer = args.mm_vision_select_layer
82
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
83
+ self.cache_dir = cache_dir
84
+
85
+ if not delay_load:
86
+ self.load_model()
87
+ else:
88
+ from transformers import SiglipVisionModel
89
+ self.cfg_only = siglip_config
90
+ self.vision_tower = SiglipVisionModel._from_config(siglip_config) # dummy-load
91
+
92
+ def load_model(self):
93
+ from transformers import SiglipVisionModel
94
+ self.vision_tower = SiglipVisionModel._from_config(siglip_config)
95
+ self.vision_tower.requires_grad_(False)
96
+ self.is_loaded = True
97
+
98
+ def feature_select(self, image_forward_outs):
99
+ assert self.select_feature == "cls_patch"
100
+ image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1)
101
+ return image_features
102
+
103
+ def forward(self, images):
104
+ if type(images) is list:
105
+ image_features = []
106
+ for image in images:
107
+ image_forward_out = self.vision_tower(
108
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
109
+ output_hidden_states=True,
110
+ return_dict=True,
111
+ )
112
+ image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype)
113
+ image_features.append(image_feature)
114
+ else:
115
+ image_forward_outs = self.vision_tower(
116
+ images.to(device=self.device, dtype=self.dtype),
117
+ output_hidden_states=True,
118
+ return_dict=True,
119
+ )
120
+ image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype)
121
+
122
+ return image_features
123
+
124
+ @property
125
+ def dummy_feature(self):
126
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
127
+
128
+ @property
129
+ def dtype(self):
130
+ return self.vision_tower.dtype
131
+
132
+ @property
133
+ def device(self):
134
+ return self.vision_tower.device
135
+
136
+ @property
137
+ def config(self):
138
+ if self.is_loaded:
139
+ return self.vision_tower.config
140
+ else:
141
+ return self.cfg_only
142
+
143
+ @property
144
+ def hidden_size(self):
145
+ return self.config.hidden_size
146
+
147
+ @property
148
+ def num_patches(self):
149
+ return (self.config.image_size // self.config.patch_size) ** 2
150
+
151
+
152
+ class CustomPatchMerger(nn.Module):
153
+ def __init__(self, dim: int, context_dim: int, hidden_dim: int, spatial_merge_size: int = 2) -> None:
154
+ super().__init__()
155
+ self.input_dim = context_dim * (spatial_merge_size**2)
156
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
157
+ self.mlp = nn.Sequential(
158
+ nn.Linear(self.input_dim, hidden_dim),
159
+ nn.GELU(),
160
+ nn.Linear(hidden_dim, dim),
161
+ )
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ x = self.mlp(self.ln_q(x).view(-1, self.input_dim))
165
+ return x
processing_valley.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import types
3
+ import io
4
+ import torch
5
+ from PIL import Image
6
+ from qwen_vl_utils import fetch_image
7
+
8
+ from transformers import (
9
+ ProcessorMixin,
10
+ SiglipImageProcessor,
11
+ BatchFeature,
12
+ Qwen2VLImageProcessor,
13
+ PreTrainedTokenizer
14
+ )
15
+
16
+ from .utils import (
17
+ process_anyres_image,
18
+ BLACK_IMG_ENV,
19
+ DEFAULT_IM_END_TOKEN,
20
+ DEFAULT_IM_START_TOKEN,
21
+ DEFAULT_IMAGE_TOKEN,
22
+ DEFAULT_VI_END_TOKEN,
23
+ DEFAULT_VI_START_TOKEN,
24
+ DEFAULT_VIDEO_TOKEN,
25
+ IMAGE_TOKEN_INDEX,
26
+ SEQ_MAX_LEN,
27
+ )
28
+
29
+ siglip_processor_config = {
30
+ "do_normalize": True,
31
+ "do_rescale": True,
32
+ "do_resize": True,
33
+ "image_mean": [
34
+ 0.5,
35
+ 0.5,
36
+ 0.5
37
+ ],
38
+ "image_processor_type": "SiglipImageProcessor",
39
+ "image_std": [
40
+ 0.5,
41
+ 0.5,
42
+ 0.5
43
+ ],
44
+ "processor_class": "SiglipProcessor",
45
+ "resample": 3,
46
+ "rescale_factor": 0.00392156862745098,
47
+ "size": {
48
+ "height": 384,
49
+ "width": 384
50
+ }
51
+ }
52
+
53
+ qwen2vl_processor_config = {
54
+ "min_pixels": 3136,
55
+ "max_pixels": 12845056,
56
+ "patch_size": 14,
57
+ "temporal_patch_size": 2,
58
+ "merge_size": 2,
59
+ "image_mean": [
60
+ 0.48145466,
61
+ 0.4578275,
62
+ 0.40821073
63
+ ],
64
+ "image_std": [
65
+ 0.26862954,
66
+ 0.26130258,
67
+ 0.27577711
68
+ ],
69
+ "image_processor_type": "Qwen2VLImageProcessor",
70
+ "processor_class": "Qwen2VLProcessor"
71
+ }
72
+
73
+ class ValleyProcessor(ProcessorMixin):
74
+ attributes = ["tokenizer"]
75
+ optional_attributes = [
76
+ "max_pixels",
77
+ "min_pixels",
78
+ "anyres",
79
+ "only_crop_single_image",
80
+ "grid_pinpoints",
81
+ "use_special_start_end_token",
82
+ "only_navit",
83
+ "chat_template",
84
+ ]
85
+ tokenizer_class = "AutoTokenizer"
86
+
87
+ def __init__(self, tokenizer=None, chat_template=None, **kwargs):
88
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
89
+ self.black_img = BLACK_IMG_ENV
90
+ self.siglip_image_processor = SiglipImageProcessor.from_dict(siglip_processor_config)
91
+ self.qwen2vl_image_processor = Qwen2VLImageProcessor.from_dict(
92
+ qwen2vl_processor_config,
93
+ )
94
+
95
+ self.anyres = kwargs.get("anyres", True)
96
+ self.grid_pinpoints = kwargs.get("grid_pinpoints", "(1x1),...,(3x3)")
97
+ self.only_crop_single_image = kwargs.get("only_crop_single_image", True)
98
+ self.use_special_start_end_token = kwargs.get("use_special_start_end_token", True)
99
+ self.only_navit = kwargs.get("only_navit", False)
100
+
101
+ def preprocess_images_siglip(self, images) -> torch.FloatTensor:
102
+ if isinstance(images[0], str):
103
+ images_pil = [Image.open(img).convert("RGB") for img in images]
104
+ elif isinstance(images[0], Image.Image):
105
+ images_pil = [img.convert("RGB") for img in images]
106
+ elif isinstance(images[0], bytes):
107
+ images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images]
108
+ else:
109
+ raise ValueError("unsupported type")
110
+
111
+ processed_images = []
112
+ have_multi_images = len(images_pil) > 1
113
+ for img in images_pil:
114
+ if self.anyres:
115
+ if not self.only_crop_single_image or not have_multi_images:
116
+ image = process_anyres_image(img, self.siglip_image_processor, self.grid_pinpoints)
117
+ else:
118
+ image = [self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0]]
119
+ else:
120
+ image = self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0]
121
+
122
+ processed_images.append(image)
123
+
124
+ if not self.anyres:
125
+ return torch.stack(processed_images, dim=0)
126
+ else:
127
+ return [torch.stack(img, dim=0) for img in processed_images]
128
+
129
+ def preprocess_images_qwen2vl(self, images) -> dict:
130
+ if isinstance(images[0], str):
131
+ images_pil = [Image.open(img).convert("RGB") for img in images]
132
+ elif isinstance(images[0], Image.Image):
133
+ images_pil = [img.convert("RGB") for img in images]
134
+ elif isinstance(images[0], bytes):
135
+ images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images]
136
+ else:
137
+ raise ValueError("unsupported type")
138
+
139
+ image_sizes = [[x.size for x in images_pil]]
140
+ data_dict_qwen2vl = self.qwen2vl_image_processor(
141
+ [fetch_image({"image": img}) for img in images_pil],
142
+ return_tensors="pt"
143
+ )
144
+
145
+ data_dict_qwen2vl["image_sizes"] = image_sizes
146
+
147
+ return data_dict_qwen2vl
148
+
149
+ def preprocess_multimodal(self, conversations):
150
+ for sentence in conversations:
151
+ if sentence["role"] == "system":
152
+ continue
153
+ segs = re.split(DEFAULT_IMAGE_TOKEN, sentence["content"])
154
+ if self.use_special_start_end_token:
155
+ sentence["content"] = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN).join(segs)
156
+ else:
157
+ sentence["content"] = DEFAULT_IMAGE_TOKEN.join(segs)
158
+
159
+ return conversations
160
+
161
+ def preprocess_qwen2(
162
+ self,
163
+ conversations,
164
+ tokenizer: PreTrainedTokenizer,
165
+ has_image: bool = False,
166
+ inference: bool = False,
167
+ only_mask_system: bool = False,
168
+ ) -> dict:
169
+ conv = types.SimpleNamespace(
170
+ system="You are a helpful assistant.",
171
+ roles=("user", "assistant"),
172
+ version="qwen2",
173
+ offset=0,
174
+ sep="<|im_start|>",
175
+ sep2="<|im_end|>\n",
176
+ )
177
+
178
+ # Check system prompt
179
+ assert conversations[0]["role"] == "system"
180
+ if conversations[0]["content"] == None:
181
+ conversations[0]["content"] = conv.system # use default system prompt
182
+
183
+ # Check conversation sequence
184
+ for j, sentence in enumerate(conversations[1:]):
185
+ role = sentence["role"]
186
+ assert role == conv.roles[j % 2], "The conversation sequence is incorrect."
187
+
188
+ conversation_str = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=inference)
189
+
190
+ # Mask targets
191
+ rounds = conversation_str.split(conv.sep2)
192
+ input_ids_ = torch.tensor([], dtype=torch.int64)
193
+ targets_ = torch.tensor([], dtype=torch.int64)
194
+ for i, rou in enumerate(rounds):
195
+ if rou == "":
196
+ continue
197
+ if (not inference) or (i < (len(rounds) - 1)):
198
+ rou += conv.sep2
199
+ if has_image:
200
+ cur_input_ids_ = self.tokenizer_image_token(rou, tokenizer, return_tensors='pt')
201
+ input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
202
+ if only_mask_system:
203
+ mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[0]}\n[\s\S]*', f'{conv.roles[0]}:', rou),
204
+ tokenizer))
205
+ else:
206
+ mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[1]}\n[\s\S]*', f'{conv.roles[1]}:', rou),
207
+ tokenizer))
208
+ targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
209
+ else:
210
+ cur_input_ids_ = tokenizer(rou, return_tensors='pt')["input_ids"][0, :]
211
+ input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
212
+ mask_len = len(tokenizer(re.sub(rf'{conv.roles[1]}\n[\s\S]*', rf'{conv.roles[1]}:', rou))["input_ids"][:])
213
+ targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
214
+
215
+ return {"input_ids": input_ids_, "labels": targets_}
216
+
217
+
218
+ def tokenizer_image_token(
219
+ self,
220
+ prompt,
221
+ tokenizer,
222
+ image_token_index=IMAGE_TOKEN_INDEX,
223
+ return_tensors=None,
224
+ ):
225
+ def split_with_token(string, token):
226
+ result = string.split(token)
227
+ for i in range(len(result) - 1):
228
+ result.insert(i * 2 + 1, token)
229
+ return result
230
+
231
+ if len(prompt) > SEQ_MAX_LEN:
232
+ raise ValueError("sequence is too long !!!")
233
+
234
+ prompt_chunks = split_with_token(prompt, DEFAULT_IMAGE_TOKEN)
235
+ input_ids, offset = ([tokenizer.bos_token_id], 1) if getattr(tokenizer,'bos_token',None) else ([], 0)
236
+ token2index = {DEFAULT_IMAGE_TOKEN: image_token_index}
237
+ for chunk in prompt_chunks:
238
+ if chunk in token2index:
239
+ input_ids.append(token2index[chunk])
240
+ else:
241
+ chunk_ids = tokenizer(chunk).input_ids
242
+ if chunk_ids[0] != getattr(tokenizer,'bos_token_id', None):
243
+ offset = 0
244
+ input_ids.extend(chunk_ids[offset:])
245
+
246
+ if return_tensors is not None:
247
+ if return_tensors == "pt":
248
+ return torch.tensor(input_ids, dtype=torch.long)
249
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
250
+ return input_ids
251
+
252
+
253
+ def __call__(self, messages, inference=True, **kwargs) -> BatchFeature:
254
+ max_pixels=kwargs.get("max_pixels", self.max_pixels)
255
+ min_pixels=kwargs.get("min_pixels", self.min_pixels)
256
+ if max_pixels is not None:
257
+ self.qwen2vl_image_processor.max_pixels = max_pixels
258
+ if min_pixels is not None:
259
+ self.qwen2vl_image_processor.min_pixels = min_pixels
260
+
261
+ # Deal with images
262
+ if "images" not in messages or not messages["images"] or not messages["images"][0]:
263
+ images = [self.black_img]
264
+ elif type(messages["images"]) == str:
265
+ images = [messages["images"]]
266
+ else:
267
+ images = messages["images"]
268
+
269
+ # Deal with conversations
270
+ conversations = messages["conversations"]
271
+ if conversations[0]["role"] != "system":
272
+ conversations = [{"role":"system", "content": None}] + conversations # dummy system prompt
273
+
274
+ # Insert special token `<image>`
275
+ assert conversations[1]["role"] == "user"
276
+ if images and "<image>" not in conversations[1]["content"]:
277
+ image_token = " ".join(["<image>"] * len(images))
278
+ conversations[1]["content"] = f"{image_token}\n{conversations[1]['content']}"
279
+
280
+ # The last message should be assistant if inference=True
281
+ if inference:
282
+ assert conversations[-1]["role"] == "user", "the last message should be assistant if inference=True"
283
+
284
+ # Image preprocess
285
+ if self.only_navit:
286
+ precessed_images_siglip = None
287
+ else:
288
+ precessed_images_siglip = self.preprocess_images_siglip(images)
289
+ processed_data_dict_qwen2vl = self.preprocess_images_qwen2vl(images)
290
+ source = self.preprocess_multimodal(conversations)
291
+ data_dict = self.preprocess_qwen2(source, self.tokenizer, has_image=True, only_mask_system=False, inference=inference)
292
+
293
+ # Construct batch data
294
+ data_dict["input_ids"] = data_dict["input_ids"].unsqueeze(0) # batch_size = 1
295
+ data_dict["labels"] = data_dict["labels"].unsqueeze(0)
296
+ data_dict["images"] = [precessed_images_siglip]
297
+
298
+ return BatchFeature(data={**data_dict, **processed_data_dict_qwen2vl})
299
+
300
+ def batch_decode(self, *args, **kwargs):
301
+ """
302
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
303
+ refer to the docstring of this method for more information.
304
+ """
305
+ return self.tokenizer.batch_decode(*args, **kwargs)
306
+
307
+
308
+ def decode(self, *args, **kwargs):
309
+ """
310
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
311
+ the docstring of this method for more information.
312
+ """
313
+ return self.tokenizer.decode(*args, **kwargs)
utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import math
5
+ import ast
6
+ import re
7
+ import torch
8
+ from transformers import StoppingCriteria
9
+
10
+ IGNORE_INDEX = -100
11
+ IMAGE_TOKEN_INDEX = -200
12
+ GANDALF_TOKEN_INDEX = -300
13
+ DEFAULT_PAD_TOKEN = "[PAD]"
14
+ DEFAULT_EOS_TOKEN = "</s>"
15
+ DEFAULT_BOS_TOKEN = "</s>"
16
+ DEFAULT_UNK_TOKEN = "<unk>"
17
+ DEFAULT_IMAGE_TOKEN = "<image>"
18
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
+ DEFAULT_IM_START_TOKEN = "<im_start>"
20
+ DEFAULT_IM_END_TOKEN = "<im_end>"
21
+ DEFAULT_VIDEO_TOKEN = "<video>"
22
+ DEFAULT_VIDEO_FRAME_TOKEN = "<vi_frame>"
23
+ DEFAULT_VI_START_TOKEN = "<vi_start>"
24
+ DEFAULT_VI_END_TOKEN = "<vi_end>"
25
+ DEFAULT_EOC_TOKEN = "<eoc>"
26
+ COR_START_TOKEN = "<cor>"
27
+ COR_END_TOKEN = "<\cor>"
28
+ SEQ_MAX_LEN = 50000
29
+ BLACK_IMG_ENV = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x03\x00\x00\x00\x03\x08\x02\x00\x00\x00\xd9J"\xe8\x00\x00\x00\x12IDAT\x08\x1dcd\x80\x01F\x06\x18`d\x80\x01\x00\x00Z\x00\x04we\x03N\x00\x00\x00\x00IEND\xaeB`\x82'
30
+
31
+
32
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
33
+ """
34
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
35
+
36
+ Args:
37
+ image_size (tuple): The size of the input image in the format (width, height).
38
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
39
+ patch_size (int): The size of each image patch.
40
+
41
+ Returns:
42
+ tuple: The shape of the image patch grid in the format (width, height).
43
+ """
44
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
45
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
46
+ # Use regex to extract the range from the input string
47
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
48
+ range_start = tuple(map(int, matches[0]))
49
+ range_end = tuple(map(int, matches[-1]))
50
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
51
+ grid_pinpoints = [
52
+ (i, j)
53
+ for i in range(range_start[0], range_end[0] + 1)
54
+ for j in range(range_start[1], range_end[1] + 1)
55
+ ]
56
+ # Multiply all elements by patch_size
57
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
58
+ if type(grid_pinpoints) is list:
59
+ possible_resolutions = grid_pinpoints
60
+ else:
61
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
62
+ width, height = select_best_resolution(image_size, possible_resolutions)
63
+ return width // patch_size, height // patch_size
64
+
65
+ def select_best_resolution(original_size, possible_resolutions):
66
+ """
67
+ Selects the best resolution from a list of possible resolutions based on the original size.
68
+
69
+ Args:
70
+ original_size (tuple): The original size of the image in the format (width, height).
71
+ possible_resolutions (list): A list of possible resolutions in the format
72
+ [(width1, height1), (width2, height2), ...].
73
+
74
+ Returns:
75
+ tuple: The best fit resolution in the format (width, height).
76
+ """
77
+ original_width, original_height = original_size
78
+ best_fit = None
79
+ max_effective_resolution = 0
80
+ min_wasted_resolution = float("inf")
81
+
82
+ for width, height in possible_resolutions:
83
+ # Calculate the downscaled size to keep the aspect ratio
84
+ scale = min(width / original_width, height / original_height)
85
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
86
+
87
+ # Calculate effective and wasted resolutions
88
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
89
+ wasted_resolution = (width * height) - effective_resolution
90
+
91
+ if effective_resolution > max_effective_resolution or \
92
+ (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
93
+ max_effective_resolution = effective_resolution
94
+ min_wasted_resolution = wasted_resolution
95
+ best_fit = (width, height)
96
+
97
+ return best_fit
98
+
99
+
100
+ def unpad_image(tensor, original_size):
101
+ """
102
+ Unpads a PyTorch tensor of a padded and resized image.
103
+
104
+ Args:
105
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
106
+ original_size (tuple): The original size of the image (height, width).
107
+
108
+ Returns:
109
+ torch.Tensor: The unpadded image tensor.
110
+ """
111
+ original_width, original_height = original_size
112
+ current_height, current_width = tensor.shape[1:]
113
+
114
+ # Compute aspect ratios
115
+ original_aspect_ratio = original_width / original_height
116
+ current_aspect_ratio = current_width / current_height
117
+
118
+ # Determine padding size and direction
119
+ if original_aspect_ratio > current_aspect_ratio:
120
+ # Padding was added to the height
121
+ scale_factor = current_width / original_width
122
+ new_height = int(original_height * scale_factor)
123
+ padding = (current_height - new_height) // 2
124
+ unpadded_tensor = tensor[:, padding: current_height - padding, :]
125
+ else:
126
+ # Padding was added to the width
127
+ scale_factor = current_height / original_height
128
+ new_width = int(original_width * scale_factor)
129
+ padding = (current_width - new_width) // 2
130
+ unpadded_tensor = tensor[:, :, padding: current_width - padding]
131
+
132
+ return unpadded_tensor
133
+
134
+
135
+ def process_anyres_image(image, processor, grid_pinpoints):
136
+ """
137
+ Process an image with variable resolutions.
138
+
139
+ Args:
140
+ image (PIL.Image.Image): The input image to be processed.
141
+ processor: The image processor object.
142
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
143
+
144
+ Returns:
145
+ torch.Tensor: A tensor containing the processed image patches.
146
+ """
147
+ # Convert grid_pinpoints from string to list
148
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
149
+ try:
150
+ patch_size = processor.size["height"]
151
+ except Exception:
152
+ patch_size = processor.size["shortest_edge"]
153
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
154
+ # Use regex to extract the range from the input string
155
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
156
+ range_start = tuple(map(int, matches[0]))
157
+ range_end = tuple(map(int, matches[-1]))
158
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
159
+ grid_pinpoints = [
160
+ (i, j)
161
+ for i in range(range_start[0], range_end[0] + 1)
162
+ for j in range(range_start[1], range_end[1] + 1)
163
+ ]
164
+ # Multiply all elements by patch_size
165
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
166
+
167
+ if type(grid_pinpoints) is list:
168
+ possible_resolutions = grid_pinpoints
169
+ else:
170
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
171
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
172
+ image_padded = resize_and_pad_image(image, best_resolution)
173
+
174
+ patches = divide_to_patches(image_padded, processor.size["height"])
175
+
176
+ # FIXME: this seems to be a bug that it resizes instead of pad.
177
+ # but to keep it consistent with previous, i will keep it as it is
178
+ # TODO: uncomment below to ablate with the padding
179
+ if isinstance(processor.size, dict):
180
+ shortest_edge = processor.size["height"]
181
+ else:
182
+ shortest_edge = min(processor.size)
183
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
184
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
185
+
186
+ image_patches = [image_original_resize] + patches
187
+ image_patches = [
188
+ processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
189
+ for image_patch in image_patches
190
+ ]
191
+ # return torch.stack(image_patches, dim=0)
192
+ return image_patches
193
+
194
+ def resize_and_pad_image(image, target_resolution):
195
+ """
196
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
197
+
198
+ Args:
199
+ image (PIL.Image.Image): The input image.
200
+ target_resolution (tuple): The target resolution (width, height) of the image.
201
+
202
+ Returns:
203
+ PIL.Image.Image: The resized and padded image.
204
+ """
205
+ original_width, original_height = image.size
206
+ target_width, target_height = target_resolution
207
+
208
+ # Determine which dimension (width or height) to fill
209
+ scale_w = target_width / original_width
210
+ scale_h = target_height / original_height
211
+
212
+ if scale_w < scale_h:
213
+ # Width will be filled completely
214
+ new_width = target_width
215
+ new_height = min(math.ceil(original_height * scale_w), target_height)
216
+ else:
217
+ # Height will be filled completely
218
+ new_height = target_height
219
+ new_width = min(math.ceil(original_width * scale_h), target_width)
220
+
221
+ # Resize the image
222
+ resized_image = image.resize((new_width, new_height))
223
+
224
+ # Create a new image with the target size and paste the resized image onto it
225
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
226
+ paste_x = (target_width - new_width) // 2
227
+ paste_y = (target_height - new_height) // 2
228
+ new_image.paste(resized_image, (paste_x, paste_y))
229
+
230
+ return new_image
231
+
232
+ def divide_to_patches(image, patch_size):
233
+ """
234
+ Divides an image into patches of a specified size.
235
+
236
+ Args:
237
+ image (PIL.Image.Image): The input image.
238
+ patch_size (int): The size of each patch.
239
+
240
+ Returns:
241
+ list: A list of PIL.Image.Image objects representing the patches.
242
+ """
243
+ patches = []
244
+ width, height = image.size
245
+ for i in range(0, height, patch_size):
246
+ for j in range(0, width, patch_size):
247
+ box = (j, i, j + patch_size, i + patch_size)
248
+ patch = image.crop(box)
249
+ patches.append(patch)
250
+
251
+ return patches