Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from . import BaseWrapperDataset | |
| class ReplaceDataset(BaseWrapperDataset): | |
| """Replaces tokens found in the dataset by a specified replacement token | |
| Args: | |
| dataset (~torch.utils.data.Dataset): dataset to replace tokens in | |
| replace_map(Dictionary[int,int]): map of token to replace -> replacement token | |
| offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be | |
| as many as the number of objects returned by the underlying dataset __getitem__ method. | |
| """ | |
| def __init__(self, dataset, replace_map, offsets): | |
| super().__init__(dataset) | |
| assert len(replace_map) > 0 | |
| self.replace_map = replace_map | |
| self.offsets = offsets | |
| def __getitem__(self, index): | |
| item = self.dataset[index] | |
| is_tuple = isinstance(item, tuple) | |
| srcs = item if is_tuple else [item] | |
| for offset, src in zip(self.offsets, srcs): | |
| for k, v in self.replace_map.items(): | |
| src_off = src[offset:] if offset >= 0 else src[:offset] | |
| src_off.masked_fill_(src_off == k, v) | |
| item = srcs if is_tuple else srcs[0] | |
| return item | |