File size: 4,610 Bytes
d8a76be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | # -*- coding: utf-8 -*-
"""
过滤规则:
- 保留:BERTScore-F1_sym ∈ [0.05, 0.40] 且 ROUGE-L_F1_sym ∈ [0.05, 0.35]
- 其余样本丢弃
- 导出 parquet(仅保留 chosen_prompt、chosen、reject 三列)
- 随机抽样打印 5 条样本
"""
import os
import math
import numpy as np
import pandas as pd
from tqdm import tqdm
from bert_score import score as bertscore
from rouge_score import rouge_scorer
# ========= 配置 =========
DATA_PATH = "/home/data/train_10k_sys_3round.parquet"
OUTPUT_PATH = "/home/data/filtered_v1.parquet"
CHOSEN_PROMPT_COL = "chosen_prompt"
CHOSEN_COL = "chosen"
REJECT_COL = "reject"
LANG = "en"
BERTSCORE_MODEL = "roberta-large"
BATCH_SIZE = 256
BERT_BATCH_CAP = 64
BERT_LO, BERT_HI = 0.05, 0.35
ROUGE_LO, ROUGE_HI = 0.05, 0.30
# ========= 工具函数 =========
def norm_text(x):
if x is None or (isinstance(x, float) and math.isnan(x)):
return ""
return str(x).strip()
def compute_bert_symmetric_f1(chosen_list, reject_list, lang, model_type, batch_size):
assert len(chosen_list) == len(reject_list)
n = len(chosen_list)
out_f1 = np.zeros(n, dtype=np.float32)
idx = 0
for start in tqdm(range(0, n, batch_size), desc="BERTScore Symmetric"):
end = min(start + batch_size, n)
c_batch = chosen_list[start:end]
r_batch = reject_list[start:end]
_, _, f1_cr = bertscore(
c_batch, r_batch,
lang=lang,
model_type=model_type,
rescale_with_baseline=True,
verbose=False,
batch_size=min(BERT_BATCH_CAP, batch_size),
)
_, _, f1_rc = bertscore(
r_batch, c_batch,
lang=lang,
model_type=model_type,
rescale_with_baseline=True,
verbose=False,
batch_size=min(BERT_BATCH_CAP, batch_size),
)
f1_sym = 0.5 * (f1_cr.cpu().numpy() + f1_rc.cpu().numpy())
out_f1[idx: idx + len(f1_sym)] = f1_sym.astype(np.float32)
idx += len(f1_sym)
return out_f1
def compute_rougeL_symmetric_f1(chosen_list, reject_list, use_stemmer=True):
assert len(chosen_list) == len(reject_list)
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=use_stemmer)
out = np.zeros(len(chosen_list), dtype=np.float32)
for i, (c, r) in enumerate(tqdm(zip(chosen_list, reject_list),
total=len(chosen_list),
desc="ROUGE-L Symmetric")):
s_cr = scorer.score(c, r)["rougeL"].fmeasure
s_rc = scorer.score(r, c)["rougeL"].fmeasure
out[i] = 0.5 * (s_cr + s_rc)
return out.astype(np.float32)
# ========= 主流程 =========
def main():
df = pd.read_parquet(DATA_PATH)
for col in [CHOSEN_PROMPT_COL, CHOSEN_COL, REJECT_COL]:
if col not in df.columns:
raise ValueError(f"输入文件缺少列:{col}")
# 只保留三列,并清洗空样本
df = df[[CHOSEN_PROMPT_COL, CHOSEN_COL, REJECT_COL]].copy()
df[CHOSEN_PROMPT_COL] = df[CHOSEN_PROMPT_COL].map(norm_text)
df[CHOSEN_COL] = df[CHOSEN_COL].map(norm_text)
df[REJECT_COL] = df[REJECT_COL].map(norm_text)
mask = (df[CHOSEN_COL].str.len() > 0) & (df[REJECT_COL].str.len() > 0)
df = df[mask].reset_index(drop=True)
if len(df) == 0:
raise ValueError("过滤后没有有效样本。")
chosen_list = df[CHOSEN_COL].tolist()
reject_list = df[REJECT_COL].tolist()
bert_f1_sym = compute_bert_symmetric_f1(
chosen_list, reject_list, lang=LANG,
model_type=BERTSCORE_MODEL, batch_size=BATCH_SIZE
)
rougeL_f1_sym = compute_rougeL_symmetric_f1(
chosen_list, reject_list, use_stemmer=True
)
keep = (
(bert_f1_sym >= BERT_LO) & (bert_f1_sym <= BERT_HI) &
(rougeL_f1_sym >= ROUGE_LO) & (rougeL_f1_sym <= ROUGE_HI)
)
kept_df = df[keep].reset_index(drop=True)
kept_df.to_parquet(OUTPUT_PATH, index=False)
print(f"[Info] 原始样本数: {len(df)}")
print(f"[Info] 保留样本数: {len(kept_df)} (保留率 {len(kept_df)/len(df):.2%})")
print(f"[Info] 已保存到: {os.path.abspath(OUTPUT_PATH)}")
show_n = min(5, len(kept_df))
if show_n > 0:
print("\n[Sample] 随机抽样 5 条:")
print(
kept_df.sample(show_n, random_state=42)
.to_string(index=False, max_colwidth=80)
)
else:
print("[Warn] 过滤后无样本,请调整阈值。")
if __name__ == "__main__":
main()
|