AndyZijianZhang commited on
Commit
3d29bb3
·
verified ·
1 Parent(s): 9e8ab54

Upload files with `vila-upload`.

Browse files

Upload config.json
Upload processing_vila.py
Upload processor_config.json
Upload configuration_vila.py
Upload tokenizer_config.json
Upload generation_config.json
Upload chat_template.jinja
Upload modeling_vila.py

chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' }}{% if message['content'] is string %}{{ message['content'] + '<|im_end|>\n' }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{{ '<image>' }}{% elif content['type'] == 'video' or 'video' in content %}{{ '<video>' }}{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
config.json CHANGED
@@ -10,7 +10,6 @@
10
  "AutoModelForVision2Seq": "modeling_vila.VILAForConditionalGeneration"
11
  },
12
  "hidden_size": 5120,
13
- "image_end_token_id": 198,
14
  "image_token_id": 151666,
15
  "mm_hidden_size": 1152,
16
  "mm_projector_type": "mlp_downsample_3x3_fix",
@@ -45,7 +44,7 @@
45
  "vocab_size": 151670
46
  },
47
  "torch_dtype": "bfloat16",
48
- "transformers_version": "4.51.3",
49
  "video_token_id": 151670,
50
  "vision_config": {
51
  "architectures": [
 
10
  "AutoModelForVision2Seq": "modeling_vila.VILAForConditionalGeneration"
11
  },
12
  "hidden_size": 5120,
 
13
  "image_token_id": 151666,
14
  "mm_hidden_size": 1152,
15
  "mm_projector_type": "mlp_downsample_3x3_fix",
 
44
  "vocab_size": 151670
45
  },
46
  "torch_dtype": "bfloat16",
47
+ "transformers_version": "4.52.3",
48
  "video_token_id": 151670,
49
  "vision_config": {
50
  "architectures": [
configuration_vila.py CHANGED
@@ -21,7 +21,6 @@ class VILAConfig(PretrainedConfig):
21
  # Model configuration.
22
  hidden_size: int
23
  image_token_id: int
24
- image_end_token_id: int
25
  mm_hidden_size: int
26
  mm_projector_type: str
27
  mm_vision_select_feature: str
@@ -30,17 +29,16 @@ class VILAConfig(PretrainedConfig):
30
 
31
  def __init__(
32
  self,
33
- *,
34
  text_config: Optional[Dict[str, Any]] = None,
35
  vision_config: Optional[Dict[str, Any]] = None,
36
- hidden_size: Optional[int] = None,
37
- image_token_id: Optional[int] = None,
38
- image_end_token_id: Optional[int] = None,
39
- mm_hidden_size: Optional[int] = None,
40
- mm_projector_type: Optional[str] = None,
41
- mm_vision_select_feature: Optional[str] = None,
42
- mm_vision_select_layer: Optional[int] = None,
43
- video_token_id: Optional[int] = None,
44
  **kwargs,
45
  ):
46
  super().__init__(**kwargs)
@@ -48,14 +46,10 @@ class VILAConfig(PretrainedConfig):
48
  self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
49
  self.vision_config = SiglipVisionConfig(**vision_config) if vision_config else SiglipVisionConfig()
50
 
51
- # By default, we use settings from NVILA-Lite.
52
- self.hidden_size = hidden_size if hidden_size is not None else 1536
53
- self.image_token_id = image_token_id if image_token_id is not None else 151649
54
- self.image_end_token_id = image_end_token_id if image_end_token_id is not None else 198 # "\n"
55
- self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else 1152
56
- self.mm_projector_type = mm_projector_type if mm_projector_type is not None else "mlp_downsample_3x3_fix"
57
- self.mm_vision_select_feature = (
58
- mm_vision_select_feature if mm_vision_select_feature is not None else "cls_patch"
59
- )
60
- self.mm_vision_select_layer = mm_vision_select_layer if mm_vision_select_layer is not None else -2
61
- self.video_token_id = video_token_id if video_token_id is not None else 151650
 
21
  # Model configuration.
22
  hidden_size: int
23
  image_token_id: int
 
24
  mm_hidden_size: int
25
  mm_projector_type: str
26
  mm_vision_select_feature: str
 
29
 
30
  def __init__(
31
  self,
 
32
  text_config: Optional[Dict[str, Any]] = None,
33
  vision_config: Optional[Dict[str, Any]] = None,
34
+ *,
35
+ hidden_size: int = 1536,
36
+ image_token_id: int = 151649,
37
+ mm_hidden_size: int = 1152,
38
+ mm_projector_type: str = "mlp_downsample_3x3_fix",
39
+ mm_vision_select_feature: str = "cls_patch",
40
+ mm_vision_select_layer: int = -2,
41
+ video_token_id: int = 151650,
42
  **kwargs,
43
  ):
44
  super().__init__(**kwargs)
 
46
  self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
47
  self.vision_config = SiglipVisionConfig(**vision_config) if vision_config else SiglipVisionConfig()
48
 
49
+ self.hidden_size = hidden_size
50
+ self.image_token_id = image_token_id
51
+ self.mm_hidden_size = mm_hidden_size
52
+ self.mm_projector_type = mm_projector_type
53
+ self.mm_vision_select_feature = mm_vision_select_feature
54
+ self.mm_vision_select_layer = mm_vision_select_layer
55
+ self.video_token_id = video_token_id
 
 
 
 
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
  "pad_token_id": 151643,
6
- "transformers_version": "4.51.3"
7
  }
 
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
  "pad_token_id": 151643,
6
+ "transformers_version": "4.52.3"
7
  }
modeling_vila.py CHANGED
@@ -1,8 +1,10 @@
1
- from typing import List, Optional, Type
2
 
3
  import torch
4
  import torch.nn as nn
5
- from torch import Tensor
 
 
6
  from transformers.configuration_utils import PretrainedConfig
7
  from transformers.generation.utils import GenerationMixin
8
  from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
@@ -13,68 +15,34 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
13
  from .configuration_vila import VILAConfig
14
 
15
 
