| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import time |
| from pathlib import Path |
| from typing import Dict, Any |
|
|
| import psutil |
| import torch |
| import torch.nn as nn |
| from accelerate import init_empty_weights |
| from accelerate.utils import set_module_tensor_to_device |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
| from rotorquant_weights import load_quantized_package, _unpack_3bit, _deterministic_rotor_matrix |
|
|
|
|
| def rss_gb() -> float: |
| return psutil.Process().memory_info().rss / (1024 ** 3) |
|
|
|
|
| def _get_parent_module(root: nn.Module, module_path: str): |
| if not module_path: |
| return None, "" |
| parts = module_path.split(".") |
| parent = root |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| return parent, parts[-1] |
|
|
|
|
| class FusedRotorLinear(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| bias: torch.Tensor | None, |
| qt_raw: Dict[str, Any], |
| seed: int, |
| block_size: int, |
| rotor_angle_scale: float, |
| layer_name: str, |
| out_chunk_size: int = 64, |
| cache_weight: bool = True, |
| ): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.block_size = block_size |
| self.rotor_angle_scale = rotor_angle_scale |
| self.layer_name = layer_name |
| self.out_chunk_size = out_chunk_size |
| self.cache_weight = cache_weight |
|
|
| self.row_size = int(qt_raw["row_size"]) |
| self.row_rot_size = int(qt_raw["row_rot_size"]) |
| self.row_padded_size = int(qt_raw["row_padded_size"]) |
| self.n_rows = int(qt_raw["n_rows"]) |
|
|
| packed = qt_raw["packed_indices"].to(torch.uint8) |
| n_values = self.n_rows * self.row_padded_size |
| idx = _unpack_3bit(packed, n_values=n_values, device=torch.device("cpu")) |
| idx = idx.reshape(self.n_rows, self.row_padded_size).to(torch.uint8) |
|
|
| self.register_buffer("indices", idx) |
| self.register_buffer("centers", qt_raw["centers"].to(torch.float16)) |
| self.register_buffer("scales", qt_raw["scales"].to(torch.float16)) |
| self.register_buffer("codebook", qt_raw["codebook"].to(torch.float16)) |
|
|
| lowrank_A = qt_raw.get("lowrank_A") |
| lowrank_B = qt_raw.get("lowrank_B") |
| if lowrank_A is not None and lowrank_B is not None: |
| self.register_buffer("lowrank_A", lowrank_A.to(torch.float16)) |
| self.register_buffer("lowrank_B", lowrank_B.to(torch.float16)) |
| else: |
| self.lowrank_A = None |
| self.lowrank_B = None |
|
|
| R = _deterministic_rotor_matrix( |
| name=layer_name, |
| seed=seed, |
| device=torch.device("cpu"), |
| dtype=torch.float32, |
| angle_scale=rotor_angle_scale, |
| ) |
| self.register_buffer("R", R.to(torch.float16)) |
|
|
| if bias is not None: |
| self.bias = nn.Parameter(bias.to(torch.float32), requires_grad=False) |
| else: |
| self.register_parameter("bias", None) |
| self.register_buffer("_cached_weight", None, persistent=False) |
|
|
| def _decode_weight_chunk(self, s: int, e: int) -> torch.Tensor: |
| rows = e - s |
| idx = self.indices[s:e].long() |
| vals = self.codebook[idx] |
|
|
| n_blocks = self.row_padded_size // self.block_size |
| vals_b = vals.view(rows, n_blocks, self.block_size) |
|
|
| centers = self.centers[s * n_blocks : e * n_blocks].view(rows, n_blocks, 1) |
| scales = self.scales[s * n_blocks : e * n_blocks].view(rows, n_blocks, 1) |
| w_rot = (vals_b * scales + centers).view(rows, self.row_padded_size) |
| w_rot = w_rot[:, : self.row_rot_size] |
|
|
| R = self.R.to(dtype=torch.float32) |
| w = (w_rot.to(torch.float32).view(rows, -1, 3) @ R).view(rows, self.row_rot_size) |
| w = w[:, : self.row_size] |
|
|
| if self.lowrank_A is not None and self.lowrank_B is not None: |
| A = self.lowrank_A[s:e].to(torch.float32) |
| B = self.lowrank_B.to(torch.float32) |
| w = w + (A @ B) |
|
|
| return w[:, : self.in_features] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| orig_shape = x.shape[:-1] |
| x2 = x.reshape(-1, x.shape[-1]).to(torch.float32) |
|
|
| if self.cache_weight and self._cached_weight is None: |
| parts = [] |
| for s in range(0, self.out_features, self.out_chunk_size): |
| e = min(self.out_features, s + self.out_chunk_size) |
| parts.append(self._decode_weight_chunk(s, e)) |
| self._cached_weight = torch.cat(parts, dim=0).to(torch.float16) |
|
|
| if self._cached_weight is not None: |
| w = self._cached_weight.to(device=x2.device, dtype=torch.float32) |
| out = x2 @ w.T |
| else: |
| out = torch.empty(x2.shape[0], self.out_features, dtype=torch.float32, device=x2.device) |
| for s in range(0, self.out_features, self.out_chunk_size): |
| e = min(self.out_features, s + self.out_chunk_size) |
| w = self._decode_weight_chunk(s, e).to(x2.device) |
| out[:, s:e] = x2 @ w.T |
|
|
| if self.bias is not None: |
| out = out + self.bias |
| return out.reshape(*orig_shape, self.out_features) |
|
|
|
|
| def load_fused_model(pkg_path: str, out_chunk_size: int = 64): |
| t0 = time.perf_counter() |
| pkg = load_quantized_package(pkg_path) |
| model_id = pkg["model_id"] |
| seed = int(pkg.get("seed", 1337)) |
| block_size = int(pkg.get("block_size", 128)) |
| rotor_angle_scale = float(pkg.get("rotor_angle_scale", 1.0)) |
|
|
| config = AutoConfig.from_pretrained(model_id) |
| with init_empty_weights(): |
| model = AutoModelForCausalLM.from_config(config) |
| model.eval() |
|
|
| passthrough = pkg["passthrough"] |
| quantized = pkg["quantized"] |
|
|
| consumed = set() |
| for w_name, qt_raw in quantized.items(): |
| if not w_name.endswith(".weight"): |
| continue |
| mod_name = w_name[:-7] |
| parent, child = _get_parent_module(model, mod_name) |
| if parent is None or not hasattr(parent, child): |
| continue |
| old = getattr(parent, child) |
| if not isinstance(old, nn.Linear): |
| continue |
| if "n_rows" not in qt_raw: |
| continue |
|
|
| bias_name = f"{mod_name}.bias" |
| bias = passthrough.get(bias_name) |
|
|
| fused = FusedRotorLinear( |
| in_features=old.in_features, |
| out_features=old.out_features, |
| bias=bias, |
| qt_raw=qt_raw, |
| seed=seed, |
| block_size=block_size, |
| rotor_angle_scale=rotor_angle_scale, |
| layer_name=w_name, |
| out_chunk_size=out_chunk_size, |
| cache_weight=True, |
| ) |
| setattr(parent, child, fused) |
| consumed.add(w_name) |
| if bias is not None: |
| consumed.add(bias_name) |
|
|
| for name, t in passthrough.items(): |
| if name in consumed: |
| continue |
| set_module_tensor_to_device(model, name, "cpu", value=t) |
|
|
| for name, qt_raw in quantized.items(): |
| if name in consumed: |
| continue |
| if "n_rows" not in qt_raw: |
| from rotorquant_weights import dequantize_to_state_dict |
| sd = dequantize_to_state_dict({ |
| "bits": pkg["bits"], |
| "block_size": pkg["block_size"], |
| "seed": pkg["seed"], |
| "lowrank_rank": pkg.get("lowrank_rank", 0), |
| "rotor_angle_scale": pkg.get("rotor_angle_scale", 1.0), |
| "rowwise": False, |
| "quantized": {name: qt_raw}, |
| "passthrough": {}, |
| }, dtype=torch.float32, device="cpu") |
| set_module_tensor_to_device(model, name, "cpu", value=sd[name]) |
| else: |
| from rotorquant_weights import dequantize_to_state_dict |
| sd = dequantize_to_state_dict({ |
| "bits": pkg["bits"], |
| "block_size": pkg["block_size"], |
| "seed": pkg["seed"], |
| "lowrank_rank": pkg.get("lowrank_rank", 0), |
| "rotor_angle_scale": pkg.get("rotor_angle_scale", 1.0), |
| "rowwise": pkg.get("rowwise", True), |
| "quantized": {name: qt_raw}, |
| "passthrough": {}, |
| }, dtype=torch.float32, device="cpu") |
| set_module_tensor_to_device(model, name, "cpu", value=sd[name]) |
|
|
| model = model.to(torch.float32) |
| load_s = time.perf_counter() - t0 |
| return model, model_id, load_s |
|
|
|
|
| def run_prompt(model, tokenizer, prompt: str, max_new_tokens: int) -> str: |
| messages = [ |
| {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, |
| {"role": "user", "content": prompt}, |
| ] |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inp = tokenizer([text], return_tensors="pt") |
| with torch.no_grad(): |
| out = model.generate(**inp, max_new_tokens=max_new_tokens, do_sample=False) |
| new = out[:, inp["input_ids"].shape[1]:] |
| return tokenizer.batch_decode(new, skip_special_tokens=True)[0] |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Run fused RotorQuant runtime") |
| sub = p.add_subparsers(dest="cmd", required=True) |
|
|
| r = sub.add_parser("run") |
| r.add_argument("--pkg", default="artifacts/qwen2.5-0.5b-rotorq3-mlp-only.pt") |
| r.add_argument("--prompt", default="Explain quantization in one paragraph.") |
| r.add_argument("--max-new-tokens", type=int, default=64) |
| r.add_argument("--out-chunk-size", type=int, default=64) |
|
|
| b = sub.add_parser("bench") |
| b.add_argument("--pkg", default="artifacts/qwen2.5-0.5b-rotorq3-mlp-only.pt") |
| b.add_argument("--out", default="artifacts/fused_runtime_meta.json") |
| b.add_argument("--out-chunk-size", type=int, default=64) |
|
|
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| if args.cmd == "run": |
| model, model_id, load_s = load_fused_model(args.pkg, out_chunk_size=args.out_chunk_size) |
| tok = AutoTokenizer.from_pretrained(model_id) |
| ans = run_prompt(model, tok, args.prompt, args.max_new_tokens) |
| print(f"load_s={load_s:.3f}") |
| print("\n=== Prompt ===") |
| print(args.prompt) |
| print("\n=== Response ===") |
| print(ans) |
|
|
| elif args.cmd == "bench": |
| rss0 = rss_gb() |
| model, model_id, load_s = load_fused_model(args.pkg, out_chunk_size=args.out_chunk_size) |
| rss1 = rss_gb() |
| tok = AutoTokenizer.from_pretrained(model_id) |
| t0 = time.perf_counter() |
| _ = run_prompt(model, tok, "warmup", 32) |
| t1 = time.perf_counter() - t0 |
| rss2 = rss_gb() |
|
|
| out = { |
| "pkg": args.pkg, |
| "load_s": load_s, |
| "warmup_generate_s": t1, |
| "rss_before_gb": rss0, |
| "rss_after_load_gb": rss1, |
| "rss_after_warmup_gb": rss2, |
| "out_chunk_size": args.out_chunk_size, |
| } |
| Path(args.out).write_text(json.dumps(out, indent=2), encoding="utf-8") |
| print(json.dumps(out, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|