| |
| |
| """ |
| 计算tool_calls_comparison中precision和recall的脚本 |
| """ |
|
|
| import json |
| import numpy as np |
| from collections import defaultdict |
| from typing import Dict, List, Tuple, Any |
| from loguru import logger |
| from utils.custom_logging import setup_logging |
| import os |
| from datetime import datetime |
|
|
| setup_logging() |
|
|
| def load_evaluation_results(file_path: str) -> Dict[str, Any]: |
| """加载评估结果文件""" |
| with open(file_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
|
|
| def calculate_precision_recall(tp: int, fp: int, fn: int) -> Tuple[float, float]: |
| """计算precision和recall""" |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| return precision, recall |
|
|
| def analyze_tool_calls_comparison(results: List[Dict]) -> Dict[str, Any]: |
| """分析tool_calls_comparison数据""" |
| |
| |
| stats = { |
| 'total_queries': len(results), |
| 'retrieval_tool': { |
| 'name': {'tp': 0, 'fp': 0, 'fn': 0}, |
| 'arguments': {'tp': 0, 'fp': 0, 'fn': 0} |
| }, |
| 'non_retrieval_tool': { |
| 'name': {'tp': 0, 'fp': 0, 'fn': 0}, |
| 'arguments': {'tp': 0, 'fp': 0, 'fn': 0} |
| }, |
| 'all_tools': { |
| 'name': {'tp': 0, 'fp': 0, 'fn': 0}, |
| 'arguments': {'tp': 0, 'fp': 0, 'fn': 0} |
| } |
| } |
| |
| processed_queries = 0 |
| total_tool_calls = 0 |
| |
| for result in results: |
| if 'comparison' not in result or 'tool_calls_comparison' not in result['comparison']: |
| continue |
| |
| processed_queries += 1 |
| comparison = result['comparison']['tool_calls_comparison'] |
| detailed_scores = comparison.get('detailed_scores', []) |
| total_tool_calls += len(detailed_scores) |
| |
| |
| for score in detailed_scores: |
| server_name = score.get('server_name', '') |
| original_name = score.get('original_name', '') |
| name_match = score.get('name_match', False) |
| arguments_match = score.get('arguments_match', False) |
| server_present = score.get('server_present', False) |
| original_present = score.get('original_present', False) |
| |
| |
| is_retrieval = 'retrieval' in server_name.lower() or 'retrieval' in original_name.lower() |
| |
| |
| tool_category = 'retrieval_tool' if is_retrieval else 'non_retrieval_tool' |
| |
| |
| if name_match and server_present and original_present: |
| |
| stats[tool_category]['name']['tp'] += 1 |
| stats['all_tools']['name']['tp'] += 1 |
| elif server_present and original_present and not name_match: |
| |
| stats[tool_category]['name']['fp'] += 1 |
| stats['all_tools']['name']['fp'] += 1 |
| elif server_present and not original_present: |
| |
| stats[tool_category]['name']['fp'] += 1 |
| stats['all_tools']['name']['fp'] += 1 |
| elif not server_present and original_present: |
| |
| stats[tool_category]['name']['fn'] += 1 |
| stats['all_tools']['name']['fn'] += 1 |
| |
| |
| if arguments_match and server_present and original_present: |
| |
| stats[tool_category]['arguments']['tp'] += 1 |
| stats['all_tools']['arguments']['tp'] += 1 |
| elif server_present and original_present and not arguments_match: |
| |
| stats[tool_category]['arguments']['fp'] += 1 |
| stats['all_tools']['arguments']['fp'] += 1 |
| elif server_present and not original_present: |
| |
| stats[tool_category]['arguments']['fp'] += 1 |
| stats['all_tools']['arguments']['fp'] += 1 |
| elif not server_present and original_present: |
| |
| stats[tool_category]['arguments']['fn'] += 1 |
| stats['all_tools']['arguments']['fn'] += 1 |
| |
| stats['processed_queries'] = processed_queries |
| stats['total_tool_calls'] = total_tool_calls |
| return stats |
|
|
| def print_results(stats: Dict[str, Any], output_file: str = None): |
| """打印结果并保存到文件""" |
| |
| if output_file is None: |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_file = f"precision_recall_results_{timestamp}.txt" |
| |
| |
| output_lines = [] |
| |
| def log_and_save(message: str): |
| logger.info(message) |
| output_lines.append(message) |
| |
| log_and_save("=" * 80) |
| log_and_save("Tool Calls Comparison - Precision & Recall 分析结果") |
| log_and_save("=" * 80) |
| log_and_save(f"总查询数: {stats['total_queries']}") |
| log_and_save(f"处理查询数: {stats['processed_queries']}") |
| log_and_save(f"总Tool Calls数: {stats['total_tool_calls']}") |
| log_and_save("") |
| |
| categories = ['retrieval_tool', 'non_retrieval_tool', 'all_tools'] |
| category_names = { |
| 'retrieval_tool': '包含Retrieval Tool', |
| 'non_retrieval_tool': '非Retrieval Tool', |
| 'all_tools': '所有Tool' |
| } |
| |
| for category in categories: |
| log_and_save(f"【{category_names[category]}】") |
| log_and_save("-" * 40) |
| |
| for metric_type in ['name', 'arguments']: |
| metric_name = 'Tool Call Name' if metric_type == 'name' else 'Tool Call Arguments' |
| tp = stats[category][metric_type]['tp'] |
| fp = stats[category][metric_type]['fp'] |
| fn = stats[category][metric_type]['fn'] |
| |
| precision, recall = calculate_precision_recall(tp, fp, fn) |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 |
| |
| log_and_save(f"{metric_name}:") |
| log_and_save(f" TP: {tp}, FP: {fp}, FN: {fn}") |
| log_and_save(f" Precision: {precision:.4f}") |
| log_and_save(f" Recall: {recall:.4f}") |
| log_and_save(f" F1-Score: {f1:.4f}") |
| log_and_save("") |
| log_and_save("") |
| |
| |
| try: |
| with open(output_file, 'w', encoding='utf-8') as f: |
| f.write('\n'.join(output_lines)) |
| logger.info(f"结果已保存到文件: {output_file}") |
| except Exception as e: |
| logger.error(f"保存文件失败: {e}") |
|
|
| def main(): |
| """主函数""" |
| file_path = 'eval_results/evaluation_results.json' |
| |
| logger.info("正在加载评估结果...") |
| data = load_evaluation_results(file_path) |
| |
| logger.info("正在分析tool_calls_comparison数据...") |
| stats = analyze_tool_calls_comparison(data['results']) |
| |
| print_results(stats) |
|
|
| if __name__ == "__main__": |
| main() |
|
|