| import torch |
| from typing import List, Tuple, Optional, Callable |
| from m1_compression import utils |
| import math |
|
|
| def _pdf_to_cdf(pdf: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| cdf = torch.cumsum(pdf.to(torch.float64), dim=-1).to(torch.float32) |
| cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) |
| return cdf |
|
|
| def _shift_left(x: int, base: int, base_to_pm1: int) -> int: |
| """Shift `x` one digit left.""" |
| return (x % base_to_pm1) * base |
|
|
| def _shift_left_keeping_msd(x: int, base: int, base_to_pm1: int, base_to_pm2: int) -> int: |
| """Shift `x` except MSD, which remains in place, one digit left.""" |
| return x - (x % base_to_pm1) + (x % base_to_pm2) * base |
|
|
| class BatchedArithmeticEncoder: |
| def __init__(self, base: int, precision: int): |
| self._base: int = base |
| self._base_to_pm1: int = int(base ** (precision - 1)) |
| self._base_to_pm2: int = int(base ** (precision - 2)) |
| self._precision: int = precision |
|
|
| def _get_bit_counts(self, buf_offsets: torch.Tensor, base_offsets: torch.Tensor) -> torch.Tensor: |
| """Get bit counts for all sequences.""" |
| bit_counts = buf_offsets - base_offsets |
| return bit_counts |
|
|
| |
| def flush_matching_digits( |
| self, |
| low, |
| high, |
| old_low, |
| encoding: bool = True, |
| bits_buffer: Optional[torch.Tensor] = None, |
| buf_offsets: Optional[torch.Tensor] = None, |
| num_carry_digits: Optional[torch.Tensor] = None, |
| current_code_in_int: Optional[torch.Tensor] = None, |
| _next_digit: Optional[Callable[[int], int]] = None, |
| valid: Optional[torch.Tensor] = None, |
| ): |
| valid = valid if valid is not None else True |
| while True: |
|
|
| msd_low = low // self._base_to_pm1 |
| msd_high = high // self._base_to_pm1 |
| mask = msd_low == msd_high |
| |
| mask = mask & valid |
| if not torch.any(mask): |
| break |
| |
| if encoding: |
| msd_low_old = old_low // self._base_to_pm1 |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| sel = mask.nonzero(as_tuple=False).flatten() |
| bits_buffer[buf_offsets[sel]] = msd_low[sel].to(torch.int32) |
| buf_offsets.index_add_( |
| 0, |
| sel, |
| torch.ones_like(sel, dtype=torch.int32) |
| ) |
|
|
| carry_sel = sel[(num_carry_digits[sel] > 0)] |
| if carry_sel.numel(): |
| _digit_carry = msd_low[carry_sel] |
| _digit_carry_old = msd_low_old[carry_sel] |
| carry_digit = ( |
| self._base - 1 + _digit_carry - _digit_carry_old |
| ) % self._base |
|
|
| rep_cnt = num_carry_digits[carry_sel] |
| repeats_max = rep_cnt.max().item() |
| grid = torch.arange( |
| repeats_max, |
| device=rep_cnt.device |
| ).expand(carry_sel.size(0), repeats_max) |
| mask_rep = grid < rep_cnt.unsqueeze(1) |
|
|
| start_pos = buf_offsets[carry_sel] |
| target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] |
| payload = carry_digit.to(torch.int32).unsqueeze(1).expand_as(grid)[mask_rep] |
|
|
| bits_buffer[target_pos] = payload |
| buf_offsets.index_add_(0, carry_sel, rep_cnt) |
| num_carry_digits[carry_sel] = 0 |
|
|
| else: |
| new_digit = torch.tensor([ |
| _next_digit(i) if m else 0 |
| for i, m in enumerate(mask.tolist()) |
| ], dtype=torch.int64, device=mask.device) |
| current_code_in_int = torch.where( |
| mask, |
| _shift_left(current_code_in_int, self._base, self._base_to_pm1) + new_digit, |
| current_code_in_int |
| ) |
| |
| low = torch.where( |
| mask, |
| _shift_left(low, self._base, self._base_to_pm1), |
| low |
| ) |
| high = torch.where( |
| mask, |
| _shift_left(high, self._base, self._base_to_pm1) + self._base - 1, |
| high, |
| ) |
| return low, high, bits_buffer, buf_offsets, num_carry_digits, current_code_in_int |
|
|
| def flush_carry_digits( |
| self, |
| low, |
| high, |
| encoding: bool = True, |
| num_carry_digits: Optional[torch.Tensor] = None, |
| current_code_in_int: Optional[torch.Tensor] = None, |
| _next_digit: Optional[Callable[[int], int]] = None, |
| valid: Optional[torch.Tensor] = None, |
| ): |
| valid = valid if valid is not None else True |
| while True: |
| second_msd_low = (low // self._base_to_pm2) % self._base |
| second_msd_high = ((high - 1) // self._base_to_pm2) % self._base |
| msd_low = low // self._base_to_pm1 |
| msd_high = (high - 1) // self._base_to_pm1 |
| mask = (msd_low + 1 == msd_high) & (second_msd_low == self._base - 1) & (second_msd_high == 0) |
| mask = mask & valid |
| if not torch.any(mask): |
| break |
| |
| low = torch.where( |
| mask, |
| _shift_left_keeping_msd(low, self._base, self._base_to_pm1, self._base_to_pm2), |
| low |
| ) |
| high = torch.where( |
| mask, |
| _shift_left_keeping_msd(high, self._base, self._base_to_pm1, self._base_to_pm2) + self._base - 1, |
| high |
| ) |
| if encoding: |
| num_carry_digits = torch.where(mask, num_carry_digits + 1, num_carry_digits) |
| else: |
| new_digit = torch.tensor([ |
| _next_digit(i) if m else 0 |
| for i, m in enumerate(mask.tolist()) |
| ], dtype=torch.int64, device=mask.device) |
| current_code_in_int = torch.where( |
| mask, |
| _shift_left_keeping_msd(current_code_in_int, self._base, self._base_to_pm1, self._base_to_pm2) + new_digit, |
| current_code_in_int |
| ) |
| |
| return low, high, num_carry_digits, current_code_in_int |
|
|
| def _process( |
| self, |
| pdf: torch.Tensor, |
| symbols: Optional[torch.Tensor] = None, |
| encoding: bool = True, |
| return_num_padded_bits: bool = False, |
| encoded_bits: Optional[List[List[int]]] = None, |
| lengths: Optional[torch.Tensor] = None, |
| ) -> List[bytes] | Tuple[List[bytes], List[int]]: |
| assert pdf is not None, "symbols or pdf must be provided" |
| assert pdf.ndim == 3, "input must be [B, T, V]" |
| B, T, V = pdf.shape |
| device = pdf.device |
| if lengths is None: |
| lengths = torch.full((B,),T,dtype = torch.int64,device=device) |
| |
|
|
| lengths = torch.clamp(lengths, min=0, max=T) |
| low = torch.zeros((B,), dtype=torch.int64, device=device) |
| |
| |
| |
| |
| |
| |
| |
| high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device) |
|
|
| cdf = _pdf_to_cdf(pdf) |
|
|
| if encoding: |
| assert symbols is not None, "symbols must be provided for encoding" |
| assert encoded_bits is None, "encoded_bits must be None for encoding" |
| num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device) |
|
|
| digits_sym = math.ceil(math.log(V, self._base)) |
| max_digits = self._precision + 2 * T * digits_sym |
| bits_buffer = torch.empty( |
| B * max_digits, |
| dtype=torch.int32, |
| device=device |
| ) |
| buf_offsets = torch.arange( |
| B, |
| device=device, |
| dtype=torch.int32 |
| ) * max_digits |
| _next_digit = None |
| current_code_in_int = None |
| decoded_symbols = None |
| else: |
| assert symbols is None, "symbols must be None for decoding" |
| assert encoded_bits is not None, "encoded_bits must be provided for decoding" |
| num_carry_digits = None |
| bits_buffer = None |
| buf_offsets = None |
| |
| cursor = [0] * B |
| def _next_digit(idx: int) -> int: |
| if cursor[idx] < len(encoded_bits[idx]): |
| d = encoded_bits[idx][cursor[idx]] |
| cursor[idx] += 1 |
| return d |
| |
| |
| |
| return self._base - 1 |
|
|
| current_code_in_int = torch.zeros((B,), dtype=torch.int64, device=device) |
| for _ in range(self._precision): |
| digits = torch.tensor([_next_digit(i) for i in range(B)], |
| dtype=torch.int64, device=device) |
| current_code_in_int = current_code_in_int * self._base + digits |
|
|
| decoded_symbols = torch.zeros((B, T), dtype=torch.int64, device=device) |
| |
| for t in range(T): |
|
|
| valid = t < lengths |
|
|
| if not valid.any(): |
| break |
|
|
| cdf_t = cdf[valid, t] |
| low_valid = low[valid] |
| high_valid = high[valid] |
| width_valid = high_valid - low_valid + 1 |
|
|
| intervals = low_valid.unsqueeze(1) + (cdf_t * width_valid.unsqueeze(1)).type(torch.int64) |
|
|
| if encoding: |
| symbols_t = symbols[valid, t : t+1] |
| else: |
| symbols_t = torch.searchsorted(intervals, current_code_in_int[valid].unsqueeze(1), right=True) - 1 |
| |
| symbols_t = symbols_t.clamp(max=V-1) |
| |
| decoded_symbols[valid, t] = symbols_t.squeeze(1) |
| |
| old_low = low.clone() |
| |
| low[valid] = intervals.gather(1, symbols_t).squeeze(1) |
| high[valid] = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1 |
| |
| (low, high, bits_buffer, buf_offsets, num_carry_digits, current_code_in_int) = self.flush_matching_digits( |
| low, high, old_low, |
| encoding=encoding, |
| bits_buffer=bits_buffer, |
| buf_offsets=buf_offsets, |
| num_carry_digits=num_carry_digits, |
| current_code_in_int=current_code_in_int if not encoding else None, |
| _next_digit=_next_digit if not encoding else None, |
| valid=valid |
| ) |
|
|
| (low, high, num_carry_digits, current_code_in_int) = self.flush_carry_digits( |
| low, high, |
| encoding=encoding, |
| num_carry_digits=num_carry_digits, |
| current_code_in_int=current_code_in_int if not encoding else None, |
| _next_digit=_next_digit if not encoding else None, |
| valid=valid |
| ) |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if encoding: |
| output_compressed_bytes = [] |
| output_num_padded_bits = [] |
| |
| bits_buffer[buf_offsets] = (low // self._base_to_pm1).to(torch.int32) |
| buf_offsets = buf_offsets + 1 |
|
|
| carry_sel = num_carry_digits.nonzero(as_tuple=False).flatten() |
| if carry_sel.numel(): |
| carry_digit = self._base - 1 |
|
|
| rep_cnt = num_carry_digits[carry_sel] |
| repeats_max = rep_cnt.max() |
| grid = torch.arange( |
| repeats_max, |
| device=rep_cnt.device |
| ).expand(carry_sel.size(0), repeats_max) |
| mask_rep = grid < rep_cnt.unsqueeze(1) |
|
|
| start_pos = buf_offsets[carry_sel] |
| target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] |
|
|
| bits_buffer[target_pos] = carry_digit |
| buf_offsets.index_add_(0, carry_sel, rep_cnt) |
| num_carry_digits[carry_sel] = 0 |
|
|
| for idx in range(B): |
| offset_start = idx * max_digits |
| offset_end = buf_offsets[idx] |
| bits_list = bits_buffer[offset_start:offset_end].cpu().tolist() |
| bitstr = "".join(map(str, bits_list)) |
| compressed_bytes, num_padded_bits = utils.bits_to_bytes(bitstr) |
| output_compressed_bytes.append(compressed_bytes) |
| output_num_padded_bits.append(num_padded_bits) |
|
|
| if return_num_padded_bits: |
| return output_compressed_bytes, output_num_padded_bits |
| else: |
| return output_compressed_bytes |
| else: |
| return decoded_symbols |
|
|
| def batched_encode( |
| self, |
| pdf: torch.Tensor, |
| symbols: torch.Tensor, |
| lengths: Optional[torch.Tensor] = None, |
| return_num_padded_bits: bool = False |
| ) -> List[bytes] | Tuple[List[bytes], List[int]]: |
| B, T, V = pdf.shape |
| device = pdf.device |
| if lengths is None: |
| lengths = torch.full((B,),T,dtype = torch.int64,device=device) |
| return self._process( |
| pdf, |
| symbols=symbols, |
| encoding=True, |
| return_num_padded_bits=return_num_padded_bits, |
| encoded_bits=None, |
| lengths = lengths, |
| ) |
|
|
| def batched_decode( |
| self, |
| pdf: torch.Tensor, |
| compressed_bytes: List[bytes], |
| num_padded_bits: List[int], |
| lengths: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| B, T, V = pdf.shape |
| device = pdf.device |
| if lengths is None: |
| lengths = torch.full((B,),T,dtype = torch.int64,device=device) |
| assert len(compressed_bytes) == B, "encoded_bits length must be equal to batch size" |
| assert len(num_padded_bits) == B, "num_padded_bits length must be equal to batch size" |
| encoded_bits = [[] for _ in range(B)] |
| for idx, (compressed_b, num_padded) in enumerate(zip(compressed_bytes, num_padded_bits)): |
| bits = utils.bytes_to_bits(compressed_b, num_padded_bits=num_padded) |
| encoded_bits[idx] = list(map(int, bits)) |
| |
| return self._process( |
| pdf, |
| symbols=None, |
| encoding=False, |
| return_num_padded_bits=False, |
| encoded_bits=encoded_bits, |
| lengths=lengths, |
| ) |
|
|
| def incremental_batched_encode( |
| self, |
| pdf: torch.Tensor, |
| symbols: torch.Tensor, |
| lengths: Optional[torch.Tensor] = None, |
| bit_threshold: Optional[int] = None, |
| return_num_padded_bits: bool = False |
| ) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]: |
| """ |
| Incrementally encode symbols with early stopping when bit threshold is exceeded. |
| |
| Args: |
| pdf: [B, T, V] probability distributions |
| symbols: [B, T] symbols to encode |
| lengths: [B] length of each sequence (optional) |
| bit_threshold: Stop encoding when any sequence exceeds this many bits |
| return_num_padded_bits: Whether to return padding information |
| |
| Returns: |
| final_compressed_bytes: List[bytes] - final compressed result for each sequence |
| stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally) |
| final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True) |
| """ |
| assert pdf.ndim == 3, "input must be [B, T, V]" |
| B, T, V = pdf.shape |
| device = pdf.device |
| |
| if lengths is None: |
| lengths = torch.full((B,), T, dtype=torch.int64, device=device) |
| |
| lengths = torch.clamp(lengths, min=0, max=T) |
| |
| |
| low = torch.zeros((B,), dtype=torch.int64, device=device) |
| high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device) |
| num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device) |
| |
| |
| digits_sym = math.ceil(math.log(V, self._base)) |
| max_digits = self._precision + 2 * T * digits_sym |
| bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device) |
| buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits |
|
|
| base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits |
| |
| |
| temp_bits_buffer = torch.empty_like(bits_buffer) |
| temp_buf_offsets = torch.empty_like(buf_offsets) |
| temp_num_carry_digits = torch.empty_like(num_carry_digits) |
| |
| cdf = _pdf_to_cdf(pdf) |
| |
| |
| final_buffer = torch.empty_like(bits_buffer) |
| final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device) |
| final_num_padded_bits = [None] * B |
| stopped_at_step = [-1] * B |
| |
| |
| active_sequences = torch.ones(B, dtype=torch.bool, device=device) |
| |
| |
| prev_finalized_buffer = torch.empty_like(bits_buffer) |
| prev_finalized_ends = torch.zeros_like(buf_offsets) |
| |
| for t in range(T): |
| valid = (t < lengths) & active_sequences |
| |
| if not valid.any(): |
| break |
| |
| cdf_t = cdf[valid, t] |
| low_valid = low[valid] |
| high_valid = high[valid] |
| width_valid = high_valid - low_valid + 1 |
| |
| intervals = low_valid.unsqueeze(1) + (cdf_t * width_valid.unsqueeze(1)).type(torch.int64) |
| symbols_t = symbols[valid, t:t+1] |
| |
| old_low = low.clone() |
| low[valid] = intervals.gather(1, symbols_t).squeeze(1) |
| high[valid] = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1 |
| |
| |
| (low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits( |
| low, high, old_low, |
| encoding=True, |
| bits_buffer=bits_buffer, |
| buf_offsets=buf_offsets, |
| num_carry_digits=num_carry_digits, |
| current_code_in_int=None, |
| _next_digit=None, |
| valid=valid |
| ) |
| |
| (low, high, num_carry_digits, _) = self.flush_carry_digits( |
| low, high, |
| encoding=True, |
| num_carry_digits=num_carry_digits, |
| current_code_in_int=None, |
| _next_digit=None, |
| valid=valid |
| ) |
| |
| |
| need_check_threshold = bit_threshold is not None and active_sequences.any() |
| some_seq_finished = ((t + 1 >= lengths) & active_sequences).any() |
| |
| if need_check_threshold or some_seq_finished: |
| |
| temp_bits_buffer.copy_(bits_buffer, True) |
| temp_buf_offsets.copy_(buf_offsets, True) |
| temp_num_carry_digits.copy_(num_carry_digits, True) |
| |
| |
| temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32) |
| temp_buf_offsets += 1 |
| |
| |
| carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten() |
| if carry_sel.numel(): |
| carry_digit = self._base - 1 |
| rep_cnt = temp_num_carry_digits[carry_sel] |
| repeats_max = rep_cnt.max() |
| grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max) |
| mask_rep = grid < rep_cnt.unsqueeze(1) |
| |
| start_pos = temp_buf_offsets[carry_sel] |
| target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] |
| temp_bits_buffer[target_pos] = carry_digit |
| temp_buf_offsets.index_add_(0, carry_sel, rep_cnt) |
| temp_num_carry_digits[carry_sel] = 0 |
| |
| |
| if need_check_threshold: |
| current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets) |
| exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences |
| |
| if exceeds_threshold.any(): |
| stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten() |
| for idx in stopped_indices.cpu().tolist(): |
| active_sequences[idx] = False |
| stopped_at_step[idx] = t |
| |
| final_buffer_ends[idx] = prev_finalized_ends[idx] |
| offset_start = idx * max_digits |
| offset_end = prev_finalized_ends[idx] |
| final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end]) |
| |
| |
| is_final_step = (t + 1 >= lengths) & active_sequences |
| if is_final_step.any(): |
| final_step_indices = is_final_step.nonzero(as_tuple=False).flatten() |
| for idx in final_step_indices.cpu().tolist(): |
| active_sequences[idx] = False |
| stopped_at_step[idx] = t + 1 |
| |
| final_buffer_ends[idx] = temp_buf_offsets[idx] |
| |
| offset_start = idx * max_digits |
| offset_end = temp_buf_offsets[idx] |
| final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end]) |
| |
| |
| if need_check_threshold: |
| prev_finalized_buffer.copy_(temp_bits_buffer) |
| prev_finalized_ends.copy_(temp_buf_offsets) |
| |
| |
| final_compressed_bytes = [] |
| |
| for idx in range(B): |
| offset_start = idx * max_digits |
| offset_end = final_buffer_ends[idx] |
| bits_list = final_buffer[offset_start:offset_end].cpu().tolist() |
| bitstr = "".join(map(str, bits_list)) |
| comp_bytes, num_padded = utils.bits_to_bytes(bitstr) |
| final_compressed_bytes.append(comp_bytes) |
| if return_num_padded_bits: |
| final_num_padded_bits[idx] = num_padded |
| |
| if return_num_padded_bits: |
| return final_compressed_bytes, stopped_at_step, final_num_padded_bits |
| else: |
| return final_compressed_bytes, stopped_at_step |
|
|
| def test_incremental_encoding(): |
| """Test the incremental encoding functionality""" |
| print("Testing incremental encoding...") |
| |
| batch_size = 4 |
| seq_len = 32 |
| vocab_size = 64 |
| base = 2 |
| precision = 32 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) |
| pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6) |
| pdf = pdf.softmax(dim=-1) |
| pdf = utils.batched_normalize_pdf_for_arithmetic_coding(pdf) |
| |
| lengths = torch.tensor([seq_len, seq_len-2, seq_len-4, seq_len], device=device) |
| |
| AC = BatchedArithmeticEncoder(base, precision) |
| |
| |
| bit_threshold = 20 |
| final_compressed_bytes, stopped_at_step, final_num_padded_bits = AC.incremental_batched_encode( |
| pdf, symbols, lengths=lengths, bit_threshold=bit_threshold, return_num_padded_bits=True |
| ) |
| |
| print(f"Stopped at steps: {stopped_at_step}") |
| print(f"Final compressed sizes (bytes): {[len(cb) if cb else 0 for cb in final_compressed_bytes]}") |
| bit_counts = [len(cb) * 8 - pb if cb else 0 for cb, pb in zip(final_compressed_bytes, final_num_padded_bits)] |
| print(f"Final bit counts: {bit_counts}") |
| |
| |
| final_compressed_bytes_full, stopped_at_step_full = AC.incremental_batched_encode( |
| pdf, symbols, lengths=lengths, bit_threshold=None, return_num_padded_bits=False |
| ) |
| |
| print(f"Full encoding stopped at: {stopped_at_step_full}") |
| print(f"Full encoding sizes (bytes): {[len(cb) if cb else 0 for cb in final_compressed_bytes_full]}") |
| |
| |
| print("\n--- Consistency Test ---") |
| |
| |
| prefix_symbols = [] |
| prefix_pdf = [] |
| prefix_lengths = [] |
| |
| for i in range(batch_size): |
| if stopped_at_step[i] == -1: |
| |
| stop_point = seq_len |
| else: |
| |
| stop_point = stopped_at_step[i] |
| |
| prefix_symbols.append(symbols[i, :stop_point]) |
| prefix_pdf.append(pdf[i, :stop_point]) |
| prefix_lengths.append(stop_point) |
| |
| |
| max_prefix_len = max(prefix_lengths) |
| batch_prefix_symbols = torch.zeros((batch_size, max_prefix_len), dtype=symbols.dtype, device=device) |
| batch_prefix_pdf = torch.zeros((batch_size, max_prefix_len, vocab_size), dtype=pdf.dtype, device=device) |
| |
| for i in range(batch_size): |
| length = prefix_lengths[i] |
| batch_prefix_symbols[i, :length] = prefix_symbols[i] |
| batch_prefix_pdf[i, :length] = prefix_pdf[i] |
| |
| prefix_lengths_tensor = torch.tensor(prefix_lengths, dtype=torch.int64, device=device) |
| |
| |
| prefix_compressed_bytes, prefix_num_padded_bits = AC.batched_encode( |
| batch_prefix_pdf, batch_prefix_symbols, |
| lengths=prefix_lengths_tensor, |
| return_num_padded_bits=True |
| ) |
| |
| |
| print("Comparing incremental vs regular encoding on prefixes:") |
| all_match = True |
| for i in range(batch_size): |
| incremental_bytes = final_compressed_bytes[i] |
| regular_bytes = prefix_compressed_bytes[i] |
| |
| if incremental_bytes == regular_bytes: |
| print(f"Sequence {i}: ✓ Match (stopped at step {stopped_at_step[i]})") |
| else: |
| print(f"Sequence {i}: ✗ Mismatch (stopped at step {stopped_at_step[i]})") |
| print(f" Incremental: {len(incremental_bytes) if incremental_bytes else 0} bytes") |
| print(f" Regular: {len(regular_bytes)} bytes") |
| all_match = False |
| assert all_match, "Some incremental encodings don't match - there may be a bug" |
| if all_match: |
| print("✓ All incremental encodings match regular encodings on prefixes!") |
| else: |
| print("✗ Some incremental encodings don't match - there may be a bug") |
| |
| print("✓ Incremental encoding test completed") |
|
|
| def test_consistency_various_thresholds(): |
| """Test consistency between incremental and regular encoding with various thresholds""" |
| print("\n=== Testing Consistency with Various Thresholds ===") |
| |
| batch_size = 3 |
| seq_len = 128 |
| vocab_size = 4 |
| base = 2 |
| precision = 16 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| torch.manual_seed(42) |
| symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) |
| pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6) |
| pdf = pdf.softmax(dim=-1) |
| lengths = torch.full((batch_size,), seq_len, device=device) |
| |
| AC = BatchedArithmeticEncoder(base, precision) |
| |
| |
| thresholds = [10, 25, 50, 100, None] |
| |
| for threshold in thresholds: |
| print(f"\n--- Testing with threshold: {threshold} ---") |
| |
| |
| if threshold is None: |
| final_bytes, stop_steps = AC.incremental_batched_encode( |
| pdf, symbols, lengths=lengths, bit_threshold=threshold |
| ) |
| else: |
| final_bytes, stop_steps = AC.incremental_batched_encode( |
| pdf, symbols, lengths=lengths, bit_threshold=threshold |
| ) |
| |
| print(f"Stop steps: {stop_steps}") |
| |
| |
| prefix_lengths = [] |
| for i in range(batch_size): |
| if stop_steps[i] == -1: |
| prefix_lengths.append(lengths[i].item()) |
| else: |
| prefix_lengths.append(stop_steps[i]) |
| |
| max_prefix_len = max(prefix_lengths) |
| batch_prefix_symbols = torch.zeros((batch_size, max_prefix_len), dtype=symbols.dtype, device=device) |
| batch_prefix_pdf = torch.zeros((batch_size, max_prefix_len, vocab_size), dtype=pdf.dtype, device=device) |
| |
| for i in range(batch_size): |
| length = prefix_lengths[i] |
| batch_prefix_symbols[i, :length] = symbols[i, :length] |
| batch_prefix_pdf[i, :length] = pdf[i, :length] |
| |
| prefix_lengths_tensor = torch.tensor(prefix_lengths, dtype=torch.int64, device=device) |
| |
| |
| regular_bytes = AC.batched_encode( |
| batch_prefix_pdf, batch_prefix_symbols, |
| lengths=prefix_lengths_tensor |
| ) |
| |
| |
| all_consistent = True |
| for i in range(batch_size): |
| if final_bytes[i] == regular_bytes[i]: |
| print(f" Seq {i}: ✓ (len={prefix_lengths[i]})") |
| else: |
| print(f" Seq {i}: ✗ INCONSISTENT (len={prefix_lengths[i]})") |
| print(f" Incremental: {len(final_bytes[i]) if final_bytes[i] else 0} bytes") |
| print(f" Regular: {len(regular_bytes[i])} bytes") |
| all_consistent = False |
| assert all_consistent, "Some incremental encodings don't match - there may be a bug" |
| if all_consistent: |
| print(f" ✓ All sequences consistent for threshold {threshold}") |
| else: |
| print(f" ✗ Inconsistencies found for threshold {threshold}") |
| |
| print("\n✓ Consistency testing completed") |
|
|
| if __name__ == "__main__": |
| |
| |
| test_incremental_encoding() |
| |
| |
| test_consistency_various_thresholds() |
| batch_size = 32 |
| seq_len = 128 |
| vocab_size = 256 |
| base = 2 |
| precision = 32 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| lengths = torch.randint(1,seq_len,(batch_size,),device=device) |
| symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) |
| |
| |
| pdf = torch.zeros(batch_size, seq_len, vocab_size, device=device) |
| for i in range(batch_size): |
| |
| s = torch.randint(0, vocab_size, (lengths[i],), device=device) |
| symbols[i, :lengths[i]] = s |
| for t in range(lengths[i]): |
| pdf[i, t, s[t]] = 1.0 |
| pdf = pdf / pdf.sum(-1, keepdim=True) |
|
|
| pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6) |
| pdf = pdf.softmax(dim=-1) |
| symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) |
| lengths = torch.randint(1, seq_len, (batch_size,), device=device) |
| |
|
|
| AC = BatchedArithmeticEncoder(base, precision) |
|
|
| |
| print("Testing original batched encoding...") |
| start_event = torch.cuda.Event(enable_timing=True) if device == "cuda" else None |
| end_event = torch.cuda.Event(enable_timing=True) if device == "cuda" else None |
| |
| if device == "cuda": |
| start_event.record() |
| codes, padded_bits = AC.batched_encode(pdf, symbols,lengths=lengths,return_num_padded_bits=True) |
| if device == "cuda": |
| end_event.record() |
| torch.cuda.synchronize() |
| print(f"CUDA wall clock time: {start_event.elapsed_time(end_event):.2f} ms") |
| print([len(c) for c in codes]) |
|
|
| decoded = AC.batched_decode(pdf, codes, padded_bits,lengths) |
| print("[DEBUG]: decoded {} symbols {}".format(decoded[0], symbols[0])) |
| print("✓ passed - avg. digits per seq:", [len(s) for s in codes]) |
|
|
| |
| for i in range(batch_size): |
| print(lengths[i].item()) |
| print(decoded[i, :lengths[i].item()]) |
| print(symbols[i, :lengths[i].item()]) |
| l = lengths[i].item() |
| assert torch.all(decoded[i, :l] == symbols[i, :l]), f"Sample {i} mismatch" |
| print("All variable-length sequences verified successfully") |