|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision.transforms import ToTensor |
|
|
|
|
|
from torchvision import transforms |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import datasets |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
from time import time |
|
|
from torch import nn |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch, os |
|
|
from utils import ApplyEnhancementFilter |
|
|
|
|
|
|
|
|
device = ( |
|
|
"cuda" |
|
|
if torch.cuda.is_available() |
|
|
else "mps" |
|
|
if torch.backends.mps.is_available() |
|
|
else "cpu" |
|
|
) |
|
|
print(f"Using {device} device for training/inference.") |
|
|
if device == "cuda": |
|
|
print(f"GPU being used: {torch.cuda.get_device_name(0)}") |
|
|
|
|
|
|
|
|
train_transform = transforms.Compose([ |
|
|
|
|
|
|
|
|
transforms.RandomAffine(degrees=35, translate=(0.1, 0.1), scale=(0.9, 1.1)), |
|
|
transforms.RandomRotation(degrees=35), |
|
|
|
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.13066047430038452,), (0.30810782313346863,)), |
|
|
|
|
|
transforms.Pad(2, fill=0, padding_mode='constant'), |
|
|
]) |
|
|
|
|
|
|
|
|
test_transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.13066047430038452,), (0.30810782313346863,)), |
|
|
transforms.Pad(2, fill=0, padding_mode='constant'), |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import idx2numpy |
|
|
class CustomImageDataset(Dataset): |
|
|
""" |
|
|
This class must inherit from the torch.utils.data.Dataset class. |
|
|
And contina functions __init__, __len__, and __getitem__. |
|
|
""" |
|
|
def __init__(self, annotations_file, image_file, transform=None, target_transform=None): |
|
|
self.img_labels = idx2numpy.convert_from_file(annotations_file) |
|
|
self.images = idx2numpy.convert_from_file(image_file) |
|
|
self.transform = transform |
|
|
self.target_transform = target_transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.img_labels) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
"""Get the image and label at the index idx.""" |
|
|
label = self.img_labels[idx] |
|
|
img = self.images[idx] |
|
|
img = Image.fromarray(img) |
|
|
|
|
|
if self.transform: |
|
|
img = self.transform(img) |
|
|
if self.target_transform: |
|
|
label = self.target_transform(label) |
|
|
|
|
|
|
|
|
|
|
|
return img, label |
|
|
|
|
|
|
|
|
|
|
|
class LeNet5Model(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.tanh = nn.Tanh() |
|
|
|
|
|
|
|
|
self.le_stack = nn.Sequential( |
|
|
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1), |
|
|
self.tanh, |
|
|
nn.AvgPool2d(kernel_size=2, stride=2), |
|
|
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1), |
|
|
self.tanh, |
|
|
nn.AvgPool2d(kernel_size=2, stride=2), |
|
|
nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1), |
|
|
self.tanh |
|
|
) |
|
|
|
|
|
self.fc_stack = nn.Sequential( |
|
|
nn.Linear(in_features=120, out_features=84), |
|
|
self.tanh, |
|
|
nn.Linear(in_features=84, out_features=10) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass of the model.""" |
|
|
x = self.le_stack(x) |
|
|
x = x.reshape(x.shape[0], -1) |
|
|
x = self.fc_stack(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def train_model(model, train_loader, test_loader, epochs=10, learning_rate=0.001, saved_model=None): |
|
|
""" |
|
|
Given a model, train the model using the train_loader and test_loader, and show metrics, |
|
|
saving the best model parameters currently. |
|
|
""" |
|
|
|
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-6) |
|
|
best_accuracy = 0.0 |
|
|
|
|
|
if os.path.exists("best_model.txt"): |
|
|
with open("best_model.txt", "r") as file: |
|
|
best_accuracy = float(file.read()) |
|
|
|
|
|
if saved_model is not None: |
|
|
model.load_state_dict(torch.load(saved_model)) |
|
|
|
|
|
|
|
|
for i in range(epochs): |
|
|
model.train() |
|
|
print("Epoch ", i) |
|
|
for batch, (x, y) in enumerate(train_loader): |
|
|
|
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_pred = model(x) |
|
|
|
|
|
loss = loss_fn(y_pred, y) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
if batch % 250 == 0: |
|
|
print(f"Epoch {i} batch {batch} loss: {loss.item()}") |
|
|
|
|
|
model.eval() |
|
|
correct, total = 0, 0 |
|
|
with torch.no_grad(): |
|
|
for x, y in test_loader: |
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
y_pred = model(x) |
|
|
_, predicted = torch.max(y_pred, 1) |
|
|
total += y.size(0) |
|
|
correct += (predicted == y).sum().item() |
|
|
print(f"Epoch {i} accuracy: {correct/total}") |
|
|
if correct/total > best_accuracy: |
|
|
best_accuracy = correct/total |
|
|
torch.save(model.state_dict(), "lenet_mnist_model.pth") |
|
|
with open("best_model.txt", "w") as file: |
|
|
file.write(f"{best_accuracy}") |
|
|
print("Training complete.") |
|
|
|
|
|
|
|
|
def init_weights(m): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
m.bias.data.fill_(0.01) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
m.bias.data.fill_(0.01) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_data = CustomImageDataset("mnist_dataset/t10k-labels.idx1-ubyte", "mnist_dataset/t10k-images.idx3-ubyte", transform=test_transform) |
|
|
print((test_data[0])[0].shape, "label value", test_data[0][1]) |
|
|
train_data = CustomImageDataset("mnist_dataset/train-labels.idx1-ubyte", "mnist_dataset/train-images.idx3-ubyte", transform=train_transform) |
|
|
|
|
|
|
|
|
test_loader = DataLoader(test_data, batch_size=64, shuffle=True) |
|
|
train_loader = DataLoader(train_data, batch_size=64, shuffle=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = LeNet5Model().to(device) |
|
|
model.apply(init_weights) |
|
|
print(model) |
|
|
|
|
|
|
|
|
|
|
|
train_model(model, train_loader, test_loader, epochs=1000, learning_rate=0.001) |
|
|
|
|
|
torch.save(model.state_dict(), "lenet_mnist_model.pth") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|