| | |
| | |
| | |
| | |
| |
|
| | from fairseq.data import data_utils |
| |
|
| | from . import BaseWrapperDataset |
| |
|
| |
|
| | class PadDataset(BaseWrapperDataset): |
| |
|
| | def __init__(self, dataset, pad_idx, left_pad): |
| | super().__init__(dataset) |
| | self.pad_idx = pad_idx |
| | self.left_pad = left_pad |
| |
|
| | def collater(self, samples): |
| | return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad) |
| |
|
| |
|
| | class LeftPadDataset(PadDataset): |
| |
|
| | def __init__(self, dataset, pad_idx): |
| | super().__init__(dataset, pad_idx, left_pad=True) |
| |
|
| |
|
| | class RightPadDataset(PadDataset): |
| |
|
| | def __init__(self, dataset, pad_idx): |
| | super().__init__(dataset, pad_idx, left_pad=False) |
| |
|