AndyZijianZhang commited on
Commit
139253d
·
1 Parent(s): 996e0be

refactor: keep original files

Browse files
Files changed (4) hide show
  1. auto_model.py +0 -1261
  2. config.json +4 -4
  3. modeling_vila.py +1214 -144
  4. modeling_vila_hf.py +191 -0
auto_model.py DELETED
@@ -1,1261 +0,0 @@
1
- import copy
2
- import json
3
- import logging
4
- import math
5
- import os
6
- import os.path
7
- import os.path as osp
8
- import shutil
9
- import warnings
10
- from abc import ABC
11
- from collections import OrderedDict, defaultdict, deque
12
- from copy import deepcopy
13
- from itertools import chain
14
- from threading import Thread
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.distributed as dist
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- import torchvision
22
- from einops import rearrange
23
- from PIL import Image
24
- from transformers import (
25
- AutoConfig,
26
- AutoModel,
27
- AutoProcessor,
28
- AutoTokenizer,
29
- GenerationConfig,
30
- LogitsProcessor,
31
- PretrainedConfig,
32
- PreTrainedModel,
33
- Qwen2Config,
34
- Qwen2ForCausalLM,
35
- Qwen2PreTrainedModel,
36
- TextIteratorStreamer,
37
- )
38
- from transformers.modeling_outputs import CausalLMOutputWithPast
39
- from transformers.modeling_utils import ContextManagers, no_init_weights
40
-
41
- from .auto_processor import VILAProcessor
42
- from .base_projector import MultimodalProjector, MultimodalProjectorConfig
43
- from .builder import build_llm_and_tokenizer
44
- from .configuration_vila import VILAConfig
45
- from .constants import *
46
- from .conversation import SeparatorStyle, default_conversation
47
- from .distributed import all_gather as vila_all_gather
48
- from .loss import soft_cross_entropy
49
- from .media import extract_media
50
- from .media_encoder import BasicImageEncoder, BasicVideoEncoder
51
- from .mm_utils import process_image, process_images
52
- from .model_utils_packing import set_seqlens_in_batch
53
- from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
54
- from .tokenizer_utils import tokenize_conversation
55
- from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
56
-
57
- # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
58
-
59
- # ease debugging
60
- python_input = input
61
-
62
-
63
- # quick hack for remote code
64
- def get_pg_manager():
65
- return None
66
-
67
-
68
- def get_model_weights_dtype(model: nn.Module):
69
- pass
70
-
71
-
72
- def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
73
- if model_type_or_path is None:
74
- return None
75
- ## load from pretrained model
76
- if config.resume_path:
77
- assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
78
- return MultimodalProjector.from_pretrained(model_type_or_path, config)
79
- ## build from scratch
80
- else:
81
- mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
82
- mm_projector = MultimodalProjector(mm_projector_cfg, config)
83
- return mm_projector
84
-
85
-
86
- def check_dot_in_model_path(model_path: str):
87
- """Check if the model path contains dot, which will affect the remote code loading."""
88
- if osp.isdir(model_path): # local model
89
- if "." in osp.abspath(model_path):
90
- return True
91
- else: # remote model
92
- if "." in model_path:
93
- return True
94
- return False
95
-
96
-
97
- def get_vila_version(model_path: str) -> str:
98
- VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
99
- for version in VERSIONS:
100
- if version in model_path.lower():
101
- return version
102
- return None
103
-
104
-
105
- def generate_jinja_template(conv_mode: str) -> str:
106
- if conv_mode == "vicuna_v1":
107
- return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
108
- {% set roles = ["user", "assistant"] %}
109
- {% set sep = " " %}
110
-
111
- {{ system_prompt }}
112
-
113
- {% for message in messages %}
114
- {% if message['role'] == roles[0] %}
115
- {{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
116
- {% else %}
117
- {{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
118
- {% endif %}
119
- {% endfor %}
120
- {% if messages[-1]['role'] == 'user' %}
121
- {{ "ASSISTANT:" }}
122
- {% endif %}
123
- """
124
- elif conv_mode == "llama_3":
125
- return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
126
- {% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
127
- {% set sep = "<|eot_id|>" %}
128
-
129
- {{ system_prompt }}
130
- {% for message in messages %}
131
- {% if message['role'] == 'user' %}
132
- {{ roles[0] }}{{ message['content'] }}{{ sep }}
133
- {% else %}
134
- {{ roles[1] }}{{ message['content'] }}{{ sep }}
135
- {% endif %}
136
- {% endfor %}
137
- {% if messages[-1]['role'] == 'user' %}
138
- {{ roles[1] }}
139
- {% endif %}
140
- """
141
- elif conv_mode == "hermes_2":
142
- return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
143
- {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
144
- {% set sep = "<|im_end|>" %}
145
-
146
- {{ system_prompt }}{{ sep }}
147
-
148
- {% for message in messages %}
149
- {% if message['role'] == 'user' %}
150
- {{ roles[0] }}{{ message['content'] }}{{ sep }}
151
- {% else %}
152
- {{ roles[1] }}{{ message['content'] }}{{ sep }}
153
- {% endif %}
154
- {% endfor %}"""
155
- else:
156
- raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
157
-
158
-
159
- def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
160
- ## skip vision tower instantiation
161
- if model_name_or_path is None:
162
- return None
163
-
164
- vision_tower_arch = None
165
- if config.resume_path and "radio" not in model_name_or_path:
166
- assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
167
- vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
168
- vision_tower_arch = vision_tower_cfg.architectures[0].lower()
169
- vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
170
-
171
- use_s2 = getattr(config, "s2", False)
172
- use_dynamic_s2 = getattr(config, "dynamic_s2", False)
173
-
174
- if "siglip" in vision_tower_name:
175
- if use_dynamic_s2:
176
- vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
177
- elif use_s2:
178
- vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
179
- else:
180
- vision_tower = SiglipVisionTower(model_name_or_path, config)
181
- else:
182
- raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
183
-
184
- config.mm_hidden_size = (
185
- vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
186
- )
187
- return vision_tower
188
-
189
-
190
- class VILAPretrainedModel(PreTrainedModel):
191
- config_class = VILAConfig
192
- main_input_name = "input_embeds"
193
- supports_gradient_checkpointing = True
194
- _supports_flash_attn_2 = True
195
- _no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
196
-
197
- def __init__(self, config: VILAConfig, *args, **kwargs):
198
- super().__init__(config)
199
- self.config = config
200
- cfgs = get_model_config(config)
201
- if len(cfgs) == 3:
202
- llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
203
- else:
204
- raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
205
-
206
- # loading on auto by default
207
- device_map = kwargs.get("device_map", "auto")
208
- self.mm_projector = build_mm_projector(mm_projector_cfg, config)
209
- self.vision_tower = build_vision_tower(vision_tower_cfg, config)
210
- if device_map in ["auto", "cuda"]:
211
- self.mm_projector = self.mm_projector.cuda()
212
- self.vision_tower = self.vision_tower.cuda()
213
- # set device_map auto can autoamtically shard llm to different devices
214
- self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
215
-
216
- # NOTE(ligeng): hard code to set padding_side to left
217
- self.tokenizer.padding_side = "left"
218
- # TODO(ligeng): need to add other decoders from config
219
- self.encoders = {"image": BasicImageEncoder(self), "video": BasicVideoEncoder(self)}
220
-
221
- self.post_config()
222
- self.is_loaded = True
223
-
224
- assert (
225
- self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
226
- ), "At least one of the components must be instantiated."
227
-
228
- @classmethod
229
- def convert_vila_dev_ckpt_to_remote(
230
- self,
231
- model_path: str,
232
- output_dir: str = None,
233
- vila_version: str | None = None,
234
- conv_mode: str | None = None,
235
- copy: bool = False,
236
- copy_weights: bool = True,
237
- copy_code: bool = True,
238
- *model_args,
239
- **kwargs,
240
- ):
241
- # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
242
- assert model_path != output_dir, "model_path and output_dir cannot be the same"
243
- if os.path.isdir(model_path):
244
- model_path = model_path
245
- else:
246
- from huggingface_hub import HfApi, snapshot_download
247
-
248
- model_path = snapshot_download(model_path)
249
- print("downloading HF model to", model_path)
250
-
251
- if check_dot_in_model_path(model_path) and output_dir is None:
252
- raise ValueError(
253
- f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
254
- )
255
- if output_dir is not None and "." in output_dir:
256
- raise ValueError(
257
- f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
258
- )
259
-
260
- if copy:
261
- print("copy is set to True, copying weights and code to output_dir")
262
- copy_weights = copy_code = True
263
- # copy weights and code to output_dir
264
- self.copy_or_symlink_directory(model_path, output_dir, copy=copy_weights)
265
- self.copy_remote_py_files(output_dir, copy=copy_code)
266
-
267
- if vila_version is None:
268
- vila_version = get_vila_version(output_dir)
269
-
270
- cfg_path = os.path.join(output_dir, "config.json")
271
- config = json.load(open(cfg_path))
272
- config["version"] = "2.0" # nvila tag
273
- config["architectures"] = ["VILAForCausalLM"]
274
- config["auto_map"] = {
275
- "AutoProcessor": "auto_processor.VILAProcessor",
276
- "AutoConfig": "modeling_vila.VILAConfig",
277
- "AutoModel": "modeling_vila.VILAForCausalLM",
278
- "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM",
279
- }
280
- # vila1.5 legacy support
281
- config["model_type"] = "vila"
282
- if vila_version in ["vila1.5", "vila-m3"]:
283
- if conv_mode is None:
284
- raise ValueError(f"Please specify the conversation mode for {output_dir}.")
285
- config["chat_template"] = conv_mode
286
- jinja_template = generate_jinja_template(conv_mode)
287
- jinja_path = os.path.join(output_dir, f"{conv_mode}.jinja")
288
- with open(jinja_path, "w") as f:
289
- f.write(jinja_template)
290
- json.dump(config, open(cfg_path, "w"), indent=2)
291
-
292
- ##########################################################################################
293
- config = AutoConfig.from_pretrained(output_dir, trust_remote_code=True)
294
- tokenizer = load_tokenizer_then_handle_media_tokens_and_chat_template(output_dir, config)
295
- tokenizer.save_pretrained(osp.join(output_dir, "llm"))
296
- ##########################################################################################
297
-
298
- @classmethod
299
- def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
300
- # Create output directory if it doesn't exist
301
- os.makedirs(output_dir, exist_ok=True)
302
- # Create symlinks for all files in model_path to output_dir
303
- for item in os.listdir(model_path):
304
- src_path = os.path.join(model_path, item)
305
- dst_path = os.path.join(output_dir, item)
306
-
307
- # Remove existing file/directory at destination if it exists
308
- if os.path.exists(dst_path):
309
- if os.path.islink(dst_path):
310
- os.unlink(dst_path)
311
- elif os.path.isdir(dst_path):
312
- shutil.rmtree(dst_path)
313
- else:
314
- os.remove(dst_path)
315
-
316
- # Create symlink
317
- if copy:
318
- if os.path.isdir(src_path):
319
- shutil.copytree(src_path, dst_path)
320
- else:
321
- shutil.copy2(src_path, dst_path)
322
- print(f"Copied {src_path} to {dst_path}")
323
- else:
324
- os.symlink(src_path, dst_path)
325
- print(f"Created symlink from {src_path} to {dst_path}")
326
-
327
- @classmethod
328
- def copy_remote_py_files(cls, output_dir, copy=True):
329
- ## copy .py and REAMDE for next loading remote code
330
- current_file_path = os.path.abspath(__file__)
331
- current_folder = os.path.dirname(current_file_path)
332
- for file_name in os.listdir(current_folder):
333
- if file_name == "INSTRUCTIONS.md":
334
- src_fname = os.path.join(current_folder, file_name)
335
- dst_fname = os.path.join(output_dir, "README.md")
336
- if os.path.exists(dst_fname):
337
- old_reamde = open(dst_fname).read()
338
- else:
339
- old_reamde = ""
340
- with open(src_fname) as src, open(dst_fname, "w") as dst:
341
- dst.write(src.read())
342
- dst.write(old_reamde)
343
- print("[HF remote code] REAMDE ", src_fname, "to", dst_fname)
344
- if file_name.endswith(".py") or file_name.endswith(".jinja"):
345
- full_file_name = os.path.join(current_folder, file_name)
346
- if os.path.isfile(full_file_name):
347
- if copy:
348
- shutil.copy(full_file_name, output_dir)
349
- print("[HF remote code] copying", full_file_name, "to", output_dir)
350
- else:
351
- # symlink to ease development
352
- if os.path.exists(os.path.join(output_dir, file_name)):
353
- os.remove(os.path.join(output_dir, file_name))
354
- os.symlink(full_file_name, os.path.join(output_dir, file_name))
355
- print("[HF remote code] linking", full_file_name, "to", output_dir)
356
-
357
- def save_pretrained(self, output_dir, state_dict=None, **kwargs):
358
- if state_dict is None:
359
- # other wise fetch from deepspeed
360
- # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
361
- state_dict = self.state_dict()
362
-
363
- if getattr(self, "tokenizer", None):
364
- self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
365
-
366
- if self.get_llm():
367
- print(f"saving llm to {osp.join(output_dir, 'llm')}")
368
- self.llm.config._name_or_path = osp.join(output_dir, "llm")
369
- llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
370
- self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
371
- self.config.llm_cfg = self.llm.config
372
-
373
- if self.get_vision_tower():
374
- print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
375
- self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
376
- vision_tower_state_dict = OrderedDict(
377
- {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
378
- )
379
- self.vision_tower.vision_tower.save_pretrained(
380
- os.path.join(output_dir, "vision_tower"),
381
- state_dict=vision_tower_state_dict,
382
- )
383
- self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
384
- self.config.vision_tower_cfg = self.vision_tower.config
385
- if hasattr(self.config.vision_tower_cfg, "auto_map"):
386
- if "radio" not in self.get_vision_tower().__class__.__name__.lower():
387
- delattr(self.config.vision_tower_cfg, "auto_map")
388
-
389
- if self.get_mm_projector():
390
- print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
391
- self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
392
- mm_projector_state_dict = OrderedDict(
393
- {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
394
- )
395
- self.mm_projector.save_pretrained(
396
- os.path.join(output_dir, "mm_projector"),
397
- state_dict=mm_projector_state_dict,
398
- )
399
- self.config.mm_projector_cfg = self.mm_projector.config
400
-
401
- ## update and save top-level config
402
- self.config._name_or_path = output_dir
403
- self.config.architectures = [self.__class__.__name__]
404
- self.config.save_pretrained(output_dir)
405
-
406
- ## copy .py and REAMDE for next loading remote code
407
- self.copy_remote_py_files(output_dir)
408
-
409
- @classmethod
410
- def from_pretrained(
411
- cls,
412
- pretrained_model_name_or_path: Optional[str] = None,
413
- *model_args,
414
- config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
415
- cache_dir: Optional[Union[str, os.PathLike]] = None,
416
- ignore_mismatched_sizes: bool = False,
417
- force_download: bool = False,
418
- local_files_only: bool = False,
419
- token: Optional[Union[str, bool]] = None,
420
- revision: str = "main",
421
- use_safetensors: Optional[bool] = None,
422
- weights_only: bool = True,
423
- **kwargs,
424
- ):
425
- # print("DEBUG2", kwargs); input()
426
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
427
- return cls._from_config(config, **kwargs)
428
-
429
- def init_llm(self, llm_config, config, *args, **kwargs):
430
- self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
431
- # hard coded for NVILA
432
- # variables for XGrammar
433
- # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
434
- NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
435
-
436
- self.pad_token_list = (
437
- self.tokenizer.pad_token_id,
438
- self.tokenizer.eos_token_id,
439
- self.tokenizer.tokenize("<|endoftext|>")[0], # for qwen
440
- )
441
-
442
- # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
443
- self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
444
- # XGrammar tokenizer and grammar compiler
445
- # lazy init only when specified json output during inference
446
- self.grammar_compiler = None
447
- self.llm.resize_token_embeddings(len(self.tokenizer))
448
- return self.llm, self.tokenizer
449
-
450
- def post_config(self):
451
- ######################################################################
452
- # TODO: need to check dtype with jason
453
- self.llm = self.llm.to(torch.float16)
454
- self.mm_projector = self.mm_projector.to(torch.float16)
455
- self.vision_tower = self.vision_tower.to(torch.float16)
456
- ######################################################################
457
- self.training = self.llm.training
458
- if self.training:
459
- self.train()
460
- else:
461
- self.eval()
462
- ## configuration
463
- if getattr(self.config, "llm_cfg", None) is None:
464
- self.config.llm_cfg = self.llm.config
465
- if getattr(self.config, "vision_tower_cfg", None) is None:
466
- self.config.vision_tower_cfg = self.vision_tower.config
467
- if getattr(self.config, "mm_projector_cfg", None) is None:
468
- self.config.mm_projector_cfg = self.mm_projector.config
469
-
470
- def get_llm(self):
471
- llm = getattr(self, "llm", None)
472
- if type(llm) is list:
473
- llm = llm[0]
474
- return llm
475
-
476
- def get_lm_head(self):
477
- lm_head = getattr(self.get_llm(), "lm_head", None)
478
- return lm_head
479
-
480
- def get_vision_tower(self):
481
- vision_tower = getattr(self, "vision_tower", None)
482
- if type(vision_tower) is list:
483
- vision_tower = vision_tower[0]
484
- return vision_tower
485
-
486
- def get_mm_projector(self):
487
- mm_projector = getattr(self, "mm_projector", None)
488
- if type(mm_projector) is list:
489
- mm_projector = mm_projector[0]
490
- return mm_projector
491
-
492
- def freezed_module_patch(self):
493
- """
494
- Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
495
- """
496
- if self.training:
497
- if self.get_llm() and not getattr(self.config, "tune_language_model", False):
498
- pass
499
- # logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
500
- if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
501
- self.get_vision_tower().eval()
502
- if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
503
- self.get_mm_projector().eval()
504
-
505
-
506
- class VILAForCausalLM(VILAPretrainedModel):
507
- def __init__(self, config: VILAConfig, *args, **kwargs):
508
- super().__init__(config, *args, **kwargs)
509
-
510
- def merge_features_for_dynamic_s2(self, image_features, block_sizes):
511
- scales = self.get_vision_tower().scales
512
- resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
513
-
514
- image_features_each_image = []
515
- new_block_sizes = []
516
- block_cnt = 0
517
- for block_size_each_image in block_sizes:
518
- if block_size_each_image is None:
519
- cur_features = image_features[block_cnt : block_cnt + 1]
520
- cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
521
- cur_features = cur_features.repeat(1, len(scales), 1, 1)
522
- image_features_each_image.append(cur_features)
523
- new_block_sizes.append((1, 1))
524
- block_cnt += 1
525
- else:
526
- cur_features_each_scale = []
527
- for scale in scales[:-1]:
528
- num_blocks_this_scale = (scale // scales[0]) ** 2
529
- cur_features_each_scale.append(
530
- self.merge_chessboard(
531
- image_features[block_cnt : block_cnt + num_blocks_this_scale],
532
- num_split_h=scale // scales[0],
533
- num_split_w=scale // scales[0],
534
- )
535
- ) # 1 * C * H * W
536
- block_cnt += num_blocks_this_scale
537
- num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
538
- cur_features_each_scale.append(
539
- self.merge_chessboard(
540
- image_features[block_cnt : block_cnt + num_blocks_last_scale],
541
- num_split_h=block_size_each_image[0],
542
- num_split_w=block_size_each_image[1],
543
- )
544
- ) # 1 * C * H * W
545
- block_cnt += num_blocks_last_scale
546
-
547
- # resize and concat features from different scales
548
- output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
549
- cur_features = torch.cat(
550
- [
551
- F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
552
- cur_features_each_scale[i].dtype
553
- )
554
- for i in range(len(cur_features_each_scale))
555
- ],
556
- dim=1,
557
- )
558
- # cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
559
-
560
- image_features_each_image.append(cur_features)
561
-
562
- if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
563
- new_block_sizes.append(block_size_each_image)
564
- else:
565
- new_block_sizes.append(
566
- (
567
- scales[resize_output_to_scale_idx] // scales[0],
568
- scales[resize_output_to_scale_idx] // scales[0],
569
- )
570
- )
571
-
572
- assert block_cnt == len(image_features)
573
-
574
- return image_features_each_image, new_block_sizes
575
-
576
- def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None):
577
- if block_sizes is None:
578
- block_sizes = [None] * len(images)
579
- if getattr(self.config, "dynamic_s2", False):
580
- image_features = self.get_vision_tower()(images)
581
- image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
582
-
583
- image_features = [
584
- self.split_chessboard(x, block_size[0], block_size[1])
585
- for x, block_size in zip(image_features, new_block_sizes)
586
- ] # list of B * C * H * W tensors
587
- image_features = torch.cat(
588
- [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
589
- ) # B * N * C
590
- image_features = self.get_mm_projector()(image_features)
591
- image_features = list(
592
- image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
593
- )
594
- image_features = [
595
- self.merge_chessboard(x, block_size[0], block_size[1])
596
- for x, block_size in zip(image_features, new_block_sizes)
597
- ] # list of 1 * C * H * W tensors
598
- image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
599
- if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
600
- image_features = torch.stack(image_features, dim=0)
601
- else:
602
- image_features = self.get_vision_tower()(images)
603
- image_features = self.get_mm_projector()(image_features)
604
- return image_features
605
-
606
- def train(self, mode: bool = True):
607
- super().train(mode)
608
- return self
609
-
610
- def _embed(
611
- self,
612
- input_ids: torch.Tensor,
613
- media: Dict[str, List[torch.Tensor]],
614
- media_config: Dict[str, Dict[str, Any]],
615
- labels: Optional[torch.Tensor],
616
- attention_mask: Optional[torch.Tensor],
617
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
618
- # NOTE(ligeng): deep copy to avoid modifying the original media and media_config
619
- media = copy.deepcopy(media)
620
- media_config = copy.deepcopy(media_config)
621
-
622
- labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
623
- attention_mask = attention_mask.to(dtype=torch.bool) if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
624
-
625
- PROCESS_GROUP_MANAGER = get_pg_manager()
626
- if PROCESS_GROUP_MANAGER is not None:
627
- for name in media:
628
- self.encoders[name].end_tokens = None
629
-
630
- # Extract text and media embeddings
631
- text_embeds = self.llm.model.embed_tokens(input_ids)
632
- if media is not None:
633
- media_embeds = self.__embed_media_tokens(media, media_config)
634
- else:
635
- # no media was provided, so we just return an empty dict
636
- media_embeds = {}
637
-
638
- # This is a workaround to make sure the dummy embeddings are consumed
639
- while media_embeds.get("dummy"):
640
- dummy_embed = media_embeds["dummy"].popleft()
641
- text_embeds += torch.sum(dummy_embed) * 0
642
-
643
- # Remove padding
644
- batch_size = labels.shape[0]
645
- text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
646
- labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
647
- # zijzhang: also apply to input_ids
648
- input_ids = [input_ids[k][attention_mask[k]] for k in range(batch_size)]
649
-
650
- # Build inverse mapping from token ID to media name
651
- media_tokens = {}
652
- for name, token_id in self.tokenizer.media_token_ids.items():
653
- media_tokens[token_id] = name
654
-
655
- # Fuse text and media embeddings
656
- inputs_m, labels_m = [], []
657
- for k in range(batch_size):
658
- inputs_mk, labels_mk = [], []
659
- pos = 0
660
- while pos < len(labels[k]):
661
- if input_ids[k][pos].item() in media_tokens:
662
- end = pos + 1
663
- name = media_tokens[input_ids[k][pos].item()]
664
- input = media_embeds[name].popleft()
665
- label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
666
- elif input_ids[k][pos].item() in self.pad_token_list:
667
- # skip pad tokens
668
- end = pos + 1
669
- pos = end
670
- continue
671
- else:
672
- end = pos
673
- while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
674
- end += 1
675
- input = text_embeds[k][pos:end]
676
- label = labels[k][pos:end]
677
-
678
- inputs_mk.append(input)
679
- labels_mk.append(label)
680
- pos = end
681
- inputs_m.append(torch.cat(inputs_mk, dim=0))
682
- labels_m.append(torch.cat(labels_mk, dim=0))
683
- inputs, labels = inputs_m, labels_m
684
-
685
- # Check if all media embeddings are consumed
686
- for name in media_embeds:
687
- if media_embeds[name]:
688
- raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
689
-
690
- # Truncate sequences to `model_max_length` as media embeddings are inserted
691
- inputs, labels = self.__truncate_sequence(inputs, labels)
692
-
693
- # Pad sequences to the longest one in the batch
694
- return self.__batchify_sequence(inputs, labels)
695
-
696
- def __embed_media_tokens(
697
- self,
698
- media: Dict[str, List[torch.Tensor]],
699
- media_config: Dict[str, Dict[str, Any]],
700
- ) -> Dict[str, List[torch.Tensor]]:
701
- embeds = defaultdict(deque)
702
- for name in media:
703
- if self.training:
704
- # Gather metainfo of media objects from all ranks
705
- info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
706
- infos = list(chain(vila_all_gather(info)))
707
-
708
- # The entire batch does not contain any media objects of this type.
709
- if not infos:
710
- continue
711
-
712
- # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
713
- if media.get(name) is None or len(media[name]) == 0:
714
- dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
715
- embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
716
- continue
717
- embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
718
- return embeds
719
-
720
- def __truncate_sequence(
721
- self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
722
- ) -> Tuple[torch.Tensor, torch.Tensor]:
723
- if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
724
- warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
725
- inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
726
- labels = [label[: self.tokenizer.model_max_length] for label in labels]
727
- return inputs, labels
728
-
729
- def __batchify_sequence(
730
- self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
731
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
732
- batch_size = len(inputs)
733
- device = inputs[0].device
734
- hidden_size = inputs[0].shape[1]
735
- max_length = max(inputs[k].shape[0] for k in range(batch_size))
736
- attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
737
-
738
- inputs_p, labels_p = [], []
739
- for k in range(batch_size):
740
- size_pk = max_length - inputs[k].shape[0]
741
- inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
742
- labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
743
- if self.tokenizer.padding_side == "right":
744
- attention_mask[k, inputs[k].shape[0] :] = False
745
- inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
746
- labels_pk = torch.cat([labels[k], labels_pk], dim=0)
747
- else:
748
- attention_mask[k, : -inputs[k].shape[0]] = False
749
- inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
750
- labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
751
- inputs_p.append(inputs_pk)
752
- labels_p.append(labels_pk)
753
-
754
- inputs = torch.stack(inputs_p, dim=0)
755
- labels = torch.stack(labels_p, dim=0)
756
- return inputs, labels, attention_mask
757
-
758
- def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
759
- # Handle sequence parallelism
760
- PROCESS_GROUP_MANAGER = get_pg_manager()
761
-
762
- # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
763
- if PROCESS_GROUP_MANAGER is not None:
764
- sp_degree = PROCESS_GROUP_MANAGER.sp_degree
765
- sp_rank = PROCESS_GROUP_MANAGER.sp_rank
766
- sp_group = PROCESS_GROUP_MANAGER.sp_pg
767
- ring_degree = PROCESS_GROUP_MANAGER.ring_degree
768
- ring_rank = PROCESS_GROUP_MANAGER.ring_rank
769
- ring_type = PROCESS_GROUP_MANAGER.ring_type
770
- ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
771
- ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
772
-
773
- bs, shard_seqlen = position_ids.shape
774
- sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
775
- dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
776
- sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
777
-
778
- if sp_rank == 0:
779
- original_start_id = 0
780
- else:
781
- original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
782
- original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
783
-
784
- # Gather attention_mask, position_ids, labels and input_embeds
785
- all_inputs_embeds = torch.zeros(
786
- bs,
787
- torch.sum(sp_seq_len_cat),
788
- inputs_embeds.shape[-1],
789
- dtype=inputs_embeds.dtype,
790
- device=inputs_embeds.device,
791
- ).contiguous()
792
- all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
793
- dist.barrier(group=sp_group)
794
- dist.all_reduce(all_inputs_embeds, group=sp_group)
795
- dist.barrier(group=sp_group)
796
-
797
- attention_mask_list = [
798
- torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
799
- for i in range(sp_degree)
800
- ]
801
- position_ids_list = [
802
- torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
803
- for i in range(sp_degree)
804
- ]
805
- labels_list = [
806
- torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
807
- ]
808
-
809
- dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
810
- dist.all_gather(position_ids_list, position_ids, group=sp_group)
811
- dist.all_gather(labels_list, labels, group=sp_group)
812
-
813
- effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
814
- effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
815
- effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
816
-
817
- global_attention_mask_list = []
818
- global_position_ids_list = []
819
- global_labels_list = []
820
- global_inputs_embeds_list = []
821
- for i in range(bs):
822
- global_attention_mask_batch_list = []
823
- global_position_ids_batch_list = []
824
- global_labels_batch_list = []
825
- global_inputs_embeds_batch_list = []
826
- for j in range(sp_degree):
827
- eff_len = effective_seqlen_batch_list[i][j]
828
- prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
829
-
830
- global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
831
- global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
832
- global_labels_batch_list.append(labels_list[j][i, :eff_len])
833
- global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
834
- global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
835
- global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
836
- global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
837
- global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
838
-
839
- global_attention_mask = torch.nn.utils.rnn.pad_sequence(
840
- global_attention_mask_list, batch_first=True, padding_value=False
841
- )
842
- global_position_ids = torch.nn.utils.rnn.pad_sequence(
843
- global_position_ids_list, batch_first=True, padding_value=-1
844
- )
845
- global_labels = torch.nn.utils.rnn.pad_sequence(
846
- global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
847
- )
848
- global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
849
- global_inputs_embeds_list, batch_first=True, padding_value=0
850
- )
851
-
852
- # Re-shard the inputs
853
- if ring_degree > 1:
854
- total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
855
- new_seqlen_per_rank = total_effective_seqlen // sp_degree
856
- assert torch.all(
857
- total_effective_seqlen % sp_degree == 0
858
- ), "total_effective_seqlen must be divisible by sp_degree"
859
-
860
- max_new_seqlen = torch.max(new_seqlen_per_rank).item()
861
-
862
- new_attention_mask = torch.zeros(
863
- (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
864
- )
865
- new_position_ids = torch.zeros(
866
- (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
867
- )
868
- new_labels = torch.full(
869
- (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
870
- )
871
- new_inputs_embeds = torch.zeros(
872
- (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
873
- dtype=global_inputs_embeds.dtype,
874
- device=global_inputs_embeds.device,
875
- )
876
-
877
- if ring_type == "ring_varlen":
878
- for i in range(bs):
879
- start_idx = new_seqlen_per_rank[i] * sp_rank
880
- end_idx = start_idx + new_seqlen_per_rank[i]
881
- new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
882
- new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
883
- new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
884
- new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
885
- i, start_idx:end_idx, :
886
- ]
887
- elif ring_type == "zigzag_ring_varlen":
888
- chunk_size = total_effective_seqlen // (2 * sp_degree)
889
- for i in range(bs):
890
- # Zigzag pattern indices
891
- if sp_degree == ring_degree:
892
- forward_rank_idx = sp_rank
893
- backward_rank_idx = 2 * sp_degree - sp_rank - 1
894
- else:
895
- ulysses_offset = ulysses_rank * ring_degree * 2
896
- forward_rank_idx = ring_rank + ulysses_offset
897
- backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
898
-
899
- # Calculate start and end indices for the forward and backward zigzag
900
- start_idx_fwd = forward_rank_idx * chunk_size[i]
901
- end_idx_fwd = start_idx_fwd + chunk_size[i]
902
-
903
- start_idx_bwd = backward_rank_idx * chunk_size[i]
904
- end_idx_bwd = start_idx_bwd + chunk_size[i]
905
-
906
- # Fill new tensors with zigzag data
907
- new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
908
- new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
909
- i, start_idx_bwd:end_idx_bwd
910
- ]
911
-
912
- new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
913
- new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
914
- i, start_idx_bwd:end_idx_bwd
915
- ]
916
-
917
- new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
918
- new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
919
-
920
- new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
921
- new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
922
- i, start_idx_bwd:end_idx_bwd, :
923
- ]
924
- else:
925
- raise ValueError(f"Invalid ring_type: {ring_type}")
926
- else:
927
- global_seq_len = global_attention_mask.shape[-1]
928
- seq_len_sharded = global_seq_len // sp_degree
929
- start_idx_reshard = seq_len_sharded * sp_rank
930
- end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
931
-
932
- new_attention_mask = torch.narrow(
933
- global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
934
- )
935
- new_position_ids = torch.narrow(
936
- global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
937
- )
938
- new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
939
- new_inputs_embeds = torch.narrow(
940
- global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
941
- )
942
-
943
- return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
944
-
945
- device = inputs_embeds.device
946
- batch_size = inputs_embeds.shape[0]
947
- seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
948
-
949
- # Pack all sequences together
950
- inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
951
- attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
952
- position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
953
- labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
954
-
955
- # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
956
- inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
957
- attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
958
- position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
959
- labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
960
-
961
- # Mask the first token of each sequence to avoid contamination
962
- for label in labels_p:
963
- label[0] = IGNORE_INDEX
964
-
965
- # Batch the data
966
- inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
967
- attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
968
- position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
969
- labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
970
-
971
- if hasattr(
972
- self, "pad_to_multiple_of"
973
- ): # related to quantization, please refer to ModelArguments for more information.
974
- assert len(labels_p.shape) == 2
975
- batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
976
- hidden_size = inputs_embeds_p.shape[-1]
977
-
978
- if max_length % self.pad_to_multiple_of != 0:
979
- max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
980
- difference = max_length - cur_length
981
-
982
- inputs_embeds_p = torch.cat(
983
- (
984
- inputs_embeds_p,
985
- torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
986
- ),
987
- dim=1,
988
- )
989
- labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
990
- attention_mask_p = torch.cat(
991
- (
992
- attention_mask_p,
993
- torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
994
- ),
995
- dim=1,
996
- )
997
- position_ids_p = torch.cat(
998
- (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
999
- )
1000
-
1001
- return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
1002
-
1003
- def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]:
1004
- raise NotImplementedError("This method is not implemented for VILA model.")
1005
- # Convert response format to logits processor
1006
- import xgrammar as xgr
1007
-
1008
- logging.info("[XGrammar] Compiling grammar for contrained output")
1009
-
1010
- if self.grammar_compiler is None:
1011
- # logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
1012
- self.grammar_compiler = xgr.GrammarCompiler(
1013
- xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
1014
- )
1015
-
1016
- if response_format.type == "json_schema":
1017
- compiled_grammar = self.grammar_compiler.compile_json_schema(
1018
- response_format.json_schema.schema_,
1019
- indent=2,
1020
- )
1021
- else:
1022
- compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
1023
-
1024
- return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
1025
-
1026
- def forward(
1027
- self,
1028
- input_ids: torch.LongTensor = None,
1029
- media: Optional[Dict[str, List[torch.Tensor]]] = None,
1030
- images: Optional[torch.FloatTensor] = None,
1031
- media_config: Optional[List] = None,
1032
- pixel_values: Optional[torch.FloatTensor] = None,
1033
- attention_mask: Optional[torch.Tensor] = None,
1034
- position_ids: Optional[torch.LongTensor] = None,
1035
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1036
- inputs_embeds: Optional[torch.FloatTensor] = None,
1037
- labels: Optional[torch.LongTensor] = None,
1038
- packing: bool = True,
1039
- force_packing: bool = False,
1040
- seqlens_in_batch: Optional[torch.LongTensor] = None,
1041
- dpo_forward: bool = False,
1042
- **kwargs,
1043
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1044
- self.freezed_module_patch()
1045
-
1046
- if images is not None:
1047
- if media is not None:
1048
- raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
1049
- print("The 'images' argument is deprecated. Please use 'media' instead.")
1050
- media = {"image": images}
1051
-
1052
- if media_config is None:
1053
- media_config = defaultdict(dict)
1054
-
1055
- if inputs_embeds is None:
1056
- inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
1057
-
1058
- if force_packing or (packing and self.training and not dpo_forward):
1059
- if seqlens_in_batch is None:
1060
- seqlens_in_batch = torch.sum(attention_mask, dim=1)
1061
- set_seqlens_in_batch(seqlens_in_batch)
1062
-
1063
- (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
1064
- inputs_embeds, attention_mask, position_ids, labels
1065
- )
1066
-
1067
- outputs = self.llm(
1068
- inputs_embeds=inputs_embeds,
1069
- attention_mask=attention_mask,
1070
- position_ids=position_ids,
1071
- past_key_values=past_key_values,
1072
- labels=labels,
1073
- **kwargs,
1074
- )
1075
-
1076
- if self.training and getattr(self.config, "time_token_ids", []):
1077
- outputs.loss = soft_cross_entropy(
1078
- outputs.logits,
1079
- labels,
1080
- soft_tokens=self.config.time_token_ids,
1081
- std=self.config.soft_ce_std,
1082
- )
1083
-
1084
- if dpo_forward:
1085
- return outputs.logits, labels
1086
-
1087
- return outputs
1088
-
1089
- # TODO(ligeng): check how qwen implements this function
1090
- # @torch.inference_mode()
1091
- def generate(
1092
- self,
1093
- input_ids: Optional[torch.FloatTensor] = None,
1094
- media: Optional[Dict[str, List[torch.Tensor]]] = None,
1095
- media_config: Dict[str, Dict[str, Any]] = None,
1096
- attention_mask: Optional[torch.LongTensor] = None,
1097
- return_output_ids_only: bool = False,
1098
- **generation_kwargs,
1099
- ) -> torch.LongTensor:
1100
- """
1101
- input_tokens: <image> describe the image
1102
- media: [Tensor(1, 3, 384, 384), ]
1103
- ----------->
1104
- input_tokens: 36000 001 002 003 004
1105
- input_emds: <media emd> 001 002 003 004
1106
- """
1107
- # NOTE: hard code to move to GPU
1108
- # input_ids = input_ids.cuda()
1109
- # media = {k: [v.cuda() if v is not None for v in media[k]] for k in media}
1110
- # if attention_mask is not None:
1111
- # attention_mask = attention_mask.cuda()
1112
- inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1113
- output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
1114
-
1115
- if return_output_ids_only:
1116
- return_value = output_ids
1117
- else:
1118
- # by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
1119
- generation_config = generation_kwargs.get("generation_config", None)
1120
- if generation_config is not None:
1121
- num_generations = generation_config.num_return_sequences
1122
- repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
1123
- return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
1124
- else:
1125
- return_value = torch.cat([input_ids, output_ids], dim=-1)
1126
-
1127
- return return_value
1128
-
1129
- @torch.inference_mode()
1130
- def generate_content(
1131
- self,
1132
- prompt: Union[str, List],
1133
- generation_config: Optional[GenerationConfig] = None,
1134
- response_format=None,
1135
- ) -> str:
1136
- # TODO(zhijianl): Support directly taking conversation as input
1137
- conversation = [{"from": "human", "value": prompt}]
1138
-
1139
- # Convert response format to logits processor
1140
- xgr_logits_processor = None
1141
-
1142
- # Extract media from the conversation
1143
-
1144
- # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used)
1145
- media = extract_media(conversation, self.config)
1146
-
1147
- # Process media
1148
- media_config = defaultdict(dict)
1149
- for name in media:
1150
- if name == "image":
1151
- if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
1152
- self.config.image_processor = self.vision_tower.image_processor
1153
- if self.config.image_aspect_ratio == "dynamic":
1154
- images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
1155
- conversation[0]["value"] = conversation[0]["value"].replace(
1156
- DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
1157
- )
1158
- else:
1159
- if type(self.config.s2_scales) is str:
1160
- self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1161
- images, block_sizes = process_image(
1162
- media["image"][0], self.config, None, enable_dynamic_s2=True
1163
- )
1164
- images = images.half()
1165
- media_config[name]["block_sizes"] = [block_sizes]
1166
- else:
1167
- images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
1168
- media[name] = [image for image in images]
1169
- elif name == "video":
1170
- if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
1171
- media[name] = [
1172
- process_images(
1173
- images,
1174
- self.vision_tower.image_processor,
1175
- self.config,
1176
- enable_dynamic_res=True,
1177
- max_tiles=self.config.video_max_tiles,
1178
- ).half()
1179
- for images in media[name]
1180
- ]
1181
- elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
1182
- self.config.image_processor = self.vision_tower.image_processor
1183
- if type(self.config.s2_scales) is str:
1184
- self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1185
- media[name] = [
1186
- torch.cat(
1187
- [
1188
- process_image(
1189
- image,
1190
- self.config,
1191
- None,
1192
- enable_dynamic_s2=True,
1193
- max_tiles=self.config.video_max_tiles,
1194
- )[0].half()
1195
- for image in images
1196
- ]
1197
- )
1198
- for images in media[name]
1199
- ]
1200
- else:
1201
- media[name] = [
1202
- process_images(images, self.vision_tower.image_processor, self.config).half()
1203
- for images in media[name]
1204
- ]
1205
- else:
1206
- raise ValueError(f"Unsupported media type: {name}")
1207
-
1208
- # Tokenize the conversation
1209
- input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
1210
-
1211
- # Set up the generation config
1212
- generation_config = generation_config or self.default_generation_config
1213
-
1214
- # print("input_ids", input_ids.shape)
1215
- # print(input_ids)
1216
- # print(self.tokenizer.batch_decode(input_ids))
1217
- # print("media", {k: len(v) for k, v in media.items()})
1218
- # print("media_config", media_config)
1219
- # print("generation_config", generation_config)
1220
- # input("wait for debug")
1221
- # Generate the response
1222
- try:
1223
- output_ids = self.generate(
1224
- input_ids=input_ids,
1225
- media=media,
1226
- media_config=media_config,
1227
- generation_config=generation_config,
1228
- logits_processor=xgr_logits_processor, # structured generation
1229
- )
1230
- except ValueError:
1231
- if not generation_config.do_sample:
1232
- raise
1233
- # FIXME(zhijianl): This is a temporary workaround for the sampling issue
1234
- logging.warning("Generation failed with sampling, retrying with greedy decoding.")
1235
- generation_config.do_sample = False
1236
- output_ids = self.generate(
1237
- input_ids=input_ids,
1238
- media=media,
1239
- media_config=media_config,
1240
- generation_config=generation_config,
1241
- logits_processor=xgr_logits_processor,
1242
- )
1243
-
1244
- # Decode the response
1245
- response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
1246
- return response
1247
-
1248
- @property
1249
- def default_generation_config(self) -> GenerationConfig:
1250
- generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
1251
- if self.tokenizer.eos_token_id is None:
1252
- raise ValueError("Tokenizer must have an EOS token")
1253
- if generation_config.max_length == GenerationConfig().max_length:
1254
- generation_config.max_length = self.tokenizer.model_max_length
1255
- if generation_config.pad_token_id is None:
1256
- generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
1257
- if generation_config.bos_token_id is None:
1258
- generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
1259
- if generation_config.eos_token_id is None:
1260
- generation_config.eos_token_id = self.tokenizer.eos_token_id
1261
- return generation_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -6,10 +6,10 @@
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "configuration_vila.VILAConfig",
9
- "AutoModel": "modeling_vila.VILAForCausalLM",
10
- "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM",
11
- "AutoModelForImageTextToText": "modeling_vila.VILAForConditionalGeneration",
12
- "AutoModelForVision2Seq": "modeling_vila.VILAForConditionalGeneration"
13
  },
14
  "chat_template": null,
15
  "drop_path_rate": 0.0,
 
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "configuration_vila.VILAConfig",
9
+ "AutoModel": "modeling_vila_hf.VILAForConditionalGeneration",
10
+ "AutoModelForCausalLM": "modeling_vila_hf.VILAForConditionalGeneration",
11
+ "AutoModelForImageTextToText": "modeling_vila_hf.VILAForConditionalGeneration",
12
+ "AutoModelForVision2Seq": "modeling_vila_hf.VILAForConditionalGeneration"
13
  },
14
  "chat_template": null,
15
  "drop_path_rate": 0.0,
modeling_vila.py CHANGED
@@ -1,70 +1,415 @@
 
 
 
 
1
  import os
2
- from typing import Dict, Optional, Tuple, Type, Union, cast, override
 
 
 
 
 
 
 
 
 
3
 
4
  import torch
5
- import transformers.modeling_utils as modeling_utils
6
- from torch import Tensor
7
- from transformers.configuration_utils import PretrainedConfig
8
- from transformers.generation.utils import GenerationMixin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.modeling_utils import PreTrainedModel
11
 
12
- from .auto_model import VILAForCausalLM
 
 
13
  from .configuration_vila import VILAConfig
 
 
 
 
 
 
 
 
 
 
 
14
 
 
15
 
16
- class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
17
- config_class: Type[PretrainedConfig] = VILAConfig
18
- base_model_prefix: str = "vila"
19
- is_parallelizable: bool = True
20
- main_input_name: str = "input_ids"
21
 
22
- config: PretrainedConfig
23
 
24
- def __init__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self,
26
- config: PretrainedConfig,
27
- model: VILAForCausalLM,
28
- *args,
 
 
 
 
 
29
  **kwargs,
30
  ):
31
- super().__init__(config, *args, **kwargs)
 
 
 
 
 
32
 
33
- self.model = model
 
34
 
35
- def forward(
36
- self,
37
- *,
38
- attention_mask: Optional[Tensor] = None,
39
- input_ids: Optional[Tensor] = None,
40
- inputs_embeds: Optional[Tensor] = None,
41
- pixel_values: Optional[Tensor] = None,
42
- **kwargs,
43
- ) -> CausalLMOutputWithPast:
44
- if inputs_embeds is None:
45
- assert input_ids is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- inputs_embeds, attention_mask = self._embed(
48
- input_ids, pixel_values, attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  )
50
- else:
51
- assert input_ids is None
52
- assert pixel_values is None
 
 
 
 
 
 
53
 
54
- outputs = self.model.llm.forward(
55
- inputs_embeds=inputs_embeds,
56
- attention_mask=attention_mask,
57
- **kwargs,
58
- )
 
 
 
 
 
 
59
 
60
- return outputs
 
 
 
 
 
 
61
 
62
- @override
63
  @classmethod
64
- @modeling_utils.restore_default_torch_dtype
65
  def from_pretrained(
66
- cls: Type[modeling_utils.SpecificPreTrainedModelType],
67
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
68
  *model_args,
69
  config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
70
  cache_dir: Optional[Union[str, os.PathLike]] = None,
@@ -76,116 +421,841 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
76
  use_safetensors: Optional[bool] = None,
77
  weights_only: bool = True,
78
  **kwargs,
79
- ) -> modeling_utils.SpecificPreTrainedModelType:
80
- state_dict = kwargs.pop("state_dict", None)
81
-
82
- if pretrained_model_name_or_path is not None:
83
- config = VILAConfig.from_pretrained(
84
- pretrained_model_name_or_path,
85
- cache_dir=cache_dir,
86
- force_download=force_download,
87
- local_files_only=local_files_only,
88
- revision=revision,
89
- use_safetensors=use_safetensors,
90
- **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
 
 
 
 
 
 
 
92
  else:
93
- assert (
94
- config is not None and state_dict is not None
95
- ), "Both config and state_dict must be provided if pretrained_model_name_or_path is None."
96
-
97
- inner_model = VILAForCausalLM.from_pretrained(
98
- pretrained_model_name_or_path, # type: ignore
99
- *model_args,
100
- config=config,
101
- cache_dir=cache_dir,
102
- ignore_mismatched_sizes=ignore_mismatched_sizes,
103
- force_download=force_download,
104
- local_files_only=local_files_only,
105
- token=token,
106
- revision=revision,
107
- use_safetensors=use_safetensors,
108
- weights_only=weights_only,
109
- **kwargs,
110
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- state_dict = inner_model.state_dict()
113
-
114
- # Prefix keys with "model.".
115
- state_dict = {f"model.{k}": v for k, v in state_dict.items()}
116
-
117
- return super().from_pretrained(
118
- None,
119
- inner_model,
120
- *model_args,
121
- config=config,
122
- cache_dir=cache_dir,
123
- ignore_mismatched_sizes=ignore_mismatched_sizes,
124
- force_download=force_download,
125
- local_files_only=local_files_only,
126
- token=token,
127
- revision=revision,
128
- state_dict=state_dict,
129
- use_safetensors=use_safetensors,
130
- weights_only=weights_only,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  **kwargs,
132
  )
133
 
134
- def _embed(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  self,
136
- input_ids: Tensor,
137
- pixel_values: Optional[Tensor],
138
- attention_mask: Optional[Tensor],
139
- ) -> Tuple[Tensor, Tensor]:
140
- """Gets the embedding of the input ids and pixel values.
141
-
142
- Args:
143
- input_ids: The input ids.
144
- pixel_values: The pixel values.
145
- attention_mask: The attention mask.
146
-
147
- Returns:
148
- A tuple of the embedding of the input ids and attention mask.
149
  """
 
 
 
 
 
 
 
150
 
151
- image_token_ids_map = cast(Dict[str, int], self.model.tokenizer.media_token_ids)
152
- image_token_ids = list(image_token_ids_map.values())
153
- image_token_idx = torch.isin(
154
- input_ids,
155
- torch.tensor(image_token_ids).to(input_ids.device),
156
- )
157
- image_token_count = image_token_idx.sum()
 
 
 
 
158
 
159
- images = list(pixel_values) if pixel_values is not None else []
160
 
161
- if image_token_count < len(images):
162
- images = images[:image_token_count]
 
 
 
 
 
 
 
163
 
164
- media = (
165
- {
166
- "image": images,
167
- }
168
- if image_token_count > 0
169
- else {}
170
- )
171
- media_config = (
172
- {
173
- "image": {},
174
- }
175
- if image_token_count > 0
176
- else {}
177
- )
178
 
179
- outputs = self.model._embed(
180
- input_ids,
181
- media,
182
- media_config,
183
- labels=None,
184
- attention_mask=(
185
- attention_mask[:, -input_ids.shape[1] :].to(dtype=torch.bool)
186
- if attention_mask is not None
187
- else None
188
- ),
189
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- return outputs[0], outputs[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import math
5
  import os
6
+ import os.path
7
+ import os.path as osp
8
+ import shutil
9
+ import warnings
10
+ from abc import ABC
11
+ from collections import OrderedDict, defaultdict, deque
12
+ from copy import deepcopy
13
+ from itertools import chain
14
+ from threading import Thread
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
 
17
  import torch
18
+ import torch.distributed as dist
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torchvision
22
+ from einops import rearrange
23
+ from PIL import Image
24
+ from transformers import (
25
+ AutoConfig,
26
+ AutoModel,
27
+ AutoProcessor,
28
+ AutoTokenizer,
29
+ GenerationConfig,
30
+ LogitsProcessor,
31
+ PretrainedConfig,
32
+ PreTrainedModel,
33
+ Qwen2Config,
34
+ Qwen2ForCausalLM,
35
+ Qwen2PreTrainedModel,
36
+ TextIteratorStreamer,
37
+ )
38
  from transformers.modeling_outputs import CausalLMOutputWithPast
39
+ from transformers.modeling_utils import ContextManagers, no_init_weights
40
 
41
+ from .auto_processor import VILAProcessor
42
+ from .base_projector import MultimodalProjector, MultimodalProjectorConfig
43
+ from .builder import build_llm_and_tokenizer
44
  from .configuration_vila import VILAConfig
45
+ from .constants import *
46
+ from .conversation import SeparatorStyle, default_conversation
47
+ from .distributed import all_gather as vila_all_gather
48
+ from .loss import soft_cross_entropy
49
+ from .media import extract_media
50
+ from .media_encoder import BasicImageEncoder, BasicVideoEncoder
51
+ from .mm_utils import process_image, process_images
52
+ from .model_utils_packing import set_seqlens_in_batch
53
+ from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
54
+ from .tokenizer_utils import tokenize_conversation
55
+ from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
56
 
57
+ # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
58
 
59
+ # ease debugging
60
+ python_input = input
 
 
 
61
 
 
62
 
63
+ # quick hack for remote code
64
+ def get_pg_manager():
65
+ return None
66
+
67
+
68
+ def get_model_weights_dtype(model: nn.Module):
69
+ pass
70
+
71
+
72
+ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
73
+ if model_type_or_path is None:
74
+ return None
75
+ ## load from pretrained model
76
+ if config.resume_path:
77
+ assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
78
+ return MultimodalProjector.from_pretrained(model_type_or_path, config)
79
+ ## build from scratch
80
+ else:
81
+ mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
82
+ mm_projector = MultimodalProjector(mm_projector_cfg, config)
83
+ return mm_projector
84
+
85
+
86
+ def check_dot_in_model_path(model_path: str):
87
+ """Check if the model path contains dot, which will affect the remote code loading."""
88
+ if osp.isdir(model_path): # local model
89
+ if "." in osp.abspath(model_path):
90
+ return True
91
+ else: # remote model
92
+ if "." in model_path:
93
+ return True
94
+ return False
95
+
96
+
97
+ def get_vila_version(model_path: str) -> str:
98
+ VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
99
+ for version in VERSIONS:
100
+ if version in model_path.lower():
101
+ return version
102
+ return None
103
+
104
+
105
+ def generate_jinja_template(conv_mode: str) -> str:
106
+ if conv_mode == "vicuna_v1":
107
+ return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
108
+ {% set roles = ["user", "assistant"] %}
109
+ {% set sep = " " %}
110
+
111
+ {{ system_prompt }}
112
+
113
+ {% for message in messages %}
114
+ {% if message['role'] == roles[0] %}
115
+ {{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
116
+ {% else %}
117
+ {{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
118
+ {% endif %}
119
+ {% endfor %}
120
+ {% if messages[-1]['role'] == 'user' %}
121
+ {{ "ASSISTANT:" }}
122
+ {% endif %}
123
+ """
124
+ elif conv_mode == "llama_3":
125
+ return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
126
+ {% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
127
+ {% set sep = "<|eot_id|>" %}
128
+
129
+ {{ system_prompt }}
130
+ {% for message in messages %}
131
+ {% if message['role'] == 'user' %}
132
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
133
+ {% else %}
134
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
135
+ {% endif %}
136
+ {% endfor %}
137
+ {% if messages[-1]['role'] == 'user' %}
138
+ {{ roles[1] }}
139
+ {% endif %}
140
+ """
141
+ elif conv_mode == "hermes_2":
142
+ return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
143
+ {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
144
+ {% set sep = "<|im_end|>" %}
145
+
146
+ {{ system_prompt }}{{ sep }}
147
+
148
+ {% for message in messages %}
149
+ {% if message['role'] == 'user' %}
150
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
151
+ {% else %}
152
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
153
+ {% endif %}
154
+ {% endfor %}"""
155
+ else:
156
+ raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
157
+
158
+
159
+ def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
160
+ ## skip vision tower instantiation
161
+ if model_name_or_path is None:
162
+ return None
163
+
164
+ vision_tower_arch = None
165
+ if config.resume_path and "radio" not in model_name_or_path:
166
+ assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
167
+ vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
168
+ vision_tower_arch = vision_tower_cfg.architectures[0].lower()
169
+ vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
170
+
171
+ use_s2 = getattr(config, "s2", False)
172
+ use_dynamic_s2 = getattr(config, "dynamic_s2", False)
173
+
174
+ if "siglip" in vision_tower_name:
175
+ if use_dynamic_s2:
176
+ vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
177
+ elif use_s2:
178
+ vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
179
+ else:
180
+ vision_tower = SiglipVisionTower(model_name_or_path, config)
181
+ else:
182
+ raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
183
+
184
+ config.mm_hidden_size = (
185
+ vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
186
+ )
187
+ return vision_tower
188
+
189
+
190
+ class VILAPretrainedModel(PreTrainedModel):
191
+ config_class = VILAConfig
192
+ main_input_name = "input_embeds"
193
+ supports_gradient_checkpointing = True
194
+ _supports_flash_attn_2 = True
195
+ _no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
196
+
197
+ def __init__(self, config: VILAConfig, *args, **kwargs):
198
+ super().__init__(config)
199
+ self.config = config
200
+ cfgs = get_model_config(config)
201
+ if len(cfgs) == 3:
202
+ llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
203
+ else:
204
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
205
+
206
+ # loading on auto by default
207
+ device_map = kwargs.get("device_map", "auto")
208
+ self.mm_projector = build_mm_projector(mm_projector_cfg, config)
209
+ self.vision_tower = build_vision_tower(vision_tower_cfg, config)
210
+ if device_map in ["auto", "cuda"]:
211
+ self.mm_projector = self.mm_projector.cuda()
212
+ self.vision_tower = self.vision_tower.cuda()
213
+ # set device_map auto can autoamtically shard llm to different devices
214
+ self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
215
+
216
+ # NOTE(ligeng): hard code to set padding_side to left
217
+ self.tokenizer.padding_side = "left"
218
+ # TODO(ligeng): need to add other decoders from config
219
+ self.encoders = {"image": BasicImageEncoder(self), "video": BasicVideoEncoder(self)}
220
+
221
+ self.post_config()
222
+ self.is_loaded = True
223
+
224
+ assert (
225
+ self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
226
+ ), "At least one of the components must be instantiated."
227
+
228
+ @classmethod
229
+ def convert_vila_dev_ckpt_to_remote(
230
  self,
231
+ model_path: str,
232
+ output_dir: str = None,
233
+ vila_version: str | None = None,
234
+ conv_mode: str | None = None,
235
+ copy: bool = False,
236
+ copy_weights: bool = True,
237
+ copy_code: bool = True,
238
+ *model_args,
239
  **kwargs,
240
  ):
241
+ # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
242
+ assert model_path != output_dir, "model_path and output_dir cannot be the same"
243
+ if os.path.isdir(model_path):
244
+ model_path = model_path
245
+ else:
246
+ from huggingface_hub import HfApi, snapshot_download
247
 
248
+ model_path = snapshot_download(model_path)
249
+ print("downloading HF model to", model_path)
250
 
251
+ if check_dot_in_model_path(model_path) and output_dir is None:
252
+ raise ValueError(
253
+ f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
254
+ )
255
+ if output_dir is not None and "." in output_dir:
256
+ raise ValueError(
257
+ f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
258
+ )
259
+
260
+ if copy:
261
+ print("copy is set to True, copying weights and code to output_dir")
262
+ copy_weights = copy_code = True
263
+ # copy weights and code to output_dir
264
+ self.copy_or_symlink_directory(model_path, output_dir, copy=copy_weights)
265
+ self.copy_remote_py_files(output_dir, copy=copy_code)
266
+
267
+ if vila_version is None:
268
+ vila_version = get_vila_version(output_dir)
269
+
270
+ cfg_path = os.path.join(output_dir, "config.json")
271
+ config = json.load(open(cfg_path))
272
+ config["version"] = "2.0" # nvila tag
273
+ config["architectures"] = ["VILAForCausalLM"]
274
+ config["auto_map"] = {
275
+ "AutoProcessor": "auto_processor.VILAProcessor",
276
+ "AutoConfig": "modeling_vila.VILAConfig",
277
+ "AutoModel": "modeling_vila.VILAForCausalLM",
278
+ "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM",
279
+ }
280
+ # vila1.5 legacy support
281
+ config["model_type"] = "vila"
282
+ if vila_version in ["vila1.5", "vila-m3"]:
283
+ if conv_mode is None:
284
+ raise ValueError(f"Please specify the conversation mode for {output_dir}.")
285
+ config["chat_template"] = conv_mode
286
+ jinja_template = generate_jinja_template(conv_mode)
287
+ jinja_path = os.path.join(output_dir, f"{conv_mode}.jinja")
288
+ with open(jinja_path, "w") as f:
289
+ f.write(jinja_template)
290
+ json.dump(config, open(cfg_path, "w"), indent=2)
291
+
292
+ ##########################################################################################
293
+ config = AutoConfig.from_pretrained(output_dir, trust_remote_code=True)
294
+ tokenizer = load_tokenizer_then_handle_media_tokens_and_chat_template(output_dir, config)
295
+ tokenizer.save_pretrained(osp.join(output_dir, "llm"))
296
+ ##########################################################################################
297
+
298
+ @classmethod
299
+ def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
300
+ # Create output directory if it doesn't exist
301
+ os.makedirs(output_dir, exist_ok=True)
302
+ # Create symlinks for all files in model_path to output_dir
303
+ for item in os.listdir(model_path):
304
+ src_path = os.path.join(model_path, item)
305
+ dst_path = os.path.join(output_dir, item)
306
+
307
+ # Remove existing file/directory at destination if it exists
308
+ if os.path.exists(dst_path):
309
+ if os.path.islink(dst_path):
310
+ os.unlink(dst_path)
311
+ elif os.path.isdir(dst_path):
312
+ shutil.rmtree(dst_path)
313
+ else:
314
+ os.remove(dst_path)
315
 
316
+ # Create symlink
317
+ if copy:
318
+ if os.path.isdir(src_path):
319
+ shutil.copytree(src_path, dst_path)
320
+ else:
321
+ shutil.copy2(src_path, dst_path)
322
+ print(f"Copied {src_path} to {dst_path}")
323
+ else:
324
+ os.symlink(src_path, dst_path)
325
+ print(f"Created symlink from {src_path} to {dst_path}")
326
+
327
+ @classmethod
328
+ def copy_remote_py_files(cls, output_dir, copy=True):
329
+ ## copy .py and REAMDE for next loading remote code
330
+ current_file_path = os.path.abspath(__file__)
331
+ current_folder = os.path.dirname(current_file_path)
332
+ for file_name in os.listdir(current_folder):
333
+ if file_name == "INSTRUCTIONS.md":
334
+ src_fname = os.path.join(current_folder, file_name)
335
+ dst_fname = os.path.join(output_dir, "README.md")
336
+ if os.path.exists(dst_fname):
337
+ old_reamde = open(dst_fname).read()
338
+ else:
339
+ old_reamde = ""
340
+ with open(src_fname) as src, open(dst_fname, "w") as dst:
341
+ dst.write(src.read())
342
+ dst.write(old_reamde)
343
+ print("[HF remote code] REAMDE ", src_fname, "to", dst_fname)
344
+ if file_name.endswith(".py") or file_name.endswith(".jinja"):
345
+ full_file_name = os.path.join(current_folder, file_name)
346
+ if os.path.isfile(full_file_name):
347
+ if copy:
348
+ shutil.copy(full_file_name, output_dir)
349
+ print("[HF remote code] copying", full_file_name, "to", output_dir)
350
+ else:
351
+ # symlink to ease development
352
+ if os.path.exists(os.path.join(output_dir, file_name)):
353
+ os.remove(os.path.join(output_dir, file_name))
354
+ os.symlink(full_file_name, os.path.join(output_dir, file_name))
355
+ print("[HF remote code] linking", full_file_name, "to", output_dir)
356
+
357
+ def save_pretrained(self, output_dir, state_dict=None, **kwargs):
358
+ if state_dict is None:
359
+ # other wise fetch from deepspeed
360
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
361
+ state_dict = self.state_dict()
362
+
363
+ if getattr(self, "tokenizer", None):
364
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
365
+
366
+ if self.get_llm():
367
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
368
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
369
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
370
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
371
+ self.config.llm_cfg = self.llm.config
372
+
373
+ if self.get_vision_tower():
374
+ print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
375
+ self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
376
+ vision_tower_state_dict = OrderedDict(
377
+ {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
378
  )
379
+ self.vision_tower.vision_tower.save_pretrained(
380
+ os.path.join(output_dir, "vision_tower"),
381
+ state_dict=vision_tower_state_dict,
382
+ )
383
+ self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
384
+ self.config.vision_tower_cfg = self.vision_tower.config
385
+ if hasattr(self.config.vision_tower_cfg, "auto_map"):
386
+ if "radio" not in self.get_vision_tower().__class__.__name__.lower():
387
+ delattr(self.config.vision_tower_cfg, "auto_map")
388
 
389
+ if self.get_mm_projector():
390
+ print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
391
+ self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
392
+ mm_projector_state_dict = OrderedDict(
393
+ {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
394
+ )
395
+ self.mm_projector.save_pretrained(
396
+ os.path.join(output_dir, "mm_projector"),
397
+ state_dict=mm_projector_state_dict,
398
+ )
399
+ self.config.mm_projector_cfg = self.mm_projector.config
400
 
401
+ ## update and save top-level config
402
+ self.config._name_or_path = output_dir
403
+ self.config.architectures = [self.__class__.__name__]
404
+ self.config.save_pretrained(output_dir)
405
+
406
+ ## copy .py and REAMDE for next loading remote code
407
+ self.copy_remote_py_files(output_dir)
408
 
 
409
  @classmethod
 
410
  def from_pretrained(
411
+ cls,
412
+ pretrained_model_name_or_path: Optional[str] = None,
413
  *model_args,
414
  config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
415
  cache_dir: Optional[Union[str, os.PathLike]] = None,
 
421
  use_safetensors: Optional[bool] = None,
422
  weights_only: bool = True,
423
  **kwargs,
424
+ ):
425
+ # print("DEBUG2", kwargs); input()
426
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
427
+ return cls._from_config(config, **kwargs)
428
+
429
+ def init_llm(self, llm_config, config, *args, **kwargs):
430
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
431
+ # hard coded for NVILA
432
+ # variables for XGrammar
433
+ # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
434
+ NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
435
+
436
+ self.pad_token_list = (
437
+ self.tokenizer.pad_token_id,
438
+ self.tokenizer.eos_token_id,
439
+ self.tokenizer.tokenize("<|endoftext|>")[0], # for qwen
440
+ )
441
+
442
+ # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
443
+ self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
444
+ # XGrammar tokenizer and grammar compiler
445
+ # lazy init only when specified json output during inference
446
+ self.grammar_compiler = None
447
+ self.llm.resize_token_embeddings(len(self.tokenizer))
448
+ return self.llm, self.tokenizer
449
+
450
+ def post_config(self):
451
+ ######################################################################
452
+ # TODO: need to check dtype with jason
453
+ self.llm = self.llm.to(torch.float16)
454
+ self.mm_projector = self.mm_projector.to(torch.float16)
455
+ self.vision_tower = self.vision_tower.to(torch.float16)
456
+ ######################################################################
457
+ self.training = self.llm.training
458
+ if self.training:
459
+ self.train()
460
+ else:
461
+ self.eval()
462
+ ## configuration
463
+ if getattr(self.config, "llm_cfg", None) is None:
464
+ self.config.llm_cfg = self.llm.config
465
+ if getattr(self.config, "vision_tower_cfg", None) is None:
466
+ self.config.vision_tower_cfg = self.vision_tower.config
467
+ if getattr(self.config, "mm_projector_cfg", None) is None:
468
+ self.config.mm_projector_cfg = self.mm_projector.config
469
+
470
+ def get_llm(self):
471
+ llm = getattr(self, "llm", None)
472
+ if type(llm) is list:
473
+ llm = llm[0]
474
+ return llm
475
+
476
+ def get_lm_head(self):
477
+ lm_head = getattr(self.get_llm(), "lm_head", None)
478
+ return lm_head
479
+
480
+ def get_vision_tower(self):
481
+ vision_tower = getattr(self, "vision_tower", None)
482
+ if type(vision_tower) is list:
483
+ vision_tower = vision_tower[0]
484
+ return vision_tower
485
+
486
+ def get_mm_projector(self):
487
+ mm_projector = getattr(self, "mm_projector", None)
488
+ if type(mm_projector) is list:
489
+ mm_projector = mm_projector[0]
490
+ return mm_projector
491
+
492
+ def freezed_module_patch(self):
493
+ """
494
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
495
+ """
496
+ if self.training:
497
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
498
+ pass
499
+ # logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
500
+ if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
501
+ self.get_vision_tower().eval()
502
+ if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
503
+ self.get_mm_projector().eval()
504
+
505
+
506
+ class VILAForCausalLM(VILAPretrainedModel):
507
+ def __init__(self, config: VILAConfig, *args, **kwargs):
508
+ super().__init__(config, *args, **kwargs)
509
+
510
+ def merge_features_for_dynamic_s2(self, image_features, block_sizes):
511
+ scales = self.get_vision_tower().scales
512
+ resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
513
+
514
+ image_features_each_image = []
515
+ new_block_sizes = []
516
+ block_cnt = 0
517
+ for block_size_each_image in block_sizes:
518
+ if block_size_each_image is None:
519
+ cur_features = image_features[block_cnt : block_cnt + 1]
520
+ cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
521
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
522
+ image_features_each_image.append(cur_features)
523
+ new_block_sizes.append((1, 1))
524
+ block_cnt += 1
525
+ else:
526
+ cur_features_each_scale = []
527
+ for scale in scales[:-1]:
528
+ num_blocks_this_scale = (scale // scales[0]) ** 2
529
+ cur_features_each_scale.append(
530
+ self.merge_chessboard(
531
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
532
+ num_split_h=scale // scales[0],
533
+ num_split_w=scale // scales[0],
534
+ )
535
+ ) # 1 * C * H * W
536
+ block_cnt += num_blocks_this_scale
537
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
538
+ cur_features_each_scale.append(
539
+ self.merge_chessboard(
540
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
541
+ num_split_h=block_size_each_image[0],
542
+ num_split_w=block_size_each_image[1],
543
+ )
544
+ ) # 1 * C * H * W
545
+ block_cnt += num_blocks_last_scale
546
+
547
+ # resize and concat features from different scales
548
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
549
+ cur_features = torch.cat(
550
+ [
551
+ F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
552
+ cur_features_each_scale[i].dtype
553
+ )
554
+ for i in range(len(cur_features_each_scale))
555
+ ],
556
+ dim=1,
557
+ )
558
+ # cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
559
+
560
+ image_features_each_image.append(cur_features)
561
+
562
+ if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
563
+ new_block_sizes.append(block_size_each_image)
564
+ else:
565
+ new_block_sizes.append(
566
+ (
567
+ scales[resize_output_to_scale_idx] // scales[0],
568
+ scales[resize_output_to_scale_idx] // scales[0],
569
+ )
570
+ )
571
+
572
+ assert block_cnt == len(image_features)
573
+
574
+ return image_features_each_image, new_block_sizes
575
+
576
+ def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None):
577
+ if block_sizes is None:
578
+ block_sizes = [None] * len(images)
579
+ if getattr(self.config, "dynamic_s2", False):
580
+ image_features = self.get_vision_tower()(images)
581
+ image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
582
+
583
+ image_features = [
584
+ self.split_chessboard(x, block_size[0], block_size[1])
585
+ for x, block_size in zip(image_features, new_block_sizes)
586
+ ] # list of B * C * H * W tensors
587
+ image_features = torch.cat(
588
+ [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
589
+ ) # B * N * C
590
+ image_features = self.get_mm_projector()(image_features)
591
+ image_features = list(
592
+ image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
593
  )
594
+ image_features = [
595
+ self.merge_chessboard(x, block_size[0], block_size[1])
596
+ for x, block_size in zip(image_features, new_block_sizes)
597
+ ] # list of 1 * C * H * W tensors
598
+ image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
599
+ if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
600
+ image_features = torch.stack(image_features, dim=0)
601
  else:
602
+ image_features = self.get_vision_tower()(images)
603
+ image_features = self.get_mm_projector()(image_features)
604
+ return image_features
605
+
606
+ def train(self, mode: bool = True):
607
+ super().train(mode)
608
+ return self
609
+
610
+ def _embed(
611
+ self,
612
+ input_ids: torch.Tensor,
613
+ media: Dict[str, List[torch.Tensor]],
614
+ media_config: Dict[str, Dict[str, Any]],
615
+ labels: Optional[torch.Tensor],
616
+ attention_mask: Optional[torch.Tensor],
617
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
618
+ # NOTE(ligeng): deep copy to avoid modifying the original media and media_config
619
+ media = copy.deepcopy(media)
620
+ media_config = copy.deepcopy(media_config)
621
+
622
+ labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
623
+ attention_mask = attention_mask.to(dtype=torch.bool) if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
624
+
625
+ PROCESS_GROUP_MANAGER = get_pg_manager()
626
+ if PROCESS_GROUP_MANAGER is not None:
627
+ for name in media:
628
+ self.encoders[name].end_tokens = None
629
+
630
+ # Extract text and media embeddings
631
+ text_embeds = self.llm.model.embed_tokens(input_ids)
632
+ if media is not None:
633
+ media_embeds = self.__embed_media_tokens(media, media_config)
634
+ else:
635
+ # no media was provided, so we just return an empty dict
636
+ media_embeds = {}
637
+
638
+ # This is a workaround to make sure the dummy embeddings are consumed
639
+ while media_embeds.get("dummy"):
640
+ dummy_embed = media_embeds["dummy"].popleft()
641
+ text_embeds += torch.sum(dummy_embed) * 0
642
+
643
+ # Remove padding
644
+ batch_size = labels.shape[0]
645
+ text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
646
+ labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
647
+ # zijzhang: also apply to input_ids
648
+ input_ids = [input_ids[k][attention_mask[k]] for k in range(batch_size)]
649
+
650
+ # Build inverse mapping from token ID to media name
651
+ media_tokens = {}
652
+ for name, token_id in self.tokenizer.media_token_ids.items():
653
+ media_tokens[token_id] = name
654
+
655
+ # Fuse text and media embeddings
656
+ inputs_m, labels_m = [], []
657
+ for k in range(batch_size):
658
+ inputs_mk, labels_mk = [], []
659
+ pos = 0
660
+ while pos < len(labels[k]):
661
+ if input_ids[k][pos].item() in media_tokens:
662
+ end = pos + 1
663
+ name = media_tokens[input_ids[k][pos].item()]
664
+ input = media_embeds[name].popleft()
665
+ label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
666
+ elif input_ids[k][pos].item() in self.pad_token_list:
667
+ # skip pad tokens
668
+ end = pos + 1
669
+ pos = end
670
+ continue
671
+ else:
672
+ end = pos
673
+ while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
674
+ end += 1
675
+ input = text_embeds[k][pos:end]
676
+ label = labels[k][pos:end]
677
+
678
+ inputs_mk.append(input)
679
+ labels_mk.append(label)
680
+ pos = end
681
+ inputs_m.append(torch.cat(inputs_mk, dim=0))
682
+ labels_m.append(torch.cat(labels_mk, dim=0))
683
+ inputs, labels = inputs_m, labels_m
684
+
685
+ # Check if all media embeddings are consumed
686
+ for name in media_embeds:
687
+ if media_embeds[name]:
688
+ raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
689
+
690
+ # Truncate sequences to `model_max_length` as media embeddings are inserted
691
+ inputs, labels = self.__truncate_sequence(inputs, labels)
692
+
693
+ # Pad sequences to the longest one in the batch
694
+ return self.__batchify_sequence(inputs, labels)
695
+
696
+ def __embed_media_tokens(
697
+ self,
698
+ media: Dict[str, List[torch.Tensor]],
699
+ media_config: Dict[str, Dict[str, Any]],
700
+ ) -> Dict[str, List[torch.Tensor]]:
701
+ embeds = defaultdict(deque)
702
+ for name in media:
703
+ if self.training:
704
+ # Gather metainfo of media objects from all ranks
705
+ info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
706
+ infos = list(chain(vila_all_gather(info)))
707
+
708
+ # The entire batch does not contain any media objects of this type.
709
+ if not infos:
710
+ continue
711
+
712
+ # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
713
+ if media.get(name) is None or len(media[name]) == 0:
714
+ dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
715
+ embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
716
+ continue
717
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
718
+ return embeds
719
+
720
+ def __truncate_sequence(
721
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
722
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
723
+ if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
724
+ warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
725
+ inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
726
+ labels = [label[: self.tokenizer.model_max_length] for label in labels]
727
+ return inputs, labels
728
+
729
+ def __batchify_sequence(
730
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
731
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
732
+ batch_size = len(inputs)
733
+ device = inputs[0].device
734
+ hidden_size = inputs[0].shape[1]
735
+ max_length = max(inputs[k].shape[0] for k in range(batch_size))
736
+ attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
737
+
738
+ inputs_p, labels_p = [], []
739
+ for k in range(batch_size):
740
+ size_pk = max_length - inputs[k].shape[0]
741
+ inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
742
+ labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
743
+ if self.tokenizer.padding_side == "right":
744
+ attention_mask[k, inputs[k].shape[0] :] = False
745
+ inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
746
+ labels_pk = torch.cat([labels[k], labels_pk], dim=0)
747
+ else:
748
+ attention_mask[k, : -inputs[k].shape[0]] = False
749
+ inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
750
+ labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
751
+ inputs_p.append(inputs_pk)
752
+ labels_p.append(labels_pk)
753
+
754
+ inputs = torch.stack(inputs_p, dim=0)
755
+ labels = torch.stack(labels_p, dim=0)
756
+ return inputs, labels, attention_mask
757
+
758
+ def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
759
+ # Handle sequence parallelism
760
+ PROCESS_GROUP_MANAGER = get_pg_manager()
761
+
762
+ # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
763
+ if PROCESS_GROUP_MANAGER is not None:
764
+ sp_degree = PROCESS_GROUP_MANAGER.sp_degree
765
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
766
+ sp_group = PROCESS_GROUP_MANAGER.sp_pg
767
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
768
+ ring_rank = PROCESS_GROUP_MANAGER.ring_rank
769
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
770
+ ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
771
+ ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
772
+
773
+ bs, shard_seqlen = position_ids.shape
774
+ sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
775
+ dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
776
+ sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
777
+
778
+ if sp_rank == 0:
779
+ original_start_id = 0
780
+ else:
781
+ original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
782
+ original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
783
+
784
+ # Gather attention_mask, position_ids, labels and input_embeds
785
+ all_inputs_embeds = torch.zeros(
786
+ bs,
787
+ torch.sum(sp_seq_len_cat),
788
+ inputs_embeds.shape[-1],
789
+ dtype=inputs_embeds.dtype,
790
+ device=inputs_embeds.device,
791
+ ).contiguous()
792
+ all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
793
+ dist.barrier(group=sp_group)
794
+ dist.all_reduce(all_inputs_embeds, group=sp_group)
795
+ dist.barrier(group=sp_group)
796
+
797
+ attention_mask_list = [
798
+ torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
799
+ for i in range(sp_degree)
800
+ ]
801
+ position_ids_list = [
802
+ torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
803
+ for i in range(sp_degree)
804
+ ]
805
+ labels_list = [
806
+ torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
807
+ ]
808
+
809
+ dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
810
+ dist.all_gather(position_ids_list, position_ids, group=sp_group)
811
+ dist.all_gather(labels_list, labels, group=sp_group)
812
+
813
+ effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
814
+ effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
815
+ effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
816
+
817
+ global_attention_mask_list = []
818
+ global_position_ids_list = []
819
+ global_labels_list = []
820
+ global_inputs_embeds_list = []
821
+ for i in range(bs):
822
+ global_attention_mask_batch_list = []
823
+ global_position_ids_batch_list = []
824
+ global_labels_batch_list = []
825
+ global_inputs_embeds_batch_list = []
826
+ for j in range(sp_degree):
827
+ eff_len = effective_seqlen_batch_list[i][j]
828
+ prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
829
+
830
+ global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
831
+ global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
832
+ global_labels_batch_list.append(labels_list[j][i, :eff_len])
833
+ global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
834
+ global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
835
+ global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
836
+ global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
837
+ global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
838
+
839
+ global_attention_mask = torch.nn.utils.rnn.pad_sequence(
840
+ global_attention_mask_list, batch_first=True, padding_value=False
841
+ )
842
+ global_position_ids = torch.nn.utils.rnn.pad_sequence(
843
+ global_position_ids_list, batch_first=True, padding_value=-1
844
+ )
845
+ global_labels = torch.nn.utils.rnn.pad_sequence(
846
+ global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
847
+ )
848
+ global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
849
+ global_inputs_embeds_list, batch_first=True, padding_value=0
850
+ )
851
+
852
+ # Re-shard the inputs
853
+ if ring_degree > 1:
854
+ total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
855
+ new_seqlen_per_rank = total_effective_seqlen // sp_degree
856
+ assert torch.all(
857
+ total_effective_seqlen % sp_degree == 0
858
+ ), "total_effective_seqlen must be divisible by sp_degree"
859
+
860
+ max_new_seqlen = torch.max(new_seqlen_per_rank).item()
861
+
862
+ new_attention_mask = torch.zeros(
863
+ (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
864
+ )
865
+ new_position_ids = torch.zeros(
866
+ (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
867
+ )
868
+ new_labels = torch.full(
869
+ (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
870
+ )
871
+ new_inputs_embeds = torch.zeros(
872
+ (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
873
+ dtype=global_inputs_embeds.dtype,
874
+ device=global_inputs_embeds.device,
875
+ )
876
+
877
+ if ring_type == "ring_varlen":
878
+ for i in range(bs):
879
+ start_idx = new_seqlen_per_rank[i] * sp_rank
880
+ end_idx = start_idx + new_seqlen_per_rank[i]
881
+ new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
882
+ new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
883
+ new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
884
+ new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
885
+ i, start_idx:end_idx, :
886
+ ]
887
+ elif ring_type == "zigzag_ring_varlen":
888
+ chunk_size = total_effective_seqlen // (2 * sp_degree)
889
+ for i in range(bs):
890
+ # Zigzag pattern indices
891
+ if sp_degree == ring_degree:
892
+ forward_rank_idx = sp_rank
893
+ backward_rank_idx = 2 * sp_degree - sp_rank - 1
894
+ else:
895
+ ulysses_offset = ulysses_rank * ring_degree * 2
896
+ forward_rank_idx = ring_rank + ulysses_offset
897
+ backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
898
+
899
+ # Calculate start and end indices for the forward and backward zigzag
900
+ start_idx_fwd = forward_rank_idx * chunk_size[i]
901
+ end_idx_fwd = start_idx_fwd + chunk_size[i]
902
+
903
+ start_idx_bwd = backward_rank_idx * chunk_size[i]
904
+ end_idx_bwd = start_idx_bwd + chunk_size[i]
905
+
906
+ # Fill new tensors with zigzag data
907
+ new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
908
+ new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
909
+ i, start_idx_bwd:end_idx_bwd
910
+ ]
911
+
912
+ new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
913
+ new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
914
+ i, start_idx_bwd:end_idx_bwd
915
+ ]
916
+
917
+ new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
918
+ new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
919
+
920
+ new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
921
+ new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
922
+ i, start_idx_bwd:end_idx_bwd, :
923
+ ]
924
+ else:
925
+ raise ValueError(f"Invalid ring_type: {ring_type}")
926
+ else:
927
+ global_seq_len = global_attention_mask.shape[-1]
928
+ seq_len_sharded = global_seq_len // sp_degree
929
+ start_idx_reshard = seq_len_sharded * sp_rank
930
+ end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
931
+
932
+ new_attention_mask = torch.narrow(
933
+ global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
934
+ )
935
+ new_position_ids = torch.narrow(
936
+ global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
937
+ )
938
+ new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
939
+ new_inputs_embeds = torch.narrow(
940
+ global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
941
+ )
942
+
943
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
944
+
945
+ device = inputs_embeds.device
946
+ batch_size = inputs_embeds.shape[0]
947
+ seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
948
+
949
+ # Pack all sequences together
950
+ inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
951
+ attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
952
+ position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
953
+ labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
954
+
955
+ # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
956
+ inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
957
+ attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
958
+ position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
959
+ labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
960
+
961
+ # Mask the first token of each sequence to avoid contamination
962
+ for label in labels_p:
963
+ label[0] = IGNORE_INDEX
964
 
965
+ # Batch the data
966
+ inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
967
+ attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
968
+ position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
969
+ labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
970
+
971
+ if hasattr(
972
+ self, "pad_to_multiple_of"
973
+ ): # related to quantization, please refer to ModelArguments for more information.
974
+ assert len(labels_p.shape) == 2
975
+ batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
976
+ hidden_size = inputs_embeds_p.shape[-1]
977
+
978
+ if max_length % self.pad_to_multiple_of != 0:
979
+ max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
980
+ difference = max_length - cur_length
981
+
982
+ inputs_embeds_p = torch.cat(
983
+ (
984
+ inputs_embeds_p,
985
+ torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
986
+ ),
987
+ dim=1,
988
+ )
989
+ labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
990
+ attention_mask_p = torch.cat(
991
+ (
992
+ attention_mask_p,
993
+ torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
994
+ ),
995
+ dim=1,
996
+ )
997
+ position_ids_p = torch.cat(
998
+ (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
999
+ )
1000
+
1001
+ return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
1002
+
1003
+ def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]:
1004
+ raise NotImplementedError("This method is not implemented for VILA model.")
1005
+ # Convert response format to logits processor
1006
+ import xgrammar as xgr
1007
+
1008
+ logging.info("[XGrammar] Compiling grammar for contrained output")
1009
+
1010
+ if self.grammar_compiler is None:
1011
+ # logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
1012
+ self.grammar_compiler = xgr.GrammarCompiler(
1013
+ xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
1014
+ )
1015
+
1016
+ if response_format.type == "json_schema":
1017
+ compiled_grammar = self.grammar_compiler.compile_json_schema(
1018
+ response_format.json_schema.schema_,
1019
+ indent=2,
1020
+ )
1021
+ else:
1022
+ compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
1023
+
1024
+ return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
1025
+
1026
+ def forward(
1027
+ self,
1028
+ input_ids: torch.LongTensor = None,
1029
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1030
+ images: Optional[torch.FloatTensor] = None,
1031
+ media_config: Optional[List] = None,
1032
+ pixel_values: Optional[torch.FloatTensor] = None,
1033
+ attention_mask: Optional[torch.Tensor] = None,
1034
+ position_ids: Optional[torch.LongTensor] = None,
1035
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1036
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1037
+ labels: Optional[torch.LongTensor] = None,
1038
+ packing: bool = True,
1039
+ force_packing: bool = False,
1040
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
1041
+ dpo_forward: bool = False,
1042
+ **kwargs,
1043
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1044
+ self.freezed_module_patch()
1045
+
1046
+ if images is not None:
1047
+ if media is not None:
1048
+ raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
1049
+ print("The 'images' argument is deprecated. Please use 'media' instead.")
1050
+ media = {"image": images}
1051
+
1052
+ if media_config is None:
1053
+ media_config = defaultdict(dict)
1054
+
1055
+ if inputs_embeds is None:
1056
+ inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
1057
+
1058
+ if force_packing or (packing and self.training and not dpo_forward):
1059
+ if seqlens_in_batch is None:
1060
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
1061
+ set_seqlens_in_batch(seqlens_in_batch)
1062
+
1063
+ (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
1064
+ inputs_embeds, attention_mask, position_ids, labels
1065
+ )
1066
+
1067
+ outputs = self.llm(
1068
+ inputs_embeds=inputs_embeds,
1069
+ attention_mask=attention_mask,
1070
+ position_ids=position_ids,
1071
+ past_key_values=past_key_values,
1072
+ labels=labels,
1073
  **kwargs,
1074
  )
1075
 
1076
+ if self.training and getattr(self.config, "time_token_ids", []):
1077
+ outputs.loss = soft_cross_entropy(
1078
+ outputs.logits,
1079
+ labels,
1080
+ soft_tokens=self.config.time_token_ids,
1081
+ std=self.config.soft_ce_std,
1082
+ )
1083
+
1084
+ if dpo_forward:
1085
+ return outputs.logits, labels
1086
+
1087
+ return outputs
1088
+
1089
+ # TODO(ligeng): check how qwen implements this function
1090
+ # @torch.inference_mode()
1091
+ def generate(
1092
  self,
1093
+ input_ids: Optional[torch.FloatTensor] = None,
1094
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1095
+ media_config: Dict[str, Dict[str, Any]] = None,
1096
+ attention_mask: Optional[torch.LongTensor] = None,
1097
+ return_output_ids_only: bool = False,
1098
+ **generation_kwargs,
1099
+ ) -> torch.LongTensor:
1100
+ """
1101
+ input_tokens: <image> describe the image
1102
+ media: [Tensor(1, 3, 384, 384), ]
1103
+ ----------->
1104
+ input_tokens: 36000 001 002 003 004
1105
+ input_emds: <media emd> 001 002 003 004
1106
  """
1107
+ # NOTE: hard code to move to GPU
1108
+ # input_ids = input_ids.cuda()
1109
+ # media = {k: [v.cuda() if v is not None for v in media[k]] for k in media}
1110
+ # if attention_mask is not None:
1111
+ # attention_mask = attention_mask.cuda()
1112
+ inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1113
+ output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
1114
 
1115
+ if return_output_ids_only:
1116
+ return_value = output_ids
1117
+ else:
1118
+ # by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
1119
+ generation_config = generation_kwargs.get("generation_config", None)
1120
+ if generation_config is not None:
1121
+ num_generations = generation_config.num_return_sequences
1122
+ repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
1123
+ return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
1124
+ else:
1125
+ return_value = torch.cat([input_ids, output_ids], dim=-1)
1126
 
1127
+ return return_value
1128
 
1129
+ @torch.inference_mode()
1130
+ def generate_content(
1131
+ self,
1132
+ prompt: Union[str, List],
1133
+ generation_config: Optional[GenerationConfig] = None,
1134
+ response_format=None,
1135
+ ) -> str:
1136
+ # TODO(zhijianl): Support directly taking conversation as input
1137
+ conversation = [{"from": "human", "value": prompt}]
1138
 
1139
+ # Convert response format to logits processor
1140
+ xgr_logits_processor = None
 
 
 
 
 
 
 
 
 
 
 
 
1141
 
1142
+ # Extract media from the conversation
1143
+
1144
+ # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used)
1145
+ media = extract_media(conversation, self.config)
1146
+
1147
+ # Process media
1148
+ media_config = defaultdict(dict)
1149
+ for name in media:
1150
+ if name == "image":
1151
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
1152
+ self.config.image_processor = self.vision_tower.image_processor
1153
+ if self.config.image_aspect_ratio == "dynamic":
1154
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
1155
+ conversation[0]["value"] = conversation[0]["value"].replace(
1156
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
1157
+ )
1158
+ else:
1159
+ if type(self.config.s2_scales) is str:
1160
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1161
+ images, block_sizes = process_image(
1162
+ media["image"][0], self.config, None, enable_dynamic_s2=True
1163
+ )
1164
+ images = images.half()
1165
+ media_config[name]["block_sizes"] = [block_sizes]
1166
+ else:
1167
+ images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
1168
+ media[name] = [image for image in images]
1169
+ elif name == "video":
1170
+ if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
1171
+ media[name] = [
1172
+ process_images(
1173
+ images,
1174
+ self.vision_tower.image_processor,
1175
+ self.config,
1176
+ enable_dynamic_res=True,
1177
+ max_tiles=self.config.video_max_tiles,
1178
+ ).half()
1179
+ for images in media[name]
1180
+ ]
1181
+ elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
1182
+ self.config.image_processor = self.vision_tower.image_processor
1183
+ if type(self.config.s2_scales) is str:
1184
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1185
+ media[name] = [
1186
+ torch.cat(
1187
+ [
1188
+ process_image(
1189
+ image,
1190
+ self.config,
1191
+ None,
1192
+ enable_dynamic_s2=True,
1193
+ max_tiles=self.config.video_max_tiles,
1194
+ )[0].half()
1195
+ for image in images
1196
+ ]
1197
+ )
1198
+ for images in media[name]
1199
+ ]
1200
+ else:
1201
+ media[name] = [
1202
+ process_images(images, self.vision_tower.image_processor, self.config).half()
1203
+ for images in media[name]
1204
+ ]
1205
+ else:
1206
+ raise ValueError(f"Unsupported media type: {name}")
1207
+
1208
+ # Tokenize the conversation
1209
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
1210
+
1211
+ # Set up the generation config
1212
+ generation_config = generation_config or self.default_generation_config
1213
+
1214
+ # print("input_ids", input_ids.shape)
1215
+ # print(input_ids)
1216
+ # print(self.tokenizer.batch_decode(input_ids))
1217
+ # print("media", {k: len(v) for k, v in media.items()})
1218
+ # print("media_config", media_config)
1219
+ # print("generation_config", generation_config)
1220
+ # input("wait for debug")
1221
+ # Generate the response
1222
+ try:
1223
+ output_ids = self.generate(
1224
+ input_ids=input_ids,
1225
+ media=media,
1226
+ media_config=media_config,
1227
+ generation_config=generation_config,
1228
+ logits_processor=xgr_logits_processor, # structured generation
1229
+ )
1230
+ except ValueError:
1231
+ if not generation_config.do_sample:
1232
+ raise
1233
+ # FIXME(zhijianl): This is a temporary workaround for the sampling issue
1234
+ logging.warning("Generation failed with sampling, retrying with greedy decoding.")
1235
+ generation_config.do_sample = False
1236
+ output_ids = self.generate(
1237
+ input_ids=input_ids,
1238
+ media=media,
1239
+ media_config=media_config,
1240
+ generation_config=generation_config,
1241
+ logits_processor=xgr_logits_processor,
1242
+ )
1243
+
1244
+ # Decode the response
1245
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
1246
+ return response
1247
 
1248
+ @property
1249
+ def default_generation_config(self) -> GenerationConfig:
1250
+ generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
1251
+ if self.tokenizer.eos_token_id is None:
1252
+ raise ValueError("Tokenizer must have an EOS token")
1253
+ if generation_config.max_length == GenerationConfig().max_length:
1254
+ generation_config.max_length = self.tokenizer.model_max_length
1255
+ if generation_config.pad_token_id is None:
1256
+ generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
1257
+ if generation_config.bos_token_id is None:
1258
+ generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
1259
+ if generation_config.eos_token_id is None:
1260
+ generation_config.eos_token_id = self.tokenizer.eos_token_id
1261
+ return generation_config
modeling_vila_hf.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Optional, Tuple, Type, Union, cast, override
3
+
4
+ import torch
5
+ import transformers.modeling_utils as modeling_utils
6
+ from torch import Tensor
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from transformers.generation.utils import GenerationMixin
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers.modeling_utils import PreTrainedModel
11
+
12
+ from .configuration_vila import VILAConfig
13
+ from .modeling_vila import VILAForCausalLM
14
+
15
+
16
+ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
17
+ config_class: Type[PretrainedConfig] = VILAConfig
18
+ base_model_prefix: str = "vila"
19
+ is_parallelizable: bool = True
20
+ main_input_name: str = "input_ids"
21
+
22
+ config: PretrainedConfig
23
+
24
+ def __init__(
25
+ self,
26
+ config: PretrainedConfig,
27
+ model: VILAForCausalLM,
28
+ *args,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(config, *args, **kwargs)
32
+
33
+ self.model = model
34
+
35
+ def forward(
36
+ self,
37
+ *,
38
+ attention_mask: Optional[Tensor] = None,
39
+ input_ids: Optional[Tensor] = None,
40
+ inputs_embeds: Optional[Tensor] = None,
41
+ pixel_values: Optional[Tensor] = None,
42
+ **kwargs,
43
+ ) -> CausalLMOutputWithPast:
44
+ if inputs_embeds is None:
45
+ assert input_ids is not None
46
+
47
+ inputs_embeds, attention_mask = self._embed(
48
+ input_ids, pixel_values, attention_mask
49
+ )
50
+ else:
51
+ assert input_ids is None
52
+ assert pixel_values is None
53
+
54
+ outputs = self.model.llm.forward(
55
+ inputs_embeds=inputs_embeds,
56
+ attention_mask=attention_mask,
57
+ **kwargs,
58
+ )
59
+
60
+ return outputs
61
+
62
+ @override
63
+ @classmethod
64
+ @modeling_utils.restore_default_torch_dtype
65
+ def from_pretrained(
66
+ cls: Type[modeling_utils.SpecificPreTrainedModelType],
67
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
68
+ *model_args,
69
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
70
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
71
+ ignore_mismatched_sizes: bool = False,
72
+ force_download: bool = False,
73
+ local_files_only: bool = False,
74
+ token: Optional[Union[str, bool]] = None,
75
+ revision: str = "main",
76
+ use_safetensors: Optional[bool] = None,
77
+ weights_only: bool = True,
78
+ **kwargs,
79
+ ) -> modeling_utils.SpecificPreTrainedModelType:
80
+ state_dict = kwargs.pop("state_dict", None)
81
+
82
+ if pretrained_model_name_or_path is not None:
83
+ config = VILAConfig.from_pretrained(
84
+ pretrained_model_name_or_path,
85
+ cache_dir=cache_dir,
86
+ force_download=force_download,
87
+ local_files_only=local_files_only,
88
+ revision=revision,
89
+ use_safetensors=use_safetensors,
90
+ **kwargs,
91
+ )
92
+ else:
93
+ assert (
94
+ config is not None and state_dict is not None
95
+ ), "Both config and state_dict must be provided if pretrained_model_name_or_path is None."
96
+
97
+ inner_model = VILAForCausalLM.from_pretrained(
98
+ pretrained_model_name_or_path, # type: ignore
99
+ *model_args,
100
+ config=config,
101
+ cache_dir=cache_dir,
102
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
103
+ force_download=force_download,
104
+ local_files_only=local_files_only,
105
+ token=token,
106
+ revision=revision,
107
+ use_safetensors=use_safetensors,
108
+ weights_only=weights_only,
109
+ **kwargs,
110
+ )
111
+
112
+ state_dict = inner_model.state_dict()
113
+
114
+ # Prefix keys with "model.".
115
+ state_dict = {f"model.{k}": v for k, v in state_dict.items()}
116
+
117
+ return super().from_pretrained(
118
+ None,
119
+ inner_model,
120
+ *model_args,
121
+ config=config,
122
+ cache_dir=cache_dir,
123
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
124
+ force_download=force_download,
125
+ local_files_only=local_files_only,
126
+ token=token,
127
+ revision=revision,
128
+ state_dict=state_dict,
129
+ use_safetensors=use_safetensors,
130
+ weights_only=weights_only,
131
+ **kwargs,
132
+ )
133
+
134
+ def _embed(
135
+ self,
136
+ input_ids: Tensor,
137
+ pixel_values: Optional[Tensor],
138
+ attention_mask: Optional[Tensor],
139
+ ) -> Tuple[Tensor, Tensor]:
140
+ """Gets the embedding of the input ids and pixel values.
141
+
142
+ Args:
143
+ input_ids: The input ids.
144
+ pixel_values: The pixel values.
145
+ attention_mask: The attention mask.
146
+
147
+ Returns:
148
+ A tuple of the embedding of the input ids and attention mask.
149
+ """
150
+
151
+ image_token_ids_map = cast(Dict[str, int], self.model.tokenizer.media_token_ids)
152
+ image_token_ids = list(image_token_ids_map.values())
153
+ image_token_idx = torch.isin(
154
+ input_ids,
155
+ torch.tensor(image_token_ids).to(input_ids.device),
156
+ )
157
+ image_token_count = image_token_idx.sum()
158
+
159
+ images = list(pixel_values) if pixel_values is not None else []
160
+
161
+ if image_token_count < len(images):
162
+ images = images[:image_token_count]
163
+
164
+ media = (
165
+ {
166
+ "image": images,
167
+ }
168
+ if image_token_count > 0
169
+ else {}
170
+ )
171
+ media_config = (
172
+ {
173
+ "image": {},
174
+ }
175
+ if image_token_count > 0
176
+ else {}
177
+ )
178
+
179
+ outputs = self.model._embed(
180
+ input_ids,
181
+ media,
182
+ media_config,
183
+ labels=None,
184
+ attention_mask=(
185
+ attention_mask[:, -input_ids.shape[1] :].to(dtype=torch.bool)
186
+ if attention_mask is not None
187
+ else None
188
+ ),
189
+ )
190
+
191
+ return outputs[0], outputs[2]