Spaces:
Build error
Build error
| import json | |
| import time | |
| import uuid | |
| from typing import Any, Optional | |
| from sqlalchemy import select, delete, func, cast, Integer | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from open_webui.internal.db import Base, get_async_db_context | |
| from open_webui.utils.response import normalize_usage | |
| from pydantic import BaseModel, ConfigDict | |
| from sqlalchemy import ( | |
| BigInteger, | |
| Boolean, | |
| Column, | |
| ForeignKey, | |
| Text, | |
| JSON, | |
| Index, | |
| ) | |
| #################### | |
| # Helpers | |
| #################### | |
| def _normalize_timestamp(timestamp: int) -> float: | |
| """Normalize and validate timestamp. Returns current time if invalid.""" | |
| now = time.time() | |
| # Convert milliseconds to seconds if needed | |
| if timestamp > 10_000_000_000: | |
| timestamp = timestamp / 1000 | |
| # Validate: must be after 2020 and not in the future (with 1 day tolerance) | |
| min_valid = 1577836800 # 2020-01-01 00:00:00 UTC | |
| max_valid = now + 86400 # 1 day in the future (clock skew tolerance) | |
| if timestamp < min_valid or timestamp > max_valid: | |
| return now | |
| return timestamp | |
| def get_usage(data: dict) -> Optional[dict]: | |
| """Extract and normalize usage from message data.""" | |
| usage = data.get('usage') or (data.get('info') or {}).get('usage') | |
| return normalize_usage(usage) if usage else None | |
| #################### | |
| # ChatMessage DB Schema | |
| #################### | |
| class ChatMessage(Base): | |
| __tablename__ = 'chat_message' | |
| # Identity | |
| id = Column(Text, primary_key=True) | |
| chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, index=True) | |
| user_id = Column(Text, index=True) | |
| # Structure | |
| role = Column(Text, nullable=False) # user, assistant, system | |
| parent_id = Column(Text, nullable=True) | |
| # Content | |
| content = Column(JSON, nullable=True) # Can be str or list of blocks | |
| output = Column(JSON, nullable=True) | |
| # Model (for assistant messages) | |
| model_id = Column(Text, nullable=True, index=True) | |
| # Attachments | |
| files = Column(JSON, nullable=True) | |
| sources = Column(JSON, nullable=True) | |
| embeds = Column(JSON, nullable=True) | |
| # Status | |
| done = Column(Boolean, default=True) | |
| status_history = Column(JSON, nullable=True) | |
| error = Column(JSON, nullable=True) | |
| # Usage (tokens, timing, etc.) | |
| usage = Column(JSON, nullable=True) | |
| # Timestamps | |
| created_at = Column(BigInteger, index=True) | |
| updated_at = Column(BigInteger) | |
| __table_args__ = ( | |
| Index('chat_message_chat_parent_idx', 'chat_id', 'parent_id'), | |
| Index('chat_message_model_created_idx', 'model_id', 'created_at'), | |
| Index('chat_message_user_created_idx', 'user_id', 'created_at'), | |
| ) | |
| #################### | |
| # Pydantic Models | |
| #################### | |
| class ChatMessageModel(BaseModel): | |
| model_config = ConfigDict(from_attributes=True) | |
| id: str | |
| chat_id: str | |
| user_id: str | |
| role: str | |
| parent_id: Optional[str] = None | |
| content: Optional[Any] = None # str or list of blocks | |
| output: Optional[list] = None | |
| model_id: Optional[str] = None | |
| files: Optional[list] = None | |
| sources: Optional[list] = None | |
| embeds: Optional[list] = None | |
| done: bool = True | |
| status_history: Optional[list] = None | |
| error: Optional[dict | str] = None | |
| usage: Optional[dict] = None | |
| created_at: int | |
| updated_at: int | |
| #################### | |
| # Table Operations | |
| #################### | |
| class ChatMessageTable: | |
| async def upsert_message( | |
| self, | |
| message_id: str, | |
| chat_id: str, | |
| user_id: str, | |
| data: dict, | |
| db: Optional[AsyncSession] = None, | |
| ) -> Optional[ChatMessageModel]: | |
| """Insert or update a chat message.""" | |
| async with get_async_db_context(db) as db: | |
| now = int(time.time()) | |
| timestamp = data.get('timestamp', now) | |
| # Use composite ID: {chat_id}-{message_id} | |
| composite_id = f'{chat_id}-{message_id}' | |
| existing = await db.get(ChatMessage, composite_id) | |
| if existing: | |
| # Update existing | |
| if 'role' in data: | |
| existing.role = data['role'] | |
| if 'parent_id' in data: | |
| existing.parent_id = data.get('parent_id') or data.get('parentId') | |
| if 'content' in data: | |
| existing.content = data.get('content') | |
| if 'output' in data: | |
| existing.output = data.get('output') | |
| if 'model_id' in data or 'model' in data: | |
| existing.model_id = data.get('model_id') or data.get('model') | |
| if 'files' in data: | |
| existing.files = data.get('files') | |
| if 'sources' in data: | |
| existing.sources = data.get('sources') | |
| if 'embeds' in data: | |
| existing.embeds = data.get('embeds') | |
| if 'done' in data: | |
| existing.done = data.get('done', True) | |
| if 'status_history' in data or 'statusHistory' in data: | |
| existing.status_history = data.get('status_history') or data.get('statusHistory') | |
| if 'error' in data: | |
| existing.error = data.get('error') | |
| # Extract and normalize usage | |
| usage = get_usage(data) | |
| if usage: | |
| # Deep-merge: preserve existing keys not present in new data | |
| # This prevents background tasks (follow-ups, title, tags) | |
| # from accidentally clearing the primary response's token counts | |
| existing.usage = {**(existing.usage or {}), **usage} | |
| existing.updated_at = now | |
| await db.commit() | |
| await db.refresh(existing) | |
| return ChatMessageModel.model_validate(existing) | |
| else: | |
| # Insert new | |
| # Extract and normalize usage | |
| usage = get_usage(data) | |
| message = ChatMessage( | |
| id=composite_id, | |
| chat_id=chat_id, | |
| user_id=user_id, | |
| role=data.get('role', 'user'), | |
| parent_id=data.get('parent_id') or data.get('parentId'), | |
| content=data.get('content'), | |
| output=data.get('output'), | |
| model_id=data.get('model_id') or data.get('model'), | |
| files=data.get('files'), | |
| sources=data.get('sources'), | |
| embeds=data.get('embeds'), | |
| done=data.get('done', True), | |
| status_history=data.get('status_history') or data.get('statusHistory'), | |
| error=data.get('error'), | |
| usage=usage, | |
| created_at=timestamp, | |
| updated_at=now, | |
| ) | |
| db.add(message) | |
| await db.commit() | |
| await db.refresh(message) | |
| return ChatMessageModel.model_validate(message) | |
| async def get_message_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatMessageModel]: | |
| async with get_async_db_context(db) as db: | |
| message = await db.get(ChatMessage, id) | |
| return ChatMessageModel.model_validate(message) if message else None | |
| async def get_messages_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> list[ChatMessageModel]: | |
| async with get_async_db_context(db) as db: | |
| result = await db.execute( | |
| select(ChatMessage).filter_by(chat_id=chat_id).order_by(ChatMessage.created_at.asc()) | |
| ) | |
| messages = result.scalars().all() | |
| return [ChatMessageModel.model_validate(message) for message in messages] | |
| async def get_messages_by_user_id( | |
| self, | |
| user_id: str, | |
| skip: int = 0, | |
| limit: int = 50, | |
| db: Optional[AsyncSession] = None, | |
| ) -> list[ChatMessageModel]: | |
| async with get_async_db_context(db) as db: | |
| result = await db.execute( | |
| select(ChatMessage) | |
| .filter_by(user_id=user_id) | |
| .order_by(ChatMessage.created_at.desc()) | |
| .offset(skip) | |
| .limit(limit) | |
| ) | |
| messages = result.scalars().all() | |
| return [ChatMessageModel.model_validate(message) for message in messages] | |
| async def get_messages_by_model_id( | |
| self, | |
| model_id: str, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| skip: int = 0, | |
| limit: int = 100, | |
| db: Optional[AsyncSession] = None, | |
| ) -> list[ChatMessageModel]: | |
| async with get_async_db_context(db) as db: | |
| stmt = select(ChatMessage).filter_by(model_id=model_id) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| stmt = stmt.order_by(ChatMessage.created_at.desc()).offset(skip).limit(limit) | |
| result = await db.execute(stmt) | |
| messages = result.scalars().all() | |
| return [ChatMessageModel.model_validate(message) for message in messages] | |
| async def get_chat_ids_by_model_id( | |
| self, | |
| model_id: str, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| skip: int = 0, | |
| limit: int = 50, | |
| db: Optional[AsyncSession] = None, | |
| ) -> list[str]: | |
| """Get distinct chat_ids that used a specific model.""" | |
| async with get_async_db_context(db) as db: | |
| stmt = select( | |
| ChatMessage.chat_id, | |
| func.max(ChatMessage.created_at).label('last_message_at'), | |
| ).filter(ChatMessage.model_id == model_id) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| # Group by chat_id and order by most recent message in each chat | |
| # Secondary sort on chat_id ensures deterministic pagination | |
| stmt = ( | |
| stmt.group_by(ChatMessage.chat_id) | |
| .order_by(func.max(ChatMessage.created_at).desc(), ChatMessage.chat_id) | |
| .offset(skip) | |
| .limit(limit) | |
| ) | |
| result = await db.execute(stmt) | |
| chat_ids = result.all() | |
| return [chat_id for chat_id, _ in chat_ids] | |
| async def delete_messages_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> bool: | |
| async with get_async_db_context(db) as db: | |
| await db.execute(delete(ChatMessage).filter_by(chat_id=chat_id)) | |
| await db.commit() | |
| return True | |
| # Analytics methods | |
| async def get_message_count_by_model( | |
| self, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| group_id: Optional[str] = None, | |
| db: Optional[AsyncSession] = None, | |
| ) -> dict[str, int]: | |
| async with get_async_db_context(db) as db: | |
| from open_webui.models.groups import GroupMember | |
| stmt = select(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter( | |
| ChatMessage.role == 'assistant', | |
| ChatMessage.model_id.isnot(None), | |
| ) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| if group_id: | |
| group_users = select(GroupMember.user_id).filter(GroupMember.group_id == group_id).scalar_subquery() | |
| stmt = stmt.filter(ChatMessage.user_id.in_(group_users)) | |
| stmt = stmt.group_by(ChatMessage.model_id) | |
| result = await db.execute(stmt) | |
| return {row.model_id: row.count for row in result.all()} | |
| async def get_token_usage_by_model( | |
| self, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| group_id: Optional[str] = None, | |
| db: Optional[AsyncSession] = None, | |
| ) -> dict[str, dict]: | |
| """Aggregate token usage by model using database-level aggregation.""" | |
| async with get_async_db_context(db) as db: | |
| from open_webui.models.groups import GroupMember | |
| # We need the dialect to determine JSON extraction syntax | |
| # For async sessions, access via get_bind() | |
| bind = await db.connection() | |
| dialect = bind.dialect.name | |
| if dialect == 'sqlite': | |
| input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer) | |
| output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer) | |
| elif dialect == 'postgresql': | |
| input_tokens = cast( | |
| func.json_extract_path_text(ChatMessage.usage, 'input_tokens'), | |
| Integer, | |
| ) | |
| output_tokens = cast( | |
| func.json_extract_path_text(ChatMessage.usage, 'output_tokens'), | |
| Integer, | |
| ) | |
| else: | |
| raise NotImplementedError(f'Unsupported dialect: {dialect}') | |
| stmt = select( | |
| ChatMessage.model_id, | |
| func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), | |
| func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), | |
| func.count(ChatMessage.id).label('message_count'), | |
| ).filter( | |
| ChatMessage.role == 'assistant', | |
| ChatMessage.model_id.isnot(None), | |
| ChatMessage.usage.isnot(None), | |
| ) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| if group_id: | |
| group_users = select(GroupMember.user_id).filter(GroupMember.group_id == group_id).scalar_subquery() | |
| stmt = stmt.filter(ChatMessage.user_id.in_(group_users)) | |
| stmt = stmt.group_by(ChatMessage.model_id) | |
| result = await db.execute(stmt) | |
| return { | |
| row.model_id: { | |
| 'input_tokens': row.input_tokens, | |
| 'output_tokens': row.output_tokens, | |
| 'total_tokens': row.input_tokens + row.output_tokens, | |
| 'message_count': row.message_count, | |
| } | |
| for row in result.all() | |
| } | |
| async def get_token_usage_by_user( | |
| self, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| group_id: Optional[str] = None, | |
| db: Optional[AsyncSession] = None, | |
| ) -> dict[str, dict]: | |
| """Aggregate token usage by user using database-level aggregation.""" | |
| async with get_async_db_context(db) as db: | |
| from open_webui.models.groups import GroupMember | |
| bind = await db.connection() | |
| dialect = bind.dialect.name | |
| if dialect == 'sqlite': | |
| input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer) | |
| output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer) | |
| elif dialect == 'postgresql': | |
| input_tokens = cast( | |
| func.json_extract_path_text(ChatMessage.usage, 'input_tokens'), | |
| Integer, | |
| ) | |
| output_tokens = cast( | |
| func.json_extract_path_text(ChatMessage.usage, 'output_tokens'), | |
| Integer, | |
| ) | |
| else: | |
| raise NotImplementedError(f'Unsupported dialect: {dialect}') | |
| stmt = select( | |
| ChatMessage.user_id, | |
| func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), | |
| func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), | |
| func.count(ChatMessage.id).label('message_count'), | |
| ).filter( | |
| ChatMessage.role == 'assistant', | |
| ChatMessage.user_id.isnot(None), | |
| ChatMessage.usage.isnot(None), | |
| ) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| if group_id: | |
| group_users = select(GroupMember.user_id).filter(GroupMember.group_id == group_id).scalar_subquery() | |
| stmt = stmt.filter(ChatMessage.user_id.in_(group_users)) | |
| stmt = stmt.group_by(ChatMessage.user_id) | |
| result = await db.execute(stmt) | |
| return { | |
| row.user_id: { | |
| 'input_tokens': row.input_tokens, | |
| 'output_tokens': row.output_tokens, | |
| 'total_tokens': row.input_tokens + row.output_tokens, | |
| 'message_count': row.message_count, | |
| } | |
| for row in result.all() | |
| } | |
| async def get_message_count_by_user( | |
| self, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| group_id: Optional[str] = None, | |
| db: Optional[AsyncSession] = None, | |
| ) -> dict[str, int]: | |
| async with get_async_db_context(db) as db: | |
| from open_webui.models.groups import GroupMember | |
| stmt = select(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter( | |
| ChatMessage.role == 'assistant', | |
| ) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| if group_id: | |
| group_users = select(GroupMember.user_id).filter(GroupMember.group_id == group_id).scalar_subquery() | |
| stmt = stmt.filter(ChatMessage.user_id.in_(group_users)) | |
| stmt = stmt.group_by(ChatMessage.user_id) | |
| result = await db.execute(stmt) | |
| return {row.user_id: row.count for row in result.all()} | |
| async def get_message_count_by_chat( | |
| self, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| group_id: Optional[str] = None, | |
| db: Optional[AsyncSession] = None, | |
| ) -> dict[str, int]: | |
| async with get_async_db_context(db) as db: | |
| from open_webui.models.groups import GroupMember | |
| stmt = select(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter( | |
| ChatMessage.role == 'assistant', | |
| ) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| if group_id: | |
| group_users = select(GroupMember.user_id).filter(GroupMember.group_id == group_id).scalar_subquery() | |
| stmt = stmt.filter(ChatMessage.user_id.in_(group_users)) | |
| stmt = stmt.group_by(ChatMessage.chat_id) | |
| result = await db.execute(stmt) | |
| return {row.chat_id: row.count for row in result.all()} | |
| async def get_daily_message_counts_by_model( | |
| self, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| group_id: Optional[str] = None, | |
| db: Optional[AsyncSession] = None, | |
| ) -> dict[str, dict[str, int]]: | |
| """Get message counts grouped by day and model.""" | |
| async with get_async_db_context(db) as db: | |
| from datetime import datetime, timedelta | |
| from open_webui.models.groups import GroupMember | |
| stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter( | |
| ChatMessage.role == 'assistant', | |
| ChatMessage.model_id.isnot(None), | |
| ) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| if group_id: | |
| group_users = select(GroupMember.user_id).filter(GroupMember.group_id == group_id).scalar_subquery() | |
| stmt = stmt.filter(ChatMessage.user_id.in_(group_users)) | |
| result = await db.execute(stmt) | |
| results = result.all() | |
| # Group by date -> model -> count | |
| daily_counts: dict[str, dict[str, int]] = {} | |
| for timestamp, model_id in results: | |
| date_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d') | |
| if date_str not in daily_counts: | |
| daily_counts[date_str] = {} | |
| daily_counts[date_str][model_id] = daily_counts[date_str].get(model_id, 0) + 1 | |
| # Fill in missing days | |
| if start_date and end_date: | |
| current = datetime.fromtimestamp(_normalize_timestamp(start_date)) | |
| end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) | |
| while current <= end_dt: | |
| date_str = current.strftime('%Y-%m-%d') | |
| if date_str not in daily_counts: | |
| daily_counts[date_str] = {} | |
| current += timedelta(days=1) | |
| return daily_counts | |
| async def get_hourly_message_counts_by_model( | |
| self, | |
| start_date: Optional[int] = None, | |
| end_date: Optional[int] = None, | |
| db: Optional[AsyncSession] = None, | |
| ) -> dict[str, dict[str, int]]: | |
| """Get message counts grouped by hour and model.""" | |
| async with get_async_db_context(db) as db: | |
| from datetime import datetime, timedelta | |
| stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter( | |
| ChatMessage.role == 'assistant', | |
| ChatMessage.model_id.isnot(None), | |
| ) | |
| if start_date: | |
| stmt = stmt.filter(ChatMessage.created_at >= start_date) | |
| if end_date: | |
| stmt = stmt.filter(ChatMessage.created_at <= end_date) | |
| result = await db.execute(stmt) | |
| results = result.all() | |
| # Group by hour -> model -> count | |
| hourly_counts: dict[str, dict[str, int]] = {} | |
| for timestamp, model_id in results: | |
| hour_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d %H:00') | |
| if hour_str not in hourly_counts: | |
| hourly_counts[hour_str] = {} | |
| hourly_counts[hour_str][model_id] = hourly_counts[hour_str].get(model_id, 0) + 1 | |
| # Fill in missing hours | |
| if start_date and end_date: | |
| current = datetime.fromtimestamp(_normalize_timestamp(start_date)).replace( | |
| minute=0, second=0, microsecond=0 | |
| ) | |
| end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) | |
| while current <= end_dt: | |
| hour_str = current.strftime('%Y-%m-%d %H:00') | |
| if hour_str not in hourly_counts: | |
| hourly_counts[hour_str] = {} | |
| current += timedelta(hours=1) | |
| return hourly_counts | |
| ChatMessages = ChatMessageTable() | |