Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Optional, TypedDict, Literal | |
| from langgraph.graph import MessagesState, StateGraph, START, END | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from logging_config import logger | |
| from tools import ( | |
| python_tool, | |
| reverse_tool, | |
| excel_file_to_markdown, | |
| sum_numbers, | |
| web_search, | |
| get_wikipedia_info, | |
| ask_audio_model | |
| ) | |
| from chess_tool import chess_tool | |
| # MODEL_PROVIDER = "gemini" | |
| MODEL_PROVIDER = "openai" | |
| MAX_ITERATIONS = 5 | |
| SYSTEM_PROMPT = \ | |
| """You are a general AI assistant. This is a GAIA problem to solve, be succinct in your answer. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
| If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless | |
| specified otherwise. | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits | |
| in plain text unless specified otherwise. | |
| If you need to access a file, use the provided task_id as a parameter to the corresponding tool, unless a url is provided. | |
| If you are asked for a comma separated list, apply the above rules depending of whether the element to be put | |
| in the list is a number or a string. | |
| """ | |
| llm_gemini = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash", | |
| include_thoughts=False, | |
| temperature=0, | |
| max_output_tokens=None, | |
| timeout=60, # The maximum number of seconds to wait for a response. | |
| max_retries=2, | |
| ) | |
| llm_openai = ChatOpenAI( | |
| # model="openai/gpt-oss-120b:together", | |
| model="openai/gpt-oss-120b:fireworks-ai", | |
| temperature=0, | |
| max_tokens=None, # type: ignore | |
| timeout=60, | |
| max_retries=2, | |
| api_key=os.getenv("HF_TOKEN"), | |
| base_url="https://router.huggingface.co/v1", | |
| ) | |
| if MODEL_PROVIDER == "gemini": | |
| llm = llm_gemini | |
| elif MODEL_PROVIDER == "openai": | |
| llm = llm_openai | |
| else: | |
| raise ValueError(f"Unsupported MODEL_PROVIDER: {MODEL_PROVIDER}") | |
| tools = [python_tool, | |
| reverse_tool, | |
| excel_file_to_markdown, | |
| sum_numbers, | |
| web_search, | |
| get_wikipedia_info, | |
| ask_audio_model, | |
| chess_tool] | |
| llm_with_tools = llm.bind_tools(tools) | |
| class InputState(TypedDict): | |
| question: str | |
| task_id: str | |
| # Define the state type with annotations | |
| class AgentState(MessagesState): | |
| system_message: str | |
| question: str | |
| task_id: str | |
| final_answer: str | |
| iterations: int | |
| error: Optional[str] | |
| class OutputState(TypedDict): | |
| final_answer: str | |
| error: Optional[str] | |
| def input(state: InputState) -> AgentState: | |
| question = state["question"] | |
| messages = [ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=question) | |
| ] | |
| return {"messages": messages, # type: ignore | |
| "iterations": 0} | |
| def agent(state: AgentState) -> AgentState: | |
| logger.info(f"LLM invoked: {state['question'][:50]=}{state['task_id']=}") | |
| question = state["question"] | |
| try: | |
| result = llm_with_tools.invoke(state["messages"]) | |
| logger.info(f"model metadata = {result.usage_metadata}") # type: ignore | |
| logger.info(f"LLM answer: {result.content}") | |
| # Append the new message to the messages list | |
| messages = state["messages"] + [result] | |
| return {"messages": messages} # type: ignore | |
| except Exception as e: | |
| logger.error(f"LLM invocation failed: {e}") | |
| return {"error": str(e)} # type: ignore | |
| def increment_iterations(state: AgentState) -> AgentState: | |
| # Additional node to increment the iteration count | |
| iterations = state.get("iterations", 0) + 1 | |
| return {"iterations": iterations} #type: ignore | |
| def route_tools(state: AgentState) -> Literal["tools", "final_output"]: | |
| """ | |
| Decide if we should continue execution or stop. | |
| """ | |
| messages = state["messages"] | |
| ai_message = messages[-1] | |
| iterations = state["iterations"] | |
| if iterations > MAX_ITERATIONS: | |
| return "final_output" # Stop execution if max iterations are reached | |
| if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: # type: ignore | |
| return "tools" | |
| return "final_output" # Stop execution if no tool calls are present | |
| def final_output(state: AgentState) -> OutputState: | |
| try: | |
| messages = state["messages"] | |
| ai_message = messages[-1] | |
| return {"final_answer": ai_message.content} # type: ignore | |
| except Exception as e: | |
| return {"error": e} # type: ignore | |
| builder = StateGraph(AgentState) | |
| tool_node = ToolNode(tools=tools) | |
| builder.add_node("input", input) | |
| builder.add_node("agent", agent) | |
| builder.add_node("increase", increment_iterations) | |
| builder.add_node("tools", tool_node) | |
| builder.add_node("final_output", final_output) | |
| # Define edges for the standard flow | |
| builder.add_edge(START, "input") | |
| builder.add_edge("input", "agent") | |
| builder.add_conditional_edges("agent", | |
| route_tools, | |
| {"tools": "increase", | |
| "final_output": "final_output"} | |
| ) | |
| builder.add_edge("increase", "tools") | |
| builder.add_edge("tools", "agent") | |
| builder.add_edge("final_output", END) | |
| builder.compile() | |
| graph = builder.compile() |