File size: 4,278 Bytes
6059138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from __future__ import annotations

from collections.abc import Sequence

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain.agents import create_agent

from memory_agent.config import AppConfig
from memory_agent.embeddings import MistralEmbedEmbeddings
from memory_agent.knowledge import FALLBACK_MESSAGE, UNDERSTOOD_MESSAGE, KnowledgeService
from memory_agent.llm import MistralLLMFactory
from memory_agent.prompt_loader import PromptLoader
from memory_agent.retrieval import HybridRetriever
from memory_agent.storage import FaissMemoryStore
from memory_agent.tools import KnowledgeToolFactory


class UniversalMemoryAgent:
    def __init__(self, config: AppConfig) -> None:
        self._config = config
        self._llm = MistralLLMFactory(config=config).create()
        prompt_set = PromptLoader().load()
        self._system_prompt = prompt_set.system_prompt.format(
            understood_message=UNDERSTOOD_MESSAGE,
            fallback_message=FALLBACK_MESSAGE,
        )
        self._embeddings = MistralEmbedEmbeddings(
            model_name=config.embedding_model,
            api_token=config.mistral_api_key,
            max_retries=config.mistral_max_retries,
            base_delay_seconds=config.rate_limit_base_delay_seconds,
            max_delay_seconds=config.rate_limit_max_delay_seconds,
        )
        self._store = FaissMemoryStore(config=config, embeddings=self._embeddings)
        self._retriever = HybridRetriever(
            store=self._store,
            dense_weight=config.dense_weight,
            sparse_weight=config.sparse_weight,
        )
        self._knowledge_service = KnowledgeService(
            store=self._store,
            retriever=self._retriever,
            llm=self._llm,
            top_k=config.top_k,
            user_prompt_template=prompt_set.user_prompt,
        )
        self._tool_factory = KnowledgeToolFactory(service=self._knowledge_service)
        self._graphs: dict[str, object] = {}

    def run(self, user_input: str, chat_history: Sequence[BaseMessage], namespace: str) -> str:
        graph = self._get_or_create_graph(namespace=namespace)
        state = graph.invoke({"messages": [*chat_history, HumanMessage(content=user_input)]})
        return self._resolve_response(messages=state.get("messages", []))

    def _get_or_create_graph(self, namespace: str):
        if namespace not in self._graphs:
            tools = self._tool_factory.build(namespace=namespace)
            self._graphs[namespace] = create_agent(
                model=self._llm,
                tools=tools,
                system_prompt=self._system_prompt,
            )
        return self._graphs[namespace]

    @staticmethod
    def _resolve_response(messages: Sequence[BaseMessage]) -> str:
        tool_by_call_id: dict[str, str] = {}
        for message in messages:
            if isinstance(message, AIMessage):
                for tool_call in message.tool_calls:
                    tool_name = tool_call.get("name")
                    tool_call_id = tool_call.get("id")
                    if tool_name and tool_call_id:
                        tool_by_call_id[str(tool_call_id)] = str(tool_name)

        saw_store_call = False
        for message in reversed(messages):
            if isinstance(message, ToolMessage):
                tool_name = message.name or tool_by_call_id.get(str(message.tool_call_id))
                if tool_name == "answer_from_knowledge":
                    content = str(message.content).strip()
                    return content or FALLBACK_MESSAGE
                if tool_name == "store_knowledge":
                    saw_store_call = True

        if saw_store_call:
            return UNDERSTOOD_MESSAGE

        # Allow direct assistant text for social/onboarding turns chosen by the model.
        for message in reversed(messages):
            if isinstance(message, AIMessage):
                content = str(message.content).strip()
                if content:
                    # Guard contract: "understood" is valid only after store_knowledge tool call.
                    if content == UNDERSTOOD_MESSAGE:
                        continue
                    return content

        return FALLBACK_MESSAGE