import matplotlib.pyplot as plt from datasets import load_dataset, concatenate_datasets from transformers import AutoTokenizer import re import numpy as np import os paths = [ "/home/data/pk-2089-L6.parquet", "/home/data/pk-1820-L6.parquet", "/home/data/pk-2355-L6.parquet", "/home/data/pk-4088-L6.parquet", "/home/data/pk-3876-L6.parquet", ] tok = AutoTokenizer.from_pretrained("/home/rm") special_tokens = { "<|im_start|>", "<|im_end|>", "<|eot_id|>", "|eot_id|", "<|end_of_text|>", "", "", "<|system|>", "<|user|>", "<|assistant|>", "", "", "", "<|start_header_id|>", "<|end_header_id|>", "[INST]", "[/INST]", } pat = re.compile("|".join(map(re.escape, special_tokens))) def clean_text(ex): def norm(s): if not isinstance(s, str): return "" s = pat.sub("", s.strip()) s = re.sub(r"\s+", " ", s).strip() return s ex["chosen"] = norm(ex.get("chosen", "")) ex["reject"] = norm(ex.get("reject", "")) ex["prompt"] = "" # reply-only return ex def add_lengths(batch): c_enc = tok(batch["chosen"], add_special_tokens=False) r_enc = tok(batch["reject"], add_special_tokens=False) len_c = [len(x) for x in c_enc["input_ids"]] len_r = [len(x) for x in r_enc["input_ids"]] return { "len_c": len_c, "len_r": len_r, "len_diff": [abs(a-b) for a,b in zip(len_c, len_r)], } needed = ["prompt", "chosen", "reject", "len_c", "len_r", "len_diff"] sets = [] for p in paths: ds = load_dataset("parquet", data_files=p, split="train") ds = ds.map(clean_text, num_proc=4) ds = ds.map(add_lengths, batched=True, batch_size=1024, num_proc=4) drop_cols = [c for c in ds.column_names if c not in needed] if drop_cols: ds = ds.remove_columns(drop_cols) sets.append(ds) full = concatenate_datasets(sets) # 分位数统计 len_diffs = np.array(full["len_diff"]) for q in [0.50, 0.75, 0.90, 0.95, 0.99]: print(f"|Δlen| 分位数 q={q:.2f}: {np.quantile(len_diffs, q)}") cut = np.quantile(len_diffs, 0.95) print(f"长度差 0.95 分位数阈值: {cut}") # ====== 绘制并保存直方图 ====== plt.figure(figsize=(8,5)) plt.hist(len_diffs, bins=50, color="skyblue", edgecolor="black") plt.axvline(cut, color="red", linestyle="--", label=f"0.95分位: {cut}") plt.title("|Δlen| 长度差分布(chosen vs reject)") plt.xlabel("Token Length Difference") plt.ylabel("Frequency") plt.legend() os.makedirs("./plots", exist_ok=True) plot_path = "./plots/len_diff_distribution.png" plt.savefig(plot_path, dpi=300) plt.close() print(f"✅ 已保存长度差分布图: {plot_path}") # 过滤并保存新数据 full = full.filter(lambda x: x["len_diff"] <= cut, num_proc=4) full = full.remove_columns(["len_c", "len_r", "len_diff"]) out = "/home/data/reply_only_pairs.parquet" full.to_parquet(out) print("saved:", out, "rows:", len(full))