from __future__ import annotations import argparse import torch from transformers import AutoModelForCausalLM, AutoTokenizer from rotorquant_weights import load_quantized_package, dequantize_to_state_dict def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Run inference with a RotorQuant package") p.add_argument("--quantized", required=True) p.add_argument("--prompt", default="Give me a short introduction to large language models.") p.add_argument("--system", default="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.") p.add_argument("--max-new-tokens", type=int, default=80) p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32") return p.parse_args() def str_to_dtype(s: str) -> torch.dtype: return { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, }[s] def main() -> None: args = parse_args() pkg = load_quantized_package(args.quantized) model_id = pkg["model_id"] dtype = str_to_dtype(args.dtype) print(f"Loading base architecture from: {model_id}") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map=None, low_cpu_mem_usage=True, ) tokenizer = AutoTokenizer.from_pretrained(model_id) print("Dequantizing state dict...") state_dict = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu") missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing or unexpected: raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}") model.eval() messages = [ {"role": "system", "content": args.system}, {"role": "user", "content": args.prompt}, ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer([text], return_tensors="pt") with torch.no_grad(): out_ids = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=False, ) new_tokens = out_ids[:, inputs["input_ids"].shape[1]:] response = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0] print("\n=== Prompt ===") print(args.prompt) print("\n=== Response ===") print(response) if __name__ == "__main__": main()