AndyZijianZhang commited on
Commit
1bc10e1
·
verified ·
1 Parent(s): 4b68f80

Upload files with `vila-upload`.

Browse files

Upload tokenizer_config.json
Upload config.json
Upload configuration_vila.py
Upload generation_config.json
Upload chat_template.jinja
Upload processing_vila.py
Upload processor_config.json
Upload modeling_vila.py

chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' }}{% if message['content'] is string %}{{ message['content'] + '<|im_end|>\n' }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{{ '<image>' }}{% elif content['type'] == 'video' or 'video' in content %}{{ '<video>' }}{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
config.json CHANGED
@@ -10,7 +10,6 @@
10
  "AutoModelForVision2Seq": "modeling_vila.VILAForConditionalGeneration"
11
  },
12
  "hidden_size": 3584,
13
- "image_end_token_id": 198,
14
  "image_token_id": 151648,
15
  "mm_hidden_size": 1152,
16
  "mm_projector_type": "mlp_downsample_3x3_fix",
@@ -47,7 +46,7 @@
47
  "vocab_size": 151648
48
  },
49
  "torch_dtype": "bfloat16",
50
- "transformers_version": "4.51.3",
51
  "video_token_id": 151649,
52
  "vision_config": {
53
  "architectures": [
 
10
  "AutoModelForVision2Seq": "modeling_vila.VILAForConditionalGeneration"
11
  },
12
  "hidden_size": 3584,
 
13
  "image_token_id": 151648,
14
  "mm_hidden_size": 1152,
15
  "mm_projector_type": "mlp_downsample_3x3_fix",
 
46
  "vocab_size": 151648
47
  },
48
  "torch_dtype": "bfloat16",
49
+ "transformers_version": "4.52.3",
50
  "video_token_id": 151649,
51
  "vision_config": {
52
  "architectures": [
configuration_vila.py CHANGED
@@ -21,7 +21,6 @@ class VILAConfig(PretrainedConfig):
21
  # Model configuration.
22
  hidden_size: int
23
  image_token_id: int
24
- image_end_token_id: int
25
  mm_hidden_size: int
26
  mm_projector_type: str
27
  mm_vision_select_feature: str
@@ -30,17 +29,16 @@ class VILAConfig(PretrainedConfig):
30
 
31
  def __init__(
32
  self,
33
- *,
34
  text_config: Optional[Dict[str, Any]] = None,
35
  vision_config: Optional[Dict[str, Any]] = None,
36
- hidden_size: Optional[int] = None,
37
- image_token_id: Optional[int] = None,
38
- image_end_token_id: Optional[int] = None,
39
- mm_hidden_size: Optional[int] = None,
40
- mm_projector_type: Optional[str] = None,
41
- mm_vision_select_feature: Optional[str] = None,
42
- mm_vision_select_layer: Optional[int] = None,
43
- video_token_id: Optional[int] = None,
44
  **kwargs,
45
  ):
46
  super().__init__(**kwargs)
@@ -48,14 +46,10 @@ class VILAConfig(PretrainedConfig):
48
  self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
49
  self.vision_config = SiglipVisionConfig(**vision_config) if vision_config else SiglipVisionConfig()
50
 
51
- # By default, we use settings from NVILA-Lite.
52
- self.hidden_size = hidden_size if hidden_size is not None else 1536
53
- self.image_token_id = image_token_id if image_token_id is not None else 151649
54
- self.image_end_token_id = image_end_token_id if image_end_token_id is not None else 198 # "\n"
55
- self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else 1152
56
- self.mm_projector_type = mm_projector_type if mm_projector_type is not None else "mlp_downsample_3x3_fix"
57
- self.mm_vision_select_feature = (
58
- mm_vision_select_feature if mm_vision_select_feature is not None else "cls_patch"
59
- )
60
- self.mm_vision_select_layer = mm_vision_select_layer if mm_vision_select_layer is not None else -2
61
- self.video_token_id = video_token_id if video_token_id is not None else 151650
 
21
  # Model configuration.
22
  hidden_size: int
23
  image_token_id: int
 
24
  mm_hidden_size: int
25
  mm_projector_type: str
26
  mm_vision_select_feature: str
 
29
 
30
  def __init__(
31
  self,
 
32
  text_config: Optional[Dict[str, Any]] = None,
33
  vision_config: Optional[Dict[str, Any]] = None,
34
+ *,
35
+ hidden_size: int = 1536,
36
+ image_token_id: int = 151649,
37
+ mm_hidden_size: int = 1152,
38
+ mm_projector_type: str = "mlp_downsample_3x3_fix",
39
+ mm_vision_select_feature: str = "cls_patch",
40
+ mm_vision_select_layer: int = -2,
41
+ video_token_id: int = 151650,
42
  **kwargs,
43
  ):
44
  super().__init__(**kwargs)
 
46
  self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
47
  self.vision_config = SiglipVisionConfig(**vision_config) if vision_config else SiglipVisionConfig()
48
 
49
+ self.hidden_size = hidden_size
50
+ self.image_token_id = image_token_id
51
+ self.mm_hidden_size = mm_hidden_size
52
+ self.mm_projector_type = mm_projector_type
53
+ self.mm_vision_select_feature = mm_vision_select_feature
54
+ self.mm_vision_select_layer = mm_vision_select_layer
55
+ self.video_token_id = video_token_id
 
 
 
 
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
  "pad_token_id": 151643,
6
- "transformers_version": "4.51.3"
7
  }
 
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
  "pad_token_id": 151643,
6
+ "transformers_version": "4.52.3"
7
  }
