|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Callable, Dict, Optional |
|
|
from fairseq.models.wav2vec import ConvFeatureExtractionModel |
|
|
from fairseq.modules import ( |
|
|
LayerNorm, |
|
|
SamePad, |
|
|
TransposeLast, |
|
|
) |
|
|
from fairseq.tasks import FairseqTask |
|
|
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias |
|
|
from .modules import BlockEncoder, Decoder1d |
|
|
from examples.data2vec.data.modality import Modality |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class D2vAudioConfig(D2vModalityConfig): |
|
|
type: Modality = Modality.AUDIO |
|
|
extractor_mode: str = "layer_norm" |
|
|
feature_encoder_spec: str = field( |
|
|
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", |
|
|
metadata={ |
|
|
"help": "string describing convolutional feature extraction layers in form of a python list that contains " |
|
|
"[(dim, kernel_size, stride), ...]" |
|
|
}, |
|
|
) |
|
|
conv_pos_width: int = field( |
|
|
default=95, |
|
|
metadata={"help": "number of filters for convolutional positional embeddings"}, |
|
|
) |
|
|
conv_pos_groups: int = field( |
|
|
default=16, |
|
|
metadata={"help": "number of groups for convolutional positional embedding"}, |
|
|
) |
|
|
conv_pos_depth: int = field( |
|
|
default=5, |
|
|
metadata={"help": "depth of positional encoder network"}, |
|
|
) |
|
|
conv_pos_pre_ln: bool = False |
|
|
|
|
|
|
|
|
class AudioEncoder(ModalitySpecificEncoder): |
|
|
|
|
|
modality_cfg: D2vAudioConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
modality_cfg: D2vAudioConfig, |
|
|
embed_dim: int, |
|
|
make_block: Callable[[float], nn.ModuleList], |
|
|
norm_layer: Callable[[int], nn.LayerNorm], |
|
|
layer_norm_first: bool, |
|
|
alibi_biases: Dict, |
|
|
task: Optional[FairseqTask], |
|
|
): |
|
|
|
|
|
self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec) |
|
|
feature_embed_dim = self.feature_enc_layers[-1][0] |
|
|
|
|
|
local_encoder = ConvFeatureExtractionModel( |
|
|
conv_layers=self.feature_enc_layers, |
|
|
dropout=0.0, |
|
|
mode=modality_cfg.extractor_mode, |
|
|
conv_bias=False, |
|
|
) |
|
|
|
|
|
project_features = nn.Sequential( |
|
|
TransposeLast(), |
|
|
nn.LayerNorm(feature_embed_dim), |
|
|
nn.Linear(feature_embed_dim, embed_dim), |
|
|
) |
|
|
|
|
|
num_pos_layers = modality_cfg.conv_pos_depth |
|
|
k = max(3, modality_cfg.conv_pos_width // num_pos_layers) |
|
|
|
|
|
positional_encoder = nn.Sequential( |
|
|
TransposeLast(), |
|
|
*[ |
|
|
nn.Sequential( |
|
|
nn.Conv1d( |
|
|
embed_dim, |
|
|
embed_dim, |
|
|
kernel_size=k, |
|
|
padding=k // 2, |
|
|
groups=modality_cfg.conv_pos_groups, |
|
|
), |
|
|
SamePad(k), |
|
|
TransposeLast(), |
|
|
LayerNorm(embed_dim, elementwise_affine=False), |
|
|
TransposeLast(), |
|
|
nn.GELU(), |
|
|
) |
|
|
for _ in range(num_pos_layers) |
|
|
], |
|
|
TransposeLast(), |
|
|
) |
|
|
|
|
|
if modality_cfg.conv_pos_pre_ln: |
|
|
positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder) |
|
|
|
|
|
dpr = np.linspace( |
|
|
modality_cfg.start_drop_path_rate, |
|
|
modality_cfg.end_drop_path_rate, |
|
|
modality_cfg.prenet_depth, |
|
|
) |
|
|
context_encoder = BlockEncoder( |
|
|
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), |
|
|
norm_layer(embed_dim) if not layer_norm_first else None, |
|
|
layer_norm_first, |
|
|
modality_cfg.prenet_layerdrop, |
|
|
modality_cfg.prenet_dropout, |
|
|
) |
|
|
|
|
|
decoder = ( |
|
|
Decoder1d(modality_cfg.decoder, embed_dim) |
|
|
if modality_cfg.decoder is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) |
|
|
|
|
|
super().__init__( |
|
|
modality_cfg=modality_cfg, |
|
|
embed_dim=embed_dim, |
|
|
local_encoder=local_encoder, |
|
|
project_features=project_features, |
|
|
fixed_positional_encoder=None, |
|
|
relative_positional_encoder=positional_encoder, |
|
|
context_encoder=context_encoder, |
|
|
decoder=decoder, |
|
|
get_alibi_bias=alibi_bias_fn, |
|
|
) |
|
|
|
|
|
def convert_padding_mask(self, x, padding_mask): |
|
|
def get_feat_extract_output_lengths(input_lengths: torch.LongTensor): |
|
|
""" |
|
|
Computes the output length of the convolutional layers |
|
|
""" |
|
|
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
return torch.floor((input_length - kernel_size) / stride + 1) |
|
|
|
|
|
for i in range(len(self.feature_enc_layers)): |
|
|
input_lengths = _conv_out_length( |
|
|
input_lengths, |
|
|
self.feature_enc_layers[i][1], |
|
|
self.feature_enc_layers[i][2], |
|
|
) |
|
|
|
|
|
return input_lengths.to(torch.long) |
|
|
|
|
|
if padding_mask is not None: |
|
|
input_lengths = (1 - padding_mask.long()).sum(-1) |
|
|
|
|
|
output_lengths = get_feat_extract_output_lengths(input_lengths) |
|
|
|
|
|
if padding_mask.any(): |
|
|
padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device) |
|
|
|
|
|
|
|
|
|
|
|
padding_mask[ |
|
|
( |
|
|
torch.arange(padding_mask.shape[0], device=padding_mask.device), |
|
|
output_lengths - 1, |
|
|
) |
|
|
] = 1 |
|
|
padding_mask = ( |
|
|
1 - padding_mask.flip([-1]).cumsum(-1).flip([-1]) |
|
|
).bool() |
|
|
else: |
|
|
padding_mask = torch.zeros( |
|
|
x.shape[:2], dtype=torch.bool, device=x.device |
|
|
) |
|
|
|
|
|
return padding_mask |
|
|
|
|
|
def reset_parameters(self): |
|
|
super().reset_parameters() |
|
|
for mod in self.project_features.children(): |
|
|
if isinstance(mod, nn.Linear): |
|
|
mod.reset_parameters() |
|
|
if self.decoder is not None: |
|
|
self.decoder.reset_parameters() |
|
|
|