Spaces:
Running
Running
added gals
Browse files- app.py +2 -2
- src/config.py +1 -0
- src/galaxy_data_loader.py +86 -61
app.py
CHANGED
|
@@ -11,7 +11,7 @@ from src.callbacks import register_callbacks
|
|
| 11 |
from src import elo
|
| 12 |
from src.galaxy_data_loader import sample_pool_streaming, image_cache
|
| 13 |
from src.galaxy_profiles import register_metadata
|
| 14 |
-
from src.config import POOL_SIZE, POOL_SEED
|
| 15 |
|
| 16 |
logging.basicConfig(
|
| 17 |
level=logging.INFO,
|
|
@@ -41,7 +41,7 @@ def create_app() -> dash.Dash:
|
|
| 41 |
|
| 42 |
# Always stream with the fixed seed so every participant sees the same pool
|
| 43 |
logger.info("Streaming pool of %d galaxies (seed=%d)...", POOL_SIZE, POOL_SEED)
|
| 44 |
-
pool, metadata_map, _ = sample_pool_streaming(POOL_SIZE, seed=POOL_SEED)
|
| 45 |
register_metadata(metadata_map)
|
| 46 |
|
| 47 |
# Load persisted ELO state or start fresh
|
|
|
|
| 11 |
from src import elo
|
| 12 |
from src.galaxy_data_loader import sample_pool_streaming, image_cache
|
| 13 |
from src.galaxy_profiles import register_metadata
|
| 14 |
+
from src.config import POOL_SIZE, POOL_SEED, IMAGE_PREFETCH_COUNT
|
| 15 |
|
| 16 |
logging.basicConfig(
|
| 17 |
level=logging.INFO,
|
|
|
|
| 41 |
|
| 42 |
# Always stream with the fixed seed so every participant sees the same pool
|
| 43 |
logger.info("Streaming pool of %d galaxies (seed=%d)...", POOL_SIZE, POOL_SEED)
|
| 44 |
+
pool, metadata_map, _ = sample_pool_streaming(POOL_SIZE, seed=POOL_SEED, prefetch_images=IMAGE_PREFETCH_COUNT)
|
| 45 |
register_metadata(metadata_map)
|
| 46 |
|
| 47 |
# Load persisted ELO state or start fresh
|
src/config.py
CHANGED
|
@@ -22,6 +22,7 @@ IMAGE_COLUMN = os.getenv("IMAGE_COLUMN", "image")
|
|
| 22 |
ID_COLUMN = os.getenv("ID_COLUMN", "id_str")
|
| 23 |
POOL_SIZE = int(os.getenv("POOL_SIZE", "5000"))
|
| 24 |
POOL_SEED = int(os.getenv("POOL_SEED", "42"))
|
|
|
|
| 25 |
|
| 26 |
# Image cache
|
| 27 |
IMAGE_CACHE_DIR = os.getenv("IMAGE_CACHE_DIR", "cache/images")
|
|
|
|
| 22 |
ID_COLUMN = os.getenv("ID_COLUMN", "id_str")
|
| 23 |
POOL_SIZE = int(os.getenv("POOL_SIZE", "5000"))
|
| 24 |
POOL_SEED = int(os.getenv("POOL_SEED", "42"))
|
| 25 |
+
IMAGE_PREFETCH_COUNT = int(os.getenv("IMAGE_PREFETCH_COUNT", "100"))
|
| 26 |
|
| 27 |
# Image cache
|
| 28 |
IMAGE_CACHE_DIR = os.getenv("IMAGE_CACHE_DIR", "cache/images")
|
src/galaxy_data_loader.py
CHANGED
|
@@ -18,18 +18,41 @@ from src.config import (
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# ---------------------------------------------------------------------------
|
| 23 |
# ImageCache β thread-safe, disk-based LRU
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
|
| 26 |
class ImageCache:
|
| 27 |
-
"""Disk-based LRU image cache keyed by sequential pool index.
|
| 28 |
-
|
| 29 |
-
Images are written at startup by sample_pool_streaming and served from
|
| 30 |
-
disk thereafter. There is no network fallback β if an image was not
|
| 31 |
-
captured during streaming it simply won't be available.
|
| 32 |
-
"""
|
| 33 |
|
| 34 |
def __init__(self, cache_dir: str = IMAGE_CACHE_DIR, max_bytes: int = IMAGE_CACHE_MAX_BYTES):
|
| 35 |
self._dir = Path(cache_dir)
|
|
@@ -57,7 +80,6 @@ class ImageCache:
|
|
| 57 |
logger.info("Image cache: %d files, %.1f MB", len(self._access_times), total / 1e6)
|
| 58 |
|
| 59 |
def get_path(self, row_index: int) -> Path | None:
|
| 60 |
-
"""Return cached file path if present, updating access time."""
|
| 61 |
p = self._path_for(row_index)
|
| 62 |
if p.exists():
|
| 63 |
with self._lock:
|
|
@@ -66,7 +88,6 @@ class ImageCache:
|
|
| 66 |
return None
|
| 67 |
|
| 68 |
def put(self, row_index: int, image_bytes: bytes) -> Path:
|
| 69 |
-
"""Write image bytes to cache, evicting LRU entries if needed."""
|
| 70 |
p = self._path_for(row_index)
|
| 71 |
p.write_bytes(image_bytes)
|
| 72 |
size = len(image_bytes)
|
|
@@ -77,7 +98,6 @@ class ImageCache:
|
|
| 77 |
return p
|
| 78 |
|
| 79 |
def _evict_if_needed(self):
|
| 80 |
-
"""Evict LRU entries until total size is within bounds. Caller holds lock."""
|
| 81 |
while self._total_bytes > self._max_bytes and self._access_times:
|
| 82 |
lru_idx = min(self._access_times, key=self._access_times.get)
|
| 83 |
p = self._path_for(lru_idx)
|
|
@@ -89,12 +109,6 @@ class ImageCache:
|
|
| 89 |
pass
|
| 90 |
del self._access_times[lru_idx]
|
| 91 |
|
| 92 |
-
def prefetch(self, row_indices: list[int]):
|
| 93 |
-
"""Log which requested indices are missing from cache (no-op fetch)."""
|
| 94 |
-
missing = [idx for idx in row_indices if self.get_path(idx) is None]
|
| 95 |
-
if missing:
|
| 96 |
-
logger.debug("prefetch: %d indices not in cache (no refetch): %s", len(missing), missing[:5])
|
| 97 |
-
|
| 98 |
|
| 99 |
# Module-level singleton
|
| 100 |
image_cache = ImageCache()
|
|
@@ -105,67 +119,78 @@ image_cache = ImageCache()
|
|
| 105 |
# ---------------------------------------------------------------------------
|
| 106 |
|
| 107 |
def sample_pool_streaming(
|
| 108 |
-
pool_size: int,
|
|
|
|
|
|
|
| 109 |
) -> tuple[list[int], dict[int, dict], int]:
|
| 110 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
seed: Shuffle seed. Pass the same seed on subsequent startups to
|
| 115 |
-
reproduce the exact same pool so saved ELO state stays valid.
|
| 116 |
|
| 117 |
Returns:
|
| 118 |
-
ids: sequential ints 0..N-1
|
| 119 |
-
metadata_map: {id -> row_dict (
|
| 120 |
-
seed:
|
| 121 |
"""
|
| 122 |
-
from datasets import load_dataset
|
| 123 |
-
from datasets import Image as HFImage
|
| 124 |
-
|
| 125 |
if seed is None:
|
| 126 |
seed = random.randint(0, 2**32 - 1)
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
split=DATASET_SPLIT,
|
| 139 |
-
streaming=True,
|
| 140 |
-
token=HF_TOKEN if HF_TOKEN else None,
|
| 141 |
-
)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
if features and IMAGE_COLUMN in features:
|
| 145 |
-
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
metadata_map: dict[int, dict] = {}
|
| 152 |
|
| 153 |
-
|
| 154 |
img_col = row.get(IMAGE_COLUMN)
|
| 155 |
-
img_bytes: bytes | None = None
|
| 156 |
if isinstance(img_col, dict):
|
| 157 |
img_bytes = img_col.get("bytes")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
if img_bytes:
|
| 160 |
-
image_cache.put(i, img_bytes)
|
| 161 |
-
else:
|
| 162 |
-
logger.warning("No image bytes for streamed row %d", i)
|
| 163 |
-
|
| 164 |
-
metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
|
| 165 |
-
ids.append(i)
|
| 166 |
-
|
| 167 |
-
if (i + 1) % 100 == 0:
|
| 168 |
-
logger.info("Streamed %d/%d galaxies", i + 1, pool_size)
|
| 169 |
-
|
| 170 |
-
logger.info("Finished streaming %d galaxies", len(ids))
|
| 171 |
return ids, metadata_map, seed
|
|
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
+
# Must be identical across both streaming passes so row order is reproducible.
|
| 22 |
+
_SHUFFLE_BUFFER = 1_000
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _make_dataset(seed: int, pool_size: int, with_images: bool):
|
| 26 |
+
"""Return a shuffled, length-limited streaming dataset iterator."""
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
from datasets import Image as HFImage
|
| 29 |
+
|
| 30 |
+
ds = load_dataset(
|
| 31 |
+
DATASET_ID,
|
| 32 |
+
DATASET_CONFIG,
|
| 33 |
+
split=DATASET_SPLIT,
|
| 34 |
+
streaming=True,
|
| 35 |
+
token=HF_TOKEN if HF_TOKEN else None,
|
| 36 |
+
)
|
| 37 |
+
features = getattr(ds, "features", None)
|
| 38 |
+
if with_images:
|
| 39 |
+
if features and IMAGE_COLUMN in features:
|
| 40 |
+
ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
|
| 41 |
+
else:
|
| 42 |
+
if features and IMAGE_COLUMN in features:
|
| 43 |
+
ds = ds.remove_columns([IMAGE_COLUMN])
|
| 44 |
+
|
| 45 |
+
ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
|
| 46 |
+
ds = ds.take(pool_size)
|
| 47 |
+
return iter(ds)
|
| 48 |
+
|
| 49 |
|
| 50 |
# ---------------------------------------------------------------------------
|
| 51 |
# ImageCache β thread-safe, disk-based LRU
|
| 52 |
# ---------------------------------------------------------------------------
|
| 53 |
|
| 54 |
class ImageCache:
|
| 55 |
+
"""Disk-based LRU image cache keyed by sequential pool index."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def __init__(self, cache_dir: str = IMAGE_CACHE_DIR, max_bytes: int = IMAGE_CACHE_MAX_BYTES):
|
| 58 |
self._dir = Path(cache_dir)
|
|
|
|
| 80 |
logger.info("Image cache: %d files, %.1f MB", len(self._access_times), total / 1e6)
|
| 81 |
|
| 82 |
def get_path(self, row_index: int) -> Path | None:
|
|
|
|
| 83 |
p = self._path_for(row_index)
|
| 84 |
if p.exists():
|
| 85 |
with self._lock:
|
|
|
|
| 88 |
return None
|
| 89 |
|
| 90 |
def put(self, row_index: int, image_bytes: bytes) -> Path:
|
|
|
|
| 91 |
p = self._path_for(row_index)
|
| 92 |
p.write_bytes(image_bytes)
|
| 93 |
size = len(image_bytes)
|
|
|
|
| 98 |
return p
|
| 99 |
|
| 100 |
def _evict_if_needed(self):
|
|
|
|
| 101 |
while self._total_bytes > self._max_bytes and self._access_times:
|
| 102 |
lru_idx = min(self._access_times, key=self._access_times.get)
|
| 103 |
p = self._path_for(lru_idx)
|
|
|
|
| 109 |
pass
|
| 110 |
del self._access_times[lru_idx]
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Module-level singleton
|
| 114 |
image_cache = ImageCache()
|
|
|
|
| 119 |
# ---------------------------------------------------------------------------
|
| 120 |
|
| 121 |
def sample_pool_streaming(
|
| 122 |
+
pool_size: int,
|
| 123 |
+
seed: int | None = None,
|
| 124 |
+
prefetch_images: int = 100,
|
| 125 |
) -> tuple[list[int], dict[int, dict], int]:
|
| 126 |
+
"""Build the galaxy pool with lazy image loading.
|
| 127 |
+
|
| 128 |
+
Pass 1 (fast): streams metadata only β no image bytes downloaded.
|
| 129 |
+
Pass 2a (sync): caches the first `prefetch_images` images before returning,
|
| 130 |
+
so the app can serve immediately.
|
| 131 |
+
Pass 2b (async): background thread fills the rest of the image cache.
|
| 132 |
|
| 133 |
+
Both passes use the same seed and shuffle buffer so row i in pass 1
|
| 134 |
+
is the same galaxy as row i in pass 2.
|
|
|
|
|
|
|
| 135 |
|
| 136 |
Returns:
|
| 137 |
+
ids: sequential ints 0..N-1
|
| 138 |
+
metadata_map: {id -> row_dict (no image column)}
|
| 139 |
+
seed: seed used (fixed for reproducibility)
|
| 140 |
"""
|
|
|
|
|
|
|
|
|
|
| 141 |
if seed is None:
|
| 142 |
seed = random.randint(0, 2**32 - 1)
|
| 143 |
|
| 144 |
+
# ------------------------------------------------------------------
|
| 145 |
+
# Pass 1: metadata only β fast, no image bytes
|
| 146 |
+
# ------------------------------------------------------------------
|
| 147 |
+
logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
|
| 148 |
+
ids: list[int] = []
|
| 149 |
+
metadata_map: dict[int, dict] = {}
|
| 150 |
|
| 151 |
+
for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
|
| 152 |
+
metadata_map[i] = row
|
| 153 |
+
ids.append(i)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
+
logger.info("Metadata ready: %d galaxies", len(ids))
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
# ------------------------------------------------------------------
|
| 158 |
+
# Pass 2: images β same seed/buffer so row order matches pass 1
|
| 159 |
+
# ------------------------------------------------------------------
|
| 160 |
+
sync_count = min(prefetch_images, pool_size)
|
| 161 |
+
logger.info("Pre-caching first %d images...", sync_count)
|
| 162 |
|
| 163 |
+
img_iter = _make_dataset(seed, pool_size, with_images=True)
|
|
|
|
| 164 |
|
| 165 |
+
def _cache_row(i: int, row: dict):
|
| 166 |
img_col = row.get(IMAGE_COLUMN)
|
|
|
|
| 167 |
if isinstance(img_col, dict):
|
| 168 |
img_bytes = img_col.get("bytes")
|
| 169 |
+
if img_bytes:
|
| 170 |
+
image_cache.put(i, img_bytes)
|
| 171 |
+
return
|
| 172 |
+
logger.warning("No image bytes for row %d", i)
|
| 173 |
+
|
| 174 |
+
# Synchronous: first sync_count images
|
| 175 |
+
for i in range(sync_count):
|
| 176 |
+
_cache_row(i, next(img_iter))
|
| 177 |
+
|
| 178 |
+
logger.info("Initial %d images cached β app ready", sync_count)
|
| 179 |
+
|
| 180 |
+
# Asynchronous: remainder in background
|
| 181 |
+
remaining = pool_size - sync_count
|
| 182 |
+
if remaining > 0:
|
| 183 |
+
def _bg_cache():
|
| 184 |
+
logger.info("Background: caching %d remaining images...", remaining)
|
| 185 |
+
for i in range(sync_count, pool_size):
|
| 186 |
+
try:
|
| 187 |
+
_cache_row(i, next(img_iter))
|
| 188 |
+
except StopIteration:
|
| 189 |
+
break
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.warning("Background cache error at row %d: %s", i, e)
|
| 192 |
+
logger.info("Background image caching complete")
|
| 193 |
+
|
| 194 |
+
threading.Thread(target=_bg_cache, daemon=True).start()
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
return ids, metadata_map, seed
|