Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from fairseq import utils | |
| from . import FairseqDataset | |
| def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): | |
| """Backtranslate a list of samples. | |
| Given an input (*samples*) of the form: | |
| [{'id': 1, 'source': 'hallo welt'}] | |
| this will return: | |
| [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}] | |
| Args: | |
| samples (List[dict]): samples to backtranslate. Individual samples are | |
| expected to have a 'source' key, which will become the 'target' | |
| after backtranslation. | |
| collate_fn (callable): function to collate samples into a mini-batch | |
| generate_fn (callable): function to generate backtranslations | |
| cuda (bool): use GPU for generation (default: ``True``) | |
| Returns: | |
| List[dict]: an updated list of samples with a backtranslated source | |
| """ | |
| collated_samples = collate_fn(samples) | |
| s = utils.move_to_cuda(collated_samples) if cuda else collated_samples | |
| generated_sources = generate_fn(s) | |
| id_to_src = {sample["id"]: sample["source"] for sample in samples} | |
| # Go through each tgt sentence in batch and its corresponding best | |
| # generated hypothesis and create a backtranslation data pair | |
| # {id: id, source: generated backtranslation, target: original tgt} | |
| return [ | |
| { | |
| "id": id.item(), | |
| "target": id_to_src[id.item()], | |
| "source": hypos[0]["tokens"].cpu(), | |
| } | |
| for id, hypos in zip(collated_samples["id"], generated_sources) | |
| ] | |
| class BacktranslationDataset(FairseqDataset): | |
| """ | |
| Sets up a backtranslation dataset which takes a tgt batch, generates | |
| a src using a tgt-src backtranslation function (*backtranslation_fn*), | |
| and returns the corresponding `{generated src, input tgt}` batch. | |
| Args: | |
| tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be | |
| backtranslated. Only the source side of this dataset will be used. | |
| After backtranslation, the source sentences in this dataset will be | |
| returned as the targets. | |
| src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated | |
| sentences. | |
| tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of | |
| sentences to be backtranslated. | |
| backtranslation_fn (callable, optional): function to call to generate | |
| backtranslations. This is typically the `generate` method of a | |
| :class:`~fairseq.sequence_generator.SequenceGenerator` object. | |
| Pass in None when it is not available at initialization time, and | |
| use set_backtranslation_fn function to set it when available. | |
| output_collater (callable, optional): function to call on the | |
| backtranslated samples to create the final batch | |
| (default: ``tgt_dataset.collater``). | |
| cuda: use GPU for generation | |
| """ | |
| def __init__( | |
| self, | |
| tgt_dataset, | |
| src_dict, | |
| tgt_dict=None, | |
| backtranslation_fn=None, | |
| output_collater=None, | |
| cuda=True, | |
| **kwargs | |
| ): | |
| self.tgt_dataset = tgt_dataset | |
| self.backtranslation_fn = backtranslation_fn | |
| self.output_collater = ( | |
| output_collater if output_collater is not None else tgt_dataset.collater | |
| ) | |
| self.cuda = cuda if torch.cuda.is_available() else False | |
| self.src_dict = src_dict | |
| self.tgt_dict = tgt_dict | |
| def __getitem__(self, index): | |
| """ | |
| Returns a single sample from *tgt_dataset*. Note that backtranslation is | |
| not applied in this step; use :func:`collater` instead to backtranslate | |
| a batch of samples. | |
| """ | |
| return self.tgt_dataset[index] | |
| def __len__(self): | |
| return len(self.tgt_dataset) | |
| def set_backtranslation_fn(self, backtranslation_fn): | |
| self.backtranslation_fn = backtranslation_fn | |
| def collater(self, samples): | |
| """Merge and backtranslate a list of samples to form a mini-batch. | |
| Using the samples from *tgt_dataset*, load a collated target sample to | |
| feed to the backtranslation model. Then take the backtranslation with | |
| the best score as the source and the original input as the target. | |
| Note: we expect *tgt_dataset* to provide a function `collater()` that | |
| will collate samples into the format expected by *backtranslation_fn*. | |
| After backtranslation, we will feed the new list of samples (i.e., the | |
| `(backtranslated source, original source)` pairs) to *output_collater* | |
| and return the result. | |
| Args: | |
| samples (List[dict]): samples to backtranslate and collate | |
| Returns: | |
| dict: a mini-batch with keys coming from *output_collater* | |
| """ | |
| if samples[0].get("is_dummy", False): | |
| return samples | |
| samples = backtranslate_samples( | |
| samples=samples, | |
| collate_fn=self.tgt_dataset.collater, | |
| generate_fn=(lambda net_input: self.backtranslation_fn(net_input)), | |
| cuda=self.cuda, | |
| ) | |
| return self.output_collater(samples) | |
| def num_tokens(self, index): | |
| """Just use the tgt dataset num_tokens""" | |
| return self.tgt_dataset.num_tokens(index) | |
| def ordered_indices(self): | |
| """Just use the tgt dataset ordered_indices""" | |
| return self.tgt_dataset.ordered_indices() | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used | |
| when filtering a dataset with ``--max-positions``. | |
| Note: we use *tgt_dataset* to approximate the length of the source | |
| sentence, since we do not know the actual length until after | |
| backtranslation. | |
| """ | |
| tgt_size = self.tgt_dataset.size(index)[0] | |
| return (tgt_size, tgt_size) | |
| def supports_prefetch(self): | |
| return getattr(self.tgt_dataset, "supports_prefetch", False) | |
| def prefetch(self, indices): | |
| return self.tgt_dataset.prefetch(indices) | |