richardschattner's picture
Upload 4 files
c62c87b verified
from tqdm.auto import tqdm
import torch
from torch import nn
from data.py import create_dataloaders
#get the train/test dataloaders from data.py
train_loader, test_loader = create_dataloaders()
#define an accuracy function
def accuracy_fn(y_true, y_pred):
correct = torch.eq(y_true, y_pred).sum().item()
acc = (correct / len(y_pred)) * 100
return acc
#instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
#create a function for a training step
def train_step(model):
train_loss, train_accuracy = 0, 0
model.train()
for batch, (x,y) in enumerate(train_loader):
#get predictions
y_logits = model(x)
y_pred = y_logits.argmax(dim = 1)
#calculate loss
loss = loss_fn(y_logits, y)
train_loss += loss.item()
train_accuracy += accuracy_fn(y, y_pred)
#update model
optimizer.zero_grad()
loss.backward()
optimizer.step()
#divide test loss and accuracy by length of dataloader
train_loss /= len(train_loader)
train_accuracy /= len(train_loader)
#return train loss and accuracy
return train_loss, train_accuracy
#create a function to test the model
def test_step(model):
test_loss, test_accuracy = 0, 0
model.eval()
with torch.inference_mode():
for batch, (x,y) in enumerate(test_loader):
y_logits = model(x)
y_pred = y_logits.argmax(dim = 1)
loss = loss_fn(y_logits, y)
test_loss += loss.item()
test_accuracy += accuracy_fn(y, y_pred)
#divide test loss and accuracy by length of dataloader
test_loss /= len(test_loader)
test_accuracy /= len(test_loader)
#return test loss and accuracy
return test_loss, test_accuracy
def train(model, epochs):
"""Trains a model for a given number of epochs
Args: model and epochs
Returns: The trained model and a dictionary of train/test loss and train/test accuracy for each epoch.
"""
#create an empty list of train/test metrics
train_loss, test_loss, train_acc, test_acc = [], [], [], []
for epoch in tqdm(range(epochs)):
#train step and save the loss and accuracy
new_train_loss, new_train_acc = train_step(model)
train_loss.append(new_train_loss)
train_acc.append(new_train_acc)
#test step and save the loss and accuracy
new_test_loss, new_test_acc = test_step(model)
test_loss.append(new_test_loss)
test_acc.append(new_test_acc)
#put the metrics in a dictionary
metrics = {"train_loss": train_loss, "test_loss" : test_loss,
"train_acc": train_acc, "test_acc": test_acc}
return model, metrics