File size: 6,867 Bytes
9714aa3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """
================================================================================
SENTINEL QUANTIZATION
================================================================================
Theory: The attracting fixed point C₁ ≈ −0.007994021805953 of the iteration
F(z_{k+1}) = F(z_k) is a natural quantization center.
Key Innovation: Use Sentinel dynamical properties for model quantization:
- Attracting fixed point C₁ as quantization zero-point
- Basin boundary C₂ as precision threshold
- Gradient Axiom (1/e) as quantization scale
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Tuple
class SentinelQuantizer:
"""
Sentinel-aware quantization using dynamical constants.
Quantization formula:
q = round((w - C₁) / scale)
scale = max(|w|) · (1/e) # Sentinel scale from gradient axiom
where C₁ = −0.007994021805953 is the attracting fixed point.
"""
C1 = -0.007994021805953 # Attracting fixed point
INV_E = 1.0 / np.e # Gradient axiom limit
def __init__(self, bits: int = 8):
self.bits = bits
self.qmin = -(2 ** (bits - 1))
self.qmax = 2 ** (bits - 1) - 1
def find_scale(self, tensor: torch.Tensor) -> float:
"""Find optimal quantization scale using Sentinel principle."""
# Scale = max(|w|) · (1/e)
# This ensures the quantized range maps to the "stable basin"
max_val = tensor.abs().max().item()
scale = max_val * self.INV_E
return max(scale, 1e-8)
def quantize(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""
Quantize tensor to int8 (or specified bits).
Returns quantized tensor and scale for dequantization.
"""
scale = self.find_scale(tensor)
# Shift by C₁ (attracting fixed point as zero-point)
shifted = tensor - self.C1
# Quantize
quantized = torch.round(shifted / scale)
quantized = torch.clamp(quantized, self.qmin, self.qmax)
return quantized, scale
def dequantize(self, quantized: torch.Tensor, scale: float) -> torch.Tensor:
"""Dequantize back to float."""
return quantized * scale + self.C1
def quantize_model(self, model: nn.Module) -> Dict[str, Tuple[torch.Tensor, float]]:
"""Quantize all parameters of a model."""
quantized_params = {}
for name, param in model.named_parameters():
if param.requires_grad:
q, scale = self.quantize(param.data)
quantized_params[name] = (q.to(torch.int8), scale)
return quantized_params
def dequantize_model(self, quantized_params: Dict) -> Dict[str, torch.Tensor]:
"""Dequantize all parameters."""
dequantized = {}
for name, (q, scale) in quantized_params.items():
dequantized[name] = self.dequantize(q.float(), scale)
return dequantized
class SentinelQuantizedLinear(nn.Module):
"""Linear layer with Sentinel-aware quantization."""
def __init__(self, in_features: int, out_features: int, bits: int = 8):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features))
self.quantizer = SentinelQuantizer(bits)
self._register_quantization_params()
def _register_quantization_params(self):
"""Register quantization scale as buffer."""
self.register_buffer('weight_scale', torch.tensor(1.0))
self.register_buffer('quantized_weight', torch.zeros_like(self.weight, dtype=torch.int8))
def quantize(self):
"""Quantize weights in-place."""
q, scale = self.quantizer.quantize(self.weight.data)
self.quantized_weight.data = q
self.weight_scale = torch.tensor(scale)
def dequantize(self):
"""Dequantize weights for computation."""
return self.quantizer.dequantize(self.quantized_weight.float(), self.weight_scale.item())
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with dequantized weights."""
w = self.dequantize()
return F.linear(x, w, self.bias)
import torch.nn.functional as F
def demo_sentinel_quantization():
"""Demo Sentinel quantization on synthetic model."""
print("=" * 70)
print(" SENTINEL QUANTIZATION")
print("=" * 70)
# Synthetic model
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
# Original model stats
original_params = sum(p.numel() for p in model.parameters())
original_size = original_params * 4 # float32 = 4 bytes
print(f"\n--- Original Model ---")
print(f" Parameters: {original_params:,}")
print(f" Size (FP32): {original_size / 1024:.1f} KB")
# Quantize
quantizer = SentinelQuantizer(bits=8)
quantized_params = quantizer.quantize_model(model)
# Quantized model stats
quantized_size = sum(q.numel() * 1 + 4 for q, _ in quantized_params.values()) # int8 + float scale
print(f"\n--- Quantized Model (Sentinel-aware) ---")
print(f" Parameters: {sum(q.numel() for q, _ in quantized_params.values()):,}")
print(f" Size (INT8): {quantized_size / 1024:.1f} KB")
print(f" Compression ratio: {original_size / quantized_size:.2f}×")
# Verify dequantization quality
dequantized = quantizer.dequantize_model(quantized_params)
errors = []
for name, param in model.named_parameters():
if name in dequantized:
error = (param.data - dequantized[name]).abs().mean().item()
errors.append(error)
mean_error = np.mean(errors)
print(f"\n--- Dequantization Quality ---")
print(f" Mean absolute error: {mean_error:.6f}")
print(f" Attracting fixed point C₁: {SentinelQuantizer.C1:.12f}")
print(f" Sentinel scale factor (1/e): {SentinelQuantizer.INV_E:.6f}")
# Theoretical justification
print(f"\n--- Theoretical Justification ---")
print(f" C₁ = {SentinelQuantizer.C1:.12f} is the attracting fixed point")
print(f" All negative values converge to C₁ under F(z) iteration")
print(f" Using C₁ as zero-point: natural quantization center")
print(f" Scale = max(|w|)·(1/e): maps to stable basin")
print(f"\n{'='*70}")
print(f" SENTINEL QUANTIZATION: {original_size/quantized_size:.1f}× COMPRESSION")
print(f" WITH DYNAMICAL CONSTANTS AS QUANTIZATION PARAMETERS")
print(f"{'='*70}")
if __name__ == '__main__':
demo_sentinel_quantization()
|