Spaces:
Sleeping
Sleeping
| """ | |
| Test FAISS memory system with deduplication and semantic search. | |
| Run with: python test_memory.py | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import tempfile | |
| import shutil | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def test_basic_memory(): | |
| """Test basic add/search functionality.""" | |
| from core.memory import AgentMemory | |
| from core.embeddings import EmbeddingModel | |
| logger.info("=== Test 1: Basic Memory Operations ===") | |
| embedder = EmbeddingModel() | |
| memory = AgentMemory(embedder=embedder) | |
| # Add some tasks | |
| tasks = [ | |
| ("Calculate the sum of 2 and 2", "4"), | |
| ("What is 5 multiplied by 3?", "15"), | |
| ("Convert 100 Fahrenheit to Celsius", "37.78"), | |
| ("List the first 5 prime numbers", "2, 3, 5, 7, 11"), | |
| ] | |
| for task, result in tasks: | |
| added = memory.add(task, result) | |
| assert added, f"Failed to add: {task}" | |
| # Check stats | |
| stats = memory.get_stats() | |
| logger.info(f"Memory stats: {stats}") | |
| assert stats["total_items"] == len(tasks), f"Expected {len(tasks)}, got {stats['total_items']}" | |
| # Search for similar task | |
| results = memory.search("What is 2+2?", k=2) | |
| logger.info(f"Search results for '2+2': {len(results)} found") | |
| if results: | |
| logger.info(f"Top result: {results[0]['task']} -> {results[0]['result']} (similarity={results[0]['similarity']:.3f})") | |
| # Should find the "2 and 2" task | |
| assert results[0]['similarity'] > 0.5, "Expected high similarity" | |
| logger.info("✓ Basic memory operations passed\n") | |
| def test_deduplication(): | |
| """Test deduplication functionality.""" | |
| from core.memory import AgentMemory | |
| from core.embeddings import EmbeddingModel | |
| logger.info("=== Test 2: Deduplication ===") | |
| embedder = EmbeddingModel() | |
| memory = AgentMemory(embedder=embedder, dedup_threshold=0.95) | |
| # Add original task | |
| task1 = "Calculate the square root of 16" | |
| result1 = "4" | |
| added1 = memory.add(task1, result1) | |
| assert added1, "Failed to add original task" | |
| logger.info(f"Added original: {task1}") | |
| # Try to add very similar task (should be detected as duplicate) | |
| task2 = "Calculate the square root of 16" # Exact duplicate | |
| result2 = "4.0" | |
| added2 = memory.add(task2, result2) | |
| logger.info(f"Duplicate detection for exact match: {'Blocked' if not added2 else 'Added'}") | |
| # Try slightly different task (might not be duplicate) | |
| task3 = "What is the square root of 16?" | |
| result3 = "4" | |
| added3 = memory.add(task3, result3) | |
| logger.info(f"Similar but different: {'Blocked' if not added3 else 'Added'}") | |
| # Check total items | |
| stats = memory.get_stats() | |
| logger.info(f"Total items after dedup test: {stats['total_items']}") | |
| assert stats['total_items'] <= 2, f"Deduplication failed, expected <=2, got {stats['total_items']}" | |
| # Explicit duplicate check | |
| is_dup = memory.is_duplicate("Calculate the square root of 16") | |
| logger.info(f"Explicit duplicate check: {is_dup}") | |
| assert is_dup, "Should detect duplicate" | |
| logger.info("✓ Deduplication passed\n") | |
| def test_semantic_search(): | |
| """Test semantic similarity search.""" | |
| from core.memory import AgentMemory | |
| from core.embeddings import EmbeddingModel | |
| logger.info("=== Test 3: Semantic Search ===") | |
| embedder = EmbeddingModel() | |
| memory = AgentMemory(embedder=embedder, similarity_threshold=0.5) | |
| # Add tasks with different topics | |
| tasks = [ | |
| ("What is the capital of France?", "Paris"), | |
| ("What is the capital of Germany?", "Berlin"), | |
| ("How do I bake a chocolate cake?", "Mix flour, sugar, eggs, cocoa..."), | |
| ("What's the recipe for cookies?", "Mix butter, sugar, flour..."), | |
| ("Solve the equation x + 5 = 10", "x = 5"), | |
| ("What is 15 divided by 3?", "5"), | |
| ] | |
| for task, result in tasks: | |
| memory.add(task, result) | |
| # Search for capital city query (should find similar tasks) | |
| results = memory.search("What is the capital of Spain?", k=3) | |
| logger.info(f"Search 'capital of Spain' found {len(results)} results:") | |
| for r in results: | |
| logger.info(f" - {r['task'][:50]} (sim={r['similarity']:.3f})") | |
| # Should find other capital queries | |
| if results: | |
| assert "capital" in results[0]["task"].lower(), "Should find capital-related tasks" | |
| # Search for math query | |
| results = memory.search("Solve x + 10 = 20", k=3) | |
| logger.info(f"Search 'solve equation' found {len(results)} results:") | |
| for r in results: | |
| logger.info(f" - {r['task'][:50]} (sim={r['similarity']:.3f})") | |
| # Search for baking query | |
| results = memory.search("How to make brownies?", k=3) | |
| logger.info(f"Search 'make brownies' found {len(results)} results:") | |
| for r in results: | |
| logger.info(f" - {r['task'][:50]} (sim={r['similarity']:.3f})") | |
| logger.info("✓ Semantic search passed\n") | |
| def test_persistence(): | |
| """Test save/load functionality.""" | |
| from core.memory import AgentMemory | |
| from core.embeddings import EmbeddingModel | |
| logger.info("=== Test 4: Persistence (Save/Load) ===") | |
| # Create temporary directory | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| save_path = os.path.join(temp_dir, "test_memory") | |
| # Create and populate memory | |
| embedder = EmbeddingModel() | |
| memory1 = AgentMemory(embedder=embedder) | |
| tasks = [ | |
| ("Task 1", "Result 1"), | |
| ("Task 2", "Result 2"), | |
| ("Task 3", "Result 3"), | |
| ] | |
| for task, result in tasks: | |
| memory1.add(task, result, metadata={"source": "test"}) | |
| # Save to disk | |
| memory1.save(save_path) | |
| logger.info(f"Saved memory to {save_path}") | |
| assert os.path.exists(f"{save_path}.index"), "Index file not created" | |
| assert os.path.exists(f"{save_path}.meta"), "Metadata file not created" | |
| # Create new memory and load | |
| memory2 = AgentMemory(embedder=embedder) | |
| memory2.load(save_path) | |
| logger.info(f"Loaded memory from {save_path}") | |
| # Verify loaded data | |
| stats1 = memory1.get_stats() | |
| stats2 = memory2.get_stats() | |
| assert stats1["total_items"] == stats2["total_items"], "Item count mismatch" | |
| assert stats1["dimension"] == stats2["dimension"], "Dimension mismatch" | |
| logger.info(f"Loaded {stats2['total_items']} items with dim={stats2['dimension']}") | |
| # Search in loaded memory | |
| results = memory2.search("Task 1", k=1) | |
| assert len(results) > 0, "Search in loaded memory failed" | |
| assert "Task 1" in results[0]["task"], "Loaded data doesn't match" | |
| logger.info(f"Search in loaded memory: {results[0]['task']}") | |
| logger.info("✓ Persistence passed\n") | |
| finally: | |
| # Cleanup | |
| shutil.rmtree(temp_dir) | |
| logger.info(f"Cleaned up {temp_dir}") | |
| def test_threshold_behavior(): | |
| """Test threshold filtering.""" | |
| from core.memory import AgentMemory | |
| from core.embeddings import EmbeddingModel | |
| logger.info("=== Test 5: Threshold Behavior ===") | |
| embedder = EmbeddingModel() | |
| # Test with strict threshold | |
| memory_strict = AgentMemory(embedder=embedder, similarity_threshold=0.9) | |
| memory_strict.add("Python programming language", "A high-level language") | |
| results_strict = memory_strict.search("Java programming", k=5) | |
| logger.info(f"Strict threshold (0.9): {len(results_strict)} results") | |
| # Test with lenient threshold | |
| memory_lenient = AgentMemory(embedder=embedder, similarity_threshold=0.3) | |
| memory_lenient.add("Python programming language", "A high-level language") | |
| results_lenient = memory_lenient.search("Java programming", k=5) | |
| logger.info(f"Lenient threshold (0.3): {len(results_lenient)} results") | |
| # Lenient should find more (or equal) results | |
| assert len(results_lenient) >= len(results_strict), "Lenient threshold should find more results" | |
| logger.info("✓ Threshold behavior passed\n") | |
| def test_metadata(): | |
| """Test metadata storage and retrieval.""" | |
| from core.memory import AgentMemory | |
| from core.embeddings import EmbeddingModel | |
| logger.info("=== Test 6: Metadata ===") | |
| embedder = EmbeddingModel() | |
| memory = AgentMemory(embedder=embedder) | |
| # Add with metadata | |
| memory.add( | |
| task="Complex calculation", | |
| result="42", | |
| metadata={ | |
| "execution_time": 1.5, | |
| "tokens": 100, | |
| "model": "test-model" | |
| } | |
| ) | |
| # Search and verify metadata | |
| results = memory.search("calculation", k=1) | |
| assert len(results) > 0, "Search failed" | |
| meta = results[0]["metadata"] | |
| logger.info(f"Retrieved metadata: {meta}") | |
| assert "execution_time" in meta, "Metadata missing" | |
| assert meta["execution_time"] == 1.5, "Metadata value incorrect" | |
| logger.info("✓ Metadata passed\n") | |
| def test_clear(): | |
| """Test memory clearing.""" | |
| from core.memory import AgentMemory | |
| from core.embeddings import EmbeddingModel | |
| logger.info("=== Test 7: Clear Memory ===") | |
| embedder = EmbeddingModel() | |
| memory = AgentMemory(embedder=embedder) | |
| # Add items | |
| for i in range(5): | |
| memory.add(f"Task {i}", f"Result {i}") | |
| stats_before = memory.get_stats() | |
| logger.info(f"Before clear: {stats_before['total_items']} items") | |
| assert stats_before["total_items"] == 5 | |
| # Clear | |
| memory.clear() | |
| stats_after = memory.get_stats() | |
| logger.info(f"After clear: {stats_after['total_items']} items") | |
| assert stats_after["total_items"] == 0, "Memory not cleared" | |
| logger.info("✓ Clear passed\n") | |
| def run_all_tests(): | |
| """Run all memory tests.""" | |
| logger.info("Starting FAISS Memory System Tests\n") | |
| try: | |
| test_basic_memory() | |
| test_deduplication() | |
| test_semantic_search() | |
| test_persistence() | |
| test_threshold_behavior() | |
| test_metadata() | |
| test_clear() | |
| logger.info("=" * 50) | |
| logger.info("All tests passed! ✓") | |
| logger.info("=" * 50) | |
| except Exception as e: | |
| logger.error(f"Test failed: {e}", exc_info=True) | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| run_all_tests() | |