Forge-EMB-mmclip / utils /models.py
AL-GR's picture
Upload model
2666e68 verified
from typing import Tuple, Union
from torch import nn
import torch
from .model_utils import VisualTransformer
from .configuration_bert import BertConfig
from .modeling_bert import BertModel
import numpy as np
from utils import _tokenizer
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
vocab_size: int,
text_attention_probs_dropout_prob: float,
text_hidden_act: str,
text_hidden_dropout_prob: float,
text_hidden_size: int,
text_initializer_range: float,
text_intermediate_size: int,
text_max_position_embeddings: int,
text_num_attention_heads: int,
text_num_hidden_layers: int,
text_type_vocab_size: int,
tokenizer=_tokenizer,
):
super().__init__()
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.bert_config = BertConfig(
vocab_size_or_config_json_file=vocab_size,
hidden_size=text_hidden_size,
num_hidden_layers=text_num_hidden_layers,
num_attention_heads=text_num_attention_heads,
intermediate_size=text_intermediate_size,
hidden_act=text_hidden_act,
hidden_dropout_prob=text_hidden_dropout_prob,
attention_probs_dropout_prob=text_attention_probs_dropout_prob,
max_position_embeddings=text_max_position_embeddings,
type_vocab_size=text_type_vocab_size,
initializer_range=text_initializer_range,
layer_norm_eps=1e-12,
)
self.bert = BertModel(self.bert_config)
self.text_projection = nn.Parameter(
torch.empty(text_hidden_size, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.tokenizer = tokenizer
# loss
# self.cl_head = CLIPLoss_withMask_withmultimodal()
self.initialize_parameters()
def initialize_parameters(self):
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
if self.text_projection is not None:
nn.init.normal_(self.text_projection,
std=self.bert_config.hidden_size ** -0.5)
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
pad_index = self.tokenizer.vocab['[PAD]']
attn_mask = text.ne(pad_index).type(self.dtype)
x = self.bert(text, attention_mask=attn_mask)[0].type(
self.dtype) # [batch_size, seq_length, hidden_size]
x = x @ self.text_projection
return x[:, 0, :], x
def forward(self, image, text):
assert image is not None or text is not None, "text and image cannot both be None!"
if image is None:
return self.encode_text(text)
elif text is None:
return self.encode_image(image)
image_features = self.encode_image(image)
text_features = self.encode_text(text)
features = {'image_embed': image_features,
'text_embed': text_features,
'logit_scale': self.logit_scale.exp()}
ret = self.cl_head(features)
return ret
class MergeLayer(nn.Module):
def __init__(self, d_model, nhead):
super().__init__()
self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
self.norm2 = nn.LayerNorm(d_model)
def forward(self, query, visual_feats, text_feats, key_padding_mask=None):
visual_cls = visual_feats[:, 0]
text_cls = text_feats[:, 0]
query = query + visual_cls.unsqueeze(1) + text_cls.unsqueeze(1)
kv = torch.cat([visual_feats, text_feats], dim=1)
# 交叉注意力
attn_output, _ = self.cross_attn(query=query, key=kv, value=kv, key_padding_mask=key_padding_mask)
# 残差连接+层归一化
query = self.norm1(query + attn_output)
# 前馈网络
ffn_output = self.ffn(query)
query = self.norm2(query + ffn_output)
return query
class Modality_Mergerv3(nn.Module):
def __init__(self, d_model: int = 512, d_output: int = 256, nhead: int = 8, layer_num: int = 3):
super().__init__()
self.q = nn.Parameter(torch.randn(1, 1, d_model))
self.layers = nn.ModuleList([
MergeLayer(d_model, nhead) for _ in range(layer_num)
])
self.proj_d512 = nn.Linear(d_model, 512)
self.proj_d256 = nn.Linear(d_model, 256)
self.proj_d128 = nn.Linear(d_model, 128)
self.proj_d64 = nn.Linear(d_model, 64)
self.proj_d32 = nn.Linear(d_model, 32)
def forward(self, visual_feats,
text_feats, text_mask,
cate_feats, cate_mask,
c2c_feats,
):
bs = visual_feats.size(0)
# 创建视觉特征的全1 mask(假设视觉特征无pad)
visual_mask = torch.ones(visual_feats.size()[:2], dtype=text_mask.dtype, device=text_mask.device)
c2c_mask = torch.ones(c2c_feats.size()[:2], dtype=text_mask.dtype, device=text_mask.device)
key_padding_mask = torch.cat([visual_mask, text_mask, cate_mask, c2c_mask], dim=1)
key_padding_mask = key_padding_mask == 0
text_feats = torch.cat([text_feats, cate_feats, c2c_feats], dim=1)
query = self.q.expand(bs, -1, -1)
for layer in self.layers:
query = layer(query, visual_feats, text_feats, key_padding_mask)
query = query.squeeze(1)
mm_features_d512 = self.proj_d512(query)
mm_features_d256 = self.proj_d256(query)
mm_features_d128 = self.proj_d128(query)
mm_features_d64 = self.proj_d64(query)
mm_features_d32 = self.proj_d32(query)
return mm_features_d512, mm_features_d256, mm_features_d128, mm_features_d64, mm_features_d32