Smith42 commited on
Commit
346f507
·
1 Parent(s): 7fefd3b

Fix error

Browse files
Files changed (4) hide show
  1. app.py +26 -10
  2. dataset_config.yaml +1 -1
  3. src/elo.py +20 -2
  4. src/galaxy_data_loader.py +40 -4
app.py CHANGED
@@ -45,16 +45,32 @@ def create_app() -> dash.Dash:
45
  # Initialize tournament
46
  logger.info("Loading tournament state...")
47
  loaded = elo.load_tournament_state()
48
- if not loaded:
49
- logger.info("No existing tournament found. Streaming new pool...")
50
- try:
51
- logger.info("Streaming pool of %d galaxies from HF dataset...", POOL_SIZE)
52
- pool, metadata_map = sample_pool_streaming(POOL_SIZE)
53
- register_metadata(metadata_map)
54
- elo.initialize_tournament(pool)
55
- except Exception as e:
56
- logger.error("Failed to initialize tournament: %s", e)
57
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Layout and callbacks
60
  app.layout = create_layout()
 
45
  # Initialize tournament
46
  logger.info("Loading tournament state...")
47
  loaded = elo.load_tournament_state()
48
+
49
+ # Always re-stream the pool to populate the image + metadata caches.
50
+ # On reload we reuse the saved seed so the same galaxies are sampled in the
51
+ # same order, keeping ELO rankings consistent across restarts.
52
+ seed = elo.get_pool_seed() if loaded else None
53
+ logger.info(
54
+ "Streaming pool of %d galaxies (seed=%s)...",
55
+ POOL_SIZE,
56
+ seed if seed is not None else "random",
57
+ )
58
+ try:
59
+ pool, metadata_map, used_seed = sample_pool_streaming(POOL_SIZE, seed=seed)
60
+ register_metadata(metadata_map)
61
+ if not loaded:
62
+ elo.initialize_tournament(pool, pool_seed=used_seed)
63
+ else:
64
+ # Persist seed into existing state so future reloads can reuse it
65
+ elo.set_pool_seed(used_seed)
66
+ logger.info(
67
+ "Tournament state restored: round %d, %d active galaxies",
68
+ elo.get_tournament_info().get("current_round", 1),
69
+ len(pool),
70
+ )
71
+ except Exception as e:
72
+ logger.error("Failed to stream galaxy pool: %s", e)
73
+ raise
74
 
75
  # Layout and callbacks
76
  app.layout = create_layout()
dataset_config.yaml CHANGED
@@ -3,7 +3,7 @@ config: "default"
3
  split: "train"
4
  image_column: "image"
5
  id_column: "id_str"
6
- pool_size: 3000
7
  min_comparisons_per_round: 3
8
  max_comparisons_per_round: 5
9
  elimination_fraction: 0.5
 
3
  split: "train"
4
  image_column: "image"
5
  id_column: "id_str"
6
+ pool_size: 1000
7
  min_comparisons_per_round: 3
8
  max_comparisons_per_round: 5
9
  elimination_fraction: 0.5
src/elo.py CHANGED
@@ -44,6 +44,7 @@ class TournamentState:
44
  eliminated: list[int] | None = None,
45
  total_comparisons: int = 0,
46
  tournament_complete: bool = False,
 
47
  ):
48
  self.active_pool = list(active_pool)
49
  self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in active_pool}
@@ -52,6 +53,7 @@ class TournamentState:
52
  self.eliminated = eliminated or []
53
  self.total_comparisons = total_comparisons
54
  self.tournament_complete = tournament_complete
 
55
 
56
  def to_dict(self) -> dict:
57
  return {
@@ -62,6 +64,7 @@ class TournamentState:
62
  "eliminated": self.eliminated,
63
  "total_comparisons": self.total_comparisons,
64
  "tournament_complete": self.tournament_complete,
 
65
  }
66
 
67
  @classmethod
@@ -74,6 +77,7 @@ class TournamentState:
74
  eliminated=d.get("eliminated", []),
75
  total_comparisons=d.get("total_comparisons", 0),
