rm_code / data_pro.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
from transformers import AutoTokenizer
from datasets import load_dataset, concatenate_datasets
import numpy as np
from tqdm import tqdm
# ========= 配置 =========
tokenizer_path = "/home/rm3.4.1_9e-6" # ← 换成你的 tokenizer 路径
parquet_paths = [
# "/home/data/pk-2089-L6.parquet",
# "/home/data/pk-1820-L6.parquet",
# "/home/data/pk-2355-L6.parquet",
# "/home/data/pk-4088-L6.parquet",
# "/home/data/pk-3876-L6.parquet",
# "/home/data/pk-2749-L6.parquet",
# "/home/data/pk-2354-L5.parquet",
# "/home/data/pk-3774-L6.parquet",
# "/home/data/pk-1158-L5.parquet"
# "/home/data/pk-4537-L0.parquet"
# "/home/data/pk-1740-L4.parquet"
# "/home/data/raw/test/1159-L6_format.parquet"
# "/home/data/raw/test/4201_2355_full_label.parquet"
"/home/data/raw/test/4201_2355_full_label_1000-8192.parquet"
]
# parquet_paths=["/home/data/prefiltered.parquet"]
output_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" # 合并后的输出路径
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
def count_total_tokens(ex):
"""为样本添加 total_tokens / chosen_tokens / rejected_tokens 字段"""
prompt = ex["chosen_prompt"]
chosen_ids = tokenizer(prompt + ex["chosen"], add_special_tokens=False)["input_ids"]
rejected_ids = tokenizer(prompt + ex["reject"], add_special_tokens=False)["input_ids"]
ex["total_tokens"] = len(chosen_ids) + len(rejected_ids)
ex["chosen_tokens"] = len(chosen_ids)
ex["rejected_tokens"] = len(rejected_ids)
return ex
def summary(arr):
"""返回 max, min, mean(三个 int/float)"""
return int(arr.max()), int(arr.min()), float(arr.mean())
# ========= 主流程 =========
cleaned_sets = [] # 过滤后的数据集
stats_before = {} # {file: (max, min, mean)}
stats_after = {} # {file: (max, min, mean)}
for path in parquet_paths:
name = path.split("/")[-1]
print(f"\n▶ 处理 {name}")
# 1. 加载
ds = load_dataset("parquet", data_files=path, split="train")
print(len(ds))
# 2. 过滤前统计
tokens_b = np.array(
tokenizer(ds["chosen_prompt"][0] + ds["chosen"][0], add_special_tokens=False)["input_ids"]
) # 占位初始化,以防空集
# 实际统计要对整个列做,需要先计算 token 字段
ds_tmp = ds.map(count_total_tokens, desc=f"[{name}] 计算 token (预统计)", num_proc=4)
stats_before[name] = summary(np.array(ds_tmp["total_tokens"]))
# 3. 正式计算 token 并过滤
ds = ds_tmp.filter(
lambda x: 1000 <= x["total_tokens"] <= 8192,
desc=f"[{name}] 过滤区间 [1000, 8192]"
)
# 4. 过滤后统计
stats_after[name] = summary(np.array(ds["total_tokens"]))
# 5. 去掉无关列,只留三列 & token 字段(方便后续合并)
# keep = ["chosen", "chosen_prompt", "reject", "total_tokens", "chosen_tokens", "rejected_tokens"]
# ds = ds.remove_columns([c for c in ds.column_names if c not in keep])
cleaned_sets.append(ds)
# ========= 打印统计结果 =========
print("\n================ Token 统计对比 ================ ")
print(f"{'数据集':<22} | {'过滤前 max/min/mean':<25} | {'过滤后 max/min/mean':<25}")
print("-"*80)
for name in parquet_paths:
n = name.split("/")[-1]
b_max, b_min, b_mean = stats_before[n]
a_max, a_min, a_mean = stats_after[n]
print(f"{n:<22} | {b_max:5d}/{b_min:5d}/{b_mean:7.1f} | {a_max:5d}/{a_min:5d}/{a_mean:7.1f}")
# ========= (可选)合并并保存 =========
merged = concatenate_datasets(cleaned_sets)
merged.to_parquet(output_path)
print("\n✅ 合并后样本数:", len(merged),)