16
- class DownSampleBlock(nn.Module):
17
- @staticmethod
18
- def flat_square(x: Tensor) -> Tensor:
19
- n, w, h, c = x.size()
20
- if w % 2 == 1:
21
- x = torch.concat([x, torch.zeros((n, 1, h, c), device=x.device, dtype=x.dtype)], dim=1).contiguous()
22
- n, w, h, c = x.size()
23
- if h % 2 == 1:
24
- x = torch.concat([x, torch.zeros((n, w, 1, c), device=x.device, dtype=x.dtype)], dim=2).contiguous()
25
- n, w, h, c = x.size()
26
- x = x.contiguous()
27
- x = x.view(n, w, int(h / 2), int(c * 2))
28
- x = x.permute(0, 2, 1, 3).contiguous()
29
- x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
30
- x = x.permute(0, 2, 1, 3).contiguous()
31
- return x
32
-
33
  def forward(self, x: Tensor) -> Tensor:
34
- vit_embeds = x
35
- h = w = int(vit_embeds.shape[1] ** 0.5)
36
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
37
- vit_embeds = self.flat_square(vit_embeds)
38
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
39
- return vit_embeds
40
 
 
 
 
41
 
42
- class DownSample3x3BlockFix(nn.Module):
43
- @staticmethod
44
- def flat_square_3x3(x: Tensor) -> Tensor:
45
- n, w, h, c = x.size()
46
- if w % 3 != 0:
47
- x = torch.concat(
48
- [
49
- x,
50
- torch.zeros((n, 3 - (w % 3), h, c), device=x.device, dtype=x.dtype),
51
- ],
52
- dim=1,
53
- ).contiguous()
54
- n, w, h, c = x.size()
55
- x = x.contiguous()
56
- if h % 3 != 0:
57
- x = torch.concat(
58
- [
59
- x,
60
- torch.zeros((n, w, 3 - (h % 3), c), device=x.device, dtype=x.dtype),
61
- ],
62
- dim=2,
63
- ).contiguous()
64
- n, w, h, c = x.size()
65
- x = x.view(n, w, int(h / 3), int(c * 3))
66
- x = x.permute(0, 2, 1, 3).contiguous()
67
- x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
68
- x = x.permute(0, 2, 1, 3).contiguous()
69
- return x
70
 
71
- def forward(self, x: Tensor) -> Tensor:
72
- vit_embeds = x
73
- h = w = int(vit_embeds.shape[1] ** 0.5)
74
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
75
- vit_embeds = self.flat_square_3x3(vit_embeds)
76
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
77
- return vit_embeds
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  class MultimodalProjector(nn.Module):
@@ -88,37 +56,24 @@ class MultimodalProjector(nn.Module):
88
  ):
89
  super().__init__(*args, **kwargs)
90
 
91
- match config.mm_projector_type:
92
- case "linear":
93
- self.layers = nn.Sequential(
94
- nn.Linear(config.vision_config.hidden_size, config.hidden_size),
95
- )
96
- case "mlp_downsample":
97
- self.layers = nn.Sequential(
98
- DownSampleBlock(),
99
- nn.LayerNorm(config.mm_hidden_size * 4),
100
- nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
101
- nn.GELU(),
102
- nn.Linear(config.hidden_size, config.hidden_size),
103
- )
104
- case "mlp_downsample_3x3_fix":
105
- self.layers = nn.Sequential(
106
- DownSample3x3BlockFix(),
107
- nn.LayerNorm(config.mm_hidden_size * 9),
108
- nn.Linear(
109
- config.mm_hidden_size * 9,
110
- config.mm_hidden_size * 3,
111
- ),
112
- nn.GELU(),
113
- nn.LayerNorm(config.vision_config.hidden_size * 3),
114
- nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
115
- nn.GELU(),
116
- nn.Linear(config.hidden_size, config.hidden_size),
117
- )
118
- case _:
119
- raise NotImplementedError(f"mm_projector_type={config.mm_projector_type} not implemented.")
120
-
121
- self.layers.to(dtype=config.torch_dtype)
122
 
123
  @property
124
  def device(self) -> torch.device:
@@ -129,7 +84,15 @@ class MultimodalProjector(nn.Module):
129
  return next(self.parameters()).dtype
130
 
131
  def forward(self, x: Tensor) -> Tensor:
132
- return self.layers(x)
 
 
 
 
 
 
 
 
133
 
134
 
135
  class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
@@ -156,9 +119,9 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
156
  ):
157
  super().__init__(config, *args, **kwargs)
158
 
159
- self.llm = Qwen2ForCausalLM(config.text_config, *args, **kwargs)
160
  self.mm_projector = MultimodalProjector(config)
161
- self.vision_tower = SiglipVisionModel(config.vision_config, *args, **kwargs)
162
 
163
  self.post_init()
164
 
@@ -168,36 +131,29 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
168
  attention_mask: Optional[Tensor] = None,
169
  input_ids: Optional[Tensor] = None,
170
  inputs_embeds: Optional[Tensor] = None,
 
171
  pixel_values: Optional[Tensor] = None,
 
 
172
  **kwargs,
173
  ) -> CausalLMOutputWithPast:
174
- # Vision info is only used for prefilling.
175
- if kwargs.get("past_key_values", None) is not None:
176
- pixel_values = None
177
 
178
- inputs_embeds = inputs_embeds.to(dtype=self.dtype) if inputs_embeds is not None else None
179
- pixel_values = pixel_values.to(dtype=self.dtype) if pixel_values is not None else None
180
-
181
- if inputs_embeds is None:
182
- assert input_ids is not None
183
-
184
- inputs_embeds = self._embed(input_ids, pixel_values)
185
- else:
186
- assert input_ids is None
187
- assert pixel_values is None
188
 
189
  outputs = self.llm.__call__(
190
- inputs_embeds=inputs_embeds.to(
191
- device=self.llm.device,
192
- dtype=self.llm.dtype,
193
- ),
194
- attention_mask=(
195
- attention_mask.to(
196
- device=self.llm.device,
197
- )
198
- if attention_mask is not None
199
- else None
200
  ),
 
 
 
201
  **kwargs,
202
  )
203
 
@@ -221,8 +177,6 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
221
  The embedding of the input ids and pixel values.
