import torch from m1_compression import utils import math import numpy as np from typing import List, Tuple, Callable, Any, Dict, Optional import logging from m1_compression.batched_arithmetic_coder import ( BatchedArithmeticEncoder, ) logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger() class CPUArithmeticEncoder(BatchedArithmeticEncoder): def __init__(self, base: int, precision: int): super().__init__(base=base, precision=precision) def batched_encode( self, gathered_cdfs: torch.Tensor, # [B, T, 2] symbols: torch.Tensor, lengths: Optional[torch.Tensor] = None, return_num_padded_bits: bool = False ) -> Tuple[List[bytes], List[int]]: raise NotImplementedError("CPUArithmeticEncoder does not support batched_encode") def incremental_batched_encode( self, gathered_cdfs: torch.Tensor, # [B, T, 2] vocab_size: int, lengths: Optional[torch.Tensor] = None, bit_threshold: Optional[int] = None, force_padding_to_threshold: bool = False, return_num_padded_bits: bool = False ) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]: """ Incrementally encode symbols with early stopping when bit threshold is exceeded. Args: pdf: [B, T, V] probability distributions symbols: [B, T] symbols to encode lengths: [B] length of each sequence (optional) bit_threshold: Stop encoding when any sequence exceeds this many bits force_padding_to_threshold: Force padding to threshold even if bit threshold is not exceeded return_num_padded_bits: Whether to return padding information Returns: final_compressed_bytes: List[bytes] - final compressed result for each sequence stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally) final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True) """ B, T, _ = gathered_cdfs.shape device = gathered_cdfs.device if lengths is None: lengths = torch.full((B,), T, dtype=torch.int64, device=device) lengths = torch.clamp(lengths, min=0, max=T) # Initialize arithmetic coding state low = torch.zeros((B,), dtype=torch.int64, device=device) high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device) num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device) # Initialize bit buffer digits_sym = math.ceil(math.log(vocab_size, self._base)) max_digits = self._precision + 2 * T * digits_sym bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device) buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits # Pre-allocate temporary buffers (avoid cloning at each step) temp_bits_buffer = torch.empty_like(bits_buffer) temp_buf_offsets = torch.empty_like(buf_offsets) temp_num_carry_digits = torch.empty_like(num_carry_digits) # Track final results for each sequence - save buffer states, not bytes final_buffer = torch.empty_like(bits_buffer) final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device) final_num_padded_bits = [None] * B stopped_at_step = [-1] * B # -1 means completed normally # Track which sequences are still active active_sequences = torch.ones(B, dtype=torch.bool, device=device) # Keep track of previous step's finalized buffer state for threshold logic prev_finalized_buffer = torch.empty_like(bits_buffer) prev_finalized_ends = torch.zeros_like(buf_offsets) for t in range(T): valid = (t < lengths) & active_sequences if not valid.any(): break # All sequences completed or stopped low_valid = low[valid] high_valid = high[valid] width_valid = high_valid - low_valid + 1 old_low = low.clone() low[valid] = low_valid + (gathered_cdfs[valid, t, 0] * width_valid).to(torch.int64) high[valid] = low_valid + (gathered_cdfs[valid, t, 1] * width_valid).to(torch.int64) - 1 # Flush digits and update buffers (low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits( low, high, old_low, encoding=True, bits_buffer=bits_buffer, buf_offsets=buf_offsets, num_carry_digits=num_carry_digits, current_code_in_int=None, _next_digit=None, valid=valid ) (low, high, num_carry_digits, _) = self.flush_carry_digits( low, high, encoding=True, num_carry_digits=num_carry_digits, current_code_in_int=None, _next_digit=None, valid=valid ) # Check if we need to compute results this step (if bit threshold checking or final step) need_check_threshold = bit_threshold is not None and active_sequences.any() some_seq_finished = ((t + 1 >= lengths) & active_sequences).any() if need_check_threshold or some_seq_finished: # Simulate finalization at this step using pre-allocated buffers temp_bits_buffer.copy_(bits_buffer, True) temp_buf_offsets.copy_(buf_offsets, True) temp_num_carry_digits.copy_(num_carry_digits, True) # Add final digit for all sequences (simulating termination) temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32) temp_buf_offsets += 1 # Handle remaining carry digits for all sequences carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten() if carry_sel.numel(): carry_digit = self._base - 1 rep_cnt = temp_num_carry_digits[carry_sel] repeats_max = rep_cnt.max() grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max) mask_rep = grid < rep_cnt.unsqueeze(1) start_pos = temp_buf_offsets[carry_sel] target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] temp_bits_buffer[target_pos] = carry_digit temp_buf_offsets.index_add_(0, carry_sel, rep_cnt) temp_num_carry_digits[carry_sel] = 0 # Check bit threshold and identify newly stopped sequences if need_check_threshold: current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets) exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences if exceeds_threshold.any(): stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten() for idx in stopped_indices.cpu().tolist(): # Only move indices to CPU active_sequences[idx] = False stopped_at_step[idx] = t # Save the result from PREVIOUS step (before exceeding threshold) final_buffer_ends[idx] = prev_finalized_ends[idx] offset_start = idx * max_digits offset_end = prev_finalized_ends[idx] final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end]) # If final step, all remaining active sequences need results is_final_step = (t + 1 >= lengths) & active_sequences if is_final_step.any(): final_step_indices = is_final_step.nonzero(as_tuple=False).flatten() for idx in final_step_indices.cpu().tolist(): active_sequences[idx] = False stopped_at_step[idx] = t + 1 # Save current step result for sequences that completed normally final_buffer_ends[idx] = temp_buf_offsets[idx] # Copy the finalized bits to main buffer for this sequence offset_start = idx * max_digits offset_end = temp_buf_offsets[idx] final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end]) # Update previous finalized buffer state for next iteration if need_check_threshold: prev_finalized_buffer.copy_(temp_bits_buffer) prev_finalized_ends.copy_(temp_buf_offsets) # Convert buffer states to compressed bytes at the very end final_compressed_bytes = [] for idx in range(B): offset_start = idx * max_digits offset_end = final_buffer_ends[idx] bits_list = final_buffer[offset_start:offset_end].cpu().tolist() bitstr = "".join(map(str, bits_list)) if force_padding_to_threshold: comp_bytes, num_padded = utils.bits_to_bytes_padding_to_threshold(bitstr, bit_threshold) else: comp_bytes, num_padded = utils.bits_to_bytes(bitstr) final_compressed_bytes.append(comp_bytes) if return_num_padded_bits: final_num_padded_bits[idx] = num_padded if return_num_padded_bits: return final_compressed_bytes, stopped_at_step, final_num_padded_bits else: return final_compressed_bytes, stopped_at_step