76
  tournament_complete=d.get("tournament_complete", False),
 
77
  )
78
 
79
 
@@ -94,11 +98,11 @@ def _init_scheduler():
94
  logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID)
95
 
96
 
97
- def initialize_tournament(pool_indices: list[int]):
98
  """Create a fresh tournament with the given pool."""
99
  global _state
100
  with _lock:
101
- _state = TournamentState(active_pool=pool_indices)
102
  _save_state()
103
  _init_scheduler()
104
  logger.info("Tournament initialized with %d galaxies", len(pool_indices))
@@ -333,6 +337,20 @@ def select_pair(seen_pairs: set[tuple[int, int]]) -> tuple[int, int] | None:
333
  return (pair[0], pair[1])
334
 
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def get_tournament_info() -> dict:
337
  """Return a snapshot of tournament state for the progress dashboard."""
338
  with _lock:
 
44
  eliminated: list[int] | None = None,
45
  total_comparisons: int = 0,
46
  tournament_complete: bool = False,
47
+ pool_seed: int | None = None,
48
  ):
49
  self.active_pool = list(active_pool)
50
  self.elo_ratings = elo_ratings or {idx: DEFAULT_ELO for idx in active_pool}
 
53
  self.eliminated = eliminated or []
54
  self.total_comparisons = total_comparisons
55
  self.tournament_complete = tournament_complete
56
+ self.pool_seed = pool_seed
57
 
58
  def to_dict(self) -> dict:
59
  return {
 
64
  "eliminated": self.eliminated,
65
  "total_comparisons": self.total_comparisons,
66
  "tournament_complete": self.tournament_complete,
67
+ "pool_seed": self.pool_seed,
68
  }
69
 
70
  @classmethod
 
77
  eliminated=d.get("eliminated", []),
78
  total_comparisons=d.get("total_comparisons", 0),
79
  tournament_complete=d.get("tournament_complete", False),
80
+ pool_seed=d.get("pool_seed"),
81
  )
82
 
83
 
 
98
  logger.info("ELO state scheduler initialized (repo=%s)", HF_LOG_REPO_ID)
99
 
100
 
101
+ def initialize_tournament(pool_indices: list[int], pool_seed: int | None = None):
102
  """Create a fresh tournament with the given pool."""
103
  global _state
104
  with _lock:
105
+ _state = TournamentState(active_pool=pool_indices, pool_seed=pool_seed)
106
  _save_state()
107
  _init_scheduler()
108
  logger.info("Tournament initialized with %d galaxies", len(pool_indices))
 
337
  return (pair[0], pair[1])
338
 
339
 
340
+ def get_pool_seed() -> int | None:
341
+ """Return the shuffle seed used when the current pool was sampled."""
342
+ with _lock:
343
+ return _state.pool_seed if _state else None
344
+
345
+
346
+ def set_pool_seed(seed: int):
347
+ """Store the pool seed into the current tournament state and save."""
348
+ with _lock:
349
+ if _state is not None:
350
+ _state.pool_seed = seed
351
+ _save_state()
352
+
353
+
354
  def get_tournament_info() -> dict:
355
  """Return a snapshot of tournament state for the progress dashboard."""
356
  with _lock:
src/galaxy_data_loader.py CHANGED
@@ -60,6 +60,29 @@ def sample_pool_indices(total: int, pool_size: int) -> list[int]:
60
  # Row / image fetching
61
  # ---------------------------------------------------------------------------
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def fetch_rows(offsets: list[int]) -> dict[int, dict]:
64
  """Fetch rows by offset via the HF dataset-viewer /rows endpoint.
65
 
@@ -206,20 +229,33 @@ image_cache = ImageCache()
206
  # Streaming pool sampler
207
  # ---------------------------------------------------------------------------
208
 
209
- def sample_pool_streaming(pool_size: int) -> tuple[list[int], dict[int, dict]]:
 
 
210
  """Stream pool_size shuffled galaxies from HF Datasets, pre-caching images.
211
 
 
 
 
 
 
 
212
  Returns:
213
  ids: sequential ints 0..N-1 used as galaxy IDs throughout the app