modeling_vila.py CHANGED
@@ -2,6 +2,7 @@ from typing import List, Optional, Type
2
 
3
  import torch
4
  import torch.nn as nn
 
5
  from torch import Tensor
6
  from transformers.configuration_utils import PretrainedConfig
7
  from transformers.generation.utils import GenerationMixin
@@ -13,68 +14,34 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
13
  from .configuration_vila import VILAConfig
14
 
15
 
16
- class DownSampleBlock(nn.Module):
17
- @staticmethod
18
- def flat_square(x: Tensor) -> Tensor:
19
- n, w, h, c = x.size()
20
- if w % 2 == 1:
21
- x = torch.concat([x, torch.zeros((n, 1, h, c), device=x.device, dtype=x.dtype)], dim=1).contiguous()
22
- n, w, h, c = x.size()
23
- if h % 2 == 1:
24
- x = torch.concat([x, torch.zeros((n, w, 1, c), device=x.device, dtype=x.dtype)], dim=2).contiguous()
25
- n, w, h, c = x.size()
26
- x = x.contiguous()
27
- x = x.view(n, w, int(h / 2), int(c * 2))
28
- x = x.permute(0, 2, 1, 3).contiguous()
29
- x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
30
- x = x.permute(0, 2, 1, 3).contiguous()
31
- return x
32
-
33
  def forward(self, x: Tensor) -> Tensor:
34
- vit_embeds = x
35
- h = w = int(vit_embeds.shape[1] ** 0.5)
36
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
37
- vit_embeds = self.flat_square(vit_embeds)
38
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
39
- return vit_embeds
40
 
 
 
 
41
 
42
- class DownSample3x3BlockFix(nn.Module):
43
- @staticmethod
44
- def flat_square_3x3(x: Tensor) -> Tensor:
45
- n, w, h, c = x.size()
46
- if w % 3 != 0:
47
- x = torch.concat(
48
- [
49
- x,
50
- torch.zeros((n, 3 - (w % 3), h, c), device=x.device, dtype=x.dtype),
51
- ],
52
- dim=1,
53
- ).contiguous()
54
- n, w, h, c = x.size()
55
- x = x.contiguous()
56
- if h % 3 != 0:
57
- x = torch.concat(
58
- [
59
- x,
60
- torch.zeros((n, w, 3 - (h % 3), c), device=x.device, dtype=x.dtype),
61
- ],
62
- dim=2,
63
- ).contiguous()
64
- n, w, h, c = x.size()
65
- x = x.view(n, w, int(h / 3), int(c * 3))
66
- x = x.permute(0, 2, 1, 3).contiguous()
67
- x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
68
- x = x.permute(0, 2, 1, 3).contiguous()
69
- return x
70
 
