File size: 17,665 Bytes
85ba398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
#!/usr/bin/env python3

import math

import torch
import torch.nn as nn

from fairseq.data.data_utils import compute_mask_indices
from fairseq.models import FairseqEncoder
from fairseq.models.wav2vec import ConvFeatureExtractionModel
from fairseq.modules import GradMultiply, LayerNorm, SamePad, TransformerEncoderLayer


#   Transformer encoder with wave input, it is adopted from wav2vec 2.0 Encoder.
#       use wav input
#       use trained position embedding so it is easier to match with text input
class SpeechWavTransformerEncoder(FairseqEncoder):

    # extra parameters for speech encoder besides those defined in transformermodel
    @staticmethod
    def add_args(parser):
        parser.add_argument(
            "--dropout-input",
            type=float,
            metavar="D",
            help="dropout to apply to the input (after feat extr)",
        )
        parser.add_argument(
            "--dropout-features",
            type=float,
            metavar="D",
            help="dropout to apply to the unmasked features (after feat extr)",
        )
        parser.add_argument(
            "--speech-extractor-mode",
            type=str,
            default="layer_norm",
            choices=["default", "layer_norm"],
            help="feature extractor norm",
        )

        parser.add_argument(
            "--speech-conv-bias",
            action="store_true",
            help="include bias in speech conv encoder",
        )

        parser.add_argument(
            "--conv-feature-layers",
            default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
            help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
        )

        parser.add_argument(
            "--speech-mask-length",
            type=int,
            help="repeat the mask indices multiple times",
        )

        parser.add_argument(
            "--speech-mask-prob",
            type=float,
            help="probability of replacing a token with mask",
        )

        parser.add_argument(
            "--speech-mask-selection",
            type=str,
            choices=["static", "uniform", "normal", "poisson"],
            help="how to choose masks",
        )

        parser.add_argument(
            "--speech-mask-other",
            type=float,
            help="stdev of the mask length in case of 'normal' selection strategy",
        )

        parser.add_argument(
            "--speech-no-mask-overlap",
            action="store_true",
            help="whether to allow masks to overlap",
        )

        parser.add_argument(
            "--speech-mask-min-space",
            type=int,
            help="min space between spans (if no overlap is enabled)",
        )

        parser.add_argument(
            "--speech-mask-channel-length",
            type=int,
            help="repeat the mask indices multiple times",
        )

        parser.add_argument(
            "--speech-mask-channel-prob",
            type=float,
            help="probability of replacing a token with mask",
        )

        parser.add_argument(
            "--speech-mask-channel-selection",
            type=str,
            choices=["static", "uniform", "normal", "poisson"],
            help="how to choose masks",
        )

        parser.add_argument(
            "--speech-mask-channel-other",
            type=float,
            help="stdev of the mask length in case of 'normal' selection strategy",
        )

        parser.add_argument(
            "--speech-no-mask-channel-overlap",
            action="store_true",
            help="whether to allow masks to overlap",
        )

        parser.add_argument(
            "--no-scale-feature",
            action="store_true",
            help="no scale for the calculated features",
        )

        parser.add_argument(
            "--speech-mask-channel-min-space",
            type=int,
            help="min space between spans (if no overlap is enabled)",
        )

        parser.add_argument(
            "--feature-grad-mult",
            type=float,
            help="reset feature grad mult in wav2vec 2.0 to this",
        )

        # positional embeddings
        parser.add_argument(
            "--conv-pos",
            type=int,
            default=128,
            help="number of filters for convolutional positional embeddings",
        )

        parser.add_argument(
            "--conv-pos-groups",
            type=int,
            default=16,
            help="number of groups for convolutional positional embedding",
        )
        # model configures
        parser.add_argument(
            "--speech-encoder-layers",
            type=int,
            help="number of speech encoder layers",
        )
        parser.add_argument(
            "--text-encoder-layers",
            type=int,
            help="number of text encoder layers",
        )

    def __init__(self, args, alway_mask=False):
        super().__init__(args)
        self.args = args
        self.dropout = args.dropout
        self.embedding_dim = args.encoder_embed_dim
        self.feat_scale = math.sqrt(args.encoder_embed_dim)
        if args.no_scale_feature:
            self.feat_scale = 1.0

        subsample = ConvFeatureExtractionModel(
            conv_layers=eval(args.conv_feature_layers),
            dropout=0.0,
            mode=args.speech_extractor_mode,  # default, layer_norm
            conv_bias=args.speech_conv_bias,
        )
        self.feature_enc_layers = eval(args.conv_feature_layers)
        self.subsample = subsample
        self.feat_proj = (
            nn.Linear(self.feature_enc_layers[-1][0], self.embedding_dim)
            if self.feature_enc_layers[-1][0] != self.embedding_dim
            else None
        )

        self.feat_layer_norm = LayerNorm(self.feature_enc_layers[-1][0])

        self.embed_positions = nn.Conv1d(
            self.embedding_dim,
            self.embedding_dim,
            kernel_size=args.conv_pos,
            padding=args.conv_pos // 2,
            groups=args.conv_pos_groups,
        )
        std = math.sqrt(4 / (args.conv_pos * self.embedding_dim))
        nn.init.normal_(self.embed_positions.weight, mean=0, std=std)
        nn.init.constant_(self.embed_positions.bias, 0)

        self.embed_positions = nn.utils.weight_norm(
            self.embed_positions, name="weight", dim=2
        )
        self.embed_positions = nn.Sequential(
            self.embed_positions, SamePad(args.conv_pos), nn.GELU()
        )

        self.mask_prob = args.speech_mask_prob
        self.mask_selection = args.speech_mask_selection
        self.mask_other = args.speech_mask_other
        self.mask_length = args.speech_mask_length
        self.no_mask_overlap = args.speech_no_mask_overlap
        self.mask_min_space = args.speech_mask_min_space

        self.mask_channel_prob = args.speech_mask_channel_prob
        self.mask_channel_selection = args.speech_mask_channel_selection
        self.mask_channel_other = args.speech_mask_channel_other
        self.mask_channel_length = args.speech_mask_channel_length
        self.no_mask_channel_overlap = args.speech_no_mask_channel_overlap
        self.mask_channel_min_space = args.speech_mask_channel_min_space

        self.dropout_input = nn.Dropout(args.dropout_input)
        self.dropout_features = nn.Dropout(args.dropout_features)

        self.feature_grad_mult = args.feature_grad_mult

        self.mask_emb = nn.Parameter(
            torch.FloatTensor(args.encoder_embed_dim).uniform_()
        )

        self.layers = nn.ModuleList(
            [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
        )
        self.layer_norm = LayerNorm(args.encoder_embed_dim)
        self.normalize_before = args.encoder_normalize_before
        self.alway_mask = alway_mask

    def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
        """
        Computes the output length of the convolutional layers
        """

        def _conv_out_length(input_length, kernel_size, stride):
            return torch.floor((input_length - kernel_size) / stride + 1)

        for i in range(len(self.feature_enc_layers)):
            input_lengths = _conv_out_length(
                input_lengths,
                self.feature_enc_layers[i][1],
                self.feature_enc_layers[i][2],
            )

        return input_lengths.to(torch.long)

    def apply_mask(self, x, padding_mask):
        B, T, C = x.shape
        if self.mask_prob > 0:
            mask_indices = compute_mask_indices(
                (B, T),
                padding_mask,
                self.mask_prob,
                self.mask_length,
                self.mask_selection,
                self.mask_other,
                min_masks=2,
                no_overlap=self.no_mask_overlap,
                min_space=self.mask_min_space,
            )
            mask_indices = torch.from_numpy(mask_indices).to(x.device)
            x[mask_indices] = self.mask_emb
        else:
            mask_indices = None

        if self.mask_channel_prob > 0:
            mask_channel_indices = compute_mask_indices(
                (B, C),
                None,
                self.mask_channel_prob,
                self.mask_channel_length,
                self.mask_channel_selection,
                self.mask_channel_other,
                no_overlap=self.no_mask_channel_overlap,
                min_space=self.mask_channel_min_space,
            )
            mask_channel_indices = (
                torch.from_numpy(mask_channel_indices)
                .to(x.device)
                .unsqueeze(1)
                .expand(-1, T, -1)
            )
            x[mask_channel_indices] = 0

        return x, mask_indices

    def forward(
        self,
        src_tokens,
        src_lengths,
        return_all_hiddens=False,
        padding_mask=None,
        features_only=True,
    ):
        mask = self.training or self.alway_mask
        if self.feature_grad_mult > 0 and self.training:
            features = self.subsample(src_tokens)
            if self.feature_grad_mult != 1.0:
                features = GradMultiply.apply(features, self.feature_grad_mult)
        else:
            with torch.no_grad():
                features = self.subsample(src_tokens)
        features = features.transpose(1, 2)
        features = self.feat_layer_norm(features)
        if self.feat_proj is not None:
            features = self.feat_proj(features)

        if padding_mask is not None:
            input_lengths = (1 - padding_mask.long()).sum(-1)
        else:
            input_lengths = src_lengths
        # apply conv formula to get real output_lengths
        output_lengths = self._get_feat_extract_output_lengths(input_lengths)

        padding_mask = torch.zeros(
            features.shape[:2], dtype=features.dtype, device=features.device
        )

        # these two operations makes sure that all values
        # before the output lengths indices are attended to
        padding_mask[
            (
                torch.arange(padding_mask.shape[0], device=padding_mask.device),
                output_lengths - 1,
            )
        ] = 1
        padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()

        features = self.feat_scale * features if self.feat_scale != 1.0 else features
        unmasked_features = features.clone()

        features = self.dropout_input(features)
        unmasked_features = self.dropout_features(unmasked_features)
        if mask:
            x, mask_indices = self.apply_mask(features, padding_mask)
        else:
            x = features
            mask_indices = None

        def cal_transformer_layers(x, encoder_padding_mask, return_all_hiddens=False):
            # x: B x T x C
            positions = self.embed_positions(x.transpose(1, 2)).transpose(1, 2)
            x = x + positions
            if not self.normalize_before:
                x = self.layer_norm(x)

            # B x T x C -> T x B x C
            x = x.transpose(0, 1)
            encoder_states = []
            for layer in self.layers:
                x = layer(x, encoder_padding_mask)
                if return_all_hiddens:
                    encoder_states.append(x)
            if self.normalize_before:
                x = self.layer_norm(x)
            return x, encoder_states

        x, encoder_states = cal_transformer_layers(x, padding_mask, return_all_hiddens)
        if features_only:
            return {
                "encoder_out": [x],  # [T x B x C]
                "encoder_padding_mask": [padding_mask]
                if padding_mask is not None
                else [],  # B x T
                "encoder_embedding": [],  #
                "encoder_states": encoder_states,  # List[T x B x C]
                "src_tokens": [],
                "src_lengths": [],
                "mask_indices": [mask_indices],
            }

        x_unmasked = x
        if self.mask_prob > 0 or self.mask_channel_prob > 0:
            x_unmasked, _ = cal_transformer_layers(unmasked_features, padding_mask)
        return {
            "encoder_out": [x],  # [T x B x C]
            "encoder_unmasked_out": [x_unmasked],  # [T x B x C]
            "encoder_padding_mask": [padding_mask]
            if padding_mask is not None
            else [],  # B x T
            "encoder_embedding": [],  #
            "encoder_states": encoder_states,  # List[T x B x C]
            "src_tokens": [],
            "src_lengths": [],
            "mask_indices": [mask_indices] if mask_indices is not None else [],  # B X T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        new_encoder_out = (
            []
            if len(encoder_out["encoder_out"]) == 0
            else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
        )

        new_encoder_padding_mask = (
            []
            if len(encoder_out["encoder_padding_mask"]) == 0
            else [
                x.index_select(0, new_order)
                for x in encoder_out["encoder_padding_mask"]
            ]
        )

        new_encoder_embedding = (
            []
            if len(encoder_out["encoder_embedding"]) == 0
            else [
                x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
            ]
        )

        encoder_states = encoder_out["encoder_states"]
        if len(encoder_states) > 0:
            for idx, state in enumerate(encoder_states):
                encoder_states[idx] = state.index_select(1, new_order)

        return {
            "encoder_out": new_encoder_out,  # T x B x C
            "encoder_padding_mask": new_encoder_padding_mask,  # B x T
            "encoder_embedding": new_encoder_embedding,  # B x T x C
            "encoder_states": encoder_states,  # List[T x B x C]
            "src_tokens": [],  # B x T
            "src_lengths": [],  # B x 1
        }


class StackedSpeechWavTransformerEncoder(FairseqEncoder):
    def __init__(self, speech_enc, text_enc_layers, text_layer_norm):
        super().__init__(None)
        self.speech_encoder = speech_enc
        self.text_encoder_layers = text_enc_layers
        self.final_layer_norm = text_layer_norm

    def forward(
        self,
        src_tokens,
        src_lengths=None,
        return_all_hiddens=False,
        padding_mask=None,
        features_only=True,
    ):

        out = self.speech_encoder.forward(
            src_tokens,
            src_lengths,
            return_all_hiddens,
            padding_mask=padding_mask,
            features_only=features_only,
        )
        x = out["encoder_out"][0]
        encoder_padding_mask = None
        if len(out["encoder_padding_mask"]) > 0:
            encoder_padding_mask = out["encoder_padding_mask"][0]

        def cal_text_layers(x, padding_mask, return_all_hiddens=False):
            encoder_states = []
            for layer in self.text_encoder_layers:
                x = layer(x, padding_mask)
                if return_all_hiddens:
                    encoder_states.append(x)
            if self.final_layer_norm is not None:
                x = self.final_layer_norm(x)
            return x, encoder_states

        x, encoder_states = cal_text_layers(x, encoder_padding_mask, return_all_hiddens)
        if features_only:
            return {
                "encoder_out": [x],  # T x B x C
                "encoder_padding_mask": [encoder_padding_mask]
                if encoder_padding_mask is not None
                else [],  # B x T
                "encoder_embedding": [],  # B x T x C
                "encoder_states": encoder_states,  # List[T x B x C]
                "src_tokens": [],
                "src_lengths": [],
            }

        x_u = out["encoder_unmasked_out"][0]
        x_u, _ = cal_text_layers(x_u, encoder_padding_mask)

        return {
            "encoder_out": [x],  # [T x B x C]
            "encoder_unmasked_out": [x_u],  # [T x B x C]
            "encoder_padding_mask": [encoder_padding_mask]
            if encoder_padding_mask is not None
            else [],  # B x T
            "encoder_embedding": [],  #
            "encoder_states": encoder_states,  # List[T x B x C]
            "src_tokens": [],
            "src_lengths": [],
            "mask_indices": out["mask_indices"],  # B X T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        return self.speech_encoder.reorder_encoder_out(encoder_out, new_order)