AndyZijianZhang commited on
Commit
9e8ab54
·
verified ·
1 Parent(s): 254a429

Upload files with `vila-upload`.

Browse files

Upload tokenizer_config.json
Upload config.json
Upload model-00007-of-00007.safetensors
Upload configuration_vila.py
Upload generation_config.json
Upload special_tokens_map.json
Upload model-00006-of-00007.safetensors
Upload added_tokens.json
Upload model.safetensors.index.json
Upload processing_vila.py
Upload processor_config.json
Upload modeling_vila.py
Upload chat_template.json

added_tokens.json CHANGED
@@ -2,6 +2,7 @@
2
  "</tool_call>": 151658,
3
  "<image>": 151666,
4
  "<tool_call>": 151657,
 
5
  "<vila/sentinel>": 151665,
6
  "<vila/video>": 151667,
7
  "<|box_end|>": 151649,
 
2
  "</tool_call>": 151658,
3
  "<image>": 151666,
4
  "<tool_call>": 151657,
5
+ "<video>": 151670,
6
  "<vila/sentinel>": 151665,
7
  "<vila/video>": 151667,
8
  "<|box_end|>": 151649,
chat_template.json CHANGED
@@ -1,3 +1,3 @@
1
  {
2
- "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<image>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
  }
 
1
  {
2
+ "chat_template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' }}{% if message['content'] is string %}{{ message['content'] + '<|im_end|>\n' }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{{ '<image>' }}{% elif content['type'] == 'video' or 'video' in content %}{{ '<video>' }}{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
3
  }
config.json CHANGED
@@ -10,6 +10,7 @@
10
  "AutoModelForVision2Seq": "modeling_vila.VILAForConditionalGeneration"
11
  },
12
  "hidden_size": 5120,
 
13
  "image_token_id": 151666,
14
  "mm_hidden_size": 1152,
15
  "mm_projector_type": "mlp_downsample_3x3_fix",
@@ -44,7 +45,8 @@
44
  "vocab_size": 151670
45
  },
46
  "torch_dtype": "bfloat16",
47
- "transformers_version": "4.50.0",
 
48
  "vision_config": {
49
  "architectures": [
50
  "SiglipVisionModel"
 
10
  "AutoModelForVision2Seq": "modeling_vila.VILAForConditionalGeneration"
11
  },
12
  "hidden_size": 5120,
13
+ "image_end_token_id": 198,
14
  "image_token_id": 151666,
15
  "mm_hidden_size": 1152,
16
  "mm_projector_type": "mlp_downsample_3x3_fix",
 
45
  "vocab_size": 151670
46
  },
47
  "torch_dtype": "bfloat16",
48
+ "transformers_version": "4.51.3",
49
+ "video_token_id": 151670,
50
  "vision_config": {
51
  "architectures": [
52
  "SiglipVisionModel"
configuration_vila.py CHANGED
@@ -21,10 +21,12 @@ class VILAConfig(PretrainedConfig):
21
  # Model configuration.
22
  hidden_size: int
23
  image_token_id: int
 
24
  mm_hidden_size: int
25
  mm_projector_type: str
26
  mm_vision_select_feature: str
27
  mm_vision_select_layer: int
 
28
 
29
  def __init__(
30
  self,
@@ -33,10 +35,12 @@ class VILAConfig(PretrainedConfig):
33
  vision_config: Optional[Dict[str, Any]] = None,
34
  hidden_size: Optional[int] = None,
35
  image_token_id: Optional[int] = None,
 
36
  mm_hidden_size: Optional[int] = None,
37
  mm_projector_type: Optional[str] = None,
38
  mm_vision_select_feature: Optional[str] = None,
39
  mm_vision_select_layer: Optional[int] = None,
 
40
  **kwargs,
41
  ):
42
  super().__init__(**kwargs)
@@ -47,9 +51,11 @@ class VILAConfig(PretrainedConfig):
47
  # By default, we use settings from NVILA-Lite.
48
  self.hidden_size = hidden_size if hidden_size is not None else 1536
49
  self.image_token_id = image_token_id if image_token_id is not None else 151649
 
50
  self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else 1152
51
  self.mm_projector_type = mm_projector_type if mm_projector_type is not None else "mlp_downsample_3x3_fix"
52
  self.mm_vision_select_feature = (
53
  mm_vision_select_feature if mm_vision_select_feature is not None else "cls_patch"
54
  )
55
  self.mm_vision_select_layer = mm_vision_select_layer if mm_vision_select_layer is not None else -2
 
 
21
  # Model configuration.
22
  hidden_size: int
23
  image_token_id: int
24
+ image_end_token_id: int
25
  mm_hidden_size: int
26
  mm_projector_type: str
27
  mm_vision_select_feature: str
28
  mm_vision_select_layer: int
29
+ video_token_id: int
30
 
31
  def __init__(
32
  self,
 
35
  vision_config: Optional[Dict[str, Any]] = None,
36
  hidden_size: Optional[int] = None,
37
  image_token_id: Optional[int] = None,
38
+ image_end_token_id: Optional[int] = None,
39
  mm_hidden_size: Optional[int] = None,
40
  mm_projector_type: Optional[str] = None,
41
  mm_vision_select_feature: Optional[str] = None,
42
  mm_vision_select_layer: Optional[int] = None,
43
+ video_token_id: Optional[int] = None,
44
  **kwargs,
45
  ):
46
  super().__init__(**kwargs)
 
51
  # By default, we use settings from NVILA-Lite.
52
  self.hidden_size = hidden_size if hidden_size is not None else 1536
53
  self.image_token_id = image_token_id if image_token_id is not None else 151649
54
+ self.image_end_token_id = image_end_token_id if image_end_token_id is not None else 198 # "\n"
55
  self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else 1152
56
  self.mm_projector_type = mm_projector_type if mm_projector_type is not None else "mlp_downsample_3x3_fix"
57
  self.mm_vision_select_feature = (
58
  mm_vision_select_feature if mm_vision_select_feature is not None else "cls_patch"
59
  )
60
  self.mm_vision_select_layer = mm_vision_select_layer if mm_vision_select_layer is not None else -2
61
+ self.video_token_id = video_token_id if video_token_id is not None else 151650
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
  "pad_token_id": 151643,
6
- "transformers_version": "4.50.0"
7
  }
 
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
  "pad_token_id": 151643,
6
+ "transformers_version": "4.51.3"
7
  }
