Spaces:
Runtime error
Runtime error
| from sqlalchemy.orm import Session, joinedload | |
| import models, schemas | |
| from password_utils import get_password_hash | |
| from typing import List, Optional | |
| # --- User CRUD --- | |
| def get_user_by_email(db: Session, email: str): | |
| return db.query(models.User).filter(models.User.email == email).first() | |
| def create_user(db: Session, user: schemas.UserCreate): | |
| hashed_password = get_password_hash(user.password) | |
| db_user = models.User( | |
| email=user.email, | |
| hashed_password=hashed_password, | |
| display_name=user.display_name, | |
| favorite_team=user.favorite_team, | |
| banter_level=user.banter_level, | |
| ) | |
| db.add(db_user) | |
| db.commit() | |
| db.refresh(db_user) | |
| return db_user | |
| def update_user_settings(db: Session, user: models.User, settings: schemas.UserSettingsUpdate) -> models.User: | |
| """Updates the settings for a given user.""" | |
| user.sound_effects_enabled = settings.sound_effects_enabled | |
| user.favorite_team = settings.favorite_team | |
| user.banter_level = settings.banter_level | |
| db.commit() | |
| db.refresh(user) | |
| return user | |
| # --- Conversation CRUD --- | |
| def get_or_create_conversation(db: Session, user_id: int) -> models.Conversation: | |
| """Find the most recent conversation for a user, or create a new one.""" | |
| conversation = db.query(models.Conversation).filter(models.Conversation.owner_id == user_id).order_by(models.Conversation.created_at.desc()).first() | |
| if not conversation: | |
| conversation = models.Conversation(owner_id=user_id) | |
| db.add(conversation) | |
| db.commit() | |
| db.refresh(conversation) | |
| return conversation | |
| def add_message_to_conversation(db: Session, conversation_id: int, role: str, content: str) -> models.Message: | |
| """Adds a new message to a conversation.""" | |
| message = models.Message(conversation_id=conversation_id, role=role, content=content) | |
| db.add(message) | |
| db.commit() | |
| db.refresh(message) | |
| return message | |
| def get_conversation_history(db: Session, conversation_id: int) -> list[models.Message]: | |
| """Retriees all messages for a given conversation.""" | |
| return db.query(models.Message).filter(models.Message.conversation_id == conversation_id).order_by(models.Message.created_at).all() | |
| def get_conversations_for_user( | |
| db, | |
| user_id: int, | |
| skip: int = 0, | |
| limit: int = 20 | |
| ) -> list: | |
| """ | |
| Return paginated conversations for a given user (default: 20 per page). | |
| """ | |
| return ( | |
| db.query(models.Conversation) | |
| .filter(models.Conversation.owner_id == user_id) | |
| .order_by(models.Conversation.created_at.desc()) | |
| .offset(skip) | |
| .limit(limit) | |
| .all() | |
| ) | |
| def count_conversations_for_user(db, user_id: int) -> int: | |
| """ | |
| Return the total number of conversations for pagination info. | |
| """ | |
| return ( | |
| db.query(models.Conversation) | |
| .filter(models.Conversation.owner_id == user_id) | |
| .count() | |
| ) | |
| def get_conversation_by_id(db, message_id: int): | |
| return db.query(models.Message).filter(models.Message.id == message_id).first() | |
| # --- Global Message CRUD --- | |
| def create_global_message(db: Session, author_id: int, content: str) -> models.GlobalMessage: | |
| """Creates a new message in the global chat.""" | |
| db_message = models.GlobalMessage(author_id=author_id, content=content) | |
| db.add(db_message) | |
| db.commit() | |
| db.refresh(db_message) | |
| return db_message | |
| def get_global_messages(db: Session, skip: int = 0, limit: int = 100) -> List[models.GlobalMessage]: | |
| """Returns list of global messages with their authors ordered by most recent""" | |
| return db.query(models.GlobalMessage).options(joinedload(models.GlobalMessage.author)).order_by(models.GlobalMessage.created_at.desc()).offset(skip).limit(limit).all() | |
| # --- Reaction CRUD --- | |
| def add_or_remove_reaction(db: Session, user_id: int, message_id: int, emoji: str) -> models.GlobalMessage: | |
| """Adds or removes a reaction from a message.""" | |
| # Check if the reaction already exists | |
| existing_reaction = db.query(models.MessageReaction).filter( | |
| models.MessageReaction.message_id == message_id, | |
| models.MessageReaction.user_id == user_id, | |
| models.MessageReaction.emoji == emoji | |
| ).first() | |
| if existing_reaction: | |
| # If it exists, delete it | |
| db.delete(existing_reaction) | |
| else: | |
| # If it doesn't exist, create it | |
| db_reaction = models.MessageReaction( | |
| message_id=message_id, | |
| user_id=user_id, | |
| emoji=emoji | |
| ) | |
| db.add(db_reaction) | |
| db.commit() | |
| # Return the message with updated reactions | |
| return db.query(models.GlobalMessage).filter(models.GlobalMessage.id == message_id).first() | |
| def get_user_count(db): | |
| return db.query(models.User).count() | |
| def get_conversation_count(db): | |
| return db.query(models.Conversation).count() | |
| def get_message_count(db): | |
| return db.query(models.Message).count() | |
| def get_average_messages_per_user(db): | |
| total_users = get_user_count(db) | |
| total_messages = get_message_count(db) | |
| return (total_messages / total_users) if total_users else 0 | |