Spaces:
Sleeping
Sleeping
Raw elo
Browse files- app.py +11 -35
- dataset_config.yaml +1 -5
- src/callbacks.py +26 -103
- src/components.py +11 -60
- src/config.py +3 -7
- src/elo.py +80 -243
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
|
|
@@ -11,19 +11,17 @@ from src.callbacks import register_callbacks
|
|
| 11 |
from src import elo
|
| 12 |
from src.galaxy_data_loader import sample_pool_streaming, image_cache
|
| 13 |
from src.galaxy_profiles import register_metadata
|
| 14 |
-
from src.config import POOL_SIZE
|
| 15 |
|
| 16 |
logging.basicConfig(
|
| 17 |
level=logging.INFO,
|
| 18 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 19 |
)
|
| 20 |
-
# Suppress noisy httpx request logs
|
| 21 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
|
| 25 |
def create_app() -> dash.Dash:
|
| 26 |
-
"""Create and configure the Dash application."""
|
| 27 |
app = dash.Dash(
|
| 28 |
__name__,
|
| 29 |
external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.FONT_AWESOME],
|
|
@@ -34,7 +32,6 @@ def create_app() -> dash.Dash:
|
|
| 34 |
|
| 35 |
server = app.server
|
| 36 |
|
| 37 |
-
# Serve galaxy images from cache (populated at startup via streaming)
|
| 38 |
@server.route("/galaxy-images/<int:row_index>.jpg")
|
| 39 |
def serve_galaxy_image(row_index):
|
| 40 |
path = image_cache.get_path(row_index)
|
|
@@ -42,41 +39,20 @@ def create_app() -> dash.Dash:
|
|
| 42 |
abort(404)
|
| 43 |
return send_file(path, mimetype="image/jpeg")
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
logger.info("
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
logger.info(
|
| 54 |
-
"Streaming pool of %d galaxies (seed=%s)...",
|
| 55 |
-
POOL_SIZE,
|
| 56 |
-
seed if seed is not None else "random",
|
| 57 |
-
)
|
| 58 |
-
try:
|
| 59 |
-
pool, metadata_map, used_seed = sample_pool_streaming(POOL_SIZE, seed=seed)
|
| 60 |
-
register_metadata(metadata_map)
|
| 61 |
-
if not loaded:
|
| 62 |
-
elo.initialize_tournament(pool, pool_seed=used_seed)
|
| 63 |
-
else:
|
| 64 |
-
# Persist seed into existing state so future reloads can reuse it
|
| 65 |
-
elo.set_pool_seed(used_seed)
|
| 66 |
-
logger.info(
|
| 67 |
-
"Tournament state restored: round %d, %d active galaxies",
|
| 68 |
-
elo.get_tournament_info().get("current_round", 1),
|
| 69 |
-
len(pool),
|
| 70 |
-
)
|
| 71 |
-
except Exception as e:
|
| 72 |
-
logger.error("Failed to stream galaxy pool: %s", e)
|
| 73 |
-
raise
|
| 74 |
|
| 75 |
-
# Layout and callbacks
|
| 76 |
app.layout = create_layout()
|
| 77 |
register_callbacks(app)
|
| 78 |
|
| 79 |
-
logger.info("
|
| 80 |
return app
|
| 81 |
|
| 82 |
|
|
|
|
| 1 |
+
"""Perihelion - Galaxy Interestingness Ranking."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
|
|
|
|
| 11 |
from src import elo
|
| 12 |
from src.galaxy_data_loader import sample_pool_streaming, image_cache
|
| 13 |
from src.galaxy_profiles import register_metadata
|
| 14 |
+
from src.config import POOL_SIZE, POOL_SEED
|
| 15 |
|
| 16 |
logging.basicConfig(
|
| 17 |
level=logging.INFO,
|
| 18 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 19 |
)
|
|
|
|
| 20 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
|
| 24 |
def create_app() -> dash.Dash:
|
|
|
|
| 25 |
app = dash.Dash(
|
| 26 |
__name__,
|
| 27 |
external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.FONT_AWESOME],
|
|
|
|
| 32 |
|
| 33 |
server = app.server
|
| 34 |
|
|
|
|
| 35 |
@server.route("/galaxy-images/<int:row_index>.jpg")
|
| 36 |
def serve_galaxy_image(row_index):
|
| 37 |
path = image_cache.get_path(row_index)
|
|
|
|
| 39 |
abort(404)
|
| 40 |
return send_file(path, mimetype="image/jpeg")
|
| 41 |
|
| 42 |
+
# Always stream with the fixed seed so every participant sees the same pool
|
| 43 |
+
logger.info("Streaming pool of %d galaxies (seed=%d)...", POOL_SIZE, POOL_SEED)
|
| 44 |
+
pool, metadata_map, _ = sample_pool_streaming(POOL_SIZE, seed=POOL_SEED)
|
| 45 |
+
register_metadata(metadata_map)
|
| 46 |
|
| 47 |
+
# Load persisted ELO state or start fresh
|
| 48 |
+
if not elo.load_elo_state():
|
| 49 |
+
logger.info("No saved state found — initializing fresh ELO rankings")
|
| 50 |
+
elo.initialize_elo(pool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
|
|
|
| 52 |
app.layout = create_layout()
|
| 53 |
register_callbacks(app)
|
| 54 |
|
| 55 |
+
logger.info("Perihelion ready!")
|
| 56 |
return app
|
| 57 |
|
| 58 |
|
dataset_config.yaml
CHANGED
|
@@ -4,10 +4,6 @@ split: "train"
|
|
| 4 |
image_column: "image"
|
| 5 |
id_column: "id_str"
|
| 6 |
pool_size: 1000
|
| 7 |
-
|
| 8 |
-
max_comparisons_per_round: 5
|
| 9 |
-
elimination_fraction: 0.5
|
| 10 |
-
final_pool_size: 100
|
| 11 |
image_cache_dir: "cache/images"
|
| 12 |
image_cache_max_bytes: 524288000 # 500 MB
|
| 13 |
-
cache_prefetch_count: 20
|
|
|
|
| 4 |
image_column: "image"
|
| 5 |
id_column: "id_str"
|
| 6 |
pool_size: 1000
|
| 7 |
+
pool_seed: 42
|
|
|
|
|
|
|
|
|
|
| 8 |
image_cache_dir: "cache/images"
|
| 9 |
image_cache_max_bytes: 524288000 # 500 MB
|
|
|
src/callbacks.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Dash callbacks for
|
| 2 |
|
| 3 |
import uuid
|
| 4 |
import logging
|
|
@@ -17,43 +17,27 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
def register_callbacks(app):
|
| 18 |
"""Register all Dash callbacks."""
|
| 19 |
|
| 20 |
-
# Initial load: populate the arena with the first pair
|
| 21 |
@app.callback(
|
| 22 |
[
|
| 23 |
Output("arena-container", "children"),
|
| 24 |
Output("current-pair", "data"),
|
| 25 |
Output("leaderboard-body", "children"),
|
| 26 |
Output("session-id", "data"),
|
| 27 |
-
Output("
|
| 28 |
Output("progress-dashboard-container", "children"),
|
| 29 |
],
|
| 30 |
Input("arena-container", "id"),
|
| 31 |
)
|
| 32 |
def initial_load(_):
|
| 33 |
session_id = uuid.uuid4().hex
|
| 34 |
-
|
| 35 |
pair = elo.select_pair(set())
|
| 36 |
-
if pair
|
| 37 |
-
|
| 38 |
-
current_pair_data = None
|
| 39 |
-
else:
|
| 40 |
-
arena = create_arena(pair[0], pair[1])
|
| 41 |
-
current_pair_data = list(pair)
|
| 42 |
-
|
| 43 |
leaderboard = create_leaderboard_rows(elo.get_leaderboard())
|
| 44 |
-
info = elo.
|
| 45 |
dashboard = create_progress_dashboard(info)
|
|
|
|
| 46 |
|
| 47 |
-
return (
|
| 48 |
-
arena,
|
| 49 |
-
current_pair_data,
|
| 50 |
-
leaderboard,
|
| 51 |
-
session_id,
|
| 52 |
-
info,
|
| 53 |
-
dashboard,
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
# Card click: pick a winner, update ELO, load next pair
|
| 57 |
@app.callback(
|
| 58 |
[
|
| 59 |
Output("arena-container", "children", allow_duplicate=True),
|
|
@@ -61,13 +45,10 @@ def register_callbacks(app):
|
|
| 61 |
Output("seen-pairs", "data", allow_duplicate=True),
|
| 62 |
Output("comparison-count", "data", allow_duplicate=True),
|
| 63 |
Output("leaderboard-body", "children", allow_duplicate=True),
|
| 64 |
-
Output("
|
| 65 |
Output("progress-dashboard-container", "children", allow_duplicate=True),
|
| 66 |
],
|
| 67 |
-
[
|
| 68 |
-
Input("left-card-btn", "n_clicks"),
|
| 69 |
-
Input("right-card-btn", "n_clicks"),
|
| 70 |
-
],
|
| 71 |
[
|
| 72 |
State("current-pair", "data"),
|
| 73 |
State("seen-pairs", "data"),
|
|
@@ -79,10 +60,8 @@ def register_callbacks(app):
|
|
| 79 |
def handle_card_click(left_clicks, right_clicks, current_pair, seen_pairs, comp_count, session_id):
|
| 80 |
if not ctx.triggered_id:
|
| 81 |
raise PreventUpdate
|
| 82 |
-
|
| 83 |
if (left_clicks in [0, None]) and (right_clicks in [0, None]):
|
| 84 |
raise PreventUpdate
|
| 85 |
-
|
| 86 |
if current_pair is None:
|
| 87 |
raise PreventUpdate
|
| 88 |
|
|
@@ -91,94 +70,54 @@ def register_callbacks(app):
|
|
| 91 |
if comp_count is None:
|
| 92 |
comp_count = 0
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
if triggered == "left-card-btn":
|
| 97 |
-
winner_side = "left"
|
| 98 |
-
elif triggered == "right-card-btn":
|
| 99 |
-
winner_side = "right"
|
| 100 |
-
else:
|
| 101 |
-
raise PreventUpdate
|
| 102 |
-
|
| 103 |
-
left_idx = current_pair[0]
|
| 104 |
-
right_idx = current_pair[1]
|
| 105 |
-
|
| 106 |
-
if winner_side == "left":
|
| 107 |
winner_idx, loser_idx = left_idx, right_idx
|
| 108 |
else:
|
| 109 |
winner_idx, loser_idx = right_idx, left_idx
|
| 110 |
|
| 111 |
-
# Record comparison
|
| 112 |
result = elo.record_comparison(winner_idx, loser_idx)
|
| 113 |
|
| 114 |
-
# Log to HF
|
| 115 |
log_query_event({
|
| 116 |
"log_type": "comparison",
|
| 117 |
"session_id": session_id,
|
| 118 |
"galaxy_left": left_idx,
|
| 119 |
"galaxy_right": right_idx,
|
| 120 |
"winner": winner_idx,
|
| 121 |
-
"
|
| 122 |
-
"
|
| 123 |
-
"
|
| 124 |
-
"
|
| 125 |
-
"elo_left_after": result["winner_elo_after"] if winner_side == "left" else result["loser_elo_after"],
|
| 126 |
-
"elo_right_after": result["loser_elo_after"] if winner_side == "left" else result["winner_elo_after"],
|
| 127 |
})
|
| 128 |
|
| 129 |
-
# Update seen pairs and count
|
| 130 |
seen_pairs.append([left_idx, right_idx])
|
| 131 |
comp_count += 1
|
| 132 |
|
| 133 |
-
|
| 134 |
-
seen_set = set()
|
| 135 |
-
for p in seen_pairs:
|
| 136 |
-
seen_set.add((p[0], p[1]))
|
| 137 |
-
seen_set.add((p[1], p[0]))
|
| 138 |
-
|
| 139 |
pair = elo.select_pair(seen_set)
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
|
| 142 |
-
arena = create_arena(None, None)
|
| 143 |
-
current_pair_data = None
|
| 144 |
-
else:
|
| 145 |
-
arena = create_arena(pair[0], pair[1])
|
| 146 |
-
current_pair_data = list(pair)
|
| 147 |
-
|
| 148 |
-
info = elo.get_tournament_info()
|
| 149 |
leaderboard = create_leaderboard_rows(elo.get_leaderboard())
|
| 150 |
dashboard = create_progress_dashboard(info)
|
| 151 |
|
| 152 |
-
return
|
| 153 |
-
arena,
|
| 154 |
-
current_pair_data,
|
| 155 |
-
seen_pairs,
|
| 156 |
-
comp_count,
|
| 157 |
-
leaderboard,
|
| 158 |
-
info,
|
| 159 |
-
dashboard,
|
| 160 |
-
)
|
| 161 |
|
| 162 |
-
# Progress dashboard update (interval-driven)
|
| 163 |
@app.callback(
|
| 164 |
[
|
| 165 |
-
Output("
|
| 166 |
Output("progress-dashboard-container", "children", allow_duplicate=True),
|
| 167 |
],
|
| 168 |
Input("progress-interval", "n_intervals"),
|
| 169 |
prevent_initial_call=True,
|
| 170 |
)
|
| 171 |
def update_progress(n_intervals):
|
| 172 |
-
info = elo.
|
| 173 |
-
|
| 174 |
-
return info, dashboard
|
| 175 |
|
| 176 |
-
# Leaderboard toggle
|
| 177 |
@app.callback(
|
| 178 |
-
[
|
| 179 |
-
Output("leaderboard-body", "style"),
|
| 180 |
-
Output("leaderboard-arrow", "style"),
|
| 181 |
-
],
|
| 182 |
Input("leaderboard-toggle", "n_clicks"),
|
| 183 |
State("leaderboard-body", "style"),
|
| 184 |
prevent_initial_call=True,
|
|
@@ -194,7 +133,6 @@ def register_callbacks(app):
|
|
| 194 |
{"transition": "transform 0.3s", "fontSize": "0.65rem", "transform": "rotate(0deg)"},
|
| 195 |
)
|
| 196 |
|
| 197 |
-
# Reset session (client-side only — does NOT restart tournament)
|
| 198 |
@app.callback(
|
| 199 |
[
|
| 200 |
Output("arena-container", "children", allow_duplicate=True),
|
|
@@ -209,23 +147,8 @@ def register_callbacks(app):
|
|
| 209 |
def reset_session(n_clicks):
|
| 210 |
if not n_clicks:
|
| 211 |
raise PreventUpdate
|
| 212 |
-
|
| 213 |
pair = elo.select_pair(set())
|
| 214 |
-
if pair
|
| 215 |
-
|
| 216 |
-
current_pair_data = None
|
| 217 |
-
else:
|
| 218 |
-
arena = create_arena(pair[0], pair[1])
|
| 219 |
-
current_pair_data = list(pair)
|
| 220 |
-
|
| 221 |
leaderboard = create_leaderboard_rows(elo.get_leaderboard())
|
| 222 |
-
|
| 223 |
-
return (
|
| 224 |
-
arena,
|
| 225 |
-
current_pair_data,
|
| 226 |
-
[],
|
| 227 |
-
0,
|
| 228 |
-
leaderboard,
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
|
|
|
|
| 1 |
+
"""Dash callbacks for Perihelion."""
|
| 2 |
|
| 3 |
import uuid
|
| 4 |
import logging
|
|
|
|
| 17 |
def register_callbacks(app):
|
| 18 |
"""Register all Dash callbacks."""
|
| 19 |
|
|
|
|
| 20 |
@app.callback(
|
| 21 |
[
|
| 22 |
Output("arena-container", "children"),
|
| 23 |
Output("current-pair", "data"),
|
| 24 |
Output("leaderboard-body", "children"),
|
| 25 |
Output("session-id", "data"),
|
| 26 |
+
Output("elo-info", "data"),
|
| 27 |
Output("progress-dashboard-container", "children"),
|
| 28 |
],
|
| 29 |
Input("arena-container", "id"),
|
| 30 |
)
|
| 31 |
def initial_load(_):
|
| 32 |
session_id = uuid.uuid4().hex
|
|
|
|
| 33 |
pair = elo.select_pair(set())
|
| 34 |
+
arena = create_arena(pair[0], pair[1]) if pair else create_arena(None, None)
|
| 35 |
+
current_pair_data = list(pair) if pair else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
leaderboard = create_leaderboard_rows(elo.get_leaderboard())
|
| 37 |
+
info = elo.get_info()
|
| 38 |
dashboard = create_progress_dashboard(info)
|
| 39 |
+
return arena, current_pair_data, leaderboard, session_id, info, dashboard
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
@app.callback(
|
| 42 |
[
|
| 43 |
Output("arena-container", "children", allow_duplicate=True),
|
|
|
|
| 45 |
Output("seen-pairs", "data", allow_duplicate=True),
|
| 46 |
Output("comparison-count", "data", allow_duplicate=True),
|
| 47 |
Output("leaderboard-body", "children", allow_duplicate=True),
|
| 48 |
+
Output("elo-info", "data", allow_duplicate=True),
|
| 49 |
Output("progress-dashboard-container", "children", allow_duplicate=True),
|
| 50 |
],
|
| 51 |
+
[Input("left-card-btn", "n_clicks"), Input("right-card-btn", "n_clicks")],
|
|
|
|
|
|
|
|
|
|
| 52 |
[
|
| 53 |
State("current-pair", "data"),
|
| 54 |
State("seen-pairs", "data"),
|
|
|
|
| 60 |
def handle_card_click(left_clicks, right_clicks, current_pair, seen_pairs, comp_count, session_id):
|
| 61 |
if not ctx.triggered_id:
|
| 62 |
raise PreventUpdate
|
|
|
|
| 63 |
if (left_clicks in [0, None]) and (right_clicks in [0, None]):
|
| 64 |
raise PreventUpdate
|
|
|
|
| 65 |
if current_pair is None:
|
| 66 |
raise PreventUpdate
|
| 67 |
|
|
|
|
| 70 |
if comp_count is None:
|
| 71 |
comp_count = 0
|
| 72 |
|
| 73 |
+
left_idx, right_idx = current_pair[0], current_pair[1]
|
| 74 |
+
if ctx.triggered_id == "left-card-btn":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
winner_idx, loser_idx = left_idx, right_idx
|
| 76 |
else:
|
| 77 |
winner_idx, loser_idx = right_idx, left_idx
|
| 78 |
|
|
|
|
| 79 |
result = elo.record_comparison(winner_idx, loser_idx)
|
| 80 |
|
|
|
|
| 81 |
log_query_event({
|
| 82 |
"log_type": "comparison",
|
| 83 |
"session_id": session_id,
|
| 84 |
"galaxy_left": left_idx,
|
| 85 |
"galaxy_right": right_idx,
|
| 86 |
"winner": winner_idx,
|
| 87 |
+
"elo_left_before": result["winner_elo_before"] if winner_idx == left_idx else result["loser_elo_before"],
|
| 88 |
+
"elo_right_before": result["loser_elo_before"] if winner_idx == left_idx else result["winner_elo_before"],
|
| 89 |
+
"elo_left_after": result["winner_elo_after"] if winner_idx == left_idx else result["loser_elo_after"],
|
| 90 |
+
"elo_right_after": result["loser_elo_after"] if winner_idx == left_idx else result["winner_elo_after"],
|
|
|
|
|
|
|
| 91 |
})
|
| 92 |
|
|
|
|
| 93 |
seen_pairs.append([left_idx, right_idx])
|
| 94 |
comp_count += 1
|
| 95 |
|
| 96 |
+
seen_set = {(p[0], p[1]) for p in seen_pairs} | {(p[1], p[0]) for p in seen_pairs}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
pair = elo.select_pair(seen_set)
|
| 98 |
+
arena = create_arena(pair[0], pair[1]) if pair else create_arena(None, None)
|
| 99 |
+
current_pair_data = list(pair) if pair else None
|
| 100 |
|
| 101 |
+
info = elo.get_info()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
leaderboard = create_leaderboard_rows(elo.get_leaderboard())
|
| 103 |
dashboard = create_progress_dashboard(info)
|
| 104 |
|
| 105 |
+
return arena, current_pair_data, seen_pairs, comp_count, leaderboard, info, dashboard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
|
|
|
| 107 |
@app.callback(
|
| 108 |
[
|
| 109 |
+
Output("elo-info", "data", allow_duplicate=True),
|
| 110 |
Output("progress-dashboard-container", "children", allow_duplicate=True),
|
| 111 |
],
|
| 112 |
Input("progress-interval", "n_intervals"),
|
| 113 |
prevent_initial_call=True,
|
| 114 |
)
|
| 115 |
def update_progress(n_intervals):
|
| 116 |
+
info = elo.get_info()
|
| 117 |
+
return info, create_progress_dashboard(info)
|
|
|
|
| 118 |
|
|
|
|
| 119 |
@app.callback(
|
| 120 |
+
[Output("leaderboard-body", "style"), Output("leaderboard-arrow", "style")],
|
|
|
|
|
|
|
|
|
|
| 121 |
Input("leaderboard-toggle", "n_clicks"),
|
| 122 |
State("leaderboard-body", "style"),
|
| 123 |
prevent_initial_call=True,
|
|
|
|
| 133 |
{"transition": "transform 0.3s", "fontSize": "0.65rem", "transform": "rotate(0deg)"},
|
| 134 |
)
|
| 135 |
|
|
|
|
| 136 |
@app.callback(
|
| 137 |
[
|
| 138 |
Output("arena-container", "children", allow_duplicate=True),
|
|
|
|
| 147 |
def reset_session(n_clicks):
|
| 148 |
if not n_clicks:
|
| 149 |
raise PreventUpdate
|
|
|
|
| 150 |
pair = elo.select_pair(set())
|
| 151 |
+
arena = create_arena(pair[0], pair[1]) if pair else create_arena(None, None)
|
| 152 |
+
current_pair_data = list(pair) if pair else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
leaderboard = create_leaderboard_rows(elo.get_leaderboard())
|
| 154 |
+
return arena, current_pair_data, [], 0, leaderboard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/components.py
CHANGED
|
@@ -340,7 +340,7 @@ def create_arena(left_idx=None, right_idx=None):
|
|
| 340 |
return html.Div(
|
| 341 |
[
|
| 342 |
html.Div(
|
| 343 |
-
"
|
| 344 |
style={
|
| 345 |
"fontFamily": "'Playfair Display', serif",
|
| 346 |
"fontSize": "1.8rem",
|
|
@@ -350,8 +350,7 @@ def create_arena(left_idx=None, right_idx=None):
|
|
| 350 |
},
|
| 351 |
),
|
| 352 |
html.P(
|
| 353 |
-
"
|
| 354 |
-
"Check the leaderboard below for final rankings!",
|
| 355 |
style={"color": "rgba(255,255,255,0.5)", "maxWidth": "400px", "margin": "0 auto 24px"},
|
| 356 |
),
|
| 357 |
dbc.Button(
|
|
@@ -393,40 +392,25 @@ def create_arena(left_idx=None, right_idx=None):
|
|
| 393 |
|
| 394 |
|
| 395 |
def create_progress_dashboard(info: dict):
|
| 396 |
-
"""Build the
|
| 397 |
-
current_round = info.get("current_round", 0)
|
| 398 |
pool_size = info.get("pool_size", 0)
|
| 399 |
total_comps = info.get("total_comparisons", 0)
|
| 400 |
-
eliminated_count = info.get("eliminated_count", 0)
|
| 401 |
-
est_remaining = info.get("est_remaining_this_round", 0)
|
| 402 |
elo_values = info.get("elo_values", [])
|
| 403 |
-
tournament_complete = info.get("tournament_complete", False)
|
| 404 |
|
| 405 |
-
# Stats row
|
| 406 |
-
status_text = "COMPLETE" if tournament_complete else f"ROUND {current_round}"
|
| 407 |
stats_row = dbc.Row(
|
| 408 |
[
|
| 409 |
-
dbc.Col(html.Div([
|
| 410 |
-
html.Div(status_text, className="progress-stat-value"),
|
| 411 |
-
html.Div("STATUS", className="progress-stat-label"),
|
| 412 |
-
], className="progress-stat"), width=3),
|
| 413 |
dbc.Col(html.Div([
|
| 414 |
html.Div(str(pool_size), className="progress-stat-value"),
|
| 415 |
-
html.Div("
|
| 416 |
-
], className="progress-stat"), width=
|
| 417 |
dbc.Col(html.Div([
|
| 418 |
html.Div(str(total_comps), className="progress-stat-value"),
|
| 419 |
html.Div("COMPARISONS", className="progress-stat-label"),
|
| 420 |
-
], className="progress-stat"), width=
|
| 421 |
-
dbc.Col(html.Div([
|
| 422 |
-
html.Div(str(eliminated_count), className="progress-stat-value"),
|
| 423 |
-
html.Div("ELIMINATED", className="progress-stat-label"),
|
| 424 |
-
], className="progress-stat"), width=3),
|
| 425 |
],
|
| 426 |
className="mb-3",
|
| 427 |
)
|
| 428 |
|
| 429 |
-
# ELO distribution histogram
|
| 430 |
if elo_values:
|
| 431 |
fig = go.Figure(data=[go.Histogram(
|
| 432 |
x=elo_values,
|
|
@@ -443,47 +427,14 @@ def create_progress_dashboard(info: dict):
|
|
| 443 |
font_size=10,
|
| 444 |
margin=dict(l=30, r=10, t=10, b=30),
|
| 445 |
height=120,
|
| 446 |
-
xaxis=dict(
|
| 447 |
-
|
| 448 |
-
title_text="ELO Rating",
|
| 449 |
-
title_font_size=9,
|
| 450 |
-
),
|
| 451 |
-
yaxis=dict(
|
| 452 |
-
gridcolor="rgba(255,255,255,0.05)",
|
| 453 |
-
title_text="Count",
|
| 454 |
-
title_font_size=9,
|
| 455 |
-
),
|
| 456 |
-
)
|
| 457 |
-
histogram = dcc.Graph(
|
| 458 |
-
figure=fig,
|
| 459 |
-
config={"displayModeBar": False},
|
| 460 |
-
style={"height": "120px"},
|
| 461 |
)
|
|
|
|
| 462 |
else:
|
| 463 |
histogram = html.Div()
|
| 464 |
|
| 465 |
-
|
| 466 |
-
remaining_text = (
|
| 467 |
-
"Tournament complete!" if tournament_complete
|
| 468 |
-
else f"~{est_remaining} comparisons remaining this round"
|
| 469 |
-
)
|
| 470 |
-
|
| 471 |
-
return html.Div(
|
| 472 |
-
[
|
| 473 |
-
stats_row,
|
| 474 |
-
histogram,
|
| 475 |
-
html.Div(
|
| 476 |
-
remaining_text,
|
| 477 |
-
style={
|
| 478 |
-
"textAlign": "center",
|
| 479 |
-
"fontSize": "0.7rem",
|
| 480 |
-
"color": "rgba(255,255,255,0.3)",
|
| 481 |
-
"marginTop": "8px",
|
| 482 |
-
},
|
| 483 |
-
),
|
| 484 |
-
],
|
| 485 |
-
className="progress-dashboard",
|
| 486 |
-
)
|
| 487 |
|
| 488 |
|
| 489 |
def create_leaderboard_rows(leaderboard_data):
|
|
@@ -571,7 +522,7 @@ def create_layout():
|
|
| 571 |
dcc.Store(id="seen-pairs", data=[]),
|
| 572 |
dcc.Store(id="current-pair", data=None),
|
| 573 |
dcc.Store(id="comparison-count", data=0),
|
| 574 |
-
dcc.Store(id="
|
| 575 |
dcc.Store(id="session-id", data=""),
|
| 576 |
|
| 577 |
# Interval for progress updates
|
|
|
|
| 340 |
return html.Div(
|
| 341 |
[
|
| 342 |
html.Div(
|
| 343 |
+
"You've seen every pair!",
|
| 344 |
style={
|
| 345 |
"fontFamily": "'Playfair Display', serif",
|
| 346 |
"fontSize": "1.8rem",
|
|
|
|
| 350 |
},
|
| 351 |
),
|
| 352 |
html.P(
|
| 353 |
+
"Reset your session to keep voting and refine the rankings.",
|
|
|
|
| 354 |
style={"color": "rgba(255,255,255,0.5)", "maxWidth": "400px", "margin": "0 auto 24px"},
|
| 355 |
),
|
| 356 |
dbc.Button(
|
|
|
|
| 392 |
|
| 393 |
|
| 394 |
def create_progress_dashboard(info: dict):
|
| 395 |
+
"""Build the ELO ranking progress dashboard."""
|
|
|
|
| 396 |
pool_size = info.get("pool_size", 0)
|
| 397 |
total_comps = info.get("total_comparisons", 0)
|
|
|
|
|
|
|
| 398 |
elo_values = info.get("elo_values", [])
|
|
|
|
| 399 |
|
|
|
|
|
|
|
| 400 |
stats_row = dbc.Row(
|
| 401 |
[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
dbc.Col(html.Div([
|
| 403 |
html.Div(str(pool_size), className="progress-stat-value"),
|
| 404 |
+
html.Div("GALAXIES", className="progress-stat-label"),
|
| 405 |
+
], className="progress-stat"), width=6),
|
| 406 |
dbc.Col(html.Div([
|
| 407 |
html.Div(str(total_comps), className="progress-stat-value"),
|
| 408 |
html.Div("COMPARISONS", className="progress-stat-label"),
|
| 409 |
+
], className="progress-stat"), width=6),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
],
|
| 411 |
className="mb-3",
|
| 412 |
)
|
| 413 |
|
|
|
|
| 414 |
if elo_values:
|
| 415 |
fig = go.Figure(data=[go.Histogram(
|
| 416 |
x=elo_values,
|
|
|
|
| 427 |
font_size=10,
|
| 428 |
margin=dict(l=30, r=10, t=10, b=30),
|
| 429 |
height=120,
|
| 430 |
+
xaxis=dict(gridcolor="rgba(255,255,255,0.05)", title_text="ELO Rating", title_font_size=9),
|
| 431 |
+
yaxis=dict(gridcolor="rgba(255,255,255,0.05)", title_text="Count", title_font_size=9),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
)
|
| 433 |
+
histogram = dcc.Graph(figure=fig, config={"displayModeBar": False}, style={"height": "120px"})
|
| 434 |
else:
|
| 435 |
histogram = html.Div()
|
| 436 |
|
| 437 |
+
return html.Div([stats_row, histogram], className="progress-dashboard")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
|
| 440 |
def create_leaderboard_rows(leaderboard_data):
|
|
|
|
| 522 |
dcc.Store(id="seen-pairs", data=[]),
|
| 523 |
dcc.Store(id="current-pair", data=None),
|
| 524 |
dcc.Store(id="comparison-count", data=0),
|
| 525 |
+
dcc.Store(id="elo-info", data={}),
|
| 526 |
dcc.Store(id="session-id", data=""),
|
| 527 |
|
| 528 |
# Interval for progress updates
|
src/config.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
|
@@ -8,7 +8,7 @@ from dotenv import load_dotenv
|
|
| 8 |
|
| 9 |
load_dotenv()
|
| 10 |
|
| 11 |
-
# HuggingFace
|
| 12 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 13 |
HF_LOG_REPO_ID = os.getenv("HF_LOG_REPO_ID", "")
|
| 14 |
HF_LOG_EVERY_MINUTES = int(os.getenv("HF_LOG_EVERY_MINUTES", "10"))
|
|
@@ -28,10 +28,6 @@ DATASET_SPLIT = _dataset_config.get("split", "train")
|
|
| 28 |
IMAGE_COLUMN = _dataset_config.get("image_column", "image")
|
| 29 |
ID_COLUMN = _dataset_config.get("id_column", "id_str")
|
| 30 |
POOL_SIZE = _dataset_config.get("pool_size", 300)
|
| 31 |
-
|
| 32 |
-
MAX_COMPS_PER_ROUND = _dataset_config.get("max_comparisons_per_round", 5)
|
| 33 |
-
ELIMINATION_FRACTION = _dataset_config.get("elimination_fraction", 0.5)
|
| 34 |
-
FINAL_POOL_SIZE = _dataset_config.get("final_pool_size", 100)
|
| 35 |
IMAGE_CACHE_DIR = _dataset_config.get("image_cache_dir", "cache/images")
|
| 36 |
IMAGE_CACHE_MAX_BYTES = _dataset_config.get("image_cache_max_bytes", 524288000)
|
| 37 |
-
CACHE_PREFETCH_COUNT = _dataset_config.get("cache_prefetch_count", 20)
|
|
|
|
| 1 |
+
"""Perihelion configuration."""
|
| 2 |
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
|
|
|
| 8 |
|
| 9 |
load_dotenv()
|
| 10 |
|
| 11 |
+
# HuggingFace (secrets stay as env vars)
|
| 12 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 13 |
HF_LOG_REPO_ID = os.getenv("HF_LOG_REPO_ID", "")
|
| 14 |
HF_LOG_EVERY_MINUTES = int(os.getenv("HF_LOG_EVERY_MINUTES", "10"))
|
|
|
|
| 28 |
IMAGE_COLUMN = _dataset_config.get("image_column", "image")
|
| 29 |
ID_COLUMN = _dataset_config.get("id_column", "id_str")
|
| 30 |
POOL_SIZE = _dataset_config.get("pool_size", 300)
|
| 31 |
+
POOL_SEED = _dataset_config.get("pool_seed", 42)
|
|
|
|
|
|
|
|
|
|
| 32 |
IMAGE_CACHE_DIR = _dataset_config.get("image_cache_dir", "cache/images")
|
| 33 |
IMAGE_CACHE_MAX_BYTES = _dataset_config.get("image_cache_max_bytes", 524288000)
|
|
|
src/elo.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
-
"""ELO rating system
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
| 6 |
-
import math
|
| 7 |
import random
|
| 8 |
import threading
|
| 9 |
import logging
|
|
@@ -12,14 +11,12 @@ from pathlib import Path
|
|
| 12 |
from huggingface_hub import CommitScheduler, hf_hub_download
|
| 13 |
|
| 14 |
from src.config import (
|
|
|
|
| 15 |
DEFAULT_ELO,
|
| 16 |
ELO_K_FACTOR,
|
| 17 |
-
ELIMINATION_FRACTION,
|
| 18 |
-
FINAL_POOL_SIZE,
|
| 19 |
HF_LOG_EVERY_MINUTES,
|
| 20 |
HF_LOG_REPO_ID,
|
| 21 |
HF_TOKEN,
|
| 22 |
-
MIN_COMPS_PER_ROUND,
|
| 23 |
)
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
|
@@ -28,61 +25,44 @@ STATE_DIR = Path("state")
|
|
| 28 |
STATE_FILE = STATE_DIR / "elo_state.json"
|
| 29 |
|
| 30 |
_lock = threading.Lock()
|
| 31 |
-
_state:
|
| 32 |
_state_scheduler = None
|
| 33 |
|
| 34 |
|
| 35 |
-
class
|
| 36 |
-
"""
|
| 37 |
|
| 38 |
def __init__(
|
| 39 |
self,
|
| 40 |
-
|
| 41 |
elo_ratings: dict[int, float] | None = None,
|
| 42 |
-
round_comparisons: dict[int, int] | None = None,
|
| 43 |
-
current_round: int = 1,
|
| 44 |
-
eliminated: list[int] | None = None,
|
| 45 |
total_comparisons: int = 0,
|
| 46 |
-
|
| 47 |
-
pool_seed: int | None = None,
|
| 48 |
):
|
| 49 |
-
self.
|
| 50 |
-
self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in
|
| 51 |
-
self.round_comparisons = round_comparisons or {idx: 0 for idx in active_pool}
|
| 52 |
-
self.current_round = current_round
|
| 53 |
-
self.eliminated = eliminated or []
|
| 54 |
self.total_comparisons = total_comparisons
|
| 55 |
-
self.
|
| 56 |
-
self.pool_seed = pool_seed
|
| 57 |
|
| 58 |
def to_dict(self) -> dict:
|
| 59 |
return {
|
| 60 |
-
"
|
| 61 |
"elo_ratings": {str(k): v for k, v in self.elo_ratings.items()},
|
| 62 |
-
"round_comparisons": {str(k): v for k, v in self.round_comparisons.items()},
|
| 63 |
-
"current_round": self.current_round,
|
| 64 |
-
"eliminated": self.eliminated,
|
| 65 |
"total_comparisons": self.total_comparisons,
|
| 66 |
-
"
|
| 67 |
-
"pool_seed": self.pool_seed,
|
| 68 |
}
|
| 69 |
|
| 70 |
@classmethod
|
| 71 |
-
def from_dict(cls, d: dict) ->
|
| 72 |
return cls(
|
| 73 |
-
|
| 74 |
elo_ratings={int(k): v for k, v in d["elo_ratings"].items()},
|
| 75 |
-
round_comparisons={int(k): v for k, v in d["round_comparisons"].items()},
|
| 76 |
-
current_round=d["current_round"],
|
| 77 |
-
eliminated=d.get("eliminated", []),
|
| 78 |
total_comparisons=d.get("total_comparisons", 0),
|
| 79 |
-
|
| 80 |
-
pool_seed=d.get("pool_seed"),
|
| 81 |
)
|
| 82 |
|
| 83 |
|
| 84 |
def _init_scheduler():
|
| 85 |
-
"""Initialize the CommitScheduler for state persistence."""
|
| 86 |
global _state_scheduler
|
| 87 |
if not HF_LOG_REPO_ID:
|
| 88 |
return
|
|
@@ -98,24 +78,26 @@ def _init_scheduler():
|
|
| 98 |
logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID)
|
| 99 |
|
| 100 |
|
| 101 |
-
def
|
| 102 |
-
"""Create
|
| 103 |
global _state
|
| 104 |
with _lock:
|
| 105 |
-
_state =
|
| 106 |
_save_state()
|
| 107 |
_init_scheduler()
|
| 108 |
-
logger.info("
|
| 109 |
|
| 110 |
|
| 111 |
-
def
|
| 112 |
-
"""Try to restore
|
| 113 |
|
|
|
|
| 114 |
Returns True if state was loaded, False if starting fresh.
|
| 115 |
"""
|
| 116 |
global _state
|
| 117 |
|
| 118 |
-
|
|
|
|
| 119 |
if HF_LOG_REPO_ID:
|
| 120 |
try:
|
| 121 |
local_path = hf_hub_download(
|
|
@@ -126,46 +108,46 @@ def load_tournament_state() -> bool:
|
|
| 126 |
)
|
| 127 |
with open(local_path) as f:
|
| 128 |
raw = json.load(f)
|
| 129 |
-
|
| 130 |
-
if "active_pool" in raw:
|
| 131 |
-
with _lock:
|
| 132 |
-
_state = TournamentState.from_dict(raw)
|
| 133 |
-
_init_scheduler()
|
| 134 |
-
_save_state()
|
| 135 |
-
logger.info(
|
| 136 |
-
"Loaded tournament state from HF: round %d, %d active galaxies",
|
| 137 |
-
_state.current_round,
|
| 138 |
-
len(_state.active_pool),
|
| 139 |
-
)
|
| 140 |
-
return True
|
| 141 |
-
else:
|
| 142 |
-
logger.info("Old-format state found on HF, ignoring")
|
| 143 |
except Exception as e:
|
| 144 |
logger.warning("Could not load state from HF: %s", e)
|
| 145 |
|
| 146 |
-
|
| 147 |
-
if STATE_FILE.exists():
|
| 148 |
try:
|
| 149 |
with open(STATE_FILE) as f:
|
| 150 |
raw = json.load(f)
|
| 151 |
-
|
| 152 |
-
with _lock:
|
| 153 |
-
_state = TournamentState.from_dict(raw)
|
| 154 |
-
_init_scheduler()
|
| 155 |
-
logger.info(
|
| 156 |
-
"Loaded tournament state from local file: round %d, %d active",
|
| 157 |
-
_state.current_round,
|
| 158 |
-
len(_state.active_pool),
|
| 159 |
-
)
|
| 160 |
-
return True
|
| 161 |
except Exception as e:
|
| 162 |
logger.warning("Could not load local state: %s", e)
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
|
| 167 |
def _save_state():
|
| 168 |
-
"""Write current tournament state to local JSON file."""
|
| 169 |
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
| 170 |
with _lock:
|
| 171 |
if _state is None:
|
|
@@ -180,13 +162,10 @@ def _expected_score(rating_a: float, rating_b: float) -> float:
|
|
| 180 |
|
| 181 |
|
| 182 |
def record_comparison(winner_idx: int, loser_idx: int) -> dict:
|
| 183 |
-
"""Record a comparison
|
| 184 |
-
|
| 185 |
-
Returns dict with before/after ratings and round info.
|
| 186 |
-
"""
|
| 187 |
with _lock:
|
| 188 |
if _state is None:
|
| 189 |
-
raise RuntimeError("
|
| 190 |
|
| 191 |
elo_w_before = _state.elo_ratings.get(winner_idx, DEFAULT_ELO)
|
| 192 |
elo_l_before = _state.elo_ratings.get(loser_idx, DEFAULT_ELO)
|
|
@@ -199,14 +178,8 @@ def record_comparison(winner_idx: int, loser_idx: int) -> dict:
|
|
| 199 |
|
| 200 |
_state.elo_ratings[winner_idx] = elo_w_after
|
| 201 |
_state.elo_ratings[loser_idx] = elo_l_after
|
| 202 |
-
|
| 203 |
-
_state.round_comparisons[winner_idx] = _state.round_comparisons.get(winner_idx, 0) + 1
|
| 204 |
-
_state.round_comparisons[loser_idx] = _state.round_comparisons.get(loser_idx, 0) + 1
|
| 205 |
_state.total_comparisons += 1
|
| 206 |
|
| 207 |
-
round_before = _state.current_round
|
| 208 |
-
advanced = _check_and_advance_round()
|
| 209 |
-
|
| 210 |
_save_state()
|
| 211 |
|
| 212 |
return {
|
|
@@ -214,204 +187,68 @@ def record_comparison(winner_idx: int, loser_idx: int) -> dict:
|
|
| 214 |
"winner_elo_after": elo_w_after,
|
| 215 |
"loser_elo_before": elo_l_before,
|
| 216 |
"loser_elo_after": elo_l_after,
|
| 217 |
-
"round": round_before,
|
| 218 |
-
"round_advanced": advanced,
|
| 219 |
}
|
| 220 |
|
| 221 |
|
| 222 |
-
def _check_and_advance_round() -> bool:
|
| 223 |
-
"""Check if all active galaxies have enough comparisons; if so, advance.
|
| 224 |
-
|
| 225 |
-
Caller must hold _lock.
|
| 226 |
-
Returns True if a round was advanced.
|
| 227 |
-
"""
|
| 228 |
-
if _state is None or _state.tournament_complete:
|
| 229 |
-
return False
|
| 230 |
-
|
| 231 |
-
for idx in _state.active_pool:
|
| 232 |
-
if _state.round_comparisons.get(idx, 0) < MIN_COMPS_PER_ROUND:
|
| 233 |
-
return False
|
| 234 |
-
|
| 235 |
-
# All galaxies have enough comparisons — advance round
|
| 236 |
-
_advance_round()
|
| 237 |
-
return True
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def _advance_round():
|
| 241 |
-
"""Eliminate bottom fraction, advance to next round. Caller holds _lock."""
|
| 242 |
-
if _state is None:
|
| 243 |
-
return
|
| 244 |
-
|
| 245 |
-
# Sort active pool by ELO descending
|
| 246 |
-
sorted_pool = sorted(
|
| 247 |
-
_state.active_pool,
|
| 248 |
-
key=lambda idx: _state.elo_ratings.get(idx, DEFAULT_ELO),
|
| 249 |
-
reverse=True,
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
keep_count = max(
|
| 253 |
-
FINAL_POOL_SIZE,
|
| 254 |
-
int(math.ceil(len(sorted_pool) * (1 - ELIMINATION_FRACTION))),
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
survivors = sorted_pool[:keep_count]
|
| 258 |
-
eliminated = sorted_pool[keep_count:]
|
| 259 |
-
|
| 260 |
-
_state.eliminated.extend(eliminated)
|
| 261 |
-
_state.active_pool = survivors
|
| 262 |
-
_state.round_comparisons = {idx: 0 for idx in survivors}
|
| 263 |
-
_state.current_round += 1
|
| 264 |
-
|
| 265 |
-
if len(survivors) <= FINAL_POOL_SIZE:
|
| 266 |
-
_state.tournament_complete = True
|
| 267 |
-
logger.info("Tournament complete! %d galaxies in final pool.", len(survivors))
|
| 268 |
-
else:
|
| 269 |
-
logger.info(
|
| 270 |
-
"Round %d: %d -> %d galaxies (eliminated %d)",
|
| 271 |
-
_state.current_round - 1,
|
| 272 |
-
len(sorted_pool),
|
| 273 |
-
len(survivors),
|
| 274 |
-
len(eliminated),
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
|
| 278 |
def select_pair(seen_pairs: set[tuple[int, int]]) -> tuple[int, int] | None:
|
| 279 |
-
"""
|
| 280 |
|
| 281 |
-
|
| 282 |
-
Returns None if tournament is complete or no pairs available.
|
| 283 |
"""
|
| 284 |
with _lock:
|
| 285 |
-
if _state is None
|
| 286 |
return None
|
| 287 |
-
|
| 288 |
-
pool = list(_state.active_pool)
|
| 289 |
if len(pool) < 2:
|
| 290 |
return None
|
| 291 |
|
| 292 |
-
# Prioritize galaxies needing more comparisons
|
| 293 |
-
needs_more = [
|
| 294 |
-
idx for idx in pool
|
| 295 |
-
if _state.round_comparisons.get(idx, 0) < MIN_COMPS_PER_ROUND
|
| 296 |
-
]
|
| 297 |
-
|
| 298 |
-
if not needs_more:
|
| 299 |
-
# All have enough — round should advance soon, but pick a pair anyway
|
| 300 |
-
needs_more = pool
|
| 301 |
-
|
| 302 |
-
# Swiss-style: pair galaxies with similar ELO
|
| 303 |
if random.random() < 0.3:
|
| 304 |
-
|
| 305 |
-
if len(needs_more) >= 2:
|
| 306 |
-
pair = random.sample(needs_more, 2)
|
| 307 |
-
else:
|
| 308 |
-
pair = random.sample(pool, 2)
|
| 309 |
else:
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
if (pair[0], pair[1]) in seen_pairs or (pair[1], pair[0]) in seen_pairs:
|
| 325 |
-
# Try a few more random attempts
|
| 326 |
-
for _ in range(50):
|
| 327 |
-
pair = random.sample(pool, 2)
|
| 328 |
-
if (pair[0], pair[1]) not in seen_pairs and (pair[1], pair[0]) not in seen_pairs:
|
| 329 |
-
break
|
| 330 |
-
else:
|
| 331 |
-
# All pairs exhausted for this session
|
| 332 |
-
return None
|
| 333 |
-
|
| 334 |
-
# Randomize left/right
|
| 335 |
if random.random() < 0.5:
|
| 336 |
return (pair[1], pair[0])
|
| 337 |
return (pair[0], pair[1])
|
| 338 |
|
| 339 |
|
| 340 |
-
def
|
| 341 |
-
"""Return
|
| 342 |
-
with _lock:
|
| 343 |
-
return _state.pool_seed if _state else None
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
def set_pool_seed(seed: int):
|
| 347 |
-
"""Store the pool seed into the current tournament state and save."""
|
| 348 |
-
with _lock:
|
| 349 |
-
if _state is not None:
|
| 350 |
-
_state.pool_seed = seed
|
| 351 |
-
_save_state()
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def get_tournament_info() -> dict:
|
| 355 |
-
"""Return a snapshot of tournament state for the progress dashboard."""
|
| 356 |
with _lock:
|
| 357 |
if _state is None:
|
| 358 |
-
return {
|
| 359 |
-
"current_round": 0,
|
| 360 |
-
"pool_size": 0,
|
| 361 |
-
"total_comparisons": 0,
|
| 362 |
-
"tournament_complete": False,
|
| 363 |
-
"elo_values": [],
|
| 364 |
-
"top_indices": [],
|
| 365 |
-
"eliminated_count": 0,
|
| 366 |
-
}
|
| 367 |
-
|
| 368 |
-
elo_values = [_state.elo_ratings.get(idx, DEFAULT_ELO) for idx in _state.active_pool]
|
| 369 |
-
|
| 370 |
-
# Top 100 by ELO
|
| 371 |
-
sorted_pool = sorted(
|
| 372 |
-
_state.active_pool,
|
| 373 |
-
key=lambda idx: _state.elo_ratings.get(idx, DEFAULT_ELO),
|
| 374 |
-
reverse=True,
|
| 375 |
-
)
|
| 376 |
-
top_indices = sorted_pool[:100]
|
| 377 |
-
|
| 378 |
-
# Estimate remaining comparisons
|
| 379 |
-
comps_needed_this_round = sum(
|
| 380 |
-
max(0, MIN_COMPS_PER_ROUND - _state.round_comparisons.get(idx, 0))
|
| 381 |
-
for idx in _state.active_pool
|
| 382 |
-
)
|
| 383 |
-
# Each comparison covers 2 galaxies
|
| 384 |
-
est_remaining_this_round = max(0, comps_needed_this_round // 2)
|
| 385 |
-
|
| 386 |
return {
|
| 387 |
-
"
|
| 388 |
-
"pool_size": len(_state.active_pool),
|
| 389 |
"total_comparisons": _state.total_comparisons,
|
| 390 |
-
"
|
| 391 |
-
"elo_values": elo_values,
|
| 392 |
-
"top_indices": top_indices,
|
| 393 |
-
"eliminated_count": len(_state.eliminated),
|
| 394 |
-
"est_remaining_this_round": est_remaining_this_round,
|
| 395 |
}
|
| 396 |
|
| 397 |
|
| 398 |
def get_leaderboard() -> list[dict]:
|
| 399 |
-
"""
|
| 400 |
with _lock:
|
| 401 |
if _state is None:
|
| 402 |
return []
|
| 403 |
return sorted(
|
| 404 |
-
[
|
| 405 |
-
{"id": idx, "elo": _state.elo_ratings.get(idx, DEFAULT_ELO)}
|
| 406 |
-
for idx in _state.active_pool
|
| 407 |
-
],
|
| 408 |
key=lambda x: x["elo"],
|
| 409 |
reverse=True,
|
| 410 |
)[:20]
|
| 411 |
|
| 412 |
|
| 413 |
def get_rating(galaxy_idx: int) -> float:
|
| 414 |
-
"""Get current ELO rating for a galaxy."""
|
| 415 |
with _lock:
|
| 416 |
if _state is None:
|
| 417 |
return DEFAULT_ELO
|
|
|
|
| 1 |
+
"""ELO rating system for a persistent galaxy ranking."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
|
|
|
| 6 |
import random
|
| 7 |
import threading
|
| 8 |
import logging
|
|
|
|
| 11 |
from huggingface_hub import CommitScheduler, hf_hub_download
|
| 12 |
|
| 13 |
from src.config import (
|
| 14 |
+
DATASET_ID,
|
| 15 |
DEFAULT_ELO,
|
| 16 |
ELO_K_FACTOR,
|
|
|
|
|
|
|
| 17 |
HF_LOG_EVERY_MINUTES,
|
| 18 |
HF_LOG_REPO_ID,
|
| 19 |
HF_TOKEN,
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
|
|
|
| 25 |
STATE_FILE = STATE_DIR / "elo_state.json"
|
| 26 |
|
| 27 |
_lock = threading.Lock()
|
| 28 |
+
_state: EloState | None = None
|
| 29 |
_state_scheduler = None
|
| 30 |
|
| 31 |
|
| 32 |
+
class EloState:
|
| 33 |
+
"""ELO ratings for a fixed pool of galaxies."""
|
| 34 |
|
| 35 |
def __init__(
|
| 36 |
self,
|
| 37 |
+
pool: list[int],
|
| 38 |
elo_ratings: dict[int, float] | None = None,
|
|
|
|
|
|
|
|
|
|
| 39 |
total_comparisons: int = 0,
|
| 40 |
+
dataset_id: str = "",
|
|
|
|
| 41 |
):
|
| 42 |
+
self.pool = list(pool)
|
| 43 |
+
self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in pool}
|
|
|
|
|
|
|
|
|
|
| 44 |
self.total_comparisons = total_comparisons
|
| 45 |
+
self.dataset_id = dataset_id
|
|
|
|
| 46 |
|
| 47 |
def to_dict(self) -> dict:
|
| 48 |
return {
|
| 49 |
+
"pool": self.pool,
|
| 50 |
"elo_ratings": {str(k): v for k, v in self.elo_ratings.items()},
|
|
|
|
|
|
|
|
|
|
| 51 |
"total_comparisons": self.total_comparisons,
|
| 52 |
+
"dataset_id": self.dataset_id,
|
|
|
|
| 53 |
}
|
| 54 |
|
| 55 |
@classmethod
|
| 56 |
+
def from_dict(cls, d: dict) -> EloState:
|
| 57 |
return cls(
|
| 58 |
+
pool=d["pool"],
|
| 59 |
elo_ratings={int(k): v for k, v in d["elo_ratings"].items()},
|
|
|
|
|
|
|
|
|
|
| 60 |
total_comparisons=d.get("total_comparisons", 0),
|
| 61 |
+
dataset_id=d.get("dataset_id", ""),
|
|
|
|
| 62 |
)
|
| 63 |
|
| 64 |
|
| 65 |
def _init_scheduler():
|
|
|
|
| 66 |
global _state_scheduler
|
| 67 |
if not HF_LOG_REPO_ID:
|
| 68 |
return
|
|
|
|
| 78 |
logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID)
|
| 79 |
|
| 80 |
|
| 81 |
+
def initialize_elo(pool_indices: list[int]):
|
| 82 |
+
"""Create fresh ELO state for the given pool."""
|
| 83 |
global _state
|
| 84 |
with _lock:
|
| 85 |
+
_state = EloState(pool=pool_indices, dataset_id=DATASET_ID)
|
| 86 |
_save_state()
|
| 87 |
_init_scheduler()
|
| 88 |
+
logger.info("ELO state initialized with %d galaxies", len(pool_indices))
|
| 89 |
|
| 90 |
|
| 91 |
+
def load_elo_state() -> bool:
|
| 92 |
+
"""Try to restore ELO state from HF Hub or local file.
|
| 93 |
|
| 94 |
+
Discards saved state if it belongs to a different dataset.
|
| 95 |
Returns True if state was loaded, False if starting fresh.
|
| 96 |
"""
|
| 97 |
global _state
|
| 98 |
|
| 99 |
+
raw = None
|
| 100 |
+
|
| 101 |
if HF_LOG_REPO_ID:
|
| 102 |
try:
|
| 103 |
local_path = hf_hub_download(
|
|
|
|
| 108 |
)
|
| 109 |
with open(local_path) as f:
|
| 110 |
raw = json.load(f)
|
| 111 |
+
logger.info("Loaded state from HF Hub")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
except Exception as e:
|
| 113 |
logger.warning("Could not load state from HF: %s", e)
|
| 114 |
|
| 115 |
+
if raw is None and STATE_FILE.exists():
|
|
|
|
| 116 |
try:
|
| 117 |
with open(STATE_FILE) as f:
|
| 118 |
raw = json.load(f)
|
| 119 |
+
logger.info("Loaded state from local file")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
except Exception as e:
|
| 121 |
logger.warning("Could not load local state: %s", e)
|
| 122 |
|
| 123 |
+
if raw is None:
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
# Validate dataset match
|
| 127 |
+
saved_dataset = raw.get("dataset_id", "")
|
| 128 |
+
if saved_dataset and saved_dataset != DATASET_ID:
|
| 129 |
+
logger.info(
|
| 130 |
+
"Saved state is for dataset '%s', current is '%s' — starting fresh",
|
| 131 |
+
saved_dataset,
|
| 132 |
+
DATASET_ID,
|
| 133 |
+
)
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
# Must have 'pool' key (new format); ignore old tournament-format files
|
| 137 |
+
if "pool" not in raw:
|
| 138 |
+
logger.info("Saved state is old format — starting fresh")
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
with _lock:
|
| 142 |
+
_state = EloState.from_dict(raw)
|
| 143 |
+
_init_scheduler()
|
| 144 |
+
_save_state()
|
| 145 |
+
logger.info("Restored ELO state: %d galaxies, %d comparisons",
|
| 146 |
+
len(_state.pool), _state.total_comparisons)
|
| 147 |
+
return True
|
| 148 |
|
| 149 |
|
| 150 |
def _save_state():
|
|
|
|
| 151 |
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
| 152 |
with _lock:
|
| 153 |
if _state is None:
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
def record_comparison(winner_idx: int, loser_idx: int) -> dict:
|
| 165 |
+
"""Record a comparison and update ELO ratings."""
|
|
|
|
|
|
|
|
|
|
| 166 |
with _lock:
|
| 167 |
if _state is None:
|
| 168 |
+
raise RuntimeError("ELO state not initialized")
|
| 169 |
|
| 170 |
elo_w_before = _state.elo_ratings.get(winner_idx, DEFAULT_ELO)
|
| 171 |
elo_l_before = _state.elo_ratings.get(loser_idx, DEFAULT_ELO)
|
|
|
|
| 178 |
|
| 179 |
_state.elo_ratings[winner_idx] = elo_w_after
|
| 180 |
_state.elo_ratings[loser_idx] = elo_l_after
|
|
|
|
|
|
|
|
|
|
| 181 |
_state.total_comparisons += 1
|
| 182 |
|
|
|
|
|
|
|
|
|
|
| 183 |
_save_state()
|
| 184 |
|
| 185 |
return {
|
|
|
|
| 187 |
"winner_elo_after": elo_w_after,
|
| 188 |
"loser_elo_before": elo_l_before,
|
| 189 |
"loser_elo_after": elo_l_after,
|
|
|
|
|
|
|
| 190 |
}
|
| 191 |
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
def select_pair(seen_pairs: set[tuple[int, int]]) -> tuple[int, int] | None:
|
| 194 |
+
"""Select a pair to compare.
|
| 195 |
|
| 196 |
+
70% close-ELO matchup, 30% random. Returns None if no unseen pair available.
|
|
|
|
| 197 |
"""
|
| 198 |
with _lock:
|
| 199 |
+
if _state is None:
|
| 200 |
return None
|
| 201 |
+
pool = list(_state.pool)
|
|
|
|
| 202 |
if len(pool) < 2:
|
| 203 |
return None
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
if random.random() < 0.3:
|
| 206 |
+
pair = random.sample(pool, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
else:
|
| 208 |
+
rated = sorted(pool, key=lambda idx: _state.elo_ratings.get(idx, DEFAULT_ELO))
|
| 209 |
+
start = random.randint(0, len(rated) - 2)
|
| 210 |
+
pair = [rated[start], rated[start + 1]]
|
| 211 |
+
|
| 212 |
+
if (pair[0], pair[1]) in seen_pairs or (pair[1], pair[0]) in seen_pairs:
|
| 213 |
+
with _lock:
|
| 214 |
+
pool = list(_state.pool)
|
| 215 |
+
for _ in range(50):
|
| 216 |
+
pair = random.sample(pool, 2)
|
| 217 |
+
if (pair[0], pair[1]) not in seen_pairs and (pair[1], pair[0]) not in seen_pairs:
|
| 218 |
+
break
|
| 219 |
+
else:
|
| 220 |
+
return None
|
| 221 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
if random.random() < 0.5:
|
| 223 |
return (pair[1], pair[0])
|
| 224 |
return (pair[0], pair[1])
|
| 225 |
|
| 226 |
|
| 227 |
+
def get_info() -> dict:
|
| 228 |
+
"""Return a snapshot of ELO state for the progress dashboard."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
with _lock:
|
| 230 |
if _state is None:
|
| 231 |
+
return {"pool_size": 0, "total_comparisons": 0, "elo_values": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
return {
|
| 233 |
+
"pool_size": len(_state.pool),
|
|
|
|
| 234 |
"total_comparisons": _state.total_comparisons,
|
| 235 |
+
"elo_values": [_state.elo_ratings.get(idx, DEFAULT_ELO) for idx in _state.pool],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
}
|
| 237 |
|
| 238 |
|
| 239 |
def get_leaderboard() -> list[dict]:
|
| 240 |
+
"""Return top 20 galaxies by ELO descending."""
|
| 241 |
with _lock:
|
| 242 |
if _state is None:
|
| 243 |
return []
|
| 244 |
return sorted(
|
| 245 |
+
[{"id": idx, "elo": _state.elo_ratings.get(idx, DEFAULT_ELO)} for idx in _state.pool],
|
|
|
|
|
|
|
|
|
|
| 246 |
key=lambda x: x["elo"],
|
| 247 |
reverse=True,
|
| 248 |
)[:20]
|
| 249 |
|
| 250 |
|
| 251 |
def get_rating(galaxy_idx: int) -> float:
|
|
|
|
| 252 |
with _lock:
|
| 253 |
if _state is None:
|
| 254 |
return DEFAULT_ELO
|