|
|
import torch |
|
|
import pytorch_lightning |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
args = argparse.ArgumentParser() |
|
|
args.add_argument('--path', type=str, required=True, help='path to model') |
|
|
|
|
|
args = args.parse_args() |
|
|
path = Path(args.path) |
|
|
|
|
|
model_dict = torch.load(path, map_location='cpu') |
|
|
dict_keys = list(model_dict["state_dict"].keys()) |
|
|
for key in dict_keys: |
|
|
if key.startswith("model."): |
|
|
del model_dict["state_dict"][key] |
|
|
|
|
|
for param in model_dict["state_dict"].keys(): |
|
|
model_dict["state_dict"][param] = model_dict["state_dict"][param].half() |
|
|
|
|
|
|
|
|
torch.save(model_dict, path.parent / "pruned.ckpt") |