carolinacon's picture
Modified the Gradio interface
0c44617
from typing import Optional
from langchain_core.messages import HumanMessage
from langgraph.graph import START, StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from core.messages import Attachment
from core.state import State
from nodes.nodes import assistant, optimize_memory, response_processing, pre_processor, agent_tools
class GaiaAgent:
react_graph: CompiledStateGraph
def __init__(self):
# Graph
builder = StateGraph(State)
# Define nodes: these do the work
builder.add_node("pre_processor", pre_processor)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(agent_tools))
builder.add_node("optimize_memory", optimize_memory)
builder.add_node("response_processing", response_processing)
# Define edges: these determine how the control flow moves
builder.add_edge(START, "pre_processor")
builder.add_edge("pre_processor", "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools If the
# latest message (result) from assistant is a not a tool call -> tools_condition routes to
# response_processing
tools_condition, {"tools": "tools", "__end__": "response_processing"}
)
builder.add_edge("tools", "optimize_memory")
builder.add_edge("optimize_memory", "assistant")
builder.add_edge("response_processing", END)
self.react_graph = builder.compile()
def __call__(self, question: str, attachment: Optional[Attachment] = None) -> str:
initial_state = {"messages": [HumanMessage(content=question)], "question": question}
if attachment:
initial_state["file_reference"] = attachment.file_path
messages = self.react_graph.invoke(initial_state, {"recursion_limit": 30})
# for m in messages['messages']:
# m.pretty_print()
answer = messages['messages'][-1].content
return answer
def __streamed_call__(self, question: str, attachment: Optional[Attachment] = None) -> str:
initial_state = {"messages": [HumanMessage(content=question)], "question": question}
if attachment:
initial_state["file_reference"] = attachment.file_path
# Stream the web agent's response
for s in self.react_graph.stream(initial_state, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()
return message.content