|
|
from tqdm.auto import tqdm
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
|
|
|
from data.py import create_dataloaders
|
|
|
|
|
|
|
|
|
train_loader, test_loader = create_dataloaders()
|
|
|
|
|
|
|
|
|
def accuracy_fn(y_true, y_pred):
|
|
|
correct = torch.eq(y_true, y_pred).sum().item()
|
|
|
acc = (correct / len(y_pred)) * 100
|
|
|
return acc
|
|
|
|
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
|
|
|
|
|
|
|
|
|
|
|
|
def train_step(model):
|
|
|
train_loss, train_accuracy = 0, 0
|
|
|
model.train()
|
|
|
|
|
|
for batch, (x,y) in enumerate(train_loader):
|
|
|
|
|
|
y_logits = model(x)
|
|
|
y_pred = y_logits.argmax(dim = 1)
|
|
|
|
|
|
|
|
|
loss = loss_fn(y_logits, y)
|
|
|
train_loss += loss.item()
|
|
|
train_accuracy += accuracy_fn(y, y_pred)
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
train_loss /= len(train_loader)
|
|
|
train_accuracy /= len(train_loader)
|
|
|
|
|
|
|
|
|
return train_loss, train_accuracy
|
|
|
|
|
|
|
|
|
def test_step(model):
|
|
|
test_loss, test_accuracy = 0, 0
|
|
|
|
|
|
model.eval()
|
|
|
with torch.inference_mode():
|
|
|
for batch, (x,y) in enumerate(test_loader):
|
|
|
y_logits = model(x)
|
|
|
y_pred = y_logits.argmax(dim = 1)
|
|
|
|
|
|
loss = loss_fn(y_logits, y)
|
|
|
test_loss += loss.item()
|
|
|
test_accuracy += accuracy_fn(y, y_pred)
|
|
|
|
|
|
|
|
|
test_loss /= len(test_loader)
|
|
|
test_accuracy /= len(test_loader)
|
|
|
|
|
|
|
|
|
return test_loss, test_accuracy
|
|
|
|
|
|
def train(model, epochs):
|
|
|
"""Trains a model for a given number of epochs
|
|
|
|
|
|
Args: model and epochs
|
|
|
Returns: The trained model and a dictionary of train/test loss and train/test accuracy for each epoch.
|
|
|
"""
|
|
|
|
|
|
train_loss, test_loss, train_acc, test_acc = [], [], [], []
|
|
|
for epoch in tqdm(range(epochs)):
|
|
|
|
|
|
new_train_loss, new_train_acc = train_step(model)
|
|
|
train_loss.append(new_train_loss)
|
|
|
train_acc.append(new_train_acc)
|
|
|
|
|
|
|
|
|
new_test_loss, new_test_acc = test_step(model)
|
|
|
test_loss.append(new_test_loss)
|
|
|
test_acc.append(new_test_acc)
|
|
|
|
|
|
|
|
|
metrics = {"train_loss": train_loss, "test_loss" : test_loss,
|
|
|
"train_acc": train_acc, "test_acc": test_acc}
|
|
|
|
|
|
return model, metrics
|
|
|
|