""" Unit tests for the SharedMemory class. """ import pytest from datetime import datetime import threading import time from coda.core.memory import SharedMemory, MemoryEntry class TestSharedMemory: """Tests for the SharedMemory class.""" @pytest.fixture def memory(self): """Create a fresh SharedMemory instance.""" return SharedMemory() def test_store_and_retrieve(self, memory): """Test basic store and retrieve operations.""" memory.store( key="test_key", value={"data": "value"}, agent_name="TestAgent" ) result = memory.retrieve("test_key") assert result == {"data": "value"} def test_retrieve_nonexistent(self, memory): """Test retrieving a non-existent key.""" result = memory.retrieve("nonexistent") assert result is None def test_retrieve_entry(self, memory): """Test retrieving the full entry with metadata.""" memory.store( key="test_key", value="test_value", agent_name="TestAgent", metadata={"extra": "info"} ) entry = memory.retrieve_entry("test_key") assert entry is not None assert entry.value == "test_value" assert entry.agent_name == "TestAgent" assert entry.metadata == {"extra": "info"} assert isinstance(entry.timestamp, datetime) def test_get_context(self, memory): """Test retrieving multiple keys as context.""" memory.store("key1", "value1", "Agent1") memory.store("key2", "value2", "Agent2") memory.store("key3", "value3", "Agent3") context = memory.get_context(["key1", "key3", "nonexistent"]) assert context == {"key1": "value1", "key3": "value3"} def test_get_all(self, memory): """Test retrieving all stored values.""" memory.store("key1", "value1", "Agent") memory.store("key2", "value2", "Agent") all_data = memory.get_all() assert all_data == {"key1": "value1", "key2": "value2"} def test_overwrite_value(self, memory): """Test overwriting an existing value.""" memory.store("key", "original", "Agent") memory.store("key", "updated", "Agent") assert memory.retrieve("key") == "updated" def test_history_tracking(self, memory): """Test that history is tracked for all operations.""" memory.store("key1", "v1", "Agent1") memory.store("key2", "v2", "Agent2") memory.store("key1", "v1_updated", "Agent1") history = memory.get_history() assert len(history) == 3 assert history[0].key == "key1" assert history[1].key == "key2" assert history[2].value == "v1_updated" def test_history_filter_by_agent(self, memory): """Test filtering history by agent name.""" memory.store("k1", "v1", "Agent1") memory.store("k2", "v2", "Agent2") memory.store("k3", "v3", "Agent1") agent1_history = memory.get_history(agent_name="Agent1") assert len(agent1_history) == 2 assert all(e.agent_name == "Agent1" for e in agent1_history) def test_has_key(self, memory): """Test key existence check.""" memory.store("exists", "value", "Agent") assert memory.has_key("exists") is True assert memory.has_key("not_exists") is False def test_clear(self, memory): """Test clearing all data.""" memory.store("k1", "v1", "Agent") memory.store("k2", "v2", "Agent") memory.clear() assert memory.retrieve("k1") is None assert memory.retrieve("k2") is None assert len(memory.get_history()) == 0 def test_keys(self, memory): """Test getting all keys.""" memory.store("a", 1, "Agent") memory.store("b", 2, "Agent") memory.store("c", 3, "Agent") keys = memory.keys() assert set(keys) == {"a", "b", "c"} def test_thread_safety(self, memory): """Test that operations are thread-safe.""" results = [] errors = [] def writer(n): try: for i in range(100): memory.store(f"key_{n}_{i}", i, f"Agent{n}") except Exception as e: errors.append(e) def reader(): try: for _ in range(100): memory.get_all() memory.keys() except Exception as e: errors.append(e) threads = [ threading.Thread(target=writer, args=(i,)) for i in range(3) ] threads.append(threading.Thread(target=reader)) for t in threads: t.start() for t in threads: t.join() assert len(errors) == 0