Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import asyncio | |
| import sqlite3 | |
| from app.services.memory import MemoryService | |
| # Ensure we're testing on a fresh DB or using a test DB | |
| TEST_DB = "test_memory.db" | |
| async def test_memory_service(): | |
| print("Starting Memory Service Tests...") | |
| # Setup | |
| if os.path.exists(TEST_DB): | |
| os.remove(TEST_DB) | |
| memory = MemoryService(db_path=TEST_DB) | |
| user1 = "user_123" | |
| user2 = "user_456" | |
| # 1. Test Short-Term Memory (Interactions) | |
| print("Testing Short-Term Memory...") | |
| await memory.add_interaction(user1, "user", "Hello ORA") | |
| await memory.add_interaction(user1, "assistant", "Hello User") | |
| history = await memory.get_short_term_memory(user1) | |
| assert len(history) == 2, f"Expected 2 messages, got {len(history)}" | |
| assert history[0]["role"] == "user" | |
| assert history[0]["content"] == "Hello ORA" | |
| # Test User Scoping | |
| print("Testing User Scoping...") | |
| await memory.add_interaction(user2, "user", "I am user 2") | |
| history_u1 = await memory.get_short_term_memory(user1) | |
| history_u2 = await memory.get_short_term_memory(user2) | |
| assert len(history_u1) == 2 | |
| assert len(history_u2) == 1 | |
| assert history_u2[0]["content"] == "I am user 2" | |
| # 2. Test Long-Term Memory (Facts) | |
| print("Testing Long-Term Memory...") | |
| await memory.add_fact(user1, "User lives in New York") | |
| await memory.add_fact(user1, "User likes pizza") | |
| facts = await memory.get_long_term_memory(user1) | |
| assert len(facts) == 2 | |
| assert "User likes pizza" in facts | |
| facts_u2 = await memory.get_long_term_memory(user2) | |
| assert len(facts_u2) == 0 | |
| # 3. Test Persistence (Re-instantiate service) | |
| print("Testing Persistence...") | |
| del memory | |
| memory_new = MemoryService(db_path=TEST_DB) | |
| facts_reloaded = await memory_new.get_long_term_memory(user1) | |
| assert len(facts_reloaded) == 2 | |
| # Cleanup | |
| conn = sqlite3.connect(TEST_DB) | |
| conn.close() | |
| if os.path.exists(TEST_DB): | |
| try: | |
| os.remove(TEST_DB) | |
| except PermissionError: | |
| print("Warning: Could not remove test db file (locked).") | |
| print("ALL TESTS PASSED!") | |
| if __name__ == "__main__": | |
| asyncio.run(test_memory_service()) | |