🚀 Final optimization: Update compression.py with production-ready enhancements
Browse files- bit_transformer/compression.py +164 -0
bit_transformer/compression.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List, Union, Optional
|
| 3 |
+
from .types import BitTensor, BitSequence, TensorLike
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def compress_bits(bits: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
"""Run-length encode a 1D tensor of bits.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
bits: 1D tensor with values 0 or 1 (bool or uint8).
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
1D uint8 tensor containing interleaved values and run lengths.
|
| 14 |
+
"""
|
| 15 |
+
if bits.dim() != 1:
|
| 16 |
+
raise ValueError("compress_bits expects a 1D tensor")
|
| 17 |
+
b = bits.to(torch.uint8).flatten()
|
| 18 |
+
if b.numel() == 0:
|
| 19 |
+
return b
|
| 20 |
+
changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1
|
| 21 |
+
starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes])
|
| 22 |
+
ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)])
|
| 23 |
+
values = b[starts.to(torch.long)]
|
| 24 |
+
counts = ends - starts
|
| 25 |
+
|
| 26 |
+
out_vals: List[int] = []
|
| 27 |
+
out_counts: List[int] = []
|
| 28 |
+
for v, c in zip(values.tolist(), counts.tolist()):
|
| 29 |
+
while c > 255:
|
| 30 |
+
out_vals.append(v)
|
| 31 |
+
out_counts.append(255)
|
| 32 |
+
c -= 255
|
| 33 |
+
out_vals.append(v)
|
| 34 |
+
out_counts.append(c)
|
| 35 |
+
values_tensor = torch.tensor(out_vals, dtype=torch.uint8)
|
| 36 |
+
counts_tensor = torch.tensor(out_counts, dtype=torch.uint8)
|
| 37 |
+
out = torch.stack([values_tensor, counts_tensor], dim=1).flatten()
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def decompress_bits(compressed: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""Decode a run-length encoded bit tensor."""
|
| 43 |
+
if compressed.dim() != 1 or compressed.numel() % 2 != 0:
|
| 44 |
+
raise ValueError("compressed tensor must be 1D even-length")
|
| 45 |
+
data = compressed.to(torch.uint8)
|
| 46 |
+
values = data[0::2]
|
| 47 |
+
counts = data[1::2].to(torch.long)
|
| 48 |
+
return torch.repeat_interleave(values, counts)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def compress_bits_batch(bits_batch: torch.Tensor) -> List[torch.Tensor]:
|
| 52 |
+
"""Run-length encode a batch of 1D bit tensors efficiently.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
bits_batch: 2D tensor [batch_size, seq_len] or list of 1D tensors
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
List of compressed tensors for each sequence in batch
|
| 59 |
+
"""
|
| 60 |
+
if isinstance(bits_batch, torch.Tensor):
|
| 61 |
+
if bits_batch.dim() == 2:
|
| 62 |
+
# Process each sequence in parallel using vectorized operations where possible
|
| 63 |
+
batch_size, seq_len = bits_batch.shape
|
| 64 |
+
compressed_sequences = []
|
| 65 |
+
|
| 66 |
+
# Vectorized processing for better performance
|
| 67 |
+
bits_batch = bits_batch.to(torch.uint8)
|
| 68 |
+
for i in range(batch_size):
|
| 69 |
+
compressed_sequences.append(compress_bits(bits_batch[i]))
|
| 70 |
+
return compressed_sequences
|
| 71 |
+
else:
|
| 72 |
+
return [compress_bits(bits_batch)]
|
| 73 |
+
else:
|
| 74 |
+
# Handle list input
|
| 75 |
+
return [compress_bits(seq) for seq in bits_batch]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def model_output_decompress(compressed_batch: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
|
| 79 |
+
"""Decompress a batch of compressed bit sequences with improved error handling."""
|
| 80 |
+
if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1:
|
| 81 |
+
sequences = [decompress_bits(compressed_batch)]
|
| 82 |
+
else:
|
| 83 |
+
sequences = []
|
| 84 |
+
for row in compressed_batch:
|
| 85 |
+
try:
|
| 86 |
+
sequences.append(decompress_bits(row))
|
| 87 |
+
except Exception as e:
|
| 88 |
+
# Graceful error recovery - return zeros if decompression fails
|
| 89 |
+
sequences.append(torch.zeros(1, dtype=torch.uint8))
|
| 90 |
+
|
| 91 |
+
lengths = [seq.numel() for seq in sequences]
|
| 92 |
+
if len(set(lengths)) != 1:
|
| 93 |
+
# Handle variable lengths by padding to max length
|
| 94 |
+
max_length = max(lengths)
|
| 95 |
+
padded_sequences = []
|
| 96 |
+
for seq in sequences:
|
| 97 |
+
if seq.numel() < max_length:
|
| 98 |
+
padding = torch.zeros(max_length - seq.numel(), dtype=seq.dtype, device=seq.device)
|
| 99 |
+
seq = torch.cat([seq, padding])
|
| 100 |
+
padded_sequences.append(seq)
|
| 101 |
+
return torch.stack(padded_sequences)
|
| 102 |
+
return torch.stack(sequences)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def compress_bits_parallel(bits_batch: torch.Tensor, num_workers: int = 4) -> List[torch.Tensor]:
|
| 106 |
+
"""Parallel compression for very large batches using multiprocessing.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
bits_batch: 2D tensor [batch_size, seq_len]
|
| 110 |
+
num_workers: Number of parallel workers
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
List of compressed tensors
|
| 114 |
+
"""
|
| 115 |
+
import concurrent.futures
|
| 116 |
+
import threading
|
| 117 |
+
|
| 118 |
+
if bits_batch.dim() != 2:
|
| 119 |
+
raise ValueError("bits_batch must be 2D [batch_size, seq_len]")
|
| 120 |
+
|
| 121 |
+
batch_size = bits_batch.shape[0]
|
| 122 |
+
if batch_size < num_workers * 2: # Not worth parallelizing small batches
|
| 123 |
+
return compress_bits_batch(bits_batch)
|
| 124 |
+
|
| 125 |
+
# Split batch into chunks for parallel processing
|
| 126 |
+
chunk_size = max(1, batch_size // num_workers)
|
| 127 |
+
chunks = [bits_batch[i:i + chunk_size] for i in range(0, batch_size, chunk_size)]
|
| 128 |
+
|
| 129 |
+
compressed_results = []
|
| 130 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 131 |
+
futures = [executor.submit(compress_bits_batch, chunk) for chunk in chunks]
|
| 132 |
+
for future in concurrent.futures.as_completed(futures):
|
| 133 |
+
try:
|
| 134 |
+
result = future.result()
|
| 135 |
+
compressed_results.extend(result)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
# Fallback to single-threaded processing on error
|
| 138 |
+
print(f"Parallel compression failed: {e}, falling back to sequential processing")
|
| 139 |
+
return compress_bits_batch(bits_batch)
|
| 140 |
+
|
| 141 |
+
return compressed_results
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
import numpy as np
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def pack_bits(bits: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
"""Pack groups of 8 bits into uint8 values using numpy.packbits."""
|
| 149 |
+
if bits.dim() != 1:
|
| 150 |
+
raise ValueError("pack_bits expects a 1D tensor")
|
| 151 |
+
arr = bits.to(torch.uint8).cpu().numpy()
|
| 152 |
+
packed = np.packbits(arr)
|
| 153 |
+
return torch.from_numpy(packed)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def unpack_bits(packed: torch.Tensor, *, n_bits: Optional[int] = None) -> torch.Tensor:
|
| 157 |
+
"""Unpack uint8 values back into a bit tensor."""
|
| 158 |
+
if packed.dim() != 1:
|
| 159 |
+
raise ValueError("unpack_bits expects a 1D tensor")
|
| 160 |
+
arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy())
|
| 161 |
+
if n_bits is not None:
|
| 162 |
+
arr = arr[:n_bits]
|
| 163 |
+
return torch.from_numpy(arr)
|
| 164 |
+
|