Spaces:
Running
Running
File size: 5,580 Bytes
cfc8e23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import logging
import math
from fastapi import APIRouter, Depends, HTTPException
from app.api.schemas import (
AnswerGenerationRequest,
AnswerGenerationResponse,
DocumentSearchRequest,
DocumentSearchResponse,
RetrievedDocument,
)
from app.dependencies import get_rag_service
from app.services.rag_service import RAGService
router = APIRouter()
logger = logging.getLogger(__name__)
def sanitize_score(score: float | None) -> float | None:
"""Replace NaN and infinity values with 0.0"""
if score is None:
logger.debug("Score is None, returning None")
return None
if math.isnan(score):
logger.warning("NaN score detected, replacing with 0.0")
return 0.0
if math.isinf(score):
logger.warning(f"Infinite score detected ({score}), replacing with 0.0")
return 0.0
logger.debug(f"Score is valid: {score}")
return score
def sanitize_retrieved_documents(
documents: list[RetrievedDocument],
) -> list[RetrievedDocument]:
"""Ensure all document scores are valid JSON-serializable values"""
logger.debug(f"Sanitizing scores for {len(documents)} documents")
sanitized_count = 0
for i, doc in enumerate(documents):
if doc.score is not None:
original_score = doc.score
doc.score = sanitize_score(doc.score)
if original_score != doc.score:
logger.debug(
f"Document {i + 1}: sanitized score {original_score} -> {doc.score}"
)
sanitized_count += 1
else:
logger.debug(f"Document {i + 1}: score {doc.score} unchanged")
if sanitized_count > 0:
logger.info(f"Sanitized {sanitized_count} document scores")
else:
logger.debug("No score sanitization needed")
return documents
@router.post("/search", response_model=DocumentSearchResponse)
async def search_documents(
request: DocumentSearchRequest, rag_service: RAGService = Depends(get_rag_service)
) -> DocumentSearchResponse:
"""
Search for relevant documents based on the query.
Returns top K most relevant documents.
"""
logger.info(
f"Document search request received: query='{request.query}', top_k={request.top_k}"
)
try:
logger.debug("Calling RAG service for document search")
documents = rag_service.search_documents(
query=request.query, top_k=request.top_k
)
logger.debug(f"RAG service returned {len(documents)} documents")
logger.debug("Sanitizing document scores")
sanitized_documents = sanitize_retrieved_documents(documents)
response = DocumentSearchResponse(
query=request.query,
documents=sanitized_documents,
total_results=len(sanitized_documents),
)
logger.info(
f"Document search completed successfully: {len(sanitized_documents)} results returned"
)
return response
except Exception as e:
logger.error(
f"Document search failed for query '{request.query}': {str(e)}",
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(e)) from e
@router.post("/answer", response_model=AnswerGenerationResponse)
async def generate_answer(
request: AnswerGenerationRequest, rag_service: RAGService = Depends(get_rag_service)
) -> AnswerGenerationResponse:
"""
Generate an answer based on the query using RAG.
Optionally returns the source documents used.
"""
logger.info(
f"Answer generation request received: query='{request.query}', top_k={request.top_k}, include_sources={request.include_sources}"
)
try:
logger.debug("Calling RAG service for answer generation")
result = rag_service.generate_answer(
query=request.query,
top_k=request.top_k,
include_sources=request.include_sources,
)
logger.debug(
f"RAG service returned result with answer length: {len(result.get('answer', ''))}"
)
if "sources" in result and result["sources"]:
logger.debug(
f"Sanitizing scores for {len(result['sources'])} source documents"
)
result["sources"] = sanitize_retrieved_documents(result["sources"])
else:
logger.debug("No sources to sanitize")
response = AnswerGenerationResponse(**result)
answer_length = len(result.get("answer", ""))
documents_used = result.get("documents_used", 0)
retrieval_method = result.get("retrieval_method", "unknown")
logger.info(
f"Answer generation completed successfully: {answer_length} chars generated, {documents_used} documents used, method={retrieval_method}"
)
return response
except Exception as e:
logger.error(
f"Answer generation failed for query '{request.query}': {str(e)}",
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/health")
async def health_check():
"""Check if the service is healthy and index is loaded"""
logger.debug("Health check request received")
try:
response = {"status": "healthy", "index_loaded": True}
logger.debug("Health check completed successfully")
return response
except Exception as e:
logger.error(f"Health check failed: {str(e)}", exc_info=True)
return {"status": "unhealthy", "index_loaded": False}
|