vxa8502 commited on
Commit
ca96fbf
·
1 Parent(s): a9bab1a

Add bootstrap confidence intervals to evaluation metrics

Browse files
data/eval_results/eval_natural_queries_20260305_161900_824897.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2026-03-05T16:18:49.380747",
3
+ "dataset": "eval_natural_queries.json",
4
+ "catalog_size": 21827,
5
+ "experiments": {},
6
+ "primary_metrics": {
7
+ "ndcg_at_10": 0.4871922222425982,
8
+ "hit_at_10": 0.7380952380952381,
9
+ "mrr": 0.42086167800453517,
10
+ "precision_at_10": 0.12857142857142856,
11
+ "recall_at_10": 0.4722222222222222,
12
+ "diversity": 0.01957190520646696,
13
+ "coverage": 0.015531222797452697,
14
+ "novelty": 9.808908578271737,
15
+ "ndcg_ci": {
16
+ "mean": 0.4872,
17
+ "ci_lower": 0.3779,
18
+ "ci_upper": 0.6078,
19
+ "confidence": 0.95
20
+ },
21
+ "hit_ci": {
22
+ "mean": 0.7381,
23
+ "ci_lower": 0.5952,
24
+ "ci_upper": 0.8571,
25
+ "confidence": 0.95
26
+ },
27
+ "mrr_ci": {
28
+ "mean": 0.4209,
29
+ "ci_lower": 0.301,
30
+ "ci_upper": 0.5545,
31
+ "confidence": 0.95
32
+ }
33
+ }
34
+ }
data/eval_results/eval_natural_queries_latest.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "timestamp": "2026-02-10T11:49:08.500849",
3
- "dataset": "eval_natural_queries.json",
4
- "catalog_size": 21827,
5
- "experiments": {},
6
- "primary_metrics": {
7
- "ndcg_at_10": 0.4871922222425982,
8
- "hit_at_10": 0.7380952380952381,
9
- "mrr": 0.42086167800453517,
10
- "precision_at_10": 0.12857142857142856,
11
- "recall_at_10": 0.4722222222222222,
12
- "diversity": 0.01957190520646696,
13
- "coverage": 0.015531222797452697,
14
- "novelty": 9.808908578271737
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/eval_results/eval_natural_queries_latest.json ADDED
@@ -0,0 +1 @@
 
 
1
+ eval_natural_queries_20260305_161900_824897.json
sage/core/__init__.py CHANGED
@@ -37,6 +37,7 @@ from sage.core.models import (
37
  FaithfulnessReport,
38
  FaithfulnessResult,
39
  # Evaluation
 
40
  EvalCase,
41
  EvalResult,
42
  MetricsReport,
@@ -110,6 +111,7 @@ __all__ = [
110
  "MultiMetricFaithfulnessReport",
111
  "FaithfulnessReport",
112
  "FaithfulnessResult",
 
113
  "EvalCase",
114
  "EvalResult",
115
  "MetricsReport",
 
37
  FaithfulnessReport,
38
  FaithfulnessResult,
39
  # Evaluation
40
+ ConfidenceInterval,
41
  EvalCase,
42
  EvalResult,
43
  MetricsReport,
 
111
  "MultiMetricFaithfulnessReport",
112
  "FaithfulnessReport",
113
  "FaithfulnessResult",
114
+ "ConfidenceInterval",
115
  "EvalCase",
116
  "EvalResult",
117
  "MetricsReport",
sage/core/models.py CHANGED
@@ -13,7 +13,7 @@ from __future__ import annotations
13
 
14
  from dataclasses import dataclass, field
15
  from enum import Enum
16
- from typing import Iterator
17
 
18
 
19
  # ============================================================================
@@ -555,6 +555,27 @@ class EvalResult:
555
  recall: float = 0.0
556
 
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  @dataclass
559
  class MetricsReport:
560
  """
@@ -575,9 +596,14 @@ class MetricsReport:
575
  novelty: float = 0.0
576
  k: int = 10
577
 
578
- def to_dict(self) -> dict:
 
 
 
 
 
579
  """Convert to dictionary for easy serialization."""
580
- return {
581
  "n_cases": self.n_cases,
582
  f"ndcg@{self.k}": round(self.ndcg_at_k, 4),
583
  f"hit@{self.k}": round(self.hit_at_k, 4),
@@ -588,19 +614,35 @@ class MetricsReport:
588
  "coverage": round(self.coverage, 4),
589
  "novelty": round(self.novelty, 4),
590
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
  def __str__(self) -> str:
593
  lines = [
594
  f"Evaluation Results (n={self.n_cases}, k={self.k})",
595
- "-" * 40,
596
- f"NDCG@{self.k}: {self.ndcg_at_k:.4f}",
597
- f"Hit@{self.k}: {self.hit_at_k:.4f}",
598
- f"MRR: {self.mrr:.4f}",
599
- f"Precision@{self.k}: {self.precision_at_k:.4f}",
600
- f"Recall@{self.k}: {self.recall_at_k:.4f}",
601
- "-" * 40,
602
- f"Diversity: {self.diversity:.4f}",
603
- f"Coverage: {self.coverage:.4f}",
604
- f"Novelty: {self.novelty:.4f}",
605
  ]
606
  return "\n".join(lines)
 
13
 
14
  from dataclasses import dataclass, field
15
  from enum import Enum
16
+ from typing import Any, Iterator
17
 
18
 
19
  # ============================================================================
 
555
  recall: float = 0.0
556
 
557
 
558
+ @dataclass
559
+ class ConfidenceInterval:
560
+ """Bootstrap confidence interval for a metric."""
561
+
562
+ mean: float
563
+ lower: float
564
+ upper: float
565
+ confidence: float = 0.95
566
+
567
+ def __str__(self) -> str:
568
+ return f"{self.mean:.3f} [{self.lower:.3f}, {self.upper:.3f}]"
569
+
570
+ def to_dict(self) -> dict[str, float]:
571
+ return {
572
+ "mean": round(self.mean, 4),
573
+ "ci_lower": round(self.lower, 4),
574
+ "ci_upper": round(self.upper, 4),
575
+ "confidence": self.confidence,
576
+ }
577
+
578
+
579
  @dataclass
580
  class MetricsReport:
581
  """
 
596
  novelty: float = 0.0
597
  k: int = 10
598
 
599
+ # Bootstrap confidence intervals (optional)
600
+ ndcg_ci: ConfidenceInterval | None = None
601
+ hit_ci: ConfidenceInterval | None = None
602
+ mrr_ci: ConfidenceInterval | None = None
603
+
604
+ def to_dict(self) -> dict[str, Any]:
605
  """Convert to dictionary for easy serialization."""
606
+ result: dict[str, Any] = {
607
  "n_cases": self.n_cases,
608
  f"ndcg@{self.k}": round(self.ndcg_at_k, 4),
609
  f"hit@{self.k}": round(self.hit_at_k, 4),
 
614
  "coverage": round(self.coverage, 4),
615
  "novelty": round(self.novelty, 4),
616
  }
617
+ for name, ci in [
618
+ ("ndcg_ci", self.ndcg_ci),
619
+ ("hit_ci", self.hit_ci),
620
+ ("mrr_ci", self.mrr_ci),
621
+ ]:
622
+ if ci:
623
+ result[name] = ci.to_dict()
624
+ return result
625
+
626
+ def _fmt_metric(
627
+ self, name: str, value: float, ci: ConfidenceInterval | None
628
+ ) -> str:
629
+ """Format a metric with optional CI."""
630
+ if ci:
631
+ return f"{name:<14s} {value:.4f} [{ci.lower:.3f}, {ci.upper:.3f}]"
632
+ return f"{name:<14s} {value:.4f}"
633
 
634
  def __str__(self) -> str:
635
  lines = [
636
  f"Evaluation Results (n={self.n_cases}, k={self.k})",
637
+ "-" * 50,
638
+ self._fmt_metric(f"NDCG@{self.k}:", self.ndcg_at_k, self.ndcg_ci),
639
+ self._fmt_metric(f"Hit@{self.k}:", self.hit_at_k, self.hit_ci),
640
+ self._fmt_metric("MRR:", self.mrr, self.mrr_ci),
641
+ self._fmt_metric(f"Precision@{self.k}:", self.precision_at_k, None),
642
+ self._fmt_metric(f"Recall@{self.k}:", self.recall_at_k, None),
643
+ "-" * 50,
644
+ self._fmt_metric("Diversity:", self.diversity, None),
645
+ self._fmt_metric("Coverage:", self.coverage, None),
646
+ self._fmt_metric("Novelty:", self.novelty, None),
647
  ]
648
  return "\n".join(lines)
sage/services/baselines.py CHANGED
@@ -230,8 +230,6 @@ def load_product_embeddings_from_qdrant() -> dict[str, np.ndarray]:
230
  product_id = point.payload.get("product_id")
231
  product_vectors[product_id].append(np.array(point.vector))
232
 
233
- client.close()
234
-
235
  # Mean aggregation + normalize
236
  return {
237
  product_id: normalize_vectors(np.mean(vectors, axis=0))
@@ -265,8 +263,6 @@ def compute_item_popularity_from_qdrant(
265
  if point.payload.get("product_id")
266
  )
267
 
268
- client.close()
269
-
270
  if not normalize:
271
  return dict(counts)
272
 
 
230
  product_id = point.payload.get("product_id")
231
  product_vectors[product_id].append(np.array(point.vector))
232
 
 
 
233
  # Mean aggregation + normalize
234
  return {
235
  product_id: normalize_vectors(np.mean(vectors, axis=0))
 
263
  if point.payload.get("product_id")
264
  )
265
 
 
 
266
  if not normalize:
267
  return dict(counts)
268
 
sage/services/evaluation.py CHANGED
@@ -20,7 +20,7 @@ from typing import Callable
20
 
21
  import numpy as np
22
 
23
- from sage.core import EvalCase, EvalResult, MetricsReport
24
  from sage.utils import normalize_vectors
25
 
26
 
@@ -160,6 +160,19 @@ def compute_item_popularity(
160
  return {item: count / total for item, count in counts.items()}
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  class EvaluationService:
164
  """
165
  Service for evaluating recommendation quality.
@@ -260,15 +273,16 @@ class EvaluationService:
260
  report = MetricsReport(
261
  n_cases=len(eval_cases),
262
  k=self.k,
263
- ndcg_at_k=float(np.mean(ndcg_scores)) if ndcg_scores else 0.0,
264
- hit_at_k=float(np.mean(hit_scores)) if hit_scores else 0.0,
265
- mrr=float(np.mean(mrr_scores)) if mrr_scores else 0.0,
266
- precision_at_k=float(np.mean(precision_scores))
267
- if precision_scores
268
- else 0.0,
269
- recall_at_k=float(np.mean(recall_scores)) if recall_scores else 0.0,
270
- diversity=float(np.mean(diversity_scores)) if diversity_scores else 0.0,
271
- novelty=float(np.mean(novelty_scores)) if novelty_scores else 0.0,
 
272
  )
273
 
274
  if self.total_items:
 
20
 
21
  import numpy as np
22
 
23
+ from sage.core import ConfidenceInterval, EvalCase, EvalResult, MetricsReport
24
  from sage.utils import normalize_vectors
25
 
26
 
 
160
  return {item: count / total for item, count in counts.items()}
161
 
162
 
163
+ def _safe_mean(scores: list[float]) -> float:
164
+ """Compute mean of scores, returning 0.0 for empty list."""
165
+ return float(np.mean(scores)) if scores else 0.0
166
+
167
+
168
+ def _compute_ci(scores: list[float]) -> ConfidenceInterval | None:
169
+ """Compute bootstrap CI for scores, returning None for empty list."""
170
+ if not scores:
171
+ return None
172
+ mean, lower, upper = bootstrap_confidence_interval(scores)
173
+ return ConfidenceInterval(mean=mean, lower=lower, upper=upper)
174
+
175
+
176
  class EvaluationService:
177
  """
178
  Service for evaluating recommendation quality.
 
273
  report = MetricsReport(
274
  n_cases=len(eval_cases),
275
  k=self.k,
276
+ ndcg_at_k=_safe_mean(ndcg_scores),
277
+ hit_at_k=_safe_mean(hit_scores),
278
+ mrr=_safe_mean(mrr_scores),
279
+ precision_at_k=_safe_mean(precision_scores),
280
+ recall_at_k=_safe_mean(recall_scores),
281
+ diversity=_safe_mean(diversity_scores),
282
+ novelty=_safe_mean(novelty_scores),
283
+ ndcg_ci=_compute_ci(ndcg_scores),
284
+ hit_ci=_compute_ci(hit_scores),
285
+ mrr_ci=_compute_ci(mrr_scores),
286
  )
287
 
288
  if self.total_items:
scripts/evaluation.py CHANGED
@@ -87,7 +87,7 @@ def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items)
87
  )
88
  logger.info(str(report))
89
 
90
- return {
91
  "ndcg_at_10": report.ndcg_at_k,
92
  "hit_at_10": report.hit_at_k,
93
  "mrr": report.mrr,
@@ -98,6 +98,17 @@ def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items)
98
  "novelty": report.novelty,
99
  }
100
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # ============================================================================
103
  # SECTION: Aggregation Methods
 
87
  )
88
  logger.info(str(report))
89
 
90
+ result = {
91
  "ndcg_at_10": report.ndcg_at_k,
92
  "hit_at_10": report.hit_at_k,
93
  "mrr": report.mrr,
 
98
  "novelty": report.novelty,
99
  }
100
 
101
+ # Add confidence intervals if available
102
+ for name, ci in [
103
+ ("ndcg_ci", report.ndcg_ci),
104
+ ("hit_ci", report.hit_ci),
105
+ ("mrr_ci", report.mrr_ci),
106
+ ]:
107
+ if ci:
108
+ result[name] = ci.to_dict()
109
+
110
+ return result
111
+
112
 
113
  # ============================================================================
114
  # SECTION: Aggregation Methods
scripts/summary.py CHANGED
@@ -42,6 +42,17 @@ def fmt(value: float | None, decimals: int = 4) -> str:
42
  return f"{value:.{decimals}f}"
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
45
  def print_section(title: str):
46
  print(f"\n{title}")
47
 
@@ -56,9 +67,12 @@ def main():
56
  print_section("Recommendation Quality (Natural Queries):")
57
  if nat and "primary_metrics" in nat:
58
  m = nat["primary_metrics"]
59
- print(f" NDCG@10: {fmt(m.get('ndcg_at_10'))}")
60
- print(f" Hit@10: {fmt(m.get('hit_at_10'))}")
61
- print(f" MRR: {fmt(m.get('mrr'))}")
 
 
 
62
  else:
63
  print(" (not available)")
64
 
 
42
  return f"{value:.{decimals}f}"
43
 
44
 
45
+ def fmt_with_ci(value: float | None, ci: dict | None, decimals: int = 3) -> str:
46
+ """Format a value with optional confidence interval."""
47
+ if value is None:
48
+ return " ---"
49
+ if ci and "ci_lower" in ci and "ci_upper" in ci:
50
+ lower = ci["ci_lower"]
51
+ upper = ci["ci_upper"]
52
+ return f"{value:.{decimals}f} [{lower:.{decimals}f}, {upper:.{decimals}f}]"
53
+ return f"{value:.{decimals}f}"
54
+
55
+
56
  def print_section(title: str):
57
  print(f"\n{title}")
58
 
 
67
  print_section("Recommendation Quality (Natural Queries):")
68
  if nat and "primary_metrics" in nat:
69
  m = nat["primary_metrics"]
70
+ for label, key, ci_key in [
71
+ ("NDCG@10", "ndcg_at_10", "ndcg_ci"),
72
+ ("Hit@10", "hit_at_10", "hit_ci"),
73
+ ("MRR", "mrr", "mrr_ci"),
74
+ ]:
75
+ print(f" {label + ':':<10s} {fmt_with_ci(m.get(key), m.get(ci_key))}")
76
  else:
77
  print(" (not available)")
78