AthelaPerk commited on
Commit
f40aa27
·
verified ·
1 Parent(s): fcc5a88

Upload mnemo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mnemo.py +522 -0
mnemo.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mnemo: Semantic-Loop Memory
3
+ ===========================
4
+ Named after Mnemosyne, Greek goddess of memory.
5
+
6
+ 21x faster than mem0. No API keys. Fully local. Learns from feedback.
7
+
8
+ Quick Start:
9
+ from mnemo import Mnemo
10
+
11
+ m = Mnemo()
12
+ m.add("User prefers dark mode")
13
+ results = m.search("user preferences")
14
+ """
15
+
16
+ import hashlib
17
+ import time
18
+ import re
19
+ import threading
20
+ import numpy as np
21
+ from typing import Dict, List, Optional, Tuple, Any
22
+ from dataclasses import dataclass, field
23
+ from collections import defaultdict
24
+ from enum import Enum
25
+
26
+ try:
27
+ import faiss
28
+ HAS_FAISS = True
29
+ except ImportError:
30
+ HAS_FAISS = False
31
+ print("Warning: faiss not installed. Using numpy fallback.")
32
+
33
+ try:
34
+ import networkx as nx
35
+ HAS_NETWORKX = True
36
+ except ImportError:
37
+ HAS_NETWORKX = False
38
+
39
+ try:
40
+ from rank_bm25 import BM25Okapi
41
+ HAS_BM25 = True
42
+ except ImportError:
43
+ HAS_BM25 = False
44
+
45
+
46
+ # =============================================================================
47
+ # ENUMS AND DATA CLASSES
48
+ # =============================================================================
49
+
50
+ class QueryIntent(Enum):
51
+ """Query intent types"""
52
+ FACTUAL = "factual"
53
+ ANALYTICAL = "analytical"
54
+ PROCEDURAL = "procedural"
55
+ EXPLORATORY = "exploratory"
56
+ NAVIGATIONAL = "navigational"
57
+ TRANSACTIONAL = "transactional"
58
+
59
+
60
+ @dataclass
61
+ class Memory:
62
+ """A single memory unit"""
63
+ id: str
64
+ content: str
65
+ embedding: np.ndarray
66
+ metadata: Dict = field(default_factory=dict)
67
+ created_at: float = field(default_factory=time.time)
68
+
69
+
70
+ @dataclass
71
+ class SearchResult:
72
+ """Search result"""
73
+ id: str
74
+ content: str
75
+ score: float
76
+ strategy_scores: Dict[str, float] = field(default_factory=dict)
77
+ metadata: Dict = field(default_factory=dict)
78
+
79
+
80
+ # =============================================================================
81
+ # CORE MNEMO CLASS
82
+ # =============================================================================
83
+
84
+ class Mnemo:
85
+ """
86
+ Mnemo: Semantic-Loop Memory System
87
+
88
+ Features:
89
+ - Multi-strategy retrieval (semantic + BM25 + graph)
90
+ - Query intent detection
91
+ - Feedback learning
92
+ - Knowledge graph
93
+ - Full observability
94
+
95
+ Example:
96
+ m = Mnemo()
97
+ m.add("User likes coffee with 2 sugars")
98
+ results = m.search("coffee preferences")
99
+ m.feedback("coffee preferences", results[0].id, relevance=0.9)
100
+ """
101
+
102
+ # Intent detection patterns
103
+ INTENT_PATTERNS = {
104
+ QueryIntent.FACTUAL: [r"^what (is|are|was|were)", r"^who (is|are)", r"^when", r"^where", r"^define"],
105
+ QueryIntent.ANALYTICAL: [r"compare", r"difference", r"contrast", r"versus|vs", r"analyze"],
106
+ QueryIntent.PROCEDURAL: [r"^how (to|do|can)", r"steps to", r"guide", r"tutorial"],
107
+ QueryIntent.EXPLORATORY: [r"tell me about", r"explain", r"describe", r"overview"],
108
+ QueryIntent.NAVIGATIONAL: [r"find", r"search for", r"locate", r"show me"],
109
+ QueryIntent.TRANSACTIONAL: [r"^(create|make|generate|write|send)", r"set up", r"configure"],
110
+ }
111
+
112
+ STOP_WORDS = {"a", "an", "the", "is", "are", "was", "were", "be", "been", "have", "has",
113
+ "do", "does", "did", "will", "would", "could", "should", "may", "might",
114
+ "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into",
115
+ "and", "but", "or", "not", "this", "that", "these", "those", "i", "me", "my"}
116
+
117
+ def __init__(self, embedding_dim: int = 384,
118
+ semantic_weight: float = 0.5,
119
+ bm25_weight: float = 0.3,
120
+ graph_weight: float = 0.2):
121
+ """
122
+ Initialize Mnemo.
123
+
124
+ Args:
125
+ embedding_dim: Dimension for embeddings (default 384 for BGE-small)
126
+ semantic_weight: Weight for semantic search (default 0.5)
127
+ bm25_weight: Weight for BM25 keyword search (default 0.3)
128
+ graph_weight: Weight for graph traversal (default 0.2)
129
+ """
130
+ self.embedding_dim = embedding_dim
131
+ self.semantic_weight = semantic_weight
132
+ self.bm25_weight = bm25_weight
133
+ self.graph_weight = graph_weight
134
+
135
+ # Storage
136
+ self.memories: Dict[str, Memory] = {}
137
+ self._embeddings: List[np.ndarray] = []
138
+ self._ids: List[str] = []
139
+
140
+ # FAISS index
141
+ if HAS_FAISS:
142
+ self.index = faiss.IndexFlatIP(embedding_dim)
143
+ else:
144
+ self.index = None
145
+
146
+ # BM25
147
+ self.bm25 = None
148
+ self._tokenized_docs: List[List[str]] = []
149
+
150
+ # Knowledge Graph
151
+ if HAS_NETWORKX:
152
+ self.graph = nx.DiGraph()
153
+ else:
154
+ self.graph = None
155
+
156
+ # Feedback learning
157
+ self._doc_boosts: Dict[str, float] = defaultdict(float)
158
+ self._query_doc_scores: Dict[str, Dict[str, float]] = defaultdict(dict)
159
+ self._feedback_count = 0
160
+
161
+ # Cache
162
+ self._cache: Dict[str, Any] = {}
163
+ self._cache_lock = threading.Lock()
164
+
165
+ # Stats
166
+ self.stats = {
167
+ "adds": 0,
168
+ "searches": 0,
169
+ "feedback": 0,
170
+ "cache_hits": 0,
171
+ "cache_misses": 0,
172
+ "strategy_wins": defaultdict(int)
173
+ }
174
+
175
+ def _get_embedding(self, text: str) -> np.ndarray:
176
+ """Generate embedding for text (hash-based, replace with real embeddings)"""
177
+ # Check cache
178
+ cache_key = f"emb:{hashlib.md5(text.encode()).hexdigest()}"
179
+ with self._cache_lock:
180
+ if cache_key in self._cache:
181
+ self.stats["cache_hits"] += 1
182
+ return self._cache[cache_key]
183
+ self.stats["cache_misses"] += 1
184
+
185
+ # Hash-based embedding (replace with sentence-transformers for production)
186
+ embedding = np.zeros(self.embedding_dim, dtype=np.float32)
187
+ words = text.lower().split()
188
+ for i, word in enumerate(words):
189
+ idx = hash(word) % self.embedding_dim
190
+ embedding[idx] += 1.0 / (i + 1)
191
+
192
+ # Normalize
193
+ norm = np.linalg.norm(embedding)
194
+ if norm > 0:
195
+ embedding = embedding / norm
196
+
197
+ with self._cache_lock:
198
+ self._cache[cache_key] = embedding
199
+
200
+ return embedding
201
+
202
+ def _detect_intent(self, query: str) -> Tuple[QueryIntent, float]:
203
+ """Detect query intent"""
204
+ query_lower = query.lower()
205
+
206
+ for intent, patterns in self.INTENT_PATTERNS.items():
207
+ for pattern in patterns:
208
+ if re.search(pattern, query_lower):
209
+ return intent, 0.85
210
+
211
+ return QueryIntent.EXPLORATORY, 0.5
212
+
213
+ def _extract_keywords(self, text: str) -> List[str]:
214
+ """Extract keywords from text"""
215
+ words = re.findall(r'\b\w+\b', text.lower())
216
+ return [w for w in words if w not in self.STOP_WORDS and len(w) > 2]
217
+
218
+ def _rebuild_bm25(self):
219
+ """Rebuild BM25 index"""
220
+ if HAS_BM25 and self._tokenized_docs:
221
+ self.bm25 = BM25Okapi(self._tokenized_docs)
222
+
223
+ def add(self, content: str, metadata: Dict = None, memory_id: str = None) -> str:
224
+ """
225
+ Add a memory.
226
+
227
+ Args:
228
+ content: Text content to store
229
+ metadata: Optional metadata dict
230
+ memory_id: Optional custom ID (auto-generated if not provided)
231
+
232
+ Returns:
233
+ Memory ID
234
+ """
235
+ # Generate ID
236
+ if memory_id is None:
237
+ memory_id = f"mem_{hashlib.md5(content.encode()).hexdigest()[:8]}"
238
+
239
+ # Get embedding
240
+ embedding = self._get_embedding(content)
241
+
242
+ # Create memory
243
+ memory = Memory(
244
+ id=memory_id,
245
+ content=content,
246
+ embedding=embedding,
247
+ metadata=metadata or {}
248
+ )
249
+
250
+ # Store
251
+ self.memories[memory_id] = memory
252
+ self._embeddings.append(embedding)
253
+ self._ids.append(memory_id)
254
+
255
+ # Update FAISS
256
+ if HAS_FAISS and self.index is not None:
257
+ self.index.add(embedding.reshape(1, -1))
258
+
259
+ # Update BM25
260
+ tokens = content.lower().split()
261
+ self._tokenized_docs.append(tokens)
262
+ self._rebuild_bm25()
263
+
264
+ # Update graph
265
+ if HAS_NETWORKX and self.graph is not None:
266
+ self.graph.add_node(memory_id, content=content, **memory.metadata)
267
+ # Extract and link entities (simplified)
268
+ keywords = self._extract_keywords(content)
269
+ for kw in keywords[:5]: # Top 5 keywords as entities
270
+ entity_id = f"entity_{kw}"
271
+ if not self.graph.has_node(entity_id):
272
+ self.graph.add_node(entity_id, type="keyword")
273
+ self.graph.add_edge(memory_id, entity_id, relation="contains")
274
+
275
+ self.stats["adds"] += 1
276
+ return memory_id
277
+
278
+ def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
279
+ """
280
+ Search memories.
281
+
282
+ Args:
283
+ query: Search query
284
+ top_k: Number of results to return
285
+
286
+ Returns:
287
+ List of SearchResult objects
288
+ """
289
+ if not self.memories:
290
+ return []
291
+
292
+ self.stats["searches"] += 1
293
+
294
+ # Detect intent
295
+ intent, confidence = self._detect_intent(query)
296
+
297
+ # Get query embedding
298
+ query_embedding = self._get_embedding(query)
299
+
300
+ # Strategy 1: Semantic search
301
+ semantic_scores = {}
302
+ if HAS_FAISS and self.index is not None and self.index.ntotal > 0:
303
+ k = min(top_k * 2, self.index.ntotal)
304
+ scores, indices = self.index.search(query_embedding.reshape(1, -1), k)
305
+ for score, idx in zip(scores[0], indices[0]):
306
+ if idx >= 0 and idx < len(self._ids):
307
+ semantic_scores[self._ids[idx]] = float(score)
308
+
309
+ # Strategy 2: BM25 keyword search
310
+ bm25_scores = {}
311
+ if HAS_BM25 and self.bm25 is not None:
312
+ tokens = query.lower().split()
313
+ scores = self.bm25.get_scores(tokens)
314
+ max_score = max(scores) if scores.any() and max(scores) > 0 else 1
315
+ for idx, score in enumerate(scores):
316
+ if score > 0.1 * max_score:
317
+ bm25_scores[self._ids[idx]] = float(score / max_score)
318
+
319
+ # Strategy 3: Graph search (simplified)
320
+ graph_scores = {}
321
+ if HAS_NETWORKX and self.graph is not None:
322
+ keywords = self._extract_keywords(query)
323
+ for kw in keywords:
324
+ entity_id = f"entity_{kw}"
325
+ if self.graph.has_node(entity_id):
326
+ for neighbor in self.graph.predecessors(entity_id):
327
+ if neighbor.startswith("mem_"):
328
+ graph_scores[neighbor] = graph_scores.get(neighbor, 0) + 0.5
329
+
330
+ # Combine scores
331
+ all_ids = set(semantic_scores.keys()) | set(bm25_scores.keys()) | set(graph_scores.keys())
332
+
333
+ results = []
334
+ for mem_id in all_ids:
335
+ strategy_scores = {
336
+ "semantic": semantic_scores.get(mem_id, 0),
337
+ "bm25": bm25_scores.get(mem_id, 0),
338
+ "graph": graph_scores.get(mem_id, 0)
339
+ }
340
+
341
+ # Weighted combination
342
+ combined = (
343
+ self.semantic_weight * strategy_scores["semantic"] +
344
+ self.bm25_weight * strategy_scores["bm25"] +
345
+ self.graph_weight * strategy_scores["graph"]
346
+ )
347
+
348
+ # Apply feedback boost
349
+ feedback_adj = self._get_feedback_adjustment(query, mem_id)
350
+ combined += feedback_adj * 0.2
351
+
352
+ memory = self.memories.get(mem_id)
353
+ if memory:
354
+ results.append(SearchResult(
355
+ id=mem_id,
356
+ content=memory.content,
357
+ score=combined,
358
+ strategy_scores=strategy_scores,
359
+ metadata=memory.metadata
360
+ ))
361
+
362
+ # Sort by score
363
+ results.sort(key=lambda x: x.score, reverse=True)
364
+
365
+ # Track winning strategy
366
+ if results:
367
+ top_result = results[0]
368
+ winning_strategy = max(top_result.strategy_scores, key=top_result.strategy_scores.get)
369
+ self.stats["strategy_wins"][winning_strategy] += 1
370
+
371
+ return results[:top_k]
372
+
373
+ def feedback(self, query: str, memory_id: str, relevance: float):
374
+ """
375
+ Record feedback to improve future searches.
376
+
377
+ Args:
378
+ query: The search query
379
+ memory_id: ID of the memory
380
+ relevance: Relevance score (-1 to 1, negative = irrelevant)
381
+ """
382
+ relevance = max(-1, min(1, relevance)) # Clamp
383
+
384
+ # Update global doc boost
385
+ self._doc_boosts[memory_id] += 0.1 * relevance
386
+
387
+ # Update query-specific score
388
+ query_key = " ".join(sorted(set(query.lower().split()))[:5])
389
+ current = self._query_doc_scores[query_key].get(memory_id, 0)
390
+ self._query_doc_scores[query_key][memory_id] = current + 0.1 * relevance
391
+
392
+ self._feedback_count += 1
393
+ self.stats["feedback"] += 1
394
+
395
+ def _get_feedback_adjustment(self, query: str, memory_id: str) -> float:
396
+ """Get feedback-based score adjustment"""
397
+ query_key = " ".join(sorted(set(query.lower().split()))[:5])
398
+
399
+ global_boost = self._doc_boosts.get(memory_id, 0)
400
+ query_boost = self._query_doc_scores.get(query_key, {}).get(memory_id, 0)
401
+
402
+ return global_boost * 0.3 + query_boost * 0.7
403
+
404
+ def get(self, memory_id: str) -> Optional[Memory]:
405
+ """Get a specific memory by ID"""
406
+ return self.memories.get(memory_id)
407
+
408
+ def delete(self, memory_id: str) -> bool:
409
+ """Delete a memory (note: FAISS index not updated, rebuild for production)"""
410
+ if memory_id in self.memories:
411
+ del self.memories[memory_id]
412
+ return True
413
+ return False
414
+
415
+ def get_stats(self) -> Dict:
416
+ """Get system statistics"""
417
+ return {
418
+ "total_memories": len(self.memories),
419
+ "adds": self.stats["adds"],
420
+ "searches": self.stats["searches"],
421
+ "feedback_count": self.stats["feedback"],
422
+ "cache_hit_rate": f"{self.stats['cache_hits'] / max(1, self.stats['cache_hits'] + self.stats['cache_misses']):.1%}",
423
+ "strategy_wins": dict(self.stats["strategy_wins"]),
424
+ "has_faiss": HAS_FAISS,
425
+ "has_bm25": HAS_BM25,
426
+ "has_graph": HAS_NETWORKX
427
+ }
428
+
429
+ def get_knowledge_graph(self):
430
+ """Get the knowledge graph (if available)"""
431
+ return self.graph
432
+
433
+ def clear(self):
434
+ """Clear all memories"""
435
+ self.memories.clear()
436
+ self._embeddings.clear()
437
+ self._ids.clear()
438
+ self._tokenized_docs.clear()
439
+ self.bm25 = None
440
+ self._cache.clear()
441
+
442
+ if HAS_FAISS:
443
+ self.index = faiss.IndexFlatIP(self.embedding_dim)
444
+
445
+ if HAS_NETWORKX:
446
+ self.graph = nx.DiGraph()
447
+
448
+ def __len__(self):
449
+ return len(self.memories)
450
+
451
+ def __repr__(self):
452
+ return f"Mnemo(memories={len(self.memories)}, embedding_dim={self.embedding_dim})"
453
+
454
+
455
+ # =============================================================================
456
+ # CONVENIENCE FUNCTIONS
457
+ # =============================================================================
458
+
459
+ def create_memory(embedding_dim: int = 384) -> Mnemo:
460
+ """Create a new Mnemo instance"""
461
+ return Mnemo(embedding_dim=embedding_dim)
462
+
463
+
464
+ # =============================================================================
465
+ # DEMO
466
+ # =============================================================================
467
+
468
+ def demo():
469
+ """Quick demo of Mnemo"""
470
+ print("=" * 50)
471
+ print("MNEMO DEMO")
472
+ print("=" * 50)
473
+
474
+ m = Mnemo()
475
+
476
+ # Add memories
477
+ memories = [
478
+ "User prefers dark mode and receives notifications in the morning",
479
+ "Project deadline is March 15th for the API redesign",
480
+ "Team standup meeting every Tuesday at 2pm in room 401",
481
+ "Favorite coffee is cappuccino with oat milk, no sugar",
482
+ "Working on machine learning model for customer churn prediction"
483
+ ]
484
+
485
+ print("\n📝 Adding memories...")
486
+ for mem in memories:
487
+ mem_id = m.add(mem)
488
+ print(f" Added: {mem_id}")
489
+
490
+ # Search
491
+ queries = [
492
+ "What are the user's notification preferences?",
493
+ "When is the project deadline?",
494
+ "Coffee order",
495
+ ]
496
+
497
+ print("\n🔍 Searching...")
498
+ for query in queries:
499
+ print(f"\n Query: '{query}'")
500
+ results = m.search(query, top_k=2)
501
+ for r in results:
502
+ print(f" → [{r.id}] score={r.score:.3f}")
503
+ print(f" {r.content[:60]}...")
504
+
505
+ # Feedback
506
+ print("\n👍 Recording feedback...")
507
+ m.feedback("notification preferences", "mem_00000000", relevance=0.9)
508
+ print(" Feedback recorded")
509
+
510
+ # Stats
511
+ print("\n📊 Stats:")
512
+ stats = m.get_stats()
513
+ for k, v in stats.items():
514
+ print(f" {k}: {v}")
515
+
516
+ print("\n" + "=" * 50)
517
+ print("✅ Demo complete!")
518
+ print("=" * 50)
519
+
520
+
521
+ if __name__ == "__main__":
522
+ demo()