aivre / app /api /routes.py
Vedang Barhate
chore: copied from assist repo
cfc8e23
raw
history blame
5.58 kB
import logging
import math
from fastapi import APIRouter, Depends, HTTPException
from app.api.schemas import (
AnswerGenerationRequest,
AnswerGenerationResponse,
DocumentSearchRequest,
DocumentSearchResponse,
RetrievedDocument,
)
from app.dependencies import get_rag_service
from app.services.rag_service import RAGService
router = APIRouter()
logger = logging.getLogger(__name__)
def sanitize_score(score: float | None) -> float | None:
"""Replace NaN and infinity values with 0.0"""
if score is None:
logger.debug("Score is None, returning None")
return None
if math.isnan(score):
logger.warning("NaN score detected, replacing with 0.0")
return 0.0
if math.isinf(score):
logger.warning(f"Infinite score detected ({score}), replacing with 0.0")
return 0.0
logger.debug(f"Score is valid: {score}")
return score
def sanitize_retrieved_documents(
documents: list[RetrievedDocument],
) -> list[RetrievedDocument]:
"""Ensure all document scores are valid JSON-serializable values"""
logger.debug(f"Sanitizing scores for {len(documents)} documents")
sanitized_count = 0
for i, doc in enumerate(documents):
if doc.score is not None:
original_score = doc.score
doc.score = sanitize_score(doc.score)
if original_score != doc.score:
logger.debug(
f"Document {i + 1}: sanitized score {original_score} -> {doc.score}"
)
sanitized_count += 1
else:
logger.debug(f"Document {i + 1}: score {doc.score} unchanged")
if sanitized_count > 0:
logger.info(f"Sanitized {sanitized_count} document scores")
else:
logger.debug("No score sanitization needed")
return documents
@router.post("/search", response_model=DocumentSearchResponse)
async def search_documents(
request: DocumentSearchRequest, rag_service: RAGService = Depends(get_rag_service)
) -> DocumentSearchResponse:
"""
Search for relevant documents based on the query.
Returns top K most relevant documents.
"""
logger.info(
f"Document search request received: query='{request.query}', top_k={request.top_k}"
)
try:
logger.debug("Calling RAG service for document search")
documents = rag_service.search_documents(
query=request.query, top_k=request.top_k
)
logger.debug(f"RAG service returned {len(documents)} documents")
logger.debug("Sanitizing document scores")
sanitized_documents = sanitize_retrieved_documents(documents)
response = DocumentSearchResponse(
query=request.query,
documents=sanitized_documents,
total_results=len(sanitized_documents),
)
logger.info(
f"Document search completed successfully: {len(sanitized_documents)} results returned"
)
return response
except Exception as e:
logger.error(
f"Document search failed for query '{request.query}': {str(e)}",
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(e)) from e
@router.post("/answer", response_model=AnswerGenerationResponse)
async def generate_answer(
request: AnswerGenerationRequest, rag_service: RAGService = Depends(get_rag_service)
) -> AnswerGenerationResponse:
"""
Generate an answer based on the query using RAG.
Optionally returns the source documents used.
"""
logger.info(
f"Answer generation request received: query='{request.query}', top_k={request.top_k}, include_sources={request.include_sources}"
)
try:
logger.debug("Calling RAG service for answer generation")
result = rag_service.generate_answer(
query=request.query,
top_k=request.top_k,
include_sources=request.include_sources,
)
logger.debug(
f"RAG service returned result with answer length: {len(result.get('answer', ''))}"
)
if "sources" in result and result["sources"]:
logger.debug(
f"Sanitizing scores for {len(result['sources'])} source documents"
)
result["sources"] = sanitize_retrieved_documents(result["sources"])
else:
logger.debug("No sources to sanitize")
response = AnswerGenerationResponse(**result)
answer_length = len(result.get("answer", ""))
documents_used = result.get("documents_used", 0)
retrieval_method = result.get("retrieval_method", "unknown")
logger.info(
f"Answer generation completed successfully: {answer_length} chars generated, {documents_used} documents used, method={retrieval_method}"
)
return response
except Exception as e:
logger.error(
f"Answer generation failed for query '{request.query}': {str(e)}",
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/health")
async def health_check():
"""Check if the service is healthy and index is loaded"""
logger.debug("Health check request received")
try:
response = {"status": "healthy", "index_loaded": True}
logger.debug("Health check completed successfully")
return response
except Exception as e:
logger.error(f"Health check failed: {str(e)}", exc_info=True)
return {"status": "unhealthy", "index_loaded": False}