File size: 6,714 Bytes
21ce873
159bb46
f229b66
 
159bb46
879d582
 
 
 
 
 
 
 
 
 
285999a
879d582
 
 
 
159bb46
 
 
0a4af2e
1a6597d
 
3d2b027
 
 
 
 
 
 
1a6597d
 
 
 
 
 
 
 
 
 
 
3d2b027
 
 
 
285999a
 
1a6597d
 
 
 
 
879d582
 
 
 
 
 
1a6597d
879d582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159bb46
879d582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21ce873
 
 
 
 
 
346f507
1a6597d
 
 
346f507
0a4af2e
1a6597d
0a4af2e
 
 
346f507
21ce873
1a6597d
 
0a4af2e
21ce873
346f507
 
 
3d2b027
21ce873
3d2b027
 
0a4af2e
3d2b027
285999a
3d2b027
21ce873
3d2b027
21ce873
3d2b027
df55019
 
 
 
 
 
 
 
 
 
3d2b027
df55019
 
3d2b027
 
 
 
 
df55019
 
 
0a4af2e
 
df55019
3d2b027
0a4af2e
 
 
 
1a6597d
 
3d2b027
 
 
 
1a6597d
 
 
0a4af2e
3d2b027
1a6597d
0a4af2e
21ce873
346f507
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""Galaxy data loading: HF Datasets streaming sampler + disk-based LRU image cache."""

from __future__ import annotations

import logging
import random
import threading
import time
from pathlib import Path

from src.config import (
    DATASET_CONFIG,
    DATASET_ID,
    DATASET_SPLIT,
    HF_TOKEN,
    ID_COLUMN,
    IMAGE_CACHE_DIR,
    IMAGE_CACHE_MAX_BYTES,
    IMAGE_COLUMN,
)

logger = logging.getLogger(__name__)

_SHUFFLE_BUFFER = 200


def _make_dataset(seed: int, pool_size: int, with_images: bool = True):
    """Return a shuffled, length-limited streaming dataset iterator.

    with_images=False uses select_columns to skip the image column entirely
    at the Parquet level β€” no image bytes downloaded.
    Both modes use the same seed+buffer so row i is always the same galaxy.
    """
    from datasets import load_dataset
    from datasets import Image as HFImage

    ds = load_dataset(
        DATASET_ID,
        DATASET_CONFIG,
        split=DATASET_SPLIT,
        streaming=True,
        token=HF_TOKEN if HF_TOKEN else None,
    )
    features = getattr(ds, "features", None)
    if with_images:
        if features and IMAGE_COLUMN in features:
            ds = ds.cast_column(IMAGE_COLUMN, HFImage(decode=False))
    else:
        if features and ID_COLUMN in features:
            ds = ds.select_columns([ID_COLUMN])

    ds = ds.shuffle(seed=seed, buffer_size=_SHUFFLE_BUFFER)
    ds = ds.take(pool_size)
    return iter(ds)


# ---------------------------------------------------------------------------
# ImageCache β€” thread-safe, disk-based LRU
# ---------------------------------------------------------------------------

class ImageCache:
    """Disk-based LRU image cache keyed by sequential pool index."""

    def __init__(self, cache_dir: str = IMAGE_CACHE_DIR, max_bytes: int = IMAGE_CACHE_MAX_BYTES):
        self._dir = Path(cache_dir)
        self._dir.mkdir(parents=True, exist_ok=True)
        self._max_bytes = max_bytes
        self._lock = threading.Lock()
        self._access_times: dict[int, float] = {}
        self._total_bytes = 0
        self._scan_existing()

    def _path_for(self, row_index: int) -> Path:
        return self._dir / f"{row_index}.jpg"

    def _scan_existing(self):
        total = 0
        for p in self._dir.glob("*.jpg"):
            try:
                idx = int(p.stem)
                size = p.stat().st_size
                total += size
                self._access_times[idx] = p.stat().st_mtime
            except (ValueError, OSError):
                continue
        self._total_bytes = total
        logger.info("Image cache: %d files, %.1f MB", len(self._access_times), total / 1e6)

    def get_path(self, row_index: int) -> Path | None:
        p = self._path_for(row_index)
        if p.exists():
            with self._lock:
                self._access_times[row_index] = time.monotonic()
            return p
        return None

    def put(self, row_index: int, image_bytes: bytes) -> Path:
        p = self._path_for(row_index)
        p.write_bytes(image_bytes)
        size = len(image_bytes)
        with self._lock:
            self._access_times[row_index] = time.monotonic()
            self._total_bytes += size
            self._evict_if_needed()
        return p

    def _evict_if_needed(self):
        while self._total_bytes > self._max_bytes and self._access_times:
            lru_idx = min(self._access_times, key=self._access_times.get)
            p = self._path_for(lru_idx)
            try:
                size = p.stat().st_size
                p.unlink()
                self._total_bytes -= size
            except OSError:
                pass
            del self._access_times[lru_idx]


# Module-level singleton
image_cache = ImageCache()


# ---------------------------------------------------------------------------
# Streaming pool sampler
# ---------------------------------------------------------------------------

def sample_pool_streaming(
    pool_size: int,
    seed: int | None = None,
    prefetch_images: int = 100,
) -> tuple[list[int], dict[int, dict], int]:
    """Build the galaxy pool, caching a small batch of images before returning.

    Single streaming pass with cast_column(decode=False) to avoid Pillow.
    The shuffle buffer is small (200 rows) so only ~300 images are downloaded
    before the app starts serving. The rest are cached in a background thread.

    Returns:
        ids: sequential ints 0..N-1
        metadata_map: {id -> row_dict (no image column)}
        seed: seed used
    """
    if seed is None:
        seed = random.randint(0, 2**32 - 1)

    logger.info("Streaming metadata for %d galaxies (seed=%d)...", pool_size, seed)

    # Pass 1: metadata only β€” fast, no images downloaded
    ids: list[int] = []
    metadata_map: dict[int, dict] = {}
    for i, row in enumerate(_make_dataset(seed, pool_size, with_images=False)):
        metadata_map[i] = {ID_COLUMN: row.get(ID_COLUMN)}
        ids.append(i)

    logger.info("All %d galaxy IDs ready", len(ids))

    # Pass 2: images β€” same seed+buffer β†’ same row order as pass 1
    def _extract_bytes(img_col) -> bytes | None:
        if isinstance(img_col, dict):
            return img_col.get("bytes")
        if img_col is not None:
            try:
                import io
                buf = io.BytesIO()
                img_col.save(buf, format="JPEG")
                return buf.getvalue()
            except Exception as e:
                logger.warning("PIL conversion failed: %s", e)
        return None

    img_it = _make_dataset(seed, pool_size, with_images=True)
    sync_count = min(prefetch_images, pool_size)

    for i in range(sync_count):
        row = next(img_it)
        img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
        if img_bytes:
            image_cache.put(i, img_bytes)
        else:
            logger.warning("No image bytes for row %d", i)

    logger.info("%d images cached β€” app ready, %d remaining in background",
                sync_count, pool_size - sync_count)

    if sync_count < pool_size:
        def _bg():
            for i in range(sync_count, pool_size):
                try:
                    row = next(img_it)
                    img_bytes = _extract_bytes(row.get(IMAGE_COLUMN))
                    if img_bytes:
                        image_cache.put(i, img_bytes)
                except StopIteration:
                    break
                except Exception as e:
                    logger.warning("Background error at row %d: %s", i, e)
            logger.info("Background image caching complete")

        threading.Thread(target=_bg, daemon=True).start()

    return ids, metadata_map, seed