LeomordKaly commited on
Commit
f4ef3b8
·
verified ·
1 Parent(s): b2a2d08

deploy: phase 3 BYOK backend (Dockerfile.hf, FastAPI on 7860)

Browse files
Dockerfile.hf CHANGED
@@ -1,151 +1,151 @@
1
- # =============================================================================
2
- # Dockerfile.hf — SecureAgentRAG backend for Hugging Face Spaces (CPU Basic).
3
- # =============================================================================
4
- # Two-stage build keeps the runtime image lean. The HF Space free tier is
5
- # CPU-only with 16 GB RAM and ~50 GB ephemeral disk, so we target a tight
6
- # memory footprint:
7
- #
8
- # - Python 3.11-slim base (~150 MB)
9
- # - Only [api, embeddings-local, pii] extras (no OCR, no Phoenix, no Postgres,
10
- # no Redis, no MCP) -- those modules are present in the source but their
11
- # dependencies are not installed
12
- # - cross-encoder reranker downloaded on first request (auto-cached under
13
- # /home/user/.cache/huggingface). Skips the 2.3 GB fine-tuned checkpoint
14
- # for the initial deploy; phase 3.2 can swap to fine_tuned once the
15
- # reranker repo is published on HF Hub.
16
- #
17
- # The Space-side README.md is uploaded separately by scripts/deploy_hf_space.py
18
- # with a YAML frontmatter declaring sdk=docker + app_port=7860.
19
- # =============================================================================
20
-
21
- # --- builder ----------------------------------------------------------------
22
- FROM python:3.11-slim AS builder
23
-
24
- WORKDIR /app
25
-
26
- RUN pip install --no-cache-dir uv
27
-
28
- # pyproject.toml + a copy of the source are required for uv to build the
29
- # editable install. README.md is referenced as the long_description.
30
- COPY pyproject.toml ./
31
- COPY README.md ./
32
-
33
- # Touch the package directories that hatchling treats as the wheel root --
34
- # we only need the directory tree to exist at build time so hatchling can
35
- # scan for __init__.py files. The actual code lands in the runtime stage.
36
- RUN mkdir -p config core inference retrieval interfaces ingestion utils evaluation app \
37
- && touch config/__init__.py core/__init__.py inference/__init__.py \
38
- && touch retrieval/__init__.py interfaces/__init__.py ingestion/__init__.py \
39
- && touch utils/__init__.py evaluation/__init__.py app/__init__.py
40
-
41
- # Intentionally skip [pii] extras -- the regex patterns in utils/pii.py
42
- # already cover every BYOK key shape (Groq / OpenAI / Anthropic / HF / Vercel
43
- # / Qdrant JWT / Qdrant management). Adding Presidio would pull spaCy
44
- # en_core_web_lg (~770 MB) which auto-downloads at runtime and crashes the
45
- # container on the CPU Basic Space when the package installer is absent.
46
- RUN uv venv /app/.venv \
47
- && uv pip install --python /app/.venv/bin/python \
48
- -e ".[api,embeddings-local]"
49
-
50
- # --- runtime ----------------------------------------------------------------
51
- FROM python:3.11-slim AS runtime
52
-
53
- WORKDIR /app
54
-
55
- # HF Spaces convention: run as uid 1000 with a writeable /home/user.
56
- RUN useradd -m -u 1000 user
57
-
58
- # System deps for PDF / image processing only -- no OCR / paddle.
59
- # Debian 12+ (trixie) renamed libgl1-mesa-glx -> libgl1 and libxrender-dev
60
- # is no longer needed at runtime (runtime is libxrender1).
61
- RUN apt-get update \
62
- && apt-get install -y --no-install-recommends \
63
- libglib2.0-0 libsm6 libxext6 libxrender1 libgl1 curl \
64
- && rm -rf /var/lib/apt/lists/*
65
-
66
- # Bring the virtualenv from the builder stage.
67
- COPY --from=builder /app/.venv /app/.venv
68
- ENV PATH="/app/.venv/bin:$PATH"
69
-
70
- # Copy application source. Files that match .dockerignore are filtered out.
71
- COPY --chown=user:user . /app
72
-
73
- USER user
74
-
75
- # Pre-populate the HF cache so the cross-encoder lives on disk before the
76
- # first request. Defensive: never fails the build -- if HF Hub is unreachable
77
- # during build (offline mirrors etc.) the cache is populated on first query.
78
- RUN python -c "import os; \
79
- from huggingface_hub import snapshot_download; \
80
- import sys; \
81
- try: snapshot_download(repo_id='BAAI/bge-reranker-v2-m3', cache_dir='/home/user/.cache/huggingface/hub'); print('reranker cached') \
82
- except Exception as e: print(f'reranker cache skipped: {e!r}', file=sys.stderr)" \
83
- || echo "build-time reranker download failed -- will lazy-load on first request"
84
-
85
- # --- BYOK production env ---------------------------------------------------
86
- # Real secrets (Qdrant URL + API key, Groq key) are injected via HF Space
87
- # secrets panel -- they ride the same SAR_* env-var protocol but are NOT
88
- # baked into the image. Only mode flags and safe defaults live here.
89
- ENV SAR_BYOK_MODE=true
90
- ENV SAR_BYOK_OWNER_QUOTA=3
91
- ENV SAR_SESSION_TTL_HOURS=24
92
- ENV SAR_CORS_ALLOW_ORIGINS='["https://app.eilm.live","https://secureagentrag-web.vercel.app","https://secureagentrag.vercel.app"]'
93
-
94
- # Cloud LLM defaults -- Groq llama-3.1-8b-instant is the cheapest fast option
95
- # on the free tier. Visitor BYOK overrides this per request.
96
- ENV SAR_DEFAULT_PROVIDER=groq
97
- ENV SAR_CLOUD_PROVIDER=groq
98
- ENV SAR_LLM_MODEL=llama-3.1-8b-instant
99
-
100
- # Embedding stack -- local BGE-M3 via sentence-transformers (CPU). Avoids
101
- # Ollama entirely.
102
- ENV SAR_EMBEDDING_BACKEND=local
103
- ENV SAR_LOCAL_EMBEDDING_MODEL=BAAI/bge-m3
104
- ENV SAR_EMBEDDING_MODEL=bge-m3
105
- ENV SAR_EMBEDDING_DIM=1024
106
-
107
- # Cross-encoder reranker -- balances quality with build size. Swap to
108
- # fine_tuned + SAR_FINETUNED_RERANKER_PATH after phase 3.2 ships the
109
- # 2.3 GB checkpoint to LeomordKaly/secureagentrag-reranker-v1.
110
- ENV SAR_RERANKER_TYPE=cross_encoder
111
- ENV SAR_RERANKER_CHECKPOINT=BAAI/bge-reranker-v2-m3
112
-
113
- # Sparse retrieval -- BM25 keeps the cold path zero-dep; SPLADE adds an
114
- # extra ~600 MB model and is skipped on free CPU Basic.
115
- ENV SAR_SPARSE_BACKEND=bm25
116
-
117
- # Persistence paths -- /tmp is the only writable area on HF Spaces.
118
- ENV SAR_AUDIT_LOG_DIR=/tmp/secureagentrag/audit_logs
119
- ENV SAR_CONVERSATION_DIR=/tmp/secureagentrag/conversations
120
- ENV SAR_CHECKPOINT_DB_PATH=/tmp/secureagentrag/checkpoints.sqlite
121
- ENV SAR_BM25_INDEX_PATH=/tmp/secureagentrag/bm25_index.pkl
122
-
123
- # Multi-tenant collections route BYOK session -> documents_sess_<sid>.
124
- ENV SAR_MULTI_TENANT_COLLECTIONS=true
125
-
126
- # Pipeline safety
127
- ENV SAR_REQUEST_TIMEOUT_S=120
128
- ENV SAR_FAITHFULNESS_GATE_ENABLED=true
129
- ENV SAR_FAITHFULNESS_GATE_MODE=flag
130
- ENV SAR_FAITHFULNESS_THRESHOLD=0.7
131
-
132
- # Logging
133
- ENV SAR_LOG_LEVEL=INFO
134
-
135
- # HF cache lives under the user home which is the only persistent writable
136
- # tree across Space restarts on CPU Basic.
137
- ENV HF_HOME=/home/user/.cache/huggingface
138
- ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface/hub
139
-
140
- EXPOSE 7860
141
-
142
- HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
143
- CMD curl --fail --silent --show-error http://localhost:7860/healthz || exit 1
144
-
145
- # uvicorn with 1 worker -- on CPU Basic two workers thrash the memory.
146
- CMD ["uvicorn", "interfaces.api:app", \
147
- "--host", "0.0.0.0", \
148
- "--port", "7860", \
149
- "--workers", "1", \
150
- "--timeout-keep-alive", "30", \
151
- "--no-access-log"]
 
1
+ # =============================================================================
2
+ # Dockerfile.hf — SecureAgentRAG backend for Hugging Face Spaces (CPU Basic).
3
+ # =============================================================================
4
+ # Two-stage build keeps the runtime image lean. The HF Space free tier is
5
+ # CPU-only with 16 GB RAM and ~50 GB ephemeral disk, so we target a tight
6
+ # memory footprint:
7
+ #
8
+ # - Python 3.11-slim base (~150 MB)
9
+ # - Only [api, embeddings-local, pii] extras (no OCR, no Phoenix, no Postgres,
10
+ # no Redis, no MCP) -- those modules are present in the source but their
11
+ # dependencies are not installed
12
+ # - cross-encoder reranker downloaded on first request (auto-cached under
13
+ # /home/user/.cache/huggingface). Skips the 2.3 GB fine-tuned checkpoint
14
+ # for the initial deploy; phase 3.2 can swap to fine_tuned once the
15
+ # reranker repo is published on HF Hub.
16
+ #
17
+ # The Space-side README.md is uploaded separately by scripts/deploy_hf_space.py
18
+ # with a YAML frontmatter declaring sdk=docker + app_port=7860.
19
+ # =============================================================================
20
+
21
+ # --- builder ----------------------------------------------------------------
22
+ FROM python:3.11-slim AS builder
23
+
24
+ WORKDIR /app
25
+
26
+ RUN pip install --no-cache-dir uv
27
+
28
+ # pyproject.toml + a copy of the source are required for uv to build the
29
+ # editable install. README.md is referenced as the long_description.
30
+ COPY pyproject.toml ./
31
+ COPY README.md ./
32
+
33
+ # Touch the package directories that hatchling treats as the wheel root --
34
+ # we only need the directory tree to exist at build time so hatchling can
35
+ # scan for __init__.py files. The actual code lands in the runtime stage.
36
+ RUN mkdir -p config core inference retrieval interfaces ingestion utils evaluation app \
37
+ && touch config/__init__.py core/__init__.py inference/__init__.py \
38
+ && touch retrieval/__init__.py interfaces/__init__.py ingestion/__init__.py \
39
+ && touch utils/__init__.py evaluation/__init__.py app/__init__.py
40
+
41
+ # Intentionally skip [pii] extras -- the regex patterns in utils/pii.py
42
+ # already cover every BYOK key shape (Groq / OpenAI / Anthropic / HF / Vercel
43
+ # / Qdrant JWT / Qdrant management). Adding Presidio would pull spaCy
44
+ # en_core_web_lg (~770 MB) which auto-downloads at runtime and crashes the
45
+ # container on the CPU Basic Space when the package installer is absent.
46
+ RUN uv venv /app/.venv \
47
+ && uv pip install --python /app/.venv/bin/python \
48
+ -e ".[api,embeddings-local]"
49
+
50
+ # --- runtime ----------------------------------------------------------------
51
+ FROM python:3.11-slim AS runtime
52
+
53
+ WORKDIR /app
54
+
55
+ # HF Spaces convention: run as uid 1000 with a writeable /home/user.
56
+ RUN useradd -m -u 1000 user
57
+
58
+ # System deps for PDF / image processing only -- no OCR / paddle.
59
+ # Debian 12+ (trixie) renamed libgl1-mesa-glx -> libgl1 and libxrender-dev
60
+ # is no longer needed at runtime (runtime is libxrender1).
61
+ RUN apt-get update \
62
+ && apt-get install -y --no-install-recommends \
63
+ libglib2.0-0 libsm6 libxext6 libxrender1 libgl1 curl \
64
+ && rm -rf /var/lib/apt/lists/*
65
+
66
+ # Bring the virtualenv from the builder stage.
67
+ COPY --from=builder /app/.venv /app/.venv
68
+ ENV PATH="/app/.venv/bin:$PATH"
69
+
70
+ # Copy application source. Files that match .dockerignore are filtered out.
71
+ COPY --chown=user:user . /app
72
+
73
+ USER user
74
+
75
+ # Pre-populate the HF cache so the cross-encoder lives on disk before the
76
+ # first request. Defensive: never fails the build -- if HF Hub is unreachable
77
+ # during build (offline mirrors etc.) the cache is populated on first query.
78
+ RUN python -c "import os; \
79
+ from huggingface_hub import snapshot_download; \
80
+ import sys; \
81
+ try: snapshot_download(repo_id='BAAI/bge-reranker-v2-m3', cache_dir='/home/user/.cache/huggingface/hub'); print('reranker cached') \
82
+ except Exception as e: print(f'reranker cache skipped: {e!r}', file=sys.stderr)" \
83
+ || echo "build-time reranker download failed -- will lazy-load on first request"
84
+
85
+ # --- BYOK production env ---------------------------------------------------
86
+ # Real secrets (Qdrant URL + API key, Groq key) are injected via HF Space
87
+ # secrets panel -- they ride the same SAR_* env-var protocol but are NOT
88
+ # baked into the image. Only mode flags and safe defaults live here.
89
+ ENV SAR_BYOK_MODE=true
90
+ ENV SAR_BYOK_OWNER_QUOTA=3
91
+ ENV SAR_SESSION_TTL_HOURS=24
92
+ ENV SAR_CORS_ALLOW_ORIGINS='["https://app.eilm.live","https://secureagentrag-web.vercel.app","https://secureagentrag.vercel.app"]'
93
+
94
+ # Cloud LLM defaults -- Groq llama-3.1-8b-instant is the cheapest fast option
95
+ # on the free tier. Visitor BYOK overrides this per request.
96
+ ENV SAR_DEFAULT_PROVIDER=groq
97
+ ENV SAR_CLOUD_PROVIDER=groq
98
+ ENV SAR_LLM_MODEL=llama-3.1-8b-instant
99
+
100
+ # Embedding stack -- local BGE-M3 via sentence-transformers (CPU). Avoids
101
+ # Ollama entirely.
102
+ ENV SAR_EMBEDDING_BACKEND=local
103
+ ENV SAR_LOCAL_EMBEDDING_MODEL=BAAI/bge-m3
104
+ ENV SAR_EMBEDDING_MODEL=bge-m3
105
+ ENV SAR_EMBEDDING_DIM=1024
106
+
107
+ # Cross-encoder reranker -- balances quality with build size. Swap to
108
+ # fine_tuned + SAR_FINETUNED_RERANKER_PATH after phase 3.2 ships the
109
+ # 2.3 GB checkpoint to LeomordKaly/secureagentrag-reranker-v1.
110
+ ENV SAR_RERANKER_TYPE=cross_encoder
111
+ ENV SAR_RERANKER_CHECKPOINT=BAAI/bge-reranker-v2-m3
112
+
113
+ # Sparse retrieval -- BM25 keeps the cold path zero-dep; SPLADE adds an
114
+ # extra ~600 MB model and is skipped on free CPU Basic.
115
+ ENV SAR_SPARSE_BACKEND=bm25
116
+
117
+ # Persistence paths -- /tmp is the only writable area on HF Spaces.
118
+ ENV SAR_AUDIT_LOG_DIR=/tmp/secureagentrag/audit_logs
119
+ ENV SAR_CONVERSATION_DIR=/tmp/secureagentrag/conversations
120
+ ENV SAR_CHECKPOINT_DB_PATH=/tmp/secureagentrag/checkpoints.sqlite
121
+ ENV SAR_BM25_INDEX_PATH=/tmp/secureagentrag/bm25_index.pkl
122
+
123
+ # Multi-tenant collections route BYOK session -> documents_sess_<sid>.
124
+ ENV SAR_MULTI_TENANT_COLLECTIONS=true
125
+
126
+ # Pipeline safety
127
+ ENV SAR_REQUEST_TIMEOUT_S=120
128
+ ENV SAR_FAITHFULNESS_GATE_ENABLED=true
129
+ ENV SAR_FAITHFULNESS_GATE_MODE=flag
130
+ ENV SAR_FAITHFULNESS_THRESHOLD=0.7
131
+
132
+ # Logging
133
+ ENV SAR_LOG_LEVEL=INFO
134
+
135
+ # HF cache lives under the user home which is the only persistent writable
136
+ # tree across Space restarts on CPU Basic.
137
+ ENV HF_HOME=/home/user/.cache/huggingface
138
+ ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface/hub
139
+
140
+ EXPOSE 7860
141
+
142
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
143
+ CMD curl --fail --silent --show-error http://localhost:7860/healthz || exit 1
144
+
145
+ # uvicorn with 1 worker -- on CPU Basic two workers thrash the memory.
146
+ CMD ["uvicorn", "interfaces.api:app", \
147
+ "--host", "0.0.0.0", \
148
+ "--port", "7860", \
149
+ "--workers", "1", \
150
+ "--timeout-keep-alive", "30", \
151
+ "--no-access-log"]
config/settings.py CHANGED
@@ -1,316 +1,316 @@
1
- """Application settings managed via pydantic-settings with environment variable support."""
2
-
3
- from __future__ import annotations
4
-
5
- import contextlib
6
- import json
7
- import os
8
- from pathlib import Path
9
-
10
- from pydantic_settings import BaseSettings, SettingsConfigDict
11
-
12
-
13
- class Settings(BaseSettings):
14
- """Central configuration for SecureAgentRAG.
15
-
16
- All settings can be overridden via environment variables prefixed with ``SAR_``.
17
- For example, ``SAR_DEBUG=true`` sets ``debug`` to True.
18
- """
19
-
20
- model_config = SettingsConfigDict(
21
- env_file=".env",
22
- env_prefix="SAR_",
23
- env_file_encoding="utf-8",
24
- case_sensitive=False,
25
- extra="ignore",
26
- )
27
-
28
- # ── Application ──────────────────────────────────────────────────────────────
29
- app_name: str = "SecureAgentRAG"
30
- debug: bool = False
31
- log_level: str = "INFO"
32
-
33
- # ── Qdrant Vector Store ─────────────────────────────────────────────────────
34
- qdrant_url: str = "http://localhost:6333"
35
- qdrant_collection: str = "documents"
36
- qdrant_api_key: str | None = None
37
-
38
- # ── Ollama / LLM ─────────────────────────────────────────────────────────────
39
- ollama_url: str = "http://localhost:11434"
40
- llm_model: str = "qwen3:8b"
41
- embedding_model: str = "bge-m3"
42
- embedding_dim: int = 1024
43
- embedding_backend: str = "ollama" # "ollama" or "local" (sentence-transformers)
44
- local_embedding_model: str = "BAAI/bge-m3"
45
- # How long Ollama keeps models resident in VRAM between requests.
46
- # On consumer hardware the LLM (qwen3:8b ~5.5GB) and embedding (bge-m3 ~1.2GB)
47
- # need to swap if VRAM is tight. Long keep-alive avoids ~5-10s reload per swap.
48
- ollama_keep_alive: str = "30m"
49
-
50
- # ── Chunking ─────────────────────────────────────────────────────────────────
51
- chunk_size: int = 1000
52
- chunk_overlap: int = 200
53
-
54
- # ── Retrieval ────────────────────────────────────────────────────────────────
55
- top_k: int = 10
56
- rerank_top_k: int = 5
57
- relevance_threshold: float = 0.7
58
- # RAG Fusion: generate N query reformulations, retrieve in parallel,
59
- # fuse the ranked lists via RRF. Boosts recall on under-specified
60
- # queries. Cost: N-1 extra LLM calls + N parallel Qdrant searches.
61
- # Set to 1 to disable.
62
- rag_fusion_n_queries: int = 3
63
- rag_fusion_enabled: bool = True
64
- # ── Reranker ─────────────────────────────────────────────────────────────────
65
- # Re-score retrieved documents for higher precision.
66
- # Options: "none" (disabled), "cross_encoder" (BGE-Reranker-v2-M3),
67
- # "colbert" (ColBERTv2 late-interaction, requires colbert-ai package).
68
- # The cross-encoder downloads ~600MB from HuggingFace on first use.
69
- # The ColBERT checkpoint is ~400MB. Disabled by default so the first
70
- # query does not silently hang on download. Pre-download explicitly.
71
- reranker_type: str = "none"
72
- reranker_checkpoint: str = "BAAI/bge-reranker-v2-m3"
73
- colbert_checkpoint: str = "colbert-ir/colbertv2.0"
74
- # Path to a locally fine-tuned cross-encoder checkpoint produced by
75
- # scripts/train_reranker.py. Used when reranker_type == "fine_tuned".
76
- finetuned_reranker_path: str = "data/checkpoints/reranker-domain-v1"
77
-
78
- # ── Inference Providers ──────────────────────────────────────────────────────
79
- default_provider: str = "ollama"
80
- cloud_provider: str | None = None
81
- groq_api_key: str | None = None
82
- openai_api_key: str | None = None
83
- anthropic_api_key: str | None = None
84
- groq_api_base: str = "https://api.groq.com/openai/v1"
85
- openai_api_base: str = "https://api.openai.com/v1"
86
- anthropic_api_base: str = "https://api.anthropic.com/v1"
87
-
88
- # ── RAG Pipeline Thresholds ───────────────────────────────────────────────────
89
- relevance_retry_threshold: float = 0.5
90
- confidence_threshold: float = 0.6
91
- max_retries: int = 2
92
-
93
- # ── JSON Citations ───────────────────────────────────────────────────────���────
94
- # When enabled, the synthesizer requests structured JSON output from the LLM
95
- # with `answer` and `citations` fields instead of relying on regex extraction.
96
- json_citations_enabled: bool = False
97
-
98
- # ── Embedding Batch Size ──────────────────────────────────────────────────────
99
- embedding_batch_size: int = 32 # Max texts per embedding API call
100
- embedding_max_concurrent_batches: int = 4 # Max concurrent batch requests
101
-
102
- # ── RBAC ─────────────────────────────────────────────────────────────────────
103
- enable_rbac: bool = True
104
-
105
- # ── Observability (Phoenix) ──────────────────────────────────────────────────
106
- phoenix_endpoint: str | None = None
107
-
108
- # ── Sparse Vectors (Qdrant native, replaces rank_bm25 pickle) ───────────────
109
- sparse_backend: str = "bm25" # "bm25" | "splade"
110
- sparse_vector_name: str = "sparse"
111
- sparse_model: str = "naver/splade-cocondenser-ensembledistil"
112
-
113
- # ── Audit + Conversation Storage ──────────────────────────────────────────────
114
- audit_log_dir: str = "audit_logs"
115
- conversation_dir: str = "conversations"
116
- checkpoint_db_path: str = "data/checkpoints.sqlite"
117
- # Opt-in: enable persistent (SQLite/Postgres) LangGraph checkpointing.
118
- # Default off because pytest-asyncio creates per-test event loops which
119
- # collide with aiosqlite's loop-bound connection. For production single-
120
- # process Streamlit / FastAPI deployments, set SAR_USE_PERSISTENT_CHECKPOINTER=true.
121
- use_persistent_checkpointer: bool = False
122
-
123
- # ── PostgreSQL (for LangGraph checkpointing) ─────────────────────────────────
124
- postgres_url: str = "postgresql://sar_user:sar_password@localhost:5433/secureagentrag"
125
-
126
- # ── Pipeline SLO ─────────────────────────────────────────────────────────────
127
- # Hard wall-clock budget for a single RAG pipeline run (rewrite loop +
128
- # retrieval + grading + synthesis + evaluation). On timeout the caller
129
- # gets a graceful refusal + audit entry; nothing partial is rendered as
130
- # if the answer succeeded. 0 disables the deadline.
131
- request_timeout_s: float = 60.0
132
-
133
- # ── Authentication ───────────────────────────────────────────────────────────
134
- # When ``jwt_secret`` is set the FastAPI / MCP layers verify HS256-signed
135
- # JWTs and derive UserContext from validated claims. When unset, callers
136
- # fall back to the dev-mode base64(json(UserContext)) token shape so
137
- # existing tests and smoke scripts keep working — but a runtime warning is
138
- # emitted on every request. Production deployments MUST set this.
139
- #
140
- # ``jwt_issuer`` / ``jwt_audience`` are checked against ``iss`` / ``aud``
141
- # claims when present. Leave empty to disable that check (default).
142
- # ``jwt_ttl_seconds`` is the lifetime of tokens minted via the local
143
- # ``/token`` dev endpoint; real IdPs (Keycloak/Auth0) set their own.
144
- jwt_secret: str | None = None
145
- jwt_issuer: str = "secureagentrag"
146
- jwt_audience: str = "secureagentrag-api"
147
- jwt_ttl_seconds: int = 3600
148
- jwt_algorithm: str = "HS256"
149
- # JWKS endpoint for RS256 verification (e.g. Keycloak, Auth0).
150
- # When set and jwt_algorithm == "RS256", tokens are verified against
151
- # the cached JWKS instead of jwt_secret.
152
- jwks_url: str | None = None
153
- jwks_cache_ttl_seconds: int = 300
154
-
155
- # ── Citation Faithfulness Gate (NLI) ─────────────────────────────────────────
156
- # After synthesis, run a per-sentence NLI check: for each sentence that
157
- # carries an inline `[N]` citation, ask a yes/no entailment question
158
- # against the cited chunk's text. Sentences that fail are either marked
159
- # `[unsupported]` (soft mode) or dropped from the answer (strict mode).
160
- # The check uses the same local LLM as the rest of the graph — no extra
161
- # model download. Cost: one LLM call per cited sentence (parallel).
162
- faithfulness_gate_enabled: bool = False
163
- faithfulness_gate_mode: str = "flag" # "flag" | "drop"
164
- faithfulness_threshold: float = 0.7 # min entailment ratio to consider answer faithful
165
- faithfulness_max_concurrent: int = 4 # parallel NLI checks
166
-
167
- # ── Redis (for distributed rate limiting / caching) ──────────────────────────
168
- redis_url: str = "redis://localhost:6379/0"
169
- use_redis_rate_limiter: bool = False
170
-
171
- # ── PII Redaction ────────────────────────────────────────────────────────────
172
- # Scrub email, phone, SSN, credit-card, IBAN, IP address before persisting
173
- # to audit log / query cache. Defense against accidental PII leakage into
174
- # secondary stores. Regex-based by default; if Microsoft Presidio is
175
- # installed it is used automatically for higher recall.
176
- pii_redaction_enabled: bool = True
177
-
178
- # ── Prompt-Injection Guardrails ──────────────────────────────────────────────
179
- # Run a regex + heuristic check on the user query before retrieval. Blocks
180
- # obvious jailbreak / system-prompt-override attempts. Logged via the audit
181
- # logger as ``security_block`` events.
182
- guardrails_enabled: bool = True
183
- # Strict mode: after the fast regex gate, escalate ambiguous or all queries
184
- # to a local LLM-based classifier for a second opinion. Adds one LLM call
185
- # per query but catches adversarial inputs that evade regex patterns.
186
- guardrails_strict: bool = False
187
- # Escalation backend used in strict mode. Options:
188
- # "llm" — legacy SAFE/UNSAFE prompt on the synth-grade model
189
- # (core.agents.guardrails_llm). Default for backward
190
- # compatibility.
191
- # "llamaguard" — Meta's LlamaGuard 3 8B via Ollama. Use with
192
- # ``ollama pull llama-guard3:8b``. More accurate on
193
- # the standard S1-S14 taxonomy.
194
- guardrails_backend: str = "llm"
195
- llamaguard_model: str = "llama-guard3:8b"
196
-
197
- # ── Contextual Retrieval (Anthropic 2024 technique) ──────────────────────────
198
- # Prepend a short LLM-generated context summary to each chunk before
199
- # embedding. Adds 1 cheap LLM call per chunk at ingestion time but
200
- # measurably improves retrieval recall (Anthropic reported ~35-49%
201
- # failure reduction). Local Qwen3-8B is fine for the summary.
202
- contextual_retrieval_enabled: bool = False
203
-
204
- # ── VLM OCR (Primary OCR via vision-language model) ───────────────────────────
205
- # Use a VLM (Qwen2.5-VL / Qwen3-VL, LLaVA, etc.) via Ollama as the primary OCR path.
206
- # Superior to PaddleOCR on complex layouts, tables, and mixed-language
207
- # documents. Falls back to PaddleOCR when the VLM is unavailable.
208
- vlm_ocr_enabled: bool = False
209
- vlm_ocr_model: str = "qwen2.5-vl"
210
-
211
- # ── Multi-Tenancy ────────────────────────────────────────────────────────────
212
- # When true, each organization gets its own Qdrant collection
213
- # (documents_{org_id}). This provides stronger isolation than payload-level
214
- # RBAC filtering but requires creating collections per org on first use.
215
- # When false, all docs share a single collection with RBAC at payload level.
216
- multi_tenant_collections: bool = False
217
-
218
- # ── BYOK demo mode (P6 production launch, see launch-plan/03-backend-byok.md)
219
- # In BYOK mode the FastAPI surface accepts per-request LLM keys from visitor
220
- # headers, scopes Qdrant writes to per-session collections, and disables
221
- # Phoenix instrumentation. Off in dev/staging, on in the Hugging Face Space
222
- # production image (SAR_BYOK_MODE=true via Space secrets).
223
- byok_mode: bool = False
224
- # When BYOK is on and a visitor did NOT bring their own LLM key, the owner
225
- # key in .env is used but throttled to this many requests per IP per hour.
226
- # The cap is intentionally tight so the Groq free-tier 30 RPM / 14400 RPD
227
- # is never exhausted by a single visitor.
228
- byok_owner_key_quota_per_hour: int = 3
229
- # Per-session Qdrant collections (documents_sess_<session_id>) are auto
230
- # purged after this many hours by retrieval/session_purge.py.
231
- session_collection_ttl_hours: int = 24
232
- # CORS allowlist consulted by the FastAPI middleware when byok_mode=true.
233
- # Empty list = no CORS middleware mounted (dev default).
234
- cors_allow_origins: list[str] = []
235
-
236
- # ── Multi-Modal RAG ──────────────────────────────────────────────────────────
237
- # When ingesting images, also generate a rich text description using a VLM.
238
- # The description is embedded as a separate chunk, enabling retrieval for
239
- # queries like "what does the diagram show?" without requiring CLIP or
240
- # other multi-modal embedding models.
241
- multimodal_descriptions_enabled: bool = False
242
-
243
- # ── Self-Query Retrieval ─────────────────────────────────────────────────────
244
- # Extract structured metadata filters (source_file, date_range,
245
- # sensitivity_level, roles) from the natural language query using a small
246
- # local LLM prompt. The filters are merged with the RBAC filter and passed
247
- # to Qdrant, scoping retrieval before embedding search runs.
248
- self_query_enabled: bool = False
249
-
250
- # ── HyDE (Hypothetical Document Embeddings) ──────────────────────────────────
251
- # Generate a hypothetical answer to the query, embed *that* instead of the
252
- # raw query. Boosts recall when query vocabulary differs from doc
253
- # vocabulary (questions vs declarative sentences). Adds one LLM call per
254
- # query — skip for simple keyword lookups; enable for complex questions.
255
- hyde_enabled: bool = False
256
-
257
- # ── Pricing for cost dashboard (USD per 1M tokens) ───────────────────────────
258
- # Used by evaluation/cost.py to convert recorded usage into $/query.
259
- price_groq_input_per_1m: float = 0.59
260
- price_groq_output_per_1m: float = 0.79
261
- price_openai_input_per_1m: float = 2.50
262
- price_openai_output_per_1m: float = 10.00
263
- price_anthropic_input_per_1m: float = 3.00
264
- price_anthropic_output_per_1m: float = 15.00
265
- # Local inference: estimated electricity cost only (consumer hardware).
266
- # 200W GPU @ $0.15/kWh ≈ $0.03/hour ≈ $0.000008/sec
267
- price_local_per_second: float = 0.000008
268
-
269
-
270
- def _apply_calibration(settings_obj: Settings) -> None:
271
- """Override threshold defaults from ``evaluation/calibration.json`` when present.
272
-
273
- The calibration script (``scripts/calibrate_thresholds.py``) writes the
274
- chosen confidence + faithfulness cutoffs against a labelled gold set. Loading
275
- them here means deployments inherit the latest tuned values automatically,
276
- while an explicit ``SAR_CONFIDENCE_THRESHOLD`` / ``SAR_FAITHFULNESS_THRESHOLD``
277
- env var still wins so operators can override per environment.
278
-
279
- Silently no-ops when the file is missing, malformed, or the relevant keys
280
- are absent — never blocks startup.
281
- """
282
- calib_path = Path(__file__).resolve().parent.parent / "evaluation" / "calibration.json"
283
- if not calib_path.exists():
284
- return
285
- try:
286
- data = json.loads(calib_path.read_text(encoding="utf-8"))
287
- except (OSError, json.JSONDecodeError):
288
- return
289
-
290
- # Reject degenerate sweeps (no negatives or no positives -> the chosen
291
- # threshold has no statistical meaning). Keeping the original default in
292
- # that case is safer than letting a 0.0 cut-off escape into production.
293
- def _sane(block: dict) -> bool:
294
- try:
295
- return (
296
- int(block.get("n_pos", 0)) > 0
297
- and int(block.get("n_neg", 0)) > 0
298
- and float(block.get("chosen_threshold", 0.0)) > 0.0
299
- )
300
- except (TypeError, ValueError):
301
- return False
302
-
303
- conf_block = data.get("confidence", {})
304
- if _sane(conf_block) and os.environ.get("SAR_CONFIDENCE_THRESHOLD") is None:
305
- with contextlib.suppress(TypeError, ValueError):
306
- settings_obj.confidence_threshold = float(conf_block["chosen_threshold"])
307
-
308
- faith_block = data.get("faithfulness", {})
309
- if _sane(faith_block) and os.environ.get("SAR_FAITHFULNESS_THRESHOLD") is None:
310
- with contextlib.suppress(TypeError, ValueError):
311
- settings_obj.faithfulness_threshold = float(faith_block["chosen_threshold"])
312
-
313
-
314
- # Singleton instance — import this throughout the application
315
- settings = Settings()
316
- _apply_calibration(settings)
 
1
+ """Application settings managed via pydantic-settings with environment variable support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+
10
+ from pydantic_settings import BaseSettings, SettingsConfigDict
11
+
12
+
13
+ class Settings(BaseSettings):
14
+ """Central configuration for SecureAgentRAG.
15
+
16
+ All settings can be overridden via environment variables prefixed with ``SAR_``.
17
+ For example, ``SAR_DEBUG=true`` sets ``debug`` to True.
18
+ """
19
+
20
+ model_config = SettingsConfigDict(
21
+ env_file=".env",
22
+ env_prefix="SAR_",
23
+ env_file_encoding="utf-8",
24
+ case_sensitive=False,
25
+ extra="ignore",
26
+ )
27
+
28
+ # ── Application ──────────────────────────────────────────────────────────────
29
+ app_name: str = "SecureAgentRAG"
30
+ debug: bool = False
31
+ log_level: str = "INFO"
32
+
33
+ # ── Qdrant Vector Store ─────────────────────────────────────────────────���────
34
+ qdrant_url: str = "http://localhost:6333"
35
+ qdrant_collection: str = "documents"
36
+ qdrant_api_key: str | None = None
37
+
38
+ # ── Ollama / LLM ─────────────────────────────────────────────────────────────
39
+ ollama_url: str = "http://localhost:11434"
40
+ llm_model: str = "qwen3:8b"
41
+ embedding_model: str = "bge-m3"
42
+ embedding_dim: int = 1024
43
+ embedding_backend: str = "ollama" # "ollama" or "local" (sentence-transformers)
44
+ local_embedding_model: str = "BAAI/bge-m3"
45
+ # How long Ollama keeps models resident in VRAM between requests.
46
+ # On consumer hardware the LLM (qwen3:8b ~5.5GB) and embedding (bge-m3 ~1.2GB)
47
+ # need to swap if VRAM is tight. Long keep-alive avoids ~5-10s reload per swap.
48
+ ollama_keep_alive: str = "30m"
49
+
50
+ # ── Chunking ─────────────────────────────────────────────────────────────────
51
+ chunk_size: int = 1000
52
+ chunk_overlap: int = 200
53
+
54
+ # ── Retrieval ────────────────────────────────────────────────────────────────
55
+ top_k: int = 10
56
+ rerank_top_k: int = 5
57
+ relevance_threshold: float = 0.7
58
+ # RAG Fusion: generate N query reformulations, retrieve in parallel,
59
+ # fuse the ranked lists via RRF. Boosts recall on under-specified
60
+ # queries. Cost: N-1 extra LLM calls + N parallel Qdrant searches.
61
+ # Set to 1 to disable.
62
+ rag_fusion_n_queries: int = 3
63
+ rag_fusion_enabled: bool = True
64
+ # ── Reranker ─────────────────────────────────────────────────────────────────
65
+ # Re-score retrieved documents for higher precision.
66
+ # Options: "none" (disabled), "cross_encoder" (BGE-Reranker-v2-M3),
67
+ # "colbert" (ColBERTv2 late-interaction, requires colbert-ai package).
68
+ # The cross-encoder downloads ~600MB from HuggingFace on first use.
69
+ # The ColBERT checkpoint is ~400MB. Disabled by default so the first
70
+ # query does not silently hang on download. Pre-download explicitly.
71
+ reranker_type: str = "none"
72
+ reranker_checkpoint: str = "BAAI/bge-reranker-v2-m3"
73
+ colbert_checkpoint: str = "colbert-ir/colbertv2.0"
74
+ # Path to a locally fine-tuned cross-encoder checkpoint produced by
75
+ # scripts/train_reranker.py. Used when reranker_type == "fine_tuned".
76
+ finetuned_reranker_path: str = "data/checkpoints/reranker-domain-v1"
77
+
78
+ # ── Inference Providers ──────────────────────────────────────────────────────
79
+ default_provider: str = "ollama"
80
+ cloud_provider: str | None = None
81
+ groq_api_key: str | None = None
82
+ openai_api_key: str | None = None
83
+ anthropic_api_key: str | None = None
84
+ groq_api_base: str = "https://api.groq.com/openai/v1"
85
+ openai_api_base: str = "https://api.openai.com/v1"
86
+ anthropic_api_base: str = "https://api.anthropic.com/v1"
87
+
88
+ # ── RAG Pipeline Thresholds ───────────────────────────────────────────────────
89
+ relevance_retry_threshold: float = 0.5
90
+ confidence_threshold: float = 0.6
91
+ max_retries: int = 2
92
+
93
+ # ── JSON Citations ───────────────────────────────────────────────────────────
94
+ # When enabled, the synthesizer requests structured JSON output from the LLM
95
+ # with `answer` and `citations` fields instead of relying on regex extraction.
96
+ json_citations_enabled: bool = False
97
+
98
+ # ── Embedding Batch Size ──────────────────────────────────────────────────────
99
+ embedding_batch_size: int = 32 # Max texts per embedding API call
100
+ embedding_max_concurrent_batches: int = 4 # Max concurrent batch requests
101
+
102
+ # ── RBAC ─────────────────────────────────────────────────────────────────────
103
+ enable_rbac: bool = True
104
+
105
+ # ── Observability (Phoenix) ──────────────────────────────────────────────────
106
+ phoenix_endpoint: str | None = None
107
+
108
+ # ── Sparse Vectors (Qdrant native, replaces rank_bm25 pickle) ��───────────────
109
+ sparse_backend: str = "bm25" # "bm25" | "splade"
110
+ sparse_vector_name: str = "sparse"
111
+ sparse_model: str = "naver/splade-cocondenser-ensembledistil"
112
+
113
+ # ── Audit + Conversation Storage ──────────────────────────────────────────────
114
+ audit_log_dir: str = "audit_logs"
115
+ conversation_dir: str = "conversations"
116
+ checkpoint_db_path: str = "data/checkpoints.sqlite"
117
+ # Opt-in: enable persistent (SQLite/Postgres) LangGraph checkpointing.
118
+ # Default off because pytest-asyncio creates per-test event loops which
119
+ # collide with aiosqlite's loop-bound connection. For production single-
120
+ # process Streamlit / FastAPI deployments, set SAR_USE_PERSISTENT_CHECKPOINTER=true.
121
+ use_persistent_checkpointer: bool = False
122
+
123
+ # ── PostgreSQL (for LangGraph checkpointing) ─────────────────────────────────
124
+ postgres_url: str = "postgresql://sar_user:sar_password@localhost:5433/secureagentrag"
125
+
126
+ # ── Pipeline SLO ─────────────────────────────────────────────────────────────
127
+ # Hard wall-clock budget for a single RAG pipeline run (rewrite loop +
128
+ # retrieval + grading + synthesis + evaluation). On timeout the caller
129
+ # gets a graceful refusal + audit entry; nothing partial is rendered as
130
+ # if the answer succeeded. 0 disables the deadline.
131
+ request_timeout_s: float = 60.0
132
+
133
+ # ── Authentication ───────────────────────────────────────────────────────────
134
+ # When ``jwt_secret`` is set the FastAPI / MCP layers verify HS256-signed
135
+ # JWTs and derive UserContext from validated claims. When unset, callers
136
+ # fall back to the dev-mode base64(json(UserContext)) token shape so
137
+ # existing tests and smoke scripts keep working — but a runtime warning is
138
+ # emitted on every request. Production deployments MUST set this.
139
+ #
140
+ # ``jwt_issuer`` / ``jwt_audience`` are checked against ``iss`` / ``aud``
141
+ # claims when present. Leave empty to disable that check (default).
142
+ # ``jwt_ttl_seconds`` is the lifetime of tokens minted via the local
143
+ # ``/token`` dev endpoint; real IdPs (Keycloak/Auth0) set their own.
144
+ jwt_secret: str | None = None
145
+ jwt_issuer: str = "secureagentrag"
146
+ jwt_audience: str = "secureagentrag-api"
147
+ jwt_ttl_seconds: int = 3600
148
+ jwt_algorithm: str = "HS256"
149
+ # JWKS endpoint for RS256 verification (e.g. Keycloak, Auth0).
150
+ # When set and jwt_algorithm == "RS256", tokens are verified against
151
+ # the cached JWKS instead of jwt_secret.
152
+ jwks_url: str | None = None
153
+ jwks_cache_ttl_seconds: int = 300
154
+
155
+ # ── Citation Faithfulness Gate (NLI) ─────────────────────────────────────────
156
+ # After synthesis, run a per-sentence NLI check: for each sentence that
157
+ # carries an inline `[N]` citation, ask a yes/no entailment question
158
+ # against the cited chunk's text. Sentences that fail are either marked
159
+ # `[unsupported]` (soft mode) or dropped from the answer (strict mode).
160
+ # The check uses the same local LLM as the rest of the graph — no extra
161
+ # model download. Cost: one LLM call per cited sentence (parallel).
162
+ faithfulness_gate_enabled: bool = False
163
+ faithfulness_gate_mode: str = "flag" # "flag" | "drop"
164
+ faithfulness_threshold: float = 0.7 # min entailment ratio to consider answer faithful
165
+ faithfulness_max_concurrent: int = 4 # parallel NLI checks
166
+
167
+ # ── Redis (for distributed rate limiting / caching) ──────────────────────────
168
+ redis_url: str = "redis://localhost:6379/0"
169
+ use_redis_rate_limiter: bool = False
170
+
171
+ # ── PII Redaction ────────────────────────────────────────────────────────────
172
+ # Scrub email, phone, SSN, credit-card, IBAN, IP address before persisting
173
+ # to audit log / query cache. Defense against accidental PII leakage into
174
+ # secondary stores. Regex-based by default; if Microsoft Presidio is
175
+ # installed it is used automatically for higher recall.
176
+ pii_redaction_enabled: bool = True
177
+
178
+ # ── Prompt-Injection Guardrails ──────────────────────────────────────────────
179
+ # Run a regex + heuristic check on the user query before retrieval. Blocks
180
+ # obvious jailbreak / system-prompt-override attempts. Logged via the audit
181
+ # logger as ``security_block`` events.
182
+ guardrails_enabled: bool = True
183
+ # Strict mode: after the fast regex gate, escalate ambiguous or all queries
184
+ # to a local LLM-based classifier for a second opinion. Adds one LLM call
185
+ # per query but catches adversarial inputs that evade regex patterns.
186
+ guardrails_strict: bool = False
187
+ # Escalation backend used in strict mode. Options:
188
+ # "llm" — legacy SAFE/UNSAFE prompt on the synth-grade model
189
+ # (core.agents.guardrails_llm). Default for backward
190
+ # compatibility.
191
+ # "llamaguard" — Meta's LlamaGuard 3 8B via Ollama. Use with
192
+ # ``ollama pull llama-guard3:8b``. More accurate on
193
+ # the standard S1-S14 taxonomy.
194
+ guardrails_backend: str = "llm"
195
+ llamaguard_model: str = "llama-guard3:8b"
196
+
197
+ # ── Contextual Retrieval (Anthropic 2024 technique) ──────────────────────────
198
+ # Prepend a short LLM-generated context summary to each chunk before
199
+ # embedding. Adds 1 cheap LLM call per chunk at ingestion time but
200
+ # measurably improves retrieval recall (Anthropic reported ~35-49%
201
+ # failure reduction). Local Qwen3-8B is fine for the summary.
202
+ contextual_retrieval_enabled: bool = False
203
+
204
+ # ── VLM OCR (Primary OCR via vision-language model) ───────────────────────────
205
+ # Use a VLM (Qwen2.5-VL / Qwen3-VL, LLaVA, etc.) via Ollama as the primary OCR path.
206
+ # Superior to PaddleOCR on complex layouts, tables, and mixed-language
207
+ # documents. Falls back to PaddleOCR when the VLM is unavailable.
208
+ vlm_ocr_enabled: bool = False
209
+ vlm_ocr_model: str = "qwen2.5-vl"
210
+
211
+ # ── Multi-Tenancy ────────────────────────────────────────────────────────────
212
+ # When true, each organization gets its own Qdrant collection
213
+ # (documents_{org_id}). This provides stronger isolation than payload-level
214
+ # RBAC filtering but requires creating collections per org on first use.
215
+ # When false, all docs share a single collection with RBAC at payload level.
216
+ multi_tenant_collections: bool = False
217
+
218
+ # ── BYOK demo mode (P6 production launch, see launch-plan/03-backend-byok.md)
219
+ # In BYOK mode the FastAPI surface accepts per-request LLM keys from visitor
220
+ # headers, scopes Qdrant writes to per-session collections, and disables
221
+ # Phoenix instrumentation. Off in dev/staging, on in the Hugging Face Space
222
+ # production image (SAR_BYOK_MODE=true via Space secrets).
223
+ byok_mode: bool = False
224
+ # When BYOK is on and a visitor did NOT bring their own LLM key, the owner
225
+ # key in .env is used but throttled to this many requests per IP per hour.
226
+ # The cap is intentionally tight so the Groq free-tier 30 RPM / 14400 RPD
227
+ # is never exhausted by a single visitor.
228
+ byok_owner_key_quota_per_hour: int = 3
229
+ # Per-session Qdrant collections (documents_sess_<session_id>) are auto
230
+ # purged after this many hours by retrieval/session_purge.py.
231
+ session_collection_ttl_hours: int = 24
232
+ # CORS allowlist consulted by the FastAPI middleware when byok_mode=true.
233
+ # Empty list = no CORS middleware mounted (dev default).
234
+ cors_allow_origins: list[str] = []
235
+
236
+ # ── Multi-Modal RAG ──────────────────────────────────────────────────────────
237
+ # When ingesting images, also generate a rich text description using a VLM.
238
+ # The description is embedded as a separate chunk, enabling retrieval for
239
+ # queries like "what does the diagram show?" without requiring CLIP or
240
+ # other multi-modal embedding models.
241
+ multimodal_descriptions_enabled: bool = False
242
+
243
+ # ── Self-Query Retrieval ─────────────────────────────────────────────────────
244
+ # Extract structured metadata filters (source_file, date_range,
245
+ # sensitivity_level, roles) from the natural language query using a small
246
+ # local LLM prompt. The filters are merged with the RBAC filter and passed
247
+ # to Qdrant, scoping retrieval before embedding search runs.
248
+ self_query_enabled: bool = False
249
+
250
+ # ── HyDE (Hypothetical Document Embeddings) ──────────────────────────────────
251
+ # Generate a hypothetical answer to the query, embed *that* instead of the
252
+ # raw query. Boosts recall when query vocabulary differs from doc
253
+ # vocabulary (questions vs declarative sentences). Adds one LLM call per
254
+ # query — skip for simple keyword lookups; enable for complex questions.
255
+ hyde_enabled: bool = False
256
+
257
+ # ── Pricing for cost dashboard (USD per 1M tokens) ───────────────────────────
258
+ # Used by evaluation/cost.py to convert recorded usage into $/query.
259
+ price_groq_input_per_1m: float = 0.59
260
+ price_groq_output_per_1m: float = 0.79
261
+ price_openai_input_per_1m: float = 2.50
262
+ price_openai_output_per_1m: float = 10.00
263
+ price_anthropic_input_per_1m: float = 3.00
264
+ price_anthropic_output_per_1m: float = 15.00
265
+ # Local inference: estimated electricity cost only (consumer hardware).
266
+ # 200W GPU @ $0.15/kWh ≈ $0.03/hour ≈ $0.000008/sec
267
+ price_local_per_second: float = 0.000008
268
+
269
+
270
+ def _apply_calibration(settings_obj: Settings) -> None:
271
+ """Override threshold defaults from ``evaluation/calibration.json`` when present.
272
+
273
+ The calibration script (``scripts/calibrate_thresholds.py``) writes the
274
+ chosen confidence + faithfulness cutoffs against a labelled gold set. Loading
275
+ them here means deployments inherit the latest tuned values automatically,
276
+ while an explicit ``SAR_CONFIDENCE_THRESHOLD`` / ``SAR_FAITHFULNESS_THRESHOLD``
277
+ env var still wins so operators can override per environment.
278
+
279
+ Silently no-ops when the file is missing, malformed, or the relevant keys
280
+ are absent — never blocks startup.
281
+ """
282
+ calib_path = Path(__file__).resolve().parent.parent / "evaluation" / "calibration.json"
283
+ if not calib_path.exists():
284
+ return
285
+ try:
286
+ data = json.loads(calib_path.read_text(encoding="utf-8"))
287
+ except (OSError, json.JSONDecodeError):
288
+ return
289
+
290
+ # Reject degenerate sweeps (no negatives or no positives -> the chosen
291
+ # threshold has no statistical meaning). Keeping the original default in
292
+ # that case is safer than letting a 0.0 cut-off escape into production.
293
+ def _sane(block: dict) -> bool:
294
+ try:
295
+ return (
296
+ int(block.get("n_pos", 0)) > 0
297
+ and int(block.get("n_neg", 0)) > 0
298
+ and float(block.get("chosen_threshold", 0.0)) > 0.0
299
+ )
300
+ except (TypeError, ValueError):
301
+ return False
302
+
303
+ conf_block = data.get("confidence", {})
304
+ if _sane(conf_block) and os.environ.get("SAR_CONFIDENCE_THRESHOLD") is None:
305
+ with contextlib.suppress(TypeError, ValueError):
306
+ settings_obj.confidence_threshold = float(conf_block["chosen_threshold"])
307
+
308
+ faith_block = data.get("faithfulness", {})
309
+ if _sane(faith_block) and os.environ.get("SAR_FAITHFULNESS_THRESHOLD") is None:
310
+ with contextlib.suppress(TypeError, ValueError):
311
+ settings_obj.faithfulness_threshold = float(faith_block["chosen_threshold"])
312
+
313
+
314
+ # Singleton instance — import this throughout the application
315
+ settings = Settings()
316
+ _apply_calibration(settings)
inference/cloud_clients.py CHANGED
@@ -1,577 +1,577 @@
1
- """Cloud LLM provider clients (Groq, OpenAI, Anthropic Claude)."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- import time
7
- from abc import ABC, abstractmethod
8
- from enum import StrEnum
9
- from typing import TYPE_CHECKING, Any
10
-
11
- if TYPE_CHECKING:
12
- from collections.abc import AsyncGenerator
13
-
14
- import httpx
15
- from tenacity import (
16
- retry,
17
- retry_if_exception_type,
18
- stop_after_attempt,
19
- wait_exponential,
20
- )
21
-
22
- from config.settings import settings
23
- from inference.llm_factory import LLMResponse
24
- from utils.logging import get_logger
25
-
26
- logger = get_logger(__name__)
27
-
28
- # Retry decorator for transient connection failures only
29
- _retry_on_connection = retry(
30
- retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
31
- stop=stop_after_attempt(3),
32
- wait=wait_exponential(multiplier=1, min=1, max=10),
33
- reraise=True,
34
- )
35
-
36
-
37
- class LLMProvider(StrEnum):
38
- """Supported LLM provider identifiers."""
39
-
40
- OLLAMA = "ollama"
41
- GROQ = "groq"
42
- OPENAI = "openai"
43
- ANTHROPIC = "anthropic"
44
-
45
-
46
- class BaseCloudClient(ABC):
47
- """Abstract base class for cloud LLM provider clients.
48
-
49
- Args:
50
- api_key: Provider API key for authentication.
51
- model: Default model identifier.
52
- timeout: Request timeout in seconds.
53
- """
54
-
55
- def __init__(self, api_key: str, model: str, timeout: float = 60.0) -> None:
56
- self.api_key = api_key
57
- self.model = model
58
- self.timeout = timeout
59
- self._client = httpx.AsyncClient(timeout=httpx.Timeout(timeout))
60
-
61
- @abstractmethod
62
- async def generate(
63
- self,
64
- prompt: str,
65
- system_prompt: str = "",
66
- temperature: float = 0.7,
67
- max_tokens: int = 2048,
68
- json_mode: bool = False,
69
- ) -> LLMResponse:
70
- """Generate a completion from the provider.
71
-
72
- Args:
73
- prompt: The user prompt text.
74
- system_prompt: Optional system context.
75
- temperature: Sampling temperature.
76
- max_tokens: Maximum tokens to generate.
77
- json_mode: When True, request JSON-formatted output.
78
-
79
- Returns:
80
- LLMResponse with generated text and metadata.
81
- """
82
-
83
- @abstractmethod
84
- async def chat(
85
- self,
86
- messages: list[dict],
87
- temperature: float = 0.7,
88
- max_tokens: int = 2048,
89
- ) -> LLMResponse:
90
- """Send a chat conversation to the provider.
91
-
92
- Args:
93
- messages: List of message dicts with 'role' and 'content' keys.
94
- temperature: Sampling temperature.
95
- max_tokens: Maximum tokens to generate.
96
-
97
- Returns:
98
- LLMResponse with generated text and metadata.
99
- """
100
-
101
- @abstractmethod
102
- async def generate_stream(
103
- self,
104
- prompt: str,
105
- system_prompt: str = "",
106
- temperature: float = 0.7,
107
- max_tokens: int = 2048,
108
- ) -> AsyncGenerator[str, None]:
109
- """Stream a completion from the provider, yielding tokens as they arrive.
110
-
111
- Args:
112
- prompt: The user prompt text.
113
- system_prompt: Optional system context.
114
- temperature: Sampling temperature.
115
- max_tokens: Maximum tokens to generate.
116
-
117
- Yields:
118
- Token strings as they are generated.
119
- """
120
-
121
- @abstractmethod
122
- async def health_check(self) -> bool:
123
- """Check if the provider API is reachable.
124
-
125
- Returns:
126
- True if the API responds successfully.
127
- """
128
-
129
- async def close(self) -> None:
130
- """Close the underlying HTTP client."""
131
- await self._client.aclose()
132
-
133
- async def __aenter__(self) -> BaseCloudClient:
134
- """Enter async context manager."""
135
- return self
136
-
137
- async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
138
- """Exit async context manager, closing the client."""
139
- await self.close()
140
-
141
-
142
- def make_byok_cloud_client(
143
- *,
144
- provider: str,
145
- user_key: str,
146
- model: str | None = None,
147
- timeout: float = 60.0,
148
- ) -> BaseCloudClient:
149
- """Build a per-request cloud LLM client that uses the visitor's API key.
150
-
151
- Each call returns a **fresh client instance** holding the supplied key
152
- in its own ``self.api_key`` slot. The visitor's key never lands on any
153
- module-level singleton, never mixes into the owner-key client, and is
154
- discarded when the FastAPI request scope ends.
155
-
156
- Args:
157
- provider: One of ``"groq"`` / ``"openai"`` / ``"anthropic"``.
158
- user_key: The visitor-supplied API key from ``X-User-LLM-Key``.
159
- model: Override the provider's default model.
160
- timeout: Per-request HTTP timeout in seconds.
161
-
162
- Returns:
163
- A new ``BaseCloudClient`` subclass instance bound to the visitor key.
164
-
165
- Raises:
166
- ValueError: ``provider`` is not in the BYOK allowlist or ``user_key``
167
- is missing.
168
- """
169
- if not user_key or not user_key.strip():
170
- raise ValueError("make_byok_cloud_client called without a user key")
171
- prov = (provider or "").lower()
172
- if prov == "groq":
173
- return GroqClient(
174
- api_key=user_key.strip(), model=model or "llama-3.1-8b-instant", timeout=timeout
175
- )
176
- if prov == "openai":
177
- return OpenAIClient(api_key=user_key.strip(), model=model or "gpt-4o-mini", timeout=timeout)
178
- if prov == "anthropic":
179
- return AnthropicClient(
180
- api_key=user_key.strip(),
181
- model=model or "claude-sonnet-4-20250514",
182
- timeout=timeout,
183
- )
184
- raise ValueError(f"BYOK provider not supported: {provider!r}")
185
-
186
-
187
- class OpenAICompatibleClient(BaseCloudClient):
188
- """Shared client for OpenAI Chat Completions-compatible APIs.
189
-
190
- Both Groq and OpenAI implement the same wire format
191
- (``POST /chat/completions`` + SSE streaming). Subclasses supply only
192
- the ``api_base`` URL and the ``provider`` tag — every method on
193
- ``BaseCloudClient`` is implemented once, here, and inherited.
194
- """
195
-
196
- #: Subclasses override these two class attrs.
197
- api_base: str = ""
198
- provider_name: str = ""
199
-
200
- def _headers(self) -> dict[str, str]:
201
- return {
202
- "Authorization": f"Bearer {self.api_key}",
203
- "Content-Type": "application/json",
204
- }
205
-
206
- @staticmethod
207
- def _messages(prompt: str, system_prompt: str) -> list[dict[str, str]]:
208
- out: list[dict[str, str]] = []
209
- if system_prompt:
210
- out.append({"role": "system", "content": system_prompt})
211
- out.append({"role": "user", "content": prompt})
212
- return out
213
-
214
- @_retry_on_connection
215
- async def generate(
216
- self,
217
- prompt: str,
218
- system_prompt: str = "",
219
- temperature: float = 0.7,
220
- max_tokens: int = 2048,
221
- json_mode: bool = False,
222
- ) -> LLMResponse:
223
- return await self.chat(
224
- messages=self._messages(prompt, system_prompt),
225
- temperature=temperature,
226
- max_tokens=max_tokens,
227
- json_mode=json_mode,
228
- )
229
-
230
- @_retry_on_connection
231
- async def chat(
232
- self,
233
- messages: list[dict],
234
- temperature: float = 0.7,
235
- max_tokens: int = 2048,
236
- json_mode: bool = False,
237
- ) -> LLMResponse:
238
- payload: dict[str, Any] = {
239
- "model": self.model,
240
- "messages": messages,
241
- "temperature": temperature,
242
- "max_tokens": max_tokens,
243
- }
244
- if json_mode:
245
- payload["response_format"] = {"type": "json_object"}
246
-
247
- start = time.perf_counter()
248
- response = await self._client.post(
249
- f"{self.api_base}/chat/completions",
250
- headers=self._headers(),
251
- json=payload,
252
- )
253
- elapsed_ms = (time.perf_counter() - start) * 1000
254
- response.raise_for_status()
255
-
256
- data = response.json()
257
- choice = data.get("choices", [{}])[0]
258
- message = choice.get("message", {})
259
- usage = data.get("usage", {})
260
-
261
- return LLMResponse(
262
- text=message.get("content", ""),
263
- model=data.get("model", self.model),
264
- provider=self.provider_name,
265
- usage={
266
- "prompt_tokens": usage.get("prompt_tokens", 0),
267
- "completion_tokens": usage.get("completion_tokens", 0),
268
- "total_tokens": usage.get("total_tokens", 0),
269
- },
270
- latency_ms=elapsed_ms,
271
- )
272
-
273
- @_retry_on_connection
274
- async def generate_stream(
275
- self,
276
- prompt: str,
277
- system_prompt: str = "",
278
- temperature: float = 0.7,
279
- max_tokens: int = 2048,
280
- ) -> AsyncGenerator[str, None]:
281
- payload: dict[str, Any] = {
282
- "model": self.model,
283
- "messages": self._messages(prompt, system_prompt),
284
- "temperature": temperature,
285
- "max_tokens": max_tokens,
286
- "stream": True,
287
- }
288
- async with self._client.stream(
289
- "POST",
290
- f"{self.api_base}/chat/completions",
291
- headers={**self._headers(), "Accept": "text/event-stream"},
292
- json=payload,
293
- ) as resp:
294
- resp.raise_for_status()
295
- async for line in resp.aiter_lines():
296
- line = line.strip()
297
- if not line.startswith("data: "):
298
- continue
299
- data_str = line[6:]
300
- if data_str == "[DONE]":
301
- break
302
- try:
303
- data = json.loads(data_str)
304
- except json.JSONDecodeError:
305
- continue
306
- choice = data.get("choices", [{}])[0]
307
- token = choice.get("delta", {}).get("content", "")
308
- if token:
309
- yield token
310
-
311
- @_retry_on_connection
312
- async def health_check(self) -> bool:
313
- try:
314
- response = await self._client.get(f"{self.api_base}/models", headers=self._headers())
315
- return response.status_code in (200, 401)
316
- except (httpx.ConnectError, httpx.TimeoutException):
317
- return False
318
-
319
-
320
- class GroqClient(OpenAICompatibleClient):
321
- """Groq cloud LLM client (OpenAI-compatible API at api.groq.com)."""
322
-
323
- provider_name = "groq"
324
-
325
- def __init__(
326
- self,
327
- api_key: str,
328
- model: str = "llama-3.3-70b-versatile",
329
- timeout: float = 60.0,
330
- ) -> None:
331
- super().__init__(api_key=api_key, model=model, timeout=timeout)
332
- self.api_base = settings.groq_api_base
333
-
334
-
335
- class OpenAIClient(OpenAICompatibleClient):
336
- """OpenAI cloud LLM client (Chat Completions API at api.openai.com)."""
337
-
338
- provider_name = "openai"
339
-
340
- def __init__(
341
- self,
342
- api_key: str,
343
- model: str = "gpt-4o-mini",
344
- timeout: float = 60.0,
345
- ) -> None:
346
- super().__init__(api_key=api_key, model=model, timeout=timeout)
347
- self.api_base = settings.openai_api_base
348
-
349
-
350
- class AnthropicClient(BaseCloudClient):
351
- """Anthropic Claude cloud LLM client using the Messages API.
352
-
353
- Args:
354
- api_key: Anthropic API key.
355
- model: Model identifier. Defaults to "claude-sonnet-4-20250514".
356
- timeout: Request timeout in seconds.
357
- """
358
-
359
- def __init__(
360
- self,
361
- api_key: str,
362
- model: str = "claude-sonnet-4-20250514",
363
- timeout: float = 60.0,
364
- ) -> None:
365
- super().__init__(api_key=api_key, model=model, timeout=timeout)
366
- self._api_base = settings.anthropic_api_base
367
-
368
- def _headers(self) -> dict[str, str]:
369
- """Build request headers with Anthropic-specific authentication."""
370
- return {
371
- "x-api-key": self.api_key,
372
- "anthropic-version": "2023-06-01",
373
- "Content-Type": "application/json",
374
- }
375
-
376
- @_retry_on_connection
377
- async def generate(
378
- self,
379
- prompt: str,
380
- system_prompt: str = "",
381
- temperature: float = 0.7,
382
- max_tokens: int = 2048,
383
- json_mode: bool = False,
384
- ) -> LLMResponse:
385
- """Generate a completion via Anthropic's Messages API.
386
-
387
- Args:
388
- prompt: The user prompt text.
389
- system_prompt: Optional system context.
390
- temperature: Sampling temperature.
391
- max_tokens: Maximum tokens to generate.
392
- json_mode: Anthropic does not support native JSON mode; ignored.
393
-
394
- Returns:
395
- LLMResponse with generated text and metadata.
396
- """
397
- messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
398
- return await self._send_messages(
399
- messages=messages,
400
- system_prompt=system_prompt,
401
- temperature=temperature,
402
- max_tokens=max_tokens,
403
- )
404
-
405
- @_retry_on_connection
406
- async def chat(
407
- self,
408
- messages: list[dict],
409
- temperature: float = 0.7,
410
- max_tokens: int = 2048,
411
- ) -> LLMResponse:
412
- """Send a chat request to Anthropic's Messages API.
413
-
414
- Anthropic uses a separate 'system' parameter instead of a system message
415
- in the messages list. This method extracts any system message and handles
416
- the format conversion.
417
-
418
- Args:
419
- messages: List of message dicts with 'role' and 'content' keys.
420
- temperature: Sampling temperature.
421
- max_tokens: Maximum tokens to generate.
422
-
423
- Returns:
424
- LLMResponse with generated text and metadata.
425
- """
426
- # Extract system message if present
427
- system_prompt = ""
428
- anthropic_messages: list[dict[str, str]] = []
429
- for msg in messages:
430
- if msg.get("role") == "system":
431
- system_prompt = msg.get("content", "")
432
- else:
433
- anthropic_messages.append(msg)
434
-
435
- return await self._send_messages(
436
- messages=anthropic_messages,
437
- system_prompt=system_prompt,
438
- temperature=temperature,
439
- max_tokens=max_tokens,
440
- )
441
-
442
- async def _send_messages(
443
- self,
444
- messages: list[dict],
445
- system_prompt: str = "",
446
- temperature: float = 0.7,
447
- max_tokens: int = 2048,
448
- ) -> LLMResponse:
449
- """Internal method to send messages to Anthropic's API.
450
-
451
- Args:
452
- messages: Anthropic-formatted messages (no system role).
453
- system_prompt: System prompt passed as top-level parameter.
454
- temperature: Sampling temperature.
455
- max_tokens: Maximum tokens to generate.
456
-
457
- Returns:
458
- LLMResponse with generated text and metadata.
459
- """
460
- payload: dict[str, Any] = {
461
- "model": self.model,
462
- "messages": messages,
463
- "temperature": temperature,
464
- "max_tokens": max_tokens,
465
- }
466
- if system_prompt:
467
- payload["system"] = system_prompt
468
-
469
- start = time.perf_counter()
470
- response = await self._client.post(
471
- f"{self._api_base}/messages",
472
- headers=self._headers(),
473
- json=payload,
474
- )
475
- elapsed_ms = (time.perf_counter() - start) * 1000
476
- response.raise_for_status()
477
-
478
- data = response.json()
479
- # Anthropic returns content as a list of content blocks
480
- content_blocks = data.get("content", [])
481
- text = ""
482
- for block in content_blocks:
483
- if block.get("type") == "text":
484
- text += block.get("text", "")
485
-
486
- usage = data.get("usage", {})
487
- return LLMResponse(
488
- text=text,
489
- model=data.get("model", self.model),
490
- provider="anthropic",
491
- usage={
492
- "prompt_tokens": usage.get("input_tokens", 0),
493
- "completion_tokens": usage.get("output_tokens", 0),
494
- "total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
495
- },
496
- latency_ms=elapsed_ms,
497
- )
498
-
499
- async def generate_stream(
500
- self,
501
- prompt: str,
502
- system_prompt: str = "",
503
- temperature: float = 0.7,
504
- max_tokens: int = 2048,
505
- ) -> AsyncGenerator[str, None]:
506
- """Stream a completion via Anthropic's Messages API.
507
-
508
- Anthropic supports streaming via SSE. Yields text content blocks
509
- as they arrive.
510
-
511
- Args:
512
- prompt: The user prompt text.
513
- system_prompt: Optional system context.
514
- temperature: Sampling temperature.
515
- max_tokens: Maximum tokens to generate.
516
-
517
- Yields:
518
- Token strings as they are generated.
519
- """
520
- payload: dict[str, Any] = {
521
- "model": self.model,
522
- "messages": [{"role": "user", "content": prompt}],
523
- "temperature": temperature,
524
- "max_tokens": max_tokens,
525
- "stream": True,
526
- }
527
- if system_prompt:
528
- payload["system"] = system_prompt
529
-
530
- async with self._client.stream(
531
- "POST",
532
- f"{self._api_base}/messages",
533
- headers={**self._headers(), "Accept": "text/event-stream"},
534
- json=payload,
535
- ) as resp:
536
- resp.raise_for_status()
537
- async for line in resp.aiter_lines():
538
- line = line.strip()
539
- if line.startswith("data: "):
540
- data_str = line[6:]
541
- if data_str == "[DONE]":
542
- break
543
- try:
544
- data = json.loads(data_str)
545
- event_type = data.get("type", "")
546
- if event_type == "content_block_delta":
547
- delta = data.get("delta", {})
548
- token = delta.get("text", "")
549
- if token:
550
- yield token
551
- elif event_type == "message_stop":
552
- break
553
- except json.JSONDecodeError:
554
- continue
555
-
556
- @_retry_on_connection
557
- async def health_check(self) -> bool:
558
- """Check if the Anthropic API is reachable.
559
-
560
- Returns:
561
- True if the API responds.
562
- """
563
- try:
564
- # Anthropic doesn't have a simple health endpoint; try a minimal request
565
- response = await self._client.post(
566
- f"{self._api_base}/messages",
567
- headers=self._headers(),
568
- json={
569
- "model": self.model,
570
- "messages": [{"role": "user", "content": "hi"}],
571
- "max_tokens": 1,
572
- },
573
- )
574
- # Any response (even 401) means the service is reachable
575
- return response.status_code in (200, 401, 400)
576
- except (httpx.ConnectError, httpx.TimeoutException):
577
- return False
 
1
+ """Cloud LLM provider clients (Groq, OpenAI, Anthropic Claude)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import time
7
+ from abc import ABC, abstractmethod
8
+ from enum import StrEnum
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ if TYPE_CHECKING:
12
+ from collections.abc import AsyncGenerator
13
+
14
+ import httpx
15
+ from tenacity import (
16
+ retry,
17
+ retry_if_exception_type,
18
+ stop_after_attempt,
19
+ wait_exponential,
20
+ )
21
+
22
+ from config.settings import settings
23
+ from inference.llm_factory import LLMResponse
24
+ from utils.logging import get_logger
25
+
26
+ logger = get_logger(__name__)
27
+
28
+ # Retry decorator for transient connection failures only
29
+ _retry_on_connection = retry(
30
+ retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
31
+ stop=stop_after_attempt(3),
32
+ wait=wait_exponential(multiplier=1, min=1, max=10),
33
+ reraise=True,
34
+ )
35
+
36
+
37
+ class LLMProvider(StrEnum):
38
+ """Supported LLM provider identifiers."""
39
+
40
+ OLLAMA = "ollama"
41
+ GROQ = "groq"
42
+ OPENAI = "openai"
43
+ ANTHROPIC = "anthropic"
44
+
45
+
46
+ class BaseCloudClient(ABC):
47
+ """Abstract base class for cloud LLM provider clients.
48
+
49
+ Args:
50
+ api_key: Provider API key for authentication.
51
+ model: Default model identifier.
52
+ timeout: Request timeout in seconds.
53
+ """
54
+
55
+ def __init__(self, api_key: str, model: str, timeout: float = 60.0) -> None:
56
+ self.api_key = api_key
57
+ self.model = model
58
+ self.timeout = timeout
59
+ self._client = httpx.AsyncClient(timeout=httpx.Timeout(timeout))
60
+
61
+ @abstractmethod
62
+ async def generate(
63
+ self,
64
+ prompt: str,
65
+ system_prompt: str = "",
66
+ temperature: float = 0.7,
67
+ max_tokens: int = 2048,
68
+ json_mode: bool = False,
69
+ ) -> LLMResponse:
70
+ """Generate a completion from the provider.
71
+
72
+ Args:
73
+ prompt: The user prompt text.
74
+ system_prompt: Optional system context.
75
+ temperature: Sampling temperature.
76
+ max_tokens: Maximum tokens to generate.
77
+ json_mode: When True, request JSON-formatted output.
78
+
79
+ Returns:
80
+ LLMResponse with generated text and metadata.
81
+ """
82
+
83
+ @abstractmethod
84
+ async def chat(
85
+ self,
86
+ messages: list[dict],
87
+ temperature: float = 0.7,
88
+ max_tokens: int = 2048,
89
+ ) -> LLMResponse:
90
+ """Send a chat conversation to the provider.
91
+
92
+ Args:
93
+ messages: List of message dicts with 'role' and 'content' keys.
94
+ temperature: Sampling temperature.
95
+ max_tokens: Maximum tokens to generate.
96
+
97
+ Returns:
98
+ LLMResponse with generated text and metadata.
99
+ """
100
+
101
+ @abstractmethod
102
+ async def generate_stream(
103
+ self,
104
+ prompt: str,
105
+ system_prompt: str = "",
106
+ temperature: float = 0.7,
107
+ max_tokens: int = 2048,
108
+ ) -> AsyncGenerator[str, None]:
109
+ """Stream a completion from the provider, yielding tokens as they arrive.
110
+
111
+ Args:
112
+ prompt: The user prompt text.
113
+ system_prompt: Optional system context.
114
+ temperature: Sampling temperature.
115
+ max_tokens: Maximum tokens to generate.
116
+
117
+ Yields:
118
+ Token strings as they are generated.
119
+ """
120
+
121
+ @abstractmethod
122
+ async def health_check(self) -> bool:
123
+ """Check if the provider API is reachable.
124
+
125
+ Returns:
126
+ True if the API responds successfully.
127
+ """
128
+
129
+ async def close(self) -> None:
130
+ """Close the underlying HTTP client."""
131
+ await self._client.aclose()
132
+
133
+ async def __aenter__(self) -> BaseCloudClient:
134
+ """Enter async context manager."""
135
+ return self
136
+
137
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
138
+ """Exit async context manager, closing the client."""
139
+ await self.close()
140
+
141
+
142
+ def make_byok_cloud_client(
143
+ *,
144
+ provider: str,
145
+ user_key: str,
146
+ model: str | None = None,
147
+ timeout: float = 60.0,
148
+ ) -> BaseCloudClient:
149
+ """Build a per-request cloud LLM client that uses the visitor's API key.
150
+
151
+ Each call returns a **fresh client instance** holding the supplied key
152
+ in its own ``self.api_key`` slot. The visitor's key never lands on any
153
+ module-level singleton, never mixes into the owner-key client, and is
154
+ discarded when the FastAPI request scope ends.
155
+
156
+ Args:
157
+ provider: One of ``"groq"`` / ``"openai"`` / ``"anthropic"``.
158
+ user_key: The visitor-supplied API key from ``X-User-LLM-Key``.
159
+ model: Override the provider's default model.
160
+ timeout: Per-request HTTP timeout in seconds.
161
+
162
+ Returns:
163
+ A new ``BaseCloudClient`` subclass instance bound to the visitor key.
164
+
165
+ Raises:
166
+ ValueError: ``provider`` is not in the BYOK allowlist or ``user_key``
167
+ is missing.
168
+ """
169
+ if not user_key or not user_key.strip():
170
+ raise ValueError("make_byok_cloud_client called without a user key")
171
+ prov = (provider or "").lower()
172
+ if prov == "groq":
173
+ return GroqClient(
174
+ api_key=user_key.strip(), model=model or "llama-3.1-8b-instant", timeout=timeout
175
+ )
176
+ if prov == "openai":
177
+ return OpenAIClient(api_key=user_key.strip(), model=model or "gpt-4o-mini", timeout=timeout)
178
+ if prov == "anthropic":
179
+ return AnthropicClient(
180
+ api_key=user_key.strip(),
181
+ model=model or "claude-sonnet-4-20250514",
182
+ timeout=timeout,
183
+ )
184
+ raise ValueError(f"BYOK provider not supported: {provider!r}")
185
+
186
+
187
+ class OpenAICompatibleClient(BaseCloudClient):
188
+ """Shared client for OpenAI Chat Completions-compatible APIs.
189
+
190
+ Both Groq and OpenAI implement the same wire format
191
+ (``POST /chat/completions`` + SSE streaming). Subclasses supply only
192
+ the ``api_base`` URL and the ``provider`` tag — every method on
193
+ ``BaseCloudClient`` is implemented once, here, and inherited.
194
+ """
195
+
196
+ #: Subclasses override these two class attrs.
197
+ api_base: str = ""
198
+ provider_name: str = ""
199
+
200
+ def _headers(self) -> dict[str, str]:
201
+ return {
202
+ "Authorization": f"Bearer {self.api_key}",
203
+ "Content-Type": "application/json",
204
+ }
205
+
206
+ @staticmethod
207
+ def _messages(prompt: str, system_prompt: str) -> list[dict[str, str]]:
208
+ out: list[dict[str, str]] = []
209
+ if system_prompt:
210
+ out.append({"role": "system", "content": system_prompt})
211
+ out.append({"role": "user", "content": prompt})
212
+ return out
213
+
214
+ @_retry_on_connection
215
+ async def generate(
216
+ self,
217
+ prompt: str,
218
+ system_prompt: str = "",
219
+ temperature: float = 0.7,
220
+ max_tokens: int = 2048,
221
+ json_mode: bool = False,
222
+ ) -> LLMResponse:
223
+ return await self.chat(
224
+ messages=self._messages(prompt, system_prompt),
225
+ temperature=temperature,
226
+ max_tokens=max_tokens,
227
+ json_mode=json_mode,
228
+ )
229
+
230
+ @_retry_on_connection
231
+ async def chat(
232
+ self,
233
+ messages: list[dict],
234
+ temperature: float = 0.7,
235
+ max_tokens: int = 2048,
236
+ json_mode: bool = False,
237
+ ) -> LLMResponse:
238
+ payload: dict[str, Any] = {
239
+ "model": self.model,
240
+ "messages": messages,
241
+ "temperature": temperature,
242
+ "max_tokens": max_tokens,
243
+ }
244
+ if json_mode:
245
+ payload["response_format"] = {"type": "json_object"}
246
+
247
+ start = time.perf_counter()
248
+ response = await self._client.post(
249
+ f"{self.api_base}/chat/completions",
250
+ headers=self._headers(),
251
+ json=payload,
252
+ )
253
+ elapsed_ms = (time.perf_counter() - start) * 1000
254
+ response.raise_for_status()
255
+
256
+ data = response.json()
257
+ choice = data.get("choices", [{}])[0]
258
+ message = choice.get("message", {})
259
+ usage = data.get("usage", {})
260
+
261
+ return LLMResponse(
262
+ text=message.get("content", ""),
263
+ model=data.get("model", self.model),
264
+ provider=self.provider_name,
265
+ usage={
266
+ "prompt_tokens": usage.get("prompt_tokens", 0),
267
+ "completion_tokens": usage.get("completion_tokens", 0),
268
+ "total_tokens": usage.get("total_tokens", 0),
269
+ },
270
+ latency_ms=elapsed_ms,
271
+ )
272
+
273
+ @_retry_on_connection
274
+ async def generate_stream(
275
+ self,
276
+ prompt: str,
277
+ system_prompt: str = "",
278
+ temperature: float = 0.7,
279
+ max_tokens: int = 2048,
280
+ ) -> AsyncGenerator[str, None]:
281
+ payload: dict[str, Any] = {
282
+ "model": self.model,
283
+ "messages": self._messages(prompt, system_prompt),
284
+ "temperature": temperature,
285
+ "max_tokens": max_tokens,
286
+ "stream": True,
287
+ }
288
+ async with self._client.stream(
289
+ "POST",
290
+ f"{self.api_base}/chat/completions",
291
+ headers={**self._headers(), "Accept": "text/event-stream"},
292
+ json=payload,
293
+ ) as resp:
294
+ resp.raise_for_status()
295
+ async for line in resp.aiter_lines():
296
+ line = line.strip()
297
+ if not line.startswith("data: "):
298
+ continue
299
+ data_str = line[6:]
300
+ if data_str == "[DONE]":
301
+ break
302
+ try:
303
+ data = json.loads(data_str)
304
+ except json.JSONDecodeError:
305
+ continue
306
+ choice = data.get("choices", [{}])[0]
307
+ token = choice.get("delta", {}).get("content", "")
308
+ if token:
309
+ yield token
310
+
311
+ @_retry_on_connection
312
+ async def health_check(self) -> bool:
313
+ try:
314
+ response = await self._client.get(f"{self.api_base}/models", headers=self._headers())
315
+ return response.status_code in (200, 401)
316
+ except (httpx.ConnectError, httpx.TimeoutException):
317
+ return False
318
+
319
+
320
+ class GroqClient(OpenAICompatibleClient):
321
+ """Groq cloud LLM client (OpenAI-compatible API at api.groq.com)."""
322
+
323
+ provider_name = "groq"
324
+
325
+ def __init__(
326
+ self,
327
+ api_key: str,
328
+ model: str = "llama-3.3-70b-versatile",
329
+ timeout: float = 60.0,
330
+ ) -> None:
331
+ super().__init__(api_key=api_key, model=model, timeout=timeout)
332
+ self.api_base = settings.groq_api_base
333
+
334
+
335
+ class OpenAIClient(OpenAICompatibleClient):
336
+ """OpenAI cloud LLM client (Chat Completions API at api.openai.com)."""
337
+
338
+ provider_name = "openai"
339
+
340
+ def __init__(
341
+ self,
342
+ api_key: str,
343
+ model: str = "gpt-4o-mini",
344
+ timeout: float = 60.0,
345
+ ) -> None:
346
+ super().__init__(api_key=api_key, model=model, timeout=timeout)
347
+ self.api_base = settings.openai_api_base
348
+
349
+
350
+ class AnthropicClient(BaseCloudClient):
351
+ """Anthropic Claude cloud LLM client using the Messages API.
352
+
353
+ Args:
354
+ api_key: Anthropic API key.
355
+ model: Model identifier. Defaults to "claude-sonnet-4-20250514".
356
+ timeout: Request timeout in seconds.
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ api_key: str,
362
+ model: str = "claude-sonnet-4-20250514",
363
+ timeout: float = 60.0,
364
+ ) -> None:
365
+ super().__init__(api_key=api_key, model=model, timeout=timeout)
366
+ self._api_base = settings.anthropic_api_base
367
+
368
+ def _headers(self) -> dict[str, str]:
369
+ """Build request headers with Anthropic-specific authentication."""
370
+ return {
371
+ "x-api-key": self.api_key,
372
+ "anthropic-version": "2023-06-01",
373
+ "Content-Type": "application/json",
374
+ }
375
+
376
+ @_retry_on_connection
377
+ async def generate(
378
+ self,
379
+ prompt: str,
380
+ system_prompt: str = "",
381
+ temperature: float = 0.7,
382
+ max_tokens: int = 2048,
383
+ json_mode: bool = False,
384
+ ) -> LLMResponse:
385
+ """Generate a completion via Anthropic's Messages API.
386
+
387
+ Args:
388
+ prompt: The user prompt text.
389
+ system_prompt: Optional system context.
390
+ temperature: Sampling temperature.
391
+ max_tokens: Maximum tokens to generate.
392
+ json_mode: Anthropic does not support native JSON mode; ignored.
393
+
394
+ Returns:
395
+ LLMResponse with generated text and metadata.
396
+ """
397
+ messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
398
+ return await self._send_messages(
399
+ messages=messages,
400
+ system_prompt=system_prompt,
401
+ temperature=temperature,
402
+ max_tokens=max_tokens,
403
+ )
404
+
405
+ @_retry_on_connection
406
+ async def chat(
407
+ self,
408
+ messages: list[dict],
409
+ temperature: float = 0.7,
410
+ max_tokens: int = 2048,
411
+ ) -> LLMResponse:
412
+ """Send a chat request to Anthropic's Messages API.
413
+
414
+ Anthropic uses a separate 'system' parameter instead of a system message
415
+ in the messages list. This method extracts any system message and handles
416
+ the format conversion.
417
+
418
+ Args:
419
+ messages: List of message dicts with 'role' and 'content' keys.
420
+ temperature: Sampling temperature.
421
+ max_tokens: Maximum tokens to generate.
422
+
423
+ Returns:
424
+ LLMResponse with generated text and metadata.
425
+ """
426
+ # Extract system message if present
427
+ system_prompt = ""
428
+ anthropic_messages: list[dict[str, str]] = []
429
+ for msg in messages:
430
+ if msg.get("role") == "system":
431
+ system_prompt = msg.get("content", "")
432
+ else:
433
+ anthropic_messages.append(msg)
434
+
435
+ return await self._send_messages(
436
+ messages=anthropic_messages,
437
+ system_prompt=system_prompt,
438
+ temperature=temperature,
439
+ max_tokens=max_tokens,
440
+ )
441
+
442
+ async def _send_messages(
443
+ self,
444
+ messages: list[dict],
445
+ system_prompt: str = "",
446
+ temperature: float = 0.7,
447
+ max_tokens: int = 2048,
448
+ ) -> LLMResponse:
449
+ """Internal method to send messages to Anthropic's API.
450
+
451
+ Args:
452
+ messages: Anthropic-formatted messages (no system role).
453
+ system_prompt: System prompt passed as top-level parameter.
454
+ temperature: Sampling temperature.
455
+ max_tokens: Maximum tokens to generate.
456
+
457
+ Returns:
458
+ LLMResponse with generated text and metadata.
459
+ """
460
+ payload: dict[str, Any] = {
461
+ "model": self.model,
462
+ "messages": messages,
463
+ "temperature": temperature,
464
+ "max_tokens": max_tokens,
465
+ }
466
+ if system_prompt:
467
+ payload["system"] = system_prompt
468
+
469
+ start = time.perf_counter()
470
+ response = await self._client.post(
471
+ f"{self._api_base}/messages",
472
+ headers=self._headers(),
473
+ json=payload,
474
+ )
475
+ elapsed_ms = (time.perf_counter() - start) * 1000
476
+ response.raise_for_status()
477
+
478
+ data = response.json()
479
+ # Anthropic returns content as a list of content blocks
480
+ content_blocks = data.get("content", [])
481
+ text = ""
482
+ for block in content_blocks:
483
+ if block.get("type") == "text":
484
+ text += block.get("text", "")
485
+
486
+ usage = data.get("usage", {})
487
+ return LLMResponse(
488
+ text=text,
489
+ model=data.get("model", self.model),
490
+ provider="anthropic",
491
+ usage={
492
+ "prompt_tokens": usage.get("input_tokens", 0),
493
+ "completion_tokens": usage.get("output_tokens", 0),
494
+ "total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
495
+ },
496
+ latency_ms=elapsed_ms,
497
+ )
498
+
499
+ async def generate_stream(
500
+ self,
501
+ prompt: str,
502
+ system_prompt: str = "",
503
+ temperature: float = 0.7,
504
+ max_tokens: int = 2048,
505
+ ) -> AsyncGenerator[str, None]:
506
+ """Stream a completion via Anthropic's Messages API.
507
+
508
+ Anthropic supports streaming via SSE. Yields text content blocks
509
+ as they arrive.
510
+
511
+ Args:
512
+ prompt: The user prompt text.
513
+ system_prompt: Optional system context.
514
+ temperature: Sampling temperature.
515
+ max_tokens: Maximum tokens to generate.
516
+
517
+ Yields:
518
+ Token strings as they are generated.
519
+ """
520
+ payload: dict[str, Any] = {
521
+ "model": self.model,
522
+ "messages": [{"role": "user", "content": prompt}],
523
+ "temperature": temperature,
524
+ "max_tokens": max_tokens,
525
+ "stream": True,
526
+ }
527
+ if system_prompt:
528
+ payload["system"] = system_prompt
529
+
530
+ async with self._client.stream(
531
+ "POST",
532
+ f"{self._api_base}/messages",
533
+ headers={**self._headers(), "Accept": "text/event-stream"},
534
+ json=payload,
535
+ ) as resp:
536
+ resp.raise_for_status()
537
+ async for line in resp.aiter_lines():
538
+ line = line.strip()
539
+ if line.startswith("data: "):
540
+ data_str = line[6:]
541
+ if data_str == "[DONE]":
542
+ break
543
+ try:
544
+ data = json.loads(data_str)
545
+ event_type = data.get("type", "")
546
+ if event_type == "content_block_delta":
547
+ delta = data.get("delta", {})
548
+ token = delta.get("text", "")
549
+ if token:
550
+ yield token
551
+ elif event_type == "message_stop":
552
+ break
553
+ except json.JSONDecodeError:
554
+ continue
555
+
556
+ @_retry_on_connection
557
+ async def health_check(self) -> bool:
558
+ """Check if the Anthropic API is reachable.
559
+
560
+ Returns:
561
+ True if the API responds.
562
+ """
563
+ try:
564
+ # Anthropic doesn't have a simple health endpoint; try a minimal request
565
+ response = await self._client.post(
566
+ f"{self._api_base}/messages",
567
+ headers=self._headers(),
568
+ json={
569
+ "model": self.model,
570
+ "messages": [{"role": "user", "content": "hi"}],
571
+ "max_tokens": 1,
572
+ },
573
+ )
574
+ # Any response (even 401) means the service is reachable
575
+ return response.status_code in (200, 401, 400)
576
+ except (httpx.ConnectError, httpx.TimeoutException):
577
+ return False
inference/ollama_client.py CHANGED
@@ -1,334 +1,334 @@
1
- """Async Ollama client wrapper with streaming support and health checks."""
2
-
3
- from __future__ import annotations
4
-
5
- import time
6
- from typing import TYPE_CHECKING, Any
7
-
8
- import httpx
9
-
10
- if TYPE_CHECKING:
11
- from collections.abc import AsyncGenerator
12
- from tenacity import (
13
- retry,
14
- retry_if_exception_type,
15
- stop_after_attempt,
16
- wait_exponential,
17
- )
18
-
19
- from config.settings import settings
20
- from inference.llm_factory import LLMResponse
21
- from utils.logging import get_logger
22
-
23
- logger = get_logger(__name__)
24
-
25
- # Retry decorator for transient connection failures only
26
- _retry_on_connection = retry(
27
- retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
28
- stop=stop_after_attempt(3),
29
- wait=wait_exponential(multiplier=1, min=1, max=10),
30
- reraise=True,
31
- )
32
-
33
-
34
- def make_byok_ollama_client(
35
- *,
36
- base_url: str,
37
- model: str | None = None,
38
- timeout: float = 60.0,
39
- ) -> OllamaClient:
40
- """Build a per-request Ollama client bound to the visitor's instance URL.
41
-
42
- Visitors running their own local Ollama can paste the public URL of
43
- that instance into the frontend. Each call returns a **fresh client**
44
- so the visitor's URL never replaces the owner default at module scope.
45
-
46
- Args:
47
- base_url: URL of the visitor's Ollama server (HTTPS preferred).
48
- model: Override the default model. Falls back to the owner's
49
- configured ``SAR_LLM_MODEL`` if the visitor's Ollama does not
50
- advertise its own.
51
- timeout: Per-request HTTP timeout in seconds.
52
-
53
- Returns:
54
- A new ``OllamaClient`` bound to ``base_url``.
55
-
56
- Raises:
57
- ValueError: ``base_url`` is empty or whitespace.
58
- """
59
- if not base_url or not base_url.strip():
60
- raise ValueError("make_byok_ollama_client called without a base_url")
61
- return OllamaClient(base_url=base_url.strip(), model=model, timeout=timeout)
62
-
63
-
64
- class OllamaClient:
65
- """Async client for the Ollama local LLM inference server.
66
-
67
- Supports generate (completion), chat, streaming, health checks,
68
- and model listing via the Ollama HTTP API.
69
-
70
- Args:
71
- base_url: Ollama server base URL. Defaults to settings.ollama_url.
72
- model: Default model name. Defaults to settings.llm_model.
73
- timeout: Request timeout in seconds.
74
- """
75
-
76
- def __init__(
77
- self,
78
- base_url: str | None = None,
79
- model: str | None = None,
80
- timeout: float = 120.0,
81
- ) -> None:
82
- self.base_url = (base_url if base_url is not None else settings.ollama_url).rstrip("/")
83
- self.model = model if model is not None else settings.llm_model
84
- self.timeout = timeout
85
- self._client = httpx.AsyncClient(
86
- base_url=self.base_url,
87
- timeout=httpx.Timeout(timeout),
88
- )
89
-
90
- @_retry_on_connection
91
- async def generate(
92
- self,
93
- prompt: str,
94
- system_prompt: str = "",
95
- temperature: float = 0.7,
96
- max_tokens: int = 2048,
97
- json_mode: bool = False,
98
- ) -> LLMResponse:
99
- """Generate a completion from the Ollama API.
100
-
101
- Args:
102
- prompt: The user prompt text.
103
- system_prompt: Optional system context.
104
- temperature: Sampling temperature (0.0-1.0).
105
- max_tokens: Maximum tokens to generate.
106
- json_mode: When True, request JSON-formatted output.
107
-
108
- Returns:
109
- LLMResponse with generated text and metadata.
110
- """
111
- payload: dict[str, Any] = {
112
- "model": self.model,
113
- "prompt": prompt,
114
- "stream": False,
115
- "options": {
116
- "temperature": temperature,
117
- "num_predict": max_tokens,
118
- },
119
- "keep_alive": settings.ollama_keep_alive,
120
- }
121
- if system_prompt:
122
- payload["system"] = system_prompt
123
- if json_mode:
124
- payload["format"] = "json"
125
-
126
- start = time.perf_counter()
127
- response = await self._client.post("/api/generate", json=payload)
128
- elapsed_ms = (time.perf_counter() - start) * 1000
129
- response.raise_for_status()
130
-
131
- data = response.json()
132
- return LLMResponse(
133
- text=data.get("response", ""),
134
- model=data.get("model", self.model),
135
- provider="ollama",
136
- usage={
137
- "prompt_tokens": data.get("prompt_eval_count", 0),
138
- "completion_tokens": data.get("eval_count", 0),
139
- "total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
140
- },
141
- latency_ms=elapsed_ms,
142
- metadata={
143
- "total_duration": data.get("total_duration"),
144
- "load_duration": data.get("load_duration"),
145
- },
146
- )
147
-
148
- @_retry_on_connection
149
- async def chat(
150
- self,
151
- messages: list[dict],
152
- temperature: float = 0.7,
153
- max_tokens: int = 2048,
154
- ) -> LLMResponse:
155
- """Send a chat conversation to the Ollama API.
156
-
157
- Args:
158
- messages: List of message dicts with 'role' and 'content' keys.
159
- Roles: "system", "user", "assistant".
160
- temperature: Sampling temperature (0.0-1.0).
161
- max_tokens: Maximum tokens to generate.
162
-
163
- Returns:
164
- LLMResponse with generated text and metadata.
165
- """
166
- payload: dict[str, Any] = {
167
- "model": self.model,
168
- "messages": messages,
169
- "stream": False,
170
- "options": {
171
- "temperature": temperature,
172
- "num_predict": max_tokens,
173
- },
174
- "keep_alive": settings.ollama_keep_alive,
175
- }
176
-
177
- start = time.perf_counter()
178
- response = await self._client.post("/api/chat", json=payload)
179
- elapsed_ms = (time.perf_counter() - start) * 1000
180
- response.raise_for_status()
181
-
182
- data = response.json()
183
- message = data.get("message", {})
184
- return LLMResponse(
185
- text=message.get("content", ""),
186
- model=data.get("model", self.model),
187
- provider="ollama",
188
- usage={
189
- "prompt_tokens": data.get("prompt_eval_count", 0),
190
- "completion_tokens": data.get("eval_count", 0),
191
- "total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
192
- },
193
- latency_ms=elapsed_ms,
194
- metadata={
195
- "total_duration": data.get("total_duration"),
196
- "load_duration": data.get("load_duration"),
197
- },
198
- )
199
-
200
- async def generate_stream(
201
- self,
202
- prompt: str,
203
- system_prompt: str = "",
204
- temperature: float = 0.7,
205
- ) -> AsyncGenerator[str, None]:
206
- """Stream a completion from the Ollama API, yielding tokens as they arrive.
207
-
208
- Args:
209
- prompt: The user prompt text.
210
- system_prompt: Optional system context.
211
- temperature: Sampling temperature (0.0-1.0).
212
-
213
- Yields:
214
- Token strings as they are generated.
215
- """
216
- payload: dict[str, Any] = {
217
- "model": self.model,
218
- "prompt": prompt,
219
- "stream": True,
220
- "options": {
221
- "temperature": temperature,
222
- },
223
- "keep_alive": settings.ollama_keep_alive,
224
- }
225
- if system_prompt:
226
- payload["system"] = system_prompt
227
-
228
- async with self._client.stream("POST", "/api/generate", json=payload) as resp:
229
- resp.raise_for_status()
230
- async for line in resp.aiter_lines():
231
- if line:
232
- import json
233
-
234
- data = json.loads(line)
235
- token = data.get("response", "")
236
- if token:
237
- yield token
238
- if data.get("done", False):
239
- break
240
-
241
- async def chat_stream(
242
- self,
243
- messages: list[dict],
244
- temperature: float = 0.7,
245
- ) -> AsyncGenerator[str, None]:
246
- """Stream a chat completion from the Ollama API, yielding tokens as they arrive.
247
-
248
- Args:
249
- messages: List of message dicts with 'role' and 'content' keys.
250
- temperature: Sampling temperature (0.0-1.0).
251
-
252
- Yields:
253
- Token strings as they are generated.
254
- """
255
- payload: dict[str, Any] = {
256
- "model": self.model,
257
- "messages": messages,
258
- "stream": True,
259
- "options": {
260
- "temperature": temperature,
261
- },
262
- "keep_alive": settings.ollama_keep_alive,
263
- }
264
-
265
- async with self._client.stream("POST", "/api/chat", json=payload) as resp:
266
- resp.raise_for_status()
267
- async for line in resp.aiter_lines():
268
- if line:
269
- import json
270
-
271
- data = json.loads(line)
272
- message = data.get("message", {})
273
- token = message.get("content", "")
274
- if token:
275
- yield token
276
- if data.get("done", False):
277
- break
278
-
279
- @_retry_on_connection
280
- async def health_check(self) -> bool:
281
- """Check if the Ollama server is reachable and responding.
282
-
283
- Returns:
284
- True if the server responds with HTTP 200, False otherwise.
285
- """
286
- try:
287
- response = await self._client.get("/api/tags")
288
- return response.status_code == 200
289
- except (httpx.ConnectError, httpx.TimeoutException):
290
- return False
291
-
292
- @_retry_on_connection
293
- async def list_models(self) -> list[str]:
294
- """List all models available on the Ollama server.
295
-
296
- Returns:
297
- List of model name strings.
298
- """
299
- response = await self._client.get("/api/tags")
300
- response.raise_for_status()
301
- data = response.json()
302
- models = data.get("models", [])
303
- return [m.get("name", "") for m in models]
304
-
305
- @_retry_on_connection
306
- async def get_model_info(self, model: str | None = None) -> dict | None:
307
- """Get detailed information about a specific model.
308
-
309
- Args:
310
- model: Model name to query. Defaults to the client's configured model.
311
-
312
- Returns:
313
- Dict with model info, or None if model not found.
314
- """
315
- target_model = model or self.model
316
- try:
317
- response = await self._client.post("/api/show", json={"name": target_model})
318
- if response.status_code == 200:
319
- return response.json()
320
- return None
321
- except httpx.HTTPStatusError:
322
- return None
323
-
324
- async def close(self) -> None:
325
- """Close the underlying HTTP client."""
326
- await self._client.aclose()
327
-
328
- async def __aenter__(self) -> OllamaClient:
329
- """Enter async context manager."""
330
- return self
331
-
332
- async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
333
- """Exit async context manager, closing the client."""
334
- await self.close()
 
1
+ """Async Ollama client wrapper with streaming support and health checks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import httpx
9
+
10
+ if TYPE_CHECKING:
11
+ from collections.abc import AsyncGenerator
12
+ from tenacity import (
13
+ retry,
14
+ retry_if_exception_type,
15
+ stop_after_attempt,
16
+ wait_exponential,
17
+ )
18
+
19
+ from config.settings import settings
20
+ from inference.llm_factory import LLMResponse
21
+ from utils.logging import get_logger
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ # Retry decorator for transient connection failures only
26
+ _retry_on_connection = retry(
27
+ retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
28
+ stop=stop_after_attempt(3),
29
+ wait=wait_exponential(multiplier=1, min=1, max=10),
30
+ reraise=True,
31
+ )
32
+
33
+
34
+ def make_byok_ollama_client(
35
+ *,
36
+ base_url: str,
37
+ model: str | None = None,
38
+ timeout: float = 60.0,
39
+ ) -> OllamaClient:
40
+ """Build a per-request Ollama client bound to the visitor's instance URL.
41
+
42
+ Visitors running their own local Ollama can paste the public URL of
43
+ that instance into the frontend. Each call returns a **fresh client**
44
+ so the visitor's URL never replaces the owner default at module scope.
45
+
46
+ Args:
47
+ base_url: URL of the visitor's Ollama server (HTTPS preferred).
48
+ model: Override the default model. Falls back to the owner's
49
+ configured ``SAR_LLM_MODEL`` if the visitor's Ollama does not
50
+ advertise its own.
51
+ timeout: Per-request HTTP timeout in seconds.
52
+
53
+ Returns:
54
+ A new ``OllamaClient`` bound to ``base_url``.
55
+
56
+ Raises:
57
+ ValueError: ``base_url`` is empty or whitespace.
58
+ """
59
+ if not base_url or not base_url.strip():
60
+ raise ValueError("make_byok_ollama_client called without a base_url")
61
+ return OllamaClient(base_url=base_url.strip(), model=model, timeout=timeout)
62
+
63
+
64
+ class OllamaClient:
65
+ """Async client for the Ollama local LLM inference server.
66
+
67
+ Supports generate (completion), chat, streaming, health checks,
68
+ and model listing via the Ollama HTTP API.
69
+
70
+ Args:
71
+ base_url: Ollama server base URL. Defaults to settings.ollama_url.
72
+ model: Default model name. Defaults to settings.llm_model.
73
+ timeout: Request timeout in seconds.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ base_url: str | None = None,
79
+ model: str | None = None,
80
+ timeout: float = 120.0,
81
+ ) -> None:
82
+ self.base_url = (base_url if base_url is not None else settings.ollama_url).rstrip("/")
83
+ self.model = model if model is not None else settings.llm_model
84
+ self.timeout = timeout
85
+ self._client = httpx.AsyncClient(
86
+ base_url=self.base_url,
87
+ timeout=httpx.Timeout(timeout),
88
+ )
89
+
90
+ @_retry_on_connection
91
+ async def generate(
92
+ self,
93
+ prompt: str,
94
+ system_prompt: str = "",
95
+ temperature: float = 0.7,
96
+ max_tokens: int = 2048,
97
+ json_mode: bool = False,
98
+ ) -> LLMResponse:
99
+ """Generate a completion from the Ollama API.
100
+
101
+ Args:
102
+ prompt: The user prompt text.
103
+ system_prompt: Optional system context.
104
+ temperature: Sampling temperature (0.0-1.0).
105
+ max_tokens: Maximum tokens to generate.
106
+ json_mode: When True, request JSON-formatted output.
107
+
108
+ Returns:
109
+ LLMResponse with generated text and metadata.
110
+ """
111
+ payload: dict[str, Any] = {
112
+ "model": self.model,
113
+ "prompt": prompt,
114
+ "stream": False,
115
+ "options": {
116
+ "temperature": temperature,
117
+ "num_predict": max_tokens,
118
+ },
119
+ "keep_alive": settings.ollama_keep_alive,
120
+ }
121
+ if system_prompt:
122
+ payload["system"] = system_prompt
123
+ if json_mode:
124
+ payload["format"] = "json"
125
+
126
+ start = time.perf_counter()
127
+ response = await self._client.post("/api/generate", json=payload)
128
+ elapsed_ms = (time.perf_counter() - start) * 1000
129
+ response.raise_for_status()
130
+
131
+ data = response.json()
132
+ return LLMResponse(
133
+ text=data.get("response", ""),
134
+ model=data.get("model", self.model),
135
+ provider="ollama",
136
+ usage={
137
+ "prompt_tokens": data.get("prompt_eval_count", 0),
138
+ "completion_tokens": data.get("eval_count", 0),
139
+ "total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
140
+ },
141
+ latency_ms=elapsed_ms,
142
+ metadata={
143
+ "total_duration": data.get("total_duration"),
144
+ "load_duration": data.get("load_duration"),
145
+ },
146
+ )
147
+
148
+ @_retry_on_connection
149
+ async def chat(
150
+ self,
151
+ messages: list[dict],
152
+ temperature: float = 0.7,
153
+ max_tokens: int = 2048,
154
+ ) -> LLMResponse:
155
+ """Send a chat conversation to the Ollama API.
156
+
157
+ Args:
158
+ messages: List of message dicts with 'role' and 'content' keys.
159
+ Roles: "system", "user", "assistant".
160
+ temperature: Sampling temperature (0.0-1.0).
161
+ max_tokens: Maximum tokens to generate.
162
+
163
+ Returns:
164
+ LLMResponse with generated text and metadata.
165
+ """
166
+ payload: dict[str, Any] = {
167
+ "model": self.model,
168
+ "messages": messages,
169
+ "stream": False,
170
+ "options": {
171
+ "temperature": temperature,
172
+ "num_predict": max_tokens,
173
+ },
174
+ "keep_alive": settings.ollama_keep_alive,
175
+ }
176
+
177
+ start = time.perf_counter()
178
+ response = await self._client.post("/api/chat", json=payload)
179
+ elapsed_ms = (time.perf_counter() - start) * 1000
180
+ response.raise_for_status()
181
+
182
+ data = response.json()
183
+ message = data.get("message", {})
184
+ return LLMResponse(
185
+ text=message.get("content", ""),
186
+ model=data.get("model", self.model),
187
+ provider="ollama",
188
+ usage={
189
+ "prompt_tokens": data.get("prompt_eval_count", 0),
190
+ "completion_tokens": data.get("eval_count", 0),
191
+ "total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
192
+ },
193
+ latency_ms=elapsed_ms,
194
+ metadata={
195
+ "total_duration": data.get("total_duration"),
196
+ "load_duration": data.get("load_duration"),
197
+ },
198
+ )
199
+
200
+ async def generate_stream(
201
+ self,
202
+ prompt: str,
203
+ system_prompt: str = "",
204
+ temperature: float = 0.7,
205
+ ) -> AsyncGenerator[str, None]:
206
+ """Stream a completion from the Ollama API, yielding tokens as they arrive.
207
+
208
+ Args:
209
+ prompt: The user prompt text.
210
+ system_prompt: Optional system context.
211
+ temperature: Sampling temperature (0.0-1.0).
212
+
213
+ Yields:
214
+ Token strings as they are generated.
215
+ """
216
+ payload: dict[str, Any] = {
217
+ "model": self.model,
218
+ "prompt": prompt,
219
+ "stream": True,
220
+ "options": {
221
+ "temperature": temperature,
222
+ },
223
+ "keep_alive": settings.ollama_keep_alive,
224
+ }
225
+ if system_prompt:
226
+ payload["system"] = system_prompt
227
+
228
+ async with self._client.stream("POST", "/api/generate", json=payload) as resp:
229
+ resp.raise_for_status()
230
+ async for line in resp.aiter_lines():
231
+ if line:
232
+ import json
233
+
234
+ data = json.loads(line)
235
+ token = data.get("response", "")
236
+ if token:
237
+ yield token
238
+ if data.get("done", False):
239
+ break
240
+
241
+ async def chat_stream(
242
+ self,
243
+ messages: list[dict],
244
+ temperature: float = 0.7,
245
+ ) -> AsyncGenerator[str, None]:
246
+ """Stream a chat completion from the Ollama API, yielding tokens as they arrive.
247
+
248
+ Args:
249
+ messages: List of message dicts with 'role' and 'content' keys.
250
+ temperature: Sampling temperature (0.0-1.0).
251
+
252
+ Yields:
253
+ Token strings as they are generated.
254
+ """
255
+ payload: dict[str, Any] = {
256
+ "model": self.model,
257
+ "messages": messages,
258
+ "stream": True,
259
+ "options": {
260
+ "temperature": temperature,
261
+ },
262
+ "keep_alive": settings.ollama_keep_alive,
263
+ }
264
+
265
+ async with self._client.stream("POST", "/api/chat", json=payload) as resp:
266
+ resp.raise_for_status()
267
+ async for line in resp.aiter_lines():
268
+ if line:
269
+ import json
270
+
271
+ data = json.loads(line)
272
+ message = data.get("message", {})
273
+ token = message.get("content", "")
274
+ if token:
275
+ yield token
276
+ if data.get("done", False):
277
+ break
278
+
279
+ @_retry_on_connection
280
+ async def health_check(self) -> bool:
281
+ """Check if the Ollama server is reachable and responding.
282
+
283
+ Returns:
284
+ True if the server responds with HTTP 200, False otherwise.
285
+ """
286
+ try:
287
+ response = await self._client.get("/api/tags")
288
+ return response.status_code == 200
289
+ except (httpx.ConnectError, httpx.TimeoutException):
290
+ return False
291
+
292
+ @_retry_on_connection
293
+ async def list_models(self) -> list[str]:
294
+ """List all models available on the Ollama server.
295
+
296
+ Returns:
297
+ List of model name strings.
298
+ """
299
+ response = await self._client.get("/api/tags")
300
+ response.raise_for_status()
301
+ data = response.json()
302
+ models = data.get("models", [])
303
+ return [m.get("name", "") for m in models]
304
+
305
+ @_retry_on_connection
306
+ async def get_model_info(self, model: str | None = None) -> dict | None:
307
+ """Get detailed information about a specific model.
308
+
309
+ Args:
310
+ model: Model name to query. Defaults to the client's configured model.
311
+
312
+ Returns:
313
+ Dict with model info, or None if model not found.
314
+ """
315
+ target_model = model or self.model
316
+ try:
317
+ response = await self._client.post("/api/show", json={"name": target_model})
318
+ if response.status_code == 200:
319
+ return response.json()
320
+ return None
321
+ except httpx.HTTPStatusError:
322
+ return None
323
+
324
+ async def close(self) -> None:
325
+ """Close the underlying HTTP client."""
326
+ await self._client.aclose()
327
+
328
+ async def __aenter__(self) -> OllamaClient:
329
+ """Enter async context manager."""
330
+ return self
331
+
332
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
333
+ """Exit async context manager, closing the client."""
334
+ await self.close()
interfaces/api.py CHANGED
@@ -1,425 +1,432 @@
1
- """FastAPI surface for SecureAgentRAG.
2
-
3
- Run with::
4
-
5
- uv run uvicorn interfaces.api:app --host 0.0.0.0 --port 8080
6
-
7
- Endpoints
8
- ---------
9
- - ``GET /healthz`` — liveness probe (no auth).
10
- - ``GET /readyz`` — readiness — pings Qdrant + Ollama.
11
- - ``POST /query`` — run the RAG pipeline; returns ``QueryResponse``.
12
- - ``POST /ingest`` — ingest a local file; requires ``user`` role.
13
- - ``GET /audit`` — read paginated audit entries; requires ``admin``.
14
- - ``POST /audit/verify``— verify the hash-chain; requires ``admin``.
15
-
16
- Auth uses a stateless bearer token. The token payload is a base64-encoded JSON
17
- ``UserContext`` so the API has no session store — caller provides identity on
18
- every request. Production deployments should swap this for Keycloak/Auth0 JWT
19
- verification (left as a hook in ``_resolve_user``).
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- import base64
25
- import json
26
- from datetime import date
27
- from typing import Annotated
28
-
29
- from config.settings import settings
30
- from utils.auth import AuthError, issue_token, verify_token
31
- from utils.logging import get_logger
32
-
33
- logger = get_logger(__name__)
34
-
35
- try:
36
- from fastapi import Depends, FastAPI, Header, HTTPException, status
37
- from fastapi.responses import JSONResponse
38
-
39
- _FASTAPI_AVAILABLE = True
40
- except ImportError: # pragma: no cover
41
- _FASTAPI_AVAILABLE = False
42
- Depends = Header = FastAPI = HTTPException = JSONResponse = status = None # type: ignore[assignment]
43
-
44
- if _FASTAPI_AVAILABLE:
45
- from core.graph import run_rag_pipeline
46
- from core.schemas import (
47
- IngestRequestModel,
48
- IngestResponseModel,
49
- QueryRequest,
50
- QueryResponse,
51
- )
52
- from ingestion.metadata import IngestRequest, SensitivityLevel, UserContext
53
- from utils.audit import audit_logger
54
- from utils.health import run_health_checks
55
- from utils.rate_limiter import RateLimiter
56
-
57
- rate_limiter = RateLimiter() # uses default token-bucket config
58
-
59
- _AUTH_ERROR_STATUS: dict[str, int] = {
60
- "missing": status.HTTP_401_UNAUTHORIZED,
61
- "malformed": status.HTTP_401_UNAUTHORIZED,
62
- "expired": status.HTTP_401_UNAUTHORIZED,
63
- "bad_signature": status.HTTP_401_UNAUTHORIZED,
64
- "bad_claims": status.HTTP_403_FORBIDDEN,
65
- }
66
-
67
- def _resolve_user_full(
68
- authorization: Annotated[str | None, Header()] = None,
69
- ) -> tuple[UserContext, dict]:
70
- """Verify the bearer token and return (UserContext, claims).
71
-
72
- Delegates to :func:`utils.auth.verify_token`, which uses HS256 JWT
73
- when ``SAR_JWT_SECRET`` is set and falls back to the legacy unsigned
74
- base64 token otherwise (with a runtime warning).
75
- """
76
- if not authorization or not authorization.lower().startswith("bearer "):
77
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
78
- token = authorization.split(" ", 1)[1]
79
- try:
80
- return verify_token(token)
81
- except AuthError as exc:
82
- code = _AUTH_ERROR_STATUS.get(exc.reason, status.HTTP_401_UNAUTHORIZED)
83
- raise HTTPException(code, f"auth_{exc.reason}: {exc}") from exc
84
-
85
- def _resolve_user(authorization: Annotated[str | None, Header()] = None) -> UserContext:
86
- """Backward-compatible dependency returning only the UserContext."""
87
- ctx, _claims = _resolve_user_full(authorization=authorization)
88
- return ctx
89
-
90
- def _require_role(required: str):
91
- def _dep(user: Annotated[UserContext, Depends(_resolve_user)]) -> UserContext:
92
- if required not in user.roles and "admin" not in user.roles:
93
- raise HTTPException(status.HTTP_403_FORBIDDEN, f"role '{required}' required")
94
- return user
95
-
96
- return _dep
97
-
98
- app = FastAPI(
99
- title="SecureAgentRAG API",
100
- version="0.1.0",
101
- description="Privacy-first multi-agent RAG with RBAC, guardrails, and audit chain.",
102
- )
103
-
104
- # Initialize Phoenix tracing if configured.
105
- # When ``settings.byok_mode`` is on, ``setup_tracing`` short-circuits to
106
- # False regardless of phoenix_endpoint (see utils/observability.py).
107
- from utils.observability import setup_tracing
108
-
109
- _tracing_enabled = setup_tracing()
110
- if _tracing_enabled:
111
- logger.info("phoenix_tracing_active_in_api")
112
-
113
- # ── BYOK CORS middleware ─────────────────────────────────────────────
114
- # Only mount CORS when:
115
- # 1) BYOK mode is on (public demo path), AND
116
- # 2) an explicit allowlist is configured via SAR_CORS_ALLOW_ORIGINS.
117
- # Empty allowlist + BYOK = wildcard would be a footgun (CSRF surface).
118
- # Empty allowlist + dev = no CORS needed (local same-origin).
119
- if settings.byok_mode and settings.cors_allow_origins:
120
- from fastapi.middleware.cors import CORSMiddleware
121
-
122
- app.add_middleware(
123
- CORSMiddleware,
124
- allow_origins=list(settings.cors_allow_origins),
125
- allow_credentials=False, # BYOK never uses cookies
126
- allow_methods=["GET", "POST", "OPTIONS"],
127
- allow_headers=["*"],
128
- )
129
- logger.info("byok_cors_enabled", origins=list(settings.cors_allow_origins))
130
-
131
- @app.get("/healthz", tags=["ops"])
132
- async def healthz() -> dict[str, str]:
133
- return {"status": "ok"}
134
-
135
- @app.get("/readyz", tags=["ops"])
136
- async def readyz() -> JSONResponse:
137
- report = await run_health_checks()
138
- code = 200 if report.overall_healthy else 503
139
- return JSONResponse(report.to_dict(), status_code=code)
140
-
141
- # ── BYOK demo endpoint ───────────────────────────────────────────────
142
- # Mounted only when ``settings.byok_mode`` is on. Bypasses JWT auth and
143
- # uses per-request BYOK credentials instead. Isolation is enforced via
144
- # session-scoped Qdrant collections, not JWT identity.
145
- if settings.byok_mode:
146
- from interfaces.byok import ByokCreds, extract_byok
147
- from utils.rate_limiter import get_owner_key_throttle
148
-
149
- _DEMO_PERSONAS: dict[str, dict] = {
150
- "engineer": {
151
- "org_id": "demo-engineering",
152
- "clearance_level": 2,
153
- "roles": ["engineering"],
154
- },
155
- "compliance": {
156
- "org_id": "demo-compliance",
157
- "clearance_level": 4,
158
- "roles": ["compliance", "legal"],
159
- },
160
- "executive": {
161
- "org_id": "demo-executive",
162
- "clearance_level": 5,
163
- "roles": ["executive", "compliance"],
164
- },
165
- }
166
-
167
- def _persona_to_user_ctx(creds: ByokCreds) -> UserContext:
168
- """Translate ``creds.demo_persona`` into a synthetic UserContext.
169
-
170
- Unknown / missing persona → minimal read-only profile so the demo
171
- still answers but cannot escalate beyond the lowest clearance.
172
- """
173
- preset = _DEMO_PERSONAS.get((creds.demo_persona or "").lower())
174
- if preset is None:
175
- preset = {"org_id": "demo-anon", "clearance_level": 1, "roles": ["viewer"]}
176
- return UserContext(
177
- user_id=f"demo-{creds.session_id}",
178
- org_id=preset["org_id"],
179
- clearance_level=preset["clearance_level"],
180
- roles=preset["roles"],
181
- )
182
-
183
- from pydantic import BaseModel as _ByokBaseModel
184
-
185
- class _ByokChatBody(_ByokBaseModel):
186
- """Public-demo chat payload — no auth fields, only the question text."""
187
-
188
- query: str
189
- prefer_cloud: bool = True
190
-
191
- # Runtime import — FastAPI dependency injection reads the annotation
192
- # at request time, so this must NOT be a TYPE_CHECKING-only import.
193
- from fastapi import Request as _FastApiRequest # noqa: TC002
194
-
195
- @app.post("/byok/chat", tags=["byok"])
196
- async def byok_chat_endpoint(
197
- request: _FastApiRequest,
198
- body: _ByokChatBody,
199
- creds: Annotated[ByokCreds, Depends(extract_byok)],
200
- ) -> dict:
201
- """Public-demo chat endpoint backed by BYOK credentials.
202
-
203
- Routing:
204
- - Visitor brought a key (``creds.has_user_key()``): pipeline uses
205
- the visitor's provider + key. No throttle.
206
- - Visitor did NOT bring a key: pipeline falls back to the owner's
207
- configured cloud provider key, gated by ``OwnerKeyHourThrottle``.
208
- When exhausted, returns 429 with copy nudging BYOK.
209
-
210
- Persona maps to a synthetic ``UserContext`` so the existing RBAC
211
- filter still runs end-to-end same code path as authenticated
212
- queries, just with demo identities.
213
- """
214
- if not creds.has_user_key():
215
- throttle = get_owner_key_throttle()
216
- client_ip = (request.client.host if request.client else None) or "anon"
217
- ok, meta = throttle.allow(client_ip)
218
- if not ok:
219
- raise HTTPException(
220
- status.HTTP_429_TOO_MANY_REQUESTS,
221
- detail={
222
- "reason": meta["reason"],
223
- "retry_after_seconds": meta["retry_after"],
224
- "hint": (
225
- "Owner-key fallback exhausted for this IP. "
226
- "Paste your own LLM key to continue — your key "
227
- "is never stored server-side."
228
- ),
229
- },
230
- )
231
- user_ctx = _persona_to_user_ctx(creds)
232
- state = await run_rag_pipeline(
233
- query=body.query,
234
- user_context=user_ctx,
235
- thread_id=f"byok-{creds.session_id}",
236
- prefer_cloud=body.prefer_cloud,
237
- # Visitor's chosen provider when present; falls back to env.
238
- override_provider=creds.safe_provider(),
239
- )
240
- response = QueryResponse.from_state(state)
241
- return {
242
- "session_id": creds.session_id,
243
- "persona": creds.demo_persona or "anonymous",
244
- "byok_used": creds.has_user_key(),
245
- "response": response.model_dump(mode="json"),
246
- }
247
-
248
- @app.post("/query", response_model=QueryResponse, tags=["rag"])
249
- async def query_endpoint(
250
- body: QueryRequest,
251
- auth: Annotated[tuple[UserContext, dict], Depends(_resolve_user_full)],
252
- ) -> QueryResponse:
253
- user, claims = auth
254
- if not rate_limiter.is_allowed(f"{user.user_id}:query"):
255
- raise HTTPException(status.HTTP_429_TOO_MANY_REQUESTS, "rate limit exceeded")
256
- # Caller-supplied user_id must match the bearer-token identity.
257
- if body.user_id != user.user_id:
258
- raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch")
259
- # Use the JWT id so the audit trail can correlate a query with the
260
- # exact token that authorised it; useful for revocation forensics.
261
- jti = claims.get("jti", "unsigned")
262
- state = await run_rag_pipeline(
263
- query=body.query,
264
- user_context=user,
265
- thread_id=f"api-{user.user_id}-{jti}",
266
- prefer_cloud=body.prefer_cloud,
267
- override_provider=body.override_provider,
268
- )
269
- return QueryResponse.from_state(state)
270
-
271
- @app.post("/ingest", response_model=IngestResponseModel, tags=["rag"])
272
- async def ingest_endpoint(
273
- body: IngestRequestModel,
274
- user: Annotated[UserContext, Depends(_require_role("user"))],
275
- ) -> IngestResponseModel:
276
- if body.user_id != user.user_id:
277
- raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch")
278
- from core.agents.retriever import _get_hybrid_searcher
279
- from ingestion.pipeline import IngestionPipeline
280
-
281
- searcher = _get_hybrid_searcher()
282
- pipeline = IngestionPipeline(
283
- qdrant_manager=searcher._qdrant, # type: ignore[attr-defined]
284
- embedding_service=searcher._embeddings, # type: ignore[attr-defined]
285
- sparse_service=searcher._sparse, # type: ignore[attr-defined]
286
- )
287
- req = IngestRequest(
288
- file_path=body.file_path,
289
- user_id=body.user_id,
290
- org_id=body.org_id,
291
- sensitivity_level=SensitivityLevel(body.sensitivity_level),
292
- roles=body.roles,
293
- )
294
- result = await pipeline.ingest_document(req)
295
- return IngestResponseModel(
296
- file_path=result.file_path,
297
- status=result.status,
298
- num_chunks=result.num_chunks,
299
- point_ids=result.point_ids,
300
- errors=result.errors,
301
- processing_time_seconds=result.processing_time_seconds,
302
- )
303
-
304
- @app.get("/audit", tags=["audit"])
305
- async def audit_list(
306
- user: Annotated[UserContext, Depends(_require_role("admin"))],
307
- start: str | None = None,
308
- end: str | None = None,
309
- limit: int = 100,
310
- ) -> dict:
311
- today = date.today().isoformat()
312
- entries = audit_logger.get_entries(
313
- start_date=start or today,
314
- end_date=end or today,
315
- user_id=None,
316
- action=None,
317
- )
318
- return {
319
- "total": len(entries),
320
- "items": [e.model_dump(mode="json") for e in entries[:limit]],
321
- }
322
-
323
- @app.post("/audit/verify", tags=["audit"])
324
- async def audit_verify(
325
- user: Annotated[UserContext, Depends(_require_role("admin"))],
326
- start: str | None = None,
327
- end: str | None = None,
328
- ) -> dict:
329
- result = audit_logger.verify_chain(start_date=start, end_date=end)
330
- return result
331
-
332
- from pydantic import BaseModel as _PydBM
333
-
334
- class _TokenRequest(_PydBM):
335
- """Identity payload accepted by the dev ``/token`` endpoint."""
336
-
337
- user_id: str
338
- org_id: str = ""
339
- roles: list[str] = []
340
- clearance_level: int = 1
341
- ttl_seconds: int | None = None
342
-
343
- class _TokenResponse(_PydBM):
344
- access_token: str
345
- token_type: str = "bearer"
346
- expires_in: int
347
-
348
- @app.post("/token", response_model=_TokenResponse, tags=["auth"])
349
- async def issue_dev_token(body: _TokenRequest) -> _TokenResponse:
350
- """Mint a signed JWT for local testing.
351
-
352
- In production the IdP (Keycloak / Auth0 / Microsoft Entra) issues the
353
- token externally and this endpoint is removed via the
354
- ``SAR_DISABLE_DEV_TOKEN`` flag — kept here so the e2e smoke script
355
- and the Streamlit demo can mint a real token rather than the
356
- unsigned base64 fallback.
357
- """
358
- if settings.jwt_algorithm.upper() == "RS256":
359
- raise HTTPException(
360
- status.HTTP_404_NOT_FOUND,
361
- "Dev token endpoint disabled in RS256 mode — use the external IdP",
362
- )
363
- if not settings.jwt_secret:
364
- raise HTTPException(
365
- status.HTTP_503_SERVICE_UNAVAILABLE,
366
- "SAR_JWT_SECRET is not configured; token endpoint disabled",
367
- )
368
- try:
369
- token = issue_token(
370
- user_id=body.user_id,
371
- org_id=body.org_id,
372
- roles=body.roles,
373
- clearance_level=body.clearance_level,
374
- ttl_seconds=body.ttl_seconds,
375
- )
376
- except AuthError as exc:
377
- raise HTTPException(
378
- status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}"
379
- ) from exc
380
- return _TokenResponse(
381
- access_token=token,
382
- token_type="bearer",
383
- expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
384
- )
385
- try:
386
- token = issue_token(
387
- user_id=body.user_id,
388
- org_id=body.org_id,
389
- roles=body.roles,
390
- clearance_level=body.clearance_level,
391
- ttl_seconds=body.ttl_seconds,
392
- )
393
- except AuthError as exc:
394
- raise HTTPException(
395
- status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}"
396
- ) from exc
397
- return _TokenResponse(
398
- access_token=token,
399
- expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
400
- )
401
-
402
- else: # pragma: no cover
403
- app = None # type: ignore[assignment]
404
-
405
-
406
- def mint_dev_token(user: dict) -> str:
407
- """Convenience for local testing — build a bearer token for a UserContext dict.
408
-
409
- When ``SAR_JWT_SECRET`` is configured this mints a real signed JWT; with
410
- no secret it falls back to the legacy unsigned base64 shape so existing
411
- test fixtures keep working.
412
- """
413
- if settings.jwt_secret:
414
- try:
415
- return issue_token(
416
- user_id=user.get("user_id", ""),
417
- org_id=user.get("org_id", ""),
418
- roles=list(user.get("roles", [])),
419
- clearance_level=int(user.get("clearance_level", 1)),
420
- )
421
- except AuthError:
422
- # Fall through to legacy shape on issuer error.
423
- pass
424
- payload = json.dumps(user).encode("utf-8")
425
- return base64.b64encode(payload).decode("ascii")
 
 
 
 
 
 
 
 
1
+ """FastAPI surface for SecureAgentRAG.
2
+
3
+ Run with::
4
+
5
+ uv run uvicorn interfaces.api:app --host 0.0.0.0 --port 8080
6
+
7
+ Endpoints
8
+ ---------
9
+ - ``GET /healthz`` — liveness probe (no auth).
10
+ - ``GET /readyz`` — readiness — pings Qdrant + Ollama.
11
+ - ``POST /query`` — run the RAG pipeline; returns ``QueryResponse``.
12
+ - ``POST /ingest`` — ingest a local file; requires ``user`` role.
13
+ - ``GET /audit`` — read paginated audit entries; requires ``admin``.
14
+ - ``POST /audit/verify``— verify the hash-chain; requires ``admin``.
15
+
16
+ Auth uses a stateless bearer token. The token payload is a base64-encoded JSON
17
+ ``UserContext`` so the API has no session store — caller provides identity on
18
+ every request. Production deployments should swap this for Keycloak/Auth0 JWT
19
+ verification (left as a hook in ``_resolve_user``).
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import base64
25
+ import json
26
+ from datetime import date
27
+ from typing import Annotated
28
+
29
+ from config.settings import settings
30
+ from utils.auth import AuthError, issue_token, verify_token
31
+ from utils.logging import get_logger
32
+
33
+ logger = get_logger(__name__)
34
+
35
+ try:
36
+ from fastapi import Depends, FastAPI, Header, HTTPException, status
37
+ from fastapi.responses import JSONResponse
38
+
39
+ _FASTAPI_AVAILABLE = True
40
+ except ImportError: # pragma: no cover
41
+ _FASTAPI_AVAILABLE = False
42
+ Depends = Header = FastAPI = HTTPException = JSONResponse = status = None # type: ignore[assignment]
43
+
44
+ if _FASTAPI_AVAILABLE:
45
+ from core.graph import run_rag_pipeline
46
+ from core.schemas import (
47
+ IngestRequestModel,
48
+ IngestResponseModel,
49
+ QueryRequest,
50
+ QueryResponse,
51
+ )
52
+ from ingestion.metadata import IngestRequest, SensitivityLevel, UserContext
53
+ from utils.audit import audit_logger
54
+ from utils.health import run_health_checks
55
+ from utils.rate_limiter import RateLimiter
56
+
57
+ rate_limiter = RateLimiter() # uses default token-bucket config
58
+
59
+ _AUTH_ERROR_STATUS: dict[str, int] = {
60
+ "missing": status.HTTP_401_UNAUTHORIZED,
61
+ "malformed": status.HTTP_401_UNAUTHORIZED,
62
+ "expired": status.HTTP_401_UNAUTHORIZED,
63
+ "bad_signature": status.HTTP_401_UNAUTHORIZED,
64
+ "bad_claims": status.HTTP_403_FORBIDDEN,
65
+ }
66
+
67
+ def _resolve_user_full(
68
+ authorization: Annotated[str | None, Header()] = None,
69
+ ) -> tuple[UserContext, dict]:
70
+ """Verify the bearer token and return (UserContext, claims).
71
+
72
+ Delegates to :func:`utils.auth.verify_token`, which uses HS256 JWT
73
+ when ``SAR_JWT_SECRET`` is set and falls back to the legacy unsigned
74
+ base64 token otherwise (with a runtime warning).
75
+ """
76
+ if not authorization or not authorization.lower().startswith("bearer "):
77
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
78
+ token = authorization.split(" ", 1)[1]
79
+ try:
80
+ return verify_token(token)
81
+ except AuthError as exc:
82
+ code = _AUTH_ERROR_STATUS.get(exc.reason, status.HTTP_401_UNAUTHORIZED)
83
+ raise HTTPException(code, f"auth_{exc.reason}: {exc}") from exc
84
+
85
+ def _resolve_user(authorization: Annotated[str | None, Header()] = None) -> UserContext:
86
+ """Backward-compatible dependency returning only the UserContext."""
87
+ ctx, _claims = _resolve_user_full(authorization=authorization)
88
+ return ctx
89
+
90
+ def _require_role(required: str):
91
+ def _dep(user: Annotated[UserContext, Depends(_resolve_user)]) -> UserContext:
92
+ if required not in user.roles and "admin" not in user.roles:
93
+ raise HTTPException(status.HTTP_403_FORBIDDEN, f"role '{required}' required")
94
+ return user
95
+
96
+ return _dep
97
+
98
+ app = FastAPI(
99
+ title="SecureAgentRAG API",
100
+ version="0.1.0",
101
+ description="Privacy-first multi-agent RAG with RBAC, guardrails, and audit chain.",
102
+ )
103
+
104
+ # Initialize Phoenix tracing if configured.
105
+ # When ``settings.byok_mode`` is on, ``setup_tracing`` short-circuits to
106
+ # False regardless of phoenix_endpoint (see utils/observability.py).
107
+ from utils.observability import setup_tracing
108
+
109
+ _tracing_enabled = setup_tracing()
110
+ if _tracing_enabled:
111
+ logger.info("phoenix_tracing_active_in_api")
112
+
113
+ # ── BYOK CORS middleware ─────────────────────────────────────────────
114
+ # Only mount CORS when:
115
+ # 1) BYOK mode is on (public demo path), AND
116
+ # 2) an explicit allowlist is configured via SAR_CORS_ALLOW_ORIGINS.
117
+ # Empty allowlist + BYOK = wildcard would be a footgun (CSRF surface).
118
+ # Empty allowlist + dev = no CORS needed (local same-origin).
119
+ if settings.byok_mode and settings.cors_allow_origins:
120
+ from fastapi.middleware.cors import CORSMiddleware
121
+
122
+ app.add_middleware(
123
+ CORSMiddleware,
124
+ allow_origins=list(settings.cors_allow_origins),
125
+ allow_credentials=False, # BYOK never uses cookies
126
+ allow_methods=["GET", "POST", "OPTIONS"],
127
+ allow_headers=["*"],
128
+ )
129
+ logger.info("byok_cors_enabled", origins=list(settings.cors_allow_origins))
130
+
131
+ @app.get("/healthz", tags=["ops"])
132
+ async def healthz() -> dict[str, str]:
133
+ return {"status": "ok"}
134
+
135
+ @app.get("/readyz", tags=["ops"])
136
+ async def readyz() -> JSONResponse:
137
+ report = await run_health_checks()
138
+ code = 200 if report.overall_healthy else 503
139
+ return JSONResponse(report.to_dict(), status_code=code)
140
+
141
+ # ── BYOK demo endpoint ───────────────────────────────────────────────
142
+ # Mounted only when ``settings.byok_mode`` is on. Bypasses JWT auth and
143
+ # uses per-request BYOK credentials instead. Isolation is enforced via
144
+ # session-scoped Qdrant collections, not JWT identity.
145
+ if settings.byok_mode:
146
+ from interfaces.byok import ByokCreds, extract_byok
147
+ from utils.rate_limiter import get_owner_key_throttle
148
+
149
+ # All demo personas share ``org_id="demo"`` so they query the same
150
+ # ingested corpus. RBAC differentiation is enforced via clearance
151
+ # level + roles at the payload-filter layer -- exactly the production
152
+ # invariant we want to demonstrate.
153
+ _DEMO_ORG_ID = "demo"
154
+ # Sensitivity levels are LOW=1, MEDIUM=2, HIGH=3 (see
155
+ # ``ingestion/metadata.py::sensitivity_to_int``). Clearance levels must
156
+ # be in the same range so the Qdrant range filter passes the right
157
+ # chunks. Engineer < Compliance == Executive, but executive carries
158
+ # a wider role set (sees both engineering + compliance content).
159
+ _DEMO_PERSONAS: dict[str, dict] = {
160
+ "engineer": {
161
+ "clearance_level": 2,
162
+ "roles": ["engineering"],
163
+ },
164
+ "compliance": {
165
+ "clearance_level": 3,
166
+ "roles": ["compliance", "legal"],
167
+ },
168
+ "executive": {
169
+ "clearance_level": 3,
170
+ "roles": ["executive", "compliance", "engineering"],
171
+ },
172
+ }
173
+
174
+ def _persona_to_user_ctx(creds: ByokCreds) -> UserContext:
175
+ """Translate ``creds.demo_persona`` into a synthetic UserContext.
176
+
177
+ Unknown / missing persona → minimal read-only profile so the demo
178
+ still answers but cannot escalate beyond the lowest clearance.
179
+ """
180
+ preset = _DEMO_PERSONAS.get((creds.demo_persona or "").lower())
181
+ if preset is None:
182
+ preset = {"clearance_level": 1, "roles": ["viewer"]}
183
+ return UserContext(
184
+ user_id=f"demo-{creds.session_id}",
185
+ org_id=_DEMO_ORG_ID,
186
+ clearance_level=preset["clearance_level"],
187
+ roles=preset["roles"],
188
+ )
189
+
190
+ from pydantic import BaseModel as _ByokBaseModel
191
+
192
+ class _ByokChatBody(_ByokBaseModel):
193
+ """Public-demo chat payload no auth fields, only the question text."""
194
+
195
+ query: str
196
+ prefer_cloud: bool = True
197
+
198
+ # Runtime import — FastAPI dependency injection reads the annotation
199
+ # at request time, so this must NOT be a TYPE_CHECKING-only import.
200
+ from fastapi import Request as _FastApiRequest # noqa: TC002
201
+
202
+ @app.post("/byok/chat", tags=["byok"])
203
+ async def byok_chat_endpoint(
204
+ request: _FastApiRequest,
205
+ body: _ByokChatBody,
206
+ creds: Annotated[ByokCreds, Depends(extract_byok)],
207
+ ) -> dict:
208
+ """Public-demo chat endpoint backed by BYOK credentials.
209
+
210
+ Routing:
211
+ - Visitor brought a key (``creds.has_user_key()``): pipeline uses
212
+ the visitor's provider + key. No throttle.
213
+ - Visitor did NOT bring a key: pipeline falls back to the owner's
214
+ configured cloud provider key, gated by ``OwnerKeyHourThrottle``.
215
+ When exhausted, returns 429 with copy nudging BYOK.
216
+
217
+ Persona maps to a synthetic ``UserContext`` so the existing RBAC
218
+ filter still runs end-to-end — same code path as authenticated
219
+ queries, just with demo identities.
220
+ """
221
+ if not creds.has_user_key():
222
+ throttle = get_owner_key_throttle()
223
+ client_ip = (request.client.host if request.client else None) or "anon"
224
+ ok, meta = throttle.allow(client_ip)
225
+ if not ok:
226
+ raise HTTPException(
227
+ status.HTTP_429_TOO_MANY_REQUESTS,
228
+ detail={
229
+ "reason": meta["reason"],
230
+ "retry_after_seconds": meta["retry_after"],
231
+ "hint": (
232
+ "Owner-key fallback exhausted for this IP. "
233
+ "Paste your own LLM key to continue — your key "
234
+ "is never stored server-side."
235
+ ),
236
+ },
237
+ )
238
+ user_ctx = _persona_to_user_ctx(creds)
239
+ state = await run_rag_pipeline(
240
+ query=body.query,
241
+ user_context=user_ctx,
242
+ thread_id=f"byok-{creds.session_id}",
243
+ prefer_cloud=body.prefer_cloud,
244
+ # Visitor's chosen provider when present; falls back to env.
245
+ override_provider=creds.safe_provider(),
246
+ )
247
+ response = QueryResponse.from_state(state)
248
+ return {
249
+ "session_id": creds.session_id,
250
+ "persona": creds.demo_persona or "anonymous",
251
+ "byok_used": creds.has_user_key(),
252
+ "response": response.model_dump(mode="json"),
253
+ }
254
+
255
+ @app.post("/query", response_model=QueryResponse, tags=["rag"])
256
+ async def query_endpoint(
257
+ body: QueryRequest,
258
+ auth: Annotated[tuple[UserContext, dict], Depends(_resolve_user_full)],
259
+ ) -> QueryResponse:
260
+ user, claims = auth
261
+ if not rate_limiter.is_allowed(f"{user.user_id}:query"):
262
+ raise HTTPException(status.HTTP_429_TOO_MANY_REQUESTS, "rate limit exceeded")
263
+ # Caller-supplied user_id must match the bearer-token identity.
264
+ if body.user_id != user.user_id:
265
+ raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch")
266
+ # Use the JWT id so the audit trail can correlate a query with the
267
+ # exact token that authorised it; useful for revocation forensics.
268
+ jti = claims.get("jti", "unsigned")
269
+ state = await run_rag_pipeline(
270
+ query=body.query,
271
+ user_context=user,
272
+ thread_id=f"api-{user.user_id}-{jti}",
273
+ prefer_cloud=body.prefer_cloud,
274
+ override_provider=body.override_provider,
275
+ )
276
+ return QueryResponse.from_state(state)
277
+
278
+ @app.post("/ingest", response_model=IngestResponseModel, tags=["rag"])
279
+ async def ingest_endpoint(
280
+ body: IngestRequestModel,
281
+ user: Annotated[UserContext, Depends(_require_role("user"))],
282
+ ) -> IngestResponseModel:
283
+ if body.user_id != user.user_id:
284
+ raise HTTPException(status.HTTP_403_FORBIDDEN, "user_id mismatch")
285
+ from core.agents.retriever import _get_hybrid_searcher
286
+ from ingestion.pipeline import IngestionPipeline
287
+
288
+ searcher = _get_hybrid_searcher()
289
+ pipeline = IngestionPipeline(
290
+ qdrant_manager=searcher._qdrant, # type: ignore[attr-defined]
291
+ embedding_service=searcher._embeddings, # type: ignore[attr-defined]
292
+ sparse_service=searcher._sparse, # type: ignore[attr-defined]
293
+ )
294
+ req = IngestRequest(
295
+ file_path=body.file_path,
296
+ user_id=body.user_id,
297
+ org_id=body.org_id,
298
+ sensitivity_level=SensitivityLevel(body.sensitivity_level),
299
+ roles=body.roles,
300
+ )
301
+ result = await pipeline.ingest_document(req)
302
+ return IngestResponseModel(
303
+ file_path=result.file_path,
304
+ status=result.status,
305
+ num_chunks=result.num_chunks,
306
+ point_ids=result.point_ids,
307
+ errors=result.errors,
308
+ processing_time_seconds=result.processing_time_seconds,
309
+ )
310
+
311
+ @app.get("/audit", tags=["audit"])
312
+ async def audit_list(
313
+ user: Annotated[UserContext, Depends(_require_role("admin"))],
314
+ start: str | None = None,
315
+ end: str | None = None,
316
+ limit: int = 100,
317
+ ) -> dict:
318
+ today = date.today().isoformat()
319
+ entries = audit_logger.get_entries(
320
+ start_date=start or today,
321
+ end_date=end or today,
322
+ user_id=None,
323
+ action=None,
324
+ )
325
+ return {
326
+ "total": len(entries),
327
+ "items": [e.model_dump(mode="json") for e in entries[:limit]],
328
+ }
329
+
330
+ @app.post("/audit/verify", tags=["audit"])
331
+ async def audit_verify(
332
+ user: Annotated[UserContext, Depends(_require_role("admin"))],
333
+ start: str | None = None,
334
+ end: str | None = None,
335
+ ) -> dict:
336
+ result = audit_logger.verify_chain(start_date=start, end_date=end)
337
+ return result
338
+
339
+ from pydantic import BaseModel as _PydBM
340
+
341
+ class _TokenRequest(_PydBM):
342
+ """Identity payload accepted by the dev ``/token`` endpoint."""
343
+
344
+ user_id: str
345
+ org_id: str = ""
346
+ roles: list[str] = []
347
+ clearance_level: int = 1
348
+ ttl_seconds: int | None = None
349
+
350
+ class _TokenResponse(_PydBM):
351
+ access_token: str
352
+ token_type: str = "bearer"
353
+ expires_in: int
354
+
355
+ @app.post("/token", response_model=_TokenResponse, tags=["auth"])
356
+ async def issue_dev_token(body: _TokenRequest) -> _TokenResponse:
357
+ """Mint a signed JWT for local testing.
358
+
359
+ In production the IdP (Keycloak / Auth0 / Microsoft Entra) issues the
360
+ token externally and this endpoint is removed via the
361
+ ``SAR_DISABLE_DEV_TOKEN`` flag kept here so the e2e smoke script
362
+ and the Streamlit demo can mint a real token rather than the
363
+ unsigned base64 fallback.
364
+ """
365
+ if settings.jwt_algorithm.upper() == "RS256":
366
+ raise HTTPException(
367
+ status.HTTP_404_NOT_FOUND,
368
+ "Dev token endpoint disabled in RS256 mode — use the external IdP",
369
+ )
370
+ if not settings.jwt_secret:
371
+ raise HTTPException(
372
+ status.HTTP_503_SERVICE_UNAVAILABLE,
373
+ "SAR_JWT_SECRET is not configured; token endpoint disabled",
374
+ )
375
+ try:
376
+ token = issue_token(
377
+ user_id=body.user_id,
378
+ org_id=body.org_id,
379
+ roles=body.roles,
380
+ clearance_level=body.clearance_level,
381
+ ttl_seconds=body.ttl_seconds,
382
+ )
383
+ except AuthError as exc:
384
+ raise HTTPException(
385
+ status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}"
386
+ ) from exc
387
+ return _TokenResponse(
388
+ access_token=token,
389
+ token_type="bearer",
390
+ expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
391
+ )
392
+ try:
393
+ token = issue_token(
394
+ user_id=body.user_id,
395
+ org_id=body.org_id,
396
+ roles=body.roles,
397
+ clearance_level=body.clearance_level,
398
+ ttl_seconds=body.ttl_seconds,
399
+ )
400
+ except AuthError as exc:
401
+ raise HTTPException(
402
+ status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}"
403
+ ) from exc
404
+ return _TokenResponse(
405
+ access_token=token,
406
+ expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
407
+ )
408
+
409
+ else: # pragma: no cover
410
+ app = None # type: ignore[assignment]
411
+
412
+
413
+ def mint_dev_token(user: dict) -> str:
414
+ """Convenience for local testing — build a bearer token for a UserContext dict.
415
+
416
+ When ``SAR_JWT_SECRET`` is configured this mints a real signed JWT; with
417
+ no secret it falls back to the legacy unsigned base64 shape so existing
418
+ test fixtures keep working.
419
+ """
420
+ if settings.jwt_secret:
421
+ try:
422
+ return issue_token(
423
+ user_id=user.get("user_id", ""),
424
+ org_id=user.get("org_id", ""),
425
+ roles=list(user.get("roles", [])),
426
+ clearance_level=int(user.get("clearance_level", 1)),
427
+ )
428
+ except AuthError:
429
+ # Fall through to legacy shape on issuer error.
430
+ pass
431
+ payload = json.dumps(user).encode("utf-8")
432
+ return base64.b64encode(payload).decode("ascii")
interfaces/byok.py CHANGED
@@ -1,166 +1,166 @@
1
- """BYOK (Bring Your Own Key) request extraction for the public demo.
2
-
3
- Mounted on the FastAPI surface only when ``settings.byok_mode=True`` (production
4
- HF Space image). Extracts per-request LLM credentials and session identity from
5
- HTTP headers so the RAG pipeline can route to the visitor's own LLM provider
6
- and Qdrant collection.
7
-
8
- The extracted ``ByokCreds`` is **never persisted**:
9
-
10
- - API keys live only in the request scope (FastAPI dep dies after response)
11
- - ``utils.pii.redact`` strips key-shaped substrings from audit log entries
12
- - The frontend stores the key in ``localStorage`` and forwards it as a header;
13
- cookies are forbidden (CSRF surface).
14
-
15
- See ``launch-plan/03-backend-byok.md`` and ``launch-plan/11-security-checklist.md``.
16
- """
17
-
18
- from __future__ import annotations
19
-
20
- import hashlib
21
- import uuid
22
- from typing import TYPE_CHECKING
23
-
24
- from pydantic import BaseModel, ConfigDict, Field
25
-
26
- if TYPE_CHECKING:
27
- from fastapi import Request
28
-
29
-
30
- # Header names the frontend sends.
31
- HDR_USER_KEY = "X-User-LLM-Key"
32
- HDR_USER_PROVIDER = "X-User-Provider"
33
- HDR_USER_OLLAMA_URL = "X-User-Ollama-URL"
34
- HDR_SESSION_ID = "X-Session-ID"
35
- HDR_DEMO_PERSONA = "X-Demo-Persona"
36
-
37
- # Supported provider literals carried in X-User-Provider.
38
- SUPPORTED_PROVIDERS: frozenset[str] = frozenset({"groq", "openai", "anthropic", "ollama"})
39
-
40
-
41
- class ByokCreds(BaseModel):
42
- """Per-request BYOK credentials and session identity.
43
-
44
- Attributes:
45
- user_key: Visitor's own LLM provider API key. None means owner-key
46
- fallback (subject to ``OwnerKeyHourThrottle``).
47
- provider: Which LLM provider the ``user_key`` is for. Validated
48
- against ``SUPPORTED_PROVIDERS``. None defaults to the platform
49
- owner's configured ``cloud_provider``.
50
- ollama_url: Visitor's Ollama instance URL when provider == "ollama".
51
- Ignored otherwise.
52
- session_id: Per-visitor session identifier. Drives the per-session
53
- Qdrant collection name. Generated server-side when the visitor
54
- does not provide one (first request of a session).
55
- demo_persona: Optional preset RBAC profile for the public demo —
56
- ``engineer`` / ``compliance`` / ``executive``. Translated to
57
- ``UserContext`` downstream.
58
- """
59
-
60
- model_config = ConfigDict(frozen=True, str_strip_whitespace=True)
61
-
62
- user_key: str | None = None
63
- provider: str | None = None
64
- ollama_url: str | None = None
65
- session_id: str = Field(..., min_length=1, max_length=128)
66
- demo_persona: str | None = None
67
-
68
- def has_user_key(self) -> bool:
69
- """True when the visitor brought their own LLM key.
70
-
71
- Owner-key fallback (False) goes through the per-IP throttle; visitor
72
- BYOK (True) bypasses it. Callers MUST consult this before deciding to
73
- consume the owner-key quota.
74
- """
75
- return bool(self.user_key and self.user_key.strip())
76
-
77
- def safe_provider(self) -> str | None:
78
- """Return ``provider`` if it is in the allowlist, else None."""
79
- if self.provider and self.provider.lower() in SUPPORTED_PROVIDERS:
80
- return self.provider.lower()
81
- return None
82
-
83
-
84
- def _derive_session_id(client_host: str | None) -> str:
85
- """Generate a deterministic-but-non-identifying session ID.
86
-
87
- Falls back to a short hash of the client host + a random UUID. The hash
88
- keeps the same session sticky if the visitor reconnects within the same
89
- UVicorn worker; the random UUID ensures cross-worker / cross-restart
90
- isolation. The full UUID flavour stays server-side — we never expose
91
- raw IP addresses in the collection name.
92
- """
93
- host = (client_host or "anon").strip() or "anon"
94
- digest = hashlib.sha256(host.encode("utf-8")).hexdigest()[:8]
95
- random = uuid.uuid4().hex[:8]
96
- return f"{digest}-{random}"
97
-
98
-
99
- def build_creds(
100
- *,
101
- user_key: str | None,
102
- provider: str | None,
103
- ollama_url: str | None,
104
- session_id: str | None,
105
- demo_persona: str | None,
106
- client_host: str | None,
107
- ) -> ByokCreds:
108
- """Pure factory — builds ``ByokCreds`` from raw header values.
109
-
110
- Separated from the FastAPI dependency so it is unit-testable without
111
- spinning up a Request object. Whitespace-trims every input; generates
112
- ``session_id`` server-side when the client omitted it.
113
- """
114
- return ByokCreds(
115
- user_key=(user_key or None),
116
- provider=(provider or None),
117
- ollama_url=(ollama_url or None),
118
- session_id=(session_id or "").strip() or _derive_session_id(client_host),
119
- demo_persona=(demo_persona or None),
120
- )
121
-
122
-
123
- # ── FastAPI integration ──────────────────────────────────────────────────────
124
- # Header annotations live in this branch so the module can be imported in
125
- # environments where fastapi is not installed (e.g. lightweight unit tests).
126
-
127
- try:
128
- # Runtime imports — FastAPI dependency injection reads annotations at
129
- # request time, so these must NOT live in a TYPE_CHECKING-only block.
130
- from fastapi import Header, Request # noqa: TC002
131
-
132
- _FASTAPI_AVAILABLE = True
133
- except ImportError: # pragma: no cover
134
- _FASTAPI_AVAILABLE = False
135
-
136
- def Header(*_a: object, **_kw: object) -> None: # type: ignore[no-redef] # noqa: N802 — keep FastAPI's name
137
- """No-op shim when FastAPI is not installed (lint-only env)."""
138
- return None
139
-
140
-
141
- if _FASTAPI_AVAILABLE:
142
- from typing import Annotated
143
-
144
- def extract_byok(
145
- request: Request,
146
- x_user_llm_key: Annotated[str | None, Header()] = None,
147
- x_user_provider: Annotated[str | None, Header()] = None,
148
- x_user_ollama_url: Annotated[str | None, Header()] = None,
149
- x_session_id: Annotated[str | None, Header()] = None,
150
- x_demo_persona: Annotated[str | None, Header()] = None,
151
- ) -> ByokCreds:
152
- """FastAPI dependency: extract per-request BYOK credentials.
153
-
154
- Pure data extraction — authentication, throttling, and routing
155
- decisions happen downstream so they can be unit-tested independently
156
- of FastAPI's request lifecycle.
157
- """
158
- host = request.client.host if request.client else None
159
- return build_creds(
160
- user_key=x_user_llm_key,
161
- provider=x_user_provider,
162
- ollama_url=x_user_ollama_url,
163
- session_id=x_session_id,
164
- demo_persona=x_demo_persona,
165
- client_host=host,
166
- )
 
1
+ """BYOK (Bring Your Own Key) request extraction for the public demo.
2
+
3
+ Mounted on the FastAPI surface only when ``settings.byok_mode=True`` (production
4
+ HF Space image). Extracts per-request LLM credentials and session identity from
5
+ HTTP headers so the RAG pipeline can route to the visitor's own LLM provider
6
+ and Qdrant collection.
7
+
8
+ The extracted ``ByokCreds`` is **never persisted**:
9
+
10
+ - API keys live only in the request scope (FastAPI dep dies after response)
11
+ - ``utils.pii.redact`` strips key-shaped substrings from audit log entries
12
+ - The frontend stores the key in ``localStorage`` and forwards it as a header;
13
+ cookies are forbidden (CSRF surface).
14
+
15
+ See ``launch-plan/03-backend-byok.md`` and ``launch-plan/11-security-checklist.md``.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import hashlib
21
+ import uuid
22
+ from typing import TYPE_CHECKING
23
+
24
+ from pydantic import BaseModel, ConfigDict, Field
25
+
26
+ if TYPE_CHECKING:
27
+ from fastapi import Request
28
+
29
+
30
+ # Header names the frontend sends.
31
+ HDR_USER_KEY = "X-User-LLM-Key"
32
+ HDR_USER_PROVIDER = "X-User-Provider"
33
+ HDR_USER_OLLAMA_URL = "X-User-Ollama-URL"
34
+ HDR_SESSION_ID = "X-Session-ID"
35
+ HDR_DEMO_PERSONA = "X-Demo-Persona"
36
+
37
+ # Supported provider literals carried in X-User-Provider.
38
+ SUPPORTED_PROVIDERS: frozenset[str] = frozenset({"groq", "openai", "anthropic", "ollama"})
39
+
40
+
41
+ class ByokCreds(BaseModel):
42
+ """Per-request BYOK credentials and session identity.
43
+
44
+ Attributes:
45
+ user_key: Visitor's own LLM provider API key. None means owner-key
46
+ fallback (subject to ``OwnerKeyHourThrottle``).
47
+ provider: Which LLM provider the ``user_key`` is for. Validated
48
+ against ``SUPPORTED_PROVIDERS``. None defaults to the platform
49
+ owner's configured ``cloud_provider``.
50
+ ollama_url: Visitor's Ollama instance URL when provider == "ollama".
51
+ Ignored otherwise.
52
+ session_id: Per-visitor session identifier. Drives the per-session
53
+ Qdrant collection name. Generated server-side when the visitor
54
+ does not provide one (first request of a session).
55
+ demo_persona: Optional preset RBAC profile for the public demo —
56
+ ``engineer`` / ``compliance`` / ``executive``. Translated to
57
+ ``UserContext`` downstream.
58
+ """
59
+
60
+ model_config = ConfigDict(frozen=True, str_strip_whitespace=True)
61
+
62
+ user_key: str | None = None
63
+ provider: str | None = None
64
+ ollama_url: str | None = None
65
+ session_id: str = Field(..., min_length=1, max_length=128)
66
+ demo_persona: str | None = None
67
+
68
+ def has_user_key(self) -> bool:
69
+ """True when the visitor brought their own LLM key.
70
+
71
+ Owner-key fallback (False) goes through the per-IP throttle; visitor
72
+ BYOK (True) bypasses it. Callers MUST consult this before deciding to
73
+ consume the owner-key quota.
74
+ """
75
+ return bool(self.user_key and self.user_key.strip())
76
+
77
+ def safe_provider(self) -> str | None:
78
+ """Return ``provider`` if it is in the allowlist, else None."""
79
+ if self.provider and self.provider.lower() in SUPPORTED_PROVIDERS:
80
+ return self.provider.lower()
81
+ return None
82
+
83
+
84
+ def _derive_session_id(client_host: str | None) -> str:
85
+ """Generate a deterministic-but-non-identifying session ID.
86
+
87
+ Falls back to a short hash of the client host + a random UUID. The hash
88
+ keeps the same session sticky if the visitor reconnects within the same
89
+ UVicorn worker; the random UUID ensures cross-worker / cross-restart
90
+ isolation. The full UUID flavour stays server-side — we never expose
91
+ raw IP addresses in the collection name.
92
+ """
93
+ host = (client_host or "anon").strip() or "anon"
94
+ digest = hashlib.sha256(host.encode("utf-8")).hexdigest()[:8]
95
+ random = uuid.uuid4().hex[:8]
96
+ return f"{digest}-{random}"
97
+
98
+
99
+ def build_creds(
100
+ *,
101
+ user_key: str | None,
102
+ provider: str | None,
103
+ ollama_url: str | None,
104
+ session_id: str | None,
105
+ demo_persona: str | None,
106
+ client_host: str | None,
107
+ ) -> ByokCreds:
108
+ """Pure factory — builds ``ByokCreds`` from raw header values.
109
+
110
+ Separated from the FastAPI dependency so it is unit-testable without
111
+ spinning up a Request object. Whitespace-trims every input; generates
112
+ ``session_id`` server-side when the client omitted it.
113
+ """
114
+ return ByokCreds(
115
+ user_key=(user_key or None),
116
+ provider=(provider or None),
117
+ ollama_url=(ollama_url or None),
118
+ session_id=(session_id or "").strip() or _derive_session_id(client_host),
119
+ demo_persona=(demo_persona or None),
120
+ )
121
+
122
+
123
+ # ── FastAPI integration ──────────────────────────────────────────────────────
124
+ # Header annotations live in this branch so the module can be imported in
125
+ # environments where fastapi is not installed (e.g. lightweight unit tests).
126
+
127
+ try:
128
+ # Runtime imports — FastAPI dependency injection reads annotations at
129
+ # request time, so these must NOT live in a TYPE_CHECKING-only block.
130
+ from fastapi import Header, Request # noqa: TC002
131
+
132
+ _FASTAPI_AVAILABLE = True
133
+ except ImportError: # pragma: no cover
134
+ _FASTAPI_AVAILABLE = False
135
+
136
+ def Header(*_a: object, **_kw: object) -> None: # type: ignore[no-redef] # noqa: N802 — keep FastAPI's name
137
+ """No-op shim when FastAPI is not installed (lint-only env)."""
138
+ return None
139
+
140
+
141
+ if _FASTAPI_AVAILABLE:
142
+ from typing import Annotated
143
+
144
+ def extract_byok(
145
+ request: Request,
146
+ x_user_llm_key: Annotated[str | None, Header()] = None,
147
+ x_user_provider: Annotated[str | None, Header()] = None,
148
+ x_user_ollama_url: Annotated[str | None, Header()] = None,
149
+ x_session_id: Annotated[str | None, Header()] = None,
150
+ x_demo_persona: Annotated[str | None, Header()] = None,
151
+ ) -> ByokCreds:
152
+ """FastAPI dependency: extract per-request BYOK credentials.
153
+
154
+ Pure data extraction — authentication, throttling, and routing
155
+ decisions happen downstream so they can be unit-tested independently
156
+ of FastAPI's request lifecycle.
157
+ """
158
+ host = request.client.host if request.client else None
159
+ return build_creds(
160
+ user_key=x_user_llm_key,
161
+ provider=x_user_provider,
162
+ ollama_url=x_user_ollama_url,
163
+ session_id=x_session_id,
164
+ demo_persona=x_demo_persona,
165
+ client_host=host,
166
+ )
retrieval/multitenancy.py CHANGED
@@ -1,43 +1,43 @@
1
- """Multi-tenancy utilities for Qdrant collection naming."""
2
-
3
- from __future__ import annotations
4
-
5
- from config.settings import settings
6
-
7
-
8
- def _sanitize(s: str) -> str:
9
- """Coerce ``s`` to a Qdrant-safe identifier (alnum + underscore only)."""
10
- return "".join(c if c.isalnum() else "_" for c in s)
11
-
12
-
13
- def get_collection_name(
14
- org_id: str | None = None,
15
- *,
16
- session_id: str | None = None,
17
- ) -> str:
18
- """Return the Qdrant collection name for a given org or BYOK session.
19
-
20
- Resolution order:
21
-
22
- 1. **BYOK mode** (``settings.byok_mode=True``) with ``session_id`` →
23
- returns ``"{base}_sess_{sanitized_session}"``. Session-scoped
24
- collections isolate each visitor's uploads.
25
- 2. **Multi-tenant** (``settings.multi_tenant_collections=True``) with
26
- ``org_id`` → returns ``"{base}_{sanitized_org}"``.
27
- 3. **Single-tenant** (default) → returns ``settings.qdrant_collection``.
28
-
29
- Args:
30
- org_id: Organisation identifier (multi-tenant mode).
31
- session_id: Per-visitor session UUID (BYOK mode). Takes priority over
32
- ``org_id`` when both are set and BYOK is on, because BYOK is the
33
- stricter isolation boundary.
34
-
35
- Returns:
36
- Collection name string suitable for QdrantManager.
37
- """
38
- base = settings.qdrant_collection
39
- if settings.byok_mode and session_id:
40
- return f"{base}_sess_{_sanitize(session_id)}"
41
- if not settings.multi_tenant_collections or not org_id:
42
- return base
43
- return f"{base}_{_sanitize(org_id)}"
 
1
+ """Multi-tenancy utilities for Qdrant collection naming."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from config.settings import settings
6
+
7
+
8
+ def _sanitize(s: str) -> str:
9
+ """Coerce ``s`` to a Qdrant-safe identifier (alnum + underscore only)."""
10
+ return "".join(c if c.isalnum() else "_" for c in s)
11
+
12
+
13
+ def get_collection_name(
14
+ org_id: str | None = None,
15
+ *,
16
+ session_id: str | None = None,
17
+ ) -> str:
18
+ """Return the Qdrant collection name for a given org or BYOK session.
19
+
20
+ Resolution order:
21
+
22
+ 1. **BYOK mode** (``settings.byok_mode=True``) with ``session_id`` →
23
+ returns ``"{base}_sess_{sanitized_session}"``. Session-scoped
24
+ collections isolate each visitor's uploads.
25
+ 2. **Multi-tenant** (``settings.multi_tenant_collections=True``) with
26
+ ``org_id`` → returns ``"{base}_{sanitized_org}"``.
27
+ 3. **Single-tenant** (default) → returns ``settings.qdrant_collection``.
28
+
29
+ Args:
30
+ org_id: Organisation identifier (multi-tenant mode).
31
+ session_id: Per-visitor session UUID (BYOK mode). Takes priority over
32
+ ``org_id`` when both are set and BYOK is on, because BYOK is the
33
+ stricter isolation boundary.
34
+
35
+ Returns:
36
+ Collection name string suitable for QdrantManager.
37
+ """
38
+ base = settings.qdrant_collection
39
+ if settings.byok_mode and session_id:
40
+ return f"{base}_sess_{_sanitize(session_id)}"
41
+ if not settings.multi_tenant_collections or not org_id:
42
+ return base
43
+ return f"{base}_{_sanitize(org_id)}"
retrieval/session_purge.py CHANGED
@@ -1,185 +1,185 @@
1
- """Per-session Qdrant collection purge for BYOK mode.
2
-
3
- In BYOK mode each visitor's uploads land in a collection named
4
- ``documents_sess_<sanitized_session_id>``. Without a cleanup pass these
5
- collections accumulate until the 1 GB Qdrant Cloud free tier fills up.
6
-
7
- This module provides:
8
-
9
- - :func:`purge_expired_sessions` — synchronous, idempotent sweep that
10
- deletes collections whose creation timestamp is older than
11
- ``settings.session_collection_ttl_hours``.
12
- - :func:`schedule_session_purge` — APScheduler hook the FastAPI lifespan
13
- calls so the sweep runs every 6 hours inside the same process. No
14
- separate cron container required.
15
-
16
- The creation timestamp is read from Qdrant's
17
- ``CollectionInfo.config.params.metadata`` (set at create-time by the
18
- ingestion pipeline). Collections without a creation timestamp are treated
19
- as legacy and **skipped** — we never delete data we can't date.
20
-
21
- See ``launch-plan/03-backend-byok.md`` § Session purge cron.
22
- """
23
-
24
- from __future__ import annotations
25
-
26
- from datetime import UTC, datetime, timedelta
27
- from typing import TYPE_CHECKING, Any
28
-
29
- from config.settings import settings
30
- from utils.logging import get_logger
31
-
32
- if TYPE_CHECKING:
33
- from qdrant_client import QdrantClient
34
-
35
- logger = get_logger(__name__)
36
-
37
-
38
- SESSION_COLLECTION_PREFIX = "_sess_"
39
- """Suffix introduced into the collection name by ``get_collection_name`` when
40
- ``byok_mode`` is on and a ``session_id`` is supplied. Used here to filter the
41
- purge sweep to BYOK collections only — multi-tenant org collections are NOT
42
- touched."""
43
-
44
-
45
- def _session_collection_prefix() -> str:
46
- """Concrete prefix for the current base collection (e.g. ``documents_sess_``)."""
47
- return f"{settings.qdrant_collection}{SESSION_COLLECTION_PREFIX}"
48
-
49
-
50
- def _is_session_collection(name: str) -> bool:
51
- """True iff ``name`` was emitted by ``get_collection_name`` with a session_id."""
52
- return name.startswith(_session_collection_prefix())
53
-
54
-
55
- def _parse_created_at(meta: dict[str, Any] | None) -> datetime | None:
56
- """Return the collection's recorded creation datetime, or None if missing.
57
-
58
- The ingestion pipeline writes ``created_at`` as an ISO-8601 UTC string into
59
- the collection's metadata payload when first creating a session
60
- collection. Older collections lack the field — those are intentionally
61
- skipped to avoid deleting data we cannot date.
62
- """
63
- if not meta:
64
- return None
65
- raw = meta.get("created_at")
66
- if not raw:
67
- return None
68
- try:
69
- # Accept both ``2026-05-26T13:00:00+00:00`` and trailing ``Z`` forms.
70
- return datetime.fromisoformat(str(raw).replace("Z", "+00:00"))
71
- except (TypeError, ValueError):
72
- logger.warning("session_purge_bad_timestamp", value=str(raw))
73
- return None
74
-
75
-
76
- def purge_expired_sessions(
77
- client: QdrantClient,
78
- *,
79
- ttl_hours: int | None = None,
80
- now: datetime | None = None,
81
- ) -> dict[str, Any]:
82
- """Delete BYOK session collections older than the TTL.
83
-
84
- Args:
85
- client: Live ``QdrantClient`` (cloud or local).
86
- ttl_hours: Override ``settings.session_collection_ttl_hours``. Tests
87
- pass small values; production uses the default 24.
88
- now: Override the clock for deterministic tests.
89
-
90
- Returns:
91
- Summary dict with counts (``inspected``, ``deleted``, ``skipped``,
92
- ``errors``) suitable for emission to the audit log.
93
- """
94
- ttl = ttl_hours if ttl_hours is not None else settings.session_collection_ttl_hours
95
- horizon = (now or datetime.now(UTC)) - timedelta(hours=ttl)
96
- inspected = deleted = skipped = errors = 0
97
- deleted_names: list[str] = []
98
-
99
- try:
100
- collections = client.get_collections().collections
101
- except Exception as exc:
102
- logger.error("session_purge_list_failed", error=str(exc))
103
- return {"inspected": 0, "deleted": 0, "skipped": 0, "errors": 1}
104
-
105
- for col in collections:
106
- name = col.name
107
- if not _is_session_collection(name):
108
- continue
109
- inspected += 1
110
- try:
111
- info = client.get_collection(name)
112
- meta = getattr(info.config.params, "metadata", None) or {}
113
- created = _parse_created_at(meta)
114
- if created is None:
115
- # Undated -> skip; we don't delete what we can't time-stamp.
116
- skipped += 1
117
- continue
118
- if created < horizon:
119
- client.delete_collection(name)
120
- deleted += 1
121
- deleted_names.append(name)
122
- logger.info(
123
- "session_purge_deleted",
124
- collection=name,
125
- created_at=created.isoformat(),
126
- age_hours=round((horizon - created).total_seconds() / 3600.0 + ttl, 1),
127
- )
128
- else:
129
- skipped += 1
130
- except Exception as exc:
131
- errors += 1
132
- logger.warning("session_purge_collection_failed", collection=name, error=str(exc))
133
-
134
- summary = {
135
- "inspected": inspected,
136
- "deleted": deleted,
137
- "skipped": skipped,
138
- "errors": errors,
139
- "deleted_names": deleted_names,
140
- "ttl_hours": ttl,
141
- }
142
- logger.info(
143
- "session_purge_summary", **{k: v for k, v in summary.items() if k != "deleted_names"}
144
- )
145
- return summary
146
-
147
-
148
- # ── FastAPI lifespan hook ────────────────────────────────────────────────────
149
-
150
-
151
- def schedule_session_purge(client: QdrantClient, *, interval_hours: int = 6) -> Any | None:
152
- """Start an APScheduler job that runs :func:`purge_expired_sessions` periodically.
153
-
154
- Called from the FastAPI ``lifespan`` context manager. Returns the
155
- ``AsyncIOScheduler`` instance (or None when APScheduler is not
156
- installed — we then run as a single-shot at startup so at least one
157
- sweep happens per restart).
158
- """
159
- if not settings.byok_mode:
160
- logger.debug("session_purge_not_scheduled", reason="byok_mode is off")
161
- return None
162
-
163
- try:
164
- from apscheduler.schedulers.asyncio import (
165
- AsyncIOScheduler, # type: ignore[import-not-found]
166
- )
167
- except ImportError:
168
- # Optional dep absent: at least sweep once so the Space does not
169
- # accumulate indefinitely on long uptimes.
170
- logger.warning("apscheduler_missing", action="single-shot purge instead")
171
- purge_expired_sessions(client)
172
- return None
173
-
174
- scheduler = AsyncIOScheduler()
175
- scheduler.add_job(
176
- purge_expired_sessions,
177
- "interval",
178
- hours=interval_hours,
179
- args=[client],
180
- id="byok-session-purge",
181
- replace_existing=True,
182
- )
183
- scheduler.start()
184
- logger.info("session_purge_scheduled", every_hours=interval_hours)
185
- return scheduler
 
1
+ """Per-session Qdrant collection purge for BYOK mode.
2
+
3
+ In BYOK mode each visitor's uploads land in a collection named
4
+ ``documents_sess_<sanitized_session_id>``. Without a cleanup pass these
5
+ collections accumulate until the 1 GB Qdrant Cloud free tier fills up.
6
+
7
+ This module provides:
8
+
9
+ - :func:`purge_expired_sessions` — synchronous, idempotent sweep that
10
+ deletes collections whose creation timestamp is older than
11
+ ``settings.session_collection_ttl_hours``.
12
+ - :func:`schedule_session_purge` — APScheduler hook the FastAPI lifespan
13
+ calls so the sweep runs every 6 hours inside the same process. No
14
+ separate cron container required.
15
+
16
+ The creation timestamp is read from Qdrant's
17
+ ``CollectionInfo.config.params.metadata`` (set at create-time by the
18
+ ingestion pipeline). Collections without a creation timestamp are treated
19
+ as legacy and **skipped** — we never delete data we can't date.
20
+
21
+ See ``launch-plan/03-backend-byok.md`` § Session purge cron.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from datetime import UTC, datetime, timedelta
27
+ from typing import TYPE_CHECKING, Any
28
+
29
+ from config.settings import settings
30
+ from utils.logging import get_logger
31
+
32
+ if TYPE_CHECKING:
33
+ from qdrant_client import QdrantClient
34
+
35
+ logger = get_logger(__name__)
36
+
37
+
38
+ SESSION_COLLECTION_PREFIX = "_sess_"
39
+ """Suffix introduced into the collection name by ``get_collection_name`` when
40
+ ``byok_mode`` is on and a ``session_id`` is supplied. Used here to filter the
41
+ purge sweep to BYOK collections only — multi-tenant org collections are NOT
42
+ touched."""
43
+
44
+
45
+ def _session_collection_prefix() -> str:
46
+ """Concrete prefix for the current base collection (e.g. ``documents_sess_``)."""
47
+ return f"{settings.qdrant_collection}{SESSION_COLLECTION_PREFIX}"
48
+
49
+
50
+ def _is_session_collection(name: str) -> bool:
51
+ """True iff ``name`` was emitted by ``get_collection_name`` with a session_id."""
52
+ return name.startswith(_session_collection_prefix())
53
+
54
+
55
+ def _parse_created_at(meta: dict[str, Any] | None) -> datetime | None:
56
+ """Return the collection's recorded creation datetime, or None if missing.
57
+
58
+ The ingestion pipeline writes ``created_at`` as an ISO-8601 UTC string into
59
+ the collection's metadata payload when first creating a session
60
+ collection. Older collections lack the field — those are intentionally
61
+ skipped to avoid deleting data we cannot date.
62
+ """
63
+ if not meta:
64
+ return None
65
+ raw = meta.get("created_at")
66
+ if not raw:
67
+ return None
68
+ try:
69
+ # Accept both ``2026-05-26T13:00:00+00:00`` and trailing ``Z`` forms.
70
+ return datetime.fromisoformat(str(raw).replace("Z", "+00:00"))
71
+ except (TypeError, ValueError):
72
+ logger.warning("session_purge_bad_timestamp", value=str(raw))
73
+ return None
74
+
75
+
76
+ def purge_expired_sessions(
77
+ client: QdrantClient,
78
+ *,
79
+ ttl_hours: int | None = None,
80
+ now: datetime | None = None,
81
+ ) -> dict[str, Any]:
82
+ """Delete BYOK session collections older than the TTL.
83
+
84
+ Args:
85
+ client: Live ``QdrantClient`` (cloud or local).
86
+ ttl_hours: Override ``settings.session_collection_ttl_hours``. Tests
87
+ pass small values; production uses the default 24.
88
+ now: Override the clock for deterministic tests.
89
+
90
+ Returns:
91
+ Summary dict with counts (``inspected``, ``deleted``, ``skipped``,
92
+ ``errors``) suitable for emission to the audit log.
93
+ """
94
+ ttl = ttl_hours if ttl_hours is not None else settings.session_collection_ttl_hours
95
+ horizon = (now or datetime.now(UTC)) - timedelta(hours=ttl)
96
+ inspected = deleted = skipped = errors = 0
97
+ deleted_names: list[str] = []
98
+
99
+ try:
100
+ collections = client.get_collections().collections
101
+ except Exception as exc:
102
+ logger.error("session_purge_list_failed", error=str(exc))
103
+ return {"inspected": 0, "deleted": 0, "skipped": 0, "errors": 1}
104
+
105
+ for col in collections:
106
+ name = col.name
107
+ if not _is_session_collection(name):
108
+ continue
109
+ inspected += 1
110
+ try:
111
+ info = client.get_collection(name)
112
+ meta = getattr(info.config.params, "metadata", None) or {}
113
+ created = _parse_created_at(meta)
114
+ if created is None:
115
+ # Undated -> skip; we don't delete what we can't time-stamp.
116
+ skipped += 1
117
+ continue
118
+ if created < horizon:
119
+ client.delete_collection(name)
120
+ deleted += 1
121
+ deleted_names.append(name)
122
+ logger.info(
123
+ "session_purge_deleted",
124
+ collection=name,
125
+ created_at=created.isoformat(),
126
+ age_hours=round((horizon - created).total_seconds() / 3600.0 + ttl, 1),
127
+ )
128
+ else:
129
+ skipped += 1
130
+ except Exception as exc:
131
+ errors += 1
132
+ logger.warning("session_purge_collection_failed", collection=name, error=str(exc))
133
+
134
+ summary = {
135
+ "inspected": inspected,
136
+ "deleted": deleted,
137
+ "skipped": skipped,
138
+ "errors": errors,
139
+ "deleted_names": deleted_names,
140
+ "ttl_hours": ttl,
141
+ }
142
+ logger.info(
143
+ "session_purge_summary", **{k: v for k, v in summary.items() if k != "deleted_names"}
144
+ )
145
+ return summary
146
+
147
+
148
+ # ── FastAPI lifespan hook ────────────────────────────────────────────────────
149
+
150
+
151
+ def schedule_session_purge(client: QdrantClient, *, interval_hours: int = 6) -> Any | None:
152
+ """Start an APScheduler job that runs :func:`purge_expired_sessions` periodically.
153
+
154
+ Called from the FastAPI ``lifespan`` context manager. Returns the
155
+ ``AsyncIOScheduler`` instance (or None when APScheduler is not
156
+ installed — we then run as a single-shot at startup so at least one
157
+ sweep happens per restart).
158
+ """
159
+ if not settings.byok_mode:
160
+ logger.debug("session_purge_not_scheduled", reason="byok_mode is off")
161
+ return None
162
+
163
+ try:
164
+ from apscheduler.schedulers.asyncio import (
165
+ AsyncIOScheduler, # type: ignore[import-not-found]
166
+ )
167
+ except ImportError:
168
+ # Optional dep absent: at least sweep once so the Space does not
169
+ # accumulate indefinitely on long uptimes.
170
+ logger.warning("apscheduler_missing", action="single-shot purge instead")
171
+ purge_expired_sessions(client)
172
+ return None
173
+
174
+ scheduler = AsyncIOScheduler()
175
+ scheduler.add_job(
176
+ purge_expired_sessions,
177
+ "interval",
178
+ hours=interval_hours,
179
+ args=[client],
180
+ id="byok-session-purge",
181
+ replace_existing=True,
182
+ )
183
+ scheduler.start()
184
+ logger.info("session_purge_scheduled", every_hours=interval_hours)
185
+ return scheduler
utils/observability.py CHANGED
@@ -1,252 +1,252 @@
1
- """Observability setup using Arize Phoenix for LLM tracing.
2
-
3
- Provides OpenTelemetry-compatible distributed tracing for LLM calls,
4
- retrieval operations, and LangGraph execution. Gracefully degrades
5
- when Phoenix is not installed or configured.
6
-
7
- Usage:
8
- Call setup_tracing() once at application startup (e.g., in app/main.py).
9
- All trace_* functions will automatically emit spans when tracing is enabled.
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- from config.settings import settings
15
- from utils.logging import get_logger
16
-
17
- _log = get_logger(__name__)
18
-
19
- # Module-level state
20
- _tracer = None
21
- _phoenix_configured = False
22
- _phoenix_project_name: str = settings.app_name
23
-
24
-
25
- def setup_tracing() -> bool:
26
- """Initialize Phoenix tracing if ``settings.phoenix_endpoint`` is set.
27
-
28
- This function is safe to call unconditionally at startup — it will
29
- log a message and return immediately if Phoenix is not configured.
30
- Tracing failures never crash the application.
31
-
32
- Returns:
33
- True if tracing was successfully enabled, False otherwise.
34
- """
35
- global _tracer, _phoenix_configured, _phoenix_project_name
36
-
37
- # BYOK mode mandates: no third-party telemetry sees a request. Phoenix
38
- # spans capture LLM prompts and completions, which would include the
39
- # visitor's keys-in-context and any private text they uploaded. Hard
40
- # disable in BYOK regardless of phoenix_endpoint configuration.
41
- if settings.byok_mode:
42
- _log.info("phoenix_tracing_disabled", reason="BYOK mode forbids external telemetry")
43
- return False
44
-
45
- if not settings.phoenix_endpoint:
46
- _log.info("phoenix_tracing_disabled", reason="No phoenix_endpoint configured")
47
- return False
48
-
49
- try:
50
- from phoenix.otel import register
51
-
52
- tracer_provider = register(
53
- project_name=settings.app_name,
54
- endpoint=settings.phoenix_endpoint,
55
- )
56
-
57
- # Attempt to instrument LLM and retrieval calls
58
- _instrument_providers()
59
-
60
- _phoenix_configured = True
61
- _phoenix_project_name = settings.app_name
62
- _log.info(
63
- "phoenix_tracing_enabled",
64
- endpoint=settings.phoenix_endpoint,
65
- project=settings.app_name,
66
- tracer_provider=str(tracer_provider),
67
- )
68
- return True
69
- except ImportError:
70
- _log.warning(
71
- "phoenix_import_failed",
72
- msg=(
73
- "arize-phoenix not installed; tracing unavailable. "
74
- "Install with: pip install 'arize-phoenix-otel'"
75
- ),
76
- )
77
- return False
78
- except Exception as exc:
79
- _log.error(
80
- "phoenix_tracing_init_error",
81
- error=str(exc),
82
- endpoint=settings.phoenix_endpoint,
83
- )
84
- return False
85
-
86
-
87
- def _instrument_providers() -> None:
88
- """Instrument LLM and retrieval providers with OpenTelemetry.
89
-
90
- Attempts to auto-instrument supported providers. Failures are
91
- logged but never raised — partial instrumentation is acceptable.
92
- """
93
- # Instrument LangChain/LangGraph if available
94
- try:
95
- from openinference.instrumentation.langchain import LangChainInstrumentor
96
-
97
- LangChainInstrumentor().instrument()
98
- _log.info("instrumented_langchain")
99
- except ImportError:
100
- _log.debug(
101
- "langchain_instrumentation_skipped",
102
- reason="openinference-instrumentation-langchain not installed",
103
- )
104
- except Exception as exc:
105
- _log.debug("langchain_instrumentation_error", reason=str(exc))
106
-
107
- # Instrument OpenAI-compatible calls if available
108
- try:
109
- from openinference.instrumentation.openai import OpenAIInstrumentor
110
-
111
- OpenAIInstrumentor().instrument()
112
- _log.info("instrumented_openai")
113
- except ImportError:
114
- _log.debug(
115
- "openai_instrumentation_skipped",
116
- reason="openinference-instrumentation-openai not installed",
117
- )
118
- except Exception as exc:
119
- _log.debug("openai_instrumentation_error", reason=str(exc))
120
-
121
-
122
- def trace_llm_call(
123
- provider: str,
124
- model: str,
125
- prompt: str,
126
- response: str,
127
- latency_ms: float,
128
- tokens: dict[str, int] | None = None,
129
- ) -> None:
130
- """Record a manual trace span for an LLM call.
131
-
132
- Can be used as an explicit trace point when auto-instrumentation
133
- is unavailable or for custom tracking.
134
-
135
- Args:
136
- provider: LLM provider name (e.g., "ollama", "groq").
137
- model: Model identifier used for generation.
138
- prompt: The input prompt text.
139
- response: The generated response text.
140
- latency_ms: Response latency in milliseconds.
141
- tokens: Optional token usage dict with keys like
142
- "prompt_tokens", "completion_tokens", "total_tokens".
143
- """
144
- if not _phoenix_configured:
145
- return
146
-
147
- try:
148
- from opentelemetry import trace
149
-
150
- tracer = trace.get_tracer("secureagentrag.llm")
151
- with tracer.start_as_current_span("llm_call") as span:
152
- span.set_attribute("llm.provider", provider)
153
- span.set_attribute("llm.model", model)
154
- span.set_attribute("llm.prompt_length", len(prompt))
155
- span.set_attribute("llm.response_length", len(response))
156
- span.set_attribute("llm.latency_ms", latency_ms)
157
- if tokens:
158
- for key, value in tokens.items():
159
- span.set_attribute(f"llm.tokens.{key}", value)
160
- except Exception as exc:
161
- _log.debug("trace_llm_call_failed", error=str(exc))
162
-
163
-
164
- def trace_retrieval(
165
- query: str,
166
- num_results: int,
167
- latency_ms: float,
168
- method: str = "hybrid",
169
- ) -> None:
170
- """Record a manual trace span for a retrieval operation.
171
-
172
- Args:
173
- query: The search query string.
174
- num_results: Number of results returned.
175
- latency_ms: Retrieval latency in milliseconds.
176
- method: Retrieval method used ("hybrid", "dense", "bm25").
177
- """
178
- if not _phoenix_configured:
179
- return
180
-
181
- try:
182
- from opentelemetry import trace
183
-
184
- tracer = trace.get_tracer("secureagentrag.retrieval")
185
- with tracer.start_as_current_span("retrieval") as span:
186
- span.set_attribute("retrieval.query_length", len(query))
187
- span.set_attribute("retrieval.num_results", num_results)
188
- span.set_attribute("retrieval.latency_ms", latency_ms)
189
- span.set_attribute("retrieval.method", method)
190
- except Exception as exc:
191
- _log.debug("trace_retrieval_failed", error=str(exc))
192
-
193
-
194
- def trace_graph_execution(
195
- query: str,
196
- nodes_executed: list[str],
197
- total_latency_ms: float,
198
- final_confidence: float,
199
- retries: int = 0,
200
- ) -> None:
201
- """Record a manual trace span for LangGraph pipeline execution.
202
-
203
- Args:
204
- query: The original user query.
205
- nodes_executed: List of graph node names that were executed.
206
- total_latency_ms: Total pipeline execution time in milliseconds.
207
- final_confidence: Final confidence score of the generated answer.
208
- retries: Number of corrective retrieval retries performed.
209
- """
210
- if not _phoenix_configured:
211
- return
212
-
213
- try:
214
- from opentelemetry import trace
215
-
216
- tracer = trace.get_tracer("secureagentrag.graph")
217
- with tracer.start_as_current_span("graph_execution") as span:
218
- span.set_attribute("graph.query_length", len(query))
219
- span.set_attribute("graph.nodes_executed", ",".join(nodes_executed))
220
- span.set_attribute("graph.total_latency_ms", total_latency_ms)
221
- span.set_attribute("graph.confidence", final_confidence)
222
- span.set_attribute("graph.retries", retries)
223
- except Exception as exc:
224
- _log.debug("trace_graph_execution_failed", error=str(exc))
225
-
226
-
227
- def get_trace_url() -> str | None:
228
- """Return the Phoenix dashboard URL if tracing is configured.
229
-
230
- Returns:
231
- Phoenix UI URL string, or None if Phoenix is not configured.
232
- """
233
- if not _phoenix_configured or not settings.phoenix_endpoint:
234
- return None
235
-
236
- # Phoenix UI typically runs on the same host
237
- endpoint = settings.phoenix_endpoint.rstrip("/")
238
- # Replace gRPC/collector port with UI port if needed
239
- if ":4317" in endpoint:
240
- return endpoint.replace(":4317", ":6006")
241
- if ":6006" in endpoint:
242
- return endpoint
243
- return endpoint
244
-
245
-
246
- def is_tracing_enabled() -> bool:
247
- """Check if Phoenix tracing is currently active.
248
-
249
- Returns:
250
- True if tracing was successfully configured, False otherwise.
251
- """
252
- return _phoenix_configured
 
1
+ """Observability setup using Arize Phoenix for LLM tracing.
2
+
3
+ Provides OpenTelemetry-compatible distributed tracing for LLM calls,
4
+ retrieval operations, and LangGraph execution. Gracefully degrades
5
+ when Phoenix is not installed or configured.
6
+
7
+ Usage:
8
+ Call setup_tracing() once at application startup (e.g., in app/main.py).
9
+ All trace_* functions will automatically emit spans when tracing is enabled.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from config.settings import settings
15
+ from utils.logging import get_logger
16
+
17
+ _log = get_logger(__name__)
18
+
19
+ # Module-level state
20
+ _tracer = None
21
+ _phoenix_configured = False
22
+ _phoenix_project_name: str = settings.app_name
23
+
24
+
25
+ def setup_tracing() -> bool:
26
+ """Initialize Phoenix tracing if ``settings.phoenix_endpoint`` is set.
27
+
28
+ This function is safe to call unconditionally at startup — it will
29
+ log a message and return immediately if Phoenix is not configured.
30
+ Tracing failures never crash the application.
31
+
32
+ Returns:
33
+ True if tracing was successfully enabled, False otherwise.
34
+ """
35
+ global _tracer, _phoenix_configured, _phoenix_project_name
36
+
37
+ # BYOK mode mandates: no third-party telemetry sees a request. Phoenix
38
+ # spans capture LLM prompts and completions, which would include the
39
+ # visitor's keys-in-context and any private text they uploaded. Hard
40
+ # disable in BYOK regardless of phoenix_endpoint configuration.
41
+ if settings.byok_mode:
42
+ _log.info("phoenix_tracing_disabled", reason="BYOK mode forbids external telemetry")
43
+ return False
44
+
45
+ if not settings.phoenix_endpoint:
46
+ _log.info("phoenix_tracing_disabled", reason="No phoenix_endpoint configured")
47
+ return False
48
+
49
+ try:
50
+ from phoenix.otel import register
51
+
52
+ tracer_provider = register(
53
+ project_name=settings.app_name,
54
+ endpoint=settings.phoenix_endpoint,
55
+ )
56
+
57
+ # Attempt to instrument LLM and retrieval calls
58
+ _instrument_providers()
59
+
60
+ _phoenix_configured = True
61
+ _phoenix_project_name = settings.app_name
62
+ _log.info(
63
+ "phoenix_tracing_enabled",
64
+ endpoint=settings.phoenix_endpoint,
65
+ project=settings.app_name,
66
+ tracer_provider=str(tracer_provider),
67
+ )
68
+ return True
69
+ except ImportError:
70
+ _log.warning(
71
+ "phoenix_import_failed",
72
+ msg=(
73
+ "arize-phoenix not installed; tracing unavailable. "
74
+ "Install with: pip install 'arize-phoenix-otel'"
75
+ ),
76
+ )
77
+ return False
78
+ except Exception as exc:
79
+ _log.error(
80
+ "phoenix_tracing_init_error",
81
+ error=str(exc),
82
+ endpoint=settings.phoenix_endpoint,
83
+ )
84
+ return False
85
+
86
+
87
+ def _instrument_providers() -> None:
88
+ """Instrument LLM and retrieval providers with OpenTelemetry.
89
+
90
+ Attempts to auto-instrument supported providers. Failures are
91
+ logged but never raised — partial instrumentation is acceptable.
92
+ """
93
+ # Instrument LangChain/LangGraph if available
94
+ try:
95
+ from openinference.instrumentation.langchain import LangChainInstrumentor
96
+
97
+ LangChainInstrumentor().instrument()
98
+ _log.info("instrumented_langchain")
99
+ except ImportError:
100
+ _log.debug(
101
+ "langchain_instrumentation_skipped",
102
+ reason="openinference-instrumentation-langchain not installed",
103
+ )
104
+ except Exception as exc:
105
+ _log.debug("langchain_instrumentation_error", reason=str(exc))
106
+
107
+ # Instrument OpenAI-compatible calls if available
108
+ try:
109
+ from openinference.instrumentation.openai import OpenAIInstrumentor
110
+
111
+ OpenAIInstrumentor().instrument()
112
+ _log.info("instrumented_openai")
113
+ except ImportError:
114
+ _log.debug(
115
+ "openai_instrumentation_skipped",
116
+ reason="openinference-instrumentation-openai not installed",
117
+ )
118
+ except Exception as exc:
119
+ _log.debug("openai_instrumentation_error", reason=str(exc))
120
+
121
+
122
+ def trace_llm_call(
123
+ provider: str,
124
+ model: str,
125
+ prompt: str,
126
+ response: str,
127
+ latency_ms: float,
128
+ tokens: dict[str, int] | None = None,
129
+ ) -> None:
130
+ """Record a manual trace span for an LLM call.
131
+
132
+ Can be used as an explicit trace point when auto-instrumentation
133
+ is unavailable or for custom tracking.
134
+
135
+ Args:
136
+ provider: LLM provider name (e.g., "ollama", "groq").
137
+ model: Model identifier used for generation.
138
+ prompt: The input prompt text.
139
+ response: The generated response text.
140
+ latency_ms: Response latency in milliseconds.
141
+ tokens: Optional token usage dict with keys like
142
+ "prompt_tokens", "completion_tokens", "total_tokens".
143
+ """
144
+ if not _phoenix_configured:
145
+ return
146
+
147
+ try:
148
+ from opentelemetry import trace
149
+
150
+ tracer = trace.get_tracer("secureagentrag.llm")
151
+ with tracer.start_as_current_span("llm_call") as span:
152
+ span.set_attribute("llm.provider", provider)
153
+ span.set_attribute("llm.model", model)
154
+ span.set_attribute("llm.prompt_length", len(prompt))
155
+ span.set_attribute("llm.response_length", len(response))
156
+ span.set_attribute("llm.latency_ms", latency_ms)
157
+ if tokens:
158
+ for key, value in tokens.items():
159
+ span.set_attribute(f"llm.tokens.{key}", value)
160
+ except Exception as exc:
161
+ _log.debug("trace_llm_call_failed", error=str(exc))
162
+
163
+
164
+ def trace_retrieval(
165
+ query: str,
166
+ num_results: int,
167
+ latency_ms: float,
168
+ method: str = "hybrid",
169
+ ) -> None:
170
+ """Record a manual trace span for a retrieval operation.
171
+
172
+ Args:
173
+ query: The search query string.
174
+ num_results: Number of results returned.
175
+ latency_ms: Retrieval latency in milliseconds.
176
+ method: Retrieval method used ("hybrid", "dense", "bm25").
177
+ """
178
+ if not _phoenix_configured:
179
+ return
180
+
181
+ try:
182
+ from opentelemetry import trace
183
+
184
+ tracer = trace.get_tracer("secureagentrag.retrieval")
185
+ with tracer.start_as_current_span("retrieval") as span:
186
+ span.set_attribute("retrieval.query_length", len(query))
187
+ span.set_attribute("retrieval.num_results", num_results)
188
+ span.set_attribute("retrieval.latency_ms", latency_ms)
189
+ span.set_attribute("retrieval.method", method)
190
+ except Exception as exc:
191
+ _log.debug("trace_retrieval_failed", error=str(exc))
192
+
193
+
194
+ def trace_graph_execution(
195
+ query: str,
196
+ nodes_executed: list[str],
197
+ total_latency_ms: float,
198
+ final_confidence: float,
199
+ retries: int = 0,
200
+ ) -> None:
201
+ """Record a manual trace span for LangGraph pipeline execution.
202
+
203
+ Args:
204
+ query: The original user query.
205
+ nodes_executed: List of graph node names that were executed.
206
+ total_latency_ms: Total pipeline execution time in milliseconds.
207
+ final_confidence: Final confidence score of the generated answer.
208
+ retries: Number of corrective retrieval retries performed.
209
+ """
210
+ if not _phoenix_configured:
211
+ return
212
+
213
+ try:
214
+ from opentelemetry import trace
215
+
216
+ tracer = trace.get_tracer("secureagentrag.graph")
217
+ with tracer.start_as_current_span("graph_execution") as span:
218
+ span.set_attribute("graph.query_length", len(query))
219
+ span.set_attribute("graph.nodes_executed", ",".join(nodes_executed))
220
+ span.set_attribute("graph.total_latency_ms", total_latency_ms)
221
+ span.set_attribute("graph.confidence", final_confidence)
222
+ span.set_attribute("graph.retries", retries)
223
+ except Exception as exc:
224
+ _log.debug("trace_graph_execution_failed", error=str(exc))
225
+
226
+
227
+ def get_trace_url() -> str | None:
228
+ """Return the Phoenix dashboard URL if tracing is configured.
229
+
230
+ Returns:
231
+ Phoenix UI URL string, or None if Phoenix is not configured.
232
+ """
233
+ if not _phoenix_configured or not settings.phoenix_endpoint:
234
+ return None
235
+
236
+ # Phoenix UI typically runs on the same host
237
+ endpoint = settings.phoenix_endpoint.rstrip("/")
238
+ # Replace gRPC/collector port with UI port if needed
239
+ if ":4317" in endpoint:
240
+ return endpoint.replace(":4317", ":6006")
241
+ if ":6006" in endpoint:
242
+ return endpoint
243
+ return endpoint
244
+
245
+
246
+ def is_tracing_enabled() -> bool:
247
+ """Check if Phoenix tracing is currently active.
248
+
249
+ Returns:
250
+ True if tracing was successfully configured, False otherwise.
251
+ """
252
+ return _phoenix_configured
utils/pii.py CHANGED
@@ -1,146 +1,146 @@
1
- """PII redaction for secondary stores (audit log, query cache, conversation history).
2
-
3
- Two strategies:
4
- - Regex-based (always available) — covers email, phone, SSN, credit card,
5
- IBAN, IPv4, URL with credentials.
6
- - Microsoft Presidio (optional dependency) — invoked when installed; higher
7
- recall and language-aware NER for names, locations, organizations.
8
-
9
- This module never sees plaintext from the LLM — it operates only on text that
10
- is about to be persisted to disk. Live prompts and retrieved contexts remain
11
- unmodified so model quality is not affected.
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import re
17
- from typing import Any
18
-
19
- from config.settings import settings
20
- from utils.logging import get_logger
21
-
22
- logger = get_logger(__name__)
23
-
24
- # Order matters — most specific patterns first so they win against the
25
- # broader phone regex. Provider-specific API-key shapes (added 2026-05-26
26
- # for BYOK mode) live ABOVE the generic ``[API_KEY]`` rule because their
27
- # prefixes are not catchable by the legacy ``(sk|pk|api|key)`` alternation.
28
- _REGEX_PATTERNS: list[tuple[re.Pattern[str], str]] = [
29
- (re.compile(r"https?://[^\s/]+:[^\s/]+@[^\s]+"), "[URL_WITH_CREDS]"),
30
- (re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"), "[EMAIL]"),
31
- (re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[SSN]"),
32
- (re.compile(r"\b(?:\d[ -]*?){13,19}\b"), "[CC]"), # Luhn-validated below
33
- (re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{10,30}\b"), "[IBAN]"),
34
- (re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b"), "[IP]"),
35
- # ── BYOK key shapes (P6 production launch) ──────────────────────────
36
- # Anthropic must come BEFORE OpenAI because ``sk-ant-...`` also matches
37
- # the generic ``sk-...`` rule below.
38
- (re.compile(r"\bsk-ant-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
39
- (re.compile(r"\bsk-(?:proj|svcacct)-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
40
- (re.compile(r"\bgsk_[A-Za-z0-9]{40,}\b"), "[API_KEY]"),
41
- (re.compile(r"\bhf_[A-Za-z0-9]{30,}\b"), "[API_KEY]"),
42
- (re.compile(r"\bvcp_[A-Za-z0-9]{20,}\b"), "[API_KEY]"),
43
- # JWT-format database API keys (Qdrant Cloud auth v2). Three dot-separated
44
- # base64url segments — the middle one is always ``eyJ...`` start.
45
- (re.compile(r"\beyJ[A-Za-z0-9_-]+\.eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b"), "[API_KEY]"),
46
- # Qdrant Cloud management keys: ``<uuid>|<token>``.
47
- (
48
- re.compile(
49
- r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\|[A-Za-z0-9_-]{20,}\b"
50
- ),
51
- "[API_KEY]",
52
- ),
53
- # Legacy generic — keeps catching ``sk-...`` and ``api_...`` shapes from
54
- # older docs and tests.
55
- (re.compile(r"\b(?:sk|pk|api|key)[-_][A-Za-z0-9_-]{16,}\b", re.IGNORECASE), "[API_KEY]"),
56
- (re.compile(r"\b(?:\+?\d{1,3}[-.\s]?)?\(?\d{2,4}\)?[-.\s]?\d{3,4}[-.\s]?\d{3,4}\b"), "[PHONE]"),
57
- ]
58
-
59
- # Try Presidio for richer detection (names, locations, etc.)
60
- try:
61
- from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
62
- from presidio_anonymizer import AnonymizerEngine # type: ignore[import-not-found]
63
-
64
- _PRESIDIO_AVAILABLE = True
65
- _analyzer: Any = AnalyzerEngine()
66
- _anonymizer: Any = AnonymizerEngine()
67
- except Exception:
68
- _PRESIDIO_AVAILABLE = False
69
- _analyzer = None
70
- _anonymizer = None
71
-
72
-
73
- def _luhn_valid(num: str) -> bool:
74
- """Luhn checksum to filter false-positive credit-card matches."""
75
- digits = [int(c) for c in num if c.isdigit()]
76
- if not (13 <= len(digits) <= 19):
77
- return False
78
- s = 0
79
- for i, d in enumerate(reversed(digits)):
80
- if i % 2 == 1:
81
- d *= 2
82
- if d > 9:
83
- d -= 9
84
- s += d
85
- return s % 10 == 0
86
-
87
-
88
- def redact(text: str) -> str:
89
- """Return ``text`` with PII tokens masked.
90
-
91
- Args:
92
- text: Arbitrary string that may contain PII.
93
-
94
- Returns:
95
- Redacted copy of the text. If redaction is disabled via settings
96
- the original string is returned unchanged.
97
- """
98
- if not settings.pii_redaction_enabled or not text:
99
- return text
100
-
101
- out = text
102
- for pattern, token in _REGEX_PATTERNS:
103
- if token == "[CC]":
104
- # Apply Luhn to avoid over-masking phone numbers / arbitrary digits.
105
- out = pattern.sub(lambda m: "[CC]" if _luhn_valid(m.group(0)) else m.group(0), out)
106
- else:
107
- out = pattern.sub(token, out)
108
-
109
- if _PRESIDIO_AVAILABLE and _analyzer is not None and _anonymizer is not None:
110
- try:
111
- results = _analyzer.analyze(text=out, language="en")
112
- if results:
113
- out = _anonymizer.anonymize(text=out, analyzer_results=results).text
114
- except Exception as exc:
115
- logger.debug("presidio_redact_failed", error=str(exc))
116
-
117
- return out
118
-
119
-
120
- def redact_dict(data: dict[str, Any], fields: tuple[str, ...] | None = None) -> dict[str, Any]:
121
- """Recursively redact string values in a dict.
122
-
123
- Args:
124
- data: Dict (possibly nested) to redact.
125
- fields: If given, only redact these top-level keys. Otherwise redact
126
- every string in the structure.
127
-
128
- Returns:
129
- Deep-redacted copy.
130
- """
131
- if not settings.pii_redaction_enabled:
132
- return data
133
-
134
- def _walk(value: Any, *, force: bool) -> Any:
135
- if isinstance(value, str):
136
- return redact(value) if force else value
137
- if isinstance(value, dict):
138
- return {
139
- k: _walk(v, force=force or (fields is not None and k in fields))
140
- for k, v in value.items()
141
- }
142
- if isinstance(value, list):
143
- return [_walk(v, force=force) for v in value]
144
- return value
145
-
146
- return _walk(data, force=fields is None)
 
1
+ """PII redaction for secondary stores (audit log, query cache, conversation history).
2
+
3
+ Two strategies:
4
+ - Regex-based (always available) — covers email, phone, SSN, credit card,
5
+ IBAN, IPv4, URL with credentials.
6
+ - Microsoft Presidio (optional dependency) — invoked when installed; higher
7
+ recall and language-aware NER for names, locations, organizations.
8
+
9
+ This module never sees plaintext from the LLM — it operates only on text that
10
+ is about to be persisted to disk. Live prompts and retrieved contexts remain
11
+ unmodified so model quality is not affected.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import re
17
+ from typing import Any
18
+
19
+ from config.settings import settings
20
+ from utils.logging import get_logger
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ # Order matters — most specific patterns first so they win against the
25
+ # broader phone regex. Provider-specific API-key shapes (added 2026-05-26
26
+ # for BYOK mode) live ABOVE the generic ``[API_KEY]`` rule because their
27
+ # prefixes are not catchable by the legacy ``(sk|pk|api|key)`` alternation.
28
+ _REGEX_PATTERNS: list[tuple[re.Pattern[str], str]] = [
29
+ (re.compile(r"https?://[^\s/]+:[^\s/]+@[^\s]+"), "[URL_WITH_CREDS]"),
30
+ (re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"), "[EMAIL]"),
31
+ (re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[SSN]"),
32
+ (re.compile(r"\b(?:\d[ -]*?){13,19}\b"), "[CC]"), # Luhn-validated below
33
+ (re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{10,30}\b"), "[IBAN]"),
34
+ (re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b"), "[IP]"),
35
+ # ── BYOK key shapes (P6 production launch) ──────────────────────────
36
+ # Anthropic must come BEFORE OpenAI because ``sk-ant-...`` also matches
37
+ # the generic ``sk-...`` rule below.
38
+ (re.compile(r"\bsk-ant-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
39
+ (re.compile(r"\bsk-(?:proj|svcacct)-[A-Za-z0-9_-]{20,}\b"), "[API_KEY]"),
40
+ (re.compile(r"\bgsk_[A-Za-z0-9]{40,}\b"), "[API_KEY]"),
41
+ (re.compile(r"\bhf_[A-Za-z0-9]{30,}\b"), "[API_KEY]"),
42
+ (re.compile(r"\bvcp_[A-Za-z0-9]{20,}\b"), "[API_KEY]"),
43
+ # JWT-format database API keys (Qdrant Cloud auth v2). Three dot-separated
44
+ # base64url segments — the middle one is always ``eyJ...`` start.
45
+ (re.compile(r"\beyJ[A-Za-z0-9_-]+\.eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b"), "[API_KEY]"),
46
+ # Qdrant Cloud management keys: ``<uuid>|<token>``.
47
+ (
48
+ re.compile(
49
+ r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\|[A-Za-z0-9_-]{20,}\b"
50
+ ),
51
+ "[API_KEY]",
52
+ ),
53
+ # Legacy generic — keeps catching ``sk-...`` and ``api_...`` shapes from
54
+ # older docs and tests.
55
+ (re.compile(r"\b(?:sk|pk|api|key)[-_][A-Za-z0-9_-]{16,}\b", re.IGNORECASE), "[API_KEY]"),
56
+ (re.compile(r"\b(?:\+?\d{1,3}[-.\s]?)?\(?\d{2,4}\)?[-.\s]?\d{3,4}[-.\s]?\d{3,4}\b"), "[PHONE]"),
57
+ ]
58
+
59
+ # Try Presidio for richer detection (names, locations, etc.)
60
+ try:
61
+ from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
62
+ from presidio_anonymizer import AnonymizerEngine # type: ignore[import-not-found]
63
+
64
+ _PRESIDIO_AVAILABLE = True
65
+ _analyzer: Any = AnalyzerEngine()
66
+ _anonymizer: Any = AnonymizerEngine()
67
+ except Exception:
68
+ _PRESIDIO_AVAILABLE = False
69
+ _analyzer = None
70
+ _anonymizer = None
71
+
72
+
73
+ def _luhn_valid(num: str) -> bool:
74
+ """Luhn checksum to filter false-positive credit-card matches."""
75
+ digits = [int(c) for c in num if c.isdigit()]
76
+ if not (13 <= len(digits) <= 19):
77
+ return False
78
+ s = 0
79
+ for i, d in enumerate(reversed(digits)):
80
+ if i % 2 == 1:
81
+ d *= 2
82
+ if d > 9:
83
+ d -= 9
84
+ s += d
85
+ return s % 10 == 0
86
+
87
+
88
+ def redact(text: str) -> str:
89
+ """Return ``text`` with PII tokens masked.
90
+
91
+ Args:
92
+ text: Arbitrary string that may contain PII.
93
+
94
+ Returns:
95
+ Redacted copy of the text. If redaction is disabled via settings
96
+ the original string is returned unchanged.
97
+ """
98
+ if not settings.pii_redaction_enabled or not text:
99
+ return text
100
+
101
+ out = text
102
+ for pattern, token in _REGEX_PATTERNS:
103
+ if token == "[CC]":
104
+ # Apply Luhn to avoid over-masking phone numbers / arbitrary digits.
105
+ out = pattern.sub(lambda m: "[CC]" if _luhn_valid(m.group(0)) else m.group(0), out)
106
+ else:
107
+ out = pattern.sub(token, out)
108
+
109
+ if _PRESIDIO_AVAILABLE and _analyzer is not None and _anonymizer is not None:
110
+ try:
111
+ results = _analyzer.analyze(text=out, language="en")
112
+ if results:
113
+ out = _anonymizer.anonymize(text=out, analyzer_results=results).text
114
+ except Exception as exc:
115
+ logger.debug("presidio_redact_failed", error=str(exc))
116
+
117
+ return out
118
+
119
+
120
+ def redact_dict(data: dict[str, Any], fields: tuple[str, ...] | None = None) -> dict[str, Any]:
121
+ """Recursively redact string values in a dict.
122
+
123
+ Args:
124
+ data: Dict (possibly nested) to redact.
125
+ fields: If given, only redact these top-level keys. Otherwise redact
126
+ every string in the structure.
127
+
128
+ Returns:
129
+ Deep-redacted copy.
130
+ """
131
+ if not settings.pii_redaction_enabled:
132
+ return data
133
+
134
+ def _walk(value: Any, *, force: bool) -> Any:
135
+ if isinstance(value, str):
136
+ return redact(value) if force else value
137
+ if isinstance(value, dict):
138
+ return {
139
+ k: _walk(v, force=force or (fields is not None and k in fields))
140
+ for k, v in value.items()
141
+ }
142
+ if isinstance(value, list):
143
+ return [_walk(v, force=force) for v in value]
144
+ return value
145
+
146
+ return _walk(data, force=fields is None)
utils/rate_limiter.py CHANGED
@@ -1,524 +1,524 @@
1
- """Token-bucket rate limiter for API request throttling.
2
-
3
- Provides per-user and per-endpoint rate limiting to prevent abuse and
4
- ensure fair resource allocation. Uses an in-memory token bucket algorithm
5
- with optional Redis backend for distributed deployments.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import time
11
- from dataclasses import dataclass, field
12
- from typing import Any
13
-
14
- from utils.logging import get_logger
15
-
16
- logger = get_logger(__name__)
17
-
18
-
19
- @dataclass
20
- class RateLimitConfig:
21
- """Configuration for a rate limit bucket.
22
-
23
- Attributes:
24
- requests_per_minute: Maximum requests allowed per minute.
25
- burst_size: Maximum burst capacity (bucket size).
26
- cooldown_seconds: Seconds to wait after being rate limited.
27
- """
28
-
29
- requests_per_minute: int = 60
30
- burst_size: int = 10
31
- cooldown_seconds: float = 1.0
32
-
33
-
34
- @dataclass
35
- class TokenBucket:
36
- """In-memory token bucket for rate limiting.
37
-
38
- Attributes:
39
- tokens: Current available tokens.
40
- last_update: Timestamp of last token refill.
41
- config: Rate limit configuration.
42
- blocked_until: Timestamp when the bucket is unblocked.
43
- """
44
-
45
- tokens: float = field(default=0.0)
46
- last_update: float = field(default_factory=time.time)
47
- config: RateLimitConfig = field(default_factory=RateLimitConfig)
48
- blocked_until: float = field(default=0.0)
49
-
50
- def _refill(self) -> None:
51
- """Refill tokens based on elapsed time since last update."""
52
- now = time.time()
53
- elapsed = now - self.last_update
54
- # Refill rate: tokens per second
55
- refill_rate = self.config.requests_per_minute / 60.0
56
- self.tokens = min(self.config.burst_size, self.tokens + elapsed * refill_rate)
57
- self.last_update = now
58
-
59
- def consume(self, tokens: float = 1.0) -> tuple[bool, dict[str, Any]]:
60
- """Attempt to consume tokens from the bucket.
61
-
62
- Args:
63
- tokens: Number of tokens to consume (default 1 per request).
64
-
65
- Returns:
66
- Tuple of (allowed, metadata) where metadata contains
67
- remaining tokens, retry_after, etc.
68
- """
69
- now = time.time()
70
-
71
- # Check if currently blocked
72
- if now < self.blocked_until:
73
- retry_after = int(self.blocked_until - now) + 1
74
- return False, {
75
- "allowed": False,
76
- "remaining": 0,
77
- "retry_after": retry_after,
78
- "reason": "cooldown_active",
79
- }
80
-
81
- self._refill()
82
-
83
- if self.tokens >= tokens:
84
- self.tokens -= tokens
85
- remaining = int(self.tokens)
86
- return True, {
87
- "allowed": True,
88
- "remaining": remaining,
89
- "retry_after": 0,
90
- "reason": None,
91
- }
92
-
93
- # Rate limit exceeded — enter cooldown
94
- self.blocked_until = now + self.config.cooldown_seconds
95
- retry_after = int(self.config.cooldown_seconds) + 1
96
- return False, {
97
- "allowed": False,
98
- "remaining": 0,
99
- "retry_after": retry_after,
100
- "reason": "rate_limit_exceeded",
101
- }
102
-
103
-
104
- class RateLimiter:
105
- """Multi-key rate limiter with per-user and per-endpoint tracking.
106
-
107
- Uses in-memory token buckets. For distributed deployments, wrap
108
- with a Redis-backed implementation.
109
-
110
- Args:
111
- default_config: Default rate limit configuration.
112
- """
113
-
114
- def __init__(self, default_config: RateLimitConfig | None = None) -> None:
115
- """Initialize the rate limiter.
116
-
117
- Args:
118
- default_config: Default configuration for new buckets.
119
- """
120
- self._default_config = default_config or RateLimitConfig()
121
- self._buckets: dict[str, TokenBucket] = {}
122
-
123
- def _get_bucket(self, key: str, config: RateLimitConfig | None = None) -> TokenBucket:
124
- """Get or create a token bucket for the given key.
125
-
126
- Args:
127
- key: Unique identifier for the bucket (e.g., user_id + endpoint).
128
- config: Optional custom configuration.
129
-
130
- Returns:
131
- The token bucket for the key.
132
- """
133
- if key not in self._buckets:
134
- self._buckets[key] = TokenBucket(
135
- tokens=config.burst_size if config else self._default_config.burst_size,
136
- config=config or self._default_config,
137
- )
138
- return self._buckets[key]
139
-
140
- def check_rate_limit(
141
- self,
142
- key: str,
143
- tokens: float = 1.0,
144
- config: RateLimitConfig | None = None,
145
- ) -> tuple[bool, dict[str, Any]]:
146
- """Check if a request is within the rate limit.
147
-
148
- Args:
149
- key: Rate limit bucket key (e.g., "user_123:query").
150
- tokens: Tokens to consume.
151
- config: Optional custom config for this key.
152
-
153
- Returns:
154
- Tuple of (allowed, metadata).
155
- """
156
- bucket = self._get_bucket(key, config)
157
- allowed, metadata = bucket.consume(tokens)
158
-
159
- if not allowed:
160
- logger.warning(
161
- "rate_limit_exceeded",
162
- key=key,
163
- retry_after=metadata["retry_after"],
164
- reason=metadata["reason"],
165
- )
166
- else:
167
- logger.debug(
168
- "rate_limit_allowed",
169
- key=key,
170
- remaining=metadata["remaining"],
171
- )
172
-
173
- return allowed, metadata
174
-
175
- def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
176
- """Simple check — returns True if request is allowed.
177
-
178
- Args:
179
- key: Rate limit bucket key.
180
- tokens: Tokens to consume.
181
-
182
- Returns:
183
- True if within rate limit, False otherwise.
184
- """
185
- allowed, _ = self.check_rate_limit(key, tokens)
186
- return allowed
187
-
188
- def get_status(self, key: str) -> dict[str, Any]:
189
- """Get current rate limit status for a key.
190
-
191
- Args:
192
- key: Rate limit bucket key.
193
-
194
- Returns:
195
- Dict with remaining tokens, reset time, etc.
196
- """
197
- bucket = self._buckets.get(key)
198
- if not bucket:
199
- return {
200
- "remaining": self._default_config.burst_size,
201
- "limit": self._default_config.requests_per_minute,
202
- "reset": 0,
203
- }
204
-
205
- bucket._refill()
206
- return {
207
- "remaining": int(bucket.tokens),
208
- "limit": bucket.config.requests_per_minute,
209
- "reset": int(max(0, bucket.blocked_until - time.time())),
210
- }
211
-
212
- def reset(self, key: str) -> None:
213
- """Reset a specific rate limit bucket.
214
-
215
- Args:
216
- key: Bucket key to reset.
217
- """
218
- if key in self._buckets:
219
- del self._buckets[key]
220
- logger.info("rate_limit_reset", key=key)
221
-
222
-
223
- class OwnerKeyHourThrottle:
224
- """Per-IP hourly throttle for the BYOK owner-key fallback.
225
-
226
- Distinct from the request-level :class:`RateLimiter` because the BYOK
227
- semantics are different:
228
-
229
- - Visitors who bring their own LLM key (``ByokCreds.has_user_key()``)
230
- bypass this throttle entirely — they are paying for their own tokens.
231
- - Visitors who do NOT bring a key fall back to the platform owner's
232
- Groq key. This throttle exists to stop a single recruiter or curious
233
- visitor from burning the free-tier 30 RPM / 14,400 RPD budget.
234
-
235
- Bucket window is rolling one hour from the first allowed request in
236
- the window. Sliding-window precision is not needed — three requests an
237
- hour is already conservative. We keep timestamps in a tiny list per IP
238
- and prune entries older than 3600 seconds on each check.
239
- """
240
-
241
- __slots__ = ("_buckets", "_quota_per_hour")
242
-
243
- def __init__(self, quota_per_hour: int) -> None:
244
- if quota_per_hour < 0:
245
- raise ValueError("quota_per_hour must be non-negative")
246
- self._quota_per_hour = quota_per_hour
247
- self._buckets: dict[str, list[float]] = {}
248
-
249
- def allow(self, ip: str, *, now: float | None = None) -> tuple[bool, dict[str, Any]]:
250
- """Return whether ``ip`` may consume one owner-key request.
251
-
252
- Args:
253
- ip: Client IP address (use ``"anon"`` when unavailable so the
254
- fallback path still throttles instead of leaking quota).
255
- now: Optional monotonic clock override for tests.
256
-
257
- Returns:
258
- ``(allowed, meta)`` where ``meta`` carries ``remaining`` and
259
- ``retry_after`` seconds, ready for an HTTP 429 response.
260
- """
261
- t = now if now is not None else time.monotonic()
262
- # Prune entries older than 1h, then count.
263
- bucket = [ts for ts in self._buckets.get(ip, []) if t - ts < 3600.0]
264
- if len(bucket) >= self._quota_per_hour:
265
- # ``retry_after`` defaults to a full window when quota_per_hour=0
266
- # (kill switch) — there is no "oldest entry" to expire.
267
- retry_after = max(1, int(3600.0 - (t - bucket[0])) + 1) if bucket else 3600
268
- self._buckets[ip] = bucket # write pruned list back
269
- return False, {
270
- "allowed": False,
271
- "remaining": 0,
272
- "retry_after": retry_after,
273
- "reason": "owner_key_hourly_quota_exhausted",
274
- }
275
- bucket.append(t)
276
- self._buckets[ip] = bucket
277
- return True, {
278
- "allowed": True,
279
- "remaining": self._quota_per_hour - len(bucket),
280
- "retry_after": 0,
281
- "reason": None,
282
- }
283
-
284
- def reset(self, ip: str) -> None:
285
- """Drop all timestamps for ``ip`` (test/cleanup helper)."""
286
- self._buckets.pop(ip, None)
287
-
288
- def reset_all(self) -> None:
289
- """Drop every bucket — used between test cases to avoid leakage."""
290
- self._buckets.clear()
291
-
292
-
293
- # Module-level singleton — lazy-initialised from settings on first use so
294
- # unit tests that monkey-patch SAR_BYOK_OWNER_QUOTA see the right value.
295
- _owner_key_throttle: OwnerKeyHourThrottle | None = None
296
-
297
-
298
- def get_owner_key_throttle() -> OwnerKeyHourThrottle:
299
- """Return the process-wide owner-key throttle, creating it lazily.
300
-
301
- Reads ``settings.byok_owner_key_quota_per_hour`` at first call. Tests
302
- that need a different quota value should call :func:`reset_owner_key_throttle`
303
- after the monkey-patch.
304
- """
305
- global _owner_key_throttle
306
- if _owner_key_throttle is None:
307
- from config.settings import settings # local import to avoid cycle
308
-
309
- _owner_key_throttle = OwnerKeyHourThrottle(
310
- quota_per_hour=settings.byok_owner_key_quota_per_hour,
311
- )
312
- return _owner_key_throttle
313
-
314
-
315
- def reset_owner_key_throttle() -> None:
316
- """Force the next :func:`get_owner_key_throttle` call to rebuild from settings.
317
-
318
- Test-only hook; production code never calls this.
319
- """
320
- global _owner_key_throttle
321
- _owner_key_throttle = None
322
-
323
-
324
- class RedisRateLimiter:
325
- """Distributed rate limiter backed by Redis.
326
-
327
- Uses Redis sorted sets with sliding window algorithm for accurate
328
- per-user rate limiting across multiple application instances.
329
-
330
- Args:
331
- redis_url: Redis connection URL.
332
- default_config: Default rate limit configuration.
333
- """
334
-
335
- def __init__(
336
- self,
337
- redis_url: str | None = None,
338
- default_config: RateLimitConfig | None = None,
339
- ) -> None:
340
- """Initialize the Redis rate limiter.
341
-
342
- Args:
343
- redis_url: Redis connection URL. Falls back to settings.
344
- default_config: Default configuration for new keys.
345
- """
346
- import redis
347
-
348
- from config.settings import settings
349
-
350
- self._redis = redis.from_url(redis_url or settings.redis_url)
351
- self._default_config = default_config or RateLimitConfig()
352
-
353
- def check_rate_limit(
354
- self,
355
- key: str,
356
- tokens: float = 1.0,
357
- config: RateLimitConfig | None = None,
358
- ) -> tuple[bool, dict[str, Any]]:
359
- """Check if a request is within the rate limit using Redis.
360
-
361
- Uses a sliding window algorithm based on Redis sorted sets.
362
-
363
- Args:
364
- key: Rate limit bucket key.
365
- tokens: Tokens to consume.
366
- config: Optional custom config.
367
-
368
- Returns:
369
- Tuple of (allowed, metadata).
370
- """
371
- cfg = config or self._default_config
372
- now = time.time()
373
- window_start = now - 60.0 # 1-minute window
374
- redis_key = f"ratelimit:{key}"
375
-
376
- # Remove old entries outside the window
377
- self._redis.zremrangebyscore(redis_key, 0, window_start)
378
-
379
- # Count current requests in window
380
- current_count = self._redis.zcard(redis_key)
381
-
382
- # Check burst limit
383
- if current_count >= cfg.burst_size:
384
- retry_after = int(cfg.cooldown_seconds) + 1
385
- return False, {
386
- "allowed": False,
387
- "remaining": 0,
388
- "retry_after": retry_after,
389
- "reason": "rate_limit_exceeded",
390
- }
391
-
392
- # Check per-minute rate
393
- rpm_limit = cfg.requests_per_minute
394
- if current_count >= rpm_limit:
395
- retry_after = int(60 - (now % 60)) + 1
396
- return False, {
397
- "allowed": False,
398
- "remaining": 0,
399
- "retry_after": retry_after,
400
- "reason": "rate_limit_exceeded",
401
- }
402
-
403
- # Record this request
404
- self._redis.zadd(redis_key, {str(now): now})
405
- # Set expiry on the key
406
- self._redis.expire(redis_key, 120)
407
-
408
- remaining = min(cfg.burst_size, rpm_limit) - current_count - 1
409
- return True, {
410
- "allowed": True,
411
- "remaining": max(0, remaining),
412
- "retry_after": 0,
413
- "reason": None,
414
- }
415
-
416
- def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
417
- """Simple check — returns True if request is allowed.
418
-
419
- Args:
420
- key: Rate limit bucket key.
421
- tokens: Tokens to consume.
422
-
423
- Returns:
424
- True if within rate limit, False otherwise.
425
- """
426
- allowed, _ = self.check_rate_limit(key, tokens)
427
- return allowed
428
-
429
- def get_status(self, key: str) -> dict[str, Any]:
430
- """Get current rate limit status for a key.
431
-
432
- Args:
433
- key: Rate limit bucket key.
434
-
435
- Returns:
436
- Dict with remaining tokens, limit, and reset time.
437
- """
438
- redis_key = f"ratelimit:{key}"
439
- now = time.time()
440
- window_start = now - 60.0
441
- self._redis.zremrangebyscore(redis_key, 0, window_start)
442
- current_count = self._redis.zcard(redis_key)
443
- remaining = max(0, self._default_config.burst_size - current_count)
444
- return {
445
- "remaining": remaining,
446
- "limit": self._default_config.requests_per_minute,
447
- "reset": int(60 - (now % 60)),
448
- }
449
-
450
- def reset(self, key: str) -> None:
451
- """Reset a specific rate limit bucket.
452
-
453
- Args:
454
- key: Bucket key to reset.
455
- """
456
- self._redis.delete(f"ratelimit:{key}")
457
- logger.info("redis_rate_limit_reset", key=key)
458
-
459
-
460
- def _get_rate_limiter() -> RateLimiter | RedisRateLimiter:
461
- """Get the appropriate rate limiter based on configuration.
462
-
463
- Returns:
464
- RateLimiter (in-memory) or RedisRateLimiter (distributed).
465
- """
466
- from config.settings import settings
467
-
468
- if settings.use_redis_rate_limiter:
469
- try:
470
- return RedisRateLimiter()
471
- except Exception as exc:
472
- logger.warning("redis_rate_limiter_failed", error=str(exc), fallback="memory")
473
- return RateLimiter(default_config=RATE_LIMIT_PROFILES["default"])
474
-
475
-
476
- # Pre-configured rate limit profiles
477
- RATE_LIMIT_PROFILES: dict[str, RateLimitConfig] = {
478
- "default": RateLimitConfig(requests_per_minute=60, burst_size=10),
479
- "strict": RateLimitConfig(requests_per_minute=10, burst_size=3, cooldown_seconds=5.0),
480
- "generous": RateLimitConfig(requests_per_minute=300, burst_size=50),
481
- "upload": RateLimitConfig(requests_per_minute=5, burst_size=2, cooldown_seconds=10.0),
482
- "query": RateLimitConfig(requests_per_minute=30, burst_size=5, cooldown_seconds=2.0),
483
- }
484
-
485
- # Module-level singleton (lazy initialization)
486
- _rate_limiter_instance: RateLimiter | RedisRateLimiter | None = None
487
-
488
-
489
- def _get_limiter() -> RateLimiter | RedisRateLimiter:
490
- """Get the singleton rate limiter instance.
491
-
492
- Returns:
493
- The configured rate limiter.
494
- """
495
- global _rate_limiter_instance
496
- if _rate_limiter_instance is None:
497
- _rate_limiter_instance = _get_rate_limiter()
498
- return _rate_limiter_instance
499
-
500
-
501
- def check_query_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
502
- """Check rate limit for a user query.
503
-
504
- Args:
505
- user_id: The user making the query.
506
-
507
- Returns:
508
- Tuple of (allowed, metadata).
509
- """
510
- key = f"{user_id}:query"
511
- return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["query"])
512
-
513
-
514
- def check_upload_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
515
- """Check rate limit for a document upload.
516
-
517
- Args:
518
- user_id: The user uploading.
519
-
520
- Returns:
521
- Tuple of (allowed, metadata).
522
- """
523
- key = f"{user_id}:upload"
524
- return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["upload"])
 
1
+ """Token-bucket rate limiter for API request throttling.
2
+
3
+ Provides per-user and per-endpoint rate limiting to prevent abuse and
4
+ ensure fair resource allocation. Uses an in-memory token bucket algorithm
5
+ with optional Redis backend for distributed deployments.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from typing import Any
13
+
14
+ from utils.logging import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class RateLimitConfig:
21
+ """Configuration for a rate limit bucket.
22
+
23
+ Attributes:
24
+ requests_per_minute: Maximum requests allowed per minute.
25
+ burst_size: Maximum burst capacity (bucket size).
26
+ cooldown_seconds: Seconds to wait after being rate limited.
27
+ """
28
+
29
+ requests_per_minute: int = 60
30
+ burst_size: int = 10
31
+ cooldown_seconds: float = 1.0
32
+
33
+
34
+ @dataclass
35
+ class TokenBucket:
36
+ """In-memory token bucket for rate limiting.
37
+
38
+ Attributes:
39
+ tokens: Current available tokens.
40
+ last_update: Timestamp of last token refill.
41
+ config: Rate limit configuration.
42
+ blocked_until: Timestamp when the bucket is unblocked.
43
+ """
44
+
45
+ tokens: float = field(default=0.0)
46
+ last_update: float = field(default_factory=time.time)
47
+ config: RateLimitConfig = field(default_factory=RateLimitConfig)
48
+ blocked_until: float = field(default=0.0)
49
+
50
+ def _refill(self) -> None:
51
+ """Refill tokens based on elapsed time since last update."""
52
+ now = time.time()
53
+ elapsed = now - self.last_update
54
+ # Refill rate: tokens per second
55
+ refill_rate = self.config.requests_per_minute / 60.0
56
+ self.tokens = min(self.config.burst_size, self.tokens + elapsed * refill_rate)
57
+ self.last_update = now
58
+
59
+ def consume(self, tokens: float = 1.0) -> tuple[bool, dict[str, Any]]:
60
+ """Attempt to consume tokens from the bucket.
61
+
62
+ Args:
63
+ tokens: Number of tokens to consume (default 1 per request).
64
+
65
+ Returns:
66
+ Tuple of (allowed, metadata) where metadata contains
67
+ remaining tokens, retry_after, etc.
68
+ """
69
+ now = time.time()
70
+
71
+ # Check if currently blocked
72
+ if now < self.blocked_until:
73
+ retry_after = int(self.blocked_until - now) + 1
74
+ return False, {
75
+ "allowed": False,
76
+ "remaining": 0,
77
+ "retry_after": retry_after,
78
+ "reason": "cooldown_active",
79
+ }
80
+
81
+ self._refill()
82
+
83
+ if self.tokens >= tokens:
84
+ self.tokens -= tokens
85
+ remaining = int(self.tokens)
86
+ return True, {
87
+ "allowed": True,
88
+ "remaining": remaining,
89
+ "retry_after": 0,
90
+ "reason": None,
91
+ }
92
+
93
+ # Rate limit exceeded — enter cooldown
94
+ self.blocked_until = now + self.config.cooldown_seconds
95
+ retry_after = int(self.config.cooldown_seconds) + 1
96
+ return False, {
97
+ "allowed": False,
98
+ "remaining": 0,
99
+ "retry_after": retry_after,
100
+ "reason": "rate_limit_exceeded",
101
+ }
102
+
103
+
104
+ class RateLimiter:
105
+ """Multi-key rate limiter with per-user and per-endpoint tracking.
106
+
107
+ Uses in-memory token buckets. For distributed deployments, wrap
108
+ with a Redis-backed implementation.
109
+
110
+ Args:
111
+ default_config: Default rate limit configuration.
112
+ """
113
+
114
+ def __init__(self, default_config: RateLimitConfig | None = None) -> None:
115
+ """Initialize the rate limiter.
116
+
117
+ Args:
118
+ default_config: Default configuration for new buckets.
119
+ """
120
+ self._default_config = default_config or RateLimitConfig()
121
+ self._buckets: dict[str, TokenBucket] = {}
122
+
123
+ def _get_bucket(self, key: str, config: RateLimitConfig | None = None) -> TokenBucket:
124
+ """Get or create a token bucket for the given key.
125
+
126
+ Args:
127
+ key: Unique identifier for the bucket (e.g., user_id + endpoint).
128
+ config: Optional custom configuration.
129
+
130
+ Returns:
131
+ The token bucket for the key.
132
+ """
133
+ if key not in self._buckets:
134
+ self._buckets[key] = TokenBucket(
135
+ tokens=config.burst_size if config else self._default_config.burst_size,
136
+ config=config or self._default_config,
137
+ )
138
+ return self._buckets[key]
139
+
140
+ def check_rate_limit(
141
+ self,
142
+ key: str,
143
+ tokens: float = 1.0,
144
+ config: RateLimitConfig | None = None,
145
+ ) -> tuple[bool, dict[str, Any]]:
146
+ """Check if a request is within the rate limit.
147
+
148
+ Args:
149
+ key: Rate limit bucket key (e.g., "user_123:query").
150
+ tokens: Tokens to consume.
151
+ config: Optional custom config for this key.
152
+
153
+ Returns:
154
+ Tuple of (allowed, metadata).
155
+ """
156
+ bucket = self._get_bucket(key, config)
157
+ allowed, metadata = bucket.consume(tokens)
158
+
159
+ if not allowed:
160
+ logger.warning(
161
+ "rate_limit_exceeded",
162
+ key=key,
163
+ retry_after=metadata["retry_after"],
164
+ reason=metadata["reason"],
165
+ )
166
+ else:
167
+ logger.debug(
168
+ "rate_limit_allowed",
169
+ key=key,
170
+ remaining=metadata["remaining"],
171
+ )
172
+
173
+ return allowed, metadata
174
+
175
+ def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
176
+ """Simple check — returns True if request is allowed.
177
+
178
+ Args:
179
+ key: Rate limit bucket key.
180
+ tokens: Tokens to consume.
181
+
182
+ Returns:
183
+ True if within rate limit, False otherwise.
184
+ """
185
+ allowed, _ = self.check_rate_limit(key, tokens)
186
+ return allowed
187
+
188
+ def get_status(self, key: str) -> dict[str, Any]:
189
+ """Get current rate limit status for a key.
190
+
191
+ Args:
192
+ key: Rate limit bucket key.
193
+
194
+ Returns:
195
+ Dict with remaining tokens, reset time, etc.
196
+ """
197
+ bucket = self._buckets.get(key)
198
+ if not bucket:
199
+ return {
200
+ "remaining": self._default_config.burst_size,
201
+ "limit": self._default_config.requests_per_minute,
202
+ "reset": 0,
203
+ }
204
+
205
+ bucket._refill()
206
+ return {
207
+ "remaining": int(bucket.tokens),
208
+ "limit": bucket.config.requests_per_minute,
209
+ "reset": int(max(0, bucket.blocked_until - time.time())),
210
+ }
211
+
212
+ def reset(self, key: str) -> None:
213
+ """Reset a specific rate limit bucket.
214
+
215
+ Args:
216
+ key: Bucket key to reset.
217
+ """
218
+ if key in self._buckets:
219
+ del self._buckets[key]
220
+ logger.info("rate_limit_reset", key=key)
221
+
222
+
223
+ class OwnerKeyHourThrottle:
224
+ """Per-IP hourly throttle for the BYOK owner-key fallback.
225
+
226
+ Distinct from the request-level :class:`RateLimiter` because the BYOK
227
+ semantics are different:
228
+
229
+ - Visitors who bring their own LLM key (``ByokCreds.has_user_key()``)
230
+ bypass this throttle entirely — they are paying for their own tokens.
231
+ - Visitors who do NOT bring a key fall back to the platform owner's
232
+ Groq key. This throttle exists to stop a single recruiter or curious
233
+ visitor from burning the free-tier 30 RPM / 14,400 RPD budget.
234
+
235
+ Bucket window is rolling one hour from the first allowed request in
236
+ the window. Sliding-window precision is not needed — three requests an
237
+ hour is already conservative. We keep timestamps in a tiny list per IP
238
+ and prune entries older than 3600 seconds on each check.
239
+ """
240
+
241
+ __slots__ = ("_buckets", "_quota_per_hour")
242
+
243
+ def __init__(self, quota_per_hour: int) -> None:
244
+ if quota_per_hour < 0:
245
+ raise ValueError("quota_per_hour must be non-negative")
246
+ self._quota_per_hour = quota_per_hour
247
+ self._buckets: dict[str, list[float]] = {}
248
+
249
+ def allow(self, ip: str, *, now: float | None = None) -> tuple[bool, dict[str, Any]]:
250
+ """Return whether ``ip`` may consume one owner-key request.
251
+
252
+ Args:
253
+ ip: Client IP address (use ``"anon"`` when unavailable so the
254
+ fallback path still throttles instead of leaking quota).
255
+ now: Optional monotonic clock override for tests.
256
+
257
+ Returns:
258
+ ``(allowed, meta)`` where ``meta`` carries ``remaining`` and
259
+ ``retry_after`` seconds, ready for an HTTP 429 response.
260
+ """
261
+ t = now if now is not None else time.monotonic()
262
+ # Prune entries older than 1h, then count.
263
+ bucket = [ts for ts in self._buckets.get(ip, []) if t - ts < 3600.0]
264
+ if len(bucket) >= self._quota_per_hour:
265
+ # ``retry_after`` defaults to a full window when quota_per_hour=0
266
+ # (kill switch) — there is no "oldest entry" to expire.
267
+ retry_after = max(1, int(3600.0 - (t - bucket[0])) + 1) if bucket else 3600
268
+ self._buckets[ip] = bucket # write pruned list back
269
+ return False, {
270
+ "allowed": False,
271
+ "remaining": 0,
272
+ "retry_after": retry_after,
273
+ "reason": "owner_key_hourly_quota_exhausted",
274
+ }
275
+ bucket.append(t)
276
+ self._buckets[ip] = bucket
277
+ return True, {
278
+ "allowed": True,
279
+ "remaining": self._quota_per_hour - len(bucket),
280
+ "retry_after": 0,
281
+ "reason": None,
282
+ }
283
+
284
+ def reset(self, ip: str) -> None:
285
+ """Drop all timestamps for ``ip`` (test/cleanup helper)."""
286
+ self._buckets.pop(ip, None)
287
+
288
+ def reset_all(self) -> None:
289
+ """Drop every bucket — used between test cases to avoid leakage."""
290
+ self._buckets.clear()
291
+
292
+
293
+ # Module-level singleton — lazy-initialised from settings on first use so
294
+ # unit tests that monkey-patch SAR_BYOK_OWNER_QUOTA see the right value.
295
+ _owner_key_throttle: OwnerKeyHourThrottle | None = None
296
+
297
+
298
+ def get_owner_key_throttle() -> OwnerKeyHourThrottle:
299
+ """Return the process-wide owner-key throttle, creating it lazily.
300
+
301
+ Reads ``settings.byok_owner_key_quota_per_hour`` at first call. Tests
302
+ that need a different quota value should call :func:`reset_owner_key_throttle`
303
+ after the monkey-patch.
304
+ """
305
+ global _owner_key_throttle
306
+ if _owner_key_throttle is None:
307
+ from config.settings import settings # local import to avoid cycle
308
+
309
+ _owner_key_throttle = OwnerKeyHourThrottle(
310
+ quota_per_hour=settings.byok_owner_key_quota_per_hour,
311
+ )
312
+ return _owner_key_throttle
313
+
314
+
315
+ def reset_owner_key_throttle() -> None:
316
+ """Force the next :func:`get_owner_key_throttle` call to rebuild from settings.
317
+
318
+ Test-only hook; production code never calls this.
319
+ """
320
+ global _owner_key_throttle
321
+ _owner_key_throttle = None
322
+
323
+
324
+ class RedisRateLimiter:
325
+ """Distributed rate limiter backed by Redis.
326
+
327
+ Uses Redis sorted sets with sliding window algorithm for accurate
328
+ per-user rate limiting across multiple application instances.
329
+
330
+ Args:
331
+ redis_url: Redis connection URL.
332
+ default_config: Default rate limit configuration.
333
+ """
334
+
335
+ def __init__(
336
+ self,
337
+ redis_url: str | None = None,
338
+ default_config: RateLimitConfig | None = None,
339
+ ) -> None:
340
+ """Initialize the Redis rate limiter.
341
+
342
+ Args:
343
+ redis_url: Redis connection URL. Falls back to settings.
344
+ default_config: Default configuration for new keys.
345
+ """
346
+ import redis
347
+
348
+ from config.settings import settings
349
+
350
+ self._redis = redis.from_url(redis_url or settings.redis_url)
351
+ self._default_config = default_config or RateLimitConfig()
352
+
353
+ def check_rate_limit(
354
+ self,
355
+ key: str,
356
+ tokens: float = 1.0,
357
+ config: RateLimitConfig | None = None,
358
+ ) -> tuple[bool, dict[str, Any]]:
359
+ """Check if a request is within the rate limit using Redis.
360
+
361
+ Uses a sliding window algorithm based on Redis sorted sets.
362
+
363
+ Args:
364
+ key: Rate limit bucket key.
365
+ tokens: Tokens to consume.
366
+ config: Optional custom config.
367
+
368
+ Returns:
369
+ Tuple of (allowed, metadata).
370
+ """
371
+ cfg = config or self._default_config
372
+ now = time.time()
373
+ window_start = now - 60.0 # 1-minute window
374
+ redis_key = f"ratelimit:{key}"
375
+
376
+ # Remove old entries outside the window
377
+ self._redis.zremrangebyscore(redis_key, 0, window_start)
378
+
379
+ # Count current requests in window
380
+ current_count = self._redis.zcard(redis_key)
381
+
382
+ # Check burst limit
383
+ if current_count >= cfg.burst_size:
384
+ retry_after = int(cfg.cooldown_seconds) + 1
385
+ return False, {
386
+ "allowed": False,
387
+ "remaining": 0,
388
+ "retry_after": retry_after,
389
+ "reason": "rate_limit_exceeded",
390
+ }
391
+
392
+ # Check per-minute rate
393
+ rpm_limit = cfg.requests_per_minute
394
+ if current_count >= rpm_limit:
395
+ retry_after = int(60 - (now % 60)) + 1
396
+ return False, {
397
+ "allowed": False,
398
+ "remaining": 0,
399
+ "retry_after": retry_after,
400
+ "reason": "rate_limit_exceeded",
401
+ }
402
+
403
+ # Record this request
404
+ self._redis.zadd(redis_key, {str(now): now})
405
+ # Set expiry on the key
406
+ self._redis.expire(redis_key, 120)
407
+
408
+ remaining = min(cfg.burst_size, rpm_limit) - current_count - 1
409
+ return True, {
410
+ "allowed": True,
411
+ "remaining": max(0, remaining),
412
+ "retry_after": 0,
413
+ "reason": None,
414
+ }
415
+
416
+ def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
417
+ """Simple check — returns True if request is allowed.
418
+
419
+ Args:
420
+ key: Rate limit bucket key.
421
+ tokens: Tokens to consume.
422
+
423
+ Returns:
424
+ True if within rate limit, False otherwise.
425
+ """
426
+ allowed, _ = self.check_rate_limit(key, tokens)
427
+ return allowed
428
+
429
+ def get_status(self, key: str) -> dict[str, Any]:
430
+ """Get current rate limit status for a key.
431
+
432
+ Args:
433
+ key: Rate limit bucket key.
434
+
435
+ Returns:
436
+ Dict with remaining tokens, limit, and reset time.
437
+ """
438
+ redis_key = f"ratelimit:{key}"
439
+ now = time.time()
440
+ window_start = now - 60.0
441
+ self._redis.zremrangebyscore(redis_key, 0, window_start)
442
+ current_count = self._redis.zcard(redis_key)
443
+ remaining = max(0, self._default_config.burst_size - current_count)
444
+ return {
445
+ "remaining": remaining,
446
+ "limit": self._default_config.requests_per_minute,
447
+ "reset": int(60 - (now % 60)),
448
+ }
449
+
450
+ def reset(self, key: str) -> None:
451
+ """Reset a specific rate limit bucket.
452
+
453
+ Args:
454
+ key: Bucket key to reset.
455
+ """
456
+ self._redis.delete(f"ratelimit:{key}")
457
+ logger.info("redis_rate_limit_reset", key=key)
458
+
459
+
460
+ def _get_rate_limiter() -> RateLimiter | RedisRateLimiter:
461
+ """Get the appropriate rate limiter based on configuration.
462
+
463
+ Returns:
464
+ RateLimiter (in-memory) or RedisRateLimiter (distributed).
465
+ """
466
+ from config.settings import settings
467
+
468
+ if settings.use_redis_rate_limiter:
469
+ try:
470
+ return RedisRateLimiter()
471
+ except Exception as exc:
472
+ logger.warning("redis_rate_limiter_failed", error=str(exc), fallback="memory")
473
+ return RateLimiter(default_config=RATE_LIMIT_PROFILES["default"])
474
+
475
+
476
+ # Pre-configured rate limit profiles
477
+ RATE_LIMIT_PROFILES: dict[str, RateLimitConfig] = {
478
+ "default": RateLimitConfig(requests_per_minute=60, burst_size=10),
479
+ "strict": RateLimitConfig(requests_per_minute=10, burst_size=3, cooldown_seconds=5.0),
480
+ "generous": RateLimitConfig(requests_per_minute=300, burst_size=50),
481
+ "upload": RateLimitConfig(requests_per_minute=5, burst_size=2, cooldown_seconds=10.0),
482
+ "query": RateLimitConfig(requests_per_minute=30, burst_size=5, cooldown_seconds=2.0),
483
+ }
484
+
485
+ # Module-level singleton (lazy initialization)
486
+ _rate_limiter_instance: RateLimiter | RedisRateLimiter | None = None
487
+
488
+
489
+ def _get_limiter() -> RateLimiter | RedisRateLimiter:
490
+ """Get the singleton rate limiter instance.
491
+
492
+ Returns:
493
+ The configured rate limiter.
494
+ """
495
+ global _rate_limiter_instance
496
+ if _rate_limiter_instance is None:
497
+ _rate_limiter_instance = _get_rate_limiter()
498
+ return _rate_limiter_instance
499
+
500
+
501
+ def check_query_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
502
+ """Check rate limit for a user query.
503
+
504
+ Args:
505
+ user_id: The user making the query.
506
+
507
+ Returns:
508
+ Tuple of (allowed, metadata).
509
+ """
510
+ key = f"{user_id}:query"
511
+ return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["query"])
512
+
513
+
514
+ def check_upload_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
515
+ """Check rate limit for a document upload.
516
+
517
+ Args:
518
+ user_id: The user uploading.
519
+
520
+ Returns:
521
+ Tuple of (allowed, metadata).
522
+ """
523
+ key = f"{user_id}:upload"
524
+ return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["upload"])