import os from pathlib import Path from typing import List, Union from datasets import load_dataset, concatenate_datasets from transformers import AutoTokenizer # ====================== 配置区 ====================== # 可以把 data_dir 改成你的目录,或直接用 parquet_paths 指定一组文件 # data_dir: Union[str, Path] = "/path/to/parquet_dir" # 包含若干 .parquet 的目录 parquet_paths: List[str] = [ "/home/data/train_10k_sys_3round.parquet", ] # 或者直接给出文件清单(优先使用这个) tokenizer_path = "/home/rm3.4.1_9e-6" # 分词器(与训练一致) output_path = "/home/data/prefiltered.parquet" # 合并后过滤结果 num_proc = max(1, (os.cpu_count() or 4) // 2) # 并行进程数,可按机器调整 min_tokens, max_tokens = 20, 80 # 过滤阈值(含边界) # ================================================== def collect_parquet_files() -> List[str]: if parquet_paths: return [str(Path(p)) for p in parquet_paths] p = Path(data_dir) if not p.exists(): raise FileNotFoundError(f"目录不存在:{p}") files = sorted([str(fp) for fp in p.glob("*.parquet")]) if not files: raise FileNotFoundError(f"目录中未找到 .parquet 文件:{p}") return files def main(): files = collect_parquet_files() print(f"发现 {len(files)} 个 parquet 文件,将合并处理:") for f in files: print(" -", f) # 加载 tokenizer(务必与训练阶段一致;不加 special tokens) tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) # 方式A:一次性合并加载(更快,前提是 schema 一致) dataset = load_dataset("parquet", data_files=files, split="train") # 如果你的文件 schema 不完全一致,可以改用逐个加载再 concatenate: # parts = [load_dataset("parquet", data_files=f, split="train") for f in files] # dataset = concatenate_datasets(parts) total_before = len(dataset) print(f"\n合并后样本数:{total_before}") # === 计算 token 数(batched=True 更快) === def add_token_lengths(batch): chosen = batch["chosen"] reject = batch["reject"] # tokenizer 接收 list,返回每个文本的 input_ids 列表 chosen_ids = tokenizer(chosen, add_special_tokens=False)["input_ids"] reject_ids = tokenizer(reject, add_special_tokens=False)["input_ids"] return { "chosen_tokens": [len(x) for x in chosen_ids], "reject_tokens": [len(x) for x in reject_ids], } dataset = dataset.map( add_token_lengths, batched=True, num_proc=num_proc, desc="计算 token 数", ) # === 过滤:两个字段都须在 [min_tokens, max_tokens] 内 === def in_range_filter(batch): ct = batch["chosen_tokens"] rt = batch["reject_tokens"] # batched=True 时需要返回布尔列表 return [ (min_tokens <= c <= max_tokens) and (min_tokens <= r <= max_tokens) for c, r in zip(ct, rt) ] dataset = dataset.filter( in_range_filter, batched=True, num_proc=num_proc, desc=f"过滤:保留 {min_tokens}~{max_tokens} tokens", ) kept = len(dataset) print(f"过滤完成:保留 {kept} / {total_before} (保留率 {kept/total_before:.2%})") # === 清理临时列并保存 === # 若原数据没有这两个字段就不会删除失败;有就删,避免污染 for col in ["chosen_tokens", "reject_tokens"]: if col in dataset.column_names: dataset = dataset.remove_columns(col) # 将结果一次性保存为 Parquet(合并后的单文件) Path(output_path).parent.mkdir(parents=True, exist_ok=True) dataset.to_parquet(output_path) print(f"已保存到:{output_path}") if __name__ == "__main__": main()