Spaces:
Runtime error
Runtime error
| """ | |
| Simple in-memory store compatible with the ChatKit Store interface. | |
| A production app would implement this using a persistant database. | |
| """ | |
| from __future__ import annotations | |
| from collections import defaultdict | |
| from chatkit.store import NotFoundError, Store | |
| from chatkit.types import Attachment, Page, ThreadItem, ThreadMetadata | |
| class MemoryStore(Store[dict]): | |
| def __init__(self): | |
| self.threads: dict[str, ThreadMetadata] = {} | |
| self.items: dict[str, list[ThreadItem]] = defaultdict(list) | |
| async def load_thread(self, thread_id: str, context: dict) -> ThreadMetadata: | |
| if thread_id not in self.threads: | |
| raise NotFoundError(f"Thread {thread_id} not found") | |
| return self.threads[thread_id] | |
| async def save_thread(self, thread: ThreadMetadata, context: dict) -> None: | |
| self.threads[thread.id] = thread | |
| async def load_threads( | |
| self, limit: int, after: str | None, order: str, context: dict | |
| ) -> Page[ThreadMetadata]: | |
| threads = list(self.threads.values()) | |
| return self._paginate( | |
| threads, | |
| after, | |
| limit, | |
| order, | |
| sort_key=lambda t: t.created_at, | |
| cursor_key=lambda t: t.id, | |
| ) | |
| async def load_thread_items( | |
| self, thread_id: str, after: str | None, limit: int, order: str, context: dict | |
| ) -> Page[ThreadItem]: | |
| items = self.items.get(thread_id, []) | |
| return self._paginate( | |
| items, | |
| after, | |
| limit, | |
| order, | |
| sort_key=lambda i: i.created_at, | |
| cursor_key=lambda i: i.id, | |
| ) | |
| async def add_thread_item( | |
| self, thread_id: str, item: ThreadItem, context: dict | |
| ) -> None: | |
| self.items[thread_id].append(item) | |
| async def save_item(self, thread_id: str, item: ThreadItem, context: dict) -> None: | |
| items = self.items[thread_id] | |
| for idx, existing in enumerate(items): | |
| if existing.id == item.id: | |
| items[idx] = item | |
| return | |
| items.append(item) | |
| async def load_item( | |
| self, thread_id: str, item_id: str, context: dict | |
| ) -> ThreadItem: | |
| for item in self.items.get(thread_id, []): | |
| if item.id == item_id: | |
| return item | |
| raise NotFoundError(f"Item {item_id} not found in thread {thread_id}") | |
| async def delete_thread(self, thread_id: str, context: dict) -> None: | |
| self.threads.pop(thread_id, None) | |
| self.items.pop(thread_id, None) | |
| async def delete_thread_item( | |
| self, thread_id: str, item_id: str, context: dict | |
| ) -> None: | |
| self.items[thread_id] = [ | |
| item for item in self.items.get(thread_id, []) if item.id != item_id | |
| ] | |
| def _paginate( | |
| self, | |
| rows: list, | |
| after: str | None, | |
| limit: int, | |
| order: str, | |
| sort_key, | |
| cursor_key, | |
| ): | |
| sorted_rows = sorted(rows, key=sort_key, reverse=order == "desc") | |
| start = 0 | |
| if after: | |
| for idx, row in enumerate(sorted_rows): | |
| if cursor_key(row) == after: | |
| start = idx + 1 | |
| break | |
| data = sorted_rows[start : start + limit] | |
| has_more = start + limit < len(sorted_rows) | |
| next_after = cursor_key(data[-1]) if has_more and data else None | |
| return Page(data=data, has_more=has_more, after=next_after) | |
| # Attachments are not implemented in the quickstart store | |
| async def save_attachment(self, attachment: Attachment, context: dict) -> None: | |
| raise NotImplementedError() | |
| async def load_attachment(self, attachment_id: str, context: dict) -> Attachment: | |
| raise NotImplementedError() | |
| async def delete_attachment(self, attachment_id: str, context: dict) -> None: | |
| raise NotImplementedError() | |