import torch import pytorch_lightning import argparse from pathlib import Path args = argparse.ArgumentParser() args.add_argument('--path', type=str, required=True, help='path to model') args = args.parse_args() path = Path(args.path) model_dict = torch.load(path, map_location='cpu') dict_keys = list(model_dict["state_dict"].keys()) for key in dict_keys: if key.startswith("model."): del model_dict["state_dict"][key] for param in model_dict["state_dict"].keys(): model_dict["state_dict"][param] = model_dict["state_dict"][param].half() #print(model_dict["state_dict"].keys()) torch.save(model_dict, path.parent / "pruned.ckpt")