File size: 10,449 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import torch
from m1_compression import utils
import math
import numpy as np
from typing import List, Tuple, Callable, Any, Dict, Optional
import logging
from m1_compression.batched_arithmetic_coder import (
    BatchedArithmeticEncoder,
)

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

class CPUArithmeticEncoder(BatchedArithmeticEncoder):
    def __init__(self, base: int, precision: int):
        super().__init__(base=base, precision=precision)

    def batched_encode(
            self, 
            gathered_cdfs: torch.Tensor, # [B, T, 2]
            symbols: torch.Tensor,
            lengths: Optional[torch.Tensor] = None,
            return_num_padded_bits: bool = False
    ) -> Tuple[List[bytes], List[int]]:
        raise NotImplementedError("CPUArithmeticEncoder does not support batched_encode")

    def incremental_batched_encode(
            self, 
            gathered_cdfs: torch.Tensor, # [B, T, 2]
            vocab_size: int,
            lengths: Optional[torch.Tensor] = None,
            bit_threshold: Optional[int] = None,
            force_padding_to_threshold: bool = False,
            return_num_padded_bits: bool = False
    ) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]:
        """
        Incrementally encode symbols with early stopping when bit threshold is exceeded.
        
        Args:
            pdf: [B, T, V] probability distributions
            symbols: [B, T] symbols to encode
            lengths: [B] length of each sequence (optional)
            bit_threshold: Stop encoding when any sequence exceeds this many bits
            force_padding_to_threshold: Force padding to threshold even if bit threshold is not exceeded
            return_num_padded_bits: Whether to return padding information
            
        Returns:
            final_compressed_bytes: List[bytes] - final compressed result for each sequence
            stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally)
            final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True)
        """
        B, T, _ = gathered_cdfs.shape
        device = gathered_cdfs.device
        
        if lengths is None:
            lengths = torch.full((B,), T, dtype=torch.int64, device=device)
        
        lengths = torch.clamp(lengths, min=0, max=T)
        
        # Initialize arithmetic coding state
        low = torch.zeros((B,), dtype=torch.int64, device=device)
        high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device)
        num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device)
        
        # Initialize bit buffer
        digits_sym = math.ceil(math.log(vocab_size, self._base))
        max_digits = self._precision + 2 * T * digits_sym
        bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device)
        buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits

        base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits
        
        # Pre-allocate temporary buffers (avoid cloning at each step)
        temp_bits_buffer = torch.empty_like(bits_buffer)
        temp_buf_offsets = torch.empty_like(buf_offsets)
        temp_num_carry_digits = torch.empty_like(num_carry_digits)
        
        # Track final results for each sequence - save buffer states, not bytes
        final_buffer = torch.empty_like(bits_buffer)
        final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device)
        final_num_padded_bits = [None] * B
        stopped_at_step = [-1] * B  # -1 means completed normally
        
        # Track which sequences are still active
        active_sequences = torch.ones(B, dtype=torch.bool, device=device)
        
        # Keep track of previous step's finalized buffer state for threshold logic
        prev_finalized_buffer = torch.empty_like(bits_buffer)
        prev_finalized_ends = torch.zeros_like(buf_offsets)
        
        for t in range(T):
            valid = (t < lengths) & active_sequences
            
            if not valid.any():
                break  # All sequences completed or stopped
            
            low_valid = low[valid]
            high_valid = high[valid]
            width_valid = high_valid - low_valid + 1
            
            old_low = low.clone()
            low[valid] = low_valid + (gathered_cdfs[valid, t, 0] * width_valid).to(torch.int64)
            high[valid] = low_valid + (gathered_cdfs[valid, t, 1] * width_valid).to(torch.int64) - 1
            
            # Flush digits and update buffers
            (low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits(
                low, high, old_low,
                encoding=True,
                bits_buffer=bits_buffer,
                buf_offsets=buf_offsets,
                num_carry_digits=num_carry_digits,
                current_code_in_int=None,
                _next_digit=None,
                valid=valid
            )
            
            (low, high, num_carry_digits, _) = self.flush_carry_digits(
                low, high,
                encoding=True,
                num_carry_digits=num_carry_digits,
                current_code_in_int=None,
                _next_digit=None,
                valid=valid
            )
            
            # Check if we need to compute results this step (if bit threshold checking or final step)
            need_check_threshold = bit_threshold is not None and active_sequences.any()
            some_seq_finished = ((t + 1 >= lengths) & active_sequences).any()
            
            if need_check_threshold or some_seq_finished:
                # Simulate finalization at this step using pre-allocated buffers
                temp_bits_buffer.copy_(bits_buffer, True)
                temp_buf_offsets.copy_(buf_offsets, True)
                temp_num_carry_digits.copy_(num_carry_digits, True)
                
                # Add final digit for all sequences (simulating termination)
                temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32)
                temp_buf_offsets += 1
                
                # Handle remaining carry digits for all sequences
                carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten()
                if carry_sel.numel():
                    carry_digit = self._base - 1
                    rep_cnt = temp_num_carry_digits[carry_sel]
                    repeats_max = rep_cnt.max()
                    grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max)
                    mask_rep = grid < rep_cnt.unsqueeze(1)
                    
                    start_pos = temp_buf_offsets[carry_sel]
                    target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep]
                    temp_bits_buffer[target_pos] = carry_digit
                    temp_buf_offsets.index_add_(0, carry_sel, rep_cnt)
                    temp_num_carry_digits[carry_sel] = 0
                
                # Check bit threshold and identify newly stopped sequences  
                if need_check_threshold:
                    current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets)
                    exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences
                    
                    if exceeds_threshold.any():
                        stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten()
                        for idx in stopped_indices.cpu().tolist():  # Only move indices to CPU
                            active_sequences[idx] = False
                            stopped_at_step[idx] = t
                            # Save the result from PREVIOUS step (before exceeding threshold)
                            final_buffer_ends[idx] = prev_finalized_ends[idx]
                            offset_start = idx * max_digits
                            offset_end = prev_finalized_ends[idx]
                            final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end])
                
                # If final step, all remaining active sequences need results
                is_final_step = (t + 1 >= lengths) & active_sequences
                if is_final_step.any():
                    final_step_indices = is_final_step.nonzero(as_tuple=False).flatten()
                    for idx in final_step_indices.cpu().tolist():
                        active_sequences[idx] = False
                        stopped_at_step[idx] = t + 1
                        # Save current step result for sequences that completed normally
                        final_buffer_ends[idx] = temp_buf_offsets[idx]
                        # Copy the finalized bits to main buffer for this sequence
                        offset_start = idx * max_digits
                        offset_end = temp_buf_offsets[idx]
                        final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end])
        
            # Update previous finalized buffer state for next iteration
            if need_check_threshold:
                prev_finalized_buffer.copy_(temp_bits_buffer)
                prev_finalized_ends.copy_(temp_buf_offsets)
        
        # Convert buffer states to compressed bytes at the very end
        final_compressed_bytes = []
        
        for idx in range(B):
            offset_start = idx * max_digits
            offset_end = final_buffer_ends[idx]
            bits_list = final_buffer[offset_start:offset_end].cpu().tolist()
            bitstr = "".join(map(str, bits_list))
            if force_padding_to_threshold:
                comp_bytes, num_padded = utils.bits_to_bytes_padding_to_threshold(bitstr, bit_threshold)
            else:
                comp_bytes, num_padded = utils.bits_to_bytes(bitstr)
            final_compressed_bytes.append(comp_bytes)
            if return_num_padded_bits:
                final_num_padded_bits[idx] = num_padded
    
        if return_num_padded_bits:
            return final_compressed_bytes, stopped_at_step, final_num_padded_bits
        else:
            return final_compressed_bytes, stopped_at_step