Harden CORS defaults (empty by default, explicit whitelist)
Browse files- .env.example +5 -0
- sage/api/app.py +58 -15
- sage/api/context.py +21 -0
- sage/api/middleware.py +14 -2
- sage/api/routes.py +131 -57
- sage/config/__init__.py +9 -0
- sage/config/logging.py +23 -0
- sage/core/__init__.py +0 -2
- sage/core/verification.py +4 -2
- sage/services/faithfulness.py +10 -17
- sage/utils.py +16 -0
- tests/test_production.py +293 -0
.env.example
CHANGED
|
@@ -16,6 +16,11 @@ ANTHROPIC_API_KEY=your_anthropic_api_key
|
|
| 16 |
# QDRANT_URL=https://your-cluster.cloud.qdrant.io
|
| 17 |
# QDRANT_API_KEY=your_qdrant_api_key
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# =============================================================================
|
| 20 |
# Optional
|
| 21 |
# =============================================================================
|
|
|
|
| 16 |
# QDRANT_URL=https://your-cluster.cloud.qdrant.io
|
| 17 |
# QDRANT_API_KEY=your_qdrant_api_key
|
| 18 |
|
| 19 |
+
# =============================================================================
|
| 20 |
+
# Security
|
| 21 |
+
# =============================================================================
|
| 22 |
+
# CORS_ORIGINS=https://your-domain.com,http://localhost:3000 # Comma-separated
|
| 23 |
+
|
| 24 |
# =============================================================================
|
| 25 |
# Optional
|
| 26 |
# =============================================================================
|
sage/api/app.py
CHANGED
|
@@ -26,7 +26,13 @@ from sage.api.middleware import (
|
|
| 26 |
from sage.api.routes import router
|
| 27 |
from sage.config import get_logger
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Graceful shutdown timeout (seconds to wait for active requests)
|
| 32 |
SHUTDOWN_TIMEOUT = float(os.getenv("SHUTDOWN_TIMEOUT", "30.0"))
|
|
@@ -49,17 +55,41 @@ async def _lifespan(app: FastAPI):
|
|
| 49 |
reset_shutdown_coordinator()
|
| 50 |
coordinator = get_shutdown_coordinator()
|
| 51 |
|
| 52 |
-
# Validate LLM credentials early
|
| 53 |
from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
-
elif LLM_PROVIDER == "anthropic" and not ANTHROPIC_API_KEY:
|
| 60 |
-
logger.warning("LLM_PROVIDER=anthropic but ANTHROPIC_API_KEY is not set")
|
| 61 |
-
elif LLM_PROVIDER == "openai" and not OPENAI_API_KEY:
|
| 62 |
-
logger.warning("LLM_PROVIDER=openai but OPENAI_API_KEY is not set")
|
| 63 |
|
| 64 |
# Embedder (loads E5-small model) -- required for all requests
|
| 65 |
from sage.adapters.embeddings import get_embedder
|
|
@@ -134,11 +164,24 @@ def create_app() -> FastAPI:
|
|
| 134 |
lifespan=_lifespan,
|
| 135 |
)
|
| 136 |
app.add_middleware(LatencyMiddleware)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
app.include_router(router)
|
| 144 |
return app
|
|
|
|
| 26 |
from sage.api.routes import router
|
| 27 |
from sage.config import get_logger
|
| 28 |
|
| 29 |
+
# CORS configuration - explicit origins required for security.
|
| 30 |
+
# Default to empty (no CORS) rather than "*" (all origins).
|
| 31 |
+
# Set CORS_ORIGINS="https://your-domain.com,http://localhost:3000" in production.
|
| 32 |
+
_cors_env = os.getenv("CORS_ORIGINS", "")
|
| 33 |
+
CORS_ORIGINS = (
|
| 34 |
+
[o.strip() for o in _cors_env.split(",") if o.strip()] if _cors_env else []
|
| 35 |
+
)
|
| 36 |
|
| 37 |
# Graceful shutdown timeout (seconds to wait for active requests)
|
| 38 |
SHUTDOWN_TIMEOUT = float(os.getenv("SHUTDOWN_TIMEOUT", "30.0"))
|
|
|
|
| 55 |
reset_shutdown_coordinator()
|
| 56 |
coordinator = get_shutdown_coordinator()
|
| 57 |
|
| 58 |
+
# Validate LLM credentials early - fail fast if invalid
|
| 59 |
from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
|
| 60 |
|
| 61 |
+
def _validate_api_key(key: str | None, provider: str) -> bool:
|
| 62 |
+
"""Validate API key format. Returns True if valid."""
|
| 63 |
+
if not key:
|
| 64 |
+
return False
|
| 65 |
+
if provider == "anthropic":
|
| 66 |
+
# Anthropic keys start with "sk-ant-" and are 100+ chars
|
| 67 |
+
return key.startswith("sk-ant-") and len(key) > 50
|
| 68 |
+
if provider == "openai":
|
| 69 |
+
# OpenAI keys start with "sk-" and are 40+ chars
|
| 70 |
+
return key.startswith("sk-") and len(key) > 20
|
| 71 |
+
return bool(key) # Unknown provider - just check non-empty
|
| 72 |
+
|
| 73 |
+
if LLM_PROVIDER == "anthropic":
|
| 74 |
+
if not ANTHROPIC_API_KEY:
|
| 75 |
+
logger.error("LLM_PROVIDER=anthropic but ANTHROPIC_API_KEY is not set")
|
| 76 |
+
raise ValueError("ANTHROPIC_API_KEY required when LLM_PROVIDER=anthropic")
|
| 77 |
+
if not _validate_api_key(ANTHROPIC_API_KEY, "anthropic"):
|
| 78 |
+
logger.error("ANTHROPIC_API_KEY has invalid format")
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"ANTHROPIC_API_KEY has invalid format (expected sk-ant-...)"
|
| 81 |
+
)
|
| 82 |
+
elif LLM_PROVIDER == "openai":
|
| 83 |
+
if not OPENAI_API_KEY:
|
| 84 |
+
logger.error("LLM_PROVIDER=openai but OPENAI_API_KEY is not set")
|
| 85 |
+
raise ValueError("OPENAI_API_KEY required when LLM_PROVIDER=openai")
|
| 86 |
+
if not _validate_api_key(OPENAI_API_KEY, "openai"):
|
| 87 |
+
logger.error("OPENAI_API_KEY has invalid format")
|
| 88 |
+
raise ValueError("OPENAI_API_KEY has invalid format (expected sk-...)")
|
| 89 |
+
else:
|
| 90 |
+
logger.warning(
|
| 91 |
+
"Unknown LLM_PROVIDER=%s, skipping credential validation", LLM_PROVIDER
|
| 92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# Embedder (loads E5-small model) -- required for all requests
|
| 95 |
from sage.adapters.embeddings import get_embedder
|
|
|
|
| 164 |
lifespan=_lifespan,
|
| 165 |
)
|
| 166 |
app.add_middleware(LatencyMiddleware)
|
| 167 |
+
|
| 168 |
+
# CORS middleware with security hardening
|
| 169 |
+
if CORS_ORIGINS:
|
| 170 |
+
if "*" in CORS_ORIGINS:
|
| 171 |
+
logger.warning(
|
| 172 |
+
"CORS_ORIGINS contains '*' - this allows requests from any origin. "
|
| 173 |
+
"Set explicit origins in production."
|
| 174 |
+
)
|
| 175 |
+
app.add_middleware(
|
| 176 |
+
CORSMiddleware,
|
| 177 |
+
allow_origins=CORS_ORIGINS,
|
| 178 |
+
allow_methods=["GET", "POST"],
|
| 179 |
+
allow_headers=["Content-Type", "Accept", "Authorization"],
|
| 180 |
+
allow_credentials=False,
|
| 181 |
+
max_age=3600, # Cache preflight for 1 hour
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
logger.info("CORS disabled (no CORS_ORIGINS configured)")
|
| 185 |
+
|
| 186 |
app.include_router(router)
|
| 187 |
return app
|
sage/api/context.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Request context management using contextvars.
|
| 3 |
+
|
| 4 |
+
Provides thread-safe request context propagation for logging and tracing.
|
| 5 |
+
Request ID set in middleware is accessible throughout the request lifecycle.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from contextvars import ContextVar
|
| 9 |
+
|
| 10 |
+
# Request ID for correlation across logs
|
| 11 |
+
request_id_var: ContextVar[str] = ContextVar("request_id", default="-")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_request_id() -> str:
|
| 15 |
+
"""Get the current request ID from context."""
|
| 16 |
+
return request_id_var.get()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def set_request_id(request_id: str) -> None:
|
| 20 |
+
"""Set the request ID in context."""
|
| 21 |
+
request_id_var.set(request_id)
|
sage/api/middleware.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
-
Request latency middleware and graceful shutdown coordinator.
|
| 3 |
|
| 4 |
Logs method/path/status/elapsed_ms for every request and records
|
| 5 |
-
Prometheus histogram observations. Adds
|
|
|
|
| 6 |
|
| 7 |
Uses a pure ASGI middleware (not BaseHTTPMiddleware) to avoid buffering
|
| 8 |
SSE streams.
|
|
@@ -24,6 +25,7 @@ from dataclasses import dataclass, field
|
|
| 24 |
from starlette.responses import JSONResponse
|
| 25 |
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
| 26 |
|
|
|
|
| 27 |
from sage.api.metrics import observe_duration, record_request
|
| 28 |
from sage.config import get_logger
|
| 29 |
|
|
@@ -192,6 +194,7 @@ class LatencyMiddleware:
|
|
| 192 |
|
| 193 |
start = time.perf_counter()
|
| 194 |
request_id = uuid.uuid4().hex[:12]
|
|
|
|
| 195 |
status = 500 # default until we see http.response.start
|
| 196 |
|
| 197 |
async def send_wrapper(message: Message) -> None:
|
|
@@ -202,8 +205,17 @@ class LatencyMiddleware:
|
|
| 202 |
# The Prometheus histogram (in finally) measures total time.
|
| 203 |
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 204 |
headers = list(message.get("headers", []))
|
|
|
|
| 205 |
headers.append((b"x-response-time-ms", f"{elapsed_ms:.1f}".encode()))
|
| 206 |
headers.append((b"x-request-id", request_id.encode()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
message = {**message, "headers": headers}
|
| 208 |
await send(message)
|
| 209 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Request latency middleware, security headers, and graceful shutdown coordinator.
|
| 3 |
|
| 4 |
Logs method/path/status/elapsed_ms for every request and records
|
| 5 |
+
Prometheus histogram observations. Adds security headers and
|
| 6 |
+
``X-Response-Time-Ms`` header.
|
| 7 |
|
| 8 |
Uses a pure ASGI middleware (not BaseHTTPMiddleware) to avoid buffering
|
| 9 |
SSE streams.
|
|
|
|
| 25 |
from starlette.responses import JSONResponse
|
| 26 |
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
| 27 |
|
| 28 |
+
from sage.api.context import set_request_id
|
| 29 |
from sage.api.metrics import observe_duration, record_request
|
| 30 |
from sage.config import get_logger
|
| 31 |
|
|
|
|
| 194 |
|
| 195 |
start = time.perf_counter()
|
| 196 |
request_id = uuid.uuid4().hex[:12]
|
| 197 |
+
set_request_id(request_id) # Propagate to all child operations
|
| 198 |
status = 500 # default until we see http.response.start
|
| 199 |
|
| 200 |
async def send_wrapper(message: Message) -> None:
|
|
|
|
| 205 |
# The Prometheus histogram (in finally) measures total time.
|
| 206 |
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 207 |
headers = list(message.get("headers", []))
|
| 208 |
+
# Timing and correlation headers
|
| 209 |
headers.append((b"x-response-time-ms", f"{elapsed_ms:.1f}".encode()))
|
| 210 |
headers.append((b"x-request-id", request_id.encode()))
|
| 211 |
+
# Security headers
|
| 212 |
+
headers.append((b"x-content-type-options", b"nosniff"))
|
| 213 |
+
headers.append((b"x-frame-options", b"DENY"))
|
| 214 |
+
headers.append((b"x-xss-protection", b"1; mode=block"))
|
| 215 |
+
headers.append((b"referrer-policy", b"strict-origin-when-cross-origin"))
|
| 216 |
+
headers.append(
|
| 217 |
+
(b"cache-control", b"no-store, no-cache, must-revalidate")
|
| 218 |
+
)
|
| 219 |
message = {**message, "headers": headers}
|
| 220 |
await send(message)
|
| 221 |
|
sage/api/routes.py
CHANGED
|
@@ -15,7 +15,7 @@ from __future__ import annotations
|
|
| 15 |
import asyncio
|
| 16 |
import json
|
| 17 |
import os
|
| 18 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
from typing import AsyncIterator
|
| 20 |
|
| 21 |
import numpy as np
|
|
@@ -26,6 +26,7 @@ from pydantic import BaseModel, Field
|
|
| 26 |
from sage.adapters.vector_store import collection_exists
|
| 27 |
from sage.api.metrics import metrics_response, record_cache_event, record_error
|
| 28 |
from sage.config import MAX_EVIDENCE, get_logger
|
|
|
|
| 29 |
from sage.core import (
|
| 30 |
AggregationMethod,
|
| 31 |
ExplanationResult,
|
|
@@ -39,6 +40,9 @@ from sage.services.retrieval import get_candidates
|
|
| 39 |
# good parallelism while bounding total concurrent LLM calls.
|
| 40 |
_MAX_EXPLAIN_WORKERS = 4
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
# Request timeout in seconds. David's rule: 10s max end-to-end.
|
| 43 |
# If the LLM hangs, cut it off and return what we have.
|
| 44 |
REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "10.0"))
|
|
@@ -206,6 +210,16 @@ def _build_evidence_list(result: ExplanationResult) -> list[dict]:
|
|
| 206 |
return result.to_evidence_dicts()
|
| 207 |
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# ---------------------------------------------------------------------------
|
| 210 |
# Health
|
| 211 |
# ---------------------------------------------------------------------------
|
|
@@ -313,11 +327,7 @@ async def ready(request: Request):
|
|
| 313 |
|
| 314 |
# Core components must be ready (explainer is optional)
|
| 315 |
core_ready = all(
|
| 316 |
-
|
| 317 |
-
components.get("qdrant", False),
|
| 318 |
-
components.get("embedder", False),
|
| 319 |
-
components.get("hhem", False),
|
| 320 |
-
]
|
| 321 |
)
|
| 322 |
|
| 323 |
if core_ready and components.get("explainer", False):
|
|
@@ -349,6 +359,103 @@ async def ready(request: Request):
|
|
| 349 |
# ---------------------------------------------------------------------------
|
| 350 |
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
def _sync_recommend(
|
| 353 |
body: RecommendationRequest,
|
| 354 |
app,
|
|
@@ -361,77 +468,44 @@ def _sync_recommend(
|
|
| 361 |
cache = app.state.cache
|
| 362 |
q = body.query
|
| 363 |
explain = body.explain
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
# Check cache before any heavy work (
|
| 366 |
-
#
|
| 367 |
-
# avoiding the cost of a second embed_single_query call.
|
| 368 |
if explain:
|
| 369 |
query_embedding = app.state.embedder.embed_single_query(q)
|
| 370 |
-
cached
|
| 371 |
-
record_cache_event(f"hit_{hit_type}" if hit_type != "miss" else "miss")
|
| 372 |
-
if cached is not None:
|
| 373 |
return cached
|
| 374 |
else:
|
| 375 |
query_embedding = None
|
| 376 |
|
| 377 |
products = _fetch_products(body, app, query_embedding=query_embedding)
|
| 378 |
-
|
| 379 |
if not products:
|
| 380 |
return {"query": q, "recommendations": []}
|
| 381 |
|
| 382 |
-
recommendations
|
| 383 |
-
|
| 384 |
if explain:
|
| 385 |
if app.state.explainer is None:
|
| 386 |
raise RuntimeError("Explanation service unavailable")
|
| 387 |
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
er = explainer.generate_explanation(
|
| 396 |
-
query=q,
|
| 397 |
-
product=product,
|
| 398 |
-
max_evidence=MAX_EVIDENCE,
|
| 399 |
-
)
|
| 400 |
-
hr = detector.check_explanation(
|
| 401 |
-
evidence_texts=er.evidence_texts,
|
| 402 |
-
explanation=er.explanation,
|
| 403 |
-
)
|
| 404 |
-
cr = verify_citations(er.explanation, er.evidence_ids, er.evidence_texts)
|
| 405 |
-
return er, hr, cr
|
| 406 |
-
|
| 407 |
-
with ThreadPoolExecutor(
|
| 408 |
-
max_workers=min(len(products), _MAX_EXPLAIN_WORKERS)
|
| 409 |
-
) as pool:
|
| 410 |
-
results = list(pool.map(_explain, products))
|
| 411 |
-
|
| 412 |
-
for i, (product, (er, hr, cr)) in enumerate(
|
| 413 |
-
zip(products, results, strict=True),
|
| 414 |
-
1,
|
| 415 |
-
):
|
| 416 |
-
rec = _build_product_dict(i, product)
|
| 417 |
-
rec["explanation"] = er.explanation
|
| 418 |
-
rec["confidence"] = {
|
| 419 |
-
"hhem_score": round(hr.score, 3),
|
| 420 |
-
"is_grounded": not hr.is_hallucinated,
|
| 421 |
-
"threshold": hr.threshold,
|
| 422 |
-
}
|
| 423 |
-
rec["citations_verified"] = cr.all_valid
|
| 424 |
-
rec["evidence_sources"] = _build_evidence_list(er)
|
| 425 |
-
recommendations.append(rec)
|
| 426 |
else:
|
| 427 |
-
|
| 428 |
-
|
|
|
|
| 429 |
|
| 430 |
result = {"query": q, "recommendations": recommendations}
|
| 431 |
|
| 432 |
-
# Store in cache (explain path only
|
| 433 |
if explain:
|
| 434 |
-
cache.put(
|
| 435 |
|
| 436 |
return result
|
| 437 |
|
|
|
|
| 15 |
import asyncio
|
| 16 |
import json
|
| 17 |
import os
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
| 19 |
from typing import AsyncIterator
|
| 20 |
|
| 21 |
import numpy as np
|
|
|
|
| 26 |
from sage.adapters.vector_store import collection_exists
|
| 27 |
from sage.api.metrics import metrics_response, record_cache_event, record_error
|
| 28 |
from sage.config import MAX_EVIDENCE, get_logger
|
| 29 |
+
from sage.utils import normalize_text
|
| 30 |
from sage.core import (
|
| 31 |
AggregationMethod,
|
| 32 |
ExplanationResult,
|
|
|
|
| 40 |
# good parallelism while bounding total concurrent LLM calls.
|
| 41 |
_MAX_EXPLAIN_WORKERS = 4
|
| 42 |
|
| 43 |
+
# Per-worker timeout for explanation generation (prevents hung workers)
|
| 44 |
+
_EXPLAIN_WORKER_TIMEOUT = 30.0
|
| 45 |
+
|
| 46 |
# Request timeout in seconds. David's rule: 10s max end-to-end.
|
| 47 |
# If the LLM hangs, cut it off and return what we have.
|
| 48 |
REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "10.0"))
|
|
|
|
| 210 |
return result.to_evidence_dicts()
|
| 211 |
|
| 212 |
|
| 213 |
+
def _build_cache_key(query: str, k: int, explain: bool, min_rating: float) -> str:
|
| 214 |
+
"""Build a cache key that includes all request parameters.
|
| 215 |
+
|
| 216 |
+
This prevents returning cached results for different request parameters.
|
| 217 |
+
For example, a query with k=3 should not return cached results from k=5.
|
| 218 |
+
"""
|
| 219 |
+
normalized_query = normalize_text(query)
|
| 220 |
+
return f"{normalized_query}:k={k}:explain={explain}:rating={min_rating:.1f}"
|
| 221 |
+
|
| 222 |
+
|
| 223 |
# ---------------------------------------------------------------------------
|
| 224 |
# Health
|
| 225 |
# ---------------------------------------------------------------------------
|
|
|
|
| 327 |
|
| 328 |
# Core components must be ready (explainer is optional)
|
| 329 |
core_ready = all(
|
| 330 |
+
components.get(key, False) for key in ("qdrant", "embedder", "hhem")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
)
|
| 332 |
|
| 333 |
if core_ready and components.get("explainer", False):
|
|
|
|
| 359 |
# ---------------------------------------------------------------------------
|
| 360 |
|
| 361 |
|
| 362 |
+
def _check_cache(
|
| 363 |
+
cache,
|
| 364 |
+
cache_key: str,
|
| 365 |
+
query_embedding: np.ndarray,
|
| 366 |
+
) -> dict | None:
|
| 367 |
+
"""Check cache for existing result and record metrics.
|
| 368 |
+
|
| 369 |
+
Returns cached result if found, None otherwise.
|
| 370 |
+
"""
|
| 371 |
+
cached, hit_type = cache.get(cache_key, query_embedding)
|
| 372 |
+
record_cache_event(f"hit_{hit_type}" if hit_type != "miss" else "miss")
|
| 373 |
+
return cached
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _generate_explanation_for_product(
|
| 377 |
+
query: str,
|
| 378 |
+
product: ProductScore,
|
| 379 |
+
explainer,
|
| 380 |
+
detector,
|
| 381 |
+
) -> tuple:
|
| 382 |
+
"""Generate explanation, HHEM score, and citation verification for a product.
|
| 383 |
+
|
| 384 |
+
Thread-safe: LLM clients use httpx, HHEM model is read-only.
|
| 385 |
+
Returns (ExplanationResult, HallucinationResult, CitationVerificationResult).
|
| 386 |
+
"""
|
| 387 |
+
er = explainer.generate_explanation(
|
| 388 |
+
query=query,
|
| 389 |
+
product=product,
|
| 390 |
+
max_evidence=MAX_EVIDENCE,
|
| 391 |
+
)
|
| 392 |
+
hr = detector.check_explanation(
|
| 393 |
+
evidence_texts=er.evidence_texts,
|
| 394 |
+
explanation=er.explanation,
|
| 395 |
+
)
|
| 396 |
+
cr = verify_citations(er.explanation, er.evidence_ids, er.evidence_texts)
|
| 397 |
+
return er, hr, cr
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _generate_explanations_parallel(
|
| 401 |
+
query: str,
|
| 402 |
+
products: list[ProductScore],
|
| 403 |
+
explainer,
|
| 404 |
+
detector,
|
| 405 |
+
) -> list[tuple[ProductScore, tuple]]:
|
| 406 |
+
"""Generate explanations for multiple products in parallel.
|
| 407 |
+
|
| 408 |
+
Uses ThreadPoolExecutor with per-worker timeout to prevent hung workers
|
| 409 |
+
from exhausting the pool. Products that timeout or fail are skipped.
|
| 410 |
+
"""
|
| 411 |
+
results = []
|
| 412 |
+
with ThreadPoolExecutor(
|
| 413 |
+
max_workers=min(len(products), _MAX_EXPLAIN_WORKERS)
|
| 414 |
+
) as pool:
|
| 415 |
+
futures = {
|
| 416 |
+
pool.submit(
|
| 417 |
+
_generate_explanation_for_product, query, p, explainer, detector
|
| 418 |
+
): p
|
| 419 |
+
for p in products
|
| 420 |
+
}
|
| 421 |
+
for future in futures:
|
| 422 |
+
product = futures[future]
|
| 423 |
+
try:
|
| 424 |
+
result = future.result(timeout=_EXPLAIN_WORKER_TIMEOUT)
|
| 425 |
+
results.append((product, result))
|
| 426 |
+
except FuturesTimeoutError:
|
| 427 |
+
logger.warning(
|
| 428 |
+
"Explanation timeout for product %s after %.1fs",
|
| 429 |
+
product.product_id,
|
| 430 |
+
_EXPLAIN_WORKER_TIMEOUT,
|
| 431 |
+
)
|
| 432 |
+
except Exception:
|
| 433 |
+
logger.exception(
|
| 434 |
+
"Explanation failed for product %s", product.product_id
|
| 435 |
+
)
|
| 436 |
+
return results
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def _build_recommendation_with_explanation(
|
| 440 |
+
rank: int,
|
| 441 |
+
product: ProductScore,
|
| 442 |
+
er: ExplanationResult,
|
| 443 |
+
hr,
|
| 444 |
+
cr,
|
| 445 |
+
) -> dict:
|
| 446 |
+
"""Build recommendation dict with explanation and confidence metrics."""
|
| 447 |
+
rec = _build_product_dict(rank, product)
|
| 448 |
+
rec["explanation"] = er.explanation
|
| 449 |
+
rec["confidence"] = {
|
| 450 |
+
"hhem_score": round(hr.score, 3),
|
| 451 |
+
"is_grounded": not hr.is_hallucinated,
|
| 452 |
+
"threshold": hr.threshold,
|
| 453 |
+
}
|
| 454 |
+
rec["citations_verified"] = cr.all_valid
|
| 455 |
+
rec["evidence_sources"] = _build_evidence_list(er)
|
| 456 |
+
return rec
|
| 457 |
+
|
| 458 |
+
|
| 459 |
def _sync_recommend(
|
| 460 |
body: RecommendationRequest,
|
| 461 |
app,
|
|
|
|
| 468 |
cache = app.state.cache
|
| 469 |
q = body.query
|
| 470 |
explain = body.explain
|
| 471 |
+
min_rating = body.filters.min_rating if body.filters else 4.0
|
| 472 |
+
cache_key = _build_cache_key(q, body.k, explain, min_rating)
|
| 473 |
|
| 474 |
+
# Check cache before any heavy work (explain path only).
|
| 475 |
+
# Embedding computed here is reused for candidate retrieval.
|
|
|
|
| 476 |
if explain:
|
| 477 |
query_embedding = app.state.embedder.embed_single_query(q)
|
| 478 |
+
if (cached := _check_cache(cache, cache_key, query_embedding)) is not None:
|
|
|
|
|
|
|
| 479 |
return cached
|
| 480 |
else:
|
| 481 |
query_embedding = None
|
| 482 |
|
| 483 |
products = _fetch_products(body, app, query_embedding=query_embedding)
|
|
|
|
| 484 |
if not products:
|
| 485 |
return {"query": q, "recommendations": []}
|
| 486 |
|
| 487 |
+
# Build recommendations with or without explanations
|
|
|
|
| 488 |
if explain:
|
| 489 |
if app.state.explainer is None:
|
| 490 |
raise RuntimeError("Explanation service unavailable")
|
| 491 |
|
| 492 |
+
explanation_results = _generate_explanations_parallel(
|
| 493 |
+
q, products, app.state.explainer, app.state.detector
|
| 494 |
+
)
|
| 495 |
+
recommendations = [
|
| 496 |
+
_build_recommendation_with_explanation(i, product, er, hr, cr)
|
| 497 |
+
for i, (product, (er, hr, cr)) in enumerate(explanation_results, 1)
|
| 498 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
else:
|
| 500 |
+
recommendations = [
|
| 501 |
+
_build_product_dict(i, product) for i, product in enumerate(products, 1)
|
| 502 |
+
]
|
| 503 |
|
| 504 |
result = {"query": q, "recommendations": recommendations}
|
| 505 |
|
| 506 |
+
# Store in cache (explain path only)
|
| 507 |
if explain:
|
| 508 |
+
cache.put(cache_key, query_embedding, result)
|
| 509 |
|
| 510 |
return result
|
| 511 |
|
sage/config/__init__.py
CHANGED
|
@@ -135,6 +135,13 @@ CACHE_MAX_ENTRIES = int(os.getenv("CACHE_MAX_ENTRIES", "1000"))
|
|
| 135 |
CACHE_TTL_SECONDS = float(os.getenv("CACHE_TTL_SECONDS", "3600"))
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
# ---------------------------------------------------------------------------
|
| 139 |
# Evidence Quality Gate
|
| 140 |
# ---------------------------------------------------------------------------
|
|
@@ -244,6 +251,8 @@ __all__ = [
|
|
| 244 |
"CACHE_SIMILARITY_THRESHOLD",
|
| 245 |
"CACHE_MAX_ENTRIES",
|
| 246 |
"CACHE_TTL_SECONDS",
|
|
|
|
|
|
|
| 247 |
# Evidence gate
|
| 248 |
"MAX_EVIDENCE",
|
| 249 |
"MIN_EVIDENCE_CHUNKS",
|
|
|
|
| 135 |
CACHE_TTL_SECONDS = float(os.getenv("CACHE_TTL_SECONDS", "3600"))
|
| 136 |
|
| 137 |
|
| 138 |
+
# ---------------------------------------------------------------------------
|
| 139 |
+
# Citation Format
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
CITATION_PREFIX = "review_" # Prefix for citation IDs (e.g., "review_123")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
# ---------------------------------------------------------------------------
|
| 146 |
# Evidence Quality Gate
|
| 147 |
# ---------------------------------------------------------------------------
|
|
|
|
| 251 |
"CACHE_SIMILARITY_THRESHOLD",
|
| 252 |
"CACHE_MAX_ENTRIES",
|
| 253 |
"CACHE_TTL_SECONDS",
|
| 254 |
+
# Citation
|
| 255 |
+
"CITATION_PREFIX",
|
| 256 |
# Evidence gate
|
| 257 |
"MAX_EVIDENCE",
|
| 258 |
"MIN_EVIDENCE_CHUNKS",
|
sage/config/logging.py
CHANGED
|
@@ -72,12 +72,21 @@ class ConsoleFormatter(logging.Formatter):
|
|
| 72 |
"ERROR": "\033[31m", # Red
|
| 73 |
"CRITICAL": "\033[35m", # Magenta
|
| 74 |
"RESET": "\033[0m",
|
|
|
|
| 75 |
}
|
| 76 |
|
| 77 |
def format(self, record: logging.LogRecord) -> str:
|
| 78 |
# Check if we're in a TTY (supports colors)
|
| 79 |
use_colors = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# Format timestamp
|
| 82 |
timestamp = self.formatTime(record, "%H:%M:%S")
|
| 83 |
|
|
@@ -86,9 +95,12 @@ class ConsoleFormatter(logging.Formatter):
|
|
| 86 |
if use_colors:
|
| 87 |
color = self.COLORS.get(level, "")
|
| 88 |
reset = self.COLORS["RESET"]
|
|
|
|
| 89 |
level_str = f"{color}{level:<8}{reset}"
|
|
|
|
| 90 |
else:
|
| 91 |
level_str = f"{level:<8}"
|
|
|
|
| 92 |
|
| 93 |
# Format message
|
| 94 |
message = record.getMessage()
|
|
@@ -101,6 +113,8 @@ class ConsoleFormatter(logging.Formatter):
|
|
| 101 |
|
| 102 |
extra_str = f" [{', '.join(extras)}]" if extras else ""
|
| 103 |
|
|
|
|
|
|
|
| 104 |
return f"{timestamp} {level_str} {message}{extra_str}"
|
| 105 |
|
| 106 |
|
|
@@ -116,10 +130,19 @@ class JSONFormatter(logging.Formatter):
|
|
| 116 |
import json
|
| 117 |
from datetime import datetime, timezone
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
log_entry = {
|
| 120 |
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 121 |
"level": record.levelname,
|
| 122 |
"logger": record.name,
|
|
|
|
| 123 |
"message": record.getMessage(),
|
| 124 |
}
|
| 125 |
|
|
|
|
| 72 |
"ERROR": "\033[31m", # Red
|
| 73 |
"CRITICAL": "\033[35m", # Magenta
|
| 74 |
"RESET": "\033[0m",
|
| 75 |
+
"DIM": "\033[2m", # Dim for request ID
|
| 76 |
}
|
| 77 |
|
| 78 |
def format(self, record: logging.LogRecord) -> str:
|
| 79 |
# Check if we're in a TTY (supports colors)
|
| 80 |
use_colors = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
|
| 81 |
|
| 82 |
+
# Get request ID from context
|
| 83 |
+
try:
|
| 84 |
+
from sage.api.context import get_request_id
|
| 85 |
+
|
| 86 |
+
request_id = get_request_id()
|
| 87 |
+
except ImportError:
|
| 88 |
+
request_id = "-"
|
| 89 |
+
|
| 90 |
# Format timestamp
|
| 91 |
timestamp = self.formatTime(record, "%H:%M:%S")
|
| 92 |
|
|
|
|
| 95 |
if use_colors:
|
| 96 |
color = self.COLORS.get(level, "")
|
| 97 |
reset = self.COLORS["RESET"]
|
| 98 |
+
dim = self.COLORS["DIM"]
|
| 99 |
level_str = f"{color}{level:<8}{reset}"
|
| 100 |
+
rid_str = f"{dim}[{request_id}]{reset}" if request_id != "-" else ""
|
| 101 |
else:
|
| 102 |
level_str = f"{level:<8}"
|
| 103 |
+
rid_str = f"[{request_id}]" if request_id != "-" else ""
|
| 104 |
|
| 105 |
# Format message
|
| 106 |
message = record.getMessage()
|
|
|
|
| 113 |
|
| 114 |
extra_str = f" [{', '.join(extras)}]" if extras else ""
|
| 115 |
|
| 116 |
+
if rid_str:
|
| 117 |
+
return f"{timestamp} {level_str} {rid_str} {message}{extra_str}"
|
| 118 |
return f"{timestamp} {level_str} {message}{extra_str}"
|
| 119 |
|
| 120 |
|
|
|
|
| 130 |
import json
|
| 131 |
from datetime import datetime, timezone
|
| 132 |
|
| 133 |
+
# Import here to avoid circular imports
|
| 134 |
+
try:
|
| 135 |
+
from sage.api.context import get_request_id
|
| 136 |
+
|
| 137 |
+
request_id = get_request_id()
|
| 138 |
+
except ImportError:
|
| 139 |
+
request_id = "-"
|
| 140 |
+
|
| 141 |
log_entry = {
|
| 142 |
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 143 |
"level": record.levelname,
|
| 144 |
"logger": record.name,
|
| 145 |
+
"request_id": request_id,
|
| 146 |
"message": record.getMessage(),
|
| 147 |
}
|
| 148 |
|
sage/core/__init__.py
CHANGED
|
@@ -65,7 +65,6 @@ from sage.core.verification import (
|
|
| 65 |
check_forbidden_phrases,
|
| 66 |
extract_citations,
|
| 67 |
extract_quotes,
|
| 68 |
-
normalize_text,
|
| 69 |
verify_citation,
|
| 70 |
verify_citations,
|
| 71 |
verify_explanation,
|
|
@@ -132,7 +131,6 @@ __all__ = [
|
|
| 132 |
"check_forbidden_phrases",
|
| 133 |
"extract_citations",
|
| 134 |
"extract_quotes",
|
| 135 |
-
"normalize_text",
|
| 136 |
"verify_citation",
|
| 137 |
"verify_citations",
|
| 138 |
"verify_explanation",
|
|
|
|
| 65 |
check_forbidden_phrases,
|
| 66 |
extract_citations,
|
| 67 |
extract_quotes,
|
|
|
|
| 68 |
verify_citation,
|
| 69 |
verify_citations,
|
| 70 |
verify_explanation,
|
|
|
|
| 131 |
"check_forbidden_phrases",
|
| 132 |
"extract_citations",
|
| 133 |
"extract_quotes",
|
|
|
|
| 134 |
"verify_citation",
|
| 135 |
"verify_citations",
|
| 136 |
"verify_explanation",
|
sage/core/verification.py
CHANGED
|
@@ -13,6 +13,7 @@ non-existent review IDs.
|
|
| 13 |
import re
|
| 14 |
from dataclasses import dataclass
|
| 15 |
|
|
|
|
| 16 |
from sage.core.models import (
|
| 17 |
CitationResult,
|
| 18 |
CitationVerificationResult,
|
|
@@ -218,16 +219,17 @@ def extract_citations(text: str) -> list[tuple[str, str | None]]:
|
|
| 218 |
|
| 219 |
# Pattern for quote followed by citation(s): "quote" [review_123] or [review_123, review_456]
|
| 220 |
quote_citation_pattern = r'"([^"]+)"\s*\[([^\]]+)\]'
|
|
|
|
| 221 |
for match in re.finditer(quote_citation_pattern, text):
|
| 222 |
quote_text = match.group(1)
|
| 223 |
citation_block = match.group(2)
|
| 224 |
# Split multiple citations like "review_123, review_456"
|
| 225 |
-
for citation_id in re.findall(
|
| 226 |
citations.append((citation_id, quote_text))
|
| 227 |
|
| 228 |
# Pattern for standalone citations not preceded by a quote
|
| 229 |
# Find all citations, then filter out ones already captured with quotes
|
| 230 |
-
all_citation_ids = set(re.findall(
|
| 231 |
quoted_citation_ids = {c[0] for c in citations}
|
| 232 |
standalone_ids = all_citation_ids - quoted_citation_ids
|
| 233 |
|
|
|
|
| 13 |
import re
|
| 14 |
from dataclasses import dataclass
|
| 15 |
|
| 16 |
+
from sage.config import CITATION_PREFIX
|
| 17 |
from sage.core.models import (
|
| 18 |
CitationResult,
|
| 19 |
CitationVerificationResult,
|
|
|
|
| 219 |
|
| 220 |
# Pattern for quote followed by citation(s): "quote" [review_123] or [review_123, review_456]
|
| 221 |
quote_citation_pattern = r'"([^"]+)"\s*\[([^\]]+)\]'
|
| 222 |
+
citation_id_pattern = rf"{re.escape(CITATION_PREFIX)}\d+"
|
| 223 |
for match in re.finditer(quote_citation_pattern, text):
|
| 224 |
quote_text = match.group(1)
|
| 225 |
citation_block = match.group(2)
|
| 226 |
# Split multiple citations like "review_123, review_456"
|
| 227 |
+
for citation_id in re.findall(citation_id_pattern, citation_block):
|
| 228 |
citations.append((citation_id, quote_text))
|
| 229 |
|
| 230 |
# Pattern for standalone citations not preceded by a quote
|
| 231 |
# Find all citations, then filter out ones already captured with quotes
|
| 232 |
+
all_citation_ids = set(re.findall(citation_id_pattern, text))
|
| 233 |
quoted_citation_ids = {c[0] for c in citations}
|
| 234 |
standalone_ids = all_citation_ids - quoted_citation_ids
|
| 235 |
|
sage/services/faithfulness.py
CHANGED
|
@@ -15,6 +15,7 @@ import asyncio
|
|
| 15 |
|
| 16 |
import numpy as np
|
| 17 |
|
|
|
|
| 18 |
from sage.core import (
|
| 19 |
AdjustedFaithfulnessReport,
|
| 20 |
AgreementReport,
|
|
@@ -120,10 +121,8 @@ def create_ragas_sample(query: str, explanation: str, evidence_texts: list[str])
|
|
| 120 |
Raises:
|
| 121 |
ImportError: If ragas is not installed.
|
| 122 |
"""
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
except ImportError:
|
| 126 |
-
raise ImportError("ragas package required. Install with: pip install ragas")
|
| 127 |
|
| 128 |
# Clean explanation for RAGAS evaluation
|
| 129 |
cleaned_explanation = _clean_explanation_for_ragas(explanation)
|
|
@@ -162,10 +161,8 @@ def get_ragas_llm(provider: str | None = None):
|
|
| 162 |
Returns:
|
| 163 |
RAGAS-compatible LLM wrapper.
|
| 164 |
"""
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
except ImportError:
|
| 168 |
-
raise ImportError("ragas package required. Install with: pip install ragas")
|
| 169 |
|
| 170 |
provider = provider or LLM_PROVIDER
|
| 171 |
|
|
@@ -211,10 +208,8 @@ class FaithfulnessEvaluator:
|
|
| 211 |
provider: LLM provider for RAGAS.
|
| 212 |
target: Faithfulness target score (default 0.85).
|
| 213 |
"""
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
except ImportError:
|
| 217 |
-
raise ImportError("ragas package required. Install with: pip install ragas")
|
| 218 |
|
| 219 |
self.llm = get_ragas_llm(provider)
|
| 220 |
self.scorer = Faithfulness(llm=self.llm)
|
|
@@ -262,11 +257,9 @@ class FaithfulnessEvaluator:
|
|
| 262 |
explanation_results: list[ExplanationResult],
|
| 263 |
) -> FaithfulnessReport:
|
| 264 |
"""Evaluate faithfulness for multiple explanations."""
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
except ImportError:
|
| 269 |
-
raise ImportError("ragas package required. Install with: pip install ragas")
|
| 270 |
|
| 271 |
samples = _explanation_results_to_samples(explanation_results)
|
| 272 |
dataset = EvaluationDataset(samples=samples)
|
|
|
|
| 15 |
|
| 16 |
import numpy as np
|
| 17 |
|
| 18 |
+
from sage.utils import ensure_ragas_installed
|
| 19 |
from sage.core import (
|
| 20 |
AdjustedFaithfulnessReport,
|
| 21 |
AgreementReport,
|
|
|
|
| 121 |
Raises:
|
| 122 |
ImportError: If ragas is not installed.
|
| 123 |
"""
|
| 124 |
+
ensure_ragas_installed()
|
| 125 |
+
from ragas import SingleTurnSample
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# Clean explanation for RAGAS evaluation
|
| 128 |
cleaned_explanation = _clean_explanation_for_ragas(explanation)
|
|
|
|
| 161 |
Returns:
|
| 162 |
RAGAS-compatible LLM wrapper.
|
| 163 |
"""
|
| 164 |
+
ensure_ragas_installed()
|
| 165 |
+
from ragas.llms import llm_factory
|
|
|
|
|
|
|
| 166 |
|
| 167 |
provider = provider or LLM_PROVIDER
|
| 168 |
|
|
|
|
| 208 |
provider: LLM provider for RAGAS.
|
| 209 |
target: Faithfulness target score (default 0.85).
|
| 210 |
"""
|
| 211 |
+
ensure_ragas_installed()
|
| 212 |
+
from ragas.metrics import Faithfulness
|
|
|
|
|
|
|
| 213 |
|
| 214 |
self.llm = get_ragas_llm(provider)
|
| 215 |
self.scorer = Faithfulness(llm=self.llm)
|
|
|
|
| 257 |
explanation_results: list[ExplanationResult],
|
| 258 |
) -> FaithfulnessReport:
|
| 259 |
"""Evaluate faithfulness for multiple explanations."""
|
| 260 |
+
ensure_ragas_installed()
|
| 261 |
+
from ragas import EvaluationDataset, evaluate
|
| 262 |
+
from ragas.metrics import Faithfulness
|
|
|
|
|
|
|
| 263 |
|
| 264 |
samples = _explanation_results_to_samples(explanation_results)
|
| 265 |
dataset = EvaluationDataset(samples=samples)
|
sage/utils.py
CHANGED
|
@@ -92,6 +92,22 @@ def require_imports(*packages: str | tuple[str, str]) -> list[ModuleType]:
|
|
| 92 |
return modules
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
# ---------------------------------------------------------------------------
|
| 96 |
# Lazy Loading Utilities
|
| 97 |
# ---------------------------------------------------------------------------
|
|
|
|
| 92 |
return modules
|
| 93 |
|
| 94 |
|
| 95 |
+
def ensure_ragas_installed() -> None:
|
| 96 |
+
"""Ensure RAGAS package is installed.
|
| 97 |
+
|
| 98 |
+
Centralizes the RAGAS availability check used across faithfulness evaluation.
|
| 99 |
+
Call this before importing RAGAS components to get a clear error message.
|
| 100 |
+
|
| 101 |
+
Usage:
|
| 102 |
+
ensure_ragas_installed()
|
| 103 |
+
from ragas import SingleTurnSample # Safe to import now
|
| 104 |
+
|
| 105 |
+
Raises:
|
| 106 |
+
ImportError: If ragas is not installed with install instructions.
|
| 107 |
+
"""
|
| 108 |
+
require_import("ragas")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
# ---------------------------------------------------------------------------
|
| 112 |
# Lazy Loading Utilities
|
| 113 |
# ---------------------------------------------------------------------------
|
tests/test_production.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for production hardening fixes.
|
| 2 |
+
|
| 3 |
+
Tests security headers, cache key generation, request ID propagation,
|
| 4 |
+
and other production-critical behaviors.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
from types import SimpleNamespace
|
| 9 |
+
from unittest.mock import MagicMock, patch
|
| 10 |
+
|
| 11 |
+
from fastapi import FastAPI
|
| 12 |
+
from fastapi.testclient import TestClient
|
| 13 |
+
|
| 14 |
+
from sage.api.middleware import LatencyMiddleware
|
| 15 |
+
from sage.api.routes import router, _build_cache_key
|
| 16 |
+
from sage.services.cache import SemanticCache
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _make_app_with_middleware(**state_overrides) -> FastAPI:
|
| 21 |
+
"""Create a test app with middleware and mocked state."""
|
| 22 |
+
app = FastAPI()
|
| 23 |
+
|
| 24 |
+
# Add latency middleware (includes security headers)
|
| 25 |
+
app.add_middleware(LatencyMiddleware)
|
| 26 |
+
|
| 27 |
+
app.include_router(router)
|
| 28 |
+
|
| 29 |
+
# Mock Qdrant client
|
| 30 |
+
mock_qdrant = MagicMock()
|
| 31 |
+
mock_qdrant.get_collections.return_value = MagicMock(collections=[])
|
| 32 |
+
|
| 33 |
+
# Mock cache
|
| 34 |
+
mock_cache = MagicMock()
|
| 35 |
+
mock_cache.get.return_value = (None, "miss")
|
| 36 |
+
mock_cache.stats.return_value = SimpleNamespace(
|
| 37 |
+
size=0,
|
| 38 |
+
max_entries=100,
|
| 39 |
+
exact_hits=0,
|
| 40 |
+
semantic_hits=0,
|
| 41 |
+
misses=0,
|
| 42 |
+
evictions=0,
|
| 43 |
+
hit_rate=0.0,
|
| 44 |
+
ttl_seconds=3600.0,
|
| 45 |
+
similarity_threshold=0.92,
|
| 46 |
+
avg_semantic_similarity=0.0,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Mock explainer with client attribute for health check
|
| 50 |
+
mock_explainer = MagicMock()
|
| 51 |
+
mock_explainer.client = MagicMock()
|
| 52 |
+
|
| 53 |
+
app.state.qdrant = state_overrides.get("qdrant", mock_qdrant)
|
| 54 |
+
app.state.embedder = state_overrides.get("embedder", MagicMock())
|
| 55 |
+
app.state.detector = state_overrides.get("detector", MagicMock())
|
| 56 |
+
app.state.explainer = state_overrides.get("explainer", mock_explainer)
|
| 57 |
+
app.state.cache = state_overrides.get("cache", mock_cache)
|
| 58 |
+
|
| 59 |
+
return app
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TestSecurityHeaders:
|
| 63 |
+
"""Test that security headers are added to all responses."""
|
| 64 |
+
|
| 65 |
+
@pytest.fixture
|
| 66 |
+
def client(self):
|
| 67 |
+
app = _make_app_with_middleware()
|
| 68 |
+
return TestClient(app)
|
| 69 |
+
|
| 70 |
+
@patch("sage.api.routes.collection_exists", return_value=True)
|
| 71 |
+
def test_security_headers_present(self, mock_collection_exists, client):
|
| 72 |
+
resp = client.get("/health")
|
| 73 |
+
assert resp.status_code == 200
|
| 74 |
+
|
| 75 |
+
# Check security headers
|
| 76 |
+
assert resp.headers.get("x-content-type-options") == "nosniff"
|
| 77 |
+
assert resp.headers.get("x-frame-options") == "DENY"
|
| 78 |
+
assert resp.headers.get("x-xss-protection") == "1; mode=block"
|
| 79 |
+
assert resp.headers.get("referrer-policy") == "strict-origin-when-cross-origin"
|
| 80 |
+
assert "no-store" in resp.headers.get("cache-control", "")
|
| 81 |
+
|
| 82 |
+
@patch("sage.api.routes.collection_exists", return_value=True)
|
| 83 |
+
def test_request_id_header_present(self, mock_collection_exists, client):
|
| 84 |
+
resp = client.get("/health")
|
| 85 |
+
assert resp.status_code == 200
|
| 86 |
+
|
| 87 |
+
# Check request ID is present and has expected format
|
| 88 |
+
request_id = resp.headers.get("x-request-id")
|
| 89 |
+
assert request_id is not None
|
| 90 |
+
assert len(request_id) == 12 # UUID hex[:12]
|
| 91 |
+
|
| 92 |
+
@patch("sage.api.routes.collection_exists", return_value=True)
|
| 93 |
+
def test_response_time_header_present(self, mock_collection_exists, client):
|
| 94 |
+
resp = client.get("/health")
|
| 95 |
+
assert resp.status_code == 200
|
| 96 |
+
|
| 97 |
+
# Check response time header
|
| 98 |
+
response_time = resp.headers.get("x-response-time-ms")
|
| 99 |
+
assert response_time is not None
|
| 100 |
+
assert float(response_time) >= 0
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class TestCacheKeyGeneration:
|
| 104 |
+
"""Test that cache keys include all request parameters."""
|
| 105 |
+
|
| 106 |
+
def test_cache_key_includes_query(self):
|
| 107 |
+
key1 = _build_cache_key("headphones", k=3, explain=True, min_rating=4.0)
|
| 108 |
+
key2 = _build_cache_key("earbuds", k=3, explain=True, min_rating=4.0)
|
| 109 |
+
assert key1 != key2
|
| 110 |
+
|
| 111 |
+
def test_cache_key_includes_k(self):
|
| 112 |
+
key1 = _build_cache_key("headphones", k=3, explain=True, min_rating=4.0)
|
| 113 |
+
key2 = _build_cache_key("headphones", k=5, explain=True, min_rating=4.0)
|
| 114 |
+
assert key1 != key2
|
| 115 |
+
assert "k=3" in key1
|
| 116 |
+
assert "k=5" in key2
|
| 117 |
+
|
| 118 |
+
def test_cache_key_includes_explain(self):
|
| 119 |
+
key1 = _build_cache_key("headphones", k=3, explain=True, min_rating=4.0)
|
| 120 |
+
key2 = _build_cache_key("headphones", k=3, explain=False, min_rating=4.0)
|
| 121 |
+
assert key1 != key2
|
| 122 |
+
assert "explain=True" in key1
|
| 123 |
+
assert "explain=False" in key2
|
| 124 |
+
|
| 125 |
+
def test_cache_key_includes_rating(self):
|
| 126 |
+
key1 = _build_cache_key("headphones", k=3, explain=True, min_rating=4.0)
|
| 127 |
+
key2 = _build_cache_key("headphones", k=3, explain=True, min_rating=3.5)
|
| 128 |
+
assert key1 != key2
|
| 129 |
+
assert "rating=4.0" in key1
|
| 130 |
+
assert "rating=3.5" in key2
|
| 131 |
+
|
| 132 |
+
def test_cache_key_normalizes_query(self):
|
| 133 |
+
key1 = _build_cache_key(
|
| 134 |
+
" Best Headphones ", k=3, explain=True, min_rating=4.0
|
| 135 |
+
)
|
| 136 |
+
key2 = _build_cache_key("best headphones", k=3, explain=True, min_rating=4.0)
|
| 137 |
+
assert key1 == key2
|
| 138 |
+
|
| 139 |
+
def test_cache_key_case_insensitive(self):
|
| 140 |
+
key1 = _build_cache_key("HEADPHONES", k=3, explain=True, min_rating=4.0)
|
| 141 |
+
key2 = _build_cache_key("headphones", k=3, explain=True, min_rating=4.0)
|
| 142 |
+
assert key1 == key2
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class TestCacheIntegration:
|
| 146 |
+
"""Integration tests for cache with request parameters."""
|
| 147 |
+
|
| 148 |
+
def test_same_query_different_k_different_cache_entries(self):
|
| 149 |
+
cache = SemanticCache(max_entries=100, ttl_seconds=3600)
|
| 150 |
+
|
| 151 |
+
# Create fake embeddings
|
| 152 |
+
embedding = np.random.rand(384).astype(np.float32)
|
| 153 |
+
|
| 154 |
+
# Store result with k=3
|
| 155 |
+
key1 = _build_cache_key("headphones", k=3, explain=True, min_rating=4.0)
|
| 156 |
+
result1 = {"query": "headphones", "recommendations": ["p1", "p2", "p3"]}
|
| 157 |
+
cache.put(key1, embedding, result1)
|
| 158 |
+
|
| 159 |
+
# Store result with k=5
|
| 160 |
+
key2 = _build_cache_key("headphones", k=5, explain=True, min_rating=4.0)
|
| 161 |
+
result2 = {
|
| 162 |
+
"query": "headphones",
|
| 163 |
+
"recommendations": ["p1", "p2", "p3", "p4", "p5"],
|
| 164 |
+
}
|
| 165 |
+
cache.put(key2, embedding, result2)
|
| 166 |
+
|
| 167 |
+
# Retrieve k=3 result
|
| 168 |
+
cached1, hit_type1 = cache.get(key1, embedding)
|
| 169 |
+
assert cached1 is not None
|
| 170 |
+
assert len(cached1["recommendations"]) == 3
|
| 171 |
+
|
| 172 |
+
# Retrieve k=5 result
|
| 173 |
+
cached2, hit_type2 = cache.get(key2, embedding)
|
| 174 |
+
assert cached2 is not None
|
| 175 |
+
assert len(cached2["recommendations"]) == 5
|
| 176 |
+
|
| 177 |
+
def test_same_query_different_rating_different_cache_entries(self):
|
| 178 |
+
cache = SemanticCache(max_entries=100, ttl_seconds=3600)
|
| 179 |
+
embedding = np.random.rand(384).astype(np.float32)
|
| 180 |
+
|
| 181 |
+
# Store with rating=4.0
|
| 182 |
+
key1 = _build_cache_key("headphones", k=3, explain=True, min_rating=4.0)
|
| 183 |
+
cache.put(key1, embedding, {"rating_filter": 4.0})
|
| 184 |
+
|
| 185 |
+
# Store with rating=3.5
|
| 186 |
+
key2 = _build_cache_key("headphones", k=3, explain=True, min_rating=3.5)
|
| 187 |
+
cache.put(key2, embedding, {"rating_filter": 3.5})
|
| 188 |
+
|
| 189 |
+
# Verify they're separate entries
|
| 190 |
+
cached1, _ = cache.get(key1, embedding)
|
| 191 |
+
cached2, _ = cache.get(key2, embedding)
|
| 192 |
+
assert cached1["rating_filter"] == 4.0
|
| 193 |
+
assert cached2["rating_filter"] == 3.5
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class TestRequestContext:
|
| 197 |
+
"""Test request ID context propagation."""
|
| 198 |
+
|
| 199 |
+
def test_request_id_context_var(self):
|
| 200 |
+
from sage.api.context import get_request_id, set_request_id
|
| 201 |
+
|
| 202 |
+
# Default value
|
| 203 |
+
assert get_request_id() == "-"
|
| 204 |
+
|
| 205 |
+
# Set and get
|
| 206 |
+
set_request_id("abc123")
|
| 207 |
+
assert get_request_id() == "abc123"
|
| 208 |
+
|
| 209 |
+
# Reset for other tests
|
| 210 |
+
set_request_id("-")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class TestCORSConfiguration:
|
| 214 |
+
"""Test CORS configuration security."""
|
| 215 |
+
|
| 216 |
+
def test_cors_not_applied_when_empty(self):
|
| 217 |
+
"""When CORS_ORIGINS is empty, no CORS middleware should be added."""
|
| 218 |
+
from sage.api.app import CORS_ORIGINS
|
| 219 |
+
|
| 220 |
+
# This test verifies the default behavior
|
| 221 |
+
# In production, CORS_ORIGINS should be explicitly set
|
| 222 |
+
# Default is empty list (no CORS)
|
| 223 |
+
assert isinstance(CORS_ORIGINS, list)
|
| 224 |
+
|
| 225 |
+
def test_cors_origins_parsing(self):
|
| 226 |
+
"""Test that CORS origins are parsed correctly."""
|
| 227 |
+
import os
|
| 228 |
+
|
| 229 |
+
# Save original
|
| 230 |
+
original = os.environ.get("CORS_ORIGINS")
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
# Test with explicit origins
|
| 234 |
+
os.environ["CORS_ORIGINS"] = "https://example.com,http://localhost:3000"
|
| 235 |
+
# Would need to reload the module to test this properly
|
| 236 |
+
# Just verify the format is correct
|
| 237 |
+
origins = [
|
| 238 |
+
o.strip() for o in os.environ["CORS_ORIGINS"].split(",") if o.strip()
|
| 239 |
+
]
|
| 240 |
+
assert origins == ["https://example.com", "http://localhost:3000"]
|
| 241 |
+
|
| 242 |
+
# Test with empty string
|
| 243 |
+
os.environ["CORS_ORIGINS"] = ""
|
| 244 |
+
origins = [
|
| 245 |
+
o.strip() for o in os.environ["CORS_ORIGINS"].split(",") if o.strip()
|
| 246 |
+
]
|
| 247 |
+
assert origins == []
|
| 248 |
+
|
| 249 |
+
finally:
|
| 250 |
+
# Restore original
|
| 251 |
+
if original is not None:
|
| 252 |
+
os.environ["CORS_ORIGINS"] = original
|
| 253 |
+
elif "CORS_ORIGINS" in os.environ:
|
| 254 |
+
del os.environ["CORS_ORIGINS"]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class TestInputValidation:
|
| 258 |
+
"""Test input validation edge cases."""
|
| 259 |
+
|
| 260 |
+
@pytest.fixture
|
| 261 |
+
def client(self):
|
| 262 |
+
app = _make_app_with_middleware()
|
| 263 |
+
return TestClient(app)
|
| 264 |
+
|
| 265 |
+
def test_empty_query_rejected(self, client):
|
| 266 |
+
resp = client.post("/recommend", json={"query": ""})
|
| 267 |
+
assert resp.status_code == 422
|
| 268 |
+
|
| 269 |
+
def test_query_too_long_rejected(self, client):
|
| 270 |
+
resp = client.post("/recommend", json={"query": "x" * 501})
|
| 271 |
+
assert resp.status_code == 422
|
| 272 |
+
|
| 273 |
+
def test_k_zero_rejected(self, client):
|
| 274 |
+
resp = client.post("/recommend", json={"query": "test", "k": 0})
|
| 275 |
+
assert resp.status_code == 422
|
| 276 |
+
|
| 277 |
+
def test_k_too_large_rejected(self, client):
|
| 278 |
+
resp = client.post("/recommend", json={"query": "test", "k": 11})
|
| 279 |
+
assert resp.status_code == 422
|
| 280 |
+
|
| 281 |
+
def test_invalid_min_rating_rejected(self, client):
|
| 282 |
+
resp = client.post(
|
| 283 |
+
"/recommend",
|
| 284 |
+
json={"query": "test", "filters": {"min_rating": 10.0}},
|
| 285 |
+
)
|
| 286 |
+
assert resp.status_code == 422
|
| 287 |
+
|
| 288 |
+
def test_negative_price_rejected(self, client):
|
| 289 |
+
resp = client.post(
|
| 290 |
+
"/recommend",
|
| 291 |
+
json={"query": "test", "filters": {"min_price": -1.0}},
|
| 292 |
+
)
|
| 293 |
+
assert resp.status_code == 422
|