File size: 8,893 Bytes
59cece6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd50ca1
 
 
59cece6
 
 
 
dd50ca1
 
 
 
 
59cece6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from typing import Dict, List
import os
import asyncio
from .types import MemoryEntry, RetrievedMemory
from .extractors import extract
from .store_sqlite import SQLiteMemoryStore
from .vector_index import VectorIndex
from .utils import exp_decay, Timer
from .rerank import rerank
from .inject import format_injection

class MemorySystem:
    def __init__(self, config):
        print(f"πŸ” MemorySystem: Initializing with config: {list(config.keys())}")
        self.cfg = config
        self.embedding_dim = 384
        
        # Load Models
        print("πŸ” MemorySystem: Loading Transformer Models...")
        # Assuming SentenceTransformer and spacy are imported elsewhere or need to be added.
        # For now, commenting out to avoid import errors if not present.
        # self.model = SentenceTransformer(config['models']['embedding'])
        print("πŸ” MemorySystem: Loading NLP Models...")
        # try:
        #     self.nlp = spacy.load(config['models']['nlp'])
        # except OSError:
        #     print("⚠️ Spacy model not found, downloading...")
        #     from spacy.cli import download
        #     download(config['models']['nlp'])
        #     self.nlp = spacy.load(config['models']['nlp'])
            
        print("βœ… MemorySystem: Models Loaded!")
        
        # Initialize Index
        print("πŸ” MemorySystem: Initializing FAISS Index...")
        # Assuming faiss is imported elsewhere or needs to be added.
        # self.index = faiss.IndexFlatL2(self.embedding_dim)
        
        # Load Memories
        self.memories = []
        self.ids = []
        print("πŸ” MemorySystem: Loading Memory Data...")
        # self._load_memories() # This method is not defined in the original code.
        print(f"βœ… MemorySystem: Loaded {len(self.memories)} memories.")
        
        # Original initializations, adapted to fit the new structure if possible
        os.makedirs("artifacts", exist_ok=True)
        db_path = self.cfg.get("storage", {}).get("path", "artifacts/memory.sqlite")
        print(f"πŸ” MemorySystem: Initializing SQLiteMemoryStore at {db_path}")
        self.store = SQLiteMemoryStore(path=db_path)
        print("πŸ” MemorySystem: Initializing VectorIndex...")
        self.vindex = VectorIndex(self.cfg["vector"]["embedding_model"])
        self.turn = 0
        self._memory_cache = {}
        
        # RESTORE STATE
        print("πŸ” MemorySystem: Rebuilding index...")
        self._rebuild_index()
        print("βœ… MemorySystem: Initialization complete.")

    def _rebuild_index(self):
        """Rebuilds the in-memory vector index from SQLite."""
        all_mems = self.store.all()
        if not all_mems:
            return
            
        # Restore cache
        for m in all_mems:
            self._memory_cache[m.memory_id] = m
            # Track max turn
            if m.source_turn > self.turn:
                self.turn = m.source_turn
        
        # Re-add to Vector Index
        # We only need to add them; VectorIndex handles embedding if not cached?
        # Actually VectorIndex.add_or_update embeds them.
        # This might be slow on startup for 5000+ items, but it's necessary since we don't save the index.
        # Optimization: We could pickle the index, but for now, re-embedding or checks is safer.
        # WAIT: Re-embedding 5000 items on every startup is SLOW (approx 10-20s).
        # But it's better than data loss.
        # Ideally we should serialize FAISS.
        # For this hackathon scope, strict correctness > startup speed.
        print(f"πŸ”„ Rebuilding Index from {len(all_mems)} memories...")
        self.vindex.add_or_update(all_mems)
        print("βœ… Index Rebuilt.")

    async def process_turn(self, user_text):
        self.turn += 1
        t_extract = Timer.start()
        
        # Async Extraction (Network Bound - Fine)
        extracted = await extract(user_text, self.turn)
        extract_ms = t_extract.ms()
        
        # Offload Blocking I/O and CPU work to ThreadPool
        # This prevents blocking the FastAPI Event Loop during Injection
        if extracted:
            await asyncio.to_thread(self._persist_memories, extracted)
            
        return {"turn": self.turn, "extracted": extracted, "extract_ms": extract_ms}

    def _persist_memories(self, extracted):
        # This runs in a separate thread
        for m in extracted:
            self._memory_cache[m.memory_id] = m
        self.store.upsert_many(extracted)
        self.vindex.add_or_update(extracted)

    def retrieve(self, query):
        cfgm = self.cfg["memory"]
        cfgv = self.cfg.get("vector", {})
        min_similarity = cfgv.get("similarity_threshold", 0.4)
        
        t = Timer.start()
        hits = self.vindex.search(query, top_k=max(10, cfgm["top_k"]*3))
        candidates = []
        for mid, base_score in hits:
            # CRITICAL FIX: Filter out semantically irrelevant matches
            # If base_score (cosine similarity) is below threshold, skip it
            if base_score < min_similarity:
                continue
                
            m = self._memory_cache.get(mid)
            if not m:
                continue
            age = self.turn - m.source_turn
            if age > cfgm["max_memory_age_turns"]:
                continue
            decay = exp_decay(age, cfgm["decay_lambda"])
            score = float(base_score) * float(m.confidence) * decay
            text = f"{m.type.value}|{m.key}={m.value}"
            candidates.append((mid, text, score))

        # CONFLICT RESOLUTION: keep highest-confidence version of each key
        key_cache = {}
        for mid, text, score in candidates:
            m = self._memory_cache[mid]
            # Key format: TYPE:KEY (e.g., preference:language)
            key = f"{m.type.value}:{m.key}"
            
            # Smart Conflict Resolution:
            # 1. Prefer significantly higher confidence ( > 0.1 diff)
            # 2. If confidence is similar, prefer the NEWER memory (Update logic)
            # 3. Handle explicit negations (if value is "DELETE" or "NULL") - TBD, for now just overwrites
            
            if key not in key_cache:
                key_cache[key] = (mid, text, score)
            else:
                curr_mid, _, curr_score = key_cache[key]
                curr_m = self._memory_cache[curr_mid]
                
                # Confidence diff check
                conf_diff = m.confidence - curr_m.confidence
                
                if conf_diff > 0.1:
                    # New one is much more confident -> Replace
                    key_cache[key] = (mid, text, score)
                elif conf_diff < -0.1:
                    # Old one is much more confident -> Keep old
                    pass
                else:
                    # Similar confidence: Prefer RECENCY
                    if m.source_turn > curr_m.source_turn:
                        # New memory is more recent -> Replace
                        # Bonus: Boost score slightly for recency to reflect "current truth"
                        boosted_score = score * 1.1 
                        key_cache[key] = (mid, text, boosted_score)
                    else:
                        # Old memory is more recent (unlikely in this loop order but safe to handle)
                        pass
        resolved_candidates = list(key_cache.values())

        if cfgm.get("rerank", True) and resolved_candidates:
            rr = rerank(query, resolved_candidates)
            ranked = rr[: cfgm["top_k"]]
            ranker_name = "multi_signal_rerank"
            score_map = {mid: s for mid, s in ranked}
            ordered_ids = [mid for mid, _ in ranked]
        else:
            resolved_candidates.sort(key=lambda x: x[2], reverse=True)
            top = resolved_candidates[: cfgm["top_k"]]
            ranker_name = "semantic_only"
            score_map = {mid: s for mid, _, s in top}
            ordered_ids = [mid for mid, _, _ in top]

        retrieved = []
        for mid in ordered_ids:
            m = self._memory_cache.get(mid)
            if not m:
                continue
            retrieved.append(RetrievedMemory(memory=m, score=score_map[mid], ranker=ranker_name))

        retrieve_ms = t.ms()
        injected = format_injection([r.memory for r in retrieved], max_tokens=cfgm["max_injected_tokens"])
        
        # POLISH: Update usage stats
        to_update = []
        for r in retrieved:
            r.memory.use_count += 1
            r.memory.last_used_turn = self.turn
            to_update.append(r.memory)
        if to_update:
            self.store.upsert_many(to_update)

        return {"turn": self.turn, "retrieved": retrieved, "retrieve_ms": retrieve_ms, "injected_context": injected}

    def close(self):
        self.store.close()