Spaces:
Running
Running
| """ELO rating system for a persistent galaxy ranking.""" | |
| from __future__ import annotations | |
| import json | |
| import random | |
| import threading | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import CommitScheduler, hf_hub_download | |
| from src.config import ( | |
| DATASET_ID, | |
| DEFAULT_ELO, | |
| ELO_K_FACTOR, | |
| HF_LOG_EVERY_MINUTES, | |
| HF_LOG_REPO_ID, | |
| HF_TOKEN, | |
| ) | |
| # Imported lazily to avoid circular import at module load time | |
| def _get_display_name(row_index: int) -> str: | |
| try: | |
| from src.galaxy_profiles import get_display_name | |
| return get_display_name(row_index) | |
| except Exception: | |
| return str(row_index) | |
| logger = logging.getLogger(__name__) | |
| STATE_DIR = Path("state") | |
| STATE_FILE = STATE_DIR / "elo_state.json" | |
| _lock = threading.Lock() | |
| _state: EloState | None = None | |
| _state_scheduler = None | |
| class EloState: | |
| """ELO ratings for a fixed pool of galaxies.""" | |
| def __init__( | |
| self, | |
| pool: list[int], | |
| elo_ratings: dict[int, float] | None = None, | |
| total_comparisons: int = 0, | |
| dataset_id: str = "", | |
| ): | |
| self.pool = list(pool) | |
| self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in pool} | |
| self.total_comparisons = total_comparisons | |
| self.dataset_id = dataset_id | |
| def to_dict(self) -> dict: | |
| named_elo = { | |
| _get_display_name(idx): self.elo_ratings.get(idx, DEFAULT_ELO) | |
| for idx in self.pool | |
| } | |
| rankings = sorted( | |
| [{"galaxy_id": gid, "elo": elo} for gid, elo in named_elo.items()], | |
| key=lambda x: x["elo"], | |
| reverse=True, | |
| ) | |
| return { | |
| "pool": [_get_display_name(idx) for idx in self.pool], | |
| "elo_ratings": named_elo, | |
| "total_comparisons": self.total_comparisons, | |
| "dataset_id": self.dataset_id, | |
| "rankings": rankings, | |
| } | |
| def from_dict(cls, d: dict, id_to_row: dict[str, int] | None = None) -> EloState: | |
| """Restore from a saved dict. | |
| If *id_to_row* is provided (display-name β row-index map), pool entries | |
| and elo_ratings keys are treated as display names and converted back to | |
| row indices. Entries that have no mapping are silently dropped. | |
| """ | |
| if id_to_row is not None: | |
| pool = [id_to_row[gid] for gid in d["pool"] if gid in id_to_row] | |
| elo_ratings = { | |
| id_to_row[gid]: v | |
| for gid, v in d["elo_ratings"].items() | |
| if gid in id_to_row | |
| } | |
| else: | |
| pool = d["pool"] | |
| elo_ratings = {int(k): v for k, v in d["elo_ratings"].items()} | |
| return cls( | |
| pool=pool, | |
| elo_ratings=elo_ratings, | |
| total_comparisons=d.get("total_comparisons", 0), | |
| dataset_id=d.get("dataset_id", ""), | |
| ) | |
| def _init_scheduler(): | |
| global _state_scheduler | |
| if not HF_LOG_REPO_ID: | |
| return | |
| STATE_DIR.mkdir(parents=True, exist_ok=True) | |
| _state_scheduler = CommitScheduler( | |
| repo_id=HF_LOG_REPO_ID, | |
| repo_type="dataset", | |
| folder_path=STATE_DIR, | |
| path_in_repo="state", | |
| every=HF_LOG_EVERY_MINUTES, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID) | |
| def initialize_elo(pool_indices: list[int]): | |
| """Create fresh ELO state for the given pool.""" | |
| global _state | |
| with _lock: | |
| _state = EloState(pool=pool_indices, dataset_id=DATASET_ID) | |
| _save_state() | |
| _init_scheduler() | |
| logger.info("ELO state initialized with %d galaxies", len(pool_indices)) | |
| def load_elo_state() -> bool: | |
| """Try to restore ELO state from HF Hub or local file. | |
| Discards saved state if it belongs to a different dataset. | |
| Returns True if state was loaded, False if starting fresh. | |
| """ | |
| global _state | |
| raw = None | |
| if HF_LOG_REPO_ID: | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=HF_LOG_REPO_ID, | |
| repo_type="dataset", | |
| filename="state/elo_state.json", | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| force_download=True, | |
| ) | |
| with open(local_path) as f: | |
| raw = json.load(f) | |
| logger.info("Loaded state from HF Hub") | |
| except Exception as e: | |
| logger.warning("Could not load state from HF: %s", e) | |
| if raw is None and STATE_FILE.exists(): | |
| try: | |
| with open(STATE_FILE) as f: | |
| raw = json.load(f) | |
| logger.info("Loaded state from local file") | |
| except Exception as e: | |
| logger.warning("Could not load local state: %s", e) | |
| if raw is None: | |
| return False | |
| # Validate dataset match | |
| saved_dataset = raw.get("dataset_id", "") | |
| if saved_dataset and saved_dataset != DATASET_ID: | |
| logger.info( | |
| "Saved state is for dataset '%s', current is '%s' β starting fresh", | |
| saved_dataset, | |
| DATASET_ID, | |
| ) | |
| return False | |
| # Must have 'pool' key (new format); ignore old tournament-format files | |
| if "pool" not in raw: | |
| logger.info("Saved state is old format β starting fresh") | |
| return False | |
| # Build reverse map: display name β row index (requires metadata to be loaded first) | |
| id_to_row: dict[str, int] | None = None | |
| pool_sample = raw["pool"] | |
| if pool_sample and isinstance(pool_sample[0], str): | |
| # New format: pool contains display names β reverse-map via metadata cache | |
| from src.galaxy_profiles import get_row_index_by_id | |
| id_to_row = {} | |
| for gid in raw["pool"]: | |
| row = get_row_index_by_id(gid) | |
| if row is not None: | |
| id_to_row[gid] = row | |
| with _lock: | |
| _state = EloState.from_dict(raw, id_to_row=id_to_row) | |
| _init_scheduler() | |
| _save_state() | |
| logger.info("Restored ELO state: %d galaxies, %d comparisons", | |
| len(_state.pool), _state.total_comparisons) | |
| return True | |
| def _save_state(): | |
| STATE_DIR.mkdir(parents=True, exist_ok=True) | |
| with _lock: | |
| if _state is None: | |
| return | |
| data = _state.to_dict() | |
| if _state_scheduler is not None: | |
| with _state_scheduler.lock: | |
| with open(STATE_FILE, "w") as f: | |
| json.dump(data, f, indent=2) | |
| else: | |
| with open(STATE_FILE, "w") as f: | |
| json.dump(data, f, indent=2) | |
| def _expected_score(rating_a: float, rating_b: float) -> float: | |
| return 1.0 / (1.0 + 10.0 ** ((rating_b - rating_a) / 400.0)) | |
| def record_comparison(winner_idx: int, loser_idx: int) -> dict: | |
| """Record a comparison and update ELO ratings.""" | |
| with _lock: | |
| if _state is None: | |
| raise RuntimeError("ELO state not initialized") | |
| elo_w_before = _state.elo_ratings.get(winner_idx, DEFAULT_ELO) | |
| elo_l_before = _state.elo_ratings.get(loser_idx, DEFAULT_ELO) | |
| expected_w = _expected_score(elo_w_before, elo_l_before) | |
| expected_l = _expected_score(elo_l_before, elo_w_before) | |
| elo_w_after = elo_w_before + ELO_K_FACTOR * (1.0 - expected_w) | |
| elo_l_after = elo_l_before + ELO_K_FACTOR * (0.0 - expected_l) | |
| _state.elo_ratings[winner_idx] = elo_w_after | |
| _state.elo_ratings[loser_idx] = elo_l_after | |
| _state.total_comparisons += 1 | |
| _save_state() | |
| return { | |
| "winner_elo_before": elo_w_before, | |
| "winner_elo_after": elo_w_after, | |
| "loser_elo_before": elo_l_before, | |
| "loser_elo_after": elo_l_after, | |
| } | |
| def select_pair() -> tuple[int, int] | None: | |
| """Select a pair to compare. | |
| 70% close-ELO matchup, 30% random. | |
| """ | |
| with _lock: | |
| if _state is None: | |
| return None | |
| pool = list(_state.pool) | |
| if len(pool) < 2: | |
| return None | |
| if random.random() < 0.3: | |
| pair = random.sample(pool, 2) | |
| else: | |
| rated = sorted(pool, key=lambda idx: _state.elo_ratings.get(idx, DEFAULT_ELO)) | |
| start = random.randint(0, len(rated) - 2) | |
| pair = [rated[start], rated[start + 1]] | |
| if random.random() < 0.5: | |
| return (pair[1], pair[0]) | |
| return (pair[0], pair[1]) | |
| def get_info() -> dict: | |
| """Return a snapshot of ELO state for the progress dashboard.""" | |
| with _lock: | |
| if _state is None: | |
| return {"pool_size": 0, "total_comparisons": 0, "elo_values": []} | |
| return { | |
| "pool_size": len(_state.pool), | |
| "total_comparisons": _state.total_comparisons, | |
| "elo_values": [_state.elo_ratings.get(idx, DEFAULT_ELO) for idx in _state.pool], | |
| } | |
| def get_leaderboard() -> list[dict]: | |
| """Return top 20 galaxies by ELO descending.""" | |
| with _lock: | |
| if _state is None: | |
| return [] | |
| return sorted( | |
| [{"id": idx, "elo": _state.elo_ratings.get(idx, DEFAULT_ELO)} for idx in _state.pool], | |
| key=lambda x: x["elo"], | |
| reverse=True, | |
| )[:20] | |
| def get_rating(galaxy_idx: int) -> float: | |
| with _lock: | |
| if _state is None: | |
| return DEFAULT_ELO | |
| return _state.elo_ratings.get(galaxy_idx, DEFAULT_ELO) | |