pgps-demo / core /train.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import time
from utils import *
def train(args, epoch, train_loader, model, criterion, optimizer):
batch_time = AverageMeter('Time', ':5.3f')
data_time = AverageMeter('Data', ':5.3f')
losses = AverageMeter('Loss', ':.4e')
progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses],
args, prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(train_loader):
'''
text_dict = {'token', 'sect_tag', 'class_tag', 'len'}
var_dict = {'pos', 'len', 'var_value', 'arg_value'}
exp_dict = {'exp', 'len', 'answer'}
'''
# measure data loading time
data_time.update(time.time() - end)
# set cuda for input data
diagrams = diagrams.cuda()
set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
# compute output
output = model(diagrams, text_dict, var_dict, exp_dict, is_train=True)
loss = criterion(output, exp_dict['exp'][:,1:].clone(), exp_dict['len']-1) # Remove special symbol [SOS]
# update the loss
torch.distributed.barrier()
reduced_loss = reduce_mean(loss, args.nprocs)
losses.update(reduced_loss.item(), len(diagrams))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i, lr = optimizer.state_dict()['param_groups'][0]['lr'])
return losses.avg