chatbot / app /routes /chat.py
Tahasaif3's picture
'code'
a0c847a
from fastapi import APIRouter, Depends, HTTPException
from qdrant_client import QdrantClient
from app.qdrant_client import get_qdrant_client
from app.schemas.chat import ChatRequest, ChatResponse, ChatSelectionRequest
from app.services.rag_service import RAGService
from app.services.embeddings_service import EmbeddingsService, GeminiEmbeddingsService
from app.services.openai_service import OpenAIService, GeminiService
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["chat"])
def get_rag_service(
qdrant_client: QdrantClient = Depends(get_qdrant_client)
):
# Choose the appropriate services based on AI_PROVIDER setting
if settings.AI_PROVIDER.lower() == "gemini":
embeddings_service = GeminiEmbeddingsService()
ai_service = GeminiService()
else:
embeddings_service = EmbeddingsService()
ai_service = OpenAIService()
return RAGService(qdrant_client, embeddings_service, ai_service)
@router.post("/chat", response_model=ChatResponse)
async def chat(
request: ChatRequest,
rag_service: RAGService = Depends(get_rag_service)
):
try:
# Retrieve context from vector database
context = await rag_service.retrieve_context(request.question, top_k=3)
# Generate response using the configured AI service
answer = await rag_service.generate_response(request.question, context)
# Extract sources from context
sources = [f"Source {i+1}" for i in range(len(context))]
return ChatResponse(answer=answer, sources=sources)
except Exception as e:
logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chat-selection", response_model=ChatResponse)
async def chat_selection(
request: ChatSelectionRequest,
rag_service: RAGService = Depends(get_rag_service)
):
try:
# Use selected text as primary context
context = [request.selected_text]
# Generate response
answer = await rag_service.generate_response(request.question, context)
return ChatResponse(answer=answer, sources=["Selected Text"])
except Exception as e:
logger.error(f"Error in chat_selection endpoint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))