huazai676's picture
Upload 6 files
ba80248 verified
import torch
from torch import nn
import numpy as np
import argparse
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def parser():
#TODO: create parser
ap = argparse.ArgumentParser()
return ap.parse_args()
def train(net, trainloader, criterion, batch_size, target_names,
validationloader=None, optimizer=None,
scheduler=None, epochs=50, logdir=None, metrics=None,
verbose=True, tuner=False, checkpoint_dir=None):
''' training loop function for simple
supervised learning task.
Args:
net (torch.nn.Module): network to train
trainloader (torch.utils.data.DataLoader):
train data loader
criterion (torch.nn.object): criterion with which
to optimize the provided network
batch_size (int): batch of trainloader and validationloader
validationloader (torch.utils.data.DataLoader, optional):
validation data loader
optimizer (torch.optim.Optimizer, optional):
optimizer function, defaults to torch.nn.optim.Adam w/ amsgrad
scheduler (torch.optim.lr_scheduler, optional):
learning rate scheduler object
epochs (int, optional): number of epochs to train network,
defaults to 50
logdir (string, optional): path to tensorboard log directory,
if None logging default to ./runs/ directory
metrics (list of tuples, optional): metrics to be logged with
name and metric being the first and second element of the
each tuple respectively
verbose (bool, optional): whether or not to print information
to console
tuner (bool, optional): whether to employ ray tune
'''
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import classification_report
writer = SummaryWriter(log_dir=logdir)
if (verbose):
from tensorflow.keras.utils import Progbar
if (optimizer is None):
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, amsgrad=True)
start_epoch = 0
if (checkpoint_dir is not None):
# state, optim_state = torch.load(os.path.join(
# checkpoint_dir, "checkpoint"))
state = torch.load(checkpoint_dir)
start_epoch = state['epoch']
net.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
assert epochs > 0, "Assertion failed. epochs must be greater than 0!"
steps = 0
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # get device
net.train(True)
# net.to(device)
if (tuner):
from ray import tune
import os
for i in range(start_epoch, start_epoch + epochs):
num_batches = len(trainloader)
num_samples = num_batches * batch_size
if (verbose):
print("\nepoch {}/{}".format(i+1, start_epoch+epochs))
pbar = Progbar(target=num_batches)
# if (metrics is not None):
# train_metrics = [0 for metric in metrics]
y_true = np.zeros((num_samples, 12))
y_pred = np.zeros((num_samples, 12))
idx = 0
for j, data in enumerate(iter(trainloader)):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# inputs, labels = data[0].to(device), [data[1][0].to(device), data[1][1].to(device)]
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
train_loss = criterion(outputs, labels)
train_loss.backward()
optimizer.step()
y_true[idx:idx+outputs.shape[0], :] = labels.detach().cpu().numpy()
y_pred[idx:idx+outputs.shape[0], :] = outputs.detach().cpu().numpy()
idx += outputs.shape[0]
if (scheduler is not None):
scheduler.step()
if (verbose):
pbar.update(j, values=[("loss",
train_loss.detach().cpu().numpy().item())])
steps += 1
writer.add_scalar('Loss/train',
train_loss.detach().cpu().numpy().item(), steps)
# if (metrics is not None):
# for (j, metric) in enumerate(metrics):
# # train_metrics[j] += metric[1](outputs, labels).detach().cpu().numpy()
# train_metrics[j] += metric[1](outputs, labels)
rep = classification_report(y_true.astype('int'),
(sigmoid(y_pred) > 0.5).astype('int'), target_names=target_names,
output_dict=True)
for k in rep.keys():
for j in rep[k].keys():
writer.add_scalar(j + '/' + k + '/train',
rep[k][j], steps)
# if (metrics is not None):
# for (j, metric) in enumerate(metrics):
# # writer.add_scalar(metric[0] + '/train',
# # train_metrics[j] / num_samples, steps)
# writer.add_scalar(metric[0] + '/train',
# train_metrics[j] / num_batches, steps)
if (validationloader is not None):
net.train(False)
val_loss = 0
# if (metrics is not None):
# val_metrics = [0 for metric in metrics]
num_val_batches = len(validationloader)
num_val_samples = num_val_batches * batch_size
y_val_true = np.zeros((num_val_samples, 12))
y_val_pred = np.zeros((num_val_samples, 12))
idx = 0
for data in iter(validationloader):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# inputs, labels = data[0].to(device), [data[1][0].to(device), data[1][1].to(device)]
outputs = net(inputs)
val_loss += criterion(outputs, labels).detach().cpu().numpy()
y_val_true[idx:idx+outputs.shape[0], :] = labels.detach().cpu().numpy()
y_val_pred[idx:idx+outputs.shape[0], :] = outputs.detach().cpu().numpy()
idx += outputs.shape[0]
# if (metrics is not None):
# for (j, metric) in enumerate(metrics):
# # val_metrics[j] += metric[1](outputs, labels).detach().cpu().numpy()
# val_metrics[j] += metric[1](outputs, labels)
val_loss /= (num_val_batches) # assume all validation set used
# scheduler.step(val_loss)
rep = classification_report(y_val_true.astype('int'),
(sigmoid(y_val_pred) > 0.5).astype('int'), target_names=target_names,
output_dict=True)
print(classification_report(y_val_true.astype('int'),
(sigmoid(y_val_pred) > 0.5).astype('int'), target_names=target_names))
# output_dict=False)
#print(rep2)
for k in rep.keys():
for j in rep[k].keys():
writer.add_scalar(j + '/' + k + '/valid',
rep[k][j], steps)
writer.add_scalar('Loss/valid', val_loss, steps)
# if (metrics is not None):
# for (j, metric) in enumerate(metrics):
# # writer.add_scalar(metric[0] + '/valid',
# # val_metrics[j] / num_val_samples, steps)
# writer.add_scalar(metric[0] + '/valid',
# val_metrics[j] / num_val_batches, steps)
# if (tuner):
# with tune.checkpoint_dir(i+1) as checkpoint_dir:
# path = os.path.join(checkpoint_dir, "checkpoint")
# torch.save((net.state_dict(), optimizer.state_dict()), path)
# tune.report(loss=val_loss, accuracy=val_metrics[0] / num_val_samples, iters=i+1)
if (verbose):
pbar.update(num_batches, values=[("val_loss",val_loss.item())])
net.train(True)
else:
if (verbose):
pbar.update(num_batches, values=None)
if __name__ == "__main__":
args = parser() # get arguments
# TODO: implement args such that we can train from the command line
#train(args.net, args.trainloader, args.criterion, args.batch_size,
# args.validationloader, args.optimizer,
# args.scheduler, args.epochs, args.logdir, args.metrics,
# args.verbose, args.tuner, args.checkpoint_dir):