File size: 12,079 Bytes
85ba398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import logging
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import II, MISSING, open_dict

from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, Wav2Vec2Config
from fairseq.models.wav2vec.wav2vec2_asr import Embedding, Linear, Wav2VecEncoder, Wav2Vec2AsrConfig
from fairseq.tasks import FairseqTask

logging.basicConfig(level=logging.DEBUG)


@dataclass
class Wav2Vec2ClassificationConfig(Wav2Vec2AsrConfig):
    latent_embed_dim: Optional[int] = field(
        default=None, metadata={"help": "latent dim (encoder w2v -> latent -> class"}
    )
    pooling: str = field(
        default="first_token",
        metadata={"help": "pooling layer choices"},
    )
    activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
        default="gelu", metadata={"help": "activation function to use"}
    )


@register_model("wav2vec_classification", dataclass=Wav2Vec2ClassificationConfig)
class Wav2VecClassification(BaseFairseqModel):
    # TODO: Can be shared/merged with ASR model class as w2v_encoder params are common.
    def __init__(
        self,
        cfg: Wav2Vec2ClassificationConfig,
        w2v_encoder: BaseFairseqModel,
        pooling_layer,
    ):
        super().__init__()
        self.cfg = cfg
        self.w2v_encoder = w2v_encoder
        self.pooling_layer = pooling_layer

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)
        return state_dict

    @classmethod
    def build_model(cls, cfg: Wav2Vec2ClassificationConfig, task: FairseqTask):
        """Build a new model instance."""
        w2v_encoder = Wav2VecEncoder(cfg, None)
        pooling_layer = get_pooling_layer(
            cfg,
            w2v_encoder.w2v_model.encoder.layers[-1].embedding_dim,
            len(task.target_dictionary),
            len(w2v_encoder.w2v_model.encoder.layers),
        )
        return cls(cfg, w2v_encoder, pooling_layer)

    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits = net_output

        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def get_logits(self, net_output):
        return net_output

    def forward(self, **kwargs):
        encoder_out_dict = self.w2v_encoder(**kwargs)
        w2v_encoder_out = encoder_out_dict["encoder_out"]  # TxBxC
        w2v_encoder_padding_mask = encoder_out_dict["padding_mask"]  # BxT
        # w2v_encoder_layer_results = encoder_out_dict["layer_results"]
        return self.pooling_layer(
            last_layer_feats=w2v_encoder_out,
            padding_mask=w2v_encoder_padding_mask,
            # all_layer_feats=w2v_encoder_layer_results,
        )

    # def forward_latent(self, **kwargs):
    #     encoder_out_dict = self.w2v_encoder(**kwargs)
    #     w2v_encoder_out = encoder_out_dict["encoder_out"]
    #     w2v_encoder_padding_mask = encoder_out_dict["encoder_padding_mask"]
    #     w2v_encoder_layer_results = encoder_out_dict["layer_results"]
    #     return self.pooling_layer.forward_latent(
    #         last_layer_feats=w2v_encoder_out,
    #         padding_mask=w2v_encoder_padding_mask,
    #         all_layer_feats=w2v_encoder_layer_results,
    #     )


def get_pooling_layer(
    cfg: Wav2Vec2ClassificationConfig,
    encoder_embed_dim: int,
    num_targets: int,
    encoder_layers: int,
):
    assert cfg.pooling == 'mean'
    if cfg.pooling == "first_token":
        return FirstToken(cfg, encoder_embed_dim, num_targets)
    # elif cfg.pooling == "mean":
    #     return MeanPooling(cfg, encoder_embed_dim, num_targets)
    elif cfg.pooling == "mean":
        return MeanPoolingFast(cfg, encoder_embed_dim, num_targets)
    elif cfg.pooling == "mean_amsoftmax":
        return MeanPoolingFastAMSoftmax(cfg, encoder_embed_dim, num_targets)
    elif cfg.pooling == "max":
        return MaxPoolingFast(cfg, encoder_embed_dim, num_targets)
    elif cfg.pooling == "elmo":
        return LayerWeightedMeanPooling(
            cfg, encoder_embed_dim, num_targets, encoder_layers
        )
    else:
        raise NotImplementedError(f"{cfg.pooling} has not been implemented yet.")


class Pooling(nn.Module):
    def __init__(
        self,
        cfg: Wav2Vec2ClassificationConfig,
        encoder_embed_dim: int,
        num_targets: int,
    ):
        super().__init__()
        self.projection = Linear(encoder_embed_dim, num_targets)

    def forward(self, last_layer_feats, **kwargs):
        raise NotImplementedError()


class FirstToken(Pooling):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, last_layer_feats, **kwargs):
        return self.projection(last_layer_feats[:, 0])


# class MeanPooling(Pooling):
#     def __init__(
#         self,
#         cfg: Wav2VecClassificationConfig,
#         encoder_embed_dim: int,
#         num_targets: int,
#         **kwargs,
#     ):
#         super().__init__(cfg, encoder_embed_dim, num_targets)
#         self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
#         self.linear = Linear(encoder_embed_dim, encoder_embed_dim)

#     def forward(self, last_layer_feats, padding_mask, **kwargs):
#         # last_layer_feats: [BxTxD]
#         # padding_mask: [BxT]
#         last_layer_feats = self.linear(self.activation_fn(last_layer_feats))
#         input_lengths = (1 - padding_mask.long()).sum(-1)
#         pooled_feature_list = []
#         for i in range(len(last_layer_feats)):
#             length = input_lengths[i]
#             pooled_feature = torch.mean(last_layer_feats[i][:length], dim=0)
#             pooled_feature_list.append(pooled_feature)
#         return self.projection(torch.stack(pooled_feature_list))


