File size: 2,469 Bytes
18f4d80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import annotations

import argparse

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from rotorquant_weights import load_quantized_package, dequantize_to_state_dict


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Run inference with a RotorQuant package")
    p.add_argument("--quantized", required=True)
    p.add_argument("--prompt", default="Give me a short introduction to large language models.")
    p.add_argument("--system", default="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.")
    p.add_argument("--max-new-tokens", type=int, default=80)
    p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
    return p.parse_args()


def str_to_dtype(s: str) -> torch.dtype:
    return {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }[s]


def main() -> None:
    args = parse_args()
    pkg = load_quantized_package(args.quantized)
    model_id = pkg["model_id"]

    dtype = str_to_dtype(args.dtype)

    print(f"Loading base architecture from: {model_id}")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=None,
        low_cpu_mem_usage=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    print("Dequantizing state dict...")
    state_dict = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu")
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if missing or unexpected:
        raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}")
    model.eval()

    messages = [
        {"role": "system", "content": args.system},
        {"role": "user", "content": args.prompt},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    inputs = tokenizer([text], return_tensors="pt")

    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=args.max_new_tokens,
            do_sample=False,
        )

    new_tokens = out_ids[:, inputs["input_ids"].shape[1]:]
    response = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]

    print("\n=== Prompt ===")
    print(args.prompt)
    print("\n=== Response ===")
    print(response)


if __name__ == "__main__":
    main()