Spaces:
Runtime error
Runtime error
File size: 5,395 Bytes
655b11f a0669d3 3f2b048 655b11f 3f2b048 655b11f a0669d3 655b11f a0669d3 655b11f cd6d75f 655b11f 3f2b048 cd6d75f 3f2b048 655b11f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
from typing import Optional, TypedDict, Literal
from langgraph.graph import MessagesState, StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from logging_config import logger
from tools import (
python_tool,
reverse_tool,
excel_file_to_markdown,
sum_numbers,
web_search,
get_wikipedia_info,
ask_audio_model
)
from chess_tool import chess_tool
# MODEL_PROVIDER = "gemini"
MODEL_PROVIDER = "openai"
MAX_ITERATIONS = 5
SYSTEM_PROMPT = \
"""You are a general AI assistant. This is a GAIA problem to solve, be succinct in your answer.
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless
specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits
in plain text unless specified otherwise.
If you need to access a file, use the provided task_id as a parameter to the corresponding tool, unless a url is provided.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put
in the list is a number or a string.
"""
llm_gemini = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
include_thoughts=False,
temperature=0,
max_output_tokens=None,
timeout=60, # The maximum number of seconds to wait for a response.
max_retries=2,
)
llm_openai = ChatOpenAI(
# model="openai/gpt-oss-120b:together",
model="openai/gpt-oss-120b:fireworks-ai",
temperature=0,
max_tokens=None, # type: ignore
timeout=60,
max_retries=2,
api_key=os.getenv("HF_TOKEN"),
base_url="https://router.huggingface.co/v1",
)
if MODEL_PROVIDER == "gemini":
llm = llm_gemini
elif MODEL_PROVIDER == "openai":
llm = llm_openai
else:
raise ValueError(f"Unsupported MODEL_PROVIDER: {MODEL_PROVIDER}")
tools = [python_tool,
reverse_tool,
excel_file_to_markdown,
sum_numbers,
web_search,
get_wikipedia_info,
ask_audio_model,
chess_tool]
llm_with_tools = llm.bind_tools(tools)
class InputState(TypedDict):
question: str
task_id: str
# Define the state type with annotations
class AgentState(MessagesState):
system_message: str
question: str
task_id: str
final_answer: str
iterations: int
error: Optional[str]
class OutputState(TypedDict):
final_answer: str
error: Optional[str]
def input(state: InputState) -> AgentState:
question = state["question"]
messages = [
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=question)
]
return {"messages": messages, # type: ignore
"iterations": 0}
def agent(state: AgentState) -> AgentState:
logger.info(f"LLM invoked: {state['question'][:50]=}{state['task_id']=}")
question = state["question"]
try:
result = llm_with_tools.invoke(state["messages"])
logger.info(f"model metadata = {result.usage_metadata}") # type: ignore
logger.info(f"LLM answer: {result.content}")
# Append the new message to the messages list
messages = state["messages"] + [result]
return {"messages": messages} # type: ignore
except Exception as e:
logger.error(f"LLM invocation failed: {e}")
return {"error": str(e)} # type: ignore
def increment_iterations(state: AgentState) -> AgentState:
# Additional node to increment the iteration count
iterations = state.get("iterations", 0) + 1
return {"iterations": iterations} #type: ignore
def route_tools(state: AgentState) -> Literal["tools", "final_output"]:
"""
Decide if we should continue execution or stop.
"""
messages = state["messages"]
ai_message = messages[-1]
iterations = state["iterations"]
if iterations > MAX_ITERATIONS:
return "final_output" # Stop execution if max iterations are reached
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: # type: ignore
return "tools"
return "final_output" # Stop execution if no tool calls are present
def final_output(state: AgentState) -> OutputState:
try:
messages = state["messages"]
ai_message = messages[-1]
return {"final_answer": ai_message.content} # type: ignore
except Exception as e:
return {"error": e} # type: ignore
builder = StateGraph(AgentState)
tool_node = ToolNode(tools=tools)
builder.add_node("input", input)
builder.add_node("agent", agent)
builder.add_node("increase", increment_iterations)
builder.add_node("tools", tool_node)
builder.add_node("final_output", final_output)
# Define edges for the standard flow
builder.add_edge(START, "input")
builder.add_edge("input", "agent")
builder.add_conditional_edges("agent",
route_tools,
{"tools": "increase",
"final_output": "final_output"}
)
builder.add_edge("increase", "tools")
builder.add_edge("tools", "agent")
builder.add_edge("final_output", END)
builder.compile()
graph = builder.compile() |