File size: 5,211 Bytes
01f4cb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import hashlib
import json
from datetime import datetime, timedelta
from typing import Any

import pandas as pd

from src.storage import load_leaderboard, save_leaderboard

# All available metric columns (computed)
ALL_METRIC_COLS = [
    "Recall@1", "Recall@5", "Recall@10", "Recall@20", "Recall@50", "Recall@100",
    "NDCG@1", "NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50", "NDCG@100",
]

# Default columns shown on leaderboard
DEFAULT_DISPLAY_METRICS = [
    "Recall@1", "Recall@5", "Recall@20", "Recall@50",
    "NDCG@1", "NDCG@5", "NDCG@20", "NDCG@50",
]

# Base columns always shown
BASE_COLS = ["rank", "model_name"]

_DEFAULT_SORT = "Recall@10"
_TOP_N = 30
_RETENTION_DAYS = 30


def make_id(email: str, model_name: str) -> str:
    return hashlib.sha256(f"{email}:{model_name}".encode()).hexdigest()[:16]


class LeaderboardManager:
    def __init__(self):
        self._entries: list[dict] = []
        self._load()
        self._cleanup()

    def _load(self):
        raw = load_leaderboard()
        self._entries = raw

    def _save(self):
        save_leaderboard(self._entries)

    def _cleanup(self):
        """Remove non-paper entries older than 30 days that are not in top 30."""
        if not self._entries:
            return

        df = pd.DataFrame(self._entries)
        if _DEFAULT_SORT in df.columns:
            top_ids = set(
                df.sort_values(by=_DEFAULT_SORT, ascending=False)
                .head(_TOP_N)["submission_id"]
                .tolist()
            )
        else:
            top_ids = set()

        cutoff = datetime.utcnow() - timedelta(days=_RETENTION_DAYS)
        kept = []
        for e in self._entries:
            sid = e.get("submission_id", "")
            is_paper = e.get("is_paper_data", False)
            ts_str = e.get("timestamp", "")
            try:
                ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
            except Exception:
                ts = datetime.utcnow()

            if is_paper or sid in top_ids or ts >= cutoff:
                kept.append(e)

        removed = len(self._entries) - len(kept)
        if removed > 0:
            print(f"[CLEANUP] Removed {removed} expired entries")
            self._entries = kept
            self._save()

    def add_result(
        self,
        email: str,
        method: str,
        model_name: str,
        albums: list[str],
        evaluated_queries: int,
        total_gt_queries: int,
        global_metrics: dict,
    ) -> dict | None:
        """Add a new evaluation result. Returns entry if added, None if not eligible."""
        # Must be a full submission (all 3 albums, all queries matched)
        if set(albums) != {"1", "2", "3"}:
            return None
        if evaluated_queries < total_gt_queries:
            return None

        submission_id = make_id(email, model_name)
        entry = {
            "submission_id": submission_id,
            "timestamp": datetime.utcnow().isoformat() + "Z",
            "email": email,
            "method": method,
            "model_name": model_name,
            "albums": ",".join(albums),
            "is_paper_data": False,
            **{k: round(v, 4) for k, v in global_metrics.items() if k in ALL_METRIC_COLS or k in ("Recall", "NDCG")},
        }

        # Keep best score per (email, model_name)
        key = (email, model_name)
        existing_idx = None
        for i, e in enumerate(self._entries):
            if (e.get("email"), e.get("model_name")) == key:
                existing_idx = i
                break

        if existing_idx is not None:
            old = self._entries[existing_idx]
            if global_metrics.get(_DEFAULT_SORT, 0) >= old.get(_DEFAULT_SORT, 0):
                self._entries[existing_idx] = entry
        else:
            self._entries.append(entry)

        self._save()
        return entry

    def get_display_df(
        self,
        method_filter: str | None = None,
        sort_by: str = _DEFAULT_SORT,
        ascending: bool = False,
        top_n: int = _TOP_N,
        metric_cols: list[str] | None = None,
    ) -> pd.DataFrame:
        """Return a pandas DataFrame ready for gr.DataFrame."""
        cols_to_show = BASE_COLS + (metric_cols or DEFAULT_DISPLAY_METRICS)

        if not self._entries:
            return pd.DataFrame(columns=cols_to_show)

        df = pd.DataFrame(self._entries)

        if method_filter and method_filter != "All":
            df = df[df["method"] == method_filter]

        if sort_by not in df.columns:
            sort_by = _DEFAULT_SORT

        df = df.sort_values(by=sort_by, ascending=ascending)
        df = df.head(top_n).reset_index(drop=True)
        df["rank"] = df.index + 1

        available = [c for c in cols_to_show if c in df.columns]
        df = df[available]
        return df

    def remove_entry(self, submission_id: str) -> bool:
        """Remove an entry by submission_id. Returns True if removed."""
        original_len = len(self._entries)
        self._entries = [e for e in self._entries if e.get("submission_id") != submission_id]
        if len(self._entries) < original_len:
            self._save()
            return True
        return False