Spaces:
Sleeping
Sleeping
File size: 2,337 Bytes
5e0532d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import asyncio
import sqlite3
from app.services.memory import MemoryService
# Ensure we're testing on a fresh DB or using a test DB
TEST_DB = "test_memory.db"
async def test_memory_service():
print("Starting Memory Service Tests...")
# Setup
if os.path.exists(TEST_DB):
os.remove(TEST_DB)
memory = MemoryService(db_path=TEST_DB)
user1 = "user_123"
user2 = "user_456"
# 1. Test Short-Term Memory (Interactions)
print("Testing Short-Term Memory...")
await memory.add_interaction(user1, "user", "Hello ORA")
await memory.add_interaction(user1, "assistant", "Hello User")
history = await memory.get_short_term_memory(user1)
assert len(history) == 2, f"Expected 2 messages, got {len(history)}"
assert history[0]["role"] == "user"
assert history[0]["content"] == "Hello ORA"
# Test User Scoping
print("Testing User Scoping...")
await memory.add_interaction(user2, "user", "I am user 2")
history_u1 = await memory.get_short_term_memory(user1)
history_u2 = await memory.get_short_term_memory(user2)
assert len(history_u1) == 2
assert len(history_u2) == 1
assert history_u2[0]["content"] == "I am user 2"
# 2. Test Long-Term Memory (Facts)
print("Testing Long-Term Memory...")
await memory.add_fact(user1, "User lives in New York")
await memory.add_fact(user1, "User likes pizza")
facts = await memory.get_long_term_memory(user1)
assert len(facts) == 2
assert "User likes pizza" in facts
facts_u2 = await memory.get_long_term_memory(user2)
assert len(facts_u2) == 0
# 3. Test Persistence (Re-instantiate service)
print("Testing Persistence...")
del memory
memory_new = MemoryService(db_path=TEST_DB)
facts_reloaded = await memory_new.get_long_term_memory(user1)
assert len(facts_reloaded) == 2
# Cleanup
conn = sqlite3.connect(TEST_DB)
conn.close()
if os.path.exists(TEST_DB):
try:
os.remove(TEST_DB)
except PermissionError:
print("Warning: Could not remove test db file (locked).")
print("ALL TESTS PASSED!")
if __name__ == "__main__":
asyncio.run(test_memory_service())
|