File size: 4,130 Bytes
a34068e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import logging
import time

from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse

from app.api.deps import dep_generator, dep_query_analyzer, dep_retriever
from app.core.generator import AnswerGenerator
from app.core.query_analyzer import QueryAnalyzer
from app.core.retriever import HybridRetriever
from app.models.schemas import (
    GeneratedAnswer,
    QueryRequest,
    SearchRequest,
    SearchResponse,
)

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/api", tags=["query"])


def _resolve_filters(request_filters, analyzed_filters):
    """Use explicit request filters if provided, otherwise use analyzed filters only if they contain values."""
    if request_filters and request_filters.has_filters():
        return request_filters
    if analyzed_filters and analyzed_filters.has_filters():
        return analyzed_filters
    return None


@router.post("/search", response_model=SearchResponse)
async def search(
    request: SearchRequest,
    retriever: HybridRetriever = Depends(dep_retriever),
    analyzer: QueryAnalyzer = Depends(dep_query_analyzer),
):
    try:
        start = time.perf_counter()

        analyzed = analyzer.analyze(request.query)
        filters = _resolve_filters(request.filters, analyzed.extracted_filters)

        results = retriever.retrieve(
            query=analyzed.clean_query,
            top_k=request.top_k,
            filters=filters,
        )

        elapsed = (time.perf_counter() - start) * 1000

        return SearchResponse(
            query=request.query,
            results=results,
            total_results=len(results),
            search_time_ms=elapsed,
        )
    except Exception as e:
        logger.error(f"Search failed: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Search failed: {e}")


@router.post("/ask")
async def ask(
    request: QueryRequest,
    retriever: HybridRetriever = Depends(dep_retriever),
    generator: AnswerGenerator = Depends(dep_generator),
    analyzer: QueryAnalyzer = Depends(dep_query_analyzer),
):
    try:
        analyzed = analyzer.analyze(request.query)
        filters = _resolve_filters(request.filters, analyzed.extracted_filters)

        chunks = retriever.retrieve(
            query=analyzed.clean_query,
            top_k=request.top_k,
            filters=filters,
        )

        if request.stream:
            return StreamingResponse(
                _stream_response(request.query, chunks, generator, request.rerank_top_k, analyzed.intent),
                media_type="text/event-stream",
            )

        answer = generator.generate_answer(
            query=request.query,
            chunks=chunks,
            rerank_top_k=request.rerank_top_k,
            intent=analyzed.intent,
        )
        return answer
    except Exception as e:
        logger.error(f"Ask failed: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Query failed: {e}")


async def _stream_response(
    query: str,
    chunks,
    generator: AnswerGenerator,
    rerank_top_k: int,
    intent: str,
):
    try:
        async for item in generator.generate_answer_stream(
            query=query,
            chunks=chunks,
            rerank_top_k=rerank_top_k,
            intent=intent,
        ):
            if isinstance(item, str):
                yield f"data: {json.dumps({'text': item})}\n\n"
            elif isinstance(item, GeneratedAnswer):
                sources = [
                    {
                        "chunk_id": s.chunk_id,
                        "text": s.text[:200],
                        "source": s.metadata.source,
                        "score": s.score,
                    }
                    for s in item.sources
                ]
                yield f"data: {json.dumps({'done': True, 'sources': sources, 'model': item.model, 'time_ms': item.generation_time_ms})}\n\n"
    except Exception as e:
        logger.error(f"Streaming failed: {e}", exc_info=True)
        yield f"data: {json.dumps({'error': str(e), 'done': True})}\n\n"