File size: 2,814 Bytes
778116a 8073bab a4b0424 8073bab a4b0424 8073bab 778116a 8073bab 0c44617 8073bab 778116a 8073bab 0c44617 8073bab 778116a 5813885 8073bab 5813885 8073bab 778116a 0f45d0b 778116a a4b0424 778116a 8073bab 5813885 8073bab 778116a 0f45d0b 778116a 8073bab 778116a 8073bab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from typing import Optional
from langchain_core.messages import HumanMessage
from langgraph.graph import START, StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from core.messages import Attachment
from core.state import State
from nodes.nodes import assistant, optimize_memory, response_processing, pre_processor, agent_tools
class GaiaAgent:
react_graph: CompiledStateGraph
def __init__(self):
# Graph
builder = StateGraph(State)
# Define nodes: these do the work
builder.add_node("pre_processor", pre_processor)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(agent_tools))
builder.add_node("optimize_memory", optimize_memory)
builder.add_node("response_processing", response_processing)
# Define edges: these determine how the control flow moves
builder.add_edge(START, "pre_processor")
builder.add_edge("pre_processor", "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools If the
# latest message (result) from assistant is a not a tool call -> tools_condition routes to
# response_processing
tools_condition, {"tools": "tools", "__end__": "response_processing"}
)
builder.add_edge("tools", "optimize_memory")
builder.add_edge("optimize_memory", "assistant")
builder.add_edge("response_processing", END)
self.react_graph = builder.compile()
def __call__(self, question: str, attachment: Optional[Attachment] = None) -> str:
initial_state = {"messages": [HumanMessage(content=question)], "question": question}
if attachment:
initial_state["file_reference"] = attachment.file_path
messages = self.react_graph.invoke(initial_state, {"recursion_limit": 30})
# for m in messages['messages']:
# m.pretty_print()
answer = messages['messages'][-1].content
return answer
def __streamed_call__(self, question: str, attachment: Optional[Attachment] = None) -> str:
initial_state = {"messages": [HumanMessage(content=question)], "question": question}
if attachment:
initial_state["file_reference"] = attachment.file_path
# Stream the web agent's response
for s in self.react_graph.stream(initial_state, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()
return message.content
|