Add bootstrap confidence intervals to evaluation metrics
Browse files- data/eval_results/eval_natural_queries_20260305_161900_824897.json +34 -0
- data/eval_results/eval_natural_queries_latest.json +0 -16
- data/eval_results/eval_natural_queries_latest.json +1 -0
- sage/core/__init__.py +2 -0
- sage/core/models.py +55 -13
- sage/services/baselines.py +0 -4
- sage/services/evaluation.py +24 -10
- scripts/evaluation.py +12 -1
- scripts/summary.py +17 -3
data/eval_results/eval_natural_queries_20260305_161900_824897.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2026-03-05T16:18:49.380747",
|
| 3 |
+
"dataset": "eval_natural_queries.json",
|
| 4 |
+
"catalog_size": 21827,
|
| 5 |
+
"experiments": {},
|
| 6 |
+
"primary_metrics": {
|
| 7 |
+
"ndcg_at_10": 0.4871922222425982,
|
| 8 |
+
"hit_at_10": 0.7380952380952381,
|
| 9 |
+
"mrr": 0.42086167800453517,
|
| 10 |
+
"precision_at_10": 0.12857142857142856,
|
| 11 |
+
"recall_at_10": 0.4722222222222222,
|
| 12 |
+
"diversity": 0.01957190520646696,
|
| 13 |
+
"coverage": 0.015531222797452697,
|
| 14 |
+
"novelty": 9.808908578271737,
|
| 15 |
+
"ndcg_ci": {
|
| 16 |
+
"mean": 0.4872,
|
| 17 |
+
"ci_lower": 0.3779,
|
| 18 |
+
"ci_upper": 0.6078,
|
| 19 |
+
"confidence": 0.95
|
| 20 |
+
},
|
| 21 |
+
"hit_ci": {
|
| 22 |
+
"mean": 0.7381,
|
| 23 |
+
"ci_lower": 0.5952,
|
| 24 |
+
"ci_upper": 0.8571,
|
| 25 |
+
"confidence": 0.95
|
| 26 |
+
},
|
| 27 |
+
"mrr_ci": {
|
| 28 |
+
"mean": 0.4209,
|
| 29 |
+
"ci_lower": 0.301,
|
| 30 |
+
"ci_upper": 0.5545,
|
| 31 |
+
"confidence": 0.95
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
}
|
data/eval_results/eval_natural_queries_latest.json
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"timestamp": "2026-02-10T11:49:08.500849",
|
| 3 |
-
"dataset": "eval_natural_queries.json",
|
| 4 |
-
"catalog_size": 21827,
|
| 5 |
-
"experiments": {},
|
| 6 |
-
"primary_metrics": {
|
| 7 |
-
"ndcg_at_10": 0.4871922222425982,
|
| 8 |
-
"hit_at_10": 0.7380952380952381,
|
| 9 |
-
"mrr": 0.42086167800453517,
|
| 10 |
-
"precision_at_10": 0.12857142857142856,
|
| 11 |
-
"recall_at_10": 0.4722222222222222,
|
| 12 |
-
"diversity": 0.01957190520646696,
|
| 13 |
-
"coverage": 0.015531222797452697,
|
| 14 |
-
"novelty": 9.808908578271737
|
| 15 |
-
}
|
| 16 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/eval_results/eval_natural_queries_latest.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
eval_natural_queries_20260305_161900_824897.json
|
sage/core/__init__.py
CHANGED
|
@@ -37,6 +37,7 @@ from sage.core.models import (
|
|
| 37 |
FaithfulnessReport,
|
| 38 |
FaithfulnessResult,
|
| 39 |
# Evaluation
|
|
|
|
| 40 |
EvalCase,
|
| 41 |
EvalResult,
|
| 42 |
MetricsReport,
|
|
@@ -110,6 +111,7 @@ __all__ = [
|
|
| 110 |
"MultiMetricFaithfulnessReport",
|
| 111 |
"FaithfulnessReport",
|
| 112 |
"FaithfulnessResult",
|
|
|
|
| 113 |
"EvalCase",
|
| 114 |
"EvalResult",
|
| 115 |
"MetricsReport",
|
|
|
|
| 37 |
FaithfulnessReport,
|
| 38 |
FaithfulnessResult,
|
| 39 |
# Evaluation
|
| 40 |
+
ConfidenceInterval,
|
| 41 |
EvalCase,
|
| 42 |
EvalResult,
|
| 43 |
MetricsReport,
|
|
|
|
| 111 |
"MultiMetricFaithfulnessReport",
|
| 112 |
"FaithfulnessReport",
|
| 113 |
"FaithfulnessResult",
|
| 114 |
+
"ConfidenceInterval",
|
| 115 |
"EvalCase",
|
| 116 |
"EvalResult",
|
| 117 |
"MetricsReport",
|
sage/core/models.py
CHANGED
|
@@ -13,7 +13,7 @@ from __future__ import annotations
|
|
| 13 |
|
| 14 |
from dataclasses import dataclass, field
|
| 15 |
from enum import Enum
|
| 16 |
-
from typing import Iterator
|
| 17 |
|
| 18 |
|
| 19 |
# ============================================================================
|
|
@@ -555,6 +555,27 @@ class EvalResult:
|
|
| 555 |
recall: float = 0.0
|
| 556 |
|
| 557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
@dataclass
|
| 559 |
class MetricsReport:
|
| 560 |
"""
|
|
@@ -575,9 +596,14 @@ class MetricsReport:
|
|
| 575 |
novelty: float = 0.0
|
| 576 |
k: int = 10
|
| 577 |
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
"""Convert to dictionary for easy serialization."""
|
| 580 |
-
|
| 581 |
"n_cases": self.n_cases,
|
| 582 |
f"ndcg@{self.k}": round(self.ndcg_at_k, 4),
|
| 583 |
f"hit@{self.k}": round(self.hit_at_k, 4),
|
|
@@ -588,19 +614,35 @@ class MetricsReport:
|
|
| 588 |
"coverage": round(self.coverage, 4),
|
| 589 |
"novelty": round(self.novelty, 4),
|
| 590 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
|
| 592 |
def __str__(self) -> str:
|
| 593 |
lines = [
|
| 594 |
f"Evaluation Results (n={self.n_cases}, k={self.k})",
|
| 595 |
-
"-" *
|
| 596 |
-
f"NDCG@{self.k}:
|
| 597 |
-
f"Hit@{self.k}:
|
| 598 |
-
|
| 599 |
-
f"Precision@{self.k}:
|
| 600 |
-
f"Recall@{self.k}:
|
| 601 |
-
"-" *
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
]
|
| 606 |
return "\n".join(lines)
|
|
|
|
| 13 |
|
| 14 |
from dataclasses import dataclass, field
|
| 15 |
from enum import Enum
|
| 16 |
+
from typing import Any, Iterator
|
| 17 |
|
| 18 |
|
| 19 |
# ============================================================================
|
|
|
|
| 555 |
recall: float = 0.0
|
| 556 |
|
| 557 |
|
| 558 |
+
@dataclass
|
| 559 |
+
class ConfidenceInterval:
|
| 560 |
+
"""Bootstrap confidence interval for a metric."""
|
| 561 |
+
|
| 562 |
+
mean: float
|
| 563 |
+
lower: float
|
| 564 |
+
upper: float
|
| 565 |
+
confidence: float = 0.95
|
| 566 |
+
|
| 567 |
+
def __str__(self) -> str:
|
| 568 |
+
return f"{self.mean:.3f} [{self.lower:.3f}, {self.upper:.3f}]"
|
| 569 |
+
|
| 570 |
+
def to_dict(self) -> dict[str, float]:
|
| 571 |
+
return {
|
| 572 |
+
"mean": round(self.mean, 4),
|
| 573 |
+
"ci_lower": round(self.lower, 4),
|
| 574 |
+
"ci_upper": round(self.upper, 4),
|
| 575 |
+
"confidence": self.confidence,
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
|
| 579 |
@dataclass
|
| 580 |
class MetricsReport:
|
| 581 |
"""
|
|
|
|
| 596 |
novelty: float = 0.0
|
| 597 |
k: int = 10
|
| 598 |
|
| 599 |
+
# Bootstrap confidence intervals (optional)
|
| 600 |
+
ndcg_ci: ConfidenceInterval | None = None
|
| 601 |
+
hit_ci: ConfidenceInterval | None = None
|
| 602 |
+
mrr_ci: ConfidenceInterval | None = None
|
| 603 |
+
|
| 604 |
+
def to_dict(self) -> dict[str, Any]:
|
| 605 |
"""Convert to dictionary for easy serialization."""
|
| 606 |
+
result: dict[str, Any] = {
|
| 607 |
"n_cases": self.n_cases,
|
| 608 |
f"ndcg@{self.k}": round(self.ndcg_at_k, 4),
|
| 609 |
f"hit@{self.k}": round(self.hit_at_k, 4),
|
|
|
|
| 614 |
"coverage": round(self.coverage, 4),
|
| 615 |
"novelty": round(self.novelty, 4),
|
| 616 |
}
|
| 617 |
+
for name, ci in [
|
| 618 |
+
("ndcg_ci", self.ndcg_ci),
|
| 619 |
+
("hit_ci", self.hit_ci),
|
| 620 |
+
("mrr_ci", self.mrr_ci),
|
| 621 |
+
]:
|
| 622 |
+
if ci:
|
| 623 |
+
result[name] = ci.to_dict()
|
| 624 |
+
return result
|
| 625 |
+
|
| 626 |
+
def _fmt_metric(
|
| 627 |
+
self, name: str, value: float, ci: ConfidenceInterval | None
|
| 628 |
+
) -> str:
|
| 629 |
+
"""Format a metric with optional CI."""
|
| 630 |
+
if ci:
|
| 631 |
+
return f"{name:<14s} {value:.4f} [{ci.lower:.3f}, {ci.upper:.3f}]"
|
| 632 |
+
return f"{name:<14s} {value:.4f}"
|
| 633 |
|
| 634 |
def __str__(self) -> str:
|
| 635 |
lines = [
|
| 636 |
f"Evaluation Results (n={self.n_cases}, k={self.k})",
|
| 637 |
+
"-" * 50,
|
| 638 |
+
self._fmt_metric(f"NDCG@{self.k}:", self.ndcg_at_k, self.ndcg_ci),
|
| 639 |
+
self._fmt_metric(f"Hit@{self.k}:", self.hit_at_k, self.hit_ci),
|
| 640 |
+
self._fmt_metric("MRR:", self.mrr, self.mrr_ci),
|
| 641 |
+
self._fmt_metric(f"Precision@{self.k}:", self.precision_at_k, None),
|
| 642 |
+
self._fmt_metric(f"Recall@{self.k}:", self.recall_at_k, None),
|
| 643 |
+
"-" * 50,
|
| 644 |
+
self._fmt_metric("Diversity:", self.diversity, None),
|
| 645 |
+
self._fmt_metric("Coverage:", self.coverage, None),
|
| 646 |
+
self._fmt_metric("Novelty:", self.novelty, None),
|
| 647 |
]
|
| 648 |
return "\n".join(lines)
|
sage/services/baselines.py
CHANGED
|
@@ -230,8 +230,6 @@ def load_product_embeddings_from_qdrant() -> dict[str, np.ndarray]:
|
|
| 230 |
product_id = point.payload.get("product_id")
|
| 231 |
product_vectors[product_id].append(np.array(point.vector))
|
| 232 |
|
| 233 |
-
client.close()
|
| 234 |
-
|
| 235 |
# Mean aggregation + normalize
|
| 236 |
return {
|
| 237 |
product_id: normalize_vectors(np.mean(vectors, axis=0))
|
|
@@ -265,8 +263,6 @@ def compute_item_popularity_from_qdrant(
|
|
| 265 |
if point.payload.get("product_id")
|
| 266 |
)
|
| 267 |
|
| 268 |
-
client.close()
|
| 269 |
-
|
| 270 |
if not normalize:
|
| 271 |
return dict(counts)
|
| 272 |
|
|
|
|
| 230 |
product_id = point.payload.get("product_id")
|
| 231 |
product_vectors[product_id].append(np.array(point.vector))
|
| 232 |
|
|
|
|
|
|
|
| 233 |
# Mean aggregation + normalize
|
| 234 |
return {
|
| 235 |
product_id: normalize_vectors(np.mean(vectors, axis=0))
|
|
|
|
| 263 |
if point.payload.get("product_id")
|
| 264 |
)
|
| 265 |
|
|
|
|
|
|
|
| 266 |
if not normalize:
|
| 267 |
return dict(counts)
|
| 268 |
|
sage/services/evaluation.py
CHANGED
|
@@ -20,7 +20,7 @@ from typing import Callable
|
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
|
| 23 |
-
from sage.core import EvalCase, EvalResult, MetricsReport
|
| 24 |
from sage.utils import normalize_vectors
|
| 25 |
|
| 26 |
|
|
@@ -160,6 +160,19 @@ def compute_item_popularity(
|
|
| 160 |
return {item: count / total for item, count in counts.items()}
|
| 161 |
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
class EvaluationService:
|
| 164 |
"""
|
| 165 |
Service for evaluating recommendation quality.
|
|
@@ -260,15 +273,16 @@ class EvaluationService:
|
|
| 260 |
report = MetricsReport(
|
| 261 |
n_cases=len(eval_cases),
|
| 262 |
k=self.k,
|
| 263 |
-
ndcg_at_k=
|
| 264 |
-
hit_at_k=
|
| 265 |
-
mrr=
|
| 266 |
-
precision_at_k=
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
)
|
| 273 |
|
| 274 |
if self.total_items:
|
|
|
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
|
| 23 |
+
from sage.core import ConfidenceInterval, EvalCase, EvalResult, MetricsReport
|
| 24 |
from sage.utils import normalize_vectors
|
| 25 |
|
| 26 |
|
|
|
|
| 160 |
return {item: count / total for item, count in counts.items()}
|
| 161 |
|
| 162 |
|
| 163 |
+
def _safe_mean(scores: list[float]) -> float:
|
| 164 |
+
"""Compute mean of scores, returning 0.0 for empty list."""
|
| 165 |
+
return float(np.mean(scores)) if scores else 0.0
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _compute_ci(scores: list[float]) -> ConfidenceInterval | None:
|
| 169 |
+
"""Compute bootstrap CI for scores, returning None for empty list."""
|
| 170 |
+
if not scores:
|
| 171 |
+
return None
|
| 172 |
+
mean, lower, upper = bootstrap_confidence_interval(scores)
|
| 173 |
+
return ConfidenceInterval(mean=mean, lower=lower, upper=upper)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
class EvaluationService:
|
| 177 |
"""
|
| 178 |
Service for evaluating recommendation quality.
|
|
|
|
| 273 |
report = MetricsReport(
|
| 274 |
n_cases=len(eval_cases),
|
| 275 |
k=self.k,
|
| 276 |
+
ndcg_at_k=_safe_mean(ndcg_scores),
|
| 277 |
+
hit_at_k=_safe_mean(hit_scores),
|
| 278 |
+
mrr=_safe_mean(mrr_scores),
|
| 279 |
+
precision_at_k=_safe_mean(precision_scores),
|
| 280 |
+
recall_at_k=_safe_mean(recall_scores),
|
| 281 |
+
diversity=_safe_mean(diversity_scores),
|
| 282 |
+
novelty=_safe_mean(novelty_scores),
|
| 283 |
+
ndcg_ci=_compute_ci(ndcg_scores),
|
| 284 |
+
hit_ci=_compute_ci(hit_scores),
|
| 285 |
+
mrr_ci=_compute_ci(mrr_scores),
|
| 286 |
)
|
| 287 |
|
| 288 |
if self.total_items:
|
scripts/evaluation.py
CHANGED
|
@@ -87,7 +87,7 @@ def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items)
|
|
| 87 |
)
|
| 88 |
logger.info(str(report))
|
| 89 |
|
| 90 |
-
|
| 91 |
"ndcg_at_10": report.ndcg_at_k,
|
| 92 |
"hit_at_10": report.hit_at_k,
|
| 93 |
"mrr": report.mrr,
|
|
@@ -98,6 +98,17 @@ def run_primary_evaluation(cases, item_embeddings, item_popularity, total_items)
|
|
| 98 |
"novelty": report.novelty,
|
| 99 |
}
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
# ============================================================================
|
| 103 |
# SECTION: Aggregation Methods
|
|
|
|
| 87 |
)
|
| 88 |
logger.info(str(report))
|
| 89 |
|
| 90 |
+
result = {
|
| 91 |
"ndcg_at_10": report.ndcg_at_k,
|
| 92 |
"hit_at_10": report.hit_at_k,
|
| 93 |
"mrr": report.mrr,
|
|
|
|
| 98 |
"novelty": report.novelty,
|
| 99 |
}
|
| 100 |
|
| 101 |
+
# Add confidence intervals if available
|
| 102 |
+
for name, ci in [
|
| 103 |
+
("ndcg_ci", report.ndcg_ci),
|
| 104 |
+
("hit_ci", report.hit_ci),
|
| 105 |
+
("mrr_ci", report.mrr_ci),
|
| 106 |
+
]:
|
| 107 |
+
if ci:
|
| 108 |
+
result[name] = ci.to_dict()
|
| 109 |
+
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
|
| 113 |
# ============================================================================
|
| 114 |
# SECTION: Aggregation Methods
|
scripts/summary.py
CHANGED
|
@@ -42,6 +42,17 @@ def fmt(value: float | None, decimals: int = 4) -> str:
|
|
| 42 |
return f"{value:.{decimals}f}"
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def print_section(title: str):
|
| 46 |
print(f"\n{title}")
|
| 47 |
|
|
@@ -56,9 +67,12 @@ def main():
|
|
| 56 |
print_section("Recommendation Quality (Natural Queries):")
|
| 57 |
if nat and "primary_metrics" in nat:
|
| 58 |
m = nat["primary_metrics"]
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
| 62 |
else:
|
| 63 |
print(" (not available)")
|
| 64 |
|
|
|
|
| 42 |
return f"{value:.{decimals}f}"
|
| 43 |
|
| 44 |
|
| 45 |
+
def fmt_with_ci(value: float | None, ci: dict | None, decimals: int = 3) -> str:
|
| 46 |
+
"""Format a value with optional confidence interval."""
|
| 47 |
+
if value is None:
|
| 48 |
+
return " ---"
|
| 49 |
+
if ci and "ci_lower" in ci and "ci_upper" in ci:
|
| 50 |
+
lower = ci["ci_lower"]
|
| 51 |
+
upper = ci["ci_upper"]
|
| 52 |
+
return f"{value:.{decimals}f} [{lower:.{decimals}f}, {upper:.{decimals}f}]"
|
| 53 |
+
return f"{value:.{decimals}f}"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
def print_section(title: str):
|
| 57 |
print(f"\n{title}")
|
| 58 |
|
|
|
|
| 67 |
print_section("Recommendation Quality (Natural Queries):")
|
| 68 |
if nat and "primary_metrics" in nat:
|
| 69 |
m = nat["primary_metrics"]
|
| 70 |
+
for label, key, ci_key in [
|
| 71 |
+
("NDCG@10", "ndcg_at_10", "ndcg_ci"),
|
| 72 |
+
("Hit@10", "hit_at_10", "hit_ci"),
|
| 73 |
+
("MRR", "mrr", "mrr_ci"),
|
| 74 |
+
]:
|
| 75 |
+
print(f" {label + ':':<10s} {fmt_with_ci(m.get(key), m.get(ci_key))}")
|
| 76 |
else:
|
| 77 |
print(" (not available)")
|
| 78 |
|