File size: 8,778 Bytes
9ddc7da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | #!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
================================================
@author: Jaron
@time: 2024/02/20 16:21:56
@email: fjjth98@163.com
@description: QFormer projector, convert image and video into fixed-length tokens
================================================
"""
import math
import torch
import torch.nn as nn
from torch.nn.functional import interpolate
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel, Blip2QFormerEncoder
from .configuration_ccam_projector import CCAMConfig
class SimpleQFormerOutput(nn.Module):
# replace last residual MLP with normal MLP
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.output_size)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor = None) -> torch.Tensor:
return self.dense(hidden_states)
class SimpleQFormerIdentity(nn.Module):
# just to replace the first attention module with identity, since it is useless
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return hidden_states,
class CCAMModel(Blip2QFormerModel):
_auto_class = 'AutoModel'
config_class = CCAMConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = True
def __init__(self, config: CCAMConfig):
super(Blip2QFormerModel, self).__init__(config)
self.gradient_checkpointing = False
self.config = config
self.num_query_tokens = config.num_query_tokens
self.visual_attn_mask_type = config.visual_attn_mask_type
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = Blip2QFormerEncoder(config)
self.encoder.layer[0].attention = SimpleQFormerIdentity() # replace the 1st attention module with identity
self.encoder.layer[-1].output_query = SimpleQFormerOutput(config)
# initialize query tokens
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.hidden_size))
# initialize pos embed
self.spatial_pos_embed = self._create_pos_embed(*config.spatial_resolution, type=config.spatial_pos_embed_type) # (H, W, C)
self.temporal_pos_embed = self._create_pos_embed(config.temporal_resolution, type=config.temporal_pos_embed_type) # (T, C)
# initialize query attn mask
if config.query_attn_mask_type == 'full':
self.query_attn_mask = None
elif config.query_attn_mask_type == 'causal':
query_attn_mask = torch.ones(self.num_query_tokens, self.num_query_tokens)
q = torch.arange(self.num_query_tokens)
query_attn_mask.masked_fill_(q > q[:, None], 0)
self.query_attn_mask = query_attn_mask[None]
else:
raise NotImplementedError(f'Do not support {self.query_attn_mask} query_attn_mask')
self.post_init()
def _create_pos_embed(self, *size: int, type: str = 'none') -> torch.Tensor:
C = self.config.encoder_hidden_size
if type == 'none':
pos_embed = None
elif type == 'learnable':
pos_embed = nn.Parameter(.02 * torch.randn(*size, C))
elif type == 'cosine':
total_len = 1
for i in size:
total_len *= i
raw = torch.outer(torch.arange(total_len), torch.exp(torch.arange(0, C, 2) * (-math.log(10000.) / C)))
pos_embed = nn.Parameter(torch.stack((raw.sin(), raw.cos()), dim=-1).view(*size, C), requires_grad=False)
else:
raise NotImplementedError(f'Do not support {type} position embeddings')
return pos_embed
def get_attn_mask(self, embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Get visual_attn_mask and query_attn_mask if needed
embeddings (torch.Tensor): (B, T, L, C)
"""
B, T, L, _ = embeddings.size()
device = embeddings.device
# visual attn mask only work for videos
if T > 1:
if self.visual_attn_mask_type == 'ccam':
base_attn_mask = torch.ones(T, T, device=device)
t = torch.arange(T, device=device)
base_attn_mask.masked_fill_(t > t[:, None], 0)
visual_attn_mask = torch.cat((
torch.kron(
base_attn_mask,
torch.ones(self.num_query_tokens // T, L, device=device)
),
torch.ones(self.num_query_tokens % T, T * L, device=device)
), dim=0)[None].expand(B, -1, -1)
elif self.visual_attn_mask_type == 'full':
visual_attn_mask = None
else:
raise NotImplementedError(f'Do not support {self.visual_attn_mask_type} attn_mask')
else:
visual_attn_mask = None
if self.query_attn_mask is None:
query_attn_mask = None
else:
query_attn_mask = self.query_attn_mask.expand(B, -1, -1)
return visual_attn_mask, query_attn_mask
def batch_forward_no_spatial(self, visual_embeds: torch.Tensor) -> torch.Tensor:
"""Batch forward without spatial mask position embeddings
Args:
visual_embeds (torch.Tensor): (B, T, L, C)
Returns:
torch.Tensor: (B, Q, C)
"""
B, T, _, C = visual_embeds.size()
query_embeds = self.query_tokens.expand(B, -1, -1)
visual_attn_mask, query_attn_mask = self.get_attn_mask(visual_embeds)
# add temporal position embeddings
if self.temporal_pos_embed is not None:
if T == self.temporal_pos_embed.size(0):
pos_embed = self.temporal_pos_embed
elif T == 1:
pos_embed = 0. * self.temporal_pos_embed[:1] # for deepspeed
else:
pos_embed = interpolate(
self.temporal_pos_embed.T[None], # (1, C, t)
size=(T,),
mode='linear',
align_corners=False
)[0].T # (T, C)
visual_embeds = visual_embeds + pos_embed.view(1, T, 1, C)
visual_embeds = visual_embeds.flatten(1, 2)
return super().forward(
query_embeds=query_embeds,
attention_mask=query_attn_mask,
encoder_hidden_states=visual_embeds,
encoder_attention_mask=visual_attn_mask
)[0]
def forward(self, visual_embeds: torch.Tensor, split_sizes: list[int], unmasked_ids: torch.LongTensor = None):
"""
visual_embeds (torch.Tensor): (T, L, C)
split_sizes (list[int]): [t0, t1, ...] sum_i ti=T
unmasked_ids (torch.LongTensor): If provided, should be in the shape of (T, L) whose value v 0<=v<=HW-1
output_attentions (_type_, optional): _description_. Defaults to None.
output_hidden_states (_type_, optional): _description_. Defaults to None.
return_dict (_type_, optional): _description_. Defaults to None.
"""
_, L, C = visual_embeds.size()
# add spatial position embeddings
if self.spatial_pos_embed is not None:
pos_embed = self.spatial_pos_embed.view(-1, C) # (H*W, C)
if unmasked_ids is None:
pos_embed = pos_embed.view(1, L, C) # if not provided, L must equals to H*W
else:
pos_embed = pos_embed[unmasked_ids] # (T, L, C)
visual_embeds = visual_embeds + pos_embed
# all inputs in this batch has the same t
if len(set(split_sizes)) == 1:
visual_embeds = visual_embeds.view(len(split_sizes), split_sizes[0], L, C)
output = self.batch_forward_no_spatial(visual_embeds)
else:
visual_embeds = visual_embeds.split(split_sizes, dim=0)
# group visual_embeds accoding to the number of frames
output, group_visual_embeds = [None] * len(split_sizes), {}
for idx, (embed, t) in enumerate(zip(visual_embeds, split_sizes)):
if t in group_visual_embeds:
group_visual_embeds[t][0].append(idx)
group_visual_embeds[t][1].append(embed)
else:
group_visual_embeds[t] = [[idx], [embed]]
for idx, embeds in group_visual_embeds.values():
cur_output = self.batch_forward_no_spatial(torch.stack(embeds, dim=0))
for i, j in enumerate(idx):
output[j] = cur_output[i]
output = torch.stack(output, dim=0)
return output
|