71
- def forward(self, x: Tensor) -> Tensor:
72
- vit_embeds = x
73
- h = w = int(vit_embeds.shape[1] ** 0.5)
74
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
75
- vit_embeds = self.flat_square_3x3(vit_embeds)
76
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
77
- return vit_embeds
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  class MultimodalProjector(nn.Module):
@@ -89,18 +56,6 @@ class MultimodalProjector(nn.Module):
89
  super().__init__(*args, **kwargs)
90
 
91
  match config.mm_projector_type:
92
- case "linear":
93
- self.layers = nn.Sequential(
94
- nn.Linear(config.vision_config.hidden_size, config.hidden_size),
95
- )
96
- case "mlp_downsample":
97
- self.layers = nn.Sequential(
98
- DownSampleBlock(),
99
- nn.LayerNorm(config.mm_hidden_size * 4),
100
- nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
101
- nn.GELU(),
102
- nn.Linear(config.hidden_size, config.hidden_size),
103
- )
104
  case "mlp_downsample_3x3_fix":
105
  self.layers = nn.Sequential(
106
  DownSample3x3BlockFix(),
@@ -116,9 +71,9 @@ class MultimodalProjector(nn.Module):
116
  nn.Linear(config.hidden_size, config.hidden_size),
117
  )
118
  case _:
119
- raise NotImplementedError(f"mm_projector_type={config.mm_projector_type} not implemented.")
120
 
121
- self.layers.to(dtype=config.torch_dtype)
122
 
123
  @property
124
  def device(self) -> torch.device:
@@ -129,7 +84,15 @@ class MultimodalProjector(nn.Module):
129
  return next(self.parameters()).dtype
130
 
131
  def forward(self, x: Tensor) -> Tensor:
132
- return self.layers(x)
 
 
 
 
 
 
 
 
133
 
134
 
135
  class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
@@ -156,9 +119,9 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
156
  ):
157
  super().__init__(config, *args, **kwargs)
158
 
159
- self.llm = Qwen2ForCausalLM(config.text_config, *args, **kwargs)
160
  self.mm_projector = MultimodalProjector(config)
161
- self.vision_tower = SiglipVisionModel(config.vision_config, *args, **kwargs)
162
 
163
  self.post_init()
164
 
@@ -175,29 +138,15 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
175
  if kwargs.get("past_key_values", None) is not None:
176
  pixel_values = None
177
 
178
- inputs_embeds = inputs_embeds.to(dtype=self.dtype) if inputs_embeds is not None else None
179
- pixel_values = pixel_values.to(dtype=self.dtype) if pixel_values is not None else None
180
-
181
  if inputs_embeds is None:
182
- assert input_ids is not None
 
183
 
184
  inputs_embeds = self._embed(input_ids, pixel_values)
185
- else:
186
- assert input_ids is None
187
- assert pixel_values is None
188
 
189
  outputs = self.llm.__call__(
190
- inputs_embeds=inputs_embeds.to(
191
- device=self.llm.device,
192
- dtype=self.llm.dtype,
193
- ),
194
- attention_mask=(
195
- attention_mask.to(
196
- device=self.llm.device,
197
- )
198
- if attention_mask is not None
199
- else None
200
- ),
201
  **kwargs,
202
  )
203
 
@@ -221,8 +170,6 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
221
  The embedding of the input ids and pixel values.
