perihelion / src /elo.py
Smith42's picture
HHH
d5a68f6
"""ELO rating system for a persistent galaxy ranking."""
from __future__ import annotations
import json
import random
import threading
import logging
from pathlib import Path
from huggingface_hub import CommitScheduler, hf_hub_download
from src.config import (
DATASET_ID,
DEFAULT_ELO,
ELO_K_FACTOR,
HF_LOG_EVERY_MINUTES,
HF_LOG_REPO_ID,
HF_TOKEN,
)
# Imported lazily to avoid circular import at module load time
def _get_display_name(row_index: int) -> str:
try:
from src.galaxy_profiles import get_display_name
return get_display_name(row_index)
except Exception:
return str(row_index)
logger = logging.getLogger(__name__)
STATE_DIR = Path("state")
STATE_FILE = STATE_DIR / "elo_state.json"
_lock = threading.Lock()
_state: EloState | None = None
_state_scheduler = None
class EloState:
"""ELO ratings for a fixed pool of galaxies."""
def __init__(
self,
pool: list[int],
elo_ratings: dict[int, float] | None = None,
total_comparisons: int = 0,
dataset_id: str = "",
):
self.pool = list(pool)
self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in pool}
self.total_comparisons = total_comparisons
self.dataset_id = dataset_id
def to_dict(self) -> dict:
named_elo = {
_get_display_name(idx): self.elo_ratings.get(idx, DEFAULT_ELO)
for idx in self.pool
}
rankings = sorted(
[{"galaxy_id": gid, "elo": elo} for gid, elo in named_elo.items()],
key=lambda x: x["elo"],
reverse=True,
)
return {
"pool": [_get_display_name(idx) for idx in self.pool],
"elo_ratings": named_elo,
"total_comparisons": self.total_comparisons,
"dataset_id": self.dataset_id,
"rankings": rankings,
}
@classmethod
def from_dict(cls, d: dict, id_to_row: dict[str, int] | None = None) -> EloState:
"""Restore from a saved dict.
If *id_to_row* is provided (display-name β†’ row-index map), pool entries
and elo_ratings keys are treated as display names and converted back to
row indices. Entries that have no mapping are silently dropped.
"""
if id_to_row is not None:
pool = [id_to_row[gid] for gid in d["pool"] if gid in id_to_row]
elo_ratings = {
id_to_row[gid]: v
for gid, v in d["elo_ratings"].items()
if gid in id_to_row
}
else:
pool = d["pool"]
elo_ratings = {int(k): v for k, v in d["elo_ratings"].items()}
return cls(
pool=pool,
elo_ratings=elo_ratings,
total_comparisons=d.get("total_comparisons", 0),
dataset_id=d.get("dataset_id", ""),
)
def _init_scheduler():
global _state_scheduler
if not HF_LOG_REPO_ID:
return
STATE_DIR.mkdir(parents=True, exist_ok=True)
_state_scheduler = CommitScheduler(
repo_id=HF_LOG_REPO_ID,
repo_type="dataset",
folder_path=STATE_DIR,
path_in_repo="state",
every=HF_LOG_EVERY_MINUTES,
token=HF_TOKEN if HF_TOKEN else None,
)
logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID)
def initialize_elo(pool_indices: list[int]):
"""Create fresh ELO state for the given pool."""
global _state
with _lock:
_state = EloState(pool=pool_indices, dataset_id=DATASET_ID)
_save_state()
_init_scheduler()
logger.info("ELO state initialized with %d galaxies", len(pool_indices))
def load_elo_state() -> bool:
"""Try to restore ELO state from HF Hub or local file.
Discards saved state if it belongs to a different dataset.
Returns True if state was loaded, False if starting fresh.
"""
global _state
raw = None
if HF_LOG_REPO_ID:
try:
local_path = hf_hub_download(
repo_id=HF_LOG_REPO_ID,
repo_type="dataset",
filename="state/elo_state.json",
token=HF_TOKEN if HF_TOKEN else None,
force_download=True,
)
with open(local_path) as f:
raw = json.load(f)
logger.info("Loaded state from HF Hub")
except Exception as e:
logger.warning("Could not load state from HF: %s", e)
if raw is None and STATE_FILE.exists():
try:
with open(STATE_FILE) as f:
raw = json.load(f)
logger.info("Loaded state from local file")
except Exception as e:
logger.warning("Could not load local state: %s", e)
if raw is None:
return False
# Validate dataset match
saved_dataset = raw.get("dataset_id", "")
if saved_dataset and saved_dataset != DATASET_ID:
logger.info(
"Saved state is for dataset '%s', current is '%s' β€” starting fresh",
saved_dataset,
DATASET_ID,
)
return False
# Must have 'pool' key (new format); ignore old tournament-format files
if "pool" not in raw:
logger.info("Saved state is old format β€” starting fresh")
return False
# Build reverse map: display name β†’ row index (requires metadata to be loaded first)
id_to_row: dict[str, int] | None = None
pool_sample = raw["pool"]
if pool_sample and isinstance(pool_sample[0], str):
# New format: pool contains display names β€” reverse-map via metadata cache
from src.galaxy_profiles import get_row_index_by_id
id_to_row = {}
for gid in raw["pool"]:
row = get_row_index_by_id(gid)
if row is not None:
id_to_row[gid] = row
with _lock:
_state = EloState.from_dict(raw, id_to_row=id_to_row)
_init_scheduler()
_save_state()
logger.info("Restored ELO state: %d galaxies, %d comparisons",
len(_state.pool), _state.total_comparisons)
return True
def _save_state():
STATE_DIR.mkdir(parents=True, exist_ok=True)
with _lock:
if _state is None:
return
data = _state.to_dict()
if _state_scheduler is not None:
with _state_scheduler.lock:
with open(STATE_FILE, "w") as f:
json.dump(data, f, indent=2)
else:
with open(STATE_FILE, "w") as f:
json.dump(data, f, indent=2)
def _expected_score(rating_a: float, rating_b: float) -> float:
return 1.0 / (1.0 + 10.0 ** ((rating_b - rating_a) / 400.0))
def record_comparison(winner_idx: int, loser_idx: int) -> dict:
"""Record a comparison and update ELO ratings."""
with _lock:
if _state is None:
raise RuntimeError("ELO state not initialized")
elo_w_before = _state.elo_ratings.get(winner_idx, DEFAULT_ELO)
elo_l_before = _state.elo_ratings.get(loser_idx, DEFAULT_ELO)
expected_w = _expected_score(elo_w_before, elo_l_before)
expected_l = _expected_score(elo_l_before, elo_w_before)
elo_w_after = elo_w_before + ELO_K_FACTOR * (1.0 - expected_w)
elo_l_after = elo_l_before + ELO_K_FACTOR * (0.0 - expected_l)
_state.elo_ratings[winner_idx] = elo_w_after
_state.elo_ratings[loser_idx] = elo_l_after
_state.total_comparisons += 1
_save_state()
return {
"winner_elo_before": elo_w_before,
"winner_elo_after": elo_w_after,
"loser_elo_before": elo_l_before,
"loser_elo_after": elo_l_after,
}
def select_pair() -> tuple[int, int] | None:
"""Select a pair to compare.
70% close-ELO matchup, 30% random.
"""
with _lock:
if _state is None:
return None
pool = list(_state.pool)
if len(pool) < 2:
return None
if random.random() < 0.3:
pair = random.sample(pool, 2)
else:
rated = sorted(pool, key=lambda idx: _state.elo_ratings.get(idx, DEFAULT_ELO))
start = random.randint(0, len(rated) - 2)
pair = [rated[start], rated[start + 1]]
if random.random() < 0.5:
return (pair[1], pair[0])
return (pair[0], pair[1])
def get_info() -> dict:
"""Return a snapshot of ELO state for the progress dashboard."""
with _lock:
if _state is None:
return {"pool_size": 0, "total_comparisons": 0, "elo_values": []}
return {
"pool_size": len(_state.pool),
"total_comparisons": _state.total_comparisons,
"elo_values": [_state.elo_ratings.get(idx, DEFAULT_ELO) for idx in _state.pool],
}
def get_leaderboard() -> list[dict]:
"""Return top 20 galaxies by ELO descending."""
with _lock:
if _state is None:
return []
return sorted(
[{"id": idx, "elo": _state.elo_ratings.get(idx, DEFAULT_ELO)} for idx in _state.pool],
key=lambda x: x["elo"],
reverse=True,
)[:20]
def get_rating(galaxy_idx: int) -> float:
with _lock:
if _state is None:
return DEFAULT_ELO
return _state.elo_ratings.get(galaxy_idx, DEFAULT_ELO)