model-00006-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e67d849dc9b28e68c55ae917164354558a2d2502c1244b3298495bc3147b97e3
3
- size 4995856896
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:303bd0b4eb3e02f493a45d09bde196ec08ee39816dd5de32de0a4f098277e7b3
3
+ size 4995861768
model-00007-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5fb0c3d9f9175496b72378dd0b25cd153be73d2d1a17aec67acd522354bf9bac
3
- size 720921232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc3965ca8e51390f30c18931fb9df01af0be469cc7c6e9ec263127c99a240726
3
+ size 720916360
model.safetensors.index.json CHANGED
@@ -917,8 +917,8 @@
917
  "vision_tower.vision_model.encoder.layers.26.self_attn.v_proj.weight": "model-00007-of-00007.safetensors",
918
  "vision_tower.vision_model.encoder.layers.3.layer_norm1.bias": "model-00006-of-00007.safetensors",
919
  "vision_tower.vision_model.encoder.layers.3.layer_norm1.weight": "model-00006-of-00007.safetensors",
920
- "vision_tower.vision_model.encoder.layers.3.layer_norm2.bias": "model-00007-of-00007.safetensors",
921
- "vision_tower.vision_model.encoder.layers.3.layer_norm2.weight": "model-00007-of-00007.safetensors",
922
  "vision_tower.vision_model.encoder.layers.3.mlp.fc1.bias": "model-00007-of-00007.safetensors",
923
  "vision_tower.vision_model.encoder.layers.3.mlp.fc1.weight": "model-00007-of-00007.safetensors",
924
  "vision_tower.vision_model.encoder.layers.3.mlp.fc2.bias": "model-00007-of-00007.safetensors",
 
917
  "vision_tower.vision_model.encoder.layers.26.self_attn.v_proj.weight": "model-00007-of-00007.safetensors",
918
  "vision_tower.vision_model.encoder.layers.3.layer_norm1.bias": "model-00006-of-00007.safetensors",
919
  "vision_tower.vision_model.encoder.layers.3.layer_norm1.weight": "model-00006-of-00007.safetensors",
920
+ "vision_tower.vision_model.encoder.layers.3.layer_norm2.bias": "model-00006-of-00007.safetensors",
921
+ "vision_tower.vision_model.encoder.layers.3.layer_norm2.weight": "model-00006-of-00007.safetensors",
922
  "vision_tower.vision_model.encoder.layers.3.mlp.fc1.bias": "model-00007-of-00007.safetensors",
923
  "vision_tower.vision_model.encoder.layers.3.mlp.fc1.weight": "model-00007-of-00007.safetensors",
924
  "vision_tower.vision_model.encoder.layers.3.mlp.fc2.bias": "model-00007-of-00007.safetensors",
modeling_vila.py CHANGED
@@ -18,10 +18,10 @@ class DownSampleBlock(nn.Module):
18
  def flat_square(x: Tensor) -> Tensor:
19
  n, w, h, c = x.size()
20
  if w % 2 == 1:
21
- x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
22
  n, w, h, c = x.size()
23
  if h % 2 == 1:
24
- x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
25
  n, w, h, c = x.size()
26
  x = x.contiguous()
27
  x = x.view(n, w, int(h / 2), int(c * 2))
@@ -118,6 +118,16 @@ class MultimodalProjector(nn.Module):
118
  case _:
