Spaces:
Sleeping
Sleeping
File size: 5,112 Bytes
9281fab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""
Unit tests for the SharedMemory class.
"""
import pytest
from datetime import datetime
import threading
import time
from coda.core.memory import SharedMemory, MemoryEntry
class TestSharedMemory:
"""Tests for the SharedMemory class."""
@pytest.fixture
def memory(self):
"""Create a fresh SharedMemory instance."""
return SharedMemory()
def test_store_and_retrieve(self, memory):
"""Test basic store and retrieve operations."""
memory.store(
key="test_key",
value={"data": "value"},
agent_name="TestAgent"
)
result = memory.retrieve("test_key")
assert result == {"data": "value"}
def test_retrieve_nonexistent(self, memory):
"""Test retrieving a non-existent key."""
result = memory.retrieve("nonexistent")
assert result is None
def test_retrieve_entry(self, memory):
"""Test retrieving the full entry with metadata."""
memory.store(
key="test_key",
value="test_value",
agent_name="TestAgent",
metadata={"extra": "info"}
)
entry = memory.retrieve_entry("test_key")
assert entry is not None
assert entry.value == "test_value"
assert entry.agent_name == "TestAgent"
assert entry.metadata == {"extra": "info"}
assert isinstance(entry.timestamp, datetime)
def test_get_context(self, memory):
"""Test retrieving multiple keys as context."""
memory.store("key1", "value1", "Agent1")
memory.store("key2", "value2", "Agent2")
memory.store("key3", "value3", "Agent3")
context = memory.get_context(["key1", "key3", "nonexistent"])
assert context == {"key1": "value1", "key3": "value3"}
def test_get_all(self, memory):
"""Test retrieving all stored values."""
memory.store("key1", "value1", "Agent")
memory.store("key2", "value2", "Agent")
all_data = memory.get_all()
assert all_data == {"key1": "value1", "key2": "value2"}
def test_overwrite_value(self, memory):
"""Test overwriting an existing value."""
memory.store("key", "original", "Agent")
memory.store("key", "updated", "Agent")
assert memory.retrieve("key") == "updated"
def test_history_tracking(self, memory):
"""Test that history is tracked for all operations."""
memory.store("key1", "v1", "Agent1")
memory.store("key2", "v2", "Agent2")
memory.store("key1", "v1_updated", "Agent1")
history = memory.get_history()
assert len(history) == 3
assert history[0].key == "key1"
assert history[1].key == "key2"
assert history[2].value == "v1_updated"
def test_history_filter_by_agent(self, memory):
"""Test filtering history by agent name."""
memory.store("k1", "v1", "Agent1")
memory.store("k2", "v2", "Agent2")
memory.store("k3", "v3", "Agent1")
agent1_history = memory.get_history(agent_name="Agent1")
assert len(agent1_history) == 2
assert all(e.agent_name == "Agent1" for e in agent1_history)
def test_has_key(self, memory):
"""Test key existence check."""
memory.store("exists", "value", "Agent")
assert memory.has_key("exists") is True
assert memory.has_key("not_exists") is False
def test_clear(self, memory):
"""Test clearing all data."""
memory.store("k1", "v1", "Agent")
memory.store("k2", "v2", "Agent")
memory.clear()
assert memory.retrieve("k1") is None
assert memory.retrieve("k2") is None
assert len(memory.get_history()) == 0
def test_keys(self, memory):
"""Test getting all keys."""
memory.store("a", 1, "Agent")
memory.store("b", 2, "Agent")
memory.store("c", 3, "Agent")
keys = memory.keys()
assert set(keys) == {"a", "b", "c"}
def test_thread_safety(self, memory):
"""Test that operations are thread-safe."""
results = []
errors = []
def writer(n):
try:
for i in range(100):
memory.store(f"key_{n}_{i}", i, f"Agent{n}")
except Exception as e:
errors.append(e)
def reader():
try:
for _ in range(100):
memory.get_all()
memory.keys()
except Exception as e:
errors.append(e)
threads = [
threading.Thread(target=writer, args=(i,))
for i in range(3)
]
threads.append(threading.Thread(target=reader))
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
|