| | import torch |
| | from typing import List, Union, Optional |
| | from .types import BitTensor, BitSequence, TensorLike |
| |
|
| |
|
| | def compress_bits(bits: torch.Tensor) -> torch.Tensor: |
| | """Run-length encode a 1D tensor of bits. |
| | |
| | Args: |
| | bits: 1D tensor with values 0 or 1 (bool or uint8). |
| | |
| | Returns: |
| | 1D uint8 tensor containing interleaved values and run lengths. |
| | """ |
| | if bits.dim() != 1: |
| | raise ValueError("compress_bits expects a 1D tensor") |
| | b = bits.to(torch.uint8).flatten() |
| | if b.numel() == 0: |
| | return b |
| | changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1 |
| | starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes]) |
| | ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)]) |
| | values = b[starts.to(torch.long)] |
| | counts = ends - starts |
| |
|
| | out_vals: List[int] = [] |
| | out_counts: List[int] = [] |
| | for v, c in zip(values.tolist(), counts.tolist()): |
| | while c > 255: |
| | out_vals.append(v) |
| | out_counts.append(255) |
| | c -= 255 |
| | out_vals.append(v) |
| | out_counts.append(c) |
| | values_tensor = torch.tensor(out_vals, dtype=torch.uint8) |
| | counts_tensor = torch.tensor(out_counts, dtype=torch.uint8) |
| | out = torch.stack([values_tensor, counts_tensor], dim=1).flatten() |
| | return out |
| |
|
| |
|
| | def decompress_bits(compressed: torch.Tensor) -> torch.Tensor: |
| | """Decode a run-length encoded bit tensor.""" |
| | if compressed.dim() != 1 or compressed.numel() % 2 != 0: |
| | raise ValueError("compressed tensor must be 1D even-length") |
| | data = compressed.to(torch.uint8) |
| | values = data[0::2] |
| | counts = data[1::2].to(torch.long) |
| | return torch.repeat_interleave(values, counts) |
| |
|
| |
|
| | def compress_bits_batch(bits_batch: torch.Tensor) -> List[torch.Tensor]: |
| | """Run-length encode a batch of 1D bit tensors efficiently. |
| | |
| | Args: |
| | bits_batch: 2D tensor [batch_size, seq_len] or list of 1D tensors |
| | |
| | Returns: |
| | List of compressed tensors for each sequence in batch |
| | """ |
| | if isinstance(bits_batch, torch.Tensor): |
| | if bits_batch.dim() == 2: |
| | |
| | batch_size, seq_len = bits_batch.shape |
| | compressed_sequences = [] |
| | |
| | |
| | bits_batch = bits_batch.to(torch.uint8) |
| | for i in range(batch_size): |
| | compressed_sequences.append(compress_bits(bits_batch[i])) |
| | return compressed_sequences |
| | else: |
| | return [compress_bits(bits_batch)] |
| | else: |
| | |
| | return [compress_bits(seq) for seq in bits_batch] |
| |
|
| |
|
| | def model_output_decompress(compressed_batch: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: |
| | """Decompress a batch of compressed bit sequences with improved error handling.""" |
| | if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1: |
| | sequences = [decompress_bits(compressed_batch)] |
| | else: |
| | sequences = [] |
| | for row in compressed_batch: |
| | try: |
| | sequences.append(decompress_bits(row)) |
| | except Exception as e: |
| | |
| | sequences.append(torch.zeros(1, dtype=torch.uint8)) |
| | |
| | lengths = [seq.numel() for seq in sequences] |
| | if len(set(lengths)) != 1: |
| | |
| | max_length = max(lengths) |
| | padded_sequences = [] |
| | for seq in sequences: |
| | if seq.numel() < max_length: |
| | padding = torch.zeros(max_length - seq.numel(), dtype=seq.dtype, device=seq.device) |
| | seq = torch.cat([seq, padding]) |
| | padded_sequences.append(seq) |
| | return torch.stack(padded_sequences) |
| | return torch.stack(sequences) |
| |
|
| |
|
| | def compress_bits_parallel(bits_batch: torch.Tensor, num_workers: int = 4) -> List[torch.Tensor]: |
| | """Parallel compression for very large batches using multiprocessing. |
| | |
| | Args: |
| | bits_batch: 2D tensor [batch_size, seq_len] |
| | num_workers: Number of parallel workers |
| | |
| | Returns: |
| | List of compressed tensors |
| | """ |
| | import concurrent.futures |
| | import threading |
| | |
| | if bits_batch.dim() != 2: |
| | raise ValueError("bits_batch must be 2D [batch_size, seq_len]") |
| | |
| | batch_size = bits_batch.shape[0] |
| | if batch_size < num_workers * 2: |
| | return compress_bits_batch(bits_batch) |
| | |
| | |
| | chunk_size = max(1, batch_size // num_workers) |
| | chunks = [bits_batch[i:i + chunk_size] for i in range(0, batch_size, chunk_size)] |
| | |
| | compressed_results = [] |
| | with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: |
| | futures = [executor.submit(compress_bits_batch, chunk) for chunk in chunks] |
| | for future in concurrent.futures.as_completed(futures): |
| | try: |
| | result = future.result() |
| | compressed_results.extend(result) |
| | except Exception as e: |
| | |
| | print(f"Parallel compression failed: {e}, falling back to sequential processing") |
| | return compress_bits_batch(bits_batch) |
| | |
| | return compressed_results |
| |
|
| |
|
| | import numpy as np |
| |
|
| |
|
| | def pack_bits(bits: torch.Tensor) -> torch.Tensor: |
| | """Pack groups of 8 bits into uint8 values using numpy.packbits.""" |
| | if bits.dim() != 1: |
| | raise ValueError("pack_bits expects a 1D tensor") |
| | arr = bits.to(torch.uint8).cpu().numpy() |
| | packed = np.packbits(arr) |
| | return torch.from_numpy(packed) |
| |
|
| |
|
| | def unpack_bits(packed: torch.Tensor, *, n_bits: Optional[int] = None) -> torch.Tensor: |
| | """Unpack uint8 values back into a bit tensor.""" |
| | if packed.dim() != 1: |
| | raise ValueError("unpack_bits expects a 1D tensor") |
| | arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy()) |
| | if n_bits is not None: |
| | arr = arr[:n_bits] |
| | return torch.from_numpy(arr) |
| |
|
| |
|