| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| import torch.nn as nn | |
| class PatchEmbed(nn.Module): | |
| """ | |
| Image to Patch Embedding | |
| """ | |
| def __init__( | |
| self, | |
| patch_size=16, | |
| in_chans=3, | |
| embed_dim=768 | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| x = self.proj(x).flatten(2).transpose(1, 2) | |
| return x | |
| class PatchEmbed3D(nn.Module): | |
| """ | |
| Image to Patch Embedding | |
| """ | |
| def __init__( | |
| self, | |
| patch_size=16, | |
| tubelet_size=2, | |
| in_chans=3, | |
| embed_dim=768, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.tubelet_size = tubelet_size | |
| self.proj = nn.Conv3d( | |
| in_channels=in_chans, | |
| out_channels=embed_dim, | |
| kernel_size=(tubelet_size, patch_size, patch_size), | |
| stride=(tubelet_size, patch_size, patch_size), | |
| ) | |
| def forward(self, x, **kwargs): | |
| B, C, T, H, W = x.shape | |
| x = self.proj(x).flatten(2).transpose(1, 2) | |
| return x | |