222
  """
223
 
224
- # Video tokens should be removed during preprocessing, so there must not be any video
225
- # tokens in the input ids.
226
  if torch.any(input_ids == self.config.video_token_id):
227
  raise ValueError("Video token ids should not be present in the input ids.")
228
 
@@ -233,56 +187,35 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
233
  if pixel_values is None:
234
  return text_embedding
235
 
236
- image_features: BaseModelOutputWithPooling = self.vision_tower.__call__(
237
- pixel_values.to(
238
- device=self.vision_tower.device,
239
- dtype=self.vision_tower.dtype,
240
- ),
241
  output_hidden_states=True,
242
  )
243
- assert image_features.hidden_states is not None
244
-
245
- # Select image feature.
246
- selected_layer_output = image_features.hidden_states[self.config.mm_vision_select_layer]
247
- match self.config.mm_vision_select_feature:
248
- case "cls_patch":
249
- selected_feature = selected_layer_output
250
- case _:
251
- raise NotImplementedError(
252
- f"mm_vision_select_feature={self.config.mm_vision_select_feature} not implemented."
253
- )
254
 
255
- # TODO: Support dynamic_s2.
256
 
257
  image_embedding: Tensor = self.mm_projector.__call__(
258
- selected_feature.to(
259
- device=self.mm_projector.device,
260
- dtype=self.mm_projector.dtype,
261
- )
262
  )
263
 
264
- # Append image end token to every image embedding.
265
- image_end_token_embedding: Tensor = self.llm.get_input_embeddings().__call__(
266
- torch.tensor(
267
- self.config.image_end_token_id,
268
- device=text_embedding.device,
269
- dtype=torch.long,
270
- ).view(1, -1)
271
- ) # Shape: (1, 1, dim_feature)
272
- image_end_token_embedding = image_end_token_embedding.expand(
273
- image_embedding.shape[0], 1, -1
274
- ) # Shape: (n_images, 1, dim_feature)
275
- image_embedding = torch.concat(
276
- [
277
- image_embedding.to(device=text_embedding.device),
278
- image_end_token_embedding,
279
- ],
280
- dim=1,
281
  )
282
 
283
- n_images, n_feature, dim_feature = image_embedding.shape
284
- image_embedding = image_embedding.view(n_images * n_feature, dim_feature)
 
 
 
 
 
285
 
286
- text_embedding[image_token_mask.to(device=text_embedding.device)] = image_embedding
287
 
288
- return text_embedding
 
 
 
 
1
+ from typing import List, Optional, Type, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import LongTensor, Tensor
7
+ from transformers.cache_utils import Cache
8
  from transformers.configuration_utils import PretrainedConfig
9
  from transformers.generation.utils import GenerationMixin
10
  from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
 
15
  from .configuration_vila import VILAConfig
16
 
17
 
18
+ class DownSample3x3BlockFix(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def forward(self, x: Tensor) -> Tensor:
20
+ """
21
+ Args:
22
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
 
 
 
23
 
24
+ Returns:
25
+ The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
26
+ """
27
 
28
+ batch_size, sequence_length, hidden_size = x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ feat_size = int(sequence_length**0.5)
31
+ if feat_size**2 != sequence_length:
32
+ raise ValueError(f"Cannot take square root: sequence_length {sequence_length} is not a perfect square")
33
+
34
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
35
+
36
+ pad_after = (3 - feat_size % 3) % 3
37
+ if pad_after > 0:
38
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
39
+ feat_size = feat_size + pad_after
40
+
41
+ features = features.reshape(batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size)
42
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
43
+ features = features.reshape(batch_size, -1, 9 * hidden_size)
44
+
45
+ return features
46
 
47
 
48
  class MultimodalProjector(nn.Module):
 
56
  ):
57
  super().__init__(*args, **kwargs)
58
 
59
+ if config.mm_projector_type == "mlp_downsample_3x3_fix":
60
+ self.layers = nn.Sequential(
61
+ DownSample3x3BlockFix(),
62
+ nn.LayerNorm(config.mm_hidden_size * 9),
63
+ nn.Linear(
64
+ config.mm_hidden_size * 9,
65
+ config.mm_hidden_size * 3,
66
+ ),
67
+ nn.GELU(),
68
+ nn.LayerNorm(config.vision_config.hidden_size * 3),
69
+ nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
70
+ nn.GELU(),
71
+ nn.Linear(config.hidden_size, config.hidden_size),
72
+ )
73
+ else:
74
+ raise NotImplementedError(f"Unsupported mm_projector_type: {config.mm_projector_type}")
75
+
76
+ self.layers.type(config.torch_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  @property
79
  def device(self) -> torch.device:
 
84
  return next(self.parameters()).dtype
85
 
86
  def forward(self, x: Tensor) -> Tensor:
87
+ """
88
+ Args:
89
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
90
+
91
+ Returns:
92
+ The output tensor of shape (batch_size, image_pad_len, hidden_size).
93
+ """
94
+
95
+ return self.layers(x.to(device=self.device, dtype=self.dtype))
96
 
97
 
98
  class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
 
119
  ):
120
  super().__init__(config, *args, **kwargs)
121
 
122
+ self.llm = Qwen2ForCausalLM._from_config(config.text_config, *args, **kwargs)
123
  self.mm_projector = MultimodalProjector(config)
124
+ self.vision_tower = SiglipVisionModel._from_config(config.vision_config, *args, **kwargs)
125
 
126
  self.post_init()
127
 
 
131
  attention_mask: Optional[Tensor] = None,
132
  input_ids: Optional[Tensor] = None,
133
  inputs_embeds: Optional[Tensor] = None,
134
+ past_key_values: Optional[Cache] = None,
135
  pixel_values: Optional[Tensor] = None,
136
+ position_ids: Optional[LongTensor] = None,
137
+ logits_to_keep: Union[int, Tensor] = 0,
138
  **kwargs,
139
  ) -> CausalLMOutputWithPast:
140
+ if (input_ids is None) ^ (inputs_embeds is not None):
141
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
 
142
 
143
+ if past_key_values is None: # Prefill
144
+ if input_ids is not None:
145
+ inputs_embeds = self._embed(input_ids, pixel_values)
146
+ input_ids = None
 
 
 
 
 
 
147
 
148
  outputs = self.llm.__call__(
149
+ attention_mask=(attention_mask.to(device=self.llm.device) if attention_mask is not None else None),
150
+ input_ids=(input_ids.to(device=self.llm.device) if input_ids is not None else None),
151
+ inputs_embeds=(
152
+ inputs_embeds.to(device=self.llm.device, dtype=self.llm.dtype) if inputs_embeds is not None else None
 
 
 
 
 
 
153
  ),
154
+ past_key_values=past_key_values,
155
+ position_ids=(position_ids.to(device=self.llm.device) if position_ids is not None else None),
156
+ logits_to_keep=logits_to_keep,
157
  **kwargs,
158
  )
159
 
 
177
  The embedding of the input ids and pixel values.
178
  """
