File size: 49,737 Bytes
46b244e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
模型评估脚本
功能:
1. 读取JSON文件中的对话数据
2. 提取human的value作为query调用server:8020
3. 处理流式返回结果
4. 对比和存储结果
"""
import json
import httpx
import asyncio
import time
import re
import os
from typing import Dict, List, Any
from utils.custom_logging import setup_logging
from utils.extraction import extract_json_from_string
from loguru import logger
from collections import Counter
setup_logging()
class ModelEvaluator:
def __init__(self, server_url: str = "http://localhost:8020/mcp_end2end/stream"):
self.server_url = server_url
self.results = []
self.client = None
self.start_time = None
self.error_count = 0
self.success_count = 0
def load_data(self, file_path: str) -> List[Dict]:
"""加载JSON数据文件"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
logger.info(f"成功加载数据文件,共{len(data)}条记录")
return data
except Exception as e:
logger.error(f"加载数据文件失败: {e}")
return []
def extract_human_queries(self, data: List[Dict]) -> List[Dict]:
"""提取所有human的value作为query"""
queries = []
for i, item in enumerate(data):
if 'conversations' in item:
for conv in item['conversations']:
if conv.get('from') == 'human':
query_data = {
'index': i,
'query': conv.get('value', ''),
'original_data': item
}
queries.append(query_data)
break # 只取第一个human的value
logger.info(f"提取到{len(queries)}个查询")
return queries
def parse_sse_events(self, sse_content: str, filter_events: List[str] = None) -> List[Dict]:
"""
解析SSE格式的内容,提取指定类型的事件
Args:
sse_content: SSE格式的文本内容(可以是多行)
filter_events: 需要过滤的事件类型列表,如果为None则解析所有事件
Returns:
解析成功的事件列表
"""
events = []
current_event = {}
parsed_count = 0
failed_count = 0
# 如果没有指定过滤事件,设置默认过滤
if filter_events is None:
filter_events = ['tool_call.created', 'tool_response.completed']
for line in sse_content.split('\n'):
line = line.strip()
if not line:
# 空行表示一个完整的事件结束
if current_event and 'event' in current_event and 'data' in current_event:
event_type = current_event['event']
# 只处理我们关心的事件类型
if event_type in filter_events:
# 使用extract_json_from_string解析data字段中的JSON
data_content = extract_json_from_string(current_event['data'])
if data_content is not None:
event_obj = {
'id': current_event.get('id'),
'event': current_event['event'],
'data': data_content
}
events.append(event_obj)
parsed_count += 1
logger.debug(f"✅ 成功解析事件: {current_event['event']} (ID: {current_event.get('id', 'N/A')})")
else:
failed_count += 1
logger.warning(f"❌ 无法解析事件数据: {current_event['event']} - {current_event['data'][:100]}...")
current_event = {}
continue
# 解析SSE格式的字段
if line.startswith('id: '):
current_event['id'] = line[4:]
elif line.startswith('event: '):
current_event['event'] = line[7:]
elif line.startswith('data: '):
current_event['data'] = line[6:]
# 检查是否是结束标记
if current_event['data'].strip() == '[DONE]':
logger.debug("收到结束标记 [DONE]")
break
else:
logger.debug(f"未知格式的行: {line}")
# 处理最后一个事件(如果没有空行结尾)
if current_event and 'event' in current_event and 'data' in current_event:
event_type = current_event['event']
# 只处理我们关心的事件类型
if event_type in filter_events:
data_content = extract_json_from_string(current_event['data'])
if data_content is not None:
event_obj = {
'id': current_event.get('id'),
'event': current_event['event'],
'data': data_content
}
events.append(event_obj)
parsed_count += 1
logger.debug(f"✅ 成功解析最后一个事件: {current_event['event']} (ID: {current_event.get('id', 'N/A')})")
else:
failed_count += 1
logger.warning(f"❌ 无法解析最后一个事件数据: {current_event['event']} - {current_event['data'][:100]}...")
# 统计和日志输出
logger.info(f"=== SSE解析结果统计 ===")
logger.info(f"成功解析事件数: {parsed_count}")
logger.info(f"解析失败事件数: {failed_count}")
logger.info(f"总事件数: {len(events)}")
if events:
event_types = [event.get('event', 'unknown') for event in events]
event_counts = Counter(event_types)
logger.info(f"事件类型分布: {dict(event_counts)}")
else:
logger.warning("⚠️ 未解析到任何目标事件")
return events
async def call_server(self, query: str, max_retries: int = 3, retry_delay: float = 2.0) -> List[Dict]:
"""异步调用server:8020端口,处理流式返回,支持重试机制"""
payload = {
"user_id": "166",
"role_code": 1,
"query": query,
"save_method": 0
}
for attempt in range(max_retries):
try:
async with httpx.AsyncClient(timeout=30.0) as client:
async with client.stream(
'POST',
self.server_url,
json=payload,
headers={'Accept': 'text/event-stream'}
) as response:
response.raise_for_status()
# 收集所有SSE文本内容
sse_content = ""
async for line in response.aiter_text():
logger.debug(f"Received data: {line}")
sse_content += line
# 检查是否收到结束标记
if '[DONE]' in line:
logger.debug("收到结束标记 [DONE]")
break
# 使用封装的方法解析SSE内容,只解析我们关心的事件
events = self.parse_sse_events(
sse_content,
filter_events=['tool_call.created', 'tool_response.completed']
)
# 验证关键事件类型
has_tool_call = any(event.get('event') == 'tool_call.created' for event in events)
has_tool_response = any(event.get('event') == 'tool_response.completed' for event in events)
logger.info(f"包含工具调用事件: {'✅' if has_tool_call else '❌'}")
logger.info(f"包含工具响应事件: {'✅' if has_tool_response else '❌'}")
return events
except httpx.RequestError as e:
logger.warning(f"Call server failed (attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
retry_delay *= 2 # 指数退避
else:
logger.error(f"All retry attempts failed for query: {query[:50]}...")
raise Exception(f"Server connection failed after {max_retries} attempts: {e}")
except httpx.TimeoutException as e:
logger.warning(f"Server timeout (attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
retry_delay *= 2
else:
logger.error(f"Timeout after all retry attempts for query: {query[:50]}...")
raise Exception(f"Server timeout after {max_retries} attempts: {e}")
except Exception as e:
logger.error(f"Unexpected error processing response (attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
retry_delay *= 2
else:
logger.error(f"Unexpected error after all retry attempts for query: {query[:50]}...")
raise Exception(f"Unexpected error after {max_retries} attempts: {e}")
raise Exception("All retry attempts exhausted")
def extract_tool_calls_and_observations(self, events: List[Dict]) -> Dict[str, List]:
"""Extract tool_call.created and tool_response.completed content from events"""
tool_calls = []
tool_responses = []
logger.debug(f"开始提取工具调用和响应,共 {len(events)} 个事件")
for event in events:
event_type = event.get('event')
event_data = event.get('data', {})
if event_type == 'tool_call.created':
logger.debug(f"Extract tool_call.created content: {event}")
# Extract tool_call information
tool_call_info = event_data.get('tool_call', {})
if tool_call_info:
tool_calls.append(tool_call_info) # 直接存储JSON对象
logger.debug(f"✅ 提取工具调用: {tool_call_info.get('name', 'unknown')}")
else:
logger.warning(f"❌ tool_call.created 事件中缺少 tool_call 信息")
elif event_type == 'tool_response.completed':
logger.debug(f"Extract tool_response.completed content: {event}")
# Extract tool_response information
if 'result_delta' in event_data:
tool_response = event_data['result_delta'].get('result', [])
tool_responses.append(tool_response) # 直接存储JSON对象
logger.debug(f"✅ 提取工具响应: {len(str(tool_response))} 字符")
else:
tool_response = []
tool_responses.append(tool_response) # 直接存储JSON对象
logger.info(f"Extract {len(tool_calls)} tool calls, {len(tool_responses)} tool responses")
return {
'tool_calls': tool_calls,
'tool_responses': tool_responses
}
def extract_original_data(self, original_data: Dict) -> Dict[str, List]:
"""Extract function_call and observation content from original data"""
function_calls = []
observations = []
if 'conversations' in original_data:
for conv in original_data['conversations']:
if conv.get('from') == 'function_call':
# 解析JSON字符串为对象
try:
function_call_obj = json.loads(conv.get('value', '{}'))
function_calls.append(function_call_obj)
except json.JSONDecodeError as e:
logger.warning(f"解析function_call JSON时出错: {e}")
function_calls.append({})
elif conv.get('from') == 'observation':
# 解析JSON字符串为对象
try:
observation_obj = json.loads(conv.get('value', '[]'))
observations.append(observation_obj)
except json.JSONDecodeError as e:
logger.warning(f"解析observation JSON时出错: {e}")
observations.append([])
return {
'function_calls': function_calls,
'observations': observations
}
def compare_tool_call(self, server_call: Dict, original_call: Dict) -> Dict:
"""比较单个工具调用,检查name和arguments的匹配度"""
try:
# 检查name是否一致
name_match = server_call.get('name') == original_call.get('name')
name_score = 1.0 if name_match else 0.0
# 检查arguments是否一致
server_args = server_call.get('arguments', {})
original_args = original_call.get('arguments', {})
arguments_match = server_args == original_args
arguments_score = 1.0 if arguments_match else 0.0
return {
'name_match': name_match,
'name_score': name_score,
'arguments_match': arguments_match,
'arguments_score': arguments_score,
'server_name': server_call.get('name', ''),
'original_name': original_call.get('name', ''),
'server_arguments': server_args,
'original_arguments': original_args
}
except (KeyError, TypeError) as e:
logger.warning(f"比较工具调用时出错: {e}")
return {
'name_match': False,
'name_score': 0.0,
'arguments_match': False,
'arguments_score': 0.0,
'server_name': '',
'original_name': '',
'server_arguments': {},
'original_arguments': {},
'error': str(e)
}
def compare_results(self, server_data: Dict[str, List], original_data: Dict[str, List]) -> Dict:
"""详细比较服务器返回结果和原始数据"""
# 初始化比较结果结构
comparison = {
'tool_calls_comparison': {
'server_count': len(server_data['tool_calls']),
'original_count': len(original_data['function_calls']),
'detailed_scores': [],
'name_average_score': 0.0,
'arguments_average_score': 0.0,
'non_retrieval_name_average_score': 0.0,
'non_retrieval_arguments_average_score': 0.0
},
'tool_responses_comparison': {
'server_count': len(server_data['tool_responses']),
'original_count': len(original_data['observations']),
'detailed_scores': [],
'average_score': 0.0
},
'overall_scores': {
'tool_responses_avg': 0.0
}
}
# 1. 比较工具调用 (tool_calls)
tool_call_name_scores = []
tool_call_arguments_scores = []
non_retrieval_name_scores = []
non_retrieval_arguments_scores = []
max_tool_calls = max(len(server_data['tool_calls']), len(original_data['function_calls']))
for i in range(max_tool_calls):
server_call = server_data['tool_calls'][i] if i < len(server_data['tool_calls']) else None
original_call = original_data['function_calls'][i] if i < len(original_data['function_calls']) else None
if server_call is None:
# 服务器缺少该调用
score_detail = {
'index': i,
'server_present': False,
'original_present': True,
'name_score': 0.0,
'arguments_score': 0.0,
'original_call': original_call
}
elif original_call is None:
# 原始数据缺少该调用
score_detail = {
'index': i,
'server_present': True,
'original_present': False,
'name_score': 0.0,
'arguments_score': 0.0,
'server_call': server_call
}
else:
# 两者都存在,进行详细比较
# 直接使用JSON对象
call_comparison = self.compare_tool_call(server_call, original_call)
score_detail = {
'index': i,
'server_present': True,
'original_present': True,
'name_score': call_comparison['name_score'],
'arguments_score': call_comparison['arguments_score'],
'name_match': call_comparison['name_match'],
'arguments_match': call_comparison['arguments_match'],
'server_name': call_comparison['server_name'],
'original_name': call_comparison['original_name'],
'server_call': server_call, # 直接存储JSON对象
'original_call': original_call # 直接存储JSON对象
}
if 'error' in call_comparison:
score_detail['error'] = call_comparison['error']
comparison['tool_calls_comparison']['detailed_scores'].append(score_detail)
tool_call_name_scores.append(score_detail['name_score'])
tool_call_arguments_scores.append(score_detail['arguments_score'])
# 收集非retrieval_tool的评分
if server_call and original_call:
server_name = server_call.get('name', '')
original_name = original_call.get('name', '')
# 只有当两个都不是retrieval_tool时才计入非retrieval评分
if server_name != 'retrieval_tool' and original_name != 'retrieval_tool':
non_retrieval_name_scores.append(score_detail['name_score'])
non_retrieval_arguments_scores.append(score_detail['arguments_score'])
# 计算工具调用name和arguments分别的平均分
comparison['tool_calls_comparison']['name_average_score'] = (
sum(tool_call_name_scores) / len(tool_call_name_scores) if tool_call_name_scores else 0.0
)
comparison['tool_calls_comparison']['arguments_average_score'] = (
sum(tool_call_arguments_scores) / len(tool_call_arguments_scores) if tool_call_arguments_scores else 0.0
)
# 计算非retrieval_tool的name和arguments分别的平均分
comparison['tool_calls_comparison']['non_retrieval_name_average_score'] = (
sum(non_retrieval_name_scores) / len(non_retrieval_name_scores) if non_retrieval_name_scores else 0.0
)
comparison['tool_calls_comparison']['non_retrieval_arguments_average_score'] = (
sum(non_retrieval_arguments_scores) / len(non_retrieval_arguments_scores) if non_retrieval_arguments_scores else 0.0
)
# 2. 比较工具响应 (tool_responses)
tool_response_scores = []
max_tool_responses = max(len(server_data['tool_responses']), len(original_data['observations']))
for i in range(max_tool_responses):
server_response = server_data['tool_responses'][i] if i < len(server_data['tool_responses']) else None
original_response = original_data['observations'][i] if i < len(original_data['observations']) else None
if server_response is None:
# 服务器缺少该响应
score_detail = {
'index': i,
'server_present': False,
'original_present': True,
'match_score': 0.0,
'original_response': original_response
}
elif original_response is None:
# 原始数据缺少该响应
score_detail = {
'index': i,
'server_present': True,
'original_present': False,
'match_score': 0.0,
'server_response': server_response
}
else:
# 两者都存在,比较完全一致性
responses_match = server_response == original_response
match_score = 1.0 if responses_match else 0.0
score_detail = {
'index': i,
'server_present': True,
'original_present': True,
'match_score': match_score,
'responses_match': responses_match,
'server_response': server_response, # 直接存储JSON对象
'original_response': original_response # 直接存储JSON对象
}
comparison['tool_responses_comparison']['detailed_scores'].append(score_detail)
tool_response_scores.append(score_detail['match_score'])
# 计算工具响应平均分
comparison['tool_responses_comparison']['average_score'] = (
sum(tool_response_scores) / len(tool_response_scores) if tool_response_scores else 0.0
)
# 3. 计算总体评分
comparison['overall_scores']['tool_responses_avg'] = comparison['tool_responses_comparison']['average_score']
# 保持向后兼容性
comparison['tool_calls_match'] = (
comparison['tool_calls_comparison']['name_average_score'] == 1.0 and
comparison['tool_calls_comparison']['arguments_average_score'] == 1.0
)
comparison['tool_responses_match'] = comparison['overall_scores']['tool_responses_avg'] == 1.0
return comparison
def calculate_global_scores(self, results: List[Dict]) -> Dict:
"""计算多个结果的全局评分"""
if not results:
return {
'global_tool_responses_avg': 0.0,
'global_tool_calls_name_avg': 0.0,
'global_tool_calls_arguments_avg': 0.0,
'global_non_retrieval_name_avg': 0.0,
'global_non_retrieval_arguments_avg': 0.0,
'total_queries': 0
}
# 收集所有评分
all_tool_responses_scores = []
all_tool_calls_name_scores = []
all_tool_calls_arguments_scores = []
all_non_retrieval_name_scores = []
all_non_retrieval_arguments_scores = []
for result in results:
comparison = result.get('comparison', {})
overall_scores = comparison.get('overall_scores', {})
tool_calls_comparison = comparison.get('tool_calls_comparison', {})
tool_responses_avg = overall_scores.get('tool_responses_avg', 0.0)
# 收集每个工具调用的name和arguments分数
detailed_scores = tool_calls_comparison.get('detailed_scores', [])
for score_detail in detailed_scores:
if score_detail.get('server_present') and score_detail.get('original_present'):
all_tool_calls_name_scores.append(score_detail.get('name_score', 0.0))
all_tool_calls_arguments_scores.append(score_detail.get('arguments_score', 0.0))
# 收集非retrieval_tool的评分
server_call = score_detail.get('server_call', {})
original_call = score_detail.get('original_call', {})
if server_call and original_call:
server_name = server_call.get('name', '')
original_name = original_call.get('name', '')
# 只有当两个都不是retrieval_tool时才计入非retrieval评分
if server_name != 'retrieval_tool' and original_name != 'retrieval_tool':
all_non_retrieval_name_scores.append(score_detail.get('name_score', 0.0))
all_non_retrieval_arguments_scores.append(score_detail.get('arguments_score', 0.0))
all_tool_responses_scores.append(tool_responses_avg)
# 计算全局平均分
global_tool_responses_avg = sum(all_tool_responses_scores) / len(all_tool_responses_scores) if all_tool_responses_scores else 0.0
global_tool_calls_name_avg = sum(all_tool_calls_name_scores) / len(all_tool_calls_name_scores) if all_tool_calls_name_scores else 0.0
global_tool_calls_arguments_avg = sum(all_tool_calls_arguments_scores) / len(all_tool_calls_arguments_scores) if all_tool_calls_arguments_scores else 0.0
global_non_retrieval_name_avg = sum(all_non_retrieval_name_scores) / len(all_non_retrieval_name_scores) if all_non_retrieval_name_scores else 0.0
global_non_retrieval_arguments_avg = sum(all_non_retrieval_arguments_scores) / len(all_non_retrieval_arguments_scores) if all_non_retrieval_arguments_scores else 0.0
return {
'global_tool_responses_avg': global_tool_responses_avg,
'global_tool_calls_name_avg': global_tool_calls_name_avg,
'global_tool_calls_arguments_avg': global_tool_calls_arguments_avg,
'global_non_retrieval_name_avg': global_non_retrieval_name_avg,
'global_non_retrieval_arguments_avg': global_non_retrieval_arguments_avg,
'total_queries': len(results)
}
def save_results(self, results: List[Dict], output_file: str):
"""Save evaluation results to file"""
try:
# 计算全局评分
global_scores = self.calculate_global_scores(results)
# 创建包含全局评分的完整结果
complete_results = {
'global_scores': global_scores,
'results': results
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(complete_results, f, ensure_ascii=False, indent=2)
logger.info(f"Results saved to: {output_file}")
except Exception as e:
logger.error(f"Save results failed: {e}")
def save_checkpoint(self, results: List[Dict], checkpoint_file: str, processed_count: int, total_count: int):
"""保存检查点文件"""
try:
checkpoint_data = {
'processed_count': processed_count,
'total_count': total_count,
'results': results,
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
}
with open(checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, ensure_ascii=False, indent=2)
logger.info(f"Checkpoint saved: {processed_count}/{total_count} processed")
except Exception as e:
logger.error(f"Save checkpoint failed: {e}")
def load_checkpoint(self, checkpoint_file: str) -> Dict:
"""加载检查点文件"""
try:
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r', encoding='utf-8') as f:
checkpoint_data = json.load(f)
logger.info(f"Checkpoint loaded: {checkpoint_data['processed_count']}/{checkpoint_data['total_count']} processed")
return checkpoint_data
else:
logger.info("No checkpoint file found, starting from beginning")
return None
except Exception as e:
logger.error(f"Load checkpoint failed: {e}")
return None
def print_progress(self, current: int, total: int, start_time: float):
"""打印进度信息"""
if total == 0:
return
elapsed_time = time.time() - start_time
progress_percent = (current / total) * 100
if current > 0:
avg_time_per_query = elapsed_time / current
remaining_queries = total - current
estimated_remaining_time = remaining_queries * avg_time_per_query
logger.info(f"进度: {current}/{total} ({progress_percent:.1f}%) | "
f"成功: {self.success_count} | 错误: {self.error_count} | "
f"已用时间: {elapsed_time/60:.1f}分钟 | "
f"预计剩余: {estimated_remaining_time/60:.1f}分钟")
else:
logger.info(f"进度: {current}/{total} ({progress_percent:.1f}%) | "
f"成功: {self.success_count} | 错误: {self.error_count} | "
f"已用时间: {elapsed_time/60:.1f}分钟")
def save_progress_report(self, output_file: str, current: int, total: int):
"""保存进度报告"""
try:
progress_data = {
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'current_progress': current,
'total_queries': total,
'success_count': self.success_count,
'error_count': self.error_count,
'progress_percentage': (current / total * 100) if total > 0 else 0,
'elapsed_time_minutes': (time.time() - self.start_time) / 60 if self.start_time else 0
}
progress_file = f"{output_file}.progress"
with open(progress_file, 'w', encoding='utf-8') as f:
json.dump(progress_data, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"Save progress report failed: {e}")
def generate_interruption_report(self, output_file: str, processed_count: int, total_queries: int, error_message: str):
"""生成中断报告"""
try:
total_time = time.time() - self.start_time if self.start_time else 0
interruption_report = {
'interruption_type': 'server_connection_failure',
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'processed_count': processed_count,
'total_queries': total_queries,
'success_count': self.success_count,
'error_count': self.error_count,
'progress_percentage': (processed_count / total_queries * 100) if total_queries > 0 else 0,
'elapsed_time_minutes': total_time / 60,
'error_message': error_message,
'resume_instructions': {
'checkpoint_file': f"{output_file}.checkpoint",
'command': f"await evaluator.evaluate(input_file='data/9.17_evaluate_data_top5_final.json', output_file='{output_file}', resume=True)",
'note': '使用 resume=True 参数从检查点恢复评估'
}
}
interruption_file = f"{output_file}.interruption_report"
with open(interruption_file, 'w', encoding='utf-8') as f:
json.dump(interruption_report, f, ensure_ascii=False, indent=2)
logger.info(f"中断报告已保存到: {interruption_file}")
except Exception as e:
logger.error(f"Generate interruption report failed: {e}")
async def evaluate(self, input_file: str, output_file: str = "evaluation_results.json",
batch_size: int = 50, start_index: int = 0, max_queries: int = None,
checkpoint_file: str = None, resume: bool = True):
"""Execute complete evaluation process with batch processing and checkpoint support"""
logger.info("Start model evaluation...")
# 初始化时间跟踪
self.start_time = time.time()
self.error_count = 0
self.success_count = 0
# 设置检查点文件
if checkpoint_file is None:
checkpoint_file = f"{output_file}.checkpoint"
# 尝试加载检查点
checkpoint_data = None
if resume:
checkpoint_data = self.load_checkpoint(checkpoint_file)
if checkpoint_data:
self.results = checkpoint_data.get('results', [])
start_index = checkpoint_data.get('processed_count', 0)
self.success_count = len(self.results) # 假设已处理的结果都是成功的
logger.info(f"Resuming from checkpoint: {start_index} queries already processed")
# 1. Load data
data = self.load_data(input_file)
if not data:
logger.error("Cannot load data, evaluation terminated")
return
# 2. Extract queries
queries = self.extract_human_queries(data)
if not queries:
logger.error("No valid queries found, evaluation terminated")
return
# 3. Apply limits and offsets
if max_queries:
queries = queries[:max_queries]
if start_index > 0:
queries = queries[start_index:]
logger.info(f"Starting from index {start_index}, processing {len(queries)} queries")
# 4. Process queries in batches
total_queries = len(queries)
processed_count = len(self.results) # 从已有结果开始计数
for batch_start in range(0, total_queries, batch_size):
batch_end = min(batch_start + batch_size, total_queries)
batch_queries = queries[batch_start:batch_end]
logger.info(f"Processing batch {batch_start//batch_size + 1}: queries {batch_start + 1}-{batch_end} of {total_queries}")
# Process each query in the current batch
for i, query_data in enumerate(batch_queries):
global_index = batch_start + i
logger.info(f"Process {global_index + 1}/{total_queries} query: {query_data['query'][:50]}...")
try:
# 异步调用服务器
events = await self.call_server(query_data['query'])
# Extract tool calls and responses from server
server_data = self.extract_tool_calls_and_observations(events)
# Extract function_call and observation from original data
original_data = self.extract_original_data(query_data['original_data'])
# Compare results
comparison = self.compare_results(server_data, original_data)
# Save results
result = {
'index': query_data['index'],
'query': query_data['query'],
'server_events_count': len(events),
'server_tool_calls': server_data['tool_calls'],
'server_tool_responses': server_data['tool_responses'],
'original_function_calls': original_data['function_calls'],
'original_observations': original_data['observations'],
'comparison': comparison,
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
}
self.results.append(result)
processed_count += 1
self.success_count += 1
# 每处理10个查询保存一次检查点和进度报告
if processed_count % 10 == 0:
self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries)
self.save_progress_report(output_file, processed_count, total_queries)
self.print_progress(processed_count, total_queries, self.start_time)
# Add delay to avoid server pressure
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error processing query {global_index + 1}: {e}")
self.error_count += 1
# 检查是否是服务器连接失败(重试后仍然失败)
if "Server connection failed after" in str(e) or "Server timeout after" in str(e) or "Unexpected error after" in str(e):
logger.error(f"🚨 服务器连接失败,保存检查点并结束评估")
logger.error(f"失败查询: {query_data['query'][:50]}...")
# 保存当前进度到检查点
self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries)
self.save_progress_report(output_file, processed_count, total_queries)
# 生成中断报告
self.generate_interruption_report(output_file, processed_count, total_queries, str(e))
logger.error(f"评估因服务器连接失败而中断")
logger.error(f"已处理 {processed_count}/{total_queries} 个查询")
logger.error(f"检查点已保存到: {checkpoint_file}")
logger.error(f"可以稍后使用 resume=True 从检查点恢复")
return # 直接结束评估
else:
# 其他类型的错误,继续处理下一个查询
logger.warning(f"查询处理失败,继续处理下一个查询: {e}")
continue
# Save intermediate results after each batch
batch_output_file = f"{output_file}.batch_{batch_start//batch_size + 1}"
self.save_results(self.results, batch_output_file)
logger.info(f"Batch {batch_start//batch_size + 1} completed, results saved to {batch_output_file}")
# 保存检查点和进度报告
self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries)
self.save_progress_report(output_file, processed_count, total_queries)
self.print_progress(processed_count, total_queries, self.start_time)
# 5. Save final results
self.save_results(self.results, output_file)
# 6. 删除检查点文件(处理完成)
if os.path.exists(checkpoint_file):
os.remove(checkpoint_file)
logger.info("Checkpoint file removed after successful completion")
# 7. Generate summary report
self.generate_summary_report()
# 8. 最终进度报告
total_time = time.time() - self.start_time
logger.info(f"=== 评估完成 ===")
logger.info(f"总查询数: {total_queries}")
logger.info(f"成功处理: {self.success_count}")
logger.info(f"处理失败: {self.error_count}")
logger.info(f"总用时: {total_time/60:.1f}分钟")
logger.info(f"平均每查询用时: {total_time/total_queries:.1f}秒")
def generate_summary_report(self):
"""生成详细的评估摘要报告"""
if not self.results:
return
total_queries = len(self.results)
# 计算全局评分
global_scores = self.calculate_global_scores(self.results)
# 收集详细评分信息
query_details = []
for i, result in enumerate(self.results):
comparison = result['comparison']
overall_scores = comparison.get('overall_scores', {})
tool_calls_comparison = comparison.get('tool_calls_comparison', {})
tool_responses_avg = overall_scores.get('tool_responses_avg', 0.0)
tool_calls_name_avg = tool_calls_comparison.get('name_average_score', 0.0)
tool_calls_arguments_avg = tool_calls_comparison.get('arguments_average_score', 0.0)
non_retrieval_name_avg = tool_calls_comparison.get('non_retrieval_name_average_score', 0.0)
non_retrieval_arguments_avg = tool_calls_comparison.get('non_retrieval_arguments_average_score', 0.0)
query_details.append({
'index': i,
'query': result['query'][:50] + '...' if len(result['query']) > 50 else result['query'],
'tool_calls_name_score': tool_calls_name_avg,
'tool_calls_arguments_score': tool_calls_arguments_avg,
'non_retrieval_name_score': non_retrieval_name_avg,
'non_retrieval_arguments_score': non_retrieval_arguments_avg,
'tool_responses_score': tool_responses_avg
})
# 兼容性统计(完全匹配)
tool_calls_perfect_matches = sum(1 for r in self.results if r['comparison']['tool_calls_match'])
tool_responses_perfect_matches = sum(1 for r in self.results if r['comparison']['tool_responses_match'])
# 生成报告
report = f"""
=== 模型评估详细摘要报告 ===
【整体统计】
总查询数: {total_queries}
工具调用完全匹配数: {tool_calls_perfect_matches} ({tool_calls_perfect_matches/total_queries*100:.1f}%)
工具响应完全匹配数: {tool_responses_perfect_matches} ({tool_responses_perfect_matches/total_queries*100:.1f}%)
【全局平均评分】
工具名称匹配平均分: {global_scores['global_tool_calls_name_avg']:.3f}
工具参数匹配平均分: {global_scores['global_tool_calls_arguments_avg']:.3f}
非retrieval工具名称匹配平均分: {global_scores['global_non_retrieval_name_avg']:.3f}
非retrieval工具参数匹配平均分: {global_scores['global_non_retrieval_arguments_avg']:.3f}
工具响应全局平均分: {global_scores['global_tool_responses_avg']:.3f}
【各查询详细评分】"""
for detail in query_details:
report += f"""
Query {detail['index']}: {detail['query']}
- 工具名称评分: {detail['tool_calls_name_score']:.3f}
- 工具参数评分: {detail['tool_calls_arguments_score']:.3f}
- 非retrieval工具名称评分: {detail['non_retrieval_name_score']:.3f}
- 非retrieval工具参数评分: {detail['non_retrieval_arguments_score']:.3f}
- 工具响应评分: {detail['tool_responses_score']:.3f}"""
report += f"""
【评分说明】
- 工具名称匹配分: 工具名称完全一致为1分,否则为0分
- 工具参数匹配分: 工具参数完全一致为1分,否则为0分
- 非retrieval工具名称匹配分: 排除retrieval_tool后,工具名称完全一致为1分,否则为0分
- 非retrieval工具参数匹配分: 排除retrieval_tool后,工具参数完全一致为1分,否则为0分
- 工具响应评分: 完全一致为1分,否则为0分
详细结果请查看 evaluation_results.json 文件
"""
print(report)
logger.info("详细摘要报告已生成")
def test_sse_parsing(self):
"""测试SSE解析功能"""
test_data_tool_call = """id: 3
event: tool_call.created
data: {"conversation_id": "c_9c5b3617", "message_id": "m_1248", "sequence": 3, "role": "assistant", "timestamp": "2025-09-18T13:01:34.230464Z", "content": "", "tool_call": {"name": "intelligent_route_analysis", "arguments": {"access": "1"}}}
"""
test_data_tool_response = """
id: 4
event: tool_response.completed
data: {"conversation_id": "c_9c5b3617", "message_id": "m_1248", "sequence": 4, "role": "tool", "timestamp": "2025-09-18T13:01:34.358678Z", "tool_call_id": "tool_2", "result_delta": {"chat_log_id": 1234, "content": "", "markdown": "智能路由分析结果\\n\\n访问链接: [上传派团单](https://testai.compassaihz.com/#/$&!upload \\"成功匹配到对应页面\\")\\n\\n", "result": {"success": true, "url": "https://testai.compassaihz.com/#/$&!upload", "message": "成功匹配到对应页面"}, "ambulance": "", "potential_tools": [{"api": "/intelligent_route_analysis", "api_cn": "页面跳转工具", "queryData": [{"type": "select", "key": "access", "label": "数字访问码", "value": {"options": [{"label": "1", "value": "1"}, {"label": "2", "value": "2"}, {"label": "3", "value": "3"}, {"label": "4", "value": "4"}, {"label": "5", "value": "5"}, {"label": "6", "value": "6"}, {"label": "7", "value": "7"}, {"label": "8", "value": "8"}, {"label": "9", "value": "9"}, {"label": "10", "value": "10"}]}, "default": "1", "multiple": false}], "description": "智能页面路由工具,通过输入1-10的数字快速跳转到对应的业务功能页面。业务功能包括:上传派团单,新增资源,新增产品,新增协议,新增协议模版,离线上传协议,离线上传价格政策,新增价格政策,新增报表,新增审批流"}], "tool_calling_chain": [{"role": "function", "tool_call": {"name": "intelligent_route_analysis", "arguments": {"access": "1"}}, "tool_response": {"success": true, "url": "https://testai.compassaihz.com/#/$&!upload", "message": "成功匹配到对应页面"}}], "api_Info": {"api": "/intelligent_route_analysis", "api_cn": "页面跳转工具", "queryData": [{"type": "select", "key": "access", "label": "数字访问码", "value": {"options": [{"label": "1", "value": "1"}, {"label": "2", "value": "2"}, {"label": "3", "value": "3"}, {"label": "4", "value": "4"}, {"label": "5", "value": "5"}, {"label": "6", "value": "6"}, {"label": "7", "value": "7"}, {"label": "8", "value": "8"}, {"label": "9", "value": "9"}, {"label": "10", "value": "10"}]}, "default": "1", "multiple": false}], "description": "智能页面路由工具,通过输入1-10的数字快速跳转到对应的业务功能页面。业务功能包括:上传派团单,新增资源,新增产品,新增协议,新增协议模版,离线上传协议,离线上传价格政策,新增价格政策,新增报表,新增审批流"}}, "success": true, "execution_time": 0.0}
"""
logger.info("=== 开始测试SSE解析功能 ===")
# 合并测试数据
combined_test_data = test_data_tool_call + test_data_tool_response
# 使用封装的方法解析SSE内容
events = self.parse_sse_events(
combined_test_data,
filter_events=['tool_call.created', 'tool_response.completed']
)
logger.info(f"=== 测试解析结果总结 ===")
logger.info(f"总共解析到 {len(events)} 个事件")
for event in events:
logger.info(f"事件: {event['event']}, ID: {event['id']}")
logger.info(f" 数据摘要: {str(event['data'])[:100]}...")
# 测试提取功能
extracted = self.extract_tool_calls_and_observations(events)
logger.info(f"Extraction results from one tool calling: {extracted}")
return events
async def main():
"""Main function"""
evaluator = ModelEvaluator()
# 首先测试SSE解析功能
# logger.info("Test SSE parsing function...")
# evaluator.test_sse_parsing()
# Use the JSON file in the current directory
input_file = "data/9.17_evaluate_data_top5_final.json"
output_file = "eval_results/evaluation_results.json"
# 使用新的参数进行评估
# batch_size: 每批处理50个查询
# max_queries: 可以限制处理的查询数量(用于测试)
# resume: 支持断点续传
await evaluator.evaluate(
input_file=input_file,
output_file=output_file,
batch_size=50, # 每批50个查询
max_queries=None, # 处理所有查询,可以设置为较小数字进行测试
checkpoint_file="eval_results/evaluation_results.json.checkpoint",
resume=True # 支持断点续传
)
if __name__ == "__main__":
asyncio.run(main())
|