Spaces:
Build error
Build error
File size: 7,111 Bytes
5413412 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | import torch
from tqdm.auto import tqdm
import time
import numpy as np
from src.losses import calc_prob_uncertinty
tic, toc = (time.time, time.time)
def train(probe, device, train_loader, optimizer, epoch, loss_func,
class_names=None, report=False, verbose_interval=5, layer_num=40,
head=None, verbose=True, return_raw_outputs=False, one_hot=False, uncertainty=False, **kwargs,):
"""
:param model: pytorch model (class:torch.nn.Module)
:param device: device used to train the model (e.g. torch.device("cuda") for training on GPU)
:param train_loader: torch.utils.data.DataLoader of train dataset
:param optimizer: optimizer for the model
:param epoch: current epoch of training
:param loss_func: loss function for the training
:param class_names: str Name for the classification classses. used in train report
:param report: whether to print a classification report of training
:param train_verbose: print a train progress report after how many batches of training in each epoch
:return: average loss, train accuracy, true labels, predictions
"""
assert (verbose_interval is None) or verbose_interval > 0, "invalid verbose_interval, verbose_interval(int) > 0"
starttime = tic()
# Set the model to the train mode: Essential for proper gradient descent
probe.train()
loss_sum = 0
correct = 0
tot = 0
preds = []
truths = []
# Iterate through the train dataset
for batch_idx, batch in enumerate(train_loader):
batch_size = 1
target = batch["age"].long().cuda()
if one_hot:
target = torch.nn.functional.one_hot(target, **kwargs).float()
optimizer.zero_grad()
if layer_num or layer_num == 0:
act = batch["hidden_states"][:, layer_num,].to("cuda")
else:
act = batch["hidden_states"].to("cuda")
output = probe(act)
if not one_hot:
loss = loss_func(output[0], target, **kwargs)
else:
loss = loss_func(output[0], target)
loss.backward()
optimizer.step()
loss_sum += loss.sum().item()
if uncertainty:
pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy())
pred = torch.argmax(output[0], axis=1)
# In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax
# Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization
# To save some computation time
if len(target.shape) > 1:
target = torch.argmax(target, axis=1)
correct += np.sum(np.array(pred.detach().cpu().numpy()) == np.array(target.detach().cpu().numpy()))
if return_raw_outputs:
preds.append(pred.detach().cpu().numpy())
truths.append(target.detach().cpu().numpy())
tot += pred.shape[0]
train_acc = correct / tot
loss_avg = loss_sum / len(train_loader)
endtime = toc()
if verbose:
print('\nTrain set: Average loss: {:.4f} ({:.3f} sec) Accuracy: {:.3f}\n'.\
format(loss_avg,
endtime-starttime,
train_acc))
preds = np.concatenate(preds)
truths = np.concatenate(truths)
if return_raw_outputs:
return loss_avg, train_acc, preds, truths
else:
return loss_avg, train_acc
def test(probe, device, test_loader, loss_func, return_raw_outputs=False, verbose=True,
layer_num=40, scheduler=None, one_hot=False, uncertainty=False, **kwargs):
"""
:param model: pytorch model (class:torch.nn.Module)
:param device: device used to train the model (e.g. torch.device("cuda") for training on GPU)
:param test_loader: torch.utils.data.DataLoader of test dataset
:param loss_func: loss function for the training
:param class_names: str Name for the classification classses. used in train report
:param test_report: whether to print a classification report of testing after each epoch
:param return_raw_outputs: whether return the raw outputs of model (before argmax). used for auc computation
:return: average test loss, test accuracy, true labels, predictions, (and raw outputs \
from model if return_raw_outputs)
"""
# Set the model to evaluation mode: Essential for testing model
probe.eval()
test_loss = 0
tot = 0
correct = 0
preds = []
truths = []
# Do not call gradient descent on the test set
# We don't adjust the weights of model on the test set
with torch.no_grad():
for batch_idx, batch in enumerate(test_loader):
batch_size = 1
target = batch["age"].long().cuda()
if one_hot:
target = torch.nn.functional.one_hot(target, **kwargs).float()
if layer_num or layer_num == 0:
act = batch["hidden_states"][:, layer_num,].to("cuda")
else:
act = batch["hidden_states"].to("cuda")
output = probe(act)
if uncertainty:
pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy())
pred = torch.argmax(output[0], axis=1)
if not one_hot:
loss = loss_func(output[0], target, **kwargs)
else:
loss = loss_func(output[0], target)
test_loss += loss.sum().item() # sum up batch loss
# In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax
# Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization
# To save some computation time
if len(target.shape) > 1:
target = torch.argmax(target, axis=1)
pred = np.array(pred.detach().cpu().numpy())
target = np.array(target.detach().cpu().numpy())
correct += np.sum(pred == target)
tot += pred.shape[0]
if return_raw_outputs:
preds.append(pred)
truths.append(target)
test_loss /= len(test_loader)
if scheduler:
scheduler.step(test_loss)
test_acc = correct / tot
if verbose:
print('Test set: Average loss: {:.4f}, Accuracy: {:.3f}\n'.format(
test_loss,
test_acc))
preds = np.concatenate(preds)
truths = np.concatenate(truths)
# If return the raw outputs (before argmax) from the model
if return_raw_outputs:
return test_loss, test_acc, preds, truths
else:
return test_loss, test_acc
import torch
from tqdm.auto import tqdm
import time
import numpy as np
from .losses import calc_prob_uncertinty
tic, toc = (time.time, time.time)
|