| """Grounded Q&A and summarisation routes. |
| |
| ``POST /query/ask`` retrieves top-K chunks from Chroma, calls the configured LLM with |
| citations enforced in the prompt, persists an audit row, and returns answer + sources. |
| ``POST /query/summarise`` uses a retrieval-oriented query then a summary-focused prompt. |
| ``POST /query`` is a legacy alias for ``/query/ask``. |
| """ |
|
|
| import time |
| from datetime import datetime, timezone |
| from uuid import uuid4 |
|
|
| from fastapi import APIRouter, HTTPException, status |
|
|
| from api.config import Settings, get_settings |
| from models.requests import QueryRequest, SummariseRequest |
| from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse |
| from rag.embedder import create_embedding_function |
| from rag.retriever import ( |
| SUMMARY_RETRIEVAL_QUERY, |
| RetrievedChunk, |
| answer_with_grounding, |
| retrieve_chunks, |
| summarise_with_grounding, |
| ) |
| from rag.vector_store import collection_document_count, get_vector_store |
| from storage.audit_store import persist_query_audit |
|
|
| router = APIRouter(prefix="/query", tags=["query"]) |
|
|
|
|
| def _model_used_label(settings: Settings) -> str: |
| provider = settings.llm_provider.lower() |
| if provider == "openai": |
| return settings.openai_model |
| if provider == "ollama": |
| return settings.ollama_chat_model |
| if provider == "anthropic": |
| return settings.anthropic_model |
| if provider == "huggingface": |
| return settings.huggingface_model |
| return f"{provider}:unknown" |
|
|
|
|
| def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]: |
| citations: list[SourceCitation] = [] |
| for chunk in chunks: |
| page = chunk.page if chunk.page is not None else 0 |
| score = float(chunk.score) if chunk.score is not None else 0.0 |
| citations.append( |
| SourceCitation( |
| document_name=chunk.source or "unknown", |
| page_number=page, |
| chunk_text=chunk.text, |
| relevance_score=score, |
| ) |
| ) |
| return citations |
|
|
|
|
| async def _run_ask( |
| settings: Settings, |
| payload: QueryRequest, |
| ) -> AskQueryResponse: |
| """Retrieve, generate grounded answer, audit, and build the API response.""" |
| top_k = payload.top_k |
| t0 = time.perf_counter() |
| embedding_function = create_embedding_function() |
| vector_store = get_vector_store( |
| persist_directory=settings.chroma_persist_directory, |
| collection_name=payload.collection_name or "default", |
| embedding_function=embedding_function, |
| ) |
| chunks = retrieve_chunks(vector_store, payload.question, top_k) |
| answer, tokens_used = answer_with_grounding(settings, payload.question, chunks) |
| elapsed_ms = int((time.perf_counter() - t0) * 1000) |
| citations = _chunks_to_citations(chunks) |
| query_id = str(uuid4()) |
| ts = datetime.now(timezone.utc) |
| response = AskQueryResponse( |
| query_id=query_id, |
| question=payload.question, |
| answer=answer, |
| sources=citations, |
| model_used=_model_used_label(settings), |
| tokens_used=tokens_used, |
| response_time_ms=elapsed_ms, |
| timestamp=ts, |
| ) |
| await persist_query_audit( |
| settings.audit_db_path, |
| query_id=query_id, |
| action="query", |
| user_id=payload.user_id, |
| question=payload.question, |
| collection_name=payload.collection_name or "default", |
| answer=answer, |
| sources=citations, |
| model_used=response.model_used, |
| tokens_used=tokens_used, |
| response_time_ms=elapsed_ms, |
| kind="ask", |
| ) |
| return response |
|
|
|
|
| async def _run_summarise( |
| settings: Settings, |
| payload: SummariseRequest, |
| ) -> SummariseQueryResponse: |
| """Retrieve with focus or default overview query, summarise, and audit.""" |
| top_k = settings.top_k_results |
| retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY |
| audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection" |
| t0 = time.perf_counter() |
| embedding_function = create_embedding_function() |
| vector_store = get_vector_store( |
| persist_directory=settings.chroma_persist_directory, |
| collection_name=payload.collection_name, |
| embedding_function=embedding_function, |
| ) |
| chunks = retrieve_chunks(vector_store, retrieval_query, top_k) |
| summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks) |
| elapsed_ms = int((time.perf_counter() - t0) * 1000) |
| citations = _chunks_to_citations(chunks) |
| doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name) |
| query_id = str(uuid4()) |
| ts = datetime.now(timezone.utc) |
| response = SummariseQueryResponse( |
| query_id=query_id, |
| summary=summary, |
| document_count=doc_count, |
| sources=citations, |
| timestamp=ts, |
| ) |
| await persist_query_audit( |
| settings.audit_db_path, |
| query_id=query_id, |
| action="summarise", |
| user_id=payload.user_id, |
| question=audit_question, |
| collection_name=payload.collection_name, |
| answer=summary, |
| sources=citations, |
| model_used=_model_used_label(settings), |
| tokens_used=tokens_used, |
| response_time_ms=elapsed_ms, |
| kind="summarise", |
| ) |
| return response |
|
|
|
|
| @router.post("/ask", response_model=AskQueryResponse) |
| async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse: |
| """Grounded question answering against a Chroma collection.""" |
| settings = get_settings() |
| try: |
| return await _run_ask(settings, payload) |
| except Exception as exc: |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc |
|
|
|
|
| @router.post("/summarise", response_model=SummariseQueryResponse) |
| async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse: |
| """Collection-wide summary with optional focus for retrieval.""" |
| settings = get_settings() |
| try: |
| return await _run_summarise(settings, payload) |
| except Exception as exc: |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc |
|
|
|
|
| legacy_query_router = APIRouter(tags=["query"]) |
|
|
|
|
| @legacy_query_router.post("/query", response_model=AskQueryResponse) |
| async def query_post_compat(payload: QueryRequest) -> AskQueryResponse: |
| """Same behavior as POST /query/ask; kept for older clients and docs that used POST /query.""" |
| return await ask_endpoint(payload) |
|
|