File size: 1,452 Bytes
cce011e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn

from custom_transformer.embedding import Embedding
from custom_transformer.encoder import TransformerEncoderBlock

class ViT(nn.Module):

    def __init__(self, 
                 image_size: int = 224,
                 in_channels: int = 3,
                 patch_size: int = 16,
                 num_transformer_layers: int = 12,
                 embedding_dim: int = 768,
                 mlp_size: int = 3072,
                 num_heads: int = 12,
                 num_classes: int = 3) -> None:
        
        super(ViT, self).__init__()

        self.embedding = Embedding(image_size=image_size, in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
        self.transformer_encoders = nn.Sequential(
            *[TransformerEncoderBlock(embedding_dim=embedding_dim, hidden_size=mlp_size, num_heads=num_heads)
              for _ in range(num_transformer_layers)]
        )
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer_encoders(x)
        x = self.classifier(x[:, 0])
        return x
    
if __name__ == '__main__':
    sample_image_batch = torch.rand(5,3,500,500)
    vit = ViT(image_size=500, patch_size=50)
    print(vit(sample_image_batch).shape)