RotorQuant-ModelWeights-Runtime / benchmark_runtime_vs_rotor.py
cnmoro's picture
Upload 29 files
18f4d80 verified
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from statistics import mean
import psutil
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rotorquant_weights import load_quantized_package, dequantize_to_state_dict
from runtime_rotor_fused import load_fused_model
def rss_gb() -> float:
return psutil.Process().memory_info().rss / (1024 ** 3)
def make_inputs(tokenizer, prompt: str):
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return tokenizer([text], return_tensors="pt")
def token_match(a: torch.Tensor, b: torch.Tensor) -> float:
n = min(a.numel(), b.numel())
if n == 0:
return 1.0
return (a[:n] == b[:n]).float().mean().item()
def run_metrics(model, tokenizer, prompts, max_new_tokens, baseline_gens=None):
tok_t, pre_t, first_t, gen_t, tps, matches = [], [], [], [], [], []
gens = []
with torch.no_grad():
for i, p in enumerate(prompts):
t0 = time.perf_counter()
inp = make_inputs(tokenizer, p)
tok_t.append(time.perf_counter() - t0)
t1 = time.perf_counter()
_ = model(**inp)
pre_t.append(time.perf_counter() - t1)
t2 = time.perf_counter()
_ = model.generate(**inp, max_new_tokens=1, min_new_tokens=1, do_sample=False)
first_t.append(time.perf_counter() - t2)
t3 = time.perf_counter()
out = model.generate(
**inp,
max_new_tokens=max_new_tokens,
min_new_tokens=max_new_tokens,
do_sample=False,
)
dt = time.perf_counter() - t3
gen_t.append(dt)
new_toks = out[:, inp["input_ids"].shape[1]:].reshape(-1).cpu()
gens.append(new_toks)
tps.append(new_toks.numel() / max(dt, 1e-9))
if baseline_gens is not None:
matches.append(token_match(new_toks, baseline_gens[i]))
return {
"tokenize_s": mean(tok_t),
"prefill_forward_s": mean(pre_t),
"first_token_latency_s": mean(first_t),
"generate_s": mean(gen_t),
"decode_tokens_per_s": mean(tps),
"token_match_vs_baseline": mean(matches) if matches else 1.0,
"gens": gens,
}
def load_baseline(model_id: str):
t0 = time.perf_counter()
m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=True).eval()
return m, time.perf_counter() - t0
def load_rotor(pkg_path: str):
t0 = time.perf_counter()
pkg = load_quantized_package(pkg_path)
model_id = pkg["model_id"]
m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=True).eval()
sd = dequantize_to_state_dict(pkg, dtype=torch.float32, device="cpu")
miss, unexp = m.load_state_dict(sd, strict=False)
if miss or unexp:
raise RuntimeError(f"State mismatch: missing={miss}, unexpected={unexp}")
return m, time.perf_counter() - t0
def load_dynamic_int8(path: str):
t0 = time.perf_counter()
m = torch.load(path, map_location="cpu", weights_only=False).eval()
return m, time.perf_counter() - t0
def scenario_result(name, load_s, metrics, rss_before, rss_after_load, rss_after_bench, baseline=None):
out = {
"scenario": name,
"load_s": load_s,
"tokenize_s": metrics["tokenize_s"],
"prefill_forward_s": metrics["prefill_forward_s"],
"first_token_latency_s": metrics["first_token_latency_s"],
"generate_s": metrics["generate_s"],
"decode_tokens_per_s": metrics["decode_tokens_per_s"],
"token_match_vs_baseline": metrics["token_match_vs_baseline"],
"rss_before_load_gb": rss_before,
"rss_after_load_gb": rss_after_load,
"rss_after_bench_gb": rss_after_bench,
}
if baseline is not None:
out["delta_vs_baseline"] = {
"load_s": out["load_s"] - baseline["load_s"],
"prefill_forward_s": out["prefill_forward_s"] - baseline["prefill_forward_s"],
"first_token_latency_s": out["first_token_latency_s"] - baseline["first_token_latency_s"],
"generate_s": out["generate_s"] - baseline["generate_s"],
"decode_tokens_per_s": out["decode_tokens_per_s"] - baseline["decode_tokens_per_s"],
"rss_after_load_gb": out["rss_after_load_gb"] - baseline["rss_after_load_gb"],
}
return out
def parse_args():
p = argparse.ArgumentParser(description="Benchmark baseline vs RotorQuant vs runtime INT8")
p.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct")
p.add_argument("--rotor-pkg", default="artifacts/qwen2.5-0.5b-rotorq3-mlp-only.pt")
p.add_argument("--fused-pkg", default="artifacts/qwen2.5-0.5b-rotorq3-rowwise-skipemb.pt")
p.add_argument("--int8-model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt")
p.add_argument("--max-new-tokens", type=int, default=64)
p.add_argument("--out", default="artifacts/runtime_benchmark.json")
return p.parse_args()
def main():
args = parse_args()
prompts = [
"Explain quantization in one paragraph.",
"Write a Python function for binary search.",
"Summarize why weight quantization helps deployment.",
"Give 3 practical tips for reducing LLM latency.",
]
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
_ = make_inputs(tokenizer, "warmup")
results = {}
rb = rss_gb()
baseline, load_b = load_baseline(args.model_id)
ral = rss_gb()
met_b = run_metrics(baseline, tokenizer, prompts, args.max_new_tokens)
rab = rss_gb()
results["baseline_fp32"] = scenario_result("baseline_fp32", load_b, met_b, rb, ral, rab)
base_ref = results["baseline_fp32"]
base_gens = met_b["gens"]
del baseline
rr0 = rss_gb()
rotor, load_r = load_rotor(args.rotor_pkg)
rr1 = rss_gb()
met_r = run_metrics(rotor, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens)
rr2 = rss_gb()
results["rotorquant_pkg"] = scenario_result("rotorquant_pkg", load_r, met_r, rr0, rr1, rr2, baseline=base_ref)
del rotor
rf0 = rss_gb()
fused, _, load_f = load_fused_model(args.fused_pkg, out_chunk_size=64)
rf1 = rss_gb()
met_f = run_metrics(fused, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens)
rf2 = rss_gb()
results["rotorquant_fused_runtime"] = scenario_result("rotorquant_fused_runtime", load_f, met_f, rf0, rf1, rf2, baseline=base_ref)
del fused
ri0 = rss_gb()
int8m, load_i = load_dynamic_int8(args.int8_model)
ri1 = rss_gb()
met_i = run_metrics(int8m, tokenizer, prompts, args.max_new_tokens, baseline_gens=base_gens)
ri2 = rss_gb()
results["runtime_dynamic_int8"] = scenario_result("runtime_dynamic_int8", load_i, met_i, ri0, ri1, ri2, baseline=base_ref)
del int8m
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(results, indent=2), encoding="utf-8")
print(f"Saved: {out}")
for k, v in results.items():
print(
f"- {k}: load={v['load_s']:.3f}s, first={v['first_token_latency_s']:.3f}s, "
f"gen={v['generate_s']:.3f}s, tok/s={v['decode_tokens_per_s']:.2f}, "
f"rss_load={v['rss_after_load_gb']:.2f}GB, match={v['token_match_vs_baseline']:.4f}"
)
if __name__ == "__main__":
main()