Add docker-compose, multi-stage builds, and developer tooling
Browse files- .dockerignore +1 -0
- .env.example +17 -22
- .gitignore +11 -4
- Dockerfile +50 -13
- Makefile +143 -38
- README.md +112 -46
- docker-compose.yml +66 -0
- render.yaml +4 -1
- reports/eda_report.md +0 -150
- requirements.txt +58 -0
- sage/adapters/embeddings.py +5 -19
- sage/adapters/hhem.py +19 -22
- sage/adapters/llm.py +178 -58
- sage/adapters/vector_store.py +3 -6
- sage/api/app.py +33 -3
- sage/api/metrics.py +117 -6
- sage/api/middleware.py +155 -22
- sage/api/routes.py +458 -153
- sage/api/run.py +1 -1
- sage/config/__init__.py +7 -5
- sage/config/queries.py +42 -37
- sage/core/prompts.py +2 -3
- sage/core/verification.py +1 -15
- sage/services/baselines.py +5 -8
- sage/services/cache.py +90 -34
- sage/services/cold_start.py +5 -18
- sage/services/evaluation.py +2 -4
- sage/services/explanation.py +11 -16
- sage/services/retrieval.py +17 -35
- sage/utils.py +310 -0
- scripts/demo.py +4 -5
- scripts/e2e_success_rate.py +7 -10
- scripts/eda.py +198 -0
- scripts/evaluation.py +6 -37
- scripts/explanation.py +3 -7
- scripts/faithfulness.py +4 -8
- scripts/human_eval.py +22 -6
- scripts/lib/__init__.py +5 -0
- scripts/lib/services.py +24 -0
- scripts/load_test.py +230 -0
- scripts/pipeline.py +2 -1
- scripts/sanity_checks.py +22 -36
- tests/conftest.py +71 -0
- tests/test_aggregation.py +18 -29
- tests/test_api.py +102 -97
- tests/test_evidence.py +18 -40
.dockerignore
CHANGED
|
@@ -6,6 +6,7 @@ venv/
|
|
| 6 |
data/
|
| 7 |
home/
|
| 8 |
scripts/
|
|
|
|
| 9 |
*.parquet
|
| 10 |
*.npy
|
| 11 |
__pycache__/
|
|
|
|
| 6 |
data/
|
| 7 |
home/
|
| 8 |
scripts/
|
| 9 |
+
tests/
|
| 10 |
*.parquet
|
| 11 |
*.npy
|
| 12 |
__pycache__/
|
.env.example
CHANGED
|
@@ -1,28 +1,23 @@
|
|
| 1 |
# Sage RAG Recommendation System - Environment Variables
|
| 2 |
# Copy this file to .env and fill in your values
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
# HuggingFace (optional, for private models)
|
| 9 |
-
HF_TOKEN=your_huggingface_token
|
| 10 |
-
|
| 11 |
-
# LLM Provider for explanation generation
|
| 12 |
-
# Options: "anthropic" or "openai"
|
| 13 |
-
LLM_PROVIDER=anthropic
|
| 14 |
-
|
| 15 |
-
# Anthropic API Key (required if LLM_PROVIDER=anthropic)
|
| 16 |
ANTHROPIC_API_KEY=your_anthropic_api_key
|
|
|
|
| 17 |
|
| 18 |
-
#
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
#
|
| 22 |
-
#
|
| 23 |
-
#
|
|
|
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
#
|
| 27 |
-
#
|
| 28 |
-
#
|
|
|
|
|
|
| 1 |
# Sage RAG Recommendation System - Environment Variables
|
| 2 |
# Copy this file to .env and fill in your values
|
| 3 |
|
| 4 |
+
# =============================================================================
|
| 5 |
+
# LLM Provider (required)
|
| 6 |
+
# =============================================================================
|
| 7 |
+
LLM_PROVIDER=anthropic # or "openai"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
ANTHROPIC_API_KEY=your_anthropic_api_key
|
| 9 |
+
# OPENAI_API_KEY=your_openai_api_key
|
| 10 |
|
| 11 |
+
# =============================================================================
|
| 12 |
+
# Qdrant Vector Database
|
| 13 |
+
# =============================================================================
|
| 14 |
+
# Local: docker-compose handles this automatically (no config needed)
|
| 15 |
+
# Cloud: uncomment and set for deployment or to use Qdrant Cloud
|
| 16 |
+
# QDRANT_URL=https://your-cluster.cloud.qdrant.io
|
| 17 |
+
# QDRANT_API_KEY=your_qdrant_api_key
|
| 18 |
|
| 19 |
+
# =============================================================================
|
| 20 |
+
# Optional
|
| 21 |
+
# =============================================================================
|
| 22 |
+
# HF_TOKEN=your_huggingface_token # For private models
|
| 23 |
+
# PORT=8000 # Render/Railway inject automatically
|
.gitignore
CHANGED
|
@@ -8,10 +8,6 @@ __pycache__/
|
|
| 8 |
|
| 9 |
# Data (too large for git)
|
| 10 |
data/
|
| 11 |
-
*.parquet
|
| 12 |
-
*.csv
|
| 13 |
-
*.json
|
| 14 |
-
!.env.example
|
| 15 |
|
| 16 |
# IDE
|
| 17 |
.vscode/
|
|
@@ -26,8 +22,19 @@ data/
|
|
| 26 |
|
| 27 |
# Build
|
| 28 |
*.egg-info/
|
|
|
|
| 29 |
dist/
|
| 30 |
build/
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# Personal
|
| 33 |
home/
|
|
|
|
| 8 |
|
| 9 |
# Data (too large for git)
|
| 10 |
data/
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# IDE
|
| 13 |
.vscode/
|
|
|
|
| 22 |
|
| 23 |
# Build
|
| 24 |
*.egg-info/
|
| 25 |
+
*.egg
|
| 26 |
dist/
|
| 27 |
build/
|
| 28 |
|
| 29 |
+
# Testing & Linting
|
| 30 |
+
.pytest_cache/
|
| 31 |
+
.mypy_cache/
|
| 32 |
+
.ruff_cache/
|
| 33 |
+
.coverage
|
| 34 |
+
htmlcov/
|
| 35 |
+
|
| 36 |
+
# Logs
|
| 37 |
+
*.log
|
| 38 |
+
|
| 39 |
# Personal
|
| 40 |
home/
|
Dockerfile
CHANGED
|
@@ -1,36 +1,42 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
# System dependencies
|
| 6 |
RUN apt-get update && \
|
| 7 |
apt-get install -y --no-install-recommends curl && \
|
| 8 |
rm -rf /var/lib/apt/lists/*
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
RUN addgroup --system sage && adduser --system --ingroup sage sage
|
| 12 |
-
|
| 13 |
-
# Ensure pip uses CPU-only torch for all subsequent installs
|
| 14 |
ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu
|
| 15 |
|
| 16 |
-
# Install torch CPU-only first
|
| 17 |
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 18 |
|
| 19 |
-
# Install
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
COPY pyproject.toml .
|
| 21 |
COPY sage/ sage/
|
| 22 |
-
RUN pip install --no-cache-dir
|
| 23 |
|
| 24 |
-
#
|
| 25 |
ENV HF_HOME=/app/.cache/huggingface
|
| 26 |
|
| 27 |
-
#
|
| 28 |
RUN python -c "\
|
| 29 |
from sentence_transformers import SentenceTransformer; \
|
| 30 |
SentenceTransformer('intfloat/e5-small-v2')"
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
# HHEM uses
|
| 34 |
RUN python -c "\
|
| 35 |
from transformers import AutoConfig, AutoTokenizer; \
|
| 36 |
from huggingface_hub import hf_hub_download; \
|
|
@@ -39,6 +45,36 @@ AutoTokenizer.from_pretrained(config.foundation); \
|
|
| 39 |
AutoConfig.from_pretrained(config.foundation); \
|
| 40 |
hf_hub_download('vectara/hallucination_evaluation_model', 'model.safetensors')"
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Fix ownership for non-root user
|
| 43 |
RUN chown -R sage:sage /app
|
| 44 |
|
|
@@ -47,6 +83,7 @@ USER sage
|
|
| 47 |
# Default port; overridden by PORT env var at runtime (Render, Railway)
|
| 48 |
EXPOSE 8000
|
| 49 |
|
|
|
|
| 50 |
HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \
|
| 51 |
CMD curl -sf http://localhost:${PORT:-8000}/health || exit 1
|
| 52 |
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Stage 1: Builder - install dependencies and download models
|
| 3 |
+
# =============================================================================
|
| 4 |
+
FROM python:3.11-slim-bookworm AS builder
|
| 5 |
|
| 6 |
WORKDIR /app
|
| 7 |
|
| 8 |
+
# System dependencies for building
|
| 9 |
RUN apt-get update && \
|
| 10 |
apt-get install -y --no-install-recommends curl && \
|
| 11 |
rm -rf /var/lib/apt/lists/*
|
| 12 |
|
| 13 |
+
# Use CPU-only torch (avoids 2GB+ CUDA libs)
|
|
|
|
|
|
|
|
|
|
| 14 |
ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu
|
| 15 |
|
| 16 |
+
# Install torch CPU-only first
|
| 17 |
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 18 |
|
| 19 |
+
# Install pinned dependencies from requirements.txt for reproducible builds
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Copy application code and install package (--no-deps since deps already installed)
|
| 24 |
+
# Note: pyproject.toml is copied last to maximize layer caching. If only
|
| 25 |
+
# pyproject.toml changes (e.g., version bump), only this layer rebuilds.
|
| 26 |
COPY pyproject.toml .
|
| 27 |
COPY sage/ sage/
|
| 28 |
+
RUN pip install --no-cache-dir . --no-deps
|
| 29 |
|
| 30 |
+
# Pre-download models to cache directory
|
| 31 |
ENV HF_HOME=/app/.cache/huggingface
|
| 32 |
|
| 33 |
+
# Download E5-small embedding model (~134MB)
|
| 34 |
RUN python -c "\
|
| 35 |
from sentence_transformers import SentenceTransformer; \
|
| 36 |
SentenceTransformer('intfloat/e5-small-v2')"
|
| 37 |
|
| 38 |
+
# Download HHEM hallucination detection model (~892MB)
|
| 39 |
+
# HHEM uses custom config pointing to foundation T5 model for tokenizer
|
| 40 |
RUN python -c "\
|
| 41 |
from transformers import AutoConfig, AutoTokenizer; \
|
| 42 |
from huggingface_hub import hf_hub_download; \
|
|
|
|
| 45 |
AutoConfig.from_pretrained(config.foundation); \
|
| 46 |
hf_hub_download('vectara/hallucination_evaluation_model', 'model.safetensors')"
|
| 47 |
|
| 48 |
+
|
| 49 |
+
# =============================================================================
|
| 50 |
+
# Stage 2: Runtime - slim image with only what's needed
|
| 51 |
+
# =============================================================================
|
| 52 |
+
FROM python:3.11-slim-bookworm AS runtime
|
| 53 |
+
|
| 54 |
+
WORKDIR /app
|
| 55 |
+
|
| 56 |
+
# Only curl for healthcheck (no build tools)
|
| 57 |
+
RUN apt-get update && \
|
| 58 |
+
apt-get install -y --no-install-recommends curl && \
|
| 59 |
+
rm -rf /var/lib/apt/lists/*
|
| 60 |
+
|
| 61 |
+
# Non-root user for security
|
| 62 |
+
RUN addgroup --system sage && adduser --system --ingroup sage sage
|
| 63 |
+
|
| 64 |
+
# Copy installed packages from builder
|
| 65 |
+
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
| 66 |
+
COPY --from=builder /usr/local/bin /usr/local/bin
|
| 67 |
+
|
| 68 |
+
# Copy application code
|
| 69 |
+
COPY --from=builder /app/sage /app/sage
|
| 70 |
+
|
| 71 |
+
# Copy pre-downloaded models from builder
|
| 72 |
+
COPY --from=builder /app/.cache /app/.cache
|
| 73 |
+
|
| 74 |
+
# Environment
|
| 75 |
+
ENV HF_HOME=/app/.cache/huggingface
|
| 76 |
+
ENV PYTHONUNBUFFERED=1
|
| 77 |
+
|
| 78 |
# Fix ownership for non-root user
|
| 79 |
RUN chown -R sage:sage /app
|
| 80 |
|
|
|
|
| 83 |
# Default port; overridden by PORT env var at runtime (Render, Railway)
|
| 84 |
EXPOSE 8000
|
| 85 |
|
| 86 |
+
# Health check with startup grace period (models take ~30s to load)
|
| 87 |
HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \
|
| 88 |
CMD curl -sf http://localhost:${PORT:-8000}/health || exit 1
|
| 89 |
|
Makefile
CHANGED
|
@@ -1,4 +1,14 @@
|
|
| 1 |
-
.PHONY: all setup data eval eval-deep eval-quick demo reset reset-hard check-env qdrant-up qdrant-down qdrant-status eda serve serve-dev docker-build docker-run deploy-info human-eval-generate human-eval human-eval-analyze test lint typecheck help
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
# ---------------------------------------------------------------------------
|
| 4 |
# Environment Check
|
|
@@ -41,13 +51,30 @@ data: check-env
|
|
| 41 |
@test -f data/splits/train.parquet || (echo "FAIL: train.parquet not created" && exit 1)
|
| 42 |
@echo "Data pipeline complete"
|
| 43 |
|
| 44 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
eda:
|
| 46 |
@echo "=== EDA ANALYSIS ==="
|
| 47 |
@mkdir -p data/figures
|
|
|
|
| 48 |
python scripts/eda.py
|
| 49 |
@echo "Figures saved to data/figures/"
|
| 50 |
-
@echo "
|
| 51 |
|
| 52 |
# ---------------------------------------------------------------------------
|
| 53 |
# Evaluation Suite
|
|
@@ -75,7 +102,7 @@ eval: check-env
|
|
| 75 |
python scripts/explanation.py --section cold && \
|
| 76 |
echo "" && \
|
| 77 |
echo "--- Faithfulness evaluation (HHEM + RAGAS) ---" && \
|
| 78 |
-
python scripts/faithfulness.py --samples
|
| 79 |
echo "" && \
|
| 80 |
echo "--- Sanity checks (spot) ---" && \
|
| 81 |
python scripts/sanity_checks.py --section spot && \
|
|
@@ -119,7 +146,22 @@ eval-quick: check-env
|
|
| 119 |
# Interactive recommendation with explanation
|
| 120 |
demo: check-env
|
| 121 |
@echo "=== DEMO ==="
|
| 122 |
-
python scripts/demo.py --query "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# ---------------------------------------------------------------------------
|
| 125 |
# Full Pipeline
|
|
@@ -159,7 +201,7 @@ deploy-info:
|
|
| 159 |
|
| 160 |
human-eval-generate: check-env
|
| 161 |
@echo "=== GENERATING HUMAN EVAL SAMPLES ==="
|
| 162 |
-
python scripts/human_eval.py --generate
|
| 163 |
|
| 164 |
human-eval: check-env
|
| 165 |
@echo "=== HUMAN EVALUATION ==="
|
|
@@ -183,6 +225,43 @@ typecheck:
|
|
| 183 |
test:
|
| 184 |
python -m pytest tests/ -v
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
# ---------------------------------------------------------------------------
|
| 187 |
# Reset
|
| 188 |
# ---------------------------------------------------------------------------
|
|
@@ -197,8 +276,8 @@ reset:
|
|
| 197 |
rm -f data/eval_results/eval_*.json
|
| 198 |
rm -f data/eval_results/faithfulness_*.json
|
| 199 |
@echo " (human_eval_*.json preserved — use rm -rf data/eval_results/ to clear)"
|
| 200 |
-
rm -rf data/explanations/
|
| 201 |
rm -rf data/figures/
|
|
|
|
| 202 |
@echo "Clearing Qdrant collection..."
|
| 203 |
@python -c "\
|
| 204 |
from sage.adapters.vector_store import get_client; \
|
|
@@ -207,13 +286,20 @@ reset:
|
|
| 207 |
echo " Qdrant not reachable, skipping collection cleanup"
|
| 208 |
@echo "Done. (Raw download cache preserved — use 'make reset-hard' to clear)"
|
| 209 |
|
| 210 |
-
# Hard reset:
|
| 211 |
reset-hard: reset
|
| 212 |
@echo "Removing raw download cache..."
|
| 213 |
rm -f data/reviews_[0-9]*.parquet
|
| 214 |
rm -f data/reviews_full.parquet
|
| 215 |
rm -rf data/qdrant_storage/
|
| 216 |
-
@echo "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
# ---------------------------------------------------------------------------
|
| 219 |
# Qdrant Management
|
|
@@ -229,11 +315,12 @@ qdrant-up:
|
|
| 229 |
docker start qdrant 2>/dev/null || true
|
| 230 |
@echo "Waiting for Qdrant..."
|
| 231 |
@for i in 1 2 3 4 5 6 7 8 9 10; do \
|
| 232 |
-
|
| 233 |
sleep 1; \
|
| 234 |
done
|
| 235 |
-
@
|
| 236 |
-
|
|
|
|
| 237 |
(echo "ERROR: Qdrant failed to start within 10 seconds" && exit 1)
|
| 238 |
|
| 239 |
qdrant-down:
|
|
@@ -256,43 +343,61 @@ qdrant-status:
|
|
| 256 |
help:
|
| 257 |
@echo "Sage - RAG Recommendation System"
|
| 258 |
@echo ""
|
| 259 |
-
@echo "
|
| 260 |
@echo " make setup Create venv and install dependencies"
|
| 261 |
-
@echo " make qdrant-up Start Qdrant vector database (Docker)"
|
| 262 |
-
@echo " make qdrant-down Stop Qdrant"
|
| 263 |
-
@echo " make qdrant-status Check Qdrant status"
|
| 264 |
-
@echo ""
|
| 265 |
-
@echo "PIPELINE:"
|
| 266 |
@echo " make data Load, chunk, embed, and index reviews"
|
| 267 |
-
@echo " make
|
| 268 |
-
@echo " make eval Standard evaluation (primary metrics + RAGAS + spot-checks)"
|
| 269 |
-
@echo " make eval-deep Deep evaluation (all ablations + baselines + calibration)"
|
| 270 |
-
@echo " make eval-quick Quick eval (skip RAGAS)"
|
| 271 |
-
@echo " make demo Run demo query"
|
| 272 |
@echo " make all Full pipeline (data + eval + demo + summary)"
|
| 273 |
@echo ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
@echo "API:"
|
| 275 |
-
@echo " make serve
|
| 276 |
-
@echo " make serve-dev
|
| 277 |
-
@echo " make docker-build
|
| 278 |
-
@echo " make docker-run
|
| 279 |
-
@echo " make deploy-info
|
| 280 |
@echo ""
|
| 281 |
@echo "HUMAN EVALUATION:"
|
| 282 |
-
@echo " make human-eval-generate Generate 50 eval samples"
|
| 283 |
@echo " make human-eval Rate samples interactively"
|
| 284 |
@echo " make human-eval-analyze Compute results from ratings"
|
| 285 |
@echo ""
|
| 286 |
@echo "QUALITY:"
|
| 287 |
-
@echo " make lint
|
| 288 |
-
@echo " make typecheck
|
| 289 |
-
@echo " make test
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
@echo ""
|
| 291 |
@echo "CLEANUP:"
|
| 292 |
-
@echo " make reset
|
| 293 |
-
@echo " make reset-hard
|
| 294 |
@echo ""
|
| 295 |
-
@echo "
|
| 296 |
-
@echo "
|
| 297 |
-
@echo "
|
| 298 |
-
@echo "
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: all setup data data-validate eval eval-deep eval-quick demo demo-interview reset reset-hard check-env qdrant-up qdrant-down qdrant-status eda serve serve-dev docker-build docker-run deploy-info human-eval-generate human-eval human-eval-analyze test lint typecheck ci info summary metrics-snapshot health help
|
| 2 |
+
|
| 3 |
+
# ---------------------------------------------------------------------------
|
| 4 |
+
# Configurable Variables (override: make demo QUERY="gaming mouse")
|
| 5 |
+
# ---------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
QUERY ?= wireless headphones with noise cancellation
|
| 8 |
+
TOP_K ?= 1
|
| 9 |
+
SAMPLES ?= 10
|
| 10 |
+
SEED ?= 42
|
| 11 |
+
PORT ?= 8000
|
| 12 |
|
| 13 |
# ---------------------------------------------------------------------------
|
| 14 |
# Environment Check
|
|
|
|
| 51 |
@test -f data/splits/train.parquet || (echo "FAIL: train.parquet not created" && exit 1)
|
| 52 |
@echo "Data pipeline complete"
|
| 53 |
|
| 54 |
+
# Validate data outputs exist and have expected structure
|
| 55 |
+
data-validate:
|
| 56 |
+
@echo "Validating data outputs..."
|
| 57 |
+
@test -f data/splits/train.parquet || (echo "FAIL: train.parquet missing" && exit 1)
|
| 58 |
+
@test -f data/splits/test.parquet || (echo "FAIL: test.parquet missing" && exit 1)
|
| 59 |
+
@python -c "\
|
| 60 |
+
import pandas as pd; import numpy as np; from pathlib import Path; \
|
| 61 |
+
t = pd.read_parquet('data/splits/train.parquet'); \
|
| 62 |
+
e = list(Path('data').glob('embeddings_*.npy')); \
|
| 63 |
+
emb = np.load(e[0]) if e else None; \
|
| 64 |
+
print(f'Train: {len(t):,} rows, {t.parent_asin.nunique():,} products'); \
|
| 65 |
+
print(f'Embeddings: {emb.shape if emb is not None else \"not found\"}'); \
|
| 66 |
+
assert len(t) > 1000, 'Train set too small'; \
|
| 67 |
+
assert emb is not None and emb.shape[1] == 384, 'Embedding dimension mismatch'; \
|
| 68 |
+
print('Validation passed')"
|
| 69 |
+
|
| 70 |
+
# Exploratory data analysis (generates figures + report)
|
| 71 |
eda:
|
| 72 |
@echo "=== EDA ANALYSIS ==="
|
| 73 |
@mkdir -p data/figures
|
| 74 |
+
@mkdir -p reports
|
| 75 |
python scripts/eda.py
|
| 76 |
@echo "Figures saved to data/figures/"
|
| 77 |
+
@echo "Report generated: reports/eda_report.md"
|
| 78 |
|
| 79 |
# ---------------------------------------------------------------------------
|
| 80 |
# Evaluation Suite
|
|
|
|
| 102 |
python scripts/explanation.py --section cold && \
|
| 103 |
echo "" && \
|
| 104 |
echo "--- Faithfulness evaluation (HHEM + RAGAS) ---" && \
|
| 105 |
+
python scripts/faithfulness.py --samples $(SAMPLES) --ragas && \
|
| 106 |
echo "" && \
|
| 107 |
echo "--- Sanity checks (spot) ---" && \
|
| 108 |
python scripts/sanity_checks.py --section spot && \
|
|
|
|
| 146 |
# Interactive recommendation with explanation
|
| 147 |
demo: check-env
|
| 148 |
@echo "=== DEMO ==="
|
| 149 |
+
python scripts/demo.py --query "$(QUERY)" --top-k $(TOP_K)
|
| 150 |
+
|
| 151 |
+
# Interview demo: 3 queries showcasing cache hit
|
| 152 |
+
demo-interview: check-env
|
| 153 |
+
@echo "=== SAGE INTERVIEW DEMO ==="
|
| 154 |
+
@echo ""
|
| 155 |
+
@echo "--- Query 1: Basic ---"
|
| 156 |
+
python scripts/demo.py --query "wireless earbuds for running" --top-k 1
|
| 157 |
+
@echo ""
|
| 158 |
+
@echo "--- Query 2: Complex (retrieval depth) ---"
|
| 159 |
+
python scripts/demo.py --query "noise cancelling headphones for office with long battery" --top-k 1
|
| 160 |
+
@echo ""
|
| 161 |
+
@echo "--- Query 3: Cache Hit (same as Query 1) ---"
|
| 162 |
+
python scripts/demo.py --query "wireless earbuds for running" --top-k 1
|
| 163 |
+
@echo ""
|
| 164 |
+
@echo "=== Demo Complete ==="
|
| 165 |
|
| 166 |
# ---------------------------------------------------------------------------
|
| 167 |
# Full Pipeline
|
|
|
|
| 201 |
|
| 202 |
human-eval-generate: check-env
|
| 203 |
@echo "=== GENERATING HUMAN EVAL SAMPLES ==="
|
| 204 |
+
python scripts/human_eval.py --generate --seed $(SEED)
|
| 205 |
|
| 206 |
human-eval: check-env
|
| 207 |
@echo "=== HUMAN EVALUATION ==="
|
|
|
|
| 225 |
test:
|
| 226 |
python -m pytest tests/ -v
|
| 227 |
|
| 228 |
+
ci: lint typecheck test
|
| 229 |
+
@echo "All CI checks passed"
|
| 230 |
+
|
| 231 |
+
# ---------------------------------------------------------------------------
|
| 232 |
+
# Info & Metrics
|
| 233 |
+
# ---------------------------------------------------------------------------
|
| 234 |
+
|
| 235 |
+
info:
|
| 236 |
+
@python -c "\
|
| 237 |
+
import sys; from sage.config import EMBEDDING_MODEL, QDRANT_URL, LLM_PROVIDER, ANTHROPIC_MODEL, OPENAI_MODEL; \
|
| 238 |
+
print('Sage v0.1.0'); \
|
| 239 |
+
print(f'Python: {sys.version_info.major}.{sys.version_info.minor}'); \
|
| 240 |
+
print(f'Embedding: {EMBEDDING_MODEL}'); \
|
| 241 |
+
print(f'Qdrant: {QDRANT_URL}'); \
|
| 242 |
+
print(f'LLM: {LLM_PROVIDER} ({ANTHROPIC_MODEL if LLM_PROVIDER == \"anthropic\" else OPENAI_MODEL})')"
|
| 243 |
+
|
| 244 |
+
summary:
|
| 245 |
+
@python scripts/summary.py
|
| 246 |
+
|
| 247 |
+
metrics-snapshot:
|
| 248 |
+
@python -c "\
|
| 249 |
+
import json; from pathlib import Path; \
|
| 250 |
+
r = Path('data/eval_results'); \
|
| 251 |
+
loo = json.load(open(r/'eval_loo_history_latest.json', encoding='utf-8')) if (r/'eval_loo_history_latest.json').exists() else {}; \
|
| 252 |
+
faith = json.load(open(r/'faithfulness_latest.json', encoding='utf-8')) if (r/'faithfulness_latest.json').exists() else {}; \
|
| 253 |
+
human = json.load(open(r/'human_eval_latest.json', encoding='utf-8')) if (r/'human_eval_latest.json').exists() else {}; \
|
| 254 |
+
pm = loo.get('primary_metrics', {}); mm = faith.get('multi_metric', {}); \
|
| 255 |
+
print('=== SAGE METRICS ==='); \
|
| 256 |
+
print(f'NDCG@10: {pm.get(\"ndcg_at_10\", \"n/a\")}'); \
|
| 257 |
+
print(f'Claim HHEM: {mm.get(\"claim_level_avg_score\", \"n/a\")}'); \
|
| 258 |
+
print(f'Quote Verif: {mm.get(\"quote_verification_rate\", \"n/a\")}'); \
|
| 259 |
+
print(f'Human Eval: {human.get(\"overall_helpfulness\", \"n/a\")}/5.0 (n={human.get(\"n_samples\", 0)})')"
|
| 260 |
+
|
| 261 |
+
health:
|
| 262 |
+
@curl -sf http://localhost:$(PORT)/health | python -m json.tool 2>/dev/null || \
|
| 263 |
+
echo "API not running at localhost:$(PORT). Start with: make serve"
|
| 264 |
+
|
| 265 |
# ---------------------------------------------------------------------------
|
| 266 |
# Reset
|
| 267 |
# ---------------------------------------------------------------------------
|
|
|
|
| 276 |
rm -f data/eval_results/eval_*.json
|
| 277 |
rm -f data/eval_results/faithfulness_*.json
|
| 278 |
@echo " (human_eval_*.json preserved — use rm -rf data/eval_results/ to clear)"
|
|
|
|
| 279 |
rm -rf data/figures/
|
| 280 |
+
rm -f reports/eda_report.md
|
| 281 |
@echo "Clearing Qdrant collection..."
|
| 282 |
@python -c "\
|
| 283 |
from sage.adapters.vector_store import get_client; \
|
|
|
|
| 286 |
echo " Qdrant not reachable, skipping collection cleanup"
|
| 287 |
@echo "Done. (Raw download cache preserved — use 'make reset-hard' to clear)"
|
| 288 |
|
| 289 |
+
# Hard reset: remove EVERYTHING (ground zero for fresh start)
|
| 290 |
reset-hard: reset
|
| 291 |
@echo "Removing raw download cache..."
|
| 292 |
rm -f data/reviews_[0-9]*.parquet
|
| 293 |
rm -f data/reviews_full.parquet
|
| 294 |
rm -rf data/qdrant_storage/
|
| 295 |
+
@echo "Removing human eval data..."
|
| 296 |
+
rm -rf data/human_eval/
|
| 297 |
+
rm -f data/eval_results/human_eval_*.json
|
| 298 |
+
@echo "Removing e2e success results..."
|
| 299 |
+
rm -f data/eval_results/e2e_success_*.json
|
| 300 |
+
@echo "Removing any remaining eval results..."
|
| 301 |
+
rm -rf data/eval_results/
|
| 302 |
+
@echo "Hard reset complete. Project at ground zero."
|
| 303 |
|
| 304 |
# ---------------------------------------------------------------------------
|
| 305 |
# Qdrant Management
|
|
|
|
| 315 |
docker start qdrant 2>/dev/null || true
|
| 316 |
@echo "Waiting for Qdrant..."
|
| 317 |
@for i in 1 2 3 4 5 6 7 8 9 10; do \
|
| 318 |
+
python -c "from sage.adapters.vector_store import get_client; get_client().get_collections()" 2>/dev/null && break; \
|
| 319 |
sleep 1; \
|
| 320 |
done
|
| 321 |
+
@python -c "\
|
| 322 |
+
from sage.adapters.vector_store import get_client; from sage.config import QDRANT_URL; \
|
| 323 |
+
get_client().get_collections(); print(f'Qdrant running at {QDRANT_URL}')" 2>/dev/null || \
|
| 324 |
(echo "ERROR: Qdrant failed to start within 10 seconds" && exit 1)
|
| 325 |
|
| 326 |
qdrant-down:
|
|
|
|
| 343 |
help:
|
| 344 |
@echo "Sage - RAG Recommendation System"
|
| 345 |
@echo ""
|
| 346 |
+
@echo "QUICK START:"
|
| 347 |
@echo " make setup Create venv and install dependencies"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
@echo " make data Load, chunk, embed, and index reviews"
|
| 349 |
+
@echo " make demo Run demo query (customizable: QUERY, TOP_K)"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
@echo " make all Full pipeline (data + eval + demo + summary)"
|
| 351 |
@echo ""
|
| 352 |
+
@echo "DEMO:"
|
| 353 |
+
@echo " make demo Single recommendation with explanation"
|
| 354 |
+
@echo " make demo QUERY=\"gaming mouse\" Custom query"
|
| 355 |
+
@echo " make demo-interview 3-query showcase (includes cache hit)"
|
| 356 |
+
@echo ""
|
| 357 |
+
@echo "INFO & METRICS:"
|
| 358 |
+
@echo " make info Show version, models, and URLs"
|
| 359 |
+
@echo " make summary Print evaluation summary"
|
| 360 |
+
@echo " make metrics-snapshot Quick metrics display"
|
| 361 |
+
@echo " make health Check API health (requires running server)"
|
| 362 |
+
@echo ""
|
| 363 |
+
@echo "PIPELINE:"
|
| 364 |
+
@echo " make data Load, chunk, embed, and index reviews"
|
| 365 |
+
@echo " make data-validate Validate data outputs"
|
| 366 |
+
@echo " make eda Exploratory data analysis (generates figures)"
|
| 367 |
+
@echo " make eval Standard evaluation (SAMPLES=10 default)"
|
| 368 |
+
@echo " make eval-deep Deep evaluation (all ablations + baselines)"
|
| 369 |
+
@echo " make eval-quick Quick eval (skip RAGAS)"
|
| 370 |
+
@echo ""
|
| 371 |
@echo "API:"
|
| 372 |
+
@echo " make serve Start API server (PORT=8000)"
|
| 373 |
+
@echo " make serve-dev Start API with auto-reload"
|
| 374 |
+
@echo " make docker-build Build Docker image"
|
| 375 |
+
@echo " make docker-run Run Docker container"
|
| 376 |
+
@echo " make deploy-info Show Render deployment instructions"
|
| 377 |
@echo ""
|
| 378 |
@echo "HUMAN EVALUATION:"
|
| 379 |
+
@echo " make human-eval-generate Generate 50 eval samples (SEED=42)"
|
| 380 |
@echo " make human-eval Rate samples interactively"
|
| 381 |
@echo " make human-eval-analyze Compute results from ratings"
|
| 382 |
@echo ""
|
| 383 |
@echo "QUALITY:"
|
| 384 |
+
@echo " make lint Run ruff linter and formatter check"
|
| 385 |
+
@echo " make typecheck Run mypy type checking"
|
| 386 |
+
@echo " make test Run unit tests"
|
| 387 |
+
@echo " make ci Run all CI checks (lint + typecheck + test)"
|
| 388 |
+
@echo ""
|
| 389 |
+
@echo "QDRANT:"
|
| 390 |
+
@echo " make qdrant-up Start Qdrant vector database (Docker)"
|
| 391 |
+
@echo " make qdrant-down Stop Qdrant"
|
| 392 |
+
@echo " make qdrant-status Check Qdrant status"
|
| 393 |
@echo ""
|
| 394 |
@echo "CLEANUP:"
|
| 395 |
+
@echo " make reset Clear generated data and Qdrant collection"
|
| 396 |
+
@echo " make reset-hard Reset + clear raw data cache"
|
| 397 |
@echo ""
|
| 398 |
+
@echo "VARIABLES:"
|
| 399 |
+
@echo " QUERY Demo query (default: wireless headphones...)"
|
| 400 |
+
@echo " TOP_K Number of results (default: 1)"
|
| 401 |
+
@echo " SAMPLES Faithfulness eval samples (default: 10)"
|
| 402 |
+
@echo " SEED Random seed for human eval (default: 42)"
|
| 403 |
+
@echo " PORT API port (default: 8000)"
|
README.md
CHANGED
|
@@ -1,77 +1,143 @@
|
|
| 1 |
# Sage
|
| 2 |
|
| 3 |
-
RAG-powered product recommendation system with explainable AI. Retrieves relevant products
|
| 4 |
|
| 5 |
-
##
|
| 6 |
|
| 7 |
-
| Metric | Target |
|
| 8 |
-
|--------|--------|
|
| 9 |
-
| Recommendation Quality (NDCG@10) | 0.30 |
|
| 10 |
-
| Explanation Faithfulness (
|
| 11 |
-
|
|
| 12 |
-
|
| 13 |
-
## Architecture
|
| 14 |
-
|
| 15 |
-
```
|
| 16 |
-
Query → Semantic Search (Qdrant) → Rank Products → Generate Explanation (LLM)
|
| 17 |
-
↓
|
| 18 |
-
Verify Citations ← Retrieve Evidence
|
| 19 |
-
↓
|
| 20 |
-
Check Faithfulness (HHEM) → Response + Confidence
|
| 21 |
-
```
|
| 22 |
|
| 23 |
## Tech Stack
|
| 24 |
|
| 25 |
-
- **Embeddings:** E5-small (384-dim
|
| 26 |
- **Vector DB:** Qdrant with semantic caching
|
| 27 |
- **LLM:** Claude Sonnet / GPT-4o-mini
|
| 28 |
- **Faithfulness:** HHEM (Vectara hallucination detector) + quote verification
|
| 29 |
-
- **API:** FastAPI with streaming support
|
|
|
|
| 30 |
|
| 31 |
## Quick Start
|
| 32 |
|
|
|
|
|
|
|
| 33 |
```bash
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
make
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
make serve
|
| 47 |
```
|
| 48 |
|
| 49 |
-
##
|
| 50 |
|
| 51 |
```bash
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
```
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
"evidence": [{"id": "review_127", "text": "..."}]
|
| 64 |
-
}]
|
| 65 |
-
}
|
| 66 |
```
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
```bash
|
| 71 |
-
make
|
| 72 |
-
make
|
| 73 |
-
make
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
## License
|
| 77 |
|
|
|
|
| 1 |
# Sage
|
| 2 |
|
| 3 |
+
RAG-powered product recommendation system with explainable AI. Retrieves relevant products via semantic search over customer reviews, generates natural language explanations grounded in evidence, and verifies faithfulness using hallucination detection.
|
| 4 |
|
| 5 |
+
## Targets
|
| 6 |
|
| 7 |
+
| Metric | Target |
|
| 8 |
+
|--------|--------|
|
| 9 |
+
| Recommendation Quality (NDCG@10) | > 0.30 |
|
| 10 |
+
| Explanation Faithfulness (RAGAS) | > 0.85 |
|
| 11 |
+
| System Latency (P99) | < 500ms |
|
| 12 |
+
| Human Evaluation (n=50) | > 3.5/5.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
## Tech Stack
|
| 15 |
|
| 16 |
+
- **Embeddings:** E5-small (384-dim)
|
| 17 |
- **Vector DB:** Qdrant with semantic caching
|
| 18 |
- **LLM:** Claude Sonnet / GPT-4o-mini
|
| 19 |
- **Faithfulness:** HHEM (Vectara hallucination detector) + quote verification
|
| 20 |
+
- **API:** FastAPI with async handlers and streaming support
|
| 21 |
+
- **Metrics:** Prometheus (latency histograms, cache hit rates, error counts)
|
| 22 |
|
| 23 |
## Quick Start
|
| 24 |
|
| 25 |
+
### Option 1: Docker (easiest)
|
| 26 |
+
|
| 27 |
```bash
|
| 28 |
+
git clone https://github.com/vxa8502/sage-recommendations
|
| 29 |
+
cd sage-recommendations
|
| 30 |
+
cp .env.example .env
|
| 31 |
+
# Edit .env and set ANTHROPIC_API_KEY (or OPENAI_API_KEY)
|
| 32 |
+
|
| 33 |
+
docker-compose up
|
| 34 |
+
curl http://localhost:8000/health
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Option 2: Local Development
|
| 38 |
|
| 39 |
+
```bash
|
| 40 |
+
python3 -m venv .venv
|
| 41 |
+
source .venv/bin/activate
|
| 42 |
+
pip install -e ".[dev,pipeline,api,anthropic]" # or openai
|
| 43 |
|
| 44 |
+
cp .env.example .env
|
| 45 |
+
# Edit .env: add LLM key + Qdrant (local via `make qdrant-up` or Qdrant Cloud)
|
| 46 |
|
| 47 |
+
make data # Load data and embeddings
|
| 48 |
+
make serve # Start API
|
| 49 |
```
|
| 50 |
|
| 51 |
+
## Environment Variables
|
| 52 |
|
| 53 |
```bash
|
| 54 |
+
# Required
|
| 55 |
+
LLM_PROVIDER=anthropic # or "openai"
|
| 56 |
+
ANTHROPIC_API_KEY=your_key_here
|
| 57 |
+
|
| 58 |
+
# Optional: Qdrant Cloud (for deployment or instead of local)
|
| 59 |
+
# QDRANT_URL=https://your-cluster.cloud.qdrant.io
|
| 60 |
+
# QDRANT_API_KEY=your_qdrant_key
|
| 61 |
```
|
| 62 |
|
| 63 |
+
## API Reference
|
| 64 |
+
|
| 65 |
+
### POST /recommend
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
curl -X POST http://localhost:8000/recommend \
|
| 69 |
+
-H "Content-Type: application/json" \
|
| 70 |
+
-d '{"query": "wireless earbuds for running", "k": 3, "explain": true}'
|
|
|
|
|
|
|
|
|
|
| 71 |
```
|
| 72 |
|
| 73 |
+
Returns ranked products with explanations grounded in customer reviews, HHEM confidence scores, and citation verification.
|
| 74 |
+
|
| 75 |
+
### POST /recommend/stream
|
| 76 |
+
|
| 77 |
+
Stream recommendations with token-by-token explanation delivery (SSE).
|
| 78 |
+
|
| 79 |
+
### GET /health
|
| 80 |
+
|
| 81 |
+
Service health check.
|
| 82 |
+
|
| 83 |
+
### GET /metrics
|
| 84 |
+
|
| 85 |
+
Prometheus metrics: latency histograms, cache hit rates, error counts.
|
| 86 |
+
|
| 87 |
+
### GET /cache/stats
|
| 88 |
+
|
| 89 |
+
Cache performance statistics.
|
| 90 |
+
|
| 91 |
+
## Failure Modes (By Design)
|
| 92 |
+
|
| 93 |
+
| Condition | System Behavior |
|
| 94 |
+
|-----------|-----------------|
|
| 95 |
+
| Insufficient evidence | Refuses to explain |
|
| 96 |
+
| Quote not found in source | Falls back to paraphrased claims |
|
| 97 |
+
| HHEM confidence below threshold | Flags explanation as uncertain |
|
| 98 |
+
|
| 99 |
+
The system refuses to hallucinate rather than confidently stating unsupported claims.
|
| 100 |
+
|
| 101 |
+
## Development
|
| 102 |
|
| 103 |
```bash
|
| 104 |
+
make test # Run tests
|
| 105 |
+
make lint # Run linter
|
| 106 |
+
make eval # Run evaluation suite
|
| 107 |
+
make all # Full pipeline
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## Project Structure
|
| 111 |
+
|
| 112 |
```
|
| 113 |
+
sage/
|
| 114 |
+
├── adapters/ # External integrations (Qdrant, LLM, HHEM)
|
| 115 |
+
├── api/ # FastAPI routes, middleware, metrics
|
| 116 |
+
├── config/ # Settings, constants, queries
|
| 117 |
+
├── core/ # Domain models, aggregation, verification
|
| 118 |
+
├── services/ # Business logic (retrieval, explanation, cache)
|
| 119 |
+
scripts/
|
| 120 |
+
├── pipeline.py # Data ingestion and embedding
|
| 121 |
+
├── demo.py # Interactive demo
|
| 122 |
+
├── evaluation.py # Recommendation metrics (NDCG, precision, recall)
|
| 123 |
+
├── faithfulness.py # RAGAS + HHEM faithfulness evaluation
|
| 124 |
+
├── explanation.py # Explanation quality tests
|
| 125 |
+
├── human_eval.py # Human evaluation workflow
|
| 126 |
+
├── sanity_checks.py # Spot checks and calibration
|
| 127 |
+
├── load_test.py # Latency benchmarking
|
| 128 |
+
├── eda.py # Exploratory data analysis
|
| 129 |
+
tests/
|
| 130 |
+
├── test_api.py
|
| 131 |
+
├── test_evidence.py
|
| 132 |
+
├── test_aggregation.py
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
## Future Work
|
| 136 |
+
|
| 137 |
+
1. **Cross-encoder reranking** for improved precision on top-k candidates
|
| 138 |
+
2. **User feedback loops** for learning from implicit signals
|
| 139 |
+
3. **Hybrid retrieval** with BM25 + dense fusion
|
| 140 |
+
4. **Expanded human evaluation** with stratified sampling
|
| 141 |
|
| 142 |
## License
|
| 143 |
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sage RAG Recommendation System - Docker Compose
|
| 2 |
+
#
|
| 3 |
+
# Usage:
|
| 4 |
+
# 1. Copy .env.example to .env and fill in your API keys
|
| 5 |
+
# 2. Run: docker-compose up
|
| 6 |
+
# 3. Hit: http://localhost:8000/health
|
| 7 |
+
#
|
| 8 |
+
# This brings up:
|
| 9 |
+
# - Sage API (FastAPI) on port 8000
|
| 10 |
+
# - Qdrant (vector DB) on port 6333
|
| 11 |
+
#
|
| 12 |
+
# For Qdrant Cloud instead of local Qdrant:
|
| 13 |
+
# Set QDRANT_URL and QDRANT_API_KEY in .env
|
| 14 |
+
# Local Qdrant still starts but is unused; API connects to cloud
|
| 15 |
+
|
| 16 |
+
services:
|
| 17 |
+
# ==========================================================================
|
| 18 |
+
# Sage API - FastAPI recommendation service
|
| 19 |
+
# ==========================================================================
|
| 20 |
+
sage:
|
| 21 |
+
build:
|
| 22 |
+
context: .
|
| 23 |
+
dockerfile: Dockerfile
|
| 24 |
+
ports:
|
| 25 |
+
- "${PORT:-8000}:${PORT:-8000}"
|
| 26 |
+
env_file:
|
| 27 |
+
- .env
|
| 28 |
+
environment:
|
| 29 |
+
- PORT=${PORT:-8000}
|
| 30 |
+
# Use local Qdrant if QDRANT_URL not set in .env
|
| 31 |
+
- QDRANT_URL=${QDRANT_URL:-http://qdrant:6333}
|
| 32 |
+
depends_on:
|
| 33 |
+
qdrant:
|
| 34 |
+
condition: service_healthy
|
| 35 |
+
healthcheck:
|
| 36 |
+
test: ["CMD", "curl", "-sf", "http://localhost:${PORT:-8000}/health"]
|
| 37 |
+
interval: 30s
|
| 38 |
+
timeout: 5s
|
| 39 |
+
start_period: 90s # Models take ~60s to load
|
| 40 |
+
retries: 3
|
| 41 |
+
restart: unless-stopped
|
| 42 |
+
|
| 43 |
+
# ==========================================================================
|
| 44 |
+
# Qdrant - Vector database for embeddings
|
| 45 |
+
# ==========================================================================
|
| 46 |
+
qdrant:
|
| 47 |
+
image: qdrant/qdrant:v1.7.4
|
| 48 |
+
ports:
|
| 49 |
+
- "6333:6333"
|
| 50 |
+
- "6334:6334" # gRPC
|
| 51 |
+
volumes:
|
| 52 |
+
# Persist vectors across container restarts
|
| 53 |
+
- qdrant_data:/qdrant/storage
|
| 54 |
+
environment:
|
| 55 |
+
- QDRANT__SERVICE__GRPC_PORT=6334
|
| 56 |
+
healthcheck:
|
| 57 |
+
test: ["CMD", "curl", "-sf", "http://localhost:6333/readyz"]
|
| 58 |
+
interval: 10s
|
| 59 |
+
timeout: 5s
|
| 60 |
+
start_period: 10s
|
| 61 |
+
retries: 3
|
| 62 |
+
restart: unless-stopped
|
| 63 |
+
|
| 64 |
+
volumes:
|
| 65 |
+
qdrant_data:
|
| 66 |
+
driver: local
|
render.yaml
CHANGED
|
@@ -2,7 +2,8 @@ services:
|
|
| 2 |
- type: web
|
| 3 |
name: sage
|
| 4 |
runtime: docker
|
| 5 |
-
plan:
|
|
|
|
| 6 |
healthCheckPath: /health
|
| 7 |
envVars:
|
| 8 |
- key: QDRANT_URL
|
|
@@ -11,5 +12,7 @@ services:
|
|
| 11 |
sync: false
|
| 12 |
- key: ANTHROPIC_API_KEY
|
| 13 |
sync: false
|
|
|
|
|
|
|
| 14 |
- key: LLM_PROVIDER
|
| 15 |
value: anthropic
|
|
|
|
| 2 |
- type: web
|
| 3 |
name: sage
|
| 4 |
runtime: docker
|
| 5 |
+
plan: starter
|
| 6 |
+
region: oregon
|
| 7 |
healthCheckPath: /health
|
| 8 |
envVars:
|
| 9 |
- key: QDRANT_URL
|
|
|
|
| 12 |
sync: false
|
| 13 |
- key: ANTHROPIC_API_KEY
|
| 14 |
sync: false
|
| 15 |
+
- key: OPENAI_API_KEY
|
| 16 |
+
sync: false
|
| 17 |
- key: LLM_PROVIDER
|
| 18 |
value: anthropic
|
reports/eda_report.md
DELETED
|
@@ -1,150 +0,0 @@
|
|
| 1 |
-
# Exploratory Data Analysis: Amazon Electronics Reviews
|
| 2 |
-
|
| 3 |
-
**Dataset:** McAuley-Lab/Amazon-Reviews-2023 (Electronics category)
|
| 4 |
-
**Subset:** 100,000 raw reviews → 2,635 after 5-core filtering
|
| 5 |
-
|
| 6 |
-
---
|
| 7 |
-
|
| 8 |
-
## Dataset Overview
|
| 9 |
-
|
| 10 |
-
The Amazon Electronics reviews dataset provides rich user feedback data for building recommendation systems. After standard preprocessing and 5-core filtering (requiring users and items to have at least 5 interactions), the dataset exhibits the characteristic sparsity of real-world recommendation scenarios.
|
| 11 |
-
|
| 12 |
-
| Metric | Raw | After 5-Core |
|
| 13 |
-
|--------|-----|--------------|
|
| 14 |
-
| Total Reviews | 100,000 | 2,635 |
|
| 15 |
-
| Unique Users | 15,322 | 334 |
|
| 16 |
-
| Unique Items | 59,429 | 318 |
|
| 17 |
-
| Avg Rating | 4.26 | 4.44 |
|
| 18 |
-
| Retention | — | 2.6% |
|
| 19 |
-
|
| 20 |
-
---
|
| 21 |
-
|
| 22 |
-
## Rating Distribution
|
| 23 |
-
|
| 24 |
-
Amazon reviews exhibit a well-known J-shaped distribution, heavily skewed toward 5-star ratings. This reflects both genuine satisfaction and selection bias (dissatisfied customers often don't leave reviews).
|
| 25 |
-
|
| 26 |
-

|
| 27 |
-
|
| 28 |
-
**Key Observations:**
|
| 29 |
-
- 5-star ratings dominate (65.4% of reviews)
|
| 30 |
-
- 1-star reviews form the second largest group (8.0%)
|
| 31 |
-
- Middle ratings (2-4 stars) are relatively rare (26.6% combined)
|
| 32 |
-
- This polarization is typical for e-commerce review data
|
| 33 |
-
|
| 34 |
-
**Implications for Modeling:**
|
| 35 |
-
- Binary classification (positive/negative) may be more robust than regression
|
| 36 |
-
- Rating-weighted aggregation should account for the skewed distribution
|
| 37 |
-
- Evidence from 4-5 star reviews carries stronger positive signal
|
| 38 |
-
|
| 39 |
-
---
|
| 40 |
-
|
| 41 |
-
## Review Length Analysis
|
| 42 |
-
|
| 43 |
-
Review length varies significantly and correlates with the chunking strategy for the RAG pipeline. Most reviews are short enough to embed directly without chunking.
|
| 44 |
-
|
| 45 |
-

|
| 46 |
-
|
| 47 |
-
**Length Statistics:**
|
| 48 |
-
- Median: 183 characters (~45 tokens)
|
| 49 |
-
- Mean: 369 characters (~92 tokens)
|
| 50 |
-
- Reviews exceeding 200 tokens: 11.2% (require chunking)
|
| 51 |
-
|
| 52 |
-
**Chunking Strategy Validation:**
|
| 53 |
-
The tiered chunking approach is well-suited to this distribution:
|
| 54 |
-
- **Short (<200 tokens):** No chunking needed — majority of reviews
|
| 55 |
-
- **Medium (200-500 tokens):** Semantic chunking at topic boundaries
|
| 56 |
-
- **Long (>500 tokens):** Semantic + sliding window fallback
|
| 57 |
-
|
| 58 |
-
---
|
| 59 |
-
|
| 60 |
-
## Review Length by Rating
|
| 61 |
-
|
| 62 |
-
Negative reviews tend to be longer than positive ones. Users who are dissatisfied often provide detailed explanations of issues, while satisfied users may simply express approval.
|
| 63 |
-
|
| 64 |
-

|
| 65 |
-
|
| 66 |
-
**Pattern:**
|
| 67 |
-
- 1-star reviews: 187 chars median
|
| 68 |
-
- 2-3 star reviews: 258-265 chars median (users explain nuance)
|
| 69 |
-
- 4-star reviews: 297 chars median (longest — detailed positive feedback)
|
| 70 |
-
- 5-star reviews: 152 chars median (shortest — quick endorsements)
|
| 71 |
-
|
| 72 |
-
**Implications:**
|
| 73 |
-
- Negative reviews provide richer evidence for issue identification
|
| 74 |
-
- Positive reviews may require multiple chunks for substantive explanations
|
| 75 |
-
- Rating filters (min_rating=4) naturally bias toward shorter evidence
|
| 76 |
-
|
| 77 |
-
---
|
| 78 |
-
|
| 79 |
-
## Temporal Distribution
|
| 80 |
-
|
| 81 |
-
The dataset spans multiple years of reviews, enabling proper temporal train/validation/test splits that prevent data leakage.
|
| 82 |
-
|
| 83 |
-

|
| 84 |
-
|
| 85 |
-
**Temporal Split Strategy:**
|
| 86 |
-
- **Train (70%):** Oldest reviews — model learns from historical patterns
|
| 87 |
-
- **Validation (10%):** Middle period — hyperparameter tuning
|
| 88 |
-
- **Test (20%):** Most recent — simulates production deployment
|
| 89 |
-
|
| 90 |
-
This chronological ordering ensures the model never sees "future" data during training.
|
| 91 |
-
|
| 92 |
-
---
|
| 93 |
-
|
| 94 |
-
## User and Item Activity
|
| 95 |
-
|
| 96 |
-
The long-tail distribution is pronounced: most users write few reviews, and most items receive few reviews. This sparsity is the fundamental challenge recommendation systems address.
|
| 97 |
-
|
| 98 |
-

|
| 99 |
-
|
| 100 |
-
**User Activity:**
|
| 101 |
-
- Users with only 1 review: 30.1%
|
| 102 |
-
- Users with 5+ reviews: 4,991 (32.6%)
|
| 103 |
-
- Power user max: 820 reviews
|
| 104 |
-
|
| 105 |
-
**Item Popularity:**
|
| 106 |
-
- Items with only 1 review: 76.0%
|
| 107 |
-
- Items with 5+ reviews: 2,434 (4.1%)
|
| 108 |
-
- Most reviewed item: 326 reviews
|
| 109 |
-
|
| 110 |
-
**Cold-Start Implications:**
|
| 111 |
-
- Many items have sparse evidence — content-based features are critical
|
| 112 |
-
- User cold-start is common — onboarding preferences help
|
| 113 |
-
- 5-core filtering ensures minimum evidence density for evaluation
|
| 114 |
-
|
| 115 |
-
---
|
| 116 |
-
|
| 117 |
-
## Data Quality Assessment
|
| 118 |
-
|
| 119 |
-
The raw dataset contains several quality issues addressed during preprocessing.
|
| 120 |
-
|
| 121 |
-
| Issue | Count | Resolution |
|
| 122 |
-
|-------|-------|------------|
|
| 123 |
-
| Missing text | 0 | — |
|
| 124 |
-
| Empty reviews | 21 | Removed |
|
| 125 |
-
| Very short (<10 chars) | 2,512 | Removed |
|
| 126 |
-
| Duplicate texts | 5,219 | Kept (valid re-purchases) |
|
| 127 |
-
| Invalid ratings | 0 | — |
|
| 128 |
-
|
| 129 |
-
**Post-Cleaning:**
|
| 130 |
-
- All reviews have valid text content
|
| 131 |
-
- All ratings are in [1, 5] range
|
| 132 |
-
- All user/product identifiers present
|
| 133 |
-
|
| 134 |
-
---
|
| 135 |
-
|
| 136 |
-
## Summary
|
| 137 |
-
|
| 138 |
-
The Amazon Electronics dataset, after 5-core filtering and cleaning, provides a solid foundation for building and evaluating a RAG-based recommendation system:
|
| 139 |
-
|
| 140 |
-
1. **Scale:** 2,635 reviews across 334 users and 318 items
|
| 141 |
-
2. **Sparsity:** 97.5% — realistic for recommendation evaluation
|
| 142 |
-
3. **Quality:** Clean text, valid ratings, proper identifiers
|
| 143 |
-
4. **Temporal:** Supports chronological train/val/test splits
|
| 144 |
-
5. **Content:** Review lengths suit the tiered chunking strategy
|
| 145 |
-
|
| 146 |
-
The J-shaped rating distribution and long-tail user/item activity are characteristic of real e-commerce data, making this an appropriate benchmark for portfolio demonstration.
|
| 147 |
-
|
| 148 |
-
---
|
| 149 |
-
|
| 150 |
-
*Figures generated by `scripts/eda.py` at 300 DPI. Run `make figures` to regenerate.*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pinned dependencies for Docker builds
|
| 2 |
+
# Generated from: pip freeze on Python 3.11
|
| 3 |
+
#
|
| 4 |
+
# To regenerate:
|
| 5 |
+
# pip install -e ".[api,anthropic]" && pip freeze | grep -v "^-e" > requirements.txt
|
| 6 |
+
#
|
| 7 |
+
# Core ML dependencies
|
| 8 |
+
torch==2.5.1
|
| 9 |
+
sentence-transformers==3.3.1
|
| 10 |
+
transformers==4.47.1
|
| 11 |
+
huggingface-hub==0.27.0
|
| 12 |
+
safetensors==0.4.5
|
| 13 |
+
numpy==2.2.1
|
| 14 |
+
|
| 15 |
+
# Vector database
|
| 16 |
+
qdrant-client==1.12.1
|
| 17 |
+
|
| 18 |
+
# API server
|
| 19 |
+
fastapi==0.115.6
|
| 20 |
+
uvicorn==0.34.0
|
| 21 |
+
starlette==0.41.3
|
| 22 |
+
pydantic==2.10.3
|
| 23 |
+
|
| 24 |
+
# LLM client
|
| 25 |
+
anthropic==0.42.0
|
| 26 |
+
|
| 27 |
+
# Metrics
|
| 28 |
+
prometheus-client==0.21.1
|
| 29 |
+
|
| 30 |
+
# Utilities
|
| 31 |
+
python-dotenv==1.0.1
|
| 32 |
+
sentencepiece==0.2.0
|
| 33 |
+
httpx==0.28.1
|
| 34 |
+
anyio==4.7.0
|
| 35 |
+
certifi==2024.12.14
|
| 36 |
+
charset-normalizer==3.4.1
|
| 37 |
+
click==8.1.8
|
| 38 |
+
filelock==3.16.1
|
| 39 |
+
fsspec==2024.12.0
|
| 40 |
+
h11==0.14.0
|
| 41 |
+
idna==3.10
|
| 42 |
+
Jinja2==3.1.4
|
| 43 |
+
joblib==1.4.2
|
| 44 |
+
MarkupSafe==3.0.2
|
| 45 |
+
packaging==24.2
|
| 46 |
+
pillow==11.1.0
|
| 47 |
+
portalocker==3.0.0
|
| 48 |
+
PyYAML==6.0.2
|
| 49 |
+
regex==2024.11.6
|
| 50 |
+
requests==2.32.3
|
| 51 |
+
scikit-learn==1.6.0
|
| 52 |
+
scipy==1.15.0
|
| 53 |
+
sniffio==1.3.1
|
| 54 |
+
threadpoolctl==3.5.0
|
| 55 |
+
tokenizers==0.21.0
|
| 56 |
+
tqdm==4.67.1
|
| 57 |
+
typing_extensions==4.12.2
|
| 58 |
+
urllib3==2.3.0
|
sage/adapters/embeddings.py
CHANGED
|
@@ -13,12 +13,12 @@ the content words overlap. Mitigation: use rating filters to enforce sentiment
|
|
| 13 |
alignment (negative reviews typically have low ratings).
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
import threading
|
| 17 |
from pathlib import Path
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
from sage.config import EMBEDDING_BATCH_SIZE, EMBEDDING_MODEL, get_logger
|
|
|
|
| 22 |
|
| 23 |
logger = get_logger(__name__)
|
| 24 |
|
|
@@ -40,13 +40,8 @@ class E5Embedder:
|
|
| 40 |
Raises:
|
| 41 |
ImportError: If sentence_transformers is not installed.
|
| 42 |
"""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
except ImportError:
|
| 46 |
-
raise ImportError(
|
| 47 |
-
"sentence_transformers package required. "
|
| 48 |
-
"Install with: pip install sentence-transformers"
|
| 49 |
-
)
|
| 50 |
|
| 51 |
logger.info("Loading embedding model: %s", model_name)
|
| 52 |
self.model = SentenceTransformer(model_name)
|
|
@@ -152,16 +147,7 @@ class E5Embedder:
|
|
| 152 |
return self.embed_queries([query])[0]
|
| 153 |
|
| 154 |
|
| 155 |
-
|
| 156 |
-
_embedder: E5Embedder | None = None
|
| 157 |
-
_embedder_lock = threading.Lock()
|
| 158 |
-
|
| 159 |
-
|
| 160 |
def get_embedder() -> E5Embedder:
|
| 161 |
"""Get or create the global embedder instance (thread-safe singleton)."""
|
| 162 |
-
|
| 163 |
-
if _embedder is None:
|
| 164 |
-
with _embedder_lock:
|
| 165 |
-
if _embedder is None:
|
| 166 |
-
_embedder = E5Embedder()
|
| 167 |
-
return _embedder
|
|
|
|
| 13 |
alignment (negative reviews typically have low ratings).
|
| 14 |
"""
|
| 15 |
|
|
|
|
| 16 |
from pathlib import Path
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
|
| 20 |
from sage.config import EMBEDDING_BATCH_SIZE, EMBEDDING_MODEL, get_logger
|
| 21 |
+
from sage.utils import require_import, thread_safe_singleton
|
| 22 |
|
| 23 |
logger = get_logger(__name__)
|
| 24 |
|
|
|
|
| 40 |
Raises:
|
| 41 |
ImportError: If sentence_transformers is not installed.
|
| 42 |
"""
|
| 43 |
+
st = require_import("sentence_transformers", pip_name="sentence-transformers")
|
| 44 |
+
SentenceTransformer = st.SentenceTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
logger.info("Loading embedding model: %s", model_name)
|
| 47 |
self.model = SentenceTransformer(model_name)
|
|
|
|
| 147 |
return self.embed_queries([query])[0]
|
| 148 |
|
| 149 |
|
| 150 |
+
@thread_safe_singleton
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def get_embedder() -> E5Embedder:
|
| 152 |
"""Get or create the global embedder instance (thread-safe singleton)."""
|
| 153 |
+
return E5Embedder()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/adapters/hhem.py
CHANGED
|
@@ -20,9 +20,10 @@ Limitations:
|
|
| 20 |
- Safe evidence budget: ~400 tokens (~3 chunks at 100 tokens each).
|
| 21 |
"""
|
| 22 |
|
| 23 |
-
import
|
| 24 |
import warnings
|
| 25 |
|
|
|
|
| 26 |
from sage.core import (
|
| 27 |
ClaimResult,
|
| 28 |
HallucinationResult,
|
|
@@ -33,6 +34,7 @@ from sage.config import (
|
|
| 33 |
HHEM_MODEL,
|
| 34 |
get_logger,
|
| 35 |
)
|
|
|
|
| 36 |
|
| 37 |
logger = get_logger(__name__)
|
| 38 |
|
|
@@ -67,16 +69,17 @@ class HallucinationDetector:
|
|
| 67 |
Raises:
|
| 68 |
ImportError: If required packages are not installed.
|
| 69 |
"""
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
self.threshold = threshold
|
| 82 |
self.device = device
|
|
@@ -232,8 +235,11 @@ class HallucinationDetector:
|
|
| 232 |
Returns:
|
| 233 |
HallucinationResult with score and hallucination flag.
|
| 234 |
"""
|
|
|
|
| 235 |
premise = self._format_premise(evidence_texts, hypothesis=explanation)
|
| 236 |
scores = self._predict([(premise, explanation)])
|
|
|
|
|
|
|
| 237 |
return self._make_result(scores[0], explanation, len(premise))
|
| 238 |
|
| 239 |
def check_claims(
|
|
@@ -297,19 +303,10 @@ class HallucinationDetector:
|
|
| 297 |
]
|
| 298 |
|
| 299 |
|
| 300 |
-
|
| 301 |
-
_detector: HallucinationDetector | None = None
|
| 302 |
-
_detector_lock = threading.Lock()
|
| 303 |
-
|
| 304 |
-
|
| 305 |
def get_detector() -> HallucinationDetector:
|
| 306 |
"""Get or create the global hallucination detector (thread-safe singleton)."""
|
| 307 |
-
|
| 308 |
-
if _detector is None:
|
| 309 |
-
with _detector_lock:
|
| 310 |
-
if _detector is None:
|
| 311 |
-
_detector = HallucinationDetector()
|
| 312 |
-
return _detector
|
| 313 |
|
| 314 |
|
| 315 |
def check_hallucination(
|
|
|
|
| 20 |
- Safe evidence budget: ~400 tokens (~3 chunks at 100 tokens each).
|
| 21 |
"""
|
| 22 |
|
| 23 |
+
import time
|
| 24 |
import warnings
|
| 25 |
|
| 26 |
+
from sage.api.metrics import observe_hhem_duration
|
| 27 |
from sage.core import (
|
| 28 |
ClaimResult,
|
| 29 |
HallucinationResult,
|
|
|
|
| 34 |
HHEM_MODEL,
|
| 35 |
get_logger,
|
| 36 |
)
|
| 37 |
+
from sage.utils import require_import, thread_safe_singleton
|
| 38 |
|
| 39 |
logger = get_logger(__name__)
|
| 40 |
|
|
|
|
| 69 |
Raises:
|
| 70 |
ImportError: If required packages are not installed.
|
| 71 |
"""
|
| 72 |
+
# Import required packages
|
| 73 |
+
torch = require_import("torch")
|
| 74 |
+
hf_hub = require_import("huggingface_hub")
|
| 75 |
+
safetensors_torch = require_import("safetensors.torch", pip_name="safetensors")
|
| 76 |
+
transformers = require_import("transformers")
|
| 77 |
+
|
| 78 |
+
hf_hub_download = hf_hub.hf_hub_download
|
| 79 |
+
load_file = safetensors_torch.load_file
|
| 80 |
+
AutoConfig = transformers.AutoConfig
|
| 81 |
+
AutoTokenizer = transformers.AutoTokenizer
|
| 82 |
+
T5ForTokenClassification = transformers.T5ForTokenClassification
|
| 83 |
|
| 84 |
self.threshold = threshold
|
| 85 |
self.device = device
|
|
|
|
| 235 |
Returns:
|
| 236 |
HallucinationResult with score and hallucination flag.
|
| 237 |
"""
|
| 238 |
+
t0 = time.perf_counter()
|
| 239 |
premise = self._format_premise(evidence_texts, hypothesis=explanation)
|
| 240 |
scores = self._predict([(premise, explanation)])
|
| 241 |
+
hhem_duration = time.perf_counter() - t0
|
| 242 |
+
observe_hhem_duration(hhem_duration)
|
| 243 |
return self._make_result(scores[0], explanation, len(premise))
|
| 244 |
|
| 245 |
def check_claims(
|
|
|
|
| 303 |
]
|
| 304 |
|
| 305 |
|
| 306 |
+
@thread_safe_singleton
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
def get_detector() -> HallucinationDetector:
|
| 308 |
"""Get or create the global hallucination detector (thread-safe singleton)."""
|
| 309 |
+
return HallucinationDetector()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
|
| 312 |
def check_hallucination(
|
sage/adapters/llm.py
CHANGED
|
@@ -2,9 +2,19 @@
|
|
| 2 |
LLM client adapters.
|
| 3 |
|
| 4 |
Provides unified interface for LLM providers (Anthropic Claude, OpenAI GPT).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from sage.config import (
|
| 10 |
ANTHROPIC_API_KEY,
|
|
@@ -16,7 +26,82 @@ from sage.config import (
|
|
| 16 |
LLM_TIMEOUT,
|
| 17 |
OPENAI_API_KEY,
|
| 18 |
OPENAI_MODEL,
|
|
|
|
|
|
|
|
|
|
| 19 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
# ---------------------------------------------------------------------------
|
|
@@ -60,24 +145,59 @@ class LLMClient(Protocol):
|
|
| 60 |
|
| 61 |
|
| 62 |
# ---------------------------------------------------------------------------
|
| 63 |
-
#
|
| 64 |
# ---------------------------------------------------------------------------
|
| 65 |
|
| 66 |
|
| 67 |
-
|
| 68 |
-
"""
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
# ---------------------------------------------------------------------------
|
|
@@ -85,7 +205,7 @@ def _translate_api_error(exc: Exception, sdk, name: str) -> NoReturn:
|
|
| 85 |
# ---------------------------------------------------------------------------
|
| 86 |
|
| 87 |
|
| 88 |
-
class AnthropicClient:
|
| 89 |
"""
|
| 90 |
Anthropic Claude client for explanation generation.
|
| 91 |
|
|
@@ -116,29 +236,27 @@ class AnthropicClient:
|
|
| 116 |
Raises:
|
| 117 |
ImportError: If anthropic package is not installed.
|
| 118 |
"""
|
| 119 |
-
|
| 120 |
-
import anthropic
|
| 121 |
-
except ImportError:
|
| 122 |
-
raise ImportError(
|
| 123 |
-
"anthropic package required. Install with: pip install anthropic"
|
| 124 |
-
)
|
| 125 |
|
| 126 |
self.client = anthropic.Anthropic(
|
| 127 |
api_key=api_key or ANTHROPIC_API_KEY,
|
| 128 |
timeout=timeout,
|
| 129 |
max_retries=max_retries,
|
| 130 |
)
|
| 131 |
-
self.
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
)
|
| 141 |
|
|
|
|
| 142 |
def generate(self, system: str, user: str) -> tuple[str, int]:
|
| 143 |
"""
|
| 144 |
Generate explanation using Claude.
|
|
@@ -152,7 +270,7 @@ class AnthropicClient:
|
|
| 152 |
|
| 153 |
Raises:
|
| 154 |
TimeoutError: If API request times out.
|
| 155 |
-
RuntimeError: If rate limited.
|
| 156 |
ConnectionError: If connection fails.
|
| 157 |
"""
|
| 158 |
try:
|
|
@@ -172,7 +290,7 @@ class AnthropicClient:
|
|
| 172 |
tokens = response.usage.input_tokens + response.usage.output_tokens
|
| 173 |
return text, tokens
|
| 174 |
except self._api_errors as exc:
|
| 175 |
-
|
| 176 |
|
| 177 |
def generate_stream(self, system: str, user: str) -> Iterator[str]:
|
| 178 |
"""
|
|
@@ -201,7 +319,7 @@ class AnthropicClient:
|
|
| 201 |
for text in stream.text_stream:
|
| 202 |
yield text
|
| 203 |
except self._api_errors as exc:
|
| 204 |
-
|
| 205 |
|
| 206 |
|
| 207 |
# ---------------------------------------------------------------------------
|
|
@@ -209,7 +327,7 @@ class AnthropicClient:
|
|
| 209 |
# ---------------------------------------------------------------------------
|
| 210 |
|
| 211 |
|
| 212 |
-
class OpenAIClient:
|
| 213 |
"""
|
| 214 |
OpenAI client for explanation generation.
|
| 215 |
|
|
@@ -240,30 +358,28 @@ class OpenAIClient:
|
|
| 240 |
Raises:
|
| 241 |
ImportError: If openai package is not installed.
|
| 242 |
"""
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
from openai import OpenAI
|
| 246 |
-
except ImportError:
|
| 247 |
-
raise ImportError(
|
| 248 |
-
"openai package required. Install with: pip install openai"
|
| 249 |
-
)
|
| 250 |
|
| 251 |
self.client = OpenAI(
|
| 252 |
api_key=api_key or OPENAI_API_KEY,
|
| 253 |
timeout=timeout,
|
| 254 |
max_retries=max_retries,
|
| 255 |
)
|
| 256 |
-
self.
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
)
|
| 266 |
|
|
|
|
| 267 |
def generate(self, system: str, user: str) -> tuple[str, int]:
|
| 268 |
"""
|
| 269 |
Generate explanation using GPT.
|
|
@@ -277,7 +393,7 @@ class OpenAIClient:
|
|
| 277 |
|
| 278 |
Raises:
|
| 279 |
TimeoutError: If API request times out.
|
| 280 |
-
RuntimeError: If rate limited.
|
| 281 |
ConnectionError: If connection fails.
|
| 282 |
"""
|
| 283 |
try:
|
|
@@ -294,7 +410,7 @@ class OpenAIClient:
|
|
| 294 |
tokens = response.usage.total_tokens if response.usage else 0
|
| 295 |
return text, tokens
|
| 296 |
except self._api_errors as exc:
|
| 297 |
-
|
| 298 |
|
| 299 |
def generate_stream(self, system: str, user: str) -> Iterator[str]:
|
| 300 |
"""
|
|
@@ -327,7 +443,7 @@ class OpenAIClient:
|
|
| 327 |
if chunk.choices[0].delta.content:
|
| 328 |
yield chunk.choices[0].delta.content
|
| 329 |
except self._api_errors as exc:
|
| 330 |
-
|
| 331 |
|
| 332 |
|
| 333 |
# ---------------------------------------------------------------------------
|
|
@@ -340,7 +456,7 @@ def get_llm_client(provider: str | None = None) -> LLMClient:
|
|
| 340 |
Get the configured LLM client.
|
| 341 |
|
| 342 |
Args:
|
| 343 |
-
provider: LLM provider (
|
| 344 |
Defaults to LLM_PROVIDER from config.
|
| 345 |
|
| 346 |
Returns:
|
|
@@ -351,19 +467,23 @@ def get_llm_client(provider: str | None = None) -> LLMClient:
|
|
| 351 |
"""
|
| 352 |
provider = provider or LLM_PROVIDER
|
| 353 |
|
| 354 |
-
if provider ==
|
| 355 |
return AnthropicClient()
|
| 356 |
-
elif provider ==
|
| 357 |
return OpenAIClient()
|
| 358 |
else:
|
| 359 |
raise ValueError(
|
| 360 |
-
f"Unknown LLM provider: {provider}.
|
|
|
|
| 361 |
)
|
| 362 |
|
| 363 |
|
| 364 |
__all__ = [
|
| 365 |
"LLMClient",
|
|
|
|
| 366 |
"AnthropicClient",
|
| 367 |
"OpenAIClient",
|
| 368 |
"get_llm_client",
|
|
|
|
|
|
|
| 369 |
]
|
|
|
|
| 2 |
LLM client adapters.
|
| 3 |
|
| 4 |
Provides unified interface for LLM providers (Anthropic Claude, OpenAI GPT).
|
| 5 |
+
|
| 6 |
+
Includes exponential backoff with jitter for rate limit handling:
|
| 7 |
+
- Initial delay: 1 second
|
| 8 |
+
- Max delay: 60 seconds
|
| 9 |
+
- Jitter: 0-25% random variation
|
| 10 |
+
- Max retries: configurable (default 3 for rate limits)
|
| 11 |
"""
|
| 12 |
|
| 13 |
+
import random
|
| 14 |
+
import time
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from typing import Any, Callable, Iterator, NoReturn, Protocol, TypeVar
|
| 18 |
|
| 19 |
from sage.config import (
|
| 20 |
ANTHROPIC_API_KEY,
|
|
|
|
| 26 |
LLM_TIMEOUT,
|
| 27 |
OPENAI_API_KEY,
|
| 28 |
OPENAI_MODEL,
|
| 29 |
+
PROVIDER_ANTHROPIC,
|
| 30 |
+
PROVIDER_OPENAI,
|
| 31 |
+
get_logger,
|
| 32 |
)
|
| 33 |
+
from sage.utils import require_import
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
T = TypeVar("T")
|
| 38 |
+
|
| 39 |
+
# Exponential backoff settings for rate limits
|
| 40 |
+
RATE_LIMIT_INITIAL_DELAY = 1.0 # seconds
|
| 41 |
+
RATE_LIMIT_MAX_DELAY = 60.0 # seconds
|
| 42 |
+
RATE_LIMIT_MAX_RETRIES = 3 # additional retries for rate limits
|
| 43 |
+
RATE_LIMIT_JITTER = 0.25 # 25% random jitter
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _calculate_backoff_delay(attempt: int, jitter: float = RATE_LIMIT_JITTER) -> float:
|
| 47 |
+
"""Calculate exponential backoff delay with jitter.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
attempt: Current retry attempt (0-indexed).
|
| 51 |
+
jitter: Maximum jitter factor (0.25 = up to 25% variation).
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Delay in seconds.
|
| 55 |
+
"""
|
| 56 |
+
base_delay = RATE_LIMIT_INITIAL_DELAY * (2**attempt)
|
| 57 |
+
delay = min(base_delay, RATE_LIMIT_MAX_DELAY)
|
| 58 |
+
# Add random jitter to prevent thundering herd
|
| 59 |
+
jitter_amount = delay * jitter * random.random()
|
| 60 |
+
return delay + jitter_amount
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def with_rate_limit_retry(func: Callable[..., T]) -> Callable[..., T]:
|
| 64 |
+
"""Decorator for retrying on rate limit errors with exponential backoff.
|
| 65 |
+
|
| 66 |
+
Wraps LLM generate methods to handle rate limit errors gracefully.
|
| 67 |
+
Uses exponential backoff with jitter to avoid thundering herd.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
@wraps(func)
|
| 71 |
+
def wrapper(self, *args, **kwargs) -> T:
|
| 72 |
+
last_exception = None
|
| 73 |
+
|
| 74 |
+
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
| 75 |
+
try:
|
| 76 |
+
return func(self, *args, **kwargs)
|
| 77 |
+
except RuntimeError as e:
|
| 78 |
+
# Check if this is a rate limit error (translated from SDK)
|
| 79 |
+
if "rate limit" not in str(e).lower():
|
| 80 |
+
raise
|
| 81 |
+
|
| 82 |
+
last_exception = e
|
| 83 |
+
|
| 84 |
+
if attempt < RATE_LIMIT_MAX_RETRIES:
|
| 85 |
+
delay = _calculate_backoff_delay(attempt)
|
| 86 |
+
logger.warning(
|
| 87 |
+
"Rate limited (attempt %d/%d), backing off %.1fs: %s",
|
| 88 |
+
attempt + 1,
|
| 89 |
+
RATE_LIMIT_MAX_RETRIES + 1,
|
| 90 |
+
delay,
|
| 91 |
+
e,
|
| 92 |
+
)
|
| 93 |
+
time.sleep(delay)
|
| 94 |
+
else:
|
| 95 |
+
logger.error(
|
| 96 |
+
"Rate limit persists after %d retries: %s",
|
| 97 |
+
RATE_LIMIT_MAX_RETRIES + 1,
|
| 98 |
+
e,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# All retries exhausted
|
| 102 |
+
raise last_exception # type: ignore[misc]
|
| 103 |
+
|
| 104 |
+
return wrapper
|
| 105 |
|
| 106 |
|
| 107 |
# ---------------------------------------------------------------------------
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
# ---------------------------------------------------------------------------
|
| 148 |
+
# Base class with shared logic
|
| 149 |
# ---------------------------------------------------------------------------
|
| 150 |
|
| 151 |
|
| 152 |
+
class LLMClientBase(ABC):
|
| 153 |
+
"""Base class with shared initialization and error handling."""
|
| 154 |
|
| 155 |
+
client: Any
|
| 156 |
+
model: str
|
| 157 |
+
temperature: float
|
| 158 |
+
max_tokens: int
|
| 159 |
+
_sdk: Any
|
| 160 |
+
_name: str
|
| 161 |
+
_api_errors: tuple[type[Exception], ...]
|
| 162 |
+
|
| 163 |
+
def _init_common(
|
| 164 |
+
self,
|
| 165 |
+
model: str,
|
| 166 |
+
temperature: float,
|
| 167 |
+
max_tokens: int,
|
| 168 |
+
sdk: Any,
|
| 169 |
+
name: str,
|
| 170 |
+
api_errors: tuple[type[Exception], ...],
|
| 171 |
+
) -> None:
|
| 172 |
+
"""Initialize common attributes."""
|
| 173 |
+
self.model = model
|
| 174 |
+
self.temperature = temperature
|
| 175 |
+
self.max_tokens = max_tokens
|
| 176 |
+
self._sdk = sdk
|
| 177 |
+
self._name = name
|
| 178 |
+
self._api_errors = api_errors
|
| 179 |
+
|
| 180 |
+
def _translate_error(self, exc: Exception) -> NoReturn:
|
| 181 |
+
"""Translate SDK-specific API errors to built-in exceptions."""
|
| 182 |
+
if isinstance(exc, self._sdk.APITimeoutError):
|
| 183 |
+
raise TimeoutError(f"{self._name} API request timed out: {exc}") from exc
|
| 184 |
+
if isinstance(exc, self._sdk.RateLimitError):
|
| 185 |
+
raise RuntimeError(f"{self._name} API rate limited: {exc}") from exc
|
| 186 |
+
if isinstance(exc, self._sdk.APIConnectionError):
|
| 187 |
+
raise ConnectionError(
|
| 188 |
+
f"Failed to connect to {self._name} API: {exc}"
|
| 189 |
+
) from exc
|
| 190 |
+
raise exc
|
| 191 |
+
|
| 192 |
+
@abstractmethod
|
| 193 |
+
def generate(self, system: str, user: str) -> tuple[str, int]:
|
| 194 |
+
"""Generate a response from the LLM."""
|
| 195 |
+
...
|
| 196 |
+
|
| 197 |
+
@abstractmethod
|
| 198 |
+
def generate_stream(self, system: str, user: str) -> Iterator[str]:
|
| 199 |
+
"""Stream response tokens from the LLM."""
|
| 200 |
+
...
|
| 201 |
|
| 202 |
|
| 203 |
# ---------------------------------------------------------------------------
|
|
|
|
| 205 |
# ---------------------------------------------------------------------------
|
| 206 |
|
| 207 |
|
| 208 |
+
class AnthropicClient(LLMClientBase):
|
| 209 |
"""
|
| 210 |
Anthropic Claude client for explanation generation.
|
| 211 |
|
|
|
|
| 236 |
Raises:
|
| 237 |
ImportError: If anthropic package is not installed.
|
| 238 |
"""
|
| 239 |
+
anthropic = require_import("anthropic")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
self.client = anthropic.Anthropic(
|
| 242 |
api_key=api_key or ANTHROPIC_API_KEY,
|
| 243 |
timeout=timeout,
|
| 244 |
max_retries=max_retries,
|
| 245 |
)
|
| 246 |
+
self._init_common(
|
| 247 |
+
model=model,
|
| 248 |
+
temperature=temperature,
|
| 249 |
+
max_tokens=max_tokens,
|
| 250 |
+
sdk=anthropic,
|
| 251 |
+
name="Anthropic",
|
| 252 |
+
api_errors=(
|
| 253 |
+
anthropic.APITimeoutError,
|
| 254 |
+
anthropic.RateLimitError,
|
| 255 |
+
anthropic.APIConnectionError,
|
| 256 |
+
),
|
| 257 |
)
|
| 258 |
|
| 259 |
+
@with_rate_limit_retry
|
| 260 |
def generate(self, system: str, user: str) -> tuple[str, int]:
|
| 261 |
"""
|
| 262 |
Generate explanation using Claude.
|
|
|
|
| 270 |
|
| 271 |
Raises:
|
| 272 |
TimeoutError: If API request times out.
|
| 273 |
+
RuntimeError: If rate limited (after retries exhausted).
|
| 274 |
ConnectionError: If connection fails.
|
| 275 |
"""
|
| 276 |
try:
|
|
|
|
| 290 |
tokens = response.usage.input_tokens + response.usage.output_tokens
|
| 291 |
return text, tokens
|
| 292 |
except self._api_errors as exc:
|
| 293 |
+
self._translate_error(exc)
|
| 294 |
|
| 295 |
def generate_stream(self, system: str, user: str) -> Iterator[str]:
|
| 296 |
"""
|
|
|
|
| 319 |
for text in stream.text_stream:
|
| 320 |
yield text
|
| 321 |
except self._api_errors as exc:
|
| 322 |
+
self._translate_error(exc)
|
| 323 |
|
| 324 |
|
| 325 |
# ---------------------------------------------------------------------------
|
|
|
|
| 327 |
# ---------------------------------------------------------------------------
|
| 328 |
|
| 329 |
|
| 330 |
+
class OpenAIClient(LLMClientBase):
|
| 331 |
"""
|
| 332 |
OpenAI client for explanation generation.
|
| 333 |
|
|
|
|
| 358 |
Raises:
|
| 359 |
ImportError: If openai package is not installed.
|
| 360 |
"""
|
| 361 |
+
openai = require_import("openai")
|
| 362 |
+
OpenAI = openai.OpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
self.client = OpenAI(
|
| 365 |
api_key=api_key or OPENAI_API_KEY,
|
| 366 |
timeout=timeout,
|
| 367 |
max_retries=max_retries,
|
| 368 |
)
|
| 369 |
+
self._init_common(
|
| 370 |
+
model=model,
|
| 371 |
+
temperature=temperature,
|
| 372 |
+
max_tokens=max_tokens,
|
| 373 |
+
sdk=openai,
|
| 374 |
+
name="OpenAI",
|
| 375 |
+
api_errors=(
|
| 376 |
+
openai.APITimeoutError,
|
| 377 |
+
openai.RateLimitError,
|
| 378 |
+
openai.APIConnectionError,
|
| 379 |
+
),
|
| 380 |
)
|
| 381 |
|
| 382 |
+
@with_rate_limit_retry
|
| 383 |
def generate(self, system: str, user: str) -> tuple[str, int]:
|
| 384 |
"""
|
| 385 |
Generate explanation using GPT.
|
|
|
|
| 393 |
|
| 394 |
Raises:
|
| 395 |
TimeoutError: If API request times out.
|
| 396 |
+
RuntimeError: If rate limited (after retries exhausted).
|
| 397 |
ConnectionError: If connection fails.
|
| 398 |
"""
|
| 399 |
try:
|
|
|
|
| 410 |
tokens = response.usage.total_tokens if response.usage else 0
|
| 411 |
return text, tokens
|
| 412 |
except self._api_errors as exc:
|
| 413 |
+
self._translate_error(exc)
|
| 414 |
|
| 415 |
def generate_stream(self, system: str, user: str) -> Iterator[str]:
|
| 416 |
"""
|
|
|
|
| 443 |
if chunk.choices[0].delta.content:
|
| 444 |
yield chunk.choices[0].delta.content
|
| 445 |
except self._api_errors as exc:
|
| 446 |
+
self._translate_error(exc)
|
| 447 |
|
| 448 |
|
| 449 |
# ---------------------------------------------------------------------------
|
|
|
|
| 456 |
Get the configured LLM client.
|
| 457 |
|
| 458 |
Args:
|
| 459 |
+
provider: LLM provider (PROVIDER_ANTHROPIC or PROVIDER_OPENAI).
|
| 460 |
Defaults to LLM_PROVIDER from config.
|
| 461 |
|
| 462 |
Returns:
|
|
|
|
| 467 |
"""
|
| 468 |
provider = provider or LLM_PROVIDER
|
| 469 |
|
| 470 |
+
if provider == PROVIDER_ANTHROPIC:
|
| 471 |
return AnthropicClient()
|
| 472 |
+
elif provider == PROVIDER_OPENAI:
|
| 473 |
return OpenAIClient()
|
| 474 |
else:
|
| 475 |
raise ValueError(
|
| 476 |
+
f"Unknown LLM provider: {provider}. "
|
| 477 |
+
f"Use '{PROVIDER_ANTHROPIC}' or '{PROVIDER_OPENAI}'."
|
| 478 |
)
|
| 479 |
|
| 480 |
|
| 481 |
__all__ = [
|
| 482 |
"LLMClient",
|
| 483 |
+
"LLMClientBase",
|
| 484 |
"AnthropicClient",
|
| 485 |
"OpenAIClient",
|
| 486 |
"get_llm_client",
|
| 487 |
+
"with_rate_limit_retry",
|
| 488 |
+
"RATE_LIMIT_MAX_RETRIES",
|
| 489 |
]
|
sage/adapters/vector_store.py
CHANGED
|
@@ -20,6 +20,7 @@ from sage.config import (
|
|
| 20 |
QDRANT_URL,
|
| 21 |
get_logger,
|
| 22 |
)
|
|
|
|
| 23 |
|
| 24 |
logger = get_logger(__name__)
|
| 25 |
|
|
@@ -44,12 +45,8 @@ def get_client():
|
|
| 44 |
Raises:
|
| 45 |
ImportError: If qdrant-client is not installed.
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
except ImportError:
|
| 50 |
-
raise ImportError(
|
| 51 |
-
"qdrant-client package required. Install with: pip install qdrant-client"
|
| 52 |
-
)
|
| 53 |
|
| 54 |
if QDRANT_API_KEY:
|
| 55 |
return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
|
|
|
| 20 |
QDRANT_URL,
|
| 21 |
get_logger,
|
| 22 |
)
|
| 23 |
+
from sage.utils import require_import
|
| 24 |
|
| 25 |
logger = get_logger(__name__)
|
| 26 |
|
|
|
|
| 45 |
Raises:
|
| 46 |
ImportError: If qdrant-client is not installed.
|
| 47 |
"""
|
| 48 |
+
qdrant = require_import("qdrant_client", pip_name="qdrant-client")
|
| 49 |
+
QdrantClient = qdrant.QdrantClient
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
if QDRANT_API_KEY:
|
| 52 |
return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
sage/api/app.py
CHANGED
|
@@ -4,6 +4,10 @@ FastAPI application factory.
|
|
| 4 |
Creates the app with lifespan-managed singletons (embedder, Qdrant client,
|
| 5 |
HHEM detector, LLM explainer, semantic cache) so heavy models are loaded
|
| 6 |
once at startup and shared across requests.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from __future__ import annotations
|
|
@@ -14,20 +18,37 @@ from contextlib import asynccontextmanager
|
|
| 14 |
from fastapi import FastAPI
|
| 15 |
from starlette.middleware.cors import CORSMiddleware
|
| 16 |
|
| 17 |
-
from sage.api.middleware import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from sage.api.routes import router
|
| 19 |
from sage.config import get_logger
|
| 20 |
|
| 21 |
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "*").split(",")]
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
logger = get_logger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
@asynccontextmanager
|
| 27 |
async def _lifespan(app: FastAPI):
|
| 28 |
-
"""Initialize shared resources at startup, release at shutdown.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
logger.info("Starting Sage API...")
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# Validate LLM credentials early
|
| 32 |
from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
|
| 33 |
|
|
@@ -92,7 +113,16 @@ async def _lifespan(app: FastAPI):
|
|
| 92 |
|
| 93 |
logger.info("Sage API ready")
|
| 94 |
yield
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
def create_app() -> FastAPI:
|
|
|
|
| 4 |
Creates the app with lifespan-managed singletons (embedder, Qdrant client,
|
| 5 |
HHEM detector, LLM explainer, semantic cache) so heavy models are loaded
|
| 6 |
once at startup and shared across requests.
|
| 7 |
+
|
| 8 |
+
Graceful shutdown:
|
| 9 |
+
- On SIGTERM, waits for active requests to complete (up to 30s)
|
| 10 |
+
- New requests during shutdown return 503 with Retry-After header
|
| 11 |
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
|
|
|
| 18 |
from fastapi import FastAPI
|
| 19 |
from starlette.middleware.cors import CORSMiddleware
|
| 20 |
|
| 21 |
+
from sage.api.middleware import (
|
| 22 |
+
LatencyMiddleware,
|
| 23 |
+
get_shutdown_coordinator,
|
| 24 |
+
reset_shutdown_coordinator,
|
| 25 |
+
)
|
| 26 |
from sage.api.routes import router
|
| 27 |
from sage.config import get_logger
|
| 28 |
|
| 29 |
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "*").split(",")]
|
| 30 |
|
| 31 |
+
# Graceful shutdown timeout (seconds to wait for active requests)
|
| 32 |
+
SHUTDOWN_TIMEOUT = float(os.getenv("SHUTDOWN_TIMEOUT", "30.0"))
|
| 33 |
+
|
| 34 |
logger = get_logger(__name__)
|
| 35 |
|
| 36 |
|
| 37 |
@asynccontextmanager
|
| 38 |
async def _lifespan(app: FastAPI):
|
| 39 |
+
"""Initialize shared resources at startup, release at shutdown.
|
| 40 |
+
|
| 41 |
+
Shutdown sequence:
|
| 42 |
+
1. Signal shutdown coordinator (new requests get 503)
|
| 43 |
+
2. Wait for active requests to complete (up to SHUTDOWN_TIMEOUT)
|
| 44 |
+
3. Release resources
|
| 45 |
+
"""
|
| 46 |
logger.info("Starting Sage API...")
|
| 47 |
|
| 48 |
+
# Reset shutdown coordinator for this app instance
|
| 49 |
+
reset_shutdown_coordinator()
|
| 50 |
+
coordinator = get_shutdown_coordinator()
|
| 51 |
+
|
| 52 |
# Validate LLM credentials early
|
| 53 |
from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
|
| 54 |
|
|
|
|
| 113 |
|
| 114 |
logger.info("Sage API ready")
|
| 115 |
yield
|
| 116 |
+
|
| 117 |
+
# Graceful shutdown: wait for active requests to complete
|
| 118 |
+
logger.info("Sage API shutting down...")
|
| 119 |
+
completed = await coordinator.wait_for_shutdown(timeout=SHUTDOWN_TIMEOUT)
|
| 120 |
+
if not completed:
|
| 121 |
+
logger.warning(
|
| 122 |
+
"Forced shutdown with %d requests still active",
|
| 123 |
+
coordinator.active_requests,
|
| 124 |
+
)
|
| 125 |
+
logger.info("Sage API shutdown complete")
|
| 126 |
|
| 127 |
|
| 128 |
def create_app() -> FastAPI:
|
sage/api/metrics.py
CHANGED
|
@@ -3,6 +3,25 @@ Prometheus metrics with graceful degradation.
|
|
| 3 |
|
| 4 |
If ``prometheus-client`` is not installed, all metric operations become no-ops
|
| 5 |
so the application can run without the optional dependency.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
@@ -23,25 +42,84 @@ try:
|
|
| 23 |
CONTENT_TYPE_LATEST,
|
| 24 |
)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
REQUEST_COUNT = Counter(
|
| 27 |
"sage_requests_total",
|
| 28 |
"Total HTTP requests",
|
| 29 |
["endpoint", "method", "status"],
|
| 30 |
)
|
| 31 |
|
| 32 |
-
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
["endpoint"],
|
| 36 |
-
buckets=
|
| 37 |
)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
CACHE_EVENTS = Counter(
|
| 40 |
"sage_cache_events_total",
|
| 41 |
"Cache lookup results",
|
| 42 |
["result"], # hit_exact, hit_semantic, miss
|
| 43 |
)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
_PROMETHEUS_AVAILABLE = True
|
| 46 |
|
| 47 |
except ImportError:
|
|
@@ -61,9 +139,18 @@ def record_request(endpoint: str, method: str, status: int) -> None:
|
|
| 61 |
|
| 62 |
|
| 63 |
def observe_duration(endpoint: str, duration_ms: float) -> None:
|
| 64 |
-
"""Record request
|
| 65 |
if _PROMETHEUS_AVAILABLE:
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def record_cache_event(result: str) -> None:
|
|
@@ -75,6 +162,30 @@ def record_cache_event(result: str) -> None:
|
|
| 75 |
CACHE_EVENTS.labels(result=result).inc()
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def prometheus_available() -> bool:
|
| 79 |
"""Return True if prometheus-client is importable."""
|
| 80 |
return _PROMETHEUS_AVAILABLE
|
|
|
|
| 3 |
|
| 4 |
If ``prometheus-client`` is not installed, all metric operations become no-ops
|
| 5 |
so the application can run without the optional dependency.
|
| 6 |
+
|
| 7 |
+
Metrics exposed at GET /metrics:
|
| 8 |
+
- sage_request_latency_seconds: End-to-end request latency (p50/p95/p99)
|
| 9 |
+
- sage_requests_total: Total requests by endpoint/method/status
|
| 10 |
+
- sage_cache_events_total: Cache hits (L1/L2) and misses
|
| 11 |
+
- sage_llm_duration_seconds: Time spent waiting on LLM API
|
| 12 |
+
- sage_retrieval_duration_seconds: Time spent on Qdrant vector search
|
| 13 |
+
- sage_embedding_duration_seconds: Time spent computing query embeddings
|
| 14 |
+
- sage_errors_total: Errors by type (timeout, llm_error, retrieval_error, etc.)
|
| 15 |
+
|
| 16 |
+
Latency budget breakdown (target p99 < 500ms):
|
| 17 |
+
1. Embedding query: ~20ms
|
| 18 |
+
2. Cache check: ~1ms (L1) or ~50ms (L2 semantic)
|
| 19 |
+
3. Vector retrieval: ~50-100ms
|
| 20 |
+
4. LLM generation: ~200-400ms
|
| 21 |
+
5. HHEM verification: ~50-100ms
|
| 22 |
+
----------------------------------------
|
| 23 |
+
Total (no cache): ~400-600ms
|
| 24 |
+
Total (cache hit): <100ms
|
| 25 |
"""
|
| 26 |
|
| 27 |
from __future__ import annotations
|
|
|
|
| 42 |
CONTENT_TYPE_LATEST,
|
| 43 |
)
|
| 44 |
|
| 45 |
+
# Standard bucket sizes for latency histograms (in seconds)
|
| 46 |
+
# Covers 5ms to 30s range for p50/p95/p99 calculation
|
| 47 |
+
LATENCY_BUCKETS = (
|
| 48 |
+
0.005,
|
| 49 |
+
0.01,
|
| 50 |
+
0.025,
|
| 51 |
+
0.05,
|
| 52 |
+
0.1,
|
| 53 |
+
0.25,
|
| 54 |
+
0.5,
|
| 55 |
+
1.0,
|
| 56 |
+
2.5,
|
| 57 |
+
5.0,
|
| 58 |
+
10.0,
|
| 59 |
+
30.0,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
# Request-level metrics
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
|
| 66 |
REQUEST_COUNT = Counter(
|
| 67 |
"sage_requests_total",
|
| 68 |
"Total HTTP requests",
|
| 69 |
["endpoint", "method", "status"],
|
| 70 |
)
|
| 71 |
|
| 72 |
+
REQUEST_LATENCY = Histogram(
|
| 73 |
+
"sage_request_latency_seconds",
|
| 74 |
+
"End-to-end request latency in seconds",
|
| 75 |
["endpoint"],
|
| 76 |
+
buckets=LATENCY_BUCKETS,
|
| 77 |
)
|
| 78 |
|
| 79 |
+
ERRORS = Counter(
|
| 80 |
+
"sage_errors_total",
|
| 81 |
+
"Total errors by type",
|
| 82 |
+
["error_type"], # timeout, llm_error, retrieval_error, validation_error
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
# Cache metrics
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
CACHE_EVENTS = Counter(
|
| 90 |
"sage_cache_events_total",
|
| 91 |
"Cache lookup results",
|
| 92 |
["result"], # hit_exact, hit_semantic, miss
|
| 93 |
)
|
| 94 |
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Component-level latency metrics (for latency budget breakdown)
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
EMBEDDING_DURATION = Histogram(
|
| 100 |
+
"sage_embedding_duration_seconds",
|
| 101 |
+
"Time to compute query embedding",
|
| 102 |
+
buckets=LATENCY_BUCKETS,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
RETRIEVAL_DURATION = Histogram(
|
| 106 |
+
"sage_retrieval_duration_seconds",
|
| 107 |
+
"Time for Qdrant vector search",
|
| 108 |
+
buckets=LATENCY_BUCKETS,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
LLM_DURATION = Histogram(
|
| 112 |
+
"sage_llm_duration_seconds",
|
| 113 |
+
"Time waiting on LLM API for explanation generation",
|
| 114 |
+
buckets=LATENCY_BUCKETS,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
HHEM_DURATION = Histogram(
|
| 118 |
+
"sage_hhem_duration_seconds",
|
| 119 |
+
"Time for HHEM hallucination check",
|
| 120 |
+
buckets=LATENCY_BUCKETS,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
_PROMETHEUS_AVAILABLE = True
|
| 124 |
|
| 125 |
except ImportError:
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
def observe_duration(endpoint: str, duration_ms: float) -> None:
|
| 142 |
+
"""Record end-to-end request latency (converts ms to seconds for Prometheus)."""
|
| 143 |
if _PROMETHEUS_AVAILABLE:
|
| 144 |
+
REQUEST_LATENCY.labels(endpoint=endpoint).observe(duration_ms / 1000.0)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def record_error(error_type: str) -> None:
|
| 148 |
+
"""Record an error by type.
|
| 149 |
+
|
| 150 |
+
Common error types: timeout, llm_error, retrieval_error, validation_error
|
| 151 |
+
"""
|
| 152 |
+
if _PROMETHEUS_AVAILABLE:
|
| 153 |
+
ERRORS.labels(error_type=error_type).inc()
|
| 154 |
|
| 155 |
|
| 156 |
def record_cache_event(result: str) -> None:
|
|
|
|
| 162 |
CACHE_EVENTS.labels(result=result).inc()
|
| 163 |
|
| 164 |
|
| 165 |
+
def observe_embedding_duration(duration_seconds: float) -> None:
|
| 166 |
+
"""Record query embedding computation time."""
|
| 167 |
+
if _PROMETHEUS_AVAILABLE:
|
| 168 |
+
EMBEDDING_DURATION.observe(duration_seconds)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def observe_retrieval_duration(duration_seconds: float) -> None:
|
| 172 |
+
"""Record Qdrant vector search time."""
|
| 173 |
+
if _PROMETHEUS_AVAILABLE:
|
| 174 |
+
RETRIEVAL_DURATION.observe(duration_seconds)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def observe_llm_duration(duration_seconds: float) -> None:
|
| 178 |
+
"""Record LLM API call time."""
|
| 179 |
+
if _PROMETHEUS_AVAILABLE:
|
| 180 |
+
LLM_DURATION.observe(duration_seconds)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def observe_hhem_duration(duration_seconds: float) -> None:
|
| 184 |
+
"""Record HHEM hallucination check time."""
|
| 185 |
+
if _PROMETHEUS_AVAILABLE:
|
| 186 |
+
HHEM_DURATION.observe(duration_seconds)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
def prometheus_available() -> bool:
|
| 190 |
"""Return True if prometheus-client is importable."""
|
| 191 |
return _PROMETHEUS_AVAILABLE
|
sage/api/middleware.py
CHANGED
|
@@ -1,18 +1,26 @@
|
|
| 1 |
"""
|
| 2 |
-
Request latency middleware.
|
| 3 |
|
| 4 |
Logs method/path/status/elapsed_ms for every request and records
|
| 5 |
Prometheus histogram observations. Adds ``X-Response-Time-Ms`` header.
|
| 6 |
|
| 7 |
Uses a pure ASGI middleware (not BaseHTTPMiddleware) to avoid buffering
|
| 8 |
SSE streams.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
|
|
| 13 |
import time
|
| 14 |
import uuid
|
|
|
|
| 15 |
|
|
|
|
| 16 |
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
| 17 |
|
| 18 |
from sage.api.metrics import observe_duration, record_request
|
|
@@ -20,13 +28,123 @@ from sage.config import get_logger
|
|
| 20 |
|
| 21 |
logger = get_logger(__name__)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Paths excluded from per-request logging (still measured by Prometheus)
|
| 24 |
-
_QUIET_PATHS = {"/metrics", "/health"}
|
| 25 |
|
| 26 |
# Known route patterns -- map raw paths to normalized labels to prevent
|
| 27 |
# unbounded Prometheus cardinality from bot scanners hitting random paths.
|
| 28 |
_KNOWN_ROUTES = {
|
| 29 |
"/health": "/health",
|
|
|
|
| 30 |
"/recommend": "/recommend",
|
| 31 |
"/recommend/stream": "/recommend/stream",
|
| 32 |
"/cache/stats": "/cache/stats",
|
|
@@ -42,9 +160,10 @@ def _normalize_path(path: str) -> str:
|
|
| 42 |
|
| 43 |
|
| 44 |
class LatencyMiddleware:
|
| 45 |
-
"""Pure ASGI middleware for latency measurement.
|
| 46 |
|
| 47 |
Does NOT buffer response bodies, so SSE streaming works correctly.
|
|
|
|
| 48 |
"""
|
| 49 |
|
| 50 |
def __init__(self, app: ASGIApp) -> None:
|
|
@@ -55,9 +174,21 @@ class LatencyMiddleware:
|
|
| 55 |
await self.app(scope, receive, send)
|
| 56 |
return
|
| 57 |
|
| 58 |
-
|
| 59 |
path = _normalize_path(scope["path"])
|
| 60 |
method = scope["method"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
request_id = uuid.uuid4().hex[:12]
|
| 62 |
status = 500 # default until we see http.response.start
|
| 63 |
|
|
@@ -74,21 +205,23 @@ class LatencyMiddleware:
|
|
| 74 |
message = {**message, "headers": headers}
|
| 75 |
await send(message)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Request latency middleware and graceful shutdown coordinator.
|
| 3 |
|
| 4 |
Logs method/path/status/elapsed_ms for every request and records
|
| 5 |
Prometheus histogram observations. Adds ``X-Response-Time-Ms`` header.
|
| 6 |
|
| 7 |
Uses a pure ASGI middleware (not BaseHTTPMiddleware) to avoid buffering
|
| 8 |
SSE streams.
|
| 9 |
+
|
| 10 |
+
Graceful shutdown:
|
| 11 |
+
- Tracks active request count
|
| 12 |
+
- On SIGTERM, waits for active requests to complete (up to timeout)
|
| 13 |
+
- Prevents new requests during shutdown (returns 503)
|
| 14 |
"""
|
| 15 |
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
+
import asyncio
|
| 19 |
import time
|
| 20 |
import uuid
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
|
| 23 |
+
from starlette.responses import JSONResponse
|
| 24 |
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
| 25 |
|
| 26 |
from sage.api.metrics import observe_duration, record_request
|
|
|
|
| 28 |
|
| 29 |
logger = get_logger(__name__)
|
| 30 |
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Graceful Shutdown Coordinator
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ShutdownCoordinator:
|
| 39 |
+
"""Coordinates graceful shutdown by tracking active requests.
|
| 40 |
+
|
| 41 |
+
Usage:
|
| 42 |
+
coordinator = ShutdownCoordinator()
|
| 43 |
+
|
| 44 |
+
# In middleware: track requests
|
| 45 |
+
async with coordinator.track_request():
|
| 46 |
+
await handle_request()
|
| 47 |
+
|
| 48 |
+
# In lifespan shutdown: wait for completion
|
| 49 |
+
await coordinator.wait_for_shutdown(timeout=30.0)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
_active_requests: int = field(default=0, init=False)
|
| 53 |
+
_shutting_down: bool = field(default=False, init=False)
|
| 54 |
+
_shutdown_event: asyncio.Event = field(default_factory=asyncio.Event, init=False)
|
| 55 |
+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, init=False)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def active_requests(self) -> int:
|
| 59 |
+
"""Number of currently active requests."""
|
| 60 |
+
return self._active_requests
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def is_shutting_down(self) -> bool:
|
| 64 |
+
"""True if shutdown has been initiated."""
|
| 65 |
+
return self._shutting_down
|
| 66 |
+
|
| 67 |
+
async def track_request(self):
|
| 68 |
+
"""Context manager to track an active request."""
|
| 69 |
+
async with self._lock:
|
| 70 |
+
self._active_requests += 1
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
yield
|
| 74 |
+
finally:
|
| 75 |
+
async with self._lock:
|
| 76 |
+
self._active_requests -= 1
|
| 77 |
+
if self._active_requests == 0 and self._shutting_down:
|
| 78 |
+
self._shutdown_event.set()
|
| 79 |
+
|
| 80 |
+
async def initiate_shutdown(self) -> None:
|
| 81 |
+
"""Signal that shutdown has begun."""
|
| 82 |
+
async with self._lock:
|
| 83 |
+
self._shutting_down = True
|
| 84 |
+
if self._active_requests == 0:
|
| 85 |
+
self._shutdown_event.set()
|
| 86 |
+
logger.info("Shutdown initiated, %d active requests", self._active_requests)
|
| 87 |
+
|
| 88 |
+
async def wait_for_shutdown(self, timeout: float = 30.0) -> bool:
|
| 89 |
+
"""Wait for active requests to complete.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
timeout: Maximum seconds to wait for requests to complete.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
True if all requests completed, False if timed out.
|
| 96 |
+
"""
|
| 97 |
+
await self.initiate_shutdown()
|
| 98 |
+
|
| 99 |
+
if self._active_requests == 0:
|
| 100 |
+
logger.info("No active requests, shutdown immediate")
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
logger.info(
|
| 104 |
+
"Waiting up to %.1fs for %d active requests",
|
| 105 |
+
timeout,
|
| 106 |
+
self._active_requests,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
await asyncio.wait_for(self._shutdown_event.wait(), timeout=timeout)
|
| 111 |
+
logger.info("All requests completed, proceeding with shutdown")
|
| 112 |
+
return True
|
| 113 |
+
except asyncio.TimeoutError:
|
| 114 |
+
logger.warning(
|
| 115 |
+
"Shutdown timeout: %d requests still active after %.1fs",
|
| 116 |
+
self._active_requests,
|
| 117 |
+
timeout,
|
| 118 |
+
)
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Global coordinator instance (set during app lifespan)
|
| 123 |
+
_shutdown_coordinator: ShutdownCoordinator | None = None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_shutdown_coordinator() -> ShutdownCoordinator:
|
| 127 |
+
"""Get the global shutdown coordinator."""
|
| 128 |
+
global _shutdown_coordinator
|
| 129 |
+
if _shutdown_coordinator is None:
|
| 130 |
+
_shutdown_coordinator = ShutdownCoordinator()
|
| 131 |
+
return _shutdown_coordinator
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def reset_shutdown_coordinator() -> None:
|
| 135 |
+
"""Reset the global shutdown coordinator (for testing)."""
|
| 136 |
+
global _shutdown_coordinator
|
| 137 |
+
_shutdown_coordinator = None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
# Paths excluded from per-request logging (still measured by Prometheus)
|
| 141 |
+
_QUIET_PATHS = {"/metrics", "/health", "/ready"}
|
| 142 |
|
| 143 |
# Known route patterns -- map raw paths to normalized labels to prevent
|
| 144 |
# unbounded Prometheus cardinality from bot scanners hitting random paths.
|
| 145 |
_KNOWN_ROUTES = {
|
| 146 |
"/health": "/health",
|
| 147 |
+
"/ready": "/ready",
|
| 148 |
"/recommend": "/recommend",
|
| 149 |
"/recommend/stream": "/recommend/stream",
|
| 150 |
"/cache/stats": "/cache/stats",
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
class LatencyMiddleware:
|
| 163 |
+
"""Pure ASGI middleware for latency measurement and graceful shutdown.
|
| 164 |
|
| 165 |
Does NOT buffer response bodies, so SSE streaming works correctly.
|
| 166 |
+
During shutdown, rejects new requests with 503 Service Unavailable.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(self, app: ASGIApp) -> None:
|
|
|
|
| 174 |
await self.app(scope, receive, send)
|
| 175 |
return
|
| 176 |
|
| 177 |
+
coordinator = get_shutdown_coordinator()
|
| 178 |
path = _normalize_path(scope["path"])
|
| 179 |
method = scope["method"]
|
| 180 |
+
|
| 181 |
+
# During shutdown, reject new requests (except health checks)
|
| 182 |
+
if coordinator.is_shutting_down and path not in {"/health", "/ready"}:
|
| 183 |
+
response = JSONResponse(
|
| 184 |
+
status_code=503,
|
| 185 |
+
content={"error": "Server is shutting down", "retry_after": 5},
|
| 186 |
+
headers={"Retry-After": "5"},
|
| 187 |
+
)
|
| 188 |
+
await response(scope, receive, send)
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
start = time.perf_counter()
|
| 192 |
request_id = uuid.uuid4().hex[:12]
|
| 193 |
status = 500 # default until we see http.response.start
|
| 194 |
|
|
|
|
| 205 |
message = {**message, "headers": headers}
|
| 206 |
await send(message)
|
| 207 |
|
| 208 |
+
# Track request for graceful shutdown
|
| 209 |
+
async with coordinator.track_request():
|
| 210 |
+
try:
|
| 211 |
+
await self.app(scope, receive, send_wrapper)
|
| 212 |
+
except Exception:
|
| 213 |
+
logger.exception("%s %s [%s] failed", method, path, request_id)
|
| 214 |
+
raise
|
| 215 |
+
finally:
|
| 216 |
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 217 |
+
record_request(path, method, status)
|
| 218 |
+
observe_duration(path, elapsed_ms)
|
| 219 |
+
if path not in _QUIET_PATHS:
|
| 220 |
+
logger.info(
|
| 221 |
+
"%s %s %d %.1fms [%s]",
|
| 222 |
+
method,
|
| 223 |
+
path,
|
| 224 |
+
status,
|
| 225 |
+
elapsed_ms,
|
| 226 |
+
request_id,
|
| 227 |
+
)
|
sage/api/routes.py
CHANGED
|
@@ -2,30 +2,29 @@
|
|
| 2 |
API route definitions.
|
| 3 |
|
| 4 |
Endpoints:
|
| 5 |
-
GET /health
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
GET /cache/stats
|
| 9 |
-
POST /cache/clear
|
| 10 |
-
GET /metrics
|
| 11 |
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
| 15 |
import json
|
|
|
|
| 16 |
from concurrent.futures import ThreadPoolExecutor
|
| 17 |
-
from
|
| 18 |
-
from typing import TYPE_CHECKING, Iterator
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if TYPE_CHECKING:
|
| 23 |
-
import numpy as np
|
| 24 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 25 |
-
from pydantic import BaseModel
|
| 26 |
|
| 27 |
from sage.adapters.vector_store import collection_exists
|
| 28 |
-
from sage.api.metrics import metrics_response, record_cache_event
|
| 29 |
from sage.config import MAX_EVIDENCE, get_logger
|
| 30 |
from sage.core import (
|
| 31 |
AggregationMethod,
|
|
@@ -40,31 +39,76 @@ from sage.services.retrieval import get_candidates
|
|
| 40 |
# good parallelism while bounding total concurrent LLM calls.
|
| 41 |
_MAX_EXPLAIN_WORKERS = 4
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
logger = get_logger(__name__)
|
| 44 |
|
| 45 |
router = APIRouter()
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# ---------------------------------------------------------------------------
|
| 49 |
# Response models
|
| 50 |
# ---------------------------------------------------------------------------
|
| 51 |
|
| 52 |
|
| 53 |
class EvidenceSource(BaseModel):
|
|
|
|
|
|
|
| 54 |
id: str
|
| 55 |
text: str
|
| 56 |
|
| 57 |
|
| 58 |
class ConfidenceScore(BaseModel):
|
|
|
|
|
|
|
| 59 |
hhem_score: float
|
| 60 |
is_grounded: bool
|
| 61 |
threshold: float
|
| 62 |
|
| 63 |
|
| 64 |
class RecommendationItem(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
rank: int
|
| 66 |
-
product_id: str
|
| 67 |
-
|
| 68 |
avg_rating: float
|
| 69 |
explanation: str | None = None
|
| 70 |
confidence: ConfidenceScore | None = None
|
|
@@ -72,22 +116,40 @@ class RecommendationItem(BaseModel):
|
|
| 72 |
evidence_sources: list[EvidenceSource] | None = None
|
| 73 |
|
| 74 |
|
| 75 |
-
class
|
|
|
|
|
|
|
| 76 |
query: str
|
| 77 |
recommendations: list[RecommendationItem]
|
| 78 |
|
| 79 |
|
| 80 |
class HealthResponse(BaseModel):
|
|
|
|
|
|
|
| 81 |
status: str
|
| 82 |
qdrant_connected: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
class ErrorResponse(BaseModel):
|
|
|
|
|
|
|
| 86 |
error: str
|
| 87 |
query: str
|
| 88 |
|
| 89 |
|
| 90 |
class CacheStatsResponse(BaseModel):
|
|
|
|
|
|
|
| 91 |
size: int
|
| 92 |
max_entries: int
|
| 93 |
exact_hits: int
|
|
@@ -105,25 +167,20 @@ class CacheStatsResponse(BaseModel):
|
|
| 105 |
# ---------------------------------------------------------------------------
|
| 106 |
|
| 107 |
|
| 108 |
-
@dataclass
|
| 109 |
-
class RecommendParams:
|
| 110 |
-
"""Query parameters shared by /recommend and /recommend/stream."""
|
| 111 |
-
|
| 112 |
-
q: str = Query(..., min_length=1, max_length=500, description="Search query")
|
| 113 |
-
k: int = Query(3, ge=1, le=10, description="Number of products")
|
| 114 |
-
min_rating: float = Query(4.0, ge=1.0, le=5.0, description="Minimum rating")
|
| 115 |
-
|
| 116 |
-
|
| 117 |
def _fetch_products(
|
| 118 |
-
|
| 119 |
-
app
|
| 120 |
-
query_embedding:
|
| 121 |
) -> list[ProductScore]:
|
| 122 |
-
"""Run candidate generation with lifespan-managed singletons.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
return get_candidates(
|
| 124 |
-
query=
|
| 125 |
-
k=
|
| 126 |
-
min_rating=
|
| 127 |
aggregation=AggregationMethod.MAX,
|
| 128 |
client=app.state.qdrant,
|
| 129 |
embedder=app.state.embedder,
|
|
@@ -132,11 +189,14 @@ def _fetch_products(
|
|
| 132 |
|
| 133 |
|
| 134 |
def _build_product_dict(rank: int, product: ProductScore) -> dict:
|
| 135 |
-
"""Build the base product metadata dict (shared by all response paths).
|
|
|
|
|
|
|
|
|
|
| 136 |
return {
|
| 137 |
"rank": rank,
|
| 138 |
"product_id": product.product_id,
|
| 139 |
-
"
|
| 140 |
"avg_rating": round(product.avg_rating, 1),
|
| 141 |
}
|
| 142 |
|
|
@@ -151,21 +211,137 @@ def _build_evidence_list(result: ExplanationResult) -> list[dict]:
|
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
@router.get("/health", response_model=HealthResponse)
|
| 155 |
-
def health(request: Request):
|
| 156 |
-
"""Deployment readiness probe.
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
"""
|
|
|
|
|
|
|
|
|
|
| 161 |
try:
|
| 162 |
-
|
| 163 |
-
ok = collection_exists(client)
|
| 164 |
except Exception:
|
| 165 |
logger.exception("Health check: Qdrant unreachable")
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
# ---------------------------------------------------------------------------
|
|
@@ -173,106 +349,188 @@ def health(request: Request):
|
|
| 173 |
# ---------------------------------------------------------------------------
|
| 174 |
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
):
|
| 186 |
-
"""Return product recommendations with optional grounded explanations."""
|
| 187 |
-
app = request.app
|
| 188 |
cache = app.state.cache
|
| 189 |
-
q =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
# The embedding computed here is reused for candidate retrieval below,
|
| 194 |
-
# avoiding the cost of a second embed_single_query call.
|
| 195 |
-
if explain:
|
| 196 |
-
query_embedding = app.state.embedder.embed_single_query(q)
|
| 197 |
-
cached, hit_type = cache.get(q, query_embedding)
|
| 198 |
-
record_cache_event(f"hit_{hit_type}" if hit_type != "miss" else "miss")
|
| 199 |
-
if cached is not None:
|
| 200 |
-
return cached
|
| 201 |
-
else:
|
| 202 |
-
query_embedding = None
|
| 203 |
|
| 204 |
-
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
|
|
|
| 208 |
|
| 209 |
-
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
"is_grounded": not hr.is_hallucinated,
|
| 252 |
-
"threshold": hr.threshold,
|
| 253 |
-
}
|
| 254 |
-
rec["citations_verified"] = cr.all_valid
|
| 255 |
-
rec["evidence_sources"] = _build_evidence_list(er)
|
| 256 |
-
recommendations.append(rec)
|
| 257 |
-
else:
|
| 258 |
-
for i, product in enumerate(products, 1):
|
| 259 |
-
recommendations.append(_build_product_dict(i, product))
|
| 260 |
|
| 261 |
-
|
|
|
|
|
|
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
return result
|
| 268 |
|
| 269 |
-
except
|
| 270 |
-
logger.
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
)
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# ---------------------------------------------------------------------------
|
| 278 |
# Recommend (SSE streaming)
|
|
@@ -284,11 +542,22 @@ def _sse_event(event: str, data: str) -> str:
|
|
| 284 |
return f"event: {event}\ndata: {data}\n\n"
|
| 285 |
|
| 286 |
|
| 287 |
-
def
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
app,
|
| 290 |
-
) ->
|
| 291 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 292 |
yield _sse_event(
|
| 293 |
"metadata",
|
| 294 |
json.dumps(
|
|
@@ -301,7 +570,7 @@ def _stream_recommendations(
|
|
| 301 |
)
|
| 302 |
|
| 303 |
try:
|
| 304 |
-
products =
|
| 305 |
except Exception:
|
| 306 |
logger.exception("Streaming: candidate generation failed")
|
| 307 |
yield _sse_event("error", json.dumps({"detail": "Failed to retrieve products"}))
|
|
@@ -309,7 +578,9 @@ def _stream_recommendations(
|
|
| 309 |
return
|
| 310 |
|
| 311 |
if not products:
|
| 312 |
-
yield _sse_event(
|
|
|
|
|
|
|
| 313 |
return
|
| 314 |
|
| 315 |
explainer = app.state.explainer
|
|
@@ -324,20 +595,52 @@ def _stream_recommendations(
|
|
| 324 |
yield _sse_event("product", json.dumps(_build_product_dict(i, product)))
|
| 325 |
|
| 326 |
try:
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
)
|
| 332 |
-
|
|
|
|
| 333 |
yield _sse_event("token", json.dumps({"text": token}))
|
| 334 |
|
| 335 |
-
result = stream.get_complete_result()
|
| 336 |
yield _sse_event(
|
| 337 |
"evidence",
|
| 338 |
json.dumps({"evidence_sources": _build_evidence_list(result)}),
|
| 339 |
)
|
| 340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
except ValueError as exc:
|
| 342 |
# Quality gate refusal — evidence insufficient for this product.
|
| 343 |
# Surface the reason so clients can display it meaningfully.
|
|
@@ -352,19 +655,21 @@ def _stream_recommendations(
|
|
| 352 |
yield _sse_event("done", json.dumps({"status": "complete"}))
|
| 353 |
|
| 354 |
|
| 355 |
-
@router.
|
| 356 |
-
def recommend_stream(
|
| 357 |
-
request: Request,
|
| 358 |
-
params: RecommendParams = Depends(),
|
| 359 |
-
):
|
| 360 |
"""Stream product recommendations with explanations via SSE.
|
| 361 |
|
|
|
|
|
|
|
| 362 |
The streaming path does not check or populate the semantic cache and
|
| 363 |
does not compute HHEM confidence scores. For cached or grounded
|
| 364 |
-
responses, use the non-streaming ``/recommend`` endpoint.
|
|
|
|
|
|
|
|
|
|
| 365 |
"""
|
| 366 |
return StreamingResponse(
|
| 367 |
-
_stream_recommendations(
|
| 368 |
media_type="text/event-stream",
|
| 369 |
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 370 |
)
|
|
@@ -376,7 +681,7 @@ def recommend_stream(
|
|
| 376 |
|
| 377 |
|
| 378 |
@router.get("/cache/stats", response_model=CacheStatsResponse)
|
| 379 |
-
def cache_stats(request: Request):
|
| 380 |
"""Return cache performance statistics."""
|
| 381 |
stats = request.app.state.cache.stats()
|
| 382 |
return {
|
|
@@ -394,7 +699,7 @@ def cache_stats(request: Request):
|
|
| 394 |
|
| 395 |
|
| 396 |
@router.post("/cache/clear")
|
| 397 |
-
def cache_clear(request: Request):
|
| 398 |
"""Clear all cached entries."""
|
| 399 |
request.app.state.cache.clear()
|
| 400 |
return {"status": "cleared"}
|
|
@@ -406,7 +711,7 @@ def cache_clear(request: Request):
|
|
| 406 |
|
| 407 |
|
| 408 |
@router.get("/metrics")
|
| 409 |
-
def metrics():
|
| 410 |
"""Prometheus metrics endpoint."""
|
| 411 |
body, content_type = metrics_response()
|
| 412 |
return Response(content=body, media_type=content_type)
|
|
|
|
| 2 |
API route definitions.
|
| 3 |
|
| 4 |
Endpoints:
|
| 5 |
+
GET /health Deployment health check
|
| 6 |
+
POST /recommend Product recommendations (optional explanations)
|
| 7 |
+
POST /recommend/stream SSE streaming explanations
|
| 8 |
+
GET /cache/stats Cache statistics
|
| 9 |
+
POST /cache/clear Clear the semantic cache
|
| 10 |
+
GET /metrics Prometheus metrics
|
| 11 |
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
+
import asyncio
|
| 16 |
import json
|
| 17 |
+
import os
|
| 18 |
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
+
from typing import AsyncIterator
|
|
|
|
| 20 |
|
| 21 |
+
import numpy as np
|
| 22 |
+
from fastapi import APIRouter, Request, Response
|
|
|
|
|
|
|
| 23 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 24 |
+
from pydantic import BaseModel, Field
|
| 25 |
|
| 26 |
from sage.adapters.vector_store import collection_exists
|
| 27 |
+
from sage.api.metrics import metrics_response, record_cache_event, record_error
|
| 28 |
from sage.config import MAX_EVIDENCE, get_logger
|
| 29 |
from sage.core import (
|
| 30 |
AggregationMethod,
|
|
|
|
| 39 |
# good parallelism while bounding total concurrent LLM calls.
|
| 40 |
_MAX_EXPLAIN_WORKERS = 4
|
| 41 |
|
| 42 |
+
# Request timeout in seconds. David's rule: 10s max end-to-end.
|
| 43 |
+
# If the LLM hangs, cut it off and return what we have.
|
| 44 |
+
REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "10.0"))
|
| 45 |
+
|
| 46 |
+
# Per-product timeout for streaming (allows partial results on timeout)
|
| 47 |
+
STREAM_PRODUCT_TIMEOUT = float(os.getenv("STREAM_PRODUCT_TIMEOUT", "15.0"))
|
| 48 |
+
|
| 49 |
logger = get_logger(__name__)
|
| 50 |
|
| 51 |
router = APIRouter()
|
| 52 |
|
| 53 |
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Request models
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class RequestFilters(BaseModel):
|
| 60 |
+
"""Optional filters for recommendation requests."""
|
| 61 |
+
|
| 62 |
+
category: str | None = Field(None, description="Product category filter")
|
| 63 |
+
min_price: float | None = Field(None, ge=0, description="Minimum price")
|
| 64 |
+
max_price: float | None = Field(None, ge=0, description="Maximum price (budget)")
|
| 65 |
+
min_rating: float = Field(4.0, ge=1.0, le=5.0, description="Minimum rating filter")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class RecommendationRequest(BaseModel):
|
| 69 |
+
"""Request body for /recommend and /recommend/stream endpoints."""
|
| 70 |
+
|
| 71 |
+
query: str = Field(
|
| 72 |
+
..., min_length=1, max_length=500, description="Natural language search query"
|
| 73 |
+
)
|
| 74 |
+
user_id: str | None = Field(
|
| 75 |
+
None, description="Optional user ID for personalization"
|
| 76 |
+
)
|
| 77 |
+
k: int = Field(3, ge=1, le=10, description="Number of products to return")
|
| 78 |
+
filters: RequestFilters | None = Field(None, description="Optional filters")
|
| 79 |
+
explain: bool = Field(True, description="Generate LLM explanations")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
# ---------------------------------------------------------------------------
|
| 83 |
# Response models
|
| 84 |
# ---------------------------------------------------------------------------
|
| 85 |
|
| 86 |
|
| 87 |
class EvidenceSource(BaseModel):
|
| 88 |
+
"""A single piece of evidence (review excerpt) supporting the recommendation."""
|
| 89 |
+
|
| 90 |
id: str
|
| 91 |
text: str
|
| 92 |
|
| 93 |
|
| 94 |
class ConfidenceScore(BaseModel):
|
| 95 |
+
"""Confidence metrics for explanation grounding."""
|
| 96 |
+
|
| 97 |
hhem_score: float
|
| 98 |
is_grounded: bool
|
| 99 |
threshold: float
|
| 100 |
|
| 101 |
|
| 102 |
class RecommendationItem(BaseModel):
|
| 103 |
+
"""A single product recommendation with optional explanation.
|
| 104 |
+
|
| 105 |
+
Matches the 'killer demo' format: product, score, explanation,
|
| 106 |
+
confidence, evidence_sources.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
rank: int
|
| 110 |
+
product_id: str # Note: product name requires catalog lookup (future enhancement)
|
| 111 |
+
score: float = Field(..., description="Relevance score (0-1)")
|
| 112 |
avg_rating: float
|
| 113 |
explanation: str | None = None
|
| 114 |
confidence: ConfidenceScore | None = None
|
|
|
|
| 116 |
evidence_sources: list[EvidenceSource] | None = None
|
| 117 |
|
| 118 |
|
| 119 |
+
class RecommendationResponse(BaseModel):
|
| 120 |
+
"""Response body for /recommend endpoint."""
|
| 121 |
+
|
| 122 |
query: str
|
| 123 |
recommendations: list[RecommendationItem]
|
| 124 |
|
| 125 |
|
| 126 |
class HealthResponse(BaseModel):
|
| 127 |
+
"""Health check response with component status."""
|
| 128 |
+
|
| 129 |
status: str
|
| 130 |
qdrant_connected: bool
|
| 131 |
+
llm_reachable: bool
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ReadinessResponse(BaseModel):
|
| 135 |
+
"""Readiness probe response with detailed component status."""
|
| 136 |
+
|
| 137 |
+
ready: bool
|
| 138 |
+
status: str
|
| 139 |
+
components: dict[str, bool]
|
| 140 |
+
message: str | None = None
|
| 141 |
|
| 142 |
|
| 143 |
class ErrorResponse(BaseModel):
|
| 144 |
+
"""Structured error response (not stack traces)."""
|
| 145 |
+
|
| 146 |
error: str
|
| 147 |
query: str
|
| 148 |
|
| 149 |
|
| 150 |
class CacheStatsResponse(BaseModel):
|
| 151 |
+
"""Semantic cache performance statistics."""
|
| 152 |
+
|
| 153 |
size: int
|
| 154 |
max_entries: int
|
| 155 |
exact_hits: int
|
|
|
|
| 167 |
# ---------------------------------------------------------------------------
|
| 168 |
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def _fetch_products(
|
| 171 |
+
request: RecommendationRequest,
|
| 172 |
+
app,
|
| 173 |
+
query_embedding: np.ndarray | None = None,
|
| 174 |
) -> list[ProductScore]:
|
| 175 |
+
"""Run candidate generation with lifespan-managed singletons.
|
| 176 |
+
|
| 177 |
+
This is a blocking call - run via asyncio.to_thread() in async handlers.
|
| 178 |
+
"""
|
| 179 |
+
min_rating = request.filters.min_rating if request.filters else 4.0
|
| 180 |
return get_candidates(
|
| 181 |
+
query=request.query,
|
| 182 |
+
k=request.k,
|
| 183 |
+
min_rating=min_rating,
|
| 184 |
aggregation=AggregationMethod.MAX,
|
| 185 |
client=app.state.qdrant,
|
| 186 |
embedder=app.state.embedder,
|
|
|
|
| 189 |
|
| 190 |
|
| 191 |
def _build_product_dict(rank: int, product: ProductScore) -> dict:
|
| 192 |
+
"""Build the base product metadata dict (shared by all response paths).
|
| 193 |
+
|
| 194 |
+
Uses 'score' instead of 'relevance_score' to match killer demo format.
|
| 195 |
+
"""
|
| 196 |
return {
|
| 197 |
"rank": rank,
|
| 198 |
"product_id": product.product_id,
|
| 199 |
+
"score": round(product.score, 3),
|
| 200 |
"avg_rating": round(product.avg_rating, 1),
|
| 201 |
}
|
| 202 |
|
|
|
|
| 211 |
# ---------------------------------------------------------------------------
|
| 212 |
|
| 213 |
|
| 214 |
+
def _check_llm_reachable(app) -> bool:
|
| 215 |
+
"""Lightweight LLM reachability check.
|
| 216 |
+
|
| 217 |
+
Returns True if explainer is configured and client is initialized.
|
| 218 |
+
Does NOT make an API call (would incur cost on every probe).
|
| 219 |
+
LLM API failures surface as 503 on /recommend.
|
| 220 |
+
"""
|
| 221 |
+
if app.state.explainer is None:
|
| 222 |
+
return False
|
| 223 |
+
# Check that client is initialized (has model attribute)
|
| 224 |
+
return (
|
| 225 |
+
hasattr(app.state.explainer, "client")
|
| 226 |
+
and app.state.explainer.client is not None
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
@router.get("/health", response_model=HealthResponse)
|
| 231 |
+
async def health(request: Request):
|
| 232 |
+
"""Deployment readiness probe.
|
| 233 |
|
| 234 |
+
Checks:
|
| 235 |
+
- Qdrant connectivity (required for recommendations)
|
| 236 |
+
- LLM explainer availability (required for explanations)
|
| 237 |
+
|
| 238 |
+
Note: LLM check verifies configuration, not API reachability.
|
| 239 |
+
Making an actual LLM call would incur cost on every probe.
|
| 240 |
"""
|
| 241 |
+
app = request.app
|
| 242 |
+
|
| 243 |
+
# Check Qdrant
|
| 244 |
try:
|
| 245 |
+
qdrant_ok = await asyncio.to_thread(collection_exists, app.state.qdrant)
|
|
|
|
| 246 |
except Exception:
|
| 247 |
logger.exception("Health check: Qdrant unreachable")
|
| 248 |
+
qdrant_ok = False
|
| 249 |
+
|
| 250 |
+
# Check LLM
|
| 251 |
+
llm_ok = _check_llm_reachable(app)
|
| 252 |
+
|
| 253 |
+
# Status is healthy only if all components are available
|
| 254 |
+
if qdrant_ok and llm_ok:
|
| 255 |
+
status = "healthy"
|
| 256 |
+
elif qdrant_ok:
|
| 257 |
+
status = "degraded" # Can recommend but not explain
|
| 258 |
+
else:
|
| 259 |
+
status = "unhealthy"
|
| 260 |
+
|
| 261 |
+
return {"status": status, "qdrant_connected": qdrant_ok, "llm_reachable": llm_ok}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@router.get("/ready", response_model=ReadinessResponse)
|
| 265 |
+
async def ready(request: Request):
|
| 266 |
+
"""Kubernetes-style readiness probe.
|
| 267 |
+
|
| 268 |
+
Unlike /health (liveness), this endpoint verifies all components are
|
| 269 |
+
actually ready to serve requests:
|
| 270 |
+
- Qdrant: Collection exists and is queryable
|
| 271 |
+
- Embedder: Model loaded and can embed text
|
| 272 |
+
- HHEM: Detector loaded
|
| 273 |
+
- Explainer: LLM client configured
|
| 274 |
+
|
| 275 |
+
Returns 200 if ready, 503 if not ready (for load balancer integration).
|
| 276 |
+
"""
|
| 277 |
+
app = request.app
|
| 278 |
+
components = {}
|
| 279 |
+
messages = []
|
| 280 |
+
|
| 281 |
+
# Check Qdrant connectivity
|
| 282 |
+
try:
|
| 283 |
+
qdrant_ok = await asyncio.to_thread(collection_exists, app.state.qdrant)
|
| 284 |
+
components["qdrant"] = qdrant_ok
|
| 285 |
+
if not qdrant_ok:
|
| 286 |
+
messages.append("Qdrant collection not found")
|
| 287 |
+
except Exception as e:
|
| 288 |
+
components["qdrant"] = False
|
| 289 |
+
messages.append(f"Qdrant unreachable: {e}")
|
| 290 |
+
|
| 291 |
+
# Check embedder
|
| 292 |
+
try:
|
| 293 |
+
if app.state.embedder is not None:
|
| 294 |
+
# Quick sanity check: embed a single word
|
| 295 |
+
_ = await asyncio.to_thread(app.state.embedder.embed_single_query, "test")
|
| 296 |
+
components["embedder"] = True
|
| 297 |
+
else:
|
| 298 |
+
components["embedder"] = False
|
| 299 |
+
messages.append("Embedder not loaded")
|
| 300 |
+
except Exception as e:
|
| 301 |
+
components["embedder"] = False
|
| 302 |
+
messages.append(f"Embedder error: {e}")
|
| 303 |
+
|
| 304 |
+
# Check HHEM detector
|
| 305 |
+
components["hhem"] = app.state.detector is not None
|
| 306 |
+
if not components["hhem"]:
|
| 307 |
+
messages.append("HHEM detector not loaded")
|
| 308 |
+
|
| 309 |
+
# Check explainer (optional - degraded mode acceptable)
|
| 310 |
+
components["explainer"] = app.state.explainer is not None
|
| 311 |
+
if not components["explainer"]:
|
| 312 |
+
messages.append("Explainer not available (degraded mode)")
|
| 313 |
+
|
| 314 |
+
# Core components must be ready (explainer is optional)
|
| 315 |
+
core_ready = all(
|
| 316 |
+
[
|
| 317 |
+
components.get("qdrant", False),
|
| 318 |
+
components.get("embedder", False),
|
| 319 |
+
components.get("hhem", False),
|
| 320 |
+
]
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if core_ready and components.get("explainer", False):
|
| 324 |
+
status = "ready"
|
| 325 |
+
message = None
|
| 326 |
+
elif core_ready:
|
| 327 |
+
status = "degraded"
|
| 328 |
+
message = "Explainer unavailable; explain=false only"
|
| 329 |
+
else:
|
| 330 |
+
status = "not_ready"
|
| 331 |
+
message = "; ".join(messages) if messages else "Core components not ready"
|
| 332 |
+
|
| 333 |
+
response_data = {
|
| 334 |
+
"ready": core_ready,
|
| 335 |
+
"status": status,
|
| 336 |
+
"components": components,
|
| 337 |
+
"message": message,
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
# Return 503 if not ready (for load balancer health checks)
|
| 341 |
+
if not core_ready:
|
| 342 |
+
return JSONResponse(status_code=503, content=response_data)
|
| 343 |
+
|
| 344 |
+
return response_data
|
| 345 |
|
| 346 |
|
| 347 |
# ---------------------------------------------------------------------------
|
|
|
|
| 349 |
# ---------------------------------------------------------------------------
|
| 350 |
|
| 351 |
|
| 352 |
+
def _sync_recommend(
|
| 353 |
+
body: RecommendationRequest,
|
| 354 |
+
app,
|
| 355 |
+
) -> dict:
|
| 356 |
+
"""Synchronous recommendation logic.
|
| 357 |
+
|
| 358 |
+
Separated for use with asyncio.to_thread() and timeout handling.
|
| 359 |
+
Returns the response dict or raises an exception.
|
| 360 |
+
"""
|
|
|
|
|
|
|
|
|
|
| 361 |
cache = app.state.cache
|
| 362 |
+
q = body.query
|
| 363 |
+
explain = body.explain
|
| 364 |
+
|
| 365 |
+
# Check cache before any heavy work (only for the explain path).
|
| 366 |
+
# The embedding computed here is reused for candidate retrieval below,
|
| 367 |
+
# avoiding the cost of a second embed_single_query call.
|
| 368 |
+
if explain:
|
| 369 |
+
query_embedding = app.state.embedder.embed_single_query(q)
|
| 370 |
+
cached, hit_type = cache.get(q, query_embedding)
|
| 371 |
+
record_cache_event(f"hit_{hit_type}" if hit_type != "miss" else "miss")
|
| 372 |
+
if cached is not None:
|
| 373 |
+
return cached
|
| 374 |
+
else:
|
| 375 |
+
query_embedding = None
|
| 376 |
+
|
| 377 |
+
products = _fetch_products(body, app, query_embedding=query_embedding)
|
| 378 |
|
| 379 |
+
if not products:
|
| 380 |
+
return {"query": q, "recommendations": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
recommendations = []
|
| 383 |
|
| 384 |
+
if explain:
|
| 385 |
+
if app.state.explainer is None:
|
| 386 |
+
raise RuntimeError("Explanation service unavailable")
|
| 387 |
|
| 388 |
+
explainer = app.state.explainer
|
| 389 |
+
detector = app.state.detector
|
| 390 |
|
| 391 |
+
def _explain(product: ProductScore):
|
| 392 |
+
# Thread safety: LLM clients use httpx (thread-safe).
|
| 393 |
+
# HHEM model in eval() + no_grad() = read-only forward
|
| 394 |
+
# pass with no state mutation. Tokenizer is stateless.
|
| 395 |
+
er = explainer.generate_explanation(
|
| 396 |
+
query=q,
|
| 397 |
+
product=product,
|
| 398 |
+
max_evidence=MAX_EVIDENCE,
|
| 399 |
+
)
|
| 400 |
+
hr = detector.check_explanation(
|
| 401 |
+
evidence_texts=er.evidence_texts,
|
| 402 |
+
explanation=er.explanation,
|
| 403 |
+
)
|
| 404 |
+
cr = verify_citations(er.explanation, er.evidence_ids, er.evidence_texts)
|
| 405 |
+
return er, hr, cr
|
| 406 |
+
|
| 407 |
+
with ThreadPoolExecutor(
|
| 408 |
+
max_workers=min(len(products), _MAX_EXPLAIN_WORKERS)
|
| 409 |
+
) as pool:
|
| 410 |
+
results = list(pool.map(_explain, products))
|
| 411 |
+
|
| 412 |
+
for i, (product, (er, hr, cr)) in enumerate(
|
| 413 |
+
zip(products, results, strict=True),
|
| 414 |
+
1,
|
| 415 |
+
):
|
| 416 |
+
rec = _build_product_dict(i, product)
|
| 417 |
+
rec["explanation"] = er.explanation
|
| 418 |
+
rec["confidence"] = {
|
| 419 |
+
"hhem_score": round(hr.score, 3),
|
| 420 |
+
"is_grounded": not hr.is_hallucinated,
|
| 421 |
+
"threshold": hr.threshold,
|
| 422 |
+
}
|
| 423 |
+
rec["citations_verified"] = cr.all_valid
|
| 424 |
+
rec["evidence_sources"] = _build_evidence_list(er)
|
| 425 |
+
recommendations.append(rec)
|
| 426 |
+
else:
|
| 427 |
+
for i, product in enumerate(products, 1):
|
| 428 |
+
recommendations.append(_build_product_dict(i, product))
|
| 429 |
+
|
| 430 |
+
result = {"query": q, "recommendations": recommendations}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
# Store in cache (explain path only; embedding was computed above)
|
| 433 |
+
if explain:
|
| 434 |
+
cache.put(q, query_embedding, result)
|
| 435 |
|
| 436 |
+
return result
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
@router.post(
|
| 440 |
+
"/recommend",
|
| 441 |
+
response_model=RecommendationResponse,
|
| 442 |
+
responses={
|
| 443 |
+
408: {"model": ErrorResponse},
|
| 444 |
+
500: {"model": ErrorResponse},
|
| 445 |
+
503: {"model": ErrorResponse},
|
| 446 |
+
},
|
| 447 |
+
)
|
| 448 |
+
async def recommend(request: Request, body: RecommendationRequest):
|
| 449 |
+
"""Return product recommendations with optional grounded explanations.
|
| 450 |
|
| 451 |
+
Accepts JSON body with query, optional user_id, filters, and k.
|
| 452 |
+
Async handler with 10s timeout - if LLM hangs, returns partial results.
|
| 453 |
+
"""
|
| 454 |
+
app = request.app
|
| 455 |
+
q = body.query
|
| 456 |
+
|
| 457 |
+
try:
|
| 458 |
+
# Run blocking code in thread pool with timeout
|
| 459 |
+
result = await asyncio.wait_for(
|
| 460 |
+
asyncio.to_thread(_sync_recommend, body, app),
|
| 461 |
+
timeout=REQUEST_TIMEOUT_SECONDS,
|
| 462 |
+
)
|
| 463 |
return result
|
| 464 |
|
| 465 |
+
except asyncio.TimeoutError:
|
| 466 |
+
logger.warning("Request timeout for query: %s", q)
|
| 467 |
+
record_error("timeout")
|
| 468 |
+
# Graceful degradation: return recommendations without explanations
|
| 469 |
+
# if we timed out during explanation generation
|
| 470 |
+
return _error_response(
|
| 471 |
+
408,
|
| 472 |
+
f"Request timeout ({REQUEST_TIMEOUT_SECONDS}s). Try with explain=false.",
|
| 473 |
+
q,
|
| 474 |
)
|
| 475 |
|
| 476 |
+
except ConnectionError as e:
|
| 477 |
+
# Qdrant or LLM API connection failed
|
| 478 |
+
error_msg = str(e).lower()
|
| 479 |
+
if "qdrant" in error_msg or "vector" in error_msg:
|
| 480 |
+
logger.error("Qdrant connection failed for query: %s - %s", q, e)
|
| 481 |
+
record_error("qdrant_unavailable")
|
| 482 |
+
return _error_response(
|
| 483 |
+
503, "Vector database unavailable. Please try again later.", q
|
| 484 |
+
)
|
| 485 |
+
else:
|
| 486 |
+
# LLM API connection failed
|
| 487 |
+
logger.error("LLM API connection failed for query: %s - %s", q, e)
|
| 488 |
+
record_error("llm_connection_error")
|
| 489 |
+
return _error_response(
|
| 490 |
+
503, "LLM service connection failed. Please try again later.", q
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
except TimeoutError as e:
|
| 494 |
+
# LLM API timeout (different from asyncio.TimeoutError)
|
| 495 |
+
logger.warning("LLM API timeout for query: %s - %s", q, e)
|
| 496 |
+
record_error("llm_timeout")
|
| 497 |
+
return _error_response(
|
| 498 |
+
504, "LLM service timeout. Try with explain=false for faster response.", q
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
except RuntimeError as e:
|
| 502 |
+
error_msg = str(e)
|
| 503 |
+
# Explanation service unavailable
|
| 504 |
+
if "Explanation service unavailable" in error_msg:
|
| 505 |
+
logger.warning("Explanation service unavailable for query: %s", q)
|
| 506 |
+
record_error("llm_unavailable")
|
| 507 |
+
return _error_response(503, str(e), q)
|
| 508 |
+
# LLM rate limited (translated from API error)
|
| 509 |
+
if "rate limit" in error_msg.lower():
|
| 510 |
+
logger.warning("LLM rate limited for query: %s", q)
|
| 511 |
+
record_error("llm_rate_limited")
|
| 512 |
+
return _error_response(
|
| 513 |
+
429, "LLM API rate limited. Please try again later.", q
|
| 514 |
+
)
|
| 515 |
+
record_error("runtime_error")
|
| 516 |
+
raise
|
| 517 |
+
|
| 518 |
+
except Exception as e:
|
| 519 |
+
# Check for Qdrant-specific errors
|
| 520 |
+
error_type = type(e).__name__
|
| 521 |
+
error_msg = str(e).lower()
|
| 522 |
+
|
| 523 |
+
if "qdrant" in error_type.lower() or "qdrant" in error_msg:
|
| 524 |
+
logger.error("Qdrant error for query: %s - %s", q, e)
|
| 525 |
+
record_error("qdrant_error")
|
| 526 |
+
return _error_response(
|
| 527 |
+
503, "Vector database error. Please try again later.", q
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
logger.exception("Recommendation failed for query: %s", q)
|
| 531 |
+
record_error("internal_error")
|
| 532 |
+
return _error_response(500, "Internal server error", q)
|
| 533 |
+
|
| 534 |
|
| 535 |
# ---------------------------------------------------------------------------
|
| 536 |
# Recommend (SSE streaming)
|
|
|
|
| 542 |
return f"event: {event}\ndata: {data}\n\n"
|
| 543 |
|
| 544 |
|
| 545 |
+
def _error_response(status_code: int, error_msg: str, query: str) -> JSONResponse:
|
| 546 |
+
"""Build a standardized JSON error response."""
|
| 547 |
+
return JSONResponse(
|
| 548 |
+
status_code=status_code,
|
| 549 |
+
content={"error": error_msg, "query": query},
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
async def _stream_recommendations(
|
| 554 |
+
body: RecommendationRequest,
|
| 555 |
app,
|
| 556 |
+
) -> AsyncIterator[str]:
|
| 557 |
+
"""Async generator that yields SSE events for streaming recommendations.
|
| 558 |
+
|
| 559 |
+
Uses asyncio.to_thread for blocking calls to avoid blocking the event loop.
|
| 560 |
+
"""
|
| 561 |
yield _sse_event(
|
| 562 |
"metadata",
|
| 563 |
json.dumps(
|
|
|
|
| 570 |
)
|
| 571 |
|
| 572 |
try:
|
| 573 |
+
products = await asyncio.to_thread(_fetch_products, body, app)
|
| 574 |
except Exception:
|
| 575 |
logger.exception("Streaming: candidate generation failed")
|
| 576 |
yield _sse_event("error", json.dumps({"detail": "Failed to retrieve products"}))
|
|
|
|
| 578 |
return
|
| 579 |
|
| 580 |
if not products:
|
| 581 |
+
yield _sse_event(
|
| 582 |
+
"done", json.dumps({"query": body.query, "recommendations": []})
|
| 583 |
+
)
|
| 584 |
return
|
| 585 |
|
| 586 |
explainer = app.state.explainer
|
|
|
|
| 595 |
yield _sse_event("product", json.dumps(_build_product_dict(i, product)))
|
| 596 |
|
| 597 |
try:
|
| 598 |
+
# Helper to generate explanation with timeout protection
|
| 599 |
+
async def _generate_with_timeout(prod):
|
| 600 |
+
# Get the stream object in a thread (it sets up the connection)
|
| 601 |
+
stream = await asyncio.to_thread(
|
| 602 |
+
explainer.generate_explanation_stream,
|
| 603 |
+
body.query,
|
| 604 |
+
prod,
|
| 605 |
+
MAX_EVIDENCE,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# Iterate over tokens - each token retrieval is blocking
|
| 609 |
+
def _get_tokens():
|
| 610 |
+
tokens = list(stream)
|
| 611 |
+
return tokens, stream.get_complete_result()
|
| 612 |
+
|
| 613 |
+
return await asyncio.to_thread(_get_tokens)
|
| 614 |
+
|
| 615 |
+
# Wrap in timeout to prevent hanging streams
|
| 616 |
+
tokens, result = await asyncio.wait_for(
|
| 617 |
+
_generate_with_timeout(product),
|
| 618 |
+
timeout=STREAM_PRODUCT_TIMEOUT,
|
| 619 |
)
|
| 620 |
+
|
| 621 |
+
for token in tokens:
|
| 622 |
yield _sse_event("token", json.dumps({"text": token}))
|
| 623 |
|
|
|
|
| 624 |
yield _sse_event(
|
| 625 |
"evidence",
|
| 626 |
json.dumps({"evidence_sources": _build_evidence_list(result)}),
|
| 627 |
)
|
| 628 |
|
| 629 |
+
except asyncio.TimeoutError:
|
| 630 |
+
logger.warning(
|
| 631 |
+
"Streaming timeout for product %s after %.1fs",
|
| 632 |
+
product.product_id,
|
| 633 |
+
STREAM_PRODUCT_TIMEOUT,
|
| 634 |
+
)
|
| 635 |
+
yield _sse_event(
|
| 636 |
+
"error",
|
| 637 |
+
json.dumps(
|
| 638 |
+
{
|
| 639 |
+
"detail": f"Explanation timed out ({STREAM_PRODUCT_TIMEOUT}s)",
|
| 640 |
+
"product_id": product.product_id,
|
| 641 |
+
}
|
| 642 |
+
),
|
| 643 |
+
)
|
| 644 |
except ValueError as exc:
|
| 645 |
# Quality gate refusal — evidence insufficient for this product.
|
| 646 |
# Surface the reason so clients can display it meaningfully.
|
|
|
|
| 655 |
yield _sse_event("done", json.dumps({"status": "complete"}))
|
| 656 |
|
| 657 |
|
| 658 |
+
@router.post("/recommend/stream")
|
| 659 |
+
async def recommend_stream(request: Request, body: RecommendationRequest):
|
|
|
|
|
|
|
|
|
|
| 660 |
"""Stream product recommendations with explanations via SSE.
|
| 661 |
|
| 662 |
+
Accepts JSON body with query, optional user_id, filters, and k.
|
| 663 |
+
|
| 664 |
The streaming path does not check or populate the semantic cache and
|
| 665 |
does not compute HHEM confidence scores. For cached or grounded
|
| 666 |
+
responses, use the non-streaming ``POST /recommend`` endpoint.
|
| 667 |
+
|
| 668 |
+
David's rule: streaming is non-negotiable. Users perceive streaming
|
| 669 |
+
as 40% faster.
|
| 670 |
"""
|
| 671 |
return StreamingResponse(
|
| 672 |
+
_stream_recommendations(body, request.app),
|
| 673 |
media_type="text/event-stream",
|
| 674 |
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 675 |
)
|
|
|
|
| 681 |
|
| 682 |
|
| 683 |
@router.get("/cache/stats", response_model=CacheStatsResponse)
|
| 684 |
+
async def cache_stats(request: Request):
|
| 685 |
"""Return cache performance statistics."""
|
| 686 |
stats = request.app.state.cache.stats()
|
| 687 |
return {
|
|
|
|
| 699 |
|
| 700 |
|
| 701 |
@router.post("/cache/clear")
|
| 702 |
+
async def cache_clear(request: Request):
|
| 703 |
"""Clear all cached entries."""
|
| 704 |
request.app.state.cache.clear()
|
| 705 |
return {"status": "cleared"}
|
|
|
|
| 711 |
|
| 712 |
|
| 713 |
@router.get("/metrics")
|
| 714 |
+
async def metrics():
|
| 715 |
"""Prometheus metrics endpoint."""
|
| 716 |
body, content_type = metrics_response()
|
| 717 |
return Response(content=body, media_type=content_type)
|
sage/api/run.py
CHANGED
|
@@ -20,7 +20,7 @@ from sage.api.app import create_app
|
|
| 20 |
from sage.config import configure_logging
|
| 21 |
|
| 22 |
|
| 23 |
-
def main():
|
| 24 |
parser = argparse.ArgumentParser(description="Sage API server")
|
| 25 |
parser.add_argument("--host", default="0.0.0.0", help="Bind address")
|
| 26 |
parser.add_argument(
|
|
|
|
| 20 |
from sage.config import configure_logging
|
| 21 |
|
| 22 |
|
| 23 |
+
def main() -> None:
|
| 24 |
parser = argparse.ArgumentParser(description="Sage API server")
|
| 25 |
parser.add_argument("--host", default="0.0.0.0", help="Bind address")
|
| 26 |
parser.add_argument(
|
sage/config/__init__.py
CHANGED
|
@@ -21,9 +21,6 @@ PROJECT_ROOT = Path(__file__).parent.parent.parent
|
|
| 21 |
DATA_DIR = PROJECT_ROOT / "data"
|
| 22 |
DATA_DIR.mkdir(exist_ok=True)
|
| 23 |
|
| 24 |
-
EXPLANATIONS_DIR = DATA_DIR / "explanations"
|
| 25 |
-
EXPLANATIONS_DIR.mkdir(exist_ok=True)
|
| 26 |
-
|
| 27 |
RESULTS_DIR = DATA_DIR / "eval_results"
|
| 28 |
RESULTS_DIR.mkdir(exist_ok=True)
|
| 29 |
|
|
@@ -89,7 +86,11 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
| 89 |
# LLM Settings
|
| 90 |
# ---------------------------------------------------------------------------
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# Model selection
|
| 95 |
ANTHROPIC_MODEL = "claude-sonnet-4-20250514"
|
|
@@ -193,7 +194,6 @@ __all__ = [
|
|
| 193 |
# Paths
|
| 194 |
"PROJECT_ROOT",
|
| 195 |
"DATA_DIR",
|
| 196 |
-
"EXPLANATIONS_DIR",
|
| 197 |
"RESULTS_DIR",
|
| 198 |
# Dataset
|
| 199 |
"DATASET_NAME",
|
|
@@ -220,6 +220,8 @@ __all__ = [
|
|
| 220 |
"ANTHROPIC_API_KEY",
|
| 221 |
"OPENAI_API_KEY",
|
| 222 |
# LLM
|
|
|
|
|
|
|
| 223 |
"LLM_PROVIDER",
|
| 224 |
"ANTHROPIC_MODEL",
|
| 225 |
"OPENAI_MODEL",
|
|
|
|
| 21 |
DATA_DIR = PROJECT_ROOT / "data"
|
| 22 |
DATA_DIR.mkdir(exist_ok=True)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
RESULTS_DIR = DATA_DIR / "eval_results"
|
| 25 |
RESULTS_DIR.mkdir(exist_ok=True)
|
| 26 |
|
|
|
|
| 86 |
# LLM Settings
|
| 87 |
# ---------------------------------------------------------------------------
|
| 88 |
|
| 89 |
+
# Provider constants
|
| 90 |
+
PROVIDER_ANTHROPIC = "anthropic"
|
| 91 |
+
PROVIDER_OPENAI = "openai"
|
| 92 |
+
|
| 93 |
+
LLM_PROVIDER = os.getenv("LLM_PROVIDER", PROVIDER_ANTHROPIC)
|
| 94 |
|
| 95 |
# Model selection
|
| 96 |
ANTHROPIC_MODEL = "claude-sonnet-4-20250514"
|
|
|
|
| 194 |
# Paths
|
| 195 |
"PROJECT_ROOT",
|
| 196 |
"DATA_DIR",
|
|
|
|
| 197 |
"RESULTS_DIR",
|
| 198 |
# Dataset
|
| 199 |
"DATASET_NAME",
|
|
|
|
| 220 |
"ANTHROPIC_API_KEY",
|
| 221 |
"OPENAI_API_KEY",
|
| 222 |
# LLM
|
| 223 |
+
"PROVIDER_ANTHROPIC",
|
| 224 |
+
"PROVIDER_OPENAI",
|
| 225 |
"LLM_PROVIDER",
|
| 226 |
"ANTHROPIC_MODEL",
|
| 227 |
"OPENAI_MODEL",
|
sage/config/queries.py
CHANGED
|
@@ -3,22 +3,30 @@ Standard evaluation queries.
|
|
| 3 |
|
| 4 |
Separated from main config to keep configuration declarative.
|
| 5 |
These are test fixtures used by evaluation scripts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
# Common product categories (high confidence expected)
|
| 11 |
"wireless headphones with noise cancellation",
|
| 12 |
-
"laptop charger
|
| 13 |
"USB hub with multiple ports",
|
| 14 |
-
"portable
|
| 15 |
"bluetooth speaker with good bass",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"HDMI cable for 4K TV",
|
| 17 |
"external hard drive for backup",
|
| 18 |
"webcam for video calls",
|
| 19 |
"wireless mouse for laptop",
|
| 20 |
"keyboard with backlight",
|
| 21 |
-
# Specific attribute queries (medium confidence)
|
| 22 |
"screen protector for phone",
|
| 23 |
"phone case with good protection",
|
| 24 |
"earbuds for working out",
|
|
@@ -27,17 +35,10 @@ EVALUATION_QUERIES = [
|
|
| 27 |
"surge protector with USB ports",
|
| 28 |
"wireless charging pad",
|
| 29 |
"fast charging USB-C cable",
|
| 30 |
-
"noise cancelling headphones for travel",
|
| 31 |
-
"portable speaker with good bass",
|
| 32 |
]
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
"wireless headphones with noise cancellation",
|
| 37 |
-
"laptop charger for MacBook",
|
| 38 |
-
"USB hub with multiple ports",
|
| 39 |
-
"portable battery pack for travel",
|
| 40 |
-
"bluetooth speaker with good bass",
|
| 41 |
"cheap but good quality earbuds",
|
| 42 |
"durable phone case that looks nice",
|
| 43 |
"fast charging cable that won't break",
|
|
@@ -49,26 +50,30 @@ ANALYSIS_QUERIES = [
|
|
| 49 |
"gift for someone who likes music",
|
| 50 |
]
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Queries for end-to-end success rate evaluation - comprehensive coverage
|
| 53 |
-
E2E_EVAL_QUERIES =
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
"charger that actually works",
|
| 66 |
-
"waterproof speaker for shower",
|
| 67 |
-
"gift for someone who likes music",
|
| 68 |
-
"tablet stand for kitchen",
|
| 69 |
-
"wireless mouse for laptop",
|
| 70 |
-
"HDMI cable for monitor",
|
| 71 |
-
"phone mount for car",
|
| 72 |
-
"screen protector for phone",
|
| 73 |
-
"backup battery for camping",
|
| 74 |
-
]
|
|
|
|
| 3 |
|
| 4 |
Separated from main config to keep configuration declarative.
|
| 5 |
These are test fixtures used by evaluation scripts.
|
| 6 |
+
|
| 7 |
+
Query organization:
|
| 8 |
+
- CORE_QUERIES: Common queries appearing in all evaluations
|
| 9 |
+
- STANDARD_QUERIES: Standard product category queries
|
| 10 |
+
- EDGE_CASE_QUERIES: Challenging queries for failure analysis
|
| 11 |
+
- Derived lists compose these bases for specific use cases
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
# Core queries - used across all evaluations (5 queries)
|
| 15 |
+
CORE_QUERIES = [
|
|
|
|
| 16 |
"wireless headphones with noise cancellation",
|
| 17 |
+
"laptop charger for MacBook",
|
| 18 |
"USB hub with multiple ports",
|
| 19 |
+
"portable battery pack for travel",
|
| 20 |
"bluetooth speaker with good bass",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
# Standard product queries - common categories (13 queries)
|
| 24 |
+
STANDARD_QUERIES = [
|
| 25 |
"HDMI cable for 4K TV",
|
| 26 |
"external hard drive for backup",
|
| 27 |
"webcam for video calls",
|
| 28 |
"wireless mouse for laptop",
|
| 29 |
"keyboard with backlight",
|
|
|
|
| 30 |
"screen protector for phone",
|
| 31 |
"phone case with good protection",
|
| 32 |
"earbuds for working out",
|
|
|
|
| 35 |
"surge protector with USB ports",
|
| 36 |
"wireless charging pad",
|
| 37 |
"fast charging USB-C cable",
|
|
|
|
|
|
|
| 38 |
]
|
| 39 |
|
| 40 |
+
# Edge case queries - tests failure modes (9 queries)
|
| 41 |
+
EDGE_CASE_QUERIES = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
"cheap but good quality earbuds",
|
| 43 |
"durable phone case that looks nice",
|
| 44 |
"fast charging cable that won't break",
|
|
|
|
| 50 |
"gift for someone who likes music",
|
| 51 |
]
|
| 52 |
|
| 53 |
+
# Primary evaluation queries - used for general RAGAS/HHEM evaluation
|
| 54 |
+
# Combines core + standard + 2 semantic variants
|
| 55 |
+
EVALUATION_QUERIES = (
|
| 56 |
+
CORE_QUERIES
|
| 57 |
+
+ STANDARD_QUERIES
|
| 58 |
+
+ [
|
| 59 |
+
"noise cancelling headphones for travel",
|
| 60 |
+
"portable speaker with good bass",
|
| 61 |
+
]
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Queries for failure analysis - focused on edge cases and challenging queries
|
| 65 |
+
ANALYSIS_QUERIES = CORE_QUERIES + EDGE_CASE_QUERIES
|
| 66 |
+
|
| 67 |
# Queries for end-to-end success rate evaluation - comprehensive coverage
|
| 68 |
+
E2E_EVAL_QUERIES = (
|
| 69 |
+
CORE_QUERIES
|
| 70 |
+
+ EDGE_CASE_QUERIES
|
| 71 |
+
+ [
|
| 72 |
+
"tablet stand for kitchen",
|
| 73 |
+
"wireless mouse for laptop",
|
| 74 |
+
"HDMI cable for monitor",
|
| 75 |
+
"phone mount for car",
|
| 76 |
+
"screen protector for phone",
|
| 77 |
+
"backup battery for camping",
|
| 78 |
+
]
|
| 79 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/core/prompts.py
CHANGED
|
@@ -11,6 +11,7 @@ Prompt design rationale:
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
from sage.core.models import ProductScore, RetrievedChunk
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
EXPLANATION_SYSTEM_PROMPT = """You explain product recommendations using ONLY direct quotes from customer reviews.
|
|
@@ -97,9 +98,7 @@ def build_explanation_prompt(
|
|
| 97 |
Returns:
|
| 98 |
Tuple of (system_prompt, user_prompt, evidence_texts, evidence_ids).
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
-
evidence_texts = [c.text for c in chunks_used]
|
| 102 |
-
evidence_ids = [c.review_id for c in chunks_used]
|
| 103 |
evidence_formatted = format_evidence(product.evidence, max_evidence)
|
| 104 |
|
| 105 |
valid_ids = ", ".join(evidence_ids)
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
from sage.core.models import ProductScore, RetrievedChunk
|
| 14 |
+
from sage.utils import extract_evidence
|
| 15 |
|
| 16 |
|
| 17 |
EXPLANATION_SYSTEM_PROMPT = """You explain product recommendations using ONLY direct quotes from customer reviews.
|
|
|
|
| 98 |
Returns:
|
| 99 |
Tuple of (system_prompt, user_prompt, evidence_texts, evidence_ids).
|
| 100 |
"""
|
| 101 |
+
evidence_texts, evidence_ids = extract_evidence(product.evidence, max_evidence)
|
|
|
|
|
|
|
| 102 |
evidence_formatted = format_evidence(product.evidence, max_evidence)
|
| 103 |
|
| 104 |
valid_ids = ", ".join(evidence_ids)
|
sage/core/verification.py
CHANGED
|
@@ -19,6 +19,7 @@ from sage.core.models import (
|
|
| 19 |
QuoteVerification,
|
| 20 |
VerificationResult,
|
| 21 |
)
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
# Forbidden phrases that violate prompt constraints.
|
|
@@ -106,21 +107,6 @@ def extract_quotes(text: str, min_length: int = 4) -> list[str]:
|
|
| 106 |
return list(dict.fromkeys(quotes)) # Preserve order, remove duplicates
|
| 107 |
|
| 108 |
|
| 109 |
-
def normalize_text(text: str) -> str:
|
| 110 |
-
"""
|
| 111 |
-
Normalize text for fuzzy matching.
|
| 112 |
-
|
| 113 |
-
Converts to lowercase and collapses whitespace.
|
| 114 |
-
|
| 115 |
-
Args:
|
| 116 |
-
text: Text to normalize.
|
| 117 |
-
|
| 118 |
-
Returns:
|
| 119 |
-
Normalized text string.
|
| 120 |
-
"""
|
| 121 |
-
return " ".join(text.lower().split())
|
| 122 |
-
|
| 123 |
-
|
| 124 |
def verify_quote_in_evidence(
|
| 125 |
quote: str,
|
| 126 |
evidence_texts: list[str],
|
|
|
|
| 19 |
QuoteVerification,
|
| 20 |
VerificationResult,
|
| 21 |
)
|
| 22 |
+
from sage.utils import normalize_text
|
| 23 |
|
| 24 |
|
| 25 |
# Forbidden phrases that violate prompt constraints.
|
|
|
|
| 107 |
return list(dict.fromkeys(quotes)) # Preserve order, remove duplicates
|
| 108 |
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def verify_quote_in_evidence(
|
| 111 |
quote: str,
|
| 112 |
evidence_texts: list[str],
|
sage/services/baselines.py
CHANGED
|
@@ -17,6 +17,7 @@ import numpy as np
|
|
| 17 |
import pandas as pd
|
| 18 |
|
| 19 |
from sage.config import COLLECTION_NAME
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class RandomBaseline:
|
|
@@ -122,9 +123,7 @@ class ItemKNNBaseline:
|
|
| 122 |
)
|
| 123 |
|
| 124 |
# Normalize embeddings for cosine similarity
|
| 125 |
-
|
| 126 |
-
norms = np.where(norms == 0, 1, norms)
|
| 127 |
-
self.embeddings_norm = self.embeddings / norms
|
| 128 |
|
| 129 |
self.embedder = embedder
|
| 130 |
|
|
@@ -146,7 +145,7 @@ class ItemKNNBaseline:
|
|
| 146 |
|
| 147 |
# Embed query
|
| 148 |
query_emb = self.embedder.embed_single_query(query)
|
| 149 |
-
query_emb =
|
| 150 |
|
| 151 |
# Compute similarities (dot product of normalized vectors = cosine)
|
| 152 |
similarities = self.embeddings_norm @ query_emb
|
|
@@ -192,8 +191,7 @@ def build_product_embeddings(
|
|
| 192 |
)
|
| 193 |
|
| 194 |
# Normalize
|
| 195 |
-
|
| 196 |
-
product_embeddings[product_id] = agg_emb
|
| 197 |
|
| 198 |
return product_embeddings
|
| 199 |
|
|
@@ -241,7 +239,6 @@ def load_product_embeddings_from_qdrant() -> dict[str, np.ndarray]:
|
|
| 241 |
product_embeddings = {}
|
| 242 |
for product_id, vectors in product_vectors.items():
|
| 243 |
mean_vec = np.mean(vectors, axis=0)
|
| 244 |
-
|
| 245 |
-
product_embeddings[product_id] = mean_vec
|
| 246 |
|
| 247 |
return product_embeddings
|
|
|
|
| 17 |
import pandas as pd
|
| 18 |
|
| 19 |
from sage.config import COLLECTION_NAME
|
| 20 |
+
from sage.utils import normalize_vectors
|
| 21 |
|
| 22 |
|
| 23 |
class RandomBaseline:
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
# Normalize embeddings for cosine similarity
|
| 126 |
+
self.embeddings_norm = normalize_vectors(self.embeddings)
|
|
|
|
|
|
|
| 127 |
|
| 128 |
self.embedder = embedder
|
| 129 |
|
|
|
|
| 145 |
|
| 146 |
# Embed query
|
| 147 |
query_emb = self.embedder.embed_single_query(query)
|
| 148 |
+
query_emb = normalize_vectors(query_emb)
|
| 149 |
|
| 150 |
# Compute similarities (dot product of normalized vectors = cosine)
|
| 151 |
similarities = self.embeddings_norm @ query_emb
|
|
|
|
| 191 |
)
|
| 192 |
|
| 193 |
# Normalize
|
| 194 |
+
product_embeddings[product_id] = normalize_vectors(agg_emb)
|
|
|
|
| 195 |
|
| 196 |
return product_embeddings
|
| 197 |
|
|
|
|
| 239 |
product_embeddings = {}
|
| 240 |
for product_id, vectors in product_vectors.items():
|
| 241 |
mean_vec = np.mean(vectors, axis=0)
|
| 242 |
+
product_embeddings[product_id] = normalize_vectors(mean_vec)
|
|
|
|
| 243 |
|
| 244 |
return product_embeddings
|
sage/services/cache.py
CHANGED
|
@@ -3,6 +3,52 @@ Semantic query cache with exact-match (L1) and embedding-similarity (L2) layers.
|
|
| 3 |
|
| 4 |
Provides sub-millisecond cache hits for repeated queries and ~50ms hits for
|
| 5 |
semantically equivalent queries, avoiding redundant retrieval + LLM calls.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import copy
|
|
@@ -12,7 +58,7 @@ from dataclasses import dataclass
|
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
|
| 15 |
-
from sage.
|
| 16 |
from sage.config import (
|
| 17 |
CACHE_MAX_ENTRIES,
|
| 18 |
CACHE_SIMILARITY_THRESHOLD,
|
|
@@ -79,14 +125,10 @@ class CacheStats:
|
|
| 79 |
class SemanticCache:
|
| 80 |
"""Thread-safe in-memory cache with exact-match and semantic-similarity layers.
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
max_entries : int
|
| 87 |
-
Maximum cached entries before LRU eviction.
|
| 88 |
-
ttl_seconds : float
|
| 89 |
-
Time-to-live in seconds. Entries older than this are evicted on access.
|
| 90 |
"""
|
| 91 |
|
| 92 |
def __init__(
|
|
@@ -125,19 +167,14 @@ class SemanticCache:
|
|
| 125 |
) -> tuple[dict | None, str]:
|
| 126 |
"""Look up a cached result.
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
Returns
|
| 137 |
-
-------
|
| 138 |
-
tuple[dict | None, str]
|
| 139 |
-
(cached_result, hit_type) where hit_type is "exact", "semantic",
|
| 140 |
-
or "miss".
|
| 141 |
"""
|
| 142 |
key = normalize_text(query)
|
| 143 |
now = time.monotonic()
|
|
@@ -151,6 +188,11 @@ class SemanticCache:
|
|
| 151 |
entry.last_accessed = now
|
| 152 |
entry.hit_count += 1
|
| 153 |
self._exact_hits += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
return copy.deepcopy(entry.result), "exact"
|
| 155 |
|
| 156 |
# L2: semantic similarity
|
|
@@ -161,22 +203,27 @@ class SemanticCache:
|
|
| 161 |
best_entry.hit_count += 1
|
| 162 |
self._semantic_hits += 1
|
| 163 |
self._semantic_similarity_sum += best_sim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
return copy.deepcopy(best_entry.result), "semantic"
|
| 165 |
|
| 166 |
self._misses += 1
|
|
|
|
|
|
|
|
|
|
| 167 |
return None, "miss"
|
| 168 |
|
| 169 |
def put(self, query: str, query_embedding: np.ndarray, result: dict) -> None:
|
| 170 |
"""Store a result in the cache.
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
The
|
| 176 |
-
query_embedding : np.ndarray
|
| 177 |
-
The query embedding vector.
|
| 178 |
-
result : dict
|
| 179 |
-
The serializable result to cache.
|
| 180 |
"""
|
| 181 |
key = normalize_text(query)
|
| 182 |
now = time.monotonic()
|
|
@@ -204,6 +251,12 @@ class SemanticCache:
|
|
| 204 |
)
|
| 205 |
self._exact[key] = entry
|
| 206 |
self._entries.append(entry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
def stats(self) -> CacheStats:
|
| 209 |
"""Return a snapshot of cache statistics."""
|
|
@@ -245,15 +298,18 @@ class SemanticCache:
|
|
| 245 |
Must be called while holding self._lock and with len(self._entries) > 0.
|
| 246 |
"""
|
| 247 |
cached_embeddings = np.array([e.embedding for e in self._entries])
|
| 248 |
-
query_norm =
|
| 249 |
-
|
| 250 |
-
cached_normed = cached_embeddings / norms
|
| 251 |
similarities = cached_normed @ query_norm
|
| 252 |
best_idx = int(np.argmax(similarities))
|
| 253 |
return self._entries[best_idx], float(similarities[best_idx])
|
| 254 |
|
| 255 |
def _remove_entry(self, entry: _CacheEntry) -> None:
|
| 256 |
-
"""Remove an entry from both indexes. Must be called while holding self._lock.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
self._exact.pop(entry.key, None)
|
| 258 |
self._entries.remove(entry)
|
| 259 |
self._evictions += 1
|
|
|
|
| 3 |
|
| 4 |
Provides sub-millisecond cache hits for repeated queries and ~50ms hits for
|
| 5 |
semantically equivalent queries, avoiding redundant retrieval + LLM calls.
|
| 6 |
+
|
| 7 |
+
Architecture (cache sits BETWEEN user and vector DB):
|
| 8 |
+
|
| 9 |
+
User Query
|
| 10 |
+
│
|
| 11 |
+
▼
|
| 12 |
+
┌─────────────────┐
|
| 13 |
+
│ L1: Exact Match │ ─── hit ──▶ Return cached response (<1ms)
|
| 14 |
+
│ (query string) │
|
| 15 |
+
└────────┬────────┘
|
| 16 |
+
│ miss
|
| 17 |
+
▼
|
| 18 |
+
┌─────────────────┐
|
| 19 |
+
│ L2: Semantic │ ─── hit ──▶ Return cached response (~50ms)
|
| 20 |
+
│ (embedding sim) │
|
| 21 |
+
└────────┬────────┘
|
| 22 |
+
│ miss
|
| 23 |
+
▼
|
| 24 |
+
┌─────────────────┐
|
| 25 |
+
│ Vector DB Query │
|
| 26 |
+
│ (Qdrant) │
|
| 27 |
+
└────────┬────────┘
|
| 28 |
+
│
|
| 29 |
+
▼
|
| 30 |
+
┌─────────────────┐
|
| 31 |
+
│ LLM Explanation │
|
| 32 |
+
│ (OpenAI) │
|
| 33 |
+
└────────┬────────┘
|
| 34 |
+
│
|
| 35 |
+
▼
|
| 36 |
+
Store in cache ──▶ Return response
|
| 37 |
+
|
| 38 |
+
TTL Policy (unified at 1 hour):
|
| 39 |
+
We use a single 3600s TTL rather than separate L1/L2 TTLs because:
|
| 40 |
+
1. Product reviews don't change frequently (static corpus)
|
| 41 |
+
2. LLM explanations are deterministic given same evidence
|
| 42 |
+
3. Simpler cache invalidation (one knob to tune)
|
| 43 |
+
4. In production, we'd tie TTL to data refresh cadence
|
| 44 |
+
|
| 45 |
+
Similarity Threshold (0.92):
|
| 46 |
+
Chosen based on empirical testing:
|
| 47 |
+
- 0.85: Too permissive, returns irrelevant cached results
|
| 48 |
+
- 0.90: Some false positives on short queries
|
| 49 |
+
- 0.92: Good balance — catches "headphones" ≈ "best headphones"
|
| 50 |
+
- 0.95: Too strict, misses obvious paraphrases
|
| 51 |
+
The threshold is configurable via CACHE_SIMILARITY_THRESHOLD env var.
|
| 52 |
"""
|
| 53 |
|
| 54 |
import copy
|
|
|
|
| 58 |
|
| 59 |
import numpy as np
|
| 60 |
|
| 61 |
+
from sage.utils import normalize_text, normalize_vectors
|
| 62 |
from sage.config import (
|
| 63 |
CACHE_MAX_ENTRIES,
|
| 64 |
CACHE_SIMILARITY_THRESHOLD,
|
|
|
|
| 125 |
class SemanticCache:
|
| 126 |
"""Thread-safe in-memory cache with exact-match and semantic-similarity layers.
|
| 127 |
|
| 128 |
+
Args:
|
| 129 |
+
similarity_threshold: Minimum cosine similarity for a semantic cache hit (0.0-1.0).
|
| 130 |
+
max_entries: Maximum cached entries before LRU eviction.
|
| 131 |
+
ttl_seconds: Time-to-live in seconds. Entries older than this are evicted on access.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
"""
|
| 133 |
|
| 134 |
def __init__(
|
|
|
|
| 167 |
) -> tuple[dict | None, str]:
|
| 168 |
"""Look up a cached result.
|
| 169 |
|
| 170 |
+
Args:
|
| 171 |
+
query: The user query.
|
| 172 |
+
query_embedding: Pre-computed embedding for semantic matching.
|
| 173 |
+
If None, only exact match is attempted.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Tuple of (cached_result, hit_type) where hit_type is "exact",
|
| 177 |
+
"semantic", or "miss".
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
"""
|
| 179 |
key = normalize_text(query)
|
| 180 |
now = time.monotonic()
|
|
|
|
| 188 |
entry.last_accessed = now
|
| 189 |
entry.hit_count += 1
|
| 190 |
self._exact_hits += 1
|
| 191 |
+
logger.info(
|
| 192 |
+
"Cache L1 HIT (exact): query=%r, hits=%d",
|
| 193 |
+
query[:50],
|
| 194 |
+
entry.hit_count,
|
| 195 |
+
)
|
| 196 |
return copy.deepcopy(entry.result), "exact"
|
| 197 |
|
| 198 |
# L2: semantic similarity
|
|
|
|
| 203 |
best_entry.hit_count += 1
|
| 204 |
self._semantic_hits += 1
|
| 205 |
self._semantic_similarity_sum += best_sim
|
| 206 |
+
logger.info(
|
| 207 |
+
"Cache L2 HIT (semantic): query=%r, matched=%r, sim=%.3f",
|
| 208 |
+
query[:50],
|
| 209 |
+
best_entry.key[:50],
|
| 210 |
+
best_sim,
|
| 211 |
+
)
|
| 212 |
return copy.deepcopy(best_entry.result), "semantic"
|
| 213 |
|
| 214 |
self._misses += 1
|
| 215 |
+
logger.info(
|
| 216 |
+
"Cache MISS: query=%r, cache_size=%d", query[:50], len(self._entries)
|
| 217 |
+
)
|
| 218 |
return None, "miss"
|
| 219 |
|
| 220 |
def put(self, query: str, query_embedding: np.ndarray, result: dict) -> None:
|
| 221 |
"""Store a result in the cache.
|
| 222 |
|
| 223 |
+
Args:
|
| 224 |
+
query: The user query.
|
| 225 |
+
query_embedding: The query embedding vector.
|
| 226 |
+
result: The serializable result to cache.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
"""
|
| 228 |
key = normalize_text(query)
|
| 229 |
now = time.monotonic()
|
|
|
|
| 251 |
)
|
| 252 |
self._exact[key] = entry
|
| 253 |
self._entries.append(entry)
|
| 254 |
+
logger.info(
|
| 255 |
+
"Cache PUT: query=%r, cache_size=%d/%d",
|
| 256 |
+
query[:50],
|
| 257 |
+
len(self._entries),
|
| 258 |
+
self._max_entries,
|
| 259 |
+
)
|
| 260 |
|
| 261 |
def stats(self) -> CacheStats:
|
| 262 |
"""Return a snapshot of cache statistics."""
|
|
|
|
| 298 |
Must be called while holding self._lock and with len(self._entries) > 0.
|
| 299 |
"""
|
| 300 |
cached_embeddings = np.array([e.embedding for e in self._entries])
|
| 301 |
+
query_norm = normalize_vectors(query_embedding)
|
| 302 |
+
cached_normed = normalize_vectors(cached_embeddings)
|
|
|
|
| 303 |
similarities = cached_normed @ query_norm
|
| 304 |
best_idx = int(np.argmax(similarities))
|
| 305 |
return self._entries[best_idx], float(similarities[best_idx])
|
| 306 |
|
| 307 |
def _remove_entry(self, entry: _CacheEntry) -> None:
|
| 308 |
+
"""Remove an entry from both indexes. Must be called while holding self._lock.
|
| 309 |
+
|
| 310 |
+
Note: Uses O(n) list.remove() which is acceptable for max_entries <= 1000.
|
| 311 |
+
For larger caches, consider a heap or ordered dict structure.
|
| 312 |
+
"""
|
| 313 |
self._exact.pop(entry.key, None)
|
| 314 |
self._entries.remove(entry)
|
| 315 |
self._evictions += 1
|
sage/services/cold_start.py
CHANGED
|
@@ -15,8 +15,8 @@ from __future__ import annotations
|
|
| 15 |
|
| 16 |
from typing import TYPE_CHECKING, Literal
|
| 17 |
|
| 18 |
-
from sage.adapters.
|
| 19 |
-
from sage.
|
| 20 |
from sage.core import (
|
| 21 |
AggregationMethod,
|
| 22 |
NewItem,
|
|
@@ -71,12 +71,13 @@ def preferences_to_query(prefs: UserPreferences) -> str:
|
|
| 71 |
return query if query else DEFAULT_COLD_START_QUERY
|
| 72 |
|
| 73 |
|
| 74 |
-
class ColdStartService:
|
| 75 |
"""
|
| 76 |
Service for handling cold-start scenarios.
|
| 77 |
|
| 78 |
Provides strategies for new users and new items.
|
| 79 |
Uses composition with RetrievalService for recommendation logic.
|
|
|
|
| 80 |
"""
|
| 81 |
|
| 82 |
def __init__(
|
|
@@ -106,20 +107,6 @@ class ColdStartService:
|
|
| 106 |
self._retrieval = RetrievalService(collection_name=self.collection_name)
|
| 107 |
return self._retrieval
|
| 108 |
|
| 109 |
-
@property
|
| 110 |
-
def embedder(self):
|
| 111 |
-
"""Lazy-load embedder."""
|
| 112 |
-
if self._embedder is None:
|
| 113 |
-
self._embedder = get_embedder()
|
| 114 |
-
return self._embedder
|
| 115 |
-
|
| 116 |
-
@property
|
| 117 |
-
def client(self):
|
| 118 |
-
"""Lazy-load Qdrant client."""
|
| 119 |
-
if self._client is None:
|
| 120 |
-
self._client = get_client()
|
| 121 |
-
return self._client
|
| 122 |
-
|
| 123 |
def recommend_for_new_user(
|
| 124 |
self,
|
| 125 |
preferences: UserPreferences | None = None,
|
|
@@ -144,7 +131,7 @@ class ColdStartService:
|
|
| 144 |
elif preferences:
|
| 145 |
search_query = preferences_to_query(preferences)
|
| 146 |
else:
|
| 147 |
-
search_query =
|
| 148 |
|
| 149 |
return self.retrieval.recommend(
|
| 150 |
query=search_query,
|
|
|
|
| 15 |
|
| 16 |
from typing import TYPE_CHECKING, Literal
|
| 17 |
|
| 18 |
+
from sage.adapters.vector_store import search
|
| 19 |
+
from sage.utils import LazyServiceMixin
|
| 20 |
from sage.core import (
|
| 21 |
AggregationMethod,
|
| 22 |
NewItem,
|
|
|
|
| 71 |
return query if query else DEFAULT_COLD_START_QUERY
|
| 72 |
|
| 73 |
|
| 74 |
+
class ColdStartService(LazyServiceMixin):
|
| 75 |
"""
|
| 76 |
Service for handling cold-start scenarios.
|
| 77 |
|
| 78 |
Provides strategies for new users and new items.
|
| 79 |
Uses composition with RetrievalService for recommendation logic.
|
| 80 |
+
Uses LazyServiceMixin for on-demand embedder and client initialization.
|
| 81 |
"""
|
| 82 |
|
| 83 |
def __init__(
|
|
|
|
| 107 |
self._retrieval = RetrievalService(collection_name=self.collection_name)
|
| 108 |
return self._retrieval
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def recommend_for_new_user(
|
| 111 |
self,
|
| 112 |
preferences: UserPreferences | None = None,
|
|
|
|
| 131 |
elif preferences:
|
| 132 |
search_query = preferences_to_query(preferences)
|
| 133 |
else:
|
| 134 |
+
search_query = DEFAULT_COLD_START_QUERY
|
| 135 |
|
| 136 |
return self.retrieval.recommend(
|
| 137 |
query=search_query,
|
sage/services/evaluation.py
CHANGED
|
@@ -21,6 +21,7 @@ from typing import Callable
|
|
| 21 |
import numpy as np
|
| 22 |
|
| 23 |
from sage.core import EvalCase, EvalResult, MetricsReport
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
# Core ranking metrics
|
|
@@ -108,10 +109,7 @@ def intra_list_diversity(embeddings: np.ndarray) -> float:
|
|
| 108 |
if n < 2:
|
| 109 |
return 0.0
|
| 110 |
|
| 111 |
-
|
| 112 |
-
norms = np.where(norms == 0, 1, norms)
|
| 113 |
-
normalized = embeddings / norms
|
| 114 |
-
|
| 115 |
similarities = normalized @ normalized.T
|
| 116 |
distances = 1 - similarities
|
| 117 |
upper_tri = np.triu(distances, k=1)
|
|
|
|
| 21 |
import numpy as np
|
| 22 |
|
| 23 |
from sage.core import EvalCase, EvalResult, MetricsReport
|
| 24 |
+
from sage.utils import normalize_vectors
|
| 25 |
|
| 26 |
|
| 27 |
# Core ranking metrics
|
|
|
|
| 109 |
if n < 2:
|
| 110 |
return 0.0
|
| 111 |
|
| 112 |
+
normalized = normalize_vectors(embeddings)
|
|
|
|
|
|
|
|
|
|
| 113 |
similarities = normalized @ normalized.T
|
| 114 |
distances = 1 - similarities
|
| 115 |
upper_tri = np.triu(distances, k=1)
|
sage/services/explanation.py
CHANGED
|
@@ -5,10 +5,10 @@ Orchestrates LLM-based explanation generation with evidence quality gates
|
|
| 5 |
and post-generation verification.
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
import time
|
| 9 |
-
|
| 10 |
from sage.adapters.llm import LLMClient, get_llm_client
|
|
|
|
| 11 |
from sage.config import get_logger
|
|
|
|
| 12 |
from sage.core import (
|
| 13 |
CitationVerificationResult,
|
| 14 |
EvidenceQuality,
|
|
@@ -67,13 +67,13 @@ def _build_refusal_result(
|
|
| 67 |
) -> ExplanationResult:
|
| 68 |
"""Build an ExplanationResult for a quality gate refusal."""
|
| 69 |
refusal = generate_refusal_message(query, quality)
|
| 70 |
-
|
| 71 |
return ExplanationResult(
|
| 72 |
explanation=refusal,
|
| 73 |
product_id=product.product_id,
|
| 74 |
query=query,
|
| 75 |
-
evidence_texts=
|
| 76 |
-
evidence_ids=
|
| 77 |
tokens_used=0,
|
| 78 |
model="quality_gate_refusal",
|
| 79 |
)
|
|
@@ -113,17 +113,12 @@ class Explainer:
|
|
| 113 |
build_explanation_prompt(query, product, max_evidence)
|
| 114 |
)
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
logger.info(
|
| 122 |
-
"LLM generation for %s: %.0fms, %d tokens",
|
| 123 |
-
product.product_id,
|
| 124 |
-
(time.perf_counter() - t0) * 1000,
|
| 125 |
-
tokens,
|
| 126 |
-
)
|
| 127 |
|
| 128 |
return explanation, tokens, evidence_texts, evidence_ids, user_prompt
|
| 129 |
|
|
|
|
| 5 |
and post-generation verification.
|
| 6 |
"""
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from sage.adapters.llm import LLMClient, get_llm_client
|
| 9 |
+
from sage.api.metrics import observe_llm_duration
|
| 10 |
from sage.config import get_logger
|
| 11 |
+
from sage.utils import extract_evidence, timed_operation
|
| 12 |
from sage.core import (
|
| 13 |
CitationVerificationResult,
|
| 14 |
EvidenceQuality,
|
|
|
|
| 67 |
) -> ExplanationResult:
|
| 68 |
"""Build an ExplanationResult for a quality gate refusal."""
|
| 69 |
refusal = generate_refusal_message(query, quality)
|
| 70 |
+
evidence_texts, evidence_ids = extract_evidence(product.evidence, max_evidence)
|
| 71 |
return ExplanationResult(
|
| 72 |
explanation=refusal,
|
| 73 |
product_id=product.product_id,
|
| 74 |
query=query,
|
| 75 |
+
evidence_texts=evidence_texts,
|
| 76 |
+
evidence_ids=evidence_ids,
|
| 77 |
tokens_used=0,
|
| 78 |
model="quality_gate_refusal",
|
| 79 |
)
|
|
|
|
| 113 |
build_explanation_prompt(query, product, max_evidence)
|
| 114 |
)
|
| 115 |
|
| 116 |
+
with timed_operation("LLM generation", logger, observe_llm_duration):
|
| 117 |
+
explanation, tokens = self.client.generate(
|
| 118 |
+
system=system_prompt,
|
| 119 |
+
user=user_prompt,
|
| 120 |
+
)
|
| 121 |
+
logger.info("Generated for %s: %d tokens", product.product_id, tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
return explanation, tokens, evidence_texts, evidence_ids, user_prompt
|
| 124 |
|
sage/services/retrieval.py
CHANGED
|
@@ -13,11 +13,11 @@ Aggregation strategies for chunk-to-product scoring:
|
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
-
import time
|
| 17 |
from typing import TYPE_CHECKING
|
| 18 |
|
| 19 |
-
from sage.adapters.
|
| 20 |
-
from sage.
|
|
|
|
| 21 |
from sage.core import (
|
| 22 |
AggregationMethod,
|
| 23 |
ProductScore,
|
|
@@ -54,11 +54,12 @@ DEFAULT_SIMILARITY_WEIGHT = 0.8 # alpha: weight for semantic similarity
|
|
| 54 |
DEFAULT_RATING_WEIGHT = 0.2 # beta: weight for normalized rating
|
| 55 |
|
| 56 |
|
| 57 |
-
class RetrievalService:
|
| 58 |
"""
|
| 59 |
Service for retrieving and ranking product recommendations.
|
| 60 |
|
| 61 |
Coordinates between embedder, vector store, and aggregation logic.
|
|
|
|
| 62 |
"""
|
| 63 |
|
| 64 |
def __init__(
|
|
@@ -88,20 +89,6 @@ class RetrievalService:
|
|
| 88 |
self._embedder = embedder
|
| 89 |
self._client = client
|
| 90 |
|
| 91 |
-
@property
|
| 92 |
-
def embedder(self):
|
| 93 |
-
"""Lazy-load embedder."""
|
| 94 |
-
if self._embedder is None:
|
| 95 |
-
self._embedder = get_embedder()
|
| 96 |
-
return self._embedder
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def client(self):
|
| 100 |
-
"""Lazy-load Qdrant client."""
|
| 101 |
-
if self._client is None:
|
| 102 |
-
self._client = get_client()
|
| 103 |
-
return self._client
|
| 104 |
-
|
| 105 |
def retrieve_chunks(
|
| 106 |
self,
|
| 107 |
query: str,
|
|
@@ -126,23 +113,18 @@ class RetrievalService:
|
|
| 126 |
limit = limit or self.candidate_limit
|
| 127 |
|
| 128 |
if query_embedding is None:
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
)
|
| 141 |
-
logger.info(
|
| 142 |
-
"Qdrant search: %.0fms, %d results",
|
| 143 |
-
(time.perf_counter() - t0) * 1000,
|
| 144 |
-
len(results),
|
| 145 |
-
)
|
| 146 |
|
| 147 |
chunks = []
|
| 148 |
for r in results:
|
|
|
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
|
|
|
| 16 |
from typing import TYPE_CHECKING
|
| 17 |
|
| 18 |
+
from sage.adapters.vector_store import search
|
| 19 |
+
from sage.api.metrics import observe_embedding_duration, observe_retrieval_duration
|
| 20 |
+
from sage.utils import LazyServiceMixin, timed_operation
|
| 21 |
from sage.core import (
|
| 22 |
AggregationMethod,
|
| 23 |
ProductScore,
|
|
|
|
| 54 |
DEFAULT_RATING_WEIGHT = 0.2 # beta: weight for normalized rating
|
| 55 |
|
| 56 |
|
| 57 |
+
class RetrievalService(LazyServiceMixin):
|
| 58 |
"""
|
| 59 |
Service for retrieving and ranking product recommendations.
|
| 60 |
|
| 61 |
Coordinates between embedder, vector store, and aggregation logic.
|
| 62 |
+
Uses LazyServiceMixin for on-demand embedder and client initialization.
|
| 63 |
"""
|
| 64 |
|
| 65 |
def __init__(
|
|
|
|
| 89 |
self._embedder = embedder
|
| 90 |
self._client = client
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def retrieve_chunks(
|
| 93 |
self,
|
| 94 |
query: str,
|
|
|
|
| 113 |
limit = limit or self.candidate_limit
|
| 114 |
|
| 115 |
if query_embedding is None:
|
| 116 |
+
with timed_operation("Embedding", logger, observe_embedding_duration):
|
| 117 |
+
query_embedding = self.embedder.embed_single_query(query)
|
| 118 |
+
|
| 119 |
+
with timed_operation("Qdrant search", logger, observe_retrieval_duration):
|
| 120 |
+
results = search(
|
| 121 |
+
client=self.client,
|
| 122 |
+
query_embedding=query_embedding.tolist(),
|
| 123 |
+
collection_name=self.collection_name,
|
| 124 |
+
limit=limit,
|
| 125 |
+
min_rating=min_rating,
|
| 126 |
+
)
|
| 127 |
+
logger.info("Retrieved %d raw results", len(results))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
chunks = []
|
| 130 |
for r in results:
|
sage/utils.py
CHANGED
|
@@ -2,9 +2,319 @@
|
|
| 2 |
Shared utility functions.
|
| 3 |
"""
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
import json
|
|
|
|
|
|
|
|
|
|
| 6 |
from datetime import datetime
|
|
|
|
| 7 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def save_results(data: dict, prefix: str, directory: Path | None = None) -> Path:
|
|
|
|
| 2 |
Shared utility functions.
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import importlib
|
| 8 |
import json
|
| 9 |
+
import threading
|
| 10 |
+
import time
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
from datetime import datetime
|
| 13 |
+
from functools import wraps
|
| 14 |
from pathlib import Path
|
| 15 |
+
from types import ModuleType
|
| 16 |
+
from typing import TYPE_CHECKING, Callable, Generator, TypeVar
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
T = TypeVar("T")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Import Utilities
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def require_import(
|
| 32 |
+
package: str,
|
| 33 |
+
*,
|
| 34 |
+
pip_name: str | None = None,
|
| 35 |
+
extras: str | None = None,
|
| 36 |
+
) -> ModuleType:
|
| 37 |
+
"""Import a package with a standardized error message.
|
| 38 |
+
|
| 39 |
+
Centralizes the try-import pattern used across adapters to provide
|
| 40 |
+
consistent, helpful error messages when optional dependencies are missing.
|
| 41 |
+
|
| 42 |
+
Usage:
|
| 43 |
+
torch = require_import("torch")
|
| 44 |
+
qdrant = require_import("qdrant_client", pip_name="qdrant-client")
|
| 45 |
+
st = require_import("sentence_transformers", pip_name="sentence-transformers")
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
package: The Python package name to import.
|
| 49 |
+
pip_name: The pip install name if different from package name.
|
| 50 |
+
extras: Optional extras to include (e.g., "[api]").
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
The imported module.
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
ImportError: With a helpful message including install command.
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
return importlib.import_module(package)
|
| 60 |
+
except ImportError as e:
|
| 61 |
+
install_name = pip_name or package
|
| 62 |
+
if extras:
|
| 63 |
+
install_name = f"{install_name}{extras}"
|
| 64 |
+
raise ImportError(
|
| 65 |
+
f"{package} package required. Install with: pip install {install_name}"
|
| 66 |
+
) from e
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def require_imports(*packages: str | tuple[str, str]) -> list[ModuleType]:
|
| 70 |
+
"""Import multiple packages with standardized error messages.
|
| 71 |
+
|
| 72 |
+
Usage:
|
| 73 |
+
torch, transformers = require_imports("torch", "transformers")
|
| 74 |
+
qdrant, = require_imports(("qdrant_client", "qdrant-client"))
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
packages: Package names or (package, pip_name) tuples.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
List of imported modules in the same order.
|
| 81 |
+
|
| 82 |
+
Raises:
|
| 83 |
+
ImportError: With a helpful message for the first missing package.
|
| 84 |
+
"""
|
| 85 |
+
modules = []
|
| 86 |
+
for pkg in packages:
|
| 87 |
+
if isinstance(pkg, tuple):
|
| 88 |
+
package, pip_name = pkg
|
| 89 |
+
modules.append(require_import(package, pip_name=pip_name))
|
| 90 |
+
else:
|
| 91 |
+
modules.append(require_import(pkg))
|
| 92 |
+
return modules
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Lazy Loading Utilities
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class LazyServiceMixin:
|
| 101 |
+
"""Mixin providing lazy-loaded embedder and Qdrant client properties.
|
| 102 |
+
|
| 103 |
+
Use this mixin in services that need on-demand access to the embedder
|
| 104 |
+
and/or Qdrant client. Avoids duplicating the lazy-load pattern.
|
| 105 |
+
|
| 106 |
+
Usage:
|
| 107 |
+
class MyService(LazyServiceMixin):
|
| 108 |
+
def __init__(self, client=None, embedder=None):
|
| 109 |
+
self._client = client
|
| 110 |
+
self._embedder = embedder
|
| 111 |
+
|
| 112 |
+
def do_something(self):
|
| 113 |
+
# Uses lazy-loaded properties from mixin
|
| 114 |
+
results = self.client.search(...)
|
| 115 |
+
embedding = self.embedder.embed_single_query(...)
|
| 116 |
+
|
| 117 |
+
The mixin expects _client and _embedder instance attributes to be set
|
| 118 |
+
(can be None for lazy initialization).
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
_client: object | None
|
| 122 |
+
_embedder: object | None
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def embedder(self):
|
| 126 |
+
"""Lazy-load the E5 embedder."""
|
| 127 |
+
if getattr(self, "_embedder", None) is None:
|
| 128 |
+
from sage.adapters.embeddings import get_embedder
|
| 129 |
+
|
| 130 |
+
self._embedder = get_embedder()
|
| 131 |
+
return self._embedder
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def client(self):
|
| 135 |
+
"""Lazy-load the Qdrant client."""
|
| 136 |
+
if getattr(self, "_client", None) is None:
|
| 137 |
+
from sage.adapters.vector_store import get_client
|
| 138 |
+
|
| 139 |
+
self._client = get_client()
|
| 140 |
+
return self._client
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
# Singleton Utilities
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def thread_safe_singleton(factory_fn: Callable[[], T]) -> Callable[[], T]:
|
| 149 |
+
"""Decorator for thread-safe lazy singleton initialization.
|
| 150 |
+
|
| 151 |
+
Usage:
|
| 152 |
+
@thread_safe_singleton
|
| 153 |
+
def get_embedder():
|
| 154 |
+
return E5Embedder()
|
| 155 |
+
|
| 156 |
+
# Later:
|
| 157 |
+
embedder = get_embedder() # Creates on first call, returns cached thereafter
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
factory_fn: Zero-argument callable that creates the instance.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
A wrapper function that returns the singleton instance.
|
| 164 |
+
"""
|
| 165 |
+
instance: T | None = None
|
| 166 |
+
lock = threading.Lock()
|
| 167 |
+
|
| 168 |
+
@wraps(factory_fn)
|
| 169 |
+
def get_instance() -> T:
|
| 170 |
+
nonlocal instance
|
| 171 |
+
if instance is None:
|
| 172 |
+
with lock:
|
| 173 |
+
if instance is None:
|
| 174 |
+
instance = factory_fn()
|
| 175 |
+
return instance
|
| 176 |
+
|
| 177 |
+
return get_instance
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@contextmanager
|
| 181 |
+
def timed_operation(
|
| 182 |
+
name: str,
|
| 183 |
+
logger: logging.Logger | None = None,
|
| 184 |
+
metrics_observer: Callable[[float], None] | None = None,
|
| 185 |
+
log_format: str = "%s: %.0fms",
|
| 186 |
+
) -> Generator[None, None, None]:
|
| 187 |
+
"""Context manager for timing operations with optional logging and metrics.
|
| 188 |
+
|
| 189 |
+
Usage:
|
| 190 |
+
with timed_operation("Embedding", logger, observe_embedding_duration):
|
| 191 |
+
result = embedder.embed(query)
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
name: Operation name for logging.
|
| 195 |
+
logger: Logger instance for info-level timing output.
|
| 196 |
+
metrics_observer: Callback that receives duration in seconds.
|
| 197 |
+
log_format: Format string for log message (name, ms).
|
| 198 |
+
|
| 199 |
+
Yields:
|
| 200 |
+
None. Duration is computed and reported on exit.
|
| 201 |
+
"""
|
| 202 |
+
t0 = time.perf_counter()
|
| 203 |
+
try:
|
| 204 |
+
yield
|
| 205 |
+
finally:
|
| 206 |
+
duration = time.perf_counter() - t0
|
| 207 |
+
if metrics_observer is not None:
|
| 208 |
+
metrics_observer(duration)
|
| 209 |
+
if logger is not None:
|
| 210 |
+
logger.info(log_format, name, duration * 1000)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def normalize_text(text: str) -> str:
|
| 214 |
+
"""Normalize text for fuzzy matching.
|
| 215 |
+
|
| 216 |
+
Converts to lowercase and collapses whitespace.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
text: Text to normalize.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Normalized text string.
|
| 223 |
+
"""
|
| 224 |
+
return " ".join(text.lower().split())
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def normalize_vectors(vectors: np.ndarray, eps: float = 1e-10) -> np.ndarray:
|
| 228 |
+
"""L2-normalize vectors to unit norm with numerical stability.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
vectors: Array of shape (n, d) or (d,) to normalize.
|
| 232 |
+
eps: Small constant for numerical stability.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Normalized vectors with the same shape as input.
|
| 236 |
+
"""
|
| 237 |
+
import numpy as np
|
| 238 |
+
|
| 239 |
+
if vectors.ndim == 1:
|
| 240 |
+
norm = np.linalg.norm(vectors) + eps
|
| 241 |
+
return vectors / norm
|
| 242 |
+
|
| 243 |
+
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
| 244 |
+
norms = np.where(norms == 0, 1, norms + eps)
|
| 245 |
+
return vectors / norms
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
# Evidence Extraction Utilities
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def extract_evidence_texts(
|
| 254 |
+
chunks: list,
|
| 255 |
+
max_chunks: int | None = None,
|
| 256 |
+
) -> list[str]:
|
| 257 |
+
"""Extract text content from evidence chunks.
|
| 258 |
+
|
| 259 |
+
Centralizes the common pattern: [c.text for c in chunks[:max_chunks]]
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
chunks: List of chunk objects with .text attribute (RetrievedChunk, etc.)
|
| 263 |
+
max_chunks: Optional limit on number of chunks to extract.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
List of text strings from the chunks.
|
| 267 |
+
"""
|
| 268 |
+
if max_chunks is not None:
|
| 269 |
+
chunks = chunks[:max_chunks]
|
| 270 |
+
return [c.text for c in chunks]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def extract_evidence_ids(
|
| 274 |
+
chunks: list,
|
| 275 |
+
max_chunks: int | None = None,
|
| 276 |
+
) -> list[str]:
|
| 277 |
+
"""Extract review IDs from evidence chunks.
|
| 278 |
+
|
| 279 |
+
Centralizes the common pattern: [c.review_id for c in chunks[:max_chunks]]
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
chunks: List of chunk objects with .review_id attribute.
|
| 283 |
+
max_chunks: Optional limit on number of chunks to extract.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
List of review ID strings from the chunks.
|
| 287 |
+
"""
|
| 288 |
+
if max_chunks is not None:
|
| 289 |
+
chunks = chunks[:max_chunks]
|
| 290 |
+
return [c.review_id for c in chunks]
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def extract_evidence(
|
| 294 |
+
chunks: list,
|
| 295 |
+
max_chunks: int | None = None,
|
| 296 |
+
) -> tuple[list[str], list[str]]:
|
| 297 |
+
"""Extract both texts and IDs from evidence chunks.
|
| 298 |
+
|
| 299 |
+
Convenience function combining extract_evidence_texts and extract_evidence_ids.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
chunks: List of chunk objects with .text and .review_id attributes.
|
| 303 |
+
max_chunks: Optional limit on number of chunks to extract.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
Tuple of (texts, ids) lists.
|
| 307 |
+
"""
|
| 308 |
+
if max_chunks is not None:
|
| 309 |
+
chunks = chunks[:max_chunks]
|
| 310 |
+
texts = [c.text for c in chunks]
|
| 311 |
+
ids = [c.review_id for c in chunks]
|
| 312 |
+
return texts, ids
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# ---------------------------------------------------------------------------
|
| 316 |
+
# File Utilities
|
| 317 |
+
# ---------------------------------------------------------------------------
|
| 318 |
|
| 319 |
|
| 320 |
def save_results(data: dict, prefix: str, directory: Path | None = None) -> Path:
|
scripts/demo.py
CHANGED
|
@@ -17,8 +17,6 @@ import json
|
|
| 17 |
|
| 18 |
from sage.core import AggregationMethod
|
| 19 |
from sage.config import FAITHFULNESS_TARGET, get_logger, log_banner, log_section
|
| 20 |
-
from sage.services.explanation import Explainer
|
| 21 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 22 |
from sage.services.retrieval import get_candidates
|
| 23 |
|
| 24 |
logger = get_logger(__name__)
|
|
@@ -45,9 +43,10 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
|
|
| 45 |
logger.warning("No products found matching query")
|
| 46 |
return None
|
| 47 |
|
| 48 |
-
# Initialize
|
| 49 |
-
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
results = []
|
| 53 |
|
|
|
|
| 17 |
|
| 18 |
from sage.core import AggregationMethod
|
| 19 |
from sage.config import FAITHFULNESS_TARGET, get_logger, log_banner, log_section
|
|
|
|
|
|
|
| 20 |
from sage.services.retrieval import get_candidates
|
| 21 |
|
| 22 |
logger = get_logger(__name__)
|
|
|
|
| 43 |
logger.warning("No products found matching query")
|
| 44 |
return None
|
| 45 |
|
| 46 |
+
# Initialize services
|
| 47 |
+
from scripts.lib.services import get_explanation_services
|
| 48 |
+
|
| 49 |
+
explainer, detector = get_explanation_services()
|
| 50 |
|
| 51 |
results = []
|
| 52 |
|
scripts/e2e_success_rate.py
CHANGED
|
@@ -20,6 +20,7 @@ from datetime import datetime
|
|
| 20 |
|
| 21 |
from sage.config import (
|
| 22 |
E2E_EVAL_QUERIES,
|
|
|
|
| 23 |
RESULTS_DIR,
|
| 24 |
get_logger,
|
| 25 |
log_banner,
|
|
@@ -103,8 +104,7 @@ class E2EReport:
|
|
| 103 |
|
| 104 |
def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
| 105 |
"""Run end-to-end success rate evaluation."""
|
| 106 |
-
from
|
| 107 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 108 |
from sage.services.faithfulness import (
|
| 109 |
is_refusal,
|
| 110 |
is_mismatch_warning,
|
|
@@ -116,8 +116,7 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 116 |
log_banner(logger, "END-TO-END SUCCESS RATE EVALUATION")
|
| 117 |
logger.info("Samples: %d", len(queries))
|
| 118 |
|
| 119 |
-
explainer =
|
| 120 |
-
detector = HallucinationDetector()
|
| 121 |
|
| 122 |
all_cases: list[CaseResult] = []
|
| 123 |
case_id = 0
|
|
@@ -290,8 +289,6 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 290 |
raw_e2e = n_raw_success / n_total if n_total > 0 else 0
|
| 291 |
adjusted_e2e = n_adjusted_success / n_total if n_total > 0 else 0
|
| 292 |
|
| 293 |
-
target = 0.85
|
| 294 |
-
|
| 295 |
report = E2EReport(
|
| 296 |
timestamp=datetime.now().isoformat(),
|
| 297 |
n_total=n_total,
|
|
@@ -305,9 +302,9 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 305 |
hhem_pass_rate=hhem_pass_rate,
|
| 306 |
raw_e2e_success_rate=raw_e2e,
|
| 307 |
adjusted_e2e_success_rate=adjusted_e2e,
|
| 308 |
-
target=
|
| 309 |
-
meets_target=adjusted_e2e >=
|
| 310 |
-
gap_to_target=
|
| 311 |
)
|
| 312 |
|
| 313 |
# Print report
|
|
@@ -359,7 +356,7 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 359 |
n_total,
|
| 360 |
adjusted_e2e * 100,
|
| 361 |
)
|
| 362 |
-
logger.info("Target: %.1f%%",
|
| 363 |
logger.info("Gap to target: %.1f%%", report.gap_to_target * 100)
|
| 364 |
logger.info("Meets target: %s", "YES" if report.meets_target else "NO")
|
| 365 |
|
|
|
|
| 20 |
|
| 21 |
from sage.config import (
|
| 22 |
E2E_EVAL_QUERIES,
|
| 23 |
+
FAITHFULNESS_TARGET,
|
| 24 |
RESULTS_DIR,
|
| 25 |
get_logger,
|
| 26 |
log_banner,
|
|
|
|
| 104 |
|
| 105 |
def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
| 106 |
"""Run end-to-end success rate evaluation."""
|
| 107 |
+
from scripts.lib.services import get_explanation_services
|
|
|
|
| 108 |
from sage.services.faithfulness import (
|
| 109 |
is_refusal,
|
| 110 |
is_mismatch_warning,
|
|
|
|
| 116 |
log_banner(logger, "END-TO-END SUCCESS RATE EVALUATION")
|
| 117 |
logger.info("Samples: %d", len(queries))
|
| 118 |
|
| 119 |
+
explainer, detector = get_explanation_services()
|
|
|
|
| 120 |
|
| 121 |
all_cases: list[CaseResult] = []
|
| 122 |
case_id = 0
|
|
|
|
| 289 |
raw_e2e = n_raw_success / n_total if n_total > 0 else 0
|
| 290 |
adjusted_e2e = n_adjusted_success / n_total if n_total > 0 else 0
|
| 291 |
|
|
|
|
|
|
|
| 292 |
report = E2EReport(
|
| 293 |
timestamp=datetime.now().isoformat(),
|
| 294 |
n_total=n_total,
|
|
|
|
| 302 |
hhem_pass_rate=hhem_pass_rate,
|
| 303 |
raw_e2e_success_rate=raw_e2e,
|
| 304 |
adjusted_e2e_success_rate=adjusted_e2e,
|
| 305 |
+
target=FAITHFULNESS_TARGET,
|
| 306 |
+
meets_target=adjusted_e2e >= FAITHFULNESS_TARGET,
|
| 307 |
+
gap_to_target=FAITHFULNESS_TARGET - adjusted_e2e,
|
| 308 |
)
|
| 309 |
|
| 310 |
# Print report
|
|
|
|
| 356 |
n_total,
|
| 357 |
adjusted_e2e * 100,
|
| 358 |
)
|
| 359 |
+
logger.info("Target: %.1f%%", FAITHFULNESS_TARGET * 100)
|
| 360 |
logger.info("Gap to target: %.1f%%", report.gap_to_target * 100)
|
| 361 |
logger.info("Meets target: %s", "YES" if report.meets_target else "NO")
|
| 362 |
|
scripts/eda.py
CHANGED
|
@@ -313,3 +313,201 @@ print(
|
|
| 313 |
)
|
| 314 |
print(f"Data quality issues: {empty_reviews + very_short + duplicate_texts}")
|
| 315 |
print(f"\nPlots saved to: {FIGURES_DIR}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
)
|
| 314 |
print(f"Data quality issues: {empty_reviews + very_short + duplicate_texts}")
|
| 315 |
print(f"\nPlots saved to: {FIGURES_DIR}")
|
| 316 |
+
|
| 317 |
+
# %% Generate markdown report
|
| 318 |
+
from pathlib import Path
|
| 319 |
+
|
| 320 |
+
REPORTS_DIR = Path("reports")
|
| 321 |
+
REPORTS_DIR.mkdir(exist_ok=True)
|
| 322 |
+
|
| 323 |
+
# Compute all stats for report
|
| 324 |
+
raw_total = len(df)
|
| 325 |
+
prepared_total = len(df_prepared)
|
| 326 |
+
unique_users_raw = df["user_id"].nunique()
|
| 327 |
+
unique_items_raw = df["parent_asin"].nunique()
|
| 328 |
+
unique_users_prepared = prepared_stats["unique_users"]
|
| 329 |
+
unique_items_prepared = prepared_stats["unique_items"]
|
| 330 |
+
avg_rating_raw = stats["avg_rating"]
|
| 331 |
+
avg_rating_prepared = prepared_stats["avg_rating"]
|
| 332 |
+
retention_pct = prepared_total / raw_total * 100
|
| 333 |
+
|
| 334 |
+
median_chars = df["text_length"].median()
|
| 335 |
+
mean_chars = df["text_length"].mean()
|
| 336 |
+
median_tokens = df["estimated_tokens"].median()
|
| 337 |
+
chunking_pct = needs_chunking / len(df) * 100
|
| 338 |
+
|
| 339 |
+
five_star_pct = rating_counts.get(5, 0) / len(df) * 100
|
| 340 |
+
one_star_pct = rating_counts.get(1, 0) / len(df) * 100
|
| 341 |
+
middle_pct = 100 - five_star_pct - one_star_pct
|
| 342 |
+
|
| 343 |
+
users_one_review = (user_counts == 1).sum()
|
| 344 |
+
users_one_review_pct = users_one_review / len(user_counts) * 100
|
| 345 |
+
users_5plus = (user_counts >= 5).sum()
|
| 346 |
+
max_user_reviews = user_counts.max()
|
| 347 |
+
|
| 348 |
+
items_one_review = (item_counts == 1).sum()
|
| 349 |
+
items_one_review_pct = items_one_review / len(item_counts) * 100
|
| 350 |
+
items_5plus = (item_counts >= 5).sum()
|
| 351 |
+
max_item_reviews = item_counts.max()
|
| 352 |
+
|
| 353 |
+
length_1star = length_by_rating.get(1, 0)
|
| 354 |
+
length_2star = length_by_rating.get(2, 0)
|
| 355 |
+
length_3star = length_by_rating.get(3, 0)
|
| 356 |
+
length_4star = length_by_rating.get(4, 0)
|
| 357 |
+
length_5star = length_by_rating.get(5, 0)
|
| 358 |
+
|
| 359 |
+
report_content = f"""# Exploratory Data Analysis: Amazon Electronics Reviews
|
| 360 |
+
|
| 361 |
+
**Dataset:** McAuley-Lab/Amazon-Reviews-2023 (Electronics category)
|
| 362 |
+
**Subset:** {raw_total:,} raw reviews -> {prepared_total:,} after 5-core filtering
|
| 363 |
+
|
| 364 |
+
---
|
| 365 |
+
|
| 366 |
+
## Dataset Overview
|
| 367 |
+
|
| 368 |
+
The Amazon Electronics reviews dataset provides rich user feedback data for building recommendation systems. After standard preprocessing and 5-core filtering (requiring users and items to have at least 5 interactions), the dataset exhibits the characteristic sparsity of real-world recommendation scenarios.
|
| 369 |
+
|
| 370 |
+
| Metric | Raw | After 5-Core |
|
| 371 |
+
|--------|-----|--------------|
|
| 372 |
+
| Total Reviews | {raw_total:,} | {prepared_total:,} |
|
| 373 |
+
| Unique Users | {unique_users_raw:,} | {unique_users_prepared:,} |
|
| 374 |
+
| Unique Items | {unique_items_raw:,} | {unique_items_prepared:,} |
|
| 375 |
+
| Avg Rating | {avg_rating_raw:.2f} | {avg_rating_prepared:.2f} |
|
| 376 |
+
| Retention | - | {retention_pct:.1f}% |
|
| 377 |
+
|
| 378 |
+
---
|
| 379 |
+
|
| 380 |
+
## Rating Distribution
|
| 381 |
+
|
| 382 |
+
Amazon reviews exhibit a well-known J-shaped distribution, heavily skewed toward 5-star ratings. This reflects both genuine satisfaction and selection bias (dissatisfied customers often don't leave reviews).
|
| 383 |
+
|
| 384 |
+

|
| 385 |
+
|
| 386 |
+
**Key Observations:**
|
| 387 |
+
- 5-star ratings dominate ({five_star_pct:.1f}% of reviews)
|
| 388 |
+
- 1-star reviews form the second largest group ({one_star_pct:.1f}%)
|
| 389 |
+
- Middle ratings (2-4 stars) are relatively rare ({middle_pct:.1f}% combined)
|
| 390 |
+
- This polarization is typical for e-commerce review data
|
| 391 |
+
|
| 392 |
+
**Implications for Modeling:**
|
| 393 |
+
- Binary classification (positive/negative) may be more robust than regression
|
| 394 |
+
- Rating-weighted aggregation should account for the skewed distribution
|
| 395 |
+
- Evidence from 4-5 star reviews carries stronger positive signal
|
| 396 |
+
|
| 397 |
+
---
|
| 398 |
+
|
| 399 |
+
## Review Length Analysis
|
| 400 |
+
|
| 401 |
+
Review length varies significantly and correlates with the chunking strategy for the RAG pipeline. Most reviews are short enough to embed directly without chunking.
|
| 402 |
+
|
| 403 |
+

|
| 404 |
+
|
| 405 |
+
**Length Statistics:**
|
| 406 |
+
- Median: {median_chars:.0f} characters (~{median_tokens:.0f} tokens)
|
| 407 |
+
- Mean: {mean_chars:.0f} characters (~{mean_chars / 4:.0f} tokens)
|
| 408 |
+
- Reviews exceeding 200 tokens: {chunking_pct:.1f}% (require chunking)
|
| 409 |
+
|
| 410 |
+
**Chunking Strategy Validation:**
|
| 411 |
+
The tiered chunking approach is well-suited to this distribution:
|
| 412 |
+
- **Short (<200 tokens):** No chunking needed - majority of reviews
|
| 413 |
+
- **Medium (200-500 tokens):** Semantic chunking at topic boundaries
|
| 414 |
+
- **Long (>500 tokens):** Semantic + sliding window fallback
|
| 415 |
+
|
| 416 |
+
---
|
| 417 |
+
|
| 418 |
+
## Review Length by Rating
|
| 419 |
+
|
| 420 |
+
Negative reviews tend to be longer than positive ones. Users who are dissatisfied often provide detailed explanations of issues, while satisfied users may simply express approval.
|
| 421 |
+
|
| 422 |
+

|
| 423 |
+
|
| 424 |
+
**Pattern:**
|
| 425 |
+
- 1-star reviews: {length_1star:.0f} chars median
|
| 426 |
+
- 2-3 star reviews: {length_2star:.0f}-{length_3star:.0f} chars median (users explain nuance)
|
| 427 |
+
- 4-star reviews: {length_4star:.0f} chars median
|
| 428 |
+
- 5-star reviews: {length_5star:.0f} chars median
|
| 429 |
+
|
| 430 |
+
**Implications:**
|
| 431 |
+
- Negative reviews provide richer evidence for issue identification
|
| 432 |
+
- Positive reviews may require multiple chunks for substantive explanations
|
| 433 |
+
- Rating filters (min_rating=4) naturally bias toward shorter evidence
|
| 434 |
+
|
| 435 |
+
---
|
| 436 |
+
|
| 437 |
+
## Temporal Distribution
|
| 438 |
+
|
| 439 |
+
The dataset spans multiple years of reviews, enabling proper temporal train/validation/test splits that prevent data leakage.
|
| 440 |
+
|
| 441 |
+

|
| 442 |
+
|
| 443 |
+
**Temporal Split Strategy:**
|
| 444 |
+
- **Train (70%):** Oldest reviews - model learns from historical patterns
|
| 445 |
+
- **Validation (10%):** Middle period - hyperparameter tuning
|
| 446 |
+
- **Test (20%):** Most recent - simulates production deployment
|
| 447 |
+
|
| 448 |
+
This chronological ordering ensures the model never sees "future" data during training.
|
| 449 |
+
|
| 450 |
+
---
|
| 451 |
+
|
| 452 |
+
## User and Item Activity
|
| 453 |
+
|
| 454 |
+
The long-tail distribution is pronounced: most users write few reviews, and most items receive few reviews. This sparsity is the fundamental challenge recommendation systems address.
|
| 455 |
+
|
| 456 |
+

|
| 457 |
+
|
| 458 |
+
**User Activity:**
|
| 459 |
+
- Users with only 1 review: {users_one_review_pct:.1f}%
|
| 460 |
+
- Users with 5+ reviews: {users_5plus:,}
|
| 461 |
+
- Power user max: {max_user_reviews} reviews
|
| 462 |
+
|
| 463 |
+
**Item Popularity:**
|
| 464 |
+
- Items with only 1 review: {items_one_review_pct:.1f}%
|
| 465 |
+
- Items with 5+ reviews: {items_5plus:,}
|
| 466 |
+
- Most reviewed item: {max_item_reviews} reviews
|
| 467 |
+
|
| 468 |
+
**Cold-Start Implications:**
|
| 469 |
+
- Many items have sparse evidence - content-based features are critical
|
| 470 |
+
- User cold-start is common - onboarding preferences help
|
| 471 |
+
- 5-core filtering ensures minimum evidence density for evaluation
|
| 472 |
+
|
| 473 |
+
---
|
| 474 |
+
|
| 475 |
+
## Data Quality Assessment
|
| 476 |
+
|
| 477 |
+
The raw dataset contains several quality issues addressed during preprocessing.
|
| 478 |
+
|
| 479 |
+
| Issue | Count | Resolution |
|
| 480 |
+
|-------|-------|------------|
|
| 481 |
+
| Missing text | 0 | - |
|
| 482 |
+
| Empty reviews | {empty_reviews} | Removed |
|
| 483 |
+
| Very short (<10 chars) | {very_short:,} | Removed |
|
| 484 |
+
| Duplicate texts | {duplicate_texts:,} | Kept (valid re-purchases) |
|
| 485 |
+
| Invalid ratings | 0 | - |
|
| 486 |
+
|
| 487 |
+
**Post-Cleaning:**
|
| 488 |
+
- All reviews have valid text content
|
| 489 |
+
- All ratings are in [1, 5] range
|
| 490 |
+
- All user/product identifiers present
|
| 491 |
+
|
| 492 |
+
---
|
| 493 |
+
|
| 494 |
+
## Summary
|
| 495 |
+
|
| 496 |
+
The Amazon Electronics dataset, after 5-core filtering and cleaning, provides a solid foundation for building and evaluating a RAG-based recommendation system:
|
| 497 |
+
|
| 498 |
+
1. **Scale:** {prepared_total:,} reviews across {unique_users_prepared:,} users and {unique_items_prepared:,} items
|
| 499 |
+
2. **Sparsity:** {100 - retention_pct:.1f}% filtered - realistic for recommendation evaluation
|
| 500 |
+
3. **Quality:** Clean text, valid ratings, proper identifiers
|
| 501 |
+
4. **Temporal:** Supports chronological train/val/test splits
|
| 502 |
+
5. **Content:** Review lengths suit the tiered chunking strategy
|
| 503 |
+
|
| 504 |
+
The J-shaped rating distribution and long-tail user/item activity are characteristic of real e-commerce data, making this an appropriate benchmark for portfolio demonstration.
|
| 505 |
+
|
| 506 |
+
---
|
| 507 |
+
|
| 508 |
+
*Report auto-generated by `scripts/eda.py`. Run `make eda` to regenerate.*
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
report_path = REPORTS_DIR / "eda_report.md"
|
| 512 |
+
report_path.write_text(report_content)
|
| 513 |
+
print(f"\nReport generated: {report_path}")
|
scripts/evaluation.py
CHANGED
|
@@ -18,19 +18,19 @@ Run from project root.
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import argparse
|
| 21 |
-
import json
|
| 22 |
from collections.abc import Callable
|
| 23 |
from datetime import datetime
|
| 24 |
from pathlib import Path
|
| 25 |
|
| 26 |
from sage.core import AggregationMethod
|
|
|
|
| 27 |
from sage.services.baselines import (
|
| 28 |
ItemKNNBaseline,
|
| 29 |
PopularityBaseline,
|
| 30 |
RandomBaseline,
|
| 31 |
load_product_embeddings_from_qdrant,
|
| 32 |
)
|
| 33 |
-
from sage.config import
|
| 34 |
from sage.data import load_eval_cases, load_splits
|
| 35 |
from sage.services.evaluation import compute_item_popularity, evaluate_recommendations
|
| 36 |
from sage.services.retrieval import recommend
|
|
@@ -62,31 +62,6 @@ def create_recommend_fn(
|
|
| 62 |
return _recommend
|
| 63 |
|
| 64 |
|
| 65 |
-
def save_results(
|
| 66 |
-
results: dict, filename: str | None = None, dataset: str | None = None
|
| 67 |
-
) -> Path:
|
| 68 |
-
"""Save evaluation results to JSON file.
|
| 69 |
-
|
| 70 |
-
Also writes a fixed-name "latest" file so downstream scripts (e.g.
|
| 71 |
-
summary.py) can locate the most recent run without globbing.
|
| 72 |
-
"""
|
| 73 |
-
if filename is None:
|
| 74 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 75 |
-
filename = f"eval_results_{timestamp}.json"
|
| 76 |
-
filepath = RESULTS_DIR / filename
|
| 77 |
-
with open(filepath, "w", encoding="utf-8") as f:
|
| 78 |
-
json.dump(results, f, indent=2)
|
| 79 |
-
|
| 80 |
-
# Write latest symlink for the summary script
|
| 81 |
-
if dataset:
|
| 82 |
-
stem = Path(dataset).stem # e.g. "eval_loo_history"
|
| 83 |
-
latest_path = RESULTS_DIR / f"{stem}_latest.json"
|
| 84 |
-
with open(latest_path, "w", encoding="utf-8") as f:
|
| 85 |
-
json.dump(results, f, indent=2)
|
| 86 |
-
|
| 87 |
-
return filepath
|
| 88 |
-
|
| 89 |
-
|
| 90 |
# ============================================================================
|
| 91 |
# SECTION: Primary Evaluation
|
| 92 |
# ============================================================================
|
|
@@ -296,14 +271,7 @@ def run_baseline_comparison(cases, train_records, all_products, product_embeddin
|
|
| 296 |
def itemknn_recommend(query: str) -> list[str]:
|
| 297 |
return itemknn_baseline.recommend(query, top_k=10)
|
| 298 |
|
| 299 |
-
|
| 300 |
-
recs = recommend(
|
| 301 |
-
query=query,
|
| 302 |
-
top_k=10,
|
| 303 |
-
candidate_limit=100,
|
| 304 |
-
aggregation=AggregationMethod.MAX,
|
| 305 |
-
)
|
| 306 |
-
return [r.product_id for r in recs]
|
| 307 |
|
| 308 |
results = {}
|
| 309 |
methods = [
|
|
@@ -434,8 +402,9 @@ def main():
|
|
| 434 |
if args.baselines:
|
| 435 |
run_baseline_comparison(cases, train_records, all_products, item_embeddings)
|
| 436 |
|
| 437 |
-
# Save results
|
| 438 |
-
|
|
|
|
| 439 |
logger.info("Results saved to: %s", results_path)
|
| 440 |
|
| 441 |
log_banner(logger, "EVALUATION COMPLETE")
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import argparse
|
|
|
|
| 21 |
from collections.abc import Callable
|
| 22 |
from datetime import datetime
|
| 23 |
from pathlib import Path
|
| 24 |
|
| 25 |
from sage.core import AggregationMethod
|
| 26 |
+
from sage.utils import save_results
|
| 27 |
from sage.services.baselines import (
|
| 28 |
ItemKNNBaseline,
|
| 29 |
PopularityBaseline,
|
| 30 |
RandomBaseline,
|
| 31 |
load_product_embeddings_from_qdrant,
|
| 32 |
)
|
| 33 |
+
from sage.config import get_logger, log_banner, log_section, log_kv
|
| 34 |
from sage.data import load_eval_cases, load_splits
|
| 35 |
from sage.services.evaluation import compute_item_popularity, evaluate_recommendations
|
| 36 |
from sage.services.retrieval import recommend
|
|
|
|
| 62 |
return _recommend
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# ============================================================================
|
| 66 |
# SECTION: Primary Evaluation
|
| 67 |
# ============================================================================
|
|
|
|
| 271 |
def itemknn_recommend(query: str) -> list[str]:
|
| 272 |
return itemknn_baseline.recommend(query, top_k=10)
|
| 273 |
|
| 274 |
+
rag_recommend = create_recommend_fn(top_k=10, aggregation=AggregationMethod.MAX)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
results = {}
|
| 277 |
methods = [
|
|
|
|
| 402 |
if args.baselines:
|
| 403 |
run_baseline_comparison(cases, train_records, all_products, item_embeddings)
|
| 404 |
|
| 405 |
+
# Save results (uses dataset stem as prefix for both timestamped and latest files)
|
| 406 |
+
prefix = Path(args.dataset).stem
|
| 407 |
+
results_path = save_results(all_results, prefix)
|
| 408 |
logger.info("Results saved to: %s", results_path)
|
| 409 |
|
| 410 |
log_banner(logger, "EVALUATION COMPLETE")
|
scripts/explanation.py
CHANGED
|
@@ -43,8 +43,7 @@ PRODUCTS_PER_QUERY = 2
|
|
| 43 |
|
| 44 |
def run_basic_tests():
|
| 45 |
"""Test basic explanation generation and HHEM detection."""
|
| 46 |
-
from
|
| 47 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 48 |
|
| 49 |
log_banner(logger, "BASIC EXPLANATION TESTS")
|
| 50 |
logger.info("Using LLM provider: %s", LLM_PROVIDER)
|
|
@@ -71,7 +70,7 @@ def run_basic_tests():
|
|
| 71 |
|
| 72 |
# Generate explanations
|
| 73 |
log_section(logger, "2. GENERATING EXPLANATIONS")
|
| 74 |
-
explainer =
|
| 75 |
all_explanations = []
|
| 76 |
|
| 77 |
for query, products in query_results.items():
|
|
@@ -84,7 +83,6 @@ def run_basic_tests():
|
|
| 84 |
|
| 85 |
# Run HHEM
|
| 86 |
log_section(logger, "3. HHEM HALLUCINATION DETECTION")
|
| 87 |
-
detector = HallucinationDetector()
|
| 88 |
hhem_results = [
|
| 89 |
detector.check_explanation(expl.evidence_texts, expl.explanation)
|
| 90 |
for expl in all_explanations
|
|
@@ -108,9 +106,7 @@ def run_basic_tests():
|
|
| 108 |
logger.info("Streaming: ")
|
| 109 |
|
| 110 |
stream = explainer.generate_explanation_stream(test_query, test_product)
|
| 111 |
-
chunks =
|
| 112 |
-
for token in stream:
|
| 113 |
-
chunks.append(token)
|
| 114 |
logger.info("".join(chunks))
|
| 115 |
|
| 116 |
streamed_result = stream.get_complete_result()
|
|
|
|
| 43 |
|
| 44 |
def run_basic_tests():
|
| 45 |
"""Test basic explanation generation and HHEM detection."""
|
| 46 |
+
from scripts.lib.services import get_explanation_services
|
|
|
|
| 47 |
|
| 48 |
log_banner(logger, "BASIC EXPLANATION TESTS")
|
| 49 |
logger.info("Using LLM provider: %s", LLM_PROVIDER)
|
|
|
|
| 70 |
|
| 71 |
# Generate explanations
|
| 72 |
log_section(logger, "2. GENERATING EXPLANATIONS")
|
| 73 |
+
explainer, detector = get_explanation_services()
|
| 74 |
all_explanations = []
|
| 75 |
|
| 76 |
for query, products in query_results.items():
|
|
|
|
| 83 |
|
| 84 |
# Run HHEM
|
| 85 |
log_section(logger, "3. HHEM HALLUCINATION DETECTION")
|
|
|
|
| 86 |
hhem_results = [
|
| 87 |
detector.check_explanation(expl.evidence_texts, expl.explanation)
|
| 88 |
for expl in all_explanations
|
|
|
|
| 106 |
logger.info("Streaming: ")
|
| 107 |
|
| 108 |
stream = explainer.generate_explanation_stream(test_query, test_product)
|
| 109 |
+
chunks = list(stream)
|
|
|
|
|
|
|
| 110 |
logger.info("".join(chunks))
|
| 111 |
|
| 112 |
streamed_result = stream.get_complete_result()
|
scripts/faithfulness.py
CHANGED
|
@@ -51,8 +51,7 @@ TOP_K_PRODUCTS = 3
|
|
| 51 |
|
| 52 |
def run_evaluation(n_samples: int, run_ragas: bool = False):
|
| 53 |
"""Run faithfulness evaluation on sample queries."""
|
| 54 |
-
from
|
| 55 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 56 |
|
| 57 |
queries = EVALUATION_QUERIES[:n_samples]
|
| 58 |
|
|
@@ -62,7 +61,7 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 62 |
# Generate explanations
|
| 63 |
log_section(logger, "1. GENERATING EXPLANATIONS")
|
| 64 |
|
| 65 |
-
explainer =
|
| 66 |
all_explanations = []
|
| 67 |
|
| 68 |
for i, query in enumerate(queries, 1):
|
|
@@ -95,7 +94,6 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 95 |
# Run HHEM
|
| 96 |
log_section(logger, "2. HHEM HALLUCINATION DETECTION")
|
| 97 |
|
| 98 |
-
detector = HallucinationDetector()
|
| 99 |
hhem_results = [
|
| 100 |
detector.check_explanation(expl.evidence_texts, expl.explanation)
|
| 101 |
for expl in all_explanations
|
|
@@ -204,13 +202,11 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 204 |
|
| 205 |
def run_failure_analysis():
|
| 206 |
"""Analyze failure cases to identify root causes."""
|
| 207 |
-
from
|
| 208 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 209 |
|
| 210 |
log_banner(logger, "FAILURE CASE ANALYSIS")
|
| 211 |
|
| 212 |
-
explainer =
|
| 213 |
-
detector = HallucinationDetector()
|
| 214 |
|
| 215 |
all_cases = []
|
| 216 |
case_id = 0
|
|
|
|
| 51 |
|
| 52 |
def run_evaluation(n_samples: int, run_ragas: bool = False):
|
| 53 |
"""Run faithfulness evaluation on sample queries."""
|
| 54 |
+
from scripts.lib.services import get_explanation_services
|
|
|
|
| 55 |
|
| 56 |
queries = EVALUATION_QUERIES[:n_samples]
|
| 57 |
|
|
|
|
| 61 |
# Generate explanations
|
| 62 |
log_section(logger, "1. GENERATING EXPLANATIONS")
|
| 63 |
|
| 64 |
+
explainer, detector = get_explanation_services()
|
| 65 |
all_explanations = []
|
| 66 |
|
| 67 |
for i, query in enumerate(queries, 1):
|
|
|
|
| 94 |
# Run HHEM
|
| 95 |
log_section(logger, "2. HHEM HALLUCINATION DETECTION")
|
| 96 |
|
|
|
|
| 97 |
hhem_results = [
|
| 98 |
detector.check_explanation(expl.evidence_texts, expl.explanation)
|
| 99 |
for expl in all_explanations
|
|
|
|
| 202 |
|
| 203 |
def run_failure_analysis():
|
| 204 |
"""Analyze failure cases to identify root causes."""
|
| 205 |
+
from scripts.lib.services import get_explanation_services
|
|
|
|
| 206 |
|
| 207 |
log_banner(logger, "FAILURE CASE ANALYSIS")
|
| 208 |
|
| 209 |
+
explainer, detector = get_explanation_services()
|
|
|
|
| 210 |
|
| 211 |
all_cases = []
|
| 212 |
case_id = 0
|
scripts/human_eval.py
CHANGED
|
@@ -100,11 +100,12 @@ def _select_config_queries(exclude: set[str], target: int = 15) -> list[str]:
|
|
| 100 |
return selected
|
| 101 |
|
| 102 |
|
| 103 |
-
def generate_samples(force: bool = False):
|
| 104 |
"""Generate recommendation+explanation samples for human evaluation."""
|
|
|
|
|
|
|
| 105 |
from sage.services.retrieval import get_candidates
|
| 106 |
-
from
|
| 107 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 108 |
|
| 109 |
# Protect existing rated samples from accidental overwrite
|
| 110 |
if SAMPLES_FILE.exists() and not force:
|
|
@@ -124,11 +125,18 @@ def generate_samples(force: bool = False):
|
|
| 124 |
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 125 |
|
| 126 |
log_banner(logger, "GENERATING HUMAN EVAL SAMPLES")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
# Select diverse query set
|
| 129 |
natural = _select_diverse_natural_queries(35)
|
| 130 |
config = _select_config_queries(set(natural), 15)
|
| 131 |
all_queries = natural + config
|
|
|
|
|
|
|
|
|
|
| 132 |
logger.info(
|
| 133 |
"Queries: %d natural + %d config = %d total",
|
| 134 |
len(natural),
|
|
@@ -146,8 +154,7 @@ def generate_samples(force: bool = False):
|
|
| 146 |
)
|
| 147 |
|
| 148 |
# Initialize services
|
| 149 |
-
explainer =
|
| 150 |
-
detector = HallucinationDetector()
|
| 151 |
|
| 152 |
samples = []
|
| 153 |
for i, query in enumerate(all_queries, 1):
|
|
@@ -496,13 +503,22 @@ def main():
|
|
| 496 |
action="store_true",
|
| 497 |
help="Overwrite existing rated samples (with --generate)",
|
| 498 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
args = parser.parse_args()
|
| 500 |
|
| 501 |
if args.force and not args.generate:
|
| 502 |
parser.error("--force can only be used with --generate")
|
| 503 |
|
|
|
|
|
|
|
|
|
|
| 504 |
if args.generate:
|
| 505 |
-
generate_samples(force=args.force)
|
| 506 |
elif args.annotate:
|
| 507 |
annotate_samples()
|
| 508 |
elif args.analyze:
|
|
|
|
| 100 |
return selected
|
| 101 |
|
| 102 |
|
| 103 |
+
def generate_samples(force: bool = False, seed: int = 42):
|
| 104 |
"""Generate recommendation+explanation samples for human evaluation."""
|
| 105 |
+
import random
|
| 106 |
+
|
| 107 |
from sage.services.retrieval import get_candidates
|
| 108 |
+
from scripts.lib.services import get_explanation_services
|
|
|
|
| 109 |
|
| 110 |
# Protect existing rated samples from accidental overwrite
|
| 111 |
if SAMPLES_FILE.exists() and not force:
|
|
|
|
| 125 |
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 126 |
|
| 127 |
log_banner(logger, "GENERATING HUMAN EVAL SAMPLES")
|
| 128 |
+
logger.info("Random seed: %d", seed)
|
| 129 |
+
|
| 130 |
+
# Set seed for reproducibility
|
| 131 |
+
random.seed(seed)
|
| 132 |
|
| 133 |
# Select diverse query set
|
| 134 |
natural = _select_diverse_natural_queries(35)
|
| 135 |
config = _select_config_queries(set(natural), 15)
|
| 136 |
all_queries = natural + config
|
| 137 |
+
|
| 138 |
+
# Shuffle with seeded random for reproducibility
|
| 139 |
+
random.shuffle(all_queries)
|
| 140 |
logger.info(
|
| 141 |
"Queries: %d natural + %d config = %d total",
|
| 142 |
len(natural),
|
|
|
|
| 154 |
)
|
| 155 |
|
| 156 |
# Initialize services
|
| 157 |
+
explainer, detector = get_explanation_services()
|
|
|
|
| 158 |
|
| 159 |
samples = []
|
| 160 |
for i, query in enumerate(all_queries, 1):
|
|
|
|
| 503 |
action="store_true",
|
| 504 |
help="Overwrite existing rated samples (with --generate)",
|
| 505 |
)
|
| 506 |
+
parser.add_argument(
|
| 507 |
+
"--seed",
|
| 508 |
+
type=int,
|
| 509 |
+
default=42,
|
| 510 |
+
help="Random seed for query selection (with --generate)",
|
| 511 |
+
)
|
| 512 |
args = parser.parse_args()
|
| 513 |
|
| 514 |
if args.force and not args.generate:
|
| 515 |
parser.error("--force can only be used with --generate")
|
| 516 |
|
| 517 |
+
if args.seed != 42 and not args.generate:
|
| 518 |
+
parser.error("--seed can only be used with --generate")
|
| 519 |
+
|
| 520 |
if args.generate:
|
| 521 |
+
generate_samples(force=args.force, seed=args.seed)
|
| 522 |
elif args.annotate:
|
| 523 |
annotate_samples()
|
| 524 |
elif args.analyze:
|
scripts/lib/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utilities for scripts."""
|
| 2 |
+
|
| 3 |
+
from scripts.lib.services import get_explanation_services
|
| 4 |
+
|
| 5 |
+
__all__ = ["get_explanation_services"]
|
scripts/lib/services.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared service initialization for scripts."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from sage.adapters.hhem import HallucinationDetector
|
| 9 |
+
from sage.services.explanation import Explainer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_explanation_services() -> tuple[Explainer, HallucinationDetector]:
|
| 13 |
+
"""Initialize Explainer and HallucinationDetector.
|
| 14 |
+
|
| 15 |
+
Centralizes the common pattern of creating both services together.
|
| 16 |
+
Import is deferred to avoid loading heavy models until needed.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Tuple of (Explainer, HallucinationDetector) instances.
|
| 20 |
+
"""
|
| 21 |
+
from sage.adapters.hhem import HallucinationDetector
|
| 22 |
+
from sage.services.explanation import Explainer
|
| 23 |
+
|
| 24 |
+
return Explainer(), HallucinationDetector()
|
scripts/load_test.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Load test script for Sage API.
|
| 4 |
+
|
| 5 |
+
Runs sequential requests and reports p50, p95, p99 latency.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# Start the API first:
|
| 9 |
+
python -m sage.api.run
|
| 10 |
+
|
| 11 |
+
# Then run the load test:
|
| 12 |
+
python scripts/load_test.py --requests 100 --url http://localhost:8000
|
| 13 |
+
|
| 14 |
+
# Test without explanations (faster):
|
| 15 |
+
python scripts/load_test.py --no-explain
|
| 16 |
+
|
| 17 |
+
David's target: p99 < 500ms
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import statistics
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
|
| 25 |
+
import httpx
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Test queries covering different scenarios
|
| 29 |
+
QUERIES = [
|
| 30 |
+
"wireless headphones for working out",
|
| 31 |
+
"laptop for video editing under $1500",
|
| 32 |
+
"best phone case for iPhone",
|
| 33 |
+
"comfortable running shoes",
|
| 34 |
+
"noise canceling earbuds",
|
| 35 |
+
"gaming keyboard mechanical",
|
| 36 |
+
"portable charger high capacity",
|
| 37 |
+
"bluetooth speaker waterproof",
|
| 38 |
+
"monitor for programming",
|
| 39 |
+
"ergonomic office chair",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def percentile(data: list[float], p: float) -> float:
|
| 44 |
+
"""Calculate the p-th percentile of data."""
|
| 45 |
+
if not data:
|
| 46 |
+
return 0.0
|
| 47 |
+
sorted_data = sorted(data)
|
| 48 |
+
k = (len(sorted_data) - 1) * (p / 100)
|
| 49 |
+
f = int(k)
|
| 50 |
+
c = f + 1
|
| 51 |
+
if c >= len(sorted_data):
|
| 52 |
+
return sorted_data[-1]
|
| 53 |
+
return sorted_data[f] + (sorted_data[c] - sorted_data[f]) * (k - f)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def run_load_test(
|
| 57 |
+
base_url: str,
|
| 58 |
+
num_requests: int,
|
| 59 |
+
explain: bool,
|
| 60 |
+
timeout: float,
|
| 61 |
+
) -> dict:
|
| 62 |
+
"""Run load test and return metrics."""
|
| 63 |
+
latencies: list[float] = []
|
| 64 |
+
errors = 0
|
| 65 |
+
cache_hits = 0
|
| 66 |
+
|
| 67 |
+
client = httpx.Client(timeout=timeout)
|
| 68 |
+
endpoint = f"{base_url}/recommend"
|
| 69 |
+
|
| 70 |
+
print(f"\nRunning {num_requests} requests to {endpoint}")
|
| 71 |
+
print(f" explain={explain}, timeout={timeout}s")
|
| 72 |
+
print("-" * 50)
|
| 73 |
+
|
| 74 |
+
for i in range(num_requests):
|
| 75 |
+
query = QUERIES[i % len(QUERIES)]
|
| 76 |
+
payload = {
|
| 77 |
+
"query": query,
|
| 78 |
+
"k": 3,
|
| 79 |
+
"explain": explain,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
start = time.perf_counter()
|
| 84 |
+
resp = client.post(endpoint, json=payload)
|
| 85 |
+
elapsed = time.perf_counter() - start
|
| 86 |
+
|
| 87 |
+
if resp.status_code == 200:
|
| 88 |
+
latencies.append(elapsed * 1000) # Convert to ms
|
| 89 |
+
|
| 90 |
+
# Check for cache hit (response time < 100ms typically indicates cache)
|
| 91 |
+
if elapsed < 0.1:
|
| 92 |
+
cache_hits += 1
|
| 93 |
+
else:
|
| 94 |
+
errors += 1
|
| 95 |
+
print(f" [{i + 1}] Error: {resp.status_code} - {resp.text[:100]}")
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
errors += 1
|
| 99 |
+
print(f" [{i + 1}] Exception: {e}")
|
| 100 |
+
|
| 101 |
+
# Progress indicator
|
| 102 |
+
if (i + 1) % 10 == 0:
|
| 103 |
+
print(f" Completed {i + 1}/{num_requests} requests...")
|
| 104 |
+
|
| 105 |
+
client.close()
|
| 106 |
+
|
| 107 |
+
# Calculate statistics
|
| 108 |
+
if latencies:
|
| 109 |
+
results = {
|
| 110 |
+
"total_requests": num_requests,
|
| 111 |
+
"successful": len(latencies),
|
| 112 |
+
"errors": errors,
|
| 113 |
+
"cache_hits": cache_hits,
|
| 114 |
+
"min_ms": min(latencies),
|
| 115 |
+
"max_ms": max(latencies),
|
| 116 |
+
"mean_ms": statistics.mean(latencies),
|
| 117 |
+
"median_ms": statistics.median(latencies),
|
| 118 |
+
"p50_ms": percentile(latencies, 50),
|
| 119 |
+
"p95_ms": percentile(latencies, 95),
|
| 120 |
+
"p99_ms": percentile(latencies, 99),
|
| 121 |
+
"stdev_ms": statistics.stdev(latencies) if len(latencies) > 1 else 0,
|
| 122 |
+
}
|
| 123 |
+
else:
|
| 124 |
+
results = {
|
| 125 |
+
"total_requests": num_requests,
|
| 126 |
+
"successful": 0,
|
| 127 |
+
"errors": errors,
|
| 128 |
+
"cache_hits": 0,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
return results
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def print_results(results: dict, target_p99_ms: float = 500.0) -> None:
|
| 135 |
+
"""Print formatted results."""
|
| 136 |
+
print("\n" + "=" * 50)
|
| 137 |
+
print("LOAD TEST RESULTS")
|
| 138 |
+
print("=" * 50)
|
| 139 |
+
|
| 140 |
+
print(f"\nRequests: {results['successful']}/{results['total_requests']} successful")
|
| 141 |
+
print(f"Errors: {results['errors']}")
|
| 142 |
+
print(f"Cache hits: {results.get('cache_hits', 0)}")
|
| 143 |
+
|
| 144 |
+
if results["successful"] > 0:
|
| 145 |
+
print("\nLatency (ms):")
|
| 146 |
+
print(f" Min: {results['min_ms']:.1f}")
|
| 147 |
+
print(f" Max: {results['max_ms']:.1f}")
|
| 148 |
+
print(f" Mean: {results['mean_ms']:.1f}")
|
| 149 |
+
print(f" Median: {results['median_ms']:.1f}")
|
| 150 |
+
print(f" StdDev: {results['stdev_ms']:.1f}")
|
| 151 |
+
|
| 152 |
+
print("\nPercentiles (ms):")
|
| 153 |
+
print(f" p50: {results['p50_ms']:.1f}")
|
| 154 |
+
print(f" p95: {results['p95_ms']:.1f}")
|
| 155 |
+
print(f" p99: {results['p99_ms']:.1f}")
|
| 156 |
+
|
| 157 |
+
# Target check
|
| 158 |
+
p99 = results["p99_ms"]
|
| 159 |
+
if p99 <= target_p99_ms:
|
| 160 |
+
print(f"\n Target p99 < {target_p99_ms}ms: PASS ({p99:.1f}ms)")
|
| 161 |
+
else:
|
| 162 |
+
print(f"\n Target p99 < {target_p99_ms}ms: FAIL ({p99:.1f}ms)")
|
| 163 |
+
print(
|
| 164 |
+
" Bottleneck: Likely LLM generation (check sage_llm_duration_seconds)"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
print("\n" + "=" * 50)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def main():
|
| 171 |
+
parser = argparse.ArgumentParser(description="Load test Sage API")
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--url",
|
| 174 |
+
default="http://localhost:8000",
|
| 175 |
+
help="Base URL of the API (default: http://localhost:8000)",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--requests",
|
| 179 |
+
type=int,
|
| 180 |
+
default=100,
|
| 181 |
+
help="Number of requests to send (default: 100)",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--no-explain",
|
| 185 |
+
action="store_true",
|
| 186 |
+
help="Disable explanations (faster, tests retrieval only)",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--timeout",
|
| 190 |
+
type=float,
|
| 191 |
+
default=30.0,
|
| 192 |
+
help="Request timeout in seconds (default: 30)",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--target-p99",
|
| 196 |
+
type=float,
|
| 197 |
+
default=500.0,
|
| 198 |
+
help="Target p99 latency in ms (default: 500)",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
args = parser.parse_args()
|
| 202 |
+
|
| 203 |
+
# Quick health check
|
| 204 |
+
try:
|
| 205 |
+
resp = httpx.get(f"{args.url}/health", timeout=5.0)
|
| 206 |
+
if resp.status_code != 200:
|
| 207 |
+
print(f"API health check failed: {resp.status_code}")
|
| 208 |
+
sys.exit(1)
|
| 209 |
+
health = resp.json()
|
| 210 |
+
print(f"API Status: {health.get('status', 'unknown')}")
|
| 211 |
+
print(
|
| 212 |
+
f"Qdrant: {'connected' if health.get('qdrant_connected') else 'disconnected'}"
|
| 213 |
+
)
|
| 214 |
+
print(f"LLM: {'available' if health.get('llm_reachable') else 'unavailable'}")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"Cannot connect to API at {args.url}: {e}")
|
| 217 |
+
sys.exit(1)
|
| 218 |
+
|
| 219 |
+
results = run_load_test(
|
| 220 |
+
base_url=args.url,
|
| 221 |
+
num_requests=args.requests,
|
| 222 |
+
explain=not args.no_explain,
|
| 223 |
+
timeout=args.timeout,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
print_results(results, target_p99_ms=args.target_p99)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
main()
|
scripts/pipeline.py
CHANGED
|
@@ -26,6 +26,7 @@ from sage.config import (
|
|
| 26 |
CHARS_PER_TOKEN,
|
| 27 |
DEV_SUBSET_SIZE,
|
| 28 |
DATA_DIR,
|
|
|
|
| 29 |
get_logger,
|
| 30 |
log_banner,
|
| 31 |
log_section,
|
|
@@ -68,7 +69,7 @@ def run_tokenizer_validation():
|
|
| 68 |
logger.info("Loaded reviews and sampled 500", extra={"total": len(df)})
|
| 69 |
logger.info("Loading E5 tokenizer...")
|
| 70 |
|
| 71 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 72 |
|
| 73 |
ratios = []
|
| 74 |
for text in sample:
|
|
|
|
| 26 |
CHARS_PER_TOKEN,
|
| 27 |
DEV_SUBSET_SIZE,
|
| 28 |
DATA_DIR,
|
| 29 |
+
EMBEDDING_MODEL,
|
| 30 |
get_logger,
|
| 31 |
log_banner,
|
| 32 |
log_section,
|
|
|
|
| 69 |
logger.info("Loaded reviews and sampled 500", extra={"total": len(df)})
|
| 70 |
logger.info("Loading E5 tokenizer...")
|
| 71 |
|
| 72 |
+
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
|
| 73 |
|
| 74 |
ratios = []
|
| 75 |
for text in sample:
|
scripts/sanity_checks.py
CHANGED
|
@@ -17,14 +17,16 @@ Usage:
|
|
| 17 |
Run from project root.
|
| 18 |
"""
|
| 19 |
|
|
|
|
|
|
|
| 20 |
import argparse
|
| 21 |
from dataclasses import dataclass
|
|
|
|
| 22 |
|
| 23 |
import numpy as np
|
| 24 |
|
| 25 |
from sage.core import AggregationMethod, ProductScore, RetrievedChunk
|
| 26 |
from sage.config import (
|
| 27 |
-
DATA_DIR,
|
| 28 |
EVALUATION_QUERIES,
|
| 29 |
get_logger,
|
| 30 |
log_banner,
|
|
@@ -32,10 +34,11 @@ from sage.config import (
|
|
| 32 |
)
|
| 33 |
from sage.services.retrieval import get_candidates
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
RESULTS_DIR.mkdir(exist_ok=True)
|
| 39 |
|
| 40 |
|
| 41 |
# ============================================================================
|
|
@@ -43,16 +46,10 @@ RESULTS_DIR.mkdir(exist_ok=True)
|
|
| 43 |
# ============================================================================
|
| 44 |
|
| 45 |
|
| 46 |
-
def run_spot_check():
|
| 47 |
"""Manual spot-check of explanations vs evidence."""
|
| 48 |
-
from sage.services.explanation import Explainer
|
| 49 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 50 |
-
|
| 51 |
log_banner(logger, "SPOT-CHECK: Manual Inspection", width=70)
|
| 52 |
|
| 53 |
-
explainer = Explainer()
|
| 54 |
-
detector = HallucinationDetector()
|
| 55 |
-
|
| 56 |
results = []
|
| 57 |
queries = EVALUATION_QUERIES[:5]
|
| 58 |
|
|
@@ -94,16 +91,10 @@ def run_spot_check():
|
|
| 94 |
# ============================================================================
|
| 95 |
|
| 96 |
|
| 97 |
-
def run_adversarial_tests():
|
| 98 |
"""Test with contradictory evidence."""
|
| 99 |
-
from sage.services.explanation import Explainer
|
| 100 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 101 |
-
|
| 102 |
log_banner(logger, "ADVERSARIAL: Contradictory Evidence", width=70)
|
| 103 |
|
| 104 |
-
explainer = Explainer()
|
| 105 |
-
detector = HallucinationDetector()
|
| 106 |
-
|
| 107 |
cases = [
|
| 108 |
{
|
| 109 |
"name": "Battery Contradiction",
|
|
@@ -169,16 +160,10 @@ def run_adversarial_tests():
|
|
| 169 |
# ============================================================================
|
| 170 |
|
| 171 |
|
| 172 |
-
def run_empty_context_tests():
|
| 173 |
"""Test graceful refusal with irrelevant evidence."""
|
| 174 |
-
from sage.services.explanation import Explainer
|
| 175 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 176 |
-
|
| 177 |
log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70)
|
| 178 |
|
| 179 |
-
explainer = Explainer()
|
| 180 |
-
detector = HallucinationDetector()
|
| 181 |
-
|
| 182 |
cases = [
|
| 183 |
{
|
| 184 |
"name": "Irrelevant",
|
|
@@ -250,16 +235,10 @@ class CalibrationSample:
|
|
| 250 |
hhem_score: float
|
| 251 |
|
| 252 |
|
| 253 |
-
def run_calibration_check():
|
| 254 |
"""Analyze confidence vs faithfulness correlation."""
|
| 255 |
-
from sage.services.explanation import Explainer
|
| 256 |
-
from sage.adapters.hhem import HallucinationDetector
|
| 257 |
-
|
| 258 |
log_banner(logger, "CALIBRATION: Confidence vs Faithfulness", width=70)
|
| 259 |
|
| 260 |
-
explainer = Explainer()
|
| 261 |
-
detector = HallucinationDetector()
|
| 262 |
-
|
| 263 |
samples = []
|
| 264 |
queries = EVALUATION_QUERIES[:15]
|
| 265 |
|
|
@@ -330,6 +309,9 @@ def run_calibration_check():
|
|
| 330 |
|
| 331 |
|
| 332 |
def main():
|
|
|
|
|
|
|
|
|
|
| 333 |
parser = argparse.ArgumentParser(description="Run pipeline sanity checks")
|
| 334 |
parser.add_argument(
|
| 335 |
"--section",
|
|
@@ -340,14 +322,18 @@ def main():
|
|
| 340 |
)
|
| 341 |
args = parser.parse_args()
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
if args.section in ("all", "spot"):
|
| 344 |
-
run_spot_check()
|
| 345 |
if args.section in ("all", "adversarial"):
|
| 346 |
-
run_adversarial_tests()
|
| 347 |
if args.section in ("all", "empty"):
|
| 348 |
-
run_empty_context_tests()
|
| 349 |
if args.section in ("all", "calibration"):
|
| 350 |
-
run_calibration_check()
|
| 351 |
|
| 352 |
log_banner(logger, "SANITY CHECKS COMPLETE", width=70)
|
| 353 |
|
|
|
|
| 17 |
Run from project root.
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
import argparse
|
| 23 |
from dataclasses import dataclass
|
| 24 |
+
from typing import TYPE_CHECKING
|
| 25 |
|
| 26 |
import numpy as np
|
| 27 |
|
| 28 |
from sage.core import AggregationMethod, ProductScore, RetrievedChunk
|
| 29 |
from sage.config import (
|
|
|
|
| 30 |
EVALUATION_QUERIES,
|
| 31 |
get_logger,
|
| 32 |
log_banner,
|
|
|
|
| 34 |
)
|
| 35 |
from sage.services.retrieval import get_candidates
|
| 36 |
|
| 37 |
+
if TYPE_CHECKING:
|
| 38 |
+
from sage.adapters.hhem import HallucinationDetector
|
| 39 |
+
from sage.services.explanation import Explainer
|
| 40 |
|
| 41 |
+
logger = get_logger(__name__)
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
# ============================================================================
|
|
|
|
| 46 |
# ============================================================================
|
| 47 |
|
| 48 |
|
| 49 |
+
def run_spot_check(explainer: Explainer, detector: HallucinationDetector):
|
| 50 |
"""Manual spot-check of explanations vs evidence."""
|
|
|
|
|
|
|
|
|
|
| 51 |
log_banner(logger, "SPOT-CHECK: Manual Inspection", width=70)
|
| 52 |
|
|
|
|
|
|
|
|
|
|
| 53 |
results = []
|
| 54 |
queries = EVALUATION_QUERIES[:5]
|
| 55 |
|
|
|
|
| 91 |
# ============================================================================
|
| 92 |
|
| 93 |
|
| 94 |
+
def run_adversarial_tests(explainer: Explainer, detector: HallucinationDetector):
|
| 95 |
"""Test with contradictory evidence."""
|
|
|
|
|
|
|
|
|
|
| 96 |
log_banner(logger, "ADVERSARIAL: Contradictory Evidence", width=70)
|
| 97 |
|
|
|
|
|
|
|
|
|
|
| 98 |
cases = [
|
| 99 |
{
|
| 100 |
"name": "Battery Contradiction",
|
|
|
|
| 160 |
# ============================================================================
|
| 161 |
|
| 162 |
|
| 163 |
+
def run_empty_context_tests(explainer: Explainer, detector: HallucinationDetector):
|
| 164 |
"""Test graceful refusal with irrelevant evidence."""
|
|
|
|
|
|
|
|
|
|
| 165 |
log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70)
|
| 166 |
|
|
|
|
|
|
|
|
|
|
| 167 |
cases = [
|
| 168 |
{
|
| 169 |
"name": "Irrelevant",
|
|
|
|
| 235 |
hhem_score: float
|
| 236 |
|
| 237 |
|
| 238 |
+
def run_calibration_check(explainer: Explainer, detector: HallucinationDetector):
|
| 239 |
"""Analyze confidence vs faithfulness correlation."""
|
|
|
|
|
|
|
|
|
|
| 240 |
log_banner(logger, "CALIBRATION: Confidence vs Faithfulness", width=70)
|
| 241 |
|
|
|
|
|
|
|
|
|
|
| 242 |
samples = []
|
| 243 |
queries = EVALUATION_QUERIES[:15]
|
| 244 |
|
|
|
|
| 309 |
|
| 310 |
|
| 311 |
def main():
|
| 312 |
+
from sage.adapters.hhem import HallucinationDetector
|
| 313 |
+
from sage.services.explanation import Explainer
|
| 314 |
+
|
| 315 |
parser = argparse.ArgumentParser(description="Run pipeline sanity checks")
|
| 316 |
parser.add_argument(
|
| 317 |
"--section",
|
|
|
|
| 322 |
)
|
| 323 |
args = parser.parse_args()
|
| 324 |
|
| 325 |
+
# Initialize services once
|
| 326 |
+
explainer = Explainer()
|
| 327 |
+
detector = HallucinationDetector()
|
| 328 |
+
|
| 329 |
if args.section in ("all", "spot"):
|
| 330 |
+
run_spot_check(explainer, detector)
|
| 331 |
if args.section in ("all", "adversarial"):
|
| 332 |
+
run_adversarial_tests(explainer, detector)
|
| 333 |
if args.section in ("all", "empty"):
|
| 334 |
+
run_empty_context_tests(explainer, detector)
|
| 335 |
if args.section in ("all", "calibration"):
|
| 336 |
+
run_calibration_check(explainer, detector)
|
| 337 |
|
| 338 |
log_banner(logger, "SANITY CHECKS COMPLETE", width=70)
|
| 339 |
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared pytest fixtures for Sage tests."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from sage.core.models import ProductScore, RetrievedChunk
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def make_chunk():
|
| 10 |
+
"""Factory fixture for creating RetrievedChunk instances."""
|
| 11 |
+
|
| 12 |
+
def _make_chunk(
|
| 13 |
+
product_id: str = "P1",
|
| 14 |
+
score: float = 0.85,
|
| 15 |
+
rating: float = 4.5,
|
| 16 |
+
text: str | None = None,
|
| 17 |
+
review_id: str | None = None,
|
| 18 |
+
) -> RetrievedChunk:
|
| 19 |
+
return RetrievedChunk(
|
| 20 |
+
text=text or f"Review for {product_id}",
|
| 21 |
+
score=score,
|
| 22 |
+
product_id=product_id,
|
| 23 |
+
rating=rating,
|
| 24 |
+
review_id=review_id or f"rev_{product_id}",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
return _make_chunk
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture
|
| 31 |
+
def make_product():
|
| 32 |
+
"""Factory fixture for creating ProductScore instances with evidence."""
|
| 33 |
+
|
| 34 |
+
def _make_product(
|
| 35 |
+
product_id: str = "P1",
|
| 36 |
+
score: float = 0.85,
|
| 37 |
+
n_chunks: int = 2,
|
| 38 |
+
avg_rating: float = 4.5,
|
| 39 |
+
text_len: int = 200,
|
| 40 |
+
) -> ProductScore:
|
| 41 |
+
evidence = [
|
| 42 |
+
RetrievedChunk(
|
| 43 |
+
text="x" * text_len,
|
| 44 |
+
score=score - i * 0.01,
|
| 45 |
+
product_id=product_id,
|
| 46 |
+
rating=avg_rating,
|
| 47 |
+
review_id=f"rev_{i}",
|
| 48 |
+
)
|
| 49 |
+
for i in range(n_chunks)
|
| 50 |
+
]
|
| 51 |
+
return ProductScore(
|
| 52 |
+
product_id=product_id,
|
| 53 |
+
score=score,
|
| 54 |
+
chunk_count=n_chunks,
|
| 55 |
+
avg_rating=avg_rating,
|
| 56 |
+
evidence=evidence,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return _make_product
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@pytest.fixture
|
| 63 |
+
def sample_chunk(make_chunk) -> RetrievedChunk:
|
| 64 |
+
"""A sample RetrievedChunk for simple tests."""
|
| 65 |
+
return make_chunk(product_id="P1", score=0.9, rating=4.5, text="Good product")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@pytest.fixture
|
| 69 |
+
def sample_product(make_product) -> ProductScore:
|
| 70 |
+
"""A sample ProductScore for simple tests."""
|
| 71 |
+
return make_product(product_id="P1", score=0.9, n_chunks=2, avg_rating=4.5)
|
tests/test_aggregation.py
CHANGED
|
@@ -3,71 +3,60 @@
|
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
from sage.core.aggregation import aggregate_chunks_to_products, apply_weighted_ranking
|
| 6 |
-
from sage.core.models import AggregationMethod, ProductScore
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def _chunk(product_id: str, score: float, rating: float = 4.5) -> RetrievedChunk:
|
| 10 |
-
"""Helper to build a RetrievedChunk."""
|
| 11 |
-
return RetrievedChunk(
|
| 12 |
-
text=f"Review for {product_id}",
|
| 13 |
-
score=score,
|
| 14 |
-
product_id=product_id,
|
| 15 |
-
rating=rating,
|
| 16 |
-
review_id=f"rev_{product_id}",
|
| 17 |
-
)
|
| 18 |
|
| 19 |
|
| 20 |
class TestAggregateChunksToProducts:
|
| 21 |
-
def test_single_chunk_per_product(self):
|
| 22 |
-
chunks = [
|
| 23 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 24 |
assert len(products) == 2
|
| 25 |
ids = {p.product_id for p in products}
|
| 26 |
assert ids == {"A", "B"}
|
| 27 |
|
| 28 |
-
def test_max_aggregation(self):
|
| 29 |
-
chunks = [
|
| 30 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 31 |
assert len(products) == 1
|
| 32 |
assert products[0].score == pytest.approx(0.9)
|
| 33 |
|
| 34 |
-
def test_mean_aggregation(self):
|
| 35 |
-
chunks = [
|
| 36 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MEAN)
|
| 37 |
assert len(products) == 1
|
| 38 |
assert products[0].score == pytest.approx(0.7, abs=0.01)
|
| 39 |
|
| 40 |
-
def test_weighted_mean_aggregation(self):
|
| 41 |
chunks = [
|
| 42 |
-
|
| 43 |
-
|
| 44 |
]
|
| 45 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.WEIGHTED_MEAN)
|
| 46 |
assert len(products) == 1
|
| 47 |
# Weighted by rating: (0.9*5 + 0.5*1) / (5+1) = 5.0/6 = 0.833
|
| 48 |
assert products[0].score == pytest.approx(0.833, abs=0.01)
|
| 49 |
|
| 50 |
-
def test_sorted_by_score_descending(self):
|
| 51 |
-
chunks = [
|
| 52 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 53 |
scores = [p.score for p in products]
|
| 54 |
assert scores == sorted(scores, reverse=True)
|
| 55 |
|
| 56 |
-
def test_chunk_count_tracked(self):
|
| 57 |
-
chunks = [
|
| 58 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 59 |
product_a = next(p for p in products if p.product_id == "A")
|
| 60 |
product_b = next(p for p in products if p.product_id == "B")
|
| 61 |
assert product_a.chunk_count == 2
|
| 62 |
assert product_b.chunk_count == 1
|
| 63 |
|
| 64 |
-
def test_avg_rating_computed(self):
|
| 65 |
-
chunks = [
|
| 66 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 67 |
assert products[0].avg_rating == pytest.approx(4.0)
|
| 68 |
|
| 69 |
-
def test_evidence_preserved(self):
|
| 70 |
-
chunks = [
|
| 71 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 72 |
assert len(products[0].evidence) == 2
|
| 73 |
|
|
|
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
from sage.core.aggregation import aggregate_chunks_to_products, apply_weighted_ranking
|
| 6 |
+
from sage.core.models import AggregationMethod, ProductScore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class TestAggregateChunksToProducts:
|
| 10 |
+
def test_single_chunk_per_product(self, make_chunk):
|
| 11 |
+
chunks = [make_chunk("A", 0.9), make_chunk("B", 0.8)]
|
| 12 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 13 |
assert len(products) == 2
|
| 14 |
ids = {p.product_id for p in products}
|
| 15 |
assert ids == {"A", "B"}
|
| 16 |
|
| 17 |
+
def test_max_aggregation(self, make_chunk):
|
| 18 |
+
chunks = [make_chunk("A", 0.9), make_chunk("A", 0.7), make_chunk("A", 0.5)]
|
| 19 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 20 |
assert len(products) == 1
|
| 21 |
assert products[0].score == pytest.approx(0.9)
|
| 22 |
|
| 23 |
+
def test_mean_aggregation(self, make_chunk):
|
| 24 |
+
chunks = [make_chunk("A", 0.9), make_chunk("A", 0.7), make_chunk("A", 0.5)]
|
| 25 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MEAN)
|
| 26 |
assert len(products) == 1
|
| 27 |
assert products[0].score == pytest.approx(0.7, abs=0.01)
|
| 28 |
|
| 29 |
+
def test_weighted_mean_aggregation(self, make_chunk):
|
| 30 |
chunks = [
|
| 31 |
+
make_chunk("A", 0.9, rating=5.0),
|
| 32 |
+
make_chunk("A", 0.5, rating=1.0),
|
| 33 |
]
|
| 34 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.WEIGHTED_MEAN)
|
| 35 |
assert len(products) == 1
|
| 36 |
# Weighted by rating: (0.9*5 + 0.5*1) / (5+1) = 5.0/6 = 0.833
|
| 37 |
assert products[0].score == pytest.approx(0.833, abs=0.01)
|
| 38 |
|
| 39 |
+
def test_sorted_by_score_descending(self, make_chunk):
|
| 40 |
+
chunks = [make_chunk("A", 0.5), make_chunk("B", 0.9), make_chunk("C", 0.7)]
|
| 41 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 42 |
scores = [p.score for p in products]
|
| 43 |
assert scores == sorted(scores, reverse=True)
|
| 44 |
|
| 45 |
+
def test_chunk_count_tracked(self, make_chunk):
|
| 46 |
+
chunks = [make_chunk("A", 0.9), make_chunk("A", 0.7), make_chunk("B", 0.8)]
|
| 47 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 48 |
product_a = next(p for p in products if p.product_id == "A")
|
| 49 |
product_b = next(p for p in products if p.product_id == "B")
|
| 50 |
assert product_a.chunk_count == 2
|
| 51 |
assert product_b.chunk_count == 1
|
| 52 |
|
| 53 |
+
def test_avg_rating_computed(self, make_chunk):
|
| 54 |
+
chunks = [make_chunk("A", 0.9, rating=5.0), make_chunk("A", 0.7, rating=3.0)]
|
| 55 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 56 |
assert products[0].avg_rating == pytest.approx(4.0)
|
| 57 |
|
| 58 |
+
def test_evidence_preserved(self, make_chunk):
|
| 59 |
+
chunks = [make_chunk("A", 0.9), make_chunk("A", 0.7)]
|
| 60 |
products = aggregate_chunks_to_products(chunks, AggregationMethod.MAX)
|
| 61 |
assert len(products[0].evidence) == 2
|
| 62 |
|
tests/test_api.py
CHANGED
|
@@ -4,7 +4,7 @@ Uses a test app with mocked state to avoid loading heavy models.
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from types import SimpleNamespace
|
| 7 |
-
from unittest.mock import MagicMock
|
| 8 |
|
| 9 |
import pytest
|
| 10 |
from fastapi import FastAPI
|
|
@@ -39,10 +39,14 @@ def _make_app(**state_overrides) -> FastAPI:
|
|
| 39 |
avg_semantic_similarity=0.0,
|
| 40 |
)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
app.state.qdrant = state_overrides.get("qdrant", mock_qdrant)
|
| 43 |
app.state.embedder = state_overrides.get("embedder", MagicMock())
|
| 44 |
app.state.detector = state_overrides.get("detector", MagicMock())
|
| 45 |
-
app.state.explainer = state_overrides.get("explainer",
|
| 46 |
app.state.cache = state_overrides.get("cache", mock_cache)
|
| 47 |
|
| 48 |
return app
|
|
@@ -55,118 +59,119 @@ def client():
|
|
| 55 |
return TestClient(app)
|
| 56 |
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
original = routes_mod.collection_exists
|
| 68 |
-
routes_mod.collection_exists = lambda client: True
|
| 69 |
-
try:
|
| 70 |
-
resp = c.get("/health")
|
| 71 |
-
assert resp.status_code == 200
|
| 72 |
-
data = resp.json()
|
| 73 |
-
assert data["status"] == "healthy"
|
| 74 |
-
assert data["qdrant_connected"] is True
|
| 75 |
-
finally:
|
| 76 |
-
routes_mod.collection_exists = original
|
| 77 |
-
|
| 78 |
-
def test_degraded_when_collection_missing(self):
|
| 79 |
app = _make_app()
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
finally:
|
| 92 |
-
routes_mod.collection_exists = original
|
| 93 |
|
| 94 |
|
| 95 |
class TestRecommendEndpoint:
|
| 96 |
def test_missing_query_returns_422(self, client):
|
| 97 |
-
|
|
|
|
| 98 |
assert resp.status_code == 422
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
assert resp.status_code == 200
|
| 108 |
data = resp.json()
|
| 109 |
-
assert data["recommendations"] ==
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
RetrievedChunk(
|
| 121 |
-
text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"
|
| 122 |
-
),
|
| 123 |
-
],
|
| 124 |
-
)
|
| 125 |
-
import sage.api.routes as routes_mod
|
| 126 |
-
|
| 127 |
-
original = routes_mod.get_candidates
|
| 128 |
-
routes_mod.get_candidates = lambda **kw: [product]
|
| 129 |
app = _make_app()
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
def test_explainer_unavailable_returns_503(self):
|
| 144 |
-
product = ProductScore(
|
| 145 |
-
product_id="P1",
|
| 146 |
-
score=0.9,
|
| 147 |
-
chunk_count=2,
|
| 148 |
-
avg_rating=4.5,
|
| 149 |
-
evidence=[
|
| 150 |
-
RetrievedChunk(
|
| 151 |
-
text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"
|
| 152 |
-
),
|
| 153 |
-
],
|
| 154 |
-
)
|
| 155 |
-
import sage.api.routes as routes_mod
|
| 156 |
-
|
| 157 |
-
original = routes_mod.get_candidates
|
| 158 |
-
routes_mod.get_candidates = lambda **kw: [product]
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
mock_embedder = MagicMock()
|
| 161 |
mock_embedder.embed_single_query.return_value = [0.1] * 384
|
| 162 |
app = _make_app(explainer=None, embedder=mock_embedder)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
assert "unavailable" in resp.json()["error"].lower()
|
| 168 |
-
finally:
|
| 169 |
-
routes_mod.get_candidates = original
|
| 170 |
|
| 171 |
|
| 172 |
class TestCacheEndpoints:
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from types import SimpleNamespace
|
| 7 |
+
from unittest.mock import MagicMock, patch
|
| 8 |
|
| 9 |
import pytest
|
| 10 |
from fastapi import FastAPI
|
|
|
|
| 39 |
avg_semantic_similarity=0.0,
|
| 40 |
)
|
| 41 |
|
| 42 |
+
# Mock explainer with client attribute for health check
|
| 43 |
+
mock_explainer = MagicMock()
|
| 44 |
+
mock_explainer.client = MagicMock()
|
| 45 |
+
|
| 46 |
app.state.qdrant = state_overrides.get("qdrant", mock_qdrant)
|
| 47 |
app.state.embedder = state_overrides.get("embedder", MagicMock())
|
| 48 |
app.state.detector = state_overrides.get("detector", MagicMock())
|
| 49 |
+
app.state.explainer = state_overrides.get("explainer", mock_explainer)
|
| 50 |
app.state.cache = state_overrides.get("cache", mock_cache)
|
| 51 |
|
| 52 |
return app
|
|
|
|
| 59 |
return TestClient(app)
|
| 60 |
|
| 61 |
|
| 62 |
+
@pytest.fixture
|
| 63 |
+
def sample_product() -> ProductScore:
|
| 64 |
+
"""Sample product for recommendation tests."""
|
| 65 |
+
return ProductScore(
|
| 66 |
+
product_id="P1",
|
| 67 |
+
score=0.9,
|
| 68 |
+
chunk_count=2,
|
| 69 |
+
avg_rating=4.5,
|
| 70 |
+
evidence=[
|
| 71 |
+
RetrievedChunk(
|
| 72 |
+
text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"
|
| 73 |
+
),
|
| 74 |
+
],
|
| 75 |
+
)
|
| 76 |
|
| 77 |
+
|
| 78 |
+
class TestHealthEndpoint:
|
| 79 |
+
@patch("sage.api.routes.collection_exists", return_value=True)
|
| 80 |
+
def test_healthy_when_all_components_available(self, mock_collection_exists):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
app = _make_app()
|
| 82 |
+
with TestClient(app) as c:
|
| 83 |
+
resp = c.get("/health")
|
| 84 |
+
assert resp.status_code == 200
|
| 85 |
+
data = resp.json()
|
| 86 |
+
assert data["status"] == "healthy"
|
| 87 |
+
assert data["qdrant_connected"] is True
|
| 88 |
+
assert data["llm_reachable"] is True
|
| 89 |
+
|
| 90 |
+
@patch("sage.api.routes.collection_exists", return_value=True)
|
| 91 |
+
def test_degraded_when_qdrant_available_but_llm_unavailable(
|
| 92 |
+
self, mock_collection_exists
|
| 93 |
+
):
|
| 94 |
+
app = _make_app(explainer=None)
|
| 95 |
+
with TestClient(app) as c:
|
| 96 |
+
resp = c.get("/health")
|
| 97 |
+
assert resp.status_code == 200
|
| 98 |
+
data = resp.json()
|
| 99 |
+
assert data["status"] == "degraded"
|
| 100 |
+
assert data["qdrant_connected"] is True
|
| 101 |
+
assert data["llm_reachable"] is False
|
| 102 |
|
| 103 |
+
@patch("sage.api.routes.collection_exists", return_value=False)
|
| 104 |
+
def test_unhealthy_when_qdrant_unavailable(self, mock_collection_exists):
|
| 105 |
+
app = _make_app()
|
| 106 |
+
with TestClient(app) as c:
|
| 107 |
+
resp = c.get("/health")
|
| 108 |
+
assert resp.status_code == 200
|
| 109 |
+
data = resp.json()
|
| 110 |
+
assert data["status"] == "unhealthy"
|
| 111 |
+
assert data["qdrant_connected"] is False
|
|
|
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
class TestRecommendEndpoint:
|
| 115 |
def test_missing_query_returns_422(self, client):
|
| 116 |
+
# POST with empty body should fail validation
|
| 117 |
+
resp = client.post("/recommend", json={})
|
| 118 |
assert resp.status_code == 422
|
| 119 |
|
| 120 |
+
@patch("sage.api.routes.get_candidates", return_value=[])
|
| 121 |
+
def test_empty_results(self, mock_get_candidates, client):
|
| 122 |
+
resp = client.post("/recommend", json={"query": "test query", "explain": False})
|
| 123 |
+
assert resp.status_code == 200
|
| 124 |
+
data = resp.json()
|
| 125 |
+
assert data["recommendations"] == []
|
| 126 |
|
| 127 |
+
@patch("sage.api.routes.get_candidates")
|
| 128 |
+
def test_returns_products_without_explain(
|
| 129 |
+
self, mock_get_candidates, sample_product
|
| 130 |
+
):
|
| 131 |
+
mock_get_candidates.return_value = [sample_product]
|
| 132 |
+
app = _make_app()
|
| 133 |
+
with TestClient(app) as c:
|
| 134 |
+
resp = c.post("/recommend", json={"query": "headphones", "explain": False})
|
| 135 |
assert resp.status_code == 200
|
| 136 |
data = resp.json()
|
| 137 |
+
assert len(data["recommendations"]) == 1
|
| 138 |
+
rec = data["recommendations"][0]
|
| 139 |
+
assert rec["product_id"] == "P1"
|
| 140 |
+
assert rec["rank"] == 1
|
| 141 |
+
# Response uses 'score' not 'relevance_score' (killer demo format)
|
| 142 |
+
assert "score" in rec
|
| 143 |
+
assert "explanation" not in rec or rec["explanation"] is None
|
| 144 |
+
|
| 145 |
+
@patch("sage.api.routes.get_candidates")
|
| 146 |
+
def test_request_with_filters(self, mock_get_candidates, sample_product):
|
| 147 |
+
mock_get_candidates.return_value = [sample_product]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
app = _make_app()
|
| 149 |
+
with TestClient(app) as c:
|
| 150 |
+
resp = c.post(
|
| 151 |
+
"/recommend",
|
| 152 |
+
json={
|
| 153 |
+
"query": "laptop for video editing",
|
| 154 |
+
"k": 5,
|
| 155 |
+
"filters": {"min_rating": 4.5, "max_price": 1500},
|
| 156 |
+
"explain": False,
|
| 157 |
+
},
|
| 158 |
+
)
|
| 159 |
+
assert resp.status_code == 200
|
| 160 |
+
data = resp.json()
|
| 161 |
+
assert len(data["recommendations"]) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
@patch("sage.api.routes.get_candidates")
|
| 164 |
+
def test_explainer_unavailable_returns_503(
|
| 165 |
+
self, mock_get_candidates, sample_product
|
| 166 |
+
):
|
| 167 |
+
mock_get_candidates.return_value = [sample_product]
|
| 168 |
mock_embedder = MagicMock()
|
| 169 |
mock_embedder.embed_single_query.return_value = [0.1] * 384
|
| 170 |
app = _make_app(explainer=None, embedder=mock_embedder)
|
| 171 |
+
with TestClient(app) as c:
|
| 172 |
+
resp = c.post("/recommend", json={"query": "headphones", "explain": True})
|
| 173 |
+
assert resp.status_code == 503
|
| 174 |
+
assert "unavailable" in resp.json()["error"].lower()
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
class TestCacheEndpoints:
|
tests/test_evidence.py
CHANGED
|
@@ -3,51 +3,29 @@
|
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
from sage.core.evidence import check_evidence_quality, generate_refusal_message
|
| 6 |
-
from sage.core.models import ProductScore, RetrievedChunk
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def _product(score: float, n_chunks: int, text_len: int = 200) -> ProductScore:
|
| 10 |
-
"""Build a ProductScore with n evidence chunks."""
|
| 11 |
-
evidence = [
|
| 12 |
-
RetrievedChunk(
|
| 13 |
-
text="x" * text_len,
|
| 14 |
-
score=score - i * 0.01,
|
| 15 |
-
product_id="P1",
|
| 16 |
-
rating=4.5,
|
| 17 |
-
review_id=f"rev_{i}",
|
| 18 |
-
)
|
| 19 |
-
for i in range(n_chunks)
|
| 20 |
-
]
|
| 21 |
-
return ProductScore(
|
| 22 |
-
product_id="P1",
|
| 23 |
-
score=score,
|
| 24 |
-
chunk_count=n_chunks,
|
| 25 |
-
avg_rating=4.5,
|
| 26 |
-
evidence=evidence,
|
| 27 |
-
)
|
| 28 |
|
| 29 |
|
| 30 |
class TestCheckEvidenceQuality:
|
| 31 |
-
def test_sufficient_evidence_passes(self):
|
| 32 |
-
product =
|
| 33 |
quality = check_evidence_quality(product)
|
| 34 |
assert quality.is_sufficient is True
|
| 35 |
assert quality.failure_reason is None
|
| 36 |
|
| 37 |
-
def test_too_few_chunks_fails(self):
|
| 38 |
-
product =
|
| 39 |
quality = check_evidence_quality(product, min_chunks=2)
|
| 40 |
assert quality.is_sufficient is False
|
| 41 |
assert "chunk" in quality.failure_reason.lower()
|
| 42 |
|
| 43 |
-
def test_too_few_tokens_fails(self):
|
| 44 |
-
product =
|
| 45 |
quality = check_evidence_quality(product, min_tokens=50)
|
| 46 |
assert quality.is_sufficient is False
|
| 47 |
assert "token" in quality.failure_reason.lower()
|
| 48 |
|
| 49 |
-
def test_low_relevance_fails(self):
|
| 50 |
-
product =
|
| 51 |
quality = check_evidence_quality(product, min_score=0.7)
|
| 52 |
assert quality.is_sufficient is False
|
| 53 |
assert (
|
|
@@ -55,34 +33,34 @@ class TestCheckEvidenceQuality:
|
|
| 55 |
or "score" in quality.failure_reason.lower()
|
| 56 |
)
|
| 57 |
|
| 58 |
-
def test_tracks_chunk_count(self):
|
| 59 |
-
product =
|
| 60 |
quality = check_evidence_quality(product)
|
| 61 |
assert quality.chunk_count == 4
|
| 62 |
|
| 63 |
-
def test_tracks_top_score(self):
|
| 64 |
-
product =
|
| 65 |
quality = check_evidence_quality(product)
|
| 66 |
assert quality.top_score == pytest.approx(0.92, abs=0.01)
|
| 67 |
|
| 68 |
|
| 69 |
class TestGenerateRefusalMessage:
|
| 70 |
-
def test_generates_message_for_insufficient_chunks(self):
|
| 71 |
-
product =
|
| 72 |
quality = check_evidence_quality(product, min_chunks=2)
|
| 73 |
msg = generate_refusal_message("wireless headphones", quality)
|
| 74 |
assert isinstance(msg, str)
|
| 75 |
assert len(msg) > 0
|
| 76 |
|
| 77 |
-
def test_generates_message_for_low_relevance(self):
|
| 78 |
-
product =
|
| 79 |
quality = check_evidence_quality(product, min_score=0.7)
|
| 80 |
msg = generate_refusal_message("laptop charger", quality)
|
| 81 |
assert isinstance(msg, str)
|
| 82 |
assert len(msg) > 0
|
| 83 |
|
| 84 |
-
def test_includes_query_context(self):
|
| 85 |
-
product =
|
| 86 |
quality = check_evidence_quality(product, min_chunks=2)
|
| 87 |
msg = generate_refusal_message("bluetooth speaker", quality)
|
| 88 |
# Message should reference the query or product context
|
|
|
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
from sage.core.evidence import check_evidence_quality, generate_refusal_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class TestCheckEvidenceQuality:
|
| 9 |
+
def test_sufficient_evidence_passes(self, make_product):
|
| 10 |
+
product = make_product(score=0.85, n_chunks=3, text_len=300)
|
| 11 |
quality = check_evidence_quality(product)
|
| 12 |
assert quality.is_sufficient is True
|
| 13 |
assert quality.failure_reason is None
|
| 14 |
|
| 15 |
+
def test_too_few_chunks_fails(self, make_product):
|
| 16 |
+
product = make_product(score=0.85, n_chunks=1, text_len=300)
|
| 17 |
quality = check_evidence_quality(product, min_chunks=2)
|
| 18 |
assert quality.is_sufficient is False
|
| 19 |
assert "chunk" in quality.failure_reason.lower()
|
| 20 |
|
| 21 |
+
def test_too_few_tokens_fails(self, make_product):
|
| 22 |
+
product = make_product(score=0.85, n_chunks=3, text_len=5)
|
| 23 |
quality = check_evidence_quality(product, min_tokens=50)
|
| 24 |
assert quality.is_sufficient is False
|
| 25 |
assert "token" in quality.failure_reason.lower()
|
| 26 |
|
| 27 |
+
def test_low_relevance_fails(self, make_product):
|
| 28 |
+
product = make_product(score=0.3, n_chunks=3, text_len=300)
|
| 29 |
quality = check_evidence_quality(product, min_score=0.7)
|
| 30 |
assert quality.is_sufficient is False
|
| 31 |
assert (
|
|
|
|
| 33 |
or "score" in quality.failure_reason.lower()
|
| 34 |
)
|
| 35 |
|
| 36 |
+
def test_tracks_chunk_count(self, make_product):
|
| 37 |
+
product = make_product(score=0.85, n_chunks=4, text_len=200)
|
| 38 |
quality = check_evidence_quality(product)
|
| 39 |
assert quality.chunk_count == 4
|
| 40 |
|
| 41 |
+
def test_tracks_top_score(self, make_product):
|
| 42 |
+
product = make_product(score=0.92, n_chunks=3)
|
| 43 |
quality = check_evidence_quality(product)
|
| 44 |
assert quality.top_score == pytest.approx(0.92, abs=0.01)
|
| 45 |
|
| 46 |
|
| 47 |
class TestGenerateRefusalMessage:
|
| 48 |
+
def test_generates_message_for_insufficient_chunks(self, make_product):
|
| 49 |
+
product = make_product(score=0.85, n_chunks=1, text_len=300)
|
| 50 |
quality = check_evidence_quality(product, min_chunks=2)
|
| 51 |
msg = generate_refusal_message("wireless headphones", quality)
|
| 52 |
assert isinstance(msg, str)
|
| 53 |
assert len(msg) > 0
|
| 54 |
|
| 55 |
+
def test_generates_message_for_low_relevance(self, make_product):
|
| 56 |
+
product = make_product(score=0.3, n_chunks=3, text_len=300)
|
| 57 |
quality = check_evidence_quality(product, min_score=0.7)
|
| 58 |
msg = generate_refusal_message("laptop charger", quality)
|
| 59 |
assert isinstance(msg, str)
|
| 60 |
assert len(msg) > 0
|
| 61 |
|
| 62 |
+
def test_includes_query_context(self, make_product):
|
| 63 |
+
product = make_product(score=0.3, n_chunks=1)
|
| 64 |
quality = check_evidence_quality(product, min_chunks=2)
|
| 65 |
msg = generate_refusal_message("bluetooth speaker", quality)
|
| 66 |
# Message should reference the query or product context
|