chatbot / src /api /routes /chat.py
jawadsaghir12's picture
new update
6b72b9c
"""Chat and conversation management routes."""
import logging
import json
import asyncio
from typing import Optional, AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import StreamingResponse
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy import select, func
from ...auth.dependencies import get_current_user
from ...db.database import get_session
from ...db.models import Message
from ...models import (
ChatRequest,
ChatResponse,
ConversationDetailResponse,
ConversationListResponse,
MessageResponse,
PaginatedConversationsResponse,
)
from ...services.conversation_service import ConversationService
from ...agents.agent import service as agent_service, service_streaming
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["Chat"])
@router.post("/send", response_model=ChatResponse)
async def send_message(
request: ChatRequest,
user: dict = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
) -> ChatResponse:
"""
Send a message and get AI response.
Creates a new conversation if conversation_id is None.
Stores both user message and AI response to database.
Args:
request: ChatRequest with query and optional conversation_id
user: Current authenticated user
session: Database session
Returns:
ChatResponse with AI response and conversation_id
Raises:
HTTPException 404: If conversation not found
HTTPException 403: If user doesn't own conversation
HTTPException 500: If processing fails
"""
try:
user_id = int(user["user_id"])
# Get or create conversation
if request.conversation_id is None:
# Create new conversation
conversation = await ConversationService.create_conversation(
user_id=user_id,
session=session,
)
conversation_id = conversation.id
logger.info(f"Created new conversation {conversation_id} for user {user_id}")
else:
# Verify user owns the conversation
try:
conv_result = await ConversationService.get_conversation_with_messages(
conversation_id=request.conversation_id,
user_id=user_id,
session=session,
)
if conv_result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found"
)
conversation_id = request.conversation_id
except PermissionError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have access to this conversation"
)
# Save user message
await ConversationService.save_message(
conversation_id=conversation_id,
sender_id=user_id,
content=request.query,
session=session,
)
logger.debug(f"Saved user message to conversation {conversation_id}")
# Get AI response
ai_response = await agent_service(request.query, conversation_id=str(conversation_id))
# Save AI response (use user_id as sender for now, or create system user)
await ConversationService.save_message(
conversation_id=conversation_id,
sender_id=user_id, # Could be system user ID if available
content=ai_response,
session=session,
)
logger.debug(f"Saved AI response to conversation {conversation_id}")
return ChatResponse(
conversation_id=conversation_id,
response=ai_response,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error sending message: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process message"
)
@router.post("/send/stream")
async def send_message_stream(
request: ChatRequest,
user: dict = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""
Send a message and stream the AI response in real-time.
Returns Server-Sent Events (SSE) stream with response chunks.
Args:
request: ChatRequest with query and optional conversation_id
user: Current authenticated user
session: Database session
Returns:
StreamingResponse with Server-Sent Events
Raises:
HTTPException: If authentication fails or conversation not found
"""
try:
user_id = int(user["user_id"])
# Get or create conversation
if request.conversation_id is None:
# Create new conversation
conversation = await ConversationService.create_conversation(
user_id=user_id,
session=session,
)
conversation_id = conversation.id
logger.info(f"Created new conversation {conversation_id} for user {user_id}")
else:
# Verify user owns the conversation
try:
conv_result = await ConversationService.get_conversation_with_messages(
conversation_id=request.conversation_id,
user_id=user_id,
session=session,
)
if conv_result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found"
)
conversation_id = request.conversation_id
except PermissionError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have access to this conversation"
)
# Save user message
await ConversationService.save_message(
conversation_id=conversation_id,
sender_id=user_id,
content=request.query,
session=session,
)
logger.debug(f"Saved user message to conversation {conversation_id}")
# Create a generator that handles streaming and saves the response
async def event_generator():
full_response = ""
try:
# Stream the AI response
async for chunk in service_streaming(request.query, conversation_id=str(conversation_id)):
full_response += chunk
# Send chunk as Server-Sent Event
yield f"data: {json.dumps({'chunk': chunk, 'conversation_id': conversation_id})}\n\n"
# Save the complete AI response
await ConversationService.save_message(
conversation_id=conversation_id,
sender_id=user_id,
content=full_response,
session=session,
)
logger.debug(f"Saved AI response to conversation {conversation_id}")
# Send completion event
yield f"data: {json.dumps({'event': 'done', 'conversation_id': conversation_id})}\n\n"
except Exception as e:
logger.error(f"Error during streaming: {e}", exc_info=True)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in streaming endpoint: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process streaming message"
)
@router.get("/conversations", response_model=PaginatedConversationsResponse)
async def list_conversations(
skip: int = Query(0, ge=0, description="Number of items to skip"),
limit: int = Query(20, ge=1, le=100, description="Max items to return"),
user: dict = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
) -> PaginatedConversationsResponse:
"""
List all conversations for the authenticated user.
Returns conversations ordered by most recent first.
Args:
skip: Pagination offset
limit: Pagination limit
user: Current authenticated user
session: Database session
Returns:
PaginatedConversationsResponse with list of conversations
Raises:
HTTPException 500: If database query fails
"""
try:
user_id = int(user["user_id"])
conversations, total = await ConversationService.get_user_conversations(
user_id=user_id,
session=session,
skip=skip,
limit=limit,
)
# Get message counts for all conversations in a single query
conv_ids = [conv.id for conv in conversations]
message_counts = {}
if conv_ids:
# Create a subquery to count messages per conversation
stmt = select(Message.conversation_id, func.count(Message.id).label('count')).where(
Message.conversation_id.in_(conv_ids)
).group_by(Message.conversation_id)
result = await session.execute(stmt)
for conversation_id, count in result.all():
message_counts[conversation_id] = count
# Build response
conv_responses = []
for conv in conversations:
message_count = message_counts.get(conv.id, 0)
conv_responses.append(
ConversationListResponse(
id=conv.id,
title=conv.title,
created_at=conv.created_at,
last_message_at=conv.last_message_at,
message_count=message_count,
)
)
logger.info(f"Retrieved {len(conv_responses)} conversations for user {user_id}")
return PaginatedConversationsResponse(
conversations=conv_responses,
total=total,
skip=skip,
limit=limit,
)
except Exception as e:
logger.error(f"Error listing conversations: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve conversations"
)
@router.get("/conversations/{conversation_id}", response_model=ConversationDetailResponse)
async def get_conversation(
conversation_id: int,
user: dict = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
) -> ConversationDetailResponse:
"""
Retrieve a specific conversation with all its messages.
Args:
conversation_id: ID of the conversation to retrieve
user: Current authenticated user
session: Database session
Returns:
ConversationDetailResponse with conversation and all messages
Raises:
HTTPException 404: If conversation not found
HTTPException 403: If user doesn't own conversation
HTTPException 500: If database query fails
"""
try:
user_id = int(user["user_id"])
result = await ConversationService.get_conversation_with_messages(
conversation_id=conversation_id,
user_id=user_id,
session=session,
)
if result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found"
)
conversation, messages = result
# Convert messages to response format
message_responses = [
MessageResponse(
id=msg.id,
sender_id=msg.sender_id,
content=msg.content,
created_at=msg.created_at,
)
for msg in messages
]
logger.info(f"Retrieved conversation {conversation_id} with {len(messages)} messages")
return ConversationDetailResponse(
id=conversation.id,
title=conversation.title,
created_at=conversation.created_at,
last_message_at=conversation.last_message_at,
messages=message_responses,
)
except PermissionError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have access to this conversation"
)
except Exception as e:
logger.error(f"Error retrieving conversation: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve conversation"
)
@router.get("/search", response_model=PaginatedConversationsResponse)
async def search_conversations(
q: str = Query(..., min_length=1, max_length=100, description="Search keyword"),
skip: int = Query(0, ge=0, description="Number of items to skip"),
limit: int = Query(20, ge=1, le=100, description="Max items to return"),
user: dict = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
) -> PaginatedConversationsResponse:
"""
Search conversations by keyword in message content.
Searches are case-insensitive.
Args:
q: Search keyword
skip: Pagination offset
limit: Pagination limit
user: Current authenticated user
session: Database session
Returns:
PaginatedConversationsResponse with matching conversations
Raises:
HTTPException 500: If database query fails
"""
try:
user_id = int(user["user_id"])
conversations, total = await ConversationService.search_conversations(
user_id=user_id,
keyword=q,
session=session,
skip=skip,
limit=limit,
)
# Get message counts for all conversations in a single query
conv_ids = [conv.id for conv in conversations]
message_counts = {}
if conv_ids:
# Create a subquery to count messages per conversation
stmt = select(Message.conversation_id, func.count(Message.id).label('count')).where(
Message.conversation_id.in_(conv_ids)
).group_by(Message.conversation_id)
result = await session.execute(stmt)
for conversation_id, count in result.all():
message_counts[conversation_id] = count
# Build response
conv_responses = []
for conv in conversations:
message_count = message_counts.get(conv.id, 0)
conv_responses.append(
ConversationListResponse(
id=conv.id,
title=conv.title,
created_at=conv.created_at,
last_message_at=conv.last_message_at,
message_count=message_count,
)
)
logger.info(f"Search for '{q}' returned {len(conv_responses)} conversations for user {user_id}")
return PaginatedConversationsResponse(
conversations=conv_responses,
total=total,
skip=skip,
limit=limit,
)
except Exception as e:
logger.error(f"Error searching conversations: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to search conversations"
)
@router.delete("/conversations/{conversation_id}")
async def delete_conversation(
conversation_id: int,
user: dict = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""
Soft delete a conversation.
The conversation is marked as deleted but data is preserved.
Args:
conversation_id: ID of the conversation to delete
user: Current authenticated user
session: Database session
Returns:
Success message
Raises:
HTTPException 404: If conversation not found
HTTPException 403: If user doesn't own conversation
HTTPException 500: If deletion fails
"""
try:
user_id = int(user["user_id"])
try:
await ConversationService.delete_conversation(
conversation_id=conversation_id,
user_id=user_id,
session=session,
)
except ValueError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found"
)
except PermissionError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have access to this conversation"
)
logger.info(f"User {user_id} deleted conversation {conversation_id}")
return {"message": "Conversation deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting conversation: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete conversation"
)