Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from typing import Any, Dict, List | |
| from chatkit.store import NotFoundError, Store | |
| from chatkit.types import Attachment, Page, Thread, ThreadItem, ThreadMetadata | |
| class _ThreadState: | |
| thread: ThreadMetadata | |
| items: List[ThreadItem] | |
| class MemoryStore(Store[dict[str, Any]]): | |
| """Simple in-memory store compatible with the ChatKit server interface.""" | |
| def __init__(self) -> None: | |
| self._threads: Dict[str, _ThreadState] = {} | |
| # Attachments intentionally unsupported; use a real store that enforces auth. | |
| def _coerce_thread_metadata(thread: ThreadMetadata | Thread) -> ThreadMetadata: | |
| """Return thread metadata without any embedded items (openai-chatkit>=1.0).""" | |
| has_items = isinstance(thread, Thread) or "items" in getattr( | |
| thread, "model_fields_set", set() | |
| ) | |
| if not has_items: | |
| return thread.model_copy(deep=True) | |
| data = thread.model_dump() | |
| data.pop("items", None) | |
| return ThreadMetadata(**data).model_copy(deep=True) | |
| # -- Thread metadata ------------------------------------------------- | |
| async def load_thread(self, thread_id: str, context: dict[str, Any]) -> ThreadMetadata: | |
| state = self._threads.get(thread_id) | |
| if not state: | |
| raise NotFoundError(f"Thread {thread_id} not found") | |
| return self._coerce_thread_metadata(state.thread) | |
| async def save_thread(self, thread: ThreadMetadata, context: dict[str, Any]) -> None: | |
| metadata = self._coerce_thread_metadata(thread) | |
| state = self._threads.get(thread.id) | |
| if state: | |
| state.thread = metadata | |
| else: | |
| self._threads[thread.id] = _ThreadState( | |
| thread=metadata, | |
| items=[], | |
| ) | |
| async def load_threads( | |
| self, | |
| limit: int, | |
| after: str | None, | |
| order: str, | |
| context: dict[str, Any], | |
| ) -> Page[ThreadMetadata]: | |
| threads = sorted( | |
| (self._coerce_thread_metadata(state.thread) for state in self._threads.values()), | |
| key=lambda t: t.created_at or datetime.min, | |
| reverse=(order == "desc"), | |
| ) | |
| if after: | |
| index_map = {thread.id: idx for idx, thread in enumerate(threads)} | |
| start = index_map.get(after, -1) + 1 | |
| else: | |
| start = 0 | |
| slice_threads = threads[start : start + limit + 1] | |
| has_more = len(slice_threads) > limit | |
| slice_threads = slice_threads[:limit] | |
| next_after = slice_threads[-1].id if has_more and slice_threads else None | |
| return Page( | |
| data=slice_threads, | |
| has_more=has_more, | |
| after=next_after, | |
| ) | |
| async def delete_thread(self, thread_id: str, context: dict[str, Any]) -> None: | |
| self._threads.pop(thread_id, None) | |
| # -- Thread items ---------------------------------------------------- | |
| def _items(self, thread_id: str) -> List[ThreadItem]: | |
| state = self._threads.get(thread_id) | |
| if state is None: | |
| state = _ThreadState( | |
| thread=ThreadMetadata(id=thread_id, created_at=datetime.utcnow()), | |
| items=[], | |
| ) | |
| self._threads[thread_id] = state | |
| return state.items | |
| async def load_thread_items( | |
| self, | |
| thread_id: str, | |
| after: str | None, | |
| limit: int, | |
| order: str, | |
| context: dict[str, Any], | |
| ) -> Page[ThreadItem]: | |
| items = [item.model_copy(deep=True) for item in self._items(thread_id)] | |
| items.sort( | |
| key=lambda item: getattr(item, "created_at", datetime.utcnow()), | |
| reverse=(order == "desc"), | |
| ) | |
| if after: | |
| index_map = {item.id: idx for idx, item in enumerate(items)} | |
| start = index_map.get(after, -1) + 1 | |
| else: | |
| start = 0 | |
| slice_items = items[start : start + limit + 1] | |
| has_more = len(slice_items) > limit | |
| slice_items = slice_items[:limit] | |
| next_after = slice_items[-1].id if has_more and slice_items else None | |
| return Page(data=slice_items, has_more=has_more, after=next_after) | |
| async def add_thread_item( | |
| self, thread_id: str, item: ThreadItem, context: dict[str, Any] | |
| ) -> None: | |
| print("Adding item", item) | |
| self._items(thread_id).append(item.model_copy(deep=True)) | |
| async def save_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None: | |
| items = self._items(thread_id) | |
| for idx, existing in enumerate(items): | |
| if existing.id == item.id: | |
| items[idx] = item.model_copy(deep=True) | |
| return | |
| items.append(item.model_copy(deep=True)) | |
| async def load_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> ThreadItem: | |
| for item in self._items(thread_id): | |
| if item.id == item_id: | |
| return item.model_copy(deep=True) | |
| raise NotFoundError(f"Item {item_id} not found") | |
| async def delete_thread_item( | |
| self, thread_id: str, item_id: str, context: dict[str, Any] | |
| ) -> None: | |
| items = self._items(thread_id) | |
| self._threads[thread_id].items = [item for item in items if item.id != item_id] | |
| # -- Files ----------------------------------------------------------- | |
| # These methods are not currently used but required to be compatible with the Store interface. | |
| async def save_attachment( | |
| self, | |
| attachment: Attachment, | |
| context: dict[str, Any], | |
| ) -> None: | |
| raise NotImplementedError( | |
| "MemoryStore does not persist attachments. Provide a Store implementation " | |
| "that enforces authentication and authorization before enabling uploads." | |
| ) | |
| async def load_attachment( | |
| self, | |
| attachment_id: str, | |
| context: dict[str, Any], | |
| ) -> Attachment: | |
| raise NotImplementedError( | |
| "MemoryStore does not load attachments. Provide a Store implementation " | |
| "that enforces authentication and authorization before enabling uploads." | |
| ) | |
| async def delete_attachment(self, attachment_id: str, context: dict[str, Any]) -> None: | |
| raise NotImplementedError( | |
| "MemoryStore does not delete attachments because they are never stored." | |
| ) | |