Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import requests | |
| from typing import TypedDict, Annotated, Literal | |
| import operator | |
| import traceback | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.tools import tool | |
| from tools.wikipedia_search import wikipedia_search | |
| from tools.visit_webpage import visit_webpage | |
| from tools.read_file import read_file | |
| from tools.python_repl import python_repl | |
| from tools.youtube_transcript import youtube_transcript | |
| from tools.image_caption import image_caption | |
| from tools.audio_transcribe import audio_transcribe | |
| from tools.read_excel import read_excel_sales | |
| from tools.python_executor import python_executor | |
| failure_logs = [] | |
| def log_failure(reason: str, details: str = ""): | |
| msg = f"❌ FAILURE: {reason}" | |
| if details: | |
| msg += f"\n Details: {details}" | |
| print(msg) | |
| failure_logs.append(msg) | |
| # ==================== 工具包装 ==================== | |
| def wikipedia_search_tool(query: str) -> str: | |
| return wikipedia_search(query) | |
| def visit_webpage_tool(url: str) -> str: | |
| return visit_webpage(url) | |
| def read_file_tool(path: str) -> str: | |
| return read_file(path) | |
| def python_repl_tool(code: str) -> str: | |
| return python_repl(code) | |
| def youtube_transcript_tool(url: str) -> str: | |
| return youtube_transcript(url) | |
| def image_caption_tool(url: str) -> str: | |
| return image_caption(url) | |
| def audio_transcribe_tool(url: str) -> str: | |
| return audio_transcribe(url) | |
| def read_excel_sales_tool(url: str) -> str: | |
| return read_excel_sales(url) | |
| def python_executor_tool(url: str) -> str: | |
| return python_executor(url) | |
| tools = [ | |
| wikipedia_search_tool, visit_webpage_tool, read_file_tool, python_repl_tool, | |
| youtube_transcript_tool, image_caption_tool, audio_transcribe_tool, | |
| read_excel_sales_tool, python_executor_tool, | |
| ] | |
| # ==================== LLM 配置 ==================== | |
| LLM_BASE_URL = os.getenv("AGICTO_BASE_URL") | |
| LLM_API_KEY = os.getenv("AGICTO_API_KEY") | |
| LLM_MODEL_ID = "qwen3.5-35b-a3b" | |
| llm = ChatOpenAI( | |
| model=LLM_MODEL_ID, | |
| base_url=LLM_BASE_URL, | |
| api_key=LLM_API_KEY, | |
| temperature=0.0, | |
| max_tokens=512, | |
| ) | |
| llm_with_tools = llm.bind_tools(tools) | |
| # ==================== State ==================== | |
| class AgentState(TypedDict): | |
| messages: Annotated[list, operator.add] | |
| tool_call_count: int | |
| last_tool_name: str | |
| last_tool_input: str | |
| consecutive_failures: int | |
| same_tool_call_count: int | |
| # ==================== 节点 ==================== | |
| def agent_node(state: AgentState): | |
| try: | |
| response = llm_with_tools.invoke(state["messages"]) | |
| return {"messages": [response]} | |
| except Exception as e: | |
| log_failure("LLM error", str(e)) | |
| return {"messages": [AIMessage(content="0")]} | |
| raw_tool_node = ToolNode(tools) | |
| def tool_node_wrapper(state: AgentState): | |
| try: | |
| result = raw_tool_node.invoke(state) | |
| new_count = state.get("tool_call_count", 0) + 1 | |
| result["tool_call_count"] = new_count | |
| last_msg_result = result["messages"][-1] if result["messages"] else None | |
| is_failure = False | |
| if isinstance(last_msg_result, ToolMessage): | |
| content = last_msg_result.content | |
| # 仅当内容以明确的错误前缀开头时视为失败 | |
| if content.startswith(("Error:", "Failed to", "No transcript", "No description", "404", "not found", "Wikipedia search error")): | |
| is_failure = True | |
| result["last_tool_name"] = last_msg_result.name | |
| result["last_tool_input"] = content[:100] | |
| if is_failure: | |
| new_failures = state.get("consecutive_failures", 0) + 1 | |
| result["consecutive_failures"] = new_failures | |
| log_failure("Tool returned error", f"Tool: {last_msg_result.name}, Output: {last_msg_result.content[:200]}") | |
| else: | |
| result["consecutive_failures"] = 0 | |
| # 循环检测:允许相同工具相同输入最多2次,第3次触发 | |
| if (state.get("last_tool_name") == result.get("last_tool_name") and | |
| state.get("last_tool_input") == result.get("last_tool_input")): | |
| same_count = state.get("same_tool_call_count", 0) + 1 | |
| result["same_tool_call_count"] = same_count | |
| if same_count >= 3: | |
| log_failure("Tool loop detected", f"Repeated {result.get('last_tool_name')} with same input {same_count} times") | |
| result["consecutive_failures"] = 2 | |
| else: | |
| result["same_tool_call_count"] = 0 | |
| return result | |
| except Exception as e: | |
| log_failure("Tool execution exception", traceback.format_exc()) | |
| return {"messages": [AIMessage(content="0")], "tool_call_count": state.get("tool_call_count", 0) + 1, | |
| "consecutive_failures": 0, "same_tool_call_count": 0} | |
| def should_continue(state: AgentState) -> Literal["tools", "end"]: | |
| if state.get("consecutive_failures", 0) >= 2: | |
| log_failure("Too many consecutive tool failures", f"Failures: {state.get('consecutive_failures', 0)}") | |
| return "end" | |
| if state.get("tool_call_count", 0) >= 12: | |
| log_failure("Max tool calls reached", f"Count: {state.get('tool_call_count', 0)}") | |
| return "end" | |
| last_msg = state["messages"][-1] | |
| if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: | |
| return "tools" | |
| return "end" | |
| # ==================== 构建图 ==================== | |
| graph_builder = StateGraph(AgentState) | |
| graph_builder.add_node("agent", agent_node) | |
| graph_builder.add_node("tools", tool_node_wrapper) | |
| graph_builder.add_edge(START, "agent") | |
| graph_builder.add_conditional_edges("agent", should_continue, {"tools": "tools", "end": END}) | |
| graph_builder.add_edge("tools", "agent") | |
| graph = graph_builder.compile() | |
| # ==================== 通用答案清洗 ==================== | |
| def clean_answer(text: str) -> str: | |
| if not text: | |
| return "0" | |
| text = str(text) | |
| text = re.sub(r"(?i)^(final answer|answer|result|output)\s*:\s*", "", text) | |
| first_line = text.strip().splitlines()[0].strip() | |
| first_line = first_line.rstrip(".!?,;:") | |
| low = first_line.lower() | |
| invalid = ("unknown", "nan", "none", "i don't know", "i cannot", "can't", | |
| "please try again", "no transcript", "failed", "error", "could not", | |
| "no such file", "404", "not accessible", "image could not be loaded") | |
| if not first_line or any(p in low for p in invalid): | |
| return "0" | |
| if re.match(r'^[\d,]+(\.\d+)?$', first_line): | |
| return first_line.replace(",", "") | |
| if ',' in first_line and len(first_line.split(',')) <= 12: | |
| return first_line | |
| if len(first_line.split()) <= 10: | |
| return first_line | |
| numbers = re.findall(r'\b\d+\b', first_line) | |
| if numbers: | |
| return numbers[0] | |
| words = re.findall(r'\b[A-Z][a-z]+\b', first_line) | |
| if words: | |
| return words[0] | |
| list_pattern = re.search(r'\b([a-e](?:,[a-e])+)\b', text) | |
| if list_pattern: | |
| return list_pattern.group(1) | |
| return "0" | |
| # ==================== Agent 入口 ==================== | |
| def agent(question: str, files=None) -> str: | |
| # 通用预处理:倒序句子反转 | |
| if question.strip().startswith('.rewsna'): | |
| question = question[::-1] # 反转整个字符串 | |
| print(f"🔵 Reversed question: {question[:80]}...") | |
| print(f"\n🔵 AGENT CALLED: {question[:80]}...") | |
| global failure_logs | |
| failure_logs.clear() | |
| context = "" | |
| if files and files[0]: | |
| url = files[0] | |
| print(f"📎 Processing file: {url}") | |
| if url.endswith('.xlsx'): | |
| context = f"[Excel file at {url}. Use read_excel_sales tool to compute total sales.]" | |
| elif url.endswith('.mp3'): | |
| context = f"[Audio file at {url}. Use audio_transcribe tool to get transcript.]" | |
| elif url.endswith('.py'): | |
| context = f"[Python file at {url}. Use python_executor tool to run it and get numeric output.]" | |
| elif url.endswith(('.png', '.jpg', '.jpeg')): | |
| context = f"[Image file at {url}. Use image_caption tool to describe the content.]" | |
| elif 'youtube.com' in url or 'youtu.be' in url: | |
| context = f"[YouTube video at {url}. Use youtube_transcript tool to get subtitles or description.]" | |
| else: | |
| try: | |
| r = requests.get(url, timeout=5) | |
| if r.ok and 'text' in r.headers.get('Content-Type', ''): | |
| context = r.text[:6000] | |
| else: | |
| context = f"File at {url} (type: {r.headers.get('Content-Type','unknown')})" | |
| except Exception as e: | |
| context = f"Could not download file: {str(e)}" | |
| system_prompt = ( | |
| "You are a GAIA assistant. You have tools for Wikipedia, web browsing, file reading, Python, YouTube transcripts, " | |
| "image captioning, audio transcription, Excel sales, and Python code execution.\n" | |
| "You may call tools up to 12 times. If a tool fails twice consecutively, stop and output '0'.\n" | |
| "After gathering enough information, output the final answer concisely.\n" | |
| "Answer with only the required number, name, list (comma-separated), or short phrase.\n" | |
| "Do not include extra text. If you cannot answer, output '0'.\n" | |
| "Examples: '7', 'right', 'b,e', 'broccoli, celery', 'Claus', '562'." | |
| ) | |
| user_content = f"{context}\n\nQuestion: {question}" if context else f"Question: {question}" | |
| messages = [HumanMessage(content=system_prompt), HumanMessage(content=user_content)] | |
| try: | |
| result = graph.invoke( | |
| {"messages": messages, "tool_call_count": 0, "last_tool_name": "", | |
| "last_tool_input": "", "consecutive_failures": 0, "same_tool_call_count": 0}, | |
| {"recursion_limit": 50} | |
| ) | |
| except Exception as e: | |
| log_failure("Graph invocation exception", traceback.format_exc()) | |
| return "0" | |
| final_msg = result["messages"][-1] | |
| answer_text = final_msg.content if isinstance(final_msg, AIMessage) else str(final_msg) | |
| cleaned = clean_answer(answer_text) | |
| if cleaned == "0" and not failure_logs: | |
| log_failure("Unknown reason for answer 0", f"Last message content: {answer_text[:200]}") | |
| print(f"\n✅ FINAL ANSWER: {cleaned}\n") | |
| return cleaned |