| |
| |
| """ |
| 基于生产端 SSE 接口的对齐评估脚本 |
| |
| 功能: |
| - 读取标注数据(包含 conversations 与 pair2 目标工具) |
| - 直接请求生产接口(默认为 http://125.122.38.32:8085/mcp_end2end/stream)获取检索Top5与最终工具调用 |
| - 计算 recall@5 与 precision@1,并输出报告 |
| |
| 使用示例: |
| python eval_via_prod_sse.py \ |
| --input_file /home/ziqiang/LLaMA-Factory/data/dataset/10_27/10.22_evaluate_data.json \ |
| --output_file /home/ziqiang/LLaMA-Factory/data/dataset/10_27/data_evaluation_prod.json \ |
| --start_idx 0 --end_idx 50 |
| """ |
|
|
| import asyncio |
| import aiohttp |
| import argparse |
| import json |
| import os |
| from typing import Any, Dict, List, Optional, Tuple |
| from datetime import datetime |
|
|
|
|
| PROD_SSE_URL = os.getenv("PROD_SSE_URL", "http://125.122.38.32:8085/mcp_end2end/stream") |
| RETRIEVAL_ENDPOINT = os.getenv("RETRIEVAL_ENDPOINT", "http://125.122.38.32:6227/v1/mcp/tools/call") |
|
|
|
|
| def _extract_user_query_from_conversation(item: Dict[str, Any]) -> str: |
| """从标注数据的一条对话中提取原始用户问题(第一条 human)。""" |
| conversations = item.get("conversations", []) |
| for msg in conversations: |
| if msg.get("from") == "human": |
| return str(msg.get("value", "")) |
| return "" |
|
|
|
|
| def _extract_pair2_target_tool(item: Dict[str, Any]) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: |
| """从标注数据中提取 pair2 的目标工具名和参数(查找 human->observation 后、下一条为 function_call 的目标)。""" |
| conversations = item.get("conversations", []) |
| |
| for i, msg in enumerate(conversations): |
| if msg.get("from") == "observation": |
| if i + 1 < len(conversations) and conversations[i + 1].get("from") == "function_call": |
| target_raw = conversations[i + 1].get("value", "") |
| try: |
| target_obj = json.loads(target_raw) |
| return target_obj.get("name"), target_obj.get("arguments", {}) |
| except Exception: |
| |
| return None, None |
| return None, None |
|
|
|
|
| async def call_prod_sse_for_case(session: aiohttp.ClientSession, query: str, user_id: str = "1") -> Dict[str, Any]: |
| """调用生产 SSE 接口,抓取第一跳检索Top5与最终选择的工具。""" |
| payload = { |
| "query": query, |
| "prompt_template": "standard", |
| "user_id": user_id, |
| "role_code": 1, |
| "user_history": [], |
| "save_method": 0, |
| "is_vector": False, |
| "is_probabilistic": False, |
| "use_retrieval": True, |
| "tool_category": None, |
| "mcp_data": None, |
| "front_data": {}, |
| "message_id": f"eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| } |
|
|
| retrieved_top5: List[str] = [] |
| predicted_tool: Optional[Dict[str, Any]] = None |
| retrieval_call_params: Optional[Dict[str, Any]] = None |
| |
| _delta_toolcall_buffer: List[str] = [] |
| _last_result_delta: Optional[Dict[str, Any]] = None |
|
|
| headers = { |
| "Accept": "text/event-stream", |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Content-Type": "application/json" |
| } |
|
|
| event_count = 0 |
| http_status = None |
| error_msg = None |
|
|
| |
| max_duration_sec = 30 |
| min_duration_sec = 3 |
| start_ts = datetime.now().timestamp() |
|
|
| async with session.post(PROD_SSE_URL, json=payload, headers=headers) as resp: |
| http_status = resp.status |
| if http_status != 200: |
| try: |
| text = await resp.text() |
| except Exception: |
| text = "" |
| return { |
| "retrieved_top5": [], |
| "predicted_tool": None, |
| "retrieval_call_params": None, |
| "http_status": http_status, |
| "error": f"non-200 response: {http_status}", |
| "response_preview": text[:1000] |
| } |
|
|
| current_event = None |
| try: |
| async for raw in resp.content: |
| try: |
| line = raw.decode("utf-8", errors="ignore").strip() |
| except Exception: |
| continue |
| if not line: |
| continue |
| |
| if line.startswith("id:"): |
| continue |
| if line.startswith("event:"): |
| current_event = line.split("event:", 1)[1].strip() |
| continue |
| if not line.startswith("data:"): |
| continue |
| data_str = line.split("data:", 1)[1].strip() |
| if data_str == "[DONE]": |
| |
| if datetime.now().timestamp() - start_ts < min_duration_sec: |
| await asyncio.sleep(min_duration_sec - (datetime.now().timestamp() - start_ts)) |
| break |
| try: |
| data = json.loads(data_str) |
| except Exception: |
| continue |
|
|
| event_count += 1 |
| if current_event == "tool_call.create.delta": |
| |
| content = data.get("content") |
| if isinstance(content, str) and content: |
| _delta_toolcall_buffer.append(content) |
|
|
| elif current_event == "tool_call.created": |
| tool_call = data.get("tool_call") |
| if isinstance(tool_call, dict): |
| name = tool_call.get("name") |
| if name == "retrieval_tool" and retrieval_call_params is None: |
| retrieval_call_params = tool_call |
| elif name and name != "retrieval_tool": |
| predicted_tool = tool_call |
|
|
| elif current_event == "tool_response.completed": |
| result = data.get("result_delta") |
| if isinstance(result, dict): |
| _last_result_delta = result |
| names: List[str] = [] |
| if isinstance(result, list): |
| for item in result[:5]: |
| if isinstance(item, dict) and isinstance(item.get("name"), str): |
| names.append(item["name"]) |
| elif isinstance(result, dict): |
| tools = result.get("tools") or [] |
| for item in tools[:5]: |
| if isinstance(item, dict) and isinstance(item.get("name"), str): |
| names.append(item["name"]) |
| |
| if not names: |
| chain = result.get("tool_calling_chain") or [] |
| if isinstance(chain, list) and chain: |
| first = chain[0] |
| t_resp = first.get("tool_response") if isinstance(first, dict) else None |
| if isinstance(t_resp, list): |
| for item in t_resp[:5]: |
| if isinstance(item, dict) and isinstance(item.get("name"), str): |
| names.append(item["name"]) |
| if names: |
| retrieved_top5 = names |
|
|
| elif current_event in ("answer.completed", "response.completed"): |
| |
| |
| chain = data.get("tool_calling_chain") or {} |
| if not chain: |
| round_data = data.get("round_data") or {} |
| chain = round_data.get("tool_calling_chain") if isinstance(round_data, dict) else {} |
| names: List[str] = [] |
| if isinstance(chain, list) and chain: |
| first = chain[0] |
| t_resp = first.get("tool_response") if isinstance(first, dict) else None |
| if isinstance(t_resp, list): |
| for item in t_resp[:5]: |
| if isinstance(item, dict) and isinstance(item.get("name"), str): |
| names.append(item["name"]) |
| if names and not retrieved_top5: |
| retrieved_top5 = names |
|
|
| |
| if datetime.now().timestamp() - start_ts > max_duration_sec: |
| error_msg = f"timeout after {max_duration_sec}s" |
| break |
| except Exception as e: |
| error_msg = str(e) |
|
|
| return { |
| "retrieved_top5": retrieved_top5, |
| "predicted_tool": predicted_tool, |
| "retrieval_call_params": retrieval_call_params, |
| "http_status": http_status, |
| "event_count": event_count, |
| **({"error": error_msg} if error_msg else {}) |
| } |
|
|
|
|
| def _extract_tool_names_from_response(resp_obj: Dict[str, Any], top_k: int = 5) -> List[str]: |
| names: List[str] = [] |
| try: |
| if isinstance(resp_obj.get("result"), list): |
| for item in resp_obj["result"][:top_k]: |
| if isinstance(item, dict) and isinstance(item.get("name"), str): |
| names.append(item["name"]) |
| elif isinstance(resp_obj.get("result"), dict): |
| tools = resp_obj["result"].get("tools") or [] |
| for item in tools[:top_k]: |
| if isinstance(item, dict) and isinstance(item.get("name"), str): |
| names.append(item["name"]) |
| elif isinstance(resp_obj.get("data"), list): |
| for item in resp_obj["data"][:top_k]: |
| if isinstance(item, dict) and isinstance(item.get("name"), str): |
| names.append(item["name"]) |
| if not names: |
| |
| text = json.dumps(resp_obj, ensure_ascii=False) |
| import re as _re |
| names = _re.findall(r'"name"\s*:\s*"([^"]+)"', text)[:top_k] |
| except Exception: |
| pass |
| return names[:top_k] |
|
|
|
|
| async def call_retrieval_endpoint(session: aiohttp.ClientSession, retrieval_call_params: Dict[str, Any]) -> Dict[str, Any]: |
| payload = { |
| "jsonrpc": "2.0", |
| "id": "eval_req_001", |
| "method": "tools/call", |
| "params": { |
| "name": "retrieval_tool", |
| "arguments": retrieval_call_params.get("arguments", retrieval_call_params) or {} |
| } |
| } |
| headers = {"Content-Type": "application/json", "Accept": "application/json"} |
| status = None |
| try: |
| async with session.post(RETRIEVAL_ENDPOINT, json=payload, headers=headers) as resp: |
| status = resp.status |
| try: |
| data = await resp.json() |
| except Exception: |
| data = {"raw": (await resp.text())[:2000]} |
| names = _extract_tool_names_from_response(data, top_k=5) if status == 200 else [] |
| return { |
| "retrieval_http_status": status, |
| "retrieval_response_preview": json.dumps(data, ensure_ascii=False)[:2000], |
| "retrieved_top5": names |
| } |
| except Exception as e: |
| return { |
| "retrieval_http_status": status, |
| "retrieval_error": str(e), |
| "retrieved_top5": [] |
| } |
|
|
|
|
| def calc_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]: |
| total = len(rows) |
| if total == 0: |
| return {"total": 0, "recall@5": 0.0, "precision@1": 0.0, "arg_accuracy": 0.0, "arg_denominator": 0} |
| recall_hits = 0 |
| precision_hits = 0 |
| precision_denominator = 0 |
| arg_hits = 0 |
| arg_total = 0 |
| for r in rows: |
| tgt_name = r.get("target_tool_name") |
| retrieved = r.get("retrieved_top5", []) or [] |
| pred = (r.get("predicted_tool") or {}).get("name") |
| recall_success = tgt_name and tgt_name in retrieved |
| if recall_success: |
| recall_hits += 1 |
| precision_denominator += 1 |
| |
| if tgt_name and pred and tgt_name == pred: |
| precision_hits += 1 |
| |
| if r.get("arg_match") is not None: |
| arg_total += 1 |
| if r.get("arg_match") == 1: |
| arg_hits += 1 |
| return { |
| "total": total, |
| "recall@5": recall_hits / total, |
| "precision@1": (precision_hits / precision_denominator) if precision_denominator > 0 else 0.0, |
| "precision_denominator": precision_denominator, |
| "arg_accuracy": (arg_hits / arg_total) if arg_total > 0 else 0.0, |
| "arg_denominator": arg_total |
| } |
|
|
|
|
| async def main(): |
| parser = argparse.ArgumentParser(description="基于生产SSE接口的评估") |
| parser.add_argument("--input_file", "-i", type=str, required=True) |
| parser.add_argument("--output_file", "-o", type=str, required=True) |
| parser.add_argument("--start_idx", "-s", type=int, default=0) |
| parser.add_argument("--end_idx", "-e", type=int, default=50) |
| parser.add_argument("--user_id", type=str, default=os.getenv("EVAL_USER_ID", "1")) |
| parser.add_argument("--checkpoint_file", "-c", type=str, default=None, |
| help="断点续跑文件路径;提供则每个case完成后都会写入断点进度") |
| args = parser.parse_args() |
|
|
| with open(args.input_file, "r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| start = max(0, args.start_idx) |
| end = min(len(data), args.end_idx if args.end_idx is not None else len(data)) |
| if start >= end: |
| print("无有效评估范围") |
| return |
|
|
| results: List[Dict[str, Any]] = [] |
| processed_indices: set = set() |
|
|
| |
| if args.checkpoint_file and os.path.exists(args.checkpoint_file): |
| try: |
| with open(args.checkpoint_file, "r", encoding="utf-8") as f: |
| ckpt = json.load(f) |
| results = ckpt.get("results", []) |
| processed_indices = set(ckpt.get("processed_indices", [])) |
| print(f"🔁 从断点恢复:已处理 {len(processed_indices)} 个cases;继续评估...") |
| except Exception as e: |
| print(f"⚠️ 读取断点失败,将从头开始: {e}") |
| results = [] |
| processed_indices = set() |
|
|
| timeout = aiohttp.ClientTimeout(total=600) |
| connector = aiohttp.TCPConnector(limit=5, limit_per_host=5) |
| total_cases = end - start |
| print(f"\n开始评估,共 {total_cases} 个cases ({start} 到 {end-1})") |
| print("=" * 60) |
| |
| async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: |
| for idx in range(start, end): |
| if idx in processed_indices: |
| |
| current_num = (idx - start + 1) |
| total_cases = end - start |
| print(f"[{current_num}/{total_cases}] 跳过已完成的 case {idx}") |
| continue |
| item = data[idx] |
| user_query = _extract_user_query_from_conversation(item) |
| target_tool_name, target_args = _extract_pair2_target_tool(item) |
| |
| current_num = idx - start + 1 |
| print(f"[{current_num}/{total_cases}] 处理 case {idx}...", end=" ", flush=True) |
|
|
| try: |
| r = await call_prod_sse_for_case(session, user_query, user_id=args.user_id) |
| except Exception as e: |
| r = {"error": str(e), "retrieved_top5": [], "predicted_tool": None, "retrieval_call_params": None} |
| print(f"❌ SSE调用失败: {str(e)[:50]}") |
|
|
| |
| if (not r.get("retrieved_top5")) and r.get("retrieval_call_params"): |
| try: |
| retrieval_summary = await call_retrieval_endpoint(session, r["retrieval_call_params"]) |
| r.update(retrieval_summary) |
| except Exception as e: |
| r.update({"retrieval_error": str(e)}) |
|
|
| |
| retrieved_top5 = r.get("retrieved_top5", []) or [] |
| recall_value = 1 if (target_tool_name and target_tool_name in retrieved_top5) else 0 |
| |
| |
| |
| pred_tool_name = (r.get("predicted_tool") or {}).get("name") |
| pred_args = (r.get("predicted_tool") or {}).get("arguments") or {} |
| tgt_args = target_args or {} |
| arg_details = {} |
| |
| if recall_value == 1 and pred_tool_name == target_tool_name and isinstance(tgt_args, dict) and tgt_args: |
| all_match = True |
| for k, v in tgt_args.items(): |
| pv = pred_args.get(k) |
| is_match = (pv == v) |
| arg_details[k] = {"target": v, "predict": pv, "match": is_match} |
| if not is_match: |
| all_match = False |
| arg_match_value = 1 if all_match else 0 |
| else: |
| arg_match_value = None |
|
|
| row = { |
| "index": idx, |
| "user_query": user_query, |
| "target_tool_name": target_tool_name, |
| "target_arguments": target_args, |
| "recall@5": recall_value, |
| "arg_match": arg_match_value, |
| "arg_match_details": arg_details, |
| **r |
| } |
| results.append(row) |
| processed_indices.add(idx) |
| |
| |
| retrieved_str = f"检索到{len(retrieved_top5)}个工具" if retrieved_top5 else "未检索到工具" |
| recall_str = "✅ recall成功" if recall_value == 1 else "❌ recall失败" |
| pred_name = (r.get("predicted_tool") or {}).get("name") or "无" |
| if arg_match_value is None: |
| if recall_value == 0: |
| arg_str = "⏭️ 参数评估跳过(未召回)" |
| elif pred_tool_name != target_tool_name: |
| arg_str = f"⏭️ 参数评估跳过(工具不匹配: {pred_name} ≠ {target_tool_name})" |
| else: |
| arg_str = "⏭️ 参数评估跳过(无目标参数)" |
| else: |
| arg_str = "✅ 参数完全匹配" if arg_match_value == 1 else "❌ 参数不匹配" |
| print(f"{retrieved_str} | {recall_str} | 预测工具: {pred_name} | {arg_str}") |
|
|
| |
| if args.checkpoint_file: |
| try: |
| ckpt_data = { |
| "results": results, |
| "processed_indices": sorted(list(processed_indices)), |
| "meta": { |
| "input_file": args.input_file, |
| "output_file": args.output_file, |
| "user_id": args.user_id, |
| "range": [start, end] |
| } |
| } |
| os.makedirs(os.path.dirname(args.checkpoint_file), exist_ok=True) |
| with open(args.checkpoint_file, "w", encoding="utf-8") as f: |
| json.dump(ckpt_data, f, ensure_ascii=False, indent=2) |
| except Exception as e: |
| print(f"⚠️ 写入断点失败: {e}") |
|
|
| print("=" * 60) |
| print(f"评估完成,共处理 {len(results)} 个cases\n") |
|
|
| metrics = calc_metrics(results) |
| report = { |
| "summary": { |
| "api": PROD_SSE_URL, |
| "user_id": args.user_id, |
| "start_idx": start, |
| "end_idx": end, |
| "metrics": metrics |
| }, |
| "cases": results |
| } |
|
|
| os.makedirs(os.path.dirname(args.output_file), exist_ok=True) |
| with open(args.output_file, "w", encoding="utf-8") as f: |
| json.dump(report, f, ensure_ascii=False, indent=2) |
| print(f"✅ 评估结果已保存: {args.output_file}") |
| print(f"📊 总体指标: recall@5={metrics['recall@5']:.3f}, precision@1={metrics['precision@1']:.3f}\n") |
|
|
| |
| if args.checkpoint_file and os.path.exists(args.checkpoint_file): |
| try: |
| os.remove(args.checkpoint_file) |
| print(f"🗑️ 已删除断点文件: {args.checkpoint_file}") |
| except Exception as e: |
| print(f"⚠️ 删除断点文件失败: {e}") |
| |
| |
| recall_failed_cases = [case for case in results if case.get("recall@5", 0) == 0] |
| if recall_failed_cases: |
| failed_output_file = args.output_file.replace(".json", "_recall_failed.json") |
| failed_report = { |
| "summary": { |
| "api": PROD_SSE_URL, |
| "user_id": args.user_id, |
| "start_idx": start, |
| "end_idx": end, |
| "total_failed": len(recall_failed_cases), |
| "total_cases": len(results), |
| "failure_rate": len(recall_failed_cases) / len(results) if results else 0.0 |
| }, |
| "cases": recall_failed_cases |
| } |
| with open(failed_output_file, "w", encoding="utf-8") as f: |
| json.dump(failed_report, f, ensure_ascii=False, indent=2) |
| print(f"❌ Recall失败cases已保存: {failed_output_file} (共 {len(recall_failed_cases)} 条)") |
| print(f" 失败率: {len(recall_failed_cases) / len(results) * 100:.1f}%") |
| else: |
| print("✅ 所有cases的recall@5都成功!") |
|
|
| |
| precision_failed_cases = [] |
| precision_eligible = 0 |
| for case in results: |
| recall_success = case.get("recall@5", 0) == 1 |
| if recall_success: |
| precision_eligible += 1 |
| tgt = case.get("target_tool_name") |
| pred = (case.get("predicted_tool") or {}).get("name") |
| |
| if recall_success and tgt and (pred != tgt): |
| precision_failed_cases.append(case) |
| if precision_failed_cases: |
| precision_failed_output = args.output_file.replace(".json", "_precision_failed.json") |
| precision_failed_report = { |
| "summary": { |
| "api": PROD_SSE_URL, |
| "user_id": args.user_id, |
| "start_idx": start, |
| "end_idx": end, |
| "total_failed": len(precision_failed_cases), |
| "total_cases": len(results), |
| "precision_eligible": precision_eligible, |
| "failure_rate": (len(precision_failed_cases) / precision_eligible) if precision_eligible > 0 else 0.0 |
| }, |
| "cases": precision_failed_cases |
| } |
| with open(precision_failed_output, "w", encoding="utf-8") as f: |
| json.dump(precision_failed_report, f, ensure_ascii=False, indent=2) |
| print(f"❌ Precision@1失败cases已保存: {precision_failed_output} (共 {len(precision_failed_cases)} 条)") |
| if precision_eligible > 0: |
| print(f" 基于recall@5成功的样本失败率: {len(precision_failed_cases) / precision_eligible * 100:.1f}% (可评估 {precision_eligible} 条)") |
| else: |
| if precision_eligible > 0: |
| print("✅ 所有recall@5成功的cases在precision@1上均命中!") |
| else: |
| print("ℹ️ 本次无recall@5成功的样本,无法评估precision@1") |
|
|
| |
| arg_failed_cases = [case for case in results if case.get("arg_match") == 0] |
| arg_eligible = sum(1 for case in results if case.get("arg_match") is not None) |
| if arg_failed_cases: |
| arg_failed_output = args.output_file.replace(".json", "_arg_failed.json") |
| arg_failed_report = { |
| "summary": { |
| "api": PROD_SSE_URL, |
| "user_id": args.user_id, |
| "start_idx": start, |
| "end_idx": end, |
| "total_failed": len(arg_failed_cases), |
| "total_eligible": arg_eligible, |
| "eligible_failure_rate": (len(arg_failed_cases) / arg_eligible) if arg_eligible else 0.0 |
| }, |
| "cases": arg_failed_cases |
| } |
| with open(arg_failed_output, "w", encoding="utf-8") as f: |
| json.dump(arg_failed_report, f, ensure_ascii=False, indent=2) |
| print(f"❌ 参数匹配失败cases已保存: {arg_failed_output} (共 {len(arg_failed_cases)} 条)") |
| if arg_eligible: |
| print(f" 基于可评估样本失败率: {len(arg_failed_cases) / arg_eligible * 100:.1f}% (可评估 {arg_eligible} 条)") |
| else: |
| if arg_eligible: |
| print("✅ 所有可评估的cases参数完全匹配!") |
| else: |
| print("ℹ️ 本次无可评估的参数匹配样本(arg_match 皆为 None)") |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|
|
|
|
|