| from __future__ import annotations |
|
|
| import argparse |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from rotorquant_weights import load_quantized_package, dequantize_to_state_dict |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Run inference with a RotorQuant package") |
| p.add_argument("--quantized", required=True) |
| p.add_argument("--prompt", default="Give me a short introduction to large language models.") |
| p.add_argument("--system", default="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.") |
| p.add_argument("--max-new-tokens", type=int, default=80) |
| p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32") |
| return p.parse_args() |
|
|
|
|
| def str_to_dtype(s: str) -> torch.dtype: |
| return { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[s] |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| pkg = load_quantized_package(args.quantized) |
| model_id = pkg["model_id"] |
|
|
| dtype = str_to_dtype(args.dtype) |
|
|
| print(f"Loading base architecture from: {model_id}") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=dtype, |
| device_map=None, |
| low_cpu_mem_usage=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
| print("Dequantizing state dict...") |
| state_dict = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu") |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing or unexpected: |
| raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}") |
| model.eval() |
|
|
| messages = [ |
| {"role": "system", "content": args.system}, |
| {"role": "user", "content": args.prompt}, |
| ] |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| inputs = tokenizer([text], return_tensors="pt") |
|
|
| with torch.no_grad(): |
| out_ids = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False, |
| ) |
|
|
| new_tokens = out_ids[:, inputs["input_ids"].shape[1]:] |
| response = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0] |
|
|
| print("\n=== Prompt ===") |
| print(args.prompt) |
| print("\n=== Response ===") |
| print(response) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|