File size: 2,970 Bytes
c62c87b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from tqdm.auto import tqdm
import torch
from torch import nn

from data.py import create_dataloaders

#get the train/test dataloaders from data.py
train_loader, test_loader = create_dataloaders()

#define an accuracy function
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100 
    return acc

#instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)


#create a function for a training step
def train_step(model):
    train_loss, train_accuracy = 0, 0
    model.train()

    for batch, (x,y) in enumerate(train_loader):    
        #get predictions
        y_logits = model(x)
        y_pred = y_logits.argmax(dim = 1)
    
        #calculate loss
        loss = loss_fn(y_logits, y)
        train_loss += loss.item()
        train_accuracy += accuracy_fn(y, y_pred)
        
        #update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #divide test loss and accuracy by length of dataloader
    train_loss /= len(train_loader)
    train_accuracy /= len(train_loader)
    
    #return train loss and accuracy
    return train_loss, train_accuracy
    
#create a function to test the model
def test_step(model):
    test_loss, test_accuracy = 0, 0
    
    model.eval()
    with torch.inference_mode():
        for batch, (x,y) in enumerate(test_loader):
            y_logits = model(x)
            y_pred = y_logits.argmax(dim = 1)
            
            loss = loss_fn(y_logits, y)
            test_loss += loss.item()         
            test_accuracy += accuracy_fn(y, y_pred)
        
        #divide test loss and accuracy by length of dataloader
        test_loss /= len(test_loader)
        test_accuracy /= len(test_loader)
    
    #return test loss and accuracy
    return test_loss, test_accuracy

def train(model, epochs):
    """Trains a model for a given number of epochs

    

    Args: model and epochs

    Returns: The trained model and a dictionary of train/test loss and train/test accuracy for each epoch.

    """
    #create an empty list of train/test metrics
    train_loss, test_loss, train_acc, test_acc = [], [], [], []
    for epoch in tqdm(range(epochs)):
        #train step and save the loss and accuracy
        new_train_loss, new_train_acc = train_step(model)
        train_loss.append(new_train_loss)
        train_acc.append(new_train_acc)
        
        #test step and save the loss and accuracy
        new_test_loss, new_test_acc = test_step(model)
        test_loss.append(new_test_loss)
        test_acc.append(new_test_acc)
        
    #put the metrics in a dictionary
    metrics = {"train_loss": train_loss, "test_loss" : test_loss, 
               "train_acc": train_acc, "test_acc": test_acc}
    
    return model, metrics