File size: 13,599 Bytes
1ffaf53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
from collections.abc import Sequence
from typing import Annotated, Literal

from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field

from src.agents.models import FeasibilityCheck, FinalAnswer, FinalConclusion, NextStep
from src.agents.prompts import GAIAPrompts
from src.agents.tools import tools

# Initialize
model = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
model_with_tools = model.bind_tools(tools)

tools_by_name = {tool.name: tool for tool in tools}

prompts = GAIAPrompts()


# Graph state
class GraphState(BaseModel):
    """The state of the graph"""

    # History
    history: Annotated[Sequence[BaseMessage], add_messages] = Field(
        default_factory=list
    )  # Complete history with node info
    coordinator_messages: Annotated[Sequence[BaseMessage], add_messages] = Field(
        default_factory=list
    )  # Coordinator-specific messages
    executor_messages: Sequence[BaseMessage] = Field(default_factory=list)  # Executor-specific messages

    # Input
    question: str

    # Feasibility check
    feasibility: FeasibilityCheck | None = None

    # Coordinator state
    next_step: NextStep | None = None
    coordinator_conclusion: FinalConclusion | None = None
    coordinator_iterations: int
    coordinator_max_iterations: int

    # Executor state
    executor_conclusion: FinalConclusion | None = None
    executor_iterations: int
    executor_max_iterations: int

    # Final answer state
    final_answer: FinalAnswer | None = None

    def __getitem__(self, item):
        return getattr(self, item)


# Nodes
def check_feasibility(state: GraphState, config: RunnableConfig):
    """Check if the question is feasible to answer with the available tools"""

    question = state["question"]

    system_message = SystemMessage(content=prompts.get_feasibility_check_prompt(tools), node="feasibility")
    question_message = HumanMessage(content=question, node="feasibility")
    messages = [system_message, question_message]

    structured_model = model.with_structured_output(FeasibilityCheck)
    response = structured_model.invoke(messages, config)

    response_message = AIMessage(content=str(response), node="feasibility")
    messages += [response_message]
    return {
        "history": messages,
        "feasibility": response,
    }


def coordinator_node(state: GraphState, config: RunnableConfig):
    """Determine the next step in the plan and select appropriate tools"""

    coordinator_messages = state["coordinator_messages"]
    new_messages = []

    if not coordinator_messages:
        system_message = SystemMessage(content=prompts.get_coordinator_system_prompt(tools), node="coordinator")
        human_message = HumanMessage(
            content=prompts.get_coordinator_context_prompt(state["question"]), node="coordinator"
        )
        coordinator_messages = [system_message, human_message]
        new_messages = coordinator_messages

    if state["executor_conclusion"]:
        executor_message = AIMessage(
            content=f"Executor conclusion: {state['executor_conclusion'].conclusion}. Complete text: {str(state['executor_conclusion'])}",
            node="executor",
        )
        coordinator_messages += [executor_message]
        new_messages += [executor_message]

    # Check if we've reached max iterations
    if (state["next_step"] and state["next_step"].is_final) or (
        state["coordinator_iterations"] >= state["coordinator_max_iterations"]
    ):
        # Generate final conclusion instead of next step
        human_message = HumanMessage(
            content=prompts.get_coordinator_max_iterations_prompt(state["question"]), node="coordinator"
        )

        structured_model = model.with_structured_output(FinalConclusion)
        response = structured_model.invoke(coordinator_messages + [human_message], config)
        response_message = AIMessage(content=str(response), node="coordinator")

        new_messages += [human_message, response_message]
        return {
            "history": new_messages,
            "coordinator_messages": new_messages,
            "coordinator_conclusion": response,
            "coordinator_iterations": state["coordinator_iterations"] + 1,
        }

    structured_model = model.with_structured_output(NextStep)
    response = structured_model.invoke(coordinator_messages, config)

    response_message = AIMessage(content=str(response), node="coordinator")
    new_messages += [response_message]

    return {
        "history": new_messages,
        "coordinator_messages": new_messages,
        "coordinator_iterations": state["coordinator_iterations"] + 1,
        "next_step": response,
        "executor_messages": [],
        "executor_conclusion": None,
        "executor_iterations": 0,
    }


def executor_node(state: GraphState, config: RunnableConfig):
    """Plan the execution of the current step using ReAct pattern"""
    if not state["next_step"]:
        return {
            "executor_conclusion": FinalConclusion(conclusion="No next step", partial_results=""),
            "executor_iterations": state["executor_iterations"] + 1,
        }

    messages = state["executor_messages"]

    if not messages:
        system_message = SystemMessage(
            content=prompts.get_executor_system_prompt(state["next_step"].tools),
            node="executor",
        )
        human_message = HumanMessage(content=prompts.get_executor_task_prompt(state["next_step"].step), node="executor")
        messages = [system_message, human_message]

    if state["executor_iterations"] >= state["executor_max_iterations"]:
        # Generate final conclusion and return to coordinator
        human_message = HumanMessage(
            content=prompts.get_executor_max_iterations_prompt(state["next_step"].step),
            node="executor",
        )

        messages += [human_message]

        structured_model = model.with_structured_output(FinalConclusion)
        response = structured_model.invoke(messages, config)

        response_message = AIMessage(
            content=f"Executor conclusion: {str(response)}",
            node="executor",
        )

        return {
            "history": [human_message, response_message],
            "executor_conclusion": response
            or FinalConclusion(conclusion="Failed to generate conclusion", partial_results=""),
            "executor_iterations": state["executor_iterations"] + 1,
        }

    selected_tools = [tool for tool in tools if tool.name in state["next_step"].tools]
    model_with_selected_tools = model.bind_tools(selected_tools)

    response_message = model_with_selected_tools.invoke(messages, config)
    response_message.node = "executor"

    return {
        "history": response_message,
        "executor_messages": messages + [response_message],
        "executor_iterations": state["executor_iterations"] + 1,
    }


