|
|
import torch
|
|
|
import torchvision
|
|
|
import torchvision.transforms as transforms
|
|
|
import torch.nn as nn
|
|
|
from model import MiniViT
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([transforms.ToTensor()])
|
|
|
|
|
|
|
|
|
trainset = torchvision.datasets.CIFAR10(root='./data',
|
|
|
train=True,
|
|
|
download=True,
|
|
|
transform=transform)
|
|
|
|
|
|
|
|
|
trainloader = torch.utils.data.DataLoader(trainset,
|
|
|
batch_size=4,
|
|
|
shuffle=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataiter = iter(trainloader)
|
|
|
images, labels = next(dataiter)
|
|
|
|
|
|
|
|
|
first_image = images[0]
|
|
|
first_label = labels[0]
|
|
|
|
|
|
|
|
|
print("----Data Inspection---")
|
|
|
print(f"Image shape: {first_image.shape}")
|
|
|
print(f"Label : {first_label.item()}")
|
|
|
|
|
|
model = MiniViT()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
|
|
|
|
|
print("\n--- Starting Training ---")
|
|
|
num_epochs = 20
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
|
|
|
|
running_loss = 0.0
|
|
|
for i, data in enumerate(trainloader, 0):
|
|
|
|
|
|
inputs, labels = data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
outputs = model(inputs)
|
|
|
|
|
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
running_loss += loss.item()
|
|
|
if i % 2000 == 1999:
|
|
|
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
|
|
|
running_loss = 0.0
|
|
|
|
|
|
print('--- Finished Training ---')
|
|
|
|
|
|
|
|
|
print("\n--- Starting Evaluation ---")
|
|
|
|
|
|
|
|
|
testset = torchvision.datasets.CIFAR10(root='./data',
|
|
|
train=False,
|
|
|
download=True,
|
|
|
transform=transform)
|
|
|
|
|
|
testloader = torch.utils.data.DataLoader(testset,
|
|
|
batch_size=4,
|
|
|
shuffle=False)
|
|
|
|
|
|
correct = 0
|
|
|
total = 0
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for data in testloader:
|
|
|
images, labels = data
|
|
|
|
|
|
|
|
|
outputs = model(images)
|
|
|
|
|
|
|
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
|
|
|
|
|
|
|
total += labels.size(0)
|
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
|
|
accuracy = 100 * correct / total
|
|
|
print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %') |