| | import torch |
| | import torch.nn as nn |
| | import math |
| |
|
| | DEBUG = False |
| |
|
| | class PatchEmbedding(nn.Module): |
| |
|
| | def __init__(self, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None: |
| | |
| | super(PatchEmbedding, self).__init__() |
| | self.linear_projection = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
|
| | |
| | if DEBUG: print(f'Patch embedding input shape: {x.shape} [batch_size, in_channels, image_height, image_width]') |
| |
|
| | |
| | x = self.linear_projection(x) |
| | if DEBUG: print(f'Linearly projected input: {x.shape} [batch_size, embedding_dim, sqrt(n_patches), sqrt(n_patches)]') |
| |
|
| | |
| | x = x.flatten(start_dim=2) |
| | if DEBUG: print(f'Flattening of last 2 dimensions of linear projection: {x.shape} [batch_size, embedding_dim, n_patches]') |
| |
|
| | |
| | x = x.mT |
| | if DEBUG: print(f'Transpose last 2 dimensions: {x.shape} [batch_size, n_patches, embedding_dim]') |
| |
|
| | return x |
| | |
| | class Embedding(nn.Module): |
| |
|
| | def __init__(self, image_size: int = 224, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None: |
| | |
| | super(Embedding, self).__init__() |
| |
|
| | assert image_size % patch_size == 0 |
| |
|
| | self.n_patches = (image_size * image_size) // (patch_size * patch_size) |
| | if DEBUG: print(f'Total number of patches: {self.n_patches}, i.e. {int(math.sqrt(self.n_patches))} x {int(math.sqrt(self.n_patches))}') |
| |
|
| | |
| | self.patch_embedding = PatchEmbedding(in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size) |
| | |
| | |
| | self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) |
| | |
| | |
| | self.position_embedding = nn.Parameter(torch.randn(1, self.n_patches + 1, embedding_dim)) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
|
| | if DEBUG: print(f'Embedding input shape: {x.shape}: [batch_size, in_channels, height, width]') |
| |
|
| | x = self.patch_embedding(x) |
| | if DEBUG: print(f'Patch embedding output: {x.shape}: [batch_size, n_patches, embedding_dim]') |
| |
|
| | x = torch.cat((self.class_token.expand(len(x), -1, -1), x), dim=1) |
| | if DEBUG: print(f'Class token prepended: {x.shape}: [batch_size, n_patches + 1, embedding_dim]') |
| |
|
| | x = x + self.position_embedding |
| | if DEBUG: print(f'Positional embedding added: {x.shape}: [batch_size, n_patches + 1, embedding_dim]') |
| |
|
| | return x |
| | |
| | if __name__ == '__main__': |
| | DEBUG = True |
| | sample_image_batch = torch.rand(5,3,224,224) |
| | embedding = Embedding() |
| | out = embedding(sample_image_batch) |
| | print(out.shape) |
| |
|