File size: 4,610 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -*- coding: utf-8 -*-
"""
过滤规则:
- 保留:BERTScore-F1_sym ∈ [0.05, 0.40] 且 ROUGE-L_F1_sym ∈ [0.05, 0.35]
- 其余样本丢弃
- 导出 parquet(仅保留 chosen_prompt、chosen、reject 三列)
- 随机抽样打印 5 条样本
"""
import os
import math
import numpy as np
import pandas as pd
from tqdm import tqdm

from bert_score import score as bertscore
from rouge_score import rouge_scorer

# ========= 配置 =========
DATA_PATH = "/home/data/train_10k_sys_3round.parquet"     
OUTPUT_PATH = "/home/data/filtered_v1.parquet"  

CHOSEN_PROMPT_COL = "chosen_prompt"
CHOSEN_COL = "chosen"
REJECT_COL = "reject"

LANG = "en"                          
BERTSCORE_MODEL = "roberta-large"    
BATCH_SIZE = 256                     
BERT_BATCH_CAP = 64                  

BERT_LO,  BERT_HI  = 0.05, 0.35
ROUGE_LO, ROUGE_HI = 0.05, 0.30


# ========= 工具函数 =========
def norm_text(x):
    if x is None or (isinstance(x, float) and math.isnan(x)):
        return ""
    return str(x).strip()

def compute_bert_symmetric_f1(chosen_list, reject_list, lang, model_type, batch_size):
    assert len(chosen_list) == len(reject_list)
    n = len(chosen_list)
    out_f1 = np.zeros(n, dtype=np.float32)
    idx = 0

    for start in tqdm(range(0, n, batch_size), desc="BERTScore Symmetric"):
        end = min(start + batch_size, n)
        c_batch = chosen_list[start:end]
        r_batch = reject_list[start:end]

        _, _, f1_cr = bertscore(
            c_batch, r_batch,
            lang=lang,
            model_type=model_type,
            rescale_with_baseline=True,
            verbose=False,
            batch_size=min(BERT_BATCH_CAP, batch_size),
        )
        _, _, f1_rc = bertscore(
            r_batch, c_batch,
            lang=lang,
            model_type=model_type,
            rescale_with_baseline=True,
            verbose=False,
            batch_size=min(BERT_BATCH_CAP, batch_size),
        )

        f1_sym = 0.5 * (f1_cr.cpu().numpy() + f1_rc.cpu().numpy())
        out_f1[idx: idx + len(f1_sym)] = f1_sym.astype(np.float32)
        idx += len(f1_sym)

    return out_f1

def compute_rougeL_symmetric_f1(chosen_list, reject_list, use_stemmer=True):
    assert len(chosen_list) == len(reject_list)
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=use_stemmer)
    out = np.zeros(len(chosen_list), dtype=np.float32)

    for i, (c, r) in enumerate(tqdm(zip(chosen_list, reject_list),
                                    total=len(chosen_list),
                                    desc="ROUGE-L Symmetric")):
        s_cr = scorer.score(c, r)["rougeL"].fmeasure
        s_rc = scorer.score(r, c)["rougeL"].fmeasure
        out[i] = 0.5 * (s_cr + s_rc)

    return out.astype(np.float32)


# ========= 主流程 =========
def main():
    df = pd.read_parquet(DATA_PATH)
    for col in [CHOSEN_PROMPT_COL, CHOSEN_COL, REJECT_COL]:
        if col not in df.columns:
            raise ValueError(f"输入文件缺少列:{col}")

    # 只保留三列,并清洗空样本
    df = df[[CHOSEN_PROMPT_COL, CHOSEN_COL, REJECT_COL]].copy()
    df[CHOSEN_PROMPT_COL] = df[CHOSEN_PROMPT_COL].map(norm_text)
    df[CHOSEN_COL] = df[CHOSEN_COL].map(norm_text)
    df[REJECT_COL] = df[REJECT_COL].map(norm_text)

    mask = (df[CHOSEN_COL].str.len() > 0) & (df[REJECT_COL].str.len() > 0)
    df = df[mask].reset_index(drop=True)

    if len(df) == 0:
        raise ValueError("过滤后没有有效样本。")

    chosen_list = df[CHOSEN_COL].tolist()
    reject_list = df[REJECT_COL].tolist()

    bert_f1_sym = compute_bert_symmetric_f1(
        chosen_list, reject_list, lang=LANG,
        model_type=BERTSCORE_MODEL, batch_size=BATCH_SIZE
    )
    rougeL_f1_sym = compute_rougeL_symmetric_f1(
        chosen_list, reject_list, use_stemmer=True
    )

    keep = (
        (bert_f1_sym >= BERT_LO) & (bert_f1_sym <= BERT_HI) &
        (rougeL_f1_sym >= ROUGE_LO) & (rougeL_f1_sym <= ROUGE_HI)
    )
    kept_df = df[keep].reset_index(drop=True)

    kept_df.to_parquet(OUTPUT_PATH, index=False)

    print(f"[Info] 原始样本数: {len(df)}")
    print(f"[Info] 保留样本数: {len(kept_df)} (保留率 {len(kept_df)/len(df):.2%})")
    print(f"[Info] 已保存到: {os.path.abspath(OUTPUT_PATH)}")

    show_n = min(5, len(kept_df))
    if show_n > 0:
        print("\n[Sample] 随机抽样 5 条:")
        print(
            kept_df.sample(show_n, random_state=42)
                   .to_string(index=False, max_colwidth=80)
        )
    else:
        print("[Warn] 过滤后无样本,请调整阈值。")

if __name__ == "__main__":
    main()