| """ |
| RLE Compression Extension for BitTransformerLM |
| ============================================== |
| |
| Advanced Run-Length Encoding compression module with multiple encoding schemes, |
| adaptive compression, and training integration for BitTransformerLM. |
| |
| Key features: |
| - Multiple RLE encoding schemes (basic, delta, hierarchical) |
| - Adaptive compression with quality thresholds |
| - Training integration with compression-aware loss |
| - Batch processing and vectorized operations |
| - Compatible with BitTransformerLM's training infrastructure |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from typing import List, Tuple, Optional, Dict, Any, Union |
| import warnings |
| import math |
| from collections import defaultdict |
| import numpy as np |
|
|
|
|
| class RLEEncoder: |
| """ |
| Advanced Run-Length Encoder with multiple encoding schemes. |
| |
| Supports: |
| - Basic RLE: (value, count) pairs |
| - Delta RLE: Differences between consecutive runs |
| - Hierarchical RLE: Multi-level compression |
| - Adaptive RLE: Chooses best scheme based on data |
| """ |
| |
| def __init__( |
| self, |
| scheme: str = "adaptive", |
| min_run_length: int = 2, |
| max_value: int = 255, |
| delta_threshold: float = 0.7, |
| hierarchical_levels: int = 2, |
| ): |
| """ |
| Args: |
| scheme: Encoding scheme ('basic', 'delta', 'hierarchical', 'adaptive') |
| min_run_length: Minimum run length to compress |
| max_value: Maximum value for encoding |
| delta_threshold: Compression ratio threshold for delta encoding |
| hierarchical_levels: Number of levels for hierarchical encoding |
| """ |
| self.scheme = scheme |
| self.min_run_length = min_run_length |
| self.max_value = max_value |
| self.delta_threshold = delta_threshold |
| self.hierarchical_levels = hierarchical_levels |
| |
| self.stats = { |
| "total_compressions": 0, |
| "total_original_size": 0, |
| "total_compressed_size": 0, |
| "scheme_usage": defaultdict(int), |
| } |
| |
| def encode_basic_rle(self, data: torch.Tensor) -> torch.Tensor: |
| """Basic run-length encoding: (value, count) pairs.""" |
| if data.numel() == 0: |
| return torch.tensor([], dtype=torch.uint8) |
| |
| data_flat = data.flatten() |
| encoded = [] |
| |
| current_val = data_flat[0].item() |
| current_count = 1 |
| |
| for i in range(1, len(data_flat)): |
| val = data_flat[i].item() |
| if val == current_val and current_count < 255: |
| current_count += 1 |
| else: |
| if current_count >= self.min_run_length: |
| encoded.extend([current_val, current_count]) |
| else: |
| |
| for _ in range(current_count): |
| encoded.append(current_val) |
| current_val = val |
| current_count = 1 |
| |
| |
| if current_count >= self.min_run_length: |
| encoded.extend([current_val, current_count]) |
| else: |
| for _ in range(current_count): |
| encoded.append(current_val) |
| |
| return torch.tensor(encoded, dtype=torch.uint8) |
| |
| def decode_basic_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor: |
| """Decode basic run-length encoded data.""" |
| if encoded.numel() == 0: |
| return torch.tensor([], dtype=torch.long) |
| |
| decoded = [] |
| i = 0 |
| |
| while i < len(encoded): |
| if i + 1 < len(encoded): |
| val = encoded[i].item() |
| count = encoded[i + 1].item() |
| |
| |
| if count > 1 and count <= 255: |
| decoded.extend([val] * count) |
| i += 2 |
| else: |
| |
| decoded.append(val) |
| i += 1 |
| else: |
| decoded.append(encoded[i].item()) |
| i += 1 |
| |
| result = torch.tensor(decoded, dtype=torch.long) |
| |
| |
| if target_length is not None: |
| if len(result) > target_length: |
| result = result[:target_length] |
| elif len(result) < target_length: |
| result = F.pad(result, (0, target_length - len(result))) |
| |
| return result |
| |
| def encode_delta_rle(self, data: torch.Tensor) -> torch.Tensor: |
| """Delta run-length encoding: encode differences between values.""" |
| if data.numel() <= 1: |
| return self.encode_basic_rle(data) |
| |
| data_flat = data.flatten() |
| |
| |
| deltas = torch.diff(data_flat, prepend=data_flat[0:1]) |
| |
| |
| shifted_deltas = deltas + 128 |
| shifted_deltas = torch.clamp(shifted_deltas, 0, 255) |
| |
| delta_encoded = self.encode_basic_rle(shifted_deltas) |
| |
| |
| result = torch.cat([data_flat[0:1].to(torch.uint8), delta_encoded]) |
| return result |
| |
| def decode_delta_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor: |
| """Decode delta run-length encoded data.""" |
| if encoded.numel() <= 1: |
| return self.decode_basic_rle(encoded, target_length) |
| |
| |
| first_val = encoded[0].item() |
| delta_encoded = encoded[1:] |
| |
| |
| deltas = self.decode_basic_rle(delta_encoded) |
| |
| |
| deltas = deltas.float() - 128 |
| |
| |
| if deltas.numel() > 0: |
| deltas[0] = first_val |
| result = torch.cumsum(deltas, dim=0).long() |
| else: |
| result = torch.tensor([first_val], dtype=torch.long) |
| |
| |
| if target_length is not None: |
| if len(result) > target_length: |
| result = result[:target_length] |
| elif len(result) < target_length: |
| result = F.pad(result, (0, target_length - len(result))) |
| |
| return result |
| |
| def encode_hierarchical_rle(self, data: torch.Tensor) -> torch.Tensor: |
| """Hierarchical RLE: Apply RLE recursively for better compression.""" |
| current_data = data.clone() |
| |
| for level in range(self.hierarchical_levels): |
| encoded = self.encode_basic_rle(current_data) |
| |
| |
| if encoded.numel() >= current_data.numel() * 0.9: |
| |
| break |
| |
| current_data = encoded |
| |
| return current_data |
| |
| def decode_hierarchical_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None, levels: int = None) -> torch.Tensor: |
| """Decode hierarchical RLE data.""" |
| if levels is None: |
| levels = self.hierarchical_levels |
| |
| current_data = encoded.clone() |
| |
| for level in range(levels): |
| try: |
| current_data = self.decode_basic_rle(current_data) |
| except Exception: |
| |
| break |
| |
| |
| if target_length is not None and current_data.numel() != target_length: |
| if current_data.numel() > target_length: |
| current_data = current_data[:target_length] |
| else: |
| current_data = F.pad(current_data, (0, target_length - current_data.numel())) |
| |
| return current_data |
| |
| def encode(self, data: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]: |
| """ |
| Encode data using the configured scheme. |
| |
| Args: |
| data: Input tensor to compress |
| |
| Returns: |
| Tuple of (encoded_data, metadata) |
| """ |
| original_shape = data.shape |
| original_size = data.numel() |
| |
| if self.scheme == "basic": |
| encoded = self.encode_basic_rle(data) |
| scheme_used = "basic" |
| elif self.scheme == "delta": |
| encoded = self.encode_delta_rle(data) |
| scheme_used = "delta" |
| elif self.scheme == "hierarchical": |
| encoded = self.encode_hierarchical_rle(data) |
| scheme_used = "hierarchical" |
| elif self.scheme == "adaptive": |
| |
| basic_encoded = self.encode_basic_rle(data) |
| delta_encoded = self.encode_delta_rle(data) |
| hierarchical_encoded = self.encode_hierarchical_rle(data) |
| |
| candidates = { |
| "basic": basic_encoded, |
| "delta": delta_encoded, |
| "hierarchical": hierarchical_encoded, |
| } |
| |
| |
| best_scheme = min(candidates.keys(), key=lambda k: candidates[k].numel()) |
| encoded = candidates[best_scheme] |
| scheme_used = best_scheme |
| else: |
| raise ValueError(f"Unknown encoding scheme: {self.scheme}") |
| |
| |
| self.stats["total_compressions"] += 1 |
| self.stats["total_original_size"] += original_size |
| self.stats["total_compressed_size"] += encoded.numel() |
| self.stats["scheme_usage"][scheme_used] += 1 |
| |
| metadata = { |
| "scheme": scheme_used, |
| "original_shape": original_shape, |
| "original_size": original_size, |
| "compressed_size": encoded.numel(), |
| "compression_ratio": encoded.numel() / original_size if original_size > 0 else 1.0, |
| } |
| |
| return encoded, metadata |
| |
| def decode(self, encoded: torch.Tensor, metadata: Dict[str, Any]) -> torch.Tensor: |
| """ |
| Decode compressed data using metadata. |
| |
| Args: |
| encoded: Compressed data |
| metadata: Metadata from encoding |
| |
| Returns: |
| Decoded tensor |
| """ |
| scheme = metadata["scheme"] |
| original_shape = metadata["original_shape"] |
| target_length = math.prod(original_shape) if original_shape else None |
| |
| if scheme == "basic": |
| decoded = self.decode_basic_rle(encoded, target_length) |
| elif scheme == "delta": |
| decoded = self.decode_delta_rle(encoded, target_length) |
| elif scheme == "hierarchical": |
| decoded = self.decode_hierarchical_rle(encoded, target_length) |
| else: |
| raise ValueError(f"Unknown decoding scheme: {scheme}") |
| |
| |
| if original_shape and decoded.numel() >= math.prod(original_shape): |
| decoded = decoded[:math.prod(original_shape)].reshape(original_shape) |
| |
| return decoded |
| |
| def get_compression_stats(self) -> Dict[str, float]: |
| """Get compression statistics.""" |
| if self.stats["total_original_size"] == 0: |
| return {"average_compression_ratio": 1.0, "total_savings": 0.0} |
| |
| avg_ratio = self.stats["total_compressed_size"] / self.stats["total_original_size"] |
| total_savings = self.stats["total_original_size"] - self.stats["total_compressed_size"] |
| |
| return { |
| "average_compression_ratio": avg_ratio, |
| "total_savings": total_savings, |
| "total_compressions": self.stats["total_compressions"], |
| "scheme_usage": dict(self.stats["scheme_usage"]), |
| } |
|
|
|
|
| class CompressedBitDataset(torch.utils.data.Dataset): |
| """ |
| Dataset wrapper that applies RLE compression on-the-fly during training. |
| |
| This allows for memory-efficient storage of large bit sequences while |
| maintaining fast access during training. |
| """ |
| |
| def __init__( |
| self, |
| data: torch.Tensor, |
| encoder: RLEEncoder, |
| compress_probability: float = 0.5, |
| cache_size: int = 1000, |
| ): |
| """ |
| Args: |
| data: Original bit sequence data |
| encoder: RLE encoder instance |
| compress_probability: Probability of returning compressed data |
| cache_size: Number of compressed items to cache |
| """ |
| self.data = data |
| self.encoder = encoder |
| self.compress_probability = compress_probability |
| self.cache_size = cache_size |
| self.cache = {} |
| self.access_count = defaultdict(int) |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]: |
| """ |
| Get item with optional compression. |
| |
| Returns: |
| Tuple of (data, metadata) where metadata indicates if compressed |
| """ |
| original_item = self.data[idx] |
| |
| |
| if torch.rand(1).item() < self.compress_probability: |
| |
| if idx in self.cache: |
| compressed, metadata = self.cache[idx] |
| self.access_count[idx] += 1 |
| metadata["from_cache"] = True |
| return compressed, metadata |
| |
| |
| compressed, metadata = self.encoder.encode(original_item) |
| |
| |
| if len(self.cache) < self.cache_size: |
| self.cache[idx] = (compressed, metadata) |
| elif self.access_count: |
| |
| least_accessed = min(self.cache.keys(), key=lambda k: self.access_count[k]) |
| del self.cache[least_accessed] |
| del self.access_count[least_accessed] |
| self.cache[idx] = (compressed, metadata) |
| |
| metadata["from_cache"] = False |
| return compressed, metadata |
| else: |
| |
| metadata = { |
| "scheme": "uncompressed", |
| "original_shape": original_item.shape, |
| "compressed": False, |
| "from_cache": False, |
| } |
| return original_item, metadata |
|
|
|
|
| def create_compression_aware_loss( |
| base_loss_fn, |
| compression_penalty: float = 0.01, |
| quality_threshold: float = 0.8, |
| ) -> callable: |
| """ |
| Create a loss function that penalizes poor compression quality. |
| |
| Args: |
| base_loss_fn: Base loss function (e.g., CrossEntropyLoss) |
| compression_penalty: Penalty weight for compression artifacts |
| quality_threshold: Minimum compression quality threshold |
| |
| Returns: |
| Compression-aware loss function |
| """ |
| def compression_aware_loss( |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| metadata_batch: Optional[List[Dict[str, Any]]] = None, |
| ) -> torch.Tensor: |
| """ |
| Compute loss with compression quality penalty. |
| |
| Args: |
| logits: Model output logits |
| targets: Target labels |
| metadata_batch: Batch of compression metadata |
| |
| Returns: |
| Adjusted loss tensor |
| """ |
| base_loss = base_loss_fn(logits, targets) |
| |
| if metadata_batch is None: |
| return base_loss |
| |
| |
| penalty = 0.0 |
| compressed_items = 0 |
| |
| for metadata in metadata_batch: |
| if metadata.get("compressed", False): |
| compressed_items += 1 |
| compression_ratio = metadata.get("compression_ratio", 1.0) |
| |
| |
| if compression_ratio > quality_threshold: |
| quality_penalty = (compression_ratio - quality_threshold) ** 2 |
| penalty += quality_penalty |
| |
| if compressed_items > 0: |
| penalty = penalty / compressed_items |
| total_loss = base_loss + compression_penalty * penalty |
| else: |
| total_loss = base_loss |
| |
| return total_loss |
| |
| return compression_aware_loss |
|
|
|
|
| def integrate_rle_with_training( |
| model, |
| data: torch.Tensor, |
| encoder_config: Optional[Dict[str, Any]] = None, |
| compression_config: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[CompressedBitDataset, callable]: |
| """ |
| Integrate RLE compression with BitTransformerLM training. |
| |
| Args: |
| model: BitTransformerLM model |
| data: Training data tensor |
| encoder_config: Configuration for RLE encoder |
| compression_config: Configuration for compression-aware training |
| |
| Returns: |
| Tuple of (compressed_dataset, compression_aware_loss_fn) |
| """ |
| |
| if encoder_config is None: |
| encoder_config = { |
| "scheme": "adaptive", |
| "min_run_length": 2, |
| "delta_threshold": 0.7, |
| } |
| |
| if compression_config is None: |
| compression_config = { |
| "compress_probability": 0.3, |
| "compression_penalty": 0.01, |
| "quality_threshold": 0.8, |
| "cache_size": 1000, |
| } |
| |
| |
| encoder = RLEEncoder(**encoder_config) |
| dataset = CompressedBitDataset( |
| data, |
| encoder, |
| compress_probability=compression_config["compress_probability"], |
| cache_size=compression_config["cache_size"], |
| ) |
| |
| |
| base_loss = torch.nn.CrossEntropyLoss() |
| loss_fn = create_compression_aware_loss( |
| base_loss, |
| compression_penalty=compression_config["compression_penalty"], |
| quality_threshold=compression_config["quality_threshold"], |
| ) |
| |
| return dataset, loss_fn |
|
|
|
|
| def benchmark_compression_schemes( |
| test_data: torch.Tensor, |
| schemes: List[str] = ["basic", "delta", "hierarchical", "adaptive"], |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Benchmark different compression schemes on test data. |
| |
| Args: |
| test_data: Test data tensor |
| schemes: List of schemes to test |
| |
| Returns: |
| Dictionary with benchmark results for each scheme |
| """ |
| results = {} |
| |
| for scheme in schemes: |
| encoder = RLEEncoder(scheme=scheme) |
| |
| |
| try: |
| compressed, metadata = encoder.encode(test_data) |
| reconstructed = encoder.decode(compressed, metadata) |
| |
| |
| compression_ratio = compressed.numel() / test_data.numel() |
| reconstruction_error = torch.mean((test_data.float() - reconstructed.float()) ** 2).item() |
| |
| results[scheme] = { |
| "compression_ratio": compression_ratio, |
| "reconstruction_error": reconstruction_error, |
| "compressed_size": compressed.numel(), |
| "original_size": test_data.numel(), |
| "success": True, |
| } |
| except Exception as e: |
| results[scheme] = { |
| "compression_ratio": 1.0, |
| "reconstruction_error": float("inf"), |
| "compressed_size": test_data.numel(), |
| "original_size": test_data.numel(), |
| "success": False, |
| "error": str(e), |
| } |
| |
| return results |
|
|
|
|
| |
| def create_rle_training_config( |
| scheme: str = "adaptive", |
| compress_probability: float = 0.3, |
| compression_penalty: float = 0.01, |
| **kwargs |
| ) -> Dict[str, Any]: |
| """ |
| Create configuration for RLE-enhanced training. |
| |
| Args: |
| scheme: RLE encoding scheme |
| compress_probability: Probability of compression during training |
| compression_penalty: Loss penalty for compression artifacts |
| **kwargs: Additional configuration options |
| |
| Returns: |
| Dictionary with RLE training configuration |
| """ |
| config = { |
| "compression_type": "rle", |
| "encoder_config": { |
| "scheme": scheme, |
| "min_run_length": kwargs.get("min_run_length", 2), |
| "delta_threshold": kwargs.get("delta_threshold", 0.7), |
| "hierarchical_levels": kwargs.get("hierarchical_levels", 2), |
| }, |
| "training_config": { |
| "compress_probability": compress_probability, |
| "compression_penalty": compression_penalty, |
| "quality_threshold": kwargs.get("quality_threshold", 0.8), |
| "cache_size": kwargs.get("cache_size", 1000), |
| }, |
| } |
| |
| return config |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing RLE Compression Module...") |
| |
| |
| test_data = torch.randint(0, 2, (100,)) |
| |
| |
| test_data[20:30] = 1 |
| test_data[50:70] = 0 |
| test_data[80:90] = 1 |
| |
| print(f"Original data shape: {test_data.shape}") |
| print(f"Original data: {test_data[:20]}...") |
| |
| |
| schemes = ["basic", "delta", "hierarchical", "adaptive"] |
| |
| for scheme in schemes: |
| print(f"\nTesting {scheme} scheme:") |
| encoder = RLEEncoder(scheme=scheme) |
| |
| try: |
| |
| compressed, metadata = encoder.encode(test_data) |
| print(f" Compressed size: {compressed.numel()}") |
| print(f" Compression ratio: {metadata['compression_ratio']:.3f}") |
| |
| |
| reconstructed = encoder.decode(compressed, metadata) |
| |
| |
| error = torch.mean((test_data.float() - reconstructed.float()) ** 2) |
| print(f" Reconstruction error: {error.item():.6f}") |
| |
| if error.item() < 1e-6: |
| print(" β
Perfect reconstruction") |
| else: |
| print(" β Reconstruction error detected") |
| |
| except Exception as e: |
| print(f" β Error: {e}") |
| |
| |
| print("\nBenchmarking compression schemes...") |
| benchmark_results = benchmark_compression_schemes(test_data) |
| |
| for scheme, results in benchmark_results.items(): |
| if results["success"]: |
| print(f"{scheme:12}: ratio={results['compression_ratio']:.3f}, " |
| f"error={results['reconstruction_error']:.6f}") |
| else: |
| print(f"{scheme:12}: FAILED - {results.get('error', 'Unknown error')}") |
| |
| print("\nRLE Compression Module test completed!") |