File size: 2,566 Bytes
f45999f 8477204 f45999f 8477204 f45999f dfd09d4 f45999f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from langgraph.func import task
from langchain_core.messages import (
SystemMessage,
ToolMessage,
HumanMessage,
BaseMessage,
)
from langchain_core.runnables import RunnableConfig
from .memory_client import memory_client
from .tools import search_memories, search_vectorstore
from .llms import main_model, output_formatter_model
from .prompts import MAKHFI_AI_PROMPT, OUTPUT_FORMATTER_PROMPT
from .schemas import OutputFormat
tools = [search_memories, search_vectorstore]
tools_by_name = {tool.name: tool for tool in tools}
agent_with_tools = main_model.bind_tools(tools)
output_formatter = output_formatter_model.with_structured_output(OutputFormat)
@task
def call_model(messages: list[BaseMessage], memories: str):
"""Call model with a sequence of messages."""
response = agent_with_tools.invoke(
[
SystemMessage(
content=MAKHFI_AI_PROMPT.format(memories=memories)
)
] + messages
)
return response
@task
def get_structued_output(agent_response: str) -> OutputFormat:
response: OutputFormat = output_formatter.invoke(
[
SystemMessage(content=OUTPUT_FORMATTER_PROMPT),
HumanMessage(content=agent_response),
]
)
return response
@task
def manage_memories(user_message: str, config: RunnableConfig):
"""Handles memories operations"""
user_id = config.get("configurable", {}).get("user_id")
message = [{"role": "user", "content": user_message}]
memories = memory_client.add(
message, user_id=user_id, version="v2", output_format="v1.1"
)
return memories
@task
def get_recent_memories(config: RunnableConfig):
"""Retrieve the most recent user memories (max 10) based on `updated_at` timestamp. """
user_id = config.get("configurable", {}).get("user_id")
if not user_id:
raise ValueError("User Id not found in config")
memories = memory_client.get_all(version="v2", filters={"user_id": user_id})
if not memories:
return "(No information is available about the user yet)"
# Sort memories by `updated_at` in descending order
sorted_memories = sorted(memories, key=lambda m: m["updated_at"], reverse=True)
recent_memories = sorted_memories[:10]
return "\n".join(f"- {mem['memory']}" for mem in recent_memories)
@task
def call_tool(tool_call, config: RunnableConfig):
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
return ToolMessage(content=observation, tool_call_id=tool_call["id"])
|