| | |
| | |
| | """ |
| | 训练数据评估脚本 |
| | 根据9.17_evaluate_data_top5_final.json的数据结构,将conversations分成source-target pairs, |
| | 使用LLM生成预测并评估工具调用和文本生成的质量 |
| | """ |
| |
|
| | import json |
| | import asyncio |
| | import re |
| | import sys |
| | import os |
| | import time |
| | import requests |
| | import argparse |
| | from typing import List, Dict, Tuple, Any, Optional |
| | from dataclasses import dataclass, asdict |
| | from loguru import logger |
| | from pathlib import Path |
| | from collections import defaultdict |
| | import aiohttp |
| | from concurrent.futures import ThreadPoolExecutor |
| | import signal |
| |
|
| | |
| | def _round_floats(obj: Any, ndigits: int = 3) -> Any: |
| | if isinstance(obj, float): |
| | return round(obj, ndigits) |
| | if isinstance(obj, list): |
| | return [_round_floats(x, ndigits) for x in obj] |
| | if isinstance(obj, dict): |
| | return {k: _round_floats(v, ndigits) for k, v in obj.items()} |
| | return obj |
| |
|
| | |
| | GEMINI_API_KEY = "AIzaSyDikJjktaSUq3sJCAHUIu7JmMEgP1DeHSI" |
| |
|
| | |
| | |
| | VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://125.122.38.32:8021") |
| | VLLM_API_KEY = os.getenv("VLLM_API_KEY", "") |
| |
|
| | QWEN_MODEL_NAME = "my_lora" |
| | |
| |
|
| | |
| | QWEN_API_URL = f"{VLLM_BASE_URL.rstrip('/')}/v1/chat/completions" |
| | |
| | RETRIEVAL_ENDPOINT = "http://125.122.38.32:8024/retrieval_tool" |
| | RETRIEVAL_HEADERS = { |
| | "accept": "application/json", |
| | "Content-Type": "application/json", |
| | } |
| |
|
| | |
| | DISABLE_RECALL = str(os.getenv("EVAL_DISABLE_RECALL", "0")).lower() in ("1", "true", "yes") |
| |
|
| | |
| | MAX_CONCURRENT_CONVERSATIONS = int(os.getenv("MAX_CONCURRENT_CONVERSATIONS", "5")) |
| | MAX_CONCURRENT_PAIRS = int(os.getenv("MAX_CONCURRENT_PAIRS", "10")) |
| | MAX_CONCURRENT_API_CALLS = int(os.getenv("MAX_CONCURRENT_API_CALLS", "20")) |
| |
|
| | @dataclass |
| | class EvaluationPair: |
| | """评估对结构""" |
| | pair_id: int |
| | source: str |
| | target: str |
| | pair_type: str |
| | conversation_id: int |
| |
|
| | @dataclass |
| | class EvaluationResult: |
| | """评估结果结构""" |
| | conversation_id: int |
| | pair_id: int |
| | pair_type: str |
| | source: str |
| | target: str |
| | predict: str |
| | score: float |
| | tool_name_score: float |
| | recall: Optional[int] = None |
| | recall_details: Optional[Dict[str, Any]] = None |
| | details: Dict[str, Any] = None |
| |
|
| | @dataclass |
| | class RealTimeMetrics: |
| | """实时指标结构""" |
| | total_conversations: int = 0 |
| | total_pairs: int = 0 |
| | |
| | |
| | pair1: Dict[str, float] = None |
| | |
| | pair2: Dict[str, float] = None |
| | pair2_consider_recall: Dict[str, float] = None |
| | |
| | |
| | pair3: Dict[str, float] = None |
| | |
| | |
| | recall_metrics: Dict[str, Any] = None |
| | |
| | |
| | overall_current_logic: Dict[str, float] = None |
| | |
| | def __post_init__(self): |
| | if self.pair1 is None: |
| | self.pair1 = {"total": 0, "accuracy": 0.0, "precision@1": 0.0} |
| | |
| | if self.pair2 is None: |
| | self.pair2 = {"total": 0, "accuracy": 0.0, "precision@1": 0.0} |
| | if self.pair2_consider_recall is None: |
| | self.pair2_consider_recall = {"total": 0, "accuracy": 0.0, "precision@1": 0.0} |
| | |
| | if self.pair3 is None: |
| | self.pair3 = {"total": 0, "answer_score": 0.0} |
| | |
| | if self.recall_metrics is None: |
| | self.recall_metrics = {"total_pairs": 0, "recall@5_1": 0, "recall@5_0": 0, "recall_rate": 0.0} |
| | |
| | if self.overall_current_logic is None: |
| | self.overall_current_logic = {"total": 0, "accuracy": 0.0, "precision@1": 0.0, "answer_score": 0.0} |
| |
|
| | class DataProcessor: |
| | """数据处理模块:将conversations分割成source-target pairs""" |
| | |
| | def __init__(self): |
| | logger.info("初始化数据处理模块") |
| | |
| | def parse_conversations(self, conversation_data: Dict, conversation_id: int) -> List[EvaluationPair]: |
| | """ |
| | 解析conversations数据,分割成pairs |
| | - Pair 1: system+tools+user -> function_call |
| | - Pair 2: system+tools+user+observation -> function_call |
| | - Pair 3: system+tools+user+observation -> gpt |
| | """ |
| | conversations = conversation_data["conversations"] |
| | system_prompt = conversation_data["system"] |
| | tools = conversation_data.get("tools", "[]") |
| | |
| | pairs = [] |
| | pair_id = 1 |
| | |
| | |
| | original_query = "" |
| | for msg in conversations: |
| | if msg["from"] == "human": |
| | original_query = msg["value"] |
| | break |
| | |
| | |
| | try: |
| | tools_str = tools if isinstance(tools, str) else json.dumps(tools, ensure_ascii=False) |
| | except Exception: |
| | tools_str = str(tools) |
| |
|
| | if '<tools>' in system_prompt and '</tools>' in system_prompt: |
| | |
| | try: |
| | base_system = re.sub(r'<tools>\s*[\s\S]*?</tools>', '<tools>\n</tools>', system_prompt) |
| | except Exception: |
| | base_system = system_prompt.replace('<tools>\n</tools>', '<tools>\n</tools>').replace('<tools></tools>', '<tools>\n</tools>') |
| | else: |
| | |
| | base_system = system_prompt |
| |
|
| | |
| | try: |
| | parsed_tools = json.loads(tools) if isinstance(tools, str) else tools |
| | except Exception: |
| | parsed_tools = tools |
| |
|
| | try: |
| | if isinstance(parsed_tools, list) and parsed_tools and isinstance(parsed_tools[0], dict): |
| | english_tools_obj = {"type": "function", "function": parsed_tools[0]} |
| | english_tools_str = json.dumps(english_tools_obj, ensure_ascii=False) |
| | else: |
| | english_tools_str = tools_str |
| | except Exception: |
| | english_tools_str = tools_str |
| |
|
| | |
| | english_tail = ( |
| | "\n\n# Tools\n\n" |
| | "You may call one or more functions to assist with the user query.\n\n" |
| | "You are provided with function signatures within <tools></tools> XML tags:\n" |
| | "<tools>\n" |
| | f"{english_tools_str}\n" |
| | "</tools>\n\n" |
| | "For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n" |
| | "<tool_call>\n" |
| | "{\"name\": <function-name>, \"arguments\": <args-json-object>}\n" |
| | "</tool_call>" |
| | ) |
| |
|
| | base_system = f"{base_system}{english_tail}" |
| | |
| | i = 0 |
| | while i < len(conversations): |
| | msg = conversations[i] |
| | |
| | if msg["from"] == "human": |
| | |
| | if i + 1 < len(conversations) and conversations[i + 1]["from"] == "function_call": |
| | |
| | source = f"{base_system}\n\nUser: {msg['value']}" |
| | target = conversations[i + 1]["value"] |
| | pairs.append(EvaluationPair( |
| | pair_id=pair_id, |
| | source=source, |
| | target=target, |
| | pair_type="tool_call", |
| | conversation_id=conversation_id |
| | )) |
| | pair_id += 1 |
| | i += 2 |
| | else: |
| | i += 1 |
| | |
| | elif msg["from"] == "observation": |
| | |
| | if i + 1 < len(conversations): |
| | next_msg = conversations[i + 1] |
| | if next_msg["from"] == "function_call": |
| | |
| | tool_resp_block = ( |
| | f"<tool_response>\n" |
| | f"用户查询: {original_query}\n\n" |
| | f"工具返回结果: {msg['value']}\n" |
| | f"</tool_response>" |
| | ) |
| | source = f"{base_system}\n\n{tool_resp_block}" |
| | target = next_msg["value"] |
| | pairs.append(EvaluationPair( |
| | pair_id=pair_id, |
| | source=source, |
| | target=target, |
| | pair_type="tool_call", |
| | conversation_id=conversation_id |
| | )) |
| | pair_id += 1 |
| | i += 2 |
| | elif next_msg["from"] == "gpt": |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | i += 2 |
| | else: |
| | i += 1 |
| | else: |
| | i += 1 |
| | else: |
| | i += 1 |
| | |
| | logger.info(f"成功解析出 {len(pairs)} 个评估对 (conversation_id: {conversation_id})") |
| | return pairs |
| |
|
| | class LLMPredictor: |
| | """LLM预测模块:根据source生成predict,使用Qwen API""" |
| | |
| | def __init__(self, model_type: str = "qwen3"): |
| | self.model_type = QWEN_MODEL_NAME |
| | self.max_retries = 5 |
| | self.retry_delay = 10 |
| | logger.info(f"初始化LLM预测模块,使用模型: {self.model_type}") |
| | |
| | async def call_qwen_api(self, session: aiohttp.ClientSession, prompt: List[Dict], temperature: float = 0.0, top_p: float = 1.0) -> str: |
| | """异步调用Qwen API生成预测""" |
| | headers = { |
| | "Content-Type": "application/json" |
| | } |
| | if VLLM_API_KEY: |
| | headers["Authorization"] = f"Bearer {VLLM_API_KEY}" |
| | |
| | data = { |
| | "model": self.model_type, |
| | "messages": prompt, |
| | "temperature": temperature, |
| | "top_p": top_p, |
| | "stream": False, |
| | "chat_template_kwargs": { |
| | "enable_thinking": False |
| | } |
| | } |
| | |
| | |
| | for attempt in range(self.max_retries): |
| | try: |
| | async with session.post(QWEN_API_URL, headers=headers, json=data, timeout=aiohttp.ClientTimeout(total=120)) as response: |
| | if response.status == 200: |
| | result = await response.json() |
| | content = result['choices'][0]['message']['content'] |
| | |
| | try: |
| | content = re.sub(r"<think>[\s\S]*?</think>", "", content, flags=re.IGNORECASE) |
| | except Exception: |
| | pass |
| | logger.debug(f"LLM 返回片段: {content[:400]}") |
| | return content.strip() |
| | else: |
| | error_msg = f"API调用失败,状态码: {response.status}, 响应: {await response.text()}" |
| | if attempt < self.max_retries - 1: |
| | logger.warning(f"第{attempt+1}次尝试失败,{error_msg},正在重试...") |
| | await asyncio.sleep(2 ** attempt) |
| | else: |
| | raise Exception(error_msg) |
| | except (aiohttp.ClientError, asyncio.TimeoutError) as e: |
| | error_msg = f"网络请求异常: {str(e)}" |
| | if attempt < self.max_retries - 1: |
| | logger.warning(f"第{attempt+1}次尝试失败,{error_msg},正在重试...") |
| | await asyncio.sleep(2 ** attempt) |
| | else: |
| | raise Exception(error_msg) |
| | |
| | return "" |
| | |
| | async def predict(self, session: aiohttp.ClientSession, source: str, pair_type: str) -> str: |
| | """根据source生成预测:将用户内容放到 user 角色,system 仅保留指令与工具。""" |
| | try: |
| | system_content = source |
| | user_content = None |
| |
|
| | |
| | if "\n\nUser: " in source: |
| | parts = source.split("\n\nUser: ", 1) |
| | system_content = parts[0] |
| | user_content = parts[1] |
| | |
| |
|
| | |
| | if user_content is None: |
| | tool_resp_match = re.search(r'<tool_response>[\s\S]*?</tool_response>', source) |
| | if tool_resp_match: |
| | tool_resp_content = tool_resp_match.group(0) |
| | |
| | user_query_match = re.search(r'用户查询:\s*(.+?)(?:\n\n|$)', tool_resp_content) |
| | tool_result_match = re.search(r'工具返回结果:\s*(.+?)(?:\n|$)', tool_resp_content, re.DOTALL) |
| | |
| | if user_query_match and tool_result_match: |
| | user_query = user_query_match.group(1).strip() |
| | tool_result = tool_result_match.group(1).strip() |
| | user_content = f"用户问题:{user_query}\n\n工具返回结果:{tool_result}" |
| | else: |
| | |
| | user_content = tool_resp_content |
| | |
| | system_content = source.replace(tool_resp_content, "").strip() |
| |
|
| | |
| | if "\n\nUser: " in system_content: |
| | system_content = system_content.split("\n\nUser: ", 1)[0].rstrip() |
| |
|
| | |
| | if user_content is None: |
| | user_content = "" |
| |
|
| | |
| | if pair_type == "tool_call": |
| | if user_content.strip(): |
| | |
| | user_content = f"{user_content}\n\n只输出一个<tool_call>,不要输出解释性文本或答案。" |
| | else: |
| | user_content = "只输出一个<tool_call>,不要输出解释性文本或答案。" |
| | else: |
| | if not user_content.strip(): |
| | user_content = "请根据工具返回的结果生成最终回答。" |
| |
|
| | prompt = [ |
| | {"role": "system", "content": system_content}, |
| | {"role": "user", "content": user_content} |
| | ] |
| |
|
| | logger.info(f"LLM prompt: {prompt}, user指令: {( 'tool_call' if pair_type=='tool_call' else 'text_generation')} ") |
| | |
| | result = await self.call_qwen_api(session, prompt, temperature=0.0, top_p=1.0) |
| | logger.info(f"LLM 输出长度: {len(result)},预览: {result[:5000]}") |
| | return result |
| | except Exception as e: |
| | logger.error(f"LLM预测失败: {e}") |
| | return "" |
| |
|
| | class RetrievalToolCaller: |
| | """检索工具调用模块""" |
| | |
| | def __init__(self): |
| | self.max_retries = 3 |
| | self.retry_delay = 2 |
| | logger.info("初始化检索工具调用模块") |
| | |
| | def extract_query_params(self, pair1_source: str) -> Dict[str, Any]: |
| | """从pair1的source中提取查询参数""" |
| | try: |
| | |
| | user_query = "" |
| | if "User: " in pair1_source: |
| | user_query = pair1_source.split("User: ")[1].strip() |
| | |
| | |
| | params = { |
| | "query": user_query, |
| | "source_filter": "toollist", |
| | "user_id": 136451106, |
| | "top_k": 5 |
| | } |
| | return params |
| | except Exception as e: |
| | logger.error(f"提取查询参数失败: {e}") |
| | return {} |
| |
|
| | def _extract_tool_call_from_text(self, text: str) -> Dict[str, Any]: |
| | """从模型预测文本中提取工具调用对象(支持裸 JSON 或 <tool_call>{...}</tool_call>)""" |
| | try: |
| | text = text.strip() |
| | if text.startswith('{') and text.endswith('}'): |
| | return json.loads(text) |
| | match = re.search(r'<tool_call>\s*({[\s\S]*?})\s*</tool_call>', text) |
| | if match: |
| | return json.loads(match.group(1)) |
| | |
| | return json.loads(text) |
| | except Exception: |
| | return {} |
| |
|
| | def extract_query_params_from_pair1_predict(self, pair1_predict: str) -> Dict[str, Any]: |
| | """从 pair1 的预测结果中提取检索参数(使用 predict_call.arguments.query)""" |
| | try: |
| | call_obj = self._extract_tool_call_from_text(pair1_predict) |
| | arguments = call_obj.get("arguments", {}) if isinstance(call_obj, dict) else {} |
| | query_from_predict = arguments.get("query", "") |
| |
|
| | params = { |
| | "query": query_from_predict, |
| | "source_filter": "toollist", |
| | "user_id": 136451106, |
| | "top_k": 5 |
| | } |
| | return params |
| | except Exception as e: |
| | logger.error(f"从pair1预测中提取检索参数失败: {e}") |
| | return {} |
| | |
| | async def call_retrieval_tool(self, session: aiohttp.ClientSession, params: Dict[str, Any]) -> Tuple[int, Dict[str, Any]]: |
| | """异步调用检索工具""" |
| | payload = { |
| | "jsonrpc": "2.0", |
| | "id": "req_001", |
| | "method": "tools/call", |
| | "params": { |
| | "name": "retrieval_tool", |
| | "arguments": params, |
| | }, |
| | } |
| | |
| | for attempt in range(self.max_retries): |
| | try: |
| | async with session.post(RETRIEVAL_ENDPOINT, headers=RETRIEVAL_HEADERS, json=payload, timeout=aiohttp.ClientTimeout(total=20)) as resp: |
| | code = resp.status |
| | try: |
| | data = await resp.json() |
| | except Exception: |
| | data = {"raw": await resp.text()} |
| | return code, data |
| | except Exception as e: |
| | if attempt < self.max_retries - 1: |
| | logger.warning(f"检索工具调用失败,第{attempt+1}次尝试: {e}") |
| | await asyncio.sleep(self.retry_delay) |
| | else: |
| | logger.error(f"检索工具调用失败,已尝试{self.max_retries}次: {e}") |
| | return 0, {"error": str(e)} |
| | |
| | def extract_retrieved_tools(self, response_obj: Dict[str, Any], top_k: int = 5) -> List[str]: |
| | """从检索工具响应中提取前top_k个工具名称""" |
| | tools = [] |
| | |
| | try: |
| | |
| | if "result" in response_obj and isinstance(response_obj["result"], list): |
| | for item in response_obj["result"][:top_k]: |
| | if isinstance(item, dict): |
| | |
| | for key in ["name", "tool_name", "title", "id", "label", "api_name"]: |
| | if key in item and isinstance(item[key], str): |
| | tools.append(item[key]) |
| | break |
| | |
| | if not any(key in item for key in ["name", "tool_name", "title", "id", "label", "api_name"]): |
| | |
| | text = json.dumps(item, ensure_ascii=False) |
| | |
| | matches = re.findall(r'"([^"]+)"', text) |
| | if matches: |
| | tools.append(matches[0]) |
| | |
| | |
| | elif "data" in response_obj and isinstance(response_obj["data"], list): |
| | for item in response_obj["data"][:top_k]: |
| | if isinstance(item, dict): |
| | for key in ["name", "tool_name", "title", "id", "label", "api_name"]: |
| | if key in item and isinstance(item[key], str): |
| | tools.append(item[key]) |
| | break |
| | |
| | |
| | if not tools: |
| | text = json.dumps(response_obj, ensure_ascii=False) |
| | |
| | matches = re.findall(r'"name":\s*"([^"]+)"', text) |
| | tools = matches[:top_k] |
| | |
| | except Exception as e: |
| | logger.error(f"提取检索工具时出错: {e}") |
| | |
| | return tools[:top_k] |
| | |
| | def compute_recall(self, pair1_source: str, pair2_target: str) -> Tuple[int, Dict[str, Any]]: |
| | """计算recall指标(保留:基于 pair1 source 的原始查询)""" |
| | try: |
| | |
| | params = self.extract_query_params(pair1_source) |
| | if not params: |
| | return 0, {"error": "无法提取查询参数"} |
| | |
| | |
| | status_code, response = self.call_retrieval_tool(params) |
| | if status_code != 200: |
| | return 0, {"error": f"检索工具调用失败,状态码: {status_code}"} |
| | |
| | |
| | retrieved_tools = self.extract_retrieved_tools(response, top_k=5) |
| | |
| | |
| | try: |
| | pair2_call = json.loads(pair2_target) |
| | target_tool = pair2_call.get("name", "") |
| | except: |
| | target_tool = "" |
| | |
| | |
| | recall = 1 if target_tool in retrieved_tools else 0 |
| | |
| | recall_details = { |
| | "target_tool": target_tool, |
| | "retrieved_tools": retrieved_tools, |
| | "recall": recall, |
| | "query_params": params, |
| | "response_status": status_code |
| | } |
| | |
| | return recall, recall_details |
| | |
| | except Exception as e: |
| | logger.error(f"计算recall失败: {e}") |
| | return 0, {"error": str(e)} |
| |
|
| | async def compute_recall_from_pair1_predict(self, session: aiohttp.ClientSession, pair1_predict: str, pair2_target: str) -> Tuple[int, Dict[str, Any]]: |
| | """计算recall指标:基于 pair1 的预测调用中的 query 字段""" |
| | try: |
| | params = self.extract_query_params_from_pair1_predict(pair1_predict) |
| | if not params: |
| | return 0, {"error": "无法从pair1预测中提取检索参数"} |
| |
|
| | logger.info(f"调用检索工具 - 查询参数: {params.get('query', '')[:100]}") |
| | |
| | status_code, response = await self.call_retrieval_tool(session, params) |
| | if status_code != 200: |
| | logger.warning(f"检索工具调用失败,状态码: {status_code}") |
| | return 0, {"error": f"检索工具调用失败,状态码: {status_code}"} |
| |
|
| | retrieved_tools = self.extract_retrieved_tools(response, top_k=5) |
| | logger.info(f"检索工具返回 - 获取到 {len(retrieved_tools)} 个工具: {retrieved_tools}") |
| |
|
| | try: |
| | pair2_call = json.loads(pair2_target) |
| | target_tool = pair2_call.get("name", "") |
| | except Exception: |
| | target_tool = "" |
| |
|
| | recall = 1 if target_tool in retrieved_tools else 0 |
| |
|
| | recall_details = { |
| | "target_tool": target_tool, |
| | "retrieved_tools": retrieved_tools, |
| | "recall": recall, |
| | "query_params": params, |
| | "response_status": status_code |
| | } |
| |
|
| | return recall, recall_details |
| | except Exception as e: |
| | logger.error(f"计算recall失败(基于pair1预测): {e}") |
| | return 0, {"error": str(e)} |
| |
|
| | class ToolCallEvaluator: |
| | """工具调用评估模块:比较tool选择和参数一致性""" |
| | |
| | def __init__(self): |
| | logger.info("初始化工具调用评估模块") |
| | |
| | def extract_tool_call(self, text: str) -> Dict[str, Any]: |
| | """从文本中提取工具调用信息""" |
| | try: |
| | |
| | if text.startswith('{') and text.endswith('}'): |
| | return json.loads(text) |
| | |
| | |
| | tool_call_pattern = r'<tool_call>\s*({.*?})\s*</tool_call>' |
| | match = re.search(tool_call_pattern, text, re.DOTALL) |
| | if match: |
| | return json.loads(match.group(1)) |
| | |
| | |
| | return json.loads(text) |
| | except: |
| | return {} |
| | |
| | def evaluate_tool_call(self, target: str, predict: str) -> Tuple[float, float, Dict[str, Any]]: |
| | """ |
| | 评估工具调用的一致性 |
| | 返回:(总分, 工具名称得分, 详细信息) |
| | """ |
| | target_call = self.extract_tool_call(target) |
| | predict_call = self.extract_tool_call(predict) |
| | if not predict_call: |
| | logger.debug(f"predict 非结构化输出,无法解析为工具调用。predict预览: {predict[:300]}") |
| | |
| | details = { |
| | "target_call": target_call, |
| | "predict_call": predict_call, |
| | "tool_name_match": False, |
| | "arguments_match": False, |
| | "argument_details": {} |
| | } |
| | |
| | score = 0.0 |
| | tool_name_score = 0.0 |
| | |
| | |
| | target_name = target_call.get("name", "") |
| | predict_name = predict_call.get("name", "") |
| | |
| | if target_name == predict_name and target_name: |
| | details["tool_name_match"] = True |
| | score += 0.5 |
| | tool_name_score = 1.0 |
| | |
| | |
| | target_args = target_call.get("arguments", {}) |
| | predict_args = predict_call.get("arguments", {}) |
| | |
| | if target_args and predict_args: |
| | matching_args = 0 |
| | total_args = len(target_args) |
| | |
| | for key, target_value in target_args.items(): |
| | predict_value = predict_args.get(key) |
| | match = (predict_value == target_value) |
| | details["argument_details"][key] = { |
| | "target": target_value, |
| | "predict": predict_value, |
| | "match": match |
| | } |
| | if match: |
| | matching_args += 1 |
| | |
| | if total_args > 0: |
| | arg_score = matching_args / total_args |
| | details["arguments_match"] = (arg_score == 1.0) |
| | score += 0.5 * arg_score |
| | |
| | return score, tool_name_score, details |
| |
|
| | class TextGenerationEvaluator: |
| | """文本生成评估模块:使用LoRA测试模型进行评估""" |
| | |
| | def __init__(self, model_type: str = "qwen3"): |
| | self.model_type = QWEN_MODEL_NAME |
| | self.max_retries = 5 |
| | self.retry_delay = 10 |
| | logger.info(f"初始化文本生成评估模块,使用模型: {self.model_type}") |
| | |
| | def call_gemini_api(self, prompt: str, temperature: float = 0.3, top_p: float = 0.95, top_k: int = 40) -> str: |
| | """调用Gemini API""" |
| | url = f"https://generativelanguage.googleapis.com/v2beta/models/{self.model_type}:generateContent?key={GEMINI_API_KEY}" |
| | headers = {"Content-Type": "application/json"} |
| | payload = { |
| | "contents": [ |
| | { |
| | "role": "user", |
| | "parts": [{"text": prompt}] |
| | } |
| | ], |
| | "generationConfig": { |
| | "temperature": float(temperature), |
| | "topP": float(top_p), |
| | "topK": int(top_k), |
| | "maxOutputTokens": 8192 |
| | } |
| | } |
| |
|
| | for attempt in range(self.max_retries): |
| | try: |
| | response = requests.post(url, headers=headers, json=payload, timeout=60) |
| | response.raise_for_status() |
| | raw = response.json() |
| | |
| | |
| | text = "" |
| | try: |
| | text = raw["candidates"][0]["content"]["parts"][0]["text"] |
| | except Exception: |
| | text = "" |
| | |
| | return text |
| | |
| | except Exception as e: |
| | if attempt < self.max_retries - 1: |
| | time.sleep(self.retry_delay) |
| | else: |
| | logger.error(f"API调用失败 (尝试 {attempt+1}/{self.max_retries}): {e}") |
| | return "" |
| | |
| | async def call_qwen_api(self, session: aiohttp.ClientSession, prompt: List[Dict], temperature: float = 0.3, top_p: float = 0.95) -> str: |
| | """异步调用Qwen API进行评估""" |
| | headers = { |
| | "Content-Type": "application/json" |
| | } |
| | if VLLM_API_KEY: |
| | headers["Authorization"] = f"Bearer {VLLM_API_KEY}" |
| | |
| | data = { |
| | "model": self.model_type, |
| | "messages": prompt, |
| | "temperature": temperature, |
| | "top_p": top_p, |
| | "stream": False, |
| | "chat_template_kwargs": { |
| | "enable_thinking": False |
| | } |
| | } |
| | |
| | |
| | for attempt in range(self.max_retries): |
| | try: |
| | |
| | try: |
| | logger.debug(f"LLM 调用完整messages: {json.dumps(data.get('messages', []), ensure_ascii=False) }") |
| | except Exception: |
| | pass |
| | async with session.post(QWEN_API_URL, headers=headers, json=data, timeout=aiohttp.ClientTimeout(total=120)) as response: |
| | if response.status == 200: |
| | result = await response.json() |
| | content = result['choices'][0]['message']['content'] |
| | |
| | else: |
| | error_msg = f"API调用失败,状态码: {response.status}, 响应: {await response.text()}" |
| | if attempt < self.max_retries - 1: |
| | logger.warning(f"第{attempt+1}次尝试失败,{error_msg},正在重试...") |
| | await asyncio.sleep(2 ** attempt) |
| | else: |
| | raise Exception(error_msg) |
| | except (aiohttp.ClientError, asyncio.TimeoutError) as e: |
| | error_msg = f"网络请求异常: {str(e)}" |
| | if attempt < self.max_retries - 1: |
| | logger.warning(f"第{attempt+1}次尝试失败,{error_msg},正在重试...") |
| | await asyncio.sleep(2 ** attempt) |
| | else: |
| | raise Exception(error_msg) |
| | |
| | return "" |
| | |
| | async def evaluate_text_generation(self, session: aiohttp.ClientSession, target: str, predict: str) -> Tuple[float, Dict[str, Any]]: |
| | """使用LoRA测试模型评估文本生成质量""" |
| | judge_prompt = f""" |
| | 请评估以下两个文本的相似度和质量,从以下几个维度进行评分(每个维度0-10分): |
| | |
| | 1. 内容准确性:预测文本是否准确传达了目标文本的主要信息 |
| | 2. 完整性:预测文本是否包含了目标文本的关键要素 |
| | 3. 表达质量:预测文本的语言表达是否清晰、流畅 |
| | 4. 格式一致性:预测文本的格式是否与目标文本相似 |
| | |
| | 目标文本: |
| | {target} |
| | |
| | 预测文本: |
| | {predict} |
| | |
| | 请按以下JSON格式返回评估结果: |
| | {{ |
| | "content_accuracy": <0-10分>, |
| | "completeness": <0-10分>, |
| | "expression_quality": <0-10分>, |
| | "format_consistency": <0-10分>, |
| | "overall_score": <0-10分>, |
| | "reasoning": "详细说明评分理由" |
| | }} |
| | """ |
| | |
| | try: |
| | |
| | prompt = [ |
| | {"role": "system", "content": "你是一个专业的文本质量评估专家,能够客观地评估文本的相似度和质量。"}, |
| | {"role": "user", "content": judge_prompt} |
| | ] |
| | |
| | |
| | result = await self.call_qwen_api(session, prompt, temperature=0.3, top_p=0.95) |
| | |
| | |
| | json_match = re.search(r'\{.*\}', result, re.DOTALL) |
| | if json_match: |
| | eval_result = json.loads(json_match.group()) |
| | overall_score = eval_result.get("overall_score", 0) / 10.0 |
| | return overall_score, eval_result |
| | else: |
| | |
| | logger.warning("无法解析JSON评估结果,使用简单文本匹配评分") |
| | simple_score = self._simple_text_similarity_score(target, predict) |
| | return simple_score, {"overall_score": simple_score * 10, "method": "simple_similarity"} |
| | |
| | except Exception as e: |
| | logger.error(f"文本生成评估失败: {e}") |
| | |
| | simple_score = self._simple_text_similarity_score(target, predict) |
| | return simple_score, {"error": str(e), "fallback_score": simple_score * 10} |
| | |
| | def _simple_text_similarity_score(self, target: str, predict: str) -> float: |
| | """简单的文本相似度评分(备用方法)""" |
| | try: |
| | |
| | target_words = set(target.lower().split()) |
| | predict_words = set(predict.lower().split()) |
| | |
| | if not target_words: |
| | return 0.0 |
| | |
| | |
| | overlap = len(target_words.intersection(predict_words)) |
| | overlap_ratio = overlap / len(target_words) |
| | |
| | |
| | length_ratio = min(len(predict), len(target)) / max(len(predict), len(target)) if max(len(predict), len(target)) > 0 else 0 |
| | |
| | |
| | score = (overlap_ratio * 0.7 + length_ratio * 0.3) |
| | return min(score, 1.0) |
| | |
| | except Exception: |
| | return 0.5 |
| |
|
| | class MetricsCalculator: |
| | """指标计算模块""" |
| | |
| | def __init__(self): |
| | logger.info("初始化指标计算模块") |
| | |
| | def calculate_pair_metrics(self, results: List[EvaluationResult], pair_id: int, metric_type: str) -> Dict[str, float]: |
| | """计算特定pair和指标类型的统计""" |
| | |
| | pair_results = [r for r in results if r.pair_id == pair_id] |
| | |
| | if not pair_results: |
| | return {"total": 0, "tool_call_avg": 0.0, "tool_name_avg": 0.0} |
| | |
| | |
| | if metric_type == "current_logic": |
| | |
| | filtered_results = pair_results |
| | elif metric_type == "real_tool": |
| | |
| | if pair_id == 2: |
| | filtered_results = [r for r in pair_results if r.recall == 1] |
| | else: |
| | |
| | return {"total": 0, "tool_call_avg": 0.0, "tool_name_avg": 0.0} |
| | elif metric_type == "recall_subset": |
| | |
| | if pair_id == 2: |
| | filtered_results = [r for r in pair_results if r.recall == 1] |
| | else: |
| | filtered_results = [] |
| | else: |
| | filtered_results = pair_results |
| | |
| | if not filtered_results: |
| | return {"total": 0, "accuracy": 0.0, "precision@1": 0.0} |
| | |
| | total = len(filtered_results) |
| | accuracy = sum(r.score for r in filtered_results) / total |
| | precision_at_1 = sum(r.tool_name_score for r in filtered_results) / total |
| | |
| | return { |
| | "total": total, |
| | "accuracy": accuracy, |
| | "precision@1": precision_at_1 |
| | } |
| | |
| | def calculate_text_generation_metrics(self, results: List[EvaluationResult]) -> Dict[str, float]: |
| | """计算文本生成指标""" |
| | text_results = [r for r in results if r.pair_type == "text_generation"] |
| | |
| | if not text_results: |
| | return {"total": 0, "answer_score": 0.0} |
| | |
| | total = len(text_results) |
| | answer_score = sum(r.score for r in text_results) / total |
| | |
| | return { |
| | "total": total, |
| | "answer_score": answer_score |
| | } |
| | |
| | def calculate_recall_metrics(self, results: List[EvaluationResult]) -> Dict[str, Any]: |
| | """计算recall指标""" |
| | |
| | pair2_results = [r for r in results if r.pair_id == 2 and r.recall is not None] |
| | |
| | if not pair2_results: |
| | return {"total_pairs": 0, "recall@5_1": 0, "recall@5_0": 0, "recall_rate": 0.0} |
| | |
| | total_pairs = len(pair2_results) |
| | recall_at_5_1 = sum(1 for r in pair2_results if r.recall == 1) |
| | recall_at_5_0 = total_pairs - recall_at_5_1 |
| | recall_rate = recall_at_5_1 / total_pairs if total_pairs > 0 else 0.0 |
| | |
| | return { |
| | "total_pairs": total_pairs, |
| | "recall@5_1": recall_at_5_1, |
| | "recall@5_0": recall_at_5_0, |
| | "recall_rate": recall_rate |
| | } |
| | |
| | def calculate_overall_metrics(self, results: List[EvaluationResult], metric_type: str) -> Dict[str, float]: |
| | """计算总体指标""" |
| | if metric_type == "current_logic": |
| | |
| | filtered_results = results |
| | elif metric_type == "real_tool": |
| | |
| | filtered_results = [] |
| | for r in results: |
| | if r.pair_id == 2: |
| | if r.recall == 1: |
| | filtered_results.append(r) |
| | else: |
| | |
| | filtered_results.append(r) |
| | elif metric_type == "recall_subset": |
| | |
| | filtered_results = [] |
| | for r in results: |
| | if r.pair_id == 2: |
| | if r.recall == 1: |
| | filtered_results.append(r) |
| | else: |
| | filtered_results.append(r) |
| | else: |
| | filtered_results = results |
| | |
| | if not filtered_results: |
| | return {"total": 0, "accuracy": 0.0, "precision@1": 0.0, "answer_score": 0.0} |
| | |
| | total = len(filtered_results) |
| | |
| | |
| | tool_call_results = [r for r in filtered_results if r.pair_type == "tool_call"] |
| | text_gen_results = [r for r in filtered_results if r.pair_type == "text_generation"] |
| | |
| | accuracy = sum(r.score for r in tool_call_results) / len(tool_call_results) if tool_call_results else 0.0 |
| | precision_at_1 = sum(r.tool_name_score for r in tool_call_results) / len(tool_call_results) if tool_call_results else 0.0 |
| | answer_score = sum(r.score for r in text_gen_results) / len(text_gen_results) if text_gen_results else 0.0 |
| | |
| | return { |
| | "total": total, |
| | "accuracy": accuracy, |
| | "precision@1": precision_at_1, |
| | "answer_score": answer_score |
| | } |
| | |
| | def update_realtime_metrics(self, metrics: RealTimeMetrics, results: List[EvaluationResult]) -> RealTimeMetrics: |
| | """更新实时指标""" |
| | |
| | metrics.total_conversations = len(set(r.conversation_id for r in results)) |
| | metrics.total_pairs = len(results) |
| | |
| | |
| | metrics.pair1 = self.calculate_pair_metrics(results, 1, "current_logic") |
| | |
| | |
| | metrics.pair2 = self.calculate_pair_metrics(results, 2, "current_logic") |
| | metrics.pair2_consider_recall = self.calculate_pair_metrics(results, 2, "real_tool") |
| | |
| | |
| | metrics.pair3 = self.calculate_text_generation_metrics(results) |
| | |
| | |
| | metrics.recall_metrics = self.calculate_recall_metrics(results) |
| | |
| | |
| | metrics.overall_current_logic = self.calculate_overall_metrics(results, "current_logic") |
| | |
| | return metrics |
| |
|
| | class TrainingDataEvaluator: |
| | """主评估类""" |
| | |
| | def __init__(self, model_type: str = "qwen3"): |
| | self.data_processor = DataProcessor() |
| | self.llm_predictor = LLMPredictor(model_type) |
| | self.tool_evaluator = ToolCallEvaluator() |
| | self.text_evaluator = TextGenerationEvaluator(model_type) |
| | self.retrieval_caller = RetrievalToolCaller() |
| | self.metrics_calculator = MetricsCalculator() |
| | logger.info("训练数据评估器初始化完成") |
| | |
| | async def evaluate_single_pair(self, session: aiohttp.ClientSession, pair: EvaluationPair, pair_predict_by_id: Dict[int, str], pair_toolname_score_by_id: Dict[int, float]) -> EvaluationResult: |
| | """异步评估单个pair""" |
| | logger.info(f"评估 Pair {pair.pair_id} (类型: {pair.pair_type})") |
| | |
| | try: |
| | logger.debug(f"Pair {pair.pair_id} source长度: {len(pair.source)},预览: {pair.source[:400]}") |
| | logger.debug(f"Pair {pair.pair_id} target长度: {len(pair.target)},预览: {pair.target[:200]}") |
| | except Exception: |
| | pass |
| | |
| | |
| | predict = await self.llm_predictor.predict(session, pair.source, pair.pair_type) |
| | |
| | pair_predict_by_id[pair.pair_id] = predict |
| | |
| | |
| | if pair.pair_type == "tool_call": |
| | score, tool_name_score, details = self.tool_evaluator.evaluate_tool_call(pair.target, predict) |
| | |
| | pair_toolname_score_by_id[pair.pair_id] = tool_name_score |
| | |
| | |
| | recall = None |
| | recall_details = None |
| | |
| | |
| | if pair.pair_id == 2 and not DISABLE_RECALL: |
| | pair1_predict = pair_predict_by_id.get(1) |
| | pair1_toolname_score = pair_toolname_score_by_id.get(1) |
| | if pair1_predict and pair1_toolname_score == 1.0: |
| | recall, recall_details = await self.retrieval_caller.compute_recall_from_pair1_predict(session, pair1_predict, pair.target) |
| | elif pair.pair_id == 2 and DISABLE_RECALL: |
| | recall, recall_details = None, None |
| | else: |
| | |
| | score, details = await self.text_evaluator.evaluate_text_generation(session, pair.target, predict) |
| | tool_name_score = 0.0 |
| | recall = None |
| | recall_details = None |
| | |
| | result = EvaluationResult( |
| | conversation_id=pair.conversation_id, |
| | pair_id=pair.pair_id, |
| | pair_type=pair.pair_type, |
| | source=pair.source, |
| | target=pair.target, |
| | predict=predict, |
| | score=score, |
| | tool_name_score=tool_name_score, |
| | recall=recall, |
| | recall_details=recall_details, |
| | details=details |
| | ) |
| | |
| | |
| | if pair.pair_type == "tool_call": |
| | if recall is not None: |
| | |
| | retrieved_tools = recall_details.get("retrieved_tools", []) if recall_details else [] |
| | target_tool = recall_details.get("target_tool", "") if recall_details else "" |
| | logger.info(f"Pair {pair.pair_id} 评估完成,accuracy: {score:.3f}, precision@1: {tool_name_score:.3f}, recall@5: {recall}") |
| | logger.info(f"Pair {pair.pair_id} 检索详情 - 目标工具: {target_tool}, 检索到的工具: {retrieved_tools}") |
| | else: |
| | logger.info(f"Pair {pair.pair_id} 评估完成,accuracy: {score:.3f}, precision@1: {tool_name_score:.3f}") |
| | else: |
| | logger.info(f"Pair {pair.pair_id} 评估完成,answer_score: {score:.3f}") |
| | |
| | return result |
| | |
| | async def evaluate_file(self, file_path: str, checkpoint_file: str = None, start_idx: int = 0, end_idx: Optional[int] = None) -> List[EvaluationResult]: |
| | """异步并发评估整个文件,支持断点续传和实时指标更新 |
| | |
| | Args: |
| | file_path: 要评估的JSON文件路径 |
| | checkpoint_file: 断点文件路径(可选) |
| | start_idx: 开始评估的对话索引(从0开始) |
| | end_idx: 结束评估的对话索引(不包含,如果为None则评估到最后) |
| | """ |
| | logger.info(f"开始异步并发评估文件: {file_path}") |
| | logger.info(f"并发配置: 最大对话并发数={MAX_CONCURRENT_CONVERSATIONS}, 最大Pair并发数={MAX_CONCURRENT_PAIRS}, 最大API并发数={MAX_CONCURRENT_API_CALLS}") |
| | |
| | if start_idx > 0 or end_idx is not None: |
| | logger.info(f"评估范围: 对话 {start_idx} 到 {end_idx if end_idx else '最后'}") |
| | |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | |
| | |
| | total_conversations = len(data) |
| | if end_idx is None: |
| | end_idx = total_conversations |
| | else: |
| | end_idx = min(end_idx, total_conversations) |
| | |
| | |
| | if start_idx >= total_conversations: |
| | logger.error(f"起始索引 {start_idx} 超出数据范围 (总共 {total_conversations} 个对话)") |
| | return [] |
| | |
| | if start_idx >= end_idx: |
| | logger.error(f"起始索引 {start_idx} 不能大于等于结束索引 {end_idx}") |
| | return [] |
| | |
| | logger.info(f"实际评估范围: 对话 {start_idx} 到 {end_idx-1} (共 {end_idx - start_idx} 个对话)") |
| | |
| | |
| | all_results = [] |
| | processed_pairs = set() |
| | conversation_id = 1 |
| | |
| | if checkpoint_file and os.path.exists(checkpoint_file): |
| | try: |
| | with open(checkpoint_file, 'r', encoding='utf-8') as f: |
| | checkpoint_data = json.load(f) |
| | all_results = [EvaluationResult(**r) for r in checkpoint_data.get("results", [])] |
| | processed_pairs = set(tuple(p) for p in checkpoint_data.get("processed_pairs", [])) |
| | conversation_id = checkpoint_data.get("next_conversation_id", 1) |
| | start_idx = len(set(r.conversation_id for r in all_results)) |
| | logger.info(f"从断点恢复,已处理 {len(all_results)} 个评估对,conversation_id: {conversation_id}") |
| | except Exception as e: |
| | logger.error(f"读取断点文件失败: {e},将从头开始评估") |
| | all_results = [] |
| | start_idx = 0 |
| | processed_pairs = set() |
| | conversation_id = 1 |
| | |
| | |
| | realtime_metrics = RealTimeMetrics() |
| | |
| | |
| | connector = aiohttp.TCPConnector(limit=MAX_CONCURRENT_API_CALLS, limit_per_host=MAX_CONCURRENT_API_CALLS) |
| | timeout = aiohttp.ClientTimeout(total=300) |
| | |
| | async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: |
| | |
| | conversation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_CONVERSATIONS) |
| | pair_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PAIRS) |
| | |
| | |
| | conversation_tasks = [] |
| | for idx, conversation_data in enumerate(data[start_idx:end_idx], start=start_idx): |
| | task = self._evaluate_conversation_async( |
| | session, conversation_semaphore, pair_semaphore, |
| | conversation_data, idx, conversation_id, processed_pairs |
| | ) |
| | conversation_tasks.append(task) |
| | conversation_id += 1 |
| | |
| | |
| | logger.info(f"开始并发评估 {len(conversation_tasks)} 个对话") |
| | conversation_results = await asyncio.gather(*conversation_tasks, return_exceptions=True) |
| | |
| | |
| | for idx, result in enumerate(conversation_results): |
| | if isinstance(result, Exception): |
| | logger.error(f"对话 {start_idx + idx} 评估失败: {result}") |
| | else: |
| | all_results.extend(result) |
| | |
| | |
| | realtime_metrics = self.metrics_calculator.update_realtime_metrics(realtime_metrics, all_results) |
| | self._save_realtime_metrics(realtime_metrics) |
| | |
| | |
| | if checkpoint_file: |
| | self._save_checkpoint(checkpoint_file, all_results, processed_pairs, start_idx + idx + 1) |
| | |
| | logger.info(f"异步并发评估完成,总共处理了 {len(all_results)} 个评估对") |
| | return all_results |
| | |
| | async def _evaluate_conversation_async(self, session: aiohttp.ClientSession, conversation_semaphore: asyncio.Semaphore, |
| | pair_semaphore: asyncio.Semaphore, conversation_data: Dict, idx: int, |
| | conversation_id: int, processed_pairs: set) -> List[EvaluationResult]: |
| | """异步评估单个对话""" |
| | async with conversation_semaphore: |
| | logger.info(f"评估对话 {idx + 1} (conversation_id: {conversation_id})") |
| | |
| | |
| | pairs = self.data_processor.parse_conversations(conversation_data, conversation_id) |
| | |
| | |
| | unprocessed_pairs = [] |
| | for pair in pairs: |
| | pair_key = (conversation_id, pair.pair_id) |
| | if pair_key not in processed_pairs: |
| | unprocessed_pairs.append(pair) |
| | else: |
| | logger.info(f"跳过已处理的 Pair {pair.pair_id}") |
| | |
| | if not unprocessed_pairs: |
| | logger.info(f"对话 {conversation_id} 的所有pairs都已处理过") |
| | return [] |
| | |
| | |
| | pair_predict_by_id = {} |
| | pair_toolname_score_by_id = {} |
| | |
| | |
| | |
| | sorted_pairs = sorted(unprocessed_pairs, key=lambda p: p.pair_id) |
| | |
| | results = [] |
| | text_gen_pairs = [] |
| | |
| | |
| | for pair in sorted_pairs: |
| | if pair.pair_type == "tool_call": |
| | |
| | result = await self._evaluate_single_pair_async( |
| | session, pair_semaphore, pair, pair_predict_by_id, pair_toolname_score_by_id |
| | ) |
| | if isinstance(result, Exception): |
| | logger.error(f"Pair {pair.pair_id} 评估失败: {result}") |
| | else: |
| | results.append(result) |
| | pair_key = (conversation_id, pair.pair_id) |
| | processed_pairs.add(pair_key) |
| | else: |
| | |
| | text_gen_pairs.append(pair) |
| | |
| | |
| | if text_gen_pairs: |
| | text_gen_tasks = [] |
| | for pair in text_gen_pairs: |
| | task = self._evaluate_single_pair_async( |
| | session, pair_semaphore, pair, pair_predict_by_id, pair_toolname_score_by_id |
| | ) |
| | text_gen_tasks.append(task) |
| | |
| | text_gen_results = await asyncio.gather(*text_gen_tasks, return_exceptions=True) |
| | |
| | |
| | for pair, result in zip(text_gen_pairs, text_gen_results): |
| | pair_key = (conversation_id, pair.pair_id) |
| | if isinstance(result, Exception): |
| | logger.error(f"Pair {pair.pair_id} 评估失败: {result}") |
| | else: |
| | results.append(result) |
| | processed_pairs.add(pair_key) |
| | |
| | return results |
| | |
| | async def _evaluate_single_pair_async(self, session: aiohttp.ClientSession, pair_semaphore: asyncio.Semaphore, |
| | pair: EvaluationPair, pair_predict_by_id: Dict[int, str], |
| | pair_toolname_score_by_id: Dict[int, float]) -> EvaluationResult: |
| | """异步评估单个pair(带信号量控制)""" |
| | async with pair_semaphore: |
| | return await self.evaluate_single_pair(session, pair, pair_predict_by_id, pair_toolname_score_by_id) |
| | |
| | def _save_checkpoint(self, checkpoint_file: str, all_results: List[EvaluationResult], |
| | processed_pairs: set, next_conversation_id: int): |
| | """保存断点文件""" |
| | try: |
| | |
| | cleaned_results = [] |
| | for r in all_results: |
| | result_dict = asdict(r) |
| | |
| | if r.pair_id in [1, 3]: |
| | result_dict.pop('recall', None) |
| | result_dict.pop('recall_details', None) |
| | cleaned_results.append(result_dict) |
| | |
| | checkpoint_data = { |
| | "results": cleaned_results, |
| | "processed_pairs": [list(p) for p in processed_pairs], |
| | "next_conversation_id": next_conversation_id |
| | } |
| | with open(checkpoint_file, 'w', encoding='utf-8') as f: |
| | json.dump(checkpoint_data, f, ensure_ascii=False) |
| | except Exception as e: |
| | logger.error(f"保存断点文件失败: {e}") |
| | |
| | def _save_realtime_metrics(self, metrics: RealTimeMetrics): |
| | """保存实时指标到文件""" |
| | try: |
| | realtime_file = "metrics/realtime_metrics.json" |
| | data = asdict(metrics) |
| | |
| | if "overall_current_logic" in data: |
| | data["overall"] = data.pop("overall_current_logic") |
| | |
| | data = _round_floats(data, 3) |
| | with open(realtime_file, 'w', encoding='utf-8') as f: |
| | json.dump(data, f, ensure_ascii=False, indent=2) |
| | except Exception as e: |
| | logger.error(f"保存实时指标失败: {e}") |
| | |
| | def generate_report(self, results: List[EvaluationResult]) -> Dict[str, Any]: |
| | """生成评估报告,按pair_id分组""" |
| | |
| | grouped_results = defaultdict(list) |
| | for result in results: |
| | grouped_results[result.pair_id].append(result) |
| | |
| | |
| | metrics_calc = MetricsCalculator() |
| | |
| | |
| | pair_metrics = {} |
| | for pair_id in [1, 2, 3]: |
| | pair_results = grouped_results.get(pair_id, []) |
| | if pair_results: |
| | if pair_id == 1: |
| | |
| | pair_metrics["pair1"] = metrics_calc.calculate_pair_metrics(pair_results, pair_id, "current_logic") |
| | elif pair_id == 2: |
| | |
| | pair_metrics["pair2"] = metrics_calc.calculate_pair_metrics(pair_results, pair_id, "current_logic") |
| | pair_metrics["pair2_consider_recall"] = metrics_calc.calculate_pair_metrics(pair_results, pair_id, "real_tool") |
| | pair_metrics["pair2_recall_subset"] = metrics_calc.calculate_pair_metrics(pair_results, pair_id, "recall_subset") |
| | else: |
| | |
| | pair_metrics["pair3"] = metrics_calc.calculate_text_generation_metrics(pair_results) |
| | |
| | |
| | recall_metrics = metrics_calc.calculate_recall_metrics(results) |
| | |
| | |
| | overall_metrics = metrics_calc.calculate_overall_metrics(results, "current_logic") |
| | |
| | |
| | report = { |
| | "summary": { |
| | "total_conversations": len(set(r.conversation_id for r in results)), |
| | "total_pairs": len(results), |
| | "pair_metrics": pair_metrics, |
| | "recall_metrics": recall_metrics, |
| | "overall_metrics": overall_metrics, |
| | "model": self.llm_predictor.model_type |
| | }, |
| | "detailed_results": { |
| | f"pair{pair_id}": [ |
| | { |
| | "conversation_id": r.conversation_id, |
| | "pair_id": r.pair_id, |
| | "pair_type": r.pair_type, |
| | "score": r.score, |
| | "tool_name_score": r.tool_name_score if r.pair_type == "tool_call" else None, |
| | **({"recall": r.recall, "recall_details": r.recall_details} if pair_id == 2 and r.recall is not None else {}), |
| | "source": r.source, |
| | "target": r.target, |
| | "predict": r.predict, |
| | "target_preview": r.target[:100] + "..." if len(r.target) > 100 else r.target, |
| | "predict_preview": r.predict[:100] + "..." if len(r.predict) > 100 else r.predict, |
| | "details": r.details |
| | } |
| | for r in pair_results |
| | ] |
| | for pair_id, pair_results in grouped_results.items() |
| | } |
| | } |
| | |
| | return report |
| |
|
| | def parse_args(): |
| | """解析命令行参数""" |
| | parser = argparse.ArgumentParser(description="训练数据评估脚本") |
| | parser.add_argument("--input_file", "-i", type=str, |
| | default="/home/ziqiang/LLaMA-Factory/data/dataset/10_22/10.22_fuzzy_data.json", |
| | help="输入JSON文件路径 (默认: data/9.17_evaluate_data_top5_final.json)") |
| | parser.add_argument("--output_file", "-o", type=str, |
| | default="/home/ziqiang/LLaMA-Factory/data/dataset/10_22/data_evaluation.json", |
| | help="输出结果文件路径 (默认: metrics/data_evaluation_results.json)") |
| | parser.add_argument("--checkpoint_file", "-c", type=str, |
| | default="/home/ziqiang/LLaMA-Factory/data/dataset/10_22/evaluation_checkpoint.json", |
| | help="断点文件路径 (默认: metrics/evaluation_checkpoint.json)") |
| | parser.add_argument("--start_idx", "-s", type=int, default=0, |
| | help="开始评估的对话索引(从0开始,默认: 0)") |
| | parser.add_argument("--end_idx", "-e", type=int, default=2000, |
| | help="结束评估的对话索引(不包含,默认: 10)") |
| | parser.add_argument("--log_file", "-l", type=str, |
| | default="/home/ziqiang/LLaMA-Factory/data/dataset/10_22/data_evaluation.log", |
| | help="日志文件路径 (默认: metrics/data_evaluation.log)") |
| | parser.add_argument("--models", type=str, default="", |
| | help="以逗号分隔的一组模型名(例如: /data/models/Qwen3-8B,my_lora)。提供多个时开启多模型评估模式") |
| | parser.add_argument("--multi_output_dir", type=str, default="evaluation/multi", |
| | help="多模型评估输出目录(默认: evaluation/multi)") |
| | parser.add_argument("--aggregate_output", type=str, default="evaluation/multi_aggregate_0929_v2.json", |
| | help="多模型聚合报告输出文件(默认: evaluation/multi_aggregate.json)") |
| | |
| | |
| | parser.add_argument("--max_concurrent_conversations", type=int, default=1, |
| | help="最大并发对话数(默认: 5)") |
| | parser.add_argument("--max_concurrent_pairs", type=int, default=1, |
| | help="最大并发pair数(默认: 10)") |
| | parser.add_argument("--max_concurrent_api_calls", type=int, default=1, |
| | help="最大并发API调用数(默认: 20)") |
| | |
| | return parser.parse_args() |
| |
|
| | async def main(): |
| | """主函数""" |
| | args = parse_args() |
| | |
| | |
| | global MAX_CONCURRENT_CONVERSATIONS, MAX_CONCURRENT_PAIRS, MAX_CONCURRENT_API_CALLS |
| | MAX_CONCURRENT_CONVERSATIONS = args.max_concurrent_conversations |
| | MAX_CONCURRENT_PAIRS = args.max_concurrent_pairs |
| | MAX_CONCURRENT_API_CALLS = args.max_concurrent_api_calls |
| | |
| | |
| | logger.add(args.log_file, rotation="100 MB", level="DEBUG") |
| | |
| | |
| | os.makedirs("metrics", exist_ok=True) |
| | |
| | logger.info("开始增强版异步并发训练数据评估") |
| | logger.info(f"输入文件: {args.input_file}") |
| | logger.info(f"输出文件: {args.output_file}") |
| | logger.info(f"断点文件: {args.checkpoint_file}") |
| | logger.info(f"评估范围: 对话 {args.start_idx} 到 {args.end_idx if args.end_idx else '最后'}") |
| | logger.info(f"并发配置: 对话={MAX_CONCURRENT_CONVERSATIONS}, Pairs={MAX_CONCURRENT_PAIRS}, API={MAX_CONCURRENT_API_CALLS}") |
| | |
| | |
| | models_list = [m.strip() for m in (args.models or "").split(',') if m.strip()] |
| | if len(models_list) > 1: |
| | os.makedirs(args.multi_output_dir, exist_ok=True) |
| | aggregate = { |
| | "input_file": args.input_file, |
| | "models": models_list, |
| | "runs": {} |
| | } |
| | for model_name in models_list: |
| | model_safe = re.sub(r"[^A-Za-z0-9_.-]", "_", model_name) |
| | output_file = os.path.join(args.multi_output_dir, f"result_{model_safe}.json") |
| | checkpoint_file = os.path.join(args.multi_output_dir, f"checkpoint_{model_safe}.json") |
| | log_file = os.path.join(args.multi_output_dir, f"eval_{model_safe}.log") |
| | try: |
| | logger.add(log_file, rotation="100 MB", level="DEBUG") |
| | except Exception: |
| | pass |
| |
|
| | evaluator = TrainingDataEvaluator(model_type=model_name) |
| |
|
| | results = await evaluator.evaluate_file( |
| | args.input_file, |
| | checkpoint_file, |
| | args.start_idx, |
| | args.end_idx |
| | ) |
| | report = evaluator.generate_report(results) |
| | with open(output_file, 'w', encoding='utf-8') as f: |
| | json.dump(_round_floats(report, 3), f, ensure_ascii=False, indent=2) |
| | aggregate["runs"][model_name] = { |
| | "output_file": output_file, |
| | "summary": report.get("summary", {}), |
| | } |
| |
|
| | if os.path.exists(checkpoint_file): |
| | try: |
| | os.remove(checkpoint_file) |
| | except Exception: |
| | pass |
| |
|
| | |
| | comparison = {} |
| | for model_name, run in aggregate["runs"].items(): |
| | summary = run.get("summary", {}) |
| | comparison[model_name] = { |
| | "overall_metrics": summary.get("overall_metrics", {}), |
| | "pair1": summary.get("pair_metrics", {}).get("pair1", {}), |
| | "pair2": summary.get("pair_metrics", {}).get("pair2", {}), |
| | "pair3": summary.get("pair_metrics", {}).get("pair3", {}), |
| | } |
| | aggregate["comparison"] = _round_floats(comparison, 3) |
| |
|
| | with open(args.aggregate_output, 'w', encoding='utf-8') as f: |
| | json.dump(_round_floats(aggregate, 3), f, ensure_ascii=False, indent=2) |
| | logger.info(f"多模型评估完成,聚合报告: {args.aggregate_output}") |
| | |
| | |
| | import gc |
| | gc.collect() |
| | return |
| |
|
| | |
| | evaluator = TrainingDataEvaluator( |
| | model_type=QWEN_MODEL_NAME if not models_list else models_list[0] |
| | ) |
| | results = await evaluator.evaluate_file( |
| | args.input_file, |
| | args.checkpoint_file, |
| | args.start_idx, |
| | args.end_idx |
| | ) |
| | report = evaluator.generate_report(results) |
| | with open(args.output_file, 'w', encoding='utf-8') as f: |
| | json.dump(_round_floats(report, 3), f, ensure_ascii=False, indent=2) |
| | logger.info(f"评估完成,结果已保存到: {args.output_file}") |
| | if os.path.exists(args.checkpoint_file): |
| | try: |
| | os.remove(args.checkpoint_file) |
| | logger.info(f"已删除断点文件: {args.checkpoint_file}") |
| | except Exception as e: |
| | logger.error(f"删除断点文件失败: {e}") |
| | |
| | |
| | import gc |
| | gc.collect() |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | asyncio.run(main()) |