File size: 5,580 Bytes
cfc8e23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import logging
import math

from fastapi import APIRouter, Depends, HTTPException

from app.api.schemas import (
    AnswerGenerationRequest,
    AnswerGenerationResponse,
    DocumentSearchRequest,
    DocumentSearchResponse,
    RetrievedDocument,
)
from app.dependencies import get_rag_service
from app.services.rag_service import RAGService

router = APIRouter()
logger = logging.getLogger(__name__)


def sanitize_score(score: float | None) -> float | None:
    """Replace NaN and infinity values with 0.0"""
    if score is None:
        logger.debug("Score is None, returning None")
        return None

    if math.isnan(score):
        logger.warning("NaN score detected, replacing with 0.0")
        return 0.0

    if math.isinf(score):
        logger.warning(f"Infinite score detected ({score}), replacing with 0.0")
        return 0.0

    logger.debug(f"Score is valid: {score}")
    return score


def sanitize_retrieved_documents(
    documents: list[RetrievedDocument],
) -> list[RetrievedDocument]:
    """Ensure all document scores are valid JSON-serializable values"""
    logger.debug(f"Sanitizing scores for {len(documents)} documents")

    sanitized_count = 0
    for i, doc in enumerate(documents):
        if doc.score is not None:
            original_score = doc.score
            doc.score = sanitize_score(doc.score)

            if original_score != doc.score:
                logger.debug(
                    f"Document {i + 1}: sanitized score {original_score} -> {doc.score}"
                )
                sanitized_count += 1
            else:
                logger.debug(f"Document {i + 1}: score {doc.score} unchanged")

    if sanitized_count > 0:
        logger.info(f"Sanitized {sanitized_count} document scores")
    else:
        logger.debug("No score sanitization needed")

    return documents


@router.post("/search", response_model=DocumentSearchResponse)
async def search_documents(
    request: DocumentSearchRequest, rag_service: RAGService = Depends(get_rag_service)
) -> DocumentSearchResponse:
    """
    Search for relevant documents based on the query.
    Returns top K most relevant documents.
    """
    logger.info(
        f"Document search request received: query='{request.query}', top_k={request.top_k}"
    )

    try:
        logger.debug("Calling RAG service for document search")
        documents = rag_service.search_documents(
            query=request.query, top_k=request.top_k
        )
        logger.debug(f"RAG service returned {len(documents)} documents")

        logger.debug("Sanitizing document scores")
        sanitized_documents = sanitize_retrieved_documents(documents)

        response = DocumentSearchResponse(
            query=request.query,
            documents=sanitized_documents,
            total_results=len(sanitized_documents),
        )

        logger.info(
            f"Document search completed successfully: {len(sanitized_documents)} results returned"
        )
        return response

    except Exception as e:
        logger.error(
            f"Document search failed for query '{request.query}': {str(e)}",
            exc_info=True,
        )
        raise HTTPException(status_code=500, detail=str(e)) from e


@router.post("/answer", response_model=AnswerGenerationResponse)
async def generate_answer(
    request: AnswerGenerationRequest, rag_service: RAGService = Depends(get_rag_service)
) -> AnswerGenerationResponse:
    """
    Generate an answer based on the query using RAG.
    Optionally returns the source documents used.
    """
    logger.info(
        f"Answer generation request received: query='{request.query}', top_k={request.top_k}, include_sources={request.include_sources}"
    )

    try:
        logger.debug("Calling RAG service for answer generation")
        result = rag_service.generate_answer(
            query=request.query,
            top_k=request.top_k,
            include_sources=request.include_sources,
        )
        logger.debug(
            f"RAG service returned result with answer length: {len(result.get('answer', ''))}"
        )

        if "sources" in result and result["sources"]:
            logger.debug(
                f"Sanitizing scores for {len(result['sources'])} source documents"
            )
            result["sources"] = sanitize_retrieved_documents(result["sources"])
        else:
            logger.debug("No sources to sanitize")

        response = AnswerGenerationResponse(**result)

        answer_length = len(result.get("answer", ""))
        documents_used = result.get("documents_used", 0)
        retrieval_method = result.get("retrieval_method", "unknown")

        logger.info(
            f"Answer generation completed successfully: {answer_length} chars generated, {documents_used} documents used, method={retrieval_method}"
        )
        return response

    except Exception as e:
        logger.error(
            f"Answer generation failed for query '{request.query}': {str(e)}",
            exc_info=True,
        )
        raise HTTPException(status_code=500, detail=str(e)) from e


@router.get("/health")
async def health_check():
    """Check if the service is healthy and index is loaded"""
    logger.debug("Health check request received")

    try:
        response = {"status": "healthy", "index_loaded": True}
        logger.debug("Health check completed successfully")
        return response

    except Exception as e:
        logger.error(f"Health check failed: {str(e)}", exc_info=True)
        return {"status": "unhealthy", "index_loaded": False}