"""API route definitions for the document assistant.""" import asyncio import json import logging import os import queue import threading from typing import TYPE_CHECKING from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel if TYPE_CHECKING: from src.agent.router import QueryRouter from src.agent.plan_and_execute import PlanAndExecuteRouter from src.agent.session_store import SessionStore from src.config import Settings from src.ingestion.pipeline import IngestionPipeline from src.retrieval.bm25_search import BM25Search from src.retrieval.embedder import Embedder from src.retrieval.vector_store import VectorStore logger = logging.getLogger(__name__) router = APIRouter() def _is_rate_limit_error(exc: str | Exception) -> bool: """Check whether an exception indicates a rate-limit / quota error. Walks the full cause chain so wrapped exceptions (e.g. LangGraph wrapping an upstream 429) are still detected. """ texts: list[str] = [] if isinstance(exc, Exception): current: BaseException | None = exc while current is not None: texts.append(str(current)) texts.append(type(current).__name__) current = current.__cause__ else: texts.append(exc) blob = " ".join(texts).lower() return ( "429" in blob or "resource_exhausted" in blob or "rate limit" in blob or "rate_limit" in blob or "too many requests" in blob ) _query_router: "QueryRouter | PlanAndExecuteRouter | None" = None _ingestion_pipeline: "IngestionPipeline | None" = None _embedder: "Embedder | None" = None _vector_store: "VectorStore | None" = None _bm25_search: "BM25Search | None" = None _settings: "Settings | None" = None _session_store: "SessionStore | None" = None def set_dependencies( query_router: "QueryRouter | PlanAndExecuteRouter", ingestion_pipeline: "IngestionPipeline", embedder: "Embedder", vector_store: "VectorStore", bm25_search: "BM25Search", settings: "Settings", session_store: "SessionStore | None" = None, ) -> None: """Inject dependencies from the application factory. Args: query_router: Configured QueryRouter instance. ingestion_pipeline: Configured IngestionPipeline instance. embedder: Embedder instance for generating embeddings. vector_store: VectorStore instance for dense indexing. bm25_search: BM25Search instance for sparse indexing. settings: Application settings. session_store: Optional SessionStore for per-user conversation memory. """ global _query_router, _ingestion_pipeline, _embedder, _vector_store, _bm25_search, _settings, _session_store _query_router = query_router _ingestion_pipeline = ingestion_pipeline _embedder = embedder _vector_store = vector_store _bm25_search = bm25_search _settings = settings _session_store = session_store class QueryRequest(BaseModel): """Request body for the query endpoint.""" question: str top_k: int = 5 strategy: str = "recursive" session_id: str = "" class PipelineResultItem(BaseModel): """A single result item in pipeline details.""" document_id: str chunk_id: str score: float source: str metadata: dict[str, str | int] = {} class PipelineDetailsResponse(BaseModel): """Intermediate pipeline data for the query response.""" original_query: str = "" retrieval_query: str = "" detected_language: str = "" translated: bool = False dense_results: list[PipelineResultItem] = [] sparse_results: list[PipelineResultItem] = [] fused_results: list[PipelineResultItem] = [] reranked_results: list[PipelineResultItem] = [] plan_steps: list[str] = [] tool_calls: list[str] = [] class SourceItem(BaseModel): """A single source item in the query response.""" chunk_id: str document_id: str score: float source: str text: str = "" metadata: dict[str, str | int] = {} class QueryResponse(BaseModel): """Response body for the query endpoint.""" answer: str sources: list[SourceItem] intent: str confidence: float pipeline_details: PipelineDetailsResponse = PipelineDetailsResponse() class IngestRequest(BaseModel): """Request body for the document ingestion endpoint.""" file_path: str strategy: str = "recursive" class IngestResponse(BaseModel): """Response body for the document ingestion endpoint.""" document_id: str chunks_created: int class HealthResponse(BaseModel): """Response body for the health check endpoint.""" status: str version: str llm_provider: str = "" llm_model: str = "" embedding_provider: str = "" embedding_model: str = "" class ReadinessResponse(BaseModel): """Response body for the readiness probe.""" status: str checks: dict[str, bool] def _build_health_response() -> HealthResponse: """Build the full health response with provider details.""" llm_provider = "" llm_model = "" embedding_provider = "" embedding_model = "" if _settings is not None: llm_provider = _settings.llm_provider embedding_provider = _settings.embedding_provider embedding_model = _settings.embedding_model model_map = { "ollama": _settings.ollama_model, "openai": _settings.openai_model, "azure_openai": _settings.azure_openai_deployment, "bedrock": _settings.aws_bedrock_model, "groq": _settings.groq_model, "anthropic": _settings.anthropic_model, "google_genai": _settings.google_model, } llm_model = model_map.get(llm_provider, _settings.generation_model) return HealthResponse( status="ok", version="0.1.0", llm_provider=llm_provider, llm_model=llm_model, embedding_provider=embedding_provider, embedding_model=embedding_model, ) @router.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: """Health check endpoint (backwards compatible). Returns: HealthResponse with service status and version. """ return _build_health_response() @router.get("/health/live", response_model=HealthResponse) async def liveness() -> HealthResponse: """Liveness probe. Returns 200 if the process is running. Kubernetes uses this to decide whether to restart the container. Does not check external dependencies. Returns: HealthResponse with service status and version. """ return _build_health_response() @router.get("/health/ready", response_model=ReadinessResponse) async def readiness() -> ReadinessResponse: """Readiness probe. Returns 200 only when all dependencies are available. Kubernetes uses this to decide whether to route traffic to the pod. Checks: vector store reachable, BM25 index loaded. Returns: ReadinessResponse with per-dependency check results. Raises: HTTPException: 503 if any dependency check fails. """ checks: dict[str, bool] = {} # Check vector store connectivity try: if _vector_store is not None: _vector_store.get_all_chunks()[:0] # lightweight probe checks["vector_store"] = True else: checks["vector_store"] = False except Exception: logger.warning("Readiness check failed: vector store unreachable") checks["vector_store"] = False # Check BM25 index is loaded checks["bm25_index"] = _bm25_search is not None and _bm25_search.is_indexed # Check router is wired up checks["router"] = _query_router is not None all_ready = all(checks.values()) if not all_ready: raise HTTPException(status_code=503, detail={"status": "unavailable", "checks": checks}) return ReadinessResponse(status="ready", checks=checks) @router.post("/query", response_model=QueryResponse) async def query_documents(request: QueryRequest) -> QueryResponse: """Query the document knowledge base. Args: request: Query parameters including question and retrieval settings. Returns: QueryResponse with generated answer and source documents. """ logger.info("Received query: %s (session=%s)", request.question, request.session_id[:8] if request.session_id else "none") # Resolve per-session memory (only used by PlanAndExecuteRouter) session_memory = None if request.session_id and _session_store is not None: session_memory = _session_store.get_memory(request.session_id) try: kwargs: dict = {"query": request.question, "top_k": request.top_k} if session_memory is not None and hasattr(_query_router, "_memory"): kwargs["memory"] = session_memory response = _query_router.route(**kwargs) except Exception as exc: exc_str = str(exc) if _is_rate_limit_error(exc): logger.warning("Rate limit / quota exhausted: %s", exc_str) raise HTTPException( status_code=429, detail="API quota temporarily exhausted. Please wait a moment and try again.", ) from exc raise # Persist the turn to SQLite (in-memory already updated by the router) if request.session_id and _session_store is not None: _session_store.persist_turn( request.session_id, request.question, response.answer, response.sources, ) sources = [result.to_dict() for result in response.sources] pd = response.pipeline_details pipeline_details = PipelineDetailsResponse( original_query=pd.original_query, retrieval_query=pd.retrieval_query, detected_language=pd.detected_language, translated=pd.translated, dense_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.dense_results], sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results], fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results], reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results], plan_steps=pd.plan_steps, tool_calls=pd.tool_calls, ) return QueryResponse( answer=response.answer, sources=sources, intent=response.intent.value, confidence=response.confidence, pipeline_details=pipeline_details, ) @router.post("/query/stream") async def query_stream(request: QueryRequest) -> StreamingResponse: """Stream pipeline progress events using Server-Sent Events (SSE). Each event is a JSON object with a ``step`` field naming the completed pipeline node, plus node-specific fields. The final event has ``step='done'`` and carries the full query result under ``result``. Args: request: Query parameters including question and retrieval settings. Returns: StreamingResponse with ``text/event-stream`` content type. """ event_queue: queue.Queue = queue.Queue() class _RateLimitLogHandler(logging.Handler): """Temporary handler that detects SDK-internal 429 retries via logs.""" _PATTERNS = ("429", "retrying request", "too many requests", "rate limit") def emit(self, record: logging.LogRecord) -> None: msg = record.getMessage().lower() if any(p in msg for p in self._PATTERNS): retry_sec = "" # Extract wait time from "Retrying request … in 5.000000 seconds" if "retrying" in msg and "seconds" in msg: for part in msg.split(): try: retry_sec = f" ({float(part):.0f}s)" break except ValueError: continue event_queue.put({ "step": "rate_limit", "message": f"API rate limit — retrying{retry_sec}", }) # Resolve per-session memory for streaming session_memory = None if request.session_id and _session_store is not None: session_memory = _session_store.get_memory(request.session_id) def _run() -> None: handler = _RateLimitLogHandler() handler.setLevel(logging.INFO) # Attach to root logger to catch openai/httpx/httpcore messages root_logger = logging.getLogger() root_logger.addHandler(handler) try: stream_kwargs: dict = {"query": request.question, "top_k": request.top_k} if session_memory is not None and hasattr(_query_router, "_memory"): stream_kwargs["memory"] = session_memory for event in _query_router.route_stream(**stream_kwargs): event_queue.put(event) # Persist turn to SQLite when streaming completes. # The router has already added the turn (with sources) to the # in-memory ConversationMemory before yielding `done`, so we # read sources back from there to keep the SQLite copy # consistent with the in-memory cache across restarts. if ( event.get("step") == "done" and request.session_id and _session_store is not None ): result = event.get("result", {}) persisted_sources = ( session_memory.last_sources() if session_memory else [] ) _session_store.persist_turn( request.session_id, request.question, result.get("answer", ""), persisted_sources, ) except Exception as exc: logger.error("Stream query failed: %s", exc, exc_info=True) exc_str = str(exc) if _is_rate_limit_error(exc): event_queue.put({"step": "error", "code": 429, "message": exc_str}) else: event_queue.put({"step": "error", "code": 500, "message": exc_str}) finally: root_logger.removeHandler(handler) event_queue.put(None) # sentinel threading.Thread(target=_run, daemon=True).start() async def _generate(): loop = asyncio.get_running_loop() while True: event = await loop.run_in_executor(None, event_queue.get) if event is None: break yield f"data: {json.dumps(event)}\n\n" return StreamingResponse(_generate(), media_type="text/event-stream") @router.post("/ingest", response_model=IngestResponse) async def ingest_document(request: IngestRequest) -> IngestResponse: """Ingest a new document into the knowledge base. Args: request: Ingestion parameters including file path and strategy. Returns: IngestResponse with document ID and number of chunks created. """ if not os.path.isfile(request.file_path): raise HTTPException(status_code=404, detail=f"File not found: {request.file_path}") logger.info("Ingesting document: %s", request.file_path) try: chunks = _ingestion_pipeline.ingest_pdf(request.file_path) if chunks: embeddings = _embedder.embed_batch([chunk.text for chunk in chunks]) _vector_store.add_chunks(chunks, embeddings) all_chunks = _vector_store.get_all_chunks() _bm25_search.index(all_chunks) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: logger.error("Ingestion failed: %s", exc) raise HTTPException(status_code=500, detail="Document ingestion failed") from exc document_id = os.path.basename(request.file_path) logger.info("Ingested %d chunks for document %s", len(chunks), document_id) return IngestResponse( document_id=document_id, chunks_created=len(chunks), )