rm_code / http_rm.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import time
import requests
import pandas as pd
from tqdm import tqdm
def build_text(prompt: str, answer: str, sep: str) -> str:
"""把 prompt + answer 按 sep 拼起来。"""
prompt = "" if prompt is None else str(prompt)
answer = "" if answer is None else str(answer)
return (prompt + sep + answer).strip()
def post_batch(server_url: str, queries, timeout=120, max_retries=3, sleep=1.0):
"""调用服务端 /get_reward,返回 rewards 列表(与输入 len 一致)。"""
url = server_url.rstrip("/") + "/get_reward"
payload = {"query": queries, "prompts": None}
last_err = None
for _ in range(max_retries):
try:
resp = requests.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
data = resp.json()
rewards = data.get("rewards") or data.get("scores")
if not isinstance(rewards, list):
raise ValueError(f"Bad response: {data}")
if len(rewards) != len(queries):
raise ValueError(f"Length mismatch: got {len(rewards)} for {len(queries)} queries")
return rewards
except Exception as e:
last_err = e
time.sleep(sleep)
raise RuntimeError(f"Request failed after retries: {last_err}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--server_url", type=str, required=True, help="FastAPI 服务地址,如 http://localhost:5000")
ap.add_argument("--data_path", type=str, required=True, help="parquet 文件路径")
ap.add_argument("--prompt_key", type=str, default="chosen_prompt", help="prompt 字段名")
ap.add_argument("--chosen_key", type=str, default="chosen", help="chosen 回答字段名")
ap.add_argument("--rejected_key", type=str, default="reject", help="rejected 回答字段名")
ap.add_argument("--batch_size", type=int, default=64, help="每次请求发送多少条")
ap.add_argument("--sep", type=str, default="", help="prompt 与回答之间的连接符,如 '\\n' 或 空串")
ap.add_argument("--save_csv", type=str, default="rm_http_eval.csv", help="保存明细的 CSV 路径")
ap.add_argument("--print_each", action="store_true", help="逐样本打印 chosen/reject 分数与实时平均 acc")
args = ap.parse_args()
# 读取数据
df = pd.read_parquet(args.data_path)
for col in [args.prompt_key, args.chosen_key, args.rejected_key]:
if col not in df.columns:
raise ValueError(f"列 {col} 不在 parquet 中,现有列:{list(df.columns)}")
prompts = df[args.prompt_key].fillna("").astype(str).tolist()
chosens = df[args.chosen_key].fillna("").astype(str).tolist()
rejects = df[args.rejected_key].fillna("").astype(str).tolist()
# 拼接 query
sep = args.sep.encode("utf-8").decode("unicode_escape") # 支持传入 "\n" 这样的转义
chosen_queries = [build_text(p, c, sep) for p, c in zip(prompts, chosens)]
rejected_queries= [build_text(p, r, sep) for p, r in zip(prompts, rejects)]
N = len(chosen_queries)
chosen_scores, rejected_scores, accs = [], [], []
seen, correct = 0, 0
pbar = tqdm(range(0, N, args.batch_size), desc="HTTP Scoring")
for i in pbar:
j = min(i + args.batch_size, N)
# 先打 chosen,再打 reject;也可以并发,这里为简洁起见串行
ch_scores = post_batch(args.server_url, chosen_queries[i:j])
rj_scores = post_batch(args.server_url, rejected_queries[i:j])
for k, (cs, rs) in enumerate(zip(ch_scores, rj_scores)):
delta = cs - rs
acc = 1 if delta > 0 else 0
chosen_scores.append(cs)
rejected_scores.append(rs)
accs.append(acc)
seen += 1
correct += acc
running_acc = correct / seen
if args.print_each:
# 样本的全局 index
idx = i + k
tqdm.write(f"[{idx}] acc={acc}, chosen={cs:.3f}, rejected={rs:.3f}, Δ={delta:.3f} | avg acc={running_acc:.3f}")
pbar.set_postfix({"avg_acc": f"{running_acc:.3f}"})
# 汇总并保存
out = df.copy()
out["chosen_score"] = chosen_scores
out["rejected_score"] = rejected_scores
out["delta"] = out["chosen_score"] - out["rejected_score"]
out["acc"] = accs
final_acc = float(out["acc"].mean()) if len(out) else 0.0
mean_chosen = float(out["chosen_score"].mean()) if len(out) else 0.0
mean_reject = float(out["rejected_score"].mean()) if len(out) else 0.0
mean_delta = float(out["delta"].mean()) if len(out) else 0.0
print("\n=========== RESULT (HTTP) ===========")
print(f"✅ Accuracy = {final_acc:.4f} ({sum(accs)}/{len(accs)})")
print(f"📊 Mean chosen = {mean_chosen:.4f}")
print(f"📉 Mean rejected = {mean_reject:.4f}")
print(f"🔼 Mean delta = {mean_delta:.4f}")
out.to_csv(args.save_csv, index=False)
print(f"💾 Saved details to: {args.save_csv}")
if __name__ == "__main__":
main()