File size: 6,866 Bytes
bb3ee41
 
 
 
 
 
bfe0e24
bb3ee41
 
 
 
 
 
 
bfe0e24
 
 
 
 
bb3ee41
 
 
 
 
bfe0e24
bb3ee41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfe0e24
bb3ee41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""Short-term memory for episode-scoped data storage."""

from __future__ import annotations

import asyncio
from collections import OrderedDict
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar

from pydantic import BaseModel, Field

T = TypeVar("T")


def _utc_now() -> datetime:
    """Return current UTC datetime."""
    return datetime.now(timezone.utc)


class MemoryEntry(BaseModel, Generic[T]):
    """A single memory entry with metadata."""

    key: str
    value: Any
    created_at: datetime = Field(default_factory=_utc_now)
    updated_at: datetime = Field(default_factory=datetime.utcnow)
    access_count: int = 0
    tags: list[str] = Field(default_factory=list)

    model_config = {"arbitrary_types_allowed": True}


class ShortTermMemory:
    """
    Episode-scoped memory using dictionary-based storage.

    This memory layer is designed for transient data that should persist
    only within a single episode. It automatically clears when the episode
    resets.

    Attributes:
        max_size: Maximum number of entries allowed.
        _store: Internal storage dictionary.
        _episode_id: Current episode identifier.
    """

    def __init__(self, max_size: int = 100) -> None:
        """
        Initialize short-term memory.

        Args:
            max_size: Maximum number of entries to store. Defaults to 100.
        """
        self.max_size = max_size
        self._store: OrderedDict[str, MemoryEntry] = OrderedDict()
        self._episode_id: str | None = None
        self._lock = asyncio.Lock()

    @property
    def episode_id(self) -> str | None:
        """Get the current episode ID."""
        return self._episode_id

    @property
    def size(self) -> int:
        """Get the current number of entries."""
        return len(self._store)

    async def set_episode(self, episode_id: str) -> None:
        """
        Set the current episode ID and clear existing memory.

        Args:
            episode_id: Unique identifier for the new episode.
        """
        async with self._lock:
            if self._episode_id != episode_id:
                self._store.clear()
                self._episode_id = episode_id

    async def set(
        self,
        key: str,
        value: Any,
        tags: list[str] | None = None,
    ) -> MemoryEntry:
        """
        Store a value in short-term memory.

        Args:
            key: Unique key for the entry.
            value: Value to store.
            tags: Optional tags for categorization.

        Returns:
            The created or updated memory entry.

        Raises:
            ValueError: If max_size would be exceeded for a new key.
        """
        async with self._lock:
            now = datetime.now(timezone.utc)

            if key in self._store:
                entry = self._store[key]
                entry.value = value
                entry.updated_at = now
                if tags is not None:
                    entry.tags = tags
                # Move to end (most recent)
                self._store.move_to_end(key)
            else:
                # Check capacity
                if len(self._store) >= self.max_size:
                    # Remove oldest entry
                    self._store.popitem(last=False)

                entry = MemoryEntry(
                    key=key,
                    value=value,
                    created_at=now,
                    updated_at=now,
                    tags=tags or [],
                )
                self._store[key] = entry

            return entry

    async def get(self, key: str, default: Any = None) -> Any:
        """
        Retrieve a value from short-term memory.

        Args:
            key: Key to look up.
            default: Default value if key not found.

        Returns:
            The stored value or default.
        """
        async with self._lock:
            entry = self._store.get(key)
            if entry is None:
                return default
            entry.access_count += 1
            return entry.value

    async def get_entry(self, key: str) -> MemoryEntry | None:
        """
        Retrieve a full memory entry with metadata.

        Args:
            key: Key to look up.

        Returns:
            The memory entry or None if not found.
        """
        async with self._lock:
            entry = self._store.get(key)
            if entry:
                entry.access_count += 1
            return entry

    async def delete(self, key: str) -> bool:
        """
        Delete an entry from memory.

        Args:
            key: Key to delete.

        Returns:
            True if the key was found and deleted, False otherwise.
        """
        async with self._lock:
            if key in self._store:
                del self._store[key]
                return True
            return False

    async def clear(self) -> int:
        """
        Clear all entries from memory.

        Returns:
            Number of entries that were cleared.
        """
        async with self._lock:
            count = len(self._store)
            self._store.clear()
            return count

    async def list_keys(self, tag: str | None = None) -> list[str]:
        """
        List all keys in memory, optionally filtered by tag.

        Args:
            tag: Optional tag to filter by.

        Returns:
            List of matching keys.
        """
        async with self._lock:
            if tag is None:
                return list(self._store.keys())
            return [k for k, v in self._store.items() if tag in v.tags]

    async def get_by_tag(self, tag: str) -> dict[str, Any]:
        """
        Retrieve all entries with a specific tag.

        Args:
            tag: Tag to filter by.

        Returns:
            Dictionary of key-value pairs matching the tag.
        """
        async with self._lock:
            return {
                k: v.value for k, v in self._store.items() if tag in v.tags
            }

    async def exists(self, key: str) -> bool:
        """
        Check if a key exists in memory.

        Args:
            key: Key to check.

        Returns:
            True if key exists, False otherwise.
        """
        async with self._lock:
            return key in self._store

    async def get_stats(self) -> dict[str, Any]:
        """
        Get statistics about the memory store.

        Returns:
            Dictionary with memory statistics.
        """
        async with self._lock:
            return {
                "size": len(self._store),
                "max_size": self.max_size,
                "episode_id": self._episode_id,
                "keys": list(self._store.keys()),
                "utilization": len(self._store) / self.max_size if self.max_size > 0 else 0,
            }