119
  raise NotImplementedError(f"mm_projector_type={config.mm_projector_type} not implemented.")
120
 
 
 
 
 
 
 
 
 
 
 
121
  def forward(self, x: Tensor) -> Tensor:
122
  return self.layers(x)
123
 
@@ -147,7 +157,7 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
147
  super().__init__(config, *args, **kwargs)
148
 
149
  self.llm = Qwen2ForCausalLM(config.text_config, *args, **kwargs)
150
- self.mm_projector = MultimodalProjector(config).to(dtype=self.dtype)
151
  self.vision_tower = SiglipVisionModel(config.vision_config, *args, **kwargs)
152
 
153
  self.post_init()
@@ -177,8 +187,17 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
177
  assert pixel_values is None
178
 
179
  outputs = self.llm.__call__(
180
- inputs_embeds=inputs_embeds,
181
- attention_mask=attention_mask,
 
 
 
 
 
 
 
 
 
182
  **kwargs,
183
  )
184
 
@@ -202,6 +221,11 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
202
  The embedding of the input ids and pixel values.
203
  """
204
 
 
 
 
 
 
205
  image_token_mask = input_ids == self.config.image_token_id
206
 
207
  text_embedding: Tensor = self.llm.get_input_embeddings().__call__(input_ids * ~image_token_mask)
@@ -210,7 +234,10 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
210
  return text_embedding
211
 
212
  image_features: BaseModelOutputWithPooling = self.vision_tower.__call__(
213
- pixel_values.to(dtype=self.vision_tower.dtype),
 
 
 
214
  output_hidden_states=True,
215
  )
216
  assert image_features.hidden_states is not None
@@ -227,13 +254,35 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
227
 
228
  # TODO: Support dynamic_s2.
229
 
230
- image_embedding: Tensor = self.mm_projector.__call__(selected_feature.to(dtype=self.dtype))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  n_images, n_feature, dim_feature = image_embedding.shape
233
  image_embedding = image_embedding.view(n_images * n_feature, dim_feature)
234
 
235
- text_embedding[image_token_mask.to(device=text_embedding.device)] = image_embedding.to(
236
- device=text_embedding.device
237
- )
238
 
239
  return text_embedding
 
18
  def flat_square(x: Tensor) -> Tensor:
19
  n, w, h, c = x.size()
20
  if w % 2 == 1:
21
+ x = torch.concat([x, torch.zeros((n, 1, h, c), device=x.device, dtype=x.dtype)], dim=1).contiguous()
22
  n, w, h, c = x.size()
23
  if h % 2 == 1:
24
+ x = torch.concat([x, torch.zeros((n, w, 1, c), device=x.device, dtype=x.dtype)], dim=2).contiguous()
25
  n, w, h, c = x.size()
26
  x = x.contiguous()
27
  x = x.view(n, w, int(h / 2), int(c * 2))
 
118
  case _:
119
  raise NotImplementedError(f"mm_projector_type={config.mm_projector_type} not implemented.")
120
 
121
+ self.layers.to(dtype=config.torch_dtype)
122
+
123
+ @property
124
+ def device(self) -> torch.device:
125
+ return next(self.parameters()).device
126
+
127
+ @property
128
+ def dtype(self) -> torch.dtype:
129
+ return next(self.parameters()).dtype
130
+
131
  def forward(self, x: Tensor) -> Tensor:
132
  return self.layers(x)
133
 
 
157
  super().__init__(config, *args, **kwargs)
158
 
159
  self.llm = Qwen2ForCausalLM(config.text_config, *args, **kwargs)
160
+ self.mm_projector = MultimodalProjector(config)
161
  self.vision_tower = SiglipVisionModel(config.vision_config, *args, **kwargs)
162
 
163
  self.post_init()
 
187
  assert pixel_values is None
188
 
189
  outputs = self.llm.__call__(
190
+ inputs_embeds=inputs_embeds.to(
191
+ device=self.llm.device,
192
+ dtype=self.llm.dtype,
193
+ ),
194
+ attention_mask=(
195
+ attention_mask.to(
196
+ device=self.llm.device,
197
+ )
198
+ if attention_mask is not None
199
+ else None
200
+ ),
201
  **kwargs,
202
  )
203
 
 
221
  The embedding of the input ids and pixel values.
222
  """
223
 
224
+ # Video tokens should be removed during preprocessing, so there must not be any video
225
+ # tokens in the input ids.
226
+ if torch.any(input_ids == self.config.video_token_id):
227
+ raise ValueError("Video token ids should not be present in the input ids.")
228
+
229
  image_token_mask = input_ids == self.config.image_token_id
230
 
231
  text_embedding: Tensor = self.llm.get_input_embeddings().__call__(input_ids * ~image_token_mask)
 
234
  return text_embedding
235
 
