File size: 10,449 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | import torch
from m1_compression import utils
import math
import numpy as np
from typing import List, Tuple, Callable, Any, Dict, Optional
import logging
from m1_compression.batched_arithmetic_coder import (
BatchedArithmeticEncoder,
)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger()
class CPUArithmeticEncoder(BatchedArithmeticEncoder):
def __init__(self, base: int, precision: int):
super().__init__(base=base, precision=precision)
def batched_encode(
self,
gathered_cdfs: torch.Tensor, # [B, T, 2]
symbols: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
return_num_padded_bits: bool = False
) -> Tuple[List[bytes], List[int]]:
raise NotImplementedError("CPUArithmeticEncoder does not support batched_encode")
def incremental_batched_encode(
self,
gathered_cdfs: torch.Tensor, # [B, T, 2]
vocab_size: int,
lengths: Optional[torch.Tensor] = None,
bit_threshold: Optional[int] = None,
force_padding_to_threshold: bool = False,
return_num_padded_bits: bool = False
) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]:
"""
Incrementally encode symbols with early stopping when bit threshold is exceeded.
Args:
pdf: [B, T, V] probability distributions
symbols: [B, T] symbols to encode
lengths: [B] length of each sequence (optional)
bit_threshold: Stop encoding when any sequence exceeds this many bits
force_padding_to_threshold: Force padding to threshold even if bit threshold is not exceeded
return_num_padded_bits: Whether to return padding information
Returns:
final_compressed_bytes: List[bytes] - final compressed result for each sequence
stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally)
final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True)
"""
B, T, _ = gathered_cdfs.shape
device = gathered_cdfs.device
if lengths is None:
lengths = torch.full((B,), T, dtype=torch.int64, device=device)
lengths = torch.clamp(lengths, min=0, max=T)
# Initialize arithmetic coding state
low = torch.zeros((B,), dtype=torch.int64, device=device)
high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device)
num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device)
# Initialize bit buffer
digits_sym = math.ceil(math.log(vocab_size, self._base))
max_digits = self._precision + 2 * T * digits_sym
bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device)
buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits
base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits
# Pre-allocate temporary buffers (avoid cloning at each step)
temp_bits_buffer = torch.empty_like(bits_buffer)
temp_buf_offsets = torch.empty_like(buf_offsets)
temp_num_carry_digits = torch.empty_like(num_carry_digits)
# Track final results for each sequence - save buffer states, not bytes
final_buffer = torch.empty_like(bits_buffer)
final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device)
final_num_padded_bits = [None] * B
stopped_at_step = [-1] * B # -1 means completed normally
# Track which sequences are still active
active_sequences = torch.ones(B, dtype=torch.bool, device=device)
# Keep track of previous step's finalized buffer state for threshold logic
prev_finalized_buffer = torch.empty_like(bits_buffer)
prev_finalized_ends = torch.zeros_like(buf_offsets)
for t in range(T):
valid = (t < lengths) & active_sequences
if not valid.any():
break # All sequences completed or stopped
low_valid = low[valid]
high_valid = high[valid]
width_valid = high_valid - low_valid + 1
old_low = low.clone()
low[valid] = low_valid + (gathered_cdfs[valid, t, 0] * width_valid).to(torch.int64)
high[valid] = low_valid + (gathered_cdfs[valid, t, 1] * width_valid).to(torch.int64) - 1
# Flush digits and update buffers
(low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits(
low, high, old_low,
encoding=True,
bits_buffer=bits_buffer,
buf_offsets=buf_offsets,
num_carry_digits=num_carry_digits,
current_code_in_int=None,
_next_digit=None,
valid=valid
)
(low, high, num_carry_digits, _) = self.flush_carry_digits(
low, high,
encoding=True,
num_carry_digits=num_carry_digits,
current_code_in_int=None,
_next_digit=None,
valid=valid
)
# Check if we need to compute results this step (if bit threshold checking or final step)
need_check_threshold = bit_threshold is not None and active_sequences.any()
some_seq_finished = ((t + 1 >= lengths) & active_sequences).any()
if need_check_threshold or some_seq_finished:
# Simulate finalization at this step using pre-allocated buffers
temp_bits_buffer.copy_(bits_buffer, True)
temp_buf_offsets.copy_(buf_offsets, True)
temp_num_carry_digits.copy_(num_carry_digits, True)
# Add final digit for all sequences (simulating termination)
temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32)
temp_buf_offsets += 1
# Handle remaining carry digits for all sequences
carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten()
if carry_sel.numel():
carry_digit = self._base - 1
rep_cnt = temp_num_carry_digits[carry_sel]
repeats_max = rep_cnt.max()
grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max)
mask_rep = grid < rep_cnt.unsqueeze(1)
start_pos = temp_buf_offsets[carry_sel]
target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep]
temp_bits_buffer[target_pos] = carry_digit
temp_buf_offsets.index_add_(0, carry_sel, rep_cnt)
temp_num_carry_digits[carry_sel] = 0
# Check bit threshold and identify newly stopped sequences
if need_check_threshold:
current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets)
exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences
if exceeds_threshold.any():
stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten()
for idx in stopped_indices.cpu().tolist(): # Only move indices to CPU
active_sequences[idx] = False
stopped_at_step[idx] = t
# Save the result from PREVIOUS step (before exceeding threshold)
final_buffer_ends[idx] = prev_finalized_ends[idx]
offset_start = idx * max_digits
offset_end = prev_finalized_ends[idx]
final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end])
# If final step, all remaining active sequences need results
is_final_step = (t + 1 >= lengths) & active_sequences
if is_final_step.any():
final_step_indices = is_final_step.nonzero(as_tuple=False).flatten()
for idx in final_step_indices.cpu().tolist():
active_sequences[idx] = False
stopped_at_step[idx] = t + 1
# Save current step result for sequences that completed normally
final_buffer_ends[idx] = temp_buf_offsets[idx]
# Copy the finalized bits to main buffer for this sequence
offset_start = idx * max_digits
offset_end = temp_buf_offsets[idx]
final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end])
# Update previous finalized buffer state for next iteration
if need_check_threshold:
prev_finalized_buffer.copy_(temp_bits_buffer)
prev_finalized_ends.copy_(temp_buf_offsets)
# Convert buffer states to compressed bytes at the very end
final_compressed_bytes = []
for idx in range(B):
offset_start = idx * max_digits
offset_end = final_buffer_ends[idx]
bits_list = final_buffer[offset_start:offset_end].cpu().tolist()
bitstr = "".join(map(str, bits_list))
if force_padding_to_threshold:
comp_bytes, num_padded = utils.bits_to_bytes_padding_to_threshold(bitstr, bit_threshold)
else:
comp_bytes, num_padded = utils.bits_to_bytes(bitstr)
final_compressed_bytes.append(comp_bytes)
if return_num_padded_bits:
final_num_padded_bits[idx] = num_padded
if return_num_padded_bits:
return final_compressed_bytes, stopped_at_step, final_num_padded_bits
else:
return final_compressed_bytes, stopped_at_step
|