Spaces:
Runtime error
Runtime error
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain.tools import tool | |
| from langchain_community.document_loaders import WikipediaLoader,ArxivLoader | |
| from tavily import TavilyClient | |
| from openai import OpenAI | |
| import base64 | |
| import re | |
| import os | |
| from typing import TypedDict, Annotated, Literal | |
| from langchain_core.messages import ( | |
| AnyMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage | |
| ) | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph import StateGraph, END | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY") | |
| tavily_client = TavilyClient(api_key=TAVILY_API_KEY) | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| MAX_STEPS = 15 | |
| def search_wikipedia(query: str, max_docs: int = 3) -> str: | |
| """Search Wikipedia for general knowledge and return summarized content. | |
| Args: | |
| query: Topic to search (e.g., 'Artificial Intelligence', 'France history') | |
| max_docs: Maximum number of Wikipedia pages to retrieve | |
| """ | |
| loader = WikipediaLoader(query=query, load_max_docs=max_docs) | |
| docs = loader.load() | |
| return "\n\n".join(doc.page_content[:3000] for doc in docs) | |
| def search_arxiv(query: str, max_docs: int = 3) -> str: | |
| """Search arXiv for scientific papers and return summaries. | |
| Args: | |
| query: Research topic or keywords (e.g., 'transformer attention') | |
| max_docs: Maximum number of papers to retrieve | |
| """ | |
| loader = ArxivLoader(query=query, load_max_docs=max_docs) | |
| docs = loader.load() | |
| return "\n\n".join(doc.page_content[:3000] for doc in docs) | |
| def search_web(query: str, max_results: int = 5) -> str: | |
| """Search the web for up-to-date information. | |
| Args: | |
| query: Search query (e.g., 'latest OpenAI model 2025') | |
| max_results: Number of results to return | |
| """ | |
| response = tavily_client.search(query=query, max_results=max_results) | |
| results = [f"{r['title']}\n{r['content']}" for r in response["results"]] | |
| return "\n\n".join(results) | |
| def transcribe_audio(file_path: str) -> str: | |
| """Transcribe an audio file (mp3, wav) into text. | |
| Args: | |
| file_path: Path to the audio file on disk | |
| """ | |
| with open(file_path, "rb") as f: | |
| transcript = openai_client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=f, | |
| ) | |
| return transcript.text | |
| def read_image(file_path: str) -> str: | |
| """Read an image file and return a description via GPT-4o vision. | |
| Args: | |
| file_path: Path to the image file on disk | |
| """ | |
| with open(file_path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("utf-8") | |
| ext = file_path.rsplit(".", 1)[-1].lower() | |
| mime = {"jpg": "image/jpeg", "jpeg": "image/jpeg", | |
| "png": "image/png", "gif": "image/gif", | |
| "webp": "image/webp"}.get(ext, "image/png") | |
| response = openai_client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", | |
| "image_url": {"url": f"data:{mime};base64,{b64}"}}, | |
| {"type": "text", | |
| "text": "Describe this image in detail. Extract any text, data, or key information visible."}, | |
| ], | |
| } | |
| ], | |
| max_tokens=1024, | |
| ) | |
| return response.choices[0].message.content | |
| def read_file(file_path: str) -> str: | |
| """Read a file and return its contents.""" | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| def python_repl(code: str) -> str: | |
| """Execute Python code and return stdout + the value of the last expression. | |
| Useful for arithmetic, data manipulation, and logic tasks. | |
| Args: | |
| code: Valid Python code string | |
| """ | |
| import io, sys, traceback | |
| stdout_capture = io.StringIO() | |
| local_vars: dict = {} | |
| try: | |
| sys.stdout = stdout_capture | |
| exec(code, {}, local_vars) # run all lines | |
| # try to eval last line as expression | |
| lines = [l for l in code.strip().splitlines() if l.strip()] | |
| last_val = "" | |
| if lines: | |
| try: | |
| last_val = repr(eval(lines[-1], {}, local_vars)) | |
| except Exception: | |
| pass | |
| except Exception: | |
| return traceback.format_exc() | |
| finally: | |
| sys.stdout = sys.__stdout__ | |
| out = stdout_capture.getvalue() | |
| return "\n".join(filter(None, [out, last_val])) or "Code executed successfully (no output)." | |
| TOOLS = [ | |
| search_wikipedia, | |
| search_arxiv, | |
| search_web, | |
| transcribe_audio, | |
| read_image, | |
| read_file, | |
| python_repl, | |
| ] | |
| TOOL_MAP = {t.name: t for t in TOOLS} | |
| SYSTEM_PROMPT = f"""You are a highly capable AI assistant solving tasks from the GAIA benchmark. | |
| ## Core rules (MUST follow) | |
| 1. THINK before acting: decompose the question and plan which tool(s) you need. | |
| 2. NEVER call the same tool with the exact same arguments twice. | |
| If the result was insufficient, use a DIFFERENT query or a DIFFERENT tool. | |
| 3. If search_wikipedia returns a biography page instead of a discography/list, | |
| immediately switch to search_web with a more specific query. | |
| 4. For calculations / counting, always use python_repl β never guess numbers. | |
| 5. Once you have enough information, STOP calling tools and give the final answer. | |
| 6. You have at most {MAX_STEPS} tool-call rounds total. Budget them wisely. | |
| ## Tool selection guide | |
| - General facts / biography β search_wikipedia (vary query if first try fails) | |
| - Discographies, filmographies, lists β search_web (Wikipedia tool may miss these) | |
| - Current events / live data β search_web | |
| - Scientific papers β search_arxiv | |
| - Arithmetic / logic β python_repl | |
| - Provided image file β read_image | |
| - Provided audio file β transcribe_audio | |
| - Provided text/csv/json β read_file | |
| ## Answer format | |
| End your FINAL response with exactly: | |
| FINAL ANSWER: <your answer> | |
| Keep it concise β no units unless asked, lists comma-separated. | |
| """ | |
| class AgentState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| step_count: int # counts agent_node invocations | |
| def make_llm(model: str = "gpt-5.4-mini") -> ChatOpenAI: | |
| return ChatOpenAI( | |
| model=model, | |
| temperature=0, | |
| api_key=OPENAI_API_KEY, | |
| ).bind_tools(TOOLS) | |
| llm_with_tools = make_llm() | |
| _step = 0 # console display counter | |
| CYAN = "\033[96m" | |
| GREEN = "\033[92m" | |
| YELLOW = "\033[93m" | |
| RED = "\033[91m" | |
| BOLD = "\033[1m" | |
| RESET = "\033[0m" | |
| def _log(label: str, text: str, color: str = RESET) -> None: | |
| print(f"{color}{'β'*60}{RESET}") | |
| print(f"{color}[Step {_step}] {label}{RESET}") | |
| if text.strip(): | |
| print(f"{color}{text.strip()}{RESET}") | |
| def agent_node(state: AgentState) -> AgentState: | |
| global _step | |
| _step += 1 | |
| step_count = state.get("step_count", 0) + 1 | |
| messages = state["messages"] | |
| # Inject system prompt on first turn | |
| if not any(isinstance(m, SystemMessage) for m in messages): | |
| messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages | |
| # Warn model to wrap up when approaching the limit | |
| if step_count >= MAX_STEPS - 2: | |
| messages = list(messages) + [HumanMessage( | |
| content=f"β οΈ You have used {step_count}/{MAX_STEPS} steps. " | |
| "Do NOT call any more tools. Synthesise what you have and give FINAL ANSWER now." | |
| )] | |
| _log("π€ AGENT THINKING β¦", "", CYAN) | |
| response = llm_with_tools.invoke(messages) | |
| if response.content: | |
| _log("π€ AGENT RESPONSE", str(response.content)[:600], CYAN) | |
| if response.tool_calls: | |
| calls_summary = "\n".join( | |
| f" β’ {tc['name']}({', '.join(f'{k}={repr(v)}' for k, v in tc['args'].items())})" | |
| for tc in response.tool_calls | |
| ) | |
| _log("π§ TOOL CALLS PLANNED", calls_summary, YELLOW) | |
| else: | |
| _log("β AGENT FINISHED (no more tool calls)", "", GREEN) | |
| return {"messages": [response], "step_count": step_count} | |
| def tool_node(state: AgentState) -> AgentState: | |
| global _step | |
| last_msg: AIMessage = state["messages"][-1] | |
| tool_results: list[ToolMessage] = [] | |
| for tc in last_msg.tool_calls: | |
| _step += 1 | |
| tool_fn = TOOL_MAP.get(tc["name"]) | |
| _log(f"βοΈ RUNNING: {tc['name']}", | |
| "\n".join(f" {k}: {repr(v)}" for k, v in tc["args"].items()), | |
| YELLOW) | |
| if tool_fn is None: | |
| result = f"ERROR: unknown tool '{tc['name']}'" | |
| _log("β TOOL ERROR", result, RED) | |
| else: | |
| try: | |
| result = tool_fn.invoke(tc["args"]) | |
| preview = str(result)[:500] + ("β¦" if len(str(result)) > 500 else "") | |
| _log(f"π₯ RESULT: {tc['name']}", preview, GREEN) | |
| except Exception as exc: | |
| result = f"ERROR calling {tc['name']}: {exc}" | |
| _log(f"β TOOL ERROR: {tc['name']}", result, RED) | |
| tool_results.append( | |
| ToolMessage(content=str(result), tool_call_id=tc["id"]) | |
| ) | |
| return {"messages": tool_results} | |
| def should_continue(state: AgentState) -> Literal["tools", "end"]: | |
| step_count = state.get("step_count", 0) | |
| if step_count >= MAX_STEPS: | |
| print(f"{RED}{'β'*60}") | |
| print(f"β MAX_STEPS ({MAX_STEPS}) reached β forcing end.{RESET}") | |
| return "end" | |
| last = state["messages"][-1] | |
| if isinstance(last, AIMessage) and last.tool_calls: | |
| return "tools" | |
| return "end" | |
| def build_graph() -> StateGraph: | |
| g = StateGraph(AgentState) | |
| g.add_node("agent", agent_node) | |
| g.add_node("tools", tool_node) | |
| g.set_entry_point("agent") | |
| g.add_conditional_edges("agent", should_continue, {"tools": "tools", "end": END}) | |
| g.add_edge("tools", "agent") # always return to agent after tool use | |
| return g.compile() | |
| graph = build_graph() | |
| def run_agent(question: str, file_path: str | None = None) -> str: | |
| """Run the agent on a GAIA question and return the extracted final answer.""" | |
| global _step | |
| _step = 0 | |
| print(f"\n{BOLD}{'β'*60}{RESET}") | |
| print(f"{BOLD}β QUESTION: {question}{RESET}") | |
| if file_path: | |
| print(f"{BOLD}π FILE: {file_path}{RESET}") | |
| print(f"{BOLD}{'β'*60}{RESET}\n") | |
| content = question | |
| if file_path: | |
| content += f"\n\n[Attached file available at: {file_path}]" | |
| result = graph.invoke({ | |
| "messages": [HumanMessage(content=content)], | |
| "step_count": 0, | |
| }) | |
| last_msg = result["messages"][-1] | |
| text = last_msg.content if isinstance(last_msg, AIMessage) else str(last_msg) | |
| match = re.search(r"FINAL ANSWER:\s*(.+)", text, re.IGNORECASE | re.DOTALL) | |
| answer = match.group(1).strip() if match else text.strip() | |
| print(f"\n{BOLD}{GREEN}{'β'*60}{RESET}") | |
| print(f"{BOLD}{GREEN}π FINAL ANSWER: {answer}{RESET}") | |
| print(f"{BOLD}{GREEN}{'β'*60}{RESET}\n") | |
| return answer |