from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage from langchain.tools import tool from langchain_community.document_loaders import WikipediaLoader,ArxivLoader from tavily import TavilyClient from openai import OpenAI import base64 import re import os from typing import TypedDict, Annotated, Literal from langchain_core.messages import ( AnyMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage ) from langgraph.graph.message import add_messages from langgraph.graph import StateGraph, END OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY") tavily_client = TavilyClient(api_key=TAVILY_API_KEY) openai_client = OpenAI(api_key=OPENAI_API_KEY) MAX_STEPS = 15 @tool def search_wikipedia(query: str, max_docs: int = 3) -> str: """Search Wikipedia for general knowledge and return summarized content. Args: query: Topic to search (e.g., 'Artificial Intelligence', 'France history') max_docs: Maximum number of Wikipedia pages to retrieve """ loader = WikipediaLoader(query=query, load_max_docs=max_docs) docs = loader.load() return "\n\n".join(doc.page_content[:3000] for doc in docs) @tool def search_arxiv(query: str, max_docs: int = 3) -> str: """Search arXiv for scientific papers and return summaries. Args: query: Research topic or keywords (e.g., 'transformer attention') max_docs: Maximum number of papers to retrieve """ loader = ArxivLoader(query=query, load_max_docs=max_docs) docs = loader.load() return "\n\n".join(doc.page_content[:3000] for doc in docs) @tool def search_web(query: str, max_results: int = 5) -> str: """Search the web for up-to-date information. Args: query: Search query (e.g., 'latest OpenAI model 2025') max_results: Number of results to return """ response = tavily_client.search(query=query, max_results=max_results) results = [f"{r['title']}\n{r['content']}" for r in response["results"]] return "\n\n".join(results) @tool def transcribe_audio(file_path: str) -> str: """Transcribe an audio file (mp3, wav) into text. Args: file_path: Path to the audio file on disk """ with open(file_path, "rb") as f: transcript = openai_client.audio.transcriptions.create( model="whisper-1", file=f, ) return transcript.text @tool def read_image(file_path: str) -> str: """Read an image file and return a description via GPT-4o vision. Args: file_path: Path to the image file on disk """ with open(file_path, "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") ext = file_path.rsplit(".", 1)[-1].lower() mime = {"jpg": "image/jpeg", "jpeg": "image/jpeg", "png": "image/png", "gif": "image/gif", "webp": "image/webp"}.get(ext, "image/png") response = openai_client.chat.completions.create( model="gpt-4o", messages=[ { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}, {"type": "text", "text": "Describe this image in detail. Extract any text, data, or key information visible."}, ], } ], max_tokens=1024, ) return response.choices[0].message.content @tool def read_file(file_path: str) -> str: """Read a file and return its contents.""" with open(file_path, "r", encoding="utf-8") as f: return f.read() @tool def python_repl(code: str) -> str: """Execute Python code and return stdout + the value of the last expression. Useful for arithmetic, data manipulation, and logic tasks. Args: code: Valid Python code string """ import io, sys, traceback stdout_capture = io.StringIO() local_vars: dict = {} try: sys.stdout = stdout_capture exec(code, {}, local_vars) # run all lines # try to eval last line as expression lines = [l for l in code.strip().splitlines() if l.strip()] last_val = "" if lines: try: last_val = repr(eval(lines[-1], {}, local_vars)) except Exception: pass except Exception: return traceback.format_exc() finally: sys.stdout = sys.__stdout__ out = stdout_capture.getvalue() return "\n".join(filter(None, [out, last_val])) or "Code executed successfully (no output)." TOOLS = [ search_wikipedia, search_arxiv, search_web, transcribe_audio, read_image, read_file, python_repl, ] TOOL_MAP = {t.name: t for t in TOOLS} SYSTEM_PROMPT = f"""You are a highly capable AI assistant solving tasks from the GAIA benchmark. ## Core rules (MUST follow) 1. THINK before acting: decompose the question and plan which tool(s) you need. 2. NEVER call the same tool with the exact same arguments twice. If the result was insufficient, use a DIFFERENT query or a DIFFERENT tool. 3. If search_wikipedia returns a biography page instead of a discography/list, immediately switch to search_web with a more specific query. 4. For calculations / counting, always use python_repl — never guess numbers. 5. Once you have enough information, STOP calling tools and give the final answer. 6. You have at most {MAX_STEPS} tool-call rounds total. Budget them wisely. ## Tool selection guide - General facts / biography → search_wikipedia (vary query if first try fails) - Discographies, filmographies, lists → search_web (Wikipedia tool may miss these) - Current events / live data → search_web - Scientific papers → search_arxiv - Arithmetic / logic → python_repl - Provided image file → read_image - Provided audio file → transcribe_audio - Provided text/csv/json → read_file ## Answer format End your FINAL response with exactly: FINAL ANSWER: Keep it concise — no units unless asked, lists comma-separated. """ class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] step_count: int # counts agent_node invocations def make_llm(model: str = "gpt-5.4-mini") -> ChatOpenAI: return ChatOpenAI( model=model, temperature=0, api_key=OPENAI_API_KEY, ).bind_tools(TOOLS) llm_with_tools = make_llm() _step = 0 # console display counter CYAN = "\033[96m" GREEN = "\033[92m" YELLOW = "\033[93m" RED = "\033[91m" BOLD = "\033[1m" RESET = "\033[0m" def _log(label: str, text: str, color: str = RESET) -> None: print(f"{color}{'─'*60}{RESET}") print(f"{color}[Step {_step}] {label}{RESET}") if text.strip(): print(f"{color}{text.strip()}{RESET}") def agent_node(state: AgentState) -> AgentState: global _step _step += 1 step_count = state.get("step_count", 0) + 1 messages = state["messages"] # Inject system prompt on first turn if not any(isinstance(m, SystemMessage) for m in messages): messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages # Warn model to wrap up when approaching the limit if step_count >= MAX_STEPS - 2: messages = list(messages) + [HumanMessage( content=f"⚠️ You have used {step_count}/{MAX_STEPS} steps. " "Do NOT call any more tools. Synthesise what you have and give FINAL ANSWER now." )] _log("🤖 AGENT THINKING …", "", CYAN) response = llm_with_tools.invoke(messages) if response.content: _log("🤖 AGENT RESPONSE", str(response.content)[:600], CYAN) if response.tool_calls: calls_summary = "\n".join( f" • {tc['name']}({', '.join(f'{k}={repr(v)}' for k, v in tc['args'].items())})" for tc in response.tool_calls ) _log("🔧 TOOL CALLS PLANNED", calls_summary, YELLOW) else: _log("✅ AGENT FINISHED (no more tool calls)", "", GREEN) return {"messages": [response], "step_count": step_count} def tool_node(state: AgentState) -> AgentState: global _step last_msg: AIMessage = state["messages"][-1] tool_results: list[ToolMessage] = [] for tc in last_msg.tool_calls: _step += 1 tool_fn = TOOL_MAP.get(tc["name"]) _log(f"⚙️ RUNNING: {tc['name']}", "\n".join(f" {k}: {repr(v)}" for k, v in tc["args"].items()), YELLOW) if tool_fn is None: result = f"ERROR: unknown tool '{tc['name']}'" _log("❌ TOOL ERROR", result, RED) else: try: result = tool_fn.invoke(tc["args"]) preview = str(result)[:500] + ("…" if len(str(result)) > 500 else "") _log(f"📥 RESULT: {tc['name']}", preview, GREEN) except Exception as exc: result = f"ERROR calling {tc['name']}: {exc}" _log(f"❌ TOOL ERROR: {tc['name']}", result, RED) tool_results.append( ToolMessage(content=str(result), tool_call_id=tc["id"]) ) return {"messages": tool_results} def should_continue(state: AgentState) -> Literal["tools", "end"]: step_count = state.get("step_count", 0) if step_count >= MAX_STEPS: print(f"{RED}{'─'*60}") print(f"⛔ MAX_STEPS ({MAX_STEPS}) reached — forcing end.{RESET}") return "end" last = state["messages"][-1] if isinstance(last, AIMessage) and last.tool_calls: return "tools" return "end" def build_graph() -> StateGraph: g = StateGraph(AgentState) g.add_node("agent", agent_node) g.add_node("tools", tool_node) g.set_entry_point("agent") g.add_conditional_edges("agent", should_continue, {"tools": "tools", "end": END}) g.add_edge("tools", "agent") # always return to agent after tool use return g.compile() graph = build_graph() def run_agent(question: str, file_path: str | None = None) -> str: """Run the agent on a GAIA question and return the extracted final answer.""" global _step _step = 0 print(f"\n{BOLD}{'═'*60}{RESET}") print(f"{BOLD}❓ QUESTION: {question}{RESET}") if file_path: print(f"{BOLD}📎 FILE: {file_path}{RESET}") print(f"{BOLD}{'═'*60}{RESET}\n") content = question if file_path: content += f"\n\n[Attached file available at: {file_path}]" result = graph.invoke({ "messages": [HumanMessage(content=content)], "step_count": 0, }) last_msg = result["messages"][-1] text = last_msg.content if isinstance(last_msg, AIMessage) else str(last_msg) match = re.search(r"FINAL ANSWER:\s*(.+)", text, re.IGNORECASE | re.DOTALL) answer = match.group(1).strip() if match else text.strip() print(f"\n{BOLD}{GREEN}{'═'*60}{RESET}") print(f"{BOLD}{GREEN}🏁 FINAL ANSWER: {answer}{RESET}") print(f"{BOLD}{GREEN}{'═'*60}{RESET}\n") return answer