zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from typing import Optional
from torch import Tensor
from torch.nn import functional as F
from .configuration_projector import ProjectorConfig_OMG_LLaVA
class Naive_Proj(nn.Module):
def __init__(self, config, rm_prior_embedding=False,
rm_query=False):
super().__init__()
query_channels = config.query_channels
self.query_channels = query_channels
feat_channels = config.feat_channels
if isinstance(query_channels, tuple):
query_channels = query_channels[0]
if isinstance(feat_channels, tuple):
feat_channels = feat_channels[0]
add_cross_attn_layer = config.add_cross_attn_layer
self.add_cross_attn_layer = config.add_cross_attn_layer
query_channels = query_channels * 2 # feat + embed
self.query_proj = nn.Linear(query_channels, feat_channels)
modules = [
nn.Linear(
feat_channels,
config.llm_hidden_size,
bias=config.bias)
]
for _ in range(1, config.depth):
modules.append(ACT2FN[config.hidden_act])
modules.append(
nn.Linear(
config.llm_hidden_size,
config.llm_hidden_size,
bias=config.bias))
self.model = nn.Sequential(*modules)
if add_cross_attn_layer:
print("Using Cross Attention Layer at Projector !!!")
self.query_cross_attn = CrossAttentionLayer(
d_model=config.llm_hidden_size,
nhead=32,
)
self.query_ffn = FFNLayer(
d_model=config.llm_hidden_size,
dim_feedforward=4096,
)
else:
self.query_cross_attn = None
self.query_ffn = None
modules = [
nn.Linear(
feat_channels + query_channels,
config.llm_hidden_size,
bias=config.bias)
]
for _ in range(1, config.depth):
modules.append(ACT2FN[config.hidden_act])
modules.append(
nn.Linear(
config.llm_hidden_size,
config.llm_hidden_size,
bias=config.bias))
self.model_feat = nn.Sequential(*modules)
self.seperate_embed = nn.Embedding(1, config.llm_hidden_size)
self.rm_prior_embedding = rm_prior_embedding
self.rm_query = rm_query
visual_prompt_proj = config.visual_prompt_proj
self.visual_prompt_proj = visual_prompt_proj
if not visual_prompt_proj:
self.visual_prompt_query_proj = None
self.visual_prompt_query_model = None
self.visual_prompt_query_cross_attn = None
self.visual_prompt_query_ffn = None
else:
print("Initialized all Layers for Visual Prompt in Projector !!!")
self.visual_prompt_query_proj = nn.Linear(query_channels, feat_channels)
modules = [
nn.Linear(
feat_channels,
config.llm_hidden_size,
bias=config.bias)
]
for _ in range(1, config.depth):
modules.append(ACT2FN[config.hidden_act])
modules.append(
nn.Linear(
config.llm_hidden_size,
config.llm_hidden_size,
bias=config.bias))
self.visual_prompt_query_model = nn.Sequential(*modules)
if add_cross_attn_layer:
self.visual_prompt_query_cross_attn = CrossAttentionLayer(
d_model=config.llm_hidden_size,
nhead=32,
)
self.visual_prompt_query_ffn = FFNLayer(
d_model=config.llm_hidden_size,
dim_feedforward=4096,
)
else:
self.visual_prompt_query_cross_attn = None
self.visual_prompt_query_ffn = None
def forward(self, x):
clip_feature, query_feat, attention_mask = x
query_feat_copy = query_feat[0, :1] # (1, 1, c)
# clip feature (bs, hw, c + 2 * q_c)
# query_feat (bs, q, c)
# attention_mask (bs, q, hw)
if self.rm_prior_embedding:
clip_feature_feat = clip_feature[:, :, :-512]
clip_feature_query = clip_feature[:, :, -512:] * 0.0
clip_feature = torch.cat([clip_feature_feat, clip_feature_query], dim=-1)
query_feat = self.query_proj(query_feat)
valid_mask = attention_mask.sum(dim=-1) < attention_mask.shape[-1] # (bs, q)
# valid_mask # (bs, q)
# query_feat (bs, q, c)
# clip_feature (bs, hw, c)
# attn_map (bs, q, hw)
bs, n_q = query_feat.shape[:2]
layer_outputs = self.model(query_feat)
# filter
clip_feature_out = clip_feature
clip_feature_out = self.model_feat(clip_feature_out)
ret = []
valid_queries_embeddings = []
for layer_output, keep in zip(layer_outputs, valid_mask):
valid_queries_embeddings.append(layer_output[keep])
self.valid_queries_embeddings = valid_queries_embeddings
self.last_clip_feature = clip_feature_out
for clip_feat, layer_output, keep in zip(clip_feature_out, layer_outputs, valid_mask):
valid_layer_output = layer_output[keep]
if self.add_cross_attn_layer:
valid_layer_output = self.query_cross_attn(
valid_layer_output.unsqueeze(1), clip_feat.unsqueeze(1),
)[:, 0]
valid_layer_output = self.query_ffn(valid_layer_output)
if self.rm_query:
ret.append(clip_feat + torch.mean(self.seperate_embed.weight) * 0.0 + torch.mean(valid_layer_output) * 0.0)
else:
ret.append(torch.cat([clip_feat, self.seperate_embed.weight, valid_layer_output], dim=0))
# generate zero using visual prompt projector if valid
if self.visual_prompt_proj:
visual_prompt_embeddings = query_feat_copy.to(self.visual_prompt_query_proj.weight.dtype)
visual_prompt_embeddings = self.visual_prompt_query_proj(visual_prompt_embeddings)
visual_prompt_embeddings = self.visual_prompt_query_model(visual_prompt_embeddings) # (B, C)
if self.add_cross_attn_layer:
clip_feat = self.last_clip_feature[0] # (B, HW, C)
visual_prompt_embeddings = self.visual_prompt_query_cross_attn(
visual_prompt_embeddings.unsqueeze(1), clip_feat.unsqueeze(1),
)[:, 0]
visual_prompt_embeddings = self.visual_prompt_query_ffn(visual_prompt_embeddings)
self.visual_prompt_zero = visual_prompt_embeddings.sum() * 0.0
else:
self.visual_prompt_zero = 0.0
return ret
def forward_visual_prompts_embeddings(self, visual_prompt_embeddings, batch_idxs):
if self.visual_prompt_proj:
visual_prompt_embeddings = visual_prompt_embeddings.to(self.visual_prompt_query_proj.weight.dtype)
visual_prompt_embeddings = self.visual_prompt_query_proj(visual_prompt_embeddings)
visual_prompt_embeddings = self.visual_prompt_query_model(visual_prompt_embeddings) # (B, C)
if self.add_cross_attn_layer:
clip_feat = self.last_clip_feature[batch_idxs].permute(1, 0, 2) # (B, HW, C)
visual_prompt_embeddings = self.visual_prompt_query_cross_attn(
visual_prompt_embeddings.unsqueeze(0), clip_feat,
)[0, :]
visual_prompt_embeddings = self.visual_prompt_query_ffn(visual_prompt_embeddings)
else:
visual_prompt_embeddings = visual_prompt_embeddings.to(self.query_proj.weight.dtype)
visual_prompt_embeddings = self.query_proj(visual_prompt_embeddings)
visual_prompt_embeddings = self.model(visual_prompt_embeddings) # (B, C)
if self.add_cross_attn_layer:
clip_feat = self.last_clip_feature[batch_idxs].permute(1, 0, 2) # (B, HW, C)
visual_prompt_embeddings = self.query_cross_attn(
visual_prompt_embeddings.unsqueeze(0), clip_feat,
)[0, :]
visual_prompt_embeddings = self.query_ffn(visual_prompt_embeddings)
return visual_prompt_embeddings
def init_visual_prompt_weights(self):
if self.visual_prompt_query_proj is not None:
self.visual_prompt_query_proj.load_state_dict(self.query_proj.state_dict())
if self.visual_prompt_query_model is not None:
self.visual_prompt_query_model.load_state_dict(self.model.state_dict())
if self.visual_prompt_query_cross_attn is not None:
self.visual_prompt_query_cross_attn.load_state_dict(
self.query_cross_attn.state_dict()
)
if self.visual_prompt_query_ffn is not None:
self.visual_prompt_query_ffn.load_state_dict(
self.query_ffn.state_dict()
)
return
class ProjectorModel_OMG_LLaVA(PreTrainedModel):
_auto_class = 'AutoModel'
config_class = ProjectorConfig_OMG_LLaVA
base_model_prefix = 'model'
supports_gradient_checkpointing = True
def __init__(self, config: ProjectorConfig_OMG_LLaVA) -> None:
super().__init__(config)
self.gradient_checkpointing = False
self.rm_prior_embedding = False
self.rm_query = False
self.model = Naive_Proj(config, )
def enable_input_require_grads(self):
def make_inputs_require_grad(module, input, output):
if isinstance(output, torch.Tensor):
output.requires_grad_(True)
else:
for item in output:
item.requires_grad_(True)
self.model.register_forward_hook(make_inputs_require_grad)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ProjectorConfig_OMG_LLaVA):
module.gradient_checkpointing = value
def forward(self, x):
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
else:
layer_outputs = self.model(x)
return layer_outputs
class SelfAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
return self.forward_post(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
class CrossAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
class FFNLayer(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt):
tgt2 = self.norm(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt):
if self.normalize_before:
return self.forward_pre(tgt)
return self.forward_post(tgt)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")