179
 
 
 
180
  if torch.any(input_ids == self.config.video_token_id):
181
  raise ValueError("Video token ids should not be present in the input ids.")
182
 
 
187
  if pixel_values is None:
188
  return text_embedding
189
 
190
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
191
+ pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype),
 
 
 
192
  output_hidden_states=True,
193
  )
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ mm_projector_input = self._vision_tower_output_to_mm_projector_input(vision_tower_output)
196
 
197
  image_embedding: Tensor = self.mm_projector.__call__(
198
+ mm_projector_input.to(device=self.mm_projector.device, dtype=self.mm_projector.dtype)
 
 
 
199
  )
200
 
201
+ image_embedding = image_embedding.reshape(-1, image_embedding.shape[-1])
202
+
203
+ text_embedding.masked_scatter_(
204
+ image_token_mask.to(device=text_embedding.device, dtype=torch.bool).unsqueeze(-1),
205
+ image_embedding.to(device=text_embedding.device, dtype=text_embedding.dtype).flatten(),
 
 
 
 
 
 
 
 
 
 
 
 
206
  )
207
 
208
+ return text_embedding
209
+
210
+ def _vision_tower_output_to_mm_projector_input(
211
+ self,
212
+ vision_tower_output: BaseModelOutputWithPooling,
213
+ ) -> Tensor:
214
+ assert vision_tower_output.hidden_states is not None
215
 
216
+ selected_layer_hidden_states = vision_tower_output.hidden_states[self.config.mm_vision_select_layer]
217
 
218
+ if self.config.mm_vision_select_feature == "cls_patch":
219
+ return selected_layer_hidden_states
220
+ else:
221
+ raise NotImplementedError(f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}")
processing_vila.py CHANGED
@@ -1,19 +1,22 @@
 
1
  from typing import List, Optional, Tuple, cast
2
 
3
  import transformers.image_transforms as image_transforms
4
  import transformers.image_utils as image_utils
5
  import transformers.utils.logging
 
6
  from PIL.Image import Image
7
  from torch import Tensor
8
  from transformers.feature_extraction_utils import BatchFeature
9
  from transformers.image_processing_utils import BaseImageProcessor
10
  from transformers.image_processing_utils_fast import BaseImageProcessorFast
11
- from transformers.image_utils import ImageInput, VideoInput
12
  from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
13
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
14
- from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
15
  from transformers.tokenization_utils import PreTrainedTokenizer
16
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
 
17
 
18
  logger = transformers.utils.logging.get_logger(__name__)
19
 
@@ -41,6 +44,7 @@ class VILAProcessor(ProcessorMixin):
41
  "image_pad_len",
42
  "max_tiles",
43
  "min_tiles",
 
44
  ]
45
 
46
  # Attributes.
@@ -51,15 +55,17 @@ class VILAProcessor(ProcessorMixin):
51
  image_pad_len: int
52
  max_tiles: int
53
  min_tiles: int
 
54
 
55
  def __init__(
56
  self,
57
  image_processor: BaseImageProcessor,
58
  tokenizer: PreTrainedTokenizer,
59
  *,
60
- image_pad_len: Optional[int] = None,
61
- max_tiles: Optional[int] = None,
62
- min_tiles: Optional[int] = None,
 
63
  **kwargs,
64
  ):
65
  super().__init__(
@@ -68,17 +74,17 @@ class VILAProcessor(ProcessorMixin):
68
  **kwargs,
69
  )
70
 
71
- self.image_pad_len = image_pad_len if image_pad_len is not None else 122
72
- self.max_tiles = max_tiles if max_tiles is not None else 12
73
- self.min_tiles = min_tiles if min_tiles is not None else 1
 
74
 
75
  def __call__(
76
  self,
77
  text: TextInput | List[TextInput],
78
  images: Optional[ImageInput] = None,
79
  videos: Optional[VideoInput] = None,
80
- audio: None = None,
81
- **kwargs: Unpack[VILAProcessorProcessingKwargs],
82
  ) -> VILAProcessorOutput:
83
  """Preprocesses inputs for VILA.
84
 
@@ -86,7 +92,6 @@ class VILAProcessor(ProcessorMixin):
86
  text: The text to be processed.
87
  images: The images to be processed.
88
  videos: The videos to be processed.
89
- audio: Not available.
90
  **kwargs: Additional arguments for processing.
91
 
92
  Returns:
@@ -99,39 +104,33 @@ class VILAProcessor(ProcessorMixin):
99
  **kwargs,
100
  )
101
 
102
- text, images, videos = self._prepare_inputs(
103
  text=text,
104
  images=images,
105
  videos=videos,
106
  )
107
 
108
- # Process videos.
109
- text, images, video_flags = self._treat_videos_as_image_seqs(
110
- text=text,
111
- images=images,
112
- videos=videos,
113
  )
114
 
115
- # Process images.
116
- image_inputs, num_cropped_images = self._process_images(
117
- images=images,
118
- **merged_kwargs["images_kwargs"],
119
- )
120
-
121
- # Process text.
122
- text = self._pad_image_tokens_by_num_crops(
123
- text,
124
- num_cropped_images=num_cropped_images,
125
- video_flags=video_flags,
126
- )
127
-
128
- text = self._pad_image_tokens_by_num_embeddings(text)
129
-
130
  text_inputs = self.tokenizer.__call__(
131
- text,
132
  **merged_kwargs["text_kwargs"],
133
  )
134
 
 
 
 
 
 
 
 
 
 
 
135
  return VILAProcessorOutput(
136
  data={
137
  **text_inputs,
@@ -139,99 +138,144 @@ class VILAProcessor(ProcessorMixin):
139
  }
140
  )
141
 
142
- def _crop_image(
143
- self,
144
- image: Image,
145
- ) -> List[Image]:
146
- """Crops the image into multiple tiles.
147
 
148
  Args:
149
- image: The image to be cropped.
150
 
151
  Returns:
152
- The cropped images.
 
