| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import argparse |
| | import glob |
| |
|
| | import yaml |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser(description="average model") |
| | parser.add_argument("--dst_model", required=True, help="averaged model") |
| | parser.add_argument("--src_path", required=True, help="src model path for average") |
| | parser.add_argument("--val_best", action="store_true", help="averaged model") |
| | parser.add_argument("--num", default=5, type=int, help="nums for averaged model") |
| | parser.add_argument( |
| | "--min_epoch", default=0, type=int, help="min epoch used for averaging model" |
| | ) |
| | parser.add_argument( |
| | "--max_epoch", |
| | default=65536, |
| | type=int, |
| | help="max epoch used for averaging model", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | print(args) |
| | return args |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| | checkpoints = [] |
| | val_scores = [] |
| | if args.val_best: |
| | yamls = glob.glob("{}/[!train]*.yaml".format(args.src_path)) |
| | for y in yamls: |
| | with open(y, "r") as f: |
| | dic_yaml = yaml.load(f, Loader=yaml.FullLoader) |
| | loss = dic_yaml["cv_loss"] |
| | epoch = dic_yaml["epoch"] |
| | if epoch >= args.min_epoch and epoch <= args.max_epoch: |
| | val_scores += [[epoch, loss]] |
| | val_scores = np.array(val_scores) |
| | sort_idx = np.argsort(val_scores[:, -1]) |
| | sorted_val_scores = val_scores[sort_idx][::1] |
| | print("best val scores = " + str(sorted_val_scores[: args.num, 1])) |
| | print( |
| | "selected epochs = " |
| | + str(sorted_val_scores[: args.num, 0].astype(np.int64)) |
| | ) |
| | path_list = [ |
| | args.src_path + "/{}.pt".format(int(epoch)) |
| | for epoch in sorted_val_scores[: args.num, 0] |
| | ] |
| | else: |
| | path_list = glob.glob("{}/[0-9]*.pt".format(args.src_path)) |
| | path_list = sorted(path_list, key=os.path.getmtime) |
| | path_list = path_list[-args.num :] |
| | print(path_list) |
| | avg = None |
| | num = args.num |
| | assert num == len(path_list) |
| | for path in path_list: |
| | print("Processing {}".format(path)) |
| | states = torch.load(path, map_location=torch.device("cpu")) |
| | if avg is None: |
| | avg = states |
| | else: |
| | for k in avg.keys(): |
| | avg[k] += states[k] |
| | |
| | for k in avg.keys(): |
| | if avg[k] is not None: |
| | |
| | avg[k] = torch.true_divide(avg[k], num) |
| | print("Saving to {}".format(args.dst_model)) |
| | torch.save(avg, args.dst_model) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|