Mayank Chugh
Deploy DocuAudit AI to Hugging Face Space (no binaries)
d44b33d
"""Grounded Q&A and summarisation routes.
``POST /query/ask`` retrieves top-K chunks from Chroma, calls the configured LLM with
citations enforced in the prompt, persists an audit row, and returns answer + sources.
``POST /query/summarise`` uses a retrieval-oriented query then a summary-focused prompt.
``POST /query`` is a legacy alias for ``/query/ask``.
"""
import time
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import APIRouter, HTTPException, status
from api.config import Settings, get_settings
from models.requests import QueryRequest, SummariseRequest
from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse
from rag.embedder import create_embedding_function
from rag.retriever import (
SUMMARY_RETRIEVAL_QUERY,
RetrievedChunk,
answer_with_grounding,
retrieve_chunks,
summarise_with_grounding,
)
from rag.vector_store import collection_document_count, get_vector_store
from storage.audit_store import persist_query_audit
router = APIRouter(prefix="/query", tags=["query"])
def _model_used_label(settings: Settings) -> str:
provider = settings.llm_provider.lower()
if provider == "openai":
return settings.openai_model
if provider == "ollama":
return settings.ollama_chat_model
if provider == "anthropic":
return settings.anthropic_model
if provider == "huggingface":
return settings.huggingface_model
return f"{provider}:unknown"
def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]:
citations: list[SourceCitation] = []
for chunk in chunks:
page = chunk.page if chunk.page is not None else 0
score = float(chunk.score) if chunk.score is not None else 0.0
citations.append(
SourceCitation(
document_name=chunk.source or "unknown",
page_number=page,
chunk_text=chunk.text,
relevance_score=score,
)
)
return citations
async def _run_ask(
settings: Settings,
payload: QueryRequest,
) -> AskQueryResponse:
"""Retrieve, generate grounded answer, audit, and build the API response."""
top_k = payload.top_k
t0 = time.perf_counter()
embedding_function = create_embedding_function()
vector_store = get_vector_store(
persist_directory=settings.chroma_persist_directory,
collection_name=payload.collection_name or "default",
embedding_function=embedding_function,
)
chunks = retrieve_chunks(vector_store, payload.question, top_k)
answer, tokens_used = answer_with_grounding(settings, payload.question, chunks)
elapsed_ms = int((time.perf_counter() - t0) * 1000)
citations = _chunks_to_citations(chunks)
query_id = str(uuid4())
ts = datetime.now(timezone.utc)
response = AskQueryResponse(
query_id=query_id,
question=payload.question,
answer=answer,
sources=citations,
model_used=_model_used_label(settings),
tokens_used=tokens_used,
response_time_ms=elapsed_ms,
timestamp=ts,
)
await persist_query_audit(
settings.audit_db_path,
query_id=query_id,
action="query",
user_id=payload.user_id,
question=payload.question,
collection_name=payload.collection_name or "default",
answer=answer,
sources=citations,
model_used=response.model_used,
tokens_used=tokens_used,
response_time_ms=elapsed_ms,
kind="ask",
)
return response
async def _run_summarise(
settings: Settings,
payload: SummariseRequest,
) -> SummariseQueryResponse:
"""Retrieve with focus or default overview query, summarise, and audit."""
top_k = settings.top_k_results
retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
t0 = time.perf_counter()
embedding_function = create_embedding_function()
vector_store = get_vector_store(
persist_directory=settings.chroma_persist_directory,
collection_name=payload.collection_name,
embedding_function=embedding_function,
)
chunks = retrieve_chunks(vector_store, retrieval_query, top_k)
summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
elapsed_ms = int((time.perf_counter() - t0) * 1000)
citations = _chunks_to_citations(chunks)
doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name)
query_id = str(uuid4())
ts = datetime.now(timezone.utc)
response = SummariseQueryResponse(
query_id=query_id,
summary=summary,
document_count=doc_count,
sources=citations,
timestamp=ts,
)
await persist_query_audit(
settings.audit_db_path,
query_id=query_id,
action="summarise",
user_id=payload.user_id,
question=audit_question,
collection_name=payload.collection_name,
answer=summary,
sources=citations,
model_used=_model_used_label(settings),
tokens_used=tokens_used,
response_time_ms=elapsed_ms,
kind="summarise",
)
return response
@router.post("/ask", response_model=AskQueryResponse)
async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse:
"""Grounded question answering against a Chroma collection."""
settings = get_settings()
try:
return await _run_ask(settings, payload)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
@router.post("/summarise", response_model=SummariseQueryResponse)
async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse:
"""Collection-wide summary with optional focus for retrieval."""
settings = get_settings()
try:
return await _run_summarise(settings, payload)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
legacy_query_router = APIRouter(tags=["query"])
@legacy_query_router.post("/query", response_model=AskQueryResponse)
async def query_post_compat(payload: QueryRequest) -> AskQueryResponse:
"""Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
return await ask_endpoint(payload)