Spaces:
Running
Running
Add galaxy
Browse files- src/galaxy_data_loader.py +37 -59
src/galaxy_data_loader.py
CHANGED
|
@@ -20,11 +20,10 @@ from src.config import (
|
|
| 20 |
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
_SHUFFLE_BUFFER = 1_000
|
| 25 |
|
| 26 |
|
| 27 |
-
def _make_dataset(seed: int, pool_size: int
|
| 28 |
"""Return a shuffled, length-limited streaming dataset iterator."""
|
| 29 |
from datasets import load_dataset
|
| 30 |
from datasets import Image as HFImage
|
|
@@ -37,13 +36,8 @@ def _make_dataset(seed: int, pool_size: int, with_images: bool):
|
|
| 37 |
token=HF_TOKEN if HF_TOKEN else None,
|
| 38 |
)
|
| 39 |
features = getattr(ds, "features", None)
|
| 40 |
-
if
|
| 41 |
-
|
| 42 |
-
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
|
| 43 |
-
else:
|
| 44 |
-
# Column projection at Parquet level β image bytes never downloaded
|
| 45 |
-
if features and IMAGE_COLUMN in features:
|
| 46 |
-
ds = ds.select_columns([c for c in features if c != IMAGE_COLUMN])
|
| 47 |
|
| 48 |
ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
|
| 49 |
ds = ds.take(pool_size)
|
|
@@ -126,74 +120,58 @@ def sample_pool_streaming(
|
|
| 126 |
seed: int | None = None,
|
| 127 |
prefetch_images: int = 100,
|
| 128 |
) -> tuple[list[int], dict[int, dict], int]:
|
| 129 |
-
"""Build the galaxy pool
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
Pass 2b (async): background thread fills the rest of the image cache.
|
| 135 |
-
|
| 136 |
-
Both passes use the same seed and shuffle buffer so row i in pass 1
|
| 137 |
-
is the same galaxy as row i in pass 2.
|
| 138 |
|
| 139 |
Returns:
|
| 140 |
ids: sequential ints 0..N-1
|
| 141 |
metadata_map: {id -> row_dict (no image column)}
|
| 142 |
-
seed: seed used
|
| 143 |
"""
|
| 144 |
if seed is None:
|
| 145 |
seed = random.randint(0, 2**32 - 1)
|
| 146 |
|
| 147 |
-
|
| 148 |
-
# Pass 1: metadata only β fast, no image bytes
|
| 149 |
-
# ------------------------------------------------------------------
|
| 150 |
-
logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
|
| 151 |
-
ids: list[int] = []
|
| 152 |
-
metadata_map: dict[int, dict] = {}
|
| 153 |
-
|
| 154 |
-
for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
|
| 155 |
-
metadata_map[i] = row
|
| 156 |
-
ids.append(i)
|
| 157 |
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
# Pass 2: images β same seed/buffer so row order matches pass 1
|
| 162 |
-
# ------------------------------------------------------------------
|
| 163 |
sync_count = min(prefetch_images, pool_size)
|
| 164 |
-
logger.info("Pre-caching first %d images...", sync_count)
|
| 165 |
-
|
| 166 |
-
img_iter = _make_dataset(seed, pool_size, with_images=True)
|
| 167 |
|
| 168 |
-
|
| 169 |
-
img_col = row.get(IMAGE_COLUMN)
|
| 170 |
-
if isinstance(img_col, dict):
|
| 171 |
-
img_bytes = img_col.get("bytes")
|
| 172 |
-
if img_bytes:
|
| 173 |
-
image_cache.put(i, img_bytes)
|
| 174 |
-
return
|
| 175 |
-
logger.warning("No image bytes for row %d", i)
|
| 176 |
-
|
| 177 |
-
# Synchronous: first sync_count images
|
| 178 |
for i in range(sync_count):
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
for i in range(sync_count, pool_size):
|
| 189 |
try:
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
except StopIteration:
|
| 192 |
break
|
| 193 |
except Exception as e:
|
| 194 |
-
logger.warning("Background
|
| 195 |
-
logger.info("Background
|
| 196 |
|
| 197 |
-
threading.Thread(target=
|
| 198 |
|
| 199 |
return ids, metadata_map, seed
|
|
|
|
| 20 |
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
+
_SHUFFLE_BUFFER = 200
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
+
def _make_dataset(seed: int, pool_size: int):
|
| 27 |
"""Return a shuffled, length-limited streaming dataset iterator."""
|
| 28 |
from datasets import load_dataset
|
| 29 |
from datasets import Image as HFImage
|
|
|
|
| 36 |
token=HF_TOKEN if HF_TOKEN else None,
|
| 37 |
)
|
| 38 |
features = getattr(ds, "features", None)
|
| 39 |
+
if features and IMAGE_COLUMN in features:
|
| 40 |
+
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
|
| 43 |
ds = ds.take(pool_size)
|
|
|
|
| 120 |
seed: int | None = None,
|
| 121 |
prefetch_images: int = 100,
|
| 122 |
) -> tuple[list[int], dict[int, dict], int]:
|
| 123 |
+
"""Build the galaxy pool, caching a small batch of images before returning.
|
| 124 |
|
| 125 |
+
Single streaming pass with cast_column(decode=False) to avoid Pillow.
|
| 126 |
+
The shuffle buffer is small (200 rows) so only ~300 images are downloaded
|
| 127 |
+
before the app starts serving. The rest are cached in a background thread.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
Returns:
|
| 130 |
ids: sequential ints 0..N-1
|
| 131 |
metadata_map: {id -> row_dict (no image column)}
|
| 132 |
+
seed: seed used
|
| 133 |
"""
|
| 134 |
if seed is None:
|
| 135 |
seed = random.randint(0, 2**32 - 1)
|
| 136 |
|
| 137 |
+
logger.info("Streaming %d galaxies (seed=%d)...", pool_size, seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
# Pool IDs are just 0..pool_size-1 β known upfront
|
| 140 |
+
ids = list(range(pool_size))
|
| 141 |
+
metadata_map: dict[int, dict] = {}
|
| 142 |
|
| 143 |
+
it = _make_dataset(seed, pool_size)
|
|
|
|
|
|
|
| 144 |
sync_count = min(prefetch_images, pool_size)
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Synchronous: first sync_count rows β populate metadata + cache images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
for i in range(sync_count):
|
| 148 |
+
row = next(it)
|
| 149 |
+
img_col = row.get(IMAGE_COLUMN)
|
| 150 |
+
if isinstance(img_col, dict) and img_col.get("bytes"):
|
| 151 |
+
image_cache.put(i, img_col["bytes"])
|
| 152 |
+
else:
|
| 153 |
+
logger.warning("No image bytes for row %d", i)
|
| 154 |
+
metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
|
| 155 |
+
|
| 156 |
+
logger.info("%d images cached β app ready, filling remaining %d in background",
|
| 157 |
+
sync_count, pool_size - sync_count)
|
| 158 |
+
|
| 159 |
+
# Asynchronous: rest of the pool
|
| 160 |
+
if sync_count < pool_size:
|
| 161 |
+
def _bg():
|
| 162 |
for i in range(sync_count, pool_size):
|
| 163 |
try:
|
| 164 |
+
row = next(it)
|
| 165 |
+
img_col = row.get(IMAGE_COLUMN)
|
| 166 |
+
if isinstance(img_col, dict) and img_col.get("bytes"):
|
| 167 |
+
image_cache.put(i, img_col["bytes"])
|
| 168 |
+
metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
|
| 169 |
except StopIteration:
|
| 170 |
break
|
| 171 |
except Exception as e:
|
| 172 |
+
logger.warning("Background error at row %d: %s", i, e)
|
| 173 |
+
logger.info("Background streaming complete")
|
| 174 |
|
| 175 |
+
threading.Thread(target=_bg, daemon=True).start()
|
| 176 |
|
| 177 |
return ids, metadata_map, seed
|