from __future__ import annotations import hashlib import json import math from dataclasses import dataclass from pathlib import Path from typing import Dict, Tuple, Any, Iterable import torch def _deterministic_rotor_matrix( name: str, seed: int, device: torch.device, dtype: torch.dtype, angle_scale: float = 1.0, ) -> torch.Tensor: key = f"{seed}:{name}".encode("utf-8") digest = hashlib.sha256(key).digest() def _u32(i: int) -> int: return int.from_bytes(digest[4 * i : 4 * (i + 1)], "little") v = torch.tensor([ (_u32(0) / (2**32 - 1)) * 2.0 - 1.0, (_u32(1) / (2**32 - 1)) * 2.0 - 1.0, (_u32(2) / (2**32 - 1)) * 2.0 - 1.0, ], device=device, dtype=torch.float64) if torch.linalg.norm(v) < 1e-12: v = torch.tensor([1.0, 0.0, 0.0], device=device, dtype=torch.float64) axis = v / torch.linalg.norm(v) angle_u = _u32(3) / (2**32 - 1) angle = (2.0 * angle_u - 1.0) * (math.pi / 2.0) * angle_scale K = torch.tensor( [ [0.0, -axis[2].item(), axis[1].item()], [axis[2].item(), 0.0, -axis[0].item()], [-axis[1].item(), axis[0].item(), 0.0], ], device=device, dtype=torch.float64, ) I = torch.eye(3, device=device, dtype=torch.float64) R = I + math.sin(angle) * K + (1.0 - math.cos(angle)) * (K @ K) return R.to(dtype=dtype) def _lloyd_codebook(values: torch.Tensor, levels: int = 8, iters: int = 25, eps: float = 1e-6) -> torch.Tensor: flat = values.reshape(-1) if flat.numel() == 0: return torch.linspace(-1.0, 1.0, levels, device=values.device, dtype=values.dtype) if flat.numel() > 250_000: idx = torch.randperm(flat.numel(), device=flat.device)[:250_000] work = flat[idx] else: work = flat probs = torch.linspace(0.0, 1.0, levels, device=work.device, dtype=torch.float32) init = torch.quantile(work.float(), probs).to(work.dtype) codebook = init for _ in range(iters): d = (work.unsqueeze(1) - codebook.unsqueeze(0)).abs() assign = d.argmin(dim=1) new_codebook = codebook.clone() for k in range(levels): sel = work[assign == k] if sel.numel() > 0: new_codebook[k] = sel.mean() if torch.max((new_codebook - codebook).abs()) < eps: codebook = new_codebook break codebook = new_codebook codebook, _ = torch.sort(codebook) return codebook def _pack_3bit(indices: torch.Tensor) -> torch.Tensor: x = indices.reshape(-1).to(torch.uint8).cpu() n = x.numel() if n == 0: return torch.empty(0, dtype=torch.uint8) n_full = (n // 8) * 8 full = x[:n_full].view(-1, 8).to(torch.int32) packed24 = ( full[:, 0] | (full[:, 1] << 3) | (full[:, 2] << 6) | (full[:, 3] << 9) | (full[:, 4] << 12) | (full[:, 5] << 15) | (full[:, 6] << 18) | (full[:, 7] << 21) ) bytes_full = torch.stack( [ (packed24 & 0xFF), ((packed24 >> 8) & 0xFF), ((packed24 >> 16) & 0xFF), ], dim=1, ).reshape(-1).to(torch.uint8) rem = x[n_full:] if rem.numel() == 0: return bytes_full rem_acc = torch.tensor(0, dtype=torch.int32) for i, v in enumerate(rem.tolist()): rem_acc |= (int(v) & 0x7) << (3 * i) rem_nbytes = (rem.numel() * 3 + 7) // 8 rem_bytes = torch.tensor([(rem_acc >> (8 * i)) & 0xFF for i in range(rem_nbytes)], dtype=torch.uint8) return torch.cat([bytes_full, rem_bytes], dim=0) def _unpack_3bit(packed: torch.Tensor, n_values: int, device: torch.device) -> torch.Tensor: if n_values == 0: return torch.empty(0, dtype=torch.uint8, device=device) p = packed.to(torch.uint8).cpu() n_full = (n_values // 8) * 8 n_full_groups = n_full // 8 n_full_bytes = n_full_groups * 3 out_parts = [] if n_full_groups > 0: b = p[:n_full_bytes].view(-1, 3).to(torch.int32) packed24 = b[:, 0] | (b[:, 1] << 8) | (b[:, 2] << 16) vals = torch.stack( [ (packed24 >> 0) & 0x7, (packed24 >> 3) & 0x7, (packed24 >> 6) & 0x7, (packed24 >> 9) & 0x7, (packed24 >> 12) & 0x7, (packed24 >> 15) & 0x7, (packed24 >> 18) & 0x7, (packed24 >> 21) & 0x7, ], dim=1, ).reshape(-1).to(torch.uint8) out_parts.append(vals) rem_n = n_values - n_full if rem_n > 0: rem_bytes = p[n_full_bytes:] acc = 0 for i, v in enumerate(rem_bytes.tolist()): acc |= int(v) << (8 * i) rem_vals = torch.tensor([(acc >> (3 * i)) & 0x7 for i in range(rem_n)], dtype=torch.uint8) out_parts.append(rem_vals) return torch.cat(out_parts, dim=0).to(device=device) @dataclass class QuantizedTensor: shape: Tuple[int, ...] n_rows: int row_size: int row_rot_size: int row_padded_size: int packed_indices: torch.Tensor centers: torch.Tensor scales: torch.Tensor codebook: torch.Tensor lowrank_A: torch.Tensor | None = None lowrank_B: torch.Tensor | None = None outlier_pos: torch.Tensor | None = None outlier_vals: torch.Tensor | None = None class RotorQuantWeightCodec: def __init__( self, bits: int = 3, block_size: int = 128, seed: int = 1337, eps: float = 1e-8, lowrank_rank: int = 0, rotor_angle_scale: float = 1.0, rowwise: bool = False, outlier_frac: float = 0.0, ): if bits != 3: raise ValueError("Current prototype only implements 3-bit packing.") self.bits = bits self.block_size = block_size self.seed = seed self.eps = eps self.lowrank_rank = lowrank_rank self.rotor_angle_scale = rotor_angle_scale self.rowwise = rowwise self.outlier_frac = outlier_frac def quantize_tensor(self, name: str, w: torch.Tensor) -> QuantizedTensor: x = w.float() if self.rowwise and x.ndim >= 2: n_rows = int(math.prod(x.shape[:-1])) row_size = x.shape[-1] rows_orig = x.reshape(n_rows, row_size) else: n_rows = 1 row_size = x.numel() rows_orig = x.reshape(1, row_size) rows = rows_orig pad3 = (-row_size) % 3 if pad3: rows = torch.cat([rows, torch.zeros(n_rows, pad3, device=rows.device, dtype=rows.dtype)], dim=1) row_rot_size = rows.shape[1] R = _deterministic_rotor_matrix( name=name, seed=self.seed, device=rows.device, dtype=rows.dtype, angle_scale=self.rotor_angle_scale, ) rot = (rows.reshape(n_rows, -1, 3) @ R.T).reshape(n_rows, row_rot_size) pad_block = (-row_rot_size) % self.block_size if pad_block: rot = torch.cat([rot, torch.zeros(n_rows, pad_block, device=rot.device, dtype=rot.dtype)], dim=1) row_padded_size = rot.shape[1] n_blocks = row_padded_size // self.block_size blocks = rot.view(n_rows, n_blocks, self.block_size) centers = torch.zeros(n_rows, n_blocks, device=blocks.device, dtype=blocks.dtype) scales = torch.zeros(n_rows, n_blocks, device=blocks.device, dtype=blocks.dtype) normed = torch.zeros_like(blocks) full_blocks = row_rot_size // self.block_size tail = row_rot_size % self.block_size if full_blocks > 0: blk = blocks[:, :full_blocks, :] c = blk.mean(dim=-1) z = blk - c.unsqueeze(-1) s = z.abs().amax(dim=-1).clamp(min=self.eps) centers[:, :full_blocks] = c scales[:, :full_blocks] = s normed[:, :full_blocks, :] = z / s.unsqueeze(-1) if tail > 0: blk = blocks[:, full_blocks, :tail] c = blk.mean(dim=-1) z = blk - c.unsqueeze(-1) s = z.abs().amax(dim=-1).clamp(min=self.eps) centers[:, full_blocks] = c scales[:, full_blocks] = s normed[:, full_blocks, :tail] = z / s.unsqueeze(-1) centers = centers.reshape(-1) scales = scales.reshape(-1) normed = normed.reshape(-1, self.block_size) codebook = _lloyd_codebook(normed, levels=2**self.bits) idx_chunks = [] chunk_blocks = 4096 for i in range(0, normed.shape[0], chunk_blocks): b = normed[i : i + chunk_blocks] diffs = (b.unsqueeze(-1) - codebook.view(1, 1, -1)).abs() idx_chunks.append(diffs.argmin(dim=-1).to(torch.uint8)) idx = torch.cat(idx_chunks, dim=0) packed = _pack_3bit(idx.reshape(-1).cpu()) qt = QuantizedTensor( shape=tuple(w.shape), n_rows=n_rows, row_size=row_size, row_rot_size=row_rot_size, row_padded_size=row_padded_size, packed_indices=packed, centers=centers.cpu().to(torch.float16), scales=scales.cpu().to(torch.float16), codebook=codebook.cpu().to(torch.float16), ) if self.lowrank_rank > 0 and n_rows > 1 and row_size > 1: deq_rows = (idx.to(torch.long).to(codebook.device).reshape(-1)) deq_vals = codebook[deq_rows] deq_blocks = deq_vals.reshape(-1, self.block_size) deq_q = (deq_blocks * scales.unsqueeze(1) + centers.unsqueeze(1)).reshape(n_rows, row_padded_size) deq_q = deq_q[:, :row_rot_size] R = _deterministic_rotor_matrix( name=name, seed=self.seed, device=rows.device, dtype=rows.dtype, angle_scale=self.rotor_angle_scale, ) x_hat_rows = (deq_q.reshape(n_rows, -1, 3) @ R).reshape(n_rows, row_rot_size)[:, :row_size] residual = rows_orig - x_hat_rows rank = min(self.lowrank_rank, residual.shape[0], residual.shape[1]) if rank > 0: U, S, V = torch.pca_lowrank(residual, q=rank, center=False, niter=2) A = (U[:, :rank] * S[:rank]).to(torch.float16).cpu() B = V[:, :rank].T.to(torch.float16).cpu() qt.lowrank_A = A qt.lowrank_B = B if self.outlier_frac > 0 and row_size > 0: deq_rows = self.dequantize_tensor(name, qt, device=torch.device("cpu"), dtype=torch.float32).reshape(n_rows, row_size) residual = (rows_orig - deq_rows).abs() k = max(1, int(row_size * self.outlier_frac)) k = min(k, row_size) vals, pos = torch.topk(residual, k=k, dim=1, largest=True, sorted=False) out_vals = torch.gather(rows_orig, dim=1, index=pos) qt.outlier_pos = pos.to(torch.int16).cpu() qt.outlier_vals = out_vals.to(torch.float16).cpu() return qt def dequantize_tensor(self, name: str, qt: QuantizedTensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor: n_blocks = qt.n_rows * (qt.row_padded_size // self.block_size) n_values = n_blocks * self.block_size idx = _unpack_3bit(qt.packed_indices, n_values=n_values, device=device).long() codebook = qt.codebook.to(device=device, dtype=torch.float32) centers = qt.centers.to(device=device, dtype=torch.float32) scales = qt.scales.to(device=device, dtype=torch.float32) vals = codebook[idx] blocks = vals.reshape(-1, self.block_size) deq_rows = (blocks * scales.unsqueeze(1) + centers.unsqueeze(1)).reshape(qt.n_rows, qt.row_padded_size) deq_rows = deq_rows[:, : qt.row_rot_size] R = _deterministic_rotor_matrix( name=name, seed=self.seed, device=device, dtype=torch.float32, angle_scale=self.rotor_angle_scale, ) x_rows = (deq_rows.reshape(qt.n_rows, -1, 3) @ R).reshape(qt.n_rows, qt.row_rot_size) x_rows = x_rows[:, : qt.row_size] if qt.lowrank_A is not None and qt.lowrank_B is not None: A = qt.lowrank_A.to(device=device, dtype=torch.float32) B = qt.lowrank_B.to(device=device, dtype=torch.float32) x_rows = x_rows + (A @ B) if qt.outlier_pos is not None and qt.outlier_vals is not None: pos = qt.outlier_pos.to(device=device, dtype=torch.long) vals = qt.outlier_vals.to(device=device, dtype=torch.float32) x_rows.scatter_(dim=1, index=pos, src=vals) return x_rows.reshape(qt.shape).to(dtype=dtype) def quantize_state_dict( state_dict: Dict[str, torch.Tensor], bits: int = 3, block_size: int = 128, seed: int = 1337, min_ndim: int = 2, verbose: bool = False, skip_names: Iterable[str] | None = None, lowrank_rank: int = 0, rotor_angle_scale: float = 1.0, rowwise: bool = False, include_if_name_contains: Iterable[str] | None = None, outlier_frac: float = 0.0, ) -> Dict[str, Any]: codec = RotorQuantWeightCodec( bits=bits, block_size=block_size, seed=seed, lowrank_rank=lowrank_rank, rotor_angle_scale=rotor_angle_scale, rowwise=rowwise, outlier_frac=outlier_frac, ) out: Dict[str, Any] = { "format": "rotorquant_v1", "bits": bits, "block_size": block_size, "seed": seed, "lowrank_rank": lowrank_rank, "rotor_angle_scale": rotor_angle_scale, "rowwise": rowwise, "outlier_frac": outlier_frac, "quantized": {}, "passthrough": {}, } skip_names = set(skip_names or []) include_fragments = list(include_if_name_contains or []) for i, (name, t) in enumerate(state_dict.items(), start=1): if verbose: print(f"[quantize] {i}/{len(state_dict)} {name} shape={tuple(t.shape)}") if include_fragments and not any(frag in name for frag in include_fragments): out["passthrough"][name] = t.cpu() continue if (not torch.is_floating_point(t)) or t.ndim < min_ndim or name in skip_names: out["passthrough"][name] = t.cpu() continue qt = codec.quantize_tensor(name, t.detach().cpu()) out["quantized"][name] = { "shape": qt.shape, "n_rows": qt.n_rows, "row_size": qt.row_size, "row_rot_size": qt.row_rot_size, "row_padded_size": qt.row_padded_size, "packed_indices": qt.packed_indices, "centers": qt.centers, "scales": qt.scales, "codebook": qt.codebook, "lowrank_A": qt.lowrank_A, "lowrank_B": qt.lowrank_B, "outlier_pos": qt.outlier_pos, "outlier_vals": qt.outlier_vals, } return out def dequantize_to_state_dict(pkg: Dict[str, Any], dtype: torch.dtype = torch.float32, device: str = "cpu") -> Dict[str, torch.Tensor]: codec = RotorQuantWeightCodec( bits=pkg["bits"], block_size=pkg["block_size"], seed=pkg["seed"], lowrank_rank=int(pkg.get("lowrank_rank", 0)), rotor_angle_scale=float(pkg.get("rotor_angle_scale", 1.0)), rowwise=bool(pkg.get("rowwise", False)), outlier_frac=float(pkg.get("outlier_frac", 0.0)), ) out: Dict[str, torch.Tensor] = {} dev = torch.device(device) for name, qt_raw in pkg["quantized"].items(): if "n_rows" in qt_raw: n_rows = int(qt_raw["n_rows"]) row_size = int(qt_raw["row_size"]) row_rot_size = int(qt_raw["row_rot_size"]) row_padded_size = int(qt_raw["row_padded_size"]) else: numel = int(qt_raw["numel"]) n_rows = 1 row_size = numel row_rot_size = ((numel + 2) // 3) * 3 row_padded_size = int(qt_raw["padded_numel"]) centers = qt_raw.get("centers") if centers is None: n_blocks = n_rows * (row_padded_size // pkg["block_size"]) centers = torch.zeros(n_blocks, dtype=torch.float16) qt = QuantizedTensor( shape=tuple(qt_raw["shape"]), n_rows=n_rows, row_size=row_size, row_rot_size=row_rot_size, row_padded_size=row_padded_size, packed_indices=qt_raw["packed_indices"], centers=centers, scales=qt_raw["scales"], codebook=qt_raw["codebook"], lowrank_A=qt_raw.get("lowrank_A"), lowrank_B=qt_raw.get("lowrank_B"), outlier_pos=qt_raw.get("outlier_pos"), outlier_vals=qt_raw.get("outlier_vals"), ) out[name] = codec.dequantize_tensor(name, qt, device=dev, dtype=dtype) for name, t in pkg["passthrough"].items(): out[name] = t.to(device=dev, dtype=(dtype if torch.is_floating_point(t) else t.dtype)) return out def estimate_bits_per_weight(pkg: Dict[str, Any]) -> float: total_numel = 0 total_bits = 0 for qt in pkg["quantized"].values(): n = int(math.prod(qt["shape"])) total_numel += n total_bits += int(qt["packed_indices"].numel()) * 8 if qt.get("centers") is not None: total_bits += int(qt["centers"].numel()) * 16 total_bits += int(qt["scales"].numel()) * 16 total_bits += int(qt["codebook"].numel()) * 16 if qt.get("lowrank_A") is not None and qt.get("lowrank_B") is not None: total_bits += int(qt["lowrank_A"].numel()) * 16 total_bits += int(qt["lowrank_B"].numel()) * 16 if qt.get("outlier_pos") is not None and qt.get("outlier_vals") is not None: total_bits += int(qt["outlier_pos"].numel()) * 16 total_bits += int(qt["outlier_vals"].numel()) * 16 if total_numel == 0: return 0.0 return total_bits / total_numel def save_quantized_package(pkg: Dict[str, Any], output_path: str | Path) -> None: output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) torch.save(pkg, output_path) def load_quantized_package(path: str | Path) -> Dict[str, Any]: return torch.load(path, map_location="cpu") def save_report(pkg: Dict[str, Any], report_path: str | Path) -> None: report = { "format": pkg["format"], "bits": int(pkg["bits"]), "block_size": int(pkg["block_size"]), "seed": int(pkg["seed"]), "lowrank_rank": int(pkg.get("lowrank_rank", 0)), "rotor_angle_scale": float(pkg.get("rotor_angle_scale", 1.0)), "rowwise": bool(pkg.get("rowwise", False)), "outlier_frac": float(pkg.get("outlier_frac", 0.0)), "num_quantized_tensors": len(pkg["quantized"]), "num_passthrough_tensors": len(pkg["passthrough"]), "estimated_bits_per_weight": estimate_bits_per_weight(pkg), } Path(report_path).write_text(json.dumps(report, indent=2), encoding="utf-8")