Shannonstral-7B-1bit / core /quantization.py
hunterbown's picture
Upload folder using huggingface_hub
5d98323 verified
#!/usr/bin/env python3
"""
Shannon 1-Bit Quantization Algorithm
Mathematical foundation for extreme neural network compression.
Each weight is reduced to its sign, preserving information through scale factors.
Theory:
Given weight matrix W ∈ ℝ^(m×n), we decompose:
W ≈ B ⊙ S
where B ∈ {-1,+1}^(m×n) are binary weights
and S ∈ ℝ^m are per-channel scale factors
Information retained: O(log₂(n)) bits per n weights
"""
import numpy as np
from typing import Tuple, Dict, Any
def quantize_to_binary(weights: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Quantize weight matrix to 1-bit representation with per-channel scaling.
Args:
weights: Input weight matrix of shape [out_channels, in_channels]
Returns:
binary_weights: Matrix of {-1, +1} values
scales: Per-channel scale factors
"""
# Extract sign information (the fundamental bit)
binary_weights = np.sign(weights)
# Compute per-channel scale factors to preserve magnitude information
scales = np.mean(np.abs(weights), axis=1, keepdims=True)
# Handle zero weights gracefully
binary_weights[weights == 0] = 1
scales[scales == 0] = 1e-8
return binary_weights, scales
def pack_binary_weights(binary_weights: np.ndarray) -> np.ndarray:
"""
Pack binary weights into compact byte representation.
Each weight in {-1, +1} is stored as a single bit.
8 weights are packed into each byte for maximum compression.
Args:
binary_weights: Matrix of {-1, +1} values
Returns:
packed: Packed byte array
"""
# Convert {-1, +1} to {0, 1}
bits = ((binary_weights + 1) / 2).astype(np.uint8).flatten()
# Pack 8 bits into each byte
num_bytes = (len(bits) + 7) // 8
padded_bits = np.pad(bits, (0, num_bytes * 8 - len(bits)), mode='constant')
packed = np.packbits(padded_bits.reshape(-1, 8), axis=1).flatten()
return packed
def unpack_binary_weights(packed: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:
"""
Unpack binary weights from compact byte representation.
Args:
packed: Packed byte array
shape: Original weight matrix shape
Returns:
binary_weights: Matrix of {-1, +1} values
"""
# Unpack bits from bytes
bits = np.unpackbits(packed)
# Truncate to exact size needed
total_elements = np.prod(shape)
bits = bits[:total_elements]
# Convert {0, 1} to {-1, +1} and reshape
binary_weights = (2.0 * bits.astype(np.float32) - 1.0).reshape(shape)
return binary_weights
def dequantize(binary_weights: np.ndarray, scales: np.ndarray) -> np.ndarray:
"""
Reconstruct approximate weights from binary representation.
Args:
binary_weights: Matrix of {-1, +1} values
scales: Per-channel scale factors
Returns:
weights: Reconstructed weight matrix
"""
return binary_weights * scales
def quantize_model(model_weights: Dict[str, np.ndarray]) -> Dict[str, Any]:
"""
Quantize entire model to 1-bit representation.
Args:
model_weights: Dictionary of weight tensors
Returns:
quantized: Dictionary containing packed weights and metadata
"""
quantized = {}
total_params = 0
total_bytes = 0
for name, weights in model_weights.items():
if weights.ndim >= 2 and weights.size > 1000: # Only quantize significant matrices
# Quantize to binary
binary, scales = quantize_to_binary(weights)
# Pack for storage
packed = pack_binary_weights(binary)
quantized[name] = {
'packed': packed,
'scales': scales.flatten(),
'shape': weights.shape,
'dtype': 'binary'
}
total_params += weights.size
total_bytes += len(packed) + scales.size * 4 # scales as float32
else:
# Keep small tensors as-is (biases, norms)
quantized[name] = {
'data': weights,
'shape': weights.shape,
'dtype': 'full'
}
total_bytes += weights.nbytes
quantized['_metadata'] = {
'total_parameters': total_params,
'compressed_bytes': total_bytes,
'compression_ratio': (total_params * 4) / total_bytes if total_bytes > 0 else 0
}
return quantized
def calculate_compression_metrics(original_size_mb: float, compressed_size_mb: float) -> Dict[str, float]:
"""
Calculate compression metrics for reporting.
Args:
original_size_mb: Original model size in megabytes
compressed_size_mb: Compressed model size in megabytes
Returns:
metrics: Dictionary of compression metrics
"""
compression_ratio = original_size_mb / compressed_size_mb
space_saved_percent = (1 - compressed_size_mb / original_size_mb) * 100
bits_per_weight = (compressed_size_mb * 8 * 1024 * 1024) / (original_size_mb * 1024 * 1024 / 4)
return {
'compression_ratio': compression_ratio,
'space_saved_percent': space_saved_percent,
'bits_per_weight': bits_per_weight,
'original_size_mb': original_size_mb,
'compressed_size_mb': compressed_size_mb
}