Spaces:
Build error
Build error
| import logging | |
| import time | |
| import uuid | |
| from typing import Optional | |
| from sqlalchemy import select, delete | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from open_webui.internal.db import Base, JSONField, get_async_db_context | |
| from pydantic import BaseModel, ConfigDict | |
| from sqlalchemy import BigInteger, Column, ForeignKey, Text, JSON | |
| log = logging.getLogger(__name__) | |
| #################### | |
| # SharedChat DB Schema | |
| #################### | |
| class SharedChat(Base): | |
| __tablename__ = 'shared_chat' | |
| id = Column(Text, primary_key=True) # The share token (UUID) — used in /s/{id} URL | |
| chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False) | |
| user_id = Column(Text, nullable=False) # Who created this share | |
| title = Column(Text) | |
| chat = Column(JSON) # Snapshot of chat JSON at share time | |
| created_at = Column(BigInteger) | |
| updated_at = Column(BigInteger) | |
| class SharedChatModel(BaseModel): | |
| model_config = ConfigDict(from_attributes=True) | |
| id: str | |
| chat_id: str | |
| user_id: str | |
| title: str | |
| chat: dict | |
| created_at: int | |
| updated_at: int | |
| class SharedChatResponse(BaseModel): | |
| id: str | |
| chat_id: str | |
| title: str | |
| share_id: Optional[str] = None # Alias for id, for backward compat | |
| updated_at: int | |
| created_at: int | |
| #################### | |
| # Table Operations | |
| #################### | |
| class SharedChatsTable: | |
| async def create(self, chat_id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]: | |
| """ | |
| Create a snapshot of the chat for link sharing. | |
| Returns the SharedChatModel with the share token as its id. | |
| """ | |
| async with get_async_db_context(db) as db: | |
| from open_webui.models.chats import Chat | |
| chat = await db.get(Chat, chat_id) | |
| if not chat: | |
| return None | |
| share_id = str(uuid.uuid4()) | |
| now = int(time.time()) | |
| shared_chat = SharedChat( | |
| id=share_id, | |
| chat_id=chat_id, | |
| user_id=user_id, | |
| title=chat.title, | |
| chat=chat.chat, | |
| created_at=now, | |
| updated_at=now, | |
| ) | |
| db.add(shared_chat) | |
| await db.commit() | |
| await db.refresh(shared_chat) | |
| return SharedChatModel.model_validate(shared_chat) | |
| async def update(self, share_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]: | |
| """ | |
| Re-snapshot: update the shared chat with the current state of the original chat. | |
| """ | |
| async with get_async_db_context(db) as db: | |
| from open_webui.models.chats import Chat | |
| shared_chat = await db.get(SharedChat, share_id) | |
| if not shared_chat: | |
| return None | |
| chat = await db.get(Chat, shared_chat.chat_id) | |
| if not chat: | |
| return None | |
| shared_chat.title = chat.title | |
| shared_chat.chat = chat.chat | |
| shared_chat.updated_at = int(time.time()) | |
| await db.commit() | |
| await db.refresh(shared_chat) | |
| return SharedChatModel.model_validate(shared_chat) | |
| async def get_by_id(self, share_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]: | |
| """Get a shared chat by its share token.""" | |
| async with get_async_db_context(db) as db: | |
| shared_chat = await db.get(SharedChat, share_id) | |
| if shared_chat: | |
| return SharedChatModel.model_validate(shared_chat) | |
| return None | |
| async def get_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]: | |
| """Get the shared chat for a given original chat. Returns the most recent one.""" | |
| async with get_async_db_context(db) as db: | |
| result = await db.execute( | |
| select(SharedChat).filter_by(chat_id=chat_id).order_by(SharedChat.updated_at.desc()).limit(1) | |
| ) | |
| shared_chat = result.scalars().first() | |
| if shared_chat: | |
| return SharedChatModel.model_validate(shared_chat) | |
| return None | |
| async def get_by_user_id( | |
| self, | |
| user_id: str, | |
| filter: Optional[dict] = None, | |
| skip: int = 0, | |
| limit: int = 50, | |
| db: Optional[AsyncSession] = None, | |
| ) -> list[SharedChatResponse]: | |
| """List all shared chats created by a user.""" | |
| async with get_async_db_context(db) as db: | |
| stmt = select(SharedChat).filter_by(user_id=user_id) | |
| if filter: | |
| query_key = filter.get('query') | |
| if query_key: | |
| stmt = stmt.filter(SharedChat.title.ilike(f'%{query_key}%')) | |
| order_by = filter.get('order_by') | |
| direction = filter.get('direction') | |
| if order_by and direction: | |
| col = getattr(SharedChat, order_by, None) | |
| if not col: | |
| raise ValueError('Invalid order_by field') | |
| if direction.lower() == 'asc': | |
| stmt = stmt.order_by(col.asc()) | |
| elif direction.lower() == 'desc': | |
| stmt = stmt.order_by(col.desc()) | |
| else: | |
| raise ValueError('Invalid direction for ordering') | |
| else: | |
| stmt = stmt.order_by(SharedChat.updated_at.desc()) | |
| if skip: | |
| stmt = stmt.offset(skip) | |
| if limit: | |
| stmt = stmt.limit(limit) | |
| result = await db.execute(stmt) | |
| return [ | |
| SharedChatResponse( | |
| id=sc.chat_id, | |
| chat_id=sc.chat_id, | |
| title=sc.title, | |
| share_id=sc.id, | |
| updated_at=sc.updated_at, | |
| created_at=sc.created_at, | |
| ) | |
| for sc in result.scalars().all() | |
| ] | |
| async def delete_by_id(self, share_id: str, db: Optional[AsyncSession] = None) -> bool: | |
| """Delete a shared chat by its share token.""" | |
| try: | |
| async with get_async_db_context(db) as db: | |
| await db.execute(delete(SharedChat).filter_by(id=share_id)) | |
| await db.commit() | |
| return True | |
| except Exception: | |
| return False | |
| async def delete_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> bool: | |
| """Delete all shared chats for a given original chat.""" | |
| try: | |
| async with get_async_db_context(db) as db: | |
| await db.execute(delete(SharedChat).filter_by(chat_id=chat_id)) | |
| await db.commit() | |
| return True | |
| except Exception: | |
| return False | |
| SharedChats = SharedChatsTable() | |