153
  """
154
 
155
- # TODO: Support more image processors.
156
- if not isinstance(self.image_processor, (SiglipImageProcessor, SiglipImageProcessorFast)):
157
- raise NotImplementedError
158
 
159
- assert self.image_processor.size["height"] == self.image_processor.size["width"]
160
- cropped_size = self.image_processor.size["height"]
161
 
162
- cropped_images: List[Image] = dynamic_preprocess(
163
- image,
164
- min_num=self.min_tiles,
165
- max_num=self.max_tiles,
166
- image_size=cropped_size,
167
- )
 
 
 
 
 
 
 
168
 
169
- return cropped_images
 
 
 
170
 
171
- def _pad_image_tokens_by_num_crops(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  self,
173
- text: List[str],
174
- *,
175
- num_cropped_images: List[int],
176
- video_flags: List[bool],
177
- ) -> List[str]:
178
- """Pads each \\<image> to num_cropped_images of "\\<image>\\n" for images and "\\<video>" for videos.
179
 
180
  Args:
181
- text: The text to be padded.
182
- num_cropped_images: The number of cropped images for each image token.
183
- video_flags: A list of flags indicating whether the num_cropped_images item is a video.
184
 
185
  Returns:
186
- The padded text.
187
  """
188
 
189
- assert len(num_cropped_images) == len(
190
- video_flags
191
- ), "num_cropped_images and video_flags must have the same length."
192
 
193
- image_token: str = cast(str, self.tokenizer.image_token)
 
194
 
195
- return_text: List[str] = []
196
 
197
- for text_item in text:
198
- return_text_item: str = ""
 
 
 
 
 
 
 
 
199
 
200
- # Repeatedly find image_token in the text.
201
- while image_token in text_item:
202
- image_pos = text_item.find(image_token)
203
 
204
- if image_pos != -1 and len(num_cropped_images) > 0:
205
- num_crops = num_cropped_images.pop(0)
206
- video_flag = video_flags.pop(0)
 
 
 
 
207
 
208
- return_text_item += (
209
- text_item[:image_pos] + (image_token if video_flag else (image_token + "\n")) * num_crops
210
- )
211
- text_item = text_item[image_pos + len(image_token) :]
212
 
213
- else:
214
- break
 
 
215
 
216
- # Must place outside the while loop.
217
- if image_token in text_item:
218
- raise ValueError("Too many image tokens in the text.")
 
 
219
 
220
- return_text_item += text_item
221
- text_item = ""
222
 
223
- return_text.append(return_text_item)
 
 
 
 
224
 
225
- if len(num_cropped_images) != 0:
226
- raise ValueError("Too many images provided.")
 
 
 
 
 
227
 
228
- return return_text
229
 
230
- def _pad_image_tokens_by_num_embeddings(
231
  self,
232
  text: List[str],
233
  ) -> List[str]:
234
- """Pads each \\<image> to image_pad_len times of "\\<image>".
 
 
 
235
 
236
  Args:
237
  text: The text to be padded.
@@ -240,147 +284,189 @@ class VILAProcessor(ProcessorMixin):
240
  The padded text.
241
  """
242
 
243
- return [
244
- text_item.replace(
245
- cast(str, self.tokenizer.image_token), cast(str, self.tokenizer.image_token) * self.image_pad_len
246
- )
247
- for text_item in text
248
- ]
249
-
250
- @staticmethod
251
- def _prepare_inputs(
252
- text: TextInput | List[TextInput],
253
- images: Optional[ImageInput],
254
- videos: Optional[VideoInput],
255
- ) -> Tuple[List[str], List[Image], List[List[Image]]]:
256
- # Prepare text.
257
- text = text if isinstance(text, list) else [text]
258
-
259
- # Prepare images.
260
- if images is not None:
261
- image_list = cast(List, image_utils.make_flat_list_of_images(images))
262
- images = [image_transforms.to_pil_image(image) for image in image_list]
263
- else:
264
- images = cast(List[Image], [])
265
-
266
- # Prepare videos.
267
- if videos is not None:
268
- video_list = cast(List[List], image_utils.make_batched_videos(videos))
269
- videos = [[image_transforms.to_pil_image(image) for image in video] for video in video_list]
270
- else:
271
- videos = cast(List[List[Image]], [])
272
 
273
- return text, images, videos
274
 
275
- def _process_images(
276
  self,
 
277
  images: List[Image],
278
- **kwargs: Unpack[ImagesKwargs],
279
- ) -> Tuple[BatchFeature, List[int]]:
280
- cropped_images: List[Image] = []
281
- num_cropped_images: List[int] = []
282
 
283
- for image in images:
284
- single_cropped_images = self._crop_image(image)
 
285
 
286
- cropped_images.extend(single_cropped_images)
287
- num_cropped_images.append(len(single_cropped_images))
 
 
288
 
289
- if len(cropped_images) == 0:
290
- # The image processor may not properly handle empty image lists.
291
- # This is a workaround to avoid errors.
292
- return BatchFeature(), num_cropped_images
293
 
294
- image_inputs = self.image_processor.__call__(
295
- cropped_images,
296
- **kwargs,
 
 
 
 
 
 
 
 
 
 
 
297
  )
298
 
299
- return image_inputs, num_cropped_images
 
 
300
 
301
- def _treat_videos_as_image_seqs(
302
- self, text: List[str], images: List[Image], videos: List[List[Image]]
303
- ) -> Tuple[List[str], List[Image], List[bool]]:
304
- """Treats videos as image sequences.
305
 
306
- This method will replace all video tokens in the text with #frame image tokens,
307
- and insert the corresponding images into the images list.
 
 
 
308
 
309
- Args:
310
- text: The text to be processed.
311
- images: The images to be processed.
312
- videos: The videos to be processed.
313
 
