Image Classification
MiniViT / model.py
nigamx's picture
Upload 2 files (#1)
fa9bf01 verified
# Add this import to the top of your file
import torch
from torch import nn
# --- MODEL ARCHITECTURE ---
class MiniViT(nn.Module):
def __init__(self, patch_size=4, hidden_dim=128, num_heads=4, num_layers=2, num_classes=10):
super().__init__()
# --- 1. Patching and Embedding ---
self.patch_size = patch_size
# An image is 32x32 with 3 color channels.
# Patch dimension is 4 * 4 * 3 = 48
patch_dim = 3 * patch_size * patch_size
num_patches = (32 // patch_size) ** 2
# This layer projects the flattened patches into the hidden_dim
self.patch_embedding = nn.Linear(patch_dim, hidden_dim)
# --- 2. CLS Token and Positional Embedding ---
# A special token that will be used for classification
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
# A learnable embedding to give the model spatial information
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim))
# --- 3. Transformer Encoder ---
# This is the main workhorse of the model
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# --- 4. Classifier Head ---
# This takes the processed CLS token and makes the final prediction
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
# x has shape [batch_size, 3, 32, 32]
# 1. Patching
# Reshape the image into a sequence of flattened patches
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(x.size(0), -1, 3 * self.patch_size * self.patch_size)
# Patches now have shape [batch_size, num_patches, patch_dim]
# 2. Embedding
# Project patches to the hidden dimension
x = self.patch_embedding(patches) # [batch_size, num_patches, hidden_dim]
# 3. Prepend CLS token and add Positional Embedding
# Expand CLS token for the whole batch and add it to the front
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # [batch_size, num_patches + 1, hidden_dim]
# Add the positional information
x = x + self.pos_embedding
# 4. Pass through Transformer Encoder
x = self.transformer_encoder(x) # [batch_size, num_patches + 1, hidden_dim]
# 5. Get the CLS token output and classify
cls_output = x[:, 0] # Get the output of the first token (CLS)
output = self.classifier(cls_output)
return output
# --- Create an instance of the model ---
# Add this line at the end of your script
model = MiniViT()
print("\n--- Model Architecture ---")
print(model)
# You can also test it with a dummy image
dummy_image = torch.randn(1, 3, 32, 32) # A single random image
prediction = model(dummy_image)
print("\n--- Dummy Prediction Test ---")
print(f"Output shape: {prediction.shape}") # Should be [1, 10]