Spaces:
Build error
Build error
| import json | |
| import time | |
| import uuid | |
| from typing import Optional | |
| from sqlalchemy.orm import Session | |
| from open_webui.internal.db import Base, JSONField, get_db, get_db_context | |
| from open_webui.models.tags import TagModel, Tag, Tags | |
| from open_webui.models.users import Users, User, UserNameResponse | |
| from open_webui.models.channels import Channels, ChannelMember | |
| from pydantic import BaseModel, ConfigDict, field_validator | |
| from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON | |
| from sqlalchemy import or_, func, select, and_, text | |
| from sqlalchemy.sql import exists | |
| #################### | |
| # Message DB Schema | |
| #################### | |
| class MessageReaction(Base): | |
| __tablename__ = "message_reaction" | |
| id = Column(Text, primary_key=True, unique=True) | |
| user_id = Column(Text) | |
| message_id = Column(Text) | |
| name = Column(Text) | |
| created_at = Column(BigInteger) | |
| class MessageReactionModel(BaseModel): | |
| model_config = ConfigDict(from_attributes=True) | |
| id: str | |
| user_id: str | |
| message_id: str | |
| name: str | |
| created_at: int # timestamp in epoch | |
| class Message(Base): | |
| __tablename__ = "message" | |
| id = Column(Text, primary_key=True, unique=True) | |
| user_id = Column(Text) | |
| channel_id = Column(Text, nullable=True) | |
| reply_to_id = Column(Text, nullable=True) | |
| parent_id = Column(Text, nullable=True) | |
| # Pins | |
| is_pinned = Column(Boolean, nullable=False, default=False) | |
| pinned_at = Column(BigInteger, nullable=True) | |
| pinned_by = Column(Text, nullable=True) | |
| content = Column(Text) | |
| data = Column(JSON, nullable=True) | |
| meta = Column(JSON, nullable=True) | |
| created_at = Column(BigInteger) # time_ns | |
| updated_at = Column(BigInteger) # time_ns | |
| class MessageModel(BaseModel): | |
| model_config = ConfigDict(from_attributes=True) | |
| id: str | |
| user_id: str | |
| channel_id: Optional[str] = None | |
| reply_to_id: Optional[str] = None | |
| parent_id: Optional[str] = None | |
| # Pins | |
| is_pinned: bool = False | |
| pinned_by: Optional[str] = None | |
| pinned_at: Optional[int] = None # timestamp in epoch (time_ns) | |
| content: str | |
| data: Optional[dict] = None | |
| meta: Optional[dict] = None | |
| created_at: int # timestamp in epoch (time_ns) | |
| updated_at: int # timestamp in epoch (time_ns) | |
| #################### | |
| # Forms | |
| #################### | |
| class MessageForm(BaseModel): | |
| temp_id: Optional[str] = None | |
| content: str | |
| reply_to_id: Optional[str] = None | |
| parent_id: Optional[str] = None | |
| data: Optional[dict] = None | |
| meta: Optional[dict] = None | |
| class Reactions(BaseModel): | |
| name: str | |
| users: list[dict] | |
| count: int | |
| class MessageUserResponse(MessageModel): | |
| user: Optional[UserNameResponse] = None | |
| class MessageUserSlimResponse(MessageUserResponse): | |
| data: bool | None = None | |
| def convert_data_to_bool(cls, v): | |
| # No data or not a dict → False | |
| if not isinstance(v, dict): | |
| return False | |
| # True if ANY value in the dict is non-empty | |
| return any(bool(val) for val in v.values()) | |
| class MessageReplyToResponse(MessageUserResponse): | |
| reply_to_message: Optional[MessageUserSlimResponse] = None | |
| class MessageWithReactionsResponse(MessageUserSlimResponse): | |
| reactions: list[Reactions] | |
| class MessageResponse(MessageReplyToResponse): | |
| latest_reply_at: Optional[int] | |
| reply_count: int | |
| reactions: list[Reactions] | |
| class MessageTable: | |
| def insert_new_message( | |
| self, | |
| form_data: MessageForm, | |
| channel_id: str, | |
| user_id: str, | |
| db: Optional[Session] = None, | |
| ) -> Optional[MessageModel]: | |
| with get_db_context(db) as db: | |
| channel_member = Channels.join_channel(channel_id, user_id) | |
| id = str(uuid.uuid4()) | |
| ts = int(time.time_ns()) | |
| message = MessageModel( | |
| **{ | |
| "id": id, | |
| "user_id": user_id, | |
| "channel_id": channel_id, | |
| "reply_to_id": form_data.reply_to_id, | |
| "parent_id": form_data.parent_id, | |
| "is_pinned": False, | |
| "pinned_at": None, | |
| "pinned_by": None, | |
| "content": form_data.content, | |
| "data": form_data.data, | |
| "meta": form_data.meta, | |
| "created_at": ts, | |
| "updated_at": ts, | |
| } | |
| ) | |
| result = Message(**message.model_dump()) | |
| db.add(result) | |
| db.commit() | |
| db.refresh(result) | |
| return MessageModel.model_validate(result) if result else None | |
| def get_message_by_id( | |
| self, | |
| id: str, | |
| include_thread_replies: Optional[bool] = True, | |
| db: Optional[Session] = None, | |
| ) -> Optional[MessageResponse]: | |
| with get_db_context(db) as db: | |
| message = db.get(Message, id) | |
| if not message: | |
| return None | |
| reply_to_message = ( | |
| self.get_message_by_id( | |
| message.reply_to_id, include_thread_replies=False, db=db | |
| ) | |
| if message.reply_to_id | |
| else None | |
| ) | |
| reactions = self.get_reactions_by_message_id(id, db=db) | |
| thread_replies = [] | |
| if include_thread_replies: | |
| thread_replies = self.get_thread_replies_by_message_id(id, db=db) | |
| # Check if message was sent by webhook (webhook info in meta takes precedence) | |
| webhook_info = message.meta.get("webhook") if message.meta else None | |
| if webhook_info and webhook_info.get("id"): | |
| # Look up webhook by ID to get current name | |
| webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) | |
| if webhook: | |
| user_info = { | |
| "id": webhook.id, | |
| "name": webhook.name, | |
| "role": "webhook", | |
| } | |
| else: | |
| # Webhook was deleted, use placeholder | |
| user_info = { | |
| "id": webhook_info.get("id"), | |
| "name": "Deleted Webhook", | |
| "role": "webhook", | |
| } | |
| else: | |
| user = Users.get_user_by_id(message.user_id, db=db) | |
| user_info = user.model_dump() if user else None | |
| return MessageResponse.model_validate( | |
| { | |
| **MessageModel.model_validate(message).model_dump(), | |
| "user": user_info, | |
| "reply_to_message": ( | |
| reply_to_message.model_dump() if reply_to_message else None | |
| ), | |
| "latest_reply_at": ( | |
| thread_replies[0].created_at if thread_replies else None | |
| ), | |
| "reply_count": len(thread_replies), | |
| "reactions": reactions, | |
| } | |
| ) | |
| def get_thread_replies_by_message_id( | |
| self, id: str, db: Optional[Session] = None | |
| ) -> list[MessageReplyToResponse]: | |
| with get_db_context(db) as db: | |
| all_messages = ( | |
| db.query(Message) | |
| .filter_by(parent_id=id) | |
| .order_by(Message.created_at.desc()) | |
| .all() | |
| ) | |
| messages = [] | |
| for message in all_messages: | |
| reply_to_message = ( | |
| self.get_message_by_id( | |
| message.reply_to_id, include_thread_replies=False, db=db | |
| ) | |
| if message.reply_to_id | |
| else None | |
| ) | |
| webhook_info = message.meta.get("webhook") if message.meta else None | |
| user_info = None | |
| if webhook_info and webhook_info.get("id"): | |
| webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) | |
| if webhook: | |
| user_info = { | |
| "id": webhook.id, | |
| "name": webhook.name, | |
| "role": "webhook", | |
| } | |
| else: | |
| user_info = { | |
| "id": webhook_info.get("id"), | |
| "name": "Deleted Webhook", | |
| "role": "webhook", | |
| } | |
| messages.append( | |
| MessageReplyToResponse.model_validate( | |
| { | |
| **MessageModel.model_validate(message).model_dump(), | |
| "user": user_info, | |
| "reply_to_message": ( | |
| reply_to_message.model_dump() | |
| if reply_to_message | |
| else None | |
| ), | |
| } | |
| ) | |
| ) | |
| return messages | |
| def get_reply_user_ids_by_message_id( | |
| self, id: str, db: Optional[Session] = None | |
| ) -> list[str]: | |
| with get_db_context(db) as db: | |
| return [ | |
| message.user_id | |
| for message in db.query(Message).filter_by(parent_id=id).all() | |
| ] | |
| def get_messages_by_channel_id( | |
| self, | |
| channel_id: str, | |
| skip: int = 0, | |
| limit: int = 50, | |
| db: Optional[Session] = None, | |
| ) -> list[MessageReplyToResponse]: | |
| with get_db_context(db) as db: | |
| all_messages = ( | |
| db.query(Message) | |
| .filter_by(channel_id=channel_id, parent_id=None) | |
| .order_by(Message.created_at.desc()) | |
| .offset(skip) | |
| .limit(limit) | |
| .all() | |
| ) | |
| messages = [] | |
| for message in all_messages: | |
| reply_to_message = ( | |
| self.get_message_by_id( | |
| message.reply_to_id, include_thread_replies=False, db=db | |
| ) | |
| if message.reply_to_id | |
| else None | |
| ) | |
| webhook_info = message.meta.get("webhook") if message.meta else None | |
| user_info = None | |
| if webhook_info and webhook_info.get("id"): | |
| webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) | |
| if webhook: | |
| user_info = { | |
| "id": webhook.id, | |
| "name": webhook.name, | |
| "role": "webhook", | |
| } | |
| else: | |
| user_info = { | |
| "id": webhook_info.get("id"), | |
| "name": "Deleted Webhook", | |
| "role": "webhook", | |
| } | |
| messages.append( | |
| MessageReplyToResponse.model_validate( | |
| { | |
| **MessageModel.model_validate(message).model_dump(), | |
| "user": user_info, | |
| "reply_to_message": ( | |
| reply_to_message.model_dump() | |
| if reply_to_message | |
| else None | |
| ), | |
| } | |
| ) | |
| ) | |
| return messages | |
| def get_messages_by_parent_id( | |
| self, | |
| channel_id: str, | |
| parent_id: str, | |
| skip: int = 0, | |
| limit: int = 50, | |
| db: Optional[Session] = None, | |
| ) -> list[MessageReplyToResponse]: | |
| with get_db_context(db) as db: | |
| message = db.get(Message, parent_id) | |
| if not message: | |
| return [] | |
| all_messages = ( | |
| db.query(Message) | |
| .filter_by(channel_id=channel_id, parent_id=parent_id) | |
| .order_by(Message.created_at.desc()) | |
| .offset(skip) | |
| .limit(limit) | |
| .all() | |
| ) | |
| # If length of all_messages is less than limit, then add the parent message | |
| if len(all_messages) < limit: | |
| all_messages.append(message) | |
| messages = [] | |
| for message in all_messages: | |
| reply_to_message = ( | |
| self.get_message_by_id( | |
| message.reply_to_id, include_thread_replies=False, db=db | |
| ) | |
| if message.reply_to_id | |
| else None | |
| ) | |
| webhook_info = message.meta.get("webhook") if message.meta else None | |
| user_info = None | |
| if webhook_info and webhook_info.get("id"): | |
| webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) | |
| if webhook: | |
| user_info = { | |
| "id": webhook.id, | |
| "name": webhook.name, | |
| "role": "webhook", | |
| } | |
| else: | |
| user_info = { | |
| "id": webhook_info.get("id"), | |
| "name": "Deleted Webhook", | |
| "role": "webhook", | |
| } | |
| messages.append( | |
| MessageReplyToResponse.model_validate( | |
| { | |
| **MessageModel.model_validate(message).model_dump(), | |
| "user": user_info, | |
| "reply_to_message": ( | |
| reply_to_message.model_dump() | |
| if reply_to_message | |
| else None | |
| ), | |
| } | |
| ) | |
| ) | |
| return messages | |
| def get_last_message_by_channel_id( | |
| self, channel_id: str, db: Optional[Session] = None | |
| ) -> Optional[MessageModel]: | |
| with get_db_context(db) as db: | |
| message = ( | |
| db.query(Message) | |
| .filter_by(channel_id=channel_id) | |
| .order_by(Message.created_at.desc()) | |
| .first() | |
| ) | |
| return MessageModel.model_validate(message) if message else None | |
| def get_pinned_messages_by_channel_id( | |
| self, | |
| channel_id: str, | |
| skip: int = 0, | |
| limit: int = 50, | |
| db: Optional[Session] = None, | |
| ) -> list[MessageModel]: | |
| with get_db_context(db) as db: | |
| all_messages = ( | |
| db.query(Message) | |
| .filter_by(channel_id=channel_id, is_pinned=True) | |
| .order_by(Message.pinned_at.desc()) | |
| .offset(skip) | |
| .limit(limit) | |
| .all() | |
| ) | |
| return [MessageModel.model_validate(message) for message in all_messages] | |
| def update_message_by_id( | |
| self, id: str, form_data: MessageForm, db: Optional[Session] = None | |
| ) -> Optional[MessageModel]: | |
| with get_db_context(db) as db: | |
| message = db.get(Message, id) | |
| message.content = form_data.content | |
| message.data = { | |
| **(message.data if message.data else {}), | |
| **(form_data.data if form_data.data else {}), | |
| } | |
| message.meta = { | |
| **(message.meta if message.meta else {}), | |
| **(form_data.meta if form_data.meta else {}), | |
| } | |
| message.updated_at = int(time.time_ns()) | |
| db.commit() | |
| db.refresh(message) | |
| return MessageModel.model_validate(message) if message else None | |
| def update_is_pinned_by_id( | |
| self, | |
| id: str, | |
| is_pinned: bool, | |
| pinned_by: Optional[str] = None, | |
| db: Optional[Session] = None, | |
| ) -> Optional[MessageModel]: | |
| with get_db_context(db) as db: | |
| message = db.get(Message, id) | |
| message.is_pinned = is_pinned | |
| message.pinned_at = int(time.time_ns()) if is_pinned else None | |
| message.pinned_by = pinned_by if is_pinned else None | |
| db.commit() | |
| db.refresh(message) | |
| return MessageModel.model_validate(message) if message else None | |
| def get_unread_message_count( | |
| self, | |
| channel_id: str, | |
| user_id: str, | |
| last_read_at: Optional[int] = None, | |
| db: Optional[Session] = None, | |
| ) -> int: | |
| with get_db_context(db) as db: | |
| query = db.query(Message).filter( | |
| Message.channel_id == channel_id, | |
| Message.parent_id == None, # only count top-level messages | |
| Message.created_at > (last_read_at if last_read_at else 0), | |
| ) | |
| if user_id: | |
| query = query.filter(Message.user_id != user_id) | |
| return query.count() | |
| def add_reaction_to_message( | |
| self, id: str, user_id: str, name: str, db: Optional[Session] = None | |
| ) -> Optional[MessageReactionModel]: | |
| with get_db_context(db) as db: | |
| # check for existing reaction | |
| existing_reaction = ( | |
| db.query(MessageReaction) | |
| .filter_by(message_id=id, user_id=user_id, name=name) | |
| .first() | |
| ) | |
| if existing_reaction: | |
| return MessageReactionModel.model_validate(existing_reaction) | |
| reaction_id = str(uuid.uuid4()) | |
| reaction = MessageReactionModel( | |
| id=reaction_id, | |
| user_id=user_id, | |
| message_id=id, | |
| name=name, | |
| created_at=int(time.time_ns()), | |
| ) | |
| result = MessageReaction(**reaction.model_dump()) | |
| db.add(result) | |
| db.commit() | |
| db.refresh(result) | |
| return MessageReactionModel.model_validate(result) if result else None | |
| def get_reactions_by_message_id( | |
| self, id: str, db: Optional[Session] = None | |
| ) -> list[Reactions]: | |
| with get_db_context(db) as db: | |
| # JOIN User so all user info is fetched in one query | |
| results = ( | |
| db.query(MessageReaction, User) | |
| .join(User, MessageReaction.user_id == User.id) | |
| .filter(MessageReaction.message_id == id) | |
| .all() | |
| ) | |
| reactions = {} | |
| for reaction, user in results: | |
| if reaction.name not in reactions: | |
| reactions[reaction.name] = { | |
| "name": reaction.name, | |
| "users": [], | |
| "count": 0, | |
| } | |
| reactions[reaction.name]["users"].append( | |
| { | |
| "id": user.id, | |
| "name": user.name, | |
| } | |
| ) | |
| reactions[reaction.name]["count"] += 1 | |
| return [Reactions(**reaction) for reaction in reactions.values()] | |
| def remove_reaction_by_id_and_user_id_and_name( | |
| self, id: str, user_id: str, name: str, db: Optional[Session] = None | |
| ) -> bool: | |
| with get_db_context(db) as db: | |
| db.query(MessageReaction).filter_by( | |
| message_id=id, user_id=user_id, name=name | |
| ).delete() | |
| db.commit() | |
| return True | |
| def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool: | |
| with get_db_context(db) as db: | |
| db.query(MessageReaction).filter_by(message_id=id).delete() | |
| db.commit() | |
| return True | |
| def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool: | |
| with get_db_context(db) as db: | |
| db.query(Message).filter_by(parent_id=id).delete() | |
| db.commit() | |
| return True | |
| def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool: | |
| with get_db_context(db) as db: | |
| db.query(Message).filter_by(id=id).delete() | |
| # Delete all reactions to this message | |
| db.query(MessageReaction).filter_by(message_id=id).delete() | |
| db.commit() | |
| return True | |
| def search_messages_by_channel_ids( | |
| self, | |
| channel_ids: list[str], | |
| query: str, | |
| start_timestamp: Optional[int] = None, | |
| end_timestamp: Optional[int] = None, | |
| limit: int = 10, | |
| db: Optional[Session] = None, | |
| ) -> list[MessageModel]: | |
| """Search messages in specified channels by content.""" | |
| with get_db_context(db) as db: | |
| query_builder = db.query(Message).filter( | |
| Message.channel_id.in_(channel_ids), | |
| Message.content.ilike(f"%{query}%"), | |
| ) | |
| if start_timestamp: | |
| query_builder = query_builder.filter( | |
| Message.created_at >= start_timestamp | |
| ) | |
| if end_timestamp: | |
| query_builder = query_builder.filter( | |
| Message.created_at <= end_timestamp | |
| ) | |
| messages = ( | |
| query_builder.order_by(Message.created_at.desc()).limit(limit).all() | |
| ) | |
| return [MessageModel.model_validate(msg) for msg in messages] | |
| Messages = MessageTable() | |