| | |
| | |
| | |
| | |
| |
|
| |
|
| | import logging |
| | import os |
| | import contextlib |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from fairseq.data import FairseqDataset, data_utils |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ExtractedFeaturesDataset(FairseqDataset): |
| | def __init__( |
| | self, |
| | path, |
| | split, |
| | min_length=3, |
| | max_length=None, |
| | labels=None, |
| | label_dict=None, |
| | shuffle=True, |
| | sort_by_length=True, |
| | ): |
| | super().__init__() |
| |
|
| | self.min_length = min_length |
| | self.max_length = max_length |
| | self.shuffle = shuffle |
| | self.sort_by_length = sort_by_length |
| | self.label_dict = label_dict |
| |
|
| | if labels is not None: |
| | assert label_dict is not None |
| |
|
| | self.sizes = [] |
| | self.offsets = [] |
| | self.labels = [] |
| |
|
| | path = os.path.join(path, split) |
| | data_path = path |
| | self.data = np.load(data_path + ".npy", mmap_mode="r") |
| |
|
| | offset = 0 |
| | skipped = 0 |
| |
|
| | if not os.path.exists(path + f".{labels}"): |
| | labels = None |
| |
|
| | with open(data_path + ".lengths", "r") as len_f, open( |
| | path + f".{labels}", "r" |
| | ) if labels is not None else contextlib.ExitStack() as lbl_f: |
| | for line in len_f: |
| | length = int(line.rstrip()) |
| | lbl = None if labels is None else next(lbl_f).rstrip().split() |
| | if length >= min_length and ( |
| | max_length is None or length <= max_length |
| | ): |
| | self.sizes.append(length) |
| | self.offsets.append(offset) |
| | if lbl is not None: |
| | self.labels.append(lbl) |
| | offset += length |
| |
|
| | self.sizes = np.asarray(self.sizes) |
| | self.offsets = np.asarray(self.offsets) |
| |
|
| | logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples") |
| |
|
| | def __getitem__(self, index): |
| | offset = self.offsets[index] |
| | end = self.sizes[index] + offset |
| | feats = torch.from_numpy(self.data[offset:end].copy()).float() |
| |
|
| | res = {"id": index, "features": feats} |
| | if len(self.labels) > 0: |
| | res["target"] = self.label_dict.encode_line( |
| | self.labels[index], |
| | line_tokenizer=lambda x: x, |
| | append_eos=False, |
| | ) |
| |
|
| | return res |
| |
|
| | def __len__(self): |
| | return len(self.sizes) |
| |
|
| | def collater(self, samples): |
| | if len(samples) == 0: |
| | return {} |
| |
|
| | features = [s["features"] for s in samples] |
| | sizes = [len(s) for s in features] |
| |
|
| | target_size = max(sizes) |
| |
|
| | collated_features = features[0].new_zeros( |
| | len(features), target_size, features[0].size(-1) |
| | ) |
| | padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False) |
| | for i, (f, size) in enumerate(zip(features, sizes)): |
| | collated_features[i, :size] = f |
| | padding_mask[i, size:] = True |
| |
|
| | res = { |
| | "id": torch.LongTensor([s["id"] for s in samples]), |
| | "net_input": {"features": collated_features, "padding_mask": padding_mask}, |
| | } |
| |
|
| | if len(self.labels) > 0: |
| | target = data_utils.collate_tokens( |
| | [s["target"] for s in samples], |
| | pad_idx=self.label_dict.pad(), |
| | left_pad=False, |
| | ) |
| | res["target"] = target |
| | return res |
| |
|
| | def num_tokens(self, index): |
| | return self.size(index) |
| |
|
| | def size(self, index): |
| | return self.sizes[index] |
| |
|
| | def ordered_indices(self): |
| | """Return an ordered list of indices. Batches will be constructed based |
| | on this order.""" |
| | if self.shuffle: |
| | order = [np.random.permutation(len(self))] |
| | else: |
| | order = [np.arange(len(self))] |
| |
|
| | if self.sort_by_length: |
| | order.append(self.sizes) |
| | return np.lexsort(order)[::-1] |
| | else: |
| | return order[0] |
| |
|