#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import time import requests import pandas as pd from tqdm import tqdm def build_text(prompt: str, answer: str, sep: str) -> str: """把 prompt + answer 按 sep 拼起来。""" prompt = "" if prompt is None else str(prompt) answer = "" if answer is None else str(answer) return (prompt + sep + answer).strip() def post_batch(server_url: str, queries, timeout=120, max_retries=3, sleep=1.0): """调用服务端 /get_reward,返回 rewards 列表(与输入 len 一致)。""" url = server_url.rstrip("/") + "/get_reward" payload = {"query": queries, "prompts": None} last_err = None for _ in range(max_retries): try: resp = requests.post(url, json=payload, timeout=timeout) resp.raise_for_status() data = resp.json() rewards = data.get("rewards") or data.get("scores") if not isinstance(rewards, list): raise ValueError(f"Bad response: {data}") if len(rewards) != len(queries): raise ValueError(f"Length mismatch: got {len(rewards)} for {len(queries)} queries") return rewards except Exception as e: last_err = e time.sleep(sleep) raise RuntimeError(f"Request failed after retries: {last_err}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--server_url", type=str, required=True, help="FastAPI 服务地址,如 http://localhost:5000") ap.add_argument("--data_path", type=str, required=True, help="parquet 文件路径") ap.add_argument("--prompt_key", type=str, default="chosen_prompt", help="prompt 字段名") ap.add_argument("--chosen_key", type=str, default="chosen", help="chosen 回答字段名") ap.add_argument("--rejected_key", type=str, default="reject", help="rejected 回答字段名") ap.add_argument("--batch_size", type=int, default=64, help="每次请求发送多少条") ap.add_argument("--sep", type=str, default="", help="prompt 与回答之间的连接符,如 '\\n' 或 空串") ap.add_argument("--save_csv", type=str, default="rm_http_eval.csv", help="保存明细的 CSV 路径") ap.add_argument("--print_each", action="store_true", help="逐样本打印 chosen/reject 分数与实时平均 acc") args = ap.parse_args() # 读取数据 df = pd.read_parquet(args.data_path) for col in [args.prompt_key, args.chosen_key, args.rejected_key]: if col not in df.columns: raise ValueError(f"列 {col} 不在 parquet 中,现有列:{list(df.columns)}") prompts = df[args.prompt_key].fillna("").astype(str).tolist() chosens = df[args.chosen_key].fillna("").astype(str).tolist() rejects = df[args.rejected_key].fillna("").astype(str).tolist() # 拼接 query sep = args.sep.encode("utf-8").decode("unicode_escape") # 支持传入 "\n" 这样的转义 chosen_queries = [build_text(p, c, sep) for p, c in zip(prompts, chosens)] rejected_queries= [build_text(p, r, sep) for p, r in zip(prompts, rejects)] N = len(chosen_queries) chosen_scores, rejected_scores, accs = [], [], [] seen, correct = 0, 0 pbar = tqdm(range(0, N, args.batch_size), desc="HTTP Scoring") for i in pbar: j = min(i + args.batch_size, N) # 先打 chosen,再打 reject;也可以并发,这里为简洁起见串行 ch_scores = post_batch(args.server_url, chosen_queries[i:j]) rj_scores = post_batch(args.server_url, rejected_queries[i:j]) for k, (cs, rs) in enumerate(zip(ch_scores, rj_scores)): delta = cs - rs acc = 1 if delta > 0 else 0 chosen_scores.append(cs) rejected_scores.append(rs) accs.append(acc) seen += 1 correct += acc running_acc = correct / seen if args.print_each: # 样本的全局 index idx = i + k tqdm.write(f"[{idx}] acc={acc}, chosen={cs:.3f}, rejected={rs:.3f}, Δ={delta:.3f} | avg acc={running_acc:.3f}") pbar.set_postfix({"avg_acc": f"{running_acc:.3f}"}) # 汇总并保存 out = df.copy() out["chosen_score"] = chosen_scores out["rejected_score"] = rejected_scores out["delta"] = out["chosen_score"] - out["rejected_score"] out["acc"] = accs final_acc = float(out["acc"].mean()) if len(out) else 0.0 mean_chosen = float(out["chosen_score"].mean()) if len(out) else 0.0 mean_reject = float(out["rejected_score"].mean()) if len(out) else 0.0 mean_delta = float(out["delta"].mean()) if len(out) else 0.0 print("\n=========== RESULT (HTTP) ===========") print(f"✅ Accuracy = {final_acc:.4f} ({sum(accs)}/{len(accs)})") print(f"📊 Mean chosen = {mean_chosen:.4f}") print(f"📉 Mean rejected = {mean_reject:.4f}") print(f"🔼 Mean delta = {mean_delta:.4f}") out.to_csv(args.save_csv, index=False) print(f"💾 Saved details to: {args.save_csv}") if __name__ == "__main__": main()