|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.activations import ACT2FN |
|
|
from typing import Optional |
|
|
from torch import Tensor |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from .configuration_projector import ProjectorConfig_OMG_LLaVA |
|
|
|
|
|
class Naive_Proj(nn.Module): |
|
|
def __init__(self, config, rm_prior_embedding=False, |
|
|
rm_query=False): |
|
|
super().__init__() |
|
|
query_channels = config.query_channels |
|
|
self.query_channels = query_channels |
|
|
feat_channels = config.feat_channels |
|
|
if isinstance(query_channels, tuple): |
|
|
query_channels = query_channels[0] |
|
|
if isinstance(feat_channels, tuple): |
|
|
feat_channels = feat_channels[0] |
|
|
|
|
|
add_cross_attn_layer = config.add_cross_attn_layer |
|
|
self.add_cross_attn_layer = config.add_cross_attn_layer |
|
|
|
|
|
query_channels = query_channels * 2 |
|
|
|
|
|
self.query_proj = nn.Linear(query_channels, feat_channels) |
|
|
|
|
|
modules = [ |
|
|
nn.Linear( |
|
|
feat_channels, |
|
|
config.llm_hidden_size, |
|
|
bias=config.bias) |
|
|
] |
|
|
for _ in range(1, config.depth): |
|
|
modules.append(ACT2FN[config.hidden_act]) |
|
|
modules.append( |
|
|
nn.Linear( |
|
|
config.llm_hidden_size, |
|
|
config.llm_hidden_size, |
|
|
bias=config.bias)) |
|
|
self.model = nn.Sequential(*modules) |
|
|
|
|
|
if add_cross_attn_layer: |
|
|
print("Using Cross Attention Layer at Projector !!!") |
|
|
self.query_cross_attn = CrossAttentionLayer( |
|
|
d_model=config.llm_hidden_size, |
|
|
nhead=32, |
|
|
) |
|
|
self.query_ffn = FFNLayer( |
|
|
d_model=config.llm_hidden_size, |
|
|
dim_feedforward=4096, |
|
|
) |
|
|
else: |
|
|
self.query_cross_attn = None |
|
|
self.query_ffn = None |
|
|
|
|
|
modules = [ |
|
|
nn.Linear( |
|
|
feat_channels + query_channels, |
|
|
config.llm_hidden_size, |
|
|
bias=config.bias) |
|
|
] |
|
|
for _ in range(1, config.depth): |
|
|
modules.append(ACT2FN[config.hidden_act]) |
|
|
modules.append( |
|
|
nn.Linear( |
|
|
config.llm_hidden_size, |
|
|
config.llm_hidden_size, |
|
|
bias=config.bias)) |
|
|
self.model_feat = nn.Sequential(*modules) |
|
|
|
|
|
self.seperate_embed = nn.Embedding(1, config.llm_hidden_size) |
|
|
|
|
|
self.rm_prior_embedding = rm_prior_embedding |
|
|
self.rm_query = rm_query |
|
|
|
|
|
visual_prompt_proj = config.visual_prompt_proj |
|
|
self.visual_prompt_proj = visual_prompt_proj |
|
|
if not visual_prompt_proj: |
|
|
self.visual_prompt_query_proj = None |
|
|
self.visual_prompt_query_model = None |
|
|
self.visual_prompt_query_cross_attn = None |
|
|
self.visual_prompt_query_ffn = None |
|
|
else: |
|
|
print("Initialized all Layers for Visual Prompt in Projector !!!") |
|
|
self.visual_prompt_query_proj = nn.Linear(query_channels, feat_channels) |
|
|
modules = [ |
|
|
nn.Linear( |
|
|
feat_channels, |
|
|
config.llm_hidden_size, |
|
|
bias=config.bias) |
|
|
] |
|
|
for _ in range(1, config.depth): |
|
|
modules.append(ACT2FN[config.hidden_act]) |
|
|
modules.append( |
|
|
nn.Linear( |
|
|
config.llm_hidden_size, |
|
|
config.llm_hidden_size, |
|
|
bias=config.bias)) |
|
|
self.visual_prompt_query_model = nn.Sequential(*modules) |
|
|
|
|
|
if add_cross_attn_layer: |
|
|
self.visual_prompt_query_cross_attn = CrossAttentionLayer( |
|
|
d_model=config.llm_hidden_size, |
|
|
nhead=32, |
|
|
) |
|
|
self.visual_prompt_query_ffn = FFNLayer( |
|
|
d_model=config.llm_hidden_size, |
|
|
dim_feedforward=4096, |
|
|
) |
|
|
else: |
|
|
self.visual_prompt_query_cross_attn = None |
|
|
self.visual_prompt_query_ffn = None |
|
|
|
|
|
def forward(self, x): |
|
|
clip_feature, query_feat, attention_mask = x |
|
|
query_feat_copy = query_feat[0, :1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.rm_prior_embedding: |
|
|
clip_feature_feat = clip_feature[:, :, :-512] |
|
|
clip_feature_query = clip_feature[:, :, -512:] * 0.0 |
|
|
clip_feature = torch.cat([clip_feature_feat, clip_feature_query], dim=-1) |
|
|
|
|
|
query_feat = self.query_proj(query_feat) |
|
|
|
|
|
valid_mask = attention_mask.sum(dim=-1) < attention_mask.shape[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bs, n_q = query_feat.shape[:2] |
|
|
|
|
|
layer_outputs = self.model(query_feat) |
|
|
|
|
|
|
|
|
clip_feature_out = clip_feature |
|
|
clip_feature_out = self.model_feat(clip_feature_out) |
|
|
ret = [] |
|
|
|
|
|
valid_queries_embeddings = [] |
|
|
for layer_output, keep in zip(layer_outputs, valid_mask): |
|
|
valid_queries_embeddings.append(layer_output[keep]) |
|
|
self.valid_queries_embeddings = valid_queries_embeddings |
|
|
|
|
|
self.last_clip_feature = clip_feature_out |
|
|
|
|
|
for clip_feat, layer_output, keep in zip(clip_feature_out, layer_outputs, valid_mask): |
|
|
valid_layer_output = layer_output[keep] |
|
|
if self.add_cross_attn_layer: |
|
|
valid_layer_output = self.query_cross_attn( |
|
|
valid_layer_output.unsqueeze(1), clip_feat.unsqueeze(1), |
|
|
)[:, 0] |
|
|
valid_layer_output = self.query_ffn(valid_layer_output) |
|
|
if self.rm_query: |
|
|
ret.append(clip_feat + torch.mean(self.seperate_embed.weight) * 0.0 + torch.mean(valid_layer_output) * 0.0) |
|
|
else: |
|
|
ret.append(torch.cat([clip_feat, self.seperate_embed.weight, valid_layer_output], dim=0)) |
|
|
|
|
|
|
|
|
if self.visual_prompt_proj: |
|
|
visual_prompt_embeddings = query_feat_copy.to(self.visual_prompt_query_proj.weight.dtype) |
|
|
visual_prompt_embeddings = self.visual_prompt_query_proj(visual_prompt_embeddings) |
|
|
visual_prompt_embeddings = self.visual_prompt_query_model(visual_prompt_embeddings) |
|
|
if self.add_cross_attn_layer: |
|
|
clip_feat = self.last_clip_feature[0] |
|
|
visual_prompt_embeddings = self.visual_prompt_query_cross_attn( |
|
|
visual_prompt_embeddings.unsqueeze(1), clip_feat.unsqueeze(1), |
|
|
)[:, 0] |
|
|
visual_prompt_embeddings = self.visual_prompt_query_ffn(visual_prompt_embeddings) |
|
|
self.visual_prompt_zero = visual_prompt_embeddings.sum() * 0.0 |
|
|
else: |
|
|
self.visual_prompt_zero = 0.0 |
|
|
return ret |
|
|
|
|
|
def forward_visual_prompts_embeddings(self, visual_prompt_embeddings, batch_idxs): |
|
|
if self.visual_prompt_proj: |
|
|
visual_prompt_embeddings = visual_prompt_embeddings.to(self.visual_prompt_query_proj.weight.dtype) |
|
|
visual_prompt_embeddings = self.visual_prompt_query_proj(visual_prompt_embeddings) |
|
|
visual_prompt_embeddings = self.visual_prompt_query_model(visual_prompt_embeddings) |
|
|
if self.add_cross_attn_layer: |
|
|
clip_feat = self.last_clip_feature[batch_idxs].permute(1, 0, 2) |
|
|
visual_prompt_embeddings = self.visual_prompt_query_cross_attn( |
|
|
visual_prompt_embeddings.unsqueeze(0), clip_feat, |
|
|
)[0, :] |
|
|
visual_prompt_embeddings = self.visual_prompt_query_ffn(visual_prompt_embeddings) |
|
|
else: |
|
|
visual_prompt_embeddings = visual_prompt_embeddings.to(self.query_proj.weight.dtype) |
|
|
visual_prompt_embeddings = self.query_proj(visual_prompt_embeddings) |
|
|
visual_prompt_embeddings = self.model(visual_prompt_embeddings) |
|
|
if self.add_cross_attn_layer: |
|
|
clip_feat = self.last_clip_feature[batch_idxs].permute(1, 0, 2) |
|
|
visual_prompt_embeddings = self.query_cross_attn( |
|
|
visual_prompt_embeddings.unsqueeze(0), clip_feat, |
|
|
)[0, :] |
|
|
visual_prompt_embeddings = self.query_ffn(visual_prompt_embeddings) |
|
|
return visual_prompt_embeddings |
|
|
|
|
|
def init_visual_prompt_weights(self): |
|
|
if self.visual_prompt_query_proj is not None: |
|
|
self.visual_prompt_query_proj.load_state_dict(self.query_proj.state_dict()) |
|
|
if self.visual_prompt_query_model is not None: |
|
|
self.visual_prompt_query_model.load_state_dict(self.model.state_dict()) |
|
|
if self.visual_prompt_query_cross_attn is not None: |
|
|
self.visual_prompt_query_cross_attn.load_state_dict( |
|
|
self.query_cross_attn.state_dict() |
|
|
) |
|
|
if self.visual_prompt_query_ffn is not None: |
|
|
self.visual_prompt_query_ffn.load_state_dict( |
|
|
self.query_ffn.state_dict() |
|
|
) |
|
|
return |
|
|
|
|
|
class ProjectorModel_OMG_LLaVA(PreTrainedModel): |
|
|
_auto_class = 'AutoModel' |
|
|
config_class = ProjectorConfig_OMG_LLaVA |
|
|
base_model_prefix = 'model' |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config: ProjectorConfig_OMG_LLaVA) -> None: |
|
|
super().__init__(config) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.rm_prior_embedding = False |
|
|
self.rm_query = False |
|
|
self.model = Naive_Proj(config, ) |
|
|
|
|
|
def enable_input_require_grads(self): |
|
|
|
|
|
def make_inputs_require_grad(module, input, output): |
|
|
if isinstance(output, torch.Tensor): |
|
|
output.requires_grad_(True) |
|
|
else: |
|
|
for item in output: |
|
|
item.requires_grad_(True) |
|
|
|
|
|
self.model.register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if isinstance(module, ProjectorConfig_OMG_LLaVA): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def forward(self, x): |
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x) |
|
|
else: |
|
|
layer_outputs = self.model(x) |
|
|
return layer_outputs |
|
|
|
|
|
|
|
|
class SelfAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, nhead, dropout=0.0, |
|
|
activation="relu", normalize_before=False): |
|
|
super().__init__() |
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
q = k = self.with_pos_embed(tgt, query_pos) |
|
|
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.norm(tgt) |
|
|
q = k = self.with_pos_embed(tgt2, query_pos) |
|
|
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt, tgt_mask, |
|
|
tgt_key_padding_mask, query_pos) |
|
|
return self.forward_post(tgt, tgt_mask, |
|
|
tgt_key_padding_mask, query_pos) |
|
|
|
|
|
|
|
|
class CrossAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, nhead, dropout=0.0, |
|
|
activation="relu", normalize_before=False): |
|
|
super().__init__() |
|
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt, memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), |
|
|
key=self.with_pos_embed(memory, pos), |
|
|
value=memory, attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, tgt, memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.norm(tgt) |
|
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), |
|
|
key=self.with_pos_embed(memory, pos), |
|
|
value=memory, attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt, memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt, memory, memory_mask, |
|
|
memory_key_padding_mask, pos, query_pos) |
|
|
return self.forward_post(tgt, memory, memory_mask, |
|
|
memory_key_padding_mask, pos, query_pos) |
|
|
|
|
|
|
|
|
class FFNLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, |
|
|
activation="relu", normalize_before=False): |
|
|
super().__init__() |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt): |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, tgt): |
|
|
tgt2 = self.norm(tgt) |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt) |
|
|
return self.forward_post(tgt) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
|
"""Return an activation function given a string""" |
|
|
if activation == "relu": |
|
|
return F.relu |
|
|
if activation == "gelu": |
|
|
return F.gelu |
|
|
if activation == "glu": |
|
|
return F.glu |
|
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |