File size: 4,123 Bytes
6a3de9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import List, Dict, Any
from sqlmodel.ext.asyncio.session import AsyncSession
from models.conversation import Conversation, ConversationCreate
from models.message import Message, MessageCreate, MessageRoleEnum
from sqlmodel import select
from uuid import UUID, uuid4
from datetime import datetime


class ConversationManager:
    """
    Manager class for handling conversation-related operations.
    """

    def __init__(self, db_session: AsyncSession):
        self.db_session = db_session

    async def create_conversation(self, user_id: str) -> Conversation:
        """
        Create a new conversation for the user.
        """
        from datetime import timedelta
        expires_at = datetime.utcnow() + timedelta(days=7)  # 7-day retention as specified

        conversation = Conversation(
            user_id=user_id,  # Keep as string as expected by model
            expires_at=expires_at,
            created_at=datetime.utcnow(),
            updated_at=datetime.utcnow()
        )

        self.db_session.add(conversation)
        await self.db_session.commit()
        await self.db_session.refresh(conversation)

        return conversation

    async def get_conversation(self, conversation_id: UUID) -> Conversation:
        """
        Get a specific conversation by ID.
        """
        statement = select(Conversation).where(Conversation.id == conversation_id)
        result = await self.db_session.exec(statement)
        conversation = result.first()
        return conversation

    async def add_message(self, conversation_id: UUID, role: MessageRoleEnum, content: str) -> Message:
        """
        Add a message to a conversation.
        """
        # Get the user_id from the conversation to associate with the message
        conversation = await self.get_conversation(conversation_id)
        if not conversation:
            raise ValueError(f"Conversation {conversation_id} not found")

        message = Message(
            conversation_id=conversation_id,
            user_id=conversation.user_id,
            role=role.value if hasattr(role, 'value') else role,
            content=content,
            created_at=datetime.utcnow()
        )

        self.db_session.add(message)
        await self.db_session.commit()
        await self.db_session.refresh(message)

        return message

    async def update_conversation_timestamp(self, conversation_id: UUID):
        """
        Update the updated_at timestamp for a conversation.
        """
        conversation = await self.get_conversation(conversation_id)
        if conversation:
            conversation.updated_at = datetime.utcnow()
            self.db_session.add(conversation)
            await self.db_session.commit()

    async def get_recent_conversations(self, user_id: str) -> List[Dict[str, Any]]:
        """
        Get recent conversations for a user.
        """
        statement = select(Conversation).where(Conversation.user_id == user_id).order_by(Conversation.updated_at.desc())
        result = await self.db_session.exec(statement)
        conversations = result.all()

        return [
            {
                "id": str(conv.id),
                "user_id": conv.user_id,
                "created_at": conv.created_at.isoformat() if conv.created_at else None,
                "updated_at": conv.updated_at.isoformat() if conv.updated_at else None
            }
            for conv in conversations
        ]

    async def get_conversation_history(self, conversation_id: UUID) -> List[Dict[str, Any]]:
        """
        Get the full history of messages in a conversation.
        """
        statement = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
        result = await self.db_session.exec(statement)
        messages = result.all()

        return [
            {
                "id": str(msg.id),
                "role": msg.role,
                "content": msg.content,
                "created_at": msg.created_at.isoformat() if msg.created_at else None
            }
            for msg in messages
        ]