LeomordKaly commited on
Commit
a382192
·
verified ·
1 Parent(s): c385f4b

deploy: phase 3 BYOK backend (Dockerfile.hf, FastAPI on 7860)

Browse files
Dockerfile.hf CHANGED
@@ -94,6 +94,12 @@ ENV SAR_BYOK_MODE=true
94
  # would still be defended; visitors who exceed the cap are nudged to
95
  # paste their own BYOK key via the UI 429 banner.
96
  ENV SAR_BYOK_OWNER_KEY_QUOTA_PER_HOUR=10
 
 
 
 
 
 
97
  ENV SAR_SESSION_COLLECTION_TTL_HOURS=24
98
  ENV SAR_CORS_ALLOW_ORIGINS='["https://secureagentrag-web.vercel.app","https://secureagentrag.vercel.app"]'
99
 
 
94
  # would still be defended; visitors who exceed the cap are nudged to
95
  # paste their own BYOK key via the UI 429 banner.
96
  ENV SAR_BYOK_OWNER_KEY_QUOTA_PER_HOUR=10
97
+ # HF Spaces fronts the container with exactly one trusted reverse proxy that
98
+ # *appends* the peer it saw to X-Forwarded-For. Tell the throttle to read the
99
+ # IP one hop from the right (spoof-resistant) instead of the attacker-appendable
100
+ # leftmost token, so a visitor can't mint a fresh owner-key bucket per request
101
+ # by forging XFF. See interfaces/byok.py::client_ip_from_request.
102
+ ENV SAR_BYOK_XFF_TRUSTED_HOPS=1
103
  ENV SAR_SESSION_COLLECTION_TTL_HOURS=24
104
  ENV SAR_CORS_ALLOW_ORIGINS='["https://secureagentrag-web.vercel.app","https://secureagentrag.vercel.app"]'
105
 
inference/byok_context.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-request BYOK credentials carried through the pipeline via a ContextVar.
2
+
3
+ The graph nodes (router → … → synthesizer) do not thread credentials through
4
+ their signatures — they call ``call_llm_*`` which builds an ``InferenceRouter``
5
+ with no per-request key. To make a visitor's *own* LLM key actually power their
6
+ request (the whole point of "Bring Your Own Key"), we stash the credentials in a
7
+ ``contextvars.ContextVar`` at the top of ``run_rag_pipeline[_stream]``.
8
+
9
+ ``ContextVar`` propagates across ``asyncio`` task boundaries (``gather``,
10
+ ``astream``), so every node — and every parallel LLM call inside a node — sees
11
+ the same per-request creds without any signature plumbing. The token is reset in
12
+ a ``finally`` so the value never leaks between requests on a reused worker.
13
+
14
+ When no BYOK key is present the ContextVar holds ``None`` and the router falls
15
+ back to the owner's cached clients exactly as before.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import contextvars
21
+ from dataclasses import dataclass
22
+
23
+ # Providers whose BYOK client is built from a bearer/API key.
24
+ _KEY_PROVIDERS = frozenset({"groq", "openai", "anthropic"})
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class ByokRuntime:
29
+ """Per-request BYOK credentials resolved from the visitor's request headers.
30
+
31
+ Attributes:
32
+ provider: Visitor's chosen provider ("groq" / "openai" / "anthropic" /
33
+ "ollama"), already allow-list validated. None = no BYOK.
34
+ user_key: Visitor's API key for a key-based provider. None for Ollama.
35
+ ollama_url: Visitor's Ollama instance URL (only for provider="ollama").
36
+ """
37
+
38
+ provider: str | None = None
39
+ user_key: str | None = None
40
+ ollama_url: str | None = None
41
+
42
+ def is_active(self) -> bool:
43
+ """True when these creds can actually drive a per-request LLM client.
44
+
45
+ A key-based provider needs a non-empty key; Ollama needs a URL. Anything
46
+ else (missing provider, key without a provider, ollama without a URL)
47
+ is *not* active — the router falls back to the owner's clients.
48
+ """
49
+ prov = (self.provider or "").lower()
50
+ if prov in _KEY_PROVIDERS:
51
+ return bool(self.user_key and self.user_key.strip())
52
+ if prov == "ollama":
53
+ return bool(self.ollama_url and self.ollama_url.strip())
54
+ return False
55
+
56
+
57
+ _byok_ctx: contextvars.ContextVar[ByokRuntime | None] = contextvars.ContextVar(
58
+ "byok_runtime", default=None
59
+ )
60
+
61
+
62
+ def set_byok_runtime(runtime: ByokRuntime | None) -> contextvars.Token:
63
+ """Bind ``runtime`` for the current async context. Returns a reset token."""
64
+ return _byok_ctx.set(runtime)
65
+
66
+
67
+ def get_byok_runtime() -> ByokRuntime | None:
68
+ """Return the BYOK creds bound to the current async context, or None."""
69
+ return _byok_ctx.get()
70
+
71
+
72
+ def reset_byok_runtime(token: contextvars.Token) -> None:
73
+ """Restore the previous ContextVar value (call in a ``finally``)."""
74
+ _byok_ctx.reset(token)
inference/router.py CHANGED
@@ -2,11 +2,13 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  from typing import TYPE_CHECKING
6
 
7
  from pydantic import BaseModel
8
 
9
  from config.settings import settings
 
10
  from inference.llm_factory import LLMResponse, get_llm
11
  from ingestion.metadata import SensitivityLevel
12
  from utils.logging import get_logger
@@ -84,15 +86,25 @@ class InferenceRouter:
84
  if isinstance(sensitivity_level, str):
85
  sensitivity_level = SensitivityLevel(sensitivity_level.lower())
