File size: 2,934 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import matplotlib.pyplot as plt
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
import re
import numpy as np
import os

paths = [
    "/home/data/pk-2089-L6.parquet",
    "/home/data/pk-1820-L6.parquet",
    "/home/data/pk-2355-L6.parquet",
    "/home/data/pk-4088-L6.parquet",
    "/home/data/pk-3876-L6.parquet",
]
tok = AutoTokenizer.from_pretrained("/home/rm")

special_tokens = {
    "<|im_start|>", "<|im_end|>",
    "<|eot_id|>", "|eot_id|", "<|end_of_text|>",
    "<s>", "</s>",
    "<|system|>", "<|user|>", "<|assistant|>",
    "<bos>", "<eos>", "<pad>",
    "<|start_header_id|>", "<|end_header_id|>",
    "[INST]", "[/INST]",
}
pat = re.compile("|".join(map(re.escape, special_tokens)))

def clean_text(ex):
    def norm(s):
        if not isinstance(s, str):
            return ""
        s = pat.sub("", s.strip())
        s = re.sub(r"\s+", " ", s).strip()
        return s
    ex["chosen"] = norm(ex.get("chosen", ""))
    ex["reject"] = norm(ex.get("reject", ""))
    ex["prompt"] = ""  # reply-only
    return ex

def add_lengths(batch):
    c_enc = tok(batch["chosen"], add_special_tokens=False)
    r_enc = tok(batch["reject"], add_special_tokens=False)
    len_c = [len(x) for x in c_enc["input_ids"]]
    len_r = [len(x) for x in r_enc["input_ids"]]
    return {
        "len_c": len_c,
        "len_r": len_r,
        "len_diff": [abs(a-b) for a,b in zip(len_c, len_r)],
    }

needed = ["prompt", "chosen", "reject", "len_c", "len_r", "len_diff"]
sets = []
for p in paths:
    ds = load_dataset("parquet", data_files=p, split="train")
    ds = ds.map(clean_text, num_proc=4)
    ds = ds.map(add_lengths, batched=True, batch_size=1024, num_proc=4)
    drop_cols = [c for c in ds.column_names if c not in needed]
    if drop_cols:
        ds = ds.remove_columns(drop_cols)
    sets.append(ds)

full = concatenate_datasets(sets)

# 分位数统计
len_diffs = np.array(full["len_diff"])
for q in [0.50, 0.75, 0.90, 0.95, 0.99]:
    print(f"|Δlen| 分位数 q={q:.2f}: {np.quantile(len_diffs, q)}")

cut = np.quantile(len_diffs, 0.95)
print(f"长度差 0.95 分位数阈值: {cut}")

# ====== 绘制并保存直方图 ======
plt.figure(figsize=(8,5))
plt.hist(len_diffs, bins=50, color="skyblue", edgecolor="black")
plt.axvline(cut, color="red", linestyle="--", label=f"0.95分位: {cut}")
plt.title("|Δlen| 长度差分布(chosen vs reject)")
plt.xlabel("Token Length Difference")
plt.ylabel("Frequency")
plt.legend()
os.makedirs("./plots", exist_ok=True)
plot_path = "./plots/len_diff_distribution.png"
plt.savefig(plot_path, dpi=300)
plt.close()
print(f"✅ 已保存长度差分布图: {plot_path}")

# 过滤并保存新数据
full = full.filter(lambda x: x["len_diff"] <= cut, num_proc=4)
full = full.remove_columns(["len_c", "len_r", "len_diff"])

out = "/home/data/reply_only_pairs.parquet"
full.to_parquet(out)
print("saved:", out, "rows:", len(full))