""" RLE Compression Extension for BitTransformerLM ============================================== Advanced Run-Length Encoding compression module with multiple encoding schemes, adaptive compression, and training integration for BitTransformerLM. Key features: - Multiple RLE encoding schemes (basic, delta, hierarchical) - Adaptive compression with quality thresholds - Training integration with compression-aware loss - Batch processing and vectorized operations - Compatible with BitTransformerLM's training infrastructure """ import torch import torch.nn.functional as F from typing import List, Tuple, Optional, Dict, Any, Union import warnings import math from collections import defaultdict import numpy as np class RLEEncoder: """ Advanced Run-Length Encoder with multiple encoding schemes. Supports: - Basic RLE: (value, count) pairs - Delta RLE: Differences between consecutive runs - Hierarchical RLE: Multi-level compression - Adaptive RLE: Chooses best scheme based on data """ def __init__( self, scheme: str = "adaptive", min_run_length: int = 2, max_value: int = 255, delta_threshold: float = 0.7, hierarchical_levels: int = 2, ): """ Args: scheme: Encoding scheme ('basic', 'delta', 'hierarchical', 'adaptive') min_run_length: Minimum run length to compress max_value: Maximum value for encoding delta_threshold: Compression ratio threshold for delta encoding hierarchical_levels: Number of levels for hierarchical encoding """ self.scheme = scheme self.min_run_length = min_run_length self.max_value = max_value self.delta_threshold = delta_threshold self.hierarchical_levels = hierarchical_levels self.stats = { "total_compressions": 0, "total_original_size": 0, "total_compressed_size": 0, "scheme_usage": defaultdict(int), } def encode_basic_rle(self, data: torch.Tensor) -> torch.Tensor: """Basic run-length encoding: (value, count) pairs.""" if data.numel() == 0: return torch.tensor([], dtype=torch.uint8) data_flat = data.flatten() encoded = [] current_val = data_flat[0].item() current_count = 1 for i in range(1, len(data_flat)): val = data_flat[i].item() if val == current_val and current_count < 255: current_count += 1 else: if current_count >= self.min_run_length: encoded.extend([current_val, current_count]) else: # Store individual values for short runs for _ in range(current_count): encoded.append(current_val) current_val = val current_count = 1 # Handle last run if current_count >= self.min_run_length: encoded.extend([current_val, current_count]) else: for _ in range(current_count): encoded.append(current_val) return torch.tensor(encoded, dtype=torch.uint8) def decode_basic_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor: """Decode basic run-length encoded data.""" if encoded.numel() == 0: return torch.tensor([], dtype=torch.long) decoded = [] i = 0 while i < len(encoded): if i + 1 < len(encoded): val = encoded[i].item() count = encoded[i + 1].item() # Check if this looks like a (value, count) pair if count > 1 and count <= 255: decoded.extend([val] * count) i += 2 else: # Individual value decoded.append(val) i += 1 else: decoded.append(encoded[i].item()) i += 1 result = torch.tensor(decoded, dtype=torch.long) # Trim or pad to target length if specified if target_length is not None: if len(result) > target_length: result = result[:target_length] elif len(result) < target_length: result = F.pad(result, (0, target_length - len(result))) return result def encode_delta_rle(self, data: torch.Tensor) -> torch.Tensor: """Delta run-length encoding: encode differences between values.""" if data.numel() <= 1: return self.encode_basic_rle(data) data_flat = data.flatten() # Compute deltas deltas = torch.diff(data_flat, prepend=data_flat[0:1]) # Apply basic RLE to deltas (shifted to handle negatives) shifted_deltas = deltas + 128 # Shift to 0-255 range shifted_deltas = torch.clamp(shifted_deltas, 0, 255) delta_encoded = self.encode_basic_rle(shifted_deltas) # Prepend original first value result = torch.cat([data_flat[0:1].to(torch.uint8), delta_encoded]) return result def decode_delta_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor: """Decode delta run-length encoded data.""" if encoded.numel() <= 1: return self.decode_basic_rle(encoded, target_length) # First value is the original value first_val = encoded[0].item() delta_encoded = encoded[1:] # Decode deltas deltas = self.decode_basic_rle(delta_encoded) # Unshift deltas deltas = deltas.float() - 128 # Reconstruct original sequence if deltas.numel() > 0: deltas[0] = first_val # Replace first delta with original value result = torch.cumsum(deltas, dim=0).long() else: result = torch.tensor([first_val], dtype=torch.long) # Trim or pad to target length if target_length is not None: if len(result) > target_length: result = result[:target_length] elif len(result) < target_length: result = F.pad(result, (0, target_length - len(result))) return result def encode_hierarchical_rle(self, data: torch.Tensor) -> torch.Tensor: """Hierarchical RLE: Apply RLE recursively for better compression.""" current_data = data.clone() for level in range(self.hierarchical_levels): encoded = self.encode_basic_rle(current_data) # Check if compression is beneficial if encoded.numel() >= current_data.numel() * 0.9: # Compression not beneficial, return previous level break current_data = encoded return current_data def decode_hierarchical_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None, levels: int = None) -> torch.Tensor: """Decode hierarchical RLE data.""" if levels is None: levels = self.hierarchical_levels current_data = encoded.clone() for level in range(levels): try: current_data = self.decode_basic_rle(current_data) except Exception: # If decoding fails, return current state break # Final length adjustment if target_length is not None and current_data.numel() != target_length: if current_data.numel() > target_length: current_data = current_data[:target_length] else: current_data = F.pad(current_data, (0, target_length - current_data.numel())) return current_data def encode(self, data: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Encode data using the configured scheme. Args: data: Input tensor to compress Returns: Tuple of (encoded_data, metadata) """ original_shape = data.shape original_size = data.numel() if self.scheme == "basic": encoded = self.encode_basic_rle(data) scheme_used = "basic" elif self.scheme == "delta": encoded = self.encode_delta_rle(data) scheme_used = "delta" elif self.scheme == "hierarchical": encoded = self.encode_hierarchical_rle(data) scheme_used = "hierarchical" elif self.scheme == "adaptive": # Try all schemes and pick the best one basic_encoded = self.encode_basic_rle(data) delta_encoded = self.encode_delta_rle(data) hierarchical_encoded = self.encode_hierarchical_rle(data) candidates = { "basic": basic_encoded, "delta": delta_encoded, "hierarchical": hierarchical_encoded, } # Choose scheme with best compression ratio best_scheme = min(candidates.keys(), key=lambda k: candidates[k].numel()) encoded = candidates[best_scheme] scheme_used = best_scheme else: raise ValueError(f"Unknown encoding scheme: {self.scheme}") # Update statistics self.stats["total_compressions"] += 1 self.stats["total_original_size"] += original_size self.stats["total_compressed_size"] += encoded.numel() self.stats["scheme_usage"][scheme_used] += 1 metadata = { "scheme": scheme_used, "original_shape": original_shape, "original_size": original_size, "compressed_size": encoded.numel(), "compression_ratio": encoded.numel() / original_size if original_size > 0 else 1.0, } return encoded, metadata def decode(self, encoded: torch.Tensor, metadata: Dict[str, Any]) -> torch.Tensor: """ Decode compressed data using metadata. Args: encoded: Compressed data metadata: Metadata from encoding Returns: Decoded tensor """ scheme = metadata["scheme"] original_shape = metadata["original_shape"] target_length = math.prod(original_shape) if original_shape else None if scheme == "basic": decoded = self.decode_basic_rle(encoded, target_length) elif scheme == "delta": decoded = self.decode_delta_rle(encoded, target_length) elif scheme == "hierarchical": decoded = self.decode_hierarchical_rle(encoded, target_length) else: raise ValueError(f"Unknown decoding scheme: {scheme}") # Reshape to original shape if original_shape and decoded.numel() >= math.prod(original_shape): decoded = decoded[:math.prod(original_shape)].reshape(original_shape) return decoded def get_compression_stats(self) -> Dict[str, float]: """Get compression statistics.""" if self.stats["total_original_size"] == 0: return {"average_compression_ratio": 1.0, "total_savings": 0.0} avg_ratio = self.stats["total_compressed_size"] / self.stats["total_original_size"] total_savings = self.stats["total_original_size"] - self.stats["total_compressed_size"] return { "average_compression_ratio": avg_ratio, "total_savings": total_savings, "total_compressions": self.stats["total_compressions"], "scheme_usage": dict(self.stats["scheme_usage"]), } class CompressedBitDataset(torch.utils.data.Dataset): """ Dataset wrapper that applies RLE compression on-the-fly during training. This allows for memory-efficient storage of large bit sequences while maintaining fast access during training. """ def __init__( self, data: torch.Tensor, encoder: RLEEncoder, compress_probability: float = 0.5, cache_size: int = 1000, ): """ Args: data: Original bit sequence data encoder: RLE encoder instance compress_probability: Probability of returning compressed data cache_size: Number of compressed items to cache """ self.data = data self.encoder = encoder self.compress_probability = compress_probability self.cache_size = cache_size self.cache = {} self.access_count = defaultdict(int) def __len__(self): return len(self.data) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Get item with optional compression. Returns: Tuple of (data, metadata) where metadata indicates if compressed """ original_item = self.data[idx] # Randomly decide whether to compress if torch.rand(1).item() < self.compress_probability: # Check cache first if idx in self.cache: compressed, metadata = self.cache[idx] self.access_count[idx] += 1 metadata["from_cache"] = True return compressed, metadata # Compress item compressed, metadata = self.encoder.encode(original_item) # Add to cache if there's room if len(self.cache) < self.cache_size: self.cache[idx] = (compressed, metadata) elif self.access_count: # Replace least accessed item least_accessed = min(self.cache.keys(), key=lambda k: self.access_count[k]) del self.cache[least_accessed] del self.access_count[least_accessed] self.cache[idx] = (compressed, metadata) metadata["from_cache"] = False return compressed, metadata else: # Return original data metadata = { "scheme": "uncompressed", "original_shape": original_item.shape, "compressed": False, "from_cache": False, } return original_item, metadata def create_compression_aware_loss( base_loss_fn, compression_penalty: float = 0.01, quality_threshold: float = 0.8, ) -> callable: """ Create a loss function that penalizes poor compression quality. Args: base_loss_fn: Base loss function (e.g., CrossEntropyLoss) compression_penalty: Penalty weight for compression artifacts quality_threshold: Minimum compression quality threshold Returns: Compression-aware loss function """ def compression_aware_loss( logits: torch.Tensor, targets: torch.Tensor, metadata_batch: Optional[List[Dict[str, Any]]] = None, ) -> torch.Tensor: """ Compute loss with compression quality penalty. Args: logits: Model output logits targets: Target labels metadata_batch: Batch of compression metadata Returns: Adjusted loss tensor """ base_loss = base_loss_fn(logits, targets) if metadata_batch is None: return base_loss # Compute compression quality penalty penalty = 0.0 compressed_items = 0 for metadata in metadata_batch: if metadata.get("compressed", False): compressed_items += 1 compression_ratio = metadata.get("compression_ratio", 1.0) # Penalty for poor compression if compression_ratio > quality_threshold: quality_penalty = (compression_ratio - quality_threshold) ** 2 penalty += quality_penalty if compressed_items > 0: penalty = penalty / compressed_items # Average penalty total_loss = base_loss + compression_penalty * penalty else: total_loss = base_loss return total_loss return compression_aware_loss def integrate_rle_with_training( model, data: torch.Tensor, encoder_config: Optional[Dict[str, Any]] = None, compression_config: Optional[Dict[str, Any]] = None, ) -> Tuple[CompressedBitDataset, callable]: """ Integrate RLE compression with BitTransformerLM training. Args: model: BitTransformerLM model data: Training data tensor encoder_config: Configuration for RLE encoder compression_config: Configuration for compression-aware training Returns: Tuple of (compressed_dataset, compression_aware_loss_fn) """ # Default configurations if encoder_config is None: encoder_config = { "scheme": "adaptive", "min_run_length": 2, "delta_threshold": 0.7, } if compression_config is None: compression_config = { "compress_probability": 0.3, "compression_penalty": 0.01, "quality_threshold": 0.8, "cache_size": 1000, } # Create encoder and dataset encoder = RLEEncoder(**encoder_config) dataset = CompressedBitDataset( data, encoder, compress_probability=compression_config["compress_probability"], cache_size=compression_config["cache_size"], ) # Create compression-aware loss base_loss = torch.nn.CrossEntropyLoss() loss_fn = create_compression_aware_loss( base_loss, compression_penalty=compression_config["compression_penalty"], quality_threshold=compression_config["quality_threshold"], ) return dataset, loss_fn def benchmark_compression_schemes( test_data: torch.Tensor, schemes: List[str] = ["basic", "delta", "hierarchical", "adaptive"], ) -> Dict[str, Dict[str, float]]: """ Benchmark different compression schemes on test data. Args: test_data: Test data tensor schemes: List of schemes to test Returns: Dictionary with benchmark results for each scheme """ results = {} for scheme in schemes: encoder = RLEEncoder(scheme=scheme) # Test compression/decompression try: compressed, metadata = encoder.encode(test_data) reconstructed = encoder.decode(compressed, metadata) # Compute metrics compression_ratio = compressed.numel() / test_data.numel() reconstruction_error = torch.mean((test_data.float() - reconstructed.float()) ** 2).item() results[scheme] = { "compression_ratio": compression_ratio, "reconstruction_error": reconstruction_error, "compressed_size": compressed.numel(), "original_size": test_data.numel(), "success": True, } except Exception as e: results[scheme] = { "compression_ratio": 1.0, "reconstruction_error": float("inf"), "compressed_size": test_data.numel(), "original_size": test_data.numel(), "success": False, "error": str(e), } return results # Example usage and utilities def create_rle_training_config( scheme: str = "adaptive", compress_probability: float = 0.3, compression_penalty: float = 0.01, **kwargs ) -> Dict[str, Any]: """ Create configuration for RLE-enhanced training. Args: scheme: RLE encoding scheme compress_probability: Probability of compression during training compression_penalty: Loss penalty for compression artifacts **kwargs: Additional configuration options Returns: Dictionary with RLE training configuration """ config = { "compression_type": "rle", "encoder_config": { "scheme": scheme, "min_run_length": kwargs.get("min_run_length", 2), "delta_threshold": kwargs.get("delta_threshold", 0.7), "hierarchical_levels": kwargs.get("hierarchical_levels", 2), }, "training_config": { "compress_probability": compress_probability, "compression_penalty": compression_penalty, "quality_threshold": kwargs.get("quality_threshold", 0.8), "cache_size": kwargs.get("cache_size", 1000), }, } return config if __name__ == "__main__": # Test the RLE compression module print("Testing RLE Compression Module...") # Create test data test_data = torch.randint(0, 2, (100,)) # Add some runs for better compression test_data[20:30] = 1 test_data[50:70] = 0 test_data[80:90] = 1 print(f"Original data shape: {test_data.shape}") print(f"Original data: {test_data[:20]}...") # Test different encoding schemes schemes = ["basic", "delta", "hierarchical", "adaptive"] for scheme in schemes: print(f"\nTesting {scheme} scheme:") encoder = RLEEncoder(scheme=scheme) try: # Encode compressed, metadata = encoder.encode(test_data) print(f" Compressed size: {compressed.numel()}") print(f" Compression ratio: {metadata['compression_ratio']:.3f}") # Decode reconstructed = encoder.decode(compressed, metadata) # Check reconstruction quality error = torch.mean((test_data.float() - reconstructed.float()) ** 2) print(f" Reconstruction error: {error.item():.6f}") if error.item() < 1e-6: print(" ✅ Perfect reconstruction") else: print(" ❌ Reconstruction error detected") except Exception as e: print(f" ❌ Error: {e}") # Benchmark all schemes print("\nBenchmarking compression schemes...") benchmark_results = benchmark_compression_schemes(test_data) for scheme, results in benchmark_results.items(): if results["success"]: print(f"{scheme:12}: ratio={results['compression_ratio']:.3f}, " f"error={results['reconstruction_error']:.6f}") else: print(f"{scheme:12}: FAILED - {results.get('error', 'Unknown error')}") print("\nRLE Compression Module test completed!")