Marik1337's picture
Add application file
6059138
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain.agents import create_agent
from memory_agent.config import AppConfig
from memory_agent.embeddings import MistralEmbedEmbeddings
from memory_agent.knowledge import FALLBACK_MESSAGE, UNDERSTOOD_MESSAGE, KnowledgeService
from memory_agent.llm import MistralLLMFactory
from memory_agent.prompt_loader import PromptLoader
from memory_agent.retrieval import HybridRetriever
from memory_agent.storage import FaissMemoryStore
from memory_agent.tools import KnowledgeToolFactory
class UniversalMemoryAgent:
def __init__(self, config: AppConfig) -> None:
self._config = config
self._llm = MistralLLMFactory(config=config).create()
prompt_set = PromptLoader().load()
self._system_prompt = prompt_set.system_prompt.format(
understood_message=UNDERSTOOD_MESSAGE,
fallback_message=FALLBACK_MESSAGE,
)
self._embeddings = MistralEmbedEmbeddings(
model_name=config.embedding_model,
api_token=config.mistral_api_key,
max_retries=config.mistral_max_retries,
base_delay_seconds=config.rate_limit_base_delay_seconds,
max_delay_seconds=config.rate_limit_max_delay_seconds,
)
self._store = FaissMemoryStore(config=config, embeddings=self._embeddings)
self._retriever = HybridRetriever(
store=self._store,
dense_weight=config.dense_weight,
sparse_weight=config.sparse_weight,
)
self._knowledge_service = KnowledgeService(
store=self._store,
retriever=self._retriever,
llm=self._llm,
top_k=config.top_k,
user_prompt_template=prompt_set.user_prompt,
)
self._tool_factory = KnowledgeToolFactory(service=self._knowledge_service)
self._graphs: dict[str, object] = {}
def run(self, user_input: str, chat_history: Sequence[BaseMessage], namespace: str) -> str:
graph = self._get_or_create_graph(namespace=namespace)
state = graph.invoke({"messages": [*chat_history, HumanMessage(content=user_input)]})
return self._resolve_response(messages=state.get("messages", []))
def _get_or_create_graph(self, namespace: str):
if namespace not in self._graphs:
tools = self._tool_factory.build(namespace=namespace)
self._graphs[namespace] = create_agent(
model=self._llm,
tools=tools,
system_prompt=self._system_prompt,
)
return self._graphs[namespace]
@staticmethod
def _resolve_response(messages: Sequence[BaseMessage]) -> str:
tool_by_call_id: dict[str, str] = {}
for message in messages:
if isinstance(message, AIMessage):
for tool_call in message.tool_calls:
tool_name = tool_call.get("name")
tool_call_id = tool_call.get("id")
if tool_name and tool_call_id:
tool_by_call_id[str(tool_call_id)] = str(tool_name)
saw_store_call = False
for message in reversed(messages):
if isinstance(message, ToolMessage):
tool_name = message.name or tool_by_call_id.get(str(message.tool_call_id))
if tool_name == "answer_from_knowledge":
content = str(message.content).strip()
return content or FALLBACK_MESSAGE
if tool_name == "store_knowledge":
saw_store_call = True
if saw_store_call:
return UNDERSTOOD_MESSAGE
# Allow direct assistant text for social/onboarding turns chosen by the model.
for message in reversed(messages):
if isinstance(message, AIMessage):
content = str(message.content).strip()
if content:
# Guard contract: "understood" is valid only after store_knowledge tool call.
if content == UNDERSTOOD_MESSAGE:
continue
return content
return FALLBACK_MESSAGE