""" Base-3 packing utilities for memory-efficient ternary weight storage. Ternary weights ({-1, 0, +1}) can be represented in base-3, allowing multiple ternary values to be packed into a single byte or integer. This provides significant memory savings over storing each value as a float32. Theoretical packing: - 1 ternary value requires log2(3) ≈ 1.58 bits - 5 ternary values fit in 1 byte (3^5 = 243 < 256) - Compression ratio: 32 bits (float) → ~1.6 bits (packed) = 20x compression """ import torch from typing import Tuple def pack_ternary_base3(W_ternary: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]: """ Pack ternary weights into base-3 representation for memory efficiency. Packs multiple ternary values ({-1, 0, +1}) into uint8 storage using base-3 encoding. This achieves near-optimal compression for ternary data. Encoding scheme: -1 → 0 (base 3) 0 → 1 (base 3) +1 → 2 (base 3) Then pack 5 base-3 digits into one byte: packed_byte = d0 + d1*3 + d2*9 + d3*27 + d4*81 Args: W_ternary: Ternary weight tensor with values in {-1, 0, +1} Shape: [out_features, in_features] or [k, out_features, in_features] Returns: packed: Packed weights as uint8 tensor (5 values per byte) original_shape: Shape of original tensor for unpacking Notes: - 5 ternary values per byte (3^5 = 243 < 256) - Pad with zeros if dimensions not divisible by 5 - This is the primary memory optimization for ternary weights """ original_shape = tuple(W_ternary.shape) # Map {-1, 0, 1} to {0, 1, 2} base3 = (W_ternary + 1).flatten().to(torch.uint8) # Pad to multiple of 5 numel = base3.numel() pad_size = (5 - numel % 5) % 5 if pad_size > 0: base3 = torch.cat([base3, torch.zeros(pad_size, dtype=torch.uint8, device=base3.device)]) # Reshape into groups of 5 base3 = base3.view(-1, 5) # Pack each group: d0 + d1*3 + d2*9 + d3*27 + d4*81 powers_of_3 = torch.tensor([1, 3, 9, 27, 81], dtype=torch.uint8, device=base3.device) packed = (base3 * powers_of_3).sum(dim=1) return packed, original_shape def unpack_ternary_base3( packed: torch.Tensor, original_shape: Tuple[int, ...], ) -> torch.Tensor: """ Unpack base-3 encoded ternary weights back to full representation. Reverses the packing operation to recover ternary weights. Args: packed: Packed uint8 tensor (5 values per byte) original_shape: Original shape of the ternary tensor Returns: W_ternary: Ternary weight tensor with values in {-1, 0, +1} """ # Extract 5 base-3 digits from each byte d0 = packed % 3 d1 = (packed // 3) % 3 d2 = (packed // 9) % 3 d3 = (packed // 27) % 3 d4 = (packed // 81) % 3 # Stack digits base3 = torch.stack([d0, d1, d2, d3, d4], dim=1).flatten() # Compute original number of elements numel = 1 for dim in original_shape: numel *= dim # Truncate padding base3 = base3[:numel] # Map {0, 1, 2} back to {-1, 0, +1} W_ternary = base3.to(torch.float32) - 1 # Reshape to original shape W_ternary = W_ternary.view(original_shape) return W_ternary def compute_compression_ratio( original_size: int, packed_size: int, ) -> float: """ Compute compression ratio for packed ternary weights. Args: original_size: Size in bytes of original float32 weights packed_size: Size in bytes of packed ternary weights Returns: Compression ratio (e.g., 20.0 means 20x compression) Examples: >>> # 512 x 512 float32 weights = 512*512*4 bytes = 1,048,576 bytes >>> # Packed: 512*512 ternary values / 5 per byte ≈ 52,429 bytes >>> ratio = compute_compression_ratio(1048576, 52429) >>> print(f"Compression: {ratio:.1f}x") Compression: 20.0x """ return original_size / packed_size if packed_size > 0 else 0.0 def estimate_memory_savings( in_features: int, out_features: int, num_layers: int = 1, ) -> dict: """ Estimate memory savings from ternary packing for a given layer configuration. Args: in_features: Input dimension out_features: Output dimension num_layers: Number of layers (for cumulative savings) Returns: Dictionary with memory statistics: - float32_bytes: Memory for float32 weights - packed_bytes: Memory for packed ternary weights - savings_bytes: Absolute memory saved - compression_ratio: Ratio of compression Examples: >>> stats = estimate_memory_savings(768, 3072, num_layers=12) >>> print(f"Total savings: {stats['savings_bytes'] / 1e6:.1f} MB") """ # Calculate float32 weight size weights_per_layer = in_features * out_features float32_bytes_per_layer = weights_per_layer * 4 # 4 bytes per float32 # Calculate packed size (5 ternary values per byte) packed_bytes_per_layer = (weights_per_layer + 4) // 5 # Ceiling division # Scale by number of layers float32_bytes = float32_bytes_per_layer * num_layers packed_bytes = packed_bytes_per_layer * num_layers # Calculate savings savings_bytes = float32_bytes - packed_bytes compression_ratio = compute_compression_ratio(float32_bytes, packed_bytes) return { 'float32_bytes': float32_bytes, 'packed_bytes': packed_bytes, 'savings_bytes': savings_bytes, 'compression_ratio': compression_ratio, } # Advanced packing schemes (for future optimization for which ill do later) def pack_ternary_bitwise(W_ternary: torch.Tensor) -> torch.Tensor: """ Alternative packing using 2 bits per ternary value. Simpler but less efficient than base-3 packing: -1 → 00 0 → 01 +1 → 10 This uses 2 bits per value (4 values per byte) instead of optimal 1.58 bits. Easier to implement but 20% less efficient than base-3 packing. TODO: - Implement 2-bit packing scheme - Compare with base-3 for speed vs. compression trade-off """ # TODO: Implement bitwise packing (future optimization) raise NotImplementedError("pack_ternary_bitwise not yet implemented") def unpack_ternary_bitwise(packed: torch.Tensor, original_shape: Tuple[int, ...]) -> torch.Tensor: """ Unpack 2-bit encoded ternary weights. TODO: - Implement bitwise unpacking """ # TODO: Implement bitwise unpacking raise NotImplementedError("unpack_ternary_bitwise not yet implemented")