from tqdm.auto import tqdm import torch from torch import nn from data.py import create_dataloaders #get the train/test dataloaders from data.py train_loader, test_loader = create_dataloaders() #define an accuracy function def accuracy_fn(y_true, y_pred): correct = torch.eq(y_true, y_pred).sum().item() acc = (correct / len(y_pred)) * 100 return acc #instantiate loss function and optimizer loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) #create a function for a training step def train_step(model): train_loss, train_accuracy = 0, 0 model.train() for batch, (x,y) in enumerate(train_loader): #get predictions y_logits = model(x) y_pred = y_logits.argmax(dim = 1) #calculate loss loss = loss_fn(y_logits, y) train_loss += loss.item() train_accuracy += accuracy_fn(y, y_pred) #update model optimizer.zero_grad() loss.backward() optimizer.step() #divide test loss and accuracy by length of dataloader train_loss /= len(train_loader) train_accuracy /= len(train_loader) #return train loss and accuracy return train_loss, train_accuracy #create a function to test the model def test_step(model): test_loss, test_accuracy = 0, 0 model.eval() with torch.inference_mode(): for batch, (x,y) in enumerate(test_loader): y_logits = model(x) y_pred = y_logits.argmax(dim = 1) loss = loss_fn(y_logits, y) test_loss += loss.item() test_accuracy += accuracy_fn(y, y_pred) #divide test loss and accuracy by length of dataloader test_loss /= len(test_loader) test_accuracy /= len(test_loader) #return test loss and accuracy return test_loss, test_accuracy def train(model, epochs): """Trains a model for a given number of epochs Args: model and epochs Returns: The trained model and a dictionary of train/test loss and train/test accuracy for each epoch. """ #create an empty list of train/test metrics train_loss, test_loss, train_acc, test_acc = [], [], [], [] for epoch in tqdm(range(epochs)): #train step and save the loss and accuracy new_train_loss, new_train_acc = train_step(model) train_loss.append(new_train_loss) train_acc.append(new_train_acc) #test step and save the loss and accuracy new_test_loss, new_test_acc = test_step(model) test_loss.append(new_test_loss) test_acc.append(new_test_acc) #put the metrics in a dictionary metrics = {"train_loss": train_loss, "test_loss" : test_loss, "train_acc": train_acc, "test_acc": test_acc} return model, metrics