from transformers import AutoTokenizer from datasets import load_dataset, concatenate_datasets import numpy as np from tqdm import tqdm # ========= 配置 ========= tokenizer_path = "/home/rm3.4.1_9e-6" # ← 换成你的 tokenizer 路径 parquet_paths = [ # "/home/data/pk-2089-L6.parquet", # "/home/data/pk-1820-L6.parquet", # "/home/data/pk-2355-L6.parquet", # "/home/data/pk-4088-L6.parquet", # "/home/data/pk-3876-L6.parquet", # "/home/data/pk-2749-L6.parquet", # "/home/data/pk-2354-L5.parquet", # "/home/data/pk-3774-L6.parquet", # "/home/data/pk-1158-L5.parquet" # "/home/data/pk-4537-L0.parquet" # "/home/data/pk-1740-L4.parquet" # "/home/data/raw/test/1159-L6_format.parquet" # "/home/data/raw/test/4201_2355_full_label.parquet" "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" ] # parquet_paths=["/home/data/prefiltered.parquet"] output_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" # 合并后的输出路径 tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) def count_total_tokens(ex): """为样本添加 total_tokens / chosen_tokens / rejected_tokens 字段""" prompt = ex["chosen_prompt"] chosen_ids = tokenizer(prompt + ex["chosen"], add_special_tokens=False)["input_ids"] rejected_ids = tokenizer(prompt + ex["reject"], add_special_tokens=False)["input_ids"] ex["total_tokens"] = len(chosen_ids) + len(rejected_ids) ex["chosen_tokens"] = len(chosen_ids) ex["rejected_tokens"] = len(rejected_ids) return ex def summary(arr): """返回 max, min, mean(三个 int/float)""" return int(arr.max()), int(arr.min()), float(arr.mean()) # ========= 主流程 ========= cleaned_sets = [] # 过滤后的数据集 stats_before = {} # {file: (max, min, mean)} stats_after = {} # {file: (max, min, mean)} for path in parquet_paths: name = path.split("/")[-1] print(f"\n▶ 处理 {name}") # 1. 加载 ds = load_dataset("parquet", data_files=path, split="train") print(len(ds)) # 2. 过滤前统计 tokens_b = np.array( tokenizer(ds["chosen_prompt"][0] + ds["chosen"][0], add_special_tokens=False)["input_ids"] ) # 占位初始化,以防空集 # 实际统计要对整个列做,需要先计算 token 字段 ds_tmp = ds.map(count_total_tokens, desc=f"[{name}] 计算 token (预统计)", num_proc=4) stats_before[name] = summary(np.array(ds_tmp["total_tokens"])) # 3. 正式计算 token 并过滤 ds = ds_tmp.filter( lambda x: 1000 <= x["total_tokens"] <= 8192, desc=f"[{name}] 过滤区间 [1000, 8192]" ) # 4. 过滤后统计 stats_after[name] = summary(np.array(ds["total_tokens"])) # 5. 去掉无关列,只留三列 & token 字段(方便后续合并) # keep = ["chosen", "chosen_prompt", "reject", "total_tokens", "chosen_tokens", "rejected_tokens"] # ds = ds.remove_columns([c for c in ds.column_names if c not in keep]) cleaned_sets.append(ds) # ========= 打印统计结果 ========= print("\n================ Token 统计对比 ================ ") print(f"{'数据集':<22} | {'过滤前 max/min/mean':<25} | {'过滤后 max/min/mean':<25}") print("-"*80) for name in parquet_paths: n = name.split("/")[-1] b_max, b_min, b_mean = stats_before[n] a_max, a_min, a_mean = stats_after[n] print(f"{n:<22} | {b_max:5d}/{b_min:5d}/{b_mean:7.1f} | {a_max:5d}/{a_min:5d}/{a_mean:7.1f}") # ========= (可选)合并并保存 ========= merged = concatenate_datasets(cleaned_sets) merged.to_parquet(output_path) print("\n✅ 合并后样本数:", len(merged),)