File size: 1,452 Bytes
cce011e 5feebb1 cce011e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | import torch
import torch.nn as nn
from custom_transformer.embedding import Embedding
from custom_transformer.encoder import TransformerEncoderBlock
class ViT(nn.Module):
def __init__(self,
image_size: int = 224,
in_channels: int = 3,
patch_size: int = 16,
num_transformer_layers: int = 12,
embedding_dim: int = 768,
mlp_size: int = 3072,
num_heads: int = 12,
num_classes: int = 3) -> None:
super(ViT, self).__init__()
self.embedding = Embedding(image_size=image_size, in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
self.transformer_encoders = nn.Sequential(
*[TransformerEncoderBlock(embedding_dim=embedding_dim, hidden_size=mlp_size, num_heads=num_heads)
for _ in range(num_transformer_layers)]
)
self.classifier = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim, out_features=num_classes)
)
def forward(self, x):
x = self.embedding(x)
x = self.transformer_encoders(x)
x = self.classifier(x[:, 0])
return x
if __name__ == '__main__':
sample_image_batch = torch.rand(5,3,500,500)
vit = ViT(image_size=500, patch_size=50)
print(vit(sample_image_batch).shape) |