Spaces:
Running
Running
| import hashlib | |
| import json | |
| from datetime import datetime, timedelta | |
| from typing import Any | |
| import pandas as pd | |
| from src.storage import load_leaderboard, save_leaderboard | |
| # All available metric columns (computed) | |
| ALL_METRIC_COLS = [ | |
| "Recall@1", "Recall@5", "Recall@10", "Recall@20", "Recall@50", "Recall@100", | |
| "NDCG@1", "NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50", "NDCG@100", | |
| ] | |
| # Default columns shown on leaderboard | |
| DEFAULT_DISPLAY_METRICS = [ | |
| "Recall@1", "Recall@5", "Recall@20", "Recall@50", | |
| "NDCG@1", "NDCG@5", "NDCG@20", "NDCG@50", | |
| ] | |
| # Base columns always shown | |
| BASE_COLS = ["rank", "model_name"] | |
| _DEFAULT_SORT = "Recall@10" | |
| _TOP_N = 30 | |
| _RETENTION_DAYS = 30 | |
| def make_id(email: str, model_name: str) -> str: | |
| return hashlib.sha256(f"{email}:{model_name}".encode()).hexdigest()[:16] | |
| class LeaderboardManager: | |
| def __init__(self): | |
| self._entries: list[dict] = [] | |
| self._load() | |
| self._cleanup() | |
| def _load(self): | |
| raw = load_leaderboard() | |
| self._entries = raw | |
| def _save(self): | |
| save_leaderboard(self._entries) | |
| def _cleanup(self): | |
| """Remove non-paper entries older than 30 days that are not in top 30.""" | |
| if not self._entries: | |
| return | |
| df = pd.DataFrame(self._entries) | |
| if _DEFAULT_SORT in df.columns: | |
| top_ids = set( | |
| df.sort_values(by=_DEFAULT_SORT, ascending=False) | |
| .head(_TOP_N)["submission_id"] | |
| .tolist() | |
| ) | |
| else: | |
| top_ids = set() | |
| cutoff = datetime.utcnow() - timedelta(days=_RETENTION_DAYS) | |
| kept = [] | |
| for e in self._entries: | |
| sid = e.get("submission_id", "") | |
| is_paper = e.get("is_paper_data", False) | |
| ts_str = e.get("timestamp", "") | |
| try: | |
| ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00")) | |
| except Exception: | |
| ts = datetime.utcnow() | |
| if is_paper or sid in top_ids or ts >= cutoff: | |
| kept.append(e) | |
| removed = len(self._entries) - len(kept) | |
| if removed > 0: | |
| print(f"[CLEANUP] Removed {removed} expired entries") | |
| self._entries = kept | |
| self._save() | |
| def add_result( | |
| self, | |
| email: str, | |
| method: str, | |
| model_name: str, | |
| albums: list[str], | |
| evaluated_queries: int, | |
| total_gt_queries: int, | |
| global_metrics: dict, | |
| ) -> dict | None: | |
| """Add a new evaluation result. Returns entry if added, None if not eligible.""" | |
| # Must be a full submission (all 3 albums, all queries matched) | |
| if set(albums) != {"1", "2", "3"}: | |
| return None | |
| if evaluated_queries < total_gt_queries: | |
| return None | |
| submission_id = make_id(email, model_name) | |
| entry = { | |
| "submission_id": submission_id, | |
| "timestamp": datetime.utcnow().isoformat() + "Z", | |
| "email": email, | |
| "method": method, | |
| "model_name": model_name, | |
| "albums": ",".join(albums), | |
| "is_paper_data": False, | |
| **{k: round(v, 4) for k, v in global_metrics.items() if k in ALL_METRIC_COLS or k in ("Recall", "NDCG")}, | |
| } | |
| # Keep best score per (email, model_name) | |
| key = (email, model_name) | |
| existing_idx = None | |
| for i, e in enumerate(self._entries): | |
| if (e.get("email"), e.get("model_name")) == key: | |
| existing_idx = i | |
| break | |
| if existing_idx is not None: | |
| old = self._entries[existing_idx] | |
| if global_metrics.get(_DEFAULT_SORT, 0) >= old.get(_DEFAULT_SORT, 0): | |
| self._entries[existing_idx] = entry | |
| else: | |
| self._entries.append(entry) | |
| self._save() | |
| return entry | |
| def get_display_df( | |
| self, | |
| method_filter: str | None = None, | |
| sort_by: str = _DEFAULT_SORT, | |
| ascending: bool = False, | |
| top_n: int = _TOP_N, | |
| metric_cols: list[str] | None = None, | |
| ) -> pd.DataFrame: | |
| """Return a pandas DataFrame ready for gr.DataFrame.""" | |
| cols_to_show = BASE_COLS + (metric_cols or DEFAULT_DISPLAY_METRICS) | |
| if not self._entries: | |
| return pd.DataFrame(columns=cols_to_show) | |
| df = pd.DataFrame(self._entries) | |
| if method_filter and method_filter != "All": | |
| df = df[df["method"] == method_filter] | |
| if sort_by not in df.columns: | |
| sort_by = _DEFAULT_SORT | |
| df = df.sort_values(by=sort_by, ascending=ascending) | |
| df = df.head(top_n).reset_index(drop=True) | |
| df["rank"] = df.index + 1 | |
| available = [c for c in cols_to_show if c in df.columns] | |
| df = df[available] | |
| return df | |
| def remove_entry(self, submission_id: str) -> bool: | |
| """Remove an entry by submission_id. Returns True if removed.""" | |
| original_len = len(self._entries) | |
| self._entries = [e for e in self._entries if e.get("submission_id") != submission_id] | |
| if len(self._entries) < original_len: | |
| self._save() | |
| return True | |
| return False | |