Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from fairseq.data import data_utils | |
| from . import BaseWrapperDataset | |
| class PaddingMaskDataset(BaseWrapperDataset): | |
| def __init__(self, dataset, left_pad, pad_length=None): | |
| super().__init__(dataset) | |
| self.left_pad = left_pad | |
| self.pad_length = pad_length | |
| def __getitem__(self, index): | |
| item = self.dataset[index] | |
| return torch.zeros_like(item).bool() | |
| def __len__(self): | |
| return len(self.dataset) | |
| def collater(self, samples): | |
| return data_utils.collate_tokens( | |
| samples, True, left_pad=self.left_pad, pad_to_length=self.pad_length | |
| ) | |
| class LeftPaddingMaskDataset(PaddingMaskDataset): | |
| def __init__(self, dataset): | |
| super().__init__(dataset, left_pad=True) | |
| class RightPaddingMaskDataset(PaddingMaskDataset): | |
| def __init__(self, dataset): | |
| super().__init__(dataset, left_pad=False) | |