Spaces:
No application file
No application file
| from fastapi import APIRouter, Depends, HTTPException | |
| from qdrant_client import QdrantClient | |
| from app.qdrant_client import get_qdrant_client | |
| from app.schemas.chat import ChatRequest, ChatResponse, ChatSelectionRequest | |
| from app.services.rag_service import RAGService | |
| from app.services.embeddings_service import EmbeddingsService, GeminiEmbeddingsService | |
| from app.services.openai_service import OpenAIService, GeminiService | |
| from app.config import settings | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api", tags=["chat"]) | |
| def get_rag_service( | |
| qdrant_client: QdrantClient = Depends(get_qdrant_client) | |
| ): | |
| # Choose the appropriate services based on AI_PROVIDER setting | |
| if settings.AI_PROVIDER.lower() == "gemini": | |
| embeddings_service = GeminiEmbeddingsService() | |
| ai_service = GeminiService() | |
| else: | |
| embeddings_service = EmbeddingsService() | |
| ai_service = OpenAIService() | |
| return RAGService(qdrant_client, embeddings_service, ai_service) | |
| async def chat( | |
| request: ChatRequest, | |
| rag_service: RAGService = Depends(get_rag_service) | |
| ): | |
| try: | |
| # Retrieve context from vector database | |
| context = await rag_service.retrieve_context(request.question, top_k=3) | |
| # Generate response using the configured AI service | |
| answer = await rag_service.generate_response(request.question, context) | |
| # Extract sources from context | |
| sources = [f"Source {i+1}" for i in range(len(context))] | |
| return ChatResponse(answer=answer, sources=sources) | |
| except Exception as e: | |
| logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat_selection( | |
| request: ChatSelectionRequest, | |
| rag_service: RAGService = Depends(get_rag_service) | |
| ): | |
| try: | |
| # Use selected text as primary context | |
| context = [request.selected_text] | |
| # Generate response | |
| answer = await rag_service.generate_response(request.question, context) | |
| return ChatResponse(answer=answer, sources=["Selected Text"]) | |
| except Exception as e: | |
| logger.error(f"Error in chat_selection endpoint: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) |