File size: 4,657 Bytes
a291087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
API endpoints for conversation management.
"""
from typing import List
import logging

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select, func, desc

from .dependencies import CurrentUserDep, get_session
from ..models.conversation import Conversation, ConversationWithPreview, ConversationRead
from ..models.message import Message, MessageRead

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/api/conversations", tags=["conversations"])


@router.get("", response_model=List[ConversationWithPreview])
async def list_conversations(
    current_user: CurrentUserDep,
    session: Session = Depends(get_session),
):
    """List all conversations for the current user with preview."""
    try:
        # Get all conversations for user
        stmt = (
            select(Conversation)
            .where(Conversation.user_id == current_user.id)
            .order_by(desc(Conversation.updated_at))
        )
        conversations = list(session.exec(stmt).all())

        # Build response with previews
        result = []
        for conv in conversations:
            # Get message count
            count_stmt = select(func.count(Message.id)).where(Message.conversation_id == conv.id)
            message_count = session.exec(count_stmt).one()

            # Get last message
            last_msg_stmt = (
                select(Message)
                .where(Message.conversation_id == conv.id)
                .order_by(desc(Message.created_at))
                .limit(1)
            )
            last_message = session.exec(last_msg_stmt).first()

            result.append(
                ConversationWithPreview(
                    id=conv.id,
                    title=conv.title,
                    created_at=conv.created_at,
                    updated_at=conv.updated_at,
                    message_count=message_count,
                    last_message=last_message.content if last_message else None,
                )
            )

        return result
    except Exception as e:
        logger.error(f"Error listing conversations: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail="Failed to list conversations")


@router.get("/{conversation_id}/messages", response_model=List[MessageRead])
async def get_conversation_messages(
    conversation_id: int,
    current_user: CurrentUserDep,
    session: Session = Depends(get_session),
):
    """Get all messages for a specific conversation."""
    try:
        # Verify conversation belongs to user
        conversation = session.get(Conversation, conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")
        if conversation.user_id != current_user.id:
            raise HTTPException(status_code=403, detail="Access denied")

        # Get messages
        stmt = (
            select(Message)
            .where(Message.conversation_id == conversation_id)
            .order_by(Message.created_at.asc())
        )
        messages = list(session.exec(stmt).all())

        return [
            MessageRead(
                id=msg.id,
                conversation_id=msg.conversation_id,
                user_id=msg.user_id,
                role=msg.role,
                content=msg.content,
                created_at=msg.created_at,
            )
            for msg in messages
        ]
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error getting messages: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail="Failed to get messages")


@router.delete("/{conversation_id}")
async def delete_conversation(
    conversation_id: int,
    current_user: CurrentUserDep,
    session: Session = Depends(get_session),
):
    """Delete a conversation and all its messages."""
    try:
        # Verify conversation belongs to user
        conversation = session.get(Conversation, conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")
        if conversation.user_id != current_user.id:
            raise HTTPException(status_code=403, detail="Access denied")

        # Delete conversation (messages will be cascade deleted)
        session.delete(conversation)
        session.commit()

        return {"message": "Conversation deleted successfully"}
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error deleting conversation: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail="Failed to delete conversation")