|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from fairseq.data import data_utils |
|
|
from . import BaseWrapperDataset |
|
|
|
|
|
|
|
|
class PaddingMaskDataset(BaseWrapperDataset): |
|
|
def __init__(self, dataset, left_pad, pad_length=None): |
|
|
super().__init__(dataset) |
|
|
self.left_pad = left_pad |
|
|
self.pad_length = pad_length |
|
|
|
|
|
def __getitem__(self, index): |
|
|
item = self.dataset[index] |
|
|
return torch.zeros_like(item).bool() |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) |
|
|
|
|
|
def collater(self, samples): |
|
|
return data_utils.collate_tokens( |
|
|
samples, True, left_pad=self.left_pad, pad_to_length=self.pad_length |
|
|
) |
|
|
|
|
|
|
|
|
class LeftPaddingMaskDataset(PaddingMaskDataset): |
|
|
def __init__(self, dataset): |
|
|
super().__init__(dataset, left_pad=True) |
|
|
|
|
|
|
|
|
class RightPaddingMaskDataset(PaddingMaskDataset): |
|
|
def __init__(self, dataset): |
|
|
super().__init__(dataset, left_pad=False) |
|
|
|