ternary-quant-demo / ternary_quant /quantizer_v2.py
AsadIsmail's picture
Bundle ternary_quant package directly (private repo fix)
162f86a verified
"""
Advanced ternary quantizer: GPTQ-Ternary + Per-Group Scales + Low-Rank Residual.
Three techniques combined to make ternary work on small models:
1. GPTQ-style Hessian error compensation:
When ternarizing column j, redistribute error to columns j+1..n
using the Hessian inverse. This prevents error accumulation within a layer.
2. Per-group scales:
Instead of 1 scale per row, use 1 per group of g columns.
Adds ~0.12 bits/param at g=128 but captures within-row variance.
3. Low-rank residual correction:
After ternarizing, compute R = W - W_ternary and approximate R
with a rank-r SVD. The low-rank part captures fine structure
that ternary fundamentally cannot.
Final formula:
W ≈ diag(alpha_groups) * T + U @ V^T
where T is ternary, alpha_groups is per-group FP16 scales,
and U@V^T is the rank-r residual correction.
"""
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional
@dataclass
class TernaryV2Parameter:
"""Advanced ternary representation with group scales + low-rank residual."""
ternary_codes: torch.Tensor # int8 {-1, 0, +1}, shape [out, in]
group_scales: torch.Tensor # FP16, shape [out, n_groups]
group_size: int # columns per group
lr_U: Optional[torch.Tensor] # FP16, shape [out, rank] (low-rank residual)
lr_V: Optional[torch.Tensor] # FP16, shape [rank, in] (low-rank residual)
original_shape: tuple
original_dtype: torch.dtype
def dequantize(self) -> torch.Tensor:
"""Reconstruct: group_alpha * T + U @ V."""
out_f, in_f = self.original_shape
T = self.ternary_codes.float()
# Expand group scales to full matrix
gs = self.group_size
alpha_expanded = self.group_scales.float() # [out, n_groups]
# Build full alpha matrix by repeating each group scale across its columns
alpha_full = alpha_expanded.repeat_interleave(gs, dim=1)[:, :in_f] # [out, in]
W_hat = alpha_full * T
# Add low-rank correction
if self.lr_U is not None and self.lr_V is not None:
W_hat = W_hat + self.lr_U.float() @ self.lr_V.float()
return W_hat
@property
def num_params(self) -> int:
return self.original_shape[0] * self.original_shape[1]
@property
def effective_bits(self) -> float:
out_f, in_f = self.original_shape
n_groups = self.group_scales.shape[1]
# Ternary codes: 2 bits each
code_bits = 2 * out_f * in_f
# Group scales: FP16 each
scale_bits = 16 * out_f * n_groups
# Low-rank: FP16 values
lr_bits = 0
if self.lr_U is not None:
rank = self.lr_U.shape[1]
lr_bits = 16 * (out_f * rank + rank * in_f)
return (code_bits + scale_bits + lr_bits) / (out_f * in_f)
class GPTQTernaryQuantizer:
"""
GPTQ-style ternary quantizer with three key innovations:
1. Hessian-based error compensation (from GPTQ/OBQ)
2. Per-group scale factors (finer than per-row)
3. Low-rank residual correction (SVD of quantization error)
"""
def __init__(
self,
group_size: int = 128,
lr_rank: int = 32,
block_size: int = 128,
damp_percent: float = 0.01,
):
"""
Args:
group_size: Number of columns per scale group. Smaller = better quality
but more overhead. 128 adds ~0.12 bits/param. 0 = per-row (no groups).
lr_rank: Rank of low-rank residual correction. 0 = disable.
32 adds ~0.5-1.0 bits/param but captures fine structure.
block_size: GPTQ block size for lazy batch updates. 128 is standard.
damp_percent: Hessian damping factor (fraction of mean diagonal).
"""
self.group_size = group_size
self.lr_rank = lr_rank
self.block_size = block_size
self.damp_percent = damp_percent
def quantize(
self,
weight: torch.Tensor,
activations: Optional[torch.Tensor] = None,
) -> TernaryV2Parameter:
"""
Quantize a weight matrix using GPTQ-Ternary + group scales + low-rank.
Args:
weight: [out_features, in_features] weight matrix
activations: [n_samples, in_features] calibration activations.
Required for Hessian computation.
Returns:
TernaryV2Parameter
"""
W = weight.float().clone()
original_shape = weight.shape
original_dtype = weight.dtype
out_features, in_features = W.shape
device = W.device
# Determine group size
gs = self.group_size if self.group_size > 0 else in_features
n_groups = (in_features + gs - 1) // gs
# Save original for residual computation
W_original = W.clone()
# --- Step 1: Compute Hessian ---
H = self._compute_hessian(W, activations)
# --- Step 2: GPTQ-style column-wise quantization with error compensation ---
T = torch.zeros_like(W, dtype=torch.int8, device=device)
group_scales = torch.zeros(
out_features,
n_groups,
dtype=torch.float32,
device=device,
)
# Compute Hessian inverse via Cholesky
try:
H_inv = torch.linalg.cholesky(H)
H_inv = torch.cholesky_inverse(H_inv)
except RuntimeError:
# Fallback: add more damping
damp = 0.1 * torch.diag(H).mean()
H_inv = torch.linalg.inv(H + damp * torch.eye(in_features, device=W.device))
Losses = torch.zeros(out_features, device=device)
block_size = min(self.block_size, in_features)
# Process in blocks
for block_start in range(0, in_features, block_size):
block_end = min(block_start + block_size, in_features)
# Error accumulator for batch update
Err = torch.zeros(out_features, block_end - block_start, device=W.device)
for j in range(block_start, block_end):
w_col = W[:, j] # [out_features]
h_jj = H_inv[j, j]
# Determine which group this column belongs to
g_idx = j // gs
# Compute group scale if this is the first column of a new group
if j % gs == 0:
g_end = min(j + gs, in_features)
group_cols = W[:, j:g_end]
# Scale = mean absolute value of the group
group_scales[:, g_idx] = group_cols.abs().mean(dim=1).clamp(min=1e-8)
alpha = group_scales[:, g_idx] # [out_features]
# Ternarize this column: round w/alpha to nearest {-1, 0, +1}
z = w_col / alpha
t = torch.zeros_like(z, dtype=torch.int8)
t[z > 0.5] = 1
t[z < -0.5] = -1
T[:, j] = t
# Quantized value
w_q = alpha * t.float()
# Quantization error
delta = (w_col - w_q) / h_jj
# Track loss
Losses += (w_col - w_q) ** 2 / h_jj
# Store error for batch update
Err[:, j - block_start] = delta
# Update remaining columns in this block
if j + 1 < block_end:
W[:, j + 1 : block_end] -= (
delta.unsqueeze(1)
* H_inv[j, j + 1 : block_end].unsqueeze(0)
)
# Batch update: propagate block errors to all remaining columns
if block_end < in_features:
W[:, block_end:] -= (
Err @ H_inv[block_start:block_end, block_end:]
)
# --- Step 3: Low-rank residual correction ---
lr_U = None
lr_V = None
if self.lr_rank > 0:
# Compute residual: what ternary couldn't capture
W_ternary = self._dequantize_ternary(T, group_scales, gs, in_features)
residual = W_original - W_ternary
# SVD of residual, keep top-r singular values
rank = min(self.lr_rank, min(out_features, in_features))
try:
U, S, Vh = torch.linalg.svd(residual, full_matrices=False)
lr_U = (U[:, :rank] * S[:rank].unsqueeze(0)).to(torch.float16)
lr_V = Vh[:rank, :].to(torch.float16)
except RuntimeError:
# SVD can fail on very ill-conditioned matrices
pass
return TernaryV2Parameter(
ternary_codes=T,
group_scales=group_scales.to(torch.float16),
group_size=gs,
lr_U=lr_U,
lr_V=lr_V,
original_shape=original_shape,
original_dtype=original_dtype,
)
def _compute_hessian(
self,
W: torch.Tensor,
activations: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Compute the Hessian H = X^T X / n_samples + damping.
The Hessian captures input correlations: H[i,j] tells us how
correlated input dimensions i and j are. This is crucial for
knowing how to redistribute quantization error.
"""
in_features = W.shape[1]
device = W.device
if activations is not None:
X = activations.float().to(device)
n = X.shape[0]
H = (X.T @ X) / n
else:
# Fallback: identity (no correlation info, degrades to simple RTN)
H = torch.eye(in_features, device=device)
# Damping for numerical stability
damp = self.damp_percent * torch.diag(H).mean()
H += damp * torch.eye(in_features, device=device)
return H
def _dequantize_ternary(
self,
T: torch.Tensor,
group_scales: torch.Tensor,
gs: int,
in_features: int,
) -> torch.Tensor:
"""Reconstruct weight from ternary codes and group scales (no low-rank)."""
alpha_expanded = group_scales.float().repeat_interleave(gs, dim=1)[:, :in_features]
return alpha_expanded * T.float()
class TernaryV2Linear(nn.Module):
"""
Drop-in replacement for nn.Linear using V2 ternary representation.
Forward: output = x @ (alpha_g * T + U @ V)^T + bias
= alpha_g * (x @ T^T) + x @ V^T @ U^T + bias
"""
def __init__(
self,
param: TernaryV2Parameter,
bias: Optional[torch.Tensor] = None,
):
super().__init__()
out_features, in_features = param.original_shape
self.register_buffer("ternary_codes", param.ternary_codes)
self.register_buffer("group_scales", param.group_scales)
self.group_size = param.group_size
if param.lr_U is not None:
self.register_buffer("lr_U", param.lr_U)
self.register_buffer("lr_V", param.lr_V)
else:
self.lr_U = None
self.lr_V = None
if bias is not None:
self.register_buffer("bias", bias.float())
else:
self.bias = None
self.out_features = out_features
self.in_features = in_features
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
# Dequantize and compute matmul
# For efficiency, we compute: output = x @ W_hat^T
# where W_hat = alpha_g * T + U @ V
T = self.ternary_codes.to(dtype)
gs = self.group_size
# Expand group scales: [out, n_groups] -> [out, in]
alpha = self.group_scales.to(dtype).repeat_interleave(gs, dim=1)[:, :self.in_features]
# Scaled ternary: alpha * T
W_hat = alpha * T
# Add low-rank correction
if self.lr_U is not None:
W_hat = W_hat + self.lr_U.to(dtype) @ self.lr_V.to(dtype)
output = nn.functional.linear(x, W_hat)
if self.bias is not None:
output = output + self.bias.to(dtype)
return output
def extra_repr(self) -> str:
rank_str = ""
if self.lr_U is not None:
rank_str = f", lr_rank={self.lr_U.shape[1]}"
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}, group_size={self.group_size}"
f"{rank_str}, bits~={self._effective_bits():.2f}"
)
def _effective_bits(self) -> float:
code_bits = 2 * self.out_features * self.in_features
n_groups = self.group_scales.shape[1]
scale_bits = 16 * self.out_features * n_groups
lr_bits = 0
if self.lr_U is not None:
rank = self.lr_U.shape[1]
lr_bits = 16 * (self.out_features * rank + rank * self.in_features)
return (code_bits + scale_bits + lr_bits) / (self.out_features * self.in_features)
def compute_v2_error(weight: torch.Tensor, param: TernaryV2Parameter) -> dict:
"""Compute quantization error metrics for V2 parameters."""
W = weight.float()
W_hat = param.dequantize()
mse = ((W - W_hat) ** 2).mean().item()
rmse = mse ** 0.5
rms_w = (W.norm().item() / (W.numel() ** 0.5) + 1e-8)
rel_error = rmse / rms_w
max_error = (W - W_hat).abs().max().item()
T = param.ternary_codes
sparsity = (T == 0).float().mean().item()
return {
"mse": mse,
"rmse": rmse,
"relative_error": rel_error,
"max_error": max_error,
"sparsity": sparsity,
"effective_bits": param.effective_bits,
}