# Add this import to the top of your file import torch from torch import nn # --- MODEL ARCHITECTURE --- class MiniViT(nn.Module): def __init__(self, patch_size=4, hidden_dim=128, num_heads=4, num_layers=2, num_classes=10): super().__init__() # --- 1. Patching and Embedding --- self.patch_size = patch_size # An image is 32x32 with 3 color channels. # Patch dimension is 4 * 4 * 3 = 48 patch_dim = 3 * patch_size * patch_size num_patches = (32 // patch_size) ** 2 # This layer projects the flattened patches into the hidden_dim self.patch_embedding = nn.Linear(patch_dim, hidden_dim) # --- 2. CLS Token and Positional Embedding --- # A special token that will be used for classification self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) # A learnable embedding to give the model spatial information self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim)) # --- 3. Transformer Encoder --- # This is the main workhorse of the model encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # --- 4. Classifier Head --- # This takes the processed CLS token and makes the final prediction self.classifier = nn.Linear(hidden_dim, num_classes) def forward(self, x): # x has shape [batch_size, 3, 32, 32] # 1. Patching # Reshape the image into a sequence of flattened patches patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) patches = patches.contiguous().view(x.size(0), -1, 3 * self.patch_size * self.patch_size) # Patches now have shape [batch_size, num_patches, patch_dim] # 2. Embedding # Project patches to the hidden dimension x = self.patch_embedding(patches) # [batch_size, num_patches, hidden_dim] # 3. Prepend CLS token and add Positional Embedding # Expand CLS token for the whole batch and add it to the front cls_tokens = self.cls_token.expand(x.size(0), -1, -1) x = torch.cat((cls_tokens, x), dim=1) # [batch_size, num_patches + 1, hidden_dim] # Add the positional information x = x + self.pos_embedding # 4. Pass through Transformer Encoder x = self.transformer_encoder(x) # [batch_size, num_patches + 1, hidden_dim] # 5. Get the CLS token output and classify cls_output = x[:, 0] # Get the output of the first token (CLS) output = self.classifier(cls_output) return output # --- Create an instance of the model --- # Add this line at the end of your script model = MiniViT() print("\n--- Model Architecture ---") print(model) # You can also test it with a dummy image dummy_image = torch.randn(1, 3, 32, 32) # A single random image prediction = model(dummy_image) print("\n--- Dummy Prediction Test ---") print(f"Output shape: {prediction.shape}") # Should be [1, 10]