Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import OrderedDict | |
| import torch | |
| from torch.utils.data.dataloader import default_collate | |
| from . import FairseqDataset | |
| def _flatten(dico, prefix=None): | |
| """Flatten a nested dictionary.""" | |
| new_dico = OrderedDict() | |
| if isinstance(dico, dict): | |
| prefix = prefix + "." if prefix is not None else "" | |
| for k, v in dico.items(): | |
| if v is None: | |
| continue | |
| new_dico.update(_flatten(v, prefix + k)) | |
| elif isinstance(dico, list): | |
| for i, v in enumerate(dico): | |
| new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) | |
| else: | |
| new_dico = OrderedDict({prefix: dico}) | |
| return new_dico | |
| def _unflatten(dico): | |
| """Unflatten a flattened dictionary into a nested dictionary.""" | |
| new_dico = OrderedDict() | |
| for full_k, v in dico.items(): | |
| full_k = full_k.split(".") | |
| node = new_dico | |
| for k in full_k[:-1]: | |
| if k.startswith("[") and k.endswith("]"): | |
| k = int(k[1:-1]) | |
| if k not in node: | |
| node[k] = OrderedDict() | |
| node = node[k] | |
| node[full_k[-1]] = v | |
| return new_dico | |
| class NestedDictionaryDataset(FairseqDataset): | |
| def __init__(self, defn, sizes=None): | |
| super().__init__() | |
| self.defn = _flatten(defn) | |
| self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes | |
| first = None | |
| for v in self.defn.values(): | |
| if not isinstance( | |
| v, | |
| ( | |
| FairseqDataset, | |
| torch.utils.data.Dataset, | |
| ), | |
| ): | |
| raise ValueError("Expected Dataset but found: {}".format(v.__class__)) | |
| first = first or v | |
| if len(v) > 0: | |
| assert len(v) == len(first), "dataset lengths must match" | |
| self._len = len(first) | |
| def __getitem__(self, index): | |
| return OrderedDict((k, ds[index]) for k, ds in self.defn.items()) | |
| def __len__(self): | |
| return self._len | |
| def collater(self, samples): | |
| """Merge a list of samples to form a mini-batch. | |
| Args: | |
| samples (List[dict]): samples to collate | |
| Returns: | |
| dict: a mini-batch suitable for forwarding with a Model | |
| """ | |
| if len(samples) == 0: | |
| return {} | |
| sample = OrderedDict() | |
| for k, ds in self.defn.items(): | |
| try: | |
| sample[k] = ds.collater([s[k] for s in samples]) | |
| except NotImplementedError: | |
| sample[k] = default_collate([s[k] for s in samples]) | |
| return _unflatten(sample) | |
| def num_tokens(self, index): | |
| """Return the number of tokens in a sample. This value is used to | |
| enforce ``--max-tokens`` during batching.""" | |
| return max(s[index] for s in self.sizes) | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used when | |
| filtering a dataset with ``--max-positions``.""" | |
| if len(self.sizes) == 1: | |
| return self.sizes[0][index] | |
| else: | |
| return (s[index] for s in self.sizes) | |
| def supports_prefetch(self): | |
| """Whether this dataset supports prefetching.""" | |
| return any(ds.supports_prefetch for ds in self.defn.values()) | |
| def prefetch(self, indices): | |
| """Prefetch the data required for this epoch.""" | |
| for ds in self.defn.values(): | |
| if getattr(ds, "supports_prefetch", False): | |
| ds.prefetch(indices) | |
| def can_reuse_epoch_itr_across_epochs(self): | |
| return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) | |
| def set_epoch(self, epoch): | |
| super().set_epoch(epoch) | |
| for ds in self.defn.values(): | |
| ds.set_epoch(epoch) | |