PhotoBench-Protected / src /leaderboard_manager.py
SorrowTea's picture
Initial PhotoBench-Protected Leaderboard
01f4cb5
import hashlib
import json
from datetime import datetime, timedelta
from typing import Any
import pandas as pd
from src.storage import load_leaderboard, save_leaderboard
# All available metric columns (computed)
ALL_METRIC_COLS = [
"Recall@1", "Recall@5", "Recall@10", "Recall@20", "Recall@50", "Recall@100",
"NDCG@1", "NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50", "NDCG@100",
]
# Default columns shown on leaderboard
DEFAULT_DISPLAY_METRICS = [
"Recall@1", "Recall@5", "Recall@20", "Recall@50",
"NDCG@1", "NDCG@5", "NDCG@20", "NDCG@50",
]
# Base columns always shown
BASE_COLS = ["rank", "model_name"]
_DEFAULT_SORT = "Recall@10"
_TOP_N = 30
_RETENTION_DAYS = 30
def make_id(email: str, model_name: str) -> str:
return hashlib.sha256(f"{email}:{model_name}".encode()).hexdigest()[:16]
class LeaderboardManager:
def __init__(self):
self._entries: list[dict] = []
self._load()
self._cleanup()
def _load(self):
raw = load_leaderboard()
self._entries = raw
def _save(self):
save_leaderboard(self._entries)
def _cleanup(self):
"""Remove non-paper entries older than 30 days that are not in top 30."""
if not self._entries:
return
df = pd.DataFrame(self._entries)
if _DEFAULT_SORT in df.columns:
top_ids = set(
df.sort_values(by=_DEFAULT_SORT, ascending=False)
.head(_TOP_N)["submission_id"]
.tolist()
)
else:
top_ids = set()
cutoff = datetime.utcnow() - timedelta(days=_RETENTION_DAYS)
kept = []
for e in self._entries:
sid = e.get("submission_id", "")
is_paper = e.get("is_paper_data", False)
ts_str = e.get("timestamp", "")
try:
ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
except Exception:
ts = datetime.utcnow()
if is_paper or sid in top_ids or ts >= cutoff:
kept.append(e)
removed = len(self._entries) - len(kept)
if removed > 0:
print(f"[CLEANUP] Removed {removed} expired entries")
self._entries = kept
self._save()
def add_result(
self,
email: str,
method: str,
model_name: str,
albums: list[str],
evaluated_queries: int,
total_gt_queries: int,
global_metrics: dict,
) -> dict | None:
"""Add a new evaluation result. Returns entry if added, None if not eligible."""
# Must be a full submission (all 3 albums, all queries matched)
if set(albums) != {"1", "2", "3"}:
return None
if evaluated_queries < total_gt_queries:
return None
submission_id = make_id(email, model_name)
entry = {
"submission_id": submission_id,
"timestamp": datetime.utcnow().isoformat() + "Z",
"email": email,
"method": method,
"model_name": model_name,
"albums": ",".join(albums),
"is_paper_data": False,
**{k: round(v, 4) for k, v in global_metrics.items() if k in ALL_METRIC_COLS or k in ("Recall", "NDCG")},
}
# Keep best score per (email, model_name)
key = (email, model_name)
existing_idx = None
for i, e in enumerate(self._entries):
if (e.get("email"), e.get("model_name")) == key:
existing_idx = i
break
if existing_idx is not None:
old = self._entries[existing_idx]
if global_metrics.get(_DEFAULT_SORT, 0) >= old.get(_DEFAULT_SORT, 0):
self._entries[existing_idx] = entry
else:
self._entries.append(entry)
self._save()
return entry
def get_display_df(
self,
method_filter: str | None = None,
sort_by: str = _DEFAULT_SORT,
ascending: bool = False,
top_n: int = _TOP_N,
metric_cols: list[str] | None = None,
) -> pd.DataFrame:
"""Return a pandas DataFrame ready for gr.DataFrame."""
cols_to_show = BASE_COLS + (metric_cols or DEFAULT_DISPLAY_METRICS)
if not self._entries:
return pd.DataFrame(columns=cols_to_show)
df = pd.DataFrame(self._entries)
if method_filter and method_filter != "All":
df = df[df["method"] == method_filter]
if sort_by not in df.columns:
sort_by = _DEFAULT_SORT
df = df.sort_values(by=sort_by, ascending=ascending)
df = df.head(top_n).reset_index(drop=True)
df["rank"] = df.index + 1
available = [c for c in cols_to_show if c in df.columns]
df = df[available]
return df
def remove_entry(self, submission_id: str) -> bool:
"""Remove an entry by submission_id. Returns True if removed."""
original_len = len(self._entries)
self._entries = [e for e in self._entries if e.get("submission_id") != submission_id]
if len(self._entries) < original_len:
self._save()
return True
return False