| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm.models.registry import register_model |
| import numpy as np |
|
|
| import utils |
| from modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config |
|
|
|
|
| class TwoLayerMLP(nn.Module): |
| def __init__( |
| self, |
| in_features, |
| hidden_features, |
| out_features, |
| norm_layer, |
| norm_input=True, |
| ): |
| super().__init__() |
| self.norm1 = norm_layer(in_features) if norm_input else nn.Identity() |
| self.dense1 = nn.Linear(in_features, hidden_features) |
| self.norm2 = norm_layer(hidden_features) |
| self.act = nn.GELU() |
| self.dense2 = nn.Linear(hidden_features, out_features) |
|
|
| def forward(self, x): |
| x = self.norm1(x) |
| x = self.dense1(x) |
| x = self.norm2(x) |
| x = self.act(x) |
| return self.dense2(x) |
|
|
|
|
| class Pooler(nn.Module): |
| def __init__(self, input_features, output_features, norm_layer): |
| super().__init__() |
| self.norm = norm_layer(input_features) |
| self.dense = nn.Linear(input_features, output_features) |
| self.activation = nn.Tanh() |
|
|
| def forward(self, x): |
| cls_rep = x[:, 0, :] |
| cls_rep = self.norm(cls_rep) |
| pooled_output = self.dense(cls_rep) |
| pooled_output = self.activation(pooled_output) |
| return pooled_output |
|
|
|
|
| class BEiT3ForVisualReasoning(BEiT3Wrapper): |
| def __init__( |
| self, |
| args, |
| num_classes, |
| norm_layer=nn.LayerNorm, |
| **kwargs |
| ): |
| super(BEiT3ForVisualReasoning, self).__init__(args=args) |
| embed_dim = args.encoder_embed_dim |
| self.head = TwoLayerMLP( |
| in_features=embed_dim * 4, |
| hidden_features=embed_dim * 2, |
| out_features=num_classes, |
| norm_layer=norm_layer, |
| ) |
| init_scale = 0.001 |
| self.head.apply(self._init_weights) |
| if isinstance(self.head.dense1, nn.Linear): |
| self.head.dense1.weight.data.mul_(init_scale) |
| self.head.dense1.bias.data.mul_(init_scale) |
|
|
| if isinstance(self.head.dense2, nn.Linear): |
| self.head.dense2.weight.data.mul_(init_scale) |
| self.head.dense2.bias.data.mul_(init_scale) |
|
|
| def forward(self, image_a, image_b, text_description, padding_mask, **kwargs): |
| bsz, _ = text_description.size() |
| |
| vision_input = torch.cat((image_a, image_b), dim=0) |
| language_input = torch.cat((text_description, text_description), dim=0) |
| padding_mask = torch.cat((padding_mask, padding_mask), dim=0) |
|
|
| outputs = self.beit3( |
| textual_tokens=language_input, |
| visual_tokens=vision_input, |
| text_padding_position=padding_mask, |
| ) |
| x = outputs["encoder_out"] |
| multiway_split_position = outputs["multiway_split_position"] |
|
|
| vision_cls = x[:, 0, :] |
| language_cls = x[:, multiway_split_position, :] |
| cls_rep = torch.cat((vision_cls, language_cls), dim=-1) |
| a, b = torch.split(cls_rep, split_size_or_sections=[bsz, bsz], dim=0) |
| cls_rep = torch.cat((a, b), dim=-1) |
| return self.head(cls_rep) |
| |
|
|
| class BEiT3ForImageClassification(BEiT3Wrapper): |
| def __init__( |
| self, |
| args, |
| num_classes, |
| norm_layer=nn.LayerNorm, |
| **kwargs |
| ): |
| super(BEiT3ForImageClassification, self).__init__(args=args) |
| embed_dim = args.encoder_embed_dim |
| self.fc_norm = norm_layer(embed_dim) |
| self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
| self.fc_norm.apply(self._init_weights) |
| self.head.apply(self._init_weights) |
| init_scale = 0.001 |
| if isinstance(self.head, nn.Linear): |
| self.head.weight.data.mul_(init_scale) |
| self.head.bias.data.mul_(init_scale) |
|
|
| def forward(self, image, **kwargs): |
| x = self.beit3(textual_tokens=None, visual_tokens=image)["encoder_out"] |
| t = x[:, 1:, :] |
| cls_x = self.fc_norm(t.mean(1)) |
| return self.head(cls_x) |
|
|
|
|
| class BEiT3ForCaptioning(BEiT3Wrapper): |
| def __init__( |
| self, |
| args, |
| **kwargs |
| ): |
| super(BEiT3ForCaptioning, self).__init__(args=args) |
| embed_dim = args.encoder_embed_dim |
| self.mlm_head = nn.Linear(embed_dim, args.vocab_size) |
| self.mlm_head.apply(self._init_weights) |
|
|
| def forward(self, image, text_ids, padding_mask, language_masked_pos, text_len=None, incremental_state=None, **kwargs): |
| text_len = text_len if text_len is not None else text_ids.size(1) |
| image_len = self.beit3.vision_embed.num_position_embeddings() |
| max_len = text_len + image_len |
| uni_mask = torch.zeros((max_len, max_len), dtype=torch.long, device=text_ids.device) |
| i_start, i_end = 0, image_len |
| t_start, t_end = image_len, max_len |
| |
| uni_mask[t_start:t_end, t_start:t_end] = torch.tril(torch.ones(text_len, text_len, dtype=torch.long, device=text_ids.device)) |
| |
| uni_mask[t_start:t_end, i_start:i_end] = 1 |
| |
| uni_mask[i_start:i_end, i_start:i_end] = 1 |
| uni_mask = 1-uni_mask |
|
|
| if incremental_state is not None: |
| for idx in range(self.get_num_layers()): |
| if idx not in incremental_state: |
| incremental_state[idx] = {} |
| |
| |
| positions = None |
| if image is None: |
| uni_mask = uni_mask[-2:] |
| padding_mask = None |
| |
| positions = torch.arange(text_len, text_ids.size(1) + text_len, device=text_ids.device).long().unsqueeze(0) |
|
|
| outputs = self.beit3( |
| textual_tokens=text_ids, |
| visual_tokens=image, |
| text_padding_position=padding_mask, |
| attn_mask=uni_mask, |
| incremental_state=incremental_state, |
| positions=positions, |
| ) |
| if image is not None: |
| text_feats = outputs["encoder_out"][:, image_len:] |
| else: |
| text_feats = outputs["encoder_out"] |
|
|
| if language_masked_pos is not None: |
| text_feats = text_feats[language_masked_pos.bool()] |
|
|
| return self.mlm_head(text_feats), incremental_state |
|
|
|
|
| class BEiT3ForVisualQuestionAnswering(BEiT3Wrapper): |
| def __init__( |
| self, |
| args, |
| num_classes, |
| norm_layer=nn.LayerNorm, |
| **kwargs |
| ): |
| super(BEiT3ForVisualQuestionAnswering, self).__init__(args=args) |
| embed_dim = args.encoder_embed_dim |
| self.pooler = Pooler( |
| input_features=embed_dim, |
| output_features=embed_dim, |
| norm_layer=norm_layer, |
| ) |
| self.pooler.apply(self._init_weights) |
| self.head = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim * 2), |
| norm_layer(embed_dim * 2), |
| nn.GELU(), |
| nn.Linear(embed_dim * 2, num_classes), |
| ) |
| self.head.apply(self._init_weights) |
|
|
| def forward(self, image, question, padding_mask, **kwargs): |
| outputs = self.beit3( |
| textual_tokens=question, |
| visual_tokens=image, |
| text_padding_position=padding_mask, |
| ) |
| x = outputs["encoder_out"] |
| cls_rep = self.pooler(x) |
| return self.head(cls_rep) |
|
|
|
|
| class BEiT3ForRetrieval(BEiT3Wrapper): |
| def __init__( |
| self, |
| args, |
| **kwargs |
| ): |
| super(BEiT3ForRetrieval, self).__init__(args=args) |
| embed_dim = args.encoder_embed_dim |
| self.language_head = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.vision_head = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.language_head.apply(self._init_weights) |
| self.vision_head.apply(self._init_weights) |
| self.criterion = utils.ClipLoss( |
| rank=utils.get_rank(), |
| world_size=utils.get_world_size(), |
| ) |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
| def forward(self, image=None, text_description=None, padding_mask=None, only_infer=False, **kwargs): |
| if image is not None: |
| outputs = self.beit3( |
| textual_tokens=None, |
| visual_tokens=image, |
| text_padding_position=None, |
| ) |
| x = outputs["encoder_out"] |
| vision_cls = self.vision_head(x[:, 0, :]) |
| vision_cls = F.normalize(vision_cls, dim=-1) |
| else: |
| vision_cls = None |
|
|
| if text_description is not None: |
| outputs = self.beit3( |
| textual_tokens=text_description, |
| visual_tokens=None, |
| text_padding_position=padding_mask, |
| ) |
| x = outputs["encoder_out"] |
| language_cls = self.language_head(x[:, 0, :]) |
| language_cls = F.normalize(language_cls, dim=-1) |
| else: |
| language_cls = None |
| |
| if only_infer: |
| return vision_cls, language_cls |
| else: |
| loss, logits_per_image, logits_per_text = self.criterion( |
| vision_cls, language_cls, self.logit_scale.exp()) |
| return loss, vision_cls, language_cls |
|
|
|
|
| @register_model |
| def beit3_base_patch16_224_imageclassification(pretrained=False, **kwargs): |
| args = _get_base_config(**kwargs) |
| args.normalize_output = False |
| model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_large_patch16_224_imageclassification(pretrained=False, **kwargs): |
| args = _get_large_config(**kwargs) |
| args.normalize_output = False |
| model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_base_patch16_224_nlvr2(pretrained=False, **kwargs): |
| args = _get_base_config(**kwargs) |
| model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_large_patch16_224_nlvr2(pretrained=False, **kwargs): |
| args = _get_large_config(**kwargs) |
| model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_base_patch16_384_vqav2(pretrained=False, **kwargs): |
| args = _get_base_config(img_size=384, **kwargs) |
| args.normalize_output = False |
| model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_base_patch16_480_vqav2(pretrained=False, **kwargs): |
| args = _get_base_config(img_size=480, **kwargs) |
| args.normalize_output = False |
| model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_large_patch16_384_vqav2(pretrained=False, **kwargs): |
| args = _get_large_config(img_size=384, **kwargs) |
| args.normalize_output = False |
| model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_large_patch16_480_vqav2(pretrained=False, **kwargs): |
| args = _get_large_config(img_size=480, **kwargs) |
| args.normalize_output = False |
| model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_large_patch16_768_vqav2(pretrained=False, **kwargs): |
| args = _get_large_config(img_size=768, **kwargs) |
| args.normalize_output = False |
| model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_base_patch16_224_captioning(pretrained=False, **kwargs): |
| args = _get_base_config(**kwargs) |
| model = BEiT3ForCaptioning(args, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_base_patch16_480_captioning(pretrained=False, **kwargs): |
| args = _get_base_config(img_size=480, **kwargs) |
| model = BEiT3ForCaptioning(args, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_large_patch16_480_captioning(pretrained=False, **kwargs): |
| args = _get_large_config(img_size=480, **kwargs) |
| model = BEiT3ForCaptioning(args, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_base_patch16_224_retrieval(pretrained=False, **kwargs): |
| args = _get_base_config(**kwargs) |
| model = BEiT3ForRetrieval(args, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_base_patch16_384_retrieval(pretrained=False, **kwargs): |
| args = _get_base_config(img_size=384, **kwargs) |
| model = BEiT3ForRetrieval(args, **kwargs) |
| return model |
|
|
|
|
| @register_model |
| def beit3_large_patch16_384_retrieval(pretrained=False, **kwargs): |
| args = _get_large_config(img_size=384, **kwargs) |
| model = BEiT3ForRetrieval(args, **kwargs) |
| return model |
|
|