File size: 7,761 Bytes
85ba398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
from typing import Any, Dict, Iterator, List

import torch
from fairseq import utils
from omegaconf import open_dict
from torch import nn

from tqdm import tqdm

from fairseq.hub_utils import GeneratorHubInterface


logger = logging.getLogger(__name__)


class MultichannelGeneratorHubInterface(GeneratorHubInterface):
    """Pytorch Hub interface for generating sequences from a pre-trained
    multichannel language model.
    """

    def __init__(self, cfg, task, models):
        super().__init__(cfg, task, models)
        self.cfg = cfg
        self.task = task
        self.models = nn.ModuleList(models)
        self.src_dicts = task.source_dictionaries
        self.tgt_dicts = task.target_dictionaries
        self.channels = task.channels

        # optimize model for generation
        for model in self.models:
            model.prepare_for_inference_(cfg)

    def sample(
        self,
        sentences: List[Dict[str, str]],
        beam: int = 1,
        verbose: bool = False,
        **kwargs
    ) -> List[str]:
        if isinstance(sentences, dict):
            return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
        tokenized_sentences = [self.encode(sentence) for sentence in sentences]
        batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
        return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]

    def score(self, sentences: List[Dict[str, str]], **kwargs):
        raise NotImplementedError(
            "MultichannelGeneratorHubInterface doesn't support score() method"
        )

    def generate(
        self,
        tokenized_sentences: List[Dict[str, torch.LongTensor]],
        beam: int = 5,
        verbose: bool = False,
        skip_invalid_size_inputs=False,
        inference_step_args=None,
        **kwargs
    ) -> List[List[Dict[str, torch.Tensor]]]:
        if isinstance(tokenized_sentences, dict):
            return self.generate(
                [tokenized_sentences], beam=beam, verbose=verbose, **kwargs
            )[0]

        # build generator using current args as well as any kwargs
        gen_args = copy.deepcopy(self.cfg.generation)
        with open_dict(gen_args):
            gen_args.beam = beam
            for k, v in kwargs.items():
                setattr(gen_args, k, v)
        generator = self.task.build_generator(self.models, gen_args)

        inference_step_args = inference_step_args or {}
        results = []
        for batch in tqdm(
            self._build_batches(tokenized_sentences, skip_invalid_size_inputs)
        ):
            batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
            translations = self.task.inference_step(
                generator, self.models, batch, **inference_step_args
            )
            for id, hypos in zip(batch["id"].tolist(), translations):
                # The output of the generator is supposed to be a tensor of size (bsz x max_len x n_channels)
                # So we need to convert it to dictionary form
                for i in range(len(hypos)):
                    hypos[i]["tokens"] = {
                        channel: hypos[i]["tokens"][..., j]
                        for j, channel in enumerate(self.channels)
                    }
                results.append((id, hypos))

        # sort output to match input order
        outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]

        if verbose:

            def getarg(name, default):
                return getattr(gen_args, name, getattr(self.cfg, name, default))

            for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
                src_str_with_unk = {
                    channel: self.string(source_tokens[channel], channel)
                    for channel in source_tokens
                }
                logger.info("S\t{}".format(src_str_with_unk))
                for hypo in target_hypotheses:
                    hypo_str = self.decode(hypo["tokens"])
                    logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
                    # hypo["positional_scores"]: T x n_channels
                    pos_scores = {}
                    for c, channel in enumerate(source_tokens):
                        pos_scores[channel] = " ".join(
                            map(
                                lambda x: "{:.4f}".format(x),
                                hypo["positional_scores"][:, c].tolist(),
                            )
                        )
                    logger.info("P\t{}".format(pos_scores))

        return outputs

    def encode(self, sentence: Dict[str, str]) -> Dict[str, torch.LongTensor]:
        assert isinstance(
            sentence, dict
        ), "Input sentence is expected to be a dictionary over channels"
        assert set(sentence.keys()) == set(
            self.channels
        ), "Mismatch between input sentence keys and model channels ({} vs {})".format(
            set(sentence.keys()), set(self.channels)
        )
        encoded_sentence = {}
        for channel in sentence:
            sentence_channel = sentence[channel]
            sentence_channel = self.tokenize(sentence_channel)
            sentence_channel = self.apply_bpe(sentence_channel)
            sentence_channel = self.binarize(sentence_channel, channel)
            encoded_sentence[channel] = sentence_channel
        sentence_size = encoded_sentence[self.channels[0]].size()
        assert all(
            encoded_sentence[channel].size() == sentence_size
            for channel in encoded_sentence
        ), "Input tensors are expected to have the same size in all channels"
        return encoded_sentence

    def decode(self, tokens: Dict[str, torch.LongTensor]) -> Dict[str, str]:
        assert isinstance(
            tokens, dict
        ), "Input tokens are expected to be a dictionary over channels"
        assert set(tokens.keys()) == set(
            self.channels
        ), "Mismatch between input tokens keys and model channels ({} vs {})".format(
            set(tokens.keys()), set(self.channels)
        )
        decoded_sentence = {}
        for channel in tokens:
            tokens_channel = tokens[channel]
            sentence_channel = self.string(tokens_channel, channel)
            sentence_channel = self.remove_bpe(sentence_channel)
            sentence_channel = self.detokenize(sentence_channel)
            decoded_sentence[channel] = sentence_channel
        return decoded_sentence

    def binarize(self, sentence: str, channel: str) -> torch.LongTensor:
        return (
            self.src_dicts[channel].encode_line(sentence, add_if_not_exist=False).long()
        )

    def string(self, tokens: torch.LongTensor, channel: str) -> str:
        return self.tgt_dicts[channel].string(tokens)

    def _build_batches(
        self, tokens: List[Dict[str, List[int]]], skip_invalid_size_inputs: bool
    ) -> Iterator[Dict[str, Any]]:
        lengths = torch.LongTensor([next(iter(d.values())).numel() for d in tokens])
        batch_iterator = self.task.get_batch_iterator(
            dataset=self.task.build_dataset_for_inference(tokens, lengths),
            max_tokens=self.cfg.dataset.max_tokens,
            max_sentences=self.cfg.dataset.batch_size,
            max_positions=self.max_positions,
            ignore_invalid_inputs=skip_invalid_size_inputs,
            disable_iterator_cache=True,
        ).next_epoch_itr(shuffle=False)
        return batch_iterator