| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import IterableDataset, DataLoader |
| import json |
| import numpy as np |
| from pathlib import Path |
| from typing import Iterator, List, Dict, Any, Tuple, Optional, Union, Callable |
| import logging |
| import argparse |
| import base64 |
| import gc |
| from collections import defaultdict, Counter, deque |
| from m1_compression.batched_arithmetic_coder import ( |
| _pdf_to_cdf, |
| ) |
| from m1_compression.hybrid_arithmetic_coder import CPUArithmeticEncoder |
| from m1_compression import utils |
| from m1_compression.compressor import ( |
| load_m1_model_and_tokenizer, |
| load_m1_model_cpu, |
| ALPHABET_SIZE, |
| ARITHMETIC_CODER_BASE, |
| ARITHMETIC_CODER_PRECISION, |
| ) |
| import torch.multiprocessing as mp |
| from offline_utils import ( |
| unpack_windows, |
| pseudo_to_packed_bytes, |
| pad_batch, |
| find_next_batch_range, |
| packed_bytes_to_pseudo, |
| pseudo_to_packed_bytes, |
| pad_batch, |
| InterleavedJsonlDataset, |
| batched_m1_compress_predict_fn, |
| ) |
| MINIMUM_SEGMENT_SIZE = 3 |
| COMPRESSION_OFFSET = 256 |
| GC_FREQ = 10 |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger() |
|
|
|
|
| class SegmentCache: |
| """Cache for segments""" |
| def __init__(self, cache_size: int = 819200, cache_desc: str = "Prediction"): |
| self.cache_size = cache_size |
| self.cache_desc = cache_desc |
| self.cache: Dict[bytes, Union[torch.Tensor, List[int]]] = {} |
| logger.info(f"Created Cache with size: {cache_size}, type: {cache_desc}") |
| |
| def get_batch(self, segments: List[bytes]) -> Tuple[List[bytes], List[torch.Tensor], List[int]]: |
| """ |
| Returns: |
| - cache_misses: unique segments not in cache |
| - cache_results: CDF tensors for segments in cache (in order of hit_indices) |
| - hit_indices: indices of segments that were cache hits (in input order) |
| """ |
| segment_to_indices = defaultdict(list) |
| for idx, seg in enumerate(segments): |
| segment_to_indices[seg].append(idx) |
| unique_segments = list(segment_to_indices.keys()) |
|
|
| cache_results = {} |
| cache_misses = [] |
| for seg in unique_segments: |
| if seg in self.cache: |
| cache_results[seg] = self.cache[seg] |
| else: |
| cache_misses.append((seg, segment_to_indices[seg])) |
|
|
| hit_indices = [] |
| for seg, indices in segment_to_indices.items(): |
| if seg in cache_results: |
| for idx in indices: |
| hit_indices.append(idx) |
|
|
| logger.info(f"{self.cache_desc} cache: {len(unique_segments)} unique segments, {len(cache_results)} hits, {len(cache_misses)} misses, {len(segments)} total segments") |
| return cache_misses, cache_results, hit_indices |
| |
| def put_batch(self, segments: List[bytes], values: List[Union[torch.Tensor, List[int]]]): |
| """Store segment -> value mappings""" |
| if self.cache_size <= 0: |
| return |
| for segment, value in zip(segments, values): |
| if segment not in self.cache: |
| if len(self.cache) < self.cache_size: |
| if isinstance(value, tuple): |
| assert len(value) == 2 or len(value) == 5, "value must be a tuple of length 2 or 5" |
| cloned_value = tuple(v.clone() if isinstance(v, torch.Tensor) else v for v in value) |
| self.cache[segment] = cloned_value |
| elif isinstance(value, torch.Tensor): |
| self.cache[segment] = value.clone() |
| else: |
| self.cache[segment] = value |
|
|
| def get_batch_size_for_length(window_len, max_batch_size): |
| """Determines the batch size for a given window length.""" |
| BATCH_SIZE_TIERS = { |
| 128: max_batch_size, |
| 512: max(max_batch_size // 64, 1), |
| 1024: max(max_batch_size // 128, 1), |
| 2048: max(max_batch_size // 256, 1), |
| } |
| for max_len, batch_size in BATCH_SIZE_TIERS.items(): |
| if window_len <= max_len: |
| return batch_size |
| return 1 |
|
|
| def segment_prediction_fn( |
| batch: List[Dict[str, Any]], |
| max_m1_batch_size, |
| batched_predict_fn, |
| first_byte_prob, |
| debug, |
| prediction_cache: Optional[SegmentCache] = None |
| ): |
| """ |
| Consumer: reads from task_queue, compresses, puts result in result_queue. |
| """ |
| all_segments = [] |
| compressed_or_raw_segments = [] |
| sample_idx_to_list_segment_idx = defaultdict(list) |
| segment_idx = 0 |
| for sample_idx, item in enumerate(batch): |
| assert "windows_starts_lens_b64" in item, "windows_starts_lens_b64 must be in item" |
| sample_bytes = item["text"].encode('utf-8') |
| byte_windows = unpack_windows(sample_bytes, item["windows_starts_lens_b64"]) |
| for byte_window_indicator in byte_windows: |
| all_segments.append(byte_window_indicator[0]) |
| compressed_or_raw_segments.append(byte_window_indicator[1]) |
| sample_idx_to_list_segment_idx[sample_idx].append(segment_idx) |
| segment_idx += 1 |
|
|
| effective_segments = [] |
| ineffective_segments = [] |
| for orig_idx, (segment, indicator) in enumerate(zip(all_segments, compressed_or_raw_segments)): |
| if len(segment) > MINIMUM_SEGMENT_SIZE and indicator == 1: |
| effective_segments.append((orig_idx, segment)) |
| else: |
| ineffective_segments.append((orig_idx, segment)) |
|
|
| sorted_effective_segments = sorted(effective_segments, key=lambda x: len(x[1])) |
| sorted_idx, sorted_segments = zip(*sorted_effective_segments) |
| sorted_segments = list(sorted_segments) |
| effective_segments_idx_map = { |
| orig_idx: new_idx |
| for new_idx, orig_idx in enumerate(sorted_idx) |
| } |
| raw_idx, raw_segments = zip(*ineffective_segments) |
| raw_segments = list(raw_segments) |
| ineffective_segments_idx_map = { |
| orig_idx: new_idx |
| for new_idx, orig_idx in enumerate(raw_idx) |
| } |
| if debug: |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| start_event.record() |
| torch.cuda.synchronize() |
| print("[Debug CUDA] time start", flush=True) |
|
|
| assert first_byte_prob.shape == (1, 1, ALPHABET_SIZE), "first_byte_prob must be of shape (1, 1, ALPHABET_SIZE)" |
|
|
| batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in sorted_segments] |
|
|
| M = len(batched_windows_np) |
| batched_cdf_ends = [None] * M |
| if debug: |
| batched_pdfs = [None] * M |
| else: |
| batched_pdfs = None |
|
|
| if prediction_cache is not None: |
| cache_misses_tup, cache_results, hit_indices = prediction_cache.get_batch(sorted_segments) |
| |
| |
| for hit_idx in hit_indices: |
| segment = sorted_segments[hit_idx] |
| value = cache_results[segment] |
| if debug: |
| batched_cdf_ends[hit_idx] = value[0] |
| batched_pdfs[hit_idx] = value[1] |
| else: |
| batched_cdf_ends[hit_idx] = value |
| |
| |
| cache_misses, cache_miss_indices = zip(*cache_misses_tup) |
| else: |
| cache_misses = sorted_segments |
| cache_miss_indices = [[i] for i in range(M)] |
| |
| |
| if cache_misses: |
| cache_miss_cdf_ends = [] |
| cache_miss_pdfs = [] if debug else None |
| |
| start_idx = 0 |
| batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in cache_misses] |
| miss_count = len(batched_windows_np) |
| |
| while start_idx < miss_count: |
| |
| start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size, get_batch_size_for_length) |
| windows_np_chunked = batched_windows_np[start_idx:end_idx] |
| padded_batched_windows, lengths = pad_batch(windows_np_chunked) |
| padded_batched_windows, lengths = padded_batched_windows.cuda(), lengths.cuda() |
| with torch.no_grad(): |
| prompt_probs = batched_predict_fn(padded_batched_windows) |
| prompt_probs = torch.cat( |
| [ |
| first_byte_prob.expand(prompt_probs.shape[0], -1, -1), |
| prompt_probs[:, :-1, ...] |
| ], |
| dim=1 |
| ) |
| prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs) |
| cdfs_gpu = _pdf_to_cdf(prompt_probs) |
| cdf_low = cdfs_gpu.gather(2, padded_batched_windows.unsqueeze(-1)).squeeze(-1) |
| cdf_high = cdfs_gpu.gather(2, (padded_batched_windows + 1).unsqueeze(-1)).squeeze(-1) |
| cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) |
|
|
| start_idx = end_idx |
| if debug: |
| cache_miss_pdfs.extend(prompt_probs.cpu()) |
| cache_miss_cdf_ends.extend(cdf_ends.cpu()) |
| |
| |
| for idx, miss_indices in enumerate(cache_miss_indices): |
| for orig_idx in miss_indices: |
| batched_cdf_ends[orig_idx] = cache_miss_cdf_ends[idx] |
| if debug: |
| batched_pdfs[orig_idx] = cache_miss_pdfs[idx] |
| |
| |
| if prediction_cache is not None: |
| if debug: |
| prediction_cache.put_batch(cache_misses, zip(cache_miss_cdf_ends, cache_miss_pdfs)) |
| else: |
| prediction_cache.put_batch(cache_misses, cache_miss_cdf_ends) |
|
|
| return ( |
| batch, |
| sorted_segments, |
| raw_segments, |
| effective_segments_idx_map, |
| ineffective_segments_idx_map, |
| sample_idx_to_list_segment_idx, |
| batched_cdf_ends, |
| batched_pdfs, |
| ) |
|
|
| def segment_compression_fn( |
| batch: List[Dict[str, Any]], |
| sorted_segments: List[List[int]], |
| raw_segments: List[List[int]], |
| effective_segments_idx_map: Dict[int, int], |
| ineffective_segments_idx_map: Dict[int, int], |
| sample_idx_to_list_segment_idx: Dict[int, List[int]], |
| batched_cdf_ends: List[torch.Tensor], |
| batched_pdfs: List[torch.Tensor], |
| output_window_size: int, |
| escape_first_byte: bool, |
| iterative_compress: bool, |
| force_padding_to_threshold: bool, |
| predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| debug: bool = False, |
| compression_cache: Optional[SegmentCache] = None |
| ): |
| ENCODING_BATCH_SIZE = 512 |
| if iterative_compress: |
| assert not escape_first_byte, "iterative_compress does not support escape_first_byte" |
| M = len(batched_cdf_ends) |
| processed_batched_compressed_bytes = [None] * M |
| if debug: |
| batched_stop_steps = [None] * M |
| batched_num_padded_bits = [None] * M |
| batched_prompt_probs = [None] * M |
| batched_lengths = [None] * M |
| |
| |
| if compression_cache is not None: |
| cache_misses_tup, cache_results, hit_indices = compression_cache.get_batch(sorted_segments) |
| |
| |
| for hit_idx in hit_indices: |
| segment = sorted_segments[hit_idx] |
| if debug: |
| assert len(cache_results[segment]) == 5, "cache_results must be a tuple of length 5" |
| if isinstance(cache_results[segment][0], tuple): |
| processed_batched_compressed_bytes[hit_idx] = cache_results[segment][0][0] |
| batched_stop_steps[hit_idx] = None |
| batched_num_padded_bits[hit_idx] = None |
| batched_prompt_probs[hit_idx] = None |
| batched_lengths[hit_idx] = None |
| else: |
| processed_batched_compressed_bytes[hit_idx] = cache_results[segment][0] |
| batched_stop_steps[hit_idx] = cache_results[segment][1] |
| batched_num_padded_bits[hit_idx] = cache_results[segment][2] |
| batched_prompt_probs[hit_idx] = cache_results[segment][3] |
| batched_lengths[hit_idx] = cache_results[segment][4] |
| else: |
| processed_batched_compressed_bytes[hit_idx] = cache_results[segment] |
| |
| |
| cache_misses, cache_miss_indices = zip(*cache_misses_tup) |
| else: |
| cache_misses = sorted_segments |
| cache_miss_indices = [[i] for i in range(M)] |
|
|
| |
| if cache_misses: |
| cache_miss_compressed_bytes = [] |
| cache_miss_stop_steps = [] |
| cache_miss_num_padded_bits = [] |
| cache_miss_prompt_probs = [] |
| cache_miss_lengths = [] |
| |
| encoder = CPUArithmeticEncoder( |
| base=ARITHMETIC_CODER_BASE, |
| precision=ARITHMETIC_CODER_PRECISION |
| ) |
| |
| |
| miss_cdf_ends = [batched_cdf_ends[miss_indices[0]] for miss_indices in cache_miss_indices] |
| if debug: |
| miss_pdfs = [batched_pdfs[miss_indices[0]] for miss_indices in cache_miss_indices] |
| else: |
| miss_pdfs = None |
| miss_count = len(cache_misses) |
| |
| cache_miss_compressed_results = [] |
| |
| for chunk_idx in range(0, miss_count, ENCODING_BATCH_SIZE): |
| chunk_start = chunk_idx |
| chunk_end = min(chunk_idx + ENCODING_BATCH_SIZE, miss_count) |
| chunk_size = chunk_end - chunk_start |
| chunk_segments = cache_misses[chunk_start:chunk_end] |
| chunk_cdf_ends = miss_cdf_ends[chunk_start:chunk_end] |
| lengths = torch.tensor([len(segment) for segment in chunk_segments], dtype=torch.int64) |
| padded_chunk_cdf_ends = torch.zeros( |
| (chunk_size, lengths.max().item(), 2), |
| device="cpu" |
| ) |
| for idx, (cdf_end, length) in enumerate(zip(chunk_cdf_ends, lengths)): |
| padded_chunk_cdf_ends[idx, :length, :] = cdf_end[:length, :] |
|
|
| if escape_first_byte: |
| chunked_compressed_bytes, chunked_stop_steps, chunked_num_padded_bits = encoder.incremental_batched_encode( |
| padded_chunk_cdf_ends[:, 1:, ...], |
| ALPHABET_SIZE, |
| lengths - 1, |
| bit_threshold=output_window_size, |
| force_padding_to_threshold=force_padding_to_threshold, |
| return_num_padded_bits=True |
| ) |
| |
| chunked_stop_steps = [step + 1 for step in chunked_stop_steps] |
| else: |
| chunked_compressed_bytes, chunked_stop_steps, chunked_num_padded_bits = encoder.incremental_batched_encode( |
| padded_chunk_cdf_ends, |
| ALPHABET_SIZE, |
| lengths, |
| bit_threshold=output_window_size, |
| force_padding_to_threshold=force_padding_to_threshold, |
| return_num_padded_bits=True |
| ) |
| cache_miss_compressed_bytes.extend(chunked_compressed_bytes) |
| cache_miss_stop_steps.extend(chunked_stop_steps) |
| if debug: |
| chunk_pdfs = miss_pdfs[chunk_start:chunk_end] |
| padded_chunk_pdfs = torch.zeros( |
| (chunk_size, lengths.max().item(), ALPHABET_SIZE), |
| device="cpu" |
| ) |
| for idx, (pdf, length) in enumerate(zip(chunk_pdfs, lengths)): |
| padded_chunk_pdfs[idx, :length, :] = pdf[:length, :] |
|
|
| if escape_first_byte: |
| cache_miss_num_padded_bits.extend(chunked_num_padded_bits) |
| cache_miss_prompt_probs.extend(padded_chunk_pdfs[:, 1:, ...]) |
| cache_miss_lengths.extend(lengths - 1) |
| else: |
| cache_miss_num_padded_bits.extend(chunked_num_padded_bits) |
| cache_miss_prompt_probs.extend(padded_chunk_pdfs) |
| cache_miss_lengths.extend(lengths) |
|
|
| for i in range(chunk_start, chunk_end): |
| window_bytes = cache_misses[i] |
| stop_step = cache_miss_stop_steps[i] |
| _compressed_bytes = list(cache_miss_compressed_bytes[i]) |
| compressed_bytes = [COMPRESSION_OFFSET + b for b in _compressed_bytes] |
| if escape_first_byte: |
| compressed_bytes = list(window_bytes[0:1]) + compressed_bytes |
| if stop_step == -1 or stop_step >= len(window_bytes): |
| cache_miss_compressed_results.append(compressed_bytes) |
| else: |
| remaining_raw_bytes = list(window_bytes[stop_step:]) |
| if iterative_compress and len(remaining_raw_bytes) > MINIMUM_SEGMENT_SIZE: |
| cache_miss_compressed_results.append((remaining_raw_bytes, compressed_bytes)) |
| else: |
| compressed_bytes = compressed_bytes + remaining_raw_bytes |
| cache_miss_compressed_results.append(compressed_bytes) |
|
|
| if iterative_compress: |
| incomplete_window_ids = [] |
| incomplete_window_remaining_bytes = [] |
| incomplete_window_compressed_bytes = [] |
| for i, compressed_bytes in enumerate(cache_miss_compressed_results): |
| if isinstance(compressed_bytes, tuple): |
| incomplete_window_ids.append(i) |
| incomplete_window_remaining_bytes.append(compressed_bytes[0]) |
| incomplete_window_compressed_bytes.append(compressed_bytes[1]) |
|
|
| remaining_compressed_bytes = iterative_compress_ac( |
| incomplete_window_remaining_bytes, |
| predict_fn, |
| first_byte_prob, |
| output_window_size, |
| force_padding_to_threshold, |
| ENCODING_BATCH_SIZE, |
| debug |
| ) |
| for i, remaining_compressed_b in enumerate(remaining_compressed_bytes): |
| id_in_cache = incomplete_window_ids[i] |
| final_compressed_bytes = incomplete_window_compressed_bytes[i] + remaining_compressed_b |
| if debug: |
| cache_miss_compressed_results[id_in_cache] = (final_compressed_bytes, "skip_debug") |
| else: |
| cache_miss_compressed_results[id_in_cache] = final_compressed_bytes |
| logger.info(f"[DEBUG] total remaining windows: {len(incomplete_window_ids)}") |
| |
| for idx, miss_indices in enumerate(cache_miss_indices): |
| for orig_idx in miss_indices: |
| if debug: |
| if isinstance(cache_miss_compressed_results[idx], tuple): |
| assert cache_miss_compressed_results[idx][1] == "skip_debug" |
| processed_batched_compressed_bytes[orig_idx] = cache_miss_compressed_results[idx][0] |
| batched_stop_steps[orig_idx] = None |
| batched_num_padded_bits[orig_idx] = None |
| batched_prompt_probs[orig_idx] = None |
| batched_lengths[orig_idx] = None |
| else: |
| processed_batched_compressed_bytes[orig_idx] = cache_miss_compressed_results[idx] |
| batched_stop_steps[orig_idx] = cache_miss_stop_steps[idx] |
| batched_num_padded_bits[orig_idx] = cache_miss_num_padded_bits[idx] |
| batched_prompt_probs[orig_idx] = cache_miss_prompt_probs[idx] |
| batched_lengths[orig_idx] = cache_miss_lengths[idx] |
| else: |
| processed_batched_compressed_bytes[orig_idx] = cache_miss_compressed_results[idx] |
| |
| if compression_cache is not None: |
| if debug: |
| compression_cache.put_batch( |
| cache_misses, |
| zip( |
| cache_miss_compressed_results, |
| cache_miss_stop_steps, |
| cache_miss_num_padded_bits, |
| cache_miss_prompt_probs, |
| cache_miss_lengths |
| ) |
| ) |
| else: |
| compression_cache.put_batch(cache_misses, cache_miss_compressed_results) |
|
|
| |
| B = len(batch) |
| |
| pseudo_lens_per_segment = [[] for _ in range(B)] |
| |
|
|
| compressed_bytes = [[] for _ in range(B)] |
| original_bytes = [[] for _ in range(B)] |
|
|
| for sample_idx, list_segment_idx in sample_idx_to_list_segment_idx.items(): |
| for segment_idx in list_segment_idx: |
| if segment_idx in effective_segments_idx_map: |
| compressed_idx = effective_segments_idx_map[segment_idx] |
| compressed_byte = processed_batched_compressed_bytes[compressed_idx] |
| |
| else: |
| raw_idx = ineffective_segments_idx_map[segment_idx] |
| compressed_byte = raw_segments[raw_idx] |
| |
| |
| pseudo_lens_per_segment[sample_idx].append(len(compressed_byte)) |
| |
| compressed_bytes[sample_idx].extend(list(compressed_byte)) |
| if debug: |
| if segment_idx in effective_segments_idx_map: |
| compressed_idx = effective_segments_idx_map[segment_idx] |
| original_byte = sorted_segments[compressed_idx] |
|
|
| _debug_prompt_probs = batched_prompt_probs[compressed_idx] |
| _debug_padded_bits = batched_num_padded_bits[compressed_idx] |
| _debug_lengths = batched_lengths[compressed_idx] |
| _debug_stop_step = batched_stop_steps[compressed_idx] |
|
|
| if _debug_prompt_probs is None: |
| original_bytes[sample_idx].append(original_byte) |
| continue |
|
|
| processed_compressed_byte = processed_batched_compressed_bytes[compressed_idx] |
| |
| if escape_first_byte: |
| _debug_escaped_compressed_byte = processed_compressed_byte[1:] |
| else: |
| _debug_escaped_compressed_byte = processed_compressed_byte |
| if _debug_stop_step == -1 or _debug_stop_step >= len(original_byte): |
| _debug_compressed_byte = _debug_escaped_compressed_byte |
| _debug_raw_remaining_bytes = None |
| raw_bytes_len = None |
| else: |
| raw_bytes_len = len(original_byte[_debug_stop_step:]) |
| _debug_compressed_byte = _debug_escaped_compressed_byte[:-raw_bytes_len] |
| _debug_raw_remaining_bytes = _debug_escaped_compressed_byte[-raw_bytes_len:] |
| _debug_compressed_byte = [b - COMPRESSION_OFFSET for b in _debug_compressed_byte] |
|
|
| print(f"##### _debug_pdfs is {_debug_prompt_probs.shape}") |
| print(f"##### _debug_padded is {_debug_padded_bits}") |
| print(f"##### _debug_compressed is {_debug_compressed_byte}") |
| print(f"##### _debug_lengths is {_debug_lengths}") |
| print(f"##### _debug_stop_step is {_debug_stop_step}") |
| print(f"##### _debug_raw_remaining_bytes is {_debug_raw_remaining_bytes}") |
| print(f"##### raw_bytes_len is {raw_bytes_len}") |
| print(f"##### original_byte len is {len(original_byte)}") |
| decoded = encoder.batched_decode( |
| _debug_prompt_probs.unsqueeze(0), |
| [_debug_compressed_byte], |
| [_debug_padded_bits], |
| _debug_lengths.unsqueeze(0) |
| )[0, :_debug_lengths.item()].cpu().tolist() |
|
|
| print(f"##### AC decoded is {decoded}") |
| if escape_first_byte: |
| decoded = processed_compressed_byte[0:1] + decoded |
|
|
| if _debug_stop_step < (_debug_lengths.item() + 1): |
| decoded = decoded[:_debug_stop_step] |
| else: |
| if _debug_stop_step < _debug_lengths.item(): |
| decoded = decoded[:_debug_stop_step] |
| print(f"##### escape_first_byte decoded is {decoded}") |
| if _debug_raw_remaining_bytes: |
| decoded = decoded + _debug_raw_remaining_bytes |
| print(f"##### decoded is {decoded}") |
| print(f"##### original_byte is {list(original_byte)}") |
| assert bytes(decoded) == original_byte, "roundtrip encoding/decoding failed \n{} and \n{}".format(bytes(decoded), original_byte) |
| else: |
| raw_idx = ineffective_segments_idx_map[segment_idx] |
| original_byte = raw_segments[raw_idx] |
| original_bytes[sample_idx].append(original_byte) |
|
|
| |
| if debug: |
| logger.info("Running internal self-verification test...") |
| for i in range(B): |
| item = batch[i] |
| |
| original_segments = unpack_windows(item["text"].encode('utf-8'), item["windows_starts_lens_b64"]) |
| generated_lens = pseudo_lens_per_segment[i] |
| generated_pseudo_list = compressed_bytes[i] |
| |
| assert len(original_segments) == len(generated_lens), \ |
| f"Metadata length mismatch for sample {i}: segments={len(original_segments)}, lens={len(generated_lens)}" |
| |
| test_ptr = 0 |
| for j in range(len(original_segments)): |
| raw_chunk, indicator = original_segments[j] |
| segment_len = generated_lens[j] |
| pseudo_slice = generated_pseudo_list[test_ptr : test_ptr + segment_len] |
| |
| if indicator == 0: |
| assert list(raw_chunk) == pseudo_slice, \ |
| f"Hole content mismatch for sample {i}, segment {j}" |
| |
| test_ptr += segment_len |
| |
| assert test_ptr == len(generated_pseudo_list), \ |
| f"Total length mismatch for sample {i}: ptr_sum={test_ptr}, total_len={len(generated_pseudo_list)}" |
| logger.info("✓ Internal self-verification test passed for all samples in the batch!") |
| |
| |
| if debug: |
| assert len(compressed_bytes) == len(batch) |
| for sample_idx in range(len(batch)): |
| assert b"".join(original_bytes[sample_idx]) == batch[sample_idx]["text"].encode('utf-8'), ( |
| "Assembled original bytes does not match the original batch: \n{} and \n{}".format( |
| b"".join(original_bytes[sample_idx]), batch[sample_idx]["text"].encode('utf-8') |
| ) |
| ) |
| |
| |
| |
| |
| logger.info(f"Example compressed bytes: {compressed_bytes[0]}") |
|
|
| write_results = [] |
| ac_key = f"m1_ac_ow{output_window_size}_escapefb-{escape_first_byte}_iterative-{iterative_compress}_forcepadding-{force_padding_to_threshold}" |
| for item, compressed_bytes_item in zip(batch, compressed_bytes): |
| item = batch[i] |
| compressed = pseudo_to_packed_bytes(compressed_bytes_item) |
|
|
| result = { |
| **item, |
| ac_key: base64.b64encode(compressed).decode("ascii"), |
| "pseudo_lens_per_segment": pseudo_lens_per_segment[i] |
| } |
| if debug: |
| unpacked = packed_bytes_to_pseudo(compressed) |
| assert unpacked == compressed_bytes_item, "Unpacked does not match compressed bytes item: \n{} and \n{}".format(unpacked, compressed_bytes_item) |
| logger.info("✓ pseudo-bytes-enc-dec round-trip passes") |
| write_results.append(result) |
| orig_total_bytes = sum([len(data["text"].encode('utf-8')) for data in batch]) |
| compressed_total_bytes = sum([len(data) for data in compressed_bytes]) |
| compression_ratio = orig_total_bytes / compressed_total_bytes if compressed_total_bytes > 0 else 0 |
| logger.info(f"[DEBUG] original total bytes: {orig_total_bytes}, compressed total bytes: {compressed_total_bytes}, compression rate : {compression_ratio:.3f}") |
| return write_results |
|
|
| def iterative_compress_ac( |
| batch_windows: List[List[int]], |
| predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| output_window_size: int, |
| force_padding_to_threshold: bool, |
| max_m1_batch_size: int = 4096, |
| debug: bool = False, |
| ) -> List[bytes]: |
| """ |
| Buffer-based compression pipeline that reads max_window_size from each file, |
| performs batched compression, advances positions based on stop_steps, and repeats. |
| """ |
| |
| if debug: |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| start_event.record() |
| torch.cuda.synchronize() |
| print("[Debug CUDA] time start", flush=True) |
| original_total_bytes = sum([len(data) for data in batch_windows]) |
| print(f"[Debug] BufferBased-> Original total bytes: {original_total_bytes}", flush=True) |
| print(f"[Debug] BufferBased-> Batch size: {len(batch_windows)}", flush=True) |
|
|
| B = len(batch_windows) |
| |
| window_positions = [0] * B |
| output_compressed_bytes = [[] for _ in range(B)] |
| windows_done = [False] * B |
|
|
| if debug: |
| output_padded_bits = [[] for _ in range(B)] |
| output_prompt_probs = [[] for _ in range(B)] |
| output_lengths = [[] for _ in range(B)] |
| |
| iter_step = 0 |
| |
| while not all(windows_done): |
| iter_step += 1 |
| |
| |
| current_windows = [] |
| active_file_indices = [] |
| |
| for i in range(B): |
| if windows_done[i]: |
| continue |
| |
| |
| start_pos = window_positions[i] |
| end_pos = len(batch_windows[i]) |
| |
| if start_pos >= len(batch_windows[i]) - MINIMUM_SEGMENT_SIZE: |
| windows_done[i] = True |
| continue |
| window_bytes = batch_windows[i][start_pos:end_pos] |
| current_windows.append(window_bytes) |
| active_file_indices.append(i) |
| |
| if not current_windows: |
| break |
|
|
| start_idx = 0 |
| batched_windows_np = [np.array(data, dtype=np.uint8) for data in current_windows] |
| current_windows_count = len(batched_windows_np) |
| encoder = CPUArithmeticEncoder( |
| base=ARITHMETIC_CODER_BASE, |
| precision=ARITHMETIC_CODER_PRECISION |
| ) |
| batched_compressed_bytes = [] |
| batched_stop_steps = [] |
| if debug: |
| batched_num_padded_bits = [] |
| batched_pdfs = [] |
| _temp_cdf_ends = [] |
| _temp_lengths = [] |
| while start_idx < current_windows_count: |
| |
| start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size, get_batch_size_for_length) |
| windows_np_chunked = batched_windows_np[start_idx:end_idx] |
| padded_batched_windows, lengths = pad_batch(windows_np_chunked) |
| |
| padded_batched_windows = padded_batched_windows.cuda() |
| with torch.no_grad(): |
| prompt_probs = predict_fn(padded_batched_windows) |
| prompt_probs = torch.cat( |
| [ |
| first_byte_prob.expand(prompt_probs.shape[0], -1, -1), |
| prompt_probs[:, :-1, ...] |
| ], |
| dim=1 |
| ) |
| prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs) |
| cdfs_gpu = _pdf_to_cdf(prompt_probs) |
| cdf_low = cdfs_gpu.gather(2, padded_batched_windows.unsqueeze(-1)).squeeze(-1) |
| cdf_high = cdfs_gpu.gather(2, (padded_batched_windows + 1).unsqueeze(-1)).squeeze(-1) |
| cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) |
|
|
| start_idx = end_idx |
|
|
| _temp_cdf_ends.append(cdf_ends.cpu()) |
| _temp_lengths.append(lengths) |
| if debug: |
| batched_pdfs.extend(prompt_probs.cpu()) |
|
|
| for cdf_ends, lengths in zip(_temp_cdf_ends, _temp_lengths): |
| chunked_compressed_bytes, chunked_stop_steps, chunked_num_padded_bits = encoder.incremental_batched_encode( |
| |
| cdf_ends, |
| ALPHABET_SIZE, |
| lengths, |
| bit_threshold=output_window_size, |
| force_padding_to_threshold=force_padding_to_threshold, |
| return_num_padded_bits=True |
| ) |
| batched_compressed_bytes.extend(chunked_compressed_bytes) |
| batched_stop_steps.extend(chunked_stop_steps) |
| if debug: |
| batched_num_padded_bits.extend(chunked_num_padded_bits) |
| |
| |
| for window_idx, file_idx in enumerate(active_file_indices): |
| compressed_bytes = batched_compressed_bytes[window_idx] |
| stop_step = batched_stop_steps[window_idx] |
| |
| |
| output_compressed_bytes[file_idx].append(compressed_bytes) |
| if debug: |
| output_padded_bits[file_idx].append(batched_num_padded_bits[window_idx]) |
| output_prompt_probs[file_idx].append(batched_pdfs[window_idx]) |
| length = torch.tensor([stop_step], dtype=torch.long, device=batched_pdfs[window_idx].device) |
| output_lengths[file_idx].append(length) |
|
|
| window_positions[file_idx] += stop_step |
| if window_positions[file_idx] >= len(batch_windows[file_idx]) - MINIMUM_SEGMENT_SIZE: |
| windows_done[file_idx] = True |
|
|
| |
| final_compressed = [] |
| for i in range(B): |
| _original_byte_window = batch_windows[i] |
| _stopped_position = window_positions[i] |
|
|
| _byte_array = b''.join(output_compressed_bytes[i]) |
| offset_compressed_bytes = [b + COMPRESSION_OFFSET for b in list(_byte_array)] |
| if _stopped_position < len(_original_byte_window): |
| raw_leftover_bytes = _original_byte_window[_stopped_position:] |
| offset_compressed_bytes = offset_compressed_bytes + list(raw_leftover_bytes) |
| |
| final_compressed.append(offset_compressed_bytes) |
| if debug: |
| end_event.record() |
| torch.cuda.synchronize() |
| elapsed_time = start_event.elapsed_time(end_event) |
| print(f"[Debug CUDA] Elapsed time: {elapsed_time:.3f}ms", flush=True) |
| encoder = CPUArithmeticEncoder( |
| base=ARITHMETIC_CODER_BASE, |
| precision=ARITHMETIC_CODER_PRECISION |
| ) |
| for ( |
| output_compressed_bytes_item, |
| output_padded_bits_item, |
| output_prompt_probs_item, |
| output_lengths_item, |
| batch_windows_item, |
| stopped_position |
| ) in zip(output_compressed_bytes, output_padded_bits, output_prompt_probs, output_lengths, batch_windows, window_positions): |
| original_bytes = batch_windows_item[:stopped_position] |
| decoded_bytes = [] |
| for ( |
| _debug_compressed, |
| _debug_padded, |
| _debug_pdfs, |
| _debug_lengths |
| ) in zip( |
| output_compressed_bytes_item, |
| output_padded_bits_item, |
| output_prompt_probs_item, |
| output_lengths_item |
| ): |
| print(f"##### _debug_pdfs is {_debug_pdfs.shape}") |
| print(f"##### _debug_padded is {_debug_padded}") |
| print(f"##### _debug_compressed is {_debug_compressed}") |
| print(f"##### _debug_lengths is {_debug_lengths}") |
| print(f"##### original_bytes is {original_bytes}") |
| decoded = encoder.batched_decode(_debug_pdfs.unsqueeze(0), [_debug_compressed], [_debug_padded], _debug_lengths) |
| decoded_bytes += decoded[0, :_debug_lengths.item()].cpu().tolist() |
| print(f"##### decoded is {bytes(decoded[0, :_debug_lengths.item()].cpu().tolist())}") |
| assert decoded_bytes == original_bytes, "roundtrip encoding/decoding failed \n{} and \n{}".format(decoded_bytes, original_bytes) |
|
|
| return final_compressed |
|
|
| def writer_consumer( |
| write_queue, |
| output_file, |
| buffer_size=100, |
| debug=False, |
| output_window_size=16, |
| escape_first_byte=False, |
| compression_cache_size=819200, |
| iterative_compress=False, |
| force_padding_to_threshold=False, |
| entropy_model_path=None, |
| firstbyte_prob_path=None, |
| num_workers=None, |
| ): |
| """ |
| Writer consumer: reads compressed results from write_queue and writes to file. |
| Maintains its own buffer and writes when buffer is full or receives sentinel. |
| """ |
| if num_workers is not None: |
| num_threads = torch.get_num_threads() |
| |
| |
| new_num_threads = 1 |
| torch.set_num_threads(new_num_threads) |
| logger.info(f"[Debug] Set num threads to {new_num_threads} for writer process {mp.current_process().name}") |
| write_buf = [] |
| |
| |
| compression_cache = SegmentCache(cache_size=compression_cache_size, cache_desc="Compression") if compression_cache_size > 0 else None |
|
|
| if iterative_compress: |
| model, _, _ = load_m1_model_and_tokenizer(entropy_model_path) |
| predict_fn = batched_m1_compress_predict_fn(model) |
|
|
| if firstbyte_prob_path is not None: |
| with open(firstbyte_prob_path, 'r', encoding='utf-8') as f: |
| first_byte_prob = json.load(f) |
| print(first_byte_prob) |
| first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) |
| else: |
| first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| else: |
| predict_fn = None |
| first_byte_prob = None |
|
|
| try: |
| with open(output_file, 'w', encoding='utf-8') as f: |
| while True: |
| args = write_queue.get() |
| if args is None: |
| break |
| ( |
| batch, |
| sorted_segments, |
| raw_segments, |
| effective_segments_idx_map, |
| ineffective_segments_idx_map, |
| sample_idx_to_list_segment_idx, |
| batched_cdf_ends, |
| batched_pdfs, |
| ) = args |
| write_results = segment_compression_fn( |
| batch, |
| sorted_segments, |
| raw_segments, |
| effective_segments_idx_map, |
| ineffective_segments_idx_map, |
| sample_idx_to_list_segment_idx, |
| batched_cdf_ends, |
| batched_pdfs, |
| output_window_size, |
| escape_first_byte, |
| iterative_compress, |
| force_padding_to_threshold, |
| predict_fn, |
| first_byte_prob, |
| debug=debug, |
| compression_cache=compression_cache |
| ) |
| write_buf.extend(write_results) |
| |
| |
| if len(write_buf) >= buffer_size: |
| logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}") |
| for buffered_item in write_buf: |
| f.write(json.dumps(buffered_item) + '\n') |
| f.flush() |
| write_buf = [] |
|
|
| |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| if write_buf: |
| logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}") |
| for buffered_item in write_buf: |
| f.write(json.dumps(buffered_item) + '\n') |
| f.flush() |
| |
| except Exception as e: |
| logger.error(f"Writer process error: {e}") |
| raise |
|
|
| def merge_output_files(output_file, writer_output_files): |
| """Merge all writer output files into a single file""" |
| logger.info(f"Merging {len(writer_output_files)} writer files into {output_file}") |
| |
| with open(output_file, 'w', encoding='utf-8') as outf: |
| for writer_output_file in writer_output_files: |
| if writer_output_file.exists(): |
| with open(writer_output_file, 'r', encoding='utf-8') as inf: |
| for line in inf: |
| outf.write(line) |
| |
| writer_output_file.unlink() |
| logger.info(f"Merged and removed writer file: {writer_output_file}") |
| |
| logger.info(f"Merged output written to: {output_file}") |
| return output_file |
|
|
| def shutdown_writers(write_queue, writer_processes): |
| """Send shutdown signals to shared queue and wait for all writers to complete""" |
| |
| for i in range(len(writer_processes)): |
| write_queue.put(None) |
| logger.info(f"Sent shutdown signal {i+1}/{len(writer_processes)}") |
| |
| |
| for i, writer_process in enumerate(writer_processes): |
| writer_process.join() |
| if writer_process.exitcode != 0: |
| logger.error(f"Writer process {i} failed with exit code: {writer_process.exitcode}") |
| else: |
| logger.info(f"Writer process {i} completed successfully") |
|
|
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach') |
| parser.add_argument('--input_file', type=str, required=True, |
| help='Directory containing input JSONL files') |
| parser.add_argument('--output_dir', type=str, required=True, |
| help='Directory to write compressed results') |
| parser.add_argument('--entropy_model_path', type=str, required=True, |
| help='Path to the M1 model checkpoint') |
| parser.add_argument('--compression_model_path', type=str, required=True, |
| help='Path to the M1 model checkpoint') |
| parser.add_argument('--data_batch_size', type=int, default=512, |
| help='Size of batches for processing (default: 512)') |
| parser.add_argument('--output_window_size', type=int, default=16, |
| help='Size of window for compression (default: 16)') |
| parser.add_argument('--escape_first_byte', action='store_true', default=False, |
| help='Escape the first byte of each window (default: False)') |
| parser.add_argument('--max_window_size', type=int, default=1024, |
| help='Maximum window size for reading from each file (default: 1024)') |
| parser.add_argument('--max_entropy_batch_size', type=int, default=4096, |
| help='Size of max batch for compression (default: 4096)') |
| parser.add_argument('--max_compression_batch_size', type=int, default=4096, |
| help='Size of max batch for compression (default: 4096)') |
| parser.add_argument('--chunk_size', type=int, default=512, |
| help='Size of chunk for compression (default: 512)') |
| parser.add_argument('--base_global_quantile', type=float, default=0.9, |
| help='Base global quantile for compression (default: 0.9)') |
| parser.add_argument('--base_monotonic_quantile', type=float, default=0.9, |
| help='Base monotonic quantile for compression (default: 0.9)') |
| parser.add_argument('--debug', action='store_true', default=False, |
| help='Debug mode (default: False)') |
| parser.add_argument('--firstbyte_prob_path', type=str, default=None, |
| help='Probability path for the first word of each window (default : None)') |
| parser.add_argument('--num_workers', type=int, default=1, |
| help='Number of workers for CPU jobs (default: 1)') |
| parser.add_argument('--process_id', type=int, default=0, |
| help='Process ID for distributed processing (default: 0)') |
| parser.add_argument('--num_processes', type=int, default=1, |
| help='Number of processes for distributed processing (default: 1)') |
| parser.add_argument('--merge_output', action='store_true', default=False, |
| help='Merge all writer output files into a single file (default: False)') |
| parser.add_argument('--prediction_cache_size', type=int, default=81920, |
| help='Size of prediction cache per process (default: 819200)') |
| parser.add_argument('--compression_cache_size', type=int, default=81920, |
| help='Size of compression cache per worker (default: 819200)') |
| parser.add_argument('--disable_caching', action='store_true', default=False, |
| help='Disable both prediction and compression caching (default: False)') |
| parser.add_argument('--iterative_compress', action='store_true', default=False, |
| help='Iterative compression (default: False)') |
| parser.add_argument('--force_padding_to_threshold', action='store_true', default=False, |
| help='Force padding to threshold (default: False)') |
|
|
| args = parser.parse_args() |
|
|
| num_threads = torch.get_num_threads() |
| |
| |
| new_num_threads = 2 |
| torch.set_num_threads(new_num_threads) |
| logger.info(f"[Debug] Set num threads to {new_num_threads} for main process") |
|
|
| mp.set_start_method('spawn', force=True) |
| dump_freq = 100 |
|
|
| |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path) |
| batched_predict_fn = batched_m1_compress_predict_fn(model) |
|
|
| if args.firstbyte_prob_path is not None: |
| with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f: |
| first_byte_prob = json.load(f) |
| print(first_byte_prob) |
| first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) |
| else: |
| first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE |
|
|
| |
| dataset = InterleavedJsonlDataset( |
| file_path=args.input_file, |
| rank=args.process_id, |
| world_size=args.num_processes, |
| ) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.data_batch_size, |
| shuffle=False, |
| collate_fn=lambda x: x |
| ) |
|
|
| input_file = Path(args.input_file) |
| logger.info(f"Processing file: {input_file}") |
| |
| output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl" |
| |
| logger.info("Data loaded. Start processing...") |
|
|
| |
| prediction_cache = None |
| if not args.disable_caching and args.prediction_cache_size > 0: |
| prediction_cache = SegmentCache(cache_size=args.prediction_cache_size, cache_desc="Prediction") |
| logger.info(f"Prediction cache enabled with size: {args.prediction_cache_size}") |
| else: |
| logger.info("Prediction cache disabled") |
|
|
| compression_cache_size = 0 if args.disable_caching else args.compression_cache_size |
| if compression_cache_size > 0: |
| logger.info(f"Compression cache enabled with size: {compression_cache_size} per worker") |
| else: |
| logger.info("Compression cache disabled") |
|
|
| write_queue = mp.Queue(maxsize=200) |
| writer_processes = [] |
| writer_output_files = [] |
| for i in range(args.num_workers): |
| |
| output_path = Path(output_file) |
| writer_output_file = output_path.parent / f"{output_path.stem}_writer_{i}.jsonl" |
| writer_output_files.append(writer_output_file) |
| writer_process = mp.Process( |
| target=writer_consumer, |
| args=( |
| write_queue, |
| writer_output_file, |
| dump_freq, |
| args.debug, |
| args.output_window_size, |
| args.escape_first_byte, |
| compression_cache_size, |
| args.iterative_compress, |
| args.force_padding_to_threshold, |
| args.entropy_model_path, |
| args.firstbyte_prob_path, |
| args.num_workers, |
| ) |
| ) |
| |
| writer_processes.append(writer_process) |
| writer_process.start() |
| logger.info(f"Started writer process {i} for output file: {writer_output_file}") |
| |
| try: |
| |
| for batch_idx, batch in enumerate(dataloader): |
| pred_results = segment_prediction_fn( |
| batch, |
| max_m1_batch_size=args.max_compression_batch_size, |
| batched_predict_fn=batched_predict_fn, |
| first_byte_prob=first_byte_prob, |
| debug=args.debug, |
| prediction_cache=prediction_cache |
| ) |
| logger.info(f"Processed batch {batch_idx}") |
| write_queue.put(pred_results) |
|
|
| if batch_idx % GC_FREQ == 0: |
| |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| shutdown_writers(write_queue, writer_processes) |
| |
| except Exception as e: |
| logger.error(f"Error during processing: {e}") |
| |
| try: |
| shutdown_writers(write_queue, writer_processes) |
| except: |
| pass |
| raise |
|
|
| if args.merge_output: |
| final_output_file = merge_output_files(output_file, writer_output_files) |
| logger.info(f"Completed processing successfully, merged output written to {final_output_file}") |
| else: |
| logger.info(f"Completed processing successfully, outputs written to {args.num_workers} separate files") |
|
|
| if __name__ == "__main__": |
| main() |
|
|