| | |
| | |
| | |
| | |
| |
|
| | import torch |
| |
|
| | from . import BaseWrapperDataset, data_utils |
| |
|
| |
|
| | class AddTargetDataset(BaseWrapperDataset): |
| | def __init__( |
| | self, |
| | dataset, |
| | labels, |
| | pad, |
| | eos, |
| | batch_targets, |
| | process_label=None, |
| | add_to_input=False, |
| | ): |
| | super().__init__(dataset) |
| | self.labels = labels |
| | self.batch_targets = batch_targets |
| | self.pad = pad |
| | self.eos = eos |
| | self.process_label = process_label |
| | self.add_to_input = add_to_input |
| |
|
| | def get_label(self, index): |
| | return ( |
| | self.labels[index] |
| | if self.process_label is None |
| | else self.process_label(self.labels[index]) |
| | ) |
| |
|
| | def __getitem__(self, index): |
| | item = self.dataset[index] |
| | item["label"] = self.get_label(index) |
| | return item |
| |
|
| | def size(self, index): |
| | sz = self.dataset.size(index) |
| | own_sz = len(self.get_label(index)) |
| | return (sz, own_sz) |
| |
|
| | def collater(self, samples): |
| | collated = self.dataset.collater(samples) |
| | if len(collated) == 0: |
| | return collated |
| | indices = set(collated["id"].tolist()) |
| | target = [s["label"] for s in samples if s["id"] in indices] |
| |
|
| | if self.batch_targets: |
| | collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) |
| | target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) |
| | collated["ntokens"] = collated["target_lengths"].sum().item() |
| | else: |
| | collated["ntokens"] = sum([len(t) for t in target]) |
| |
|
| | collated["target"] = target |
| |
|
| | if self.add_to_input: |
| | eos = target.new_full((target.size(0), 1), self.eos) |
| | collated["target"] = torch.cat([target, eos], dim=-1).long() |
| | collated["net_input"]["prev_output_tokens"] = torch.cat( |
| | [eos, target], dim=-1 |
| | ).long() |
| | collated["ntokens"] += target.size(0) |
| | return collated |
| |
|