perihelion / src /galaxy_data_loader.py
Smith42's picture
only load dr8 id on startup
285999a
"""Galaxy data loading: HF Datasets streaming sampler + disk-based LRU image cache."""
from __future__ import annotations
import logging
import random
import threading
import time
from pathlib import Path
from src.config import (
DATASET_CONFIG,
DATASET_ID,
DATASET_SPLIT,
HF_TOKEN,
ID_COLUMN,
IMAGE_CACHE_DIR,
IMAGE_CACHE_MAX_BYTES,
IMAGE_COLUMN,
)
logger = logging.getLogger(__name__)
_SHUFFLE_BUFFER = 200
def _make_dataset(seed: int, pool_size: int, with_images: bool = True):
"""Return a shuffled, length-limited streaming dataset iterator.
with_images=False uses select_columns to skip the image column entirely
at the Parquet level β€” no image bytes downloaded.
Both modes use the same seed+buffer so row i is always the same galaxy.
"""
from datasets import load_dataset
from datasets import Image as HFImage
ds = load_dataset(
DATASET_ID,
DATASET_CONFIG,
split=DATASET_SPLIT,
streaming=True,
token=HF_TOKEN if HF_TOKEN else None,
)
features = getattr(ds, "features", None)
if with_images:
if features and IMAGE_COLUMN in features:
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
else:
if features and ID_COLUMN in features:
ds = ds.select_columns([ID_COLUMN])
ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
ds = ds.take(pool_size)
return iter(ds)
# ---------------------------------------------------------------------------
# ImageCache β€” thread-safe, disk-based LRU
# ---------------------------------------------------------------------------
class ImageCache:
"""Disk-based LRU image cache keyed by sequential pool index."""
def __init__(self, cache_dir: str = IMAGE_CACHE_DIR, max_bytes: int = IMAGE_CACHE_MAX_BYTES):
self._dir = Path(cache_dir)
self._dir.mkdir(parents=True, exist_ok=True)
self._max_bytes = max_bytes
self._lock = threading.Lock()
self._access_times: dict[int, float] = {}
self._total_bytes = 0
self._scan_existing()
def _path_for(self, row_index: int) -> Path:
return self._dir / f"{row_index}.jpg"
def _scan_existing(self):
total = 0
for p in self._dir.glob("*.jpg"):
try:
idx = int(p.stem)
size = p.stat().st_size
total += size
self._access_times[idx] = p.stat().st_mtime
except (ValueError, OSError):
continue
self._total_bytes = total
logger.info("Image cache: %d files, %.1f MB", len(self._access_times), total / 1e6)
def get_path(self, row_index: int) -> Path | None:
p = self._path_for(row_index)
if p.exists():
with self._lock:
self._access_times[row_index] = time.monotonic()
return p
return None
def put(self, row_index: int, image_bytes: bytes) -> Path:
p = self._path_for(row_index)
p.write_bytes(image_bytes)
size = len(image_bytes)
with self._lock:
self._access_times[row_index] = time.monotonic()
self._total_bytes += size
self._evict_if_needed()
return p
def _evict_if_needed(self):
while self._total_bytes > self._max_bytes and self._access_times:
lru_idx = min(self._access_times, key=self._access_times.get)
p = self._path_for(lru_idx)
try:
size = p.stat().st_size
p.unlink()
self._total_bytes -= size
except OSError:
pass
del self._access_times[lru_idx]
# Module-level singleton
image_cache = ImageCache()
# ---------------------------------------------------------------------------
# Streaming pool sampler
# ---------------------------------------------------------------------------
def sample_pool_streaming(
pool_size: int,
seed: int | None = None,
prefetch_images: int = 100,
) -> tuple[list[int], dict[int, dict], int]:
"""Build the galaxy pool, caching a small batch of images before returning.
Single streaming pass with cast_column(decode=False) to avoid Pillow.
The shuffle buffer is small (200 rows) so only ~300 images are downloaded
before the app starts serving. The rest are cached in a background thread.
Returns:
ids: sequential ints 0..N-1
metadata_map: {id -> row_dict (no image column)}
seed: seed used
"""
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
# Pass 1: metadata only β€” fast, no images downloaded
ids: list[int] = []
metadata_map: dict[int, dict] = {}
for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
metadata_map[i] = {ID_COLUMN: row.get(ID_COLUMN)}
ids.append(i)
logger.info("All %d galaxy IDs ready", len(ids))
# Pass 2: images β€” same seed+buffer β†’ same row order as pass 1
def _extract_bytes(img_col) -> bytes | None:
if isinstance(img_col, dict):
return img_col.get("bytes")
if img_col is not None:
try:
import io
buf = io.BytesIO()
img_col.save(buf, format="JPEG")
return buf.getvalue()
except Exception as e:
logger.warning("PIL conversion failed: %s", e)
return None
img_it = _make_dataset(seed, pool_size, with_images=True)
sync_count = min(prefetch_images, pool_size)
for i in range(sync_count):
row = next(img_it)
img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
if img_bytes:
image_cache.put(i, img_bytes)
else:
logger.warning("No image bytes for row %d", i)
logger.info("%d images cached β€” app ready, %d remaining in background",
sync_count, pool_size - sync_count)
if sync_count < pool_size:
def _bg():
for i in range(sync_count, pool_size):
try:
row = next(img_it)
img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
if img_bytes:
image_cache.put(i, img_bytes)
except StopIteration:
break
except Exception as e:
logger.warning("Background error at row %d: %s", i, e)
logger.info("Background image caching complete")
threading.Thread(target=_bg, daemon=True).start()
return ids, metadata_map, seed