222
  """
223
 
224
- # Video tokens should be removed during preprocessing, so there must not be any video
225
- # tokens in the input ids.
226
  if torch.any(input_ids == self.config.video_token_id):
227
  raise ValueError("Video token ids should not be present in the input ids.")
228
 
@@ -233,56 +180,38 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
233
  if pixel_values is None:
234
  return text_embedding
235
 
236
- image_features: BaseModelOutputWithPooling = self.vision_tower.__call__(
237
- pixel_values.to(
238
- device=self.vision_tower.device,
239
- dtype=self.vision_tower.dtype,
240
- ),
241
  output_hidden_states=True,
242
  )
243
- assert image_features.hidden_states is not None
244
 
245
- # Select image feature.
246
- selected_layer_output = image_features.hidden_states[self.config.mm_vision_select_layer]
247
- match self.config.mm_vision_select_feature:
248
- case "cls_patch":
249
- selected_feature = selected_layer_output
250
- case _:
251
- raise NotImplementedError(
252
- f"mm_vision_select_feature={self.config.mm_vision_select_feature} not implemented."
253
- )
254
-
255
- # TODO: Support dynamic_s2.
256
 
257
  image_embedding: Tensor = self.mm_projector.__call__(
258
- selected_feature.to(
259
- device=self.mm_projector.device,
260
- dtype=self.mm_projector.dtype,
261
- )
262
  )
263
 
264
- # Append image end token to every image embedding.
265
- image_end_token_embedding: Tensor = self.llm.get_input_embeddings().__call__(
266
- torch.tensor(
267
- self.config.image_end_token_id,
268
- device=text_embedding.device,
269
- dtype=torch.long,
270
- ).view(1, -1)
271
- ) # Shape: (1, 1, dim_feature)
272
- image_end_token_embedding = image_end_token_embedding.expand(
273
- image_embedding.shape[0], 1, -1
274
- ) # Shape: (n_images, 1, dim_feature)
275
- image_embedding = torch.concat(
276
- [
277
- image_embedding.to(device=text_embedding.device),
278
- image_end_token_embedding,
279
- ],
280
- dim=1,
281
  )
282
 
283
- n_images, n_feature, dim_feature = image_embedding.shape
284
- image_embedding = image_embedding.view(n_images * n_feature, dim_feature)
285
 
286
- text_embedding[image_token_mask.to(device=text_embedding.device)] = image_embedding
 
 
 
 
287
 
288
- return text_embedding
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  import torch.nn as nn
5
+ import torch.nn.functional as F
6
  from torch import Tensor
7
  from transformers.configuration_utils import PretrainedConfig
8
  from transformers.generation.utils import GenerationMixin
 
14
  from .configuration_vila import VILAConfig
15
 
16
 
17
+ class DownSample3x3BlockFix(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def forward(self, x: Tensor) -> Tensor:
19
+ """
20
+ Args:
21
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
 
 
 
22
 
23
+ Returns:
24
+ The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
25
+ """
26
 
27
+ batch_size, sequence_length, hidden_size = x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ feat_size = int(sequence_length**0.5)
30
+ if feat_size**2 != sequence_length:
31
+ raise ValueError(f"Cannot take square root: sequence_length {sequence_length} is not a perfect square")
32
+
33
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
34
+
35
+ pad_after = (3 - feat_size % 3) % 3
36
+ if pad_after > 0:
37
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
38
+ feat_size = feat_size + pad_after
39
+
40
+ features = features.reshape(batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size)
41
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
42
+ features = features.reshape(batch_size, -1, 9 * hidden_size)
43
+
44
+ return features
45
 
46
 
47
  class MultimodalProjector(nn.Module):
 
56
  super().__init__(*args, **kwargs)
57
 
58
  match config.mm_projector_type:
 
 
 
 
 
 
 
 
 
 
 
 
59
  case "mlp_downsample_3x3_fix":
60
  self.layers = nn.Sequential(
61
  DownSample3x3BlockFix(),
 
71
  nn.Linear(config.hidden_size, config.hidden_size),
72
  )
73
  case _:
74
+ raise NotImplementedError(f"Unsupported mm_projector_type: {config.mm_projector_type}")
75
 
76
+ self.layers.type(config.torch_dtype)
77
 
78
  @property
79
  def device(self) -> torch.device:
 
84
  return next(self.parameters()).dtype
85
 
86
  def forward(self, x: Tensor) -> Tensor:
87
+ """
88
+ Args:
89
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
90
+
91
+ Returns:
92
+ The output tensor of shape (batch_size, image_pad_len, hidden_size).
93
+ """
94
+
95
+ return self.layers(x.to(device=self.device, dtype=self.dtype))
96
 
97
 
98
  class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
 
119
  ):
120
  super().__init__(config, *args, **kwargs)
121
 
122
+ self.llm = Qwen2ForCausalLM._from_config(config.text_config, *args, **kwargs)
123
  self.mm_projector = MultimodalProjector(config)
124
+ self.vision_tower = SiglipVisionModel._from_config(config.vision_config, *args, **kwargs)
125
 
126
  self.post_init()
127
 
 
138
  if kwargs.get("past_key_values", None) is not None:
139
  pixel_values = None
140
 
 
 
 
141
  if inputs_embeds is None:
142
+ if input_ids is None:
143
+ raise ValueError("input_ids is required when inputs_embeds is None")
144
 
145
  inputs_embeds = self._embed(input_ids, pixel_values)
 
 
 
146
 
147
  outputs = self.llm.__call__(
148
+ inputs_embeds=inputs_embeds.to(device=self.llm.device, dtype=self.llm.dtype),
149
+ attention_mask=(attention_mask.to(device=self.llm.device) if attention_mask is not None else None),
 
 
 
 
 
 
 
 
 
150
  **kwargs,
151
  )
152
 
 
170
  The embedding of the input ids and pixel values.
171
  """
