multimodal-hw / src /data.py
AlekMan's picture
Upload 41 files
61d3625 verified
"""Dataset streaming utilities."""
from __future__ import annotations
from typing import Iterator
from datasets import IterableDataset, load_dataset
from .config import CONFIG
from .utils import set_seed
class WikiArtStream:
"""Wrapper around streaming dataset with deterministic shuffle."""
def __init__(self, sample_size: int | None = None) -> None:
self.cfg = CONFIG.dataset
self.sample_size = sample_size or self.cfg.sample_size
set_seed(self.cfg.seed)
streaming_ds = load_dataset(
self.cfg.name,
split=self.cfg.split,
streaming=self.cfg.streaming,
trust_remote_code=True,
)
assert isinstance(streaming_ds, IterableDataset)
self.dataset: IterableDataset = streaming_ds.shuffle(
seed=self.cfg.seed,
buffer_size=self.cfg.shuffle_buffer,
)
def __iter__(self) -> Iterator[dict]:
for idx, sample in enumerate(self.dataset):
if idx >= self.sample_size:
break
yield sample