Spaces:
Running
Running
| """API route definitions for the document assistant.""" | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import queue | |
| import threading | |
| from typing import TYPE_CHECKING | |
| from fastapi import APIRouter, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| if TYPE_CHECKING: | |
| from src.agent.router import QueryRouter | |
| from src.agent.plan_and_execute import PlanAndExecuteRouter | |
| from src.agent.session_store import SessionStore | |
| from src.config import Settings | |
| from src.ingestion.pipeline import IngestionPipeline | |
| from src.retrieval.bm25_search import BM25Search | |
| from src.retrieval.embedder import Embedder | |
| from src.retrieval.vector_store import VectorStore | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter() | |
| def _is_rate_limit_error(exc: str | Exception) -> bool: | |
| """Check whether an exception indicates a rate-limit / quota error. | |
| Walks the full cause chain so wrapped exceptions (e.g. LangGraph | |
| wrapping an upstream 429) are still detected. | |
| """ | |
| texts: list[str] = [] | |
| if isinstance(exc, Exception): | |
| current: BaseException | None = exc | |
| while current is not None: | |
| texts.append(str(current)) | |
| texts.append(type(current).__name__) | |
| current = current.__cause__ | |
| else: | |
| texts.append(exc) | |
| blob = " ".join(texts).lower() | |
| return ( | |
| "429" in blob | |
| or "resource_exhausted" in blob | |
| or "rate limit" in blob | |
| or "rate_limit" in blob | |
| or "too many requests" in blob | |
| ) | |
| _query_router: "QueryRouter | PlanAndExecuteRouter | None" = None | |
| _ingestion_pipeline: "IngestionPipeline | None" = None | |
| _embedder: "Embedder | None" = None | |
| _vector_store: "VectorStore | None" = None | |
| _bm25_search: "BM25Search | None" = None | |
| _settings: "Settings | None" = None | |
| _session_store: "SessionStore | None" = None | |
| def set_dependencies( | |
| query_router: "QueryRouter | PlanAndExecuteRouter", | |
| ingestion_pipeline: "IngestionPipeline", | |
| embedder: "Embedder", | |
| vector_store: "VectorStore", | |
| bm25_search: "BM25Search", | |
| settings: "Settings", | |
| session_store: "SessionStore | None" = None, | |
| ) -> None: | |
| """Inject dependencies from the application factory. | |
| Args: | |
| query_router: Configured QueryRouter instance. | |
| ingestion_pipeline: Configured IngestionPipeline instance. | |
| embedder: Embedder instance for generating embeddings. | |
| vector_store: VectorStore instance for dense indexing. | |
| bm25_search: BM25Search instance for sparse indexing. | |
| settings: Application settings. | |
| session_store: Optional SessionStore for per-user conversation memory. | |
| """ | |
| global _query_router, _ingestion_pipeline, _embedder, _vector_store, _bm25_search, _settings, _session_store | |
| _query_router = query_router | |
| _ingestion_pipeline = ingestion_pipeline | |
| _embedder = embedder | |
| _vector_store = vector_store | |
| _bm25_search = bm25_search | |
| _settings = settings | |
| _session_store = session_store | |
| class QueryRequest(BaseModel): | |
| """Request body for the query endpoint.""" | |
| question: str | |
| top_k: int = 5 | |
| strategy: str = "recursive" | |
| session_id: str = "" | |
| class PipelineResultItem(BaseModel): | |
| """A single result item in pipeline details.""" | |
| document_id: str | |
| chunk_id: str | |
| score: float | |
| source: str | |
| metadata: dict[str, str | int] = {} | |
| class PipelineDetailsResponse(BaseModel): | |
| """Intermediate pipeline data for the query response.""" | |
| original_query: str = "" | |
| retrieval_query: str = "" | |
| detected_language: str = "" | |
| translated: bool = False | |
| dense_results: list[PipelineResultItem] = [] | |
| sparse_results: list[PipelineResultItem] = [] | |
| fused_results: list[PipelineResultItem] = [] | |
| reranked_results: list[PipelineResultItem] = [] | |
| plan_steps: list[str] = [] | |
| tool_calls: list[str] = [] | |
| class SourceItem(BaseModel): | |
| """A single source item in the query response.""" | |
| chunk_id: str | |
| document_id: str | |
| score: float | |
| source: str | |
| text: str = "" | |
| metadata: dict[str, str | int] = {} | |
| class QueryResponse(BaseModel): | |
| """Response body for the query endpoint.""" | |
| answer: str | |
| sources: list[SourceItem] | |
| intent: str | |
| confidence: float | |
| pipeline_details: PipelineDetailsResponse = PipelineDetailsResponse() | |
| class IngestRequest(BaseModel): | |
| """Request body for the document ingestion endpoint.""" | |
| file_path: str | |
| strategy: str = "recursive" | |
| class IngestResponse(BaseModel): | |
| """Response body for the document ingestion endpoint.""" | |
| document_id: str | |
| chunks_created: int | |
| class HealthResponse(BaseModel): | |
| """Response body for the health check endpoint.""" | |
| status: str | |
| version: str | |
| llm_provider: str = "" | |
| llm_model: str = "" | |
| embedding_provider: str = "" | |
| embedding_model: str = "" | |
| class ReadinessResponse(BaseModel): | |
| """Response body for the readiness probe.""" | |
| status: str | |
| checks: dict[str, bool] | |
| def _build_health_response() -> HealthResponse: | |
| """Build the full health response with provider details.""" | |
| llm_provider = "" | |
| llm_model = "" | |
| embedding_provider = "" | |
| embedding_model = "" | |
| if _settings is not None: | |
| llm_provider = _settings.llm_provider | |
| embedding_provider = _settings.embedding_provider | |
| embedding_model = _settings.embedding_model | |
| model_map = { | |
| "ollama": _settings.ollama_model, | |
| "openai": _settings.openai_model, | |
| "azure_openai": _settings.azure_openai_deployment, | |
| "bedrock": _settings.aws_bedrock_model, | |
| "groq": _settings.groq_model, | |
| "anthropic": _settings.anthropic_model, | |
| "google_genai": _settings.google_model, | |
| } | |
| llm_model = model_map.get(llm_provider, _settings.generation_model) | |
| return HealthResponse( | |
| status="ok", | |
| version="0.1.0", | |
| llm_provider=llm_provider, | |
| llm_model=llm_model, | |
| embedding_provider=embedding_provider, | |
| embedding_model=embedding_model, | |
| ) | |
| async def health_check() -> HealthResponse: | |
| """Health check endpoint (backwards compatible). | |
| Returns: | |
| HealthResponse with service status and version. | |
| """ | |
| return _build_health_response() | |
| async def liveness() -> HealthResponse: | |
| """Liveness probe. Returns 200 if the process is running. | |
| Kubernetes uses this to decide whether to restart the container. | |
| Does not check external dependencies. | |
| Returns: | |
| HealthResponse with service status and version. | |
| """ | |
| return _build_health_response() | |
| async def readiness() -> ReadinessResponse: | |
| """Readiness probe. Returns 200 only when all dependencies are available. | |
| Kubernetes uses this to decide whether to route traffic to the pod. | |
| Checks: vector store reachable, BM25 index loaded. | |
| Returns: | |
| ReadinessResponse with per-dependency check results. | |
| Raises: | |
| HTTPException: 503 if any dependency check fails. | |
| """ | |
| checks: dict[str, bool] = {} | |
| # Check vector store connectivity | |
| try: | |
| if _vector_store is not None: | |
| _vector_store.get_all_chunks()[:0] # lightweight probe | |
| checks["vector_store"] = True | |
| else: | |
| checks["vector_store"] = False | |
| except Exception: | |
| logger.warning("Readiness check failed: vector store unreachable") | |
| checks["vector_store"] = False | |
| # Check BM25 index is loaded | |
| checks["bm25_index"] = _bm25_search is not None and _bm25_search.is_indexed | |
| # Check router is wired up | |
| checks["router"] = _query_router is not None | |
| all_ready = all(checks.values()) | |
| if not all_ready: | |
| raise HTTPException(status_code=503, detail={"status": "unavailable", "checks": checks}) | |
| return ReadinessResponse(status="ready", checks=checks) | |
| async def query_documents(request: QueryRequest) -> QueryResponse: | |
| """Query the document knowledge base. | |
| Args: | |
| request: Query parameters including question and retrieval settings. | |
| Returns: | |
| QueryResponse with generated answer and source documents. | |
| """ | |
| logger.info("Received query: %s (session=%s)", request.question, request.session_id[:8] if request.session_id else "none") | |
| # Resolve per-session memory (only used by PlanAndExecuteRouter) | |
| session_memory = None | |
| if request.session_id and _session_store is not None: | |
| session_memory = _session_store.get_memory(request.session_id) | |
| try: | |
| kwargs: dict = {"query": request.question, "top_k": request.top_k} | |
| if session_memory is not None and hasattr(_query_router, "_memory"): | |
| kwargs["memory"] = session_memory | |
| response = _query_router.route(**kwargs) | |
| except Exception as exc: | |
| exc_str = str(exc) | |
| if _is_rate_limit_error(exc): | |
| logger.warning("Rate limit / quota exhausted: %s", exc_str) | |
| raise HTTPException( | |
| status_code=429, | |
| detail="API quota temporarily exhausted. Please wait a moment and try again.", | |
| ) from exc | |
| raise | |
| # Persist the turn to SQLite (in-memory already updated by the router) | |
| if request.session_id and _session_store is not None: | |
| _session_store.persist_turn( | |
| request.session_id, | |
| request.question, | |
| response.answer, | |
| response.sources, | |
| ) | |
| sources = [result.to_dict() for result in response.sources] | |
| pd = response.pipeline_details | |
| pipeline_details = PipelineDetailsResponse( | |
| original_query=pd.original_query, | |
| retrieval_query=pd.retrieval_query, | |
| detected_language=pd.detected_language, | |
| translated=pd.translated, | |
| dense_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.dense_results], | |
| sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results], | |
| fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results], | |
| reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results], | |
| plan_steps=pd.plan_steps, | |
| tool_calls=pd.tool_calls, | |
| ) | |
| return QueryResponse( | |
| answer=response.answer, | |
| sources=sources, | |
| intent=response.intent.value, | |
| confidence=response.confidence, | |
| pipeline_details=pipeline_details, | |
| ) | |
| async def query_stream(request: QueryRequest) -> StreamingResponse: | |
| """Stream pipeline progress events using Server-Sent Events (SSE). | |
| Each event is a JSON object with a ``step`` field naming the completed | |
| pipeline node, plus node-specific fields. The final event has | |
| ``step='done'`` and carries the full query result under ``result``. | |
| Args: | |
| request: Query parameters including question and retrieval settings. | |
| Returns: | |
| StreamingResponse with ``text/event-stream`` content type. | |
| """ | |
| event_queue: queue.Queue = queue.Queue() | |
| class _RateLimitLogHandler(logging.Handler): | |
| """Temporary handler that detects SDK-internal 429 retries via logs.""" | |
| _PATTERNS = ("429", "retrying request", "too many requests", "rate limit") | |
| def emit(self, record: logging.LogRecord) -> None: | |
| msg = record.getMessage().lower() | |
| if any(p in msg for p in self._PATTERNS): | |
| retry_sec = "" | |
| # Extract wait time from "Retrying request … in 5.000000 seconds" | |
| if "retrying" in msg and "seconds" in msg: | |
| for part in msg.split(): | |
| try: | |
| retry_sec = f" ({float(part):.0f}s)" | |
| break | |
| except ValueError: | |
| continue | |
| event_queue.put({ | |
| "step": "rate_limit", | |
| "message": f"API rate limit — retrying{retry_sec}", | |
| }) | |
| # Resolve per-session memory for streaming | |
| session_memory = None | |
| if request.session_id and _session_store is not None: | |
| session_memory = _session_store.get_memory(request.session_id) | |
| def _run() -> None: | |
| handler = _RateLimitLogHandler() | |
| handler.setLevel(logging.INFO) | |
| # Attach to root logger to catch openai/httpx/httpcore messages | |
| root_logger = logging.getLogger() | |
| root_logger.addHandler(handler) | |
| try: | |
| stream_kwargs: dict = {"query": request.question, "top_k": request.top_k} | |
| if session_memory is not None and hasattr(_query_router, "_memory"): | |
| stream_kwargs["memory"] = session_memory | |
| for event in _query_router.route_stream(**stream_kwargs): | |
| event_queue.put(event) | |
| # Persist turn to SQLite when streaming completes. | |
| # The router has already added the turn (with sources) to the | |
| # in-memory ConversationMemory before yielding `done`, so we | |
| # read sources back from there to keep the SQLite copy | |
| # consistent with the in-memory cache across restarts. | |
| if ( | |
| event.get("step") == "done" | |
| and request.session_id | |
| and _session_store is not None | |
| ): | |
| result = event.get("result", {}) | |
| persisted_sources = ( | |
| session_memory.last_sources() if session_memory else [] | |
| ) | |
| _session_store.persist_turn( | |
| request.session_id, | |
| request.question, | |
| result.get("answer", ""), | |
| persisted_sources, | |
| ) | |
| except Exception as exc: | |
| logger.error("Stream query failed: %s", exc, exc_info=True) | |
| exc_str = str(exc) | |
| if _is_rate_limit_error(exc): | |
| event_queue.put({"step": "error", "code": 429, "message": exc_str}) | |
| else: | |
| event_queue.put({"step": "error", "code": 500, "message": exc_str}) | |
| finally: | |
| root_logger.removeHandler(handler) | |
| event_queue.put(None) # sentinel | |
| threading.Thread(target=_run, daemon=True).start() | |
| async def _generate(): | |
| loop = asyncio.get_running_loop() | |
| while True: | |
| event = await loop.run_in_executor(None, event_queue.get) | |
| if event is None: | |
| break | |
| yield f"data: {json.dumps(event)}\n\n" | |
| return StreamingResponse(_generate(), media_type="text/event-stream") | |
| async def ingest_document(request: IngestRequest) -> IngestResponse: | |
| """Ingest a new document into the knowledge base. | |
| Args: | |
| request: Ingestion parameters including file path and strategy. | |
| Returns: | |
| IngestResponse with document ID and number of chunks created. | |
| """ | |
| if not os.path.isfile(request.file_path): | |
| raise HTTPException(status_code=404, detail=f"File not found: {request.file_path}") | |
| logger.info("Ingesting document: %s", request.file_path) | |
| try: | |
| chunks = _ingestion_pipeline.ingest_pdf(request.file_path) | |
| if chunks: | |
| embeddings = _embedder.embed_batch([chunk.text for chunk in chunks]) | |
| _vector_store.add_chunks(chunks, embeddings) | |
| all_chunks = _vector_store.get_all_chunks() | |
| _bm25_search.index(all_chunks) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| except Exception as exc: | |
| logger.error("Ingestion failed: %s", exc) | |
| raise HTTPException(status_code=500, detail="Document ingestion failed") from exc | |
| document_id = os.path.basename(request.file_path) | |
| logger.info("Ingested %d chunks for document %s", len(chunks), document_id) | |
| return IngestResponse( | |
| document_id=document_id, | |
| chunks_created=len(chunks), | |
| ) | |