Image Classification
Files changed (2) hide show
  1. main.py +120 -0
  2. model.py +78 -0
main.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+ import torch.nn as nn
5
+ from model import MiniViT
6
+
7
+ #This is a standard transformation to convert images to PyTorch Tensors
8
+
9
+ transform = transforms.Compose([transforms.ToTensor()])
10
+
11
+ # Download and load the CIFAR-10 training dataset
12
+ trainset = torchvision.datasets.CIFAR10(root='./data',
13
+ train=True,
14
+ download=True,
15
+ transform=transform)
16
+
17
+ # Create a DataLoader to handle batching and shuffling
18
+ trainloader = torch.utils.data.DataLoader(trainset,
19
+ batch_size=4,
20
+ shuffle=True)
21
+
22
+ # --- INSPECT ONE IMAGE ---
23
+ # Get one batch of training images
24
+
25
+ dataiter = iter(trainloader)
26
+ images, labels = next(dataiter)
27
+
28
+ # Select the very first image and its label from the batch
29
+ first_image = images[0]
30
+ first_label = labels[0]
31
+
32
+ # Print the shape of the image tensor and its label
33
+ print("----Data Inspection---")
34
+ print(f"Image shape: {first_image.shape}")
35
+ print(f"Label : {first_label.item()}")
36
+
37
+ model = MiniViT()
38
+ # --- TRAINING SETUP ---
39
+
40
+ # 1. The Loss Function
41
+ # CrossEntropyLoss is a standard choice for classification problems.
42
+ criterion = nn.CrossEntropyLoss()
43
+
44
+ # 2. The Optimizer
45
+ # Adam is a popular and effective optimizer. We tell it which parameters
46
+ # to tune (model.parameters()) and the learning rate (lr).
47
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
48
+
49
+ # --- THE TRAINING LOOP ---
50
+ print("\n--- Starting Training ---")
51
+ num_epochs = 20 # Let's train for 5 full cycles through the data
52
+
53
+ for epoch in range(num_epochs):
54
+
55
+ running_loss = 0.0
56
+ for i, data in enumerate(trainloader, 0):
57
+ # Get the inputs; data is a list of [inputs, labels]
58
+ inputs, labels = data
59
+
60
+ # --- The 5 Core Steps of Training ---
61
+
62
+ # 1. Zero the parameter gradients (important!)
63
+ optimizer.zero_grad()
64
+
65
+ # 2. Forward pass: get the model's predictions
66
+ outputs = model(inputs)
67
+
68
+ # 3. Calculate the loss (how wrong the model was)
69
+ loss = criterion(outputs, labels)
70
+
71
+ # 4. Backward pass: calculate the gradients
72
+ loss.backward()
73
+
74
+ # 5. Update the weights: the optimizer tunes the model
75
+ optimizer.step()
76
+
77
+ # Print statistics
78
+ running_loss += loss.item()
79
+ if i % 2000 == 1999: # Print every 2000 mini-batches
80
+ print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
81
+ running_loss = 0.0
82
+
83
+ print('--- Finished Training ---')
84
+
85
+ # --- EVALUATION ---
86
+ print("\n--- Starting Evaluation ---")
87
+
88
+ # First, we need to load the test dataset
89
+ testset = torchvision.datasets.CIFAR10(root='./data',
90
+ train=False, # IMPORTANT: use the test set
91
+ download=True,
92
+ transform=transform)
93
+
94
+ testloader = torch.utils.data.DataLoader(testset,
95
+ batch_size=4,
96
+ shuffle=False) # No need to shuffle for testing
97
+
98
+ correct = 0
99
+ total = 0
100
+
101
+ # Set the model to evaluation mode (disables dropout, etc.)
102
+ model.eval()
103
+
104
+ # We don't need to calculate gradients for evaluation, which saves memory and computations
105
+ with torch.no_grad():
106
+ for data in testloader:
107
+ images, labels = data
108
+
109
+ # Get the model's predictions
110
+ outputs = model(images)
111
+
112
+ # Find the prediction with the highest score (the predicted class)
113
+ _, predicted = torch.max(outputs.data, 1)
114
+
115
+ # Count the total and correct predictions
116
+ total += labels.size(0)
117
+ correct += (predicted == labels).sum().item()
118
+
119
+ accuracy = 100 * correct / total
120
+ print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %')
model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add this import to the top of your file
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ # --- MODEL ARCHITECTURE ---
7
+ class MiniViT(nn.Module):
8
+ def __init__(self, patch_size=4, hidden_dim=128, num_heads=4, num_layers=2, num_classes=10):
9
+ super().__init__()
10
+
11
+ # --- 1. Patching and Embedding ---
12
+ self.patch_size = patch_size
13
+ # An image is 32x32 with 3 color channels.
14
+ # Patch dimension is 4 * 4 * 3 = 48
15
+ patch_dim = 3 * patch_size * patch_size
16
+ num_patches = (32 // patch_size) ** 2
17
+
18
+ # This layer projects the flattened patches into the hidden_dim
19
+ self.patch_embedding = nn.Linear(patch_dim, hidden_dim)
20
+
21
+ # --- 2. CLS Token and Positional Embedding ---
22
+ # A special token that will be used for classification
23
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
24
+
25
+ # A learnable embedding to give the model spatial information
26
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim))
27
+
28
+ # --- 3. Transformer Encoder ---
29
+ # This is the main workhorse of the model
30
+ encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
31
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
32
+
33
+ # --- 4. Classifier Head ---
34
+ # This takes the processed CLS token and makes the final prediction
35
+ self.classifier = nn.Linear(hidden_dim, num_classes)
36
+
37
+ def forward(self, x):
38
+ # x has shape [batch_size, 3, 32, 32]
39
+
40
+ # 1. Patching
41
+ # Reshape the image into a sequence of flattened patches
42
+ patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
43
+ patches = patches.contiguous().view(x.size(0), -1, 3 * self.patch_size * self.patch_size)
44
+ # Patches now have shape [batch_size, num_patches, patch_dim]
45
+
46
+ # 2. Embedding
47
+ # Project patches to the hidden dimension
48
+ x = self.patch_embedding(patches) # [batch_size, num_patches, hidden_dim]
49
+
50
+ # 3. Prepend CLS token and add Positional Embedding
51
+ # Expand CLS token for the whole batch and add it to the front
52
+ cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
53
+ x = torch.cat((cls_tokens, x), dim=1) # [batch_size, num_patches + 1, hidden_dim]
54
+
55
+ # Add the positional information
56
+ x = x + self.pos_embedding
57
+
58
+ # 4. Pass through Transformer Encoder
59
+ x = self.transformer_encoder(x) # [batch_size, num_patches + 1, hidden_dim]
60
+
61
+ # 5. Get the CLS token output and classify
62
+ cls_output = x[:, 0] # Get the output of the first token (CLS)
63
+ output = self.classifier(cls_output)
64
+
65
+ return output
66
+
67
+
68
+ # --- Create an instance of the model ---
69
+ # Add this line at the end of your script
70
+ model = MiniViT()
71
+ print("\n--- Model Architecture ---")
72
+ print(model)
73
+
74
+ # You can also test it with a dummy image
75
+ dummy_image = torch.randn(1, 3, 32, 32) # A single random image
76
+ prediction = model(dummy_image)
77
+ print("\n--- Dummy Prediction Test ---")
78
+ print(f"Output shape: {prediction.shape}") # Should be [1, 10]