Spaces:
Sleeping
Sleeping
File size: 7,761 Bytes
85ba398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from typing import Any, Dict, Iterator, List
import torch
from fairseq import utils
from omegaconf import open_dict
from torch import nn
from tqdm import tqdm
from fairseq.hub_utils import GeneratorHubInterface
logger = logging.getLogger(__name__)
class MultichannelGeneratorHubInterface(GeneratorHubInterface):
"""Pytorch Hub interface for generating sequences from a pre-trained
multichannel language model.
"""
def __init__(self, cfg, task, models):
super().__init__(cfg, task, models)
self.cfg = cfg
self.task = task
self.models = nn.ModuleList(models)
self.src_dicts = task.source_dictionaries
self.tgt_dicts = task.target_dictionaries
self.channels = task.channels
# optimize model for generation
for model in self.models:
model.prepare_for_inference_(cfg)
def sample(
self,
sentences: List[Dict[str, str]],
beam: int = 1,
verbose: bool = False,
**kwargs
) -> List[str]:
if isinstance(sentences, dict):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
def score(self, sentences: List[Dict[str, str]], **kwargs):
raise NotImplementedError(
"MultichannelGeneratorHubInterface doesn't support score() method"
)
def generate(
self,
tokenized_sentences: List[Dict[str, torch.LongTensor]],
beam: int = 5,
verbose: bool = False,
skip_invalid_size_inputs=False,
inference_step_args=None,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
if isinstance(tokenized_sentences, dict):
return self.generate(
[tokenized_sentences], beam=beam, verbose=verbose, **kwargs
)[0]
# build generator using current args as well as any kwargs
gen_args = copy.deepcopy(self.cfg.generation)
with open_dict(gen_args):
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(self.models, gen_args)
inference_step_args = inference_step_args or {}
results = []
for batch in tqdm(
self._build_batches(tokenized_sentences, skip_invalid_size_inputs)
):
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
translations = self.task.inference_step(
generator, self.models, batch, **inference_step_args
)
for id, hypos in zip(batch["id"].tolist(), translations):
# The output of the generator is supposed to be a tensor of size (bsz x max_len x n_channels)
# So we need to convert it to dictionary form
for i in range(len(hypos)):
hypos[i]["tokens"] = {
channel: hypos[i]["tokens"][..., j]
for j, channel in enumerate(self.channels)
}
results.append((id, hypos))
# sort output to match input order
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
if verbose:
def getarg(name, default):
return getattr(gen_args, name, getattr(self.cfg, name, default))
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
src_str_with_unk = {
channel: self.string(source_tokens[channel], channel)
for channel in source_tokens
}
logger.info("S\t{}".format(src_str_with_unk))
for hypo in target_hypotheses:
hypo_str = self.decode(hypo["tokens"])
logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
# hypo["positional_scores"]: T x n_channels
pos_scores = {}
for c, channel in enumerate(source_tokens):
pos_scores[channel] = " ".join(
map(
lambda x: "{:.4f}".format(x),
hypo["positional_scores"][:, c].tolist(),
)
)
logger.info("P\t{}".format(pos_scores))
return outputs
def encode(self, sentence: Dict[str, str]) -> Dict[str, torch.LongTensor]:
assert isinstance(
sentence, dict
), "Input sentence is expected to be a dictionary over channels"
assert set(sentence.keys()) == set(
self.channels
), "Mismatch between input sentence keys and model channels ({} vs {})".format(
set(sentence.keys()), set(self.channels)
)
encoded_sentence = {}
for channel in sentence:
sentence_channel = sentence[channel]
sentence_channel = self.tokenize(sentence_channel)
sentence_channel = self.apply_bpe(sentence_channel)
sentence_channel = self.binarize(sentence_channel, channel)
encoded_sentence[channel] = sentence_channel
sentence_size = encoded_sentence[self.channels[0]].size()
assert all(
encoded_sentence[channel].size() == sentence_size
for channel in encoded_sentence
), "Input tensors are expected to have the same size in all channels"
return encoded_sentence
def decode(self, tokens: Dict[str, torch.LongTensor]) -> Dict[str, str]:
assert isinstance(
tokens, dict
), "Input tokens are expected to be a dictionary over channels"
assert set(tokens.keys()) == set(
self.channels
), "Mismatch between input tokens keys and model channels ({} vs {})".format(
set(tokens.keys()), set(self.channels)
)
decoded_sentence = {}
for channel in tokens:
tokens_channel = tokens[channel]
sentence_channel = self.string(tokens_channel, channel)
sentence_channel = self.remove_bpe(sentence_channel)
sentence_channel = self.detokenize(sentence_channel)
decoded_sentence[channel] = sentence_channel
return decoded_sentence
def binarize(self, sentence: str, channel: str) -> torch.LongTensor:
return (
self.src_dicts[channel].encode_line(sentence, add_if_not_exist=False).long()
)
def string(self, tokens: torch.LongTensor, channel: str) -> str:
return self.tgt_dicts[channel].string(tokens)
def _build_batches(
self, tokens: List[Dict[str, List[int]]], skip_invalid_size_inputs: bool
) -> Iterator[Dict[str, Any]]:
lengths = torch.LongTensor([next(iter(d.values())).numel() for d in tokens])
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(tokens, lengths),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=self.max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs,
disable_iterator_cache=True,
).next_epoch_itr(shuffle=False)
return batch_iterator
|