from __future__ import annotations import argparse import json import time from pathlib import Path import psutil import torch import torch.nn as nn from torch.ao.quantization import quantize_dynamic from torch.ao.quantization.qconfig import default_dynamic_qconfig from transformers import AutoModelForCausalLM, AutoTokenizer def rss_gb() -> float: return psutil.Process().memory_info().rss / (1024 ** 3) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Build and run runtime INT8 quantized model (dynamic quantization)") sub = p.add_subparsers(dest="cmd", required=True) b = sub.add_parser("build", help="Build and save runtime INT8 model") b.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct") b.add_argument("--out", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt") b.add_argument("--meta", default="artifacts/qwen2.5-0.5b-dynamic-int8-meta.json") b.add_argument( "--include-name-contains", action="append", default=[], help="Only quantize Linear modules whose full module name contains one of these fragments.", ) r = sub.add_parser("run", help="Load runtime INT8 model and generate") r.add_argument("--model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt") r.add_argument("--prompt", default="Explain quantization in one paragraph.") r.add_argument("--max-new-tokens", type=int, default=64) return p.parse_args() def build_runtime_int8(model_id: str, out_path: str, meta_path: str, include_name_contains: list[str]) -> None: t0 = time.perf_counter() rss0 = rss_gb() model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=True, ).eval() t_load = time.perf_counter() - t0 rss_load = rss_gb() t1 = time.perf_counter() if include_name_contains: qspec = {} for name, mod in model.named_modules(): if isinstance(mod, nn.Linear) and any(f in name for f in include_name_contains): qspec[name] = default_dynamic_qconfig qmodel = quantize_dynamic(model, qspec, dtype=torch.qint8) else: qmodel = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) t_quant = time.perf_counter() - t1 rss_quant = rss_gb() out = Path(out_path) out.parent.mkdir(parents=True, exist_ok=True) torch.save(qmodel, out) file_size_mb = out.stat().st_size / (1024 ** 2) meta = { "model_id": model_id, "path": str(out), "file_size_mb": file_size_mb, "build_load_s": t_load, "build_quantize_s": t_quant, "rss_start_gb": rss0, "rss_after_load_gb": rss_load, "rss_after_quantize_gb": rss_quant, "scheme": "torch_dynamic_int8_linear", "include_name_contains": include_name_contains, } Path(meta_path).write_text(json.dumps(meta, indent=2), encoding="utf-8") print(f"Saved: {out}") print(f"Meta: {meta_path}") print(json.dumps(meta, indent=2)) def run_runtime_int8(model_path: str, prompt: str, max_new_tokens: int) -> None: model = torch.load(model_path, map_location="cpu", weights_only=False).eval() if hasattr(model, "config") and getattr(model.config, "_name_or_path", None): model_id = model.config._name_or_path else: model_id = "Qwen/Qwen2.5-0.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) messages = [ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, {"role": "user", "content": prompt}, ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer([text], return_tensors="pt") with torch.no_grad(): out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) new_tokens = out[:, inputs["input_ids"].shape[1]:] ans = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0] print("\n=== Prompt ===") print(prompt) print("\n=== Response ===") print(ans) def main() -> None: args = parse_args() if args.cmd == "build": build_runtime_int8(args.model_id, args.out, args.meta, args.include_name_contains) elif args.cmd == "run": run_runtime_int8(args.model, args.prompt, args.max_new_tokens) if __name__ == "__main__": main()