vxa8502 commited on
Commit
bf39698
·
1 Parent(s): 4a10224

Apply ruff formatting

Browse files
sage/adapters/hhem.py CHANGED
@@ -148,7 +148,9 @@ class HallucinationDetector:
148
  remaining = [t for t in evidence_texts if hyp_lower not in t.lower()]
149
  evidence_texts = containing + remaining
150
 
151
- hypothesis_tokens = len(self.tokenizer(hypothesis, add_special_tokens=False).input_ids)
 
 
152
  budget = HHEM_MAX_TOKENS - HHEM_TEMPLATE_OVERHEAD - hypothesis_tokens
153
 
154
  kept = []
@@ -253,13 +255,20 @@ class HallucinationDetector:
253
  List of ClaimResult objects, one per claim.
254
  """
255
  pairs = [
256
- (self._format_premise(evidence_texts, hypothesis=claim, prioritize_hypothesis=True), claim)
 
 
 
 
 
257
  for claim in claims
258
  ]
259
  scores = self._predict(pairs)
260
 
261
  return [
262
- ClaimResult(claim=claim, score=score, is_hallucinated=score < self.threshold)
 
 
263
  for claim, score in zip(claims, scores)
264
  ]
265
 
 
148
  remaining = [t for t in evidence_texts if hyp_lower not in t.lower()]
149
  evidence_texts = containing + remaining
150
 
151
+ hypothesis_tokens = len(
152
+ self.tokenizer(hypothesis, add_special_tokens=False).input_ids
153
+ )
154
  budget = HHEM_MAX_TOKENS - HHEM_TEMPLATE_OVERHEAD - hypothesis_tokens
155
 
156
  kept = []
 
255
  List of ClaimResult objects, one per claim.
256
  """
257
  pairs = [
258
+ (
259
+ self._format_premise(
260
+ evidence_texts, hypothesis=claim, prioritize_hypothesis=True
261
+ ),
262
+ claim,
263
+ )
264
  for claim in claims
265
  ]
266
  scores = self._predict(pairs)
267
 
268
  return [
269
+ ClaimResult(
270
+ claim=claim, score=score, is_hallucinated=score < self.threshold
271
+ )
272
  for claim, score in zip(claims, scores)
273
  ]
274
 
sage/adapters/llm.py CHANGED
@@ -356,7 +356,9 @@ def get_llm_client(provider: str | None = None) -> LLMClient:
356
  elif provider == "openai":
357
  return OpenAIClient()
358
  else:
359
- raise ValueError(f"Unknown LLM provider: {provider}. Use 'anthropic' or 'openai'.")
 
 
360
 
361
 
