Smith42 commited on
Commit
0a4af2e
Β·
1 Parent(s): e789c3d

Add galaxy

Browse files
Files changed (1) hide show
  1. src/galaxy_data_loader.py +37 -59
src/galaxy_data_loader.py CHANGED
@@ -20,11 +20,10 @@ from src.config import (
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
- # Must be identical across both streaming passes so row order is reproducible.
24
- _SHUFFLE_BUFFER = 1_000
25
 
26
 
27
- def _make_dataset(seed: int, pool_size: int, with_images: bool):
28
  """Return a shuffled, length-limited streaming dataset iterator."""
29
  from datasets import load_dataset
30
  from datasets import Image as HFImage
@@ -37,13 +36,8 @@ def _make_dataset(seed: int, pool_size: int, with_images: bool):
37
  token=HF_TOKEN if HF_TOKEN else None,
38
  )
39
  features = getattr(ds, "features", None)
40
- if with_images:
41
- if features and IMAGE_COLUMN in features:
42
- ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
43
- else:
44
- # Column projection at Parquet level β€” image bytes never downloaded
45
- if features and IMAGE_COLUMN in features:
46
- ds = ds.select_columns([c for c in features if c != IMAGE_COLUMN])
47
 
48
  ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
49
  ds = ds.take(pool_size)
@@ -126,74 +120,58 @@ def sample_pool_streaming(
126
  seed: int | None = None,
127
  prefetch_images: int = 100,
128
  ) -> tuple[list[int], dict[int, dict], int]:
129
- """Build the galaxy pool with lazy image loading.
130
 
131
- Pass 1 (fast): streams metadata only β€” no image bytes downloaded.
132
- Pass 2a (sync): caches the first `prefetch_images` images before returning,
133
- so the app can serve immediately.
134
- Pass 2b (async): background thread fills the rest of the image cache.
135
-
136
- Both passes use the same seed and shuffle buffer so row i in pass 1
137
- is the same galaxy as row i in pass 2.
138
 
139
  Returns:
140
  ids: sequential ints 0..N-1
141
  metadata_map: {id -> row_dict (no image column)}
142
- seed: seed used (fixed for reproducibility)
143
  """
144
  if seed is None:
145
  seed = random.randint(0, 2**32 - 1)
146
 
147
- # ------------------------------------------------------------------
148
- # Pass 1: metadata only β€” fast, no image bytes
149
- # ------------------------------------------------------------------
150
- logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
151
- ids: list[int] = []
152
- metadata_map: dict[int, dict] = {}
153
-
154
- for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
155
- metadata_map[i] = row
156
- ids.append(i)
157
 
158
- logger.info("Metadata ready: %d galaxies", len(ids))
 
 
159
 
160
- # ------------------------------------------------------------------
161
- # Pass 2: images β€” same seed/buffer so row order matches pass 1
162
- # ------------------------------------------------------------------
163
  sync_count = min(prefetch_images, pool_size)
164
- logger.info("Pre-caching first %d images...", sync_count)
165
-
166
- img_iter = _make_dataset(seed, pool_size, with_images=True)
167
 
168
- def _cache_row(i: int, row: dict):
169
- img_col = row.get(IMAGE_COLUMN)
170
- if isinstance(img_col, dict):
171
- img_bytes = img_col.get("bytes")
172
- if img_bytes:
173
- image_cache.put(i, img_bytes)
174
- return
175
- logger.warning("No image bytes for row %d", i)
176
-
177
- # Synchronous: first sync_count images
178
  for i in range(sync_count):
179
- _cache_row(i, next(img_iter))
180
-
181
- logger.info("Initial %d images cached β€” app ready", sync_count)
182
-
183
- # Asynchronous: remainder in background
184
- remaining = pool_size - sync_count
185
- if remaining > 0:
186
- def _bg_cache():
187
- logger.info("Background: caching %d remaining images...", remaining)
 
 
 
 
 
188
  for i in range(sync_count, pool_size):
189
  try:
190
- _cache_row(i, next(img_iter))
 
 
 
 
191
  except StopIteration:
192
  break
193
  except Exception as e:
194
- logger.warning("Background cache error at row %d: %s", i, e)
195
- logger.info("Background image caching complete")
196
 
197
- threading.Thread(target=_bg_cache, daemon=True).start()
198
 
199
  return ids, metadata_map, seed
 
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
+ _SHUFFLE_BUFFER = 200
 
24
 
25
 
26
+ def _make_dataset(seed: int, pool_size: int):
27
  """Return a shuffled, length-limited streaming dataset iterator."""
28
  from datasets import load_dataset
29
  from datasets import Image as HFImage
 
36
  token=HF_TOKEN if HF_TOKEN else None,
37
  )
38
  features = getattr(ds, "features", None)
39
+ if features and IMAGE_COLUMN in features:
40
+ ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
 
 
 
 
 
41
 
42
  ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
43
  ds = ds.take(pool_size)
 
120
  seed: int | None = None,
121
  prefetch_images: int = 100,
122
  ) -> tuple[list[int], dict[int, dict], int]:
123
+ """Build the galaxy pool, caching a small batch of images before returning.
124
 
125
+ Single streaming pass with cast_column(decode=False) to avoid Pillow.
126
+ The shuffle buffer is small (200 rows) so only ~300 images are downloaded
127
+ before the app starts serving. The rest are cached in a background thread.
 
 
 
 
128
 
129
  Returns:
130
  ids: sequential ints 0..N-1
131
  metadata_map: {id -> row_dict (no image column)}
132
+ seed: seed used
133
  """
134
  if seed is None:
135
  seed = random.randint(0, 2**32 - 1)
136
 
137
+ logger.info("Streaming %d galaxies (seed=%d)...", pool_size, seed)
 
 
 
 
 
 
 
 
 
138
 
139
+ # Pool IDs are just 0..pool_size-1 β€” known upfront
140
+ ids = list(range(pool_size))
141
+ metadata_map: dict[int, dict] = {}
142
 
143
+ it = _make_dataset(seed, pool_size)
 
 
144
  sync_count = min(prefetch_images, pool_size)
 
 
 
145
 
146
+ # Synchronous: first sync_count rows β€” populate metadata + cache images
 
 
 
 
 
 
 
 
 
147
  for i in range(sync_count):
148
+ row = next(it)
149
+ img_col = row.get(IMAGE_COLUMN)
150
+ if isinstance(img_col, dict) and img_col.get("bytes"):
151
+ image_cache.put(i, img_col["bytes"])
152
+ else:
153
+ logger.warning("No image bytes for row %d", i)
154
+ metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
155
+
156
+ logger.info("%d images cached β€” app ready, filling remaining %d in background",
157
+ sync_count, pool_size - sync_count)
158
+
159
+ # Asynchronous: rest of the pool
160
+ if sync_count < pool_size:
161
+ def _bg():
162
  for i in range(sync_count, pool_size):
163
  try:
164
+ row = next(it)
165
+ img_col = row.get(IMAGE_COLUMN)
166
+ if isinstance(img_col, dict) and img_col.get("bytes"):
167
+ image_cache.put(i, img_col["bytes"])
168
+ metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
169
  except StopIteration:
170
  break
171
  except Exception as e:
172
+ logger.warning("Background error at row %d: %s", i, e)
173
+ logger.info("Background streaming complete")
174
 
175
+ threading.Thread(target=_bg, daemon=True).start()
176
 
177
  return ids, metadata_map, seed