File size: 6,603 Bytes
fe36046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Memory Layer Implementation for LangGraph Agent System"""
import os
import time
import hashlib
import sqlite3
from typing import Optional, List, Dict, Any, Tuple
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from supabase.client import Client, create_client
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import BaseMessage, HumanMessage


# Constants for memory management
TTL = 300  # seconds – how long we keep similarity-search results
SIMILARITY_THRESHOLD = 0.85  # cosine score above which we assume we already know the answer


class MemoryManager:
    """Manages short-term, long-term memory and checkpointing for the agent system"""
    
    def __init__(self):
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        self.vector_store = None
        self.checkpointer = None
        self._sqlite_connection = None
        
        # In-memory caches
        self.query_cache: Dict[str, Tuple[float, List]] = {}
        self.processed_tasks: set[str] = set()
        self.seen_hashes: set[str] = set()
        
        self._initialize_vector_store()
        self._initialize_checkpointer()
    
    def _initialize_vector_store(self) -> None:
        """Initialize Supabase vector store for long-term memory"""
        try:
            supabase_url = os.environ.get("SUPABASE_URL")
            supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
            
            if not supabase_url or not supabase_key:
                print("Warning: Supabase credentials not found, vector store will be disabled")
                return
                
            supabase: Client = create_client(supabase_url, supabase_key)
            self.vector_store = SupabaseVectorStore(
                client=supabase,
                embedding=self.embeddings,
                table_name="documents",
                query_name="match_documents_langchain",
            )
            print("Vector store initialized successfully")
        except Exception as e:
            print(f"Warning: Could not initialize Supabase vector store: {e}")
    
    def _initialize_checkpointer(self) -> None:
        """Initialize SQLite checkpointer for short-term memory"""
        try:
            # Create a direct SQLite connection
            self._sqlite_connection = sqlite3.connect(":memory:", check_same_thread=False)
            self.checkpointer = SqliteSaver(self._sqlite_connection)
            print("Checkpointer initialized successfully")
        except Exception as e:
            print(f"Warning: Could not initialize checkpointer: {e}")
    
    def get_checkpointer(self) -> Optional[SqliteSaver]:
        """Get the checkpointer instance"""
        return self.checkpointer
    
    def close_checkpointer(self) -> None:
        """Close the checkpointer and its SQLite connection"""
        if self._sqlite_connection:
            try:
                self._sqlite_connection.close()
                print("SQLite connection closed")
            except Exception as e:
                print(f"Warning: Error closing SQLite connection: {e}")
    
    def similarity_search(self, query: str, k: int = 2) -> List[Any]:
        """Search for similar questions with caching"""
        if not self.vector_store:
            return []
        
        # Check cache first
        q_hash = hashlib.sha256(query.encode()).hexdigest()
        now = time.time()
        
        if q_hash in self.query_cache and now - self.query_cache[q_hash][0] < TTL:
            print("Memory: Cache hit for similarity search")
            return self.query_cache[q_hash][1]
        
        try:
            print("Memory: Searching vector store for similar questions...")
            similar_questions = self.vector_store.similarity_search_with_relevance_scores(query, k=k)
            self.query_cache[q_hash] = (now, similar_questions)
            return similar_questions
        except Exception as e:
            print(f"Memory: Vector store search error – {e}")
            return []
    
    def should_ingest(self, query: str) -> bool:
        """Determine if this query/answer should be ingested to long-term memory"""
        if not self.vector_store:
            return False
        
        similar_questions = self.similarity_search(query, k=1)
        top_score = similar_questions[0][1] if similar_questions else 0.0
        return top_score < SIMILARITY_THRESHOLD
    
    def ingest_qa_pair(self, question: str, answer: str, attachments: str = "") -> None:
        """Store Q/A pair in long-term memory"""
        if not self.vector_store:
            print("Memory: Vector store not available for ingestion")
            return
        
        try:
            payload = f"Question:\n{question}\n\nAnswer:\n{answer}"
            if attachments:
                payload += f"\n\n{attachments}"
            
            hash_id = hashlib.sha256(payload.encode()).hexdigest()
            if hash_id in self.seen_hashes:
                print("Memory: Duplicate payload within session – skip")
                return
            
            self.seen_hashes.add(hash_id)
            self.vector_store.add_texts(
                [payload], 
                metadatas=[{"hash_id": hash_id, "timestamp": time.time()}]
            )
            print("Memory: Stored new Q/A pair in vector store")
        except Exception as e:
            print(f"Memory: Error while upserting – {e}")
    
    def get_similar_qa(self, query: str) -> Optional[str]:
        """Get similar Q/A for context"""
        similar_questions = self.similarity_search(query, k=1)
        if not similar_questions:
            return None
        
        example_doc = similar_questions[0][0] if isinstance(similar_questions[0], tuple) else similar_questions[0]
        return example_doc.page_content
    
    def add_processed_task(self, task_id: str) -> None:
        """Mark a task as processed to avoid re-downloading attachments"""
        self.processed_tasks.add(task_id)
    
    def is_task_processed(self, task_id: str) -> bool:
        """Check if a task has already been processed"""
        return task_id in self.processed_tasks
    
    def clear_session_cache(self) -> None:
        """Clear session-specific caches"""
        self.query_cache.clear()
        self.processed_tasks.clear()
        self.seen_hashes.clear()
        print("Memory: Session cache cleared")


# Global memory manager instance
memory_manager = MemoryManager()