File size: 1,718 Bytes
383bfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import time
from utils import *

def train(args, epoch, train_loader, model, criterion, optimizer):

    batch_time = AverageMeter('Time', ':5.3f')
    data_time = AverageMeter('Data', ':5.3f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses],
                             args, prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    end = time.time()

    for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(train_loader):
        '''
            text_dict = {'token', 'sect_tag', 'class_tag', 'len'}
            var_dict = {'pos', 'len', 'var_value', 'arg_value'}
            exp_dict = {'exp', 'len', 'answer'}
        '''
        # measure data loading time
        data_time.update(time.time() - end)
        # set cuda for input data
        diagrams = diagrams.cuda()
        set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
        # compute output
        output = model(diagrams, text_dict, var_dict, exp_dict, is_train=True)
        loss = criterion(output, exp_dict['exp'][:,1:].clone(), exp_dict['len']-1) # Remove special symbol [SOS]
        # update the loss
        torch.distributed.barrier()
        reduced_loss = reduce_mean(loss, args.nprocs)
        losses.update(reduced_loss.item(), len(diagrams))
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            progress.display(i, lr = optimizer.state_dict()['param_groups'][0]['lr'])

    return losses.avg