File size: 1,699 Bytes
383bfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import time
from utils import *

def validate(args, val_loader, model, tgt_lang):
    
    batch_time = AverageMeter('Time', ':5.3f')
    acc_ans = AverageMeter('Ans_Acc', ':5.4f')
    acc_eq = AverageMeter('Eq_Acc', ':5.4f')
    progress = ProgressMeter(len(val_loader), [batch_time, acc_ans, acc_eq], args, prefix='Test: ')
    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(val_loader):
            # set cuda for input data
            diagrams = diagrams.cuda()
            set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
            # compute output
            output = model(diagrams, text_dict, var_dict, exp_dict, is_train=False)
            if args.eval_method == "completion":
                acc1, acc2 = compute_exp_result_comp(output, var_dict, exp_dict, tgt_lang)
            elif args.eval_method == "choice":
                acc1, acc2 = compute_exp_result_choice(output, var_dict, exp_dict, tgt_lang)
            elif args.eval_method == "top3":
                acc1, acc2 = compute_exp_result_topk(output, var_dict, exp_dict, tgt_lang, k_num=3)

            torch.distributed.barrier()
            
            reduced_acc_ans = reduce_mean(torch.tensor([acc1]).cuda(), args.nprocs)
            reduced_acc_eq = reduce_mean(torch.tensor([acc2]).cuda(), args.nprocs)

            acc_ans.update(reduced_acc_ans.item(), len(diagrams))
            acc_eq.update(reduced_acc_eq.item(), len(diagrams))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    return acc_ans.avg, acc_eq.avg