172
 
 
 
173
  if torch.any(input_ids == self.config.video_token_id):
174
  raise ValueError("Video token ids should not be present in the input ids.")
175
 
 
180
  if pixel_values is None:
181
  return text_embedding
182
 
183
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
184
+ pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype),
 
 
 
185
  output_hidden_states=True,
186
  )
 
187
 
188
+ mm_projector_input = self._vision_tower_output_to_mm_projector_input(vision_tower_output)
 
 
 
 
 
 
 
 
 
 
189
 
190
  image_embedding: Tensor = self.mm_projector.__call__(
191
+ mm_projector_input.to(device=self.mm_projector.device, dtype=self.mm_projector.dtype)
 
 
 
192
  )
193
 
194
+ image_embedding = image_embedding.reshape(-1, image_embedding.shape[-1])
195
+
196
+ text_embedding.masked_scatter_(
197
+ image_token_mask.to(device=text_embedding.device, dtype=torch.bool).unsqueeze(-1),
198
+ image_embedding.to(device=text_embedding.device, dtype=text_embedding.dtype).flatten(),
 
 
 
 
 
 
 
 
 
 
 
 
199
  )
200
 
201
+ return text_embedding
 
202
 
203
+ def _vision_tower_output_to_mm_projector_input(
204
+ self,
205
+ vision_tower_output: BaseModelOutputWithPooling,
206
+ ) -> Tensor:
207
+ assert vision_tower_output.hidden_states is not None
208
 
209
+ selected_layer_hidden_states = vision_tower_output.hidden_states[self.config.mm_vision_select_layer]
210
+
211
+ match self.config.mm_vision_select_feature:
212
+ case "cls_patch":
213
+ return selected_layer_hidden_states
214
+ case _:
215
+ raise NotImplementedError(
216
+ f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
217
+ )
processing_vila.py CHANGED
@@ -3,17 +3,19 @@ from typing import List, Optional, Tuple, cast
3
  import transformers.image_transforms as image_transforms
4
  import transformers.image_utils as image_utils
5
  import transformers.utils.logging
 
6
  from PIL.Image import Image
7
  from torch import Tensor
8
  from transformers.feature_extraction_utils import BatchFeature
9
  from transformers.image_processing_utils import BaseImageProcessor
10
  from transformers.image_processing_utils_fast import BaseImageProcessorFast
11
- from transformers.image_utils import ImageInput, VideoInput
12
  from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
13
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
14
- from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
15
  from transformers.tokenization_utils import PreTrainedTokenizer
16
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
 
17
 
18
  logger = transformers.utils.logging.get_logger(__name__)
19
 
@@ -41,6 +43,7 @@ class VILAProcessor(ProcessorMixin):
41
  "image_pad_len",
42
  "max_tiles",
43
  "min_tiles",
 
44
  ]
45
 
46
  # Attributes.
@@ -51,15 +54,17 @@ class VILAProcessor(ProcessorMixin):
51
  image_pad_len: int
52
  max_tiles: int
53
  min_tiles: int
 
54
 
