| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import IterableDataset, Dataset, DataLoader |
| import json |
| import numpy as np |
| from pathlib import Path |
| from typing import Iterator, List, Dict, Any, Callable, Tuple, Optional |
| import logging |
| import argparse |
| import base64 |
| import time |
| import math |
| import gc |
| from collections import defaultdict, Counter |
| from m1_compression import utils |
| from m1_compression.compressor import ( |
| load_m1_model_and_tokenizer, |
| ALPHABET_SIZE, |
| ) |
| import multiprocessing as mp |
| from m1_compression.enumerative_coder_simple import SimpleAdaptiveRankCodec |
| from offline_entropy_window_split import unpack_windows |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger() |
|
|
| def pseudo_to_packed_bytes(lst: list[int]) -> bytes: |
| out = bytearray() |
| acc = bits = 0 |
| for v in lst: |
| acc |= (v & 0x1FF) << bits |
| bits += 9 |
| while bits >= 8: |
| out.append(acc & 0xFF) |
| acc >>= 8 |
| bits -= 8 |
| if bits: |
| out.append(acc) |
| return bytes(out) |
|
|
| def packed_bytes_to_pseudo(b: bytes) -> list[int]: |
| out, acc, bits = [], 0, 0 |
| for byte in b: |
| acc |= byte << bits |
| bits += 8 |
| while bits >= 9: |
| out.append(acc & 0x1FF) |
| acc >>= 9 |
| bits -= 9 |
| return out |
|
|
|
|
| def calculate_compression_ratio(original_bytes: List[bytes], compressed_segments: List[bytes]) -> float: |
| if not compressed_segments or len(original_bytes) == 0: |
| return 1.0 |
| |
| total_compressed_length = sum(len(compressed_seg) for compressed_seg in compressed_segments) |
| ratio = total_compressed_length / sum(len(orig_seg) for orig_seg in original_bytes) |
| if ratio > 2.0: |
| logger.warning(f"Unusual compression ratio: {ratio:.4f} (compressed larger than original)") |
| |
| return ratio |
|
|
| def collect_window_size_statistics(segmented_results: List[List[bytes]]) -> Dict[int, int]: |
| window_size_counts = Counter() |
| |
| for segments in segmented_results: |
| for segment in segments: |
| window_size = len(segment) |
| window_size_counts[window_size] += 1 |
| |
| return dict(window_size_counts) |
|
|
| def pad_batch(batch: List[bytes]): |
| batch_tensors = [torch.tensor(data, dtype=torch.int64) for data in batch] |
| lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) |
| padded_batch = torch.nn.utils.rnn.pad_sequence( |
| batch_tensors, |
| batch_first=True, |
| padding_value=0, |
| padding_side="right" |
| ) |
| return padded_batch, lengths |
|
|
| def get_batch_size_for_length(window_len, max_batch_size): |
| """Determines the batch size for a given window length.""" |
| BATCH_SIZE_TIERS = { |
| 128: max_batch_size, |
| 512: max(max_batch_size // 64, 1), |
| 1024: max(max_batch_size // 128, 1), |
| 2048: max(max_batch_size // 256, 1), |
| } |
| for max_len, batch_size in BATCH_SIZE_TIERS.items(): |
| if window_len <= max_len: |
| return batch_size |
| return 1 |
|
|
| def find_next_batch_range(all_windows, start_idx, max_m1_batch_size): |
| M = len(all_windows) |
| if start_idx >= M: |
| return start_idx, start_idx |
|
|
| first_window_len = len(all_windows[start_idx]) |
| base_batch_size = get_batch_size_for_length(first_window_len, max_m1_batch_size) |
|
|
| low = start_idx |
| high = min(start_idx + base_batch_size, M) |
| high_batch_size = get_batch_size_for_length(len(all_windows[high - 1]), max_m1_batch_size) |
| if high_batch_size == base_batch_size: |
| return start_idx, high |
|
|
| search_low = low |
| search_high = high |
| while search_low < search_high: |
| mid = search_low + (search_high - search_low) // 2 |
| mid_window_len = len(all_windows[mid]) |
| if get_batch_size_for_length(mid_window_len, max_m1_batch_size) == base_batch_size: |
| |
| |
| search_low = mid + 1 |
| else: |
| |
| |
| |
| search_high = mid |
| end_idx = search_low |
| if end_idx == start_idx: |
| return start_idx, start_idx + 1 |
| else: |
| return start_idx, end_idx |
|
|
|
|
| def simple_rle_topk_compression( |
| batch: List[bytes], |
| predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| max_m1_batch_size: int = 4096, |
| debug: bool = False, |
| ): |
| """use language model to compress, return compressed bytes and padded bits |
| |
| Args: |
| sliding_windows: List of byte sequences to compress |
| predict_fn: Function that predicts next token probabilities |
| return_num_padded_bits: Whether to return number of padded bits |
| profile: Whether to print timing information for each major step |
| """ |
| if debug: |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| start_event.record() |
| torch.cuda.synchronize() |
| print("[Debug CUDA] time start", flush=True) |
|
|
| assert first_byte_prob.shape == (1, 1, ALPHABET_SIZE), "first_byte_prob must be of shape (1, 1, ALPHABET_SIZE)" |
|
|
| |
| |
| batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in batch] |
|
|
| M = len(batched_windows_np) |
|
|
| batched_repeat_probs = [] |
| batched_ranks = [] |
| batched_lengths = [] |
| if debug: |
| batched_sorted_indices = [] |
|
|
| start_idx = 0 |
| while start_idx < M: |
| |
| start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size) |
| windows_np_chunked = batched_windows_np[start_idx:end_idx] |
| padded_batched_windows, lengths = pad_batch(windows_np_chunked) |
| padded_batched_windows, lengths = padded_batched_windows.cuda(), lengths.cuda() |
|
|
| prompt_probs = predict_fn(padded_batched_windows) |
| prompt_probs = torch.cat( |
| [ |
| first_byte_prob.expand(prompt_probs.shape[0], -1, -1), |
| prompt_probs[:, :-1, ...] |
| ], |
| dim=1 |
| ) |
| prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs) |
| |
| |
| |
| |
| next_token_probs = torch.gather( |
| prompt_probs, |
| dim=-1, |
| index=padded_batched_windows.unsqueeze(-1) |
| ).squeeze(-1) |
| sorted_indices = torch.argsort(prompt_probs, dim=-1, descending=True) |
| rank_bitvector = padded_batched_windows.unsqueeze(-1) == sorted_indices |
| ranks = torch.argmax(rank_bitvector.float(), dim=-1) |
| start_idx = end_idx |
| batched_repeat_probs.extend(next_token_probs.cpu().numpy().tolist()) |
| batched_ranks.extend(ranks.cpu().numpy().tolist()) |
| batched_lengths.extend(lengths.cpu().numpy().tolist()) |
| if debug: |
| batched_sorted_indices.extend(sorted_indices.cpu().numpy().tolist()) |
|
|
| if debug: |
| return batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices |
| else: |
| return batched_repeat_probs, batched_ranks, batched_lengths |
|
|
| class JsonlShardedDataset(Dataset): |
| def __init__( |
| self, |
| file_path: str, |
| current_proc_rank: int = 0, |
| total_procs: int = 1, |
| ) -> None: |
|
|
| assert 0 <= current_proc_rank < total_procs, "rank must be in [0, world_size)" |
| self.current_proc_rank = current_proc_rank |
| self.total_procs = total_procs |
|
|
| |
| with open(file_path, "r", encoding="utf-8") as f: |
| full_data: List[Dict[str, Any]] = [json.loads(line) for line in f] |
|
|
| |
| total = len(full_data) |
| per_proc = math.ceil(total / total_procs) |
| start = current_proc_rank * per_proc |
| end = min(start + per_proc, total) |
| self.data = full_data[start:end] |
|
|
| def __len__(self) -> int: |
| return len(self.data) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| return self.data[idx] |
|
|
|
|
| class InterleavedJsonlDataset(IterableDataset): |
| """ |
| An iterable-style dataset for reading a large JSONL file using an |
| interleaving/striding pattern, without yielding state information. |
| |
| This is designed for multi-process data loading. Each process reads the |
| entire file but only processes lines that match its rank (offset). |
| For `N` total processes (world_size), process `r` (rank) will read |
| lines r, r+N, r+2N, ... (0-indexed). |
| |
| This method ensures an even distribution of lines across processes. |
| |
| Args: |
| file_path (str): Path to the JSONL file. |
| rank (int): The rank of the current process, used as the offset. |
| world_size (int): The total number of processes, used as the block_size/stride. |
| """ |
| def __init__( |
| self, |
| file_path: str, |
| rank: int, |
| world_size: int, |
| ) -> None: |
| super().__init__() |
| |
| if not (0 <= rank < world_size): |
| raise ValueError(f"Rank must be in [0, {world_size-1}], but got {rank}") |
|
|
| self.file_path = file_path |
| self.offset = rank |
| self.block_size = world_size |
|
|
| def __iter__(self) -> Iterator[Dict[str, Any]]: |
| """ |
| The iterator method that yields the parsed JSON data for the assigned lines. |
| """ |
| try: |
| with open(self.file_path, "r", encoding="utf-8") as f: |
| |
| |
| for line_number, line in enumerate(f): |
| |
| if (line_number % self.block_size) == self.offset: |
| try: |
| |
| yield json.loads(line) |
| except json.JSONDecodeError: |
| |
| |
| print(f"Warning: Rank {self.offset} could not decode JSON on line ~{line_number+1}. Skipping.") |
| continue |
| except Exception as e: |
| print(f"Error in worker {self.offset}: {e}") |
| raise |
|
|
| def batched_m1_compress_predict_fn(model): |
| def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: |
| if input_tensor.dim() == 1: |
| input_tensor = input_tensor.unsqueeze(0) |
| with torch.no_grad(): |
| |
| logits = model(input_tensor, **kwargs) |
| logits = logits[..., :256] |
| logits = logits.float() |
| assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values." |
| probs = torch.softmax(logits, dim=-1) |
| return probs |
| |
| return predict_fn |
|
|
| def segment_prediction_fn( |
| batch: List[Dict[str, Any]], |
| max_m1_batch_size, |
| batched_predict_fn, |
| first_byte_prob, |
| debug |
| ): |
| """ |
| Consumer: reads from task_queue, compresses, puts result in result_queue. |
| """ |
| all_segments = [] |
| compressed_or_raw_segments = [] |
| sample_idx_to_list_segment_idx = defaultdict(list) |
| segment_idx = 0 |
| for sample_idx, item in enumerate(batch): |
| assert "windows_starts_lens_b64" in item, "windows_starts_lens_b64 must be in item" |
| sample_bytes = item["text"].encode('utf-8') |
| byte_windows = unpack_windows(sample_bytes, item["windows_starts_lens_b64"]) |
| for byte_window_indicator in byte_windows: |
| all_segments.append(byte_window_indicator[0]) |
| compressed_or_raw_segments.append(byte_window_indicator[1]) |
| sample_idx_to_list_segment_idx[sample_idx].append(segment_idx) |
| segment_idx += 1 |
|
|
| effective_segments = [] |
| ineffective_segments = [] |
| for orig_idx, (segment, indicator) in enumerate(zip(all_segments, compressed_or_raw_segments)): |
| if len(segment) > 3 and indicator == 1: |
| effective_segments.append((orig_idx, segment)) |
| else: |
| ineffective_segments.append((orig_idx, segment)) |
| |
| |
| sorted_effective_segments = sorted(effective_segments, key=lambda x: len(x[1])) |
| sorted_idx, sorted_segments = zip(*sorted_effective_segments) |
| sorted_segments = list(sorted_segments) |
| effective_segments_idx_map = { |
| orig_idx: new_idx |
| for new_idx, orig_idx in enumerate(sorted_idx) |
| } |
| raw_idx, raw_segments = zip(*ineffective_segments) |
| raw_segments = list(raw_segments) |
| ineffective_segments_idx_map = { |
| orig_idx: new_idx |
| for new_idx, orig_idx in enumerate(raw_idx) |
| } |
| |
| batch_ret = simple_rle_topk_compression( |
| sorted_segments, |
| batched_predict_fn, |
| first_byte_prob, |
| max_m1_batch_size=max_m1_batch_size, |
| debug=debug, |
| ) |
| if debug: |
| batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices = batch_ret |
| else: |
| batched_repeat_probs, batched_ranks, batched_lengths = batch_ret |
| batched_sorted_indices = None |
| return ( |
| batch, |
| sorted_segments, |
| raw_segments, |
| effective_segments_idx_map, |
| ineffective_segments_idx_map, |
| sample_idx_to_list_segment_idx, |
| batched_repeat_probs, |
| batched_ranks, |
| batched_lengths, |
| batched_sorted_indices, |
| debug |
| ) |
|
|
| def segment_compression_fn( |
| batch: List[Dict[str, Any]], |
| sorted_segments: List[List[int]], |
| raw_segments: List[List[int]], |
| effective_segments_idx_map: Dict[int, int], |
| ineffective_segments_idx_map: Dict[int, int], |
| sample_idx_to_list_segment_idx: Dict[int, List[int]], |
| batched_repeat_probs: List[List[float]], |
| batched_ranks: List[List[int]], |
| batched_lengths: List[int], |
| batched_sorted_indices: Optional[List[List[int]]] = None, |
| debug: bool = False, |
| ): |
| M = len(batched_lengths) |
| batched_compressed_bytes = [] |
|
|
| for i in range(M): |
| lengths = batched_lengths[i] |
| window_bytes = sorted_segments[i] |
| repeat_probs = batched_repeat_probs[i][:lengths] |
| ranks = batched_ranks[i][:lengths] |
| codec = SimpleAdaptiveRankCodec(top_k=4) |
| encoding = codec.encode_window(list(window_bytes), repeat_probs, ranks) |
| compressed_bytes = codec.encoding_to_pseudo_bytes(encoding) |
| if debug: |
| sorted_indices = batched_sorted_indices[i][:lengths] |
| decoded = codec.decode_window(encoding, lengths, sorted_indices) |
| assert bytes(decoded) == window_bytes, "decoded does not match window_bytes: \n{} and \n{}".format(decoded, window_bytes) |
| if i < 10: |
| logger.info(f"Example input window bytes: {window_bytes}") |
| logger.info(f"Example encoding : {encoding}") |
| logger.info(f"Example compressed bytes : {compressed_bytes}") |
| batched_compressed_bytes.append(compressed_bytes) |
|
|
|
|
| |
| compressed_bytes = [[] for _ in range(len(batch))] |
| original_bytes = [[] for _ in range(len(batch))] |
|
|
| for sample_idx, list_segment_idx in sample_idx_to_list_segment_idx.items(): |
| for segment_idx in list_segment_idx: |
| if segment_idx in effective_segments_idx_map: |
| compressed_idx = effective_segments_idx_map[segment_idx] |
| compressed_byte = batched_compressed_bytes[compressed_idx] |
| else: |
| raw_idx = ineffective_segments_idx_map[segment_idx] |
| compressed_byte = raw_segments[raw_idx] |
| |
| if debug: |
| if segment_idx in effective_segments_idx_map: |
| compressed_idx = effective_segments_idx_map[segment_idx] |
| original_byte = sorted_segments[compressed_idx] |
| else: |
| raw_idx = ineffective_segments_idx_map[segment_idx] |
| original_byte = raw_segments[raw_idx] |
| original_bytes[sample_idx].append(original_byte) |
| |
| compressed_bytes[sample_idx].extend(list(compressed_byte)) |
|
|
| batched_compressed_bytes = [] |
| if debug: |
| assert len(compressed_bytes) == len(batch) |
| for sample_idx in range(len(batch)): |
| assert b"".join(original_bytes[sample_idx]) == batch[sample_idx]["text"].encode('utf-8'), ( |
| "Assembled original bytes does not match the original batch: \n{} and \n{}".format( |
| b"".join(original_bytes[sample_idx]), batch[sample_idx] |
| ) |
| ) |
| |
| |
| |
| |
| logger.info(f"Example compressed bytes: {compressed_bytes[0]}") |
|
|
| write_results = [] |
| ac_key = "m1_enumerative" |
| for item, compressed_bytes_item in zip(batch, compressed_bytes): |
| compressed = pseudo_to_packed_bytes(compressed_bytes_item) |
|
|
| result = { |
| **item, |
| ac_key: base64.b64encode(compressed).decode("ascii") |
| } |
| if debug: |
| unpacked = packed_bytes_to_pseudo(compressed) |
| assert unpacked == compressed_bytes_item, "Unpacked does not match compressed bytes item: \n{} and \n{}".format(unpacked, compressed_bytes_item) |
| logger.info("✓ pseudo-bytes-enc-dec round-trip passes") |
| write_results.append(result) |
| orig_total_bytes = sum([len(data["text"].encode('utf-8')) for data in batch]) |
| compressed_total_bytes = sum([len(data) for data in compressed_bytes]) |
| compression_ratio = orig_total_bytes / compressed_total_bytes if compressed_total_bytes > 0 else 0 |
| logger.info(f"[DEBUG] original total bytes: {orig_total_bytes}, compressed total bytes: {compressed_total_bytes}, compression rate : {compression_ratio:.3f}") |
| return write_results |
|
|
| def writer_consumer(write_queue, output_file, buffer_size=100): |
| """ |
| Writer consumer: reads compressed results from write_queue and writes to file. |
| Maintains its own buffer and writes when buffer is full or receives sentinel. |
| """ |
| write_buf = [] |
| |
| try: |
| with open(output_file, 'w', encoding='utf-8') as f: |
| while True: |
| args = write_queue.get() |
| if args is None: |
| break |
| ( |
| batch, |
| sorted_segments, |
| raw_segments, |
| effective_segments_idx_map, |
| ineffective_segments_idx_map, |
| sample_idx_to_list_segment_idx, |
| batched_repeat_probs, |
| batched_ranks, |
| batched_lengths, |
| batched_sorted_indices, |
| debug |
| ) = args |
| write_results = segment_compression_fn( |
| batch, |
| sorted_segments, |
| raw_segments, |
| effective_segments_idx_map, |
| ineffective_segments_idx_map, |
| sample_idx_to_list_segment_idx, |
| batched_repeat_probs, |
| batched_ranks, |
| batched_lengths, |
| batched_sorted_indices=batched_sorted_indices, |
| debug=debug |
| ) |
| write_buf.extend(write_results) |
| |
| |
| if len(write_buf) >= buffer_size: |
| logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}") |
| for buffered_item in write_buf: |
| f.write(json.dumps(buffered_item) + '\n') |
| f.flush() |
| write_buf = [] |
|
|
| |
| if write_buf: |
| logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}") |
| for buffered_item in write_buf: |
| f.write(json.dumps(buffered_item) + '\n') |
| f.flush() |
| |
| except Exception as e: |
| logger.error(f"Writer process error: {e}") |
| raise |
|
|
| def merge_output_files(output_file, writer_output_files): |
| """Merge all writer output files into a single file""" |
| logger.info(f"Merging {len(writer_output_files)} writer files into {output_file}") |
| |
| with open(output_file, 'w', encoding='utf-8') as outf: |
| for writer_output_file in writer_output_files: |
| if writer_output_file.exists(): |
| with open(writer_output_file, 'r', encoding='utf-8') as inf: |
| for line in inf: |
| outf.write(line) |
| |
| writer_output_file.unlink() |
| logger.info(f"Merged and removed writer file: {writer_output_file}") |
| |
| logger.info(f"Merged output written to: {output_file}") |
| return output_file |
|
|
| def shutdown_writers(write_queue, writer_processes): |
| """Send shutdown signals to shared queue and wait for all writers to complete""" |
| |
| for i in range(len(writer_processes)): |
| write_queue.put(None) |
| logger.info(f"Sent shutdown signal {i+1}/{len(writer_processes)}") |
| |
| |
| for i, writer_process in enumerate(writer_processes): |
| writer_process.join() |
| if writer_process.exitcode != 0: |
| logger.error(f"Writer process {i} failed with exit code: {writer_process.exitcode}") |
| else: |
| logger.info(f"Writer process {i} completed successfully") |
|
|
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach') |
| parser.add_argument('--input_file', type=str, required=True, |
| help='Directory containing input JSONL files') |
| parser.add_argument('--output_dir', type=str, required=True, |
| help='Directory to write compressed results') |
| parser.add_argument('--entropy_model_path', type=str, required=True, |
| help='Path to the M1 model checkpoint') |
| parser.add_argument('--compression_model_path', type=str, required=True, |
| help='Path to the M1 model checkpoint') |
| parser.add_argument('--data_batch_size', type=int, default=512, |
| help='Size of batches for processing (default: 512)') |
| parser.add_argument('--output_window_size', type=int, default=16, |
| help='Size of window for compression (default: 16)') |
| parser.add_argument('--max_window_size', type=int, default=1024, |
| help='Maximum window size for reading from each file (default: 1024)') |
| parser.add_argument('--max_entropy_batch_size', type=int, default=4096, |
| help='Size of max batch for compression (default: 4096)') |
| parser.add_argument('--max_compression_batch_size', type=int, default=4096, |
| help='Size of max batch for compression (default: 4096)') |
| parser.add_argument('--chunk_size', type=int, default=512, |
| help='Size of chunk for compression (default: 512)') |
| parser.add_argument('--base_global_quantile', type=float, default=0.9, |
| help='Base global quantile for compression (default: 0.9)') |
| parser.add_argument('--base_monotonic_quantile', type=float, default=0.9, |
| help='Base monotonic quantile for compression (default: 0.9)') |
| parser.add_argument('--debug', action='store_true', default=False, |
| help='Debug mode (default: False)') |
| parser.add_argument('--firstbyte_prob_path', type=str, default=None, |
| help='Probability path for the first word of each window (default : None)') |
| parser.add_argument('--num_workers', type=int, default=1, |
| help='Number of workers for CPU jobs (default: 1)') |
| parser.add_argument('--process_id', type=int, default=0, |
| help='Process ID for distributed processing (default: 0)') |
| parser.add_argument('--num_processes', type=int, default=1, |
| help='Number of processes for distributed processing (default: 1)') |
| parser.add_argument('--merge_output', action='store_true', default=False, |
| help='Merge all writer output files into a single file (default: False)') |
|
|
| args = parser.parse_args() |
|
|
| mp.set_start_method('spawn', force=True) |
| gc_freq = 100 |
| dump_freq = 25 |
|
|
| |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path) |
| batched_predict_fn = batched_m1_compress_predict_fn(model) |
|
|
| if args.firstbyte_prob_path is not None: |
| with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f: |
| first_byte_prob = json.load(f) |
| print(first_byte_prob) |
| first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) |
| else: |
| first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE |
|
|
| |
| dataset = InterleavedJsonlDataset( |
| file_path=args.input_file, |
| rank=args.process_id, |
| world_size=args.num_processes, |
| ) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.data_batch_size, |
| shuffle=False, |
| collate_fn=lambda x: x |
| ) |
|
|
| input_file = Path(args.input_file) |
| logger.info(f"Processing file: {input_file}") |
| |
| output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl" |
| |
| logger.info("Data loaded. Start processing...") |
|
|
| write_queue = mp.Queue(maxsize=200) |
| writer_processes = [] |
| writer_output_files = [] |
| for i in range(args.num_workers): |
| |
| output_path = Path(output_file) |
| writer_output_file = output_path.parent / f"{output_path.stem}_writer_{i}.jsonl" |
| writer_output_files.append(writer_output_file) |
| writer_process = mp.Process( |
| target=writer_consumer, |
| args=(write_queue, writer_output_file, dump_freq) |
| ) |
| |
| writer_processes.append(writer_process) |
| writer_process.start() |
| logger.info(f"Started writer process {i} for output file: {writer_output_file}") |
| |
| try: |
| |
| for batch_idx, batch in enumerate(dataloader): |
| pred_results = segment_prediction_fn( |
| batch, |
| max_m1_batch_size=args.max_compression_batch_size, |
| batched_predict_fn=batched_predict_fn, |
| first_byte_prob=first_byte_prob, |
| debug=args.debug, |
| ) |
| logger.info(f"Processed batch {batch_idx}") |
| write_queue.put(pred_results) |
|
|
| if batch_idx % gc_freq == 0: |
| |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| shutdown_writers(write_queue, writer_processes) |
| |
| except Exception as e: |
| logger.error(f"Error during processing: {e}") |
| |
| try: |
| shutdown_writers(write_queue, writer_processes) |
| except: |
| pass |
| raise |
|
|
| if args.merge_output: |
| final_output_file = merge_output_files(output_file, writer_output_files) |
| logger.info(f"Completed processing successfully, merged output written to {final_output_file}") |
| else: |
| logger.info(f"Completed processing successfully, outputs written to {args.num_workers} separate files") |
|
|
| if __name__ == "__main__": |
| main() |
|
|