| import os |
| from typing import List, Union, Tuple |
| from langchain_chroma import Chroma |
| from langchain_openai import OpenAIEmbeddings |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage |
|
|
| from src.util import load_openai_keys, load_config |
|
|
|
|
| class MemoryUpdater: |
| def __init__(self) -> None: |
| CONFIG = load_config() |
| LLM_CONFIG = CONFIG['LLM'] |
|
|
| load_openai_keys() |
|
|
| |
| self.embedding_llm = OpenAIEmbeddings( |
| model=LLM_CONFIG["EmbeddingLLM"]["MODEL_NAME"], |
| api_key=os.getenv("OPENAI_API_KEY"), |
| ) |
|
|
| |
| self.vector_database = Chroma( |
| persist_directory="./memory_data", |
| embedding_function=self.embedding_llm, |
| collection_name="memory_collection", |
| ) |
|
|
| @staticmethod |
| def _normalize_history( |
| history: Union[List[dict], List[Tuple[str, str]], None] |
| ) -> List[BaseMessage]: |
| """Convert history into LangChain messages""" |
| if not history: |
| return [] |
| msgs: List[BaseMessage] = [] |
| if isinstance(history[0], dict): |
| for m in history: |
| role, content = m.get("role", "user"), m.get("content", "") |
| if role == "user": |
| msgs.append(HumanMessage(content=content)) |
| elif role == "assistant": |
| msgs.append(AIMessage(content=content)) |
| elif isinstance(history[0], (list, tuple)) and len(history[0]) == 2: |
| for u, a in history: |
| if u: |
| msgs.append(HumanMessage(content=u)) |
| if a: |
| msgs.append(AIMessage(content=a)) |
| return msgs |
|
|
| def update_memory(self, user_id: str, message: str, response: str): |
| """Store the latest user/assistant turn into memory DB (long-term storage)""" |
| doc_text = f"[User {user_id}] {message}\n[Assistant] {response}" |
| metadata = {"user_id": user_id} |
| self.vector_database.add_texts([doc_text], metadatas=[metadata]) |
|
|
| def get_short_term_memory(self, user_id: str, k: int = 5) -> List[str]: |
| """Retrieve the most recent k memory entries for this user""" |
| retriever = self.vector_database.as_retriever(search_kwargs={"k": k}) |
| docs = retriever.invoke(f"user {user_id} recent memory") |
| return [d.page_content for d in docs] |
|
|
| def get_long_term_memory(self, user_id: str) -> List[str]: |
| """Retrieve ALL memory entries for this user""" |
| results = self.vector_database.get(where={"user_id": user_id}) |
| return results.get("documents", []) |
|
|