#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 模型评估脚本 功能: 1. 读取JSON文件中的对话数据 2. 提取human的value作为query调用server:8020 3. 处理流式返回结果 4. 对比和存储结果 """ import json import httpx import asyncio import time import re import os from typing import Dict, List, Any from utils.custom_logging import setup_logging from utils.extraction import extract_json_from_string from loguru import logger from collections import Counter setup_logging() class ModelEvaluator: def __init__(self, server_url: str = "http://localhost:8020/mcp_end2end/stream"): self.server_url = server_url self.results = [] self.client = None self.start_time = None self.error_count = 0 self.success_count = 0 def load_data(self, file_path: str) -> List[Dict]: """加载JSON数据文件""" try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) logger.info(f"成功加载数据文件,共{len(data)}条记录") return data except Exception as e: logger.error(f"加载数据文件失败: {e}") return [] def extract_human_queries(self, data: List[Dict]) -> List[Dict]: """提取所有human的value作为query""" queries = [] for i, item in enumerate(data): if 'conversations' in item: for conv in item['conversations']: if conv.get('from') == 'human': query_data = { 'index': i, 'query': conv.get('value', ''), 'original_data': item } queries.append(query_data) break # 只取第一个human的value logger.info(f"提取到{len(queries)}个查询") return queries def parse_sse_events(self, sse_content: str, filter_events: List[str] = None) -> List[Dict]: """ 解析SSE格式的内容,提取指定类型的事件 Args: sse_content: SSE格式的文本内容(可以是多行) filter_events: 需要过滤的事件类型列表,如果为None则解析所有事件 Returns: 解析成功的事件列表 """ events = [] current_event = {} parsed_count = 0 failed_count = 0 # 如果没有指定过滤事件,设置默认过滤 if filter_events is None: filter_events = ['tool_call.created', 'tool_response.completed'] for line in sse_content.split('\n'): line = line.strip() if not line: # 空行表示一个完整的事件结束 if current_event and 'event' in current_event and 'data' in current_event: event_type = current_event['event'] # 只处理我们关心的事件类型 if event_type in filter_events: # 使用extract_json_from_string解析data字段中的JSON data_content = extract_json_from_string(current_event['data']) if data_content is not None: event_obj = { 'id': current_event.get('id'), 'event': current_event['event'], 'data': data_content } events.append(event_obj) parsed_count += 1 logger.debug(f"✅ 成功解析事件: {current_event['event']} (ID: {current_event.get('id', 'N/A')})") else: failed_count += 1 logger.warning(f"❌ 无法解析事件数据: {current_event['event']} - {current_event['data'][:100]}...") current_event = {} continue # 解析SSE格式的字段 if line.startswith('id: '): current_event['id'] = line[4:] elif line.startswith('event: '): current_event['event'] = line[7:] elif line.startswith('data: '): current_event['data'] = line[6:] # 检查是否是结束标记 if current_event['data'].strip() == '[DONE]': logger.debug("收到结束标记 [DONE]") break else: logger.debug(f"未知格式的行: {line}") # 处理最后一个事件(如果没有空行结尾) if current_event and 'event' in current_event and 'data' in current_event: event_type = current_event['event'] # 只处理我们关心的事件类型 if event_type in filter_events: data_content = extract_json_from_string(current_event['data']) if data_content is not None: event_obj = { 'id': current_event.get('id'), 'event': current_event['event'], 'data': data_content } events.append(event_obj) parsed_count += 1 logger.debug(f"✅ 成功解析最后一个事件: {current_event['event']} (ID: {current_event.get('id', 'N/A')})") else: failed_count += 1 logger.warning(f"❌ 无法解析最后一个事件数据: {current_event['event']} - {current_event['data'][:100]}...") # 统计和日志输出 logger.info(f"=== SSE解析结果统计 ===") logger.info(f"成功解析事件数: {parsed_count}") logger.info(f"解析失败事件数: {failed_count}") logger.info(f"总事件数: {len(events)}") if events: event_types = [event.get('event', 'unknown') for event in events] event_counts = Counter(event_types) logger.info(f"事件类型分布: {dict(event_counts)}") else: logger.warning("⚠️ 未解析到任何目标事件") return events async def call_server(self, query: str, max_retries: int = 3, retry_delay: float = 2.0) -> List[Dict]: """异步调用server:8020端口,处理流式返回,支持重试机制""" payload = { "user_id": "166", "role_code": 1, "query": query, "save_method": 0 } for attempt in range(max_retries): try: async with httpx.AsyncClient(timeout=30.0) as client: async with client.stream( 'POST', self.server_url, json=payload, headers={'Accept': 'text/event-stream'} ) as response: response.raise_for_status() # 收集所有SSE文本内容 sse_content = "" async for line in response.aiter_text(): logger.debug(f"Received data: {line}") sse_content += line # 检查是否收到结束标记 if '[DONE]' in line: logger.debug("收到结束标记 [DONE]") break # 使用封装的方法解析SSE内容,只解析我们关心的事件 events = self.parse_sse_events( sse_content, filter_events=['tool_call.created', 'tool_response.completed'] ) # 验证关键事件类型 has_tool_call = any(event.get('event') == 'tool_call.created' for event in events) has_tool_response = any(event.get('event') == 'tool_response.completed' for event in events) logger.info(f"包含工具调用事件: {'✅' if has_tool_call else '❌'}") logger.info(f"包含工具响应事件: {'✅' if has_tool_response else '❌'}") return events except httpx.RequestError as e: logger.warning(f"Call server failed (attempt {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: logger.info(f"Retrying in {retry_delay} seconds...") await asyncio.sleep(retry_delay) retry_delay *= 2 # 指数退避 else: logger.error(f"All retry attempts failed for query: {query[:50]}...") raise Exception(f"Server connection failed after {max_retries} attempts: {e}") except httpx.TimeoutException as e: logger.warning(f"Server timeout (attempt {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: logger.info(f"Retrying in {retry_delay} seconds...") await asyncio.sleep(retry_delay) retry_delay *= 2 else: logger.error(f"Timeout after all retry attempts for query: {query[:50]}...") raise Exception(f"Server timeout after {max_retries} attempts: {e}") except Exception as e: logger.error(f"Unexpected error processing response (attempt {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: logger.info(f"Retrying in {retry_delay} seconds...") await asyncio.sleep(retry_delay) retry_delay *= 2 else: logger.error(f"Unexpected error after all retry attempts for query: {query[:50]}...") raise Exception(f"Unexpected error after {max_retries} attempts: {e}") raise Exception("All retry attempts exhausted") def extract_tool_calls_and_observations(self, events: List[Dict]) -> Dict[str, List]: """Extract tool_call.created and tool_response.completed content from events""" tool_calls = [] tool_responses = [] logger.debug(f"开始提取工具调用和响应,共 {len(events)} 个事件") for event in events: event_type = event.get('event') event_data = event.get('data', {}) if event_type == 'tool_call.created': logger.debug(f"Extract tool_call.created content: {event}") # Extract tool_call information tool_call_info = event_data.get('tool_call', {}) if tool_call_info: tool_calls.append(tool_call_info) # 直接存储JSON对象 logger.debug(f"✅ 提取工具调用: {tool_call_info.get('name', 'unknown')}") else: logger.warning(f"❌ tool_call.created 事件中缺少 tool_call 信息") elif event_type == 'tool_response.completed': logger.debug(f"Extract tool_response.completed content: {event}") # Extract tool_response information if 'result_delta' in event_data: tool_response = event_data['result_delta'].get('result', []) tool_responses.append(tool_response) # 直接存储JSON对象 logger.debug(f"✅ 提取工具响应: {len(str(tool_response))} 字符") else: tool_response = [] tool_responses.append(tool_response) # 直接存储JSON对象 logger.info(f"Extract {len(tool_calls)} tool calls, {len(tool_responses)} tool responses") return { 'tool_calls': tool_calls, 'tool_responses': tool_responses } def extract_original_data(self, original_data: Dict) -> Dict[str, List]: """Extract function_call and observation content from original data""" function_calls = [] observations = [] if 'conversations' in original_data: for conv in original_data['conversations']: if conv.get('from') == 'function_call': # 解析JSON字符串为对象 try: function_call_obj = json.loads(conv.get('value', '{}')) function_calls.append(function_call_obj) except json.JSONDecodeError as e: logger.warning(f"解析function_call JSON时出错: {e}") function_calls.append({}) elif conv.get('from') == 'observation': # 解析JSON字符串为对象 try: observation_obj = json.loads(conv.get('value', '[]')) observations.append(observation_obj) except json.JSONDecodeError as e: logger.warning(f"解析observation JSON时出错: {e}") observations.append([]) return { 'function_calls': function_calls, 'observations': observations } def compare_tool_call(self, server_call: Dict, original_call: Dict) -> Dict: """比较单个工具调用,检查name和arguments的匹配度""" try: # 检查name是否一致 name_match = server_call.get('name') == original_call.get('name') name_score = 1.0 if name_match else 0.0 # 检查arguments是否一致 server_args = server_call.get('arguments', {}) original_args = original_call.get('arguments', {}) arguments_match = server_args == original_args arguments_score = 1.0 if arguments_match else 0.0 return { 'name_match': name_match, 'name_score': name_score, 'arguments_match': arguments_match, 'arguments_score': arguments_score, 'server_name': server_call.get('name', ''), 'original_name': original_call.get('name', ''), 'server_arguments': server_args, 'original_arguments': original_args } except (KeyError, TypeError) as e: logger.warning(f"比较工具调用时出错: {e}") return { 'name_match': False, 'name_score': 0.0, 'arguments_match': False, 'arguments_score': 0.0, 'server_name': '', 'original_name': '', 'server_arguments': {}, 'original_arguments': {}, 'error': str(e) } def compare_results(self, server_data: Dict[str, List], original_data: Dict[str, List]) -> Dict: """详细比较服务器返回结果和原始数据""" # 初始化比较结果结构 comparison = { 'tool_calls_comparison': { 'server_count': len(server_data['tool_calls']), 'original_count': len(original_data['function_calls']), 'detailed_scores': [], 'name_average_score': 0.0, 'arguments_average_score': 0.0, 'non_retrieval_name_average_score': 0.0, 'non_retrieval_arguments_average_score': 0.0 }, 'tool_responses_comparison': { 'server_count': len(server_data['tool_responses']), 'original_count': len(original_data['observations']), 'detailed_scores': [], 'average_score': 0.0 }, 'overall_scores': { 'tool_responses_avg': 0.0 } } # 1. 比较工具调用 (tool_calls) tool_call_name_scores = [] tool_call_arguments_scores = [] non_retrieval_name_scores = [] non_retrieval_arguments_scores = [] max_tool_calls = max(len(server_data['tool_calls']), len(original_data['function_calls'])) for i in range(max_tool_calls): server_call = server_data['tool_calls'][i] if i < len(server_data['tool_calls']) else None original_call = original_data['function_calls'][i] if i < len(original_data['function_calls']) else None if server_call is None: # 服务器缺少该调用 score_detail = { 'index': i, 'server_present': False, 'original_present': True, 'name_score': 0.0, 'arguments_score': 0.0, 'original_call': original_call } elif original_call is None: # 原始数据缺少该调用 score_detail = { 'index': i, 'server_present': True, 'original_present': False, 'name_score': 0.0, 'arguments_score': 0.0, 'server_call': server_call } else: # 两者都存在,进行详细比较 # 直接使用JSON对象 call_comparison = self.compare_tool_call(server_call, original_call) score_detail = { 'index': i, 'server_present': True, 'original_present': True, 'name_score': call_comparison['name_score'], 'arguments_score': call_comparison['arguments_score'], 'name_match': call_comparison['name_match'], 'arguments_match': call_comparison['arguments_match'], 'server_name': call_comparison['server_name'], 'original_name': call_comparison['original_name'], 'server_call': server_call, # 直接存储JSON对象 'original_call': original_call # 直接存储JSON对象 } if 'error' in call_comparison: score_detail['error'] = call_comparison['error'] comparison['tool_calls_comparison']['detailed_scores'].append(score_detail) tool_call_name_scores.append(score_detail['name_score']) tool_call_arguments_scores.append(score_detail['arguments_score']) # 收集非retrieval_tool的评分 if server_call and original_call: server_name = server_call.get('name', '') original_name = original_call.get('name', '') # 只有当两个都不是retrieval_tool时才计入非retrieval评分 if server_name != 'retrieval_tool' and original_name != 'retrieval_tool': non_retrieval_name_scores.append(score_detail['name_score']) non_retrieval_arguments_scores.append(score_detail['arguments_score']) # 计算工具调用name和arguments分别的平均分 comparison['tool_calls_comparison']['name_average_score'] = ( sum(tool_call_name_scores) / len(tool_call_name_scores) if tool_call_name_scores else 0.0 ) comparison['tool_calls_comparison']['arguments_average_score'] = ( sum(tool_call_arguments_scores) / len(tool_call_arguments_scores) if tool_call_arguments_scores else 0.0 ) # 计算非retrieval_tool的name和arguments分别的平均分 comparison['tool_calls_comparison']['non_retrieval_name_average_score'] = ( sum(non_retrieval_name_scores) / len(non_retrieval_name_scores) if non_retrieval_name_scores else 0.0 ) comparison['tool_calls_comparison']['non_retrieval_arguments_average_score'] = ( sum(non_retrieval_arguments_scores) / len(non_retrieval_arguments_scores) if non_retrieval_arguments_scores else 0.0 ) # 2. 比较工具响应 (tool_responses) tool_response_scores = [] max_tool_responses = max(len(server_data['tool_responses']), len(original_data['observations'])) for i in range(max_tool_responses): server_response = server_data['tool_responses'][i] if i < len(server_data['tool_responses']) else None original_response = original_data['observations'][i] if i < len(original_data['observations']) else None if server_response is None: # 服务器缺少该响应 score_detail = { 'index': i, 'server_present': False, 'original_present': True, 'match_score': 0.0, 'original_response': original_response } elif original_response is None: # 原始数据缺少该响应 score_detail = { 'index': i, 'server_present': True, 'original_present': False, 'match_score': 0.0, 'server_response': server_response } else: # 两者都存在,比较完全一致性 responses_match = server_response == original_response match_score = 1.0 if responses_match else 0.0 score_detail = { 'index': i, 'server_present': True, 'original_present': True, 'match_score': match_score, 'responses_match': responses_match, 'server_response': server_response, # 直接存储JSON对象 'original_response': original_response # 直接存储JSON对象 } comparison['tool_responses_comparison']['detailed_scores'].append(score_detail) tool_response_scores.append(score_detail['match_score']) # 计算工具响应平均分 comparison['tool_responses_comparison']['average_score'] = ( sum(tool_response_scores) / len(tool_response_scores) if tool_response_scores else 0.0 ) # 3. 计算总体评分 comparison['overall_scores']['tool_responses_avg'] = comparison['tool_responses_comparison']['average_score'] # 保持向后兼容性 comparison['tool_calls_match'] = ( comparison['tool_calls_comparison']['name_average_score'] == 1.0 and comparison['tool_calls_comparison']['arguments_average_score'] == 1.0 ) comparison['tool_responses_match'] = comparison['overall_scores']['tool_responses_avg'] == 1.0 return comparison def calculate_global_scores(self, results: List[Dict]) -> Dict: """计算多个结果的全局评分""" if not results: return { 'global_tool_responses_avg': 0.0, 'global_tool_calls_name_avg': 0.0, 'global_tool_calls_arguments_avg': 0.0, 'global_non_retrieval_name_avg': 0.0, 'global_non_retrieval_arguments_avg': 0.0, 'total_queries': 0 } # 收集所有评分 all_tool_responses_scores = [] all_tool_calls_name_scores = [] all_tool_calls_arguments_scores = [] all_non_retrieval_name_scores = [] all_non_retrieval_arguments_scores = [] for result in results: comparison = result.get('comparison', {}) overall_scores = comparison.get('overall_scores', {}) tool_calls_comparison = comparison.get('tool_calls_comparison', {}) tool_responses_avg = overall_scores.get('tool_responses_avg', 0.0) # 收集每个工具调用的name和arguments分数 detailed_scores = tool_calls_comparison.get('detailed_scores', []) for score_detail in detailed_scores: if score_detail.get('server_present') and score_detail.get('original_present'): all_tool_calls_name_scores.append(score_detail.get('name_score', 0.0)) all_tool_calls_arguments_scores.append(score_detail.get('arguments_score', 0.0)) # 收集非retrieval_tool的评分 server_call = score_detail.get('server_call', {}) original_call = score_detail.get('original_call', {}) if server_call and original_call: server_name = server_call.get('name', '') original_name = original_call.get('name', '') # 只有当两个都不是retrieval_tool时才计入非retrieval评分 if server_name != 'retrieval_tool' and original_name != 'retrieval_tool': all_non_retrieval_name_scores.append(score_detail.get('name_score', 0.0)) all_non_retrieval_arguments_scores.append(score_detail.get('arguments_score', 0.0)) all_tool_responses_scores.append(tool_responses_avg) # 计算全局平均分 global_tool_responses_avg = sum(all_tool_responses_scores) / len(all_tool_responses_scores) if all_tool_responses_scores else 0.0 global_tool_calls_name_avg = sum(all_tool_calls_name_scores) / len(all_tool_calls_name_scores) if all_tool_calls_name_scores else 0.0 global_tool_calls_arguments_avg = sum(all_tool_calls_arguments_scores) / len(all_tool_calls_arguments_scores) if all_tool_calls_arguments_scores else 0.0 global_non_retrieval_name_avg = sum(all_non_retrieval_name_scores) / len(all_non_retrieval_name_scores) if all_non_retrieval_name_scores else 0.0 global_non_retrieval_arguments_avg = sum(all_non_retrieval_arguments_scores) / len(all_non_retrieval_arguments_scores) if all_non_retrieval_arguments_scores else 0.0 return { 'global_tool_responses_avg': global_tool_responses_avg, 'global_tool_calls_name_avg': global_tool_calls_name_avg, 'global_tool_calls_arguments_avg': global_tool_calls_arguments_avg, 'global_non_retrieval_name_avg': global_non_retrieval_name_avg, 'global_non_retrieval_arguments_avg': global_non_retrieval_arguments_avg, 'total_queries': len(results) } def save_results(self, results: List[Dict], output_file: str): """Save evaluation results to file""" try: # 计算全局评分 global_scores = self.calculate_global_scores(results) # 创建包含全局评分的完整结果 complete_results = { 'global_scores': global_scores, 'results': results } with open(output_file, 'w', encoding='utf-8') as f: json.dump(complete_results, f, ensure_ascii=False, indent=2) logger.info(f"Results saved to: {output_file}") except Exception as e: logger.error(f"Save results failed: {e}") def save_checkpoint(self, results: List[Dict], checkpoint_file: str, processed_count: int, total_count: int): """保存检查点文件""" try: checkpoint_data = { 'processed_count': processed_count, 'total_count': total_count, 'results': results, 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S') } with open(checkpoint_file, 'w', encoding='utf-8') as f: json.dump(checkpoint_data, f, ensure_ascii=False, indent=2) logger.info(f"Checkpoint saved: {processed_count}/{total_count} processed") except Exception as e: logger.error(f"Save checkpoint failed: {e}") def load_checkpoint(self, checkpoint_file: str) -> Dict: """加载检查点文件""" try: if os.path.exists(checkpoint_file): with open(checkpoint_file, 'r', encoding='utf-8') as f: checkpoint_data = json.load(f) logger.info(f"Checkpoint loaded: {checkpoint_data['processed_count']}/{checkpoint_data['total_count']} processed") return checkpoint_data else: logger.info("No checkpoint file found, starting from beginning") return None except Exception as e: logger.error(f"Load checkpoint failed: {e}") return None def print_progress(self, current: int, total: int, start_time: float): """打印进度信息""" if total == 0: return elapsed_time = time.time() - start_time progress_percent = (current / total) * 100 if current > 0: avg_time_per_query = elapsed_time / current remaining_queries = total - current estimated_remaining_time = remaining_queries * avg_time_per_query logger.info(f"进度: {current}/{total} ({progress_percent:.1f}%) | " f"成功: {self.success_count} | 错误: {self.error_count} | " f"已用时间: {elapsed_time/60:.1f}分钟 | " f"预计剩余: {estimated_remaining_time/60:.1f}分钟") else: logger.info(f"进度: {current}/{total} ({progress_percent:.1f}%) | " f"成功: {self.success_count} | 错误: {self.error_count} | " f"已用时间: {elapsed_time/60:.1f}分钟") def save_progress_report(self, output_file: str, current: int, total: int): """保存进度报告""" try: progress_data = { 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'current_progress': current, 'total_queries': total, 'success_count': self.success_count, 'error_count': self.error_count, 'progress_percentage': (current / total * 100) if total > 0 else 0, 'elapsed_time_minutes': (time.time() - self.start_time) / 60 if self.start_time else 0 } progress_file = f"{output_file}.progress" with open(progress_file, 'w', encoding='utf-8') as f: json.dump(progress_data, f, ensure_ascii=False, indent=2) except Exception as e: logger.error(f"Save progress report failed: {e}") def generate_interruption_report(self, output_file: str, processed_count: int, total_queries: int, error_message: str): """生成中断报告""" try: total_time = time.time() - self.start_time if self.start_time else 0 interruption_report = { 'interruption_type': 'server_connection_failure', 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'processed_count': processed_count, 'total_queries': total_queries, 'success_count': self.success_count, 'error_count': self.error_count, 'progress_percentage': (processed_count / total_queries * 100) if total_queries > 0 else 0, 'elapsed_time_minutes': total_time / 60, 'error_message': error_message, 'resume_instructions': { 'checkpoint_file': f"{output_file}.checkpoint", 'command': f"await evaluator.evaluate(input_file='data/9.17_evaluate_data_top5_final.json', output_file='{output_file}', resume=True)", 'note': '使用 resume=True 参数从检查点恢复评估' } } interruption_file = f"{output_file}.interruption_report" with open(interruption_file, 'w', encoding='utf-8') as f: json.dump(interruption_report, f, ensure_ascii=False, indent=2) logger.info(f"中断报告已保存到: {interruption_file}") except Exception as e: logger.error(f"Generate interruption report failed: {e}") async def evaluate(self, input_file: str, output_file: str = "evaluation_results.json", batch_size: int = 50, start_index: int = 0, max_queries: int = None, checkpoint_file: str = None, resume: bool = True): """Execute complete evaluation process with batch processing and checkpoint support""" logger.info("Start model evaluation...") # 初始化时间跟踪 self.start_time = time.time() self.error_count = 0 self.success_count = 0 # 设置检查点文件 if checkpoint_file is None: checkpoint_file = f"{output_file}.checkpoint" # 尝试加载检查点 checkpoint_data = None if resume: checkpoint_data = self.load_checkpoint(checkpoint_file) if checkpoint_data: self.results = checkpoint_data.get('results', []) start_index = checkpoint_data.get('processed_count', 0) self.success_count = len(self.results) # 假设已处理的结果都是成功的 logger.info(f"Resuming from checkpoint: {start_index} queries already processed") # 1. Load data data = self.load_data(input_file) if not data: logger.error("Cannot load data, evaluation terminated") return # 2. Extract queries queries = self.extract_human_queries(data) if not queries: logger.error("No valid queries found, evaluation terminated") return # 3. Apply limits and offsets if max_queries: queries = queries[:max_queries] if start_index > 0: queries = queries[start_index:] logger.info(f"Starting from index {start_index}, processing {len(queries)} queries") # 4. Process queries in batches total_queries = len(queries) processed_count = len(self.results) # 从已有结果开始计数 for batch_start in range(0, total_queries, batch_size): batch_end = min(batch_start + batch_size, total_queries) batch_queries = queries[batch_start:batch_end] logger.info(f"Processing batch {batch_start//batch_size + 1}: queries {batch_start + 1}-{batch_end} of {total_queries}") # Process each query in the current batch for i, query_data in enumerate(batch_queries): global_index = batch_start + i logger.info(f"Process {global_index + 1}/{total_queries} query: {query_data['query'][:50]}...") try: # 异步调用服务器 events = await self.call_server(query_data['query']) # Extract tool calls and responses from server server_data = self.extract_tool_calls_and_observations(events) # Extract function_call and observation from original data original_data = self.extract_original_data(query_data['original_data']) # Compare results comparison = self.compare_results(server_data, original_data) # Save results result = { 'index': query_data['index'], 'query': query_data['query'], 'server_events_count': len(events), 'server_tool_calls': server_data['tool_calls'], 'server_tool_responses': server_data['tool_responses'], 'original_function_calls': original_data['function_calls'], 'original_observations': original_data['observations'], 'comparison': comparison, 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S') } self.results.append(result) processed_count += 1 self.success_count += 1 # 每处理10个查询保存一次检查点和进度报告 if processed_count % 10 == 0: self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries) self.save_progress_report(output_file, processed_count, total_queries) self.print_progress(processed_count, total_queries, self.start_time) # Add delay to avoid server pressure await asyncio.sleep(1) except Exception as e: logger.error(f"Error processing query {global_index + 1}: {e}") self.error_count += 1 # 检查是否是服务器连接失败(重试后仍然失败) if "Server connection failed after" in str(e) or "Server timeout after" in str(e) or "Unexpected error after" in str(e): logger.error(f"🚨 服务器连接失败,保存检查点并结束评估") logger.error(f"失败查询: {query_data['query'][:50]}...") # 保存当前进度到检查点 self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries) self.save_progress_report(output_file, processed_count, total_queries) # 生成中断报告 self.generate_interruption_report(output_file, processed_count, total_queries, str(e)) logger.error(f"评估因服务器连接失败而中断") logger.error(f"已处理 {processed_count}/{total_queries} 个查询") logger.error(f"检查点已保存到: {checkpoint_file}") logger.error(f"可以稍后使用 resume=True 从检查点恢复") return # 直接结束评估 else: # 其他类型的错误,继续处理下一个查询 logger.warning(f"查询处理失败,继续处理下一个查询: {e}") continue # Save intermediate results after each batch batch_output_file = f"{output_file}.batch_{batch_start//batch_size + 1}" self.save_results(self.results, batch_output_file) logger.info(f"Batch {batch_start//batch_size + 1} completed, results saved to {batch_output_file}") # 保存检查点和进度报告 self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries) self.save_progress_report(output_file, processed_count, total_queries) self.print_progress(processed_count, total_queries, self.start_time) # 5. Save final results self.save_results(self.results, output_file) # 6. 删除检查点文件(处理完成) if os.path.exists(checkpoint_file): os.remove(checkpoint_file) logger.info("Checkpoint file removed after successful completion") # 7. Generate summary report self.generate_summary_report() # 8. 最终进度报告 total_time = time.time() - self.start_time logger.info(f"=== 评估完成 ===") logger.info(f"总查询数: {total_queries}") logger.info(f"成功处理: {self.success_count}") logger.info(f"处理失败: {self.error_count}") logger.info(f"总用时: {total_time/60:.1f}分钟") logger.info(f"平均每查询用时: {total_time/total_queries:.1f}秒") def generate_summary_report(self): """生成详细的评估摘要报告""" if not self.results: return total_queries = len(self.results) # 计算全局评分 global_scores = self.calculate_global_scores(self.results) # 收集详细评分信息 query_details = [] for i, result in enumerate(self.results): comparison = result['comparison'] overall_scores = comparison.get('overall_scores', {}) tool_calls_comparison = comparison.get('tool_calls_comparison', {}) tool_responses_avg = overall_scores.get('tool_responses_avg', 0.0) tool_calls_name_avg = tool_calls_comparison.get('name_average_score', 0.0) tool_calls_arguments_avg = tool_calls_comparison.get('arguments_average_score', 0.0) non_retrieval_name_avg = tool_calls_comparison.get('non_retrieval_name_average_score', 0.0) non_retrieval_arguments_avg = tool_calls_comparison.get('non_retrieval_arguments_average_score', 0.0) query_details.append({ 'index': i, 'query': result['query'][:50] + '...' if len(result['query']) > 50 else result['query'], 'tool_calls_name_score': tool_calls_name_avg, 'tool_calls_arguments_score': tool_calls_arguments_avg, 'non_retrieval_name_score': non_retrieval_name_avg, 'non_retrieval_arguments_score': non_retrieval_arguments_avg, 'tool_responses_score': tool_responses_avg }) # 兼容性统计(完全匹配) tool_calls_perfect_matches = sum(1 for r in self.results if r['comparison']['tool_calls_match']) tool_responses_perfect_matches = sum(1 for r in self.results if r['comparison']['tool_responses_match']) # 生成报告 report = f""" === 模型评估详细摘要报告 === 【整体统计】 总查询数: {total_queries} 工具调用完全匹配数: {tool_calls_perfect_matches} ({tool_calls_perfect_matches/total_queries*100:.1f}%) 工具响应完全匹配数: {tool_responses_perfect_matches} ({tool_responses_perfect_matches/total_queries*100:.1f}%) 【全局平均评分】 工具名称匹配平均分: {global_scores['global_tool_calls_name_avg']:.3f} 工具参数匹配平均分: {global_scores['global_tool_calls_arguments_avg']:.3f} 非retrieval工具名称匹配平均分: {global_scores['global_non_retrieval_name_avg']:.3f} 非retrieval工具参数匹配平均分: {global_scores['global_non_retrieval_arguments_avg']:.3f} 工具响应全局平均分: {global_scores['global_tool_responses_avg']:.3f} 【各查询详细评分】""" for detail in query_details: report += f""" Query {detail['index']}: {detail['query']} - 工具名称评分: {detail['tool_calls_name_score']:.3f} - 工具参数评分: {detail['tool_calls_arguments_score']:.3f} - 非retrieval工具名称评分: {detail['non_retrieval_name_score']:.3f} - 非retrieval工具参数评分: {detail['non_retrieval_arguments_score']:.3f} - 工具响应评分: {detail['tool_responses_score']:.3f}""" report += f""" 【评分说明】 - 工具名称匹配分: 工具名称完全一致为1分,否则为0分 - 工具参数匹配分: 工具参数完全一致为1分,否则为0分 - 非retrieval工具名称匹配分: 排除retrieval_tool后,工具名称完全一致为1分,否则为0分 - 非retrieval工具参数匹配分: 排除retrieval_tool后,工具参数完全一致为1分,否则为0分 - 工具响应评分: 完全一致为1分,否则为0分 详细结果请查看 evaluation_results.json 文件 """ print(report) logger.info("详细摘要报告已生成") def test_sse_parsing(self): """测试SSE解析功能""" test_data_tool_call = """id: 3 event: tool_call.created data: {"conversation_id": "c_9c5b3617", "message_id": "m_1248", "sequence": 3, "role": "assistant", "timestamp": "2025-09-18T13:01:34.230464Z", "content": "", "tool_call": {"name": "intelligent_route_analysis", "arguments": {"access": "1"}}} """ test_data_tool_response = """ id: 4 event: tool_response.completed data: {"conversation_id": "c_9c5b3617", "message_id": "m_1248", "sequence": 4, "role": "tool", "timestamp": "2025-09-18T13:01:34.358678Z", "tool_call_id": "tool_2", "result_delta": {"chat_log_id": 1234, "content": "", "markdown": "智能路由分析结果\\n\\n访问链接: [上传派团单](https://testai.compassaihz.com/#/$&!upload \\"成功匹配到对应页面\\")\\n\\n", "result": {"success": true, "url": "https://testai.compassaihz.com/#/$&!upload", "message": "成功匹配到对应页面"}, "ambulance": "", "potential_tools": [{"api": "/intelligent_route_analysis", "api_cn": "页面跳转工具", "queryData": [{"type": "select", "key": "access", "label": "数字访问码", "value": {"options": [{"label": "1", "value": "1"}, {"label": "2", "value": "2"}, {"label": "3", "value": "3"}, {"label": "4", "value": "4"}, {"label": "5", "value": "5"}, {"label": "6", "value": "6"}, {"label": "7", "value": "7"}, {"label": "8", "value": "8"}, {"label": "9", "value": "9"}, {"label": "10", "value": "10"}]}, "default": "1", "multiple": false}], "description": "智能页面路由工具,通过输入1-10的数字快速跳转到对应的业务功能页面。业务功能包括:上传派团单,新增资源,新增产品,新增协议,新增协议模版,离线上传协议,离线上传价格政策,新增价格政策,新增报表,新增审批流"}], "tool_calling_chain": [{"role": "function", "tool_call": {"name": "intelligent_route_analysis", "arguments": {"access": "1"}}, "tool_response": {"success": true, "url": "https://testai.compassaihz.com/#/$&!upload", "message": "成功匹配到对应页面"}}], "api_Info": {"api": "/intelligent_route_analysis", "api_cn": "页面跳转工具", "queryData": [{"type": "select", "key": "access", "label": "数字访问码", "value": {"options": [{"label": "1", "value": "1"}, {"label": "2", "value": "2"}, {"label": "3", "value": "3"}, {"label": "4", "value": "4"}, {"label": "5", "value": "5"}, {"label": "6", "value": "6"}, {"label": "7", "value": "7"}, {"label": "8", "value": "8"}, {"label": "9", "value": "9"}, {"label": "10", "value": "10"}]}, "default": "1", "multiple": false}], "description": "智能页面路由工具,通过输入1-10的数字快速跳转到对应的业务功能页面。业务功能包括:上传派团单,新增资源,新增产品,新增协议,新增协议模版,离线上传协议,离线上传价格政策,新增价格政策,新增报表,新增审批流"}}, "success": true, "execution_time": 0.0} """ logger.info("=== 开始测试SSE解析功能 ===") # 合并测试数据 combined_test_data = test_data_tool_call + test_data_tool_response # 使用封装的方法解析SSE内容 events = self.parse_sse_events( combined_test_data, filter_events=['tool_call.created', 'tool_response.completed'] ) logger.info(f"=== 测试解析结果总结 ===") logger.info(f"总共解析到 {len(events)} 个事件") for event in events: logger.info(f"事件: {event['event']}, ID: {event['id']}") logger.info(f" 数据摘要: {str(event['data'])[:100]}...") # 测试提取功能 extracted = self.extract_tool_calls_and_observations(events) logger.info(f"Extraction results from one tool calling: {extracted}") return events async def main(): """Main function""" evaluator = ModelEvaluator() # 首先测试SSE解析功能 # logger.info("Test SSE parsing function...") # evaluator.test_sse_parsing() # Use the JSON file in the current directory input_file = "data/9.17_evaluate_data_top5_final.json" output_file = "eval_results/evaluation_results.json" # 使用新的参数进行评估 # batch_size: 每批处理50个查询 # max_queries: 可以限制处理的查询数量(用于测试) # resume: 支持断点续传 await evaluator.evaluate( input_file=input_file, output_file=output_file, batch_size=50, # 每批50个查询 max_queries=None, # 处理所有查询,可以设置为较小数字进行测试 checkpoint_file="eval_results/evaluation_results.json.checkpoint", resume=True # 支持断点续传 ) if __name__ == "__main__": asyncio.run(main())