86
 
87
- # 1. Admin override honoured for LOW/MEDIUM, but NEVER allowed to move
88
- # HIGH-sensitivity work off local inference on a self-hosted deploy.
89
- # Without this guard the override short-circuits the HIGH→local branch
90
- # below (order-of-checks footgun). The override is still respected when
91
- # the deploy explicitly opts into cloud-for-HIGH (the GPU-less public
92
- # demo, SAR_ALLOW_CLOUD_FOR_HIGH=true) or when it targets local Ollama.
93
- # NOTE: override_provider is currently not wired into the pipeline path
94
- # (call_llm_* / synthesizer call route() without it); this guard is
95
- # defence-in-depth so the privacy guarantee holds even if it ever is.
 
 
 
 
 
 
 
 
 
 
96
  if override_provider:
97
  high_must_stay_local = (
98
  sensitivity_level == SensitivityLevel.HIGH
@@ -222,8 +234,8 @@ class InferenceRouter:
222
  import time
223
 
224
  start = time.perf_counter()
 
225
  try:
226
- client = get_llm(provider=decision.provider, model=decision.model)
227
  response = await client.generate(prompt=prompt, system_prompt=system_prompt, **kwargs)
228
  elapsed_ms = (time.perf_counter() - start) * 1000
229
  response.latency_ms = elapsed_ms
@@ -262,6 +274,13 @@ class InferenceRouter:
262
  )
263
  response.latency_ms = (time.perf_counter() - start) * 1000
264
  return response, fallback_decision
 
 
 
 
 
 
 
265
 
266
  @staticmethod
267
  def _normalised_sensitivity(level: SensitivityLevel | str) -> SensitivityLevel:
@@ -300,7 +319,7 @@ class InferenceRouter:
300
  forced_local=decision.forced_local,
301
  )
302
 
303
- client = get_llm(provider=decision.provider, model=decision.model)
304
  try:
305
  import time
306
 
@@ -310,8 +329,11 @@ class InferenceRouter:
310
  response.latency_ms = elapsed_ms
311
  return response, decision
312
  finally:
313
- # Clients are cached — do NOT close per-request
314
- pass
 
 
 
315
 
316
  async def generate_stream_with_routing(
317
  self,
@@ -346,7 +368,7 @@ class InferenceRouter:
346
  forced_local=decision.forced_local,
347
  )
348
 
349
- client = get_llm(provider=decision.provider, model=decision.model)
350
  try:
351
  if hasattr(client, "generate_stream"):
352
  async for token in client.generate_stream(
@@ -360,8 +382,11 @@ class InferenceRouter:
360
  )
361
  yield response.text
362
  finally:
363
- # Clients are cached — do NOT close per-request
364
- pass
 
 
 
365
 
366
  def get_available_providers(self) -> list[str]:
367
  """Return a list of currently configured and available providers.
@@ -418,3 +443,35 @@ class InferenceRouter:
418
  "anthropic": settings.anthropic_model,
419
  }
420
  return model_defaults.get(provider, settings.llm_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import contextlib
6
  from typing import TYPE_CHECKING
7
 
8
  from pydantic import BaseModel
9
 
10
  from config.settings import settings
11
+ from inference.byok_context import get_byok_runtime
12
  from inference.llm_factory import LLMResponse, get_llm
13
  from ingestion.metadata import SensitivityLevel
14
  from utils.logging import get_logger
 
86
  if isinstance(sensitivity_level, str):
87
  sensitivity_level = SensitivityLevel(sensitivity_level.lower())
88
 
89
+ # BYOK: when a visitor brought their own usable key/provider for THIS
90
+ # request (carried via the ContextVar set in run_rag_pipeline[_stream]),
91
+ # treat their provider like an explicit override so their key actually
92
+ # powers the answer. The same HIGH-sensitivity guard below still applies,
93
+ # so a visitor cloud key cannot move HIGH off local on a self-hosted
94
+ # deploy (SAR_ALLOW_CLOUD_FOR_HIGH=false). The matching per-request
95
+ # client is built in ``_client_for`` from the same ContextVar.
96
+ if override_provider is None:
97
+ _rt = get_byok_runtime()
98
+ if _rt is not None and _rt.is_active():
99
+ override_provider = (_rt.provider or "").lower()
100
+
101
+ # 1. Admin / BYOK override — honoured for LOW/MEDIUM, but NEVER allowed
102
+ # to move HIGH-sensitivity work off local inference on a self-hosted
103
+ # deploy. Without this guard the override short-circuits the HIGH→local
104
+ # branch below (order-of-checks footgun). The override is still respected
105
+ # when the deploy explicitly opts into cloud-for-HIGH (the GPU-less
106
+ # public demo, SAR_ALLOW_CLOUD_FOR_HIGH=true) or when it targets local
107
+ # Ollama.
108
  if override_provider:
109
  high_must_stay_local = (
110
  sensitivity_level == SensitivityLevel.HIGH
 
234
  import time
235
 
236
  start = time.perf_counter()
237
+ client, ephemeral = self._client_for(decision.provider, decision.model)
238
  try:
 
239
  response = await client.generate(prompt=prompt, system_prompt=system_prompt, **kwargs)
240
  elapsed_ms = (time.perf_counter() - start) * 1000
241
  response.latency_ms = elapsed_ms
 
274
  )
275
  response.latency_ms = (time.perf_counter() - start) * 1000
276
  return response, fallback_decision
277
+ finally:
278
+ # Per-request BYOK clients are fresh and unshared — close them so the
279
+ # visitor's httpx connection pool is released. Owner clients are
280
+ # cached and must never be closed here.
281
+ if ephemeral:
282
+ with contextlib.suppress(Exception):
283
+ await client.close()
284
 
285
  @staticmethod
286
  def _normalised_sensitivity(level: SensitivityLevel | str) -> SensitivityLevel:
 
319
  forced_local=decision.forced_local,
320
  )
321
 
322
+ client, ephemeral = self._client_for(decision.provider, decision.model)
323
  try:
324
  import time
325
 
 
329
  response.latency_ms = elapsed_ms
330
  return response, decision
331
  finally:
332
+ # Owner clients are cached — never closed here. Per-request BYOK
333
+ # clients are fresh and must be closed to release their pool.
334
+ if ephemeral:
335
+ with contextlib.suppress(Exception):
336
+ await client.close()
337
 
338
  async def generate_stream_with_routing(
339
  self,
 
368
  forced_local=decision.forced_local,
369
  )
370
 
371
+ client, ephemeral = self._client_for(decision.provider, decision.model)
372
  try:
373
  if hasattr(client, "generate_stream"):
374
  async for token in client.generate_stream(
 
382
  )
383
  yield response.text
384
  finally:
385
+ # Owner clients are cached — never closed. Per-request BYOK clients
386
+ # are fresh; close after the stream is exhausted.
387
+ if ephemeral:
388
+ with contextlib.suppress(Exception):
389
+ await client.close()
390
 
391
  def get_available_providers(self) -> list[str]:
392
  """Return a list of currently configured and available providers.
 
