cnmoro's picture
Upload 29 files
18f4d80 verified
from __future__ import annotations
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rotorquant_weights import load_quantized_package, dequantize_to_state_dict
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Run inference with a RotorQuant package")
p.add_argument("--quantized", required=True)
p.add_argument("--prompt", default="Give me a short introduction to large language models.")
p.add_argument("--system", default="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.")
p.add_argument("--max-new-tokens", type=int, default=80)
p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
return p.parse_args()
def str_to_dtype(s: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[s]
def main() -> None:
args = parse_args()
pkg = load_quantized_package(args.quantized)
model_id = pkg["model_id"]
dtype = str_to_dtype(args.dtype)
print(f"Loading base architecture from: {model_id}")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=None,
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
print("Dequantizing state dict...")
state_dict = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}")
model.eval()
messages = [
{"role": "system", "content": args.system},
{"role": "user", "content": args.prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer([text], return_tensors="pt")
with torch.no_grad():
out_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
)
new_tokens = out_ids[:, inputs["input_ids"].shape[1]:]
response = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]
print("\n=== Prompt ===")
print(args.prompt)
print("\n=== Response ===")
print(response)
if __name__ == "__main__":
main()