Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for the SharedMemory class. | |
| """ | |
| import pytest | |
| from datetime import datetime | |
| import threading | |
| import time | |
| from coda.core.memory import SharedMemory, MemoryEntry | |
| class TestSharedMemory: | |
| """Tests for the SharedMemory class.""" | |
| def memory(self): | |
| """Create a fresh SharedMemory instance.""" | |
| return SharedMemory() | |
| def test_store_and_retrieve(self, memory): | |
| """Test basic store and retrieve operations.""" | |
| memory.store( | |
| key="test_key", | |
| value={"data": "value"}, | |
| agent_name="TestAgent" | |
| ) | |
| result = memory.retrieve("test_key") | |
| assert result == {"data": "value"} | |
| def test_retrieve_nonexistent(self, memory): | |
| """Test retrieving a non-existent key.""" | |
| result = memory.retrieve("nonexistent") | |
| assert result is None | |
| def test_retrieve_entry(self, memory): | |
| """Test retrieving the full entry with metadata.""" | |
| memory.store( | |
| key="test_key", | |
| value="test_value", | |
| agent_name="TestAgent", | |
| metadata={"extra": "info"} | |
| ) | |
| entry = memory.retrieve_entry("test_key") | |
| assert entry is not None | |
| assert entry.value == "test_value" | |
| assert entry.agent_name == "TestAgent" | |
| assert entry.metadata == {"extra": "info"} | |
| assert isinstance(entry.timestamp, datetime) | |
| def test_get_context(self, memory): | |
| """Test retrieving multiple keys as context.""" | |
| memory.store("key1", "value1", "Agent1") | |
| memory.store("key2", "value2", "Agent2") | |
| memory.store("key3", "value3", "Agent3") | |
| context = memory.get_context(["key1", "key3", "nonexistent"]) | |
| assert context == {"key1": "value1", "key3": "value3"} | |
| def test_get_all(self, memory): | |
| """Test retrieving all stored values.""" | |
| memory.store("key1", "value1", "Agent") | |
| memory.store("key2", "value2", "Agent") | |
| all_data = memory.get_all() | |
| assert all_data == {"key1": "value1", "key2": "value2"} | |
| def test_overwrite_value(self, memory): | |
| """Test overwriting an existing value.""" | |
| memory.store("key", "original", "Agent") | |
| memory.store("key", "updated", "Agent") | |
| assert memory.retrieve("key") == "updated" | |
| def test_history_tracking(self, memory): | |
| """Test that history is tracked for all operations.""" | |
| memory.store("key1", "v1", "Agent1") | |
| memory.store("key2", "v2", "Agent2") | |
| memory.store("key1", "v1_updated", "Agent1") | |
| history = memory.get_history() | |
| assert len(history) == 3 | |
| assert history[0].key == "key1" | |
| assert history[1].key == "key2" | |
| assert history[2].value == "v1_updated" | |
| def test_history_filter_by_agent(self, memory): | |
| """Test filtering history by agent name.""" | |
| memory.store("k1", "v1", "Agent1") | |
| memory.store("k2", "v2", "Agent2") | |
| memory.store("k3", "v3", "Agent1") | |
| agent1_history = memory.get_history(agent_name="Agent1") | |
| assert len(agent1_history) == 2 | |
| assert all(e.agent_name == "Agent1" for e in agent1_history) | |
| def test_has_key(self, memory): | |
| """Test key existence check.""" | |
| memory.store("exists", "value", "Agent") | |
| assert memory.has_key("exists") is True | |
| assert memory.has_key("not_exists") is False | |
| def test_clear(self, memory): | |
| """Test clearing all data.""" | |
| memory.store("k1", "v1", "Agent") | |
| memory.store("k2", "v2", "Agent") | |
| memory.clear() | |
| assert memory.retrieve("k1") is None | |
| assert memory.retrieve("k2") is None | |
| assert len(memory.get_history()) == 0 | |
| def test_keys(self, memory): | |
| """Test getting all keys.""" | |
| memory.store("a", 1, "Agent") | |
| memory.store("b", 2, "Agent") | |
| memory.store("c", 3, "Agent") | |
| keys = memory.keys() | |
| assert set(keys) == {"a", "b", "c"} | |
| def test_thread_safety(self, memory): | |
| """Test that operations are thread-safe.""" | |
| results = [] | |
| errors = [] | |
| def writer(n): | |
| try: | |
| for i in range(100): | |
| memory.store(f"key_{n}_{i}", i, f"Agent{n}") | |
| except Exception as e: | |
| errors.append(e) | |
| def reader(): | |
| try: | |
| for _ in range(100): | |
| memory.get_all() | |
| memory.keys() | |
| except Exception as e: | |
| errors.append(e) | |
| threads = [ | |
| threading.Thread(target=writer, args=(i,)) | |
| for i in range(3) | |
| ] | |
| threads.append(threading.Thread(target=reader)) | |
| for t in threads: | |
| t.start() | |
| for t in threads: | |
| t.join() | |
| assert len(errors) == 0 | |