443
  "anthropic": settings.anthropic_model,
444
  }
445
  return model_defaults.get(provider, settings.llm_model)
446
+
447
+ @staticmethod
448
+ def _client_for(provider: str, model: str):
449
+ """Resolve the LLM client for ``provider``, honouring per-request BYOK.
450
+
451
+ When the current request carries active BYOK creds (ContextVar) for the
452
+ *same* provider the routing decision selected, build a **fresh
453
+ per-request client** bound to the visitor's key/URL — so the visitor's
454
+ own key pays for and powers the call. The fresh client is ephemeral and
455
+ the caller MUST close it after use.
456
+
457
+ Otherwise return the owner's cached client (shared, never closed
458
+ per-request).
459
+
460
+ Returns:
461
+ ``(client, ephemeral)`` — ``ephemeral`` True means the caller owns
462
+ the client and must ``await client.close()`` when done.
463
+ """
464
+ rt = get_byok_runtime()
465
+ if rt is not None and rt.is_active() and (rt.provider or "").lower() == provider.lower():
466
+ prov = provider.lower()
467
+ if prov == "ollama":
468
+ from inference.ollama_client import make_byok_ollama_client
469
+
470
+ return make_byok_ollama_client(base_url=rt.ollama_url or "", model=model), True
471
+ from inference.cloud_clients import make_byok_cloud_client
472
+
473
+ return (
474
+ make_byok_cloud_client(provider=prov, user_key=rt.user_key or "", model=model),
475
+ True,
476
+ )
477
+ return get_llm(provider=provider, model=model), False
interfaces/api.py CHANGED
@@ -296,9 +296,29 @@ if _FASTAPI_AVAILABLE:
296
  # uses per-request BYOK credentials instead. Isolation is enforced via
297
  # session-scoped Qdrant collections, not JWT identity.
298
  if settings.byok_mode:
 
 
 
 
 
299
  from interfaces.byok import ByokCreds, client_ip_from_request, extract_byok
300
  from utils.rate_limiter import get_owner_key_throttle
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  # All demo personas share ``org_id="demo"`` so they query the same
303
  # ingested corpus. RBAC differentiation is enforced via clearance
304
  # level + roles at the payload-filter layer -- exactly the production
@@ -495,7 +515,11 @@ if _FASTAPI_AVAILABLE:
495
  filter still runs end-to-end — same code path as authenticated
496
  queries, just with demo identities.
497
  """
498
- if not creds.has_user_key():
 
 
 
 
499
  throttle = get_owner_key_throttle()
500
  client_ip = client_ip_from_request(request)
501
  ok, meta = throttle.allow(client_ip)
@@ -516,16 +540,24 @@ if _FASTAPI_AVAILABLE:
516
  import time as _t
517
 
518
  _t0 = _t.perf_counter()
519
- state = await run_rag_pipeline(
520
- query=body.query,
521
- user_context=user_ctx,
522
- thread_id=f"byok-{creds.session_id}",
523
- prefer_cloud=body.prefer_cloud,
524
- # Visitor's chosen provider when present; falls back to env.
525
- override_provider=creds.safe_provider(),
526
- persona_style=_persona_style(creds),
527
- byok_session_id=creds.session_id,
528
- )
 
 
 
 
 
 
 
 
529
  elapsed_ms = (_t.perf_counter() - _t0) * 1000
530
  response = QueryResponse.from_state(state)
531
  # Persist a single audit-log row so /byok/audit can surface the
@@ -591,7 +623,7 @@ if _FASTAPI_AVAILABLE:
591
 
592
  CORS is already mounted on the app when ``byok_mode`` is on.
593
  """
594
- if not creds.has_user_key():
595
  throttle = get_owner_key_throttle()
596
  client_ip = client_ip_from_request(request)
597
  ok, meta = throttle.allow(client_ip)
@@ -614,6 +646,9 @@ if _FASTAPI_AVAILABLE:
614
  import time as _t
615
 
616
  _t0 = _t.perf_counter()
 
 
 
617
  # Replay the session_id up front so the client can stitch
618
  # token deltas to a known turn without waiting for `final`.
