Spaces:
Running
Running
deploy: phase 3 BYOK backend (Dockerfile.hf, FastAPI on 7860)
Browse files- Dockerfile.hf +151 -151
- config/settings.py +316 -316
- inference/cloud_clients.py +577 -577
- inference/ollama_client.py +334 -334
- interfaces/api.py +432 -425
- interfaces/byok.py +166 -166
- retrieval/multitenancy.py +43 -43
- retrieval/session_purge.py +185 -185
- utils/observability.py +252 -252
- utils/pii.py +146 -146
- utils/rate_limiter.py +524 -524
Dockerfile.hf
CHANGED
|
@@ -1,151 +1,151 @@
|
|
| 1 |
-
# =============================================================================
|
| 2 |
-
# Dockerfile.hf — SecureAgentRAG backend for Hugging Face Spaces (CPU Basic).
|
| 3 |
-
# =============================================================================
|
| 4 |
-
# Two-stage build keeps the runtime image lean. The HF Space free tier is
|
| 5 |
-
# CPU-only with 16 GB RAM and ~50 GB ephemeral disk, so we target a tight
|
| 6 |
-
# memory footprint:
|
| 7 |
-
#
|
| 8 |
-
# - Python 3.11-slim base (~150 MB)
|
| 9 |
-
# - Only [api, embeddings-local, pii] extras (no OCR, no Phoenix, no Postgres,
|
| 10 |
-
# no Redis, no MCP) -- those modules are present in the source but their
|
| 11 |
-
# dependencies are not installed
|
| 12 |
-
# - cross-encoder reranker downloaded on first request (auto-cached under
|
| 13 |
-
# /home/user/.cache/huggingface). Skips the 2.3 GB fine-tuned checkpoint
|
| 14 |
-
# for the initial deploy; phase 3.2 can swap to fine_tuned once the
|
| 15 |
-
# reranker repo is published on HF Hub.
|
| 16 |
-
#
|
| 17 |
-
# The Space-side README.md is uploaded separately by scripts/deploy_hf_space.py
|
| 18 |
-
# with a YAML frontmatter declaring sdk=docker + app_port=7860.
|
| 19 |
-
# =============================================================================
|
| 20 |
-
|
| 21 |
-
# --- builder ----------------------------------------------------------------
|
| 22 |
-
FROM python:3.11-slim AS builder
|
| 23 |
-
|
| 24 |
-
WORKDIR /app
|
| 25 |
-
|
| 26 |
-
RUN pip install --no-cache-dir uv
|
| 27 |
-
|
| 28 |
-
# pyproject.toml + a copy of the source are required for uv to build the
|
| 29 |
-
# editable install. README.md is referenced as the long_description.
|
| 30 |
-
COPY pyproject.toml ./
|
| 31 |
-
COPY README.md ./
|
| 32 |
-
|
| 33 |
-
# Touch the package directories that hatchling treats as the wheel root --
|
| 34 |
-
# we only need the directory tree to exist at build time so hatchling can
|
| 35 |
-
# scan for __init__.py files. The actual code lands in the runtime stage.
|
| 36 |
-
RUN mkdir -p config core inference retrieval interfaces ingestion utils evaluation app \
|
| 37 |
-
&& touch config/__init__.py core/__init__.py inference/__init__.py \
|
| 38 |
-
&& touch retrieval/__init__.py interfaces/__init__.py ingestion/__init__.py \
|
| 39 |
-
&& touch utils/__init__.py evaluation/__init__.py app/__init__.py
|
| 40 |
-
|
| 41 |
-
# Intentionally skip [pii] extras -- the regex patterns in utils/pii.py
|
| 42 |
-
# already cover every BYOK key shape (Groq / OpenAI / Anthropic / HF / Vercel
|
| 43 |
-
# / Qdrant JWT / Qdrant management). Adding Presidio would pull spaCy
|
| 44 |
-
# en_core_web_lg (~770 MB) which auto-downloads at runtime and crashes the
|
| 45 |
-
# container on the CPU Basic Space when the package installer is absent.
|
| 46 |
-
RUN uv venv /app/.venv \
|
| 47 |
-
&& uv pip install --python /app/.venv/bin/python \
|
| 48 |
-
-e ".[api,embeddings-local]"
|
| 49 |
-
|
| 50 |
-
# --- runtime ----------------------------------------------------------------
|
| 51 |
-
FROM python:3.11-slim AS runtime
|
| 52 |
-
|
| 53 |
-
WORKDIR /app
|
| 54 |
-
|
| 55 |
-
# HF Spaces convention: run as uid 1000 with a writeable /home/user.
|
| 56 |
-
RUN useradd -m -u 1000 user
|
| 57 |
-
|
| 58 |
-
# System deps for PDF / image processing only -- no OCR / paddle.
|
| 59 |
-
# Debian 12+ (trixie) renamed libgl1-mesa-glx -> libgl1 and libxrender-dev
|
| 60 |
-
# is no longer needed at runtime (runtime is libxrender1).
|
| 61 |
-
RUN apt-get update \
|
| 62 |
-
&& apt-get install -y --no-install-recommends \
|
| 63 |
-
libglib2.0-0 libsm6 libxext6 libxrender1 libgl1 curl \
|
| 64 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 65 |
-
|
| 66 |
-
# Bring the virtualenv from the builder stage.
|
| 67 |
-
COPY --from=builder /app/.venv /app/.venv
|
| 68 |
-
ENV PATH="/app/.venv/bin:$PATH"
|
| 69 |
-
|
| 70 |
-
# Copy application source. Files that match .dockerignore are filtered out.
|
| 71 |
-
COPY --chown=user:user . /app
|
| 72 |
-
|
| 73 |
-
USER user
|
| 74 |
-
|
| 75 |
-
# Pre-populate the HF cache so the cross-encoder lives on disk before the
|
| 76 |
-
# first request. Defensive: never fails the build -- if HF Hub is unreachable
|
| 77 |
-
# during build (offline mirrors etc.) the cache is populated on first query.
|
| 78 |
-
RUN python -c "import os; \
|
| 79 |
-
from huggingface_hub import snapshot_download; \
|
| 80 |
-
import sys; \
|
| 81 |
-
try: snapshot_download(repo_id='BAAI/bge-reranker-v2-m3', cache_dir='/home/user/.cache/huggingface/hub'); print('reranker cached') \
|
| 82 |
-
except Exception as e: print(f'reranker cache skipped: {e!r}', file=sys.stderr)" \
|
| 83 |
-
|| echo "build-time reranker download failed -- will lazy-load on first request"
|
| 84 |
-
|
| 85 |
-
# --- BYOK production env ---------------------------------------------------
|
| 86 |
-
# Real secrets (Qdrant URL + API key, Groq key) are injected via HF Space
|
| 87 |
-
# secrets panel -- they ride the same SAR_* env-var protocol but are NOT
|
| 88 |
-
# baked into the image. Only mode flags and safe defaults live here.
|
| 89 |
-
ENV SAR_BYOK_MODE=true
|
| 90 |
-
ENV SAR_BYOK_OWNER_QUOTA=3
|
| 91 |
-
ENV SAR_SESSION_TTL_HOURS=24
|
| 92 |
-
ENV SAR_CORS_ALLOW_ORIGINS='["https://app.eilm.live","https://secureagentrag-web.vercel.app","https://secureagentrag.vercel.app"]'
|
| 93 |
-
|
| 94 |
-
# Cloud LLM defaults -- Groq llama-3.1-8b-instant is the cheapest fast option
|
| 95 |
-
# on the free tier. Visitor BYOK overrides this per request.
|
| 96 |
-
ENV SAR_DEFAULT_PROVIDER=groq
|
| 97 |
-
ENV SAR_CLOUD_PROVIDER=groq
|
| 98 |
-
ENV SAR_LLM_MODEL=llama-3.1-8b-instant
|
| 99 |
-
|
| 100 |
-
# Embedding stack -- local BGE-M3 via sentence-transformers (CPU). Avoids
|
| 101 |
-
# Ollama entirely.
|
| 102 |
-
ENV SAR_EMBEDDING_BACKEND=local
|
| 103 |
-
ENV SAR_LOCAL_EMBEDDING_MODEL=BAAI/bge-m3
|
| 104 |
-
ENV SAR_EMBEDDING_MODEL=bge-m3
|
| 105 |
-
ENV SAR_EMBEDDING_DIM=1024
|
| 106 |
-
|
| 107 |
-
# Cross-encoder reranker -- balances quality with build size. Swap to
|
| 108 |
-
# fine_tuned + SAR_FINETUNED_RERANKER_PATH after phase 3.2 ships the
|
| 109 |
-
# 2.3 GB checkpoint to LeomordKaly/secureagentrag-reranker-v1.
|
| 110 |
-
ENV SAR_RERANKER_TYPE=cross_encoder
|
| 111 |
-
ENV SAR_RERANKER_CHECKPOINT=BAAI/bge-reranker-v2-m3
|
| 112 |
-
|
| 113 |
-
# Sparse retrieval -- BM25 keeps the cold path zero-dep; SPLADE adds an
|
| 114 |
-
# extra ~600 MB model and is skipped on free CPU Basic.
|
| 115 |
-
ENV SAR_SPARSE_BACKEND=bm25
|
| 116 |
-
|
| 117 |
-
# Persistence paths -- /tmp is the only writable area on HF Spaces.
|
| 118 |
-
ENV SAR_AUDIT_LOG_DIR=/tmp/secureagentrag/audit_logs
|
| 119 |
-
ENV SAR_CONVERSATION_DIR=/tmp/secureagentrag/conversations
|
| 120 |
-
ENV SAR_CHECKPOINT_DB_PATH=/tmp/secureagentrag/checkpoints.sqlite
|
| 121 |
-
ENV SAR_BM25_INDEX_PATH=/tmp/secureagentrag/bm25_index.pkl
|
| 122 |
-
|
| 123 |
-
# Multi-tenant collections route BYOK session -> documents_sess_<sid>.
|
| 124 |
-
ENV SAR_MULTI_TENANT_COLLECTIONS=true
|
| 125 |
-
|
| 126 |
-
# Pipeline safety
|
| 127 |
-
ENV SAR_REQUEST_TIMEOUT_S=120
|
| 128 |
-
ENV SAR_FAITHFULNESS_GATE_ENABLED=true
|
| 129 |
-
ENV SAR_FAITHFULNESS_GATE_MODE=flag
|
| 130 |
-
ENV SAR_FAITHFULNESS_THRESHOLD=0.7
|
| 131 |
-
|
| 132 |
-
# Logging
|
| 133 |
-
ENV SAR_LOG_LEVEL=INFO
|
| 134 |
-
|
| 135 |
-
# HF cache lives under the user home which is the only persistent writable
|
| 136 |
-
# tree across Space restarts on CPU Basic.
|
| 137 |
-
ENV HF_HOME=/home/user/.cache/huggingface
|
| 138 |
-
ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface/hub
|
| 139 |
-
|
| 140 |
-
EXPOSE 7860
|
| 141 |
-
|
| 142 |
-
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 143 |
-
CMD curl --fail --silent --show-error http://localhost:7860/healthz || exit 1
|
| 144 |
-
|
| 145 |
-
# uvicorn with 1 worker -- on CPU Basic two workers thrash the memory.
|
| 146 |
-
CMD ["uvicorn", "interfaces.api:app", \
|
| 147 |
-
"--host", "0.0.0.0", \
|
| 148 |
-
"--port", "7860", \
|
| 149 |
-
"--workers", "1", \
|
| 150 |
-
"--timeout-keep-alive", "30", \
|
| 151 |
-
"--no-access-log"]
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Dockerfile.hf — SecureAgentRAG backend for Hugging Face Spaces (CPU Basic).
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# Two-stage build keeps the runtime image lean. The HF Space free tier is
|
| 5 |
+
# CPU-only with 16 GB RAM and ~50 GB ephemeral disk, so we target a tight
|
| 6 |
+
# memory footprint:
|
| 7 |
+
#
|
| 8 |
+
# - Python 3.11-slim base (~150 MB)
|
| 9 |
+
# - Only [api, embeddings-local, pii] extras (no OCR, no Phoenix, no Postgres,
|
| 10 |
+
# no Redis, no MCP) -- those modules are present in the source but their
|
| 11 |
+
# dependencies are not installed
|
| 12 |
+
# - cross-encoder reranker downloaded on first request (auto-cached under
|
| 13 |
+
# /home/user/.cache/huggingface). Skips the 2.3 GB fine-tuned checkpoint
|
| 14 |
+
# for the initial deploy; phase 3.2 can swap to fine_tuned once the
|
| 15 |
+
# reranker repo is published on HF Hub.
|
| 16 |
+
#
|
| 17 |
+
# The Space-side README.md is uploaded separately by scripts/deploy_hf_space.py
|
| 18 |
+
# with a YAML frontmatter declaring sdk=docker + app_port=7860.
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
# --- builder ----------------------------------------------------------------
|
| 22 |
+
FROM python:3.11-slim AS builder
|
| 23 |
+
|
| 24 |
+
WORKDIR /app
|
| 25 |
+
|
| 26 |
+
RUN pip install --no-cache-dir uv
|
| 27 |
+
|
| 28 |
+
# pyproject.toml + a copy of the source are required for uv to build the
|
| 29 |
+
# editable install. README.md is referenced as the long_description.
|
| 30 |
+
COPY pyproject.toml ./
|
| 31 |
+
COPY README.md ./
|
| 32 |
+
|
| 33 |
+
# Touch the package directories that hatchling treats as the wheel root --
|
| 34 |
+
# we only need the directory tree to exist at build time so hatchling can
|
| 35 |
+
# scan for __init__.py files. The actual code lands in the runtime stage.
|
| 36 |
+
RUN mkdir -p config core inference retrieval interfaces ingestion utils evaluation app \
|
| 37 |
+
&& touch config/__init__.py core/__init__.py inference/__init__.py \
|
| 38 |
+
&& touch retrieval/__init__.py interfaces/__init__.py ingestion/__init__.py \
|
| 39 |
+
&& touch utils/__init__.py evaluation/__init__.py app/__init__.py
|
| 40 |
+
|
| 41 |
+
# Intentionally skip [pii] extras -- the regex patterns in utils/pii.py
|
| 42 |
+
# already cover every BYOK key shape (Groq / OpenAI / Anthropic / HF / Vercel
|
| 43 |
+
# / Qdrant JWT / Qdrant management). Adding Presidio would pull spaCy
|
| 44 |
+
# en_core_web_lg (~770 MB) which auto-downloads at runtime and crashes the
|
| 45 |
+
# container on the CPU Basic Space when the package installer is absent.
|
| 46 |
+
RUN uv venv /app/.venv \
|
| 47 |
+
&& uv pip install --python /app/.venv/bin/python \
|
| 48 |
+
-e ".[api,embeddings-local]"
|
| 49 |
+
|
| 50 |
+
# --- runtime ----------------------------------------------------------------
|
| 51 |
+
FROM python:3.11-slim AS runtime
|
| 52 |
+
|
| 53 |
+
WORKDIR /app
|
| 54 |
+
|
| 55 |
+
# HF Spaces convention: run as uid 1000 with a writeable /home/user.
|
| 56 |
+
RUN useradd -m -u 1000 user
|
| 57 |
+
|
| 58 |
+
# System deps for PDF / image processing only -- no OCR / paddle.
|
| 59 |
+
# Debian 12+ (trixie) renamed libgl1-mesa-glx -> libgl1 and libxrender-dev
|
| 60 |
+
# is no longer needed at runtime (runtime is libxrender1).
|
| 61 |
+
RUN apt-get update \
|
| 62 |
+
&& apt-get install -y --no-install-recommends \
|
| 63 |
+
libglib2.0-0 libsm6 libxext6 libxrender1 libgl1 curl \
|
| 64 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 65 |
+
|
| 66 |
+
# Bring the virtualenv from the builder stage.
|
| 67 |
+
COPY --from=builder /app/.venv /app/.venv
|
| 68 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 69 |
+
|
| 70 |
+
# Copy application source. Files that match .dockerignore are filtered out.
|
| 71 |
+
COPY --chown=user:user . /app
|
| 72 |
+
|
| 73 |
+
USER user
|
| 74 |
+
|
| 75 |
+
# Pre-populate the HF cache so the cross-encoder lives on disk before the
|
| 76 |
+
# first request. Defensive: never fails the build -- if HF Hub is unreachable
|
| 77 |
+
# during build (offline mirrors etc.) the cache is populated on first query.
|
| 78 |
+
RUN python -c "import os; \
|
| 79 |
+
from huggingface_hub import snapshot_download; \
|
| 80 |
+
import sys; \
|
| 81 |
+
try: snapshot_download(repo_id='BAAI/bge-reranker-v2-m3', cache_dir='/home/user/.cache/huggingface/hub'); print('reranker cached') \
|
| 82 |
+
except Exception as e: print(f'reranker cache skipped: {e!r}', file=sys.stderr)" \
|
| 83 |
+
|| echo "build-time reranker download failed -- will lazy-load on first request"
|
| 84 |
+
|
| 85 |
+
# --- BYOK production env ---------------------------------------------------
|
| 86 |
+
# Real secrets (Qdrant URL + API key, Groq key) are injected via HF Space
|
| 87 |
+
# secrets panel -- they ride the same SAR_* env-var protocol but are NOT
|
| 88 |
+
# baked into the image. Only mode flags and safe defaults live here.
|
| 89 |
+
ENV SAR_BYOK_MODE=true
|
| 90 |
+
ENV SAR_BYOK_OWNER_QUOTA=3
|
| 91 |
+
ENV SAR_SESSION_TTL_HOURS=24
|
| 92 |
+
ENV SAR_CORS_ALLOW_ORIGINS='["https://app.eilm.live","https://secureagentrag-web.vercel.app","https://secureagentrag.vercel.app"]'
|
| 93 |
+
|
| 94 |
+
# Cloud LLM defaults -- Groq llama-3.1-8b-instant is the cheapest fast option
|
| 95 |
+
# on the free tier. Visitor BYOK overrides this per request.
|
| 96 |
+
ENV SAR_DEFAULT_PROVIDER=groq
|
| 97 |
+
ENV SAR_CLOUD_PROVIDER=groq
|
| 98 |
+
ENV SAR_LLM_MODEL=llama-3.1-8b-instant
|
| 99 |
+
|
| 100 |
+
# Embedding stack -- local BGE-M3 via sentence-transformers (CPU). Avoids
|
| 101 |
+
# Ollama entirely.
|
| 102 |
+
ENV SAR_EMBEDDING_BACKEND=local
|
| 103 |
+
ENV SAR_LOCAL_EMBEDDING_MODEL=BAAI/bge-m3
|
| 104 |
+
ENV SAR_EMBEDDING_MODEL=bge-m3
|
| 105 |
+
ENV SAR_EMBEDDING_DIM=1024
|
| 106 |
+
|
| 107 |
+
# Cross-encoder reranker -- balances quality with build size. Swap to
|
| 108 |
+
# fine_tuned + SAR_FINETUNED_RERANKER_PATH after phase 3.2 ships the
|
| 109 |
+
# 2.3 GB checkpoint to LeomordKaly/secureagentrag-reranker-v1.
|
| 110 |
+
ENV SAR_RERANKER_TYPE=cross_encoder
|
| 111 |
+
ENV SAR_RERANKER_CHECKPOINT=BAAI/bge-reranker-v2-m3
|
| 112 |
+
|
| 113 |
+
# Sparse retrieval -- BM25 keeps the cold path zero-dep; SPLADE adds an
|
| 114 |
+
# extra ~600 MB model and is skipped on free CPU Basic.
|
| 115 |
+
ENV SAR_SPARSE_BACKEND=bm25
|
| 116 |
+
|
| 117 |
+
# Persistence paths -- /tmp is the only writable area on HF Spaces.
|
| 118 |
+
ENV SAR_AUDIT_LOG_DIR=/tmp/secureagentrag/audit_logs
|
| 119 |
+
ENV SAR_CONVERSATION_DIR=/tmp/secureagentrag/conversations
|
| 120 |
+
ENV SAR_CHECKPOINT_DB_PATH=/tmp/secureagentrag/checkpoints.sqlite
|
| 121 |
+
ENV SAR_BM25_INDEX_PATH=/tmp/secureagentrag/bm25_index.pkl
|
| 122 |
+
|
| 123 |
+
# Multi-tenant collections route BYOK session -> documents_sess_<sid>.
|
| 124 |
+
ENV SAR_MULTI_TENANT_COLLECTIONS=true
|
| 125 |
+
|
| 126 |
+
# Pipeline safety
|
| 127 |
+
ENV SAR_REQUEST_TIMEOUT_S=120
|
| 128 |
+
ENV SAR_FAITHFULNESS_GATE_ENABLED=true
|
| 129 |
+
ENV SAR_FAITHFULNESS_GATE_MODE=flag
|
| 130 |
+
ENV SAR_FAITHFULNESS_THRESHOLD=0.7
|
| 131 |
+
|
| 132 |
+
# Logging
|
| 133 |
+
ENV SAR_LOG_LEVEL=INFO
|
| 134 |
+
|
| 135 |
+
# HF cache lives under the user home which is the only persistent writable
|
| 136 |
+
# tree across Space restarts on CPU Basic.
|
| 137 |
+
ENV HF_HOME=/home/user/.cache/huggingface
|
| 138 |
+
ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface/hub
|
| 139 |
+
|
| 140 |
+
EXPOSE 7860
|
| 141 |
+
|
| 142 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 143 |
+
CMD curl --fail --silent --show-error http://localhost:7860/healthz || exit 1
|
| 144 |
+
|
| 145 |
+
# uvicorn with 1 worker -- on CPU Basic two workers thrash the memory.
|
| 146 |
+
CMD ["uvicorn", "interfaces.api:app", \
|
| 147 |
+
"--host", "0.0.0.0", \
|
| 148 |
+
"--port", "7860", \
|
| 149 |
+
"--workers", "1", \
|
| 150 |
+
"--timeout-keep-alive", "30", \
|
| 151 |
+
"--no-access-log"]
|
config/settings.py
CHANGED
|
@@ -1,316 +1,316 @@
|
|
| 1 |
-
"""Application settings managed via pydantic-settings with environment variable support."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import contextlib
|
| 6 |
-
import json
|
| 7 |
-
import os
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class Settings(BaseSettings):
|
| 14 |
-
"""Central configuration for SecureAgentRAG.
|
| 15 |
-
|
| 16 |
-
All settings can be overridden via environment variables prefixed with ``SAR_``.
|
| 17 |
-
For example, ``SAR_DEBUG=true`` sets ``debug`` to True.
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
model_config = SettingsConfigDict(
|
| 21 |
-
env_file=".env",
|
| 22 |
-
env_prefix="SAR_",
|
| 23 |
-
env_file_encoding="utf-8",
|
| 24 |
-
case_sensitive=False,
|
| 25 |
-
extra="ignore",
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
# ── Application ──────────────────────────────────────────────────────────────
|
| 29 |
-
app_name: str = "SecureAgentRAG"
|
| 30 |
-
debug: bool = False
|
| 31 |
-
log_level: str = "INFO"
|
| 32 |
-
|
| 33 |
-
# ── Qdrant Vector Store ─────────────────────────────────────────────────────
|
| 34 |
-
qdrant_url: str = "http://localhost:6333"
|
| 35 |
-
qdrant_collection: str = "documents"
|
| 36 |
-
qdrant_api_key: str | None = None
|
| 37 |
-
|
| 38 |
-
# ── Ollama / LLM ─────────────────────────────────────────────────────────────
|
| 39 |
-
ollama_url: str = "http://localhost:11434"
|
| 40 |
-
llm_model: str = "qwen3:8b"
|
| 41 |
-
embedding_model: str = "bge-m3"
|
| 42 |
-
embedding_dim: int = 1024
|
| 43 |
-
embedding_backend: str = "ollama" # "ollama" or "local" (sentence-transformers)
|
| 44 |
-
local_embedding_model: str = "BAAI/bge-m3"
|
| 45 |
-
# How long Ollama keeps models resident in VRAM between requests.
|
| 46 |
-
# On consumer hardware the LLM (qwen3:8b ~5.5GB) and embedding (bge-m3 ~1.2GB)
|
| 47 |
-
# need to swap if VRAM is tight. Long keep-alive avoids ~5-10s reload per swap.
|
| 48 |
-
ollama_keep_alive: str = "30m"
|
| 49 |
-
|
| 50 |
-
# ── Chunking ─────────────────────────────────────────────────────────────────
|
| 51 |
-
chunk_size: int = 1000
|
| 52 |
-
chunk_overlap: int = 200
|
| 53 |
-
|
| 54 |
-
# ── Retrieval ────────────────────────────────────────────────────────────────
|
| 55 |
-
top_k: int = 10
|
| 56 |
-
rerank_top_k: int = 5
|
| 57 |
-
relevance_threshold: float = 0.7
|
| 58 |
-
# RAG Fusion: generate N query reformulations, retrieve in parallel,
|
| 59 |
-
# fuse the ranked lists via RRF. Boosts recall on under-specified
|
| 60 |
-
# queries. Cost: N-1 extra LLM calls + N parallel Qdrant searches.
|
| 61 |
-
# Set to 1 to disable.
|
| 62 |
-
rag_fusion_n_queries: int = 3
|
| 63 |
-
rag_fusion_enabled: bool = True
|
| 64 |
-
# ── Reranker ─────────────────────────────────────────────────────────────────
|
| 65 |
-
# Re-score retrieved documents for higher precision.
|
| 66 |
-
# Options: "none" (disabled), "cross_encoder" (BGE-Reranker-v2-M3),
|
| 67 |
-
# "colbert" (ColBERTv2 late-interaction, requires colbert-ai package).
|
| 68 |
-
# The cross-encoder downloads ~600MB from HuggingFace on first use.
|
| 69 |
-
# The ColBERT checkpoint is ~400MB. Disabled by default so the first
|
| 70 |
-
# query does not silently hang on download. Pre-download explicitly.
|
| 71 |
-
reranker_type: str = "none"
|
| 72 |
-
reranker_checkpoint: str = "BAAI/bge-reranker-v2-m3"
|
| 73 |
-
colbert_checkpoint: str = "colbert-ir/colbertv2.0"
|
| 74 |
-
# Path to a locally fine-tuned cross-encoder checkpoint produced by
|
| 75 |
-
# scripts/train_reranker.py. Used when reranker_type == "fine_tuned".
|
| 76 |
-
finetuned_reranker_path: str = "data/checkpoints/reranker-domain-v1"
|
| 77 |
-
|
| 78 |
-
# ── Inference Providers ──────────────────────────────────────────────────────
|
| 79 |
-
default_provider: str = "ollama"
|
| 80 |
-
cloud_provider: str | None = None
|
| 81 |
-
groq_api_key: str | None = None
|
| 82 |
-
openai_api_key: str | None = None
|
| 83 |
-
anthropic_api_key: str | None = None
|
| 84 |
-
groq_api_base: str = "https://api.groq.com/openai/v1"
|
| 85 |
-
openai_api_base: str = "https://api.openai.com/v1"
|
| 86 |
-
anthropic_api_base: str = "https://api.anthropic.com/v1"
|
| 87 |
-
|
| 88 |
-
# ── RAG Pipeline Thresholds ───────────────────────────────────────────────────
|
| 89 |
-
relevance_retry_threshold: float = 0.5
|
| 90 |
-
confidence_threshold: float = 0.6
|
| 91 |
-
max_retries: int = 2
|
| 92 |
-
|
| 93 |
-
# ── JSON Citations ───────────────────────────────────────────────────────
|
| 94 |
-
# When enabled, the synthesizer requests structured JSON output from the LLM
|
| 95 |
-
# with `answer` and `citations` fields instead of relying on regex extraction.
|
| 96 |
-
json_citations_enabled: bool = False
|
| 97 |
-
|
| 98 |
-
# ── Embedding Batch Size ──────────────────────────────────────────────────────
|
| 99 |
-
embedding_batch_size: int = 32 # Max texts per embedding API call
|
| 100 |
-
embedding_max_concurrent_batches: int = 4 # Max concurrent batch requests
|
| 101 |
-
|
| 102 |
-
# ── RBAC ─────────────────────────────────────────────────────────────────────
|
| 103 |
-
enable_rbac: bool = True
|
| 104 |
-
|
| 105 |
-
# ── Observability (Phoenix) ──────────────────────────────────────────────────
|
| 106 |
-
phoenix_endpoint: str | None = None
|
| 107 |
-
|
| 108 |
-
# ── Sparse Vectors (Qdrant native, replaces rank_bm25 pickle) ───────────────
|
| 109 |
-
sparse_backend: str = "bm25" # "bm25" | "splade"
|
| 110 |
-
sparse_vector_name: str = "sparse"
|
| 111 |
-
sparse_model: str = "naver/splade-cocondenser-ensembledistil"
|
| 112 |
-
|
| 113 |
-
# ── Audit + Conversation Storage ──────────────────────────────────────────────
|
| 114 |
-
audit_log_dir: str = "audit_logs"
|
| 115 |
-
conversation_dir: str = "conversations"
|
| 116 |
-
checkpoint_db_path: str = "data/checkpoints.sqlite"
|
| 117 |
-
# Opt-in: enable persistent (SQLite/Postgres) LangGraph checkpointing.
|
| 118 |
-
# Default off because pytest-asyncio creates per-test event loops which
|
| 119 |
-
# collide with aiosqlite's loop-bound connection. For production single-
|
| 120 |
-
# process Streamlit / FastAPI deployments, set SAR_USE_PERSISTENT_CHECKPOINTER=true.
|
| 121 |
-
use_persistent_checkpointer: bool = False
|
| 122 |
-
|
| 123 |
-
# ── PostgreSQL (for LangGraph checkpointing) ─────────────────────────────────
|
| 124 |
-
postgres_url: str = "postgresql://sar_user:sar_password@localhost:5433/secureagentrag"
|
| 125 |
-
|
| 126 |
-
# ── Pipeline SLO ─────────────────────────────────────────────────────────────
|
| 127 |
-
# Hard wall-clock budget for a single RAG pipeline run (rewrite loop +
|
| 128 |
-
# retrieval + grading + synthesis + evaluation). On timeout the caller
|
| 129 |
-
# gets a graceful refusal + audit entry; nothing partial is rendered as
|
| 130 |
-
# if the answer succeeded. 0 disables the deadline.
|
| 131 |
-
request_timeout_s: float = 60.0
|
| 132 |
-
|
| 133 |
-
# ── Authentication ───────────────────────────────────────────────────────────
|
| 134 |
-
# When ``jwt_secret`` is set the FastAPI / MCP layers verify HS256-signed
|
| 135 |
-
# JWTs and derive UserContext from validated claims. When unset, callers
|
| 136 |
-
# fall back to the dev-mode base64(json(UserContext)) token shape so
|
| 137 |
-
# existing tests and smoke scripts keep working — but a runtime warning is
|
| 138 |
-
# emitted on every request. Production deployments MUST set this.
|
| 139 |
-
#
|
| 140 |
-
# ``jwt_issuer`` / ``jwt_audience`` are checked against ``iss`` / ``aud``
|
| 141 |
-
# claims when present. Leave empty to disable that check (default).
|
| 142 |
-
# ``jwt_ttl_seconds`` is the lifetime of tokens minted via the local
|
| 143 |
-
# ``/token`` dev endpoint; real IdPs (Keycloak/Auth0) set their own.
|
| 144 |
-
jwt_secret: str | None = None
|
| 145 |
-
jwt_issuer: str = "secureagentrag"
|
| 146 |
-
jwt_audience: str = "secureagentrag-api"
|
| 147 |
-
jwt_ttl_seconds: int = 3600
|
| 148 |
-
jwt_algorithm: str = "HS256"
|
| 149 |
-
# JWKS endpoint for RS256 verification (e.g. Keycloak, Auth0).
|
| 150 |
-
# When set and jwt_algorithm == "RS256", tokens are verified against
|
| 151 |
-
# the cached JWKS instead of jwt_secret.
|
| 152 |
-
jwks_url: str | None = None
|
| 153 |
-
jwks_cache_ttl_seconds: int = 300
|
| 154 |
-
|
| 155 |
-
# ── Citation Faithfulness Gate (NLI) ─────────────────────────────────────────
|
| 156 |
-
# After synthesis, run a per-sentence NLI check: for each sentence that
|
| 157 |
-
# carries an inline `[N]` citation, ask a yes/no entailment question
|
| 158 |
-
# against the cited chunk's text. Sentences that fail are either marked
|
| 159 |
-
# `[unsupported]` (soft mode) or dropped from the answer (strict mode).
|
| 160 |
-
# The check uses the same local LLM as the rest of the graph — no extra
|
| 161 |
-
# model download. Cost: one LLM call per cited sentence (parallel).
|
| 162 |
-
faithfulness_gate_enabled: bool = False
|
| 163 |
-
faithfulness_gate_mode: str = "flag" # "flag" | "drop"
|
| 164 |
-
faithfulness_threshold: float = 0.7 # min entailment ratio to consider answer faithful
|
| 165 |
-
faithfulness_max_concurrent: int = 4 # parallel NLI checks
|
| 166 |
-
|
| 167 |
-
# ── Redis (for distributed rate limiting / caching) ──────────────────────────
|
| 168 |
-
redis_url: str = "redis://localhost:6379/0"
|
| 169 |
-
use_redis_rate_limiter: bool = False
|
| 170 |
-
|
| 171 |
-
# ── PII Redaction ────────────────────────────────────────────────────────────
|
| 172 |
-
# Scrub email, phone, SSN, credit-card, IBAN, IP address before persisting
|
| 173 |
-
# to audit log / query cache. Defense against accidental PII leakage into
|
| 174 |
-
# secondary stores. Regex-based by default; if Microsoft Presidio is
|
| 175 |
-
# installed it is used automatically for higher recall.
|
| 176 |
-
pii_redaction_enabled: bool = True
|
| 177 |
-
|
| 178 |
-
# ── Prompt-Injection Guardrails ──────────────────────────────────────────────
|
| 179 |
-
# Run a regex + heuristic check on the user query before retrieval. Blocks
|
| 180 |
-
# obvious jailbreak / system-prompt-override attempts. Logged via the audit
|
| 181 |
-
# logger as ``security_block`` events.
|
| 182 |
-
guardrails_enabled: bool = True
|
| 183 |
-
# Strict mode: after the fast regex gate, escalate ambiguous or all queries
|
| 184 |
-
# to a local LLM-based classifier for a second opinion. Adds one LLM call
|
| 185 |
-
# per query but catches adversarial inputs that evade regex patterns.
|
| 186 |
-
guardrails_strict: bool = False
|
| 187 |
-
# Escalation backend used in strict mode. Options:
|
| 188 |
-
# "llm" — legacy SAFE/UNSAFE prompt on the synth-grade model
|
| 189 |
-
# (core.agents.guardrails_llm). Default for backward
|
| 190 |
-
# compatibility.
|
| 191 |
-
# "llamaguard" — Meta's LlamaGuard 3 8B via Ollama. Use with
|
| 192 |
-
# ``ollama pull llama-guard3:8b``. More accurate on
|
| 193 |
-
# the standard S1-S14 taxonomy.
|
| 194 |
-
guardrails_backend: str = "llm"
|
| 195 |
-
llamaguard_model: str = "llama-guard3:8b"
|
| 196 |
-
|
| 197 |
-
# ── Contextual Retrieval (Anthropic 2024 technique) ──────────────────────────
|
| 198 |
-
# Prepend a short LLM-generated context summary to each chunk before
|
| 199 |
-
# embedding. Adds 1 cheap LLM call per chunk at ingestion time but
|
| 200 |
-
# measurably improves retrieval recall (Anthropic reported ~35-49%
|
| 201 |
-
# failure reduction). Local Qwen3-8B is fine for the summary.
|
| 202 |
-
contextual_retrieval_enabled: bool = False
|
| 203 |
-
|
| 204 |
-
# ── VLM OCR (Primary OCR via vision-language model) ───────────────────────────
|
| 205 |
-
# Use a VLM (Qwen2.5-VL / Qwen3-VL, LLaVA, etc.) via Ollama as the primary OCR path.
|
| 206 |
-
# Superior to PaddleOCR on complex layouts, tables, and mixed-language
|
| 207 |
-
# documents. Falls back to PaddleOCR when the VLM is unavailable.
|
| 208 |
-
vlm_ocr_enabled: bool = False
|
| 209 |
-
vlm_ocr_model: str = "qwen2.5-vl"
|
| 210 |
-
|
| 211 |
-
# ── Multi-Tenancy ────────────────────────────────────────────────────────────
|
| 212 |
-
# When true, each organization gets its own Qdrant collection
|
| 213 |
-
# (documents_{org_id}). This provides stronger isolation than payload-level
|
| 214 |
-
# RBAC filtering but requires creating collections per org on first use.
|
| 215 |
-
# When false, all docs share a single collection with RBAC at payload level.
|
| 216 |
-
multi_tenant_collections: bool = False
|
| 217 |
-
|
| 218 |
-
# ── BYOK demo mode (P6 production launch, see launch-plan/03-backend-byok.md)
|
| 219 |
-
# In BYOK mode the FastAPI surface accepts per-request LLM keys from visitor
|
| 220 |
-
# headers, scopes Qdrant writes to per-session collections, and disables
|
| 221 |
-
# Phoenix instrumentation. Off in dev/staging, on in the Hugging Face Space
|
| 222 |
-
# production image (SAR_BYOK_MODE=true via Space secrets).
|
| 223 |
-
byok_mode: bool = False
|
| 224 |
-
# When BYOK is on and a visitor did NOT bring their own LLM key, the owner
|
| 225 |
-
# key in .env is used but throttled to this many requests per IP per hour.
|
| 226 |
-
# The cap is intentionally tight so the Groq free-tier 30 RPM / 14400 RPD
|
| 227 |
-
# is never exhausted by a single visitor.
|
| 228 |
-
byok_owner_key_quota_per_hour: int = 3
|
| 229 |
-
# Per-session Qdrant collections (documents_sess_<session_id>) are auto
|
| 230 |
-
# purged after this many hours by retrieval/session_purge.py.
|
| 231 |
-
session_collection_ttl_hours: int = 24
|
| 232 |
-
# CORS allowlist consulted by the FastAPI middleware when byok_mode=true.
|
| 233 |
-
# Empty list = no CORS middleware mounted (dev default).
|
| 234 |
-
cors_allow_origins: list[str] = []
|
| 235 |
-
|
| 236 |
-
# ── Multi-Modal RAG ──────────────────────────────────────────────────────────
|
| 237 |
-
# When ingesting images, also generate a rich text description using a VLM.
|
| 238 |
-
# The description is embedded as a separate chunk, enabling retrieval for
|
| 239 |
-
# queries like "what does the diagram show?" without requiring CLIP or
|
| 240 |
-
# other multi-modal embedding models.
|
| 241 |
-
multimodal_descriptions_enabled: bool = False
|
| 242 |
-
|
| 243 |
-
# ── Self-Query Retrieval ─────────────────────────────────────────────────────
|
| 244 |
-
# Extract structured metadata filters (source_file, date_range,
|
| 245 |
-
# sensitivity_level, roles) from the natural language query using a small
|
| 246 |
-
# local LLM prompt. The filters are merged with the RBAC filter and passed
|
| 247 |
-
# to Qdrant, scoping retrieval before embedding search runs.
|
| 248 |
-
self_query_enabled: bool = False
|
| 249 |
-
|
| 250 |
-
# ── HyDE (Hypothetical Document Embeddings) ──────────────────────────────────
|
| 251 |
-
# Generate a hypothetical answer to the query, embed *that* instead of the
|
| 252 |
-
# raw query. Boosts recall when query vocabulary differs from doc
|
| 253 |
-
# vocabulary (questions vs declarative sentences). Adds one LLM call per
|
| 254 |
-
# query — skip for simple keyword lookups; enable for complex questions.
|
| 255 |
-
hyde_enabled: bool = False
|
| 256 |
-
|
| 257 |
-
# ── Pricing for cost dashboard (USD per 1M tokens) ───────────────────────────
|
| 258 |
-
# Used by evaluation/cost.py to convert recorded usage into $/query.
|
| 259 |
-
price_groq_input_per_1m: float = 0.59
|
| 260 |
-
price_groq_output_per_1m: float = 0.79
|
| 261 |
-
price_openai_input_per_1m: float = 2.50
|
| 262 |
-
price_openai_output_per_1m: float = 10.00
|
| 263 |
-
price_anthropic_input_per_1m: float = 3.00
|
| 264 |
-
price_anthropic_output_per_1m: float = 15.00
|
| 265 |
-
# Local inference: estimated electricity cost only (consumer hardware).
|
| 266 |
-
# 200W GPU @ $0.15/kWh ≈ $0.03/hour ≈ $0.000008/sec
|
| 267 |
-
price_local_per_second: float = 0.000008
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
def _apply_calibration(settings_obj: Settings) -> None:
|
| 271 |
-
"""Override threshold defaults from ``evaluation/calibration.json`` when present.
|
| 272 |
-
|
| 273 |
-
The calibration script (``scripts/calibrate_thresholds.py``) writes the
|
| 274 |
-
chosen confidence + faithfulness cutoffs against a labelled gold set. Loading
|
| 275 |
-
them here means deployments inherit the latest tuned values automatically,
|
| 276 |
-
while an explicit ``SAR_CONFIDENCE_THRESHOLD`` / ``SAR_FAITHFULNESS_THRESHOLD``
|
| 277 |
-
env var still wins so operators can override per environment.
|
| 278 |
-
|
| 279 |
-
Silently no-ops when the file is missing, malformed, or the relevant keys
|
| 280 |
-
are absent — never blocks startup.
|
| 281 |
-
"""
|
| 282 |
-
calib_path = Path(__file__).resolve().parent.parent / "evaluation" / "calibration.json"
|
| 283 |
-
if not calib_path.exists():
|
| 284 |
-
return
|
| 285 |
-
try:
|
| 286 |
-
data = json.loads(calib_path.read_text(encoding="utf-8"))
|
| 287 |
-
except (OSError, json.JSONDecodeError):
|
| 288 |
-
return
|
| 289 |
-
|
| 290 |
-
# Reject degenerate sweeps (no negatives or no positives -> the chosen
|
| 291 |
-
# threshold has no statistical meaning). Keeping the original default in
|
| 292 |
-
# that case is safer than letting a 0.0 cut-off escape into production.
|
| 293 |
-
def _sane(block: dict) -> bool:
|
| 294 |
-
try:
|
| 295 |
-
return (
|
| 296 |
-
int(block.get("n_pos", 0)) > 0
|
| 297 |
-
and int(block.get("n_neg", 0)) > 0
|
| 298 |
-
and float(block.get("chosen_threshold", 0.0)) > 0.0
|
| 299 |
-
)
|
| 300 |
-
except (TypeError, ValueError):
|
| 301 |
-
return False
|
| 302 |
-
|
| 303 |
-
conf_block = data.get("confidence", {})
|
| 304 |
-
if _sane(conf_block) and os.environ.get("SAR_CONFIDENCE_THRESHOLD") is None:
|
| 305 |
-
with contextlib.suppress(TypeError, ValueError):
|
| 306 |
-
settings_obj.confidence_threshold = float(conf_block["chosen_threshold"])
|
| 307 |
-
|
| 308 |
-
faith_block = data.get("faithfulness", {})
|
| 309 |
-
if _sane(faith_block) and os.environ.get("SAR_FAITHFULNESS_THRESHOLD") is None:
|
| 310 |
-
with contextlib.suppress(TypeError, ValueError):
|
| 311 |
-
settings_obj.faithfulness_threshold = float(faith_block["chosen_threshold"])
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
# Singleton instance — import this throughout the application
|
| 315 |
-
settings = Settings()
|
| 316 |
-
_apply_calibration(settings)
|
|
|
|
| 1 |
+
"""Application settings managed via pydantic-settings with environment variable support."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import contextlib
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Settings(BaseSettings):
|
| 14 |
+
"""Central configuration for SecureAgentRAG.
|
| 15 |
+
|
| 16 |
+
All settings can be overridden via environment variables prefixed with ``SAR_``.
|
| 17 |
+
For example, ``SAR_DEBUG=true`` sets ``debug`` to True.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
model_config = SettingsConfigDict(
|
| 21 |
+
env_file=".env",
|
| 22 |
+
env_prefix="SAR_",
|
| 23 |
+
env_file_encoding="utf-8",
|
| 24 |
+
case_sensitive=False,
|
| 25 |
+
extra="ignore",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# ── Application ──────────────────────────────────────────────────────────────
|
| 29 |
+
app_name: str = "SecureAgentRAG"
|
| 30 |
+
debug: bool = False
|
| 31 |
+
log_level: str = "INFO"
|
| 32 |
+
|
| 33 |
+
# ── Qdrant Vector Store ─────────────────────────────────────────────────���────
|
| 34 |
+
qdrant_url: str = "http://localhost:6333"
|
| 35 |
+
qdrant_collection: str = "documents"
|
| 36 |
+
qdrant_api_key: str | None = None
|
| 37 |
+
|
| 38 |
+
# ── Ollama / LLM ─────────────────────────────────────────────────────────────
|
| 39 |
+
ollama_url: str = "http://localhost:11434"
|
| 40 |
+
llm_model: str = "qwen3:8b"
|
| 41 |
+
embedding_model: str = "bge-m3"
|
| 42 |
+
embedding_dim: int = 1024
|
| 43 |
+
embedding_backend: str = "ollama" # "ollama" or "local" (sentence-transformers)
|
| 44 |
+
local_embedding_model: str = "BAAI/bge-m3"
|
| 45 |
+
# How long Ollama keeps models resident in VRAM between requests.
|
| 46 |
+
# On consumer hardware the LLM (qwen3:8b ~5.5GB) and embedding (bge-m3 ~1.2GB)
|
| 47 |
+
# need to swap if VRAM is tight. Long keep-alive avoids ~5-10s reload per swap.
|
| 48 |
+
ollama_keep_alive: str = "30m"
|
| 49 |
+
|
| 50 |
+
# ── Chunking ─────────────────────────────────────────────────────────────────
|
| 51 |
+
chunk_size: int = 1000
|
| 52 |
+
chunk_overlap: int = 200
|
| 53 |
+
|
| 54 |
+
# ── Retrieval ────────────────────────────────────────────────────────────────
|
| 55 |
+
top_k: int = 10
|
| 56 |
+
rerank_top_k: int = 5
|
| 57 |
+
relevance_threshold: float = 0.7
|
| 58 |
+
# RAG Fusion: generate N query reformulations, retrieve in parallel,
|
| 59 |
+
# fuse the ranked lists via RRF. Boosts recall on under-specified
|
| 60 |
+
# queries. Cost: N-1 extra LLM calls + N parallel Qdrant searches.
|
| 61 |
+
# Set to 1 to disable.
|
| 62 |
+
rag_fusion_n_queries: int = 3
|
| 63 |
+
rag_fusion_enabled: bool = True
|
| 64 |
+
# ── Reranker ─────────────────────────────────────────────────────────────────
|
| 65 |
+
# Re-score retrieved documents for higher precision.
|
| 66 |
+
# Options: "none" (disabled), "cross_encoder" (BGE-Reranker-v2-M3),
|
| 67 |
+
# "colbert" (ColBERTv2 late-interaction, requires colbert-ai package).
|
| 68 |
+
# The cross-encoder downloads ~600MB from HuggingFace on first use.
|
| 69 |
+
# The ColBERT checkpoint is ~400MB. Disabled by default so the first
|
| 70 |
+
# query does not silently hang on download. Pre-download explicitly.
|
| 71 |
+
reranker_type: str = "none"
|
| 72 |
+
reranker_checkpoint: str = "BAAI/bge-reranker-v2-m3"
|
| 73 |
+
colbert_checkpoint: str = "colbert-ir/colbertv2.0"
|
| 74 |
+
# Path to a locally fine-tuned cross-encoder checkpoint produced by
|
| 75 |
+
# scripts/train_reranker.py. Used when reranker_type == "fine_tuned".
|
| 76 |
+
finetuned_reranker_path: str = "data/checkpoints/reranker-domain-v1"
|
| 77 |
+
|
| 78 |
+
# ── Inference Providers ──────────────────────────────────────────────────────
|
| 79 |
+
default_provider: str = "ollama"
|
| 80 |
+
cloud_provider: str | None = None
|
| 81 |
+
groq_api_key: str | None = None
|
| 82 |
+
openai_api_key: str | None = None
|
| 83 |
+
anthropic_api_key: str | None = None
|
| 84 |
+
groq_api_base: str = "https://api.groq.com/openai/v1"
|
| 85 |
+
openai_api_base: str = "https://api.openai.com/v1"
|
| 86 |
+
anthropic_api_base: str = "https://api.anthropic.com/v1"
|
| 87 |
+
|
| 88 |
+
# ── RAG Pipeline Thresholds ───────────────────────────────────────────────────
|
| 89 |
+
relevance_retry_threshold: float = 0.5
|
| 90 |
+
confidence_threshold: float = 0.6
|
| 91 |
+
max_retries: int = 2
|
| 92 |
+
|
| 93 |
+
# ── JSON Citations ────────────────────────────────────────────────────────────
|
| 94 |
+
# When enabled, the synthesizer requests structured JSON output from the LLM
|
| 95 |
+
# with `answer` and `citations` fields instead of relying on regex extraction.
|
| 96 |
+
json_citations_enabled: bool = False
|
| 97 |
+
|
| 98 |
+
# ── Embedding Batch Size ──────────────────────────────────────────────────────
|
| 99 |
+
embedding_batch_size: int = 32 # Max texts per embedding API call
|
| 100 |
+
embedding_max_concurrent_batches: int = 4 # Max concurrent batch requests
|
| 101 |
+
|
| 102 |
+
# ── RBAC ─────────────────────────────────────────────────────────────────────
|
| 103 |
+
enable_rbac: bool = True
|
| 104 |
+
|
| 105 |
+
# ── Observability (Phoenix) ──────────────────────────────────────────────────
|
| 106 |
+
phoenix_endpoint: str | None = None
|
| 107 |
+
|
| 108 |
+
# ── Sparse Vectors (Qdrant native, replaces rank_bm25 pickle) ��───────────────
|
| 109 |
+
sparse_backend: str = "bm25" # "bm25" | "splade"
|
| 110 |
+
sparse_vector_name: str = "sparse"
|
| 111 |
+
sparse_model: str = "naver/splade-cocondenser-ensembledistil"
|
| 112 |
+
|
| 113 |
+
# ── Audit + Conversation Storage ──────────────────────────────────────────────
|
| 114 |
+
audit_log_dir: str = "audit_logs"
|
| 115 |
+
conversation_dir: str = "conversations"
|
| 116 |
+
checkpoint_db_path: str = "data/checkpoints.sqlite"
|
| 117 |
+
# Opt-in: enable persistent (SQLite/Postgres) LangGraph checkpointing.
|
| 118 |
+
# Default off because pytest-asyncio creates per-test event loops which
|
| 119 |
+
# collide with aiosqlite's loop-bound connection. For production single-
|
| 120 |
+
# process Streamlit / FastAPI deployments, set SAR_USE_PERSISTENT_CHECKPOINTER=true.
|
| 121 |
+
use_persistent_checkpointer: bool = False
|
| 122 |
+
|
| 123 |
+
# ── PostgreSQL (for LangGraph checkpointing) ─────────────────────────────────
|
| 124 |
+
postgres_url: str = "postgresql://sar_user:sar_password@localhost:5433/secureagentrag"
|
| 125 |
+
|
| 126 |
+
# ── Pipeline SLO ─────────────────────────────────────────────────────────────
|
| 127 |
+
# Hard wall-clock budget for a single RAG pipeline run (rewrite loop +
|
| 128 |
+
# retrieval + grading + synthesis + evaluation). On timeout the caller
|
| 129 |
+
# gets a graceful refusal + audit entry; nothing partial is rendered as
|
| 130 |
+
# if the answer succeeded. 0 disables the deadline.
|
| 131 |
+
request_timeout_s: float = 60.0
|
| 132 |
+
|
| 133 |
+
# ── Authentication ───────────────────────────────────────────────────────────
|
| 134 |
+
# When ``jwt_secret`` is set the FastAPI / MCP layers verify HS256-signed
|
| 135 |
+
# JWTs and derive UserContext from validated claims. When unset, callers
|
| 136 |
+
# fall back to the dev-mode base64(json(UserContext)) token shape so
|
| 137 |
+
# existing tests and smoke scripts keep working — but a runtime warning is
|
| 138 |
+
# emitted on every request. Production deployments MUST set this.
|
| 139 |
+
#
|
| 140 |
+
# ``jwt_issuer`` / ``jwt_audience`` are checked against ``iss`` / ``aud``
|
| 141 |
+
# claims when present. Leave empty to disable that check (default).
|
| 142 |
+
# ``jwt_ttl_seconds`` is the lifetime of tokens minted via the local
|
| 143 |
+
# ``/token`` dev endpoint; real IdPs (Keycloak/Auth0) set their own.
|
| 144 |
+
jwt_secret: str | None = None
|
| 145 |
+
jwt_issuer: str = "secureagentrag"
|
| 146 |
+
jwt_audience: str = "secureagentrag-api"
|
| 147 |
+
jwt_ttl_seconds: int = 3600
|
| 148 |
+
jwt_algorithm: str = "HS256"
|
| 149 |
+
# JWKS endpoint for RS256 verification (e.g. Keycloak, Auth0).
|
| 150 |
+
# When set and jwt_algorithm == "RS256", tokens are verified against
|
| 151 |
+
# the cached JWKS instead of jwt_secret.
|
| 152 |
+
jwks_url: str | None = None
|
| 153 |
+
jwks_cache_ttl_seconds: int = 300
|
| 154 |
+
|
| 155 |
+
# ── Citation Faithfulness Gate (NLI) ─────────────────────────────────────────
|
| 156 |
+
# After synthesis, run a per-sentence NLI check: for each sentence that
|
| 157 |
+
# carries an inline `[N]` citation, ask a yes/no entailment question
|
| 158 |
+
# against the cited chunk's text. Sentences that fail are either marked
|
| 159 |
+
# `[unsupported]` (soft mode) or dropped from the answer (strict mode).
|
| 160 |
+
# The check uses the same local LLM as the rest of the graph — no extra
|
| 161 |
+
# model download. Cost: one LLM call per cited sentence (parallel).
|
| 162 |
+
faithfulness_gate_enabled: bool = False
|
| 163 |
+
faithfulness_gate_mode: str = "flag" # "flag" | "drop"
|
| 164 |
+
faithfulness_threshold: float = 0.7 # min entailment ratio to consider answer faithful
|
| 165 |
+
faithfulness_max_concurrent: int = 4 # parallel NLI checks
|
| 166 |
+
|
| 167 |
+
# ── Redis (for distributed rate limiting / caching) ──────────────────────────
|
| 168 |
+
redis_url: str = "redis://localhost:6379/0"
|
| 169 |
+
use_redis_rate_limiter: bool = False
|
| 170 |
+
|
| 171 |
+
# ── PII Redaction ────────────────────────────────────────────────────────────
|
| 172 |
+
# Scrub email, phone, SSN, credit-card, IBAN, IP address before persisting
|
| 173 |
+
# to audit log / query cache. Defense against accidental PII leakage into
|
| 174 |
+
# secondary stores. Regex-based by default; if Microsoft Presidio is
|
| 175 |
+
# installed it is used automatically for higher recall.
|
| 176 |
+
pii_redaction_enabled: bool = True
|
| 177 |
+
|
| 178 |
+
# ── Prompt-Injection Guardrails ──────────────────────────────────────────────
|
| 179 |
+
# Run a regex + heuristic check on the user query before retrieval. Blocks
|
| 180 |
+
# obvious jailbreak / system-prompt-override attempts. Logged via the audit
|
| 181 |
+
# logger as ``security_block`` events.
|
| 182 |
+
guardrails_enabled: bool = True
|
| 183 |
+
# Strict mode: after the fast regex gate, escalate ambiguous or all queries
|
| 184 |
+
# to a local LLM-based classifier for a second opinion. Adds one LLM call
|
| 185 |
+
# per query but catches adversarial inputs that evade regex patterns.
|
| 186 |
+
guardrails_strict: bool = False
|
| 187 |
+
# Escalation backend used in strict mode. Options:
|
| 188 |
+
# "llm" — legacy SAFE/UNSAFE prompt on the synth-grade model
|
| 189 |
+
# (core.agents.guardrails_llm). Default for backward
|
| 190 |
+
# compatibility.
|
| 191 |
+
# "llamaguard" — Meta's LlamaGuard 3 8B via Ollama. Use with
|
| 192 |
+
# ``ollama pull llama-guard3:8b``. More accurate on
|
| 193 |
+
# the standard S1-S14 taxonomy.
|
| 194 |
+
guardrails_backend: str = "llm"
|
| 195 |
+
llamaguard_model: str = "llama-guard3:8b"
|
| 196 |
+
|
| 197 |
+
# ── Contextual Retrieval (Anthropic 2024 technique) ──────────────────────────
|
| 198 |
+
# Prepend a short LLM-generated context summary to each chunk before
|
| 199 |
+
# embedding. Adds 1 cheap LLM call per chunk at ingestion time but
|
| 200 |
+
# measurably improves retrieval recall (Anthropic reported ~35-49%
|
| 201 |
+
# failure reduction). Local Qwen3-8B is fine for the summary.
|
| 202 |
+
contextual_retrieval_enabled: bool = False
|
| 203 |
+
|
| 204 |
+
# ── VLM OCR (Primary OCR via vision-language model) ───────────────────────────
|
| 205 |
+
# Use a VLM (Qwen2.5-VL / Qwen3-VL, LLaVA, etc.) via Ollama as the primary OCR path.
|
| 206 |
+
# Superior to PaddleOCR on complex layouts, tables, and mixed-language
|
| 207 |
+
# documents. Falls back to PaddleOCR when the VLM is unavailable.
|
| 208 |
+
vlm_ocr_enabled: bool = False
|
| 209 |
+
vlm_ocr_model: str = "qwen2.5-vl"
|
| 210 |
+
|
| 211 |
+
# ── Multi-Tenancy ────────────────────────────────────────────────────────────
|
| 212 |
+
# When true, each organization gets its own Qdrant collection
|
| 213 |
+
# (documents_{org_id}). This provides stronger isolation than payload-level
|
| 214 |
+
# RBAC filtering but requires creating collections per org on first use.
|
| 215 |
+
# When false, all docs share a single collection with RBAC at payload level.
|
| 216 |
+
multi_tenant_collections: bool = False
|
| 217 |
+
|
| 218 |
+
# ── BYOK demo mode (P6 production launch, see launch-plan/03-backend-byok.md)
|
| 219 |
+
# In BYOK mode the FastAPI surface accepts per-request LLM keys from visitor
|
| 220 |
+
# headers, scopes Qdrant writes to per-session collections, and disables
|
| 221 |
+
# Phoenix instrumentation. Off in dev/staging, on in the Hugging Face Space
|
| 222 |
+
# production image (SAR_BYOK_MODE=true via Space secrets).
|
| 223 |
+
byok_mode: bool = False
|
| 224 |
+
# When BYOK is on and a visitor did NOT bring their own LLM key, the owner
|
| 225 |
+
# key in .env is used but throttled to this many requests per IP per hour.
|
| 226 |
+
# The cap is intentionally tight so the Groq free-tier 30 RPM / 14400 RPD
|
| 227 |
+
# is never exhausted by a single visitor.
|
| 228 |
+
byok_owner_key_quota_per_hour: int = 3
|
| 229 |
+
# Per-session Qdrant collections (documents_sess_<session_id>) are auto
|
| 230 |
+
# purged after this many hours by retrieval/session_purge.py.
|
| 231 |
+
session_collection_ttl_hours: int = 24
|
| 232 |
+
# CORS allowlist consulted by the FastAPI middleware when byok_mode=true.
|
| 233 |
+
# Empty list = no CORS middleware mounted (dev default).
|
| 234 |
+
cors_allow_origins: list[str] = []
|
| 235 |
+
|
| 236 |
+
# ── Multi-Modal RAG ──────────────────────────────────────────────────────────
|
| 237 |
+
# When ingesting images, also generate a rich text description using a VLM.
|
| 238 |
+
# The description is embedded as a separate chunk, enabling retrieval for
|
| 239 |
+
# queries like "what does the diagram show?" without requiring CLIP or
|
| 240 |
+
# other multi-modal embedding models.
|
| 241 |
+
multimodal_descriptions_enabled: bool = False
|
| 242 |
+
|
| 243 |
+
# ── Self-Query Retrieval ─────────────────────────────────────────────────────
|
| 244 |
+
# Extract structured metadata filters (source_file, date_range,
|
| 245 |
+
# sensitivity_level, roles) from the natural language query using a small
|
| 246 |
+
# local LLM prompt. The filters are merged with the RBAC filter and passed
|
| 247 |
+
# to Qdrant, scoping retrieval before embedding search runs.
|
| 248 |
+
self_query_enabled: bool = False
|
| 249 |
+
|
| 250 |
+
# ── HyDE (Hypothetical Document Embeddings) ──────────────────────────────────
|
| 251 |
+
# Generate a hypothetical answer to the query, embed *that* instead of the
|
| 252 |
+
# raw query. Boosts recall when query vocabulary differs from doc
|
| 253 |
+
# vocabulary (questions vs declarative sentences). Adds one LLM call per
|
| 254 |
+
# query — skip for simple keyword lookups; enable for complex questions.
|
| 255 |
+
hyde_enabled: bool = False
|
| 256 |
+
|
| 257 |
+
# ── Pricing for cost dashboard (USD per 1M tokens) ───────────────────────────
|
| 258 |
+
# Used by evaluation/cost.py to convert recorded usage into $/query.
|
| 259 |
+
price_groq_input_per_1m: float = 0.59
|
| 260 |
+
price_groq_output_per_1m: float = 0.79
|
| 261 |
+
price_openai_input_per_1m: float = 2.50
|
| 262 |
+
price_openai_output_per_1m: float = 10.00
|
| 263 |
+
price_anthropic_input_per_1m: float = 3.00
|
| 264 |
+
price_anthropic_output_per_1m: float = 15.00
|
| 265 |
+
# Local inference: estimated electricity cost only (consumer hardware).
|
| 266 |
+
# 200W GPU @ $0.15/kWh ≈ $0.03/hour ≈ $0.000008/sec
|
| 267 |
+
price_local_per_second: float = 0.000008
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _apply_calibration(settings_obj: Settings) -> None:
|
| 271 |
+
"""Override threshold defaults from ``evaluation/calibration.json`` when present.
|
| 272 |
+
|
| 273 |
+
The calibration script (``scripts/calibrate_thresholds.py``) writes the
|
| 274 |
+
chosen confidence + faithfulness cutoffs against a labelled gold set. Loading
|
| 275 |
+
them here means deployments inherit the latest tuned values automatically,
|
| 276 |
+
while an explicit ``SAR_CONFIDENCE_THRESHOLD`` / ``SAR_FAITHFULNESS_THRESHOLD``
|
| 277 |
+
env var still wins so operators can override per environment.
|
| 278 |
+
|
| 279 |
+
Silently no-ops when the file is missing, malformed, or the relevant keys
|
| 280 |
+
are absent — never blocks startup.
|
| 281 |
+
"""
|
| 282 |
+
calib_path = Path(__file__).resolve().parent.parent / "evaluation" / "calibration.json"
|
| 283 |
+
if not calib_path.exists():
|
| 284 |
+
return
|
| 285 |
+
try:
|
| 286 |
+
data = json.loads(calib_path.read_text(encoding="utf-8"))
|
| 287 |
+
except (OSError, json.JSONDecodeError):
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
# Reject degenerate sweeps (no negatives or no positives -> the chosen
|
| 291 |
+
# threshold has no statistical meaning). Keeping the original default in
|
| 292 |
+
# that case is safer than letting a 0.0 cut-off escape into production.
|
| 293 |
+
def _sane(block: dict) -> bool:
|
| 294 |
+
try:
|
| 295 |
+
return (
|
| 296 |
+
int(block.get("n_pos", 0)) > 0
|
| 297 |
+
and int(block.get("n_neg", 0)) > 0
|
| 298 |
+
and float(block.get("chosen_threshold", 0.0)) > 0.0
|
| 299 |
+
)
|
| 300 |
+
except (TypeError, ValueError):
|
| 301 |
+
return False
|
| 302 |
+
|
| 303 |
+
conf_block = data.get("confidence", {})
|
| 304 |
+
if _sane(conf_block) and os.environ.get("SAR_CONFIDENCE_THRESHOLD") is None:
|
| 305 |
+
with contextlib.suppress(TypeError, ValueError):
|
| 306 |
+
settings_obj.confidence_threshold = float(conf_block["chosen_threshold"])
|
| 307 |
+
|
| 308 |
+
faith_block = data.get("faithfulness", {})
|
| 309 |
+
if _sane(faith_block) and os.environ.get("SAR_FAITHFULNESS_THRESHOLD") is None:
|
| 310 |
+
with contextlib.suppress(TypeError, ValueError):
|
| 311 |
+
settings_obj.faithfulness_threshold = float(faith_block["chosen_threshold"])
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# Singleton instance — import this throughout the application
|
| 315 |
+
settings = Settings()
|
| 316 |
+
_apply_calibration(settings)
|
inference/cloud_clients.py
CHANGED
|
@@ -1,577 +1,577 @@
|
|
| 1 |
-
"""Cloud LLM provider clients (Groq, OpenAI, Anthropic Claude)."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import json
|
| 6 |
-
import time
|
| 7 |
-
from abc import ABC, abstractmethod
|
| 8 |
-
from enum import StrEnum
|
| 9 |
-
from typing import TYPE_CHECKING, Any
|
| 10 |
-
|
| 11 |
-
if TYPE_CHECKING:
|
| 12 |
-
from collections.abc import AsyncGenerator
|
| 13 |
-
|
| 14 |
-
import httpx
|
| 15 |
-
from tenacity import (
|
| 16 |
-
retry,
|
| 17 |
-
retry_if_exception_type,
|
| 18 |
-
stop_after_attempt,
|
| 19 |
-
wait_exponential,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
from config.settings import settings
|
| 23 |
-
from inference.llm_factory import LLMResponse
|
| 24 |
-
from utils.logging import get_logger
|
| 25 |
-
|
| 26 |
-
logger = get_logger(__name__)
|
| 27 |
-
|
| 28 |
-
# Retry decorator for transient connection failures only
|
| 29 |
-
_retry_on_connection = retry(
|
| 30 |
-
retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
|
| 31 |
-
stop=stop_after_attempt(3),
|
| 32 |
-
wait=wait_exponential(multiplier=1, min=1, max=10),
|
| 33 |
-
reraise=True,
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class LLMProvider(StrEnum):
|
| 38 |
-
"""Supported LLM provider identifiers."""
|
| 39 |
-
|
| 40 |
-
OLLAMA = "ollama"
|
| 41 |
-
GROQ = "groq"
|
| 42 |
-
OPENAI = "openai"
|
| 43 |
-
ANTHROPIC = "anthropic"
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class BaseCloudClient(ABC):
|
| 47 |
-
"""Abstract base class for cloud LLM provider clients.
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
api_key: Provider API key for authentication.
|
| 51 |
-
model: Default model identifier.
|
| 52 |
-
timeout: Request timeout in seconds.
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
def __init__(self, api_key: str, model: str, timeout: float = 60.0) -> None:
|
| 56 |
-
self.api_key = api_key
|
| 57 |
-
self.model = model
|
| 58 |
-
self.timeout = timeout
|
| 59 |
-
self._client = httpx.AsyncClient(timeout=httpx.Timeout(timeout))
|
| 60 |
-
|
| 61 |
-
@abstractmethod
|
| 62 |
-
async def generate(
|
| 63 |
-
self,
|
| 64 |
-
prompt: str,
|
| 65 |
-
system_prompt: str = "",
|
| 66 |
-
temperature: float = 0.7,
|
| 67 |
-
max_tokens: int = 2048,
|
| 68 |
-
json_mode: bool = False,
|
| 69 |
-
) -> LLMResponse:
|
| 70 |
-
"""Generate a completion from the provider.
|
| 71 |
-
|
| 72 |
-
Args:
|
| 73 |
-
prompt: The user prompt text.
|
| 74 |
-
system_prompt: Optional system context.
|
| 75 |
-
temperature: Sampling temperature.
|
| 76 |
-
max_tokens: Maximum tokens to generate.
|
| 77 |
-
json_mode: When True, request JSON-formatted output.
|
| 78 |
-
|
| 79 |
-
Returns:
|
| 80 |
-
LLMResponse with generated text and metadata.
|
| 81 |
-
"""
|
| 82 |
-
|
| 83 |
-
@abstractmethod
|
| 84 |
-
async def chat(
|
| 85 |
-
self,
|
| 86 |
-
messages: list[dict],
|
| 87 |
-
temperature: float = 0.7,
|
| 88 |
-
max_tokens: int = 2048,
|
| 89 |
-
) -> LLMResponse:
|
| 90 |
-
"""Send a chat conversation to the provider.
|
| 91 |
-
|
| 92 |
-
Args:
|
| 93 |
-
messages: List of message dicts with 'role' and 'content' keys.
|
| 94 |
-
temperature: Sampling temperature.
|
| 95 |
-
max_tokens: Maximum tokens to generate.
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
LLMResponse with generated text and metadata.
|
| 99 |
-
"""
|
| 100 |
-
|
| 101 |
-
@abstractmethod
|
| 102 |
-
async def generate_stream(
|
| 103 |
-
self,
|
| 104 |
-
prompt: str,
|
| 105 |
-
system_prompt: str = "",
|
| 106 |
-
temperature: float = 0.7,
|
| 107 |
-
max_tokens: int = 2048,
|
| 108 |
-
) -> AsyncGenerator[str, None]:
|
| 109 |
-
"""Stream a completion from the provider, yielding tokens as they arrive.
|
| 110 |
-
|
| 111 |
-
Args:
|
| 112 |
-
prompt: The user prompt text.
|
| 113 |
-
system_prompt: Optional system context.
|
| 114 |
-
temperature: Sampling temperature.
|
| 115 |
-
max_tokens: Maximum tokens to generate.
|
| 116 |
-
|
| 117 |
-
Yields:
|
| 118 |
-
Token strings as they are generated.
|
| 119 |
-
"""
|
| 120 |
-
|
| 121 |
-
@abstractmethod
|
| 122 |
-
async def health_check(self) -> bool:
|
| 123 |
-
"""Check if the provider API is reachable.
|
| 124 |
-
|
| 125 |
-
Returns:
|
| 126 |
-
True if the API responds successfully.
|
| 127 |
-
"""
|
| 128 |
-
|
| 129 |
-
async def close(self) -> None:
|
| 130 |
-
"""Close the underlying HTTP client."""
|
| 131 |
-
await self._client.aclose()
|
| 132 |
-
|
| 133 |
-
async def __aenter__(self) -> BaseCloudClient:
|
| 134 |
-
"""Enter async context manager."""
|
| 135 |
-
return self
|
| 136 |
-
|
| 137 |
-
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 138 |
-
"""Exit async context manager, closing the client."""
|
| 139 |
-
await self.close()
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def make_byok_cloud_client(
|
| 143 |
-
*,
|
| 144 |
-
provider: str,
|
| 145 |
-
user_key: str,
|
| 146 |
-
model: str | None = None,
|
| 147 |
-
timeout: float = 60.0,
|
| 148 |
-
) -> BaseCloudClient:
|
| 149 |
-
"""Build a per-request cloud LLM client that uses the visitor's API key.
|
| 150 |
-
|
| 151 |
-
Each call returns a **fresh client instance** holding the supplied key
|
| 152 |
-
in its own ``self.api_key`` slot. The visitor's key never lands on any
|
| 153 |
-
module-level singleton, never mixes into the owner-key client, and is
|
| 154 |
-
discarded when the FastAPI request scope ends.
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
provider: One of ``"groq"`` / ``"openai"`` / ``"anthropic"``.
|
| 158 |
-
user_key: The visitor-supplied API key from ``X-User-LLM-Key``.
|
| 159 |
-
model: Override the provider's default model.
|
| 160 |
-
timeout: Per-request HTTP timeout in seconds.
|
| 161 |
-
|
| 162 |
-
Returns:
|
| 163 |
-
A new ``BaseCloudClient`` subclass instance bound to the visitor key.
|
| 164 |
-
|
| 165 |
-
Raises:
|
| 166 |
-
ValueError: ``provider`` is not in the BYOK allowlist or ``user_key``
|
| 167 |
-
is missing.
|
| 168 |
-
"""
|
| 169 |
-
if not user_key or not user_key.strip():
|
| 170 |
-
raise ValueError("make_byok_cloud_client called without a user key")
|
| 171 |
-
prov = (provider or "").lower()
|
| 172 |
-
if prov == "groq":
|
| 173 |
-
return GroqClient(
|
| 174 |
-
api_key=user_key.strip(), model=model or "llama-3.1-8b-instant", timeout=timeout
|
| 175 |
-
)
|
| 176 |
-
if prov == "openai":
|
| 177 |
-
return OpenAIClient(api_key=user_key.strip(), model=model or "gpt-4o-mini", timeout=timeout)
|
| 178 |
-
if prov == "anthropic":
|
| 179 |
-
return AnthropicClient(
|
| 180 |
-
api_key=user_key.strip(),
|
| 181 |
-
model=model or "claude-sonnet-4-20250514",
|
| 182 |
-
timeout=timeout,
|
| 183 |
-
)
|
| 184 |
-
raise ValueError(f"BYOK provider not supported: {provider!r}")
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
class OpenAICompatibleClient(BaseCloudClient):
|
| 188 |
-
"""Shared client for OpenAI Chat Completions-compatible APIs.
|
| 189 |
-
|
| 190 |
-
Both Groq and OpenAI implement the same wire format
|
| 191 |
-
(``POST /chat/completions`` + SSE streaming). Subclasses supply only
|
| 192 |
-
the ``api_base`` URL and the ``provider`` tag — every method on
|
| 193 |
-
``BaseCloudClient`` is implemented once, here, and inherited.
|
| 194 |
-
"""
|
| 195 |
-
|
| 196 |
-
#: Subclasses override these two class attrs.
|
| 197 |
-
api_base: str = ""
|
| 198 |
-
provider_name: str = ""
|
| 199 |
-
|
| 200 |
-
def _headers(self) -> dict[str, str]:
|
| 201 |
-
return {
|
| 202 |
-
"Authorization": f"Bearer {self.api_key}",
|
| 203 |
-
"Content-Type": "application/json",
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
@staticmethod
|
| 207 |
-
def _messages(prompt: str, system_prompt: str) -> list[dict[str, str]]:
|
| 208 |
-
out: list[dict[str, str]] = []
|
| 209 |
-
if system_prompt:
|
| 210 |
-
out.append({"role": "system", "content": system_prompt})
|
| 211 |
-
out.append({"role": "user", "content": prompt})
|
| 212 |
-
return out
|
| 213 |
-
|
| 214 |
-
@_retry_on_connection
|
| 215 |
-
async def generate(
|
| 216 |
-
self,
|
| 217 |
-
prompt: str,
|
| 218 |
-
system_prompt: str = "",
|
| 219 |
-
temperature: float = 0.7,
|
| 220 |
-
max_tokens: int = 2048,
|
| 221 |
-
json_mode: bool = False,
|
| 222 |
-
) -> LLMResponse:
|
| 223 |
-
return await self.chat(
|
| 224 |
-
messages=self._messages(prompt, system_prompt),
|
| 225 |
-
temperature=temperature,
|
| 226 |
-
max_tokens=max_tokens,
|
| 227 |
-
json_mode=json_mode,
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
@_retry_on_connection
|
| 231 |
-
async def chat(
|
| 232 |
-
self,
|
| 233 |
-
messages: list[dict],
|
| 234 |
-
temperature: float = 0.7,
|
| 235 |
-
max_tokens: int = 2048,
|
| 236 |
-
json_mode: bool = False,
|
| 237 |
-
) -> LLMResponse:
|
| 238 |
-
payload: dict[str, Any] = {
|
| 239 |
-
"model": self.model,
|
| 240 |
-
"messages": messages,
|
| 241 |
-
"temperature": temperature,
|
| 242 |
-
"max_tokens": max_tokens,
|
| 243 |
-
}
|
| 244 |
-
if json_mode:
|
| 245 |
-
payload["response_format"] = {"type": "json_object"}
|
| 246 |
-
|
| 247 |
-
start = time.perf_counter()
|
| 248 |
-
response = await self._client.post(
|
| 249 |
-
f"{self.api_base}/chat/completions",
|
| 250 |
-
headers=self._headers(),
|
| 251 |
-
json=payload,
|
| 252 |
-
)
|
| 253 |
-
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 254 |
-
response.raise_for_status()
|
| 255 |
-
|
| 256 |
-
data = response.json()
|
| 257 |
-
choice = data.get("choices", [{}])[0]
|
| 258 |
-
message = choice.get("message", {})
|
| 259 |
-
usage = data.get("usage", {})
|
| 260 |
-
|
| 261 |
-
return LLMResponse(
|
| 262 |
-
text=message.get("content", ""),
|
| 263 |
-
model=data.get("model", self.model),
|
| 264 |
-
provider=self.provider_name,
|
| 265 |
-
usage={
|
| 266 |
-
"prompt_tokens": usage.get("prompt_tokens", 0),
|
| 267 |
-
"completion_tokens": usage.get("completion_tokens", 0),
|
| 268 |
-
"total_tokens": usage.get("total_tokens", 0),
|
| 269 |
-
},
|
| 270 |
-
latency_ms=elapsed_ms,
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
@_retry_on_connection
|
| 274 |
-
async def generate_stream(
|
| 275 |
-
self,
|
| 276 |
-
prompt: str,
|
| 277 |
-
system_prompt: str = "",
|
| 278 |
-
temperature: float = 0.7,
|
| 279 |
-
max_tokens: int = 2048,
|
| 280 |
-
) -> AsyncGenerator[str, None]:
|
| 281 |
-
payload: dict[str, Any] = {
|
| 282 |
-
"model": self.model,
|
| 283 |
-
"messages": self._messages(prompt, system_prompt),
|
| 284 |
-
"temperature": temperature,
|
| 285 |
-
"max_tokens": max_tokens,
|
| 286 |
-
"stream": True,
|
| 287 |
-
}
|
| 288 |
-
async with self._client.stream(
|
| 289 |
-
"POST",
|
| 290 |
-
f"{self.api_base}/chat/completions",
|
| 291 |
-
headers={**self._headers(), "Accept": "text/event-stream"},
|
| 292 |
-
json=payload,
|
| 293 |
-
) as resp:
|
| 294 |
-
resp.raise_for_status()
|
| 295 |
-
async for line in resp.aiter_lines():
|
| 296 |
-
line = line.strip()
|
| 297 |
-
if not line.startswith("data: "):
|
| 298 |
-
continue
|
| 299 |
-
data_str = line[6:]
|
| 300 |
-
if data_str == "[DONE]":
|
| 301 |
-
break
|
| 302 |
-
try:
|
| 303 |
-
data = json.loads(data_str)
|
| 304 |
-
except json.JSONDecodeError:
|
| 305 |
-
continue
|
| 306 |
-
choice = data.get("choices", [{}])[0]
|
| 307 |
-
token = choice.get("delta", {}).get("content", "")
|
| 308 |
-
if token:
|
| 309 |
-
yield token
|
| 310 |
-
|
| 311 |
-
@_retry_on_connection
|
| 312 |
-
async def health_check(self) -> bool:
|
| 313 |
-
try:
|
| 314 |
-
response = await self._client.get(f"{self.api_base}/models", headers=self._headers())
|
| 315 |
-
return response.status_code in (200, 401)
|
| 316 |
-
except (httpx.ConnectError, httpx.TimeoutException):
|
| 317 |
-
return False
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
class GroqClient(OpenAICompatibleClient):
|
| 321 |
-
"""Groq cloud LLM client (OpenAI-compatible API at api.groq.com)."""
|
| 322 |
-
|
| 323 |
-
provider_name = "groq"
|
| 324 |
-
|
| 325 |
-
def __init__(
|
| 326 |
-
self,
|
| 327 |
-
api_key: str,
|
| 328 |
-
model: str = "llama-3.3-70b-versatile",
|
| 329 |
-
timeout: float = 60.0,
|
| 330 |
-
) -> None:
|
| 331 |
-
super().__init__(api_key=api_key, model=model, timeout=timeout)
|
| 332 |
-
self.api_base = settings.groq_api_base
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
class OpenAIClient(OpenAICompatibleClient):
|
| 336 |
-
"""OpenAI cloud LLM client (Chat Completions API at api.openai.com)."""
|
| 337 |
-
|
| 338 |
-
provider_name = "openai"
|
| 339 |
-
|
| 340 |
-
def __init__(
|
| 341 |
-
self,
|
| 342 |
-
api_key: str,
|
| 343 |
-
model: str = "gpt-4o-mini",
|
| 344 |
-
timeout: float = 60.0,
|
| 345 |
-
) -> None:
|
| 346 |
-
super().__init__(api_key=api_key, model=model, timeout=timeout)
|
| 347 |
-
self.api_base = settings.openai_api_base
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
class AnthropicClient(BaseCloudClient):
|
| 351 |
-
"""Anthropic Claude cloud LLM client using the Messages API.
|
| 352 |
-
|
| 353 |
-
Args:
|
| 354 |
-
api_key: Anthropic API key.
|
| 355 |
-
model: Model identifier. Defaults to "claude-sonnet-4-20250514".
|
| 356 |
-
timeout: Request timeout in seconds.
|
| 357 |
-
"""
|
| 358 |
-
|
| 359 |
-
def __init__(
|
| 360 |
-
self,
|
| 361 |
-
api_key: str,
|
| 362 |
-
model: str = "claude-sonnet-4-20250514",
|
| 363 |
-
timeout: float = 60.0,
|
| 364 |
-
) -> None:
|
| 365 |
-
super().__init__(api_key=api_key, model=model, timeout=timeout)
|
| 366 |
-
self._api_base = settings.anthropic_api_base
|
| 367 |
-
|
| 368 |
-
def _headers(self) -> dict[str, str]:
|
| 369 |
-
"""Build request headers with Anthropic-specific authentication."""
|
| 370 |
-
return {
|
| 371 |
-
"x-api-key": self.api_key,
|
| 372 |
-
"anthropic-version": "2023-06-01",
|
| 373 |
-
"Content-Type": "application/json",
|
| 374 |
-
}
|
| 375 |
-
|
| 376 |
-
@_retry_on_connection
|
| 377 |
-
async def generate(
|
| 378 |
-
self,
|
| 379 |
-
prompt: str,
|
| 380 |
-
system_prompt: str = "",
|
| 381 |
-
temperature: float = 0.7,
|
| 382 |
-
max_tokens: int = 2048,
|
| 383 |
-
json_mode: bool = False,
|
| 384 |
-
) -> LLMResponse:
|
| 385 |
-
"""Generate a completion via Anthropic's Messages API.
|
| 386 |
-
|
| 387 |
-
Args:
|
| 388 |
-
prompt: The user prompt text.
|
| 389 |
-
system_prompt: Optional system context.
|
| 390 |
-
temperature: Sampling temperature.
|
| 391 |
-
max_tokens: Maximum tokens to generate.
|
| 392 |
-
json_mode: Anthropic does not support native JSON mode; ignored.
|
| 393 |
-
|
| 394 |
-
Returns:
|
| 395 |
-
LLMResponse with generated text and metadata.
|
| 396 |
-
"""
|
| 397 |
-
messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
|
| 398 |
-
return await self._send_messages(
|
| 399 |
-
messages=messages,
|
| 400 |
-
system_prompt=system_prompt,
|
| 401 |
-
temperature=temperature,
|
| 402 |
-
max_tokens=max_tokens,
|
| 403 |
-
)
|
| 404 |
-
|
| 405 |
-
@_retry_on_connection
|
| 406 |
-
async def chat(
|
| 407 |
-
self,
|
| 408 |
-
messages: list[dict],
|
| 409 |
-
temperature: float = 0.7,
|
| 410 |
-
max_tokens: int = 2048,
|
| 411 |
-
) -> LLMResponse:
|
| 412 |
-
"""Send a chat request to Anthropic's Messages API.
|
| 413 |
-
|
| 414 |
-
Anthropic uses a separate 'system' parameter instead of a system message
|
| 415 |
-
in the messages list. This method extracts any system message and handles
|
| 416 |
-
the format conversion.
|
| 417 |
-
|
| 418 |
-
Args:
|
| 419 |
-
messages: List of message dicts with 'role' and 'content' keys.
|
| 420 |
-
temperature: Sampling temperature.
|
| 421 |
-
max_tokens: Maximum tokens to generate.
|
| 422 |
-
|
| 423 |
-
Returns:
|
| 424 |
-
LLMResponse with generated text and metadata.
|
| 425 |
-
"""
|
| 426 |
-
# Extract system message if present
|
| 427 |
-
system_prompt = ""
|
| 428 |
-
anthropic_messages: list[dict[str, str]] = []
|
| 429 |
-
for msg in messages:
|
| 430 |
-
if msg.get("role") == "system":
|
| 431 |
-
system_prompt = msg.get("content", "")
|
| 432 |
-
else:
|
| 433 |
-
anthropic_messages.append(msg)
|
| 434 |
-
|
| 435 |
-
return await self._send_messages(
|
| 436 |
-
messages=anthropic_messages,
|
| 437 |
-
system_prompt=system_prompt,
|
| 438 |
-
temperature=temperature,
|
| 439 |
-
max_tokens=max_tokens,
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
async def _send_messages(
|
| 443 |
-
self,
|
| 444 |
-
messages: list[dict],
|
| 445 |
-
system_prompt: str = "",
|
| 446 |
-
temperature: float = 0.7,
|
| 447 |
-
max_tokens: int = 2048,
|
| 448 |
-
) -> LLMResponse:
|
| 449 |
-
"""Internal method to send messages to Anthropic's API.
|
| 450 |
-
|
| 451 |
-
Args:
|
| 452 |
-
messages: Anthropic-formatted messages (no system role).
|
| 453 |
-
system_prompt: System prompt passed as top-level parameter.
|
| 454 |
-
temperature: Sampling temperature.
|
| 455 |
-
max_tokens: Maximum tokens to generate.
|
| 456 |
-
|
| 457 |
-
Returns:
|
| 458 |
-
LLMResponse with generated text and metadata.
|
| 459 |
-
"""
|
| 460 |
-
payload: dict[str, Any] = {
|
| 461 |
-
"model": self.model,
|
| 462 |
-
"messages": messages,
|
| 463 |
-
"temperature": temperature,
|
| 464 |
-
"max_tokens": max_tokens,
|
| 465 |
-
}
|
| 466 |
-
if system_prompt:
|
| 467 |
-
payload["system"] = system_prompt
|
| 468 |
-
|
| 469 |
-
start = time.perf_counter()
|
| 470 |
-
response = await self._client.post(
|
| 471 |
-
f"{self._api_base}/messages",
|
| 472 |
-
headers=self._headers(),
|
| 473 |
-
json=payload,
|
| 474 |
-
)
|
| 475 |
-
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 476 |
-
response.raise_for_status()
|
| 477 |
-
|
| 478 |
-
data = response.json()
|
| 479 |
-
# Anthropic returns content as a list of content blocks
|
| 480 |
-
content_blocks = data.get("content", [])
|
| 481 |
-
text = ""
|
| 482 |
-
for block in content_blocks:
|
| 483 |
-
if block.get("type") == "text":
|
| 484 |
-
text += block.get("text", "")
|
| 485 |
-
|
| 486 |
-
usage = data.get("usage", {})
|
| 487 |
-
return LLMResponse(
|
| 488 |
-
text=text,
|
| 489 |
-
model=data.get("model", self.model),
|
| 490 |
-
provider="anthropic",
|
| 491 |
-
usage={
|
| 492 |
-
"prompt_tokens": usage.get("input_tokens", 0),
|
| 493 |
-
"completion_tokens": usage.get("output_tokens", 0),
|
| 494 |
-
"total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
|
| 495 |
-
},
|
| 496 |
-
latency_ms=elapsed_ms,
|
| 497 |
-
)
|
| 498 |
-
|
| 499 |
-
async def generate_stream(
|
| 500 |
-
self,
|
| 501 |
-
prompt: str,
|
| 502 |
-
system_prompt: str = "",
|
| 503 |
-
temperature: float = 0.7,
|
| 504 |
-
max_tokens: int = 2048,
|
| 505 |
-
) -> AsyncGenerator[str, None]:
|
| 506 |
-
"""Stream a completion via Anthropic's Messages API.
|
| 507 |
-
|
| 508 |
-
Anthropic supports streaming via SSE. Yields text content blocks
|
| 509 |
-
as they arrive.
|
| 510 |
-
|
| 511 |
-
Args:
|
| 512 |
-
prompt: The user prompt text.
|
| 513 |
-
system_prompt: Optional system context.
|
| 514 |
-
temperature: Sampling temperature.
|
| 515 |
-
max_tokens: Maximum tokens to generate.
|
| 516 |
-
|
| 517 |
-
Yields:
|
| 518 |
-
Token strings as they are generated.
|
| 519 |
-
"""
|
| 520 |
-
payload: dict[str, Any] = {
|
| 521 |
-
"model": self.model,
|
| 522 |
-
"messages": [{"role": "user", "content": prompt}],
|
| 523 |
-
"temperature": temperature,
|
| 524 |
-
"max_tokens": max_tokens,
|
| 525 |
-
"stream": True,
|
| 526 |
-
}
|
| 527 |
-
if system_prompt:
|
| 528 |
-
payload["system"] = system_prompt
|
| 529 |
-
|
| 530 |
-
async with self._client.stream(
|
| 531 |
-
"POST",
|
| 532 |
-
f"{self._api_base}/messages",
|
| 533 |
-
headers={**self._headers(), "Accept": "text/event-stream"},
|
| 534 |
-
json=payload,
|
| 535 |
-
) as resp:
|
| 536 |
-
resp.raise_for_status()
|
| 537 |
-
async for line in resp.aiter_lines():
|
| 538 |
-
line = line.strip()
|
| 539 |
-
if line.startswith("data: "):
|
| 540 |
-
data_str = line[6:]
|
| 541 |
-
if data_str == "[DONE]":
|
| 542 |
-
break
|
| 543 |
-
try:
|
| 544 |
-
data = json.loads(data_str)
|
| 545 |
-
event_type = data.get("type", "")
|
| 546 |
-
if event_type == "content_block_delta":
|
| 547 |
-
delta = data.get("delta", {})
|
| 548 |
-
token = delta.get("text", "")
|
| 549 |
-
if token:
|
| 550 |
-
yield token
|
| 551 |
-
elif event_type == "message_stop":
|
| 552 |
-
break
|
| 553 |
-
except json.JSONDecodeError:
|
| 554 |
-
continue
|
| 555 |
-
|
| 556 |
-
@_retry_on_connection
|
| 557 |
-
async def health_check(self) -> bool:
|
| 558 |
-
"""Check if the Anthropic API is reachable.
|
| 559 |
-
|
| 560 |
-
Returns:
|
| 561 |
-
True if the API responds.
|
| 562 |
-
"""
|
| 563 |
-
try:
|
| 564 |
-
# Anthropic doesn't have a simple health endpoint; try a minimal request
|
| 565 |
-
response = await self._client.post(
|
| 566 |
-
f"{self._api_base}/messages",
|
| 567 |
-
headers=self._headers(),
|
| 568 |
-
json={
|
| 569 |
-
"model": self.model,
|
| 570 |
-
"messages": [{"role": "user", "content": "hi"}],
|
| 571 |
-
"max_tokens": 1,
|
| 572 |
-
},
|
| 573 |
-
)
|
| 574 |
-
# Any response (even 401) means the service is reachable
|
| 575 |
-
return response.status_code in (200, 401, 400)
|
| 576 |
-
except (httpx.ConnectError, httpx.TimeoutException):
|
| 577 |
-
return False
|
|
|
|
| 1 |
+
"""Cloud LLM provider clients (Groq, OpenAI, Anthropic Claude)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from enum import StrEnum
|
| 9 |
+
from typing import TYPE_CHECKING, Any
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from collections.abc import AsyncGenerator
|
| 13 |
+
|
| 14 |
+
import httpx
|
| 15 |
+
from tenacity import (
|
| 16 |
+
retry,
|
| 17 |
+
retry_if_exception_type,
|
| 18 |
+
stop_after_attempt,
|
| 19 |
+
wait_exponential,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from config.settings import settings
|
| 23 |
+
from inference.llm_factory import LLMResponse
|
| 24 |
+
from utils.logging import get_logger
|
| 25 |
+
|
| 26 |
+
logger = get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
# Retry decorator for transient connection failures only
|
| 29 |
+
_retry_on_connection = retry(
|
| 30 |
+
retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
|
| 31 |
+
stop=stop_after_attempt(3),
|
| 32 |
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
| 33 |
+
reraise=True,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class LLMProvider(StrEnum):
|
| 38 |
+
"""Supported LLM provider identifiers."""
|
| 39 |
+
|
| 40 |
+
OLLAMA = "ollama"
|
| 41 |
+
GROQ = "groq"
|
| 42 |
+
OPENAI = "openai"
|
| 43 |
+
ANTHROPIC = "anthropic"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BaseCloudClient(ABC):
|
| 47 |
+
"""Abstract base class for cloud LLM provider clients.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
api_key: Provider API key for authentication.
|
| 51 |
+
model: Default model identifier.
|
| 52 |
+
timeout: Request timeout in seconds.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, api_key: str, model: str, timeout: float = 60.0) -> None:
|
| 56 |
+
self.api_key = api_key
|
| 57 |
+
self.model = model
|
| 58 |
+
self.timeout = timeout
|
| 59 |
+
self._client = httpx.AsyncClient(timeout=httpx.Timeout(timeout))
|
| 60 |
+
|
| 61 |
+
@abstractmethod
|
| 62 |
+
async def generate(
|
| 63 |
+
self,
|
| 64 |
+
prompt: str,
|
| 65 |
+
system_prompt: str = "",
|
| 66 |
+
temperature: float = 0.7,
|
| 67 |
+
max_tokens: int = 2048,
|
| 68 |
+
json_mode: bool = False,
|
| 69 |
+
) -> LLMResponse:
|
| 70 |
+
"""Generate a completion from the provider.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
prompt: The user prompt text.
|
| 74 |
+
system_prompt: Optional system context.
|
| 75 |
+
temperature: Sampling temperature.
|
| 76 |
+
max_tokens: Maximum tokens to generate.
|
| 77 |
+
json_mode: When True, request JSON-formatted output.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
LLMResponse with generated text and metadata.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
@abstractmethod
|
| 84 |
+
async def chat(
|
| 85 |
+
self,
|
| 86 |
+
messages: list[dict],
|
| 87 |
+
temperature: float = 0.7,
|
| 88 |
+
max_tokens: int = 2048,
|
| 89 |
+
) -> LLMResponse:
|
| 90 |
+
"""Send a chat conversation to the provider.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
messages: List of message dicts with 'role' and 'content' keys.
|
| 94 |
+
temperature: Sampling temperature.
|
| 95 |
+
max_tokens: Maximum tokens to generate.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
LLMResponse with generated text and metadata.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
@abstractmethod
|
| 102 |
+
async def generate_stream(
|
| 103 |
+
self,
|
| 104 |
+
prompt: str,
|
| 105 |
+
system_prompt: str = "",
|
| 106 |
+
temperature: float = 0.7,
|
| 107 |
+
max_tokens: int = 2048,
|
| 108 |
+
) -> AsyncGenerator[str, None]:
|
| 109 |
+
"""Stream a completion from the provider, yielding tokens as they arrive.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
prompt: The user prompt text.
|
| 113 |
+
system_prompt: Optional system context.
|
| 114 |
+
temperature: Sampling temperature.
|
| 115 |
+
max_tokens: Maximum tokens to generate.
|
| 116 |
+
|
| 117 |
+
Yields:
|
| 118 |
+
Token strings as they are generated.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
@abstractmethod
|
| 122 |
+
async def health_check(self) -> bool:
|
| 123 |
+
"""Check if the provider API is reachable.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
True if the API responds successfully.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
async def close(self) -> None:
|
| 130 |
+
"""Close the underlying HTTP client."""
|
| 131 |
+
await self._client.aclose()
|
| 132 |
+
|
| 133 |
+
async def __aenter__(self) -> BaseCloudClient:
|
| 134 |
+
"""Enter async context manager."""
|
| 135 |
+
return self
|
| 136 |
+
|
| 137 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 138 |
+
"""Exit async context manager, closing the client."""
|
| 139 |
+
await self.close()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def make_byok_cloud_client(
|
| 143 |
+
*,
|
| 144 |
+
provider: str,
|
| 145 |
+
user_key: str,
|
| 146 |
+
model: str | None = None,
|
| 147 |
+
timeout: float = 60.0,
|
| 148 |
+
) -> BaseCloudClient:
|
| 149 |
+
"""Build a per-request cloud LLM client that uses the visitor's API key.
|
| 150 |
+
|
| 151 |
+
Each call returns a **fresh client instance** holding the supplied key
|
| 152 |
+
in its own ``self.api_key`` slot. The visitor's key never lands on any
|
| 153 |
+
module-level singleton, never mixes into the owner-key client, and is
|
| 154 |
+
discarded when the FastAPI request scope ends.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
provider: One of ``"groq"`` / ``"openai"`` / ``"anthropic"``.
|
| 158 |
+
user_key: The visitor-supplied API key from ``X-User-LLM-Key``.
|
| 159 |
+
model: Override the provider's default model.
|
| 160 |
+
timeout: Per-request HTTP timeout in seconds.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
A new ``BaseCloudClient`` subclass instance bound to the visitor key.
|
| 164 |
+
|
| 165 |
+
Raises:
|
| 166 |
+
ValueError: ``provider`` is not in the BYOK allowlist or ``user_key``
|
| 167 |
+
is missing.
|
| 168 |
+
"""
|
| 169 |
+
if not user_key or not user_key.strip():
|
| 170 |
+
raise ValueError("make_byok_cloud_client called without a user key")
|
| 171 |
+
prov = (provider or "").lower()
|
| 172 |
+
if prov == "groq":
|
| 173 |
+
return GroqClient(
|
| 174 |
+
api_key=user_key.strip(), model=model or "llama-3.1-8b-instant", timeout=timeout
|
| 175 |
+
)
|
| 176 |
+
if prov == "openai":
|
| 177 |
+
return OpenAIClient(api_key=user_key.strip(), model=model or "gpt-4o-mini", timeout=timeout)
|
| 178 |
+
if prov == "anthropic":
|
| 179 |
+
return AnthropicClient(
|
| 180 |
+
api_key=user_key.strip(),
|
| 181 |
+
model=model or "claude-sonnet-4-20250514",
|
| 182 |
+
timeout=timeout,
|
| 183 |
+
)
|
| 184 |
+
raise ValueError(f"BYOK provider not supported: {provider!r}")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class OpenAICompatibleClient(BaseCloudClient):
|
| 188 |
+
"""Shared client for OpenAI Chat Completions-compatible APIs.
|
| 189 |
+
|
| 190 |
+
Both Groq and OpenAI implement the same wire format
|
| 191 |
+
(``POST /chat/completions`` + SSE streaming). Subclasses supply only
|
| 192 |
+
the ``api_base`` URL and the ``provider`` tag — every method on
|
| 193 |
+
``BaseCloudClient`` is implemented once, here, and inherited.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
#: Subclasses override these two class attrs.
|
| 197 |
+
api_base: str = ""
|
| 198 |
+
provider_name: str = ""
|
| 199 |
+
|
| 200 |
+
def _headers(self) -> dict[str, str]:
|
| 201 |
+
return {
|
| 202 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 203 |
+
"Content-Type": "application/json",
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def _messages(prompt: str, system_prompt: str) -> list[dict[str, str]]:
|
| 208 |
+
out: list[dict[str, str]] = []
|
| 209 |
+
if system_prompt:
|
| 210 |
+
out.append({"role": "system", "content": system_prompt})
|
| 211 |
+
out.append({"role": "user", "content": prompt})
|
| 212 |
+
return out
|
| 213 |
+
|
| 214 |
+
@_retry_on_connection
|
| 215 |
+
async def generate(
|
| 216 |
+
self,
|
| 217 |
+
prompt: str,
|
| 218 |
+
system_prompt: str = "",
|
| 219 |
+
temperature: float = 0.7,
|
| 220 |
+
max_tokens: int = 2048,
|
| 221 |
+
json_mode: bool = False,
|
| 222 |
+
) -> LLMResponse:
|
| 223 |
+
return await self.chat(
|
| 224 |
+
messages=self._messages(prompt, system_prompt),
|
| 225 |
+
temperature=temperature,
|
| 226 |
+
max_tokens=max_tokens,
|
| 227 |
+
json_mode=json_mode,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
@_retry_on_connection
|
| 231 |
+
async def chat(
|
| 232 |
+
self,
|
| 233 |
+
messages: list[dict],
|
| 234 |
+
temperature: float = 0.7,
|
| 235 |
+
max_tokens: int = 2048,
|
| 236 |
+
json_mode: bool = False,
|
| 237 |
+
) -> LLMResponse:
|
| 238 |
+
payload: dict[str, Any] = {
|
| 239 |
+
"model": self.model,
|
| 240 |
+
"messages": messages,
|
| 241 |
+
"temperature": temperature,
|
| 242 |
+
"max_tokens": max_tokens,
|
| 243 |
+
}
|
| 244 |
+
if json_mode:
|
| 245 |
+
payload["response_format"] = {"type": "json_object"}
|
| 246 |
+
|
| 247 |
+
start = time.perf_counter()
|
| 248 |
+
response = await self._client.post(
|
| 249 |
+
f"{self.api_base}/chat/completions",
|
| 250 |
+
headers=self._headers(),
|
| 251 |
+
json=payload,
|
| 252 |
+
)
|
| 253 |
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 254 |
+
response.raise_for_status()
|
| 255 |
+
|
| 256 |
+
data = response.json()
|
| 257 |
+
choice = data.get("choices", [{}])[0]
|
| 258 |
+
message = choice.get("message", {})
|
| 259 |
+
usage = data.get("usage", {})
|
| 260 |
+
|
| 261 |
+
return LLMResponse(
|
| 262 |
+
text=message.get("content", ""),
|
| 263 |
+
model=data.get("model", self.model),
|
| 264 |
+
provider=self.provider_name,
|
| 265 |
+
usage={
|
| 266 |
+
"prompt_tokens": usage.get("prompt_tokens", 0),
|
| 267 |
+
"completion_tokens": usage.get("completion_tokens", 0),
|
| 268 |
+
"total_tokens": usage.get("total_tokens", 0),
|
| 269 |
+
},
|
| 270 |
+
latency_ms=elapsed_ms,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
@_retry_on_connection
|
| 274 |
+
async def generate_stream(
|
| 275 |
+
self,
|
| 276 |
+
prompt: str,
|
| 277 |
+
system_prompt: str = "",
|
| 278 |
+
temperature: float = 0.7,
|
| 279 |
+
max_tokens: int = 2048,
|
| 280 |
+
) -> AsyncGenerator[str, None]:
|
| 281 |
+
payload: dict[str, Any] = {
|
| 282 |
+
"model": self.model,
|
| 283 |
+
"messages": self._messages(prompt, system_prompt),
|
| 284 |
+
"temperature": temperature,
|
| 285 |
+
"max_tokens": max_tokens,
|
| 286 |
+
"stream": True,
|
| 287 |
+
}
|
| 288 |
+
async with self._client.stream(
|
| 289 |
+
"POST",
|
| 290 |
+
f"{self.api_base}/chat/completions",
|
| 291 |
+
headers={**self._headers(), "Accept": "text/event-stream"},
|
| 292 |
+
json=payload,
|
| 293 |
+
) as resp:
|
| 294 |
+
resp.raise_for_status()
|
| 295 |
+
async for line in resp.aiter_lines():
|
| 296 |
+
line = line.strip()
|
| 297 |
+
if not line.startswith("data: "):
|
| 298 |
+
continue
|
| 299 |
+
data_str = line[6:]
|
| 300 |
+
if data_str == "[DONE]":
|
| 301 |
+
break
|
| 302 |
+
try:
|
| 303 |
+
data = json.loads(data_str)
|
| 304 |
+
except json.JSONDecodeError:
|
| 305 |
+
continue
|
| 306 |
+
choice = data.get("choices", [{}])[0]
|
| 307 |
+
token = choice.get("delta", {}).get("content", "")
|
| 308 |
+
if token:
|
| 309 |
+
yield token
|
| 310 |
+
|
| 311 |
+
@_retry_on_connection
|
| 312 |
+
async def health_check(self) -> bool:
|
| 313 |
+
try:
|
| 314 |
+
response = await self._client.get(f"{self.api_base}/models", headers=self._headers())
|
| 315 |
+
return response.status_code in (200, 401)
|
| 316 |
+
except (httpx.ConnectError, httpx.TimeoutException):
|
| 317 |
+
return False
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class GroqClient(OpenAICompatibleClient):
|
| 321 |
+
"""Groq cloud LLM client (OpenAI-compatible API at api.groq.com)."""
|
| 322 |
+
|
| 323 |
+
provider_name = "groq"
|
| 324 |
+
|
| 325 |
+
def __init__(
|
| 326 |
+
self,
|
| 327 |
+
api_key: str,
|
| 328 |
+
model: str = "llama-3.3-70b-versatile",
|
| 329 |
+
timeout: float = 60.0,
|
| 330 |
+
) -> None:
|
| 331 |
+
super().__init__(api_key=api_key, model=model, timeout=timeout)
|
| 332 |
+
self.api_base = settings.groq_api_base
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class OpenAIClient(OpenAICompatibleClient):
|
| 336 |
+
"""OpenAI cloud LLM client (Chat Completions API at api.openai.com)."""
|
| 337 |
+
|
| 338 |
+
provider_name = "openai"
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
api_key: str,
|
| 343 |
+
model: str = "gpt-4o-mini",
|
| 344 |
+
timeout: float = 60.0,
|
| 345 |
+
) -> None:
|
| 346 |
+
super().__init__(api_key=api_key, model=model, timeout=timeout)
|
| 347 |
+
self.api_base = settings.openai_api_base
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class AnthropicClient(BaseCloudClient):
|
| 351 |
+
"""Anthropic Claude cloud LLM client using the Messages API.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
api_key: Anthropic API key.
|
| 355 |
+
model: Model identifier. Defaults to "claude-sonnet-4-20250514".
|
| 356 |
+
timeout: Request timeout in seconds.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
def __init__(
|
| 360 |
+
self,
|
| 361 |
+
api_key: str,
|
| 362 |
+
model: str = "claude-sonnet-4-20250514",
|
| 363 |
+
timeout: float = 60.0,
|
| 364 |
+
) -> None:
|
| 365 |
+
super().__init__(api_key=api_key, model=model, timeout=timeout)
|
| 366 |
+
self._api_base = settings.anthropic_api_base
|
| 367 |
+
|
| 368 |
+
def _headers(self) -> dict[str, str]:
|
| 369 |
+
"""Build request headers with Anthropic-specific authentication."""
|
| 370 |
+
return {
|
| 371 |
+
"x-api-key": self.api_key,
|
| 372 |
+
"anthropic-version": "2023-06-01",
|
| 373 |
+
"Content-Type": "application/json",
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
@_retry_on_connection
|
| 377 |
+
async def generate(
|
| 378 |
+
self,
|
| 379 |
+
prompt: str,
|
| 380 |
+
system_prompt: str = "",
|
| 381 |
+
temperature: float = 0.7,
|
| 382 |
+
max_tokens: int = 2048,
|
| 383 |
+
json_mode: bool = False,
|
| 384 |
+
) -> LLMResponse:
|
| 385 |
+
"""Generate a completion via Anthropic's Messages API.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
prompt: The user prompt text.
|
| 389 |
+
system_prompt: Optional system context.
|
| 390 |
+
temperature: Sampling temperature.
|
| 391 |
+
max_tokens: Maximum tokens to generate.
|
| 392 |
+
json_mode: Anthropic does not support native JSON mode; ignored.
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
LLMResponse with generated text and metadata.
|
| 396 |
+
"""
|
| 397 |
+
messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
|
| 398 |
+
return await self._send_messages(
|
| 399 |
+
messages=messages,
|
| 400 |
+
system_prompt=system_prompt,
|
| 401 |
+
temperature=temperature,
|
| 402 |
+
max_tokens=max_tokens,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
@_retry_on_connection
|
| 406 |
+
async def chat(
|
| 407 |
+
self,
|
| 408 |
+
messages: list[dict],
|
| 409 |
+
temperature: float = 0.7,
|
| 410 |
+
max_tokens: int = 2048,
|
| 411 |
+
) -> LLMResponse:
|
| 412 |
+
"""Send a chat request to Anthropic's Messages API.
|
| 413 |
+
|
| 414 |
+
Anthropic uses a separate 'system' parameter instead of a system message
|
| 415 |
+
in the messages list. This method extracts any system message and handles
|
| 416 |
+
the format conversion.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
messages: List of message dicts with 'role' and 'content' keys.
|
| 420 |
+
temperature: Sampling temperature.
|
| 421 |
+
max_tokens: Maximum tokens to generate.
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
LLMResponse with generated text and metadata.
|
| 425 |
+
"""
|
| 426 |
+
# Extract system message if present
|
| 427 |
+
system_prompt = ""
|
| 428 |
+
anthropic_messages: list[dict[str, str]] = []
|
| 429 |
+
for msg in messages:
|
| 430 |
+
if msg.get("role") == "system":
|
| 431 |
+
system_prompt = msg.get("content", "")
|
| 432 |
+
else:
|
| 433 |
+
anthropic_messages.append(msg)
|
| 434 |
+
|
| 435 |
+
return await self._send_messages(
|
| 436 |
+
messages=anthropic_messages,
|
| 437 |
+
system_prompt=system_prompt,
|
| 438 |
+
temperature=temperature,
|
| 439 |
+
max_tokens=max_tokens,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
async def _send_messages(
|
| 443 |
+
self,
|
| 444 |
+
messages: list[dict],
|
| 445 |
+
system_prompt: str = "",
|
| 446 |
+
temperature: float = 0.7,
|
| 447 |
+
max_tokens: int = 2048,
|
| 448 |
+
) -> LLMResponse:
|
| 449 |
+
"""Internal method to send messages to Anthropic's API.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
messages: Anthropic-formatted messages (no system role).
|
| 453 |
+
system_prompt: System prompt passed as top-level parameter.
|
| 454 |
+
temperature: Sampling temperature.
|
| 455 |
+
max_tokens: Maximum tokens to generate.
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
LLMResponse with generated text and metadata.
|
| 459 |
+
"""
|
| 460 |
+
payload: dict[str, Any] = {
|
| 461 |
+
"model": self.model,
|
| 462 |
+
"messages": messages,
|
| 463 |
+
"temperature": temperature,
|
| 464 |
+
"max_tokens": max_tokens,
|
| 465 |
+
}
|
| 466 |
+
if system_prompt:
|
| 467 |
+
payload["system"] = system_prompt
|
| 468 |
+
|
| 469 |
+
start = time.perf_counter()
|
| 470 |
+
response = await self._client.post(
|
| 471 |
+
f"{self._api_base}/messages",
|
| 472 |
+
headers=self._headers(),
|
| 473 |
+
json=payload,
|
| 474 |
+
)
|
| 475 |
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 476 |
+
response.raise_for_status()
|
| 477 |
+
|
| 478 |
+
data = response.json()
|
| 479 |
+
# Anthropic returns content as a list of content blocks
|
| 480 |
+
content_blocks = data.get("content", [])
|
| 481 |
+
text = ""
|
| 482 |
+
for block in content_blocks:
|
| 483 |
+
if block.get("type") == "text":
|
| 484 |
+
text += block.get("text", "")
|
| 485 |
+
|
| 486 |
+
usage = data.get("usage", {})
|
| 487 |
+
return LLMResponse(
|
| 488 |
+
text=text,
|
| 489 |
+
model=data.get("model", self.model),
|
| 490 |
+
provider="anthropic",
|
| 491 |
+
usage={
|
| 492 |
+
"prompt_tokens": usage.get("input_tokens", 0),
|
| 493 |
+
"completion_tokens": usage.get("output_tokens", 0),
|
| 494 |
+
"total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
|
| 495 |
+
},
|
| 496 |
+
latency_ms=elapsed_ms,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
async def generate_stream(
|
| 500 |
+
self,
|
| 501 |
+
prompt: str,
|
| 502 |
+
system_prompt: str = "",
|
| 503 |
+
temperature: float = 0.7,
|
| 504 |
+
max_tokens: int = 2048,
|
| 505 |
+
) -> AsyncGenerator[str, None]:
|
| 506 |
+
"""Stream a completion via Anthropic's Messages API.
|
| 507 |
+
|
| 508 |
+
Anthropic supports streaming via SSE. Yields text content blocks
|
| 509 |
+
as they arrive.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
prompt: The user prompt text.
|
| 513 |
+
system_prompt: Optional system context.
|
| 514 |
+
temperature: Sampling temperature.
|
| 515 |
+
max_tokens: Maximum tokens to generate.
|
| 516 |
+
|
| 517 |
+
Yields:
|
| 518 |
+
Token strings as they are generated.
|
| 519 |
+
"""
|
| 520 |
+
payload: dict[str, Any] = {
|
| 521 |
+
"model": self.model,
|
| 522 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 523 |
+
"temperature": temperature,
|
| 524 |
+
"max_tokens": max_tokens,
|
| 525 |
+
"stream": True,
|
| 526 |
+
}
|
| 527 |
+
if system_prompt:
|
| 528 |
+
payload["system"] = system_prompt
|
| 529 |
+
|
| 530 |
+
async with self._client.stream(
|
| 531 |
+
"POST",
|
| 532 |
+
f"{self._api_base}/messages",
|
| 533 |
+
headers={**self._headers(), "Accept": "text/event-stream"},
|
| 534 |
+
json=payload,
|
| 535 |
+
) as resp:
|
| 536 |
+
resp.raise_for_status()
|
| 537 |
+
async for line in resp.aiter_lines():
|
| 538 |
+
line = line.strip()
|
| 539 |
+
if line.startswith("data: "):
|
| 540 |
+
data_str = line[6:]
|
| 541 |
+
if data_str == "[DONE]":
|
| 542 |
+
break
|
| 543 |
+
try:
|
| 544 |
+
data = json.loads(data_str)
|
| 545 |
+
event_type = data.get("type", "")
|
| 546 |
+
if event_type == "content_block_delta":
|
| 547 |
+
delta = data.get("delta", {})
|
| 548 |
+
token = delta.get("text", "")
|
| 549 |
+
if token:
|
| 550 |
+
yield token
|
| 551 |
+
elif event_type == "message_stop":
|
| 552 |
+
break
|
| 553 |
+
except json.JSONDecodeError:
|
| 554 |
+
continue
|
| 555 |
+
|
| 556 |
+
@_retry_on_connection
|
| 557 |
+
async def health_check(self) -> bool:
|
| 558 |
+
"""Check if the Anthropic API is reachable.
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
True if the API responds.
|
| 562 |
+
"""
|
| 563 |
+
try:
|
| 564 |
+
# Anthropic doesn't have a simple health endpoint; try a minimal request
|
| 565 |
+
response = await self._client.post(
|
| 566 |
+
f"{self._api_base}/messages",
|
| 567 |
+
headers=self._headers(),
|
| 568 |
+
json={
|
| 569 |
+
"model": self.model,
|
| 570 |
+
"messages": [{"role": "user", "content": "hi"}],
|
| 571 |
+
"max_tokens": 1,
|
| 572 |
+
},
|
| 573 |
+
)
|
| 574 |
+
# Any response (even 401) means the service is reachable
|
| 575 |
+
return response.status_code in (200, 401, 400)
|
| 576 |
+
except (httpx.ConnectError, httpx.TimeoutException):
|
| 577 |
+
return False
|
inference/ollama_client.py
CHANGED
|
@@ -1,334 +1,334 @@
|
|
| 1 |
-
"""Async Ollama client wrapper with streaming support and health checks."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import time
|
| 6 |
-
from typing import TYPE_CHECKING, Any
|
| 7 |
-
|
| 8 |
-
import httpx
|
| 9 |
-
|
| 10 |
-
if TYPE_CHECKING:
|
| 11 |
-
from collections.abc import AsyncGenerator
|
| 12 |
-
from tenacity import (
|
| 13 |
-
retry,
|
| 14 |
-
retry_if_exception_type,
|
| 15 |
-
stop_after_attempt,
|
| 16 |
-
wait_exponential,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
from config.settings import settings
|
| 20 |
-
from inference.llm_factory import LLMResponse
|
| 21 |
-
from utils.logging import get_logger
|
| 22 |
-
|
| 23 |
-
logger = get_logger(__name__)
|
| 24 |
-
|
| 25 |
-
# Retry decorator for transient connection failures only
|
| 26 |
-
_retry_on_connection = retry(
|
| 27 |
-
retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
|
| 28 |
-
stop=stop_after_attempt(3),
|
| 29 |
-
wait=wait_exponential(multiplier=1, min=1, max=10),
|
| 30 |
-
reraise=True,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def make_byok_ollama_client(
|
| 35 |
-
*,
|
| 36 |
-
base_url: str,
|
| 37 |
-
model: str | None = None,
|
| 38 |
-
timeout: float = 60.0,
|
| 39 |
-
) -> OllamaClient:
|
| 40 |
-
"""Build a per-request Ollama client bound to the visitor's instance URL.
|
| 41 |
-
|
| 42 |
-
Visitors running their own local Ollama can paste the public URL of
|
| 43 |
-
that instance into the frontend. Each call returns a **fresh client**
|
| 44 |
-
so the visitor's URL never replaces the owner default at module scope.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
base_url: URL of the visitor's Ollama server (HTTPS preferred).
|
| 48 |
-
model: Override the default model. Falls back to the owner's
|
| 49 |
-
configured ``SAR_LLM_MODEL`` if the visitor's Ollama does not
|
| 50 |
-
advertise its own.
|
| 51 |
-
timeout: Per-request HTTP timeout in seconds.
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
A new ``OllamaClient`` bound to ``base_url``.
|
| 55 |
-
|
| 56 |
-
Raises:
|
| 57 |
-
ValueError: ``base_url`` is empty or whitespace.
|
| 58 |
-
"""
|
| 59 |
-
if not base_url or not base_url.strip():
|
| 60 |
-
raise ValueError("make_byok_ollama_client called without a base_url")
|
| 61 |
-
return OllamaClient(base_url=base_url.strip(), model=model, timeout=timeout)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class OllamaClient:
|
| 65 |
-
"""Async client for the Ollama local LLM inference server.
|
| 66 |
-
|
| 67 |
-
Supports generate (completion), chat, streaming, health checks,
|
| 68 |
-
and model listing via the Ollama HTTP API.
|
| 69 |
-
|
| 70 |
-
Args:
|
| 71 |
-
base_url: Ollama server base URL. Defaults to settings.ollama_url.
|
| 72 |
-
model: Default model name. Defaults to settings.llm_model.
|
| 73 |
-
timeout: Request timeout in seconds.
|
| 74 |
-
"""
|
| 75 |
-
|
| 76 |
-
def __init__(
|
| 77 |
-
self,
|
| 78 |
-
base_url: str | None = None,
|
| 79 |
-
model: str | None = None,
|
| 80 |
-
timeout: float = 120.0,
|
| 81 |
-
) -> None:
|
| 82 |
-
self.base_url = (base_url if base_url is not None else settings.ollama_url).rstrip("/")
|
| 83 |
-
self.model = model if model is not None else settings.llm_model
|
| 84 |
-
self.timeout = timeout
|
| 85 |
-
self._client = httpx.AsyncClient(
|
| 86 |
-
base_url=self.base_url,
|
| 87 |
-
timeout=httpx.Timeout(timeout),
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
@_retry_on_connection
|
| 91 |
-
async def generate(
|
| 92 |
-
self,
|
| 93 |
-
prompt: str,
|
| 94 |
-
system_prompt: str = "",
|
| 95 |
-
temperature: float = 0.7,
|
| 96 |
-
max_tokens: int = 2048,
|
| 97 |
-
json_mode: bool = False,
|
| 98 |
-
) -> LLMResponse:
|
| 99 |
-
"""Generate a completion from the Ollama API.
|
| 100 |
-
|
| 101 |
-
Args:
|
| 102 |
-
prompt: The user prompt text.
|
| 103 |
-
system_prompt: Optional system context.
|
| 104 |
-
temperature: Sampling temperature (0.0-1.0).
|
| 105 |
-
max_tokens: Maximum tokens to generate.
|
| 106 |
-
json_mode: When True, request JSON-formatted output.
|
| 107 |
-
|
| 108 |
-
Returns:
|
| 109 |
-
LLMResponse with generated text and metadata.
|
| 110 |
-
"""
|
| 111 |
-
payload: dict[str, Any] = {
|
| 112 |
-
"model": self.model,
|
| 113 |
-
"prompt": prompt,
|
| 114 |
-
"stream": False,
|
| 115 |
-
"options": {
|
| 116 |
-
"temperature": temperature,
|
| 117 |
-
"num_predict": max_tokens,
|
| 118 |
-
},
|
| 119 |
-
"keep_alive": settings.ollama_keep_alive,
|
| 120 |
-
}
|
| 121 |
-
if system_prompt:
|
| 122 |
-
payload["system"] = system_prompt
|
| 123 |
-
if json_mode:
|
| 124 |
-
payload["format"] = "json"
|
| 125 |
-
|
| 126 |
-
start = time.perf_counter()
|
| 127 |
-
response = await self._client.post("/api/generate", json=payload)
|
| 128 |
-
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 129 |
-
response.raise_for_status()
|
| 130 |
-
|
| 131 |
-
data = response.json()
|
| 132 |
-
return LLMResponse(
|
| 133 |
-
text=data.get("response", ""),
|
| 134 |
-
model=data.get("model", self.model),
|
| 135 |
-
provider="ollama",
|
| 136 |
-
usage={
|
| 137 |
-
"prompt_tokens": data.get("prompt_eval_count", 0),
|
| 138 |
-
"completion_tokens": data.get("eval_count", 0),
|
| 139 |
-
"total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
|
| 140 |
-
},
|
| 141 |
-
latency_ms=elapsed_ms,
|
| 142 |
-
metadata={
|
| 143 |
-
"total_duration": data.get("total_duration"),
|
| 144 |
-
"load_duration": data.get("load_duration"),
|
| 145 |
-
},
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
@_retry_on_connection
|
| 149 |
-
async def chat(
|
| 150 |
-
self,
|
| 151 |
-
messages: list[dict],
|
| 152 |
-
temperature: float = 0.7,
|
| 153 |
-
max_tokens: int = 2048,
|
| 154 |
-
) -> LLMResponse:
|
| 155 |
-
"""Send a chat conversation to the Ollama API.
|
| 156 |
-
|
| 157 |
-
Args:
|
| 158 |
-
messages: List of message dicts with 'role' and 'content' keys.
|
| 159 |
-
Roles: "system", "user", "assistant".
|
| 160 |
-
temperature: Sampling temperature (0.0-1.0).
|
| 161 |
-
max_tokens: Maximum tokens to generate.
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
LLMResponse with generated text and metadata.
|
| 165 |
-
"""
|
| 166 |
-
payload: dict[str, Any] = {
|
| 167 |
-
"model": self.model,
|
| 168 |
-
"messages": messages,
|
| 169 |
-
"stream": False,
|
| 170 |
-
"options": {
|
| 171 |
-
"temperature": temperature,
|
| 172 |
-
"num_predict": max_tokens,
|
| 173 |
-
},
|
| 174 |
-
"keep_alive": settings.ollama_keep_alive,
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
start = time.perf_counter()
|
| 178 |
-
response = await self._client.post("/api/chat", json=payload)
|
| 179 |
-
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 180 |
-
response.raise_for_status()
|
| 181 |
-
|
| 182 |
-
data = response.json()
|
| 183 |
-
message = data.get("message", {})
|
| 184 |
-
return LLMResponse(
|
| 185 |
-
text=message.get("content", ""),
|
| 186 |
-
model=data.get("model", self.model),
|
| 187 |
-
provider="ollama",
|
| 188 |
-
usage={
|
| 189 |
-
"prompt_tokens": data.get("prompt_eval_count", 0),
|
| 190 |
-
"completion_tokens": data.get("eval_count", 0),
|
| 191 |
-
"total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
|
| 192 |
-
},
|
| 193 |
-
latency_ms=elapsed_ms,
|
| 194 |
-
metadata={
|
| 195 |
-
"total_duration": data.get("total_duration"),
|
| 196 |
-
"load_duration": data.get("load_duration"),
|
| 197 |
-
},
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
async def generate_stream(
|
| 201 |
-
self,
|
| 202 |
-
prompt: str,
|
| 203 |
-
system_prompt: str = "",
|
| 204 |
-
temperature: float = 0.7,
|
| 205 |
-
) -> AsyncGenerator[str, None]:
|
| 206 |
-
"""Stream a completion from the Ollama API, yielding tokens as they arrive.
|
| 207 |
-
|
| 208 |
-
Args:
|
| 209 |
-
prompt: The user prompt text.
|
| 210 |
-
system_prompt: Optional system context.
|
| 211 |
-
temperature: Sampling temperature (0.0-1.0).
|
| 212 |
-
|
| 213 |
-
Yields:
|
| 214 |
-
Token strings as they are generated.
|
| 215 |
-
"""
|
| 216 |
-
payload: dict[str, Any] = {
|
| 217 |
-
"model": self.model,
|
| 218 |
-
"prompt": prompt,
|
| 219 |
-
"stream": True,
|
| 220 |
-
"options": {
|
| 221 |
-
"temperature": temperature,
|
| 222 |
-
},
|
| 223 |
-
"keep_alive": settings.ollama_keep_alive,
|
| 224 |
-
}
|
| 225 |
-
if system_prompt:
|
| 226 |
-
payload["system"] = system_prompt
|
| 227 |
-
|
| 228 |
-
async with self._client.stream("POST", "/api/generate", json=payload) as resp:
|
| 229 |
-
resp.raise_for_status()
|
| 230 |
-
async for line in resp.aiter_lines():
|
| 231 |
-
if line:
|
| 232 |
-
import json
|
| 233 |
-
|
| 234 |
-
data = json.loads(line)
|
| 235 |
-
token = data.get("response", "")
|
| 236 |
-
if token:
|
| 237 |
-
yield token
|
| 238 |
-
if data.get("done", False):
|
| 239 |
-
break
|
| 240 |
-
|
| 241 |
-
async def chat_stream(
|
| 242 |
-
self,
|
| 243 |
-
messages: list[dict],
|
| 244 |
-
temperature: float = 0.7,
|
| 245 |
-
) -> AsyncGenerator[str, None]:
|
| 246 |
-
"""Stream a chat completion from the Ollama API, yielding tokens as they arrive.
|
| 247 |
-
|
| 248 |
-
Args:
|
| 249 |
-
messages: List of message dicts with 'role' and 'content' keys.
|
| 250 |
-
temperature: Sampling temperature (0.0-1.0).
|
| 251 |
-
|
| 252 |
-
Yields:
|
| 253 |
-
Token strings as they are generated.
|
| 254 |
-
"""
|
| 255 |
-
payload: dict[str, Any] = {
|
| 256 |
-
"model": self.model,
|
| 257 |
-
"messages": messages,
|
| 258 |
-
"stream": True,
|
| 259 |
-
"options": {
|
| 260 |
-
"temperature": temperature,
|
| 261 |
-
},
|
| 262 |
-
"keep_alive": settings.ollama_keep_alive,
|
| 263 |
-
}
|
| 264 |
-
|
| 265 |
-
async with self._client.stream("POST", "/api/chat", json=payload) as resp:
|
| 266 |
-
resp.raise_for_status()
|
| 267 |
-
async for line in resp.aiter_lines():
|
| 268 |
-
if line:
|
| 269 |
-
import json
|
| 270 |
-
|
| 271 |
-
data = json.loads(line)
|
| 272 |
-
message = data.get("message", {})
|
| 273 |
-
token = message.get("content", "")
|
| 274 |
-
if token:
|
| 275 |
-
yield token
|
| 276 |
-
if data.get("done", False):
|
| 277 |
-
break
|
| 278 |
-
|
| 279 |
-
@_retry_on_connection
|
| 280 |
-
async def health_check(self) -> bool:
|
| 281 |
-
"""Check if the Ollama server is reachable and responding.
|
| 282 |
-
|
| 283 |
-
Returns:
|
| 284 |
-
True if the server responds with HTTP 200, False otherwise.
|
| 285 |
-
"""
|
| 286 |
-
try:
|
| 287 |
-
response = await self._client.get("/api/tags")
|
| 288 |
-
return response.status_code == 200
|
| 289 |
-
except (httpx.ConnectError, httpx.TimeoutException):
|
| 290 |
-
return False
|
| 291 |
-
|
| 292 |
-
@_retry_on_connection
|
| 293 |
-
async def list_models(self) -> list[str]:
|
| 294 |
-
"""List all models available on the Ollama server.
|
| 295 |
-
|
| 296 |
-
Returns:
|
| 297 |
-
List of model name strings.
|
| 298 |
-
"""
|
| 299 |
-
response = await self._client.get("/api/tags")
|
| 300 |
-
response.raise_for_status()
|
| 301 |
-
data = response.json()
|
| 302 |
-
models = data.get("models", [])
|
| 303 |
-
return [m.get("name", "") for m in models]
|
| 304 |
-
|
| 305 |
-
@_retry_on_connection
|
| 306 |
-
async def get_model_info(self, model: str | None = None) -> dict | None:
|
| 307 |
-
"""Get detailed information about a specific model.
|
| 308 |
-
|
| 309 |
-
Args:
|
| 310 |
-
model: Model name to query. Defaults to the client's configured model.
|
| 311 |
-
|
| 312 |
-
Returns:
|
| 313 |
-
Dict with model info, or None if model not found.
|
| 314 |
-
"""
|
| 315 |
-
target_model = model or self.model
|
| 316 |
-
try:
|
| 317 |
-
response = await self._client.post("/api/show", json={"name": target_model})
|
| 318 |
-
if response.status_code == 200:
|
| 319 |
-
return response.json()
|
| 320 |
-
return None
|
| 321 |
-
except httpx.HTTPStatusError:
|
| 322 |
-
return None
|
| 323 |
-
|
| 324 |
-
async def close(self) -> None:
|
| 325 |
-
"""Close the underlying HTTP client."""
|
| 326 |
-
await self._client.aclose()
|
| 327 |
-
|
| 328 |
-
async def __aenter__(self) -> OllamaClient:
|
| 329 |
-
"""Enter async context manager."""
|
| 330 |
-
return self
|
| 331 |
-
|
| 332 |
-
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 333 |
-
"""Exit async context manager, closing the client."""
|
| 334 |
-
await self.close()
|
|
|
|
| 1 |
+
"""Async Ollama client wrapper with streaming support and health checks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from typing import TYPE_CHECKING, Any
|
| 7 |
+
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from collections.abc import AsyncGenerator
|
| 12 |
+
from tenacity import (
|
| 13 |
+
retry,
|
| 14 |
+
retry_if_exception_type,
|
| 15 |
+
stop_after_attempt,
|
| 16 |
+
wait_exponential,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from config.settings import settings
|
| 20 |
+
from inference.llm_factory import LLMResponse
|
| 21 |
+
from utils.logging import get_logger
|
| 22 |
+
|
| 23 |
+
logger = get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
# Retry decorator for transient connection failures only
|
| 26 |
+
_retry_on_connection = retry(
|
| 27 |
+
retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
|
| 28 |
+
stop=stop_after_attempt(3),
|
| 29 |
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
| 30 |
+
reraise=True,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def make_byok_ollama_client(
|
| 35 |
+
*,
|
| 36 |
+
base_url: str,
|
| 37 |
+
model: str | None = None,
|
| 38 |
+
timeout: float = 60.0,
|
| 39 |
+
) -> OllamaClient:
|
| 40 |
+
"""Build a per-request Ollama client bound to the visitor's instance URL.
|
| 41 |
+
|
| 42 |
+
Visitors running their own local Ollama can paste the public URL of
|
| 43 |
+
that instance into the frontend. Each call returns a **fresh client**
|
| 44 |
+
so the visitor's URL never replaces the owner default at module scope.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
base_url: URL of the visitor's Ollama server (HTTPS preferred).
|
| 48 |
+
model: Override the default model. Falls back to the owner's
|
| 49 |
+
configured ``SAR_LLM_MODEL`` if the visitor's Ollama does not
|
| 50 |
+
advertise its own.
|
| 51 |
+
timeout: Per-request HTTP timeout in seconds.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
A new ``OllamaClient`` bound to ``base_url``.
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
ValueError: ``base_url`` is empty or whitespace.
|
| 58 |
+
"""
|
| 59 |
+
if not base_url or not base_url.strip():
|
| 60 |
+
raise ValueError("make_byok_ollama_client called without a base_url")
|
| 61 |
+
return OllamaClient(base_url=base_url.strip(), model=model, timeout=timeout)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class OllamaClient:
|
| 65 |
+
"""Async client for the Ollama local LLM inference server.
|
| 66 |
+
|
| 67 |
+
Supports generate (completion), chat, streaming, health checks,
|
| 68 |
+
and model listing via the Ollama HTTP API.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
base_url: Ollama server base URL. Defaults to settings.ollama_url.
|
| 72 |
+
model: Default model name. Defaults to settings.llm_model.
|
| 73 |
+
timeout: Request timeout in seconds.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
base_url: str | None = None,
|
| 79 |
+
model: str | None = None,
|
| 80 |
+
timeout: float = 120.0,
|
| 81 |
+
) -> None:
|
| 82 |
+
self.base_url = (base_url if base_url is not None else settings.ollama_url).rstrip("/")
|
| 83 |
+
self.model = model if model is not None else settings.llm_model
|
| 84 |
+
self.timeout = timeout
|
| 85 |
+
self._client = httpx.AsyncClient(
|
| 86 |
+
base_url=self.base_url,
|
| 87 |
+
timeout=httpx.Timeout(timeout),
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
@_retry_on_connection
|
| 91 |
+
async def generate(
|
| 92 |
+
self,
|
| 93 |
+
prompt: str,
|
| 94 |
+
system_prompt: str = "",
|
| 95 |
+
temperature: float = 0.7,
|
| 96 |
+
max_tokens: int = 2048,
|
| 97 |
+
json_mode: bool = False,
|
| 98 |
+
) -> LLMResponse:
|
| 99 |
+
"""Generate a completion from the Ollama API.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
prompt: The user prompt text.
|
| 103 |
+
system_prompt: Optional system context.
|
| 104 |
+
temperature: Sampling temperature (0.0-1.0).
|
| 105 |
+
max_tokens: Maximum tokens to generate.
|
| 106 |
+
json_mode: When True, request JSON-formatted output.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
LLMResponse with generated text and metadata.
|
| 110 |
+
"""
|
| 111 |
+
payload: dict[str, Any] = {
|
| 112 |
+
"model": self.model,
|
| 113 |
+
"prompt": prompt,
|
| 114 |
+
"stream": False,
|
| 115 |
+
"options": {
|
| 116 |
+
"temperature": temperature,
|
| 117 |
+
"num_predict": max_tokens,
|
| 118 |
+
},
|
| 119 |
+
"keep_alive": settings.ollama_keep_alive,
|
| 120 |
+
}
|
| 121 |
+
if system_prompt:
|
| 122 |
+
payload["system"] = system_prompt
|
| 123 |
+
if json_mode:
|
| 124 |
+
payload["format"] = "json"
|
| 125 |
+
|
| 126 |
+
start = time.perf_counter()
|
| 127 |
+
response = await self._client.post("/api/generate", json=payload)
|
| 128 |
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 129 |
+
response.raise_for_status()
|
| 130 |
+
|
| 131 |
+
data = response.json()
|
| 132 |
+
return LLMResponse(
|
| 133 |
+
text=data.get("response", ""),
|
| 134 |
+
model=data.get("model", self.model),
|
| 135 |
+
provider="ollama",
|
| 136 |
+
usage={
|
| 137 |
+
"prompt_tokens": data.get("prompt_eval_count", 0),
|
| 138 |
+
"completion_tokens": data.get("eval_count", 0),
|
| 139 |
+
"total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
|
| 140 |
+
},
|
| 141 |
+
latency_ms=elapsed_ms,
|
| 142 |
+
metadata={
|
| 143 |
+
"total_duration": data.get("total_duration"),
|
| 144 |
+
"load_duration": data.get("load_duration"),
|
| 145 |
+
},
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
@_retry_on_connection
|
| 149 |
+
async def chat(
|
| 150 |
+
self,
|
| 151 |
+
messages: list[dict],
|
| 152 |
+
temperature: float = 0.7,
|
| 153 |
+
max_tokens: int = 2048,
|
| 154 |
+
) -> LLMResponse:
|
| 155 |
+
"""Send a chat conversation to the Ollama API.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
messages: List of message dicts with 'role' and 'content' keys.
|
| 159 |
+
Roles: "system", "user", "assistant".
|
| 160 |
+
temperature: Sampling temperature (0.0-1.0).
|
| 161 |
+
max_tokens: Maximum tokens to generate.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
LLMResponse with generated text and metadata.
|
| 165 |
+
"""
|
| 166 |
+
payload: dict[str, Any] = {
|
| 167 |
+
"model": self.model,
|
| 168 |
+
"messages": messages,
|
| 169 |
+
"stream": False,
|
| 170 |
+
"options": {
|
| 171 |
+
"temperature": temperature,
|
| 172 |
+
"num_predict": max_tokens,
|
| 173 |
+
},
|
| 174 |
+
"keep_alive": settings.ollama_keep_alive,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
start = time.perf_counter()
|
| 178 |
+
response = await self._client.post("/api/chat", json=payload)
|
| 179 |
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 180 |
+
response.raise_for_status()
|
| 181 |
+
|
| 182 |
+
data = response.json()
|
| 183 |
+
message = data.get("message", {})
|
| 184 |
+
return LLMResponse(
|
| 185 |
+
text=message.get("content", ""),
|
| 186 |
+
model=data.get("model", self.model),
|
| 187 |
+
provider="ollama",
|
| 188 |
+
usage={
|
| 189 |
+
"prompt_tokens": data.get("prompt_eval_count", 0),
|
| 190 |
+
"completion_tokens": data.get("eval_count", 0),
|
| 191 |
+
"total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
|
| 192 |
+
},
|
| 193 |
+
latency_ms=elapsed_ms,
|
| 194 |
+
metadata={
|
| 195 |
+
"total_duration": data.get("total_duration"),
|
| 196 |
+
"load_duration": data.get("load_duration"),
|
| 197 |
+
},
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
async def generate_stream(
|
| 201 |
+
self,
|
| 202 |
+
prompt: str,
|
| 203 |
+
system_prompt: str = "",
|
| 204 |
+
temperature: float = 0.7,
|
| 205 |
+
) -> AsyncGenerator[str, None]:
|
| 206 |
+
"""Stream a completion from the Ollama API, yielding tokens as they arrive.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
prompt: The user prompt text.
|
| 210 |
+
system_prompt: Optional system context.
|
| 211 |
+
temperature: Sampling temperature (0.0-1.0).
|
| 212 |
+
|
| 213 |
+
Yields:
|
| 214 |
+
Token strings as they are generated.
|
| 215 |
+
"""
|
| 216 |
+
payload: dict[str, Any] = {
|
| 217 |
+
"model": self.model,
|
| 218 |
+
"prompt": prompt,
|
| 219 |
+
"stream": True,
|
| 220 |
+
"options": {
|
| 221 |
+
"temperature": temperature,
|
| 222 |
+
},
|
| 223 |
+
"keep_alive": settings.ollama_keep_alive,
|
| 224 |
+
}
|
| 225 |
+
if system_prompt:
|
| 226 |
+
payload["system"] = system_prompt
|
| 227 |
+
|
| 228 |
+
async with self._client.stream("POST", "/api/generate", json=payload) as resp:
|
| 229 |
+
resp.raise_for_status()
|
| 230 |
+
async for line in resp.aiter_lines():
|
| 231 |
+
if line:
|
| 232 |
+
import json
|
| 233 |
+
|
| 234 |
+
data = json.loads(line)
|
| 235 |
+
token = data.get("response", "")
|
| 236 |
+
if token:
|
| 237 |
+
yield token
|
| 238 |
+
if data.get("done", False):
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
async def chat_stream(
|
| 242 |
+
self,
|
| 243 |
+
messages: list[dict],
|
| 244 |
+
temperature: float = 0.7,
|
| 245 |
+
) -> AsyncGenerator[str, None]:
|
| 246 |
+
"""Stream a chat completion from the Ollama API, yielding tokens as they arrive.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
messages: List of message dicts with 'role' and 'content' keys.
|
| 250 |
+
temperature: Sampling temperature (0.0-1.0).
|
| 251 |
+
|
| 252 |
+
Yields:
|
| 253 |
+
Token strings as they are generated.
|
| 254 |
+
"""
|
| 255 |
+
payload: dict[str, Any] = {
|
| 256 |
+
"model": self.model,
|
| 257 |
+
"messages": messages,
|
| 258 |
+
"stream": True,
|
| 259 |
+
"options": {
|
| 260 |
+
"temperature": temperature,
|
| 261 |
+
},
|
| 262 |
+
"keep_alive": settings.ollama_keep_alive,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
async with self._client.stream("POST", "/api/chat", json=payload) as resp:
|
| 266 |
+
resp.raise_for_status()
|
| 267 |
+
async for line in resp.aiter_lines():
|
| 268 |
+
if line:
|
| 269 |
+
import json
|
| 270 |
+
|
| 271 |
+
data = json.loads(line)
|
| 272 |
+
message = data.get("message", {})
|
| 273 |
+
token = message.get("content", "")
|
| 274 |
+
if token:
|
| 275 |
+
yield token
|
| 276 |
+
if data.get("done", False):
|
| 277 |
+
break
|
| 278 |
+
|
| 279 |
+
@_retry_on_connection
|
| 280 |
+
async def health_check(self) -> bool:
|
| 281 |
+
"""Check if the Ollama server is reachable and responding.
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
True if the server responds with HTTP 200, False otherwise.
|
| 285 |
+
"""
|
| 286 |
+
try:
|
| 287 |
+
response = await self._client.get("/api/tags")
|
| 288 |
+
return response.status_code == 200
|
| 289 |
+
except (httpx.ConnectError, httpx.TimeoutException):
|
| 290 |
+
return False
|
| 291 |
+
|
| 292 |
+
@_retry_on_connection
|
| 293 |
+
async def list_models(self) -> list[str]:
|
| 294 |
+
"""List all models available on the Ollama server.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
List of model name strings.
|
| 298 |
+
"""
|
| 299 |
+
response = await self._client.get("/api/tags")
|
| 300 |
+
response.raise_for_status()
|
| 301 |
+
data = response.json()
|
| 302 |
+
models = data.get("models", [])
|
| 303 |
+
return [m.get("name", "") for m in models]
|
| 304 |
+
|
| 305 |
+
@_retry_on_connection
|
| 306 |
+
async def get_model_info(self, model: str | None = None) -> dict | None:
|
| 307 |
+
"""Get detailed information about a specific model.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
model: Model name to query. Defaults to the client's configured model.
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
Dict with model info, or None if model not found.
|
| 314 |
+
"""
|
| 315 |
+
target_model = model or self.model
|
| 316 |
+
try:
|
| 317 |
+
response = await self._client.post("/api/show", json={"name": target_model})
|
| 318 |
+
if response.status_code == 200:
|
| 319 |
+
return response.json()
|
| 320 |
+
return None
|
| 321 |
+
except httpx.HTTPStatusError:
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
async def close(self) -> None:
|
| 325 |
+
"""Close the underlying HTTP client."""
|
| 326 |
+
await self._client.aclose()
|
| 327 |
+
|
| 328 |
+
async def __aenter__(self) -> OllamaClient:
|
| 329 |
+
"""Enter async context manager."""
|
| 330 |
+
return self
|
| 331 |
+
|
| 332 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 333 |
+
"""Exit async context manager, closing the client."""
|
| 334 |
+
await self.close()
|
interfaces/api.py
CHANGED
|
@@ -1,425 +1,432 @@
|
|
| 1 |
-
"""FastAPI surface for SecureAgentRAG.
|
| 2 |
-
|
| 3 |
-
Run with::
|
| 4 |
-
|
| 5 |
-
uv run uvicorn interfaces.api:app --host 0.0.0.0 --port 8080
|
| 6 |
-
|
| 7 |
-
Endpoints
|
| 8 |
-
---------
|
| 9 |
-
- ``GET /healthz`` — liveness probe (no auth).
|
| 10 |
-
- ``GET /readyz`` — readiness — pings Qdrant + Ollama.
|
| 11 |
-
- ``POST /query`` — run the RAG pipeline; returns ``QueryResponse``.
|
| 12 |
-
- ``POST /ingest`` — ingest a local file; requires ``user`` role.
|
| 13 |
-
- ``GET /audit`` — read paginated audit entries; requires ``admin``.
|
| 14 |
-
- ``POST /audit/verify``— verify the hash-chain; requires ``admin``.
|
| 15 |
-
|
| 16 |
-
Auth uses a stateless bearer token. The token payload is a base64-encoded JSON
|
| 17 |
-
``UserContext`` so the API has no session store — caller provides identity on
|
| 18 |
-
every request. Production deployments should swap this for Keycloak/Auth0 JWT
|
| 19 |
-
verification (left as a hook in ``_resolve_user``).
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
from __future__ import annotations
|
| 23 |
-
|
| 24 |
-
import base64
|
| 25 |
-
import json
|
| 26 |
-
from datetime import date
|
| 27 |
-
from typing import Annotated
|
| 28 |
-
|
| 29 |
-
from config.settings import settings
|
| 30 |
-
from utils.auth import AuthError, issue_token, verify_token
|
| 31 |
-
from utils.logging import get_logger
|
| 32 |
-
|
| 33 |
-
logger = get_logger(__name__)
|
| 34 |
-
|
| 35 |
-
try:
|
| 36 |
-
from fastapi import Depends, FastAPI, Header, HTTPException, status
|
| 37 |
-
from fastapi.responses import JSONResponse
|
| 38 |
-
|
| 39 |
-
_FASTAPI_AVAILABLE = True
|
| 40 |
-
except ImportError: # pragma: no cover
|
| 41 |
-
_FASTAPI_AVAILABLE = False
|
| 42 |
-
Depends = Header = FastAPI = HTTPException = JSONResponse = status = None # type: ignore[assignment]
|
| 43 |
-
|
| 44 |
-
if _FASTAPI_AVAILABLE:
|
| 45 |
-
from core.graph import run_rag_pipeline
|
| 46 |
-
from core.schemas import (
|
| 47 |
-
IngestRequestModel,
|
| 48 |
-
IngestResponseModel,
|
| 49 |
-
QueryRequest,
|
| 50 |
-
QueryResponse,
|
| 51 |
-
)
|
| 52 |
-
from ingestion.metadata import IngestRequest, SensitivityLevel, UserContext
|
| 53 |
-
from utils.audit import audit_logger
|
| 54 |
-
from utils.health import run_health_checks
|
| 55 |
-
from utils.rate_limiter import RateLimiter
|
| 56 |
-
|
| 57 |
-
rate_limiter = RateLimiter() # uses default token-bucket config
|
| 58 |
-
|
| 59 |
-
_AUTH_ERROR_STATUS: dict[str, int] = {
|
| 60 |
-
"missing": status.HTTP_401_UNAUTHORIZED,
|
| 61 |
-
"malformed": status.HTTP_401_UNAUTHORIZED,
|
| 62 |
-
"expired": status.HTTP_401_UNAUTHORIZED,
|
| 63 |
-
"bad_signature": status.HTTP_401_UNAUTHORIZED,
|
| 64 |
-
"bad_claims": status.HTTP_403_FORBIDDEN,
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
def _resolve_user_full(
|
| 68 |
-
authorization: Annotated[str | None, Header()] = None,
|
| 69 |
-
) -> tuple[UserContext, dict]:
|
| 70 |
-
"""Verify the bearer token and return (UserContext, claims).
|
| 71 |
-
|
| 72 |
-
Delegates to :func:`utils.auth.verify_token`, which uses HS256 JWT
|
| 73 |
-
when ``SAR_JWT_SECRET`` is set and falls back to the legacy unsigned
|
| 74 |
-
base64 token otherwise (with a runtime warning).
|
| 75 |
-
"""
|
| 76 |
-
if not authorization or not authorization.lower().startswith("bearer "):
|
| 77 |
-
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
|
| 78 |
-
token = authorization.split(" ", 1)[1]
|
| 79 |
-
try:
|
| 80 |
-
return verify_token(token)
|
| 81 |
-
except AuthError as exc:
|
| 82 |
-
code = _AUTH_ERROR_STATUS.get(exc.reason, status.HTTP_401_UNAUTHORIZED)
|
| 83 |
-
raise HTTPException(code, f"auth_{exc.reason}: {exc}") from exc
|
| 84 |
-
|
| 85 |
-
def _resolve_user(authorization: Annotated[str | None, Header()] = None) -> UserContext:
|
| 86 |
-
"""Backward-compatible dependency returning only the UserContext."""
|
| 87 |
-
ctx, _claims = _resolve_user_full(authorization=authorization)
|
| 88 |
-
return ctx
|
| 89 |
-
|
| 90 |
-
def _require_role(required: str):
|
| 91 |
-
def _dep(user: Annotated[UserContext, Depends(_resolve_user)]) -> UserContext:
|
| 92 |
-
if required not in user.roles and "admin" not in user.roles:
|
| 93 |
-
raise HTTPException(status.HTTP_403_FORBIDDEN, f"role '{required}' required")
|
| 94 |
-
return user
|
| 95 |
-
|
| 96 |
-
return _dep
|
| 97 |
-
|
| 98 |
-
app = FastAPI(
|
| 99 |
-
title="SecureAgentRAG API",
|
| 100 |
-
version="0.1.0",
|
| 101 |
-
description="Privacy-first multi-agent RAG with RBAC, guardrails, and audit chain.",
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
# Initialize Phoenix tracing if configured.
|
| 105 |
-
# When ``settings.byok_mode`` is on, ``setup_tracing`` short-circuits to
|
| 106 |
-
# False regardless of phoenix_endpoint (see utils/observability.py).
|
| 107 |
-
from utils.observability import setup_tracing
|
| 108 |
-
|
| 109 |
-
_tracing_enabled = setup_tracing()
|
| 110 |
-
if _tracing_enabled:
|
| 111 |
-
logger.info("phoenix_tracing_active_in_api")
|
| 112 |
-
|
| 113 |
-
# ── BYOK CORS middleware ─────────────────────────────────────────────
|
| 114 |
-
# Only mount CORS when:
|
| 115 |
-
# 1) BYOK mode is on (public demo path), AND
|
| 116 |
-
# 2) an explicit allowlist is configured via SAR_CORS_ALLOW_ORIGINS.
|
| 117 |
-
# Empty allowlist + BYOK = wildcard would be a footgun (CSRF surface).
|
| 118 |
-
# Empty allowlist + dev = no CORS needed (local same-origin).
|
| 119 |
-
if settings.byok_mode and settings.cors_allow_origins:
|
| 120 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 121 |
-
|
| 122 |
-
app.add_middleware(
|
| 123 |
-
CORSMiddleware,
|
| 124 |
-
allow_origins=list(settings.cors_allow_origins),
|
| 125 |
-
allow_credentials=False, # BYOK never uses cookies
|
| 126 |
-
allow_methods=["GET", "POST", "OPTIONS"],
|
| 127 |
-
allow_headers=["*"],
|
| 128 |
-
)
|
| 129 |
-
logger.info("byok_cors_enabled", origins=list(settings.cors_allow_origins))
|
| 130 |
-
|
| 131 |
-
@app.get("/healthz", tags=["ops"])
|
| 132 |
-
async def healthz() -> dict[str, str]:
|
| 133 |
-
return {"status": "ok"}
|
| 134 |
-
|
| 135 |
-
@app.get("/readyz", tags=["ops"])
|
| 136 |
-
async def readyz() -> JSONResponse:
|
| 137 |
-
report = await run_health_checks()
|
| 138 |
-
code = 200 if report.overall_healthy else 503
|
| 139 |
-
return JSONResponse(report.to_dict(), status_code=code)
|
| 140 |
-
|
| 141 |
-
# ── BYOK demo endpoint ───────────────────────────────────────────────
|
| 142 |
-
# Mounted only when ``settings.byok_mode`` is on. Bypasses JWT auth and
|
| 143 |
-
# uses per-request BYOK credentials instead. Isolation is enforced via
|
| 144 |
-
# session-scoped Qdrant collections, not JWT identity.
|
| 145 |
-
if settings.byok_mode:
|
| 146 |
-
from interfaces.byok import ByokCreds, extract_byok
|
| 147 |
-
from utils.rate_limiter import get_owner_key_throttle
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
""
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
"
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
"""
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI surface for SecureAgentRAG.
|
| 2 |
+
|
| 3 |
+
Run with::
|
| 4 |
+
|
| 5 |
+
uv run uvicorn interfaces.api:app --host 0.0.0.0 --port 8080
|
| 6 |
+
|
| 7 |
+
Endpoints
|
| 8 |
+
---------
|
| 9 |
+
- ``GET /healthz`` — liveness probe (no auth).
|
| 10 |
+
- ``GET /readyz`` — readiness — pings Qdrant + Ollama.
|
| 11 |
+
- ``POST /query`` — run the RAG pipeline; returns ``QueryResponse``.
|
| 12 |
+
- ``POST /ingest`` — ingest a local file; requires ``user`` role.
|
| 13 |
+
- ``GET /audit`` — read paginated audit entries; requires ``admin``.
|
| 14 |
+
- ``POST /audit/verify``— verify the hash-chain; requires ``admin``.
|
| 15 |
+
|
| 16 |
+
Auth uses a stateless bearer token. The token payload is a base64-encoded JSON
|
| 17 |
+
``UserContext`` so the API has no session store — caller provides identity on
|
| 18 |
+
every request. Production deployments should swap this for Keycloak/Auth0 JWT
|
| 19 |
+
verification (left as a hook in ``_resolve_user``).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import base64
|
| 25 |
+
import json
|
| 26 |
+
from datetime import date
|
| 27 |
+
from typing import Annotated
|
| 28 |
+
|
| 29 |
+
from config.settings import settings
|
| 30 |
+
from utils.auth import AuthError, issue_token, verify_token
|
| 31 |
+
from utils.logging import get_logger
|
| 32 |
+
|
| 33 |
+
logger = get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from fastapi import Depends, FastAPI, Header, HTTPException, status
|
| 37 |
+
from fastapi.responses import JSONResponse
|
| 38 |
+
|
| 39 |
+
_FASTAPI_AVAILABLE = True
|
| 40 |
+
except ImportError: # pragma: no cover
|
| 41 |
+
_FASTAPI_AVAILABLE = False
|
| 42 |
+
Depends = Header = FastAPI = HTTPException = JSONResponse = status = None # type: ignore[assignment]
|
| 43 |
+
|
| 44 |
+
if _FASTAPI_AVAILABLE:
|
| 45 |
+
from core.graph import run_rag_pipeline
|
| 46 |
+
from core.schemas import (
|
| 47 |
+
IngestRequestModel,
|
| 48 |
+
IngestResponseModel,
|
| 49 |
+
QueryRequest,
|
| 50 |
+
QueryResponse,
|
| 51 |
+
)
|
| 52 |
+
from ingestion.metadata import IngestRequest, SensitivityLevel, UserContext
|
| 53 |
+
from utils.audit import audit_logger
|
| 54 |
+
from utils.health import run_health_checks
|
| 55 |
+
from utils.rate_limiter import RateLimiter
|
| 56 |
+
|
| 57 |
+
rate_limiter = RateLimiter() # uses default token-bucket config
|
| 58 |
+
|
| 59 |
+
_AUTH_ERROR_STATUS: dict[str, int] = {
|
| 60 |
+
"missing": status.HTTP_401_UNAUTHORIZED,
|
| 61 |
+
"malformed": status.HTTP_401_UNAUTHORIZED,
|
| 62 |
+
"expired": status.HTTP_401_UNAUTHORIZED,
|
| 63 |
+
"bad_signature": status.HTTP_401_UNAUTHORIZED,
|
| 64 |
+
"bad_claims": status.HTTP_403_FORBIDDEN,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def _resolve_user_full(
|
| 68 |
+
authorization: Annotated[str | None, Header()] = None,
|
| 69 |
+
) -> tuple[UserContext, dict]:
|
| 70 |
+
"""Verify the bearer token and return (UserContext, claims).
|
| 71 |
+
|
| 72 |
+
Delegates to :func:`utils.auth.verify_token`, which uses HS256 JWT
|
| 73 |
+
when ``SAR_JWT_SECRET`` is set and falls back to the legacy unsigned
|
| 74 |
+
base64 token otherwise (with a runtime warning).
|
| 75 |
+
"""
|
| 76 |
+
if not authorization or not authorization.lower().startswith("bearer "):
|
| 77 |
+
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
|
| 78 |
+
token = authorization.split(" ", 1)[1]
|
| 79 |
+
try:
|
| 80 |
+
return verify_token(token)
|
| 81 |
+
except AuthError as exc:
|
| 82 |
+
code = _AUTH_ERROR_STATUS.get(exc.reason, status.HTTP_401_UNAUTHORIZED)
|
| 83 |
+
raise HTTPException(code, f"auth_{exc.reason}: {exc}") from exc
|
| 84 |
+
|
| 85 |
+
def _resolve_user(authorization: Annotated[str | None, Header()] = None) -> UserContext:
|
| 86 |
+
"""Backward-compatible dependency returning only the UserContext."""
|
| 87 |
+
ctx, _claims = _resolve_user_full(authorization=authorization)
|
| 88 |
+
return ctx
|
| 89 |
+
|
| 90 |
+
def _require_role(required: str):
|
| 91 |
+
def _dep(user: Annotated[UserContext, Depends(_resolve_user)]) -> UserContext:
|
| 92 |
+
if required not in user.roles and "admin" not in user.roles:
|
| 93 |
+
raise HTTPException(status.HTTP_403_FORBIDDEN, f"role '{required}' required")
|
| 94 |
+
return user
|
| 95 |
+
|
| 96 |
+
return _dep
|
| 97 |
+
|
| 98 |
+
app = FastAPI(
|
| 99 |
+
title="SecureAgentRAG API",
|
| 100 |
+
version="0.1.0",
|
| 101 |
+
description="Privacy-first multi-agent RAG with RBAC, guardrails, and audit chain.",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Initialize Phoenix tracing if configured.
|
| 105 |
+
# When ``settings.byok_mode`` is on, ``setup_tracing`` short-circuits to
|
| 106 |
+
# False regardless of phoenix_endpoint (see utils/observability.py).
|
| 107 |
+
from utils.observability import setup_tracing
|
| 108 |
+
|
| 109 |
+
_tracing_enabled = setup_tracing()
|
| 110 |
+
if _tracing_enabled:
|
| 111 |
+
logger.info("phoenix_tracing_active_in_api")
|
| 112 |
+
|
| 113 |
+
# ── BYOK CORS middleware ─────────────────────────────────────────────
|
| 114 |
+
# Only mount CORS when:
|
| 115 |
+
# 1) BYOK mode is on (public demo path), AND
|
| 116 |
+
# 2) an explicit allowlist is configured via SAR_CORS_ALLOW_ORIGINS.
|
| 117 |
+
# Empty allowlist + BYOK = wildcard would be a footgun (CSRF surface).
|
| 118 |
+
# Empty allowlist + dev = no CORS needed (local same-origin).
|
| 119 |
+
if settings.byok_mode and settings.cors_allow_origins:
|
| 120 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 121 |
+
|
| 122 |
+
app.add_middleware(
|
| 123 |
+
CORSMiddleware,
|
| 124 |
+
allow_origins=list(settings.cors_allow_origins),
|
| 125 |
+
allow_credentials=False, # BYOK never uses cookies
|
| 126 |
+
allow_methods=["GET", "POST", "OPTIONS"],
|
| 127 |
+
allow_headers=["*"],
|
| 128 |
+
)
|
| 129 |
+
logger.info("byok_cors_enabled", origins=list(settings.cors_allow_origins))
|
| 130 |
+
|
| 131 |
+
@app.get("/healthz", tags=["ops"])
|
| 132 |
+
async def healthz() -> dict[str, str]:
|
| 133 |
+
return {"status": "ok"}
|
| 134 |
+
|
| 135 |
+
@app.get("/readyz", tags=["ops"])
|
| 136 |
+
async def readyz() -> JSONResponse:
|
| 137 |
+
report = await run_health_checks()
|
| 138 |
+
code = 200 if report.overall_healthy else 503
|
| 139 |
+
return JSONResponse(report.to_dict(), status_code=code)
|
| 140 |
+
|
| 141 |
+
# ── BYOK demo endpoint ───────────────────────────────────────────────
|
| 142 |
+
# Mounted only when ``settings.byok_mode`` is on. Bypasses JWT auth and
|
| 143 |
+
# uses per-request BYOK credentials instead. Isolation is enforced via
|
| 144 |
+
# session-scoped Qdrant collections, not JWT identity.
|
| 145 |
+
if settings.byok_mode:
|
| 146 |
+
from interfaces.byok import ByokCreds, extract_byok
|
| 147 |
+
from utils.rate_limiter import get_owner_key_throttle
|
| 148 |
+
|
| 149 |
+
# All demo personas share ``org_id="demo"`` so they query the same
|
| 150 |
+
# ingested corpus. RBAC differentiation is enforced via clearance
|
| 151 |
+
# level + roles at the payload-filter layer -- exactly the production
|
| 152 |
+
# invariant we want to demonstrate.
|
| 153 |
+
_DEMO_ORG_ID = "demo"
|
| 154 |
+
# Sensitivity levels are LOW=1, MEDIUM=2, HIGH=3 (see
|
| 155 |
+
# ``ingestion/metadata.py::sensitivity_to_int``). Clearance levels must
|
| 156 |
+
# be in the same range so the Qdrant range filter passes the right
|
| 157 |
+
# chunks. Engineer < Compliance == Executive, but executive carries
|
| 158 |
+
# a wider role set (sees both engineering + compliance content).
|
| 159 |
+
_DEMO_PERSONAS: dict[str, dict] = {
|
| 160 |
+
"engineer": {
|
| 161 |
+
"clearance_level": 2,
|
| 162 |
+
"roles": ["engineering"],
|
| 163 |
+
},
|
| 164 |
+
"compliance": {
|
| 165 |
+
"clearance_level": 3,
|
| 166 |
+
"roles": ["compliance", "legal"],
|
| 167 |
+
},
|
| 168 |
+
"executive": {
|
| 169 |
+
"clearance_level": 3,
|
| 170 |
+
"roles": ["executive", "compliance", "engineering"],
|
| 171 |
+
},
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
def _persona_to_user_ctx(creds: ByokCreds) -> UserContext:
|
| 175 |
+
"""Translate ``creds.demo_persona`` into a synthetic UserContext.
|
| 176 |
+
|
| 177 |
+
Unknown / missing persona → minimal read-only profile so the demo
|
| 178 |
+
still answers but cannot escalate beyond the lowest clearance.
|
| 179 |
+
"""
|
| 180 |
+
preset = _DEMO_PERSONAS.get((creds.demo_persona or "").lower())
|
| 181 |
+
if preset is None:
|
| 182 |
+
preset = {"clearance_level": 1, "roles": ["viewer"]}
|
| 183 |
+
return UserContext(
|
| 184 |
+
user_id=f"demo-{creds.session_id}",
|
| 185 |
+
org_id=_DEMO_ORG_ID,
|
| 186 |
+
clearance_level=preset["clearance_level"],
|
| 187 |
+
roles=preset["roles"],
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
from pydantic import BaseModel as _ByokBaseModel
|
| 191 |
+
|
| 192 |
+
class _ByokChatBody(_ByokBaseModel):
|
| 193 |
+
"""Public-demo chat payload — no auth fields, only the question text."""
|
| 194 |
+
|
| 195 |
+
query: str
|
| 196 |
+
prefer_cloud: bool = True
|
| 197 |
+
|
| 198 |
+
# Runtime import — FastAPI dependency injection reads the annotation
|
| 199 |
+
# at request time, so this must NOT be a TYPE_CHECKING-only import.
|
| 200 |
+
from fastapi import Request as _FastApiRequest # noqa: TC002
|
| 201 |
+
|
| 202 |
+
@app.post("/byok/chat", tags=["byok"])
|
| 203 |
+
async def byok_chat_endpoint(
|
| 204 |
+
request: _FastApiRequest,
|
| 205 |
+
body: _ByokChatBody,
|
| 206 |
+
creds: Annotated[ByokCreds, Depends(extract_byok)],
|
| 207 |
+
) -> dict:
|
| 208 |
+
"""Public-demo chat endpoint backed by BYOK credentials.
|
| 209 |
+
|
| 210 |
+
Routing:
|
| 211 |
+
- Visitor brought a key (``creds.has_user_key()``): pipeline uses
|
| 212 |
+
the visitor's provider + key. No throttle.
|
| 213 |
+
- Visitor did NOT bring a key: pipeline falls back to the owner's
|
| 214 |
+
configured cloud provider key, gated by ``OwnerKeyHourThrottle``.
|
| 215 |
+
When exhausted, returns 429 with copy nudging BYOK.
|
| 216 |
+
|
| 217 |
+
Persona maps to a synthetic ``UserContext`` so the existing RBAC
|
| 218 |
+
filter still runs end-to-end — same code path as authenticated
|
| 219 |
+
queries, just with demo identities.
|
| 220 |
+
"""
|
| 221 |
+
if not creds.has_user_key():
|
| 222 |
+
throttle = get_owner_key_throttle()
|
| 223 |
+
client_ip = (request.client.host if request.client else None) or "anon"
|
| 224 |
+
ok, meta = throttle.allow(client_ip)
|
| 225 |
+
if not ok:
|
| 226 |
+
raise HTTPException(
|
| 227 |
+
status.HTTP_429_TOO_MANY_REQUESTS,
|
| 228 |
+
detail={
|
| 229 |
+
"reason": meta["reason"],
|
| 230 |
+
"retry_after_seconds": meta["retry_after"],
|
| 231 |
+
"hint": (
|
| 232 |
+
"Owner-key fallback exhausted for this IP. "
|
| 233 |
+
"Paste your own LLM key to continue — your key "
|
| 234 |
+
"is never stored server-side."
|
| 235 |
+
),
|
| 236 |
+
},
|
| 237 |
+
)
|
| 238 |
+
user_ctx = _persona_to_user_ctx(creds)
|
| 239 |
+
state = await run_rag_pipeline(
|
| 240 |
+
query=body.query,
|
| 241 |
+
user_context=user_ctx,
|
| 242 |
+
thread_id=f"byok-{creds.session_id}",
|
| 243 |
+
prefer_cloud=body.prefer_cloud,
|
| 244 |
+
# Visitor's chosen provider when present; falls back to env.
|
| 245 |
+
override_provider=creds.safe_provider(),
|
| 246 |
+
)
|
| 247 |
+
response = QueryResponse.from_state(state)
|
| 248 |
+
return {
|
| 249 |
+
"session_id": creds.session_id,
|
| 250 |
+
"persona": creds.demo_persona or "anonymous",
|
| 251 |
+
"byok_used": creds.has_user_key(),
|
| 252 |
+
"response": response.model_dump(mode="json"),
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
@app.post("/query", response_model=QueryResponse, tags=["rag"])
|
| 256 |
+
async def query_endpoint(
|
| 257 |
+
body: QueryRequest,
|
| 258 |
+
auth: Annotated[tuple[UserContext, dict], Depends(_resolve_user_full)],
|
| 259 |
+
) -> QueryResponse:
|
| 260 |
+
user, claims = auth
|
| 261 |
+
if not rate_limiter.is_allowed(f"{user.user_id}:query"):
|
| 262 |
+
raise HTTPException(status.HTTP_429_TOO_MANY_REQUESTS, "rate limit exceeded")
|
| 263 |
+
# Caller-supplied user_id must match the bearer-token identity.
|
| 264 |
+
if body.user_id != user.user_id:
|
| 265 |
+
raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch")
|
| 266 |
+
# Use the JWT id so the audit trail can correlate a query with the
|
| 267 |
+
# exact token that authorised it; useful for revocation forensics.
|
| 268 |
+
jti = claims.get("jti", "unsigned")
|
| 269 |
+
state = await run_rag_pipeline(
|
| 270 |
+
query=body.query,
|
| 271 |
+
user_context=user,
|
| 272 |
+
thread_id=f"api-{user.user_id}-{jti}",
|
| 273 |
+
prefer_cloud=body.prefer_cloud,
|
| 274 |
+
override_provider=body.override_provider,
|
| 275 |
+
)
|
| 276 |
+
return QueryResponse.from_state(state)
|
| 277 |
+
|
| 278 |
+
@app.post("/ingest", response_model=IngestResponseModel, tags=["rag"])
|
| 279 |
+
async def ingest_endpoint(
|
| 280 |
+
body: IngestRequestModel,
|
| 281 |
+
user: Annotated[UserContext, Depends(_require_role("user"))],
|
| 282 |
+
) -> IngestResponseModel:
|
| 283 |
+
if body.user_id != user.user_id:
|
| 284 |
+
raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch")
|
| 285 |
+
from core.agents.retriever import _get_hybrid_searcher
|
| 286 |
+
from ingestion.pipeline import IngestionPipeline
|
| 287 |
+
|
| 288 |
+
searcher = _get_hybrid_searcher()
|
| 289 |
+
pipeline = IngestionPipeline(
|
| 290 |
+
qdrant_manager=searcher._qdrant, # type: ignore[attr-defined]
|
| 291 |
+
embedding_service=searcher._embeddings, # type: ignore[attr-defined]
|
| 292 |
+
sparse_service=searcher._sparse, # type: ignore[attr-defined]
|
| 293 |
+
)
|
| 294 |
+
req = IngestRequest(
|
| 295 |
+
file_path=body.file_path,
|
| 296 |
+
user_id=body.user_id,
|
| 297 |
+
org_id=body.org_id,
|
| 298 |
+
sensitivity_level=SensitivityLevel(body.sensitivity_level),
|
| 299 |
+
roles=body.roles,
|
| 300 |
+
)
|
| 301 |
+
result = await pipeline.ingest_document(req)
|
| 302 |
+
return IngestResponseModel(
|
| 303 |
+
file_path=result.file_path,
|
| 304 |
+
status=result.status,
|
| 305 |
+
num_chunks=result.num_chunks,
|
| 306 |
+
point_ids=result.point_ids,
|
| 307 |
+
errors=result.errors,
|
| 308 |
+
processing_time_seconds=result.processing_time_seconds,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
@app.get("/audit", tags=["audit"])
|
| 312 |
+
async def audit_list(
|
| 313 |
+
user: Annotated[UserContext, Depends(_require_role("admin"))],
|
| 314 |
+
start: str | None = None,
|
| 315 |
+
end: str | None = None,
|
| 316 |
+
limit: int = 100,
|
| 317 |
+
) -> dict:
|
| 318 |
+
today = date.today().isoformat()
|
| 319 |
+
entries = audit_logger.get_entries(
|
| 320 |
+
start_date=start or today,
|
| 321 |
+
end_date=end or today,
|
| 322 |
+
user_id=None,
|
| 323 |
+
action=None,
|
| 324 |
+
)
|
| 325 |
+
return {
|
| 326 |
+
"total": len(entries),
|
| 327 |
+
"items": [e.model_dump(mode="json") for e in entries[:limit]],
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
@app.post("/audit/verify", tags=["audit"])
|
| 331 |
+
async def audit_verify(
|
| 332 |
+
user: Annotated[UserContext, Depends(_require_role("admin"))],
|
| 333 |
+
start: str | None = None,
|
| 334 |
+
end: str | None = None,
|
| 335 |
+
) -> dict:
|
| 336 |
+
result = audit_logger.verify_chain(start_date=start, end_date=end)
|
| 337 |
+
return result
|
| 338 |
+
|
| 339 |
+
from pydantic import BaseModel as _PydBM
|
| 340 |
+
|
| 341 |
+
class _TokenRequest(_PydBM):
|
| 342 |
+
"""Identity payload accepted by the dev ``/token`` endpoint."""
|
| 343 |
+
|
| 344 |
+
user_id: str
|
| 345 |
+
org_id: str = ""
|
| 346 |
+
roles: list[str] = []
|
| 347 |
+
clearance_level: int = 1
|
| 348 |
+
ttl_seconds: int | None = None
|
| 349 |
+
|
| 350 |
+
class _TokenResponse(_PydBM):
|
| 351 |
+
access_token: str
|
| 352 |
+
token_type: str = "bearer"
|
| 353 |
+
expires_in: int
|
| 354 |
+
|
| 355 |
+
@app.post("/token", response_model=_TokenResponse, tags=["auth"])
|
| 356 |
+
async def issue_dev_token(body: _TokenRequest) -> _TokenResponse:
|
| 357 |
+
"""Mint a signed JWT for local testing.
|
| 358 |
+
|
| 359 |
+
In production the IdP (Keycloak / Auth0 / Microsoft Entra) issues the
|
| 360 |
+
token externally and this endpoint is removed via the
|
| 361 |
+
``SAR_DISABLE_DEV_TOKEN`` flag — kept here so the e2e smoke script
|
| 362 |
+
and the Streamlit demo can mint a real token rather than the
|
| 363 |
+
unsigned base64 fallback.
|
| 364 |
+
"""
|
| 365 |
+
if settings.jwt_algorithm.upper() == "RS256":
|
| 366 |
+
raise HTTPException(
|
| 367 |
+
status.HTTP_404_NOT_FOUND,
|
| 368 |
+
"Dev token endpoint disabled in RS256 mode — use the external IdP",
|
| 369 |
+
)
|
| 370 |
+
if not settings.jwt_secret:
|
| 371 |
+
raise HTTPException(
|
| 372 |
+
status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 373 |
+
"SAR_JWT_SECRET is not configured; token endpoint disabled",
|
| 374 |
+
)
|
| 375 |
+
try:
|
| 376 |
+
token = issue_token(
|
| 377 |
+
user_id=body.user_id,
|
| 378 |
+
org_id=body.org_id,
|
| 379 |
+
roles=body.roles,
|
| 380 |
+
clearance_level=body.clearance_level,
|
| 381 |
+
ttl_seconds=body.ttl_seconds,
|
| 382 |
+
)
|
| 383 |
+
except AuthError as exc:
|
| 384 |
+
raise HTTPException(
|
| 385 |
+
status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}"
|
| 386 |
+
) from exc
|
| 387 |
+
return _TokenResponse(
|
| 388 |
+
access_token=token,
|
| 389 |
+
token_type="bearer",
|
| 390 |
+
expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
|
| 391 |
+
)
|
| 392 |
+
try:
|
| 393 |
+
token = issue_token(
|
| 394 |
+
user_id=body.user_id,
|
| 395 |
+
org_id=body.org_id,
|
| 396 |
+
roles=body.roles,
|
| 397 |
+
clearance_level=body.clearance_level,
|
| 398 |
+
ttl_seconds=body.ttl_seconds,
|
| 399 |
+
)
|
| 400 |
+
except AuthError as exc:
|
| 401 |
+
raise HTTPException(
|
| 402 |
+
status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}"
|
| 403 |
+
) from exc
|
| 404 |
+
return _TokenResponse(
|
| 405 |
+
access_token=token,
|
| 406 |
+
expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
else: # pragma: no cover
|
| 410 |
+
app = None # type: ignore[assignment]
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def mint_dev_token(user: dict) -> str:
|
| 414 |
+
"""Convenience for local testing — build a bearer token for a UserContext dict.
|
| 415 |
+
|
| 416 |
+
When ``SAR_JWT_SECRET`` is configured this mints a real signed JWT; with
|
| 417 |
+
no secret it falls back to the legacy unsigned base64 shape so existing
|
| 418 |
+
test fixtures keep working.
|
| 419 |
+
"""
|
| 420 |
+
if settings.jwt_secret:
|
| 421 |
+
try:
|
| 422 |
+
return issue_token(
|
| 423 |
+
user_id=user.get("user_id", ""),
|
| 424 |
+
org_id=user.get("org_id", ""),
|
| 425 |
+
roles=list(user.get("roles", [])),
|
| 426 |
+
clearance_level=int(user.get("clearance_level", 1)),
|
| 427 |
+
)
|
| 428 |
+
except AuthError:
|
| 429 |
+
# Fall through to legacy shape on issuer error.
|
| 430 |
+
pass
|
| 431 |
+
payload = json.dumps(user).encode("utf-8")
|
| 432 |
+
return base64.b64encode(payload).decode("ascii")
|
interfaces/byok.py
CHANGED
|
@@ -1,166 +1,166 @@
|
|
| 1 |
-
"""BYOK (Bring Your Own Key) request extraction for the public demo.
|
| 2 |
-
|
| 3 |
-
Mounted on the FastAPI surface only when ``settings.byok_mode=True`` (production
|
| 4 |
-
HF Space image). Extracts per-request LLM credentials and session identity from
|
| 5 |
-
HTTP headers so the RAG pipeline can route to the visitor's own LLM provider
|
| 6 |
-
and Qdrant collection.
|
| 7 |
-
|
| 8 |
-
The extracted ``ByokCreds`` is **never persisted**:
|
| 9 |
-
|
| 10 |
-
- API keys live only in the request scope (FastAPI dep dies after response)
|
| 11 |
-
- ``utils.pii.redact`` strips key-shaped substrings from audit log entries
|
| 12 |
-
- The frontend stores the key in ``localStorage`` and forwards it as a header;
|
| 13 |
-
cookies are forbidden (CSRF surface).
|
| 14 |
-
|
| 15 |
-
See ``launch-plan/03-backend-byok.md`` and ``launch-plan/11-security-checklist.md``.
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
from __future__ import annotations
|
| 19 |
-
|
| 20 |
-
import hashlib
|
| 21 |
-
import uuid
|
| 22 |
-
from typing import TYPE_CHECKING
|
| 23 |
-
|
| 24 |
-
from pydantic import BaseModel, ConfigDict, Field
|
| 25 |
-
|
| 26 |
-
if TYPE_CHECKING:
|
| 27 |
-
from fastapi import Request
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# Header names the frontend sends.
|
| 31 |
-
HDR_USER_KEY = "X-User-LLM-Key"
|
| 32 |
-
HDR_USER_PROVIDER = "X-User-Provider"
|
| 33 |
-
HDR_USER_OLLAMA_URL = "X-User-Ollama-URL"
|
| 34 |
-
HDR_SESSION_ID = "X-Session-ID"
|
| 35 |
-
HDR_DEMO_PERSONA = "X-Demo-Persona"
|
| 36 |
-
|
| 37 |
-
# Supported provider literals carried in X-User-Provider.
|
| 38 |
-
SUPPORTED_PROVIDERS: frozenset[str] = frozenset({"groq", "openai", "anthropic", "ollama"})
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class ByokCreds(BaseModel):
|
| 42 |
-
"""Per-request BYOK credentials and session identity.
|
| 43 |
-
|
| 44 |
-
Attributes:
|
| 45 |
-
user_key: Visitor's own LLM provider API key. None means owner-key
|
| 46 |
-
fallback (subject to ``OwnerKeyHourThrottle``).
|
| 47 |
-
provider: Which LLM provider the ``user_key`` is for. Validated
|
| 48 |
-
against ``SUPPORTED_PROVIDERS``. None defaults to the platform
|
| 49 |
-
owner's configured ``cloud_provider``.
|
| 50 |
-
ollama_url: Visitor's Ollama instance URL when provider == "ollama".
|
| 51 |
-
Ignored otherwise.
|
| 52 |
-
session_id: Per-visitor session identifier. Drives the per-session
|
| 53 |
-
Qdrant collection name. Generated server-side when the visitor
|
| 54 |
-
does not provide one (first request of a session).
|
| 55 |
-
demo_persona: Optional preset RBAC profile for the public demo —
|
| 56 |
-
``engineer`` / ``compliance`` / ``executive``. Translated to
|
| 57 |
-
``UserContext`` downstream.
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
model_config = ConfigDict(frozen=True, str_strip_whitespace=True)
|
| 61 |
-
|
| 62 |
-
user_key: str | None = None
|
| 63 |
-
provider: str | None = None
|
| 64 |
-
ollama_url: str | None = None
|
| 65 |
-
session_id: str = Field(..., min_length=1, max_length=128)
|
| 66 |
-
demo_persona: str | None = None
|
| 67 |
-
|
| 68 |
-
def has_user_key(self) -> bool:
|
| 69 |
-
"""True when the visitor brought their own LLM key.
|
| 70 |
-
|
| 71 |
-
Owner-key fallback (False) goes through the per-IP throttle; visitor
|
| 72 |
-
BYOK (True) bypasses it. Callers MUST consult this before deciding to
|
| 73 |
-
consume the owner-key quota.
|
| 74 |
-
"""
|
| 75 |
-
return bool(self.user_key and self.user_key.strip())
|
| 76 |
-
|
| 77 |
-
def safe_provider(self) -> str | None:
|
| 78 |
-
"""Return ``provider`` if it is in the allowlist, else None."""
|
| 79 |
-
if self.provider and self.provider.lower() in SUPPORTED_PROVIDERS:
|
| 80 |
-
return self.provider.lower()
|
| 81 |
-
return None
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def _derive_session_id(client_host: str | None) -> str:
|
| 85 |
-
"""Generate a deterministic-but-non-identifying session ID.
|
| 86 |
-
|
| 87 |
-
Falls back to a short hash of the client host + a random UUID. The hash
|
| 88 |
-
keeps the same session sticky if the visitor reconnects within the same
|
| 89 |
-
UVicorn worker; the random UUID ensures cross-worker / cross-restart
|
| 90 |
-
isolation. The full UUID flavour stays server-side — we never expose
|
| 91 |
-
raw IP addresses in the collection name.
|
| 92 |
-
"""
|
| 93 |
-
host = (client_host or "anon").strip() or "anon"
|
| 94 |
-
digest = hashlib.sha256(host.encode("utf-8")).hexdigest()[:8]
|
| 95 |
-
random = uuid.uuid4().hex[:8]
|
| 96 |
-
return f"{digest}-{random}"
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def build_creds(
|
| 100 |
-
*,
|
| 101 |
-
user_key: str | None,
|
| 102 |
-
provider: str | None,
|
| 103 |
-
ollama_url: str | None,
|
| 104 |
-
session_id: str | None,
|
| 105 |
-
demo_persona: str | None,
|
| 106 |
-
client_host: str | None,
|
| 107 |
-
) -> ByokCreds:
|
| 108 |
-
"""Pure factory — builds ``ByokCreds`` from raw header values.
|
| 109 |
-
|
| 110 |
-
Separated from the FastAPI dependency so it is unit-testable without
|
| 111 |
-
spinning up a Request object. Whitespace-trims every input; generates
|
| 112 |
-
``session_id`` server-side when the client omitted it.
|
| 113 |
-
"""
|
| 114 |
-
return ByokCreds(
|
| 115 |
-
user_key=(user_key or None),
|
| 116 |
-
provider=(provider or None),
|
| 117 |
-
ollama_url=(ollama_url or None),
|
| 118 |
-
session_id=(session_id or "").strip() or _derive_session_id(client_host),
|
| 119 |
-
demo_persona=(demo_persona or None),
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# ── FastAPI integration ──────────────────────────────────────────────────────
|
| 124 |
-
# Header annotations live in this branch so the module can be imported in
|
| 125 |
-
# environments where fastapi is not installed (e.g. lightweight unit tests).
|
| 126 |
-
|
| 127 |
-
try:
|
| 128 |
-
# Runtime imports — FastAPI dependency injection reads annotations at
|
| 129 |
-
# request time, so these must NOT live in a TYPE_CHECKING-only block.
|
| 130 |
-
from fastapi import Header, Request # noqa: TC002
|
| 131 |
-
|
| 132 |
-
_FASTAPI_AVAILABLE = True
|
| 133 |
-
except ImportError: # pragma: no cover
|
| 134 |
-
_FASTAPI_AVAILABLE = False
|
| 135 |
-
|
| 136 |
-
def Header(*_a: object, **_kw: object) -> None: # type: ignore[no-redef] # noqa: N802 — keep FastAPI's name
|
| 137 |
-
"""No-op shim when FastAPI is not installed (lint-only env)."""
|
| 138 |
-
return None
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
if _FASTAPI_AVAILABLE:
|
| 142 |
-
from typing import Annotated
|
| 143 |
-
|
| 144 |
-
def extract_byok(
|
| 145 |
-
request: Request,
|
| 146 |
-
x_user_llm_key: Annotated[str | None, Header()] = None,
|
| 147 |
-
x_user_provider: Annotated[str | None, Header()] = None,
|
| 148 |
-
x_user_ollama_url: Annotated[str | None, Header()] = None,
|
| 149 |
-
x_session_id: Annotated[str | None, Header()] = None,
|
| 150 |
-
x_demo_persona: Annotated[str | None, Header()] = None,
|
| 151 |
-
) -> ByokCreds:
|
| 152 |
-
"""FastAPI dependency: extract per-request BYOK credentials.
|
| 153 |
-
|
| 154 |
-
Pure data extraction — authentication, throttling, and routing
|
| 155 |
-
decisions happen downstream so they can be unit-tested independently
|
| 156 |
-
of FastAPI's request lifecycle.
|
| 157 |
-
"""
|
| 158 |
-
host = request.client.host if request.client else None
|
| 159 |
-
return build_creds(
|
| 160 |
-
user_key=x_user_llm_key,
|
| 161 |
-
provider=x_user_provider,
|
| 162 |
-
ollama_url=x_user_ollama_url,
|
| 163 |
-
session_id=x_session_id,
|
| 164 |
-
demo_persona=x_demo_persona,
|
| 165 |
-
client_host=host,
|
| 166 |
-
)
|
|
|
|
| 1 |
+
"""BYOK (Bring Your Own Key) request extraction for the public demo.
|
| 2 |
+
|
| 3 |
+
Mounted on the FastAPI surface only when ``settings.byok_mode=True`` (production
|
| 4 |
+
HF Space image). Extracts per-request LLM credentials and session identity from
|
| 5 |
+
HTTP headers so the RAG pipeline can route to the visitor's own LLM provider
|
| 6 |
+
and Qdrant collection.
|
| 7 |
+
|
| 8 |
+
The extracted ``ByokCreds`` is **never persisted**:
|
| 9 |
+
|
| 10 |
+
- API keys live only in the request scope (FastAPI dep dies after response)
|
| 11 |
+
- ``utils.pii.redact`` strips key-shaped substrings from audit log entries
|
| 12 |
+
- The frontend stores the key in ``localStorage`` and forwards it as a header;
|
| 13 |
+
cookies are forbidden (CSRF surface).
|
| 14 |
+
|
| 15 |
+
See ``launch-plan/03-backend-byok.md`` and ``launch-plan/11-security-checklist.md``.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import hashlib
|
| 21 |
+
import uuid
|
| 22 |
+
from typing import TYPE_CHECKING
|
| 23 |
+
|
| 24 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from fastapi import Request
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Header names the frontend sends.
|
| 31 |
+
HDR_USER_KEY = "X-User-LLM-Key"
|
| 32 |
+
HDR_USER_PROVIDER = "X-User-Provider"
|
| 33 |
+
HDR_USER_OLLAMA_URL = "X-User-Ollama-URL"
|
| 34 |
+
HDR_SESSION_ID = "X-Session-ID"
|
| 35 |
+
HDR_DEMO_PERSONA = "X-Demo-Persona"
|
| 36 |
+
|
| 37 |
+
# Supported provider literals carried in X-User-Provider.
|
| 38 |
+
SUPPORTED_PROVIDERS: frozenset[str] = frozenset({"groq", "openai", "anthropic", "ollama"})
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ByokCreds(BaseModel):
|
| 42 |
+
"""Per-request BYOK credentials and session identity.
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
user_key: Visitor's own LLM provider API key. None means owner-key
|
| 46 |
+
fallback (subject to ``OwnerKeyHourThrottle``).
|
| 47 |
+
provider: Which LLM provider the ``user_key`` is for. Validated
|
| 48 |
+
against ``SUPPORTED_PROVIDERS``. None defaults to the platform
|
| 49 |
+
owner's configured ``cloud_provider``.
|
| 50 |
+
ollama_url: Visitor's Ollama instance URL when provider == "ollama".
|
| 51 |
+
Ignored otherwise.
|
| 52 |
+
session_id: Per-visitor session identifier. Drives the per-session
|
| 53 |
+
Qdrant collection name. Generated server-side when the visitor
|
| 54 |
+
does not provide one (first request of a session).
|
| 55 |
+
demo_persona: Optional preset RBAC profile for the public demo —
|
| 56 |
+
``engineer`` / ``compliance`` / ``executive``. Translated to
|
| 57 |
+
``UserContext`` downstream.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
model_config = ConfigDict(frozen=True, str_strip_whitespace=True)
|
| 61 |
+
|
| 62 |
+
user_key: str | None = None
|
| 63 |
+
provider: str | None = None
|
| 64 |
+
ollama_url: str | None = None
|
| 65 |
+
session_id: str = Field(..., min_length=1, max_length=128)
|
| 66 |
+
demo_persona: str | None = None
|
| 67 |
+
|
| 68 |
+
def has_user_key(self) -> bool:
|
| 69 |
+
"""True when the visitor brought their own LLM key.
|
| 70 |
+
|
| 71 |
+
Owner-key fallback (False) goes through the per-IP throttle; visitor
|
| 72 |
+
BYOK (True) bypasses it. Callers MUST consult this before deciding to
|
| 73 |
+
consume the owner-key quota.
|
| 74 |
+
"""
|
| 75 |
+
return bool(self.user_key and self.user_key.strip())
|
| 76 |
+
|
| 77 |
+
def safe_provider(self) -> str | None:
|
| 78 |
+
"""Return ``provider`` if it is in the allowlist, else None."""
|
| 79 |
+
if self.provider and self.provider.lower() in SUPPORTED_PROVIDERS:
|
| 80 |
+
return self.provider.lower()
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _derive_session_id(client_host: str | None) -> str:
|
| 85 |
+
"""Generate a deterministic-but-non-identifying session ID.
|
| 86 |
+
|
| 87 |
+
Falls back to a short hash of the client host + a random UUID. The hash
|
| 88 |
+
keeps the same session sticky if the visitor reconnects within the same
|
| 89 |
+
UVicorn worker; the random UUID ensures cross-worker / cross-restart
|
| 90 |
+
isolation. The full UUID flavour stays server-side — we never expose
|
| 91 |
+
raw IP addresses in the collection name.
|
| 92 |
+
"""
|
| 93 |
+
host = (client_host or "anon").strip() or "anon"
|
| 94 |
+
digest = hashlib.sha256(host.encode("utf-8")).hexdigest()[:8]
|
| 95 |
+
random = uuid.uuid4().hex[:8]
|
| 96 |
+
return f"{digest}-{random}"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def build_creds(
|
| 100 |
+
*,
|
| 101 |
+
user_key: str | None,
|
| 102 |
+
provider: str | None,
|
| 103 |
+
ollama_url: str | None,
|
| 104 |
+
session_id: str | None,
|
| 105 |
+
demo_persona: str | None,
|
| 106 |
+
client_host: str | None,
|
| 107 |
+
) -> ByokCreds:
|
| 108 |
+
"""Pure factory — builds ``ByokCreds`` from raw header values.
|
| 109 |
+
|
| 110 |
+
Separated from the FastAPI dependency so it is unit-testable without
|
| 111 |
+
spinning up a Request object. Whitespace-trims every input; generates
|
| 112 |
+
``session_id`` server-side when the client omitted it.
|
| 113 |
+
"""
|
| 114 |
+
return ByokCreds(
|
| 115 |
+
user_key=(user_key or None),
|
| 116 |
+
provider=(provider or None),
|
| 117 |
+
ollama_url=(ollama_url or None),
|
| 118 |
+
session_id=(session_id or "").strip() or _derive_session_id(client_host),
|
| 119 |
+
demo_persona=(demo_persona or None),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ── FastAPI integration ──────────────────────────────────────────────────────
|
| 124 |
+
# Header annotations live in this branch so the module can be imported in
|
| 125 |
+
# environments where fastapi is not installed (e.g. lightweight unit tests).
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
# Runtime imports — FastAPI dependency injection reads annotations at
|
| 129 |
+
# request time, so these must NOT live in a TYPE_CHECKING-only block.
|
| 130 |
+
from fastapi import Header, Request # noqa: TC002
|
| 131 |
+
|
| 132 |
+
_FASTAPI_AVAILABLE = True
|
| 133 |
+
except ImportError: # pragma: no cover
|
| 134 |
+
_FASTAPI_AVAILABLE = False
|
| 135 |
+
|
| 136 |
+
def Header(*_a: object, **_kw: object) -> None: # type: ignore[no-redef] # noqa: N802 — keep FastAPI's name
|
| 137 |
+
"""No-op shim when FastAPI is not installed (lint-only env)."""
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if _FASTAPI_AVAILABLE:
|
| 142 |
+
from typing import Annotated
|
| 143 |
+
|
| 144 |
+
def extract_byok(
|
| 145 |
+
request: Request,
|
| 146 |
+
x_user_llm_key: Annotated[str | None, Header()] = None,
|
| 147 |
+
x_user_provider: Annotated[str | None, Header()] = None,
|
| 148 |
+
x_user_ollama_url: Annotated[str | None, Header()] = None,
|
| 149 |
+
x_session_id: Annotated[str | None, Header()] = None,
|
| 150 |
+
x_demo_persona: Annotated[str | None, Header()] = None,
|
| 151 |
+
) -> ByokCreds:
|
| 152 |
+
"""FastAPI dependency: extract per-request BYOK credentials.
|
| 153 |
+
|
| 154 |
+
Pure data extraction — authentication, throttling, and routing
|
| 155 |
+
decisions happen downstream so they can be unit-tested independently
|
| 156 |
+
of FastAPI's request lifecycle.
|
| 157 |
+
"""
|
| 158 |
+
host = request.client.host if request.client else None
|
| 159 |
+
return build_creds(
|
| 160 |
+
user_key=x_user_llm_key,
|
| 161 |
+
provider=x_user_provider,
|
| 162 |
+
ollama_url=x_user_ollama_url,
|
| 163 |
+
session_id=x_session_id,
|
| 164 |
+
demo_persona=x_demo_persona,
|
| 165 |
+
client_host=host,
|
| 166 |
+
)
|
retrieval/multitenancy.py
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
-
"""Multi-tenancy utilities for Qdrant collection naming."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
from config.settings import settings
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def _sanitize(s: str) -> str:
|
| 9 |
-
"""Coerce ``s`` to a Qdrant-safe identifier (alnum + underscore only)."""
|
| 10 |
-
return "".join(c if c.isalnum() else "_" for c in s)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def get_collection_name(
|
| 14 |
-
org_id: str | None = None,
|
| 15 |
-
*,
|
| 16 |
-
session_id: str | None = None,
|
| 17 |
-
) -> str:
|
| 18 |
-
"""Return the Qdrant collection name for a given org or BYOK session.
|
| 19 |
-
|
| 20 |
-
Resolution order:
|
| 21 |
-
|
| 22 |
-
1. **BYOK mode** (``settings.byok_mode=True``) with ``session_id`` →
|
| 23 |
-
returns ``"{base}_sess_{sanitized_session}"``. Session-scoped
|
| 24 |
-
collections isolate each visitor's uploads.
|
| 25 |
-
2. **Multi-tenant** (``settings.multi_tenant_collections=True``) with
|
| 26 |
-
``org_id`` → returns ``"{base}_{sanitized_org}"``.
|
| 27 |
-
3. **Single-tenant** (default) → returns ``settings.qdrant_collection``.
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
org_id: Organisation identifier (multi-tenant mode).
|
| 31 |
-
session_id: Per-visitor session UUID (BYOK mode). Takes priority over
|
| 32 |
-
``org_id`` when both are set and BYOK is on, because BYOK is the
|
| 33 |
-
stricter isolation boundary.
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
Collection name string suitable for QdrantManager.
|
| 37 |
-
"""
|
| 38 |
-
base = settings.qdrant_collection
|
| 39 |
-
if settings.byok_mode and session_id:
|
| 40 |
-
return f"{base}_sess_{_sanitize(session_id)}"
|
| 41 |
-
if not settings.multi_tenant_collections or not org_id:
|
| 42 |
-
return base
|
| 43 |
-
return f"{base}_{_sanitize(org_id)}"
|
|
|
|
| 1 |
+
"""Multi-tenancy utilities for Qdrant collection naming."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from config.settings import settings
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _sanitize(s: str) -> str:
|
| 9 |
+
"""Coerce ``s`` to a Qdrant-safe identifier (alnum + underscore only)."""
|
| 10 |
+
return "".join(c if c.isalnum() else "_" for c in s)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_collection_name(
|
| 14 |
+
org_id: str | None = None,
|
| 15 |
+
*,
|
| 16 |
+
session_id: str | None = None,
|
| 17 |
+
) -> str:
|
| 18 |
+
"""Return the Qdrant collection name for a given org or BYOK session.
|
| 19 |
+
|
| 20 |
+
Resolution order:
|
| 21 |
+
|
| 22 |
+
1. **BYOK mode** (``settings.byok_mode=True``) with ``session_id`` →
|
| 23 |
+
returns ``"{base}_sess_{sanitized_session}"``. Session-scoped
|
| 24 |
+
collections isolate each visitor's uploads.
|
| 25 |
+
2. **Multi-tenant** (``settings.multi_tenant_collections=True``) with
|
| 26 |
+
``org_id`` → returns ``"{base}_{sanitized_org}"``.
|
| 27 |
+
3. **Single-tenant** (default) → returns ``settings.qdrant_collection``.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
org_id: Organisation identifier (multi-tenant mode).
|
| 31 |
+
session_id: Per-visitor session UUID (BYOK mode). Takes priority over
|
| 32 |
+
``org_id`` when both are set and BYOK is on, because BYOK is the
|
| 33 |
+
stricter isolation boundary.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Collection name string suitable for QdrantManager.
|
| 37 |
+
"""
|
| 38 |
+
base = settings.qdrant_collection
|
| 39 |
+
if settings.byok_mode and session_id:
|
| 40 |
+
return f"{base}_sess_{_sanitize(session_id)}"
|
| 41 |
+
if not settings.multi_tenant_collections or not org_id:
|
| 42 |
+
return base
|
| 43 |
+
return f"{base}_{_sanitize(org_id)}"
|
retrieval/session_purge.py
CHANGED
|
@@ -1,185 +1,185 @@
|
|
| 1 |
-
"""Per-session Qdrant collection purge for BYOK mode.
|
| 2 |
-
|
| 3 |
-
In BYOK mode each visitor's uploads land in a collection named
|
| 4 |
-
``documents_sess_<sanitized_session_id>``. Without a cleanup pass these
|
| 5 |
-
collections accumulate until the 1 GB Qdrant Cloud free tier fills up.
|
| 6 |
-
|
| 7 |
-
This module provides:
|
| 8 |
-
|
| 9 |
-
- :func:`purge_expired_sessions` — synchronous, idempotent sweep that
|
| 10 |
-
deletes collections whose creation timestamp is older than
|
| 11 |
-
``settings.session_collection_ttl_hours``.
|
| 12 |
-
- :func:`schedule_session_purge` — APScheduler hook the FastAPI lifespan
|
| 13 |
-
calls so the sweep runs every 6 hours inside the same process. No
|
| 14 |
-
separate cron container required.
|
| 15 |
-
|
| 16 |
-
The creation timestamp is read from Qdrant's
|
| 17 |
-
``CollectionInfo.config.params.metadata`` (set at create-time by the
|
| 18 |
-
ingestion pipeline). Collections without a creation timestamp are treated
|
| 19 |
-
as legacy and **skipped** — we never delete data we can't date.
|
| 20 |
-
|
| 21 |
-
See ``launch-plan/03-backend-byok.md`` § Session purge cron.
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
from __future__ import annotations
|
| 25 |
-
|
| 26 |
-
from datetime import UTC, datetime, timedelta
|
| 27 |
-
from typing import TYPE_CHECKING, Any
|
| 28 |
-
|
| 29 |
-
from config.settings import settings
|
| 30 |
-
from utils.logging import get_logger
|
| 31 |
-
|
| 32 |
-
if TYPE_CHECKING:
|
| 33 |
-
from qdrant_client import QdrantClient
|
| 34 |
-
|
| 35 |
-
logger = get_logger(__name__)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
SESSION_COLLECTION_PREFIX = "_sess_"
|
| 39 |
-
"""Suffix introduced into the collection name by ``get_collection_name`` when
|
| 40 |
-
``byok_mode`` is on and a ``session_id`` is supplied. Used here to filter the
|
| 41 |
-
purge sweep to BYOK collections only — multi-tenant org collections are NOT
|
| 42 |
-
touched."""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _session_collection_prefix() -> str:
|
| 46 |
-
"""Concrete prefix for the current base collection (e.g. ``documents_sess_``)."""
|
| 47 |
-
return f"{settings.qdrant_collection}{SESSION_COLLECTION_PREFIX}"
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _is_session_collection(name: str) -> bool:
|
| 51 |
-
"""True iff ``name`` was emitted by ``get_collection_name`` with a session_id."""
|
| 52 |
-
return name.startswith(_session_collection_prefix())
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _parse_created_at(meta: dict[str, Any] | None) -> datetime | None:
|
| 56 |
-
"""Return the collection's recorded creation datetime, or None if missing.
|
| 57 |
-
|
| 58 |
-
The ingestion pipeline writes ``created_at`` as an ISO-8601 UTC string into
|
| 59 |
-
the collection's metadata payload when first creating a session
|
| 60 |
-
collection. Older collections lack the field — those are intentionally
|
| 61 |
-
skipped to avoid deleting data we cannot date.
|
| 62 |
-
"""
|
| 63 |
-
if not meta:
|
| 64 |
-
return None
|
| 65 |
-
raw = meta.get("created_at")
|
| 66 |
-
if not raw:
|
| 67 |
-
return None
|
| 68 |
-
try:
|
| 69 |
-
# Accept both ``2026-05-26T13:00:00+00:00`` and trailing ``Z`` forms.
|
| 70 |
-
return datetime.fromisoformat(str(raw).replace("Z", "+00:00"))
|
| 71 |
-
except (TypeError, ValueError):
|
| 72 |
-
logger.warning("session_purge_bad_timestamp", value=str(raw))
|
| 73 |
-
return None
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def purge_expired_sessions(
|
| 77 |
-
client: QdrantClient,
|
| 78 |
-
*,
|
| 79 |
-
ttl_hours: int | None = None,
|
| 80 |
-
now: datetime | None = None,
|
| 81 |
-
) -> dict[str, Any]:
|
| 82 |
-
"""Delete BYOK session collections older than the TTL.
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
client: Live ``QdrantClient`` (cloud or local).
|
| 86 |
-
ttl_hours: Override ``settings.session_collection_ttl_hours``. Tests
|
| 87 |
-
pass small values; production uses the default 24.
|
| 88 |
-
now: Override the clock for deterministic tests.
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
Summary dict with counts (``inspected``, ``deleted``, ``skipped``,
|
| 92 |
-
``errors``) suitable for emission to the audit log.
|
| 93 |
-
"""
|
| 94 |
-
ttl = ttl_hours if ttl_hours is not None else settings.session_collection_ttl_hours
|
| 95 |
-
horizon = (now or datetime.now(UTC)) - timedelta(hours=ttl)
|
| 96 |
-
inspected = deleted = skipped = errors = 0
|
| 97 |
-
deleted_names: list[str] = []
|
| 98 |
-
|
| 99 |
-
try:
|
| 100 |
-
collections = client.get_collections().collections
|
| 101 |
-
except Exception as exc:
|
| 102 |
-
logger.error("session_purge_list_failed", error=str(exc))
|
| 103 |
-
return {"inspected": 0, "deleted": 0, "skipped": 0, "errors": 1}
|
| 104 |
-
|
| 105 |
-
for col in collections:
|
| 106 |
-
name = col.name
|
| 107 |
-
if not _is_session_collection(name):
|
| 108 |
-
continue
|
| 109 |
-
inspected += 1
|
| 110 |
-
try:
|
| 111 |
-
info = client.get_collection(name)
|
| 112 |
-
meta = getattr(info.config.params, "metadata", None) or {}
|
| 113 |
-
created = _parse_created_at(meta)
|
| 114 |
-
if created is None:
|
| 115 |
-
# Undated -> skip; we don't delete what we can't time-stamp.
|
| 116 |
-
skipped += 1
|
| 117 |
-
continue
|
| 118 |
-
if created < horizon:
|
| 119 |
-
client.delete_collection(name)
|
| 120 |
-
deleted += 1
|
| 121 |
-
deleted_names.append(name)
|
| 122 |
-
logger.info(
|
| 123 |
-
"session_purge_deleted",
|
| 124 |
-
collection=name,
|
| 125 |
-
created_at=created.isoformat(),
|
| 126 |
-
age_hours=round((horizon - created).total_seconds() / 3600.0 + ttl, 1),
|
| 127 |
-
)
|
| 128 |
-
else:
|
| 129 |
-
skipped += 1
|
| 130 |
-
except Exception as exc:
|
| 131 |
-
errors += 1
|
| 132 |
-
logger.warning("session_purge_collection_failed", collection=name, error=str(exc))
|
| 133 |
-
|
| 134 |
-
summary = {
|
| 135 |
-
"inspected": inspected,
|
| 136 |
-
"deleted": deleted,
|
| 137 |
-
"skipped": skipped,
|
| 138 |
-
"errors": errors,
|
| 139 |
-
"deleted_names": deleted_names,
|
| 140 |
-
"ttl_hours": ttl,
|
| 141 |
-
}
|
| 142 |
-
logger.info(
|
| 143 |
-
"session_purge_summary", **{k: v for k, v in summary.items() if k != "deleted_names"}
|
| 144 |
-
)
|
| 145 |
-
return summary
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
# ── FastAPI lifespan hook ────────────────────────────────────────────────────
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def schedule_session_purge(client: QdrantClient, *, interval_hours: int = 6) -> Any | None:
|
| 152 |
-
"""Start an APScheduler job that runs :func:`purge_expired_sessions` periodically.
|
| 153 |
-
|
| 154 |
-
Called from the FastAPI ``lifespan`` context manager. Returns the
|
| 155 |
-
``AsyncIOScheduler`` instance (or None when APScheduler is not
|
| 156 |
-
installed — we then run as a single-shot at startup so at least one
|
| 157 |
-
sweep happens per restart).
|
| 158 |
-
"""
|
| 159 |
-
if not settings.byok_mode:
|
| 160 |
-
logger.debug("session_purge_not_scheduled", reason="byok_mode is off")
|
| 161 |
-
return None
|
| 162 |
-
|
| 163 |
-
try:
|
| 164 |
-
from apscheduler.schedulers.asyncio import (
|
| 165 |
-
AsyncIOScheduler, # type: ignore[import-not-found]
|
| 166 |
-
)
|
| 167 |
-
except ImportError:
|
| 168 |
-
# Optional dep absent: at least sweep once so the Space does not
|
| 169 |
-
# accumulate indefinitely on long uptimes.
|
| 170 |
-
logger.warning("apscheduler_missing", action="single-shot purge instead")
|
| 171 |
-
purge_expired_sessions(client)
|
| 172 |
-
return None
|
| 173 |
-
|
| 174 |
-
scheduler = AsyncIOScheduler()
|
| 175 |
-
scheduler.add_job(
|
| 176 |
-
purge_expired_sessions,
|
| 177 |
-
"interval",
|
| 178 |
-
hours=interval_hours,
|
| 179 |
-
args=[client],
|
| 180 |
-
id="byok-session-purge",
|
| 181 |
-
replace_existing=True,
|
| 182 |
-
)
|
| 183 |
-
scheduler.start()
|
| 184 |
-
logger.info("session_purge_scheduled", every_hours=interval_hours)
|
| 185 |
-
return scheduler
|
|
|
|
| 1 |
+
"""Per-session Qdrant collection purge for BYOK mode.
|
| 2 |
+
|
| 3 |
+
In BYOK mode each visitor's uploads land in a collection named
|
| 4 |
+
``documents_sess_<sanitized_session_id>``. Without a cleanup pass these
|
| 5 |
+
collections accumulate until the 1 GB Qdrant Cloud free tier fills up.
|
| 6 |
+
|
| 7 |
+
This module provides:
|
| 8 |
+
|
| 9 |
+
- :func:`purge_expired_sessions` — synchronous, idempotent sweep that
|
| 10 |
+
deletes collections whose creation timestamp is older than
|
| 11 |
+
``settings.session_collection_ttl_hours``.
|
| 12 |
+
- :func:`schedule_session_purge` — APScheduler hook the FastAPI lifespan
|
| 13 |
+
calls so the sweep runs every 6 hours inside the same process. No
|
| 14 |
+
separate cron container required.
|
| 15 |
+
|
| 16 |
+
The creation timestamp is read from Qdrant's
|
| 17 |
+
``CollectionInfo.config.params.metadata`` (set at create-time by the
|
| 18 |
+
ingestion pipeline). Collections without a creation timestamp are treated
|
| 19 |
+
as legacy and **skipped** — we never delete data we can't date.
|
| 20 |
+
|
| 21 |
+
See ``launch-plan/03-backend-byok.md`` § Session purge cron.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
from datetime import UTC, datetime, timedelta
|
| 27 |
+
from typing import TYPE_CHECKING, Any
|
| 28 |
+
|
| 29 |
+
from config.settings import settings
|
| 30 |
+
from utils.logging import get_logger
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from qdrant_client import QdrantClient
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
SESSION_COLLECTION_PREFIX = "_sess_"
|
| 39 |
+
"""Suffix introduced into the collection name by ``get_collection_name`` when
|
| 40 |
+
``byok_mode`` is on and a ``session_id`` is supplied. Used here to filter the
|
| 41 |
+
purge sweep to BYOK collections only — multi-tenant org collections are NOT
|
| 42 |
+
touched."""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _session_collection_prefix() -> str:
|
| 46 |
+
"""Concrete prefix for the current base collection (e.g. ``documents_sess_``)."""
|
| 47 |
+
return f"{settings.qdrant_collection}{SESSION_COLLECTION_PREFIX}"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _is_session_collection(name: str) -> bool:
|
| 51 |
+
"""True iff ``name`` was emitted by ``get_collection_name`` with a session_id."""
|
| 52 |
+
return name.startswith(_session_collection_prefix())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _parse_created_at(meta: dict[str, Any] | None) -> datetime | None:
|
| 56 |
+
"""Return the collection's recorded creation datetime, or None if missing.
|
| 57 |
+
|
| 58 |
+
The ingestion pipeline writes ``created_at`` as an ISO-8601 UTC string into
|
| 59 |
+
the collection's metadata payload when first creating a session
|
| 60 |
+
collection. Older collections lack the field — those are intentionally
|
| 61 |
+
skipped to avoid deleting data we cannot date.
|
| 62 |
+
"""
|
| 63 |
+
if not meta:
|
| 64 |
+
return None
|
| 65 |
+
raw = meta.get("created_at")
|
| 66 |
+
if not raw:
|
| 67 |
+
return None
|
| 68 |
+
try:
|
| 69 |
+
# Accept both ``2026-05-26T13:00:00+00:00`` and trailing ``Z`` forms.
|
| 70 |
+
return datetime.fromisoformat(str(raw).replace("Z", "+00:00"))
|
| 71 |
+
except (TypeError, ValueError):
|
| 72 |
+
logger.warning("session_purge_bad_timestamp", value=str(raw))
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def purge_expired_sessions(
|
| 77 |
+
client: QdrantClient,
|
| 78 |
+
*,
|
| 79 |
+
ttl_hours: int | None = None,
|
| 80 |
+
now: datetime | None = None,
|
| 81 |
+
) -> dict[str, Any]:
|
| 82 |
+
"""Delete BYOK session collections older than the TTL.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
client: Live ``QdrantClient`` (cloud or local).
|
| 86 |
+
ttl_hours: Override ``settings.session_collection_ttl_hours``. Tests
|
| 87 |
+
pass small values; production uses the default 24.
|
| 88 |
+
now: Override the clock for deterministic tests.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Summary dict with counts (``inspected``, ``deleted``, ``skipped``,
|
| 92 |
+
``errors``) suitable for emission to the audit log.
|
| 93 |
+
"""
|
| 94 |
+
ttl = ttl_hours if ttl_hours is not None else settings.session_collection_ttl_hours
|
| 95 |
+
horizon = (now or datetime.now(UTC)) - timedelta(hours=ttl)
|
| 96 |
+
inspected = deleted = skipped = errors = 0
|
| 97 |
+
deleted_names: list[str] = []
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
collections = client.get_collections().collections
|
| 101 |
+
except Exception as exc:
|
| 102 |
+
logger.error("session_purge_list_failed", error=str(exc))
|
| 103 |
+
return {"inspected": 0, "deleted": 0, "skipped": 0, "errors": 1}
|
| 104 |
+
|
| 105 |
+
for col in collections:
|
| 106 |
+
name = col.name
|
| 107 |
+
if not _is_session_collection(name):
|
| 108 |
+
continue
|
| 109 |
+
inspected += 1
|
| 110 |
+
try:
|
| 111 |
+
info = client.get_collection(name)
|
| 112 |
+
meta = getattr(info.config.params, "metadata", None) or {}
|
| 113 |
+
created = _parse_created_at(meta)
|
| 114 |
+
if created is None:
|
| 115 |
+
# Undated -> skip; we don't delete what we can't time-stamp.
|
| 116 |
+
skipped += 1
|
| 117 |
+
continue
|
| 118 |
+
if created < horizon:
|
| 119 |
+
client.delete_collection(name)
|
| 120 |
+
deleted += 1
|
| 121 |
+
deleted_names.append(name)
|
| 122 |
+
logger.info(
|
| 123 |
+
"session_purge_deleted",
|
| 124 |
+
collection=name,
|
| 125 |
+
created_at=created.isoformat(),
|
| 126 |
+
age_hours=round((horizon - created).total_seconds() / 3600.0 + ttl, 1),
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
skipped += 1
|
| 130 |
+
except Exception as exc:
|
| 131 |
+
errors += 1
|
| 132 |
+
logger.warning("session_purge_collection_failed", collection=name, error=str(exc))
|
| 133 |
+
|
| 134 |
+
summary = {
|
| 135 |
+
"inspected": inspected,
|
| 136 |
+
"deleted": deleted,
|
| 137 |
+
"skipped": skipped,
|
| 138 |
+
"errors": errors,
|
| 139 |
+
"deleted_names": deleted_names,
|
| 140 |
+
"ttl_hours": ttl,
|
| 141 |
+
}
|
| 142 |
+
logger.info(
|
| 143 |
+
"session_purge_summary", **{k: v for k, v in summary.items() if k != "deleted_names"}
|
| 144 |
+
)
|
| 145 |
+
return summary
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# ── FastAPI lifespan hook ────────────────────────────────────────────────────
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def schedule_session_purge(client: QdrantClient, *, interval_hours: int = 6) -> Any | None:
|
| 152 |
+
"""Start an APScheduler job that runs :func:`purge_expired_sessions` periodically.
|
| 153 |
+
|
| 154 |
+
Called from the FastAPI ``lifespan`` context manager. Returns the
|
| 155 |
+
``AsyncIOScheduler`` instance (or None when APScheduler is not
|
| 156 |
+
installed — we then run as a single-shot at startup so at least one
|
| 157 |
+
sweep happens per restart).
|
| 158 |
+
"""
|
| 159 |
+
if not settings.byok_mode:
|
| 160 |
+
logger.debug("session_purge_not_scheduled", reason="byok_mode is off")
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
from apscheduler.schedulers.asyncio import (
|
| 165 |
+
AsyncIOScheduler, # type: ignore[import-not-found]
|
| 166 |
+
)
|
| 167 |
+
except ImportError:
|
| 168 |
+
# Optional dep absent: at least sweep once so the Space does not
|
| 169 |
+
# accumulate indefinitely on long uptimes.
|
| 170 |
+
logger.warning("apscheduler_missing", action="single-shot purge instead")
|
| 171 |
+
purge_expired_sessions(client)
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
scheduler = AsyncIOScheduler()
|
| 175 |
+
scheduler.add_job(
|
| 176 |
+
purge_expired_sessions,
|
| 177 |
+
"interval",
|
| 178 |
+
hours=interval_hours,
|
| 179 |
+
args=[client],
|
| 180 |
+
id="byok-session-purge",
|
| 181 |
+
replace_existing=True,
|
| 182 |
+
)
|
| 183 |
+
scheduler.start()
|
| 184 |
+
logger.info("session_purge_scheduled", every_hours=interval_hours)
|
| 185 |
+
return scheduler
|
utils/observability.py
CHANGED
|
@@ -1,252 +1,252 @@
|
|
| 1 |
-
"""Observability setup using Arize Phoenix for LLM tracing.
|
| 2 |
-
|
| 3 |
-
Provides OpenTelemetry-compatible distributed tracing for LLM calls,
|
| 4 |
-
retrieval operations, and LangGraph execution. Gracefully degrades
|
| 5 |
-
when Phoenix is not installed or configured.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
Call setup_tracing() once at application startup (e.g., in app/main.py).
|
| 9 |
-
All trace_* functions will automatically emit spans when tracing is enabled.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
from config.settings import settings
|
| 15 |
-
from utils.logging import get_logger
|
| 16 |
-
|
| 17 |
-
_log = get_logger(__name__)
|
| 18 |
-
|
| 19 |
-
# Module-level state
|
| 20 |
-
_tracer = None
|
| 21 |
-
_phoenix_configured = False
|
| 22 |
-
_phoenix_project_name: str = settings.app_name
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def setup_tracing() -> bool:
|
| 26 |
-
"""Initialize Phoenix tracing if ``settings.phoenix_endpoint`` is set.
|
| 27 |
-
|
| 28 |
-
This function is safe to call unconditionally at startup — it will
|
| 29 |
-
log a message and return immediately if Phoenix is not configured.
|
| 30 |
-
Tracing failures never crash the application.
|
| 31 |
-
|
| 32 |
-
Returns:
|
| 33 |
-
True if tracing was successfully enabled, False otherwise.
|
| 34 |
-
"""
|
| 35 |
-
global _tracer, _phoenix_configured, _phoenix_project_name
|
| 36 |
-
|
| 37 |
-
# BYOK mode mandates: no third-party telemetry sees a request. Phoenix
|
| 38 |
-
# spans capture LLM prompts and completions, which would include the
|
| 39 |
-
# visitor's keys-in-context and any private text they uploaded. Hard
|
| 40 |
-
# disable in BYOK regardless of phoenix_endpoint configuration.
|
| 41 |
-
if settings.byok_mode:
|
| 42 |
-
_log.info("phoenix_tracing_disabled", reason="BYOK mode forbids external telemetry")
|
| 43 |
-
return False
|
| 44 |
-
|
| 45 |
-
if not settings.phoenix_endpoint:
|
| 46 |
-
_log.info("phoenix_tracing_disabled", reason="No phoenix_endpoint configured")
|
| 47 |
-
return False
|
| 48 |
-
|
| 49 |
-
try:
|
| 50 |
-
from phoenix.otel import register
|
| 51 |
-
|
| 52 |
-
tracer_provider = register(
|
| 53 |
-
project_name=settings.app_name,
|
| 54 |
-
endpoint=settings.phoenix_endpoint,
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
# Attempt to instrument LLM and retrieval calls
|
| 58 |
-
_instrument_providers()
|
| 59 |
-
|
| 60 |
-
_phoenix_configured = True
|
| 61 |
-
_phoenix_project_name = settings.app_name
|
| 62 |
-
_log.info(
|
| 63 |
-
"phoenix_tracing_enabled",
|
| 64 |
-
endpoint=settings.phoenix_endpoint,
|
| 65 |
-
project=settings.app_name,
|
| 66 |
-
tracer_provider=str(tracer_provider),
|
| 67 |
-
)
|
| 68 |
-
return True
|
| 69 |
-
except ImportError:
|
| 70 |
-
_log.warning(
|
| 71 |
-
"phoenix_import_failed",
|
| 72 |
-
msg=(
|
| 73 |
-
"arize-phoenix not installed; tracing unavailable. "
|
| 74 |
-
"Install with: pip install 'arize-phoenix-otel'"
|
| 75 |
-
),
|
| 76 |
-
)
|
| 77 |
-
return False
|
| 78 |
-
except Exception as exc:
|
| 79 |
-
_log.error(
|
| 80 |
-
"phoenix_tracing_init_error",
|
| 81 |
-
error=str(exc),
|
| 82 |
-
endpoint=settings.phoenix_endpoint,
|
| 83 |
-
)
|
| 84 |
-
return False
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def _instrument_providers() -> None:
|
| 88 |
-
"""Instrument LLM and retrieval providers with OpenTelemetry.
|
| 89 |
-
|
| 90 |
-
Attempts to auto-instrument supported providers. Failures are
|
| 91 |
-
logged but never raised — partial instrumentation is acceptable.
|
| 92 |
-
"""
|
| 93 |
-
# Instrument LangChain/LangGraph if available
|
| 94 |
-
try:
|
| 95 |
-
from openinference.instrumentation.langchain import LangChainInstrumentor
|
| 96 |
-
|
| 97 |
-
LangChainInstrumentor().instrument()
|
| 98 |
-
_log.info("instrumented_langchain")
|
| 99 |
-
except ImportError:
|
| 100 |
-
_log.debug(
|
| 101 |
-
"langchain_instrumentation_skipped",
|
| 102 |
-
reason="openinference-instrumentation-langchain not installed",
|
| 103 |
-
)
|
| 104 |
-
except Exception as exc:
|
| 105 |
-
_log.debug("langchain_instrumentation_error", reason=str(exc))
|
| 106 |
-
|
| 107 |
-
# Instrument OpenAI-compatible calls if available
|
| 108 |
-
try:
|
| 109 |
-
from openinference.instrumentation.openai import OpenAIInstrumentor
|
| 110 |
-
|
| 111 |
-
OpenAIInstrumentor().instrument()
|
| 112 |
-
_log.info("instrumented_openai")
|
| 113 |
-
except ImportError:
|
| 114 |
-
_log.debug(
|
| 115 |
-
"openai_instrumentation_skipped",
|
| 116 |
-
reason="openinference-instrumentation-openai not installed",
|
| 117 |
-
)
|
| 118 |
-
except Exception as exc:
|
| 119 |
-
_log.debug("openai_instrumentation_error", reason=str(exc))
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def trace_llm_call(
|
| 123 |
-
provider: str,
|
| 124 |
-
model: str,
|
| 125 |
-
prompt: str,
|
| 126 |
-
response: str,
|
| 127 |
-
latency_ms: float,
|
| 128 |
-
tokens: dict[str, int] | None = None,
|
| 129 |
-
) -> None:
|
| 130 |
-
"""Record a manual trace span for an LLM call.
|
| 131 |
-
|
| 132 |
-
Can be used as an explicit trace point when auto-instrumentation
|
| 133 |
-
is unavailable or for custom tracking.
|
| 134 |
-
|
| 135 |
-
Args:
|
| 136 |
-
provider: LLM provider name (e.g., "ollama", "groq").
|
| 137 |
-
model: Model identifier used for generation.
|
| 138 |
-
prompt: The input prompt text.
|
| 139 |
-
response: The generated response text.
|
| 140 |
-
latency_ms: Response latency in milliseconds.
|
| 141 |
-
tokens: Optional token usage dict with keys like
|
| 142 |
-
"prompt_tokens", "completion_tokens", "total_tokens".
|
| 143 |
-
"""
|
| 144 |
-
if not _phoenix_configured:
|
| 145 |
-
return
|
| 146 |
-
|
| 147 |
-
try:
|
| 148 |
-
from opentelemetry import trace
|
| 149 |
-
|
| 150 |
-
tracer = trace.get_tracer("secureagentrag.llm")
|
| 151 |
-
with tracer.start_as_current_span("llm_call") as span:
|
| 152 |
-
span.set_attribute("llm.provider", provider)
|
| 153 |
-
span.set_attribute("llm.model", model)
|
| 154 |
-
span.set_attribute("llm.prompt_length", len(prompt))
|
| 155 |
-
span.set_attribute("llm.response_length", len(response))
|
| 156 |
-
span.set_attribute("llm.latency_ms", latency_ms)
|
| 157 |
-
if tokens:
|
| 158 |
-
for key, value in tokens.items():
|
| 159 |
-
span.set_attribute(f"llm.tokens.{key}", value)
|
| 160 |
-
except Exception as exc:
|
| 161 |
-
_log.debug("trace_llm_call_failed", error=str(exc))
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def trace_retrieval(
|
| 165 |
-
query: str,
|
| 166 |
-
num_results: int,
|
| 167 |
-
latency_ms: float,
|
| 168 |
-
method: str = "hybrid",
|
| 169 |
-
) -> None:
|
| 170 |
-
"""Record a manual trace span for a retrieval operation.
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
query: The search query string.
|
| 174 |
-
num_results: Number of results returned.
|
| 175 |
-
latency_ms: Retrieval latency in milliseconds.
|
| 176 |
-
method: Retrieval method used ("hybrid", "dense", "bm25").
|
| 177 |
-
"""
|
| 178 |
-
if not _phoenix_configured:
|
| 179 |
-
return
|
| 180 |
-
|
| 181 |
-
try:
|
| 182 |
-
from opentelemetry import trace
|
| 183 |
-
|
| 184 |
-
tracer = trace.get_tracer("secureagentrag.retrieval")
|
| 185 |
-
with tracer.start_as_current_span("retrieval") as span:
|
| 186 |
-
span.set_attribute("retrieval.query_length", len(query))
|
| 187 |
-
span.set_attribute("retrieval.num_results", num_results)
|
| 188 |
-
span.set_attribute("retrieval.latency_ms", latency_ms)
|
| 189 |
-
span.set_attribute("retrieval.method", method)
|
| 190 |
-
except Exception as exc:
|
| 191 |
-
_log.debug("trace_retrieval_failed", error=str(exc))
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def trace_graph_execution(
|
| 195 |
-
query: str,
|
| 196 |
-
nodes_executed: list[str],
|
| 197 |
-
total_latency_ms: float,
|
| 198 |
-
final_confidence: float,
|
| 199 |
-
retries: int = 0,
|
| 200 |
-
) -> None:
|
| 201 |
-
"""Record a manual trace span for LangGraph pipeline execution.
|
| 202 |
-
|
| 203 |
-
Args:
|
| 204 |
-
query: The original user query.
|
| 205 |
-
nodes_executed: List of graph node names that were executed.
|
| 206 |
-
total_latency_ms: Total pipeline execution time in milliseconds.
|
| 207 |
-
final_confidence: Final confidence score of the generated answer.
|
| 208 |
-
retries: Number of corrective retrieval retries performed.
|
| 209 |
-
"""
|
| 210 |
-
if not _phoenix_configured:
|
| 211 |
-
return
|
| 212 |
-
|
| 213 |
-
try:
|
| 214 |
-
from opentelemetry import trace
|
| 215 |
-
|
| 216 |
-
tracer = trace.get_tracer("secureagentrag.graph")
|
| 217 |
-
with tracer.start_as_current_span("graph_execution") as span:
|
| 218 |
-
span.set_attribute("graph.query_length", len(query))
|
| 219 |
-
span.set_attribute("graph.nodes_executed", ",".join(nodes_executed))
|
| 220 |
-
span.set_attribute("graph.total_latency_ms", total_latency_ms)
|
| 221 |
-
span.set_attribute("graph.confidence", final_confidence)
|
| 222 |
-
span.set_attribute("graph.retries", retries)
|
| 223 |
-
except Exception as exc:
|
| 224 |
-
_log.debug("trace_graph_execution_failed", error=str(exc))
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
def get_trace_url() -> str | None:
|
| 228 |
-
"""Return the Phoenix dashboard URL if tracing is configured.
|
| 229 |
-
|
| 230 |
-
Returns:
|
| 231 |
-
Phoenix UI URL string, or None if Phoenix is not configured.
|
| 232 |
-
"""
|
| 233 |
-
if not _phoenix_configured or not settings.phoenix_endpoint:
|
| 234 |
-
return None
|
| 235 |
-
|
| 236 |
-
# Phoenix UI typically runs on the same host
|
| 237 |
-
endpoint = settings.phoenix_endpoint.rstrip("/")
|
| 238 |
-
# Replace gRPC/collector port with UI port if needed
|
| 239 |
-
if ":4317" in endpoint:
|
| 240 |
-
return endpoint.replace(":4317", ":6006")
|
| 241 |
-
if ":6006" in endpoint:
|
| 242 |
-
return endpoint
|
| 243 |
-
return endpoint
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def is_tracing_enabled() -> bool:
|
| 247 |
-
"""Check if Phoenix tracing is currently active.
|
| 248 |
-
|
| 249 |
-
Returns:
|
| 250 |
-
True if tracing was successfully configured, False otherwise.
|
| 251 |
-
"""
|
| 252 |
-
return _phoenix_configured
|
|
|
|
| 1 |
+
"""Observability setup using Arize Phoenix for LLM tracing.
|
| 2 |
+
|
| 3 |
+
Provides OpenTelemetry-compatible distributed tracing for LLM calls,
|
| 4 |
+
retrieval operations, and LangGraph execution. Gracefully degrades
|
| 5 |
+
when Phoenix is not installed or configured.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
Call setup_tracing() once at application startup (e.g., in app/main.py).
|
| 9 |
+
All trace_* functions will automatically emit spans when tracing is enabled.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from config.settings import settings
|
| 15 |
+
from utils.logging import get_logger
|
| 16 |
+
|
| 17 |
+
_log = get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
# Module-level state
|
| 20 |
+
_tracer = None
|
| 21 |
+
_phoenix_configured = False
|
| 22 |
+
_phoenix_project_name: str = settings.app_name
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def setup_tracing() -> bool:
|
| 26 |
+
"""Initialize Phoenix tracing if ``settings.phoenix_endpoint`` is set.
|
| 27 |
+
|
| 28 |
+
This function is safe to call unconditionally at startup — it will
|
| 29 |
+
log a message and return immediately if Phoenix is not configured.
|
| 30 |
+
Tracing failures never crash the application.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
True if tracing was successfully enabled, False otherwise.
|
| 34 |
+
"""
|
| 35 |
+
global _tracer, _phoenix_configured, _phoenix_project_name
|
| 36 |
+
|
| 37 |
+
# BYOK mode mandates: no third-party telemetry sees a request. Phoenix
|
| 38 |
+
# spans capture LLM prompts and completions, which would include the
|
| 39 |
+
# visitor's keys-in-context and any private text they uploaded. Hard
|
| 40 |
+
# disable in BYOK regardless of phoenix_endpoint configuration.
|
| 41 |
+
if settings.byok_mode:
|
| 42 |
+
_log.info("phoenix_tracing_disabled", reason="BYOK mode forbids external telemetry")
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
if not settings.phoenix_endpoint:
|
| 46 |
+
_log.info("phoenix_tracing_disabled", reason="No phoenix_endpoint configured")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from phoenix.otel import register
|
| 51 |
+
|
| 52 |
+
tracer_provider = register(
|
| 53 |
+
project_name=settings.app_name,
|
| 54 |
+
endpoint=settings.phoenix_endpoint,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Attempt to instrument LLM and retrieval calls
|
| 58 |
+
_instrument_providers()
|
| 59 |
+
|
| 60 |
+
_phoenix_configured = True
|
| 61 |
+
_phoenix_project_name = settings.app_name
|
| 62 |
+
_log.info(
|
| 63 |
+
"phoenix_tracing_enabled",
|
| 64 |
+
endpoint=settings.phoenix_endpoint,
|
| 65 |
+
project=settings.app_name,
|
| 66 |
+
tracer_provider=str(tracer_provider),
|
| 67 |
+
)
|
| 68 |
+
return True
|
| 69 |
+
except ImportError:
|
| 70 |
+
_log.warning(
|
| 71 |
+
"phoenix_import_failed",
|
| 72 |
+
msg=(
|
| 73 |
+
"arize-phoenix not installed; tracing unavailable. "
|
| 74 |
+
"Install with: pip install 'arize-phoenix-otel'"
|
| 75 |
+
),
|
| 76 |
+
)
|
| 77 |
+
return False
|
| 78 |
+
except Exception as exc:
|
| 79 |
+
_log.error(
|
| 80 |
+
"phoenix_tracing_init_error",
|
| 81 |
+
error=str(exc),
|
| 82 |
+
endpoint=settings.phoenix_endpoint,
|
| 83 |
+
)
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _instrument_providers() -> None:
|
| 88 |
+
"""Instrument LLM and retrieval providers with OpenTelemetry.
|
| 89 |
+
|
| 90 |
+
Attempts to auto-instrument supported providers. Failures are
|
| 91 |
+
logged but never raised — partial instrumentation is acceptable.
|
| 92 |
+
"""
|
| 93 |
+
# Instrument LangChain/LangGraph if available
|
| 94 |
+
try:
|
| 95 |
+
from openinference.instrumentation.langchain import LangChainInstrumentor
|
| 96 |
+
|
| 97 |
+
LangChainInstrumentor().instrument()
|
| 98 |
+
_log.info("instrumented_langchain")
|
| 99 |
+
except ImportError:
|
| 100 |
+
_log.debug(
|
| 101 |
+
"langchain_instrumentation_skipped",
|
| 102 |
+
reason="openinference-instrumentation-langchain not installed",
|
| 103 |
+
)
|
| 104 |
+
except Exception as exc:
|
| 105 |
+
_log.debug("langchain_instrumentation_error", reason=str(exc))
|
| 106 |
+
|
| 107 |
+
# Instrument OpenAI-compatible calls if available
|
| 108 |
+
try:
|
| 109 |
+
from openinference.instrumentation.openai import OpenAIInstrumentor
|
| 110 |
+
|
| 111 |
+
OpenAIInstrumentor().instrument()
|
| 112 |
+
_log.info("instrumented_openai")
|
| 113 |
+
except ImportError:
|
| 114 |
+
_log.debug(
|
| 115 |
+
"openai_instrumentation_skipped",
|
| 116 |
+
reason="openinference-instrumentation-openai not installed",
|
| 117 |
+
)
|
| 118 |
+
except Exception as exc:
|
| 119 |
+
_log.debug("openai_instrumentation_error", reason=str(exc))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def trace_llm_call(
|
| 123 |
+
provider: str,
|
| 124 |
+
model: str,
|
| 125 |
+
prompt: str,
|
| 126 |
+
response: str,
|
| 127 |
+
latency_ms: float,
|
| 128 |
+
tokens: dict[str, int] | None = None,
|
| 129 |
+
) -> None:
|
| 130 |
+
"""Record a manual trace span for an LLM call.
|
| 131 |
+
|
| 132 |
+
Can be used as an explicit trace point when auto-instrumentation
|
| 133 |
+
is unavailable or for custom tracking.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
provider: LLM provider name (e.g., "ollama", "groq").
|
| 137 |
+
model: Model identifier used for generation.
|
| 138 |
+
prompt: The input prompt text.
|
| 139 |
+
response: The generated response text.
|
| 140 |
+
latency_ms: Response latency in milliseconds.
|
| 141 |
+
tokens: Optional token usage dict with keys like
|
| 142 |
+
"prompt_tokens", "completion_tokens", "total_tokens".
|
| 143 |
+
"""
|
| 144 |
+
if not _phoenix_configured:
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
from opentelemetry import trace
|
| 149 |
+
|
| 150 |
+
tracer = trace.get_tracer("secureagentrag.llm")
|
| 151 |
+
with tracer.start_as_current_span("llm_call") as span:
|
| 152 |
+
span.set_attribute("llm.provider", provider)
|
| 153 |
+
span.set_attribute("llm.model", model)
|
| 154 |
+
span.set_attribute("llm.prompt_length", len(prompt))
|
| 155 |
+
span.set_attribute("llm.response_length", len(response))
|
| 156 |
+
span.set_attribute("llm.latency_ms", latency_ms)
|
| 157 |
+
if tokens:
|
| 158 |
+
for key, value in tokens.items():
|
| 159 |
+
span.set_attribute(f"llm.tokens.{key}", value)
|
| 160 |
+
except Exception as exc:
|
| 161 |
+
_log.debug("trace_llm_call_failed", error=str(exc))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def trace_retrieval(
|
| 165 |
+
query: str,
|
| 166 |
+
num_results: int,
|
| 167 |
+
latency_ms: float,
|
| 168 |
+
method: str = "hybrid",
|
| 169 |
+
) -> None:
|
| 170 |
+
"""Record a manual trace span for a retrieval operation.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
query: The search query string.
|
| 174 |
+
num_results: Number of results returned.
|
| 175 |
+
latency_ms: Retrieval latency in milliseconds.
|
| 176 |
+
method: Retrieval method used ("hybrid", "dense", "bm25").
|
| 177 |
+
"""
|
| 178 |
+
if not _phoenix_configured:
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
from opentelemetry import trace
|
| 183 |
+
|
| 184 |
+
tracer = trace.get_tracer("secureagentrag.retrieval")
|
| 185 |
+
with tracer.start_as_current_span("retrieval") as span:
|
| 186 |
+
span.set_attribute("retrieval.query_length", len(query))
|
| 187 |
+
span.set_attribute("retrieval.num_results", num_results)
|
| 188 |
+
span.set_attribute("retrieval.latency_ms", latency_ms)
|
| 189 |
+
span.set_attribute("retrieval.method", method)
|
| 190 |
+
except Exception as exc:
|
| 191 |
+
_log.debug("trace_retrieval_failed", error=str(exc))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def trace_graph_execution(
|
| 195 |
+
query: str,
|
| 196 |
+
nodes_executed: list[str],
|
| 197 |
+
total_latency_ms: float,
|
| 198 |
+
final_confidence: float,
|
| 199 |
+
retries: int = 0,
|
| 200 |
+
) -> None:
|
| 201 |
+
"""Record a manual trace span for LangGraph pipeline execution.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
query: The original user query.
|
| 205 |
+
nodes_executed: List of graph node names that were executed.
|
| 206 |
+
total_latency_ms: Total pipeline execution time in milliseconds.
|
| 207 |
+
final_confidence: Final confidence score of the generated answer.
|
| 208 |
+
retries: Number of corrective retrieval retries performed.
|
| 209 |
+
"""
|
| 210 |
+
if not _phoenix_configured:
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
from opentelemetry import trace
|
| 215 |
+
|
| 216 |
+
tracer = trace.get_tracer("secureagentrag.graph")
|
| 217 |
+
with tracer.start_as_current_span("graph_execution") as span:
|
| 218 |
+
span.set_attribute("graph.query_length", len(query))
|
| 219 |
+
span.set_attribute("graph.nodes_executed", ",".join(nodes_executed))
|
| 220 |
+
span.set_attribute("graph.total_latency_ms", total_latency_ms)
|
| 221 |
+
span.set_attribute("graph.confidence", final_confidence)
|
| 222 |
+
span.set_attribute("graph.retries", retries)
|
| 223 |
+
except Exception as exc:
|
| 224 |
+
_log.debug("trace_graph_execution_failed", error=str(exc))
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def get_trace_url() -> str | None:
|
| 228 |
+
"""Return the Phoenix dashboard URL if tracing is configured.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Phoenix UI URL string, or None if Phoenix is not configured.
|
| 232 |
+
"""
|
| 233 |
+
if not _phoenix_configured or not settings.phoenix_endpoint:
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
# Phoenix UI typically runs on the same host
|
| 237 |
+
endpoint = settings.phoenix_endpoint.rstrip("/")
|
| 238 |
+
# Replace gRPC/collector port with UI port if needed
|
| 239 |
+
if ":4317" in endpoint:
|
| 240 |
+
return endpoint.replace(":4317", ":6006")
|
| 241 |
+
if ":6006" in endpoint:
|
| 242 |
+
return endpoint
|
| 243 |
+
return endpoint
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def is_tracing_enabled() -> bool:
|
| 247 |
+
"""Check if Phoenix tracing is currently active.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
True if tracing was successfully configured, False otherwise.
|
| 251 |
+
"""
|
| 252 |
+
return _phoenix_configured
|
utils/pii.py
CHANGED
|
@@ -1,146 +1,146 @@
|
|
| 1 |
-
"""PII redaction for secondary stores (audit log, query cache, conversation history).
|
| 2 |
-
|
| 3 |
-
Two strategies:
|
| 4 |
-
- Regex-based (always available) — covers email, phone, SSN, credit card,
|
| 5 |
-
IBAN, IPv4, URL with credentials.
|
| 6 |
-
- Microsoft Presidio (optional dependency) — invoked when installed; higher
|
| 7 |
-
recall and language-aware NER for names, locations, organizations.
|
| 8 |
-
|
| 9 |
-
This module never sees plaintext from the LLM — it operates only on text that
|
| 10 |
-
is about to be persisted to disk. Live prompts and retrieved contexts remain
|
| 11 |
-
unmodified so model quality is not affected.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import re
|
| 17 |
-
from typing import Any
|
| 18 |
-
|
| 19 |
-
from config.settings import settings
|
| 20 |
-
from utils.logging import get_logger
|
| 21 |
-
|
| 22 |
-
logger = get_logger(__name__)
|
| 23 |
-
|
| 24 |
-
# Order matters — most specific patterns first so they win against the
|
| 25 |
-
# broader phone regex. Provider-specific API-key shapes (added 2026-05-26
|
| 26 |
-
# for BYOK mode) live ABOVE the generic ``[API_KEY]`` rule because their
|
| 27 |
-
# prefixes are not catchable by the legacy ``(sk|pk|api|key)`` alternation.
|
| 28 |
-
_REGEX_PATTERNS: list[tuple[re.Pattern[str], str]] = [
|
| 29 |
-
(re.compile(r"https?://[^\s/]+:[^\s/]+@[^\s]+"), "[URL_WITH_CREDS]"),
|
| 30 |
-
(re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"), "[EMAIL]"),
|
| 31 |
-
(re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[SSN]"),
|
| 32 |
-
(re.compile(r"\b(?:\d[ -]*?){13,19}\b"), "[CC]"), # Luhn-validated below
|
| 33 |
-
(re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{10,30}\b"), "[IBAN]"),
|
| 34 |
-
(re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b"), "[IP]"),
|
| 35 |
-
# ── BYOK key shapes (P6 production launch) ──────────────────────────
|
| 36 |
-
# Anthropic must come BEFORE OpenAI because ``sk-ant-...`` also matches
|
| 37 |
-
# the generic ``sk-...`` rule below.
|
| 38 |
-
(re.compile(r"\bsk-ant-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
|
| 39 |
-
(re.compile(r"\bsk-(?:proj|svcacct)-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
|
| 40 |
-
(re.compile(r"\bgsk_[A-Za-z0-9]{40,}\b"), "[API_KEY]"),
|
| 41 |
-
(re.compile(r"\bhf_[A-Za-z0-9]{30,}\b"), "[API_KEY]"),
|
| 42 |
-
(re.compile(r"\bvcp_[A-Za-z0-9]{20,}\b"), "[API_KEY]"),
|
| 43 |
-
# JWT-format database API keys (Qdrant Cloud auth v2). Three dot-separated
|
| 44 |
-
# base64url segments — the middle one is always ``eyJ...`` start.
|
| 45 |
-
(re.compile(r"\beyJ[A-Za-z0-9_-]+\.eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b"), "[API_KEY]"),
|
| 46 |
-
# Qdrant Cloud management keys: ``<uuid>|<token>``.
|
| 47 |
-
(
|
| 48 |
-
re.compile(
|
| 49 |
-
r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\|[A-Za-z0-9_-]{20,}\b"
|
| 50 |
-
),
|
| 51 |
-
"[API_KEY]",
|
| 52 |
-
),
|
| 53 |
-
# Legacy generic — keeps catching ``sk-...`` and ``api_...`` shapes from
|
| 54 |
-
# older docs and tests.
|
| 55 |
-
(re.compile(r"\b(?:sk|pk|api|key)[-_][A-Za-z0-9_-]{16,}\b", re.IGNORECASE), "[API_KEY]"),
|
| 56 |
-
(re.compile(r"\b(?:\+?\d{1,3}[-.\s]?)?\(?\d{2,4}\)?[-.\s]?\d{3,4}[-.\s]?\d{3,4}\b"), "[PHONE]"),
|
| 57 |
-
]
|
| 58 |
-
|
| 59 |
-
# Try Presidio for richer detection (names, locations, etc.)
|
| 60 |
-
try:
|
| 61 |
-
from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
|
| 62 |
-
from presidio_anonymizer import AnonymizerEngine # type: ignore[import-not-found]
|
| 63 |
-
|
| 64 |
-
_PRESIDIO_AVAILABLE = True
|
| 65 |
-
_analyzer: Any = AnalyzerEngine()
|
| 66 |
-
_anonymizer: Any = AnonymizerEngine()
|
| 67 |
-
except Exception:
|
| 68 |
-
_PRESIDIO_AVAILABLE = False
|
| 69 |
-
_analyzer = None
|
| 70 |
-
_anonymizer = None
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _luhn_valid(num: str) -> bool:
|
| 74 |
-
"""Luhn checksum to filter false-positive credit-card matches."""
|
| 75 |
-
digits = [int(c) for c in num if c.isdigit()]
|
| 76 |
-
if not (13 <= len(digits) <= 19):
|
| 77 |
-
return False
|
| 78 |
-
s = 0
|
| 79 |
-
for i, d in enumerate(reversed(digits)):
|
| 80 |
-
if i % 2 == 1:
|
| 81 |
-
d *= 2
|
| 82 |
-
if d > 9:
|
| 83 |
-
d -= 9
|
| 84 |
-
s += d
|
| 85 |
-
return s % 10 == 0
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def redact(text: str) -> str:
|
| 89 |
-
"""Return ``text`` with PII tokens masked.
|
| 90 |
-
|
| 91 |
-
Args:
|
| 92 |
-
text: Arbitrary string that may contain PII.
|
| 93 |
-
|
| 94 |
-
Returns:
|
| 95 |
-
Redacted copy of the text. If redaction is disabled via settings
|
| 96 |
-
the original string is returned unchanged.
|
| 97 |
-
"""
|
| 98 |
-
if not settings.pii_redaction_enabled or not text:
|
| 99 |
-
return text
|
| 100 |
-
|
| 101 |
-
out = text
|
| 102 |
-
for pattern, token in _REGEX_PATTERNS:
|
| 103 |
-
if token == "[CC]":
|
| 104 |
-
# Apply Luhn to avoid over-masking phone numbers / arbitrary digits.
|
| 105 |
-
out = pattern.sub(lambda m: "[CC]" if _luhn_valid(m.group(0)) else m.group(0), out)
|
| 106 |
-
else:
|
| 107 |
-
out = pattern.sub(token, out)
|
| 108 |
-
|
| 109 |
-
if _PRESIDIO_AVAILABLE and _analyzer is not None and _anonymizer is not None:
|
| 110 |
-
try:
|
| 111 |
-
results = _analyzer.analyze(text=out, language="en")
|
| 112 |
-
if results:
|
| 113 |
-
out = _anonymizer.anonymize(text=out, analyzer_results=results).text
|
| 114 |
-
except Exception as exc:
|
| 115 |
-
logger.debug("presidio_redact_failed", error=str(exc))
|
| 116 |
-
|
| 117 |
-
return out
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def redact_dict(data: dict[str, Any], fields: tuple[str, ...] | None = None) -> dict[str, Any]:
|
| 121 |
-
"""Recursively redact string values in a dict.
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
data: Dict (possibly nested) to redact.
|
| 125 |
-
fields: If given, only redact these top-level keys. Otherwise redact
|
| 126 |
-
every string in the structure.
|
| 127 |
-
|
| 128 |
-
Returns:
|
| 129 |
-
Deep-redacted copy.
|
| 130 |
-
"""
|
| 131 |
-
if not settings.pii_redaction_enabled:
|
| 132 |
-
return data
|
| 133 |
-
|
| 134 |
-
def _walk(value: Any, *, force: bool) -> Any:
|
| 135 |
-
if isinstance(value, str):
|
| 136 |
-
return redact(value) if force else value
|
| 137 |
-
if isinstance(value, dict):
|
| 138 |
-
return {
|
| 139 |
-
k: _walk(v, force=force or (fields is not None and k in fields))
|
| 140 |
-
for k, v in value.items()
|
| 141 |
-
}
|
| 142 |
-
if isinstance(value, list):
|
| 143 |
-
return [_walk(v, force=force) for v in value]
|
| 144 |
-
return value
|
| 145 |
-
|
| 146 |
-
return _walk(data, force=fields is None)
|
|
|
|
| 1 |
+
"""PII redaction for secondary stores (audit log, query cache, conversation history).
|
| 2 |
+
|
| 3 |
+
Two strategies:
|
| 4 |
+
- Regex-based (always available) — covers email, phone, SSN, credit card,
|
| 5 |
+
IBAN, IPv4, URL with credentials.
|
| 6 |
+
- Microsoft Presidio (optional dependency) — invoked when installed; higher
|
| 7 |
+
recall and language-aware NER for names, locations, organizations.
|
| 8 |
+
|
| 9 |
+
This module never sees plaintext from the LLM — it operates only on text that
|
| 10 |
+
is about to be persisted to disk. Live prompts and retrieved contexts remain
|
| 11 |
+
unmodified so model quality is not affected.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from config.settings import settings
|
| 20 |
+
from utils.logging import get_logger
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
# Order matters — most specific patterns first so they win against the
|
| 25 |
+
# broader phone regex. Provider-specific API-key shapes (added 2026-05-26
|
| 26 |
+
# for BYOK mode) live ABOVE the generic ``[API_KEY]`` rule because their
|
| 27 |
+
# prefixes are not catchable by the legacy ``(sk|pk|api|key)`` alternation.
|
| 28 |
+
_REGEX_PATTERNS: list[tuple[re.Pattern[str], str]] = [
|
| 29 |
+
(re.compile(r"https?://[^\s/]+:[^\s/]+@[^\s]+"), "[URL_WITH_CREDS]"),
|
| 30 |
+
(re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"), "[EMAIL]"),
|
| 31 |
+
(re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[SSN]"),
|
| 32 |
+
(re.compile(r"\b(?:\d[ -]*?){13,19}\b"), "[CC]"), # Luhn-validated below
|
| 33 |
+
(re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{10,30}\b"), "[IBAN]"),
|
| 34 |
+
(re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b"), "[IP]"),
|
| 35 |
+
# ── BYOK key shapes (P6 production launch) ──────────────────────────
|
| 36 |
+
# Anthropic must come BEFORE OpenAI because ``sk-ant-...`` also matches
|
| 37 |
+
# the generic ``sk-...`` rule below.
|
| 38 |
+
(re.compile(r"\bsk-ant-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
|
| 39 |
+
(re.compile(r"\bsk-(?:proj|svcacct)-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
|
| 40 |
+
(re.compile(r"\bgsk_[A-Za-z0-9]{40,}\b"), "[API_KEY]"),
|
| 41 |
+
(re.compile(r"\bhf_[A-Za-z0-9]{30,}\b"), "[API_KEY]"),
|
| 42 |
+
(re.compile(r"\bvcp_[A-Za-z0-9]{20,}\b"), "[API_KEY]"),
|
| 43 |
+
# JWT-format database API keys (Qdrant Cloud auth v2). Three dot-separated
|
| 44 |
+
# base64url segments — the middle one is always ``eyJ...`` start.
|
| 45 |
+
(re.compile(r"\beyJ[A-Za-z0-9_-]+\.eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b"), "[API_KEY]"),
|
| 46 |
+
# Qdrant Cloud management keys: ``<uuid>|<token>``.
|
| 47 |
+
(
|
| 48 |
+
re.compile(
|
| 49 |
+
r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\|[A-Za-z0-9_-]{20,}\b"
|
| 50 |
+
),
|
| 51 |
+
"[API_KEY]",
|
| 52 |
+
),
|
| 53 |
+
# Legacy generic — keeps catching ``sk-...`` and ``api_...`` shapes from
|
| 54 |
+
# older docs and tests.
|
| 55 |
+
(re.compile(r"\b(?:sk|pk|api|key)[-_][A-Za-z0-9_-]{16,}\b", re.IGNORECASE), "[API_KEY]"),
|
| 56 |
+
(re.compile(r"\b(?:\+?\d{1,3}[-.\s]?)?\(?\d{2,4}\)?[-.\s]?\d{3,4}[-.\s]?\d{3,4}\b"), "[PHONE]"),
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
# Try Presidio for richer detection (names, locations, etc.)
|
| 60 |
+
try:
|
| 61 |
+
from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
|
| 62 |
+
from presidio_anonymizer import AnonymizerEngine # type: ignore[import-not-found]
|
| 63 |
+
|
| 64 |
+
_PRESIDIO_AVAILABLE = True
|
| 65 |
+
_analyzer: Any = AnalyzerEngine()
|
| 66 |
+
_anonymizer: Any = AnonymizerEngine()
|
| 67 |
+
except Exception:
|
| 68 |
+
_PRESIDIO_AVAILABLE = False
|
| 69 |
+
_analyzer = None
|
| 70 |
+
_anonymizer = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _luhn_valid(num: str) -> bool:
|
| 74 |
+
"""Luhn checksum to filter false-positive credit-card matches."""
|
| 75 |
+
digits = [int(c) for c in num if c.isdigit()]
|
| 76 |
+
if not (13 <= len(digits) <= 19):
|
| 77 |
+
return False
|
| 78 |
+
s = 0
|
| 79 |
+
for i, d in enumerate(reversed(digits)):
|
| 80 |
+
if i % 2 == 1:
|
| 81 |
+
d *= 2
|
| 82 |
+
if d > 9:
|
| 83 |
+
d -= 9
|
| 84 |
+
s += d
|
| 85 |
+
return s % 10 == 0
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def redact(text: str) -> str:
|
| 89 |
+
"""Return ``text`` with PII tokens masked.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
text: Arbitrary string that may contain PII.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Redacted copy of the text. If redaction is disabled via settings
|
| 96 |
+
the original string is returned unchanged.
|
| 97 |
+
"""
|
| 98 |
+
if not settings.pii_redaction_enabled or not text:
|
| 99 |
+
return text
|
| 100 |
+
|
| 101 |
+
out = text
|
| 102 |
+
for pattern, token in _REGEX_PATTERNS:
|
| 103 |
+
if token == "[CC]":
|
| 104 |
+
# Apply Luhn to avoid over-masking phone numbers / arbitrary digits.
|
| 105 |
+
out = pattern.sub(lambda m: "[CC]" if _luhn_valid(m.group(0)) else m.group(0), out)
|
| 106 |
+
else:
|
| 107 |
+
out = pattern.sub(token, out)
|
| 108 |
+
|
| 109 |
+
if _PRESIDIO_AVAILABLE and _analyzer is not None and _anonymizer is not None:
|
| 110 |
+
try:
|
| 111 |
+
results = _analyzer.analyze(text=out, language="en")
|
| 112 |
+
if results:
|
| 113 |
+
out = _anonymizer.anonymize(text=out, analyzer_results=results).text
|
| 114 |
+
except Exception as exc:
|
| 115 |
+
logger.debug("presidio_redact_failed", error=str(exc))
|
| 116 |
+
|
| 117 |
+
return out
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def redact_dict(data: dict[str, Any], fields: tuple[str, ...] | None = None) -> dict[str, Any]:
|
| 121 |
+
"""Recursively redact string values in a dict.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
data: Dict (possibly nested) to redact.
|
| 125 |
+
fields: If given, only redact these top-level keys. Otherwise redact
|
| 126 |
+
every string in the structure.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Deep-redacted copy.
|
| 130 |
+
"""
|
| 131 |
+
if not settings.pii_redaction_enabled:
|
| 132 |
+
return data
|
| 133 |
+
|
| 134 |
+
def _walk(value: Any, *, force: bool) -> Any:
|
| 135 |
+
if isinstance(value, str):
|
| 136 |
+
return redact(value) if force else value
|
| 137 |
+
if isinstance(value, dict):
|
| 138 |
+
return {
|
| 139 |
+
k: _walk(v, force=force or (fields is not None and k in fields))
|
| 140 |
+
for k, v in value.items()
|
| 141 |
+
}
|
| 142 |
+
if isinstance(value, list):
|
| 143 |
+
return [_walk(v, force=force) for v in value]
|
| 144 |
+
return value
|
| 145 |
+
|
| 146 |
+
return _walk(data, force=fields is None)
|
utils/rate_limiter.py
CHANGED
|
@@ -1,524 +1,524 @@
|
|
| 1 |
-
"""Token-bucket rate limiter for API request throttling.
|
| 2 |
-
|
| 3 |
-
Provides per-user and per-endpoint rate limiting to prevent abuse and
|
| 4 |
-
ensure fair resource allocation. Uses an in-memory token bucket algorithm
|
| 5 |
-
with optional Redis backend for distributed deployments.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import time
|
| 11 |
-
from dataclasses import dataclass, field
|
| 12 |
-
from typing import Any
|
| 13 |
-
|
| 14 |
-
from utils.logging import get_logger
|
| 15 |
-
|
| 16 |
-
logger = get_logger(__name__)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
@dataclass
|
| 20 |
-
class RateLimitConfig:
|
| 21 |
-
"""Configuration for a rate limit bucket.
|
| 22 |
-
|
| 23 |
-
Attributes:
|
| 24 |
-
requests_per_minute: Maximum requests allowed per minute.
|
| 25 |
-
burst_size: Maximum burst capacity (bucket size).
|
| 26 |
-
cooldown_seconds: Seconds to wait after being rate limited.
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
requests_per_minute: int = 60
|
| 30 |
-
burst_size: int = 10
|
| 31 |
-
cooldown_seconds: float = 1.0
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@dataclass
|
| 35 |
-
class TokenBucket:
|
| 36 |
-
"""In-memory token bucket for rate limiting.
|
| 37 |
-
|
| 38 |
-
Attributes:
|
| 39 |
-
tokens: Current available tokens.
|
| 40 |
-
last_update: Timestamp of last token refill.
|
| 41 |
-
config: Rate limit configuration.
|
| 42 |
-
blocked_until: Timestamp when the bucket is unblocked.
|
| 43 |
-
"""
|
| 44 |
-
|
| 45 |
-
tokens: float = field(default=0.0)
|
| 46 |
-
last_update: float = field(default_factory=time.time)
|
| 47 |
-
config: RateLimitConfig = field(default_factory=RateLimitConfig)
|
| 48 |
-
blocked_until: float = field(default=0.0)
|
| 49 |
-
|
| 50 |
-
def _refill(self) -> None:
|
| 51 |
-
"""Refill tokens based on elapsed time since last update."""
|
| 52 |
-
now = time.time()
|
| 53 |
-
elapsed = now - self.last_update
|
| 54 |
-
# Refill rate: tokens per second
|
| 55 |
-
refill_rate = self.config.requests_per_minute / 60.0
|
| 56 |
-
self.tokens = min(self.config.burst_size, self.tokens + elapsed * refill_rate)
|
| 57 |
-
self.last_update = now
|
| 58 |
-
|
| 59 |
-
def consume(self, tokens: float = 1.0) -> tuple[bool, dict[str, Any]]:
|
| 60 |
-
"""Attempt to consume tokens from the bucket.
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
tokens: Number of tokens to consume (default 1 per request).
|
| 64 |
-
|
| 65 |
-
Returns:
|
| 66 |
-
Tuple of (allowed, metadata) where metadata contains
|
| 67 |
-
remaining tokens, retry_after, etc.
|
| 68 |
-
"""
|
| 69 |
-
now = time.time()
|
| 70 |
-
|
| 71 |
-
# Check if currently blocked
|
| 72 |
-
if now < self.blocked_until:
|
| 73 |
-
retry_after = int(self.blocked_until - now) + 1
|
| 74 |
-
return False, {
|
| 75 |
-
"allowed": False,
|
| 76 |
-
"remaining": 0,
|
| 77 |
-
"retry_after": retry_after,
|
| 78 |
-
"reason": "cooldown_active",
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
self._refill()
|
| 82 |
-
|
| 83 |
-
if self.tokens >= tokens:
|
| 84 |
-
self.tokens -= tokens
|
| 85 |
-
remaining = int(self.tokens)
|
| 86 |
-
return True, {
|
| 87 |
-
"allowed": True,
|
| 88 |
-
"remaining": remaining,
|
| 89 |
-
"retry_after": 0,
|
| 90 |
-
"reason": None,
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
# Rate limit exceeded — enter cooldown
|
| 94 |
-
self.blocked_until = now + self.config.cooldown_seconds
|
| 95 |
-
retry_after = int(self.config.cooldown_seconds) + 1
|
| 96 |
-
return False, {
|
| 97 |
-
"allowed": False,
|
| 98 |
-
"remaining": 0,
|
| 99 |
-
"retry_after": retry_after,
|
| 100 |
-
"reason": "rate_limit_exceeded",
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class RateLimiter:
|
| 105 |
-
"""Multi-key rate limiter with per-user and per-endpoint tracking.
|
| 106 |
-
|
| 107 |
-
Uses in-memory token buckets. For distributed deployments, wrap
|
| 108 |
-
with a Redis-backed implementation.
|
| 109 |
-
|
| 110 |
-
Args:
|
| 111 |
-
default_config: Default rate limit configuration.
|
| 112 |
-
"""
|
| 113 |
-
|
| 114 |
-
def __init__(self, default_config: RateLimitConfig | None = None) -> None:
|
| 115 |
-
"""Initialize the rate limiter.
|
| 116 |
-
|
| 117 |
-
Args:
|
| 118 |
-
default_config: Default configuration for new buckets.
|
| 119 |
-
"""
|
| 120 |
-
self._default_config = default_config or RateLimitConfig()
|
| 121 |
-
self._buckets: dict[str, TokenBucket] = {}
|
| 122 |
-
|
| 123 |
-
def _get_bucket(self, key: str, config: RateLimitConfig | None = None) -> TokenBucket:
|
| 124 |
-
"""Get or create a token bucket for the given key.
|
| 125 |
-
|
| 126 |
-
Args:
|
| 127 |
-
key: Unique identifier for the bucket (e.g., user_id + endpoint).
|
| 128 |
-
config: Optional custom configuration.
|
| 129 |
-
|
| 130 |
-
Returns:
|
| 131 |
-
The token bucket for the key.
|
| 132 |
-
"""
|
| 133 |
-
if key not in self._buckets:
|
| 134 |
-
self._buckets[key] = TokenBucket(
|
| 135 |
-
tokens=config.burst_size if config else self._default_config.burst_size,
|
| 136 |
-
config=config or self._default_config,
|
| 137 |
-
)
|
| 138 |
-
return self._buckets[key]
|
| 139 |
-
|
| 140 |
-
def check_rate_limit(
|
| 141 |
-
self,
|
| 142 |
-
key: str,
|
| 143 |
-
tokens: float = 1.0,
|
| 144 |
-
config: RateLimitConfig | None = None,
|
| 145 |
-
) -> tuple[bool, dict[str, Any]]:
|
| 146 |
-
"""Check if a request is within the rate limit.
|
| 147 |
-
|
| 148 |
-
Args:
|
| 149 |
-
key: Rate limit bucket key (e.g., "user_123:query").
|
| 150 |
-
tokens: Tokens to consume.
|
| 151 |
-
config: Optional custom config for this key.
|
| 152 |
-
|
| 153 |
-
Returns:
|
| 154 |
-
Tuple of (allowed, metadata).
|
| 155 |
-
"""
|
| 156 |
-
bucket = self._get_bucket(key, config)
|
| 157 |
-
allowed, metadata = bucket.consume(tokens)
|
| 158 |
-
|
| 159 |
-
if not allowed:
|
| 160 |
-
logger.warning(
|
| 161 |
-
"rate_limit_exceeded",
|
| 162 |
-
key=key,
|
| 163 |
-
retry_after=metadata["retry_after"],
|
| 164 |
-
reason=metadata["reason"],
|
| 165 |
-
)
|
| 166 |
-
else:
|
| 167 |
-
logger.debug(
|
| 168 |
-
"rate_limit_allowed",
|
| 169 |
-
key=key,
|
| 170 |
-
remaining=metadata["remaining"],
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
return allowed, metadata
|
| 174 |
-
|
| 175 |
-
def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
|
| 176 |
-
"""Simple check — returns True if request is allowed.
|
| 177 |
-
|
| 178 |
-
Args:
|
| 179 |
-
key: Rate limit bucket key.
|
| 180 |
-
tokens: Tokens to consume.
|
| 181 |
-
|
| 182 |
-
Returns:
|
| 183 |
-
True if within rate limit, False otherwise.
|
| 184 |
-
"""
|
| 185 |
-
allowed, _ = self.check_rate_limit(key, tokens)
|
| 186 |
-
return allowed
|
| 187 |
-
|
| 188 |
-
def get_status(self, key: str) -> dict[str, Any]:
|
| 189 |
-
"""Get current rate limit status for a key.
|
| 190 |
-
|
| 191 |
-
Args:
|
| 192 |
-
key: Rate limit bucket key.
|
| 193 |
-
|
| 194 |
-
Returns:
|
| 195 |
-
Dict with remaining tokens, reset time, etc.
|
| 196 |
-
"""
|
| 197 |
-
bucket = self._buckets.get(key)
|
| 198 |
-
if not bucket:
|
| 199 |
-
return {
|
| 200 |
-
"remaining": self._default_config.burst_size,
|
| 201 |
-
"limit": self._default_config.requests_per_minute,
|
| 202 |
-
"reset": 0,
|
| 203 |
-
}
|
| 204 |
-
|
| 205 |
-
bucket._refill()
|
| 206 |
-
return {
|
| 207 |
-
"remaining": int(bucket.tokens),
|
| 208 |
-
"limit": bucket.config.requests_per_minute,
|
| 209 |
-
"reset": int(max(0, bucket.blocked_until - time.time())),
|
| 210 |
-
}
|
| 211 |
-
|
| 212 |
-
def reset(self, key: str) -> None:
|
| 213 |
-
"""Reset a specific rate limit bucket.
|
| 214 |
-
|
| 215 |
-
Args:
|
| 216 |
-
key: Bucket key to reset.
|
| 217 |
-
"""
|
| 218 |
-
if key in self._buckets:
|
| 219 |
-
del self._buckets[key]
|
| 220 |
-
logger.info("rate_limit_reset", key=key)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
class OwnerKeyHourThrottle:
|
| 224 |
-
"""Per-IP hourly throttle for the BYOK owner-key fallback.
|
| 225 |
-
|
| 226 |
-
Distinct from the request-level :class:`RateLimiter` because the BYOK
|
| 227 |
-
semantics are different:
|
| 228 |
-
|
| 229 |
-
- Visitors who bring their own LLM key (``ByokCreds.has_user_key()``)
|
| 230 |
-
bypass this throttle entirely — they are paying for their own tokens.
|
| 231 |
-
- Visitors who do NOT bring a key fall back to the platform owner's
|
| 232 |
-
Groq key. This throttle exists to stop a single recruiter or curious
|
| 233 |
-
visitor from burning the free-tier 30 RPM / 14,400 RPD budget.
|
| 234 |
-
|
| 235 |
-
Bucket window is rolling one hour from the first allowed request in
|
| 236 |
-
the window. Sliding-window precision is not needed — three requests an
|
| 237 |
-
hour is already conservative. We keep timestamps in a tiny list per IP
|
| 238 |
-
and prune entries older than 3600 seconds on each check.
|
| 239 |
-
"""
|
| 240 |
-
|
| 241 |
-
__slots__ = ("_buckets", "_quota_per_hour")
|
| 242 |
-
|
| 243 |
-
def __init__(self, quota_per_hour: int) -> None:
|
| 244 |
-
if quota_per_hour < 0:
|
| 245 |
-
raise ValueError("quota_per_hour must be non-negative")
|
| 246 |
-
self._quota_per_hour = quota_per_hour
|
| 247 |
-
self._buckets: dict[str, list[float]] = {}
|
| 248 |
-
|
| 249 |
-
def allow(self, ip: str, *, now: float | None = None) -> tuple[bool, dict[str, Any]]:
|
| 250 |
-
"""Return whether ``ip`` may consume one owner-key request.
|
| 251 |
-
|
| 252 |
-
Args:
|
| 253 |
-
ip: Client IP address (use ``"anon"`` when unavailable so the
|
| 254 |
-
fallback path still throttles instead of leaking quota).
|
| 255 |
-
now: Optional monotonic clock override for tests.
|
| 256 |
-
|
| 257 |
-
Returns:
|
| 258 |
-
``(allowed, meta)`` where ``meta`` carries ``remaining`` and
|
| 259 |
-
``retry_after`` seconds, ready for an HTTP 429 response.
|
| 260 |
-
"""
|
| 261 |
-
t = now if now is not None else time.monotonic()
|
| 262 |
-
# Prune entries older than 1h, then count.
|
| 263 |
-
bucket = [ts for ts in self._buckets.get(ip, []) if t - ts < 3600.0]
|
| 264 |
-
if len(bucket) >= self._quota_per_hour:
|
| 265 |
-
# ``retry_after`` defaults to a full window when quota_per_hour=0
|
| 266 |
-
# (kill switch) — there is no "oldest entry" to expire.
|
| 267 |
-
retry_after = max(1, int(3600.0 - (t - bucket[0])) + 1) if bucket else 3600
|
| 268 |
-
self._buckets[ip] = bucket # write pruned list back
|
| 269 |
-
return False, {
|
| 270 |
-
"allowed": False,
|
| 271 |
-
"remaining": 0,
|
| 272 |
-
"retry_after": retry_after,
|
| 273 |
-
"reason": "owner_key_hourly_quota_exhausted",
|
| 274 |
-
}
|
| 275 |
-
bucket.append(t)
|
| 276 |
-
self._buckets[ip] = bucket
|
| 277 |
-
return True, {
|
| 278 |
-
"allowed": True,
|
| 279 |
-
"remaining": self._quota_per_hour - len(bucket),
|
| 280 |
-
"retry_after": 0,
|
| 281 |
-
"reason": None,
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
-
def reset(self, ip: str) -> None:
|
| 285 |
-
"""Drop all timestamps for ``ip`` (test/cleanup helper)."""
|
| 286 |
-
self._buckets.pop(ip, None)
|
| 287 |
-
|
| 288 |
-
def reset_all(self) -> None:
|
| 289 |
-
"""Drop every bucket — used between test cases to avoid leakage."""
|
| 290 |
-
self._buckets.clear()
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
# Module-level singleton — lazy-initialised from settings on first use so
|
| 294 |
-
# unit tests that monkey-patch SAR_BYOK_OWNER_QUOTA see the right value.
|
| 295 |
-
_owner_key_throttle: OwnerKeyHourThrottle | None = None
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def get_owner_key_throttle() -> OwnerKeyHourThrottle:
|
| 299 |
-
"""Return the process-wide owner-key throttle, creating it lazily.
|
| 300 |
-
|
| 301 |
-
Reads ``settings.byok_owner_key_quota_per_hour`` at first call. Tests
|
| 302 |
-
that need a different quota value should call :func:`reset_owner_key_throttle`
|
| 303 |
-
after the monkey-patch.
|
| 304 |
-
"""
|
| 305 |
-
global _owner_key_throttle
|
| 306 |
-
if _owner_key_throttle is None:
|
| 307 |
-
from config.settings import settings # local import to avoid cycle
|
| 308 |
-
|
| 309 |
-
_owner_key_throttle = OwnerKeyHourThrottle(
|
| 310 |
-
quota_per_hour=settings.byok_owner_key_quota_per_hour,
|
| 311 |
-
)
|
| 312 |
-
return _owner_key_throttle
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
def reset_owner_key_throttle() -> None:
|
| 316 |
-
"""Force the next :func:`get_owner_key_throttle` call to rebuild from settings.
|
| 317 |
-
|
| 318 |
-
Test-only hook; production code never calls this.
|
| 319 |
-
"""
|
| 320 |
-
global _owner_key_throttle
|
| 321 |
-
_owner_key_throttle = None
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
class RedisRateLimiter:
|
| 325 |
-
"""Distributed rate limiter backed by Redis.
|
| 326 |
-
|
| 327 |
-
Uses Redis sorted sets with sliding window algorithm for accurate
|
| 328 |
-
per-user rate limiting across multiple application instances.
|
| 329 |
-
|
| 330 |
-
Args:
|
| 331 |
-
redis_url: Redis connection URL.
|
| 332 |
-
default_config: Default rate limit configuration.
|
| 333 |
-
"""
|
| 334 |
-
|
| 335 |
-
def __init__(
|
| 336 |
-
self,
|
| 337 |
-
redis_url: str | None = None,
|
| 338 |
-
default_config: RateLimitConfig | None = None,
|
| 339 |
-
) -> None:
|
| 340 |
-
"""Initialize the Redis rate limiter.
|
| 341 |
-
|
| 342 |
-
Args:
|
| 343 |
-
redis_url: Redis connection URL. Falls back to settings.
|
| 344 |
-
default_config: Default configuration for new keys.
|
| 345 |
-
"""
|
| 346 |
-
import redis
|
| 347 |
-
|
| 348 |
-
from config.settings import settings
|
| 349 |
-
|
| 350 |
-
self._redis = redis.from_url(redis_url or settings.redis_url)
|
| 351 |
-
self._default_config = default_config or RateLimitConfig()
|
| 352 |
-
|
| 353 |
-
def check_rate_limit(
|
| 354 |
-
self,
|
| 355 |
-
key: str,
|
| 356 |
-
tokens: float = 1.0,
|
| 357 |
-
config: RateLimitConfig | None = None,
|
| 358 |
-
) -> tuple[bool, dict[str, Any]]:
|
| 359 |
-
"""Check if a request is within the rate limit using Redis.
|
| 360 |
-
|
| 361 |
-
Uses a sliding window algorithm based on Redis sorted sets.
|
| 362 |
-
|
| 363 |
-
Args:
|
| 364 |
-
key: Rate limit bucket key.
|
| 365 |
-
tokens: Tokens to consume.
|
| 366 |
-
config: Optional custom config.
|
| 367 |
-
|
| 368 |
-
Returns:
|
| 369 |
-
Tuple of (allowed, metadata).
|
| 370 |
-
"""
|
| 371 |
-
cfg = config or self._default_config
|
| 372 |
-
now = time.time()
|
| 373 |
-
window_start = now - 60.0 # 1-minute window
|
| 374 |
-
redis_key = f"ratelimit:{key}"
|
| 375 |
-
|
| 376 |
-
# Remove old entries outside the window
|
| 377 |
-
self._redis.zremrangebyscore(redis_key, 0, window_start)
|
| 378 |
-
|
| 379 |
-
# Count current requests in window
|
| 380 |
-
current_count = self._redis.zcard(redis_key)
|
| 381 |
-
|
| 382 |
-
# Check burst limit
|
| 383 |
-
if current_count >= cfg.burst_size:
|
| 384 |
-
retry_after = int(cfg.cooldown_seconds) + 1
|
| 385 |
-
return False, {
|
| 386 |
-
"allowed": False,
|
| 387 |
-
"remaining": 0,
|
| 388 |
-
"retry_after": retry_after,
|
| 389 |
-
"reason": "rate_limit_exceeded",
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
# Check per-minute rate
|
| 393 |
-
rpm_limit = cfg.requests_per_minute
|
| 394 |
-
if current_count >= rpm_limit:
|
| 395 |
-
retry_after = int(60 - (now % 60)) + 1
|
| 396 |
-
return False, {
|
| 397 |
-
"allowed": False,
|
| 398 |
-
"remaining": 0,
|
| 399 |
-
"retry_after": retry_after,
|
| 400 |
-
"reason": "rate_limit_exceeded",
|
| 401 |
-
}
|
| 402 |
-
|
| 403 |
-
# Record this request
|
| 404 |
-
self._redis.zadd(redis_key, {str(now): now})
|
| 405 |
-
# Set expiry on the key
|
| 406 |
-
self._redis.expire(redis_key, 120)
|
| 407 |
-
|
| 408 |
-
remaining = min(cfg.burst_size, rpm_limit) - current_count - 1
|
| 409 |
-
return True, {
|
| 410 |
-
"allowed": True,
|
| 411 |
-
"remaining": max(0, remaining),
|
| 412 |
-
"retry_after": 0,
|
| 413 |
-
"reason": None,
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
|
| 417 |
-
"""Simple check — returns True if request is allowed.
|
| 418 |
-
|
| 419 |
-
Args:
|
| 420 |
-
key: Rate limit bucket key.
|
| 421 |
-
tokens: Tokens to consume.
|
| 422 |
-
|
| 423 |
-
Returns:
|
| 424 |
-
True if within rate limit, False otherwise.
|
| 425 |
-
"""
|
| 426 |
-
allowed, _ = self.check_rate_limit(key, tokens)
|
| 427 |
-
return allowed
|
| 428 |
-
|
| 429 |
-
def get_status(self, key: str) -> dict[str, Any]:
|
| 430 |
-
"""Get current rate limit status for a key.
|
| 431 |
-
|
| 432 |
-
Args:
|
| 433 |
-
key: Rate limit bucket key.
|
| 434 |
-
|
| 435 |
-
Returns:
|
| 436 |
-
Dict with remaining tokens, limit, and reset time.
|
| 437 |
-
"""
|
| 438 |
-
redis_key = f"ratelimit:{key}"
|
| 439 |
-
now = time.time()
|
| 440 |
-
window_start = now - 60.0
|
| 441 |
-
self._redis.zremrangebyscore(redis_key, 0, window_start)
|
| 442 |
-
current_count = self._redis.zcard(redis_key)
|
| 443 |
-
remaining = max(0, self._default_config.burst_size - current_count)
|
| 444 |
-
return {
|
| 445 |
-
"remaining": remaining,
|
| 446 |
-
"limit": self._default_config.requests_per_minute,
|
| 447 |
-
"reset": int(60 - (now % 60)),
|
| 448 |
-
}
|
| 449 |
-
|
| 450 |
-
def reset(self, key: str) -> None:
|
| 451 |
-
"""Reset a specific rate limit bucket.
|
| 452 |
-
|
| 453 |
-
Args:
|
| 454 |
-
key: Bucket key to reset.
|
| 455 |
-
"""
|
| 456 |
-
self._redis.delete(f"ratelimit:{key}")
|
| 457 |
-
logger.info("redis_rate_limit_reset", key=key)
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
def _get_rate_limiter() -> RateLimiter | RedisRateLimiter:
|
| 461 |
-
"""Get the appropriate rate limiter based on configuration.
|
| 462 |
-
|
| 463 |
-
Returns:
|
| 464 |
-
RateLimiter (in-memory) or RedisRateLimiter (distributed).
|
| 465 |
-
"""
|
| 466 |
-
from config.settings import settings
|
| 467 |
-
|
| 468 |
-
if settings.use_redis_rate_limiter:
|
| 469 |
-
try:
|
| 470 |
-
return RedisRateLimiter()
|
| 471 |
-
except Exception as exc:
|
| 472 |
-
logger.warning("redis_rate_limiter_failed", error=str(exc), fallback="memory")
|
| 473 |
-
return RateLimiter(default_config=RATE_LIMIT_PROFILES["default"])
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
# Pre-configured rate limit profiles
|
| 477 |
-
RATE_LIMIT_PROFILES: dict[str, RateLimitConfig] = {
|
| 478 |
-
"default": RateLimitConfig(requests_per_minute=60, burst_size=10),
|
| 479 |
-
"strict": RateLimitConfig(requests_per_minute=10, burst_size=3, cooldown_seconds=5.0),
|
| 480 |
-
"generous": RateLimitConfig(requests_per_minute=300, burst_size=50),
|
| 481 |
-
"upload": RateLimitConfig(requests_per_minute=5, burst_size=2, cooldown_seconds=10.0),
|
| 482 |
-
"query": RateLimitConfig(requests_per_minute=30, burst_size=5, cooldown_seconds=2.0),
|
| 483 |
-
}
|
| 484 |
-
|
| 485 |
-
# Module-level singleton (lazy initialization)
|
| 486 |
-
_rate_limiter_instance: RateLimiter | RedisRateLimiter | None = None
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
def _get_limiter() -> RateLimiter | RedisRateLimiter:
|
| 490 |
-
"""Get the singleton rate limiter instance.
|
| 491 |
-
|
| 492 |
-
Returns:
|
| 493 |
-
The configured rate limiter.
|
| 494 |
-
"""
|
| 495 |
-
global _rate_limiter_instance
|
| 496 |
-
if _rate_limiter_instance is None:
|
| 497 |
-
_rate_limiter_instance = _get_rate_limiter()
|
| 498 |
-
return _rate_limiter_instance
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
def check_query_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
|
| 502 |
-
"""Check rate limit for a user query.
|
| 503 |
-
|
| 504 |
-
Args:
|
| 505 |
-
user_id: The user making the query.
|
| 506 |
-
|
| 507 |
-
Returns:
|
| 508 |
-
Tuple of (allowed, metadata).
|
| 509 |
-
"""
|
| 510 |
-
key = f"{user_id}:query"
|
| 511 |
-
return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["query"])
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
def check_upload_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
|
| 515 |
-
"""Check rate limit for a document upload.
|
| 516 |
-
|
| 517 |
-
Args:
|
| 518 |
-
user_id: The user uploading.
|
| 519 |
-
|
| 520 |
-
Returns:
|
| 521 |
-
Tuple of (allowed, metadata).
|
| 522 |
-
"""
|
| 523 |
-
key = f"{user_id}:upload"
|
| 524 |
-
return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["upload"])
|
|
|
|
| 1 |
+
"""Token-bucket rate limiter for API request throttling.
|
| 2 |
+
|
| 3 |
+
Provides per-user and per-endpoint rate limiting to prevent abuse and
|
| 4 |
+
ensure fair resource allocation. Uses an in-memory token bucket algorithm
|
| 5 |
+
with optional Redis backend for distributed deployments.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import time
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from utils.logging import get_logger
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class RateLimitConfig:
|
| 21 |
+
"""Configuration for a rate limit bucket.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
requests_per_minute: Maximum requests allowed per minute.
|
| 25 |
+
burst_size: Maximum burst capacity (bucket size).
|
| 26 |
+
cooldown_seconds: Seconds to wait after being rate limited.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
requests_per_minute: int = 60
|
| 30 |
+
burst_size: int = 10
|
| 31 |
+
cooldown_seconds: float = 1.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class TokenBucket:
|
| 36 |
+
"""In-memory token bucket for rate limiting.
|
| 37 |
+
|
| 38 |
+
Attributes:
|
| 39 |
+
tokens: Current available tokens.
|
| 40 |
+
last_update: Timestamp of last token refill.
|
| 41 |
+
config: Rate limit configuration.
|
| 42 |
+
blocked_until: Timestamp when the bucket is unblocked.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
tokens: float = field(default=0.0)
|
| 46 |
+
last_update: float = field(default_factory=time.time)
|
| 47 |
+
config: RateLimitConfig = field(default_factory=RateLimitConfig)
|
| 48 |
+
blocked_until: float = field(default=0.0)
|
| 49 |
+
|
| 50 |
+
def _refill(self) -> None:
|
| 51 |
+
"""Refill tokens based on elapsed time since last update."""
|
| 52 |
+
now = time.time()
|
| 53 |
+
elapsed = now - self.last_update
|
| 54 |
+
# Refill rate: tokens per second
|
| 55 |
+
refill_rate = self.config.requests_per_minute / 60.0
|
| 56 |
+
self.tokens = min(self.config.burst_size, self.tokens + elapsed * refill_rate)
|
| 57 |
+
self.last_update = now
|
| 58 |
+
|
| 59 |
+
def consume(self, tokens: float = 1.0) -> tuple[bool, dict[str, Any]]:
|
| 60 |
+
"""Attempt to consume tokens from the bucket.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
tokens: Number of tokens to consume (default 1 per request).
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple of (allowed, metadata) where metadata contains
|
| 67 |
+
remaining tokens, retry_after, etc.
|
| 68 |
+
"""
|
| 69 |
+
now = time.time()
|
| 70 |
+
|
| 71 |
+
# Check if currently blocked
|
| 72 |
+
if now < self.blocked_until:
|
| 73 |
+
retry_after = int(self.blocked_until - now) + 1
|
| 74 |
+
return False, {
|
| 75 |
+
"allowed": False,
|
| 76 |
+
"remaining": 0,
|
| 77 |
+
"retry_after": retry_after,
|
| 78 |
+
"reason": "cooldown_active",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
self._refill()
|
| 82 |
+
|
| 83 |
+
if self.tokens >= tokens:
|
| 84 |
+
self.tokens -= tokens
|
| 85 |
+
remaining = int(self.tokens)
|
| 86 |
+
return True, {
|
| 87 |
+
"allowed": True,
|
| 88 |
+
"remaining": remaining,
|
| 89 |
+
"retry_after": 0,
|
| 90 |
+
"reason": None,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Rate limit exceeded — enter cooldown
|
| 94 |
+
self.blocked_until = now + self.config.cooldown_seconds
|
| 95 |
+
retry_after = int(self.config.cooldown_seconds) + 1
|
| 96 |
+
return False, {
|
| 97 |
+
"allowed": False,
|
| 98 |
+
"remaining": 0,
|
| 99 |
+
"retry_after": retry_after,
|
| 100 |
+
"reason": "rate_limit_exceeded",
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class RateLimiter:
|
| 105 |
+
"""Multi-key rate limiter with per-user and per-endpoint tracking.
|
| 106 |
+
|
| 107 |
+
Uses in-memory token buckets. For distributed deployments, wrap
|
| 108 |
+
with a Redis-backed implementation.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
default_config: Default rate limit configuration.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, default_config: RateLimitConfig | None = None) -> None:
|
| 115 |
+
"""Initialize the rate limiter.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
default_config: Default configuration for new buckets.
|
| 119 |
+
"""
|
| 120 |
+
self._default_config = default_config or RateLimitConfig()
|
| 121 |
+
self._buckets: dict[str, TokenBucket] = {}
|
| 122 |
+
|
| 123 |
+
def _get_bucket(self, key: str, config: RateLimitConfig | None = None) -> TokenBucket:
|
| 124 |
+
"""Get or create a token bucket for the given key.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
key: Unique identifier for the bucket (e.g., user_id + endpoint).
|
| 128 |
+
config: Optional custom configuration.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
The token bucket for the key.
|
| 132 |
+
"""
|
| 133 |
+
if key not in self._buckets:
|
| 134 |
+
self._buckets[key] = TokenBucket(
|
| 135 |
+
tokens=config.burst_size if config else self._default_config.burst_size,
|
| 136 |
+
config=config or self._default_config,
|
| 137 |
+
)
|
| 138 |
+
return self._buckets[key]
|
| 139 |
+
|
| 140 |
+
def check_rate_limit(
|
| 141 |
+
self,
|
| 142 |
+
key: str,
|
| 143 |
+
tokens: float = 1.0,
|
| 144 |
+
config: RateLimitConfig | None = None,
|
| 145 |
+
) -> tuple[bool, dict[str, Any]]:
|
| 146 |
+
"""Check if a request is within the rate limit.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
key: Rate limit bucket key (e.g., "user_123:query").
|
| 150 |
+
tokens: Tokens to consume.
|
| 151 |
+
config: Optional custom config for this key.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Tuple of (allowed, metadata).
|
| 155 |
+
"""
|
| 156 |
+
bucket = self._get_bucket(key, config)
|
| 157 |
+
allowed, metadata = bucket.consume(tokens)
|
| 158 |
+
|
| 159 |
+
if not allowed:
|
| 160 |
+
logger.warning(
|
| 161 |
+
"rate_limit_exceeded",
|
| 162 |
+
key=key,
|
| 163 |
+
retry_after=metadata["retry_after"],
|
| 164 |
+
reason=metadata["reason"],
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
logger.debug(
|
| 168 |
+
"rate_limit_allowed",
|
| 169 |
+
key=key,
|
| 170 |
+
remaining=metadata["remaining"],
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return allowed, metadata
|
| 174 |
+
|
| 175 |
+
def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
|
| 176 |
+
"""Simple check — returns True if request is allowed.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
key: Rate limit bucket key.
|
| 180 |
+
tokens: Tokens to consume.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
True if within rate limit, False otherwise.
|
| 184 |
+
"""
|
| 185 |
+
allowed, _ = self.check_rate_limit(key, tokens)
|
| 186 |
+
return allowed
|
| 187 |
+
|
| 188 |
+
def get_status(self, key: str) -> dict[str, Any]:
|
| 189 |
+
"""Get current rate limit status for a key.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
key: Rate limit bucket key.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Dict with remaining tokens, reset time, etc.
|
| 196 |
+
"""
|
| 197 |
+
bucket = self._buckets.get(key)
|
| 198 |
+
if not bucket:
|
| 199 |
+
return {
|
| 200 |
+
"remaining": self._default_config.burst_size,
|
| 201 |
+
"limit": self._default_config.requests_per_minute,
|
| 202 |
+
"reset": 0,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
bucket._refill()
|
| 206 |
+
return {
|
| 207 |
+
"remaining": int(bucket.tokens),
|
| 208 |
+
"limit": bucket.config.requests_per_minute,
|
| 209 |
+
"reset": int(max(0, bucket.blocked_until - time.time())),
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
def reset(self, key: str) -> None:
|
| 213 |
+
"""Reset a specific rate limit bucket.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
key: Bucket key to reset.
|
| 217 |
+
"""
|
| 218 |
+
if key in self._buckets:
|
| 219 |
+
del self._buckets[key]
|
| 220 |
+
logger.info("rate_limit_reset", key=key)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class OwnerKeyHourThrottle:
|
| 224 |
+
"""Per-IP hourly throttle for the BYOK owner-key fallback.
|
| 225 |
+
|
| 226 |
+
Distinct from the request-level :class:`RateLimiter` because the BYOK
|
| 227 |
+
semantics are different:
|
| 228 |
+
|
| 229 |
+
- Visitors who bring their own LLM key (``ByokCreds.has_user_key()``)
|
| 230 |
+
bypass this throttle entirely — they are paying for their own tokens.
|
| 231 |
+
- Visitors who do NOT bring a key fall back to the platform owner's
|
| 232 |
+
Groq key. This throttle exists to stop a single recruiter or curious
|
| 233 |
+
visitor from burning the free-tier 30 RPM / 14,400 RPD budget.
|
| 234 |
+
|
| 235 |
+
Bucket window is rolling one hour from the first allowed request in
|
| 236 |
+
the window. Sliding-window precision is not needed — three requests an
|
| 237 |
+
hour is already conservative. We keep timestamps in a tiny list per IP
|
| 238 |
+
and prune entries older than 3600 seconds on each check.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
__slots__ = ("_buckets", "_quota_per_hour")
|
| 242 |
+
|
| 243 |
+
def __init__(self, quota_per_hour: int) -> None:
|
| 244 |
+
if quota_per_hour < 0:
|
| 245 |
+
raise ValueError("quota_per_hour must be non-negative")
|
| 246 |
+
self._quota_per_hour = quota_per_hour
|
| 247 |
+
self._buckets: dict[str, list[float]] = {}
|
| 248 |
+
|
| 249 |
+
def allow(self, ip: str, *, now: float | None = None) -> tuple[bool, dict[str, Any]]:
|
| 250 |
+
"""Return whether ``ip`` may consume one owner-key request.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
ip: Client IP address (use ``"anon"`` when unavailable so the
|
| 254 |
+
fallback path still throttles instead of leaking quota).
|
| 255 |
+
now: Optional monotonic clock override for tests.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
``(allowed, meta)`` where ``meta`` carries ``remaining`` and
|
| 259 |
+
``retry_after`` seconds, ready for an HTTP 429 response.
|
| 260 |
+
"""
|
| 261 |
+
t = now if now is not None else time.monotonic()
|
| 262 |
+
# Prune entries older than 1h, then count.
|
| 263 |
+
bucket = [ts for ts in self._buckets.get(ip, []) if t - ts < 3600.0]
|
| 264 |
+
if len(bucket) >= self._quota_per_hour:
|
| 265 |
+
# ``retry_after`` defaults to a full window when quota_per_hour=0
|
| 266 |
+
# (kill switch) — there is no "oldest entry" to expire.
|
| 267 |
+
retry_after = max(1, int(3600.0 - (t - bucket[0])) + 1) if bucket else 3600
|
| 268 |
+
self._buckets[ip] = bucket # write pruned list back
|
| 269 |
+
return False, {
|
| 270 |
+
"allowed": False,
|
| 271 |
+
"remaining": 0,
|
| 272 |
+
"retry_after": retry_after,
|
| 273 |
+
"reason": "owner_key_hourly_quota_exhausted",
|
| 274 |
+
}
|
| 275 |
+
bucket.append(t)
|
| 276 |
+
self._buckets[ip] = bucket
|
| 277 |
+
return True, {
|
| 278 |
+
"allowed": True,
|
| 279 |
+
"remaining": self._quota_per_hour - len(bucket),
|
| 280 |
+
"retry_after": 0,
|
| 281 |
+
"reason": None,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
def reset(self, ip: str) -> None:
|
| 285 |
+
"""Drop all timestamps for ``ip`` (test/cleanup helper)."""
|
| 286 |
+
self._buckets.pop(ip, None)
|
| 287 |
+
|
| 288 |
+
def reset_all(self) -> None:
|
| 289 |
+
"""Drop every bucket — used between test cases to avoid leakage."""
|
| 290 |
+
self._buckets.clear()
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# Module-level singleton — lazy-initialised from settings on first use so
|
| 294 |
+
# unit tests that monkey-patch SAR_BYOK_OWNER_QUOTA see the right value.
|
| 295 |
+
_owner_key_throttle: OwnerKeyHourThrottle | None = None
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def get_owner_key_throttle() -> OwnerKeyHourThrottle:
|
| 299 |
+
"""Return the process-wide owner-key throttle, creating it lazily.
|
| 300 |
+
|
| 301 |
+
Reads ``settings.byok_owner_key_quota_per_hour`` at first call. Tests
|
| 302 |
+
that need a different quota value should call :func:`reset_owner_key_throttle`
|
| 303 |
+
after the monkey-patch.
|
| 304 |
+
"""
|
| 305 |
+
global _owner_key_throttle
|
| 306 |
+
if _owner_key_throttle is None:
|
| 307 |
+
from config.settings import settings # local import to avoid cycle
|
| 308 |
+
|
| 309 |
+
_owner_key_throttle = OwnerKeyHourThrottle(
|
| 310 |
+
quota_per_hour=settings.byok_owner_key_quota_per_hour,
|
| 311 |
+
)
|
| 312 |
+
return _owner_key_throttle
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def reset_owner_key_throttle() -> None:
|
| 316 |
+
"""Force the next :func:`get_owner_key_throttle` call to rebuild from settings.
|
| 317 |
+
|
| 318 |
+
Test-only hook; production code never calls this.
|
| 319 |
+
"""
|
| 320 |
+
global _owner_key_throttle
|
| 321 |
+
_owner_key_throttle = None
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class RedisRateLimiter:
|
| 325 |
+
"""Distributed rate limiter backed by Redis.
|
| 326 |
+
|
| 327 |
+
Uses Redis sorted sets with sliding window algorithm for accurate
|
| 328 |
+
per-user rate limiting across multiple application instances.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
redis_url: Redis connection URL.
|
| 332 |
+
default_config: Default rate limit configuration.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
def __init__(
|
| 336 |
+
self,
|
| 337 |
+
redis_url: str | None = None,
|
| 338 |
+
default_config: RateLimitConfig | None = None,
|
| 339 |
+
) -> None:
|
| 340 |
+
"""Initialize the Redis rate limiter.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
redis_url: Redis connection URL. Falls back to settings.
|
| 344 |
+
default_config: Default configuration for new keys.
|
| 345 |
+
"""
|
| 346 |
+
import redis
|
| 347 |
+
|
| 348 |
+
from config.settings import settings
|
| 349 |
+
|
| 350 |
+
self._redis = redis.from_url(redis_url or settings.redis_url)
|
| 351 |
+
self._default_config = default_config or RateLimitConfig()
|
| 352 |
+
|
| 353 |
+
def check_rate_limit(
|
| 354 |
+
self,
|
| 355 |
+
key: str,
|
| 356 |
+
tokens: float = 1.0,
|
| 357 |
+
config: RateLimitConfig | None = None,
|
| 358 |
+
) -> tuple[bool, dict[str, Any]]:
|
| 359 |
+
"""Check if a request is within the rate limit using Redis.
|
| 360 |
+
|
| 361 |
+
Uses a sliding window algorithm based on Redis sorted sets.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
key: Rate limit bucket key.
|
| 365 |
+
tokens: Tokens to consume.
|
| 366 |
+
config: Optional custom config.
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
Tuple of (allowed, metadata).
|
| 370 |
+
"""
|
| 371 |
+
cfg = config or self._default_config
|
| 372 |
+
now = time.time()
|
| 373 |
+
window_start = now - 60.0 # 1-minute window
|
| 374 |
+
redis_key = f"ratelimit:{key}"
|
| 375 |
+
|
| 376 |
+
# Remove old entries outside the window
|
| 377 |
+
self._redis.zremrangebyscore(redis_key, 0, window_start)
|
| 378 |
+
|
| 379 |
+
# Count current requests in window
|
| 380 |
+
current_count = self._redis.zcard(redis_key)
|
| 381 |
+
|
| 382 |
+
# Check burst limit
|
| 383 |
+
if current_count >= cfg.burst_size:
|
| 384 |
+
retry_after = int(cfg.cooldown_seconds) + 1
|
| 385 |
+
return False, {
|
| 386 |
+
"allowed": False,
|
| 387 |
+
"remaining": 0,
|
| 388 |
+
"retry_after": retry_after,
|
| 389 |
+
"reason": "rate_limit_exceeded",
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
# Check per-minute rate
|
| 393 |
+
rpm_limit = cfg.requests_per_minute
|
| 394 |
+
if current_count >= rpm_limit:
|
| 395 |
+
retry_after = int(60 - (now % 60)) + 1
|
| 396 |
+
return False, {
|
| 397 |
+
"allowed": False,
|
| 398 |
+
"remaining": 0,
|
| 399 |
+
"retry_after": retry_after,
|
| 400 |
+
"reason": "rate_limit_exceeded",
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
# Record this request
|
| 404 |
+
self._redis.zadd(redis_key, {str(now): now})
|
| 405 |
+
# Set expiry on the key
|
| 406 |
+
self._redis.expire(redis_key, 120)
|
| 407 |
+
|
| 408 |
+
remaining = min(cfg.burst_size, rpm_limit) - current_count - 1
|
| 409 |
+
return True, {
|
| 410 |
+
"allowed": True,
|
| 411 |
+
"remaining": max(0, remaining),
|
| 412 |
+
"retry_after": 0,
|
| 413 |
+
"reason": None,
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
|
| 417 |
+
"""Simple check — returns True if request is allowed.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
key: Rate limit bucket key.
|
| 421 |
+
tokens: Tokens to consume.
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
True if within rate limit, False otherwise.
|
| 425 |
+
"""
|
| 426 |
+
allowed, _ = self.check_rate_limit(key, tokens)
|
| 427 |
+
return allowed
|
| 428 |
+
|
| 429 |
+
def get_status(self, key: str) -> dict[str, Any]:
|
| 430 |
+
"""Get current rate limit status for a key.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
key: Rate limit bucket key.
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
Dict with remaining tokens, limit, and reset time.
|
| 437 |
+
"""
|
| 438 |
+
redis_key = f"ratelimit:{key}"
|
| 439 |
+
now = time.time()
|
| 440 |
+
window_start = now - 60.0
|
| 441 |
+
self._redis.zremrangebyscore(redis_key, 0, window_start)
|
| 442 |
+
current_count = self._redis.zcard(redis_key)
|
| 443 |
+
remaining = max(0, self._default_config.burst_size - current_count)
|
| 444 |
+
return {
|
| 445 |
+
"remaining": remaining,
|
| 446 |
+
"limit": self._default_config.requests_per_minute,
|
| 447 |
+
"reset": int(60 - (now % 60)),
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
def reset(self, key: str) -> None:
|
| 451 |
+
"""Reset a specific rate limit bucket.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
key: Bucket key to reset.
|
| 455 |
+
"""
|
| 456 |
+
self._redis.delete(f"ratelimit:{key}")
|
| 457 |
+
logger.info("redis_rate_limit_reset", key=key)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _get_rate_limiter() -> RateLimiter | RedisRateLimiter:
|
| 461 |
+
"""Get the appropriate rate limiter based on configuration.
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
RateLimiter (in-memory) or RedisRateLimiter (distributed).
|
| 465 |
+
"""
|
| 466 |
+
from config.settings import settings
|
| 467 |
+
|
| 468 |
+
if settings.use_redis_rate_limiter:
|
| 469 |
+
try:
|
| 470 |
+
return RedisRateLimiter()
|
| 471 |
+
except Exception as exc:
|
| 472 |
+
logger.warning("redis_rate_limiter_failed", error=str(exc), fallback="memory")
|
| 473 |
+
return RateLimiter(default_config=RATE_LIMIT_PROFILES["default"])
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# Pre-configured rate limit profiles
|
| 477 |
+
RATE_LIMIT_PROFILES: dict[str, RateLimitConfig] = {
|
| 478 |
+
"default": RateLimitConfig(requests_per_minute=60, burst_size=10),
|
| 479 |
+
"strict": RateLimitConfig(requests_per_minute=10, burst_size=3, cooldown_seconds=5.0),
|
| 480 |
+
"generous": RateLimitConfig(requests_per_minute=300, burst_size=50),
|
| 481 |
+
"upload": RateLimitConfig(requests_per_minute=5, burst_size=2, cooldown_seconds=10.0),
|
| 482 |
+
"query": RateLimitConfig(requests_per_minute=30, burst_size=5, cooldown_seconds=2.0),
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
# Module-level singleton (lazy initialization)
|
| 486 |
+
_rate_limiter_instance: RateLimiter | RedisRateLimiter | None = None
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def _get_limiter() -> RateLimiter | RedisRateLimiter:
|
| 490 |
+
"""Get the singleton rate limiter instance.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
The configured rate limiter.
|
| 494 |
+
"""
|
| 495 |
+
global _rate_limiter_instance
|
| 496 |
+
if _rate_limiter_instance is None:
|
| 497 |
+
_rate_limiter_instance = _get_rate_limiter()
|
| 498 |
+
return _rate_limiter_instance
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def check_query_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
|
| 502 |
+
"""Check rate limit for a user query.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
user_id: The user making the query.
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
Tuple of (allowed, metadata).
|
| 509 |
+
"""
|
| 510 |
+
key = f"{user_id}:query"
|
| 511 |
+
return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["query"])
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def check_upload_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
|
| 515 |
+
"""Check rate limit for a document upload.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
user_id: The user uploading.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
Tuple of (allowed, metadata).
|
| 522 |
+
"""
|
| 523 |
+
key = f"{user_id}:upload"
|
| 524 |
+
return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["upload"])
|