File size: 5,100 Bytes
d8a76be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
import pandas as pd
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
import wandb
# === 与模型卡匹配的模板片段 ===
PREFIX = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
DATA_PATH = "/home/data/test_transformed_v1.parquet" # 数据路径
WANDB_PROJECT = "reranker_eval_wrong"
WANDB_RUN_NAME = "qwen3_seqcls_scoring"
def format_query(chosen_prompt: str) -> str:
# 直接把整段 chosen_prompt 当做 Query(原样不抽取)
instruction = (
"Given a roleplay prompt and recent context, score candidate replies higher when they stay in character, continue the scene coherently, and feel vivid and engaging."
)
return f"{PREFIX}<Instruct>: {instruction}\n<Query>:{chosen_prompt}\n"
def format_document(doc_text: str) -> str:
# 候选文本作为 <Document>,并接上 SUFFIX
return f"<Document>: {doc_text}{SUFFIX}"
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
parser.set_defaults(
model="deeppin/Qwen3-Reranker-8B-SequenceClassification",
task="score",
enforce_eager=True,
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# 1) 读取数据
df = pd.read_parquet(DATA_PATH)
wandb.init(project=WANDB_PROJECT, name=WANDB_RUN_NAME)
wandb.config.update({"model": args.model, "data_path": DATA_PATH})
# 2) 初始化模型
llm = LLM(**vars(args))
# 3) 逐条打分 + 实时累计正确率
correct = 0
total = 0
wrong_samples = []
for i, row in df.iterrows():
chosen_prompt = row["chosen_prompt"]
chosen = row["chosen"]
reject = row["reject"]
# 若有缺失,跳过该样本
if not isinstance(chosen_prompt, str) or not isinstance(chosen, str) or not isinstance(reject, str):
continue
if chosen.strip() == "" or reject.strip() == "":
continue
q = format_query(chosen_prompt)
d1 = format_document(chosen)
d2 = format_document(reject)
try:
# 同一个 q,分别与 d1/d2 配对打分
outs = llm.score([q, q], [d1, d2])
# 按你指定的访问方式取分
s1, s2 = (o.outputs.score for o in outs)
chosen_better = (s1 > s2)
total += 1
if chosen_better:
correct += 1
running_acc = correct / total if total > 0 else 0.0
# 每条样本打印分数与是否正确
print({"chosen_score": s1, "reject_score": s2, "chosen_better": chosen_better},f"[RunningAcc] {correct}/{total} = {running_acc:.4f}")
wandb.log({
"metric/running_acc": running_acc,
"score/chosen": float(s1),
"score/reject": float(s2),
"score/margin": float(s1 - s2),
}, step=total)
if not chosen_better:
wrong_samples.append({
"index": int(i),
"chosen_score": float(s1),
"reject_score": float(s2),
"margin": float(s1 - s2),
"chosen_prompt": chosen_prompt,
"chosen": chosen,
"reject": reject,
})
except Exception as e:
# 出错不断流:记录并继续
print(f"[Error] index={i}: {e}")
# 4) 结束后给出最终正确率
final_acc = correct / total if total > 0 else 0.0
print(f"[FinalAcc] {correct}/{total} = {final_acc:.4f}")
wandb.summary["final/accuracy"] = final_acc
wandb.summary["final/total"] = total
wandb.summary["final/correct"] = correct
wandb.summary["final/wrong"] = len(wrong_samples)
# 把判错样本作为表格上传
if wrong_samples:
table = wandb.Table(columns=[
"index", "chosen_score", "reject_score", "margin",
"chosen_prompt", "chosen", "reject"
])
for r in wrong_samples:
table.add_data(
r["index"], r["chosen_score"], r["reject_score"], r["margin"],
r["chosen_prompt"], r["chosen"], r["reject"]
)
wandb.log({"errors/wrong_samples": table})
# 另存 CSV artifact(可选)
try:
_df = pd.DataFrame(wrong_samples)
_df.to_csv("wrong_samples.csv", index=False)
art = wandb.Artifact("wrong_samples", type="dataset")
art.add_file("wrong_samples.csv")
wandb.log_artifact(art)
except Exception:
pass
wandb.finish()
if __name__ == "__main__":
args = parse_args()
main(args)
|