55
  def __init__(
56
  self,
57
  image_processor: BaseImageProcessor,
58
  tokenizer: PreTrainedTokenizer,
59
  *,
60
- image_pad_len: Optional[int] = None,
61
- max_tiles: Optional[int] = None,
62
- min_tiles: Optional[int] = None,
 
63
  **kwargs,
64
  ):
65
  super().__init__(
@@ -68,9 +73,10 @@ class VILAProcessor(ProcessorMixin):
68
  **kwargs,
69
  )
70
 
71
- self.image_pad_len = image_pad_len if image_pad_len is not None else 122
72
- self.max_tiles = max_tiles if max_tiles is not None else 12
73
- self.min_tiles = min_tiles if min_tiles is not None else 1
 
74
 
75
  def __call__(
76
  self,
@@ -78,7 +84,7 @@ class VILAProcessor(ProcessorMixin):
78
  images: Optional[ImageInput] = None,
79
  videos: Optional[VideoInput] = None,
80
  audio: None = None,
81
- **kwargs: Unpack[VILAProcessorProcessingKwargs],
82
  ) -> VILAProcessorOutput:
83
  """Preprocesses inputs for VILA.
84
 
@@ -99,39 +105,59 @@ class VILAProcessor(ProcessorMixin):
99
  **kwargs,
100
  )
101
 
102
- text, images, videos = self._prepare_inputs(
103
  text=text,
104
  images=images,
105
  videos=videos,
106
  )
107
 
108
  # Process videos.
109
- text, images, video_flags = self._treat_videos_as_image_seqs(
110
- text=text,
111
- images=images,
112
- videos=videos,
113
  )
114
 
115
  # Process images.
116
  image_inputs, num_cropped_images = self._process_images(
117
- images=images,
 
118
  **merged_kwargs["images_kwargs"],
119
  )
120
 
121
  # Process text.
122
- text = self._pad_image_tokens_by_num_crops(
123
- text,
124
  num_cropped_images=num_cropped_images,
125
  video_flags=video_flags,
126
  )
127
 
128
- text = self._pad_image_tokens_by_num_embeddings(text)
129
 
130
  text_inputs = self.tokenizer.__call__(
131
- text,
132
  **merged_kwargs["text_kwargs"],
133
  )
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return VILAProcessorOutput(
136
  data={
137
  **text_inputs,
@@ -142,6 +168,8 @@ class VILAProcessor(ProcessorMixin):
142
  def _crop_image(
143
  self,
144
  image: Image,
 
 
145
  ) -> List[Image]:
146
  """Crops the image into multiple tiles.
147
 
@@ -162,7 +190,7 @@ class VILAProcessor(ProcessorMixin):
162
  cropped_images: List[Image] = dynamic_preprocess(
163
  image,
164
  min_num=self.min_tiles,
165
- max_num=self.max_tiles,
166
  image_size=cropped_size,
167
  )
168
 
@@ -240,12 +268,9 @@ class VILAProcessor(ProcessorMixin):
240
  The padded text.
