File size: 3,977 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from pathlib import Path
from typing import List, Union

from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer


# ====================== 配置区 ======================
# 可以把 data_dir 改成你的目录,或直接用 parquet_paths 指定一组文件
# data_dir: Union[str, Path] = "/path/to/parquet_dir"   # 包含若干 .parquet 的目录
parquet_paths: List[str] = [
    "/home/data/train_10k_sys_3round.parquet",
]                         # 或者直接给出文件清单(优先使用这个)
tokenizer_path = "/home/rm3.4.1_9e-6"           # 分词器(与训练一致)
output_path = "/home/data/prefiltered.parquet"     # 合并后过滤结果
num_proc = max(1, (os.cpu_count() or 4) // 2)        # 并行进程数,可按机器调整
min_tokens, max_tokens = 20, 80                      # 过滤阈值(含边界)
# ==================================================


def collect_parquet_files() -> List[str]:
    if parquet_paths:
        return [str(Path(p)) for p in parquet_paths]
    p = Path(data_dir)
    if not p.exists():
        raise FileNotFoundError(f"目录不存在:{p}")
    files = sorted([str(fp) for fp in p.glob("*.parquet")])
    if not files:
        raise FileNotFoundError(f"目录中未找到 .parquet 文件:{p}")
    return files


def main():
    files = collect_parquet_files()
    print(f"发现 {len(files)} 个 parquet 文件,将合并处理:")
    for f in files:
        print("  -", f)

    # 加载 tokenizer(务必与训练阶段一致;不加 special tokens)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    # 方式A:一次性合并加载(更快,前提是 schema 一致)
    dataset = load_dataset("parquet", data_files=files, split="train")

    # 如果你的文件 schema 不完全一致,可以改用逐个加载再 concatenate:
    # parts = [load_dataset("parquet", data_files=f, split="train") for f in files]
    # dataset = concatenate_datasets(parts)

    total_before = len(dataset)
    print(f"\n合并后样本数:{total_before}")

    # === 计算 token 数(batched=True 更快) ===
    def add_token_lengths(batch):
        chosen = batch["chosen"]
        reject = batch["reject"]

        # tokenizer 接收 list,返回每个文本的 input_ids 列表
        chosen_ids = tokenizer(chosen, add_special_tokens=False)["input_ids"]
        reject_ids = tokenizer(reject, add_special_tokens=False)["input_ids"]

        return {
            "chosen_tokens": [len(x) for x in chosen_ids],
            "reject_tokens": [len(x) for x in reject_ids],
        }

    dataset = dataset.map(
        add_token_lengths,
        batched=True,
        num_proc=num_proc,
        desc="计算 token 数",
    )

    # === 过滤:两个字段都须在 [min_tokens, max_tokens] 内 ===
    def in_range_filter(batch):
        ct = batch["chosen_tokens"]
        rt = batch["reject_tokens"]
        # batched=True 时需要返回布尔列表
        return [
            (min_tokens <= c <= max_tokens) and (min_tokens <= r <= max_tokens)
            for c, r in zip(ct, rt)
        ]

    dataset = dataset.filter(
        in_range_filter,
        batched=True,
        num_proc=num_proc,
        desc=f"过滤:保留 {min_tokens}~{max_tokens} tokens",
    )

    kept = len(dataset)
    print(f"过滤完成:保留 {kept} / {total_before} (保留率 {kept/total_before:.2%})")

    # === 清理临时列并保存 ===
    # 若原数据没有这两个字段就不会删除失败;有就删,避免污染
    for col in ["chosen_tokens", "reject_tokens"]:
        if col in dataset.column_names:
            dataset = dataset.remove_columns(col)

    # 将结果一次性保存为 Parquet(合并后的单文件)
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    dataset.to_parquet(output_path)
    print(f"已保存到:{output_path}")


if __name__ == "__main__":
    main()