from __future__ import annotations import argparse import json import time from pathlib import Path from statistics import mean import psutil import torch from transformers import AutoModelForCausalLM, AutoTokenizer from rotorquant_weights import load_quantized_package, dequantize_to_state_dict from runtime_rotor_fused import load_fused_model def rss_gb() -> float: return psutil.Process().memory_info().rss / (1024 ** 3) def make_inputs(tokenizer, prompt: str): messages = [ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, {"role": "user", "content": prompt}, ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return tokenizer([text], return_tensors="pt") def token_match(a: torch.Tensor, b: torch.Tensor) -> float: n = min(a.numel(), b.numel()) if n == 0: return 1.0 return (a[:n] == b[:n]).float().mean().item() def run_metrics(model, tokenizer, prompts, max_new_tokens, baseline_gens=None): tok_t, pre_t, first_t, gen_t, tps, matches = [], [], [], [], [], [] gens = [] with torch.no_grad(): for i, p in enumerate(prompts): t0 = time.perf_counter() inp = make_inputs(tokenizer, p) tok_t.append(time.perf_counter() - t0) t1 = time.perf_counter() _ = model(**inp) pre_t.append(time.perf_counter() - t1) t2 = time.perf_counter() _ = model.generate(**inp, max_new_tokens=1, min_new_tokens=1, do_sample=False) first_t.append(time.perf_counter() - t2) t3 = time.perf_counter() out = model.generate( **inp, max_new_tokens=max_new_tokens, min_new_tokens=max_new_tokens, do_sample=False, ) dt = time.perf_counter() - t3 gen_t.append(dt) new_toks = out[:, inp["input_ids"].shape[1]:].reshape(-1).cpu() gens.append(new_toks) tps.append(new_toks.numel() / max(dt, 1e-9)) if baseline_gens is not None: matches.append(token_match(new_toks, baseline_gens[i])) return { "tokenize_s": mean(tok_t), "prefill_forward_s": mean(pre_t), "first_token_latency_s": mean(first_t), "generate_s": mean(gen_t), "decode_tokens_per_s": mean(tps), "token_match_vs_baseline": mean(matches) if matches else 1.0, "gens": gens, } def load_baseline(model_id: str): t0 = time.perf_counter() m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=True).eval() return m, time.perf_counter() - t0 def load_rotor(pkg_path: str): t0 = time.perf_counter() pkg = load_quantized_package(pkg_path) model_id = pkg["model_id"] m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=True).eval() sd = dequantize_to_state_dict(pkg, dtype=torch.float32, device="cpu") miss, unexp = m.load_state_dict(sd, strict=False) if miss or unexp: raise RuntimeError(f"State mismatch: missing={miss}, unexpected={unexp}") return m, time.perf_counter() - t0 def load_dynamic_int8(path: str): t0 = time.perf_counter() m = torch.load(path, map_location="cpu", weights_only=False).eval() return m, time.perf_counter() - t0 def scenario_result(name, load_s, metrics, rss_before, rss_after_load, rss_after_bench, baseline=None): out = { "scenario": name, "load_s": load_s, "tokenize_s": metrics["tokenize_s"], "prefill_forward_s": metrics["prefill_forward_s"], "first_token_latency_s": metrics["first_token_latency_s"], "generate_s": metrics["generate_s"], "decode_tokens_per_s": metrics["decode_tokens_per_s"], "token_match_vs_baseline": metrics["token_match_vs_baseline"], "rss_before_load_gb": rss_before, "rss_after_load_gb": rss_after_load, "rss_after_bench_gb": rss_after_bench, } if baseline is not None: out["delta_vs_baseline"] = { "load_s": out["load_s"] - baseline["load_s"], "prefill_forward_s": out["prefill_forward_s"] - baseline["prefill_forward_s"], "first_token_latency_s": out["first_token_latency_s"] - baseline["first_token_latency_s"], "generate_s": out["generate_s"] - baseline["generate_s"], "decode_tokens_per_s": out["decode_tokens_per_s"] - baseline["decode_tokens_per_s"], "rss_after_load_gb": out["rss_after_load_gb"] - baseline["rss_after_load_gb"], } return out def parse_args(): p = argparse.ArgumentParser(description="Benchmark baseline vs RotorQuant vs runtime INT8") p.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct") p.add_argument("--rotor-pkg", default="artifacts/qwen2.5-0.5b-rotorq3-mlp-only.pt") p.add_argument("--fused-pkg", default="artifacts/qwen2.5-0.5b-rotorq3-rowwise-skipemb.pt") p.add_argument("--int8-model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt") p.add_argument("--max-new-tokens", type=int, default=64) p.add_argument("--out", default="artifacts/runtime_benchmark.json") return p.parse_args() def main(): args = parse_args() prompts = [ "Explain quantization in one paragraph.", "Write a Python function for binary search.", "Summarize why weight quantization helps deployment.", "Give 3 practical tips for reducing LLM latency.", ] tokenizer = AutoTokenizer.from_pretrained(args.model_id) _ = make_inputs(tokenizer, "warmup") results = {} rb = rss_gb() baseline, load_b = load_baseline(args.model_id) ral = rss_gb() met_b = run_metrics(baseline, tokenizer, prompts, args.max_new_tokens) rab = rss_gb() results["baseline_fp32"] = scenario_result("baseline_fp32", load_b, met_b, rb, ral, rab) base_ref = results["baseline_fp32"] base_gens = met_b["gens"] del baseline rr0 = rss_gb() rotor, load_r = load_rotor(args.rotor_pkg) rr1 = rss_gb() met_r = run_metrics(rotor, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens) rr2 = rss_gb() results["rotorquant_pkg"] = scenario_result("rotorquant_pkg", load_r, met_r, rr0, rr1, rr2, baseline=base_ref) del rotor rf0 = rss_gb() fused, _, load_f = load_fused_model(args.fused_pkg, out_chunk_size=64) rf1 = rss_gb() met_f = run_metrics(fused, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens) rf2 = rss_gb() results["rotorquant_fused_runtime"] = scenario_result("rotorquant_fused_runtime", load_f, met_f, rf0, rf1, rf2, baseline=base_ref) del fused ri0 = rss_gb() int8m, load_i = load_dynamic_int8(args.int8_model) ri1 = rss_gb() met_i = run_metrics(int8m, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens) ri2 = rss_gb() results["runtime_dynamic_int8"] = scenario_result("runtime_dynamic_int8", load_i, met_i, ri0, ri1, ri2, baseline=base_ref) del int8m out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(results, indent=2), encoding="utf-8") print(f"Saved: {out}") for k, v in results.items(): print( f"- {k}: load={v['load_s']:.3f}s, first={v['first_token_latency_s']:.3f}s, " f"gen={v['generate_s']:.3f}s, tok/s={v['decode_tokens_per_s']:.2f}, " f"rss_load={v['rss_after_load_gb']:.2f}GB, match={v['token_match_vs_baseline']:.4f}" ) if __name__ == "__main__": main()