| | |
| | |
| | |
| | |
| | |
| | from functools import partial |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | from peft import LoraConfig, get_peft_model |
| | from timm.models.layers import DropPath |
| | from torch import nn |
| | from transformers import GenerationConfig |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import logging |
| |
|
| | from .configuration_internvl import InternVLConfig |
| | from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder, |
| | InternVisionModel) |
| | from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask |
| |
|
| | try: |
| | from .flash_attention import FlashAttention |
| | except: |
| | print('FlashAttention is not installed.') |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class InternVLPreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| |
|
| | config_class = InternVLConfig |
| | base_model_prefix = 'internvl' |
| | supports_gradient_checkpointing = True |
| | _keys_to_ignore_on_load_missing = [ |
| | r'position_ids', |
| | ] |
| | _no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM'] |
| | _skip_keys_device_placement = 'past_key_values' |
| | _keep_in_fp32_modules = ['wo'] |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | factor = self.config.initializer_range |
| | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=factor) |
| | if hasattr(module, 'bias') and module.bias is not None: |
| | module.bias.data.zero_() |
| | if isinstance(module, InternVisionEmbeddings): |
| | if hasattr(self.config, 'vision_config'): |
| | factor = self.config.vision_config.initializer_range |
| | nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) |
| | nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | elif isinstance(module, nn.Linear) and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | def _set_gradient_checkpointing(self, module, value=False): |
| | if isinstance(module, InternVisionModel): |
| | module.gradient_checkpointing = value |
| | if isinstance(module, InternVisionEncoder): |
| | module.gradient_checkpointing = value |
| |
|
| |
|
| | class CrossAttention(nn.Module): |
| | def __init__( |
| | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., |
| | proj_drop=0., attn_head_dim=None, out_dim=None): |
| | super().__init__() |
| | if out_dim is None: |
| | out_dim = dim |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | if attn_head_dim is not None: |
| | head_dim = attn_head_dim |
| | all_head_dim = head_dim * self.num_heads |
| | self.scale = qk_scale or head_dim ** -0.5 |
| | assert all_head_dim == dim |
| |
|
| | self.q = nn.Linear(dim, all_head_dim, bias=False) |
| | self.k = nn.Linear(dim, all_head_dim, bias=False) |
| | self.v = nn.Linear(dim, all_head_dim, bias=False) |
| |
|
| | if qkv_bias: |
| | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| | self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| | else: |
| | self.q_bias = None |
| | self.k_bias = None |
| | self.v_bias = None |
| |
|
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(all_head_dim, out_dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x, k=None, v=None): |
| | B, N, C = x.shape |
| | N_k = k.shape[1] |
| | N_v = v.shape[1] |
| |
|
| | q_bias, k_bias, v_bias = None, None, None |
| | if self.q_bias is not None: |
| | q_bias = self.q_bias |
| | k_bias = self.k_bias |
| | v_bias = self.v_bias |
| |
|
| | q = F.linear(input=x, weight=self.q.weight, bias=q_bias) |
| | q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) |
| |
|
| | k = F.linear(input=k, weight=self.k.weight, bias=k_bias) |
| | k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) |
| |
|
| | v = F.linear(input=v, weight=self.v.weight, bias=v_bias) |
| | v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) |
| |
|
| | q = q * self.scale |
| | attn = (q @ k.transpose(-2, -1)) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | return x |
| |
|
| |
|
| | class AttentiveBlock(nn.Module): |
| |
|
| | def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
| | drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None): |
| | super().__init__() |
| |
|
| | self.norm1_q = norm_layer(dim) |
| | self.norm1_k = norm_layer(dim) |
| | self.norm1_v = norm_layer(dim) |
| | self.cross_attn = CrossAttention( |
| | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, |
| | proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim) |
| |
|
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| |
|
| | def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None): |
| | x_q = self.norm1_q(x_q + pos_q) |
| | x_k = self.norm1_k(x_kv + pos_k) |
| | x_v = self.norm1_v(x_kv) |
| | x = self.cross_attn(x_q, k=x_k, v=x_v) |
| |
|
| | return x |
| |
|
| |
|
| | class AttentionPoolingBlock(AttentiveBlock): |
| |
|
| | def forward(self, x): |
| | x_q = x.mean(1, keepdim=True) |
| | x_kv, pos_q, pos_k = x, 0, 0 |
| | x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None) |
| | x = x.squeeze(1) |
| | return x |
| |
|
| |
|
| | class InternVLModel(InternVLPreTrainedModel): |
| | config_class = InternVLConfig |
| | main_input_name = 'pixel_values' |
| |
|
| | def __init__(self, config: InternVLConfig): |
| | super().__init__(config) |
| |
|
| | text_hidden_size = config.qllama_config.hidden_size |
| | vision_hidden_size = config.vision_config.hidden_size |
| | clip_embed_dim = config.clip_embed_dim |
| | attn_pool_num_heads = config.attn_pool_num_heads |
| | config.qllama_config.num_query_token = config.num_query_token |
| | self.num_query_token = config.num_query_token |
| | self.label_smoothing = config.label_smoothing |
| |
|
| | self.vision_model = InternVisionModel(config.vision_config) |
| | self.qllama = LlamaForCausalLM(config.qllama_config) |
| | self.query_tokens = nn.Parameter( |
| | torch.zeros(1, config.num_query_token, text_hidden_size) |
| | ) |
| |
|
| | self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) |
| | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
| | self.clip_projector = AttentionPoolingBlock( |
| | dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None, |
| | drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim) |
| | self.clip_projector2 = AttentionPoolingBlock( |
| | dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None, |
| | drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim) |
| | self.itm_head = nn.Linear(text_hidden_size, 2) |
| | self.gradient_checkpointing = True |
| |
|
| | |
| | |
| |
|
| | if config.use_backbone_lora: |
| | self.wrap_backbone_lora(r=config.use_backbone_lora) |
| | if config.use_qllama_lora: |
| | self.wrap_qllama_lora(r=config.use_qllama_lora) |
| | if config.force_image_size: |
| | self.vision_model.resize_pos_embeddings( |
| | old_size=config.vision_config.image_size, |
| | new_size=config.force_image_size, |
| | patch_size=config.vision_config.patch_size |
| | ) |
| |
|
| | def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
| | lora_config = LoraConfig( |
| | r=r, |
| | target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'], |
| | lora_alpha=lora_alpha, |
| | lora_dropout=lora_dropout, |
| | ) |
| | self.vision_model = get_peft_model(self.vision_model, lora_config) |
| | self.vision_model.print_trainable_parameters() |
| |
|
| | def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
| | lora_config = LoraConfig( |
| | r=r, |
| | target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', |
| | 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'], |
| | lora_alpha=lora_alpha, |
| | lora_dropout=lora_dropout, |
| | ) |
| | self.qllama = get_peft_model(self.qllama, lora_config) |
| | self.qllama.print_trainable_parameters() |
| |
|
| | def get_input_embeddings(self): |
| | return self.qllama.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value): |
| | self.qllama.set_input_embeddings(value) |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.qllama.set_output_embeddings(new_embeddings) |
| |
|
| | def get_output_embeddings(self) -> nn.Module: |
| | return self.qllama.get_output_embeddings() |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | pixel_values: torch.FloatTensor, |
| | input_ids: torch.FloatTensor, |
| | attention_mask: torch.LongTensor, |
| | generation_config: Optional[GenerationConfig] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | **generate_kwargs, |
| | ) -> torch.LongTensor: |
| |
|
| | vision_outputs = self.vision_model( |
| | pixel_values=pixel_values, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict) |
| | image_embeds = vision_outputs[0] |
| |
|
| | batch_size = image_embeds.shape[0] |
| | input_embeds = self.get_input_embeddings()(input_ids) |
| | query_tokens = self.query_tokens.repeat(batch_size, 1, 1) |
| | input_embeds = torch.cat([query_tokens, input_embeds], dim=1) |
| | image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) |
| | attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) |
| |
|
| | outputs = self.qllama.generate( |
| | inputs_embeds=input_embeds, |
| | attention_mask=attention_mask, |
| | vision_hidden_states=image_embeds, |
| | generation_config=generation_config, |
| | use_zero_attention_mask=True, |
| | **generate_kwargs, |
| | ) |
| |
|
| | return outputs |
| |
|
| | def get_text_features( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ): |
| | r""" |
| | Returns: |
| | text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`): |
| | The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that |
| | contains the language model logits, the past key values and the hidden states if |
| | `output_hidden_states=True`. |
| | ```""" |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | input_embeds = self.get_input_embeddings()(input_ids) |
| | attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to( |
| | input_embeds.device) |
| | attention_mask += _make_causal_mask( |
| | (attention_mask.shape[0], attention_mask.shape[2]), |
| | input_embeds.dtype, |
| | device=input_embeds.device |
| | ) |
| | if type(self.qllama.model) == LlamaForCausalLM: |
| | outputs = self.qllama.model.model.forward_train( |
| | inputs_embeds=input_embeds, |
| | vision_hidden_states=None, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ).last_hidden_state |
| | else: |
| | outputs = self.qllama.model.forward_train( |
| | inputs_embeds=input_embeds, |
| | vision_hidden_states=None, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ).last_hidden_state |
| | return outputs |
| |
|
| | def get_image_features( |
| | self, |
| | pixel_values: torch.FloatTensor, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ): |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | vision_outputs = self.vision_model( |
| | pixel_values=pixel_values, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict) |
| | image_embeds = vision_outputs[0] |
| | backbone_embeds = image_embeds |
| |
|
| | batch_size = image_embeds.shape[0] |
| | input_embeds = self.query_tokens.repeat(batch_size, 1, 1) |
| |
|
| | attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) |
| | attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to( |
| | input_embeds.device) |
| | if type(self.qllama.model) == LlamaForCausalLM: |
| | outputs = self.qllama.model.model.forward_train( |
| | inputs_embeds=input_embeds, |
| | vision_hidden_states=image_embeds, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ).last_hidden_state |
| | else: |
| | outputs = self.qllama.model.forward_train( |
| | inputs_embeds=input_embeds, |
| | vision_hidden_states=image_embeds, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ).last_hidden_state |
| | return backbone_embeds, outputs |
| |
|
| | def encode_image(self, image, mode): |
| | if mode == 'InternVL-C': |
| | vision_outputs = self.vision_model( |
| | pixel_values=image, |
| | output_hidden_states=False, |
| | return_dict=True) |
| | image_embeds = vision_outputs[0] |
| | image_embeds = self.clip_projector(image_embeds) |
| | elif mode == 'InternVL-G': |
| | backbone_embeds, image_embeds = self.get_image_features( |
| | pixel_values=image, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| | backbone_embeds = self.clip_projector(backbone_embeds) |
| | image_embeds = self.clip_projector2(image_embeds) |
| | |
| | backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True) |
| | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True) |
| | image_embeds = image_embeds + backbone_embeds |
| | else: |
| | raise NotImplementedError |
| | return image_embeds |
| |
|
| | def encode_text(self, text): |
| | attention_mask = text > 0 |
| | text_embeds = self.get_text_features( |
| | input_ids=text, |
| | attention_mask=attention_mask, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| | text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1] |
| | text_embeds = text_embeds @ self.text_projection |
| | return text_embeds |
| |
|
| | def forward(self, image, text, mode='InternVL-C'): |
| | assert mode in ['InternVL-C', 'InternVL-G'], 'mode must be InternVL-C or InternVL-G' |
| | image_features = self.encode_image(image, mode) |
| | text_features = self.encode_text(text) |
| |
|
| | |
| | image_features = image_features / image_features.norm(dim=1, keepdim=True) |
| | text_features = text_features / text_features.norm(dim=1, keepdim=True) |
| |
|
| | |
| | logit_scale = self.logit_scale.exp() |
| | logits_per_image = logit_scale * image_features @ text_features.t() |
| | logits_per_text = logits_per_image.t() |
| |
|
| | return logits_per_image, logits_per_text |
| |
|
| |
|
| | class InternVL_C(InternVLModel): |
| |
|
| | def encode_image(self, image): |
| | vision_outputs = self.vision_model( |
| | pixel_values=image, |
| | output_hidden_states=False, |
| | return_dict=True) |
| | image_embeds = vision_outputs[0] |
| | image_embeds = self.clip_projector(image_embeds) |
| | return image_embeds |
| |
|
| | def encode_text(self, text): |
| | attention_mask = text > 0 |
| | text_embeds = self.get_text_features( |
| | input_ids=text, |
| | attention_mask=attention_mask, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| | text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1] |
| | text_embeds = text_embeds @ self.text_projection |
| | return text_embeds |
| |
|
| | def forward(self, image, text): |
| | image_features = self.encode_image(image) |
| | text_features = self.encode_text(text) |
| |
|
| | |
| | image_features = image_features / image_features.norm(dim=1, keepdim=True) |
| | text_features = text_features / text_features.norm(dim=1, keepdim=True) |
| |
|
| | |
| | logit_scale = self.logit_scale.exp() |
| | logits_per_image = logit_scale * image_features @ text_features.t() |
| | logits_per_text = logits_per_image.t() |
| |
|
| | return logits_per_image, logits_per_text |
| |
|
| |
|
| | class InternVL_G(InternVLModel): |
| |
|
| | def encode_image(self, image): |
| | backbone_embeds, image_embeds = self.get_image_features( |
| | pixel_values=image, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| | backbone_embeds = self.clip_projector(backbone_embeds) |
| | image_embeds = self.clip_projector2(image_embeds) |
| | |
| | backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True) |
| | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True) |
| | image_embeds = image_embeds + backbone_embeds |
| | return image_embeds |
| |
|
| | def encode_text(self, text): |
| | attention_mask = text > 0 |
| | text_embeds = self.get_text_features( |
| | input_ids=text, |
| | attention_mask=attention_mask, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| | text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1] |
| | text_embeds = text_embeds @ self.text_projection |
| | return text_embeds |
| |
|
| | def forward(self, image, text): |
| | image_features = self.encode_image(image) |
| | text_features = self.encode_text(text) |
| |
|
| | |
| | image_features = image_features / image_features.norm(dim=1, keepdim=True) |
| | text_features = text_features / text_features.norm(dim=1, keepdim=True) |
| |
|
| | |
| | logit_scale = self.logit_scale.exp() |
| | logits_per_image = logit_scale * image_features @ text_features.t() |
| | logits_per_text = logits_per_image.t() |
| |
|
| | return logits_per_image, logits_per_text |
| |
|