Spaces:
Sleeping
Sleeping
| """Chat and conversation management routes.""" | |
| import logging | |
| import json | |
| import asyncio | |
| from typing import Optional, AsyncGenerator | |
| from fastapi import APIRouter, Depends, HTTPException, status, Query | |
| from fastapi.responses import StreamingResponse | |
| from sqlmodel.ext.asyncio.session import AsyncSession | |
| from sqlalchemy import select, func | |
| from ...auth.dependencies import get_current_user | |
| from ...db.database import get_session | |
| from ...db.models import Message | |
| from ...models import ( | |
| ChatRequest, | |
| ChatResponse, | |
| ConversationDetailResponse, | |
| ConversationListResponse, | |
| MessageResponse, | |
| PaginatedConversationsResponse, | |
| ) | |
| from ...services.conversation_service import ConversationService | |
| from ...agents.agent import service as agent_service, service_streaming | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/chat", tags=["Chat"]) | |
| async def send_message( | |
| request: ChatRequest, | |
| user: dict = Depends(get_current_user), | |
| session: AsyncSession = Depends(get_session), | |
| ) -> ChatResponse: | |
| """ | |
| Send a message and get AI response. | |
| Creates a new conversation if conversation_id is None. | |
| Stores both user message and AI response to database. | |
| Args: | |
| request: ChatRequest with query and optional conversation_id | |
| user: Current authenticated user | |
| session: Database session | |
| Returns: | |
| ChatResponse with AI response and conversation_id | |
| Raises: | |
| HTTPException 404: If conversation not found | |
| HTTPException 403: If user doesn't own conversation | |
| HTTPException 500: If processing fails | |
| """ | |
| try: | |
| user_id = int(user["user_id"]) | |
| # Get or create conversation | |
| if request.conversation_id is None: | |
| # Create new conversation | |
| conversation = await ConversationService.create_conversation( | |
| user_id=user_id, | |
| session=session, | |
| ) | |
| conversation_id = conversation.id | |
| logger.info(f"Created new conversation {conversation_id} for user {user_id}") | |
| else: | |
| # Verify user owns the conversation | |
| try: | |
| conv_result = await ConversationService.get_conversation_with_messages( | |
| conversation_id=request.conversation_id, | |
| user_id=user_id, | |
| session=session, | |
| ) | |
| if conv_result is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Conversation not found" | |
| ) | |
| conversation_id = request.conversation_id | |
| except PermissionError: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="You do not have access to this conversation" | |
| ) | |
| # Save user message | |
| await ConversationService.save_message( | |
| conversation_id=conversation_id, | |
| sender_id=user_id, | |
| content=request.query, | |
| session=session, | |
| ) | |
| logger.debug(f"Saved user message to conversation {conversation_id}") | |
| # Get AI response | |
| ai_response = await agent_service(request.query, conversation_id=str(conversation_id)) | |
| # Save AI response (use user_id as sender for now, or create system user) | |
| await ConversationService.save_message( | |
| conversation_id=conversation_id, | |
| sender_id=user_id, # Could be system user ID if available | |
| content=ai_response, | |
| session=session, | |
| ) | |
| logger.debug(f"Saved AI response to conversation {conversation_id}") | |
| return ChatResponse( | |
| conversation_id=conversation_id, | |
| response=ai_response, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error sending message: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to process message" | |
| ) | |
| async def send_message_stream( | |
| request: ChatRequest, | |
| user: dict = Depends(get_current_user), | |
| session: AsyncSession = Depends(get_session), | |
| ): | |
| """ | |
| Send a message and stream the AI response in real-time. | |
| Returns Server-Sent Events (SSE) stream with response chunks. | |
| Args: | |
| request: ChatRequest with query and optional conversation_id | |
| user: Current authenticated user | |
| session: Database session | |
| Returns: | |
| StreamingResponse with Server-Sent Events | |
| Raises: | |
| HTTPException: If authentication fails or conversation not found | |
| """ | |
| try: | |
| user_id = int(user["user_id"]) | |
| # Get or create conversation | |
| if request.conversation_id is None: | |
| # Create new conversation | |
| conversation = await ConversationService.create_conversation( | |
| user_id=user_id, | |
| session=session, | |
| ) | |
| conversation_id = conversation.id | |
| logger.info(f"Created new conversation {conversation_id} for user {user_id}") | |
| else: | |
| # Verify user owns the conversation | |
| try: | |
| conv_result = await ConversationService.get_conversation_with_messages( | |
| conversation_id=request.conversation_id, | |
| user_id=user_id, | |
| session=session, | |
| ) | |
| if conv_result is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Conversation not found" | |
| ) | |
| conversation_id = request.conversation_id | |
| except PermissionError: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="You do not have access to this conversation" | |
| ) | |
| # Save user message | |
| await ConversationService.save_message( | |
| conversation_id=conversation_id, | |
| sender_id=user_id, | |
| content=request.query, | |
| session=session, | |
| ) | |
| logger.debug(f"Saved user message to conversation {conversation_id}") | |
| # Create a generator that handles streaming and saves the response | |
| async def event_generator(): | |
| full_response = "" | |
| try: | |
| # Stream the AI response | |
| async for chunk in service_streaming(request.query, conversation_id=str(conversation_id)): | |
| full_response += chunk | |
| # Send chunk as Server-Sent Event | |
| yield f"data: {json.dumps({'chunk': chunk, 'conversation_id': conversation_id})}\n\n" | |
| # Save the complete AI response | |
| await ConversationService.save_message( | |
| conversation_id=conversation_id, | |
| sender_id=user_id, | |
| content=full_response, | |
| session=session, | |
| ) | |
| logger.debug(f"Saved AI response to conversation {conversation_id}") | |
| # Send completion event | |
| yield f"data: {json.dumps({'event': 'done', 'conversation_id': conversation_id})}\n\n" | |
| except Exception as e: | |
| logger.error(f"Error during streaming: {e}", exc_info=True) | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in streaming endpoint: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to process streaming message" | |
| ) | |
| async def list_conversations( | |
| skip: int = Query(0, ge=0, description="Number of items to skip"), | |
| limit: int = Query(20, ge=1, le=100, description="Max items to return"), | |
| user: dict = Depends(get_current_user), | |
| session: AsyncSession = Depends(get_session), | |
| ) -> PaginatedConversationsResponse: | |
| """ | |
| List all conversations for the authenticated user. | |
| Returns conversations ordered by most recent first. | |
| Args: | |
| skip: Pagination offset | |
| limit: Pagination limit | |
| user: Current authenticated user | |
| session: Database session | |
| Returns: | |
| PaginatedConversationsResponse with list of conversations | |
| Raises: | |
| HTTPException 500: If database query fails | |
| """ | |
| try: | |
| user_id = int(user["user_id"]) | |
| conversations, total = await ConversationService.get_user_conversations( | |
| user_id=user_id, | |
| session=session, | |
| skip=skip, | |
| limit=limit, | |
| ) | |
| # Get message counts for all conversations in a single query | |
| conv_ids = [conv.id for conv in conversations] | |
| message_counts = {} | |
| if conv_ids: | |
| # Create a subquery to count messages per conversation | |
| stmt = select(Message.conversation_id, func.count(Message.id).label('count')).where( | |
| Message.conversation_id.in_(conv_ids) | |
| ).group_by(Message.conversation_id) | |
| result = await session.execute(stmt) | |
| for conversation_id, count in result.all(): | |
| message_counts[conversation_id] = count | |
| # Build response | |
| conv_responses = [] | |
| for conv in conversations: | |
| message_count = message_counts.get(conv.id, 0) | |
| conv_responses.append( | |
| ConversationListResponse( | |
| id=conv.id, | |
| title=conv.title, | |
| created_at=conv.created_at, | |
| last_message_at=conv.last_message_at, | |
| message_count=message_count, | |
| ) | |
| ) | |
| logger.info(f"Retrieved {len(conv_responses)} conversations for user {user_id}") | |
| return PaginatedConversationsResponse( | |
| conversations=conv_responses, | |
| total=total, | |
| skip=skip, | |
| limit=limit, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error listing conversations: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to retrieve conversations" | |
| ) | |
| async def get_conversation( | |
| conversation_id: int, | |
| user: dict = Depends(get_current_user), | |
| session: AsyncSession = Depends(get_session), | |
| ) -> ConversationDetailResponse: | |
| """ | |
| Retrieve a specific conversation with all its messages. | |
| Args: | |
| conversation_id: ID of the conversation to retrieve | |
| user: Current authenticated user | |
| session: Database session | |
| Returns: | |
| ConversationDetailResponse with conversation and all messages | |
| Raises: | |
| HTTPException 404: If conversation not found | |
| HTTPException 403: If user doesn't own conversation | |
| HTTPException 500: If database query fails | |
| """ | |
| try: | |
| user_id = int(user["user_id"]) | |
| result = await ConversationService.get_conversation_with_messages( | |
| conversation_id=conversation_id, | |
| user_id=user_id, | |
| session=session, | |
| ) | |
| if result is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Conversation not found" | |
| ) | |
| conversation, messages = result | |
| # Convert messages to response format | |
| message_responses = [ | |
| MessageResponse( | |
| id=msg.id, | |
| sender_id=msg.sender_id, | |
| content=msg.content, | |
| created_at=msg.created_at, | |
| ) | |
| for msg in messages | |
| ] | |
| logger.info(f"Retrieved conversation {conversation_id} with {len(messages)} messages") | |
| return ConversationDetailResponse( | |
| id=conversation.id, | |
| title=conversation.title, | |
| created_at=conversation.created_at, | |
| last_message_at=conversation.last_message_at, | |
| messages=message_responses, | |
| ) | |
| except PermissionError: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="You do not have access to this conversation" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error retrieving conversation: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to retrieve conversation" | |
| ) | |
| async def search_conversations( | |
| q: str = Query(..., min_length=1, max_length=100, description="Search keyword"), | |
| skip: int = Query(0, ge=0, description="Number of items to skip"), | |
| limit: int = Query(20, ge=1, le=100, description="Max items to return"), | |
| user: dict = Depends(get_current_user), | |
| session: AsyncSession = Depends(get_session), | |
| ) -> PaginatedConversationsResponse: | |
| """ | |
| Search conversations by keyword in message content. | |
| Searches are case-insensitive. | |
| Args: | |
| q: Search keyword | |
| skip: Pagination offset | |
| limit: Pagination limit | |
| user: Current authenticated user | |
| session: Database session | |
| Returns: | |
| PaginatedConversationsResponse with matching conversations | |
| Raises: | |
| HTTPException 500: If database query fails | |
| """ | |
| try: | |
| user_id = int(user["user_id"]) | |
| conversations, total = await ConversationService.search_conversations( | |
| user_id=user_id, | |
| keyword=q, | |
| session=session, | |
| skip=skip, | |
| limit=limit, | |
| ) | |
| # Get message counts for all conversations in a single query | |
| conv_ids = [conv.id for conv in conversations] | |
| message_counts = {} | |
| if conv_ids: | |
| # Create a subquery to count messages per conversation | |
| stmt = select(Message.conversation_id, func.count(Message.id).label('count')).where( | |
| Message.conversation_id.in_(conv_ids) | |
| ).group_by(Message.conversation_id) | |
| result = await session.execute(stmt) | |
| for conversation_id, count in result.all(): | |
| message_counts[conversation_id] = count | |
| # Build response | |
| conv_responses = [] | |
| for conv in conversations: | |
| message_count = message_counts.get(conv.id, 0) | |
| conv_responses.append( | |
| ConversationListResponse( | |
| id=conv.id, | |
| title=conv.title, | |
| created_at=conv.created_at, | |
| last_message_at=conv.last_message_at, | |
| message_count=message_count, | |
| ) | |
| ) | |
| logger.info(f"Search for '{q}' returned {len(conv_responses)} conversations for user {user_id}") | |
| return PaginatedConversationsResponse( | |
| conversations=conv_responses, | |
| total=total, | |
| skip=skip, | |
| limit=limit, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error searching conversations: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to search conversations" | |
| ) | |
| async def delete_conversation( | |
| conversation_id: int, | |
| user: dict = Depends(get_current_user), | |
| session: AsyncSession = Depends(get_session), | |
| ): | |
| """ | |
| Soft delete a conversation. | |
| The conversation is marked as deleted but data is preserved. | |
| Args: | |
| conversation_id: ID of the conversation to delete | |
| user: Current authenticated user | |
| session: Database session | |
| Returns: | |
| Success message | |
| Raises: | |
| HTTPException 404: If conversation not found | |
| HTTPException 403: If user doesn't own conversation | |
| HTTPException 500: If deletion fails | |
| """ | |
| try: | |
| user_id = int(user["user_id"]) | |
| try: | |
| await ConversationService.delete_conversation( | |
| conversation_id=conversation_id, | |
| user_id=user_id, | |
| session=session, | |
| ) | |
| except ValueError: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Conversation not found" | |
| ) | |
| except PermissionError: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="You do not have access to this conversation" | |
| ) | |
| logger.info(f"User {user_id} deleted conversation {conversation_id}") | |
| return {"message": "Conversation deleted successfully"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting conversation: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to delete conversation" | |
| ) | |