Spaces:
Sleeping
Sleeping
File size: 2,963 Bytes
34e27fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | from sqlmodel import Session, select
from typing import List, Optional
from datetime import datetime
import uuid
from ..models.conversation import Conversation, ConversationCreate
from ..models.message import Message, MessageCreate
class ConversationService:
@staticmethod
def create_conversation(user_id: uuid.UUID, db_session: Session) -> Conversation:
"""Create a new conversation for a user"""
conversation = Conversation(user_id=user_id)
db_session.add(conversation)
db_session.commit()
db_session.refresh(conversation)
return conversation
@staticmethod
def get_conversation_by_id(conversation_id: int, user_id: uuid.UUID, db_session: Session) -> Optional[Conversation]:
"""Get a conversation by ID for a specific user (enforces user isolation)"""
statement = select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == user_id
)
return db_session.exec(statement).first()
@staticmethod
def get_messages(conversation_id: int, user_id: uuid.UUID, db_session: Session, limit: int = 20) -> List[Message]:
"""Get messages from a conversation with user isolation enforced"""
# First verify the conversation belongs to the user
conversation = ConversationService.get_conversation_by_id(conversation_id, user_id, db_session)
if not conversation:
return []
# Get messages for this conversation
statement = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at.desc()).limit(limit)
messages = db_session.exec(statement).all()
# Reverse to return in chronological order (oldest first)
return list(reversed(messages))
@staticmethod
def add_message(conversation_id: int, user_id: uuid.UUID, role: str, content: str, db_session: Session) -> Message:
"""Add a message to a conversation with user isolation enforced"""
# Verify the conversation belongs to the user
conversation = ConversationService.get_conversation_by_id(conversation_id, user_id, db_session)
if not conversation:
raise ValueError("Conversation not found or does not belong to user")
message = Message(
conversation_id=conversation_id,
user_id=user_id,
role=role,
content=content
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
return message
@staticmethod
def get_latest_conversation(user_id: uuid.UUID, db_session: Session) -> Optional[Conversation]:
"""Get the most recent conversation for a user"""
statement = select(Conversation).where(
Conversation.user_id == user_id
).order_by(Conversation.created_at.desc()).limit(1)
return db_session.exec(statement).first() |