import os import torch def average_checkpoints(last): avg = None for path in last: states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"] states = {k[6:]: v for k, v in states.items() if k.startswith("model.")} if avg is None: avg = states else: for k in avg.keys(): avg[k] += states[k] # average for k in avg.keys(): if avg[k] is not None: if avg[k].is_floating_point(): avg[k] /= len(last) else: avg[k] //= len(last) return avg def ensemble(args): last = [ os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range( args.max_epochs - 10, args.max_epochs, ) ] model_path = os.path.join(args.exp_dir, args.exp_name, f"model_avg_10.pth") torch.save(average_checkpoints(last), model_path) return model_path