def fn_mean(x, mask):
    """
    Args:
        x: TxBxD
        mask: BxT
    Return:
        y: BxD
    """
    if mask is not None:
        mask = mask.t()[:, :, None]
        return (x * mask).sum(0) / mask.sum(0)
    else:
        return x.sum(0) / x.shape[0]


class MeanPoolingFast(nn.Module):
    def __init__(
        self,
        cfg: Wav2Vec2ClassificationConfig,
        encoder_embed_dim: int,
        num_targets: int,
        **kwargs,
    ):
        super().__init__()
        self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
        self.latent_embed_dim = (
            cfg.latent_embed_dim
            if cfg.latent_embed_dim is not None
            else encoder_embed_dim
        )
        logging.debug(f"| {self.latent_embed_dim=}")
        self.linear = Linear(encoder_embed_dim, self.latent_embed_dim)
        self.projection = Linear(self.latent_embed_dim, num_targets)

    def forward(self, last_layer_feats, padding_mask, **kwargs):
        """
        Arguments
            features - [TxBxD] Acoustic feature with shape
            padding_mask - [BxT]     Padding Mask
        """
        if padding_mask is not None:
            feat_mask = (~padding_mask).to(last_layer_feats.dtype)
        else:
            feat_mask = None
        feat = self.linear(last_layer_feats)
        feat = fn_mean(feat, feat_mask)
        feat = self.activation_fn(feat)
        return self.projection(feat)

    def forward_latent(self, last_layer_feats, padding_mask, **kwargs):
        """
        Arguments
            features - [TxBxD] Acoustic feature with shape
            padding_mask - [BxT]     Padding Mask
        """
        if padding_mask is not None:
            feat_mask = (~padding_mask).to(last_layer_feats.dtype)
        else:
            feat_mask = None
        feat = self.linear(last_layer_feats)
        feat = fn_mean(feat, feat_mask)
        return feat


class MeanPoolingFastAMSoftmax(MeanPoolingFast):
    def __init__(
        self,
        cfg: Wav2Vec2ClassificationConfig,
        encoder_embed_dim: int,
        num_targets: int,
        **kwargs,
    ):
        super().__init__(cfg, encoder_embed_dim, num_targets, **kwargs)
        self.projection = Linear(self.latent_embed_dim, num_targets, bias=False)
        nn.init.xavier_normal_(self.projection.weight, gain=1)

    def forward(self, last_layer_feats, padding_mask, **kwargs):

        """
        Arguments
            features - [BxTxD] Acoustic feature with shape
            padding_mask - [BxT]     Padding Mask
        """
        feat_mask = (~padding_mask).to(last_layer_feats.dtype)  # T,B -> B,T
        feat = self.linear(last_layer_feats)  # B,T,D
        feat = fn_mean(feat, feat_mask)  # B,D
        feat = self.activation_fn(feat)
        # normalize feat
        feat_norm = F.normalize(feat, p=2, dim=-1)  # B,D
        weight_norm = F.normalize(self.projection.weight.t(), p=2, dim=-1)  # D,K
        cos_fw = feat_norm @ weight_norm
        return cos_fw


def fn_max(x, mask):
    """
    Args:
        x: TxBxD
        mask: BxT
    Return:
        y: BxD
    """
    mask = mask.t()[:, :, None].to(torch.bool)
    return x.masked_fill(~mask, -1e-8).max(0)[0]


class MaxPoolingFast(Pooling):
    def __init__(
        self,
        cfg: Wav2Vec2ClassificationConfig,
        encoder_embed_dim: int,
        num_targets: int,
        **kwargs,
    ):
        super().__init__(cfg, encoder_embed_dim, num_targets)
        self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
        self.linear = Linear(encoder_embed_dim, encoder_embed_dim)

    def forward(self, last_layer_feats, padding_mask, **kwargs):

        """
        Arguments
            features - [TxBxD] Acoustic feature with shape
            padding_mask - [BxT]     Padding Mask
        """
        feat_mask = (~padding_mask).to(last_layer_feats.dtype)
        feat = self.linear(last_layer_feats)
        feat = fn_max(feat, feat_mask)
        feat = self.activation_fn(feat)
        return self.projection(feat)


class LayerWeightedMeanPooling(MeanPoolingFast):
    """Elmo-style weighted average representation."""

    def __init__(
        self,
        cfg: Wav2Vec2ClassificationConfig,
        encoder_embed_dim: int,
        num_targets: int,
        encoder_layers: int,
    ):
        super().__init__(cfg, encoder_embed_dim, num_targets)
        self.num_layers = encoder_layers
        self.weights = nn.Parameter(torch.ones(encoder_layers))

    def forward(self, last_layer_feats, padding_mask, all_layer_feats):
        # last_layer_feats: [BxTxD]
        # padding_mask: [BxT]
        if not self.training:
            msg = (
                f"Number of layers in input features = {len(all_layer_feats)}."
                f" Expected {self.num_layers} layers."
            )
            assert len(all_layer_feats) == self.num_layers, msg

        # Stack up all layers and reshape to (num_layers, features)
        all_layer_feats_stacked = torch.stack(all_layer_feats, dim=0)
        num_layers, *original_feat_shape = all_layer_feats_stacked.shape
        all_layer_feats_stacked_flat = all_layer_feats_stacked.view(num_layers, -1)

        # Weighted average
        normalized_weights = F.softmax(self.weights, dim=-1)
        weighted_avg_features = (
            normalized_weights.unsqueeze(-1) * all_layer_feats_stacked_flat
        ).sum(dim=0)
        weighted_avg_features = weighted_avg_features.view(*original_feat_shape)

        # Mean Pooling on weighted average features.
        return super().forward(weighted_avg_features, padding_mask)