File size: 1,484 Bytes
2c44909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | #!/usr/bin/env python3
import argparse
import os
import sys
import torch
def ensure_llmpruner_on_path() -> None:
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner")
if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path:
sys.path.insert(0, llmpruner_root)
def load_llmpruner_checkpoint(path: str):
ensure_llmpruner_on_path()
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
raise SystemExit(
"Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries."
)
return checkpoint["model"], checkpoint["tokenizer"]
def main() -> None:
parser = argparse.ArgumentParser(
description="Convert an LLM-Pruner .bin checkpoint to a Hugging Face save_pretrained directory."
)
parser.add_argument("--input", required=True, help="Path to LLM-Pruner pytorch_model.bin")
parser.add_argument("--output_dir", required=True, help="Directory to write HF model artifacts")
args = parser.parse_args()
model, tokenizer = load_llmpruner_checkpoint(args.input)
os.makedirs(args.output_dir, exist_ok=True)
model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print(args.output_dir)
if __name__ == "__main__":
main()
|