| | import os |
| | from pathlib import Path |
| | from typing import List, Union |
| |
|
| | from datasets import load_dataset, concatenate_datasets |
| | from transformers import AutoTokenizer |
| |
|
| |
|
| | |
| | |
| | |
| | parquet_paths: List[str] = [ |
| | "/home/data/train_10k_sys_3round.parquet", |
| | ] |
| | tokenizer_path = "/home/rm3.4.1_9e-6" |
| | output_path = "/home/data/prefiltered.parquet" |
| | num_proc = max(1, (os.cpu_count() or 4) // 2) |
| | min_tokens, max_tokens = 20, 80 |
| | |
| |
|
| |
|
| | def collect_parquet_files() -> List[str]: |
| | if parquet_paths: |
| | return [str(Path(p)) for p in parquet_paths] |
| | p = Path(data_dir) |
| | if not p.exists(): |
| | raise FileNotFoundError(f"目录不存在:{p}") |
| | files = sorted([str(fp) for fp in p.glob("*.parquet")]) |
| | if not files: |
| | raise FileNotFoundError(f"目录中未找到 .parquet 文件:{p}") |
| | return files |
| |
|
| |
|
| | def main(): |
| | files = collect_parquet_files() |
| | print(f"发现 {len(files)} 个 parquet 文件,将合并处理:") |
| | for f in files: |
| | print(" -", f) |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| |
|
| | |
| | dataset = load_dataset("parquet", data_files=files, split="train") |
| |
|
| | |
| | |
| | |
| |
|
| | total_before = len(dataset) |
| | print(f"\n合并后样本数:{total_before}") |
| |
|
| | |
| | def add_token_lengths(batch): |
| | chosen = batch["chosen"] |
| | reject = batch["reject"] |
| |
|
| | |
| | chosen_ids = tokenizer(chosen, add_special_tokens=False)["input_ids"] |
| | reject_ids = tokenizer(reject, add_special_tokens=False)["input_ids"] |
| |
|
| | return { |
| | "chosen_tokens": [len(x) for x in chosen_ids], |
| | "reject_tokens": [len(x) for x in reject_ids], |
| | } |
| |
|
| | dataset = dataset.map( |
| | add_token_lengths, |
| | batched=True, |
| | num_proc=num_proc, |
| | desc="计算 token 数", |
| | ) |
| |
|
| | |
| | def in_range_filter(batch): |
| | ct = batch["chosen_tokens"] |
| | rt = batch["reject_tokens"] |
| | |
| | return [ |
| | (min_tokens <= c <= max_tokens) and (min_tokens <= r <= max_tokens) |
| | for c, r in zip(ct, rt) |
| | ] |
| |
|
| | dataset = dataset.filter( |
| | in_range_filter, |
| | batched=True, |
| | num_proc=num_proc, |
| | desc=f"过滤:保留 {min_tokens}~{max_tokens} tokens", |
| | ) |
| |
|
| | kept = len(dataset) |
| | print(f"过滤完成:保留 {kept} / {total_before} (保留率 {kept/total_before:.2%})") |
| |
|
| | |
| | |
| | for col in ["chosen_tokens", "reject_tokens"]: |
| | if col in dataset.column_names: |
| | dataset = dataset.remove_columns(col) |
| |
|
| | |
| | Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
| | dataset.to_parquet(output_path) |
| | print(f"已保存到:{output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|