| | import contextlib |
| | import math |
| |
|
| | import einops |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| | from transformers import Qwen2ForCausalLM, SiglipVisionModel |
| | from transformers.cache_utils import Cache |
| | from transformers.generation.utils import GenerationMixin |
| | from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast |
| | from transformers.modeling_utils import PreTrainedModel |
| |
|
| | from .configuration_nvila import NVILAConfig |
| |
|
| | MM_HIDDEN_SIZE = 3456 |
| |
|
| |
|
| | class NVILAMultiModalProjectorDownsampleBlock(nn.Module): |
| | def forward(self, x: Tensor) -> Tensor: |
| | batch_size, sequence_length, hidden_size = x.shape |
| |
|
| | feat_size = math.isqrt(sequence_length) |
| |
|
| | features = x.reshape(batch_size, feat_size, feat_size, hidden_size) |
| |
|
| | pad_after = feat_size % 2 |
| | if pad_after > 0: |
| | features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after)) |
| | feat_size = feat_size + pad_after |
| |
|
| | features = features.reshape(batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size) |
| | features = features.permute(0, 1, 3, 2, 4, 5).contiguous() |
| | features = features.reshape(batch_size, -1, 4 * hidden_size) |
| |
|
| | return features |
| |
|
| |
|
| | class NVILAMultiModalProjector(nn.Module): |
| | def __init__(self, config: NVILAConfig): |
| | super().__init__() |
| |
|
| | self.layers = nn.Sequential( |
| | NVILAMultiModalProjectorDownsampleBlock(), |
| | nn.LayerNorm(MM_HIDDEN_SIZE * 4), |
| | nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size), |
| | nn.GELU(), |
| | nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size), |
| | ) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | return self.layers(x) |
| |
|
| |
|
| | class NVILAForConditionalGeneration(PreTrainedModel, GenerationMixin): |
| | config_class = NVILAConfig |
| | base_model_prefix: str = "llm" |
| | _auto_class = "AutoModel" |
| | _supports_flash_attn_2 = True |
| | _supports_sdpa = True |
| |
|
| | def __init__(self, config: NVILAConfig): |
| | super().__init__(config) |
| |
|
| | self.config: NVILAConfig |
| |
|
| | @contextlib.contextmanager |
| | def default_torch_dtype(dtype): |
| | original_dtype = torch.get_default_dtype() |
| | torch.set_default_dtype(dtype) |
| | try: |
| | yield |
| | finally: |
| | torch.set_default_dtype(original_dtype) |
| |
|
| | with default_torch_dtype(config.torch_dtype): |
| | self.vision_tower = SiglipVisionModel(config.vision_config) |
| | self.mm_projector = NVILAMultiModalProjector(config) |
| | self.llm = Qwen2ForCausalLM(config.text_config) |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | *, |
| | block_sizes: list[tuple[int, int]] | None = None, |
| | input_ids: Tensor | None = None, |
| | inputs_embeds: Tensor | None = None, |
| | pixel_values: Tensor | None = None, |
| | pixel_values_videos: Tensor | None = None, |
| | **kwargs, |
| | ) -> CausalLMOutputWithPast: |
| | assert (input_ids is None) != ( |
| | inputs_embeds is None |
| | ), "Exactly one of `input_ids` or `inputs_embeds` must be specified." |
| |
|
| | if input_ids is not None and torch.any( |
| | torch.isin( |
| | input_ids, |
| | torch.tensor( |
| | [self.config.image_token_id, self.config.video_token_id], |
| | device=input_ids.device, |
| | ), |
| | ).any() |
| | ): |
| | inputs_embeds = self._embed( |
| | block_sizes=block_sizes, |
| | input_ids=input_ids, |
| | pixel_values=pixel_values, |
| | pixel_values_videos=pixel_values_videos, |
| | ) |
| | input_ids = None |
| |
|
| | outputs = self.llm( |
| | input_ids=input_ids, |
| | inputs_embeds=inputs_embeds, |
| | **kwargs, |
| | ) |
| |
|
| | return outputs |
| |
|
| | def _embed( |
| | self, |
| | *, |
| | block_sizes: list[tuple[int, int]] | None, |
| | input_ids: Tensor, |
| | pixel_values: Tensor | None, |
| | pixel_values_videos: Tensor | None, |
| | ) -> Tensor: |
| | inputs_embeds: Tensor = self.llm.model.embed_tokens(input_ids) |
| |
|
| | for pixel_values, media_token_id in [ |
| | (pixel_values, self.config.image_token_id), |
| | (pixel_values_videos, self.config.video_token_id), |
| | ]: |
| | if pixel_values is None: |
| | continue |
| |
|
| | vision_features = self._encode_vision( |
| | pixel_values, |
| | block_sizes=block_sizes, |
| | ) |
| | vision_features = einops.rearrange(vision_features, "n p d -> (n p) d") |
| |
|
| | inputs_embeds[input_ids == media_token_id] = vision_features |
| |
|
| | return inputs_embeds |
| |
|
| | def _encode_vision( |
| | self, |
| | pixel_values: Tensor, |
| | *, |
| | block_sizes: list[tuple[int, int]] | None = None, |
| | ) -> Tensor: |
| | vision_tower_output: BaseModelOutputWithPooling = self.vision_tower( |
| | pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype), |
| | output_hidden_states=True, |
| | ) |
| | assert vision_tower_output.hidden_states is not None |
| |
|
| | vision_features: Tensor = vision_tower_output.hidden_states[-2] |
| |
|
| | vision_features_list, block_sizes = merge_features_for_dynamic_s2( |
| | vision_features, |
| | block_sizes=block_sizes if block_sizes is not None else [None] * vision_features.shape[0], |
| | resize_output_to_scale_idx=-1, |
| | scales=[448, 896, 1344], |
| | ) |
| |
|
| | vision_features_list = [ |
| | split_chessboard(x, block_size[0], block_size[1]) |
| | for x, block_size in zip(vision_features_list, block_sizes) |
| | ] |
| |
|
| | vision_features = torch.cat([einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]) |
| |
|
| | vision_features = self.mm_projector(vision_features.to(self.device, self.dtype)) |
| |
|
| | vision_features_list = list( |
| | vision_features.split([block_size[0] * block_size[1] for block_size in block_sizes], dim=0) |
| | ) |
| | vision_features_list = [ |
| | merge_chessboard(x, block_size[0], block_size[1]) |
| | for x, block_size in zip(vision_features_list, block_sizes) |
| | ] |
| |
|
| | vision_features = torch.stack([einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]) |
| |
|
| | return vision_features |
| |
|
| |
|
| | |
| |
|
| |
|
| | def merge_chessboard(x, num_split_h, num_split_w): |
| | """ |
| | x: b * n * c or b * h * w * c |
| | out: b * c * h * w |
| | Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. |
| | """ |
| | B = x.shape[0] |
| | if x.dim() == 3: |
| | N = x.shape[1] |
| | x = einops.rearrange(x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)) |
| |
|
| | assert B % (num_split_h * num_split_w) == 0 |
| | b = B // (num_split_h * num_split_w) |
| |
|
| | x_merge = torch.cat( |
| | [ |
| | torch.cat( |
| | [x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1 |
| | ) |
| | for i in range(num_split_h) |
| | ], |
| | dim=-2, |
| | ) |
| |
|
| | return x_merge |
| |
|
| |
|
| | def merge_features_for_dynamic_s2(image_features, block_sizes, *, scales, resize_output_to_scale_idx): |
| | image_features_each_image = [] |
| | new_block_sizes = [] |
| | block_cnt = 0 |
| | for block_size_each_image in block_sizes: |
| | if block_size_each_image is None: |
| | cur_features = image_features[block_cnt : block_cnt + 1] |
| | cur_features = einops.rearrange(cur_features, "1 (h w) c -> 1 c h w", h=math.isqrt(cur_features.shape[1])) |
| | cur_features = cur_features.repeat(1, len(scales), 1, 1) |
| | image_features_each_image.append(cur_features) |
| | new_block_sizes.append((1, 1)) |
| | block_cnt += 1 |
| | else: |
| | cur_features_each_scale = [] |
| | for scale in scales[:-1]: |
| | num_blocks_this_scale = (scale // scales[0]) ** 2 |
| | cur_features_each_scale.append( |
| | merge_chessboard( |
| | image_features[block_cnt : block_cnt + num_blocks_this_scale], |
| | num_split_h=scale // scales[0], |
| | num_split_w=scale // scales[0], |
| | ) |
| | ) |
| | block_cnt += num_blocks_this_scale |
| | num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1] |
| | cur_features_each_scale.append( |
| | merge_chessboard( |
| | image_features[block_cnt : block_cnt + num_blocks_last_scale], |
| | num_split_h=block_size_each_image[0], |
| | num_split_w=block_size_each_image[1], |
| | ) |
| | ) |
| | block_cnt += num_blocks_last_scale |
| |
|
| | |
| | output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:] |
| | cur_features = torch.cat( |
| | [ |
| | F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to( |
| | cur_features_each_scale[i].dtype |
| | ) |
| | for i in range(len(cur_features_each_scale)) |
| | ], |
| | dim=1, |
| | ) |
| | |
| |
|
| | image_features_each_image.append(cur_features) |
| |
|
| | if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1: |
| | new_block_sizes.append(block_size_each_image) |
| | else: |
| | new_block_sizes.append( |
| | ( |
| | scales[resize_output_to_scale_idx] // scales[0], |
| | scales[resize_output_to_scale_idx] // scales[0], |
| | ) |
| | ) |
| |
|
| | assert block_cnt == len( |
| | image_features |
| | ), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!" |
| |
|
| | return image_features_each_image, new_block_sizes |
| |
|
| |
|
| | def split_chessboard(x, num_split_h, num_split_w): |
| | """ |
| | x: b * c * h * w |
| | out: b * c * h * w |
| | Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension |
| | """ |
| | B, C, H, W = x.shape |
| | assert H % num_split_h == 0 and W % num_split_w == 0 |
| | h, w = H // num_split_h, W // num_split_w |
| | x_split = torch.cat( |
| | [x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)], |
| | dim=0, |
| | ) |
| | return x_split |
| |
|