rm_code / simi_score.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
# -*- coding: utf-8 -*-
"""
对称比较直方图(仅输出 PNG):
- 同时计算 chosen→reject 与 reject→chosen 的 BERTScore-F1 与 ROUGE-L F1;
- 在每个指标上做方向平均(对称分数);
- 将两种指标的直方图画在同一张 PNG 中保存;
- 直接运行脚本(无需命令行参数)。
"""
import os
import math
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg") # 适配无 GUI 环境
import matplotlib.pyplot as plt
from bert_score import score as bertscore
from rouge_score import rouge_scorer
# ========= 配置(按需修改)=========
DATA_PATH = "/home/data/prefiltered.parquet" # 你的 parquet 路径
CHOSEN_COL = "chosen"
REJECT_COL = "reject"
LANG = "en" # BERTScore 语言(中文可用 "zh")
BERTSCORE_MODEL = "roberta-large" # 中文可用 "hfl/chinese-roberta-wwm-ext"
BATCH_SIZE = 256 # 仅作用于 BERTScore 的外层批大小
BERT_BATCH_CAP = 64 # 传给 bert-score 的每次前向上限,防 OOM
PNG_PATH = "symmetric_metrics_hist.png"
# ========= 工具函数 =========
def norm_text(x):
if x is None or (isinstance(x, float) and math.isnan(x)):
return ""
return str(x).strip()
def compute_bert_symmetric_f1(chosen_list, reject_list, lang, model_type, batch_size):
"""
对称 BERTScore-F1:
F1_sym = 0.5 * (F1(chosen→reject) + F1(reject→chosen))
返回 numpy.float32 数组(长度等于样本数)
"""
assert len(chosen_list) == len(reject_list)
n = len(chosen_list)
out_f1 = np.zeros(n, dtype=np.float32)
idx = 0
for start in tqdm(range(0, n, batch_size), desc="BERTScore Symmetric"):
end = min(start + batch_size, n)
c_batch = chosen_list[start:end]
r_batch = reject_list[start:end]
# 方向1:chosen -> reject
_, _, f1_cr = bertscore(
c_batch, r_batch,
lang=lang,
model_type=model_type,
rescale_with_baseline=True,
verbose=False,
batch_size=min(BERT_BATCH_CAP, batch_size),
)
# 方向2:reject -> chosen
_, _, f1_rc = bertscore(
r_batch, c_batch,
lang=lang,
model_type=model_type,
rescale_with_baseline=True,
verbose=False,
batch_size=min(BERT_BATCH_CAP, batch_size),
)
f1_sym = 0.5 * (f1_cr.cpu().numpy() + f1_rc.cpu().numpy())
out_f1[idx: idx + len(f1_sym)] = f1_sym.astype(np.float32)
idx += len(f1_sym)
return out_f1
def compute_rougeL_symmetric_f1(chosen_list, reject_list, use_stemmer=True):
"""
对称 ROUGE-L F1:
F1_sym = 0.5 * (F1(chosen→reject) + F1(reject→chosen))
返回 numpy.float32 数组
"""
assert len(chosen_list) == len(reject_list)
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=use_stemmer)
out = np.zeros(len(chosen_list), dtype=np.float32)
for i, (c, r) in enumerate(tqdm(zip(chosen_list, reject_list),
total=len(chosen_list),
desc="ROUGE-L Symmetric")):
# rouge_scorer.score(reference, candidate)
s_cr = scorer.score(c, r)["rougeL"].fmeasure # chosen→reject
s_rc = scorer.score(r, c)["rougeL"].fmeasure # reject→chosen
out[i] = 0.5 * (s_cr + s_rc)
return out.astype(np.float32)
# ========= 主流程 =========
def main():
# 读取 parquet
df = pd.read_parquet(DATA_PATH)
if CHOSEN_COL not in df.columns or REJECT_COL not in df.columns:
raise ValueError(f"输入文件缺少列:{CHOSEN_COL}{REJECT_COL}")
df[CHOSEN_COL] = df[CHOSEN_COL].map(norm_text)
df[REJECT_COL] = df[REJECT_COL].map(norm_text)
mask = (df[CHOSEN_COL].str.len() > 0) & (df[REJECT_COL].str.len() > 0)
df = df[mask].reset_index(drop=True)
chosen_list = df[CHOSEN_COL].tolist()
reject_list = df[REJECT_COL].tolist()
n = len(chosen_list)
if n == 0:
raise ValueError("过滤后没有有效样本。请检查输入列内容。")
# 1) 对称 BERTScore-F1
berts_f1_sym = compute_bert_symmetric_f1(
chosen_list, reject_list,
lang=LANG,
model_type=BERTSCORE_MODEL,
batch_size=BATCH_SIZE,
)
# 2) 对称 ROUGE-L F1
rougeL_f1_sym = compute_rougeL_symmetric_f1(
chosen_list, reject_list, use_stemmer=True
)
# 3) 绘图:两种指标的直方图,保存同一张 PNG
plt.figure(figsize=(12, 5))
# 计算 bin 范围(这里覆盖实际分数范围)
bins_bert = np.linspace(berts_f1_sym.min(), berts_f1_sym.max(), 30) # 分30个bin
bins_rouge = np.linspace(rougeL_f1_sym.min(), rougeL_f1_sym.max(), 30)
# 左图 - BERTScore-F1
plt.subplot(1, 2, 1)
plt.hist(berts_f1_sym, bins=bins_bert, color='blue', alpha=0.7, edgecolor='black')
plt.title("Distribution of F1 BERT Scores")
plt.xlabel("F1 BERT Score")
plt.ylabel("Frequency")
# 右图 - ROUGE-L F1
plt.subplot(1, 2, 2)
plt.hist(rougeL_f1_sym, bins=bins_rouge, color='green', alpha=0.7, edgecolor='black')
plt.title("Distribution of F1 ROUGE-L Scores")
plt.xlabel("F1 ROUGE-L Score")
plt.ylabel("Frequency")
plt.tight_layout()
plt.savefig(PNG_PATH, dpi=300)
print(f"[Info] 直方图已保存:{os.path.abspath(PNG_PATH)}")
if __name__ == "__main__":
main()