214
  metadata_map: {id -> row_dict (without image column)} for display names
 
215
  """
216
  from datasets import load_dataset
217
  from datasets import Image as HFImage
218
 
 
 
 
219
  logger.info(
220
- "Streaming %d galaxies from %s (shuffle buffer=10000)...",
221
  pool_size,
222
  DATASET_ID,
 
223
  )
224
 
225
  ds = load_dataset(
@@ -235,7 +271,7 @@ def sample_pool_streaming(pool_size: int) -> tuple[list[int], dict[int, dict]]:
235
  if features and IMAGE_COLUMN in features:
236
  ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
237
 
238
- ds = ds.shuffle(seed=random.randint(0, 2**32 - 1), buffer_size=10_000)
239
  ds = ds.take(pool_size)
240
 
241
  ids: list[int] = []
@@ -259,4 +295,4 @@ def sample_pool_streaming(pool_size: int) -> tuple[list[int], dict[int, dict]]:
259
  logger.info("Streamed %d/%d galaxies", i + 1, pool_size)
260
 
261
  logger.info("Finished streaming %d galaxies", len(ids))
262
- return ids, metadata_map
 
60
  # Row / image fetching
61
  # ---------------------------------------------------------------------------
62
 
63
+ def fetch_image_bytes(row_index: int) -> bytes | None:
64
+ """Fetch raw image bytes for a single row via the dataset-viewer API.
65
+
66
+ Uses fetch_rows to get the signed image URL, then downloads the image.
67
+ Returns None on any failure.
68
+ """
69
+ rows = fetch_rows([row_index])
70
+ row = rows.get(row_index)
71
+ if row is None:
72
+ return None
73
+ img_url = _extract_image_url(row)
74
+ if not img_url:
75
+ logger.warning("No image URL in row %d", row_index)
76
+ return None
77
+ try:
78
+ resp = requests.get(img_url, headers=_hf_headers(), timeout=30)
79
+ resp.raise_for_status()
80
+ return resp.content
81
+ except Exception as e:
82
+ logger.warning("Failed to download image for row %d: %s", row_index, e)
83
+ return None
84
+
85
+
86
  def fetch_rows(offsets: list[int]) -> dict[int, dict]:
87
  """Fetch rows by offset via the HF dataset-viewer /rows endpoint.
88
 
 
229
  # Streaming pool sampler
230
  # ---------------------------------------------------------------------------
231
 
232
+ def sample_pool_streaming(
233
+ pool_size: int, seed: int | None = None
234
+ ) -> tuple[list[int], dict[int, dict], int]:
235
  """Stream pool_size shuffled galaxies from HF Datasets, pre-caching images.
236
 
237
+ Args:
238
+ pool_size: Number of galaxies to include in the pool.
239
+ seed: Shuffle seed. If None, a random seed is generated. Pass the same
240
+ seed on subsequent startups to reproduce the exact same pool order
241
+ so that saved ELO state remains valid across restarts.
242
+
243
  Returns:
244
  ids: sequential ints 0..N-1 used as galaxy IDs throughout the app
245
  metadata_map: {id -> row_dict (without image column)} for display names
246
+ seed: the seed that was used (store in tournament state for reuse)
247
  """
248
  from datasets import load_dataset
249
  from datasets import Image as HFImage
250
 
251
+ if seed is None:
252
+ seed = random.randint(0, 2**32 - 1)
253
+
254
  logger.info(
255
+ "Streaming %d galaxies from %s (shuffle seed=%d)...",
256
  pool_size,
257
  DATASET_ID,
258
+ seed,
259
  )
260
 
261
  ds = load_dataset(
 
271
  if features and IMAGE_COLUMN in features:
272
  ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
273
 
274
+ ds = ds.shuffle(seed=seed, buffer_size=10_000)
275
  ds = ds.take(pool_size)
276
 
277
  ids: list[int] = []
 
295
  logger.info("Streamed %d/%d galaxies", i + 1, pool_size)
296
 
297
  logger.info("Finished streaming %d galaxies", len(ids))
298
+ return ids, metadata_map, seed