|
|
"""Memory Layer Implementation for LangGraph Agent System""" |
|
|
import os |
|
|
import time |
|
|
import hashlib |
|
|
import sqlite3 |
|
|
from typing import Optional, List, Dict, Any, Tuple |
|
|
from langchain_community.vectorstores import SupabaseVectorStore |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from supabase.client import Client, create_client |
|
|
from langgraph.checkpoint.sqlite import SqliteSaver |
|
|
from langchain_core.messages import BaseMessage, HumanMessage |
|
|
|
|
|
|
|
|
|
|
|
TTL = 300 |
|
|
SIMILARITY_THRESHOLD = 0.85 |
|
|
|
|
|
|
|
|
class MemoryManager: |
|
|
"""Manages short-term, long-term memory and checkpointing for the agent system""" |
|
|
|
|
|
def __init__(self): |
|
|
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
|
self.vector_store = None |
|
|
self.checkpointer = None |
|
|
self._sqlite_connection = None |
|
|
|
|
|
|
|
|
self.query_cache: Dict[str, Tuple[float, List]] = {} |
|
|
self.processed_tasks: set[str] = set() |
|
|
self.seen_hashes: set[str] = set() |
|
|
|
|
|
self._initialize_vector_store() |
|
|
self._initialize_checkpointer() |
|
|
|
|
|
def _initialize_vector_store(self) -> None: |
|
|
"""Initialize Supabase vector store for long-term memory""" |
|
|
try: |
|
|
supabase_url = os.environ.get("SUPABASE_URL") |
|
|
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY") |
|
|
|
|
|
if not supabase_url or not supabase_key: |
|
|
print("Warning: Supabase credentials not found, vector store will be disabled") |
|
|
return |
|
|
|
|
|
supabase: Client = create_client(supabase_url, supabase_key) |
|
|
self.vector_store = SupabaseVectorStore( |
|
|
client=supabase, |
|
|
embedding=self.embeddings, |
|
|
table_name="documents", |
|
|
query_name="match_documents_langchain", |
|
|
) |
|
|
print("Vector store initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not initialize Supabase vector store: {e}") |
|
|
|
|
|
def _initialize_checkpointer(self) -> None: |
|
|
"""Initialize SQLite checkpointer for short-term memory""" |
|
|
try: |
|
|
|
|
|
self._sqlite_connection = sqlite3.connect(":memory:", check_same_thread=False) |
|
|
self.checkpointer = SqliteSaver(self._sqlite_connection) |
|
|
print("Checkpointer initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not initialize checkpointer: {e}") |
|
|
|
|
|
def get_checkpointer(self) -> Optional[SqliteSaver]: |
|
|
"""Get the checkpointer instance""" |
|
|
return self.checkpointer |
|
|
|
|
|
def close_checkpointer(self) -> None: |
|
|
"""Close the checkpointer and its SQLite connection""" |
|
|
if self._sqlite_connection: |
|
|
try: |
|
|
self._sqlite_connection.close() |
|
|
print("SQLite connection closed") |
|
|
except Exception as e: |
|
|
print(f"Warning: Error closing SQLite connection: {e}") |
|
|
|
|
|
def similarity_search(self, query: str, k: int = 2) -> List[Any]: |
|
|
"""Search for similar questions with caching""" |
|
|
if not self.vector_store: |
|
|
return [] |
|
|
|
|
|
|
|
|
q_hash = hashlib.sha256(query.encode()).hexdigest() |
|
|
now = time.time() |
|
|
|
|
|
if q_hash in self.query_cache and now - self.query_cache[q_hash][0] < TTL: |
|
|
print("Memory: Cache hit for similarity search") |
|
|
return self.query_cache[q_hash][1] |
|
|
|
|
|
try: |
|
|
print("Memory: Searching vector store for similar questions...") |
|
|
similar_questions = self.vector_store.similarity_search_with_relevance_scores(query, k=k) |
|
|
self.query_cache[q_hash] = (now, similar_questions) |
|
|
return similar_questions |
|
|
except Exception as e: |
|
|
print(f"Memory: Vector store search error – {e}") |
|
|
return [] |
|
|
|
|
|
def should_ingest(self, query: str) -> bool: |
|
|
"""Determine if this query/answer should be ingested to long-term memory""" |
|
|
if not self.vector_store: |
|
|
return False |
|
|
|
|
|
similar_questions = self.similarity_search(query, k=1) |
|
|
top_score = similar_questions[0][1] if similar_questions else 0.0 |
|
|
return top_score < SIMILARITY_THRESHOLD |
|
|
|
|
|
def ingest_qa_pair(self, question: str, answer: str, attachments: str = "") -> None: |
|
|
"""Store Q/A pair in long-term memory""" |
|
|
if not self.vector_store: |
|
|
print("Memory: Vector store not available for ingestion") |
|
|
return |
|
|
|
|
|
try: |
|
|
payload = f"Question:\n{question}\n\nAnswer:\n{answer}" |
|
|
if attachments: |
|
|
payload += f"\n\n{attachments}" |
|
|
|
|
|
hash_id = hashlib.sha256(payload.encode()).hexdigest() |
|
|
if hash_id in self.seen_hashes: |
|
|
print("Memory: Duplicate payload within session – skip") |
|
|
return |
|
|
|
|
|
self.seen_hashes.add(hash_id) |
|
|
self.vector_store.add_texts( |
|
|
[payload], |
|
|
metadatas=[{"hash_id": hash_id, "timestamp": time.time()}] |
|
|
) |
|
|
print("Memory: Stored new Q/A pair in vector store") |
|
|
except Exception as e: |
|
|
print(f"Memory: Error while upserting – {e}") |
|
|
|
|
|
def get_similar_qa(self, query: str) -> Optional[str]: |
|
|
"""Get similar Q/A for context""" |
|
|
similar_questions = self.similarity_search(query, k=1) |
|
|
if not similar_questions: |
|
|
return None |
|
|
|
|
|
example_doc = similar_questions[0][0] if isinstance(similar_questions[0], tuple) else similar_questions[0] |
|
|
return example_doc.page_content |
|
|
|
|
|
def add_processed_task(self, task_id: str) -> None: |
|
|
"""Mark a task as processed to avoid re-downloading attachments""" |
|
|
self.processed_tasks.add(task_id) |
|
|
|
|
|
def is_task_processed(self, task_id: str) -> bool: |
|
|
"""Check if a task has already been processed""" |
|
|
return task_id in self.processed_tasks |
|
|
|
|
|
def clear_session_cache(self) -> None: |
|
|
"""Clear session-specific caches""" |
|
|
self.query_cache.clear() |
|
|
self.processed_tasks.clear() |
|
|
self.seen_hashes.clear() |
|
|
print("Memory: Session cache cleared") |
|
|
|
|
|
|
|
|
|
|
|
memory_manager = MemoryManager() |