#!/usr/bin/env python # -*- coding: utf-8 -*- """ ================================================ @author: Jaron @time: 2024/02/20 16:21:56 @email: fjjth98@163.com @description: QFormer projector, convert image and video into fixed-length tokens ================================================ """ import math import torch import torch.nn as nn from torch.nn.functional import interpolate from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel, Blip2QFormerEncoder from .configuration_ccam_projector import CCAMConfig class SimpleQFormerOutput(nn.Module): # replace last residual MLP with normal MLP def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.output_size) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor = None) -> torch.Tensor: return self.dense(hidden_states) class SimpleQFormerIdentity(nn.Module): # just to replace the first attention module with identity, since it is useless def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: return hidden_states, class CCAMModel(Blip2QFormerModel): _auto_class = 'AutoModel' config_class = CCAMConfig base_model_prefix = 'model' supports_gradient_checkpointing = True def __init__(self, config: CCAMConfig): super(Blip2QFormerModel, self).__init__(config) self.gradient_checkpointing = False self.config = config self.num_query_tokens = config.num_query_tokens self.visual_attn_mask_type = config.visual_attn_mask_type self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.encoder = Blip2QFormerEncoder(config) self.encoder.layer[0].attention = SimpleQFormerIdentity() # replace the 1st attention module with identity self.encoder.layer[-1].output_query = SimpleQFormerOutput(config) # initialize query tokens self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.hidden_size)) # initialize pos embed self.spatial_pos_embed = self._create_pos_embed(*config.spatial_resolution, type=config.spatial_pos_embed_type) # (H, W, C) self.temporal_pos_embed = self._create_pos_embed(config.temporal_resolution, type=config.temporal_pos_embed_type) # (T, C) # initialize query attn mask if config.query_attn_mask_type == 'full': self.query_attn_mask = None elif config.query_attn_mask_type == 'causal': query_attn_mask = torch.ones(self.num_query_tokens, self.num_query_tokens) q = torch.arange(self.num_query_tokens) query_attn_mask.masked_fill_(q > q[:, None], 0) self.query_attn_mask = query_attn_mask[None] else: raise NotImplementedError(f'Do not support {self.query_attn_mask} query_attn_mask') self.post_init() def _create_pos_embed(self, *size: int, type: str = 'none') -> torch.Tensor: C = self.config.encoder_hidden_size if type == 'none': pos_embed = None elif type == 'learnable': pos_embed = nn.Parameter(.02 * torch.randn(*size, C)) elif type == 'cosine': total_len = 1 for i in size: total_len *= i raw = torch.outer(torch.arange(total_len), torch.exp(torch.arange(0, C, 2) * (-math.log(10000.) / C))) pos_embed = nn.Parameter(torch.stack((raw.sin(), raw.cos()), dim=-1).view(*size, C), requires_grad=False) else: raise NotImplementedError(f'Do not support {type} position embeddings') return pos_embed def get_attn_mask(self, embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Get visual_attn_mask and query_attn_mask if needed embeddings (torch.Tensor): (B, T, L, C) """ B, T, L, _ = embeddings.size() device = embeddings.device # visual attn mask only work for videos if T > 1: if self.visual_attn_mask_type == 'ccam': base_attn_mask = torch.ones(T, T, device=device) t = torch.arange(T, device=device) base_attn_mask.masked_fill_(t > t[:, None], 0) visual_attn_mask = torch.cat(( torch.kron( base_attn_mask, torch.ones(self.num_query_tokens // T, L, device=device) ), torch.ones(self.num_query_tokens % T, T * L, device=device) ), dim=0)[None].expand(B, -1, -1) elif self.visual_attn_mask_type == 'full': visual_attn_mask = None else: raise NotImplementedError(f'Do not support {self.visual_attn_mask_type} attn_mask') else: visual_attn_mask = None if self.query_attn_mask is None: query_attn_mask = None else: query_attn_mask = self.query_attn_mask.expand(B, -1, -1) return visual_attn_mask, query_attn_mask def batch_forward_no_spatial(self, visual_embeds: torch.Tensor) -> torch.Tensor: """Batch forward without spatial mask position embeddings Args: visual_embeds (torch.Tensor): (B, T, L, C) Returns: torch.Tensor: (B, Q, C) """ B, T, _, C = visual_embeds.size() query_embeds = self.query_tokens.expand(B, -1, -1) visual_attn_mask, query_attn_mask = self.get_attn_mask(visual_embeds) # add temporal position embeddings if self.temporal_pos_embed is not None: if T == self.temporal_pos_embed.size(0): pos_embed = self.temporal_pos_embed elif T == 1: pos_embed = 0. * self.temporal_pos_embed[:1] # for deepspeed else: pos_embed = interpolate( self.temporal_pos_embed.T[None], # (1, C, t) size=(T,), mode='linear', align_corners=False )[0].T # (T, C) visual_embeds = visual_embeds + pos_embed.view(1, T, 1, C) visual_embeds = visual_embeds.flatten(1, 2) return super().forward( query_embeds=query_embeds, attention_mask=query_attn_mask, encoder_hidden_states=visual_embeds, encoder_attention_mask=visual_attn_mask )[0] def forward(self, visual_embeds: torch.Tensor, split_sizes: list[int], unmasked_ids: torch.LongTensor = None): """ visual_embeds (torch.Tensor): (T, L, C) split_sizes (list[int]): [t0, t1, ...] sum_i ti=T unmasked_ids (torch.LongTensor): If provided, should be in the shape of (T, L) whose value v 0<=v<=HW-1 output_attentions (_type_, optional): _description_. Defaults to None. output_hidden_states (_type_, optional): _description_. Defaults to None. return_dict (_type_, optional): _description_. Defaults to None. """ _, L, C = visual_embeds.size() # add spatial position embeddings if self.spatial_pos_embed is not None: pos_embed = self.spatial_pos_embed.view(-1, C) # (H*W, C) if unmasked_ids is None: pos_embed = pos_embed.view(1, L, C) # if not provided, L must equals to H*W else: pos_embed = pos_embed[unmasked_ids] # (T, L, C) visual_embeds = visual_embeds + pos_embed # all inputs in this batch has the same t if len(set(split_sizes)) == 1: visual_embeds = visual_embeds.view(len(split_sizes), split_sizes[0], L, C) output = self.batch_forward_no_spatial(visual_embeds) else: visual_embeds = visual_embeds.split(split_sizes, dim=0) # group visual_embeds accoding to the number of frames output, group_visual_embeds = [None] * len(split_sizes), {} for idx, (embed, t) in enumerate(zip(visual_embeds, split_sizes)): if t in group_visual_embeds: group_visual_embeds[t][0].append(idx) group_visual_embeds[t][1].append(embed) else: group_visual_embeds[t] = [[idx], [embed]] for idx, embeds in group_visual_embeds.values(): cur_output = self.batch_forward_no_spatial(torch.stack(embeds, dim=0)) for i, j in enumerate(idx): output[j] = cur_output[i] output = torch.stack(output, dim=0) return output