Image Classification
File size: 4,002 Bytes
fa9bf01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from model import MiniViT

#This is a standard transformation to convert images to PyTorch Tensors

transform = transforms.Compose([transforms.ToTensor()])

# Download and load the CIFAR-10 training dataset
trainset = torchvision.datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform)

# Create a DataLoader to handle batching and shuffling
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=4,
                                          shuffle=True)

# --- INSPECT ONE IMAGE ---
# Get one batch of training images

dataiter = iter(trainloader)
images, labels = next(dataiter)

# Select the very first image and its label from the batch
first_image = images[0]
first_label = labels[0]

# Print the shape of the image tensor and its label
print("----Data Inspection---")
print(f"Image shape: {first_image.shape}")
print(f"Label : {first_label.item()}")

model = MiniViT()
# --- TRAINING SETUP ---

# 1. The Loss Function
# CrossEntropyLoss is a standard choice for classification problems.
criterion = nn.CrossEntropyLoss()

# 2. The Optimizer
# Adam is a popular and effective optimizer. We tell it which parameters
# to tune (model.parameters()) and the learning rate (lr).
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# --- THE TRAINING LOOP ---
print("\n--- Starting Training ---")
num_epochs = 20  # Let's train for 5 full cycles through the data

for epoch in range(num_epochs):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # Get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # --- The 5 Core Steps of Training ---

        # 1. Zero the parameter gradients (important!)
        optimizer.zero_grad()

        # 2. Forward pass: get the model's predictions
        outputs = model(inputs)

        # 3. Calculate the loss (how wrong the model was)
        loss = criterion(outputs, labels)

        # 4. Backward pass: calculate the gradients
        loss.backward()

        # 5. Update the weights: the optimizer tunes the model
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:  # Print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('--- Finished Training ---')

# --- EVALUATION ---
print("\n--- Starting Evaluation ---")

# First, we need to load the test dataset
testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,  # IMPORTANT: use the test set
                                       download=True,
                                       transform=transform)

testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=4,
                                         shuffle=False)  # No need to shuffle for testing

correct = 0
total = 0

# Set the model to evaluation mode (disables dropout, etc.)
model.eval()

# We don't need to calculate gradients for evaluation, which saves memory and computations
with torch.no_grad():
    for data in testloader:
        images, labels = data

        # Get the model's predictions
        outputs = model(images)

        # Find the prediction with the highest score (the predicted class)
        _, predicted = torch.max(outputs.data, 1)

        # Count the total and correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %')