File size: 1,484 Bytes
2c44909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/env python3
import argparse
import os
import sys

import torch


def ensure_llmpruner_on_path() -> None:
    repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner")
    if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path:
        sys.path.insert(0, llmpruner_root)


def load_llmpruner_checkpoint(path: str):
    ensure_llmpruner_on_path()
    checkpoint = torch.load(path, map_location="cpu", weights_only=False)
    if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
        raise SystemExit(
            "Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries."
        )
    return checkpoint["model"], checkpoint["tokenizer"]


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Convert an LLM-Pruner .bin checkpoint to a Hugging Face save_pretrained directory."
    )
    parser.add_argument("--input", required=True, help="Path to LLM-Pruner pytorch_model.bin")
    parser.add_argument("--output_dir", required=True, help="Directory to write HF model artifacts")
    args = parser.parse_args()

    model, tokenizer = load_llmpruner_checkpoint(args.input)
    os.makedirs(args.output_dir, exist_ok=True)
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(args.output_dir)


if __name__ == "__main__":
    main()