NovelaiAI / clean.py
tlsdlftn79's picture
Upload 5 files
8482701
raw
history blame contribute delete
651 Bytes
import torch
import pytorch_lightning
import argparse
from pathlib import Path
args = argparse.ArgumentParser()
args.add_argument('--path', type=str, required=True, help='path to model')
args = args.parse_args()
path = Path(args.path)
model_dict = torch.load(path, map_location='cpu')
dict_keys = list(model_dict["state_dict"].keys())
for key in dict_keys:
if key.startswith("model."):
del model_dict["state_dict"][key]
for param in model_dict["state_dict"].keys():
model_dict["state_dict"][param] = model_dict["state_dict"][param].half()
#print(model_dict["state_dict"].keys())
torch.save(model_dict, path.parent / "pruned.ckpt")