gMAS / src /utils /memory.py
Артём Боярских
chore: initial commit
3193174
"""
Agent memory system with stratified levels, sharing, and hidden channels.
Supports:
- Memory stratification (working/long-term memory)
- TTL and size limits
- Compression and truncation
- Cross-subgraph sharing with access filters
- External storage backends
- Hidden channels (hidden_state, embeddings)
"""
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from typing import Any, Protocol, runtime_checkable
import torch
from pydantic import BaseModel, ConfigDict, Field
__all__ = [
"AccessFilter",
"AgentMemory",
"AsyncMemoryStorage",
"CompressionStrategy",
"HiddenChannel",
"MemoryConfig",
"MemoryEntry",
"MemoryLevel",
"MemoryStorage",
"Message",
"MessageProtocol",
"SharedMemoryPool",
"SharingPolicy",
"SummaryCompressor",
"TruncateCompressor",
]
class MemoryLevel(str, Enum):
WORKING = "working"
LONG_TERM = "long_term"
SHARED = "shared"
class MemoryEntry(BaseModel):
"""Memory entry with TTL, priority, and metadata."""
content: dict[str, Any]
level: MemoryLevel = MemoryLevel.WORKING
created_at: float = Field(default_factory=time.time)
accessed_at: float = Field(default_factory=time.time)
ttl: float | None = None
priority: int = 0
tags: set[str] = Field(default_factory=set)
source_agent: str | None = None
@property
def is_expired(self) -> bool:
"""True if the TTL has expired."""
if self.ttl is None:
return False
return time.time() - self.created_at > self.ttl
def touch(self) -> None:
"""Update the last-accessed timestamp."""
self.accessed_at = time.time()
@property
def age(self) -> float:
"""Age of the entry in seconds."""
return time.time() - self.created_at
class CompressionStrategy(ABC):
"""Abstract base strategy for compressing a collection of MemoryEntry objects."""
@abstractmethod
def compress(
self,
entries: list[MemoryEntry],
max_entries: int,
) -> list[MemoryEntry]: ...
class TruncateCompressor(CompressionStrategy):
"""Compressor that keeps a limited number of entries ordered by priority and recency."""
def __init__(self, *, keep_recent: bool = True, keep_high_priority: bool = True):
self.keep_recent = keep_recent
self.keep_high_priority = keep_high_priority
def compress(
self,
entries: list[MemoryEntry],
max_entries: int,
) -> list[MemoryEntry]:
"""Trim entries beyond the limit according to the sorting strategy."""
if len(entries) <= max_entries:
return entries
def sort_key(e: MemoryEntry) -> tuple[int, float]:
priority = -e.priority if self.keep_high_priority else 0
age = -e.accessed_at if self.keep_recent else e.accessed_at
return (priority, age)
sorted_entries = sorted(entries, key=sort_key)
return sorted_entries[:max_entries]
class SummaryCompressor(CompressionStrategy):
"""Compressor that collapses old entries into a summary using a summarizer callable."""
def __init__(
self,
summarizer: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None,
batch_size: int = 5,
):
self.summarizer = summarizer
self.batch_size = batch_size
def compress(
self,
entries: list[MemoryEntry],
max_entries: int,
) -> list[MemoryEntry]:
"""Compress entries, keeping a summary of old ones plus the most recent."""
if len(entries) <= max_entries or self.summarizer is None:
return entries[-max_entries:] if len(entries) > max_entries else entries
to_keep = entries[-(max_entries - 1) :]
to_summarize = entries[: -(max_entries - 1)]
summary_content = self.summarizer([e.content for e in to_summarize])
summary_entry = MemoryEntry(
content=summary_content,
level=MemoryLevel.LONG_TERM,
priority=max(e.priority for e in to_summarize) if to_summarize else 0,
tags=set().union(*(e.tags for e in to_summarize)),
)
return [summary_entry, *to_keep]
class MemoryConfig(BaseModel):
"""Configuration for agent memory capacity, TTL, and compression strategy."""
model_config = ConfigDict(arbitrary_types_allowed=True)
working_max_entries: int = 20
working_default_ttl: float | None = 3600.0
long_term_max_entries: int = 100
long_term_default_ttl: float | None = None
compression_strategy: CompressionStrategy = Field(default_factory=lambda: TruncateCompressor())
auto_compress: bool = True
promote_after_accesses: int = 3
demote_inactive_after: float = 7200.0
cleanup_interval: float = 300.0
hidden_state_dim: int | None = None
class AgentMemory:
"""Agent memory manager with working/long-term levels and hidden channels."""
def __init__(
self,
agent_id: str,
config: MemoryConfig | None = None,
):
self.agent_id = agent_id
self.config = config or MemoryConfig()
self._working: list[MemoryEntry] = []
self._long_term: list[MemoryEntry] = []
self._access_counts: dict[int, int] = {}
self._last_cleanup = time.time()
self._hidden_state: torch.Tensor | None = None
self._embedding: torch.Tensor | None = None
@property
def working_memory(self) -> list[MemoryEntry]:
"""Active working-memory entries (expired entries excluded)."""
self._maybe_cleanup()
return [e for e in self._working if not e.is_expired]
@property
def long_term_memory(self) -> list[MemoryEntry]:
"""Active long-term memory entries (expired entries excluded)."""
return [e for e in self._long_term if not e.is_expired]
@property
def all_entries(self) -> list[MemoryEntry]:
"""All entries from both zones (with TTL filtering applied)."""
return self.working_memory + self.long_term_memory
@property
def hidden_state(self) -> torch.Tensor | None:
return self._hidden_state
@hidden_state.setter
def hidden_state(self, value: torch.Tensor | None) -> None:
self._hidden_state = value
@property
def embedding(self) -> torch.Tensor | None:
return self._embedding
@embedding.setter
def embedding(self, value: torch.Tensor | None) -> None:
self._embedding = value
def add(
self,
content: dict[str, Any],
level: MemoryLevel = MemoryLevel.WORKING,
ttl: float | None = None,
priority: int = 0,
tags: set[str] | None = None,
source_agent: str | None = None,
) -> MemoryEntry:
"""Add an entry to the specified memory level."""
if ttl is None:
ttl = self.config.working_default_ttl if level == MemoryLevel.WORKING else self.config.long_term_default_ttl
entry = MemoryEntry(
content=content,
level=level,
ttl=ttl,
priority=priority,
tags=tags or set(),
source_agent=source_agent,
)
if level == MemoryLevel.WORKING:
self._working.append(entry)
self._maybe_compress_working()
else:
self._long_term.append(entry)
self._maybe_compress_long_term()
return entry
def add_message(
self,
role: str,
content: str,
**metadata,
) -> MemoryEntry:
"""Convenience method for adding a chat message."""
return self.add(
content={"role": role, "content": content, **metadata},
level=MemoryLevel.WORKING,
)
def get(
self,
level: MemoryLevel | None = None,
tags: set[str] | None = None,
limit: int | None = None,
*,
include_expired: bool = False,
) -> list[MemoryEntry]:
"""Retrieve entries filtered by level, tags, and/or limit."""
if level == MemoryLevel.WORKING:
entries = self._working
elif level == MemoryLevel.LONG_TERM:
entries = self._long_term
else:
entries = self._working + self._long_term
if not include_expired:
entries = [e for e in entries if not e.is_expired]
if tags:
entries = [e for e in entries if tags & e.tags]
if limit:
entries = entries[-limit:]
for entry in entries:
entry.touch()
entry_id = id(entry)
self._access_counts[entry_id] = self._access_counts.get(entry_id, 0) + 1
if (
entry.level == MemoryLevel.WORKING
and self._access_counts[entry_id] >= self.config.promote_after_accesses
):
self._promote(entry)
return entries
def get_messages(self, limit: int | None = None) -> list[dict[str, Any]]:
"""Return contents of entries that contain a 'role' field (chat messages)."""
entries = self.get(limit=limit)
return [e.content for e in entries if "role" in e.content]
def clear(self, level: MemoryLevel | None = None) -> None:
"""Clear the specified memory level, or all levels if None."""
if level is None or level == MemoryLevel.WORKING:
self._working.clear()
if level is None or level == MemoryLevel.LONG_TERM:
self._long_term.clear()
self._access_counts.clear()
def remove_expired(self) -> int:
"""Remove expired entries and return the number removed."""
before = len(self._working) + len(self._long_term)
self._working = [e for e in self._working if not e.is_expired]
self._long_term = [e for e in self._long_term if not e.is_expired]
after = len(self._working) + len(self._long_term)
return before - after
def _promote(self, entry: MemoryEntry) -> None:
if entry in self._working:
self._working.remove(entry)
entry.level = MemoryLevel.LONG_TERM
entry.ttl = self.config.long_term_default_ttl
self._long_term.append(entry)
def _demote(self, entry: MemoryEntry) -> None:
if entry in self._long_term:
self._long_term.remove(entry)
entry.level = MemoryLevel.WORKING
entry.ttl = self.config.working_default_ttl
self._working.append(entry)
def _maybe_compress_working(self) -> None:
if not self.config.auto_compress:
return
if len(self._working) > self.config.working_max_entries:
self._working = self.config.compression_strategy.compress(self._working, self.config.working_max_entries)
def _maybe_compress_long_term(self) -> None:
if not self.config.auto_compress:
return
if len(self._long_term) > self.config.long_term_max_entries:
self._long_term = self.config.compression_strategy.compress(
self._long_term, self.config.long_term_max_entries
)
def _maybe_cleanup(self) -> None:
now = time.time()
if now - self._last_cleanup < self.config.cleanup_interval:
return
self._last_cleanup = now
self.remove_expired()
inactive_threshold = now - self.config.demote_inactive_after
for entry in list(self._long_term):
if entry.accessed_at < inactive_threshold:
self._demote(entry)
def to_dict(self) -> dict[str, Any]:
"""Serialize agent memory for persistence."""
return {
"agent_id": self.agent_id,
"working": [
{
"content": e.content,
"created_at": e.created_at,
"ttl": e.ttl,
"priority": e.priority,
"tags": list(e.tags),
}
for e in self._working
],
"long_term": [
{
"content": e.content,
"created_at": e.created_at,
"ttl": e.ttl,
"priority": e.priority,
"tags": list(e.tags),
}
for e in self._long_term
],
"hidden_state": (self._hidden_state.cpu().tolist() if self._hidden_state is not None else None),
"embedding": (self._embedding.cpu().tolist() if self._embedding is not None else None),
}
@classmethod
def from_dict(
cls,
data: dict[str, Any],
config: MemoryConfig | None = None,
) -> "AgentMemory":
"""Restore an AgentMemory instance from a dictionary."""
memory = cls(data["agent_id"], config)
for e_data in data.get("working", []):
entry = MemoryEntry(
content=e_data["content"],
level=MemoryLevel.WORKING,
created_at=e_data.get("created_at", time.time()),
ttl=e_data.get("ttl"),
priority=e_data.get("priority", 0),
tags=set(e_data.get("tags", [])),
)
memory._working.append(entry)
for e_data in data.get("long_term", []):
entry = MemoryEntry(
content=e_data["content"],
level=MemoryLevel.LONG_TERM,
created_at=e_data.get("created_at", time.time()),
ttl=e_data.get("ttl"),
priority=e_data.get("priority", 0),
tags=set(e_data.get("tags", [])),
)
memory._long_term.append(entry)
if data.get("hidden_state"):
memory._hidden_state = torch.tensor(data["hidden_state"])
if data.get("embedding"):
memory._embedding = torch.tensor(data["embedding"])
return memory
class AccessFilter(ABC):
"""Interface for access filters on shared memory."""
@abstractmethod
def can_access(
self,
requester_id: str,
owner_id: str,
entry: MemoryEntry,
) -> bool: ...
class SharingPolicy(str, Enum):
NONE = "none"
SAME_SUBGRAPH = "same_subgraph"
BY_TAGS = "by_tags"
BY_ROLE_FAMILY = "by_role_family"
FULL = "full"
class TagBasedFilter(AccessFilter):
"""Access filter that grants access based on overlapping tags."""
def __init__(self, shared_tags: set[str]):
self.shared_tags = shared_tags
def can_access(
self,
requester_id: str,
owner_id: str,
entry: MemoryEntry,
) -> bool:
del requester_id, owner_id # Unused in this filter implementation
return bool(entry.tags & self.shared_tags)
class SubgraphFilter(AccessFilter):
"""Access filter that restricts access to agents in the same subgraph."""
def __init__(self, subgraph_members: dict[str, set[str]]):
self.subgraph_members = subgraph_members
self._agent_to_subgraph: dict[str, str] = {}
for sg_id, members in subgraph_members.items():
for member in members:
self._agent_to_subgraph[member] = sg_id
def can_access(
self,
requester_id: str,
owner_id: str,
entry: MemoryEntry,
) -> bool:
del entry # Unused in this filter implementation
req_sg = self._agent_to_subgraph.get(requester_id)
own_sg = self._agent_to_subgraph.get(owner_id)
return req_sg is not None and req_sg == own_sg
class RoleFamilyFilter(AccessFilter):
"""Access filter that restricts access to agents belonging to the same role family."""
def __init__(self, role_families: dict[str, set[str]]):
self.role_families = role_families
self._agent_to_family: dict[str, str] = {}
for family, members in role_families.items():
for member in members:
self._agent_to_family[member] = family
def can_access(
self,
requester_id: str,
owner_id: str,
entry: MemoryEntry,
) -> bool:
del entry # Unused in this filter implementation
req_fam = self._agent_to_family.get(requester_id)
own_fam = self._agent_to_family.get(owner_id)
return req_fam is not None and req_fam == own_fam
class SharedMemoryPool:
"""Shared memory pool with configurable access and propagation policies."""
def __init__(
self,
access_filter: AccessFilter | None = None,
default_policy: SharingPolicy = SharingPolicy.BY_TAGS,
):
self.access_filter = access_filter
self.default_policy = default_policy
self._memories: dict[str, AgentMemory] = {}
self._shared_entries: list[MemoryEntry] = []
def register(self, memory: AgentMemory) -> None:
self._memories[memory.agent_id] = memory
def unregister(self, agent_id: str) -> None:
self._memories.pop(agent_id, None)
def share(
self,
from_agent: str,
entry: MemoryEntry,
to_agents: list[str] | None = None,
) -> None:
"""Share an entry with specific agents or place it in the shared pool."""
shared_entry = MemoryEntry(
content=entry.content.copy(),
level=MemoryLevel.SHARED,
ttl=entry.ttl,
priority=entry.priority,
tags=entry.tags.copy(),
source_agent=from_agent,
)
if to_agents:
for agent_id in to_agents:
if agent_id in self._memories:
self._memories[agent_id].add(
content=shared_entry.content,
level=MemoryLevel.WORKING,
priority=shared_entry.priority,
tags=shared_entry.tags | {"shared"},
source_agent=from_agent,
)
else:
self._shared_entries.append(shared_entry)
def get_shared(
self,
requester_id: str,
tags: set[str] | None = None,
limit: int | None = None,
) -> list[MemoryEntry]:
"""Retrieve entries from the shared pool, respecting the access filter."""
entries = []
for entry in self._shared_entries:
if (
self.access_filter
and entry.source_agent
and not self.access_filter.can_access(requester_id, entry.source_agent, entry)
):
continue
if tags and not (tags & entry.tags):
continue
if not entry.is_expired:
entries.append(entry)
if limit:
entries = entries[-limit:]
return entries
def get_from_agent(
self,
requester_id: str,
owner_id: str,
level: MemoryLevel | None = None,
tags: set[str] | None = None,
) -> list[MemoryEntry]:
"""Retrieve entries from a specific agent, subject to the access filter."""
if owner_id not in self._memories:
return []
owner_memory = self._memories[owner_id]
entries = owner_memory.get(level=level, tags=tags)
if self.access_filter:
entries = [e for e in entries if self.access_filter.can_access(requester_id, owner_id, e)]
return entries
def broadcast(
self,
from_agent: str,
content: dict[str, Any],
tags: set[str] | None = None,
) -> None:
"""Broadcast an entry to all registered agents except the sender."""
for agent_id, memory in self._memories.items():
if agent_id != from_agent:
memory.add(
content=content,
level=MemoryLevel.WORKING,
tags=(tags or set()) | {"broadcast"},
source_agent=from_agent,
)
@runtime_checkable
class MemoryStorage(Protocol):
def save_memory(self, agent_id: str, memory: AgentMemory) -> None: ...
def load_memory(self, agent_id: str) -> AgentMemory | None: ...
def delete_memory(self, agent_id: str) -> None: ...
def list_agents(self) -> list[str]: ...
@runtime_checkable
class AsyncMemoryStorage(Protocol):
async def save_memory(self, agent_id: str, memory: AgentMemory) -> None: ...
async def load_memory(self, agent_id: str) -> AgentMemory | None: ...
async def delete_memory(self, agent_id: str) -> None: ...
async def list_agents(self) -> list[str]: ...
class HiddenChannel(BaseModel):
"""Container for hidden state and embedding tensors with metadata."""
model_config = ConfigDict(arbitrary_types_allowed=True)
hidden_state: torch.Tensor | None = None
embedding: torch.Tensor | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Serialize the hidden channel to a dictionary."""
return {
"hidden_state": (self.hidden_state.cpu().tolist() if self.hidden_state is not None else None),
"embedding": (self.embedding.cpu().tolist() if self.embedding is not None else None),
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "HiddenChannel":
"""Restore a HiddenChannel from a dictionary."""
return cls(
hidden_state=(torch.tensor(data["hidden_state"]) if data.get("hidden_state") else None),
embedding=(torch.tensor(data["embedding"]) if data.get("embedding") else None),
metadata=data.get("metadata", {}),
)
class Message(BaseModel):
"""Message with a visible content part and an optional hidden channel."""
sender_id: str
receiver_id: str | None
content: str
role: str = "assistant"
timestamp: float = Field(default_factory=time.time)
hidden: HiddenChannel | None = None
message_type: str = "response"
priority: int = 0
tags: set[str] = Field(default_factory=set)
@property
def has_hidden(self) -> bool:
"""True if the message carries a hidden state or embedding."""
return self.hidden is not None and (self.hidden.hidden_state is not None or self.hidden.embedding is not None)
def to_visible_dict(self) -> dict[str, Any]:
"""Return only the visible part of the message (role/content/sender)."""
return {
"role": self.role,
"content": self.content,
"sender": self.sender_id,
}
def to_full_dict(self) -> dict[str, Any]:
"""Full message representation including metadata and hidden channel."""
result = self.to_visible_dict()
result["timestamp"] = self.timestamp
result["message_type"] = self.message_type
result["priority"] = self.priority
result["tags"] = list(self.tags)
if self.hidden:
result["hidden"] = self.hidden.to_dict()
return result
class MessageProtocol:
"""Message protocol: creates messages and combines hidden states."""
def __init__(
self,
*,
enable_hidden: bool = True,
hidden_dim: int | None = None,
combine_hidden: Callable[[list[torch.Tensor]], torch.Tensor] | None = None,
):
self.enable_hidden = enable_hidden
self.hidden_dim = hidden_dim
self.combine_hidden = combine_hidden or self._default_combine
@staticmethod
def _default_combine(tensors: list[torch.Tensor]) -> torch.Tensor:
"""Average hidden tensors — default implementation."""
if not tensors:
msg = "No tensors to combine"
raise ValueError(msg)
stacked = torch.stack(tensors)
return stacked.mean(dim=0)
def create_message(
self,
sender_id: str,
content: str,
receiver_id: str | None = None,
hidden_state: torch.Tensor | None = None,
embedding: torch.Tensor | None = None,
**kwargs,
) -> Message:
"""Create a Message, optionally packaging hidden data."""
hidden = None
if self.enable_hidden and (hidden_state is not None or embedding is not None):
hidden = HiddenChannel(
hidden_state=hidden_state,
embedding=embedding,
)
return Message(
sender_id=sender_id,
receiver_id=receiver_id,
content=content,
hidden=hidden,
**kwargs,
)
def extract_hidden_states(
self,
messages: list[Message],
) -> list[torch.Tensor]:
"""Extract a list of hidden states from messages."""
return [msg.hidden.hidden_state for msg in messages if msg.hidden and msg.hidden.hidden_state is not None]
def combine_incoming_hidden(
self,
messages: list[Message],
) -> torch.Tensor | None:
"""Combine hidden states from incoming messages."""
states = self.extract_hidden_states(messages)
if not states:
return None
return self.combine_hidden(states)
def format_visible(
self,
messages: list[Message],
agent_names: dict[str, str] | None = None,
) -> str:
"""Format the visible parts of messages for a prompt."""
parts = []
for msg in messages:
name = (agent_names or {}).get(msg.sender_id, msg.sender_id)
parts.append(f"[{name}]:\n{msg.content}")
return "\n\n".join(parts)