Smith42 commited on
Commit
3d2b027
Β·
1 Parent(s): df55019

make sure we have ids

Browse files
Files changed (1) hide show
  1. src/galaxy_data_loader.py +33 -22
src/galaxy_data_loader.py CHANGED
@@ -23,8 +23,13 @@ logger = logging.getLogger(__name__)
23
  _SHUFFLE_BUFFER = 200
24
 
25
 
26
- def _make_dataset(seed: int, pool_size: int):
27
- """Return a shuffled, length-limited streaming dataset iterator."""
 
 
 
 
 
28
  from datasets import load_dataset
29
  from datasets import Image as HFImage
30
 
@@ -36,8 +41,12 @@ def _make_dataset(seed: int, pool_size: int):
36
  token=HF_TOKEN if HF_TOKEN else None,
37
  )
38
  features = getattr(ds, "features", None)
39
- if features and IMAGE_COLUMN in features:
40
- ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
 
 
 
 
41
 
42
  ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
43
  ds = ds.take(pool_size)
@@ -134,56 +143,58 @@ def sample_pool_streaming(
134
  if seed is None:
135
  seed = random.randint(0, 2**32 - 1)
136
 
137
- logger.info("Streaming %d galaxies (seed=%d)...", pool_size, seed)
138
 
139
- # Pool IDs are just 0..pool_size-1 β€” known upfront
140
- ids = list(range(pool_size))
141
  metadata_map: dict[int, dict] = {}
 
 
 
142
 
143
- it = _make_dataset(seed, pool_size)
144
- sync_count = min(prefetch_images, pool_size)
145
 
 
146
  def _extract_bytes(img_col) -> bytes | None:
147
- """Extract raw JPEG bytes from either a bytes-dict or a PIL Image."""
148
  if isinstance(img_col, dict):
149
  return img_col.get("bytes")
150
  if img_col is not None:
151
- # cast_column didn't take effect β€” img_col is a PIL Image
152
  try:
153
  import io
154
  buf = io.BytesIO()
155
  img_col.save(buf, format="JPEG")
156
  return buf.getvalue()
157
  except Exception as e:
158
- logger.warning("Failed to convert PIL image to bytes: %s", e)
159
  return None
160
 
161
- def _process_row(i: int, row: dict):
 
 
 
 
162
  img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
163
  if img_bytes:
164
  image_cache.put(i, img_bytes)
165
  else:
166
  logger.warning("No image bytes for row %d", i)
167
- metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
168
-
169
- # Synchronous: first sync_count rows β€” populate metadata + cache images
170
- for i in range(sync_count):
171
- _process_row(i, next(it))
172
 
173
- logger.info("%d images cached β€” app ready, filling remaining %d in background",
174
  sync_count, pool_size - sync_count)
175
 
176
- # Asynchronous: rest of the pool
177
  if sync_count < pool_size:
178
  def _bg():
179
  for i in range(sync_count, pool_size):
180
  try:
181
- _process_row(i, next(it))
 
 
 
182
  except StopIteration:
183
  break
184
  except Exception as e:
185
  logger.warning("Background error at row %d: %s", i, e)
186
- logger.info("Background streaming complete")
187
 
188
  threading.Thread(target=_bg, daemon=True).start()
189
 
 
23
  _SHUFFLE_BUFFER = 200
24
 
25
 
26
+ def _make_dataset(seed: int, pool_size: int, with_images: bool = True):
27
+ """Return a shuffled, length-limited streaming dataset iterator.
28
+
29
+ with_images=False uses select_columns to skip the image column entirely
30
+ at the Parquet level β€” no image bytes downloaded.
31
+ Both modes use the same seed+buffer so row i is always the same galaxy.
32
+ """
33
  from datasets import load_dataset
34
  from datasets import Image as HFImage
35
 
 
41
  token=HF_TOKEN if HF_TOKEN else None,
42
  )
43
  features = getattr(ds, "features", None)
44
+ if with_images:
45
+ if features and IMAGE_COLUMN in features:
46
+ ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
47
+ else:
48
+ if features and IMAGE_COLUMN in features:
49
+ ds = ds.select_columns([c for c in features if c != IMAGE_COLUMN])
50
 
51
  ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
52
  ds = ds.take(pool_size)
 
143
  if seed is None:
144
  seed = random.randint(0, 2**32 - 1)
145
 
146
+ logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
147
 
148
+ # Pass 1: metadata only β€” fast, no images downloaded
149
+ ids: list[int] = []
150
  metadata_map: dict[int, dict] = {}
151
+ for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
152
+ metadata_map[i] = row
153
+ ids.append(i)
154
 
155
+ logger.info("All %d galaxy IDs ready", len(ids))
 
156
 
157
+ # Pass 2: images β€” same seed+buffer β†’ same row order as pass 1
158
  def _extract_bytes(img_col) -> bytes | None:
 
159
  if isinstance(img_col, dict):
160
  return img_col.get("bytes")
161
  if img_col is not None:
 
162
  try:
163
  import io
164
  buf = io.BytesIO()
165
  img_col.save(buf, format="JPEG")
166
  return buf.getvalue()
167
  except Exception as e:
168
+ logger.warning("PIL conversion failed: %s", e)
169
  return None
170
 
171
+ img_it = _make_dataset(seed, pool_size, with_images=True)
172
+ sync_count = min(prefetch_images, pool_size)
173
+
174
+ for i in range(sync_count):
175
+ row = next(img_it)
176
  img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
177
  if img_bytes:
178
  image_cache.put(i, img_bytes)
179
  else:
180
  logger.warning("No image bytes for row %d", i)
 
 
 
 
 
181
 
182
+ logger.info("%d images cached β€” app ready, %d remaining in background",
183
  sync_count, pool_size - sync_count)
184
 
 
185
  if sync_count < pool_size:
186
  def _bg():
187
  for i in range(sync_count, pool_size):
188
  try:
189
+ row = next(img_it)
190
+ img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
191
+ if img_bytes:
192
+ image_cache.put(i, img_bytes)
193
  except StopIteration:
194
  break
195
  except Exception as e:
196
  logger.warning("Background error at row %d: %s", i, e)
197
+ logger.info("Background image caching complete")
198
 
199
  threading.Thread(target=_bg, daemon=True).start()
200