""" Agent 核心模块 - GAIA LangGraph ReAct Agent 包含:AgentState, System Prompt, Graph 构建, 答案提取 """ import re from typing import Sequence, Literal, Annotated, Optional from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage from langchain_openai import ChatOpenAI from langgraph.graph import StateGraph, END from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode try: from typing import TypedDict except ImportError: from typing_extensions import TypedDict from config import ( OPENAI_BASE_URL, OPENAI_API_KEY, MODEL, TEMPERATURE, MAX_ITERATIONS, DEBUG, LLM_TIMEOUT, RATE_LIMIT_RETRY_MAX, RATE_LIMIT_RETRY_BASE_DELAY, ) # 导入工具 from tools import BASE_TOOLS # 尝试导入扩展工具 try: from extension_tools import EXTENSION_TOOLS ALL_TOOLS = BASE_TOOLS + EXTENSION_TOOLS except ImportError as e: print(f"⚠️ 扩展工具加载失败: {e}") print(" 提示: 请确保安装了 pandas 和 openpyxl (pip install pandas openpyxl)") EXTENSION_TOOLS = [] ALL_TOOLS = BASE_TOOLS # 尝试导入 RAG 工具 try: from rag import RAG_TOOLS ALL_TOOLS = ALL_TOOLS + RAG_TOOLS except ImportError: RAG_TOOLS = [] # RAG 短路辅助(可选导入,不影响工具加载) try: from rag import rag_lookup_answer except ImportError: rag_lookup_answer = None # 打印已加载的工具列表(调试用) _tool_names = [t.name for t in ALL_TOOLS] if DEBUG: print(f"✓ 已加载 {len(ALL_TOOLS)} 个工具: {_tool_names}") if 'parse_excel' not in _tool_names: print("⚠️ 警告: parse_excel 工具未加载,Excel 文件处理将不可用!") # ======================================== # System Prompt 设计 # ======================================== SYSTEM_PROMPT = """你是一个专业的问答助手,专门解答GAIA基准测试中的各类问题。你需要准确、简洁地回答问题。 ## 你的能力 你可以使用以下工具来获取信息和处理任务: ### 知识库工具(RAG) - `rag_query(question)`: 查询知识库中的相似问题,获取解题策略建议。返回推荐的工具和解题步骤。**遇到复杂问题时优先使用!** - `rag_retrieve(question)`: 仅检索相似问题,不生成建议。返回原始的相似问题和解法。 - `rag_stats()`: 查看知识库状态(文档数量等)。 ### 信息获取工具 - `web_search(query)`: 使用DuckDuckGo搜索网络信息。适用于查找人物、事件、地点、组织等外部知识。 - `wikipedia_search(query)`: 在维基百科中搜索,返回简短摘要(3句话)。适用于快速确认人物/事件的基本信息。 - `wikipedia_page(title, section)`: 获取维基百科页面的完整内容。**需要详细数据(如专辑列表、获奖记录、作品年表)时必须用此工具!** - `tavily_search(query)`: 使用Tavily进行高质量网络搜索,返回最多3条结果。需要API Key。 - `arxiv_search(query)`: 在arXiv上搜索学术论文,返回最多3条结果。适用于查找科学研究和学术文献。 ### 文件处理工具 - `fetch_task_files(task_id)`: 从评分服务器下载任务附件。当问题涉及附件时必须先调用此工具。 - `read_file(file_path)`: 读取本地文件内容,支持txt/csv/json/zip等格式。**注意:不支持Excel和PDF!** - `parse_pdf(file_path)`: 解析PDF文件,提取文本内容。**PDF文件必须用此工具!** - `parse_excel(file_path)`: 解析Excel文件(.xlsx/.xls),返回表格内容。**Excel文件必须用此工具!** - `image_ocr(file_path)`: 对图片进行OCR文字识别。 - `transcribe_audio(file_path)`: 将音频文件转写为文字。 - `analyze_image(file_path, question)`: 使用AI分析图片内容。 ### 计算和代码工具 - `calc(expression)`: 执行安全的数学计算,如 "2+3*4" 或 "sqrt(16)"。适用于简单算术。 - `run_python(code)`: 在沙箱中执行Python代码。支持 import math/re/json/datetime/collections/random/string/itertools/functools 模块。适用于复杂数据处理、排序、过滤、日期计算等操作。 ## 工具使用策略 ### 优先级顺序 0. **先查知识库**【最高优先级】: - 首先调用 `rag_query(question)` 查询知识库 - 如果返回"知识库匹配成功",**直接使用该答案作为最终回答**,不需要再调用其他工具 - 如果返回"知识库参考",参考答案和步骤选择后续工具 - 如果无匹配,按后续优先级使用其他工具 1. **有附件的问题**【重要】: - 第一步:用 `fetch_task_files(task_id)` 下载文件 - 第二步:根据文件扩展名选择正确的读取工具: * `.xlsx` / `.xls` → 必须用 `parse_excel(file_path)` * `.pdf` → 必须用 `parse_pdf(file_path)` * `.txt` / `.csv` / `.json` / `.md` → 用 `read_file(file_path)` * `.png` / `.jpg` / `.jpeg` → 用 `image_ocr(file_path)` 或 `analyze_image(file_path, question)` * `.mp3` / `.wav` → 用 `transcribe_audio(file_path)` - 第三步:分析文件内容,进行必要的计算或处理 - **禁止**:下载文件后不要用 web_search 搜索,文件内容已经本地可用! 2. **需要外部信息**: - **百科知识查询流程**【重要】: * 第一步:用 `wikipedia_search(query)` 确认页面标题 * 第二步:如果需要详细数据(专辑列表、作品年表、获奖记录等),必须用 `wikipedia_page(title, section)` 获取完整内容 * 示例:查 Mercedes Sosa 专辑数 → `wikipedia_search("Mercedes Sosa")` → `wikipedia_page("Mercedes Sosa", "Discography")` - 通用搜索: 使用 `web_search` 搜索其他网络信息 - 学术论文: 使用 `arxiv_search` 查找研究文献 - 高质量结果: 使用 `tavily_search` (如果配置了API Key) 3. **需要计算**: 简单算术用 `calc`,复杂处理用 `run_python` 4. **数据处理**: 使用 `run_python` 进行排序、过滤、统计等操作 ### 工具使用原则 - **只有问题明确提到"attached file"或"附件"时才调用 `fetch_task_files`**,否则不要调用 - 每次只调用一个必要的工具,分析结果后再决定下一步 - 如果工具返回错误,尝试调整参数或换用其他工具 - 搜索时使用精确的关键词,避免过于宽泛 - 读取大文件时注意内容可能被截断,关注关键信息 - **如果 `wikipedia_search` 返回的摘要不足以回答问题,立即使用 `wikipedia_page` 获取完整内容** ## 思考过程 在回答问题前,请按以下步骤思考: 1. **理解问题**: 问题在问什么?需要什么类型的信息? 2. **咨询知识库**: 如果问题复杂或不确定解法,用 `rag_query` 查看相似问题的解题策略 3. **判断工具**: 根据问题类型和 RAG 建议,选择合适的工具 4. **执行获取**: 调用工具获取信息 5. **分析整合**: 分析工具返回的信息,提取关键答案 6. **格式化输出**: 按要求格式输出最终答案 ## 答案格式要求【非常重要】 最终答案必须遵循以下格式: - **数字答案**: 直接输出数字,如 `42` 而不是 "答案是42" - **人名/地名**: 直接输出名称,如 `Albert Einstein` 而不是 "答案是Albert Einstein" - **日期答案**: 使用标准格式 `YYYY-MM-DD` 或按问题要求的格式 - **列表答案**: 用逗号分隔,如 `A, B, C` - **是/否答案**: 输出 `Yes` 或 `No` ⚠️ 最终回答时,只输出答案本身,不要包含: - 不要说"答案是..."、"The answer is..." - 不要添加解释或推理过程 - 不要使用"最终答案:"等前缀 ## 错误恢复 如果遇到问题: - 工具调用失败: 检查参数,尝试简化或换用其他工具 - 搜索无结果: 尝试不同的关键词组合 - 文件读取失败: 确认文件路径正确,检查文件格式 - 计算错误: 检查表达式语法,考虑使用Python代码 ## 示例 问题: "Who was the first person to walk on the moon?" 正确答案: Neil Armstrong 错误答案: The answer is Neil Armstrong. 问题: "What is 15% of 200?" 正确答案: 30 错误答案: 15% of 200 is 30. ### 文件处理示例【重要】 问题: "[Task ID: abc123] The attached Excel file contains sales data. What is the total revenue?" ✅ 正确流程: 1. fetch_task_files("abc123") → 下载文件到本地路径 2. parse_excel("/path/to/file.xlsx") → 读取Excel内容,得到表格数据 3. calc("100+200+300") 或 run_python("...") → 计算总收入 4. 输出最终答案 ❌ 错误流程: 1. fetch_task_files("abc123") → 下载文件 2. web_search("sales data total revenue") → 错!文件内容在本地,不需要搜索网络! ### RAG 辅助示例 问题: "How many studio albums did Mercedes Sosa release between 2000 and 2009?" ✅ 推荐流程: 1. rag_query("How many studio albums did Mercedes Sosa release between 2000 and 2009?") → 获取建议:使用 wikipedia_page 查 Discography 2. wikipedia_search("Mercedes Sosa") → 确认页面存在 3. wikipedia_page("Mercedes Sosa", "Discography") → 获取完整专辑列表 4. run_python("...") → 筛选 2000-2009 年的专辑并计数 5. 输出最终答案 RAG 的价值:直接告诉你该用 wikipedia_page 而不是 web_search,节省试错时间。 现在请回答用户的问题。""" # ======================================== # Agent State 定义 # ======================================== class AgentState(TypedDict): """Agent 状态定义""" # 核心字段 messages: Annotated[Sequence[BaseMessage], add_messages] # 消息历史 # 迭代控制 iteration_count: int # 当前迭代次数,防止无限循环 # ======================================== # LLM 初始化 # ======================================== # 全局 LLM 实例(避免每次迭代重复创建) _llm_instance = None _llm_with_tools = None def get_llm(): """获取 LLM 单例""" global _llm_instance if _llm_instance is None: _llm_instance = ChatOpenAI( model=MODEL, temperature=TEMPERATURE, base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY, timeout=LLM_TIMEOUT, max_retries=2, ) return _llm_instance def get_llm_with_tools(): """获取绑定工具的 LLM 单例""" global _llm_with_tools if _llm_with_tools is None: _llm_with_tools = get_llm().bind_tools(ALL_TOOLS) return _llm_with_tools def invoke_llm_with_retry(llm, messages, max_retries=None, base_delay=None): """ 带重试逻辑的 LLM 调用(处理 429 速率限制错误) Args: llm: LLM 实例 messages: 消息列表 max_retries: 最大重试次数,默认使用配置值 base_delay: 基础延迟秒数,默认使用配置值 Returns: LLM 响应 Raises: 原始异常(如果重试耗尽) """ import time from openai import RateLimitError if max_retries is None: max_retries = RATE_LIMIT_RETRY_MAX if base_delay is None: base_delay = RATE_LIMIT_RETRY_BASE_DELAY last_error = None for attempt in range(max_retries + 1): try: return llm.invoke(messages) except RateLimitError as e: last_error = e if attempt < max_retries: # 指数退避:base_delay * 2^attempt delay = base_delay * (2 ** attempt) print(f"[Rate Limit] 429 错误,第 {attempt + 1}/{max_retries + 1} 次尝试,等待 {delay:.1f} 秒后重试...") time.sleep(delay) else: print(f"[Rate Limit] 重试次数已耗尽 ({max_retries + 1} 次),抛出异常") raise except Exception as e: # 其他错误直接抛出 raise # 不应该到这里,但以防万一 if last_error: raise last_error def create_llm(): """创建 LLM 实例(保留兼容性)""" return get_llm() # ======================================== # Graph 节点定义 # ======================================== def assistant(state: AgentState) -> dict: """ LLM 推理节点 职责: 1. 接收当前状态 2. 构建完整消息(包含 System Prompt) 3. 调用 LLM 生成响应 4. 更新迭代计数 """ messages = state["messages"] iteration = state.get("iteration_count", 0) + 1 # 构建完整消息列表 full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages) # 接近迭代上限时添加强制结束警告 if iteration >= MAX_ITERATIONS - 1: print(f"[Iteration {iteration}] FORCING FINAL ANSWER (no tools)") warning = f""" ⚠️ 【最后机会】已进行 {iteration} 次迭代,达到上限 {MAX_ITERATIONS}。 你必须立即给出最终答案!不要再调用任何工具! 直接根据已有信息输出答案。如果信息不足,给出最佳估计。 """ full_messages.append(SystemMessage(content=warning)) # 不绑定工具,强制 LLM 只输出文本 llm = get_llm() try: response = invoke_llm_with_retry(llm, full_messages) except Exception as e: print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}") raise elif iteration >= MAX_ITERATIONS - 2: warning = f"\n\n⚠️ 警告:已进行 {iteration} 次迭代,接近上限 {MAX_ITERATIONS},请尽快给出最终答案,不要再搜索。" full_messages.append(SystemMessage(content=warning)) # 使用单例 LLM(避免重复创建) llm_with_tools = get_llm_with_tools() try: response = invoke_llm_with_retry(llm_with_tools, full_messages) except Exception as e: print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}") raise else: # 使用单例 LLM(避免重复创建) llm_with_tools = get_llm_with_tools() try: response = invoke_llm_with_retry(llm_with_tools, full_messages) except Exception as e: print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}") raise # 始终打印迭代信息(便于调试) print(f"[Iteration {iteration}] LLM Response: {response.content[:200] if response.content else '(empty)'}...") if hasattr(response, 'tool_calls') and response.tool_calls: print(f"[Iteration {iteration}] Tool calls: {[tc['name'] for tc in response.tool_calls]}") return { "messages": [response], "iteration_count": iteration } def should_continue(state: AgentState) -> Literal["tools", "end"]: """ 路由判断:决定继续使用工具还是结束 判断逻辑: 1. 达到迭代上限 → 强制结束 2. 有工具调用 → 继续执行工具 3. 无工具调用 → 返回答案,结束 """ last_message = state["messages"][-1] iteration = state.get("iteration_count", 0) # 达到迭代上限,强制结束 if iteration >= MAX_ITERATIONS: print(f"[Router] Reached max iterations ({MAX_ITERATIONS}), forcing end") return "end" # 检查是否有工具调用 if hasattr(last_message, "tool_calls") and last_message.tool_calls: print(f"[Router] Has tool calls, continuing to tools") return "tools" # 无工具调用,返回答案 print(f"[Router] No tool calls, ending") return "end" # ======================================== # Graph 构建 # ======================================== def build_agent_graph(): """ 构建 Agent Graph 流程: START → assistant → [should_continue] → tools → assistant → ... → END """ graph = StateGraph(AgentState) # 添加节点 graph.add_node("assistant", assistant) graph.add_node("tools", ToolNode(ALL_TOOLS)) # 设置入口点 graph.set_entry_point("assistant") # 添加条件边 graph.add_conditional_edges( "assistant", should_continue, {"tools": "tools", "end": END} ) # 工具执行后返回 assistant graph.add_edge("tools", "assistant") return graph.compile() # ======================================== # 答案提取 # ======================================== def extract_final_answer(result: dict) -> str: """ 从 Agent 结果中提取最终答案 处理步骤: 1. 获取最后一条消息 2. 移除常见前缀 3. 移除尾部解释 4. 提取 JSON 格式答案 5. 清理格式 """ messages = result.get("messages", []) if not messages: print("[extract_final_answer] No messages in result") return "无法获取答案" # 优先选择"无 tool_calls 的 AIMessage" content = None # 第一优先:无 tool_calls 的 AIMessage(真正的最终答案) for msg in reversed(messages): if isinstance(msg, AIMessage) and msg.content and str(msg.content).strip(): if not (hasattr(msg, "tool_calls") and msg.tool_calls): content = msg.content break # 第二优先:有 tool_calls 的 AIMessage if content is None: for msg in reversed(messages): if isinstance(msg, AIMessage) and msg.content and str(msg.content).strip(): content = msg.content break # 第三优先:任何有内容的消息(可能是 ToolMessage) if content is None: for msg in reversed(messages): if hasattr(msg, "content") and msg.content and str(msg.content).strip(): content = msg.content break print(f"[extract_final_answer] Raw content: {content[:500] if content else '(empty)'}...") if not content: print("[extract_final_answer] Empty content in all messages") return "无法获取答案" answer = content.strip() # Step 1: 移除常见前缀 prefix_patterns = [ # 英文前缀 r'^(?:the\s+)?(?:final\s+)?answer\s*(?:is|:)\s*', r'^(?:the\s+)?result\s*(?:is|:)\s*', r'^(?:therefore|thus|so|hence)[,:]?\s*', r'^based\s+on\s+(?:the|my)\s+(?:analysis|research|calculations?)[,:]?\s*', r'^after\s+(?:analyzing|reviewing|checking)[^,]*[,:]?\s*', r'^according\s+to\s+[^,]*[,:]?\s*', # 中文前缀 r'^(?:最终)?答案[是为::]\s*', r'^(?:结果|结论)[是为::]\s*', r'^(?:因此|所以|综上)[,,::]?\s*', r'^根据(?:以上)?(?:分析|信息|计算)[,,::]?\s*', r'^经过(?:分析|计算|查询)[,,::]?\s*', ] for pattern in prefix_patterns: answer = re.sub(pattern, '', answer, flags=re.IGNORECASE) # Step 2: 移除尾部解释 suffix_patterns = [ r'\s*(?:This|That|The|It)\s+(?:is|was|represents|refers\s+to).*$', r'\s*[(\(].*[)\)]$', r'\s*[。\.]$', r'\s*\n\n.*$', # 移除额外段落 ] for pattern in suffix_patterns: answer = re.sub(pattern, '', answer, flags=re.IGNORECASE | re.DOTALL) # Step 3: 提取 JSON 格式答案 json_patterns = [ r'\{["\']?(?:final_?)?answer["\']?\s*:\s*["\']?([^"\'}\n]+)["\']?\}', r'"answer"\s*:\s*"([^"]+)"', ] for pattern in json_patterns: json_match = re.search(pattern, answer, re.IGNORECASE) if json_match: answer = json_match.group(1) break # Step 4: 清理 answer = answer.strip() answer = re.sub(r'\s+', ' ', answer) # 合并空白 answer = answer.strip('"\'') # 移除引号 # Step 5: 数字格式处理 if re.match(r'^[\d,\.]+$', answer): answer = answer.replace(',', '') return answer def post_process_answer(answer: str, expected_type: str = None) -> str: """ 根据预期类型后处理答案 Args: answer: 原始答案 expected_type: 预期类型 (number, date, boolean, list) Returns: 处理后的答案 """ if expected_type == "number": match = re.search(r'-?\d+\.?\d*', answer.replace(',', '')) if match: return match.group() elif expected_type == "date": # 尝试标准化日期格式 date_patterns = [ (r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"), (r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(1)):02d}-{int(m.group(2)):02d}"), ] for pattern, formatter in date_patterns: match = re.search(pattern, answer) if match: return formatter(match) elif expected_type == "boolean": lower = answer.lower().strip() if lower in ['yes', 'true', '是', '对', 'correct']: return "Yes" elif lower in ['no', 'false', '否', '不', '错', 'incorrect']: return "No" elif expected_type == "list": answer = re.sub(r'\s*[;;、]\s*', ', ', answer) return answer # ======================================== # GaiaAgent 入口类 # ======================================== class GaiaAgent: """ GAIA Agent 入口类 使用方法: agent = GaiaAgent() answer = agent("Who founded Microsoft?") """ def __init__(self): """初始化 Agent""" self.graph = build_agent_graph() def _needs_reformatting(self, answer: str) -> bool: """检查答案是否需要重新格式化""" if not answer or answer == "无法获取答案": return False indicators = [ answer.startswith('http'), 'URL:' in answer, len(answer) > 300, answer.count('\n') > 3, answer.startswith('1.') and '2.' in answer, answer.startswith('- '), '...' in answer and len(answer) > 100, ] return any(indicators) def _force_format_answer(self, result: dict) -> str: """强制格式化答案""" messages = result.get("messages", []) format_prompt = ( "根据上述对话收集的信息,输出最终答案。\n\n" "【强制要求】只输出答案本身,不要解释、不要前缀。\n" "- 数字:直接输出(如 42)\n" "- 人名/地名:直接输出(如 Albert Einstein)\n" "- 日期:YYYY-MM-DD\n" "- 是/否:Yes 或 No\n\n" "最终答案:" ) full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages) full_messages.append(HumanMessage(content=format_prompt)) llm = get_llm() try: print("[Reformat] Forcing answer formatting...") response = invoke_llm_with_retry(llm, full_messages) formatted = extract_final_answer({"messages": [response]}) print(f"[Reformat] Result: {formatted[:100]}...") return formatted except Exception as e: print(f"[Reformat] Error: {e}") return "无法获取答案" def __call__(self, question: str, task_id: str = None) -> str: """ 执行问答 Args: question: 用户问题 task_id: 任务 ID(可选,用于下载附件) Returns: 最终答案 """ # 如果有 task_id,注入到问题中 if task_id: question_with_id = f"[Task ID: {task_id}]\n\n{question}" else: question_with_id = question # ===== RAG 前置短路:高置信度匹配直接返回 ===== try: if rag_lookup_answer is not None: hit = rag_lookup_answer(question, min_similarity=0.85) if hit and hit.get("answer"): print(f"[GaiaAgent] RAG short-circuit hit: similarity={hit.get('similarity', 0):.2f}") if DEBUG: print(f"[Final Answer] {hit['answer']}") return str(hit["answer"]).strip() except Exception as e: if DEBUG: print(f"[GaiaAgent] RAG short-circuit failed: {type(e).__name__}: {e}") # ===== RAG 短路检查结束 ===== # 初始状态 initial_state = { "messages": [HumanMessage(content=question_with_id)], "iteration_count": 0 } try: # 执行 Agent result = self.graph.invoke(initial_state) # 提取答案 answer = extract_final_answer(result) # 检查答案是否需要格式化 if self._needs_reformatting(answer): print(f"[GaiaAgent] Answer needs reformatting: {answer[:50]}...") answer = self._force_format_answer(result) if DEBUG: print(f"[Final Answer] {answer}") return answer if answer else "无法获取答案" except Exception as e: import traceback error_msg = f"Agent 执行出错: {type(e).__name__}: {str(e)}" print(f"[ERROR] {error_msg}") print(traceback.format_exc()) return error_msg def run_with_history(self, messages: list) -> dict: """ 带历史消息执行 Args: messages: 消息历史列表 Returns: 完整结果字典 """ initial_state = { "messages": messages, "iteration_count": 0 } return self.graph.invoke(initial_state) # ======================================== # 便捷函数 # ======================================== def run_agent(question: str, task_id: str = None) -> str: """ 运行 Agent 的便捷函数 Args: question: 用户问题 task_id: 任务 ID(可选) Returns: 最终答案 """ agent = GaiaAgent() return agent(question, task_id) # ======================================== # 测试 # ======================================== if __name__ == "__main__": # 简单测试 agent = GaiaAgent() # 测试计算 print("Test 1: Calculation") answer = agent("What is 15% of 200?") print(f"Answer: {answer}\n") # 测试搜索 print("Test 2: Search") answer = agent("Who founded Microsoft?") print(f"Answer: {answer}\n")