# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
import pandas as pd
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
import wandb
# === 与模型卡匹配的模板片段 ===
PREFIX = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
DATA_PATH = "/home/data/test_transformed_v1.parquet" # 数据路径
WANDB_PROJECT = "reranker_eval_wrong"
WANDB_RUN_NAME = "qwen3_seqcls_scoring"
def format_query(chosen_prompt: str) -> str:
# 直接把整段 chosen_prompt 当做 Query(原样不抽取)
instruction = (
"Given a roleplay prompt and recent context, score candidate replies higher when they stay in character, continue the scene coherently, and feel vivid and engaging."
)
return f"{PREFIX}: {instruction}\n:{chosen_prompt}\n"
def format_document(doc_text: str) -> str:
# 候选文本作为 ,并接上 SUFFIX
return f": {doc_text}{SUFFIX}"
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
parser.set_defaults(
model="deeppin/Qwen3-Reranker-8B-SequenceClassification",
task="score",
enforce_eager=True,
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# 1) 读取数据
df = pd.read_parquet(DATA_PATH)
wandb.init(project=WANDB_PROJECT, name=WANDB_RUN_NAME)
wandb.config.update({"model": args.model, "data_path": DATA_PATH})
# 2) 初始化模型
llm = LLM(**vars(args))
# 3) 逐条打分 + 实时累计正确率
correct = 0
total = 0
wrong_samples = []
for i, row in df.iterrows():
chosen_prompt = row["chosen_prompt"]
chosen = row["chosen"]
reject = row["reject"]
# 若有缺失,跳过该样本
if not isinstance(chosen_prompt, str) or not isinstance(chosen, str) or not isinstance(reject, str):
continue
if chosen.strip() == "" or reject.strip() == "":
continue
q = format_query(chosen_prompt)
d1 = format_document(chosen)
d2 = format_document(reject)
try:
# 同一个 q,分别与 d1/d2 配对打分
outs = llm.score([q, q], [d1, d2])
# 按你指定的访问方式取分
s1, s2 = (o.outputs.score for o in outs)
chosen_better = (s1 > s2)
total += 1
if chosen_better:
correct += 1
running_acc = correct / total if total > 0 else 0.0
# 每条样本打印分数与是否正确
print({"chosen_score": s1, "reject_score": s2, "chosen_better": chosen_better},f"[RunningAcc] {correct}/{total} = {running_acc:.4f}")
wandb.log({
"metric/running_acc": running_acc,
"score/chosen": float(s1),
"score/reject": float(s2),
"score/margin": float(s1 - s2),
}, step=total)
if not chosen_better:
wrong_samples.append({
"index": int(i),
"chosen_score": float(s1),
"reject_score": float(s2),
"margin": float(s1 - s2),
"chosen_prompt": chosen_prompt,
"chosen": chosen,
"reject": reject,
})
except Exception as e:
# 出错不断流:记录并继续
print(f"[Error] index={i}: {e}")
# 4) 结束后给出最终正确率
final_acc = correct / total if total > 0 else 0.0
print(f"[FinalAcc] {correct}/{total} = {final_acc:.4f}")
wandb.summary["final/accuracy"] = final_acc
wandb.summary["final/total"] = total
wandb.summary["final/correct"] = correct
wandb.summary["final/wrong"] = len(wrong_samples)
# 把判错样本作为表格上传
if wrong_samples:
table = wandb.Table(columns=[
"index", "chosen_score", "reject_score", "margin",
"chosen_prompt", "chosen", "reject"
])
for r in wrong_samples:
table.add_data(
r["index"], r["chosen_score"], r["reject_score"], r["margin"],
r["chosen_prompt"], r["chosen"], r["reject"]
)
wandb.log({"errors/wrong_samples": table})
# 另存 CSV artifact(可选)
try:
_df = pd.DataFrame(wrong_samples)
_df.to_csv("wrong_samples.csv", index=False)
art = wandb.Artifact("wrong_samples", type="dataset")
art.add_file("wrong_samples.csv")
wandb.log_artifact(art)
except Exception:
pass
wandb.finish()
if __name__ == "__main__":
args = parse_args()
main(args)