| | |
| | |
| |
|
| | import argparse |
| | import time |
| | import requests |
| | import pandas as pd |
| | from tqdm import tqdm |
| |
|
| | def build_text(prompt: str, answer: str, sep: str) -> str: |
| | """把 prompt + answer 按 sep 拼起来。""" |
| | prompt = "" if prompt is None else str(prompt) |
| | answer = "" if answer is None else str(answer) |
| | return (prompt + sep + answer).strip() |
| |
|
| | def post_batch(server_url: str, queries, timeout=120, max_retries=3, sleep=1.0): |
| | """调用服务端 /get_reward,返回 rewards 列表(与输入 len 一致)。""" |
| | url = server_url.rstrip("/") + "/get_reward" |
| | payload = {"query": queries, "prompts": None} |
| | last_err = None |
| | for _ in range(max_retries): |
| | try: |
| | resp = requests.post(url, json=payload, timeout=timeout) |
| | resp.raise_for_status() |
| | data = resp.json() |
| | rewards = data.get("rewards") or data.get("scores") |
| | if not isinstance(rewards, list): |
| | raise ValueError(f"Bad response: {data}") |
| | if len(rewards) != len(queries): |
| | raise ValueError(f"Length mismatch: got {len(rewards)} for {len(queries)} queries") |
| | return rewards |
| | except Exception as e: |
| | last_err = e |
| | time.sleep(sleep) |
| | raise RuntimeError(f"Request failed after retries: {last_err}") |
| |
|
| | def main(): |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--server_url", type=str, required=True, help="FastAPI 服务地址,如 http://localhost:5000") |
| | ap.add_argument("--data_path", type=str, required=True, help="parquet 文件路径") |
| | ap.add_argument("--prompt_key", type=str, default="chosen_prompt", help="prompt 字段名") |
| | ap.add_argument("--chosen_key", type=str, default="chosen", help="chosen 回答字段名") |
| | ap.add_argument("--rejected_key", type=str, default="reject", help="rejected 回答字段名") |
| | ap.add_argument("--batch_size", type=int, default=64, help="每次请求发送多少条") |
| | ap.add_argument("--sep", type=str, default="", help="prompt 与回答之间的连接符,如 '\\n' 或 空串") |
| | ap.add_argument("--save_csv", type=str, default="rm_http_eval.csv", help="保存明细的 CSV 路径") |
| | ap.add_argument("--print_each", action="store_true", help="逐样本打印 chosen/reject 分数与实时平均 acc") |
| | args = ap.parse_args() |
| |
|
| | |
| | df = pd.read_parquet(args.data_path) |
| | for col in [args.prompt_key, args.chosen_key, args.rejected_key]: |
| | if col not in df.columns: |
| | raise ValueError(f"列 {col} 不在 parquet 中,现有列:{list(df.columns)}") |
| |
|
| | prompts = df[args.prompt_key].fillna("").astype(str).tolist() |
| | chosens = df[args.chosen_key].fillna("").astype(str).tolist() |
| | rejects = df[args.rejected_key].fillna("").astype(str).tolist() |
| |
|
| | |
| | sep = args.sep.encode("utf-8").decode("unicode_escape") |
| | chosen_queries = [build_text(p, c, sep) for p, c in zip(prompts, chosens)] |
| | rejected_queries= [build_text(p, r, sep) for p, r in zip(prompts, rejects)] |
| |
|
| | N = len(chosen_queries) |
| | chosen_scores, rejected_scores, accs = [], [], [] |
| |
|
| | seen, correct = 0, 0 |
| | pbar = tqdm(range(0, N, args.batch_size), desc="HTTP Scoring") |
| |
|
| | for i in pbar: |
| | j = min(i + args.batch_size, N) |
| | |
| | ch_scores = post_batch(args.server_url, chosen_queries[i:j]) |
| | rj_scores = post_batch(args.server_url, rejected_queries[i:j]) |
| |
|
| | for k, (cs, rs) in enumerate(zip(ch_scores, rj_scores)): |
| | delta = cs - rs |
| | acc = 1 if delta > 0 else 0 |
| | chosen_scores.append(cs) |
| | rejected_scores.append(rs) |
| | accs.append(acc) |
| |
|
| | seen += 1 |
| | correct += acc |
| | running_acc = correct / seen |
| |
|
| | if args.print_each: |
| | |
| | idx = i + k |
| | tqdm.write(f"[{idx}] acc={acc}, chosen={cs:.3f}, rejected={rs:.3f}, Δ={delta:.3f} | avg acc={running_acc:.3f}") |
| |
|
| | pbar.set_postfix({"avg_acc": f"{running_acc:.3f}"}) |
| |
|
| | |
| | out = df.copy() |
| | out["chosen_score"] = chosen_scores |
| | out["rejected_score"] = rejected_scores |
| | out["delta"] = out["chosen_score"] - out["rejected_score"] |
| | out["acc"] = accs |
| |
|
| | final_acc = float(out["acc"].mean()) if len(out) else 0.0 |
| | mean_chosen = float(out["chosen_score"].mean()) if len(out) else 0.0 |
| | mean_reject = float(out["rejected_score"].mean()) if len(out) else 0.0 |
| | mean_delta = float(out["delta"].mean()) if len(out) else 0.0 |
| |
|
| | print("\n=========== RESULT (HTTP) ===========") |
| | print(f"✅ Accuracy = {final_acc:.4f} ({sum(accs)}/{len(accs)})") |
| | print(f"📊 Mean chosen = {mean_chosen:.4f}") |
| | print(f"📉 Mean rejected = {mean_reject:.4f}") |
| | print(f"🔼 Mean delta = {mean_delta:.4f}") |
| |
|
| | out.to_csv(args.save_csv, index=False) |
| | print(f"💾 Saved details to: {args.save_csv}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|