Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import sqlite3 | |
| import threading | |
| from datetime import datetime | |
| from typing import Any, Dict, List | |
| from chatkit.store import NotFoundError, Store | |
| from chatkit.types import Attachment, Page, Thread, ThreadItem, ThreadMetadata | |
| class SQLiteStore(Store[dict[str, Any]]): | |
| """Persistent SQLite-backed store compatible with the ChatKit server interface.""" | |
| def __init__(self, db_path: str = "chatkit_threads.db") -> None: | |
| self.db_path = db_path | |
| self.conn = sqlite3.connect(db_path, check_same_thread=False) | |
| self.conn.row_factory = sqlite3.Row | |
| # 🚀 Performance optimizations | |
| self.conn.execute("PRAGMA journal_mode=WAL") # Better concurrency | |
| self.conn.execute("PRAGMA synchronous=NORMAL") # Faster writes | |
| self.conn.execute("PRAGMA cache_size=10000") # Larger cache (10MB) | |
| self.conn.execute("PRAGMA temp_store=MEMORY") # Temp tables in RAM | |
| self.conn.execute("PRAGMA mmap_size=268435456") # 256MB memory-mapped I/O | |
| self.conn.execute("PRAGMA page_size=4096") # Optimal page size | |
| self._lock = threading.RLock() # Thread-safe locking | |
| self._init_db() | |
| def _init_db(self) -> None: | |
| """Initialize database tables.""" | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| # Threads table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS threads ( | |
| thread_id TEXT PRIMARY KEY, | |
| metadata TEXT NOT NULL, | |
| created_at TEXT NOT NULL | |
| ) | |
| """) | |
| # Thread items table with sequence number for guaranteed ordering | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS thread_items ( | |
| item_id TEXT PRIMARY KEY, | |
| thread_id TEXT NOT NULL, | |
| item_data TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| sequence_num INTEGER NOT NULL, | |
| FOREIGN KEY (thread_id) REFERENCES threads(thread_id) ON DELETE CASCADE | |
| ) | |
| """) | |
| # Create optimized indexes for fast retrieval | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_thread_items_thread_id | |
| ON thread_items(thread_id) | |
| """) | |
| # 🚀 Composite index for ORDER BY queries - CRITICAL for performance! | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_thread_items_thread_sequence | |
| ON thread_items(thread_id, sequence_num, created_at) | |
| """) | |
| # Index for created_at lookups | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_thread_items_created_at | |
| ON thread_items(created_at) | |
| """) | |
| # Index for item_id lookups (primary key already indexed, but explicit is better) | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_thread_items_item_id | |
| ON thread_items(item_id) | |
| """) | |
| # Analyze tables for query optimizer | |
| cursor.execute("ANALYZE") | |
| self.conn.commit() | |
| def _coerce_thread_metadata(thread: ThreadMetadata | Thread) -> ThreadMetadata: | |
| """Return thread metadata without any embedded items (openai-chatkit>=1.0).""" | |
| has_items = isinstance(thread, Thread) or "items" in getattr( | |
| thread, "model_fields_set", set() | |
| ) | |
| if not has_items: | |
| return thread.model_copy(deep=True) | |
| data = thread.model_dump() | |
| data.pop("items", None) | |
| return ThreadMetadata(**data).model_copy(deep=True) | |
| # -- Thread metadata ------------------------------------------------- | |
| async def load_thread(self, thread_id: str, context: dict[str, Any]) -> ThreadMetadata: | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| "SELECT metadata FROM threads WHERE thread_id = ?", | |
| (thread_id,) | |
| ) | |
| row = cursor.fetchone() | |
| if not row: | |
| raise NotFoundError(f"Thread {thread_id} not found") | |
| metadata_dict = json.loads(row["metadata"]) | |
| return ThreadMetadata(**metadata_dict) | |
| async def save_thread(self, thread: ThreadMetadata, context: dict[str, Any]) -> None: | |
| with self._lock: | |
| metadata = self._coerce_thread_metadata(thread) | |
| metadata_json = metadata.model_dump_json() | |
| created_at = (metadata.created_at or datetime.utcnow()).isoformat() | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| """ | |
| INSERT INTO threads (thread_id, metadata, created_at) | |
| VALUES (?, ?, ?) | |
| ON CONFLICT(thread_id) DO UPDATE SET | |
| metadata = excluded.metadata | |
| """, | |
| (thread.id, metadata_json, created_at) | |
| ) | |
| self.conn.commit() | |
| async def load_threads( | |
| self, | |
| limit: int, | |
| after: str | None, | |
| order: str, | |
| context: dict[str, Any], | |
| ) -> Page[ThreadMetadata]: | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| order_clause = "DESC" if order == "desc" else "ASC" | |
| if after: | |
| # Get the created_at of the 'after' thread | |
| cursor.execute( | |
| "SELECT created_at FROM threads WHERE thread_id = ?", | |
| (after,) | |
| ) | |
| after_row = cursor.fetchone() | |
| if after_row: | |
| after_time = after_row["created_at"] | |
| comparison = "<" if order == "desc" else ">" | |
| cursor.execute( | |
| f""" | |
| SELECT metadata FROM threads | |
| WHERE created_at {comparison} ? | |
| ORDER BY created_at {order_clause} | |
| LIMIT ? | |
| """, | |
| (after_time, limit + 1) | |
| ) | |
| else: | |
| cursor.execute( | |
| f""" | |
| SELECT metadata FROM threads | |
| ORDER BY created_at {order_clause} | |
| LIMIT ? | |
| """, | |
| (limit + 1,) | |
| ) | |
| else: | |
| cursor.execute( | |
| f""" | |
| SELECT metadata FROM threads | |
| ORDER BY created_at {order_clause} | |
| LIMIT ? | |
| """, | |
| (limit + 1,) | |
| ) | |
| rows = cursor.fetchall() | |
| threads = [ThreadMetadata(**json.loads(row["metadata"])) for row in rows] | |
| has_more = len(threads) > limit | |
| threads = threads[:limit] | |
| next_after = threads[-1].id if has_more and threads else None | |
| return Page( | |
| data=threads, | |
| has_more=has_more, | |
| after=next_after, | |
| ) | |
| async def delete_thread(self, thread_id: str, context: dict[str, Any]) -> None: | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| cursor.execute("DELETE FROM threads WHERE thread_id = ?", (thread_id,)) | |
| cursor.execute("DELETE FROM thread_items WHERE thread_id = ?", (thread_id,)) | |
| self.conn.commit() | |
| # -- Thread items ---------------------------------------------------- | |
| def _get_next_sequence_num(self, thread_id: str) -> int: | |
| """Get the next sequence number for a thread.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| "SELECT MAX(sequence_num) as max_seq FROM thread_items WHERE thread_id = ?", | |
| (thread_id,) | |
| ) | |
| row = cursor.fetchone() | |
| max_seq = row["max_seq"] if row and row["max_seq"] is not None else 0 | |
| return max_seq + 1 | |
| async def load_thread_items( | |
| self, | |
| thread_id: str, | |
| after: str | None, | |
| limit: int, | |
| order: str, | |
| context: dict[str, Any], | |
| ) -> Page[ThreadItem]: | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| # Use sequence_num for reliable ordering, with created_at as secondary | |
| order_clause = "DESC" if order == "desc" else "ASC" | |
| if after: | |
| # Get the sequence_num of the 'after' item | |
| cursor.execute( | |
| "SELECT sequence_num FROM thread_items WHERE item_id = ?", | |
| (after,) | |
| ) | |
| after_row = cursor.fetchone() | |
| if after_row: | |
| after_seq = after_row["sequence_num"] | |
| comparison = "<" if order == "desc" else ">" | |
| cursor.execute( | |
| f""" | |
| SELECT item_data FROM thread_items | |
| WHERE thread_id = ? AND sequence_num {comparison} ? | |
| ORDER BY sequence_num {order_clause}, created_at {order_clause} | |
| LIMIT ? | |
| """, | |
| (thread_id, after_seq, limit + 1) | |
| ) | |
| else: | |
| cursor.execute( | |
| f""" | |
| SELECT item_data FROM thread_items | |
| WHERE thread_id = ? | |
| ORDER BY sequence_num {order_clause}, created_at {order_clause} | |
| LIMIT ? | |
| """, | |
| (thread_id, limit + 1) | |
| ) | |
| else: | |
| cursor.execute( | |
| f""" | |
| SELECT item_data FROM thread_items | |
| WHERE thread_id = ? | |
| ORDER BY sequence_num {order_clause}, created_at {order_clause} | |
| LIMIT ? | |
| """, | |
| (thread_id, limit + 1) | |
| ) | |
| rows = cursor.fetchall() | |
| items = [] | |
| for row in rows: | |
| try: | |
| item_dict = json.loads(row["item_data"]) | |
| item_type = item_dict.get("type") | |
| # Import all available message types | |
| from chatkit.types import ( | |
| UserMessageItem, | |
| AssistantMessageItem, | |
| ClientToolCallItem, | |
| WorkflowItem, | |
| WidgetItem, | |
| TaskItem, | |
| HiddenContextItem, | |
| ) | |
| # Reconstruct based on type | |
| if item_type == "user_message": | |
| items.append(UserMessageItem(**item_dict)) | |
| elif item_type == "assistant_message": | |
| items.append(AssistantMessageItem(**item_dict)) | |
| elif item_type == "client_tool_call": | |
| items.append(ClientToolCallItem(**item_dict)) | |
| elif item_type == "workflow": | |
| items.append(WorkflowItem(**item_dict)) | |
| elif item_type == "widget": | |
| items.append(WidgetItem(**item_dict)) | |
| elif item_type == "task": | |
| items.append(TaskItem(**item_dict)) | |
| elif item_type == "hidden_context_item": | |
| items.append(HiddenContextItem(**item_dict)) | |
| else: | |
| # Unknown type - log but continue | |
| print(f"⚠️ Skipping unknown item type: {item_type}") | |
| continue | |
| except (ImportError, TypeError, KeyError, ValueError, Exception) as e: | |
| # If reconstruction fails, skip this item | |
| print(f"⚠️ Failed to reconstruct item: {e}") | |
| continue | |
| has_more = len(items) > limit | |
| items = items[:limit] | |
| next_after = items[-1].id if has_more and items else None | |
| return Page(data=items, has_more=has_more, after=next_after) | |
| async def add_thread_item( | |
| self, thread_id: str, item: ThreadItem, context: dict[str, Any] | |
| ) -> None: | |
| with self._lock: | |
| # FIX: ChatKit/ChatCompletionsModel produces "__fake_id__" which causes collisions. | |
| # We enforce a unique ID only for items with actual content to ensure their persistence. | |
| if item.id == "__fake_id__": | |
| has_content = False | |
| # Check if message has meaningful text content | |
| if hasattr(item, 'content'): | |
| for part in item.content: | |
| if getattr(part, 'type', '') == 'output_text' and getattr(part, 'text', '').strip(): | |
| has_content = True | |
| break | |
| if has_content: | |
| import uuid | |
| new_id = f"gen_{uuid.uuid4().hex[:12]}" | |
| item = item.model_copy(update={"id": new_id}) | |
| print(f"🔄 Resolved fake_id (with content) to unique ID: {new_id}") | |
| else: | |
| print(f"⚠️ Keeping __fake_id__ for empty/tool placeholder") | |
| item_json = item.model_dump_json() | |
| created_at_val = getattr(item, "created_at", None) | |
| if created_at_val: | |
| if isinstance(created_at_val, str): | |
| created_at = created_at_val | |
| else: | |
| created_at = created_at_val.isoformat() | |
| else: | |
| created_at = datetime.utcnow().isoformat() | |
| cursor = self.conn.cursor() | |
| # Check if item already exists | |
| cursor.execute( | |
| "SELECT item_id FROM thread_items WHERE item_id = ?", | |
| (item.id,) | |
| ) | |
| existing = cursor.fetchone() | |
| if existing: | |
| # Update existing item - keep original sequence_num | |
| cursor.execute( | |
| """ | |
| UPDATE thread_items | |
| SET item_data = ?, created_at = ? | |
| WHERE item_id = ? | |
| """, | |
| (item_json, created_at, item.id) | |
| ) | |
| print(f"✅ Updated existing item {item.id}") | |
| else: | |
| # Insert new item with next sequence number | |
| sequence_num = self._get_next_sequence_num(thread_id) | |
| cursor.execute( | |
| """ | |
| INSERT INTO thread_items (item_id, thread_id, item_data, created_at, sequence_num) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, | |
| (item.id, thread_id, item_json, created_at, sequence_num) | |
| ) | |
| print(f"✅ Inserted new item {item.id} with sequence {sequence_num}") | |
| self.conn.commit() | |
| # Verify insertion | |
| cursor.execute( | |
| "SELECT COUNT(*) as count FROM thread_items WHERE thread_id = ?", | |
| (thread_id,) | |
| ) | |
| count = cursor.fetchone()["count"] | |
| print(f"🔍 Total items in thread {thread_id}: {count}") | |
| async def save_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None: | |
| # Use add_thread_item which handles both insert and update | |
| await self.add_thread_item(thread_id, item, context) | |
| async def load_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> ThreadItem: | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| "SELECT item_data FROM thread_items WHERE thread_id = ? AND item_id = ?", | |
| (thread_id, item_id) | |
| ) | |
| row = cursor.fetchone() | |
| if not row: | |
| raise NotFoundError(f"Item {item_id} not found") | |
| item_dict = json.loads(row["item_data"]) | |
| item_type = item_dict.get("type") | |
| try: | |
| from chatkit.types import ( | |
| UserMessageItem, | |
| AssistantMessageItem, | |
| ClientToolCallItem, | |
| WorkflowItem, | |
| WidgetItem, | |
| TaskItem, | |
| HiddenContextItem, | |
| ) | |
| # Reconstruct based on type | |
| if item_type == "user_message": | |
| return UserMessageItem(**item_dict) | |
| elif item_type == "assistant_message": | |
| return AssistantMessageItem(**item_dict) | |
| elif item_type == "client_tool_call": | |
| return ClientToolCallItem(**item_dict) | |
| elif item_type == "workflow": | |
| return WorkflowItem(**item_dict) | |
| elif item_type == "widget": | |
| return WidgetItem(**item_dict) | |
| elif item_type == "task": | |
| return TaskItem(**item_dict) | |
| elif item_type == "hidden_context_item": | |
| return HiddenContextItem(**item_dict) | |
| else: | |
| raise NotFoundError(f"Item {item_id} has unknown type: {item_type}") | |
| except (ImportError, TypeError, KeyError, ValueError) as e: | |
| raise NotFoundError(f"Failed to load item {item_id}: {e}") | |
| async def delete_thread_item( | |
| self, thread_id: str, item_id: str, context: dict[str, Any] | |
| ) -> None: | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| "DELETE FROM thread_items WHERE thread_id = ? AND item_id = ?", | |
| (thread_id, item_id) | |
| ) | |
| self.conn.commit() | |
| # -- Files ----------------------------------------------------------- | |
| async def save_attachment( | |
| self, | |
| attachment: Attachment, | |
| context: dict[str, Any], | |
| ) -> None: | |
| raise NotImplementedError( | |
| "SQLiteStore does not persist attachments. Provide a Store implementation " | |
| "that enforces authentication and authorization before enabling uploads." | |
| ) | |
| async def load_attachment( | |
| self, | |
| attachment_id: str, | |
| context: dict[str, Any], | |
| ) -> Attachment: | |
| raise NotImplementedError( | |
| "SQLiteStore does not load attachments. Provide a Store implementation " | |
| "that enforces authentication and authorization before enabling uploads." | |
| ) | |
| async def delete_attachment(self, attachment_id: str, context: dict[str, Any]) -> None: | |
| raise NotImplementedError( | |
| "SQLiteStore does not delete attachments because they are never stored." | |
| ) | |
| # Helper method to get items (used by zendesk integration) | |
| def _items(self, thread_id: str) -> List[ThreadItem]: | |
| """Synchronous helper to get all items for a thread (for compatibility with existing code).""" | |
| with self._lock: | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| """ | |
| SELECT item_data FROM thread_items | |
| WHERE thread_id = ? | |
| ORDER BY sequence_num ASC, created_at ASC | |
| """, | |
| (thread_id,) | |
| ) | |
| rows = cursor.fetchall() | |
| items = [] | |
| for row in rows: | |
| try: | |
| item_dict = json.loads(row["item_data"]) | |
| item_type = item_dict.get("type") | |
| from chatkit.types import ( | |
| UserMessageItem, | |
| AssistantMessageItem, | |
| ClientToolCallItem, | |
| WorkflowItem, | |
| WidgetItem, | |
| TaskItem, | |
| HiddenContextItem, | |
| ) | |
| # Reconstruct based on type | |
| if item_type == "user_message": | |
| items.append(UserMessageItem(**item_dict)) | |
| elif item_type == "assistant_message": | |
| items.append(AssistantMessageItem(**item_dict)) | |
| elif item_type == "client_tool_call": | |
| items.append(ClientToolCallItem(**item_dict)) | |
| elif item_type == "workflow": | |
| items.append(WorkflowItem(**item_dict)) | |
| elif item_type == "widget": | |
| items.append(WidgetItem(**item_dict)) | |
| elif item_type == "task": | |
| items.append(TaskItem(**item_dict)) | |
| elif item_type == "hidden_context_item": | |
| items.append(HiddenContextItem(**item_dict)) | |
| else: | |
| print(f"⚠️ Skipping unknown item type in _items: {item_type}") | |
| continue | |
| except (ImportError, TypeError, KeyError, ValueError, Exception) as e: | |
| print(f"⚠️ Failed to reconstruct item in _items: {e}") | |
| continue | |
| return items | |
| def close(self): | |
| """Close the database connection.""" | |
| with self._lock: | |
| self.conn.close() | |