| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import argparse |
| | import glob |
| |
|
| | import yaml |
| | import torch |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser(description='average model') |
| | parser.add_argument('--dst_model', required=True, help='averaged model') |
| | parser.add_argument('--src_path', |
| | required=True, |
| | help='src model path for average') |
| | parser.add_argument('--val_best', |
| | action="store_true", |
| | help='averaged model') |
| | parser.add_argument('--num', |
| | default=5, |
| | type=int, |
| | help='nums for averaged model') |
| |
|
| | args = parser.parse_args() |
| | print(args) |
| | return args |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| | val_scores = [] |
| | if args.val_best: |
| | yamls = glob.glob('{}/*.yaml'.format(args.src_path)) |
| | yamls = [ |
| | f for f in yamls |
| | if not (os.path.basename(f).startswith('train') |
| | or os.path.basename(f).startswith('init')) |
| | ] |
| | for y in yamls: |
| | with open(y, 'r') as f: |
| | dic_yaml = yaml.load(f, Loader=yaml.BaseLoader) |
| | loss = float(dic_yaml['loss_dict']['loss']) |
| | epoch = int(dic_yaml['epoch']) |
| | step = int(dic_yaml['step']) |
| | tag = dic_yaml['tag'] |
| | val_scores += [[epoch, step, loss, tag]] |
| | sorted_val_scores = sorted(val_scores, |
| | key=lambda x: x[2], |
| | reverse=False) |
| | print("best val (epoch, step, loss, tag) = " + |
| | str(sorted_val_scores[:args.num])) |
| | path_list = [ |
| | args.src_path + '/epoch_{}_whole.pt'.format(score[0]) |
| | for score in sorted_val_scores[:args.num] |
| | ] |
| | print(path_list) |
| | avg = {} |
| | num = args.num |
| | assert num == len(path_list) |
| | for path in path_list: |
| | print('Processing {}'.format(path)) |
| | states = torch.load(path, map_location=torch.device('cpu')) |
| | for k in states.keys(): |
| | if k not in avg.keys(): |
| | avg[k] = states[k].clone() |
| | else: |
| | avg[k] += states[k] |
| | |
| | for k in avg.keys(): |
| | if avg[k] is not None: |
| | |
| | avg[k] = torch.true_divide(avg[k], num) |
| | print('Saving to {}'.format(args.dst_model)) |
| | torch.save(avg, args.dst_model) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|