Spaces:
Running
Running
| """PostgreSQL Store implementation for ChatKit SDK. | |
| [Task]: T009 | |
| [From]: specs/010-chatkit-migration/data-model.md - ChatKit SDK Interface Requirements | |
| [From]: specs/010-chatkit-migration/contracts/backend.md - Store Interface Implementation | |
| This module implements the ChatKit Store interface using SQLModel and PostgreSQL. | |
| The Store interface is required by ChatKit's Python SDK for thread and message persistence. | |
| ChatKit Store Protocol Methods: | |
| - list_threads: List threads for a user with pagination | |
| - get_thread: Get a specific thread by ID | |
| - create_thread: Create a new thread | |
| - update_thread: Update thread metadata | |
| - delete_thread: Delete a thread | |
| - list_messages: List messages in a thread | |
| - get_message: Get a specific message | |
| - create_message: Create a new message | |
| - update_message: Update a message | |
| - delete_message: Delete a message | |
| """ | |
| import uuid | |
| from datetime import datetime | |
| from typing import Any, Optional | |
| from sqlmodel import Session, select, col | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from models.thread import Thread | |
| from models.message import Message, MessageRole | |
| class PostgresChatKitStore: | |
| """PostgreSQL implementation of ChatKit Store interface. | |
| [From]: specs/010-chatkit-migration/data-model.md - Store Interface | |
| This store provides thread and message persistence for ChatKit using | |
| the existing SQLModel models and PostgreSQL database. | |
| Note: The ChatKit SDK uses a Protocol-based interface. The actual | |
| protocol types (ThreadMetadata, MessageItem, etc.) would be imported | |
| from the openai_chatkit package. For this implementation, we use | |
| dictionary-based representations until the SDK is installed. | |
| Usage: | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| from services.chatkit_store import PostgresChatKitStore | |
| engine = create_async_engine(database_url) | |
| async with AsyncSession(engine) as session: | |
| store = PostgresChatKitStore(session) | |
| thread = await store.create_thread( | |
| user_id="user-123", | |
| title="My Conversation", | |
| metadata={"tag": "important"} | |
| ) | |
| """ | |
| def __init__(self, session: AsyncSession): | |
| """Initialize the store with a database session. | |
| Args: | |
| session: SQLAlchemy async session for database operations | |
| """ | |
| self.session = session | |
| async def list_threads( | |
| self, | |
| user_id: str, | |
| limit: int = 50, | |
| offset: int = 0 | |
| ) -> list[dict]: | |
| """List threads for a user with pagination. | |
| [From]: specs/010-chatkit-migration/data-model.md - Retrieve User's Conversations | |
| Args: | |
| user_id: User ID to filter threads | |
| limit: Maximum number of threads to return (default: 50) | |
| offset: Number of threads to skip (default: 0) | |
| Returns: | |
| List of thread metadata dictionaries | |
| """ | |
| stmt = ( | |
| select(Thread) | |
| .where(Thread.user_id == uuid.UUID(user_id)) | |
| .order_by(Thread.updated_at.desc()) | |
| .limit(limit) | |
| .offset(offset) | |
| ) | |
| result = await self.session.execute(stmt) | |
| threads = result.scalars().all() | |
| return [ | |
| { | |
| "id": str(thread.id), | |
| "user_id": str(thread.user_id), | |
| "title": thread.title, | |
| "metadata": thread.thread_metadata or {}, | |
| "created_at": thread.created_at.isoformat(), | |
| "updated_at": thread.updated_at.isoformat(), | |
| } | |
| for thread in threads | |
| ] | |
| async def get_thread(self, thread_id: str) -> Optional[dict]: | |
| """Get a specific thread by ID. | |
| Args: | |
| thread_id: Thread UUID as string | |
| Returns: | |
| Thread metadata dictionary or None if not found | |
| """ | |
| stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id)) | |
| result = await self.session.execute(stmt) | |
| thread = result.scalar_one_or_none() | |
| if thread is None: | |
| return None | |
| return { | |
| "id": str(thread.id), | |
| "user_id": str(thread.user_id), | |
| "title": thread.title, | |
| "metadata": thread.thread_metadata or {}, | |
| "created_at": thread.created_at.isoformat(), | |
| "updated_at": thread.updated_at.isoformat(), | |
| } | |
| async def create_thread( | |
| self, | |
| user_id: str, | |
| title: Optional[str] = None, | |
| metadata: Optional[dict] = None | |
| ) -> dict: | |
| """Create a new thread. | |
| Args: | |
| user_id: User ID who owns the thread | |
| title: Optional thread title | |
| metadata: Optional thread metadata | |
| Returns: | |
| Created thread metadata dictionary | |
| """ | |
| thread = Thread( | |
| user_id=uuid.UUID(user_id), | |
| title=title, | |
| thread_metadata=metadata or {}, | |
| ) | |
| self.session.add(thread) | |
| await self.session.commit() | |
| await self.session.refresh(thread) | |
| return { | |
| "id": str(thread.id), | |
| "user_id": str(thread.user_id), | |
| "title": thread.title, | |
| "metadata": thread.thread_metadata or {}, | |
| "created_at": thread.created_at.isoformat(), | |
| "updated_at": thread.updated_at.isoformat(), | |
| } | |
| async def update_thread( | |
| self, | |
| thread_id: str, | |
| title: Optional[str] = None, | |
| metadata: Optional[dict] = None | |
| ) -> Optional[dict]: | |
| """Update a thread. | |
| Args: | |
| thread_id: Thread UUID as string | |
| title: New title (optional) | |
| metadata: New metadata (optional) | |
| Returns: | |
| Updated thread metadata dictionary or None if not found | |
| """ | |
| stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id)) | |
| result = await self.session.execute(stmt) | |
| thread = result.scalar_one_or_none() | |
| if thread is None: | |
| return None | |
| if title is not None: | |
| thread.title = title | |
| if metadata is not None: | |
| thread.thread_metadata = metadata | |
| thread.updated_at = datetime.utcnow() | |
| await self.session.commit() | |
| await self.session.refresh(thread) | |
| return { | |
| "id": str(thread.id), | |
| "user_id": str(thread.user_id), | |
| "title": thread.title, | |
| "metadata": thread.thread_metadata or {}, | |
| "created_at": thread.created_at.isoformat(), | |
| "updated_at": thread.updated_at.isoformat(), | |
| } | |
| async def delete_thread(self, thread_id: str) -> bool: | |
| """Delete a thread. | |
| Args: | |
| thread_id: Thread UUID as string | |
| Returns: | |
| True if deleted, False if not found | |
| """ | |
| stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id)) | |
| result = await self.session.execute(stmt) | |
| thread = result.scalar_one_or_none() | |
| if thread is None: | |
| return False | |
| await self.session.delete(thread) | |
| await self.session.commit() | |
| return True | |
| async def list_messages( | |
| self, | |
| thread_id: str, | |
| limit: int = 50, | |
| offset: int = 0 | |
| ) -> list[dict]: | |
| """List messages in a thread. | |
| [From]: specs/010-chatkit-migration/data-model.md - Retrieve Conversation Messages | |
| Args: | |
| thread_id: Thread UUID as string | |
| limit: Maximum number of messages to return | |
| offset: Number of messages to skip | |
| Returns: | |
| List of message item dictionaries | |
| """ | |
| stmt = ( | |
| select(Message) | |
| .where(Message.thread_id == uuid.UUID(thread_id)) | |
| .order_by(Message.created_at.asc()) | |
| .limit(limit) | |
| .offset(offset) | |
| ) | |
| result = await self.session.execute(stmt) | |
| messages = result.scalars().all() | |
| return [ | |
| { | |
| "id": str(msg.id), | |
| "type": "message", | |
| "role": msg.role.value, | |
| "content": [{"type": "text", "text": msg.content}], | |
| "tool_calls": msg.tool_calls, | |
| "created_at": msg.created_at.isoformat(), | |
| } | |
| for msg in messages | |
| ] | |
| async def get_message(self, message_id: str) -> Optional[dict]: | |
| """Get a specific message by ID. | |
| Args: | |
| message_id: Message UUID as string | |
| Returns: | |
| Message item dictionary or None if not found | |
| """ | |
| stmt = select(Message).where(Message.id == uuid.UUID(message_id)) | |
| result = await self.session.execute(stmt) | |
| message = result.scalar_one_or_none() | |
| if message is None: | |
| return None | |
| return { | |
| "id": str(message.id), | |
| "type": "message", | |
| "role": message.role.value, | |
| "content": [{"type": "text", "text": message.content}], | |
| "tool_calls": message.tool_calls, | |
| "created_at": message.created_at.isoformat(), | |
| } | |
| async def create_message( | |
| self, | |
| thread_id: str, | |
| item: dict | |
| ) -> dict: | |
| """Create a new message in a thread. | |
| Args: | |
| thread_id: Thread UUID as string | |
| item: Message item from ChatKit (UserMessageItem or ClientToolCallOutputItem) | |
| Returns: | |
| Created message item dictionary | |
| Raises: | |
| ValueError: If item format is invalid | |
| """ | |
| # Extract content from ChatKit item format | |
| # ChatKit uses: {"type": "message", "role": "user", "content": [{"type": "text", "text": "..."}]} | |
| item_type = item.get("type", "message") | |
| role = item.get("role", "user") | |
| # Extract text content from content array | |
| content_array = item.get("content", []) | |
| text_content = "" | |
| for content_block in content_array: | |
| if content_block.get("type") == "text": | |
| text_content = content_block.get("text", "") | |
| break | |
| # Handle client tool output items | |
| if item_type == "client_tool_call_output": | |
| text_content = item.get("output", "") | |
| message = Message( | |
| thread_id=uuid.UUID(thread_id), | |
| role=MessageRole(role), | |
| content=text_content, | |
| tool_calls=item.get("tool_calls"), | |
| ) | |
| self.session.add(message) | |
| await self.session.commit() | |
| await self.session.refresh(message) | |
| # Update thread's updated_at timestamp | |
| thread_stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id)) | |
| thread_result = await self.session.execute(thread_stmt) | |
| thread = thread_result.scalar_one_or_none() | |
| if thread: | |
| thread.updated_at = datetime.utcnow() | |
| await self.session.commit() | |
| return { | |
| "id": str(message.id), | |
| "type": "message", | |
| "role": message.role.value, | |
| "content": [{"type": "text", "text": message.content}], | |
| "tool_calls": message.tool_calls, | |
| "created_at": message.created_at.isoformat(), | |
| } | |
| async def update_message( | |
| self, | |
| message_id: str, | |
| item: dict | |
| ) -> Optional[dict]: | |
| """Update a message. | |
| Args: | |
| message_id: Message UUID as string | |
| item: Updated message item | |
| Returns: | |
| Updated message item dictionary or None if not found | |
| """ | |
| stmt = select(Message).where(Message.id == uuid.UUID(message_id)) | |
| result = await self.session.execute(stmt) | |
| message = result.scalar_one_or_none() | |
| if message is None: | |
| return None | |
| # Update content if provided | |
| content_array = item.get("content", []) | |
| if content_array: | |
| for content_block in content_array: | |
| if content_block.get("type") == "text": | |
| message.content = content_block.get("text", message.content) | |
| break | |
| # Update tool_calls if provided | |
| if "tool_calls" in item: | |
| message.tool_calls = item["tool_calls"] | |
| await self.session.commit() | |
| await self.session.refresh(message) | |
| return { | |
| "id": str(message.id), | |
| "type": "message", | |
| "role": message.role.value, | |
| "content": [{"type": "text", "text": message.content}], | |
| "tool_calls": message.tool_calls, | |
| "created_at": message.created_at.isoformat(), | |
| } | |
| async def delete_message(self, message_id: str) -> bool: | |
| """Delete a message. | |
| Args: | |
| message_id: Message UUID as string | |
| Returns: | |
| True if deleted, False if not found | |
| """ | |
| stmt = select(Message).where(Message.id == uuid.UUID(message_id)) | |
| result = await self.session.execute(stmt) | |
| message = result.scalar_one_or_none() | |
| if message is None: | |
| return False | |
| await self.session.delete(message) | |
| await self.session.commit() | |
| return True | |