File size: 2,963 Bytes
34e27fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from sqlmodel import Session, select
from typing import List, Optional
from datetime import datetime
import uuid
from ..models.conversation import Conversation, ConversationCreate
from ..models.message import Message, MessageCreate


class ConversationService:
    @staticmethod
    def create_conversation(user_id: uuid.UUID, db_session: Session) -> Conversation:
        """Create a new conversation for a user"""
        conversation = Conversation(user_id=user_id)
        db_session.add(conversation)
        db_session.commit()
        db_session.refresh(conversation)
        return conversation

    @staticmethod
    def get_conversation_by_id(conversation_id: int, user_id: uuid.UUID, db_session: Session) -> Optional[Conversation]:
        """Get a conversation by ID for a specific user (enforces user isolation)"""
        statement = select(Conversation).where(
            Conversation.id == conversation_id,
            Conversation.user_id == user_id
        )
        return db_session.exec(statement).first()

    @staticmethod
    def get_messages(conversation_id: int, user_id: uuid.UUID, db_session: Session, limit: int = 20) -> List[Message]:
        """Get messages from a conversation with user isolation enforced"""
        # First verify the conversation belongs to the user
        conversation = ConversationService.get_conversation_by_id(conversation_id, user_id, db_session)
        if not conversation:
            return []

        # Get messages for this conversation
        statement = select(Message).where(
            Message.conversation_id == conversation_id
        ).order_by(Message.created_at.desc()).limit(limit)

        messages = db_session.exec(statement).all()
        # Reverse to return in chronological order (oldest first)
        return list(reversed(messages))

    @staticmethod
    def add_message(conversation_id: int, user_id: uuid.UUID, role: str, content: str, db_session: Session) -> Message:
        """Add a message to a conversation with user isolation enforced"""
        # Verify the conversation belongs to the user
        conversation = ConversationService.get_conversation_by_id(conversation_id, user_id, db_session)
        if not conversation:
            raise ValueError("Conversation not found or does not belong to user")

        message = Message(
            conversation_id=conversation_id,
            user_id=user_id,
            role=role,
            content=content
        )
        db_session.add(message)
        db_session.commit()
        db_session.refresh(message)
        return message

    @staticmethod
    def get_latest_conversation(user_id: uuid.UUID, db_session: Session) -> Optional[Conversation]:
        """Get the most recent conversation for a user"""
        statement = select(Conversation).where(
            Conversation.user_id == user_id
        ).order_by(Conversation.created_at.desc()).limit(1)

        return db_session.exec(statement).first()