| |
| |
| |
| |
| |
| from typing import Optional, Tuple, Union |
| from functools import partial |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
|
|
| from einops import rearrange |
| from timm.models.layers import DropPath |
| from torch import nn |
| from transformers.activations import ACT2FN |
| from transformers.modeling_outputs import (BaseModelOutput, |
| BaseModelOutputWithPooling) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
|
|
| from .configuration_navil_vit import NaViLVisionConfig |
| from .modular_intern_vit import ( |
| InternVisionFlashAttention2, |
| InternVisionSdpaAttention, |
| InternMLP, |
| NORM2FN, |
| InternVisionRotaryEmbedding, |
| ) |
|
|
| try: |
| |
| from flash_attn import flash_attn_varlen_func |
| from flash_attn.layers.rotary import apply_rotary_emb |
| has_flash_attn = True |
| except: |
| print('FlashAttention is not installed.') |
| has_flash_attn = False |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class NaViLVisionEmbeddingsAnyRes(nn.Module): |
| def __init__(self, config: NaViLVisionConfig): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.image_size = config.image_size |
| self.patch_size = config.patch_size |
| self.merge_size = int(1.0 / config.downsample_ratio) |
|
|
| self.patch_embedding = nn.Conv2d( |
| in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size |
| ) |
|
|
| self.num_patches = (self.image_size // self.patch_size) ** 2 |
| self.num_positions = self.num_patches + 1 |
|
|
| def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
| target_dtype = self.patch_embedding.weight.dtype |
| patch_embeds = self.patch_embedding(pixel_values) |
| batch_size, _, height, width = patch_embeds.shape |
|
|
| return patch_embeds.flatten(1) |
|
|
|
|
| class NaViLVisionEncoderLayerAnyRes(nn.Module): |
| def __init__(self, config: NaViLVisionConfig, drop_path_rate: float): |
| super().__init__() |
| self.embed_dim = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.norm_type = config.norm_type |
|
|
| if has_flash_attn: |
| self.attn = InternVisionFlashAttention2(config) |
| else: |
| self.attn = InternVisionSdpaAttention(config) |
| self.mlp = InternMLP(config) |
| self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) |
| self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) |
|
|
| self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) |
| self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) |
| self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
| self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens, |
| rotary_pos_emb |
| ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| """ |
| hidden_states = hidden_states + self.drop_path1( |
| self.attn( |
| self.norm1(hidden_states), |
| cu_seqlens=cu_seqlens, |
| rotary_pos_emb=rotary_pos_emb, |
| ) * self.ls1) |
|
|
| hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2) |
|
|
| return hidden_states |
|
|
|
|
| class NaViLVisionEncoderAnyRes(nn.Module): |
| """ |
| Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
| [`InternEncoderLayer`]. |
| |
| Args: |
| config (`InternConfig`): |
| The corresponding vision configuration for the `InternEncoder`. |
| """ |
|
|
| def __init__(self, config: NaViLVisionConfig): |
| super().__init__() |
| self.config = config |
| |
| dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] |
| self.layers = nn.ModuleList([ |
| NaViLVisionEncoderLayerAnyRes(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) |
| self.gradient_checkpointing = True |
|
|
| head_dim = config.hidden_size // config.num_attention_heads |
| self.rotary_pos_emb = InternVisionRotaryEmbedding(head_dim // 2) |
|
|
| self.merge_size = int(1.0 / config.downsample_ratio) |
| self.merge_unit = self.merge_size * self.merge_size |
| self.patch_size = config.patch_size |
| self.fullatt_block_indexes = config.fullatt_block_indexes |
| self.window_size = config.window_size |
| |
| def rot_pos_emb(self, grid_thw): |
| pos_ids = [] |
| for t, h, w in grid_thw: |
| hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
| hpos_ids = hpos_ids.reshape( |
| h // self.merge_size, |
| self.merge_size, |
| w // self.merge_size, |
| self.merge_size, |
| ) |
| hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
| hpos_ids = hpos_ids.flatten() |
|
|
| wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
| wpos_ids = wpos_ids.reshape( |
| h // self.merge_size, |
| self.merge_size, |
| w // self.merge_size, |
| self.merge_size, |
| ) |
| wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
| wpos_ids = wpos_ids.flatten() |
| pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
| pos_ids = torch.cat(pos_ids, dim=0) |
| max_grid_size = grid_thw[:, 1:].max() |
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
| return rotary_pos_emb |
| |
| def get_window_index(self, grid_thw): |
| window_index: list = [] |
| cu_window_seqlens: list = [0] |
| window_index_id = 0 |
| vit_merger_window_size = self.window_size // self.merge_size |
| assert vit_merger_window_size > 0 |
|
|
| for grid_t, grid_h, grid_w in grid_thw: |
| llm_grid_h, llm_grid_w = ( |
| grid_h // self.merge_size, |
| grid_w // self.merge_size, |
| ) |
| index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) |
| pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
| pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
| num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
| num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
| index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
| index_padded = index_padded.reshape( |
| grid_t, |
| num_windows_h, |
| vit_merger_window_size, |
| num_windows_w, |
| vit_merger_window_size, |
| ) |
| index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
| grid_t, |
| num_windows_h * num_windows_w, |
| vit_merger_window_size, |
| vit_merger_window_size, |
| ) |
| seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
| index_padded = index_padded.reshape(-1) |
| index_new = index_padded[index_padded != -100] |
| window_index.append(index_new + window_index_id) |
| cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1] |
| cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
| window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
| window_index = torch.cat(window_index, dim=0) |
|
|
| return window_index, cu_window_seqlens |
|
|
| def forward( |
| self, |
| inputs_embeds, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| grid_thw: Optional[torch.Tensor] = None, |
| ) -> Union[Tuple, BaseModelOutput]: |
| r""" |
| Args: |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Embedded representation of the inputs. Should be float, not int tokens. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| encoder_states = () if output_hidden_states else None |
| hidden_states = inputs_embeds |
|
|
| rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
| cu_window_seqlens = torch.tensor( |
| cu_window_seqlens, |
| device=hidden_states.device, |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
|
|
| seq_len, _ = hidden_states.size() |
| hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1) |
| hidden_states = hidden_states[window_index, :, :] |
| hidden_states = hidden_states.reshape(seq_len, -1) |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.merge_unit, self.merge_unit, -1) |
| rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
|
|
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, |
| |
| |
| |
| |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
|
| for idx, encoder_layer in enumerate(self.layers): |
| if (self.fullatt_block_indexes is None) or (idx in self.fullatt_block_indexes): |
| cu_seqlens_now = cu_seqlens |
| else: |
| cu_seqlens_now = cu_window_seqlens |
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| partial(encoder_layer, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb), |
| hidden_states) |
| else: |
| layer_outputs = encoder_layer( |
| hidden_states, |
| cu_seqlens=cu_seqlens_now, |
| rotary_pos_emb=rotary_pos_emb, |
| ) |
| hidden_states = layer_outputs |
|
|
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, encoder_states] if v is not None) |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, hidden_states=encoder_states |
| ) |
|
|
|
|
| class NaViLVisionModelAnyRes(PreTrainedModel): |
| main_input_name = 'pixel_values' |
| config_class = NaViLVisionConfig |
| _no_split_modules = ['NaViLVisionEncoderLayerAnyRes'] |
|
|
| def __init__(self, config: NaViLVisionConfig): |
| super().__init__(config) |
| self.config = config |
| |
| self.merge_size = int(1.0 / config.downsample_ratio) |
| self.embeddings = NaViLVisionEmbeddingsAnyRes(config) |
| self.encoder = NaViLVisionEncoderAnyRes(config) |
| |
| def get_input_embeddings(self): |
| return self.embeddings |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_embeds: Optional[torch.FloatTensor] = None, |
| grid_thw: Optional[torch.Tensor] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if pixel_values is None and pixel_embeds is None: |
| raise ValueError('You have to specify pixel_values or pixel_embeds') |
|
|
| if pixel_embeds is not None: |
| hidden_states = pixel_embeds |
| else: |
| if len(pixel_values.shape) == 4: |
| hidden_states = self.embeddings(pixel_values) |
| else: |
| raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') |
| |
| encoder_outputs = self.encoder( |
| inputs_embeds=hidden_states, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| grid_thw=grid_thw |
| ) |
| last_hidden_state = encoder_outputs.last_hidden_state |
| |
|
|
| last_hidden_state = last_hidden_state.unsqueeze(1).reshape(-1, self.merge_size, self.merge_size, last_hidden_state.shape[-1]) |
|
|
| if not return_dict: |
| return (last_hidden_state, ) + encoder_outputs[1:] |
| |
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=None, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|