Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import argparse | |
| def sigmoid(x): | |
| return 1 / (1 + np.exp(-x)) | |
| def parser(): | |
| #TODO: create parser | |
| ap = argparse.ArgumentParser() | |
| return ap.parse_args() | |
| def train(net, trainloader, criterion, batch_size, target_names, | |
| validationloader=None, optimizer=None, | |
| scheduler=None, epochs=50, logdir=None, metrics=None, | |
| verbose=True, tuner=False, checkpoint_dir=None): | |
| ''' training loop function for simple | |
| supervised learning task. | |
| Args: | |
| net (torch.nn.Module): network to train | |
| trainloader (torch.utils.data.DataLoader): | |
| train data loader | |
| criterion (torch.nn.object): criterion with which | |
| to optimize the provided network | |
| batch_size (int): batch of trainloader and validationloader | |
| validationloader (torch.utils.data.DataLoader, optional): | |
| validation data loader | |
| optimizer (torch.optim.Optimizer, optional): | |
| optimizer function, defaults to torch.nn.optim.Adam w/ amsgrad | |
| scheduler (torch.optim.lr_scheduler, optional): | |
| learning rate scheduler object | |
| epochs (int, optional): number of epochs to train network, | |
| defaults to 50 | |
| logdir (string, optional): path to tensorboard log directory, | |
| if None logging default to ./runs/ directory | |
| metrics (list of tuples, optional): metrics to be logged with | |
| name and metric being the first and second element of the | |
| each tuple respectively | |
| verbose (bool, optional): whether or not to print information | |
| to console | |
| tuner (bool, optional): whether to employ ray tune | |
| ''' | |
| from torch.utils.tensorboard import SummaryWriter | |
| from sklearn.metrics import classification_report | |
| writer = SummaryWriter(log_dir=logdir) | |
| if (verbose): | |
| from tensorflow.keras.utils import Progbar | |
| if (optimizer is None): | |
| optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, amsgrad=True) | |
| start_epoch = 0 | |
| if (checkpoint_dir is not None): | |
| # state, optim_state = torch.load(os.path.join( | |
| # checkpoint_dir, "checkpoint")) | |
| state = torch.load(checkpoint_dir) | |
| start_epoch = state['epoch'] | |
| net.load_state_dict(state['state_dict']) | |
| optimizer.load_state_dict(state['optimizer']) | |
| assert epochs > 0, "Assertion failed. epochs must be greater than 0!" | |
| steps = 0 | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # get device | |
| net.train(True) | |
| # net.to(device) | |
| if (tuner): | |
| from ray import tune | |
| import os | |
| for i in range(start_epoch, start_epoch + epochs): | |
| num_batches = len(trainloader) | |
| num_samples = num_batches * batch_size | |
| if (verbose): | |
| print("\nepoch {}/{}".format(i+1, start_epoch+epochs)) | |
| pbar = Progbar(target=num_batches) | |
| # if (metrics is not None): | |
| # train_metrics = [0 for metric in metrics] | |
| y_true = np.zeros((num_samples, 12)) | |
| y_pred = np.zeros((num_samples, 12)) | |
| idx = 0 | |
| for j, data in enumerate(iter(trainloader)): | |
| # get the inputs; data is a list of [inputs, labels] | |
| inputs, labels = data[0].to(device), data[1].to(device) | |
| # inputs, labels = data[0].to(device), [data[1][0].to(device), data[1][1].to(device)] | |
| # zero the parameter gradients | |
| optimizer.zero_grad() | |
| # forward + backward + optimize | |
| outputs = net(inputs) | |
| train_loss = criterion(outputs, labels) | |
| train_loss.backward() | |
| optimizer.step() | |
| y_true[idx:idx+outputs.shape[0], :] = labels.detach().cpu().numpy() | |
| y_pred[idx:idx+outputs.shape[0], :] = outputs.detach().cpu().numpy() | |
| idx += outputs.shape[0] | |
| if (scheduler is not None): | |
| scheduler.step() | |
| if (verbose): | |
| pbar.update(j, values=[("loss", | |
| train_loss.detach().cpu().numpy().item())]) | |
| steps += 1 | |
| writer.add_scalar('Loss/train', | |
| train_loss.detach().cpu().numpy().item(), steps) | |
| # if (metrics is not None): | |
| # for (j, metric) in enumerate(metrics): | |
| # # train_metrics[j] += metric[1](outputs, labels).detach().cpu().numpy() | |
| # train_metrics[j] += metric[1](outputs, labels) | |
| rep = classification_report(y_true.astype('int'), | |
| (sigmoid(y_pred) > 0.5).astype('int'), target_names=target_names, | |
| output_dict=True) | |
| for k in rep.keys(): | |
| for j in rep[k].keys(): | |
| writer.add_scalar(j + '/' + k + '/train', | |
| rep[k][j], steps) | |
| # if (metrics is not None): | |
| # for (j, metric) in enumerate(metrics): | |
| # # writer.add_scalar(metric[0] + '/train', | |
| # # train_metrics[j] / num_samples, steps) | |
| # writer.add_scalar(metric[0] + '/train', | |
| # train_metrics[j] / num_batches, steps) | |
| if (validationloader is not None): | |
| net.train(False) | |
| val_loss = 0 | |
| # if (metrics is not None): | |
| # val_metrics = [0 for metric in metrics] | |
| num_val_batches = len(validationloader) | |
| num_val_samples = num_val_batches * batch_size | |
| y_val_true = np.zeros((num_val_samples, 12)) | |
| y_val_pred = np.zeros((num_val_samples, 12)) | |
| idx = 0 | |
| for data in iter(validationloader): | |
| # get the inputs; data is a list of [inputs, labels] | |
| inputs, labels = data[0].to(device), data[1].to(device) | |
| # inputs, labels = data[0].to(device), [data[1][0].to(device), data[1][1].to(device)] | |
| outputs = net(inputs) | |
| val_loss += criterion(outputs, labels).detach().cpu().numpy() | |
| y_val_true[idx:idx+outputs.shape[0], :] = labels.detach().cpu().numpy() | |
| y_val_pred[idx:idx+outputs.shape[0], :] = outputs.detach().cpu().numpy() | |
| idx += outputs.shape[0] | |
| # if (metrics is not None): | |
| # for (j, metric) in enumerate(metrics): | |
| # # val_metrics[j] += metric[1](outputs, labels).detach().cpu().numpy() | |
| # val_metrics[j] += metric[1](outputs, labels) | |
| val_loss /= (num_val_batches) # assume all validation set used | |
| # scheduler.step(val_loss) | |
| rep = classification_report(y_val_true.astype('int'), | |
| (sigmoid(y_val_pred) > 0.5).astype('int'), target_names=target_names, | |
| output_dict=True) | |
| print(classification_report(y_val_true.astype('int'), | |
| (sigmoid(y_val_pred) > 0.5).astype('int'), target_names=target_names)) | |
| # output_dict=False) | |
| #print(rep2) | |
| for k in rep.keys(): | |
| for j in rep[k].keys(): | |
| writer.add_scalar(j + '/' + k + '/valid', | |
| rep[k][j], steps) | |
| writer.add_scalar('Loss/valid', val_loss, steps) | |
| # if (metrics is not None): | |
| # for (j, metric) in enumerate(metrics): | |
| # # writer.add_scalar(metric[0] + '/valid', | |
| # # val_metrics[j] / num_val_samples, steps) | |
| # writer.add_scalar(metric[0] + '/valid', | |
| # val_metrics[j] / num_val_batches, steps) | |
| # if (tuner): | |
| # with tune.checkpoint_dir(i+1) as checkpoint_dir: | |
| # path = os.path.join(checkpoint_dir, "checkpoint") | |
| # torch.save((net.state_dict(), optimizer.state_dict()), path) | |
| # tune.report(loss=val_loss, accuracy=val_metrics[0] / num_val_samples, iters=i+1) | |
| if (verbose): | |
| pbar.update(num_batches, values=[("val_loss",val_loss.item())]) | |
| net.train(True) | |
| else: | |
| if (verbose): | |
| pbar.update(num_batches, values=None) | |
| if __name__ == "__main__": | |
| args = parser() # get arguments | |
| # TODO: implement args such that we can train from the command line | |
| #train(args.net, args.trainloader, args.criterion, args.batch_size, | |
| # args.validationloader, args.optimizer, | |
| # args.scheduler, args.epochs, args.logdir, args.metrics, | |
| # args.verbose, args.tuner, args.checkpoint_dir): | |