Byte-lingua-code / offline_entropy_window_compress_v2.py
2ira's picture
offline_compression_graph_code
72c0672 verified
"""
new pipeline:
1.prepare data: from original batch to unpack,sort,rerank windows and reconstruct indexes
2.compress_segment_xxx(): use different compression algorithm
3.reconstruct_result: reconstruct the window from compressed results and idx
4.main()--- use prepare_segments and compress_seg_xx to produce and pass them to consumers
5.write_consumer: get from compressed data, reconstruct and write the result
"""
import torch
import torch.nn.functional as F
from torch.utils.data import IterableDataset, Dataset, DataLoader
import json
import numpy as np
from pathlib import Path
from typing import Iterator, List, Dict, Any, Callable, Tuple, Optional
import logging
import argparse
import base64
import time
import math
import gc
from collections import defaultdict, Counter,deque
from m1_compression.utils import *
from m1_compression.compressor import (
load_m1_model_and_tokenizer,
ALPHABET_SIZE,
)
import multiprocessing as mp
from m1_compression.enumerative_coder_simple import SimpleAdaptiveRankCodec
from m1_compression.batched_arithmetic_coder import BatchedArithmeticEncoder
from m1_compression.hybrid_arithmetic_coder import HybridArithmeticEncoder
from m1_compression.compressor import (
load_m1_model_and_tokenizer,
ALPHABET_SIZE,
ARITHMETIC_CODER_BASE,
ARITHMETIC_CODER_PRECISION,
)
from offline_entropy_window_split import unpack_windows
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger()
def pseudo_to_packed_bytes(lst: list[int]) -> bytes:
out = bytearray()
acc = bits = 0
for v in lst:
acc |= (v & 0x1FF) << bits
bits += 9
while bits >= 8:
out.append(acc & 0xFF)
acc >>= 8
bits -= 8
if bits: # flush tail
out.append(acc)
return bytes(out)
def packed_bytes_to_pseudo(b: bytes) -> list[int]:
out, acc, bits = [], 0, 0
for byte in b:
acc |= byte << bits
bits += 8
while bits >= 9:
out.append(acc & 0x1FF)
acc >>= 9
bits -= 9
return out
def calculate_compression_ratio(original_bytes: List[bytes], compressed_segments: List[bytes]) -> float:
if not compressed_segments or len(original_bytes) == 0:
return 1.0
total_compressed_length = sum(len(compressed_seg) for compressed_seg in compressed_segments)
ratio = total_compressed_length / sum(len(orig_seg) for orig_seg in original_bytes)
if ratio > 2.0:
logger.warning(f"Unusual compression ratio: {ratio:.4f} (compressed larger than original)")
return ratio
def collect_window_size_statistics(segmented_results: List[List[bytes]]) -> Dict[int, int]:
window_size_counts = Counter()
for segments in segmented_results:
for segment in segments:
window_size = len(segment)
window_size_counts[window_size] += 1
return dict(window_size_counts)
# def pad_batch(batch: List[bytes]):
# batch_tensors = [torch.tensor(data, dtype=torch.int64) for data in batch]
# lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64)
# padded_batch = torch.nn.utils.rnn.pad_sequence(
# batch_tensors,
# batch_first=True,
# padding_value=0,
# padding_side="right"
# )
# return padded_batch, lengths
def pad_batch(batch: List[bytes]):
# fix 1: transfer bytes to (list(data))
batch_tensors = [torch.tensor(list(data), dtype=torch.int64) for data in batch]
lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64)
# fix 2: remove torch.nn.utils.rnn.pad_sequence 不支持的 'padding_side' 参数。
# right padding(对于 batch_first=True)
padded_batch = torch.nn.utils.rnn.pad_sequence(
batch_tensors,
batch_first=True,
padding_value=0
)
return padded_batch, lengths
# control long seq with smaller batch
def get_batch_size_for_length(window_len, max_batch_size):
"""
Determines the batch size for a given window length.
VERY AGGRESSIVE reduction for long sequences to prevent OOM.
"""
# max_batch_size only for short len
if window_len <= 128:
return max_batch_size
if window_len <= 256:
return max(max_batch_size // 4, 1)
if window_len <= 512:
return max(max_batch_size // 16, 1)
if window_len <= 1024:
return max(max_batch_size // 64, 1)
if window_len <= 2048:
return 2 # 对于 1k-2k 的序列,最多处理 2 个
# 对于超过 2048 的超长序列,一次只处理 1 个
return 1
# def get_batch_size_for_length(window_len, max_batch_size):
# """Determines the batch size for a given window length."""
# BATCH_SIZE_TIERS = {
# 128: max_batch_size,
# 512: max(max_batch_size // 64, 1),
# 1024: max(max_batch_size // 128, 1),
# 2048: max(max_batch_size // 256, 1),
# }
# for max_len, batch_size in BATCH_SIZE_TIERS.items():
# if window_len <= max_len:
# return batch_size
# return 1
def find_next_batch_range(all_windows, start_idx, max_m1_batch_size):
M = len(all_windows)
if start_idx >= M:
return start_idx, start_idx
first_window_len = len(all_windows[start_idx])
base_batch_size = get_batch_size_for_length(first_window_len, max_m1_batch_size)
low = start_idx
high = min(start_idx + base_batch_size, M)
high_batch_size = get_batch_size_for_length(len(all_windows[high - 1]), max_m1_batch_size)
if high_batch_size == base_batch_size:
return start_idx, high
search_low = low
search_high = high
while search_low < search_high:
mid = search_low + (search_high - search_low) // 2
mid_window_len = len(all_windows[mid])
if get_batch_size_for_length(mid_window_len, max_m1_batch_size) == base_batch_size:
# This window is valid. The partition point must be to the right of it.
# So, we continue searching in the range [mid + 1, high).
search_low = mid + 1
else:
# This window is NOT valid. It might be the partition point itself,
# or the point is to its left.
# So, we continue searching in the range [low, mid).
search_high = mid
end_idx = search_low
if end_idx == start_idx:
return start_idx, start_idx + 1
else:
return start_idx, end_idx
class JsonlShardedDataset(Dataset):
def __init__(
self,
file_path: str,
current_proc_rank: int = 0,
total_procs: int = 1,
) -> None:
assert 0 <= current_proc_rank < total_procs, "rank must be in [0, world_size)"
self.current_proc_rank = current_proc_rank
self.total_procs = total_procs
# -- load the whole file once (fast for < few-GB files) -------------
with open(file_path, "r", encoding="utf-8") as f:
full_data: List[Dict[str, Any]] = [json.loads(line) for line in f]
# -- pick the slice that belongs to *this* process ------------------
total = len(full_data)
per_proc = math.ceil(total / total_procs)
start = current_proc_rank * per_proc
end = min(start + per_proc, total)
self.data = full_data[start:end]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.data[idx]
class InterleavedJsonlDataset(IterableDataset):
"""
An iterable-style dataset for reading a large JSONL file using an
interleaving/striding pattern, without yielding state information.
This is designed for multi-process data loading. Each process reads the
entire file but only processes lines that match its rank (offset).
For `N` total processes (world_size), process `r` (rank) will read
lines r, r+N, r+2N, ... (0-indexed).
This method ensures an even distribution of lines across processes.
Args:
file_path (str): Path to the JSONL file.
rank (int): The rank of the current process, used as the offset.
world_size (int): The total number of processes, used as the block_size/stride.
"""
def __init__(
self,
file_path: str,
rank: int,
world_size: int,
) -> None:
super().__init__()
if not (0 <= rank < world_size):
raise ValueError(f"Rank must be in [0, {world_size-1}], but got {rank}")
self.file_path = file_path
self.offset = rank
self.block_size = world_size
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""
The iterator method that yields the parsed JSON data for the assigned lines.
"""
try:
with open(self.file_path, "r", encoding="utf-8") as f:
# We use a simple line counter to determine which lines to process.
# The line_number is 0-indexed.
for line_number, line in enumerate(f):
# Check if the current line number belongs to this process
if (line_number % self.block_size) == self.offset:
try:
# Yield the parsed JSON object
yield json.loads(line)
except json.JSONDecodeError:
# This line is malformed. We can either raise an error
# or, more robustly, just print a warning and skip it.
print(f"Warning: Rank {self.offset} could not decode JSON on line ~{line_number+1}. Skipping.")
continue
except Exception as e:
print(f"Error in worker {self.offset}: {e}")
raise
def batched_m1_compress_predict_fn(model):
def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor:
if input_tensor.dim() == 1:
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor, **kwargs)
logits = logits[..., :256]
logits = logits.float()
assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values."
probs = torch.softmax(logits, dim=-1)
return probs
return predict_fn
class CachingCompressorWrapper:
def __init__(
self,
base_compression_fn: Callable, # add cache on base compressor, make module seperate
cache_size: int = 819200, # default a big cache size
cache_policy: str = 'fifo' # default fifo
):
if cache_policy not in ['fifo']:
raise ValueError(f"no caching policy: {cache_policy}.")
self.base_compression_fn = base_compression_fn
self.cache_size = cache_size
self.cache_policy = cache_policy
# self.cache 存储: raw_bytes -> compressed_pseudo_bytes (List[int])
self.cache: Dict[bytes, List[int]] = {}
self.fifo_queue: deque[bytes] = deque()
logger.info(f"Create CachingCompressorWrapper '{self.base_compression_fn.__name__}',"
f"Cache size: {self.cache_size}, policy: {self.cache_policy}")
def compress(
self,
sorted_segments: List[bytes],
*args, **kwargs
) -> List[List[int]]:
"""
compressors with cache
"""
if not sorted_segments:
return []
M = len(sorted_segments)
# 1. unique data and indxes
segment_to_indices = defaultdict(list)
for i, seg in enumerate(sorted_segments):
segment_to_indices[seg].append(i)
unique_segments = list(segment_to_indices.keys())
# 2. check in cache or not
misses_data = []
results_for_uniques: Dict[bytes, List[int]] = {}
## for each unique segment, check in cache or not
for segment in unique_segments:
if segment in self.cache:
results_for_uniques[segment] = self.cache[segment]
else:
misses_data.append(segment)
hit_count = len(unique_segments) - len(misses_data)
logger.info(f"Cache checking: {len(unique_segments)} segments, "
f"Get {hit_count}, No caching {len(misses_data)} ")
# 3. compress non-caching segments
if misses_data:
## keep use original one
newly_compressed = self.base_compression_fn(
misses_data, *args, **kwargs
)
# 4.refresh cache and fill in result
for i in range(len(misses_data)):
raw_segment = misses_data[i]
compressed_result = newly_compressed[i]
results_for_uniques[raw_segment] = compressed_result
# refresh cache
if self.cache_size > 0 and raw_segment not in self.cache:
if len(self.cache) >= self.cache_size:
if self.cache_policy == 'fifo':
oldest_key = self.fifo_queue.popleft()
del self.cache[oldest_key]
self.cache[raw_segment] = compressed_result
self.fifo_queue.append(raw_segment)
# 5. rebuild all results
all_compressed_results = [None] * M
for seg, indices in segment_to_indices.items():
result = results_for_uniques[seg]
for original_index in indices:
all_compressed_results[original_index] = result
return all_compressed_results
def __call__(self, *args, **kwargs):
return self.compress(*args, **kwargs)
def compress_segments_hybrid_arithmetic(
sorted_segments: List[bytes],
batched_predict_fn: Callable,
first_byte_prob: torch.Tensor,
max_m1_batch_size: int=4096,
debug: bool = True
) -> List[List[int]]:
"""
这个函数现在只处理它收到的数据,不需要关心缓存或去重。
这些逻辑已经被外层的 CachingCompressorWrapper 处理了。
"""
M = len(sorted_segments)
if M == 0:
return []
# 注意:这里的 sorted_segments 已经是去重后、未命中缓存的数据了。
logger.info(f"Hybrid AC 核心: 正在处理 {M} 个不重复、未命中缓存的段。")
segment_to_compressed = {}
ENCODING_BATCH_SIZE = 128
encoder = HybridArithmeticEncoder(
batched_predict_fn=batched_predict_fn,
first_byte_prob=first_byte_prob
)
all_compressed_results = []
for i in range(0, M, ENCODING_BATCH_SIZE):
batch_start = i
batch_end = min(i + ENCODING_BATCH_SIZE, M)
batch_segments = sorted_segments[batch_start:batch_end]
try:
codes = encoder.batched_encode(batch_segments, return_num_padded_bits=False)
# 对比压缩效果
for seg, code in zip(batch_segments, codes):
if len(code) < len(seg):
all_compressed_results.append(list(code))
else:
all_compressed_results.append(list(seg)) # 压缩效果不好,用原始数据
except Exception as e:
logger.warning(f"Hybrid AC 核心: 批次 {batch_start}-{batch_end} 编码失败: {e}. 该批次使用原始字节。")
for seg in batch_segments:
all_compressed_results.append(list(seg))
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return all_compressed_results
def prepare_segments(batch: List[Dict[str, Any]])->Dict[str,Any]:
"""
remove the unpack,sort and rerank methods from segment_pre..
address unpack and judge the compressiable simultaneously
"""
all_segments = []
is_compressible_indicator = []
sample_idx_to_list_segment_idx = defaultdict(list)
segment_idx = 0
for sample_idx, item in enumerate(batch):
assert "windows_starts_lens_b64" in item, "windows_starts_lens_b64 must be in item"
sample_bytes = item["text"].encode('utf-8')
byte_windows = unpack_windows(sample_bytes, item["windows_starts_lens_b64"])
for segment,indicator in byte_windows:
all_segments.append(segment)
is_compressible = (indicator == 1 and len(segment) > 3)
is_compressible_indicator.append(is_compressible)
# record mapping for sample--segment
sample_idx_to_list_segment_idx[sample_idx].append(segment_idx)
segment_idx += 1
effective_segments = {} # {index: window}
raw_segments_map = {} # {index: window}
for i, (segment, is_comp) in enumerate(zip(all_segments, is_compressible_indicator)):
if is_comp:
effective_segments[i] = segment
else:
raw_segments_map[i] = segment
# rerank by length, reduce padding
sorted_indices_to_compress = sorted(
effective_segments.keys(),
key=lambda idx: len(effective_segments[idx])
)
sorted_segments_to_compress = [effective_segments[idx] for idx in sorted_indices_to_compress]
# create reconstruct information --
# 1. mapping sample and segment idx -- in one time unpack
# 2. mapping old and new idx
sorted_to_original_idx_map = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices_to_compress)}
reconstruction_info = {
"sample_idx_to_list_segment_idx": sample_idx_to_list_segment_idx,
"sorted_to_original_idx_map": sorted_to_original_idx_map,
"raw_segments_map": raw_segments_map,
"total_segments": len(all_segments),
"batch_meta": batch, # meta data
"effective_segments_map": effective_segments
}
return {
"sorted_segments_to_compress": sorted_segments_to_compress,
"reconstruction_info": reconstruction_info,
}
# wrap it as the first processing function
def simple_rle_topk_compression(
batch: List[bytes],
predict_fn: Callable,
first_byte_prob: torch.Tensor,
max_m1_batch_size: int = 4096,
debug: bool = True,
):
"""use language model to compress, return compressed bytes and padded bits
Args:
sliding_windows: List of byte sequences to compress
predict_fn: Function that predicts next token probabilities
return_num_padded_bits: Whether to return number of padded bits
profile: Whether to print timing information for each major step
"""
if debug:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
torch.cuda.synchronize()## make sure all previous events are completed
print("[Debug CUDA] time start", flush=True)
assert first_byte_prob.shape == (1, 1, ALPHABET_SIZE), "first_byte_prob must be of shape (1, 1, ALPHABET_SIZE)"
# refactored batch output window AC:
#### 1. pad the current batch
batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in batch]
M = len(batched_windows_np)
batched_repeat_probs = []
batched_ranks = []
batched_lengths = []
if debug:
batched_sorted_indices = []
start_idx = 0
while start_idx < M:
# Use the new helper function to find the exact range for the next safe batch
start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size)
windows_np_chunked = batched_windows_np[start_idx:end_idx]
padded_batched_windows, lengths = pad_batch(windows_np_chunked)
padded_batched_windows, lengths = padded_batched_windows.cuda(), lengths.cuda()
prompt_probs = predict_fn(padded_batched_windows)
prompt_probs = torch.cat(
[
first_byte_prob.expand(prompt_probs.shape[0], -1, -1),
prompt_probs[:, :-1, ...]
],
dim=1
)
prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs)
######## Use BatchArithmeticEncoder to replace address one by one ###########
# we calculate two quantiles from prompt_probs
# 1. the probability of the next byte
# 2. the byte ids of the topk next bytes
next_token_probs = torch.gather(
prompt_probs,
dim=-1,
index=padded_batched_windows.unsqueeze(-1)
).squeeze(-1) # [B, L]
sorted_indices = torch.argsort(prompt_probs, dim=-1, descending=True)
rank_bitvector = padded_batched_windows.unsqueeze(-1) == sorted_indices
ranks = torch.argmax(rank_bitvector.float(), dim=-1) # [B, L]
start_idx = end_idx
batched_repeat_probs.extend(next_token_probs.cpu().numpy().tolist())
batched_ranks.extend(ranks.cpu().numpy().tolist())
batched_lengths.extend(lengths.cpu().numpy().tolist())
if debug:
batched_sorted_indices.extend(sorted_indices.cpu().numpy().tolist())
if debug:
return batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices
else:
return batched_repeat_probs, batched_ranks, batched_lengths
def compress_segments_rank_based(
sorted_segments: List[bytes],
batched_predict_fn: Callable,
first_byte_prob: torch.Tensor,
max_m1_batch_size: int=4096,
debug: bool = True
) -> List[List[int]]:
"""
(SimpleAdaptiveRankCodec)。
decompress GPU probs and CPU compression。
"""
# --- GPU Stage : acquire probs and rank ---
# use original simple_rle_topk_compression
try:
gpu_result = simple_rle_topk_compression(
sorted_segments,
batched_predict_fn,
first_byte_prob,
max_m1_batch_size=max_m1_batch_size,
debug=debug,
)
if debug:
batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices = gpu_result
else:
batched_repeat_probs, batched_ranks, batched_lengths = gpu_result
batched_sorted_indices = None
# --- CPU Stage: encoding one by one ---
if len(batched_lengths) != len(sorted_segments):
logger.error(f"FATAL: Length mismatch after GPU stage. Expected {len(sorted_segments)}, got {len(batched_lengths)}. Falling back to raw data.")
# 如果长度不匹配,说明上游出错了,直接返回原始数据
return [list(seg) for seg in sorted_segments]
M = len(batched_lengths)
batched_compressed_bytes = []
for i in range(M):
lengths = batched_lengths[i]
window_bytes = sorted_segments[i]
repeat_probs = batched_repeat_probs[i][:lengths]
ranks = batched_ranks[i][:lengths]
codec = SimpleAdaptiveRankCodec(top_k=4)
encoding = codec.encode_window(list(window_bytes), repeat_probs, ranks)
compressed_bytes = codec.encoding_to_pseudo_bytes(encoding)
#Add: compare compress result
if len(compressed_bytes) >= len(window_bytes):
# use raw bytes replace
batched_compressed_bytes.append(list(window_bytes))
else:
# compress successfully
batched_compressed_bytes.append(compressed_bytes)
if debug:
# make sure batched_compressed_bytes[i] is a list
if batched_sorted_indices is None or batched_sorted_indices[i] is None:
logger.warning(f"Debug mode is on but sorted_indices for segment {i} is None. Skipping decode check.")
continue
#Add: valid encode and decode rather than pseudo_bytes -> encoding
sorted_indices = batched_sorted_indices[i][:lengths]
decoded = codec.decode_window(encoding, lengths, sorted_indices)
assert bytes(decoded) == window_bytes, "decoded does not match window_bytes: \n{} and \n{}".format(decoded, window_bytes)
if i < 10:
logger.info(f"Example input window bytes: {window_bytes}")
logger.info(f"Example encoding : {encoding}")
logger.info(f"Example compressed bytes : {compressed_bytes}")
# batched_compressed_bytes.append(compressed_bytes) 这里重复会导致keyerror
return batched_compressed_bytes
except Exception as e:
logger.error(f"Unhandled exception in compress_segments_rank_based: {e}. Falling back to raw data for the entire batch.", exc_info=True)
# if any error back to original bytes
return [list(seg) for seg in sorted_segments]
def compress_segments_arithmetic(
sorted_segments: List[bytes],
batched_predict_fn: Callable,
first_byte_prob: torch.Tensor,
max_m1_batch_size: int = 4096,
debug: bool = True
) -> List[List[int]]:
"""
Final robust version for arithmetic compression.
This version is inspired by successful production code and is designed to be stable.
It compresses unique segments in small, manageable batches and handles failures gracefully.
"""
device = first_byte_prob.device
M = len(sorted_segments)
if M == 0:
return []
# --- 1. unique ---
logger.info(f"Step 1: Identifying unique segments to compress.")
# 创建从原始段到其所有出现位置的映射--这里是为unique准备
# mapping original <-> position
segment_to_indices = defaultdict(list)
for i, seg in enumerate(sorted_segments):
segment_to_indices[seg].append(i)
# only valid segment
unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 2]
logger.info(f"Found {len(unique_segments)} unique segments (len>2) out of {M} total segments.")
# store only compressed result
segment_to_compressed = {}
# --- 2. safe encoding ---
ENCODING_BATCH_SIZE = 128
encoder = BatchedArithmeticEncoder(base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION)
logger.info(f"Step 2: Encoding unique segments in batches of size {ENCODING_BATCH_SIZE}.")
for i in range(0, len(unique_segments), ENCODING_BATCH_SIZE):
batch_start = i
batch_end = min(i + ENCODING_BATCH_SIZE, len(unique_segments))
batch_unique_segments = unique_segments[batch_start:batch_end]
# for each small batch
try:
# prepare for bytes--padding
batch_padded_segments, batch_lengths = pad_batch(batch_unique_segments)
batch_padded_segments = batch_padded_segments.to(device)
batch_lengths = batch_lengths.to(device)
# batch_predict_fn
with torch.no_grad():
# 净化输入,防止模型侧的错误
safe_padded_segments = batch_padded_segments.clamp(0, ALPHABET_SIZE - 1)
probs = batched_predict_fn(safe_padded_segments)
# # NOTE : normalize probs to avoid NaN/Inf
# if not torch.isfinite(probs).all():
# logger.warning(f"NaN/Inf detected in model probabilities for batch {i//ENCODING_BATCH_SIZE}. Clamping.")
# probs = torch.nan_to_num(probs, nan=1e-9, posinf=1.0, neginf=1e-9)
# probs = probs / probs.sum(dim=-1, keepdim=True)
final_probs = torch.cat([first_byte_prob.expand(probs.shape[0], -1, -1), probs[:, :-1, ...]], dim=1)
normalized_probs = utils.batched_normalize_pdf_for_arithmetic_coding(final_probs)
if not torch.isfinite(normalized_probs).all():
raise ValueError("NaN or Inf in normalized probabilities after normalization.")
# batch_encode
codes, _ = encoder.batched_encode(
normalized_probs,
batch_padded_segments,
lengths=batch_lengths,
return_num_padded_bits=True
)
# store result
for seg, code in zip(batch_unique_segments, codes):
segment_to_compressed[seg] = list(code)
except Exception as e:
# fail,fallback to original data
logger.warning(f"Batch encoding failed for unique segments {batch_start}-{batch_end}: {e}. Using raw bytes for this batch.")
for seg in batch_unique_segments:
segment_to_compressed[seg] = list(seg)
# clean step
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- 3. reconstruct all result ---
logger.info("Step 3: Reconstructing final list from unique compressed segments.")
all_compressed_results = [None] * M
for seg, indices in segment_to_indices.items():
if len(seg) <= 2:
# short segments
result = list(seg)
else:
# from mapping to get results
compressed_data = segment_to_compressed.get(seg, list(seg)) # not find,use original data
if len(compressed_data) >= len(seg):
result = list(seg) # back to origin
else:
result = compressed_data
# fill back original data
for original_index in indices:
all_compressed_results[original_index] = result
return all_compressed_results
def compress_segments_hybrid_arithmetic(
sorted_segments: List[bytes],
batched_predict_fn: Callable,
first_byte_prob: torch.Tensor,
max_m1_batch_size: int=4096,
debug: bool = True
) -> List[List[int]]:
"""
GPU and CPU hybrid version for arithmetic compression.
"""
M = len(sorted_segments)
if M == 0:
return []
logger.info("Step 1: Identifying unique segments to compress.")
device = first_byte_prob.device
segment_to_indices = defaultdict(list)
for i, seg in enumerate(sorted_segments):
segment_to_indices[seg].append(i)
unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 2]
logger.info(f"Found {len(unique_segments)} unique segments (len>2) out of {M} total segments.")
# store only compressed result
segment_to_compressed = {}
# --- 2. safe encoding ---
ENCODING_BATCH_SIZE = 128
encoder = HybridArithmeticEncoder(
batched_predict_fn=batched_predict_fn,
first_byte_prob=first_byte_prob
)
logger.info(f"Step 2: Encoding unique segments in batches of size {ENCODING_BATCH_SIZE}.")
for i in range(0, len(unique_segments), ENCODING_BATCH_SIZE):
batch_start = i
batch_end = min(i + ENCODING_BATCH_SIZE, len(unique_segments))
batch_unique_segments = unique_segments[batch_start:batch_end]
try:
if debug:
# in debug pattetn, get padded_bits for validation
codes, padded_bits = encoder.batched_encode(batch_unique_segments, return_num_padded_bits=True)
decoded_tensor = encoder.batched_decode(codes, padded_bits, batch_unique_segments)
for j, original_seg_bytes in enumerate(batch_unique_segments):
original_len = len(original_seg_bytes)
decoded_bytes = bytes(decoded_tensor[j, :original_len].cpu().tolist())
assert decoded_bytes == original_seg_bytes, f"Hybrid decode mismatch for segment!"
else:
codes = encoder.batched_encode(batch_unique_segments, return_num_padded_bits=False)
# store results
for seg, code in zip(batch_unique_segments, codes):
segment_to_compressed[seg] = list(code)
except Exception as e:
logger.warning(f"Batch encoding failed for unique segments {batch_start}-{batch_end}: {e}. Using raw bytes for this batch.")
for seg in batch_unique_segments:
segment_to_compressed[seg] = list(seg)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Step 3: Reconstructing final list from unique compressed segments.")
all_compressed_results = [None] * M
for seg, indices in segment_to_indices.items():
if len(seg) <= 2:
result = list(seg)
else:
compressed_data = segment_to_compressed.get(seg, list(seg))
if len(compressed_data) >= len(seg):
result = list(seg)
else:
result = compressed_data
for original_index in indices:
all_compressed_results[original_index] = result
return all_compressed_results
def reconstruct_results(
compressed_map: Dict[int, List[int]],
reconstruction_info: Dict[str, Any],
debug: bool = True
) -> List[Dict[str, Any]]:
"""
Reconstruct the original results from the compressed results.
Need and compressed ratio and assert the reconstruction is correct.
"""
sample_idx_to_list_segment_idx = reconstruction_info["sample_idx_to_list_segment_idx"]
raw_segments_map = reconstruction_info["raw_segments_map"]
batch_meta = reconstruction_info["batch_meta"]
#Add: valid reconstruction
sorted_to_original_idx_map = reconstruction_info["sorted_to_original_idx_map"]
# mapping original_idx -> compressed_data
original_idx_to_compressed_data = {
v: compressed_map[k]
for k, v in sorted_to_original_idx_map.items()
if k in compressed_map
}
write_results = []
ac_key = "m1_enumerative"
# compute compress ratio
total_original_bytes = 0
total_compressed_pseudo_bytes = 0
for sample_idx,item in enumerate(batch_meta):
final_pseudo_bytes = []
if debug:
reconstructed_original_segments = []
segment_indices_for_sample = sample_idx_to_list_segment_idx.get(sample_idx, [])
for original_idx in segment_indices_for_sample:
if original_idx in original_idx_to_compressed_data:
# compressed bytes
compressed_data = original_idx_to_compressed_data[original_idx]
final_pseudo_bytes.extend(compressed_data)
if debug:
#Add: assert encode and decode
total_compressed_pseudo_bytes += len(compressed_data)
original_segment_bytes = reconstruction_info["effective_segments_map"][original_idx]
reconstructed_original_segments.append(original_segment_bytes)
total_original_bytes += len(original_segment_bytes)
elif original_idx in raw_segments_map:
# raw bytes
raw_data = raw_segments_map[original_idx]
final_pseudo_bytes.extend(list(raw_data))
if debug:
total_compressed_pseudo_bytes += len(raw_data)
reconstructed_original_segments.append(raw_data)
total_original_bytes += len(raw_data)
else:
# Case 3: wrong both not exist
logger.error(f"FATAL LOGIC ERROR: Segment with original_idx {original_idx} does not exist in effective_segments_map or raw_segments_map!")
# try from effective_segments_map
original_segment_bytes = reconstruction_info["effective_segments_map"].get(original_idx)
if original_segment_bytes:
final_pseudo_bytes.extend(list(original_segment_bytes))
packed_bytes = pseudo_to_packed_bytes(final_pseudo_bytes)
result = {
**item,
"m1_compressed_data": base64.b64encode(packed_bytes).decode("ascii")
}
write_results.append(result)
#Add: assert encode and decode
if debug and reconstructed_original_segments:
original_sample_bytes = item["text"].encode('utf-8')
reconstructed_sample_bytes = b"".join(reconstructed_original_segments)
assert reconstructed_sample_bytes == original_sample_bytes, \
f"Sample {sample_idx} reconstruction failed!"
# check pack and unpack
unpacked_pseudo_bytes = packed_bytes_to_pseudo(packed_bytes)
assert unpacked_pseudo_bytes == final_pseudo_bytes, \
f"Pseudo-bytes packing/unpacking round-trip failed for sample {sample_idx}"
#Add: copmute compress ratio
if debug and total_original_bytes > 0:
compression_ratio = total_compressed_pseudo_bytes / total_original_bytes
logger.info(f"Batch compression stats: "
f"Original bytes: {total_original_bytes}, "
f"Compressed pseudo-bytes: {total_compressed_pseudo_bytes}, "
f"Ratio: {compression_ratio:.4f}")
# this tab wrong
return write_results
def writer_consumer(write_queue, output_file, buffer_size=100,debug=True):
write_buf = []
try:
with open(output_file, 'w', encoding='utf-8') as f:
while True:
payload = write_queue.get()
if payload is None:
break
# get result from reconstruct results
write_results = reconstruct_results(
payload["compressed_map"],
payload["reconstruction_info"],
debug=debug
)
write_buf.extend(write_results)
# clean the complex expression of before segmentation
# Write buffer when it's full
if len(write_buf) >= buffer_size:
logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}")
for buffered_item in write_buf:
f.write(json.dumps(buffered_item) + '\n')
f.flush()
write_buf = []
# Write remaining items in buffer
if write_buf:
logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}")
for buffered_item in write_buf:
f.write(json.dumps(buffered_item) + '\n')
f.flush()
except Exception as e:
logger.error(f"Writer process error: {e}")
raise
def merge_output_files(output_file, writer_output_files):
"""Merge all writer output files into a single file"""
logger.info(f"Merging {len(writer_output_files)} writer files into {output_file}")
with open(output_file, 'w', encoding='utf-8') as outf:
for writer_output_file in writer_output_files:
if writer_output_file.exists():
with open(writer_output_file, 'r', encoding='utf-8') as inf:
for line in inf:
outf.write(line)
# Optionally remove the individual writer files
writer_output_file.unlink()
logger.info(f"Merged and removed writer file: {writer_output_file}")
logger.info(f"Merged output written to: {output_file}")
return output_file
def shutdown_writers(write_queue, writer_processes):
"""Send shutdown signals to shared queue and wait for all writers to complete"""
# Send one sentinel per writer to ensure all writers get the shutdown signal
for i in range(len(writer_processes)):
write_queue.put(None)
logger.info(f"Sent shutdown signal {i+1}/{len(writer_processes)}")
# Wait for all writers to complete
for i, writer_process in enumerate(writer_processes):
writer_process.join()
if writer_process.exitcode != 0:
logger.error(f"Writer process {i} failed with exit code: {writer_process.exitcode}")
else:
logger.info(f"Writer process {i} completed successfully")
def main_processor_fn(
batch: List[Dict[str, Any]],
compression_fn: Callable, # <-- 传入一个压缩函数作为参数!
predict_fn: Callable,
first_byte_prob: torch.Tensor,
max_m1_batch_size: int,
debug: bool = True
):
# 1. preparing data
prep_data = prepare_segments(batch)
sorted_segments = prep_data["sorted_segments_to_compress"]
reconstruction_info = prep_data["reconstruction_info"]
# 2. compress data
if sorted_segments:
#Add: Time consume
start_time = time.time()
compressed_pseudo_bytes = compression_fn(
sorted_segments,
predict_fn,
first_byte_prob,
max_m1_batch_size,
debug
)
end_time = time.time()
duration = end_time - start_time
logger.info(
f"Compressed {len(sorted_segments)} segments "
f"in {duration:.4f} seconds ({len(sorted_segments)/duration if duration > 0 else float('inf'):.2f} segments/sec)."
)
# create a mapping from origal idx to compressed results
#sorted_to_original_idx_map = reconstruction_info["sorted_to_original_idx_map"]
compressed_map = {
i: data
for i, data in enumerate(compressed_pseudo_bytes)
}
# compressed_map: key--sorted i; value--compressed bytes
# but below sorted_to_original_idx_map[i] is original idx..
# compressed_map = {
# sorted_to_original_idx_map[i]: data
# for i, data in enumerate(compressed_pseudo_bytes)
# }
else:
compressed_map = {}
# 3. pack result to consumer
payload = {
"compressed_map": compressed_map,
"reconstruction_info": reconstruction_info
}
return payload
def main():
# Set up argument parser
parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach')
parser.add_argument('--input_file', type=str, required=True,
help='Directory containing input JSONL files')
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to write compressed results')
parser.add_argument('--entropy_model_path', type=str, required=True,
help='Path to the M1 model checkpoint')
parser.add_argument('--compression_model_path', type=str, required=True,
help='Path to the M1 model checkpoint')
parser.add_argument('--compressor', type=str, default='rank_based',
choices=['rank_based', 'arithmetic','hybrid_arithmetic'],
help='Choose the compression algorithm.')
parser.add_argument('--data_batch_size', type=int, default=512,
help='Size of batches for processing (default: 512)')
parser.add_argument('--output_window_size', type=int, default=16,
help='Size of window for compression (default: 16)')
parser.add_argument('--max_window_size', type=int, default=1024,
help='Maximum window size for reading from each file (default: 1024)')
parser.add_argument('--max_entropy_batch_size', type=int, default=4096,
help='Size of max batch for compression (default: 4096)')
parser.add_argument('--max_compression_batch_size', type=int, default=4096,
help='Size of max batch for compression (default: 4096)')
parser.add_argument('--chunk_size', type=int, default=512,
help='Size of chunk for compression (default: 512)')
parser.add_argument('--base_global_quantile', type=float, default=0.9,
help='Base global quantile for compression (default: 0.9)')
parser.add_argument('--base_monotonic_quantile', type=float, default=0.9,
help='Base monotonic quantile for compression (default: 0.9)')
parser.add_argument('--debug', action='store_true', default=True,
help='Debug mode (default: False)')
parser.add_argument('--firstbyte_prob_path', type=str, default=None,
help='Probability path for the first word of each window (default : None)')
parser.add_argument('--num_workers', type=int, default=1,
help='Number of workers for CPU jobs (default: 1)')
parser.add_argument('--process_id', type=int, default=0,
help='Process ID for distributed processing (default: 0)')
parser.add_argument('--num_processes', type=int, default=1,
help='Number of processes for distributed processing (default: 1)')
parser.add_argument('--merge_output', action='store_true', default=False,
help='Merge all writer output files into a single file (default: False)')
# adding cache parameters
parser.add_argument('--use_global_cache', action='store_true', default=True,
help='Enable the global compression cache.')
parser.add_argument('--cache_size', type=int, default=819200,
help='Size of the global compression cache.')
args = parser.parse_args()
# choose compression algorithm
if args.compressor == 'rank_based':
compression_algorithm = compress_segments_rank_based
elif args.compressor == 'arithmetic':
compression_algorithm = compress_segments_arithmetic
elif args.compressor == 'hybrid_arithmetic':
compression_algorithm = compress_segments_hybrid_arithmetic
else:
raise ValueError(f"Unknown compressor: {args.compressor}")
logger.info(f"Using compression algorithm: {compression_algorithm.__name__}")
# use wrapper to make cache for each algorithm
if args.use_global_cache:
caching_wrapper = CachingCompressorWrapper(
base_compression_fn=compression_algorithm,
cache_size=args.cache_size
)
# use cache
compression_algorithm_to_use = caching_wrapper
logger.info("Global cache start....")
else:
# no cache
compression_algorithm_to_use = compression_algorithm
logger.info("No Global cache ...")
mp.set_start_method('spawn', force=True)
gc_freq = 100
dump_freq = 25
# Create output directory if it doesn't exist
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load model and tokenizer
model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path)
batched_predict_fn = batched_m1_compress_predict_fn(model)
if args.firstbyte_prob_path is not None:
with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f:
first_byte_prob = json.load(f)
print(first_byte_prob)
first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0)
else:
first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE
# Create dataset and dataloader
dataset = InterleavedJsonlDataset(
file_path=args.input_file,
rank=args.process_id,
world_size=args.num_processes,
)
dataloader = DataLoader(
dataset,
batch_size=args.data_batch_size,
shuffle=False,
collate_fn=lambda x: x
)
input_file = Path(args.input_file)
logger.info(f"Processing file: {input_file}")
output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl"
logger.info("Data loaded. Start processing...")
write_queue = mp.Queue(maxsize=200)
writer_processes = []
writer_output_files = []
for i in range(args.num_workers):
# Create unique output file for each writer
output_path = Path(output_file)
writer_output_file = output_path.parent / f"{output_path.stem}_writer_{i}.jsonl"
writer_output_files.append(writer_output_file)
writer_process = mp.Process(
target=writer_consumer,
args=(write_queue, writer_output_file, dump_freq,args.debug)
)
writer_processes.append(writer_process)
writer_process.start()
logger.info(f"Started writer process {i} for output file: {writer_output_file}")
try:
# Process each batch
for batch_idx, batch in enumerate(dataloader):
payload_for_writer = main_processor_fn(
batch,
compression_algorithm_to_use, # compressor with cache
# compression_algorithm, # 把选择的算法传进去
batched_predict_fn,
first_byte_prob,
args.max_compression_batch_size,
args.debug,
)
logger.info(f"Processed batch {batch_idx}")
write_queue.put(payload_for_writer)
if batch_idx % gc_freq == 0:
# Clean up GPU memory
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Signal completion to all writer processes
shutdown_writers(write_queue, writer_processes)
except Exception as e:
logger.error(f"Error during processing: {e}")
# Try to terminate writer processes cleanly
try:
shutdown_writers(write_queue, writer_processes)
except:
pass
raise
if args.merge_output:
final_output_file = merge_output_files(output_file, writer_output_files)
logger.info(f"Completed processing successfully, merged output written to {final_output_file}")
else:
logger.info(f"Completed processing successfully, outputs written to {args.num_workers} separate files")
if __name__ == "__main__":
main()