File size: 12,365 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from typing import Union, List, Tuple

from src.models.model import Model

################################################################################
# Extends Model class for speech recognition, with optional decoding
################################################################################


class Decoder(object):
    """

    Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e). Base

    class for decoder objects, which convert emitted frame-by-frame token

    probabilities into a string transcription.

    """
    def __init__(self,

                 labels: Union[List[str], Tuple[str]],

                 sep_idx: int = None,

                 blank_idx: int = 0):
        """

        Parameters

        ----------

        labels (list):   character corresponding to each token index



        sep_idx (int):   index corresponding to space / separating character



        blank_idx (int): index corresponding to blank '_' character

        """
        self.labels = labels
        self.blank_idx = blank_idx
        self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])

        if sep_idx is None:

            # use out-of-bounds index for separating character
            sep_idx = len(labels)
            if ' ' in labels:
                sep_idx = labels.index(' ')
            elif '|' in labels:
                sep_idx = labels.index('|')
            self.sep_idx = sep_idx

        else:
            self.sep_idx = sep_idx

    def get_labels(self):
        return self.labels

    def get_sep_idx(self):
        return self.sep_idx

    def get_blank_idx(self):
        return self.blank_idx

    def __call__(self, emission: torch.Tensor, sizes=None):
        return self.decode(emission, sizes)

    def decode(self, emission: torch.Tensor, sizes=None):
        """

        Decode emitted token probabilities to obtain a string transcription.



        Parameters

        ----------

        emission (Tensor): shape (n_batch, n_frames, n_tokens)



        sizes (Tensor):    length in frames of each emission in batch

        """
        raise NotImplementedError


class GreedyCTCDecoder(Decoder):
    """

    A simple decoder module to map token probability sequences to transcripts.

    Decodes 'greedily' by selecting maximum-probability token at each time step.

    Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e).

    """
    def __init__(self,

                 labels: Union[List[str], Tuple[str]],

                 sep_idx: int = None,

                 blank_idx: int = 0):
        super().__init__(labels, sep_idx, blank_idx)

    def convert_to_strings(self,

                           sequences,

                           sizes=None,

                           remove_repetitions=False,

                           return_offsets=False):
        """

        Given a list of sequences holding token numbers, return the

        corresponding strings. Optionally, collapse repeated token subsequences

        and return final length of each processed sequence.



        Parameters

        ----------



        sequences (Tensor): shape (n_batch, n_frames); holds argmax token index

                            for each frame



        sizes



        remove_repetitions



        return_offsets



        Returns

        -------



        """

        strings = []
        offsets = [] if return_offsets else None

        for i, sequence in enumerate(sequences):

            seq_len = sizes[i] if sizes is not None else len(sequence)
            string, string_offsets = self.process_string(sequence, seq_len, remove_repetitions)
            strings.append(string)
            if return_offsets:
                offsets.append(string_offsets)

        if return_offsets:
            return strings, offsets
        else:
            return strings

    def process_string(self,

                       sequence,

                       size,

                       remove_repetitions=False):
        string = ''
        offsets = []

        for i in range(size):
            char = self.int_to_char[sequence[i].item()]

            if char != self.int_to_char[self.blank_idx]:

                # skip repeated characters if specified
                if remove_repetitions and i != 0 and \
                        char == self.int_to_char[sequence[i - 1].item()]:
                    pass
                elif char == self.labels[self.sep_idx]:
                    string += self.labels[self.sep_idx]
                    offsets.append(i)
                else:
                    string = string + char
                    offsets.append(i)

        return string, torch.tensor(offsets, dtype=torch.int)

    def decode(self, emission, sizes=None):
        """

        Returns the argmax decoding given the emitted token probabilities.

        According to connectionist temporal classification (CTC), removes

        repeated elements in the decoded token sequence, as well as blanks.



        Parameters

        ----------

        emission (Tensor): shape (n_batch, n_frames, n_tokens)



        sizes (Tensor):    length in frames of each emission in batch



        Returns

        -------

        transcription (list[str]): string transcription for each item in batch



        offsets (???     frame index per character predicted



        """

        if emission.ndim == 2:  # require shape (n_batch, n_frames, n_tokens)
            emission = emission.unsqueeze(0)

        # compute max-probability label at each sequence index
        max_probs = torch.argmax(emission, dim=-1)  # (n_batch, sequence_len)

        strings, offsets = self.convert_to_strings(max_probs,
                                                   sizes,
                                                   remove_repetitions=True,
                                                   return_offsets=True)
        return strings, offsets


