| | import json |
| | import time |
| | import uuid |
| | from typing import Optional |
| |
|
| | from sqlalchemy.orm import Session |
| | from open_webui.internal.db import Base, JSONField, get_db, get_db_context |
| | from open_webui.models.tags import TagModel, Tag, Tags |
| | from open_webui.models.users import Users, User, UserNameResponse |
| | from open_webui.models.channels import Channels, ChannelMember |
| |
|
| |
|
| | from pydantic import BaseModel, ConfigDict, field_validator |
| | from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON |
| | from sqlalchemy import or_, func, select, and_, text |
| | from sqlalchemy.sql import exists |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class MessageReaction(Base): |
| | __tablename__ = "message_reaction" |
| | id = Column(Text, primary_key=True, unique=True) |
| | user_id = Column(Text) |
| | message_id = Column(Text) |
| | name = Column(Text) |
| | created_at = Column(BigInteger) |
| |
|
| |
|
| | class MessageReactionModel(BaseModel): |
| | model_config = ConfigDict(from_attributes=True) |
| |
|
| | id: str |
| | user_id: str |
| | message_id: str |
| | name: str |
| | created_at: int |
| |
|
| |
|
| | class Message(Base): |
| | __tablename__ = "message" |
| | id = Column(Text, primary_key=True, unique=True) |
| |
|
| | user_id = Column(Text) |
| | channel_id = Column(Text, nullable=True) |
| |
|
| | reply_to_id = Column(Text, nullable=True) |
| | parent_id = Column(Text, nullable=True) |
| |
|
| | |
| | is_pinned = Column(Boolean, nullable=False, default=False) |
| | pinned_at = Column(BigInteger, nullable=True) |
| | pinned_by = Column(Text, nullable=True) |
| |
|
| | content = Column(Text) |
| | data = Column(JSON, nullable=True) |
| | meta = Column(JSON, nullable=True) |
| |
|
| | created_at = Column(BigInteger) |
| | updated_at = Column(BigInteger) |
| |
|
| |
|
| | class MessageModel(BaseModel): |
| | model_config = ConfigDict(from_attributes=True) |
| |
|
| | id: str |
| | user_id: str |
| | channel_id: Optional[str] = None |
| |
|
| | reply_to_id: Optional[str] = None |
| | parent_id: Optional[str] = None |
| |
|
| | |
| | is_pinned: bool = False |
| | pinned_by: Optional[str] = None |
| | pinned_at: Optional[int] = None |
| |
|
| | content: str |
| | data: Optional[dict] = None |
| | meta: Optional[dict] = None |
| |
|
| | created_at: int |
| | updated_at: int |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class MessageForm(BaseModel): |
| | temp_id: Optional[str] = None |
| | content: str |
| | reply_to_id: Optional[str] = None |
| | parent_id: Optional[str] = None |
| | data: Optional[dict] = None |
| | meta: Optional[dict] = None |
| |
|
| |
|
| | class Reactions(BaseModel): |
| | name: str |
| | users: list[dict] |
| | count: int |
| |
|
| |
|
| | class MessageUserResponse(MessageModel): |
| | user: Optional[UserNameResponse] = None |
| |
|
| |
|
| | class MessageUserSlimResponse(MessageUserResponse): |
| | data: bool | None = None |
| |
|
| | @field_validator("data", mode="before") |
| | def convert_data_to_bool(cls, v): |
| | |
| | if not isinstance(v, dict): |
| | return False |
| |
|
| | |
| | return any(bool(val) for val in v.values()) |
| |
|
| |
|
| | class MessageReplyToResponse(MessageUserResponse): |
| | reply_to_message: Optional[MessageUserSlimResponse] = None |
| |
|
| |
|
| | class MessageWithReactionsResponse(MessageUserSlimResponse): |
| | reactions: list[Reactions] |
| |
|
| |
|
| | class MessageResponse(MessageReplyToResponse): |
| | latest_reply_at: Optional[int] |
| | reply_count: int |
| | reactions: list[Reactions] |
| |
|
| |
|
| | class MessageTable: |
| | def insert_new_message( |
| | self, |
| | form_data: MessageForm, |
| | channel_id: str, |
| | user_id: str, |
| | db: Optional[Session] = None, |
| | ) -> Optional[MessageModel]: |
| | with get_db_context(db) as db: |
| | channel_member = Channels.join_channel(channel_id, user_id) |
| |
|
| | id = str(uuid.uuid4()) |
| | ts = int(time.time_ns()) |
| |
|
| | message = MessageModel( |
| | **{ |
| | "id": id, |
| | "user_id": user_id, |
| | "channel_id": channel_id, |
| | "reply_to_id": form_data.reply_to_id, |
| | "parent_id": form_data.parent_id, |
| | "is_pinned": False, |
| | "pinned_at": None, |
| | "pinned_by": None, |
| | "content": form_data.content, |
| | "data": form_data.data, |
| | "meta": form_data.meta, |
| | "created_at": ts, |
| | "updated_at": ts, |
| | } |
| | ) |
| | result = Message(**message.model_dump()) |
| |
|
| | db.add(result) |
| | db.commit() |
| | db.refresh(result) |
| | return MessageModel.model_validate(result) if result else None |
| |
|
| | def get_message_by_id( |
| | self, |
| | id: str, |
| | include_thread_replies: Optional[bool] = True, |
| | db: Optional[Session] = None, |
| | ) -> Optional[MessageResponse]: |
| | with get_db_context(db) as db: |
| | message = db.get(Message, id) |
| | if not message: |
| | return None |
| |
|
| | reply_to_message = ( |
| | self.get_message_by_id( |
| | message.reply_to_id, include_thread_replies=False, db=db |
| | ) |
| | if message.reply_to_id |
| | else None |
| | ) |
| |
|
| | reactions = self.get_reactions_by_message_id(id, db=db) |
| |
|
| | thread_replies = [] |
| | if include_thread_replies: |
| | thread_replies = self.get_thread_replies_by_message_id(id, db=db) |
| |
|
| | |
| | webhook_info = message.meta.get("webhook") if message.meta else None |
| | if webhook_info and webhook_info.get("id"): |
| | |
| | webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) |
| | if webhook: |
| | user_info = { |
| | "id": webhook.id, |
| | "name": webhook.name, |
| | "role": "webhook", |
| | } |
| | else: |
| | |
| | user_info = { |
| | "id": webhook_info.get("id"), |
| | "name": "Deleted Webhook", |
| | "role": "webhook", |
| | } |
| | else: |
| | user = Users.get_user_by_id(message.user_id, db=db) |
| | user_info = user.model_dump() if user else None |
| |
|
| | return MessageResponse.model_validate( |
| | { |
| | **MessageModel.model_validate(message).model_dump(), |
| | "user": user_info, |
| | "reply_to_message": ( |
| | reply_to_message.model_dump() if reply_to_message else None |
| | ), |
| | "latest_reply_at": ( |
| | thread_replies[0].created_at if thread_replies else None |
| | ), |
| | "reply_count": len(thread_replies), |
| | "reactions": reactions, |
| | } |
| | ) |
| |
|
| | def get_thread_replies_by_message_id( |
| | self, id: str, db: Optional[Session] = None |
| | ) -> list[MessageReplyToResponse]: |
| | with get_db_context(db) as db: |
| | all_messages = ( |
| | db.query(Message) |
| | .filter_by(parent_id=id) |
| | .order_by(Message.created_at.desc()) |
| | .all() |
| | ) |
| |
|
| | messages = [] |
| | for message in all_messages: |
| | reply_to_message = ( |
| | self.get_message_by_id( |
| | message.reply_to_id, include_thread_replies=False, db=db |
| | ) |
| | if message.reply_to_id |
| | else None |
| | ) |
| |
|
| | webhook_info = message.meta.get("webhook") if message.meta else None |
| | user_info = None |
| | if webhook_info and webhook_info.get("id"): |
| | webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) |
| | if webhook: |
| | user_info = { |
| | "id": webhook.id, |
| | "name": webhook.name, |
| | "role": "webhook", |
| | } |
| | else: |
| | user_info = { |
| | "id": webhook_info.get("id"), |
| | "name": "Deleted Webhook", |
| | "role": "webhook", |
| | } |
| |
|
| | messages.append( |
| | MessageReplyToResponse.model_validate( |
| | { |
| | **MessageModel.model_validate(message).model_dump(), |
| | "user": user_info, |
| | "reply_to_message": ( |
| | reply_to_message.model_dump() |
| | if reply_to_message |
| | else None |
| | ), |
| | } |
| | ) |
| | ) |
| | return messages |
| |
|
| | def get_reply_user_ids_by_message_id( |
| | self, id: str, db: Optional[Session] = None |
| | ) -> list[str]: |
| | with get_db_context(db) as db: |
| | return [ |
| | message.user_id |
| | for message in db.query(Message).filter_by(parent_id=id).all() |
| | ] |
| |
|
| | def get_messages_by_channel_id( |
| | self, |
| | channel_id: str, |
| | skip: int = 0, |
| | limit: int = 50, |
| | db: Optional[Session] = None, |
| | ) -> list[MessageReplyToResponse]: |
| | with get_db_context(db) as db: |
| | all_messages = ( |
| | db.query(Message) |
| | .filter_by(channel_id=channel_id, parent_id=None) |
| | .order_by(Message.created_at.desc()) |
| | .offset(skip) |
| | .limit(limit) |
| | .all() |
| | ) |
| |
|
| | messages = [] |
| | for message in all_messages: |
| | reply_to_message = ( |
| | self.get_message_by_id( |
| | message.reply_to_id, include_thread_replies=False, db=db |
| | ) |
| | if message.reply_to_id |
| | else None |
| | ) |
| |
|
| | webhook_info = message.meta.get("webhook") if message.meta else None |
| | user_info = None |
| | if webhook_info and webhook_info.get("id"): |
| | webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) |
| | if webhook: |
| | user_info = { |
| | "id": webhook.id, |
| | "name": webhook.name, |
| | "role": "webhook", |
| | } |
| | else: |
| | user_info = { |
| | "id": webhook_info.get("id"), |
| | "name": "Deleted Webhook", |
| | "role": "webhook", |
| | } |
| |
|
| | messages.append( |
| | MessageReplyToResponse.model_validate( |
| | { |
| | **MessageModel.model_validate(message).model_dump(), |
| | "user": user_info, |
| | "reply_to_message": ( |
| | reply_to_message.model_dump() |
| | if reply_to_message |
| | else None |
| | ), |
| | } |
| | ) |
| | ) |
| | return messages |
| |
|
| | def get_messages_by_parent_id( |
| | self, |
| | channel_id: str, |
| | parent_id: str, |
| | skip: int = 0, |
| | limit: int = 50, |
| | db: Optional[Session] = None, |
| | ) -> list[MessageReplyToResponse]: |
| | with get_db_context(db) as db: |
| | message = db.get(Message, parent_id) |
| |
|
| | if not message: |
| | return [] |
| |
|
| | all_messages = ( |
| | db.query(Message) |
| | .filter_by(channel_id=channel_id, parent_id=parent_id) |
| | .order_by(Message.created_at.desc()) |
| | .offset(skip) |
| | .limit(limit) |
| | .all() |
| | ) |
| |
|
| | |
| | if len(all_messages) < limit: |
| | all_messages.append(message) |
| |
|
| | messages = [] |
| | for message in all_messages: |
| | reply_to_message = ( |
| | self.get_message_by_id( |
| | message.reply_to_id, include_thread_replies=False, db=db |
| | ) |
| | if message.reply_to_id |
| | else None |
| | ) |
| |
|
| | webhook_info = message.meta.get("webhook") if message.meta else None |
| | user_info = None |
| | if webhook_info and webhook_info.get("id"): |
| | webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) |
| | if webhook: |
| | user_info = { |
| | "id": webhook.id, |
| | "name": webhook.name, |
| | "role": "webhook", |
| | } |
| | else: |
| | user_info = { |
| | "id": webhook_info.get("id"), |
| | "name": "Deleted Webhook", |
| | "role": "webhook", |
| | } |
| |
|
| | messages.append( |
| | MessageReplyToResponse.model_validate( |
| | { |
| | **MessageModel.model_validate(message).model_dump(), |
| | "user": user_info, |
| | "reply_to_message": ( |
| | reply_to_message.model_dump() |
| | if reply_to_message |
| | else None |
| | ), |
| | } |
| | ) |
| | ) |
| | return messages |
| |
|
| | def get_last_message_by_channel_id( |
| | self, channel_id: str, db: Optional[Session] = None |
| | ) -> Optional[MessageModel]: |
| | with get_db_context(db) as db: |
| | message = ( |
| | db.query(Message) |
| | .filter_by(channel_id=channel_id) |
| | .order_by(Message.created_at.desc()) |
| | .first() |
| | ) |
| | return MessageModel.model_validate(message) if message else None |
| |
|
| | def get_pinned_messages_by_channel_id( |
| | self, |
| | channel_id: str, |
| | skip: int = 0, |
| | limit: int = 50, |
| | db: Optional[Session] = None, |
| | ) -> list[MessageModel]: |
| | with get_db_context(db) as db: |
| | all_messages = ( |
| | db.query(Message) |
| | .filter_by(channel_id=channel_id, is_pinned=True) |
| | .order_by(Message.pinned_at.desc()) |
| | .offset(skip) |
| | .limit(limit) |
| | .all() |
| | ) |
| | return [MessageModel.model_validate(message) for message in all_messages] |
| |
|
| | def update_message_by_id( |
| | self, id: str, form_data: MessageForm, db: Optional[Session] = None |
| | ) -> Optional[MessageModel]: |
| | with get_db_context(db) as db: |
| | message = db.get(Message, id) |
| | message.content = form_data.content |
| | message.data = { |
| | **(message.data if message.data else {}), |
| | **(form_data.data if form_data.data else {}), |
| | } |
| | message.meta = { |
| | **(message.meta if message.meta else {}), |
| | **(form_data.meta if form_data.meta else {}), |
| | } |
| | message.updated_at = int(time.time_ns()) |
| | db.commit() |
| | db.refresh(message) |
| | return MessageModel.model_validate(message) if message else None |
| |
|
| | def update_is_pinned_by_id( |
| | self, |
| | id: str, |
| | is_pinned: bool, |
| | pinned_by: Optional[str] = None, |
| | db: Optional[Session] = None, |
| | ) -> Optional[MessageModel]: |
| | with get_db_context(db) as db: |
| | message = db.get(Message, id) |
| | message.is_pinned = is_pinned |
| | message.pinned_at = int(time.time_ns()) if is_pinned else None |
| | message.pinned_by = pinned_by if is_pinned else None |
| | db.commit() |
| | db.refresh(message) |
| | return MessageModel.model_validate(message) if message else None |
| |
|
| | def get_unread_message_count( |
| | self, |
| | channel_id: str, |
| | user_id: str, |
| | last_read_at: Optional[int] = None, |
| | db: Optional[Session] = None, |
| | ) -> int: |
| | with get_db_context(db) as db: |
| | query = db.query(Message).filter( |
| | Message.channel_id == channel_id, |
| | Message.parent_id == None, |
| | Message.created_at > (last_read_at if last_read_at else 0), |
| | ) |
| | if user_id: |
| | query = query.filter(Message.user_id != user_id) |
| | return query.count() |
| |
|
| | def add_reaction_to_message( |
| | self, id: str, user_id: str, name: str, db: Optional[Session] = None |
| | ) -> Optional[MessageReactionModel]: |
| | with get_db_context(db) as db: |
| | |
| | existing_reaction = ( |
| | db.query(MessageReaction) |
| | .filter_by(message_id=id, user_id=user_id, name=name) |
| | .first() |
| | ) |
| | if existing_reaction: |
| | return MessageReactionModel.model_validate(existing_reaction) |
| |
|
| | reaction_id = str(uuid.uuid4()) |
| | reaction = MessageReactionModel( |
| | id=reaction_id, |
| | user_id=user_id, |
| | message_id=id, |
| | name=name, |
| | created_at=int(time.time_ns()), |
| | ) |
| | result = MessageReaction(**reaction.model_dump()) |
| | db.add(result) |
| | db.commit() |
| | db.refresh(result) |
| | return MessageReactionModel.model_validate(result) if result else None |
| |
|
| | def get_reactions_by_message_id( |
| | self, id: str, db: Optional[Session] = None |
| | ) -> list[Reactions]: |
| | with get_db_context(db) as db: |
| | |
| | results = ( |
| | db.query(MessageReaction, User) |
| | .join(User, MessageReaction.user_id == User.id) |
| | .filter(MessageReaction.message_id == id) |
| | .all() |
| | ) |
| |
|
| | reactions = {} |
| |
|
| | for reaction, user in results: |
| | if reaction.name not in reactions: |
| | reactions[reaction.name] = { |
| | "name": reaction.name, |
| | "users": [], |
| | "count": 0, |
| | } |
| |
|
| | reactions[reaction.name]["users"].append( |
| | { |
| | "id": user.id, |
| | "name": user.name, |
| | } |
| | ) |
| | reactions[reaction.name]["count"] += 1 |
| |
|
| | return [Reactions(**reaction) for reaction in reactions.values()] |
| |
|
| | def remove_reaction_by_id_and_user_id_and_name( |
| | self, id: str, user_id: str, name: str, db: Optional[Session] = None |
| | ) -> bool: |
| | with get_db_context(db) as db: |
| | db.query(MessageReaction).filter_by( |
| | message_id=id, user_id=user_id, name=name |
| | ).delete() |
| | db.commit() |
| | return True |
| |
|
| | def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool: |
| | with get_db_context(db) as db: |
| | db.query(MessageReaction).filter_by(message_id=id).delete() |
| | db.commit() |
| | return True |
| |
|
| | def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool: |
| | with get_db_context(db) as db: |
| | db.query(Message).filter_by(parent_id=id).delete() |
| | db.commit() |
| | return True |
| |
|
| | def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool: |
| | with get_db_context(db) as db: |
| | db.query(Message).filter_by(id=id).delete() |
| |
|
| | |
| | db.query(MessageReaction).filter_by(message_id=id).delete() |
| |
|
| | db.commit() |
| | return True |
| |
|
| | def search_messages_by_channel_ids( |
| | self, |
| | channel_ids: list[str], |
| | query: str, |
| | start_timestamp: Optional[int] = None, |
| | end_timestamp: Optional[int] = None, |
| | limit: int = 10, |
| | db: Optional[Session] = None, |
| | ) -> list[MessageModel]: |
| | """Search messages in specified channels by content.""" |
| | with get_db_context(db) as db: |
| | query_builder = db.query(Message).filter( |
| | Message.channel_id.in_(channel_ids), |
| | Message.content.ilike(f"%{query}%"), |
| | ) |
| |
|
| | if start_timestamp: |
| | query_builder = query_builder.filter( |
| | Message.created_at >= start_timestamp |
| | ) |
| | if end_timestamp: |
| | query_builder = query_builder.filter( |
| | Message.created_at <= end_timestamp |
| | ) |
| |
|
| | messages = ( |
| | query_builder.order_by(Message.created_at.desc()).limit(limit).all() |
| | ) |
| | return [MessageModel.model_validate(msg) for msg in messages] |
| |
|
| |
|
| | Messages = MessageTable() |
| |
|