| |
| |
| |
| |
|
|
| from fairseq.data import data_utils |
|
|
| from . import BaseWrapperDataset |
|
|
|
|
| class PadDataset(BaseWrapperDataset): |
| def __init__(self, dataset, pad_idx, left_pad, pad_length=None): |
| super().__init__(dataset) |
| self.pad_idx = pad_idx |
| self.left_pad = left_pad |
| self.pad_length = pad_length |
|
|
| def collater(self, samples): |
| return data_utils.collate_tokens( |
| samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length |
| ) |
|
|
|
|
| class LeftPadDataset(PadDataset): |
| def __init__(self, dataset, pad_idx): |
| super().__init__(dataset, pad_idx, left_pad=True) |
|
|
|
|
| class RightPadDataset(PadDataset): |
| def __init__(self, dataset, pad_idx): |
| super().__init__(dataset, pad_idx, left_pad=False) |
|
|