File size: 4,462 Bytes
18f4d80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import psutil
import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic
from torch.ao.quantization.qconfig import default_dynamic_qconfig
from transformers import AutoModelForCausalLM, AutoTokenizer
def rss_gb() -> float:
return psutil.Process().memory_info().rss / (1024 ** 3)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Build and run runtime INT8 quantized model (dynamic quantization)")
sub = p.add_subparsers(dest="cmd", required=True)
b = sub.add_parser("build", help="Build and save runtime INT8 model")
b.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct")
b.add_argument("--out", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt")
b.add_argument("--meta", default="artifacts/qwen2.5-0.5b-dynamic-int8-meta.json")
b.add_argument(
"--include-name-contains",
action="append",
default=[],
help="Only quantize Linear modules whose full module name contains one of these fragments.",
)
r = sub.add_parser("run", help="Load runtime INT8 model and generate")
r.add_argument("--model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt")
r.add_argument("--prompt", default="Explain quantization in one paragraph.")
r.add_argument("--max-new-tokens", type=int, default=64)
return p.parse_args()
def build_runtime_int8(model_id: str, out_path: str, meta_path: str, include_name_contains: list[str]) -> None:
t0 = time.perf_counter()
rss0 = rss_gb()
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map=None,
low_cpu_mem_usage=True,
).eval()
t_load = time.perf_counter() - t0
rss_load = rss_gb()
t1 = time.perf_counter()
if include_name_contains:
qspec = {}
for name, mod in model.named_modules():
if isinstance(mod, nn.Linear) and any(f in name for f in include_name_contains):
qspec[name] = default_dynamic_qconfig
qmodel = quantize_dynamic(model, qspec, dtype=torch.qint8)
else:
qmodel = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
t_quant = time.perf_counter() - t1
rss_quant = rss_gb()
out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
torch.save(qmodel, out)
file_size_mb = out.stat().st_size / (1024 ** 2)
meta = {
"model_id": model_id,
"path": str(out),
"file_size_mb": file_size_mb,
"build_load_s": t_load,
"build_quantize_s": t_quant,
"rss_start_gb": rss0,
"rss_after_load_gb": rss_load,
"rss_after_quantize_gb": rss_quant,
"scheme": "torch_dynamic_int8_linear",
"include_name_contains": include_name_contains,
}
Path(meta_path).write_text(json.dumps(meta, indent=2), encoding="utf-8")
print(f"Saved: {out}")
print(f"Meta: {meta_path}")
print(json.dumps(meta, indent=2))
def run_runtime_int8(model_path: str, prompt: str, max_new_tokens: int) -> None:
model = torch.load(model_path, map_location="cpu", weights_only=False).eval()
if hasattr(model, "config") and getattr(model.config, "_name_or_path", None):
model_id = model.config._name_or_path
else:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt")
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
ans = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]
print("\n=== Prompt ===")
print(prompt)
print("\n=== Response ===")
print(ans)
def main() -> None:
args = parse_args()
if args.cmd == "build":
build_runtime_int8(args.model_id, args.out, args.meta, args.include_name_contains)
elif args.cmd == "run":
run_runtime_int8(args.model, args.prompt, args.max_new_tokens)
if __name__ == "__main__":
main()
|