from __future__ import annotations from collections.abc import Sequence from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain.agents import create_agent from memory_agent.config import AppConfig from memory_agent.embeddings import MistralEmbedEmbeddings from memory_agent.knowledge import FALLBACK_MESSAGE, UNDERSTOOD_MESSAGE, KnowledgeService from memory_agent.llm import MistralLLMFactory from memory_agent.prompt_loader import PromptLoader from memory_agent.retrieval import HybridRetriever from memory_agent.storage import FaissMemoryStore from memory_agent.tools import KnowledgeToolFactory class UniversalMemoryAgent: def __init__(self, config: AppConfig) -> None: self._config = config self._llm = MistralLLMFactory(config=config).create() prompt_set = PromptLoader().load() self._system_prompt = prompt_set.system_prompt.format( understood_message=UNDERSTOOD_MESSAGE, fallback_message=FALLBACK_MESSAGE, ) self._embeddings = MistralEmbedEmbeddings( model_name=config.embedding_model, api_token=config.mistral_api_key, max_retries=config.mistral_max_retries, base_delay_seconds=config.rate_limit_base_delay_seconds, max_delay_seconds=config.rate_limit_max_delay_seconds, ) self._store = FaissMemoryStore(config=config, embeddings=self._embeddings) self._retriever = HybridRetriever( store=self._store, dense_weight=config.dense_weight, sparse_weight=config.sparse_weight, ) self._knowledge_service = KnowledgeService( store=self._store, retriever=self._retriever, llm=self._llm, top_k=config.top_k, user_prompt_template=prompt_set.user_prompt, ) self._tool_factory = KnowledgeToolFactory(service=self._knowledge_service) self._graphs: dict[str, object] = {} def run(self, user_input: str, chat_history: Sequence[BaseMessage], namespace: str) -> str: graph = self._get_or_create_graph(namespace=namespace) state = graph.invoke({"messages": [*chat_history, HumanMessage(content=user_input)]}) return self._resolve_response(messages=state.get("messages", [])) def _get_or_create_graph(self, namespace: str): if namespace not in self._graphs: tools = self._tool_factory.build(namespace=namespace) self._graphs[namespace] = create_agent( model=self._llm, tools=tools, system_prompt=self._system_prompt, ) return self._graphs[namespace] @staticmethod def _resolve_response(messages: Sequence[BaseMessage]) -> str: tool_by_call_id: dict[str, str] = {} for message in messages: if isinstance(message, AIMessage): for tool_call in message.tool_calls: tool_name = tool_call.get("name") tool_call_id = tool_call.get("id") if tool_name and tool_call_id: tool_by_call_id[str(tool_call_id)] = str(tool_name) saw_store_call = False for message in reversed(messages): if isinstance(message, ToolMessage): tool_name = message.name or tool_by_call_id.get(str(message.tool_call_id)) if tool_name == "answer_from_knowledge": content = str(message.content).strip() return content or FALLBACK_MESSAGE if tool_name == "store_knowledge": saw_store_call = True if saw_store_call: return UNDERSTOOD_MESSAGE # Allow direct assistant text for social/onboarding turns chosen by the model. for message in reversed(messages): if isinstance(message, AIMessage): content = str(message.content).strip() if content: # Guard contract: "understood" is valid only after store_knowledge tool call. if content == UNDERSTOOD_MESSAGE: continue return content return FALLBACK_MESSAGE