314
- Returns:
315
- The processed text and images, and a list of flags indicating whether the images are from videos.
316
- """
 
 
 
 
 
317
 
318
- image_token = cast(str, self.tokenizer.image_token)
319
- video_token = cast(str, self.tokenizer.video_token)
320
 
321
- return_text: List[str] = []
322
- return_images: List[Image] = []
323
- return_video_flags: List[bool] = []
 
 
 
 
 
324
 
325
- for text_item in text:
326
- return_text_item: str = ""
 
 
 
 
 
327
 
328
- # Repeatedly find image_token or video_token in the text.
329
- while image_token in text_item or video_token in text_item:
330
- image_pos = text_item.find(image_token)
331
- video_pos = text_item.find(video_token)
332
 
333
- # If not found, set position to the end of the text.
334
- if image_pos == -1:
335
- image_pos = len(text_item)
336
- if video_pos == -1:
337
- video_pos = len(text_item)
338
 
339
- if image_pos != len(text_item) and len(images) > 0 and image_pos < video_pos:
340
- # Take an image and keep the image token if:
341
- # - an image token is found, and
342
- # - there are images left, and
343
- # - the image token is before the first video token.
 
 
 
 
 
 
 
344
 
345
- image = images.pop(0)
346
- return_images.append(image)
347
- return_video_flags.append(False)
 
 
 
 
348
 
349
- return_text_item += text_item[: image_pos + len(image_token)]
350
- text_item = text_item[image_pos + len(image_token) :]
351
 
352
- elif video_pos != len(text_item) and len(videos) > 0 and video_pos < image_pos:
353
- # Take a video and replace the video token with #frame image tokens if:
354
- # - a video token is found, and
355
- # - there are videos left, and
356
- # - the video token is before the first image token.
 
 
 
357
 
358
- video = videos.pop(0)
359
- return_images.extend(video)
360
- return_video_flags.extend([True] * len(video))
 
361
 
362
- return_text_item += text_item[:video_pos] + image_token * len(video)
363
- text_item = text_item[video_pos + len(video_token) :]
364
- else:
365
- break
366
 
367
- # Must place outside the while loop.
368
- if image_token in text_item:
369
- raise ValueError("Too many image tokens in the text.")
370
- if video_token in text_item:
371
- raise ValueError("Too many video tokens in the text.")
 
 
 
 
372
 
373
- return_text_item += text_item
374
- text_item = ""
375
 
376
- return_text.append(return_text_item)
 
 
 
 
 
377
 
378
- if len(images) != 0:
379
- raise ValueError("Too many images provided.")
380
- if len(videos) != 0:
381
- raise ValueError("Too many videos provided.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- return return_text, return_images, return_video_flags
384
 
385
 
386
  def dynamic_preprocess(image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail=True) -> List[Image]:
 
1
+ import uuid
2
  from typing import List, Optional, Tuple, cast
3
 
4
  import transformers.image_transforms as image_transforms
5
  import transformers.image_utils as image_utils
6
  import transformers.utils.logging
7
+ import transformers.video_utils as video_utils
8
  from PIL.Image import Image
9
  from torch import Tensor
10
  from transformers.feature_extraction_utils import BatchFeature
11
  from transformers.image_processing_utils import BaseImageProcessor
12
  from transformers.image_processing_utils_fast import BaseImageProcessorFast
13
+ from transformers.image_utils import ImageInput
14
  from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
15
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
16
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
17
  from transformers.tokenization_utils import PreTrainedTokenizer
18
+ from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TextInput
19
+ from transformers.video_utils import VideoInput
20
 
21
  logger = transformers.utils.logging.get_logger(__name__)
22
 
 
44
  "image_pad_len",
45
  "max_tiles",
46
  "min_tiles",
47
+ "video_max_tiles",
48
  ]
49
 
50
  # Attributes.
 
55
  image_pad_len: int
56
  max_tiles: int
57
  min_tiles: int
58
+ video_max_tiles: int
59
 
60
  def __init__(
61
  self,
62
  image_processor: BaseImageProcessor,
63
  tokenizer: PreTrainedTokenizer,
64
  *,
65
+ image_pad_len: int = 121,
66
+ max_tiles: int = 12,
67
+ min_tiles: int = 1,
68
+ video_max_tiles: int = 1,
69
  **kwargs,
70
  ):
71
  super().__init__(
 
74
  **kwargs,
75
  )
76
 
77
+ self.image_pad_len = image_pad_len
78
+ self.max_tiles = max_tiles
79
+ self.min_tiles = min_tiles
80
+ self.video_max_tiles = video_max_tiles
81
 
82
  def __call__(
83
  self,
84
  text: TextInput | List[TextInput],
85
  images: Optional[ImageInput] = None,
86
  videos: Optional[VideoInput] = None,
87
+ **kwargs: Unpack[ProcessingKwargs],
 
88
  ) -> VILAProcessorOutput:
89
  """Preprocesses inputs for VILA.
90
 
 
92
  text: The text to be processed.
93
  images: The images to be processed.
94
  videos: The videos to be processed.
 
95
  **kwargs: Additional arguments for processing.
96
 
97
  Returns:
 
104
  **kwargs,
105
  )
106
 
107
+ normalized_text, normalized_images, normalized_videos = self._normalize_inputs(
108
  text=text,
109
  images=images,
110
  videos=videos,
111
  )
112
 
