File size: 5,100 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from argparse import Namespace
import pandas as pd
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
import wandb

# === 与模型卡匹配的模板片段 ===
PREFIX = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

DATA_PATH = "/home/data/test_transformed_v1.parquet"  # 数据路径
WANDB_PROJECT = "reranker_eval_wrong"
WANDB_RUN_NAME = "qwen3_seqcls_scoring"

def format_query(chosen_prompt: str) -> str:
    # 直接把整段 chosen_prompt 当做 Query(原样不抽取)
    instruction = (
        "Given a roleplay prompt and recent context, score candidate replies higher when they stay in character, continue the scene coherently, and feel vivid and engaging."
    )
    return f"{PREFIX}<Instruct>: {instruction}\n<Query>:{chosen_prompt}\n"

def format_document(doc_text: str) -> str:
    # 候选文本作为 <Document>,并接上 SUFFIX
    return f"<Document>: {doc_text}{SUFFIX}"

def parse_args():
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    parser.set_defaults(
        model="deeppin/Qwen3-Reranker-8B-SequenceClassification",
        task="score",
        enforce_eager=True,
        trust_remote_code=True,
    )
    return parser.parse_args()

def main(args: Namespace):
    # 1) 读取数据
    df = pd.read_parquet(DATA_PATH)

    wandb.init(project=WANDB_PROJECT, name=WANDB_RUN_NAME)
    wandb.config.update({"model": args.model, "data_path": DATA_PATH})

    # 2) 初始化模型
    llm = LLM(**vars(args))

    # 3) 逐条打分 + 实时累计正确率
    correct = 0
    total = 0
    wrong_samples = []
    for i, row in df.iterrows():
        chosen_prompt = row["chosen_prompt"]
        chosen = row["chosen"]
        reject = row["reject"]

        # 若有缺失,跳过该样本
        if not isinstance(chosen_prompt, str) or not isinstance(chosen, str) or not isinstance(reject, str):
            continue
        if chosen.strip() == "" or reject.strip() == "":
            continue

        q = format_query(chosen_prompt)
        d1 = format_document(chosen)
        d2 = format_document(reject)

        try:
            # 同一个 q,分别与 d1/d2 配对打分
            outs = llm.score([q, q], [d1, d2])
            # 按你指定的访问方式取分
            s1, s2 = (o.outputs.score for o in outs) 
            chosen_better = (s1 > s2)
            total += 1
            if chosen_better:
                correct += 1
            running_acc = correct / total if total > 0 else 0.0
            # 每条样本打印分数与是否正确
            print({"chosen_score": s1, "reject_score": s2, "chosen_better": chosen_better},f"[RunningAcc] {correct}/{total} = {running_acc:.4f}")
            wandb.log({
                    "metric/running_acc": running_acc,
                    "score/chosen": float(s1),
                    "score/reject": float(s2),
                    "score/margin": float(s1 - s2),
                }, step=total)
            if not chosen_better:
                wrong_samples.append({
                    "index": int(i),
                    "chosen_score": float(s1),
                    "reject_score": float(s2),
                    "margin": float(s1 - s2),
                    "chosen_prompt": chosen_prompt,
                    "chosen": chosen,
                    "reject": reject,
                })
        except Exception as e:
            # 出错不断流:记录并继续
            print(f"[Error] index={i}: {e}")

    # 4) 结束后给出最终正确率
    final_acc = correct / total if total > 0 else 0.0
    print(f"[FinalAcc] {correct}/{total} = {final_acc:.4f}")
    wandb.summary["final/accuracy"] = final_acc
    wandb.summary["final/total"] = total
    wandb.summary["final/correct"] = correct
    wandb.summary["final/wrong"] = len(wrong_samples)
         # 把判错样本作为表格上传
    if wrong_samples:
        table = wandb.Table(columns=[
            "index", "chosen_score", "reject_score", "margin",
            "chosen_prompt", "chosen", "reject"
        ])
        for r in wrong_samples:
            table.add_data(
                r["index"], r["chosen_score"], r["reject_score"], r["margin"],
                r["chosen_prompt"], r["chosen"], r["reject"]
            )
        wandb.log({"errors/wrong_samples": table})
            # 另存 CSV   artifact(可选)
        try:
            _df = pd.DataFrame(wrong_samples)
            _df.to_csv("wrong_samples.csv", index=False)
            art = wandb.Artifact("wrong_samples", type="dataset")
            art.add_file("wrong_samples.csv")
            wandb.log_artifact(art)
        except Exception:
            pass
    wandb.finish()

if __name__ == "__main__":
    args = parse_args()
    main(args)