619
  yield (
@@ -674,6 +709,10 @@ if _FASTAPI_AVAILABLE:
674
  except Exception as exc: # pragma: no cover -- defensive
675
  logger.exception("byok_stream_failed", error=str(exc))
676
  yield (f"event: error\ndata: {json.dumps({'message': 'stream_failed'})}\n\n")
 
 
 
 
677
  # Persist audit row at the end of the stream so /byok/audit
678
  # surfaces the session's history even when the visitor
679
  # disconnects before the final frame.
@@ -1174,22 +1213,6 @@ if _FASTAPI_AVAILABLE:
1174
  token_type="bearer",
1175
  expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
1176
  )
1177
- try:
1178
- token = issue_token(
1179
- user_id=body.user_id,
1180
- org_id=body.org_id,
1181
- roles=body.roles,
1182
- clearance_level=body.clearance_level,
1183
- ttl_seconds=body.ttl_seconds,
1184
- )
1185
- except AuthError as exc:
1186
- raise HTTPException(
1187
- status.HTTP_500_INTERNAL_SERVER_ERROR, f"token_issue_{exc.reason}: {exc}"
1188
- ) from exc
1189
- return _TokenResponse(
1190
- access_token=token,
1191
- expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
1192
- )
1193
 
1194
  else: # pragma: no cover
1195
  app = None # type: ignore[assignment]
 
296
  # uses per-request BYOK credentials instead. Isolation is enforced via
297
  # session-scoped Qdrant collections, not JWT identity.
298
  if settings.byok_mode:
299
+ from inference.byok_context import (
300
+ ByokRuntime,
301
+ reset_byok_runtime,
302
+ set_byok_runtime,
303
+ )
304
  from interfaces.byok import ByokCreds, client_ip_from_request, extract_byok
305
  from utils.rate_limiter import get_owner_key_throttle
306
 
307
+ def _byok_runtime_for(creds: ByokCreds) -> ByokRuntime | None:
308
+ """Build the per-request BYOK runtime from creds, or None.
309
+
310
+ Only returns a runtime when the visitor brought usable creds — so
311
+ the visitor's own key powers the call. Otherwise None and the
312
+ pipeline routes through the owner's cached clients (throttled).
313
+ """
314
+ if not creds.byok_active():
315
+ return None
316
+ return ByokRuntime(
317
+ provider=creds.safe_provider(),
318
+ user_key=creds.user_key,
319
+ ollama_url=creds.ollama_url,
320
+ )
321
+
322
  # All demo personas share ``org_id="demo"`` so they query the same
323
  # ingested corpus. RBAC differentiation is enforced via clearance
324
  # level + roles at the payload-filter layer -- exactly the production
 
515
  filter still runs end-to-end — same code path as authenticated
516
  queries, just with demo identities.
517
  """
518
+ # Only a visitor with *usable* BYOK creds bypasses the throttle —
519
+ # and that same key now actually powers the call (see the BYOK
520
+ # runtime below). A bare/junk key with no usable provider no longer
521
+ # skips the throttle while spending the owner key.
522
+ if not creds.byok_active():
523
  throttle = get_owner_key_throttle()
524
  client_ip = client_ip_from_request(request)
525
  ok, meta = throttle.allow(client_ip)
 
540
  import time as _t
541
 
542
  _t0 = _t.perf_counter()
543
+ # Bind the visitor's key/provider for THIS request so the inference
544
+ # router builds a per-request client from it. The ContextVar
545
+ # propagates into run_rag_pipeline and every LangGraph node/LLM call;
546
+ # reset in finally so it never leaks to the next request.
547
+ _byok_tok = set_byok_runtime(_byok_runtime_for(creds))
548
+ try:
549
+ state = await run_rag_pipeline(
550
+ query=body.query,
551
+ user_context=user_ctx,
552
+ thread_id=f"byok-{creds.session_id}",
553
+ prefer_cloud=body.prefer_cloud,
554
+ # Visitor's chosen provider when present; falls back to env.
555
+ override_provider=creds.safe_provider(),
556
+ persona_style=_persona_style(creds),
557
+ byok_session_id=creds.session_id,
558
+ )
559
+ finally:
560
+ reset_byok_runtime(_byok_tok)
561
  elapsed_ms = (_t.perf_counter() - _t0) * 1000
562
  response = QueryResponse.from_state(state)
563
  # Persist a single audit-log row so /byok/audit can surface the
 
623
 
624
  CORS is already mounted on the app when ``byok_mode`` is on.
625
  """
626
+ if not creds.byok_active():
627
  throttle = get_owner_key_throttle()
628
  client_ip = client_ip_from_request(request)
629
  ok, meta = throttle.allow(client_ip)
 
646
  import time as _t
647
 
648
  _t0 = _t.perf_counter()
649
+ # Bind the visitor's key/provider for the lifetime of this
650
+ # stream so the synthesizer's streaming LLM call uses it.
651
+ _byok_tok = set_byok_runtime(_byok_runtime_for(creds))
652
  # Replay the session_id up front so the client can stitch
653
  # token deltas to a known turn without waiting for `final`.