236
  image_features: BaseModelOutputWithPooling = self.vision_tower.__call__(
237
+ pixel_values.to(
238
+ device=self.vision_tower.device,
239
+ dtype=self.vision_tower.dtype,
240
+ ),
241
  output_hidden_states=True,
242
  )
243
  assert image_features.hidden_states is not None
 
254
 
255
  # TODO: Support dynamic_s2.
256
 
257
+ image_embedding: Tensor = self.mm_projector.__call__(
258
+ selected_feature.to(
259
+ device=self.mm_projector.device,
260
+ dtype=self.mm_projector.dtype,
261
+ )
262
+ )
263
+
264
+ # Append image end token to every image embedding.
265
+ image_end_token_embedding: Tensor = self.llm.get_input_embeddings().__call__(
266
+ torch.tensor(
267
+ self.config.image_end_token_id,
268
+ device=text_embedding.device,
269
+ dtype=torch.long,
270
+ ).view(1, -1)
271
+ ) # Shape: (1, 1, dim_feature)
272
+ image_end_token_embedding = image_end_token_embedding.expand(
273
+ image_embedding.shape[0], 1, -1
274
+ ) # Shape: (n_images, 1, dim_feature)
275
+ image_embedding = torch.concat(
276
+ [
277
+ image_embedding.to(device=text_embedding.device),
278
+ image_end_token_embedding,
279
+ ],
280
+ dim=1,
281
+ )
282
 
283
  n_images, n_feature, dim_feature = image_embedding.shape
284
  image_embedding = image_embedding.view(n_images * n_feature, dim_feature)
285
 
286
+ text_embedding[image_token_mask.to(device=text_embedding.device)] = image_embedding
 
 
287
 
288
  return text_embedding
processing_vila.py CHANGED
@@ -1,9 +1,8 @@
1
  from typing import List, Optional, Tuple, cast
2
 
3
- import numpy as np
4
  import transformers.image_transforms as image_transforms
5
  import transformers.image_utils as image_utils
6
- from numpy.typing import NDArray
7
  from PIL.Image import Image
8
  from torch import Tensor
9
  from transformers.feature_extraction_utils import BatchFeature
@@ -12,19 +11,21 @@ from transformers.image_processing_utils_fast import BaseImageProcessorFast
12
  from transformers.image_utils import ImageInput, VideoInput
13
  from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
14
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
15
- from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
16
  from transformers.tokenization_utils import PreTrainedTokenizer
17
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
18
 
 
19
 
20
- class VILAProcessorKwargs(ProcessingKwargs, total=False):
 
21
  _defaults = {} # type: ignore
22
 
23
 
24
  class VILAProcessorOutput(BatchFeature):
25
- input_ids: List[List[int]] | NDArray[np.int64] | Tensor
26
- attention_mask: List[List[int]] | NDArray[np.int64] | Tensor
27
- pixel_values: Optional[List[NDArray[np.float32]] | NDArray[np.float32] | Tensor]
28
 
29
 
30
  class VILAProcessor(ProcessorMixin):
@@ -67,56 +68,68 @@ class VILAProcessor(ProcessorMixin):
67
  **kwargs,
68
  )
69
 
70
- self.image_pad_len = image_pad_len if image_pad_len is not None else 121
71
  self.max_tiles = max_tiles if max_tiles is not None else 12
72
  self.min_tiles = min_tiles if min_tiles is not None else 1
73
 
74
  def __call__(
75
  self,
 
76
  images: Optional[ImageInput] = None,
77
- text: Optional[TextInput | List[TextInput]] = None,
78
- audio: None = None,
79
  videos: Optional[VideoInput] = None,
80
- **kwargs: Unpack[VILAProcessorKwargs],
 
81
  ) -> VILAProcessorOutput:
82
- # Validate arguments.
83
- assert text is not None and text != [], "text must be provided"
84
- assert not kwargs.get("is_split_into_words", False), "is_split_into_words=True is not supported"
85
 
86
- output_kwargs = self._merge_kwargs(
87
- VILAProcessorKwargs, # type: ignore
 
 
 
 
 
 
 
 
 
 
 
88
  tokenizer_init_kwargs=self.tokenizer.init_kwargs,
89
  **kwargs,
90
  )
91
 
92
- # Process images.
93
- if images is not None and images != []:
94
- image_inputs, num_cropped_images = self._process_images(
95
- images=images,
96
- **output_kwargs["images_kwargs"],
97
- )
98
- else:
99
- # If no images are provided, do not define pixel_values.
100
- image_inputs = BatchFeature()
101
- num_cropped_images = []
 
 
102
 
103
- # TODO: video processing.
 
 
 
 
104
 
105
  # Process text.
106
- text = text if isinstance(text, list) else [text]
107
-
108
  text = self._pad_image_tokens_by_num_crops(
109
  text,
110
  num_cropped_images=num_cropped_images,
 
111
  )
112
 
