pilgrim-65's picture
README modified with clarifications
a0669d3
import os
from typing import Optional, TypedDict, Literal
from langgraph.graph import MessagesState, StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from logging_config import logger
from tools import (
python_tool,
reverse_tool,
excel_file_to_markdown,
sum_numbers,
web_search,
get_wikipedia_info,
ask_audio_model
)
from chess_tool import chess_tool
# MODEL_PROVIDER = "gemini"
MODEL_PROVIDER = "openai"
MAX_ITERATIONS = 5
SYSTEM_PROMPT = \
"""You are a general AI assistant. This is a GAIA problem to solve, be succinct in your answer.
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless
specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits
in plain text unless specified otherwise.
If you need to access a file, use the provided task_id as a parameter to the corresponding tool, unless a url is provided.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put
in the list is a number or a string.
"""
llm_gemini = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
include_thoughts=False,
temperature=0,
max_output_tokens=None,
timeout=60, # The maximum number of seconds to wait for a response.
max_retries=2,
)
llm_openai = ChatOpenAI(
# model="openai/gpt-oss-120b:together",
model="openai/gpt-oss-120b:fireworks-ai",
temperature=0,
max_tokens=None, # type: ignore
timeout=60,
max_retries=2,
api_key=os.getenv("HF_TOKEN"),
base_url="https://router.huggingface.co/v1",
)
if MODEL_PROVIDER == "gemini":
llm = llm_gemini
elif MODEL_PROVIDER == "openai":
llm = llm_openai
else:
raise ValueError(f"Unsupported MODEL_PROVIDER: {MODEL_PROVIDER}")
tools = [python_tool,
reverse_tool,
excel_file_to_markdown,
sum_numbers,
web_search,
get_wikipedia_info,
ask_audio_model,
chess_tool]
llm_with_tools = llm.bind_tools(tools)
class InputState(TypedDict):
question: str
task_id: str
# Define the state type with annotations
class AgentState(MessagesState):
system_message: str
question: str
task_id: str
final_answer: str
iterations: int
error: Optional[str]
class OutputState(TypedDict):
final_answer: str
error: Optional[str]
def input(state: InputState) -> AgentState:
question = state["question"]
messages = [
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=question)
]
return {"messages": messages, # type: ignore
"iterations": 0}
def agent(state: AgentState) -> AgentState:
logger.info(f"LLM invoked: {state['question'][:50]=}{state['task_id']=}")
question = state["question"]
try:
result = llm_with_tools.invoke(state["messages"])
logger.info(f"model metadata = {result.usage_metadata}") # type: ignore
logger.info(f"LLM answer: {result.content}")
# Append the new message to the messages list
messages = state["messages"] + [result]
return {"messages": messages} # type: ignore
except Exception as e:
logger.error(f"LLM invocation failed: {e}")
return {"error": str(e)} # type: ignore
def increment_iterations(state: AgentState) -> AgentState:
# Additional node to increment the iteration count
iterations = state.get("iterations", 0) + 1
return {"iterations": iterations} #type: ignore
def route_tools(state: AgentState) -> Literal["tools", "final_output"]:
"""
Decide if we should continue execution or stop.
"""
messages = state["messages"]
ai_message = messages[-1]
iterations = state["iterations"]
if iterations > MAX_ITERATIONS:
return "final_output" # Stop execution if max iterations are reached
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: # type: ignore
return "tools"
return "final_output" # Stop execution if no tool calls are present
def final_output(state: AgentState) -> OutputState:
try:
messages = state["messages"]
ai_message = messages[-1]
return {"final_answer": ai_message.content} # type: ignore
except Exception as e:
return {"error": e} # type: ignore
builder = StateGraph(AgentState)
tool_node = ToolNode(tools=tools)
builder.add_node("input", input)
builder.add_node("agent", agent)
builder.add_node("increase", increment_iterations)
builder.add_node("tools", tool_node)
builder.add_node("final_output", final_output)
# Define edges for the standard flow
builder.add_edge(START, "input")
builder.add_edge("input", "agent")
builder.add_conditional_edges("agent",
route_tools,
{"tools": "increase",
"final_output": "final_output"}
)
builder.add_edge("increase", "tools")
builder.add_edge("tools", "agent")
builder.add_edge("final_output", END)
builder.compile()
graph = builder.compile()