File size: 4,462 Bytes
18f4d80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import psutil
import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic
from torch.ao.quantization.qconfig import default_dynamic_qconfig
from transformers import AutoModelForCausalLM, AutoTokenizer


def rss_gb() -> float:
    return psutil.Process().memory_info().rss / (1024 ** 3)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Build and run runtime INT8 quantized model (dynamic quantization)")
    sub = p.add_subparsers(dest="cmd", required=True)

    b = sub.add_parser("build", help="Build and save runtime INT8 model")
    b.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct")
    b.add_argument("--out", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt")
    b.add_argument("--meta", default="artifacts/qwen2.5-0.5b-dynamic-int8-meta.json")
    b.add_argument(
        "--include-name-contains",
        action="append",
        default=[],
        help="Only quantize Linear modules whose full module name contains one of these fragments.",
    )

    r = sub.add_parser("run", help="Load runtime INT8 model and generate")
    r.add_argument("--model", default="artifacts/qwen2.5-0.5b-dynamic-int8.pt")
    r.add_argument("--prompt", default="Explain quantization in one paragraph.")
    r.add_argument("--max-new-tokens", type=int, default=64)

    return p.parse_args()


def build_runtime_int8(model_id: str, out_path: str, meta_path: str, include_name_contains: list[str]) -> None:
    t0 = time.perf_counter()
    rss0 = rss_gb()

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float32,
        device_map=None,
        low_cpu_mem_usage=True,
    ).eval()

    t_load = time.perf_counter() - t0
    rss_load = rss_gb()

    t1 = time.perf_counter()
    if include_name_contains:
        qspec = {}
        for name, mod in model.named_modules():
            if isinstance(mod, nn.Linear) and any(f in name for f in include_name_contains):
                qspec[name] = default_dynamic_qconfig
        qmodel = quantize_dynamic(model, qspec, dtype=torch.qint8)
    else:
        qmodel = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
    t_quant = time.perf_counter() - t1
    rss_quant = rss_gb()

    out = Path(out_path)
    out.parent.mkdir(parents=True, exist_ok=True)
    torch.save(qmodel, out)
    file_size_mb = out.stat().st_size / (1024 ** 2)

    meta = {
        "model_id": model_id,
        "path": str(out),
        "file_size_mb": file_size_mb,
        "build_load_s": t_load,
        "build_quantize_s": t_quant,
        "rss_start_gb": rss0,
        "rss_after_load_gb": rss_load,
        "rss_after_quantize_gb": rss_quant,
        "scheme": "torch_dynamic_int8_linear",
        "include_name_contains": include_name_contains,
    }
    Path(meta_path).write_text(json.dumps(meta, indent=2), encoding="utf-8")

    print(f"Saved: {out}")
    print(f"Meta: {meta_path}")
    print(json.dumps(meta, indent=2))


def run_runtime_int8(model_path: str, prompt: str, max_new_tokens: int) -> None:
    model = torch.load(model_path, map_location="cpu", weights_only=False).eval()

    if hasattr(model, "config") and getattr(model.config, "_name_or_path", None):
        model_id = model.config._name_or_path
    else:
        model_id = "Qwen/Qwen2.5-0.5B-Instruct"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    messages = [
        {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
        {"role": "user", "content": prompt},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt")

    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    ans = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]

    print("\n=== Prompt ===")
    print(prompt)
    print("\n=== Response ===")
    print(ans)


def main() -> None:
    args = parse_args()
    if args.cmd == "build":
        build_runtime_int8(args.model_id, args.out, args.meta, args.include_name_contains)
    elif args.cmd == "run":
        run_runtime_int8(args.model, args.prompt, args.max_new_tokens)


if __name__ == "__main__":
    main()