File size: 3,709 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from transformers import AutoTokenizer
from datasets import load_dataset, concatenate_datasets
import numpy as np
from tqdm import tqdm

# ========= 配置 =========
tokenizer_path = "/home/rm3.4.1_9e-6"              # ← 换成你的 tokenizer 路径
parquet_paths = [
    # "/home/data/pk-2089-L6.parquet",
    # "/home/data/pk-1820-L6.parquet",
    # "/home/data/pk-2355-L6.parquet",
    # "/home/data/pk-4088-L6.parquet",
    # "/home/data/pk-3876-L6.parquet",
    # "/home/data/pk-2749-L6.parquet",
    # "/home/data/pk-2354-L5.parquet",
    # "/home/data/pk-3774-L6.parquet",
    # "/home/data/pk-1158-L5.parquet"  
    # "/home/data/pk-4537-L0.parquet"
    # "/home/data/pk-1740-L4.parquet"
    # "/home/data/raw/test/1159-L6_format.parquet"
    # "/home/data/raw/test/4201_2355_full_label.parquet"
    "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet"
    
]
# parquet_paths=["/home/data/prefiltered.parquet"]
output_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet"  # 合并后的输出路径
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def count_total_tokens(ex):
    """为样本添加 total_tokens / chosen_tokens / rejected_tokens 字段"""
    prompt = ex["chosen_prompt"]
    chosen_ids   = tokenizer(prompt + ex["chosen"],   add_special_tokens=False)["input_ids"]
    rejected_ids = tokenizer(prompt + ex["reject"],   add_special_tokens=False)["input_ids"]
    ex["total_tokens"]    = len(chosen_ids) + len(rejected_ids)
    ex["chosen_tokens"]   = len(chosen_ids)
    ex["rejected_tokens"] = len(rejected_ids)
    return ex

def summary(arr):
    """返回 max, min, mean(三个 int/float)"""
    return int(arr.max()), int(arr.min()), float(arr.mean())

# ========= 主流程 =========
cleaned_sets   = []        # 过滤后的数据集
stats_before   = {}        # {file: (max, min, mean)}
stats_after    = {}        # {file: (max, min, mean)}

for path in parquet_paths:
    name = path.split("/")[-1]
    print(f"\n▶ 处理 {name}")

    # 1. 加载
    ds = load_dataset("parquet", data_files=path, split="train")
    print(len(ds))
    # 2. 过滤前统计
    tokens_b = np.array(
        tokenizer(ds["chosen_prompt"][0] + ds["chosen"][0], add_special_tokens=False)["input_ids"]
    )  # 占位初始化,以防空集
    # 实际统计要对整个列做,需要先计算 token 字段
    ds_tmp = ds.map(count_total_tokens, desc=f"[{name}] 计算 token (预统计)", num_proc=4)
    stats_before[name] = summary(np.array(ds_tmp["total_tokens"]))

    # 3. 正式计算 token 并过滤
    ds = ds_tmp.filter(
        lambda x: 1000 <= x["total_tokens"] <= 8192,
        desc=f"[{name}] 过滤区间 [1000, 8192]"
    )

    # 4. 过滤后统计
    stats_after[name] = summary(np.array(ds["total_tokens"]))

    # 5. 去掉无关列,只留三列 & token 字段(方便后续合并)
    # keep = ["chosen", "chosen_prompt", "reject", "total_tokens", "chosen_tokens", "rejected_tokens"]
    # ds = ds.remove_columns([c for c in ds.column_names if c not in keep])

    cleaned_sets.append(ds)

# ========= 打印统计结果 =========
print("\n================ Token 统计对比 ================ ")
print(f"{'数据集':<22} | {'过滤前 max/min/mean':<25} | {'过滤后 max/min/mean':<25}")
print("-"*80)
for name in parquet_paths:
    n = name.split("/")[-1]
    b_max, b_min, b_mean = stats_before[n]
    a_max, a_min, a_mean = stats_after[n]
    print(f"{n:<22} | {b_max:5d}/{b_min:5d}/{b_mean:7.1f} | {a_max:5d}/{a_min:5d}/{a_mean:7.1f}")

# ========= (可选)合并并保存 =========
merged = concatenate_datasets(cleaned_sets)
merged.to_parquet(output_path)
print("\n✅ 合并后样本数:", len(merged),)