Spaces:
Running
Running
| """ | |
| Advanced ternary quantizer: GPTQ-Ternary + Per-Group Scales + Low-Rank Residual. | |
| Three techniques combined to make ternary work on small models: | |
| 1. GPTQ-style Hessian error compensation: | |
| When ternarizing column j, redistribute error to columns j+1..n | |
| using the Hessian inverse. This prevents error accumulation within a layer. | |
| 2. Per-group scales: | |
| Instead of 1 scale per row, use 1 per group of g columns. | |
| Adds ~0.12 bits/param at g=128 but captures within-row variance. | |
| 3. Low-rank residual correction: | |
| After ternarizing, compute R = W - W_ternary and approximate R | |
| with a rank-r SVD. The low-rank part captures fine structure | |
| that ternary fundamentally cannot. | |
| Final formula: | |
| W ≈ diag(alpha_groups) * T + U @ V^T | |
| where T is ternary, alpha_groups is per-group FP16 scales, | |
| and U@V^T is the rank-r residual correction. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| class TernaryV2Parameter: | |
| """Advanced ternary representation with group scales + low-rank residual.""" | |
| ternary_codes: torch.Tensor # int8 {-1, 0, +1}, shape [out, in] | |
| group_scales: torch.Tensor # FP16, shape [out, n_groups] | |
| group_size: int # columns per group | |
| lr_U: Optional[torch.Tensor] # FP16, shape [out, rank] (low-rank residual) | |
| lr_V: Optional[torch.Tensor] # FP16, shape [rank, in] (low-rank residual) | |
| original_shape: tuple | |
| original_dtype: torch.dtype | |
| def dequantize(self) -> torch.Tensor: | |
| """Reconstruct: group_alpha * T + U @ V.""" | |
| out_f, in_f = self.original_shape | |
| T = self.ternary_codes.float() | |
| # Expand group scales to full matrix | |
| gs = self.group_size | |
| alpha_expanded = self.group_scales.float() # [out, n_groups] | |
| # Build full alpha matrix by repeating each group scale across its columns | |
| alpha_full = alpha_expanded.repeat_interleave(gs, dim=1)[:, :in_f] # [out, in] | |
| W_hat = alpha_full * T | |
| # Add low-rank correction | |
| if self.lr_U is not None and self.lr_V is not None: | |
| W_hat = W_hat + self.lr_U.float() @ self.lr_V.float() | |
| return W_hat | |
| def num_params(self) -> int: | |
| return self.original_shape[0] * self.original_shape[1] | |
| def effective_bits(self) -> float: | |
| out_f, in_f = self.original_shape | |
| n_groups = self.group_scales.shape[1] | |
| # Ternary codes: 2 bits each | |
| code_bits = 2 * out_f * in_f | |
| # Group scales: FP16 each | |
| scale_bits = 16 * out_f * n_groups | |
| # Low-rank: FP16 values | |
| lr_bits = 0 | |
| if self.lr_U is not None: | |
| rank = self.lr_U.shape[1] | |
| lr_bits = 16 * (out_f * rank + rank * in_f) | |
| return (code_bits + scale_bits + lr_bits) / (out_f * in_f) | |
| class GPTQTernaryQuantizer: | |
| """ | |
| GPTQ-style ternary quantizer with three key innovations: | |
| 1. Hessian-based error compensation (from GPTQ/OBQ) | |
| 2. Per-group scale factors (finer than per-row) | |
| 3. Low-rank residual correction (SVD of quantization error) | |
| """ | |
| def __init__( | |
| self, | |
| group_size: int = 128, | |
| lr_rank: int = 32, | |
| block_size: int = 128, | |
| damp_percent: float = 0.01, | |
| ): | |
| """ | |
| Args: | |
| group_size: Number of columns per scale group. Smaller = better quality | |
| but more overhead. 128 adds ~0.12 bits/param. 0 = per-row (no groups). | |
| lr_rank: Rank of low-rank residual correction. 0 = disable. | |
| 32 adds ~0.5-1.0 bits/param but captures fine structure. | |
| block_size: GPTQ block size for lazy batch updates. 128 is standard. | |
| damp_percent: Hessian damping factor (fraction of mean diagonal). | |
| """ | |
| self.group_size = group_size | |
| self.lr_rank = lr_rank | |
| self.block_size = block_size | |
| self.damp_percent = damp_percent | |
| def quantize( | |
| self, | |
| weight: torch.Tensor, | |
| activations: Optional[torch.Tensor] = None, | |
| ) -> TernaryV2Parameter: | |
| """ | |
| Quantize a weight matrix using GPTQ-Ternary + group scales + low-rank. | |
| Args: | |
| weight: [out_features, in_features] weight matrix | |
| activations: [n_samples, in_features] calibration activations. | |
| Required for Hessian computation. | |
| Returns: | |
| TernaryV2Parameter | |
| """ | |
| W = weight.float().clone() | |
| original_shape = weight.shape | |
| original_dtype = weight.dtype | |
| out_features, in_features = W.shape | |
| device = W.device | |
| # Determine group size | |
| gs = self.group_size if self.group_size > 0 else in_features | |
| n_groups = (in_features + gs - 1) // gs | |
| # Save original for residual computation | |
| W_original = W.clone() | |
| # --- Step 1: Compute Hessian --- | |
| H = self._compute_hessian(W, activations) | |
| # --- Step 2: GPTQ-style column-wise quantization with error compensation --- | |
| T = torch.zeros_like(W, dtype=torch.int8, device=device) | |
| group_scales = torch.zeros( | |
| out_features, | |
| n_groups, | |
| dtype=torch.float32, | |
| device=device, | |
| ) | |
| # Compute Hessian inverse via Cholesky | |
| try: | |
| H_inv = torch.linalg.cholesky(H) | |
| H_inv = torch.cholesky_inverse(H_inv) | |
| except RuntimeError: | |
| # Fallback: add more damping | |
| damp = 0.1 * torch.diag(H).mean() | |
| H_inv = torch.linalg.inv(H + damp * torch.eye(in_features, device=W.device)) | |
| Losses = torch.zeros(out_features, device=device) | |
| block_size = min(self.block_size, in_features) | |
| # Process in blocks | |
| for block_start in range(0, in_features, block_size): | |
| block_end = min(block_start + block_size, in_features) | |
| # Error accumulator for batch update | |
| Err = torch.zeros(out_features, block_end - block_start, device=W.device) | |
| for j in range(block_start, block_end): | |
| w_col = W[:, j] # [out_features] | |
| h_jj = H_inv[j, j] | |
| # Determine which group this column belongs to | |
| g_idx = j // gs | |
| # Compute group scale if this is the first column of a new group | |
| if j % gs == 0: | |
| g_end = min(j + gs, in_features) | |
| group_cols = W[:, j:g_end] | |
| # Scale = mean absolute value of the group | |
| group_scales[:, g_idx] = group_cols.abs().mean(dim=1).clamp(min=1e-8) | |
| alpha = group_scales[:, g_idx] # [out_features] | |
| # Ternarize this column: round w/alpha to nearest {-1, 0, +1} | |
| z = w_col / alpha | |
| t = torch.zeros_like(z, dtype=torch.int8) | |
| t[z > 0.5] = 1 | |
| t[z < -0.5] = -1 | |
| T[:, j] = t | |
| # Quantized value | |
| w_q = alpha * t.float() | |
| # Quantization error | |
| delta = (w_col - w_q) / h_jj | |
| # Track loss | |
| Losses += (w_col - w_q) ** 2 / h_jj | |
| # Store error for batch update | |
| Err[:, j - block_start] = delta | |
| # Update remaining columns in this block | |
| if j + 1 < block_end: | |
| W[:, j + 1 : block_end] -= ( | |
| delta.unsqueeze(1) | |
| * H_inv[j, j + 1 : block_end].unsqueeze(0) | |
| ) | |
| # Batch update: propagate block errors to all remaining columns | |
| if block_end < in_features: | |
| W[:, block_end:] -= ( | |
| Err @ H_inv[block_start:block_end, block_end:] | |
| ) | |
| # --- Step 3: Low-rank residual correction --- | |
| lr_U = None | |
| lr_V = None | |
| if self.lr_rank > 0: | |
| # Compute residual: what ternary couldn't capture | |
| W_ternary = self._dequantize_ternary(T, group_scales, gs, in_features) | |
| residual = W_original - W_ternary | |
| # SVD of residual, keep top-r singular values | |
| rank = min(self.lr_rank, min(out_features, in_features)) | |
| try: | |
| U, S, Vh = torch.linalg.svd(residual, full_matrices=False) | |
| lr_U = (U[:, :rank] * S[:rank].unsqueeze(0)).to(torch.float16) | |
| lr_V = Vh[:rank, :].to(torch.float16) | |
| except RuntimeError: | |
| # SVD can fail on very ill-conditioned matrices | |
| pass | |
| return TernaryV2Parameter( | |
| ternary_codes=T, | |
| group_scales=group_scales.to(torch.float16), | |
| group_size=gs, | |
| lr_U=lr_U, | |
| lr_V=lr_V, | |
| original_shape=original_shape, | |
| original_dtype=original_dtype, | |
| ) | |
| def _compute_hessian( | |
| self, | |
| W: torch.Tensor, | |
| activations: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| """ | |
| Compute the Hessian H = X^T X / n_samples + damping. | |
| The Hessian captures input correlations: H[i,j] tells us how | |
| correlated input dimensions i and j are. This is crucial for | |
| knowing how to redistribute quantization error. | |
| """ | |
| in_features = W.shape[1] | |
| device = W.device | |
| if activations is not None: | |
| X = activations.float().to(device) | |
| n = X.shape[0] | |
| H = (X.T @ X) / n | |
| else: | |
| # Fallback: identity (no correlation info, degrades to simple RTN) | |
| H = torch.eye(in_features, device=device) | |
| # Damping for numerical stability | |
| damp = self.damp_percent * torch.diag(H).mean() | |
| H += damp * torch.eye(in_features, device=device) | |
| return H | |
| def _dequantize_ternary( | |
| self, | |
| T: torch.Tensor, | |
| group_scales: torch.Tensor, | |
| gs: int, | |
| in_features: int, | |
| ) -> torch.Tensor: | |
| """Reconstruct weight from ternary codes and group scales (no low-rank).""" | |
| alpha_expanded = group_scales.float().repeat_interleave(gs, dim=1)[:, :in_features] | |
| return alpha_expanded * T.float() | |
| class TernaryV2Linear(nn.Module): | |
| """ | |
| Drop-in replacement for nn.Linear using V2 ternary representation. | |
| Forward: output = x @ (alpha_g * T + U @ V)^T + bias | |
| = alpha_g * (x @ T^T) + x @ V^T @ U^T + bias | |
| """ | |
| def __init__( | |
| self, | |
| param: TernaryV2Parameter, | |
| bias: Optional[torch.Tensor] = None, | |
| ): | |
| super().__init__() | |
| out_features, in_features = param.original_shape | |
| self.register_buffer("ternary_codes", param.ternary_codes) | |
| self.register_buffer("group_scales", param.group_scales) | |
| self.group_size = param.group_size | |
| if param.lr_U is not None: | |
| self.register_buffer("lr_U", param.lr_U) | |
| self.register_buffer("lr_V", param.lr_V) | |
| else: | |
| self.lr_U = None | |
| self.lr_V = None | |
| if bias is not None: | |
| self.register_buffer("bias", bias.float()) | |
| else: | |
| self.bias = None | |
| self.out_features = out_features | |
| self.in_features = in_features | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| dtype = x.dtype | |
| # Dequantize and compute matmul | |
| # For efficiency, we compute: output = x @ W_hat^T | |
| # where W_hat = alpha_g * T + U @ V | |
| T = self.ternary_codes.to(dtype) | |
| gs = self.group_size | |
| # Expand group scales: [out, n_groups] -> [out, in] | |
| alpha = self.group_scales.to(dtype).repeat_interleave(gs, dim=1)[:, :self.in_features] | |
| # Scaled ternary: alpha * T | |
| W_hat = alpha * T | |
| # Add low-rank correction | |
| if self.lr_U is not None: | |
| W_hat = W_hat + self.lr_U.to(dtype) @ self.lr_V.to(dtype) | |
| output = nn.functional.linear(x, W_hat) | |
| if self.bias is not None: | |
| output = output + self.bias.to(dtype) | |
| return output | |
| def extra_repr(self) -> str: | |
| rank_str = "" | |
| if self.lr_U is not None: | |
| rank_str = f", lr_rank={self.lr_U.shape[1]}" | |
| return ( | |
| f"in_features={self.in_features}, out_features={self.out_features}, " | |
| f"bias={self.bias is not None}, group_size={self.group_size}" | |
| f"{rank_str}, bits~={self._effective_bits():.2f}" | |
| ) | |
| def _effective_bits(self) -> float: | |
| code_bits = 2 * self.out_features * self.in_features | |
| n_groups = self.group_scales.shape[1] | |
| scale_bits = 16 * self.out_features * n_groups | |
| lr_bits = 0 | |
| if self.lr_U is not None: | |
| rank = self.lr_U.shape[1] | |
| lr_bits = 16 * (self.out_features * rank + rank * self.in_features) | |
| return (code_bits + scale_bits + lr_bits) / (self.out_features * self.in_features) | |
| def compute_v2_error(weight: torch.Tensor, param: TernaryV2Parameter) -> dict: | |
| """Compute quantization error metrics for V2 parameters.""" | |
| W = weight.float() | |
| W_hat = param.dequantize() | |
| mse = ((W - W_hat) ** 2).mean().item() | |
| rmse = mse ** 0.5 | |
| rms_w = (W.norm().item() / (W.numel() ** 0.5) + 1e-8) | |
| rel_error = rmse / rms_w | |
| max_error = (W - W_hat).abs().max().item() | |
| T = param.ternary_codes | |
| sparsity = (T == 0).float().mean().item() | |
| return { | |
| "mse": mse, | |
| "rmse": rmse, | |
| "relative_error": rel_error, | |
| "max_error": max_error, | |
| "sparsity": sparsity, | |
| "effective_bits": param.effective_bits, | |
| } | |