Spaces:
Build error
Build error
| import torch | |
| from tqdm.auto import tqdm | |
| import time | |
| import numpy as np | |
| from src.losses import calc_prob_uncertinty | |
| tic, toc = (time.time, time.time) | |
| def train(probe, device, train_loader, optimizer, epoch, loss_func, | |
| class_names=None, report=False, verbose_interval=5, layer_num=40, | |
| head=None, verbose=True, return_raw_outputs=False, one_hot=False, uncertainty=False, **kwargs,): | |
| """ | |
| :param model: pytorch model (class:torch.nn.Module) | |
| :param device: device used to train the model (e.g. torch.device("cuda") for training on GPU) | |
| :param train_loader: torch.utils.data.DataLoader of train dataset | |
| :param optimizer: optimizer for the model | |
| :param epoch: current epoch of training | |
| :param loss_func: loss function for the training | |
| :param class_names: str Name for the classification classses. used in train report | |
| :param report: whether to print a classification report of training | |
| :param train_verbose: print a train progress report after how many batches of training in each epoch | |
| :return: average loss, train accuracy, true labels, predictions | |
| """ | |
| assert (verbose_interval is None) or verbose_interval > 0, "invalid verbose_interval, verbose_interval(int) > 0" | |
| starttime = tic() | |
| # Set the model to the train mode: Essential for proper gradient descent | |
| probe.train() | |
| loss_sum = 0 | |
| correct = 0 | |
| tot = 0 | |
| preds = [] | |
| truths = [] | |
| # Iterate through the train dataset | |
| for batch_idx, batch in enumerate(train_loader): | |
| batch_size = 1 | |
| target = batch["age"].long().cuda() | |
| if one_hot: | |
| target = torch.nn.functional.one_hot(target, **kwargs).float() | |
| optimizer.zero_grad() | |
| if layer_num or layer_num == 0: | |
| act = batch["hidden_states"][:, layer_num,].to("cuda") | |
| else: | |
| act = batch["hidden_states"].to("cuda") | |
| output = probe(act) | |
| if not one_hot: | |
| loss = loss_func(output[0], target, **kwargs) | |
| else: | |
| loss = loss_func(output[0], target) | |
| loss.backward() | |
| optimizer.step() | |
| loss_sum += loss.sum().item() | |
| if uncertainty: | |
| pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy()) | |
| pred = torch.argmax(output[0], axis=1) | |
| # In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax | |
| # Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization | |
| # To save some computation time | |
| if len(target.shape) > 1: | |
| target = torch.argmax(target, axis=1) | |
| correct += np.sum(np.array(pred.detach().cpu().numpy()) == np.array(target.detach().cpu().numpy())) | |
| if return_raw_outputs: | |
| preds.append(pred.detach().cpu().numpy()) | |
| truths.append(target.detach().cpu().numpy()) | |
| tot += pred.shape[0] | |
| train_acc = correct / tot | |
| loss_avg = loss_sum / len(train_loader) | |
| endtime = toc() | |
| if verbose: | |
| print('\nTrain set: Average loss: {:.4f} ({:.3f} sec) Accuracy: {:.3f}\n'.\ | |
| format(loss_avg, | |
| endtime-starttime, | |
| train_acc)) | |
| preds = np.concatenate(preds) | |
| truths = np.concatenate(truths) | |
| if return_raw_outputs: | |
| return loss_avg, train_acc, preds, truths | |
| else: | |
| return loss_avg, train_acc | |
| def test(probe, device, test_loader, loss_func, return_raw_outputs=False, verbose=True, | |
| layer_num=40, scheduler=None, one_hot=False, uncertainty=False, **kwargs): | |
| """ | |
| :param model: pytorch model (class:torch.nn.Module) | |
| :param device: device used to train the model (e.g. torch.device("cuda") for training on GPU) | |
| :param test_loader: torch.utils.data.DataLoader of test dataset | |
| :param loss_func: loss function for the training | |
| :param class_names: str Name for the classification classses. used in train report | |
| :param test_report: whether to print a classification report of testing after each epoch | |
| :param return_raw_outputs: whether return the raw outputs of model (before argmax). used for auc computation | |
| :return: average test loss, test accuracy, true labels, predictions, (and raw outputs \ | |
| from model if return_raw_outputs) | |
| """ | |
| # Set the model to evaluation mode: Essential for testing model | |
| probe.eval() | |
| test_loss = 0 | |
| tot = 0 | |
| correct = 0 | |
| preds = [] | |
| truths = [] | |
| # Do not call gradient descent on the test set | |
| # We don't adjust the weights of model on the test set | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(test_loader): | |
| batch_size = 1 | |
| target = batch["age"].long().cuda() | |
| if one_hot: | |
| target = torch.nn.functional.one_hot(target, **kwargs).float() | |
| if layer_num or layer_num == 0: | |
| act = batch["hidden_states"][:, layer_num,].to("cuda") | |
| else: | |
| act = batch["hidden_states"].to("cuda") | |
| output = probe(act) | |
| if uncertainty: | |
| pred, uncertainty = calc_prob_uncertinty(output[0].detach().cpu().numpy()) | |
| pred = torch.argmax(output[0], axis=1) | |
| if not one_hot: | |
| loss = loss_func(output[0], target, **kwargs) | |
| else: | |
| loss = loss_func(output[0], target) | |
| test_loss += loss.sum().item() # sum up batch loss | |
| # In the Scikit-Learn's implementation of OvR Multi-class Logistic Regression. They linearly normalized the predicted probability and then call argmax | |
| # Below is an equivalent implementation of the scikit-learn's decision function. The only difference is we didn't do the linearly normalization | |
| # To save some computation time | |
| if len(target.shape) > 1: | |
| target = torch.argmax(target, axis=1) | |
| pred = np.array(pred.detach().cpu().numpy()) | |
| target = np.array(target.detach().cpu().numpy()) | |
| correct += np.sum(pred == target) | |
| tot += pred.shape[0] | |
| if return_raw_outputs: | |
| preds.append(pred) | |
| truths.append(target) | |
| test_loss /= len(test_loader) | |
| if scheduler: | |
| scheduler.step(test_loss) | |
| test_acc = correct / tot | |
| if verbose: | |
| print('Test set: Average loss: {:.4f}, Accuracy: {:.3f}\n'.format( | |
| test_loss, | |
| test_acc)) | |
| preds = np.concatenate(preds) | |
| truths = np.concatenate(truths) | |
| # If return the raw outputs (before argmax) from the model | |
| if return_raw_outputs: | |
| return test_loss, test_acc, preds, truths | |
| else: | |
| return test_loss, test_acc | |
| import torch | |
| from tqdm.auto import tqdm | |
| import time | |
| import numpy as np | |
| from .losses import calc_prob_uncertinty | |
| tic, toc = (time.time, time.time) | |