| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import psutil |
| import torch |
| import torch.nn as nn |
| from torch.ao.quantization import quantize_dynamic |
| from torch.ao.quantization.qconfig import default_dynamic_qconfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| def rss_gb() -> float: |
| return psutil.Process().memory_info().rss / (1024 ** 3) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Build and run runtime INT8 quantized model (dynamic quantization)") |
| sub = p.add_subparsers(dest="cmd", required=True) |
|
|
| b = sub.add_parser("build", help="Build and save runtime INT8 model") |
| b.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct") |
| b.add_argument("--out", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt") |
| b.add_argument("--meta", default="artifacts/qwen2.5-0.5b-dynamic-int8-meta.json") |
| b.add_argument( |
| "--include-name-contains", |
| action="append", |
| default=[], |
| help="Only quantize Linear modules whose full module name contains one of these fragments.", |
| ) |
|
|
| r = sub.add_parser("run", help="Load runtime INT8 model and generate") |
| r.add_argument("--model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt") |
| r.add_argument("--prompt", default="Explain quantization in one paragraph.") |
| r.add_argument("--max-new-tokens", type=int, default=64) |
|
|
| return p.parse_args() |
|
|
|
|
| def build_runtime_int8(model_id: str, out_path: str, meta_path: str, include_name_contains: list[str]) -> None: |
| t0 = time.perf_counter() |
| rss0 = rss_gb() |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.float32, |
| device_map=None, |
| low_cpu_mem_usage=True, |
| ).eval() |
|
|
| t_load = time.perf_counter() - t0 |
| rss_load = rss_gb() |
|
|
| t1 = time.perf_counter() |
| if include_name_contains: |
| qspec = {} |
| for name, mod in model.named_modules(): |
| if isinstance(mod, nn.Linear) and any(f in name for f in include_name_contains): |
| qspec[name] = default_dynamic_qconfig |
| qmodel = quantize_dynamic(model, qspec, dtype=torch.qint8) |
| else: |
| qmodel = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) |
| t_quant = time.perf_counter() - t1 |
| rss_quant = rss_gb() |
|
|
| out = Path(out_path) |
| out.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(qmodel, out) |
| file_size_mb = out.stat().st_size / (1024 ** 2) |
|
|
| meta = { |
| "model_id": model_id, |
| "path": str(out), |
| "file_size_mb": file_size_mb, |
| "build_load_s": t_load, |
| "build_quantize_s": t_quant, |
| "rss_start_gb": rss0, |
| "rss_after_load_gb": rss_load, |
| "rss_after_quantize_gb": rss_quant, |
| "scheme": "torch_dynamic_int8_linear", |
| "include_name_contains": include_name_contains, |
| } |
| Path(meta_path).write_text(json.dumps(meta, indent=2), encoding="utf-8") |
|
|
| print(f"Saved: {out}") |
| print(f"Meta: {meta_path}") |
| print(json.dumps(meta, indent=2)) |
|
|
|
|
| def run_runtime_int8(model_path: str, prompt: str, max_new_tokens: int) -> None: |
| model = torch.load(model_path, map_location="cpu", weights_only=False).eval() |
|
|
| if hasattr(model, "config") and getattr(model.config, "_name_or_path", None): |
| model_id = model.config._name_or_path |
| else: |
| model_id = "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| messages = [ |
| {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, |
| {"role": "user", "content": prompt}, |
| ] |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer([text], return_tensors="pt") |
|
|
| with torch.no_grad(): |
| out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) |
|
|
| new_tokens = out[:, inputs["input_ids"].shape[1]:] |
| ans = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0] |
|
|
| print("\n=== Prompt ===") |
| print(prompt) |
| print("\n=== Response ===") |
| print(ans) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.cmd == "build": |
| build_runtime_int8(args.model_id, args.out, args.meta, args.include_name_contains) |
| elif args.cmd == "run": |
| run_runtime_int8(args.model, args.prompt, args.max_new_tokens) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|