| from __future__ import annotations |
|
|
| from collections.abc import Sequence |
|
|
| from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage |
| from langchain.agents import create_agent |
|
|
| from memory_agent.config import AppConfig |
| from memory_agent.embeddings import MistralEmbedEmbeddings |
| from memory_agent.knowledge import FALLBACK_MESSAGE, UNDERSTOOD_MESSAGE, KnowledgeService |
| from memory_agent.llm import MistralLLMFactory |
| from memory_agent.prompt_loader import PromptLoader |
| from memory_agent.retrieval import HybridRetriever |
| from memory_agent.storage import FaissMemoryStore |
| from memory_agent.tools import KnowledgeToolFactory |
|
|
|
|
| class UniversalMemoryAgent: |
| def __init__(self, config: AppConfig) -> None: |
| self._config = config |
| self._llm = MistralLLMFactory(config=config).create() |
| prompt_set = PromptLoader().load() |
| self._system_prompt = prompt_set.system_prompt.format( |
| understood_message=UNDERSTOOD_MESSAGE, |
| fallback_message=FALLBACK_MESSAGE, |
| ) |
| self._embeddings = MistralEmbedEmbeddings( |
| model_name=config.embedding_model, |
| api_token=config.mistral_api_key, |
| max_retries=config.mistral_max_retries, |
| base_delay_seconds=config.rate_limit_base_delay_seconds, |
| max_delay_seconds=config.rate_limit_max_delay_seconds, |
| ) |
| self._store = FaissMemoryStore(config=config, embeddings=self._embeddings) |
| self._retriever = HybridRetriever( |
| store=self._store, |
| dense_weight=config.dense_weight, |
| sparse_weight=config.sparse_weight, |
| ) |
| self._knowledge_service = KnowledgeService( |
| store=self._store, |
| retriever=self._retriever, |
| llm=self._llm, |
| top_k=config.top_k, |
| user_prompt_template=prompt_set.user_prompt, |
| ) |
| self._tool_factory = KnowledgeToolFactory(service=self._knowledge_service) |
| self._graphs: dict[str, object] = {} |
|
|
| def run(self, user_input: str, chat_history: Sequence[BaseMessage], namespace: str) -> str: |
| graph = self._get_or_create_graph(namespace=namespace) |
| state = graph.invoke({"messages": [*chat_history, HumanMessage(content=user_input)]}) |
| return self._resolve_response(messages=state.get("messages", [])) |
|
|
| def _get_or_create_graph(self, namespace: str): |
| if namespace not in self._graphs: |
| tools = self._tool_factory.build(namespace=namespace) |
| self._graphs[namespace] = create_agent( |
| model=self._llm, |
| tools=tools, |
| system_prompt=self._system_prompt, |
| ) |
| return self._graphs[namespace] |
|
|
| @staticmethod |
| def _resolve_response(messages: Sequence[BaseMessage]) -> str: |
| tool_by_call_id: dict[str, str] = {} |
| for message in messages: |
| if isinstance(message, AIMessage): |
| for tool_call in message.tool_calls: |
| tool_name = tool_call.get("name") |
| tool_call_id = tool_call.get("id") |
| if tool_name and tool_call_id: |
| tool_by_call_id[str(tool_call_id)] = str(tool_name) |
|
|
| saw_store_call = False |
| for message in reversed(messages): |
| if isinstance(message, ToolMessage): |
| tool_name = message.name or tool_by_call_id.get(str(message.tool_call_id)) |
| if tool_name == "answer_from_knowledge": |
| content = str(message.content).strip() |
| return content or FALLBACK_MESSAGE |
| if tool_name == "store_knowledge": |
| saw_store_call = True |
|
|
| if saw_store_call: |
| return UNDERSTOOD_MESSAGE |
|
|
| |
| for message in reversed(messages): |
| if isinstance(message, AIMessage): |
| content = str(message.content).strip() |
| if content: |
| |
| if content == UNDERSTOOD_MESSAGE: |
| continue |
| return content |
|
|
| return FALLBACK_MESSAGE |
|
|