| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
| from statistics import mean |
|
|
| import psutil |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from rotorquant_weights import load_quantized_package, dequantize_to_state_dict |
| from runtime_rotor_fused import load_fused_model |
|
|
|
|
| def rss_gb() -> float: |
| return psutil.Process().memory_info().rss / (1024 ** 3) |
|
|
|
|
| def make_inputs(tokenizer, prompt: str): |
| messages = [ |
| {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, |
| {"role": "user", "content": prompt}, |
| ] |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| return tokenizer([text], return_tensors="pt") |
|
|
|
|
| def token_match(a: torch.Tensor, b: torch.Tensor) -> float: |
| n = min(a.numel(), b.numel()) |
| if n == 0: |
| return 1.0 |
| return (a[:n] == b[:n]).float().mean().item() |
|
|
|
|
| def run_metrics(model, tokenizer, prompts, max_new_tokens, baseline_gens=None): |
| tok_t, pre_t, first_t, gen_t, tps, matches = [], [], [], [], [], [] |
| gens = [] |
|
|
| with torch.no_grad(): |
| for i, p in enumerate(prompts): |
| t0 = time.perf_counter() |
| inp = make_inputs(tokenizer, p) |
| tok_t.append(time.perf_counter() - t0) |
|
|
| t1 = time.perf_counter() |
| _ = model(**inp) |
| pre_t.append(time.perf_counter() - t1) |
|
|
| t2 = time.perf_counter() |
| _ = model.generate(**inp, max_new_tokens=1, min_new_tokens=1, do_sample=False) |
| first_t.append(time.perf_counter() - t2) |
|
|
| t3 = time.perf_counter() |
| out = model.generate( |
| **inp, |
| max_new_tokens=max_new_tokens, |
| min_new_tokens=max_new_tokens, |
| do_sample=False, |
| ) |
| dt = time.perf_counter() - t3 |
| gen_t.append(dt) |
| new_toks = out[:, inp["input_ids"].shape[1]:].reshape(-1).cpu() |
| gens.append(new_toks) |
| tps.append(new_toks.numel() / max(dt, 1e-9)) |
|
|
| if baseline_gens is not None: |
| matches.append(token_match(new_toks, baseline_gens[i])) |
|
|
| return { |
| "tokenize_s": mean(tok_t), |
| "prefill_forward_s": mean(pre_t), |
| "first_token_latency_s": mean(first_t), |
| "generate_s": mean(gen_t), |
| "decode_tokens_per_s": mean(tps), |
| "token_match_vs_baseline": mean(matches) if matches else 1.0, |
| "gens": gens, |
| } |
|
|
|
|
| def load_baseline(model_id: str): |
| t0 = time.perf_counter() |
| m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=True).eval() |
| return m, time.perf_counter() - t0 |
|
|
|
|
| def load_rotor(pkg_path: str): |
| t0 = time.perf_counter() |
| pkg = load_quantized_package(pkg_path) |
| model_id = pkg["model_id"] |
| m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=True).eval() |
| sd = dequantize_to_state_dict(pkg, dtype=torch.float32, device="cpu") |
| miss, unexp = m.load_state_dict(sd, strict=False) |
| if miss or unexp: |
| raise RuntimeError(f"State mismatch: missing={miss}, unexpected={unexp}") |
| return m, time.perf_counter() - t0 |
|
|
|
|
| def load_dynamic_int8(path: str): |
| t0 = time.perf_counter() |
| m = torch.load(path, map_location="cpu", weights_only=False).eval() |
| return m, time.perf_counter() - t0 |
|
|
|
|
| def scenario_result(name, load_s, metrics, rss_before, rss_after_load, rss_after_bench, baseline=None): |
| out = { |
| "scenario": name, |
| "load_s": load_s, |
| "tokenize_s": metrics["tokenize_s"], |
| "prefill_forward_s": metrics["prefill_forward_s"], |
| "first_token_latency_s": metrics["first_token_latency_s"], |
| "generate_s": metrics["generate_s"], |
| "decode_tokens_per_s": metrics["decode_tokens_per_s"], |
| "token_match_vs_baseline": metrics["token_match_vs_baseline"], |
| "rss_before_load_gb": rss_before, |
| "rss_after_load_gb": rss_after_load, |
| "rss_after_bench_gb": rss_after_bench, |
| } |
| if baseline is not None: |
| out["delta_vs_baseline"] = { |
| "load_s": out["load_s"] - baseline["load_s"], |
| "prefill_forward_s": out["prefill_forward_s"] - baseline["prefill_forward_s"], |
| "first_token_latency_s": out["first_token_latency_s"] - baseline["first_token_latency_s"], |
| "generate_s": out["generate_s"] - baseline["generate_s"], |
| "decode_tokens_per_s": out["decode_tokens_per_s"] - baseline["decode_tokens_per_s"], |
| "rss_after_load_gb": out["rss_after_load_gb"] - baseline["rss_after_load_gb"], |
| } |
| return out |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Benchmark baseline vs RotorQuant vs runtime INT8") |
| p.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct") |
| p.add_argument("--rotor-pkg", default="artifacts/qwen2.5-0.5b-rotorq3-mlp-only.pt") |
| p.add_argument("--fused-pkg", default="artifacts/qwen2.5-0.5b-rotorq3-rowwise-skipemb.pt") |
| p.add_argument("--int8-model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt") |
| p.add_argument("--max-new-tokens", type=int, default=64) |
| p.add_argument("--out", default="artifacts/runtime_benchmark.json") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| prompts = [ |
| "Explain quantization in one paragraph.", |
| "Write a Python function for binary search.", |
| "Summarize why weight quantization helps deployment.", |
| "Give 3 practical tips for reducing LLM latency.", |
| ] |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_id) |
| _ = make_inputs(tokenizer, "warmup") |
|
|
| results = {} |
|
|
| rb = rss_gb() |
| baseline, load_b = load_baseline(args.model_id) |
| ral = rss_gb() |
| met_b = run_metrics(baseline, tokenizer, prompts, args.max_new_tokens) |
| rab = rss_gb() |
| results["baseline_fp32"] = scenario_result("baseline_fp32", load_b, met_b, rb, ral, rab) |
| base_ref = results["baseline_fp32"] |
| base_gens = met_b["gens"] |
| del baseline |
|
|
| rr0 = rss_gb() |
| rotor, load_r = load_rotor(args.rotor_pkg) |
| rr1 = rss_gb() |
| met_r = run_metrics(rotor, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens) |
| rr2 = rss_gb() |
| results["rotorquant_pkg"] = scenario_result("rotorquant_pkg", load_r, met_r, rr0, rr1, rr2, baseline=base_ref) |
| del rotor |
|
|
| rf0 = rss_gb() |
| fused, _, load_f = load_fused_model(args.fused_pkg, out_chunk_size=64) |
| rf1 = rss_gb() |
| met_f = run_metrics(fused, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens) |
| rf2 = rss_gb() |
| results["rotorquant_fused_runtime"] = scenario_result("rotorquant_fused_runtime", load_f, met_f, rf0, rf1, rf2, baseline=base_ref) |
| del fused |
|
|
| ri0 = rss_gb() |
| int8m, load_i = load_dynamic_int8(args.int8_model) |
| ri1 = rss_gb() |
| met_i = run_metrics(int8m, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens) |
| ri2 = rss_gb() |
| results["runtime_dynamic_int8"] = scenario_result("runtime_dynamic_int8", load_i, met_i, ri0, ri1, ri2, baseline=base_ref) |
| del int8m |
|
|
| out = Path(args.out) |
| out.parent.mkdir(parents=True, exist_ok=True) |
| out.write_text(json.dumps(results, indent=2), encoding="utf-8") |
|
|
| print(f"Saved: {out}") |
| for k, v in results.items(): |
| print( |
| f"- {k}: load={v['load_s']:.3f}s, first={v['first_token_latency_s']:.3f}s, " |
| f"gen={v['generate_s']:.3f}s, tok/s={v['decode_tokens_per_s']:.2f}, " |
| f"rss_load={v['rss_after_load_gb']:.2f}GB, match={v['token_match_vs_baseline']:.4f}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|