"""Grounded Q&A and summarisation routes. ``POST /query/ask`` retrieves top-K chunks from Chroma, calls the configured LLM with citations enforced in the prompt, persists an audit row, and returns answer + sources. ``POST /query/summarise`` uses a retrieval-oriented query then a summary-focused prompt. ``POST /query`` is a legacy alias for ``/query/ask``. """ import time from datetime import datetime, timezone from uuid import uuid4 from fastapi import APIRouter, HTTPException, status from api.config import Settings, get_settings from models.requests import QueryRequest, SummariseRequest from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse from rag.embedder import create_embedding_function from rag.retriever import ( SUMMARY_RETRIEVAL_QUERY, RetrievedChunk, answer_with_grounding, retrieve_chunks, summarise_with_grounding, ) from rag.vector_store import collection_document_count, get_vector_store from storage.audit_store import persist_query_audit router = APIRouter(prefix="/query", tags=["query"]) def _model_used_label(settings: Settings) -> str: provider = settings.llm_provider.lower() if provider == "openai": return settings.openai_model if provider == "ollama": return settings.ollama_chat_model if provider == "anthropic": return settings.anthropic_model if provider == "huggingface": return settings.huggingface_model return f"{provider}:unknown" def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]: citations: list[SourceCitation] = [] for chunk in chunks: page = chunk.page if chunk.page is not None else 0 score = float(chunk.score) if chunk.score is not None else 0.0 citations.append( SourceCitation( document_name=chunk.source or "unknown", page_number=page, chunk_text=chunk.text, relevance_score=score, ) ) return citations async def _run_ask( settings: Settings, payload: QueryRequest, ) -> AskQueryResponse: """Retrieve, generate grounded answer, audit, and build the API response.""" top_k = payload.top_k t0 = time.perf_counter() embedding_function = create_embedding_function() vector_store = get_vector_store( persist_directory=settings.chroma_persist_directory, collection_name=payload.collection_name or "default", embedding_function=embedding_function, ) chunks = retrieve_chunks(vector_store, payload.question, top_k) answer, tokens_used = answer_with_grounding(settings, payload.question, chunks) elapsed_ms = int((time.perf_counter() - t0) * 1000) citations = _chunks_to_citations(chunks) query_id = str(uuid4()) ts = datetime.now(timezone.utc) response = AskQueryResponse( query_id=query_id, question=payload.question, answer=answer, sources=citations, model_used=_model_used_label(settings), tokens_used=tokens_used, response_time_ms=elapsed_ms, timestamp=ts, ) await persist_query_audit( settings.audit_db_path, query_id=query_id, action="query", user_id=payload.user_id, question=payload.question, collection_name=payload.collection_name or "default", answer=answer, sources=citations, model_used=response.model_used, tokens_used=tokens_used, response_time_ms=elapsed_ms, kind="ask", ) return response async def _run_summarise( settings: Settings, payload: SummariseRequest, ) -> SummariseQueryResponse: """Retrieve with focus or default overview query, summarise, and audit.""" top_k = settings.top_k_results retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection" t0 = time.perf_counter() embedding_function = create_embedding_function() vector_store = get_vector_store( persist_directory=settings.chroma_persist_directory, collection_name=payload.collection_name, embedding_function=embedding_function, ) chunks = retrieve_chunks(vector_store, retrieval_query, top_k) summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks) elapsed_ms = int((time.perf_counter() - t0) * 1000) citations = _chunks_to_citations(chunks) doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name) query_id = str(uuid4()) ts = datetime.now(timezone.utc) response = SummariseQueryResponse( query_id=query_id, summary=summary, document_count=doc_count, sources=citations, timestamp=ts, ) await persist_query_audit( settings.audit_db_path, query_id=query_id, action="summarise", user_id=payload.user_id, question=audit_question, collection_name=payload.collection_name, answer=summary, sources=citations, model_used=_model_used_label(settings), tokens_used=tokens_used, response_time_ms=elapsed_ms, kind="summarise", ) return response @router.post("/ask", response_model=AskQueryResponse) async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse: """Grounded question answering against a Chroma collection.""" settings = get_settings() try: return await _run_ask(settings, payload) except Exception as exc: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc @router.post("/summarise", response_model=SummariseQueryResponse) async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse: """Collection-wide summary with optional focus for retrieval.""" settings = get_settings() try: return await _run_summarise(settings, payload) except Exception as exc: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc legacy_query_router = APIRouter(tags=["query"]) @legacy_query_router.post("/query", response_model=AskQueryResponse) async def query_post_compat(payload: QueryRequest) -> AskQueryResponse: """Same behavior as POST /query/ask; kept for older clients and docs that used POST /query.""" return await ask_endpoint(payload)