pgps-demo / core /test.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import time
from utils import *
def validate(args, val_loader, model, tgt_lang):
batch_time = AverageMeter('Time', ':5.3f')
acc_ans = AverageMeter('Ans_Acc', ':5.4f')
acc_eq = AverageMeter('Eq_Acc', ':5.4f')
progress = ProgressMeter(len(val_loader), [batch_time, acc_ans, acc_eq], args, prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(val_loader):
# set cuda for input data
diagrams = diagrams.cuda()
set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
# compute output
output = model(diagrams, text_dict, var_dict, exp_dict, is_train=False)
if args.eval_method == "completion":
acc1, acc2 = compute_exp_result_comp(output, var_dict, exp_dict, tgt_lang)
elif args.eval_method == "choice":
acc1, acc2 = compute_exp_result_choice(output, var_dict, exp_dict, tgt_lang)
elif args.eval_method == "top3":
acc1, acc2 = compute_exp_result_topk(output, var_dict, exp_dict, tgt_lang, k_num=3)
torch.distributed.barrier()
reduced_acc_ans = reduce_mean(torch.tensor([acc1]).cuda(), args.nprocs)
reduced_acc_eq = reduce_mean(torch.tensor([acc2]).cuda(), args.nprocs)
acc_ans.update(reduced_acc_ans.item(), len(diagrams))
acc_eq.update(reduced_acc_eq.item(), len(diagrams))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return acc_ans.avg, acc_eq.avg