Mental-Thougness-LLM / src /memory_update.py
SmileyFriend's picture
Update src/memory_update.py
a288c28 verified
import os
from typing import List, Union, Tuple
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from src.util import load_openai_keys, load_config
class MemoryUpdater:
def __init__(self) -> None:
CONFIG = load_config()
LLM_CONFIG = CONFIG['LLM']
load_openai_keys()
# Use same embedding model as RAG
self.embedding_llm = OpenAIEmbeddings(
model=LLM_CONFIG["EmbeddingLLM"]["MODEL_NAME"],
api_key=os.getenv("OPENAI_API_KEY"),
)
# Keep memory separate from reference DB
self.vector_database = Chroma(
persist_directory="./memory_data", # new dir for memory
embedding_function=self.embedding_llm,
collection_name="memory_collection", # separate collection
)
@staticmethod
def _normalize_history(
history: Union[List[dict], List[Tuple[str, str]], None]
) -> List[BaseMessage]:
"""Convert history into LangChain messages"""
if not history:
return []
msgs: List[BaseMessage] = []
if isinstance(history[0], dict):
for m in history:
role, content = m.get("role", "user"), m.get("content", "")
if role == "user":
msgs.append(HumanMessage(content=content))
elif role == "assistant":
msgs.append(AIMessage(content=content))
elif isinstance(history[0], (list, tuple)) and len(history[0]) == 2:
for u, a in history:
if u:
msgs.append(HumanMessage(content=u))
if a:
msgs.append(AIMessage(content=a))
return msgs
def update_memory(self, user_id: str, message: str, response: str):
"""Store the latest user/assistant turn into memory DB (long-term storage)"""
doc_text = f"[User {user_id}] {message}\n[Assistant] {response}"
metadata = {"user_id": user_id}
self.vector_database.add_texts([doc_text], metadatas=[metadata])
def get_short_term_memory(self, user_id: str, k: int = 5) -> List[str]:
"""Retrieve the most recent k memory entries for this user"""
retriever = self.vector_database.as_retriever(search_kwargs={"k": k})
docs = retriever.invoke(f"user {user_id} recent memory")
return [d.page_content for d in docs]
def get_long_term_memory(self, user_id: str) -> List[str]:
"""Retrieve ALL memory entries for this user"""
results = self.vector_database.get(where={"user_id": user_id})
return results.get("documents", [])