| import torch | |
| import torch.nn as nn | |
| from custom_transformer.embedding import Embedding | |
| from custom_transformer.encoder import TransformerEncoderBlock | |
| class ViT(nn.Module): | |
| def __init__(self, | |
| image_size: int = 224, | |
| in_channels: int = 3, | |
| patch_size: int = 16, | |
| num_transformer_layers: int = 12, | |
| embedding_dim: int = 768, | |
| mlp_size: int = 3072, | |
| num_heads: int = 12, | |
| num_classes: int = 3) -> None: | |
| super(ViT, self).__init__() | |
| self.embedding = Embedding(image_size=image_size, in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size) | |
| self.transformer_encoders = nn.Sequential( | |
| *[TransformerEncoderBlock(embedding_dim=embedding_dim, hidden_size=mlp_size, num_heads=num_heads) | |
| for _ in range(num_transformer_layers)] | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(normalized_shape=embedding_dim), | |
| nn.Linear(in_features=embedding_dim, out_features=num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| x = self.transformer_encoders(x) | |
| x = self.classifier(x[:, 0]) | |
| return x | |
| if __name__ == '__main__': | |
| sample_image_batch = torch.rand(5,3,500,500) | |
| vit = ViT(image_size=500, patch_size=50) | |
| print(vit(sample_image_batch).shape) |