Smith42 commited on
Commit
1a6597d
Β·
1 Parent(s): a63a087

added gals

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. src/config.py +1 -0
  3. src/galaxy_data_loader.py +86 -61
app.py CHANGED
@@ -11,7 +11,7 @@ from src.callbacks import register_callbacks
11
  from src import elo
12
  from src.galaxy_data_loader import sample_pool_streaming, image_cache
13
  from src.galaxy_profiles import register_metadata
14
- from src.config import POOL_SIZE, POOL_SEED
15
 
16
  logging.basicConfig(
17
  level=logging.INFO,
@@ -41,7 +41,7 @@ def create_app() -> dash.Dash:
41
 
42
  # Always stream with the fixed seed so every participant sees the same pool
43
  logger.info("Streaming pool of %d galaxies (seed=%d)...", POOL_SIZE, POOL_SEED)
44
- pool, metadata_map, _ = sample_pool_streaming(POOL_SIZE, seed=POOL_SEED)
45
  register_metadata(metadata_map)
46
 
47
  # Load persisted ELO state or start fresh
 
11
  from src import elo
12
  from src.galaxy_data_loader import sample_pool_streaming, image_cache
13
  from src.galaxy_profiles import register_metadata
14
+ from src.config import POOL_SIZE, POOL_SEED, IMAGE_PREFETCH_COUNT
15
 
16
  logging.basicConfig(
17
  level=logging.INFO,
 
41
 
42
  # Always stream with the fixed seed so every participant sees the same pool
43
  logger.info("Streaming pool of %d galaxies (seed=%d)...", POOL_SIZE, POOL_SEED)
44
+ pool, metadata_map, _ = sample_pool_streaming(POOL_SIZE, seed=POOL_SEED, prefetch_images=IMAGE_PREFETCH_COUNT)
45
  register_metadata(metadata_map)
46
 
47
  # Load persisted ELO state or start fresh
src/config.py CHANGED
@@ -22,6 +22,7 @@ IMAGE_COLUMN = os.getenv("IMAGE_COLUMN", "image")
22
  ID_COLUMN = os.getenv("ID_COLUMN", "id_str")
23
  POOL_SIZE = int(os.getenv("POOL_SIZE", "5000"))
24
  POOL_SEED = int(os.getenv("POOL_SEED", "42"))
 
25
 
26
  # Image cache
27
  IMAGE_CACHE_DIR = os.getenv("IMAGE_CACHE_DIR", "cache/images")
 
22
  ID_COLUMN = os.getenv("ID_COLUMN", "id_str")
23
  POOL_SIZE = int(os.getenv("POOL_SIZE", "5000"))
24
  POOL_SEED = int(os.getenv("POOL_SEED", "42"))
25
+ IMAGE_PREFETCH_COUNT = int(os.getenv("IMAGE_PREFETCH_COUNT", "100"))
26
 
27
  # Image cache
28
  IMAGE_CACHE_DIR = os.getenv("IMAGE_CACHE_DIR", "cache/images")
src/galaxy_data_loader.py CHANGED
@@ -18,18 +18,41 @@ from src.config import (
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ---------------------------------------------------------------------------
23
  # ImageCache β€” thread-safe, disk-based LRU
24
  # ---------------------------------------------------------------------------
25
 
26
  class ImageCache:
27
- """Disk-based LRU image cache keyed by sequential pool index.
28
-
29
- Images are written at startup by sample_pool_streaming and served from
30
- disk thereafter. There is no network fallback β€” if an image was not
31
- captured during streaming it simply won't be available.
32
- """
33
 
34
  def __init__(self, cache_dir: str = IMAGE_CACHE_DIR, max_bytes: int = IMAGE_CACHE_MAX_BYTES):
35
  self._dir = Path(cache_dir)
@@ -57,7 +80,6 @@ class ImageCache:
57
  logger.info("Image cache: %d files, %.1f MB", len(self._access_times), total / 1e6)
58
 
59
  def get_path(self, row_index: int) -> Path | None:
60
- """Return cached file path if present, updating access time."""
61
  p = self._path_for(row_index)
62
  if p.exists():
63
  with self._lock:
@@ -66,7 +88,6 @@ class ImageCache:
66
  return None
67
 
68
  def put(self, row_index: int, image_bytes: bytes) -> Path:
69
- """Write image bytes to cache, evicting LRU entries if needed."""
70
  p = self._path_for(row_index)
71
  p.write_bytes(image_bytes)
72
  size = len(image_bytes)
@@ -77,7 +98,6 @@ class ImageCache:
77
  return p
78
 
79
  def _evict_if_needed(self):
80
- """Evict LRU entries until total size is within bounds. Caller holds lock."""
81
  while self._total_bytes > self._max_bytes and self._access_times:
82
  lru_idx = min(self._access_times, key=self._access_times.get)
83
  p = self._path_for(lru_idx)
@@ -89,12 +109,6 @@ class ImageCache:
89
  pass
90
  del self._access_times[lru_idx]
91
 
92
- def prefetch(self, row_indices: list[int]):
93
- """Log which requested indices are missing from cache (no-op fetch)."""
94
- missing = [idx for idx in row_indices if self.get_path(idx) is None]
95
- if missing:
96
- logger.debug("prefetch: %d indices not in cache (no refetch): %s", len(missing), missing[:5])
97
-
98
 
99
  # Module-level singleton
100
  image_cache = ImageCache()
@@ -105,67 +119,78 @@ image_cache = ImageCache()
105
  # ---------------------------------------------------------------------------
106
 
107
  def sample_pool_streaming(
108
- pool_size: int, seed: int | None = None
 
 
109
  ) -> tuple[list[int], dict[int, dict], int]:
110
- """Stream pool_size shuffled galaxies from HF Datasets, pre-caching images.
 
 
 
 
 
111
 
112
- Args:
113
- pool_size: Number of galaxies to include in the pool.
114
- seed: Shuffle seed. Pass the same seed on subsequent startups to
115
- reproduce the exact same pool so saved ELO state stays valid.
116
 
117
  Returns:
118
- ids: sequential ints 0..N-1 used as galaxy IDs throughout the app
119
- metadata_map: {id -> row_dict (without image column)} for display names
120
- seed: the seed that was used (store in tournament state for reuse)
121
  """
122
- from datasets import load_dataset
123
- from datasets import Image as HFImage
124
-
125
  if seed is None:
126
  seed = random.randint(0, 2**32 - 1)
127
 
128
- logger.info(
129
- "Streaming %d galaxies from %s (shuffle seed=%d)...",
130
- pool_size,
131
- DATASET_ID,
132
- seed,
133
- )
134
 
135
- ds = load_dataset(
136
- DATASET_ID,
137
- DATASET_CONFIG,
138
- split=DATASET_SPLIT,
139
- streaming=True,
140
- token=HF_TOKEN if HF_TOKEN else None,
141
- )
142
 
143
- features = getattr(ds, "features", None)
144
- if features and IMAGE_COLUMN in features:
145
- ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
146
 
147
- ds = ds.shuffle(seed=seed, buffer_size=10_000)
148
- ds = ds.take(pool_size)
 
 
 
149
 
150
- ids: list[int] = []
151
- metadata_map: dict[int, dict] = {}
152
 
153
- for i, row in enumerate(ds):
154
  img_col = row.get(IMAGE_COLUMN)
155
- img_bytes: bytes | None = None
156
  if isinstance(img_col, dict):
157
  img_bytes = img_col.get("bytes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- if img_bytes:
160
- image_cache.put(i, img_bytes)
161
- else:
162
- logger.warning("No image bytes for streamed row %d", i)
163
-
164
- metadata_map[i] = {k: v for k, v in row.items() if k != IMAGE_COLUMN}
165
- ids.append(i)
166
-
167
- if (i + 1) % 100 == 0:
168
- logger.info("Streamed %d/%d galaxies", i + 1, pool_size)
169
-
170
- logger.info("Finished streaming %d galaxies", len(ids))
171
  return ids, metadata_map, seed
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
+ # Must be identical across both streaming passes so row order is reproducible.
22
+ _SHUFFLE_BUFFER = 1_000
23
+
24
+
25
+ def _make_dataset(seed: int, pool_size: int, with_images: bool):
26
+ """Return a shuffled, length-limited streaming dataset iterator."""
27
+ from datasets import load_dataset
28
+ from datasets import Image as HFImage
29
+
30
+ ds = load_dataset(
31
+ DATASET_ID,
32
+ DATASET_CONFIG,
33
+ split=DATASET_SPLIT,
34
+ streaming=True,
35
+ token=HF_TOKEN if HF_TOKEN else None,
36
+ )
37
+ features = getattr(ds, "features", None)
38
+ if with_images:
39
+ if features and IMAGE_COLUMN in features:
40
+ ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
41
+ else:
42
+ if features and IMAGE_COLUMN in features:
43
+ ds = ds.remove_columns([IMAGE_COLUMN])
44
+
45
+ ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
46
+ ds = ds.take(pool_size)
47
+ return iter(ds)
48
+
49
 
50
  # ---------------------------------------------------------------------------
51
  # ImageCache β€” thread-safe, disk-based LRU
52
  # ---------------------------------------------------------------------------
53
 
54
  class ImageCache:
55
+ """Disk-based LRU image cache keyed by sequential pool index."""
 
 
 
 
 
56
 
57
  def __init__(self, cache_dir: str = IMAGE_CACHE_DIR, max_bytes: int = IMAGE_CACHE_MAX_BYTES):
58
  self._dir = Path(cache_dir)
 
80
  logger.info("Image cache: %d files, %.1f MB", len(self._access_times), total / 1e6)
81
 
82
  def get_path(self, row_index: int) -> Path | None:
 
83
  p = self._path_for(row_index)
84
  if p.exists():
85
  with self._lock:
 
88
  return None
89
 
90
  def put(self, row_index: int, image_bytes: bytes) -> Path:
 
91
  p = self._path_for(row_index)
92
  p.write_bytes(image_bytes)
93
  size = len(image_bytes)
 
98
  return p
99
 
100
  def _evict_if_needed(self):
 
101
  while self._total_bytes > self._max_bytes and self._access_times:
102
  lru_idx = min(self._access_times, key=self._access_times.get)
103
  p = self._path_for(lru_idx)
 
109
  pass
110
  del self._access_times[lru_idx]
111
 
 
 
 
 
 
 
112
 
113
  # Module-level singleton
114
  image_cache = ImageCache()
 
119
  # ---------------------------------------------------------------------------
120
 
121
  def sample_pool_streaming(
122
+ pool_size: int,
123
+ seed: int | None = None,
124
+ prefetch_images: int = 100,
125
  ) -> tuple[list[int], dict[int, dict], int]:
126
+ """Build the galaxy pool with lazy image loading.
127
+
128
+ Pass 1 (fast): streams metadata only β€” no image bytes downloaded.
129
+ Pass 2a (sync): caches the first `prefetch_images` images before returning,
130
+ so the app can serve immediately.
131
+ Pass 2b (async): background thread fills the rest of the image cache.
132
 
133
+ Both passes use the same seed and shuffle buffer so row i in pass 1
134
+ is the same galaxy as row i in pass 2.
 
 
135
 
136
  Returns:
137
+ ids: sequential ints 0..N-1
138
+ metadata_map: {id -> row_dict (no image column)}
139
+ seed: seed used (fixed for reproducibility)
140
  """
 
 
 
141
  if seed is None:
142
  seed = random.randint(0, 2**32 - 1)
143
 
144
+ # ------------------------------------------------------------------
145
+ # Pass 1: metadata only β€” fast, no image bytes
146
+ # ------------------------------------------------------------------
147
+ logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)
148
+ ids: list[int] = []
149
+ metadata_map: dict[int, dict] = {}
150
 
151
+ for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
152
+ metadata_map[i] = row
153
+ ids.append(i)
 
 
 
 
154
 
155
+ logger.info("Metadata ready: %d galaxies", len(ids))
 
 
156
 
157
+ # ------------------------------------------------------------------
158
+ # Pass 2: images β€” same seed/buffer so row order matches pass 1
159
+ # ------------------------------------------------------------------
160
+ sync_count = min(prefetch_images, pool_size)
161
+ logger.info("Pre-caching first %d images...", sync_count)
162
 
163
+ img_iter = _make_dataset(seed, pool_size, with_images=True)
 
164
 
165
+ def _cache_row(i: int, row: dict):
166
  img_col = row.get(IMAGE_COLUMN)
 
167
  if isinstance(img_col, dict):
168
  img_bytes = img_col.get("bytes")
169
+ if img_bytes:
170
+ image_cache.put(i, img_bytes)
171
+ return
172
+ logger.warning("No image bytes for row %d", i)
173
+
174
+ # Synchronous: first sync_count images
175
+ for i in range(sync_count):
176
+ _cache_row(i, next(img_iter))
177
+
178
+ logger.info("Initial %d images cached β€” app ready", sync_count)
179
+
180
+ # Asynchronous: remainder in background
181
+ remaining = pool_size - sync_count
182
+ if remaining > 0:
183
+ def _bg_cache():
184
+ logger.info("Background: caching %d remaining images...", remaining)
185
+ for i in range(sync_count, pool_size):
186
+ try:
187
+ _cache_row(i, next(img_iter))
188
+ except StopIteration:
189
+ break
190
+ except Exception as e:
191
+ logger.warning("Background cache error at row %d: %s", i, e)
192
+ logger.info("Background image caching complete")
193
+
194
+ threading.Thread(target=_bg_cache, daemon=True).start()
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  return ids, metadata_map, seed