| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| import math |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Tuple, Any, Iterable |
|
|
| import torch |
|
|
|
|
| def _deterministic_rotor_matrix( |
| name: str, |
| seed: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| angle_scale: float = 1.0, |
| ) -> torch.Tensor: |
| key = f"{seed}:{name}".encode("utf-8") |
| digest = hashlib.sha256(key).digest() |
|
|
| def _u32(i: int) -> int: |
| return int.from_bytes(digest[4 * i : 4 * (i + 1)], "little") |
|
|
| v = torch.tensor([ |
| (_u32(0) / (2**32 - 1)) * 2.0 - 1.0, |
| (_u32(1) / (2**32 - 1)) * 2.0 - 1.0, |
| (_u32(2) / (2**32 - 1)) * 2.0 - 1.0, |
| ], device=device, dtype=torch.float64) |
| if torch.linalg.norm(v) < 1e-12: |
| v = torch.tensor([1.0, 0.0, 0.0], device=device, dtype=torch.float64) |
| axis = v / torch.linalg.norm(v) |
|
|
| angle_u = _u32(3) / (2**32 - 1) |
| angle = (2.0 * angle_u - 1.0) * (math.pi / 2.0) * angle_scale |
|
|
| K = torch.tensor( |
| [ |
| [0.0, -axis[2].item(), axis[1].item()], |
| [axis[2].item(), 0.0, -axis[0].item()], |
| [-axis[1].item(), axis[0].item(), 0.0], |
| ], |
| device=device, |
| dtype=torch.float64, |
| ) |
| I = torch.eye(3, device=device, dtype=torch.float64) |
| R = I + math.sin(angle) * K + (1.0 - math.cos(angle)) * (K @ K) |
| return R.to(dtype=dtype) |
|
|
|
|
| def _lloyd_codebook(values: torch.Tensor, levels: int = 8, iters: int = 25, eps: float = 1e-6) -> torch.Tensor: |
| flat = values.reshape(-1) |
| if flat.numel() == 0: |
| return torch.linspace(-1.0, 1.0, levels, device=values.device, dtype=values.dtype) |
|
|
| if flat.numel() > 250_000: |
| idx = torch.randperm(flat.numel(), device=flat.device)[:250_000] |
| work = flat[idx] |
| else: |
| work = flat |
|
|
| probs = torch.linspace(0.0, 1.0, levels, device=work.device, dtype=torch.float32) |
| init = torch.quantile(work.float(), probs).to(work.dtype) |
| codebook = init |
|
|
| for _ in range(iters): |
| d = (work.unsqueeze(1) - codebook.unsqueeze(0)).abs() |
| assign = d.argmin(dim=1) |
| new_codebook = codebook.clone() |
| for k in range(levels): |
| sel = work[assign == k] |
| if sel.numel() > 0: |
| new_codebook[k] = sel.mean() |
| if torch.max((new_codebook - codebook).abs()) < eps: |
| codebook = new_codebook |
| break |
| codebook = new_codebook |
|
|
| codebook, _ = torch.sort(codebook) |
| return codebook |
|
|
|
|
| def _pack_3bit(indices: torch.Tensor) -> torch.Tensor: |
| x = indices.reshape(-1).to(torch.uint8).cpu() |
| n = x.numel() |
| if n == 0: |
| return torch.empty(0, dtype=torch.uint8) |
|
|
| n_full = (n // 8) * 8 |
| full = x[:n_full].view(-1, 8).to(torch.int32) |
| packed24 = ( |
| full[:, 0] |
| | (full[:, 1] << 3) |
| | (full[:, 2] << 6) |
| | (full[:, 3] << 9) |
| | (full[:, 4] << 12) |
| | (full[:, 5] << 15) |
| | (full[:, 6] << 18) |
| | (full[:, 7] << 21) |
| ) |
| bytes_full = torch.stack( |
| [ |
| (packed24 & 0xFF), |
| ((packed24 >> 8) & 0xFF), |
| ((packed24 >> 16) & 0xFF), |
| ], |
| dim=1, |
| ).reshape(-1).to(torch.uint8) |
|
|
| rem = x[n_full:] |
| if rem.numel() == 0: |
| return bytes_full |
|
|
| rem_acc = torch.tensor(0, dtype=torch.int32) |
| for i, v in enumerate(rem.tolist()): |
| rem_acc |= (int(v) & 0x7) << (3 * i) |
| rem_nbytes = (rem.numel() * 3 + 7) // 8 |
| rem_bytes = torch.tensor([(rem_acc >> (8 * i)) & 0xFF for i in range(rem_nbytes)], dtype=torch.uint8) |
| return torch.cat([bytes_full, rem_bytes], dim=0) |
|
|
|
|
| def _unpack_3bit(packed: torch.Tensor, n_values: int, device: torch.device) -> torch.Tensor: |
| if n_values == 0: |
| return torch.empty(0, dtype=torch.uint8, device=device) |
|
|
| p = packed.to(torch.uint8).cpu() |
| n_full = (n_values // 8) * 8 |
| n_full_groups = n_full // 8 |
| n_full_bytes = n_full_groups * 3 |
|
|
| out_parts = [] |
| if n_full_groups > 0: |
| b = p[:n_full_bytes].view(-1, 3).to(torch.int32) |
| packed24 = b[:, 0] | (b[:, 1] << 8) | (b[:, 2] << 16) |
| vals = torch.stack( |
| [ |
| (packed24 >> 0) & 0x7, |
| (packed24 >> 3) & 0x7, |
| (packed24 >> 6) & 0x7, |
| (packed24 >> 9) & 0x7, |
| (packed24 >> 12) & 0x7, |
| (packed24 >> 15) & 0x7, |
| (packed24 >> 18) & 0x7, |
| (packed24 >> 21) & 0x7, |
| ], |
| dim=1, |
| ).reshape(-1).to(torch.uint8) |
| out_parts.append(vals) |
|
|
| rem_n = n_values - n_full |
| if rem_n > 0: |
| rem_bytes = p[n_full_bytes:] |
| acc = 0 |
| for i, v in enumerate(rem_bytes.tolist()): |
| acc |= int(v) << (8 * i) |
| rem_vals = torch.tensor([(acc >> (3 * i)) & 0x7 for i in range(rem_n)], dtype=torch.uint8) |
| out_parts.append(rem_vals) |
|
|
| return torch.cat(out_parts, dim=0).to(device=device) |
|
|
|
|
| @dataclass |
| class QuantizedTensor: |
| shape: Tuple[int, ...] |
| n_rows: int |
| row_size: int |
| row_rot_size: int |
| row_padded_size: int |
| packed_indices: torch.Tensor |
| centers: torch.Tensor |
| scales: torch.Tensor |
| codebook: torch.Tensor |
| lowrank_A: torch.Tensor | None = None |
| lowrank_B: torch.Tensor | None = None |
| outlier_pos: torch.Tensor | None = None |
| outlier_vals: torch.Tensor | None = None |
|
|
|
|
| class RotorQuantWeightCodec: |
| def __init__( |
| self, |
| bits: int = 3, |
| block_size: int = 128, |
| seed: int = 1337, |
| eps: float = 1e-8, |
| lowrank_rank: int = 0, |
| rotor_angle_scale: float = 1.0, |
| rowwise: bool = False, |
| outlier_frac: float = 0.0, |
| ): |
| if bits != 3: |
| raise ValueError("Current prototype only implements 3-bit packing.") |
| self.bits = bits |
| self.block_size = block_size |
| self.seed = seed |
| self.eps = eps |
| self.lowrank_rank = lowrank_rank |
| self.rotor_angle_scale = rotor_angle_scale |
| self.rowwise = rowwise |
| self.outlier_frac = outlier_frac |
|
|
| def quantize_tensor(self, name: str, w: torch.Tensor) -> QuantizedTensor: |
| x = w.float() |
| if self.rowwise and x.ndim >= 2: |
| n_rows = int(math.prod(x.shape[:-1])) |
| row_size = x.shape[-1] |
| rows_orig = x.reshape(n_rows, row_size) |
| else: |
| n_rows = 1 |
| row_size = x.numel() |
| rows_orig = x.reshape(1, row_size) |
|
|
| rows = rows_orig |
| pad3 = (-row_size) % 3 |
| if pad3: |
| rows = torch.cat([rows, torch.zeros(n_rows, pad3, device=rows.device, dtype=rows.dtype)], dim=1) |
| row_rot_size = rows.shape[1] |
|
|
| R = _deterministic_rotor_matrix( |
| name=name, |
| seed=self.seed, |
| device=rows.device, |
| dtype=rows.dtype, |
| angle_scale=self.rotor_angle_scale, |
| ) |
| rot = (rows.reshape(n_rows, -1, 3) @ R.T).reshape(n_rows, row_rot_size) |
|
|
| pad_block = (-row_rot_size) % self.block_size |
| if pad_block: |
| rot = torch.cat([rot, torch.zeros(n_rows, pad_block, device=rot.device, dtype=rot.dtype)], dim=1) |
| row_padded_size = rot.shape[1] |
|
|
| n_blocks = row_padded_size // self.block_size |
| blocks = rot.view(n_rows, n_blocks, self.block_size) |
|
|
| centers = torch.zeros(n_rows, n_blocks, device=blocks.device, dtype=blocks.dtype) |
| scales = torch.zeros(n_rows, n_blocks, device=blocks.device, dtype=blocks.dtype) |
| normed = torch.zeros_like(blocks) |
|
|
| full_blocks = row_rot_size // self.block_size |
| tail = row_rot_size % self.block_size |
|
|
| if full_blocks > 0: |
| blk = blocks[:, :full_blocks, :] |
| c = blk.mean(dim=-1) |
| z = blk - c.unsqueeze(-1) |
| s = z.abs().amax(dim=-1).clamp(min=self.eps) |
| centers[:, :full_blocks] = c |
| scales[:, :full_blocks] = s |
| normed[:, :full_blocks, :] = z / s.unsqueeze(-1) |
|
|
| if tail > 0: |
| blk = blocks[:, full_blocks, :tail] |
| c = blk.mean(dim=-1) |
| z = blk - c.unsqueeze(-1) |
| s = z.abs().amax(dim=-1).clamp(min=self.eps) |
| centers[:, full_blocks] = c |
| scales[:, full_blocks] = s |
| normed[:, full_blocks, :tail] = z / s.unsqueeze(-1) |
|
|
| centers = centers.reshape(-1) |
| scales = scales.reshape(-1) |
| normed = normed.reshape(-1, self.block_size) |
|
|
| codebook = _lloyd_codebook(normed, levels=2**self.bits) |
| idx_chunks = [] |
| chunk_blocks = 4096 |
| for i in range(0, normed.shape[0], chunk_blocks): |
| b = normed[i : i + chunk_blocks] |
| diffs = (b.unsqueeze(-1) - codebook.view(1, 1, -1)).abs() |
| idx_chunks.append(diffs.argmin(dim=-1).to(torch.uint8)) |
| idx = torch.cat(idx_chunks, dim=0) |
|
|
| packed = _pack_3bit(idx.reshape(-1).cpu()) |
| qt = QuantizedTensor( |
| shape=tuple(w.shape), |
| n_rows=n_rows, |
| row_size=row_size, |
| row_rot_size=row_rot_size, |
| row_padded_size=row_padded_size, |
| packed_indices=packed, |
| centers=centers.cpu().to(torch.float16), |
| scales=scales.cpu().to(torch.float16), |
| codebook=codebook.cpu().to(torch.float16), |
| ) |
| if self.lowrank_rank > 0 and n_rows > 1 and row_size > 1: |
| deq_rows = (idx.to(torch.long).to(codebook.device).reshape(-1)) |
| deq_vals = codebook[deq_rows] |
| deq_blocks = deq_vals.reshape(-1, self.block_size) |
| deq_q = (deq_blocks * scales.unsqueeze(1) + centers.unsqueeze(1)).reshape(n_rows, row_padded_size) |
| deq_q = deq_q[:, :row_rot_size] |
| R = _deterministic_rotor_matrix( |
| name=name, |
| seed=self.seed, |
| device=rows.device, |
| dtype=rows.dtype, |
| angle_scale=self.rotor_angle_scale, |
| ) |
| x_hat_rows = (deq_q.reshape(n_rows, -1, 3) @ R).reshape(n_rows, row_rot_size)[:, :row_size] |
|
|
| residual = rows_orig - x_hat_rows |
| rank = min(self.lowrank_rank, residual.shape[0], residual.shape[1]) |
| if rank > 0: |
| U, S, V = torch.pca_lowrank(residual, q=rank, center=False, niter=2) |
| A = (U[:, :rank] * S[:rank]).to(torch.float16).cpu() |
| B = V[:, :rank].T.to(torch.float16).cpu() |
| qt.lowrank_A = A |
| qt.lowrank_B = B |
|
|
| if self.outlier_frac > 0 and row_size > 0: |
| deq_rows = self.dequantize_tensor(name, qt, device=torch.device("cpu"), dtype=torch.float32).reshape(n_rows, row_size) |
| residual = (rows_orig - deq_rows).abs() |
| k = max(1, int(row_size * self.outlier_frac)) |
| k = min(k, row_size) |
| vals, pos = torch.topk(residual, k=k, dim=1, largest=True, sorted=False) |
| out_vals = torch.gather(rows_orig, dim=1, index=pos) |
| qt.outlier_pos = pos.to(torch.int16).cpu() |
| qt.outlier_vals = out_vals.to(torch.float16).cpu() |
| return qt |
|
|
| def dequantize_tensor(self, name: str, qt: QuantizedTensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
| n_blocks = qt.n_rows * (qt.row_padded_size // self.block_size) |
| n_values = n_blocks * self.block_size |
| idx = _unpack_3bit(qt.packed_indices, n_values=n_values, device=device).long() |
| codebook = qt.codebook.to(device=device, dtype=torch.float32) |
| centers = qt.centers.to(device=device, dtype=torch.float32) |
| scales = qt.scales.to(device=device, dtype=torch.float32) |
|
|
| vals = codebook[idx] |
| blocks = vals.reshape(-1, self.block_size) |
| deq_rows = (blocks * scales.unsqueeze(1) + centers.unsqueeze(1)).reshape(qt.n_rows, qt.row_padded_size) |
| deq_rows = deq_rows[:, : qt.row_rot_size] |
|
|
| R = _deterministic_rotor_matrix( |
| name=name, |
| seed=self.seed, |
| device=device, |
| dtype=torch.float32, |
| angle_scale=self.rotor_angle_scale, |
| ) |
| x_rows = (deq_rows.reshape(qt.n_rows, -1, 3) @ R).reshape(qt.n_rows, qt.row_rot_size) |
| x_rows = x_rows[:, : qt.row_size] |
| if qt.lowrank_A is not None and qt.lowrank_B is not None: |
| A = qt.lowrank_A.to(device=device, dtype=torch.float32) |
| B = qt.lowrank_B.to(device=device, dtype=torch.float32) |
| x_rows = x_rows + (A @ B) |
| if qt.outlier_pos is not None and qt.outlier_vals is not None: |
| pos = qt.outlier_pos.to(device=device, dtype=torch.long) |
| vals = qt.outlier_vals.to(device=device, dtype=torch.float32) |
| x_rows.scatter_(dim=1, index=pos, src=vals) |
| return x_rows.reshape(qt.shape).to(dtype=dtype) |
|
|
|
|
| def quantize_state_dict( |
| state_dict: Dict[str, torch.Tensor], |
| bits: int = 3, |
| block_size: int = 128, |
| seed: int = 1337, |
| min_ndim: int = 2, |
| verbose: bool = False, |
| skip_names: Iterable[str] | None = None, |
| lowrank_rank: int = 0, |
| rotor_angle_scale: float = 1.0, |
| rowwise: bool = False, |
| include_if_name_contains: Iterable[str] | None = None, |
| outlier_frac: float = 0.0, |
| ) -> Dict[str, Any]: |
| codec = RotorQuantWeightCodec( |
| bits=bits, |
| block_size=block_size, |
| seed=seed, |
| lowrank_rank=lowrank_rank, |
| rotor_angle_scale=rotor_angle_scale, |
| rowwise=rowwise, |
| outlier_frac=outlier_frac, |
| ) |
| out: Dict[str, Any] = { |
| "format": "rotorquant_v1", |
| "bits": bits, |
| "block_size": block_size, |
| "seed": seed, |
| "lowrank_rank": lowrank_rank, |
| "rotor_angle_scale": rotor_angle_scale, |
| "rowwise": rowwise, |
| "outlier_frac": outlier_frac, |
| "quantized": {}, |
| "passthrough": {}, |
| } |
|
|
| skip_names = set(skip_names or []) |
| include_fragments = list(include_if_name_contains or []) |
| for i, (name, t) in enumerate(state_dict.items(), start=1): |
| if verbose: |
| print(f"[quantize] {i}/{len(state_dict)} {name} shape={tuple(t.shape)}") |
| if include_fragments and not any(frag in name for frag in include_fragments): |
| out["passthrough"][name] = t.cpu() |
| continue |
| if (not torch.is_floating_point(t)) or t.ndim < min_ndim or name in skip_names: |
| out["passthrough"][name] = t.cpu() |
| continue |
| qt = codec.quantize_tensor(name, t.detach().cpu()) |
| out["quantized"][name] = { |
| "shape": qt.shape, |
| "n_rows": qt.n_rows, |
| "row_size": qt.row_size, |
| "row_rot_size": qt.row_rot_size, |
| "row_padded_size": qt.row_padded_size, |
| "packed_indices": qt.packed_indices, |
| "centers": qt.centers, |
| "scales": qt.scales, |
| "codebook": qt.codebook, |
| "lowrank_A": qt.lowrank_A, |
| "lowrank_B": qt.lowrank_B, |
| "outlier_pos": qt.outlier_pos, |
| "outlier_vals": qt.outlier_vals, |
| } |
| return out |
|
|
|
|
| def dequantize_to_state_dict(pkg: Dict[str, Any], dtype: torch.dtype = torch.float32, device: str = "cpu") -> Dict[str, torch.Tensor]: |
| codec = RotorQuantWeightCodec( |
| bits=pkg["bits"], |
| block_size=pkg["block_size"], |
| seed=pkg["seed"], |
| lowrank_rank=int(pkg.get("lowrank_rank", 0)), |
| rotor_angle_scale=float(pkg.get("rotor_angle_scale", 1.0)), |
| rowwise=bool(pkg.get("rowwise", False)), |
| outlier_frac=float(pkg.get("outlier_frac", 0.0)), |
| ) |
| out: Dict[str, torch.Tensor] = {} |
| dev = torch.device(device) |
|
|
| for name, qt_raw in pkg["quantized"].items(): |
| if "n_rows" in qt_raw: |
| n_rows = int(qt_raw["n_rows"]) |
| row_size = int(qt_raw["row_size"]) |
| row_rot_size = int(qt_raw["row_rot_size"]) |
| row_padded_size = int(qt_raw["row_padded_size"]) |
| else: |
| numel = int(qt_raw["numel"]) |
| n_rows = 1 |
| row_size = numel |
| row_rot_size = ((numel + 2) // 3) * 3 |
| row_padded_size = int(qt_raw["padded_numel"]) |
|
|
| centers = qt_raw.get("centers") |
| if centers is None: |
| n_blocks = n_rows * (row_padded_size // pkg["block_size"]) |
| centers = torch.zeros(n_blocks, dtype=torch.float16) |
|
|
| qt = QuantizedTensor( |
| shape=tuple(qt_raw["shape"]), |
| n_rows=n_rows, |
| row_size=row_size, |
| row_rot_size=row_rot_size, |
| row_padded_size=row_padded_size, |
| packed_indices=qt_raw["packed_indices"], |
| centers=centers, |
| scales=qt_raw["scales"], |
| codebook=qt_raw["codebook"], |
| lowrank_A=qt_raw.get("lowrank_A"), |
| lowrank_B=qt_raw.get("lowrank_B"), |
| outlier_pos=qt_raw.get("outlier_pos"), |
| outlier_vals=qt_raw.get("outlier_vals"), |
| ) |
| out[name] = codec.dequantize_tensor(name, qt, device=dev, dtype=dtype) |
|
|
| for name, t in pkg["passthrough"].items(): |
| out[name] = t.to(device=dev, dtype=(dtype if torch.is_floating_point(t) else t.dtype)) |
|
|
| return out |
|
|
|
|
| def estimate_bits_per_weight(pkg: Dict[str, Any]) -> float: |
| total_numel = 0 |
| total_bits = 0 |
|
|
| for qt in pkg["quantized"].values(): |
| n = int(math.prod(qt["shape"])) |
| total_numel += n |
| total_bits += int(qt["packed_indices"].numel()) * 8 |
| if qt.get("centers") is not None: |
| total_bits += int(qt["centers"].numel()) * 16 |
| total_bits += int(qt["scales"].numel()) * 16 |
| total_bits += int(qt["codebook"].numel()) * 16 |
| if qt.get("lowrank_A") is not None and qt.get("lowrank_B") is not None: |
| total_bits += int(qt["lowrank_A"].numel()) * 16 |
| total_bits += int(qt["lowrank_B"].numel()) * 16 |
| if qt.get("outlier_pos") is not None and qt.get("outlier_vals") is not None: |
| total_bits += int(qt["outlier_pos"].numel()) * 16 |
| total_bits += int(qt["outlier_vals"].numel()) * 16 |
|
|
| if total_numel == 0: |
| return 0.0 |
| return total_bits / total_numel |
|
|
|
|
| def save_quantized_package(pkg: Dict[str, Any], output_path: str | Path) -> None: |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(pkg, output_path) |
|
|
|
|
| def load_quantized_package(path: str | Path) -> Dict[str, Any]: |
| return torch.load(path, map_location="cpu") |
|
|
|
|
| def save_report(pkg: Dict[str, Any], report_path: str | Path) -> None: |
| report = { |
| "format": pkg["format"], |
| "bits": int(pkg["bits"]), |
| "block_size": int(pkg["block_size"]), |
| "seed": int(pkg["seed"]), |
| "lowrank_rank": int(pkg.get("lowrank_rank", 0)), |
| "rotor_angle_scale": float(pkg.get("rotor_angle_scale", 1.0)), |
| "rowwise": bool(pkg.get("rowwise", False)), |
| "outlier_frac": float(pkg.get("outlier_frac", 0.0)), |
| "num_quantized_tensors": len(pkg["quantized"]), |
| "num_passthrough_tensors": len(pkg["passthrough"]), |
| "estimated_bits_per_weight": estimate_bits_per_weight(pkg), |
| } |
| Path(report_path).write_text(json.dumps(report, indent=2), encoding="utf-8") |
|
|