File size: 6,920 Bytes
87a665c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import logging
import time
import uuid
from typing import Optional

from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context

from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, ForeignKey, Text, JSON

log = logging.getLogger(__name__)

####################
# SharedChat DB Schema
####################


class SharedChat(Base):
    __tablename__ = 'shared_chat'

    id = Column(Text, primary_key=True)  # The share token (UUID) — used in /s/{id} URL
    chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False)
    user_id = Column(Text, nullable=False)  # Who created this share

    title = Column(Text)
    chat = Column(JSON)  # Snapshot of chat JSON at share time

    created_at = Column(BigInteger)
    updated_at = Column(BigInteger)


class SharedChatModel(BaseModel):
    model_config = ConfigDict(from_attributes=True)

    id: str
    chat_id: str
    user_id: str

    title: str
    chat: dict

    created_at: int
    updated_at: int


class SharedChatResponse(BaseModel):
    id: str
    chat_id: str
    title: str
    share_id: Optional[str] = None  # Alias for id, for backward compat
    updated_at: int
    created_at: int


####################
# Table Operations
####################


class SharedChatsTable:
    async def create(self, chat_id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
        """
        Create a snapshot of the chat for link sharing.
        Returns the SharedChatModel with the share token as its id.
        """
        async with get_async_db_context(db) as db:
            from open_webui.models.chats import Chat

            chat = await db.get(Chat, chat_id)
            if not chat:
                return None

            share_id = str(uuid.uuid4())
            now = int(time.time())

            shared_chat = SharedChat(
                id=share_id,
                chat_id=chat_id,
                user_id=user_id,
                title=chat.title,
                chat=chat.chat,
                created_at=now,
                updated_at=now,
            )
            db.add(shared_chat)
            await db.commit()
            await db.refresh(shared_chat)

            return SharedChatModel.model_validate(shared_chat)

    async def update(self, share_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
        """
        Re-snapshot: update the shared chat with the current state of the original chat.
        """
        async with get_async_db_context(db) as db:
            from open_webui.models.chats import Chat

            shared_chat = await db.get(SharedChat, share_id)
            if not shared_chat:
                return None

            chat = await db.get(Chat, shared_chat.chat_id)
            if not chat:
                return None

            shared_chat.title = chat.title
            shared_chat.chat = chat.chat
            shared_chat.updated_at = int(time.time())

            await db.commit()
            await db.refresh(shared_chat)
            return SharedChatModel.model_validate(shared_chat)

    async def get_by_id(self, share_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
        """Get a shared chat by its share token."""
        async with get_async_db_context(db) as db:
            shared_chat = await db.get(SharedChat, share_id)
            if shared_chat:
                return SharedChatModel.model_validate(shared_chat)
            return None

    async def get_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[SharedChatModel]:
        """Get the shared chat for a given original chat. Returns the most recent one."""
        async with get_async_db_context(db) as db:
            result = await db.execute(
                select(SharedChat).filter_by(chat_id=chat_id).order_by(SharedChat.updated_at.desc()).limit(1)
            )
            shared_chat = result.scalars().first()
            if shared_chat:
                return SharedChatModel.model_validate(shared_chat)
            return None

    async def get_by_user_id(
        self,
        user_id: str,
        filter: Optional[dict] = None,
        skip: int = 0,
        limit: int = 50,
        db: Optional[AsyncSession] = None,
    ) -> list[SharedChatResponse]:
        """List all shared chats created by a user."""
        async with get_async_db_context(db) as db:
            stmt = select(SharedChat).filter_by(user_id=user_id)

            if filter:
                query_key = filter.get('query')
                if query_key:
                    stmt = stmt.filter(SharedChat.title.ilike(f'%{query_key}%'))

                order_by = filter.get('order_by')
                direction = filter.get('direction')

                if order_by and direction:
                    col = getattr(SharedChat, order_by, None)
                    if not col:
                        raise ValueError('Invalid order_by field')
                    if direction.lower() == 'asc':
                        stmt = stmt.order_by(col.asc())
                    elif direction.lower() == 'desc':
                        stmt = stmt.order_by(col.desc())
                    else:
                        raise ValueError('Invalid direction for ordering')
            else:
                stmt = stmt.order_by(SharedChat.updated_at.desc())

            if skip:
                stmt = stmt.offset(skip)
            if limit:
                stmt = stmt.limit(limit)

            result = await db.execute(stmt)
            return [
                SharedChatResponse(
                    id=sc.chat_id,
                    chat_id=sc.chat_id,
                    title=sc.title,
                    share_id=sc.id,
                    updated_at=sc.updated_at,
                    created_at=sc.created_at,
                )
                for sc in result.scalars().all()
            ]

    async def delete_by_id(self, share_id: str, db: Optional[AsyncSession] = None) -> bool:
        """Delete a shared chat by its share token."""
        try:
            async with get_async_db_context(db) as db:
                await db.execute(delete(SharedChat).filter_by(id=share_id))
                await db.commit()
                return True
        except Exception:
            return False

    async def delete_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> bool:
        """Delete all shared chats for a given original chat."""
        try:
            async with get_async_db_context(db) as db:
                await db.execute(delete(SharedChat).filter_by(chat_id=chat_id))
                await db.commit()
                return True
        except Exception:
            return False


SharedChats = SharedChatsTable()