241
  """
242
 
243
- return [
244
- text_item.replace(
245
- cast(str, self.tokenizer.image_token), cast(str, self.tokenizer.image_token) * self.image_pad_len
246
- )
247
- for text_item in text
248
- ]
249
 
250
  @staticmethod
251
  def _prepare_inputs(
@@ -253,35 +278,36 @@ class VILAProcessor(ProcessorMixin):
253
  images: Optional[ImageInput],
254
  videos: Optional[VideoInput],
255
  ) -> Tuple[List[str], List[Image], List[List[Image]]]:
256
- # Prepare text.
257
- text = text if isinstance(text, list) else [text]
258
 
259
- # Prepare images.
260
  if images is not None:
261
  image_list = cast(List, image_utils.make_flat_list_of_images(images))
262
- images = [image_transforms.to_pil_image(image) for image in image_list]
263
  else:
264
- images = cast(List[Image], [])
265
 
266
- # Prepare videos.
267
  if videos is not None:
268
- video_list = cast(List[List], image_utils.make_batched_videos(videos))
269
- videos = [[image_transforms.to_pil_image(image) for image in video] for video in video_list]
 
 
270
  else:
271
- videos = cast(List[List[Image]], [])
272
 
273
- return text, images, videos
274
 
275
  def _process_images(
276
  self,
277
  images: List[Image],
278
- **kwargs: Unpack[ImagesKwargs],
 
 
279
  ) -> Tuple[BatchFeature, List[int]]:
280
  cropped_images: List[Image] = []
281
  num_cropped_images: List[int] = []
282
 
283
- for image in images:
284
- single_cropped_images = self._crop_image(image)
285
 
286
  cropped_images.extend(single_cropped_images)
287
  num_cropped_images.append(len(single_cropped_images))
 
3
  import transformers.image_transforms as image_transforms
4
  import transformers.image_utils as image_utils
5
  import transformers.utils.logging
6
+ import transformers.video_utils as video_utils
7
  from PIL.Image import Image
8
  from torch import Tensor
9
  from transformers.feature_extraction_utils import BatchFeature
10
  from transformers.image_processing_utils import BaseImageProcessor
11
  from transformers.image_processing_utils_fast import BaseImageProcessorFast
12
+ from transformers.image_utils import ImageInput
13
  from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
14
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
15
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
16
  from transformers.tokenization_utils import PreTrainedTokenizer
17
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
18
+ from transformers.video_utils import VideoInput
19
 
20
  logger = transformers.utils.logging.get_logger(__name__)
21
 
 
43
  "image_pad_len",
44
  "max_tiles",
45
  "min_tiles",
46
+ "video_max_tiles",
47
  ]
48
 
49
  # Attributes.
 
54
  image_pad_len: int
55
  max_tiles: int
56
  min_tiles: int
57
+ video_max_tiles: int
58
 
59
  def __init__(
60
  self,
61
  image_processor: BaseImageProcessor,
62
  tokenizer: PreTrainedTokenizer,
63
  *,
64
+ image_pad_len: int = 121,
65
+ max_tiles: int = 12,
66
+ min_tiles: int = 1,
67
+ video_max_tiles: int = 1,
68
  **kwargs,
69
  ):
70
  super().__init__(
 
73
  **kwargs,
74
  )
75
 
76
+ self.image_pad_len = image_pad_len
77
+ self.max_tiles = max_tiles
78
+ self.min_tiles = min_tiles
79
+ self.video_max_tiles = video_max_tiles
80
 
81
  def __call__(
82
  self,
 
84
  images: Optional[ImageInput] = None,
85
  videos: Optional[VideoInput] = None,
86
  audio: None = None,
87
+ **kwargs: Unpack[ProcessingKwargs],
88
  ) -> VILAProcessorOutput:
89
  """Preprocesses inputs for VILA.
90
 
 
105
  **kwargs,
106
  )
107
 
108
+ prepared_text, prepared_images, prepared_videos = self._prepare_inputs(
109
  text=text,
110
  images=images,
111
  videos=videos,
112
  )
113
 
114
  # Process videos.
115
+ prepared_text, prepared_images, video_flags = self._treat_videos_as_image_seqs(
116
+ text=prepared_text,
117
+ images=prepared_images,
118
+ videos=prepared_videos,
119
  )
120
 
121
  # Process images.
122
  image_inputs, num_cropped_images = self._process_images(
123
+ images=prepared_images,
124
+ video_flags=video_flags,
125
  **merged_kwargs["images_kwargs"],
126
  )
127
 
128
  # Process text.
129
+ prepared_text = self._pad_image_tokens_by_num_crops(
130
+ prepared_text,
131
  num_cropped_images=num_cropped_images,
132
  video_flags=video_flags,
133
  )
134
 
135
+ prepared_text = self._pad_image_tokens_by_num_embeddings(prepared_text)
136
 
137
  text_inputs = self.tokenizer.__call__(
138
+ prepared_text,
139
  **merged_kwargs["text_kwargs"],
140
  )
141
 
