Image Classification
File size: 3,236 Bytes
fa9bf01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Add this import to the top of your file
import torch
from torch import nn


# --- MODEL ARCHITECTURE ---
class MiniViT(nn.Module):
    def __init__(self, patch_size=4, hidden_dim=128, num_heads=4, num_layers=2, num_classes=10):
        super().__init__()

        # --- 1. Patching and Embedding ---
        self.patch_size = patch_size
        # An image is 32x32 with 3 color channels.
        # Patch dimension is 4 * 4 * 3 = 48
        patch_dim = 3 * patch_size * patch_size
        num_patches = (32 // patch_size) ** 2

        # This layer projects the flattened patches into the hidden_dim
        self.patch_embedding = nn.Linear(patch_dim, hidden_dim)

        # --- 2. CLS Token and Positional Embedding ---
        # A special token that will be used for classification
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))

        # A learnable embedding to give the model spatial information
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim))

        # --- 3. Transformer Encoder ---
        # This is the main workhorse of the model
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # --- 4. Classifier Head ---
        # This takes the processed CLS token and makes the final prediction
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # x has shape [batch_size, 3, 32, 32]

        # 1. Patching
        # Reshape the image into a sequence of flattened patches
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(x.size(0), -1, 3 * self.patch_size * self.patch_size)
        # Patches now have shape [batch_size, num_patches, patch_dim]

        # 2. Embedding
        # Project patches to the hidden dimension
        x = self.patch_embedding(patches)  # [batch_size, num_patches, hidden_dim]

        # 3. Prepend CLS token and add Positional Embedding
        # Expand CLS token for the whole batch and add it to the front
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)  # [batch_size, num_patches + 1, hidden_dim]

        # Add the positional information
        x = x + self.pos_embedding

        # 4. Pass through Transformer Encoder
        x = self.transformer_encoder(x)  # [batch_size, num_patches + 1, hidden_dim]

        # 5. Get the CLS token output and classify
        cls_output = x[:, 0]  # Get the output of the first token (CLS)
        output = self.classifier(cls_output)

        return output


# --- Create an instance of the model ---
# Add this line at the end of your script
model = MiniViT()
print("\n--- Model Architecture ---")
print(model)

# You can also test it with a dummy image
dummy_image = torch.randn(1, 3, 32, 32)  # A single random image
prediction = model(dummy_image)
print("\n--- Dummy Prediction Test ---")
print(f"Output shape: {prediction.shape}")  # Should be [1, 10]