gMAS / tests /test_memory.py
ะั€ั‚ั‘ะผ ะ‘ะพัั€ัะบะธั…
chore: initial commit
3193174
import time
from typing import Any
import pytest
import torch
from utils.memory import (
AgentMemory,
HiddenChannel,
MemoryConfig,
MemoryEntry,
MemoryLevel,
Message,
MessageProtocol,
RoleFamilyFilter,
SharedMemoryPool,
SharingPolicy,
SubgraphFilter,
SummaryCompressor,
TagBasedFilter,
TruncateCompressor,
)
class TestMemoryEntry:
"""Tests for MemoryEntry."""
def test_create_default(self):
"""Create an entry with default parameters."""
entry = MemoryEntry(content={"text": "hello"})
assert entry.content == {"text": "hello"}
assert entry.level == MemoryLevel.WORKING
assert entry.priority == 0
assert entry.ttl is None
assert entry.source_agent is None
def test_create_with_all_fields(self):
"""Create an entry specifying all fields."""
entry = MemoryEntry(
content={"role": "user", "content": "hi"},
level=MemoryLevel.LONG_TERM,
ttl=3600.0,
priority=5,
tags={"shared", "very imp"},
source_agent="agent_a",
)
assert entry.level == MemoryLevel.LONG_TERM
assert entry.ttl == 3600.0
assert entry.priority == 5
assert entry.tags == {"shared", "very imp"}
assert entry.source_agent == "agent_a"
def test_not_expired_without_ttl(self):
"""Entry without TTL never expires."""
entry = MemoryEntry(content={"text": "hi"}, ttl=None)
assert not entry.is_expired
def test_not_expired_with_long_ttl(self):
"""Entry with a long TTL does not expire immediately."""
entry = MemoryEntry(content={"text": "hi"}, ttl=3600.0)
assert not entry.is_expired
def test_expired_with_short_ttl(self):
"""Entry with a very short TTL expires after the delay."""
entry = MemoryEntry(content={"text": "hi"}, ttl=0.01)
time.sleep(0.02)
assert entry.is_expired
def test_touch_updates_accessed_at(self):
"""touch() updates the last-accessed timestamp."""
entry = MemoryEntry(content={"text": "hi"})
old_time = entry.accessed_at
time.sleep(0.01)
entry.touch()
assert entry.accessed_at > old_time
def test_age_property(self):
"""Age reflects time elapsed since creation."""
entry = MemoryEntry(content={"text": "hi"})
time.sleep(0.01)
assert entry.age >= 0.01
assert entry.age < 1.0
class TestTruncateCompressor:
"""Tests for TruncateCompressor."""
def _make_entry(self, priority: int = 0, accessed_at: float = 0.0) -> MemoryEntry:
entry = MemoryEntry(content={"text": "data"}, priority=priority)
entry.accessed_at = accessed_at
return entry
def test_no_compression_needed(self):
"""No compression when entry count is within limit."""
comp = TruncateCompressor()
entries = [MemoryEntry(content={"i": i}) for i in range(3)]
result = comp.compress(entries, max_entries=5)
assert result is entries
assert len(result) == 3
def test_exact_limit(self):
"""Entry count exactly at max_entries โ€” returned unchanged."""
comp = TruncateCompressor()
entries = [MemoryEntry(content={"i": i}) for i in range(5)]
result = comp.compress(entries, max_entries=5)
assert result is entries
assert len(result) == 5
def test_empty_list(self):
"""Empty input list returns empty list."""
comp = TruncateCompressor()
result = comp.compress([], max_entries=3)
assert result == []
def test_keeps_high_priority_and_recent(self):
"""By default, high-priority and recent entries are preserved."""
comp = TruncateCompressor()
e_low_old = self._make_entry(priority=0, accessed_at=100.0)
e_low_new = self._make_entry(priority=0, accessed_at=300.0)
e_high_old = self._make_entry(priority=10, accessed_at=100.0)
e_high_new = self._make_entry(priority=10, accessed_at=300.0)
entries = [e_low_old, e_low_new, e_high_old, e_high_new]
result = comp.compress(entries, max_entries=2)
assert len(result) == 2
assert e_high_new in result
assert e_high_old in result
def test_priority_vs_recency(self):
"""High priority beats high recency when both can't fit."""
comp = TruncateCompressor()
e_low_new = self._make_entry(priority=0, accessed_at=999.0)
e_high_old = self._make_entry(priority=10, accessed_at=1.0)
result = comp.compress([e_low_new, e_high_old], max_entries=1)
assert len(result) == 1
assert result[0] is e_high_old
def test_keep_recent_false(self):
"""keep_recent=False โ€” older entries survive."""
comp = TruncateCompressor(keep_recent=False, keep_high_priority=False)
e_old = self._make_entry(priority=0, accessed_at=100.0)
e_new = self._make_entry(priority=0, accessed_at=999.0)
result = comp.compress([e_new, e_old], max_entries=1)
assert len(result) == 1
assert result[0] is e_old
def test_keep_high_priority_false(self):
"""keep_high_priority=False โ€” priority is ignored, only recency matters."""
comp = TruncateCompressor(keep_recent=True, keep_high_priority=False)
e_high_old = self._make_entry(priority=99, accessed_at=1.0)
e_low_new = self._make_entry(priority=0, accessed_at=999.0)
result = comp.compress([e_high_old, e_low_new], max_entries=1)
assert len(result) == 1
assert result[0] is e_low_new
class TestSummaryCompressor:
"""Tests for SummaryCompressor."""
@staticmethod
def _fake_summarizer(contents: list[dict[str, Any]]) -> dict[str, Any]:
"""Joins texts with ' | '."""
texts = [c.get("text", "") for c in contents]
return {"text": "Summary: " + " | ".join(texts)}
def test_no_compression_within_limit(self):
"""No compression when entry count is within limit."""
comp = SummaryCompressor(summarizer=self._fake_summarizer)
entries = [MemoryEntry(content={"text": f"msg{i}"}) for i in range(3)]
result = comp.compress(entries, max_entries=5)
assert result is entries
assert len(result) == 3
def test_no_summarizer_truncates(self):
"""Without a summarizer, entries are simply truncated."""
comp = SummaryCompressor(summarizer=None)
entries = [MemoryEntry(content={"text": f"msg{i}"}) for i in range(5)]
result = comp.compress(entries, max_entries=3)
assert len(result) == 3
assert result[0].content == {"text": "msg2"}
assert result[1].content == {"text": "msg3"}
assert result[2].content == {"text": "msg4"}
def test_summarizes_old_entries(self):
"""Old entries are summarized; recent entries are kept as-is."""
comp = SummaryCompressor(summarizer=self._fake_summarizer)
entries = [MemoryEntry(content={"text": f"msg{i}"}) for i in range(5)]
result = comp.compress(entries, max_entries=3)
assert len(result) == 3
assert "Summary:" in result[0].content["text"]
assert "msg0" in result[0].content["text"]
assert "msg1" in result[0].content["text"]
assert "msg2" in result[0].content["text"]
assert result[1].content == {"text": "msg3"}
assert result[2].content == {"text": "msg4"}
def test_summary_entry_level_is_long_term(self):
"""Summary entry always has LONG_TERM level."""
comp = SummaryCompressor(summarizer=self._fake_summarizer)
entries = [MemoryEntry(content={"text": f"msg{i}"}) for i in range(5)]
result = comp.compress(entries, max_entries=3)
assert result[0].level == MemoryLevel.LONG_TERM
def test_summary_entry_max_priority(self):
"""Summary entry receives the maximum priority from compressed entries."""
comp = SummaryCompressor(summarizer=self._fake_summarizer)
entries = [
MemoryEntry(content={"text": "a"}, priority=2),
MemoryEntry(content={"text": "b"}, priority=7),
MemoryEntry(content={"text": "c"}, priority=3),
MemoryEntry(content={"text": "d"}, priority=1),
]
result = comp.compress(entries, max_entries=2)
assert result[0].priority == 7
def test_summary_entry_tags_merged(self):
"""Summary entry merges tags from all compressed entries."""
comp = SummaryCompressor(summarizer=self._fake_summarizer)
entries = [
MemoryEntry(content={"text": "a"}, tags={"alpha", "beta"}),
MemoryEntry(content={"text": "b"}, tags={"beta", "gamma"}),
MemoryEntry(content={"text": "c"}, tags={"delta"}),
MemoryEntry(content={"text": "d"}, tags={"epsilon"}),
]
result = comp.compress(entries, max_entries=2)
assert result[0].tags == {"alpha", "beta", "gamma", "delta"}
class TestMemoryConfig:
"""Tests for MemoryConfig."""
def test_default_values(self):
"""All default values are correct."""
cfg = MemoryConfig()
assert cfg.working_max_entries == 20
assert cfg.working_default_ttl == 3600.0
assert cfg.long_term_max_entries == 100
assert cfg.long_term_default_ttl is None
assert cfg.auto_compress is True
assert cfg.promote_after_accesses == 3
assert cfg.demote_inactive_after == 7200.0
assert cfg.cleanup_interval == 300.0
assert cfg.hidden_state_dim is None
def test_default_compression_strategy(self):
"""Default compression strategy is TruncateCompressor."""
cfg = MemoryConfig()
assert isinstance(cfg.compression_strategy, TruncateCompressor)
def test_custom_values(self):
"""All parameters can be overridden."""
custom_comp = SummaryCompressor()
cfg = MemoryConfig(
working_max_entries=50,
working_default_ttl=None,
long_term_max_entries=200,
long_term_default_ttl=86400.0,
compression_strategy=custom_comp,
auto_compress=False,
promote_after_accesses=10,
demote_inactive_after=3600.0,
cleanup_interval=60.0,
hidden_state_dim=128,
)
assert cfg.working_max_entries == 50
assert cfg.working_default_ttl is None
assert cfg.long_term_max_entries == 200
assert cfg.long_term_default_ttl == 86400.0
assert cfg.compression_strategy is custom_comp
assert cfg.auto_compress is False
assert cfg.promote_after_accesses == 10
assert cfg.demote_inactive_after == 3600.0
assert cfg.cleanup_interval == 60.0
assert cfg.hidden_state_dim == 128
class TestAgentMemory:
"""Tests for AgentMemory."""
def test_init_default_config(self):
"""Create with default config."""
mem = AgentMemory("agent_1")
assert mem.agent_id == "agent_1"
assert isinstance(mem.config, MemoryConfig)
assert mem._working == []
assert mem._long_term == []
def test_init_custom_config(self):
"""Create with custom config."""
cfg = MemoryConfig(working_max_entries=5)
mem = AgentMemory("agent_2", config=cfg)
assert mem.config.working_max_entries == 5
def test_add_to_working(self):
"""add() places entry in working memory by default."""
mem = AgentMemory("a")
entry = mem.add(content={"text": "hi"})
assert entry in mem._working
assert entry not in mem._long_term
assert entry.level == MemoryLevel.WORKING
def test_add_to_long_term(self):
"""add() with level=LONG_TERM places entry in long-term memory."""
mem = AgentMemory("a")
entry = mem.add(content={"text": "hi"}, level=MemoryLevel.LONG_TERM)
assert entry in mem._long_term
assert entry not in mem._working
assert entry.level == MemoryLevel.LONG_TERM
def test_add_default_ttl_working(self):
"""Default TTL from config is applied to working entries."""
cfg = MemoryConfig(working_default_ttl=999.0)
mem = AgentMemory("a", config=cfg)
entry = mem.add(content={"text": "hi"})
assert entry.ttl == 999.0
def test_add_default_ttl_long_term(self):
"""Default TTL from config is applied to long-term entries."""
cfg = MemoryConfig(long_term_default_ttl=5000.0)
mem = AgentMemory("a", config=cfg)
entry = mem.add(content={"text": "hi"}, level=MemoryLevel.LONG_TERM)
assert entry.ttl == 5000.0
def test_add_explicit_ttl(self):
"""Explicit TTL overrides the config default."""
mem = AgentMemory("a")
entry = mem.add(content={"text": "hi"}, ttl=42.0)
assert entry.ttl == 42.0
def test_add_with_tags_and_priority(self):
"""Tags and priority are stored in the entry."""
mem = AgentMemory("a")
entry = mem.add(
content={"text": "hi"},
priority=5,
tags={"urgent"},
source_agent="agent_b",
)
assert entry.priority == 5
assert entry.tags == {"urgent"}
assert entry.source_agent == "agent_b"
def test_add_message(self):
"""add_message() creates a working entry with role/content."""
mem = AgentMemory("a")
entry = mem.add_message(role="user", content="hello")
assert entry.content["role"] == "user"
assert entry.content["content"] == "hello"
assert entry.level == MemoryLevel.WORKING
def test_add_message_with_metadata(self):
"""add_message() passes additional metadata into the entry."""
mem = AgentMemory("a")
entry = mem.add_message(role="assistant", content="reply", tool_call_id="123")
assert entry.content["tool_call_id"] == "123"
def test_get_all(self):
"""get() without filters returns all entries."""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
result = mem.get()
assert len(result) == 2
def test_get_by_level(self):
"""get() with level filter returns only entries at that level."""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
working = mem.get(level=MemoryLevel.WORKING)
assert len(working) == 1
assert working[0].content == {"i": 1}
lt = mem.get(level=MemoryLevel.LONG_TERM)
assert len(lt) == 1
assert lt[0].content == {"i": 2}
def test_get_filters_expired(self):
"""get() excludes expired entries by default."""
mem = AgentMemory("a")
mem.add(content={"text": "old"}, ttl=0.01)
mem.add(content={"text": "fresh"}, ttl=9999.0)
time.sleep(0.02)
result = mem.get()
assert len(result) == 1
assert result[0].content["text"] == "fresh"
def test_get_include_expired(self):
"""get(include_expired=True) returns expired entries as well."""
mem = AgentMemory("a")
mem.add(content={"text": "old"}, ttl=0.01)
time.sleep(0.02)
result = mem.get(include_expired=True)
assert len(result) == 1
def test_get_by_tags(self):
"""get(tags=...) returns only entries with matching tags."""
mem = AgentMemory("a")
mem.add(content={"i": 1}, tags={"alpha"})
mem.add(content={"i": 2}, tags={"beta"})
mem.add(content={"i": 3}, tags={"alpha", "beta"})
result = mem.get(tags={"alpha"})
assert len(result) == 2
def test_get_with_limit(self):
"""get(limit=N) returns the last N entries."""
mem = AgentMemory("a")
for i in range(10):
mem.add(content={"i": i})
result = mem.get(limit=3)
assert len(result) == 3
assert result[0].content == {"i": 7}
assert result[2].content == {"i": 9}
def test_promote_after_accesses(self):
"""Entry is promoted to long-term after N accesses via get()."""
cfg = MemoryConfig(promote_after_accesses=3)
mem = AgentMemory("a", config=cfg)
mem.add(content={"text": "test"})
mem.get()
mem.get()
assert len(mem._working) == 1
mem.get()
assert len(mem._working) == 0
assert len(mem._long_term) == 1
assert mem._long_term[0].level == MemoryLevel.LONG_TERM
def test_get_messages(self):
"""get_messages() returns only entries that have a 'role' field."""
mem = AgentMemory("a")
mem.add_message(role="user", content="hi")
mem.add(content={"text": "not a message"})
messages = mem.get_messages()
assert len(messages) == 1
assert messages[0]["role"] == "user"
def test_get_messages_with_limit(self):
"""get_messages(limit=N) limits the number of returned messages."""
mem = AgentMemory("a")
for i in range(5):
mem.add_message(role="user", content=f"msg{i}")
messages = mem.get_messages(limit=2)
assert len(messages) == 2
def test_clear_all(self):
"""clear() without arguments clears all memory levels."""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
mem.clear()
assert mem._working == []
assert mem._long_term == []
def test_clear_working_only(self):
"""clear(WORKING) clears only working memory."""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
mem.clear(level=MemoryLevel.WORKING)
assert mem._working == []
assert len(mem._long_term) == 1
def test_clear_long_term_only(self):
"""clear(LONG_TERM) clears only long-term memory."""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
mem.clear(level=MemoryLevel.LONG_TERM)
assert len(mem._working) == 1
assert mem._long_term == []
def test_remove_expired(self):
"""remove_expired() deletes expired entries and returns count removed."""
mem = AgentMemory("a")
mem.add(content={"text": "old"}, ttl=0.01)
mem.add(content={"text": "fresh"}, ttl=9999.0)
time.sleep(0.02)
removed = mem.remove_expired()
assert removed == 1
assert len(mem._working) == 1
assert mem._working[0].content["text"] == "fresh"
def test_auto_compress_working(self):
"""Overflow in working memory triggers automatic compression."""
cfg = MemoryConfig(working_max_entries=3, auto_compress=True)
mem = AgentMemory("a", config=cfg)
for i in range(5):
mem.add(content={"i": i})
assert len(mem._working) <= 3
def test_auto_compress_disabled(self):
"""auto_compress=False โ€” entries are not compressed."""
cfg = MemoryConfig(working_max_entries=3, auto_compress=False)
mem = AgentMemory("a", config=cfg)
for i in range(5):
mem.add(content={"i": i})
assert len(mem._working) == 5
def test_working_memory_property(self):
"""working_memory returns only non-expired entries from working level."""
mem = AgentMemory("a")
mem.add(content={"text": "alive"}, ttl=9999.0)
mem.add(content={"text": "dead"}, ttl=0.01)
time.sleep(0.02)
result = mem.working_memory
assert len(result) == 1
assert result[0].content["text"] == "alive"
def test_all_entries_property(self):
"""all_entries combines working and long-term, excluding expired."""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
assert len(mem.all_entries) == 2
def test_hidden_state_and_embedding(self):
"""hidden_state and embedding are stored and retrieved correctly."""
mem = AgentMemory("a")
assert mem.hidden_state is None
assert mem.embedding is None
hs = torch.randn(64)
emb = torch.randn(128)
mem.hidden_state = hs
mem.embedding = emb
assert mem.hidden_state is not None
assert torch.equal(mem.hidden_state, hs)
assert mem.embedding is not None
assert torch.equal(mem.embedding, emb)
class TestAgentMemorySerialization:
"""Tests for to_dict / from_dict."""
def test_to_dict_empty(self):
"""Empty memory serializes correctly."""
mem = AgentMemory("agent_1")
d = mem.to_dict()
assert d["agent_id"] == "agent_1"
assert d["working"] == []
assert d["long_term"] == []
assert d["hidden_state"] is None
assert d["embedding"] is None
def test_to_dict_with_long_term_entry(self):
"""Long-term entries appear under their own key."""
mem = AgentMemory("a")
mem.add(content={"text": "important"}, level=MemoryLevel.LONG_TERM, priority=9)
d = mem.to_dict()
assert len(d["working"]) == 0
assert len(d["long_term"]) == 1
assert d["long_term"][0]["content"] == {"text": "important"}
assert d["long_term"][0]["priority"] == 9
def test_to_dict_with_hidden_state_and_embedding(self):
"""hidden_state and embedding are serialized as lists."""
mem = AgentMemory("a")
mem.hidden_state = torch.tensor([1.0, 2.0, 3.0])
mem.embedding = torch.tensor([4.0, 5.0])
d = mem.to_dict()
assert d["hidden_state"] == [1.0, 2.0, 3.0]
assert d["embedding"] == [4.0, 5.0]
def test_to_dict_tags_serialized_as_list(self):
"""Tags (set) are converted to list during serialization."""
mem = AgentMemory("a")
mem.add(content={"x": 1}, tags={"a", "b", "c"})
d = mem.to_dict()
assert isinstance(d["working"][0]["tags"], list)
assert set(d["working"][0]["tags"]) == {"a", "b", "c"}
def test_to_dict_does_not_include_source_agent(self): # TODO: may need to revisit this behavior
"""source_agent is NOT saved during serialization (data loss)."""
mem = AgentMemory("a")
mem.add(content={"x": 1}, source_agent="agent_b")
d = mem.to_dict()
assert "source_agent" not in d["working"][0]
def test_to_dict_does_not_include_accessed_at(self): # TODO: may need to revisit this behavior
"""accessed_at is NOT saved during serialization."""
mem = AgentMemory("a")
mem.add(content={"x": 1})
d = mem.to_dict()
assert "accessed_at" not in d["working"][0]
def test_from_dict_basic(self):
"""from_dict() restores agent_id and all entry fields."""
data = {
"agent_id": "agent_x",
"working": [
{"content": {"text": "hi"}, "priority": 2, "tags": ["a"], "ttl": 100.0, "created_at": 1000.0},
],
"long_term": [
{"content": {"text": "old"}, "priority": 5, "tags": ["b", "c"], "ttl": None, "created_at": 900.0},
],
"hidden_state": None,
"embedding": None,
}
mem = AgentMemory.from_dict(data)
assert mem.agent_id == "agent_x"
assert len(mem._working) == 1
assert len(mem._long_term) == 1
assert mem._working[0].content == {"text": "hi"}
assert mem._working[0].priority == 2
assert mem._working[0].tags == {"a"}
assert mem._working[0].ttl == 100.0
assert mem._working[0].created_at == 1000.0
assert mem._working[0].level == MemoryLevel.WORKING
assert mem._long_term[0].content == {"text": "old"}
assert mem._long_term[0].level == MemoryLevel.LONG_TERM
assert mem._long_term[0].tags == {"b", "c"}
def test_from_dict_with_custom_config(self):
"""from_dict() accepts a custom config."""
data = {"agent_id": "a", "working": [], "long_term": []}
cfg = MemoryConfig(working_max_entries=7)
mem = AgentMemory.from_dict(data, config=cfg)
assert mem.config.working_max_entries == 7
def test_from_dict_default_config(self):
"""from_dict() uses default MemoryConfig when none is provided."""
data = {"agent_id": "a", "working": [], "long_term": []}
mem = AgentMemory.from_dict(data)
assert isinstance(mem.config, MemoryConfig)
assert mem.config.working_max_entries == 20
def test_from_dict_restores_hidden_state(self):
"""from_dict() restores hidden_state and embedding from lists."""
data = {
"agent_id": "a",
"working": [],
"long_term": [],
"hidden_state": [1.0, 2.0, 3.0],
"embedding": [4.0, 5.0],
}
mem = AgentMemory.from_dict(data)
assert mem.hidden_state is not None
assert torch.equal(mem.hidden_state, torch.tensor([1.0, 2.0, 3.0]))
assert mem.embedding is not None
assert torch.equal(mem.embedding, torch.tensor([4.0, 5.0]))
def test_from_dict_missing_optional_fields(self):
"""from_dict() with minimal data uses sensible defaults."""
data = {
"agent_id": "a",
"working": [{"content": {"msg": "hi"}}],
"long_term": [],
}
mem = AgentMemory.from_dict(data)
entry = mem._working[0]
assert entry.content == {"msg": "hi"}
assert entry.priority == 0
assert entry.tags == set()
assert entry.ttl is None
assert entry.created_at > 0
def test_roundtrip(self):
"""to_dict โ†’ from_dict preserves all data."""
mem = AgentMemory("rt_agent")
mem.add(content={"text": "w1"}, priority=1, tags={"x"}, ttl=500.0)
mem.add(content={"text": "w2"}, priority=2, tags={"y", "z"})
mem.add(content={"text": "lt1"}, level=MemoryLevel.LONG_TERM, priority=8)
mem.hidden_state = torch.tensor([1.0, 2.0])
mem.embedding = torch.tensor([3.0, 4.0, 5.0])
d = mem.to_dict()
restored = AgentMemory.from_dict(d)
assert restored.agent_id == "rt_agent"
assert len(restored._working) == 2
assert len(restored._long_term) == 1
assert restored._working[0].content == {"text": "w1"}
assert restored._working[0].priority == 1
assert restored._working[0].tags == {"x"}
assert restored._working[0].ttl == 500.0
assert restored._working[1].content == {"text": "w2"}
assert restored._working[1].tags == {"y", "z"}
assert restored._long_term[0].content == {"text": "lt1"}
assert restored._long_term[0].priority == 8
assert restored.hidden_state is not None
assert torch.equal(restored.hidden_state, torch.tensor([1.0, 2.0]))
assert restored.embedding is not None
assert torch.equal(restored.embedding, torch.tensor([3.0, 4.0, 5.0]))
def test_roundtrip_preserves_created_at(self):
"""Round-trip preserves created_at timestamp."""
mem = AgentMemory("a")
entry = mem.add(content={"x": 1})
original_created = entry.created_at
restored = AgentMemory.from_dict(mem.to_dict())
assert restored._working[0].created_at == pytest.approx(original_created, abs=0.001)
def test_roundtrip_loses_source_agent(self): # TODO: may need to revisit this behavior
"""Round-trip does not preserve source_agent."""
mem = AgentMemory("a")
mem.add(content={"x": 1}, source_agent="agent_b")
restored = AgentMemory.from_dict(mem.to_dict())
assert restored._working[0].source_agent is None
class TestAccessFilters:
"""Tests for TagBasedFilter, SubgraphFilter, RoleFamilyFilter."""
def test_tag_filter_allows_on_intersection(self):
"""Access is granted when entry tags intersect with shared_tags."""
f = TagBasedFilter(shared_tags={"shared", "public"})
entry = MemoryEntry(content={"x": 1}, tags={"shared", "internal"})
assert f.can_access("agent_a", "agent_b", entry) is True
def test_tag_filter_denies_no_intersection(self):
"""Access is denied when there is no tag intersection."""
f = TagBasedFilter(shared_tags={"shared", "public"})
entry = MemoryEntry(content={"x": 1}, tags={"private", "internal"})
assert f.can_access("agent_a", "agent_b", entry) is False
def test_tag_filter_empty_entry_tags(self):
"""Access is denied when the entry has no tags."""
f = TagBasedFilter(shared_tags={"shared"})
entry = MemoryEntry(content={"x": 1}, tags=set())
assert f.can_access("agent_a", "agent_b", entry) is False
def test_tag_filter_empty_shared_tags(self):
"""Access is denied when shared_tags is empty."""
f = TagBasedFilter(shared_tags=set())
entry = MemoryEntry(content={"x": 1}, tags={"shared"})
assert f.can_access("agent_a", "agent_b", entry) is False
def test_tag_filter_ignores_agent_ids(self):
"""TagBasedFilter does not depend on requester or owner IDs."""
f = TagBasedFilter(shared_tags={"ok"})
entry = MemoryEntry(content={"x": 1}, tags={"ok"})
assert f.can_access("any", "who", entry) is True
def test_subgraph_filter_same_subgraph_allows(self):
"""Agents in the same subgraph are allowed access."""
f = SubgraphFilter(
subgraph_members={
"sg1": {"agent_a", "agent_b"},
"sg2": {"agent_c"},
}
)
entry = MemoryEntry(content={"x": 1})
assert f.can_access("agent_a", "agent_b", entry) is True
def test_subgraph_filter_different_subgraph_denies(self):
"""Agents in different subgraphs are denied access."""
f = SubgraphFilter(
subgraph_members={
"sg1": {"agent_a"},
"sg2": {"agent_b"},
}
)
entry = MemoryEntry(content={"x": 1})
assert f.can_access("agent_a", "agent_b", entry) is False
def test_subgraph_filter_unknown_requester_denies(self):
"""Unknown requester is denied access."""
f = SubgraphFilter(subgraph_members={"sg1": {"agent_a"}})
entry = MemoryEntry(content={"x": 1})
assert f.can_access("unknown", "agent_a", entry) is False
def test_subgraph_filter_unknown_owner_denies(self):
"""Unknown owner is denied access."""
f = SubgraphFilter(subgraph_members={"sg1": {"agent_a"}})
entry = MemoryEntry(content={"x": 1})
assert f.can_access("agent_a", "unknown", entry) is False
def test_subgraph_filter_ignores_entry(self):
"""SubgraphFilter does not depend on the entry's content."""
f = SubgraphFilter(subgraph_members={"sg1": {"a", "b"}})
entry = MemoryEntry(content={"x": 1}, tags={"secret"})
assert f.can_access("a", "b", entry) is True
def test_subgraph_agent_to_subgraph_mapping(self):
"""Reverse index _agent_to_subgraph is built correctly."""
f = SubgraphFilter(
subgraph_members={
"analysis": {"solver", "reviewer"},
"output": {"writer"},
}
)
assert f._agent_to_subgraph["solver"] == "analysis"
assert f._agent_to_subgraph["reviewer"] == "analysis"
assert f._agent_to_subgraph["writer"] == "output"
def test_role_family_same_family_allows(self):
"""Agents from the same role family are allowed access."""
f = RoleFamilyFilter(
role_families={
"analysts": {"data_analyst", "market_analyst"},
"writers": {"copywriter", "editor"},
}
)
entry = MemoryEntry(content={"x": 1})
assert f.can_access("data_analyst", "market_analyst", entry) is True
def test_role_family_different_family_denies(self):
"""Agents from different role families are denied access."""
f = RoleFamilyFilter(
role_families={
"analysts": {"data_analyst"},
"writers": {"copywriter"},
}
)
entry = MemoryEntry(content={"x": 1})
assert f.can_access("data_analyst", "copywriter", entry) is False
def test_role_family_unknown_requester_denies(self):
"""Requester not belonging to any family is denied access."""
f = RoleFamilyFilter(role_families={"analysts": {"data_analyst"}})
entry = MemoryEntry(content={"x": 1})
assert f.can_access("outsider", "data_analyst", entry) is False
def test_role_family_unknown_owner_denies(self):
"""Owner not belonging to any family is denied access."""
f = RoleFamilyFilter(role_families={"analysts": {"data_analyst"}})
entry = MemoryEntry(content={"x": 1})
assert f.can_access("data_analyst", "outsider", entry) is False
def test_role_family_agent_to_family_mapping(self):
"""Reverse index _agent_to_family is built correctly."""
f = RoleFamilyFilter(
role_families={
"analysts": {"solver", "reviewer"},
"writers": {"editor"},
}
)
assert f._agent_to_family["solver"] == "analysts"
assert f._agent_to_family["reviewer"] == "analysts"
assert f._agent_to_family["editor"] == "writers"
def test_role_family_ignores_entry(self):
"""RoleFamilyFilter does not depend on the entry's content."""
f = RoleFamilyFilter(role_families={"team": {"a", "b"}})
entry = MemoryEntry(content={"x": 1}, tags={"secret"}, priority=99)
assert f.can_access("a", "b", entry) is True
class TestSharedMemoryPool:
"""Tests for SharedMemoryPool."""
def _make_pool_with_agents(self, *agent_ids: str) -> tuple[SharedMemoryPool, dict[str, AgentMemory]]:
pool = SharedMemoryPool()
memories = {}
for aid in agent_ids:
mem = AgentMemory(aid)
pool.register(mem)
memories[aid] = mem
return pool, memories
def test_init_defaults(self):
"""Pool is created with correct defaults."""
pool = SharedMemoryPool()
assert pool.access_filter is None
assert pool.default_policy == SharingPolicy.BY_TAGS
assert pool._memories == {}
assert pool._shared_entries == []
def test_register_and_unregister(self):
"""register/unregister adds and removes agent memory."""
pool = SharedMemoryPool()
mem = AgentMemory("a")
pool.register(mem)
assert "a" in pool._memories
pool.unregister("a")
assert "a" not in pool._memories
def test_unregister_nonexistent(self):
"""Unregistering an unknown agent does not raise."""
pool = SharedMemoryPool()
pool.unregister("ghost")
def test_share_to_specific_agents(self):
"""share() with to_agents places a copy in each recipient's working memory."""
pool, mems = self._make_pool_with_agents("sender", "recv_1", "recv_2")
entry = mems["sender"].add(content={"text": "hello"}, priority=3, tags={"info"})
pool.share("sender", entry, to_agents=["recv_1", "recv_2"])
r1_entries = mems["recv_1"].get()
assert len(r1_entries) == 1
assert r1_entries[0].content == {"text": "hello"}
assert r1_entries[0].priority == 3
assert "shared" in r1_entries[0].tags
assert r1_entries[0].source_agent == "sender"
assert r1_entries[0].level == MemoryLevel.WORKING
r2_entries = mems["recv_2"].get()
assert len(r2_entries) == 1
def test_share_to_unknown_agent_ignored(self):
"""Sharing to a non-registered agent does not raise."""
pool, mems = self._make_pool_with_agents("sender")
entry = mems["sender"].add(content={"text": "hi"})
pool.share("sender", entry, to_agents=["ghost"])
def test_share_to_pool(self):
"""share() without to_agents places the entry in the shared pool."""
pool, mems = self._make_pool_with_agents("sender")
entry = mems["sender"].add(content={"text": "public"})
pool.share("sender", entry)
assert len(pool._shared_entries) == 1
assert pool._shared_entries[0].source_agent == "sender"
assert pool._shared_entries[0].level == MemoryLevel.SHARED
def test_share_copies_content(self):
"""share() copies content rather than sharing the reference."""
pool, mems = self._make_pool_with_agents("sender")
original_content = {"text": "hello"}
entry = mems["sender"].add(content=original_content)
pool.share("sender", entry)
pool._shared_entries[0].content["text"] = "modified"
assert entry.content["text"] == "hello"
def test_get_shared_no_filter(self):
"""get_shared() without a filter returns all non-expired entries."""
pool, mems = self._make_pool_with_agents("a", "b")
entry = mems["a"].add(content={"text": "data"})
pool.share("a", entry)
result = pool.get_shared("b")
assert len(result) == 1
assert result[0].content == {"text": "data"}
def test_get_shared_filters_by_tags(self):
"""get_shared() with tags filters by tag intersection."""
pool, mems = self._make_pool_with_agents("a")
e1 = mems["a"].add(content={"i": 1}, tags={"alpha"})
e2 = mems["a"].add(content={"i": 2}, tags={"beta"})
pool.share("a", e1)
pool.share("a", e2)
result = pool.get_shared("anyone", tags={"alpha"})
assert len(result) == 1
assert result[0].content == {"i": 1}
def test_get_shared_filters_expired(self):
"""get_shared() does not return expired entries."""
pool, mems = self._make_pool_with_agents("a")
entry = mems["a"].add(content={"text": "old"}, ttl=0.01)
pool.share("a", entry)
time.sleep(0.02)
result = pool.get_shared("anyone")
assert len(result) == 0
def test_get_shared_with_limit(self):
"""get_shared() with limit returns the last N entries."""
pool, mems = self._make_pool_with_agents("a")
for i in range(5):
entry = mems["a"].add(content={"i": i})
pool.share("a", entry)
result = pool.get_shared("anyone", limit=2)
assert len(result) == 2
assert result[0].content == {"i": 3}
assert result[1].content == {"i": 4}
def test_get_shared_with_access_filter(self):
"""get_shared() respects the access_filter."""
f = SubgraphFilter(
subgraph_members={
"sg1": {"a", "b"},
"sg2": {"c"},
}
)
pool = SharedMemoryPool(access_filter=f)
mem_a = AgentMemory("a")
pool.register(mem_a)
entry = mem_a.add(content={"text": "secret"})
pool.share("a", entry)
assert len(pool.get_shared("b")) == 1
assert len(pool.get_shared("c")) == 0
def test_get_from_agent_basic(self):
"""get_from_agent() returns entries from a specific agent."""
pool, mems = self._make_pool_with_agents("a", "b")
mems["a"].add(content={"text": "data_a"})
result = pool.get_from_agent("b", "a")
assert len(result) == 1
assert result[0].content == {"text": "data_a"}
def test_get_from_agent_unknown_owner(self):
"""get_from_agent() returns empty list for unknown owner."""
pool = SharedMemoryPool()
result = pool.get_from_agent("anyone", "ghost")
assert result == []
def test_get_from_agent_with_filter(self):
"""get_from_agent() respects the access_filter."""
f = SubgraphFilter(
subgraph_members={
"sg1": {"a", "b"},
"sg2": {"c"},
}
)
pool = SharedMemoryPool(access_filter=f)
mem_a = AgentMemory("a")
pool.register(mem_a)
mem_a.add(content={"text": "data"})
assert len(pool.get_from_agent("b", "a")) == 1
assert len(pool.get_from_agent("c", "a")) == 0
def test_get_from_agent_with_level_filter(self):
"""get_from_agent() filters by level."""
pool, mems = self._make_pool_with_agents("a", "b")
mems["a"].add(content={"i": 1})
mems["a"].add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
result = pool.get_from_agent("b", "a", level=MemoryLevel.WORKING)
assert len(result) == 1
assert result[0].content == {"i": 1}
def test_get_from_agent_with_tags_filter(self):
"""get_from_agent() filters by tags."""
pool, mems = self._make_pool_with_agents("a", "b")
mems["a"].add(content={"i": 1}, tags={"alpha"})
mems["a"].add(content={"i": 2}, tags={"beta"})
result = pool.get_from_agent("b", "a", tags={"beta"})
assert len(result) == 1
assert result[0].content == {"i": 2}
def test_broadcast(self):
"""broadcast() sends an entry to all agents except the sender."""
pool, mems = self._make_pool_with_agents("sender", "recv_1", "recv_2")
pool.broadcast("sender", content={"text": "hello all"}, tags={"news"})
assert len(mems["sender"].get()) == 0
r1 = mems["recv_1"].get()
assert len(r1) == 1
assert r1[0].content == {"text": "hello all"}
assert "broadcast" in r1[0].tags
assert "news" in r1[0].tags
assert r1[0].source_agent == "sender"
r2 = mems["recv_2"].get()
assert len(r2) == 1
def test_broadcast_no_tags(self):
"""broadcast() without tags adds only the 'broadcast' tag."""
pool, mems = self._make_pool_with_agents("a", "b")
pool.broadcast("a", content={"x": 1})
entry = mems["b"].get()[0]
assert entry.tags == {"broadcast"}
class TestHiddenChannel:
"""Tests for HiddenChannel."""
def test_create_defaults(self):
"""Create with default values."""
hc = HiddenChannel()
assert hc.hidden_state is None
assert hc.embedding is None
assert hc.metadata == {}
def test_create_with_tensors(self):
"""Create with tensor fields."""
hs = torch.tensor([1.0, 2.0, 3.0])
emb = torch.tensor([4.0, 5.0])
hc = HiddenChannel(hidden_state=hs, embedding=emb, metadata={"key": "val"})
assert hc.hidden_state is not None
assert torch.equal(hc.hidden_state, hs)
assert hc.embedding is not None
assert torch.equal(hc.embedding, emb)
assert hc.metadata == {"key": "val"}
def test_to_dict_with_tensors(self):
"""to_dict converts tensors to lists."""
hc = HiddenChannel(
hidden_state=torch.tensor([1.0, 2.0]),
embedding=torch.tensor([3.0]),
metadata={"a": 1},
)
d = hc.to_dict()
assert d["hidden_state"] == [1.0, 2.0]
assert d["embedding"] == [3.0]
assert d["metadata"] == {"a": 1}
def test_to_dict_none_tensors(self):
"""to_dict with None tensors."""
hc = HiddenChannel()
d = hc.to_dict()
assert d["hidden_state"] is None
assert d["embedding"] is None
assert d["metadata"] == {}
def test_from_dict_with_tensors(self):
"""from_dict restores tensors from lists."""
data = {
"hidden_state": [1.0, 2.0],
"embedding": [3.0, 4.0],
"metadata": {"key": "val"},
}
hc = HiddenChannel.from_dict(data)
assert hc.hidden_state is not None
assert torch.equal(hc.hidden_state, torch.tensor([1.0, 2.0]))
assert hc.embedding is not None
assert torch.equal(hc.embedding, torch.tensor([3.0, 4.0]))
assert hc.metadata == {"key": "val"}
def test_from_dict_none_tensors(self):
"""from_dict with None tensors."""
data = {"hidden_state": None, "embedding": None, "metadata": {}}
hc = HiddenChannel.from_dict(data)
assert hc.hidden_state is None
assert hc.embedding is None
def test_from_dict_missing_metadata(self):
"""from_dict without metadata defaults to empty dict."""
data = {"hidden_state": None, "embedding": None}
hc = HiddenChannel.from_dict(data)
assert hc.metadata == {}
class TestMessage:
"""Tests for Message."""
def test_create_minimal(self):
"""Create with minimal parameters."""
msg = Message(sender_id="a", receiver_id=None, content="hello")
assert msg.sender_id == "a"
assert msg.receiver_id is None
assert msg.content == "hello"
assert msg.role == "assistant"
assert msg.message_type == "response"
assert msg.priority == 0
assert msg.tags == set()
assert msg.hidden is None
assert msg.timestamp > 0
def test_create_full(self):
"""Create with all fields."""
hc = HiddenChannel(hidden_state=torch.tensor([1.0]))
msg = Message(
sender_id="a",
receiver_id="b",
content="hi",
role="user",
hidden=hc,
message_type="query",
priority=5,
tags={"urgent"},
)
assert msg.receiver_id == "b"
assert msg.role == "user"
assert msg.hidden is hc
assert msg.message_type == "query"
assert msg.priority == 5
assert msg.tags == {"urgent"}
def test_has_hidden_true_with_state(self):
"""has_hidden is True when hidden_state is present."""
hc = HiddenChannel(hidden_state=torch.tensor([1.0]))
msg = Message(sender_id="a", receiver_id=None, content="x", hidden=hc)
assert msg.has_hidden is True
def test_has_hidden_true_with_embedding(self):
"""has_hidden is True when embedding is present."""
hc = HiddenChannel(embedding=torch.tensor([1.0]))
msg = Message(sender_id="a", receiver_id=None, content="x", hidden=hc)
assert msg.has_hidden is True
def test_has_hidden_false_no_hidden(self):
"""has_hidden is False when no hidden channel is set."""
msg = Message(sender_id="a", receiver_id=None, content="x")
assert msg.has_hidden is False
def test_has_hidden_false_empty_channel(self):
"""has_hidden is False when HiddenChannel has no tensors."""
hc = HiddenChannel()
msg = Message(sender_id="a", receiver_id=None, content="x", hidden=hc)
assert msg.has_hidden is False
def test_to_visible_dict(self):
"""to_visible_dict returns only role, content, and sender."""
msg = Message(sender_id="a", receiver_id="b", content="hello", priority=5)
d = msg.to_visible_dict()
assert d == {"role": "assistant", "content": "hello", "sender": "a"}
assert "priority" not in d
assert "receiver_id" not in d
def test_to_full_dict_without_hidden(self):
"""to_full_dict without hidden does not include 'hidden' key."""
msg = Message(sender_id="a", receiver_id=None, content="hi", tags={"x"})
d = msg.to_full_dict()
assert d["role"] == "assistant"
assert d["content"] == "hi"
assert d["sender"] == "a"
assert d["message_type"] == "response"
assert d["priority"] == 0
assert set(d["tags"]) == {"x"}
assert "hidden" not in d
def test_to_full_dict_with_hidden(self):
"""to_full_dict with hidden includes the serialized channel."""
hc = HiddenChannel(hidden_state=torch.tensor([1.0, 2.0]))
msg = Message(sender_id="a", receiver_id=None, content="hi", hidden=hc)
d = msg.to_full_dict()
assert "hidden" in d
assert d["hidden"]["hidden_state"] == [1.0, 2.0]
class TestMessageProtocol:
"""Tests for MessageProtocol."""
def test_init_defaults(self):
"""Create with default parameters."""
proto = MessageProtocol()
assert proto.enable_hidden is True
assert proto.hidden_dim is None
def test_init_custom(self):
"""Create with custom parameters."""
def fn(tensors):
return tensors[0]
proto = MessageProtocol(enable_hidden=False, hidden_dim=64, combine_hidden=fn)
assert proto.enable_hidden is False
assert proto.hidden_dim == 64
assert proto.combine_hidden is fn
def test_create_message_without_hidden(self):
"""create_message without hidden data."""
proto = MessageProtocol()
msg = proto.create_message("a", "hello")
assert msg.sender_id == "a"
assert msg.content == "hello"
assert msg.hidden is None
def test_create_message_with_hidden(self):
"""create_message with hidden_state."""
proto = MessageProtocol()
hs = torch.tensor([1.0, 2.0])
msg = proto.create_message("a", "hi", hidden_state=hs)
assert msg.hidden is not None
assert msg.hidden.hidden_state is not None
assert torch.equal(msg.hidden.hidden_state, hs)
def test_create_message_hidden_disabled(self):
"""enable_hidden=False โ€” hidden data is ignored."""
proto = MessageProtocol(enable_hidden=False)
hs = torch.tensor([1.0])
msg = proto.create_message("a", "hi", hidden_state=hs)
assert msg.hidden is None
def test_create_message_with_kwargs(self):
"""create_message passes extra keyword arguments to Message."""
proto = MessageProtocol()
msg = proto.create_message("a", "hi", role="user", priority=3)
assert msg.role == "user"
assert msg.priority == 3
def test_extract_hidden_states(self):
"""extract_hidden_states collects only non-None hidden states."""
proto = MessageProtocol()
hs1 = torch.tensor([1.0])
hs2 = torch.tensor([2.0])
msgs = [
proto.create_message("a", "1", hidden_state=hs1),
proto.create_message("b", "2"),
proto.create_message("c", "3", hidden_state=hs2),
]
states = proto.extract_hidden_states(msgs)
assert len(states) == 2
assert torch.equal(states[0], hs1)
assert torch.equal(states[1], hs2)
def test_extract_hidden_states_empty(self):
"""extract_hidden_states returns empty list when no hidden states exist."""
proto = MessageProtocol()
msgs = [proto.create_message("a", "hi")]
assert proto.extract_hidden_states(msgs) == []
def test_combine_incoming_hidden(self):
"""combine_incoming_hidden averages hidden states."""
proto = MessageProtocol()
msgs = [
proto.create_message("a", "1", hidden_state=torch.tensor([2.0, 4.0])),
proto.create_message("b", "2", hidden_state=torch.tensor([6.0, 8.0])),
]
result = proto.combine_incoming_hidden(msgs)
assert result is not None
assert torch.equal(result, torch.tensor([4.0, 6.0]))
def test_combine_incoming_hidden_none(self):
"""combine_incoming_hidden returns None when no hidden states exist."""
proto = MessageProtocol()
msgs = [proto.create_message("a", "hi")]
assert proto.combine_incoming_hidden(msgs) is None
def test_combine_incoming_hidden_custom(self):
"""Custom combine function is used when provided."""
proto = MessageProtocol(combine_hidden=lambda ts: torch.sum(torch.stack(ts), dim=0))
msgs = [
proto.create_message("a", "1", hidden_state=torch.tensor([1.0, 2.0])),
proto.create_message("b", "2", hidden_state=torch.tensor([3.0, 4.0])),
]
result = proto.combine_incoming_hidden(msgs)
assert result is not None
assert torch.equal(result, torch.tensor([4.0, 6.0]))
def test_default_combine_raises_on_empty(self):
"""_default_combine raises ValueError on an empty list."""
proto = MessageProtocol()
with pytest.raises(ValueError, match="No tensors to combine"):
proto._default_combine([])
def test_format_visible(self):
"""format_visible formats messages for a prompt."""
proto = MessageProtocol()
msgs = [
proto.create_message("agent_a", "Hello"),
proto.create_message("agent_b", "Reply"),
]
result = proto.format_visible(msgs)
assert "[agent_a]:" in result
assert "Hello" in result
assert "[agent_b]:" in result
assert "Reply" in result
def test_format_visible_with_names(self):
"""format_visible uses agent name mapping when provided."""
proto = MessageProtocol()
msgs = [proto.create_message("a", "hi")]
result = proto.format_visible(msgs, agent_names={"a": "Voice"})
assert "[Voice]:" in result
assert "hi" in result
def test_format_visible_empty(self):
"""format_visible with empty list returns empty string."""
proto = MessageProtocol()
assert proto.format_visible([]) == ""
class TestMemoryLevelShared:
"""Tests for MemoryLevel.SHARED behaviour in AgentMemory."""
def test_add_shared_level_goes_to_long_term_storage(self):
"""add() with SHARED level stores the entry in _long_term internally."""
mem = AgentMemory("a")
entry = mem.add(content={"text": "shared_data"}, level=MemoryLevel.SHARED)
assert entry.level == MemoryLevel.SHARED
assert entry in mem._long_term
assert entry not in mem._working
def test_add_shared_level_applies_long_term_ttl(self):
"""Default TTL for SHARED level uses long_term_default_ttl."""
cfg = MemoryConfig(long_term_default_ttl=1234.0)
mem = AgentMemory("a", config=cfg)
entry = mem.add(content={"text": "x"}, level=MemoryLevel.SHARED)
assert entry.ttl == 1234.0
def test_get_shared_level_returns_all_entries(self):
"""get(level=SHARED) falls through to the else branch and returns all entries."""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
# SHARED is not WORKING and not LONG_TERM โ†’ else branch โ†’ all entries
result = mem.get(level=MemoryLevel.SHARED)
assert len(result) == 2
def test_clear_shared_level_only_resets_access_counts(self):
"""
clear(SHARED) does not remove WORKING or LONG_TERM entries,
but does reset access counts (current implementation behaviour).
"""
mem = AgentMemory("a")
mem.add(content={"i": 1})
mem.add(content={"i": 2}, level=MemoryLevel.LONG_TERM)
mem.clear(level=MemoryLevel.SHARED)
# Neither working nor long_term should be cleared
assert len(mem._working) == 1
assert len(mem._long_term) == 1
def test_shared_memory_pool_share_creates_shared_level_entry(self):
"""SharedMemoryPool.share() without to_agents stores a SHARED-level entry."""
pool = SharedMemoryPool()
mem = AgentMemory("sender")
pool.register(mem)
entry = mem.add(content={"text": "pool_data"})
pool.share("sender", entry)
assert len(pool._shared_entries) == 1
assert pool._shared_entries[0].level == MemoryLevel.SHARED
class TestIntegrationWorkingToLongTerm:
"""Entry lifecycle: working โ†’ promote โ†’ demote."""
def test_promote_after_n_accesses(self):
"""Entry is promoted to long_term after N accesses via get()."""
cfg = MemoryConfig(promote_after_accesses=2)
mem = AgentMemory("a", cfg)
mem.add(content={"text": "important"})
assert len(mem.get(level=MemoryLevel.WORKING)) == 1
assert len(mem.get(level=MemoryLevel.LONG_TERM)) == 0
mem.get()
assert len(mem.get(level=MemoryLevel.LONG_TERM)) >= 1
lt = mem.get(level=MemoryLevel.LONG_TERM)
assert lt[0].content == {"text": "important"}
assert lt[0].level == MemoryLevel.LONG_TERM
def test_demote_inactive_entry(self):
"""Inactive long-term entry is demoted back to working."""
cfg = MemoryConfig(
promote_after_accesses=1,
demote_inactive_after=0.1,
cleanup_interval=0.0,
)
mem = AgentMemory("a", cfg)
mem.add(content={"text": "data"}, level=MemoryLevel.LONG_TERM)
assert len(mem._long_term) == 1
time.sleep(0.2)
_ = mem.working_memory
assert len(mem._long_term) == 0
assert any(e.content == {"text": "data"} for e in mem._working)
def test_full_cycle_promote_then_demote(self):
"""Working โ†’ promote โ†’ long_term โ†’ demote โ†’ working."""
cfg = MemoryConfig(
promote_after_accesses=2,
demote_inactive_after=0.1,
cleanup_interval=0.0,
)
mem = AgentMemory("a", cfg)
mem.add(content={"msg": "cycle"})
mem.get()
mem.get()
assert any(e.content == {"msg": "cycle"} for e in mem._long_term)
time.sleep(0.2)
_ = mem.working_memory
assert any(e.content == {"msg": "cycle"} for e in mem._working)
assert all(e.content != {"msg": "cycle"} for e in mem._long_term)
class TestIntegrationCompression:
"""Compression on working memory overflow."""
def test_truncate_on_overflow(self):
"""Adding more than working_max_entries triggers truncation."""
cfg = MemoryConfig(working_max_entries=5, auto_compress=True)
mem = AgentMemory("a", cfg)
for i in range(10):
mem.add(content={"idx": i})
assert len(mem._working) == 5
def test_high_priority_survives_compression(self):
"""High-priority entries survive compression."""
cfg = MemoryConfig(working_max_entries=3, auto_compress=True)
mem = AgentMemory("a", cfg)
mem.add(content={"text": "important"}, priority=10)
for i in range(5):
mem.add(content={"text": f"regular_{i}"}, priority=0)
contents = [e.content["text"] for e in mem._working]
assert "important" in contents
class TestIntegrationSharedMemoryPool:
"""SharedMemoryPool: sharing between agents."""
def test_share_to_specific_agents(self):
"""Agent A shares an entry with agents B and C."""
pool = SharedMemoryPool()
mem_a = AgentMemory("a")
mem_b = AgentMemory("b")
mem_c = AgentMemory("c")
pool.register(mem_a)
pool.register(mem_b)
pool.register(mem_c)
entry = mem_a.add(content={"role": "assistant", "content": "answer from A"})
pool.share("a", entry, to_agents=["b", "c"])
b_entries = mem_b.get()
c_entries = mem_c.get()
assert any(e.content["content"] == "answer from A" for e in b_entries)
assert any(e.content["content"] == "answer from A" for e in c_entries)
assert any("shared" in e.tags for e in b_entries)
a_entries = mem_a.get()
shared_entries = [e for e in a_entries if "shared" in e.tags]
assert len(shared_entries) == 0
def test_share_to_pool_with_access_filter(self):
"""SubgraphFilter: agent from a different subgraph cannot see the entry."""
flt = SubgraphFilter({"sg1": {"a", "b"}, "sg2": {"c"}})
pool = SharedMemoryPool(access_filter=flt)
mem_a = AgentMemory("a")
mem_b = AgentMemory("b")
mem_c = AgentMemory("c")
pool.register(mem_a)
pool.register(mem_b)
pool.register(mem_c)
entry = mem_a.add(content={"text": "secret"})
pool.share("a", entry)
b_shared = pool.get_shared("b")
assert len(b_shared) == 1
c_shared = pool.get_shared("c")
assert len(c_shared) == 0
class TestIntegrationBroadcast:
"""Broadcast: sending to all agents."""
def test_broadcast_reaches_all_except_sender(self):
"""Broadcast from A reaches B, C, D but not A."""
pool = SharedMemoryPool()
agents = {}
for name in ["a", "b", "c", "d"]:
mem = AgentMemory(name)
pool.register(mem)
agents[name] = mem
pool.broadcast(
from_agent="a",
content={"role": "system", "content": "New rules"},
tags={"instruction"},
)
a_entries = agents["a"].get(tags={"broadcast"})
assert len(a_entries) == 0
for name in ["b", "c", "d"]:
entries = agents[name].get(tags={"broadcast"})
assert len(entries) == 1
assert entries[0].content["content"] == "New rules"
assert "broadcast" in entries[0].tags
assert "instruction" in entries[0].tags
def test_broadcast_goes_to_working_level(self):
"""Broadcast entries land in WORKING memory."""
pool = SharedMemoryPool()
mem_a = AgentMemory("a")
mem_b = AgentMemory("b")
pool.register(mem_a)
pool.register(mem_b)
pool.broadcast("a", content={"text": "update"})
b_working = mem_b.get(level=MemoryLevel.WORKING)
assert any(e.content["text"] == "update" for e in b_working)
b_lt = mem_b.get(level=MemoryLevel.LONG_TERM)
assert all(e.content.get("text") != "update" for e in b_lt)
def test_broadcast_with_no_tags(self):
"""Broadcast without tags adds only the 'broadcast' tag."""
pool = SharedMemoryPool()
mem_a = AgentMemory("a")
mem_b = AgentMemory("b")
pool.register(mem_a)
pool.register(mem_b)
pool.broadcast("a", content={"msg": "hello"})
entries = mem_b.get()
bcast = [e for e in entries if "broadcast" in e.tags]
assert len(bcast) == 1
assert bcast[0].tags == {"broadcast"}
def test_broadcast_source_agent_set(self):
"""Broadcast sets source_agent on each recipient's entry."""
pool = SharedMemoryPool()
mem_a = AgentMemory("a")
mem_b = AgentMemory("b")
pool.register(mem_a)
pool.register(mem_b)
pool.broadcast("a", content={"msg": "hi"})
entries = mem_b.get()
bcast = [e for e in entries if "broadcast" in e.tags]
assert len(bcast) == 1
assert bcast[0].source_agent == "a"
def test_multiple_broadcasts_accumulate(self):
"""Multiple broadcasts from different agents accumulate."""
pool = SharedMemoryPool()
agents = {}
for name in ["a", "b", "c"]:
mem = AgentMemory(name)
pool.register(mem)
agents[name] = mem
pool.broadcast("a", content={"msg": "from_a"})
pool.broadcast("b", content={"msg": "from_b"})
c_entries = agents["c"].get(tags={"broadcast"})
assert len(c_entries) == 2
msgs = {e.content["msg"] for e in c_entries}
assert msgs == {"from_a", "from_b"}
a_entries = agents["a"].get(tags={"broadcast"})
assert len(a_entries) == 1
assert a_entries[0].content["msg"] == "from_b"
class TestIntegrationTTLCleanup:
"""TTL and cleanup under repeated access."""
def test_expired_entries_not_returned(self):
"""Entries with expired TTL are not returned by get()."""
mem = AgentMemory("a")
mem.add(content={"text": "short-lived"}, ttl=0.1)
mem.add(content={"text": "long-lived"}, ttl=9999)
time.sleep(0.2)
entries = mem.get()
contents = [e.content["text"] for e in entries]
assert "short-lived" not in contents
assert "long-lived" in contents
def test_remove_expired_cleans_up(self):
"""remove_expired() physically removes expired entries."""
mem = AgentMemory("a")
mem.add(content={"text": "temp"}, ttl=0.1)
mem.add(content={"text": "perm"}, ttl=None)
time.sleep(0.2)
removed = mem.remove_expired()
assert removed == 1
assert len(mem._working) == 1
assert mem._working[0].content["text"] == "perm"
class TestIntegrationFullPipeline:
"""Full pipeline: add โ†’ compress โ†’ promote โ†’ share โ†’ get_shared."""
def test_end_to_end_memory_pipeline(self):
"""Entry travels the full path: add โ†’ compress โ†’ promote โ†’ share."""
cfg = MemoryConfig(
working_max_entries=5,
promote_after_accesses=2,
auto_compress=True,
)
pool = SharedMemoryPool()
mem_a = AgentMemory("a", cfg)
mem_b = AgentMemory("b")
pool.register(mem_a)
pool.register(mem_b)
mem_a.add(content={"text": "key_entry"}, priority=10)
for i in range(8):
mem_a.add(content={"text": f"filler_{i}"}, priority=0)
contents = [e.content["text"] for e in mem_a._working]
assert "key_entry" in contents
mem_a.get()
mem_a.get()
lt = mem_a.get(level=MemoryLevel.LONG_TERM)
lt_texts = [e.content["text"] for e in lt]
assert "key_entry" in lt_texts
key_entry = next(e for e in lt if e.content["text"] == "key_entry")
pool.share("a", key_entry, to_agents=["b"])
b_entries = mem_b.get()
assert any(e.content["text"] == "key_entry" for e in b_entries)
class TestLongTermCompression:
"""Tests for _maybe_compress_long_term (lines 339, 341)."""
def test_long_term_auto_compress(self):
"""Adding entries beyond long_term_max_entries triggers compression (lines 339-341)."""
cfg = MemoryConfig(
long_term_max_entries=3,
auto_compress=True,
compression_strategy=TruncateCompressor(),
)
mem = AgentMemory(agent_id="test", config=cfg)
# Add 5 long-term entries to trigger compression
for i in range(5):
mem.add(
content={"text": f"entry {i}"},
level=MemoryLevel.LONG_TERM,
)
# After compression, we should have at most long_term_max_entries entries
long_term = mem.get(level=MemoryLevel.LONG_TERM)
assert len(long_term) <= cfg.long_term_max_entries
def test_long_term_auto_compress_disabled(self):
"""When auto_compress=False, no compression happens (line 339: return early)."""
cfg = MemoryConfig(
long_term_max_entries=3,
auto_compress=False,
)
mem = AgentMemory(agent_id="test", config=cfg)
for i in range(5):
mem.add(
content={"text": f"entry {i}"},
level=MemoryLevel.LONG_TERM,
)
long_term = mem.get(level=MemoryLevel.LONG_TERM)
assert len(long_term) == 5 # No compression
if __name__ == "__main__":
pytest.main([__file__, "-v"])