import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn from model import MiniViT #This is a standard transformation to convert images to PyTorch Tensors transform = transforms.Compose([transforms.ToTensor()]) # Download and load the CIFAR-10 training dataset trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) # Create a DataLoader to handle batching and shuffling trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True) # --- INSPECT ONE IMAGE --- # Get one batch of training images dataiter = iter(trainloader) images, labels = next(dataiter) # Select the very first image and its label from the batch first_image = images[0] first_label = labels[0] # Print the shape of the image tensor and its label print("----Data Inspection---") print(f"Image shape: {first_image.shape}") print(f"Label : {first_label.item()}") model = MiniViT() # --- TRAINING SETUP --- # 1. The Loss Function # CrossEntropyLoss is a standard choice for classification problems. criterion = nn.CrossEntropyLoss() # 2. The Optimizer # Adam is a popular and effective optimizer. We tell it which parameters # to tune (model.parameters()) and the learning rate (lr). optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # --- THE TRAINING LOOP --- print("\n--- Starting Training ---") num_epochs = 20 # Let's train for 5 full cycles through the data for epoch in range(num_epochs): running_loss = 0.0 for i, data in enumerate(trainloader, 0): # Get the inputs; data is a list of [inputs, labels] inputs, labels = data # --- The 5 Core Steps of Training --- # 1. Zero the parameter gradients (important!) optimizer.zero_grad() # 2. Forward pass: get the model's predictions outputs = model(inputs) # 3. Calculate the loss (how wrong the model was) loss = criterion(outputs, labels) # 4. Backward pass: calculate the gradients loss.backward() # 5. Update the weights: the optimizer tunes the model optimizer.step() # Print statistics running_loss += loss.item() if i % 2000 == 1999: # Print every 2000 mini-batches print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}') running_loss = 0.0 print('--- Finished Training ---') # --- EVALUATION --- print("\n--- Starting Evaluation ---") # First, we need to load the test dataset testset = torchvision.datasets.CIFAR10(root='./data', train=False, # IMPORTANT: use the test set download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False) # No need to shuffle for testing correct = 0 total = 0 # Set the model to evaluation mode (disables dropout, etc.) model.eval() # We don't need to calculate gradients for evaluation, which saves memory and computations with torch.no_grad(): for data in testloader: images, labels = data # Get the model's predictions outputs = model(images) # Find the prediction with the highest score (the predicted class) _, predicted = torch.max(outputs.data, 1) # Count the total and correct predictions total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %')