| """ |
| new pipeline: |
| 1.prepare data: from original batch to unpack,sort,rerank windows and reconstruct indexes |
| 2.compress_segment_xxx(): use different compression algorithm |
| 3.reconstruct_result: reconstruct the window from compressed results and idx |
| 4.main()--- use prepare_segments and compress_seg_xx to produce and pass them to consumers |
| 5.write_consumer: get from compressed data, reconstruct and write the result |
| """ |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import IterableDataset, Dataset, DataLoader |
| import json |
| import numpy as np |
| from pathlib import Path |
| from typing import Iterator, List, Dict, Any, Callable, Tuple, Optional |
| import logging |
| import argparse |
| import base64 |
| import time |
| import math |
| import gc |
| from collections import defaultdict, Counter,deque |
| from m1_compression.utils import * |
| from m1_compression.compressor import ( |
| load_m1_model_and_tokenizer, |
| ALPHABET_SIZE, |
| ) |
| import multiprocessing as mp |
| from m1_compression.enumerative_coder_simple import SimpleAdaptiveRankCodec |
| from m1_compression.batched_arithmetic_coder import BatchedArithmeticEncoder |
| from m1_compression.hybrid_arithmetic_coder import HybridArithmeticEncoder |
| from m1_compression.compressor import ( |
| load_m1_model_and_tokenizer, |
| ALPHABET_SIZE, |
| ARITHMETIC_CODER_BASE, |
| ARITHMETIC_CODER_PRECISION, |
| ) |
| from offline_entropy_window_split import unpack_windows |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger() |
|
|
| def pseudo_to_packed_bytes(lst: list[int]) -> bytes: |
| out = bytearray() |
| acc = bits = 0 |
| for v in lst: |
| acc |= (v & 0x1FF) << bits |
| bits += 9 |
| while bits >= 8: |
| out.append(acc & 0xFF) |
| acc >>= 8 |
| bits -= 8 |
| if bits: |
| out.append(acc) |
| return bytes(out) |
|
|
| def packed_bytes_to_pseudo(b: bytes) -> list[int]: |
| out, acc, bits = [], 0, 0 |
| for byte in b: |
| acc |= byte << bits |
| bits += 8 |
| while bits >= 9: |
| out.append(acc & 0x1FF) |
| acc >>= 9 |
| bits -= 9 |
| return out |
|
|
| def calculate_compression_ratio(original_bytes: List[bytes], compressed_segments: List[bytes]) -> float: |
| if not compressed_segments or len(original_bytes) == 0: |
| return 1.0 |
| |
| total_compressed_length = sum(len(compressed_seg) for compressed_seg in compressed_segments) |
| ratio = total_compressed_length / sum(len(orig_seg) for orig_seg in original_bytes) |
| if ratio > 2.0: |
| logger.warning(f"Unusual compression ratio: {ratio:.4f} (compressed larger than original)") |
| |
| return ratio |
|
|
| def collect_window_size_statistics(segmented_results: List[List[bytes]]) -> Dict[int, int]: |
| window_size_counts = Counter() |
| |
| for segments in segmented_results: |
| for segment in segments: |
| window_size = len(segment) |
| window_size_counts[window_size] += 1 |
| |
| return dict(window_size_counts) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def pad_batch(batch: List[bytes]): |
| |
| |
| batch_tensors = [torch.tensor(list(data), dtype=torch.int64) for data in batch] |
| lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) |
| |
| |
| padded_batch = torch.nn.utils.rnn.pad_sequence( |
| batch_tensors, |
| batch_first=True, |
| padding_value=0 |
| ) |
| return padded_batch, lengths |
|
|
| |
| def get_batch_size_for_length(window_len, max_batch_size): |
| """ |
| Determines the batch size for a given window length. |
| VERY AGGRESSIVE reduction for long sequences to prevent OOM. |
| """ |
| |
| if window_len <= 128: |
| return max_batch_size |
| if window_len <= 256: |
| return max(max_batch_size // 4, 1) |
| if window_len <= 512: |
| return max(max_batch_size // 16, 1) |
| if window_len <= 1024: |
| return max(max_batch_size // 64, 1) |
| if window_len <= 2048: |
| return 2 |
| |
| return 1 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def find_next_batch_range(all_windows, start_idx, max_m1_batch_size): |
| M = len(all_windows) |
| if start_idx >= M: |
| return start_idx, start_idx |
|
|
| first_window_len = len(all_windows[start_idx]) |
| base_batch_size = get_batch_size_for_length(first_window_len, max_m1_batch_size) |
|
|
| low = start_idx |
| high = min(start_idx + base_batch_size, M) |
| high_batch_size = get_batch_size_for_length(len(all_windows[high - 1]), max_m1_batch_size) |
| if high_batch_size == base_batch_size: |
| return start_idx, high |
|
|
| search_low = low |
| search_high = high |
| while search_low < search_high: |
| mid = search_low + (search_high - search_low) // 2 |
| mid_window_len = len(all_windows[mid]) |
| if get_batch_size_for_length(mid_window_len, max_m1_batch_size) == base_batch_size: |
| |
| |
| search_low = mid + 1 |
| else: |
| |
| |
| |
| search_high = mid |
| end_idx = search_low |
| if end_idx == start_idx: |
| return start_idx, start_idx + 1 |
| else: |
| return start_idx, end_idx |
|
|
| class JsonlShardedDataset(Dataset): |
| def __init__( |
| self, |
| file_path: str, |
| current_proc_rank: int = 0, |
| total_procs: int = 1, |
| ) -> None: |
|
|
| assert 0 <= current_proc_rank < total_procs, "rank must be in [0, world_size)" |
| self.current_proc_rank = current_proc_rank |
| self.total_procs = total_procs |
|
|
| |
| with open(file_path, "r", encoding="utf-8") as f: |
| full_data: List[Dict[str, Any]] = [json.loads(line) for line in f] |
|
|
| |
| total = len(full_data) |
| per_proc = math.ceil(total / total_procs) |
| start = current_proc_rank * per_proc |
| end = min(start + per_proc, total) |
| self.data = full_data[start:end] |
|
|
| def __len__(self) -> int: |
| return len(self.data) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| return self.data[idx] |
|
|
| class InterleavedJsonlDataset(IterableDataset): |
| """ |
| An iterable-style dataset for reading a large JSONL file using an |
| interleaving/striding pattern, without yielding state information. |
| |
| This is designed for multi-process data loading. Each process reads the |
| entire file but only processes lines that match its rank (offset). |
| For `N` total processes (world_size), process `r` (rank) will read |
| lines r, r+N, r+2N, ... (0-indexed). |
| |
| This method ensures an even distribution of lines across processes. |
| |
| Args: |
| file_path (str): Path to the JSONL file. |
| rank (int): The rank of the current process, used as the offset. |
| world_size (int): The total number of processes, used as the block_size/stride. |
| """ |
| def __init__( |
| self, |
| file_path: str, |
| rank: int, |
| world_size: int, |
| ) -> None: |
| super().__init__() |
| |
| if not (0 <= rank < world_size): |
| raise ValueError(f"Rank must be in [0, {world_size-1}], but got {rank}") |
|
|
| self.file_path = file_path |
| self.offset = rank |
| self.block_size = world_size |
|
|
| def __iter__(self) -> Iterator[Dict[str, Any]]: |
| """ |
| The iterator method that yields the parsed JSON data for the assigned lines. |
| """ |
| try: |
| with open(self.file_path, "r", encoding="utf-8") as f: |
| |
| |
| for line_number, line in enumerate(f): |
| |
| if (line_number % self.block_size) == self.offset: |
| try: |
| |
| yield json.loads(line) |
| except json.JSONDecodeError: |
| |
| |
| print(f"Warning: Rank {self.offset} could not decode JSON on line ~{line_number+1}. Skipping.") |
| continue |
| except Exception as e: |
| print(f"Error in worker {self.offset}: {e}") |
| raise |
|
|
| def batched_m1_compress_predict_fn(model): |
| def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: |
| if input_tensor.dim() == 1: |
| input_tensor = input_tensor.unsqueeze(0) |
| with torch.no_grad(): |
| logits = model(input_tensor, **kwargs) |
| logits = logits[..., :256] |
| logits = logits.float() |
| assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values." |
| probs = torch.softmax(logits, dim=-1) |
| return probs |
| |
| return predict_fn |
|
|
| class CachingCompressorWrapper: |
| def __init__( |
| self, |
| base_compression_fn: Callable, |
| cache_size: int = 819200, |
| cache_policy: str = 'fifo' |
| ): |
| if cache_policy not in ['fifo']: |
| raise ValueError(f"no caching policy: {cache_policy}.") |
| self.base_compression_fn = base_compression_fn |
| self.cache_size = cache_size |
| self.cache_policy = cache_policy |
| |
| self.cache: Dict[bytes, List[int]] = {} |
| self.fifo_queue: deque[bytes] = deque() |
| logger.info(f"Create CachingCompressorWrapper '{self.base_compression_fn.__name__}'," |
| f"Cache size: {self.cache_size}, policy: {self.cache_policy}") |
| |
| def compress( |
| self, |
| sorted_segments: List[bytes], |
| *args, **kwargs |
| ) -> List[List[int]]: |
| """ |
| compressors with cache |
| """ |
| if not sorted_segments: |
| return [] |
| M = len(sorted_segments) |
| |
| segment_to_indices = defaultdict(list) |
| for i, seg in enumerate(sorted_segments): |
| segment_to_indices[seg].append(i) |
| unique_segments = list(segment_to_indices.keys()) |
| |
| misses_data = [] |
| results_for_uniques: Dict[bytes, List[int]] = {} |
|
|
| |
| for segment in unique_segments: |
| if segment in self.cache: |
| results_for_uniques[segment] = self.cache[segment] |
| else: |
| misses_data.append(segment) |
| hit_count = len(unique_segments) - len(misses_data) |
| logger.info(f"Cache checking: {len(unique_segments)} segments, " |
| f"Get {hit_count}, No caching {len(misses_data)} ") |
| |
| if misses_data: |
| |
| newly_compressed = self.base_compression_fn( |
| misses_data, *args, **kwargs |
| ) |
| |
| for i in range(len(misses_data)): |
| raw_segment = misses_data[i] |
| compressed_result = newly_compressed[i] |
| results_for_uniques[raw_segment] = compressed_result |
| |
| if self.cache_size > 0 and raw_segment not in self.cache: |
| if len(self.cache) >= self.cache_size: |
| if self.cache_policy == 'fifo': |
| oldest_key = self.fifo_queue.popleft() |
| del self.cache[oldest_key] |
| self.cache[raw_segment] = compressed_result |
| self.fifo_queue.append(raw_segment) |
| |
| all_compressed_results = [None] * M |
| for seg, indices in segment_to_indices.items(): |
| result = results_for_uniques[seg] |
| for original_index in indices: |
| all_compressed_results[original_index] = result |
| return all_compressed_results |
| |
| def __call__(self, *args, **kwargs): |
| return self.compress(*args, **kwargs) |
|
|
| def compress_segments_hybrid_arithmetic( |
| sorted_segments: List[bytes], |
| batched_predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| max_m1_batch_size: int=4096, |
| debug: bool = True |
| ) -> List[List[int]]: |
| """ |
| 这个函数现在只处理它收到的数据,不需要关心缓存或去重。 |
| 这些逻辑已经被外层的 CachingCompressorWrapper 处理了。 |
| """ |
| M = len(sorted_segments) |
| if M == 0: |
| return [] |
| |
| logger.info(f"Hybrid AC 核心: 正在处理 {M} 个不重复、未命中缓存的段。") |
| segment_to_compressed = {} |
| ENCODING_BATCH_SIZE = 128 |
| encoder = HybridArithmeticEncoder( |
| batched_predict_fn=batched_predict_fn, |
| first_byte_prob=first_byte_prob |
| ) |
| all_compressed_results = [] |
| for i in range(0, M, ENCODING_BATCH_SIZE): |
| batch_start = i |
| batch_end = min(i + ENCODING_BATCH_SIZE, M) |
| batch_segments = sorted_segments[batch_start:batch_end] |
| try: |
| codes = encoder.batched_encode(batch_segments, return_num_padded_bits=False) |
| |
| for seg, code in zip(batch_segments, codes): |
| if len(code) < len(seg): |
| all_compressed_results.append(list(code)) |
| else: |
| all_compressed_results.append(list(seg)) |
| except Exception as e: |
| logger.warning(f"Hybrid AC 核心: 批次 {batch_start}-{batch_end} 编码失败: {e}. 该批次使用原始字节。") |
| for seg in batch_segments: |
| all_compressed_results.append(list(seg)) |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return all_compressed_results |
|
|
| def prepare_segments(batch: List[Dict[str, Any]])->Dict[str,Any]: |
| """ |
| remove the unpack,sort and rerank methods from segment_pre.. |
| address unpack and judge the compressiable simultaneously |
| """ |
| all_segments = [] |
| is_compressible_indicator = [] |
| sample_idx_to_list_segment_idx = defaultdict(list) |
| segment_idx = 0 |
|
|
| for sample_idx, item in enumerate(batch): |
| assert "windows_starts_lens_b64" in item, "windows_starts_lens_b64 must be in item" |
| sample_bytes = item["text"].encode('utf-8') |
| byte_windows = unpack_windows(sample_bytes, item["windows_starts_lens_b64"]) |
| for segment,indicator in byte_windows: |
| all_segments.append(segment) |
| is_compressible = (indicator == 1 and len(segment) > 3) |
| is_compressible_indicator.append(is_compressible) |
|
|
| |
| sample_idx_to_list_segment_idx[sample_idx].append(segment_idx) |
| segment_idx += 1 |
| |
| effective_segments = {} |
| raw_segments_map = {} |
| for i, (segment, is_comp) in enumerate(zip(all_segments, is_compressible_indicator)): |
| if is_comp: |
| effective_segments[i] = segment |
| else: |
| raw_segments_map[i] = segment |
|
|
| |
| |
| sorted_indices_to_compress = sorted( |
| effective_segments.keys(), |
| key=lambda idx: len(effective_segments[idx]) |
| ) |
| sorted_segments_to_compress = [effective_segments[idx] for idx in sorted_indices_to_compress] |
|
|
| |
| |
| |
| sorted_to_original_idx_map = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices_to_compress)} |
| |
| reconstruction_info = { |
| "sample_idx_to_list_segment_idx": sample_idx_to_list_segment_idx, |
| "sorted_to_original_idx_map": sorted_to_original_idx_map, |
| "raw_segments_map": raw_segments_map, |
| "total_segments": len(all_segments), |
| "batch_meta": batch, |
| "effective_segments_map": effective_segments |
| } |
|
|
| return { |
| "sorted_segments_to_compress": sorted_segments_to_compress, |
| "reconstruction_info": reconstruction_info, |
| } |
|
|
| |
| def simple_rle_topk_compression( |
| batch: List[bytes], |
| predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| max_m1_batch_size: int = 4096, |
| debug: bool = True, |
| ): |
| """use language model to compress, return compressed bytes and padded bits |
| |
| Args: |
| sliding_windows: List of byte sequences to compress |
| predict_fn: Function that predicts next token probabilities |
| return_num_padded_bits: Whether to return number of padded bits |
| profile: Whether to print timing information for each major step |
| """ |
| if debug: |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| start_event.record() |
| torch.cuda.synchronize() |
| print("[Debug CUDA] time start", flush=True) |
|
|
| assert first_byte_prob.shape == (1, 1, ALPHABET_SIZE), "first_byte_prob must be of shape (1, 1, ALPHABET_SIZE)" |
|
|
| |
| |
| batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in batch] |
|
|
| M = len(batched_windows_np) |
|
|
| batched_repeat_probs = [] |
| batched_ranks = [] |
| batched_lengths = [] |
| if debug: |
| batched_sorted_indices = [] |
|
|
| start_idx = 0 |
| while start_idx < M: |
| |
| start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size) |
| windows_np_chunked = batched_windows_np[start_idx:end_idx] |
| padded_batched_windows, lengths = pad_batch(windows_np_chunked) |
| padded_batched_windows, lengths = padded_batched_windows.cuda(), lengths.cuda() |
|
|
| prompt_probs = predict_fn(padded_batched_windows) |
| prompt_probs = torch.cat( |
| [ |
| first_byte_prob.expand(prompt_probs.shape[0], -1, -1), |
| prompt_probs[:, :-1, ...] |
| ], |
| dim=1 |
| ) |
| prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs) |
| |
| |
| |
| |
| next_token_probs = torch.gather( |
| prompt_probs, |
| dim=-1, |
| index=padded_batched_windows.unsqueeze(-1) |
| ).squeeze(-1) |
| sorted_indices = torch.argsort(prompt_probs, dim=-1, descending=True) |
| rank_bitvector = padded_batched_windows.unsqueeze(-1) == sorted_indices |
| ranks = torch.argmax(rank_bitvector.float(), dim=-1) |
| start_idx = end_idx |
| batched_repeat_probs.extend(next_token_probs.cpu().numpy().tolist()) |
| batched_ranks.extend(ranks.cpu().numpy().tolist()) |
| batched_lengths.extend(lengths.cpu().numpy().tolist()) |
| if debug: |
| batched_sorted_indices.extend(sorted_indices.cpu().numpy().tolist()) |
|
|
| if debug: |
| return batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices |
| else: |
| return batched_repeat_probs, batched_ranks, batched_lengths |
|
|
| def compress_segments_rank_based( |
| sorted_segments: List[bytes], |
| batched_predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| max_m1_batch_size: int=4096, |
| debug: bool = True |
| ) -> List[List[int]]: |
| """ |
| (SimpleAdaptiveRankCodec)。 |
| decompress GPU probs and CPU compression。 |
| """ |
| |
| |
| try: |
|
|
| gpu_result = simple_rle_topk_compression( |
| sorted_segments, |
| batched_predict_fn, |
| first_byte_prob, |
| max_m1_batch_size=max_m1_batch_size, |
| debug=debug, |
| ) |
| if debug: |
| batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices = gpu_result |
| else: |
| batched_repeat_probs, batched_ranks, batched_lengths = gpu_result |
| batched_sorted_indices = None |
| |
|
|
| |
|
|
| if len(batched_lengths) != len(sorted_segments): |
| logger.error(f"FATAL: Length mismatch after GPU stage. Expected {len(sorted_segments)}, got {len(batched_lengths)}. Falling back to raw data.") |
| |
| return [list(seg) for seg in sorted_segments] |
|
|
| M = len(batched_lengths) |
| batched_compressed_bytes = [] |
|
|
| for i in range(M): |
| lengths = batched_lengths[i] |
| window_bytes = sorted_segments[i] |
| repeat_probs = batched_repeat_probs[i][:lengths] |
| ranks = batched_ranks[i][:lengths] |
|
|
| codec = SimpleAdaptiveRankCodec(top_k=4) |
| encoding = codec.encode_window(list(window_bytes), repeat_probs, ranks) |
| compressed_bytes = codec.encoding_to_pseudo_bytes(encoding) |
|
|
| |
| if len(compressed_bytes) >= len(window_bytes): |
| |
| batched_compressed_bytes.append(list(window_bytes)) |
| else: |
| |
| batched_compressed_bytes.append(compressed_bytes) |
|
|
| if debug: |
| |
| if batched_sorted_indices is None or batched_sorted_indices[i] is None: |
| logger.warning(f"Debug mode is on but sorted_indices for segment {i} is None. Skipping decode check.") |
| continue |
| |
| |
| |
| sorted_indices = batched_sorted_indices[i][:lengths] |
| decoded = codec.decode_window(encoding, lengths, sorted_indices) |
| assert bytes(decoded) == window_bytes, "decoded does not match window_bytes: \n{} and \n{}".format(decoded, window_bytes) |
| if i < 10: |
| logger.info(f"Example input window bytes: {window_bytes}") |
| logger.info(f"Example encoding : {encoding}") |
| logger.info(f"Example compressed bytes : {compressed_bytes}") |
| |
|
|
| return batched_compressed_bytes |
|
|
| except Exception as e: |
| logger.error(f"Unhandled exception in compress_segments_rank_based: {e}. Falling back to raw data for the entire batch.", exc_info=True) |
| |
| return [list(seg) for seg in sorted_segments] |
|
|
|
|
| def compress_segments_arithmetic( |
| sorted_segments: List[bytes], |
| batched_predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| max_m1_batch_size: int = 4096, |
| debug: bool = True |
| ) -> List[List[int]]: |
| """ |
| Final robust version for arithmetic compression. |
| This version is inspired by successful production code and is designed to be stable. |
| It compresses unique segments in small, manageable batches and handles failures gracefully. |
| """ |
| device = first_byte_prob.device |
| M = len(sorted_segments) |
| |
| if M == 0: |
| return [] |
|
|
| |
| logger.info(f"Step 1: Identifying unique segments to compress.") |
| |
| |
| |
| segment_to_indices = defaultdict(list) |
| for i, seg in enumerate(sorted_segments): |
| segment_to_indices[seg].append(i) |
| |
| |
| unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 2] |
| logger.info(f"Found {len(unique_segments)} unique segments (len>2) out of {M} total segments.") |
|
|
| |
| segment_to_compressed = {} |
| |
| |
| ENCODING_BATCH_SIZE = 128 |
| encoder = BatchedArithmeticEncoder(base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION) |
|
|
| logger.info(f"Step 2: Encoding unique segments in batches of size {ENCODING_BATCH_SIZE}.") |
| |
| for i in range(0, len(unique_segments), ENCODING_BATCH_SIZE): |
| batch_start = i |
| batch_end = min(i + ENCODING_BATCH_SIZE, len(unique_segments)) |
| batch_unique_segments = unique_segments[batch_start:batch_end] |
| |
| |
| try: |
| |
| batch_padded_segments, batch_lengths = pad_batch(batch_unique_segments) |
| batch_padded_segments = batch_padded_segments.to(device) |
| batch_lengths = batch_lengths.to(device) |
|
|
| |
| with torch.no_grad(): |
| |
| safe_padded_segments = batch_padded_segments.clamp(0, ALPHABET_SIZE - 1) |
| probs = batched_predict_fn(safe_padded_segments) |
| |
| |
| |
| |
| |
| |
|
|
| final_probs = torch.cat([first_byte_prob.expand(probs.shape[0], -1, -1), probs[:, :-1, ...]], dim=1) |
| normalized_probs = utils.batched_normalize_pdf_for_arithmetic_coding(final_probs) |
| |
| if not torch.isfinite(normalized_probs).all(): |
| raise ValueError("NaN or Inf in normalized probabilities after normalization.") |
|
|
| |
| codes, _ = encoder.batched_encode( |
| normalized_probs, |
| batch_padded_segments, |
| lengths=batch_lengths, |
| return_num_padded_bits=True |
| ) |
| |
| |
| for seg, code in zip(batch_unique_segments, codes): |
| segment_to_compressed[seg] = list(code) |
|
|
| except Exception as e: |
| |
| logger.warning(f"Batch encoding failed for unique segments {batch_start}-{batch_end}: {e}. Using raw bytes for this batch.") |
| for seg in batch_unique_segments: |
| segment_to_compressed[seg] = list(seg) |
| |
| |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| logger.info("Step 3: Reconstructing final list from unique compressed segments.") |
| all_compressed_results = [None] * M |
| |
| for seg, indices in segment_to_indices.items(): |
| if len(seg) <= 2: |
| |
| result = list(seg) |
| else: |
| |
| compressed_data = segment_to_compressed.get(seg, list(seg)) |
| if len(compressed_data) >= len(seg): |
| result = list(seg) |
| else: |
| result = compressed_data |
| |
| |
| for original_index in indices: |
| all_compressed_results[original_index] = result |
| |
| return all_compressed_results |
|
|
|
|
|
|
| def compress_segments_hybrid_arithmetic( |
| sorted_segments: List[bytes], |
| batched_predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| max_m1_batch_size: int=4096, |
| debug: bool = True |
| ) -> List[List[int]]: |
| """ |
| GPU and CPU hybrid version for arithmetic compression. |
| """ |
| M = len(sorted_segments) |
| if M == 0: |
| return [] |
| logger.info("Step 1: Identifying unique segments to compress.") |
| |
| device = first_byte_prob.device |
| segment_to_indices = defaultdict(list) |
| for i, seg in enumerate(sorted_segments): |
| segment_to_indices[seg].append(i) |
| |
| unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 2] |
| logger.info(f"Found {len(unique_segments)} unique segments (len>2) out of {M} total segments.") |
| |
| |
| segment_to_compressed = {} |
|
|
| |
| ENCODING_BATCH_SIZE = 128 |
|
|
| encoder = HybridArithmeticEncoder( |
| batched_predict_fn=batched_predict_fn, |
| first_byte_prob=first_byte_prob |
| ) |
| logger.info(f"Step 2: Encoding unique segments in batches of size {ENCODING_BATCH_SIZE}.") |
| |
| for i in range(0, len(unique_segments), ENCODING_BATCH_SIZE): |
| batch_start = i |
| batch_end = min(i + ENCODING_BATCH_SIZE, len(unique_segments)) |
| batch_unique_segments = unique_segments[batch_start:batch_end] |
| try: |
| if debug: |
| |
| codes, padded_bits = encoder.batched_encode(batch_unique_segments, return_num_padded_bits=True) |
| decoded_tensor = encoder.batched_decode(codes, padded_bits, batch_unique_segments) |
| for j, original_seg_bytes in enumerate(batch_unique_segments): |
| original_len = len(original_seg_bytes) |
| decoded_bytes = bytes(decoded_tensor[j, :original_len].cpu().tolist()) |
| assert decoded_bytes == original_seg_bytes, f"Hybrid decode mismatch for segment!" |
| else: |
| codes = encoder.batched_encode(batch_unique_segments, return_num_padded_bits=False) |
| |
| |
| for seg, code in zip(batch_unique_segments, codes): |
| segment_to_compressed[seg] = list(code) |
| |
| except Exception as e: |
| logger.warning(f"Batch encoding failed for unique segments {batch_start}-{batch_end}: {e}. Using raw bytes for this batch.") |
| for seg in batch_unique_segments: |
| segment_to_compressed[seg] = list(seg) |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| logger.info("Step 3: Reconstructing final list from unique compressed segments.") |
| all_compressed_results = [None] * M |
| |
| for seg, indices in segment_to_indices.items(): |
| if len(seg) <= 2: |
| result = list(seg) |
| else: |
| compressed_data = segment_to_compressed.get(seg, list(seg)) |
| if len(compressed_data) >= len(seg): |
| result = list(seg) |
| else: |
| result = compressed_data |
| |
| for original_index in indices: |
| all_compressed_results[original_index] = result |
| |
| return all_compressed_results |
|
|
| |
| def reconstruct_results( |
| compressed_map: Dict[int, List[int]], |
| reconstruction_info: Dict[str, Any], |
| debug: bool = True |
| ) -> List[Dict[str, Any]]: |
| """ |
| Reconstruct the original results from the compressed results. |
| Need and compressed ratio and assert the reconstruction is correct. |
| """ |
| sample_idx_to_list_segment_idx = reconstruction_info["sample_idx_to_list_segment_idx"] |
| raw_segments_map = reconstruction_info["raw_segments_map"] |
| batch_meta = reconstruction_info["batch_meta"] |
|
|
| |
| sorted_to_original_idx_map = reconstruction_info["sorted_to_original_idx_map"] |
| |
| original_idx_to_compressed_data = { |
| v: compressed_map[k] |
| for k, v in sorted_to_original_idx_map.items() |
| if k in compressed_map |
| } |
|
|
| write_results = [] |
| ac_key = "m1_enumerative" |
|
|
| |
| total_original_bytes = 0 |
| total_compressed_pseudo_bytes = 0 |
|
|
| for sample_idx,item in enumerate(batch_meta): |
| final_pseudo_bytes = [] |
| if debug: |
| reconstructed_original_segments = [] |
|
|
| segment_indices_for_sample = sample_idx_to_list_segment_idx.get(sample_idx, []) |
| for original_idx in segment_indices_for_sample: |
| if original_idx in original_idx_to_compressed_data: |
| |
| compressed_data = original_idx_to_compressed_data[original_idx] |
| final_pseudo_bytes.extend(compressed_data) |
| if debug: |
| |
| total_compressed_pseudo_bytes += len(compressed_data) |
| original_segment_bytes = reconstruction_info["effective_segments_map"][original_idx] |
| reconstructed_original_segments.append(original_segment_bytes) |
| total_original_bytes += len(original_segment_bytes) |
|
|
| elif original_idx in raw_segments_map: |
| |
| raw_data = raw_segments_map[original_idx] |
| final_pseudo_bytes.extend(list(raw_data)) |
| |
| if debug: |
| total_compressed_pseudo_bytes += len(raw_data) |
| reconstructed_original_segments.append(raw_data) |
| total_original_bytes += len(raw_data) |
| |
| else: |
| |
| logger.error(f"FATAL LOGIC ERROR: Segment with original_idx {original_idx} does not exist in effective_segments_map or raw_segments_map!") |
| |
| original_segment_bytes = reconstruction_info["effective_segments_map"].get(original_idx) |
| if original_segment_bytes: |
| final_pseudo_bytes.extend(list(original_segment_bytes)) |
|
|
|
|
| packed_bytes = pseudo_to_packed_bytes(final_pseudo_bytes) |
| result = { |
| **item, |
| "m1_compressed_data": base64.b64encode(packed_bytes).decode("ascii") |
| } |
| write_results.append(result) |
| |
| if debug and reconstructed_original_segments: |
| original_sample_bytes = item["text"].encode('utf-8') |
| reconstructed_sample_bytes = b"".join(reconstructed_original_segments) |
| assert reconstructed_sample_bytes == original_sample_bytes, \ |
| f"Sample {sample_idx} reconstruction failed!" |
| |
| unpacked_pseudo_bytes = packed_bytes_to_pseudo(packed_bytes) |
| assert unpacked_pseudo_bytes == final_pseudo_bytes, \ |
| f"Pseudo-bytes packing/unpacking round-trip failed for sample {sample_idx}" |
|
|
| |
| if debug and total_original_bytes > 0: |
| compression_ratio = total_compressed_pseudo_bytes / total_original_bytes |
| logger.info(f"Batch compression stats: " |
| f"Original bytes: {total_original_bytes}, " |
| f"Compressed pseudo-bytes: {total_compressed_pseudo_bytes}, " |
| f"Ratio: {compression_ratio:.4f}") |
|
|
| |
| return write_results |
|
|
|
|
| def writer_consumer(write_queue, output_file, buffer_size=100,debug=True): |
| write_buf = [] |
| try: |
| with open(output_file, 'w', encoding='utf-8') as f: |
| while True: |
| payload = write_queue.get() |
| if payload is None: |
| break |
| |
| write_results = reconstruct_results( |
| payload["compressed_map"], |
| payload["reconstruction_info"], |
| debug=debug |
| ) |
| write_buf.extend(write_results) |
| |
| |
| |
| if len(write_buf) >= buffer_size: |
| logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}") |
| for buffered_item in write_buf: |
| f.write(json.dumps(buffered_item) + '\n') |
| f.flush() |
| write_buf = [] |
|
|
| |
| if write_buf: |
| logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}") |
| for buffered_item in write_buf: |
| f.write(json.dumps(buffered_item) + '\n') |
| f.flush() |
| |
| except Exception as e: |
| logger.error(f"Writer process error: {e}") |
| raise |
|
|
| def merge_output_files(output_file, writer_output_files): |
| """Merge all writer output files into a single file""" |
| logger.info(f"Merging {len(writer_output_files)} writer files into {output_file}") |
| |
| with open(output_file, 'w', encoding='utf-8') as outf: |
| for writer_output_file in writer_output_files: |
| if writer_output_file.exists(): |
| with open(writer_output_file, 'r', encoding='utf-8') as inf: |
| for line in inf: |
| outf.write(line) |
| |
| writer_output_file.unlink() |
| logger.info(f"Merged and removed writer file: {writer_output_file}") |
| |
| logger.info(f"Merged output written to: {output_file}") |
| return output_file |
|
|
| def shutdown_writers(write_queue, writer_processes): |
| """Send shutdown signals to shared queue and wait for all writers to complete""" |
| |
| for i in range(len(writer_processes)): |
| write_queue.put(None) |
| logger.info(f"Sent shutdown signal {i+1}/{len(writer_processes)}") |
| |
| |
| for i, writer_process in enumerate(writer_processes): |
| writer_process.join() |
| if writer_process.exitcode != 0: |
| logger.error(f"Writer process {i} failed with exit code: {writer_process.exitcode}") |
| else: |
| logger.info(f"Writer process {i} completed successfully") |
|
|
|
|
| def main_processor_fn( |
| batch: List[Dict[str, Any]], |
| compression_fn: Callable, |
| predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| max_m1_batch_size: int, |
| debug: bool = True |
| ): |
| |
| prep_data = prepare_segments(batch) |
| sorted_segments = prep_data["sorted_segments_to_compress"] |
| reconstruction_info = prep_data["reconstruction_info"] |
| |
| if sorted_segments: |
| |
| start_time = time.time() |
| compressed_pseudo_bytes = compression_fn( |
| sorted_segments, |
| predict_fn, |
| first_byte_prob, |
| max_m1_batch_size, |
| debug |
| ) |
| end_time = time.time() |
| duration = end_time - start_time |
| logger.info( |
| f"Compressed {len(sorted_segments)} segments " |
| f"in {duration:.4f} seconds ({len(sorted_segments)/duration if duration > 0 else float('inf'):.2f} segments/sec)." |
| ) |
| |
| |
| |
| compressed_map = { |
| i: data |
| for i, data in enumerate(compressed_pseudo_bytes) |
| } |
| |
| |
| |
| |
| |
| |
| else: |
| compressed_map = {} |
| |
| payload = { |
| "compressed_map": compressed_map, |
| "reconstruction_info": reconstruction_info |
| } |
| return payload |
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach') |
| parser.add_argument('--input_file', type=str, required=True, |
| help='Directory containing input JSONL files') |
| parser.add_argument('--output_dir', type=str, required=True, |
| help='Directory to write compressed results') |
| parser.add_argument('--entropy_model_path', type=str, required=True, |
| help='Path to the M1 model checkpoint') |
| parser.add_argument('--compression_model_path', type=str, required=True, |
| help='Path to the M1 model checkpoint') |
| parser.add_argument('--compressor', type=str, default='rank_based', |
| choices=['rank_based', 'arithmetic','hybrid_arithmetic'], |
| help='Choose the compression algorithm.') |
| parser.add_argument('--data_batch_size', type=int, default=512, |
| help='Size of batches for processing (default: 512)') |
| parser.add_argument('--output_window_size', type=int, default=16, |
| help='Size of window for compression (default: 16)') |
| parser.add_argument('--max_window_size', type=int, default=1024, |
| help='Maximum window size for reading from each file (default: 1024)') |
| parser.add_argument('--max_entropy_batch_size', type=int, default=4096, |
| help='Size of max batch for compression (default: 4096)') |
| parser.add_argument('--max_compression_batch_size', type=int, default=4096, |
| help='Size of max batch for compression (default: 4096)') |
| parser.add_argument('--chunk_size', type=int, default=512, |
| help='Size of chunk for compression (default: 512)') |
| parser.add_argument('--base_global_quantile', type=float, default=0.9, |
| help='Base global quantile for compression (default: 0.9)') |
| parser.add_argument('--base_monotonic_quantile', type=float, default=0.9, |
| help='Base monotonic quantile for compression (default: 0.9)') |
| parser.add_argument('--debug', action='store_true', default=True, |
| help='Debug mode (default: False)') |
| parser.add_argument('--firstbyte_prob_path', type=str, default=None, |
| help='Probability path for the first word of each window (default : None)') |
| parser.add_argument('--num_workers', type=int, default=1, |
| help='Number of workers for CPU jobs (default: 1)') |
| parser.add_argument('--process_id', type=int, default=0, |
| help='Process ID for distributed processing (default: 0)') |
| parser.add_argument('--num_processes', type=int, default=1, |
| help='Number of processes for distributed processing (default: 1)') |
| parser.add_argument('--merge_output', action='store_true', default=False, |
| help='Merge all writer output files into a single file (default: False)') |
|
|
| |
| parser.add_argument('--use_global_cache', action='store_true', default=True, |
| help='Enable the global compression cache.') |
| parser.add_argument('--cache_size', type=int, default=819200, |
| help='Size of the global compression cache.') |
| |
| |
| args = parser.parse_args() |
|
|
| |
| if args.compressor == 'rank_based': |
| compression_algorithm = compress_segments_rank_based |
| elif args.compressor == 'arithmetic': |
| compression_algorithm = compress_segments_arithmetic |
| elif args.compressor == 'hybrid_arithmetic': |
| compression_algorithm = compress_segments_hybrid_arithmetic |
| else: |
| raise ValueError(f"Unknown compressor: {args.compressor}") |
| logger.info(f"Using compression algorithm: {compression_algorithm.__name__}") |
|
|
| |
| |
| if args.use_global_cache: |
| caching_wrapper = CachingCompressorWrapper( |
| base_compression_fn=compression_algorithm, |
| cache_size=args.cache_size |
| ) |
| |
| compression_algorithm_to_use = caching_wrapper |
| logger.info("Global cache start....") |
| else: |
| |
| compression_algorithm_to_use = compression_algorithm |
| logger.info("No Global cache ...") |
| |
| |
| mp.set_start_method('spawn', force=True) |
| gc_freq = 100 |
| dump_freq = 25 |
|
|
| |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path) |
| batched_predict_fn = batched_m1_compress_predict_fn(model) |
|
|
| if args.firstbyte_prob_path is not None: |
| with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f: |
| first_byte_prob = json.load(f) |
| print(first_byte_prob) |
| first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) |
| else: |
| first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE |
|
|
| |
| dataset = InterleavedJsonlDataset( |
| file_path=args.input_file, |
| rank=args.process_id, |
| world_size=args.num_processes, |
| ) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.data_batch_size, |
| shuffle=False, |
| collate_fn=lambda x: x |
| ) |
|
|
| input_file = Path(args.input_file) |
| logger.info(f"Processing file: {input_file}") |
| |
| output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl" |
| |
| logger.info("Data loaded. Start processing...") |
|
|
| write_queue = mp.Queue(maxsize=200) |
| writer_processes = [] |
| writer_output_files = [] |
| for i in range(args.num_workers): |
| |
| output_path = Path(output_file) |
| writer_output_file = output_path.parent / f"{output_path.stem}_writer_{i}.jsonl" |
| writer_output_files.append(writer_output_file) |
| writer_process = mp.Process( |
| target=writer_consumer, |
| args=(write_queue, writer_output_file, dump_freq,args.debug) |
| ) |
| |
| writer_processes.append(writer_process) |
| writer_process.start() |
| logger.info(f"Started writer process {i} for output file: {writer_output_file}") |
| |
| try: |
| |
| for batch_idx, batch in enumerate(dataloader): |
| payload_for_writer = main_processor_fn( |
| batch, |
| compression_algorithm_to_use, |
| |
| batched_predict_fn, |
| first_byte_prob, |
| args.max_compression_batch_size, |
| args.debug, |
| ) |
| logger.info(f"Processed batch {batch_idx}") |
| write_queue.put(payload_for_writer) |
|
|
| if batch_idx % gc_freq == 0: |
| |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| shutdown_writers(write_queue, writer_processes) |
| |
| except Exception as e: |
| logger.error(f"Error during processing: {e}") |
| |
| try: |
| shutdown_writers(write_queue, writer_processes) |
| except: |
| pass |
| raise |
|
|
| if args.merge_output: |
| final_output_file = merge_output_files(output_file, writer_output_files) |
| logger.info(f"Completed processing successfully, merged output written to {final_output_file}") |
| else: |
| logger.info(f"Completed processing successfully, outputs written to {args.num_workers} separate files") |
|
|
| if __name__ == "__main__": |
| main() |
|
|