Spaces:
Sleeping
Sleeping
Fix error
Browse files- app.py +26 -10
- dataset_config.yaml +1 -1
- src/elo.py +20 -2
- src/galaxy_data_loader.py +40 -4
app.py
CHANGED
|
@@ -45,16 +45,32 @@ def create_app() -> dash.Dash:
|
|
| 45 |
# Initialize tournament
|
| 46 |
logger.info("Loading tournament state...")
|
| 47 |
loaded = elo.load_tournament_state()
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# Layout and callbacks
|
| 60 |
app.layout = create_layout()
|
|
|
|
| 45 |
# Initialize tournament
|
| 46 |
logger.info("Loading tournament state...")
|
| 47 |
loaded = elo.load_tournament_state()
|
| 48 |
+
|
| 49 |
+
# Always re-stream the pool to populate the image + metadata caches.
|
| 50 |
+
# On reload we reuse the saved seed so the same galaxies are sampled in the
|
| 51 |
+
# same order, keeping ELO rankings consistent across restarts.
|
| 52 |
+
seed = elo.get_pool_seed() if loaded else None
|
| 53 |
+
logger.info(
|
| 54 |
+
"Streaming pool of %d galaxies (seed=%s)...",
|
| 55 |
+
POOL_SIZE,
|
| 56 |
+
seed if seed is not None else "random",
|
| 57 |
+
)
|
| 58 |
+
try:
|
| 59 |
+
pool, metadata_map, used_seed = sample_pool_streaming(POOL_SIZE, seed=seed)
|
| 60 |
+
register_metadata(metadata_map)
|
| 61 |
+
if not loaded:
|
| 62 |
+
elo.initialize_tournament(pool, pool_seed=used_seed)
|
| 63 |
+
else:
|
| 64 |
+
# Persist seed into existing state so future reloads can reuse it
|
| 65 |
+
elo.set_pool_seed(used_seed)
|
| 66 |
+
logger.info(
|
| 67 |
+
"Tournament state restored: round %d, %d active galaxies",
|
| 68 |
+
elo.get_tournament_info().get("current_round", 1),
|
| 69 |
+
len(pool),
|
| 70 |
+
)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error("Failed to stream galaxy pool: %s", e)
|
| 73 |
+
raise
|
| 74 |
|
| 75 |
# Layout and callbacks
|
| 76 |
app.layout = create_layout()
|
dataset_config.yaml
CHANGED
|
@@ -3,7 +3,7 @@ config: "default"
|
|
| 3 |
split: "train"
|
| 4 |
image_column: "image"
|
| 5 |
id_column: "id_str"
|
| 6 |
-
pool_size:
|
| 7 |
min_comparisons_per_round: 3
|
| 8 |
max_comparisons_per_round: 5
|
| 9 |
elimination_fraction: 0.5
|
|
|
|
| 3 |
split: "train"
|
| 4 |
image_column: "image"
|
| 5 |
id_column: "id_str"
|
| 6 |
+
pool_size: 1000
|
| 7 |
min_comparisons_per_round: 3
|
| 8 |
max_comparisons_per_round: 5
|
| 9 |
elimination_fraction: 0.5
|
src/elo.py
CHANGED
|
@@ -44,6 +44,7 @@ class TournamentState:
|
|
| 44 |
eliminated: list[int] | None = None,
|
| 45 |
total_comparisons: int = 0,
|
| 46 |
tournament_complete: bool = False,
|
|
|
|
| 47 |
):
|
| 48 |
self.active_pool = list(active_pool)
|
| 49 |
self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in active_pool}
|
|
@@ -52,6 +53,7 @@ class TournamentState:
|
|
| 52 |
self.eliminated = eliminated or []
|
| 53 |
self.total_comparisons = total_comparisons
|
| 54 |
self.tournament_complete = tournament_complete
|
|
|
|
| 55 |
|
| 56 |
def to_dict(self) -> dict:
|
| 57 |
return {
|
|
@@ -62,6 +64,7 @@ class TournamentState:
|
|
| 62 |
"eliminated": self.eliminated,
|
| 63 |
"total_comparisons": self.total_comparisons,
|
| 64 |
"tournament_complete": self.tournament_complete,
|
|
|
|
| 65 |
}
|
| 66 |
|
| 67 |
@classmethod
|
|
@@ -74,6 +77,7 @@ class TournamentState:
|
|
| 74 |
eliminated=d.get("eliminated", []),
|
| 75 |
total_comparisons=d.get("total_comparisons", 0),
|
| 76 |
tournament_complete=d.get("tournament_complete", False),
|
|
|
|
| 77 |
)
|
| 78 |
|
| 79 |
|
|
@@ -94,11 +98,11 @@ def _init_scheduler():
|
|
| 94 |
logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID)
|
| 95 |
|
| 96 |
|
| 97 |
-
def initialize_tournament(pool_indices: list[int]):
|
| 98 |
"""Create a fresh tournament with the given pool."""
|
| 99 |
global _state
|
| 100 |
with _lock:
|
| 101 |
-
_state = TournamentState(active_pool=pool_indices)
|
| 102 |
_save_state()
|
| 103 |
_init_scheduler()
|
| 104 |
logger.info("Tournament initialized with %d galaxies", len(pool_indices))
|
|
@@ -333,6 +337,20 @@ def select_pair(seen_pairs: set[tuple[int, int]]) -> tuple[int, int] | None:
|
|
| 333 |
return (pair[0], pair[1])
|
| 334 |
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
def get_tournament_info() -> dict:
|
| 337 |
"""Return a snapshot of tournament state for the progress dashboard."""
|
| 338 |
with _lock:
|
|
|
|
| 44 |
eliminated: list[int] | None = None,
|
| 45 |
total_comparisons: int = 0,
|
| 46 |
tournament_complete: bool = False,
|
| 47 |
+
pool_seed: int | None = None,
|
| 48 |
):
|
| 49 |
self.active_pool = list(active_pool)
|
| 50 |
self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in active_pool}
|
|
|
|
| 53 |
self.eliminated = eliminated or []
|
| 54 |
self.total_comparisons = total_comparisons
|
| 55 |
self.tournament_complete = tournament_complete
|
| 56 |
+
self.pool_seed = pool_seed
|
| 57 |
|
| 58 |
def to_dict(self) -> dict:
|
| 59 |
return {
|
|
|
|
| 64 |
"eliminated": self.eliminated,
|
| 65 |
"total_comparisons": self.total_comparisons,
|
| 66 |
"tournament_complete": self.tournament_complete,
|
| 67 |
+
"pool_seed": self.pool_seed,
|
| 68 |
}
|
| 69 |
|
| 70 |
@classmethod
|
|
|
|
| 77 |
eliminated=d.get("eliminated", []),
|
| 78 |
total_comparisons=d.get("total_comparisons", 0),
|
| 79 |
tournament_complete=d.get("tournament_complete", False),
|
| 80 |
+
pool_seed=d.get("pool_seed"),
|
| 81 |
)
|
| 82 |
|
| 83 |
|
|
|
|
| 98 |
logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID)
|
| 99 |
|
| 100 |
|
| 101 |
+
def initialize_tournament(pool_indices: list[int], pool_seed: int | None = None):
|
| 102 |
"""Create a fresh tournament with the given pool."""
|
| 103 |
global _state
|
| 104 |
with _lock:
|
| 105 |
+
_state = TournamentState(active_pool=pool_indices, pool_seed=pool_seed)
|
| 106 |
_save_state()
|
| 107 |
_init_scheduler()
|
| 108 |
logger.info("Tournament initialized with %d galaxies", len(pool_indices))
|
|
|
|
| 337 |
return (pair[0], pair[1])
|
| 338 |
|
| 339 |
|
| 340 |
+
def get_pool_seed() -> int | None:
|
| 341 |
+
"""Return the shuffle seed used when the current pool was sampled."""
|
| 342 |
+
with _lock:
|
| 343 |
+
return _state.pool_seed if _state else None
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def set_pool_seed(seed: int):
|
| 347 |
+
"""Store the pool seed into the current tournament state and save."""
|
| 348 |
+
with _lock:
|
| 349 |
+
if _state is not None:
|
| 350 |
+
_state.pool_seed = seed
|
| 351 |
+
_save_state()
|
| 352 |
+
|
| 353 |
+
|
| 354 |
def get_tournament_info() -> dict:
|
| 355 |
"""Return a snapshot of tournament state for the progress dashboard."""
|
| 356 |
with _lock:
|
src/galaxy_data_loader.py
CHANGED
|
@@ -60,6 +60,29 @@ def sample_pool_indices(total: int, pool_size: int) -> list[int]:
|
|
| 60 |
# Row / image fetching
|
| 61 |
# ---------------------------------------------------------------------------
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def fetch_rows(offsets: list[int]) -> dict[int, dict]:
|
| 64 |
"""Fetch rows by offset via the HF dataset-viewer /rows endpoint.
|
| 65 |
|
|
@@ -206,20 +229,33 @@ image_cache = ImageCache()
|
|
| 206 |
# Streaming pool sampler
|
| 207 |
# ---------------------------------------------------------------------------
|
| 208 |
|
| 209 |
-
def sample_pool_streaming(
|
|
|
|
|
|
|
| 210 |
"""Stream pool_size shuffled galaxies from HF Datasets, pre-caching images.
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
Returns:
|
| 213 |
ids: sequential ints 0..N-1 used as galaxy IDs throughout the app
|
| 214 |
metadata_map: {id -> row_dict (without image column)} for display names
|
|
|
|
| 215 |
"""
|
| 216 |
from datasets import load_dataset
|
| 217 |
from datasets import Image as HFImage
|
| 218 |
|
|
|
|
|
|
|
|
|
|
| 219 |
logger.info(
|
| 220 |
-
"Streaming %d galaxies from %s (shuffle
|
| 221 |
pool_size,
|
| 222 |
DATASET_ID,
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
ds = load_dataset(
|
|
@@ -235,7 +271,7 @@ def sample_pool_streaming(pool_size: int) -> tuple[list[int], dict[int, dict]]:
|
|
| 235 |
if features and IMAGE_COLUMN in features:
|
| 236 |
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
|
| 237 |
|
| 238 |
-
ds = ds.shuffle(seed=
|
| 239 |
ds = ds.take(pool_size)
|
| 240 |
|
| 241 |
ids: list[int] = []
|
|
@@ -259,4 +295,4 @@ def sample_pool_streaming(pool_size: int) -> tuple[list[int], dict[int, dict]]:
|
|
| 259 |
logger.info("Streamed %d/%d galaxies", i + 1, pool_size)
|
| 260 |
|
| 261 |
logger.info("Finished streaming %d galaxies", len(ids))
|
| 262 |
-
return ids, metadata_map
|
|
|
|
| 60 |
# Row / image fetching
|
| 61 |
# ---------------------------------------------------------------------------
|
| 62 |
|
| 63 |
+
def fetch_image_bytes(row_index: int) -> bytes | None:
|
| 64 |
+
"""Fetch raw image bytes for a single row via the dataset-viewer API.
|
| 65 |
+
|
| 66 |
+
Uses fetch_rows to get the signed image URL, then downloads the image.
|
| 67 |
+
Returns None on any failure.
|
| 68 |
+
"""
|
| 69 |
+
rows = fetch_rows([row_index])
|
| 70 |
+
row = rows.get(row_index)
|
| 71 |
+
if row is None:
|
| 72 |
+
return None
|
| 73 |
+
img_url = _extract_image_url(row)
|
| 74 |
+
if not img_url:
|
| 75 |
+
logger.warning("No image URL in row %d", row_index)
|
| 76 |
+
return None
|
| 77 |
+
try:
|
| 78 |
+
resp = requests.get(img_url, headers=_hf_headers(), timeout=30)
|
| 79 |
+
resp.raise_for_status()
|
| 80 |
+
return resp.content
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.warning("Failed to download image for row %d: %s", row_index, e)
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
def fetch_rows(offsets: list[int]) -> dict[int, dict]:
|
| 87 |
"""Fetch rows by offset via the HF dataset-viewer /rows endpoint.
|
| 88 |
|
|
|
|
| 229 |
# Streaming pool sampler
|
| 230 |
# ---------------------------------------------------------------------------
|
| 231 |
|
| 232 |
+
def sample_pool_streaming(
|
| 233 |
+
pool_size: int, seed: int | None = None
|
| 234 |
+
) -> tuple[list[int], dict[int, dict], int]:
|
| 235 |
"""Stream pool_size shuffled galaxies from HF Datasets, pre-caching images.
|
| 236 |
|
| 237 |
+
Args:
|
| 238 |
+
pool_size: Number of galaxies to include in the pool.
|
| 239 |
+
seed: Shuffle seed. If None, a random seed is generated. Pass the same
|
| 240 |
+
seed on subsequent startups to reproduce the exact same pool order
|
| 241 |
+
so that saved ELO state remains valid across restarts.
|
| 242 |
+
|
| 243 |
Returns:
|
| 244 |
ids: sequential ints 0..N-1 used as galaxy IDs throughout the app
|
| 245 |
metadata_map: {id -> row_dict (without image column)} for display names
|
| 246 |
+
seed: the seed that was used (store in tournament state for reuse)
|
| 247 |
"""
|
| 248 |
from datasets import load_dataset
|
| 249 |
from datasets import Image as HFImage
|
| 250 |
|
| 251 |
+
if seed is None:
|
| 252 |
+
seed = random.randint(0, 2**32 - 1)
|
| 253 |
+
|
| 254 |
logger.info(
|
| 255 |
+
"Streaming %d galaxies from %s (shuffle seed=%d)...",
|
| 256 |
pool_size,
|
| 257 |
DATASET_ID,
|
| 258 |
+
seed,
|
| 259 |
)
|
| 260 |
|
| 261 |
ds = load_dataset(
|
|
|
|
| 271 |
if features and IMAGE_COLUMN in features:
|
| 272 |
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
|
| 273 |
|
| 274 |
+
ds = ds.shuffle(seed=seed, buffer_size=10_000)
|
| 275 |
ds = ds.take(pool_size)
|
| 276 |
|
| 277 |
ids: list[int] = []
|
|
|
|
| 295 |
logger.info("Streamed %d/%d galaxies", i + 1, pool_size)
|
| 296 |
|
| 297 |
logger.info("Finished streaming %d galaxies", len(ids))
|
| 298 |
+
return ids, metadata_map, seed
|