File size: 2,970 Bytes
c62c87b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from tqdm.auto import tqdm
import torch
from torch import nn
from data.py import create_dataloaders
#get the train/test dataloaders from data.py
train_loader, test_loader = create_dataloaders()
#define an accuracy function
def accuracy_fn(y_true, y_pred):
correct = torch.eq(y_true, y_pred).sum().item()
acc = (correct / len(y_pred)) * 100
return acc
#instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
#create a function for a training step
def train_step(model):
train_loss, train_accuracy = 0, 0
model.train()
for batch, (x,y) in enumerate(train_loader):
#get predictions
y_logits = model(x)
y_pred = y_logits.argmax(dim = 1)
#calculate loss
loss = loss_fn(y_logits, y)
train_loss += loss.item()
train_accuracy += accuracy_fn(y, y_pred)
#update model
optimizer.zero_grad()
loss.backward()
optimizer.step()
#divide test loss and accuracy by length of dataloader
train_loss /= len(train_loader)
train_accuracy /= len(train_loader)
#return train loss and accuracy
return train_loss, train_accuracy
#create a function to test the model
def test_step(model):
test_loss, test_accuracy = 0, 0
model.eval()
with torch.inference_mode():
for batch, (x,y) in enumerate(test_loader):
y_logits = model(x)
y_pred = y_logits.argmax(dim = 1)
loss = loss_fn(y_logits, y)
test_loss += loss.item()
test_accuracy += accuracy_fn(y, y_pred)
#divide test loss and accuracy by length of dataloader
test_loss /= len(test_loader)
test_accuracy /= len(test_loader)
#return test loss and accuracy
return test_loss, test_accuracy
def train(model, epochs):
"""Trains a model for a given number of epochs
Args: model and epochs
Returns: The trained model and a dictionary of train/test loss and train/test accuracy for each epoch.
"""
#create an empty list of train/test metrics
train_loss, test_loss, train_acc, test_acc = [], [], [], []
for epoch in tqdm(range(epochs)):
#train step and save the loss and accuracy
new_train_loss, new_train_acc = train_step(model)
train_loss.append(new_train_loss)
train_acc.append(new_train_acc)
#test step and save the loss and accuracy
new_test_loss, new_test_acc = test_step(model)
test_loss.append(new_test_loss)
test_acc.append(new_test_acc)
#put the metrics in a dictionary
metrics = {"train_loss": train_loss, "test_loss" : test_loss,
"train_acc": train_acc, "test_acc": test_acc}
return model, metrics
|