File size: 5,142 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import time
import requests
import pandas as pd
from tqdm import tqdm

def build_text(prompt: str, answer: str, sep: str) -> str:
    """把 prompt + answer 按 sep 拼起来。"""
    prompt = "" if prompt is None else str(prompt)
    answer = "" if answer is None else str(answer)
    return (prompt + sep + answer).strip()

def post_batch(server_url: str, queries, timeout=120, max_retries=3, sleep=1.0):
    """调用服务端 /get_reward,返回 rewards 列表(与输入 len 一致)。"""
    url = server_url.rstrip("/") + "/get_reward"
    payload = {"query": queries, "prompts": None}
    last_err = None
    for _ in range(max_retries):
        try:
            resp = requests.post(url, json=payload, timeout=timeout)
            resp.raise_for_status()
            data = resp.json()
            rewards = data.get("rewards") or data.get("scores")
            if not isinstance(rewards, list):
                raise ValueError(f"Bad response: {data}")
            if len(rewards) != len(queries):
                raise ValueError(f"Length mismatch: got {len(rewards)} for {len(queries)} queries")
            return rewards
        except Exception as e:
            last_err = e
            time.sleep(sleep)
    raise RuntimeError(f"Request failed after retries: {last_err}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--server_url", type=str, required=True, help="FastAPI 服务地址,如 http://localhost:5000")
    ap.add_argument("--data_path", type=str, required=True, help="parquet 文件路径")
    ap.add_argument("--prompt_key", type=str, default="chosen_prompt", help="prompt 字段名")
    ap.add_argument("--chosen_key", type=str, default="chosen", help="chosen 回答字段名")
    ap.add_argument("--rejected_key", type=str, default="reject", help="rejected 回答字段名")
    ap.add_argument("--batch_size", type=int, default=64, help="每次请求发送多少条")
    ap.add_argument("--sep", type=str, default="", help="prompt 与回答之间的连接符,如 '\\n' 或 空串")
    ap.add_argument("--save_csv", type=str, default="rm_http_eval.csv", help="保存明细的 CSV 路径")
    ap.add_argument("--print_each", action="store_true", help="逐样本打印 chosen/reject 分数与实时平均 acc")
    args = ap.parse_args()

    # 读取数据
    df = pd.read_parquet(args.data_path)
    for col in [args.prompt_key, args.chosen_key, args.rejected_key]:
        if col not in df.columns:
            raise ValueError(f"列 {col} 不在 parquet 中,现有列:{list(df.columns)}")

    prompts  = df[args.prompt_key].fillna("").astype(str).tolist()
    chosens  = df[args.chosen_key].fillna("").astype(str).tolist()
    rejects  = df[args.rejected_key].fillna("").astype(str).tolist()

    # 拼接 query
    sep = args.sep.encode("utf-8").decode("unicode_escape")  # 支持传入 "\n" 这样的转义
    chosen_queries  = [build_text(p, c, sep) for p, c in zip(prompts, chosens)]
    rejected_queries= [build_text(p, r, sep) for p, r in zip(prompts, rejects)]

    N = len(chosen_queries)
    chosen_scores, rejected_scores, accs = [], [], []

    seen, correct = 0, 0
    pbar = tqdm(range(0, N, args.batch_size), desc="HTTP Scoring")

    for i in pbar:
        j = min(i + args.batch_size, N)
        # 先打 chosen,再打 reject;也可以并发,这里为简洁起见串行
        ch_scores = post_batch(args.server_url, chosen_queries[i:j])
        rj_scores = post_batch(args.server_url, rejected_queries[i:j])

        for k, (cs, rs) in enumerate(zip(ch_scores, rj_scores)):
            delta = cs - rs
            acc   = 1 if delta > 0 else 0
            chosen_scores.append(cs)
            rejected_scores.append(rs)
            accs.append(acc)

            seen += 1
            correct += acc
            running_acc = correct / seen

            if args.print_each:
                # 样本的全局 index
                idx = i + k
                tqdm.write(f"[{idx}] acc={acc}, chosen={cs:.3f}, rejected={rs:.3f}, Δ={delta:.3f} | avg acc={running_acc:.3f}")

        pbar.set_postfix({"avg_acc": f"{running_acc:.3f}"})

    # 汇总并保存
    out = df.copy()
    out["chosen_score"]   = chosen_scores
    out["rejected_score"] = rejected_scores
    out["delta"]          = out["chosen_score"] - out["rejected_score"]
    out["acc"]            = accs

    final_acc   = float(out["acc"].mean()) if len(out) else 0.0
    mean_chosen = float(out["chosen_score"].mean()) if len(out) else 0.0
    mean_reject = float(out["rejected_score"].mean()) if len(out) else 0.0
    mean_delta  = float(out["delta"].mean()) if len(out) else 0.0

    print("\n=========== RESULT (HTTP) ===========")
    print(f"✅ Accuracy        = {final_acc:.4f}  ({sum(accs)}/{len(accs)})")
    print(f"📊 Mean chosen     = {mean_chosen:.4f}")
    print(f"📉 Mean rejected   = {mean_reject:.4f}")
    print(f"🔼 Mean delta      = {mean_delta:.4f}")

    out.to_csv(args.save_csv, index=False)
    print(f"💾 Saved details to: {args.save_csv}")

if __name__ == "__main__":
    main()