File size: 26,588 Bytes
46b244e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于生产端 SSE 接口的对齐评估脚本
功能:
- 读取标注数据(包含 conversations 与 pair2 目标工具)
- 直接请求生产接口(默认为 http://125.122.38.32:8085/mcp_end2end/stream)获取检索Top5与最终工具调用
- 计算 recall@5 与 precision@1,并输出报告
使用示例:
python eval_via_prod_sse.py \
--input_file /home/ziqiang/LLaMA-Factory/data/dataset/10_27/10.22_evaluate_data.json \
--output_file /home/ziqiang/LLaMA-Factory/data/dataset/10_27/data_evaluation_prod.json \
--start_idx 0 --end_idx 50
"""
import asyncio
import aiohttp
import argparse
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime
PROD_SSE_URL = os.getenv("PROD_SSE_URL", "http://125.122.38.32:8085/mcp_end2end/stream")
RETRIEVAL_ENDPOINT = os.getenv("RETRIEVAL_ENDPOINT", "http://125.122.38.32:6227/v1/mcp/tools/call")
def _extract_user_query_from_conversation(item: Dict[str, Any]) -> str:
"""从标注数据的一条对话中提取原始用户问题(第一条 human)。"""
conversations = item.get("conversations", [])
for msg in conversations:
if msg.get("from") == "human":
return str(msg.get("value", ""))
return ""
def _extract_pair2_target_tool(item: Dict[str, Any]) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""从标注数据中提取 pair2 的目标工具名和参数(查找 human->observation 后、下一条为 function_call 的目标)。"""
conversations = item.get("conversations", [])
# 逻辑:遇到 observation 后紧邻的下一条若是 function_call,则视为 pair2 目标
for i, msg in enumerate(conversations):
if msg.get("from") == "observation":
if i + 1 < len(conversations) and conversations[i + 1].get("from") == "function_call":
target_raw = conversations[i + 1].get("value", "")
try:
target_obj = json.loads(target_raw)
return target_obj.get("name"), target_obj.get("arguments", {})
except Exception:
# 非标准JSON则返回原串
return None, None
return None, None
async def call_prod_sse_for_case(session: aiohttp.ClientSession, query: str, user_id: str = "1") -> Dict[str, Any]:
"""调用生产 SSE 接口,抓取第一跳检索Top5与最终选择的工具。"""
payload = {
"query": query,
"prompt_template": "standard",
"user_id": user_id,
"role_code": 1, # 必填字段
"user_history": [],
"save_method": 0,
"is_vector": False,
"is_probabilistic": False,
"use_retrieval": True,
"tool_category": None,
"mcp_data": None, # 应为字符串或 None,不是对象
"front_data": {},
"message_id": f"eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
}
retrieved_top5: List[str] = []
predicted_tool: Optional[Dict[str, Any]] = None
retrieval_call_params: Optional[Dict[str, Any]] = None
# 作为回退的数据缓存
_delta_toolcall_buffer: List[str] = []
_last_result_delta: Optional[Dict[str, Any]] = None
headers = {
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/json"
}
event_count = 0
http_status = None
error_msg = None
# 最长等待时长与最短驻留,避免瞬时退出
max_duration_sec = 30
min_duration_sec = 3
start_ts = datetime.now().timestamp()
async with session.post(PROD_SSE_URL, json=payload, headers=headers) as resp:
http_status = resp.status
if http_status != 200:
try:
text = await resp.text()
except Exception:
text = ""
return {
"retrieved_top5": [],
"predicted_tool": None,
"retrieval_call_params": None,
"http_status": http_status,
"error": f"non-200 response: {http_status}",
"response_preview": text[:1000]
}
current_event = None
try:
async for raw in resp.content:
try:
line = raw.decode("utf-8", errors="ignore").strip()
except Exception:
continue
if not line:
continue
# 忽略 SSE id 行
if line.startswith("id:"):
continue
if line.startswith("event:"):
current_event = line.split("event:", 1)[1].strip()
continue
if not line.startswith("data:"):
continue
data_str = line.split("data:", 1)[1].strip()
if data_str == "[DONE]":
# 若流很快结束但仍未达到最短驻留,则继续等待片刻以容错代理缓冲
if datetime.now().timestamp() - start_ts < min_duration_sec:
await asyncio.sleep(min_duration_sec - (datetime.now().timestamp() - start_ts))
break
try:
data = json.loads(data_str)
except Exception:
continue
event_count += 1
if current_event == "tool_call.create.delta":
# 增量工具调用内容,通常是字符串形式,累积以便必要时解析
content = data.get("content")
if isinstance(content, str) and content:
_delta_toolcall_buffer.append(content)
elif current_event == "tool_call.created":
tool_call = data.get("tool_call")
if isinstance(tool_call, dict):
name = tool_call.get("name")
if name == "retrieval_tool" and retrieval_call_params is None:
retrieval_call_params = tool_call
elif name and name != "retrieval_tool":
predicted_tool = tool_call
elif current_event == "tool_response.completed":
result = data.get("result_delta")
if isinstance(result, dict):
_last_result_delta = result
names: List[str] = []
if isinstance(result, list):
for item in result[:5]:
if isinstance(item, dict) and isinstance(item.get("name"), str):
names.append(item["name"])
elif isinstance(result, dict):
tools = result.get("tools") or []
for item in tools[:5]:
if isinstance(item, dict) and isinstance(item.get("name"), str):
names.append(item["name"])
# 回退:从 tool_calling_chain 中解析候选工具
if not names:
chain = result.get("tool_calling_chain") or []
if isinstance(chain, list) and chain:
first = chain[0]
t_resp = first.get("tool_response") if isinstance(first, dict) else None
if isinstance(t_resp, list):
for item in t_resp[:5]:
if isinstance(item, dict) and isinstance(item.get("name"), str):
names.append(item["name"])
if names:
retrieved_top5 = names
elif current_event in ("answer.completed", "response.completed"):
# 进一步回退:某些实现会把完整链路放到 round_data 或 usage 区域
# 这里尝试从 data 中再次提取 tool_calling_chain 作为候选
chain = data.get("tool_calling_chain") or {}
if not chain:
round_data = data.get("round_data") or {}
chain = round_data.get("tool_calling_chain") if isinstance(round_data, dict) else {}
names: List[str] = []
if isinstance(chain, list) and chain:
first = chain[0]
t_resp = first.get("tool_response") if isinstance(first, dict) else None
if isinstance(t_resp, list):
for item in t_resp[:5]:
if isinstance(item, dict) and isinstance(item.get("name"), str):
names.append(item["name"])
if names and not retrieved_top5:
retrieved_top5 = names
# 超时保护:若超过最长时长则跳出
if datetime.now().timestamp() - start_ts > max_duration_sec:
error_msg = f"timeout after {max_duration_sec}s"
break
except Exception as e:
error_msg = str(e)
return {
"retrieved_top5": retrieved_top5,
"predicted_tool": predicted_tool,
"retrieval_call_params": retrieval_call_params,
"http_status": http_status,
"event_count": event_count,
**({"error": error_msg} if error_msg else {})
}
def _extract_tool_names_from_response(resp_obj: Dict[str, Any], top_k: int = 5) -> List[str]:
names: List[str] = []
try:
if isinstance(resp_obj.get("result"), list):
for item in resp_obj["result"][:top_k]:
if isinstance(item, dict) and isinstance(item.get("name"), str):
names.append(item["name"])
elif isinstance(resp_obj.get("result"), dict):
tools = resp_obj["result"].get("tools") or []
for item in tools[:top_k]:
if isinstance(item, dict) and isinstance(item.get("name"), str):
names.append(item["name"])
elif isinstance(resp_obj.get("data"), list):
for item in resp_obj["data"][:top_k]:
if isinstance(item, dict) and isinstance(item.get("name"), str):
names.append(item["name"])
if not names:
# 兜底在全文中提取 "name": "..."
text = json.dumps(resp_obj, ensure_ascii=False)
import re as _re
names = _re.findall(r'"name"\s*:\s*"([^"]+)"', text)[:top_k]
except Exception:
pass
return names[:top_k]
async def call_retrieval_endpoint(session: aiohttp.ClientSession, retrieval_call_params: Dict[str, Any]) -> Dict[str, Any]:
payload = {
"jsonrpc": "2.0",
"id": "eval_req_001",
"method": "tools/call",
"params": {
"name": "retrieval_tool",
"arguments": retrieval_call_params.get("arguments", retrieval_call_params) or {}
}
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}
status = None
try:
async with session.post(RETRIEVAL_ENDPOINT, json=payload, headers=headers) as resp:
status = resp.status
try:
data = await resp.json()
except Exception:
data = {"raw": (await resp.text())[:2000]}
names = _extract_tool_names_from_response(data, top_k=5) if status == 200 else []
return {
"retrieval_http_status": status,
"retrieval_response_preview": json.dumps(data, ensure_ascii=False)[:2000],
"retrieved_top5": names
}
except Exception as e:
return {
"retrieval_http_status": status,
"retrieval_error": str(e),
"retrieved_top5": []
}
def calc_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
total = len(rows)
if total == 0:
return {"total": 0, "recall@5": 0.0, "precision@1": 0.0, "arg_accuracy": 0.0, "arg_denominator": 0}
recall_hits = 0
precision_hits = 0
precision_denominator = 0 # precision@1的分母:recall@5成功的样本数
arg_hits = 0
arg_total = 0 # 仅统计 recall 成功的样本
for r in rows:
tgt_name = r.get("target_tool_name")
retrieved = r.get("retrieved_top5", []) or []
pred = (r.get("predicted_tool") or {}).get("name")
recall_success = tgt_name and tgt_name in retrieved
if recall_success:
recall_hits += 1
precision_denominator += 1 # 只有在recall成功时才计入precision分母
# precision@1只在recall@5成功的样本中计算
if tgt_name and pred and tgt_name == pred:
precision_hits += 1
# 仅在 arg_match 有效(非 None)时计入参数准确率
if r.get("arg_match") is not None:
arg_total += 1
if r.get("arg_match") == 1:
arg_hits += 1
return {
"total": total,
"recall@5": recall_hits / total,
"precision@1": (precision_hits / precision_denominator) if precision_denominator > 0 else 0.0,
"precision_denominator": precision_denominator, # 记录precision@1的分母
"arg_accuracy": (arg_hits / arg_total) if arg_total > 0 else 0.0,
"arg_denominator": arg_total
}
async def main():
parser = argparse.ArgumentParser(description="基于生产SSE接口的评估")
parser.add_argument("--input_file", "-i", type=str, required=True)
parser.add_argument("--output_file", "-o", type=str, required=True)
parser.add_argument("--start_idx", "-s", type=int, default=0)
parser.add_argument("--end_idx", "-e", type=int, default=50)
parser.add_argument("--user_id", type=str, default=os.getenv("EVAL_USER_ID", "1"))
parser.add_argument("--checkpoint_file", "-c", type=str, default=None,
help="断点续跑文件路径;提供则每个case完成后都会写入断点进度")
args = parser.parse_args()
with open(args.input_file, "r", encoding="utf-8") as f:
data = json.load(f)
start = max(0, args.start_idx)
end = min(len(data), args.end_idx if args.end_idx is not None else len(data))
if start >= end:
print("无有效评估范围")
return
results: List[Dict[str, Any]] = []
processed_indices: set = set()
# 读取断点文件
if args.checkpoint_file and os.path.exists(args.checkpoint_file):
try:
with open(args.checkpoint_file, "r", encoding="utf-8") as f:
ckpt = json.load(f)
results = ckpt.get("results", [])
processed_indices = set(ckpt.get("processed_indices", []))
print(f"🔁 从断点恢复:已处理 {len(processed_indices)} 个cases;继续评估...")
except Exception as e:
print(f"⚠️ 读取断点失败,将从头开始: {e}")
results = []
processed_indices = set()
timeout = aiohttp.ClientTimeout(total=600)
connector = aiohttp.TCPConnector(limit=5, limit_per_host=5)
total_cases = end - start
print(f"\n开始评估,共 {total_cases} 个cases ({start} 到 {end-1})")
print("=" * 60)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
for idx in range(start, end):
if idx in processed_indices:
# 跳过已处理
current_num = (idx - start + 1)
total_cases = end - start
print(f"[{current_num}/{total_cases}] 跳过已完成的 case {idx}")
continue
item = data[idx]
user_query = _extract_user_query_from_conversation(item)
target_tool_name, target_args = _extract_pair2_target_tool(item)
current_num = idx - start + 1
print(f"[{current_num}/{total_cases}] 处理 case {idx}...", end=" ", flush=True)
try:
r = await call_prod_sse_for_case(session, user_query, user_id=args.user_id)
except Exception as e:
r = {"error": str(e), "retrieved_top5": [], "predicted_tool": None, "retrieval_call_params": None}
print(f"❌ SSE调用失败: {str(e)[:50]}")
# 若前端SSE未直接给出 Top5,则根据第一跳 retrieval 的 query 再请求 6227 网关获取Top5
if (not r.get("retrieved_top5")) and r.get("retrieval_call_params"):
try:
retrieval_summary = await call_retrieval_endpoint(session, r["retrieval_call_params"])
r.update(retrieval_summary)
except Exception as e:
r.update({"retrieval_error": str(e)})
# 计算该case的recall@5(1表示成功,0表示失败)
retrieved_top5 = r.get("retrieved_top5", []) or []
recall_value = 1 if (target_tool_name and target_tool_name in retrieved_top5) else 0
# 计算第二轮 function_call 的参数匹配(与标注 target_arguments 比较)
# 仅在 recall 成功(recall_value==1)且工具名称也匹配时才计算参数准确率
pred_tool_name = (r.get("predicted_tool") or {}).get("name")
pred_args = (r.get("predicted_tool") or {}).get("arguments") or {}
tgt_args = target_args or {}
arg_details = {}
# 只有工具名称匹配时,参数匹配才有意义
if recall_value == 1 and pred_tool_name == target_tool_name and isinstance(tgt_args, dict) and tgt_args:
all_match = True
for k, v in tgt_args.items():
pv = pred_args.get(k)
is_match = (pv == v)
arg_details[k] = {"target": v, "predict": pv, "match": is_match}
if not is_match:
all_match = False
arg_match_value = 1 if all_match else 0
else:
arg_match_value = None # 跳过参数评估(工具未召回、工具名不匹配或无目标参数)
row = {
"index": idx,
"user_query": user_query,
"target_tool_name": target_tool_name,
"target_arguments": target_args,
"recall@5": recall_value, # 添加recall字段:1表示成功,0表示失败
"arg_match": arg_match_value, # 第二轮参数是否与标注完全一致(1/0)
"arg_match_details": arg_details,
**r
}
results.append(row)
processed_indices.add(idx)
# 打印进度信息
retrieved_str = f"检索到{len(retrieved_top5)}个工具" if retrieved_top5 else "未检索到工具"
recall_str = "✅ recall成功" if recall_value == 1 else "❌ recall失败"
pred_name = (r.get("predicted_tool") or {}).get("name") or "无"
if arg_match_value is None:
if recall_value == 0:
arg_str = "⏭️ 参数评估跳过(未召回)"
elif pred_tool_name != target_tool_name:
arg_str = f"⏭️ 参数评估跳过(工具不匹配: {pred_name} ≠ {target_tool_name})"
else:
arg_str = "⏭️ 参数评估跳过(无目标参数)"
else:
arg_str = "✅ 参数完全匹配" if arg_match_value == 1 else "❌ 参数不匹配"
print(f"{retrieved_str} | {recall_str} | 预测工具: {pred_name} | {arg_str}")
# 实时保存断点
if args.checkpoint_file:
try:
ckpt_data = {
"results": results,
"processed_indices": sorted(list(processed_indices)),
"meta": {
"input_file": args.input_file,
"output_file": args.output_file,
"user_id": args.user_id,
"range": [start, end]
}
}
os.makedirs(os.path.dirname(args.checkpoint_file), exist_ok=True)
with open(args.checkpoint_file, "w", encoding="utf-8") as f:
json.dump(ckpt_data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"⚠️ 写入断点失败: {e}")
print("=" * 60)
print(f"评估完成,共处理 {len(results)} 个cases\n")
metrics = calc_metrics(results)
report = {
"summary": {
"api": PROD_SSE_URL,
"user_id": args.user_id,
"start_idx": start,
"end_idx": end,
"metrics": metrics
},
"cases": results
}
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"✅ 评估结果已保存: {args.output_file}")
print(f"📊 总体指标: recall@5={metrics['recall@5']:.3f}, precision@1={metrics['precision@1']:.3f}\n")
# 评估完成后删除断点(可选)
if args.checkpoint_file and os.path.exists(args.checkpoint_file):
try:
os.remove(args.checkpoint_file)
print(f"🗑️ 已删除断点文件: {args.checkpoint_file}")
except Exception as e:
print(f"⚠️ 删除断点文件失败: {e}")
# 筛选recall=0的cases,单独保存
recall_failed_cases = [case for case in results if case.get("recall@5", 0) == 0]
if recall_failed_cases:
failed_output_file = args.output_file.replace(".json", "_recall_failed.json")
failed_report = {
"summary": {
"api": PROD_SSE_URL,
"user_id": args.user_id,
"start_idx": start,
"end_idx": end,
"total_failed": len(recall_failed_cases),
"total_cases": len(results),
"failure_rate": len(recall_failed_cases) / len(results) if results else 0.0
},
"cases": recall_failed_cases
}
with open(failed_output_file, "w", encoding="utf-8") as f:
json.dump(failed_report, f, ensure_ascii=False, indent=2)
print(f"❌ Recall失败cases已保存: {failed_output_file} (共 {len(recall_failed_cases)} 条)")
print(f" 失败率: {len(recall_failed_cases) / len(results) * 100:.1f}%")
else:
print("✅ 所有cases的recall@5都成功!")
# 筛选precision@1失败的cases(仅在recall@5成功的样本中,预测工具名不等于目标工具名)
precision_failed_cases = []
precision_eligible = 0 # recall@5成功的样本数(precision@1的分母)
for case in results:
recall_success = case.get("recall@5", 0) == 1
if recall_success:
precision_eligible += 1
tgt = case.get("target_tool_name")
pred = (case.get("predicted_tool") or {}).get("name")
# 只有recall@5成功,但预测工具名不等于目标工具名时,才算precision@1失败
if recall_success and tgt and (pred != tgt):
precision_failed_cases.append(case)
if precision_failed_cases:
precision_failed_output = args.output_file.replace(".json", "_precision_failed.json")
precision_failed_report = {
"summary": {
"api": PROD_SSE_URL,
"user_id": args.user_id,
"start_idx": start,
"end_idx": end,
"total_failed": len(precision_failed_cases),
"total_cases": len(results),
"precision_eligible": precision_eligible, # recall@5成功的样本数
"failure_rate": (len(precision_failed_cases) / precision_eligible) if precision_eligible > 0 else 0.0
},
"cases": precision_failed_cases
}
with open(precision_failed_output, "w", encoding="utf-8") as f:
json.dump(precision_failed_report, f, ensure_ascii=False, indent=2)
print(f"❌ Precision@1失败cases已保存: {precision_failed_output} (共 {len(precision_failed_cases)} 条)")
if precision_eligible > 0:
print(f" 基于recall@5成功的样本失败率: {len(precision_failed_cases) / precision_eligible * 100:.1f}% (可评估 {precision_eligible} 条)")
else:
if precision_eligible > 0:
print("✅ 所有recall@5成功的cases在precision@1上均命中!")
else:
print("ℹ️ 本次无recall@5成功的样本,无法评估precision@1")
# 筛选参数匹配失败(arg_accuracy失败)的cases(仅统计被评估样本 arg_match==0)
arg_failed_cases = [case for case in results if case.get("arg_match") == 0]
arg_eligible = sum(1 for case in results if case.get("arg_match") is not None)
if arg_failed_cases:
arg_failed_output = args.output_file.replace(".json", "_arg_failed.json")
arg_failed_report = {
"summary": {
"api": PROD_SSE_URL,
"user_id": args.user_id,
"start_idx": start,
"end_idx": end,
"total_failed": len(arg_failed_cases),
"total_eligible": arg_eligible,
"eligible_failure_rate": (len(arg_failed_cases) / arg_eligible) if arg_eligible else 0.0
},
"cases": arg_failed_cases
}
with open(arg_failed_output, "w", encoding="utf-8") as f:
json.dump(arg_failed_report, f, ensure_ascii=False, indent=2)
print(f"❌ 参数匹配失败cases已保存: {arg_failed_output} (共 {len(arg_failed_cases)} 条)")
if arg_eligible:
print(f" 基于可评估样本失败率: {len(arg_failed_cases) / arg_eligible * 100:.1f}% (可评估 {arg_eligible} 条)")
else:
if arg_eligible:
print("✅ 所有可评估的cases参数完全匹配!")
else:
print("ℹ️ 本次无可评估的参数匹配样本(arg_match 皆为 None)")
if __name__ == "__main__":
asyncio.run(main())
|