def tool_node(state: GraphState):
    """Execute tools based on the last message's tool calls"""
    outputs = []
    messages = state["executor_messages"]
    last_message = state["executor_messages"][-1]

    for tool_call in last_message.tool_calls:
        try:
            tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
            tool_message = ToolMessage(
                content=str(tool_result),
                name=tool_call["name"],
                tool_call_id=tool_call["id"],
                node="tools",
            )
            outputs.append(tool_message)
        except Exception as e:
            tool_message = ToolMessage(
                content=f"Error executing tool {tool_call['name']}: {str(e)}",
                name=tool_call["name"],
                tool_call_id=tool_call["id"],
                node="tools",
            )
            outputs.append(tool_message)

    return {
        "history": outputs,
        "executor_messages": messages + outputs,
    }


def finalise(state: GraphState, config: RunnableConfig):
    """Generate the final answer based on coordinator history"""
    system_message = SystemMessage(content=prompts.get_finalizer_prompt(), node="finalise")
    messages = [system_message] + state["coordinator_messages"]

    structured_model = model.with_structured_output(FinalAnswer)
    response = structured_model.invoke(messages, config)
    response_message = AIMessage(content=str(response), node="finalise")

    return {"history": response_message, "final_answer": response}


# Edges
def should_continue_after_feasibility(state: GraphState) -> Literal["coordinator", END]:
    """Decide whether to continue with coordination or end"""
    if state["feasibility"] and state["feasibility"].feasible:
        return "coordinator"
    return END


def should_continue_after_coordinator(state: GraphState) -> Literal["executor", "finalise"]:
    """Decide whether to continue with execution or go to final answer"""
    if state["coordinator_conclusion"] or (state["coordinator_iterations"] >= state["coordinator_max_iterations"]):
        return "finalise"
    return "executor"


def should_continue_after_executor(state: GraphState) -> Literal["tools", "coordinator", "executor"]:
    """Decide whether to continue with tools or go back to coordinator"""
    last_message = state["executor_messages"][-1]
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "tools"

    if state["executor_conclusion"]:
        return "coordinator"

    return "executor"


def should_continue_after_tools(state: GraphState) -> Literal["executor"]:
    """Tools always go back to executor"""
    return "executor"


# Graph
def build_graph():
    """Build the graph"""
    graph = StateGraph(GraphState)

    # Add nodes
    graph.add_node("check_feasibility", check_feasibility)
    graph.add_node("coordinator", coordinator_node)
    graph.add_node("executor", executor_node)
    graph.add_node("tools", tool_node)
    graph.add_node("finalise", finalise)

    # Set entry point
    graph.set_entry_point("check_feasibility")

    # Add edges
    graph.add_conditional_edges(
        "check_feasibility", should_continue_after_feasibility, {"coordinator": "coordinator", END: END}
    )
    graph.add_conditional_edges(
        "coordinator", should_continue_after_coordinator, {"executor": "executor", "finalise": "finalise"}
    )
    graph.add_conditional_edges(
        "executor",
        should_continue_after_executor,
        {"executor": "executor", "tools": "tools", "coordinator": "coordinator"},
    )
    graph.add_conditional_edges(
        "tools",
        should_continue_after_tools,
        {"executor": "executor"},
    )

    # Finalise node goes to END
    graph.add_edge("finalise", END)

    return graph.compile()


def run_agent(question: str, coordinator_max_iterations: int = 5, executor_max_iterations: int = 3):
    """Run the agent with a question"""
    graph = build_graph()

    initial_state = {
        "question": question,
        "history": [],
        "coordinator_messages": [],
        "executor_messages": [],
        "coordinator_iterations": 0,
        "executor_iterations": 0,
        "coordinator_max_iterations": coordinator_max_iterations,
        "executor_max_iterations": executor_max_iterations,
    }

    # Stream the execution
    print(f"Question: {question}")
    print("=" * 50)

    for step in graph.stream(initial_state):
        for node, output in step.items():
            print(f"\n--- {node.upper()} ---")

            # Print history with node information
            if "history" in output and output["history"]:
                print("\nComplete History (with node info):")
                for msg in output["history"]:
                    node_info = getattr(msg, "node", "unknown") if hasattr(msg, "node") else "unknown"
                    content = getattr(msg, "content", str(msg)) if hasattr(msg, "content") else str(msg)
                    print(f"[{node_info}] {msg.__class__.__name__}: {content}")

            if "coordinator_messages" in output and output["coordinator_messages"]:
                print("\nCoordinator Messages:")
                for msg in output["coordinator_messages"]:
                    if hasattr(msg, "content"):
                        print(f"{msg.__class__.__name__}: {msg.content}")

            if "executor_messages" in output and output["executor_messages"]:
                print("\nExecutor Messages:")
                for msg in output["executor_messages"]:
                    if hasattr(msg, "content"):
                        print(f"{msg.__class__.__name__}: {msg.content}")

            if "executor_conclusion" in output and output["executor_conclusion"]:
                print("\n=== EXECUTOR CONCLUSION ===")
                print(f"Conclusion: {output['executor_conclusion'].conclusion}")
                print(f"Partial Results: {output['executor_conclusion'].partial_results}")
                print(f"Confidence: {output['executor_conclusion'].confidence}")

            if "final_answer" in output and output["final_answer"]:
                print("\n=== FINAL ANSWER ===")
                print(f"Answer: {output['final_answer'].answer}")
                print(f"Confidence: {output['final_answer'].confidence}")
                print(f"Reasoning: {output['final_answer'].reasoning}")