Spaces:
Runtime error
Runtime error
| from tqdm.auto import tqdm | |
| import torch | |
| from torch import nn | |
| from data.py import create_dataloaders | |
| #get the train/test dataloaders from data.py | |
| train_loader, test_loader = create_dataloaders() | |
| #define an accuracy function | |
| def accuracy_fn(y_true, y_pred): | |
| correct = torch.eq(y_true, y_pred).sum().item() | |
| acc = (correct / len(y_pred)) * 100 | |
| return acc | |
| #instantiate loss function and optimizer | |
| loss_fn = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) | |
| #create a function for a training step | |
| def train_step(model): | |
| train_loss, train_accuracy = 0, 0 | |
| model.train() | |
| for batch, (x,y) in enumerate(train_loader): | |
| #get predictions | |
| y_logits = model(x) | |
| y_pred = y_logits.argmax(dim = 1) | |
| #calculate loss | |
| loss = loss_fn(y_logits, y) | |
| train_loss += loss.item() | |
| train_accuracy += accuracy_fn(y, y_pred) | |
| #update model | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| #divide test loss and accuracy by length of dataloader | |
| train_loss /= len(train_loader) | |
| train_accuracy /= len(train_loader) | |
| #return train loss and accuracy | |
| return train_loss, train_accuracy | |
| #create a function to test the model | |
| def test_step(model): | |
| test_loss, test_accuracy = 0, 0 | |
| model.eval() | |
| with torch.inference_mode(): | |
| for batch, (x,y) in enumerate(test_loader): | |
| y_logits = model(x) | |
| y_pred = y_logits.argmax(dim = 1) | |
| loss = loss_fn(y_logits, y) | |
| test_loss += loss.item() | |
| test_accuracy += accuracy_fn(y, y_pred) | |
| #divide test loss and accuracy by length of dataloader | |
| test_loss /= len(test_loader) | |
| test_accuracy /= len(test_loader) | |
| #return test loss and accuracy | |
| return test_loss, test_accuracy | |
| def train(model, epochs): | |
| """Trains a model for a given number of epochs | |
| Args: model and epochs | |
| Returns: The trained model and a dictionary of train/test loss and train/test accuracy for each epoch. | |
| """ | |
| #create an empty list of train/test metrics | |
| train_loss, test_loss, train_acc, test_acc = [], [], [], [] | |
| for epoch in tqdm(range(epochs)): | |
| #train step and save the loss and accuracy | |
| new_train_loss, new_train_acc = train_step(model) | |
| train_loss.append(new_train_loss) | |
| train_acc.append(new_train_acc) | |
| #test step and save the loss and accuracy | |
| new_test_loss, new_test_acc = test_step(model) | |
| test_loss.append(new_test_loss) | |
| test_acc.append(new_test_acc) | |
| #put the metrics in a dictionary | |
| metrics = {"train_loss": train_loss, "test_loss" : test_loss, | |
| "train_acc": train_acc, "test_acc": test_acc} | |
| return model, metrics | |