Spaces:
Sleeping
Sleeping
File size: 1,120 Bytes
61d3625 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
"""Dataset streaming utilities."""
from __future__ import annotations
from typing import Iterator
from datasets import IterableDataset, load_dataset
from .config import CONFIG
from .utils import set_seed
class WikiArtStream:
"""Wrapper around streaming dataset with deterministic shuffle."""
def __init__(self, sample_size: int | None = None) -> None:
self.cfg = CONFIG.dataset
self.sample_size = sample_size or self.cfg.sample_size
set_seed(self.cfg.seed)
streaming_ds = load_dataset(
self.cfg.name,
split=self.cfg.split,
streaming=self.cfg.streaming,
trust_remote_code=True,
)
assert isinstance(streaming_ds, IterableDataset)
self.dataset: IterableDataset = streaming_ds.shuffle(
seed=self.cfg.seed,
buffer_size=self.cfg.shuffle_buffer,
)
def __iter__(self) -> Iterator[dict]:
for idx, sample in enumerate(self.dataset):
if idx >= self.sample_size:
break
yield sample
|