import os from typing import Optional, TypedDict, Literal from langgraph.graph import MessagesState, StateGraph, START, END from langgraph.prebuilt import ToolNode from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage from logging_config import logger from tools import ( python_tool, reverse_tool, excel_file_to_markdown, sum_numbers, web_search, get_wikipedia_info, ask_audio_model ) from chess_tool import chess_tool # MODEL_PROVIDER = "gemini" MODEL_PROVIDER = "openai" MAX_ITERATIONS = 5 SYSTEM_PROMPT = \ """You are a general AI assistant. This is a GAIA problem to solve, be succinct in your answer. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you need to access a file, use the provided task_id as a parameter to the corresponding tool, unless a url is provided. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. """ llm_gemini = ChatGoogleGenerativeAI( model="gemini-2.5-flash", include_thoughts=False, temperature=0, max_output_tokens=None, timeout=60, # The maximum number of seconds to wait for a response. max_retries=2, ) llm_openai = ChatOpenAI( # model="openai/gpt-oss-120b:together", model="openai/gpt-oss-120b:fireworks-ai", temperature=0, max_tokens=None, # type: ignore timeout=60, max_retries=2, api_key=os.getenv("HF_TOKEN"), base_url="https://router.huggingface.co/v1", ) if MODEL_PROVIDER == "gemini": llm = llm_gemini elif MODEL_PROVIDER == "openai": llm = llm_openai else: raise ValueError(f"Unsupported MODEL_PROVIDER: {MODEL_PROVIDER}") tools = [python_tool, reverse_tool, excel_file_to_markdown, sum_numbers, web_search, get_wikipedia_info, ask_audio_model, chess_tool] llm_with_tools = llm.bind_tools(tools) class InputState(TypedDict): question: str task_id: str # Define the state type with annotations class AgentState(MessagesState): system_message: str question: str task_id: str final_answer: str iterations: int error: Optional[str] class OutputState(TypedDict): final_answer: str error: Optional[str] def input(state: InputState) -> AgentState: question = state["question"] messages = [ SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=question) ] return {"messages": messages, # type: ignore "iterations": 0} def agent(state: AgentState) -> AgentState: logger.info(f"LLM invoked: {state['question'][:50]=}{state['task_id']=}") question = state["question"] try: result = llm_with_tools.invoke(state["messages"]) logger.info(f"model metadata = {result.usage_metadata}") # type: ignore logger.info(f"LLM answer: {result.content}") # Append the new message to the messages list messages = state["messages"] + [result] return {"messages": messages} # type: ignore except Exception as e: logger.error(f"LLM invocation failed: {e}") return {"error": str(e)} # type: ignore def increment_iterations(state: AgentState) -> AgentState: # Additional node to increment the iteration count iterations = state.get("iterations", 0) + 1 return {"iterations": iterations} #type: ignore def route_tools(state: AgentState) -> Literal["tools", "final_output"]: """ Decide if we should continue execution or stop. """ messages = state["messages"] ai_message = messages[-1] iterations = state["iterations"] if iterations > MAX_ITERATIONS: return "final_output" # Stop execution if max iterations are reached if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: # type: ignore return "tools" return "final_output" # Stop execution if no tool calls are present def final_output(state: AgentState) -> OutputState: try: messages = state["messages"] ai_message = messages[-1] return {"final_answer": ai_message.content} # type: ignore except Exception as e: return {"error": e} # type: ignore builder = StateGraph(AgentState) tool_node = ToolNode(tools=tools) builder.add_node("input", input) builder.add_node("agent", agent) builder.add_node("increase", increment_iterations) builder.add_node("tools", tool_node) builder.add_node("final_output", final_output) # Define edges for the standard flow builder.add_edge(START, "input") builder.add_edge("input", "agent") builder.add_conditional_edges("agent", route_tools, {"tools": "increase", "final_output": "final_output"} ) builder.add_edge("increase", "tools") builder.add_edge("tools", "agent") builder.add_edge("final_output", END) builder.compile() graph = builder.compile()