import matplotlib.pyplot as plt
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
import re
import numpy as np
import os
paths = [
"/home/data/pk-2089-L6.parquet",
"/home/data/pk-1820-L6.parquet",
"/home/data/pk-2355-L6.parquet",
"/home/data/pk-4088-L6.parquet",
"/home/data/pk-3876-L6.parquet",
]
tok = AutoTokenizer.from_pretrained("/home/rm")
special_tokens = {
"<|im_start|>", "<|im_end|>",
"<|eot_id|>", "|eot_id|", "<|end_of_text|>",
"", "",
"<|system|>", "<|user|>", "<|assistant|>",
"", "", "",
"<|start_header_id|>", "<|end_header_id|>",
"[INST]", "[/INST]",
}
pat = re.compile("|".join(map(re.escape, special_tokens)))
def clean_text(ex):
def norm(s):
if not isinstance(s, str):
return ""
s = pat.sub("", s.strip())
s = re.sub(r"\s+", " ", s).strip()
return s
ex["chosen"] = norm(ex.get("chosen", ""))
ex["reject"] = norm(ex.get("reject", ""))
ex["prompt"] = "" # reply-only
return ex
def add_lengths(batch):
c_enc = tok(batch["chosen"], add_special_tokens=False)
r_enc = tok(batch["reject"], add_special_tokens=False)
len_c = [len(x) for x in c_enc["input_ids"]]
len_r = [len(x) for x in r_enc["input_ids"]]
return {
"len_c": len_c,
"len_r": len_r,
"len_diff": [abs(a-b) for a,b in zip(len_c, len_r)],
}
needed = ["prompt", "chosen", "reject", "len_c", "len_r", "len_diff"]
sets = []
for p in paths:
ds = load_dataset("parquet", data_files=p, split="train")
ds = ds.map(clean_text, num_proc=4)
ds = ds.map(add_lengths, batched=True, batch_size=1024, num_proc=4)
drop_cols = [c for c in ds.column_names if c not in needed]
if drop_cols:
ds = ds.remove_columns(drop_cols)
sets.append(ds)
full = concatenate_datasets(sets)
# 分位数统计
len_diffs = np.array(full["len_diff"])
for q in [0.50, 0.75, 0.90, 0.95, 0.99]:
print(f"|Δlen| 分位数 q={q:.2f}: {np.quantile(len_diffs, q)}")
cut = np.quantile(len_diffs, 0.95)
print(f"长度差 0.95 分位数阈值: {cut}")
# ====== 绘制并保存直方图 ======
plt.figure(figsize=(8,5))
plt.hist(len_diffs, bins=50, color="skyblue", edgecolor="black")
plt.axvline(cut, color="red", linestyle="--", label=f"0.95分位: {cut}")
plt.title("|Δlen| 长度差分布(chosen vs reject)")
plt.xlabel("Token Length Difference")
plt.ylabel("Frequency")
plt.legend()
os.makedirs("./plots", exist_ok=True)
plot_path = "./plots/len_diff_distribution.png"
plt.savefig(plot_path, dpi=300)
plt.close()
print(f"✅ 已保存长度差分布图: {plot_path}")
# 过滤并保存新数据
full = full.filter(lambda x: x["len_diff"] <= cut, num_proc=4)
full = full.remove_columns(["len_c", "len_r", "len_diff"])
out = "/home/data/reply_only_pairs.parquet"
full.to_parquet(out)
print("saved:", out, "rows:", len(full))