File size: 7,111 Bytes
5413412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
from tqdm.auto import tqdm
import time
import numpy as np
from src.losses import calc_prob_uncertinty
tic, toc = (time.time, time.time)


def train(probe, device, train_loader, optimizer, epoch, loss_func, 
          class_names=None, report=False, verbose_interval=5, layer_num=40,
          head=None, verbose=True, return_raw_outputs=False, one_hot=False, uncertainty=False, **kwargs,):
    """
    :param model: pytorch model (class:torch.nn.Module)
    :param device: device used to train the model (e.g. torch.device("cuda") for training on GPU)
    :param train_loader: torch.utils.data.DataLoader of train dataset
    :param optimizer: optimizer for the model
    :param epoch: current epoch of training
    :param loss_func: loss function for the training
    :param class_names: str Name for the classification classses. used in train report
    :param report: whether to print a classification report of training 
    :param train_verbose: print a train progress report after how many batches of training in each epoch
    :return: average loss, train accuracy, true labels, predictions
    """
    assert (verbose_interval is None) or verbose_interval > 0, "invalid verbose_interval, verbose_interval(int) > 0"
    starttime = tic()
    # Set the model to the train mode: Essential for proper gradient descent
    probe.train()
    loss_sum = 0
    correct = 0
    tot = 0
    
    preds = []
    truths = []
    
    # Iterate through the train dataset
    for batch_idx, batch in enumerate(train_loader):
        batch_size = 1
        target = batch["age"].long().cuda()
        if one_hot:
            target = torch.nn.functional.one_hot(target, **kwargs).float()
        optimizer.zero_grad()

        if layer_num or layer_num == 0:
            act = batch["hidden_states"][:, layer_num,].to("cuda")
        else:
            act = batch["hidden_states"].to("cuda")
        output = probe(act)
        if not one_hot:
            loss = loss_func(output[0], target, **kwargs)
        else:
            loss = loss_func(output[0], target)
        loss.backward()
        optimizer.step()
        
        loss_sum += loss.sum().item()  
        if uncertainty:
            pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy())
        pred = torch.argmax(output[0], axis=1)

        # In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax
        # Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization
        # To save some computation time
        if len(target.shape) > 1:
            target = torch.argmax(target, axis=1)
        correct += np.sum(np.array(pred.detach().cpu().numpy()) == np.array(target.detach().cpu().numpy()))
        if return_raw_outputs:
            preds.append(pred.detach().cpu().numpy())
            truths.append(target.detach().cpu().numpy())
        tot += pred.shape[0] 
    
    train_acc = correct / tot
    loss_avg = loss_sum / len(train_loader)
    
    endtime = toc()
    if verbose:
        print('\nTrain set: Average loss: {:.4f} ({:.3f} sec) Accuracy: {:.3f}\n'.\
              format(loss_avg, 
                     endtime-starttime,
                     train_acc))
        
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
        
    if return_raw_outputs:
        return loss_avg, train_acc, preds, truths
    else:
        return loss_avg, train_acc
    

def test(probe, device, test_loader, loss_func, return_raw_outputs=False, verbose=True,
         layer_num=40, scheduler=None, one_hot=False, uncertainty=False, **kwargs):
    """
    :param model: pytorch model (class:torch.nn.Module)
    :param device: device used to train the model (e.g. torch.device("cuda") for training on GPU)
    :param test_loader: torch.utils.data.DataLoader of test dataset
    :param loss_func: loss function for the training
    :param class_names: str Name for the classification classses. used in train report
    :param test_report: whether to print a classification report of testing after each epoch
    :param return_raw_outputs: whether return the raw outputs of model (before argmax). used for auc computation
    :return: average test loss, test accuracy, true labels, predictions, (and raw outputs \ 
    from model if return_raw_outputs)
    """
    # Set the model to evaluation mode: Essential for testing model
    probe.eval()
    test_loss = 0
    tot = 0
    correct = 0
    preds = []
    truths = []
        
    # Do not call gradient descent on the test set
    # We don't adjust the weights of model on the test set
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            batch_size = 1
            target = batch["age"].long().cuda()
            if one_hot:
                target = torch.nn.functional.one_hot(target, **kwargs).float()
            if layer_num or layer_num == 0:
                act = batch["hidden_states"][:, layer_num,].to("cuda")
            else:
                act = batch["hidden_states"].to("cuda")
            output = probe(act)
            if uncertainty:
                pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy())
            pred = torch.argmax(output[0], axis=1)
            
            if not one_hot:
                loss = loss_func(output[0], target, **kwargs)
            else:
                loss = loss_func(output[0], target)
            test_loss += loss.sum().item()  # sum up batch loss

            # In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax
            # Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization
            # To save some computation time
            if len(target.shape) > 1:
                target = torch.argmax(target, axis=1)
            
            
            pred = np.array(pred.detach().cpu().numpy())
            target = np.array(target.detach().cpu().numpy())
            correct += np.sum(pred == target)
            tot += pred.shape[0] 
            if return_raw_outputs:
                preds.append(pred)
                truths.append(target)
                
    test_loss /= len(test_loader)
    if scheduler:
        scheduler.step(test_loss)
    
    test_acc = correct / tot

    if verbose:
        print('Test set: Average loss: {:.4f},  Accuracy: {:.3f}\n'.format(
              test_loss,
              test_acc))
        
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
        
    # If return the raw outputs (before argmax) from the model
    if return_raw_outputs:
        return test_loss, test_acc, preds, truths
    else:
        return test_loss, test_acc

import torch
from tqdm.auto import tqdm
import time
import numpy as np
from .losses import calc_prob_uncertinty
tic, toc = (time.time, time.time)