File size: 14,649 Bytes
bb3ee41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
"""Unified memory manager providing access to all memory layers."""

from __future__ import annotations

import logging
from enum import Enum
from typing import Any

from pydantic import BaseModel, Field

from app.config import Settings
from app.memory.long_term import Document, LongTermMemory, SearchResult
from app.memory.shared import Message, SharedMemory
from app.memory.short_term import MemoryEntry, ShortTermMemory
from app.memory.working import WorkingMemory, WorkingMemoryItem

logger = logging.getLogger(__name__)


class MemoryType(str, Enum):
    """Types of memory layers."""

    SHORT_TERM = "short_term"
    WORKING = "working"
    LONG_TERM = "long_term"
    SHARED = "shared"


class MemoryStats(BaseModel):
    """Statistics for all memory layers."""

    short_term: dict[str, Any] = Field(default_factory=dict)
    working: dict[str, Any] = Field(default_factory=dict)
    long_term: dict[str, Any] = Field(default_factory=dict)
    shared: dict[str, Any] = Field(default_factory=dict)


class MemoryManager:
    """
    Unified interface to all memory layers.

    The MemoryManager provides a single entry point for interacting with
    different types of memory (short-term, working, long-term, shared).
    It handles initialization, coordination, and lifecycle management.

    Attributes:
        short_term: Episode-scoped dictionary memory.
        working: LRU-based reasoning scratch space.
        long_term: Persistent vector storage.
        shared: Multi-agent shared state.
    """

    def __init__(self, settings: Settings) -> None:
        """
        Initialize memory manager with settings.

        Args:
            settings: Application settings.
        """
        self._settings = settings
        self._initialized = False

        # Initialize memory layers
        self.short_term = ShortTermMemory(
            max_size=settings.short_term_memory_size,
        )

        self.working = WorkingMemory(
            capacity=settings.working_memory_size,
        )

        self.long_term = LongTermMemory(
            collection_name=settings.chroma_collection_name,
            persist_directory=settings.chroma_persist_directory,
            top_k=settings.long_term_memory_top_k,
        )

        self.shared = SharedMemory()

    async def initialize(self) -> None:
        """
        Initialize all memory layers.

        This should be called during application startup.
        """
        if self._initialized:
            return

        try:
            # Initialize long-term memory (ChromaDB)
            await self.long_term.initialize()
            self._initialized = True
            logger.info("Memory manager initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize memory manager: {e}")
            raise

    async def shutdown(self) -> None:
        """
        Shutdown all memory layers gracefully.

        This should be called during application shutdown.
        """
        try:
            # Persist long-term memory
            await self.long_term.shutdown()

            # Clear working memory
            await self.working.clear()

            self._initialized = False
            logger.info("Memory manager shutdown complete")
        except Exception as e:
            logger.error(f"Error during memory manager shutdown: {e}")
            raise

    @property
    def is_initialized(self) -> bool:
        """Check if memory manager is initialized."""
        return self._initialized

    # =========================================================================
    # Unified Store Interface
    # =========================================================================

    async def store(
        self,
        key: str,
        value: Any,
        memory_type: MemoryType = MemoryType.SHORT_TERM,
        **kwargs: Any,
    ) -> Any:
        """
        Store a value in the specified memory layer.

        Args:
            key: Key or identifier for the stored value.
            value: Value to store.
            memory_type: Which memory layer to use.
            **kwargs: Additional arguments passed to the specific layer.

        Returns:
            The created entry/document (varies by memory type).

        Raises:
            ValueError: If memory_type is invalid.
        """
        match memory_type:
            case MemoryType.SHORT_TERM:
                tags = kwargs.get("tags")
                return await self.short_term.set(key, value, tags=tags)

            case MemoryType.WORKING:
                priority = kwargs.get("priority", 0.0)
                metadata = kwargs.get("metadata")
                return await self.working.push(
                    content=value,
                    item_id=key,
                    priority=priority,
                    metadata=metadata,
                )

            case MemoryType.LONG_TERM:
                if not isinstance(value, str):
                    value = str(value)
                metadata = kwargs.get("metadata")
                embedding = kwargs.get("embedding")
                return await self.long_term.store(
                    content=value,
                    document_id=key,
                    metadata=metadata,
                    embedding=embedding,
                )

            case MemoryType.SHARED:
                await self.shared.set_state(key, value)
                return value

            case _:
                raise ValueError(f"Invalid memory type: {memory_type}")

    # =========================================================================
    # Unified Retrieve Interface
    # =========================================================================

    async def retrieve(
        self,
        key: str,
        memory_type: MemoryType = MemoryType.SHORT_TERM,
        default: Any = None,
    ) -> Any:
        """
        Retrieve a value from the specified memory layer.

        Args:
            key: Key or identifier to look up.
            memory_type: Which memory layer to query.
            default: Default value if not found.

        Returns:
            The stored value or default.

        Raises:
            ValueError: If memory_type is invalid.
        """
        match memory_type:
            case MemoryType.SHORT_TERM:
                return await self.short_term.get(key, default=default)

            case MemoryType.WORKING:
                item = await self.working.peek_by_id(key)
                return item.content if item else default

            case MemoryType.LONG_TERM:
                doc = await self.long_term.get(key)
                return doc.content if doc else default

            case MemoryType.SHARED:
                return await self.shared.get_state(key, default=default)

            case _:
                raise ValueError(f"Invalid memory type: {memory_type}")

    # =========================================================================
    # Unified Search Interface
    # =========================================================================

    async def search(
        self,
        query: str,
        memory_type: MemoryType = MemoryType.LONG_TERM,
        top_k: int = 10,
        **kwargs: Any,
    ) -> list[Any]:
        """
        Search for values in the specified memory layer.

        Args:
            query: Search query.
            memory_type: Which memory layer to search.
            top_k: Maximum number of results.
            **kwargs: Additional arguments for specific layers.

        Returns:
            List of matching entries/documents.

        Raises:
            ValueError: If memory_type is invalid or doesn't support search.
        """
        match memory_type:
            case MemoryType.SHORT_TERM:
                # Search by tag or return all keys containing query
                tag = kwargs.get("tag")
                if tag:
                    return list((await self.short_term.get_by_tag(tag)).items())[:top_k]
                keys = await self.short_term.list_keys()
                matching = [k for k in keys if query.lower() in k.lower()]
                results = []
                for key in matching[:top_k]:
                    value = await self.short_term.get(key)
                    results.append((key, value))
                return results

            case MemoryType.WORKING:
                # Search working memory items
                def matches(item: WorkingMemoryItem) -> bool:
                    content_str = str(item.content).lower()
                    return query.lower() in content_str

                items = await self.working.search(matches)
                return items[:top_k]

            case MemoryType.LONG_TERM:
                where = kwargs.get("where")
                query_embedding = kwargs.get("query_embedding")
                return await self.long_term.search(
                    query=query,
                    top_k=top_k,
                    where=where,
                    query_embedding=query_embedding,
                )

            case MemoryType.SHARED:
                # Search state keys
                all_state = await self.shared.get_all_state()
                matching = [
                    (k, v)
                    for k, v in all_state.items()
                    if query.lower() in k.lower()
                    or query.lower() in str(v).lower()
                ]
                return matching[:top_k]

            case _:
                raise ValueError(f"Invalid memory type: {memory_type}")

    # =========================================================================
    # Unified Clear Interface
    # =========================================================================

    async def clear(
        self,
        memory_type: MemoryType | None = None,
    ) -> dict[str, int]:
        """
        Clear memory layers.

        Args:
            memory_type: Specific layer to clear, or None for all.

        Returns:
            Dictionary with counts of cleared items per layer.
        """
        results: dict[str, int] = {}

        if memory_type is None or memory_type == MemoryType.SHORT_TERM:
            results["short_term"] = await self.short_term.clear()

        if memory_type is None or memory_type == MemoryType.WORKING:
            results["working"] = await self.working.clear()

        if memory_type is None or memory_type == MemoryType.LONG_TERM:
            results["long_term"] = await self.long_term.clear()

        if memory_type is None or memory_type == MemoryType.SHARED:
            shared_results = await self.shared.clear()
            results["shared_channels"] = shared_results["channels"]
            results["shared_state"] = shared_results["state_keys"]

        return results

    # =========================================================================
    # Episode Management
    # =========================================================================

    async def start_episode(self, episode_id: str) -> None:
        """
        Start a new episode, clearing episode-scoped memory.

        Args:
            episode_id: Unique identifier for the episode.
        """
        await self.short_term.set_episode(episode_id)
        await self.working.clear()
        logger.debug(f"Started episode: {episode_id}")

    async def end_episode(self) -> dict[str, int]:
        """
        End the current episode, clearing temporary memory.

        Returns:
            Counts of cleared items.
        """
        results = {
            "short_term": await self.short_term.clear(),
            "working": await self.working.clear(),
        }
        logger.debug(f"Ended episode: {results}")
        return results

    # =========================================================================
    # Statistics
    # =========================================================================

    async def get_stats(self) -> MemoryStats:
        """
        Get statistics from all memory layers.

        Returns:
            MemoryStats with info from each layer.
        """
        return MemoryStats(
            short_term=await self.short_term.get_stats(),
            working=await self.working.get_stats(),
            long_term=await self.long_term.get_stats(),
            shared=await self.shared.get_stats(),
        )

    # =========================================================================
    # Convenience Methods
    # =========================================================================

    async def remember(
        self,
        content: str,
        metadata: dict[str, Any] | None = None,
    ) -> Document:
        """
        Store content in long-term memory for later retrieval.

        This is a convenience method for storing knowledge.

        Args:
            content: Text content to remember.
            metadata: Optional metadata.

        Returns:
            The stored document.
        """
        return await self.long_term.store(content=content, metadata=metadata)

    async def recall(
        self,
        query: str,
        top_k: int = 5,
    ) -> list[SearchResult]:
        """
        Recall relevant memories based on a query.

        This is a convenience method for semantic search.

        Args:
            query: Search query.
            top_k: Number of results to return.

        Returns:
            List of relevant search results.
        """
        return await self.long_term.search(query=query, top_k=top_k)

    async def think(
        self,
        thought: str,
        priority: float = 0.0,
    ) -> WorkingMemoryItem:
        """
        Add a thought to working memory.

        This is a convenience method for reasoning steps.

        Args:
            thought: The thought content.
            priority: Priority score.

        Returns:
            The working memory item.
        """
        return await self.working.push(content=thought, priority=priority)

    async def broadcast(
        self,
        channel: str,
        message: Any,
        sender: str | None = None,
    ) -> Message:
        """
        Broadcast a message to a shared channel.

        This is a convenience method for multi-agent communication.

        Args:
            channel: Channel name.
            message: Message payload.
            sender: Optional sender identifier.

        Returns:
            The published message.
        """
        return await self.shared.publish(
            channel=channel,
            payload=message,
            sender=sender,
        )