Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import logging | |
| from typing import Any, Dict, Iterator, List | |
| import torch | |
| from fairseq import utils | |
| from omegaconf import open_dict | |
| from torch import nn | |
| from tqdm import tqdm | |
| from fairseq.hub_utils import GeneratorHubInterface | |
| logger = logging.getLogger(__name__) | |
| class MultichannelGeneratorHubInterface(GeneratorHubInterface): | |
| """Pytorch Hub interface for generating sequences from a pre-trained | |
| multichannel language model. | |
| """ | |
| def __init__(self, cfg, task, models): | |
| super().__init__(cfg, task, models) | |
| self.cfg = cfg | |
| self.task = task | |
| self.models = nn.ModuleList(models) | |
| self.src_dicts = task.source_dictionaries | |
| self.tgt_dicts = task.target_dictionaries | |
| self.channels = task.channels | |
| # optimize model for generation | |
| for model in self.models: | |
| model.prepare_for_inference_(cfg) | |
| def sample( | |
| self, | |
| sentences: List[Dict[str, str]], | |
| beam: int = 1, | |
| verbose: bool = False, | |
| **kwargs | |
| ) -> List[str]: | |
| if isinstance(sentences, dict): | |
| return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] | |
| tokenized_sentences = [self.encode(sentence) for sentence in sentences] | |
| batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) | |
| return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos] | |
| def score(self, sentences: List[Dict[str, str]], **kwargs): | |
| raise NotImplementedError( | |
| "MultichannelGeneratorHubInterface doesn't support score() method" | |
| ) | |
| def generate( | |
| self, | |
| tokenized_sentences: List[Dict[str, torch.LongTensor]], | |
| beam: int = 5, | |
| verbose: bool = False, | |
| skip_invalid_size_inputs=False, | |
| inference_step_args=None, | |
| **kwargs | |
| ) -> List[List[Dict[str, torch.Tensor]]]: | |
| if isinstance(tokenized_sentences, dict): | |
| return self.generate( | |
| [tokenized_sentences], beam=beam, verbose=verbose, **kwargs | |
| )[0] | |
| # build generator using current args as well as any kwargs | |
| gen_args = copy.deepcopy(self.cfg.generation) | |
| with open_dict(gen_args): | |
| gen_args.beam = beam | |
| for k, v in kwargs.items(): | |
| setattr(gen_args, k, v) | |
| generator = self.task.build_generator(self.models, gen_args) | |
| inference_step_args = inference_step_args or {} | |
| results = [] | |
| for batch in tqdm( | |
| self._build_batches(tokenized_sentences, skip_invalid_size_inputs) | |
| ): | |
| batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) | |
| translations = self.task.inference_step( | |
| generator, self.models, batch, **inference_step_args | |
| ) | |
| for id, hypos in zip(batch["id"].tolist(), translations): | |
| # The output of the generator is supposed to be a tensor of size (bsz x max_len x n_channels) | |
| # So we need to convert it to dictionary form | |
| for i in range(len(hypos)): | |
| hypos[i]["tokens"] = { | |
| channel: hypos[i]["tokens"][..., j] | |
| for j, channel in enumerate(self.channels) | |
| } | |
| results.append((id, hypos)) | |
| # sort output to match input order | |
| outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] | |
| if verbose: | |
| def getarg(name, default): | |
| return getattr(gen_args, name, getattr(self.cfg, name, default)) | |
| for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): | |
| src_str_with_unk = { | |
| channel: self.string(source_tokens[channel], channel) | |
| for channel in source_tokens | |
| } | |
| logger.info("S\t{}".format(src_str_with_unk)) | |
| for hypo in target_hypotheses: | |
| hypo_str = self.decode(hypo["tokens"]) | |
| logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) | |
| # hypo["positional_scores"]: T x n_channels | |
| pos_scores = {} | |
| for c, channel in enumerate(source_tokens): | |
| pos_scores[channel] = " ".join( | |
| map( | |
| lambda x: "{:.4f}".format(x), | |
| hypo["positional_scores"][:, c].tolist(), | |
| ) | |
| ) | |
| logger.info("P\t{}".format(pos_scores)) | |
| return outputs | |
| def encode(self, sentence: Dict[str, str]) -> Dict[str, torch.LongTensor]: | |
| assert isinstance( | |
| sentence, dict | |
| ), "Input sentence is expected to be a dictionary over channels" | |
| assert set(sentence.keys()) == set( | |
| self.channels | |
| ), "Mismatch between input sentence keys and model channels ({} vs {})".format( | |
| set(sentence.keys()), set(self.channels) | |
| ) | |
| encoded_sentence = {} | |
| for channel in sentence: | |
| sentence_channel = sentence[channel] | |
| sentence_channel = self.tokenize(sentence_channel) | |
| sentence_channel = self.apply_bpe(sentence_channel) | |
| sentence_channel = self.binarize(sentence_channel, channel) | |
| encoded_sentence[channel] = sentence_channel | |
| sentence_size = encoded_sentence[self.channels[0]].size() | |
| assert all( | |
| encoded_sentence[channel].size() == sentence_size | |
| for channel in encoded_sentence | |
| ), "Input tensors are expected to have the same size in all channels" | |
| return encoded_sentence | |
| def decode(self, tokens: Dict[str, torch.LongTensor]) -> Dict[str, str]: | |
| assert isinstance( | |
| tokens, dict | |
| ), "Input tokens are expected to be a dictionary over channels" | |
| assert set(tokens.keys()) == set( | |
| self.channels | |
| ), "Mismatch between input tokens keys and model channels ({} vs {})".format( | |
| set(tokens.keys()), set(self.channels) | |
| ) | |
| decoded_sentence = {} | |
| for channel in tokens: | |
| tokens_channel = tokens[channel] | |
| sentence_channel = self.string(tokens_channel, channel) | |
| sentence_channel = self.remove_bpe(sentence_channel) | |
| sentence_channel = self.detokenize(sentence_channel) | |
| decoded_sentence[channel] = sentence_channel | |
| return decoded_sentence | |
| def binarize(self, sentence: str, channel: str) -> torch.LongTensor: | |
| return ( | |
| self.src_dicts[channel].encode_line(sentence, add_if_not_exist=False).long() | |
| ) | |
| def string(self, tokens: torch.LongTensor, channel: str) -> str: | |
| return self.tgt_dicts[channel].string(tokens) | |
| def _build_batches( | |
| self, tokens: List[Dict[str, List[int]]], skip_invalid_size_inputs: bool | |
| ) -> Iterator[Dict[str, Any]]: | |
| lengths = torch.LongTensor([next(iter(d.values())).numel() for d in tokens]) | |
| batch_iterator = self.task.get_batch_iterator( | |
| dataset=self.task.build_dataset_for_inference(tokens, lengths), | |
| max_tokens=self.cfg.dataset.max_tokens, | |
| max_sentences=self.cfg.dataset.batch_size, | |
| max_positions=self.max_positions, | |
| ignore_invalid_inputs=skip_invalid_size_inputs, | |
| disable_iterator_cache=True, | |
| ).next_epoch_itr(shuffle=False) | |
| return batch_iterator | |