File size: 6,546 Bytes
d44b33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Grounded Q&A and summarisation routes.

``POST /query/ask`` retrieves top-K chunks from Chroma, calls the configured LLM with
citations enforced in the prompt, persists an audit row, and returns answer + sources.
``POST /query/summarise`` uses a retrieval-oriented query then a summary-focused prompt.
``POST /query`` is a legacy alias for ``/query/ask``.
"""

import time
from datetime import datetime, timezone
from uuid import uuid4

from fastapi import APIRouter, HTTPException, status

from api.config import Settings, get_settings
from models.requests import QueryRequest, SummariseRequest
from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse
from rag.embedder import create_embedding_function
from rag.retriever import (
    SUMMARY_RETRIEVAL_QUERY,
    RetrievedChunk,
    answer_with_grounding,
    retrieve_chunks,
    summarise_with_grounding,
)
from rag.vector_store import collection_document_count, get_vector_store
from storage.audit_store import persist_query_audit

router = APIRouter(prefix="/query", tags=["query"])


def _model_used_label(settings: Settings) -> str:
    provider = settings.llm_provider.lower()
    if provider == "openai":
        return settings.openai_model
    if provider == "ollama":
        return settings.ollama_chat_model
    if provider == "anthropic":
        return settings.anthropic_model
    if provider == "huggingface":
        return settings.huggingface_model
    return f"{provider}:unknown"


def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]:
    citations: list[SourceCitation] = []
    for chunk in chunks:
        page = chunk.page if chunk.page is not None else 0
        score = float(chunk.score) if chunk.score is not None else 0.0
        citations.append(
            SourceCitation(
                document_name=chunk.source or "unknown",
                page_number=page,
                chunk_text=chunk.text,
                relevance_score=score,
            )
        )
    return citations


async def _run_ask(
    settings: Settings,
    payload: QueryRequest,
) -> AskQueryResponse:
    """Retrieve, generate grounded answer, audit, and build the API response."""
    top_k = payload.top_k
    t0 = time.perf_counter()
    embedding_function = create_embedding_function()
    vector_store = get_vector_store(
        persist_directory=settings.chroma_persist_directory,
        collection_name=payload.collection_name or "default",
        embedding_function=embedding_function,
    )
    chunks = retrieve_chunks(vector_store, payload.question, top_k)
    answer, tokens_used = answer_with_grounding(settings, payload.question, chunks)
    elapsed_ms = int((time.perf_counter() - t0) * 1000)
    citations = _chunks_to_citations(chunks)
    query_id = str(uuid4())
    ts = datetime.now(timezone.utc)
    response = AskQueryResponse(
        query_id=query_id,
        question=payload.question,
        answer=answer,
        sources=citations,
        model_used=_model_used_label(settings),
        tokens_used=tokens_used,
        response_time_ms=elapsed_ms,
        timestamp=ts,
    )
    await persist_query_audit(
        settings.audit_db_path,
        query_id=query_id,
        action="query",
        user_id=payload.user_id,
        question=payload.question,
        collection_name=payload.collection_name or "default",
        answer=answer,
        sources=citations,
        model_used=response.model_used,
        tokens_used=tokens_used,
        response_time_ms=elapsed_ms,
        kind="ask",
    )
    return response


async def _run_summarise(
    settings: Settings,
    payload: SummariseRequest,
) -> SummariseQueryResponse:
    """Retrieve with focus or default overview query, summarise, and audit."""
    top_k = settings.top_k_results
    retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
    audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
    t0 = time.perf_counter()
    embedding_function = create_embedding_function()
    vector_store = get_vector_store(
        persist_directory=settings.chroma_persist_directory,
        collection_name=payload.collection_name,
        embedding_function=embedding_function,
    )
    chunks = retrieve_chunks(vector_store, retrieval_query, top_k)
    summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
    elapsed_ms = int((time.perf_counter() - t0) * 1000)
    citations = _chunks_to_citations(chunks)
    doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name)
    query_id = str(uuid4())
    ts = datetime.now(timezone.utc)
    response = SummariseQueryResponse(
        query_id=query_id,
        summary=summary,
        document_count=doc_count,
        sources=citations,
        timestamp=ts,
    )
    await persist_query_audit(
        settings.audit_db_path,
        query_id=query_id,
        action="summarise",
        user_id=payload.user_id,
        question=audit_question,
        collection_name=payload.collection_name,
        answer=summary,
        sources=citations,
        model_used=_model_used_label(settings),
        tokens_used=tokens_used,
        response_time_ms=elapsed_ms,
        kind="summarise",
    )
    return response


@router.post("/ask", response_model=AskQueryResponse)
async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse:
    """Grounded question answering against a Chroma collection."""
    settings = get_settings()
    try:
        return await _run_ask(settings, payload)
    except Exception as exc:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc


@router.post("/summarise", response_model=SummariseQueryResponse)
async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse:
    """Collection-wide summary with optional focus for retrieval."""
    settings = get_settings()
    try:
        return await _run_summarise(settings, payload)
    except Exception as exc:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc


legacy_query_router = APIRouter(tags=["query"])


@legacy_query_router.post("/query", response_model=AskQueryResponse)
async def query_post_compat(payload: QueryRequest) -> AskQueryResponse:
    """Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
    return await ask_endpoint(payload)