|
|
from typing import Optional |
|
|
|
|
|
from langchain_core.messages import HumanMessage |
|
|
from langgraph.graph import START, StateGraph, END |
|
|
from langgraph.graph.state import CompiledStateGraph |
|
|
from langgraph.prebuilt import ToolNode |
|
|
from langgraph.prebuilt import tools_condition |
|
|
|
|
|
from core.messages import Attachment |
|
|
from core.state import State |
|
|
from nodes.nodes import assistant, optimize_memory, response_processing, pre_processor, agent_tools |
|
|
|
|
|
|
|
|
class GaiaAgent: |
|
|
react_graph: CompiledStateGraph |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
builder = StateGraph(State) |
|
|
|
|
|
|
|
|
builder.add_node("pre_processor", pre_processor) |
|
|
builder.add_node("assistant", assistant) |
|
|
builder.add_node("tools", ToolNode(agent_tools)) |
|
|
builder.add_node("optimize_memory", optimize_memory) |
|
|
builder.add_node("response_processing", response_processing) |
|
|
|
|
|
|
|
|
builder.add_edge(START, "pre_processor") |
|
|
builder.add_edge("pre_processor", "assistant") |
|
|
|
|
|
builder.add_conditional_edges( |
|
|
"assistant", |
|
|
|
|
|
|
|
|
|
|
|
tools_condition, {"tools": "tools", "__end__": "response_processing"} |
|
|
) |
|
|
|
|
|
builder.add_edge("tools", "optimize_memory") |
|
|
builder.add_edge("optimize_memory", "assistant") |
|
|
builder.add_edge("response_processing", END) |
|
|
self.react_graph = builder.compile() |
|
|
|
|
|
def __call__(self, question: str, attachment: Optional[Attachment] = None) -> str: |
|
|
initial_state = {"messages": [HumanMessage(content=question)], "question": question} |
|
|
if attachment: |
|
|
initial_state["file_reference"] = attachment.file_path |
|
|
|
|
|
messages = self.react_graph.invoke(initial_state, {"recursion_limit": 30}) |
|
|
|
|
|
|
|
|
|
|
|
answer = messages['messages'][-1].content |
|
|
return answer |
|
|
|
|
|
def __streamed_call__(self, question: str, attachment: Optional[Attachment] = None) -> str: |
|
|
initial_state = {"messages": [HumanMessage(content=question)], "question": question} |
|
|
if attachment: |
|
|
initial_state["file_reference"] = attachment.file_path |
|
|
|
|
|
|
|
|
for s in self.react_graph.stream(initial_state, stream_mode="values"): |
|
|
message = s["messages"][-1] |
|
|
if isinstance(message, tuple): |
|
|
print(message) |
|
|
else: |
|
|
message.pretty_print() |
|
|
return message.content |
|
|
|