File size: 3,131 Bytes
cce011e
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5feebb1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import math

DEBUG = False

class PatchEmbedding(nn.Module):

    def __init__(self, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None:
        
        super(PatchEmbedding, self).__init__()
        self.linear_projection = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        # Input: [batch_size, in_channels, H, W]
        if DEBUG: print(f'Patch embedding input shape: {x.shape} [batch_size, in_channels, image_height, image_width]')

        # Linear Projection: [batch_size, embedding_dim, sqrt(n_patches), sqrt(n_patches)]
        x = self.linear_projection(x)
        if DEBUG: print(f'Linearly projected input: {x.shape} [batch_size, embedding_dim, sqrt(n_patches), sqrt(n_patches)]')

        # Flattening: [batch_size, embedding_dim, n_patches]
        x = x.flatten(start_dim=2)
        if DEBUG: print(f'Flattening of last 2 dimensions of linear projection: {x.shape} [batch_size, embedding_dim, n_patches]')

        # Transpose last 2 dimensions: [batch_size, n_patches, embedding_dim]
        x = x.mT
        if DEBUG: print(f'Transpose last 2 dimensions: {x.shape} [batch_size, n_patches, embedding_dim]')

        return x
    
class Embedding(nn.Module):

    def __init__(self, image_size: int = 224, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None:
        
        super(Embedding, self).__init__()

        assert image_size % patch_size == 0

        self.n_patches = (image_size * image_size) // (patch_size * patch_size)
        if DEBUG: print(f'Total number of patches: {self.n_patches}, i.e. {int(math.sqrt(self.n_patches))} x {int(math.sqrt(self.n_patches))}')

        # Patch embedding defined above
        self.patch_embedding = PatchEmbedding(in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
        
        # The class token x0, 1 for each embedding dim
        self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
        # The positional embedding, `n_patches` many for each embedding dim
        self.position_embedding = nn.Parameter(torch.randn(1, self.n_patches + 1, embedding_dim))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:

        if DEBUG: print(f'Embedding input shape: {x.shape}: [batch_size, in_channels, height, width]')

        x = self.patch_embedding(x)
        if DEBUG: print(f'Patch embedding output: {x.shape}: [batch_size, n_patches, embedding_dim]')

        x = torch.cat((self.class_token.expand(len(x), -1, -1), x), dim=1)
        if DEBUG: print(f'Class token prepended: {x.shape}: [batch_size, n_patches + 1, embedding_dim]')

        x = x + self.position_embedding
        if DEBUG: print(f'Positional embedding added: {x.shape}: [batch_size, n_patches + 1, embedding_dim]')

        return x
    
if __name__ == '__main__':
    DEBUG = True
    sample_image_batch = torch.rand(5,3,224,224)
    embedding = Embedding()
    out = embedding(sample_image_batch)
    print(out.shape)