File size: 7,017 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""

Base-3 packing utilities for memory-efficient ternary weight storage.



Ternary weights ({-1, 0, +1}) can be represented in base-3, allowing

multiple ternary values to be packed into a single byte or integer.

This provides significant memory savings over storing each value as a float32.



Theoretical packing:

    - 1 ternary value requires log2(3) β‰ˆ 1.58 bits

    - 5 ternary values fit in 1 byte (3^5 = 243 < 256)

    - Compression ratio: 32 bits (float) β†’ ~1.6 bits (packed) = 20x compression

"""

import torch
from typing import Tuple


def pack_ternary_base3(W_ternary: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]:
    """

    Pack ternary weights into base-3 representation for memory efficiency.

    

    Packs multiple ternary values ({-1, 0, +1}) into uint8 storage using base-3

    encoding. This achieves near-optimal compression for ternary data.

    

    Encoding scheme:

        -1 β†’ 0 (base 3)

         0 β†’ 1 (base 3)

        +1 β†’ 2 (base 3)

    

    Then pack 5 base-3 digits into one byte:

        packed_byte = d0 + d1*3 + d2*9 + d3*27 + d4*81

    

    Args:

        W_ternary: Ternary weight tensor with values in {-1, 0, +1}

                   Shape: [out_features, in_features] or [k, out_features, in_features]

    

    Returns:

        packed: Packed weights as uint8 tensor (5 values per byte)

        original_shape: Shape of original tensor for unpacking

    

    Notes:

        - 5 ternary values per byte (3^5 = 243 < 256)

        - Pad with zeros if dimensions not divisible by 5

        - This is the primary memory optimization for ternary weights

    """
    original_shape = tuple(W_ternary.shape)
    
    # Map {-1, 0, 1} to {0, 1, 2}
    base3 = (W_ternary + 1).flatten().to(torch.uint8)
    
    # Pad to multiple of 5
    numel = base3.numel()
    pad_size = (5 - numel % 5) % 5
    if pad_size > 0:
        base3 = torch.cat([base3, torch.zeros(pad_size, dtype=torch.uint8, device=base3.device)])
    
    # Reshape into groups of 5
    base3 = base3.view(-1, 5)
    
    # Pack each group: d0 + d1*3 + d2*9 + d3*27 + d4*81
    powers_of_3 = torch.tensor([1, 3, 9, 27, 81], dtype=torch.uint8, device=base3.device)
    packed = (base3 * powers_of_3).sum(dim=1)
    
    return packed, original_shape


def unpack_ternary_base3(

    packed: torch.Tensor,

    original_shape: Tuple[int, ...],

) -> torch.Tensor:
    """

    Unpack base-3 encoded ternary weights back to full representation.

    

    Reverses the packing operation to recover ternary weights.

    

    Args:

        packed: Packed uint8 tensor (5 values per byte)

        original_shape: Original shape of the ternary tensor

    

    Returns:

        W_ternary: Ternary weight tensor with values in {-1, 0, +1}

    """
    # Extract 5 base-3 digits from each byte
    d0 = packed % 3
    d1 = (packed // 3) % 3
    d2 = (packed // 9) % 3
    d3 = (packed // 27) % 3
    d4 = (packed // 81) % 3
    
    # Stack digits
    base3 = torch.stack([d0, d1, d2, d3, d4], dim=1).flatten()
    
    # Compute original number of elements
    numel = 1
    for dim in original_shape:
        numel *= dim
    
    # Truncate padding
    base3 = base3[:numel]
    
    # Map {0, 1, 2} back to {-1, 0, +1}
    W_ternary = base3.to(torch.float32) - 1
    
    # Reshape to original shape
    W_ternary = W_ternary.view(original_shape)
    
    return W_ternary


def compute_compression_ratio(

    original_size: int,

    packed_size: int,

) -> float:
    """

    Compute compression ratio for packed ternary weights.

    

    Args:

        original_size: Size in bytes of original float32 weights

        packed_size: Size in bytes of packed ternary weights

    

    Returns:

        Compression ratio (e.g., 20.0 means 20x compression)

    

    Examples:

        >>> # 512 x 512 float32 weights = 512*512*4 bytes = 1,048,576 bytes

        >>> # Packed: 512*512 ternary values / 5 per byte β‰ˆ 52,429 bytes

        >>> ratio = compute_compression_ratio(1048576, 52429)

        >>> print(f"Compression: {ratio:.1f}x")

        Compression: 20.0x

    """
    return original_size / packed_size if packed_size > 0 else 0.0


def estimate_memory_savings(

    in_features: int,

    out_features: int,

    num_layers: int = 1,

) -> dict:
    """

    Estimate memory savings from ternary packing for a given layer configuration.

    

    Args:

        in_features: Input dimension

        out_features: Output dimension

        num_layers: Number of layers (for cumulative savings)

    

    Returns:

        Dictionary with memory statistics:

            - float32_bytes: Memory for float32 weights

            - packed_bytes: Memory for packed ternary weights

            - savings_bytes: Absolute memory saved

            - compression_ratio: Ratio of compression

    

    Examples:

        >>> stats = estimate_memory_savings(768, 3072, num_layers=12)

        >>> print(f"Total savings: {stats['savings_bytes'] / 1e6:.1f} MB")

    """
    # Calculate float32 weight size
    weights_per_layer = in_features * out_features
    float32_bytes_per_layer = weights_per_layer * 4  # 4 bytes per float32
    
    # Calculate packed size (5 ternary values per byte)
    packed_bytes_per_layer = (weights_per_layer + 4) // 5  # Ceiling division
    
    # Scale by number of layers
    float32_bytes = float32_bytes_per_layer * num_layers
    packed_bytes = packed_bytes_per_layer * num_layers
    
    # Calculate savings
    savings_bytes = float32_bytes - packed_bytes
    compression_ratio = compute_compression_ratio(float32_bytes, packed_bytes)
    
    return {
        'float32_bytes': float32_bytes,
        'packed_bytes': packed_bytes,
        'savings_bytes': savings_bytes,
        'compression_ratio': compression_ratio,
    }


# Advanced packing schemes (for future optimization for which ill do later)

def pack_ternary_bitwise(W_ternary: torch.Tensor) -> torch.Tensor:
    """

    Alternative packing using 2 bits per ternary value.

    

    Simpler but less efficient than base-3 packing:

        -1 β†’ 00

         0 β†’ 01

        +1 β†’ 10

    

    This uses 2 bits per value (4 values per byte) instead of optimal 1.58 bits.

    Easier to implement but 20% less efficient than base-3 packing.

    

    TODO:

        - Implement 2-bit packing scheme

        - Compare with base-3 for speed vs. compression trade-off

    """
    # TODO: Implement bitwise packing (future optimization)
    raise NotImplementedError("pack_ternary_bitwise not yet implemented")


def unpack_ternary_bitwise(packed: torch.Tensor, original_shape: Tuple[int, ...]) -> torch.Tensor:
    """

    Unpack 2-bit encoded ternary weights.

    

    TODO:

        - Implement bitwise unpacking

    """
    # TODO: Implement bitwise unpacking
    raise NotImplementedError("unpack_ternary_bitwise not yet implemented")