113
- text = self._pad_image_tokens_by_num_embeddings(
114
- text,
115
- )
116
 
117
  text_inputs = self.tokenizer.__call__(
118
  text,
119
- **output_kwargs["text_kwargs"],
120
  )
121
 
122
  return VILAProcessorOutput(
@@ -140,7 +153,8 @@ class VILAProcessor(ProcessorMixin):
140
  """
141
 
142
  # TODO: Support more image processors.
143
- assert isinstance(self.image_processor, (SiglipImageProcessor, SiglipImageProcessorFast))
 
144
 
145
  assert self.image_processor.size["height"] == self.image_processor.size["width"]
146
  cropped_size = self.image_processor.size["height"]
@@ -156,61 +170,68 @@ class VILAProcessor(ProcessorMixin):
156
 
157
  def _pad_image_tokens_by_num_crops(
158
  self,
159
- text: List[TextInput],
160
  *,
161
  num_cropped_images: List[int],
162
- ) -> List[TextInput]:
163
- """Pads each <image> to num_cropped_images of "<image>\n\n".
 
164
 
165
  Args:
166
  text: The text to be padded.
167
  num_cropped_images: The number of cropped images for each image token.
 
168
 
169
  Returns:
170
  The padded text.
171
  """
 
 
 
 
 
172
  image_token: str = cast(str, self.tokenizer.image_token)
173
 
174
- # Validate arguments.
175
- num_images = len(num_cropped_images)
176
- num_image_tokens = sum([item.count(image_token) for item in text])
177
- assert num_images == num_image_tokens, (
178
- f"Number of image tokens ({num_image_tokens}) in text does not match "
179
- f"the number of images ({num_images})."
180
- )
181
 
182
- assert all(
183
- image_pad_len > 0 for image_pad_len in num_cropped_images
184
- ), "All image padding lengths should be positive integers."
185
 
186
- # Pad image tokens.
187
- image_idx = 0
188
- padded_text: List[TextInput] = []
189
 
190
- for i in range(len(text)):
191
- padded_text_item = ""
192
- remaining_text = text[i]
193
 
194
- while True:
195
- token_pos = remaining_text.find(image_token)
196
- if token_pos == -1:
197
- padded_text_item += remaining_text
 
 
198
  break
199
 
200
- padded_text_item += remaining_text[:token_pos] + ((image_token + "\n") * num_cropped_images[image_idx])
 
 
 
 
 
201
 
202
- image_idx += 1
203
- remaining_text = remaining_text[token_pos + len(image_token) :]
204
 
205
- padded_text.append(padded_text_item)
 
206
 
207
- return padded_text
208
 
209
  def _pad_image_tokens_by_num_embeddings(
210
  self,
211
- text: List[TextInput],
212
- ) -> List[TextInput]:
213
- """Pads each <image> to image_pad_len times of "<image>".
214
 
215
  Args:
216
  text: The text to be padded.
@@ -218,56 +239,151 @@ class VILAProcessor(ProcessorMixin):
218
  Returns:
219
  The padded text.
220
  """
221
- image_token: str = cast(str, self.tokenizer.image_token)
222
 
223
- padded_text: List[TextInput] = []
224
-
225
- for i in range(len(text)):
226
- padded_text_item = ""
227
- remaining_text = text[i]
228
-
229
- while True:
230
- token_pos = remaining_text.find(image_token)
231
- if token_pos == -1:
232
- padded_text_item += remaining_text
233
- break
234
-
235
- padded_text_item += remaining_text[:token_pos] + (image_token * self.image_pad_len)
 
 
236
 
237
- remaining_text = remaining_text[token_pos + len(image_token) :]
 
 
 
 
 
238
 
239
- padded_text.append(padded_text_item)
 
 
 
 
 
240
 
241
- return padded_text
242
 
243
  def _process_images(
244
  self,
245
- images: ImageInput,
246
- **kwargs: Unpack[VILAProcessorKwargs],
247
  ) -> Tuple[BatchFeature, List[int]]:
248
- images_flatten = cast(
249
- List[Image] | List[NDArray] | List[Tensor],
250
- image_utils.make_flat_list_of_images(images),
251
- )
252
-
253
  cropped_images: List[Image] = []
254
  num_cropped_images: List[int] = []
255
- for image in images_flatten:
256
- pil_image: Image = image_transforms.to_pil_image(image)
257
- single_cropped_images = self._crop_image(pil_image)
258
 
259
  cropped_images.extend(single_cropped_images)
260
  num_cropped_images.append(len(single_cropped_images))
261
 
262
- image_inputs = self.image_processor(
 
 
 
 
 
263
  cropped_images,
264
  **kwargs,
265
  )
266
 
267
  return image_inputs, num_cropped_images
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
271
  orig_width, orig_height = image.size
272
  aspect_ratio = orig_width / orig_height
273
 
@@ -309,7 +425,9 @@ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbna
309
  return processed_images
310
 
311
 
312
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
 
 
313
  best_ratio_diff = float("inf")
314
  best_ratio = (1, 1)
315
  area = width * height
 
1
  from typing import List, Optional, Tuple, cast
2
 
 
3
  import transformers.image_transforms as image_transforms
4
  import transformers.image_utils as image_utils
5
+ import transformers.utils.logging
6
  from PIL.Image import Image
7
  from torch import Tensor
8
  from transformers.feature_extraction_utils import BatchFeature
 
11
  from transformers.image_utils import ImageInput, VideoInput
12
  from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
13
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
14
+ from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
15
  from transformers.tokenization_utils import PreTrainedTokenizer
16
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
17
 
18
+ logger = transformers.utils.logging.get_logger(__name__)
19
 
20
+
21
+ class VILAProcessorProcessingKwargs(ProcessingKwargs, total=False):
22
  _defaults = {} # type: ignore
23
 
24
 
25
  class VILAProcessorOutput(BatchFeature):
26
+ input_ids: List[List[int]] | Tensor
27
+ attention_mask: List[List[int]] | Tensor
28
+ pixel_values: Optional[List[Tensor] | Tensor]
29
 
30
 
31
  class VILAProcessor(ProcessorMixin):
 
68
  **kwargs,
69
  )
70
 
71
+ self.image_pad_len = image_pad_len if image_pad_len is not None else 122
72
  self.max_tiles = max_tiles if max_tiles is not None else 12
73
  self.min_tiles = min_tiles if min_tiles is not None else 1
74
 
75
  def __call__(
76
  self,
77
+ text: TextInput | List[TextInput],
78
  images: Optional[ImageInput] = None,
 
 
79
  videos: Optional[VideoInput] = None,
80
+ audio: None = None,
81
+ **kwargs: Unpack[VILAProcessorProcessingKwargs],
82
  ) -> VILAProcessorOutput:
83
+ """Preprocesses inputs for VILA.
 
 
84
 
85
+ Args:
86
+ text: The text to be processed.
87
+ images: The images to be processed.
88
+ videos: The videos to be processed.
89
+ audio: Not available.
90
+ **kwargs: Additional arguments for processing.
91
+
92
+ Returns:
93
+ The processed inputs that can be fed to the model.
94
+ """
95
+
96
+ merged_kwargs = self._merge_kwargs(
97
+ VILAProcessorProcessingKwargs, # type: ignore
98
  tokenizer_init_kwargs=self.tokenizer.init_kwargs,
99
  **kwargs,
100
  )
101
 
102
+ text, images, videos = self._prepare_inputs(
103
+ text=text,
104
+ images=images,
105
+ videos=videos,
106
+ )
107
+
108
+ # Process videos.
109
+ text, images, video_flags = self._treat_videos_as_image_seqs(
110
+ text=text,
111
+ images=images,
112
+ videos=videos,
113
+ )
114
 
115
+ # Process images.
116
+ image_inputs, num_cropped_images = self._process_images(
117
+ images=images,
118
+ **merged_kwargs["images_kwargs"],
119
+ )
120
 
121
  # Process text.
 
 
122
  text = self._pad_image_tokens_by_num_crops(
123
  text,
124
  num_cropped_images=num_cropped_images,
125
+ video_flags=video_flags,
126
  )
127
 
128
+ text = self._pad_image_tokens_by_num_embeddings(text)
 
 
129
 
130
  text_inputs = self.tokenizer.__call__(
131
  text,
132
+ **merged_kwargs["text_kwargs"],
133
  )
134
 
135
  return VILAProcessorOutput(
 
153
  """
154
 
155
  # TODO: Support more image processors.
156
+ if not isinstance(self.image_processor, (SiglipImageProcessor, SiglipImageProcessorFast)):
157
+ raise NotImplementedError
158
 
159
  assert self.image_processor.size["height"] == self.image_processor.size["width"]
160
  cropped_size = self.image_processor.size["height"]
 
170
 
171
  def _pad_image_tokens_by_num_crops(
172
  self,
173
+ text: List[str],
174
  *,
175
  num_cropped_images: List[int],
176
+ video_flags: List[bool],
177
+ ) -> List[str]:
178
+ """Pads each \\<image> to num_cropped_images of "\\<image>\\n" for images and "\\<video>" for videos.
179
 
180
  Args:
181
  text: The text to be padded.
182
  num_cropped_images: The number of cropped images for each image token.
183
+ video_flags: A list of flags indicating whether the num_cropped_images item is a video.
184
 
185
  Returns:
186
  The padded text.
187
  """
188
+
189
+ assert len(num_cropped_images) == len(
190
+ video_flags
191
+ ), "num_cropped_images and video_flags must have the same length."
192
+
193
  image_token: str = cast(str, self.tokenizer.image_token)
194
 
195
+ return_text: List[str] = []
 
 
 
 
 
 
196
 
197
+ for text_item in text:
198
+ return_text_item: str = ""
 
199
 
200
+ # Repeatedly find image_token in the text.
201
+ while image_token in text_item:
202
+ image_pos = text_item.find(image_token)
203
 
204
+ if image_pos != -1 and len(num_cropped_images) > 0:
205
+ num_crops = num_cropped_images.pop(0)
206
+ video_flag = video_flags.pop(0)
207
 
208
+ return_text_item += (
209
+ text_item[:image_pos] + (image_token if video_flag else (image_token + "\n")) * num_crops
210
+ )
211
+ text_item = text_item[image_pos + len(image_token) :]
212
+
213
+ else:
214
  break
215
 
216
+ # Must place outside the while loop.
217
+ if image_token in text_item:
218
+ raise ValueError("Too many image tokens in the text.")
219
+
220
+ return_text_item += text_item
221
+ text_item = ""
222
 
223
+ return_text.append(return_text_item)
 
224
 
225
+ if len(num_cropped_images) != 0:
226
+ raise ValueError("Too many images provided.")
227
 
228
+ return return_text
229
 
230
  def _pad_image_tokens_by_num_embeddings(
231
  self,
232
+ text: List[str],
233
+ ) -> List[str]:
234
+ """Pads each \\<image> to image_pad_len times of "\\<image>".
235
 
236
  Args:
237
  text: The text to be padded.
 
239
  Returns:
240
  The padded text.
241
  """
 
242
 
243
+ return [
244
+ text_item.replace(
245
+ cast(str, self.tokenizer.image_token), cast(str, self.tokenizer.image_token) * self.image_pad_len
246
+ )
247
+ for text_item in text
248
+ ]
249
+
250
+ @staticmethod
251
+ def _prepare_inputs(
252
+ text: TextInput | List[TextInput],
253
+ images: Optional[ImageInput],
254
+ videos: Optional[VideoInput],
255
+ ) -> Tuple[List[str], List[Image], List[List[Image]]]:
256
+ # Prepare text.
257
+ text = text if isinstance(text, list) else [text]
258
 
259
+ # Prepare images.
260
+ if images is not None:
261
+ image_list = cast(List, image_utils.make_flat_list_of_images(images))
262
+ images = [image_transforms.to_pil_image(image) for image in image_list]
263
+ else:
264
+ images = cast(List[Image], [])
265
 
266
+ # Prepare videos.
267
+ if videos is not None:
268
+ video_list = cast(List[List], image_utils.make_batched_videos(videos))
269
+ videos = [[image_transforms.to_pil_image(image) for image in video] for video in video_list]
270
+ else:
271
+ videos = cast(List[List[Image]], [])
272
 
273
+ return text, images, videos
274
 
275
  def _process_images(
276
  self,
277
+ images: List[Image],
278
+ **kwargs: Unpack[ImagesKwargs],
279
  ) -> Tuple[BatchFeature, List[int]]:
 
 
 
 
 
280
  cropped_images: List[Image] = []
281
  num_cropped_images: List[int] = []
282
+
283
+ for image in images:
284
+ single_cropped_images = self._crop_image(image)
285
 
286
  cropped_images.extend(single_cropped_images)
287
  num_cropped_images.append(len(single_cropped_images))
288
 
289
+ if len(cropped_images) == 0:
290
+ # The image processor may not properly handle empty image lists.
291
+ # This is a workaround to avoid errors.
292
+ return BatchFeature(), num_cropped_images
293
+
294
+ image_inputs = self.image_processor.__call__(
295
  cropped_images,
296
  **kwargs,
297
  )
298
 
299
  return image_inputs, num_cropped_images
300
 
301
+ def _treat_videos_as_image_seqs(
302
+ self, text: List[str], images: List[Image], videos: List[List[Image]]
303
+ ) -> Tuple[List[str], List[Image], List[bool]]:
304
+ """Treats videos as image sequences.
305
+
306
+ This method will replace all video tokens in the text with #frame image tokens,
307
+ and insert the corresponding images into the images list.
308
+
309
+ Args:
310
+ text: The text to be processed.
311
+ images: The images to be processed.
312
+ videos: The videos to be processed.
313
+
314
+ Returns:
315
+ The processed text and images, and a list of flags indicating whether the images are from videos.
316
+ """
317
+
318
+ image_token = cast(str, self.tokenizer.image_token)
319
+ video_token = cast(str, self.tokenizer.video_token)
320
+
321
+ return_text: List[str] = []
322
+ return_images: List[Image] = []
323
+ return_video_flags: List[bool] = []
324
+
325
+ for text_item in text:
326
+ return_text_item: str = ""
327
+
328
+ # Repeatedly find image_token or video_token in the text.
329
+ while image_token in text_item or video_token in text_item:
330
+ image_pos = text_item.find(image_token)
331
+ video_pos = text_item.find(video_token)
332
+
333
+ # If not found, set position to the end of the text.
334
+ if image_pos == -1:
335
+ image_pos = len(text_item)
336
+ if video_pos == -1:
337
+ video_pos = len(text_item)
338
+
339
+ if image_pos != len(text_item) and len(images) > 0 and image_pos < video_pos:
340
+ # Take an image and keep the image token if:
341
+ # - an image token is found, and
342
+ # - there are images left, and
343
+ # - the image token is before the first video token.
344
+
345
+ image = images.pop(0)
346
+ return_images.append(image)
347
+ return_video_flags.append(False)
348
+
349
+ return_text_item += text_item[: image_pos + len(image_token)]
350
+ text_item = text_item[image_pos + len(image_token) :]
351
+
352
+ elif video_pos != len(text_item) and len(videos) > 0 and video_pos < image_pos:
353
+ # Take a video and replace the video token with #frame image tokens if:
354
+ # - a video token is found, and
355
+ # - there are videos left, and
356
+ # - the video token is before the first image token.
357
+
358
+ video = videos.pop(0)
359
+ return_images.extend(video)
360
+ return_video_flags.extend([True] * len(video))
361
+
362
+ return_text_item += text_item[:video_pos] + image_token * len(video)
363
+ text_item = text_item[video_pos + len(video_token) :]
364
+ else:
365
+ break
366
+
367
+ # Must place outside the while loop.
368
+ if image_token in text_item:
369
+ raise ValueError("Too many image tokens in the text.")
370
+ if video_token in text_item:
371
+ raise ValueError("Too many video tokens in the text.")
372
+
373
+ return_text_item += text_item
374
+ text_item = ""
375
+
376
+ return_text.append(return_text_item)
377
+
378
+ if len(images) != 0:
379
+ raise ValueError("Too many images provided.")
380
+ if len(videos) != 0:
381
+ raise ValueError("Too many videos provided.")
382
+
383
+ return return_text, return_images, return_video_flags
384
+
385
 
386
+ def dynamic_preprocess(image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail=True) -> List[Image]:
387
  orig_width, orig_height = image.size
388
  aspect_ratio = orig_width / orig_height
389
 
 
425
  return processed_images
426
 
427
 
428
+ def find_closest_aspect_ratio(
429
+ aspect_ratio: float, target_ratios: List[Tuple[int, int]], width: int, height: int, image_size: int
430
+ ) -> Tuple[int, int]:
431
  best_ratio_diff = float("inf")
432
  best_ratio = (1, 1)
433
  area = width * height
processor_config.json CHANGED
@@ -2,7 +2,7 @@
2
  "auto_map": {
3
  "AutoProcessor": "processing_vila.VILAProcessor"
4
  },
5
- "image_pad_len": 121,
6
  "max_tiles": 12,
7
  "min_tiles": 1,
8
  "processor_class": "VILAProcessor"
 
2
  "auto_map": {
3
  "AutoProcessor": "processing_vila.VILAProcessor"
4
  },
5
+ "image_pad_len": 122,
6
  "max_tiles": 12,
7
  "min_tiles": 1,
8
  "processor_class": "VILAProcessor"
special_tokens_map.json CHANGED
@@ -38,5 +38,6 @@
38
  "normalized": false,
39
  "rstrip": false,
40
  "single_word": false
41
- }
 
