#!/usr/bin/env python3 import argparse import os import sys import torch def ensure_llmpruner_on_path() -> None: repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner") if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path: sys.path.insert(0, llmpruner_root) def load_llmpruner_checkpoint(path: str): ensure_llmpruner_on_path() checkpoint = torch.load(path, map_location="cpu", weights_only=False) if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint: raise SystemExit( "Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries." ) return checkpoint["model"], checkpoint["tokenizer"] def main() -> None: parser = argparse.ArgumentParser( description="Convert an LLM-Pruner .bin checkpoint to a Hugging Face save_pretrained directory." ) parser.add_argument("--input", required=True, help="Path to LLM-Pruner pytorch_model.bin") parser.add_argument("--output_dir", required=True, help="Directory to write HF model artifacts") args = parser.parse_args() model, tokenizer = load_llmpruner_checkpoint(args.input) os.makedirs(args.output_dir, exist_ok=True) model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) print(args.output_dir) if __name__ == "__main__": main()