File size: 6,546 Bytes
d44b33d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """Grounded Q&A and summarisation routes.
``POST /query/ask`` retrieves top-K chunks from Chroma, calls the configured LLM with
citations enforced in the prompt, persists an audit row, and returns answer + sources.
``POST /query/summarise`` uses a retrieval-oriented query then a summary-focused prompt.
``POST /query`` is a legacy alias for ``/query/ask``.
"""
import time
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import APIRouter, HTTPException, status
from api.config import Settings, get_settings
from models.requests import QueryRequest, SummariseRequest
from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse
from rag.embedder import create_embedding_function
from rag.retriever import (
SUMMARY_RETRIEVAL_QUERY,
RetrievedChunk,
answer_with_grounding,
retrieve_chunks,
summarise_with_grounding,
)
from rag.vector_store import collection_document_count, get_vector_store
from storage.audit_store import persist_query_audit
router = APIRouter(prefix="/query", tags=["query"])
def _model_used_label(settings: Settings) -> str:
provider = settings.llm_provider.lower()
if provider == "openai":
return settings.openai_model
if provider == "ollama":
return settings.ollama_chat_model
if provider == "anthropic":
return settings.anthropic_model
if provider == "huggingface":
return settings.huggingface_model
return f"{provider}:unknown"
def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]:
citations: list[SourceCitation] = []
for chunk in chunks:
page = chunk.page if chunk.page is not None else 0
score = float(chunk.score) if chunk.score is not None else 0.0
citations.append(
SourceCitation(
document_name=chunk.source or "unknown",
page_number=page,
chunk_text=chunk.text,
relevance_score=score,
)
)
return citations
async def _run_ask(
settings: Settings,
payload: QueryRequest,
) -> AskQueryResponse:
"""Retrieve, generate grounded answer, audit, and build the API response."""
top_k = payload.top_k
t0 = time.perf_counter()
embedding_function = create_embedding_function()
vector_store = get_vector_store(
persist_directory=settings.chroma_persist_directory,
collection_name=payload.collection_name or "default",
embedding_function=embedding_function,
)
chunks = retrieve_chunks(vector_store, payload.question, top_k)
answer, tokens_used = answer_with_grounding(settings, payload.question, chunks)
elapsed_ms = int((time.perf_counter() - t0) * 1000)
citations = _chunks_to_citations(chunks)
query_id = str(uuid4())
ts = datetime.now(timezone.utc)
response = AskQueryResponse(
query_id=query_id,
question=payload.question,
answer=answer,
sources=citations,
model_used=_model_used_label(settings),
tokens_used=tokens_used,
response_time_ms=elapsed_ms,
timestamp=ts,
)
await persist_query_audit(
settings.audit_db_path,
query_id=query_id,
action="query",
user_id=payload.user_id,
question=payload.question,
collection_name=payload.collection_name or "default",
answer=answer,
sources=citations,
model_used=response.model_used,
tokens_used=tokens_used,
response_time_ms=elapsed_ms,
kind="ask",
)
return response
async def _run_summarise(
settings: Settings,
payload: SummariseRequest,
) -> SummariseQueryResponse:
"""Retrieve with focus or default overview query, summarise, and audit."""
top_k = settings.top_k_results
retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
t0 = time.perf_counter()
embedding_function = create_embedding_function()
vector_store = get_vector_store(
persist_directory=settings.chroma_persist_directory,
collection_name=payload.collection_name,
embedding_function=embedding_function,
)
chunks = retrieve_chunks(vector_store, retrieval_query, top_k)
summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
elapsed_ms = int((time.perf_counter() - t0) * 1000)
citations = _chunks_to_citations(chunks)
doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name)
query_id = str(uuid4())
ts = datetime.now(timezone.utc)
response = SummariseQueryResponse(
query_id=query_id,
summary=summary,
document_count=doc_count,
sources=citations,
timestamp=ts,
)
await persist_query_audit(
settings.audit_db_path,
query_id=query_id,
action="summarise",
user_id=payload.user_id,
question=audit_question,
collection_name=payload.collection_name,
answer=summary,
sources=citations,
model_used=_model_used_label(settings),
tokens_used=tokens_used,
response_time_ms=elapsed_ms,
kind="summarise",
)
return response
@router.post("/ask", response_model=AskQueryResponse)
async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse:
"""Grounded question answering against a Chroma collection."""
settings = get_settings()
try:
return await _run_ask(settings, payload)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
@router.post("/summarise", response_model=SummariseQueryResponse)
async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse:
"""Collection-wide summary with optional focus for retrieval."""
settings = get_settings()
try:
return await _run_summarise(settings, payload)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
legacy_query_router = APIRouter(tags=["query"])
@legacy_query_router.post("/query", response_model=AskQueryResponse)
async def query_post_compat(payload: QueryRequest) -> AskQueryResponse:
"""Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
return await ask_endpoint(payload)
|