LangGraph_GAIA / agent.py
BiGuan's picture
Update agent.py
9f491dc verified
Raw
History Blame Contribute Delete
11 kB
import os
import re
import requests
from typing import TypedDict, Annotated, Literal
import operator
import traceback
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool
from tools.wikipedia_search import wikipedia_search
from tools.visit_webpage import visit_webpage
from tools.read_file import read_file
from tools.python_repl import python_repl
from tools.youtube_transcript import youtube_transcript
from tools.image_caption import image_caption
from tools.audio_transcribe import audio_transcribe
from tools.read_excel import read_excel_sales
from tools.python_executor import python_executor
failure_logs = []
def log_failure(reason: str, details: str = ""):
msg = f"❌ FAILURE: {reason}"
if details:
msg += f"\n Details: {details}"
print(msg)
failure_logs.append(msg)
# ==================== 工具包装 ====================
@tool(description="Search Wikipedia and return full page content (up to 8000 chars).")
def wikipedia_search_tool(query: str) -> str:
return wikipedia_search(query)
@tool(description="Fetch a webpage and return its textual content.")
def visit_webpage_tool(url: str) -> str:
return visit_webpage(url)
@tool(description="Read a text/CSV file from a URL.")
def read_file_tool(path: str) -> str:
return read_file(path)
@tool(description="Execute Python code and return printed output.")
def python_repl_tool(code: str) -> str:
return python_repl(code)
@tool(description="Get transcript or description from YouTube video.")
def youtube_transcript_tool(url: str) -> str:
return youtube_transcript(url)
@tool(description="Describe an image from a URL using AI.")
def image_caption_tool(url: str) -> str:
return image_caption(url)
@tool(description="Transcribe an audio file (MP3) from a URL.")
def audio_transcribe_tool(url: str) -> str:
return audio_transcribe(url)
@tool(description="Read Excel file and compute total sales (food only).")
def read_excel_sales_tool(url: str) -> str:
return read_excel_sales(url)
@tool(description="Execute Python code from a file URL and return output.")
def python_executor_tool(url: str) -> str:
return python_executor(url)
tools = [
wikipedia_search_tool, visit_webpage_tool, read_file_tool, python_repl_tool,
youtube_transcript_tool, image_caption_tool, audio_transcribe_tool,
read_excel_sales_tool, python_executor_tool,
]
# ==================== LLM 配置 ====================
LLM_BASE_URL = os.getenv("AGICTO_BASE_URL")
LLM_API_KEY = os.getenv("AGICTO_API_KEY")
LLM_MODEL_ID = "qwen3.5-35b-a3b"
llm = ChatOpenAI(
model=LLM_MODEL_ID,
base_url=LLM_BASE_URL,
api_key=LLM_API_KEY,
temperature=0.0,
max_tokens=512,
)
llm_with_tools = llm.bind_tools(tools)
# ==================== State ====================
class AgentState(TypedDict):
messages: Annotated[list, operator.add]
tool_call_count: int
last_tool_name: str
last_tool_input: str
consecutive_failures: int
same_tool_call_count: int
# ==================== 节点 ====================
def agent_node(state: AgentState):
try:
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
except Exception as e:
log_failure("LLM error", str(e))
return {"messages": [AIMessage(content="0")]}
raw_tool_node = ToolNode(tools)
def tool_node_wrapper(state: AgentState):
try:
result = raw_tool_node.invoke(state)
new_count = state.get("tool_call_count", 0) + 1
result["tool_call_count"] = new_count
last_msg_result = result["messages"][-1] if result["messages"] else None
is_failure = False
if isinstance(last_msg_result, ToolMessage):
content = last_msg_result.content
# 仅当内容以明确的错误前缀开头时视为失败
if content.startswith(("Error:", "Failed to", "No transcript", "No description", "404", "not found", "Wikipedia search error")):
is_failure = True
result["last_tool_name"] = last_msg_result.name
result["last_tool_input"] = content[:100]
if is_failure:
new_failures = state.get("consecutive_failures", 0) + 1
result["consecutive_failures"] = new_failures
log_failure("Tool returned error", f"Tool: {last_msg_result.name}, Output: {last_msg_result.content[:200]}")
else:
result["consecutive_failures"] = 0
# 循环检测:允许相同工具相同输入最多2次,第3次触发
if (state.get("last_tool_name") == result.get("last_tool_name") and
state.get("last_tool_input") == result.get("last_tool_input")):
same_count = state.get("same_tool_call_count", 0) + 1
result["same_tool_call_count"] = same_count
if same_count >= 3:
log_failure("Tool loop detected", f"Repeated {result.get('last_tool_name')} with same input {same_count} times")
result["consecutive_failures"] = 2
else:
result["same_tool_call_count"] = 0
return result
except Exception as e:
log_failure("Tool execution exception", traceback.format_exc())
return {"messages": [AIMessage(content="0")], "tool_call_count": state.get("tool_call_count", 0) + 1,
"consecutive_failures": 0, "same_tool_call_count": 0}
def should_continue(state: AgentState) -> Literal["tools", "end"]:
if state.get("consecutive_failures", 0) >= 2:
log_failure("Too many consecutive tool failures", f"Failures: {state.get('consecutive_failures', 0)}")
return "end"
if state.get("tool_call_count", 0) >= 12:
log_failure("Max tool calls reached", f"Count: {state.get('tool_call_count', 0)}")
return "end"
last_msg = state["messages"][-1]
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
return "tools"
return "end"
# ==================== 构建图 ====================
graph_builder = StateGraph(AgentState)
graph_builder.add_node("agent", agent_node)
graph_builder.add_node("tools", tool_node_wrapper)
graph_builder.add_edge(START, "agent")
graph_builder.add_conditional_edges("agent", should_continue, {"tools": "tools", "end": END})
graph_builder.add_edge("tools", "agent")
graph = graph_builder.compile()
# ==================== 通用答案清洗 ====================
def clean_answer(text: str) -> str:
if not text:
return "0"
text = str(text)
text = re.sub(r"(?i)^(final answer|answer|result|output)\s*:\s*", "", text)
first_line = text.strip().splitlines()[0].strip()
first_line = first_line.rstrip(".!?,;:")
low = first_line.lower()
invalid = ("unknown", "nan", "none", "i don't know", "i cannot", "can't",
"please try again", "no transcript", "failed", "error", "could not",
"no such file", "404", "not accessible", "image could not be loaded")
if not first_line or any(p in low for p in invalid):
return "0"
if re.match(r'^[\d,]+(\.\d+)?$', first_line):
return first_line.replace(",", "")
if ',' in first_line and len(first_line.split(',')) <= 12:
return first_line
if len(first_line.split()) <= 10:
return first_line
numbers = re.findall(r'\b\d+\b', first_line)
if numbers:
return numbers[0]
words = re.findall(r'\b[A-Z][a-z]+\b', first_line)
if words:
return words[0]
list_pattern = re.search(r'\b([a-e](?:,[a-e])+)\b', text)
if list_pattern:
return list_pattern.group(1)
return "0"
# ==================== Agent 入口 ====================
def agent(question: str, files=None) -> str:
# 通用预处理:倒序句子反转
if question.strip().startswith('.rewsna'):
question = question[::-1] # 反转整个字符串
print(f"🔵 Reversed question: {question[:80]}...")
print(f"\n🔵 AGENT CALLED: {question[:80]}...")
global failure_logs
failure_logs.clear()
context = ""
if files and files[0]:
url = files[0]
print(f"📎 Processing file: {url}")
if url.endswith('.xlsx'):
context = f"[Excel file at {url}. Use read_excel_sales tool to compute total sales.]"
elif url.endswith('.mp3'):
context = f"[Audio file at {url}. Use audio_transcribe tool to get transcript.]"
elif url.endswith('.py'):
context = f"[Python file at {url}. Use python_executor tool to run it and get numeric output.]"
elif url.endswith(('.png', '.jpg', '.jpeg')):
context = f"[Image file at {url}. Use image_caption tool to describe the content.]"
elif 'youtube.com' in url or 'youtu.be' in url:
context = f"[YouTube video at {url}. Use youtube_transcript tool to get subtitles or description.]"
else:
try:
r = requests.get(url, timeout=5)
if r.ok and 'text' in r.headers.get('Content-Type', ''):
context = r.text[:6000]
else:
context = f"File at {url} (type: {r.headers.get('Content-Type','unknown')})"
except Exception as e:
context = f"Could not download file: {str(e)}"
system_prompt = (
"You are a GAIA assistant. You have tools for Wikipedia, web browsing, file reading, Python, YouTube transcripts, "
"image captioning, audio transcription, Excel sales, and Python code execution.\n"
"You may call tools up to 12 times. If a tool fails twice consecutively, stop and output '0'.\n"
"After gathering enough information, output the final answer concisely.\n"
"Answer with only the required number, name, list (comma-separated), or short phrase.\n"
"Do not include extra text. If you cannot answer, output '0'.\n"
"Examples: '7', 'right', 'b,e', 'broccoli, celery', 'Claus', '562'."
)
user_content = f"{context}\n\nQuestion: {question}" if context else f"Question: {question}"
messages = [HumanMessage(content=system_prompt), HumanMessage(content=user_content)]
try:
result = graph.invoke(
{"messages": messages, "tool_call_count": 0, "last_tool_name": "",
"last_tool_input": "", "consecutive_failures": 0, "same_tool_call_count": 0},
{"recursion_limit": 50}
)
except Exception as e:
log_failure("Graph invocation exception", traceback.format_exc())
return "0"
final_msg = result["messages"][-1]
answer_text = final_msg.content if isinstance(final_msg, AIMessage) else str(final_msg)
cleaned = clean_answer(answer_text)
if cleaned == "0" and not failure_logs:
log_failure("Unknown reason for answer 0", f"Last message content: {answer_text[:200]}")
print(f"\n✅ FINAL ANSWER: {cleaned}\n")
return cleaned