"""FastAPI surface for SecureAgentRAG. Run with:: uv run uvicorn interfaces.api:app --host 0.0.0.0 --port 8080 Endpoints --------- - ``GET /healthz`` — liveness probe (no auth). - ``GET /readyz`` — readiness — pings Qdrant + Ollama. - ``POST /query`` — run the RAG pipeline; returns ``QueryResponse``. - ``POST /ingest`` — ingest a local file; requires ``user`` role. - ``GET /audit`` — read paginated audit entries; requires ``admin``. - ``POST /audit/verify``— verify the hash-chain; requires ``admin``. Auth uses a stateless bearer token. The token payload is a base64-encoded JSON ``UserContext`` so the API has no session store — caller provides identity on every request. Production deployments should swap this for Keycloak/Auth0 JWT verification (left as a hook in ``_resolve_user``). """ from __future__ import annotations import asyncio import base64 import contextlib import json from datetime import date from typing import Annotated from config.settings import settings from utils.auth import AuthError, issue_token, verify_token from utils.logging import get_logger logger = get_logger(__name__) try: from fastapi import Depends, FastAPI, Header, HTTPException, status from fastapi.responses import JSONResponse _FASTAPI_AVAILABLE = True except ImportError: # pragma: no cover _FASTAPI_AVAILABLE = False Depends = Header = FastAPI = HTTPException = JSONResponse = status = None # type: ignore[assignment] if _FASTAPI_AVAILABLE: from core.graph import run_rag_pipeline from core.schemas import ( IngestRequestModel, IngestResponseModel, QueryRequest, QueryResponse, ) from ingestion.metadata import IngestRequest, SensitivityLevel, UserContext from utils.audit import audit_logger from utils.health import run_health_checks from utils.rate_limiter import RateLimiter rate_limiter = RateLimiter() # uses default token-bucket config _AUTH_ERROR_STATUS: dict[str, int] = { "missing": status.HTTP_401_UNAUTHORIZED, "malformed": status.HTTP_401_UNAUTHORIZED, "expired": status.HTTP_401_UNAUTHORIZED, "bad_signature": status.HTTP_401_UNAUTHORIZED, "bad_claims": status.HTTP_403_FORBIDDEN, } def _resolve_user_full( authorization: Annotated[str | None, Header()] = None, ) -> tuple[UserContext, dict]: """Verify the bearer token and return (UserContext, claims). Delegates to :func:`utils.auth.verify_token`, which uses HS256 JWT when ``SAR_JWT_SECRET`` is set and falls back to the legacy unsigned base64 token otherwise (with a runtime warning). """ if not authorization or not authorization.lower().startswith("bearer "): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token") token = authorization.split(" ", 1)[1] try: return verify_token(token) except AuthError as exc: code = _AUTH_ERROR_STATUS.get(exc.reason, status.HTTP_401_UNAUTHORIZED) raise HTTPException(code, f"auth_{exc.reason}: {exc}") from exc def _resolve_user(authorization: Annotated[str | None, Header()] = None) -> UserContext: """Backward-compatible dependency returning only the UserContext.""" ctx, _claims = _resolve_user_full(authorization=authorization) return ctx def _require_role(required: str): def _dep(user: Annotated[UserContext, Depends(_resolve_user)]) -> UserContext: if required not in user.roles and "admin" not in user.roles: raise HTTPException(status.HTTP_403_FORBIDDEN, f"role '{required}' required") return user return _dep from contextlib import asynccontextmanager @asynccontextmanager async def _lifespan(_app: FastAPI): """Start background jobs on boot, stop them cleanly on shutdown. - Periodic audit hash-chain verification (always; local-only). - BYOK session-collection purge (only when ``byok_mode`` is on). Both degrade gracefully when APScheduler is absent and never block startup on failure. """ schedulers: list = [] bg_tasks: list = [] # keep strong refs so created tasks aren't GC'd try: from utils.audit_verify import schedule_audit_verification s = schedule_audit_verification() if s is not None: schedulers.append(s) except Exception as exc: # pragma: no cover - defensive logger.error("audit_verify_schedule_failed", error=str(exc)) if settings.byok_mode: try: from retrieval.qdrant_client import QdrantManager from retrieval.session_purge import schedule_session_purge s = schedule_session_purge(QdrantManager().client) if s is not None: schedulers.append(s) except Exception as exc: # pragma: no cover - defensive logger.error("session_purge_schedule_failed", error=str(exc)) # Warm the local embedder off the request path. The first BYOK upload # would otherwise pay the ~770 MB sentence-transformers model load # (20-40 s on CPU Basic) inline, blowing the Vercel Edge 30 s proxy # budget and surfacing as a spurious "upload timed out" even for tiny # files. Loading it right after boot makes the first real upload fast. if settings.embedding_backend == "local": async def _warm_embedder() -> None: try: from retrieval.embeddings import _get_local_embedder await asyncio.to_thread(_get_local_embedder) logger.info("embedder_warmed_at_startup") except Exception as exc: # pragma: no cover - defensive logger.warning("embedder_warm_failed", error=str(exc)) with contextlib.suppress(Exception): bg_tasks.append(asyncio.create_task(_warm_embedder())) try: yield finally: for s in schedulers: with contextlib.suppress(Exception): s.shutdown(wait=False) app = FastAPI( title="SecureAgentRAG API", version="0.1.0", description="Privacy-first multi-agent RAG with RBAC, guardrails, and audit chain.", lifespan=_lifespan, ) # Initialize Phoenix tracing if configured. # When ``settings.byok_mode`` is on, ``setup_tracing`` short-circuits to # False regardless of phoenix_endpoint (see utils/observability.py). from utils.observability import setup_tracing _tracing_enabled = setup_tracing() if _tracing_enabled: logger.info("phoenix_tracing_active_in_api") # ── Prometheus metrics ─────────────────────────────────────────────── # Aggregate counters/histograms only — no prompts, completions, keys, or # user text ever lands in a label, so this is safe to expose even under # BYOK (unlike Phoenix tracing, which is hard-disabled above). When the # ``[metrics]`` extra is installed we mount prometheus-fastapi- # instrumentator for HTTP-level metrics and let it serve ``/metrics``; # the custom RAG metrics in utils.metrics share the same default registry # so they appear in the same exposition. Without the extra we fall back to # a manual ``/metrics`` route that 501s until prometheus_client is present. try: from prometheus_fastapi_instrumentator import Instrumentator Instrumentator( should_group_status_codes=True, excluded_handlers=["/metrics", "/healthz"], ).instrument(app).expose(app, endpoint="/metrics", include_in_schema=False) logger.info("prometheus_metrics_enabled", mode="instrumentator") except ImportError: from fastapi import Response from utils.metrics import render_latest @app.get("/metrics", include_in_schema=False, tags=["ops"]) async def metrics() -> Response: try: payload, content_type = render_latest() except RuntimeError: return Response( "prometheus_client not installed; install the [metrics] extra", status_code=status.HTTP_501_NOT_IMPLEMENTED, media_type="text/plain", ) return Response(payload, media_type=content_type) logger.info("prometheus_metrics_enabled", mode="manual_fallback") # ── BYOK CORS middleware ───────────────────────────────────────────── # Only mount CORS when: # 1) BYOK mode is on (public demo path), AND # 2) an explicit allowlist is configured via SAR_CORS_ALLOW_ORIGINS. # Empty allowlist + BYOK = wildcard would be a footgun (CSRF surface). # Empty allowlist + dev = no CORS needed (local same-origin). if settings.byok_mode and settings.cors_allow_origins: from fastapi.middleware.cors import CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=list(settings.cors_allow_origins), allow_credentials=False, # BYOK never uses cookies allow_methods=["GET", "POST", "OPTIONS"], allow_headers=["*"], ) logger.info("byok_cors_enabled", origins=list(settings.cors_allow_origins)) @app.get("/healthz", tags=["ops"]) async def healthz() -> dict[str, str]: return {"status": "ok"} @app.get("/readyz", tags=["ops"]) async def readyz() -> JSONResponse: # In BYOK production mode the deploy has no local Ollama -- pinging # it would always fail and the cron keepalive would mistake the # Space for down. Skip Ollama; ping Groq /models with the owner key # instead (when configured). Postgres + Redis remain optional. if settings.byok_mode: report = await run_health_checks(include_ollama=False) # Optionally surface Groq reachability when an owner key is set. if settings.groq_api_key: import time from utils.health import HealthStatus t0 = time.perf_counter() try: import httpx async with httpx.AsyncClient(timeout=5.0) as client: r = await client.get( f"{settings.groq_api_base}/models", headers={"Authorization": f"Bearer {settings.groq_api_key}"}, ) latency_ms = (time.perf_counter() - t0) * 1000 if r.status_code == 200: report.services.append( HealthStatus( name="groq", healthy=True, latency_ms=latency_ms, message="Owner key reachable.", optional=True, ) ) else: report.services.append( HealthStatus( name="groq", healthy=False, latency_ms=latency_ms, message=f"HTTP {r.status_code}", optional=True, ) ) except Exception as exc: latency_ms = (time.perf_counter() - t0) * 1000 report.services.append( HealthStatus( name="groq", healthy=False, latency_ms=latency_ms, message=f"Connection failed: {exc!s}", optional=True, ) ) code = 200 if report.overall_healthy else 503 return JSONResponse(report.to_dict(), status_code=code) report = await run_health_checks() code = 200 if report.overall_healthy else 503 return JSONResponse(report.to_dict(), status_code=code) # ── BYOK demo endpoint ─────────────────────────────────────────────── # Mounted only when ``settings.byok_mode`` is on. Bypasses JWT auth and # uses per-request BYOK credentials instead. Isolation is enforced via # session-scoped Qdrant collections, not JWT identity. if settings.byok_mode: from inference.byok_context import ( ByokRuntime, reset_byok_runtime, set_byok_runtime, ) from interfaces.byok import ByokCreds, client_ip_from_request, extract_byok from utils.rate_limiter import get_owner_key_throttle def _byok_runtime_for(creds: ByokCreds) -> ByokRuntime | None: """Build the per-request BYOK runtime from creds, or None. Only returns a runtime when the visitor brought usable creds — so the visitor's own key powers the call. Otherwise None and the pipeline routes through the owner's cached clients (throttled). """ if not creds.byok_active(): return None return ByokRuntime( provider=creds.safe_provider(), user_key=creds.user_key, ollama_url=creds.ollama_url, ) # All demo personas share ``org_id="demo"`` so they query the same # ingested corpus. RBAC differentiation is enforced via clearance # level + roles at the payload-filter layer -- exactly the production # invariant we want to demonstrate. _DEMO_ORG_ID = "demo" # Sensitivity levels are LOW=1, MEDIUM=2, HIGH=3 (see # ``ingestion/metadata.py::sensitivity_to_int``). Clearance levels must # be in the same range so the Qdrant range filter passes the right # chunks. Engineer < Compliance == Executive, but executive carries # a wider role set (sees both engineering + compliance content). # ``style`` is a short tone hint injected into the synthesizer's # system prompt so the three demo personas produce visibly distinct # answers from the same retrieved chunks. _DEMO_PERSONAS: dict[str, dict] = { "engineer": { "clearance_level": 2, "roles": ["engineering"], "style": ( "Write for a senior engineer. Use precise technical language, " "name the underlying mechanism, and prefer concrete code " "snippets or commands over abstract description. Skip the " "executive summary -- this reader wants the wire-level detail." ), }, "compliance": { "clearance_level": 3, "roles": ["compliance", "legal"], "style": ( "Write for a compliance / legal reviewer. Foreground the " "regulatory citations, control IDs, and risk vocabulary. " "Highlight gaps and required attestations. Hedge claims that " "are not directly supported by an authoritative source." ), }, "executive": { "clearance_level": 3, "roles": ["executive", "compliance", "engineering"], "style": ( "Write for a busy executive. Lead with the bottom line in " "two sentences. Quantify impact in dollars / risk percentage / " "headcount where the source supports it. Skip implementation " "detail unless it changes the decision." ), }, } def _persona_to_user_ctx(creds: ByokCreds) -> UserContext: """Translate ``creds.demo_persona`` into a synthetic UserContext. Unknown / missing persona → minimal read-only profile so the demo still answers but cannot escalate beyond the lowest clearance. """ preset = _DEMO_PERSONAS.get((creds.demo_persona or "").lower()) if preset is None: preset = {"clearance_level": 1, "roles": ["viewer"], "style": ""} return UserContext( user_id=f"demo-{creds.session_id}", org_id=_DEMO_ORG_ID, clearance_level=preset["clearance_level"], roles=preset["roles"], ) def _persona_style(creds: ByokCreds) -> str: preset = _DEMO_PERSONAS.get((creds.demo_persona or "").lower()) return (preset or {}).get("style", "") if preset else "" # ── Public demo metadata endpoints (no auth, no BYOK key needed) ── # Lightweight read-only endpoints that power the public corpus + # personas + status pages in the frontend. They expose only # metadata that is already implied by the demo (filename, roles, # sensitivity, chunk counts) -- nothing that could leak content. @app.get("/byok/personas", tags=["byok"]) async def byok_personas() -> dict: """Return the three preset RBAC personas + their synth styles. Single source of truth for the frontend ``/personas`` page so the UI never drifts from the actual server-side dispatch in ``_persona_to_user_ctx``. """ return { "items": [ { "key": key, "label": key.capitalize(), "clearance_level": preset["clearance_level"], "roles": list(preset["roles"]), "style": preset["style"], } for key, preset in _DEMO_PERSONAS.items() ], "default": "engineer", "org_id": _DEMO_ORG_ID, } @app.get("/byok/corpus", tags=["byok"]) async def byok_corpus() -> dict: """Summarise the base demo corpus -- source files + metadata. Scrolls the root tenant Qdrant collection (the hand-curated demo docs — English RBAC + Arabic Egypt) and groups points by ``source_file``. Returns one row per file with the chunk count, sensitivity label, and roles -- never the chunk text. Visitor uploads under ``documents_sess_`` are NOT included (those live in the session collection and are surfaced via ``/byok/uploads``). """ from core.agents.retriever import _get_hybrid_searcher files: dict[str, dict] = {} # Bound before the try so the response can always report the real # store even if the scroll fails early. collection = settings.qdrant_collection try: searcher = _get_hybrid_searcher() # The root manager (``for_org`` of the demo org) points at # the base demo collection. Multi-tenant mode wraps the # name as ``documents_demo``. qdrant = searcher._qdrant.for_org(_DEMO_ORG_ID) # type: ignore[attr-defined] client = qdrant.client collection = qdrant.collection_name next_offset = None pages = 0 # Scroll the whole collection in pages of 256. The Arabic # flagship corpus keeps growing, so a fixed page cap would # silently truncate the corpus page. Loop until the cursor is # exhausted; the 64-page ceiling (16k chunks) is a runaway # guard only. while pages < 64: points, next_offset = client.scroll( collection_name=collection, limit=256, with_payload=True, with_vectors=False, offset=next_offset, ) for p in points: payload = p.payload or {} src = payload.get("source_file") or "unknown" base_name = src.split("/")[-1].split("\\")[-1] roles = payload.get("roles") or [] sensitivity = payload.get("sensitivity_level") or "low" item = files.setdefault( src, { "source_file": base_name, "chunks": 0, "sensitivity_level": sensitivity, "roles": list(roles), }, ) item["chunks"] += 1 # Roles can vary across chunks of the same file in # principle; union them so the visitor sees the # widest access required to retrieve the file. for r in roles: if r not in item["roles"]: item["roles"].append(r) pages += 1 if not next_offset: break except Exception as exc: logger.warning("byok_corpus_list_failed", error=str(exc)) sorted_items = sorted(files.values(), key=lambda f: f["source_file"].lower()) return { # Report the real collection name (``documents`` in the default # demo, ``documents_demo`` under multi-tenant mode) instead of a # hardcoded literal so the corpus page never misreports the store. "collection": collection, "count": len(sorted_items), "total_chunks": sum(f["chunks"] for f in sorted_items), "items": sorted_items, } @app.get("/byok/stats", tags=["byok"]) async def byok_stats() -> dict: """Public, no-auth aggregate stats for the landing page. Two kinds of number, surfaced honestly: - **eval** — the rolling Ragas baseline shipped in ``evaluation/baseline.json`` (durable, committed in the repo). This is "proof, not claims": the demo's measured groundedness. - **live activity** — queries answered + documents grounded read from the audit log present on *this* instance. The HF Space writes audit JSONL to ephemeral ``/tmp``, so these reset when the Space sleeps/restarts; the frontend labels them "since the demo last woke" rather than implying an all-time total. No PII, no chunk text, no keys — safe to expose anonymously. """ eval_block: dict = {} try: import json as _json from pathlib import Path as _Path base = _Path(__file__).resolve().parent.parent / "evaluation" / "baseline.json" if base.exists(): data = _json.loads(base.read_text(encoding="utf-8")) eval_block = { "faithfulness": data.get("faithfulness"), "context_precision": data.get("context_precision"), "answer_relevancy": data.get("answer_relevancy"), "calibrated_at": data.get("_calibrated_at"), } except Exception as exc: logger.warning("byok_stats_eval_read_failed", error=str(exc)) queries_answered = 0 docs_grounded = 0 try: from datetime import timedelta end = date.today() start = end - timedelta(days=30) entries = audit_logger.get_entries( start_date=start.isoformat(), end_date=end.isoformat(), action="query", ) for e in entries: md = e.metadata or {} # Uploads also write action="query" rows; exclude them so # the counter reflects answered questions only. if md.get("action_hint") == "upload": continue if e.status == "success": queries_answered += 1 docs_grounded += int(md.get("documents_used", 0) or 0) except Exception as exc: logger.warning("byok_stats_audit_read_failed", error=str(exc)) return { "queries_answered": queries_answered, "docs_grounded": docs_grounded, "eval": eval_block, } from pydantic import BaseModel as _ByokBaseModel class _ByokChatBody(_ByokBaseModel): """Public-demo chat payload — no auth fields, only the question text.""" query: str prefer_cloud: bool = True # Runtime import — FastAPI dependency injection reads the annotation # at request time, so this must NOT be a TYPE_CHECKING-only import. from fastapi import Request as _FastApiRequest @app.post("/byok/chat", tags=["byok"]) async def byok_chat_endpoint( request: _FastApiRequest, body: _ByokChatBody, creds: Annotated[ByokCreds, Depends(extract_byok)], ) -> dict: """Public-demo chat endpoint backed by BYOK credentials. Routing: - Visitor brought a key (``creds.has_user_key()``): pipeline uses the visitor's provider + key. No throttle. - Visitor did NOT bring a key: pipeline falls back to the owner's configured cloud provider key, gated by ``OwnerKeyHourThrottle``. When exhausted, returns 429 with copy nudging BYOK. Persona maps to a synthetic ``UserContext`` so the existing RBAC filter still runs end-to-end — same code path as authenticated queries, just with demo identities. """ # Only a visitor with *usable* BYOK creds bypasses the throttle — # and that same key now actually powers the call (see the BYOK # runtime below). A bare/junk key with no usable provider no longer # skips the throttle while spending the owner key. if not creds.byok_active(): throttle = get_owner_key_throttle() client_ip = client_ip_from_request(request) ok, meta = throttle.allow(client_ip) if not ok: raise HTTPException( status.HTTP_429_TOO_MANY_REQUESTS, detail={ "reason": meta["reason"], "retry_after_seconds": meta["retry_after"], "hint": ( "Owner-key fallback exhausted for this IP. " "Paste your own LLM key to continue — your key " "is never stored server-side." ), }, ) user_ctx = _persona_to_user_ctx(creds) import time as _t _t0 = _t.perf_counter() # Bind the visitor's key/provider for THIS request so the inference # router builds a per-request client from it. The ContextVar # propagates into run_rag_pipeline and every LangGraph node/LLM call; # reset in finally so it never leaks to the next request. _byok_tok = set_byok_runtime(_byok_runtime_for(creds)) try: state = await run_rag_pipeline( query=body.query, user_context=user_ctx, thread_id=f"byok-{creds.session_id}", prefer_cloud=body.prefer_cloud, # Visitor's chosen provider when present; falls back to env. override_provider=creds.safe_provider(), persona_style=_persona_style(creds), byok_session_id=creds.session_id, ) finally: reset_byok_runtime(_byok_tok) elapsed_ms = (_t.perf_counter() - _t0) * 1000 response = QueryResponse.from_state(state) # Persist a single audit-log row so /byok/audit can surface the # session's history. utils.pii.redact strips key shapes from the # query/answer before write. try: audit_logger.log_query( user_id=user_ctx.user_id, org_id=user_ctx.org_id, query=body.query, response_summary=(response.answer or "")[:200], sensitivity=state.get("query_sensitivity", "low"), status="blocked" if response.blocked else "success", latency_ms=elapsed_ms, persona=creds.demo_persona or "anonymous", byok_used=creds.has_user_key(), synth_provider=response.provenance.provider, synth_model=response.provenance.model, faithfulness_ratio=state.get("faithfulness_ratio", 1.0), documents_used=len(state.get("relevant_documents", [])), ) except Exception as exc: # pragma: no cover -- defensive logger.exception("byok_audit_persist_failed", error=str(exc)) return { "session_id": creds.session_id, "persona": creds.demo_persona or "anonymous", "byok_used": creds.has_user_key(), "response": response.model_dump(mode="json"), # Extra payload surfaced for the new frontend UX -- raw audit # trail of nodes executed (used for the "trace pills" strip) # and the rewriter's output if the rewriter fired. "trace": [ {k: v for k, v in entry.items() if k in {"node", "action"}} for entry in state.get("audit_trail", []) ], "rewritten_query": state.get("rewritten_query", ""), "query_sensitivity": state.get("query_sensitivity", "low"), "faithfulness_ratio": float(state.get("faithfulness_ratio", 1.0)), "documents_seen_total": len(state.get("documents", [])), "documents_used_total": len(state.get("relevant_documents", [])), } # ── BYOK streaming endpoint (SSE) ──────────────────────────────── from fastapi.responses import StreamingResponse as _StreamingResponse from core.graph import run_rag_pipeline_stream @app.post("/byok/chat/stream", tags=["byok"]) async def byok_chat_stream_endpoint( request: _FastApiRequest, body: _ByokChatBody, creds: Annotated[ByokCreds, Depends(extract_byok)], ): """Server-Sent Events variant of ``/byok/chat``. Emits ``event: \\ndata: \\n\\n`` frames mirroring the event dicts produced by ``run_rag_pipeline_stream``: - ``phase`` — graph node fired (with merged state snapshot) - ``token`` — synthesizer streaming token - ``blocked`` — guardrails / security / timeout refusal - ``final`` — last state + total latency CORS is already mounted on the app when ``byok_mode`` is on. """ if not creds.byok_active(): throttle = get_owner_key_throttle() client_ip = client_ip_from_request(request) ok, meta = throttle.allow(client_ip) if not ok: raise HTTPException( status.HTTP_429_TOO_MANY_REQUESTS, detail={ "reason": meta["reason"], "retry_after_seconds": meta["retry_after"], "hint": ( "Owner-key fallback exhausted for this IP. " "Paste your own LLM key to continue — your key " "is never stored server-side." ), }, ) user_ctx = _persona_to_user_ctx(creds) async def _gen(): import time as _t _t0 = _t.perf_counter() # Bind the visitor's key/provider for the lifetime of this # stream so the synthesizer's streaming LLM call uses it. _byok_tok = set_byok_runtime(_byok_runtime_for(creds)) # Replay the session_id up front so the client can stitch # token deltas to a known turn without waiting for `final`. yield ( "event: open\n" f"data: {json.dumps({'session_id': creds.session_id, 'persona': creds.demo_persona or 'anonymous', 'byok_used': creds.has_user_key()})}\n\n" ) final_state: dict | None = None try: async for evt in run_rag_pipeline_stream( query=body.query, user_context=user_ctx, thread_id=f"byok-{creds.session_id}", prefer_cloud=body.prefer_cloud, override_provider=creds.safe_provider(), persona_style=_persona_style(creds), byok_session_id=creds.session_id, ): etype = evt.get("type", "unknown") # Reshape the heavy phase/final payload -- raw GraphState # is large; the frontend only needs the public-facing # bits. token frames pass through verbatim. if etype == "token": payload = {"text": evt.get("text", "")} elif etype in ("phase", "blocked", "final"): st = evt.get("state", {}) or {} if etype == "final": final_state = st payload = { "name": evt.get("name", ""), "message": evt.get("message", ""), "latency_ms": evt.get("latency_ms", 0.0), "rewritten_query": st.get("rewritten_query", ""), "query_sensitivity": st.get("query_sensitivity", "low"), "guardrails_passed": st.get("guardrails_passed", True), "security_passed": st.get("security_passed", True), "documents_seen_total": len(st.get("documents", [])), "documents_used_total": len(st.get("relevant_documents", [])), "faithfulness_ratio": float(st.get("faithfulness_ratio", 1.0)), "confidence_score": float(st.get("confidence_score", 0.0)), "synth_provider": st.get("synth_provider", ""), "synth_model": st.get("synth_model", ""), "synth_latency_ms": float(st.get("synth_latency_ms", 0.0)), "trace": [ {k: v for k, v in e.items() if k in {"node", "action"}} for e in st.get("audit_trail", []) ], } if etype == "final": # Include the full QueryResponse shape so the # frontend reaches parity with the non-stream # endpoint (citations + provenance + blocked). payload["response"] = QueryResponse.from_state(st).model_dump( mode="json" ) else: payload = evt yield f"event: {etype}\ndata: {json.dumps(payload)}\n\n" except Exception as exc: # pragma: no cover -- defensive logger.exception("byok_stream_failed", error=str(exc)) yield (f"event: error\ndata: {json.dumps({'message': 'stream_failed'})}\n\n") finally: # Always clear the per-request BYOK runtime so it never # leaks into the next request handled by this worker. reset_byok_runtime(_byok_tok) # Persist audit row at the end of the stream so /byok/audit # surfaces the session's history even when the visitor # disconnects before the final frame. if final_state is not None: elapsed_ms = (_t.perf_counter() - _t0) * 1000 resp = QueryResponse.from_state(final_state) try: audit_logger.log_query( user_id=user_ctx.user_id, org_id=user_ctx.org_id, query=body.query, response_summary=(resp.answer or "")[:200], sensitivity=final_state.get("query_sensitivity", "low"), status="blocked" if resp.blocked else "success", latency_ms=elapsed_ms, persona=creds.demo_persona or "anonymous", byok_used=creds.has_user_key(), synth_provider=resp.provenance.provider, synth_model=resp.provenance.model, faithfulness_ratio=final_state.get("faithfulness_ratio", 1.0), documents_used=len(final_state.get("relevant_documents", [])), stream=True, ) except Exception as exc: # pragma: no cover logger.exception("byok_stream_audit_persist_failed", error=str(exc)) return _StreamingResponse( _gen(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no", "Connection": "keep-alive", }, ) # ── BYOK public audit export ───────────────────────────────────── @app.get("/byok/audit", tags=["byok"]) async def byok_audit_endpoint( request: _FastApiRequest, creds: Annotated[ByokCreds, Depends(extract_byok)], ) -> dict: """Return the last N PII-redacted audit entries for this demo. Public, session-scoped, capped by ``settings.byok_audit_max_entries``. The audit log file is shared across sessions inside the HF Space -- we filter by the visitor's demo user_id (``demo-``) so visitors only see their own turns. """ limit = max(0, int(settings.byok_audit_max_entries)) if limit == 0: raise HTTPException(status.HTTP_404_NOT_FOUND, detail="audit export disabled") today = date.today().isoformat() entries = audit_logger.get_entries( start_date=today, end_date=today, user_id=f"demo-{creds.session_id}", ) # Newest first, then cap. entries = list(reversed(entries))[:limit] return { "session_id": creds.session_id, "count": len(entries), "items": [e.model_dump(mode="json") for e in entries], } # ── BYOK answer feedback (👍/👎 → hash-chained audit row) ───────── class _ByokFeedbackBody(_ByokBaseModel): """Public-demo feedback payload — a rating on the last answer.""" rating: str query: str = "" answer_summary: str = "" @app.post("/byok/feedback", tags=["byok"]) async def byok_feedback_endpoint( request: _FastApiRequest, body: _ByokFeedbackBody, creds: Annotated[ByokCreds, Depends(extract_byok)], ) -> dict: """Record a thumbs-up/down on an answer as a session-scoped audit row. The rating lands on the same SHA-256 hash chain as every query, so it is itself tamper-evident and shows up in ``/byok/audit``. No LLM call, no throttle — it is a cheap write. """ rating = (body.rating or "").strip().lower() if rating not in ("up", "down"): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail={"reason": "rating must be 'up' or 'down'"}, ) try: audit_logger.log_feedback( user_id=f"demo-{creds.session_id}", org_id=_DEMO_ORG_ID, rating=rating, query=body.query, answer_summary=(body.answer_summary or "")[:200], persona=creds.demo_persona or "anonymous", ) except Exception as exc: # pragma: no cover -- defensive logger.warning("byok_feedback_persist_failed", error=str(exc)) return {"session_id": creds.session_id, "rating": rating, "recorded": True} # ── BYOK upload endpoints ──────────────────────────────────────── from fastapi import File, UploadFile from qdrant_client.models import FieldCondition, Filter, MatchValue from core.agents.retriever import _get_hybrid_searcher from ingestion.metadata import IngestRequest, SensitivityLevel from ingestion.pipeline import IngestionPipeline def _session_qdrant_for_creds(creds: ByokCreds): """Return a QdrantManager bound to the visitor's session collection.""" searcher = _get_hybrid_searcher() return searcher._qdrant.for_session(creds.session_id) # type: ignore[attr-defined] def _list_session_uploads(creds: ByokCreds) -> list[dict]: """Group session-collection points by `source_file` -> upload rows.""" qdrant = _session_qdrant_for_creds(creds) client = qdrant.client collection = qdrant.collection_name files: dict[str, dict] = {} try: # Single scroll; the cap is 5 files * O(low chunks) so a single # page covers it. Limit at 1000 points is defensive only. points, _next = client.scroll( collection_name=collection, limit=1000, with_payload=True, with_vectors=False, ) for p in points: payload = p.payload or {} # Skip the purge sentinel point — it carries no source_file # and must never appear as a phantom upload row. if payload.get("__sentinel__"): continue src = payload.get("source_file") or "" # Reduce absolute paths down to a basename + sha so the # visitor sees the filename they uploaded, not the # container's tmp path. base_name = src.split("/")[-1].split("\\")[-1] item = files.setdefault( src, { "file_id": payload.get("source_file_id", base_name), "filename": base_name, "source_file": src, "chunks": 0, "first_ingested": payload.get("ingested_at"), }, ) item["chunks"] += 1 except Exception as exc: logger.warning("byok_uploads_list_failed", error=str(exc)) return list(files.values()) @app.get("/byok/uploads", tags=["byok"]) async def byok_uploads_list( request: _FastApiRequest, creds: Annotated[ByokCreds, Depends(extract_byok)], ) -> dict: uploads = _list_session_uploads(creds) return { "session_id": creds.session_id, "count": len(uploads), "max_files": settings.byok_upload_max_files, "max_bytes": settings.byok_upload_max_bytes, "max_chunks_per_file": settings.byok_upload_max_chunks_per_file, "allowed_extensions": list(settings.byok_upload_allowed_extensions), "items": uploads, } @app.post("/byok/uploads", tags=["byok"]) async def byok_uploads_ingest( request: _FastApiRequest, file: Annotated[UploadFile, File(...)], creds: Annotated[ByokCreds, Depends(extract_byok)], ) -> dict: """Accept a multipart upload from the BYOK visitor. The file is parsed by the existing ``ingestion.pipeline``, chunked, embedded, and upserted into the visitor's session-scoped Qdrant collection (``documents_sess_``). Visitor uploads are tagged ``org_id="demo"`` + roles=["viewer", ...all personas] + sensitivity_level=LOW so every demo persona can see them; the visitor's session collection is the isolation boundary. """ # ── 1. Validate ext + size ────────────────────────────────── filename = file.filename or "upload" ext = ("." + filename.rsplit(".", 1)[-1].lower()) if "." in filename else "" allowed = {e.lower() for e in settings.byok_upload_allowed_extensions} if ext not in allowed: raise HTTPException( status.HTTP_400_BAD_REQUEST, detail={ "reason": "unsupported_extension", "extension": ext, "allowed": sorted(allowed), }, ) # Read with a hard cap; abort on first byte over the limit. max_bytes = int(settings.byok_upload_max_bytes) buf = bytearray() while True: chunk = await file.read(64 * 1024) if not chunk: break buf.extend(chunk) if len(buf) > max_bytes: raise HTTPException( status.HTTP_413_CONTENT_TOO_LARGE, detail={ "reason": "file_too_large", "limit_bytes": max_bytes, }, ) if len(buf) == 0: raise HTTPException(status.HTTP_400_BAD_REQUEST, detail={"reason": "empty_file"}) # ── 2. Enforce per-session file-count cap ─────────────────── existing = _list_session_uploads(creds) if len(existing) >= int(settings.byok_upload_max_files): raise HTTPException( status.HTTP_409_CONFLICT, detail={ "reason": "upload_quota_exceeded", "max_files": settings.byok_upload_max_files, "hint": "Delete an existing upload first.", }, ) # ── 3. Spool to temp file -- the ingestion pipeline reads from disk. import os as _os import tempfile as _tempfile import uuid as _uuid from datetime import UTC as _UTC from datetime import datetime as _datetime file_id = _uuid.uuid4().hex safe_name = ( "".join(c if (c.isalnum() or c in "._-") else "_" for c in filename) or "upload" ) tmp_dir = _tempfile.mkdtemp(prefix=f"byok_{creds.session_id}_") tmp_path = _os.path.join(tmp_dir, safe_name) try: with open(tmp_path, "wb") as fh: fh.write(bytes(buf)) # ── 4. Build session-scoped pipeline + ingest ─────────── searcher = _get_hybrid_searcher() sess_qdrant = searcher._qdrant.for_session(creds.session_id) # type: ignore[attr-defined] pipeline = IngestionPipeline( qdrant_manager=sess_qdrant, embedding_service=searcher._embedder, # type: ignore[attr-defined] sparse_service=searcher._sparse, # type: ignore[attr-defined] ) req = IngestRequest( file_path=tmp_path, user_id=f"demo-{creds.session_id}", org_id=_DEMO_ORG_ID, sensitivity_level=SensitivityLevel.LOW, # Visible to every demo persona inside the visitor's session. roles=[ "viewer", "engineering", "compliance", "legal", "executive", ], ) result = await pipeline.ingest_document(req) # Reject documents that chunk too aggressively. Without this # cap a 100-page PDF dominates the dual-collection retrieval # and the grader/faithfulness loop times out under the 180 s # SLO. Delete what we just upserted and return 413 so the # visitor knows exactly what failed. max_chunks = int(settings.byok_upload_max_chunks_per_file) if max_chunks > 0 and result.num_chunks > max_chunks: try: if result.point_ids: sess_qdrant.client.delete( collection_name=sess_qdrant.collection_name, points_selector=list(result.point_ids), ) except Exception as exc: logger.warning( "byok_upload_oversize_cleanup_failed", error=str(exc), chunks=result.num_chunks, ) raise HTTPException( status.HTTP_413_CONTENT_TOO_LARGE, detail={ "reason": "too_many_chunks", "chunk_count": result.num_chunks, "max_chunks": max_chunks, "hint": ( f"This file chunks to {result.num_chunks} pieces " f"but the BYOK demo caps single uploads at {max_chunks}. " "Try a shorter document or a section instead." ), }, ) # Ingestion produced nothing usable (loader/parse failure, or a # file with no extractable text). Returning 200 with # status="failed" looked like success to the client and the file # silently never appeared. Surface it as a 422 so the UI shows a # real, immediate error instead of waiting on a phantom upload. if result.status == "failed" or not result.point_ids: raise HTTPException( status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "reason": "ingestion_failed", "errors": result.errors or ["No extractable text in the file."], "hint": ( "Could not extract text from this file. Scanned PDFs " "without a text layer and empty files won't ingest — " "try a text-based .txt, .md, or .pdf." ), }, ) # Tag every newly-upserted chunk with the file_id + ingested_at # so the list and delete endpoints can group by file. if result.point_ids: try: sess_qdrant.client.set_payload( collection_name=sess_qdrant.collection_name, payload={ "source_file_id": file_id, "ingested_at": _datetime.now(_UTC).isoformat(), "original_filename": safe_name, }, points=list(result.point_ids), ) except Exception as exc: logger.warning("byok_upload_set_payload_failed", error=str(exc)) # ── 5. Audit ──────────────────────────────────────────── try: audit_logger.log_query( user_id=f"demo-{creds.session_id}", org_id=_DEMO_ORG_ID, query=f"[upload] {safe_name}", response_summary=( f"ingested {result.num_chunks} chunks from {len(buf)} bytes" ), sensitivity="low", status=result.status, latency_ms=result.processing_time_seconds * 1000, action_hint="upload", file_id=file_id, filename=safe_name, chunks=result.num_chunks, ) except Exception as exc: logger.warning("byok_upload_audit_failed", error=str(exc)) return { "session_id": creds.session_id, "file_id": file_id, "filename": safe_name, "status": result.status, "chunks": result.num_chunks, "errors": result.errors, "processing_time_seconds": result.processing_time_seconds, } finally: # Always purge the temp file -- visitor content never lingers on disk. try: _os.remove(tmp_path) _os.rmdir(tmp_dir) except OSError: pass @app.delete("/byok/uploads/{file_id}", tags=["byok"]) async def byok_uploads_delete( request: _FastApiRequest, file_id: str, creds: Annotated[ByokCreds, Depends(extract_byok)], ) -> dict: qdrant = _session_qdrant_for_creds(creds) client = qdrant.client collection = qdrant.collection_name flt = Filter( must=[FieldCondition(key="source_file_id", match=MatchValue(value=file_id))] ) try: # Count before delete so we can report what was dropped. count_resp = client.count(collection_name=collection, count_filter=flt, exact=True) deleted = int(getattr(count_resp, "count", 0)) client.delete(collection_name=collection, points_selector=flt) except Exception as exc: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"delete_failed: {exc!s}", ) from exc return { "session_id": creds.session_id, "file_id": file_id, "deleted_chunks": deleted, } # ── BYOK extraction mode (doc -> structured JSON) ──────────────── # The platform's second face next to RAG Q&A: upload a document + a # field schema, get back one validated JSON object. No retrieval, no # vector DB — parse -> one json_mode LLM call -> validate. Reuses the # inference router (visitor's BYOK key powers it; sensitivity routing # applies) and lands on the same audit chain. See ADR-041 / Tier X. from fastapi import Form @app.post("/byok/extract", tags=["byok"]) async def byok_extract( request: _FastApiRequest, file: Annotated[UploadFile, File(...)], fields: Annotated[str, Form(...)], creds: Annotated[ByokCreds, Depends(extract_byok)], ) -> dict: """Extract a caller-defined field schema from an uploaded document. ``fields`` is a JSON string: ``[{"name","type","description"}, ...]``. Returns ``{fields: {...}, model, provider, latency_ms}``. Same throttle / BYOK-runtime contract as ``/byok/chat``. """ from core.extraction import extract_fields, normalise_fields # Parse + validate the schema first (cheap, fail fast). try: raw_fields = json.loads(fields) if not isinstance(raw_fields, list): raise ValueError("fields must be a JSON array") schema = normalise_fields(raw_fields) except (json.JSONDecodeError, ValueError) as exc: raise HTTPException( status.HTTP_400_BAD_REQUEST, detail={"reason": "bad_schema", "error": str(exc)}, ) from exc # Throttle owner-key fallback exactly like chat. if not creds.byok_active(): throttle = get_owner_key_throttle() ok, meta = throttle.allow(client_ip_from_request(request)) if not ok: raise HTTPException( status.HTTP_429_TOO_MANY_REQUESTS, detail={ "reason": meta["reason"], "retry_after_seconds": meta["retry_after"], "hint": ( "Owner-key fallback exhausted for this IP. Paste " "your own LLM key to continue — never stored." ), }, ) # Validate ext + size (mirror the upload caps). filename = file.filename or "upload" ext = ("." + filename.rsplit(".", 1)[-1].lower()) if "." in filename else "" allowed = {e.lower() for e in settings.byok_upload_allowed_extensions} if ext not in allowed: raise HTTPException( status.HTTP_400_BAD_REQUEST, detail={"reason": "unsupported_extension", "allowed": sorted(allowed)}, ) max_bytes = int(settings.byok_upload_max_bytes) buf = bytearray() while True: chunk = await file.read(64 * 1024) if not chunk: break buf.extend(chunk) if len(buf) > max_bytes: raise HTTPException( status.HTTP_413_CONTENT_TOO_LARGE, detail={"reason": "file_too_large", "limit_bytes": max_bytes}, ) if not buf: raise HTTPException(status.HTTP_400_BAD_REQUEST, detail={"reason": "empty_file"}) # Spool + parse to text via the existing loaders. import os as _os import tempfile as _tempfile from ingestion.loaders import load_document safe_name = ( "".join(c if (c.isalnum() or c in "._-") else "_" for c in filename) or "upload" ) tmp_dir = _tempfile.mkdtemp(prefix=f"byok_extract_{creds.session_id}_") tmp_path = _os.path.join(tmp_dir, safe_name) _t0 = __import__("time").perf_counter() try: with open(tmp_path, "wb") as fh: fh.write(bytes(buf)) try: docs = await asyncio.to_thread(load_document, tmp_path) except Exception as exc: raise HTTPException( status.HTTP_422_UNPROCESSABLE_ENTITY, detail={"reason": "parse_failed", "error": str(exc)}, ) from exc text = "\n\n".join(d.text for d in docs if d.text).strip() if not text: raise HTTPException( status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "reason": "no_text", "hint": "No extractable text (scanned image PDFs need OCR).", }, ) _byok_tok = set_byok_runtime(_byok_runtime_for(creds)) try: result = await extract_fields( text, schema, prefer_cloud=True, sensitivity_level="low" ) finally: reset_byok_runtime(_byok_tok) finally: try: _os.remove(tmp_path) _os.rmdir(tmp_dir) except OSError: pass elapsed_ms = (__import__("time").perf_counter() - _t0) * 1000 try: audit_logger.log_query( user_id=f"demo-{creds.session_id}", org_id=_DEMO_ORG_ID, query=f"[extract] {safe_name} ({len(schema)} fields)", response_summary=f"extracted {len(result['fields'])} fields", sensitivity="low", status="success", latency_ms=elapsed_ms, action_hint="extract", byok_used=creds.has_user_key(), synth_provider=result["provider"], synth_model=result["model"], ) except Exception as exc: # pragma: no cover - defensive logger.warning("byok_extract_audit_failed", error=str(exc)) return { "session_id": creds.session_id, "filename": safe_name, "byok_used": creds.has_user_key(), "fields": result["fields"], "provider": result["provider"], "model": result["model"], "latency_ms": elapsed_ms, } @app.post("/query", response_model=QueryResponse, tags=["rag"]) async def query_endpoint( body: QueryRequest, auth: Annotated[tuple[UserContext, dict], Depends(_resolve_user_full)], ) -> QueryResponse: user, claims = auth if not rate_limiter.is_allowed(f"{user.user_id}:query"): raise HTTPException(status.HTTP_429_TOO_MANY_REQUESTS, "rate limit exceeded") # Caller-supplied user_id must match the bearer-token identity. if body.user_id != user.user_id: raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch") # Use the JWT id so the audit trail can correlate a query with the # exact token that authorised it; useful for revocation forensics. jti = claims.get("jti", "unsigned") state = await run_rag_pipeline( query=body.query, user_context=user, thread_id=f"api-{user.user_id}-{jti}", prefer_cloud=body.prefer_cloud, override_provider=body.override_provider, ) return QueryResponse.from_state(state) @app.post("/ingest", response_model=IngestResponseModel, tags=["rag"]) async def ingest_endpoint( body: IngestRequestModel, user: Annotated[UserContext, Depends(_require_role("user"))], ) -> IngestResponseModel: if body.user_id != user.user_id: raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch") from core.agents.retriever import _get_hybrid_searcher from ingestion.pipeline import IngestionPipeline searcher = _get_hybrid_searcher() pipeline = IngestionPipeline( qdrant_manager=searcher._qdrant, # type: ignore[attr-defined] embedding_service=searcher._embedder, # type: ignore[attr-defined] sparse_service=searcher._sparse, # type: ignore[attr-defined] ) req = IngestRequest( file_path=body.file_path, user_id=body.user_id, org_id=body.org_id, sensitivity_level=SensitivityLevel(body.sensitivity_level), roles=body.roles, ) result = await pipeline.ingest_document(req) return IngestResponseModel( file_path=result.file_path, status=result.status, num_chunks=result.num_chunks, point_ids=result.point_ids, errors=result.errors, processing_time_seconds=result.processing_time_seconds, ) @app.get("/audit", tags=["audit"]) async def audit_list( user: Annotated[UserContext, Depends(_require_role("admin"))], start: str | None = None, end: str | None = None, limit: int = 100, ) -> dict: today = date.today().isoformat() entries = audit_logger.get_entries( start_date=start or today, end_date=end or today, user_id=None, action=None, ) return { "total": len(entries), "items": [e.model_dump(mode="json") for e in entries[:limit]], } @app.post("/audit/verify", tags=["audit"]) async def audit_verify( user: Annotated[UserContext, Depends(_require_role("admin"))], start: str | None = None, end: str | None = None, ) -> dict: result = audit_logger.verify_chain(start_date=start, end_date=end) return result from pydantic import BaseModel as _PydBM class _TokenRequest(_PydBM): """Identity payload accepted by the dev ``/token`` endpoint.""" user_id: str org_id: str = "" roles: list[str] = [] clearance_level: int = 1 ttl_seconds: int | None = None class _TokenResponse(_PydBM): access_token: str token_type: str = "bearer" expires_in: int @app.post("/token", response_model=_TokenResponse, tags=["auth"]) async def issue_dev_token(body: _TokenRequest) -> _TokenResponse: """Mint a signed JWT for local testing. In production the IdP (Keycloak / Auth0 / Microsoft Entra) issues the token externally and this endpoint is removed via the ``SAR_DISABLE_DEV_TOKEN`` flag — kept here so the e2e smoke script and the Streamlit demo can mint a real token rather than the unsigned base64 fallback. """ if settings.disable_dev_token: raise HTTPException( status.HTTP_404_NOT_FOUND, "Dev token endpoint disabled (SAR_DISABLE_DEV_TOKEN=true)", ) if settings.jwt_algorithm.upper() == "RS256": raise HTTPException( status.HTTP_404_NOT_FOUND, "Dev token endpoint disabled in RS256 mode — use the external IdP", ) if not settings.jwt_secret: raise HTTPException( status.HTTP_503_SERVICE_UNAVAILABLE, "SAR_JWT_SECRET is not configured; token endpoint disabled", ) try: token = issue_token( user_id=body.user_id, org_id=body.org_id, roles=body.roles, clearance_level=body.clearance_level, ttl_seconds=body.ttl_seconds, ) except AuthError as exc: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}" ) from exc return _TokenResponse( access_token=token, token_type="bearer", expires_in=body.ttl_seconds or settings.jwt_ttl_seconds, ) else: # pragma: no cover app = None # type: ignore[assignment] def mint_dev_token(user: dict) -> str: """Convenience for local testing — build a bearer token for a UserContext dict. When ``SAR_JWT_SECRET`` is configured this mints a real signed JWT. With no secret it emits the legacy unsigned base64 shape *only* when ``SAR_ALLOW_UNSIGNED_TOKENS`` is on (matching the verifier's fail-closed policy); otherwise it raises so callers are forced to configure auth. """ if settings.jwt_secret: try: return issue_token( user_id=user.get("user_id", ""), org_id=user.get("org_id", ""), roles=list(user.get("roles", [])), clearance_level=int(user.get("clearance_level", 1)), ) except AuthError: # Fall through to legacy shape on issuer error (only if allowed). pass if not settings.allow_unsigned_tokens: raise AuthError( "missing", "cannot mint a token: set SAR_JWT_SECRET, or SAR_ALLOW_UNSIGNED_TOKENS=true for dev", ) payload = json.dumps(user).encode("utf-8") return base64.b64encode(payload).decode("ascii")