File size: 6,619 Bytes
6789f6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
from fairseq.models.wav2vec import ConvFeatureExtractionModel
from fairseq.modules import (
LayerNorm,
SamePad,
TransposeLast,
)
from fairseq.tasks import FairseqTask
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
from .modules import BlockEncoder, Decoder1d
from examples.data2vec.data.modality import Modality
@dataclass
class D2vAudioConfig(D2vModalityConfig):
type: Modality = Modality.AUDIO
extractor_mode: str = "layer_norm"
feature_encoder_spec: str = field(
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
metadata={
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
conv_pos_width: int = field(
default=95,
metadata={"help": "number of filters for convolutional positional embeddings"},
)
conv_pos_groups: int = field(
default=16,
metadata={"help": "number of groups for convolutional positional embedding"},
)
conv_pos_depth: int = field(
default=5,
metadata={"help": "depth of positional encoder network"},
)
conv_pos_pre_ln: bool = False
class AudioEncoder(ModalitySpecificEncoder):
modality_cfg: D2vAudioConfig
def __init__(
self,
modality_cfg: D2vAudioConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
feature_embed_dim = self.feature_enc_layers[-1][0]
local_encoder = ConvFeatureExtractionModel(
conv_layers=self.feature_enc_layers,
dropout=0.0,
mode=modality_cfg.extractor_mode,
conv_bias=False,
)
project_features = nn.Sequential(
TransposeLast(),
nn.LayerNorm(feature_embed_dim),
nn.Linear(feature_embed_dim, embed_dim),
)
num_pos_layers = modality_cfg.conv_pos_depth
k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
positional_encoder = nn.Sequential(
TransposeLast(),
*[
nn.Sequential(
nn.Conv1d(
embed_dim,
embed_dim,
kernel_size=k,
padding=k // 2,
groups=modality_cfg.conv_pos_groups,
),
SamePad(k),
TransposeLast(),
LayerNorm(embed_dim, elementwise_affine=False),
TransposeLast(),
nn.GELU(),
)
for _ in range(num_pos_layers)
],
TransposeLast(),
)
if modality_cfg.conv_pos_pre_ln:
positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim) if not layer_norm_first else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout,
)
decoder = (
Decoder1d(modality_cfg.decoder, embed_dim)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=project_features,
fixed_positional_encoder=None,
relative_positional_encoder=positional_encoder,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def convert_padding_mask(self, x, padding_mask):
def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size) / stride + 1)
for i in range(len(self.feature_enc_layers)):
input_lengths = _conv_out_length(
input_lengths,
self.feature_enc_layers[i][1],
self.feature_enc_layers[i][2],
)
return input_lengths.to(torch.long)
if padding_mask is not None:
input_lengths = (1 - padding_mask.long()).sum(-1)
# apply conv formula to get real output_lengths
output_lengths = get_feat_extract_output_lengths(input_lengths)
if padding_mask.any():
padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask[
(
torch.arange(padding_mask.shape[0], device=padding_mask.device),
output_lengths - 1,
)
] = 1
padding_mask = (
1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
).bool()
else:
padding_mask = torch.zeros(
x.shape[:2], dtype=torch.bool, device=x.device
)
return padding_mask
def reset_parameters(self):
super().reset_parameters()
for mod in self.project_features.children():
if isinstance(mod, nn.Linear):
mod.reset_parameters()
if self.decoder is not None:
self.decoder.reset_parameters()
|