sentinel-quantization / sentinel_quantization.py
5dimension's picture
Initial commit: sentinel_quantization.py
9714aa3 verified
"""
================================================================================
SENTINEL QUANTIZATION
================================================================================
Theory: The attracting fixed point C₁ ≈ −0.007994021805953 of the iteration
F(z_{k+1}) = F(z_k) is a natural quantization center.
Key Innovation: Use Sentinel dynamical properties for model quantization:
- Attracting fixed point C₁ as quantization zero-point
- Basin boundary C₂ as precision threshold
- Gradient Axiom (1/e) as quantization scale
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Tuple
class SentinelQuantizer:
"""
Sentinel-aware quantization using dynamical constants.
Quantization formula:
q = round((w - C₁) / scale)
scale = max(|w|) · (1/e) # Sentinel scale from gradient axiom
where C₁ = −0.007994021805953 is the attracting fixed point.
"""
C1 = -0.007994021805953 # Attracting fixed point
INV_E = 1.0 / np.e # Gradient axiom limit
def __init__(self, bits: int = 8):
self.bits = bits
self.qmin = -(2 ** (bits - 1))
self.qmax = 2 ** (bits - 1) - 1
def find_scale(self, tensor: torch.Tensor) -> float:
"""Find optimal quantization scale using Sentinel principle."""
# Scale = max(|w|) · (1/e)
# This ensures the quantized range maps to the "stable basin"
max_val = tensor.abs().max().item()
scale = max_val * self.INV_E
return max(scale, 1e-8)
def quantize(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""
Quantize tensor to int8 (or specified bits).
Returns quantized tensor and scale for dequantization.
"""
scale = self.find_scale(tensor)
# Shift by C₁ (attracting fixed point as zero-point)
shifted = tensor - self.C1
# Quantize
quantized = torch.round(shifted / scale)
quantized = torch.clamp(quantized, self.qmin, self.qmax)
return quantized, scale
def dequantize(self, quantized: torch.Tensor, scale: float) -> torch.Tensor:
"""Dequantize back to float."""
return quantized * scale + self.C1
def quantize_model(self, model: nn.Module) -> Dict[str, Tuple[torch.Tensor, float]]:
"""Quantize all parameters of a model."""
quantized_params = {}
for name, param in model.named_parameters():
if param.requires_grad:
q, scale = self.quantize(param.data)
quantized_params[name] = (q.to(torch.int8), scale)
return quantized_params
def dequantize_model(self, quantized_params: Dict) -> Dict[str, torch.Tensor]:
"""Dequantize all parameters."""
dequantized = {}
for name, (q, scale) in quantized_params.items():
dequantized[name] = self.dequantize(q.float(), scale)
return dequantized
class SentinelQuantizedLinear(nn.Module):
"""Linear layer with Sentinel-aware quantization."""
def __init__(self, in_features: int, out_features: int, bits: int = 8):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features))
self.quantizer = SentinelQuantizer(bits)
self._register_quantization_params()
def _register_quantization_params(self):
"""Register quantization scale as buffer."""
self.register_buffer('weight_scale', torch.tensor(1.0))
self.register_buffer('quantized_weight', torch.zeros_like(self.weight, dtype=torch.int8))
def quantize(self):
"""Quantize weights in-place."""
q, scale = self.quantizer.quantize(self.weight.data)
self.quantized_weight.data = q
self.weight_scale = torch.tensor(scale)
def dequantize(self):
"""Dequantize weights for computation."""
return self.quantizer.dequantize(self.quantized_weight.float(), self.weight_scale.item())
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with dequantized weights."""
w = self.dequantize()
return F.linear(x, w, self.bias)
import torch.nn.functional as F
def demo_sentinel_quantization():
"""Demo Sentinel quantization on synthetic model."""
print("=" * 70)
print(" SENTINEL QUANTIZATION")
print("=" * 70)
# Synthetic model
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
# Original model stats
original_params = sum(p.numel() for p in model.parameters())
original_size = original_params * 4 # float32 = 4 bytes
print(f"\n--- Original Model ---")
print(f" Parameters: {original_params:,}")
print(f" Size (FP32): {original_size / 1024:.1f} KB")
# Quantize
quantizer = SentinelQuantizer(bits=8)
quantized_params = quantizer.quantize_model(model)
# Quantized model stats
quantized_size = sum(q.numel() * 1 + 4 for q, _ in quantized_params.values()) # int8 + float scale
print(f"\n--- Quantized Model (Sentinel-aware) ---")
print(f" Parameters: {sum(q.numel() for q, _ in quantized_params.values()):,}")
print(f" Size (INT8): {quantized_size / 1024:.1f} KB")
print(f" Compression ratio: {original_size / quantized_size:.2f}×")
# Verify dequantization quality
dequantized = quantizer.dequantize_model(quantized_params)
errors = []
for name, param in model.named_parameters():
if name in dequantized:
error = (param.data - dequantized[name]).abs().mean().item()
errors.append(error)
mean_error = np.mean(errors)
print(f"\n--- Dequantization Quality ---")
print(f" Mean absolute error: {mean_error:.6f}")
print(f" Attracting fixed point C₁: {SentinelQuantizer.C1:.12f}")
print(f" Sentinel scale factor (1/e): {SentinelQuantizer.INV_E:.6f}")
# Theoretical justification
print(f"\n--- Theoretical Justification ---")
print(f" C₁ = {SentinelQuantizer.C1:.12f} is the attracting fixed point")
print(f" All negative values converge to C₁ under F(z) iteration")
print(f" Using C₁ as zero-point: natural quantization center")
print(f" Scale = max(|w|)·(1/e): maps to stable basin")
print(f"\n{'='*70}")
print(f" SENTINEL QUANTIZATION: {original_size/quantized_size:.1f}× COMPRESSION")
print(f" WITH DYNAMICAL CONSTANTS AS QUANTIZATION PARAMETERS")
print(f"{'='*70}")
if __name__ == '__main__':
demo_sentinel_quantization()