from typing import List, Dict, Any from sqlmodel.ext.asyncio.session import AsyncSession from models.conversation import Conversation, ConversationCreate from models.message import Message, MessageCreate, MessageRoleEnum from sqlmodel import select from uuid import UUID, uuid4 from datetime import datetime class ConversationManager: """ Manager class for handling conversation-related operations. """ def __init__(self, db_session: AsyncSession): self.db_session = db_session async def create_conversation(self, user_id: str) -> Conversation: """ Create a new conversation for the user. """ from datetime import timedelta expires_at = datetime.utcnow() + timedelta(days=7) # 7-day retention as specified conversation = Conversation( user_id=user_id, # Keep as string as expected by model expires_at=expires_at, created_at=datetime.utcnow(), updated_at=datetime.utcnow() ) self.db_session.add(conversation) await self.db_session.commit() await self.db_session.refresh(conversation) return conversation async def get_conversation(self, conversation_id: UUID) -> Conversation: """ Get a specific conversation by ID. """ statement = select(Conversation).where(Conversation.id == conversation_id) result = await self.db_session.exec(statement) conversation = result.first() return conversation async def add_message(self, conversation_id: UUID, role: MessageRoleEnum, content: str) -> Message: """ Add a message to a conversation. """ # Get the user_id from the conversation to associate with the message conversation = await self.get_conversation(conversation_id) if not conversation: raise ValueError(f"Conversation {conversation_id} not found") message = Message( conversation_id=conversation_id, user_id=conversation.user_id, role=role.value if hasattr(role, 'value') else role, content=content, created_at=datetime.utcnow() ) self.db_session.add(message) await self.db_session.commit() await self.db_session.refresh(message) return message async def update_conversation_timestamp(self, conversation_id: UUID): """ Update the updated_at timestamp for a conversation. """ conversation = await self.get_conversation(conversation_id) if conversation: conversation.updated_at = datetime.utcnow() self.db_session.add(conversation) await self.db_session.commit() async def get_recent_conversations(self, user_id: str) -> List[Dict[str, Any]]: """ Get recent conversations for a user. """ statement = select(Conversation).where(Conversation.user_id == user_id).order_by(Conversation.updated_at.desc()) result = await self.db_session.exec(statement) conversations = result.all() return [ { "id": str(conv.id), "user_id": conv.user_id, "created_at": conv.created_at.isoformat() if conv.created_at else None, "updated_at": conv.updated_at.isoformat() if conv.updated_at else None } for conv in conversations ] async def get_conversation_history(self, conversation_id: UUID) -> List[Dict[str, Any]]: """ Get the full history of messages in a conversation. """ statement = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at) result = await self.db_session.exec(statement) messages = result.all() return [ { "id": str(msg.id), "role": msg.role, "content": msg.content, "created_at": msg.created_at.isoformat() if msg.created_at else None } for msg in messages ]