|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
|
|
|
class MiniViT(nn.Module):
|
|
|
def __init__(self, patch_size=4, hidden_dim=128, num_heads=4, num_layers=2, num_classes=10):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
patch_dim = 3 * patch_size * patch_size
|
|
|
num_patches = (32 // patch_size) ** 2
|
|
|
|
|
|
|
|
|
self.patch_embedding = nn.Linear(patch_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
|
|
|
|
|
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim))
|
|
|
|
|
|
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
|
|
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
|
|
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(hidden_dim, num_classes)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
|
|
patches = patches.contiguous().view(x.size(0), -1, 3 * self.patch_size * self.patch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.patch_embedding(patches)
|
|
|
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
|
|
|
|
|
x = x + self.pos_embedding
|
|
|
|
|
|
|
|
|
x = self.transformer_encoder(x)
|
|
|
|
|
|
|
|
|
cls_output = x[:, 0]
|
|
|
output = self.classifier(cls_output)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = MiniViT()
|
|
|
print("\n--- Model Architecture ---")
|
|
|
print(model)
|
|
|
|
|
|
|
|
|
dummy_image = torch.randn(1, 3, 32, 32)
|
|
|
prediction = model(dummy_image)
|
|
|
print("\n--- Dummy Prediction Test ---")
|
|
|
print(f"Output shape: {prediction.shape}") |