# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from argparse import Namespace import pandas as pd from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser import wandb # === 与模型卡匹配的模板片段 === PREFIX = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" DATA_PATH = "/home/data/test_transformed_v1.parquet" # 数据路径 WANDB_PROJECT = "reranker_eval_wrong" WANDB_RUN_NAME = "qwen3_seqcls_scoring" def format_query(chosen_prompt: str) -> str: # 直接把整段 chosen_prompt 当做 Query(原样不抽取) instruction = ( "Given a roleplay prompt and recent context, score candidate replies higher when they stay in character, continue the scene coherently, and feel vivid and engaging." ) return f"{PREFIX}: {instruction}\n:{chosen_prompt}\n" def format_document(doc_text: str) -> str: # 候选文本作为 ,并接上 SUFFIX return f": {doc_text}{SUFFIX}" def parse_args(): parser = FlexibleArgumentParser() parser = EngineArgs.add_cli_args(parser) parser.set_defaults( model="deeppin/Qwen3-Reranker-8B-SequenceClassification", task="score", enforce_eager=True, trust_remote_code=True, ) return parser.parse_args() def main(args: Namespace): # 1) 读取数据 df = pd.read_parquet(DATA_PATH) wandb.init(project=WANDB_PROJECT, name=WANDB_RUN_NAME) wandb.config.update({"model": args.model, "data_path": DATA_PATH}) # 2) 初始化模型 llm = LLM(**vars(args)) # 3) 逐条打分 + 实时累计正确率 correct = 0 total = 0 wrong_samples = [] for i, row in df.iterrows(): chosen_prompt = row["chosen_prompt"] chosen = row["chosen"] reject = row["reject"] # 若有缺失,跳过该样本 if not isinstance(chosen_prompt, str) or not isinstance(chosen, str) or not isinstance(reject, str): continue if chosen.strip() == "" or reject.strip() == "": continue q = format_query(chosen_prompt) d1 = format_document(chosen) d2 = format_document(reject) try: # 同一个 q,分别与 d1/d2 配对打分 outs = llm.score([q, q], [d1, d2]) # 按你指定的访问方式取分 s1, s2 = (o.outputs.score for o in outs) chosen_better = (s1 > s2) total += 1 if chosen_better: correct += 1 running_acc = correct / total if total > 0 else 0.0 # 每条样本打印分数与是否正确 print({"chosen_score": s1, "reject_score": s2, "chosen_better": chosen_better},f"[RunningAcc] {correct}/{total} = {running_acc:.4f}") wandb.log({ "metric/running_acc": running_acc, "score/chosen": float(s1), "score/reject": float(s2), "score/margin": float(s1 - s2), }, step=total) if not chosen_better: wrong_samples.append({ "index": int(i), "chosen_score": float(s1), "reject_score": float(s2), "margin": float(s1 - s2), "chosen_prompt": chosen_prompt, "chosen": chosen, "reject": reject, }) except Exception as e: # 出错不断流:记录并继续 print(f"[Error] index={i}: {e}") # 4) 结束后给出最终正确率 final_acc = correct / total if total > 0 else 0.0 print(f"[FinalAcc] {correct}/{total} = {final_acc:.4f}") wandb.summary["final/accuracy"] = final_acc wandb.summary["final/total"] = total wandb.summary["final/correct"] = correct wandb.summary["final/wrong"] = len(wrong_samples) # 把判错样本作为表格上传 if wrong_samples: table = wandb.Table(columns=[ "index", "chosen_score", "reject_score", "margin", "chosen_prompt", "chosen", "reject" ]) for r in wrong_samples: table.add_data( r["index"], r["chosen_score"], r["reject_score"], r["margin"], r["chosen_prompt"], r["chosen"], r["reject"] ) wandb.log({"errors/wrong_samples": table}) # 另存 CSV artifact(可选) try: _df = pd.DataFrame(wrong_samples) _df.to_csv("wrong_samples.csv", index=False) art = wandb.Artifact("wrong_samples", type="dataset") art.add_file("wrong_samples.csv") wandb.log_artifact(art) except Exception: pass wandb.finish() if __name__ == "__main__": args = parse_args() main(args)