Spaces:
Running
Running
| """Dataset streaming utilities.""" | |
| from __future__ import annotations | |
| from typing import Iterator | |
| from datasets import IterableDataset, load_dataset | |
| from .config import CONFIG | |
| from .utils import set_seed | |
| class WikiArtStream: | |
| """Wrapper around streaming dataset with deterministic shuffle.""" | |
| def __init__(self, sample_size: int | None = None) -> None: | |
| self.cfg = CONFIG.dataset | |
| self.sample_size = sample_size or self.cfg.sample_size | |
| set_seed(self.cfg.seed) | |
| streaming_ds = load_dataset( | |
| self.cfg.name, | |
| split=self.cfg.split, | |
| streaming=self.cfg.streaming, | |
| trust_remote_code=True, | |
| ) | |
| assert isinstance(streaming_ds, IterableDataset) | |
| self.dataset: IterableDataset = streaming_ds.shuffle( | |
| seed=self.cfg.seed, | |
| buffer_size=self.cfg.shuffle_buffer, | |
| ) | |
| def __iter__(self) -> Iterator[dict]: | |
| for idx, sample in enumerate(self.dataset): | |
| if idx >= self.sample_size: | |
| break | |
| yield sample | |