Spaces:
Sleeping
Sleeping
make sure we have ids
Browse files- src/galaxy_data_loader.py +33 -22
src/galaxy_data_loader.py
CHANGED
|
@@ -23,8 +23,13 @@ logger = logging.getLogger(__name__)
|
|
| 23 |
_SHUFFLE_BUFFER = 200
|
| 24 |
|
| 25 |
|
| 26 |
-
def _make_dataset(seed: int, pool_size: int):
|
| 27 |
-
"""Return a shuffled, length-limited streaming dataset iterator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from datasets import load_dataset
|
| 29 |
from datasets import Image as HFImage
|
| 30 |
|
|
@@ -36,8 +41,12 @@ def _make_dataset(seed: int, pool_size: int):
|
|
| 36 |
token=HF_TOKEN if HF_TOKEN else None,
|
| 37 |
)
|
| 38 |
features = getattr(ds, "features", None)
|
| 39 |
-
if
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
|
| 43 |
ds = ds.take(pool_size)
|
|
@@ -134,56 +143,58 @@ def sample_pool_streaming(
|
|
| 134 |
if seed is None:
|
| 135 |
seed = random.randint(0, 2**32 - 1)
|
| 136 |
|
| 137 |
-
logger.info("Streaming %d galaxies (seed=%d)...", pool_size, seed)
|
| 138 |
|
| 139 |
-
#
|
| 140 |
-
ids =
|
| 141 |
metadata_map: dict[int, dict] = {}
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
|
| 144 |
-
sync_count = min(prefetch_images, pool_size)
|
| 145 |
|
|
|
|
| 146 |
def _extract_bytes(img_col) -> bytes | None:
|
| 147 |
-
"""Extract raw JPEG bytes from either a bytes-dict or a PIL Image."""
|
| 148 |
if isinstance(img_col, dict):
|
| 149 |
return img_col.get("bytes")
|
| 150 |
if img_col is not None:
|
| 151 |
-
# cast_column didn't take effect β img_col is a PIL Image
|
| 152 |
try:
|
| 153 |
import io
|
| 154 |
buf = io.BytesIO()
|
| 155 |
img_col.save(buf, format="JPEG")
|
| 156 |
return buf.getvalue()
|
| 157 |
except Exception as e:
|
| 158 |
-
logger.warning("
|
| 159 |
return None
|
| 160 |
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
|
| 163 |
if img_bytes:
|
| 164 |
image_cache.put(i, img_bytes)
|
| 165 |
else:
|
| 166 |
logger.warning("No image bytes for row %d", i)
|
| 167 |
-
metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
|
| 168 |
-
|
| 169 |
-
# Synchronous: first sync_count rows β populate metadata + cache images
|
| 170 |
-
for i in range(sync_count):
|
| 171 |
-
_process_row(i, next(it))
|
| 172 |
|
| 173 |
-
logger.info("%d images cached β app ready,
|
| 174 |
sync_count, pool_size - sync_count)
|
| 175 |
|
| 176 |
-
# Asynchronous: rest of the pool
|
| 177 |
if sync_count < pool_size:
|
| 178 |
def _bg():
|
| 179 |
for i in range(sync_count, pool_size):
|
| 180 |
try:
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
| 182 |
except StopIteration:
|
| 183 |
break
|
| 184 |
except Exception as e:
|
| 185 |
logger.warning("Background error at row %d: %s", i, e)
|
| 186 |
-
logger.info("Background
|
| 187 |
|
| 188 |
threading.Thread(target=_bg, daemon=True).start()
|
| 189 |
|
|
|
|
| 23 |
_SHUFFLE_BUFFER = 200
|
| 24 |
|
| 25 |
|
| 26 |
+
def _make_dataset(seed: int, pool_size: int, with_images: bool = True):
|
| 27 |
+
"""Return a shuffled, length-limited streaming dataset iterator.
|
| 28 |
+
|
| 29 |
+
with_images=False uses select_columns to skip the image column entirely
|
| 30 |
+
at the Parquet level β no image bytes downloaded.
|
| 31 |
+
Both modes use the same seed+buffer so row i is always the same galaxy.
|
| 32 |
+
"""
|
| 33 |
from datasets import load_dataset
|
| 34 |
from datasets import Image as HFImage
|
| 35 |
|
|
|
|
| 41 |
token=HF_TOKEN if HF_TOKEN else None,
|
| 42 |
)
|
| 43 |
features = getattr(ds, "features", None)
|
| 44 |
+
if with_images:
|
| 45 |
+
if features and IMAGE_COLUMN in features:
|
| 46 |
+
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
|
| 47 |
+
else:
|
| 48 |
+
if features and IMAGE_COLUMN in features:
|
| 49 |
+
ds = ds.select_columns([c for c in features if c != IMAGE_COLUMN])
|
| 50 |
|
| 51 |
ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
|
| 52 |
ds = ds.take(pool_size)
|
|
|
|
| 143 |
if seed is None:
|
| 144 |
seed = random.randint(0, 2**32 - 1)
|
| 145 |
|
| 146 |
+
logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
|
| 147 |
|
| 148 |
+
# Pass 1: metadata only β fast, no images downloaded
|
| 149 |
+
ids: list[int] = []
|
| 150 |
metadata_map: dict[int, dict] = {}
|
| 151 |
+
for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
|
| 152 |
+
metadata_map[i] = row
|
| 153 |
+
ids.append(i)
|
| 154 |
|
| 155 |
+
logger.info("All %d galaxy IDs ready", len(ids))
|
|
|
|
| 156 |
|
| 157 |
+
# Pass 2: images β same seed+buffer β same row order as pass 1
|
| 158 |
def _extract_bytes(img_col) -> bytes | None:
|
|
|
|
| 159 |
if isinstance(img_col, dict):
|
| 160 |
return img_col.get("bytes")
|
| 161 |
if img_col is not None:
|
|
|
|
| 162 |
try:
|
| 163 |
import io
|
| 164 |
buf = io.BytesIO()
|
| 165 |
img_col.save(buf, format="JPEG")
|
| 166 |
return buf.getvalue()
|
| 167 |
except Exception as e:
|
| 168 |
+
logger.warning("PIL conversion failed: %s", e)
|
| 169 |
return None
|
| 170 |
|
| 171 |
+
img_it = _make_dataset(seed, pool_size, with_images=True)
|
| 172 |
+
sync_count = min(prefetch_images, pool_size)
|
| 173 |
+
|
| 174 |
+
for i in range(sync_count):
|
| 175 |
+
row = next(img_it)
|
| 176 |
img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
|
| 177 |
if img_bytes:
|
| 178 |
image_cache.put(i, img_bytes)
|
| 179 |
else:
|
| 180 |
logger.warning("No image bytes for row %d", i)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
+
logger.info("%d images cached β app ready, %d remaining in background",
|
| 183 |
sync_count, pool_size - sync_count)
|
| 184 |
|
|
|
|
| 185 |
if sync_count < pool_size:
|
| 186 |
def _bg():
|
| 187 |
for i in range(sync_count, pool_size):
|
| 188 |
try:
|
| 189 |
+
row = next(img_it)
|
| 190 |
+
img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
|
| 191 |
+
if img_bytes:
|
| 192 |
+
image_cache.put(i, img_bytes)
|
| 193 |
except StopIteration:
|
| 194 |
break
|
| 195 |
except Exception as e:
|
| 196 |
logger.warning("Background error at row %d: %s", i, e)
|
| 197 |
+
logger.info("Background image caching complete")
|
| 198 |
|
| 199 |
threading.Thread(target=_bg, daemon=True).start()
|
| 200 |
|