| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class VisionLanguageEmbedding(nn.Module): |
| def __init__(self, text_embed, vision_embed): |
| super().__init__() |
| self.text_embed = text_embed |
| self.vision_embed = vision_embed |
|
|
| def forward(self, textual_tokens, visual_tokens, **kwargs): |
| if textual_tokens is None: |
| return self.vision_embed(visual_tokens) |
|
|
| if visual_tokens is None: |
| return self.text_embed(textual_tokens) |
|
|
| x1 = self.vision_embed(visual_tokens) |
| x2 = self.text_embed(textual_tokens) |
|
|
| return torch.cat([x1, x2], dim=1) |
|
|
|
|
| class VisionEmbedding(nn.Module): |
| """Image to Patch Embedding""" |
|
|
| def __init__( |
| self, |
| img_size=224, |
| patch_size=16, |
| in_chans=3, |
| embed_dim=768, |
| contain_mask_token=False, |
| prepend_cls_token=False, |
| ): |
| super().__init__() |
| img_size = (img_size, img_size) |
| patch_size = (patch_size, patch_size) |
| num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
| self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.num_patches = num_patches |
|
|
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
| if contain_mask_token: |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| else: |
| self.mask_token = None |
|
|
| if prepend_cls_token: |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| else: |
| self.cls_token = None |
|
|
| def num_position_embeddings(self): |
| if self.cls_token is None: |
| return self.num_patches |
| else: |
| return self.num_patches + 1 |
|
|
| def forward(self, x, masked_position=None, **kwargs): |
| B, C, H, W = x.shape |
| x = self.proj(x).flatten(2).transpose(1, 2) |
|
|
| batch_size, seq_len, _ = x.size() |
|
|
| if masked_position is not None: |
| assert self.mask_token is not None |
| mask_token = self.mask_token.expand(batch_size, seq_len, -1) |
| w = masked_position.unsqueeze(-1).type_as(mask_token) |
| x = x * (1 - w) + mask_token * w |
|
|
| if self.cls_token is not None: |
| cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| return x |
|
|
|
|
| class TextEmbedding(nn.Embedding): |
| def reset_parameters(self): |
| nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) |
| self._fill_padding_idx_with_zero() |
|
|
|
|
| class PositionalEmbedding(nn.Embedding): |
| def forward( |
| self, |
| x, |
| positions=None, |
| **kwargs, |
| ): |
| if positions is None: |
| |
| positions = torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0) |
| return F.embedding( |
| positions, |
| self.weight, |
| self.padding_idx, |
| self.max_norm, |
| self.norm_type, |
| self.scale_grad_by_freq, |
| self.sparse, |
| ) |
|
|