Spaces:
Sleeping
Sleeping
File size: 4,750 Bytes
75bea1c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | from __future__ import annotations
"""Context memory management for the agent."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class MemoryEntry:
"""A single entry in memory."""
key: str
value: Any
timestamp: datetime = field(default_factory=datetime.now)
source: str = "unknown"
relevance: float = 1.0
class ContextMemory:
"""Manages context and working memory for the agent."""
def __init__(self, max_entries: int = 100):
"""Initialize memory.
Args:
max_entries: Maximum entries to keep
"""
self.max_entries = max_entries
self._short_term: dict[str, MemoryEntry] = {}
self._working: dict[str, Any] = {}
self._conversation: list[dict[str, str]] = []
def store(self, key: str, value: Any, source: str = "agent") -> None:
"""Store a value in short-term memory.
Args:
key: Memory key
value: Value to store
source: Source of the information
"""
self._short_term[key] = MemoryEntry(
key=key,
value=value,
source=source,
)
# Trim if over capacity
if len(self._short_term) > self.max_entries:
self._trim_oldest()
def retrieve(self, key: str) -> Any | None:
"""Retrieve a value from memory.
Args:
key: Memory key
Returns:
Stored value or None
"""
entry = self._short_term.get(key)
return entry.value if entry else None
def update_working(self, key: str, value: Any) -> None:
"""Update working memory.
Args:
key: Memory key
value: Value to store
"""
self._working[key] = value
def get_working(self, key: str, default: Any = None) -> Any:
"""Get from working memory.
Args:
key: Memory key
default: Default value if not found
Returns:
Stored value or default
"""
return self._working.get(key, default)
def add_conversation_turn(self, role: str, content: str) -> None:
"""Add a turn to conversation history.
Args:
role: Message role (user/assistant)
content: Message content
"""
self._conversation.append({
"role": role,
"content": content,
"timestamp": datetime.now().isoformat(),
})
def get_conversation_history(self, limit: int = 10) -> list[dict[str, str]]:
"""Get recent conversation history.
Args:
limit: Maximum turns to return
Returns:
List of conversation turns
"""
return self._conversation[-limit:]
def get_context_summary(self) -> dict[str, Any]:
"""Get a summary of current context.
Returns:
Dictionary with context summary
"""
return {
"short_term_keys": list(self._short_term.keys()),
"working_memory_keys": list(self._working.keys()),
"conversation_length": len(self._conversation),
}
def clear_working(self) -> None:
"""Clear working memory."""
self._working.clear()
def clear_all(self) -> None:
"""Clear all memory."""
self._short_term.clear()
self._working.clear()
self._conversation.clear()
def _trim_oldest(self) -> None:
"""Remove oldest entries to stay within capacity."""
if not self._short_term:
return
# Sort by timestamp and remove oldest
sorted_keys = sorted(
self._short_term.keys(),
key=lambda k: self._short_term[k].timestamp,
)
# Remove oldest 10%
to_remove = max(1, len(sorted_keys) // 10)
for key in sorted_keys[:to_remove]:
del self._short_term[key]
def search(self, query: str) -> list[MemoryEntry]:
"""Search memory for relevant entries.
Args:
query: Search query
Returns:
List of matching entries
"""
query_lower = query.lower()
results = []
for entry in self._short_term.values():
# Simple keyword matching
value_str = str(entry.value).lower()
if query_lower in value_str or query_lower in entry.key.lower():
results.append(entry)
# Sort by relevance (for now, just by timestamp)
results.sort(key=lambda e: e.timestamp, reverse=True)
return results
|