42
  }
 
38
  "normalized": false,
39
  "rstrip": false,
40
  "single_word": false
41
+ },
42
+ "video_token": "<video>"
43
  }
tokenizer_config.json CHANGED
@@ -217,6 +217,14 @@
217
  "rstrip": false,
218
  "single_word": false,
219
  "special": true
 
 
 
 
 
 
 
 
220
  }
221
  },
222
  "additional_special_tokens": [
@@ -246,7 +254,8 @@
246
  "eos_token": "<|im_end|>",
247
  "errors": "replace",
248
  "extra_special_tokens": {
249
- "image_token": "<image>"
 
250
  },
251
  "image_token": "<image>",
252
  "legacy": false,
@@ -256,5 +265,6 @@
256
  "processor_class": "VILAProcessor",
257
  "split_special_tokens": false,
258
  "tokenizer_class": "Qwen2Tokenizer",
259
- "unk_token": null
 
260
  }
 
217
  "rstrip": false,
218
  "single_word": false,
219
  "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<video>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
  }
229
  },
230
  "additional_special_tokens": [
 
254
  "eos_token": "<|im_end|>",
255
  "errors": "replace",
256
  "extra_special_tokens": {
257
+ "image_token": "<image>",
258
+ "video_token": "<video>"
259
  },
260
  "image_token": "<image>",
261
  "legacy": false,
 
265
  "processor_class": "VILAProcessor",
266
  "split_special_tokens": false,
267
  "tokenizer_class": "Qwen2Tokenizer",
268
+ "unk_token": null,
269
+ "video_token": "<video>"
270
  }