Apply ruff formatting
Browse files- sage/adapters/hhem.py +12 -3
- sage/adapters/llm.py +3 -1
- sage/adapters/vector_store.py +1 -2
- sage/api/app.py +10 -12
- sage/api/metrics.py +7 -1
- sage/api/middleware.py +6 -5
- sage/api/routes.py +43 -11
- sage/api/run.py +2 -1
- sage/config/logging.py +40 -13
- sage/core/aggregation.py +18 -14
- sage/core/chunking.py +6 -6
- sage/core/evidence.py +12 -6
- sage/core/models.py +34 -4
- sage/core/prompts.py +1 -1
- sage/core/verification.py +7 -8
- sage/services/__init__.py +1 -0
- sage/services/baselines.py +8 -7
- sage/services/cache.py +8 -2
- sage/services/cold_start.py +4 -1
- sage/services/evaluation.py +3 -1
- sage/services/explanation.py +12 -6
- sage/services/faithfulness.py +10 -7
- sage/services/retrieval.py +15 -4
- sage/utils.py +1 -0
- scripts/build_eval_dataset.py +235 -40
- scripts/build_natural_eval_dataset.py +21 -25
- scripts/demo.py +7 -6
- scripts/e2e_success_rate.py +66 -14
- scripts/eda.py +78 -29
- scripts/evaluation.py +58 -18
- scripts/explanation.py +53 -15
- scripts/faithfulness.py +83 -36
- scripts/human_eval.py +57 -26
- scripts/pipeline.py +63 -20
- scripts/sanity_checks.py +93 -28
- scripts/summary.py +22 -7
- tests/test_aggregation.py +9 -3
- tests/test_api.py +29 -7
- tests/test_chunking.py +3 -1
- tests/test_evidence.py +4 -1
- tests/test_faithfulness.py +9 -2
- tests/test_models.py +22 -6
sage/adapters/hhem.py
CHANGED
|
@@ -148,7 +148,9 @@ class HallucinationDetector:
|
|
| 148 |
remaining = [t for t in evidence_texts if hyp_lower not in t.lower()]
|
| 149 |
evidence_texts = containing + remaining
|
| 150 |
|
| 151 |
-
hypothesis_tokens = len(
|
|
|
|
|
|
|
| 152 |
budget = HHEM_MAX_TOKENS - HHEM_TEMPLATE_OVERHEAD - hypothesis_tokens
|
| 153 |
|
| 154 |
kept = []
|
|
@@ -253,13 +255,20 @@ class HallucinationDetector:
|
|
| 253 |
List of ClaimResult objects, one per claim.
|
| 254 |
"""
|
| 255 |
pairs = [
|
| 256 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
for claim in claims
|
| 258 |
]
|
| 259 |
scores = self._predict(pairs)
|
| 260 |
|
| 261 |
return [
|
| 262 |
-
ClaimResult(
|
|
|
|
|
|
|
| 263 |
for claim, score in zip(claims, scores)
|
| 264 |
]
|
| 265 |
|
|
|
|
| 148 |
remaining = [t for t in evidence_texts if hyp_lower not in t.lower()]
|
| 149 |
evidence_texts = containing + remaining
|
| 150 |
|
| 151 |
+
hypothesis_tokens = len(
|
| 152 |
+
self.tokenizer(hypothesis, add_special_tokens=False).input_ids
|
| 153 |
+
)
|
| 154 |
budget = HHEM_MAX_TOKENS - HHEM_TEMPLATE_OVERHEAD - hypothesis_tokens
|
| 155 |
|
| 156 |
kept = []
|
|
|
|
| 255 |
List of ClaimResult objects, one per claim.
|
| 256 |
"""
|
| 257 |
pairs = [
|
| 258 |
+
(
|
| 259 |
+
self._format_premise(
|
| 260 |
+
evidence_texts, hypothesis=claim, prioritize_hypothesis=True
|
| 261 |
+
),
|
| 262 |
+
claim,
|
| 263 |
+
)
|
| 264 |
for claim in claims
|
| 265 |
]
|
| 266 |
scores = self._predict(pairs)
|
| 267 |
|
| 268 |
return [
|
| 269 |
+
ClaimResult(
|
| 270 |
+
claim=claim, score=score, is_hallucinated=score < self.threshold
|
| 271 |
+
)
|
| 272 |
for claim, score in zip(claims, scores)
|
| 273 |
]
|
| 274 |
|
sage/adapters/llm.py
CHANGED
|
@@ -356,7 +356,9 @@ def get_llm_client(provider: str | None = None) -> LLMClient:
|
|
| 356 |
elif provider == "openai":
|
| 357 |
return OpenAIClient()
|
| 358 |
else:
|
| 359 |
-
raise ValueError(
|
|
|
|
|
|
|
| 360 |
|
| 361 |
|
| 362 |
__all__ = [
|
|
|
|
| 356 |
elif provider == "openai":
|
| 357 |
return OpenAIClient()
|
| 358 |
else:
|
| 359 |
+
raise ValueError(
|
| 360 |
+
f"Unknown LLM provider: {provider}. Use 'anthropic' or 'openai'."
|
| 361 |
+
)
|
| 362 |
|
| 363 |
|
| 364 |
__all__ = [
|
sage/adapters/vector_store.py
CHANGED
|
@@ -42,8 +42,7 @@ def get_client():
|
|
| 42 |
from qdrant_client import QdrantClient
|
| 43 |
except ImportError:
|
| 44 |
raise ImportError(
|
| 45 |
-
"qdrant-client package required. "
|
| 46 |
-
"Install with: pip install qdrant-client"
|
| 47 |
)
|
| 48 |
|
| 49 |
if QDRANT_API_KEY:
|
|
|
|
| 42 |
from qdrant_client import QdrantClient
|
| 43 |
except ImportError:
|
| 44 |
raise ImportError(
|
| 45 |
+
"qdrant-client package required. Install with: pip install qdrant-client"
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
if QDRANT_API_KEY:
|
sage/api/app.py
CHANGED
|
@@ -31,24 +31,19 @@ async def _lifespan(app: FastAPI):
|
|
| 31 |
|
| 32 |
# Validate LLM credentials early
|
| 33 |
from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
|
|
|
|
| 34 |
if not ANTHROPIC_API_KEY and not OPENAI_API_KEY:
|
| 35 |
logger.error(
|
| 36 |
-
"No LLM API key set -- add ANTHROPIC_API_KEY "
|
| 37 |
-
"or OPENAI_API_KEY to .env"
|
| 38 |
)
|
| 39 |
elif LLM_PROVIDER == "anthropic" and not ANTHROPIC_API_KEY:
|
| 40 |
-
logger.warning(
|
| 41 |
-
"LLM_PROVIDER=anthropic but ANTHROPIC_API_KEY "
|
| 42 |
-
"is not set"
|
| 43 |
-
)
|
| 44 |
elif LLM_PROVIDER == "openai" and not OPENAI_API_KEY:
|
| 45 |
-
logger.warning(
|
| 46 |
-
"LLM_PROVIDER=openai but OPENAI_API_KEY "
|
| 47 |
-
"is not set"
|
| 48 |
-
)
|
| 49 |
|
| 50 |
# Embedder (loads E5-small model) -- required for all requests
|
| 51 |
from sage.adapters.embeddings import get_embedder
|
|
|
|
| 52 |
try:
|
| 53 |
app.state.embedder = get_embedder()
|
| 54 |
logger.info("Embedder loaded")
|
|
@@ -58,6 +53,7 @@ async def _lifespan(app: FastAPI):
|
|
| 58 |
|
| 59 |
# Qdrant client
|
| 60 |
from sage.adapters.vector_store import get_client, collection_exists
|
|
|
|
| 61 |
app.state.qdrant = get_client()
|
| 62 |
try:
|
| 63 |
if collection_exists(app.state.qdrant):
|
|
@@ -69,6 +65,7 @@ async def _lifespan(app: FastAPI):
|
|
| 69 |
|
| 70 |
# HHEM hallucination detector (loads T5 model) -- required for grounding
|
| 71 |
from sage.adapters.hhem import HallucinationDetector
|
|
|
|
| 72 |
try:
|
| 73 |
app.state.detector = HallucinationDetector()
|
| 74 |
logger.info("HHEM detector loaded")
|
|
@@ -78,18 +75,19 @@ async def _lifespan(app: FastAPI):
|
|
| 78 |
|
| 79 |
# LLM explainer -- graceful degradation if unavailable
|
| 80 |
from sage.services.explanation import Explainer
|
|
|
|
| 81 |
try:
|
| 82 |
app.state.explainer = Explainer()
|
| 83 |
logger.info("Explainer ready (%s)", app.state.explainer.model)
|
| 84 |
except Exception:
|
| 85 |
logger.exception(
|
| 86 |
-
"Failed to initialize explainer -- "
|
| 87 |
-
"explain=true requests will fail"
|
| 88 |
)
|
| 89 |
app.state.explainer = None
|
| 90 |
|
| 91 |
# Semantic cache
|
| 92 |
from sage.services.cache import SemanticCache
|
|
|
|
| 93 |
app.state.cache = SemanticCache()
|
| 94 |
logger.info("Semantic cache initialized")
|
| 95 |
|
|
|
|
| 31 |
|
| 32 |
# Validate LLM credentials early
|
| 33 |
from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
|
| 34 |
+
|
| 35 |
if not ANTHROPIC_API_KEY and not OPENAI_API_KEY:
|
| 36 |
logger.error(
|
| 37 |
+
"No LLM API key set -- add ANTHROPIC_API_KEY or OPENAI_API_KEY to .env"
|
|
|
|
| 38 |
)
|
| 39 |
elif LLM_PROVIDER == "anthropic" and not ANTHROPIC_API_KEY:
|
| 40 |
+
logger.warning("LLM_PROVIDER=anthropic but ANTHROPIC_API_KEY is not set")
|
|
|
|
|
|
|
|
|
|
| 41 |
elif LLM_PROVIDER == "openai" and not OPENAI_API_KEY:
|
| 42 |
+
logger.warning("LLM_PROVIDER=openai but OPENAI_API_KEY is not set")
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Embedder (loads E5-small model) -- required for all requests
|
| 45 |
from sage.adapters.embeddings import get_embedder
|
| 46 |
+
|
| 47 |
try:
|
| 48 |
app.state.embedder = get_embedder()
|
| 49 |
logger.info("Embedder loaded")
|
|
|
|
| 53 |
|
| 54 |
# Qdrant client
|
| 55 |
from sage.adapters.vector_store import get_client, collection_exists
|
| 56 |
+
|
| 57 |
app.state.qdrant = get_client()
|
| 58 |
try:
|
| 59 |
if collection_exists(app.state.qdrant):
|
|
|
|
| 65 |
|
| 66 |
# HHEM hallucination detector (loads T5 model) -- required for grounding
|
| 67 |
from sage.adapters.hhem import HallucinationDetector
|
| 68 |
+
|
| 69 |
try:
|
| 70 |
app.state.detector = HallucinationDetector()
|
| 71 |
logger.info("HHEM detector loaded")
|
|
|
|
| 75 |
|
| 76 |
# LLM explainer -- graceful degradation if unavailable
|
| 77 |
from sage.services.explanation import Explainer
|
| 78 |
+
|
| 79 |
try:
|
| 80 |
app.state.explainer = Explainer()
|
| 81 |
logger.info("Explainer ready (%s)", app.state.explainer.model)
|
| 82 |
except Exception:
|
| 83 |
logger.exception(
|
| 84 |
+
"Failed to initialize explainer -- explain=true requests will fail"
|
|
|
|
| 85 |
)
|
| 86 |
app.state.explainer = None
|
| 87 |
|
| 88 |
# Semantic cache
|
| 89 |
from sage.services.cache import SemanticCache
|
| 90 |
+
|
| 91 |
app.state.cache = SemanticCache()
|
| 92 |
logger.info("Semantic cache initialized")
|
| 93 |
|
sage/api/metrics.py
CHANGED
|
@@ -16,7 +16,12 @@ logger = get_logger(__name__)
|
|
| 16 |
# ---------------------------------------------------------------------------
|
| 17 |
|
| 18 |
try:
|
| 19 |
-
from prometheus_client import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
REQUEST_COUNT = Counter(
|
| 22 |
"sage_requests_total",
|
|
@@ -48,6 +53,7 @@ except ImportError:
|
|
| 48 |
# Public helpers
|
| 49 |
# ---------------------------------------------------------------------------
|
| 50 |
|
|
|
|
| 51 |
def record_request(endpoint: str, method: str, status: int) -> None:
|
| 52 |
"""Increment the request counter."""
|
| 53 |
if _PROMETHEUS_AVAILABLE:
|
|
|
|
| 16 |
# ---------------------------------------------------------------------------
|
| 17 |
|
| 18 |
try:
|
| 19 |
+
from prometheus_client import (
|
| 20 |
+
Counter,
|
| 21 |
+
Histogram,
|
| 22 |
+
generate_latest,
|
| 23 |
+
CONTENT_TYPE_LATEST,
|
| 24 |
+
)
|
| 25 |
|
| 26 |
REQUEST_COUNT = Counter(
|
| 27 |
"sage_requests_total",
|
|
|
|
| 53 |
# Public helpers
|
| 54 |
# ---------------------------------------------------------------------------
|
| 55 |
|
| 56 |
+
|
| 57 |
def record_request(endpoint: str, method: str, status: int) -> None:
|
| 58 |
"""Increment the request counter."""
|
| 59 |
if _PROMETHEUS_AVAILABLE:
|
sage/api/middleware.py
CHANGED
|
@@ -69,9 +69,7 @@ class LatencyMiddleware:
|
|
| 69 |
# The Prometheus histogram (in finally) measures total time.
|
| 70 |
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 71 |
headers = list(message.get("headers", []))
|
| 72 |
-
headers.append(
|
| 73 |
-
(b"x-response-time-ms", f"{elapsed_ms:.1f}".encode())
|
| 74 |
-
)
|
| 75 |
headers.append((b"x-request-id", request_id.encode()))
|
| 76 |
message = {**message, "headers": headers}
|
| 77 |
await send(message)
|
|
@@ -88,6 +86,9 @@ class LatencyMiddleware:
|
|
| 88 |
if path not in _QUIET_PATHS:
|
| 89 |
logger.info(
|
| 90 |
"%s %s %d %.1fms [%s]",
|
| 91 |
-
method,
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
|
|
|
| 69 |
# The Prometheus histogram (in finally) measures total time.
|
| 70 |
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 71 |
headers = list(message.get("headers", []))
|
| 72 |
+
headers.append((b"x-response-time-ms", f"{elapsed_ms:.1f}".encode()))
|
|
|
|
|
|
|
| 73 |
headers.append((b"x-request-id", request_id.encode()))
|
| 74 |
message = {**message, "headers": headers}
|
| 75 |
await send(message)
|
|
|
|
| 86 |
if path not in _QUIET_PATHS:
|
| 87 |
logger.info(
|
| 88 |
"%s %s %d %.1fms [%s]",
|
| 89 |
+
method,
|
| 90 |
+
path,
|
| 91 |
+
status,
|
| 92 |
+
elapsed_ms,
|
| 93 |
+
request_id,
|
| 94 |
)
|
sage/api/routes.py
CHANGED
|
@@ -24,7 +24,12 @@ from pydantic import BaseModel
|
|
| 24 |
from sage.adapters.vector_store import collection_exists
|
| 25 |
from sage.api.metrics import metrics_response, record_cache_event
|
| 26 |
from sage.config import MAX_EVIDENCE, get_logger
|
| 27 |
-
from sage.core import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from sage.services.retrieval import get_candidates
|
| 29 |
|
| 30 |
# Cap parallel LLM+HHEM workers per request. With k=10 and concurrent
|
|
@@ -41,6 +46,7 @@ router = APIRouter()
|
|
| 41 |
# Response models
|
| 42 |
# ---------------------------------------------------------------------------
|
| 43 |
|
|
|
|
| 44 |
class EvidenceSource(BaseModel):
|
| 45 |
id: str
|
| 46 |
text: str
|
|
@@ -95,6 +101,7 @@ class CacheStatsResponse(BaseModel):
|
|
| 95 |
# Shared helpers
|
| 96 |
# ---------------------------------------------------------------------------
|
| 97 |
|
|
|
|
| 98 |
@dataclass
|
| 99 |
class RecommendParams:
|
| 100 |
"""Query parameters shared by /recommend and /recommend/stream."""
|
|
@@ -105,7 +112,9 @@ class RecommendParams:
|
|
| 105 |
|
| 106 |
|
| 107 |
def _fetch_products(
|
| 108 |
-
params: RecommendParams,
|
|
|
|
|
|
|
| 109 |
) -> list[ProductScore]:
|
| 110 |
"""Run candidate generation with lifespan-managed singletons."""
|
| 111 |
return get_candidates(
|
|
@@ -138,6 +147,7 @@ def _build_evidence_list(result: ExplanationResult) -> list[dict]:
|
|
| 138 |
# Health
|
| 139 |
# ---------------------------------------------------------------------------
|
| 140 |
|
|
|
|
| 141 |
@router.get("/health", response_model=HealthResponse)
|
| 142 |
def health(request: Request):
|
| 143 |
"""Deployment readiness probe. Checks Qdrant connectivity.
|
|
@@ -159,6 +169,7 @@ def health(request: Request):
|
|
| 159 |
# Recommend (non-streaming)
|
| 160 |
# ---------------------------------------------------------------------------
|
| 161 |
|
|
|
|
| 162 |
@router.get(
|
| 163 |
"/recommend",
|
| 164 |
response_model=RecommendResponse,
|
|
@@ -208,20 +219,27 @@ def recommend(
|
|
| 208 |
# HHEM model in eval() + no_grad() = read-only forward
|
| 209 |
# pass with no state mutation. Tokenizer is stateless.
|
| 210 |
er = explainer.generate_explanation(
|
| 211 |
-
query=q,
|
|
|
|
|
|
|
| 212 |
)
|
| 213 |
hr = detector.check_explanation(
|
| 214 |
evidence_texts=er.evidence_texts,
|
| 215 |
explanation=er.explanation,
|
| 216 |
)
|
| 217 |
-
cr = verify_citations(
|
|
|
|
|
|
|
| 218 |
return er, hr, cr
|
| 219 |
|
| 220 |
-
with ThreadPoolExecutor(
|
|
|
|
|
|
|
| 221 |
results = list(pool.map(_explain, products))
|
| 222 |
|
| 223 |
for i, (product, (er, hr, cr)) in enumerate(
|
| 224 |
-
zip(products, results),
|
|
|
|
| 225 |
):
|
| 226 |
rec = _build_product_dict(i, product)
|
| 227 |
rec["explanation"] = er.explanation
|
|
@@ -257,6 +275,7 @@ def recommend(
|
|
| 257 |
# Recommend (SSE streaming)
|
| 258 |
# ---------------------------------------------------------------------------
|
| 259 |
|
|
|
|
| 260 |
def _sse_event(event: str, data: str) -> str:
|
| 261 |
"""Format a single SSE event."""
|
| 262 |
return f"event: {event}\ndata: {data}\n\n"
|
|
@@ -267,9 +286,16 @@ def _stream_recommendations(
|
|
| 267 |
app,
|
| 268 |
) -> Iterator[str]:
|
| 269 |
"""Generator that yields SSE events for streaming recommendations."""
|
| 270 |
-
yield _sse_event(
|
| 271 |
-
"
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
try:
|
| 275 |
products = _fetch_products(params, app)
|
|
@@ -285,7 +311,9 @@ def _stream_recommendations(
|
|
| 285 |
|
| 286 |
explainer = app.state.explainer
|
| 287 |
if explainer is None:
|
| 288 |
-
yield _sse_event(
|
|
|
|
|
|
|
| 289 |
yield _sse_event("done", json.dumps({"status": "error"}))
|
| 290 |
return
|
| 291 |
|
|
@@ -314,7 +342,9 @@ def _stream_recommendations(
|
|
| 314 |
yield _sse_event("refusal", json.dumps({"detail": str(exc)}))
|
| 315 |
except Exception:
|
| 316 |
logger.exception("Streaming error for product %s", product.product_id)
|
| 317 |
-
yield _sse_event(
|
|
|
|
|
|
|
| 318 |
|
| 319 |
yield _sse_event("done", json.dumps({"status": "complete"}))
|
| 320 |
|
|
@@ -341,6 +371,7 @@ def recommend_stream(
|
|
| 341 |
# Cache management
|
| 342 |
# ---------------------------------------------------------------------------
|
| 343 |
|
|
|
|
| 344 |
@router.get("/cache/stats", response_model=CacheStatsResponse)
|
| 345 |
def cache_stats(request: Request):
|
| 346 |
"""Return cache performance statistics."""
|
|
@@ -370,6 +401,7 @@ def cache_clear(request: Request):
|
|
| 370 |
# Prometheus metrics
|
| 371 |
# ---------------------------------------------------------------------------
|
| 372 |
|
|
|
|
| 373 |
@router.get("/metrics")
|
| 374 |
def metrics():
|
| 375 |
"""Prometheus metrics endpoint."""
|
|
|
|
| 24 |
from sage.adapters.vector_store import collection_exists
|
| 25 |
from sage.api.metrics import metrics_response, record_cache_event
|
| 26 |
from sage.config import MAX_EVIDENCE, get_logger
|
| 27 |
+
from sage.core import (
|
| 28 |
+
AggregationMethod,
|
| 29 |
+
ExplanationResult,
|
| 30 |
+
ProductScore,
|
| 31 |
+
verify_citations,
|
| 32 |
+
)
|
| 33 |
from sage.services.retrieval import get_candidates
|
| 34 |
|
| 35 |
# Cap parallel LLM+HHEM workers per request. With k=10 and concurrent
|
|
|
|
| 46 |
# Response models
|
| 47 |
# ---------------------------------------------------------------------------
|
| 48 |
|
| 49 |
+
|
| 50 |
class EvidenceSource(BaseModel):
|
| 51 |
id: str
|
| 52 |
text: str
|
|
|
|
| 101 |
# Shared helpers
|
| 102 |
# ---------------------------------------------------------------------------
|
| 103 |
|
| 104 |
+
|
| 105 |
@dataclass
|
| 106 |
class RecommendParams:
|
| 107 |
"""Query parameters shared by /recommend and /recommend/stream."""
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def _fetch_products(
|
| 115 |
+
params: RecommendParams,
|
| 116 |
+
app,
|
| 117 |
+
query_embedding=None,
|
| 118 |
) -> list[ProductScore]:
|
| 119 |
"""Run candidate generation with lifespan-managed singletons."""
|
| 120 |
return get_candidates(
|
|
|
|
| 147 |
# Health
|
| 148 |
# ---------------------------------------------------------------------------
|
| 149 |
|
| 150 |
+
|
| 151 |
@router.get("/health", response_model=HealthResponse)
|
| 152 |
def health(request: Request):
|
| 153 |
"""Deployment readiness probe. Checks Qdrant connectivity.
|
|
|
|
| 169 |
# Recommend (non-streaming)
|
| 170 |
# ---------------------------------------------------------------------------
|
| 171 |
|
| 172 |
+
|
| 173 |
@router.get(
|
| 174 |
"/recommend",
|
| 175 |
response_model=RecommendResponse,
|
|
|
|
| 219 |
# HHEM model in eval() + no_grad() = read-only forward
|
| 220 |
# pass with no state mutation. Tokenizer is stateless.
|
| 221 |
er = explainer.generate_explanation(
|
| 222 |
+
query=q,
|
| 223 |
+
product=product,
|
| 224 |
+
max_evidence=MAX_EVIDENCE,
|
| 225 |
)
|
| 226 |
hr = detector.check_explanation(
|
| 227 |
evidence_texts=er.evidence_texts,
|
| 228 |
explanation=er.explanation,
|
| 229 |
)
|
| 230 |
+
cr = verify_citations(
|
| 231 |
+
er.explanation, er.evidence_ids, er.evidence_texts
|
| 232 |
+
)
|
| 233 |
return er, hr, cr
|
| 234 |
|
| 235 |
+
with ThreadPoolExecutor(
|
| 236 |
+
max_workers=min(len(products), _MAX_EXPLAIN_WORKERS)
|
| 237 |
+
) as pool:
|
| 238 |
results = list(pool.map(_explain, products))
|
| 239 |
|
| 240 |
for i, (product, (er, hr, cr)) in enumerate(
|
| 241 |
+
zip(products, results),
|
| 242 |
+
1,
|
| 243 |
):
|
| 244 |
rec = _build_product_dict(i, product)
|
| 245 |
rec["explanation"] = er.explanation
|
|
|
|
| 275 |
# Recommend (SSE streaming)
|
| 276 |
# ---------------------------------------------------------------------------
|
| 277 |
|
| 278 |
+
|
| 279 |
def _sse_event(event: str, data: str) -> str:
|
| 280 |
"""Format a single SSE event."""
|
| 281 |
return f"event: {event}\ndata: {data}\n\n"
|
|
|
|
| 286 |
app,
|
| 287 |
) -> Iterator[str]:
|
| 288 |
"""Generator that yields SSE events for streaming recommendations."""
|
| 289 |
+
yield _sse_event(
|
| 290 |
+
"metadata",
|
| 291 |
+
json.dumps(
|
| 292 |
+
{
|
| 293 |
+
"verified": False,
|
| 294 |
+
"cache": False,
|
| 295 |
+
"hhem": False,
|
| 296 |
+
}
|
| 297 |
+
),
|
| 298 |
+
)
|
| 299 |
|
| 300 |
try:
|
| 301 |
products = _fetch_products(params, app)
|
|
|
|
| 311 |
|
| 312 |
explainer = app.state.explainer
|
| 313 |
if explainer is None:
|
| 314 |
+
yield _sse_event(
|
| 315 |
+
"error", json.dumps({"detail": "Explanation service unavailable"})
|
| 316 |
+
)
|
| 317 |
yield _sse_event("done", json.dumps({"status": "error"}))
|
| 318 |
return
|
| 319 |
|
|
|
|
| 342 |
yield _sse_event("refusal", json.dumps({"detail": str(exc)}))
|
| 343 |
except Exception:
|
| 344 |
logger.exception("Streaming error for product %s", product.product_id)
|
| 345 |
+
yield _sse_event(
|
| 346 |
+
"error", json.dumps({"detail": "Failed to generate explanation"})
|
| 347 |
+
)
|
| 348 |
|
| 349 |
yield _sse_event("done", json.dumps({"status": "complete"}))
|
| 350 |
|
|
|
|
| 371 |
# Cache management
|
| 372 |
# ---------------------------------------------------------------------------
|
| 373 |
|
| 374 |
+
|
| 375 |
@router.get("/cache/stats", response_model=CacheStatsResponse)
|
| 376 |
def cache_stats(request: Request):
|
| 377 |
"""Return cache performance statistics."""
|
|
|
|
| 401 |
# Prometheus metrics
|
| 402 |
# ---------------------------------------------------------------------------
|
| 403 |
|
| 404 |
+
|
| 405 |
@router.get("/metrics")
|
| 406 |
def metrics():
|
| 407 |
"""Prometheus metrics endpoint."""
|
sage/api/run.py
CHANGED
|
@@ -24,7 +24,8 @@ def main():
|
|
| 24 |
parser = argparse.ArgumentParser(description="Sage API server")
|
| 25 |
parser.add_argument("--host", default="0.0.0.0", help="Bind address")
|
| 26 |
parser.add_argument(
|
| 27 |
-
"--port",
|
|
|
|
| 28 |
default=int(os.getenv("PORT", "8000")),
|
| 29 |
help="Port (defaults to PORT env var, then 8000)",
|
| 30 |
)
|
|
|
|
| 24 |
parser = argparse.ArgumentParser(description="Sage API server")
|
| 25 |
parser.add_argument("--host", default="0.0.0.0", help="Bind address")
|
| 26 |
parser.add_argument(
|
| 27 |
+
"--port",
|
| 28 |
+
type=int,
|
| 29 |
default=int(os.getenv("PORT", "8000")),
|
| 30 |
help="Port (defaults to PORT env var, then 8000)",
|
| 31 |
)
|
sage/config/logging.py
CHANGED
|
@@ -28,27 +28,48 @@ LOG_FORMAT = os.getenv("SAGE_LOG_FORMAT", "console") # "console" or "json"
|
|
| 28 |
|
| 29 |
# Standard LogRecord attributes to ignore when extracting user-specified extras.
|
| 30 |
# These are built-in attributes from logging.LogRecord plus taskName from asyncio.
|
| 31 |
-
_STANDARD_LOG_ATTRS = frozenset(
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
# ---------------------------------------------------------------------------
|
| 41 |
# Custom Formatter (Console)
|
| 42 |
# ---------------------------------------------------------------------------
|
| 43 |
|
|
|
|
| 44 |
class ConsoleFormatter(logging.Formatter):
|
| 45 |
"""Human-readable formatter with visual hierarchy."""
|
| 46 |
|
| 47 |
COLORS = {
|
| 48 |
-
"DEBUG": "\033[36m",
|
| 49 |
-
"INFO": "\033[32m",
|
| 50 |
-
"WARNING": "\033[33m",
|
| 51 |
-
"ERROR": "\033[31m",
|
| 52 |
"CRITICAL": "\033[35m", # Magenta
|
| 53 |
"RESET": "\033[0m",
|
| 54 |
}
|
|
@@ -87,6 +108,7 @@ class ConsoleFormatter(logging.Formatter):
|
|
| 87 |
# Custom Formatter (JSON)
|
| 88 |
# ---------------------------------------------------------------------------
|
| 89 |
|
|
|
|
| 90 |
class JSONFormatter(logging.Formatter):
|
| 91 |
"""Machine-parseable JSON formatter for production."""
|
| 92 |
|
|
@@ -180,14 +202,19 @@ def get_logger(name: str) -> logging.Logger:
|
|
| 180 |
# Convenience functions for visual output
|
| 181 |
# ---------------------------------------------------------------------------
|
| 182 |
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
| 184 |
"""Log a visual banner for section headers."""
|
| 185 |
logger.info(char * width)
|
| 186 |
logger.info(title)
|
| 187 |
logger.info(char * width)
|
| 188 |
|
| 189 |
|
| 190 |
-
def log_section(
|
|
|
|
|
|
|
| 191 |
"""Log a section divider."""
|
| 192 |
logger.info("")
|
| 193 |
logger.info(char * width)
|
|
|
|
| 28 |
|
| 29 |
# Standard LogRecord attributes to ignore when extracting user-specified extras.
|
| 30 |
# These are built-in attributes from logging.LogRecord plus taskName from asyncio.
|
| 31 |
+
_STANDARD_LOG_ATTRS = frozenset(
|
| 32 |
+
{
|
| 33 |
+
"name",
|
| 34 |
+
"msg",
|
| 35 |
+
"args",
|
| 36 |
+
"created",
|
| 37 |
+
"filename",
|
| 38 |
+
"funcName",
|
| 39 |
+
"levelname",
|
| 40 |
+
"levelno",
|
| 41 |
+
"lineno",
|
| 42 |
+
"module",
|
| 43 |
+
"msecs",
|
| 44 |
+
"pathname",
|
| 45 |
+
"process",
|
| 46 |
+
"processName",
|
| 47 |
+
"relativeCreated",
|
| 48 |
+
"stack_info",
|
| 49 |
+
"exc_info",
|
| 50 |
+
"exc_text",
|
| 51 |
+
"thread",
|
| 52 |
+
"threadName",
|
| 53 |
+
"message",
|
| 54 |
+
"asctime",
|
| 55 |
+
"taskName",
|
| 56 |
+
}
|
| 57 |
+
)
|
| 58 |
|
| 59 |
|
| 60 |
# ---------------------------------------------------------------------------
|
| 61 |
# Custom Formatter (Console)
|
| 62 |
# ---------------------------------------------------------------------------
|
| 63 |
|
| 64 |
+
|
| 65 |
class ConsoleFormatter(logging.Formatter):
|
| 66 |
"""Human-readable formatter with visual hierarchy."""
|
| 67 |
|
| 68 |
COLORS = {
|
| 69 |
+
"DEBUG": "\033[36m", # Cyan
|
| 70 |
+
"INFO": "\033[32m", # Green
|
| 71 |
+
"WARNING": "\033[33m", # Yellow
|
| 72 |
+
"ERROR": "\033[31m", # Red
|
| 73 |
"CRITICAL": "\033[35m", # Magenta
|
| 74 |
"RESET": "\033[0m",
|
| 75 |
}
|
|
|
|
| 108 |
# Custom Formatter (JSON)
|
| 109 |
# ---------------------------------------------------------------------------
|
| 110 |
|
| 111 |
+
|
| 112 |
class JSONFormatter(logging.Formatter):
|
| 113 |
"""Machine-parseable JSON formatter for production."""
|
| 114 |
|
|
|
|
| 202 |
# Convenience functions for visual output
|
| 203 |
# ---------------------------------------------------------------------------
|
| 204 |
|
| 205 |
+
|
| 206 |
+
def log_banner(
|
| 207 |
+
logger: logging.Logger, title: str, char: str = "=", width: int = 60
|
| 208 |
+
) -> None:
|
| 209 |
"""Log a visual banner for section headers."""
|
| 210 |
logger.info(char * width)
|
| 211 |
logger.info(title)
|
| 212 |
logger.info(char * width)
|
| 213 |
|
| 214 |
|
| 215 |
+
def log_section(
|
| 216 |
+
logger: logging.Logger, title: str, char: str = "-", width: int = 60
|
| 217 |
+
) -> None:
|
| 218 |
"""Log a section divider."""
|
| 219 |
logger.info("")
|
| 220 |
logger.info(char * width)
|
sage/core/aggregation.py
CHANGED
|
@@ -55,13 +55,15 @@ def aggregate_chunks_to_products(
|
|
| 55 |
else:
|
| 56 |
agg_score = max(scores)
|
| 57 |
|
| 58 |
-
product_scores.append(
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Sort by score descending
|
| 67 |
return sorted(product_scores, key=lambda p: p.score, reverse=True)
|
|
@@ -112,12 +114,14 @@ def apply_weighted_ranking(
|
|
| 112 |
# Create new ProductScore objects with updated scores
|
| 113 |
reranked = []
|
| 114 |
for i, product in enumerate(products):
|
| 115 |
-
reranked.append(
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
|
| 123 |
return sorted(reranked, key=lambda p: p.score, reverse=True)
|
|
|
|
| 55 |
else:
|
| 56 |
agg_score = max(scores)
|
| 57 |
|
| 58 |
+
product_scores.append(
|
| 59 |
+
ProductScore(
|
| 60 |
+
product_id=product_id,
|
| 61 |
+
score=agg_score,
|
| 62 |
+
chunk_count=len(prod_chunks),
|
| 63 |
+
avg_rating=float(np.mean(ratings)),
|
| 64 |
+
evidence=sorted(prod_chunks, key=lambda c: c.score, reverse=True),
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
|
| 68 |
# Sort by score descending
|
| 69 |
return sorted(product_scores, key=lambda p: p.score, reverse=True)
|
|
|
|
| 114 |
# Create new ProductScore objects with updated scores
|
| 115 |
reranked = []
|
| 116 |
for i, product in enumerate(products):
|
| 117 |
+
reranked.append(
|
| 118 |
+
ProductScore(
|
| 119 |
+
product_id=product.product_id,
|
| 120 |
+
score=float(final_scores[i]),
|
| 121 |
+
chunk_count=product.chunk_count,
|
| 122 |
+
avg_rating=product.avg_rating,
|
| 123 |
+
evidence=product.evidence,
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
|
| 127 |
return sorted(reranked, key=lambda p: p.score, reverse=True)
|
sage/core/chunking.py
CHANGED
|
@@ -19,16 +19,16 @@ from sage.config import CHARS_PER_TOKEN
|
|
| 19 |
|
| 20 |
|
| 21 |
# Chunking thresholds (tokens)
|
| 22 |
-
NO_CHUNK_THRESHOLD = 200
|
| 23 |
-
SEMANTIC_THRESHOLD = 500
|
| 24 |
-
MAX_CHUNK_TOKENS = 400
|
| 25 |
|
| 26 |
# Semantic chunking config
|
| 27 |
-
SIMILARITY_PERCENTILE = 85
|
| 28 |
|
| 29 |
# Sliding window config (fallback)
|
| 30 |
-
SLIDING_CHUNK_SIZE = 150
|
| 31 |
-
SLIDING_OVERLAP = 30
|
| 32 |
|
| 33 |
|
| 34 |
def estimate_tokens(text: str) -> int:
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
# Chunking thresholds (tokens)
|
| 22 |
+
NO_CHUNK_THRESHOLD = 200 # Texts under this: no chunking
|
| 23 |
+
SEMANTIC_THRESHOLD = 500 # Texts under this: semantic only
|
| 24 |
+
MAX_CHUNK_TOKENS = 400 # Chunks larger than this get sliding window
|
| 25 |
|
| 26 |
# Semantic chunking config
|
| 27 |
+
SIMILARITY_PERCENTILE = 85 # Split at drops below this percentile
|
| 28 |
|
| 29 |
# Sliding window config (fallback)
|
| 30 |
+
SLIDING_CHUNK_SIZE = 150 # Target tokens per sliding window chunk
|
| 31 |
+
SLIDING_OVERLAP = 30 # Token overlap between chunks
|
| 32 |
|
| 33 |
|
| 34 |
def estimate_tokens(text: str) -> int:
|
sage/core/evidence.py
CHANGED
|
@@ -101,8 +101,14 @@ def check_evidence_quality(
|
|
| 101 |
|
| 102 |
# Check thresholds using table-driven validation
|
| 103 |
thresholds = [
|
| 104 |
-
(
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
(top_score < min_score, f"low_relevance: {top_score:.3f} < {min_score}"),
|
| 107 |
]
|
| 108 |
|
|
@@ -150,17 +156,17 @@ def generate_refusal_message(
|
|
| 150 |
f"I cannot provide a confident recommendation for this product based on "
|
| 151 |
f"the available review evidence. Only {quality.chunk_count} review excerpt(s) "
|
| 152 |
f"were found, which is insufficient to make a well-grounded recommendation "
|
| 153 |
-
f
|
| 154 |
)
|
| 155 |
elif "insufficient_tokens" in reason:
|
| 156 |
return (
|
| 157 |
f"I cannot provide a meaningful recommendation for this product. "
|
| 158 |
f"The available review evidence is too brief ({quality.total_tokens} tokens) "
|
| 159 |
-
f
|
| 160 |
)
|
| 161 |
elif "low_relevance" in reason:
|
| 162 |
return (
|
| 163 |
-
f
|
| 164 |
f"the available reviews do not appear to be sufficiently relevant "
|
| 165 |
f"(relevance score: {quality.top_score:.2f}). The reviews may discuss "
|
| 166 |
f"different aspects or product features than what you're looking for."
|
|
@@ -168,5 +174,5 @@ def generate_refusal_message(
|
|
| 168 |
else:
|
| 169 |
return (
|
| 170 |
f"I cannot provide a recommendation for this product due to "
|
| 171 |
-
f
|
| 172 |
)
|
|
|
|
| 101 |
|
| 102 |
# Check thresholds using table-driven validation
|
| 103 |
thresholds = [
|
| 104 |
+
(
|
| 105 |
+
chunk_count < min_chunks,
|
| 106 |
+
f"insufficient_chunks: {chunk_count} < {min_chunks}",
|
| 107 |
+
),
|
| 108 |
+
(
|
| 109 |
+
total_tokens < min_tokens,
|
| 110 |
+
f"insufficient_tokens: {total_tokens} < {min_tokens}",
|
| 111 |
+
),
|
| 112 |
(top_score < min_score, f"low_relevance: {top_score:.3f} < {min_score}"),
|
| 113 |
]
|
| 114 |
|
|
|
|
| 156 |
f"I cannot provide a confident recommendation for this product based on "
|
| 157 |
f"the available review evidence. Only {quality.chunk_count} review excerpt(s) "
|
| 158 |
f"were found, which is insufficient to make a well-grounded recommendation "
|
| 159 |
+
f'for your query about "{query}".'
|
| 160 |
)
|
| 161 |
elif "insufficient_tokens" in reason:
|
| 162 |
return (
|
| 163 |
f"I cannot provide a meaningful recommendation for this product. "
|
| 164 |
f"The available review evidence is too brief ({quality.total_tokens} tokens) "
|
| 165 |
+
f'to support a well-grounded explanation for your query about "{query}".'
|
| 166 |
)
|
| 167 |
elif "low_relevance" in reason:
|
| 168 |
return (
|
| 169 |
+
f'I cannot recommend this product for your query about "{query}" because '
|
| 170 |
f"the available reviews do not appear to be sufficiently relevant "
|
| 171 |
f"(relevance score: {quality.top_score:.2f}). The reviews may discuss "
|
| 172 |
f"different aspects or product features than what you're looking for."
|
|
|
|
| 174 |
else:
|
| 175 |
return (
|
| 176 |
f"I cannot provide a recommendation for this product due to "
|
| 177 |
+
f'insufficient review evidence for your query about "{query}".'
|
| 178 |
)
|
sage/core/models.py
CHANGED
|
@@ -20,8 +20,10 @@ from typing import Iterator
|
|
| 20 |
# RETRIEVAL & RECOMMENDATION MODELS
|
| 21 |
# ============================================================================
|
| 22 |
|
|
|
|
| 23 |
class AggregationMethod(Enum):
|
| 24 |
"""Methods for aggregating chunk scores to product scores."""
|
|
|
|
| 25 |
MAX = "max"
|
| 26 |
MEAN = "mean"
|
| 27 |
WEIGHTED_MEAN = "weighted_mean"
|
|
@@ -35,6 +37,7 @@ class Chunk:
|
|
| 35 |
This is the unit stored in the vector database. Reviews are chunked
|
| 36 |
using semantic or sliding-window strategies based on length.
|
| 37 |
"""
|
|
|
|
| 38 |
text: str
|
| 39 |
chunk_index: int
|
| 40 |
total_chunks: int
|
|
@@ -52,6 +55,7 @@ class RetrievedChunk:
|
|
| 52 |
This is returned by semantic search and used as evidence for
|
| 53 |
explanation generation.
|
| 54 |
"""
|
|
|
|
| 55 |
text: str
|
| 56 |
score: float
|
| 57 |
product_id: str
|
|
@@ -67,6 +71,7 @@ class ProductScore:
|
|
| 67 |
Multiple chunks may belong to the same product. This dataclass
|
| 68 |
holds the aggregated score and all supporting evidence.
|
| 69 |
"""
|
|
|
|
| 70 |
product_id: str
|
| 71 |
score: float
|
| 72 |
chunk_count: int
|
|
@@ -89,6 +94,7 @@ class Recommendation:
|
|
| 89 |
This is the output of the recommendation pipeline, ready for
|
| 90 |
display or API response.
|
| 91 |
"""
|
|
|
|
| 92 |
rank: int
|
| 93 |
product_id: str
|
| 94 |
score: float
|
|
@@ -102,6 +108,7 @@ class Recommendation:
|
|
| 102 |
# COLD START MODELS
|
| 103 |
# ============================================================================
|
| 104 |
|
|
|
|
| 105 |
@dataclass
|
| 106 |
class UserPreferences:
|
| 107 |
"""
|
|
@@ -110,6 +117,7 @@ class UserPreferences:
|
|
| 110 |
In production, these would be collected via an onboarding flow:
|
| 111 |
"What categories interest you?" "What's your budget?" etc.
|
| 112 |
"""
|
|
|
|
| 113 |
categories: list[str] | None = None
|
| 114 |
budget: str | None = None # "low", "medium", "high", or specific like "$50-100"
|
| 115 |
priorities: list[str] | None = None # ["quality", "value", "durability"]
|
|
@@ -123,6 +131,7 @@ class NewItem:
|
|
| 123 |
|
| 124 |
In production, this would come from the product catalog.
|
| 125 |
"""
|
|
|
|
| 126 |
product_id: str
|
| 127 |
title: str
|
| 128 |
description: str | None = None
|
|
@@ -136,6 +145,7 @@ class NewItem:
|
|
| 136 |
# EXPLANATION MODELS
|
| 137 |
# ============================================================================
|
| 138 |
|
|
|
|
| 139 |
@dataclass
|
| 140 |
class ExplanationResult:
|
| 141 |
"""
|
|
@@ -144,6 +154,7 @@ class ExplanationResult:
|
|
| 144 |
Contains the generated explanation along with evidence attribution
|
| 145 |
for traceability and faithfulness verification.
|
| 146 |
"""
|
|
|
|
| 147 |
explanation: str
|
| 148 |
product_id: str
|
| 149 |
query: str
|
|
@@ -174,6 +185,7 @@ class StreamingExplanation:
|
|
| 174 |
print(token, end="", flush=True)
|
| 175 |
result = stream.get_complete_result()
|
| 176 |
"""
|
|
|
|
| 177 |
token_iterator: Iterator[str]
|
| 178 |
product_id: str
|
| 179 |
query: str
|
|
@@ -215,6 +227,7 @@ class EvidenceQuality:
|
|
| 215 |
when evidence is too thin. Thin evidence (1 chunk, few tokens)
|
| 216 |
correlates strongly with LLM overclaiming.
|
| 217 |
"""
|
|
|
|
| 218 |
is_sufficient: bool
|
| 219 |
chunk_count: int
|
| 220 |
total_tokens: int
|
|
@@ -226,9 +239,11 @@ class EvidenceQuality:
|
|
| 226 |
# VERIFICATION MODELS
|
| 227 |
# ============================================================================
|
| 228 |
|
|
|
|
| 229 |
@dataclass
|
| 230 |
class QuoteVerification:
|
| 231 |
"""Result of verifying a single quoted claim against evidence."""
|
|
|
|
| 232 |
quote: str
|
| 233 |
found: bool
|
| 234 |
source_text: str | None = None # Which evidence text contained it
|
|
@@ -243,6 +258,7 @@ class VerificationResult:
|
|
| 243 |
exists in the provided evidence. Catches wrong attribution where
|
| 244 |
LLM cites quotes that don't exist.
|
| 245 |
"""
|
|
|
|
| 246 |
all_verified: bool
|
| 247 |
quotes_found: int
|
| 248 |
quotes_missing: int
|
|
@@ -254,6 +270,7 @@ class VerificationResult:
|
|
| 254 |
# HALLUCINATION DETECTION MODELS
|
| 255 |
# ============================================================================
|
| 256 |
|
|
|
|
| 257 |
@dataclass
|
| 258 |
class HallucinationResult:
|
| 259 |
"""
|
|
@@ -263,6 +280,7 @@ class HallucinationResult:
|
|
| 263 |
consistency between evidence (premise) and explanation (hypothesis).
|
| 264 |
Score < 0.5 indicates hallucination.
|
| 265 |
"""
|
|
|
|
| 266 |
score: float
|
| 267 |
is_hallucinated: bool
|
| 268 |
threshold: float
|
|
@@ -273,6 +291,7 @@ class HallucinationResult:
|
|
| 273 |
@dataclass
|
| 274 |
class ClaimResult:
|
| 275 |
"""Result of hallucination check for a single claim."""
|
|
|
|
| 276 |
claim: str
|
| 277 |
score: float
|
| 278 |
is_hallucinated: bool
|
|
@@ -286,6 +305,7 @@ class AgreementReport:
|
|
| 286 |
Useful for understanding when the two methods disagree and
|
| 287 |
calibrating thresholds.
|
| 288 |
"""
|
|
|
|
| 289 |
n_samples: int
|
| 290 |
agreement_rate: float # Proportion where both agree on pass/fail
|
| 291 |
hhem_pass_rate: float
|
|
@@ -306,6 +326,7 @@ class AdjustedFaithfulnessReport:
|
|
| 306 |
Refusals (e.g., "I cannot recommend...") are correct LLM behavior
|
| 307 |
but get penalized by HHEM. This report adjusts for that.
|
| 308 |
"""
|
|
|
|
| 309 |
n_total: int
|
| 310 |
n_refusals: int
|
| 311 |
n_evaluated: int # n_total - n_refusals
|
|
@@ -334,6 +355,7 @@ class ClaimLevelReport:
|
|
| 334 |
- min_score: Lowest scoring claim (weakest grounding)
|
| 335 |
- pass_rate: Proportion of claims scoring >= threshold
|
| 336 |
"""
|
|
|
|
| 337 |
n_explanations: int
|
| 338 |
n_claims: int
|
| 339 |
avg_score: float
|
|
@@ -368,6 +390,7 @@ class MultiMetricFaithfulnessReport:
|
|
| 368 |
individually. Full-explanation HHEM (57%) measures structural patterns
|
| 369 |
that HHEM was trained on, not actual hallucination."
|
| 370 |
"""
|
|
|
|
| 371 |
n_samples: int
|
| 372 |
# Quote verification (lexical)
|
| 373 |
quote_verification_rate: float
|
|
@@ -410,19 +433,19 @@ class MultiMetricFaithfulnessReport:
|
|
| 410 |
"=" * 50,
|
| 411 |
"",
|
| 412 |
"Quote Verification (lexical grounding):",
|
| 413 |
-
f" Pass rate: {self.quote_verification_rate*100:.1f}% ({self.quotes_found}/{self.quotes_total})",
|
| 414 |
"",
|
| 415 |
"Claim-Level HHEM (semantic grounding per claim):",
|
| 416 |
-
f" Pass rate: {self.claim_level_pass_rate*100:.1f}%",
|
| 417 |
f" Avg score: {self.claim_level_avg_score:.3f}",
|
| 418 |
f" Min score: {self.claim_level_min_score:.3f}",
|
| 419 |
"",
|
| 420 |
"Full-Explanation HHEM (structural compatibility):",
|
| 421 |
-
f" Pass rate: {self.full_explanation_pass_rate*100:.1f}%",
|
| 422 |
f" Avg score: {self.full_explanation_avg_score:.3f}",
|
| 423 |
"",
|
| 424 |
"-" * 50,
|
| 425 |
-
f"PRIMARY METRIC ({self.primary_metric}): {self.claim_level_pass_rate*100:.1f}%",
|
| 426 |
]
|
| 427 |
return "\n".join(lines)
|
| 428 |
|
|
@@ -431,9 +454,11 @@ class MultiMetricFaithfulnessReport:
|
|
| 431 |
# FAITHFULNESS EVALUATION MODELS (RAGAS)
|
| 432 |
# ============================================================================
|
| 433 |
|
|
|
|
| 434 |
@dataclass
|
| 435 |
class FaithfulnessResult:
|
| 436 |
"""Result of RAGAS faithfulness evaluation for a single explanation."""
|
|
|
|
| 437 |
score: float
|
| 438 |
query: str
|
| 439 |
explanation: str
|
|
@@ -444,6 +469,7 @@ class FaithfulnessResult:
|
|
| 444 |
@dataclass
|
| 445 |
class FaithfulnessReport:
|
| 446 |
"""Aggregate report for batch faithfulness evaluation (legacy format)."""
|
|
|
|
| 447 |
mean_score: float
|
| 448 |
min_score: float
|
| 449 |
max_score: float
|
|
@@ -459,6 +485,7 @@ class FaithfulnessReport:
|
|
| 459 |
# EVALUATION METRICS MODELS
|
| 460 |
# ============================================================================
|
| 461 |
|
|
|
|
| 462 |
@dataclass
|
| 463 |
class EvalCase:
|
| 464 |
"""
|
|
@@ -470,6 +497,7 @@ class EvalCase:
|
|
| 470 |
For binary relevance, use 1 for relevant, 0 for not.
|
| 471 |
user_id: Optional user identifier for the case.
|
| 472 |
"""
|
|
|
|
| 473 |
query: str
|
| 474 |
relevant_items: dict[str, float]
|
| 475 |
user_id: str | None = None
|
|
@@ -483,6 +511,7 @@ class EvalCase:
|
|
| 483 |
@dataclass
|
| 484 |
class EvalResult:
|
| 485 |
"""Results from evaluating a single recommendation list."""
|
|
|
|
| 486 |
ndcg: float = 0.0
|
| 487 |
hit: float = 0.0
|
| 488 |
mrr: float = 0.0
|
|
@@ -498,6 +527,7 @@ class MetricsReport:
|
|
| 498 |
Includes both accuracy metrics (NDCG, Hit, MRR) and
|
| 499 |
beyond-accuracy metrics (diversity, coverage, novelty).
|
| 500 |
"""
|
|
|
|
| 501 |
n_cases: int = 0
|
| 502 |
ndcg_at_k: float = 0.0
|
| 503 |
hit_at_k: float = 0.0
|
|
|
|
| 20 |
# RETRIEVAL & RECOMMENDATION MODELS
|
| 21 |
# ============================================================================
|
| 22 |
|
| 23 |
+
|
| 24 |
class AggregationMethod(Enum):
|
| 25 |
"""Methods for aggregating chunk scores to product scores."""
|
| 26 |
+
|
| 27 |
MAX = "max"
|
| 28 |
MEAN = "mean"
|
| 29 |
WEIGHTED_MEAN = "weighted_mean"
|
|
|
|
| 37 |
This is the unit stored in the vector database. Reviews are chunked
|
| 38 |
using semantic or sliding-window strategies based on length.
|
| 39 |
"""
|
| 40 |
+
|
| 41 |
text: str
|
| 42 |
chunk_index: int
|
| 43 |
total_chunks: int
|
|
|
|
| 55 |
This is returned by semantic search and used as evidence for
|
| 56 |
explanation generation.
|
| 57 |
"""
|
| 58 |
+
|
| 59 |
text: str
|
| 60 |
score: float
|
| 61 |
product_id: str
|
|
|
|
| 71 |
Multiple chunks may belong to the same product. This dataclass
|
| 72 |
holds the aggregated score and all supporting evidence.
|
| 73 |
"""
|
| 74 |
+
|
| 75 |
product_id: str
|
| 76 |
score: float
|
| 77 |
chunk_count: int
|
|
|
|
| 94 |
This is the output of the recommendation pipeline, ready for
|
| 95 |
display or API response.
|
| 96 |
"""
|
| 97 |
+
|
| 98 |
rank: int
|
| 99 |
product_id: str
|
| 100 |
score: float
|
|
|
|
| 108 |
# COLD START MODELS
|
| 109 |
# ============================================================================
|
| 110 |
|
| 111 |
+
|
| 112 |
@dataclass
|
| 113 |
class UserPreferences:
|
| 114 |
"""
|
|
|
|
| 117 |
In production, these would be collected via an onboarding flow:
|
| 118 |
"What categories interest you?" "What's your budget?" etc.
|
| 119 |
"""
|
| 120 |
+
|
| 121 |
categories: list[str] | None = None
|
| 122 |
budget: str | None = None # "low", "medium", "high", or specific like "$50-100"
|
| 123 |
priorities: list[str] | None = None # ["quality", "value", "durability"]
|
|
|
|
| 131 |
|
| 132 |
In production, this would come from the product catalog.
|
| 133 |
"""
|
| 134 |
+
|
| 135 |
product_id: str
|
| 136 |
title: str
|
| 137 |
description: str | None = None
|
|
|
|
| 145 |
# EXPLANATION MODELS
|
| 146 |
# ============================================================================
|
| 147 |
|
| 148 |
+
|
| 149 |
@dataclass
|
| 150 |
class ExplanationResult:
|
| 151 |
"""
|
|
|
|
| 154 |
Contains the generated explanation along with evidence attribution
|
| 155 |
for traceability and faithfulness verification.
|
| 156 |
"""
|
| 157 |
+
|
| 158 |
explanation: str
|
| 159 |
product_id: str
|
| 160 |
query: str
|
|
|
|
| 185 |
print(token, end="", flush=True)
|
| 186 |
result = stream.get_complete_result()
|
| 187 |
"""
|
| 188 |
+
|
| 189 |
token_iterator: Iterator[str]
|
| 190 |
product_id: str
|
| 191 |
query: str
|
|
|
|
| 227 |
when evidence is too thin. Thin evidence (1 chunk, few tokens)
|
| 228 |
correlates strongly with LLM overclaiming.
|
| 229 |
"""
|
| 230 |
+
|
| 231 |
is_sufficient: bool
|
| 232 |
chunk_count: int
|
| 233 |
total_tokens: int
|
|
|
|
| 239 |
# VERIFICATION MODELS
|
| 240 |
# ============================================================================
|
| 241 |
|
| 242 |
+
|
| 243 |
@dataclass
|
| 244 |
class QuoteVerification:
|
| 245 |
"""Result of verifying a single quoted claim against evidence."""
|
| 246 |
+
|
| 247 |
quote: str
|
| 248 |
found: bool
|
| 249 |
source_text: str | None = None # Which evidence text contained it
|
|
|
|
| 258 |
exists in the provided evidence. Catches wrong attribution where
|
| 259 |
LLM cites quotes that don't exist.
|
| 260 |
"""
|
| 261 |
+
|
| 262 |
all_verified: bool
|
| 263 |
quotes_found: int
|
| 264 |
quotes_missing: int
|
|
|
|
| 270 |
# HALLUCINATION DETECTION MODELS
|
| 271 |
# ============================================================================
|
| 272 |
|
| 273 |
+
|
| 274 |
@dataclass
|
| 275 |
class HallucinationResult:
|
| 276 |
"""
|
|
|
|
| 280 |
consistency between evidence (premise) and explanation (hypothesis).
|
| 281 |
Score < 0.5 indicates hallucination.
|
| 282 |
"""
|
| 283 |
+
|
| 284 |
score: float
|
| 285 |
is_hallucinated: bool
|
| 286 |
threshold: float
|
|
|
|
| 291 |
@dataclass
|
| 292 |
class ClaimResult:
|
| 293 |
"""Result of hallucination check for a single claim."""
|
| 294 |
+
|
| 295 |
claim: str
|
| 296 |
score: float
|
| 297 |
is_hallucinated: bool
|
|
|
|
| 305 |
Useful for understanding when the two methods disagree and
|
| 306 |
calibrating thresholds.
|
| 307 |
"""
|
| 308 |
+
|
| 309 |
n_samples: int
|
| 310 |
agreement_rate: float # Proportion where both agree on pass/fail
|
| 311 |
hhem_pass_rate: float
|
|
|
|
| 326 |
Refusals (e.g., "I cannot recommend...") are correct LLM behavior
|
| 327 |
but get penalized by HHEM. This report adjusts for that.
|
| 328 |
"""
|
| 329 |
+
|
| 330 |
n_total: int
|
| 331 |
n_refusals: int
|
| 332 |
n_evaluated: int # n_total - n_refusals
|
|
|
|
| 355 |
- min_score: Lowest scoring claim (weakest grounding)
|
| 356 |
- pass_rate: Proportion of claims scoring >= threshold
|
| 357 |
"""
|
| 358 |
+
|
| 359 |
n_explanations: int
|
| 360 |
n_claims: int
|
| 361 |
avg_score: float
|
|
|
|
| 390 |
individually. Full-explanation HHEM (57%) measures structural patterns
|
| 391 |
that HHEM was trained on, not actual hallucination."
|
| 392 |
"""
|
| 393 |
+
|
| 394 |
n_samples: int
|
| 395 |
# Quote verification (lexical)
|
| 396 |
quote_verification_rate: float
|
|
|
|
| 433 |
"=" * 50,
|
| 434 |
"",
|
| 435 |
"Quote Verification (lexical grounding):",
|
| 436 |
+
f" Pass rate: {self.quote_verification_rate * 100:.1f}% ({self.quotes_found}/{self.quotes_total})",
|
| 437 |
"",
|
| 438 |
"Claim-Level HHEM (semantic grounding per claim):",
|
| 439 |
+
f" Pass rate: {self.claim_level_pass_rate * 100:.1f}%",
|
| 440 |
f" Avg score: {self.claim_level_avg_score:.3f}",
|
| 441 |
f" Min score: {self.claim_level_min_score:.3f}",
|
| 442 |
"",
|
| 443 |
"Full-Explanation HHEM (structural compatibility):",
|
| 444 |
+
f" Pass rate: {self.full_explanation_pass_rate * 100:.1f}%",
|
| 445 |
f" Avg score: {self.full_explanation_avg_score:.3f}",
|
| 446 |
"",
|
| 447 |
"-" * 50,
|
| 448 |
+
f"PRIMARY METRIC ({self.primary_metric}): {self.claim_level_pass_rate * 100:.1f}%",
|
| 449 |
]
|
| 450 |
return "\n".join(lines)
|
| 451 |
|
|
|
|
| 454 |
# FAITHFULNESS EVALUATION MODELS (RAGAS)
|
| 455 |
# ============================================================================
|
| 456 |
|
| 457 |
+
|
| 458 |
@dataclass
|
| 459 |
class FaithfulnessResult:
|
| 460 |
"""Result of RAGAS faithfulness evaluation for a single explanation."""
|
| 461 |
+
|
| 462 |
score: float
|
| 463 |
query: str
|
| 464 |
explanation: str
|
|
|
|
| 469 |
@dataclass
|
| 470 |
class FaithfulnessReport:
|
| 471 |
"""Aggregate report for batch faithfulness evaluation (legacy format)."""
|
| 472 |
+
|
| 473 |
mean_score: float
|
| 474 |
min_score: float
|
| 475 |
max_score: float
|
|
|
|
| 485 |
# EVALUATION METRICS MODELS
|
| 486 |
# ============================================================================
|
| 487 |
|
| 488 |
+
|
| 489 |
@dataclass
|
| 490 |
class EvalCase:
|
| 491 |
"""
|
|
|
|
| 497 |
For binary relevance, use 1 for relevant, 0 for not.
|
| 498 |
user_id: Optional user identifier for the case.
|
| 499 |
"""
|
| 500 |
+
|
| 501 |
query: str
|
| 502 |
relevant_items: dict[str, float]
|
| 503 |
user_id: str | None = None
|
|
|
|
| 511 |
@dataclass
|
| 512 |
class EvalResult:
|
| 513 |
"""Results from evaluating a single recommendation list."""
|
| 514 |
+
|
| 515 |
ndcg: float = 0.0
|
| 516 |
hit: float = 0.0
|
| 517 |
mrr: float = 0.0
|
|
|
|
| 527 |
Includes both accuracy metrics (NDCG, Hit, MRR) and
|
| 528 |
beyond-accuracy metrics (diversity, coverage, novelty).
|
| 529 |
"""
|
| 530 |
+
|
| 531 |
n_cases: int = 0
|
| 532 |
ndcg_at_k: float = 0.0
|
| 533 |
hit_at_k: float = 0.0
|
sage/core/prompts.py
CHANGED
|
@@ -76,7 +76,7 @@ def format_evidence(
|
|
| 76 |
return "(No review evidence available)"
|
| 77 |
|
| 78 |
return "\n\n".join(
|
| 79 |
-
f
|
| 80 |
for chunk in chunks[:max_chunks]
|
| 81 |
)
|
| 82 |
|
|
|
|
| 76 |
return "(No review evidence available)"
|
| 77 |
|
| 78 |
return "\n\n".join(
|
| 79 |
+
f'[{chunk.review_id}] ({int(chunk.rating or 0)}/5 stars): "{chunk.text}"'
|
| 80 |
for chunk in chunks[:max_chunks]
|
| 81 |
)
|
| 82 |
|
sage/core/verification.py
CHANGED
|
@@ -86,9 +86,9 @@ def extract_quotes(text: str, min_length: int = 4) -> list[str]:
|
|
| 86 |
List of unique quoted strings found in the text.
|
| 87 |
"""
|
| 88 |
patterns = [
|
| 89 |
-
r'"([^"]+)"',
|
| 90 |
-
r'"([^"]+)"',
|
| 91 |
-
r"'([^']+)'",
|
| 92 |
]
|
| 93 |
|
| 94 |
quotes = []
|
|
@@ -206,6 +206,7 @@ def verify_explanation(
|
|
| 206 |
# Citation ID Verification
|
| 207 |
# =============================================================================
|
| 208 |
|
|
|
|
| 209 |
@dataclass
|
| 210 |
class CitationResult:
|
| 211 |
"""Result of verifying a single citation."""
|
|
@@ -256,12 +257,12 @@ def extract_citations(text: str) -> list[tuple[str, str | None]]:
|
|
| 256 |
quote_text = match.group(1)
|
| 257 |
citation_block = match.group(2)
|
| 258 |
# Split multiple citations like "review_123, review_456"
|
| 259 |
-
for citation_id in re.findall(r
|
| 260 |
citations.append((citation_id, quote_text))
|
| 261 |
|
| 262 |
# Pattern for standalone citations not preceded by a quote
|
| 263 |
# Find all citations, then filter out ones already captured with quotes
|
| 264 |
-
all_citation_ids = set(re.findall(r
|
| 265 |
quoted_citation_ids = {c[0] for c in citations}
|
| 266 |
standalone_ids = all_citation_ids - quoted_citation_ids
|
| 267 |
|
|
@@ -294,9 +295,7 @@ def verify_citation(
|
|
| 294 |
"""
|
| 295 |
# Collect all chunks belonging to this citation ID (a single review
|
| 296 |
# may produce multiple chunks after splitting long reviews).
|
| 297 |
-
matching_indices = [
|
| 298 |
-
i for i, eid in enumerate(evidence_ids) if eid == citation_id
|
| 299 |
-
]
|
| 300 |
|
| 301 |
if not matching_indices:
|
| 302 |
return CitationResult(
|
|
|
|
| 86 |
List of unique quoted strings found in the text.
|
| 87 |
"""
|
| 88 |
patterns = [
|
| 89 |
+
r'"([^"]+)"', # Regular double quotes
|
| 90 |
+
r'"([^"]+)"', # Curly double quotes
|
| 91 |
+
r"'([^']+)'", # Single quotes
|
| 92 |
]
|
| 93 |
|
| 94 |
quotes = []
|
|
|
|
| 206 |
# Citation ID Verification
|
| 207 |
# =============================================================================
|
| 208 |
|
| 209 |
+
|
| 210 |
@dataclass
|
| 211 |
class CitationResult:
|
| 212 |
"""Result of verifying a single citation."""
|
|
|
|
| 257 |
quote_text = match.group(1)
|
| 258 |
citation_block = match.group(2)
|
| 259 |
# Split multiple citations like "review_123, review_456"
|
| 260 |
+
for citation_id in re.findall(r"review_\d+", citation_block):
|
| 261 |
citations.append((citation_id, quote_text))
|
| 262 |
|
| 263 |
# Pattern for standalone citations not preceded by a quote
|
| 264 |
# Find all citations, then filter out ones already captured with quotes
|
| 265 |
+
all_citation_ids = set(re.findall(r"review_\d+", text))
|
| 266 |
quoted_citation_ids = {c[0] for c in citations}
|
| 267 |
standalone_ids = all_citation_ids - quoted_citation_ids
|
| 268 |
|
|
|
|
| 295 |
"""
|
| 296 |
# Collect all chunks belonging to this citation ID (a single review
|
| 297 |
# may produce multiple chunks after splitting long reviews).
|
| 298 |
+
matching_indices = [i for i, eid in enumerate(evidence_ids) if eid == citation_id]
|
|
|
|
|
|
|
| 299 |
|
| 300 |
if not matching_indices:
|
| 301 |
return CitationResult(
|
sage/services/__init__.py
CHANGED
|
@@ -59,6 +59,7 @@ _LAZY_IMPORTS = {
|
|
| 59 |
def __getattr__(name: str):
|
| 60 |
if name in _LAZY_IMPORTS:
|
| 61 |
import importlib
|
|
|
|
| 62 |
module = importlib.import_module(_LAZY_IMPORTS[name])
|
| 63 |
return getattr(module, name)
|
| 64 |
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
|
|
| 59 |
def __getattr__(name: str):
|
| 60 |
if name in _LAZY_IMPORTS:
|
| 61 |
import importlib
|
| 62 |
+
|
| 63 |
module = importlib.import_module(_LAZY_IMPORTS[name])
|
| 64 |
return getattr(module, name)
|
| 65 |
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
sage/services/baselines.py
CHANGED
|
@@ -73,9 +73,7 @@ class PopularityBaseline:
|
|
| 73 |
counts = Counter(i[item_key] for i in interactions if item_key in i)
|
| 74 |
|
| 75 |
# Sort by popularity (descending)
|
| 76 |
-
self.ranked_items = [
|
| 77 |
-
item for item, _ in counts.most_common()
|
| 78 |
-
]
|
| 79 |
|
| 80 |
self.popularity = counts
|
| 81 |
|
|
@@ -119,9 +117,9 @@ class ItemKNNBaseline:
|
|
| 119 |
embedder: E5Embedder instance for query embedding.
|
| 120 |
"""
|
| 121 |
self.product_ids = list(product_embeddings.keys())
|
| 122 |
-
self.embeddings = np.array(
|
| 123 |
-
product_embeddings[pid] for pid in self.product_ids
|
| 124 |
-
|
| 125 |
|
| 126 |
# Normalize embeddings for cosine similarity
|
| 127 |
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
|
|
@@ -143,6 +141,7 @@ class ItemKNNBaseline:
|
|
| 143 |
"""
|
| 144 |
if self.embedder is None:
|
| 145 |
from sage.adapters.embeddings import get_embedder
|
|
|
|
| 146 |
self.embedder = get_embedder()
|
| 147 |
|
| 148 |
# Embed query
|
|
@@ -188,7 +187,9 @@ def build_product_embeddings(
|
|
| 188 |
elif aggregation == "max":
|
| 189 |
agg_emb = product_embs.max(axis=0)
|
| 190 |
else:
|
| 191 |
-
raise ValueError(
|
|
|
|
|
|
|
| 192 |
|
| 193 |
# Normalize
|
| 194 |
agg_emb = agg_emb / (np.linalg.norm(agg_emb) + 1e-8)
|
|
|
|
| 73 |
counts = Counter(i[item_key] for i in interactions if item_key in i)
|
| 74 |
|
| 75 |
# Sort by popularity (descending)
|
| 76 |
+
self.ranked_items = [item for item, _ in counts.most_common()]
|
|
|
|
|
|
|
| 77 |
|
| 78 |
self.popularity = counts
|
| 79 |
|
|
|
|
| 117 |
embedder: E5Embedder instance for query embedding.
|
| 118 |
"""
|
| 119 |
self.product_ids = list(product_embeddings.keys())
|
| 120 |
+
self.embeddings = np.array(
|
| 121 |
+
[product_embeddings[pid] for pid in self.product_ids]
|
| 122 |
+
)
|
| 123 |
|
| 124 |
# Normalize embeddings for cosine similarity
|
| 125 |
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
|
|
|
|
| 141 |
"""
|
| 142 |
if self.embedder is None:
|
| 143 |
from sage.adapters.embeddings import get_embedder
|
| 144 |
+
|
| 145 |
self.embedder = get_embedder()
|
| 146 |
|
| 147 |
# Embed query
|
|
|
|
| 187 |
elif aggregation == "max":
|
| 188 |
agg_emb = product_embs.max(axis=0)
|
| 189 |
else:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
f"Unknown aggregation method: {aggregation}. Use 'mean' or 'max'."
|
| 192 |
+
)
|
| 193 |
|
| 194 |
# Normalize
|
| 195 |
agg_emb = agg_emb / (np.linalg.norm(agg_emb) + 1e-8)
|
sage/services/cache.py
CHANGED
|
@@ -27,6 +27,7 @@ logger = get_logger(__name__)
|
|
| 27 |
# Cache entry
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
|
|
|
|
| 30 |
@dataclass
|
| 31 |
class _CacheEntry:
|
| 32 |
"""Single cached result with metadata for eviction."""
|
|
@@ -43,6 +44,7 @@ class _CacheEntry:
|
|
| 43 |
# Cache stats
|
| 44 |
# ---------------------------------------------------------------------------
|
| 45 |
|
|
|
|
| 46 |
@dataclass
|
| 47 |
class CacheStats:
|
| 48 |
"""Snapshot of cache performance metrics."""
|
|
@@ -73,6 +75,7 @@ class CacheStats:
|
|
| 73 |
# Semantic cache
|
| 74 |
# ---------------------------------------------------------------------------
|
| 75 |
|
|
|
|
| 76 |
class SemanticCache:
|
| 77 |
"""Thread-safe in-memory cache with exact-match and semantic-similarity layers.
|
| 78 |
|
|
@@ -116,7 +119,9 @@ class SemanticCache:
|
|
| 116 |
# ------------------------------------------------------------------
|
| 117 |
|
| 118 |
def get(
|
| 119 |
-
self,
|
|
|
|
|
|
|
| 120 |
) -> tuple[dict | None, str]:
|
| 121 |
"""Look up a cached result.
|
| 122 |
|
|
@@ -232,7 +237,8 @@ class SemanticCache:
|
|
| 232 |
# ------------------------------------------------------------------
|
| 233 |
|
| 234 |
def _find_semantic_match(
|
| 235 |
-
self,
|
|
|
|
| 236 |
) -> tuple[_CacheEntry, float]:
|
| 237 |
"""Find the best semantic match among cached entries.
|
| 238 |
|
|
|
|
| 27 |
# Cache entry
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
|
| 30 |
+
|
| 31 |
@dataclass
|
| 32 |
class _CacheEntry:
|
| 33 |
"""Single cached result with metadata for eviction."""
|
|
|
|
| 44 |
# Cache stats
|
| 45 |
# ---------------------------------------------------------------------------
|
| 46 |
|
| 47 |
+
|
| 48 |
@dataclass
|
| 49 |
class CacheStats:
|
| 50 |
"""Snapshot of cache performance metrics."""
|
|
|
|
| 75 |
# Semantic cache
|
| 76 |
# ---------------------------------------------------------------------------
|
| 77 |
|
| 78 |
+
|
| 79 |
class SemanticCache:
|
| 80 |
"""Thread-safe in-memory cache with exact-match and semantic-similarity layers.
|
| 81 |
|
|
|
|
| 119 |
# ------------------------------------------------------------------
|
| 120 |
|
| 121 |
def get(
|
| 122 |
+
self,
|
| 123 |
+
query: str,
|
| 124 |
+
query_embedding: np.ndarray | None = None,
|
| 125 |
) -> tuple[dict | None, str]:
|
| 126 |
"""Look up a cached result.
|
| 127 |
|
|
|
|
| 237 |
# ------------------------------------------------------------------
|
| 238 |
|
| 239 |
def _find_semantic_match(
|
| 240 |
+
self,
|
| 241 |
+
query_embedding: np.ndarray,
|
| 242 |
) -> tuple[_CacheEntry, float]:
|
| 243 |
"""Find the best semantic match among cached entries.
|
| 244 |
|
sage/services/cold_start.py
CHANGED
|
@@ -102,6 +102,7 @@ class ColdStartService:
|
|
| 102 |
"""Lazy-load retrieval service."""
|
| 103 |
if self._retrieval is None:
|
| 104 |
from sage.services.retrieval import RetrievalService
|
|
|
|
| 105 |
self._retrieval = RetrievalService(collection_name=self.collection_name)
|
| 106 |
return self._retrieval
|
| 107 |
|
|
@@ -179,7 +180,9 @@ class ColdStartService:
|
|
| 179 |
item_text = " ".join(text_parts)
|
| 180 |
|
| 181 |
# Embed as a passage
|
| 182 |
-
item_embedding = self.embedder.embed_passages([item_text], show_progress=False)[
|
|
|
|
|
|
|
| 183 |
|
| 184 |
# Search for similar chunks
|
| 185 |
results = search(
|
|
|
|
| 102 |
"""Lazy-load retrieval service."""
|
| 103 |
if self._retrieval is None:
|
| 104 |
from sage.services.retrieval import RetrievalService
|
| 105 |
+
|
| 106 |
self._retrieval = RetrievalService(collection_name=self.collection_name)
|
| 107 |
return self._retrieval
|
| 108 |
|
|
|
|
| 180 |
item_text = " ".join(text_parts)
|
| 181 |
|
| 182 |
# Embed as a passage
|
| 183 |
+
item_embedding = self.embedder.embed_passages([item_text], show_progress=False)[
|
| 184 |
+
0
|
| 185 |
+
]
|
| 186 |
|
| 187 |
# Search for similar chunks
|
| 188 |
results = search(
|
sage/services/evaluation.py
CHANGED
|
@@ -263,7 +263,9 @@ class EvaluationService:
|
|
| 263 |
ndcg_at_k=float(np.mean(ndcg_scores)) if ndcg_scores else 0.0,
|
| 264 |
hit_at_k=float(np.mean(hit_scores)) if hit_scores else 0.0,
|
| 265 |
mrr=float(np.mean(mrr_scores)) if mrr_scores else 0.0,
|
| 266 |
-
precision_at_k=float(np.mean(precision_scores))
|
|
|
|
|
|
|
| 267 |
recall_at_k=float(np.mean(recall_scores)) if recall_scores else 0.0,
|
| 268 |
diversity=float(np.mean(diversity_scores)) if diversity_scores else 0.0,
|
| 269 |
novelty=float(np.mean(novelty_scores)) if novelty_scores else 0.0,
|
|
|
|
| 263 |
ndcg_at_k=float(np.mean(ndcg_scores)) if ndcg_scores else 0.0,
|
| 264 |
hit_at_k=float(np.mean(hit_scores)) if hit_scores else 0.0,
|
| 265 |
mrr=float(np.mean(mrr_scores)) if mrr_scores else 0.0,
|
| 266 |
+
precision_at_k=float(np.mean(precision_scores))
|
| 267 |
+
if precision_scores
|
| 268 |
+
else 0.0,
|
| 269 |
recall_at_k=float(np.mean(recall_scores)) if recall_scores else 0.0,
|
| 270 |
diversity=float(np.mean(diversity_scores)) if diversity_scores else 0.0,
|
| 271 |
novelty=float(np.mean(novelty_scores)) if novelty_scores else 0.0,
|
sage/services/explanation.py
CHANGED
|
@@ -109,8 +109,8 @@ class Explainer:
|
|
| 109 |
Returns:
|
| 110 |
(explanation, tokens, evidence_texts, evidence_ids, user_prompt).
|
| 111 |
"""
|
| 112 |
-
system_prompt, user_prompt, evidence_texts, evidence_ids =
|
| 113 |
-
query, product, max_evidence
|
| 114 |
)
|
| 115 |
|
| 116 |
t0 = time.perf_counter()
|
|
@@ -120,7 +120,9 @@ class Explainer:
|
|
| 120 |
)
|
| 121 |
logger.info(
|
| 122 |
"LLM generation for %s: %.0fms, %d tokens",
|
| 123 |
-
product.product_id,
|
|
|
|
|
|
|
| 124 |
)
|
| 125 |
|
| 126 |
return explanation, tokens, evidence_texts, evidence_ids, user_prompt
|
|
@@ -241,8 +243,8 @@ class Explainer:
|
|
| 241 |
f"Client {type(self.client).__name__} does not support streaming."
|
| 242 |
)
|
| 243 |
|
| 244 |
-
system_prompt, user_prompt, evidence_texts, evidence_ids =
|
| 245 |
-
query, product, max_evidence
|
| 246 |
)
|
| 247 |
|
| 248 |
token_iterator = self.client.generate_stream(
|
|
@@ -335,7 +337,11 @@ class Explainer:
|
|
| 335 |
"""
|
| 336 |
return [
|
| 337 |
self.generate_explanation(
|
| 338 |
-
query,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
)
|
| 340 |
for product in products
|
| 341 |
]
|
|
|
|
| 109 |
Returns:
|
| 110 |
(explanation, tokens, evidence_texts, evidence_ids, user_prompt).
|
| 111 |
"""
|
| 112 |
+
system_prompt, user_prompt, evidence_texts, evidence_ids = (
|
| 113 |
+
build_explanation_prompt(query, product, max_evidence)
|
| 114 |
)
|
| 115 |
|
| 116 |
t0 = time.perf_counter()
|
|
|
|
| 120 |
)
|
| 121 |
logger.info(
|
| 122 |
"LLM generation for %s: %.0fms, %d tokens",
|
| 123 |
+
product.product_id,
|
| 124 |
+
(time.perf_counter() - t0) * 1000,
|
| 125 |
+
tokens,
|
| 126 |
)
|
| 127 |
|
| 128 |
return explanation, tokens, evidence_texts, evidence_ids, user_prompt
|
|
|
|
| 243 |
f"Client {type(self.client).__name__} does not support streaming."
|
| 244 |
)
|
| 245 |
|
| 246 |
+
system_prompt, user_prompt, evidence_texts, evidence_ids = (
|
| 247 |
+
build_explanation_prompt(query, product, max_evidence)
|
| 248 |
)
|
| 249 |
|
| 250 |
token_iterator = self.client.generate_stream(
|
|
|
|
| 337 |
"""
|
| 338 |
return [
|
| 339 |
self.generate_explanation(
|
| 340 |
+
query,
|
| 341 |
+
product,
|
| 342 |
+
max_evidence,
|
| 343 |
+
enforce_quality_gate,
|
| 344 |
+
enforce_forbidden_phrases,
|
| 345 |
)
|
| 346 |
for product in products
|
| 347 |
]
|
sage/services/faithfulness.py
CHANGED
|
@@ -73,7 +73,9 @@ def create_ragas_sample(query: str, explanation: str, evidence_texts: list[str])
|
|
| 73 |
)
|
| 74 |
|
| 75 |
|
| 76 |
-
def _explanation_results_to_samples(
|
|
|
|
|
|
|
| 77 |
"""Convert ExplanationResults to RAGAS samples."""
|
| 78 |
return [
|
| 79 |
create_ragas_sample(
|
|
@@ -239,6 +241,7 @@ class FaithfulnessEvaluator:
|
|
| 239 |
results=individual_results,
|
| 240 |
)
|
| 241 |
|
|
|
|
| 242 |
def evaluate_faithfulness(
|
| 243 |
explanation_results: list[ExplanationResult],
|
| 244 |
provider: str | None = None,
|
|
@@ -396,7 +399,8 @@ def compute_adjusted_faithfulness(
|
|
| 396 |
# - Valid non-recommendations count as passes (correct behavior)
|
| 397 |
# - Regular recommendations evaluated by HHEM
|
| 398 |
regular_passes = sum(
|
| 399 |
-
1
|
|
|
|
| 400 |
if not is_non_rec and not r.is_hallucinated
|
| 401 |
)
|
| 402 |
adjusted_passes = regular_passes + n_valid_non_recs
|
|
@@ -594,14 +598,13 @@ def compute_multi_metric_faithfulness(
|
|
| 594 |
detector = get_detector()
|
| 595 |
|
| 596 |
# 1. Full-explanation HHEM (structural)
|
| 597 |
-
full_scores = [
|
| 598 |
-
detector.check_explanation(ev, exp).score
|
| 599 |
-
for ev, exp in items
|
| 600 |
-
]
|
| 601 |
|
| 602 |
# 2. Claim-level HHEM
|
| 603 |
claim_report = compute_claim_level_hhem(
|
| 604 |
-
items,
|
|
|
|
|
|
|
| 605 |
)
|
| 606 |
|
| 607 |
# 3. Quote verification (lexical)
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
|
| 76 |
+
def _explanation_results_to_samples(
|
| 77 |
+
explanation_results: list[ExplanationResult],
|
| 78 |
+
) -> list:
|
| 79 |
"""Convert ExplanationResults to RAGAS samples."""
|
| 80 |
return [
|
| 81 |
create_ragas_sample(
|
|
|
|
| 241 |
results=individual_results,
|
| 242 |
)
|
| 243 |
|
| 244 |
+
|
| 245 |
def evaluate_faithfulness(
|
| 246 |
explanation_results: list[ExplanationResult],
|
| 247 |
provider: str | None = None,
|
|
|
|
| 399 |
# - Valid non-recommendations count as passes (correct behavior)
|
| 400 |
# - Regular recommendations evaluated by HHEM
|
| 401 |
regular_passes = sum(
|
| 402 |
+
1
|
| 403 |
+
for r, is_non_rec in zip(results, valid_non_recs)
|
| 404 |
if not is_non_rec and not r.is_hallucinated
|
| 405 |
)
|
| 406 |
adjusted_passes = regular_passes + n_valid_non_recs
|
|
|
|
| 598 |
detector = get_detector()
|
| 599 |
|
| 600 |
# 1. Full-explanation HHEM (structural)
|
| 601 |
+
full_scores = [detector.check_explanation(ev, exp).score for ev, exp in items]
|
|
|
|
|
|
|
|
|
|
| 602 |
|
| 603 |
# 2. Claim-level HHEM
|
| 604 |
claim_report = compute_claim_level_hhem(
|
| 605 |
+
items,
|
| 606 |
+
threshold,
|
| 607 |
+
full_explanation_scores=full_scores,
|
| 608 |
)
|
| 609 |
|
| 610 |
# 3. Quote verification (lexical)
|
sage/services/retrieval.py
CHANGED
|
@@ -138,7 +138,11 @@ class RetrievalService:
|
|
| 138 |
limit=limit,
|
| 139 |
min_rating=min_rating,
|
| 140 |
)
|
| 141 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
chunks = []
|
| 144 |
for r in results:
|
|
@@ -157,7 +161,9 @@ class RetrievalService:
|
|
| 157 |
)
|
| 158 |
|
| 159 |
product_ids = {c.product_id for c in chunks}
|
| 160 |
-
logger.info(
|
|
|
|
|
|
|
| 161 |
|
| 162 |
return chunks
|
| 163 |
|
|
@@ -247,7 +253,11 @@ def retrieve_chunks(
|
|
| 247 |
"""Retrieve relevant chunks from the vector store."""
|
| 248 |
service = RetrievalService(client=client, embedder=embedder)
|
| 249 |
return service.retrieve_chunks(
|
| 250 |
-
query,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
|
|
@@ -347,7 +357,8 @@ def recommend_for_user(
|
|
| 347 |
|
| 348 |
# Get products to exclude
|
| 349 |
exclude: set[str] = {
|
| 350 |
-
pid
|
|
|
|
| 351 |
if (pid := r.get("product_id")) is not None and isinstance(pid, str)
|
| 352 |
}
|
| 353 |
|
|
|
|
| 138 |
limit=limit,
|
| 139 |
min_rating=min_rating,
|
| 140 |
)
|
| 141 |
+
logger.info(
|
| 142 |
+
"Qdrant search: %.0fms, %d results",
|
| 143 |
+
(time.perf_counter() - t0) * 1000,
|
| 144 |
+
len(results),
|
| 145 |
+
)
|
| 146 |
|
| 147 |
chunks = []
|
| 148 |
for r in results:
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
product_ids = {c.product_id for c in chunks}
|
| 164 |
+
logger.info(
|
| 165 |
+
"Retrieved %d chunks across %d products", len(chunks), len(product_ids)
|
| 166 |
+
)
|
| 167 |
|
| 168 |
return chunks
|
| 169 |
|
|
|
|
| 253 |
"""Retrieve relevant chunks from the vector store."""
|
| 254 |
service = RetrievalService(client=client, embedder=embedder)
|
| 255 |
return service.retrieve_chunks(
|
| 256 |
+
query,
|
| 257 |
+
limit,
|
| 258 |
+
min_rating,
|
| 259 |
+
exclude_products,
|
| 260 |
+
query_embedding,
|
| 261 |
)
|
| 262 |
|
| 263 |
|
|
|
|
| 357 |
|
| 358 |
# Get products to exclude
|
| 359 |
exclude: set[str] = {
|
| 360 |
+
pid
|
| 361 |
+
for r in user_history
|
| 362 |
if (pid := r.get("product_id")) is not None and isinstance(pid, str)
|
| 363 |
}
|
| 364 |
|
sage/utils.py
CHANGED
|
@@ -20,6 +20,7 @@ def save_results(data: dict, prefix: str, directory: Path | None = None) -> Path
|
|
| 20 |
"""
|
| 21 |
if directory is None:
|
| 22 |
from sage.config import RESULTS_DIR
|
|
|
|
| 23 |
directory = RESULTS_DIR
|
| 24 |
|
| 25 |
directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 20 |
"""
|
| 21 |
if directory is None:
|
| 22 |
from sage.config import RESULTS_DIR
|
| 23 |
+
|
| 24 |
directory = RESULTS_DIR
|
| 25 |
|
| 26 |
directory.mkdir(parents=True, exist_ok=True)
|
scripts/build_eval_dataset.py
CHANGED
|
@@ -35,27 +35,199 @@ EVAL_DIR = DATA_DIR / "eval"
|
|
| 35 |
|
| 36 |
# Common stopwords to filter out
|
| 37 |
STOPWORDS = {
|
| 38 |
-
"i",
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
"
|
| 44 |
-
"
|
| 45 |
-
"
|
| 46 |
-
"
|
| 47 |
-
"
|
| 48 |
-
"
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
| 53 |
-
"
|
| 54 |
-
"
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
|
|
@@ -75,7 +247,7 @@ def extract_keywords(text: str, max_keywords: int = 8) -> list[str]:
|
|
| 75 |
# Clean text
|
| 76 |
text = text.lower()
|
| 77 |
text = re.sub(r"<br\s*/?>", " ", text) # Remove HTML breaks
|
| 78 |
-
text = re.sub(r"[^a-z\s]", " ", text)
|
| 79 |
text = re.sub(r"\s+", " ", text).strip()
|
| 80 |
|
| 81 |
# Tokenize and filter
|
|
@@ -165,6 +337,7 @@ def generate_query_from_history(
|
|
| 165 |
# Evaluation Dataset Construction
|
| 166 |
# ---------------------------------------------------------------------------
|
| 167 |
|
|
|
|
| 168 |
def build_leave_one_out_cases(
|
| 169 |
df: pd.DataFrame,
|
| 170 |
min_reviews: int = 2,
|
|
@@ -231,16 +404,21 @@ def build_leave_one_out_cases(
|
|
| 231 |
|
| 232 |
# Only include if target has positive relevance
|
| 233 |
if relevance > 0:
|
| 234 |
-
eval_cases.append(
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
|
| 240 |
if verbose:
|
| 241 |
logger.info("Users with enough reviews: %d", len(user_groups) - skipped_users)
|
| 242 |
logger.info("Eval cases created: %d", len(eval_cases))
|
| 243 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
return eval_cases
|
| 246 |
|
|
@@ -310,16 +488,20 @@ def build_multi_relevant_cases(
|
|
| 310 |
)
|
| 311 |
|
| 312 |
if relevant_items:
|
| 313 |
-
eval_cases.append(
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
| 318 |
|
| 319 |
if verbose:
|
| 320 |
logger.info("Users with train history: %d", len(train_users))
|
| 321 |
logger.info("Eval cases created: %d", len(eval_cases))
|
| 322 |
-
avg_relevant =
|
|
|
|
|
|
|
| 323 |
logger.info("Avg relevant items per case: %.1f", avg_relevant)
|
| 324 |
|
| 325 |
return eval_cases
|
|
@@ -400,7 +582,12 @@ if __name__ == "__main__":
|
|
| 400 |
# Load splits
|
| 401 |
log_section(logger, "Loading data splits")
|
| 402 |
train_df, val_df, test_df = load_splits()
|
| 403 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
# Strategy 1: Leave-one-out with keyword queries
|
| 406 |
# WARNING: This strategy has TARGET LEAKAGE - queries are generated from
|
|
@@ -418,8 +605,12 @@ if __name__ == "__main__":
|
|
| 418 |
# Show examples
|
| 419 |
logger.info("Sample queries:")
|
| 420 |
for case in loo_keyword_cases[:5]:
|
| 421 |
-
logger.info(
|
| 422 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
save_eval_cases(loo_keyword_cases, "eval_loo_keyword.json")
|
| 425 |
|
|
@@ -435,8 +626,12 @@ if __name__ == "__main__":
|
|
| 435 |
# Show examples
|
| 436 |
logger.info("Sample queries:")
|
| 437 |
for case in loo_history_cases[:5]:
|
| 438 |
-
logger.info(
|
| 439 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
|
| 441 |
save_eval_cases(loo_history_cases, "eval_loo_history.json")
|
| 442 |
|
|
@@ -452,7 +647,7 @@ if __name__ == "__main__":
|
|
| 452 |
if multi_cases:
|
| 453 |
logger.info("Sample queries:")
|
| 454 |
for case in multi_cases[:3]:
|
| 455 |
-
logger.info(
|
| 456 |
logger.info(" Relevant: %d items", len(case.relevant_items))
|
| 457 |
|
| 458 |
save_eval_cases(multi_cases, "eval_multi_relevant.json")
|
|
|
|
| 35 |
|
| 36 |
# Common stopwords to filter out
|
| 37 |
STOPWORDS = {
|
| 38 |
+
"i",
|
| 39 |
+
"me",
|
| 40 |
+
"my",
|
| 41 |
+
"myself",
|
| 42 |
+
"we",
|
| 43 |
+
"our",
|
| 44 |
+
"ours",
|
| 45 |
+
"ourselves",
|
| 46 |
+
"you",
|
| 47 |
+
"your",
|
| 48 |
+
"yours",
|
| 49 |
+
"yourself",
|
| 50 |
+
"yourselves",
|
| 51 |
+
"he",
|
| 52 |
+
"him",
|
| 53 |
+
"his",
|
| 54 |
+
"himself",
|
| 55 |
+
"she",
|
| 56 |
+
"her",
|
| 57 |
+
"hers",
|
| 58 |
+
"herself",
|
| 59 |
+
"it",
|
| 60 |
+
"its",
|
| 61 |
+
"itself",
|
| 62 |
+
"they",
|
| 63 |
+
"them",
|
| 64 |
+
"their",
|
| 65 |
+
"theirs",
|
| 66 |
+
"themselves",
|
| 67 |
+
"what",
|
| 68 |
+
"which",
|
| 69 |
+
"who",
|
| 70 |
+
"whom",
|
| 71 |
+
"this",
|
| 72 |
+
"that",
|
| 73 |
+
"these",
|
| 74 |
+
"those",
|
| 75 |
+
"am",
|
| 76 |
+
"is",
|
| 77 |
+
"are",
|
| 78 |
+
"was",
|
| 79 |
+
"were",
|
| 80 |
+
"be",
|
| 81 |
+
"been",
|
| 82 |
+
"being",
|
| 83 |
+
"have",
|
| 84 |
+
"has",
|
| 85 |
+
"had",
|
| 86 |
+
"having",
|
| 87 |
+
"do",
|
| 88 |
+
"does",
|
| 89 |
+
"did",
|
| 90 |
+
"doing",
|
| 91 |
+
"a",
|
| 92 |
+
"an",
|
| 93 |
+
"the",
|
| 94 |
+
"and",
|
| 95 |
+
"but",
|
| 96 |
+
"if",
|
| 97 |
+
"or",
|
| 98 |
+
"because",
|
| 99 |
+
"as",
|
| 100 |
+
"until",
|
| 101 |
+
"while",
|
| 102 |
+
"of",
|
| 103 |
+
"at",
|
| 104 |
+
"by",
|
| 105 |
+
"for",
|
| 106 |
+
"with",
|
| 107 |
+
"about",
|
| 108 |
+
"against",
|
| 109 |
+
"between",
|
| 110 |
+
"into",
|
| 111 |
+
"through",
|
| 112 |
+
"during",
|
| 113 |
+
"before",
|
| 114 |
+
"after",
|
| 115 |
+
"above",
|
| 116 |
+
"below",
|
| 117 |
+
"to",
|
| 118 |
+
"from",
|
| 119 |
+
"up",
|
| 120 |
+
"down",
|
| 121 |
+
"in",
|
| 122 |
+
"out",
|
| 123 |
+
"on",
|
| 124 |
+
"off",
|
| 125 |
+
"over",
|
| 126 |
+
"under",
|
| 127 |
+
"again",
|
| 128 |
+
"further",
|
| 129 |
+
"then",
|
| 130 |
+
"once",
|
| 131 |
+
"here",
|
| 132 |
+
"there",
|
| 133 |
+
"when",
|
| 134 |
+
"where",
|
| 135 |
+
"why",
|
| 136 |
+
"how",
|
| 137 |
+
"all",
|
| 138 |
+
"each",
|
| 139 |
+
"few",
|
| 140 |
+
"more",
|
| 141 |
+
"most",
|
| 142 |
+
"other",
|
| 143 |
+
"some",
|
| 144 |
+
"such",
|
| 145 |
+
"no",
|
| 146 |
+
"nor",
|
| 147 |
+
"not",
|
| 148 |
+
"only",
|
| 149 |
+
"own",
|
| 150 |
+
"same",
|
| 151 |
+
"so",
|
| 152 |
+
"than",
|
| 153 |
+
"too",
|
| 154 |
+
"very",
|
| 155 |
+
"s",
|
| 156 |
+
"t",
|
| 157 |
+
"can",
|
| 158 |
+
"will",
|
| 159 |
+
"just",
|
| 160 |
+
"don",
|
| 161 |
+
"should",
|
| 162 |
+
"now",
|
| 163 |
+
"d",
|
| 164 |
+
"ll",
|
| 165 |
+
"m",
|
| 166 |
+
"o",
|
| 167 |
+
"re",
|
| 168 |
+
"ve",
|
| 169 |
+
"y",
|
| 170 |
+
"ain",
|
| 171 |
+
"aren",
|
| 172 |
+
"couldn",
|
| 173 |
+
"didn",
|
| 174 |
+
"doesn",
|
| 175 |
+
"hadn",
|
| 176 |
+
"hasn",
|
| 177 |
+
"haven",
|
| 178 |
+
"isn",
|
| 179 |
+
"ma",
|
| 180 |
+
"mightn",
|
| 181 |
+
"mustn",
|
| 182 |
+
"needn",
|
| 183 |
+
"shan",
|
| 184 |
+
"shouldn",
|
| 185 |
+
"wasn",
|
| 186 |
+
"weren",
|
| 187 |
+
"won",
|
| 188 |
+
"wouldn",
|
| 189 |
+
"also",
|
| 190 |
+
"would",
|
| 191 |
+
"could",
|
| 192 |
+
"get",
|
| 193 |
+
"got",
|
| 194 |
+
"one",
|
| 195 |
+
"two",
|
| 196 |
+
"really",
|
| 197 |
+
"like",
|
| 198 |
+
"just",
|
| 199 |
+
"even",
|
| 200 |
+
"well",
|
| 201 |
+
"much",
|
| 202 |
+
"still",
|
| 203 |
+
"back",
|
| 204 |
+
"way",
|
| 205 |
+
"thing",
|
| 206 |
+
"things",
|
| 207 |
+
"make",
|
| 208 |
+
"made",
|
| 209 |
+
"work",
|
| 210 |
+
"works",
|
| 211 |
+
"worked",
|
| 212 |
+
"use",
|
| 213 |
+
"used",
|
| 214 |
+
"using",
|
| 215 |
+
"good",
|
| 216 |
+
"great",
|
| 217 |
+
"nice",
|
| 218 |
+
"product",
|
| 219 |
+
"item",
|
| 220 |
+
"bought",
|
| 221 |
+
"buy",
|
| 222 |
+
"amazon",
|
| 223 |
+
"review",
|
| 224 |
+
"ordered",
|
| 225 |
+
"order",
|
| 226 |
+
"received",
|
| 227 |
+
"came",
|
| 228 |
+
"arrived",
|
| 229 |
+
"shipping",
|
| 230 |
+
"shipped",
|
| 231 |
}
|
| 232 |
|
| 233 |
|
|
|
|
| 247 |
# Clean text
|
| 248 |
text = text.lower()
|
| 249 |
text = re.sub(r"<br\s*/?>", " ", text) # Remove HTML breaks
|
| 250 |
+
text = re.sub(r"[^a-z\s]", " ", text) # Keep only letters
|
| 251 |
text = re.sub(r"\s+", " ", text).strip()
|
| 252 |
|
| 253 |
# Tokenize and filter
|
|
|
|
| 337 |
# Evaluation Dataset Construction
|
| 338 |
# ---------------------------------------------------------------------------
|
| 339 |
|
| 340 |
+
|
| 341 |
def build_leave_one_out_cases(
|
| 342 |
df: pd.DataFrame,
|
| 343 |
min_reviews: int = 2,
|
|
|
|
| 404 |
|
| 405 |
# Only include if target has positive relevance
|
| 406 |
if relevance > 0:
|
| 407 |
+
eval_cases.append(
|
| 408 |
+
EvalCase(
|
| 409 |
+
query=query,
|
| 410 |
+
relevant_items={target_product: relevance},
|
| 411 |
+
user_id=user_id,
|
| 412 |
+
)
|
| 413 |
+
)
|
| 414 |
|
| 415 |
if verbose:
|
| 416 |
logger.info("Users with enough reviews: %d", len(user_groups) - skipped_users)
|
| 417 |
logger.info("Eval cases created: %d", len(eval_cases))
|
| 418 |
+
logger.info(
|
| 419 |
+
"Skipped (low relevance): %d",
|
| 420 |
+
len(user_groups) - skipped_users - len(eval_cases),
|
| 421 |
+
)
|
| 422 |
|
| 423 |
return eval_cases
|
| 424 |
|
|
|
|
| 488 |
)
|
| 489 |
|
| 490 |
if relevant_items:
|
| 491 |
+
eval_cases.append(
|
| 492 |
+
EvalCase(
|
| 493 |
+
query=query,
|
| 494 |
+
relevant_items=relevant_items,
|
| 495 |
+
user_id=user_id,
|
| 496 |
+
)
|
| 497 |
+
)
|
| 498 |
|
| 499 |
if verbose:
|
| 500 |
logger.info("Users with train history: %d", len(train_users))
|
| 501 |
logger.info("Eval cases created: %d", len(eval_cases))
|
| 502 |
+
avg_relevant = (
|
| 503 |
+
np.mean([len(c.relevant_items) for c in eval_cases]) if eval_cases else 0
|
| 504 |
+
)
|
| 505 |
logger.info("Avg relevant items per case: %.1f", avg_relevant)
|
| 506 |
|
| 507 |
return eval_cases
|
|
|
|
| 582 |
# Load splits
|
| 583 |
log_section(logger, "Loading data splits")
|
| 584 |
train_df, val_df, test_df = load_splits()
|
| 585 |
+
logger.info(
|
| 586 |
+
"Train: %s | Val: %s | Test: %s",
|
| 587 |
+
f"{len(train_df):,}",
|
| 588 |
+
f"{len(val_df):,}",
|
| 589 |
+
f"{len(test_df):,}",
|
| 590 |
+
)
|
| 591 |
|
| 592 |
# Strategy 1: Leave-one-out with keyword queries
|
| 593 |
# WARNING: This strategy has TARGET LEAKAGE - queries are generated from
|
|
|
|
| 605 |
# Show examples
|
| 606 |
logger.info("Sample queries:")
|
| 607 |
for case in loo_keyword_cases[:5]:
|
| 608 |
+
logger.info(' Query: "%s"', case.query)
|
| 609 |
+
logger.info(
|
| 610 |
+
" Target: %s (rel=%s)",
|
| 611 |
+
list(case.relevant_items.keys())[0],
|
| 612 |
+
list(case.relevant_items.values())[0],
|
| 613 |
+
)
|
| 614 |
|
| 615 |
save_eval_cases(loo_keyword_cases, "eval_loo_keyword.json")
|
| 616 |
|
|
|
|
| 626 |
# Show examples
|
| 627 |
logger.info("Sample queries:")
|
| 628 |
for case in loo_history_cases[:5]:
|
| 629 |
+
logger.info(' Query: "%s"', case.query)
|
| 630 |
+
logger.info(
|
| 631 |
+
" Target: %s (rel=%s)",
|
| 632 |
+
list(case.relevant_items.keys())[0],
|
| 633 |
+
list(case.relevant_items.values())[0],
|
| 634 |
+
)
|
| 635 |
|
| 636 |
save_eval_cases(loo_history_cases, "eval_loo_history.json")
|
| 637 |
|
|
|
|
| 647 |
if multi_cases:
|
| 648 |
logger.info("Sample queries:")
|
| 649 |
for case in multi_cases[:3]:
|
| 650 |
+
logger.info(' Query: "%s..."', case.query[:60])
|
| 651 |
logger.info(" Relevant: %d items", len(case.relevant_items))
|
| 652 |
|
| 653 |
save_eval_cases(multi_cases, "eval_multi_relevant.json")
|
scripts/build_natural_eval_dataset.py
CHANGED
|
@@ -70,7 +70,6 @@ NATURAL_QUERIES = [
|
|
| 70 |
"category": "echo_devices",
|
| 71 |
"intent": "feature_specific",
|
| 72 |
},
|
| 73 |
-
|
| 74 |
# === FIRE TABLET QUERIES ===
|
| 75 |
{
|
| 76 |
"query": "tablet for reading books and light browsing",
|
|
@@ -111,7 +110,6 @@ NATURAL_QUERIES = [
|
|
| 111 |
"category": "fire_tablets",
|
| 112 |
"intent": "use_case",
|
| 113 |
},
|
| 114 |
-
|
| 115 |
# === FIRE TV / STREAMING QUERIES ===
|
| 116 |
{
|
| 117 |
"query": "streaming device for my tv",
|
|
@@ -151,7 +149,6 @@ NATURAL_QUERIES = [
|
|
| 151 |
"category": "fire_tv",
|
| 152 |
"intent": "use_case",
|
| 153 |
},
|
| 154 |
-
|
| 155 |
# === SMART HOME QUERIES ===
|
| 156 |
{
|
| 157 |
"query": "smart plug to control lights with alexa",
|
|
@@ -191,7 +188,6 @@ NATURAL_QUERIES = [
|
|
| 191 |
"category": "smart_home",
|
| 192 |
"intent": "feature_specific",
|
| 193 |
},
|
| 194 |
-
|
| 195 |
# === STORAGE QUERIES ===
|
| 196 |
{
|
| 197 |
"query": "sd card for camera",
|
|
@@ -232,7 +228,6 @@ NATURAL_QUERIES = [
|
|
| 232 |
"category": "storage",
|
| 233 |
"intent": "feature_specific",
|
| 234 |
},
|
| 235 |
-
|
| 236 |
# === HEADPHONES / AUDIO QUERIES ===
|
| 237 |
{
|
| 238 |
"query": "wireless headphones for working out",
|
|
@@ -283,7 +278,6 @@ NATURAL_QUERIES = [
|
|
| 283 |
"category": "headphones_audio",
|
| 284 |
"intent": "use_case",
|
| 285 |
},
|
| 286 |
-
|
| 287 |
# === CABLES / ADAPTERS QUERIES ===
|
| 288 |
{
|
| 289 |
"query": "usb c charging cable",
|
|
@@ -322,7 +316,6 @@ NATURAL_QUERIES = [
|
|
| 322 |
"category": "cables_adapters",
|
| 323 |
"intent": "feature_specific",
|
| 324 |
},
|
| 325 |
-
|
| 326 |
# === KEYBOARD / MOUSE QUERIES ===
|
| 327 |
{
|
| 328 |
"query": "wireless keyboard for computer",
|
|
@@ -353,7 +346,6 @@ NATURAL_QUERIES = [
|
|
| 353 |
"category": "keyboards_mice",
|
| 354 |
"intent": "feature_specific",
|
| 355 |
},
|
| 356 |
-
|
| 357 |
# === GIFT QUERIES ===
|
| 358 |
{
|
| 359 |
"query": "gift for someone who likes music",
|
|
@@ -395,7 +387,6 @@ NATURAL_QUERIES = [
|
|
| 395 |
"category": "gifts",
|
| 396 |
"intent": "gift",
|
| 397 |
},
|
| 398 |
-
|
| 399 |
# === PROBLEM-SOLVING QUERIES ===
|
| 400 |
{
|
| 401 |
"query": "headphones that dont hurt ears",
|
|
@@ -424,7 +415,6 @@ NATURAL_QUERIES = [
|
|
| 424 |
"category": "fire_tv",
|
| 425 |
"intent": "problem_solving",
|
| 426 |
},
|
| 427 |
-
|
| 428 |
# === COMPARISON / BEST QUERIES ===
|
| 429 |
{
|
| 430 |
"query": "best value fire tablet",
|
|
@@ -460,15 +450,19 @@ def build_natural_eval_cases() -> list[EvalCase]:
|
|
| 460 |
"""Convert natural queries to EvalCase objects."""
|
| 461 |
cases = []
|
| 462 |
for item in NATURAL_QUERIES:
|
| 463 |
-
cases.append(
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
| 468 |
return cases
|
| 469 |
|
| 470 |
|
| 471 |
-
def save_natural_eval_cases(
|
|
|
|
|
|
|
| 472 |
"""Save evaluation cases with metadata."""
|
| 473 |
EVAL_DIR.mkdir(exist_ok=True)
|
| 474 |
filepath = EVAL_DIR / filename
|
|
@@ -476,12 +470,14 @@ def save_natural_eval_cases(cases: list[EvalCase], filename: str = "eval_natural
|
|
| 476 |
# Include metadata for analysis
|
| 477 |
data = []
|
| 478 |
for i, item in enumerate(NATURAL_QUERIES):
|
| 479 |
-
data.append(
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
| 485 |
|
| 486 |
with open(filepath, "w") as f:
|
| 487 |
json.dump(data, f, indent=2)
|
|
@@ -522,9 +518,9 @@ def analyze_dataset():
|
|
| 522 |
# Sample queries
|
| 523 |
log_section(logger, "SAMPLE QUERIES")
|
| 524 |
for q in NATURAL_QUERIES[:5]:
|
| 525 |
-
logger.info(
|
| 526 |
-
logger.info(" Category: %s | Intent: %s", q[
|
| 527 |
-
logger.info(" Relevant: %d products", len(q[
|
| 528 |
|
| 529 |
|
| 530 |
if __name__ == "__main__":
|
|
|
|
| 70 |
"category": "echo_devices",
|
| 71 |
"intent": "feature_specific",
|
| 72 |
},
|
|
|
|
| 73 |
# === FIRE TABLET QUERIES ===
|
| 74 |
{
|
| 75 |
"query": "tablet for reading books and light browsing",
|
|
|
|
| 110 |
"category": "fire_tablets",
|
| 111 |
"intent": "use_case",
|
| 112 |
},
|
|
|
|
| 113 |
# === FIRE TV / STREAMING QUERIES ===
|
| 114 |
{
|
| 115 |
"query": "streaming device for my tv",
|
|
|
|
| 149 |
"category": "fire_tv",
|
| 150 |
"intent": "use_case",
|
| 151 |
},
|
|
|
|
| 152 |
# === SMART HOME QUERIES ===
|
| 153 |
{
|
| 154 |
"query": "smart plug to control lights with alexa",
|
|
|
|
| 188 |
"category": "smart_home",
|
| 189 |
"intent": "feature_specific",
|
| 190 |
},
|
|
|
|
| 191 |
# === STORAGE QUERIES ===
|
| 192 |
{
|
| 193 |
"query": "sd card for camera",
|
|
|
|
| 228 |
"category": "storage",
|
| 229 |
"intent": "feature_specific",
|
| 230 |
},
|
|
|
|
| 231 |
# === HEADPHONES / AUDIO QUERIES ===
|
| 232 |
{
|
| 233 |
"query": "wireless headphones for working out",
|
|
|
|
| 278 |
"category": "headphones_audio",
|
| 279 |
"intent": "use_case",
|
| 280 |
},
|
|
|
|
| 281 |
# === CABLES / ADAPTERS QUERIES ===
|
| 282 |
{
|
| 283 |
"query": "usb c charging cable",
|
|
|
|
| 316 |
"category": "cables_adapters",
|
| 317 |
"intent": "feature_specific",
|
| 318 |
},
|
|
|
|
| 319 |
# === KEYBOARD / MOUSE QUERIES ===
|
| 320 |
{
|
| 321 |
"query": "wireless keyboard for computer",
|
|
|
|
| 346 |
"category": "keyboards_mice",
|
| 347 |
"intent": "feature_specific",
|
| 348 |
},
|
|
|
|
| 349 |
# === GIFT QUERIES ===
|
| 350 |
{
|
| 351 |
"query": "gift for someone who likes music",
|
|
|
|
| 387 |
"category": "gifts",
|
| 388 |
"intent": "gift",
|
| 389 |
},
|
|
|
|
| 390 |
# === PROBLEM-SOLVING QUERIES ===
|
| 391 |
{
|
| 392 |
"query": "headphones that dont hurt ears",
|
|
|
|
| 415 |
"category": "fire_tv",
|
| 416 |
"intent": "problem_solving",
|
| 417 |
},
|
|
|
|
| 418 |
# === COMPARISON / BEST QUERIES ===
|
| 419 |
{
|
| 420 |
"query": "best value fire tablet",
|
|
|
|
| 450 |
"""Convert natural queries to EvalCase objects."""
|
| 451 |
cases = []
|
| 452 |
for item in NATURAL_QUERIES:
|
| 453 |
+
cases.append(
|
| 454 |
+
EvalCase(
|
| 455 |
+
query=item["query"],
|
| 456 |
+
relevant_items=item["relevant_items"],
|
| 457 |
+
user_id=None, # No user for natural queries
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
return cases
|
| 461 |
|
| 462 |
|
| 463 |
+
def save_natural_eval_cases(
|
| 464 |
+
cases: list[EvalCase], filename: str = "eval_natural_queries.json"
|
| 465 |
+
):
|
| 466 |
"""Save evaluation cases with metadata."""
|
| 467 |
EVAL_DIR.mkdir(exist_ok=True)
|
| 468 |
filepath = EVAL_DIR / filename
|
|
|
|
| 470 |
# Include metadata for analysis
|
| 471 |
data = []
|
| 472 |
for i, item in enumerate(NATURAL_QUERIES):
|
| 473 |
+
data.append(
|
| 474 |
+
{
|
| 475 |
+
"query": item["query"],
|
| 476 |
+
"relevant_items": item["relevant_items"],
|
| 477 |
+
"category": item.get("category", "unknown"),
|
| 478 |
+
"intent": item.get("intent", "general"),
|
| 479 |
+
}
|
| 480 |
+
)
|
| 481 |
|
| 482 |
with open(filepath, "w") as f:
|
| 483 |
json.dump(data, f, indent=2)
|
|
|
|
| 518 |
# Sample queries
|
| 519 |
log_section(logger, "SAMPLE QUERIES")
|
| 520 |
for q in NATURAL_QUERIES[:5]:
|
| 521 |
+
logger.info('Query: "%s"', q["query"])
|
| 522 |
+
logger.info(" Category: %s | Intent: %s", q["category"], q["intent"])
|
| 523 |
+
logger.info(" Relevant: %d products", len(q["relevant_items"]))
|
| 524 |
|
| 525 |
|
| 526 |
if __name__ == "__main__":
|
scripts/demo.py
CHANGED
|
@@ -31,7 +31,7 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
|
|
| 31 |
Returns dict suitable for JSON serialization.
|
| 32 |
"""
|
| 33 |
log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70)
|
| 34 |
-
logger.info(
|
| 35 |
|
| 36 |
# Get candidates
|
| 37 |
products = get_candidates(
|
|
@@ -91,7 +91,7 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
|
|
| 91 |
# Truncate long evidence for display
|
| 92 |
display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text
|
| 93 |
logger.info("[%s]:", ev_id)
|
| 94 |
-
logger.info(
|
| 95 |
|
| 96 |
# Compile result
|
| 97 |
result = {
|
|
@@ -108,8 +108,7 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
|
|
| 108 |
"evidence_sources": [
|
| 109 |
{"id": ev_id, "text": ev_text}
|
| 110 |
for ev_id, ev_text in zip(
|
| 111 |
-
explanation_result.evidence_ids,
|
| 112 |
-
explanation_result.evidence_texts
|
| 113 |
)
|
| 114 |
],
|
| 115 |
}
|
|
@@ -131,13 +130,15 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
|
|
| 131 |
def main():
|
| 132 |
parser = argparse.ArgumentParser(description="Demo recommendation pipeline")
|
| 133 |
parser.add_argument(
|
| 134 |
-
"--query",
|
|
|
|
| 135 |
type=str,
|
| 136 |
default="wireless earbuds for running",
|
| 137 |
help="Query to demonstrate",
|
| 138 |
)
|
| 139 |
parser.add_argument(
|
| 140 |
-
"--top-k",
|
|
|
|
| 141 |
type=int,
|
| 142 |
default=1,
|
| 143 |
help="Number of products to recommend (default: 1)",
|
|
|
|
| 31 |
Returns dict suitable for JSON serialization.
|
| 32 |
"""
|
| 33 |
log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70)
|
| 34 |
+
logger.info('Query: "%s"', query)
|
| 35 |
|
| 36 |
# Get candidates
|
| 37 |
products = get_candidates(
|
|
|
|
| 91 |
# Truncate long evidence for display
|
| 92 |
display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text
|
| 93 |
logger.info("[%s]:", ev_id)
|
| 94 |
+
logger.info(' "%s"', display_text)
|
| 95 |
|
| 96 |
# Compile result
|
| 97 |
result = {
|
|
|
|
| 108 |
"evidence_sources": [
|
| 109 |
{"id": ev_id, "text": ev_text}
|
| 110 |
for ev_id, ev_text in zip(
|
| 111 |
+
explanation_result.evidence_ids, explanation_result.evidence_texts
|
|
|
|
| 112 |
)
|
| 113 |
],
|
| 114 |
}
|
|
|
|
| 130 |
def main():
|
| 131 |
parser = argparse.ArgumentParser(description="Demo recommendation pipeline")
|
| 132 |
parser.add_argument(
|
| 133 |
+
"--query",
|
| 134 |
+
"-q",
|
| 135 |
type=str,
|
| 136 |
default="wireless earbuds for running",
|
| 137 |
help="Query to demonstrate",
|
| 138 |
)
|
| 139 |
parser.add_argument(
|
| 140 |
+
"--top-k",
|
| 141 |
+
"-k",
|
| 142 |
type=int,
|
| 143 |
default=1,
|
| 144 |
help="Number of products to recommend (default: 1)",
|
scripts/e2e_success_rate.py
CHANGED
|
@@ -149,7 +149,7 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 149 |
case_id = 0
|
| 150 |
|
| 151 |
for query in queries:
|
| 152 |
-
logger.info(
|
| 153 |
|
| 154 |
products = get_candidates(
|
| 155 |
query=query,
|
|
@@ -275,23 +275,43 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 275 |
n_evidence_insufficient = sum(1 for c in all_cases if not c.evidence_sufficient)
|
| 276 |
n_generated = sum(1 for c in all_cases if c.evidence_sufficient)
|
| 277 |
n_forbidden_violations = sum(1 for c in all_cases if c.has_forbidden_phrases)
|
| 278 |
-
n_hhem_failures = sum(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
n_valid_non_recs = sum(1 for c in all_cases if c.is_valid_non_recommendation)
|
| 280 |
|
| 281 |
# Success counts
|
| 282 |
n_raw_success = sum(1 for c in all_cases if c.e2e_success)
|
| 283 |
-
n_adjusted_success =
|
|
|
|
|
|
|
| 284 |
|
| 285 |
# Rates
|
| 286 |
evidence_pass_rate = n_generated / n_total if n_total > 0 else 0
|
| 287 |
|
| 288 |
# Forbidden phrase compliance among generated explanations
|
| 289 |
generated_cases = [c for c in all_cases if c.evidence_sufficient]
|
| 290 |
-
phrase_compliance =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
# HHEM pass rate among non-refusal generated explanations
|
| 293 |
-
non_refusal_generated = [
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
raw_e2e = n_raw_success / n_total if n_total > 0 else 0
|
| 297 |
adjusted_e2e = n_adjusted_success / n_total if n_total > 0 else 0
|
|
@@ -321,11 +341,31 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 321 |
|
| 322 |
log_section(logger, "Stage Breakdown")
|
| 323 |
logger.info("Total cases: %d", n_total)
|
| 324 |
-
logger.info(
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
log_section(logger, "Component Rates")
|
| 331 |
logger.info("Evidence pass rate: %.1f%%", evidence_pass_rate * 100)
|
|
@@ -333,8 +373,18 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 333 |
logger.info("HHEM pass rate: %.1f%%", hhem_pass_rate * 100)
|
| 334 |
|
| 335 |
log_section(logger, "END-TO-END SUCCESS RATES")
|
| 336 |
-
logger.info(
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
logger.info("Target: %.1f%%", target * 100)
|
| 339 |
logger.info("Gap to target: %.1f%%", report.gap_to_target * 100)
|
| 340 |
logger.info("Meets target: %s", "YES" if report.meets_target else "NO")
|
|
@@ -355,7 +405,9 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
|
|
| 355 |
"cases": [c.to_dict() for c in all_cases],
|
| 356 |
}
|
| 357 |
|
| 358 |
-
output_file =
|
|
|
|
|
|
|
| 359 |
with open(output_file, "w") as f:
|
| 360 |
json.dump(output, f, indent=2)
|
| 361 |
logger.info("Saved: %s", output_file)
|
|
|
|
| 149 |
case_id = 0
|
| 150 |
|
| 151 |
for query in queries:
|
| 152 |
+
logger.info('Query: "%s"', query)
|
| 153 |
|
| 154 |
products = get_candidates(
|
| 155 |
query=query,
|
|
|
|
| 275 |
n_evidence_insufficient = sum(1 for c in all_cases if not c.evidence_sufficient)
|
| 276 |
n_generated = sum(1 for c in all_cases if c.evidence_sufficient)
|
| 277 |
n_forbidden_violations = sum(1 for c in all_cases if c.has_forbidden_phrases)
|
| 278 |
+
n_hhem_failures = sum(
|
| 279 |
+
1
|
| 280 |
+
for c in all_cases
|
| 281 |
+
if c.evidence_sufficient
|
| 282 |
+
and not c.hhem_pass
|
| 283 |
+
and not c.is_valid_non_recommendation
|
| 284 |
+
)
|
| 285 |
n_valid_non_recs = sum(1 for c in all_cases if c.is_valid_non_recommendation)
|
| 286 |
|
| 287 |
# Success counts
|
| 288 |
n_raw_success = sum(1 for c in all_cases if c.e2e_success)
|
| 289 |
+
n_adjusted_success = (
|
| 290 |
+
n_raw_success + n_valid_non_recs
|
| 291 |
+
) # Valid non-recs are correct behavior
|
| 292 |
|
| 293 |
# Rates
|
| 294 |
evidence_pass_rate = n_generated / n_total if n_total > 0 else 0
|
| 295 |
|
| 296 |
# Forbidden phrase compliance among generated explanations
|
| 297 |
generated_cases = [c for c in all_cases if c.evidence_sufficient]
|
| 298 |
+
phrase_compliance = (
|
| 299 |
+
sum(1 for c in generated_cases if not c.has_forbidden_phrases)
|
| 300 |
+
/ len(generated_cases)
|
| 301 |
+
if generated_cases
|
| 302 |
+
else 0
|
| 303 |
+
)
|
| 304 |
|
| 305 |
# HHEM pass rate among non-refusal generated explanations
|
| 306 |
+
non_refusal_generated = [
|
| 307 |
+
c for c in generated_cases if not c.is_valid_non_recommendation
|
| 308 |
+
]
|
| 309 |
+
hhem_pass_rate = (
|
| 310 |
+
sum(1 for c in non_refusal_generated if c.hhem_pass)
|
| 311 |
+
/ len(non_refusal_generated)
|
| 312 |
+
if non_refusal_generated
|
| 313 |
+
else 0
|
| 314 |
+
)
|
| 315 |
|
| 316 |
raw_e2e = n_raw_success / n_total if n_total > 0 else 0
|
| 317 |
adjusted_e2e = n_adjusted_success / n_total if n_total > 0 else 0
|
|
|
|
| 341 |
|
| 342 |
log_section(logger, "Stage Breakdown")
|
| 343 |
logger.info("Total cases: %d", n_total)
|
| 344 |
+
logger.info(
|
| 345 |
+
"Evidence insufficient: %d (%.1f%%)",
|
| 346 |
+
n_evidence_insufficient,
|
| 347 |
+
n_evidence_insufficient / n_total * 100,
|
| 348 |
+
)
|
| 349 |
+
logger.info(
|
| 350 |
+
"Generated explanations: %d (%.1f%%)",
|
| 351 |
+
n_generated,
|
| 352 |
+
n_generated / n_total * 100,
|
| 353 |
+
)
|
| 354 |
+
logger.info(
|
| 355 |
+
"Forbidden phrase fails: %d (%.1f%%)",
|
| 356 |
+
n_forbidden_violations,
|
| 357 |
+
n_forbidden_violations / n_total * 100,
|
| 358 |
+
)
|
| 359 |
+
logger.info(
|
| 360 |
+
"HHEM failures: %d (%.1f%%)",
|
| 361 |
+
n_hhem_failures,
|
| 362 |
+
n_hhem_failures / n_total * 100,
|
| 363 |
+
)
|
| 364 |
+
logger.info(
|
| 365 |
+
"Valid non-recommendations:%d (%.1f%%)",
|
| 366 |
+
n_valid_non_recs,
|
| 367 |
+
n_valid_non_recs / n_total * 100,
|
| 368 |
+
)
|
| 369 |
|
| 370 |
log_section(logger, "Component Rates")
|
| 371 |
logger.info("Evidence pass rate: %.1f%%", evidence_pass_rate * 100)
|
|
|
|
| 373 |
logger.info("HHEM pass rate: %.1f%%", hhem_pass_rate * 100)
|
| 374 |
|
| 375 |
log_section(logger, "END-TO-END SUCCESS RATES")
|
| 376 |
+
logger.info(
|
| 377 |
+
"Raw E2E success: %d/%d = %.1f%%",
|
| 378 |
+
n_raw_success,
|
| 379 |
+
n_total,
|
| 380 |
+
raw_e2e * 100,
|
| 381 |
+
)
|
| 382 |
+
logger.info(
|
| 383 |
+
"Adjusted E2E success: %d/%d = %.1f%%",
|
| 384 |
+
n_adjusted_success,
|
| 385 |
+
n_total,
|
| 386 |
+
adjusted_e2e * 100,
|
| 387 |
+
)
|
| 388 |
logger.info("Target: %.1f%%", target * 100)
|
| 389 |
logger.info("Gap to target: %.1f%%", report.gap_to_target * 100)
|
| 390 |
logger.info("Meets target: %s", "YES" if report.meets_target else "NO")
|
|
|
|
| 405 |
"cases": [c.to_dict() for c in all_cases],
|
| 406 |
}
|
| 407 |
|
| 408 |
+
output_file = (
|
| 409 |
+
RESULTS_DIR / f"e2e_success_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 410 |
+
)
|
| 411 |
with open(output_file, "w") as f:
|
| 412 |
json.dump(output, f, indent=2)
|
| 413 |
logger.info("Saved: %s", output_file)
|
scripts/eda.py
CHANGED
|
@@ -17,19 +17,22 @@ FIGURES_DIR.mkdir(exist_ok=True)
|
|
| 17 |
|
| 18 |
# Plot configuration
|
| 19 |
plt.style.use("seaborn-v0_8-whitegrid")
|
| 20 |
-
plt.rcParams.update(
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Enable retina display for Jupyter notebooks
|
| 31 |
try:
|
| 32 |
from IPython import get_ipython
|
|
|
|
| 33 |
if get_ipython() is not None:
|
| 34 |
get_ipython().run_line_magic("matplotlib", "inline")
|
| 35 |
get_ipython().run_line_magic("config", "InlineBackend.figure_format='retina'")
|
|
@@ -56,15 +59,23 @@ for key, value in stats.items():
|
|
| 56 |
# %% Rating distribution
|
| 57 |
fig, ax = plt.subplots()
|
| 58 |
rating_counts = pd.Series(stats["rating_dist"])
|
| 59 |
-
bars = ax.bar(
|
|
|
|
|
|
|
| 60 |
ax.set_xlabel("Rating")
|
| 61 |
ax.set_ylabel("Count")
|
| 62 |
ax.set_title("Rating Distribution")
|
| 63 |
ax.set_xticks(rating_counts.index)
|
| 64 |
|
| 65 |
for bar, count in zip(bars, rating_counts.values):
|
| 66 |
-
ax.text(
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
plt.tight_layout()
|
| 70 |
plt.savefig(FIGURES_DIR / "rating_distribution.png", dpi=150)
|
|
@@ -84,16 +95,25 @@ fig, axes = plt.subplots(1, 2, figsize=FIGURE_SIZE_WIDE)
|
|
| 84 |
|
| 85 |
# Character length histogram
|
| 86 |
ax1 = axes[0]
|
| 87 |
-
df["text_length"].clip(upper=2000).hist(
|
|
|
|
|
|
|
| 88 |
ax1.set_xlabel("Character Length (clipped at 2000)")
|
| 89 |
ax1.set_ylabel("Count")
|
| 90 |
ax1.set_title("Review Length Distribution")
|
| 91 |
-
ax1.axvline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
ax1.legend()
|
| 93 |
|
| 94 |
# Token estimate histogram
|
| 95 |
ax2 = axes[1]
|
| 96 |
-
df["estimated_tokens"].clip(upper=500).hist(
|
|
|
|
|
|
|
| 97 |
ax2.set_xlabel("Estimated Tokens (clipped at 500)")
|
| 98 |
ax2.set_ylabel("Count")
|
| 99 |
ax2.set_title("Estimated Token Distribution")
|
|
@@ -108,12 +128,19 @@ needs_chunking = (df["estimated_tokens"] > 200).sum()
|
|
| 108 |
print("\nReview length stats:")
|
| 109 |
print(f" Median characters: {df['text_length'].median():.0f}")
|
| 110 |
print(f" Median tokens (est): {df['estimated_tokens'].median():.0f}")
|
| 111 |
-
print(
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# %% Review length by rating
|
| 114 |
fig, ax = plt.subplots()
|
| 115 |
length_by_rating = df.groupby("rating")["text_length"].median()
|
| 116 |
-
bars = ax.bar(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
ax.set_xlabel("Rating")
|
| 118 |
ax.set_ylabel("Median Review Length (chars)")
|
| 119 |
ax.set_title("Review Length by Rating")
|
|
@@ -134,7 +161,9 @@ df["year_month"] = df["datetime"].dt.to_period("M")
|
|
| 134 |
reviews_over_time = df.groupby("year_month").size()
|
| 135 |
|
| 136 |
fig, ax = plt.subplots(figsize=FIGURE_SIZE_WIDE)
|
| 137 |
-
reviews_over_time.plot(
|
|
|
|
|
|
|
| 138 |
ax.set_xlabel("Month")
|
| 139 |
ax.set_ylabel("Number of Reviews")
|
| 140 |
ax.set_title("Reviews Over Time")
|
|
@@ -156,7 +185,7 @@ missing = df.isnull().sum()
|
|
| 156 |
print("\nMissing values:")
|
| 157 |
for col, count in missing.items():
|
| 158 |
if count > 0:
|
| 159 |
-
print(f" {col}: {count:,} ({count/len(df)*100:.2f}%)")
|
| 160 |
if missing.sum() == 0:
|
| 161 |
print(" None!")
|
| 162 |
|
|
@@ -185,14 +214,18 @@ fig, axes = plt.subplots(1, 2, figsize=FIGURE_SIZE_WIDE)
|
|
| 185 |
|
| 186 |
# Reviews per user
|
| 187 |
ax1 = axes[0]
|
| 188 |
-
user_counts.clip(upper=20).value_counts().sort_index().plot(
|
|
|
|
|
|
|
| 189 |
ax1.set_xlabel("Reviews per User")
|
| 190 |
ax1.set_ylabel("Number of Users")
|
| 191 |
ax1.set_title("User Activity Distribution")
|
| 192 |
|
| 193 |
# Reviews per item
|
| 194 |
ax2 = axes[1]
|
| 195 |
-
item_counts.clip(upper=20).value_counts().sort_index().plot(
|
|
|
|
|
|
|
| 196 |
ax2.set_xlabel("Reviews per Item")
|
| 197 |
ax2.set_ylabel("Number of Items")
|
| 198 |
ax2.set_title("Item Popularity Distribution")
|
|
@@ -202,12 +235,16 @@ plt.savefig(FIGURES_DIR / "user_item_distribution.png", dpi=150)
|
|
| 202 |
plt.show()
|
| 203 |
|
| 204 |
print("\nUser activity:")
|
| 205 |
-
print(
|
|
|
|
|
|
|
| 206 |
print(f" Users with 5+ reviews: {(user_counts >= 5).sum():,}")
|
| 207 |
print(f" Max reviews by one user: {user_counts.max()}")
|
| 208 |
|
| 209 |
print("\nItem popularity:")
|
| 210 |
-
print(
|
|
|
|
|
|
|
| 211 |
print(f" Items with 5+ reviews: {(item_counts >= 5).sum():,}")
|
| 212 |
print(f" Max reviews for one item: {item_counts.max()}")
|
| 213 |
|
|
@@ -217,7 +254,9 @@ items_5plus = set(item_counts[item_counts >= 5].index)
|
|
| 217 |
|
| 218 |
eligible_mask = df["user_id"].isin(users_5plus) & df["parent_asin"].isin(items_5plus)
|
| 219 |
print("\n5-core filtering preview:")
|
| 220 |
-
print(
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# %% Sample reviews across length buckets
|
| 223 |
print("\n=== Sample Reviews by Length Bucket ===")
|
|
@@ -232,14 +271,18 @@ length_buckets = [
|
|
| 232 |
]
|
| 233 |
|
| 234 |
for min_tok, max_tok, label in length_buckets:
|
| 235 |
-
bucket_mask = (df["estimated_tokens"] >= min_tok) & (
|
|
|
|
|
|
|
| 236 |
bucket_df = df[bucket_mask]
|
| 237 |
|
| 238 |
if len(bucket_df) == 0:
|
| 239 |
print(f"{label}: No reviews")
|
| 240 |
continue
|
| 241 |
|
| 242 |
-
print(
|
|
|
|
|
|
|
| 243 |
|
| 244 |
samples = bucket_df.sample(min(3, len(bucket_df)), random_state=42)
|
| 245 |
for _, row in samples.iterrows():
|
|
@@ -256,10 +299,14 @@ df_prepared = prepare_data(subset_size=DEV_SUBSET_SIZE, verbose=False)
|
|
| 256 |
prepared_stats = get_review_stats(df_prepared)
|
| 257 |
|
| 258 |
print(f"Raw reviews: {len(df):,}")
|
| 259 |
-
print(
|
|
|
|
|
|
|
| 260 |
print(f"Unique users: {prepared_stats['unique_users']:,}")
|
| 261 |
print(f"Unique items: {prepared_stats['unique_items']:,}")
|
| 262 |
-
print(
|
|
|
|
|
|
|
| 263 |
|
| 264 |
# %% Summary
|
| 265 |
print("\n" + "=" * 50)
|
|
@@ -269,7 +316,9 @@ print(f"Total reviews: {len(df):,}")
|
|
| 269 |
print(f"Unique users: {df['user_id'].nunique():,}")
|
| 270 |
print(f"Unique items: {df['parent_asin'].nunique():,}")
|
| 271 |
print(f"Average rating: {df['rating'].mean():.2f}")
|
| 272 |
-
print(
|
|
|
|
|
|
|
| 273 |
print(f"Data quality issues: {empty_reviews + very_short + duplicate_texts}")
|
| 274 |
print(f"\nPlots saved to: {FIGURES_DIR}")
|
| 275 |
|
|
|
|
| 17 |
|
| 18 |
# Plot configuration
|
| 19 |
plt.style.use("seaborn-v0_8-whitegrid")
|
| 20 |
+
plt.rcParams.update(
|
| 21 |
+
{
|
| 22 |
+
"figure.figsize": (10, 5),
|
| 23 |
+
"figure.dpi": 100,
|
| 24 |
+
"savefig.dpi": 150,
|
| 25 |
+
"font.size": 11,
|
| 26 |
+
"axes.titlesize": 12,
|
| 27 |
+
"axes.labelsize": 11,
|
| 28 |
+
"figure.autolayout": True,
|
| 29 |
+
}
|
| 30 |
+
)
|
| 31 |
|
| 32 |
# Enable retina display for Jupyter notebooks
|
| 33 |
try:
|
| 34 |
from IPython import get_ipython
|
| 35 |
+
|
| 36 |
if get_ipython() is not None:
|
| 37 |
get_ipython().run_line_magic("matplotlib", "inline")
|
| 38 |
get_ipython().run_line_magic("config", "InlineBackend.figure_format='retina'")
|
|
|
|
| 59 |
# %% Rating distribution
|
| 60 |
fig, ax = plt.subplots()
|
| 61 |
rating_counts = pd.Series(stats["rating_dist"])
|
| 62 |
+
bars = ax.bar(
|
| 63 |
+
rating_counts.index, rating_counts.values, color=PRIMARY_COLOR, edgecolor="black"
|
| 64 |
+
)
|
| 65 |
ax.set_xlabel("Rating")
|
| 66 |
ax.set_ylabel("Count")
|
| 67 |
ax.set_title("Rating Distribution")
|
| 68 |
ax.set_xticks(rating_counts.index)
|
| 69 |
|
| 70 |
for bar, count in zip(bars, rating_counts.values):
|
| 71 |
+
ax.text(
|
| 72 |
+
bar.get_x() + bar.get_width() / 2,
|
| 73 |
+
bar.get_height() + 50,
|
| 74 |
+
f"{count:,}",
|
| 75 |
+
ha="center",
|
| 76 |
+
va="bottom",
|
| 77 |
+
fontsize=10,
|
| 78 |
+
)
|
| 79 |
|
| 80 |
plt.tight_layout()
|
| 81 |
plt.savefig(FIGURES_DIR / "rating_distribution.png", dpi=150)
|
|
|
|
| 95 |
|
| 96 |
# Character length histogram
|
| 97 |
ax1 = axes[0]
|
| 98 |
+
df["text_length"].clip(upper=2000).hist(
|
| 99 |
+
bins=50, ax=ax1, color=PRIMARY_COLOR, edgecolor="white"
|
| 100 |
+
)
|
| 101 |
ax1.set_xlabel("Character Length (clipped at 2000)")
|
| 102 |
ax1.set_ylabel("Count")
|
| 103 |
ax1.set_title("Review Length Distribution")
|
| 104 |
+
ax1.axvline(
|
| 105 |
+
df["text_length"].median(),
|
| 106 |
+
color="red",
|
| 107 |
+
linestyle="--",
|
| 108 |
+
label=f"Median: {df['text_length'].median():.0f}",
|
| 109 |
+
)
|
| 110 |
ax1.legend()
|
| 111 |
|
| 112 |
# Token estimate histogram
|
| 113 |
ax2 = axes[1]
|
| 114 |
+
df["estimated_tokens"].clip(upper=500).hist(
|
| 115 |
+
bins=50, ax=ax2, color=SECONDARY_COLOR, edgecolor="white"
|
| 116 |
+
)
|
| 117 |
ax2.set_xlabel("Estimated Tokens (clipped at 500)")
|
| 118 |
ax2.set_ylabel("Count")
|
| 119 |
ax2.set_title("Estimated Token Distribution")
|
|
|
|
| 128 |
print("\nReview length stats:")
|
| 129 |
print(f" Median characters: {df['text_length'].median():.0f}")
|
| 130 |
print(f" Median tokens (est): {df['estimated_tokens'].median():.0f}")
|
| 131 |
+
print(
|
| 132 |
+
f" Reviews > 200 tokens: {needs_chunking:,} ({needs_chunking / len(df) * 100:.1f}%)"
|
| 133 |
+
)
|
| 134 |
|
| 135 |
# %% Review length by rating
|
| 136 |
fig, ax = plt.subplots()
|
| 137 |
length_by_rating = df.groupby("rating")["text_length"].median()
|
| 138 |
+
bars = ax.bar(
|
| 139 |
+
length_by_rating.index,
|
| 140 |
+
length_by_rating.values,
|
| 141 |
+
color=PRIMARY_COLOR,
|
| 142 |
+
edgecolor="white",
|
| 143 |
+
)
|
| 144 |
ax.set_xlabel("Rating")
|
| 145 |
ax.set_ylabel("Median Review Length (chars)")
|
| 146 |
ax.set_title("Review Length by Rating")
|
|
|
|
| 161 |
reviews_over_time = df.groupby("year_month").size()
|
| 162 |
|
| 163 |
fig, ax = plt.subplots(figsize=FIGURE_SIZE_WIDE)
|
| 164 |
+
reviews_over_time.plot(
|
| 165 |
+
kind="line", ax=ax, marker="o", markersize=3, linewidth=1, color=PRIMARY_COLOR
|
| 166 |
+
)
|
| 167 |
ax.set_xlabel("Month")
|
| 168 |
ax.set_ylabel("Number of Reviews")
|
| 169 |
ax.set_title("Reviews Over Time")
|
|
|
|
| 185 |
print("\nMissing values:")
|
| 186 |
for col, count in missing.items():
|
| 187 |
if count > 0:
|
| 188 |
+
print(f" {col}: {count:,} ({count / len(df) * 100:.2f}%)")
|
| 189 |
if missing.sum() == 0:
|
| 190 |
print(" None!")
|
| 191 |
|
|
|
|
| 214 |
|
| 215 |
# Reviews per user
|
| 216 |
ax1 = axes[0]
|
| 217 |
+
user_counts.clip(upper=20).value_counts().sort_index().plot(
|
| 218 |
+
kind="bar", ax=ax1, color=PRIMARY_COLOR
|
| 219 |
+
)
|
| 220 |
ax1.set_xlabel("Reviews per User")
|
| 221 |
ax1.set_ylabel("Number of Users")
|
| 222 |
ax1.set_title("User Activity Distribution")
|
| 223 |
|
| 224 |
# Reviews per item
|
| 225 |
ax2 = axes[1]
|
| 226 |
+
item_counts.clip(upper=20).value_counts().sort_index().plot(
|
| 227 |
+
kind="bar", ax=ax2, color=SECONDARY_COLOR
|
| 228 |
+
)
|
| 229 |
ax2.set_xlabel("Reviews per Item")
|
| 230 |
ax2.set_ylabel("Number of Items")
|
| 231 |
ax2.set_title("Item Popularity Distribution")
|
|
|
|
| 235 |
plt.show()
|
| 236 |
|
| 237 |
print("\nUser activity:")
|
| 238 |
+
print(
|
| 239 |
+
f" Users with 1 review: {(user_counts == 1).sum():,} ({(user_counts == 1).sum() / len(user_counts) * 100:.1f}%)"
|
| 240 |
+
)
|
| 241 |
print(f" Users with 5+ reviews: {(user_counts >= 5).sum():,}")
|
| 242 |
print(f" Max reviews by one user: {user_counts.max()}")
|
| 243 |
|
| 244 |
print("\nItem popularity:")
|
| 245 |
+
print(
|
| 246 |
+
f" Items with 1 review: {(item_counts == 1).sum():,} ({(item_counts == 1).sum() / len(item_counts) * 100:.1f}%)"
|
| 247 |
+
)
|
| 248 |
print(f" Items with 5+ reviews: {(item_counts >= 5).sum():,}")
|
| 249 |
print(f" Max reviews for one item: {item_counts.max()}")
|
| 250 |
|
|
|
|
| 254 |
|
| 255 |
eligible_mask = df["user_id"].isin(users_5plus) & df["parent_asin"].isin(items_5plus)
|
| 256 |
print("\n5-core filtering preview:")
|
| 257 |
+
print(
|
| 258 |
+
f" Reviews eligible (first pass): {eligible_mask.sum():,} ({eligible_mask.sum() / len(df) * 100:.1f}%)"
|
| 259 |
+
)
|
| 260 |
|
| 261 |
# %% Sample reviews across length buckets
|
| 262 |
print("\n=== Sample Reviews by Length Bucket ===")
|
|
|
|
| 271 |
]
|
| 272 |
|
| 273 |
for min_tok, max_tok, label in length_buckets:
|
| 274 |
+
bucket_mask = (df["estimated_tokens"] >= min_tok) & (
|
| 275 |
+
df["estimated_tokens"] < max_tok
|
| 276 |
+
)
|
| 277 |
bucket_df = df[bucket_mask]
|
| 278 |
|
| 279 |
if len(bucket_df) == 0:
|
| 280 |
print(f"{label}: No reviews")
|
| 281 |
continue
|
| 282 |
|
| 283 |
+
print(
|
| 284 |
+
f"{label}: {len(bucket_df):,} reviews ({len(bucket_df) / len(df) * 100:.1f}%)"
|
| 285 |
+
)
|
| 286 |
|
| 287 |
samples = bucket_df.sample(min(3, len(bucket_df)), random_state=42)
|
| 288 |
for _, row in samples.iterrows():
|
|
|
|
| 299 |
prepared_stats = get_review_stats(df_prepared)
|
| 300 |
|
| 301 |
print(f"Raw reviews: {len(df):,}")
|
| 302 |
+
print(
|
| 303 |
+
f"Prepared reviews: {len(df_prepared):,} ({len(df_prepared) / len(df) * 100:.1f}% retained)"
|
| 304 |
+
)
|
| 305 |
print(f"Unique users: {prepared_stats['unique_users']:,}")
|
| 306 |
print(f"Unique items: {prepared_stats['unique_items']:,}")
|
| 307 |
+
print(
|
| 308 |
+
f"Avg rating: {prepared_stats['avg_rating']:.2f} (raw: {stats['avg_rating']:.2f})"
|
| 309 |
+
)
|
| 310 |
|
| 311 |
# %% Summary
|
| 312 |
print("\n" + "=" * 50)
|
|
|
|
| 316 |
print(f"Unique users: {df['user_id'].nunique():,}")
|
| 317 |
print(f"Unique items: {df['parent_asin'].nunique():,}")
|
| 318 |
print(f"Average rating: {df['rating'].mean():.2f}")
|
| 319 |
+
print(
|
| 320 |
+
f"Reviews needing chunking: {needs_chunking:,} ({needs_chunking / len(df) * 100:.1f}%)"
|
| 321 |
+
)
|
| 322 |
print(f"Data quality issues: {empty_reviews + very_short + duplicate_texts}")
|
| 323 |
print(f"\nPlots saved to: {FIGURES_DIR}")
|
| 324 |
|
scripts/evaluation.py
CHANGED
|
@@ -48,6 +48,7 @@ def create_recommend_fn(
|
|
| 48 |
rating_weight: float = 0.0,
|
| 49 |
):
|
| 50 |
"""Create a recommend function for evaluation."""
|
|
|
|
| 51 |
def _recommend(query: str) -> list[str]:
|
| 52 |
recs = recommend(
|
| 53 |
query=query,
|
|
@@ -59,10 +60,13 @@ def create_recommend_fn(
|
|
| 59 |
rating_weight=rating_weight,
|
| 60 |
)
|
| 61 |
return [r.product_id for r in recs]
|
|
|
|
| 62 |
return _recommend
|
| 63 |
|
| 64 |
|
| 65 |
-
def save_results(
|
|
|
|
|
|
|
| 66 |
"""Save evaluation results to JSON file.
|
| 67 |
|
| 68 |
Also writes a fixed-name "latest" file so downstream scripts (e.g.
|
|
@@ -89,6 +93,7 @@ def save_results(results: dict, filename: str | None = None, dataset: str | None
|
|
| 89 |
# SECTION: Primary Evaluation
|
| 90 |
# ============================================================================
|
| 91 |
|
|
|
|
| 92 |
def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items):
|
| 93 |
"""Run primary evaluation on leave-one-out dataset."""
|
| 94 |
log_banner(logger, "EVALUATION: Leave-One-Out (History Queries)")
|
|
@@ -124,6 +129,7 @@ def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items)
|
|
| 124 |
# SECTION: Aggregation Methods
|
| 125 |
# ============================================================================
|
| 126 |
|
|
|
|
| 127 |
def run_aggregation_comparison(cases):
|
| 128 |
"""Compare different aggregation methods."""
|
| 129 |
log_banner(logger, "AGGREGATION METHOD COMPARISON")
|
|
@@ -154,6 +160,7 @@ def run_aggregation_comparison(cases):
|
|
| 154 |
# SECTION: Rating Filter
|
| 155 |
# ============================================================================
|
| 156 |
|
|
|
|
| 157 |
def run_rating_filter_comparison(cases):
|
| 158 |
"""Compare different rating filters."""
|
| 159 |
log_banner(logger, "RATING FILTER COMPARISON")
|
|
@@ -177,6 +184,7 @@ def run_rating_filter_comparison(cases):
|
|
| 177 |
# SECTION: K Values
|
| 178 |
# ============================================================================
|
| 179 |
|
|
|
|
| 180 |
def run_k_value_comparison(cases):
|
| 181 |
"""Compare metrics at different K values."""
|
| 182 |
log_banner(logger, "METRICS AT DIFFERENT K VALUES")
|
|
@@ -200,16 +208,23 @@ def run_k_value_comparison(cases):
|
|
| 200 |
# SECTION: Weight Tuning
|
| 201 |
# ============================================================================
|
| 202 |
|
|
|
|
| 203 |
def run_weight_tuning(cases):
|
| 204 |
"""Run ranking weight tuning experiment."""
|
| 205 |
log_banner(logger, "RANKING WEIGHT TUNING (alpha*sim + beta*rating)")
|
| 206 |
|
| 207 |
weight_configs = [
|
| 208 |
-
(1.0, 0.0),
|
| 209 |
-
(0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
]
|
| 211 |
|
| 212 |
-
logger.info(
|
|
|
|
|
|
|
| 213 |
logger.info("-" * 52)
|
| 214 |
|
| 215 |
results = []
|
|
@@ -227,15 +242,22 @@ def run_weight_tuning(cases):
|
|
| 227 |
k=10,
|
| 228 |
verbose=False,
|
| 229 |
)
|
| 230 |
-
results.append(
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
logger.info(
|
| 237 |
"%-10.1f %-12.1f %-10.4f %-10.4f %-10.4f",
|
| 238 |
-
alpha,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
if report.ndcg_at_k > best_ndcg:
|
|
@@ -245,7 +267,9 @@ def run_weight_tuning(cases):
|
|
| 245 |
logger.info("-" * 52)
|
| 246 |
logger.info(
|
| 247 |
"Best: alpha=%.1f, beta=%.1f (NDCG@10=%.4f)",
|
| 248 |
-
best_weights[0],
|
|
|
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
return results, best_weights, best_ndcg
|
|
@@ -255,6 +279,7 @@ def run_weight_tuning(cases):
|
|
| 255 |
# SECTION: Baseline Comparison
|
| 256 |
# ============================================================================
|
| 257 |
|
|
|
|
| 258 |
def run_baseline_comparison(cases, train_records, all_products, product_embeddings):
|
| 259 |
"""Compare against baselines: Random, Popularity, ItemKNN."""
|
| 260 |
log_banner(logger, "BASELINE COMPARISON")
|
|
@@ -274,7 +299,12 @@ def run_baseline_comparison(cases, train_records, all_products, product_embeddin
|
|
| 274 |
return itemknn_baseline.recommend(query, top_k=10)
|
| 275 |
|
| 276 |
def rag_recommend(query: str) -> list[str]:
|
| 277 |
-
recs = recommend(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
return [r.product_id for r in recs]
|
| 279 |
|
| 280 |
results = {}
|
|
@@ -305,7 +335,10 @@ def run_baseline_comparison(cases, train_records, all_products, product_embeddin
|
|
| 305 |
for name, report in results.items():
|
| 306 |
logger.info(
|
| 307 |
"%-15s %10.4f %10.4f %10.4f",
|
| 308 |
-
name,
|
|
|
|
|
|
|
|
|
|
| 309 |
)
|
| 310 |
|
| 311 |
# Relative improvements
|
|
@@ -323,17 +356,22 @@ def run_baseline_comparison(cases, train_records, all_products, product_embeddin
|
|
| 323 |
# Main
|
| 324 |
# ============================================================================
|
| 325 |
|
|
|
|
| 326 |
def main():
|
| 327 |
parser = argparse.ArgumentParser(description="Run recommendation evaluation")
|
| 328 |
-
parser.add_argument("--baselines", action="store_true", help="Include baseline comparison")
|
| 329 |
parser.add_argument(
|
| 330 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
choices=["all", "primary", "aggregation", "rating", "k", "weights"],
|
| 332 |
default="primary",
|
| 333 |
help="Which section to run (default: primary)",
|
| 334 |
)
|
| 335 |
parser.add_argument(
|
| 336 |
-
"--dataset",
|
|
|
|
| 337 |
default="eval_loo_history.json",
|
| 338 |
help="Evaluation dataset file (default: eval_loo_history.json)",
|
| 339 |
)
|
|
@@ -375,7 +413,9 @@ def main():
|
|
| 375 |
)
|
| 376 |
|
| 377 |
if args.section in ("all", "aggregation"):
|
| 378 |
-
all_results["experiments"]["aggregation_methods"] = run_aggregation_comparison(
|
|
|
|
|
|
|
| 379 |
|
| 380 |
if args.section in ("all", "rating"):
|
| 381 |
run_rating_filter_comparison(cases)
|
|
|
|
| 48 |
rating_weight: float = 0.0,
|
| 49 |
):
|
| 50 |
"""Create a recommend function for evaluation."""
|
| 51 |
+
|
| 52 |
def _recommend(query: str) -> list[str]:
|
| 53 |
recs = recommend(
|
| 54 |
query=query,
|
|
|
|
| 60 |
rating_weight=rating_weight,
|
| 61 |
)
|
| 62 |
return [r.product_id for r in recs]
|
| 63 |
+
|
| 64 |
return _recommend
|
| 65 |
|
| 66 |
|
| 67 |
+
def save_results(
|
| 68 |
+
results: dict, filename: str | None = None, dataset: str | None = None
|
| 69 |
+
) -> Path:
|
| 70 |
"""Save evaluation results to JSON file.
|
| 71 |
|
| 72 |
Also writes a fixed-name "latest" file so downstream scripts (e.g.
|
|
|
|
| 93 |
# SECTION: Primary Evaluation
|
| 94 |
# ============================================================================
|
| 95 |
|
| 96 |
+
|
| 97 |
def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items):
|
| 98 |
"""Run primary evaluation on leave-one-out dataset."""
|
| 99 |
log_banner(logger, "EVALUATION: Leave-One-Out (History Queries)")
|
|
|
|
| 129 |
# SECTION: Aggregation Methods
|
| 130 |
# ============================================================================
|
| 131 |
|
| 132 |
+
|
| 133 |
def run_aggregation_comparison(cases):
|
| 134 |
"""Compare different aggregation methods."""
|
| 135 |
log_banner(logger, "AGGREGATION METHOD COMPARISON")
|
|
|
|
| 160 |
# SECTION: Rating Filter
|
| 161 |
# ============================================================================
|
| 162 |
|
| 163 |
+
|
| 164 |
def run_rating_filter_comparison(cases):
|
| 165 |
"""Compare different rating filters."""
|
| 166 |
log_banner(logger, "RATING FILTER COMPARISON")
|
|
|
|
| 184 |
# SECTION: K Values
|
| 185 |
# ============================================================================
|
| 186 |
|
| 187 |
+
|
| 188 |
def run_k_value_comparison(cases):
|
| 189 |
"""Compare metrics at different K values."""
|
| 190 |
log_banner(logger, "METRICS AT DIFFERENT K VALUES")
|
|
|
|
| 208 |
# SECTION: Weight Tuning
|
| 209 |
# ============================================================================
|
| 210 |
|
| 211 |
+
|
| 212 |
def run_weight_tuning(cases):
|
| 213 |
"""Run ranking weight tuning experiment."""
|
| 214 |
log_banner(logger, "RANKING WEIGHT TUNING (alpha*sim + beta*rating)")
|
| 215 |
|
| 216 |
weight_configs = [
|
| 217 |
+
(1.0, 0.0),
|
| 218 |
+
(0.9, 0.1),
|
| 219 |
+
(0.8, 0.2),
|
| 220 |
+
(0.7, 0.3),
|
| 221 |
+
(0.6, 0.4),
|
| 222 |
+
(0.5, 0.5),
|
| 223 |
]
|
| 224 |
|
| 225 |
+
logger.info(
|
| 226 |
+
"%-10s %-12s %-10s %-10s %-10s", "alpha", "beta", "NDCG@10", "Hit@10", "MRR"
|
| 227 |
+
)
|
| 228 |
logger.info("-" * 52)
|
| 229 |
|
| 230 |
results = []
|
|
|
|
| 242 |
k=10,
|
| 243 |
verbose=False,
|
| 244 |
)
|
| 245 |
+
results.append(
|
| 246 |
+
{
|
| 247 |
+
"alpha": alpha,
|
| 248 |
+
"beta": beta,
|
| 249 |
+
"ndcg_at_10": report.ndcg_at_k,
|
| 250 |
+
"hit_at_10": report.hit_at_k,
|
| 251 |
+
"mrr": report.mrr,
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
logger.info(
|
| 255 |
"%-10.1f %-12.1f %-10.4f %-10.4f %-10.4f",
|
| 256 |
+
alpha,
|
| 257 |
+
beta,
|
| 258 |
+
report.ndcg_at_k,
|
| 259 |
+
report.hit_at_k,
|
| 260 |
+
report.mrr,
|
| 261 |
)
|
| 262 |
|
| 263 |
if report.ndcg_at_k > best_ndcg:
|
|
|
|
| 267 |
logger.info("-" * 52)
|
| 268 |
logger.info(
|
| 269 |
"Best: alpha=%.1f, beta=%.1f (NDCG@10=%.4f)",
|
| 270 |
+
best_weights[0],
|
| 271 |
+
best_weights[1],
|
| 272 |
+
best_ndcg,
|
| 273 |
)
|
| 274 |
|
| 275 |
return results, best_weights, best_ndcg
|
|
|
|
| 279 |
# SECTION: Baseline Comparison
|
| 280 |
# ============================================================================
|
| 281 |
|
| 282 |
+
|
| 283 |
def run_baseline_comparison(cases, train_records, all_products, product_embeddings):
|
| 284 |
"""Compare against baselines: Random, Popularity, ItemKNN."""
|
| 285 |
log_banner(logger, "BASELINE COMPARISON")
|
|
|
|
| 299 |
return itemknn_baseline.recommend(query, top_k=10)
|
| 300 |
|
| 301 |
def rag_recommend(query: str) -> list[str]:
|
| 302 |
+
recs = recommend(
|
| 303 |
+
query=query,
|
| 304 |
+
top_k=10,
|
| 305 |
+
candidate_limit=100,
|
| 306 |
+
aggregation=AggregationMethod.MAX,
|
| 307 |
+
)
|
| 308 |
return [r.product_id for r in recs]
|
| 309 |
|
| 310 |
results = {}
|
|
|
|
| 335 |
for name, report in results.items():
|
| 336 |
logger.info(
|
| 337 |
"%-15s %10.4f %10.4f %10.4f",
|
| 338 |
+
name,
|
| 339 |
+
report.ndcg_at_k,
|
| 340 |
+
report.hit_at_k,
|
| 341 |
+
report.mrr,
|
| 342 |
)
|
| 343 |
|
| 344 |
# Relative improvements
|
|
|
|
| 356 |
# Main
|
| 357 |
# ============================================================================
|
| 358 |
|
| 359 |
+
|
| 360 |
def main():
|
| 361 |
parser = argparse.ArgumentParser(description="Run recommendation evaluation")
|
|
|
|
| 362 |
parser.add_argument(
|
| 363 |
+
"--baselines", action="store_true", help="Include baseline comparison"
|
| 364 |
+
)
|
| 365 |
+
parser.add_argument(
|
| 366 |
+
"--section",
|
| 367 |
+
"-s",
|
| 368 |
choices=["all", "primary", "aggregation", "rating", "k", "weights"],
|
| 369 |
default="primary",
|
| 370 |
help="Which section to run (default: primary)",
|
| 371 |
)
|
| 372 |
parser.add_argument(
|
| 373 |
+
"--dataset",
|
| 374 |
+
"-d",
|
| 375 |
default="eval_loo_history.json",
|
| 376 |
help="Evaluation dataset file (default: eval_loo_history.json)",
|
| 377 |
)
|
|
|
|
| 413 |
)
|
| 414 |
|
| 415 |
if args.section in ("all", "aggregation"):
|
| 416 |
+
all_results["experiments"]["aggregation_methods"] = run_aggregation_comparison(
|
| 417 |
+
cases
|
| 418 |
+
)
|
| 419 |
|
| 420 |
if args.section in ("all", "rating"):
|
| 421 |
run_rating_filter_comparison(cases)
|
scripts/explanation.py
CHANGED
|
@@ -40,6 +40,7 @@ PRODUCTS_PER_QUERY = 2
|
|
| 40 |
# SECTION: Basic Explanation Generation
|
| 41 |
# ============================================================================
|
| 42 |
|
|
|
|
| 43 |
def run_basic_tests():
|
| 44 |
"""Test basic explanation generation and HHEM detection."""
|
| 45 |
from sage.services.explanation import Explainer
|
|
@@ -59,10 +60,13 @@ def run_basic_tests():
|
|
| 59 |
query_results = {}
|
| 60 |
for query in test_queries:
|
| 61 |
products = get_candidates(
|
| 62 |
-
query=query,
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
query_results[query] = products
|
| 65 |
-
logger.info(
|
| 66 |
logger.info(" Found %d products", len(products))
|
| 67 |
|
| 68 |
# Generate explanations
|
|
@@ -71,7 +75,7 @@ def run_basic_tests():
|
|
| 71 |
all_explanations = []
|
| 72 |
|
| 73 |
for query, products in query_results.items():
|
| 74 |
-
logger.info(
|
| 75 |
for product in products[:PRODUCTS_PER_QUERY]:
|
| 76 |
result = explainer.generate_explanation(query, product)
|
| 77 |
all_explanations.append(result)
|
|
@@ -100,7 +104,7 @@ def run_basic_tests():
|
|
| 100 |
if query_results:
|
| 101 |
test_query = list(query_results.keys())[0]
|
| 102 |
test_product = query_results[test_query][0]
|
| 103 |
-
logger.info(
|
| 104 |
logger.info("Streaming: ")
|
| 105 |
|
| 106 |
stream = explainer.generate_explanation_stream(test_query, test_product)
|
|
@@ -110,7 +114,9 @@ def run_basic_tests():
|
|
| 110 |
logger.info("".join(chunks))
|
| 111 |
|
| 112 |
streamed_result = stream.get_complete_result()
|
| 113 |
-
hhem = detector.check_explanation(
|
|
|
|
|
|
|
| 114 |
logger.info("HHEM Score: %.3f", hhem.score)
|
| 115 |
|
| 116 |
log_banner(logger, "BASIC TESTS COMPLETE")
|
|
@@ -120,7 +126,10 @@ def run_basic_tests():
|
|
| 120 |
# SECTION: Evidence Quality Gate
|
| 121 |
# ============================================================================
|
| 122 |
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
| 124 |
"""Create a mock ProductScore for testing."""
|
| 125 |
chunks = [
|
| 126 |
RetrievedChunk(
|
|
@@ -145,7 +154,11 @@ def run_quality_gate_tests():
|
|
| 145 |
"""Test the evidence quality gate."""
|
| 146 |
from sage.core.evidence import check_evidence_quality, generate_refusal_message
|
| 147 |
from sage.services.faithfulness import is_refusal
|
| 148 |
-
from sage.config import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
log_banner(logger, "EVIDENCE QUALITY GATE TESTS")
|
| 151 |
|
|
@@ -161,7 +174,14 @@ def run_quality_gate_tests():
|
|
| 161 |
product = create_mock_product(n_chunks, tok, score)
|
| 162 |
quality = check_evidence_quality(product)
|
| 163 |
status = "PASS" if quality.is_sufficient == expected else "FAIL"
|
| 164 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
assert quality.is_sufficient == expected
|
| 166 |
|
| 167 |
log_section(logger, "2. REFUSAL GENERATION")
|
|
@@ -172,12 +192,18 @@ def run_quality_gate_tests():
|
|
| 172 |
quality = check_evidence_quality(product)
|
| 173 |
refusal = generate_refusal_message(query, quality)
|
| 174 |
detected = is_refusal(refusal)
|
| 175 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
assert detected
|
| 177 |
|
| 178 |
logger.info(
|
| 179 |
"Thresholds: chunks=%d, tokens=%d, score=%.2f",
|
| 180 |
-
MIN_EVIDENCE_CHUNKS,
|
|
|
|
|
|
|
| 181 |
)
|
| 182 |
log_banner(logger, "QUALITY GATE TESTS COMPLETE")
|
| 183 |
|
|
@@ -186,9 +212,14 @@ def run_quality_gate_tests():
|
|
| 186 |
# SECTION: Verification Loop
|
| 187 |
# ============================================================================
|
| 188 |
|
|
|
|
| 189 |
def run_verification_tests():
|
| 190 |
"""Test the post-generation verification loop."""
|
| 191 |
-
from sage.core.verification import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
log_banner(logger, "VERIFICATION LOOP TESTS")
|
| 194 |
|
|
@@ -235,6 +266,7 @@ def run_verification_tests():
|
|
| 235 |
# SECTION: Cold-Start
|
| 236 |
# ============================================================================
|
| 237 |
|
|
|
|
| 238 |
def run_cold_start_tests():
|
| 239 |
"""Test cold-start handling."""
|
| 240 |
from sage.services.cold_start import (
|
|
@@ -264,7 +296,9 @@ def run_cold_start_tests():
|
|
| 264 |
for count in test_counts:
|
| 265 |
level = get_warmup_level(count)
|
| 266 |
weight = get_content_weight(count)
|
| 267 |
-
logger.info(
|
|
|
|
|
|
|
| 268 |
|
| 269 |
# Test preferences to query
|
| 270 |
log_section(logger, "2. PREFERENCES TO QUERY")
|
|
@@ -277,7 +311,7 @@ def run_cold_start_tests():
|
|
| 277 |
)
|
| 278 |
query = preferences_to_query(prefs)
|
| 279 |
logger.info("Preferences: %s", prefs)
|
| 280 |
-
logger.info(
|
| 281 |
|
| 282 |
# Test cold-start recommendations
|
| 283 |
log_section(logger, "3. COLD-START RECOMMENDATIONS")
|
|
@@ -290,7 +324,9 @@ def run_cold_start_tests():
|
|
| 290 |
)
|
| 291 |
logger.info("Got %d recommendations", len(recs))
|
| 292 |
for r in recs[:3]:
|
| 293 |
-
logger.info(
|
|
|
|
|
|
|
| 294 |
|
| 295 |
logger.info("Query-based (cold user):")
|
| 296 |
recs = recommend_cold_start_user(
|
|
@@ -337,10 +373,12 @@ def run_cold_start_tests():
|
|
| 337 |
# Main
|
| 338 |
# ============================================================================
|
| 339 |
|
|
|
|
| 340 |
def main():
|
| 341 |
parser = argparse.ArgumentParser(description="Run explanation tests")
|
| 342 |
parser.add_argument(
|
| 343 |
-
"--section",
|
|
|
|
| 344 |
choices=["all", "basic", "gate", "verify", "cold"],
|
| 345 |
default="all",
|
| 346 |
help="Which section to run",
|
|
|
|
| 40 |
# SECTION: Basic Explanation Generation
|
| 41 |
# ============================================================================
|
| 42 |
|
| 43 |
+
|
| 44 |
def run_basic_tests():
|
| 45 |
"""Test basic explanation generation and HHEM detection."""
|
| 46 |
from sage.services.explanation import Explainer
|
|
|
|
| 60 |
query_results = {}
|
| 61 |
for query in test_queries:
|
| 62 |
products = get_candidates(
|
| 63 |
+
query=query,
|
| 64 |
+
k=TOP_K_PRODUCTS,
|
| 65 |
+
min_rating=4.0,
|
| 66 |
+
aggregation=AggregationMethod.MAX,
|
| 67 |
)
|
| 68 |
query_results[query] = products
|
| 69 |
+
logger.info('Query: "%s"', query)
|
| 70 |
logger.info(" Found %d products", len(products))
|
| 71 |
|
| 72 |
# Generate explanations
|
|
|
|
| 75 |
all_explanations = []
|
| 76 |
|
| 77 |
for query, products in query_results.items():
|
| 78 |
+
logger.info('--- Query: "%s" ---', query)
|
| 79 |
for product in products[:PRODUCTS_PER_QUERY]:
|
| 80 |
result = explainer.generate_explanation(query, product)
|
| 81 |
all_explanations.append(result)
|
|
|
|
| 104 |
if query_results:
|
| 105 |
test_query = list(query_results.keys())[0]
|
| 106 |
test_product = query_results[test_query][0]
|
| 107 |
+
logger.info('Query: "%s"', test_query)
|
| 108 |
logger.info("Streaming: ")
|
| 109 |
|
| 110 |
stream = explainer.generate_explanation_stream(test_query, test_product)
|
|
|
|
| 114 |
logger.info("".join(chunks))
|
| 115 |
|
| 116 |
streamed_result = stream.get_complete_result()
|
| 117 |
+
hhem = detector.check_explanation(
|
| 118 |
+
streamed_result.evidence_texts, streamed_result.explanation
|
| 119 |
+
)
|
| 120 |
logger.info("HHEM Score: %.3f", hhem.score)
|
| 121 |
|
| 122 |
log_banner(logger, "BASIC TESTS COMPLETE")
|
|
|
|
| 126 |
# SECTION: Evidence Quality Gate
|
| 127 |
# ============================================================================
|
| 128 |
|
| 129 |
+
|
| 130 |
+
def create_mock_product(
|
| 131 |
+
n_chunks: int, tokens_per_chunk: int = 100, product_score: float = 0.85
|
| 132 |
+
) -> ProductScore:
|
| 133 |
"""Create a mock ProductScore for testing."""
|
| 134 |
chunks = [
|
| 135 |
RetrievedChunk(
|
|
|
|
| 154 |
"""Test the evidence quality gate."""
|
| 155 |
from sage.core.evidence import check_evidence_quality, generate_refusal_message
|
| 156 |
from sage.services.faithfulness import is_refusal
|
| 157 |
+
from sage.config import (
|
| 158 |
+
MIN_EVIDENCE_CHUNKS,
|
| 159 |
+
MIN_EVIDENCE_TOKENS,
|
| 160 |
+
MIN_RETRIEVAL_SCORE,
|
| 161 |
+
)
|
| 162 |
|
| 163 |
log_banner(logger, "EVIDENCE QUALITY GATE TESTS")
|
| 164 |
|
|
|
|
| 174 |
product = create_mock_product(n_chunks, tok, score)
|
| 175 |
quality = check_evidence_quality(product)
|
| 176 |
status = "PASS" if quality.is_sufficient == expected else "FAIL"
|
| 177 |
+
logger.info(
|
| 178 |
+
"[%s] %d chunks, %d tok, score=%.2f -> %s",
|
| 179 |
+
status,
|
| 180 |
+
n_chunks,
|
| 181 |
+
tok,
|
| 182 |
+
score,
|
| 183 |
+
reason,
|
| 184 |
+
)
|
| 185 |
assert quality.is_sufficient == expected
|
| 186 |
|
| 187 |
log_section(logger, "2. REFUSAL GENERATION")
|
|
|
|
| 192 |
quality = check_evidence_quality(product)
|
| 193 |
refusal = generate_refusal_message(query, quality)
|
| 194 |
detected = is_refusal(refusal)
|
| 195 |
+
logger.info(
|
| 196 |
+
"[%s] Refusal detected for %s",
|
| 197 |
+
"PASS" if detected else "FAIL",
|
| 198 |
+
quality.failure_reason,
|
| 199 |
+
)
|
| 200 |
assert detected
|
| 201 |
|
| 202 |
logger.info(
|
| 203 |
"Thresholds: chunks=%d, tokens=%d, score=%.2f",
|
| 204 |
+
MIN_EVIDENCE_CHUNKS,
|
| 205 |
+
MIN_EVIDENCE_TOKENS,
|
| 206 |
+
MIN_RETRIEVAL_SCORE,
|
| 207 |
)
|
| 208 |
log_banner(logger, "QUALITY GATE TESTS COMPLETE")
|
| 209 |
|
|
|
|
| 212 |
# SECTION: Verification Loop
|
| 213 |
# ============================================================================
|
| 214 |
|
| 215 |
+
|
| 216 |
def run_verification_tests():
|
| 217 |
"""Test the post-generation verification loop."""
|
| 218 |
+
from sage.core.verification import (
|
| 219 |
+
extract_quotes,
|
| 220 |
+
verify_quote_in_evidence,
|
| 221 |
+
verify_explanation,
|
| 222 |
+
)
|
| 223 |
|
| 224 |
log_banner(logger, "VERIFICATION LOOP TESTS")
|
| 225 |
|
|
|
|
| 266 |
# SECTION: Cold-Start
|
| 267 |
# ============================================================================
|
| 268 |
|
| 269 |
+
|
| 270 |
def run_cold_start_tests():
|
| 271 |
"""Test cold-start handling."""
|
| 272 |
from sage.services.cold_start import (
|
|
|
|
| 296 |
for count in test_counts:
|
| 297 |
level = get_warmup_level(count)
|
| 298 |
weight = get_content_weight(count)
|
| 299 |
+
logger.info(
|
| 300 |
+
" %d interactions: level=%s, content_weight=%.1f", count, level, weight
|
| 301 |
+
)
|
| 302 |
|
| 303 |
# Test preferences to query
|
| 304 |
log_section(logger, "2. PREFERENCES TO QUERY")
|
|
|
|
| 311 |
)
|
| 312 |
query = preferences_to_query(prefs)
|
| 313 |
logger.info("Preferences: %s", prefs)
|
| 314 |
+
logger.info('Query: "%s"', query)
|
| 315 |
|
| 316 |
# Test cold-start recommendations
|
| 317 |
log_section(logger, "3. COLD-START RECOMMENDATIONS")
|
|
|
|
| 324 |
)
|
| 325 |
logger.info("Got %d recommendations", len(recs))
|
| 326 |
for r in recs[:3]:
|
| 327 |
+
logger.info(
|
| 328 |
+
" %s: score=%.3f, rating=%.1f", r.product_id, r.score, r.avg_rating
|
| 329 |
+
)
|
| 330 |
|
| 331 |
logger.info("Query-based (cold user):")
|
| 332 |
recs = recommend_cold_start_user(
|
|
|
|
| 373 |
# Main
|
| 374 |
# ============================================================================
|
| 375 |
|
| 376 |
+
|
| 377 |
def main():
|
| 378 |
parser = argparse.ArgumentParser(description="Run explanation tests")
|
| 379 |
parser.add_argument(
|
| 380 |
+
"--section",
|
| 381 |
+
"-s",
|
| 382 |
choices=["all", "basic", "gate", "verify", "cold"],
|
| 383 |
default="all",
|
| 384 |
help="Which section to run",
|
scripts/faithfulness.py
CHANGED
|
@@ -47,6 +47,7 @@ TOP_K_PRODUCTS = 3
|
|
| 47 |
# SECTION: Core Evaluation
|
| 48 |
# ============================================================================
|
| 49 |
|
|
|
|
| 50 |
def run_evaluation(n_samples: int, run_ragas: bool = False):
|
| 51 |
"""Run faithfulness evaluation on sample queries."""
|
| 52 |
from sage.services.explanation import Explainer
|
|
@@ -64,8 +65,13 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 64 |
all_explanations = []
|
| 65 |
|
| 66 |
for i, query in enumerate(queries, 1):
|
| 67 |
-
logger.info(
|
| 68 |
-
products = get_candidates(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
if not products:
|
| 71 |
logger.info(" No products found")
|
|
@@ -73,7 +79,9 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 73 |
|
| 74 |
product = products[0]
|
| 75 |
try:
|
| 76 |
-
result = explainer.generate_explanation(
|
|
|
|
|
|
|
| 77 |
all_explanations.append(result)
|
| 78 |
logger.info(" %s: %s...", product.product_id, result.explanation[:60])
|
| 79 |
except Exception:
|
|
@@ -101,7 +109,9 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 101 |
|
| 102 |
logger.info(
|
| 103 |
"HHEM (full-explanation): %d/%d grounded, mean=%.3f",
|
| 104 |
-
len(hhem_results) - n_hallucinated,
|
|
|
|
|
|
|
| 105 |
)
|
| 106 |
|
| 107 |
# Multi-metric faithfulness (claim-level as primary)
|
|
@@ -109,20 +119,24 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 109 |
|
| 110 |
from sage.services.faithfulness import compute_multi_metric_faithfulness
|
| 111 |
|
| 112 |
-
multi_items = [
|
| 113 |
-
(expl.evidence_texts, expl.explanation) for expl in all_explanations
|
| 114 |
-
]
|
| 115 |
multi_report = compute_multi_metric_faithfulness(multi_items)
|
| 116 |
|
| 117 |
-
logger.info(
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
multi_report.quote_verification_rate * 100,
|
| 120 |
)
|
| 121 |
-
logger.info(
|
| 122 |
-
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
-
logger.info(
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
# RAGAS (optional)
|
|
@@ -132,16 +146,17 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
|
|
| 132 |
|
| 133 |
try:
|
| 134 |
from sage.services.faithfulness import FaithfulnessEvaluator
|
|
|
|
| 135 |
evaluator = FaithfulnessEvaluator()
|
| 136 |
ragas_report = evaluator.evaluate_batch(all_explanations)
|
| 137 |
|
| 138 |
logger.info(
|
| 139 |
"Faithfulness: %.3f +/- %.3f",
|
| 140 |
-
ragas_report.mean_score,
|
|
|
|
| 141 |
)
|
| 142 |
logger.info(
|
| 143 |
-
"Passing: %d/%d",
|
| 144 |
-
ragas_report.n_passing, ragas_report.n_samples
|
| 145 |
)
|
| 146 |
except Exception:
|
| 147 |
logger.exception("RAGAS evaluation failed")
|
|
@@ -217,8 +232,10 @@ def run_failure_analysis():
|
|
| 217 |
case_id = 0
|
| 218 |
|
| 219 |
for query in ANALYSIS_QUERIES:
|
| 220 |
-
logger.info(
|
| 221 |
-
products = get_candidates(
|
|
|
|
|
|
|
| 222 |
|
| 223 |
if not products:
|
| 224 |
continue
|
|
@@ -226,21 +243,27 @@ def run_failure_analysis():
|
|
| 226 |
for product in products[:2]:
|
| 227 |
try:
|
| 228 |
result = explainer.generate_explanation(query, product, max_evidence=3)
|
| 229 |
-
hhem = detector.check_explanation(
|
|
|
|
|
|
|
| 230 |
|
| 231 |
case_id += 1
|
| 232 |
-
all_cases.append(
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
| 241 |
|
| 242 |
status = "FAIL" if hhem.is_hallucinated else "PASS"
|
| 243 |
-
logger.info(
|
|
|
|
|
|
|
| 244 |
except Exception:
|
| 245 |
logger.exception(" Error processing product")
|
| 246 |
|
|
@@ -254,7 +277,9 @@ def run_failure_analysis():
|
|
| 254 |
|
| 255 |
log_banner(logger, "ANALYSIS SUMMARY")
|
| 256 |
logger.info("Total cases: %d", len(all_cases))
|
| 257 |
-
logger.info(
|
|
|
|
|
|
|
| 258 |
logger.info("Passes: %d", len(passes))
|
| 259 |
|
| 260 |
# Categorize failures
|
|
@@ -274,6 +299,7 @@ def run_failure_analysis():
|
|
| 274 |
# SECTION: Adjusted Faithfulness
|
| 275 |
# ============================================================================
|
| 276 |
|
|
|
|
| 277 |
def run_adjusted_calculation():
|
| 278 |
"""Calculate adjusted faithfulness with refusals excluded."""
|
| 279 |
from sage.services.faithfulness import is_refusal
|
|
@@ -296,8 +322,14 @@ def run_adjusted_calculation():
|
|
| 296 |
|
| 297 |
# Classify
|
| 298 |
refusals = [c for c in cases if is_refusal(c["explanation"])]
|
| 299 |
-
non_refusal_passes = [
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
n_total = len(cases)
|
| 303 |
raw_pass = sum(1 for c in cases if not c["is_hallucinated"])
|
|
@@ -309,9 +341,21 @@ def run_adjusted_calculation():
|
|
| 309 |
logger.info("Non-refusal fails: %d", len(non_refusal_fails))
|
| 310 |
|
| 311 |
log_section(logger, "Metrics")
|
| 312 |
-
logger.info(
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
# Save
|
| 317 |
output = {
|
|
@@ -328,12 +372,15 @@ def run_adjusted_calculation():
|
|
| 328 |
# Main
|
| 329 |
# ============================================================================
|
| 330 |
|
|
|
|
| 331 |
def main():
|
| 332 |
parser = argparse.ArgumentParser(description="Run faithfulness evaluation")
|
| 333 |
parser.add_argument("--samples", "-n", type=int, default=DEFAULT_SAMPLES)
|
| 334 |
parser.add_argument("--ragas", action="store_true", help="Include RAGAS evaluation")
|
| 335 |
parser.add_argument("--analyze", action="store_true", help="Run failure analysis")
|
| 336 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 337 |
args = parser.parse_args()
|
| 338 |
|
| 339 |
if args.analyze:
|
|
|
|
| 47 |
# SECTION: Core Evaluation
|
| 48 |
# ============================================================================
|
| 49 |
|
| 50 |
+
|
| 51 |
def run_evaluation(n_samples: int, run_ragas: bool = False):
|
| 52 |
"""Run faithfulness evaluation on sample queries."""
|
| 53 |
from sage.services.explanation import Explainer
|
|
|
|
| 65 |
all_explanations = []
|
| 66 |
|
| 67 |
for i, query in enumerate(queries, 1):
|
| 68 |
+
logger.info('[%d/%d] "%s"', i, len(queries), query)
|
| 69 |
+
products = get_candidates(
|
| 70 |
+
query=query,
|
| 71 |
+
k=TOP_K_PRODUCTS,
|
| 72 |
+
min_rating=4.0,
|
| 73 |
+
aggregation=AggregationMethod.MAX,
|
| 74 |
+
)
|
| 75 |
|
| 76 |
if not products:
|
| 77 |
logger.info(" No products found")
|
|
|
|
| 79 |
|
| 80 |
product = products[0]
|
| 81 |
try:
|
| 82 |
+
result = explainer.generate_explanation(
|
| 83 |
+
query, product, max_evidence=MAX_EVIDENCE
|
| 84 |
+
)
|
| 85 |
all_explanations.append(result)
|
| 86 |
logger.info(" %s: %s...", product.product_id, result.explanation[:60])
|
| 87 |
except Exception:
|
|
|
|
| 109 |
|
| 110 |
logger.info(
|
| 111 |
"HHEM (full-explanation): %d/%d grounded, mean=%.3f",
|
| 112 |
+
len(hhem_results) - n_hallucinated,
|
| 113 |
+
len(hhem_results),
|
| 114 |
+
np.mean(hhem_scores),
|
| 115 |
)
|
| 116 |
|
| 117 |
# Multi-metric faithfulness (claim-level as primary)
|
|
|
|
| 119 |
|
| 120 |
from sage.services.faithfulness import compute_multi_metric_faithfulness
|
| 121 |
|
| 122 |
+
multi_items = [(expl.evidence_texts, expl.explanation) for expl in all_explanations]
|
|
|
|
|
|
|
| 123 |
multi_report = compute_multi_metric_faithfulness(multi_items)
|
| 124 |
|
| 125 |
+
logger.info(
|
| 126 |
+
"Quote verification: %d/%d (%.1f%%)",
|
| 127 |
+
multi_report.quotes_found,
|
| 128 |
+
multi_report.quotes_total,
|
| 129 |
multi_report.quote_verification_rate * 100,
|
| 130 |
)
|
| 131 |
+
logger.info(
|
| 132 |
+
"Claim-level HHEM: %.3f avg, %.1f%% pass rate",
|
| 133 |
+
multi_report.claim_level_avg_score,
|
| 134 |
+
multi_report.claim_level_pass_rate * 100,
|
| 135 |
)
|
| 136 |
+
logger.info(
|
| 137 |
+
"Full-explanation: %.3f avg, %.1f%% pass rate (reference only)",
|
| 138 |
+
multi_report.full_explanation_avg_score,
|
| 139 |
+
multi_report.full_explanation_pass_rate * 100,
|
| 140 |
)
|
| 141 |
|
| 142 |
# RAGAS (optional)
|
|
|
|
| 146 |
|
| 147 |
try:
|
| 148 |
from sage.services.faithfulness import FaithfulnessEvaluator
|
| 149 |
+
|
| 150 |
evaluator = FaithfulnessEvaluator()
|
| 151 |
ragas_report = evaluator.evaluate_batch(all_explanations)
|
| 152 |
|
| 153 |
logger.info(
|
| 154 |
"Faithfulness: %.3f +/- %.3f",
|
| 155 |
+
ragas_report.mean_score,
|
| 156 |
+
ragas_report.std_score,
|
| 157 |
)
|
| 158 |
logger.info(
|
| 159 |
+
"Passing: %d/%d", ragas_report.n_passing, ragas_report.n_samples
|
|
|
|
| 160 |
)
|
| 161 |
except Exception:
|
| 162 |
logger.exception("RAGAS evaluation failed")
|
|
|
|
| 232 |
case_id = 0
|
| 233 |
|
| 234 |
for query in ANALYSIS_QUERIES:
|
| 235 |
+
logger.info('Query: "%s"', query)
|
| 236 |
+
products = get_candidates(
|
| 237 |
+
query=query, k=3, min_rating=3.5, aggregation=AggregationMethod.MAX
|
| 238 |
+
)
|
| 239 |
|
| 240 |
if not products:
|
| 241 |
continue
|
|
|
|
| 243 |
for product in products[:2]:
|
| 244 |
try:
|
| 245 |
result = explainer.generate_explanation(query, product, max_evidence=3)
|
| 246 |
+
hhem = detector.check_explanation(
|
| 247 |
+
result.evidence_texts, result.explanation
|
| 248 |
+
)
|
| 249 |
|
| 250 |
case_id += 1
|
| 251 |
+
all_cases.append(
|
| 252 |
+
{
|
| 253 |
+
"case_id": case_id,
|
| 254 |
+
"query": query,
|
| 255 |
+
"product_id": product.product_id,
|
| 256 |
+
"explanation": result.explanation,
|
| 257 |
+
"evidence_texts": result.evidence_texts,
|
| 258 |
+
"hhem_score": hhem.score,
|
| 259 |
+
"is_hallucinated": hhem.is_hallucinated,
|
| 260 |
+
}
|
| 261 |
+
)
|
| 262 |
|
| 263 |
status = "FAIL" if hhem.is_hallucinated else "PASS"
|
| 264 |
+
logger.info(
|
| 265 |
+
" [%s] %.3f - %s...", status, hhem.score, product.product_id[:20]
|
| 266 |
+
)
|
| 267 |
except Exception:
|
| 268 |
logger.exception(" Error processing product")
|
| 269 |
|
|
|
|
| 277 |
|
| 278 |
log_banner(logger, "ANALYSIS SUMMARY")
|
| 279 |
logger.info("Total cases: %d", len(all_cases))
|
| 280 |
+
logger.info(
|
| 281 |
+
"Failures: %d (%.1f%%)", len(failures), len(failures) / len(all_cases) * 100
|
| 282 |
+
)
|
| 283 |
logger.info("Passes: %d", len(passes))
|
| 284 |
|
| 285 |
# Categorize failures
|
|
|
|
| 299 |
# SECTION: Adjusted Faithfulness
|
| 300 |
# ============================================================================
|
| 301 |
|
| 302 |
+
|
| 303 |
def run_adjusted_calculation():
|
| 304 |
"""Calculate adjusted faithfulness with refusals excluded."""
|
| 305 |
from sage.services.faithfulness import is_refusal
|
|
|
|
| 322 |
|
| 323 |
# Classify
|
| 324 |
refusals = [c for c in cases if is_refusal(c["explanation"])]
|
| 325 |
+
non_refusal_passes = [
|
| 326 |
+
c
|
| 327 |
+
for c in cases
|
| 328 |
+
if not is_refusal(c["explanation"]) and not c["is_hallucinated"]
|
| 329 |
+
]
|
| 330 |
+
non_refusal_fails = [
|
| 331 |
+
c for c in cases if not is_refusal(c["explanation"]) and c["is_hallucinated"]
|
| 332 |
+
]
|
| 333 |
|
| 334 |
n_total = len(cases)
|
| 335 |
raw_pass = sum(1 for c in cases if not c["is_hallucinated"])
|
|
|
|
| 341 |
logger.info("Non-refusal fails: %d", len(non_refusal_fails))
|
| 342 |
|
| 343 |
log_section(logger, "Metrics")
|
| 344 |
+
logger.info(
|
| 345 |
+
"Raw pass rate: %d/%d = %.1f%%",
|
| 346 |
+
raw_pass,
|
| 347 |
+
n_total,
|
| 348 |
+
raw_pass / n_total * 100,
|
| 349 |
+
)
|
| 350 |
+
logger.info(
|
| 351 |
+
"Adjusted pass rate: %d/%d = %.1f%%",
|
| 352 |
+
adjusted_pass,
|
| 353 |
+
n_total,
|
| 354 |
+
adjusted_pass / n_total * 100,
|
| 355 |
+
)
|
| 356 |
+
logger.info(
|
| 357 |
+
"Improvement: +%.1f%%", (adjusted_pass / n_total - raw_pass / n_total) * 100
|
| 358 |
+
)
|
| 359 |
|
| 360 |
# Save
|
| 361 |
output = {
|
|
|
|
| 372 |
# Main
|
| 373 |
# ============================================================================
|
| 374 |
|
| 375 |
+
|
| 376 |
def main():
|
| 377 |
parser = argparse.ArgumentParser(description="Run faithfulness evaluation")
|
| 378 |
parser.add_argument("--samples", "-n", type=int, default=DEFAULT_SAMPLES)
|
| 379 |
parser.add_argument("--ragas", action="store_true", help="Include RAGAS evaluation")
|
| 380 |
parser.add_argument("--analyze", action="store_true", help="Run failure analysis")
|
| 381 |
+
parser.add_argument(
|
| 382 |
+
"--adjusted", action="store_true", help="Calculate adjusted metrics"
|
| 383 |
+
)
|
| 384 |
args = parser.parse_args()
|
| 385 |
|
| 386 |
if args.analyze:
|
scripts/human_eval.py
CHANGED
|
@@ -51,6 +51,7 @@ NATURAL_QUERIES_FILE = DATA_DIR / "eval" / "eval_natural_queries.json"
|
|
| 51 |
# Sample Generation
|
| 52 |
# ============================================================================
|
| 53 |
|
|
|
|
| 54 |
def _select_diverse_natural_queries(target: int = 35) -> list[str]:
|
| 55 |
"""Select diverse queries from natural eval dataset, balanced by category."""
|
| 56 |
if not NATURAL_QUERIES_FILE.exists():
|
|
@@ -114,7 +115,8 @@ def generate_samples(force: bool = False):
|
|
| 114 |
logger.error(
|
| 115 |
"%s contains %d rated samples. "
|
| 116 |
"Use --force to overwrite, or run --annotate to continue.",
|
| 117 |
-
SAMPLES_FILE,
|
|
|
|
| 118 |
)
|
| 119 |
sys.exit(1)
|
| 120 |
|
|
@@ -129,7 +131,9 @@ def generate_samples(force: bool = False):
|
|
| 129 |
all_queries = natural + config
|
| 130 |
logger.info(
|
| 131 |
"Queries: %d natural + %d config = %d total",
|
| 132 |
-
len(natural),
|
|
|
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
if len(all_queries) < TARGET_SAMPLES:
|
|
@@ -137,7 +141,8 @@ def generate_samples(force: bool = False):
|
|
| 137 |
"Only %d unique queries available (target: %d). "
|
| 138 |
"Results will lack statistical power. "
|
| 139 |
"Run 'make eval' to build natural query dataset.",
|
| 140 |
-
len(all_queries),
|
|
|
|
| 141 |
)
|
| 142 |
|
| 143 |
# Initialize services
|
|
@@ -146,10 +151,12 @@ def generate_samples(force: bool = False):
|
|
| 146 |
|
| 147 |
samples = []
|
| 148 |
for i, query in enumerate(all_queries, 1):
|
| 149 |
-
logger.info(
|
| 150 |
|
| 151 |
products = get_candidates(
|
| 152 |
-
query=query,
|
|
|
|
|
|
|
| 153 |
aggregation=AggregationMethod.MAX,
|
| 154 |
)
|
| 155 |
if not products:
|
|
@@ -159,10 +166,13 @@ def generate_samples(force: bool = False):
|
|
| 159 |
product = products[0]
|
| 160 |
try:
|
| 161 |
expl = explainer.generate_explanation(
|
| 162 |
-
query,
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
hhem = detector.check_explanation(
|
| 165 |
-
expl.evidence_texts,
|
|
|
|
| 166 |
)
|
| 167 |
|
| 168 |
sample = {
|
|
@@ -178,7 +188,9 @@ def generate_samples(force: bool = False):
|
|
| 178 |
samples.append(sample)
|
| 179 |
logger.info(
|
| 180 |
" %s (%.1f stars) HHEM=%.3f",
|
| 181 |
-
product.product_id,
|
|
|
|
|
|
|
| 182 |
)
|
| 183 |
except ValueError as exc:
|
| 184 |
logger.info(" Quality gate refusal: %s", exc)
|
|
@@ -197,6 +209,7 @@ def generate_samples(force: bool = False):
|
|
| 197 |
# Interactive Annotation
|
| 198 |
# ============================================================================
|
| 199 |
|
|
|
|
| 200 |
def _load_samples() -> list[dict]:
|
| 201 |
"""Load samples from disk."""
|
| 202 |
if not SAMPLES_FILE.exists():
|
|
@@ -261,7 +274,7 @@ def annotate_samples():
|
|
| 261 |
text = ev["text"]
|
| 262 |
if len(text) > 200:
|
| 263 |
text = text[:200] + "..."
|
| 264 |
-
print(f
|
| 265 |
print()
|
| 266 |
|
| 267 |
# Collect ratings
|
|
@@ -286,6 +299,7 @@ def annotate_samples():
|
|
| 286 |
# Analysis
|
| 287 |
# ============================================================================
|
| 288 |
|
|
|
|
| 289 |
def analyze_results():
|
| 290 |
"""Compute aggregate metrics from rated samples."""
|
| 291 |
samples = _load_samples()
|
|
@@ -306,7 +320,7 @@ def analyze_results():
|
|
| 306 |
n = len(scores)
|
| 307 |
mean = sum(scores) / n
|
| 308 |
variance = sum((x - mean) ** 2 for x in scores) / (n - 1) if n > 1 else 0.0
|
| 309 |
-
std = variance
|
| 310 |
dimensions_results[dim_key] = {
|
| 311 |
"mean": round(mean, 2),
|
| 312 |
"std": round(std, 2),
|
|
@@ -315,7 +329,11 @@ def analyze_results():
|
|
| 315 |
}
|
| 316 |
logger.info(
|
| 317 |
" %-15s mean=%.2f std=%.2f range=[%d, %d]",
|
| 318 |
-
dim_key + ":",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
)
|
| 320 |
|
| 321 |
# Overall helpfulness: mean of per-sample averages
|
|
@@ -328,15 +346,20 @@ def analyze_results():
|
|
| 328 |
passed = overall >= HELPFULNESS_TARGET
|
| 329 |
|
| 330 |
logger.info("")
|
| 331 |
-
logger.info(
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
# HHEM vs Trust correlation (Spearman)
|
| 335 |
correlation = _compute_hhem_trust_correlation(rated)
|
| 336 |
if correlation:
|
| 337 |
logger.info(
|
| 338 |
"HHEM-Trust correlation: r=%.3f, p=%.4f",
|
| 339 |
-
correlation["spearman_r"],
|
|
|
|
| 340 |
)
|
| 341 |
|
| 342 |
# Save results
|
|
@@ -368,6 +391,7 @@ def _compute_hhem_trust_correlation(rated: list[dict]) -> dict | None:
|
|
| 368 |
|
| 369 |
try:
|
| 370 |
from scipy.stats import spearmanr
|
|
|
|
| 371 |
r, p = spearmanr(hhem_scores, trust_scores)
|
| 372 |
return {"spearman_r": round(float(r), 4), "p_value": round(float(p), 4)}
|
| 373 |
except ImportError:
|
|
@@ -399,13 +423,13 @@ def _manual_spearman(x: list[float], y: list[float]) -> dict | None:
|
|
| 399 |
ry = _rank(y)
|
| 400 |
|
| 401 |
d_sq = sum((rx[i] - ry[i]) ** 2 for i in range(n))
|
| 402 |
-
rho = 1 - (6 * d_sq) / (n * (n
|
| 403 |
|
| 404 |
# Approximate p-value via t-distribution (large sample)
|
| 405 |
if abs(rho) >= 1.0:
|
| 406 |
p = 0.0
|
| 407 |
else:
|
| 408 |
-
t = rho * math.sqrt((n - 2) / (1 - rho
|
| 409 |
# Two-tailed p-value approximation
|
| 410 |
p = 2 * (1 - _t_cdf_approx(abs(t), n - 2))
|
| 411 |
|
|
@@ -427,6 +451,7 @@ def _t_cdf_approx(t: float, df: int) -> float:
|
|
| 427 |
# Status
|
| 428 |
# ============================================================================
|
| 429 |
|
|
|
|
| 430 |
def show_status():
|
| 431 |
"""Show annotation progress."""
|
| 432 |
if not SAMPLES_FILE.exists():
|
|
@@ -450,21 +475,27 @@ def show_status():
|
|
| 450 |
# Main
|
| 451 |
# ============================================================================
|
| 452 |
|
|
|
|
| 453 |
def main():
|
| 454 |
parser = argparse.ArgumentParser(
|
| 455 |
description="Human evaluation of recommendation explanations",
|
| 456 |
)
|
| 457 |
group = parser.add_mutually_exclusive_group(required=True)
|
| 458 |
-
group.add_argument(
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
group.add_argument(
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
args = parser.parse_args()
|
| 469 |
|
| 470 |
if args.force and not args.generate:
|
|
|
|
| 51 |
# Sample Generation
|
| 52 |
# ============================================================================
|
| 53 |
|
| 54 |
+
|
| 55 |
def _select_diverse_natural_queries(target: int = 35) -> list[str]:
|
| 56 |
"""Select diverse queries from natural eval dataset, balanced by category."""
|
| 57 |
if not NATURAL_QUERIES_FILE.exists():
|
|
|
|
| 115 |
logger.error(
|
| 116 |
"%s contains %d rated samples. "
|
| 117 |
"Use --force to overwrite, or run --annotate to continue.",
|
| 118 |
+
SAMPLES_FILE,
|
| 119 |
+
rated,
|
| 120 |
)
|
| 121 |
sys.exit(1)
|
| 122 |
|
|
|
|
| 131 |
all_queries = natural + config
|
| 132 |
logger.info(
|
| 133 |
"Queries: %d natural + %d config = %d total",
|
| 134 |
+
len(natural),
|
| 135 |
+
len(config),
|
| 136 |
+
len(all_queries),
|
| 137 |
)
|
| 138 |
|
| 139 |
if len(all_queries) < TARGET_SAMPLES:
|
|
|
|
| 141 |
"Only %d unique queries available (target: %d). "
|
| 142 |
"Results will lack statistical power. "
|
| 143 |
"Run 'make eval' to build natural query dataset.",
|
| 144 |
+
len(all_queries),
|
| 145 |
+
TARGET_SAMPLES,
|
| 146 |
)
|
| 147 |
|
| 148 |
# Initialize services
|
|
|
|
| 151 |
|
| 152 |
samples = []
|
| 153 |
for i, query in enumerate(all_queries, 1):
|
| 154 |
+
logger.info('[%d/%d] "%s"', i, len(all_queries), query)
|
| 155 |
|
| 156 |
products = get_candidates(
|
| 157 |
+
query=query,
|
| 158 |
+
k=1,
|
| 159 |
+
min_rating=4.0,
|
| 160 |
aggregation=AggregationMethod.MAX,
|
| 161 |
)
|
| 162 |
if not products:
|
|
|
|
| 166 |
product = products[0]
|
| 167 |
try:
|
| 168 |
expl = explainer.generate_explanation(
|
| 169 |
+
query,
|
| 170 |
+
product,
|
| 171 |
+
max_evidence=MAX_EVIDENCE,
|
| 172 |
)
|
| 173 |
hhem = detector.check_explanation(
|
| 174 |
+
expl.evidence_texts,
|
| 175 |
+
expl.explanation,
|
| 176 |
)
|
| 177 |
|
| 178 |
sample = {
|
|
|
|
| 188 |
samples.append(sample)
|
| 189 |
logger.info(
|
| 190 |
" %s (%.1f stars) HHEM=%.3f",
|
| 191 |
+
product.product_id,
|
| 192 |
+
product.avg_rating,
|
| 193 |
+
hhem.score,
|
| 194 |
)
|
| 195 |
except ValueError as exc:
|
| 196 |
logger.info(" Quality gate refusal: %s", exc)
|
|
|
|
| 209 |
# Interactive Annotation
|
| 210 |
# ============================================================================
|
| 211 |
|
| 212 |
+
|
| 213 |
def _load_samples() -> list[dict]:
|
| 214 |
"""Load samples from disk."""
|
| 215 |
if not SAMPLES_FILE.exists():
|
|
|
|
| 274 |
text = ev["text"]
|
| 275 |
if len(text) > 200:
|
| 276 |
text = text[:200] + "..."
|
| 277 |
+
print(f' [{ev["id"]}]: "{text}"')
|
| 278 |
print()
|
| 279 |
|
| 280 |
# Collect ratings
|
|
|
|
| 299 |
# Analysis
|
| 300 |
# ============================================================================
|
| 301 |
|
| 302 |
+
|
| 303 |
def analyze_results():
|
| 304 |
"""Compute aggregate metrics from rated samples."""
|
| 305 |
samples = _load_samples()
|
|
|
|
| 320 |
n = len(scores)
|
| 321 |
mean = sum(scores) / n
|
| 322 |
variance = sum((x - mean) ** 2 for x in scores) / (n - 1) if n > 1 else 0.0
|
| 323 |
+
std = variance**0.5
|
| 324 |
dimensions_results[dim_key] = {
|
| 325 |
"mean": round(mean, 2),
|
| 326 |
"std": round(std, 2),
|
|
|
|
| 329 |
}
|
| 330 |
logger.info(
|
| 331 |
" %-15s mean=%.2f std=%.2f range=[%d, %d]",
|
| 332 |
+
dim_key + ":",
|
| 333 |
+
mean,
|
| 334 |
+
std,
|
| 335 |
+
min(scores),
|
| 336 |
+
max(scores),
|
| 337 |
)
|
| 338 |
|
| 339 |
# Overall helpfulness: mean of per-sample averages
|
|
|
|
| 346 |
passed = overall >= HELPFULNESS_TARGET
|
| 347 |
|
| 348 |
logger.info("")
|
| 349 |
+
logger.info(
|
| 350 |
+
"Overall helpfulness: %.2f (target: %.1f) [%s]",
|
| 351 |
+
overall,
|
| 352 |
+
HELPFULNESS_TARGET,
|
| 353 |
+
"PASS" if passed else "FAIL",
|
| 354 |
+
)
|
| 355 |
|
| 356 |
# HHEM vs Trust correlation (Spearman)
|
| 357 |
correlation = _compute_hhem_trust_correlation(rated)
|
| 358 |
if correlation:
|
| 359 |
logger.info(
|
| 360 |
"HHEM-Trust correlation: r=%.3f, p=%.4f",
|
| 361 |
+
correlation["spearman_r"],
|
| 362 |
+
correlation["p_value"],
|
| 363 |
)
|
| 364 |
|
| 365 |
# Save results
|
|
|
|
| 391 |
|
| 392 |
try:
|
| 393 |
from scipy.stats import spearmanr
|
| 394 |
+
|
| 395 |
r, p = spearmanr(hhem_scores, trust_scores)
|
| 396 |
return {"spearman_r": round(float(r), 4), "p_value": round(float(p), 4)}
|
| 397 |
except ImportError:
|
|
|
|
| 423 |
ry = _rank(y)
|
| 424 |
|
| 425 |
d_sq = sum((rx[i] - ry[i]) ** 2 for i in range(n))
|
| 426 |
+
rho = 1 - (6 * d_sq) / (n * (n**2 - 1))
|
| 427 |
|
| 428 |
# Approximate p-value via t-distribution (large sample)
|
| 429 |
if abs(rho) >= 1.0:
|
| 430 |
p = 0.0
|
| 431 |
else:
|
| 432 |
+
t = rho * math.sqrt((n - 2) / (1 - rho**2))
|
| 433 |
# Two-tailed p-value approximation
|
| 434 |
p = 2 * (1 - _t_cdf_approx(abs(t), n - 2))
|
| 435 |
|
|
|
|
| 451 |
# Status
|
| 452 |
# ============================================================================
|
| 453 |
|
| 454 |
+
|
| 455 |
def show_status():
|
| 456 |
"""Show annotation progress."""
|
| 457 |
if not SAMPLES_FILE.exists():
|
|
|
|
| 475 |
# Main
|
| 476 |
# ============================================================================
|
| 477 |
|
| 478 |
+
|
| 479 |
def main():
|
| 480 |
parser = argparse.ArgumentParser(
|
| 481 |
description="Human evaluation of recommendation explanations",
|
| 482 |
)
|
| 483 |
group = parser.add_mutually_exclusive_group(required=True)
|
| 484 |
+
group.add_argument(
|
| 485 |
+
"--generate", action="store_true", help="Generate recommendation samples"
|
| 486 |
+
)
|
| 487 |
+
group.add_argument(
|
| 488 |
+
"--annotate", action="store_true", help="Rate samples interactively (resumable)"
|
| 489 |
+
)
|
| 490 |
+
group.add_argument(
|
| 491 |
+
"--analyze", action="store_true", help="Compute aggregate results from ratings"
|
| 492 |
+
)
|
| 493 |
+
group.add_argument("--status", action="store_true", help="Show annotation progress")
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--force",
|
| 496 |
+
action="store_true",
|
| 497 |
+
help="Overwrite existing rated samples (with --generate)",
|
| 498 |
+
)
|
| 499 |
args = parser.parse_args()
|
| 500 |
|
| 501 |
if args.force and not args.generate:
|
scripts/pipeline.py
CHANGED
|
@@ -54,6 +54,7 @@ logger = get_logger(__name__)
|
|
| 54 |
# TOKENIZER VALIDATION (--validate-tokenizer)
|
| 55 |
# ============================================================================
|
| 56 |
|
|
|
|
| 57 |
def run_tokenizer_validation():
|
| 58 |
"""Validate the chars/token ratio assumption used in chunker.py."""
|
| 59 |
from transformers import AutoTokenizer
|
|
@@ -90,10 +91,16 @@ def run_tokenizer_validation():
|
|
| 90 |
# CHUNKING QUALITY TEST (--test-chunking)
|
| 91 |
# ============================================================================
|
| 92 |
|
|
|
|
| 93 |
def run_chunking_test():
|
| 94 |
"""Test chunking quality on long reviews."""
|
| 95 |
import pandas as pd
|
| 96 |
-
from sage.core.chunking import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
log_banner(logger, "CHUNKING QUALITY TEST", width=70)
|
| 99 |
|
|
@@ -113,28 +120,36 @@ def run_chunking_test():
|
|
| 113 |
chunks = chunk_text(text, embedder=embedder)
|
| 114 |
sentences = split_sentences(text)
|
| 115 |
|
| 116 |
-
results.append(
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
|
| 123 |
if idx < 5:
|
| 124 |
logger.info(
|
| 125 |
"Review %d [%d*] (%d tok) -> %d chunks",
|
| 126 |
-
idx + 1,
|
|
|
|
|
|
|
|
|
|
| 127 |
)
|
| 128 |
|
| 129 |
results_df = pd.DataFrame(results)
|
| 130 |
log_section(logger, f"Summary ({len(results_df)} reviews)")
|
| 131 |
logger.info(
|
| 132 |
"Chunks per review: %.2f (median: %.0f)",
|
| 133 |
-
results_df["chunks"].mean(),
|
|
|
|
| 134 |
)
|
| 135 |
logger.info("Avg tokens/chunk: %.0f", results_df["avg_chunk_tokens"].mean())
|
| 136 |
|
| 137 |
-
expansion = (
|
|
|
|
|
|
|
| 138 |
logger.info("Expansion ratio: %.2fx", expansion)
|
| 139 |
|
| 140 |
|
|
@@ -142,6 +157,7 @@ def run_chunking_test():
|
|
| 142 |
# MAIN PIPELINE
|
| 143 |
# ============================================================================
|
| 144 |
|
|
|
|
| 145 |
def run_pipeline(subset_size: int, force: bool):
|
| 146 |
"""Run the full data pipeline: load, chunk, embed, upload."""
|
| 147 |
logger.info("Config", extra={"subset_size": subset_size, "force": force})
|
|
@@ -174,7 +190,8 @@ def run_pipeline(subset_size: int, force: bool):
|
|
| 174 |
needs_chunking = (df["estimated_tokens"] > 200).sum()
|
| 175 |
logger.info(
|
| 176 |
"Reviews needing chunking (>200 tokens): %d (%.1f%%)",
|
| 177 |
-
needs_chunking,
|
|
|
|
| 178 |
)
|
| 179 |
|
| 180 |
# Prepare reviews for chunking
|
|
@@ -192,7 +209,9 @@ def run_pipeline(subset_size: int, force: bool):
|
|
| 192 |
chunks = chunk_reviews_batch(reviews_for_chunking, embedder=embedder)
|
| 193 |
logger.info(
|
| 194 |
"Created %d chunks from %d reviews (expansion: %.2fx)",
|
| 195 |
-
len(chunks),
|
|
|
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
# Generate embeddings
|
|
@@ -200,7 +219,9 @@ def run_pipeline(subset_size: int, force: bool):
|
|
| 200 |
cache_path = DATA_DIR / f"embeddings_{len(chunks)}.npy"
|
| 201 |
|
| 202 |
logger.info("Embedding %d chunks...", len(chunk_texts))
|
| 203 |
-
embeddings = embedder.embed_passages(
|
|
|
|
|
|
|
| 204 |
logger.info("Embeddings shape: %s", embeddings.shape)
|
| 205 |
|
| 206 |
# Embedding technical validation
|
|
@@ -243,12 +264,21 @@ def run_pipeline(subset_size: int, force: bool):
|
|
| 243 |
sim_out = float(np.dot(emb_query, emb_out))
|
| 244 |
|
| 245 |
logger.info("Query: '%s'", test_query)
|
| 246 |
-
logger.info(
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
logger.info(" Out-of-domain: '%s' = %.3f", out_of_domain, sim_out)
|
| 249 |
|
| 250 |
if sim_in_similar > sim_in_different > sim_out:
|
| 251 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
else:
|
| 253 |
logger.warning("Unexpected ranking")
|
| 254 |
|
|
@@ -307,10 +337,23 @@ def run_pipeline(subset_size: int, force: bool):
|
|
| 307 |
|
| 308 |
def main():
|
| 309 |
parser = argparse.ArgumentParser(description="Run the data pipeline")
|
| 310 |
-
parser.add_argument(
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
args = parser.parse_args()
|
| 315 |
|
| 316 |
if args.validate_tokenizer:
|
|
|
|
| 54 |
# TOKENIZER VALIDATION (--validate-tokenizer)
|
| 55 |
# ============================================================================
|
| 56 |
|
| 57 |
+
|
| 58 |
def run_tokenizer_validation():
|
| 59 |
"""Validate the chars/token ratio assumption used in chunker.py."""
|
| 60 |
from transformers import AutoTokenizer
|
|
|
|
| 91 |
# CHUNKING QUALITY TEST (--test-chunking)
|
| 92 |
# ============================================================================
|
| 93 |
|
| 94 |
+
|
| 95 |
def run_chunking_test():
|
| 96 |
"""Test chunking quality on long reviews."""
|
| 97 |
import pandas as pd
|
| 98 |
+
from sage.core.chunking import (
|
| 99 |
+
chunk_text,
|
| 100 |
+
split_sentences,
|
| 101 |
+
estimate_tokens,
|
| 102 |
+
NO_CHUNK_THRESHOLD,
|
| 103 |
+
)
|
| 104 |
|
| 105 |
log_banner(logger, "CHUNKING QUALITY TEST", width=70)
|
| 106 |
|
|
|
|
| 120 |
chunks = chunk_text(text, embedder=embedder)
|
| 121 |
sentences = split_sentences(text)
|
| 122 |
|
| 123 |
+
results.append(
|
| 124 |
+
{
|
| 125 |
+
"tokens": tokens,
|
| 126 |
+
"sentences": len(sentences),
|
| 127 |
+
"chunks": len(chunks),
|
| 128 |
+
"avg_chunk_tokens": np.mean([estimate_tokens(c) for c in chunks]),
|
| 129 |
+
}
|
| 130 |
+
)
|
| 131 |
|
| 132 |
if idx < 5:
|
| 133 |
logger.info(
|
| 134 |
"Review %d [%d*] (%d tok) -> %d chunks",
|
| 135 |
+
idx + 1,
|
| 136 |
+
rating,
|
| 137 |
+
tokens,
|
| 138 |
+
len(chunks),
|
| 139 |
)
|
| 140 |
|
| 141 |
results_df = pd.DataFrame(results)
|
| 142 |
log_section(logger, f"Summary ({len(results_df)} reviews)")
|
| 143 |
logger.info(
|
| 144 |
"Chunks per review: %.2f (median: %.0f)",
|
| 145 |
+
results_df["chunks"].mean(),
|
| 146 |
+
results_df["chunks"].median(),
|
| 147 |
)
|
| 148 |
logger.info("Avg tokens/chunk: %.0f", results_df["avg_chunk_tokens"].mean())
|
| 149 |
|
| 150 |
+
expansion = (
|
| 151 |
+
results_df["chunks"] * results_df["avg_chunk_tokens"]
|
| 152 |
+
).sum() / results_df["tokens"].sum()
|
| 153 |
logger.info("Expansion ratio: %.2fx", expansion)
|
| 154 |
|
| 155 |
|
|
|
|
| 157 |
# MAIN PIPELINE
|
| 158 |
# ============================================================================
|
| 159 |
|
| 160 |
+
|
| 161 |
def run_pipeline(subset_size: int, force: bool):
|
| 162 |
"""Run the full data pipeline: load, chunk, embed, upload."""
|
| 163 |
logger.info("Config", extra={"subset_size": subset_size, "force": force})
|
|
|
|
| 190 |
needs_chunking = (df["estimated_tokens"] > 200).sum()
|
| 191 |
logger.info(
|
| 192 |
"Reviews needing chunking (>200 tokens): %d (%.1f%%)",
|
| 193 |
+
needs_chunking,
|
| 194 |
+
needs_chunking / len(df) * 100,
|
| 195 |
)
|
| 196 |
|
| 197 |
# Prepare reviews for chunking
|
|
|
|
| 209 |
chunks = chunk_reviews_batch(reviews_for_chunking, embedder=embedder)
|
| 210 |
logger.info(
|
| 211 |
"Created %d chunks from %d reviews (expansion: %.2fx)",
|
| 212 |
+
len(chunks),
|
| 213 |
+
len(reviews_for_chunking),
|
| 214 |
+
len(chunks) / len(reviews_for_chunking),
|
| 215 |
)
|
| 216 |
|
| 217 |
# Generate embeddings
|
|
|
|
| 219 |
cache_path = DATA_DIR / f"embeddings_{len(chunks)}.npy"
|
| 220 |
|
| 221 |
logger.info("Embedding %d chunks...", len(chunk_texts))
|
| 222 |
+
embeddings = embedder.embed_passages(
|
| 223 |
+
chunk_texts, cache_path=cache_path, force=force
|
| 224 |
+
)
|
| 225 |
logger.info("Embeddings shape: %s", embeddings.shape)
|
| 226 |
|
| 227 |
# Embedding technical validation
|
|
|
|
| 264 |
sim_out = float(np.dot(emb_query, emb_out))
|
| 265 |
|
| 266 |
logger.info("Query: '%s'", test_query)
|
| 267 |
+
logger.info(
|
| 268 |
+
" In-domain (same topic): '%s' = %.3f", in_domain_similar, sim_in_similar
|
| 269 |
+
)
|
| 270 |
+
logger.info(
|
| 271 |
+
" In-domain (diff topic): '%s' = %.3f", in_domain_different, sim_in_different
|
| 272 |
+
)
|
| 273 |
logger.info(" Out-of-domain: '%s' = %.3f", out_of_domain, sim_out)
|
| 274 |
|
| 275 |
if sim_in_similar > sim_in_different > sim_out:
|
| 276 |
+
logger.info(
|
| 277 |
+
"Ranking correct: %.3f > %.3f > %.3f",
|
| 278 |
+
sim_in_similar,
|
| 279 |
+
sim_in_different,
|
| 280 |
+
sim_out,
|
| 281 |
+
)
|
| 282 |
else:
|
| 283 |
logger.warning("Unexpected ranking")
|
| 284 |
|
|
|
|
| 337 |
|
| 338 |
def main():
|
| 339 |
parser = argparse.ArgumentParser(description="Run the data pipeline")
|
| 340 |
+
parser.add_argument(
|
| 341 |
+
"--force", action="store_true", help="Force recreate collection"
|
| 342 |
+
)
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--subset-size",
|
| 345 |
+
type=int,
|
| 346 |
+
default=DEV_SUBSET_SIZE,
|
| 347 |
+
help="Number of reviews to load initially",
|
| 348 |
+
)
|
| 349 |
+
parser.add_argument(
|
| 350 |
+
"--validate-tokenizer",
|
| 351 |
+
action="store_true",
|
| 352 |
+
help="Run tokenizer validation only",
|
| 353 |
+
)
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--test-chunking", action="store_true", help="Run chunking quality test only"
|
| 356 |
+
)
|
| 357 |
args = parser.parse_args()
|
| 358 |
|
| 359 |
if args.validate_tokenizer:
|
scripts/sanity_checks.py
CHANGED
|
@@ -42,6 +42,7 @@ RESULTS_DIR.mkdir(exist_ok=True)
|
|
| 42 |
# SECTION: Spot-Check
|
| 43 |
# ============================================================================
|
| 44 |
|
|
|
|
| 45 |
def run_spot_check():
|
| 46 |
"""Manual spot-check of explanations vs evidence."""
|
| 47 |
from sage.services.explanation import Explainer
|
|
@@ -56,18 +57,24 @@ def run_spot_check():
|
|
| 56 |
queries = EVALUATION_QUERIES[:5]
|
| 57 |
|
| 58 |
for query in queries:
|
| 59 |
-
products = get_candidates(
|
|
|
|
|
|
|
| 60 |
|
| 61 |
for product in products[:2]:
|
| 62 |
result = explainer.generate_explanation(query, product, max_evidence=3)
|
| 63 |
hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
| 64 |
|
| 65 |
log_section(logger, f"SAMPLE {len(results) + 1}")
|
| 66 |
-
logger.info(
|
| 67 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
logger.info("EVIDENCE:")
|
| 69 |
for ev in result.evidence_texts[:2]:
|
| 70 |
-
logger.info(
|
| 71 |
logger.info("EXPLANATION:")
|
| 72 |
logger.info(" %s", result.explanation)
|
| 73 |
|
|
@@ -86,6 +93,7 @@ def run_spot_check():
|
|
| 86 |
# SECTION: Adversarial Tests
|
| 87 |
# ============================================================================
|
| 88 |
|
|
|
|
| 89 |
def run_adversarial_tests():
|
| 90 |
"""Test with contradictory evidence."""
|
| 91 |
from sage.services.explanation import Explainer
|
|
@@ -118,10 +126,28 @@ def run_adversarial_tests():
|
|
| 118 |
log_section(logger, case["name"])
|
| 119 |
|
| 120 |
chunks = [
|
| 121 |
-
RetrievedChunk(
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
]
|
| 124 |
-
product = ProductScore(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
result = explainer.generate_explanation(case["query"], product, max_evidence=2)
|
| 127 |
hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
|
@@ -142,6 +168,7 @@ def run_adversarial_tests():
|
|
| 142 |
# SECTION: Empty Context Tests
|
| 143 |
# ============================================================================
|
| 144 |
|
|
|
|
| 145 |
def run_empty_context_tests():
|
| 146 |
"""Test graceful refusal with irrelevant evidence."""
|
| 147 |
from sage.services.explanation import Explainer
|
|
@@ -153,19 +180,46 @@ def run_empty_context_tests():
|
|
| 153 |
detector = HallucinationDetector()
|
| 154 |
|
| 155 |
cases = [
|
| 156 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
{"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
|
| 158 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
]
|
| 160 |
|
| 161 |
-
refusal_words = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
results = []
|
| 163 |
|
| 164 |
for case in cases:
|
| 165 |
log_section(logger, case["name"])
|
| 166 |
|
| 167 |
-
chunk = RetrievedChunk(
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
result = explainer.generate_explanation(case["query"], product, max_evidence=1)
|
| 171 |
_hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
|
@@ -185,6 +239,7 @@ def run_empty_context_tests():
|
|
| 185 |
# SECTION: Calibration Check
|
| 186 |
# ============================================================================
|
| 187 |
|
|
|
|
| 188 |
@dataclass
|
| 189 |
class CalibrationSample:
|
| 190 |
query: str
|
|
@@ -210,21 +265,27 @@ def run_calibration_check():
|
|
| 210 |
|
| 211 |
logger.info("Generating samples...")
|
| 212 |
for query in queries:
|
| 213 |
-
products = get_candidates(
|
|
|
|
|
|
|
| 214 |
|
| 215 |
for product in products[:2]:
|
| 216 |
try:
|
| 217 |
result = explainer.generate_explanation(query, product, max_evidence=3)
|
| 218 |
-
hhem = detector.check_explanation(
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
except Exception:
|
| 229 |
logger.debug("Error generating sample", exc_info=True)
|
| 230 |
|
|
@@ -251,24 +312,28 @@ def run_calibration_check():
|
|
| 251 |
# Stratified analysis
|
| 252 |
sorted_samples = sorted(samples, key=lambda s: s.retrieval_score)
|
| 253 |
n = len(sorted_samples)
|
| 254 |
-
low = sorted_samples[:n//3]
|
| 255 |
-
mid = sorted_samples[n//3:2*n//3]
|
| 256 |
-
high = sorted_samples[2*n//3:]
|
| 257 |
|
| 258 |
log_section(logger, "HHEM by Confidence Tier")
|
| 259 |
logger.info(" LOW (n=%2d): %.3f", len(low), np.mean([s.hhem_score for s in low]))
|
| 260 |
logger.info(" MED (n=%2d): %.3f", len(mid), np.mean([s.hhem_score for s in mid]))
|
| 261 |
-
logger.info(
|
|
|
|
|
|
|
| 262 |
|
| 263 |
|
| 264 |
# ============================================================================
|
| 265 |
# Main
|
| 266 |
# ============================================================================
|
| 267 |
|
|
|
|
| 268 |
def main():
|
| 269 |
parser = argparse.ArgumentParser(description="Run pipeline sanity checks")
|
| 270 |
parser.add_argument(
|
| 271 |
-
"--section",
|
|
|
|
| 272 |
choices=["all", "spot", "adversarial", "empty", "calibration"],
|
| 273 |
default="all",
|
| 274 |
help="Which section to run",
|
|
|
|
| 42 |
# SECTION: Spot-Check
|
| 43 |
# ============================================================================
|
| 44 |
|
| 45 |
+
|
| 46 |
def run_spot_check():
|
| 47 |
"""Manual spot-check of explanations vs evidence."""
|
| 48 |
from sage.services.explanation import Explainer
|
|
|
|
| 57 |
queries = EVALUATION_QUERIES[:5]
|
| 58 |
|
| 59 |
for query in queries:
|
| 60 |
+
products = get_candidates(
|
| 61 |
+
query=query, k=2, min_rating=4.0, aggregation=AggregationMethod.MAX
|
| 62 |
+
)
|
| 63 |
|
| 64 |
for product in products[:2]:
|
| 65 |
result = explainer.generate_explanation(query, product, max_evidence=3)
|
| 66 |
hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
| 67 |
|
| 68 |
log_section(logger, f"SAMPLE {len(results) + 1}")
|
| 69 |
+
logger.info('Query: "%s"', query)
|
| 70 |
+
logger.info(
|
| 71 |
+
"HHEM: %.3f (%s)",
|
| 72 |
+
hhem.score,
|
| 73 |
+
"PASS" if not hhem.is_hallucinated else "FAIL",
|
| 74 |
+
)
|
| 75 |
logger.info("EVIDENCE:")
|
| 76 |
for ev in result.evidence_texts[:2]:
|
| 77 |
+
logger.info(' "%s..."', ev[:100])
|
| 78 |
logger.info("EXPLANATION:")
|
| 79 |
logger.info(" %s", result.explanation)
|
| 80 |
|
|
|
|
| 93 |
# SECTION: Adversarial Tests
|
| 94 |
# ============================================================================
|
| 95 |
|
| 96 |
+
|
| 97 |
def run_adversarial_tests():
|
| 98 |
"""Test with contradictory evidence."""
|
| 99 |
from sage.services.explanation import Explainer
|
|
|
|
| 126 |
log_section(logger, case["name"])
|
| 127 |
|
| 128 |
chunks = [
|
| 129 |
+
RetrievedChunk(
|
| 130 |
+
text=case["positive"],
|
| 131 |
+
score=0.9,
|
| 132 |
+
product_id="TEST",
|
| 133 |
+
rating=5.0,
|
| 134 |
+
review_id="pos",
|
| 135 |
+
),
|
| 136 |
+
RetrievedChunk(
|
| 137 |
+
text=case["negative"],
|
| 138 |
+
score=0.85,
|
| 139 |
+
product_id="TEST",
|
| 140 |
+
rating=1.0,
|
| 141 |
+
review_id="neg",
|
| 142 |
+
),
|
| 143 |
]
|
| 144 |
+
product = ProductScore(
|
| 145 |
+
product_id="TEST",
|
| 146 |
+
score=0.85,
|
| 147 |
+
chunk_count=2,
|
| 148 |
+
avg_rating=3.0,
|
| 149 |
+
evidence=chunks,
|
| 150 |
+
)
|
| 151 |
|
| 152 |
result = explainer.generate_explanation(case["query"], product, max_evidence=2)
|
| 153 |
hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
|
|
|
| 168 |
# SECTION: Empty Context Tests
|
| 169 |
# ============================================================================
|
| 170 |
|
| 171 |
+
|
| 172 |
def run_empty_context_tests():
|
| 173 |
"""Test graceful refusal with irrelevant evidence."""
|
| 174 |
from sage.services.explanation import Explainer
|
|
|
|
| 180 |
detector = HallucinationDetector()
|
| 181 |
|
| 182 |
cases = [
|
| 183 |
+
{
|
| 184 |
+
"name": "Irrelevant",
|
| 185 |
+
"query": "quantum computing textbook",
|
| 186 |
+
"evidence": "Great USB cable.",
|
| 187 |
+
},
|
| 188 |
{"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
|
| 189 |
+
{
|
| 190 |
+
"name": "Foreign",
|
| 191 |
+
"query": "wireless mouse",
|
| 192 |
+
"evidence": "Muy bueno el producto.",
|
| 193 |
+
},
|
| 194 |
]
|
| 195 |
|
| 196 |
+
refusal_words = [
|
| 197 |
+
"cannot",
|
| 198 |
+
"can't",
|
| 199 |
+
"unable",
|
| 200 |
+
"no evidence",
|
| 201 |
+
"insufficient",
|
| 202 |
+
"limited",
|
| 203 |
+
]
|
| 204 |
results = []
|
| 205 |
|
| 206 |
for case in cases:
|
| 207 |
log_section(logger, case["name"])
|
| 208 |
|
| 209 |
+
chunk = RetrievedChunk(
|
| 210 |
+
text=case["evidence"],
|
| 211 |
+
score=0.3,
|
| 212 |
+
product_id="TEST",
|
| 213 |
+
rating=3.0,
|
| 214 |
+
review_id="r1",
|
| 215 |
+
)
|
| 216 |
+
product = ProductScore(
|
| 217 |
+
product_id="TEST",
|
| 218 |
+
score=0.3,
|
| 219 |
+
chunk_count=1,
|
| 220 |
+
avg_rating=3.0,
|
| 221 |
+
evidence=[chunk],
|
| 222 |
+
)
|
| 223 |
|
| 224 |
result = explainer.generate_explanation(case["query"], product, max_evidence=1)
|
| 225 |
_hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
|
|
|
| 239 |
# SECTION: Calibration Check
|
| 240 |
# ============================================================================
|
| 241 |
|
| 242 |
+
|
| 243 |
@dataclass
|
| 244 |
class CalibrationSample:
|
| 245 |
query: str
|
|
|
|
| 265 |
|
| 266 |
logger.info("Generating samples...")
|
| 267 |
for query in queries:
|
| 268 |
+
products = get_candidates(
|
| 269 |
+
query=query, k=5, min_rating=3.0, aggregation=AggregationMethod.MAX
|
| 270 |
+
)
|
| 271 |
|
| 272 |
for product in products[:2]:
|
| 273 |
try:
|
| 274 |
result = explainer.generate_explanation(query, product, max_evidence=3)
|
| 275 |
+
hhem = detector.check_explanation(
|
| 276 |
+
result.evidence_texts, result.explanation
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
samples.append(
|
| 280 |
+
CalibrationSample(
|
| 281 |
+
query=query,
|
| 282 |
+
product_id=product.product_id,
|
| 283 |
+
retrieval_score=product.score,
|
| 284 |
+
evidence_count=product.chunk_count,
|
| 285 |
+
avg_rating=product.avg_rating,
|
| 286 |
+
hhem_score=hhem.score,
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
except Exception:
|
| 290 |
logger.debug("Error generating sample", exc_info=True)
|
| 291 |
|
|
|
|
| 312 |
# Stratified analysis
|
| 313 |
sorted_samples = sorted(samples, key=lambda s: s.retrieval_score)
|
| 314 |
n = len(sorted_samples)
|
| 315 |
+
low = sorted_samples[: n // 3]
|
| 316 |
+
mid = sorted_samples[n // 3 : 2 * n // 3]
|
| 317 |
+
high = sorted_samples[2 * n // 3 :]
|
| 318 |
|
| 319 |
log_section(logger, "HHEM by Confidence Tier")
|
| 320 |
logger.info(" LOW (n=%2d): %.3f", len(low), np.mean([s.hhem_score for s in low]))
|
| 321 |
logger.info(" MED (n=%2d): %.3f", len(mid), np.mean([s.hhem_score for s in mid]))
|
| 322 |
+
logger.info(
|
| 323 |
+
" HIGH (n=%2d): %.3f", len(high), np.mean([s.hhem_score for s in high])
|
| 324 |
+
)
|
| 325 |
|
| 326 |
|
| 327 |
# ============================================================================
|
| 328 |
# Main
|
| 329 |
# ============================================================================
|
| 330 |
|
| 331 |
+
|
| 332 |
def main():
|
| 333 |
parser = argparse.ArgumentParser(description="Run pipeline sanity checks")
|
| 334 |
parser.add_argument(
|
| 335 |
+
"--section",
|
| 336 |
+
"-s",
|
| 337 |
choices=["all", "spot", "adversarial", "empty", "calibration"],
|
| 338 |
default="all",
|
| 339 |
help="Which section to run",
|
scripts/summary.py
CHANGED
|
@@ -16,7 +16,12 @@ import json
|
|
| 16 |
import sys
|
| 17 |
from pathlib import Path
|
| 18 |
|
| 19 |
-
from sage.config import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
WIDTH = 60
|
| 22 |
SEP = "=" * WIDTH
|
|
@@ -82,14 +87,20 @@ def main():
|
|
| 82 |
quotes_total = mm.get("quotes_total", 0)
|
| 83 |
|
| 84 |
if claim_pass is not None:
|
| 85 |
-
print(
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# Full-explanation HHEM (reference)
|
| 89 |
h = faith["hhem"]
|
| 90 |
n_grounded = n_samples - h.get("n_hallucinated", 0)
|
| 91 |
full_avg = h.get("mean_score")
|
| 92 |
-
print(
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# RAGAS if available
|
| 95 |
ragas = faith.get("ragas", {})
|
|
@@ -98,8 +109,10 @@ def main():
|
|
| 98 |
print(f" RAGAS Faith: {fmt(ragas_faith, 3)}")
|
| 99 |
|
| 100 |
# Pass/fail: use claim-level as primary, fall back to RAGAS, then full HHEM
|
| 101 |
-
effective =
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
)
|
| 104 |
if effective is not None:
|
| 105 |
status = "PASS" if effective >= FAITHFULNESS_TARGET else "FAIL"
|
|
@@ -123,7 +136,9 @@ def main():
|
|
| 123 |
print(f" {label + ':':<15s} {fmt(m, 2) if m is not None else ' ---'}")
|
| 124 |
if overall is not None:
|
| 125 |
status = "PASS" if human.get("pass", False) else "FAIL"
|
| 126 |
-
print(
|
|
|
|
|
|
|
| 127 |
corr = human.get("hhem_trust_correlation", {})
|
| 128 |
r = corr.get("spearman_r")
|
| 129 |
if r is not None:
|
|
|
|
| 16 |
import sys
|
| 17 |
from pathlib import Path
|
| 18 |
|
| 19 |
+
from sage.config import (
|
| 20 |
+
EVAL_DIMENSIONS,
|
| 21 |
+
FAITHFULNESS_TARGET,
|
| 22 |
+
HELPFULNESS_TARGET,
|
| 23 |
+
RESULTS_DIR,
|
| 24 |
+
)
|
| 25 |
|
| 26 |
WIDTH = 60
|
| 27 |
SEP = "=" * WIDTH
|
|
|
|
| 87 |
quotes_total = mm.get("quotes_total", 0)
|
| 88 |
|
| 89 |
if claim_pass is not None:
|
| 90 |
+
print(
|
| 91 |
+
f" Claim HHEM: {fmt(claim_avg, 3)} ({claim_pass * 100:.0f}% pass)"
|
| 92 |
+
)
|
| 93 |
+
print(
|
| 94 |
+
f" Quote Verif: {fmt(quote_rate, 3)} ({quotes_found}/{quotes_total})"
|
| 95 |
+
)
|
| 96 |
|
| 97 |
# Full-explanation HHEM (reference)
|
| 98 |
h = faith["hhem"]
|
| 99 |
n_grounded = n_samples - h.get("n_hallucinated", 0)
|
| 100 |
full_avg = h.get("mean_score")
|
| 101 |
+
print(
|
| 102 |
+
f" Full HHEM: {fmt(full_avg, 3)} ({n_grounded}/{n_samples} grounded, reference)"
|
| 103 |
+
)
|
| 104 |
|
| 105 |
# RAGAS if available
|
| 106 |
ragas = faith.get("ragas", {})
|
|
|
|
| 109 |
print(f" RAGAS Faith: {fmt(ragas_faith, 3)}")
|
| 110 |
|
| 111 |
# Pass/fail: use claim-level as primary, fall back to RAGAS, then full HHEM
|
| 112 |
+
effective = (
|
| 113 |
+
claim_avg
|
| 114 |
+
if claim_avg is not None
|
| 115 |
+
else (ragas_faith if ragas_faith is not None else full_avg)
|
| 116 |
)
|
| 117 |
if effective is not None:
|
| 118 |
status = "PASS" if effective >= FAITHFULNESS_TARGET else "FAIL"
|
|
|
|
| 136 |
print(f" {label + ':':<15s} {fmt(m, 2) if m is not None else ' ---'}")
|
| 137 |
if overall is not None:
|
| 138 |
status = "PASS" if human.get("pass", False) else "FAIL"
|
| 139 |
+
print(
|
| 140 |
+
f" Helpfulness: {fmt(overall, 2)} (target: {target:.1f}) [{status}]"
|
| 141 |
+
)
|
| 142 |
corr = human.get("hhem_trust_correlation", {})
|
| 143 |
r = corr.get("spearman_r")
|
| 144 |
if r is not None:
|
tests/test_aggregation.py
CHANGED
|
@@ -82,7 +82,9 @@ class TestApplyWeightedRanking:
|
|
| 82 |
ProductScore(product_id="A", score=0.9, chunk_count=2, avg_rating=3.0),
|
| 83 |
ProductScore(product_id="B", score=0.7, chunk_count=1, avg_rating=5.0),
|
| 84 |
]
|
| 85 |
-
ranked = apply_weighted_ranking(
|
|
|
|
|
|
|
| 86 |
assert len(ranked) == 2
|
| 87 |
# B has higher rating, so with 50/50 weights it might rank higher
|
| 88 |
assert all(isinstance(p, ProductScore) for p in ranked)
|
|
@@ -92,7 +94,9 @@ class TestApplyWeightedRanking:
|
|
| 92 |
ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
|
| 93 |
ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
|
| 94 |
]
|
| 95 |
-
ranked = apply_weighted_ranking(
|
|
|
|
|
|
|
| 96 |
assert ranked[0].product_id == "A"
|
| 97 |
|
| 98 |
def test_pure_rating_reranks(self):
|
|
@@ -100,7 +104,9 @@ class TestApplyWeightedRanking:
|
|
| 100 |
ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
|
| 101 |
ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
|
| 102 |
]
|
| 103 |
-
ranked = apply_weighted_ranking(
|
|
|
|
|
|
|
| 104 |
assert ranked[0].product_id == "B"
|
| 105 |
|
| 106 |
def test_single_product(self):
|
|
|
|
| 82 |
ProductScore(product_id="A", score=0.9, chunk_count=2, avg_rating=3.0),
|
| 83 |
ProductScore(product_id="B", score=0.7, chunk_count=1, avg_rating=5.0),
|
| 84 |
]
|
| 85 |
+
ranked = apply_weighted_ranking(
|
| 86 |
+
products, similarity_weight=0.5, rating_weight=0.5
|
| 87 |
+
)
|
| 88 |
assert len(ranked) == 2
|
| 89 |
# B has higher rating, so with 50/50 weights it might rank higher
|
| 90 |
assert all(isinstance(p, ProductScore) for p in ranked)
|
|
|
|
| 94 |
ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
|
| 95 |
ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
|
| 96 |
]
|
| 97 |
+
ranked = apply_weighted_ranking(
|
| 98 |
+
products, similarity_weight=1.0, rating_weight=0.0
|
| 99 |
+
)
|
| 100 |
assert ranked[0].product_id == "A"
|
| 101 |
|
| 102 |
def test_pure_rating_reranks(self):
|
|
|
|
| 104 |
ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
|
| 105 |
ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
|
| 106 |
]
|
| 107 |
+
ranked = apply_weighted_ranking(
|
| 108 |
+
products, similarity_weight=0.0, rating_weight=1.0
|
| 109 |
+
)
|
| 110 |
assert ranked[0].product_id == "B"
|
| 111 |
|
| 112 |
def test_single_product(self):
|
tests/test_api.py
CHANGED
|
@@ -27,9 +27,16 @@ def _make_app(**state_overrides) -> FastAPI:
|
|
| 27 |
mock_cache = MagicMock()
|
| 28 |
mock_cache.get.return_value = (None, "miss")
|
| 29 |
mock_cache.stats.return_value = SimpleNamespace(
|
| 30 |
-
size=0,
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
app.state.qdrant = state_overrides.get("qdrant", mock_qdrant)
|
|
@@ -56,6 +63,7 @@ class TestHealthEndpoint:
|
|
| 56 |
with TestClient(app) as c:
|
| 57 |
# Patch collection_exists to return True
|
| 58 |
import sage.api.routes as routes_mod
|
|
|
|
| 59 |
original = routes_mod.collection_exists
|
| 60 |
routes_mod.collection_exists = lambda client: True
|
| 61 |
try:
|
|
@@ -70,6 +78,7 @@ class TestHealthEndpoint:
|
|
| 70 |
def test_degraded_when_collection_missing(self):
|
| 71 |
app = _make_app()
|
| 72 |
import sage.api.routes as routes_mod
|
|
|
|
| 73 |
original = routes_mod.collection_exists
|
| 74 |
routes_mod.collection_exists = lambda client: False
|
| 75 |
try:
|
|
@@ -90,6 +99,7 @@ class TestRecommendEndpoint:
|
|
| 90 |
|
| 91 |
def test_empty_results(self, client):
|
| 92 |
import sage.api.routes as routes_mod
|
|
|
|
| 93 |
original = routes_mod.get_candidates
|
| 94 |
routes_mod.get_candidates = lambda **kw: []
|
| 95 |
try:
|
|
@@ -102,12 +112,18 @@ class TestRecommendEndpoint:
|
|
| 102 |
|
| 103 |
def test_returns_products_without_explain(self):
|
| 104 |
product = ProductScore(
|
| 105 |
-
product_id="P1",
|
|
|
|
|
|
|
|
|
|
| 106 |
evidence=[
|
| 107 |
-
RetrievedChunk(
|
|
|
|
|
|
|
| 108 |
],
|
| 109 |
)
|
| 110 |
import sage.api.routes as routes_mod
|
|
|
|
| 111 |
original = routes_mod.get_candidates
|
| 112 |
routes_mod.get_candidates = lambda **kw: [product]
|
| 113 |
app = _make_app()
|
|
@@ -126,12 +142,18 @@ class TestRecommendEndpoint:
|
|
| 126 |
|
| 127 |
def test_explainer_unavailable_returns_503(self):
|
| 128 |
product = ProductScore(
|
| 129 |
-
product_id="P1",
|
|
|
|
|
|
|
|
|
|
| 130 |
evidence=[
|
| 131 |
-
RetrievedChunk(
|
|
|
|
|
|
|
| 132 |
],
|
| 133 |
)
|
| 134 |
import sage.api.routes as routes_mod
|
|
|
|
| 135 |
original = routes_mod.get_candidates
|
| 136 |
routes_mod.get_candidates = lambda **kw: [product]
|
| 137 |
|
|
|
|
| 27 |
mock_cache = MagicMock()
|
| 28 |
mock_cache.get.return_value = (None, "miss")
|
| 29 |
mock_cache.stats.return_value = SimpleNamespace(
|
| 30 |
+
size=0,
|
| 31 |
+
max_entries=100,
|
| 32 |
+
exact_hits=0,
|
| 33 |
+
semantic_hits=0,
|
| 34 |
+
misses=0,
|
| 35 |
+
evictions=0,
|
| 36 |
+
hit_rate=0.0,
|
| 37 |
+
ttl_seconds=3600.0,
|
| 38 |
+
similarity_threshold=0.92,
|
| 39 |
+
avg_semantic_similarity=0.0,
|
| 40 |
)
|
| 41 |
|
| 42 |
app.state.qdrant = state_overrides.get("qdrant", mock_qdrant)
|
|
|
|
| 63 |
with TestClient(app) as c:
|
| 64 |
# Patch collection_exists to return True
|
| 65 |
import sage.api.routes as routes_mod
|
| 66 |
+
|
| 67 |
original = routes_mod.collection_exists
|
| 68 |
routes_mod.collection_exists = lambda client: True
|
| 69 |
try:
|
|
|
|
| 78 |
def test_degraded_when_collection_missing(self):
|
| 79 |
app = _make_app()
|
| 80 |
import sage.api.routes as routes_mod
|
| 81 |
+
|
| 82 |
original = routes_mod.collection_exists
|
| 83 |
routes_mod.collection_exists = lambda client: False
|
| 84 |
try:
|
|
|
|
| 99 |
|
| 100 |
def test_empty_results(self, client):
|
| 101 |
import sage.api.routes as routes_mod
|
| 102 |
+
|
| 103 |
original = routes_mod.get_candidates
|
| 104 |
routes_mod.get_candidates = lambda **kw: []
|
| 105 |
try:
|
|
|
|
| 112 |
|
| 113 |
def test_returns_products_without_explain(self):
|
| 114 |
product = ProductScore(
|
| 115 |
+
product_id="P1",
|
| 116 |
+
score=0.9,
|
| 117 |
+
chunk_count=2,
|
| 118 |
+
avg_rating=4.5,
|
| 119 |
evidence=[
|
| 120 |
+
RetrievedChunk(
|
| 121 |
+
text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"
|
| 122 |
+
),
|
| 123 |
],
|
| 124 |
)
|
| 125 |
import sage.api.routes as routes_mod
|
| 126 |
+
|
| 127 |
original = routes_mod.get_candidates
|
| 128 |
routes_mod.get_candidates = lambda **kw: [product]
|
| 129 |
app = _make_app()
|
|
|
|
| 142 |
|
| 143 |
def test_explainer_unavailable_returns_503(self):
|
| 144 |
product = ProductScore(
|
| 145 |
+
product_id="P1",
|
| 146 |
+
score=0.9,
|
| 147 |
+
chunk_count=2,
|
| 148 |
+
avg_rating=4.5,
|
| 149 |
evidence=[
|
| 150 |
+
RetrievedChunk(
|
| 151 |
+
text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"
|
| 152 |
+
),
|
| 153 |
],
|
| 154 |
)
|
| 155 |
import sage.api.routes as routes_mod
|
| 156 |
+
|
| 157 |
original = routes_mod.get_candidates
|
| 158 |
routes_mod.get_candidates = lambda **kw: [product]
|
| 159 |
|
tests/test_chunking.py
CHANGED
|
@@ -67,7 +67,9 @@ class TestSlidingWindowChunk:
|
|
| 67 |
|
| 68 |
def test_long_text_creates_multiple_chunks(self):
|
| 69 |
# Create text long enough to require multiple chunks
|
| 70 |
-
sentences = [
|
|
|
|
|
|
|
| 71 |
text = " ".join(sentences)
|
| 72 |
chunks = sliding_window_chunk(text, chunk_size=50, overlap=10)
|
| 73 |
assert len(chunks) > 1
|
|
|
|
| 67 |
|
| 68 |
def test_long_text_creates_multiple_chunks(self):
|
| 69 |
# Create text long enough to require multiple chunks
|
| 70 |
+
sentences = [
|
| 71 |
+
f"This is sentence number {i} with some padding text." for i in range(20)
|
| 72 |
+
]
|
| 73 |
text = " ".join(sentences)
|
| 74 |
chunks = sliding_window_chunk(text, chunk_size=50, overlap=10)
|
| 75 |
assert len(chunks) > 1
|
tests/test_evidence.py
CHANGED
|
@@ -50,7 +50,10 @@ class TestCheckEvidenceQuality:
|
|
| 50 |
product = _product(score=0.3, n_chunks=3, text_len=300)
|
| 51 |
quality = check_evidence_quality(product, min_score=0.7)
|
| 52 |
assert quality.is_sufficient is False
|
| 53 |
-
assert
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def test_tracks_chunk_count(self):
|
| 56 |
product = _product(score=0.85, n_chunks=4, text_len=200)
|
|
|
|
| 50 |
product = _product(score=0.3, n_chunks=3, text_len=300)
|
| 51 |
quality = check_evidence_quality(product, min_score=0.7)
|
| 52 |
assert quality.is_sufficient is False
|
| 53 |
+
assert (
|
| 54 |
+
"relevance" in quality.failure_reason.lower()
|
| 55 |
+
or "score" in quality.failure_reason.lower()
|
| 56 |
+
)
|
| 57 |
|
| 58 |
def test_tracks_chunk_count(self):
|
| 59 |
product = _product(score=0.85, n_chunks=4, text_len=200)
|
tests/test_faithfulness.py
CHANGED
|
@@ -29,13 +29,20 @@ class TestIsRefusal:
|
|
| 29 |
|
| 30 |
class TestIsMismatchWarning:
|
| 31 |
def test_detects_not_best_match(self):
|
| 32 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def test_detects_not_designed_for(self):
|
| 35 |
assert is_mismatch_warning("This is not designed for that purpose.") is True
|
| 36 |
|
| 37 |
def test_detects_not_suitable(self):
|
| 38 |
-
assert
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def test_normal_explanation_not_mismatch(self):
|
| 41 |
assert is_mismatch_warning("Great headphones with noise cancellation.") is False
|
|
|
|
| 29 |
|
| 30 |
class TestIsMismatchWarning:
|
| 31 |
def test_detects_not_best_match(self):
|
| 32 |
+
assert (
|
| 33 |
+
is_mismatch_warning(
|
| 34 |
+
"This product may not be the best match for your needs."
|
| 35 |
+
)
|
| 36 |
+
is True
|
| 37 |
+
)
|
| 38 |
|
| 39 |
def test_detects_not_designed_for(self):
|
| 40 |
assert is_mismatch_warning("This is not designed for that purpose.") is True
|
| 41 |
|
| 42 |
def test_detects_not_suitable(self):
|
| 43 |
+
assert (
|
| 44 |
+
is_mismatch_warning("This product is not suitable for heavy use.") is True
|
| 45 |
+
)
|
| 46 |
|
| 47 |
def test_normal_explanation_not_mismatch(self):
|
| 48 |
assert is_mismatch_warning("Great headphones with noise cancellation.") is False
|
tests/test_models.py
CHANGED
|
@@ -35,19 +35,32 @@ class TestNewItem:
|
|
| 35 |
class TestProductScore:
|
| 36 |
def test_top_evidence_returns_highest(self):
|
| 37 |
chunks = [
|
| 38 |
-
RetrievedChunk(
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
]
|
| 42 |
product = ProductScore(
|
| 43 |
-
product_id="P1",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
)
|
| 45 |
assert product.top_evidence.text == "high"
|
| 46 |
assert product.top_evidence.score == 0.9
|
| 47 |
|
| 48 |
def test_top_evidence_empty(self):
|
| 49 |
product = ProductScore(
|
| 50 |
-
product_id="P1",
|
|
|
|
|
|
|
|
|
|
| 51 |
)
|
| 52 |
assert product.top_evidence is None
|
| 53 |
|
|
@@ -116,7 +129,10 @@ class TestStreamingExplanation:
|
|
| 116 |
class TestEvidenceQuality:
|
| 117 |
def test_sufficient(self):
|
| 118 |
eq = EvidenceQuality(
|
| 119 |
-
is_sufficient=True,
|
|
|
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
assert eq.is_sufficient is True
|
| 122 |
assert eq.failure_reason is None
|
|
|
|
| 35 |
class TestProductScore:
|
| 36 |
def test_top_evidence_returns_highest(self):
|
| 37 |
chunks = [
|
| 38 |
+
RetrievedChunk(
|
| 39 |
+
text="low", score=0.5, product_id="P1", rating=4.0, review_id="r1"
|
| 40 |
+
),
|
| 41 |
+
RetrievedChunk(
|
| 42 |
+
text="high", score=0.9, product_id="P1", rating=4.0, review_id="r2"
|
| 43 |
+
),
|
| 44 |
+
RetrievedChunk(
|
| 45 |
+
text="mid", score=0.7, product_id="P1", rating=4.0, review_id="r3"
|
| 46 |
+
),
|
| 47 |
]
|
| 48 |
product = ProductScore(
|
| 49 |
+
product_id="P1",
|
| 50 |
+
score=0.9,
|
| 51 |
+
chunk_count=3,
|
| 52 |
+
avg_rating=4.0,
|
| 53 |
+
evidence=chunks,
|
| 54 |
)
|
| 55 |
assert product.top_evidence.text == "high"
|
| 56 |
assert product.top_evidence.score == 0.9
|
| 57 |
|
| 58 |
def test_top_evidence_empty(self):
|
| 59 |
product = ProductScore(
|
| 60 |
+
product_id="P1",
|
| 61 |
+
score=0.5,
|
| 62 |
+
chunk_count=0,
|
| 63 |
+
avg_rating=4.0,
|
| 64 |
)
|
| 65 |
assert product.top_evidence is None
|
| 66 |
|
|
|
|
| 129 |
class TestEvidenceQuality:
|
| 130 |
def test_sufficient(self):
|
| 131 |
eq = EvidenceQuality(
|
| 132 |
+
is_sufficient=True,
|
| 133 |
+
chunk_count=3,
|
| 134 |
+
total_tokens=150,
|
| 135 |
+
top_score=0.9,
|
| 136 |
)
|
| 137 |
assert eq.is_sufficient is True
|
| 138 |
assert eq.failure_reason is None
|