SCoDA / tests /test_memory.py
vanishingradient's picture
Added init files
9281fab
"""
Unit tests for the SharedMemory class.
"""
import pytest
from datetime import datetime
import threading
import time
from coda.core.memory import SharedMemory, MemoryEntry
class TestSharedMemory:
"""Tests for the SharedMemory class."""
@pytest.fixture
def memory(self):
"""Create a fresh SharedMemory instance."""
return SharedMemory()
def test_store_and_retrieve(self, memory):
"""Test basic store and retrieve operations."""
memory.store(
key="test_key",
value={"data": "value"},
agent_name="TestAgent"
)
result = memory.retrieve("test_key")
assert result == {"data": "value"}
def test_retrieve_nonexistent(self, memory):
"""Test retrieving a non-existent key."""
result = memory.retrieve("nonexistent")
assert result is None
def test_retrieve_entry(self, memory):
"""Test retrieving the full entry with metadata."""
memory.store(
key="test_key",
value="test_value",
agent_name="TestAgent",
metadata={"extra": "info"}
)
entry = memory.retrieve_entry("test_key")
assert entry is not None
assert entry.value == "test_value"
assert entry.agent_name == "TestAgent"
assert entry.metadata == {"extra": "info"}
assert isinstance(entry.timestamp, datetime)
def test_get_context(self, memory):
"""Test retrieving multiple keys as context."""
memory.store("key1", "value1", "Agent1")
memory.store("key2", "value2", "Agent2")
memory.store("key3", "value3", "Agent3")
context = memory.get_context(["key1", "key3", "nonexistent"])
assert context == {"key1": "value1", "key3": "value3"}
def test_get_all(self, memory):
"""Test retrieving all stored values."""
memory.store("key1", "value1", "Agent")
memory.store("key2", "value2", "Agent")
all_data = memory.get_all()
assert all_data == {"key1": "value1", "key2": "value2"}
def test_overwrite_value(self, memory):
"""Test overwriting an existing value."""
memory.store("key", "original", "Agent")
memory.store("key", "updated", "Agent")
assert memory.retrieve("key") == "updated"
def test_history_tracking(self, memory):
"""Test that history is tracked for all operations."""
memory.store("key1", "v1", "Agent1")
memory.store("key2", "v2", "Agent2")
memory.store("key1", "v1_updated", "Agent1")
history = memory.get_history()
assert len(history) == 3
assert history[0].key == "key1"
assert history[1].key == "key2"
assert history[2].value == "v1_updated"
def test_history_filter_by_agent(self, memory):
"""Test filtering history by agent name."""
memory.store("k1", "v1", "Agent1")
memory.store("k2", "v2", "Agent2")
memory.store("k3", "v3", "Agent1")
agent1_history = memory.get_history(agent_name="Agent1")
assert len(agent1_history) == 2
assert all(e.agent_name == "Agent1" for e in agent1_history)
def test_has_key(self, memory):
"""Test key existence check."""
memory.store("exists", "value", "Agent")
assert memory.has_key("exists") is True
assert memory.has_key("not_exists") is False
def test_clear(self, memory):
"""Test clearing all data."""
memory.store("k1", "v1", "Agent")
memory.store("k2", "v2", "Agent")
memory.clear()
assert memory.retrieve("k1") is None
assert memory.retrieve("k2") is None
assert len(memory.get_history()) == 0
def test_keys(self, memory):
"""Test getting all keys."""
memory.store("a", 1, "Agent")
memory.store("b", 2, "Agent")
memory.store("c", 3, "Agent")
keys = memory.keys()
assert set(keys) == {"a", "b", "c"}
def test_thread_safety(self, memory):
"""Test that operations are thread-safe."""
results = []
errors = []
def writer(n):
try:
for i in range(100):
memory.store(f"key_{n}_{i}", i, f"Agent{n}")
except Exception as e:
errors.append(e)
def reader():
try:
for _ in range(100):
memory.get_all()
memory.keys()
except Exception as e:
errors.append(e)
threads = [
threading.Thread(target=writer, args=(i,))
for i in range(3)
]
threads.append(threading.Thread(target=reader))
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0