Spaces:
Build error
Build error
File size: 6,920 Bytes
87a665c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import logging
import time
import uuid
from typing import Optional
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, ForeignKey, Text, JSON
log = logging.getLogger(__name__)
####################
# SharedChat DB Schema
####################
class SharedChat(Base):
__tablename__ = 'shared_chat'
id = Column(Text, primary_key=True) # The share token (UUID) — used in /s/{id} URL
chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False)
user_id = Column(Text, nullable=False) # Who created this share
title = Column(Text)
chat = Column(JSON) # Snapshot of chat JSON at share time
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class SharedChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
chat_id: str
user_id: str
title: str
chat: dict
created_at: int
updated_at: int
class SharedChatResponse(BaseModel):
id: str
chat_id: str
title: str
share_id: Optional[str] = None # Alias for id, for backward compat
updated_at: int
created_at: int
####################
# Table Operations
####################
class SharedChatsTable:
async def create(self, chat_id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
"""
Create a snapshot of the chat for link sharing.
Returns the SharedChatModel with the share token as its id.
"""
async with get_async_db_context(db) as db:
from open_webui.models.chats import Chat
chat = await db.get(Chat, chat_id)
if not chat:
return None
share_id = str(uuid.uuid4())
now = int(time.time())
shared_chat = SharedChat(
id=share_id,
chat_id=chat_id,
user_id=user_id,
title=chat.title,
chat=chat.chat,
created_at=now,
updated_at=now,
)
db.add(shared_chat)
await db.commit()
await db.refresh(shared_chat)
return SharedChatModel.model_validate(shared_chat)
async def update(self, share_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
"""
Re-snapshot: update the shared chat with the current state of the original chat.
"""
async with get_async_db_context(db) as db:
from open_webui.models.chats import Chat
shared_chat = await db.get(SharedChat, share_id)
if not shared_chat:
return None
chat = await db.get(Chat, shared_chat.chat_id)
if not chat:
return None
shared_chat.title = chat.title
shared_chat.chat = chat.chat
shared_chat.updated_at = int(time.time())
await db.commit()
await db.refresh(shared_chat)
return SharedChatModel.model_validate(shared_chat)
async def get_by_id(self, share_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
"""Get a shared chat by its share token."""
async with get_async_db_context(db) as db:
shared_chat = await db.get(SharedChat, share_id)
if shared_chat:
return SharedChatModel.model_validate(shared_chat)
return None
async def get_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
"""Get the shared chat for a given original chat. Returns the most recent one."""
async with get_async_db_context(db) as db:
result = await db.execute(
select(SharedChat).filter_by(chat_id=chat_id).order_by(SharedChat.updated_at.desc()).limit(1)
)
shared_chat = result.scalars().first()
if shared_chat:
return SharedChatModel.model_validate(shared_chat)
return None
async def get_by_user_id(
self,
user_id: str,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
db: Optional[AsyncSession] = None,
) -> list[SharedChatResponse]:
"""List all shared chats created by a user."""
async with get_async_db_context(db) as db:
stmt = select(SharedChat).filter_by(user_id=user_id)
if filter:
query_key = filter.get('query')
if query_key:
stmt = stmt.filter(SharedChat.title.ilike(f'%{query_key}%'))
order_by = filter.get('order_by')
direction = filter.get('direction')
if order_by and direction:
col = getattr(SharedChat, order_by, None)
if not col:
raise ValueError('Invalid order_by field')
if direction.lower() == 'asc':
stmt = stmt.order_by(col.asc())
elif direction.lower() == 'desc':
stmt = stmt.order_by(col.desc())
else:
raise ValueError('Invalid direction for ordering')
else:
stmt = stmt.order_by(SharedChat.updated_at.desc())
if skip:
stmt = stmt.offset(skip)
if limit:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
return [
SharedChatResponse(
id=sc.chat_id,
chat_id=sc.chat_id,
title=sc.title,
share_id=sc.id,
updated_at=sc.updated_at,
created_at=sc.created_at,
)
for sc in result.scalars().all()
]
async def delete_by_id(self, share_id: str, db: Optional[AsyncSession] = None) -> bool:
"""Delete a shared chat by its share token."""
try:
async with get_async_db_context(db) as db:
await db.execute(delete(SharedChat).filter_by(id=share_id))
await db.commit()
return True
except Exception:
return False
async def delete_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> bool:
"""Delete all shared chats for a given original chat."""
try:
async with get_async_db_context(db) as db:
await db.execute(delete(SharedChat).filter_by(chat_id=chat_id))
await db.commit()
return True
except Exception:
return False
SharedChats = SharedChatsTable()
|