File size: 5,288 Bytes
5d98323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
"""
Shannon 1-Bit Quantization Algorithm

Mathematical foundation for extreme neural network compression.
Each weight is reduced to its sign, preserving information through scale factors.

Theory:
    Given weight matrix W ∈ ℝ^(m×n), we decompose:
    W ≈ B ⊙ S
    where B ∈ {-1,+1}^(m×n) are binary weights
    and S ∈ ℝ^m are per-channel scale factors

    Information retained: O(log₂(n)) bits per n weights
"""

import numpy as np
from typing import Tuple, Dict, Any


def quantize_to_binary(weights: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Quantize weight matrix to 1-bit representation with per-channel scaling.

    Args:
        weights: Input weight matrix of shape [out_channels, in_channels]

    Returns:
        binary_weights: Matrix of {-1, +1} values
        scales: Per-channel scale factors
    """
    # Extract sign information (the fundamental bit)
    binary_weights = np.sign(weights)

    # Compute per-channel scale factors to preserve magnitude information
    scales = np.mean(np.abs(weights), axis=1, keepdims=True)

    # Handle zero weights gracefully
    binary_weights[weights == 0] = 1
    scales[scales == 0] = 1e-8

    return binary_weights, scales


def pack_binary_weights(binary_weights: np.ndarray) -> np.ndarray:
    """
    Pack binary weights into compact byte representation.

    Each weight in {-1, +1} is stored as a single bit.
    8 weights are packed into each byte for maximum compression.

    Args:
        binary_weights: Matrix of {-1, +1} values

    Returns:
        packed: Packed byte array
    """
    # Convert {-1, +1} to {0, 1}
    bits = ((binary_weights + 1) / 2).astype(np.uint8).flatten()

    # Pack 8 bits into each byte
    num_bytes = (len(bits) + 7) // 8
    padded_bits = np.pad(bits, (0, num_bytes * 8 - len(bits)), mode='constant')
    packed = np.packbits(padded_bits.reshape(-1, 8), axis=1).flatten()

    return packed


def unpack_binary_weights(packed: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:
    """
    Unpack binary weights from compact byte representation.

    Args:
        packed: Packed byte array
        shape: Original weight matrix shape

    Returns:
        binary_weights: Matrix of {-1, +1} values
    """
    # Unpack bits from bytes
    bits = np.unpackbits(packed)

    # Truncate to exact size needed
    total_elements = np.prod(shape)
    bits = bits[:total_elements]

    # Convert {0, 1} to {-1, +1} and reshape
    binary_weights = (2.0 * bits.astype(np.float32) - 1.0).reshape(shape)

    return binary_weights


def dequantize(binary_weights: np.ndarray, scales: np.ndarray) -> np.ndarray:
    """
    Reconstruct approximate weights from binary representation.

    Args:
        binary_weights: Matrix of {-1, +1} values
        scales: Per-channel scale factors

    Returns:
        weights: Reconstructed weight matrix
    """
    return binary_weights * scales


def quantize_model(model_weights: Dict[str, np.ndarray]) -> Dict[str, Any]:
    """
    Quantize entire model to 1-bit representation.

    Args:
        model_weights: Dictionary of weight tensors

    Returns:
        quantized: Dictionary containing packed weights and metadata
    """
    quantized = {}
    total_params = 0
    total_bytes = 0

    for name, weights in model_weights.items():
        if weights.ndim >= 2 and weights.size > 1000:  # Only quantize significant matrices
            # Quantize to binary
            binary, scales = quantize_to_binary(weights)

            # Pack for storage
            packed = pack_binary_weights(binary)

            quantized[name] = {
                'packed': packed,
                'scales': scales.flatten(),
                'shape': weights.shape,
                'dtype': 'binary'
            }

            total_params += weights.size
            total_bytes += len(packed) + scales.size * 4  # scales as float32
        else:
            # Keep small tensors as-is (biases, norms)
            quantized[name] = {
                'data': weights,
                'shape': weights.shape,
                'dtype': 'full'
            }
            total_bytes += weights.nbytes

    quantized['_metadata'] = {
        'total_parameters': total_params,
        'compressed_bytes': total_bytes,
        'compression_ratio': (total_params * 4) / total_bytes if total_bytes > 0 else 0
    }

    return quantized


def calculate_compression_metrics(original_size_mb: float, compressed_size_mb: float) -> Dict[str, float]:
    """
    Calculate compression metrics for reporting.

    Args:
        original_size_mb: Original model size in megabytes
        compressed_size_mb: Compressed model size in megabytes

    Returns:
        metrics: Dictionary of compression metrics
    """
    compression_ratio = original_size_mb / compressed_size_mb
    space_saved_percent = (1 - compressed_size_mb / original_size_mb) * 100
    bits_per_weight = (compressed_size_mb * 8 * 1024 * 1024) / (original_size_mb * 1024 * 1024 / 4)

    return {
        'compression_ratio': compression_ratio,
        'space_saved_percent': space_saved_percent,
        'bits_per_weight': bits_per_weight,
        'original_size_mb': original_size_mb,
        'compressed_size_mb': compressed_size_mb
    }