654
  yield (
 
709
  except Exception as exc: # pragma: no cover -- defensive
710
  logger.exception("byok_stream_failed", error=str(exc))
711
  yield (f"event: error\ndata: {json.dumps({'message': 'stream_failed'})}\n\n")
712
+ finally:
713
+ # Always clear the per-request BYOK runtime so it never
714
+ # leaks into the next request handled by this worker.
715
+ reset_byok_runtime(_byok_tok)
716
  # Persist audit row at the end of the stream so /byok/audit
717
  # surfaces the session's history even when the visitor
718
  # disconnects before the final frame.
 
1213
  token_type="bearer",
1214
  expires_in=body.ttl_seconds or settings.jwt_ttl_seconds,
1215
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
  else: # pragma: no cover
1218
  app = None # type: ignore[assignment]
interfaces/byok.py CHANGED
@@ -80,6 +80,23 @@ class ByokCreds(BaseModel):
80
  return self.provider.lower()
81
  return None
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  def client_ip_from_request(request: Request) -> str:
85
  """Resolve the visitor IP for throttling, honouring ``X-Forwarded-For``.
@@ -135,8 +152,10 @@ def _derive_session_id(client_host: str | None) -> str:
135
  """
136
  host = (client_host or "anon").strip() or "anon"
137
  digest = hashlib.sha256(host.encode("utf-8")).hexdigest()[:8]
138
- random = uuid.uuid4().hex[:8]
139
- return f"{digest}-{random}"
 
 
140
 
141
 
142
  def build_creds(
 
80
  return self.provider.lower()
81
  return None
82
 
83
+ def byok_active(self) -> bool:
84
+ """True when the visitor's creds can actually power a per-request LLM call.
85
+
86
+ Stricter than :meth:`has_user_key`: a key-based provider (groq / openai /
87
+ anthropic) needs a non-empty key AND a valid provider; an Ollama BYOK
88
+ needs a reachable URL. This is the gate the chat endpoints use both to
89
+ (a) bypass the owner-key throttle and (b) bind the per-request client —
90
+ so a bare ``X-User-LLM-Key`` with no usable provider can no longer skip
91
+ the throttle while still spending the owner key.
92
+ """
93
+ prov = self.safe_provider()
94
+ if prov in ("groq", "openai", "anthropic"):
95
+ return self.has_user_key()
96
+ if prov == "ollama":
97
+ return bool(self.ollama_url and self.ollama_url.strip())
98
+ return False
99
+
100
 
101
  def client_ip_from_request(request: Request) -> str:
102
  """Resolve the visitor IP for throttling, honouring ``X-Forwarded-For``.
 
152
  """
153
  host = (client_host or "anon").strip() or "anon"
154
  digest = hashlib.sha256(host.encode("utf-8")).hexdigest()[:8]
155
+ # Full UUID4 (122 bits) for the random component — the session id guards one
156
+ # visitor's session-scoped uploads / audit from another, so it must be hard
157
+ # to guess. The host digest only adds reconnect stickiness within a worker.
158
+ return f"{digest}-{uuid.uuid4().hex}"
159
 
160
 
161
  def build_creds(
utils/query_cache.py CHANGED
@@ -1,277 +1,289 @@
1
- """Query result caching with Redis fallback to in-memory.
2
-
3
- Caches RAG pipeline results to avoid redundant LLM calls and retrieval
4
- for identical queries from the same user. Uses Redis when available for
5
- distributed caching across multiple app instances.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import hashlib
11
- import json
12
- import time
13
- from typing import Any
14
-
15
- from config.settings import settings
16
- from utils.logging import get_logger
17
- from utils.pii import redact_dict
18
-
19
- logger = get_logger(__name__)
20
-
21
- # In-memory fallback cache
22
- _memory_cache: dict[str, tuple[dict[str, Any], float]] = {}
23
- _memory_cache_ttl_seconds: float = 300.0 # 5 minutes default
24
-
25
- # Redis singleton
26
- _redis_client = None
27
-
28
- # Cache metrics counters
29
- _cache_hits: int = 0
30
- _cache_misses: int = 0
31
-
32
-
33
- def get_cache_metrics() -> dict[str, int]:
34
- """Return cache hit/miss counters."""
35
- total = _cache_hits + _cache_misses
36
- return {
37
- "hits": _cache_hits,
38
- "misses": _cache_misses,
39
- "total": total,
40
- "hit_rate": round(_cache_hits / total, 4) if total > 0 else 0.0,
41
- }
42
-
43
-
44
- def reset_cache_metrics() -> None:
45
- """Reset cache hit/miss counters."""
46
- global _cache_hits, _cache_misses
47
- _cache_hits = 0
48
- _cache_misses = 0
49
-
50
-
51
- def _get_redis_client():
52
- """Lazy-initialize Redis client for query caching.
53
-
54
- Returns:
55
- Redis client instance or None if unavailable.
56
- """
57
- global _redis_client
58
- if _redis_client is not None:
59
- return _redis_client
60
-
61
- if not settings.redis_url:
62
- return None
63
-
64
- try:
65
- import redis
66
-
67
- _redis_client = redis.from_url(settings.redis_url, decode_responses=True)
68
- _redis_client.ping()
69
- logger.info("query_cache_redis_connected")
70
- return _redis_client
71
- except ImportError:
72
- logger.debug("redis_not_installed_for_query_cache")
73
- except Exception as exc:
74
- logger.warning("query_cache_redis_connection_failed", error=str(exc))
75
-
76
- _redis_client = False # Mark as unavailable
77
- return None
78
-
79
-
80
- def _build_cache_key(user_id: str, query: str, context_hash: str = "") -> str:
81
- """Build a deterministic cache key from user + query.
82
-
83
- Args:
84
- user_id: The user's identifier.
85
- query: The query text.
86
- context_hash: Optional hash of additional context (model, filters, etc.).
87
-
88
- Returns:
89
- A hash string suitable for use as a cache key.
90
- """
91
- key_data = f"{user_id}:{query.lower().strip()}:{context_hash}"
92
- return hashlib.sha256(key_data.encode()).hexdigest()[:32]
93
-
94
-
95
- def get_cached_result(
96
- user_id: str,
97
- query: str,
98
- context_hash: str = "",
99
- ttl_seconds: float | None = None,
100
- ) -> dict[str, Any] | None:
101
- """Retrieve a cached query result if available and not expired.
102
-
103
- Args:
104
- user_id: The user's identifier.
105
- query: The query text.
106
- context_hash: Optional hash of additional context.
107
- ttl_seconds: Cache TTL. Defaults to settings or 300s.
108
-
109
- Returns:
110
- Cached result dict, or None if not found or expired.
111
- """
112
- cache_key = _build_cache_key(user_id, query, context_hash)
113
- _ = ttl_seconds or _memory_cache_ttl_seconds
114
-
115
- global _cache_hits, _cache_misses
116
-
117
- # Try Redis first
118
- redis_client = _get_redis_client()
119
- if redis_client:
120
- try:
121
- cached = redis_client.get(f"rag:query:{cache_key}")
122
- if cached:
123
- result = json.loads(cached)
124
- _cache_hits += 1
125
- logger.info("query_cache_hit", source="redis", user_id=user_id)
126
- return result
127
- except Exception as exc:
128
- logger.debug("query_cache_redis_read_failed", error=str(exc))
129
-
130
- # Fallback to in-memory
131
- if cache_key in _memory_cache:
132
- result, expiry = _memory_cache[cache_key]
133
- if time.time() < expiry:
134
- _cache_hits += 1
135
- logger.info("query_cache_hit", source="memory", user_id=user_id)
136
- return result
137
- # Expired — clean up
138
- del _memory_cache[cache_key]
139
-
140
- _cache_misses += 1
141
- return None
142
-
143
-
144
- def set_cached_result(
145
- user_id: str,
146
- query: str,
147
- result: dict[str, Any],
148
- context_hash: str = "",
149
- ttl_seconds: float | None = None,
150
- ) -> None:
151
- """Store a query result in the cache.
152
-
153
- Args:
154
- user_id: The user's identifier.
155
- query: The query text.
156
- result: The result dict to cache.
157
- context_hash: Optional hash of additional context.
158
- ttl_seconds: Cache TTL. Defaults to settings or 300s.
159
- """
160
- cache_key = _build_cache_key(user_id, query, context_hash)
161
- ttl = ttl_seconds or _memory_cache_ttl_seconds
162
-
163
- # Serialize result (exclude non-serializable fields) + redact PII before
164
- # persistence so disk/Redis never sees emails, phones, card numbers, etc.
165
- serializable_result = redact_dict(_make_serializable(result))
166
-
167
- # Try Redis first
168
- redis_client = _get_redis_client()
169
- if redis_client:
170
- try:
171
- redis_client.setex(
172
- f"rag:query:{cache_key}",
173
- int(ttl),
174
- json.dumps(serializable_result),
175
- )
176
- logger.info("query_cache_stored", source="redis", user_id=user_id)
177
- return
178
- except Exception as exc:
179
- logger.debug("query_cache_redis_write_failed", error=str(exc))
180
-
181
- # Fallback to in-memory
182
- _memory_cache[cache_key] = (serializable_result, time.time() + ttl)
183
- logger.info("query_cache_stored", source="memory", user_id=user_id)
184
-
185
- # Prune memory cache if too large
186
- if len(_memory_cache) > 1000:
187
- _prune_memory_cache()
188
-
189
-
190
- def _make_serializable(obj: Any) -> Any:
191
- """Convert an object to a JSON-serializable form.
192
-
193
- Args:
194
- obj: Object to serialize.
195
-
196
- Returns:
197
- JSON-serializable representation.
198
- """
199
- if isinstance(obj, dict):
200
- return {k: _make_serializable(v) for k, v in obj.items()}
201
- if isinstance(obj, list):
202
- return [_make_serializable(v) for v in obj]
203
- if isinstance(obj, (str, int, float, bool, type(None))):
204
- return obj
205
- return str(obj)
206
-
207
-
208
- def _prune_memory_cache() -> None:
209
- """Remove expired entries from the in-memory cache."""
210
- now = time.time()
211
- expired_keys = [k for k, (_, expiry) in _memory_cache.items() if expiry < now]
212
- for k in expired_keys:
213
- del _memory_cache[k]
214
-
215
- # If still too large, remove oldest
216
- if len(_memory_cache) > 1000:
217
- sorted_items = sorted(_memory_cache.items(), key=lambda x: x[1][1])
218
- for k, _ in sorted_items[:100]:
219
- del _memory_cache[k]
220
-
221
-
222
- def invalidate_user_cache(user_id: str) -> int:
223
- """Invalidate all cached queries for a specific user.
224
-
225
- Args:
226
- user_id: The user's identifier.
227
-
228
- Returns:
229
- Number of entries invalidated.
230
- """
231
- count = 0
232
-
233
- # In-memory
234
- prefix = hashlib.sha256(f"{user_id}:".encode()).hexdigest()[:16]
235
- keys_to_remove = [k for k in _memory_cache if k.startswith(prefix)]
236
- for k in keys_to_remove:
237
- del _memory_cache[k]
238
- count += 1
239
-
240
- # Redis scan for user-specific keys
241
- redis_client = _get_redis_client()
242
- if redis_client:
243
- try:
244
- pattern = "rag:query:*"
245
- for key in redis_client.scan_iter(match=pattern, count=100):
246
- # Best-effort: we can't easily decode the key back to user_id
247
- # So we just clear all query cache entries
248
- redis_client.delete(key)
249
- count += 1
250
- except Exception as exc:
251
- logger.debug("query_cache_redis_invalidate_failed", error=str(exc))
252
-
253
- logger.info("query_cache_invalidated", user_id=user_id, count=count)
254
- return count
255
-
256
-
257
- def clear_all_cache() -> int:
258
- """Clear all query caches (memory + Redis).
259
-
260
- Returns:
261
- Number of entries cleared.
262
- """
263
- count = len(_memory_cache)
264
- _memory_cache.clear()
265
-
266
- redis_client = _get_redis_client()
267
- if redis_client:
268
- try:
269
- pattern = "rag:query:*"
270
- for key in redis_client.scan_iter(match=pattern, count=100):
271
- redis_client.delete(key)
272
- count += 1
273
- except Exception as exc:
274
- logger.debug("query_cache_redis_clear_failed", error=str(exc))
275
-
276
- logger.info("query_cache_cleared_all", count=count)
277
- return count
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query result caching with Redis fallback to in-memory.
2
+
3
+ Caches RAG pipeline results to avoid redundant LLM calls and retrieval
4
+ for identical queries from the same user. Uses Redis when available for
5
+ distributed caching across multiple app instances.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import hashlib
11
+ import json
12
+ import time
13
+ from typing import Any
14
+
15
+ from config.settings import settings
16
+ from utils.logging import get_logger
17
+ from utils.pii import redact_dict
18
+
19
+ logger = get_logger(__name__)
20
+
21
+ # In-memory fallback cache
22
+ _memory_cache: dict[str, tuple[dict[str, Any], float]] = {}
23
+ _memory_cache_ttl_seconds: float = 300.0 # 5 minutes default
24
+
25
+ # Redis singleton
26
+ _redis_client = None
27
+
28
+ # Cache metrics counters
29
+ _cache_hits: int = 0
30
+ _cache_misses: int = 0
31
+
32
+
33
+ def get_cache_metrics() -> dict[str, int]:
34
+ """Return cache hit/miss counters."""
35
+ total = _cache_hits + _cache_misses
36
+ return {
37
+ "hits": _cache_hits,
38
+ "misses": _cache_misses,
39
+ "total": total,
40
+ "hit_rate": round(_cache_hits / total, 4) if total > 0 else 0.0,
41
+ }
42
+
43
+
44
+ def reset_cache_metrics() -> None:
45
+ """Reset cache hit/miss counters."""
46
+ global _cache_hits, _cache_misses
47
+ _cache_hits = 0
48
+ _cache_misses = 0
49
+
50
+
51
+ def _get_redis_client():
52
+ """Lazy-initialize Redis client for query caching.
53
+
54
+ Returns:
55
+ Redis client instance or None if unavailable.
56
+ """
57
+ global _redis_client
58
+ if _redis_client is not None:
59
+ return _redis_client
60
+
61
+ if not settings.redis_url:
62
+ return None
63
+
64
+ try:
65
+ import redis
66
+
67
+ _redis_client = redis.from_url(settings.redis_url, decode_responses=True)
68
+ _redis_client.ping()
69
+ logger.info("query_cache_redis_connected")
70
+ return _redis_client
71
+ except ImportError:
72
+ logger.debug("redis_not_installed_for_query_cache")
73
+ except Exception as exc:
74
+ logger.warning("query_cache_redis_connection_failed", error=str(exc))
75
+
76
+ _redis_client = False # Mark as unavailable
77
+ return None
78
+
79
+
80
+ def _user_prefix(user_id: str) -> str:
81
+ """Stable per-user key prefix so ``invalidate_user_cache`` can scan by user."""
82
+ return hashlib.sha256(user_id.encode()).hexdigest()[:12]
83
+
84
+
85
+ def _build_cache_key(user_id: str, query: str, context_hash: str = "") -> str:
86
+ """Build a deterministic cache key from user + query.
87
+
88
+ The key is ``<user_prefix><body_hash>`` so a single user's entries share a
89
+ common prefix that is what makes ``invalidate_user_cache`` work (a hash of
90
+ one string is never a prefix of a hash of a different string, so the old
91
+ ``startswith(sha256(user_id))`` scan silently matched nothing).
92
+
93
+ Args:
94
+ user_id: The user's identifier.
95
+ query: The query text.
96
+ context_hash: Optional hash of additional context (model, filters, etc.).
97
+
98
+ Returns:
99
+ A hash string suitable for use as a cache key.
100
+ """
101
+ body = f"{query.lower().strip()}:{context_hash}"
102
+ body_hash = hashlib.sha256(body.encode()).hexdigest()[:20]
103
+ return f"{_user_prefix(user_id)}{body_hash}"
104
+
105
+
106
+ def get_cached_result(
107
+ user_id: str,
108
+ query: str,
109
+ context_hash: str = "",
110
+ ttl_seconds: float | None = None,
111
+ ) -> dict[str, Any] | None:
112
+ """Retrieve a cached query result if available and not expired.
113
+
114
+ Args:
115
+ user_id: The user's identifier.
116
+ query: The query text.
117
+ context_hash: Optional hash of additional context.
118
+ ttl_seconds: Cache TTL. Defaults to settings or 300s.
119
+
120
+ Returns:
121
+ Cached result dict, or None if not found or expired.
122
+ """
123
+ cache_key = _build_cache_key(user_id, query, context_hash)
124
+ _ = ttl_seconds or _memory_cache_ttl_seconds
125
+
126
+ global _cache_hits, _cache_misses
127
+
128
+ # Try Redis first
129
+ redis_client = _get_redis_client()
130
+ if redis_client:
131
+ try:
132
+ cached = redis_client.get(f"rag:query:{cache_key}")
133
+ if cached:
134
+ result = json.loads(cached)
135
+ _cache_hits += 1
136
+ logger.info("query_cache_hit", source="redis", user_id=user_id)
137
+ return result
138
+ except Exception as exc:
139
+ logger.debug("query_cache_redis_read_failed", error=str(exc))
140
+
141
+ # Fallback to in-memory
142
+ if cache_key in _memory_cache:
143
+ result, expiry = _memory_cache[cache_key]
144
+ if time.time() < expiry:
145
+ _cache_hits += 1
146
+ logger.info("query_cache_hit", source="memory", user_id=user_id)
147
+ return result
148
+ # Expired clean up
149
+ del _memory_cache[cache_key]
150
+
151
+ _cache_misses += 1
152
+ return None
153
+
154
+
155
+ def set_cached_result(
156
+ user_id: str,
157
+ query: str,
158
+ result: dict[str, Any],
159
+ context_hash: str = "",
160
+ ttl_seconds: float | None = None,
161
+ ) -> None:
162
+ """Store a query result in the cache.
163
+
164
+ Args:
165
+ user_id: The user's identifier.
166
+ query: The query text.
167
+ result: The result dict to cache.
168
+ context_hash: Optional hash of additional context.
169
+ ttl_seconds: Cache TTL. Defaults to settings or 300s.
170
+ """
171
+ cache_key = _build_cache_key(user_id, query, context_hash)
172
+ ttl = ttl_seconds or _memory_cache_ttl_seconds
173
+
174
+ # Serialize result (exclude non-serializable fields) + redact PII before
175
+ # persistence so disk/Redis never sees emails, phones, card numbers, etc.
176
+ serializable_result = redact_dict(_make_serializable(result))
177
+
178
+ # Try Redis first
179
+ redis_client = _get_redis_client()
180
+ if redis_client:
181
+ try:
182
+ redis_client.setex(
183
+ f"rag:query:{cache_key}",
184
+ int(ttl),
185
+ json.dumps(serializable_result),
186
+ )
187
+ logger.info("query_cache_stored", source="redis", user_id=user_id)
188
+ return
189
+ except Exception as exc:
190
+ logger.debug("query_cache_redis_write_failed", error=str(exc))
191
+
192
+ # Fallback to in-memory
193
+ _memory_cache[cache_key] = (serializable_result, time.time() + ttl)
194
+ logger.info("query_cache_stored", source="memory", user_id=user_id)
195
+
196
+ # Prune memory cache if too large
197
+ if len(_memory_cache) > 1000:
198
+ _prune_memory_cache()
199
+
200
+
201
+ def _make_serializable(obj: Any) -> Any:
202
+ """Convert an object to a JSON-serializable form.
203
+
204
+ Args:
205
+ obj: Object to serialize.
206
+
207
+ Returns:
208
+ JSON-serializable representation.
209
+ """
210
+ if isinstance(obj, dict):
211
+ return {k: _make_serializable(v) for k, v in obj.items()}
212
+ if isinstance(obj, list):
213
+ return [_make_serializable(v) for v in obj]
214
+ if isinstance(obj, (str, int, float, bool, type(None))):
215
+ return obj
216
+ return str(obj)
217
+
218
+
219
+ def _prune_memory_cache() -> None:
220
+ """Remove expired entries from the in-memory cache."""
221
+ now = time.time()
222
+ expired_keys = [k for k, (_, expiry) in _memory_cache.items() if expiry < now]
223
+ for k in expired_keys:
224
+ del _memory_cache[k]
225
+
226
+ # If still too large, remove oldest
227
+ if len(_memory_cache) > 1000:
228
+ sorted_items = sorted(_memory_cache.items(), key=lambda x: x[1][1])
229
+ for k, _ in sorted_items[:100]:
230
+ del _memory_cache[k]
231
+
232
+
233
+ def invalidate_user_cache(user_id: str) -> int:
234
+ """Invalidate all cached queries for a specific user.
235
+
236
+ Args:
237
+ user_id: The user's identifier.
238
+
239
+ Returns:
240
+ Number of entries invalidated.
241
+ """
242
+ count = 0
243
+
244
+ # In-memory — keys are namespaced ``<user_prefix><body_hash>`` so a single
245
+ # user's entries share this prefix (see _build_cache_key).
246
+ prefix = _user_prefix(user_id)
247
+ keys_to_remove = [k for k in _memory_cache if k.startswith(prefix)]
248
+ for k in keys_to_remove:
249
+ del _memory_cache[k]
250
+ count += 1
251
+
252
+ # Redis — scan for user-specific keys
253
+ redis_client = _get_redis_client()
254
+ if redis_client:
255
+ try:
256
+ pattern = "rag:query:*"
257
+ for key in redis_client.scan_iter(match=pattern, count=100):
258
+ # Best-effort: we can't easily decode the key back to user_id
259
+ # So we just clear all query cache entries
260
+ redis_client.delete(key)
261
+ count += 1
262
+ except Exception as exc:
263
+ logger.debug("query_cache_redis_invalidate_failed", error=str(exc))
264
+
265
+ logger.info("query_cache_invalidated", user_id=user_id, count=count)
266
+ return count
267
+
268
+
269
+ def clear_all_cache() -> int:
270
+ """Clear all query caches (memory + Redis).
271
+
272
+ Returns:
273
+ Number of entries cleared.
274
+ """
275
+ count = len(_memory_cache)
276
+ _memory_cache.clear()
277
+
278
+ redis_client = _get_redis_client()
279
+ if redis_client:
280
+ try:
281
+ pattern = "rag:query:*"
282
+ for key in redis_client.scan_iter(match=pattern, count=100):
283
+ redis_client.delete(key)
284
+ count += 1
285
+ except Exception as exc:
286
+ logger.debug("query_cache_redis_clear_failed", error=str(exc))
287
+
288
+ logger.info("query_cache_cleared_all", count=count)
289
+ return count