| | import matplotlib.pyplot as plt |
| | from datasets import load_dataset, concatenate_datasets |
| | from transformers import AutoTokenizer |
| | import re |
| | import numpy as np |
| | import os |
| |
|
| | paths = [ |
| | "/home/data/pk-2089-L6.parquet", |
| | "/home/data/pk-1820-L6.parquet", |
| | "/home/data/pk-2355-L6.parquet", |
| | "/home/data/pk-4088-L6.parquet", |
| | "/home/data/pk-3876-L6.parquet", |
| | ] |
| | tok = AutoTokenizer.from_pretrained("/home/rm") |
| |
|
| | special_tokens = { |
| | "<|im_start|>", "<|im_end|>", |
| | "<|eot_id|>", "|eot_id|", "<|end_of_text|>", |
| | "<s>", "</s>", |
| | "<|system|>", "<|user|>", "<|assistant|>", |
| | "<bos>", "<eos>", "<pad>", |
| | "<|start_header_id|>", "<|end_header_id|>", |
| | "[INST]", "[/INST]", |
| | } |
| | pat = re.compile("|".join(map(re.escape, special_tokens))) |
| |
|
| | def clean_text(ex): |
| | def norm(s): |
| | if not isinstance(s, str): |
| | return "" |
| | s = pat.sub("", s.strip()) |
| | s = re.sub(r"\s+", " ", s).strip() |
| | return s |
| | ex["chosen"] = norm(ex.get("chosen", "")) |
| | ex["reject"] = norm(ex.get("reject", "")) |
| | ex["prompt"] = "" |
| | return ex |
| |
|
| | def add_lengths(batch): |
| | c_enc = tok(batch["chosen"], add_special_tokens=False) |
| | r_enc = tok(batch["reject"], add_special_tokens=False) |
| | len_c = [len(x) for x in c_enc["input_ids"]] |
| | len_r = [len(x) for x in r_enc["input_ids"]] |
| | return { |
| | "len_c": len_c, |
| | "len_r": len_r, |
| | "len_diff": [abs(a-b) for a,b in zip(len_c, len_r)], |
| | } |
| |
|
| | needed = ["prompt", "chosen", "reject", "len_c", "len_r", "len_diff"] |
| | sets = [] |
| | for p in paths: |
| | ds = load_dataset("parquet", data_files=p, split="train") |
| | ds = ds.map(clean_text, num_proc=4) |
| | ds = ds.map(add_lengths, batched=True, batch_size=1024, num_proc=4) |
| | drop_cols = [c for c in ds.column_names if c not in needed] |
| | if drop_cols: |
| | ds = ds.remove_columns(drop_cols) |
| | sets.append(ds) |
| |
|
| | full = concatenate_datasets(sets) |
| |
|
| | |
| | len_diffs = np.array(full["len_diff"]) |
| | for q in [0.50, 0.75, 0.90, 0.95, 0.99]: |
| | print(f"|Δlen| 分位数 q={q:.2f}: {np.quantile(len_diffs, q)}") |
| |
|
| | cut = np.quantile(len_diffs, 0.95) |
| | print(f"长度差 0.95 分位数阈值: {cut}") |
| |
|
| | |
| | plt.figure(figsize=(8,5)) |
| | plt.hist(len_diffs, bins=50, color="skyblue", edgecolor="black") |
| | plt.axvline(cut, color="red", linestyle="--", label=f"0.95分位: {cut}") |
| | plt.title("|Δlen| 长度差分布(chosen vs reject)") |
| | plt.xlabel("Token Length Difference") |
| | plt.ylabel("Frequency") |
| | plt.legend() |
| | os.makedirs("./plots", exist_ok=True) |
| | plot_path = "./plots/len_diff_distribution.png" |
| | plt.savefig(plot_path, dpi=300) |
| | plt.close() |
| | print(f"✅ 已保存长度差分布图: {plot_path}") |
| |
|
| | |
| | full = full.filter(lambda x: x["len_diff"] <= cut, num_proc=4) |
| | full = full.remove_columns(["len_c", "len_r", "len_diff"]) |
| |
|
| | out = "/home/data/reply_only_pairs.parquet" |
| | full.to_parquet(out) |
| | print("saved:", out, "rows:", len(full)) |
| |
|