| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from vlmo.torchscale.architecture.encoder import Encoder |
| from vlmo.torchscale.component.embedding import ( |
| PositionalEmbedding, |
| TextEmbedding, |
| VisionEmbedding, |
| ) |
| from vlmo.torchscale.component.multiway_network import MutliwayEmbedding |
|
|
|
|
| class BEiT3(nn.Module): |
| def __init__(self, args, **kwargs): |
| super().__init__() |
| self.args = args |
| assert args.multiway |
| assert args.vocab_size > 0 |
| assert not args.share_encoder_input_output_embed |
| self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim) |
| self.vision_embed = VisionEmbedding( |
| args.img_size, |
| args.patch_size, |
| args.in_chans, |
| args.encoder_embed_dim, |
| contain_mask_token=True, |
| prepend_cls_token=True, |
| ) |
| |
| embed_positions = MutliwayEmbedding( |
| modules=[ |
| PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim), |
| PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), |
| ], |
| dim=1, |
| ) |
| self.encoder = Encoder( |
| args, |
| embed_tokens=None, |
| embed_positions=embed_positions, |
| output_projection=None, |
| is_encoder_decoder=False, |
| ) |
|
|
| def forward( |
| self, |
| textual_tokens=None, |
| visual_tokens=None, |
| text_padding_position=None, |
| attn_mask=None, |
| vision_masked_position=None, |
| incremental_state=None, |
| positions=None, |
| ): |
| assert textual_tokens is not None or visual_tokens is not None |
|
|
| if textual_tokens is None: |
| x = self.vision_embed(visual_tokens, vision_masked_position) |
| encoder_padding_mask = None |
| multiway_split_position = -1 |
| elif visual_tokens is None: |
| x = self.text_embed(textual_tokens) |
| encoder_padding_mask = text_padding_position |
| multiway_split_position = 0 |
| else: |
| x1 = self.vision_embed(visual_tokens, vision_masked_position) |
| multiway_split_position = x1.size(1) |
| x2 = self.text_embed(textual_tokens) |
| diff = x1.shape[0] // x2.shape[0] |
| if diff != 1: |
| x2 = torch.repeat_interleave(x2, diff, dim=0) |
| text_padding_position = torch.repeat_interleave(text_padding_position, diff, dim=0) |
| x = torch.cat([x1, x2], dim=1) |
| if text_padding_position is not None: |
| encoder_padding_mask = torch.cat( |
| [ |
| torch.zeros(x1.shape[:-1], device=x1.device, dtype=torch.bool), |
| text_padding_position, |
| ], |
| dim=1, |
| ) |
| else: |
| encoder_padding_mask = None |
| encoder_out = self.encoder( |
| src_tokens=None, |
| encoder_padding_mask=encoder_padding_mask, |
| attn_mask=attn_mask, |
| token_embeddings=x, |
| multiway_split_position=multiway_split_position, |
| incremental_state=incremental_state, |
| positions=positions, |
| ) |
| return encoder_out |
|
|