File size: 2,469 Bytes
18f4d80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | from __future__ import annotations
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rotorquant_weights import load_quantized_package, dequantize_to_state_dict
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Run inference with a RotorQuant package")
p.add_argument("--quantized", required=True)
p.add_argument("--prompt", default="Give me a short introduction to large language models.")
p.add_argument("--system", default="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.")
p.add_argument("--max-new-tokens", type=int, default=80)
p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
return p.parse_args()
def str_to_dtype(s: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[s]
def main() -> None:
args = parse_args()
pkg = load_quantized_package(args.quantized)
model_id = pkg["model_id"]
dtype = str_to_dtype(args.dtype)
print(f"Loading base architecture from: {model_id}")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=None,
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
print("Dequantizing state dict...")
state_dict = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}")
model.eval()
messages = [
{"role": "system", "content": args.system},
{"role": "user", "content": args.prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer([text], return_tensors="pt")
with torch.no_grad():
out_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
)
new_tokens = out_ids[:, inputs["input_ids"].shape[1]:]
response = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]
print("\n=== Prompt ===")
print(args.prompt)
print("\n=== Response ===")
print(response)
if __name__ == "__main__":
main()
|