ClassLens / chatkit /backend /app /memory_store.py
lukasgross's picture
Simplify ChatKit backend (#104)
b0fbfe1
"""
Simple in-memory store compatible with the ChatKit Store interface.
A production app would implement this using a persistant database.
"""
from __future__ import annotations
from collections import defaultdict
from chatkit.store import NotFoundError, Store
from chatkit.types import Attachment, Page, ThreadItem, ThreadMetadata
class MemoryStore(Store[dict]):
def __init__(self):
self.threads: dict[str, ThreadMetadata] = {}
self.items: dict[str, list[ThreadItem]] = defaultdict(list)
async def load_thread(self, thread_id: str, context: dict) -> ThreadMetadata:
if thread_id not in self.threads:
raise NotFoundError(f"Thread {thread_id} not found")
return self.threads[thread_id]
async def save_thread(self, thread: ThreadMetadata, context: dict) -> None:
self.threads[thread.id] = thread
async def load_threads(
self, limit: int, after: str | None, order: str, context: dict
) -> Page[ThreadMetadata]:
threads = list(self.threads.values())
return self._paginate(
threads,
after,
limit,
order,
sort_key=lambda t: t.created_at,
cursor_key=lambda t: t.id,
)
async def load_thread_items(
self, thread_id: str, after: str | None, limit: int, order: str, context: dict
) -> Page[ThreadItem]:
items = self.items.get(thread_id, [])
return self._paginate(
items,
after,
limit,
order,
sort_key=lambda i: i.created_at,
cursor_key=lambda i: i.id,
)
async def add_thread_item(
self, thread_id: str, item: ThreadItem, context: dict
) -> None:
self.items[thread_id].append(item)
async def save_item(self, thread_id: str, item: ThreadItem, context: dict) -> None:
items = self.items[thread_id]
for idx, existing in enumerate(items):
if existing.id == item.id:
items[idx] = item
return
items.append(item)
async def load_item(
self, thread_id: str, item_id: str, context: dict
) -> ThreadItem:
for item in self.items.get(thread_id, []):
if item.id == item_id:
return item
raise NotFoundError(f"Item {item_id} not found in thread {thread_id}")
async def delete_thread(self, thread_id: str, context: dict) -> None:
self.threads.pop(thread_id, None)
self.items.pop(thread_id, None)
async def delete_thread_item(
self, thread_id: str, item_id: str, context: dict
) -> None:
self.items[thread_id] = [
item for item in self.items.get(thread_id, []) if item.id != item_id
]
def _paginate(
self,
rows: list,
after: str | None,
limit: int,
order: str,
sort_key,
cursor_key,
):
sorted_rows = sorted(rows, key=sort_key, reverse=order == "desc")
start = 0
if after:
for idx, row in enumerate(sorted_rows):
if cursor_key(row) == after:
start = idx + 1
break
data = sorted_rows[start : start + limit]
has_more = start + limit < len(sorted_rows)
next_after = cursor_key(data[-1]) if has_more and data else None
return Page(data=data, has_more=has_more, after=next_after)
# Attachments are not implemented in the quickstart store
async def save_attachment(self, attachment: Attachment, context: dict) -> None:
raise NotImplementedError()
async def load_attachment(self, attachment_id: str, context: dict) -> Attachment:
raise NotImplementedError()
async def delete_attachment(self, attachment_id: str, context: dict) -> None:
raise NotImplementedError()