RotorQuant-ModelWeights-Runtime / rotorquant_weights.py
cnmoro's picture
Upload 29 files
18f4d80 verified
from __future__ import annotations
import hashlib
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple, Any, Iterable
import torch
def _deterministic_rotor_matrix(
name: str,
seed: int,
device: torch.device,
dtype: torch.dtype,
angle_scale: float = 1.0,
) -> torch.Tensor:
key = f"{seed}:{name}".encode("utf-8")
digest = hashlib.sha256(key).digest()
def _u32(i: int) -> int:
return int.from_bytes(digest[4 * i : 4 * (i + 1)], "little")
v = torch.tensor([
(_u32(0) / (2**32 - 1)) * 2.0 - 1.0,
(_u32(1) / (2**32 - 1)) * 2.0 - 1.0,
(_u32(2) / (2**32 - 1)) * 2.0 - 1.0,
], device=device, dtype=torch.float64)
if torch.linalg.norm(v) < 1e-12:
v = torch.tensor([1.0, 0.0, 0.0], device=device, dtype=torch.float64)
axis = v / torch.linalg.norm(v)
angle_u = _u32(3) / (2**32 - 1)
angle = (2.0 * angle_u - 1.0) * (math.pi / 2.0) * angle_scale
K = torch.tensor(
[
[0.0, -axis[2].item(), axis[1].item()],
[axis[2].item(), 0.0, -axis[0].item()],
[-axis[1].item(), axis[0].item(), 0.0],
],
device=device,
dtype=torch.float64,
)
I = torch.eye(3, device=device, dtype=torch.float64)
R = I + math.sin(angle) * K + (1.0 - math.cos(angle)) * (K @ K)
return R.to(dtype=dtype)
def _lloyd_codebook(values: torch.Tensor, levels: int = 8, iters: int = 25, eps: float = 1e-6) -> torch.Tensor:
flat = values.reshape(-1)
if flat.numel() == 0:
return torch.linspace(-1.0, 1.0, levels, device=values.device, dtype=values.dtype)
if flat.numel() > 250_000:
idx = torch.randperm(flat.numel(), device=flat.device)[:250_000]
work = flat[idx]
else:
work = flat
probs = torch.linspace(0.0, 1.0, levels, device=work.device, dtype=torch.float32)
init = torch.quantile(work.float(), probs).to(work.dtype)
codebook = init
for _ in range(iters):
d = (work.unsqueeze(1) - codebook.unsqueeze(0)).abs()
assign = d.argmin(dim=1)
new_codebook = codebook.clone()
for k in range(levels):
sel = work[assign == k]
if sel.numel() > 0:
new_codebook[k] = sel.mean()
if torch.max((new_codebook - codebook).abs()) < eps:
codebook = new_codebook
break
codebook = new_codebook
codebook, _ = torch.sort(codebook)
return codebook
def _pack_3bit(indices: torch.Tensor) -> torch.Tensor:
x = indices.reshape(-1).to(torch.uint8).cpu()
n = x.numel()
if n == 0:
return torch.empty(0, dtype=torch.uint8)
n_full = (n // 8) * 8
full = x[:n_full].view(-1, 8).to(torch.int32)
packed24 = (
full[:, 0]
| (full[:, 1] << 3)
| (full[:, 2] << 6)
| (full[:, 3] << 9)
| (full[:, 4] << 12)
| (full[:, 5] << 15)
| (full[:, 6] << 18)
| (full[:, 7] << 21)
)
bytes_full = torch.stack(
[
(packed24 & 0xFF),
((packed24 >> 8) & 0xFF),
((packed24 >> 16) & 0xFF),
],
dim=1,
).reshape(-1).to(torch.uint8)
rem = x[n_full:]
if rem.numel() == 0:
return bytes_full
rem_acc = torch.tensor(0, dtype=torch.int32)
for i, v in enumerate(rem.tolist()):
rem_acc |= (int(v) & 0x7) << (3 * i)
rem_nbytes = (rem.numel() * 3 + 7) // 8
rem_bytes = torch.tensor([(rem_acc >> (8 * i)) & 0xFF for i in range(rem_nbytes)], dtype=torch.uint8)
return torch.cat([bytes_full, rem_bytes], dim=0)
def _unpack_3bit(packed: torch.Tensor, n_values: int, device: torch.device) -> torch.Tensor:
if n_values == 0:
return torch.empty(0, dtype=torch.uint8, device=device)
p = packed.to(torch.uint8).cpu()
n_full = (n_values // 8) * 8
n_full_groups = n_full // 8
n_full_bytes = n_full_groups * 3
out_parts = []
if n_full_groups > 0:
b = p[:n_full_bytes].view(-1, 3).to(torch.int32)
packed24 = b[:, 0] | (b[:, 1] << 8) | (b[:, 2] << 16)
vals = torch.stack(
[
(packed24 >> 0) & 0x7,
(packed24 >> 3) & 0x7,
(packed24 >> 6) & 0x7,
(packed24 >> 9) & 0x7,
(packed24 >> 12) & 0x7,
(packed24 >> 15) & 0x7,
(packed24 >> 18) & 0x7,
(packed24 >> 21) & 0x7,
],
dim=1,
).reshape(-1).to(torch.uint8)
out_parts.append(vals)
rem_n = n_values - n_full
if rem_n > 0:
rem_bytes = p[n_full_bytes:]
acc = 0
for i, v in enumerate(rem_bytes.tolist()):
acc |= int(v) << (8 * i)
rem_vals = torch.tensor([(acc >> (3 * i)) & 0x7 for i in range(rem_n)], dtype=torch.uint8)
out_parts.append(rem_vals)
return torch.cat(out_parts, dim=0).to(device=device)
@dataclass
class QuantizedTensor:
shape: Tuple[int, ...]
n_rows: int
row_size: int
row_rot_size: int
row_padded_size: int
packed_indices: torch.Tensor
centers: torch.Tensor
scales: torch.Tensor
codebook: torch.Tensor
lowrank_A: torch.Tensor | None = None
lowrank_B: torch.Tensor | None = None
outlier_pos: torch.Tensor | None = None
outlier_vals: torch.Tensor | None = None
class RotorQuantWeightCodec:
def __init__(
self,
bits: int = 3,
block_size: int = 128,
seed: int = 1337,
eps: float = 1e-8,
lowrank_rank: int = 0,
rotor_angle_scale: float = 1.0,
rowwise: bool = False,
outlier_frac: float = 0.0,
):
if bits != 3:
raise ValueError("Current prototype only implements 3-bit packing.")
self.bits = bits
self.block_size = block_size
self.seed = seed
self.eps = eps
self.lowrank_rank = lowrank_rank
self.rotor_angle_scale = rotor_angle_scale
self.rowwise = rowwise
self.outlier_frac = outlier_frac
def quantize_tensor(self, name: str, w: torch.Tensor) -> QuantizedTensor:
x = w.float()
if self.rowwise and x.ndim >= 2:
n_rows = int(math.prod(x.shape[:-1]))
row_size = x.shape[-1]
rows_orig = x.reshape(n_rows, row_size)
else:
n_rows = 1
row_size = x.numel()
rows_orig = x.reshape(1, row_size)
rows = rows_orig
pad3 = (-row_size) % 3
if pad3:
rows = torch.cat([rows, torch.zeros(n_rows, pad3, device=rows.device, dtype=rows.dtype)], dim=1)
row_rot_size = rows.shape[1]
R = _deterministic_rotor_matrix(
name=name,
seed=self.seed,
device=rows.device,
dtype=rows.dtype,
angle_scale=self.rotor_angle_scale,
)
rot = (rows.reshape(n_rows, -1, 3) @ R.T).reshape(n_rows, row_rot_size)
pad_block = (-row_rot_size) % self.block_size
if pad_block:
rot = torch.cat([rot, torch.zeros(n_rows, pad_block, device=rot.device, dtype=rot.dtype)], dim=1)
row_padded_size = rot.shape[1]
n_blocks = row_padded_size // self.block_size
blocks = rot.view(n_rows, n_blocks, self.block_size)
centers = torch.zeros(n_rows, n_blocks, device=blocks.device, dtype=blocks.dtype)
scales = torch.zeros(n_rows, n_blocks, device=blocks.device, dtype=blocks.dtype)
normed = torch.zeros_like(blocks)
full_blocks = row_rot_size // self.block_size
tail = row_rot_size % self.block_size
if full_blocks > 0:
blk = blocks[:, :full_blocks, :]
c = blk.mean(dim=-1)
z = blk - c.unsqueeze(-1)
s = z.abs().amax(dim=-1).clamp(min=self.eps)
centers[:, :full_blocks] = c
scales[:, :full_blocks] = s
normed[:, :full_blocks, :] = z / s.unsqueeze(-1)
if tail > 0:
blk = blocks[:, full_blocks, :tail]
c = blk.mean(dim=-1)
z = blk - c.unsqueeze(-1)
s = z.abs().amax(dim=-1).clamp(min=self.eps)
centers[:, full_blocks] = c
scales[:, full_blocks] = s
normed[:, full_blocks, :tail] = z / s.unsqueeze(-1)
centers = centers.reshape(-1)
scales = scales.reshape(-1)
normed = normed.reshape(-1, self.block_size)
codebook = _lloyd_codebook(normed, levels=2**self.bits)
idx_chunks = []
chunk_blocks = 4096
for i in range(0, normed.shape[0], chunk_blocks):
b = normed[i : i + chunk_blocks]
diffs = (b.unsqueeze(-1) - codebook.view(1, 1, -1)).abs()
idx_chunks.append(diffs.argmin(dim=-1).to(torch.uint8))
idx = torch.cat(idx_chunks, dim=0)
packed = _pack_3bit(idx.reshape(-1).cpu())
qt = QuantizedTensor(
shape=tuple(w.shape),
n_rows=n_rows,
row_size=row_size,
row_rot_size=row_rot_size,
row_padded_size=row_padded_size,
packed_indices=packed,
centers=centers.cpu().to(torch.float16),
scales=scales.cpu().to(torch.float16),
codebook=codebook.cpu().to(torch.float16),
)
if self.lowrank_rank > 0 and n_rows > 1 and row_size > 1:
deq_rows = (idx.to(torch.long).to(codebook.device).reshape(-1))
deq_vals = codebook[deq_rows]
deq_blocks = deq_vals.reshape(-1, self.block_size)
deq_q = (deq_blocks * scales.unsqueeze(1) + centers.unsqueeze(1)).reshape(n_rows, row_padded_size)
deq_q = deq_q[:, :row_rot_size]
R = _deterministic_rotor_matrix(
name=name,
seed=self.seed,
device=rows.device,
dtype=rows.dtype,
angle_scale=self.rotor_angle_scale,
)
x_hat_rows = (deq_q.reshape(n_rows, -1, 3) @ R).reshape(n_rows, row_rot_size)[:, :row_size]
residual = rows_orig - x_hat_rows
rank = min(self.lowrank_rank, residual.shape[0], residual.shape[1])
if rank > 0:
U, S, V = torch.pca_lowrank(residual, q=rank, center=False, niter=2)
A = (U[:, :rank] * S[:rank]).to(torch.float16).cpu()
B = V[:, :rank].T.to(torch.float16).cpu()
qt.lowrank_A = A
qt.lowrank_B = B
if self.outlier_frac > 0 and row_size > 0:
deq_rows = self.dequantize_tensor(name, qt, device=torch.device("cpu"), dtype=torch.float32).reshape(n_rows, row_size)
residual = (rows_orig - deq_rows).abs()
k = max(1, int(row_size * self.outlier_frac))
k = min(k, row_size)
vals, pos = torch.topk(residual, k=k, dim=1, largest=True, sorted=False)
out_vals = torch.gather(rows_orig, dim=1, index=pos)
qt.outlier_pos = pos.to(torch.int16).cpu()
qt.outlier_vals = out_vals.to(torch.float16).cpu()
return qt
def dequantize_tensor(self, name: str, qt: QuantizedTensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
n_blocks = qt.n_rows * (qt.row_padded_size // self.block_size)
n_values = n_blocks * self.block_size
idx = _unpack_3bit(qt.packed_indices, n_values=n_values, device=device).long()
codebook = qt.codebook.to(device=device, dtype=torch.float32)
centers = qt.centers.to(device=device, dtype=torch.float32)
scales = qt.scales.to(device=device, dtype=torch.float32)
vals = codebook[idx]
blocks = vals.reshape(-1, self.block_size)
deq_rows = (blocks * scales.unsqueeze(1) + centers.unsqueeze(1)).reshape(qt.n_rows, qt.row_padded_size)
deq_rows = deq_rows[:, : qt.row_rot_size]
R = _deterministic_rotor_matrix(
name=name,
seed=self.seed,
device=device,
dtype=torch.float32,
angle_scale=self.rotor_angle_scale,
)
x_rows = (deq_rows.reshape(qt.n_rows, -1, 3) @ R).reshape(qt.n_rows, qt.row_rot_size)
x_rows = x_rows[:, : qt.row_size]
if qt.lowrank_A is not None and qt.lowrank_B is not None:
A = qt.lowrank_A.to(device=device, dtype=torch.float32)
B = qt.lowrank_B.to(device=device, dtype=torch.float32)
x_rows = x_rows + (A @ B)
if qt.outlier_pos is not None and qt.outlier_vals is not None:
pos = qt.outlier_pos.to(device=device, dtype=torch.long)
vals = qt.outlier_vals.to(device=device, dtype=torch.float32)
x_rows.scatter_(dim=1, index=pos, src=vals)
return x_rows.reshape(qt.shape).to(dtype=dtype)
def quantize_state_dict(
state_dict: Dict[str, torch.Tensor],
bits: int = 3,
block_size: int = 128,
seed: int = 1337,
min_ndim: int = 2,
verbose: bool = False,
skip_names: Iterable[str] | None = None,
lowrank_rank: int = 0,
rotor_angle_scale: float = 1.0,
rowwise: bool = False,
include_if_name_contains: Iterable[str] | None = None,
outlier_frac: float = 0.0,
) -> Dict[str, Any]:
codec = RotorQuantWeightCodec(
bits=bits,
block_size=block_size,
seed=seed,
lowrank_rank=lowrank_rank,
rotor_angle_scale=rotor_angle_scale,
rowwise=rowwise,
outlier_frac=outlier_frac,
)
out: Dict[str, Any] = {
"format": "rotorquant_v1",
"bits": bits,
"block_size": block_size,
"seed": seed,
"lowrank_rank": lowrank_rank,
"rotor_angle_scale": rotor_angle_scale,
"rowwise": rowwise,
"outlier_frac": outlier_frac,
"quantized": {},
"passthrough": {},
}
skip_names = set(skip_names or [])
include_fragments = list(include_if_name_contains or [])
for i, (name, t) in enumerate(state_dict.items(), start=1):
if verbose:
print(f"[quantize] {i}/{len(state_dict)} {name} shape={tuple(t.shape)}")
if include_fragments and not any(frag in name for frag in include_fragments):
out["passthrough"][name] = t.cpu()
continue
if (not torch.is_floating_point(t)) or t.ndim < min_ndim or name in skip_names:
out["passthrough"][name] = t.cpu()
continue
qt = codec.quantize_tensor(name, t.detach().cpu())
out["quantized"][name] = {
"shape": qt.shape,
"n_rows": qt.n_rows,
"row_size": qt.row_size,
"row_rot_size": qt.row_rot_size,
"row_padded_size": qt.row_padded_size,
"packed_indices": qt.packed_indices,
"centers": qt.centers,
"scales": qt.scales,
"codebook": qt.codebook,
"lowrank_A": qt.lowrank_A,
"lowrank_B": qt.lowrank_B,
"outlier_pos": qt.outlier_pos,
"outlier_vals": qt.outlier_vals,
}
return out
def dequantize_to_state_dict(pkg: Dict[str, Any], dtype: torch.dtype = torch.float32, device: str = "cpu") -> Dict[str, torch.Tensor]:
codec = RotorQuantWeightCodec(
bits=pkg["bits"],
block_size=pkg["block_size"],
seed=pkg["seed"],
lowrank_rank=int(pkg.get("lowrank_rank", 0)),
rotor_angle_scale=float(pkg.get("rotor_angle_scale", 1.0)),
rowwise=bool(pkg.get("rowwise", False)),
outlier_frac=float(pkg.get("outlier_frac", 0.0)),
)
out: Dict[str, torch.Tensor] = {}
dev = torch.device(device)
for name, qt_raw in pkg["quantized"].items():
if "n_rows" in qt_raw:
n_rows = int(qt_raw["n_rows"])
row_size = int(qt_raw["row_size"])
row_rot_size = int(qt_raw["row_rot_size"])
row_padded_size = int(qt_raw["row_padded_size"])
else:
numel = int(qt_raw["numel"])
n_rows = 1
row_size = numel
row_rot_size = ((numel + 2) // 3) * 3
row_padded_size = int(qt_raw["padded_numel"])
centers = qt_raw.get("centers")
if centers is None:
n_blocks = n_rows * (row_padded_size // pkg["block_size"])
centers = torch.zeros(n_blocks, dtype=torch.float16)
qt = QuantizedTensor(
shape=tuple(qt_raw["shape"]),
n_rows=n_rows,
row_size=row_size,
row_rot_size=row_rot_size,
row_padded_size=row_padded_size,
packed_indices=qt_raw["packed_indices"],
centers=centers,
scales=qt_raw["scales"],
codebook=qt_raw["codebook"],
lowrank_A=qt_raw.get("lowrank_A"),
lowrank_B=qt_raw.get("lowrank_B"),
outlier_pos=qt_raw.get("outlier_pos"),
outlier_vals=qt_raw.get("outlier_vals"),
)
out[name] = codec.dequantize_tensor(name, qt, device=dev, dtype=dtype)
for name, t in pkg["passthrough"].items():
out[name] = t.to(device=dev, dtype=(dtype if torch.is_floating_point(t) else t.dtype))
return out
def estimate_bits_per_weight(pkg: Dict[str, Any]) -> float:
total_numel = 0
total_bits = 0
for qt in pkg["quantized"].values():
n = int(math.prod(qt["shape"]))
total_numel += n
total_bits += int(qt["packed_indices"].numel()) * 8
if qt.get("centers") is not None:
total_bits += int(qt["centers"].numel()) * 16
total_bits += int(qt["scales"].numel()) * 16
total_bits += int(qt["codebook"].numel()) * 16
if qt.get("lowrank_A") is not None and qt.get("lowrank_B") is not None:
total_bits += int(qt["lowrank_A"].numel()) * 16
total_bits += int(qt["lowrank_B"].numel()) * 16
if qt.get("outlier_pos") is not None and qt.get("outlier_vals") is not None:
total_bits += int(qt["outlier_pos"].numel()) * 16
total_bits += int(qt["outlier_vals"].numel()) * 16
if total_numel == 0:
return 0.0
return total_bits / total_numel
def save_quantized_package(pkg: Dict[str, Any], output_path: str | Path) -> None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(pkg, output_path)
def load_quantized_package(path: str | Path) -> Dict[str, Any]:
return torch.load(path, map_location="cpu")
def save_report(pkg: Dict[str, Any], report_path: str | Path) -> None:
report = {
"format": pkg["format"],
"bits": int(pkg["bits"]),
"block_size": int(pkg["block_size"]),
"seed": int(pkg["seed"]),
"lowrank_rank": int(pkg.get("lowrank_rank", 0)),
"rotor_angle_scale": float(pkg.get("rotor_angle_scale", 1.0)),
"rowwise": bool(pkg.get("rowwise", False)),
"outlier_frac": float(pkg.get("outlier_frac", 0.0)),
"num_quantized_tensors": len(pkg["quantized"]),
"num_passthrough_tensors": len(pkg["passthrough"]),
"estimated_bits_per_weight": estimate_bits_per_weight(pkg),
}
Path(report_path).write_text(json.dumps(report, indent=2), encoding="utf-8")