#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 基于生产端 SSE 接口的对齐评估脚本 功能: - 读取标注数据(包含 conversations 与 pair2 目标工具) - 直接请求生产接口(默认为 http://125.122.38.32:8085/mcp_end2end/stream)获取检索Top5与最终工具调用 - 计算 recall@5 与 precision@1,并输出报告 使用示例: python eval_via_prod_sse.py \ --input_file /home/ziqiang/LLaMA-Factory/data/dataset/10_27/10.22_evaluate_data.json \ --output_file /home/ziqiang/LLaMA-Factory/data/dataset/10_27/data_evaluation_prod.json \ --start_idx 0 --end_idx 50 """ import asyncio import aiohttp import argparse import json import os from typing import Any, Dict, List, Optional, Tuple from datetime import datetime PROD_SSE_URL = os.getenv("PROD_SSE_URL", "http://125.122.38.32:8085/mcp_end2end/stream") RETRIEVAL_ENDPOINT = os.getenv("RETRIEVAL_ENDPOINT", "http://125.122.38.32:6227/v1/mcp/tools/call") def _extract_user_query_from_conversation(item: Dict[str, Any]) -> str: """从标注数据的一条对话中提取原始用户问题(第一条 human)。""" conversations = item.get("conversations", []) for msg in conversations: if msg.get("from") == "human": return str(msg.get("value", "")) return "" def _extract_pair2_target_tool(item: Dict[str, Any]) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: """从标注数据中提取 pair2 的目标工具名和参数(查找 human->observation 后、下一条为 function_call 的目标)。""" conversations = item.get("conversations", []) # 逻辑:遇到 observation 后紧邻的下一条若是 function_call,则视为 pair2 目标 for i, msg in enumerate(conversations): if msg.get("from") == "observation": if i + 1 < len(conversations) and conversations[i + 1].get("from") == "function_call": target_raw = conversations[i + 1].get("value", "") try: target_obj = json.loads(target_raw) return target_obj.get("name"), target_obj.get("arguments", {}) except Exception: # 非标准JSON则返回原串 return None, None return None, None async def call_prod_sse_for_case(session: aiohttp.ClientSession, query: str, user_id: str = "1") -> Dict[str, Any]: """调用生产 SSE 接口,抓取第一跳检索Top5与最终选择的工具。""" payload = { "query": query, "prompt_template": "standard", "user_id": user_id, "role_code": 1, # 必填字段 "user_history": [], "save_method": 0, "is_vector": False, "is_probabilistic": False, "use_retrieval": True, "tool_category": None, "mcp_data": None, # 应为字符串或 None,不是对象 "front_data": {}, "message_id": f"eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}" } retrieved_top5: List[str] = [] predicted_tool: Optional[Dict[str, Any]] = None retrieval_call_params: Optional[Dict[str, Any]] = None # 作为回退的数据缓存 _delta_toolcall_buffer: List[str] = [] _last_result_delta: Optional[Dict[str, Any]] = None headers = { "Accept": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/json" } event_count = 0 http_status = None error_msg = None # 最长等待时长与最短驻留,避免瞬时退出 max_duration_sec = 30 min_duration_sec = 3 start_ts = datetime.now().timestamp() async with session.post(PROD_SSE_URL, json=payload, headers=headers) as resp: http_status = resp.status if http_status != 200: try: text = await resp.text() except Exception: text = "" return { "retrieved_top5": [], "predicted_tool": None, "retrieval_call_params": None, "http_status": http_status, "error": f"non-200 response: {http_status}", "response_preview": text[:1000] } current_event = None try: async for raw in resp.content: try: line = raw.decode("utf-8", errors="ignore").strip() except Exception: continue if not line: continue # 忽略 SSE id 行 if line.startswith("id:"): continue if line.startswith("event:"): current_event = line.split("event:", 1)[1].strip() continue if not line.startswith("data:"): continue data_str = line.split("data:", 1)[1].strip() if data_str == "[DONE]": # 若流很快结束但仍未达到最短驻留,则继续等待片刻以容错代理缓冲 if datetime.now().timestamp() - start_ts < min_duration_sec: await asyncio.sleep(min_duration_sec - (datetime.now().timestamp() - start_ts)) break try: data = json.loads(data_str) except Exception: continue event_count += 1 if current_event == "tool_call.create.delta": # 增量工具调用内容,通常是字符串形式,累积以便必要时解析 content = data.get("content") if isinstance(content, str) and content: _delta_toolcall_buffer.append(content) elif current_event == "tool_call.created": tool_call = data.get("tool_call") if isinstance(tool_call, dict): name = tool_call.get("name") if name == "retrieval_tool" and retrieval_call_params is None: retrieval_call_params = tool_call elif name and name != "retrieval_tool": predicted_tool = tool_call elif current_event == "tool_response.completed": result = data.get("result_delta") if isinstance(result, dict): _last_result_delta = result names: List[str] = [] if isinstance(result, list): for item in result[:5]: if isinstance(item, dict) and isinstance(item.get("name"), str): names.append(item["name"]) elif isinstance(result, dict): tools = result.get("tools") or [] for item in tools[:5]: if isinstance(item, dict) and isinstance(item.get("name"), str): names.append(item["name"]) # 回退:从 tool_calling_chain 中解析候选工具 if not names: chain = result.get("tool_calling_chain") or [] if isinstance(chain, list) and chain: first = chain[0] t_resp = first.get("tool_response") if isinstance(first, dict) else None if isinstance(t_resp, list): for item in t_resp[:5]: if isinstance(item, dict) and isinstance(item.get("name"), str): names.append(item["name"]) if names: retrieved_top5 = names elif current_event in ("answer.completed", "response.completed"): # 进一步回退:某些实现会把完整链路放到 round_data 或 usage 区域 # 这里尝试从 data 中再次提取 tool_calling_chain 作为候选 chain = data.get("tool_calling_chain") or {} if not chain: round_data = data.get("round_data") or {} chain = round_data.get("tool_calling_chain") if isinstance(round_data, dict) else {} names: List[str] = [] if isinstance(chain, list) and chain: first = chain[0] t_resp = first.get("tool_response") if isinstance(first, dict) else None if isinstance(t_resp, list): for item in t_resp[:5]: if isinstance(item, dict) and isinstance(item.get("name"), str): names.append(item["name"]) if names and not retrieved_top5: retrieved_top5 = names # 超时保护:若超过最长时长则跳出 if datetime.now().timestamp() - start_ts > max_duration_sec: error_msg = f"timeout after {max_duration_sec}s" break except Exception as e: error_msg = str(e) return { "retrieved_top5": retrieved_top5, "predicted_tool": predicted_tool, "retrieval_call_params": retrieval_call_params, "http_status": http_status, "event_count": event_count, **({"error": error_msg} if error_msg else {}) } def _extract_tool_names_from_response(resp_obj: Dict[str, Any], top_k: int = 5) -> List[str]: names: List[str] = [] try: if isinstance(resp_obj.get("result"), list): for item in resp_obj["result"][:top_k]: if isinstance(item, dict) and isinstance(item.get("name"), str): names.append(item["name"]) elif isinstance(resp_obj.get("result"), dict): tools = resp_obj["result"].get("tools") or [] for item in tools[:top_k]: if isinstance(item, dict) and isinstance(item.get("name"), str): names.append(item["name"]) elif isinstance(resp_obj.get("data"), list): for item in resp_obj["data"][:top_k]: if isinstance(item, dict) and isinstance(item.get("name"), str): names.append(item["name"]) if not names: # 兜底在全文中提取 "name": "..." text = json.dumps(resp_obj, ensure_ascii=False) import re as _re names = _re.findall(r'"name"\s*:\s*"([^"]+)"', text)[:top_k] except Exception: pass return names[:top_k] async def call_retrieval_endpoint(session: aiohttp.ClientSession, retrieval_call_params: Dict[str, Any]) -> Dict[str, Any]: payload = { "jsonrpc": "2.0", "id": "eval_req_001", "method": "tools/call", "params": { "name": "retrieval_tool", "arguments": retrieval_call_params.get("arguments", retrieval_call_params) or {} } } headers = {"Content-Type": "application/json", "Accept": "application/json"} status = None try: async with session.post(RETRIEVAL_ENDPOINT, json=payload, headers=headers) as resp: status = resp.status try: data = await resp.json() except Exception: data = {"raw": (await resp.text())[:2000]} names = _extract_tool_names_from_response(data, top_k=5) if status == 200 else [] return { "retrieval_http_status": status, "retrieval_response_preview": json.dumps(data, ensure_ascii=False)[:2000], "retrieved_top5": names } except Exception as e: return { "retrieval_http_status": status, "retrieval_error": str(e), "retrieved_top5": [] } def calc_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]: total = len(rows) if total == 0: return {"total": 0, "recall@5": 0.0, "precision@1": 0.0, "arg_accuracy": 0.0, "arg_denominator": 0} recall_hits = 0 precision_hits = 0 precision_denominator = 0 # precision@1的分母:recall@5成功的样本数 arg_hits = 0 arg_total = 0 # 仅统计 recall 成功的样本 for r in rows: tgt_name = r.get("target_tool_name") retrieved = r.get("retrieved_top5", []) or [] pred = (r.get("predicted_tool") or {}).get("name") recall_success = tgt_name and tgt_name in retrieved if recall_success: recall_hits += 1 precision_denominator += 1 # 只有在recall成功时才计入precision分母 # precision@1只在recall@5成功的样本中计算 if tgt_name and pred and tgt_name == pred: precision_hits += 1 # 仅在 arg_match 有效(非 None)时计入参数准确率 if r.get("arg_match") is not None: arg_total += 1 if r.get("arg_match") == 1: arg_hits += 1 return { "total": total, "recall@5": recall_hits / total, "precision@1": (precision_hits / precision_denominator) if precision_denominator > 0 else 0.0, "precision_denominator": precision_denominator, # 记录precision@1的分母 "arg_accuracy": (arg_hits / arg_total) if arg_total > 0 else 0.0, "arg_denominator": arg_total } async def main(): parser = argparse.ArgumentParser(description="基于生产SSE接口的评估") parser.add_argument("--input_file", "-i", type=str, required=True) parser.add_argument("--output_file", "-o", type=str, required=True) parser.add_argument("--start_idx", "-s", type=int, default=0) parser.add_argument("--end_idx", "-e", type=int, default=50) parser.add_argument("--user_id", type=str, default=os.getenv("EVAL_USER_ID", "1")) parser.add_argument("--checkpoint_file", "-c", type=str, default=None, help="断点续跑文件路径;提供则每个case完成后都会写入断点进度") args = parser.parse_args() with open(args.input_file, "r", encoding="utf-8") as f: data = json.load(f) start = max(0, args.start_idx) end = min(len(data), args.end_idx if args.end_idx is not None else len(data)) if start >= end: print("无有效评估范围") return results: List[Dict[str, Any]] = [] processed_indices: set = set() # 读取断点文件 if args.checkpoint_file and os.path.exists(args.checkpoint_file): try: with open(args.checkpoint_file, "r", encoding="utf-8") as f: ckpt = json.load(f) results = ckpt.get("results", []) processed_indices = set(ckpt.get("processed_indices", [])) print(f"🔁 从断点恢复:已处理 {len(processed_indices)} 个cases;继续评估...") except Exception as e: print(f"⚠️ 读取断点失败,将从头开始: {e}") results = [] processed_indices = set() timeout = aiohttp.ClientTimeout(total=600) connector = aiohttp.TCPConnector(limit=5, limit_per_host=5) total_cases = end - start print(f"\n开始评估,共 {total_cases} 个cases ({start} 到 {end-1})") print("=" * 60) async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: for idx in range(start, end): if idx in processed_indices: # 跳过已处理 current_num = (idx - start + 1) total_cases = end - start print(f"[{current_num}/{total_cases}] 跳过已完成的 case {idx}") continue item = data[idx] user_query = _extract_user_query_from_conversation(item) target_tool_name, target_args = _extract_pair2_target_tool(item) current_num = idx - start + 1 print(f"[{current_num}/{total_cases}] 处理 case {idx}...", end=" ", flush=True) try: r = await call_prod_sse_for_case(session, user_query, user_id=args.user_id) except Exception as e: r = {"error": str(e), "retrieved_top5": [], "predicted_tool": None, "retrieval_call_params": None} print(f"❌ SSE调用失败: {str(e)[:50]}") # 若前端SSE未直接给出 Top5,则根据第一跳 retrieval 的 query 再请求 6227 网关获取Top5 if (not r.get("retrieved_top5")) and r.get("retrieval_call_params"): try: retrieval_summary = await call_retrieval_endpoint(session, r["retrieval_call_params"]) r.update(retrieval_summary) except Exception as e: r.update({"retrieval_error": str(e)}) # 计算该case的recall@5(1表示成功,0表示失败) retrieved_top5 = r.get("retrieved_top5", []) or [] recall_value = 1 if (target_tool_name and target_tool_name in retrieved_top5) else 0 # 计算第二轮 function_call 的参数匹配(与标注 target_arguments 比较) # 仅在 recall 成功(recall_value==1)且工具名称也匹配时才计算参数准确率 pred_tool_name = (r.get("predicted_tool") or {}).get("name") pred_args = (r.get("predicted_tool") or {}).get("arguments") or {} tgt_args = target_args or {} arg_details = {} # 只有工具名称匹配时,参数匹配才有意义 if recall_value == 1 and pred_tool_name == target_tool_name and isinstance(tgt_args, dict) and tgt_args: all_match = True for k, v in tgt_args.items(): pv = pred_args.get(k) is_match = (pv == v) arg_details[k] = {"target": v, "predict": pv, "match": is_match} if not is_match: all_match = False arg_match_value = 1 if all_match else 0 else: arg_match_value = None # 跳过参数评估(工具未召回、工具名不匹配或无目标参数) row = { "index": idx, "user_query": user_query, "target_tool_name": target_tool_name, "target_arguments": target_args, "recall@5": recall_value, # 添加recall字段:1表示成功,0表示失败 "arg_match": arg_match_value, # 第二轮参数是否与标注完全一致(1/0) "arg_match_details": arg_details, **r } results.append(row) processed_indices.add(idx) # 打印进度信息 retrieved_str = f"检索到{len(retrieved_top5)}个工具" if retrieved_top5 else "未检索到工具" recall_str = "✅ recall成功" if recall_value == 1 else "❌ recall失败" pred_name = (r.get("predicted_tool") or {}).get("name") or "无" if arg_match_value is None: if recall_value == 0: arg_str = "⏭️ 参数评估跳过(未召回)" elif pred_tool_name != target_tool_name: arg_str = f"⏭️ 参数评估跳过(工具不匹配: {pred_name} ≠ {target_tool_name})" else: arg_str = "⏭️ 参数评估跳过(无目标参数)" else: arg_str = "✅ 参数完全匹配" if arg_match_value == 1 else "❌ 参数不匹配" print(f"{retrieved_str} | {recall_str} | 预测工具: {pred_name} | {arg_str}") # 实时保存断点 if args.checkpoint_file: try: ckpt_data = { "results": results, "processed_indices": sorted(list(processed_indices)), "meta": { "input_file": args.input_file, "output_file": args.output_file, "user_id": args.user_id, "range": [start, end] } } os.makedirs(os.path.dirname(args.checkpoint_file), exist_ok=True) with open(args.checkpoint_file, "w", encoding="utf-8") as f: json.dump(ckpt_data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"⚠️ 写入断点失败: {e}") print("=" * 60) print(f"评估完成,共处理 {len(results)} 个cases\n") metrics = calc_metrics(results) report = { "summary": { "api": PROD_SSE_URL, "user_id": args.user_id, "start_idx": start, "end_idx": end, "metrics": metrics }, "cases": results } os.makedirs(os.path.dirname(args.output_file), exist_ok=True) with open(args.output_file, "w", encoding="utf-8") as f: json.dump(report, f, ensure_ascii=False, indent=2) print(f"✅ 评估结果已保存: {args.output_file}") print(f"📊 总体指标: recall@5={metrics['recall@5']:.3f}, precision@1={metrics['precision@1']:.3f}\n") # 评估完成后删除断点(可选) if args.checkpoint_file and os.path.exists(args.checkpoint_file): try: os.remove(args.checkpoint_file) print(f"🗑️ 已删除断点文件: {args.checkpoint_file}") except Exception as e: print(f"⚠️ 删除断点文件失败: {e}") # 筛选recall=0的cases,单独保存 recall_failed_cases = [case for case in results if case.get("recall@5", 0) == 0] if recall_failed_cases: failed_output_file = args.output_file.replace(".json", "_recall_failed.json") failed_report = { "summary": { "api": PROD_SSE_URL, "user_id": args.user_id, "start_idx": start, "end_idx": end, "total_failed": len(recall_failed_cases), "total_cases": len(results), "failure_rate": len(recall_failed_cases) / len(results) if results else 0.0 }, "cases": recall_failed_cases } with open(failed_output_file, "w", encoding="utf-8") as f: json.dump(failed_report, f, ensure_ascii=False, indent=2) print(f"❌ Recall失败cases已保存: {failed_output_file} (共 {len(recall_failed_cases)} 条)") print(f" 失败率: {len(recall_failed_cases) / len(results) * 100:.1f}%") else: print("✅ 所有cases的recall@5都成功!") # 筛选precision@1失败的cases(仅在recall@5成功的样本中,预测工具名不等于目标工具名) precision_failed_cases = [] precision_eligible = 0 # recall@5成功的样本数(precision@1的分母) for case in results: recall_success = case.get("recall@5", 0) == 1 if recall_success: precision_eligible += 1 tgt = case.get("target_tool_name") pred = (case.get("predicted_tool") or {}).get("name") # 只有recall@5成功,但预测工具名不等于目标工具名时,才算precision@1失败 if recall_success and tgt and (pred != tgt): precision_failed_cases.append(case) if precision_failed_cases: precision_failed_output = args.output_file.replace(".json", "_precision_failed.json") precision_failed_report = { "summary": { "api": PROD_SSE_URL, "user_id": args.user_id, "start_idx": start, "end_idx": end, "total_failed": len(precision_failed_cases), "total_cases": len(results), "precision_eligible": precision_eligible, # recall@5成功的样本数 "failure_rate": (len(precision_failed_cases) / precision_eligible) if precision_eligible > 0 else 0.0 }, "cases": precision_failed_cases } with open(precision_failed_output, "w", encoding="utf-8") as f: json.dump(precision_failed_report, f, ensure_ascii=False, indent=2) print(f"❌ Precision@1失败cases已保存: {precision_failed_output} (共 {len(precision_failed_cases)} 条)") if precision_eligible > 0: print(f" 基于recall@5成功的样本失败率: {len(precision_failed_cases) / precision_eligible * 100:.1f}% (可评估 {precision_eligible} 条)") else: if precision_eligible > 0: print("✅ 所有recall@5成功的cases在precision@1上均命中!") else: print("ℹ️ 本次无recall@5成功的样本,无法评估precision@1") # 筛选参数匹配失败(arg_accuracy失败)的cases(仅统计被评估样本 arg_match==0) arg_failed_cases = [case for case in results if case.get("arg_match") == 0] arg_eligible = sum(1 for case in results if case.get("arg_match") is not None) if arg_failed_cases: arg_failed_output = args.output_file.replace(".json", "_arg_failed.json") arg_failed_report = { "summary": { "api": PROD_SSE_URL, "user_id": args.user_id, "start_idx": start, "end_idx": end, "total_failed": len(arg_failed_cases), "total_eligible": arg_eligible, "eligible_failure_rate": (len(arg_failed_cases) / arg_eligible) if arg_eligible else 0.0 }, "cases": arg_failed_cases } with open(arg_failed_output, "w", encoding="utf-8") as f: json.dump(arg_failed_report, f, ensure_ascii=False, indent=2) print(f"❌ 参数匹配失败cases已保存: {arg_failed_output} (共 {len(arg_failed_cases)} 条)") if arg_eligible: print(f" 基于可评估样本失败率: {len(arg_failed_cases) / arg_eligible * 100:.1f}% (可评估 {arg_eligible} 条)") else: if arg_eligible: print("✅ 所有可评估的cases参数完全匹配!") else: print("ℹ️ 本次无可评估的参数匹配样本(arg_match 皆为 None)") if __name__ == "__main__": asyncio.run(main())