rm_code / token_20-80.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
import os
from pathlib import Path
from typing import List, Union
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
# ====================== 配置区 ======================
# 可以把 data_dir 改成你的目录,或直接用 parquet_paths 指定一组文件
# data_dir: Union[str, Path] = "/path/to/parquet_dir" # 包含若干 .parquet 的目录
parquet_paths: List[str] = [
"/home/data/train_10k_sys_3round.parquet",
] # 或者直接给出文件清单(优先使用这个)
tokenizer_path = "/home/rm3.4.1_9e-6" # 分词器(与训练一致)
output_path = "/home/data/prefiltered.parquet" # 合并后过滤结果
num_proc = max(1, (os.cpu_count() or 4) // 2) # 并行进程数,可按机器调整
min_tokens, max_tokens = 20, 80 # 过滤阈值(含边界)
# ==================================================
def collect_parquet_files() -> List[str]:
if parquet_paths:
return [str(Path(p)) for p in parquet_paths]
p = Path(data_dir)
if not p.exists():
raise FileNotFoundError(f"目录不存在:{p}")
files = sorted([str(fp) for fp in p.glob("*.parquet")])
if not files:
raise FileNotFoundError(f"目录中未找到 .parquet 文件:{p}")
return files
def main():
files = collect_parquet_files()
print(f"发现 {len(files)} 个 parquet 文件,将合并处理:")
for f in files:
print(" -", f)
# 加载 tokenizer(务必与训练阶段一致;不加 special tokens)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# 方式A:一次性合并加载(更快,前提是 schema 一致)
dataset = load_dataset("parquet", data_files=files, split="train")
# 如果你的文件 schema 不完全一致,可以改用逐个加载再 concatenate:
# parts = [load_dataset("parquet", data_files=f, split="train") for f in files]
# dataset = concatenate_datasets(parts)
total_before = len(dataset)
print(f"\n合并后样本数:{total_before}")
# === 计算 token 数(batched=True 更快) ===
def add_token_lengths(batch):
chosen = batch["chosen"]
reject = batch["reject"]
# tokenizer 接收 list,返回每个文本的 input_ids 列表
chosen_ids = tokenizer(chosen, add_special_tokens=False)["input_ids"]
reject_ids = tokenizer(reject, add_special_tokens=False)["input_ids"]
return {
"chosen_tokens": [len(x) for x in chosen_ids],
"reject_tokens": [len(x) for x in reject_ids],
}
dataset = dataset.map(
add_token_lengths,
batched=True,
num_proc=num_proc,
desc="计算 token 数",
)
# === 过滤:两个字段都须在 [min_tokens, max_tokens] 内 ===
def in_range_filter(batch):
ct = batch["chosen_tokens"]
rt = batch["reject_tokens"]
# batched=True 时需要返回布尔列表
return [
(min_tokens <= c <= max_tokens) and (min_tokens <= r <= max_tokens)
for c, r in zip(ct, rt)
]
dataset = dataset.filter(
in_range_filter,
batched=True,
num_proc=num_proc,
desc=f"过滤:保留 {min_tokens}~{max_tokens} tokens",
)
kept = len(dataset)
print(f"过滤完成:保留 {kept} / {total_before} (保留率 {kept/total_before:.2%})")
# === 清理临时列并保存 ===
# 若原数据没有这两个字段就不会删除失败;有就删,避免污染
for col in ["chosen_tokens", "reject_tokens"]:
if col in dataset.column_names:
dataset = dataset.remove_columns(col)
# 将结果一次性保存为 Parquet(合并后的单文件)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
dataset.to_parquet(output_path)
print(f"已保存到:{output_path}")
if __name__ == "__main__":
main()