class SpeechRecognitionModel(Model):

    def __init__(self,

                 model: nn.Module,

                 decoder: Decoder = None

                 ):

        super().__init__()

        self.model = model
        self.model.eval()

        # ensure that list of viable tokens can be retrieved from wrapped model
        labels_method = getattr(self.model, "get_labels", None)
        labels_attr = getattr(self.model, "labels", None)
        if callable(labels_method):
            self._get_labels_fn = lambda: self.model.get_labels()
        elif labels_attr is not None:
            self._get_labels_fn = lambda: self.model.labels
        else:
            raise ValueError(f'Wrapped model must have method `.get_labels()`'
                             f' or attribute `.labels`')

        # ensure that blank and separator tokens can be retrieved from wrapped
        # model
        sep_method = getattr(self.model, "get_sep_idx", None)
        sep_attr = getattr(self.model, "sep_idx", None)
        if callable(sep_method):
            self._get_sep_fn = lambda: self.model.get_sep_idx()
        elif sep_attr is not None:
            self._get_sep_fn = lambda: self.model.sep_idx
        else:
            raise ValueError(f'Wrapped model must have method `.get_sep_idx()`'
                             f' or attribute `.sep_idx`')

        blank_method = getattr(self.model, "get_blank_idx", None)
        blank_attr = getattr(self.model, "blank_idx", None)
        if callable(blank_method):
            self._get_blank_fn = lambda: self.model.get_blank_idx()
        elif blank_attr is not None:
            self._get_blank_fn = lambda: self.model.blank_idx
        else:
            raise ValueError(f'Wrapped model must have method '
                             f'`.get_blank_idx()` or attribute `.blank_idx`')

        # initialize decoder
        if decoder is None:
            decoder = GreedyCTCDecoder(
                labels=self.get_labels(),
                blank_idx=self.get_blank_idx(),
                sep_idx=self.get_sep_idx()
            )
        self.decoder = decoder

        # translate characters to token indices
        self.char_to_idx = {l: i for i, l in enumerate(decoder.labels)}

    def get_labels(self):
        """Retrieve a list of valid tokens"""
        return self._get_labels_fn()

    def get_blank_idx(self):
        """Return index of blank token"""
        return self._get_blank_fn()

    def get_sep_idx(self):
        """Return index of separator token"""
        return self._get_sep_fn()

    def forward(self, x: torch.Tensor):
        return self.model.forward(x)

    def transcribe(self, x: torch.Tensor, return_alignment: bool = False):

        if return_alignment:
            return self.decoder(self.model(x))
        else:
            return self.decoder(self.model(x))[0]

    def load_weights(self, path: str):
        """

        Load weights from checkpoint file

        """

        # check if file exists
        if not path or not os.path.isfile(path):
            return

        model_state = self.model.state_dict()
        loaded_state = torch.load(path)

        for name, param in loaded_state.items():

            origname = name

            if name not in model_state:
                print("{} is not in the model.".format(origname))
                continue

            if model_state[name].size() != loaded_state[origname].size():
                print(
                    "Wrong parameter length: {}, model: {}, loaded: {}".format(
                        origname,
                        model_state[name].size(),
                        loaded_state[origname].size()
                    )
                )
                continue

            model_state[name].copy_(param)

    def extract_features(

            self,

            x: torch.Tensor

    ) -> List[torch.Tensor]:

        """

        Extract deep features.



        :param x: input

        :return: a list of tensors holding intermediate activations / features

        """

        try:
            return self.model.extract_features(x)
        except AttributeError:
            return []

    def _str_to_tensor(self, seq: str):
        token_indices = [self.char_to_idx[c] for c in seq]
        return torch.as_tensor(token_indices, dtype=torch.long)

    def match_predict(self,

                      y_pred: Union[List[str], torch.Tensor],

                      y_true: Union[List[str], torch.Tensor]):
        """

        Determine whether (batched) target pairs are equivalent.

        """

        n_batch = len(y_pred)

        y_true_lengths = None

        # convert ground-truth transcriptions to tensor form
        if isinstance(y_true, list):
            y_true = [self._str_to_tensor(t) for t in y_true]
            y_true_lengths = [t.shape[-1] for t in y_true]
            y_true = pad_sequence(
                y_true,
                batch_first=True
            )  # (n_batch, max_seq_len)

        if y_true_lengths is None:
            y_true_lengths = [y_true.shape[-1]] * n_batch

        # convert predicted transcriptions to tensor form
        if isinstance(y_pred, list):
            y_pred = [self._str_to_tensor(t) for t in y_pred]
            y_pred = pad_sequence(
                y_pred,
                batch_first=True
            )  # (n_batch, max_seq_len)

        length_diff = max(0, y_true.shape[-1] - y_pred.shape[-1])
        if length_diff:
            y_pred = F.pad(y_pred, (0, length_diff))

        matches = []
        for i in range(n_batch):
            matches.append(
                torch.all(
                    y_pred[i, ..., :y_true_lengths[i]] == y_true[i, ..., :y_true_lengths[i]]
                )
            )

        return torch.as_tensor(matches)



        """

        # masked comparison

        use which one as dimension to select --- true or pred?



        pred lengths may be unnecessary! just select to true length

        """