File size: 7,017 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""
Base-3 packing utilities for memory-efficient ternary weight storage.
Ternary weights ({-1, 0, +1}) can be represented in base-3, allowing
multiple ternary values to be packed into a single byte or integer.
This provides significant memory savings over storing each value as a float32.
Theoretical packing:
- 1 ternary value requires log2(3) β 1.58 bits
- 5 ternary values fit in 1 byte (3^5 = 243 < 256)
- Compression ratio: 32 bits (float) β ~1.6 bits (packed) = 20x compression
"""
import torch
from typing import Tuple
def pack_ternary_base3(W_ternary: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]:
"""
Pack ternary weights into base-3 representation for memory efficiency.
Packs multiple ternary values ({-1, 0, +1}) into uint8 storage using base-3
encoding. This achieves near-optimal compression for ternary data.
Encoding scheme:
-1 β 0 (base 3)
0 β 1 (base 3)
+1 β 2 (base 3)
Then pack 5 base-3 digits into one byte:
packed_byte = d0 + d1*3 + d2*9 + d3*27 + d4*81
Args:
W_ternary: Ternary weight tensor with values in {-1, 0, +1}
Shape: [out_features, in_features] or [k, out_features, in_features]
Returns:
packed: Packed weights as uint8 tensor (5 values per byte)
original_shape: Shape of original tensor for unpacking
Notes:
- 5 ternary values per byte (3^5 = 243 < 256)
- Pad with zeros if dimensions not divisible by 5
- This is the primary memory optimization for ternary weights
"""
original_shape = tuple(W_ternary.shape)
# Map {-1, 0, 1} to {0, 1, 2}
base3 = (W_ternary + 1).flatten().to(torch.uint8)
# Pad to multiple of 5
numel = base3.numel()
pad_size = (5 - numel % 5) % 5
if pad_size > 0:
base3 = torch.cat([base3, torch.zeros(pad_size, dtype=torch.uint8, device=base3.device)])
# Reshape into groups of 5
base3 = base3.view(-1, 5)
# Pack each group: d0 + d1*3 + d2*9 + d3*27 + d4*81
powers_of_3 = torch.tensor([1, 3, 9, 27, 81], dtype=torch.uint8, device=base3.device)
packed = (base3 * powers_of_3).sum(dim=1)
return packed, original_shape
def unpack_ternary_base3(
packed: torch.Tensor,
original_shape: Tuple[int, ...],
) -> torch.Tensor:
"""
Unpack base-3 encoded ternary weights back to full representation.
Reverses the packing operation to recover ternary weights.
Args:
packed: Packed uint8 tensor (5 values per byte)
original_shape: Original shape of the ternary tensor
Returns:
W_ternary: Ternary weight tensor with values in {-1, 0, +1}
"""
# Extract 5 base-3 digits from each byte
d0 = packed % 3
d1 = (packed // 3) % 3
d2 = (packed // 9) % 3
d3 = (packed // 27) % 3
d4 = (packed // 81) % 3
# Stack digits
base3 = torch.stack([d0, d1, d2, d3, d4], dim=1).flatten()
# Compute original number of elements
numel = 1
for dim in original_shape:
numel *= dim
# Truncate padding
base3 = base3[:numel]
# Map {0, 1, 2} back to {-1, 0, +1}
W_ternary = base3.to(torch.float32) - 1
# Reshape to original shape
W_ternary = W_ternary.view(original_shape)
return W_ternary
def compute_compression_ratio(
original_size: int,
packed_size: int,
) -> float:
"""
Compute compression ratio for packed ternary weights.
Args:
original_size: Size in bytes of original float32 weights
packed_size: Size in bytes of packed ternary weights
Returns:
Compression ratio (e.g., 20.0 means 20x compression)
Examples:
>>> # 512 x 512 float32 weights = 512*512*4 bytes = 1,048,576 bytes
>>> # Packed: 512*512 ternary values / 5 per byte β 52,429 bytes
>>> ratio = compute_compression_ratio(1048576, 52429)
>>> print(f"Compression: {ratio:.1f}x")
Compression: 20.0x
"""
return original_size / packed_size if packed_size > 0 else 0.0
def estimate_memory_savings(
in_features: int,
out_features: int,
num_layers: int = 1,
) -> dict:
"""
Estimate memory savings from ternary packing for a given layer configuration.
Args:
in_features: Input dimension
out_features: Output dimension
num_layers: Number of layers (for cumulative savings)
Returns:
Dictionary with memory statistics:
- float32_bytes: Memory for float32 weights
- packed_bytes: Memory for packed ternary weights
- savings_bytes: Absolute memory saved
- compression_ratio: Ratio of compression
Examples:
>>> stats = estimate_memory_savings(768, 3072, num_layers=12)
>>> print(f"Total savings: {stats['savings_bytes'] / 1e6:.1f} MB")
"""
# Calculate float32 weight size
weights_per_layer = in_features * out_features
float32_bytes_per_layer = weights_per_layer * 4 # 4 bytes per float32
# Calculate packed size (5 ternary values per byte)
packed_bytes_per_layer = (weights_per_layer + 4) // 5 # Ceiling division
# Scale by number of layers
float32_bytes = float32_bytes_per_layer * num_layers
packed_bytes = packed_bytes_per_layer * num_layers
# Calculate savings
savings_bytes = float32_bytes - packed_bytes
compression_ratio = compute_compression_ratio(float32_bytes, packed_bytes)
return {
'float32_bytes': float32_bytes,
'packed_bytes': packed_bytes,
'savings_bytes': savings_bytes,
'compression_ratio': compression_ratio,
}
# Advanced packing schemes (for future optimization for which ill do later)
def pack_ternary_bitwise(W_ternary: torch.Tensor) -> torch.Tensor:
"""
Alternative packing using 2 bits per ternary value.
Simpler but less efficient than base-3 packing:
-1 β 00
0 β 01
+1 β 10
This uses 2 bits per value (4 values per byte) instead of optimal 1.58 bits.
Easier to implement but 20% less efficient than base-3 packing.
TODO:
- Implement 2-bit packing scheme
- Compare with base-3 for speed vs. compression trade-off
"""
# TODO: Implement bitwise packing (future optimization)
raise NotImplementedError("pack_ternary_bitwise not yet implemented")
def unpack_ternary_bitwise(packed: torch.Tensor, original_shape: Tuple[int, ...]) -> torch.Tensor:
"""
Unpack 2-bit encoded ternary weights.
TODO:
- Implement bitwise unpacking
"""
# TODO: Implement bitwise unpacking
raise NotImplementedError("unpack_ternary_bitwise not yet implemented")
|