from __future__ import annotations from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List from chatkit.store import NotFoundError, Store from chatkit.types import Attachment, Page, Thread, ThreadItem, ThreadMetadata @dataclass class _ThreadState: thread: ThreadMetadata items: List[ThreadItem] class MemoryStore(Store[dict[str, Any]]): """Simple in-memory store compatible with the ChatKit server interface.""" def __init__(self) -> None: self._threads: Dict[str, _ThreadState] = {} # Attachments intentionally unsupported; use a real store that enforces auth. @staticmethod def _coerce_thread_metadata(thread: ThreadMetadata | Thread) -> ThreadMetadata: """Return thread metadata without any embedded items (openai-chatkit>=1.0).""" has_items = isinstance(thread, Thread) or "items" in getattr( thread, "model_fields_set", set() ) if not has_items: return thread.model_copy(deep=True) data = thread.model_dump() data.pop("items", None) return ThreadMetadata(**data).model_copy(deep=True) # -- Thread metadata ------------------------------------------------- async def load_thread(self, thread_id: str, context: dict[str, Any]) -> ThreadMetadata: state = self._threads.get(thread_id) if not state: raise NotFoundError(f"Thread {thread_id} not found") return self._coerce_thread_metadata(state.thread) async def save_thread(self, thread: ThreadMetadata, context: dict[str, Any]) -> None: metadata = self._coerce_thread_metadata(thread) state = self._threads.get(thread.id) if state: state.thread = metadata else: self._threads[thread.id] = _ThreadState( thread=metadata, items=[], ) async def load_threads( self, limit: int, after: str | None, order: str, context: dict[str, Any], ) -> Page[ThreadMetadata]: threads = sorted( (self._coerce_thread_metadata(state.thread) for state in self._threads.values()), key=lambda t: t.created_at or datetime.min, reverse=(order == "desc"), ) if after: index_map = {thread.id: idx for idx, thread in enumerate(threads)} start = index_map.get(after, -1) + 1 else: start = 0 slice_threads = threads[start : start + limit + 1] has_more = len(slice_threads) > limit slice_threads = slice_threads[:limit] next_after = slice_threads[-1].id if has_more and slice_threads else None return Page( data=slice_threads, has_more=has_more, after=next_after, ) async def delete_thread(self, thread_id: str, context: dict[str, Any]) -> None: self._threads.pop(thread_id, None) # -- Thread items ---------------------------------------------------- def _items(self, thread_id: str) -> List[ThreadItem]: state = self._threads.get(thread_id) if state is None: state = _ThreadState( thread=ThreadMetadata(id=thread_id, created_at=datetime.utcnow()), items=[], ) self._threads[thread_id] = state return state.items async def load_thread_items( self, thread_id: str, after: str | None, limit: int, order: str, context: dict[str, Any], ) -> Page[ThreadItem]: items = [item.model_copy(deep=True) for item in self._items(thread_id)] items.sort( key=lambda item: getattr(item, "created_at", datetime.utcnow()), reverse=(order == "desc"), ) if after: index_map = {item.id: idx for idx, item in enumerate(items)} start = index_map.get(after, -1) + 1 else: start = 0 slice_items = items[start : start + limit + 1] has_more = len(slice_items) > limit slice_items = slice_items[:limit] next_after = slice_items[-1].id if has_more and slice_items else None return Page(data=slice_items, has_more=has_more, after=next_after) async def add_thread_item( self, thread_id: str, item: ThreadItem, context: dict[str, Any] ) -> None: print("Adding item", item) self._items(thread_id).append(item.model_copy(deep=True)) async def save_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None: items = self._items(thread_id) for idx, existing in enumerate(items): if existing.id == item.id: items[idx] = item.model_copy(deep=True) return items.append(item.model_copy(deep=True)) async def load_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> ThreadItem: for item in self._items(thread_id): if item.id == item_id: return item.model_copy(deep=True) raise NotFoundError(f"Item {item_id} not found") async def delete_thread_item( self, thread_id: str, item_id: str, context: dict[str, Any] ) -> None: items = self._items(thread_id) self._threads[thread_id].items = [item for item in items if item.id != item_id] # -- Files ----------------------------------------------------------- # These methods are not currently used but required to be compatible with the Store interface. async def save_attachment( self, attachment: Attachment, context: dict[str, Any], ) -> None: raise NotImplementedError( "MemoryStore does not persist attachments. Provide a Store implementation " "that enforces authentication and authorization before enabling uploads." ) async def load_attachment( self, attachment_id: str, context: dict[str, Any], ) -> Attachment: raise NotImplementedError( "MemoryStore does not load attachments. Provide a Store implementation " "that enforces authentication and authorization before enabling uploads." ) async def delete_attachment(self, attachment_id: str, context: dict[str, Any]) -> None: raise NotImplementedError( "MemoryStore does not delete attachments because they are never stored." )