Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from utils.lr_scheduler import WarmupMultiStepLR | |
| import datetime | |
| import torch.distributed as dist | |
| from datasets.operators import result_compute, normalize_exp | |
| from func_timeout import func_timeout | |
| import random | |
| import gc | |
| def save_checkpoint(state, is_best, dump_path=None): | |
| if is_best: | |
| dump_path_best = os.path.join(dump_path, 'best_model.pth') | |
| torch.save(state, dump_path_best) | |
| else: | |
| dump_path_recent = os.path.join(dump_path, str(state['epoch'])+'.pth') | |
| torch.save(state, dump_path_recent) | |
| class AverageMeter(object): | |
| """ | |
| Computes and stores the average and current value | |
| """ | |
| def __init__(self, name, fmt=':f'): | |
| self.name = name | |
| self.fmt = fmt | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
| return fmtstr.format(**self.__dict__) | |
| class ProgressMeter(object): | |
| def __init__(self, num_batches, meters, args, prefix=""): | |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
| self.meters = meters | |
| self.prefix = prefix | |
| self.args = args | |
| def display(self, batch, lr=None): | |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
| entries += [str(meter) for meter in self.meters] | |
| if not lr is None: | |
| entries += ["lr: "+str(format(lr, '.6f'))] | |
| self.args.logger.info('\t'.join(entries)) | |
| def _get_batch_fmtstr(self, num_batches): | |
| num_digits = len(str(num_batches // 1)) | |
| fmt = '{:' + str(num_digits) + 'd}' | |
| return '[' + fmt + '/' + fmt.format(num_batches) + ']' | |
| def adjust_learning_rate(optimizer, epoch, args): | |
| """ | |
| Sets the learning rate to the initial LR decayed by 10 every 30 epochs | |
| """ | |
| lr = args.lr * (0.1**(epoch // 30)) | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| def accuracy(output, target, topk=(1, )): | |
| """ | |
| Computes the accuracy over the k top predictions for the specified values of k | |
| """ | |
| with torch.no_grad(): | |
| maxk = max(topk) | |
| batch_size = target.size(0) | |
| _, pred = output.topk(maxk, 1, True, True) | |
| pred = pred.t() | |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
| res = [] | |
| for k in topk: | |
| correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) | |
| res.append(correct_k.mul_(100.0 / batch_size)) | |
| return res | |
| def get_scheduler(args, optimizer): | |
| if args.scheduler_type == "multistep": | |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| args.scheduler_step, | |
| gamma=args.scheduler_factor, | |
| ) | |
| elif args.scheduler_type == "cosine": | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=args.max_epochs, eta_min=1e-6) | |
| elif args.scheduler_type == "warmup": | |
| scheduler = WarmupMultiStepLR( | |
| optimizer, | |
| args.scheduler_step, | |
| gamma=args.scheduler_factor, | |
| warmup_epochs=args.warm_epoch, | |
| ) | |
| else: | |
| raise NotImplementedError("Unsupported LR Scheduler: {}".format(args.scheduler_type)) | |
| return scheduler | |
| def get_optimizer(args, model): | |
| if args.use_MLM_pretrain: | |
| pretrain_params = list(map(id, model.mlm_pretrain.parameters())) | |
| other_params = filter(lambda p: id(p) not in pretrain_params, model.parameters()) | |
| if args.optimizer_type == "SGD": | |
| optimizer = torch.optim.SGD( | |
| model.parameters(), | |
| lr=args.lr, | |
| momentum=args.momentum, | |
| weight_decay=args.weight_decay, | |
| nesterov=True, | |
| ) | |
| elif args.optimizer_type == "ADAM": | |
| if args.use_MLM_pretrain: | |
| optimizer = torch.optim.Adam( | |
| [{"params":model.mlm_pretrain.parameters()}, | |
| {"params":other_params, "lr":args.lr_LM}], | |
| lr=args.lr, | |
| betas=(0.9, 0.999), | |
| weight_decay=args.weight_decay, | |
| ) | |
| else: | |
| optimizer = torch.optim.Adam( | |
| model.parameters(), | |
| lr=args.lr, | |
| betas=(0.9, 0.999), | |
| weight_decay=args.weight_decay, | |
| ) | |
| elif args.optimizer_type == "ADAMW": | |
| if args.use_MLM_pretrain: | |
| optimizer = torch.optim.AdamW( | |
| [{"params":model.mlm_pretrain.parameters(), "lr":args.lr_LM}, | |
| {"params":other_params}], | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| else: | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| else: | |
| raise NotImplementedError("Unsupported Optimizer Type : {}".format(args.optimizer_type)) | |
| return optimizer | |
| def reduce_mean(tensor, nprocs): | |
| rt = tensor.clone() | |
| dist.all_reduce(rt, op=dist.ReduceOp.SUM) | |
| rt /= nprocs | |
| return rt | |
| def set_cuda(data_dict): | |
| for key in data_dict: | |
| if torch.is_tensor(data_dict[key]): | |
| data_dict[key] = data_dict[key].cuda() | |
| def initialize_logger(params, ): | |
| """ | |
| Initialize the experience: | |
| - dump parameters | |
| - create a logger | |
| """ | |
| while True: | |
| exp_id = datetime.datetime.strftime(datetime.datetime.now(),'%Y-%m-%d-%H-%M-%S') | |
| if not os.path.exists(os.path.join(params.dump_path, exp_id)): | |
| break | |
| params.dump_path = os.path.join(params.dump_path, exp_id) | |
| if params.local_rank == 0: | |
| os.makedirs(params.dump_path) | |
| # create a logger | |
| logger = create_logger(os.path.join(params.dump_path,'record.log'), params.local_rank) | |
| logger.info("============ Initialized logger ============") | |
| logger.info("\n"+"\n".join("\t\t\t\t%s: %s" % (k, str(v)) | |
| for k, v in sorted(dict(vars(params)).items()))) | |
| logger.info("The experiment results will be stored in %s" % params.dump_path) | |
| return logger | |
| def aeq(*args): | |
| """ | |
| Assert all arguments have the same value | |
| """ | |
| arguments = (arg for arg in args) | |
| first = next(arguments) | |
| assert all(arg == first for arg in arguments), \ | |
| "Not all arguments have the same value: " + str(args) | |
| def sequence_mask(lengths, max_len=None): | |
| """ | |
| Creates a boolean mask from sequence lengths. | |
| """ | |
| batch_size = lengths.numel() | |
| max_len = max_len or lengths.max() | |
| return torch.arange(0, max_len, device=lengths.device) \ | |
| .type_as(lengths) \ | |
| .repeat(batch_size, 1) \ | |
| .lt(lengths.unsqueeze(1)) | |
| def copy_list(l): | |
| r = [] | |
| if len(l) == 0: | |
| return r | |
| for i in l: | |
| if type(i) is list: | |
| r.append(copy_list(i)) | |
| else: | |
| r.append(i) | |
| return r | |
| def compute_exp_result_choice(test_preds, var_dict, exp_dict, tgt_lang): | |
| """ | |
| Arguments | |
| test_preds: B x candi_size(beam_size) x token_list | |
| var_dict: {'pos', 'len', 'var_value', 'arg_value'} | |
| exp_dict: {'exp', 'len', 'answer'} | |
| tgt_lang: vocab of target text | |
| Returns: | |
| ans_acc | |
| eq_acc | |
| """ | |
| gc.collect() | |
| ans_num = eq_num = 0 | |
| for k in range(len(test_preds)): # batch id | |
| tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS] | |
| var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \ | |
| for i, item in enumerate(var_dict['arg_value'][k])} | |
| tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict) | |
| num_list = var_dict['var_value'][k] | |
| tgt_result = float(exp_dict['answer'][k]) | |
| choices = exp_dict['choices'][k] | |
| is_find_ans = False | |
| for j in range(len(test_preds[k])): # pred candi id | |
| try: | |
| pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict) | |
| pred = normalize_exp(pred) | |
| pred_result = float(func_timeout(2.0, result_compute, \ | |
| kwargs=dict(num_all_list=num_list, exp_tokens=pred))) | |
| if pred == tgt: | |
| ans_num += 1 | |
| eq_num += 1 | |
| is_find_ans = True | |
| break | |
| for item in choices: | |
| if abs(pred_result-item)<5e-2: | |
| is_find_ans = True | |
| if is_find_ans and abs(pred_result-tgt_result)<5e-3: | |
| ans_num +=1 | |
| if len(pred)==len(tgt): | |
| eq_num += 1 | |
| if is_find_ans: break | |
| except: | |
| pass | |
| if not is_find_ans: | |
| pred_result = random.choice(choices) | |
| if abs(pred_result-tgt_result)<5e-2: | |
| ans_num +=1 | |
| return ans_num/len(test_preds), eq_num/len(test_preds) | |
| def compute_exp_result_topk(test_preds, var_dict, exp_dict, tgt_lang, k_num = 3): | |
| """ | |
| Arguments | |
| test_preds: B x candi_size(beam_size) x token_list | |
| var_dict: {'pos', 'len', 'var_value', 'arg_value'} | |
| exp_dict: {'exp', 'len', 'answer'} | |
| tgt_lang: vocab of target text | |
| Returns: | |
| ans_acc | |
| eq_acc | |
| """ | |
| gc.collect() | |
| ans_num = eq_num = 0 | |
| for k in range(len(test_preds)): # batch id | |
| tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS] | |
| var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \ | |
| for i, item in enumerate(var_dict['arg_value'][k])} | |
| tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict) | |
| num_list = var_dict['var_value'][k] | |
| tgt_result = float(exp_dict['answer'][k]) | |
| is_ans_same = is_eq_same = False | |
| for j in range(k_num): # top-n | |
| try: | |
| pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict) | |
| pred = normalize_exp(pred) | |
| pred_result = float(func_timeout(2.0, result_compute, \ | |
| kwargs=dict(num_all_list=num_list, exp_tokens=pred))) | |
| if pred == tgt: | |
| is_ans_same = True | |
| is_eq_same = True | |
| break | |
| if abs(pred_result-tgt_result)<5e-3: | |
| is_ans_same = True | |
| if len(pred)==len(tgt): | |
| is_eq_same = True | |
| break | |
| except: | |
| pass | |
| if is_ans_same: ans_num +=1 | |
| if is_eq_same: eq_num +=1 | |
| return ans_num/len(test_preds), eq_num/len(test_preds) | |
| def compute_exp_result_comp(test_preds, var_dict, exp_dict, tgt_lang): | |
| """ | |
| Arguments | |
| test_preds: B x candi_size(beam_size) x token_list | |
| var_dict: {'pos', 'len', 'var_value', 'arg_value'} | |
| exp_dict: {'exp', 'len', 'answer'} | |
| tgt_lang: vocab of target text | |
| Returns: | |
| ans_acc | |
| eq_acc | |
| """ | |
| gc.collect() | |
| ans_num = eq_num = 0 | |
| for k in range(len(test_preds)): # batch id | |
| tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS] | |
| var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \ | |
| for i, item in enumerate(var_dict['arg_value'][k])} | |
| tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict) | |
| num_list = var_dict['var_value'][k] | |
| tgt_result = float(exp_dict['answer'][k]) | |
| is_ans_same = is_eq_same = False | |
| for j in range(len(test_preds[k])): # pred candi id | |
| try: | |
| pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict) | |
| pred = normalize_exp(pred) | |
| pred_result = float(func_timeout(2.0, result_compute, \ | |
| kwargs=dict(num_all_list=num_list, exp_tokens=pred))) | |
| if pred == tgt: | |
| is_ans_same = True | |
| is_eq_same = True | |
| break | |
| if abs(pred_result-tgt_result)<5e-3: | |
| is_ans_same = True | |
| if len(pred)==len(tgt): | |
| is_eq_same = True | |
| break | |
| except: | |
| pass | |
| if is_ans_same: ans_num +=1 | |
| if is_eq_same: eq_num +=1 | |
| return ans_num/len(test_preds), eq_num/len(test_preds) |