|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from fairseq.data import data_utils, FairseqDataset |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def collate( |
|
|
samples, |
|
|
pad_idx, |
|
|
eos_idx, |
|
|
left_pad_source=True, |
|
|
left_pad_target=False, |
|
|
input_feeding=True, |
|
|
pad_to_length=None, |
|
|
): |
|
|
if len(samples) == 0: |
|
|
return {} |
|
|
|
|
|
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): |
|
|
return data_utils.collate_tokens( |
|
|
[s[key] for s in samples], |
|
|
pad_idx, eos_idx, left_pad, move_eos_to_beginning, |
|
|
pad_to_length=pad_to_length, |
|
|
) |
|
|
|
|
|
def check_alignment(alignment, src_len, tgt_len): |
|
|
if alignment is None or len(alignment) == 0: |
|
|
return False |
|
|
if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: |
|
|
logger.warning("alignment size mismatch found, skipping alignment!") |
|
|
return False |
|
|
return True |
|
|
|
|
|
def compute_alignment_weights(alignments): |
|
|
""" |
|
|
Given a tensor of shape [:, 2] containing the source-target indices |
|
|
corresponding to the alignments, a weight vector containing the |
|
|
inverse frequency of each target index is computed. |
|
|
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then |
|
|
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target |
|
|
index 3 is repeated twice) |
|
|
""" |
|
|
align_tgt = alignments[:, 1] |
|
|
_, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) |
|
|
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] |
|
|
return 1. / align_weights.float() |
|
|
|
|
|
id = torch.LongTensor([s['id'] for s in samples]) |
|
|
src_tokens = merge( |
|
|
'source', left_pad=left_pad_source, |
|
|
pad_to_length=pad_to_length['source'] if pad_to_length is not None else None |
|
|
) |
|
|
|
|
|
src_lengths = torch.LongTensor([ |
|
|
s['source'].ne(pad_idx).long().sum() for s in samples |
|
|
]) |
|
|
src_lengths, sort_order = src_lengths.sort(descending=True) |
|
|
id = id.index_select(0, sort_order) |
|
|
src_tokens = src_tokens.index_select(0, sort_order) |
|
|
|
|
|
prev_output_tokens = None |
|
|
target = None |
|
|
if samples[0].get('target', None) is not None: |
|
|
target = merge( |
|
|
'target', left_pad=left_pad_target, |
|
|
pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, |
|
|
) |
|
|
target = target.index_select(0, sort_order) |
|
|
tgt_lengths = torch.LongTensor([ |
|
|
s['target'].ne(pad_idx).long().sum() for s in samples |
|
|
]).index_select(0, sort_order) |
|
|
ntokens = tgt_lengths.sum().item() |
|
|
|
|
|
if samples[0].get('prev_output_tokens', None) is not None: |
|
|
prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target) |
|
|
elif input_feeding: |
|
|
|
|
|
|
|
|
prev_output_tokens = merge( |
|
|
'target', |
|
|
left_pad=left_pad_target, |
|
|
move_eos_to_beginning=True, |
|
|
pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, |
|
|
) |
|
|
else: |
|
|
ntokens = src_lengths.sum().item() |
|
|
|
|
|
batch = { |
|
|
'id': id, |
|
|
'nsentences': len(samples), |
|
|
'ntokens': ntokens, |
|
|
'net_input': { |
|
|
'src_tokens': src_tokens, |
|
|
'src_lengths': src_lengths, |
|
|
}, |
|
|
'target': target, |
|
|
} |
|
|
if prev_output_tokens is not None: |
|
|
batch['net_input']['prev_output_tokens'] = prev_output_tokens.index_select(0, sort_order) |
|
|
|
|
|
if samples[0].get('alignment', None) is not None: |
|
|
bsz, tgt_sz = batch['target'].shape |
|
|
src_sz = batch['net_input']['src_tokens'].shape[1] |
|
|
|
|
|
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) |
|
|
offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) |
|
|
if left_pad_source: |
|
|
offsets[:, 0] += (src_sz - src_lengths) |
|
|
if left_pad_target: |
|
|
offsets[:, 1] += (tgt_sz - tgt_lengths) |
|
|
|
|
|
alignments = [ |
|
|
alignment + offset |
|
|
for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) |
|
|
for alignment in [samples[align_idx]['alignment'].view(-1, 2)] |
|
|
if check_alignment(alignment, src_len, tgt_len) |
|
|
] |
|
|
|
|
|
if len(alignments) > 0: |
|
|
alignments = torch.cat(alignments, dim=0) |
|
|
align_weights = compute_alignment_weights(alignments) |
|
|
|
|
|
batch['alignments'] = alignments |
|
|
batch['align_weights'] = align_weights |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
class LanguagePairDataset(FairseqDataset): |
|
|
""" |
|
|
A pair of torch.utils.data.Datasets. |
|
|
|
|
|
Args: |
|
|
src (torch.utils.data.Dataset): source dataset to wrap |
|
|
src_sizes (List[int]): source sentence lengths |
|
|
src_dict (~fairseq.data.Dictionary): source vocabulary |
|
|
tgt (torch.utils.data.Dataset, optional): target dataset to wrap |
|
|
tgt_sizes (List[int], optional): target sentence lengths |
|
|
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary |
|
|
left_pad_source (bool, optional): pad source tensors on the left side |
|
|
(default: True). |
|
|
left_pad_target (bool, optional): pad target tensors on the left side |
|
|
(default: False). |
|
|
shuffle (bool, optional): shuffle dataset elements before batching |
|
|
(default: True). |
|
|
input_feeding (bool, optional): create a shifted version of the targets |
|
|
to be passed into the model for teacher forcing (default: True). |
|
|
remove_eos_from_source (bool, optional): if set, removes eos from end |
|
|
of source if it's present (default: False). |
|
|
append_eos_to_target (bool, optional): if set, appends eos to end of |
|
|
target if it's absent (default: False). |
|
|
align_dataset (torch.utils.data.Dataset, optional): dataset |
|
|
containing alignments. |
|
|
append_bos (bool, optional): if set, appends bos to the beginning of |
|
|
source/target sentence. |
|
|
num_buckets (int, optional): if set to a value greater than 0, then |
|
|
batches will be bucketed into the given number of batch shapes. |
|
|
src_lang_id (int, optional): source language ID, if set, the collated batch |
|
|
will contain a field 'src_lang_id' in 'net_input' which indicates the |
|
|
source language of the samples. |
|
|
tgt_lang_id (int, optional): target language ID, if set, the collated batch |
|
|
will contain a field 'tgt_lang_id' which indicates the target language |
|
|
of the samples. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, src, src_sizes, src_dict, |
|
|
tgt=None, tgt_sizes=None, tgt_dict=None, |
|
|
left_pad_source=True, left_pad_target=False, |
|
|
shuffle=True, input_feeding=True, |
|
|
remove_eos_from_source=False, append_eos_to_target=False, |
|
|
align_dataset=None, |
|
|
append_bos=False, eos=None, |
|
|
num_buckets=0, |
|
|
src_lang_id=None, |
|
|
tgt_lang_id=None, |
|
|
): |
|
|
if tgt_dict is not None: |
|
|
assert src_dict.pad() == tgt_dict.pad() |
|
|
assert src_dict.eos() == tgt_dict.eos() |
|
|
assert src_dict.unk() == tgt_dict.unk() |
|
|
if tgt is not None: |
|
|
assert len(src) == len(tgt), "Source and target must contain the same number of examples" |
|
|
self.src = src |
|
|
self.tgt = tgt |
|
|
self.src_sizes = np.array(src_sizes) |
|
|
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None |
|
|
self.src_dict = src_dict |
|
|
self.tgt_dict = tgt_dict |
|
|
self.left_pad_source = left_pad_source |
|
|
self.left_pad_target = left_pad_target |
|
|
self.shuffle = shuffle |
|
|
self.input_feeding = input_feeding |
|
|
self.remove_eos_from_source = remove_eos_from_source |
|
|
self.append_eos_to_target = append_eos_to_target |
|
|
self.align_dataset = align_dataset |
|
|
if self.align_dataset is not None: |
|
|
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" |
|
|
self.append_bos = append_bos |
|
|
self.eos = (eos if eos is not None else src_dict.eos()) |
|
|
self.src_lang_id = src_lang_id |
|
|
self.tgt_lang_id = tgt_lang_id |
|
|
if num_buckets > 0: |
|
|
from fairseq.data import BucketPadLengthDataset |
|
|
self.src = BucketPadLengthDataset( |
|
|
self.src, |
|
|
sizes=self.src_sizes, |
|
|
num_buckets=num_buckets, |
|
|
pad_idx=self.src_dict.pad(), |
|
|
left_pad=self.left_pad_source, |
|
|
) |
|
|
self.src_sizes = self.src.sizes |
|
|
logger.info('bucketing source lengths: {}'.format(list(self.src.buckets))) |
|
|
if self.tgt is not None: |
|
|
self.tgt = BucketPadLengthDataset( |
|
|
self.tgt, |
|
|
sizes=self.tgt_sizes, |
|
|
num_buckets=num_buckets, |
|
|
pad_idx=self.tgt_dict.pad(), |
|
|
left_pad=self.left_pad_target, |
|
|
) |
|
|
self.tgt_sizes = self.tgt.sizes |
|
|
logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets))) |
|
|
|
|
|
|
|
|
|
|
|
num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) |
|
|
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) |
|
|
self.buckets = [ |
|
|
(None, num_tokens) |
|
|
for num_tokens in np.unique(self.bucketed_num_tokens) |
|
|
] |
|
|
else: |
|
|
self.buckets = None |
|
|
|
|
|
def get_batch_shapes(self): |
|
|
return self.buckets |
|
|
|
|
|
def __getitem__(self, index): |
|
|
tgt_item = self.tgt[index] if self.tgt is not None else None |
|
|
src_item = self.src[index] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.append_eos_to_target: |
|
|
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() |
|
|
if self.tgt and self.tgt[index][-1] != eos: |
|
|
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) |
|
|
|
|
|
if self.append_bos: |
|
|
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() |
|
|
if self.tgt and self.tgt[index][0] != bos: |
|
|
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) |
|
|
|
|
|
bos = self.src_dict.bos() |
|
|
if self.src[index][0] != bos: |
|
|
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) |
|
|
|
|
|
if self.remove_eos_from_source: |
|
|
eos = self.src_dict.eos() |
|
|
if self.src[index][-1] == eos: |
|
|
src_item = self.src[index][:-1] |
|
|
|
|
|
example = { |
|
|
'id': index, |
|
|
'source': src_item, |
|
|
'target': tgt_item, |
|
|
} |
|
|
if self.align_dataset is not None: |
|
|
example['alignment'] = self.align_dataset[index] |
|
|
return example |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.src) |
|
|
|
|
|
def collater(self, samples, pad_to_length=None): |
|
|
"""Merge a list of samples to form a mini-batch. |
|
|
|
|
|
Args: |
|
|
samples (List[dict]): samples to collate |
|
|
pad_to_length (dict, optional): a dictionary of |
|
|
{'source': source_pad_to_length, 'target': target_pad_to_length} |
|
|
to indicate the max length to pad to in source and target respectively. |
|
|
|
|
|
Returns: |
|
|
dict: a mini-batch with the following keys: |
|
|
|
|
|
- `id` (LongTensor): example IDs in the original input order |
|
|
- `ntokens` (int): total number of tokens in the batch |
|
|
- `net_input` (dict): the input to the Model, containing keys: |
|
|
|
|
|
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in |
|
|
the source sentence of shape `(bsz, src_len)`. Padding will |
|
|
appear on the left if *left_pad_source* is ``True``. |
|
|
- `src_lengths` (LongTensor): 1D Tensor of the unpadded |
|
|
lengths of each source sentence of shape `(bsz)` |
|
|
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of |
|
|
tokens in the target sentence, shifted right by one |
|
|
position for teacher forcing, of shape `(bsz, tgt_len)`. |
|
|
This key will not be present if *input_feeding* is |
|
|
``False``. Padding will appear on the left if |
|
|
*left_pad_target* is ``True``. |
|
|
- `src_lang_id` (LongTensor): a long Tensor which contains source |
|
|
language IDs of each sample in the batch |
|
|
|
|
|
- `target` (LongTensor): a padded 2D Tensor of tokens in the |
|
|
target sentence of shape `(bsz, tgt_len)`. Padding will appear |
|
|
on the left if *left_pad_target* is ``True``. |
|
|
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language |
|
|
IDs of each sample in the batch |
|
|
""" |
|
|
res = collate( |
|
|
samples, |
|
|
pad_idx=self.src_dict.pad(), |
|
|
eos_idx=self.eos, |
|
|
left_pad_source=self.left_pad_source, |
|
|
left_pad_target=self.left_pad_target, |
|
|
input_feeding=self.input_feeding, |
|
|
pad_to_length=pad_to_length, |
|
|
) |
|
|
if self.src_lang_id is not None or self.tgt_lang_id is not None: |
|
|
src_tokens = res['net_input']['src_tokens'] |
|
|
bsz = src_tokens.size(0) |
|
|
if self.src_lang_id is not None: |
|
|
res['net_input']['src_lang_id'] = torch.LongTensor( |
|
|
[[self.src_lang_id]] |
|
|
).expand(bsz, 1).to(src_tokens) |
|
|
if self.tgt_lang_id is not None: |
|
|
res['tgt_lang_id'] = torch.LongTensor( |
|
|
[[self.tgt_lang_id]] |
|
|
).expand(bsz, 1).to(src_tokens) |
|
|
return res |
|
|
|
|
|
def num_tokens(self, index): |
|
|
"""Return the number of tokens in a sample. This value is used to |
|
|
enforce ``--max-tokens`` during batching.""" |
|
|
return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) |
|
|
|
|
|
def size(self, index): |
|
|
"""Return an example's size as a float or tuple. This value is used when |
|
|
filtering a dataset with ``--max-positions``.""" |
|
|
return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) |
|
|
|
|
|
def ordered_indices(self): |
|
|
"""Return an ordered list of indices. Batches will be constructed based |
|
|
on this order.""" |
|
|
if self.shuffle: |
|
|
indices = np.random.permutation(len(self)) |
|
|
else: |
|
|
indices = np.arange(len(self)) |
|
|
if self.buckets is None: |
|
|
|
|
|
if self.tgt_sizes is not None: |
|
|
indices = indices[ |
|
|
np.argsort(self.tgt_sizes[indices], kind='mergesort') |
|
|
] |
|
|
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] |
|
|
else: |
|
|
|
|
|
|
|
|
return indices[ |
|
|
np.argsort(self.bucketed_num_tokens[indices], kind='mergesort') |
|
|
] |
|
|
|
|
|
@property |
|
|
def supports_prefetch(self): |
|
|
return ( |
|
|
getattr(self.src, 'supports_prefetch', False) |
|
|
and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None) |
|
|
) |
|
|
|
|
|
def prefetch(self, indices): |
|
|
self.src.prefetch(indices) |
|
|
if self.tgt is not None: |
|
|
self.tgt.prefetch(indices) |
|
|
if self.align_dataset is not None: |
|
|
self.align_dataset.prefetch(indices) |
|
|
|
|
|
def filter_indices_by_size(self, indices, max_sizes): |
|
|
""" Filter a list of sample indices. Remove those that are longer |
|
|
than specified in max_sizes. |
|
|
|
|
|
Args: |
|
|
indices (np.array): original array of sample indices |
|
|
max_sizes (int or list[int] or tuple[int]): max sample size, |
|
|
can be defined separately for src and tgt (then list or tuple) |
|
|
|
|
|
Returns: |
|
|
np.array: filtered sample array |
|
|
list: list of removed indices |
|
|
""" |
|
|
if max_sizes is None: |
|
|
return indices, [] |
|
|
if type(max_sizes) in (int, float): |
|
|
max_src_size, max_tgt_size = max_sizes, max_sizes |
|
|
else: |
|
|
max_src_size, max_tgt_size = max_sizes |
|
|
if self.tgt_sizes is None: |
|
|
ignored = indices[self.src_sizes[indices] > max_src_size] |
|
|
else: |
|
|
ignored = indices[(self.src_sizes[indices] > max_src_size) | |
|
|
(self.tgt_sizes[indices] > max_tgt_size)] |
|
|
if len(ignored) > 0: |
|
|
if self.tgt_sizes is None: |
|
|
indices = indices[self.src_sizes[indices] <= max_src_size] |
|
|
else: |
|
|
indices = indices[(self.src_sizes[indices] <= max_src_size) & |
|
|
(self.tgt_sizes[indices] <= max_tgt_size)] |
|
|
return indices, ignored.tolist() |
|
|
|