Spaces:
Running
Running
File size: 6,714 Bytes
21ce873 159bb46 f229b66 159bb46 879d582 285999a 879d582 159bb46 0a4af2e 1a6597d 3d2b027 1a6597d 3d2b027 285999a 1a6597d 879d582 1a6597d 879d582 159bb46 879d582 21ce873 346f507 1a6597d 346f507 0a4af2e 1a6597d 0a4af2e 346f507 21ce873 1a6597d 0a4af2e 21ce873 346f507 3d2b027 21ce873 3d2b027 0a4af2e 3d2b027 285999a 3d2b027 21ce873 3d2b027 21ce873 3d2b027 df55019 3d2b027 df55019 3d2b027 df55019 0a4af2e df55019 3d2b027 0a4af2e 1a6597d 3d2b027 1a6597d 0a4af2e 3d2b027 1a6597d 0a4af2e 21ce873 346f507 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """Galaxy data loading: HF Datasets streaming sampler + disk-based LRU image cache."""
from __future__ import annotations
import logging
import random
import threading
import time
from pathlib import Path
from src.config import (
DATASET_CONFIG,
DATASET_ID,
DATASET_SPLIT,
HF_TOKEN,
ID_COLUMN,
IMAGE_CACHE_DIR,
IMAGE_CACHE_MAX_BYTES,
IMAGE_COLUMN,
)
logger = logging.getLogger(__name__)
_SHUFFLE_BUFFER = 200
def _make_dataset(seed: int, pool_size: int, with_images: bool = True):
"""Return a shuffled, length-limited streaming dataset iterator.
with_images=False uses select_columns to skip the image column entirely
at the Parquet level β no image bytes downloaded.
Both modes use the same seed+buffer so row i is always the same galaxy.
"""
from datasets import load_dataset
from datasets import Image as HFImage
ds = load_dataset(
DATASET_ID,
DATASET_CONFIG,
split=DATASET_SPLIT,
streaming=True,
token=HF_TOKEN if HF_TOKEN else None,
)
features = getattr(ds, "features", None)
if with_images:
if features and IMAGE_COLUMN in features:
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
else:
if features and ID_COLUMN in features:
ds = ds.select_columns([ID_COLUMN])
ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
ds = ds.take(pool_size)
return iter(ds)
# ---------------------------------------------------------------------------
# ImageCache β thread-safe, disk-based LRU
# ---------------------------------------------------------------------------
class ImageCache:
"""Disk-based LRU image cache keyed by sequential pool index."""
def __init__(self, cache_dir: str = IMAGE_CACHE_DIR, max_bytes: int = IMAGE_CACHE_MAX_BYTES):
self._dir = Path(cache_dir)
self._dir.mkdir(parents=True, exist_ok=True)
self._max_bytes = max_bytes
self._lock = threading.Lock()
self._access_times: dict[int, float] = {}
self._total_bytes = 0
self._scan_existing()
def _path_for(self, row_index: int) -> Path:
return self._dir / f"{row_index}.jpg"
def _scan_existing(self):
total = 0
for p in self._dir.glob("*.jpg"):
try:
idx = int(p.stem)
size = p.stat().st_size
total += size
self._access_times[idx] = p.stat().st_mtime
except (ValueError, OSError):
continue
self._total_bytes = total
logger.info("Image cache: %d files, %.1f MB", len(self._access_times), total / 1e6)
def get_path(self, row_index: int) -> Path | None:
p = self._path_for(row_index)
if p.exists():
with self._lock:
self._access_times[row_index] = time.monotonic()
return p
return None
def put(self, row_index: int, image_bytes: bytes) -> Path:
p = self._path_for(row_index)
p.write_bytes(image_bytes)
size = len(image_bytes)
with self._lock:
self._access_times[row_index] = time.monotonic()
self._total_bytes += size
self._evict_if_needed()
return p
def _evict_if_needed(self):
while self._total_bytes > self._max_bytes and self._access_times:
lru_idx = min(self._access_times, key=self._access_times.get)
p = self._path_for(lru_idx)
try:
size = p.stat().st_size
p.unlink()
self._total_bytes -= size
except OSError:
pass
del self._access_times[lru_idx]
# Module-level singleton
image_cache = ImageCache()
# ---------------------------------------------------------------------------
# Streaming pool sampler
# ---------------------------------------------------------------------------
def sample_pool_streaming(
pool_size: int,
seed: int | None = None,
prefetch_images: int = 100,
) -> tuple[list[int], dict[int, dict], int]:
"""Build the galaxy pool, caching a small batch of images before returning.
Single streaming pass with cast_column(decode=False) to avoid Pillow.
The shuffle buffer is small (200 rows) so only ~300 images are downloaded
before the app starts serving. The rest are cached in a background thread.
Returns:
ids: sequential ints 0..N-1
metadata_map: {id -> row_dict (no image column)}
seed: seed used
"""
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
# Pass 1: metadata only β fast, no images downloaded
ids: list[int] = []
metadata_map: dict[int, dict] = {}
for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
metadata_map[i] = {ID_COLUMN: row.get(ID_COLUMN)}
ids.append(i)
logger.info("All %d galaxy IDs ready", len(ids))
# Pass 2: images β same seed+buffer β same row order as pass 1
def _extract_bytes(img_col) -> bytes | None:
if isinstance(img_col, dict):
return img_col.get("bytes")
if img_col is not None:
try:
import io
buf = io.BytesIO()
img_col.save(buf, format="JPEG")
return buf.getvalue()
except Exception as e:
logger.warning("PIL conversion failed: %s", e)
return None
img_it = _make_dataset(seed, pool_size, with_images=True)
sync_count = min(prefetch_images, pool_size)
for i in range(sync_count):
row = next(img_it)
img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
if img_bytes:
image_cache.put(i, img_bytes)
else:
logger.warning("No image bytes for row %d", i)
logger.info("%d images cached β app ready, %d remaining in background",
sync_count, pool_size - sync_count)
if sync_count < pool_size:
def _bg():
for i in range(sync_count, pool_size):
try:
row = next(img_it)
img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
if img_bytes:
image_cache.put(i, img_bytes)
except StopIteration:
break
except Exception as e:
logger.warning("Background error at row %d: %s", i, e)
logger.info("Background image caching complete")
threading.Thread(target=_bg, daemon=True).start()
return ids, metadata_map, seed
|