| | from transformers import AutoTokenizer |
| | from datasets import load_dataset, concatenate_datasets |
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| | |
| | tokenizer_path = "/home/rm3.4.1_9e-6" |
| | parquet_paths = [ |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" |
| | |
| | ] |
| | |
| | output_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| |
|
| | def count_total_tokens(ex): |
| | """为样本添加 total_tokens / chosen_tokens / rejected_tokens 字段""" |
| | prompt = ex["chosen_prompt"] |
| | chosen_ids = tokenizer(prompt + ex["chosen"], add_special_tokens=False)["input_ids"] |
| | rejected_ids = tokenizer(prompt + ex["reject"], add_special_tokens=False)["input_ids"] |
| | ex["total_tokens"] = len(chosen_ids) + len(rejected_ids) |
| | ex["chosen_tokens"] = len(chosen_ids) |
| | ex["rejected_tokens"] = len(rejected_ids) |
| | return ex |
| |
|
| | def summary(arr): |
| | """返回 max, min, mean(三个 int/float)""" |
| | return int(arr.max()), int(arr.min()), float(arr.mean()) |
| |
|
| | |
| | cleaned_sets = [] |
| | stats_before = {} |
| | stats_after = {} |
| |
|
| | for path in parquet_paths: |
| | name = path.split("/")[-1] |
| | print(f"\n▶ 处理 {name}") |
| |
|
| | |
| | ds = load_dataset("parquet", data_files=path, split="train") |
| | print(len(ds)) |
| | |
| | tokens_b = np.array( |
| | tokenizer(ds["chosen_prompt"][0] + ds["chosen"][0], add_special_tokens=False)["input_ids"] |
| | ) |
| | |
| | ds_tmp = ds.map(count_total_tokens, desc=f"[{name}] 计算 token (预统计)", num_proc=4) |
| | stats_before[name] = summary(np.array(ds_tmp["total_tokens"])) |
| |
|
| | |
| | ds = ds_tmp.filter( |
| | lambda x: 1000 <= x["total_tokens"] <= 8192, |
| | desc=f"[{name}] 过滤区间 [1000, 8192]" |
| | ) |
| |
|
| | |
| | stats_after[name] = summary(np.array(ds["total_tokens"])) |
| |
|
| | |
| | |
| | |
| |
|
| | cleaned_sets.append(ds) |
| |
|
| | |
| | print("\n================ Token 统计对比 ================ ") |
| | print(f"{'数据集':<22} | {'过滤前 max/min/mean':<25} | {'过滤后 max/min/mean':<25}") |
| | print("-"*80) |
| | for name in parquet_paths: |
| | n = name.split("/")[-1] |
| | b_max, b_min, b_mean = stats_before[n] |
| | a_max, a_min, a_mean = stats_after[n] |
| | print(f"{n:<22} | {b_max:5d}/{b_min:5d}/{b_mean:7.1f} | {a_max:5d}/{a_min:5d}/{a_mean:7.1f}") |
| |
|
| | |
| | merged = concatenate_datasets(cleaned_sets) |
| | merged.to_parquet(output_path) |
| | print("\n✅ 合并后样本数:", len(merged),) |
| | |
| |
|