banter-api / crud.py
EbukaGaus's picture
push
d5a3ec4
from sqlalchemy.orm import Session, joinedload
import models, schemas
from password_utils import get_password_hash
from typing import List, Optional
# --- User CRUD ---
def get_user_by_email(db: Session, email: str):
return db.query(models.User).filter(models.User.email == email).first()
def create_user(db: Session, user: schemas.UserCreate):
hashed_password = get_password_hash(user.password)
db_user = models.User(
email=user.email,
hashed_password=hashed_password,
display_name=user.display_name,
favorite_team=user.favorite_team,
banter_level=user.banter_level,
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def update_user_settings(db: Session, user: models.User, settings: schemas.UserSettingsUpdate) -> models.User:
"""Updates the settings for a given user."""
user.sound_effects_enabled = settings.sound_effects_enabled
user.favorite_team = settings.favorite_team
user.banter_level = settings.banter_level
db.commit()
db.refresh(user)
return user
# --- Conversation CRUD ---
def get_or_create_conversation(db: Session, user_id: int) -> models.Conversation:
"""Find the most recent conversation for a user, or create a new one."""
conversation = db.query(models.Conversation).filter(models.Conversation.owner_id == user_id).order_by(models.Conversation.created_at.desc()).first()
if not conversation:
conversation = models.Conversation(owner_id=user_id)
db.add(conversation)
db.commit()
db.refresh(conversation)
return conversation
def add_message_to_conversation(db: Session, conversation_id: int, role: str, content: str) -> models.Message:
"""Adds a new message to a conversation."""
message = models.Message(conversation_id=conversation_id, role=role, content=content)
db.add(message)
db.commit()
db.refresh(message)
return message
def get_conversation_history(db: Session, conversation_id: int) -> list[models.Message]:
"""Retriees all messages for a given conversation."""
return db.query(models.Message).filter(models.Message.conversation_id == conversation_id).order_by(models.Message.created_at).all()
def get_conversations_for_user(
db,
user_id: int,
skip: int = 0,
limit: int = 20
) -> list:
"""
Return paginated conversations for a given user (default: 20 per page).
"""
return (
db.query(models.Conversation)
.filter(models.Conversation.owner_id == user_id)
.order_by(models.Conversation.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
def count_conversations_for_user(db, user_id: int) -> int:
"""
Return the total number of conversations for pagination info.
"""
return (
db.query(models.Conversation)
.filter(models.Conversation.owner_id == user_id)
.count()
)
def get_conversation_by_id(db, message_id: int):
return db.query(models.Message).filter(models.Message.id == message_id).first()
# --- Global Message CRUD ---
def create_global_message(db: Session, author_id: int, content: str) -> models.GlobalMessage:
"""Creates a new message in the global chat."""
db_message = models.GlobalMessage(author_id=author_id, content=content)
db.add(db_message)
db.commit()
db.refresh(db_message)
return db_message
def get_global_messages(db: Session, skip: int = 0, limit: int = 100) -> List[models.GlobalMessage]:
"""Returns list of global messages with their authors ordered by most recent"""
return db.query(models.GlobalMessage).options(joinedload(models.GlobalMessage.author)).order_by(models.GlobalMessage.created_at.desc()).offset(skip).limit(limit).all()
# --- Reaction CRUD ---
def add_or_remove_reaction(db: Session, user_id: int, message_id: int, emoji: str) -> models.GlobalMessage:
"""Adds or removes a reaction from a message."""
# Check if the reaction already exists
existing_reaction = db.query(models.MessageReaction).filter(
models.MessageReaction.message_id == message_id,
models.MessageReaction.user_id == user_id,
models.MessageReaction.emoji == emoji
).first()
if existing_reaction:
# If it exists, delete it
db.delete(existing_reaction)
else:
# If it doesn't exist, create it
db_reaction = models.MessageReaction(
message_id=message_id,
user_id=user_id,
emoji=emoji
)
db.add(db_reaction)
db.commit()
# Return the message with updated reactions
return db.query(models.GlobalMessage).filter(models.GlobalMessage.id == message_id).first()
def get_user_count(db):
return db.query(models.User).count()
def get_conversation_count(db):
return db.query(models.Conversation).count()
def get_message_count(db):
return db.query(models.Message).count()
def get_average_messages_per_user(db):
total_users = get_user_count(db)
total_messages = get_message_count(db)
return (total_messages / total_users) if total_users else 0