File size: 2,566 Bytes
f45999f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8477204
f45999f
 
8477204
 
 
 
 
f45999f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfd09d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f45999f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from langgraph.func import task
from langchain_core.messages import (
    SystemMessage,
    ToolMessage,
    HumanMessage,
    BaseMessage,
)
from langchain_core.runnables import RunnableConfig

from .memory_client import memory_client
from .tools import search_memories, search_vectorstore
from .llms import main_model, output_formatter_model
from .prompts import MAKHFI_AI_PROMPT, OUTPUT_FORMATTER_PROMPT
from .schemas import OutputFormat

tools = [search_memories, search_vectorstore]
tools_by_name = {tool.name: tool for tool in tools}

agent_with_tools = main_model.bind_tools(tools)
output_formatter = output_formatter_model.with_structured_output(OutputFormat)

@task
def call_model(messages: list[BaseMessage], memories: str):
    """Call model with a sequence of messages."""
    response = agent_with_tools.invoke(
        [
            SystemMessage(
                content=MAKHFI_AI_PROMPT.format(memories=memories)
            )
        ] + messages
    )
    return response

@task
def get_structued_output(agent_response: str) -> OutputFormat:
    response: OutputFormat = output_formatter.invoke(
        [
            SystemMessage(content=OUTPUT_FORMATTER_PROMPT),
            HumanMessage(content=agent_response),
        ]
    )
    return response

@task
def manage_memories(user_message: str, config: RunnableConfig):
    """Handles memories operations"""
    user_id = config.get("configurable", {}).get("user_id")

    message = [{"role": "user", "content": user_message}]
    memories = memory_client.add(
        message, user_id=user_id, version="v2", output_format="v1.1"
    )
    return memories

@task
def get_recent_memories(config: RunnableConfig):
    """Retrieve the most recent user memories (max 10) based on `updated_at` timestamp. """
    user_id = config.get("configurable", {}).get("user_id")

    if not user_id:
        raise ValueError("User Id not found in config")

    memories = memory_client.get_all(version="v2", filters={"user_id": user_id})
    
    if not memories:
        return "(No information is available about the user yet)"

    # Sort memories by `updated_at` in descending order
    sorted_memories = sorted(memories, key=lambda m: m["updated_at"], reverse=True)
    recent_memories = sorted_memories[:10]

    return "\n".join(f"- {mem['memory']}" for mem in recent_memories)

@task
def call_tool(tool_call, config: RunnableConfig):
    tool = tools_by_name[tool_call["name"]]
    observation = tool.invoke(tool_call["args"])
    return ToolMessage(content=observation, tool_call_id=tool_call["id"])