import os from typing import List, Union, Tuple from langchain_chroma import Chroma from langchain_openai import OpenAIEmbeddings from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from src.util import load_openai_keys, load_config class MemoryUpdater: def __init__(self) -> None: CONFIG = load_config() LLM_CONFIG = CONFIG['LLM'] load_openai_keys() # Use same embedding model as RAG self.embedding_llm = OpenAIEmbeddings( model=LLM_CONFIG["EmbeddingLLM"]["MODEL_NAME"], api_key=os.getenv("OPENAI_API_KEY"), ) # Keep memory separate from reference DB self.vector_database = Chroma( persist_directory="./memory_data", # new dir for memory embedding_function=self.embedding_llm, collection_name="memory_collection", # separate collection ) @staticmethod def _normalize_history( history: Union[List[dict], List[Tuple[str, str]], None] ) -> List[BaseMessage]: """Convert history into LangChain messages""" if not history: return [] msgs: List[BaseMessage] = [] if isinstance(history[0], dict): for m in history: role, content = m.get("role", "user"), m.get("content", "") if role == "user": msgs.append(HumanMessage(content=content)) elif role == "assistant": msgs.append(AIMessage(content=content)) elif isinstance(history[0], (list, tuple)) and len(history[0]) == 2: for u, a in history: if u: msgs.append(HumanMessage(content=u)) if a: msgs.append(AIMessage(content=a)) return msgs def update_memory(self, user_id: str, message: str, response: str): """Store the latest user/assistant turn into memory DB (long-term storage)""" doc_text = f"[User {user_id}] {message}\n[Assistant] {response}" metadata = {"user_id": user_id} self.vector_database.add_texts([doc_text], metadatas=[metadata]) def get_short_term_memory(self, user_id: str, k: int = 5) -> List[str]: """Retrieve the most recent k memory entries for this user""" retriever = self.vector_database.as_retriever(search_kwargs={"k": k}) docs = retriever.invoke(f"user {user_id} recent memory") return [d.page_content for d in docs] def get_long_term_memory(self, user_id: str) -> List[str]: """Retrieve ALL memory entries for this user""" results = self.vector_database.get(where={"user_id": user_id}) return results.get("documents", [])