i4ata's picture
smol update
5feebb1
import torch
import torch.nn as nn
from custom_transformer.embedding import Embedding
from custom_transformer.encoder import TransformerEncoderBlock
class ViT(nn.Module):
def __init__(self,
image_size: int = 224,
in_channels: int = 3,
patch_size: int = 16,
num_transformer_layers: int = 12,
embedding_dim: int = 768,
mlp_size: int = 3072,
num_heads: int = 12,
num_classes: int = 3) -> None:
super(ViT, self).__init__()
self.embedding = Embedding(image_size=image_size, in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
self.transformer_encoders = nn.Sequential(
*[TransformerEncoderBlock(embedding_dim=embedding_dim, hidden_size=mlp_size, num_heads=num_heads)
for _ in range(num_transformer_layers)]
)
self.classifier = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim, out_features=num_classes)
)
def forward(self, x):
x = self.embedding(x)
x = self.transformer_encoders(x)
x = self.classifier(x[:, 0])
return x
if __name__ == '__main__':
sample_image_batch = torch.rand(5,3,500,500)
vit = ViT(image_size=500, patch_size=50)
print(vit(sample_image_batch).shape)