Spaces:
Running
Running
| import logging | |
| import math | |
| from fastapi import APIRouter, Depends, HTTPException | |
| from app.api.schemas import ( | |
| AnswerGenerationRequest, | |
| AnswerGenerationResponse, | |
| DocumentSearchRequest, | |
| DocumentSearchResponse, | |
| RetrievedDocument, | |
| ) | |
| from app.dependencies import get_rag_service | |
| from app.services.rag_service import RAGService | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| def sanitize_score(score: float | None) -> float | None: | |
| """Replace NaN and infinity values with 0.0""" | |
| if score is None: | |
| logger.debug("Score is None, returning None") | |
| return None | |
| if math.isnan(score): | |
| logger.warning("NaN score detected, replacing with 0.0") | |
| return 0.0 | |
| if math.isinf(score): | |
| logger.warning(f"Infinite score detected ({score}), replacing with 0.0") | |
| return 0.0 | |
| logger.debug(f"Score is valid: {score}") | |
| return score | |
| def sanitize_retrieved_documents( | |
| documents: list[RetrievedDocument], | |
| ) -> list[RetrievedDocument]: | |
| """Ensure all document scores are valid JSON-serializable values""" | |
| logger.debug(f"Sanitizing scores for {len(documents)} documents") | |
| sanitized_count = 0 | |
| for i, doc in enumerate(documents): | |
| if doc.score is not None: | |
| original_score = doc.score | |
| doc.score = sanitize_score(doc.score) | |
| if original_score != doc.score: | |
| logger.debug( | |
| f"Document {i + 1}: sanitized score {original_score} -> {doc.score}" | |
| ) | |
| sanitized_count += 1 | |
| else: | |
| logger.debug(f"Document {i + 1}: score {doc.score} unchanged") | |
| if sanitized_count > 0: | |
| logger.info(f"Sanitized {sanitized_count} document scores") | |
| else: | |
| logger.debug("No score sanitization needed") | |
| return documents | |
| async def search_documents( | |
| request: DocumentSearchRequest, rag_service: RAGService = Depends(get_rag_service) | |
| ) -> DocumentSearchResponse: | |
| """ | |
| Search for relevant documents based on the query. | |
| Returns top K most relevant documents. | |
| """ | |
| logger.info( | |
| f"Document search request received: query='{request.query}', top_k={request.top_k}" | |
| ) | |
| try: | |
| logger.debug("Calling RAG service for document search") | |
| documents = rag_service.search_documents( | |
| query=request.query, top_k=request.top_k | |
| ) | |
| logger.debug(f"RAG service returned {len(documents)} documents") | |
| logger.debug("Sanitizing document scores") | |
| sanitized_documents = sanitize_retrieved_documents(documents) | |
| response = DocumentSearchResponse( | |
| query=request.query, | |
| documents=sanitized_documents, | |
| total_results=len(sanitized_documents), | |
| ) | |
| logger.info( | |
| f"Document search completed successfully: {len(sanitized_documents)} results returned" | |
| ) | |
| return response | |
| except Exception as e: | |
| logger.error( | |
| f"Document search failed for query '{request.query}': {str(e)}", | |
| exc_info=True, | |
| ) | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| async def generate_answer( | |
| request: AnswerGenerationRequest, rag_service: RAGService = Depends(get_rag_service) | |
| ) -> AnswerGenerationResponse: | |
| """ | |
| Generate an answer based on the query using RAG. | |
| Optionally returns the source documents used. | |
| """ | |
| logger.info( | |
| f"Answer generation request received: query='{request.query}', top_k={request.top_k}, include_sources={request.include_sources}" | |
| ) | |
| try: | |
| logger.debug("Calling RAG service for answer generation") | |
| result = rag_service.generate_answer( | |
| query=request.query, | |
| top_k=request.top_k, | |
| include_sources=request.include_sources, | |
| ) | |
| logger.debug( | |
| f"RAG service returned result with answer length: {len(result.get('answer', ''))}" | |
| ) | |
| if "sources" in result and result["sources"]: | |
| logger.debug( | |
| f"Sanitizing scores for {len(result['sources'])} source documents" | |
| ) | |
| result["sources"] = sanitize_retrieved_documents(result["sources"]) | |
| else: | |
| logger.debug("No sources to sanitize") | |
| response = AnswerGenerationResponse(**result) | |
| answer_length = len(result.get("answer", "")) | |
| documents_used = result.get("documents_used", 0) | |
| retrieval_method = result.get("retrieval_method", "unknown") | |
| logger.info( | |
| f"Answer generation completed successfully: {answer_length} chars generated, {documents_used} documents used, method={retrieval_method}" | |
| ) | |
| return response | |
| except Exception as e: | |
| logger.error( | |
| f"Answer generation failed for query '{request.query}': {str(e)}", | |
| exc_info=True, | |
| ) | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| async def health_check(): | |
| """Check if the service is healthy and index is loaded""" | |
| logger.debug("Health check request received") | |
| try: | |
| response = {"status": "healthy", "index_loaded": True} | |
| logger.debug("Health check completed successfully") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Health check failed: {str(e)}", exc_info=True) | |
| return {"status": "unhealthy", "index_loaded": False} | |