142
+ # Find the last image token of each image tile and replace to "\n".
143
+ lf_token_id = self.tokenizer.encode("\n")[0]
144
+ image_token_id = self.tokenizer.image_token_id
145
+
146
+ for i in range(len(text_inputs.input_ids)):
147
+ input_ids = text_inputs.input_ids[i]
148
+
149
+ idx = 0
150
+ while idx < len(input_ids):
151
+ if input_ids[idx] != image_token_id:
152
+ idx += 1
153
+ continue
154
+
155
+ if idx + self.image_pad_len < len(input_ids):
156
+ input_ids[idx + self.image_pad_len] = lf_token_id
157
+ idx += self.image_pad_len + 1
158
+ else:
159
+ break
160
+
161
  return VILAProcessorOutput(
162
  data={
163
  **text_inputs,
 
168
  def _crop_image(
169
  self,
170
  image: Image,
171
+ *,
172
+ is_video_frame: bool,
173
  ) -> List[Image]:
174
  """Crops the image into multiple tiles.
175
 
 
190
  cropped_images: List[Image] = dynamic_preprocess(
191
  image,
192
  min_num=self.min_tiles,
193
+ max_num=self.max_tiles if not is_video_frame else self.video_max_tiles,
194
  image_size=cropped_size,
195
  )
196
 
 
268
  The padded text.
269
  """
270
 
271
+ image_token = cast(str, self.tokenizer.image_token)
272
+
273
+ return [text_item.replace(image_token, image_token * (self.image_pad_len + 1)) for text_item in text]
 
 
 
274
 
275
  @staticmethod
276
  def _prepare_inputs(
 
278
  images: Optional[ImageInput],
279
  videos: Optional[VideoInput],
280
  ) -> Tuple[List[str], List[Image], List[List[Image]]]:
281
+ prepared_text = text if isinstance(text, list) else [text]
 
282
 
 
283
  if images is not None:
284
  image_list = cast(List, image_utils.make_flat_list_of_images(images))
285
+ prepared_images = [cast(Image, image_transforms.to_pil_image(image)) for image in image_list]
286
  else:
287
+ prepared_images = []
288
 
 
289
  if videos is not None:
290
+ video_list = cast(List[List], video_utils.make_batched_videos(videos))
291
+ prepared_videos = [
292
+ [cast(Image, image_transforms.to_pil_image(image)) for image in video] for video in video_list
293
+ ]
294
  else:
295
+ prepared_videos = []
296
 
297
+ return prepared_text, prepared_images, prepared_videos
298
 
299
  def _process_images(
300
  self,
301
  images: List[Image],
302
+ *,
303
+ video_flags: List[bool],
304
+ **kwargs,
305
  ) -> Tuple[BatchFeature, List[int]]:
306
  cropped_images: List[Image] = []
307
  num_cropped_images: List[int] = []
308
 
309
+ for image, video_flag in zip(images, video_flags):
310
+ single_cropped_images = self._crop_image(image, is_video_frame=video_flag)
311
 
312
  cropped_images.extend(single_cropped_images)
313
  num_cropped_images.append(len(single_cropped_images))
processor_config.json CHANGED
@@ -2,8 +2,9 @@
2
  "auto_map": {
3
  "AutoProcessor": "processing_vila.VILAProcessor"
4
  },
5
- "image_pad_len": 122,
6
  "max_tiles": 12,
7
  "min_tiles": 1,
8
- "processor_class": "VILAProcessor"
 
9
  }
 
2
  "auto_map": {
3
  "AutoProcessor": "processing_vila.VILAProcessor"
4
  },
5
+ "image_pad_len": 121,
6
  "max_tiles": 12,
7
  "min_tiles": 1,
8
+ "processor_class": "VILAProcessor",
9
+ "video_max_tiles": 1
10
  }
tokenizer_config.json CHANGED
@@ -66,7 +66,6 @@
66
  "AutoProcessor": "processing_vila.VILAProcessor"
67
  },
68
  "bos_token": "[BOS]",
69
- "chat_template": null,
70
  "clean_up_tokenization_spaces": false,
71
  "eos_token": "<|im_end|>",
72
  "errors": "replace",
 
66
  "AutoProcessor": "processing_vila.VILAProcessor"
67
  },
68
  "bos_token": "[BOS]",
 
69
  "clean_up_tokenization_spaces": false,
70
  "eos_token": "<|im_end|>",
71
  "errors": "replace",