| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| """Splits related API.""" |
|
|
| import abc |
| import collections |
| import copy |
| import dataclasses |
| import re |
| from dataclasses import dataclass |
| from typing import Optional, Union |
|
|
| from .arrow_reader import FileInstructions, make_file_instructions |
| from .naming import _split_re |
| from .utils.py_utils import NonMutableDict, asdict |
|
|
|
|
| @dataclass |
| class SplitInfo: |
| name: str = dataclasses.field(default="", metadata={"include_in_asdict_even_if_is_default": True}) |
| num_bytes: int = dataclasses.field(default=0, metadata={"include_in_asdict_even_if_is_default": True}) |
| num_examples: int = dataclasses.field(default=0, metadata={"include_in_asdict_even_if_is_default": True}) |
| shard_lengths: Optional[list[int]] = None |
|
|
| |
| |
| |
| |
| dataset_name: Optional[str] = dataclasses.field( |
| default=None, metadata={"include_in_asdict_even_if_is_default": True} |
| ) |
|
|
| @property |
| def file_instructions(self): |
| """Returns the list of dict(filename, take, skip).""" |
| |
| instructions = make_file_instructions( |
| name=self.dataset_name, |
| split_infos=[self], |
| instruction=str(self.name), |
| ) |
| return instructions.file_instructions |
|
|
|
|
| @dataclass |
| class SubSplitInfo: |
| """Wrapper around a sub split info. |
| This class expose info on the subsplit: |
| ``` |
| ds, info = datasets.load_dataset(..., split='train[75%:]', with_info=True) |
| info.splits['train[75%:]'].num_examples |
| ``` |
| """ |
|
|
| instructions: FileInstructions |
|
|
| @property |
| def num_examples(self): |
| """Returns the number of example in the subsplit.""" |
| return self.instructions.num_examples |
|
|
| @property |
| def file_instructions(self): |
| """Returns the list of dict(filename, take, skip).""" |
| return self.instructions.file_instructions |
|
|
|
|
| class SplitBase(metaclass=abc.ABCMeta): |
| |
| """Abstract base class for Split compositionality. |
| |
| See the |
| [guide on splits](../loading#slice-splits) |
| for more information. |
| |
| There are three parts to the composition: |
| 1) The splits are composed (defined, merged, split,...) together before |
| calling the `.as_dataset()` function. This is done with the `__add__`, |
| `__getitem__`, which return a tree of `SplitBase` (whose leaf |
| are the `NamedSplit` objects) |
| |
| ``` |
| split = datasets.Split.TRAIN + datasets.Split.TEST.subsplit(datasets.percent[:50]) |
| ``` |
| |
| 2) The `SplitBase` is forwarded to the `.as_dataset()` function |
| to be resolved into actual read instruction. This is done by the |
| `.get_read_instruction()` method which takes the real dataset splits |
| (name, number of shards,...) and parse the tree to return a |
| `SplitReadInstruction()` object |
| |
| ``` |
| read_instruction = split.get_read_instruction(self.info.splits) |
| ``` |
| |
| 3) The `SplitReadInstruction` is then used in the `tf.data.Dataset` pipeline |
| to define which files to read and how to skip examples within file. |
| |
| """ |
|
|
| |
|
|
| @abc.abstractmethod |
| def get_read_instruction(self, split_dict): |
| """Parse the descriptor tree and compile all read instructions together. |
| |
| Args: |
| split_dict: `dict`, The `dict[split_name, SplitInfo]` of the dataset |
| |
| Returns: |
| split_read_instruction: `SplitReadInstruction` |
| """ |
| raise NotImplementedError("Abstract method") |
|
|
| def __eq__(self, other): |
| """Equality: datasets.Split.TRAIN == 'train'.""" |
| if isinstance(other, (NamedSplit, str)): |
| return False |
| raise NotImplementedError("Equality is not implemented between merged/sub splits.") |
|
|
| def __ne__(self, other): |
| """InEquality: datasets.Split.TRAIN != 'test'.""" |
| return not self.__eq__(other) |
|
|
| def __add__(self, other): |
| """Merging: datasets.Split.TRAIN + datasets.Split.TEST.""" |
| return _SplitMerged(self, other) |
|
|
| def subsplit(self, arg=None, k=None, percent=None, weighted=None): |
| """Divides this split into subsplits. |
| |
| There are 3 ways to define subsplits, which correspond to the 3 |
| arguments `k` (get `k` even subsplits), `percent` (get a slice of the |
| dataset with `datasets.percent`), and `weighted` (get subsplits with proportions |
| specified by `weighted`). |
| |
| Example:: |
| |
| ``` |
| # 50% train, 50% test |
| train, test = split.subsplit(k=2) |
| # 50% train, 25% test, 25% validation |
| train, test, validation = split.subsplit(weighted=[2, 1, 1]) |
| # Extract last 20% |
| subsplit = split.subsplit(datasets.percent[-20:]) |
| ``` |
| |
| Warning: k and weighted will be converted into percent which mean that |
| values below the percent will be rounded up or down. The final split may be |
| bigger to deal with remainders. For instance: |
| |
| ``` |
| train, test, valid = split.subsplit(k=3) # 33%, 33%, 34% |
| s1, s2, s3, s4 = split.subsplit(weighted=[2, 2, 1, 1]) # 33%, 33%, 16%, 18% |
| ``` |
| |
| Args: |
| arg: If no kwargs are given, `arg` will be interpreted as one of |
| `k`, `percent`, or `weighted` depending on the type. |
| For example: |
| ``` |
| split.subsplit(10) # Equivalent to split.subsplit(k=10) |
| split.subsplit(datasets.percent[:-20]) # percent=datasets.percent[:-20] |
| split.subsplit([1, 1, 2]) # weighted=[1, 1, 2] |
| ``` |
| k: `int` If set, subdivide the split into `k` equal parts. |
| percent: `datasets.percent slice`, return a single subsplit corresponding to |
| a slice of the original split. For example: |
| `split.subsplit(datasets.percent[-20:]) # Last 20% of the dataset`. |
| weighted: `list[int]`, return a list of subsplits whose proportions match |
| the normalized sum of the list. For example: |
| `split.subsplit(weighted=[1, 1, 2]) # 25%, 25%, 50%`. |
| |
| Returns: |
| A subsplit or list of subsplits extracted from this split object. |
| """ |
| |
| |
| if sum(bool(x) for x in (arg, k, percent, weighted)) != 1: |
| raise ValueError("Only one argument of subsplit should be set.") |
|
|
| |
| if isinstance(arg, int): |
| k = arg |
| elif isinstance(arg, slice): |
| percent = arg |
| elif isinstance(arg, list): |
| weighted = arg |
|
|
| if not (k or percent or weighted): |
| raise ValueError( |
| f"Invalid split argument {arg}. Only list, slice and int supported. " |
| "One of k, weighted or percent should be set to a non empty value." |
| ) |
|
|
| def assert_slices_coverage(slices): |
| |
| assert sum((list(range(*s.indices(100))) for s in slices), []) == list(range(100)) |
|
|
| if k: |
| if not 0 < k <= 100: |
| raise ValueError(f"Subsplit k should be between 0 and 100, got {k}") |
| shift = 100 // k |
| slices = [slice(i * shift, (i + 1) * shift) for i in range(k)] |
| |
| slices[-1] = slice(slices[-1].start, 100) |
| |
| assert_slices_coverage(slices) |
| return tuple(_SubSplit(self, s) for s in slices) |
| elif percent: |
| return _SubSplit(self, percent) |
| elif weighted: |
| |
| total = sum(weighted) |
| weighted = [100 * x // total for x in weighted] |
| |
| start = 0 |
| stop = 0 |
| slices = [] |
| for v in weighted: |
| stop += v |
| slices.append(slice(start, stop)) |
| start = stop |
| |
| slices[-1] = slice(slices[-1].start, 100) |
| |
| assert_slices_coverage(slices) |
| return tuple(_SubSplit(self, s) for s in slices) |
| else: |
| |
| raise ValueError("Could not determine the split") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| class PercentSliceMeta(type): |
| def __getitem__(cls, slice_value): |
| if not isinstance(slice_value, slice): |
| raise ValueError(f"datasets.percent should only be called with slice, not {slice_value}") |
| return slice_value |
|
|
|
|
| class PercentSlice(metaclass=PercentSliceMeta): |
| |
| """Syntactic sugar for defining slice subsplits: `datasets.percent[75:-5]`. |
| |
| See the |
| [guide on splits](../loading#slice-splits) |
| for more information. |
| """ |
|
|
| |
| pass |
|
|
|
|
| percent = PercentSlice |
|
|
|
|
| class _SplitMerged(SplitBase): |
| """Represent two split descriptors merged together.""" |
|
|
| def __init__(self, split1, split2): |
| self._split1 = split1 |
| self._split2 = split2 |
|
|
| def get_read_instruction(self, split_dict): |
| read_instruction1 = self._split1.get_read_instruction(split_dict) |
| read_instruction2 = self._split2.get_read_instruction(split_dict) |
| return read_instruction1 + read_instruction2 |
|
|
| def __repr__(self): |
| return f"({repr(self._split1)} + {repr(self._split2)})" |
|
|
|
|
| class _SubSplit(SplitBase): |
| """Represent a sub split of a split descriptor.""" |
|
|
| def __init__(self, split, slice_value): |
| self._split = split |
| self._slice_value = slice_value |
|
|
| def get_read_instruction(self, split_dict): |
| return self._split.get_read_instruction(split_dict)[self._slice_value] |
|
|
| def __repr__(self): |
| slice_str = "{start}:{stop}" |
| if self._slice_value.step is not None: |
| slice_str += ":{step}" |
| slice_str = slice_str.format( |
| start="" if self._slice_value.start is None else self._slice_value.start, |
| stop="" if self._slice_value.stop is None else self._slice_value.stop, |
| step=self._slice_value.step, |
| ) |
| return f"{repr(self._split)}(datasets.percent[{slice_str}])" |
|
|
|
|
| class NamedSplit(SplitBase): |
| """Descriptor corresponding to a named split (train, test, ...). |
| |
| Example: |
| Each descriptor can be composed with other using addition or slice: |
| |
| ```py |
| split = datasets.Split.TRAIN.subsplit(datasets.percent[0:25]) + datasets.Split.TEST |
| ``` |
| |
| The resulting split will correspond to 25% of the train split merged with |
| 100% of the test split. |
| |
| A split cannot be added twice, so the following will fail: |
| |
| ```py |
| split = ( |
| datasets.Split.TRAIN.subsplit(datasets.percent[:25]) + |
| datasets.Split.TRAIN.subsplit(datasets.percent[75:]) |
| ) # Error |
| split = datasets.Split.TEST + datasets.Split.ALL # Error |
| ``` |
| |
| The slices can be applied only one time. So the following are valid: |
| |
| ```py |
| split = ( |
| datasets.Split.TRAIN.subsplit(datasets.percent[:25]) + |
| datasets.Split.TEST.subsplit(datasets.percent[:50]) |
| ) |
| split = (datasets.Split.TRAIN + datasets.Split.TEST).subsplit(datasets.percent[:50]) |
| ``` |
| |
| But this is not valid: |
| |
| ```py |
| train = datasets.Split.TRAIN |
| test = datasets.Split.TEST |
| split = train.subsplit(datasets.percent[:25]).subsplit(datasets.percent[:25]) |
| split = (train.subsplit(datasets.percent[:25]) + test).subsplit(datasets.percent[:50]) |
| ``` |
| """ |
|
|
| def __init__(self, name): |
| self._name = name |
| split_names_from_instruction = [split_instruction.split("[")[0] for split_instruction in name.split("+")] |
| for split_name in split_names_from_instruction: |
| if not re.match(_split_re, split_name): |
| raise ValueError(f"Split name should match '{_split_re}' but got '{split_name}'.") |
|
|
| def __str__(self): |
| return self._name |
|
|
| def __repr__(self): |
| return f"NamedSplit({self._name!r})" |
|
|
| def __eq__(self, other): |
| """Equality: datasets.Split.TRAIN == 'train'.""" |
| if isinstance(other, NamedSplit): |
| return self._name == other._name |
| elif isinstance(other, SplitBase): |
| return False |
| elif isinstance(other, str): |
| return self._name == other |
| else: |
| return False |
|
|
| def __lt__(self, other): |
| return self._name < other._name |
|
|
| def __hash__(self): |
| return hash(self._name) |
|
|
| def get_read_instruction(self, split_dict): |
| return SplitReadInstruction(split_dict[self._name]) |
|
|
|
|
| class NamedSplitAll(NamedSplit): |
| """Split corresponding to the union of all defined dataset splits.""" |
|
|
| def __init__(self): |
| super().__init__("all") |
|
|
| def __repr__(self): |
| return "NamedSplitAll()" |
|
|
| def get_read_instruction(self, split_dict): |
| |
| read_instructions = [SplitReadInstruction(s) for s in split_dict.values()] |
| return sum(read_instructions, SplitReadInstruction()) |
|
|
|
|
| class Split: |
| |
| """`Enum` for dataset splits. |
| |
| Datasets are typically split into different subsets to be used at various |
| stages of training and evaluation. |
| |
| - `TRAIN`: the training data. |
| - `VALIDATION`: the validation data. If present, this is typically used as |
| evaluation data while iterating on a model (e.g. changing hyperparameters, |
| model architecture, etc.). |
| - `TEST`: the testing data. This is the data to report metrics on. Typically |
| you do not want to use this during model iteration as you may overfit to it. |
| - `ALL`: the union of all defined dataset splits. |
| |
| All splits, including compositions inherit from `datasets.SplitBase`. |
| |
| See the [guide](../load_hub#splits) on splits for more information. |
| |
| Example: |
| |
| ```py |
| >>> datasets.SplitGenerator( |
| ... name=datasets.Split.TRAIN, |
| ... gen_kwargs={"split_key": "train", "files": dl_manager.download_and extract(url)}, |
| ... ), |
| ... datasets.SplitGenerator( |
| ... name=datasets.Split.VALIDATION, |
| ... gen_kwargs={"split_key": "validation", "files": dl_manager.download_and extract(url)}, |
| ... ), |
| ... datasets.SplitGenerator( |
| ... name=datasets.Split.TEST, |
| ... gen_kwargs={"split_key": "test", "files": dl_manager.download_and extract(url)}, |
| ... ) |
| ``` |
| """ |
|
|
| |
| TRAIN = NamedSplit("train") |
| TEST = NamedSplit("test") |
| VALIDATION = NamedSplit("validation") |
| ALL = NamedSplitAll() |
|
|
| def __new__(cls, name): |
| """Create a custom split with datasets.Split('custom_name').""" |
| return NamedSplitAll() if name == "all" else NamedSplit(name) |
|
|
|
|
| |
| SlicedSplitInfo = collections.namedtuple( |
| "SlicedSplitInfo", |
| [ |
| "split_info", |
| "slice_value", |
| ], |
| ) |
|
|
|
|
| class SplitReadInstruction: |
| """Object containing the reading instruction for the dataset. |
| |
| Similarly to `SplitDescriptor` nodes, this object can be composed with itself, |
| but the resolution happens instantaneously, instead of keeping track of the |
| tree, such as all instructions are compiled and flattened in a single |
| SplitReadInstruction object containing the list of files and slice to use. |
| |
| Once resolved, the instructions can be accessed with: |
| |
| ``` |
| read_instructions.get_list_sliced_split_info() # List of splits to use |
| ``` |
| |
| """ |
|
|
| def __init__(self, split_info=None): |
| self._splits = NonMutableDict(error_msg="Overlap between splits. Split {key} has been added with itself.") |
|
|
| if split_info: |
| self.add(SlicedSplitInfo(split_info=split_info, slice_value=None)) |
|
|
| def add(self, sliced_split): |
| """Add a SlicedSplitInfo the read instructions.""" |
| |
| |
| |
| self._splits[sliced_split.split_info.name] = sliced_split |
|
|
| def __add__(self, other): |
| """Merging split together.""" |
| |
| |
| |
| split_instruction = SplitReadInstruction() |
| split_instruction._splits.update(self._splits) |
| split_instruction._splits.update(other._splits) |
| return split_instruction |
|
|
| def __getitem__(self, slice_value): |
| """Sub-splits.""" |
| |
| split_instruction = SplitReadInstruction() |
| for v in self._splits.values(): |
| if v.slice_value is not None: |
| raise ValueError(f"Trying to slice Split {v.split_info.name} which has already been sliced") |
| v = v._asdict() |
| v["slice_value"] = slice_value |
| split_instruction.add(SlicedSplitInfo(**v)) |
| return split_instruction |
|
|
| def get_list_sliced_split_info(self): |
| return list(self._splits.values()) |
|
|
|
|
| class SplitDict(dict): |
| """Split info object.""" |
|
|
| def __init__(self, *args, dataset_name=None, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.dataset_name = dataset_name |
|
|
| def __getitem__(self, key: Union[SplitBase, str]): |
| |
| if str(key) in self: |
| return super().__getitem__(str(key)) |
| |
| else: |
| instructions = make_file_instructions( |
| name=self.dataset_name, |
| split_infos=self.values(), |
| instruction=key, |
| ) |
| return SubSplitInfo(instructions) |
|
|
| def __setitem__(self, key: Union[SplitBase, str], value: SplitInfo): |
| if key != value.name: |
| raise ValueError(f"Cannot add elem. (key mismatch: '{key}' != '{value.name}')") |
| super().__setitem__(key, value) |
|
|
| def add(self, split_info: SplitInfo): |
| """Add the split info.""" |
| if split_info.name in self: |
| raise ValueError(f"Split {split_info.name} already present") |
| split_info.dataset_name = self.dataset_name |
| super().__setitem__(split_info.name, split_info) |
|
|
| @property |
| def total_num_examples(self): |
| """Return the total number of examples.""" |
| return sum(s.num_examples for s in self.values()) |
|
|
| @classmethod |
| def from_split_dict(cls, split_infos: Union[list, dict], dataset_name: Optional[str] = None): |
| """Returns a new SplitDict initialized from a Dict or List of `split_infos`.""" |
| if isinstance(split_infos, dict): |
| split_infos = list(split_infos.values()) |
|
|
| if dataset_name is None: |
| dataset_name = split_infos[0].get("dataset_name") if split_infos else None |
|
|
| split_dict = cls(dataset_name=dataset_name) |
|
|
| for split_info in split_infos: |
| if isinstance(split_info, dict): |
| split_info = SplitInfo(**split_info) |
| split_dict.add(split_info) |
|
|
| return split_dict |
|
|
| def to_split_dict(self): |
| """Returns a list of SplitInfo protos that we have.""" |
| out = [] |
| for split_name, split_info in self.items(): |
| split_info = copy.deepcopy(split_info) |
| split_info.name = split_name |
| out.append(split_info) |
| return out |
|
|
| def copy(self): |
| return SplitDict.from_split_dict(self.to_split_dict(), self.dataset_name) |
|
|
| def _to_yaml_list(self) -> list: |
| out = [asdict(s) for s in self.to_split_dict()] |
| |
| for split_info_dict in out: |
| split_info_dict.pop("shard_lengths", None) |
| |
| for split_info_dict in out: |
| split_info_dict.pop("dataset_name", None) |
| return out |
|
|
| @classmethod |
| def _from_yaml_list(cls, yaml_data: list) -> "SplitDict": |
| return cls.from_split_dict(yaml_data) |
|
|
|
|
| @dataclass |
| class SplitGenerator: |
| """Defines the split information for the generator. |
| |
| This should be used as returned value of |
| `GeneratorBasedBuilder._split_generators`. |
| See `GeneratorBasedBuilder._split_generators` for more info and example |
| of usage. |
| |
| Args: |
| name (`str`): |
| Name of the `Split` for which the generator will |
| create the examples. |
| **gen_kwargs (additional keyword arguments): |
| Keyword arguments to forward to the `DatasetBuilder._generate_examples` method |
| of the builder. |
| |
| Example: |
| |
| ```py |
| >>> datasets.SplitGenerator( |
| ... name=datasets.Split.TRAIN, |
| ... gen_kwargs={"split_key": "train", "files": dl_manager.download_and_extract(url)}, |
| ... ) |
| ``` |
| """ |
|
|
| name: str |
| gen_kwargs: dict = dataclasses.field(default_factory=dict) |
| split_info: SplitInfo = dataclasses.field(init=False) |
|
|
| def __post_init__(self): |
| self.name = str(self.name) |
| NamedSplit(self.name) |
| self.split_info = SplitInfo(name=self.name) |
|
|