Bingsu commited on
Commit
e8c3ba1
·
verified ·
1 Parent(s): be19eec

Delete modeling_hyperclovax.py

Browse files
Files changed (1) hide show
  1. modeling_hyperclovax.py +0 -1810
modeling_hyperclovax.py DELETED
@@ -1,1810 +0,0 @@
1
- import ast
2
- import contextlib
3
- import gc
4
- import json
5
- import math
6
- import os
7
- from dataclasses import dataclass
8
- from functools import partial
9
- from itertools import chain
10
- from typing import Any, Dict, List, Optional, Tuple, Union
11
-
12
- import torch
13
- import torch.distributed as dist
14
- import torch.nn as nn
15
- from einops import rearrange
16
- from timm.layers import LayerNorm, LayerNorm2d
17
- from timm.models.regnet import RegStage
18
- from torch.nn import CrossEntropyLoss
19
- from transformers import (
20
- AutoConfig,
21
- AutoModel,
22
- AutoModelForCausalLM,
23
- AutoTokenizer,
24
- PreTrainedModel,
25
- )
26
- from transformers.generation.utils import GenerationMixin
27
- from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
28
- from transformers.modeling_utils import (
29
- is_fsdp_enabled,
30
- is_local_dist_rank_0,
31
- no_init_weights,
32
- )
33
- from transformers.models.auto import CONFIG_MAPPING
34
- from transformers.utils import ModelOutput
35
-
36
- from .configuration_hyperclovax import HCXVisionConfig
37
- from .preprocessor import select_best_resolution
38
-
39
- EOT = "<|endofturn|>"
40
- IMG_LOC = "<|dummy3|>"
41
-
42
-
43
- def get_rank():
44
- if dist.is_initialized():
45
- return dist.get_rank()
46
- return 0
47
-
48
-
49
- def get_world_size():
50
- if torch.distributed.is_initialized():
51
- world_size = torch.distributed.get_world_size()
52
- else:
53
- world_size = 1
54
- return world_size
55
-
56
-
57
- def unpad_image(tensor: torch.Tensor, original_size: Tuple[int, int]) -> torch.Tensor:
58
- """Unpads a PyTorch tensor of a padded and resized image.
59
-
60
- This function removes padding from a tensor image that was previously padded and resized.
61
- The padding is removed based on the aspect ratio difference between the original and current image dimensions.
62
-
63
- Args:
64
- tensor: The image tensor, assumed to be in CxHxW format.
65
- original_size: The original size of the image as (width, height).
66
-
67
- Returns:
68
- The unpadded image tensor.
69
-
70
- Examples:
71
- >>> import torch
72
- >>> # Example 1: Unpadding with height padding
73
- >>> padded_tensor = torch.randn(1, 64, 48) # Padded tensor (C=1, H=64, W=48)
74
- >>> original_size = (32, 32) # Original size (width=32, height=32)
75
- >>> unpadded_tensor = unpad_image(padded_tensor, original_size)
76
- >>> unpadded_tensor.shape
77
- torch.Size([1, 48, 48])
78
- >>> # Example 2: Unpadding with width padding
79
- >>> padded_tensor = torch.randn(1, 48, 64) # Padded tensor (C=1, H=48, W=64)
80
- >>> original_size = (32, 32) # Original size (width=32, height=32)
81
- >>> unpadded_tensor = unpad_image(padded_tensor, original_size)
82
- >>> unpadded_tensor.shape
83
- torch.Size([1, 48, 48])
84
- """
85
- original_width, original_height = original_size
86
- current_height, current_width = tensor.shape[1:]
87
-
88
- original_aspect_ratio = original_width / original_height
89
- current_aspect_ratio = current_width / current_height
90
-
91
- if original_aspect_ratio > current_aspect_ratio:
92
- scale_factor = current_width / original_width
93
- new_height = int(original_height * scale_factor)
94
- padding = (current_height - new_height) // 2
95
- unpadded_tensor = tensor[:, padding : current_height - padding, :]
96
- else:
97
- scale_factor = current_height / original_height
98
- new_width = int(original_width * scale_factor)
99
- padding = (current_width - new_width) // 2
100
- unpadded_tensor = tensor[:, :, padding : current_width - padding]
101
-
102
- return unpadded_tensor
103
-
104
-
105
- def get_anyres_image_grid_shape(
106
- image_size: Tuple[int, int],
107
- grid_pinpoints: Union[str, List[Tuple[int, int]]],
108
- patch_size: int,
109
- ) -> Tuple[int, int]:
110
- """Calculates the image patch grid shape after any-resolution preprocessing.
111
-
112
- Selects the optimal resolution from predefined grid pinpoints based on input image
113
- dimensions using `select_best_resolution`, then computes the grid layout by
114
- dividing the selected resolution by the patch size using integer division.
115
-
116
- Args:
117
- image_size (Tuple[int, int]): Original image dimensions in (width, height) format.
118
- grid_pinpoints (Union[str, List[Tuple[int, int]]]): Accepts either:
119
- - List of (height, width) resolution tuples
120
- - String representation of list (e.g., "[(224, 224), (336, 336)]")
121
- patch_size (int): Spatial dimension of square patches for grid division.
122
-
123
- Returns:
124
- Tuple[int, int]: Grid dimensions as (num_patches_width, num_patches_height).
125
-
126
- Examples:
127
- >>> # Basic case with list input
128
- >>> get_anyres_image_grid_shape((1000, 800), [(224, 224), (448, 448)], 112)
129
- (4, 4)
130
-
131
- >>> # Basic case with string input
132
- >>> get_anyres_image_grid_shape((600, 400), "[(336, 336), (672, 672)]", 112)
133
- (6, 6)
134
-
135
- >>> # Case where resolution is not perfectly divisible by patch_size
136
- >>> # select_best_resolution picks (224, 224). 224 // 100 = 2
137
- >>> get_anyres_image_grid_shape((500, 500), [(224, 224)], 100)
138
- (2, 2)
139
-
140
- >>> # Different patch size
141
- >>> # select_best_resolution picks (448, 448). 448 // 224 = 2
142
- >>> get_anyres_image_grid_shape((1200, 900), [(448, 448), (224, 224)], 224)
143
- (2, 2)
144
-
145
- Note:
146
- String-formatted grid_pinpoints are converted via ast.literal_eval. Invalid formats
147
- may raise syntax exceptions. The actual resolution selection depends on the
148
- implementation of `select_best_resolution`. The doctests assume
149
- `select_best_resolution` picks the *first* resolution provided in `grid_pinpoints`.
150
- """
151
- possible_resolutions = grid_pinpoints if isinstance(grid_pinpoints, list) else ast.literal_eval(grid_pinpoints)
152
-
153
- original_width, original_height = image_size
154
- height, width = select_best_resolution((original_height, original_width), possible_resolutions)
155
- return width // patch_size, height // patch_size
156
-
157
-
158
- def reshape_and_unpad_image_features(
159
- image_feature: torch.Tensor,
160
- height: int,
161
- width: int,
162
- image_size: Tuple[int, int],
163
- possible_resolutions: List[Tuple[int, int]],
164
- grid_size: int,
165
- unpad: bool,
166
- image_newline: torch.Tensor,
167
- ) -> torch.Tensor:
168
- """Reshapes and processes image features with optional unpadding operation.
169
-
170
- Processes input image features by:
171
- 1. Separating base features from spatial features
172
- 2. Reshaping spatial features into a 5D tensor (num_patch_height, num_patch_width, height, width, channels)
173
- 3. Performing either unpadding operation or simple reshaping based on 'unpad' flag
174
- 4. Concatenating processed features with base features
175
-
176
- Args:
177
- image_feature: Input tensor containing image features with shape
178
- [1 + num_patches, feature_dim] where the first element is the base feature
179
- height: Original image height in pixels
180
- width: Original image width in pixels
181
- image_size: Target image size as (width, height) tuple
182
- possible_resolutions: List of possible [height, width] resolutions for multi-scale processing
183
- grid_size: Grid dimension for patch arrangement
184
- unpad: Flag to enable unpadding operation
185
- image_newline: Special token tensor used as separator when unpadding
186
-
187
- Returns:
188
- torch.Tensor: Processed image features tensor with shape [1 + num_processed_patches, feature_dim]
189
-
190
- Raises:
191
- AssertionError: If base feature dimension doesn't match height*width
192
- """
193
- base_image_feature = image_feature[0]
194
- image_feature = image_feature[1:]
195
-
196
- assert (
197
- height * width == base_image_feature.shape[0]
198
- ), f"height: {height}, width: {width}, base_image_feature.shape[0]: {base_image_feature.shape[0]}"
199
-
200
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_size, possible_resolutions, grid_size)
201
- image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
202
-
203
- if unpad:
204
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
205
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
206
- image_feature = unpad_image(image_feature, image_size)
207
- image_feature = torch.cat(
208
- (
209
- image_feature,
210
- image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device),
211
- ),
212
- dim=-1,
213
- )
214
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
215
- else:
216
- image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
217
- image_feature = image_feature.flatten(0, 3)
218
- image_feature = torch.cat((base_image_feature, image_feature), dim=0)
219
-
220
- return image_feature
221
-
222
-
223
- def anyres_postprocessing(
224
- image_forward_outs: torch.FloatTensor,
225
- split_sizes: List[int],
226
- image_sizes: List[List[int]],
227
- possible_resolutions: List[Tuple[int, int]],
228
- is_videos: List[bool],
229
- patch_size: int,
230
- grid_size: int,
231
- image_newline: torch.FloatTensor,
232
- num_queries_vis_abstractor: int = -1,
233
- unpad: bool = False,
234
- ) -> List[torch.FloatTensor]:
235
- """Processes 2D visual features into 1D sequences with post-processing steps.
236
-
237
- Performs AnyRes postprocessing by flattening 2D visual features from grid partitions into 1D sequences, adding
238
- newline embeddings at row boundaries for images, and optionally removing padding regions based on original image
239
- sizes. For video data, processes each frame's features separately into a single sequence per video and disables
240
- unpadding and newline insertion.
241
-
242
- Args:
243
- image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape
244
- (number_of_images_in_grid, total_patches, feature_dim) containing visual features.
245
- split_sizes (List[int]): A list containing the number of patches for each sample in the batch. The sum of
246
- `split_sizes` should equal `image_forward_outs.shape[0]`.
247
- image_sizes (List[List[int]]): A list where each element is a list `[width, height]` representing the original
248
- dimensions of the corresponding image sample. Used for unpadding.
249
- possible_resolutions (List[Tuple[int, int]]): A list of supported resolution tuples `(height, width)` used by
250
- `reshape_and_unpad_image_features` for spatial reconstruction, especially during unpadding.
251
- is_videos (List[bool]): A list of boolean flags indicating whether each corresponding sample in the batch is a
252
- video [`True`] or an image [`False`].
253
- patch_size (int): The spatial dimension (height and width) of the square patches the image was divided into.
254
- grid_size (int): The spatial dimension (height and width) of the square grid onto which patches are mapped.
255
- `grid_size` should be divisible by `patch_size`.
256
- image_newline (torch.FloatTensor): A learnable tensor representing the newline embedding, typically with shape
257
- (1, feature_dim). Added after each row of image patches when not unpadding.
258
- num_queries_vis_abstractor (int, optional): If a visual abstractor with a fixed number of output queries is used
259
- instead of grid patching, this specifies the number of queries. Must be a perfect square if > 0.
260
- Defaults to -1 (indicating standard grid patching is used).
261
- unpad (bool, optional): If `True`, removes padding tokens from image features based on `image_sizes` and
262
- `possible_resolutions`. Does not apply to video features. Defaults to False.
263
-
264
- Returns:
265
- List[torch.FloatTensor]: A list of tensors, where each tensor represents the processed 1D sequence of visual
266
- features for a single sample from the input batch. The length of the sequence varies depending on processing
267
- (unpadding, newlines, video flattening).
268
-
269
- Raises:
270
- AssertionError: If `num_queries_vis_abstractor` is greater than 0 but not a perfect square.
271
- """
272
- height = width = grid_size // patch_size
273
-
274
- if num_queries_vis_abstractor > 0:
275
- assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number"
276
- height = width = int(num_queries_vis_abstractor**0.5)
277
-
278
- image_features = torch.split(image_forward_outs, split_sizes, dim=0)
279
-
280
- # post-processing (unpad, add newline)
281
- new_image_features = []
282
- for image_idx, (image_feature, is_video) in enumerate(zip(image_features, is_videos)):
283
- if image_feature.shape[0] > 1:
284
- if not is_video:
285
- image_feature = reshape_and_unpad_image_features(
286
- image_feature=image_feature,
287
- height=height,
288
- width=width,
289
- image_size=image_sizes[image_idx],
290
- possible_resolutions=possible_resolutions,
291
- grid_size=grid_size, # Pass grid info if needed by helper
292
- unpad=unpad,
293
- image_newline=image_newline,
294
- )
295
- else:
296
- image_feature = image_feature.flatten(0, 1)
297
- else:
298
- image_feature = image_feature[0]
299
- if unpad and not is_video:
300
- image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0)
301
- new_image_features.append(image_feature)
302
- image_features = new_image_features
303
- return image_features
304
-
305
-
306
- def adaptive_anyres_postprocessing(
307
- image_forward_outs: torch.FloatTensor,
308
- image_sizes: List[List[int]],
309
- possible_resolutions: List[Tuple[int, int]],
310
- is_videos: List[bool],
311
- group_ids: List[List[int]],
312
- num_queries_vis_abstractors: List[List[int]],
313
- grid_size: int,
314
- image_newline: torch.FloatTensor,
315
- unpad: bool = False,
316
- ) -> List[torch.FloatTensor]:
317
- """Adaptive AnyRes postprocessing for multi-group feature aggregation.
318
-
319
- Processes 2D visual features into 1D sequences with group-wise adaptive processing. Each image can belong to
320
- multiple processing groups with different query configurations. Features are processed per group and aggregated
321
- according to group_ids.
322
-
323
- Args:
324
- image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape
325
- (number_of_images_in_grid, total_patches, feature_dim) containing visual features.
326
- image_sizes (List[List[int]]): Original image dimensions for each sample. [[width, height], ... ]
327
- possible_resolutions (List[Tuple[int, int]]): Supported resolutions. [[height, width], ... ]
328
- is_videos (List[bool]): Flags indicating video inputs
329
- group_ids (List[List[int]]): Group indices for feature aggregation. Each group means a single grid.
330
- num_queries_vis_abstractors (List[List[int]]): Query numbers per group
331
- grid_size (int): Total grid size for spatial processing
332
- image_newline (torch.FloatTensor): Sample-wise config. Newline embedding tensor
333
- unpad (bool, optional): Sample-wise config. Enable padding removal. Defaults to False.
334
-
335
- Returns:
336
- List[torch.FloatTensor]: Aggregated features per group
337
-
338
- Raises:
339
- AssertionError: If num_queries is not square number in any group
340
- """
341
- # post-processing (unpad, add newline)
342
- new_image_features = []
343
- for image_idx, (image_feature, is_video) in enumerate(zip(image_forward_outs, is_videos)):
344
- num_queries_vis_abstractor = num_queries_vis_abstractors[image_idx]
345
- assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number"
346
- height = width = int(num_queries_vis_abstractor**0.5)
347
-
348
- if image_feature.shape[0] > 1:
349
- if not is_video:
350
- image_feature = reshape_and_unpad_image_features(
351
- image_feature=image_feature,
352
- height=height,
353
- width=width,
354
- image_size=image_sizes[image_idx],
355
- possible_resolutions=possible_resolutions,
356
- grid_size=grid_size,
357
- unpad=unpad,
358
- image_newline=image_newline,
359
- )
360
- else:
361
- image_feature = image_feature.flatten(0, 1)
362
- else:
363
- image_feature = image_feature[0]
364
- if unpad and not is_video:
365
- image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0)
366
- new_image_features.append(image_feature)
367
-
368
- image_features = [
369
- torch.cat([new_image_features[group_id] for group_id in group_ids_list], dim=0) for group_ids_list in group_ids
370
- ]
371
- return image_features
372
-
373
-
374
- @dataclass
375
- class HCXVisionOutput(ModelOutput):
376
- """Output class for vision models, containing various computation results.
377
-
378
- Args:
379
- loss (Optional[torch.FloatTensor], optional): Total cross-entropy loss calculated from logits and labels.
380
- loss_per_sample (Optional[torch.FloatTensor], optional): Per-sample loss values for advanced loss processing.
381
- logits (torch.FloatTensor): Classification scores (before SoftMax) of shape (batch_size, num_classes).
382
- past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]], optional): Contains precomputed hidden-states
383
- that can be used (see `past_key_values` input) to speed up sequential decoding.
384
- hidden_states (Optional[Tuple[torch.FloatTensor]], optional):
385
- Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of
386
- shape (batch_size, sequence_length, hidden_size).
387
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
388
- attentions (Optional[Tuple[torch.FloatTensor]], optional): Tuple of torch.FloatTensor (one for each layer)
389
- of shape (batch_size, num_heads, sequence_length, sequence_length). Attentions weights after the attention
390
- softmax, used to compute the weighted average in the self-attention heads.
391
- """
392
-
393
- loss: Optional[torch.FloatTensor] = None
394
- loss_per_sample: Optional[torch.FloatTensor] = None
395
- logits: torch.FloatTensor = None
396
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
397
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
398
- attentions: Optional[Tuple[torch.FloatTensor]] = None
399
-
400
-
401
- class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin):
402
- """HCX Vision model for causal language modeling with vision-language capabilities.
403
-
404
- This class combines a vision model with a language model to create a multimodal model
405
- capable of processing images or videos and generating text based on the visual inputs.
406
-
407
- Attributes:
408
- config_class: Configuration class for the model.
409
- vision_model_name: Name of the vision model component.
410
- _no_split_modules: List of modules that should not be split during parallel processing.
411
- supports_gradient_checkpointing: Whether the model supports gradient checkpointing.
412
- _skip_keys_device_placement: Keys to skip during device placement.
413
- """
414
-
415
- config_class = HCXVisionConfig
416
- vision_model_name = "vision_model"
417
- _no_split_modules = ["CLIPAttention", "SiglipVisionModel"]
418
- supports_gradient_checkpointing = True
419
- _skip_keys_device_placement = "past_key_values"
420
-
421
- def __init__(
422
- self,
423
- config: HCXVisionConfig,
424
- **kwargs: Optional[Any],
425
- ) -> None:
426
- """Initialize the HCXVisionForCausalLM model.
427
-
428
- Args:
429
- config: Configuration object for the model containing parameters for both
430
- vision and language components.
431
- **kwargs: Additional keyword arguments:
432
- - use_liger: Whether to use liger kernel for hyperclovax models.
433
- - use_fused_ce: Whether to use fused cross-entropy loss.
434
- - use_sum_loss: Whether to use sum reduction for loss instead of mean.
435
- - is_safetensor_save: Whether to save model using safetensors format.
436
-
437
- Raises:
438
- ValueError: If vision_config is not defined or if language_config is not defined.
439
- """
440
- super().__init__(config)
441
-
442
- self.flag_changed_max_position_embeddings = False
443
-
444
- vision_model_type = config.vision_config["model_type"]
445
- if vision_model_type in CONFIG_MAPPING:
446
- vision_config = CONFIG_MAPPING[vision_model_type](**config.vision_config)
447
- vision_config.auto_map = {}
448
- else:
449
- if config.vision_model_name_or_path is not None:
450
- vision_config = AutoConfig.from_pretrained(config.vision_model_name_or_path, trust_remote_code=True)
451
- elif config.vision_config["_name_or_path"] is not None:
452
- vision_config = AutoConfig.from_pretrained(
453
- config.vision_config["_name_or_path"], trust_remote_code=True
454
- )
455
- else:
456
- raise ValueError("vision_config is not defined")
457
-
458
- self.use_liger = kwargs.pop("use_liger", False)
459
- self.use_fused_ce = kwargs.pop("use_fused_ce", False)
460
- self.reduction = "sum" if kwargs.pop("use_sum_loss", False) else "mean"
461
-
462
- self.vision_config = vision_config
463
- vision_config.anyres = config.anyres
464
- vision_config.max_num_grids = config.max_num_grids
465
-
466
- possible_resolutions = []
467
- if config.anyres:
468
- assert config.max_num_grids > 0
469
- for i in range(1, config.max_num_grids + 1):
470
- for j in range(1, config.max_num_grids + 1):
471
- if i == 1 and j == 1 and not config.use_1x1_grid:
472
- continue
473
- if i * j <= config.max_num_grids:
474
- possible_resolutions.append([i, j])
475
-
476
- possible_resolutions = [
477
- [ys * vision_config.image_size, xs * vision_config.image_size] for ys, xs in possible_resolutions
478
- ]
479
-
480
- self.possible_resolutions = possible_resolutions
481
-
482
- with no_init_weights():
483
- self.vision_model = AutoModel.from_config(
484
- vision_config, trust_remote_code=True
485
- ) # weight will be loaded in from_pretrained
486
-
487
- assert config.language_config["model_type"] == "llama"
488
- language_config = CONFIG_MAPPING["llama"](**config.language_config)
489
- language_config._attn_implementation = kwargs.get("attn_implementation", "sdpa") # activate flash attention
490
- language_config.logits_scaling = 1.0
491
-
492
- self.language_config = language_config
493
- self.language_model = AutoModelForCausalLM.from_config(language_config)
494
-
495
- self.language_model.gradient_checkpointing_enable()
496
- self.num_queries_vis_abstractor = config.num_queries_vis_abstractor
497
-
498
- # mm_projctor(==connector); vision_model_hidden_size -> LLM embedding size
499
- input_hidden_size = vision_config.hidden_size
500
- self.mm_projector = HCXVisionCAbstractor(
501
- num_queries=self.num_queries_vis_abstractor,
502
- num_input_tokens=(self.vision_config.image_size // self.vision_config.patch_size) ** 2,
503
- encoder_hidden_size=input_hidden_size,
504
- hidden_size=input_hidden_size,
505
- output_hidden_size=language_config.hidden_size,
506
- pos_emb=config.proj_pos_emb,
507
- prenorm=config.proj_prenorm,
508
- )
509
- self.use_nth_layer = config.use_nth_layer
510
- self.config.update({"vision_config": self.vision_model.config.to_dict()})
511
- self.config.update({"language_config": self.language_model.config.to_dict()})
512
- self.lm_head_vocab_size = (
513
- language_config.padded_vocab_size
514
- if hasattr(language_config, "padded_vocab_size")
515
- else language_config.vocab_size
516
- )
517
- self.language_model.lm_head = nn.Linear(language_config.hidden_size, self.lm_head_vocab_size, bias=False)
518
- self.model_parallel = False
519
- self.device_map = None
520
- self.use_no_grad = None
521
- self.decoder_max_length = config.decoder_max_length
522
-
523
- self.anyres = config.anyres
524
- self.unpad = config.unpad
525
- if self.anyres:
526
- self.image_newline = nn.Parameter(torch.empty(language_config.hidden_size, dtype=self.dtype))
527
-
528
- self.is_safetensor_save = kwargs.get("is_safetensor_save", True)
529
- self._backward_compatibility_gradient_checkpointing()
530
-
531
- def _init_weights(self, module):
532
- # copies from https://github.com/kakaobrain/honeybee/blob/main/honeybee/common_layers.py#L55
533
- if (
534
- isinstance(module, nn.Conv2d) # noqa: SIM101
535
- or isinstance(module, nn.Embedding)
536
- or isinstance(module, nn.Linear)
537
- ):
538
- module.weight.data.normal_(mean=0.0, std=0.02)
539
- if hasattr(module, "bias") and module.bias is not None:
540
- module.bias.data.zero_()
541
-
542
- elif isinstance(module, nn.LayerNorm):
543
- module.bias.data.zero_()
544
- module.weight.data.fill_(1.0)
545
- elif isinstance(module, nn.Parameter):
546
- embed_std = 1 / torch.sqrt(torch.tensor(module.size(0), dtype=torch.float)).to(module.dtype)
547
- module.data.normal_(mean=0.0, std=embed_std)
548
-
549
- def forward(
550
- self,
551
- input_ids: Optional[torch.LongTensor] = None,
552
- pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
553
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
554
- attention_mask: Optional[torch.FloatTensor] = None,
555
- inputs_embeds: Optional[torch.FloatTensor] = None,
556
- labels: Optional[torch.LongTensor] = None,
557
- use_cache: Optional[bool] = None,
558
- output_attentions: Optional[bool] = None,
559
- output_hidden_states: Optional[bool] = None,
560
- return_dict: Optional[bool] = None,
561
- image_sizes: Optional[List[List[List[int]]]] = None,
562
- vision_query_lengths: Optional[List[List[int]]] = None,
563
- non_vision_query_lengths: Optional[List[int]] = None,
564
- img_start_ids_list: Optional[List[List[int]]] = None,
565
- num_queries_vis_abstractors: Optional[List[List[int]]] = None,
566
- num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
567
- first_last_frames_slows: Optional[List[bool]] = None,
568
- is_video_list: Optional[List[bool]] = None,
569
- **kwargs,
570
- ) -> Union[Tuple, HCXVisionOutput]:
571
- """Forward pass of the model.
572
-
573
- This method processes the input tokens and images, combines them into a unified
574
- representation, and generates text output based on the inputs.
575
-
576
- Args:
577
- input_ids: Input token IDs. In positions where images are inputted, the value is replaced by "<|dummy3|>"
578
- pixel_values: List of lists of 4D tensors for images. Each outer list corresponds to a batch and contains
579
- inner lists of image tensors.
580
- past_key_values: Pre-computed key and value states of the attention layers for faster inference.
581
- attention_mask: Mask to avoid performing attention on padding token indices.
582
- inputs_embeds: Input embeddings. If provided, input_ids will not be used.
583
- labels: Labels for computing the language modeling loss.
584
- use_cache: Whether to use past key/values for faster inference.
585
- output_attentions: Whether to return attention weights of each layer.
586
- output_hidden_states: Whether to return hidden states of each layer.
587
- return_dict: Whether to return a ModelOutput instead of a tuple.
588
- image_sizes: List of lists representing image dimensions (width, height).
589
- vision_query_lengths: List of lists containing lengths when each image is converted into visual tokens.
590
- non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample.
591
- img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample.
592
- num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.\
593
- For video frames, this is the number of visual tokens for the fast part.
594
- num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for
595
- the slow part when applying the slowfast algorithm to video frames.
596
- first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is
597
- applied to the first or last frames of the video.
598
- is_video_list: List of booleans indicating which inputs are videos.
599
- **kwargs: Additional keyword arguments.
600
-
601
- Returns:
602
- If return_dict=True, returns an HCXVisionOutput object containing:
603
- - loss: Language modeling loss if labels are provided, otherwise None.
604
- - loss_per_sample: Per-sample loss if labels are provided, otherwise None.
605
- - logits: Prediction scores of the language modeling head.
606
- - past_key_values: Past key/values for faster inference if use_cache=True.
607
- - hidden_states: Hidden states of all layers if output_hidden_states=True.
608
- - attentions: Attention weights of all layers if output_attentions=True.
609
- If return_dict=False, returns a tuple containing the above items except loss_per_sample.
610
- """
611
- output_attentions = (
612
- output_attentions if output_attentions is not None else self.config.vision_config["output_attentions"]
613
- )
614
- output_hidden_states = (
615
- output_hidden_states
616
- if output_hidden_states is not None
617
- else self.config.vision_config["output_hidden_states"]
618
- )
619
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
620
-
621
- if inputs_embeds is None and past_key_values is None:
622
- inputs_embeds = self.extract_inputs_embeds(
623
- input_ids=input_ids,
624
- pixel_values=pixel_values,
625
- past_key_values=past_key_values,
626
- image_sizes=image_sizes,
627
- vision_query_lengths=vision_query_lengths,
628
- non_vision_query_lengths=non_vision_query_lengths,
629
- img_start_ids_list=img_start_ids_list,
630
- num_queries_vis_abstractors=num_queries_vis_abstractors,
631
- num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
632
- first_last_frames_slows=first_last_frames_slows,
633
- is_videos=is_video_list,
634
- )
635
-
636
- if inputs_embeds is not None:
637
- input_ids = None
638
-
639
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
640
- outputs = self.language_model.base_model(
641
- input_ids=input_ids,
642
- inputs_embeds=inputs_embeds,
643
- attention_mask=attention_mask,
644
- past_key_values=past_key_values,
645
- use_cache=use_cache,
646
- output_attentions=output_attentions,
647
- output_hidden_states=output_hidden_states,
648
- return_dict=return_dict,
649
- )
650
-
651
- hidden_states = outputs[0]
652
- hidden_states = hidden_states * self.language_config.logits_scaling
653
-
654
- loss = None
655
- loss_per_sample = None
656
- logits = self.language_model.lm_head(hidden_states)
657
- if labels is not None:
658
- # Shift so that tokens < n predict n
659
- shift_logits = logits[..., :-1, :].contiguous()
660
- shift_labels = labels[..., 1:].contiguous()
661
- # Flatten the tokens
662
- loss_fct = CrossEntropyLoss(reduction="none") # ignore IGNORE_INDEX(-100)
663
- shift_logits = shift_logits.view(-1, self.lm_head_vocab_size)
664
- shift_labels = shift_labels.view(-1)
665
- # Enable model/pipeline parallelism
666
- shift_labels = shift_labels.to(shift_logits.device)
667
- loss = loss_fct(shift_logits, shift_labels)
668
- if get_rank() == 0:
669
- loss_per_sample = loss.view(logits.shape[0], -1).sum(axis=1) / (
670
- shift_labels.view(logits.shape[0], -1) != self.config.ignore_index
671
- ).sum(axis=1)
672
- loss = loss[shift_labels != self.config.ignore_index].mean()
673
- if not return_dict:
674
- output = (logits,) + outputs[1:]
675
- return (loss,) + output if loss is not None else output
676
-
677
- return HCXVisionOutput(
678
- loss=loss,
679
- loss_per_sample=loss_per_sample,
680
- logits=logits,
681
- past_key_values=outputs.past_key_values,
682
- hidden_states=outputs.hidden_states,
683
- attentions=outputs.attentions,
684
- )
685
-
686
- def determine_non_vision_query_lengths(
687
- self, input_ids: torch.LongTensor, pad_id: int, img_start_id: int
688
- ) -> List[int]:
689
- """Calculate the lengths of non-vision query parts in the input.
690
-
691
- This method calculates the length of text tokens (excluding visual tokens) for each sample.
692
- When input_ids are collated, they are padded with pad_id on the right, so this method finds
693
- these values by identifying pad tokens and img_start_id tokens.
694
-
695
- Args:
696
- input_ids: Input token IDs with img_start_id markers for image positions.
697
- pad_id: Token ID used for padding.
698
- img_start_id: Token ID marking the start of image data.
699
-
700
- Returns:
701
- List of lengths of non-vision query parts for each sample in the batch.
702
- """
703
- non_vision_query_lengths = []
704
- batch_size, len_seq = input_ids.size(0), input_ids.size(1)
705
-
706
- for i in range(batch_size):
707
- temp_idx = (input_ids[i] == pad_id).nonzero()
708
- eos_idx = temp_idx[0, 0].item() if len(temp_idx) > 0 else len_seq
709
- num_imgs = (input_ids[i] == img_start_id).sum().item()
710
- non_vision_query_lengths.append(eos_idx - num_imgs)
711
-
712
- if all([pad_id in input_id for input_id in input_ids.tolist()]):
713
- non_vision_query_lengths = [
714
- non_vision_query_length + 1 for non_vision_query_length in non_vision_query_lengths
715
- ]
716
-
717
- return non_vision_query_lengths
718
-
719
- def determine_vision_query_lengths(
720
- self, image_features: List[List[torch.Tensor]], image_cnts: List[int]
721
- ) -> List[List[int]]:
722
- """Calculate the lengths of vision query parts in the input.
723
-
724
- This method calculates the lengths of visual tokens for each image in each sample based on
725
- the shapes of image feature tensors. For samples without any images, a dummy image is included
726
- but then converted to an empty list.
727
-
728
- Args:
729
- image_features: List of lists of image features tensors.
730
- image_cnts: List of counts of images for each sample in the batch.
731
-
732
- Returns:
733
- List of lists of lengths of visual tokens for each image in each sample.
734
- """
735
- vision_query_lengths = [
736
- [image_feature.size(0) for image_feature in image_feature_list] for image_feature_list in image_features
737
- ]
738
-
739
- for i, image_cnt in enumerate(image_cnts):
740
- if image_cnt == 0:
741
- assert len(vision_query_lengths[i]) == 1 # 현재 검정 이미지 1개 들어가있음
742
- vision_query_lengths[i] = [] # 빈 list 로 변환
743
-
744
- return vision_query_lengths
745
-
746
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
747
- def get_input_embeddings(self):
748
- return self.language_model.get_input_embeddings()
749
-
750
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
751
- def set_input_embeddings(self, value):
752
- self.language_model.set_input_embeddings(value)
753
-
754
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
755
- def get_output_embeddings(self):
756
- return self.language_model.get_output_embeddings()
757
-
758
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
759
- def set_output_embeddings(self, new_embeddings):
760
- self.language_model.set_output_embeddings(new_embeddings)
761
-
762
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
763
- def set_decoder(self, decoder):
764
- self.language_model.set_decoder(decoder)
765
-
766
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
767
- def get_decoder(self):
768
- return self.language_model.get_decoder()
769
-
770
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
771
- def tie_weights(self):
772
- return self.language_model.tie_weights()
773
-
774
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
775
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
776
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
777
- self.config.text_config.vocab_size = model_embeds.num_embeddings
778
- self.vocab_size = model_embeds.num_embeddings
779
- return model_embeds
780
-
781
- def extract_inputs_embeds(
782
- self,
783
- input_ids: Optional[torch.LongTensor] = None,
784
- pixel_values: Optional[List[List[torch.FloatTensor]]] = None, # list of list of 4D tensors
785
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
786
- image_sizes: Optional[List[List[List[int]]]] = None,
787
- vision_query_lengths: Optional[List[List[int]]] = None,
788
- non_vision_query_lengths: Optional[List[int]] = None,
789
- img_start_ids_list: Optional[List[List[int]]] = None,
790
- num_queries_vis_abstractors: Optional[List[List[int]]] = None,
791
- num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
792
- first_last_frames_slows: Optional[List[bool]] = None,
793
- is_videos: Optional[List[str]] = None,
794
- ):
795
- """Extract input embeddings by processing text tokens and visual features.
796
-
797
- This method processes the input tokens and image features, extracts the visual features
798
- using the vision model, and combines them with the text token embeddings to create
799
- a unified input representation for the language model.
800
-
801
- Args:
802
- input_ids: Input token IDs with img_start_id markers for image positions.
803
- pixel_values: List of lists of image tensors.
804
- past_key_values: Pre-computed key and value states for faster inference.
805
- image_sizes: List of lists of image dimensions (width, height).
806
- vision_query_lengths: List of lists of lengths when each image is converted to visual tokens.
807
- non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample.
808
- img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample.
809
- num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.
810
- num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for
811
- the slow part when applying the slowfast algorithm to video frames.
812
- first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is
813
- applied to the first or last frames of the video.
814
- is_videos: List of booleans indicating which inputs are videos.
815
-
816
- Returns:
817
- Combined embeddings of text tokens and visual features.
818
- """
819
- inputs_embeds = None
820
- if past_key_values:
821
- pass
822
- else:
823
- # Flatten CLIP and connector for feature encoding, then convert back to List of List format
824
- len_pixel_values = [len(pixel_value) for pixel_value in pixel_values]
825
- concat_pixel_values = torch.cat(list(chain(*pixel_values)), dim=0) # list of list of 4D Tensor
826
- visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
827
- # Check if all parameters of the model require_grad=False
828
- if self.use_no_grad is None:
829
- self.use_no_grad = all(not p.requires_grad for p in self.vision_model.vision_model.encoder.parameters())
830
- context = torch.no_grad() if self.use_no_grad else contextlib.nullcontext()
831
- with context:
832
- if self.use_no_grad:
833
- # Fixed number of for-loop iterations to 10.
834
- # Currently no memory effect observed, so proceeding without chunking.
835
- n_chunks = 1
836
- else:
837
- n_chunks = 1
838
- total_len = concat_pixel_values.size(0)
839
- # Calculate the size of each chunk based on total data length (divided into 10 chunks)
840
- chunk_size = math.ceil(total_len / n_chunks) if total_len > 0 else 1
841
- image_forward_outs_chunks = []
842
-
843
- for i in range(n_chunks):
844
- start = i * chunk_size
845
- end = (i + 1) * chunk_size
846
- # Current chunk slice (could be an empty tensor if there's no data)
847
- chunk = concat_pixel_values[start:end].to(self.vision_model.dtype)
848
- # If the current chunk size is smaller than chunk_size, pad with dummy data
849
- if chunk.size(0) < chunk_size:
850
- # print(f"chunk.size(0): {chunk.size(0)}, chunk_size: {chunk_size}")
851
- pad_size = chunk_size - chunk.size(0)
852
- # Create dummy tensor based on concat_pixel_values shape
853
- dummy_shape = (pad_size,) + tuple(concat_pixel_values.shape[1:])
854
- dummy = torch.zeros(
855
- dummy_shape,
856
- dtype=concat_pixel_values.dtype,
857
- device=concat_pixel_values.device,
858
- )
859
- chunk = torch.cat([chunk, dummy], dim=0)
860
-
861
- # Pass the chunk through the vision model (processed according to use_nth_layer)
862
- if self.use_nth_layer == -1:
863
- # Replace post_layernorm of the last layer with Identity
864
- self.vision_model.vision_model.post_layernorm = nn.Identity()
865
- outs = self.vision_model(chunk)
866
- outs = outs.last_hidden_state[:, visual_token_idx:]
867
- else:
868
- outs = self.vision_model(chunk, output_hidden_states=True)
869
- outs = outs.hidden_states[self.use_nth_layer][:, visual_token_idx:]
870
- image_forward_outs_chunks.append(outs)
871
-
872
- # Concatenate results from all chunks
873
- image_forward_outs = torch.cat(image_forward_outs_chunks, dim=0).to(image_forward_outs_chunks[0].dtype)
874
-
875
- if num_queries_vis_abstractors is None:
876
- assert num_queries_vis_abstractors_slow is None
877
- image_sizes = list(chain(*image_sizes))
878
- if is_videos is not None:
879
- is_videos = list(chain(*is_videos))
880
- group_ids = None
881
- image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype)
882
- image_forward_outs = self.mm_projector(image_forward_outs)
883
- else:
884
- # adaptive anyres is only implemented in HCXVisionCAbstractor
885
- assert isinstance(self.mm_projector, HCXVisionCAbstractor)
886
-
887
- (
888
- num_queries_vis_abstractors,
889
- num_grids,
890
- image_sizes,
891
- is_videos,
892
- group_ids,
893
- ) = self.compute_adaptive_params(
894
- pixel_values,
895
- num_queries_vis_abstractors,
896
- num_queries_vis_abstractors_slow,
897
- image_sizes,
898
- is_videos,
899
- first_last_frames_slows,
900
- )
901
-
902
- image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype)
903
- image_forward_outs = self.mm_projector(
904
- image_forward_outs,
905
- num_queries_vis_abstractors=num_queries_vis_abstractors,
906
- num_grids=num_grids,
907
- )
908
-
909
- if self.anyres:
910
- split_sizes = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)]
911
-
912
- if num_queries_vis_abstractors is None:
913
- image_features = anyres_postprocessing(
914
- image_forward_outs=image_forward_outs,
915
- split_sizes=split_sizes,
916
- image_sizes=image_sizes,
917
- num_queries_vis_abstractor=self.num_queries_vis_abstractor,
918
- unpad=self.unpad,
919
- is_videos=is_videos,
920
- patch_size=self.vision_model.config.patch_size,
921
- grid_size=self.vision_model.config.image_size,
922
- image_newline=self.image_newline,
923
- possible_resolutions=self.possible_resolutions,
924
- )
925
- else:
926
- image_features = adaptive_anyres_postprocessing(
927
- image_forward_outs=image_forward_outs,
928
- image_sizes=image_sizes,
929
- num_queries_vis_abstractors=num_queries_vis_abstractors,
930
- unpad=self.unpad,
931
- is_videos=is_videos,
932
- grid_size=self.vision_model.config.image_size,
933
- image_newline=self.image_newline,
934
- possible_resolutions=self.possible_resolutions,
935
- group_ids=group_ids,
936
- )
937
- else:
938
- if num_queries_vis_abstractors is None:
939
- image_features = [image_forward_out for image_forward_out in image_forward_outs]
940
- else:
941
- image_features = [image_forward_out.unsqueeze(0) for image_forward_out in image_forward_outs]
942
-
943
- # print(f"BEFORE GROUPING: len(image_features): {len(image_features)}")
944
- image_features = [
945
- image_features[sum(len_pixel_values[:i]) : sum(len_pixel_values[: i + 1])]
946
- for i in range(len(len_pixel_values))
947
- ]
948
-
949
- batch_size = input_ids.size(0)
950
- image_feature_dim = image_features[0][0].size(1)
951
- image_feature_dtype = image_features[0][0].dtype
952
-
953
- if img_start_ids_list is None:
954
- image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist()
955
- else:
956
- image_cnts = [len(img_start_ids) for img_start_ids in img_start_ids_list]
957
-
958
- if non_vision_query_lengths is None:
959
- non_vision_query_lengths = self.determine_non_vision_query_lengths(
960
- input_ids, self.tokenizer.pad_token_id, self.config.img_start_id
961
- )
962
-
963
- if vision_query_lengths is None:
964
- vision_query_lengths = self.determine_vision_query_lengths(image_features, image_cnts)
965
-
966
- # Slicing is faster than concatenation
967
- len_inputs_embeds = max(
968
- [
969
- sum(vision_query_length) + non_vision_query_length
970
- for non_vision_query_length, vision_query_length in zip(
971
- non_vision_query_lengths, vision_query_lengths
972
- )
973
- ]
974
- )
975
- len_inputs_embeds = min(self.decoder_max_length, len_inputs_embeds)
976
-
977
- inputs_embeds = torch.zeros(
978
- [batch_size, len_inputs_embeds, image_feature_dim],
979
- dtype=image_feature_dtype,
980
- device=self.device,
981
- requires_grad=True,
982
- ).clone()
983
- # temp_embeds : torch.bfloat16 : [batchsize, 174, 3072]
984
- temp_embeds = self.get_input_embeddings()(input_ids)
985
-
986
- # The complete format is <PROMPT><USER_PREFIX><VISION_QUERIES>Sentence
987
- for batch_idx, sample in enumerate(input_ids):
988
- # Concatenate with visual tokens and then slice
989
- non_vision_query_length = non_vision_query_lengths[batch_idx]
990
- # Safely concatenate with visual tokens and then slice
991
- sample = sample[: non_vision_query_length + image_cnts[batch_idx]]
992
-
993
- if image_cnts[batch_idx] == 0: # Text instruction data doesn't insert image features
994
- temp_idx = 0
995
- # Reference: https://github.com/haotian-liu/LLaVA/commit/44e0562f9497fb79f042427307472a87d266d90a#diff-4477387d506ccb1897a13972cba26c9da3fad4d3e1c32ec4b8bd8ff7acd3f292
996
- # https://github.com/intel/intel-extension-for-transformers/issues/1201#issuecomment-1915875119
997
- inputs_embeds[batch_idx, :non_vision_query_length] = temp_embeds[batch_idx][
998
- :non_vision_query_length
999
- ]
1000
- inputs_embeds[batch_idx, temp_idx:temp_idx] = image_features[batch_idx][0][
1001
- 0:0
1002
- ] # First image of batch_idx sample (dummy image)
1003
- else:
1004
- if img_start_ids_list is None:
1005
- img_start_ids = (sample == self.config.img_start_id).nonzero()
1006
- else:
1007
- img_start_ids = img_start_ids_list[batch_idx]
1008
- assert len(img_start_ids) == image_cnts[batch_idx] == len(image_features[batch_idx])
1009
- # Initialize starting points for input embeddings and temporary embeddings
1010
- input_start, temp_start = 0, 0
1011
-
1012
- # Iterate through each image starting point in the batch
1013
- for multi_img_idx, img_start_idx in enumerate(img_start_ids):
1014
- # Calculate token length up to the current image starting point
1015
- token_len = img_start_idx - temp_start
1016
-
1017
- # Copy tokens to inputs_embeds
1018
- inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[
1019
- batch_idx, temp_start : temp_start + token_len
1020
- ]
1021
-
1022
- inputs_embeds[
1023
- batch_idx,
1024
- input_start
1025
- + token_len : input_start
1026
- + token_len
1027
- + vision_query_lengths[batch_idx][multi_img_idx],
1028
- ] = image_features[batch_idx][multi_img_idx]
1029
-
1030
- # Update starting points for next token processing
1031
- input_start += token_len + vision_query_lengths[batch_idx][multi_img_idx]
1032
- temp_start += token_len + 1 # Increase by 1 to skip the image start token
1033
-
1034
- # Process tokens after the last image end token
1035
- token_len = min(sample[temp_start:].size(0), inputs_embeds.size(1) - input_start)
1036
- inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[
1037
- batch_idx, temp_start : temp_start + token_len
1038
- ]
1039
- return inputs_embeds
1040
-
1041
- @torch.no_grad()
1042
- def generate(
1043
- self,
1044
- input_ids: Optional[torch.LongTensor] = None,
1045
- pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
1046
- image_sizes: Optional[List[List[List[int]]]] = None,
1047
- vision_query_lengths: Optional[List[List[int]]] = None,
1048
- non_vision_query_lengths: Optional[List[int]] = None,
1049
- num_queries_vis_abstractors: Optional[List[List[int]]] = None,
1050
- num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
1051
- first_last_frames_slows: Optional[List[bool]] = None,
1052
- is_videos: Optional[List[bool]] = None,
1053
- img_start_ids_list: Optional[List[List[int]]] = None,
1054
- pad_token_id: Optional[int] = None,
1055
- eos_token_id: Optional[int] = None,
1056
- bad_words_ids: Optional[List[List[int]]] = None,
1057
- max_length: int = 196,
1058
- min_length: int = 2,
1059
- do_sample: bool = True,
1060
- num_beams: int = 1,
1061
- top_p: float = 0.6,
1062
- top_k: int = 0,
1063
- temperature: float = 0.5,
1064
- repetition_penalty: float = 1.0,
1065
- length_penalty: int = 1,
1066
- use_cache: bool = True,
1067
- **kwargs,
1068
- ) -> torch.LongTensor:
1069
- """Generate text based on input tokens and images.
1070
-
1071
- This method generates text based on the provided input tokens and images using
1072
- beam search and/or sampling strategies.
1073
-
1074
- Args:
1075
- input_ids: Input token IDs with img_start_id markers for image positions.
1076
- pixel_values: List of lists of image tensors.
1077
- image_sizes: List of lists of image dimensions (width, height).
1078
- vision_query_lengths: List of lists of lengths when each image is converted to visual tokens.
1079
- non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample.
1080
- num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.
1081
- num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for the slow part when
1082
- applying the slowfast algorithm to video frames.
1083
- first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is applied to the first
1084
- or last frames of the video.
1085
- is_videos: List of booleans indicating which inputs are videos.
1086
- img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample.
1087
- pad_token_id: Token ID used for padding.
1088
- eos_token_id: Token ID used to signal the end of a sequence.
1089
- bad_words_ids: List of token ID sequences that should not be generated.
1090
- max_length: Maximum length of the sequence to be generated (input length + max_new_tokens).
1091
- min_length: Minimum length of the sequence to be generated (input length + min_new_tokens).
1092
- do_sample: Whether to use sampling for generation (otherwise uses greedy decoding).
1093
- num_beams: Number of beams for beam search. 1 means no beam search.
1094
- top_p: Nucleus sampling parameter. Tokens with cumulative probability > top_p are kept.
1095
- top_k: Number of highest probability tokens to keep for top-k-filtering.
1096
- temperature: Value used to modulate the next token probabilities.
1097
- repetition_penalty: Penalty applied to tokens that have already appeared in the sequence.
1098
- length_penalty: Exponential penalty applied to sequence length.
1099
- use_cache: Whether to use past key/values for faster inference.
1100
- **kwargs: Additional keyword arguments.
1101
-
1102
- Returns:
1103
- Generated token IDs.
1104
- """
1105
- # inputs_embeds: torch.bfloat16 : [batchsize, variable(visual token, text token, system prompt 모두 포함)]
1106
- if pad_token_id is None:
1107
- pad_token_id = self.tokenizer.pad_token_id
1108
- if eos_token_id is None:
1109
- eos_token_id = self.tokenizer.encode("<|endofturn|>")[0]
1110
- if bad_words_ids is None:
1111
- bad_words_ids = [
1112
- [
1113
- self.config.language_config["bos_token_id"],
1114
- ],
1115
- [
1116
- self.config.language_config["eos_token_id"],
1117
- ],
1118
- ]
1119
-
1120
- if pixel_values is None:
1121
- return self.language_model.generate(
1122
- input_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, bad_words_ids=bad_words_ids, **kwargs
1123
- )
1124
- inputs_embeds = self.extract_inputs_embeds(
1125
- input_ids=input_ids,
1126
- pixel_values=self.to_vision_model_device(pixel_values),
1127
- image_sizes=image_sizes,
1128
- vision_query_lengths=vision_query_lengths,
1129
- non_vision_query_lengths=non_vision_query_lengths,
1130
- img_start_ids_list=img_start_ids_list,
1131
- num_queries_vis_abstractors=num_queries_vis_abstractors,
1132
- num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
1133
- first_last_frames_slows=first_last_frames_slows,
1134
- is_videos=is_videos,
1135
- )
1136
- inputs_embeds = (
1137
- inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
1138
- )
1139
-
1140
- # pred : torch.int64 : [batchsize, generated token_length]
1141
- pred = self.language_model.generate(
1142
- inputs_embeds=inputs_embeds,
1143
- pad_token_id=pad_token_id,
1144
- eos_token_id=eos_token_id,
1145
- bad_words_ids=bad_words_ids,
1146
- max_new_tokens=max_length,
1147
- min_length=min_length,
1148
- num_beams=num_beams,
1149
- do_sample=(False if temperature == 0.0 else do_sample), # set do_sample=False if invalid temperature
1150
- top_k=top_k,
1151
- top_p=top_p,
1152
- temperature=temperature,
1153
- repetition_penalty=repetition_penalty,
1154
- length_penalty=length_penalty,
1155
- early_stopping=(False if num_beams <= 1 else True), # set early_stopping=False when not beam_search
1156
- use_cache=use_cache,
1157
- )
1158
-
1159
- return pred
1160
-
1161
- def to_vision_model_device(self, input_tensor: Union[torch.Tensor, List]) -> Union[torch.Tensor, List]:
1162
- """Move input tensors to the vision model's device.
1163
- This method recursively moves input tensors or lists of tensors to the vision model's device.
1164
-
1165
- Args:
1166
- input_tensor: Input tensor or list of tensors to be moved to the vision model's device.
1167
-
1168
- Returns:
1169
- The input tensor or list of tensors moved to the vision model's device.
1170
-
1171
- Raises:
1172
- TypeError: If the input is neither a tensor nor a list.
1173
- """
1174
- if isinstance(input_tensor, list):
1175
- return [self.to_vision_model_device(item) for item in input_tensor]
1176
- elif isinstance(input_tensor, torch.Tensor):
1177
- return input_tensor.to(self.vision_model.device)
1178
- else:
1179
- raise TypeError("Unsupported data type. Only tensors and lists are allowed.")
1180
-
1181
- def prepare_inputs_for_generation(
1182
- self,
1183
- input_ids: torch.LongTensor,
1184
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1185
- attention_mask: Optional[torch.FloatTensor] = None,
1186
- inputs_embeds: Optional[torch.FloatTensor] = None,
1187
- **kwargs,
1188
- ) -> Dict[str, Any]:
1189
- """Prepare inputs for the generation algorithm.
1190
-
1191
- This method prepares the input for each generation step based on the model's needs.
1192
-
1193
- Args:
1194
- input_ids: Input token IDs.
1195
- past_key_values: Pre-computed key and value states for faster inference.
1196
- attention_mask: Mask to avoid performing attention on padding token indices.
1197
- inputs_embeds: Input embeddings. If provided, input_ids will not be used.
1198
- **kwargs: Additional keyword arguments.
1199
-
1200
- Returns:
1201
- Dictionary containing the prepared inputs for the model.
1202
- """
1203
- input_ids = kwargs.get("decoder_input_ids", input_ids)
1204
-
1205
- if past_key_values:
1206
- input_ids = input_ids[:, -1:]
1207
-
1208
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1209
- if inputs_embeds is not None and past_key_values is None:
1210
- model_inputs = {"inputs_embeds": inputs_embeds}
1211
- else:
1212
- model_inputs = {"input_ids": input_ids}
1213
-
1214
- model_inputs.update(
1215
- {
1216
- "past_key_values": past_key_values,
1217
- "use_cache": kwargs.get("use_cache"),
1218
- "attention_mask": attention_mask,
1219
- "pixel_values": kwargs.get("pixel_values", None),
1220
- }
1221
- )
1222
- return model_inputs
1223
-
1224
- @classmethod
1225
- def from_config(cls, config, vision_model_name_or_path):
1226
- return cls(config, vision_model_name_or_path)
1227
-
1228
- @classmethod
1229
- def from_pretrained(
1230
- cls,
1231
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
1232
- *model_args,
1233
- **kwargs,
1234
- ) -> "HCXVisionForCausalLM":
1235
- assert pretrained_model_name_or_path is not None
1236
-
1237
- save_only_vision = kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False
1238
- save_only_qformer = kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False
1239
- save_shard_size = kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB"
1240
-
1241
- if pretrained_model_name_or_path is not None: # when evaluate or load instruction tunned model
1242
- model: HCXVisionForCausalLM = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
1243
- model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
1244
-
1245
- img_start_id = model.tokenizer.encode(IMG_LOC, add_special_tokens=False)
1246
- assert (
1247
- len(img_start_id) == 1
1248
- ), f'"<|dummy3|>" was not encoded into a single special token. Encoding result: {img_start_id}'
1249
- model.config.img_start_id = img_start_id[0]
1250
-
1251
- model.save_only_vision = save_only_vision
1252
- model.save_only_qformer = save_only_qformer
1253
- model.save_shard_size = save_shard_size
1254
-
1255
- return model
1256
-
1257
- def get_language_model(self):
1258
- return self.language_model.base_model
1259
-
1260
- def get_vision_model(self):
1261
- return self.vision_model
1262
-
1263
- def save_pretrained(
1264
- self,
1265
- save_directory: Union[str, os.PathLike],
1266
- *args,
1267
- **kwargs,
1268
- ):
1269
- state_dict = kwargs["state_dict"] if "state_dict" in kwargs else self.state_dict()
1270
- partial_state_dict = self.get_pretrained_state_dict(
1271
- state_dict,
1272
- save_directory,
1273
- )
1274
- kwargs["state_dict"] = partial_state_dict
1275
- kwargs["safe_serialization"] = self.is_safetensor_save
1276
- kwargs.setdefault("max_shard_size", self.save_shard_size)
1277
- super().save_pretrained(save_directory, *args, **kwargs)
1278
-
1279
- def get_pretrained_state_dict(self, state_dict, save_dir):
1280
- vision_key = "vision_model."
1281
- llm_keys = ["language_model."]
1282
- head_key = "lm_head."
1283
-
1284
- for key in list(state_dict.keys()):
1285
- if self.save_only_vision:
1286
- for llm_key in llm_keys:
1287
- if llm_key in key:
1288
- state_dict.pop(key)
1289
- if key.startswith(head_key):
1290
- state_dict.pop(key)
1291
-
1292
- elif self.save_only_qformer:
1293
- if f"{vision_key}" in key:
1294
- state_dict.pop(key)
1295
-
1296
- return state_dict
1297
-
1298
- def compute_adaptive_params(
1299
- self,
1300
- pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
1301
- num_queries_vis_abstractors: Optional[List[List[int]]] = None,
1302
- num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
1303
- image_sizes: Optional[List[List[List[int]]]] = None,
1304
- is_videos: Optional[List[bool]] = None,
1305
- first_last_frames_slows: Optional[List[bool]] = None,
1306
- ) -> Tuple[List[int], List[int], List[List[int]], List[bool], List[List[int]]]:
1307
- """Compute adaptive parameters for processing different image and video inputs.
1308
-
1309
- This method calculates parameters needed for adaptive processing, especially when handling
1310
- variable resolutions or applying the slowfast algorithm to video frames. It flattens
1311
- batch-level inputs (lists of lists) into single lists representing all images/frames
1312
- in the batch. Based on slowfast configuration, it may split video frames into 'slow'
1313
- and 'fast' components, adjusting query counts and grid indices accordingly.
1314
-
1315
- Args:
1316
- pixel_values: List of lists of image tensors (per sample). Used to determine the initial number of grids per
1317
- image/frame.
1318
- num_queries_vis_abstractors: List of lists (per sample) containing the base number of visual tokens
1319
- generated by the visual abstractor for each image grid
1320
- (e.g., 81 for a full grid, 9 for a subsampled/fast grid).
1321
- num_queries_vis_abstractors_slow: List of lists (per sample) containing the number of visual tokens for the
1322
- 'slow' path when applying slowfast. Non-zero values here trigger the slowfast processing logic.
1323
- image_sizes: List of lists (per sample) of original image dimensions ([width, height]).
1324
- is_videos: List of lists (per sample) of booleans indicating if each input item is part of a video sequence.
1325
- first_last_frames_slows: List (per sample) of booleans. If True, slowfast logic
1326
- (if active based on `num_queries_vis_abstractors_slow`) is applied only to the first or last frame(s)
1327
- within each video sequence.
1328
-
1329
- Returns:
1330
- Tuple containing:
1331
- - num_queries_vis_abstractors: Flattened list of final query counts per processed grid.
1332
- Values might be adjusted based on slow/fast splitting
1333
- (e.g., using values from `num_queries_vis_abstractors_slow` for slow frames).
1334
- Example: [81, 81, 81, 9, 81, 9, ...] (Image, Image, Vid_Slow, Vid_Fast, Vid_Slow, Vid_Fast...)
1335
- - num_grids: Flattened list representing cumulative grid counts, acting as end indices for slicing the
1336
- flattened `image_forward_outs`. Adjusted for slow/fast splits.
1337
- Example: [0, 1, 9, 10, 18, 19, 27, ...] (Indices after Grid0_Slow(1),
1338
- Grid1_Fast(8), Grid2_Slow(1), Grid3_Fast(8)...).
1339
- - image_sizes: Flattened list of image dimensions ([width, height]), potentially duplicated if slow/fast
1340
- splitting occurred.
1341
- - is_videos: Flattened list of booleans indicating video status, potentially duplicated for
1342
- slow/fast splits. Example: [False, False, True, True, True, True, ...]
1343
- (Image1, Image2, Vid_grid1_slow, Vid_grid1_fast, Vid_grid2_slow, Vid_grid2_fast...)
1344
- - group_ids: List of lists, grouping indices that correspond to the same original image or frame.
1345
- If a frame is split into slow/fast, its group will contain multiple indices.
1346
- Example: [[0], [1], [2, 3], [4, 5], ...]
1347
- (Group for Image1, Group for Image2, Group for Vid1_Slow+Fast, Group for Vid2_Slow+Fast...).
1348
-
1349
- Raises:
1350
- AssertionError: If input validation fails (e.g., negative query counts).
1351
- Exception: If an unexpected case is encountered during slowfast processing.
1352
- """
1353
-
1354
- # Check if all elements are integers greater than or equal to 0
1355
- assert all(
1356
- all(isinstance(value, int) and value >= 0 for value in sublist) for sublist in num_queries_vis_abstractors
1357
- ), "All values in num_queries_vis_abstractors must be integers >= 0."
1358
-
1359
- assert all(
1360
- all(isinstance(value, int) and value >= 0 for value in sublist)
1361
- for sublist in num_queries_vis_abstractors_slow
1362
- ), "All values in num_queries_vis_abstractors_slow must be integers >= 0."
1363
-
1364
- assert is_videos is not None
1365
-
1366
- # Is it the first or last image? (for applying slowfast to video processing)
1367
- is_first_images = []
1368
- is_last_images = []
1369
- for is_video in is_videos:
1370
- for idx, is_video_item in enumerate(is_video):
1371
- if idx == 0:
1372
- is_first_images.append(True)
1373
- else:
1374
- is_first_images.append(False)
1375
- if idx == len(is_video) - 1:
1376
- is_last_images.append(True)
1377
- else:
1378
- is_last_images.append(False)
1379
-
1380
- num_queries_vis_abstractors = list(chain(*num_queries_vis_abstractors))
1381
- num_queries_vis_abstractors_slow = list(chain(*num_queries_vis_abstractors_slow))
1382
- image_sizes = list(chain(*image_sizes))
1383
- is_videos = list(chain(*is_videos))
1384
- first_last_frames_slows = list(chain(*first_last_frames_slows))
1385
-
1386
- # Use slowfast mode if there's at least one visual token count greater than 0 in num_queries_vis_abstractors_slow
1387
- use_slowfast = any([num_query > 0 for num_query in num_queries_vis_abstractors_slow])
1388
- num_grids = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)]
1389
- num_grids = [0] + num_grids
1390
- group_ids = []
1391
-
1392
- if use_slowfast:
1393
- new_num_grids = [num_grids[0]]
1394
- new_num_queries = []
1395
- new_image_sizes = []
1396
- new_is_videos = []
1397
-
1398
- # When using slowfast, split more finely
1399
- # 0th local grid is slow frame, remaining local grids are fast frames
1400
- for (
1401
- num_query,
1402
- num_query_slow,
1403
- num_grid,
1404
- image_size,
1405
- is_video,
1406
- first_last_frames_slow,
1407
- is_first_image,
1408
- is_last_image,
1409
- ) in zip(
1410
- num_queries_vis_abstractors,
1411
- num_queries_vis_abstractors_slow,
1412
- num_grids[1:],
1413
- image_sizes,
1414
- is_videos,
1415
- first_last_frames_slows,
1416
- is_first_images,
1417
- is_last_images,
1418
- ):
1419
-
1420
- if not first_last_frames_slow and num_query_slow > 0: # Process all image in slowfast mode
1421
- assert is_video # slowfast mode is only applied to videos
1422
-
1423
- this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0]
1424
-
1425
- # slow frame (first grid)
1426
- new_num_grids.append(new_num_grids[-1] + 1)
1427
- new_num_queries.append(num_query_slow)
1428
- new_image_sizes.append(image_size)
1429
- new_is_videos.append(is_video)
1430
-
1431
- if num_grid >= 2:
1432
- # fast frames
1433
- new_num_grids.append(new_num_grids[-1] + num_grid - 1)
1434
- new_num_queries.append(num_query)
1435
- new_image_sizes.append(image_size)
1436
- new_is_videos.append(is_video)
1437
- this_group_ids.append(this_group_ids[-1] + 1)
1438
-
1439
- group_ids.append(this_group_ids)
1440
- elif (
1441
- first_last_frames_slow and num_query_slow > 0 and (is_first_image or is_last_image)
1442
- ): # Process only first/last image in slowfast mode
1443
- # Case for special treatment of first/last frames in slow mode
1444
- assert is_video # slowfast mode is only applied to videos
1445
-
1446
- this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0]
1447
-
1448
- if num_grid == 1:
1449
- # Simply process with slow since there's only one grid
1450
- new_num_grids.append(new_num_grids[-1] + 1)
1451
- new_num_queries.append(num_query_slow)
1452
- new_image_sizes.append(image_size)
1453
- new_is_videos.append(is_video)
1454
-
1455
- if num_grid >= 2:
1456
- # Special treatment for first or last grid depending on is_first_image or is_last_image
1457
-
1458
- if is_first_image: # includes both first and last
1459
- # slow frame (first grid)
1460
- new_num_grids.append(new_num_grids[-1] + 1)
1461
- new_num_queries.append(num_query_slow)
1462
- new_image_sizes.append(image_size)
1463
- new_is_videos.append(is_video)
1464
- # fast frames
1465
- new_num_grids.append(new_num_grids[-1] + num_grid - 1)
1466
- new_num_queries.append(num_query)
1467
- new_image_sizes.append(image_size)
1468
- new_is_videos.append(is_video)
1469
- this_group_ids.append(this_group_ids[-1] + 1)
1470
- elif is_last_image:
1471
- # fast frames
1472
- new_num_grids.append(new_num_grids[-1] + num_grid - 1)
1473
- new_num_queries.append(num_query)
1474
- new_image_sizes.append(image_size)
1475
- new_is_videos.append(is_video)
1476
- # slow frame (last grid)
1477
- new_num_grids.append(new_num_grids[-1] + 1)
1478
- new_num_queries.append(num_query_slow)
1479
- new_image_sizes.append(image_size)
1480
- new_is_videos.append(is_video)
1481
- this_group_ids.append(this_group_ids[-1] + 1)
1482
- else:
1483
- raise Exception("This case should not be reached.")
1484
- group_ids.append(this_group_ids)
1485
- else:
1486
- # Not in slowfast mode, so reduce all by num_query (fast)
1487
- new_num_grids.append(new_num_grids[-1] + num_grid)
1488
- new_num_queries.append(num_query)
1489
- new_image_sizes.append(image_size)
1490
- new_is_videos.append(is_video)
1491
-
1492
- start_group_id = group_ids[-1][-1] + 1 if group_ids else 0
1493
- group_ids.append([start_group_id])
1494
-
1495
- num_grids = new_num_grids
1496
- num_queries_vis_abstractors = new_num_queries
1497
- image_sizes = new_image_sizes
1498
- is_videos = new_is_videos
1499
- else:
1500
- num_grids = [sum(num_grids[:i]) for i in range(1, len(num_grids) + 1)]
1501
- group_ids = [[group_id] for group_id in range(len(is_videos))]
1502
-
1503
- return num_queries_vis_abstractors, num_grids, image_sizes, is_videos, group_ids
1504
-
1505
-
1506
- def load_state_dict_into_model(model_to_load, state_dict, strict=True, start_prefix=""):
1507
- # from https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/modeling_utils.py#L517
1508
- # Convert old format to new format if needed from a PyTorch state_dict
1509
- old_keys = []
1510
- new_keys = []
1511
- for key in state_dict.keys():
1512
- new_key = None
1513
- if "gamma" in key:
1514
- new_key = key.replace("gamma", "weight")
1515
- if "beta" in key:
1516
- new_key = key.replace("beta", "bias")
1517
- if new_key:
1518
- old_keys.append(key)
1519
- new_keys.append(new_key)
1520
- for old_key, new_key in zip(old_keys, new_keys):
1521
- state_dict[new_key] = state_dict.pop(old_key)
1522
-
1523
- # copy state_dict so _load_from_state_dict can modify it
1524
- metadata = getattr(state_dict, "_metadata", None)
1525
- state_dict = state_dict.copy()
1526
- if metadata is not None:
1527
- state_dict._metadata = metadata
1528
-
1529
- error_msgs = []
1530
-
1531
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
1532
- # so we need to apply the function recursively.
1533
- def load(module: nn.Module, state_dict, prefix=""):
1534
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
1535
- args = (state_dict, prefix, local_metadata, strict, [], [], error_msgs)
1536
- # Parameters of module and children will start with prefix. We can exit early if there are none in this
1537
- # state_dict
1538
- if len([key for key in state_dict if key.startswith(prefix)]) > 0:
1539
- if is_deepspeed_zero3_enabled():
1540
- import deepspeed
1541
-
1542
- # In sharded models, each shard has only part of the full state_dict, so only gather
1543
- # parameters that are in the current state_dict.
1544
- named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
1545
- params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
1546
- if len(params_to_gather) > 0:
1547
- # because zero3 puts placeholders in model params, this context
1548
- # manager gathers (unpartitions) the params of the current layer, then loads from
1549
- # the state dict and then re-partitions them again
1550
- with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
1551
- if torch.distributed.get_rank() == 0:
1552
- module._load_from_state_dict(*args)
1553
- else:
1554
- module._load_from_state_dict(*args)
1555
-
1556
- for name, child in module._modules.items():
1557
- if child is not None:
1558
- load(child, state_dict, prefix + name + ".")
1559
-
1560
- load(model_to_load, state_dict, prefix=start_prefix)
1561
- # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
1562
- # it's safe to delete it.
1563
- del state_dict
1564
-
1565
- return error_msgs
1566
-
1567
-
1568
- class HCXVisionCAbstractor(nn.Module):
1569
- """
1570
- This module is based on C-Abstractor, whose license is under apache-2.0.
1571
- You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py
1572
- and we made necessary modifications.
1573
- """
1574
-
1575
- def __init__(
1576
- self,
1577
- num_queries: int,
1578
- num_input_tokens: int,
1579
- encoder_hidden_size: int,
1580
- hidden_size: int,
1581
- output_hidden_size: int,
1582
- pos_emb: bool = True,
1583
- prenorm: bool = False,
1584
- ):
1585
- super().__init__()
1586
- self.num_input_tokens = num_input_tokens
1587
- self.output_hidden_size = output_hidden_size
1588
-
1589
- # Positional embedding
1590
- if pos_emb:
1591
- self.pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size))
1592
- self.pos_emb.data.normal_(mean=0.0, std=0.02)
1593
- else:
1594
- self.pos_emb = None
1595
-
1596
- # (Optional) Pre-normalization layer
1597
- if prenorm:
1598
- self.prenorm = LayerNorm(encoder_hidden_size)
1599
- else:
1600
- self.prenorm = None
1601
-
1602
- self.build_net(num_queries, encoder_hidden_size, hidden_size, output_hidden_size)
1603
- self.dtype = next(self.parameters()).dtype
1604
-
1605
- def forward(
1606
- self,
1607
- x: torch.Tensor,
1608
- num_queries_vis_abstractors: Optional[List[List[int]]] = None,
1609
- num_grids: Optional[List[int]] = None,
1610
- ) -> torch.Tensor:
1611
- """
1612
- Args:
1613
- x: (B, L, encoder_hidden_size) tensor from the visual backbone (e.g. CLIP visual encoder), including cls token.
1614
- """
1615
- if self.prenorm is not None:
1616
- x = self.prenorm(x)
1617
-
1618
- if self.pos_emb is not None:
1619
- x = x + self.pos_emb
1620
-
1621
- x = self._forward(
1622
- x,
1623
- num_queries_vis_abstractors=num_queries_vis_abstractors,
1624
- num_grids=num_grids,
1625
- ) # (B, L, output_hidden_size)
1626
-
1627
- return x
1628
-
1629
- def _forward(
1630
- self,
1631
- x: torch.Tensor,
1632
- num_queries_vis_abstractors: Optional[List[List[int]]] = None,
1633
- num_grids: Optional[List[int]] = None,
1634
- ) -> torch.Tensor:
1635
- # x: [B, L, dim]
1636
- B, L, dim = x.shape
1637
- hw = int(L ** 0.5)
1638
- x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
1639
-
1640
- if num_queries_vis_abstractors is not None:
1641
- assert num_grids is not None
1642
- return self._forward_adaptive_num_query(x, num_queries_vis_abstractors, num_grids)
1643
-
1644
- x = self.net(x)
1645
- x = rearrange(x, "b d h w -> b (h w) d")
1646
- x = self.readout(x)
1647
- return x
1648
-
1649
- def _forward_adaptive_num_query(
1650
- self,
1651
- x: torch.Tensor,
1652
- num_queries_vis_abstractors: Optional[List[List[int]]] = None,
1653
- num_grids: Optional[List[int]] = None,
1654
- ) -> List[torch.Tensor]:
1655
- # self.net is consisted by 3 layers (s1, sampler, s2)
1656
- assert len(self.net) == 3
1657
-
1658
- x = self.net[0](x) # s1
1659
- new_x = []
1660
- for i, num_queries in enumerate(num_queries_vis_abstractors):
1661
- hw = int(num_queries**0.5)
1662
- sampler = nn.AdaptiveAvgPool2d((hw, hw))
1663
- out = sampler(x[num_grids[i]:num_grids[i + 1], :])
1664
- out = self.net[2](out) # s2
1665
-
1666
- out = rearrange(out, "b d h w -> b (h w) d")
1667
- out = self.readout(out)
1668
-
1669
- new_x.append(out)
1670
- return new_x
1671
-
1672
- def build_net(
1673
- self,
1674
- n_queries: int,
1675
- encoder_hidden_size: int,
1676
- hidden_size: int,
1677
- output_hidden_size: int,
1678
- depth: int = 3,
1679
- mlp_depth: int = 2,
1680
- ):
1681
- assert (n_queries ** 0.5).is_integer(), f"n_queries must be square number. n_queries: {n_queries}"
1682
- hw = int(n_queries ** 0.5)
1683
-
1684
- # RegBlock = ResBlock + SE
1685
- RegBlock = partial(
1686
- RegStage,
1687
- stride=1,
1688
- dilation=1,
1689
- act_layer=nn.SiLU,
1690
- norm_layer=LayerNorm2d,
1691
- )
1692
-
1693
- s1 = RegBlock(
1694
- depth,
1695
- encoder_hidden_size,
1696
- hidden_size,
1697
- )
1698
- sampler = nn.AdaptiveAvgPool2d((hw, hw))
1699
- s2 = RegBlock(
1700
- depth,
1701
- hidden_size,
1702
- hidden_size,
1703
- )
1704
-
1705
- self.net = nn.Sequential(s1, sampler, s2)
1706
- self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size)
1707
-
1708
- def build_mlp(
1709
- self,
1710
- depth: int,
1711
- hidden_size: int,
1712
- output_hidden_size: int,
1713
- ):
1714
- layers = [nn.Linear(hidden_size, output_hidden_size)]
1715
- for _ in range(1, depth):
1716
- layers.append(nn.SiLU())
1717
- layers.append(nn.Linear(output_hidden_size, output_hidden_size))
1718
- return nn.Sequential(*layers)
1719
-
1720
- def load_sharded_checkpoint(
1721
- model, folder, pick_prefix="", replace_prefix_list=[], replace_prefix_dict={}, print_info=True
1722
- ):
1723
- if folder is None:
1724
- return {}
1725
-
1726
- files = os.listdir(folder)
1727
-
1728
- # find relevant files
1729
- pytorch_bin_files = [file for file in files if file.startswith("pytorch_model") and file.endswith(".bin")]
1730
- safetensor_files = [file for file in files if file.endswith(".safetensors")]
1731
- shard_index_file = [file for file in files if file.endswith(".index.json")]
1732
-
1733
- # check if sharded
1734
- index_present = len(shard_index_file) > 0
1735
- index_file = os.path.join(folder, shard_index_file[0]) if index_present else []
1736
-
1737
- # check if safetensor
1738
- is_safetensor = len(safetensor_files) > 0
1739
-
1740
- model_keys = model.state_dict().keys()
1741
-
1742
- if is_safetensor:
1743
- from safetensors.torch import load_file
1744
-
1745
- load_function = load_file
1746
- shard_files = safetensor_files
1747
- else:
1748
- load_function = partial(torch.load, map_location="cpu")
1749
- shard_files = pytorch_bin_files
1750
-
1751
- # sharded case
1752
- if index_present:
1753
- with open(index_file, "r", encoding="utf-8") as f:
1754
- index = json.load(f)
1755
- loaded_keys = index["weight_map"].keys()
1756
- if pick_prefix:
1757
- loaded_keys = [k[len(pick_prefix) :] for k in loaded_keys if k.startswith(pick_prefix)]
1758
- if replace_prefix_list:
1759
- for rep_prefix in replace_prefix_list:
1760
- loaded_keys = [k[len(rep_prefix) :] if k.startswith(rep_prefix) else k for k in loaded_keys]
1761
- if replace_prefix_dict:
1762
- for rep_prefix in replace_prefix_dict:
1763
- loaded_keys = [
1764
- k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k
1765
- for k in loaded_keys
1766
- ]
1767
-
1768
- for i, shard_file in enumerate(shard_files):
1769
- state_dict = load_function(os.path.join(folder, shard_file))
1770
-
1771
- # if pick_prefix, use only pick
1772
- if pick_prefix:
1773
- state_dict = {k[len(pick_prefix) :]: v for k, v in state_dict.items() if k.startswith(pick_prefix)}
1774
-
1775
- for rep_prefix in replace_prefix_list:
1776
- state_dict = {k[len(rep_prefix) :] if k.startswith(rep_prefix) else k: v for k, v in state_dict.items()}
1777
-
1778
- for rep_prefix in replace_prefix_dict:
1779
- state_dict = {
1780
- k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k: v
1781
- for k, v in state_dict.items()
1782
- }
1783
-
1784
- if is_deepspeed_zero3_enabled():
1785
- # torch.distributed.barrier()
1786
- rank = torch.distributed.get_rank()
1787
- print(f"# [info] ZeRo3 - load sharded no {i}, rank {rank}")
1788
- load_state_dict_into_model(model, state_dict, strict=False)
1789
- elif is_fsdp_enabled():
1790
- if is_local_dist_rank_0():
1791
- model.load_state_dict(state_dict, strict=False)
1792
- else:
1793
- model.load_state_dict(state_dict, strict=False)
1794
- # Make sure memory is freed before we load the next state dict.
1795
-
1796
- if not index_present:
1797
- loaded_keys = state_dict.keys()
1798
-
1799
- del state_dict
1800
- gc.collect()
1801
-
1802
- # missing keys
1803
- missing_keys = [key for key in model_keys if key not in loaded_keys]
1804
- unexpected_keys = [key for key in loaded_keys if key not in model_keys]
1805
-
1806
- if get_rank() == 0 and print_info:
1807
- print(f"[info] missing_keys: {missing_keys}")
1808
- print(f"[info] unexpected_keys: {unexpected_keys}")
1809
-
1810
- return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}