Byte-lingua-code / m1_compression /batched_arithmetic_coder.py
2ira's picture
offline_compression_graph_code
72c0672 verified
import torch
from typing import List, Tuple, Optional, Callable
from m1_compression import utils
import math
def _pdf_to_cdf(pdf: torch.Tensor) -> torch.Tensor:
# NOTE: we do cumsum in float64, as we found
# cumsum in float32 leads to numerical errors
# when batch size is different across runs
cdf = torch.cumsum(pdf.to(torch.float64), dim=-1).to(torch.float32)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # prepend 0
return cdf # shape [..., vocab_size+1]
def _shift_left(x: int, base: int, base_to_pm1: int) -> int:
"""Shift `x` one digit left."""
return (x % base_to_pm1) * base
def _shift_left_keeping_msd(x: int, base: int, base_to_pm1: int, base_to_pm2: int) -> int:
"""Shift `x` except MSD, which remains in place, one digit left."""
return x - (x % base_to_pm1) + (x % base_to_pm2) * base
class BatchedArithmeticEncoder:
def __init__(self, base: int, precision: int):
self._base: int = base
self._base_to_pm1: int = int(base ** (precision - 1))
self._base_to_pm2: int = int(base ** (precision - 2))
self._precision: int = precision
def _get_bit_counts(self, buf_offsets: torch.Tensor, base_offsets: torch.Tensor) -> torch.Tensor:
"""Get bit counts for all sequences."""
bit_counts = buf_offsets - base_offsets
return bit_counts
# Helper lambdas --------------------------------------------------------
def flush_matching_digits(
self,
low,
high,
old_low,
encoding: bool = True,
bits_buffer: Optional[torch.Tensor] = None,
buf_offsets: Optional[torch.Tensor] = None,
num_carry_digits: Optional[torch.Tensor] = None,
current_code_in_int: Optional[torch.Tensor] = None,
_next_digit: Optional[Callable[[int], int]] = None,
valid: Optional[torch.Tensor] = None, # add mask
):
valid = valid if valid is not None else True
while True:
msd_low = low // self._base_to_pm1
msd_high = high // self._base_to_pm1
mask = msd_low == msd_high
## get masked mask
mask = mask & valid
if not torch.any(mask):
break
if encoding:
msd_low_old = old_low // self._base_to_pm1
# digit = msd_low[mask]
# digit_old = msd_low_old[mask]
# for idx, (d, d_old) in zip(mask.nonzero(as_tuple=False).flatten().tolist(), zip(digit.tolist(), digit_old.tolist())):
# # 1) real digit
# code[idx].append(int(d))
# # 2) any pending carries now resolved
# num_carry_digit = num_carry_digits[idx].item()
# if num_carry_digit:
# carry_digit = (
# self._base - 1 + d - d_old
# ) % self._base
# code[idx].extend([carry_digit] * num_carry_digit)
# num_carry_digits[idx] = 0
# ------------------------------------------------------------------
sel = mask.nonzero(as_tuple=False).flatten()
bits_buffer[buf_offsets[sel]] = msd_low[sel].to(torch.int32)
buf_offsets.index_add_(
0,
sel,
torch.ones_like(sel, dtype=torch.int32)
)
carry_sel = sel[(num_carry_digits[sel] > 0)]
if carry_sel.numel():
_digit_carry = msd_low[carry_sel]
_digit_carry_old = msd_low_old[carry_sel]
carry_digit = (
self._base - 1 + _digit_carry - _digit_carry_old
) % self._base
rep_cnt = num_carry_digits[carry_sel]
repeats_max = rep_cnt.max().item()
grid = torch.arange(
repeats_max,
device=rep_cnt.device
).expand(carry_sel.size(0), repeats_max) # [K2, M]
mask_rep = grid < rep_cnt.unsqueeze(1) # [K2, M]
start_pos = buf_offsets[carry_sel]
target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep]
payload = carry_digit.to(torch.int32).unsqueeze(1).expand_as(grid)[mask_rep]
bits_buffer[target_pos] = payload
buf_offsets.index_add_(0, carry_sel, rep_cnt)
num_carry_digits[carry_sel] = 0
else:
new_digit = torch.tensor([
_next_digit(i) if m else 0
for i, m in enumerate(mask.tolist())
], dtype=torch.int64, device=mask.device)
current_code_in_int = torch.where(
mask,
_shift_left(current_code_in_int, self._base, self._base_to_pm1) + new_digit,
current_code_in_int
)
# Shift left to remove matching digits
low = torch.where(
mask,
_shift_left(low, self._base, self._base_to_pm1),
low
)
high = torch.where(
mask,
_shift_left(high, self._base, self._base_to_pm1) + self._base - 1,
high,
)
return low, high, bits_buffer, buf_offsets, num_carry_digits, current_code_in_int
def flush_carry_digits(
self,
low,
high,
encoding: bool = True,
num_carry_digits: Optional[torch.Tensor] = None,
current_code_in_int: Optional[torch.Tensor] = None,
_next_digit: Optional[Callable[[int], int]] = None,
valid: Optional[torch.Tensor] = None,
):
valid = valid if valid is not None else True
while True:
second_msd_low = (low // self._base_to_pm2) % self._base
second_msd_high = ((high - 1) // self._base_to_pm2) % self._base
msd_low = low // self._base_to_pm1
msd_high = (high - 1) // self._base_to_pm1
mask = (msd_low + 1 == msd_high) & (second_msd_low == self._base - 1) & (second_msd_high == 0)
mask = mask & valid # valid mask
if not torch.any(mask):
break
# For sequences in *mask*, we shift left *keeping MSD fixed*
low = torch.where(
mask,
_shift_left_keeping_msd(low, self._base, self._base_to_pm1, self._base_to_pm2),
low
)
high = torch.where(
mask,
_shift_left_keeping_msd(high, self._base, self._base_to_pm1, self._base_to_pm2) + self._base - 1,
high
)
if encoding:
num_carry_digits = torch.where(mask, num_carry_digits + 1, num_carry_digits)
else:
new_digit = torch.tensor([
_next_digit(i) if m else 0
for i, m in enumerate(mask.tolist())
], dtype=torch.int64, device=mask.device)
current_code_in_int = torch.where(
mask,
_shift_left_keeping_msd(current_code_in_int, self._base, self._base_to_pm1, self._base_to_pm2) + new_digit,
current_code_in_int
)
return low, high, num_carry_digits, current_code_in_int
def _process(
self,
pdf: torch.Tensor,
symbols: Optional[torch.Tensor] = None,
encoding: bool = True,
return_num_padded_bits: bool = False,
encoded_bits: Optional[List[List[int]]] = None,
lengths: Optional[torch.Tensor] = None,
) -> List[bytes] | Tuple[List[bytes], List[int]]:
assert pdf is not None, "symbols or pdf must be provided"
assert pdf.ndim == 3, "input must be [B, T, V]"
B, T, V = pdf.shape
device = pdf.device
if lengths is None:
lengths = torch.full((B,),T,dtype = torch.int64,device=device)
# --- interval state ----------------------------------------------------
lengths = torch.clamp(lengths, min=0, max=T)
low = torch.zeros((B,), dtype=torch.int64, device=device)
# NOTE: We represent the AC interval [0, 1) as rational numbers:
# [0, 1)
# ~ [self._low / base ** precision, (self._high + 1) / base ** precision)
# = [self._low / base ** precision, self._high / base ** precision],
# where the we represent the upper bound *INCLUSIVE*. This is a subtle
# detail required to make the integer arithmetic work correctly given that
# all involved integers have `precision` digits in base `base`.
high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device)
cdf = _pdf_to_cdf(pdf)
if encoding:
assert symbols is not None, "symbols must be provided for encoding"
assert encoded_bits is None, "encoded_bits must be None for encoding"
num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device)
digits_sym = math.ceil(math.log(V, self._base))
max_digits = self._precision + 2 * T * digits_sym
bits_buffer = torch.empty(
B * max_digits,
dtype=torch.int32,
device=device
)
buf_offsets = torch.arange(
B,
device=device,
dtype=torch.int32
) * max_digits
_next_digit = None
current_code_in_int = None
decoded_symbols = None
else:
assert symbols is None, "symbols must be None for decoding"
assert encoded_bits is not None, "encoded_bits must be provided for decoding"
num_carry_digits = None
bits_buffer = None
buf_offsets = None
# Stream read cursors -------------------------------------------------
cursor = [0] * B
def _next_digit(idx: int) -> int:
if cursor[idx] < len(encoded_bits[idx]):
d = encoded_bits[idx][cursor[idx]]
cursor[idx] += 1
return d
# Add padding to ensure the AC state is well-defined when decoding the last
# symbol. Note that what exactly we do here depends on how encoder
# termination is implemented (see `Encoder.terminate`).
return self._base - 1
current_code_in_int = torch.zeros((B,), dtype=torch.int64, device=device)
for _ in range(self._precision):
digits = torch.tensor([_next_digit(i) for i in range(B)],
dtype=torch.int64, device=device)
current_code_in_int = current_code_in_int * self._base + digits
decoded_symbols = torch.zeros((B, T), dtype=torch.int64, device=device)
# ---------------- main encoding ---------------------------------------
for t in range(T):
valid = t < lengths
if not valid.any():
break #all the sample is completed
cdf_t = cdf[valid, t] # [valid, V+1]
low_valid = low[valid]
high_valid = high[valid]
width_valid = high_valid - low_valid + 1
intervals = low_valid.unsqueeze(1) + (cdf_t * width_valid.unsqueeze(1)).type(torch.int64)
if encoding:
symbols_t = symbols[valid, t : t+1]
else:
symbols_t = torch.searchsorted(intervals, current_code_in_int[valid].unsqueeze(1), right=True) - 1
# V 是词汇表大小,即 pdf.shape[-1]。
symbols_t = symbols_t.clamp(max=V-1)
# ==============================================================
decoded_symbols[valid, t] = symbols_t.squeeze(1)
old_low = low.clone()
## there is some wrong,if there are no valid sequences to process in a batch
low[valid] = intervals.gather(1, symbols_t).squeeze(1)
high[valid] = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1
(low, high, bits_buffer, buf_offsets, num_carry_digits, current_code_in_int) = self.flush_matching_digits(
low, high, old_low,
encoding=encoding,
bits_buffer=bits_buffer,
buf_offsets=buf_offsets,
num_carry_digits=num_carry_digits,
current_code_in_int=current_code_in_int if not encoding else None,
_next_digit=_next_digit if not encoding else None,
valid=valid
)
(low, high, num_carry_digits, current_code_in_int) = self.flush_carry_digits(
low, high,
encoding=encoding,
num_carry_digits=num_carry_digits,
current_code_in_int=current_code_in_int if not encoding else None,
_next_digit=_next_digit if not encoding else None,
valid=valid
)
# for t in range(T):
# cdf_t = cdf[:, t] # [B, V+1]
# width = high - low + 1
# intervals = low.unsqueeze(1) + (cdf_t * width.unsqueeze(1)).type(torch.int64)
# old_low = low
# if encoding:
# symbols_t = symbols[:, t:t+1] # [B, 1]
# else:
# symbols_t = torch.searchsorted(intervals, current_code_in_int.unsqueeze(1), right=True) - 1
# decoded_symbols[:, t] = symbols_t.squeeze(1)
# low = intervals.gather(1, symbols_t).squeeze(1)
# high = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1
# # Renormalise until interval large enough
# (
# low,
# high,
# bits_buffer, # encoding variable
# buf_offsets, # encoding variable
# num_carry_digits, # encoding variable
# current_code_in_int, # decoding variable
# ) = self.flush_matching_digits(
# low,
# high,
# old_low,
# encoding,
# bits_buffer, # encoding variable
# buf_offsets, # encoding variable
# num_carry_digits, # encoding variable
# current_code_in_int, # decoding variable
# _next_digit, # decoding variable
# )
# (
# low,
# high,
# num_carry_digits, # encoding variable
# current_code_in_int, # decoding variable
# ) = self.flush_carry_digits(
# low,
# high,
# encoding,
# num_carry_digits, # encoding variable
# current_code_in_int, # decoding variable
# _next_digit, # decoding variable
# )
if encoding:
output_compressed_bytes = []
output_num_padded_bits = []
# -------------------- finalization ------------------------------------
bits_buffer[buf_offsets] = (low // self._base_to_pm1).to(torch.int32)
buf_offsets = buf_offsets + 1
carry_sel = num_carry_digits.nonzero(as_tuple=False).flatten()
if carry_sel.numel():
carry_digit = self._base - 1
rep_cnt = num_carry_digits[carry_sel]
repeats_max = rep_cnt.max()
grid = torch.arange(
repeats_max,
device=rep_cnt.device
).expand(carry_sel.size(0), repeats_max) # [K2, M]
mask_rep = grid < rep_cnt.unsqueeze(1) # [K2, M]
start_pos = buf_offsets[carry_sel]
target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep]
bits_buffer[target_pos] = carry_digit
buf_offsets.index_add_(0, carry_sel, rep_cnt)
num_carry_digits[carry_sel] = 0
for idx in range(B):
offset_start = idx * max_digits
offset_end = buf_offsets[idx]
bits_list = bits_buffer[offset_start:offset_end].cpu().tolist()
bitstr = "".join(map(str, bits_list))
compressed_bytes, num_padded_bits = utils.bits_to_bytes(bitstr)
output_compressed_bytes.append(compressed_bytes)
output_num_padded_bits.append(num_padded_bits)
if return_num_padded_bits:
return output_compressed_bytes, output_num_padded_bits
else:
return output_compressed_bytes
else:
return decoded_symbols
def batched_encode(
self,
pdf: torch.Tensor,
symbols: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
return_num_padded_bits: bool = False
) -> List[bytes] | Tuple[List[bytes], List[int]]:
B, T, V = pdf.shape
device = pdf.device
if lengths is None:
lengths = torch.full((B,),T,dtype = torch.int64,device=device)
return self._process(
pdf,
symbols=symbols,
encoding=True,
return_num_padded_bits=return_num_padded_bits,
encoded_bits=None,
lengths = lengths, ## pass length to address
)
def batched_decode(
self,
pdf: torch.Tensor,
compressed_bytes: List[bytes],
num_padded_bits: List[int],
lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, T, V = pdf.shape
device = pdf.device
if lengths is None:
lengths = torch.full((B,),T,dtype = torch.int64,device=device)
assert len(compressed_bytes) == B, "encoded_bits length must be equal to batch size"
assert len(num_padded_bits) == B, "num_padded_bits length must be equal to batch size"
encoded_bits = [[] for _ in range(B)]
for idx, (compressed_b, num_padded) in enumerate(zip(compressed_bytes, num_padded_bits)):
bits = utils.bytes_to_bits(compressed_b, num_padded_bits=num_padded)
encoded_bits[idx] = list(map(int, bits))
# print("[DEBUG]: encoded_bits[{}]: {}".format(idx, encoded_bits[idx]))
return self._process(
pdf,
symbols=None,
encoding=False,
return_num_padded_bits=False,
encoded_bits=encoded_bits,
lengths=lengths,
)
def incremental_batched_encode(
self,
pdf: torch.Tensor,
symbols: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
bit_threshold: Optional[int] = None,
return_num_padded_bits: bool = False
) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]:
"""
Incrementally encode symbols with early stopping when bit threshold is exceeded.
Args:
pdf: [B, T, V] probability distributions
symbols: [B, T] symbols to encode
lengths: [B] length of each sequence (optional)
bit_threshold: Stop encoding when any sequence exceeds this many bits
return_num_padded_bits: Whether to return padding information
Returns:
final_compressed_bytes: List[bytes] - final compressed result for each sequence
stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally)
final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True)
"""
assert pdf.ndim == 3, "input must be [B, T, V]"
B, T, V = pdf.shape
device = pdf.device
if lengths is None:
lengths = torch.full((B,), T, dtype=torch.int64, device=device)
lengths = torch.clamp(lengths, min=0, max=T)
# Initialize arithmetic coding state
low = torch.zeros((B,), dtype=torch.int64, device=device)
high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device)
num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device)
# Initialize bit buffer
digits_sym = math.ceil(math.log(V, self._base))
max_digits = self._precision + 2 * T * digits_sym
bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device)
buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits
base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits
# Pre-allocate temporary buffers (avoid cloning at each step)
temp_bits_buffer = torch.empty_like(bits_buffer)
temp_buf_offsets = torch.empty_like(buf_offsets)
temp_num_carry_digits = torch.empty_like(num_carry_digits)
cdf = _pdf_to_cdf(pdf)
# Track final results for each sequence - save buffer states, not bytes
final_buffer = torch.empty_like(bits_buffer)
final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device)
final_num_padded_bits = [None] * B
stopped_at_step = [-1] * B # -1 means completed normally
# Track which sequences are still active
active_sequences = torch.ones(B, dtype=torch.bool, device=device)
# Keep track of previous step's finalized buffer state for threshold logic
prev_finalized_buffer = torch.empty_like(bits_buffer)
prev_finalized_ends = torch.zeros_like(buf_offsets)
for t in range(T):
valid = (t < lengths) & active_sequences
if not valid.any():
break # All sequences completed or stopped
cdf_t = cdf[valid, t] # [valid, V+1]
low_valid = low[valid]
high_valid = high[valid]
width_valid = high_valid - low_valid + 1
intervals = low_valid.unsqueeze(1) + (cdf_t * width_valid.unsqueeze(1)).type(torch.int64)
symbols_t = symbols[valid, t:t+1]
old_low = low.clone()
low[valid] = intervals.gather(1, symbols_t).squeeze(1)
high[valid] = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1
# Flush digits and update buffers
(low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits(
low, high, old_low,
encoding=True,
bits_buffer=bits_buffer,
buf_offsets=buf_offsets,
num_carry_digits=num_carry_digits,
current_code_in_int=None,
_next_digit=None,
valid=valid
)
(low, high, num_carry_digits, _) = self.flush_carry_digits(
low, high,
encoding=True,
num_carry_digits=num_carry_digits,
current_code_in_int=None,
_next_digit=None,
valid=valid
)
# Check if we need to compute results this step (if bit threshold checking or final step)
need_check_threshold = bit_threshold is not None and active_sequences.any()
some_seq_finished = ((t + 1 >= lengths) & active_sequences).any()
if need_check_threshold or some_seq_finished:
# Simulate finalization at this step using pre-allocated buffers
temp_bits_buffer.copy_(bits_buffer, True)
temp_buf_offsets.copy_(buf_offsets, True)
temp_num_carry_digits.copy_(num_carry_digits, True)
# Add final digit for all sequences (simulating termination)
temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32)
temp_buf_offsets += 1
# Handle remaining carry digits for all sequences
carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten()
if carry_sel.numel():
carry_digit = self._base - 1
rep_cnt = temp_num_carry_digits[carry_sel]
repeats_max = rep_cnt.max()
grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max)
mask_rep = grid < rep_cnt.unsqueeze(1)
start_pos = temp_buf_offsets[carry_sel]
target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep]
temp_bits_buffer[target_pos] = carry_digit
temp_buf_offsets.index_add_(0, carry_sel, rep_cnt)
temp_num_carry_digits[carry_sel] = 0
# Check bit threshold and identify newly stopped sequences
if need_check_threshold:
current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets)
exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences
if exceeds_threshold.any():
stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten()
for idx in stopped_indices.cpu().tolist(): # Only move indices to CPU
active_sequences[idx] = False
stopped_at_step[idx] = t
# Save the result from PREVIOUS step (before exceeding threshold)
final_buffer_ends[idx] = prev_finalized_ends[idx]
offset_start = idx * max_digits
offset_end = prev_finalized_ends[idx]
final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end])
# If final step, all remaining active sequences need results
is_final_step = (t + 1 >= lengths) & active_sequences
if is_final_step.any():
final_step_indices = is_final_step.nonzero(as_tuple=False).flatten()
for idx in final_step_indices.cpu().tolist():
active_sequences[idx] = False
stopped_at_step[idx] = t + 1
# Save current step result for sequences that completed normally
final_buffer_ends[idx] = temp_buf_offsets[idx]
# Copy the finalized bits to main buffer for this sequence
offset_start = idx * max_digits
offset_end = temp_buf_offsets[idx]
final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end])
# Update previous finalized buffer state for next iteration
if need_check_threshold:
prev_finalized_buffer.copy_(temp_bits_buffer)
prev_finalized_ends.copy_(temp_buf_offsets)
# Convert buffer states to compressed bytes at the very end
final_compressed_bytes = []
for idx in range(B):
offset_start = idx * max_digits
offset_end = final_buffer_ends[idx]
bits_list = final_buffer[offset_start:offset_end].cpu().tolist()
bitstr = "".join(map(str, bits_list))
comp_bytes, num_padded = utils.bits_to_bytes(bitstr)
final_compressed_bytes.append(comp_bytes)
if return_num_padded_bits:
final_num_padded_bits[idx] = num_padded
if return_num_padded_bits:
return final_compressed_bytes, stopped_at_step, final_num_padded_bits
else:
return final_compressed_bytes, stopped_at_step
def test_incremental_encoding():
"""Test the incremental encoding functionality"""
print("Testing incremental encoding...")
batch_size = 4
seq_len = 32
vocab_size = 64
base = 2
precision = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create test data with different scenarios
symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6)
pdf = pdf.softmax(dim=-1)
pdf = utils.batched_normalize_pdf_for_arithmetic_coding(pdf)
# Test with variable lengths to ensure edge cases are covered
lengths = torch.tensor([seq_len, seq_len-2, seq_len-4, seq_len], device=device)
AC = BatchedArithmeticEncoder(base, precision)
# Test incremental encoding with bit threshold
bit_threshold = 20 # Stop when sequences exceed 20 bits
final_compressed_bytes, stopped_at_step, final_num_padded_bits = AC.incremental_batched_encode(
pdf, symbols, lengths=lengths, bit_threshold=bit_threshold, return_num_padded_bits=True
)
print(f"Stopped at steps: {stopped_at_step}")
print(f"Final compressed sizes (bytes): {[len(cb) if cb else 0 for cb in final_compressed_bytes]}")
bit_counts = [len(cb) * 8 - pb if cb else 0 for cb, pb in zip(final_compressed_bytes, final_num_padded_bits)]
print(f"Final bit counts: {bit_counts}")
# Test without threshold
final_compressed_bytes_full, stopped_at_step_full = AC.incremental_batched_encode(
pdf, symbols, lengths=lengths, bit_threshold=None, return_num_padded_bits=False
)
print(f"Full encoding stopped at: {stopped_at_step_full}") # Should all be -1
print(f"Full encoding sizes (bytes): {[len(cb) if cb else 0 for cb in final_compressed_bytes_full]}")
# Consistency test: Verify that incremental encoding matches regular encoding on prefixes
print("\n--- Consistency Test ---")
# Extract prefixes based on where sequences stopped
prefix_symbols = []
prefix_pdf = []
prefix_lengths = []
for i in range(batch_size):
if stopped_at_step[i] == -1:
# Sequence completed normally, use full sequence
stop_point = seq_len
else:
# Sequence stopped early, use up to stop point
stop_point = stopped_at_step[i]
prefix_symbols.append(symbols[i, :stop_point])
prefix_pdf.append(pdf[i, :stop_point])
prefix_lengths.append(stop_point)
# Create tensors for prefix encoding
max_prefix_len = max(prefix_lengths)
batch_prefix_symbols = torch.zeros((batch_size, max_prefix_len), dtype=symbols.dtype, device=device)
batch_prefix_pdf = torch.zeros((batch_size, max_prefix_len, vocab_size), dtype=pdf.dtype, device=device)
for i in range(batch_size):
length = prefix_lengths[i]
batch_prefix_symbols[i, :length] = prefix_symbols[i]
batch_prefix_pdf[i, :length] = prefix_pdf[i]
prefix_lengths_tensor = torch.tensor(prefix_lengths, dtype=torch.int64, device=device)
# Encode prefixes using regular batched encoder
prefix_compressed_bytes, prefix_num_padded_bits = AC.batched_encode(
batch_prefix_pdf, batch_prefix_symbols,
lengths=prefix_lengths_tensor,
return_num_padded_bits=True
)
# Compare results
print("Comparing incremental vs regular encoding on prefixes:")
all_match = True
for i in range(batch_size):
incremental_bytes = final_compressed_bytes[i]
regular_bytes = prefix_compressed_bytes[i]
if incremental_bytes == regular_bytes:
print(f"Sequence {i}: ✓ Match (stopped at step {stopped_at_step[i]})")
else:
print(f"Sequence {i}: ✗ Mismatch (stopped at step {stopped_at_step[i]})")
print(f" Incremental: {len(incremental_bytes) if incremental_bytes else 0} bytes")
print(f" Regular: {len(regular_bytes)} bytes")
all_match = False
assert all_match, "Some incremental encodings don't match - there may be a bug"
if all_match:
print("✓ All incremental encodings match regular encodings on prefixes!")
else:
print("✗ Some incremental encodings don't match - there may be a bug")
print("✓ Incremental encoding test completed")
def test_consistency_various_thresholds():
"""Test consistency between incremental and regular encoding with various thresholds"""
print("\n=== Testing Consistency with Various Thresholds ===")
batch_size = 3
seq_len = 128
vocab_size = 4
base = 2
precision = 16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create fixed test data for reproducible results
torch.manual_seed(42)
symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6)
pdf = pdf.softmax(dim=-1)
lengths = torch.full((batch_size,), seq_len, device=device)
AC = BatchedArithmeticEncoder(base, precision)
# Test with multiple threshold values
thresholds = [10, 25, 50, 100, None] # None means no threshold
for threshold in thresholds:
print(f"\n--- Testing with threshold: {threshold} ---")
# Run incremental encoding
if threshold is None:
final_bytes, stop_steps = AC.incremental_batched_encode(
pdf, symbols, lengths=lengths, bit_threshold=threshold
)
else:
final_bytes, stop_steps = AC.incremental_batched_encode(
pdf, symbols, lengths=lengths, bit_threshold=threshold
)
print(f"Stop steps: {stop_steps}")
# Create prefix data based on stop points
prefix_lengths = []
for i in range(batch_size):
if stop_steps[i] == -1:
prefix_lengths.append(lengths[i].item())
else:
prefix_lengths.append(stop_steps[i])
max_prefix_len = max(prefix_lengths)
batch_prefix_symbols = torch.zeros((batch_size, max_prefix_len), dtype=symbols.dtype, device=device)
batch_prefix_pdf = torch.zeros((batch_size, max_prefix_len, vocab_size), dtype=pdf.dtype, device=device)
for i in range(batch_size):
length = prefix_lengths[i]
batch_prefix_symbols[i, :length] = symbols[i, :length]
batch_prefix_pdf[i, :length] = pdf[i, :length]
prefix_lengths_tensor = torch.tensor(prefix_lengths, dtype=torch.int64, device=device)
# Run regular encoding on prefixes
regular_bytes = AC.batched_encode(
batch_prefix_pdf, batch_prefix_symbols,
lengths=prefix_lengths_tensor
)
# Compare results
all_consistent = True
for i in range(batch_size):
if final_bytes[i] == regular_bytes[i]:
print(f" Seq {i}: ✓ (len={prefix_lengths[i]})")
else:
print(f" Seq {i}: ✗ INCONSISTENT (len={prefix_lengths[i]})")
print(f" Incremental: {len(final_bytes[i]) if final_bytes[i] else 0} bytes")
print(f" Regular: {len(regular_bytes[i])} bytes")
all_consistent = False
assert all_consistent, "Some incremental encodings don't match - there may be a bug"
if all_consistent:
print(f" ✓ All sequences consistent for threshold {threshold}")
else:
print(f" ✗ Inconsistencies found for threshold {threshold}")
print("\n✓ Consistency testing completed")
if __name__ == "__main__":
# Test incremental encoding
test_incremental_encoding()
# Test consistency between incremental and regular encoding
test_consistency_various_thresholds()
batch_size = 32
seq_len = 128
vocab_size = 256
base = 2
precision = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
lengths = torch.randint(1,seq_len,(batch_size,),device=device)
symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
# 构造 one-hot pdf
pdf = torch.zeros(batch_size, seq_len, vocab_size, device=device)
for i in range(batch_size):
# get pdf in valid length so it can be normalized.
s = torch.randint(0, vocab_size, (lengths[i],), device=device)
symbols[i, :lengths[i]] = s
for t in range(lengths[i]):
pdf[i, t, s[t]] = 1.0
pdf = pdf / pdf.sum(-1, keepdim=True)
pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6)
pdf = pdf.softmax(dim=-1)
symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
lengths = torch.randint(1, seq_len, (batch_size,), device=device)
# lengths = torch.full((batch_size,), seq_len, device=device)
AC = BatchedArithmeticEncoder(base, precision)
# Test original functionality
print("Testing original batched encoding...")
start_event = torch.cuda.Event(enable_timing=True) if device == "cuda" else None
end_event = torch.cuda.Event(enable_timing=True) if device == "cuda" else None
if device == "cuda":
start_event.record()
codes, padded_bits = AC.batched_encode(pdf, symbols,lengths=lengths,return_num_padded_bits=True)
if device == "cuda":
end_event.record()
torch.cuda.synchronize()
print(f"CUDA wall clock time: {start_event.elapsed_time(end_event):.2f} ms")
print([len(c) for c in codes])
decoded = AC.batched_decode(pdf, codes, padded_bits,lengths)
print("[DEBUG]: decoded {} symbols {}".format(decoded[0], symbols[0]))
print("✓ passed - avg. digits per seq:", [len(s) for s in codes])
# Validate the length
for i in range(batch_size):
print(lengths[i].item())
print(decoded[i, :lengths[i].item()])
print(symbols[i, :lengths[i].item()])
l = lengths[i].item()
assert torch.all(decoded[i, :l] == symbols[i, :l]), f"Sample {i} mismatch"
print("All variable-length sequences verified successfully")