talktuner-probe-training / src /train_test_utils.py
jmjoseph's picture
Deploy TalkTuner probe training interface
5413412 verified
import torch
from tqdm.auto import tqdm
import time
import numpy as np
from src.losses import calc_prob_uncertinty
tic, toc = (time.time, time.time)
def train(probe, device, train_loader, optimizer, epoch, loss_func,
class_names=None, report=False, verbose_interval=5, layer_num=40,
head=None, verbose=True, return_raw_outputs=False, one_hot=False, uncertainty=False, **kwargs,):
"""
:param model: pytorch model (class:torch.nn.Module)
:param device: device used to train the model (e.g. torch.device("cuda") for training on GPU)
:param train_loader: torch.utils.data.DataLoader of train dataset
:param optimizer: optimizer for the model
:param epoch: current epoch of training
:param loss_func: loss function for the training
:param class_names: str Name for the classification classses. used in train report
:param report: whether to print a classification report of training
:param train_verbose: print a train progress report after how many batches of training in each epoch
:return: average loss, train accuracy, true labels, predictions
"""
assert (verbose_interval is None) or verbose_interval > 0, "invalid verbose_interval, verbose_interval(int) > 0"
starttime = tic()
# Set the model to the train mode: Essential for proper gradient descent
probe.train()
loss_sum = 0
correct = 0
tot = 0
preds = []
truths = []
# Iterate through the train dataset
for batch_idx, batch in enumerate(train_loader):
batch_size = 1
target = batch["age"].long().cuda()
if one_hot:
target = torch.nn.functional.one_hot(target, **kwargs).float()
optimizer.zero_grad()
if layer_num or layer_num == 0:
act = batch["hidden_states"][:, layer_num,].to("cuda")
else:
act = batch["hidden_states"].to("cuda")
output = probe(act)
if not one_hot:
loss = loss_func(output[0], target, **kwargs)
else:
loss = loss_func(output[0], target)
loss.backward()
optimizer.step()
loss_sum += loss.sum().item()
if uncertainty:
pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy())
pred = torch.argmax(output[0], axis=1)
# In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax
# Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization
# To save some computation time
if len(target.shape) > 1:
target = torch.argmax(target, axis=1)
correct += np.sum(np.array(pred.detach().cpu().numpy()) == np.array(target.detach().cpu().numpy()))
if return_raw_outputs:
preds.append(pred.detach().cpu().numpy())
truths.append(target.detach().cpu().numpy())
tot += pred.shape[0]
train_acc = correct / tot
loss_avg = loss_sum / len(train_loader)
endtime = toc()
if verbose:
print('\nTrain set: Average loss: {:.4f} ({:.3f} sec) Accuracy: {:.3f}\n'.\
format(loss_avg,
endtime-starttime,
train_acc))
preds = np.concatenate(preds)
truths = np.concatenate(truths)
if return_raw_outputs:
return loss_avg, train_acc, preds, truths
else:
return loss_avg, train_acc
def test(probe, device, test_loader, loss_func, return_raw_outputs=False, verbose=True,
layer_num=40, scheduler=None, one_hot=False, uncertainty=False, **kwargs):
"""
:param model: pytorch model (class:torch.nn.Module)
:param device: device used to train the model (e.g. torch.device("cuda") for training on GPU)
:param test_loader: torch.utils.data.DataLoader of test dataset
:param loss_func: loss function for the training
:param class_names: str Name for the classification classses. used in train report
:param test_report: whether to print a classification report of testing after each epoch
:param return_raw_outputs: whether return the raw outputs of model (before argmax). used for auc computation
:return: average test loss, test accuracy, true labels, predictions, (and raw outputs \
from model if return_raw_outputs)
"""
# Set the model to evaluation mode: Essential for testing model
probe.eval()
test_loss = 0
tot = 0
correct = 0
preds = []
truths = []
# Do not call gradient descent on the test set
# We don't adjust the weights of model on the test set
with torch.no_grad():
for batch_idx, batch in enumerate(test_loader):
batch_size = 1
target = batch["age"].long().cuda()
if one_hot:
target = torch.nn.functional.one_hot(target, **kwargs).float()
if layer_num or layer_num == 0:
act = batch["hidden_states"][:, layer_num,].to("cuda")
else:
act = batch["hidden_states"].to("cuda")
output = probe(act)
if uncertainty:
pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy())
pred = torch.argmax(output[0], axis=1)
if not one_hot:
loss = loss_func(output[0], target, **kwargs)
else:
loss = loss_func(output[0], target)
test_loss += loss.sum().item() # sum up batch loss
# In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax
# Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization
# To save some computation time
if len(target.shape) > 1:
target = torch.argmax(target, axis=1)
pred = np.array(pred.detach().cpu().numpy())
target = np.array(target.detach().cpu().numpy())
correct += np.sum(pred == target)
tot += pred.shape[0]
if return_raw_outputs:
preds.append(pred)
truths.append(target)
test_loss /= len(test_loader)
if scheduler:
scheduler.step(test_loss)
test_acc = correct / tot
if verbose:
print('Test set: Average loss: {:.4f}, Accuracy: {:.3f}\n'.format(
test_loss,
test_acc))
preds = np.concatenate(preds)
truths = np.concatenate(truths)
# If return the raw outputs (before argmax) from the model
if return_raw_outputs:
return test_loss, test_acc, preds, truths
else:
return test_loss, test_acc
import torch
from tqdm.auto import tqdm
import time
import numpy as np
from .losses import calc_prob_uncertinty
tic, toc = (time.time, time.time)