File size: 1,120 Bytes
61d3625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Dataset streaming utilities."""
from __future__ import annotations

from typing import Iterator

from datasets import IterableDataset, load_dataset

from .config import CONFIG
from .utils import set_seed


class WikiArtStream:
    """Wrapper around streaming dataset with deterministic shuffle."""

    def __init__(self, sample_size: int | None = None) -> None:
        self.cfg = CONFIG.dataset
        self.sample_size = sample_size or self.cfg.sample_size

        set_seed(self.cfg.seed)
        
        streaming_ds = load_dataset(
            self.cfg.name,
            split=self.cfg.split,
            streaming=self.cfg.streaming,
            trust_remote_code=True,
        )
        assert isinstance(streaming_ds, IterableDataset)
        self.dataset: IterableDataset = streaming_ds.shuffle(
            seed=self.cfg.seed,
            buffer_size=self.cfg.shuffle_buffer,
        )

    def __iter__(self) -> Iterator[dict]:
        for idx, sample in enumerate(self.dataset):
            if idx >= self.sample_size:
                break
            yield sample