| |
| import argparse |
| import os |
| import sys |
|
|
| import torch |
|
|
|
|
| def ensure_llmpruner_on_path() -> None: |
| repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner") |
| if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path: |
| sys.path.insert(0, llmpruner_root) |
|
|
|
|
| def load_llmpruner_checkpoint(path: str): |
| ensure_llmpruner_on_path() |
| checkpoint = torch.load(path, map_location="cpu", weights_only=False) |
| if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint: |
| raise SystemExit( |
| "Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries." |
| ) |
| return checkpoint["model"], checkpoint["tokenizer"] |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Convert an LLM-Pruner .bin checkpoint to a Hugging Face save_pretrained directory." |
| ) |
| parser.add_argument("--input", required=True, help="Path to LLM-Pruner pytorch_model.bin") |
| parser.add_argument("--output_dir", required=True, help="Directory to write HF model artifacts") |
| args = parser.parse_args() |
|
|
| model, tokenizer = load_llmpruner_checkpoint(args.input) |
| os.makedirs(args.output_dir, exist_ok=True) |
| model.save_pretrained(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
| print(args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|