File size: 3,438 Bytes
18f4d80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | from __future__ import annotations
import argparse
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM
from rotorquant_weights import (
quantize_state_dict,
save_quantized_package,
save_report,
estimate_bits_per_weight,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Quantize a HF model with RotorQuant-weight codec")
p.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct")
p.add_argument("--output", default="artifacts/qwen2.5-0.5b-rotorq3.pt")
p.add_argument("--report", default="artifacts/qwen2.5-0.5b-rotorq3-report.json")
p.add_argument("--bits", type=int, default=3)
p.add_argument("--block-size", type=int, default=128)
p.add_argument("--seed", type=int, default=1337)
p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
p.add_argument("--min-ndim", type=int, default=2)
p.add_argument(
"--skip-name",
action="append",
default=[],
help="Exact tensor names to keep unquantized (repeatable).",
)
p.add_argument("--lowrank-rank", type=int, default=0, help="Optional residual low-rank correction rank.")
p.add_argument("--rotor-angle-scale", type=float, default=1.0, help="Scale for rotor angle; 0.0 disables rotation.")
p.add_argument("--rowwise", action="store_true", help="Quantize per-row (higher overhead, sometimes higher fidelity).")
p.add_argument("--outlier-frac", type=float, default=0.0, help="Store top-k residual outliers per row in fp16.")
p.add_argument(
"--include-name-contains",
action="append",
default=[],
help="Only quantize tensors whose name contains at least one provided fragment (repeatable).",
)
return p.parse_args()
def str_to_dtype(s: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[s]
def main() -> None:
args = parse_args()
dtype = str_to_dtype(args.dtype)
print(f"Loading model: {args.model_id} (dtype={dtype})")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
torch_dtype=dtype,
device_map=None,
low_cpu_mem_usage=True,
)
model.eval()
state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
print(f"State dict tensors: {len(state)}")
pkg = quantize_state_dict(
state,
bits=args.bits,
block_size=args.block_size,
seed=args.seed,
min_ndim=args.min_ndim,
verbose=True,
skip_names=args.skip_name,
lowrank_rank=args.lowrank_rank,
rotor_angle_scale=args.rotor_angle_scale,
rowwise=args.rowwise,
include_if_name_contains=args.include_name_contains,
outlier_frac=args.outlier_frac,
)
pkg["model_id"] = args.model_id
pkg["source_dtype"] = args.dtype
output_path = Path(args.output)
report_path = Path(args.report)
save_quantized_package(pkg, output_path)
save_report(pkg, report_path)
bpw = estimate_bits_per_weight(pkg)
print(f"Saved quantized package: {output_path}")
print(f"Saved report: {report_path}")
print(f"Estimated effective bits/weight: {bpw:.4f}")
print(f"Quantized tensors: {len(pkg['quantized'])}, passthrough: {len(pkg['passthrough'])}")
if __name__ == "__main__":
main()
|