File size: 5,395 Bytes
655b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0669d3
 
3f2b048
655b11f
3f2b048
655b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0669d3
655b11f
 
 
 
 
a0669d3
 
655b11f
cd6d75f
655b11f
 
 
 
 
 
3f2b048
cd6d75f
3f2b048
 
 
 
655b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
from typing import Optional, TypedDict, Literal
from langgraph.graph import MessagesState, StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from logging_config import logger
from tools import (
    python_tool,
    reverse_tool,
    excel_file_to_markdown,
    sum_numbers,
    web_search,
    get_wikipedia_info,
    ask_audio_model
    )
from chess_tool import chess_tool

# MODEL_PROVIDER = "gemini"
MODEL_PROVIDER = "openai" 

MAX_ITERATIONS = 5

SYSTEM_PROMPT = \
"""You are a general AI assistant. This is a GAIA problem to solve, be succinct in your answer.
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. 
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless 
specified otherwise. 
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits 
in plain text unless specified otherwise. 
If you need to access a file, use the provided task_id as a parameter to the corresponding tool, unless a url is provided.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put 
in the list is a number or a string.
"""

llm_gemini = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    include_thoughts=False,
    temperature=0,
    max_output_tokens=None,
    timeout=60,  # The maximum number of seconds to wait for a response.
    max_retries=2,
)

llm_openai = ChatOpenAI(
    # model="openai/gpt-oss-120b:together",
    model="openai/gpt-oss-120b:fireworks-ai",
    temperature=0,
    max_tokens=None, # type: ignore
    timeout=60,
    max_retries=2,
    api_key=os.getenv("HF_TOKEN"),
    base_url="https://router.huggingface.co/v1",
)

if MODEL_PROVIDER == "gemini":
    llm = llm_gemini
elif MODEL_PROVIDER == "openai":
    llm = llm_openai
else:
    raise ValueError(f"Unsupported MODEL_PROVIDER: {MODEL_PROVIDER}")

tools = [python_tool,
    reverse_tool,
    excel_file_to_markdown,
    sum_numbers,
    web_search,
    get_wikipedia_info,
    ask_audio_model,
    chess_tool]

llm_with_tools = llm.bind_tools(tools)

class InputState(TypedDict):
    question: str
    task_id: str

# Define the state type with annotations
class AgentState(MessagesState):
    system_message: str
    question: str
    task_id: str
    final_answer: str
    iterations: int
    error: Optional[str]

class OutputState(TypedDict):
    final_answer: str
    error: Optional[str]

def input(state: InputState) -> AgentState:
    question = state["question"]
    messages = [
        SystemMessage(content=SYSTEM_PROMPT),
        HumanMessage(content=question)
    ]
    return {"messages": messages, # type: ignore
            "iterations": 0} 

def agent(state: AgentState) -> AgentState:
    logger.info(f"LLM invoked: {state['question'][:50]=}{state['task_id']=}")
    question = state["question"]
    try:
        result = llm_with_tools.invoke(state["messages"])
        logger.info(f"model metadata = {result.usage_metadata}") # type: ignore
        logger.info(f"LLM answer: {result.content}")
        # Append the new message to the messages list
        messages = state["messages"] + [result]
        return {"messages": messages} # type: ignore
    except Exception as e:
        logger.error(f"LLM invocation failed: {e}")
        return {"error": str(e)} # type: ignore

def increment_iterations(state: AgentState) -> AgentState:
    # Additional node to increment the iteration count
    iterations = state.get("iterations", 0) + 1
    return {"iterations": iterations} #type: ignore

def route_tools(state: AgentState) -> Literal["tools", "final_output"]:
    """
    Decide if we should continue execution or stop.
    """
    messages = state["messages"]
    ai_message = messages[-1]
    iterations = state["iterations"]
    
    if iterations > MAX_ITERATIONS:
        return "final_output"  # Stop execution if max iterations are reached
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: # type: ignore
        return "tools"
    return "final_output"  # Stop execution if no tool calls are present

def final_output(state: AgentState) -> OutputState:
    try:
        messages = state["messages"]
        ai_message = messages[-1]
        return {"final_answer": ai_message.content} # type: ignore
    except Exception as e:
        return {"error": e} # type: ignore

builder = StateGraph(AgentState)
tool_node = ToolNode(tools=tools)
builder.add_node("input", input)
builder.add_node("agent", agent)
builder.add_node("increase", increment_iterations)
builder.add_node("tools", tool_node)
builder.add_node("final_output", final_output)
# Define edges for the standard flow
builder.add_edge(START, "input")
builder.add_edge("input", "agent")
builder.add_conditional_edges("agent", 
                              route_tools,
                              {"tools": "increase",
                               "final_output": "final_output"}
                            )
builder.add_edge("increase", "tools")
builder.add_edge("tools", "agent")
builder.add_edge("final_output", END)
builder.compile()
graph = builder.compile()