| from typing import Tuple, Union |
| from torch import nn |
| import torch |
| from .model_utils import VisualTransformer |
| from .configuration_bert import BertConfig |
| from .modeling_bert import BertModel |
| import numpy as np |
| from utils import _tokenizer |
|
|
| class CLIP(nn.Module): |
| def __init__(self, |
| embed_dim: int, |
| |
| image_resolution: int, |
| vision_layers: Union[Tuple[int, int, int, int], int], |
| vision_width: int, |
| vision_patch_size: int, |
| |
| vocab_size: int, |
| text_attention_probs_dropout_prob: float, |
| text_hidden_act: str, |
| text_hidden_dropout_prob: float, |
| text_hidden_size: int, |
| text_initializer_range: float, |
| text_intermediate_size: int, |
| text_max_position_embeddings: int, |
| text_num_attention_heads: int, |
| text_num_hidden_layers: int, |
| text_type_vocab_size: int, |
| tokenizer=_tokenizer, |
| ): |
| super().__init__() |
|
|
| vision_heads = vision_width // 64 |
| self.visual = VisualTransformer( |
| input_resolution=image_resolution, |
| patch_size=vision_patch_size, |
| width=vision_width, |
| layers=vision_layers, |
| heads=vision_heads, |
| output_dim=embed_dim |
| ) |
|
|
| self.bert_config = BertConfig( |
| vocab_size_or_config_json_file=vocab_size, |
| hidden_size=text_hidden_size, |
| num_hidden_layers=text_num_hidden_layers, |
| num_attention_heads=text_num_attention_heads, |
| intermediate_size=text_intermediate_size, |
| hidden_act=text_hidden_act, |
| hidden_dropout_prob=text_hidden_dropout_prob, |
| attention_probs_dropout_prob=text_attention_probs_dropout_prob, |
| max_position_embeddings=text_max_position_embeddings, |
| type_vocab_size=text_type_vocab_size, |
| initializer_range=text_initializer_range, |
| layer_norm_eps=1e-12, |
| ) |
| self.bert = BertModel(self.bert_config) |
|
|
| self.text_projection = nn.Parameter( |
| torch.empty(text_hidden_size, embed_dim)) |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
| self.tokenizer = tokenizer |
|
|
| |
| |
|
|
| self.initialize_parameters() |
|
|
| def initialize_parameters(self): |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
| if self.text_projection is not None: |
| nn.init.normal_(self.text_projection, |
| std=self.bert_config.hidden_size ** -0.5) |
|
|
| @property |
| def dtype(self): |
| return self.visual.conv1.weight.dtype |
|
|
| def encode_image(self, image): |
| return self.visual(image.type(self.dtype)) |
|
|
| def encode_text(self, text): |
| pad_index = self.tokenizer.vocab['[PAD]'] |
| attn_mask = text.ne(pad_index).type(self.dtype) |
| x = self.bert(text, attention_mask=attn_mask)[0].type( |
| self.dtype) |
| x = x @ self.text_projection |
| return x[:, 0, :], x |
|
|
| def forward(self, image, text): |
| assert image is not None or text is not None, "text and image cannot both be None!" |
|
|
| if image is None: |
| return self.encode_text(text) |
| elif text is None: |
| return self.encode_image(image) |
| image_features = self.encode_image(image) |
| text_features = self.encode_text(text) |
|
|
| features = {'image_embed': image_features, |
| 'text_embed': text_features, |
| 'logit_scale': self.logit_scale.exp()} |
|
|
| ret = self.cl_head(features) |
|
|
| return ret |
| |
|
|
|
|
| class MergeLayer(nn.Module): |
| def __init__(self, d_model, nhead): |
| super().__init__() |
| self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model)) |
| self.norm2 = nn.LayerNorm(d_model) |
|
|
| def forward(self, query, visual_feats, text_feats, key_padding_mask=None): |
| visual_cls = visual_feats[:, 0] |
| text_cls = text_feats[:, 0] |
| query = query + visual_cls.unsqueeze(1) + text_cls.unsqueeze(1) |
| kv = torch.cat([visual_feats, text_feats], dim=1) |
|
|
| |
| attn_output, _ = self.cross_attn(query=query, key=kv, value=kv, key_padding_mask=key_padding_mask) |
| |
| query = self.norm1(query + attn_output) |
| |
| ffn_output = self.ffn(query) |
| query = self.norm2(query + ffn_output) |
| return query |
|
|
|
|
| class Modality_Mergerv3(nn.Module): |
| def __init__(self, d_model: int = 512, d_output: int = 256, nhead: int = 8, layer_num: int = 3): |
| super().__init__() |
| self.q = nn.Parameter(torch.randn(1, 1, d_model)) |
| self.layers = nn.ModuleList([ |
| MergeLayer(d_model, nhead) for _ in range(layer_num) |
| ]) |
| self.proj_d512 = nn.Linear(d_model, 512) |
| self.proj_d256 = nn.Linear(d_model, 256) |
| self.proj_d128 = nn.Linear(d_model, 128) |
| self.proj_d64 = nn.Linear(d_model, 64) |
| self.proj_d32 = nn.Linear(d_model, 32) |
|
|
| def forward(self, visual_feats, |
| text_feats, text_mask, |
| cate_feats, cate_mask, |
| c2c_feats, |
| ): |
| bs = visual_feats.size(0) |
|
|
| |
| visual_mask = torch.ones(visual_feats.size()[:2], dtype=text_mask.dtype, device=text_mask.device) |
| c2c_mask = torch.ones(c2c_feats.size()[:2], dtype=text_mask.dtype, device=text_mask.device) |
| key_padding_mask = torch.cat([visual_mask, text_mask, cate_mask, c2c_mask], dim=1) |
| key_padding_mask = key_padding_mask == 0 |
|
|
| text_feats = torch.cat([text_feats, cate_feats, c2c_feats], dim=1) |
|
|
| query = self.q.expand(bs, -1, -1) |
| for layer in self.layers: |
| query = layer(query, visual_feats, text_feats, key_padding_mask) |
|
|
| query = query.squeeze(1) |
|
|
| mm_features_d512 = self.proj_d512(query) |
| mm_features_d256 = self.proj_d256(query) |
| mm_features_d128 = self.proj_d128(query) |
| mm_features_d64 = self.proj_d64(query) |
| mm_features_d32 = self.proj_d32(query) |
| return mm_features_d512, mm_features_d256, mm_features_d128, mm_features_d64, mm_features_d32 |
|
|