cnmoro's picture
Upload 29 files
18f4d80 verified
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import psutil
import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic
from torch.ao.quantization.qconfig import default_dynamic_qconfig
from transformers import AutoModelForCausalLM, AutoTokenizer
def rss_gb() -> float:
return psutil.Process().memory_info().rss / (1024 ** 3)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Build and run runtime INT8 quantized model (dynamic quantization)")
sub = p.add_subparsers(dest="cmd", required=True)
b = sub.add_parser("build", help="Build and save runtime INT8 model")
b.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct")
b.add_argument("--out", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt")
b.add_argument("--meta", default="artifacts/qwen2.5-0.5b-dynamic-int8-meta.json")
b.add_argument(
"--include-name-contains",
action="append",
default=[],
help="Only quantize Linear modules whose full module name contains one of these fragments.",
)
r = sub.add_parser("run", help="Load runtime INT8 model and generate")
r.add_argument("--model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt")
r.add_argument("--prompt", default="Explain quantization in one paragraph.")
r.add_argument("--max-new-tokens", type=int, default=64)
return p.parse_args()
def build_runtime_int8(model_id: str, out_path: str, meta_path: str, include_name_contains: list[str]) -> None:
t0 = time.perf_counter()
rss0 = rss_gb()
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map=None,
low_cpu_mem_usage=True,
).eval()
t_load = time.perf_counter() - t0
rss_load = rss_gb()
t1 = time.perf_counter()
if include_name_contains:
qspec = {}
for name, mod in model.named_modules():
if isinstance(mod, nn.Linear) and any(f in name for f in include_name_contains):
qspec[name] = default_dynamic_qconfig
qmodel = quantize_dynamic(model, qspec, dtype=torch.qint8)
else:
qmodel = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
t_quant = time.perf_counter() - t1
rss_quant = rss_gb()
out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
torch.save(qmodel, out)
file_size_mb = out.stat().st_size / (1024 ** 2)
meta = {
"model_id": model_id,
"path": str(out),
"file_size_mb": file_size_mb,
"build_load_s": t_load,
"build_quantize_s": t_quant,
"rss_start_gb": rss0,
"rss_after_load_gb": rss_load,
"rss_after_quantize_gb": rss_quant,
"scheme": "torch_dynamic_int8_linear",
"include_name_contains": include_name_contains,
}
Path(meta_path).write_text(json.dumps(meta, indent=2), encoding="utf-8")
print(f"Saved: {out}")
print(f"Meta: {meta_path}")
print(json.dumps(meta, indent=2))
def run_runtime_int8(model_path: str, prompt: str, max_new_tokens: int) -> None:
model = torch.load(model_path, map_location="cpu", weights_only=False).eval()
if hasattr(model, "config") and getattr(model.config, "_name_or_path", None):
model_id = model.config._name_or_path
else:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt")
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
ans = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]
print("\n=== Prompt ===")
print(prompt)
print("\n=== Response ===")
print(ans)
def main() -> None:
args = parse_args()
if args.cmd == "build":
build_runtime_int8(args.model_id, args.out, args.meta, args.include_name_contains)
elif args.cmd == "run":
run_runtime_int8(args.model, args.prompt, args.max_new_tokens)
if __name__ == "__main__":
main()