🚀 Refined BitTransformerLM: Organized codebase with best practices
Browse files
bit_transformer/BTLM_Extensions/rle_compression.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RLE Compression Extension for BitTransformerLM
|
| 3 |
+
==============================================
|
| 4 |
+
|
| 5 |
+
Advanced Run-Length Encoding compression module with multiple encoding schemes,
|
| 6 |
+
adaptive compression, and training integration for BitTransformerLM.
|
| 7 |
+
|
| 8 |
+
Key features:
|
| 9 |
+
- Multiple RLE encoding schemes (basic, delta, hierarchical)
|
| 10 |
+
- Adaptive compression with quality thresholds
|
| 11 |
+
- Training integration with compression-aware loss
|
| 12 |
+
- Batch processing and vectorized operations
|
| 13 |
+
- Compatible with BitTransformerLM's training infrastructure
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from typing import List, Tuple, Optional, Dict, Any, Union
|
| 19 |
+
import warnings
|
| 20 |
+
import math
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RLEEncoder:
|
| 26 |
+
"""
|
| 27 |
+
Advanced Run-Length Encoder with multiple encoding schemes.
|
| 28 |
+
|
| 29 |
+
Supports:
|
| 30 |
+
- Basic RLE: (value, count) pairs
|
| 31 |
+
- Delta RLE: Differences between consecutive runs
|
| 32 |
+
- Hierarchical RLE: Multi-level compression
|
| 33 |
+
- Adaptive RLE: Chooses best scheme based on data
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
scheme: str = "adaptive",
|
| 39 |
+
min_run_length: int = 2,
|
| 40 |
+
max_value: int = 255,
|
| 41 |
+
delta_threshold: float = 0.7,
|
| 42 |
+
hierarchical_levels: int = 2,
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Args:
|
| 46 |
+
scheme: Encoding scheme ('basic', 'delta', 'hierarchical', 'adaptive')
|
| 47 |
+
min_run_length: Minimum run length to compress
|
| 48 |
+
max_value: Maximum value for encoding
|
| 49 |
+
delta_threshold: Compression ratio threshold for delta encoding
|
| 50 |
+
hierarchical_levels: Number of levels for hierarchical encoding
|
| 51 |
+
"""
|
| 52 |
+
self.scheme = scheme
|
| 53 |
+
self.min_run_length = min_run_length
|
| 54 |
+
self.max_value = max_value
|
| 55 |
+
self.delta_threshold = delta_threshold
|
| 56 |
+
self.hierarchical_levels = hierarchical_levels
|
| 57 |
+
|
| 58 |
+
self.stats = {
|
| 59 |
+
"total_compressions": 0,
|
| 60 |
+
"total_original_size": 0,
|
| 61 |
+
"total_compressed_size": 0,
|
| 62 |
+
"scheme_usage": defaultdict(int),
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def encode_basic_rle(self, data: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
"""Basic run-length encoding: (value, count) pairs."""
|
| 67 |
+
if data.numel() == 0:
|
| 68 |
+
return torch.tensor([], dtype=torch.uint8)
|
| 69 |
+
|
| 70 |
+
data_flat = data.flatten()
|
| 71 |
+
encoded = []
|
| 72 |
+
|
| 73 |
+
current_val = data_flat[0].item()
|
| 74 |
+
current_count = 1
|
| 75 |
+
|
| 76 |
+
for i in range(1, len(data_flat)):
|
| 77 |
+
val = data_flat[i].item()
|
| 78 |
+
if val == current_val and current_count < 255:
|
| 79 |
+
current_count += 1
|
| 80 |
+
else:
|
| 81 |
+
if current_count >= self.min_run_length:
|
| 82 |
+
encoded.extend([current_val, current_count])
|
| 83 |
+
else:
|
| 84 |
+
# Store individual values for short runs
|
| 85 |
+
for _ in range(current_count):
|
| 86 |
+
encoded.append(current_val)
|
| 87 |
+
current_val = val
|
| 88 |
+
current_count = 1
|
| 89 |
+
|
| 90 |
+
# Handle last run
|
| 91 |
+
if current_count >= self.min_run_length:
|
| 92 |
+
encoded.extend([current_val, current_count])
|
| 93 |
+
else:
|
| 94 |
+
for _ in range(current_count):
|
| 95 |
+
encoded.append(current_val)
|
| 96 |
+
|
| 97 |
+
return torch.tensor(encoded, dtype=torch.uint8)
|
| 98 |
+
|
| 99 |
+
def decode_basic_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
|
| 100 |
+
"""Decode basic run-length encoded data."""
|
| 101 |
+
if encoded.numel() == 0:
|
| 102 |
+
return torch.tensor([], dtype=torch.long)
|
| 103 |
+
|
| 104 |
+
decoded = []
|
| 105 |
+
i = 0
|
| 106 |
+
|
| 107 |
+
while i < len(encoded):
|
| 108 |
+
if i + 1 < len(encoded):
|
| 109 |
+
val = encoded[i].item()
|
| 110 |
+
count = encoded[i + 1].item()
|
| 111 |
+
|
| 112 |
+
# Check if this looks like a (value, count) pair
|
| 113 |
+
if count > 1 and count <= 255:
|
| 114 |
+
decoded.extend([val] * count)
|
| 115 |
+
i += 2
|
| 116 |
+
else:
|
| 117 |
+
# Individual value
|
| 118 |
+
decoded.append(val)
|
| 119 |
+
i += 1
|
| 120 |
+
else:
|
| 121 |
+
decoded.append(encoded[i].item())
|
| 122 |
+
i += 1
|
| 123 |
+
|
| 124 |
+
result = torch.tensor(decoded, dtype=torch.long)
|
| 125 |
+
|
| 126 |
+
# Trim or pad to target length if specified
|
| 127 |
+
if target_length is not None:
|
| 128 |
+
if len(result) > target_length:
|
| 129 |
+
result = result[:target_length]
|
| 130 |
+
elif len(result) < target_length:
|
| 131 |
+
result = F.pad(result, (0, target_length - len(result)))
|
| 132 |
+
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
def encode_delta_rle(self, data: torch.Tensor) -> torch.Tensor:
|
| 136 |
+
"""Delta run-length encoding: encode differences between values."""
|
| 137 |
+
if data.numel() <= 1:
|
| 138 |
+
return self.encode_basic_rle(data)
|
| 139 |
+
|
| 140 |
+
data_flat = data.flatten()
|
| 141 |
+
|
| 142 |
+
# Compute deltas
|
| 143 |
+
deltas = torch.diff(data_flat, prepend=data_flat[0:1])
|
| 144 |
+
|
| 145 |
+
# Apply basic RLE to deltas (shifted to handle negatives)
|
| 146 |
+
shifted_deltas = deltas + 128 # Shift to 0-255 range
|
| 147 |
+
shifted_deltas = torch.clamp(shifted_deltas, 0, 255)
|
| 148 |
+
|
| 149 |
+
delta_encoded = self.encode_basic_rle(shifted_deltas)
|
| 150 |
+
|
| 151 |
+
# Prepend original first value
|
| 152 |
+
result = torch.cat([data_flat[0:1].to(torch.uint8), delta_encoded])
|
| 153 |
+
return result
|
| 154 |
+
|
| 155 |
+
def decode_delta_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
|
| 156 |
+
"""Decode delta run-length encoded data."""
|
| 157 |
+
if encoded.numel() <= 1:
|
| 158 |
+
return self.decode_basic_rle(encoded, target_length)
|
| 159 |
+
|
| 160 |
+
# First value is the original value
|
| 161 |
+
first_val = encoded[0].item()
|
| 162 |
+
delta_encoded = encoded[1:]
|
| 163 |
+
|
| 164 |
+
# Decode deltas
|
| 165 |
+
deltas = self.decode_basic_rle(delta_encoded)
|
| 166 |
+
|
| 167 |
+
# Unshift deltas
|
| 168 |
+
deltas = deltas.float() - 128
|
| 169 |
+
|
| 170 |
+
# Reconstruct original sequence
|
| 171 |
+
if deltas.numel() > 0:
|
| 172 |
+
deltas[0] = first_val # Replace first delta with original value
|
| 173 |
+
result = torch.cumsum(deltas, dim=0).long()
|
| 174 |
+
else:
|
| 175 |
+
result = torch.tensor([first_val], dtype=torch.long)
|
| 176 |
+
|
| 177 |
+
# Trim or pad to target length
|
| 178 |
+
if target_length is not None:
|
| 179 |
+
if len(result) > target_length:
|
| 180 |
+
result = result[:target_length]
|
| 181 |
+
elif len(result) < target_length:
|
| 182 |
+
result = F.pad(result, (0, target_length - len(result)))
|
| 183 |
+
|
| 184 |
+
return result
|
| 185 |
+
|
| 186 |
+
def encode_hierarchical_rle(self, data: torch.Tensor) -> torch.Tensor:
|
| 187 |
+
"""Hierarchical RLE: Apply RLE recursively for better compression."""
|
| 188 |
+
current_data = data.clone()
|
| 189 |
+
|
| 190 |
+
for level in range(self.hierarchical_levels):
|
| 191 |
+
encoded = self.encode_basic_rle(current_data)
|
| 192 |
+
|
| 193 |
+
# Check if compression is beneficial
|
| 194 |
+
if encoded.numel() >= current_data.numel() * 0.9:
|
| 195 |
+
# Compression not beneficial, return previous level
|
| 196 |
+
break
|
| 197 |
+
|
| 198 |
+
current_data = encoded
|
| 199 |
+
|
| 200 |
+
return current_data
|
| 201 |
+
|
| 202 |
+
def decode_hierarchical_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None, levels: int = None) -> torch.Tensor:
|
| 203 |
+
"""Decode hierarchical RLE data."""
|
| 204 |
+
if levels is None:
|
| 205 |
+
levels = self.hierarchical_levels
|
| 206 |
+
|
| 207 |
+
current_data = encoded.clone()
|
| 208 |
+
|
| 209 |
+
for level in range(levels):
|
| 210 |
+
try:
|
| 211 |
+
current_data = self.decode_basic_rle(current_data)
|
| 212 |
+
except Exception:
|
| 213 |
+
# If decoding fails, return current state
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
# Final length adjustment
|
| 217 |
+
if target_length is not None and current_data.numel() != target_length:
|
| 218 |
+
if current_data.numel() > target_length:
|
| 219 |
+
current_data = current_data[:target_length]
|
| 220 |
+
else:
|
| 221 |
+
current_data = F.pad(current_data, (0, target_length - current_data.numel()))
|
| 222 |
+
|
| 223 |
+
return current_data
|
| 224 |
+
|
| 225 |
+
def encode(self, data: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
| 226 |
+
"""
|
| 227 |
+
Encode data using the configured scheme.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
data: Input tensor to compress
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Tuple of (encoded_data, metadata)
|
| 234 |
+
"""
|
| 235 |
+
original_shape = data.shape
|
| 236 |
+
original_size = data.numel()
|
| 237 |
+
|
| 238 |
+
if self.scheme == "basic":
|
| 239 |
+
encoded = self.encode_basic_rle(data)
|
| 240 |
+
scheme_used = "basic"
|
| 241 |
+
elif self.scheme == "delta":
|
| 242 |
+
encoded = self.encode_delta_rle(data)
|
| 243 |
+
scheme_used = "delta"
|
| 244 |
+
elif self.scheme == "hierarchical":
|
| 245 |
+
encoded = self.encode_hierarchical_rle(data)
|
| 246 |
+
scheme_used = "hierarchical"
|
| 247 |
+
elif self.scheme == "adaptive":
|
| 248 |
+
# Try all schemes and pick the best one
|
| 249 |
+
basic_encoded = self.encode_basic_rle(data)
|
| 250 |
+
delta_encoded = self.encode_delta_rle(data)
|
| 251 |
+
hierarchical_encoded = self.encode_hierarchical_rle(data)
|
| 252 |
+
|
| 253 |
+
candidates = {
|
| 254 |
+
"basic": basic_encoded,
|
| 255 |
+
"delta": delta_encoded,
|
| 256 |
+
"hierarchical": hierarchical_encoded,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
# Choose scheme with best compression ratio
|
| 260 |
+
best_scheme = min(candidates.keys(), key=lambda k: candidates[k].numel())
|
| 261 |
+
encoded = candidates[best_scheme]
|
| 262 |
+
scheme_used = best_scheme
|
| 263 |
+
else:
|
| 264 |
+
raise ValueError(f"Unknown encoding scheme: {self.scheme}")
|
| 265 |
+
|
| 266 |
+
# Update statistics
|
| 267 |
+
self.stats["total_compressions"] += 1
|
| 268 |
+
self.stats["total_original_size"] += original_size
|
| 269 |
+
self.stats["total_compressed_size"] += encoded.numel()
|
| 270 |
+
self.stats["scheme_usage"][scheme_used] += 1
|
| 271 |
+
|
| 272 |
+
metadata = {
|
| 273 |
+
"scheme": scheme_used,
|
| 274 |
+
"original_shape": original_shape,
|
| 275 |
+
"original_size": original_size,
|
| 276 |
+
"compressed_size": encoded.numel(),
|
| 277 |
+
"compression_ratio": encoded.numel() / original_size if original_size > 0 else 1.0,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
return encoded, metadata
|
| 281 |
+
|
| 282 |
+
def decode(self, encoded: torch.Tensor, metadata: Dict[str, Any]) -> torch.Tensor:
|
| 283 |
+
"""
|
| 284 |
+
Decode compressed data using metadata.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
encoded: Compressed data
|
| 288 |
+
metadata: Metadata from encoding
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Decoded tensor
|
| 292 |
+
"""
|
| 293 |
+
scheme = metadata["scheme"]
|
| 294 |
+
original_shape = metadata["original_shape"]
|
| 295 |
+
target_length = math.prod(original_shape) if original_shape else None
|
| 296 |
+
|
| 297 |
+
if scheme == "basic":
|
| 298 |
+
decoded = self.decode_basic_rle(encoded, target_length)
|
| 299 |
+
elif scheme == "delta":
|
| 300 |
+
decoded = self.decode_delta_rle(encoded, target_length)
|
| 301 |
+
elif scheme == "hierarchical":
|
| 302 |
+
decoded = self.decode_hierarchical_rle(encoded, target_length)
|
| 303 |
+
else:
|
| 304 |
+
raise ValueError(f"Unknown decoding scheme: {scheme}")
|
| 305 |
+
|
| 306 |
+
# Reshape to original shape
|
| 307 |
+
if original_shape and decoded.numel() >= math.prod(original_shape):
|
| 308 |
+
decoded = decoded[:math.prod(original_shape)].reshape(original_shape)
|
| 309 |
+
|
| 310 |
+
return decoded
|
| 311 |
+
|
| 312 |
+
def get_compression_stats(self) -> Dict[str, float]:
|
| 313 |
+
"""Get compression statistics."""
|
| 314 |
+
if self.stats["total_original_size"] == 0:
|
| 315 |
+
return {"average_compression_ratio": 1.0, "total_savings": 0.0}
|
| 316 |
+
|
| 317 |
+
avg_ratio = self.stats["total_compressed_size"] / self.stats["total_original_size"]
|
| 318 |
+
total_savings = self.stats["total_original_size"] - self.stats["total_compressed_size"]
|
| 319 |
+
|
| 320 |
+
return {
|
| 321 |
+
"average_compression_ratio": avg_ratio,
|
| 322 |
+
"total_savings": total_savings,
|
| 323 |
+
"total_compressions": self.stats["total_compressions"],
|
| 324 |
+
"scheme_usage": dict(self.stats["scheme_usage"]),
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class CompressedBitDataset(torch.utils.data.Dataset):
|
| 329 |
+
"""
|
| 330 |
+
Dataset wrapper that applies RLE compression on-the-fly during training.
|
| 331 |
+
|
| 332 |
+
This allows for memory-efficient storage of large bit sequences while
|
| 333 |
+
maintaining fast access during training.
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
data: torch.Tensor,
|
| 339 |
+
encoder: RLEEncoder,
|
| 340 |
+
compress_probability: float = 0.5,
|
| 341 |
+
cache_size: int = 1000,
|
| 342 |
+
):
|
| 343 |
+
"""
|
| 344 |
+
Args:
|
| 345 |
+
data: Original bit sequence data
|
| 346 |
+
encoder: RLE encoder instance
|
| 347 |
+
compress_probability: Probability of returning compressed data
|
| 348 |
+
cache_size: Number of compressed items to cache
|
| 349 |
+
"""
|
| 350 |
+
self.data = data
|
| 351 |
+
self.encoder = encoder
|
| 352 |
+
self.compress_probability = compress_probability
|
| 353 |
+
self.cache_size = cache_size
|
| 354 |
+
self.cache = {}
|
| 355 |
+
self.access_count = defaultdict(int)
|
| 356 |
+
|
| 357 |
+
def __len__(self):
|
| 358 |
+
return len(self.data)
|
| 359 |
+
|
| 360 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
| 361 |
+
"""
|
| 362 |
+
Get item with optional compression.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Tuple of (data, metadata) where metadata indicates if compressed
|
| 366 |
+
"""
|
| 367 |
+
original_item = self.data[idx]
|
| 368 |
+
|
| 369 |
+
# Randomly decide whether to compress
|
| 370 |
+
if torch.rand(1).item() < self.compress_probability:
|
| 371 |
+
# Check cache first
|
| 372 |
+
if idx in self.cache:
|
| 373 |
+
compressed, metadata = self.cache[idx]
|
| 374 |
+
self.access_count[idx] += 1
|
| 375 |
+
metadata["from_cache"] = True
|
| 376 |
+
return compressed, metadata
|
| 377 |
+
|
| 378 |
+
# Compress item
|
| 379 |
+
compressed, metadata = self.encoder.encode(original_item)
|
| 380 |
+
|
| 381 |
+
# Add to cache if there's room
|
| 382 |
+
if len(self.cache) < self.cache_size:
|
| 383 |
+
self.cache[idx] = (compressed, metadata)
|
| 384 |
+
elif self.access_count:
|
| 385 |
+
# Replace least accessed item
|
| 386 |
+
least_accessed = min(self.cache.keys(), key=lambda k: self.access_count[k])
|
| 387 |
+
del self.cache[least_accessed]
|
| 388 |
+
del self.access_count[least_accessed]
|
| 389 |
+
self.cache[idx] = (compressed, metadata)
|
| 390 |
+
|
| 391 |
+
metadata["from_cache"] = False
|
| 392 |
+
return compressed, metadata
|
| 393 |
+
else:
|
| 394 |
+
# Return original data
|
| 395 |
+
metadata = {
|
| 396 |
+
"scheme": "uncompressed",
|
| 397 |
+
"original_shape": original_item.shape,
|
| 398 |
+
"compressed": False,
|
| 399 |
+
"from_cache": False,
|
| 400 |
+
}
|
| 401 |
+
return original_item, metadata
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def create_compression_aware_loss(
|
| 405 |
+
base_loss_fn,
|
| 406 |
+
compression_penalty: float = 0.01,
|
| 407 |
+
quality_threshold: float = 0.8,
|
| 408 |
+
) -> callable:
|
| 409 |
+
"""
|
| 410 |
+
Create a loss function that penalizes poor compression quality.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
base_loss_fn: Base loss function (e.g., CrossEntropyLoss)
|
| 414 |
+
compression_penalty: Penalty weight for compression artifacts
|
| 415 |
+
quality_threshold: Minimum compression quality threshold
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
Compression-aware loss function
|
| 419 |
+
"""
|
| 420 |
+
def compression_aware_loss(
|
| 421 |
+
logits: torch.Tensor,
|
| 422 |
+
targets: torch.Tensor,
|
| 423 |
+
metadata_batch: Optional[List[Dict[str, Any]]] = None,
|
| 424 |
+
) -> torch.Tensor:
|
| 425 |
+
"""
|
| 426 |
+
Compute loss with compression quality penalty.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
logits: Model output logits
|
| 430 |
+
targets: Target labels
|
| 431 |
+
metadata_batch: Batch of compression metadata
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
Adjusted loss tensor
|
| 435 |
+
"""
|
| 436 |
+
base_loss = base_loss_fn(logits, targets)
|
| 437 |
+
|
| 438 |
+
if metadata_batch is None:
|
| 439 |
+
return base_loss
|
| 440 |
+
|
| 441 |
+
# Compute compression quality penalty
|
| 442 |
+
penalty = 0.0
|
| 443 |
+
compressed_items = 0
|
| 444 |
+
|
| 445 |
+
for metadata in metadata_batch:
|
| 446 |
+
if metadata.get("compressed", False):
|
| 447 |
+
compressed_items += 1
|
| 448 |
+
compression_ratio = metadata.get("compression_ratio", 1.0)
|
| 449 |
+
|
| 450 |
+
# Penalty for poor compression
|
| 451 |
+
if compression_ratio > quality_threshold:
|
| 452 |
+
quality_penalty = (compression_ratio - quality_threshold) ** 2
|
| 453 |
+
penalty += quality_penalty
|
| 454 |
+
|
| 455 |
+
if compressed_items > 0:
|
| 456 |
+
penalty = penalty / compressed_items # Average penalty
|
| 457 |
+
total_loss = base_loss + compression_penalty * penalty
|
| 458 |
+
else:
|
| 459 |
+
total_loss = base_loss
|
| 460 |
+
|
| 461 |
+
return total_loss
|
| 462 |
+
|
| 463 |
+
return compression_aware_loss
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def integrate_rle_with_training(
|
| 467 |
+
model,
|
| 468 |
+
data: torch.Tensor,
|
| 469 |
+
encoder_config: Optional[Dict[str, Any]] = None,
|
| 470 |
+
compression_config: Optional[Dict[str, Any]] = None,
|
| 471 |
+
) -> Tuple[CompressedBitDataset, callable]:
|
| 472 |
+
"""
|
| 473 |
+
Integrate RLE compression with BitTransformerLM training.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
model: BitTransformerLM model
|
| 477 |
+
data: Training data tensor
|
| 478 |
+
encoder_config: Configuration for RLE encoder
|
| 479 |
+
compression_config: Configuration for compression-aware training
|
| 480 |
+
|
| 481 |
+
Returns:
|
| 482 |
+
Tuple of (compressed_dataset, compression_aware_loss_fn)
|
| 483 |
+
"""
|
| 484 |
+
# Default configurations
|
| 485 |
+
if encoder_config is None:
|
| 486 |
+
encoder_config = {
|
| 487 |
+
"scheme": "adaptive",
|
| 488 |
+
"min_run_length": 2,
|
| 489 |
+
"delta_threshold": 0.7,
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
if compression_config is None:
|
| 493 |
+
compression_config = {
|
| 494 |
+
"compress_probability": 0.3,
|
| 495 |
+
"compression_penalty": 0.01,
|
| 496 |
+
"quality_threshold": 0.8,
|
| 497 |
+
"cache_size": 1000,
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
# Create encoder and dataset
|
| 501 |
+
encoder = RLEEncoder(**encoder_config)
|
| 502 |
+
dataset = CompressedBitDataset(
|
| 503 |
+
data,
|
| 504 |
+
encoder,
|
| 505 |
+
compress_probability=compression_config["compress_probability"],
|
| 506 |
+
cache_size=compression_config["cache_size"],
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
# Create compression-aware loss
|
| 510 |
+
base_loss = torch.nn.CrossEntropyLoss()
|
| 511 |
+
loss_fn = create_compression_aware_loss(
|
| 512 |
+
base_loss,
|
| 513 |
+
compression_penalty=compression_config["compression_penalty"],
|
| 514 |
+
quality_threshold=compression_config["quality_threshold"],
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
return dataset, loss_fn
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def benchmark_compression_schemes(
|
| 521 |
+
test_data: torch.Tensor,
|
| 522 |
+
schemes: List[str] = ["basic", "delta", "hierarchical", "adaptive"],
|
| 523 |
+
) -> Dict[str, Dict[str, float]]:
|
| 524 |
+
"""
|
| 525 |
+
Benchmark different compression schemes on test data.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
test_data: Test data tensor
|
| 529 |
+
schemes: List of schemes to test
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
Dictionary with benchmark results for each scheme
|
| 533 |
+
"""
|
| 534 |
+
results = {}
|
| 535 |
+
|
| 536 |
+
for scheme in schemes:
|
| 537 |
+
encoder = RLEEncoder(scheme=scheme)
|
| 538 |
+
|
| 539 |
+
# Test compression/decompression
|
| 540 |
+
try:
|
| 541 |
+
compressed, metadata = encoder.encode(test_data)
|
| 542 |
+
reconstructed = encoder.decode(compressed, metadata)
|
| 543 |
+
|
| 544 |
+
# Compute metrics
|
| 545 |
+
compression_ratio = compressed.numel() / test_data.numel()
|
| 546 |
+
reconstruction_error = torch.mean((test_data.float() - reconstructed.float()) ** 2).item()
|
| 547 |
+
|
| 548 |
+
results[scheme] = {
|
| 549 |
+
"compression_ratio": compression_ratio,
|
| 550 |
+
"reconstruction_error": reconstruction_error,
|
| 551 |
+
"compressed_size": compressed.numel(),
|
| 552 |
+
"original_size": test_data.numel(),
|
| 553 |
+
"success": True,
|
| 554 |
+
}
|
| 555 |
+
except Exception as e:
|
| 556 |
+
results[scheme] = {
|
| 557 |
+
"compression_ratio": 1.0,
|
| 558 |
+
"reconstruction_error": float("inf"),
|
| 559 |
+
"compressed_size": test_data.numel(),
|
| 560 |
+
"original_size": test_data.numel(),
|
| 561 |
+
"success": False,
|
| 562 |
+
"error": str(e),
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
return results
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# Example usage and utilities
|
| 569 |
+
def create_rle_training_config(
|
| 570 |
+
scheme: str = "adaptive",
|
| 571 |
+
compress_probability: float = 0.3,
|
| 572 |
+
compression_penalty: float = 0.01,
|
| 573 |
+
**kwargs
|
| 574 |
+
) -> Dict[str, Any]:
|
| 575 |
+
"""
|
| 576 |
+
Create configuration for RLE-enhanced training.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
scheme: RLE encoding scheme
|
| 580 |
+
compress_probability: Probability of compression during training
|
| 581 |
+
compression_penalty: Loss penalty for compression artifacts
|
| 582 |
+
**kwargs: Additional configuration options
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
Dictionary with RLE training configuration
|
| 586 |
+
"""
|
| 587 |
+
config = {
|
| 588 |
+
"compression_type": "rle",
|
| 589 |
+
"encoder_config": {
|
| 590 |
+
"scheme": scheme,
|
| 591 |
+
"min_run_length": kwargs.get("min_run_length", 2),
|
| 592 |
+
"delta_threshold": kwargs.get("delta_threshold", 0.7),
|
| 593 |
+
"hierarchical_levels": kwargs.get("hierarchical_levels", 2),
|
| 594 |
+
},
|
| 595 |
+
"training_config": {
|
| 596 |
+
"compress_probability": compress_probability,
|
| 597 |
+
"compression_penalty": compression_penalty,
|
| 598 |
+
"quality_threshold": kwargs.get("quality_threshold", 0.8),
|
| 599 |
+
"cache_size": kwargs.get("cache_size", 1000),
|
| 600 |
+
},
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
return config
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
if __name__ == "__main__":
|
| 607 |
+
# Test the RLE compression module
|
| 608 |
+
print("Testing RLE Compression Module...")
|
| 609 |
+
|
| 610 |
+
# Create test data
|
| 611 |
+
test_data = torch.randint(0, 2, (100,))
|
| 612 |
+
|
| 613 |
+
# Add some runs for better compression
|
| 614 |
+
test_data[20:30] = 1
|
| 615 |
+
test_data[50:70] = 0
|
| 616 |
+
test_data[80:90] = 1
|
| 617 |
+
|
| 618 |
+
print(f"Original data shape: {test_data.shape}")
|
| 619 |
+
print(f"Original data: {test_data[:20]}...")
|
| 620 |
+
|
| 621 |
+
# Test different encoding schemes
|
| 622 |
+
schemes = ["basic", "delta", "hierarchical", "adaptive"]
|
| 623 |
+
|
| 624 |
+
for scheme in schemes:
|
| 625 |
+
print(f"\nTesting {scheme} scheme:")
|
| 626 |
+
encoder = RLEEncoder(scheme=scheme)
|
| 627 |
+
|
| 628 |
+
try:
|
| 629 |
+
# Encode
|
| 630 |
+
compressed, metadata = encoder.encode(test_data)
|
| 631 |
+
print(f" Compressed size: {compressed.numel()}")
|
| 632 |
+
print(f" Compression ratio: {metadata['compression_ratio']:.3f}")
|
| 633 |
+
|
| 634 |
+
# Decode
|
| 635 |
+
reconstructed = encoder.decode(compressed, metadata)
|
| 636 |
+
|
| 637 |
+
# Check reconstruction quality
|
| 638 |
+
error = torch.mean((test_data.float() - reconstructed.float()) ** 2)
|
| 639 |
+
print(f" Reconstruction error: {error.item():.6f}")
|
| 640 |
+
|
| 641 |
+
if error.item() < 1e-6:
|
| 642 |
+
print(" ✅ Perfect reconstruction")
|
| 643 |
+
else:
|
| 644 |
+
print(" ❌ Reconstruction error detected")
|
| 645 |
+
|
| 646 |
+
except Exception as e:
|
| 647 |
+
print(f" ❌ Error: {e}")
|
| 648 |
+
|
| 649 |
+
# Benchmark all schemes
|
| 650 |
+
print("\nBenchmarking compression schemes...")
|
| 651 |
+
benchmark_results = benchmark_compression_schemes(test_data)
|
| 652 |
+
|
| 653 |
+
for scheme, results in benchmark_results.items():
|
| 654 |
+
if results["success"]:
|
| 655 |
+
print(f"{scheme:12}: ratio={results['compression_ratio']:.3f}, "
|
| 656 |
+
f"error={results['reconstruction_error']:.6f}")
|
| 657 |
+
else:
|
| 658 |
+
print(f"{scheme:12}: FAILED - {results.get('error', 'Unknown error')}")
|
| 659 |
+
|
| 660 |
+
print("\nRLE Compression Module test completed!")
|