113
+ preprocessed_text, preprocessed_media_tiles = self._preprocess_inputs(
114
+ text=normalized_text,
115
+ images=normalized_images,
116
+ videos=normalized_videos,
 
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  text_inputs = self.tokenizer.__call__(
120
+ preprocessed_text,
121
  **merged_kwargs["text_kwargs"],
122
  )
123
 
124
+ if len(preprocessed_media_tiles) > 0:
125
+ image_inputs = self.image_processor.__call__(
126
+ preprocessed_media_tiles,
127
+ **merged_kwargs["images_kwargs"],
128
+ )
129
+ else:
130
+ image_inputs = BatchFeature()
131
+
132
+ text_inputs = self._replace_image_tile_suffix(text_inputs)
133
+
134
  return VILAProcessorOutput(
135
  data={
136
  **text_inputs,
 
138
  }
139
  )
140
 
141
+ def _find_media_token_order(self, text: List[str]) -> List[str]:
142
+ """Finds the order of media tokens in the text.
 
 
 
143
 
144
  Args:
145
+ text: The text to be processed.
146
 
147
  Returns:
148
+ The order of media tokens in the text. Each item is either an image token or a video
149
+ token.
150
  """
151
 
152
+ image_token = cast(str, self.tokenizer.image_token)
153
+ video_token = cast(str, self.tokenizer.video_token)
 
154
 
155
+ return_order: List[str] = []
 
156
 
157
+ for text_item in text:
158
+ while image_token in text_item or video_token in text_item:
159
+ image_pos = text_item.find(image_token)
160
+ video_pos = text_item.find(video_token)
161
+
162
+ if image_pos == -1 and video_pos == -1:
163
+ # If no media token found, move to the next text item.
164
+ break
165
+
166
+ elif image_pos == -1:
167
+ # If only video token found, add it to the return order.
168
+ return_order.append(video_token)
169
+ text_item = text_item[video_pos + len(video_token) :]
170
 
171
+ elif video_pos == -1:
172
+ # If only image token found, add it to the return order.
173
+ return_order.append(image_token)
174
+ text_item = text_item[image_pos + len(image_token) :]
175
 
176
+ else:
177
+ # If both tokens found, choose the one that appears first.
178
+ if image_pos < video_pos:
179
+ return_order.append(image_token)
180
+ text_item = text_item[image_pos + len(image_token) :]
181
+ else:
182
+ return_order.append(video_token)
183
+ text_item = text_item[video_pos + len(video_token) :]
184
+
185
+ return return_order
186
+
187
+ def _generate_image_token_placeholder(self, text: List[str]) -> str:
188
+ while True:
189
+ placeholder = f"<|image_placeholder_{str(uuid.uuid4())}|>"
190
+ if all(placeholder not in text_item for text_item in text):
191
+ return placeholder
192
+
193
+ def _merge_media_tiles(
194
  self,
195
+ image_tiles: List[List[Image]],
196
+ video_tiles: List[List[List[Image]]],
197
+ media_token_order: List[str],
198
+ ) -> List[Image]:
199
+ """Merges the media tiles by the media token order.
 
200
 
201
  Args:
202
+ image_tiles: The image tiles.
203
+ video_tiles: The video tiles.
204
+ media_token_order: The order of media tokens in the text.
205
 
206
  Returns:
207
+ The merged media tiles.
208
  """
209
 
210
+ image_token = cast(str, self.tokenizer.image_token)
211
+ video_token = cast(str, self.tokenizer.video_token)
 
212
 
213
+ image_tiles_idx = 0
214
+ video_tiles_idx = 0
215
 
216
+ return_tiles: List[Image] = []
217
 
218
+ for media_token in media_token_order:
219
+ if media_token == image_token:
220
+ return_tiles.extend(image_tiles[image_tiles_idx])
221
+ image_tiles_idx += 1
222
+ elif media_token == video_token:
223
+ for video_tile in video_tiles[video_tiles_idx]:
224
+ return_tiles.extend(video_tile)
225
+ video_tiles_idx += 1
226
+ else:
227
+ raise ValueError(f"Invalid media token: {media_token}")
228
 
229
+ return return_tiles
 
 
230
 
231
+ def _normalize_inputs(
232
+ self,
233
+ text: TextInput | List[TextInput],
234
+ images: Optional[ImageInput],
235
+ videos: Optional[VideoInput],
236
+ ) -> Tuple[List[str], List[Image], List[List[Image]]]:
237
+ """Normalizes text, image, and video inputs for processing.
238
 
239
+ This method converts various input formats into standardized lists of PIL images
240
+ and text strings that can be processed by the model.
 
 
241
 
242
+ Args:
243
+ text: The original input text.
244
+ images: The original input images.
245
+ videos: The original input videos.
246
 
247
+ Returns:
248
+ The text as a list of strings.
249
+ The images as a list of PIL images.
250
+ The videos as a list of lists of PIL images.
251
+ """
252
 
253
+ prepared_text = text if isinstance(text, list) else [text]
 
254
 
255
+ if images is not None:
256
+ image_list = cast(List, image_utils.make_flat_list_of_images(images))
257
+ prepared_images = [cast(Image, image_transforms.to_pil_image(image)) for image in image_list]
258
+ else:
259
+ prepared_images = []
260
 
261
+ if videos is not None:
262
+ video_list = cast(List[List], video_utils.make_batched_videos(videos))
263
+ prepared_videos = [
264
+ [cast(Image, image_transforms.to_pil_image(image)) for image in video] for video in video_list
265
+ ]
266
+ else:
267
+ prepared_videos = []
268
 
269
+ return prepared_text, prepared_images, prepared_videos
270
 
271
+ def _pad_image_tiles(
272
  self,
273
  text: List[str],
274
  ) -> List[str]:
275
+ """Pads each media tile.
276
+
277
+ This will pad each <image> to (self.image_pad_len + 1) times. The additional one padding is
278
+ for the \\n token suffix.
279
 
280
  Args:
281
  text: The text to be padded.
 
284
  The padded text.
285
  """
286
 
287
+ image_token = cast(str, self.tokenizer.image_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ return [text_item.replace(image_token, image_token * (self.image_pad_len + 1)) for text_item in text]
290
 
291
+ def _preprocess_inputs(
292
  self,
293
+ text: List[str],
294
  images: List[Image],
295
+ videos: List[List[Image]],
296
+ ) -> Tuple[List[str], List[Image]]:
297
+ """Preprocesses the input data for the VILA model.
 
298
 
299
+ This method takes a list of texts, images, and videos, and prepares them for the model.
300
+ It handles the interleaving of text and media, and returns the processed text and a
301
+ list of media tiles (images or video frames).
302
 
303
+ Args:
304
+ text: The input text.
305
+ images: The input images.
306
+ videos: The input videos.
307
 
308
+ Returns:
309
+ The text ready to be tokenized.
310
+ The media tiles ready to be processed.
311
+ """
312
 
313
+ media_token_order = self._find_media_token_order(text)
314
+
315
+ image_token_placeholder = self._generate_image_token_placeholder(text)
316
+
317
+ preprocessed_text = text
318
+ preprocessed_text, preprocessed_image_tiles = self._preprocess_images(
319
+ preprocessed_text,
320
+ images,
321
+ image_token_placeholder=image_token_placeholder,
322
+ )
323
+ preprocessed_text, preprocessed_video_tiles = self._preprocess_videos(
324
+ preprocessed_text,
325
+ videos,
326
+ image_token_placeholder=image_token_placeholder,
327
  )
328
 
329
+ # Convert back to the original image token.
330
+ image_token = cast(str, self.tokenizer.image_token)
331
+ preprocessed_text = [text_item.replace(image_token_placeholder, image_token) for text_item in preprocessed_text]
332
 
333
+ preprocessed_text = self._pad_image_tiles(preprocessed_text)
 
 
 
334
 
335
+ preprocessed_media_tiles = self._merge_media_tiles(
336
+ preprocessed_image_tiles,
337
+ preprocessed_video_tiles,
338
+ media_token_order,
339
+ )
340
 
341
+ return preprocessed_text, preprocessed_media_tiles
 
 
 
342
 
343
+ def _preprocess_images(
344
+ self,
345
+ text: List[str],
346
+ images: List[Image],
347
+ *,
348
+ image_token_placeholder: str,
349
+ ) -> Tuple[List[str], List[List[Image]]]:
350
+ single_image_token_placeholder = self._generate_image_token_placeholder(text)
351
 
352
+ preprocessed_text = text
353
+ preprocessed_image_tiles: List[List[Image]] = []
354
 
355
+ for image in images:
356
+ preprocessed_text, preprocessed_single_image_tiles = self._preprocess_single_image(
357
+ text,
358
+ image,
359
+ image_token_placeholder=single_image_token_placeholder,
360
+ is_video_frame=False,
361
+ use_dynamic_preprocess=(len(images) == 1),
362
+ )
363
 
364
+ preprocessed_text = [
365
+ text_item.replace(
366
+ single_image_token_placeholder,
367
+ (image_token_placeholder + "\n") if len(images) == 1 else image_token_placeholder,
368
+ )
369
+ for text_item in preprocessed_text
370
+ ]
371
 
372
+ preprocessed_image_tiles.append(preprocessed_single_image_tiles)
 
 
 
373
 
374
+ return preprocessed_text, preprocessed_image_tiles
 
 
 
 
375
 
376
+ def _preprocess_single_image(
377
+ self,
378
+ text: List[str],
379
+ image: Image,
380
+ *,
381
+ image_token_placeholder: str,
382
+ is_video_frame: bool,
383
+ use_dynamic_preprocess: bool,
384
+ ) -> Tuple[List[str], List[Image]]:
385
+ assert isinstance(self.image_processor, (SiglipImageProcessor, SiglipImageProcessorFast))
386
+ assert self.image_processor.size["height"] == self.image_processor.size["width"]
387
+ cropped_size = self.image_processor.size["height"]
388
 
389
+ if use_dynamic_preprocess:
390
+ if is_video_frame:
391
+ max_num = self.video_max_tiles
392
+ else:
393
+ max_num = self.max_tiles
394
+ else:
395
+ max_num = 1
396
 
397
+ image = image.convert("RGB")
 
398
 
399
+ cropped_images: List[Image] = dynamic_preprocess(
400
+ image,
401
+ min_num=self.min_tiles,
402
+ max_num=max_num,
403
+ image_size=cropped_size,
404
+ )
405
+
406
+ image_token = cast(str, self.tokenizer.image_token)
407
 
408
+ for i in range(len(text)):
409
+ if image_token in text[i]:
410
+ text[i] = text[i].replace(image_token, image_token_placeholder * len(cropped_images))
411
+ break
412
 
413
+ return text, cropped_images
 
 
 
414
 
415
+ def _preprocess_videos(
416
+ self,
417
+ text: List[str],
418
+ videos: List[List[Image]],
419
+ *,
420
+ image_token_placeholder: str,
421
+ ) -> Tuple[List[str], List[List[List[Image]]]]:
422
+ image_token = cast(str, self.tokenizer.image_token)
423
+ video_token = cast(str, self.tokenizer.video_token)
424
 
425
+ processed_text = text
426
+ processed_video_tiles: List[List[List[Image]]] = []
427
 
428
+ for video in videos:
429
+ # Replace the first video token with #frame image tokens.
430
+ for i in range(len(processed_text)):
431
+ if video_token in processed_text[i]:
432
+ processed_text[i] = processed_text[i].replace(video_token, image_token * len(video))
433
+ break
434
 
435
+ processed_frame_tiles: List[List[Image]] = []
436
+ for frame in video:
437
+ processed_text, processed_single_frame_tiles = self._preprocess_single_image(
438
+ processed_text,
439
+ frame,
440
+ image_token_placeholder=image_token_placeholder,
441
+ is_video_frame=True,
442
+ use_dynamic_preprocess=(self.video_max_tiles > 1),
443
+ )
444
+ processed_frame_tiles.append(processed_single_frame_tiles)
445
+
446
+ processed_video_tiles.append(processed_frame_tiles)
447
+
448
+ return processed_text, processed_video_tiles
449
+
450
+ def _replace_image_tile_suffix(self, text_inputs: BatchEncoding) -> BatchEncoding:
451
+ lf_token_id = cast(int, self.tokenizer.encode("\n")[0])
452
+ image_token_id = cast(int, self.tokenizer.image_token_id)
453
+
454
+ for i in range(len(text_inputs.input_ids)):
455
+ input_ids = text_inputs.input_ids[i]
456
+
457
+ idx = 0
458
+ while idx < len(input_ids):
459
+ if input_ids[idx] != image_token_id:
460
+ idx += 1
461
+ continue
462
+
463
+ if idx + self.image_pad_len < len(input_ids):
464
+ input_ids[idx + self.image_pad_len] = lf_token_id
465
+ idx += self.image_pad_len + 1
466
+ else:
467
+ break
468
 
469
+ return text_inputs
470
 
471
 
472
  def dynamic_preprocess(image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail=True) -> List[Image]:
processor_config.json CHANGED
@@ -2,8 +2,9 @@
2
  "auto_map": {
3
  "AutoProcessor": "processing_vila.VILAProcessor"
4
  },
5
- "image_pad_len": 122,
6
  "max_tiles": 12,
7
  "min_tiles": 1,
8
- "processor_class": "VILAProcessor"
 
9
  }
 
2
  "auto_map": {
3
  "AutoProcessor": "processing_vila.VILAProcessor"
4
  },
5
+ "image_pad_len": 121,
6
  "max_tiles": 12,
7
  "min_tiles": 1,
8
+ "processor_class": "VILAProcessor",
9
+ "video_max_tiles": 1
10
  }
tokenizer_config.json CHANGED
@@ -249,7 +249,6 @@
249
  "AutoProcessor": "processing_vila.VILAProcessor"
250
  },
251
  "bos_token": "[BOS]",
252
- "chat_template": null,
253
  "clean_up_tokenization_spaces": false,
254
  "eos_token": "<|im_end|>",
255
  "errors": "replace",
 
249
  "AutoProcessor": "processing_vila.VILAProcessor"
250
  },
251
  "bos_token": "[BOS]",
 
252
  "clean_up_tokenization_spaces": false,
253
  "eos_token": "<|im_end|>",
254
  "errors": "replace",