import torch from typing import List, Tuple, Optional, Callable from m1_compression import utils import math def _pdf_to_cdf(pdf: torch.Tensor) -> torch.Tensor: # NOTE: we do cumsum in float64, as we found # cumsum in float32 leads to numerical errors # when batch size is different across runs cdf = torch.cumsum(pdf.to(torch.float64), dim=-1).to(torch.float32) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # prepend 0 return cdf # shape [..., vocab_size+1] def _shift_left(x: int, base: int, base_to_pm1: int) -> int: """Shift `x` one digit left.""" return (x % base_to_pm1) * base def _shift_left_keeping_msd(x: int, base: int, base_to_pm1: int, base_to_pm2: int) -> int: """Shift `x` except MSD, which remains in place, one digit left.""" return x - (x % base_to_pm1) + (x % base_to_pm2) * base class BatchedArithmeticEncoder: def __init__(self, base: int, precision: int): self._base: int = base self._base_to_pm1: int = int(base ** (precision - 1)) self._base_to_pm2: int = int(base ** (precision - 2)) self._precision: int = precision def _get_bit_counts(self, buf_offsets: torch.Tensor, base_offsets: torch.Tensor) -> torch.Tensor: """Get bit counts for all sequences.""" bit_counts = buf_offsets - base_offsets return bit_counts # Helper lambdas -------------------------------------------------------- def flush_matching_digits( self, low, high, old_low, encoding: bool = True, bits_buffer: Optional[torch.Tensor] = None, buf_offsets: Optional[torch.Tensor] = None, num_carry_digits: Optional[torch.Tensor] = None, current_code_in_int: Optional[torch.Tensor] = None, _next_digit: Optional[Callable[[int], int]] = None, valid: Optional[torch.Tensor] = None, # add mask ): valid = valid if valid is not None else True while True: msd_low = low // self._base_to_pm1 msd_high = high // self._base_to_pm1 mask = msd_low == msd_high ## get masked mask mask = mask & valid if not torch.any(mask): break if encoding: msd_low_old = old_low // self._base_to_pm1 # digit = msd_low[mask] # digit_old = msd_low_old[mask] # for idx, (d, d_old) in zip(mask.nonzero(as_tuple=False).flatten().tolist(), zip(digit.tolist(), digit_old.tolist())): # # 1) real digit # code[idx].append(int(d)) # # 2) any pending carries now resolved # num_carry_digit = num_carry_digits[idx].item() # if num_carry_digit: # carry_digit = ( # self._base - 1 + d - d_old # ) % self._base # code[idx].extend([carry_digit] * num_carry_digit) # num_carry_digits[idx] = 0 # ------------------------------------------------------------------ sel = mask.nonzero(as_tuple=False).flatten() bits_buffer[buf_offsets[sel]] = msd_low[sel].to(torch.int32) buf_offsets.index_add_( 0, sel, torch.ones_like(sel, dtype=torch.int32) ) carry_sel = sel[(num_carry_digits[sel] > 0)] if carry_sel.numel(): _digit_carry = msd_low[carry_sel] _digit_carry_old = msd_low_old[carry_sel] carry_digit = ( self._base - 1 + _digit_carry - _digit_carry_old ) % self._base rep_cnt = num_carry_digits[carry_sel] repeats_max = rep_cnt.max().item() grid = torch.arange( repeats_max, device=rep_cnt.device ).expand(carry_sel.size(0), repeats_max) # [K2, M] mask_rep = grid < rep_cnt.unsqueeze(1) # [K2, M] start_pos = buf_offsets[carry_sel] target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] payload = carry_digit.to(torch.int32).unsqueeze(1).expand_as(grid)[mask_rep] bits_buffer[target_pos] = payload buf_offsets.index_add_(0, carry_sel, rep_cnt) num_carry_digits[carry_sel] = 0 else: new_digit = torch.tensor([ _next_digit(i) if m else 0 for i, m in enumerate(mask.tolist()) ], dtype=torch.int64, device=mask.device) current_code_in_int = torch.where( mask, _shift_left(current_code_in_int, self._base, self._base_to_pm1) + new_digit, current_code_in_int ) # Shift left to remove matching digits low = torch.where( mask, _shift_left(low, self._base, self._base_to_pm1), low ) high = torch.where( mask, _shift_left(high, self._base, self._base_to_pm1) + self._base - 1, high, ) return low, high, bits_buffer, buf_offsets, num_carry_digits, current_code_in_int def flush_carry_digits( self, low, high, encoding: bool = True, num_carry_digits: Optional[torch.Tensor] = None, current_code_in_int: Optional[torch.Tensor] = None, _next_digit: Optional[Callable[[int], int]] = None, valid: Optional[torch.Tensor] = None, ): valid = valid if valid is not None else True while True: second_msd_low = (low // self._base_to_pm2) % self._base second_msd_high = ((high - 1) // self._base_to_pm2) % self._base msd_low = low // self._base_to_pm1 msd_high = (high - 1) // self._base_to_pm1 mask = (msd_low + 1 == msd_high) & (second_msd_low == self._base - 1) & (second_msd_high == 0) mask = mask & valid # valid mask if not torch.any(mask): break # For sequences in *mask*, we shift left *keeping MSD fixed* low = torch.where( mask, _shift_left_keeping_msd(low, self._base, self._base_to_pm1, self._base_to_pm2), low ) high = torch.where( mask, _shift_left_keeping_msd(high, self._base, self._base_to_pm1, self._base_to_pm2) + self._base - 1, high ) if encoding: num_carry_digits = torch.where(mask, num_carry_digits + 1, num_carry_digits) else: new_digit = torch.tensor([ _next_digit(i) if m else 0 for i, m in enumerate(mask.tolist()) ], dtype=torch.int64, device=mask.device) current_code_in_int = torch.where( mask, _shift_left_keeping_msd(current_code_in_int, self._base, self._base_to_pm1, self._base_to_pm2) + new_digit, current_code_in_int ) return low, high, num_carry_digits, current_code_in_int def _process( self, pdf: torch.Tensor, symbols: Optional[torch.Tensor] = None, encoding: bool = True, return_num_padded_bits: bool = False, encoded_bits: Optional[List[List[int]]] = None, lengths: Optional[torch.Tensor] = None, ) -> List[bytes] | Tuple[List[bytes], List[int]]: assert pdf is not None, "symbols or pdf must be provided" assert pdf.ndim == 3, "input must be [B, T, V]" B, T, V = pdf.shape device = pdf.device if lengths is None: lengths = torch.full((B,),T,dtype = torch.int64,device=device) # --- interval state ---------------------------------------------------- lengths = torch.clamp(lengths, min=0, max=T) low = torch.zeros((B,), dtype=torch.int64, device=device) # NOTE: We represent the AC interval [0, 1) as rational numbers: # [0, 1) # ~ [self._low / base ** precision, (self._high + 1) / base ** precision) # = [self._low / base ** precision, self._high / base ** precision], # where the we represent the upper bound *INCLUSIVE*. This is a subtle # detail required to make the integer arithmetic work correctly given that # all involved integers have `precision` digits in base `base`. high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device) cdf = _pdf_to_cdf(pdf) if encoding: assert symbols is not None, "symbols must be provided for encoding" assert encoded_bits is None, "encoded_bits must be None for encoding" num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device) digits_sym = math.ceil(math.log(V, self._base)) max_digits = self._precision + 2 * T * digits_sym bits_buffer = torch.empty( B * max_digits, dtype=torch.int32, device=device ) buf_offsets = torch.arange( B, device=device, dtype=torch.int32 ) * max_digits _next_digit = None current_code_in_int = None decoded_symbols = None else: assert symbols is None, "symbols must be None for decoding" assert encoded_bits is not None, "encoded_bits must be provided for decoding" num_carry_digits = None bits_buffer = None buf_offsets = None # Stream read cursors ------------------------------------------------- cursor = [0] * B def _next_digit(idx: int) -> int: if cursor[idx] < len(encoded_bits[idx]): d = encoded_bits[idx][cursor[idx]] cursor[idx] += 1 return d # Add padding to ensure the AC state is well-defined when decoding the last # symbol. Note that what exactly we do here depends on how encoder # termination is implemented (see `Encoder.terminate`). return self._base - 1 current_code_in_int = torch.zeros((B,), dtype=torch.int64, device=device) for _ in range(self._precision): digits = torch.tensor([_next_digit(i) for i in range(B)], dtype=torch.int64, device=device) current_code_in_int = current_code_in_int * self._base + digits decoded_symbols = torch.zeros((B, T), dtype=torch.int64, device=device) # ---------------- main encoding --------------------------------------- for t in range(T): valid = t < lengths if not valid.any(): break #all the sample is completed cdf_t = cdf[valid, t] # [valid, V+1] low_valid = low[valid] high_valid = high[valid] width_valid = high_valid - low_valid + 1 intervals = low_valid.unsqueeze(1) + (cdf_t * width_valid.unsqueeze(1)).type(torch.int64) if encoding: symbols_t = symbols[valid, t : t+1] else: symbols_t = torch.searchsorted(intervals, current_code_in_int[valid].unsqueeze(1), right=True) - 1 # V 是词汇表大小,即 pdf.shape[-1]。 symbols_t = symbols_t.clamp(max=V-1) # ============================================================== decoded_symbols[valid, t] = symbols_t.squeeze(1) old_low = low.clone() ## there is some wrong,if there are no valid sequences to process in a batch low[valid] = intervals.gather(1, symbols_t).squeeze(1) high[valid] = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1 (low, high, bits_buffer, buf_offsets, num_carry_digits, current_code_in_int) = self.flush_matching_digits( low, high, old_low, encoding=encoding, bits_buffer=bits_buffer, buf_offsets=buf_offsets, num_carry_digits=num_carry_digits, current_code_in_int=current_code_in_int if not encoding else None, _next_digit=_next_digit if not encoding else None, valid=valid ) (low, high, num_carry_digits, current_code_in_int) = self.flush_carry_digits( low, high, encoding=encoding, num_carry_digits=num_carry_digits, current_code_in_int=current_code_in_int if not encoding else None, _next_digit=_next_digit if not encoding else None, valid=valid ) # for t in range(T): # cdf_t = cdf[:, t] # [B, V+1] # width = high - low + 1 # intervals = low.unsqueeze(1) + (cdf_t * width.unsqueeze(1)).type(torch.int64) # old_low = low # if encoding: # symbols_t = symbols[:, t:t+1] # [B, 1] # else: # symbols_t = torch.searchsorted(intervals, current_code_in_int.unsqueeze(1), right=True) - 1 # decoded_symbols[:, t] = symbols_t.squeeze(1) # low = intervals.gather(1, symbols_t).squeeze(1) # high = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1 # # Renormalise until interval large enough # ( # low, # high, # bits_buffer, # encoding variable # buf_offsets, # encoding variable # num_carry_digits, # encoding variable # current_code_in_int, # decoding variable # ) = self.flush_matching_digits( # low, # high, # old_low, # encoding, # bits_buffer, # encoding variable # buf_offsets, # encoding variable # num_carry_digits, # encoding variable # current_code_in_int, # decoding variable # _next_digit, # decoding variable # ) # ( # low, # high, # num_carry_digits, # encoding variable # current_code_in_int, # decoding variable # ) = self.flush_carry_digits( # low, # high, # encoding, # num_carry_digits, # encoding variable # current_code_in_int, # decoding variable # _next_digit, # decoding variable # ) if encoding: output_compressed_bytes = [] output_num_padded_bits = [] # -------------------- finalization ------------------------------------ bits_buffer[buf_offsets] = (low // self._base_to_pm1).to(torch.int32) buf_offsets = buf_offsets + 1 carry_sel = num_carry_digits.nonzero(as_tuple=False).flatten() if carry_sel.numel(): carry_digit = self._base - 1 rep_cnt = num_carry_digits[carry_sel] repeats_max = rep_cnt.max() grid = torch.arange( repeats_max, device=rep_cnt.device ).expand(carry_sel.size(0), repeats_max) # [K2, M] mask_rep = grid < rep_cnt.unsqueeze(1) # [K2, M] start_pos = buf_offsets[carry_sel] target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] bits_buffer[target_pos] = carry_digit buf_offsets.index_add_(0, carry_sel, rep_cnt) num_carry_digits[carry_sel] = 0 for idx in range(B): offset_start = idx * max_digits offset_end = buf_offsets[idx] bits_list = bits_buffer[offset_start:offset_end].cpu().tolist() bitstr = "".join(map(str, bits_list)) compressed_bytes, num_padded_bits = utils.bits_to_bytes(bitstr) output_compressed_bytes.append(compressed_bytes) output_num_padded_bits.append(num_padded_bits) if return_num_padded_bits: return output_compressed_bytes, output_num_padded_bits else: return output_compressed_bytes else: return decoded_symbols def batched_encode( self, pdf: torch.Tensor, symbols: torch.Tensor, lengths: Optional[torch.Tensor] = None, return_num_padded_bits: bool = False ) -> List[bytes] | Tuple[List[bytes], List[int]]: B, T, V = pdf.shape device = pdf.device if lengths is None: lengths = torch.full((B,),T,dtype = torch.int64,device=device) return self._process( pdf, symbols=symbols, encoding=True, return_num_padded_bits=return_num_padded_bits, encoded_bits=None, lengths = lengths, ## pass length to address ) def batched_decode( self, pdf: torch.Tensor, compressed_bytes: List[bytes], num_padded_bits: List[int], lengths: Optional[torch.Tensor] = None, ) -> torch.Tensor: B, T, V = pdf.shape device = pdf.device if lengths is None: lengths = torch.full((B,),T,dtype = torch.int64,device=device) assert len(compressed_bytes) == B, "encoded_bits length must be equal to batch size" assert len(num_padded_bits) == B, "num_padded_bits length must be equal to batch size" encoded_bits = [[] for _ in range(B)] for idx, (compressed_b, num_padded) in enumerate(zip(compressed_bytes, num_padded_bits)): bits = utils.bytes_to_bits(compressed_b, num_padded_bits=num_padded) encoded_bits[idx] = list(map(int, bits)) # print("[DEBUG]: encoded_bits[{}]: {}".format(idx, encoded_bits[idx])) return self._process( pdf, symbols=None, encoding=False, return_num_padded_bits=False, encoded_bits=encoded_bits, lengths=lengths, ) def incremental_batched_encode( self, pdf: torch.Tensor, symbols: torch.Tensor, lengths: Optional[torch.Tensor] = None, bit_threshold: Optional[int] = None, return_num_padded_bits: bool = False ) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]: """ Incrementally encode symbols with early stopping when bit threshold is exceeded. Args: pdf: [B, T, V] probability distributions symbols: [B, T] symbols to encode lengths: [B] length of each sequence (optional) bit_threshold: Stop encoding when any sequence exceeds this many bits return_num_padded_bits: Whether to return padding information Returns: final_compressed_bytes: List[bytes] - final compressed result for each sequence stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally) final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True) """ assert pdf.ndim == 3, "input must be [B, T, V]" B, T, V = pdf.shape device = pdf.device if lengths is None: lengths = torch.full((B,), T, dtype=torch.int64, device=device) lengths = torch.clamp(lengths, min=0, max=T) # Initialize arithmetic coding state low = torch.zeros((B,), dtype=torch.int64, device=device) high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device) num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device) # Initialize bit buffer digits_sym = math.ceil(math.log(V, self._base)) max_digits = self._precision + 2 * T * digits_sym bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device) buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits # Pre-allocate temporary buffers (avoid cloning at each step) temp_bits_buffer = torch.empty_like(bits_buffer) temp_buf_offsets = torch.empty_like(buf_offsets) temp_num_carry_digits = torch.empty_like(num_carry_digits) cdf = _pdf_to_cdf(pdf) # Track final results for each sequence - save buffer states, not bytes final_buffer = torch.empty_like(bits_buffer) final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device) final_num_padded_bits = [None] * B stopped_at_step = [-1] * B # -1 means completed normally # Track which sequences are still active active_sequences = torch.ones(B, dtype=torch.bool, device=device) # Keep track of previous step's finalized buffer state for threshold logic prev_finalized_buffer = torch.empty_like(bits_buffer) prev_finalized_ends = torch.zeros_like(buf_offsets) for t in range(T): valid = (t < lengths) & active_sequences if not valid.any(): break # All sequences completed or stopped cdf_t = cdf[valid, t] # [valid, V+1] low_valid = low[valid] high_valid = high[valid] width_valid = high_valid - low_valid + 1 intervals = low_valid.unsqueeze(1) + (cdf_t * width_valid.unsqueeze(1)).type(torch.int64) symbols_t = symbols[valid, t:t+1] old_low = low.clone() low[valid] = intervals.gather(1, symbols_t).squeeze(1) high[valid] = intervals.gather(1, (symbols_t + 1)).squeeze(1) - 1 # Flush digits and update buffers (low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits( low, high, old_low, encoding=True, bits_buffer=bits_buffer, buf_offsets=buf_offsets, num_carry_digits=num_carry_digits, current_code_in_int=None, _next_digit=None, valid=valid ) (low, high, num_carry_digits, _) = self.flush_carry_digits( low, high, encoding=True, num_carry_digits=num_carry_digits, current_code_in_int=None, _next_digit=None, valid=valid ) # Check if we need to compute results this step (if bit threshold checking or final step) need_check_threshold = bit_threshold is not None and active_sequences.any() some_seq_finished = ((t + 1 >= lengths) & active_sequences).any() if need_check_threshold or some_seq_finished: # Simulate finalization at this step using pre-allocated buffers temp_bits_buffer.copy_(bits_buffer, True) temp_buf_offsets.copy_(buf_offsets, True) temp_num_carry_digits.copy_(num_carry_digits, True) # Add final digit for all sequences (simulating termination) temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32) temp_buf_offsets += 1 # Handle remaining carry digits for all sequences carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten() if carry_sel.numel(): carry_digit = self._base - 1 rep_cnt = temp_num_carry_digits[carry_sel] repeats_max = rep_cnt.max() grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max) mask_rep = grid < rep_cnt.unsqueeze(1) start_pos = temp_buf_offsets[carry_sel] target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] temp_bits_buffer[target_pos] = carry_digit temp_buf_offsets.index_add_(0, carry_sel, rep_cnt) temp_num_carry_digits[carry_sel] = 0 # Check bit threshold and identify newly stopped sequences if need_check_threshold: current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets) exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences if exceeds_threshold.any(): stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten() for idx in stopped_indices.cpu().tolist(): # Only move indices to CPU active_sequences[idx] = False stopped_at_step[idx] = t # Save the result from PREVIOUS step (before exceeding threshold) final_buffer_ends[idx] = prev_finalized_ends[idx] offset_start = idx * max_digits offset_end = prev_finalized_ends[idx] final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end]) # If final step, all remaining active sequences need results is_final_step = (t + 1 >= lengths) & active_sequences if is_final_step.any(): final_step_indices = is_final_step.nonzero(as_tuple=False).flatten() for idx in final_step_indices.cpu().tolist(): active_sequences[idx] = False stopped_at_step[idx] = t + 1 # Save current step result for sequences that completed normally final_buffer_ends[idx] = temp_buf_offsets[idx] # Copy the finalized bits to main buffer for this sequence offset_start = idx * max_digits offset_end = temp_buf_offsets[idx] final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end]) # Update previous finalized buffer state for next iteration if need_check_threshold: prev_finalized_buffer.copy_(temp_bits_buffer) prev_finalized_ends.copy_(temp_buf_offsets) # Convert buffer states to compressed bytes at the very end final_compressed_bytes = [] for idx in range(B): offset_start = idx * max_digits offset_end = final_buffer_ends[idx] bits_list = final_buffer[offset_start:offset_end].cpu().tolist() bitstr = "".join(map(str, bits_list)) comp_bytes, num_padded = utils.bits_to_bytes(bitstr) final_compressed_bytes.append(comp_bytes) if return_num_padded_bits: final_num_padded_bits[idx] = num_padded if return_num_padded_bits: return final_compressed_bytes, stopped_at_step, final_num_padded_bits else: return final_compressed_bytes, stopped_at_step def test_incremental_encoding(): """Test the incremental encoding functionality""" print("Testing incremental encoding...") batch_size = 4 seq_len = 32 vocab_size = 64 base = 2 precision = 32 device = "cuda" if torch.cuda.is_available() else "cpu" # Create test data with different scenarios symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6) pdf = pdf.softmax(dim=-1) pdf = utils.batched_normalize_pdf_for_arithmetic_coding(pdf) # Test with variable lengths to ensure edge cases are covered lengths = torch.tensor([seq_len, seq_len-2, seq_len-4, seq_len], device=device) AC = BatchedArithmeticEncoder(base, precision) # Test incremental encoding with bit threshold bit_threshold = 20 # Stop when sequences exceed 20 bits final_compressed_bytes, stopped_at_step, final_num_padded_bits = AC.incremental_batched_encode( pdf, symbols, lengths=lengths, bit_threshold=bit_threshold, return_num_padded_bits=True ) print(f"Stopped at steps: {stopped_at_step}") print(f"Final compressed sizes (bytes): {[len(cb) if cb else 0 for cb in final_compressed_bytes]}") bit_counts = [len(cb) * 8 - pb if cb else 0 for cb, pb in zip(final_compressed_bytes, final_num_padded_bits)] print(f"Final bit counts: {bit_counts}") # Test without threshold final_compressed_bytes_full, stopped_at_step_full = AC.incremental_batched_encode( pdf, symbols, lengths=lengths, bit_threshold=None, return_num_padded_bits=False ) print(f"Full encoding stopped at: {stopped_at_step_full}") # Should all be -1 print(f"Full encoding sizes (bytes): {[len(cb) if cb else 0 for cb in final_compressed_bytes_full]}") # Consistency test: Verify that incremental encoding matches regular encoding on prefixes print("\n--- Consistency Test ---") # Extract prefixes based on where sequences stopped prefix_symbols = [] prefix_pdf = [] prefix_lengths = [] for i in range(batch_size): if stopped_at_step[i] == -1: # Sequence completed normally, use full sequence stop_point = seq_len else: # Sequence stopped early, use up to stop point stop_point = stopped_at_step[i] prefix_symbols.append(symbols[i, :stop_point]) prefix_pdf.append(pdf[i, :stop_point]) prefix_lengths.append(stop_point) # Create tensors for prefix encoding max_prefix_len = max(prefix_lengths) batch_prefix_symbols = torch.zeros((batch_size, max_prefix_len), dtype=symbols.dtype, device=device) batch_prefix_pdf = torch.zeros((batch_size, max_prefix_len, vocab_size), dtype=pdf.dtype, device=device) for i in range(batch_size): length = prefix_lengths[i] batch_prefix_symbols[i, :length] = prefix_symbols[i] batch_prefix_pdf[i, :length] = prefix_pdf[i] prefix_lengths_tensor = torch.tensor(prefix_lengths, dtype=torch.int64, device=device) # Encode prefixes using regular batched encoder prefix_compressed_bytes, prefix_num_padded_bits = AC.batched_encode( batch_prefix_pdf, batch_prefix_symbols, lengths=prefix_lengths_tensor, return_num_padded_bits=True ) # Compare results print("Comparing incremental vs regular encoding on prefixes:") all_match = True for i in range(batch_size): incremental_bytes = final_compressed_bytes[i] regular_bytes = prefix_compressed_bytes[i] if incremental_bytes == regular_bytes: print(f"Sequence {i}: ✓ Match (stopped at step {stopped_at_step[i]})") else: print(f"Sequence {i}: ✗ Mismatch (stopped at step {stopped_at_step[i]})") print(f" Incremental: {len(incremental_bytes) if incremental_bytes else 0} bytes") print(f" Regular: {len(regular_bytes)} bytes") all_match = False assert all_match, "Some incremental encodings don't match - there may be a bug" if all_match: print("✓ All incremental encodings match regular encodings on prefixes!") else: print("✗ Some incremental encodings don't match - there may be a bug") print("✓ Incremental encoding test completed") def test_consistency_various_thresholds(): """Test consistency between incremental and regular encoding with various thresholds""" print("\n=== Testing Consistency with Various Thresholds ===") batch_size = 3 seq_len = 128 vocab_size = 4 base = 2 precision = 16 device = "cuda" if torch.cuda.is_available() else "cpu" # Create fixed test data for reproducible results torch.manual_seed(42) symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6) pdf = pdf.softmax(dim=-1) lengths = torch.full((batch_size,), seq_len, device=device) AC = BatchedArithmeticEncoder(base, precision) # Test with multiple threshold values thresholds = [10, 25, 50, 100, None] # None means no threshold for threshold in thresholds: print(f"\n--- Testing with threshold: {threshold} ---") # Run incremental encoding if threshold is None: final_bytes, stop_steps = AC.incremental_batched_encode( pdf, symbols, lengths=lengths, bit_threshold=threshold ) else: final_bytes, stop_steps = AC.incremental_batched_encode( pdf, symbols, lengths=lengths, bit_threshold=threshold ) print(f"Stop steps: {stop_steps}") # Create prefix data based on stop points prefix_lengths = [] for i in range(batch_size): if stop_steps[i] == -1: prefix_lengths.append(lengths[i].item()) else: prefix_lengths.append(stop_steps[i]) max_prefix_len = max(prefix_lengths) batch_prefix_symbols = torch.zeros((batch_size, max_prefix_len), dtype=symbols.dtype, device=device) batch_prefix_pdf = torch.zeros((batch_size, max_prefix_len, vocab_size), dtype=pdf.dtype, device=device) for i in range(batch_size): length = prefix_lengths[i] batch_prefix_symbols[i, :length] = symbols[i, :length] batch_prefix_pdf[i, :length] = pdf[i, :length] prefix_lengths_tensor = torch.tensor(prefix_lengths, dtype=torch.int64, device=device) # Run regular encoding on prefixes regular_bytes = AC.batched_encode( batch_prefix_pdf, batch_prefix_symbols, lengths=prefix_lengths_tensor ) # Compare results all_consistent = True for i in range(batch_size): if final_bytes[i] == regular_bytes[i]: print(f" Seq {i}: ✓ (len={prefix_lengths[i]})") else: print(f" Seq {i}: ✗ INCONSISTENT (len={prefix_lengths[i]})") print(f" Incremental: {len(final_bytes[i]) if final_bytes[i] else 0} bytes") print(f" Regular: {len(regular_bytes[i])} bytes") all_consistent = False assert all_consistent, "Some incremental encodings don't match - there may be a bug" if all_consistent: print(f" ✓ All sequences consistent for threshold {threshold}") else: print(f" ✗ Inconsistencies found for threshold {threshold}") print("\n✓ Consistency testing completed") if __name__ == "__main__": # Test incremental encoding test_incremental_encoding() # Test consistency between incremental and regular encoding test_consistency_various_thresholds() batch_size = 32 seq_len = 128 vocab_size = 256 base = 2 precision = 32 device = "cuda" if torch.cuda.is_available() else "cpu" lengths = torch.randint(1,seq_len,(batch_size,),device=device) symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) # 构造 one-hot pdf pdf = torch.zeros(batch_size, seq_len, vocab_size, device=device) for i in range(batch_size): # get pdf in valid length so it can be normalized. s = torch.randint(0, vocab_size, (lengths[i],), device=device) symbols[i, :lengths[i]] = s for t in range(lengths[i]): pdf[i, t, s[t]] = 1.0 pdf = pdf / pdf.sum(-1, keepdim=True) pdf = torch.rand(batch_size, seq_len, vocab_size, device=device).clamp(min=1e-6) pdf = pdf.softmax(dim=-1) symbols = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) lengths = torch.randint(1, seq_len, (batch_size,), device=device) # lengths = torch.full((batch_size,), seq_len, device=device) AC = BatchedArithmeticEncoder(base, precision) # Test original functionality print("Testing original batched encoding...") start_event = torch.cuda.Event(enable_timing=True) if device == "cuda" else None end_event = torch.cuda.Event(enable_timing=True) if device == "cuda" else None if device == "cuda": start_event.record() codes, padded_bits = AC.batched_encode(pdf, symbols,lengths=lengths,return_num_padded_bits=True) if device == "cuda": end_event.record() torch.cuda.synchronize() print(f"CUDA wall clock time: {start_event.elapsed_time(end_event):.2f} ms") print([len(c) for c in codes]) decoded = AC.batched_decode(pdf, codes, padded_bits,lengths) print("[DEBUG]: decoded {} symbols {}".format(decoded[0], symbols[0])) print("✓ passed - avg. digits per seq:", [len(s) for s in codes]) # Validate the length for i in range(batch_size): print(lengths[i].item()) print(decoded[i, :lengths[i].item()]) print(symbols[i, :lengths[i].item()]) l = lengths[i].item() assert torch.all(decoded[i, :l] == symbols[i, :l]), f"Sample {i} mismatch" print("All variable-length sequences verified successfully")