| |
| |
|
|
| """ |
| improved_fast_compare.py |
| 1)修改 ac 算法处理,让压缩前后的数据处理逻辑一致 |
| |
| 2) 扰动策略扩展: |
| --perturb_mode: |
| - prefix_delete 删除前缀(默认,等价你原来的 10% 前缀删除) |
| - suffix_delete 删除后缀 |
| - middle_delete 删除中间连续一段 |
| - random_span_delete 随机位置删除连续一段(可复现) |
| - random_char_delete 随机删除若干字符(非连续) |
| |
| --delete_ratio 控制删除比例(默认 0.1) |
| --random_seed 控制随机扰动可复现(默认 1234;每个 worker 会加上 rank 做偏移) |
| |
| 3) Gzip:尽量固定 mtime=0,避免 gzip header 时间戳引入无关差异 |
| |
| |
| 运行示例: |
| python improved_fast_compare.py \ |
| --input_dir /mnt/hdfs/user/linzheng/data/ocpython_subsampled_50G_entropy90_splits_chunk512_ow20_iterative-true_forcepadding-true_merged_ac \ |
| --m1_model /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_40M_lr1e-3_steps200k_bs8_seqlen2048_python/checkpoints/0000200000 \ |
| --first_prob_path /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/ac_unigram_probs/python500k_unigram_prob.json \ |
| --max_lines 10000 \ |
| -o analysis_output_fast_opt \ |
| --max_files 8 \ |
| --perturb_mode random_char_delete \ |
| --delete_ratio 0.1 \ |
| --ac_chunk_size 512 |
| """ |
|
|
| import os |
| import sys |
| import json |
| import gzip |
| import math |
| import base64 |
| import argparse |
| from typing import List, Callable, Tuple, Optional |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| import numpy as np |
| import torch |
| import torch.multiprocessing as mp |
|
|
| import Levenshtein |
|
|
| try: |
| import pandas as pd |
| except Exception: |
| pd = None |
|
|
| try: |
| import matplotlib.pyplot as plt |
| except Exception: |
| plt = None |
|
|
| try: |
| import seaborn as sns |
| except Exception: |
| sns = None |
|
|
|
|
| |
| |
| |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
| current_dir = os.getcwd() |
| if current_dir not in sys.path: |
| sys.path.append(current_dir) |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| if script_dir not in sys.path: |
| sys.path.append(script_dir) |
|
|
| |
| |
| |
| try: |
| from transformers import AutoTokenizer |
| except ImportError: |
| print("❌ Error: transformers not installed.") |
| sys.exit(1) |
|
|
| try: |
| from m1_compression import utils |
| from m1_compression.compressor import ( |
| load_m1_model_and_tokenizer, |
| ALPHABET_SIZE, |
| ARITHMETIC_CODER_BASE, |
| ARITHMETIC_CODER_PRECISION, |
| ) |
| from m1_compression.hybrid_arithmetic_coder import CPUArithmeticEncoder |
| from m1_compression.batched_arithmetic_coder import _pdf_to_cdf |
| except ImportError as e: |
| print(f"❌ Error: m1_compression not found. {e}") |
| sys.exit(1) |
|
|
|
|
| |
| |
| |
|
|
| def token_ids_to_str(ids: List[int]) -> str: |
| |
| |
| return "".join(chr(x if x <= 0x10FFFF else 0x10FFFF) for x in ids) |
|
|
|
|
| def bytes_to_latin1_str(b: bytes) -> str: |
| |
| return b.decode("latin1") |
|
|
|
|
| def pad_batch_fast(batch: List[bytes]) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| 将 List[bytes] -> (padded_batch[int64], lengths[int64]) |
| 关键优化:numpy.frombuffer + 一次性拷贝,避免 Python list(data) |
| """ |
| if not batch: |
| return torch.empty((0, 0), dtype=torch.long), torch.empty((0,), dtype=torch.long) |
|
|
| lengths_np = np.fromiter((len(x) for x in batch), dtype=np.int32, count=len(batch)) |
| max_len = int(lengths_np.max()) if lengths_np.size else 0 |
| arr = np.zeros((len(batch), max_len), dtype=np.uint8) |
| for i, seg in enumerate(batch): |
| seg_np = np.frombuffer(seg, dtype=np.uint8) |
| if seg_np.size: |
| arr[i, : seg_np.size] = seg_np |
| padded = torch.from_numpy(arr).to(torch.long) |
| lengths = torch.from_numpy(lengths_np.astype(np.int64)) |
| return padded, lengths |
|
|
|
|
| def iter_jsonl_shard_bytes(file_path: str, shard_rank: int, shard_world: int): |
| """ |
| 按“字节范围”切分 jsonl 文件:每个 shard 只读自己负责的文件区间。 |
| 适合“单文件吃满多 GPU”。 |
| """ |
| file_size = os.path.getsize(file_path) |
| start = (file_size * shard_rank) // shard_world |
| end = (file_size * (shard_rank + 1)) // shard_world |
|
|
| with open(file_path, "rb") as f: |
| f.seek(start) |
| if start > 0: |
| f.readline() |
|
|
| while f.tell() < end: |
| line = f.readline() |
| if not line: |
| break |
| yield line |
|
|
|
|
| |
| |
| |
|
|
| def perturb_text(text: str, mode: str, delete_ratio: float, rng: np.random.Generator) -> str: |
| """ |
| 返回扰动后的文本(删除策略为主)。 |
| - delete_ratio: (0, 1) 之间建议;>=1 会尽量保留最少 1 个字符 |
| """ |
| if not isinstance(text, str) or not text: |
| return text |
|
|
| n = len(text) |
| if n <= 1: |
| return text |
|
|
| r = float(delete_ratio) |
| if r <= 0: |
| return text |
|
|
| |
| k = int(math.floor(n * r)) |
| k = max(1, k) |
| k = min(k, n - 1) |
|
|
| if mode == "prefix_delete": |
| return text[k:] |
|
|
| if mode == "suffix_delete": |
| return text[: n - k] |
|
|
| if mode == "middle_delete": |
| start = (n - k) // 2 |
| return text[:start] + text[start + k :] |
|
|
| if mode == "random_span_delete": |
| |
| start = int(rng.integers(0, n - k + 1)) |
| return text[:start] + text[start + k :] |
|
|
| if mode == "random_char_delete": |
| |
| idx = np.arange(n) |
| del_idx = rng.choice(idx, size=k, replace=False) |
| mask = np.ones(n, dtype=bool) |
| mask[del_idx] = False |
| |
| return "".join([ch for i, ch in enumerate(text) if mask[i]]) |
|
|
| |
| return text[max(1, int(n * 0.2)) :] |
|
|
|
|
| |
| |
| |
|
|
| def batched_m1_compress_predict_fn(model): |
| def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: |
| if input_tensor.dim() == 1: |
| input_tensor = input_tensor.unsqueeze(0) |
| with torch.inference_mode(): |
| logits = model(input_tensor, **kwargs) |
| logits = logits[..., :256].float() |
| probs = torch.softmax(logits, dim=-1) |
| return probs |
| return predict_fn |
|
|
|
|
| def compress_segments_smart_batch_bytes( |
| all_segments: List[bytes], |
| batched_predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| device: torch.device, |
| encoder: CPUArithmeticEncoder, |
| gpu_batch_size: int = 256, |
| bit_threshold: int = 64, |
| ) -> List[bytes]: |
| """ |
| 高性能 AC 压缩: |
| 1) 先按长度排序,降低 padding 浪费 |
| 2) 推理在 GPU,编码在 CPU |
| 3) 输出每个 segment 的压缩 bytes(不转 List[int]) |
| """ |
| M = len(all_segments) |
| if M == 0: |
| return [] |
|
|
| lengths = np.fromiter((len(s) for s in all_segments), dtype=np.int32, count=M) |
| sorted_indices = np.argsort(lengths, kind="stable") |
| sorted_segments = [all_segments[i] for i in sorted_indices] |
|
|
| out: List[Optional[bytes]] = [None] * M |
|
|
| for i in range(0, M, gpu_batch_size): |
| batch_slice = sorted_segments[i : i + gpu_batch_size] |
| batch_orig_indices = sorted_indices[i : i + gpu_batch_size] |
|
|
| try: |
| padded_batch_cpu, lengths_cpu = pad_batch_fast(batch_slice) |
|
|
| |
| padded_batch = padded_batch_cpu.pin_memory().to(device, non_blocking=True) |
|
|
| |
| prompt_probs = batched_predict_fn(padded_batch) |
|
|
| |
| final_probs = torch.cat( |
| [ |
| first_byte_prob.expand(prompt_probs.shape[0], -1, -1), |
| prompt_probs[:, :-1, ...], |
| ], |
| dim=1, |
| ) |
|
|
| |
| final_probs = utils.batched_normalize_pdf_for_arithmetic_coding(final_probs) |
| cdfs_gpu = _pdf_to_cdf(final_probs) |
|
|
| cdf_low = cdfs_gpu.gather(2, padded_batch.unsqueeze(-1)).squeeze(-1) |
| cdf_high = cdfs_gpu.gather(2, (padded_batch + 1).unsqueeze(-1)).squeeze(-1) |
| cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) |
|
|
| |
| enc_out = encoder.incremental_batched_encode( |
| cdf_ends.cpu(), |
| ALPHABET_SIZE, |
| lengths_cpu, |
| bit_threshold=bit_threshold, |
| force_padding_to_threshold=False, |
| return_num_padded_bits=False, |
| ) |
|
|
| if isinstance(enc_out, tuple): |
| chunked_compressed_bytes = enc_out[0] |
| else: |
| chunked_compressed_bytes = enc_out |
|
|
| for idx, code in zip(batch_orig_indices, chunked_compressed_bytes): |
| out[int(idx)] = bytes(code) |
|
|
| except Exception: |
| |
| for idx, seg in zip(batch_orig_indices, batch_slice): |
| out[int(idx)] = seg |
|
|
| return [x if x is not None else b"" for x in out] |
|
|
|
|
| class M1ACManager: |
| """ |
| 统一的 AC 压缩管理器: |
| - 输入:一批文本 List[str] |
| - 处理:统一按 chunk_size 做字节分块(orig/pert 完全一致) |
| - 输出:每条样本对应拼接后的 AC bytes stream |
| """ |
| def __init__( |
| self, |
| model_path: str, |
| first_prob_path: str, |
| device_id: int, |
| gpu_batch_size: int = 256, |
| bit_threshold: int = 64, |
| chunk_size: int = 512, |
| ): |
| self.device = torch.device(f"cuda:{device_id}") |
| self.gpu_batch_size = gpu_batch_size |
| self.bit_threshold = bit_threshold |
| self.chunk_size = int(chunk_size) |
|
|
| self.model, _, _ = load_m1_model_and_tokenizer(model_path) |
| self.model.to(self.device) |
| self.model.eval() |
| self.predict_fn = batched_m1_compress_predict_fn(self.model) |
|
|
| if first_prob_path and os.path.exists(first_prob_path): |
| with open(first_prob_path, "r") as f: |
| prob_data = json.load(f) |
| self.first_byte_prob = torch.tensor(prob_data, dtype=torch.float32, device=self.device) |
| if self.first_byte_prob.dim() == 1: |
| self.first_byte_prob = self.first_byte_prob.unsqueeze(0).unsqueeze(0) |
| else: |
| self.first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device=self.device) / ALPHABET_SIZE |
|
|
| |
| self.encoder = CPUArithmeticEncoder(base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION) |
|
|
| def _segment_bytes(self, raw_bytes: bytes) -> List[bytes]: |
| if not raw_bytes: |
| return [b""] |
| cs = self.chunk_size |
| if cs <= 0 or len(raw_bytes) <= cs: |
| return [raw_bytes] |
| return [raw_bytes[i : i + cs] for i in range(0, len(raw_bytes), cs)] |
|
|
| def compress_batch_smart_bytes(self, texts: List[str]) -> List[bytes]: |
| """ |
| texts: List[str] |
| Return: List[bytes] 每个 sample 对应拼接后的 AC bitstream(bytes) |
| """ |
| if not texts: |
| return [] |
|
|
| all_segments_flat: List[bytes] = [] |
| sample_map: List[Tuple[int, int]] = [] |
| current_idx = 0 |
|
|
| for text in texts: |
| raw_bytes = (text or "").encode("utf-8", errors="ignore") |
| segs = self._segment_bytes(raw_bytes) |
|
|
| count = len(segs) |
| sample_map.append((current_idx, current_idx + count)) |
| all_segments_flat.extend(segs) |
| current_idx += count |
|
|
| if not all_segments_flat: |
| return [b"" for _ in texts] |
|
|
| compressed_chunks_flat = compress_segments_smart_batch_bytes( |
| all_segments_flat, |
| self.predict_fn, |
| self.first_byte_prob, |
| self.device, |
| self.encoder, |
| gpu_batch_size=self.gpu_batch_size, |
| bit_threshold=self.bit_threshold, |
| ) |
|
|
| results: List[bytes] = [] |
| for start, end in sample_map: |
| results.append(b"".join(compressed_chunks_flat[start:end])) |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def gzip_compress_stable(data: bytes) -> bytes: |
| """ |
| 尽量固定 gzip header 的 mtime=0,避免时间戳导致同输入不同输出的噪声。 |
| 不同 Python 版本可能不支持 mtime 参数,做兼容降级。 |
| """ |
| try: |
| return gzip.compress(data, mtime=0) |
| except TypeError: |
| |
| return gzip.compress(data) |
|
|
|
|
| def run_gzip_task(text_pair: Tuple[str, str]) -> float: |
| t1, t2 = text_pair |
| b1 = (t1 or "").encode("utf-8", errors="ignore") |
| b2 = (t2 or "").encode("utf-8", errors="ignore") |
| g1 = gzip_compress_stable(b1) |
| g2 = gzip_compress_stable(b2) |
| if not g1: |
| return 0.0 |
| d = Levenshtein.distance(bytes_to_latin1_str(g1), bytes_to_latin1_str(g2)) |
| return d / len(g1) |
|
|
|
|
| def process_one_file( |
| gpu_id: int, |
| file_path: str, |
| tokenizer: AutoTokenizer, |
| ac_manager: M1ACManager, |
| max_lines: int, |
| worker_batch_size: int, |
| gzip_threads: int, |
| shard_rank: int, |
| shard_world: int, |
| perturb_mode: str, |
| delete_ratio: float, |
| rng_seed: int, |
| ) -> dict: |
| """ |
| 处理单个 jsonl 文件,返回 results dict |
| """ |
| results = {"Gzip": [], "Tokenizer": [], "Neural": []} |
|
|
| |
| if shard_world > 1 and max_lines > 0: |
| shard_max_lines = int(math.ceil(max_lines / shard_world)) |
| else: |
| shard_max_lines = max_lines |
|
|
| raw_texts: List[str] = [] |
| pert_texts: List[str] = [] |
|
|
| processed_total = 0 |
|
|
| |
| rng = np.random.default_rng(int(rng_seed)) |
|
|
| |
| thread_pool = ThreadPoolExecutor(max_workers=gzip_threads) |
|
|
| def flush(): |
| nonlocal raw_texts, pert_texts |
| if not raw_texts: |
| return |
|
|
| curr_batch_size = len(raw_texts) |
|
|
| |
| gz_vals = list(thread_pool.map(run_gzip_task, zip(raw_texts, pert_texts))) |
| results["Gzip"].extend(gz_vals) |
|
|
| |
| try: |
| tok1 = tokenizer(raw_texts, add_special_tokens=False)["input_ids"] |
| tok2 = tokenizer(pert_texts, add_special_tokens=False)["input_ids"] |
| for a, b in zip(tok1, tok2): |
| if a: |
| d = Levenshtein.distance(token_ids_to_str(a), token_ids_to_str(b)) |
| results["Tokenizer"].append(d / len(a)) |
| except Exception: |
| |
| pass |
|
|
| |
| both_texts = raw_texts + pert_texts |
| try: |
| both_streams = ac_manager.compress_batch_smart_bytes(both_texts) |
| ac1_list = both_streams[:curr_batch_size] |
| ac2_list = both_streams[curr_batch_size:] |
|
|
| for a1, a2 in zip(ac1_list, ac2_list): |
| if a1: |
| d = Levenshtein.distance(bytes_to_latin1_str(a1), bytes_to_latin1_str(a2)) |
| results["Neural"].append(d / len(a1)) |
| except Exception as e: |
| print(f"[GPU {gpu_id}] AC Batch Error: {e}") |
|
|
| raw_texts, pert_texts = [], [] |
|
|
| |
| line_iter = iter_jsonl_shard_bytes(file_path, shard_rank, shard_world) |
|
|
| for i, line in enumerate(line_iter): |
| if shard_max_lines > 0 and i >= shard_max_lines: |
| break |
| try: |
| if not line or len(line) < 100: |
| continue |
| data = json.loads(line) |
| text = data.get("text", "") |
| if not isinstance(text, str) or len(text) < 50: |
| continue |
|
|
| |
| text_p = perturb_text(text, perturb_mode, delete_ratio, rng) |
|
|
| raw_texts.append(text) |
| pert_texts.append(text_p) |
|
|
| processed_total += 1 |
| if len(raw_texts) >= worker_batch_size: |
| flush() |
|
|
| if processed_total % 2000 == 0: |
| print( |
| f"[GPU {gpu_id}] processed {processed_total} lines " |
| f"(file={os.path.basename(file_path)}, shard={shard_rank}/{shard_world}, perturb={perturb_mode}, ratio={delete_ratio})" |
| ) |
|
|
| except Exception: |
| continue |
|
|
| flush() |
| thread_pool.shutdown(wait=True) |
|
|
| print(f"[GPU {gpu_id}] done file={os.path.basename(file_path)} shard={shard_rank}/{shard_world} total={processed_total}") |
| return results |
|
|
|
|
| def process_files_worker( |
| rank: int, |
| gpu_id: int, |
| file_paths: List[str], |
| output_dir: str, |
| model_path: str, |
| prob_path: str, |
| max_lines: int, |
| worker_batch_size: int, |
| gzip_threads: int, |
| shard_mode: bool, |
| gpu_batch_size: int, |
| bit_threshold: int, |
| ac_chunk_size: int, |
| perturb_mode: str, |
| delete_ratio: float, |
| random_seed: int, |
| ): |
| """ |
| 一个 GPU 进程:加载一次 tokenizer + M1 模型,然后顺序处理分配给它的文件(或单文件 shard) |
| """ |
| try: |
| torch.cuda.set_device(gpu_id) |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained( |
| "infly/OpenCoder-1.5B-Base", |
| trust_remote_code=True, |
| use_fast=True, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| except Exception: |
| tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True) |
|
|
| |
| ac_manager = M1ACManager( |
| model_path=model_path, |
| first_prob_path=prob_path, |
| device_id=gpu_id, |
| gpu_batch_size=gpu_batch_size, |
| bit_threshold=bit_threshold, |
| chunk_size=ac_chunk_size, |
| ) |
|
|
| |
| base_seed = int(random_seed) + int(rank) * 1000003 |
|
|
| for fp in file_paths: |
| base = os.path.basename(fp) |
|
|
| |
| if shard_mode: |
| shard_rank = rank |
| shard_world = torch.cuda.device_count() |
| else: |
| shard_rank = 0 |
| shard_world = 1 |
|
|
| print( |
| f"[GPU {gpu_id}] start file={base} shard={shard_rank}/{shard_world} " |
| f"perturb={perturb_mode} ratio={delete_ratio} ac_chunk={ac_chunk_size}" |
| ) |
|
|
| res = process_one_file( |
| gpu_id=gpu_id, |
| file_path=fp, |
| tokenizer=tokenizer, |
| ac_manager=ac_manager, |
| max_lines=max_lines, |
| worker_batch_size=worker_batch_size, |
| gzip_threads=gzip_threads, |
| shard_rank=shard_rank, |
| shard_world=shard_world, |
| perturb_mode=perturb_mode, |
| delete_ratio=delete_ratio, |
| rng_seed=base_seed + (hash(base) % 100000), |
| ) |
|
|
| out_name = f"res_gpu{gpu_id}_rank{rank}_shard{shard_rank}of{shard_world}_{base}.json" |
| out_path = os.path.join(output_dir, out_name) |
| with open(out_path, "w") as f: |
| json.dump(res, f) |
|
|
| except Exception as e: |
| print(f"❌ [GPU {gpu_id}] Worker Error: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input_dir", type=str, required=True) |
| parser.add_argument("--m1_model", type=str, required=True) |
| parser.add_argument("--first_prob_path", type=str, required=True) |
| parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_fast_opt_v2") |
| parser.add_argument("--max_lines", type=int, default=10000) |
|
|
| |
| parser.add_argument("--max_files", type=int, default=8, help="只取前 N 个 jsonl 文件;0 表示不限制") |
| parser.add_argument("--worker_batch_size", type=int, default=500, help="flush 的行数 batch") |
| parser.add_argument("--gzip_threads", type=int, default=8, help="每个 GPU 进程内用于 gzip 的线程数") |
| parser.add_argument("--ac_gpu_batch_size", type=int, default=256, help="AC 推理的 GPU mini-batch size") |
| parser.add_argument("--ac_bit_threshold", type=int, default=64, help="Arithmetic coder bit_threshold(16->64/128 往往更快)") |
| parser.add_argument("--ac_chunk_size", type=int, default=512, help="AC 输入字节分块大小(orig/pert 统一策略)") |
|
|
| |
| parser.add_argument( |
| "--perturb_mode", |
| type=str, |
| default="prefix_delete", |
| choices=[ |
| "prefix_delete", |
| "suffix_delete", |
| "middle_delete", |
| "random_span_delete", |
| "random_char_delete", |
| ], |
| help="扰动(删除)策略", |
| ) |
| parser.add_argument("--delete_ratio", type=float, default=0.2, help="删除比例(0~1 推荐)") |
| parser.add_argument("--random_seed", type=int, default=1234, help="随机扰动种子(可复现)") |
|
|
| args = parser.parse_args() |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| files = [ |
| os.path.join(args.input_dir, f) |
| for f in os.listdir(args.input_dir) |
| if f.endswith(".jsonl") and "writer" not in f |
| ] |
| files.sort() |
|
|
| if args.max_files and args.max_files > 0: |
| files = files[: args.max_files] |
|
|
| num_gpus = torch.cuda.device_count() |
| if num_gpus == 0: |
| print("❌ No GPU detected.") |
| return |
| if not files: |
| print("❌ No jsonl files found.") |
| return |
|
|
| |
| |
| |
| shard_mode = (len(files) == 1 and num_gpus > 1) |
|
|
| assignments: List[List[str]] = [[] for _ in range(num_gpus)] |
| if shard_mode: |
| for r in range(num_gpus): |
| assignments[r] = [files[0]] |
| print(f"🚀 Single-file shard mode enabled: {files[0]} -> {num_gpus} shards") |
| else: |
| for idx, fp in enumerate(files): |
| assignments[idx % num_gpus].append(fp) |
| non_empty = sum(1 for a in assignments if a) |
| print(f"🚀 Multi-file mode: {len(files)} files assigned across {non_empty}/{num_gpus} GPU workers") |
|
|
| print( |
| " worker_batch_size={}, gzip_threads={}, ac_gpu_batch_size={}, ac_bit_threshold={}, ac_chunk_size={}, " |
| "perturb_mode={}, delete_ratio={}, random_seed={}".format( |
| args.worker_batch_size, |
| args.gzip_threads, |
| args.ac_gpu_batch_size, |
| args.ac_bit_threshold, |
| args.ac_chunk_size, |
| args.perturb_mode, |
| args.delete_ratio, |
| args.random_seed, |
| ) |
| ) |
|
|
| mp.set_start_method("spawn", force=True) |
| procs = [] |
| for rank in range(num_gpus): |
| if not assignments[rank]: |
| continue |
| gpu_id = rank % num_gpus |
| p = mp.Process( |
| target=process_files_worker, |
| args=( |
| rank, |
| gpu_id, |
| assignments[rank], |
| args.output_dir, |
| args.m1_model, |
| args.first_prob_path, |
| args.max_lines, |
| args.worker_batch_size, |
| args.gzip_threads, |
| shard_mode, |
| args.ac_gpu_batch_size, |
| args.ac_bit_threshold, |
| args.ac_chunk_size, |
| args.perturb_mode, |
| args.delete_ratio, |
| args.random_seed, |
| ), |
| ) |
| p.start() |
| procs.append(p) |
|
|
| for p in procs: |
| p.join() |
|
|
| |
| print("✅ Merging results...") |
| final_results = {"Gzip": [], "Tokenizer": [], "Neural": []} |
| for fn in os.listdir(args.output_dir): |
| if fn.startswith("res_") and fn.endswith(".json"): |
| try: |
| with open(os.path.join(args.output_dir, fn), "r") as f: |
| d = json.load(f) |
| for k in final_results: |
| final_results[k].extend(d.get(k, [])) |
| except Exception: |
| pass |
|
|
| for k, v in final_results.items(): |
| print(f" {k}: {len(v)} samples") |
|
|
| |
| stats = {} |
| for k, v in final_results.items(): |
| if v: |
| stats[k] = { |
| "count": int(len(v)), |
| "mean": float(np.mean(v)), |
| "p50": float(np.median(v)), |
| "p90": float(np.quantile(v, 0.9)), |
| } |
|
|
| with open(os.path.join(args.output_dir, "final_stats.json"), "w") as f: |
| json.dump(stats, f, indent=2, ensure_ascii=False) |
| print(f"📄 Saved stats -> {os.path.join(args.output_dir, 'final_stats.json')}") |
|
|
| |
| if all(k in stats for k in ["Tokenizer", "Neural", "Gzip"]): |
| m_tok = stats["Tokenizer"]["mean"] |
| m_ac = stats["Neural"]["mean"] |
| m_gz = stats["Gzip"]["mean"] |
| print(f"🔎 mean NED: Tokenizer={m_tok:.4f}, Neural={m_ac:.4f}, Gzip={m_gz:.4f}") |
| if (m_tok < m_ac) and (m_ac < m_gz): |
| print("✅ Observed ordering matches expectation: tokenizer < ac < gzip") |
| else: |
| print("⚠️ Ordering NOT matched this run. 可能需要调整 delete_ratio / perturb_mode / 数据集 / 分词器 / 模型。") |
|
|
| |
| plot_data = [] |
| for algo, vals in final_results.items(): |
| for val in vals: |
| |
| if 0 <= val < 2.0: |
| plot_data.append({"Proxy compressor": algo, "NED": val}) |
|
|
| if plot_data and plt is not None: |
| out_img = os.path.join(args.output_dir, "stability_fast_opt_v2.png") |
| try: |
| if (pd is not None) and (sns is not None): |
| df = pd.DataFrame(plot_data) |
| plt.figure(figsize=(10, 6)) |
| sns.kdeplot(data=df, x="NED", hue="Proxy compressor", fill=True, common_norm=False) |
| plt.xlabel("Normalized Levenshtein Distance") |
| plt.xlim(0, 1.5) |
| plt.savefig(out_img, dpi=200) |
| print(f"📊 Saved plot -> {out_img}") |
| else: |
| |
| plt.figure(figsize=(10, 6)) |
| for algo in ["Tokenizer", "Neural", "Gzip"]: |
| vals = [x["NED"] for x in plot_data if x["Algorithm"] == algo] |
| if vals: |
| plt.hist(vals, bins=80, density=True, alpha=0.4, label=algo) |
| plt.xlabel("Normalized Levenshtein Distance") |
| plt.xlim(0, 1.5) |
| plt.legend() |
| plt.savefig(out_img, dpi=200) |
| print(f"📊 Saved plot -> {out_img}") |
| except Exception as e: |
| print(f"⚠️ Plot failed: {e}") |
|
|
| print("🎉 Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|