pgps-demo / utils /utils.py
asdfasdfdsafdsa's picture
Fix Lang classes, CUDA compatibility, and config imports
2a2cec1 verified
import os
import torch
from utils.lr_scheduler import WarmupMultiStepLR
import datetime
import torch.distributed as dist
from datasets.operators import result_compute, normalize_exp
from func_timeout import func_timeout
import random
import gc
def save_checkpoint(state, is_best, dump_path=None):
if is_best:
dump_path_best = os.path.join(dump_path, 'best_model.pth')
torch.save(state, dump_path_best)
else:
dump_path_recent = os.path.join(dump_path, str(state['epoch'])+'.pth')
torch.save(state, dump_path_recent)
class AverageMeter(object):
"""
Computes and stores the average and current value
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, args, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
self.args = args
def display(self, batch, lr=None):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
if not lr is None:
entries += ["lr: "+str(format(lr, '.6f'))]
self.args.logger.info('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def adjust_learning_rate(optimizer, epoch, args):
"""
Sets the learning rate to the initial LR decayed by 10 every 30 epochs
"""
lr = args.lr * (0.1**(epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1, )):
"""
Computes the accuracy over the k top predictions for the specified values of k
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def get_scheduler(args, optimizer):
if args.scheduler_type == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
args.scheduler_step,
gamma=args.scheduler_factor,
)
elif args.scheduler_type == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.max_epochs, eta_min=1e-6)
elif args.scheduler_type == "warmup":
scheduler = WarmupMultiStepLR(
optimizer,
args.scheduler_step,
gamma=args.scheduler_factor,
warmup_epochs=args.warm_epoch,
)
else:
raise NotImplementedError("Unsupported LR Scheduler: {}".format(args.scheduler_type))
return scheduler
def get_optimizer(args, model):
if args.use_MLM_pretrain:
pretrain_params = list(map(id, model.mlm_pretrain.parameters()))
other_params = filter(lambda p: id(p) not in pretrain_params, model.parameters())
if args.optimizer_type == "SGD":
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True,
)
elif args.optimizer_type == "ADAM":
if args.use_MLM_pretrain:
optimizer = torch.optim.Adam(
[{"params":model.mlm_pretrain.parameters()},
{"params":other_params, "lr":args.lr_LM}],
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
)
else:
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
)
elif args.optimizer_type == "ADAMW":
if args.use_MLM_pretrain:
optimizer = torch.optim.AdamW(
[{"params":model.mlm_pretrain.parameters(), "lr":args.lr_LM},
{"params":other_params}],
lr=args.lr,
weight_decay=args.weight_decay,
)
else:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
else:
raise NotImplementedError("Unsupported Optimizer Type : {}".format(args.optimizer_type))
return optimizer
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
def set_cuda(data_dict):
for key in data_dict:
if torch.is_tensor(data_dict[key]):
data_dict[key] = data_dict[key].cuda()
def initialize_logger(params, ):
"""
Initialize the experience:
- dump parameters
- create a logger
"""
while True:
exp_id = datetime.datetime.strftime(datetime.datetime.now(),'%Y-%m-%d-%H-%M-%S')
if not os.path.exists(os.path.join(params.dump_path, exp_id)):
break
params.dump_path = os.path.join(params.dump_path, exp_id)
if params.local_rank == 0:
os.makedirs(params.dump_path)
# create a logger
logger = create_logger(os.path.join(params.dump_path,'record.log'), params.local_rank)
logger.info("============ Initialized logger ============")
logger.info("\n"+"\n".join("\t\t\t\t%s: %s" % (k, str(v))
for k, v in sorted(dict(vars(params)).items())))
logger.info("The experiment results will be stored in %s" % params.dump_path)
return logger
def aeq(*args):
"""
Assert all arguments have the same value
"""
arguments = (arg for arg in args)
first = next(arguments)
assert all(arg == first for arg in arguments), \
"Not all arguments have the same value: " + str(args)
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return torch.arange(0, max_len, device=lengths.device) \
.type_as(lengths) \
.repeat(batch_size, 1) \
.lt(lengths.unsqueeze(1))
def copy_list(l):
r = []
if len(l) == 0:
return r
for i in l:
if type(i) is list:
r.append(copy_list(i))
else:
r.append(i)
return r
def compute_exp_result_choice(test_preds, var_dict, exp_dict, tgt_lang):
"""
Arguments
test_preds: B x candi_size(beam_size) x token_list
var_dict: {'pos', 'len', 'var_value', 'arg_value'}
exp_dict: {'exp', 'len', 'answer'}
tgt_lang: vocab of target text
Returns:
ans_acc
eq_acc
"""
gc.collect()
ans_num = eq_num = 0
for k in range(len(test_preds)): # batch id
tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
for i, item in enumerate(var_dict['arg_value'][k])}
tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
num_list = var_dict['var_value'][k]
tgt_result = float(exp_dict['answer'][k])
choices = exp_dict['choices'][k]
is_find_ans = False
for j in range(len(test_preds[k])): # pred candi id
try:
pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
pred = normalize_exp(pred)
pred_result = float(func_timeout(2.0, result_compute, \
kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
if pred == tgt:
ans_num += 1
eq_num += 1
is_find_ans = True
break
for item in choices:
if abs(pred_result-item)<5e-2:
is_find_ans = True
if is_find_ans and abs(pred_result-tgt_result)<5e-3:
ans_num +=1
if len(pred)==len(tgt):
eq_num += 1
if is_find_ans: break
except:
pass
if not is_find_ans:
pred_result = random.choice(choices)
if abs(pred_result-tgt_result)<5e-2:
ans_num +=1
return ans_num/len(test_preds), eq_num/len(test_preds)
def compute_exp_result_topk(test_preds, var_dict, exp_dict, tgt_lang, k_num = 3):
"""
Arguments
test_preds: B x candi_size(beam_size) x token_list
var_dict: {'pos', 'len', 'var_value', 'arg_value'}
exp_dict: {'exp', 'len', 'answer'}
tgt_lang: vocab of target text
Returns:
ans_acc
eq_acc
"""
gc.collect()
ans_num = eq_num = 0
for k in range(len(test_preds)): # batch id
tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
for i, item in enumerate(var_dict['arg_value'][k])}
tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
num_list = var_dict['var_value'][k]
tgt_result = float(exp_dict['answer'][k])
is_ans_same = is_eq_same = False
for j in range(k_num): # top-n
try:
pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
pred = normalize_exp(pred)
pred_result = float(func_timeout(2.0, result_compute, \
kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
if pred == tgt:
is_ans_same = True
is_eq_same = True
break
if abs(pred_result-tgt_result)<5e-3:
is_ans_same = True
if len(pred)==len(tgt):
is_eq_same = True
break
except:
pass
if is_ans_same: ans_num +=1
if is_eq_same: eq_num +=1
return ans_num/len(test_preds), eq_num/len(test_preds)
def compute_exp_result_comp(test_preds, var_dict, exp_dict, tgt_lang):
"""
Arguments
test_preds: B x candi_size(beam_size) x token_list
var_dict: {'pos', 'len', 'var_value', 'arg_value'}
exp_dict: {'exp', 'len', 'answer'}
tgt_lang: vocab of target text
Returns:
ans_acc
eq_acc
"""
gc.collect()
ans_num = eq_num = 0
for k in range(len(test_preds)): # batch id
tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
for i, item in enumerate(var_dict['arg_value'][k])}
tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
num_list = var_dict['var_value'][k]
tgt_result = float(exp_dict['answer'][k])
is_ans_same = is_eq_same = False
for j in range(len(test_preds[k])): # pred candi id
try:
pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
pred = normalize_exp(pred)
pred_result = float(func_timeout(2.0, result_compute, \
kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
if pred == tgt:
is_ans_same = True
is_eq_same = True
break
if abs(pred_result-tgt_result)<5e-3:
is_ans_same = True
if len(pred)==len(tgt):
is_eq_same = True
break
except:
pass
if is_ans_same: ans_num +=1
if is_eq_same: eq_num +=1
return ans_num/len(test_preds), eq_num/len(test_preds)