WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
d1e4760 verified
"""
RLE Compression Extension for BitTransformerLM
==============================================
Advanced Run-Length Encoding compression module with multiple encoding schemes,
adaptive compression, and training integration for BitTransformerLM.
Key features:
- Multiple RLE encoding schemes (basic, delta, hierarchical)
- Adaptive compression with quality thresholds
- Training integration with compression-aware loss
- Batch processing and vectorized operations
- Compatible with BitTransformerLM's training infrastructure
"""
import torch
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict, Any, Union
import warnings
import math
from collections import defaultdict
import numpy as np
class RLEEncoder:
"""
Advanced Run-Length Encoder with multiple encoding schemes.
Supports:
- Basic RLE: (value, count) pairs
- Delta RLE: Differences between consecutive runs
- Hierarchical RLE: Multi-level compression
- Adaptive RLE: Chooses best scheme based on data
"""
def __init__(
self,
scheme: str = "adaptive",
min_run_length: int = 2,
max_value: int = 255,
delta_threshold: float = 0.7,
hierarchical_levels: int = 2,
):
"""
Args:
scheme: Encoding scheme ('basic', 'delta', 'hierarchical', 'adaptive')
min_run_length: Minimum run length to compress
max_value: Maximum value for encoding
delta_threshold: Compression ratio threshold for delta encoding
hierarchical_levels: Number of levels for hierarchical encoding
"""
self.scheme = scheme
self.min_run_length = min_run_length
self.max_value = max_value
self.delta_threshold = delta_threshold
self.hierarchical_levels = hierarchical_levels
self.stats = {
"total_compressions": 0,
"total_original_size": 0,
"total_compressed_size": 0,
"scheme_usage": defaultdict(int),
}
def encode_basic_rle(self, data: torch.Tensor) -> torch.Tensor:
"""Basic run-length encoding: (value, count) pairs."""
if data.numel() == 0:
return torch.tensor([], dtype=torch.uint8)
data_flat = data.flatten()
encoded = []
current_val = data_flat[0].item()
current_count = 1
for i in range(1, len(data_flat)):
val = data_flat[i].item()
if val == current_val and current_count < 255:
current_count += 1
else:
if current_count >= self.min_run_length:
encoded.extend([current_val, current_count])
else:
# Store individual values for short runs
for _ in range(current_count):
encoded.append(current_val)
current_val = val
current_count = 1
# Handle last run
if current_count >= self.min_run_length:
encoded.extend([current_val, current_count])
else:
for _ in range(current_count):
encoded.append(current_val)
return torch.tensor(encoded, dtype=torch.uint8)
def decode_basic_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
"""Decode basic run-length encoded data."""
if encoded.numel() == 0:
return torch.tensor([], dtype=torch.long)
decoded = []
i = 0
while i < len(encoded):
if i + 1 < len(encoded):
val = encoded[i].item()
count = encoded[i + 1].item()
# Check if this looks like a (value, count) pair
if count > 1 and count <= 255:
decoded.extend([val] * count)
i += 2
else:
# Individual value
decoded.append(val)
i += 1
else:
decoded.append(encoded[i].item())
i += 1
result = torch.tensor(decoded, dtype=torch.long)
# Trim or pad to target length if specified
if target_length is not None:
if len(result) > target_length:
result = result[:target_length]
elif len(result) < target_length:
result = F.pad(result, (0, target_length - len(result)))
return result
def encode_delta_rle(self, data: torch.Tensor) -> torch.Tensor:
"""Delta run-length encoding: encode differences between values."""
if data.numel() <= 1:
return self.encode_basic_rle(data)
data_flat = data.flatten()
# Compute deltas
deltas = torch.diff(data_flat, prepend=data_flat[0:1])
# Apply basic RLE to deltas (shifted to handle negatives)
shifted_deltas = deltas + 128 # Shift to 0-255 range
shifted_deltas = torch.clamp(shifted_deltas, 0, 255)
delta_encoded = self.encode_basic_rle(shifted_deltas)
# Prepend original first value
result = torch.cat([data_flat[0:1].to(torch.uint8), delta_encoded])
return result
def decode_delta_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
"""Decode delta run-length encoded data."""
if encoded.numel() <= 1:
return self.decode_basic_rle(encoded, target_length)
# First value is the original value
first_val = encoded[0].item()
delta_encoded = encoded[1:]
# Decode deltas
deltas = self.decode_basic_rle(delta_encoded)
# Unshift deltas
deltas = deltas.float() - 128
# Reconstruct original sequence
if deltas.numel() > 0:
deltas[0] = first_val # Replace first delta with original value
result = torch.cumsum(deltas, dim=0).long()
else:
result = torch.tensor([first_val], dtype=torch.long)
# Trim or pad to target length
if target_length is not None:
if len(result) > target_length:
result = result[:target_length]
elif len(result) < target_length:
result = F.pad(result, (0, target_length - len(result)))
return result
def encode_hierarchical_rle(self, data: torch.Tensor) -> torch.Tensor:
"""Hierarchical RLE: Apply RLE recursively for better compression."""
current_data = data.clone()
for level in range(self.hierarchical_levels):
encoded = self.encode_basic_rle(current_data)
# Check if compression is beneficial
if encoded.numel() >= current_data.numel() * 0.9:
# Compression not beneficial, return previous level
break
current_data = encoded
return current_data
def decode_hierarchical_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None, levels: int = None) -> torch.Tensor:
"""Decode hierarchical RLE data."""
if levels is None:
levels = self.hierarchical_levels
current_data = encoded.clone()
for level in range(levels):
try:
current_data = self.decode_basic_rle(current_data)
except Exception:
# If decoding fails, return current state
break
# Final length adjustment
if target_length is not None and current_data.numel() != target_length:
if current_data.numel() > target_length:
current_data = current_data[:target_length]
else:
current_data = F.pad(current_data, (0, target_length - current_data.numel()))
return current_data
def encode(self, data: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""
Encode data using the configured scheme.
Args:
data: Input tensor to compress
Returns:
Tuple of (encoded_data, metadata)
"""
original_shape = data.shape
original_size = data.numel()
if self.scheme == "basic":
encoded = self.encode_basic_rle(data)
scheme_used = "basic"
elif self.scheme == "delta":
encoded = self.encode_delta_rle(data)
scheme_used = "delta"
elif self.scheme == "hierarchical":
encoded = self.encode_hierarchical_rle(data)
scheme_used = "hierarchical"
elif self.scheme == "adaptive":
# Try all schemes and pick the best one
basic_encoded = self.encode_basic_rle(data)
delta_encoded = self.encode_delta_rle(data)
hierarchical_encoded = self.encode_hierarchical_rle(data)
candidates = {
"basic": basic_encoded,
"delta": delta_encoded,
"hierarchical": hierarchical_encoded,
}
# Choose scheme with best compression ratio
best_scheme = min(candidates.keys(), key=lambda k: candidates[k].numel())
encoded = candidates[best_scheme]
scheme_used = best_scheme
else:
raise ValueError(f"Unknown encoding scheme: {self.scheme}")
# Update statistics
self.stats["total_compressions"] += 1
self.stats["total_original_size"] += original_size
self.stats["total_compressed_size"] += encoded.numel()
self.stats["scheme_usage"][scheme_used] += 1
metadata = {
"scheme": scheme_used,
"original_shape": original_shape,
"original_size": original_size,
"compressed_size": encoded.numel(),
"compression_ratio": encoded.numel() / original_size if original_size > 0 else 1.0,
}
return encoded, metadata
def decode(self, encoded: torch.Tensor, metadata: Dict[str, Any]) -> torch.Tensor:
"""
Decode compressed data using metadata.
Args:
encoded: Compressed data
metadata: Metadata from encoding
Returns:
Decoded tensor
"""
scheme = metadata["scheme"]
original_shape = metadata["original_shape"]
target_length = math.prod(original_shape) if original_shape else None
if scheme == "basic":
decoded = self.decode_basic_rle(encoded, target_length)
elif scheme == "delta":
decoded = self.decode_delta_rle(encoded, target_length)
elif scheme == "hierarchical":
decoded = self.decode_hierarchical_rle(encoded, target_length)
else:
raise ValueError(f"Unknown decoding scheme: {scheme}")
# Reshape to original shape
if original_shape and decoded.numel() >= math.prod(original_shape):
decoded = decoded[:math.prod(original_shape)].reshape(original_shape)
return decoded
def get_compression_stats(self) -> Dict[str, float]:
"""Get compression statistics."""
if self.stats["total_original_size"] == 0:
return {"average_compression_ratio": 1.0, "total_savings": 0.0}
avg_ratio = self.stats["total_compressed_size"] / self.stats["total_original_size"]
total_savings = self.stats["total_original_size"] - self.stats["total_compressed_size"]
return {
"average_compression_ratio": avg_ratio,
"total_savings": total_savings,
"total_compressions": self.stats["total_compressions"],
"scheme_usage": dict(self.stats["scheme_usage"]),
}
class CompressedBitDataset(torch.utils.data.Dataset):
"""
Dataset wrapper that applies RLE compression on-the-fly during training.
This allows for memory-efficient storage of large bit sequences while
maintaining fast access during training.
"""
def __init__(
self,
data: torch.Tensor,
encoder: RLEEncoder,
compress_probability: float = 0.5,
cache_size: int = 1000,
):
"""
Args:
data: Original bit sequence data
encoder: RLE encoder instance
compress_probability: Probability of returning compressed data
cache_size: Number of compressed items to cache
"""
self.data = data
self.encoder = encoder
self.compress_probability = compress_probability
self.cache_size = cache_size
self.cache = {}
self.access_count = defaultdict(int)
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""
Get item with optional compression.
Returns:
Tuple of (data, metadata) where metadata indicates if compressed
"""
original_item = self.data[idx]
# Randomly decide whether to compress
if torch.rand(1).item() < self.compress_probability:
# Check cache first
if idx in self.cache:
compressed, metadata = self.cache[idx]
self.access_count[idx] += 1
metadata["from_cache"] = True
return compressed, metadata
# Compress item
compressed, metadata = self.encoder.encode(original_item)
# Add to cache if there's room
if len(self.cache) < self.cache_size:
self.cache[idx] = (compressed, metadata)
elif self.access_count:
# Replace least accessed item
least_accessed = min(self.cache.keys(), key=lambda k: self.access_count[k])
del self.cache[least_accessed]
del self.access_count[least_accessed]
self.cache[idx] = (compressed, metadata)
metadata["from_cache"] = False
return compressed, metadata
else:
# Return original data
metadata = {
"scheme": "uncompressed",
"original_shape": original_item.shape,
"compressed": False,
"from_cache": False,
}
return original_item, metadata
def create_compression_aware_loss(
base_loss_fn,
compression_penalty: float = 0.01,
quality_threshold: float = 0.8,
) -> callable:
"""
Create a loss function that penalizes poor compression quality.
Args:
base_loss_fn: Base loss function (e.g., CrossEntropyLoss)
compression_penalty: Penalty weight for compression artifacts
quality_threshold: Minimum compression quality threshold
Returns:
Compression-aware loss function
"""
def compression_aware_loss(
logits: torch.Tensor,
targets: torch.Tensor,
metadata_batch: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
"""
Compute loss with compression quality penalty.
Args:
logits: Model output logits
targets: Target labels
metadata_batch: Batch of compression metadata
Returns:
Adjusted loss tensor
"""
base_loss = base_loss_fn(logits, targets)
if metadata_batch is None:
return base_loss
# Compute compression quality penalty
penalty = 0.0
compressed_items = 0
for metadata in metadata_batch:
if metadata.get("compressed", False):
compressed_items += 1
compression_ratio = metadata.get("compression_ratio", 1.0)
# Penalty for poor compression
if compression_ratio > quality_threshold:
quality_penalty = (compression_ratio - quality_threshold) ** 2
penalty += quality_penalty
if compressed_items > 0:
penalty = penalty / compressed_items # Average penalty
total_loss = base_loss + compression_penalty * penalty
else:
total_loss = base_loss
return total_loss
return compression_aware_loss
def integrate_rle_with_training(
model,
data: torch.Tensor,
encoder_config: Optional[Dict[str, Any]] = None,
compression_config: Optional[Dict[str, Any]] = None,
) -> Tuple[CompressedBitDataset, callable]:
"""
Integrate RLE compression with BitTransformerLM training.
Args:
model: BitTransformerLM model
data: Training data tensor
encoder_config: Configuration for RLE encoder
compression_config: Configuration for compression-aware training
Returns:
Tuple of (compressed_dataset, compression_aware_loss_fn)
"""
# Default configurations
if encoder_config is None:
encoder_config = {
"scheme": "adaptive",
"min_run_length": 2,
"delta_threshold": 0.7,
}
if compression_config is None:
compression_config = {
"compress_probability": 0.3,
"compression_penalty": 0.01,
"quality_threshold": 0.8,
"cache_size": 1000,
}
# Create encoder and dataset
encoder = RLEEncoder(**encoder_config)
dataset = CompressedBitDataset(
data,
encoder,
compress_probability=compression_config["compress_probability"],
cache_size=compression_config["cache_size"],
)
# Create compression-aware loss
base_loss = torch.nn.CrossEntropyLoss()
loss_fn = create_compression_aware_loss(
base_loss,
compression_penalty=compression_config["compression_penalty"],
quality_threshold=compression_config["quality_threshold"],
)
return dataset, loss_fn
def benchmark_compression_schemes(
test_data: torch.Tensor,
schemes: List[str] = ["basic", "delta", "hierarchical", "adaptive"],
) -> Dict[str, Dict[str, float]]:
"""
Benchmark different compression schemes on test data.
Args:
test_data: Test data tensor
schemes: List of schemes to test
Returns:
Dictionary with benchmark results for each scheme
"""
results = {}
for scheme in schemes:
encoder = RLEEncoder(scheme=scheme)
# Test compression/decompression
try:
compressed, metadata = encoder.encode(test_data)
reconstructed = encoder.decode(compressed, metadata)
# Compute metrics
compression_ratio = compressed.numel() / test_data.numel()
reconstruction_error = torch.mean((test_data.float() - reconstructed.float()) ** 2).item()
results[scheme] = {
"compression_ratio": compression_ratio,
"reconstruction_error": reconstruction_error,
"compressed_size": compressed.numel(),
"original_size": test_data.numel(),
"success": True,
}
except Exception as e:
results[scheme] = {
"compression_ratio": 1.0,
"reconstruction_error": float("inf"),
"compressed_size": test_data.numel(),
"original_size": test_data.numel(),
"success": False,
"error": str(e),
}
return results
# Example usage and utilities
def create_rle_training_config(
scheme: str = "adaptive",
compress_probability: float = 0.3,
compression_penalty: float = 0.01,
**kwargs
) -> Dict[str, Any]:
"""
Create configuration for RLE-enhanced training.
Args:
scheme: RLE encoding scheme
compress_probability: Probability of compression during training
compression_penalty: Loss penalty for compression artifacts
**kwargs: Additional configuration options
Returns:
Dictionary with RLE training configuration
"""
config = {
"compression_type": "rle",
"encoder_config": {
"scheme": scheme,
"min_run_length": kwargs.get("min_run_length", 2),
"delta_threshold": kwargs.get("delta_threshold", 0.7),
"hierarchical_levels": kwargs.get("hierarchical_levels", 2),
},
"training_config": {
"compress_probability": compress_probability,
"compression_penalty": compression_penalty,
"quality_threshold": kwargs.get("quality_threshold", 0.8),
"cache_size": kwargs.get("cache_size", 1000),
},
}
return config
if __name__ == "__main__":
# Test the RLE compression module
print("Testing RLE Compression Module...")
# Create test data
test_data = torch.randint(0, 2, (100,))
# Add some runs for better compression
test_data[20:30] = 1
test_data[50:70] = 0
test_data[80:90] = 1
print(f"Original data shape: {test_data.shape}")
print(f"Original data: {test_data[:20]}...")
# Test different encoding schemes
schemes = ["basic", "delta", "hierarchical", "adaptive"]
for scheme in schemes:
print(f"\nTesting {scheme} scheme:")
encoder = RLEEncoder(scheme=scheme)
try:
# Encode
compressed, metadata = encoder.encode(test_data)
print(f" Compressed size: {compressed.numel()}")
print(f" Compression ratio: {metadata['compression_ratio']:.3f}")
# Decode
reconstructed = encoder.decode(compressed, metadata)
# Check reconstruction quality
error = torch.mean((test_data.float() - reconstructed.float()) ** 2)
print(f" Reconstruction error: {error.item():.6f}")
if error.item() < 1e-6:
print(" βœ… Perfect reconstruction")
else:
print(" ❌ Reconstruction error detected")
except Exception as e:
print(f" ❌ Error: {e}")
# Benchmark all schemes
print("\nBenchmarking compression schemes...")
benchmark_results = benchmark_compression_schemes(test_data)
for scheme, results in benchmark_results.items():
if results["success"]:
print(f"{scheme:12}: ratio={results['compression_ratio']:.3f}, "
f"error={results['reconstruction_error']:.6f}")
else:
print(f"{scheme:12}: FAILED - {results.get('error', 'Unknown error')}")
print("\nRLE Compression Module test completed!")