| import json |
| import base64 |
| import argparse |
| import os |
| import sys |
| import gzip |
| import time |
| import math |
| import gc |
| import torch |
| import torch.multiprocessing as mp |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import pandas as pd |
| from tqdm import tqdm |
| from typing import List, Dict, Any, Callable, Tuple, Optional |
| import Levenshtein |
| from collections import defaultdict |
|
|
| |
| |
| |
| current_dir = os.getcwd() |
| if current_dir not in sys.path: |
| sys.path.append(current_dir) |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| if script_dir not in sys.path: |
| sys.path.append(script_dir) |
|
|
| print(f"🔧 System Path Fixed. CWD: {current_dir}") |
|
|
| |
| |
| |
| try: |
| from transformers import AutoTokenizer |
| except ImportError: |
| print("❌ Error: 'transformers' not installed.") |
| sys.exit(1) |
|
|
| try: |
| from m1_compression import utils |
| from m1_compression.compressor import ( |
| load_m1_model_and_tokenizer, |
| ALPHABET_SIZE, |
| ARITHMETIC_CODER_BASE, |
| ARITHMETIC_CODER_PRECISION |
| ) |
| from m1_compression.hybrid_arithmetic_coder import CPUArithmeticEncoder |
| from m1_compression.batched_arithmetic_coder import _pdf_to_cdf |
| print("✅ Successfully imported m1_compression modules.") |
| except ImportError as e: |
| print(f"❌ FATAL ERROR: {e}") |
| sys.exit(1) |
|
|
| |
| |
| |
| def vread(buf: bytes, i: int): |
| shift = val = 0 |
| while True: |
| b = buf[i] |
| i += 1 |
| val |= (b & 0x7F) << shift |
| if b < 0x80: return val, i |
| shift += 7 |
|
|
| def unpack_windows(input_bytes: bytes, b64_stream: str) -> List[Tuple[bytes, int]]: |
| try: |
| if not b64_stream: return [] |
| buf, i, cursor, byte_windows = base64.b64decode(b64_stream), 0, 0, [] |
| while i < len(buf): |
| gap, i = vread(buf, i) |
| size, i = vread(buf, i) |
| start = cursor + gap |
| if gap > 0: byte_windows.append((input_bytes[cursor:start], 0)) |
| end = start + size |
| byte_windows.append((input_bytes[start:end], 1)) |
| cursor = end |
| if cursor < len(input_bytes): byte_windows.append((input_bytes[cursor:], 0)) |
| return byte_windows |
| except (base64.binascii.Error, IndexError): return [] |
|
|
| def list_to_comparable_str(int_list: List[int]) -> str: |
| return "".join([chr(min(x, 0x10FFFF)) for x in int_list]) |
|
|
| def pad_batch(batch: List[bytes]): |
| batch_tensors = [torch.tensor(list(data), dtype=torch.int64) for data in batch] |
| lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) |
| padded_batch = torch.nn.utils.rnn.pad_sequence( |
| batch_tensors, |
| batch_first=True, |
| padding_value=0 |
| ) |
| return padded_batch, lengths |
|
|
| |
| |
| |
|
|
| def batched_m1_compress_predict_fn(model): |
| def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: |
| if input_tensor.dim() == 1: input_tensor = input_tensor.unsqueeze(0) |
| with torch.no_grad(): |
| logits = model(input_tensor, **kwargs) |
| logits = logits[..., :256].float() |
| probs = torch.softmax(logits, dim=-1) |
| return probs |
| return predict_fn |
|
|
| def compress_segments_ac_impl( |
| sorted_segments: List[bytes], |
| batched_predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| device: torch.device |
| ) -> List[List[int]]: |
| """ |
| 底层批处理函数:接收一大堆 segments,分批送入 GPU 计算,再用 CPU 编码 |
| """ |
| M = len(sorted_segments) |
| if M == 0: return [] |
| |
| |
| segment_to_indices = defaultdict(list) |
| for i, seg in enumerate(sorted_segments): |
| segment_to_indices[seg].append(i) |
| |
| unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 0] |
| segment_to_compressed = {} |
| |
| encoder = CPUArithmeticEncoder(base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION) |
| |
| |
| |
| GPU_BATCH_SIZE = 256 |
| |
| for i in range(0, len(unique_segments), GPU_BATCH_SIZE): |
| batch_segments = unique_segments[i : i + GPU_BATCH_SIZE] |
| |
| try: |
| padded_batch, lengths = pad_batch(batch_segments) |
| padded_batch = padded_batch.to(device) |
| |
| |
| with torch.no_grad(): |
| prompt_probs = batched_predict_fn(padded_batch) |
| |
| final_probs = torch.cat( |
| [ |
| first_byte_prob.expand(prompt_probs.shape[0], -1, -1), |
| prompt_probs[:, :-1, ...] |
| ], |
| dim=1 |
| ) |
| |
| final_probs = utils.batched_normalize_pdf_for_arithmetic_coding(final_probs) |
| cdfs_gpu = _pdf_to_cdf(final_probs) |
| |
| cdf_low = cdfs_gpu.gather(2, padded_batch.unsqueeze(-1)).squeeze(-1) |
| cdf_high = cdfs_gpu.gather(2, (padded_batch + 1).unsqueeze(-1)).squeeze(-1) |
| cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) |
| |
| chunked_compressed_bytes, _, _ = encoder.incremental_batched_encode( |
| cdf_ends.cpu(), |
| ALPHABET_SIZE, |
| lengths, |
| bit_threshold=16, |
| force_padding_to_threshold=False, |
| return_num_padded_bits=True |
| ) |
| |
| for seg, code in zip(batch_segments, chunked_compressed_bytes): |
| segment_to_compressed[seg] = list(code) |
| |
| except Exception as e: |
| |
| for seg in batch_segments: |
| segment_to_compressed[seg] = list(seg) |
| |
| all_results = [None] * M |
| for seg, indices in segment_to_indices.items(): |
| res = segment_to_compressed.get(seg, list(seg)) |
| for idx in indices: |
| all_results[idx] = res |
| |
| return all_results |
|
|
| class M1ACManager: |
| def __init__(self, model_path, first_prob_path, device_id): |
| self.device = torch.device(f"cuda:{device_id}") |
| print(f"[GPU {device_id}] Loading M1 Model...") |
| |
| self.model, _, _ = load_m1_model_and_tokenizer(model_path) |
| self.model.to(self.device) |
| self.model.eval() |
| self.predict_fn = batched_m1_compress_predict_fn(self.model) |
| |
| if first_prob_path and os.path.exists(first_prob_path): |
| with open(first_prob_path, 'r') as f: |
| prob_data = json.load(f) |
| self.first_byte_prob = torch.tensor(prob_data, dtype=torch.float32, device=self.device) |
| if self.first_byte_prob.dim() == 1: |
| self.first_byte_prob = self.first_byte_prob.unsqueeze(0).unsqueeze(0) |
| else: |
| self.first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device=self.device) / ALPHABET_SIZE |
|
|
| def compress_batch(self, inputs: List[Tuple[str, Optional[str]]]) -> List[List[int]]: |
| """ |
| 新的批量压缩接口。 |
| inputs: List of (text, windows_b64) |
| Returns: List of compressed int lists |
| """ |
| all_segments_flat = [] |
| |
| |
| sample_segment_map = [] |
| |
| current_idx = 0 |
| |
| |
| for text, windows_b64 in inputs: |
| raw_bytes = text.encode('utf-8') |
| sample_segs = [] |
| |
| if windows_b64: |
| |
| for seg, ind in unpack_windows(raw_bytes, windows_b64): |
| if len(seg) > 0: sample_segs.append(seg) |
| else: |
| |
| CHUNK = 512 |
| for i in range(0, len(raw_bytes), CHUNK): |
| sample_segs.append(raw_bytes[i : i + CHUNK]) |
| |
| count = len(sample_segs) |
| sample_segment_map.append((current_idx, current_idx + count)) |
| all_segments_flat.extend(sample_segs) |
| current_idx += count |
| |
| if not all_segments_flat: |
| return [[] for _ in inputs] |
| |
| |
| |
| compressed_chunks_flat = compress_segments_ac_impl( |
| all_segments_flat, self.predict_fn, self.first_byte_prob, self.device |
| ) |
| |
| |
| results = [] |
| for start, end in sample_segment_map: |
| |
| sample_chunks = compressed_chunks_flat[start:end] |
| full_stream = [x for chunk in sample_chunks for x in chunk] |
| results.append(full_stream) |
| |
| return results |
|
|
| |
| |
| |
|
|
| def process_file_worker(rank, gpu_id, file_path, output_dir, model_path, prob_path, max_lines): |
| try: |
| torch.cuda.set_device(gpu_id) |
| |
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained("infly/OpenCoder-1.5B-Base", trust_remote_code=True) |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token |
| except: |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
| |
| try: |
| ac_manager = M1ACManager(model_path, prob_path, gpu_id) |
| except Exception as e: |
| print(f"❌ [GPU {gpu_id}] Init Failed: {e}") |
| return |
|
|
| results = {"Gzip": [], "Tokenizer": [], "AC_M1": []} |
| filename = os.path.basename(file_path) |
| print(f"[GPU {gpu_id}] Processing {filename}...") |
| |
| |
| |
| |
| WORKER_BATCH_SIZE = 200 |
| |
| batch_texts = [] |
| batch_pert_texts = [] |
| batch_metas = [] |
| |
| processed_count = 0 |
| |
| def flush_batch(): |
| nonlocal batch_texts, batch_pert_texts, batch_metas |
| if not batch_texts: return |
| |
| |
| for t, tp in zip(batch_texts, batch_pert_texts): |
| gz1 = list(gzip.compress(t.encode('utf-8'))) |
| gz2 = list(gzip.compress(tp.encode('utf-8'))) |
| if gz1: |
| d = Levenshtein.distance(list_to_comparable_str(gz1), list_to_comparable_str(gz2)) |
| results["Gzip"].append(d / len(gz1)) |
|
|
| |
| |
| for t, tp in zip(batch_texts, batch_pert_texts): |
| tok1 = tokenizer.encode(t, add_special_tokens=False) |
| tok2 = tokenizer.encode(tp, add_special_tokens=False) |
| if tok1: |
| d = Levenshtein.distance(list_to_comparable_str(tok1), list_to_comparable_str(tok2)) |
| results["Tokenizer"].append(d / len(tok1)) |
| |
| |
| |
| orig_inputs = list(zip(batch_texts, batch_metas)) |
| pert_inputs = list(zip(batch_pert_texts, [None]*len(batch_pert_texts))) |
| |
| try: |
| |
| ac1_list = ac_manager.compress_batch(orig_inputs) |
| ac2_list = ac_manager.compress_batch(pert_inputs) |
| |
| for ac1, ac2 in zip(ac1_list, ac2_list): |
| if ac1 and len(ac1) > 0: |
| d = Levenshtein.distance(list_to_comparable_str(ac1), list_to_comparable_str(ac2)) |
| results["AC_M1"].append(d / len(ac1)) |
| except Exception as e: |
| print(f"[GPU {gpu_id}] AC Batch Error: {e}") |
|
|
| |
| batch_texts, batch_pert_texts, batch_metas = [], [], [] |
|
|
| with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: |
| for i, line in enumerate(f): |
| if max_lines > 0 and i >= max_lines: break |
| try: |
| data = json.loads(line) |
| text = data.get('text', '') |
| windows_b64 = data.get('windows_starts_lens_b64') |
| if not text or len(text) < 100: continue |
| |
| cut_idx = max(1, int(len(text) * 0.1)) |
| text_pert = text[cut_idx:] |
| |
| |
| batch_texts.append(text) |
| batch_pert_texts.append(text_pert) |
| batch_metas.append(windows_b64) |
| |
| processed_count += 1 |
| |
| |
| if len(batch_texts) >= WORKER_BATCH_SIZE: |
| flush_batch() |
| if processed_count % 500 == 0: |
| print(f"[GPU {gpu_id}] Processed {processed_count} lines...") |
| |
| except Exception: |
| continue |
| |
| |
| flush_batch() |
| |
| output_file = os.path.join(output_dir, f"partial_result_{rank}_{filename}.json") |
| with open(output_file, 'w') as f: |
| json.dump(results, f) |
| print(f"✅ [GPU {gpu_id}] Done {filename}. Total: {processed_count}") |
|
|
| except Exception as e: |
| print(f"❌ [GPU {gpu_id}] Worker failed: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input_dir", type=str, required=True) |
| parser.add_argument("--m1_model", type=str, required=True) |
| parser.add_argument("--first_prob_path", type=str, required=True) |
| parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_parallel") |
| parser.add_argument("--max_lines", type=int, default=10000) |
| args = parser.parse_args() |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if f.endswith('.jsonl') and "writer" not in f] |
| files.sort() |
| |
| num_gpus = torch.cuda.device_count() |
| if num_gpus == 0: return |
|
|
| if len(files) > num_gpus: |
| files = files[:num_gpus] |
| |
| actual_procs = len(files) |
| print(f"🚀 Launching {actual_procs} processes (Batch Mode)...") |
|
|
| mp.set_start_method('spawn', force=True) |
| processes = [] |
| |
| for rank in range(actual_procs): |
| p = mp.Process( |
| target=process_file_worker, |
| args=(rank, rank % num_gpus, files[rank], args.output_dir, args.m1_model, args.first_prob_path, args.max_lines) |
| ) |
| p.start() |
| processes.append(p) |
|
|
| for p in processes: |
| p.join() |
| |
| print("✅ All workers finished. Merging results...") |
|
|
| final_results = {"Gzip": [], "Tokenizer": [], "AC_M1": []} |
| for filename in os.listdir(args.output_dir): |
| if filename.startswith("partial_result_") and filename.endswith(".json"): |
| try: |
| with open(os.path.join(args.output_dir, filename), 'r') as f: |
| data = json.load(f) |
| for k in final_results: |
| if k in data: final_results[k].extend(data[k]) |
| except: pass |
| |
| for k, v in final_results.items(): |
| print(f" -> {k}: {len(v)} samples collected.") |
|
|
| plot_records = [] |
| for algo, vals in final_results.items(): |
| cleaned = [v for v in vals if v < 2.0] |
| for v in cleaned: |
| plot_records.append({"Algorithm": algo, "Normalized Edit Distance": v}) |
| |
| if not plot_records: |
| print("❌ No data collected.") |
| return |
|
|
| print("📊 Generating plot...") |
| try: |
| df = pd.DataFrame(plot_records) |
| plt.figure(figsize=(12, 7)) |
| sns.set_style("whitegrid") |
| sns.kdeplot(data=df, x="Normalized Edit Distance", hue="Algorithm", fill=True, common_norm=False, palette="tab10", alpha=0.5) |
| plt.title("Compression Stability Analysis") |
| plt.xlabel("Normalized Levenshtein Distance") |
| plt.xlim(0, 1.2) |
| plt.savefig(os.path.join(args.output_dir, "stability_parallel_batch.png"), dpi=300) |
| except Exception as e: |
| print(f"⚠️ Plotting failed: {e}") |
|
|
| stats = {k: {"mean": float(np.mean(v)), "count": len(v)} for k, v in final_results.items() if v} |
| with open(os.path.join(args.output_dir, "final_stats.json"), 'w') as f: |
| json.dump(stats, f, indent=2) |
| |
| print(f"🎉 Done!") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
| """ |
| # 有 8 个json 文件 先测试一个文件 |
| python compare_three_compression_lv.py \ |
| --input_dir /mnt/hdfs/user/linzheng/data/ocpython_subsampled_50G_entropy90_splits_chunk512_ow20_iterative-true_forcepadding-true_merged_ac \ |
| --m1_model /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_40M_lr1e-3_steps200k_bs8_seqlen2048_python/checkpoints/0000200000 \ |
| --first_prob_path /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/ac_unigram_probs/python500k_unigram_prob.json |
| |
| 这里可能出现缺少某些模块 |
| pip install xformers==0.0.23.post1 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com |
| """ |