WCNegentropy commited on
Commit
d1e4760
·
verified ·
1 Parent(s): 1e5dcf2

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
bit_transformer/BTLM_Extensions/rle_compression.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RLE Compression Extension for BitTransformerLM
3
+ ==============================================
4
+
5
+ Advanced Run-Length Encoding compression module with multiple encoding schemes,
6
+ adaptive compression, and training integration for BitTransformerLM.
7
+
8
+ Key features:
9
+ - Multiple RLE encoding schemes (basic, delta, hierarchical)
10
+ - Adaptive compression with quality thresholds
11
+ - Training integration with compression-aware loss
12
+ - Batch processing and vectorized operations
13
+ - Compatible with BitTransformerLM's training infrastructure
14
+ """
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from typing import List, Tuple, Optional, Dict, Any, Union
19
+ import warnings
20
+ import math
21
+ from collections import defaultdict
22
+ import numpy as np
23
+
24
+
25
+ class RLEEncoder:
26
+ """
27
+ Advanced Run-Length Encoder with multiple encoding schemes.
28
+
29
+ Supports:
30
+ - Basic RLE: (value, count) pairs
31
+ - Delta RLE: Differences between consecutive runs
32
+ - Hierarchical RLE: Multi-level compression
33
+ - Adaptive RLE: Chooses best scheme based on data
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ scheme: str = "adaptive",
39
+ min_run_length: int = 2,
40
+ max_value: int = 255,
41
+ delta_threshold: float = 0.7,
42
+ hierarchical_levels: int = 2,
43
+ ):
44
+ """
45
+ Args:
46
+ scheme: Encoding scheme ('basic', 'delta', 'hierarchical', 'adaptive')
47
+ min_run_length: Minimum run length to compress
48
+ max_value: Maximum value for encoding
49
+ delta_threshold: Compression ratio threshold for delta encoding
50
+ hierarchical_levels: Number of levels for hierarchical encoding
51
+ """
52
+ self.scheme = scheme
53
+ self.min_run_length = min_run_length
54
+ self.max_value = max_value
55
+ self.delta_threshold = delta_threshold
56
+ self.hierarchical_levels = hierarchical_levels
57
+
58
+ self.stats = {
59
+ "total_compressions": 0,
60
+ "total_original_size": 0,
61
+ "total_compressed_size": 0,
62
+ "scheme_usage": defaultdict(int),
63
+ }
64
+
65
+ def encode_basic_rle(self, data: torch.Tensor) -> torch.Tensor:
66
+ """Basic run-length encoding: (value, count) pairs."""
67
+ if data.numel() == 0:
68
+ return torch.tensor([], dtype=torch.uint8)
69
+
70
+ data_flat = data.flatten()
71
+ encoded = []
72
+
73
+ current_val = data_flat[0].item()
74
+ current_count = 1
75
+
76
+ for i in range(1, len(data_flat)):
77
+ val = data_flat[i].item()
78
+ if val == current_val and current_count < 255:
79
+ current_count += 1
80
+ else:
81
+ if current_count >= self.min_run_length:
82
+ encoded.extend([current_val, current_count])
83
+ else:
84
+ # Store individual values for short runs
85
+ for _ in range(current_count):
86
+ encoded.append(current_val)
87
+ current_val = val
88
+ current_count = 1
89
+
90
+ # Handle last run
91
+ if current_count >= self.min_run_length:
92
+ encoded.extend([current_val, current_count])
93
+ else:
94
+ for _ in range(current_count):
95
+ encoded.append(current_val)
96
+
97
+ return torch.tensor(encoded, dtype=torch.uint8)
98
+
99
+ def decode_basic_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
100
+ """Decode basic run-length encoded data."""
101
+ if encoded.numel() == 0:
102
+ return torch.tensor([], dtype=torch.long)
103
+
104
+ decoded = []
105
+ i = 0
106
+
107
+ while i < len(encoded):
108
+ if i + 1 < len(encoded):
109
+ val = encoded[i].item()
110
+ count = encoded[i + 1].item()
111
+
112
+ # Check if this looks like a (value, count) pair
113
+ if count > 1 and count <= 255:
114
+ decoded.extend([val] * count)
115
+ i += 2
116
+ else:
117
+ # Individual value
118
+ decoded.append(val)
119
+ i += 1
120
+ else:
121
+ decoded.append(encoded[i].item())
122
+ i += 1
123
+
124
+ result = torch.tensor(decoded, dtype=torch.long)
125
+
126
+ # Trim or pad to target length if specified
127
+ if target_length is not None:
128
+ if len(result) > target_length:
129
+ result = result[:target_length]
130
+ elif len(result) < target_length:
131
+ result = F.pad(result, (0, target_length - len(result)))
132
+
133
+ return result
134
+
135
+ def encode_delta_rle(self, data: torch.Tensor) -> torch.Tensor:
136
+ """Delta run-length encoding: encode differences between values."""
137
+ if data.numel() <= 1:
138
+ return self.encode_basic_rle(data)
139
+
140
+ data_flat = data.flatten()
141
+
142
+ # Compute deltas
143
+ deltas = torch.diff(data_flat, prepend=data_flat[0:1])
144
+
145
+ # Apply basic RLE to deltas (shifted to handle negatives)
146
+ shifted_deltas = deltas + 128 # Shift to 0-255 range
147
+ shifted_deltas = torch.clamp(shifted_deltas, 0, 255)
148
+
149
+ delta_encoded = self.encode_basic_rle(shifted_deltas)
150
+
151
+ # Prepend original first value
152
+ result = torch.cat([data_flat[0:1].to(torch.uint8), delta_encoded])
153
+ return result
154
+
155
+ def decode_delta_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
156
+ """Decode delta run-length encoded data."""
157
+ if encoded.numel() <= 1:
158
+ return self.decode_basic_rle(encoded, target_length)
159
+
160
+ # First value is the original value
161
+ first_val = encoded[0].item()
162
+ delta_encoded = encoded[1:]
163
+
164
+ # Decode deltas
165
+ deltas = self.decode_basic_rle(delta_encoded)
166
+
167
+ # Unshift deltas
168
+ deltas = deltas.float() - 128
169
+
170
+ # Reconstruct original sequence
171
+ if deltas.numel() > 0:
172
+ deltas[0] = first_val # Replace first delta with original value
173
+ result = torch.cumsum(deltas, dim=0).long()
174
+ else:
175
+ result = torch.tensor([first_val], dtype=torch.long)
176
+
177
+ # Trim or pad to target length
178
+ if target_length is not None:
179
+ if len(result) > target_length:
180
+ result = result[:target_length]
181
+ elif len(result) < target_length:
182
+ result = F.pad(result, (0, target_length - len(result)))
183
+
184
+ return result
185
+
186
+ def encode_hierarchical_rle(self, data: torch.Tensor) -> torch.Tensor:
187
+ """Hierarchical RLE: Apply RLE recursively for better compression."""
188
+ current_data = data.clone()
189
+
190
+ for level in range(self.hierarchical_levels):
191
+ encoded = self.encode_basic_rle(current_data)
192
+
193
+ # Check if compression is beneficial
194
+ if encoded.numel() >= current_data.numel() * 0.9:
195
+ # Compression not beneficial, return previous level
196
+ break
197
+
198
+ current_data = encoded
199
+
200
+ return current_data
201
+
202
+ def decode_hierarchical_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None, levels: int = None) -> torch.Tensor:
203
+ """Decode hierarchical RLE data."""
204
+ if levels is None:
205
+ levels = self.hierarchical_levels
206
+
207
+ current_data = encoded.clone()
208
+
209
+ for level in range(levels):
210
+ try:
211
+ current_data = self.decode_basic_rle(current_data)
212
+ except Exception:
213
+ # If decoding fails, return current state
214
+ break
215
+
216
+ # Final length adjustment
217
+ if target_length is not None and current_data.numel() != target_length:
218
+ if current_data.numel() > target_length:
219
+ current_data = current_data[:target_length]
220
+ else:
221
+ current_data = F.pad(current_data, (0, target_length - current_data.numel()))
222
+
223
+ return current_data
224
+
225
+ def encode(self, data: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]:
226
+ """
227
+ Encode data using the configured scheme.
228
+
229
+ Args:
230
+ data: Input tensor to compress
231
+
232
+ Returns:
233
+ Tuple of (encoded_data, metadata)
234
+ """
235
+ original_shape = data.shape
236
+ original_size = data.numel()
237
+
238
+ if self.scheme == "basic":
239
+ encoded = self.encode_basic_rle(data)
240
+ scheme_used = "basic"
241
+ elif self.scheme == "delta":
242
+ encoded = self.encode_delta_rle(data)
243
+ scheme_used = "delta"
244
+ elif self.scheme == "hierarchical":
245
+ encoded = self.encode_hierarchical_rle(data)
246
+ scheme_used = "hierarchical"
247
+ elif self.scheme == "adaptive":
248
+ # Try all schemes and pick the best one
249
+ basic_encoded = self.encode_basic_rle(data)
250
+ delta_encoded = self.encode_delta_rle(data)
251
+ hierarchical_encoded = self.encode_hierarchical_rle(data)
252
+
253
+ candidates = {
254
+ "basic": basic_encoded,
255
+ "delta": delta_encoded,
256
+ "hierarchical": hierarchical_encoded,
257
+ }
258
+
259
+ # Choose scheme with best compression ratio
260
+ best_scheme = min(candidates.keys(), key=lambda k: candidates[k].numel())
261
+ encoded = candidates[best_scheme]
262
+ scheme_used = best_scheme
263
+ else:
264
+ raise ValueError(f"Unknown encoding scheme: {self.scheme}")
265
+
266
+ # Update statistics
267
+ self.stats["total_compressions"] += 1
268
+ self.stats["total_original_size"] += original_size
269
+ self.stats["total_compressed_size"] += encoded.numel()
270
+ self.stats["scheme_usage"][scheme_used] += 1
271
+
272
+ metadata = {
273
+ "scheme": scheme_used,
274
+ "original_shape": original_shape,
275
+ "original_size": original_size,
276
+ "compressed_size": encoded.numel(),
277
+ "compression_ratio": encoded.numel() / original_size if original_size > 0 else 1.0,
278
+ }
279
+
280
+ return encoded, metadata
281
+
282
+ def decode(self, encoded: torch.Tensor, metadata: Dict[str, Any]) -> torch.Tensor:
283
+ """
284
+ Decode compressed data using metadata.
285
+
286
+ Args:
287
+ encoded: Compressed data
288
+ metadata: Metadata from encoding
289
+
290
+ Returns:
291
+ Decoded tensor
292
+ """
293
+ scheme = metadata["scheme"]
294
+ original_shape = metadata["original_shape"]
295
+ target_length = math.prod(original_shape) if original_shape else None
296
+
297
+ if scheme == "basic":
298
+ decoded = self.decode_basic_rle(encoded, target_length)
299
+ elif scheme == "delta":
300
+ decoded = self.decode_delta_rle(encoded, target_length)
301
+ elif scheme == "hierarchical":
302
+ decoded = self.decode_hierarchical_rle(encoded, target_length)
303
+ else:
304
+ raise ValueError(f"Unknown decoding scheme: {scheme}")
305
+
306
+ # Reshape to original shape
307
+ if original_shape and decoded.numel() >= math.prod(original_shape):
308
+ decoded = decoded[:math.prod(original_shape)].reshape(original_shape)
309
+
310
+ return decoded
311
+
312
+ def get_compression_stats(self) -> Dict[str, float]:
313
+ """Get compression statistics."""
314
+ if self.stats["total_original_size"] == 0:
315
+ return {"average_compression_ratio": 1.0, "total_savings": 0.0}
316
+
317
+ avg_ratio = self.stats["total_compressed_size"] / self.stats["total_original_size"]
318
+ total_savings = self.stats["total_original_size"] - self.stats["total_compressed_size"]
319
+
320
+ return {
321
+ "average_compression_ratio": avg_ratio,
322
+ "total_savings": total_savings,
323
+ "total_compressions": self.stats["total_compressions"],
324
+ "scheme_usage": dict(self.stats["scheme_usage"]),
325
+ }
326
+
327
+
328
+ class CompressedBitDataset(torch.utils.data.Dataset):
329
+ """
330
+ Dataset wrapper that applies RLE compression on-the-fly during training.
331
+
332
+ This allows for memory-efficient storage of large bit sequences while
333
+ maintaining fast access during training.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ data: torch.Tensor,
339
+ encoder: RLEEncoder,
340
+ compress_probability: float = 0.5,
341
+ cache_size: int = 1000,
342
+ ):
343
+ """
344
+ Args:
345
+ data: Original bit sequence data
346
+ encoder: RLE encoder instance
347
+ compress_probability: Probability of returning compressed data
348
+ cache_size: Number of compressed items to cache
349
+ """
350
+ self.data = data
351
+ self.encoder = encoder
352
+ self.compress_probability = compress_probability
353
+ self.cache_size = cache_size
354
+ self.cache = {}
355
+ self.access_count = defaultdict(int)
356
+
357
+ def __len__(self):
358
+ return len(self.data)
359
+
360
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
361
+ """
362
+ Get item with optional compression.
363
+
364
+ Returns:
365
+ Tuple of (data, metadata) where metadata indicates if compressed
366
+ """
367
+ original_item = self.data[idx]
368
+
369
+ # Randomly decide whether to compress
370
+ if torch.rand(1).item() < self.compress_probability:
371
+ # Check cache first
372
+ if idx in self.cache:
373
+ compressed, metadata = self.cache[idx]
374
+ self.access_count[idx] += 1
375
+ metadata["from_cache"] = True
376
+ return compressed, metadata
377
+
378
+ # Compress item
379
+ compressed, metadata = self.encoder.encode(original_item)
380
+
381
+ # Add to cache if there's room
382
+ if len(self.cache) < self.cache_size:
383
+ self.cache[idx] = (compressed, metadata)
384
+ elif self.access_count:
385
+ # Replace least accessed item
386
+ least_accessed = min(self.cache.keys(), key=lambda k: self.access_count[k])
387
+ del self.cache[least_accessed]
388
+ del self.access_count[least_accessed]
389
+ self.cache[idx] = (compressed, metadata)
390
+
391
+ metadata["from_cache"] = False
392
+ return compressed, metadata
393
+ else:
394
+ # Return original data
395
+ metadata = {
396
+ "scheme": "uncompressed",
397
+ "original_shape": original_item.shape,
398
+ "compressed": False,
399
+ "from_cache": False,
400
+ }
401
+ return original_item, metadata
402
+
403
+
404
+ def create_compression_aware_loss(
405
+ base_loss_fn,
406
+ compression_penalty: float = 0.01,
407
+ quality_threshold: float = 0.8,
408
+ ) -> callable:
409
+ """
410
+ Create a loss function that penalizes poor compression quality.
411
+
412
+ Args:
413
+ base_loss_fn: Base loss function (e.g., CrossEntropyLoss)
414
+ compression_penalty: Penalty weight for compression artifacts
415
+ quality_threshold: Minimum compression quality threshold
416
+
417
+ Returns:
418
+ Compression-aware loss function
419
+ """
420
+ def compression_aware_loss(
421
+ logits: torch.Tensor,
422
+ targets: torch.Tensor,
423
+ metadata_batch: Optional[List[Dict[str, Any]]] = None,
424
+ ) -> torch.Tensor:
425
+ """
426
+ Compute loss with compression quality penalty.
427
+
428
+ Args:
429
+ logits: Model output logits
430
+ targets: Target labels
431
+ metadata_batch: Batch of compression metadata
432
+
433
+ Returns:
434
+ Adjusted loss tensor
435
+ """
436
+ base_loss = base_loss_fn(logits, targets)
437
+
438
+ if metadata_batch is None:
439
+ return base_loss
440
+
441
+ # Compute compression quality penalty
442
+ penalty = 0.0
443
+ compressed_items = 0
444
+
445
+ for metadata in metadata_batch:
446
+ if metadata.get("compressed", False):
447
+ compressed_items += 1
448
+ compression_ratio = metadata.get("compression_ratio", 1.0)
449
+
450
+ # Penalty for poor compression
451
+ if compression_ratio > quality_threshold:
452
+ quality_penalty = (compression_ratio - quality_threshold) ** 2
453
+ penalty += quality_penalty
454
+
455
+ if compressed_items > 0:
456
+ penalty = penalty / compressed_items # Average penalty
457
+ total_loss = base_loss + compression_penalty * penalty
458
+ else:
459
+ total_loss = base_loss
460
+
461
+ return total_loss
462
+
463
+ return compression_aware_loss
464
+
465
+
466
+ def integrate_rle_with_training(
467
+ model,
468
+ data: torch.Tensor,
469
+ encoder_config: Optional[Dict[str, Any]] = None,
470
+ compression_config: Optional[Dict[str, Any]] = None,
471
+ ) -> Tuple[CompressedBitDataset, callable]:
472
+ """
473
+ Integrate RLE compression with BitTransformerLM training.
474
+
475
+ Args:
476
+ model: BitTransformerLM model
477
+ data: Training data tensor
478
+ encoder_config: Configuration for RLE encoder
479
+ compression_config: Configuration for compression-aware training
480
+
481
+ Returns:
482
+ Tuple of (compressed_dataset, compression_aware_loss_fn)
483
+ """
484
+ # Default configurations
485
+ if encoder_config is None:
486
+ encoder_config = {
487
+ "scheme": "adaptive",
488
+ "min_run_length": 2,
489
+ "delta_threshold": 0.7,
490
+ }
491
+
492
+ if compression_config is None:
493
+ compression_config = {
494
+ "compress_probability": 0.3,
495
+ "compression_penalty": 0.01,
496
+ "quality_threshold": 0.8,
497
+ "cache_size": 1000,
498
+ }
499
+
500
+ # Create encoder and dataset
501
+ encoder = RLEEncoder(**encoder_config)
502
+ dataset = CompressedBitDataset(
503
+ data,
504
+ encoder,
505
+ compress_probability=compression_config["compress_probability"],
506
+ cache_size=compression_config["cache_size"],
507
+ )
508
+
509
+ # Create compression-aware loss
510
+ base_loss = torch.nn.CrossEntropyLoss()
511
+ loss_fn = create_compression_aware_loss(
512
+ base_loss,
513
+ compression_penalty=compression_config["compression_penalty"],
514
+ quality_threshold=compression_config["quality_threshold"],
515
+ )
516
+
517
+ return dataset, loss_fn
518
+
519
+
520
+ def benchmark_compression_schemes(
521
+ test_data: torch.Tensor,
522
+ schemes: List[str] = ["basic", "delta", "hierarchical", "adaptive"],
523
+ ) -> Dict[str, Dict[str, float]]:
524
+ """
525
+ Benchmark different compression schemes on test data.
526
+
527
+ Args:
528
+ test_data: Test data tensor
529
+ schemes: List of schemes to test
530
+
531
+ Returns:
532
+ Dictionary with benchmark results for each scheme
533
+ """
534
+ results = {}
535
+
536
+ for scheme in schemes:
537
+ encoder = RLEEncoder(scheme=scheme)
538
+
539
+ # Test compression/decompression
540
+ try:
541
+ compressed, metadata = encoder.encode(test_data)
542
+ reconstructed = encoder.decode(compressed, metadata)
543
+
544
+ # Compute metrics
545
+ compression_ratio = compressed.numel() / test_data.numel()
546
+ reconstruction_error = torch.mean((test_data.float() - reconstructed.float()) ** 2).item()
547
+
548
+ results[scheme] = {
549
+ "compression_ratio": compression_ratio,
550
+ "reconstruction_error": reconstruction_error,
551
+ "compressed_size": compressed.numel(),
552
+ "original_size": test_data.numel(),
553
+ "success": True,
554
+ }
555
+ except Exception as e:
556
+ results[scheme] = {
557
+ "compression_ratio": 1.0,
558
+ "reconstruction_error": float("inf"),
559
+ "compressed_size": test_data.numel(),
560
+ "original_size": test_data.numel(),
561
+ "success": False,
562
+ "error": str(e),
563
+ }
564
+
565
+ return results
566
+
567
+
568
+ # Example usage and utilities
569
+ def create_rle_training_config(
570
+ scheme: str = "adaptive",
571
+ compress_probability: float = 0.3,
572
+ compression_penalty: float = 0.01,
573
+ **kwargs
574
+ ) -> Dict[str, Any]:
575
+ """
576
+ Create configuration for RLE-enhanced training.
577
+
578
+ Args:
579
+ scheme: RLE encoding scheme
580
+ compress_probability: Probability of compression during training
581
+ compression_penalty: Loss penalty for compression artifacts
582
+ **kwargs: Additional configuration options
583
+
584
+ Returns:
585
+ Dictionary with RLE training configuration
586
+ """
587
+ config = {
588
+ "compression_type": "rle",
589
+ "encoder_config": {
590
+ "scheme": scheme,
591
+ "min_run_length": kwargs.get("min_run_length", 2),
592
+ "delta_threshold": kwargs.get("delta_threshold", 0.7),
593
+ "hierarchical_levels": kwargs.get("hierarchical_levels", 2),
594
+ },
595
+ "training_config": {
596
+ "compress_probability": compress_probability,
597
+ "compression_penalty": compression_penalty,
598
+ "quality_threshold": kwargs.get("quality_threshold", 0.8),
599
+ "cache_size": kwargs.get("cache_size", 1000),
600
+ },
601
+ }
602
+
603
+ return config
604
+
605
+
606
+ if __name__ == "__main__":
607
+ # Test the RLE compression module
608
+ print("Testing RLE Compression Module...")
609
+
610
+ # Create test data
611
+ test_data = torch.randint(0, 2, (100,))
612
+
613
+ # Add some runs for better compression
614
+ test_data[20:30] = 1
615
+ test_data[50:70] = 0
616
+ test_data[80:90] = 1
617
+
618
+ print(f"Original data shape: {test_data.shape}")
619
+ print(f"Original data: {test_data[:20]}...")
620
+
621
+ # Test different encoding schemes
622
+ schemes = ["basic", "delta", "hierarchical", "adaptive"]
623
+
624
+ for scheme in schemes:
625
+ print(f"\nTesting {scheme} scheme:")
626
+ encoder = RLEEncoder(scheme=scheme)
627
+
628
+ try:
629
+ # Encode
630
+ compressed, metadata = encoder.encode(test_data)
631
+ print(f" Compressed size: {compressed.numel()}")
632
+ print(f" Compression ratio: {metadata['compression_ratio']:.3f}")
633
+
634
+ # Decode
635
+ reconstructed = encoder.decode(compressed, metadata)
636
+
637
+ # Check reconstruction quality
638
+ error = torch.mean((test_data.float() - reconstructed.float()) ** 2)
639
+ print(f" Reconstruction error: {error.item():.6f}")
640
+
641
+ if error.item() < 1e-6:
642
+ print(" ✅ Perfect reconstruction")
643
+ else:
644
+ print(" ❌ Reconstruction error detected")
645
+
646
+ except Exception as e:
647
+ print(f" ❌ Error: {e}")
648
+
649
+ # Benchmark all schemes
650
+ print("\nBenchmarking compression schemes...")
651
+ benchmark_results = benchmark_compression_schemes(test_data)
652
+
653
+ for scheme, results in benchmark_results.items():
654
+ if results["success"]:
655
+ print(f"{scheme:12}: ratio={results['compression_ratio']:.3f}, "
656
+ f"error={results['reconstruction_error']:.6f}")
657
+ else:
658
+ print(f"{scheme:12}: FAILED - {results.get('error', 'Unknown error')}")
659
+
660
+ print("\nRLE Compression Module test completed!")