362
  __all__ = [
 
356
  elif provider == "openai":
357
  return OpenAIClient()
358
  else:
359
+ raise ValueError(
360
+ f"Unknown LLM provider: {provider}. Use 'anthropic' or 'openai'."
361
+ )
362
 
363
 
364
  __all__ = [
sage/adapters/vector_store.py CHANGED
@@ -42,8 +42,7 @@ def get_client():
42
  from qdrant_client import QdrantClient
43
  except ImportError:
44
  raise ImportError(
45
- "qdrant-client package required. "
46
- "Install with: pip install qdrant-client"
47
  )
48
 
49
  if QDRANT_API_KEY:
 
42
  from qdrant_client import QdrantClient
43
  except ImportError:
44
  raise ImportError(
45
+ "qdrant-client package required. Install with: pip install qdrant-client"
 
46
  )
47
 
48
  if QDRANT_API_KEY:
sage/api/app.py CHANGED
@@ -31,24 +31,19 @@ async def _lifespan(app: FastAPI):
31
 
32
  # Validate LLM credentials early
33
  from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
 
34
  if not ANTHROPIC_API_KEY and not OPENAI_API_KEY:
35
  logger.error(
36
- "No LLM API key set -- add ANTHROPIC_API_KEY "
37
- "or OPENAI_API_KEY to .env"
38
  )
39
  elif LLM_PROVIDER == "anthropic" and not ANTHROPIC_API_KEY:
40
- logger.warning(
41
- "LLM_PROVIDER=anthropic but ANTHROPIC_API_KEY "
42
- "is not set"
43
- )
44
  elif LLM_PROVIDER == "openai" and not OPENAI_API_KEY:
45
- logger.warning(
46
- "LLM_PROVIDER=openai but OPENAI_API_KEY "
47
- "is not set"
48
- )
49
 
50
  # Embedder (loads E5-small model) -- required for all requests
51
  from sage.adapters.embeddings import get_embedder
 
52
  try:
53
  app.state.embedder = get_embedder()
54
  logger.info("Embedder loaded")
@@ -58,6 +53,7 @@ async def _lifespan(app: FastAPI):
58
 
59
  # Qdrant client
60
  from sage.adapters.vector_store import get_client, collection_exists
 
61
  app.state.qdrant = get_client()
62
  try:
63
  if collection_exists(app.state.qdrant):
@@ -69,6 +65,7 @@ async def _lifespan(app: FastAPI):
69
 
70
  # HHEM hallucination detector (loads T5 model) -- required for grounding
71
  from sage.adapters.hhem import HallucinationDetector
 
72
  try:
73
  app.state.detector = HallucinationDetector()
74
  logger.info("HHEM detector loaded")
@@ -78,18 +75,19 @@ async def _lifespan(app: FastAPI):
78
 
79
  # LLM explainer -- graceful degradation if unavailable
80
  from sage.services.explanation import Explainer
 
81
  try:
82
  app.state.explainer = Explainer()
83
  logger.info("Explainer ready (%s)", app.state.explainer.model)
84
  except Exception:
85
  logger.exception(
86
- "Failed to initialize explainer -- "
87
- "explain=true requests will fail"
88
  )
89
  app.state.explainer = None
90
 
91
  # Semantic cache
92
  from sage.services.cache import SemanticCache
 
93
  app.state.cache = SemanticCache()
94
  logger.info("Semantic cache initialized")
95
 
 
31
 
32
  # Validate LLM credentials early
33
  from sage.config import ANTHROPIC_API_KEY, LLM_PROVIDER, OPENAI_API_KEY
34
+
35
  if not ANTHROPIC_API_KEY and not OPENAI_API_KEY:
36
  logger.error(
37
+ "No LLM API key set -- add ANTHROPIC_API_KEY or OPENAI_API_KEY to .env"
 
38
  )
39
  elif LLM_PROVIDER == "anthropic" and not ANTHROPIC_API_KEY:
40
+ logger.warning("LLM_PROVIDER=anthropic but ANTHROPIC_API_KEY is not set")
 
 
 
41
  elif LLM_PROVIDER == "openai" and not OPENAI_API_KEY:
42
+ logger.warning("LLM_PROVIDER=openai but OPENAI_API_KEY is not set")
 
 
 
43
 
44
  # Embedder (loads E5-small model) -- required for all requests
45
  from sage.adapters.embeddings import get_embedder
46
+
47
  try:
48
  app.state.embedder = get_embedder()
49
  logger.info("Embedder loaded")
 
53
 
54
  # Qdrant client
55
  from sage.adapters.vector_store import get_client, collection_exists
56
+
57
  app.state.qdrant = get_client()
58
  try:
59
  if collection_exists(app.state.qdrant):
 
65
 
66
  # HHEM hallucination detector (loads T5 model) -- required for grounding
67
  from sage.adapters.hhem import HallucinationDetector
68
+
69
  try:
70
  app.state.detector = HallucinationDetector()
71
  logger.info("HHEM detector loaded")
 
75
 
76
  # LLM explainer -- graceful degradation if unavailable
77
  from sage.services.explanation import Explainer
78
+
79
  try:
80
  app.state.explainer = Explainer()
81
  logger.info("Explainer ready (%s)", app.state.explainer.model)
82
  except Exception:
83
  logger.exception(
84
+ "Failed to initialize explainer -- explain=true requests will fail"
 
85
  )
86
  app.state.explainer = None
87
 
88
  # Semantic cache
89
  from sage.services.cache import SemanticCache
90
+
91
  app.state.cache = SemanticCache()
92
  logger.info("Semantic cache initialized")
93
 
sage/api/metrics.py CHANGED
@@ -16,7 +16,12 @@ logger = get_logger(__name__)
16
  # ---------------------------------------------------------------------------
17
 
18
  try:
19
- from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
 
 
 
 
 
20
 
21
  REQUEST_COUNT = Counter(
22
  "sage_requests_total",
@@ -48,6 +53,7 @@ except ImportError:
48
  # Public helpers
49
  # ---------------------------------------------------------------------------
50
 
 
51
  def record_request(endpoint: str, method: str, status: int) -> None:
52
  """Increment the request counter."""
53
  if _PROMETHEUS_AVAILABLE:
 
16
  # ---------------------------------------------------------------------------
17
 
18
  try:
19
+ from prometheus_client import (
20
+ Counter,
21
+ Histogram,
22
+ generate_latest,
23
+ CONTENT_TYPE_LATEST,
24
+ )
25
 
26
  REQUEST_COUNT = Counter(
27
  "sage_requests_total",
 
53
  # Public helpers
54
  # ---------------------------------------------------------------------------
55
 
56
+
57
  def record_request(endpoint: str, method: str, status: int) -> None:
58
  """Increment the request counter."""
59
  if _PROMETHEUS_AVAILABLE:
sage/api/middleware.py CHANGED
@@ -69,9 +69,7 @@ class LatencyMiddleware:
69
  # The Prometheus histogram (in finally) measures total time.
70
  elapsed_ms = (time.perf_counter() - start) * 1000
71
  headers = list(message.get("headers", []))
72
- headers.append(
73
- (b"x-response-time-ms", f"{elapsed_ms:.1f}".encode())
74
- )
75
  headers.append((b"x-request-id", request_id.encode()))
76
  message = {**message, "headers": headers}
77
  await send(message)
@@ -88,6 +86,9 @@ class LatencyMiddleware:
88
  if path not in _QUIET_PATHS:
89
  logger.info(
90
  "%s %s %d %.1fms [%s]",
91
- method, path, status,
92
- elapsed_ms, request_id,
 
 
 
93
  )
 
69
  # The Prometheus histogram (in finally) measures total time.
70
  elapsed_ms = (time.perf_counter() - start) * 1000
71
  headers = list(message.get("headers", []))
72
+ headers.append((b"x-response-time-ms", f"{elapsed_ms:.1f}".encode()))
 
 
73
  headers.append((b"x-request-id", request_id.encode()))
74
  message = {**message, "headers": headers}
75
  await send(message)
 
86
  if path not in _QUIET_PATHS:
87
  logger.info(
88
  "%s %s %d %.1fms [%s]",
89
+ method,
90
+ path,
91
+ status,
92
+ elapsed_ms,
93
+ request_id,
94
  )
sage/api/routes.py CHANGED
@@ -24,7 +24,12 @@ from pydantic import BaseModel
24
  from sage.adapters.vector_store import collection_exists
25
  from sage.api.metrics import metrics_response, record_cache_event
26
  from sage.config import MAX_EVIDENCE, get_logger
27
- from sage.core import AggregationMethod, ExplanationResult, ProductScore, verify_citations
 
 
 
 
 
28
  from sage.services.retrieval import get_candidates
29
 
30
  # Cap parallel LLM+HHEM workers per request. With k=10 and concurrent
@@ -41,6 +46,7 @@ router = APIRouter()
41
  # Response models
42
  # ---------------------------------------------------------------------------
43
 
 
44
  class EvidenceSource(BaseModel):
45
  id: str
46
  text: str
@@ -95,6 +101,7 @@ class CacheStatsResponse(BaseModel):
95
  # Shared helpers
96
  # ---------------------------------------------------------------------------
97
 
 
98
  @dataclass
99
  class RecommendParams:
100
  """Query parameters shared by /recommend and /recommend/stream."""
@@ -105,7 +112,9 @@ class RecommendParams:
105
 
106
 
107
  def _fetch_products(
108
- params: RecommendParams, app, query_embedding=None,
 
 
109
  ) -> list[ProductScore]:
110
  """Run candidate generation with lifespan-managed singletons."""
111
  return get_candidates(
@@ -138,6 +147,7 @@ def _build_evidence_list(result: ExplanationResult) -> list[dict]:
138
  # Health
139
  # ---------------------------------------------------------------------------
140
 
 
141
  @router.get("/health", response_model=HealthResponse)
142
  def health(request: Request):
143
  """Deployment readiness probe. Checks Qdrant connectivity.
@@ -159,6 +169,7 @@ def health(request: Request):
159
  # Recommend (non-streaming)
160
  # ---------------------------------------------------------------------------
161
 
 
162
  @router.get(
163
  "/recommend",
164
  response_model=RecommendResponse,
@@ -208,20 +219,27 @@ def recommend(
208
  # HHEM model in eval() + no_grad() = read-only forward
209
  # pass with no state mutation. Tokenizer is stateless.
210
  er = explainer.generate_explanation(
211
- query=q, product=product, max_evidence=MAX_EVIDENCE,
 
 
212
  )
213
  hr = detector.check_explanation(
214
  evidence_texts=er.evidence_texts,
215
  explanation=er.explanation,
216
  )
217
- cr = verify_citations(er.explanation, er.evidence_ids, er.evidence_texts)
 
 
218
  return er, hr, cr
219
 
220
- with ThreadPoolExecutor(max_workers=min(len(products), _MAX_EXPLAIN_WORKERS)) as pool:
 
 
221
  results = list(pool.map(_explain, products))
222
 
223
  for i, (product, (er, hr, cr)) in enumerate(
224
- zip(products, results), 1,
 
225
  ):
226
  rec = _build_product_dict(i, product)
227
  rec["explanation"] = er.explanation
@@ -257,6 +275,7 @@ def recommend(
257
  # Recommend (SSE streaming)
258
  # ---------------------------------------------------------------------------
259
 
 
260
  def _sse_event(event: str, data: str) -> str:
261
  """Format a single SSE event."""
262
  return f"event: {event}\ndata: {data}\n\n"
@@ -267,9 +286,16 @@ def _stream_recommendations(
267
  app,
268
  ) -> Iterator[str]:
269
  """Generator that yields SSE events for streaming recommendations."""
270
- yield _sse_event("metadata", json.dumps({
271
- "verified": False, "cache": False, "hhem": False,
272
- }))
 
 
 
 
 
 
 
273
 
274
  try:
275
  products = _fetch_products(params, app)
@@ -285,7 +311,9 @@ def _stream_recommendations(
285
 
286
  explainer = app.state.explainer
287
  if explainer is None:
288
- yield _sse_event("error", json.dumps({"detail": "Explanation service unavailable"}))
 
 
289
  yield _sse_event("done", json.dumps({"status": "error"}))
290
  return
291
 
@@ -314,7 +342,9 @@ def _stream_recommendations(
314
  yield _sse_event("refusal", json.dumps({"detail": str(exc)}))
315
  except Exception:
316
  logger.exception("Streaming error for product %s", product.product_id)
317
- yield _sse_event("error", json.dumps({"detail": "Failed to generate explanation"}))
 
 
318
 
319
  yield _sse_event("done", json.dumps({"status": "complete"}))
320
 
@@ -341,6 +371,7 @@ def recommend_stream(
341
  # Cache management
342
  # ---------------------------------------------------------------------------
343
 
 
344
  @router.get("/cache/stats", response_model=CacheStatsResponse)
345
  def cache_stats(request: Request):
346
  """Return cache performance statistics."""
@@ -370,6 +401,7 @@ def cache_clear(request: Request):
370
  # Prometheus metrics
371
  # ---------------------------------------------------------------------------
372
 
 
373
  @router.get("/metrics")
374
  def metrics():
375
  """Prometheus metrics endpoint."""
 
24
  from sage.adapters.vector_store import collection_exists
25
  from sage.api.metrics import metrics_response, record_cache_event
26
  from sage.config import MAX_EVIDENCE, get_logger
27
+ from sage.core import (
28
+ AggregationMethod,
29
+ ExplanationResult,
30
+ ProductScore,
31
+ verify_citations,
32
+ )
33
  from sage.services.retrieval import get_candidates
34
 
35
  # Cap parallel LLM+HHEM workers per request. With k=10 and concurrent
 
46
  # Response models
47
  # ---------------------------------------------------------------------------
48
 
49
+
50
  class EvidenceSource(BaseModel):
51
  id: str
52
  text: str
 
101
  # Shared helpers
102
  # ---------------------------------------------------------------------------
103
 
104
+
105
  @dataclass
106
  class RecommendParams:
107
  """Query parameters shared by /recommend and /recommend/stream."""
 
112
 
113
 
114
  def _fetch_products(
115
+ params: RecommendParams,
116
+ app,
117
+ query_embedding=None,
118
  ) -> list[ProductScore]:
119
  """Run candidate generation with lifespan-managed singletons."""
120
  return get_candidates(
 
147
  # Health
148
  # ---------------------------------------------------------------------------
149
 
150
+
151
  @router.get("/health", response_model=HealthResponse)
152
  def health(request: Request):
153
  """Deployment readiness probe. Checks Qdrant connectivity.
 
169
  # Recommend (non-streaming)
170
  # ---------------------------------------------------------------------------
171
 
172
+
173
  @router.get(
174
  "/recommend",
175
  response_model=RecommendResponse,
 
219
  # HHEM model in eval() + no_grad() = read-only forward
220
  # pass with no state mutation. Tokenizer is stateless.
221
  er = explainer.generate_explanation(
222
+ query=q,
223
+ product=product,
224
+ max_evidence=MAX_EVIDENCE,
225
  )
226
  hr = detector.check_explanation(
227
  evidence_texts=er.evidence_texts,
228
  explanation=er.explanation,
229
  )
230
+ cr = verify_citations(
231
+ er.explanation, er.evidence_ids, er.evidence_texts
232
+ )
233
  return er, hr, cr
234
 
235
+ with ThreadPoolExecutor(
236
+ max_workers=min(len(products), _MAX_EXPLAIN_WORKERS)
237
+ ) as pool:
238
  results = list(pool.map(_explain, products))
239
 
240
  for i, (product, (er, hr, cr)) in enumerate(
241
+ zip(products, results),
242
+ 1,
243
  ):
244
  rec = _build_product_dict(i, product)
245
  rec["explanation"] = er.explanation
 
275
  # Recommend (SSE streaming)
276
  # ---------------------------------------------------------------------------
277
 
278
+
279
  def _sse_event(event: str, data: str) -> str:
280
  """Format a single SSE event."""
281
  return f"event: {event}\ndata: {data}\n\n"
 
286
  app,
287
  ) -> Iterator[str]:
288
  """Generator that yields SSE events for streaming recommendations."""
289
+ yield _sse_event(
290
+ "metadata",
291
+ json.dumps(
292
+ {
293
+ "verified": False,
294
+ "cache": False,
295
+ "hhem": False,
296
+ }
297
+ ),
298
+ )
299
 
300
  try:
301
  products = _fetch_products(params, app)
 
311
 
312
  explainer = app.state.explainer
313
  if explainer is None:
314
+ yield _sse_event(
315
+ "error", json.dumps({"detail": "Explanation service unavailable"})
316
+ )
317
  yield _sse_event("done", json.dumps({"status": "error"}))
318
  return
319
 
 
342
  yield _sse_event("refusal", json.dumps({"detail": str(exc)}))
343
  except Exception:
344
  logger.exception("Streaming error for product %s", product.product_id)
345
+ yield _sse_event(
346
+ "error", json.dumps({"detail": "Failed to generate explanation"})
347
+ )
348
 
349
  yield _sse_event("done", json.dumps({"status": "complete"}))
350
 
 
371
  # Cache management
372
  # ---------------------------------------------------------------------------
373
 
374
+
375
  @router.get("/cache/stats", response_model=CacheStatsResponse)
376
  def cache_stats(request: Request):
377
  """Return cache performance statistics."""
 
401
  # Prometheus metrics
402
  # ---------------------------------------------------------------------------
403
 
404
+
405
  @router.get("/metrics")
406
  def metrics():
407
  """Prometheus metrics endpoint."""
sage/api/run.py CHANGED
@@ -24,7 +24,8 @@ def main():
24
  parser = argparse.ArgumentParser(description="Sage API server")
25
  parser.add_argument("--host", default="0.0.0.0", help="Bind address")
26
  parser.add_argument(
27
- "--port", type=int,
 
28
  default=int(os.getenv("PORT", "8000")),
29
  help="Port (defaults to PORT env var, then 8000)",
30
  )
 
24
  parser = argparse.ArgumentParser(description="Sage API server")
25
  parser.add_argument("--host", default="0.0.0.0", help="Bind address")
26
  parser.add_argument(
27
+ "--port",
28
+ type=int,
29
  default=int(os.getenv("PORT", "8000")),
30
  help="Port (defaults to PORT env var, then 8000)",
31
  )
sage/config/logging.py CHANGED
@@ -28,27 +28,48 @@ LOG_FORMAT = os.getenv("SAGE_LOG_FORMAT", "console") # "console" or "json"
28
 
29
  # Standard LogRecord attributes to ignore when extracting user-specified extras.
30
  # These are built-in attributes from logging.LogRecord plus taskName from asyncio.
31
- _STANDARD_LOG_ATTRS = frozenset({
32
- "name", "msg", "args", "created", "filename", "funcName",
33
- "levelname", "levelno", "lineno", "module", "msecs",
34
- "pathname", "process", "processName", "relativeCreated",
35
- "stack_info", "exc_info", "exc_text", "thread", "threadName",
36
- "message", "asctime", "taskName",
37
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  # ---------------------------------------------------------------------------
41
  # Custom Formatter (Console)
42
  # ---------------------------------------------------------------------------
43
 
 
44
  class ConsoleFormatter(logging.Formatter):
45
  """Human-readable formatter with visual hierarchy."""
46
 
47
  COLORS = {
48
- "DEBUG": "\033[36m", # Cyan
49
- "INFO": "\033[32m", # Green
50
- "WARNING": "\033[33m", # Yellow
51
- "ERROR": "\033[31m", # Red
52
  "CRITICAL": "\033[35m", # Magenta
53
  "RESET": "\033[0m",
54
  }
@@ -87,6 +108,7 @@ class ConsoleFormatter(logging.Formatter):
87
  # Custom Formatter (JSON)
88
  # ---------------------------------------------------------------------------
89
 
 
90
  class JSONFormatter(logging.Formatter):
91
  """Machine-parseable JSON formatter for production."""
92
 
@@ -180,14 +202,19 @@ def get_logger(name: str) -> logging.Logger:
180
  # Convenience functions for visual output
181
  # ---------------------------------------------------------------------------
182
 
183
- def log_banner(logger: logging.Logger, title: str, char: str = "=", width: int = 60) -> None:
 
 
 
184
  """Log a visual banner for section headers."""
185
  logger.info(char * width)
186
  logger.info(title)
187
  logger.info(char * width)
188
 
189
 
190
- def log_section(logger: logging.Logger, title: str, char: str = "-", width: int = 60) -> None:
 
 
191
  """Log a section divider."""
192
  logger.info("")
193
  logger.info(char * width)
 
28
 
29
  # Standard LogRecord attributes to ignore when extracting user-specified extras.
30
  # These are built-in attributes from logging.LogRecord plus taskName from asyncio.
31
+ _STANDARD_LOG_ATTRS = frozenset(
32
+ {
33
+ "name",
34
+ "msg",
35
+ "args",
36
+ "created",
37
+ "filename",
38
+ "funcName",
39
+ "levelname",
40
+ "levelno",
41
+ "lineno",
42
+ "module",
43
+ "msecs",
44
+ "pathname",
45
+ "process",
46
+ "processName",
47
+ "relativeCreated",
48
+ "stack_info",
49
+ "exc_info",
50
+ "exc_text",
51
+ "thread",
52
+ "threadName",
53
+ "message",
54
+ "asctime",
55
+ "taskName",
56
+ }
57
+ )
58
 
59
 
60
  # ---------------------------------------------------------------------------
61
  # Custom Formatter (Console)
62
  # ---------------------------------------------------------------------------
63
 
64
+
65
  class ConsoleFormatter(logging.Formatter):
66
  """Human-readable formatter with visual hierarchy."""
67
 
68
  COLORS = {
69
+ "DEBUG": "\033[36m", # Cyan
70
+ "INFO": "\033[32m", # Green
71
+ "WARNING": "\033[33m", # Yellow
72
+ "ERROR": "\033[31m", # Red
73
  "CRITICAL": "\033[35m", # Magenta
74
  "RESET": "\033[0m",
75
  }
 
108
  # Custom Formatter (JSON)
109
  # ---------------------------------------------------------------------------
110
 
111
+
112
  class JSONFormatter(logging.Formatter):
113
  """Machine-parseable JSON formatter for production."""
114
 
 
202
  # Convenience functions for visual output
203
  # ---------------------------------------------------------------------------
204
 
205
+
206
+ def log_banner(
207
+ logger: logging.Logger, title: str, char: str = "=", width: int = 60
208
+ ) -> None:
209
  """Log a visual banner for section headers."""
210
  logger.info(char * width)
211
  logger.info(title)
212
  logger.info(char * width)
213
 
214
 
215
+ def log_section(
216
+ logger: logging.Logger, title: str, char: str = "-", width: int = 60
217
+ ) -> None:
218
  """Log a section divider."""
219
  logger.info("")
220
  logger.info(char * width)
sage/core/aggregation.py CHANGED
@@ -55,13 +55,15 @@ def aggregate_chunks_to_products(
55
  else:
56
  agg_score = max(scores)
57
 
58
- product_scores.append(ProductScore(
59
- product_id=product_id,
60
- score=agg_score,
61
- chunk_count=len(prod_chunks),
62
- avg_rating=float(np.mean(ratings)),
63
- evidence=sorted(prod_chunks, key=lambda c: c.score, reverse=True),
64
- ))
 
 
65
 
66
  # Sort by score descending
67
  return sorted(product_scores, key=lambda p: p.score, reverse=True)
@@ -112,12 +114,14 @@ def apply_weighted_ranking(
112
  # Create new ProductScore objects with updated scores
113
  reranked = []
114
  for i, product in enumerate(products):
115
- reranked.append(ProductScore(
116
- product_id=product.product_id,
117
- score=float(final_scores[i]),
118
- chunk_count=product.chunk_count,
119
- avg_rating=product.avg_rating,
120
- evidence=product.evidence,
121
- ))
 
 
122
 
123
  return sorted(reranked, key=lambda p: p.score, reverse=True)
 
55
  else:
56
  agg_score = max(scores)
57
 
58
+ product_scores.append(
59
+ ProductScore(
60
+ product_id=product_id,
61
+ score=agg_score,
62
+ chunk_count=len(prod_chunks),
63
+ avg_rating=float(np.mean(ratings)),
64
+ evidence=sorted(prod_chunks, key=lambda c: c.score, reverse=True),
65
+ )
66
+ )
67
 
68
  # Sort by score descending
69
  return sorted(product_scores, key=lambda p: p.score, reverse=True)
 
114
  # Create new ProductScore objects with updated scores
115
  reranked = []
116
  for i, product in enumerate(products):
117
+ reranked.append(
118
+ ProductScore(
119
+ product_id=product.product_id,
120
+ score=float(final_scores[i]),
121
+ chunk_count=product.chunk_count,
122
+ avg_rating=product.avg_rating,
123
+ evidence=product.evidence,
124
+ )
125
+ )
126
 
127
  return sorted(reranked, key=lambda p: p.score, reverse=True)
sage/core/chunking.py CHANGED
@@ -19,16 +19,16 @@ from sage.config import CHARS_PER_TOKEN
19
 
20
 
21
  # Chunking thresholds (tokens)
22
- NO_CHUNK_THRESHOLD = 200 # Texts under this: no chunking
23
- SEMANTIC_THRESHOLD = 500 # Texts under this: semantic only
24
- MAX_CHUNK_TOKENS = 400 # Chunks larger than this get sliding window
25
 
26
  # Semantic chunking config
27
- SIMILARITY_PERCENTILE = 85 # Split at drops below this percentile
28
 
29
  # Sliding window config (fallback)
30
- SLIDING_CHUNK_SIZE = 150 # Target tokens per sliding window chunk
31
- SLIDING_OVERLAP = 30 # Token overlap between chunks
32
 
33
 
34
  def estimate_tokens(text: str) -> int:
 
19
 
20
 
21
  # Chunking thresholds (tokens)
22
+ NO_CHUNK_THRESHOLD = 200 # Texts under this: no chunking
23
+ SEMANTIC_THRESHOLD = 500 # Texts under this: semantic only
24
+ MAX_CHUNK_TOKENS = 400 # Chunks larger than this get sliding window
25
 
26
  # Semantic chunking config
27
+ SIMILARITY_PERCENTILE = 85 # Split at drops below this percentile
28
 
29
  # Sliding window config (fallback)
30
+ SLIDING_CHUNK_SIZE = 150 # Target tokens per sliding window chunk
31
+ SLIDING_OVERLAP = 30 # Token overlap between chunks
32
 
33
 
34
  def estimate_tokens(text: str) -> int:
sage/core/evidence.py CHANGED
@@ -101,8 +101,14 @@ def check_evidence_quality(
101
 
102
  # Check thresholds using table-driven validation
103
  thresholds = [
104
- (chunk_count < min_chunks, f"insufficient_chunks: {chunk_count} < {min_chunks}"),
105
- (total_tokens < min_tokens, f"insufficient_tokens: {total_tokens} < {min_tokens}"),
 
 
 
 
 
 
106
  (top_score < min_score, f"low_relevance: {top_score:.3f} < {min_score}"),
107
  ]
108
 
@@ -150,17 +156,17 @@ def generate_refusal_message(
150
  f"I cannot provide a confident recommendation for this product based on "
151
  f"the available review evidence. Only {quality.chunk_count} review excerpt(s) "
152
  f"were found, which is insufficient to make a well-grounded recommendation "
153
- f"for your query about \"{query}\"."
154
  )
155
  elif "insufficient_tokens" in reason:
156
  return (
157
  f"I cannot provide a meaningful recommendation for this product. "
158
  f"The available review evidence is too brief ({quality.total_tokens} tokens) "
159
- f"to support a well-grounded explanation for your query about \"{query}\"."
160
  )
161
  elif "low_relevance" in reason:
162
  return (
163
- f"I cannot recommend this product for your query about \"{query}\" because "
164
  f"the available reviews do not appear to be sufficiently relevant "
165
  f"(relevance score: {quality.top_score:.2f}). The reviews may discuss "
166
  f"different aspects or product features than what you're looking for."
@@ -168,5 +174,5 @@ def generate_refusal_message(
168
  else:
169
  return (
170
  f"I cannot provide a recommendation for this product due to "
171
- f"insufficient review evidence for your query about \"{query}\"."
172
  )
 
101
 
102
  # Check thresholds using table-driven validation
103
  thresholds = [
104
+ (
105
+ chunk_count < min_chunks,
106
+ f"insufficient_chunks: {chunk_count} < {min_chunks}",
107
+ ),
108
+ (
109
+ total_tokens < min_tokens,
110
+ f"insufficient_tokens: {total_tokens} < {min_tokens}",
111
+ ),
112
  (top_score < min_score, f"low_relevance: {top_score:.3f} < {min_score}"),
113
  ]
114
 
 
156
  f"I cannot provide a confident recommendation for this product based on "
157
  f"the available review evidence. Only {quality.chunk_count} review excerpt(s) "
158
  f"were found, which is insufficient to make a well-grounded recommendation "
159
+ f'for your query about "{query}".'
160
  )
161
  elif "insufficient_tokens" in reason:
162
  return (
163
  f"I cannot provide a meaningful recommendation for this product. "
164
  f"The available review evidence is too brief ({quality.total_tokens} tokens) "
165
+ f'to support a well-grounded explanation for your query about "{query}".'
166
  )
167
  elif "low_relevance" in reason:
168
  return (
169
+ f'I cannot recommend this product for your query about "{query}" because '
170
  f"the available reviews do not appear to be sufficiently relevant "
171
  f"(relevance score: {quality.top_score:.2f}). The reviews may discuss "
172
  f"different aspects or product features than what you're looking for."
 
174
  else:
175
  return (
176
  f"I cannot provide a recommendation for this product due to "
177
+ f'insufficient review evidence for your query about "{query}".'
178
  )
sage/core/models.py CHANGED
@@ -20,8 +20,10 @@ from typing import Iterator
20
  # RETRIEVAL & RECOMMENDATION MODELS
21
  # ============================================================================
22
 
 
23
  class AggregationMethod(Enum):
24
  """Methods for aggregating chunk scores to product scores."""
 
25
  MAX = "max"
26
  MEAN = "mean"
27
  WEIGHTED_MEAN = "weighted_mean"
@@ -35,6 +37,7 @@ class Chunk:
35
  This is the unit stored in the vector database. Reviews are chunked
36
  using semantic or sliding-window strategies based on length.
37
  """
 
38
  text: str
39
  chunk_index: int
40
  total_chunks: int
@@ -52,6 +55,7 @@ class RetrievedChunk:
52
  This is returned by semantic search and used as evidence for
53
  explanation generation.
54
  """
 
55
  text: str
56
  score: float
57
  product_id: str
@@ -67,6 +71,7 @@ class ProductScore:
67
  Multiple chunks may belong to the same product. This dataclass
68
  holds the aggregated score and all supporting evidence.
69
  """
 
70
  product_id: str
71
  score: float
72
  chunk_count: int
@@ -89,6 +94,7 @@ class Recommendation:
89
  This is the output of the recommendation pipeline, ready for
90
  display or API response.
91
  """
 
92
  rank: int
93
  product_id: str
94
  score: float
@@ -102,6 +108,7 @@ class Recommendation:
102
  # COLD START MODELS
103
  # ============================================================================
104
 
 
105
  @dataclass
106
  class UserPreferences:
107
  """
@@ -110,6 +117,7 @@ class UserPreferences:
110
  In production, these would be collected via an onboarding flow:
111
  "What categories interest you?" "What's your budget?" etc.
112
  """
 
113
  categories: list[str] | None = None
114
  budget: str | None = None # "low", "medium", "high", or specific like "$50-100"
115
  priorities: list[str] | None = None # ["quality", "value", "durability"]
@@ -123,6 +131,7 @@ class NewItem:
123
 
124
  In production, this would come from the product catalog.
125
  """
 
126
  product_id: str
127
  title: str
128
  description: str | None = None
@@ -136,6 +145,7 @@ class NewItem:
136
  # EXPLANATION MODELS
137
  # ============================================================================
138
 
 
139
  @dataclass
140
  class ExplanationResult:
141
  """
@@ -144,6 +154,7 @@ class ExplanationResult:
144
  Contains the generated explanation along with evidence attribution
145
  for traceability and faithfulness verification.
146
  """
 
147
  explanation: str
148
  product_id: str
149
  query: str
@@ -174,6 +185,7 @@ class StreamingExplanation:
174
  print(token, end="", flush=True)
175
  result = stream.get_complete_result()
176
  """
 
177
  token_iterator: Iterator[str]
178
  product_id: str
179
  query: str
@@ -215,6 +227,7 @@ class EvidenceQuality:
215
  when evidence is too thin. Thin evidence (1 chunk, few tokens)
216
  correlates strongly with LLM overclaiming.
217
  """
 
218
  is_sufficient: bool
219
  chunk_count: int
220
  total_tokens: int
@@ -226,9 +239,11 @@ class EvidenceQuality:
226
  # VERIFICATION MODELS
227
  # ============================================================================
228
 
 
229
  @dataclass
230
  class QuoteVerification:
231
  """Result of verifying a single quoted claim against evidence."""
 
232
  quote: str
233
  found: bool
234
  source_text: str | None = None # Which evidence text contained it
@@ -243,6 +258,7 @@ class VerificationResult:
243
  exists in the provided evidence. Catches wrong attribution where
244
  LLM cites quotes that don't exist.
245
  """
 
246
  all_verified: bool
247
  quotes_found: int
248
  quotes_missing: int
@@ -254,6 +270,7 @@ class VerificationResult:
254
  # HALLUCINATION DETECTION MODELS
255
  # ============================================================================
256
 
 
257
  @dataclass
258
  class HallucinationResult:
259
  """
@@ -263,6 +280,7 @@ class HallucinationResult:
263
  consistency between evidence (premise) and explanation (hypothesis).
264
  Score < 0.5 indicates hallucination.
265
  """
 
266
  score: float
267
  is_hallucinated: bool
268
  threshold: float
@@ -273,6 +291,7 @@ class HallucinationResult:
273
  @dataclass
274
  class ClaimResult:
275
  """Result of hallucination check for a single claim."""
 
276
  claim: str
277
  score: float
278
  is_hallucinated: bool
@@ -286,6 +305,7 @@ class AgreementReport:
286
  Useful for understanding when the two methods disagree and
287
  calibrating thresholds.
288
  """
 
289
  n_samples: int
290
  agreement_rate: float # Proportion where both agree on pass/fail
291
  hhem_pass_rate: float
@@ -306,6 +326,7 @@ class AdjustedFaithfulnessReport:
306
  Refusals (e.g., "I cannot recommend...") are correct LLM behavior
307
  but get penalized by HHEM. This report adjusts for that.
308
  """
 
309
  n_total: int
310
  n_refusals: int
311
  n_evaluated: int # n_total - n_refusals
@@ -334,6 +355,7 @@ class ClaimLevelReport:
334
  - min_score: Lowest scoring claim (weakest grounding)
335
  - pass_rate: Proportion of claims scoring >= threshold
336
  """
 
337
  n_explanations: int
338
  n_claims: int
339
  avg_score: float
@@ -368,6 +390,7 @@ class MultiMetricFaithfulnessReport:
368
  individually. Full-explanation HHEM (57%) measures structural patterns
369
  that HHEM was trained on, not actual hallucination."
370
  """
 
371
  n_samples: int
372
  # Quote verification (lexical)
373
  quote_verification_rate: float
@@ -410,19 +433,19 @@ class MultiMetricFaithfulnessReport:
410
  "=" * 50,
411
  "",
412
  "Quote Verification (lexical grounding):",
413
- f" Pass rate: {self.quote_verification_rate*100:.1f}% ({self.quotes_found}/{self.quotes_total})",
414
  "",
415
  "Claim-Level HHEM (semantic grounding per claim):",
416
- f" Pass rate: {self.claim_level_pass_rate*100:.1f}%",
417
  f" Avg score: {self.claim_level_avg_score:.3f}",
418
  f" Min score: {self.claim_level_min_score:.3f}",
419
  "",
420
  "Full-Explanation HHEM (structural compatibility):",
421
- f" Pass rate: {self.full_explanation_pass_rate*100:.1f}%",
422
  f" Avg score: {self.full_explanation_avg_score:.3f}",
423
  "",
424
  "-" * 50,
425
- f"PRIMARY METRIC ({self.primary_metric}): {self.claim_level_pass_rate*100:.1f}%",
426
  ]
427
  return "\n".join(lines)
428
 
@@ -431,9 +454,11 @@ class MultiMetricFaithfulnessReport:
431
  # FAITHFULNESS EVALUATION MODELS (RAGAS)
432
  # ============================================================================
433
 
 
434
  @dataclass
435
  class FaithfulnessResult:
436
  """Result of RAGAS faithfulness evaluation for a single explanation."""
 
437
  score: float
438
  query: str
439
  explanation: str
@@ -444,6 +469,7 @@ class FaithfulnessResult:
444
  @dataclass
445
  class FaithfulnessReport:
446
  """Aggregate report for batch faithfulness evaluation (legacy format)."""
 
447
  mean_score: float
448
  min_score: float
449
  max_score: float
@@ -459,6 +485,7 @@ class FaithfulnessReport:
459
  # EVALUATION METRICS MODELS
460
  # ============================================================================
461
 
 
462
  @dataclass
463
  class EvalCase:
464
  """
@@ -470,6 +497,7 @@ class EvalCase:
470
  For binary relevance, use 1 for relevant, 0 for not.
471
  user_id: Optional user identifier for the case.
472
  """
 
473
  query: str
474
  relevant_items: dict[str, float]
475
  user_id: str | None = None
@@ -483,6 +511,7 @@ class EvalCase:
483
  @dataclass
484
  class EvalResult:
485
  """Results from evaluating a single recommendation list."""
 
486
  ndcg: float = 0.0
487
  hit: float = 0.0
488
  mrr: float = 0.0
@@ -498,6 +527,7 @@ class MetricsReport:
498
  Includes both accuracy metrics (NDCG, Hit, MRR) and
499
  beyond-accuracy metrics (diversity, coverage, novelty).
500
  """
 
501
  n_cases: int = 0
502
  ndcg_at_k: float = 0.0
503
  hit_at_k: float = 0.0
 
20
  # RETRIEVAL & RECOMMENDATION MODELS
21
  # ============================================================================
22
 
23
+
24
  class AggregationMethod(Enum):
25
  """Methods for aggregating chunk scores to product scores."""
26
+
27
  MAX = "max"
28
  MEAN = "mean"
29
  WEIGHTED_MEAN = "weighted_mean"
 
37
  This is the unit stored in the vector database. Reviews are chunked
38
  using semantic or sliding-window strategies based on length.
39
  """
40
+
41
  text: str
42
  chunk_index: int
43
  total_chunks: int
 
55
  This is returned by semantic search and used as evidence for
56
  explanation generation.
57
  """
58
+
59
  text: str
60
  score: float
61
  product_id: str
 
71
  Multiple chunks may belong to the same product. This dataclass
72
  holds the aggregated score and all supporting evidence.
73
  """
74
+
75
  product_id: str
76
  score: float
77
  chunk_count: int
 
94
  This is the output of the recommendation pipeline, ready for
95
  display or API response.
96
  """
97
+
98
  rank: int
99
  product_id: str
100
  score: float
 
108
  # COLD START MODELS
109
  # ============================================================================
110
 
111
+
112
  @dataclass
113
  class UserPreferences:
114
  """
 
117
  In production, these would be collected via an onboarding flow:
118
  "What categories interest you?" "What's your budget?" etc.
119
  """
120
+
121
  categories: list[str] | None = None
122
  budget: str | None = None # "low", "medium", "high", or specific like "$50-100"
123
  priorities: list[str] | None = None # ["quality", "value", "durability"]
 
131
 
132
  In production, this would come from the product catalog.
133
  """
134
+
135
  product_id: str
136
  title: str
137
  description: str | None = None
 
145
  # EXPLANATION MODELS
146
  # ============================================================================
147
 
148
+
149
  @dataclass
150
  class ExplanationResult:
151
  """
 
154
  Contains the generated explanation along with evidence attribution
155
  for traceability and faithfulness verification.
156
  """
157
+
158
  explanation: str
159
  product_id: str
160
  query: str
 
185
  print(token, end="", flush=True)
186
  result = stream.get_complete_result()
187
  """
188
+
189
  token_iterator: Iterator[str]
190
  product_id: str
191
  query: str
 
227
  when evidence is too thin. Thin evidence (1 chunk, few tokens)
228
  correlates strongly with LLM overclaiming.
229
  """
230
+
231
  is_sufficient: bool
232
  chunk_count: int
233
  total_tokens: int
 
239
  # VERIFICATION MODELS
240
  # ============================================================================
241
 
242
+
243
  @dataclass
244
  class QuoteVerification:
245
  """Result of verifying a single quoted claim against evidence."""
246
+
247
  quote: str
248
  found: bool
249
  source_text: str | None = None # Which evidence text contained it
 
258
  exists in the provided evidence. Catches wrong attribution where
259
  LLM cites quotes that don't exist.
260
  """
261
+
262
  all_verified: bool
263
  quotes_found: int
264
  quotes_missing: int
 
270
  # HALLUCINATION DETECTION MODELS
271
  # ============================================================================
272
 
273
+
274
  @dataclass
275
  class HallucinationResult:
276
  """
 
280
  consistency between evidence (premise) and explanation (hypothesis).
281
  Score < 0.5 indicates hallucination.
282
  """
283
+
284
  score: float
285
  is_hallucinated: bool
286
  threshold: float
 
291
  @dataclass
292
  class ClaimResult:
293
  """Result of hallucination check for a single claim."""
294
+
295
  claim: str
296
  score: float
297
  is_hallucinated: bool
 
305
  Useful for understanding when the two methods disagree and
306
  calibrating thresholds.
307
  """
308
+
309
  n_samples: int
310
  agreement_rate: float # Proportion where both agree on pass/fail
311
  hhem_pass_rate: float
 
326
  Refusals (e.g., "I cannot recommend...") are correct LLM behavior
327
  but get penalized by HHEM. This report adjusts for that.
328
  """
329
+
330
  n_total: int
331
  n_refusals: int
332
  n_evaluated: int # n_total - n_refusals
 
355
  - min_score: Lowest scoring claim (weakest grounding)
356
  - pass_rate: Proportion of claims scoring >= threshold
357
  """
358
+
359
  n_explanations: int
360
  n_claims: int
361
  avg_score: float
 
390
  individually. Full-explanation HHEM (57%) measures structural patterns
391
  that HHEM was trained on, not actual hallucination."
392
  """
393
+
394
  n_samples: int
395
  # Quote verification (lexical)
396
  quote_verification_rate: float
 
433
  "=" * 50,
434
  "",
435
  "Quote Verification (lexical grounding):",
436
+ f" Pass rate: {self.quote_verification_rate * 100:.1f}% ({self.quotes_found}/{self.quotes_total})",
437
  "",
438
  "Claim-Level HHEM (semantic grounding per claim):",
439
+ f" Pass rate: {self.claim_level_pass_rate * 100:.1f}%",
440
  f" Avg score: {self.claim_level_avg_score:.3f}",
441
  f" Min score: {self.claim_level_min_score:.3f}",
442
  "",
443
  "Full-Explanation HHEM (structural compatibility):",
444
+ f" Pass rate: {self.full_explanation_pass_rate * 100:.1f}%",
445
  f" Avg score: {self.full_explanation_avg_score:.3f}",
446
  "",
447
  "-" * 50,
448
+ f"PRIMARY METRIC ({self.primary_metric}): {self.claim_level_pass_rate * 100:.1f}%",
449
  ]
450
  return "\n".join(lines)
451
 
 
454
  # FAITHFULNESS EVALUATION MODELS (RAGAS)
455
  # ============================================================================
456
 
457
+
458
  @dataclass
459
  class FaithfulnessResult:
460
  """Result of RAGAS faithfulness evaluation for a single explanation."""
461
+
462
  score: float
463
  query: str
464
  explanation: str
 
469
  @dataclass
470
  class FaithfulnessReport:
471
  """Aggregate report for batch faithfulness evaluation (legacy format)."""
472
+
473
  mean_score: float
474
  min_score: float
475
  max_score: float
 
485
  # EVALUATION METRICS MODELS
486
  # ============================================================================
487
 
488
+
489
  @dataclass
490
  class EvalCase:
491
  """
 
497
  For binary relevance, use 1 for relevant, 0 for not.
498
  user_id: Optional user identifier for the case.
499
  """
500
+
501
  query: str
502
  relevant_items: dict[str, float]
503
  user_id: str | None = None
 
511
  @dataclass
512
  class EvalResult:
513
  """Results from evaluating a single recommendation list."""
514
+
515
  ndcg: float = 0.0
516
  hit: float = 0.0
517
  mrr: float = 0.0
 
527
  Includes both accuracy metrics (NDCG, Hit, MRR) and
528
  beyond-accuracy metrics (diversity, coverage, novelty).
529
  """
530
+
531
  n_cases: int = 0
532
  ndcg_at_k: float = 0.0
533
  hit_at_k: float = 0.0
sage/core/prompts.py CHANGED
@@ -76,7 +76,7 @@ def format_evidence(
76
  return "(No review evidence available)"
77
 
78
  return "\n\n".join(
79
- f"[{chunk.review_id}] ({int(chunk.rating or 0)}/5 stars): \"{chunk.text}\""
80
  for chunk in chunks[:max_chunks]
81
  )
82
 
 
76
  return "(No review evidence available)"
77
 
78
  return "\n\n".join(
79
+ f'[{chunk.review_id}] ({int(chunk.rating or 0)}/5 stars): "{chunk.text}"'
80
  for chunk in chunks[:max_chunks]
81
  )
82
 
sage/core/verification.py CHANGED
@@ -86,9 +86,9 @@ def extract_quotes(text: str, min_length: int = 4) -> list[str]:
86
  List of unique quoted strings found in the text.
87
  """
88
  patterns = [
89
- r'"([^"]+)"', # Regular double quotes
90
- r'"([^"]+)"', # Curly double quotes
91
- r"'([^']+)'", # Single quotes
92
  ]
93
 
94
  quotes = []
@@ -206,6 +206,7 @@ def verify_explanation(
206
  # Citation ID Verification
207
  # =============================================================================
208
 
 
209
  @dataclass
210
  class CitationResult:
211
  """Result of verifying a single citation."""
@@ -256,12 +257,12 @@ def extract_citations(text: str) -> list[tuple[str, str | None]]:
256
  quote_text = match.group(1)
257
  citation_block = match.group(2)
258
  # Split multiple citations like "review_123, review_456"
259
- for citation_id in re.findall(r'review_\d+', citation_block):
260
  citations.append((citation_id, quote_text))
261
 
262
  # Pattern for standalone citations not preceded by a quote
263
  # Find all citations, then filter out ones already captured with quotes
264
- all_citation_ids = set(re.findall(r'review_\d+', text))
265
  quoted_citation_ids = {c[0] for c in citations}
266
  standalone_ids = all_citation_ids - quoted_citation_ids
267
 
@@ -294,9 +295,7 @@ def verify_citation(
294
  """
295
  # Collect all chunks belonging to this citation ID (a single review
296
  # may produce multiple chunks after splitting long reviews).
297
- matching_indices = [
298
- i for i, eid in enumerate(evidence_ids) if eid == citation_id
299
- ]
300
 
301
  if not matching_indices:
302
  return CitationResult(
 
86
  List of unique quoted strings found in the text.
87
  """
88
  patterns = [
89
+ r'"([^"]+)"', # Regular double quotes
90
+ r'"([^"]+)"', # Curly double quotes
91
+ r"'([^']+)'", # Single quotes
92
  ]
93
 
94
  quotes = []
 
206
  # Citation ID Verification
207
  # =============================================================================
208
 
209
+
210
  @dataclass
211
  class CitationResult:
212
  """Result of verifying a single citation."""
 
257
  quote_text = match.group(1)
258
  citation_block = match.group(2)
259
  # Split multiple citations like "review_123, review_456"
260
+ for citation_id in re.findall(r"review_\d+", citation_block):
261
  citations.append((citation_id, quote_text))
262
 
263
  # Pattern for standalone citations not preceded by a quote
264
  # Find all citations, then filter out ones already captured with quotes
265
+ all_citation_ids = set(re.findall(r"review_\d+", text))
266
  quoted_citation_ids = {c[0] for c in citations}
267
  standalone_ids = all_citation_ids - quoted_citation_ids
268
 
 
295
  """
296
  # Collect all chunks belonging to this citation ID (a single review
297
  # may produce multiple chunks after splitting long reviews).
298
+ matching_indices = [i for i, eid in enumerate(evidence_ids) if eid == citation_id]
 
 
299
 
300
  if not matching_indices:
301
  return CitationResult(
sage/services/__init__.py CHANGED
@@ -59,6 +59,7 @@ _LAZY_IMPORTS = {
59
  def __getattr__(name: str):
60
  if name in _LAZY_IMPORTS:
61
  import importlib
 
62
  module = importlib.import_module(_LAZY_IMPORTS[name])
63
  return getattr(module, name)
64
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
 
59
  def __getattr__(name: str):
60
  if name in _LAZY_IMPORTS:
61
  import importlib
62
+
63
  module = importlib.import_module(_LAZY_IMPORTS[name])
64
  return getattr(module, name)
65
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
sage/services/baselines.py CHANGED
@@ -73,9 +73,7 @@ class PopularityBaseline:
73
  counts = Counter(i[item_key] for i in interactions if item_key in i)
74
 
75
  # Sort by popularity (descending)
76
- self.ranked_items = [
77
- item for item, _ in counts.most_common()
78
- ]
79
 
80
  self.popularity = counts
81
 
@@ -119,9 +117,9 @@ class ItemKNNBaseline:
119
  embedder: E5Embedder instance for query embedding.
120
  """
121
  self.product_ids = list(product_embeddings.keys())
122
- self.embeddings = np.array([
123
- product_embeddings[pid] for pid in self.product_ids
124
- ])
125
 
126
  # Normalize embeddings for cosine similarity
127
  norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
@@ -143,6 +141,7 @@ class ItemKNNBaseline:
143
  """
144
  if self.embedder is None:
145
  from sage.adapters.embeddings import get_embedder
 
146
  self.embedder = get_embedder()
147
 
148
  # Embed query
@@ -188,7 +187,9 @@ def build_product_embeddings(
188
  elif aggregation == "max":
189
  agg_emb = product_embs.max(axis=0)
190
  else:
191
- raise ValueError(f"Unknown aggregation method: {aggregation}. Use 'mean' or 'max'.")
 
 
192
 
193
  # Normalize
194
  agg_emb = agg_emb / (np.linalg.norm(agg_emb) + 1e-8)
 
73
  counts = Counter(i[item_key] for i in interactions if item_key in i)
74
 
75
  # Sort by popularity (descending)
76
+ self.ranked_items = [item for item, _ in counts.most_common()]
 
 
77
 
78
  self.popularity = counts
79
 
 
117
  embedder: E5Embedder instance for query embedding.
118
  """
119
  self.product_ids = list(product_embeddings.keys())
120
+ self.embeddings = np.array(
121
+ [product_embeddings[pid] for pid in self.product_ids]
122
+ )
123
 
124
  # Normalize embeddings for cosine similarity
125
  norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
 
141
  """
142
  if self.embedder is None:
143
  from sage.adapters.embeddings import get_embedder
144
+
145
  self.embedder = get_embedder()
146
 
147
  # Embed query
 
187
  elif aggregation == "max":
188
  agg_emb = product_embs.max(axis=0)
189
  else:
190
+ raise ValueError(
191
+ f"Unknown aggregation method: {aggregation}. Use 'mean' or 'max'."
192
+ )
193
 
194
  # Normalize
195
  agg_emb = agg_emb / (np.linalg.norm(agg_emb) + 1e-8)
sage/services/cache.py CHANGED
@@ -27,6 +27,7 @@ logger = get_logger(__name__)
27
  # Cache entry
28
  # ---------------------------------------------------------------------------
29
 
 
30
  @dataclass
31
  class _CacheEntry:
32
  """Single cached result with metadata for eviction."""
@@ -43,6 +44,7 @@ class _CacheEntry:
43
  # Cache stats
44
  # ---------------------------------------------------------------------------
45
 
 
46
  @dataclass
47
  class CacheStats:
48
  """Snapshot of cache performance metrics."""
@@ -73,6 +75,7 @@ class CacheStats:
73
  # Semantic cache
74
  # ---------------------------------------------------------------------------
75
 
 
76
  class SemanticCache:
77
  """Thread-safe in-memory cache with exact-match and semantic-similarity layers.
78
 
@@ -116,7 +119,9 @@ class SemanticCache:
116
  # ------------------------------------------------------------------
117
 
118
  def get(
119
- self, query: str, query_embedding: np.ndarray | None = None,
 
 
120
  ) -> tuple[dict | None, str]:
121
  """Look up a cached result.
122
 
@@ -232,7 +237,8 @@ class SemanticCache:
232
  # ------------------------------------------------------------------
233
 
234
  def _find_semantic_match(
235
- self, query_embedding: np.ndarray,
 
236
  ) -> tuple[_CacheEntry, float]:
237
  """Find the best semantic match among cached entries.
238
 
 
27
  # Cache entry
28
  # ---------------------------------------------------------------------------
29
 
30
+
31
  @dataclass
32
  class _CacheEntry:
33
  """Single cached result with metadata for eviction."""
 
44
  # Cache stats
45
  # ---------------------------------------------------------------------------
46
 
47
+
48
  @dataclass
49
  class CacheStats:
50
  """Snapshot of cache performance metrics."""
 
75
  # Semantic cache
76
  # ---------------------------------------------------------------------------
77
 
78
+
79
  class SemanticCache:
80
  """Thread-safe in-memory cache with exact-match and semantic-similarity layers.
81
 
 
119
  # ------------------------------------------------------------------
120
 
121
  def get(
122
+ self,
123
+ query: str,
124
+ query_embedding: np.ndarray | None = None,
125
  ) -> tuple[dict | None, str]:
126
  """Look up a cached result.
127
 
 
237
  # ------------------------------------------------------------------
238
 
239
  def _find_semantic_match(
240
+ self,
241
+ query_embedding: np.ndarray,
242
  ) -> tuple[_CacheEntry, float]:
243
  """Find the best semantic match among cached entries.
244
 
sage/services/cold_start.py CHANGED
@@ -102,6 +102,7 @@ class ColdStartService:
102
  """Lazy-load retrieval service."""
103
  if self._retrieval is None:
104
  from sage.services.retrieval import RetrievalService
 
105
  self._retrieval = RetrievalService(collection_name=self.collection_name)
106
  return self._retrieval
107
 
@@ -179,7 +180,9 @@ class ColdStartService:
179
  item_text = " ".join(text_parts)
180
 
181
  # Embed as a passage
182
- item_embedding = self.embedder.embed_passages([item_text], show_progress=False)[0]
 
 
183
 
184
  # Search for similar chunks
185
  results = search(
 
102
  """Lazy-load retrieval service."""
103
  if self._retrieval is None:
104
  from sage.services.retrieval import RetrievalService
105
+
106
  self._retrieval = RetrievalService(collection_name=self.collection_name)
107
  return self._retrieval
108
 
 
180
  item_text = " ".join(text_parts)
181
 
182
  # Embed as a passage
183
+ item_embedding = self.embedder.embed_passages([item_text], show_progress=False)[
184
+ 0
185
+ ]
186
 
187
  # Search for similar chunks
188
  results = search(
sage/services/evaluation.py CHANGED
@@ -263,7 +263,9 @@ class EvaluationService:
263
  ndcg_at_k=float(np.mean(ndcg_scores)) if ndcg_scores else 0.0,
264
  hit_at_k=float(np.mean(hit_scores)) if hit_scores else 0.0,
265
  mrr=float(np.mean(mrr_scores)) if mrr_scores else 0.0,
266
- precision_at_k=float(np.mean(precision_scores)) if precision_scores else 0.0,
 
 
267
  recall_at_k=float(np.mean(recall_scores)) if recall_scores else 0.0,
268
  diversity=float(np.mean(diversity_scores)) if diversity_scores else 0.0,
269
  novelty=float(np.mean(novelty_scores)) if novelty_scores else 0.0,
 
263
  ndcg_at_k=float(np.mean(ndcg_scores)) if ndcg_scores else 0.0,
264
  hit_at_k=float(np.mean(hit_scores)) if hit_scores else 0.0,
265
  mrr=float(np.mean(mrr_scores)) if mrr_scores else 0.0,
266
+ precision_at_k=float(np.mean(precision_scores))
267
+ if precision_scores
268
+ else 0.0,
269
  recall_at_k=float(np.mean(recall_scores)) if recall_scores else 0.0,
270
  diversity=float(np.mean(diversity_scores)) if diversity_scores else 0.0,
271
  novelty=float(np.mean(novelty_scores)) if novelty_scores else 0.0,
sage/services/explanation.py CHANGED
@@ -109,8 +109,8 @@ class Explainer:
109
  Returns:
110
  (explanation, tokens, evidence_texts, evidence_ids, user_prompt).
111
  """
112
- system_prompt, user_prompt, evidence_texts, evidence_ids = build_explanation_prompt(
113
- query, product, max_evidence
114
  )
115
 
116
  t0 = time.perf_counter()
@@ -120,7 +120,9 @@ class Explainer:
120
  )
121
  logger.info(
122
  "LLM generation for %s: %.0fms, %d tokens",
123
- product.product_id, (time.perf_counter() - t0) * 1000, tokens,
 
 
124
  )
125
 
126
  return explanation, tokens, evidence_texts, evidence_ids, user_prompt
@@ -241,8 +243,8 @@ class Explainer:
241
  f"Client {type(self.client).__name__} does not support streaming."
242
  )
243
 
244
- system_prompt, user_prompt, evidence_texts, evidence_ids = build_explanation_prompt(
245
- query, product, max_evidence
246
  )
247
 
248
  token_iterator = self.client.generate_stream(
@@ -335,7 +337,11 @@ class Explainer:
335
  """
336
  return [
337
  self.generate_explanation(
338
- query, product, max_evidence, enforce_quality_gate, enforce_forbidden_phrases
 
 
 
 
339
  )
340
  for product in products
341
  ]
 
109
  Returns:
110
  (explanation, tokens, evidence_texts, evidence_ids, user_prompt).
111
  """
112
+ system_prompt, user_prompt, evidence_texts, evidence_ids = (
113
+ build_explanation_prompt(query, product, max_evidence)
114
  )
115
 
116
  t0 = time.perf_counter()
 
120
  )
121
  logger.info(
122
  "LLM generation for %s: %.0fms, %d tokens",
123
+ product.product_id,
124
+ (time.perf_counter() - t0) * 1000,
125
+ tokens,
126
  )
127
 
128
  return explanation, tokens, evidence_texts, evidence_ids, user_prompt
 
243
  f"Client {type(self.client).__name__} does not support streaming."
244
  )
245
 
246
+ system_prompt, user_prompt, evidence_texts, evidence_ids = (
247
+ build_explanation_prompt(query, product, max_evidence)
248
  )
249
 
250
  token_iterator = self.client.generate_stream(
 
337
  """
338
  return [
339
  self.generate_explanation(
340
+ query,
341
+ product,
342
+ max_evidence,
343
+ enforce_quality_gate,
344
+ enforce_forbidden_phrases,
345
  )
346
  for product in products
347
  ]
sage/services/faithfulness.py CHANGED
@@ -73,7 +73,9 @@ def create_ragas_sample(query: str, explanation: str, evidence_texts: list[str])
73
  )
74
 
75
 
76
- def _explanation_results_to_samples(explanation_results: list[ExplanationResult]) -> list:
 
 
77
  """Convert ExplanationResults to RAGAS samples."""
78
  return [
79
  create_ragas_sample(
@@ -239,6 +241,7 @@ class FaithfulnessEvaluator:
239
  results=individual_results,
240
  )
241
 
 
242
  def evaluate_faithfulness(
243
  explanation_results: list[ExplanationResult],
244
  provider: str | None = None,
@@ -396,7 +399,8 @@ def compute_adjusted_faithfulness(
396
  # - Valid non-recommendations count as passes (correct behavior)
397
  # - Regular recommendations evaluated by HHEM
398
  regular_passes = sum(
399
- 1 for r, is_non_rec in zip(results, valid_non_recs)
 
400
  if not is_non_rec and not r.is_hallucinated
401
  )
402
  adjusted_passes = regular_passes + n_valid_non_recs
@@ -594,14 +598,13 @@ def compute_multi_metric_faithfulness(
594
  detector = get_detector()
595
 
596
  # 1. Full-explanation HHEM (structural)
597
- full_scores = [
598
- detector.check_explanation(ev, exp).score
599
- for ev, exp in items
600
- ]
601
 
602
  # 2. Claim-level HHEM
603
  claim_report = compute_claim_level_hhem(
604
- items, threshold, full_explanation_scores=full_scores,
 
 
605
  )
606
 
607
  # 3. Quote verification (lexical)
 
73
  )
74
 
75
 
76
+ def _explanation_results_to_samples(
77
+ explanation_results: list[ExplanationResult],
78
+ ) -> list:
79
  """Convert ExplanationResults to RAGAS samples."""
80
  return [
81
  create_ragas_sample(
 
241
  results=individual_results,
242
  )
243
 
244
+
245
  def evaluate_faithfulness(
246
  explanation_results: list[ExplanationResult],
247
  provider: str | None = None,
 
399
  # - Valid non-recommendations count as passes (correct behavior)
400
  # - Regular recommendations evaluated by HHEM
401
  regular_passes = sum(
402
+ 1
403
+ for r, is_non_rec in zip(results, valid_non_recs)
404
  if not is_non_rec and not r.is_hallucinated
405
  )
406
  adjusted_passes = regular_passes + n_valid_non_recs
 
598
  detector = get_detector()
599
 
600
  # 1. Full-explanation HHEM (structural)
601
+ full_scores = [detector.check_explanation(ev, exp).score for ev, exp in items]
 
 
 
602
 
603
  # 2. Claim-level HHEM
604
  claim_report = compute_claim_level_hhem(
605
+ items,
606
+ threshold,
607
+ full_explanation_scores=full_scores,
608
  )
609
 
610
  # 3. Quote verification (lexical)
sage/services/retrieval.py CHANGED
@@ -138,7 +138,11 @@ class RetrievalService:
138
  limit=limit,
139
  min_rating=min_rating,
140
  )
141
- logger.info("Qdrant search: %.0fms, %d results", (time.perf_counter() - t0) * 1000, len(results))
 
 
 
 
142
 
143
  chunks = []
144
  for r in results:
@@ -157,7 +161,9 @@ class RetrievalService:
157
  )
158
 
159
  product_ids = {c.product_id for c in chunks}
160
- logger.info("Retrieved %d chunks across %d products", len(chunks), len(product_ids))
 
 
161
 
162
  return chunks
163
 
@@ -247,7 +253,11 @@ def retrieve_chunks(
247
  """Retrieve relevant chunks from the vector store."""
248
  service = RetrievalService(client=client, embedder=embedder)
249
  return service.retrieve_chunks(
250
- query, limit, min_rating, exclude_products, query_embedding,
 
 
 
 
251
  )
252
 
253
 
@@ -347,7 +357,8 @@ def recommend_for_user(
347
 
348
  # Get products to exclude
349
  exclude: set[str] = {
350
- pid for r in user_history
 
351
  if (pid := r.get("product_id")) is not None and isinstance(pid, str)
352
  }
353
 
 
138
  limit=limit,
139
  min_rating=min_rating,
140
  )
141
+ logger.info(
142
+ "Qdrant search: %.0fms, %d results",
143
+ (time.perf_counter() - t0) * 1000,
144
+ len(results),
145
+ )
146
 
147
  chunks = []
148
  for r in results:
 
161
  )
162
 
163
  product_ids = {c.product_id for c in chunks}
164
+ logger.info(
165
+ "Retrieved %d chunks across %d products", len(chunks), len(product_ids)
166
+ )
167
 
168
  return chunks
169
 
 
253
  """Retrieve relevant chunks from the vector store."""
254
  service = RetrievalService(client=client, embedder=embedder)
255
  return service.retrieve_chunks(
256
+ query,
257
+ limit,
258
+ min_rating,
259
+ exclude_products,
260
+ query_embedding,
261
  )
262
 
263
 
 
357
 
358
  # Get products to exclude
359
  exclude: set[str] = {
360
+ pid
361
+ for r in user_history
362
  if (pid := r.get("product_id")) is not None and isinstance(pid, str)
363
  }
364
 
sage/utils.py CHANGED
@@ -20,6 +20,7 @@ def save_results(data: dict, prefix: str, directory: Path | None = None) -> Path
20
  """
21
  if directory is None:
22
  from sage.config import RESULTS_DIR
 
23
  directory = RESULTS_DIR
24
 
25
  directory.mkdir(parents=True, exist_ok=True)
 
20
  """
21
  if directory is None:
22
  from sage.config import RESULTS_DIR
23
+
24
  directory = RESULTS_DIR
25
 
26
  directory.mkdir(parents=True, exist_ok=True)
scripts/build_eval_dataset.py CHANGED
@@ -35,27 +35,199 @@ EVAL_DIR = DATA_DIR / "eval"
35
 
36
  # Common stopwords to filter out
37
  STOPWORDS = {
38
- "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your",
39
- "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she",
40
- "her", "hers", "herself", "it", "its", "itself", "they", "them", "their",
41
- "theirs", "themselves", "what", "which", "who", "whom", "this", "that",
42
- "these", "those", "am", "is", "are", "was", "were", "be", "been", "being",
43
- "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an",
44
- "the", "and", "but", "if", "or", "because", "as", "until", "while", "of",
45
- "at", "by", "for", "with", "about", "against", "between", "into", "through",
46
- "during", "before", "after", "above", "below", "to", "from", "up", "down",
47
- "in", "out", "on", "off", "over", "under", "again", "further", "then",
48
- "once", "here", "there", "when", "where", "why", "how", "all", "each",
49
- "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only",
50
- "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just",
51
- "don", "should", "now", "d", "ll", "m", "o", "re", "ve", "y", "ain", "aren",
52
- "couldn", "didn", "doesn", "hadn", "hasn", "haven", "isn", "ma", "mightn",
53
- "mustn", "needn", "shan", "shouldn", "wasn", "weren", "won", "wouldn",
54
- "also", "would", "could", "get", "got", "one", "two", "really", "like",
55
- "just", "even", "well", "much", "still", "back", "way", "thing", "things",
56
- "make", "made", "work", "works", "worked", "use", "used", "using", "good",
57
- "great", "nice", "product", "item", "bought", "buy", "amazon", "review",
58
- "ordered", "order", "received", "came", "arrived", "shipping", "shipped",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  }
60
 
61
 
@@ -75,7 +247,7 @@ def extract_keywords(text: str, max_keywords: int = 8) -> list[str]:
75
  # Clean text
76
  text = text.lower()
77
  text = re.sub(r"<br\s*/?>", " ", text) # Remove HTML breaks
78
- text = re.sub(r"[^a-z\s]", " ", text) # Keep only letters
79
  text = re.sub(r"\s+", " ", text).strip()
80
 
81
  # Tokenize and filter
@@ -165,6 +337,7 @@ def generate_query_from_history(
165
  # Evaluation Dataset Construction
166
  # ---------------------------------------------------------------------------
167
 
 
168
  def build_leave_one_out_cases(
169
  df: pd.DataFrame,
170
  min_reviews: int = 2,
@@ -231,16 +404,21 @@ def build_leave_one_out_cases(
231
 
232
  # Only include if target has positive relevance
233
  if relevance > 0:
234
- eval_cases.append(EvalCase(
235
- query=query,
236
- relevant_items={target_product: relevance},
237
- user_id=user_id,
238
- ))
 
 
239
 
240
  if verbose:
241
  logger.info("Users with enough reviews: %d", len(user_groups) - skipped_users)
242
  logger.info("Eval cases created: %d", len(eval_cases))
243
- logger.info("Skipped (low relevance): %d", len(user_groups) - skipped_users - len(eval_cases))
 
 
 
244
 
245
  return eval_cases
246
 
@@ -310,16 +488,20 @@ def build_multi_relevant_cases(
310
  )
311
 
312
  if relevant_items:
313
- eval_cases.append(EvalCase(
314
- query=query,
315
- relevant_items=relevant_items,
316
- user_id=user_id,
317
- ))
 
 
318
 
319
  if verbose:
320
  logger.info("Users with train history: %d", len(train_users))
321
  logger.info("Eval cases created: %d", len(eval_cases))
322
- avg_relevant = np.mean([len(c.relevant_items) for c in eval_cases]) if eval_cases else 0
 
 
323
  logger.info("Avg relevant items per case: %.1f", avg_relevant)
324
 
325
  return eval_cases
@@ -400,7 +582,12 @@ if __name__ == "__main__":
400
  # Load splits
401
  log_section(logger, "Loading data splits")
402
  train_df, val_df, test_df = load_splits()
403
- logger.info("Train: %s | Val: %s | Test: %s", f"{len(train_df):,}", f"{len(val_df):,}", f"{len(test_df):,}")
 
 
 
 
 
404
 
405
  # Strategy 1: Leave-one-out with keyword queries
406
  # WARNING: This strategy has TARGET LEAKAGE - queries are generated from
@@ -418,8 +605,12 @@ if __name__ == "__main__":
418
  # Show examples
419
  logger.info("Sample queries:")
420
  for case in loo_keyword_cases[:5]:
421
- logger.info(" Query: \"%s\"", case.query)
422
- logger.info(" Target: %s (rel=%s)", list(case.relevant_items.keys())[0], list(case.relevant_items.values())[0])
 
 
 
 
423
 
424
  save_eval_cases(loo_keyword_cases, "eval_loo_keyword.json")
425
 
@@ -435,8 +626,12 @@ if __name__ == "__main__":
435
  # Show examples
436
  logger.info("Sample queries:")
437
  for case in loo_history_cases[:5]:
438
- logger.info(" Query: \"%s\"", case.query)
439
- logger.info(" Target: %s (rel=%s)", list(case.relevant_items.keys())[0], list(case.relevant_items.values())[0])
 
 
 
 
440
 
441
  save_eval_cases(loo_history_cases, "eval_loo_history.json")
442
 
@@ -452,7 +647,7 @@ if __name__ == "__main__":
452
  if multi_cases:
453
  logger.info("Sample queries:")
454
  for case in multi_cases[:3]:
455
- logger.info(" Query: \"%s...\"", case.query[:60])
456
  logger.info(" Relevant: %d items", len(case.relevant_items))
457
 
458
  save_eval_cases(multi_cases, "eval_multi_relevant.json")
 
35
 
36
  # Common stopwords to filter out
37
  STOPWORDS = {
38
+ "i",
39
+ "me",
40
+ "my",
41
+ "myself",
42
+ "we",
43
+ "our",
44
+ "ours",
45
+ "ourselves",
46
+ "you",
47
+ "your",
48
+ "yours",
49
+ "yourself",
50
+ "yourselves",
51
+ "he",
52
+ "him",
53
+ "his",
54
+ "himself",
55
+ "she",
56
+ "her",
57
+ "hers",
58
+ "herself",
59
+ "it",
60
+ "its",
61
+ "itself",
62
+ "they",
63
+ "them",
64
+ "their",
65
+ "theirs",
66
+ "themselves",
67
+ "what",
68
+ "which",
69
+ "who",
70
+ "whom",
71
+ "this",
72
+ "that",
73
+ "these",
74
+ "those",
75
+ "am",
76
+ "is",
77
+ "are",
78
+ "was",
79
+ "were",
80
+ "be",
81
+ "been",
82
+ "being",
83
+ "have",
84
+ "has",
85
+ "had",
86
+ "having",
87
+ "do",
88
+ "does",
89
+ "did",
90
+ "doing",
91
+ "a",
92
+ "an",
93
+ "the",
94
+ "and",
95
+ "but",
96
+ "if",
97
+ "or",
98
+ "because",
99
+ "as",
100
+ "until",
101
+ "while",
102
+ "of",
103
+ "at",
104
+ "by",
105
+ "for",
106
+ "with",
107
+ "about",
108
+ "against",
109
+ "between",
110
+ "into",
111
+ "through",
112
+ "during",
113
+ "before",
114
+ "after",
115
+ "above",
116
+ "below",
117
+ "to",
118
+ "from",
119
+ "up",
120
+ "down",
121
+ "in",
122
+ "out",
123
+ "on",
124
+ "off",
125
+ "over",
126
+ "under",
127
+ "again",
128
+ "further",
129
+ "then",
130
+ "once",
131
+ "here",
132
+ "there",
133
+ "when",
134
+ "where",
135
+ "why",
136
+ "how",
137
+ "all",
138
+ "each",
139
+ "few",
140
+ "more",
141
+ "most",
142
+ "other",
143
+ "some",
144
+ "such",
145
+ "no",
146
+ "nor",
147
+ "not",
148
+ "only",
149
+ "own",
150
+ "same",
151
+ "so",
152
+ "than",
153
+ "too",
154
+ "very",
155
+ "s",
156
+ "t",
157
+ "can",
158
+ "will",
159
+ "just",
160
+ "don",
161
+ "should",
162
+ "now",
163
+ "d",
164
+ "ll",
165
+ "m",
166
+ "o",
167
+ "re",
168
+ "ve",
169
+ "y",
170
+ "ain",
171
+ "aren",
172
+ "couldn",
173
+ "didn",
174
+ "doesn",
175
+ "hadn",
176
+ "hasn",
177
+ "haven",
178
+ "isn",
179
+ "ma",
180
+ "mightn",
181
+ "mustn",
182
+ "needn",
183
+ "shan",
184
+ "shouldn",
185
+ "wasn",
186
+ "weren",
187
+ "won",
188
+ "wouldn",
189
+ "also",
190
+ "would",
191
+ "could",
192
+ "get",
193
+ "got",
194
+ "one",
195
+ "two",
196
+ "really",
197
+ "like",
198
+ "just",
199
+ "even",
200
+ "well",
201
+ "much",
202
+ "still",
203
+ "back",
204
+ "way",
205
+ "thing",
206
+ "things",
207
+ "make",
208
+ "made",
209
+ "work",
210
+ "works",
211
+ "worked",
212
+ "use",
213
+ "used",
214
+ "using",
215
+ "good",
216
+ "great",
217
+ "nice",
218
+ "product",
219
+ "item",
220
+ "bought",
221
+ "buy",
222
+ "amazon",
223
+ "review",
224
+ "ordered",
225
+ "order",
226
+ "received",
227
+ "came",
228
+ "arrived",
229
+ "shipping",
230
+ "shipped",
231
  }
232
 
233
 
 
247
  # Clean text
248
  text = text.lower()
249
  text = re.sub(r"<br\s*/?>", " ", text) # Remove HTML breaks
250
+ text = re.sub(r"[^a-z\s]", " ", text) # Keep only letters
251
  text = re.sub(r"\s+", " ", text).strip()
252
 
253
  # Tokenize and filter
 
337
  # Evaluation Dataset Construction
338
  # ---------------------------------------------------------------------------
339
 
340
+
341
  def build_leave_one_out_cases(
342
  df: pd.DataFrame,
343
  min_reviews: int = 2,
 
404
 
405
  # Only include if target has positive relevance
406
  if relevance > 0:
407
+ eval_cases.append(
408
+ EvalCase(
409
+ query=query,
410
+ relevant_items={target_product: relevance},
411
+ user_id=user_id,
412
+ )
413
+ )
414
 
415
  if verbose:
416
  logger.info("Users with enough reviews: %d", len(user_groups) - skipped_users)
417
  logger.info("Eval cases created: %d", len(eval_cases))
418
+ logger.info(
419
+ "Skipped (low relevance): %d",
420
+ len(user_groups) - skipped_users - len(eval_cases),
421
+ )
422
 
423
  return eval_cases
424
 
 
488
  )
489
 
490
  if relevant_items:
491
+ eval_cases.append(
492
+ EvalCase(
493
+ query=query,
494
+ relevant_items=relevant_items,
495
+ user_id=user_id,
496
+ )
497
+ )
498
 
499
  if verbose:
500
  logger.info("Users with train history: %d", len(train_users))
501
  logger.info("Eval cases created: %d", len(eval_cases))
502
+ avg_relevant = (
503
+ np.mean([len(c.relevant_items) for c in eval_cases]) if eval_cases else 0
504
+ )
505
  logger.info("Avg relevant items per case: %.1f", avg_relevant)
506
 
507
  return eval_cases
 
582
  # Load splits
583
  log_section(logger, "Loading data splits")
584
  train_df, val_df, test_df = load_splits()
585
+ logger.info(
586
+ "Train: %s | Val: %s | Test: %s",
587
+ f"{len(train_df):,}",
588
+ f"{len(val_df):,}",
589
+ f"{len(test_df):,}",
590
+ )
591
 
592
  # Strategy 1: Leave-one-out with keyword queries
593
  # WARNING: This strategy has TARGET LEAKAGE - queries are generated from
 
605
  # Show examples
606
  logger.info("Sample queries:")
607
  for case in loo_keyword_cases[:5]:
608
+ logger.info(' Query: "%s"', case.query)
609
+ logger.info(
610
+ " Target: %s (rel=%s)",
611
+ list(case.relevant_items.keys())[0],
612
+ list(case.relevant_items.values())[0],
613
+ )
614
 
615
  save_eval_cases(loo_keyword_cases, "eval_loo_keyword.json")
616
 
 
626
  # Show examples
627
  logger.info("Sample queries:")
628
  for case in loo_history_cases[:5]:
629
+ logger.info(' Query: "%s"', case.query)
630
+ logger.info(
631
+ " Target: %s (rel=%s)",
632
+ list(case.relevant_items.keys())[0],
633
+ list(case.relevant_items.values())[0],
634
+ )
635
 
636
  save_eval_cases(loo_history_cases, "eval_loo_history.json")
637
 
 
647
  if multi_cases:
648
  logger.info("Sample queries:")
649
  for case in multi_cases[:3]:
650
+ logger.info(' Query: "%s..."', case.query[:60])
651
  logger.info(" Relevant: %d items", len(case.relevant_items))
652
 
653
  save_eval_cases(multi_cases, "eval_multi_relevant.json")
scripts/build_natural_eval_dataset.py CHANGED
@@ -70,7 +70,6 @@ NATURAL_QUERIES = [
70
  "category": "echo_devices",
71
  "intent": "feature_specific",
72
  },
73
-
74
  # === FIRE TABLET QUERIES ===
75
  {
76
  "query": "tablet for reading books and light browsing",
@@ -111,7 +110,6 @@ NATURAL_QUERIES = [
111
  "category": "fire_tablets",
112
  "intent": "use_case",
113
  },
114
-
115
  # === FIRE TV / STREAMING QUERIES ===
116
  {
117
  "query": "streaming device for my tv",
@@ -151,7 +149,6 @@ NATURAL_QUERIES = [
151
  "category": "fire_tv",
152
  "intent": "use_case",
153
  },
154
-
155
  # === SMART HOME QUERIES ===
156
  {
157
  "query": "smart plug to control lights with alexa",
@@ -191,7 +188,6 @@ NATURAL_QUERIES = [
191
  "category": "smart_home",
192
  "intent": "feature_specific",
193
  },
194
-
195
  # === STORAGE QUERIES ===
196
  {
197
  "query": "sd card for camera",
@@ -232,7 +228,6 @@ NATURAL_QUERIES = [
232
  "category": "storage",
233
  "intent": "feature_specific",
234
  },
235
-
236
  # === HEADPHONES / AUDIO QUERIES ===
237
  {
238
  "query": "wireless headphones for working out",
@@ -283,7 +278,6 @@ NATURAL_QUERIES = [
283
  "category": "headphones_audio",
284
  "intent": "use_case",
285
  },
286
-
287
  # === CABLES / ADAPTERS QUERIES ===
288
  {
289
  "query": "usb c charging cable",
@@ -322,7 +316,6 @@ NATURAL_QUERIES = [
322
  "category": "cables_adapters",
323
  "intent": "feature_specific",
324
  },
325
-
326
  # === KEYBOARD / MOUSE QUERIES ===
327
  {
328
  "query": "wireless keyboard for computer",
@@ -353,7 +346,6 @@ NATURAL_QUERIES = [
353
  "category": "keyboards_mice",
354
  "intent": "feature_specific",
355
  },
356
-
357
  # === GIFT QUERIES ===
358
  {
359
  "query": "gift for someone who likes music",
@@ -395,7 +387,6 @@ NATURAL_QUERIES = [
395
  "category": "gifts",
396
  "intent": "gift",
397
  },
398
-
399
  # === PROBLEM-SOLVING QUERIES ===
400
  {
401
  "query": "headphones that dont hurt ears",
@@ -424,7 +415,6 @@ NATURAL_QUERIES = [
424
  "category": "fire_tv",
425
  "intent": "problem_solving",
426
  },
427
-
428
  # === COMPARISON / BEST QUERIES ===
429
  {
430
  "query": "best value fire tablet",
@@ -460,15 +450,19 @@ def build_natural_eval_cases() -> list[EvalCase]:
460
  """Convert natural queries to EvalCase objects."""
461
  cases = []
462
  for item in NATURAL_QUERIES:
463
- cases.append(EvalCase(
464
- query=item["query"],
465
- relevant_items=item["relevant_items"],
466
- user_id=None, # No user for natural queries
467
- ))
 
 
468
  return cases
469
 
470
 
471
- def save_natural_eval_cases(cases: list[EvalCase], filename: str = "eval_natural_queries.json"):
 
 
472
  """Save evaluation cases with metadata."""
473
  EVAL_DIR.mkdir(exist_ok=True)
474
  filepath = EVAL_DIR / filename
@@ -476,12 +470,14 @@ def save_natural_eval_cases(cases: list[EvalCase], filename: str = "eval_natural
476
  # Include metadata for analysis
477
  data = []
478
  for i, item in enumerate(NATURAL_QUERIES):
479
- data.append({
480
- "query": item["query"],
481
- "relevant_items": item["relevant_items"],
482
- "category": item.get("category", "unknown"),
483
- "intent": item.get("intent", "general"),
484
- })
 
 
485
 
486
  with open(filepath, "w") as f:
487
  json.dump(data, f, indent=2)
@@ -522,9 +518,9 @@ def analyze_dataset():
522
  # Sample queries
523
  log_section(logger, "SAMPLE QUERIES")
524
  for q in NATURAL_QUERIES[:5]:
525
- logger.info("Query: \"%s\"", q['query'])
526
- logger.info(" Category: %s | Intent: %s", q['category'], q['intent'])
527
- logger.info(" Relevant: %d products", len(q['relevant_items']))
528
 
529
 
530
  if __name__ == "__main__":
 
70
  "category": "echo_devices",
71
  "intent": "feature_specific",
72
  },
 
73
  # === FIRE TABLET QUERIES ===
74
  {
75
  "query": "tablet for reading books and light browsing",
 
110
  "category": "fire_tablets",
111
  "intent": "use_case",
112
  },
 
113
  # === FIRE TV / STREAMING QUERIES ===
114
  {
115
  "query": "streaming device for my tv",
 
149
  "category": "fire_tv",
150
  "intent": "use_case",
151
  },
 
152
  # === SMART HOME QUERIES ===
153
  {
154
  "query": "smart plug to control lights with alexa",
 
188
  "category": "smart_home",
189
  "intent": "feature_specific",
190
  },
 
191
  # === STORAGE QUERIES ===
192
  {
193
  "query": "sd card for camera",
 
228
  "category": "storage",
229
  "intent": "feature_specific",
230
  },
 
231
  # === HEADPHONES / AUDIO QUERIES ===
232
  {
233
  "query": "wireless headphones for working out",
 
278
  "category": "headphones_audio",
279
  "intent": "use_case",
280
  },
 
281
  # === CABLES / ADAPTERS QUERIES ===
282
  {
283
  "query": "usb c charging cable",
 
316
  "category": "cables_adapters",
317
  "intent": "feature_specific",
318
  },
 
319
  # === KEYBOARD / MOUSE QUERIES ===
320
  {
321
  "query": "wireless keyboard for computer",
 
346
  "category": "keyboards_mice",
347
  "intent": "feature_specific",
348
  },
 
349
  # === GIFT QUERIES ===
350
  {
351
  "query": "gift for someone who likes music",
 
387
  "category": "gifts",
388
  "intent": "gift",
389
  },
 
390
  # === PROBLEM-SOLVING QUERIES ===
391
  {
392
  "query": "headphones that dont hurt ears",
 
415
  "category": "fire_tv",
416
  "intent": "problem_solving",
417
  },
 
418
  # === COMPARISON / BEST QUERIES ===
419
  {
420
  "query": "best value fire tablet",
 
450
  """Convert natural queries to EvalCase objects."""
451
  cases = []
452
  for item in NATURAL_QUERIES:
453
+ cases.append(
454
+ EvalCase(
455
+ query=item["query"],
456
+ relevant_items=item["relevant_items"],
457
+ user_id=None, # No user for natural queries
458
+ )
459
+ )
460
  return cases
461
 
462
 
463
+ def save_natural_eval_cases(
464
+ cases: list[EvalCase], filename: str = "eval_natural_queries.json"
465
+ ):
466
  """Save evaluation cases with metadata."""
467
  EVAL_DIR.mkdir(exist_ok=True)
468
  filepath = EVAL_DIR / filename
 
470
  # Include metadata for analysis
471
  data = []
472
  for i, item in enumerate(NATURAL_QUERIES):
473
+ data.append(
474
+ {
475
+ "query": item["query"],
476
+ "relevant_items": item["relevant_items"],
477
+ "category": item.get("category", "unknown"),
478
+ "intent": item.get("intent", "general"),
479
+ }
480
+ )
481
 
482
  with open(filepath, "w") as f:
483
  json.dump(data, f, indent=2)
 
518
  # Sample queries
519
  log_section(logger, "SAMPLE QUERIES")
520
  for q in NATURAL_QUERIES[:5]:
521
+ logger.info('Query: "%s"', q["query"])
522
+ logger.info(" Category: %s | Intent: %s", q["category"], q["intent"])
523
+ logger.info(" Relevant: %d products", len(q["relevant_items"]))
524
 
525
 
526
  if __name__ == "__main__":
scripts/demo.py CHANGED
@@ -31,7 +31,7 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
31
  Returns dict suitable for JSON serialization.
32
  """
33
  log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70)
34
- logger.info("Query: \"%s\"", query)
35
 
36
  # Get candidates
37
  products = get_candidates(
@@ -91,7 +91,7 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
91
  # Truncate long evidence for display
92
  display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text
93
  logger.info("[%s]:", ev_id)
94
- logger.info(" \"%s\"", display_text)
95
 
96
  # Compile result
97
  result = {
@@ -108,8 +108,7 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
108
  "evidence_sources": [
109
  {"id": ev_id, "text": ev_text}
110
  for ev_id, ev_text in zip(
111
- explanation_result.evidence_ids,
112
- explanation_result.evidence_texts
113
  )
114
  ],
115
  }
@@ -131,13 +130,15 @@ def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
131
  def main():
132
  parser = argparse.ArgumentParser(description="Demo recommendation pipeline")
133
  parser.add_argument(
134
- "--query", "-q",
 
135
  type=str,
136
  default="wireless earbuds for running",
137
  help="Query to demonstrate",
138
  )
139
  parser.add_argument(
140
- "--top-k", "-k",
 
141
  type=int,
142
  default=1,
143
  help="Number of products to recommend (default: 1)",
 
31
  Returns dict suitable for JSON serialization.
32
  """
33
  log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70)
34
+ logger.info('Query: "%s"', query)
35
 
36
  # Get candidates
37
  products = get_candidates(
 
91
  # Truncate long evidence for display
92
  display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text
93
  logger.info("[%s]:", ev_id)
94
+ logger.info(' "%s"', display_text)
95
 
96
  # Compile result
97
  result = {
 
108
  "evidence_sources": [
109
  {"id": ev_id, "text": ev_text}
110
  for ev_id, ev_text in zip(
111
+ explanation_result.evidence_ids, explanation_result.evidence_texts
 
112
  )
113
  ],
114
  }
 
130
  def main():
131
  parser = argparse.ArgumentParser(description="Demo recommendation pipeline")
132
  parser.add_argument(
133
+ "--query",
134
+ "-q",
135
  type=str,
136
  default="wireless earbuds for running",
137
  help="Query to demonstrate",
138
  )
139
  parser.add_argument(
140
+ "--top-k",
141
+ "-k",
142
  type=int,
143
  default=1,
144
  help="Number of products to recommend (default: 1)",
scripts/e2e_success_rate.py CHANGED
@@ -149,7 +149,7 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
149
  case_id = 0
150
 
151
  for query in queries:
152
- logger.info("Query: \"%s\"", query)
153
 
154
  products = get_candidates(
155
  query=query,
@@ -275,23 +275,43 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
275
  n_evidence_insufficient = sum(1 for c in all_cases if not c.evidence_sufficient)
276
  n_generated = sum(1 for c in all_cases if c.evidence_sufficient)
277
  n_forbidden_violations = sum(1 for c in all_cases if c.has_forbidden_phrases)
278
- n_hhem_failures = sum(1 for c in all_cases if c.evidence_sufficient and not c.hhem_pass and not c.is_valid_non_recommendation)
 
 
 
 
 
 
279
  n_valid_non_recs = sum(1 for c in all_cases if c.is_valid_non_recommendation)
280
 
281
  # Success counts
282
  n_raw_success = sum(1 for c in all_cases if c.e2e_success)
283
- n_adjusted_success = n_raw_success + n_valid_non_recs # Valid non-recs are correct behavior
 
 
284
 
285
  # Rates
286
  evidence_pass_rate = n_generated / n_total if n_total > 0 else 0
287
 
288
  # Forbidden phrase compliance among generated explanations
289
  generated_cases = [c for c in all_cases if c.evidence_sufficient]
290
- phrase_compliance = sum(1 for c in generated_cases if not c.has_forbidden_phrases) / len(generated_cases) if generated_cases else 0
 
 
 
 
 
291
 
292
  # HHEM pass rate among non-refusal generated explanations
293
- non_refusal_generated = [c for c in generated_cases if not c.is_valid_non_recommendation]
294
- hhem_pass_rate = sum(1 for c in non_refusal_generated if c.hhem_pass) / len(non_refusal_generated) if non_refusal_generated else 0
 
 
 
 
 
 
 
295
 
296
  raw_e2e = n_raw_success / n_total if n_total > 0 else 0
297
  adjusted_e2e = n_adjusted_success / n_total if n_total > 0 else 0
@@ -321,11 +341,31 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
321
 
322
  log_section(logger, "Stage Breakdown")
323
  logger.info("Total cases: %d", n_total)
324
- logger.info("Evidence insufficient: %d (%.1f%%)", n_evidence_insufficient, n_evidence_insufficient / n_total * 100)
325
- logger.info("Generated explanations: %d (%.1f%%)", n_generated, n_generated / n_total * 100)
326
- logger.info("Forbidden phrase fails: %d (%.1f%%)", n_forbidden_violations, n_forbidden_violations / n_total * 100)
327
- logger.info("HHEM failures: %d (%.1f%%)", n_hhem_failures, n_hhem_failures / n_total * 100)
328
- logger.info("Valid non-recommendations:%d (%.1f%%)", n_valid_non_recs, n_valid_non_recs / n_total * 100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  log_section(logger, "Component Rates")
331
  logger.info("Evidence pass rate: %.1f%%", evidence_pass_rate * 100)
@@ -333,8 +373,18 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
333
  logger.info("HHEM pass rate: %.1f%%", hhem_pass_rate * 100)
334
 
335
  log_section(logger, "END-TO-END SUCCESS RATES")
336
- logger.info("Raw E2E success: %d/%d = %.1f%%", n_raw_success, n_total, raw_e2e * 100)
337
- logger.info("Adjusted E2E success: %d/%d = %.1f%%", n_adjusted_success, n_total, adjusted_e2e * 100)
 
 
 
 
 
 
 
 
 
 
338
  logger.info("Target: %.1f%%", target * 100)
339
  logger.info("Gap to target: %.1f%%", report.gap_to_target * 100)
340
  logger.info("Meets target: %s", "YES" if report.meets_target else "NO")
@@ -355,7 +405,9 @@ def run_e2e_evaluation(n_samples: int = 20) -> E2EReport:
355
  "cases": [c.to_dict() for c in all_cases],
356
  }
357
 
358
- output_file = RESULTS_DIR / f"e2e_success_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
 
 
359
  with open(output_file, "w") as f:
360
  json.dump(output, f, indent=2)
361
  logger.info("Saved: %s", output_file)
 
149
  case_id = 0
150
 
151
  for query in queries:
152
+ logger.info('Query: "%s"', query)
153
 
154
  products = get_candidates(
155
  query=query,
 
275
  n_evidence_insufficient = sum(1 for c in all_cases if not c.evidence_sufficient)
276
  n_generated = sum(1 for c in all_cases if c.evidence_sufficient)
277
  n_forbidden_violations = sum(1 for c in all_cases if c.has_forbidden_phrases)
278
+ n_hhem_failures = sum(
279
+ 1
280
+ for c in all_cases
281
+ if c.evidence_sufficient
282
+ and not c.hhem_pass
283
+ and not c.is_valid_non_recommendation
284
+ )
285
  n_valid_non_recs = sum(1 for c in all_cases if c.is_valid_non_recommendation)
286
 
287
  # Success counts
288
  n_raw_success = sum(1 for c in all_cases if c.e2e_success)
289
+ n_adjusted_success = (
290
+ n_raw_success + n_valid_non_recs
291
+ ) # Valid non-recs are correct behavior
292
 
293
  # Rates
294
  evidence_pass_rate = n_generated / n_total if n_total > 0 else 0
295
 
296
  # Forbidden phrase compliance among generated explanations
297
  generated_cases = [c for c in all_cases if c.evidence_sufficient]
298
+ phrase_compliance = (
299
+ sum(1 for c in generated_cases if not c.has_forbidden_phrases)
300
+ / len(generated_cases)
301
+ if generated_cases
302
+ else 0
303
+ )
304
 
305
  # HHEM pass rate among non-refusal generated explanations
306
+ non_refusal_generated = [
307
+ c for c in generated_cases if not c.is_valid_non_recommendation
308
+ ]
309
+ hhem_pass_rate = (
310
+ sum(1 for c in non_refusal_generated if c.hhem_pass)
311
+ / len(non_refusal_generated)
312
+ if non_refusal_generated
313
+ else 0
314
+ )
315
 
316
  raw_e2e = n_raw_success / n_total if n_total > 0 else 0
317
  adjusted_e2e = n_adjusted_success / n_total if n_total > 0 else 0
 
341
 
342
  log_section(logger, "Stage Breakdown")
343
  logger.info("Total cases: %d", n_total)
344
+ logger.info(
345
+ "Evidence insufficient: %d (%.1f%%)",
346
+ n_evidence_insufficient,
347
+ n_evidence_insufficient / n_total * 100,
348
+ )
349
+ logger.info(
350
+ "Generated explanations: %d (%.1f%%)",
351
+ n_generated,
352
+ n_generated / n_total * 100,
353
+ )
354
+ logger.info(
355
+ "Forbidden phrase fails: %d (%.1f%%)",
356
+ n_forbidden_violations,
357
+ n_forbidden_violations / n_total * 100,
358
+ )
359
+ logger.info(
360
+ "HHEM failures: %d (%.1f%%)",
361
+ n_hhem_failures,
362
+ n_hhem_failures / n_total * 100,
363
+ )
364
+ logger.info(
365
+ "Valid non-recommendations:%d (%.1f%%)",
366
+ n_valid_non_recs,
367
+ n_valid_non_recs / n_total * 100,
368
+ )
369
 
370
  log_section(logger, "Component Rates")
371
  logger.info("Evidence pass rate: %.1f%%", evidence_pass_rate * 100)
 
373
  logger.info("HHEM pass rate: %.1f%%", hhem_pass_rate * 100)
374
 
375
  log_section(logger, "END-TO-END SUCCESS RATES")
376
+ logger.info(
377
+ "Raw E2E success: %d/%d = %.1f%%",
378
+ n_raw_success,
379
+ n_total,
380
+ raw_e2e * 100,
381
+ )
382
+ logger.info(
383
+ "Adjusted E2E success: %d/%d = %.1f%%",
384
+ n_adjusted_success,
385
+ n_total,
386
+ adjusted_e2e * 100,
387
+ )
388
  logger.info("Target: %.1f%%", target * 100)
389
  logger.info("Gap to target: %.1f%%", report.gap_to_target * 100)
390
  logger.info("Meets target: %s", "YES" if report.meets_target else "NO")
 
405
  "cases": [c.to_dict() for c in all_cases],
406
  }
407
 
408
+ output_file = (
409
+ RESULTS_DIR / f"e2e_success_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
410
+ )
411
  with open(output_file, "w") as f:
412
  json.dump(output, f, indent=2)
413
  logger.info("Saved: %s", output_file)
scripts/eda.py CHANGED
@@ -17,19 +17,22 @@ FIGURES_DIR.mkdir(exist_ok=True)
17
 
18
  # Plot configuration
19
  plt.style.use("seaborn-v0_8-whitegrid")
20
- plt.rcParams.update({
21
- "figure.figsize": (10, 5),
22
- "figure.dpi": 100,
23
- "savefig.dpi": 150,
24
- "font.size": 11,
25
- "axes.titlesize": 12,
26
- "axes.labelsize": 11,
27
- "figure.autolayout": True,
28
- })
 
 
29
 
30
  # Enable retina display for Jupyter notebooks
31
  try:
32
  from IPython import get_ipython
 
33
  if get_ipython() is not None:
34
  get_ipython().run_line_magic("matplotlib", "inline")
35
  get_ipython().run_line_magic("config", "InlineBackend.figure_format='retina'")
@@ -56,15 +59,23 @@ for key, value in stats.items():
56
  # %% Rating distribution
57
  fig, ax = plt.subplots()
58
  rating_counts = pd.Series(stats["rating_dist"])
59
- bars = ax.bar(rating_counts.index, rating_counts.values, color=PRIMARY_COLOR, edgecolor="black")
 
 
60
  ax.set_xlabel("Rating")
61
  ax.set_ylabel("Count")
62
  ax.set_title("Rating Distribution")
63
  ax.set_xticks(rating_counts.index)
64
 
65
  for bar, count in zip(bars, rating_counts.values):
66
- ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
67
- f"{count:,}", ha="center", va="bottom", fontsize=10)
 
 
 
 
 
 
68
 
69
  plt.tight_layout()
70
  plt.savefig(FIGURES_DIR / "rating_distribution.png", dpi=150)
@@ -84,16 +95,25 @@ fig, axes = plt.subplots(1, 2, figsize=FIGURE_SIZE_WIDE)
84
 
85
  # Character length histogram
86
  ax1 = axes[0]
87
- df["text_length"].clip(upper=2000).hist(bins=50, ax=ax1, color=PRIMARY_COLOR, edgecolor="white")
 
 
88
  ax1.set_xlabel("Character Length (clipped at 2000)")
89
  ax1.set_ylabel("Count")
90
  ax1.set_title("Review Length Distribution")
91
- ax1.axvline(df["text_length"].median(), color="red", linestyle="--", label=f"Median: {df['text_length'].median():.0f}")
 
 
 
 
 
92
  ax1.legend()
93
 
94
  # Token estimate histogram
95
  ax2 = axes[1]
96
- df["estimated_tokens"].clip(upper=500).hist(bins=50, ax=ax2, color=SECONDARY_COLOR, edgecolor="white")
 
 
97
  ax2.set_xlabel("Estimated Tokens (clipped at 500)")
98
  ax2.set_ylabel("Count")
99
  ax2.set_title("Estimated Token Distribution")
@@ -108,12 +128,19 @@ needs_chunking = (df["estimated_tokens"] > 200).sum()
108
  print("\nReview length stats:")
109
  print(f" Median characters: {df['text_length'].median():.0f}")
110
  print(f" Median tokens (est): {df['estimated_tokens'].median():.0f}")
111
- print(f" Reviews > 200 tokens: {needs_chunking:,} ({needs_chunking/len(df)*100:.1f}%)")
 
 
112
 
113
  # %% Review length by rating
114
  fig, ax = plt.subplots()
115
  length_by_rating = df.groupby("rating")["text_length"].median()
116
- bars = ax.bar(length_by_rating.index, length_by_rating.values, color=PRIMARY_COLOR, edgecolor="white")
 
 
 
 
 
117
  ax.set_xlabel("Rating")
118
  ax.set_ylabel("Median Review Length (chars)")
119
  ax.set_title("Review Length by Rating")
@@ -134,7 +161,9 @@ df["year_month"] = df["datetime"].dt.to_period("M")
134
  reviews_over_time = df.groupby("year_month").size()
135
 
136
  fig, ax = plt.subplots(figsize=FIGURE_SIZE_WIDE)
137
- reviews_over_time.plot(kind="line", ax=ax, marker="o", markersize=3, linewidth=1, color=PRIMARY_COLOR)
 
 
138
  ax.set_xlabel("Month")
139
  ax.set_ylabel("Number of Reviews")
140
  ax.set_title("Reviews Over Time")
@@ -156,7 +185,7 @@ missing = df.isnull().sum()
156
  print("\nMissing values:")
157
  for col, count in missing.items():
158
  if count > 0:
159
- print(f" {col}: {count:,} ({count/len(df)*100:.2f}%)")
160
  if missing.sum() == 0:
161
  print(" None!")
162
 
@@ -185,14 +214,18 @@ fig, axes = plt.subplots(1, 2, figsize=FIGURE_SIZE_WIDE)
185
 
186
  # Reviews per user
187
  ax1 = axes[0]
188
- user_counts.clip(upper=20).value_counts().sort_index().plot(kind="bar", ax=ax1, color=PRIMARY_COLOR)
 
 
189
  ax1.set_xlabel("Reviews per User")
190
  ax1.set_ylabel("Number of Users")
191
  ax1.set_title("User Activity Distribution")
192
 
193
  # Reviews per item
194
  ax2 = axes[1]
195
- item_counts.clip(upper=20).value_counts().sort_index().plot(kind="bar", ax=ax2, color=SECONDARY_COLOR)
 
 
196
  ax2.set_xlabel("Reviews per Item")
197
  ax2.set_ylabel("Number of Items")
198
  ax2.set_title("Item Popularity Distribution")
@@ -202,12 +235,16 @@ plt.savefig(FIGURES_DIR / "user_item_distribution.png", dpi=150)
202
  plt.show()
203
 
204
  print("\nUser activity:")
205
- print(f" Users with 1 review: {(user_counts == 1).sum():,} ({(user_counts == 1).sum()/len(user_counts)*100:.1f}%)")
 
 
206
  print(f" Users with 5+ reviews: {(user_counts >= 5).sum():,}")
207
  print(f" Max reviews by one user: {user_counts.max()}")
208
 
209
  print("\nItem popularity:")
210
- print(f" Items with 1 review: {(item_counts == 1).sum():,} ({(item_counts == 1).sum()/len(item_counts)*100:.1f}%)")
 
 
211
  print(f" Items with 5+ reviews: {(item_counts >= 5).sum():,}")
212
  print(f" Max reviews for one item: {item_counts.max()}")
213
 
@@ -217,7 +254,9 @@ items_5plus = set(item_counts[item_counts >= 5].index)
217
 
218
  eligible_mask = df["user_id"].isin(users_5plus) & df["parent_asin"].isin(items_5plus)
219
  print("\n5-core filtering preview:")
220
- print(f" Reviews eligible (first pass): {eligible_mask.sum():,} ({eligible_mask.sum()/len(df)*100:.1f}%)")
 
 
221
 
222
  # %% Sample reviews across length buckets
223
  print("\n=== Sample Reviews by Length Bucket ===")
@@ -232,14 +271,18 @@ length_buckets = [
232
  ]
233
 
234
  for min_tok, max_tok, label in length_buckets:
235
- bucket_mask = (df["estimated_tokens"] >= min_tok) & (df["estimated_tokens"] < max_tok)
 
 
236
  bucket_df = df[bucket_mask]
237
 
238
  if len(bucket_df) == 0:
239
  print(f"{label}: No reviews")
240
  continue
241
 
242
- print(f"{label}: {len(bucket_df):,} reviews ({len(bucket_df)/len(df)*100:.1f}%)")
 
 
243
 
244
  samples = bucket_df.sample(min(3, len(bucket_df)), random_state=42)
245
  for _, row in samples.iterrows():
@@ -256,10 +299,14 @@ df_prepared = prepare_data(subset_size=DEV_SUBSET_SIZE, verbose=False)
256
  prepared_stats = get_review_stats(df_prepared)
257
 
258
  print(f"Raw reviews: {len(df):,}")
259
- print(f"Prepared reviews: {len(df_prepared):,} ({len(df_prepared)/len(df)*100:.1f}% retained)")
 
 
260
  print(f"Unique users: {prepared_stats['unique_users']:,}")
261
  print(f"Unique items: {prepared_stats['unique_items']:,}")
262
- print(f"Avg rating: {prepared_stats['avg_rating']:.2f} (raw: {stats['avg_rating']:.2f})")
 
 
263
 
264
  # %% Summary
265
  print("\n" + "=" * 50)
@@ -269,7 +316,9 @@ print(f"Total reviews: {len(df):,}")
269
  print(f"Unique users: {df['user_id'].nunique():,}")
270
  print(f"Unique items: {df['parent_asin'].nunique():,}")
271
  print(f"Average rating: {df['rating'].mean():.2f}")
272
- print(f"Reviews needing chunking: {needs_chunking:,} ({needs_chunking/len(df)*100:.1f}%)")
 
 
273
  print(f"Data quality issues: {empty_reviews + very_short + duplicate_texts}")
274
  print(f"\nPlots saved to: {FIGURES_DIR}")
275
 
 
17
 
18
  # Plot configuration
19
  plt.style.use("seaborn-v0_8-whitegrid")
20
+ plt.rcParams.update(
21
+ {
22
+ "figure.figsize": (10, 5),
23
+ "figure.dpi": 100,
24
+ "savefig.dpi": 150,
25
+ "font.size": 11,
26
+ "axes.titlesize": 12,
27
+ "axes.labelsize": 11,
28
+ "figure.autolayout": True,
29
+ }
30
+ )
31
 
32
  # Enable retina display for Jupyter notebooks
33
  try:
34
  from IPython import get_ipython
35
+
36
  if get_ipython() is not None:
37
  get_ipython().run_line_magic("matplotlib", "inline")
38
  get_ipython().run_line_magic("config", "InlineBackend.figure_format='retina'")
 
59
  # %% Rating distribution
60
  fig, ax = plt.subplots()
61
  rating_counts = pd.Series(stats["rating_dist"])
62
+ bars = ax.bar(
63
+ rating_counts.index, rating_counts.values, color=PRIMARY_COLOR, edgecolor="black"
64
+ )
65
  ax.set_xlabel("Rating")
66
  ax.set_ylabel("Count")
67
  ax.set_title("Rating Distribution")
68
  ax.set_xticks(rating_counts.index)
69
 
70
  for bar, count in zip(bars, rating_counts.values):
71
+ ax.text(
72
+ bar.get_x() + bar.get_width() / 2,
73
+ bar.get_height() + 50,
74
+ f"{count:,}",
75
+ ha="center",
76
+ va="bottom",
77
+ fontsize=10,
78
+ )
79
 
80
  plt.tight_layout()
81
  plt.savefig(FIGURES_DIR / "rating_distribution.png", dpi=150)
 
95
 
96
  # Character length histogram
97
  ax1 = axes[0]
98
+ df["text_length"].clip(upper=2000).hist(
99
+ bins=50, ax=ax1, color=PRIMARY_COLOR, edgecolor="white"
100
+ )
101
  ax1.set_xlabel("Character Length (clipped at 2000)")
102
  ax1.set_ylabel("Count")
103
  ax1.set_title("Review Length Distribution")
104
+ ax1.axvline(
105
+ df["text_length"].median(),
106
+ color="red",
107
+ linestyle="--",
108
+ label=f"Median: {df['text_length'].median():.0f}",
109
+ )
110
  ax1.legend()
111
 
112
  # Token estimate histogram
113
  ax2 = axes[1]
114
+ df["estimated_tokens"].clip(upper=500).hist(
115
+ bins=50, ax=ax2, color=SECONDARY_COLOR, edgecolor="white"
116
+ )
117
  ax2.set_xlabel("Estimated Tokens (clipped at 500)")
118
  ax2.set_ylabel("Count")
119
  ax2.set_title("Estimated Token Distribution")
 
128
  print("\nReview length stats:")
129
  print(f" Median characters: {df['text_length'].median():.0f}")
130
  print(f" Median tokens (est): {df['estimated_tokens'].median():.0f}")
131
+ print(
132
+ f" Reviews > 200 tokens: {needs_chunking:,} ({needs_chunking / len(df) * 100:.1f}%)"
133
+ )
134
 
135
  # %% Review length by rating
136
  fig, ax = plt.subplots()
137
  length_by_rating = df.groupby("rating")["text_length"].median()
138
+ bars = ax.bar(
139
+ length_by_rating.index,
140
+ length_by_rating.values,
141
+ color=PRIMARY_COLOR,
142
+ edgecolor="white",
143
+ )
144
  ax.set_xlabel("Rating")
145
  ax.set_ylabel("Median Review Length (chars)")
146
  ax.set_title("Review Length by Rating")
 
161
  reviews_over_time = df.groupby("year_month").size()
162
 
163
  fig, ax = plt.subplots(figsize=FIGURE_SIZE_WIDE)
164
+ reviews_over_time.plot(
165
+ kind="line", ax=ax, marker="o", markersize=3, linewidth=1, color=PRIMARY_COLOR
166
+ )
167
  ax.set_xlabel("Month")
168
  ax.set_ylabel("Number of Reviews")
169
  ax.set_title("Reviews Over Time")
 
185
  print("\nMissing values:")
186
  for col, count in missing.items():
187
  if count > 0:
188
+ print(f" {col}: {count:,} ({count / len(df) * 100:.2f}%)")
189
  if missing.sum() == 0:
190
  print(" None!")
191
 
 
214
 
215
  # Reviews per user
216
  ax1 = axes[0]
217
+ user_counts.clip(upper=20).value_counts().sort_index().plot(
218
+ kind="bar", ax=ax1, color=PRIMARY_COLOR
219
+ )
220
  ax1.set_xlabel("Reviews per User")
221
  ax1.set_ylabel("Number of Users")
222
  ax1.set_title("User Activity Distribution")
223
 
224
  # Reviews per item
225
  ax2 = axes[1]
226
+ item_counts.clip(upper=20).value_counts().sort_index().plot(
227
+ kind="bar", ax=ax2, color=SECONDARY_COLOR
228
+ )
229
  ax2.set_xlabel("Reviews per Item")
230
  ax2.set_ylabel("Number of Items")
231
  ax2.set_title("Item Popularity Distribution")
 
235
  plt.show()
236
 
237
  print("\nUser activity:")
238
+ print(
239
+ f" Users with 1 review: {(user_counts == 1).sum():,} ({(user_counts == 1).sum() / len(user_counts) * 100:.1f}%)"
240
+ )
241
  print(f" Users with 5+ reviews: {(user_counts >= 5).sum():,}")
242
  print(f" Max reviews by one user: {user_counts.max()}")
243
 
244
  print("\nItem popularity:")
245
+ print(
246
+ f" Items with 1 review: {(item_counts == 1).sum():,} ({(item_counts == 1).sum() / len(item_counts) * 100:.1f}%)"
247
+ )
248
  print(f" Items with 5+ reviews: {(item_counts >= 5).sum():,}")
249
  print(f" Max reviews for one item: {item_counts.max()}")
250
 
 
254
 
255
  eligible_mask = df["user_id"].isin(users_5plus) & df["parent_asin"].isin(items_5plus)
256
  print("\n5-core filtering preview:")
257
+ print(
258
+ f" Reviews eligible (first pass): {eligible_mask.sum():,} ({eligible_mask.sum() / len(df) * 100:.1f}%)"
259
+ )
260
 
261
  # %% Sample reviews across length buckets
262
  print("\n=== Sample Reviews by Length Bucket ===")
 
271
  ]
272
 
273
  for min_tok, max_tok, label in length_buckets:
274
+ bucket_mask = (df["estimated_tokens"] >= min_tok) & (
275
+ df["estimated_tokens"] < max_tok
276
+ )
277
  bucket_df = df[bucket_mask]
278
 
279
  if len(bucket_df) == 0:
280
  print(f"{label}: No reviews")
281
  continue
282
 
283
+ print(
284
+ f"{label}: {len(bucket_df):,} reviews ({len(bucket_df) / len(df) * 100:.1f}%)"
285
+ )
286
 
287
  samples = bucket_df.sample(min(3, len(bucket_df)), random_state=42)
288
  for _, row in samples.iterrows():
 
299
  prepared_stats = get_review_stats(df_prepared)
300
 
301
  print(f"Raw reviews: {len(df):,}")
302
+ print(
303
+ f"Prepared reviews: {len(df_prepared):,} ({len(df_prepared) / len(df) * 100:.1f}% retained)"
304
+ )
305
  print(f"Unique users: {prepared_stats['unique_users']:,}")
306
  print(f"Unique items: {prepared_stats['unique_items']:,}")
307
+ print(
308
+ f"Avg rating: {prepared_stats['avg_rating']:.2f} (raw: {stats['avg_rating']:.2f})"
309
+ )
310
 
311
  # %% Summary
312
  print("\n" + "=" * 50)
 
316
  print(f"Unique users: {df['user_id'].nunique():,}")
317
  print(f"Unique items: {df['parent_asin'].nunique():,}")
318
  print(f"Average rating: {df['rating'].mean():.2f}")
319
+ print(
320
+ f"Reviews needing chunking: {needs_chunking:,} ({needs_chunking / len(df) * 100:.1f}%)"
321
+ )
322
  print(f"Data quality issues: {empty_reviews + very_short + duplicate_texts}")
323
  print(f"\nPlots saved to: {FIGURES_DIR}")
324
 
scripts/evaluation.py CHANGED
@@ -48,6 +48,7 @@ def create_recommend_fn(
48
  rating_weight: float = 0.0,
49
  ):
50
  """Create a recommend function for evaluation."""
 
51
  def _recommend(query: str) -> list[str]:
52
  recs = recommend(
53
  query=query,
@@ -59,10 +60,13 @@ def create_recommend_fn(
59
  rating_weight=rating_weight,
60
  )
61
  return [r.product_id for r in recs]
 
62
  return _recommend
63
 
64
 
65
- def save_results(results: dict, filename: str | None = None, dataset: str | None = None) -> Path:
 
 
66
  """Save evaluation results to JSON file.
67
 
68
  Also writes a fixed-name "latest" file so downstream scripts (e.g.
@@ -89,6 +93,7 @@ def save_results(results: dict, filename: str | None = None, dataset: str | None
89
  # SECTION: Primary Evaluation
90
  # ============================================================================
91
 
 
92
  def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items):
93
  """Run primary evaluation on leave-one-out dataset."""
94
  log_banner(logger, "EVALUATION: Leave-One-Out (History Queries)")
@@ -124,6 +129,7 @@ def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items)
124
  # SECTION: Aggregation Methods
125
  # ============================================================================
126
 
 
127
  def run_aggregation_comparison(cases):
128
  """Compare different aggregation methods."""
129
  log_banner(logger, "AGGREGATION METHOD COMPARISON")
@@ -154,6 +160,7 @@ def run_aggregation_comparison(cases):
154
  # SECTION: Rating Filter
155
  # ============================================================================
156
 
 
157
  def run_rating_filter_comparison(cases):
158
  """Compare different rating filters."""
159
  log_banner(logger, "RATING FILTER COMPARISON")
@@ -177,6 +184,7 @@ def run_rating_filter_comparison(cases):
177
  # SECTION: K Values
178
  # ============================================================================
179
 
 
180
  def run_k_value_comparison(cases):
181
  """Compare metrics at different K values."""
182
  log_banner(logger, "METRICS AT DIFFERENT K VALUES")
@@ -200,16 +208,23 @@ def run_k_value_comparison(cases):
200
  # SECTION: Weight Tuning
201
  # ============================================================================
202
 
 
203
  def run_weight_tuning(cases):
204
  """Run ranking weight tuning experiment."""
205
  log_banner(logger, "RANKING WEIGHT TUNING (alpha*sim + beta*rating)")
206
 
207
  weight_configs = [
208
- (1.0, 0.0), (0.9, 0.1), (0.8, 0.2),
209
- (0.7, 0.3), (0.6, 0.4), (0.5, 0.5),
 
 
 
 
210
  ]
211
 
212
- logger.info("%-10s %-12s %-10s %-10s %-10s", "alpha", "beta", "NDCG@10", "Hit@10", "MRR")
 
 
213
  logger.info("-" * 52)
214
 
215
  results = []
@@ -227,15 +242,22 @@ def run_weight_tuning(cases):
227
  k=10,
228
  verbose=False,
229
  )
230
- results.append({
231
- "alpha": alpha, "beta": beta,
232
- "ndcg_at_10": report.ndcg_at_k,
233
- "hit_at_10": report.hit_at_k,
234
- "mrr": report.mrr,
235
- })
 
 
 
236
  logger.info(
237
  "%-10.1f %-12.1f %-10.4f %-10.4f %-10.4f",
238
- alpha, beta, report.ndcg_at_k, report.hit_at_k, report.mrr
 
 
 
 
239
  )
240
 
241
  if report.ndcg_at_k > best_ndcg:
@@ -245,7 +267,9 @@ def run_weight_tuning(cases):
245
  logger.info("-" * 52)
246
  logger.info(
247
  "Best: alpha=%.1f, beta=%.1f (NDCG@10=%.4f)",
248
- best_weights[0], best_weights[1], best_ndcg
 
 
249
  )
250
 
251
  return results, best_weights, best_ndcg
@@ -255,6 +279,7 @@ def run_weight_tuning(cases):
255
  # SECTION: Baseline Comparison
256
  # ============================================================================
257
 
 
258
  def run_baseline_comparison(cases, train_records, all_products, product_embeddings):
259
  """Compare against baselines: Random, Popularity, ItemKNN."""
260
  log_banner(logger, "BASELINE COMPARISON")
@@ -274,7 +299,12 @@ def run_baseline_comparison(cases, train_records, all_products, product_embeddin
274
  return itemknn_baseline.recommend(query, top_k=10)
275
 
276
  def rag_recommend(query: str) -> list[str]:
277
- recs = recommend(query=query, top_k=10, candidate_limit=100, aggregation=AggregationMethod.MAX)
 
 
 
 
 
278
  return [r.product_id for r in recs]
279
 
280
  results = {}
@@ -305,7 +335,10 @@ def run_baseline_comparison(cases, train_records, all_products, product_embeddin
305
  for name, report in results.items():
306
  logger.info(
307
  "%-15s %10.4f %10.4f %10.4f",
308
- name, report.ndcg_at_k, report.hit_at_k, report.mrr
 
 
 
309
  )
310
 
311
  # Relative improvements
@@ -323,17 +356,22 @@ def run_baseline_comparison(cases, train_records, all_products, product_embeddin
323
  # Main
324
  # ============================================================================
325
 
 
326
  def main():
327
  parser = argparse.ArgumentParser(description="Run recommendation evaluation")
328
- parser.add_argument("--baselines", action="store_true", help="Include baseline comparison")
329
  parser.add_argument(
330
- "--section", "-s",
 
 
 
 
331
  choices=["all", "primary", "aggregation", "rating", "k", "weights"],
332
  default="primary",
333
  help="Which section to run (default: primary)",
334
  )
335
  parser.add_argument(
336
- "--dataset", "-d",
 
337
  default="eval_loo_history.json",
338
  help="Evaluation dataset file (default: eval_loo_history.json)",
339
  )
@@ -375,7 +413,9 @@ def main():
375
  )
376
 
377
  if args.section in ("all", "aggregation"):
378
- all_results["experiments"]["aggregation_methods"] = run_aggregation_comparison(cases)
 
 
379
 
380
  if args.section in ("all", "rating"):
381
  run_rating_filter_comparison(cases)
 
48
  rating_weight: float = 0.0,
49
  ):
50
  """Create a recommend function for evaluation."""
51
+
52
  def _recommend(query: str) -> list[str]:
53
  recs = recommend(
54
  query=query,
 
60
  rating_weight=rating_weight,
61
  )
62
  return [r.product_id for r in recs]
63
+
64
  return _recommend
65
 
66
 
67
+ def save_results(
68
+ results: dict, filename: str | None = None, dataset: str | None = None
69
+ ) -> Path:
70
  """Save evaluation results to JSON file.
71
 
72
  Also writes a fixed-name "latest" file so downstream scripts (e.g.
 
93
  # SECTION: Primary Evaluation
94
  # ============================================================================
95
 
96
+
97
  def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items):
98
  """Run primary evaluation on leave-one-out dataset."""
99
  log_banner(logger, "EVALUATION: Leave-One-Out (History Queries)")
 
129
  # SECTION: Aggregation Methods
130
  # ============================================================================
131
 
132
+
133
  def run_aggregation_comparison(cases):
134
  """Compare different aggregation methods."""
135
  log_banner(logger, "AGGREGATION METHOD COMPARISON")
 
160
  # SECTION: Rating Filter
161
  # ============================================================================
162
 
163
+
164
  def run_rating_filter_comparison(cases):
165
  """Compare different rating filters."""
166
  log_banner(logger, "RATING FILTER COMPARISON")
 
184
  # SECTION: K Values
185
  # ============================================================================
186
 
187
+
188
  def run_k_value_comparison(cases):
189
  """Compare metrics at different K values."""
190
  log_banner(logger, "METRICS AT DIFFERENT K VALUES")
 
208
  # SECTION: Weight Tuning
209
  # ============================================================================
210
 
211
+
212
  def run_weight_tuning(cases):
213
  """Run ranking weight tuning experiment."""
214
  log_banner(logger, "RANKING WEIGHT TUNING (alpha*sim + beta*rating)")
215
 
216
  weight_configs = [
217
+ (1.0, 0.0),
218
+ (0.9, 0.1),
219
+ (0.8, 0.2),
220
+ (0.7, 0.3),
221
+ (0.6, 0.4),
222
+ (0.5, 0.5),
223
  ]
224
 
225
+ logger.info(
226
+ "%-10s %-12s %-10s %-10s %-10s", "alpha", "beta", "NDCG@10", "Hit@10", "MRR"
227
+ )
228
  logger.info("-" * 52)
229
 
230
  results = []
 
242
  k=10,
243
  verbose=False,
244
  )
245
+ results.append(
246
+ {
247
+ "alpha": alpha,
248
+ "beta": beta,
249
+ "ndcg_at_10": report.ndcg_at_k,
250
+ "hit_at_10": report.hit_at_k,
251
+ "mrr": report.mrr,
252
+ }
253
+ )
254
  logger.info(
255
  "%-10.1f %-12.1f %-10.4f %-10.4f %-10.4f",
256
+ alpha,
257
+ beta,
258
+ report.ndcg_at_k,
259
+ report.hit_at_k,
260
+ report.mrr,
261
  )
262
 
263
  if report.ndcg_at_k > best_ndcg:
 
267
  logger.info("-" * 52)
268
  logger.info(
269
  "Best: alpha=%.1f, beta=%.1f (NDCG@10=%.4f)",
270
+ best_weights[0],
271
+ best_weights[1],
272
+ best_ndcg,
273
  )
274
 
275
  return results, best_weights, best_ndcg
 
279
  # SECTION: Baseline Comparison
280
  # ============================================================================
281
 
282
+
283
  def run_baseline_comparison(cases, train_records, all_products, product_embeddings):
284
  """Compare against baselines: Random, Popularity, ItemKNN."""
285
  log_banner(logger, "BASELINE COMPARISON")
 
299
  return itemknn_baseline.recommend(query, top_k=10)
300
 
301
  def rag_recommend(query: str) -> list[str]:
302
+ recs = recommend(
303
+ query=query,
304
+ top_k=10,
305
+ candidate_limit=100,
306
+ aggregation=AggregationMethod.MAX,
307
+ )
308
  return [r.product_id for r in recs]
309
 
310
  results = {}
 
335
  for name, report in results.items():
336
  logger.info(
337
  "%-15s %10.4f %10.4f %10.4f",
338
+ name,
339
+ report.ndcg_at_k,
340
+ report.hit_at_k,
341
+ report.mrr,
342
  )
343
 
344
  # Relative improvements
 
356
  # Main
357
  # ============================================================================
358
 
359
+
360
  def main():
361
  parser = argparse.ArgumentParser(description="Run recommendation evaluation")
 
362
  parser.add_argument(
363
+ "--baselines", action="store_true", help="Include baseline comparison"
364
+ )
365
+ parser.add_argument(
366
+ "--section",
367
+ "-s",
368
  choices=["all", "primary", "aggregation", "rating", "k", "weights"],
369
  default="primary",
370
  help="Which section to run (default: primary)",
371
  )
372
  parser.add_argument(
373
+ "--dataset",
374
+ "-d",
375
  default="eval_loo_history.json",
376
  help="Evaluation dataset file (default: eval_loo_history.json)",
377
  )
 
413
  )
414
 
415
  if args.section in ("all", "aggregation"):
416
+ all_results["experiments"]["aggregation_methods"] = run_aggregation_comparison(
417
+ cases
418
+ )
419
 
420
  if args.section in ("all", "rating"):
421
  run_rating_filter_comparison(cases)
scripts/explanation.py CHANGED
@@ -40,6 +40,7 @@ PRODUCTS_PER_QUERY = 2
40
  # SECTION: Basic Explanation Generation
41
  # ============================================================================
42
 
 
43
  def run_basic_tests():
44
  """Test basic explanation generation and HHEM detection."""
45
  from sage.services.explanation import Explainer
@@ -59,10 +60,13 @@ def run_basic_tests():
59
  query_results = {}
60
  for query in test_queries:
61
  products = get_candidates(
62
- query=query, k=TOP_K_PRODUCTS, min_rating=4.0, aggregation=AggregationMethod.MAX
 
 
 
63
  )
64
  query_results[query] = products
65
- logger.info("Query: \"%s\"", query)
66
  logger.info(" Found %d products", len(products))
67
 
68
  # Generate explanations
@@ -71,7 +75,7 @@ def run_basic_tests():
71
  all_explanations = []
72
 
73
  for query, products in query_results.items():
74
- logger.info("--- Query: \"%s\" ---", query)
75
  for product in products[:PRODUCTS_PER_QUERY]:
76
  result = explainer.generate_explanation(query, product)
77
  all_explanations.append(result)
@@ -100,7 +104,7 @@ def run_basic_tests():
100
  if query_results:
101
  test_query = list(query_results.keys())[0]
102
  test_product = query_results[test_query][0]
103
- logger.info("Query: \"%s\"", test_query)
104
  logger.info("Streaming: ")
105
 
106
  stream = explainer.generate_explanation_stream(test_query, test_product)
@@ -110,7 +114,9 @@ def run_basic_tests():
110
  logger.info("".join(chunks))
111
 
112
  streamed_result = stream.get_complete_result()
113
- hhem = detector.check_explanation(streamed_result.evidence_texts, streamed_result.explanation)
 
 
114
  logger.info("HHEM Score: %.3f", hhem.score)
115
 
116
  log_banner(logger, "BASIC TESTS COMPLETE")
@@ -120,7 +126,10 @@ def run_basic_tests():
120
  # SECTION: Evidence Quality Gate
121
  # ============================================================================
122
 
123
- def create_mock_product(n_chunks: int, tokens_per_chunk: int = 100, product_score: float = 0.85) -> ProductScore:
 
 
 
124
  """Create a mock ProductScore for testing."""
125
  chunks = [
126
  RetrievedChunk(
@@ -145,7 +154,11 @@ def run_quality_gate_tests():
145
  """Test the evidence quality gate."""
146
  from sage.core.evidence import check_evidence_quality, generate_refusal_message
147
  from sage.services.faithfulness import is_refusal
148
- from sage.config import MIN_EVIDENCE_CHUNKS, MIN_EVIDENCE_TOKENS, MIN_RETRIEVAL_SCORE
 
 
 
 
149
 
150
  log_banner(logger, "EVIDENCE QUALITY GATE TESTS")
151
 
@@ -161,7 +174,14 @@ def run_quality_gate_tests():
161
  product = create_mock_product(n_chunks, tok, score)
162
  quality = check_evidence_quality(product)
163
  status = "PASS" if quality.is_sufficient == expected else "FAIL"
164
- logger.info("[%s] %d chunks, %d tok, score=%.2f -> %s", status, n_chunks, tok, score, reason)
 
 
 
 
 
 
 
165
  assert quality.is_sufficient == expected
166
 
167
  log_section(logger, "2. REFUSAL GENERATION")
@@ -172,12 +192,18 @@ def run_quality_gate_tests():
172
  quality = check_evidence_quality(product)
173
  refusal = generate_refusal_message(query, quality)
174
  detected = is_refusal(refusal)
175
- logger.info("[%s] Refusal detected for %s", "PASS" if detected else "FAIL", quality.failure_reason)
 
 
 
 
176
  assert detected
177
 
178
  logger.info(
179
  "Thresholds: chunks=%d, tokens=%d, score=%.2f",
180
- MIN_EVIDENCE_CHUNKS, MIN_EVIDENCE_TOKENS, MIN_RETRIEVAL_SCORE
 
 
181
  )
182
  log_banner(logger, "QUALITY GATE TESTS COMPLETE")
183
 
@@ -186,9 +212,14 @@ def run_quality_gate_tests():
186
  # SECTION: Verification Loop
187
  # ============================================================================
188
 
 
189
  def run_verification_tests():
190
  """Test the post-generation verification loop."""
191
- from sage.core.verification import extract_quotes, verify_quote_in_evidence, verify_explanation
 
 
 
 
192
 
193
  log_banner(logger, "VERIFICATION LOOP TESTS")
194
 
@@ -235,6 +266,7 @@ def run_verification_tests():
235
  # SECTION: Cold-Start
236
  # ============================================================================
237
 
 
238
  def run_cold_start_tests():
239
  """Test cold-start handling."""
240
  from sage.services.cold_start import (
@@ -264,7 +296,9 @@ def run_cold_start_tests():
264
  for count in test_counts:
265
  level = get_warmup_level(count)
266
  weight = get_content_weight(count)
267
- logger.info(" %d interactions: level=%s, content_weight=%.1f", count, level, weight)
 
 
268
 
269
  # Test preferences to query
270
  log_section(logger, "2. PREFERENCES TO QUERY")
@@ -277,7 +311,7 @@ def run_cold_start_tests():
277
  )
278
  query = preferences_to_query(prefs)
279
  logger.info("Preferences: %s", prefs)
280
- logger.info("Query: \"%s\"", query)
281
 
282
  # Test cold-start recommendations
283
  log_section(logger, "3. COLD-START RECOMMENDATIONS")
@@ -290,7 +324,9 @@ def run_cold_start_tests():
290
  )
291
  logger.info("Got %d recommendations", len(recs))
292
  for r in recs[:3]:
293
- logger.info(" %s: score=%.3f, rating=%.1f", r.product_id, r.score, r.avg_rating)
 
 
294
 
295
  logger.info("Query-based (cold user):")
296
  recs = recommend_cold_start_user(
@@ -337,10 +373,12 @@ def run_cold_start_tests():
337
  # Main
338
  # ============================================================================
339
 
 
340
  def main():
341
  parser = argparse.ArgumentParser(description="Run explanation tests")
342
  parser.add_argument(
343
- "--section", "-s",
 
344
  choices=["all", "basic", "gate", "verify", "cold"],
345
  default="all",
346
  help="Which section to run",
 
40
  # SECTION: Basic Explanation Generation
41
  # ============================================================================
42
 
43
+
44
  def run_basic_tests():
45
  """Test basic explanation generation and HHEM detection."""
46
  from sage.services.explanation import Explainer
 
60
  query_results = {}
61
  for query in test_queries:
62
  products = get_candidates(
63
+ query=query,
64
+ k=TOP_K_PRODUCTS,
65
+ min_rating=4.0,
66
+ aggregation=AggregationMethod.MAX,
67
  )
68
  query_results[query] = products
69
+ logger.info('Query: "%s"', query)
70
  logger.info(" Found %d products", len(products))
71
 
72
  # Generate explanations
 
75
  all_explanations = []
76
 
77
  for query, products in query_results.items():
78
+ logger.info('--- Query: "%s" ---', query)
79
  for product in products[:PRODUCTS_PER_QUERY]:
80
  result = explainer.generate_explanation(query, product)
81
  all_explanations.append(result)
 
104
  if query_results:
105
  test_query = list(query_results.keys())[0]
106
  test_product = query_results[test_query][0]
107
+ logger.info('Query: "%s"', test_query)
108
  logger.info("Streaming: ")
109
 
110
  stream = explainer.generate_explanation_stream(test_query, test_product)
 
114
  logger.info("".join(chunks))
115
 
116
  streamed_result = stream.get_complete_result()
117
+ hhem = detector.check_explanation(
118
+ streamed_result.evidence_texts, streamed_result.explanation
119
+ )
120
  logger.info("HHEM Score: %.3f", hhem.score)
121
 
122
  log_banner(logger, "BASIC TESTS COMPLETE")
 
126
  # SECTION: Evidence Quality Gate
127
  # ============================================================================
128
 
129
+
130
+ def create_mock_product(
131
+ n_chunks: int, tokens_per_chunk: int = 100, product_score: float = 0.85
132
+ ) -> ProductScore:
133
  """Create a mock ProductScore for testing."""
134
  chunks = [
135
  RetrievedChunk(
 
154
  """Test the evidence quality gate."""
155
  from sage.core.evidence import check_evidence_quality, generate_refusal_message
156
  from sage.services.faithfulness import is_refusal
157
+ from sage.config import (
158
+ MIN_EVIDENCE_CHUNKS,
159
+ MIN_EVIDENCE_TOKENS,
160
+ MIN_RETRIEVAL_SCORE,
161
+ )
162
 
163
  log_banner(logger, "EVIDENCE QUALITY GATE TESTS")
164
 
 
174
  product = create_mock_product(n_chunks, tok, score)
175
  quality = check_evidence_quality(product)
176
  status = "PASS" if quality.is_sufficient == expected else "FAIL"
177
+ logger.info(
178
+ "[%s] %d chunks, %d tok, score=%.2f -> %s",
179
+ status,
180
+ n_chunks,
181
+ tok,
182
+ score,
183
+ reason,
184
+ )
185
  assert quality.is_sufficient == expected
186
 
187
  log_section(logger, "2. REFUSAL GENERATION")
 
192
  quality = check_evidence_quality(product)
193
  refusal = generate_refusal_message(query, quality)
194
  detected = is_refusal(refusal)
195
+ logger.info(
196
+ "[%s] Refusal detected for %s",
197
+ "PASS" if detected else "FAIL",
198
+ quality.failure_reason,
199
+ )
200
  assert detected
201
 
202
  logger.info(
203
  "Thresholds: chunks=%d, tokens=%d, score=%.2f",
204
+ MIN_EVIDENCE_CHUNKS,
205
+ MIN_EVIDENCE_TOKENS,
206
+ MIN_RETRIEVAL_SCORE,
207
  )
208
  log_banner(logger, "QUALITY GATE TESTS COMPLETE")
209
 
 
212
  # SECTION: Verification Loop
213
  # ============================================================================
214
 
215
+
216
  def run_verification_tests():
217
  """Test the post-generation verification loop."""
218
+ from sage.core.verification import (
219
+ extract_quotes,
220
+ verify_quote_in_evidence,
221
+ verify_explanation,
222
+ )
223
 
224
  log_banner(logger, "VERIFICATION LOOP TESTS")
225
 
 
266
  # SECTION: Cold-Start
267
  # ============================================================================
268
 
269
+
270
  def run_cold_start_tests():
271
  """Test cold-start handling."""
272
  from sage.services.cold_start import (
 
296
  for count in test_counts:
297
  level = get_warmup_level(count)
298
  weight = get_content_weight(count)
299
+ logger.info(
300
+ " %d interactions: level=%s, content_weight=%.1f", count, level, weight
301
+ )
302
 
303
  # Test preferences to query
304
  log_section(logger, "2. PREFERENCES TO QUERY")
 
311
  )
312
  query = preferences_to_query(prefs)
313
  logger.info("Preferences: %s", prefs)
314
+ logger.info('Query: "%s"', query)
315
 
316
  # Test cold-start recommendations
317
  log_section(logger, "3. COLD-START RECOMMENDATIONS")
 
324
  )
325
  logger.info("Got %d recommendations", len(recs))
326
  for r in recs[:3]:
327
+ logger.info(
328
+ " %s: score=%.3f, rating=%.1f", r.product_id, r.score, r.avg_rating
329
+ )
330
 
331
  logger.info("Query-based (cold user):")
332
  recs = recommend_cold_start_user(
 
373
  # Main
374
  # ============================================================================
375
 
376
+
377
  def main():
378
  parser = argparse.ArgumentParser(description="Run explanation tests")
379
  parser.add_argument(
380
+ "--section",
381
+ "-s",
382
  choices=["all", "basic", "gate", "verify", "cold"],
383
  default="all",
384
  help="Which section to run",
scripts/faithfulness.py CHANGED
@@ -47,6 +47,7 @@ TOP_K_PRODUCTS = 3
47
  # SECTION: Core Evaluation
48
  # ============================================================================
49
 
 
50
  def run_evaluation(n_samples: int, run_ragas: bool = False):
51
  """Run faithfulness evaluation on sample queries."""
52
  from sage.services.explanation import Explainer
@@ -64,8 +65,13 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
64
  all_explanations = []
65
 
66
  for i, query in enumerate(queries, 1):
67
- logger.info("[%d/%d] \"%s\"", i, len(queries), query)
68
- products = get_candidates(query=query, k=TOP_K_PRODUCTS, min_rating=4.0, aggregation=AggregationMethod.MAX)
 
 
 
 
 
69
 
70
  if not products:
71
  logger.info(" No products found")
@@ -73,7 +79,9 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
73
 
74
  product = products[0]
75
  try:
76
- result = explainer.generate_explanation(query, product, max_evidence=MAX_EVIDENCE)
 
 
77
  all_explanations.append(result)
78
  logger.info(" %s: %s...", product.product_id, result.explanation[:60])
79
  except Exception:
@@ -101,7 +109,9 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
101
 
102
  logger.info(
103
  "HHEM (full-explanation): %d/%d grounded, mean=%.3f",
104
- len(hhem_results) - n_hallucinated, len(hhem_results), np.mean(hhem_scores),
 
 
105
  )
106
 
107
  # Multi-metric faithfulness (claim-level as primary)
@@ -109,20 +119,24 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
109
 
110
  from sage.services.faithfulness import compute_multi_metric_faithfulness
111
 
112
- multi_items = [
113
- (expl.evidence_texts, expl.explanation) for expl in all_explanations
114
- ]
115
  multi_report = compute_multi_metric_faithfulness(multi_items)
116
 
117
- logger.info("Quote verification: %d/%d (%.1f%%)",
118
- multi_report.quotes_found, multi_report.quotes_total,
 
 
119
  multi_report.quote_verification_rate * 100,
120
  )
121
- logger.info("Claim-level HHEM: %.3f avg, %.1f%% pass rate",
122
- multi_report.claim_level_avg_score, multi_report.claim_level_pass_rate * 100,
 
 
123
  )
124
- logger.info("Full-explanation: %.3f avg, %.1f%% pass rate (reference only)",
125
- multi_report.full_explanation_avg_score, multi_report.full_explanation_pass_rate * 100,
 
 
126
  )
127
 
128
  # RAGAS (optional)
@@ -132,16 +146,17 @@ def run_evaluation(n_samples: int, run_ragas: bool = False):
132
 
133
  try:
134
  from sage.services.faithfulness import FaithfulnessEvaluator
 
135
  evaluator = FaithfulnessEvaluator()
136
  ragas_report = evaluator.evaluate_batch(all_explanations)
137
 
138
  logger.info(
139
  "Faithfulness: %.3f +/- %.3f",
140
- ragas_report.mean_score, ragas_report.std_score
 
141
  )
142
  logger.info(
143
- "Passing: %d/%d",
144
- ragas_report.n_passing, ragas_report.n_samples
145
  )
146
  except Exception:
147
  logger.exception("RAGAS evaluation failed")
@@ -217,8 +232,10 @@ def run_failure_analysis():
217
  case_id = 0
218
 
219
  for query in ANALYSIS_QUERIES:
220
- logger.info("Query: \"%s\"", query)
221
- products = get_candidates(query=query, k=3, min_rating=3.5, aggregation=AggregationMethod.MAX)
 
 
222
 
223
  if not products:
224
  continue
@@ -226,21 +243,27 @@ def run_failure_analysis():
226
  for product in products[:2]:
227
  try:
228
  result = explainer.generate_explanation(query, product, max_evidence=3)
229
- hhem = detector.check_explanation(result.evidence_texts, result.explanation)
 
 
230
 
231
  case_id += 1
232
- all_cases.append({
233
- "case_id": case_id,
234
- "query": query,
235
- "product_id": product.product_id,
236
- "explanation": result.explanation,
237
- "evidence_texts": result.evidence_texts,
238
- "hhem_score": hhem.score,
239
- "is_hallucinated": hhem.is_hallucinated,
240
- })
 
 
241
 
242
  status = "FAIL" if hhem.is_hallucinated else "PASS"
243
- logger.info(" [%s] %.3f - %s...", status, hhem.score, product.product_id[:20])
 
 
244
  except Exception:
245
  logger.exception(" Error processing product")
246
 
@@ -254,7 +277,9 @@ def run_failure_analysis():
254
 
255
  log_banner(logger, "ANALYSIS SUMMARY")
256
  logger.info("Total cases: %d", len(all_cases))
257
- logger.info("Failures: %d (%.1f%%)", len(failures), len(failures) / len(all_cases) * 100)
 
 
258
  logger.info("Passes: %d", len(passes))
259
 
260
  # Categorize failures
@@ -274,6 +299,7 @@ def run_failure_analysis():
274
  # SECTION: Adjusted Faithfulness
275
  # ============================================================================
276
 
 
277
  def run_adjusted_calculation():
278
  """Calculate adjusted faithfulness with refusals excluded."""
279
  from sage.services.faithfulness import is_refusal
@@ -296,8 +322,14 @@ def run_adjusted_calculation():
296
 
297
  # Classify
298
  refusals = [c for c in cases if is_refusal(c["explanation"])]
299
- non_refusal_passes = [c for c in cases if not is_refusal(c["explanation"]) and not c["is_hallucinated"]]
300
- non_refusal_fails = [c for c in cases if not is_refusal(c["explanation"]) and c["is_hallucinated"]]
 
 
 
 
 
 
301
 
302
  n_total = len(cases)
303
  raw_pass = sum(1 for c in cases if not c["is_hallucinated"])
@@ -309,9 +341,21 @@ def run_adjusted_calculation():
309
  logger.info("Non-refusal fails: %d", len(non_refusal_fails))
310
 
311
  log_section(logger, "Metrics")
312
- logger.info("Raw pass rate: %d/%d = %.1f%%", raw_pass, n_total, raw_pass / n_total * 100)
313
- logger.info("Adjusted pass rate: %d/%d = %.1f%%", adjusted_pass, n_total, adjusted_pass / n_total * 100)
314
- logger.info("Improvement: +%.1f%%", (adjusted_pass / n_total - raw_pass / n_total) * 100)
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  # Save
317
  output = {
@@ -328,12 +372,15 @@ def run_adjusted_calculation():
328
  # Main
329
  # ============================================================================
330
 
 
331
  def main():
332
  parser = argparse.ArgumentParser(description="Run faithfulness evaluation")
333
  parser.add_argument("--samples", "-n", type=int, default=DEFAULT_SAMPLES)
334
  parser.add_argument("--ragas", action="store_true", help="Include RAGAS evaluation")
335
  parser.add_argument("--analyze", action="store_true", help="Run failure analysis")
336
- parser.add_argument("--adjusted", action="store_true", help="Calculate adjusted metrics")
 
 
337
  args = parser.parse_args()
338
 
339
  if args.analyze:
 
47
  # SECTION: Core Evaluation
48
  # ============================================================================
49
 
50
+
51
  def run_evaluation(n_samples: int, run_ragas: bool = False):
52
  """Run faithfulness evaluation on sample queries."""
53
  from sage.services.explanation import Explainer
 
65
  all_explanations = []
66
 
67
  for i, query in enumerate(queries, 1):
68
+ logger.info('[%d/%d] "%s"', i, len(queries), query)
69
+ products = get_candidates(
70
+ query=query,
71
+ k=TOP_K_PRODUCTS,
72
+ min_rating=4.0,
73
+ aggregation=AggregationMethod.MAX,
74
+ )
75
 
76
  if not products:
77
  logger.info(" No products found")
 
79
 
80
  product = products[0]
81
  try:
82
+ result = explainer.generate_explanation(
83
+ query, product, max_evidence=MAX_EVIDENCE
84
+ )
85
  all_explanations.append(result)
86
  logger.info(" %s: %s...", product.product_id, result.explanation[:60])
87
  except Exception:
 
109
 
110
  logger.info(
111
  "HHEM (full-explanation): %d/%d grounded, mean=%.3f",
112
+ len(hhem_results) - n_hallucinated,
113
+ len(hhem_results),
114
+ np.mean(hhem_scores),
115
  )
116
 
117
  # Multi-metric faithfulness (claim-level as primary)
 
119
 
120
  from sage.services.faithfulness import compute_multi_metric_faithfulness
121
 
122
+ multi_items = [(expl.evidence_texts, expl.explanation) for expl in all_explanations]
 
 
123
  multi_report = compute_multi_metric_faithfulness(multi_items)
124
 
125
+ logger.info(
126
+ "Quote verification: %d/%d (%.1f%%)",
127
+ multi_report.quotes_found,
128
+ multi_report.quotes_total,
129
  multi_report.quote_verification_rate * 100,
130
  )
131
+ logger.info(
132
+ "Claim-level HHEM: %.3f avg, %.1f%% pass rate",
133
+ multi_report.claim_level_avg_score,
134
+ multi_report.claim_level_pass_rate * 100,
135
  )
136
+ logger.info(
137
+ "Full-explanation: %.3f avg, %.1f%% pass rate (reference only)",
138
+ multi_report.full_explanation_avg_score,
139
+ multi_report.full_explanation_pass_rate * 100,
140
  )
141
 
142
  # RAGAS (optional)
 
146
 
147
  try:
148
  from sage.services.faithfulness import FaithfulnessEvaluator
149
+
150
  evaluator = FaithfulnessEvaluator()
151
  ragas_report = evaluator.evaluate_batch(all_explanations)
152
 
153
  logger.info(
154
  "Faithfulness: %.3f +/- %.3f",
155
+ ragas_report.mean_score,
156
+ ragas_report.std_score,
157
  )
158
  logger.info(
159
+ "Passing: %d/%d", ragas_report.n_passing, ragas_report.n_samples
 
160
  )
161
  except Exception:
162
  logger.exception("RAGAS evaluation failed")
 
232
  case_id = 0
233
 
234
  for query in ANALYSIS_QUERIES:
235
+ logger.info('Query: "%s"', query)
236
+ products = get_candidates(
237
+ query=query, k=3, min_rating=3.5, aggregation=AggregationMethod.MAX
238
+ )
239
 
240
  if not products:
241
  continue
 
243
  for product in products[:2]:
244
  try:
245
  result = explainer.generate_explanation(query, product, max_evidence=3)
246
+ hhem = detector.check_explanation(
247
+ result.evidence_texts, result.explanation
248
+ )
249
 
250
  case_id += 1
251
+ all_cases.append(
252
+ {
253
+ "case_id": case_id,
254
+ "query": query,
255
+ "product_id": product.product_id,
256
+ "explanation": result.explanation,
257
+ "evidence_texts": result.evidence_texts,
258
+ "hhem_score": hhem.score,
259
+ "is_hallucinated": hhem.is_hallucinated,
260
+ }
261
+ )
262
 
263
  status = "FAIL" if hhem.is_hallucinated else "PASS"
264
+ logger.info(
265
+ " [%s] %.3f - %s...", status, hhem.score, product.product_id[:20]
266
+ )
267
  except Exception:
268
  logger.exception(" Error processing product")
269
 
 
277
 
278
  log_banner(logger, "ANALYSIS SUMMARY")
279
  logger.info("Total cases: %d", len(all_cases))
280
+ logger.info(
281
+ "Failures: %d (%.1f%%)", len(failures), len(failures) / len(all_cases) * 100
282
+ )
283
  logger.info("Passes: %d", len(passes))
284
 
285
  # Categorize failures
 
299
  # SECTION: Adjusted Faithfulness
300
  # ============================================================================
301
 
302
+
303
  def run_adjusted_calculation():
304
  """Calculate adjusted faithfulness with refusals excluded."""
305
  from sage.services.faithfulness import is_refusal
 
322
 
323
  # Classify
324
  refusals = [c for c in cases if is_refusal(c["explanation"])]
325
+ non_refusal_passes = [
326
+ c
327
+ for c in cases
328
+ if not is_refusal(c["explanation"]) and not c["is_hallucinated"]
329
+ ]
330
+ non_refusal_fails = [
331
+ c for c in cases if not is_refusal(c["explanation"]) and c["is_hallucinated"]
332
+ ]
333
 
334
  n_total = len(cases)
335
  raw_pass = sum(1 for c in cases if not c["is_hallucinated"])
 
341
  logger.info("Non-refusal fails: %d", len(non_refusal_fails))
342
 
343
  log_section(logger, "Metrics")
344
+ logger.info(
345
+ "Raw pass rate: %d/%d = %.1f%%",
346
+ raw_pass,
347
+ n_total,
348
+ raw_pass / n_total * 100,
349
+ )
350
+ logger.info(
351
+ "Adjusted pass rate: %d/%d = %.1f%%",
352
+ adjusted_pass,
353
+ n_total,
354
+ adjusted_pass / n_total * 100,
355
+ )
356
+ logger.info(
357
+ "Improvement: +%.1f%%", (adjusted_pass / n_total - raw_pass / n_total) * 100
358
+ )
359
 
360
  # Save
361
  output = {
 
372
  # Main
373
  # ============================================================================
374
 
375
+
376
  def main():
377
  parser = argparse.ArgumentParser(description="Run faithfulness evaluation")
378
  parser.add_argument("--samples", "-n", type=int, default=DEFAULT_SAMPLES)
379
  parser.add_argument("--ragas", action="store_true", help="Include RAGAS evaluation")
380
  parser.add_argument("--analyze", action="store_true", help="Run failure analysis")
381
+ parser.add_argument(
382
+ "--adjusted", action="store_true", help="Calculate adjusted metrics"
383
+ )
384
  args = parser.parse_args()
385
 
386
  if args.analyze:
scripts/human_eval.py CHANGED
@@ -51,6 +51,7 @@ NATURAL_QUERIES_FILE = DATA_DIR / "eval" / "eval_natural_queries.json"
51
  # Sample Generation
52
  # ============================================================================
53
 
 
54
  def _select_diverse_natural_queries(target: int = 35) -> list[str]:
55
  """Select diverse queries from natural eval dataset, balanced by category."""
56
  if not NATURAL_QUERIES_FILE.exists():
@@ -114,7 +115,8 @@ def generate_samples(force: bool = False):
114
  logger.error(
115
  "%s contains %d rated samples. "
116
  "Use --force to overwrite, or run --annotate to continue.",
117
- SAMPLES_FILE, rated,
 
118
  )
119
  sys.exit(1)
120
 
@@ -129,7 +131,9 @@ def generate_samples(force: bool = False):
129
  all_queries = natural + config
130
  logger.info(
131
  "Queries: %d natural + %d config = %d total",
132
- len(natural), len(config), len(all_queries),
 
 
133
  )
134
 
135
  if len(all_queries) < TARGET_SAMPLES:
@@ -137,7 +141,8 @@ def generate_samples(force: bool = False):
137
  "Only %d unique queries available (target: %d). "
138
  "Results will lack statistical power. "
139
  "Run 'make eval' to build natural query dataset.",
140
- len(all_queries), TARGET_SAMPLES,
 
141
  )
142
 
143
  # Initialize services
@@ -146,10 +151,12 @@ def generate_samples(force: bool = False):
146
 
147
  samples = []
148
  for i, query in enumerate(all_queries, 1):
149
- logger.info("[%d/%d] \"%s\"", i, len(all_queries), query)
150
 
151
  products = get_candidates(
152
- query=query, k=1, min_rating=4.0,
 
 
153
  aggregation=AggregationMethod.MAX,
154
  )
155
  if not products:
@@ -159,10 +166,13 @@ def generate_samples(force: bool = False):
159
  product = products[0]
160
  try:
161
  expl = explainer.generate_explanation(
162
- query, product, max_evidence=MAX_EVIDENCE,
 
 
163
  )
164
  hhem = detector.check_explanation(
165
- expl.evidence_texts, expl.explanation,
 
166
  )
167
 
168
  sample = {
@@ -178,7 +188,9 @@ def generate_samples(force: bool = False):
178
  samples.append(sample)
179
  logger.info(
180
  " %s (%.1f stars) HHEM=%.3f",
181
- product.product_id, product.avg_rating, hhem.score,
 
 
182
  )
183
  except ValueError as exc:
184
  logger.info(" Quality gate refusal: %s", exc)
@@ -197,6 +209,7 @@ def generate_samples(force: bool = False):
197
  # Interactive Annotation
198
  # ============================================================================
199
 
 
200
  def _load_samples() -> list[dict]:
201
  """Load samples from disk."""
202
  if not SAMPLES_FILE.exists():
@@ -261,7 +274,7 @@ def annotate_samples():
261
  text = ev["text"]
262
  if len(text) > 200:
263
  text = text[:200] + "..."
264
- print(f" [{ev['id']}]: \"{text}\"")
265
  print()
266
 
267
  # Collect ratings
@@ -286,6 +299,7 @@ def annotate_samples():
286
  # Analysis
287
  # ============================================================================
288
 
 
289
  def analyze_results():
290
  """Compute aggregate metrics from rated samples."""
291
  samples = _load_samples()
@@ -306,7 +320,7 @@ def analyze_results():
306
  n = len(scores)
307
  mean = sum(scores) / n
308
  variance = sum((x - mean) ** 2 for x in scores) / (n - 1) if n > 1 else 0.0
309
- std = variance ** 0.5
310
  dimensions_results[dim_key] = {
311
  "mean": round(mean, 2),
312
  "std": round(std, 2),
@@ -315,7 +329,11 @@ def analyze_results():
315
  }
316
  logger.info(
317
  " %-15s mean=%.2f std=%.2f range=[%d, %d]",
318
- dim_key + ":", mean, std, min(scores), max(scores),
 
 
 
 
319
  )
320
 
321
  # Overall helpfulness: mean of per-sample averages
@@ -328,15 +346,20 @@ def analyze_results():
328
  passed = overall >= HELPFULNESS_TARGET
329
 
330
  logger.info("")
331
- logger.info("Overall helpfulness: %.2f (target: %.1f) [%s]",
332
- overall, HELPFULNESS_TARGET, "PASS" if passed else "FAIL")
 
 
 
 
333
 
334
  # HHEM vs Trust correlation (Spearman)
335
  correlation = _compute_hhem_trust_correlation(rated)
336
  if correlation:
337
  logger.info(
338
  "HHEM-Trust correlation: r=%.3f, p=%.4f",
339
- correlation["spearman_r"], correlation["p_value"],
 
340
  )
341
 
342
  # Save results
@@ -368,6 +391,7 @@ def _compute_hhem_trust_correlation(rated: list[dict]) -> dict | None:
368
 
369
  try:
370
  from scipy.stats import spearmanr
 
371
  r, p = spearmanr(hhem_scores, trust_scores)
372
  return {"spearman_r": round(float(r), 4), "p_value": round(float(p), 4)}
373
  except ImportError:
@@ -399,13 +423,13 @@ def _manual_spearman(x: list[float], y: list[float]) -> dict | None:
399
  ry = _rank(y)
400
 
401
  d_sq = sum((rx[i] - ry[i]) ** 2 for i in range(n))
402
- rho = 1 - (6 * d_sq) / (n * (n ** 2 - 1))
403
 
404
  # Approximate p-value via t-distribution (large sample)
405
  if abs(rho) >= 1.0:
406
  p = 0.0
407
  else:
408
- t = rho * math.sqrt((n - 2) / (1 - rho ** 2))
409
  # Two-tailed p-value approximation
410
  p = 2 * (1 - _t_cdf_approx(abs(t), n - 2))
411
 
@@ -427,6 +451,7 @@ def _t_cdf_approx(t: float, df: int) -> float:
427
  # Status
428
  # ============================================================================
429
 
 
430
  def show_status():
431
  """Show annotation progress."""
432
  if not SAMPLES_FILE.exists():
@@ -450,21 +475,27 @@ def show_status():
450
  # Main
451
  # ============================================================================
452
 
 
453
  def main():
454
  parser = argparse.ArgumentParser(
455
  description="Human evaluation of recommendation explanations",
456
  )
457
  group = parser.add_mutually_exclusive_group(required=True)
458
- group.add_argument("--generate", action="store_true",
459
- help="Generate recommendation samples")
460
- group.add_argument("--annotate", action="store_true",
461
- help="Rate samples interactively (resumable)")
462
- group.add_argument("--analyze", action="store_true",
463
- help="Compute aggregate results from ratings")
464
- group.add_argument("--status", action="store_true",
465
- help="Show annotation progress")
466
- parser.add_argument("--force", action="store_true",
467
- help="Overwrite existing rated samples (with --generate)")
 
 
 
 
 
468
  args = parser.parse_args()
469
 
470
  if args.force and not args.generate:
 
51
  # Sample Generation
52
  # ============================================================================
53
 
54
+
55
  def _select_diverse_natural_queries(target: int = 35) -> list[str]:
56
  """Select diverse queries from natural eval dataset, balanced by category."""
57
  if not NATURAL_QUERIES_FILE.exists():
 
115
  logger.error(
116
  "%s contains %d rated samples. "
117
  "Use --force to overwrite, or run --annotate to continue.",
118
+ SAMPLES_FILE,
119
+ rated,
120
  )
121
  sys.exit(1)
122
 
 
131
  all_queries = natural + config
132
  logger.info(
133
  "Queries: %d natural + %d config = %d total",
134
+ len(natural),
135
+ len(config),
136
+ len(all_queries),
137
  )
138
 
139
  if len(all_queries) < TARGET_SAMPLES:
 
141
  "Only %d unique queries available (target: %d). "
142
  "Results will lack statistical power. "
143
  "Run 'make eval' to build natural query dataset.",
144
+ len(all_queries),
145
+ TARGET_SAMPLES,
146
  )
147
 
148
  # Initialize services
 
151
 
152
  samples = []
153
  for i, query in enumerate(all_queries, 1):
154
+ logger.info('[%d/%d] "%s"', i, len(all_queries), query)
155
 
156
  products = get_candidates(
157
+ query=query,
158
+ k=1,
159
+ min_rating=4.0,
160
  aggregation=AggregationMethod.MAX,
161
  )
162
  if not products:
 
166
  product = products[0]
167
  try:
168
  expl = explainer.generate_explanation(
169
+ query,
170
+ product,
171
+ max_evidence=MAX_EVIDENCE,
172
  )
173
  hhem = detector.check_explanation(
174
+ expl.evidence_texts,
175
+ expl.explanation,
176
  )
177
 
178
  sample = {
 
188
  samples.append(sample)
189
  logger.info(
190
  " %s (%.1f stars) HHEM=%.3f",
191
+ product.product_id,
192
+ product.avg_rating,
193
+ hhem.score,
194
  )
195
  except ValueError as exc:
196
  logger.info(" Quality gate refusal: %s", exc)
 
209
  # Interactive Annotation
210
  # ============================================================================
211
 
212
+
213
  def _load_samples() -> list[dict]:
214
  """Load samples from disk."""
215
  if not SAMPLES_FILE.exists():
 
274
  text = ev["text"]
275
  if len(text) > 200:
276
  text = text[:200] + "..."
277
+ print(f' [{ev["id"]}]: "{text}"')
278
  print()
279
 
280
  # Collect ratings
 
299
  # Analysis
300
  # ============================================================================
301
 
302
+
303
  def analyze_results():
304
  """Compute aggregate metrics from rated samples."""
305
  samples = _load_samples()
 
320
  n = len(scores)
321
  mean = sum(scores) / n
322
  variance = sum((x - mean) ** 2 for x in scores) / (n - 1) if n > 1 else 0.0
323
+ std = variance**0.5
324
  dimensions_results[dim_key] = {
325
  "mean": round(mean, 2),
326
  "std": round(std, 2),
 
329
  }
330
  logger.info(
331
  " %-15s mean=%.2f std=%.2f range=[%d, %d]",
332
+ dim_key + ":",
333
+ mean,
334
+ std,
335
+ min(scores),
336
+ max(scores),
337
  )
338
 
339
  # Overall helpfulness: mean of per-sample averages
 
346
  passed = overall >= HELPFULNESS_TARGET
347
 
348
  logger.info("")
349
+ logger.info(
350
+ "Overall helpfulness: %.2f (target: %.1f) [%s]",
351
+ overall,
352
+ HELPFULNESS_TARGET,
353
+ "PASS" if passed else "FAIL",
354
+ )
355
 
356
  # HHEM vs Trust correlation (Spearman)
357
  correlation = _compute_hhem_trust_correlation(rated)
358
  if correlation:
359
  logger.info(
360
  "HHEM-Trust correlation: r=%.3f, p=%.4f",
361
+ correlation["spearman_r"],
362
+ correlation["p_value"],
363
  )
364
 
365
  # Save results
 
391
 
392
  try:
393
  from scipy.stats import spearmanr
394
+
395
  r, p = spearmanr(hhem_scores, trust_scores)
396
  return {"spearman_r": round(float(r), 4), "p_value": round(float(p), 4)}
397
  except ImportError:
 
423
  ry = _rank(y)
424
 
425
  d_sq = sum((rx[i] - ry[i]) ** 2 for i in range(n))
426
+ rho = 1 - (6 * d_sq) / (n * (n**2 - 1))
427
 
428
  # Approximate p-value via t-distribution (large sample)
429
  if abs(rho) >= 1.0:
430
  p = 0.0
431
  else:
432
+ t = rho * math.sqrt((n - 2) / (1 - rho**2))
433
  # Two-tailed p-value approximation
434
  p = 2 * (1 - _t_cdf_approx(abs(t), n - 2))
435
 
 
451
  # Status
452
  # ============================================================================
453
 
454
+
455
  def show_status():
456
  """Show annotation progress."""
457
  if not SAMPLES_FILE.exists():
 
475
  # Main
476
  # ============================================================================
477
 
478
+
479
  def main():
480
  parser = argparse.ArgumentParser(
481
  description="Human evaluation of recommendation explanations",
482
  )
483
  group = parser.add_mutually_exclusive_group(required=True)
484
+ group.add_argument(
485
+ "--generate", action="store_true", help="Generate recommendation samples"
486
+ )
487
+ group.add_argument(
488
+ "--annotate", action="store_true", help="Rate samples interactively (resumable)"
489
+ )
490
+ group.add_argument(
491
+ "--analyze", action="store_true", help="Compute aggregate results from ratings"
492
+ )
493
+ group.add_argument("--status", action="store_true", help="Show annotation progress")
494
+ parser.add_argument(
495
+ "--force",
496
+ action="store_true",
497
+ help="Overwrite existing rated samples (with --generate)",
498
+ )
499
  args = parser.parse_args()
500
 
501
  if args.force and not args.generate:
scripts/pipeline.py CHANGED
@@ -54,6 +54,7 @@ logger = get_logger(__name__)
54
  # TOKENIZER VALIDATION (--validate-tokenizer)
55
  # ============================================================================
56
 
 
57
  def run_tokenizer_validation():
58
  """Validate the chars/token ratio assumption used in chunker.py."""
59
  from transformers import AutoTokenizer
@@ -90,10 +91,16 @@ def run_tokenizer_validation():
90
  # CHUNKING QUALITY TEST (--test-chunking)
91
  # ============================================================================
92
 
 
93
  def run_chunking_test():
94
  """Test chunking quality on long reviews."""
95
  import pandas as pd
96
- from sage.core.chunking import chunk_text, split_sentences, estimate_tokens, NO_CHUNK_THRESHOLD
 
 
 
 
 
97
 
98
  log_banner(logger, "CHUNKING QUALITY TEST", width=70)
99
 
@@ -113,28 +120,36 @@ def run_chunking_test():
113
  chunks = chunk_text(text, embedder=embedder)
114
  sentences = split_sentences(text)
115
 
116
- results.append({
117
- "tokens": tokens,
118
- "sentences": len(sentences),
119
- "chunks": len(chunks),
120
- "avg_chunk_tokens": np.mean([estimate_tokens(c) for c in chunks]),
121
- })
 
 
122
 
123
  if idx < 5:
124
  logger.info(
125
  "Review %d [%d*] (%d tok) -> %d chunks",
126
- idx + 1, rating, tokens, len(chunks)
 
 
 
127
  )
128
 
129
  results_df = pd.DataFrame(results)
130
  log_section(logger, f"Summary ({len(results_df)} reviews)")
131
  logger.info(
132
  "Chunks per review: %.2f (median: %.0f)",
133
- results_df["chunks"].mean(), results_df["chunks"].median()
 
134
  )
135
  logger.info("Avg tokens/chunk: %.0f", results_df["avg_chunk_tokens"].mean())
136
 
137
- expansion = (results_df["chunks"] * results_df["avg_chunk_tokens"]).sum() / results_df["tokens"].sum()
 
 
138
  logger.info("Expansion ratio: %.2fx", expansion)
139
 
140
 
@@ -142,6 +157,7 @@ def run_chunking_test():
142
  # MAIN PIPELINE
143
  # ============================================================================
144
 
 
145
  def run_pipeline(subset_size: int, force: bool):
146
  """Run the full data pipeline: load, chunk, embed, upload."""
147
  logger.info("Config", extra={"subset_size": subset_size, "force": force})
@@ -174,7 +190,8 @@ def run_pipeline(subset_size: int, force: bool):
174
  needs_chunking = (df["estimated_tokens"] > 200).sum()
175
  logger.info(
176
  "Reviews needing chunking (>200 tokens): %d (%.1f%%)",
177
- needs_chunking, needs_chunking / len(df) * 100
 
178
  )
179
 
180
  # Prepare reviews for chunking
@@ -192,7 +209,9 @@ def run_pipeline(subset_size: int, force: bool):
192
  chunks = chunk_reviews_batch(reviews_for_chunking, embedder=embedder)
193
  logger.info(
194
  "Created %d chunks from %d reviews (expansion: %.2fx)",
195
- len(chunks), len(reviews_for_chunking), len(chunks) / len(reviews_for_chunking)
 
 
196
  )
197
 
198
  # Generate embeddings
@@ -200,7 +219,9 @@ def run_pipeline(subset_size: int, force: bool):
200
  cache_path = DATA_DIR / f"embeddings_{len(chunks)}.npy"
201
 
202
  logger.info("Embedding %d chunks...", len(chunk_texts))
203
- embeddings = embedder.embed_passages(chunk_texts, cache_path=cache_path, force=force)
 
 
204
  logger.info("Embeddings shape: %s", embeddings.shape)
205
 
206
  # Embedding technical validation
@@ -243,12 +264,21 @@ def run_pipeline(subset_size: int, force: bool):
243
  sim_out = float(np.dot(emb_query, emb_out))
244
 
245
  logger.info("Query: '%s'", test_query)
246
- logger.info(" In-domain (same topic): '%s' = %.3f", in_domain_similar, sim_in_similar)
247
- logger.info(" In-domain (diff topic): '%s' = %.3f", in_domain_different, sim_in_different)
 
 
 
 
248
  logger.info(" Out-of-domain: '%s' = %.3f", out_of_domain, sim_out)
249
 
250
  if sim_in_similar > sim_in_different > sim_out:
251
- logger.info("Ranking correct: %.3f > %.3f > %.3f", sim_in_similar, sim_in_different, sim_out)
 
 
 
 
 
252
  else:
253
  logger.warning("Unexpected ranking")
254
 
@@ -307,10 +337,23 @@ def run_pipeline(subset_size: int, force: bool):
307
 
308
  def main():
309
  parser = argparse.ArgumentParser(description="Run the data pipeline")
310
- parser.add_argument("--force", action="store_true", help="Force recreate collection")
311
- parser.add_argument("--subset-size", type=int, default=DEV_SUBSET_SIZE, help="Number of reviews to load initially")
312
- parser.add_argument("--validate-tokenizer", action="store_true", help="Run tokenizer validation only")
313
- parser.add_argument("--test-chunking", action="store_true", help="Run chunking quality test only")
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  args = parser.parse_args()
315
 
316
  if args.validate_tokenizer:
 
54
  # TOKENIZER VALIDATION (--validate-tokenizer)
55
  # ============================================================================
56
 
57
+
58
  def run_tokenizer_validation():
59
  """Validate the chars/token ratio assumption used in chunker.py."""
60
  from transformers import AutoTokenizer
 
91
  # CHUNKING QUALITY TEST (--test-chunking)
92
  # ============================================================================
93
 
94
+
95
  def run_chunking_test():
96
  """Test chunking quality on long reviews."""
97
  import pandas as pd
98
+ from sage.core.chunking import (
99
+ chunk_text,
100
+ split_sentences,
101
+ estimate_tokens,
102
+ NO_CHUNK_THRESHOLD,
103
+ )
104
 
105
  log_banner(logger, "CHUNKING QUALITY TEST", width=70)
106
 
 
120
  chunks = chunk_text(text, embedder=embedder)
121
  sentences = split_sentences(text)
122
 
123
+ results.append(
124
+ {
125
+ "tokens": tokens,
126
+ "sentences": len(sentences),
127
+ "chunks": len(chunks),
128
+ "avg_chunk_tokens": np.mean([estimate_tokens(c) for c in chunks]),
129
+ }
130
+ )
131
 
132
  if idx < 5:
133
  logger.info(
134
  "Review %d [%d*] (%d tok) -> %d chunks",
135
+ idx + 1,
136
+ rating,
137
+ tokens,
138
+ len(chunks),
139
  )
140
 
141
  results_df = pd.DataFrame(results)
142
  log_section(logger, f"Summary ({len(results_df)} reviews)")
143
  logger.info(
144
  "Chunks per review: %.2f (median: %.0f)",
145
+ results_df["chunks"].mean(),
146
+ results_df["chunks"].median(),
147
  )
148
  logger.info("Avg tokens/chunk: %.0f", results_df["avg_chunk_tokens"].mean())
149
 
150
+ expansion = (
151
+ results_df["chunks"] * results_df["avg_chunk_tokens"]
152
+ ).sum() / results_df["tokens"].sum()
153
  logger.info("Expansion ratio: %.2fx", expansion)
154
 
155
 
 
157
  # MAIN PIPELINE
158
  # ============================================================================
159
 
160
+
161
  def run_pipeline(subset_size: int, force: bool):
162
  """Run the full data pipeline: load, chunk, embed, upload."""
163
  logger.info("Config", extra={"subset_size": subset_size, "force": force})
 
190
  needs_chunking = (df["estimated_tokens"] > 200).sum()
191
  logger.info(
192
  "Reviews needing chunking (>200 tokens): %d (%.1f%%)",
193
+ needs_chunking,
194
+ needs_chunking / len(df) * 100,
195
  )
196
 
197
  # Prepare reviews for chunking
 
209
  chunks = chunk_reviews_batch(reviews_for_chunking, embedder=embedder)
210
  logger.info(
211
  "Created %d chunks from %d reviews (expansion: %.2fx)",
212
+ len(chunks),
213
+ len(reviews_for_chunking),
214
+ len(chunks) / len(reviews_for_chunking),
215
  )
216
 
217
  # Generate embeddings
 
219
  cache_path = DATA_DIR / f"embeddings_{len(chunks)}.npy"
220
 
221
  logger.info("Embedding %d chunks...", len(chunk_texts))
222
+ embeddings = embedder.embed_passages(
223
+ chunk_texts, cache_path=cache_path, force=force
224
+ )
225
  logger.info("Embeddings shape: %s", embeddings.shape)
226
 
227
  # Embedding technical validation
 
264
  sim_out = float(np.dot(emb_query, emb_out))
265
 
266
  logger.info("Query: '%s'", test_query)
267
+ logger.info(
268
+ " In-domain (same topic): '%s' = %.3f", in_domain_similar, sim_in_similar
269
+ )
270
+ logger.info(
271
+ " In-domain (diff topic): '%s' = %.3f", in_domain_different, sim_in_different
272
+ )
273
  logger.info(" Out-of-domain: '%s' = %.3f", out_of_domain, sim_out)
274
 
275
  if sim_in_similar > sim_in_different > sim_out:
276
+ logger.info(
277
+ "Ranking correct: %.3f > %.3f > %.3f",
278
+ sim_in_similar,
279
+ sim_in_different,
280
+ sim_out,
281
+ )
282
  else:
283
  logger.warning("Unexpected ranking")
284
 
 
337
 
338
  def main():
339
  parser = argparse.ArgumentParser(description="Run the data pipeline")
340
+ parser.add_argument(
341
+ "--force", action="store_true", help="Force recreate collection"
342
+ )
343
+ parser.add_argument(
344
+ "--subset-size",
345
+ type=int,
346
+ default=DEV_SUBSET_SIZE,
347
+ help="Number of reviews to load initially",
348
+ )
349
+ parser.add_argument(
350
+ "--validate-tokenizer",
351
+ action="store_true",
352
+ help="Run tokenizer validation only",
353
+ )
354
+ parser.add_argument(
355
+ "--test-chunking", action="store_true", help="Run chunking quality test only"
356
+ )
357
  args = parser.parse_args()
358
 
359
  if args.validate_tokenizer:
scripts/sanity_checks.py CHANGED
@@ -42,6 +42,7 @@ RESULTS_DIR.mkdir(exist_ok=True)
42
  # SECTION: Spot-Check
43
  # ============================================================================
44
 
 
45
  def run_spot_check():
46
  """Manual spot-check of explanations vs evidence."""
47
  from sage.services.explanation import Explainer
@@ -56,18 +57,24 @@ def run_spot_check():
56
  queries = EVALUATION_QUERIES[:5]
57
 
58
  for query in queries:
59
- products = get_candidates(query=query, k=2, min_rating=4.0, aggregation=AggregationMethod.MAX)
 
 
60
 
61
  for product in products[:2]:
62
  result = explainer.generate_explanation(query, product, max_evidence=3)
63
  hhem = detector.check_explanation(result.evidence_texts, result.explanation)
64
 
65
  log_section(logger, f"SAMPLE {len(results) + 1}")
66
- logger.info("Query: \"%s\"", query)
67
- logger.info("HHEM: %.3f (%s)", hhem.score, "PASS" if not hhem.is_hallucinated else "FAIL")
 
 
 
 
68
  logger.info("EVIDENCE:")
69
  for ev in result.evidence_texts[:2]:
70
- logger.info(" \"%s...\"", ev[:100])
71
  logger.info("EXPLANATION:")
72
  logger.info(" %s", result.explanation)
73
 
@@ -86,6 +93,7 @@ def run_spot_check():
86
  # SECTION: Adversarial Tests
87
  # ============================================================================
88
 
 
89
  def run_adversarial_tests():
90
  """Test with contradictory evidence."""
91
  from sage.services.explanation import Explainer
@@ -118,10 +126,28 @@ def run_adversarial_tests():
118
  log_section(logger, case["name"])
119
 
120
  chunks = [
121
- RetrievedChunk(text=case["positive"], score=0.9, product_id="TEST", rating=5.0, review_id="pos"),
122
- RetrievedChunk(text=case["negative"], score=0.85, product_id="TEST", rating=1.0, review_id="neg"),
 
 
 
 
 
 
 
 
 
 
 
 
123
  ]
124
- product = ProductScore(product_id="TEST", score=0.85, chunk_count=2, avg_rating=3.0, evidence=chunks)
 
 
 
 
 
 
125
 
126
  result = explainer.generate_explanation(case["query"], product, max_evidence=2)
127
  hhem = detector.check_explanation(result.evidence_texts, result.explanation)
@@ -142,6 +168,7 @@ def run_adversarial_tests():
142
  # SECTION: Empty Context Tests
143
  # ============================================================================
144
 
 
145
  def run_empty_context_tests():
146
  """Test graceful refusal with irrelevant evidence."""
147
  from sage.services.explanation import Explainer
@@ -153,19 +180,46 @@ def run_empty_context_tests():
153
  detector = HallucinationDetector()
154
 
155
  cases = [
156
- {"name": "Irrelevant", "query": "quantum computing textbook", "evidence": "Great USB cable."},
 
 
 
 
157
  {"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
158
- {"name": "Foreign", "query": "wireless mouse", "evidence": "Muy bueno el producto."},
 
 
 
 
159
  ]
160
 
161
- refusal_words = ["cannot", "can't", "unable", "no evidence", "insufficient", "limited"]
 
 
 
 
 
 
 
162
  results = []
163
 
164
  for case in cases:
165
  log_section(logger, case["name"])
166
 
167
- chunk = RetrievedChunk(text=case["evidence"], score=0.3, product_id="TEST", rating=3.0, review_id="r1")
168
- product = ProductScore(product_id="TEST", score=0.3, chunk_count=1, avg_rating=3.0, evidence=[chunk])
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  result = explainer.generate_explanation(case["query"], product, max_evidence=1)
171
  _hhem = detector.check_explanation(result.evidence_texts, result.explanation)
@@ -185,6 +239,7 @@ def run_empty_context_tests():
185
  # SECTION: Calibration Check
186
  # ============================================================================
187
 
 
188
  @dataclass
189
  class CalibrationSample:
190
  query: str
@@ -210,21 +265,27 @@ def run_calibration_check():
210
 
211
  logger.info("Generating samples...")
212
  for query in queries:
213
- products = get_candidates(query=query, k=5, min_rating=3.0, aggregation=AggregationMethod.MAX)
 
 
214
 
215
  for product in products[:2]:
216
  try:
217
  result = explainer.generate_explanation(query, product, max_evidence=3)
218
- hhem = detector.check_explanation(result.evidence_texts, result.explanation)
219
-
220
- samples.append(CalibrationSample(
221
- query=query,
222
- product_id=product.product_id,
223
- retrieval_score=product.score,
224
- evidence_count=product.chunk_count,
225
- avg_rating=product.avg_rating,
226
- hhem_score=hhem.score,
227
- ))
 
 
 
 
228
  except Exception:
229
  logger.debug("Error generating sample", exc_info=True)
230
 
@@ -251,24 +312,28 @@ def run_calibration_check():
251
  # Stratified analysis
252
  sorted_samples = sorted(samples, key=lambda s: s.retrieval_score)
253
  n = len(sorted_samples)
254
- low = sorted_samples[:n//3]
255
- mid = sorted_samples[n//3:2*n//3]
256
- high = sorted_samples[2*n//3:]
257
 
258
  log_section(logger, "HHEM by Confidence Tier")
259
  logger.info(" LOW (n=%2d): %.3f", len(low), np.mean([s.hhem_score for s in low]))
260
  logger.info(" MED (n=%2d): %.3f", len(mid), np.mean([s.hhem_score for s in mid]))
261
- logger.info(" HIGH (n=%2d): %.3f", len(high), np.mean([s.hhem_score for s in high]))
 
 
262
 
263
 
264
  # ============================================================================
265
  # Main
266
  # ============================================================================
267
 
 
268
  def main():
269
  parser = argparse.ArgumentParser(description="Run pipeline sanity checks")
270
  parser.add_argument(
271
- "--section", "-s",
 
272
  choices=["all", "spot", "adversarial", "empty", "calibration"],
273
  default="all",
274
  help="Which section to run",
 
42
  # SECTION: Spot-Check
43
  # ============================================================================
44
 
45
+
46
  def run_spot_check():
47
  """Manual spot-check of explanations vs evidence."""
48
  from sage.services.explanation import Explainer
 
57
  queries = EVALUATION_QUERIES[:5]
58
 
59
  for query in queries:
60
+ products = get_candidates(
61
+ query=query, k=2, min_rating=4.0, aggregation=AggregationMethod.MAX
62
+ )
63
 
64
  for product in products[:2]:
65
  result = explainer.generate_explanation(query, product, max_evidence=3)
66
  hhem = detector.check_explanation(result.evidence_texts, result.explanation)
67
 
68
  log_section(logger, f"SAMPLE {len(results) + 1}")
69
+ logger.info('Query: "%s"', query)
70
+ logger.info(
71
+ "HHEM: %.3f (%s)",
72
+ hhem.score,
73
+ "PASS" if not hhem.is_hallucinated else "FAIL",
74
+ )
75
  logger.info("EVIDENCE:")
76
  for ev in result.evidence_texts[:2]:
77
+ logger.info(' "%s..."', ev[:100])
78
  logger.info("EXPLANATION:")
79
  logger.info(" %s", result.explanation)
80
 
 
93
  # SECTION: Adversarial Tests
94
  # ============================================================================
95
 
96
+
97
  def run_adversarial_tests():
98
  """Test with contradictory evidence."""
99
  from sage.services.explanation import Explainer
 
126
  log_section(logger, case["name"])
127
 
128
  chunks = [
129
+ RetrievedChunk(
130
+ text=case["positive"],
131
+ score=0.9,
132
+ product_id="TEST",
133
+ rating=5.0,
134
+ review_id="pos",
135
+ ),
136
+ RetrievedChunk(
137
+ text=case["negative"],
138
+ score=0.85,
139
+ product_id="TEST",
140
+ rating=1.0,
141
+ review_id="neg",
142
+ ),
143
  ]
144
+ product = ProductScore(
145
+ product_id="TEST",
146
+ score=0.85,
147
+ chunk_count=2,
148
+ avg_rating=3.0,
149
+ evidence=chunks,
150
+ )
151
 
152
  result = explainer.generate_explanation(case["query"], product, max_evidence=2)
153
  hhem = detector.check_explanation(result.evidence_texts, result.explanation)
 
168
  # SECTION: Empty Context Tests
169
  # ============================================================================
170
 
171
+
172
  def run_empty_context_tests():
173
  """Test graceful refusal with irrelevant evidence."""
174
  from sage.services.explanation import Explainer
 
180
  detector = HallucinationDetector()
181
 
182
  cases = [
183
+ {
184
+ "name": "Irrelevant",
185
+ "query": "quantum computing textbook",
186
+ "evidence": "Great USB cable.",
187
+ },
188
  {"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
189
+ {
190
+ "name": "Foreign",
191
+ "query": "wireless mouse",
192
+ "evidence": "Muy bueno el producto.",
193
+ },
194
  ]
195
 
196
+ refusal_words = [
197
+ "cannot",
198
+ "can't",
199
+ "unable",
200
+ "no evidence",
201
+ "insufficient",
202
+ "limited",
203
+ ]
204
  results = []
205
 
206
  for case in cases:
207
  log_section(logger, case["name"])
208
 
209
+ chunk = RetrievedChunk(
210
+ text=case["evidence"],
211
+ score=0.3,
212
+ product_id="TEST",
213
+ rating=3.0,
214
+ review_id="r1",
215
+ )
216
+ product = ProductScore(
217
+ product_id="TEST",
218
+ score=0.3,
219
+ chunk_count=1,
220
+ avg_rating=3.0,
221
+ evidence=[chunk],
222
+ )
223
 
224
  result = explainer.generate_explanation(case["query"], product, max_evidence=1)
225
  _hhem = detector.check_explanation(result.evidence_texts, result.explanation)
 
239
  # SECTION: Calibration Check
240
  # ============================================================================
241
 
242
+
243
  @dataclass
244
  class CalibrationSample:
245
  query: str
 
265
 
266
  logger.info("Generating samples...")
267
  for query in queries:
268
+ products = get_candidates(
269
+ query=query, k=5, min_rating=3.0, aggregation=AggregationMethod.MAX
270
+ )
271
 
272
  for product in products[:2]:
273
  try:
274
  result = explainer.generate_explanation(query, product, max_evidence=3)
275
+ hhem = detector.check_explanation(
276
+ result.evidence_texts, result.explanation
277
+ )
278
+
279
+ samples.append(
280
+ CalibrationSample(
281
+ query=query,
282
+ product_id=product.product_id,
283
+ retrieval_score=product.score,
284
+ evidence_count=product.chunk_count,
285
+ avg_rating=product.avg_rating,
286
+ hhem_score=hhem.score,
287
+ )
288
+ )
289
  except Exception:
290
  logger.debug("Error generating sample", exc_info=True)
291
 
 
312
  # Stratified analysis
313
  sorted_samples = sorted(samples, key=lambda s: s.retrieval_score)
314
  n = len(sorted_samples)
315
+ low = sorted_samples[: n // 3]
316
+ mid = sorted_samples[n // 3 : 2 * n // 3]
317
+ high = sorted_samples[2 * n // 3 :]
318
 
319
  log_section(logger, "HHEM by Confidence Tier")
320
  logger.info(" LOW (n=%2d): %.3f", len(low), np.mean([s.hhem_score for s in low]))
321
  logger.info(" MED (n=%2d): %.3f", len(mid), np.mean([s.hhem_score for s in mid]))
322
+ logger.info(
323
+ " HIGH (n=%2d): %.3f", len(high), np.mean([s.hhem_score for s in high])
324
+ )
325
 
326
 
327
  # ============================================================================
328
  # Main
329
  # ============================================================================
330
 
331
+
332
  def main():
333
  parser = argparse.ArgumentParser(description="Run pipeline sanity checks")
334
  parser.add_argument(
335
+ "--section",
336
+ "-s",
337
  choices=["all", "spot", "adversarial", "empty", "calibration"],
338
  default="all",
339
  help="Which section to run",
scripts/summary.py CHANGED
@@ -16,7 +16,12 @@ import json
16
  import sys
17
  from pathlib import Path
18
 
19
- from sage.config import EVAL_DIMENSIONS, FAITHFULNESS_TARGET, HELPFULNESS_TARGET, RESULTS_DIR
 
 
 
 
 
20
 
21
  WIDTH = 60
22
  SEP = "=" * WIDTH
@@ -82,14 +87,20 @@ def main():
82
  quotes_total = mm.get("quotes_total", 0)
83
 
84
  if claim_pass is not None:
85
- print(f" Claim HHEM: {fmt(claim_avg, 3)} ({claim_pass*100:.0f}% pass)")
86
- print(f" Quote Verif: {fmt(quote_rate, 3)} ({quotes_found}/{quotes_total})")
 
 
 
 
87
 
88
  # Full-explanation HHEM (reference)
89
  h = faith["hhem"]
90
  n_grounded = n_samples - h.get("n_hallucinated", 0)
91
  full_avg = h.get("mean_score")
92
- print(f" Full HHEM: {fmt(full_avg, 3)} ({n_grounded}/{n_samples} grounded, reference)")
 
 
93
 
94
  # RAGAS if available
95
  ragas = faith.get("ragas", {})
@@ -98,8 +109,10 @@ def main():
98
  print(f" RAGAS Faith: {fmt(ragas_faith, 3)}")
99
 
100
  # Pass/fail: use claim-level as primary, fall back to RAGAS, then full HHEM
101
- effective = claim_avg if claim_avg is not None else (
102
- ragas_faith if ragas_faith is not None else full_avg
 
 
103
  )
104
  if effective is not None:
105
  status = "PASS" if effective >= FAITHFULNESS_TARGET else "FAIL"
@@ -123,7 +136,9 @@ def main():
123
  print(f" {label + ':':<15s} {fmt(m, 2) if m is not None else ' ---'}")
124
  if overall is not None:
125
  status = "PASS" if human.get("pass", False) else "FAIL"
126
- print(f" Helpfulness: {fmt(overall, 2)} (target: {target:.1f}) [{status}]")
 
 
127
  corr = human.get("hhem_trust_correlation", {})
128
  r = corr.get("spearman_r")
129
  if r is not None:
 
16
  import sys
17
  from pathlib import Path
18
 
19
+ from sage.config import (
20
+ EVAL_DIMENSIONS,
21
+ FAITHFULNESS_TARGET,
22
+ HELPFULNESS_TARGET,
23
+ RESULTS_DIR,
24
+ )
25
 
26
  WIDTH = 60
27
  SEP = "=" * WIDTH
 
87
  quotes_total = mm.get("quotes_total", 0)
88
 
89
  if claim_pass is not None:
90
+ print(
91
+ f" Claim HHEM: {fmt(claim_avg, 3)} ({claim_pass * 100:.0f}% pass)"
92
+ )
93
+ print(
94
+ f" Quote Verif: {fmt(quote_rate, 3)} ({quotes_found}/{quotes_total})"
95
+ )
96
 
97
  # Full-explanation HHEM (reference)
98
  h = faith["hhem"]
99
  n_grounded = n_samples - h.get("n_hallucinated", 0)
100
  full_avg = h.get("mean_score")
101
+ print(
102
+ f" Full HHEM: {fmt(full_avg, 3)} ({n_grounded}/{n_samples} grounded, reference)"
103
+ )
104
 
105
  # RAGAS if available
106
  ragas = faith.get("ragas", {})
 
109
  print(f" RAGAS Faith: {fmt(ragas_faith, 3)}")
110
 
111
  # Pass/fail: use claim-level as primary, fall back to RAGAS, then full HHEM
112
+ effective = (
113
+ claim_avg
114
+ if claim_avg is not None
115
+ else (ragas_faith if ragas_faith is not None else full_avg)
116
  )
117
  if effective is not None:
118
  status = "PASS" if effective >= FAITHFULNESS_TARGET else "FAIL"
 
136
  print(f" {label + ':':<15s} {fmt(m, 2) if m is not None else ' ---'}")
137
  if overall is not None:
138
  status = "PASS" if human.get("pass", False) else "FAIL"
139
+ print(
140
+ f" Helpfulness: {fmt(overall, 2)} (target: {target:.1f}) [{status}]"
141
+ )
142
  corr = human.get("hhem_trust_correlation", {})
143
  r = corr.get("spearman_r")
144
  if r is not None:
tests/test_aggregation.py CHANGED
@@ -82,7 +82,9 @@ class TestApplyWeightedRanking:
82
  ProductScore(product_id="A", score=0.9, chunk_count=2, avg_rating=3.0),
83
  ProductScore(product_id="B", score=0.7, chunk_count=1, avg_rating=5.0),
84
  ]
85
- ranked = apply_weighted_ranking(products, similarity_weight=0.5, rating_weight=0.5)
 
 
86
  assert len(ranked) == 2
87
  # B has higher rating, so with 50/50 weights it might rank higher
88
  assert all(isinstance(p, ProductScore) for p in ranked)
@@ -92,7 +94,9 @@ class TestApplyWeightedRanking:
92
  ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
93
  ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
94
  ]
95
- ranked = apply_weighted_ranking(products, similarity_weight=1.0, rating_weight=0.0)
 
 
96
  assert ranked[0].product_id == "A"
97
 
98
  def test_pure_rating_reranks(self):
@@ -100,7 +104,9 @@ class TestApplyWeightedRanking:
100
  ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
101
  ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
102
  ]
103
- ranked = apply_weighted_ranking(products, similarity_weight=0.0, rating_weight=1.0)
 
 
104
  assert ranked[0].product_id == "B"
105
 
106
  def test_single_product(self):
 
82
  ProductScore(product_id="A", score=0.9, chunk_count=2, avg_rating=3.0),
83
  ProductScore(product_id="B", score=0.7, chunk_count=1, avg_rating=5.0),
84
  ]
85
+ ranked = apply_weighted_ranking(
86
+ products, similarity_weight=0.5, rating_weight=0.5
87
+ )
88
  assert len(ranked) == 2
89
  # B has higher rating, so with 50/50 weights it might rank higher
90
  assert all(isinstance(p, ProductScore) for p in ranked)
 
94
  ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
95
  ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
96
  ]
97
+ ranked = apply_weighted_ranking(
98
+ products, similarity_weight=1.0, rating_weight=0.0
99
+ )
100
  assert ranked[0].product_id == "A"
101
 
102
  def test_pure_rating_reranks(self):
 
104
  ProductScore(product_id="A", score=0.9, chunk_count=1, avg_rating=1.0),
105
  ProductScore(product_id="B", score=0.5, chunk_count=1, avg_rating=5.0),
106
  ]
107
+ ranked = apply_weighted_ranking(
108
+ products, similarity_weight=0.0, rating_weight=1.0
109
+ )
110
  assert ranked[0].product_id == "B"
111
 
112
  def test_single_product(self):
tests/test_api.py CHANGED
@@ -27,9 +27,16 @@ def _make_app(**state_overrides) -> FastAPI:
27
  mock_cache = MagicMock()
28
  mock_cache.get.return_value = (None, "miss")
29
  mock_cache.stats.return_value = SimpleNamespace(
30
- size=0, max_entries=100, exact_hits=0, semantic_hits=0,
31
- misses=0, evictions=0, hit_rate=0.0, ttl_seconds=3600.0,
32
- similarity_threshold=0.92, avg_semantic_similarity=0.0,
 
 
 
 
 
 
 
33
  )
34
 
35
  app.state.qdrant = state_overrides.get("qdrant", mock_qdrant)
@@ -56,6 +63,7 @@ class TestHealthEndpoint:
56
  with TestClient(app) as c:
57
  # Patch collection_exists to return True
58
  import sage.api.routes as routes_mod
 
59
  original = routes_mod.collection_exists
60
  routes_mod.collection_exists = lambda client: True
61
  try:
@@ -70,6 +78,7 @@ class TestHealthEndpoint:
70
  def test_degraded_when_collection_missing(self):
71
  app = _make_app()
72
  import sage.api.routes as routes_mod
 
73
  original = routes_mod.collection_exists
74
  routes_mod.collection_exists = lambda client: False
75
  try:
@@ -90,6 +99,7 @@ class TestRecommendEndpoint:
90
 
91
  def test_empty_results(self, client):
92
  import sage.api.routes as routes_mod
 
93
  original = routes_mod.get_candidates
94
  routes_mod.get_candidates = lambda **kw: []
95
  try:
@@ -102,12 +112,18 @@ class TestRecommendEndpoint:
102
 
103
  def test_returns_products_without_explain(self):
104
  product = ProductScore(
105
- product_id="P1", score=0.9, chunk_count=2, avg_rating=4.5,
 
 
 
106
  evidence=[
107
- RetrievedChunk(text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"),
 
 
108
  ],
109
  )
110
  import sage.api.routes as routes_mod
 
111
  original = routes_mod.get_candidates
112
  routes_mod.get_candidates = lambda **kw: [product]
113
  app = _make_app()
@@ -126,12 +142,18 @@ class TestRecommendEndpoint:
126
 
127
  def test_explainer_unavailable_returns_503(self):
128
  product = ProductScore(
129
- product_id="P1", score=0.9, chunk_count=2, avg_rating=4.5,
 
 
 
130
  evidence=[
131
- RetrievedChunk(text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"),
 
 
132
  ],
133
  )
134
  import sage.api.routes as routes_mod
 
135
  original = routes_mod.get_candidates
136
  routes_mod.get_candidates = lambda **kw: [product]
137
 
 
27
  mock_cache = MagicMock()
28
  mock_cache.get.return_value = (None, "miss")
29
  mock_cache.stats.return_value = SimpleNamespace(
30
+ size=0,
31
+ max_entries=100,
32
+ exact_hits=0,
33
+ semantic_hits=0,
34
+ misses=0,
35
+ evictions=0,
36
+ hit_rate=0.0,
37
+ ttl_seconds=3600.0,
38
+ similarity_threshold=0.92,
39
+ avg_semantic_similarity=0.0,
40
  )
41
 
42
  app.state.qdrant = state_overrides.get("qdrant", mock_qdrant)
 
63
  with TestClient(app) as c:
64
  # Patch collection_exists to return True
65
  import sage.api.routes as routes_mod
66
+
67
  original = routes_mod.collection_exists
68
  routes_mod.collection_exists = lambda client: True
69
  try:
 
78
  def test_degraded_when_collection_missing(self):
79
  app = _make_app()
80
  import sage.api.routes as routes_mod
81
+
82
  original = routes_mod.collection_exists
83
  routes_mod.collection_exists = lambda client: False
84
  try:
 
99
 
100
  def test_empty_results(self, client):
101
  import sage.api.routes as routes_mod
102
+
103
  original = routes_mod.get_candidates
104
  routes_mod.get_candidates = lambda **kw: []
105
  try:
 
112
 
113
  def test_returns_products_without_explain(self):
114
  product = ProductScore(
115
+ product_id="P1",
116
+ score=0.9,
117
+ chunk_count=2,
118
+ avg_rating=4.5,
119
  evidence=[
120
+ RetrievedChunk(
121
+ text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"
122
+ ),
123
  ],
124
  )
125
  import sage.api.routes as routes_mod
126
+
127
  original = routes_mod.get_candidates
128
  routes_mod.get_candidates = lambda **kw: [product]
129
  app = _make_app()
 
142
 
143
  def test_explainer_unavailable_returns_503(self):
144
  product = ProductScore(
145
+ product_id="P1",
146
+ score=0.9,
147
+ chunk_count=2,
148
+ avg_rating=4.5,
149
  evidence=[
150
+ RetrievedChunk(
151
+ text="Good", score=0.9, product_id="P1", rating=4.5, review_id="r1"
152
+ ),
153
  ],
154
  )
155
  import sage.api.routes as routes_mod
156
+
157
  original = routes_mod.get_candidates
158
  routes_mod.get_candidates = lambda **kw: [product]
159
 
tests/test_chunking.py CHANGED
@@ -67,7 +67,9 @@ class TestSlidingWindowChunk:
67
 
68
  def test_long_text_creates_multiple_chunks(self):
69
  # Create text long enough to require multiple chunks
70
- sentences = [f"This is sentence number {i} with some padding text." for i in range(20)]
 
 
71
  text = " ".join(sentences)
72
  chunks = sliding_window_chunk(text, chunk_size=50, overlap=10)
73
  assert len(chunks) > 1
 
67
 
68
  def test_long_text_creates_multiple_chunks(self):
69
  # Create text long enough to require multiple chunks
70
+ sentences = [
71
+ f"This is sentence number {i} with some padding text." for i in range(20)
72
+ ]
73
  text = " ".join(sentences)
74
  chunks = sliding_window_chunk(text, chunk_size=50, overlap=10)
75
  assert len(chunks) > 1
tests/test_evidence.py CHANGED
@@ -50,7 +50,10 @@ class TestCheckEvidenceQuality:
50
  product = _product(score=0.3, n_chunks=3, text_len=300)
51
  quality = check_evidence_quality(product, min_score=0.7)
52
  assert quality.is_sufficient is False
53
- assert "relevance" in quality.failure_reason.lower() or "score" in quality.failure_reason.lower()
 
 
 
54
 
55
  def test_tracks_chunk_count(self):
56
  product = _product(score=0.85, n_chunks=4, text_len=200)
 
50
  product = _product(score=0.3, n_chunks=3, text_len=300)
51
  quality = check_evidence_quality(product, min_score=0.7)
52
  assert quality.is_sufficient is False
53
+ assert (
54
+ "relevance" in quality.failure_reason.lower()
55
+ or "score" in quality.failure_reason.lower()
56
+ )
57
 
58
  def test_tracks_chunk_count(self):
59
  product = _product(score=0.85, n_chunks=4, text_len=200)
tests/test_faithfulness.py CHANGED
@@ -29,13 +29,20 @@ class TestIsRefusal:
29
 
30
  class TestIsMismatchWarning:
31
  def test_detects_not_best_match(self):
32
- assert is_mismatch_warning("This product may not be the best match for your needs.") is True
 
 
 
 
 
33
 
34
  def test_detects_not_designed_for(self):
35
  assert is_mismatch_warning("This is not designed for that purpose.") is True
36
 
37
  def test_detects_not_suitable(self):
38
- assert is_mismatch_warning("This product is not suitable for heavy use.") is True
 
 
39
 
40
  def test_normal_explanation_not_mismatch(self):
41
  assert is_mismatch_warning("Great headphones with noise cancellation.") is False
 
29
 
30
  class TestIsMismatchWarning:
31
  def test_detects_not_best_match(self):
32
+ assert (
33
+ is_mismatch_warning(
34
+ "This product may not be the best match for your needs."
35
+ )
36
+ is True
37
+ )
38
 
39
  def test_detects_not_designed_for(self):
40
  assert is_mismatch_warning("This is not designed for that purpose.") is True
41
 
42
  def test_detects_not_suitable(self):
43
+ assert (
44
+ is_mismatch_warning("This product is not suitable for heavy use.") is True
45
+ )
46
 
47
  def test_normal_explanation_not_mismatch(self):
48
  assert is_mismatch_warning("Great headphones with noise cancellation.") is False
tests/test_models.py CHANGED
@@ -35,19 +35,32 @@ class TestNewItem:
35
  class TestProductScore:
36
  def test_top_evidence_returns_highest(self):
37
  chunks = [
38
- RetrievedChunk(text="low", score=0.5, product_id="P1", rating=4.0, review_id="r1"),
39
- RetrievedChunk(text="high", score=0.9, product_id="P1", rating=4.0, review_id="r2"),
40
- RetrievedChunk(text="mid", score=0.7, product_id="P1", rating=4.0, review_id="r3"),
 
 
 
 
 
 
41
  ]
42
  product = ProductScore(
43
- product_id="P1", score=0.9, chunk_count=3, avg_rating=4.0, evidence=chunks,
 
 
 
 
44
  )
45
  assert product.top_evidence.text == "high"
46
  assert product.top_evidence.score == 0.9
47
 
48
  def test_top_evidence_empty(self):
49
  product = ProductScore(
50
- product_id="P1", score=0.5, chunk_count=0, avg_rating=4.0,
 
 
 
51
  )
52
  assert product.top_evidence is None
53
 
@@ -116,7 +129,10 @@ class TestStreamingExplanation:
116
  class TestEvidenceQuality:
117
  def test_sufficient(self):
118
  eq = EvidenceQuality(
119
- is_sufficient=True, chunk_count=3, total_tokens=150, top_score=0.9,
 
 
 
120
  )
121
  assert eq.is_sufficient is True
122
  assert eq.failure_reason is None
 
35
  class TestProductScore:
36
  def test_top_evidence_returns_highest(self):
37
  chunks = [
38
+ RetrievedChunk(
39
+ text="low", score=0.5, product_id="P1", rating=4.0, review_id="r1"
40
+ ),
41
+ RetrievedChunk(
42
+ text="high", score=0.9, product_id="P1", rating=4.0, review_id="r2"
43
+ ),
44
+ RetrievedChunk(
45
+ text="mid", score=0.7, product_id="P1", rating=4.0, review_id="r3"
46
+ ),
47
  ]
48
  product = ProductScore(
49
+ product_id="P1",
50
+ score=0.9,
51
+ chunk_count=3,
52
+ avg_rating=4.0,
53
+ evidence=chunks,
54
  )
55
  assert product.top_evidence.text == "high"
56
  assert product.top_evidence.score == 0.9
57
 
58
  def test_top_evidence_empty(self):
59
  product = ProductScore(
60
+ product_id="P1",
61
+ score=0.5,
62
+ chunk_count=0,
63
+ avg_rating=4.0,
64
  )
65
  assert product.top_evidence is None
66
 
 
129
  class TestEvidenceQuality:
130
  def test_sufficient(self):
131
  eq = EvidenceQuality(
132
+ is_sufficient=True,
133
+ chunk_count=3,
134
+ total_tokens=150,
135
+ top_score=0.9,
136
  )
137
  assert eq.is_sufficient is True
138
  assert eq.failure_reason is None