|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from fairseq.data.data_utils import compute_mask_indices |
|
|
from fairseq.models import FairseqEncoder |
|
|
from fairseq.models.wav2vec import ConvFeatureExtractionModel |
|
|
from fairseq.modules import GradMultiply, LayerNorm, SamePad, TransformerEncoderLayer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeechWavTransformerEncoder(FairseqEncoder): |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def add_args(parser): |
|
|
parser.add_argument( |
|
|
"--dropout-input", |
|
|
type=float, |
|
|
metavar="D", |
|
|
help="dropout to apply to the input (after feat extr)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dropout-features", |
|
|
type=float, |
|
|
metavar="D", |
|
|
help="dropout to apply to the unmasked features (after feat extr)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--speech-extractor-mode", |
|
|
type=str, |
|
|
default="layer_norm", |
|
|
choices=["default", "layer_norm"], |
|
|
help="feature extractor norm", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-conv-bias", |
|
|
action="store_true", |
|
|
help="include bias in speech conv encoder", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--conv-feature-layers", |
|
|
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", |
|
|
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-length", |
|
|
type=int, |
|
|
help="repeat the mask indices multiple times", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-prob", |
|
|
type=float, |
|
|
help="probability of replacing a token with mask", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-selection", |
|
|
type=str, |
|
|
choices=["static", "uniform", "normal", "poisson"], |
|
|
help="how to choose masks", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-other", |
|
|
type=float, |
|
|
help="stdev of the mask length in case of 'normal' selection strategy", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-no-mask-overlap", |
|
|
action="store_true", |
|
|
help="whether to allow masks to overlap", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-min-space", |
|
|
type=int, |
|
|
help="min space between spans (if no overlap is enabled)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-channel-length", |
|
|
type=int, |
|
|
help="repeat the mask indices multiple times", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-channel-prob", |
|
|
type=float, |
|
|
help="probability of replacing a token with mask", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-channel-selection", |
|
|
type=str, |
|
|
choices=["static", "uniform", "normal", "poisson"], |
|
|
help="how to choose masks", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-channel-other", |
|
|
type=float, |
|
|
help="stdev of the mask length in case of 'normal' selection strategy", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-no-mask-channel-overlap", |
|
|
action="store_true", |
|
|
help="whether to allow masks to overlap", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--no-scale-feature", |
|
|
action="store_true", |
|
|
help="no scale for the calculated features", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-mask-channel-min-space", |
|
|
type=int, |
|
|
help="min space between spans (if no overlap is enabled)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--feature-grad-mult", |
|
|
type=float, |
|
|
help="reset feature grad mult in wav2vec 2.0 to this", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--conv-pos", |
|
|
type=int, |
|
|
default=128, |
|
|
help="number of filters for convolutional positional embeddings", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--conv-pos-groups", |
|
|
type=int, |
|
|
default=16, |
|
|
help="number of groups for convolutional positional embedding", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--speech-encoder-layers", |
|
|
type=int, |
|
|
help="number of speech encoder layers", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text-encoder-layers", |
|
|
type=int, |
|
|
help="number of text encoder layers", |
|
|
) |
|
|
|
|
|
def __init__(self, args, alway_mask=False): |
|
|
super().__init__(args) |
|
|
self.args = args |
|
|
self.dropout = args.dropout |
|
|
self.embedding_dim = args.encoder_embed_dim |
|
|
self.feat_scale = math.sqrt(args.encoder_embed_dim) |
|
|
if args.no_scale_feature: |
|
|
self.feat_scale = 1.0 |
|
|
|
|
|
subsample = ConvFeatureExtractionModel( |
|
|
conv_layers=eval(args.conv_feature_layers), |
|
|
dropout=0.0, |
|
|
mode=args.speech_extractor_mode, |
|
|
conv_bias=args.speech_conv_bias, |
|
|
) |
|
|
self.feature_enc_layers = eval(args.conv_feature_layers) |
|
|
self.subsample = subsample |
|
|
self.feat_proj = ( |
|
|
nn.Linear(self.feature_enc_layers[-1][0], self.embedding_dim) |
|
|
if self.feature_enc_layers[-1][0] != self.embedding_dim |
|
|
else None |
|
|
) |
|
|
|
|
|
self.feat_layer_norm = LayerNorm(self.feature_enc_layers[-1][0]) |
|
|
|
|
|
self.embed_positions = nn.Conv1d( |
|
|
self.embedding_dim, |
|
|
self.embedding_dim, |
|
|
kernel_size=args.conv_pos, |
|
|
padding=args.conv_pos // 2, |
|
|
groups=args.conv_pos_groups, |
|
|
) |
|
|
std = math.sqrt(4 / (args.conv_pos * self.embedding_dim)) |
|
|
nn.init.normal_(self.embed_positions.weight, mean=0, std=std) |
|
|
nn.init.constant_(self.embed_positions.bias, 0) |
|
|
|
|
|
self.embed_positions = nn.utils.weight_norm( |
|
|
self.embed_positions, name="weight", dim=2 |
|
|
) |
|
|
self.embed_positions = nn.Sequential( |
|
|
self.embed_positions, SamePad(args.conv_pos), nn.GELU() |
|
|
) |
|
|
|
|
|
self.mask_prob = args.speech_mask_prob |
|
|
self.mask_selection = args.speech_mask_selection |
|
|
self.mask_other = args.speech_mask_other |
|
|
self.mask_length = args.speech_mask_length |
|
|
self.no_mask_overlap = args.speech_no_mask_overlap |
|
|
self.mask_min_space = args.speech_mask_min_space |
|
|
|
|
|
self.mask_channel_prob = args.speech_mask_channel_prob |
|
|
self.mask_channel_selection = args.speech_mask_channel_selection |
|
|
self.mask_channel_other = args.speech_mask_channel_other |
|
|
self.mask_channel_length = args.speech_mask_channel_length |
|
|
self.no_mask_channel_overlap = args.speech_no_mask_channel_overlap |
|
|
self.mask_channel_min_space = args.speech_mask_channel_min_space |
|
|
|
|
|
self.dropout_input = nn.Dropout(args.dropout_input) |
|
|
self.dropout_features = nn.Dropout(args.dropout_features) |
|
|
|
|
|
self.feature_grad_mult = args.feature_grad_mult |
|
|
|
|
|
self.mask_emb = nn.Parameter( |
|
|
torch.FloatTensor(args.encoder_embed_dim).uniform_() |
|
|
) |
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] |
|
|
) |
|
|
self.layer_norm = LayerNorm(args.encoder_embed_dim) |
|
|
self.normalize_before = args.encoder_normalize_before |
|
|
self.alway_mask = alway_mask |
|
|
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
|
|
""" |
|
|
Computes the output length of the convolutional layers |
|
|
""" |
|
|
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
return torch.floor((input_length - kernel_size) / stride + 1) |
|
|
|
|
|
for i in range(len(self.feature_enc_layers)): |
|
|
input_lengths = _conv_out_length( |
|
|
input_lengths, |
|
|
self.feature_enc_layers[i][1], |
|
|
self.feature_enc_layers[i][2], |
|
|
) |
|
|
|
|
|
return input_lengths.to(torch.long) |
|
|
|
|
|
def apply_mask(self, x, padding_mask): |
|
|
B, T, C = x.shape |
|
|
if self.mask_prob > 0: |
|
|
mask_indices = compute_mask_indices( |
|
|
(B, T), |
|
|
padding_mask, |
|
|
self.mask_prob, |
|
|
self.mask_length, |
|
|
self.mask_selection, |
|
|
self.mask_other, |
|
|
min_masks=2, |
|
|
no_overlap=self.no_mask_overlap, |
|
|
min_space=self.mask_min_space, |
|
|
) |
|
|
mask_indices = torch.from_numpy(mask_indices).to(x.device) |
|
|
x[mask_indices] = self.mask_emb |
|
|
else: |
|
|
mask_indices = None |
|
|
|
|
|
if self.mask_channel_prob > 0: |
|
|
mask_channel_indices = compute_mask_indices( |
|
|
(B, C), |
|
|
None, |
|
|
self.mask_channel_prob, |
|
|
self.mask_channel_length, |
|
|
self.mask_channel_selection, |
|
|
self.mask_channel_other, |
|
|
no_overlap=self.no_mask_channel_overlap, |
|
|
min_space=self.mask_channel_min_space, |
|
|
) |
|
|
mask_channel_indices = ( |
|
|
torch.from_numpy(mask_channel_indices) |
|
|
.to(x.device) |
|
|
.unsqueeze(1) |
|
|
.expand(-1, T, -1) |
|
|
) |
|
|
x[mask_channel_indices] = 0 |
|
|
|
|
|
return x, mask_indices |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src_tokens, |
|
|
src_lengths, |
|
|
return_all_hiddens=False, |
|
|
padding_mask=None, |
|
|
features_only=True, |
|
|
): |
|
|
mask = self.training or self.alway_mask |
|
|
if self.feature_grad_mult > 0 and self.training: |
|
|
features = self.subsample(src_tokens) |
|
|
if self.feature_grad_mult != 1.0: |
|
|
features = GradMultiply.apply(features, self.feature_grad_mult) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
features = self.subsample(src_tokens) |
|
|
features = features.transpose(1, 2) |
|
|
features = self.feat_layer_norm(features) |
|
|
if self.feat_proj is not None: |
|
|
features = self.feat_proj(features) |
|
|
|
|
|
if padding_mask is not None: |
|
|
input_lengths = (1 - padding_mask.long()).sum(-1) |
|
|
else: |
|
|
input_lengths = src_lengths |
|
|
|
|
|
output_lengths = self._get_feat_extract_output_lengths(input_lengths) |
|
|
|
|
|
padding_mask = torch.zeros( |
|
|
features.shape[:2], dtype=features.dtype, device=features.device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
padding_mask[ |
|
|
( |
|
|
torch.arange(padding_mask.shape[0], device=padding_mask.device), |
|
|
output_lengths - 1, |
|
|
) |
|
|
] = 1 |
|
|
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() |
|
|
|
|
|
features = self.feat_scale * features if self.feat_scale != 1.0 else features |
|
|
unmasked_features = features.clone() |
|
|
|
|
|
features = self.dropout_input(features) |
|
|
unmasked_features = self.dropout_features(unmasked_features) |
|
|
if mask: |
|
|
x, mask_indices = self.apply_mask(features, padding_mask) |
|
|
else: |
|
|
x = features |
|
|
mask_indices = None |
|
|
|
|
|
def cal_transformer_layers(x, encoder_padding_mask, return_all_hiddens=False): |
|
|
|
|
|
positions = self.embed_positions(x.transpose(1, 2)).transpose(1, 2) |
|
|
x = x + positions |
|
|
if not self.normalize_before: |
|
|
x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
encoder_states = [] |
|
|
for layer in self.layers: |
|
|
x = layer(x, encoder_padding_mask) |
|
|
if return_all_hiddens: |
|
|
encoder_states.append(x) |
|
|
if self.normalize_before: |
|
|
x = self.layer_norm(x) |
|
|
return x, encoder_states |
|
|
|
|
|
x, encoder_states = cal_transformer_layers(x, padding_mask, return_all_hiddens) |
|
|
if features_only: |
|
|
return { |
|
|
"encoder_out": [x], |
|
|
"encoder_padding_mask": [padding_mask] |
|
|
if padding_mask is not None |
|
|
else [], |
|
|
"encoder_embedding": [], |
|
|
"encoder_states": encoder_states, |
|
|
"src_tokens": [], |
|
|
"src_lengths": [], |
|
|
"mask_indices": [mask_indices], |
|
|
} |
|
|
|
|
|
x_unmasked = x |
|
|
if self.mask_prob > 0 or self.mask_channel_prob > 0: |
|
|
x_unmasked, _ = cal_transformer_layers(unmasked_features, padding_mask) |
|
|
return { |
|
|
"encoder_out": [x], |
|
|
"encoder_unmasked_out": [x_unmasked], |
|
|
"encoder_padding_mask": [padding_mask] |
|
|
if padding_mask is not None |
|
|
else [], |
|
|
"encoder_embedding": [], |
|
|
"encoder_states": encoder_states, |
|
|
"src_tokens": [], |
|
|
"src_lengths": [], |
|
|
"mask_indices": [mask_indices] if mask_indices is not None else [], |
|
|
} |
|
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
|
new_encoder_out = ( |
|
|
[] |
|
|
if len(encoder_out["encoder_out"]) == 0 |
|
|
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] |
|
|
) |
|
|
|
|
|
new_encoder_padding_mask = ( |
|
|
[] |
|
|
if len(encoder_out["encoder_padding_mask"]) == 0 |
|
|
else [ |
|
|
x.index_select(0, new_order) |
|
|
for x in encoder_out["encoder_padding_mask"] |
|
|
] |
|
|
) |
|
|
|
|
|
new_encoder_embedding = ( |
|
|
[] |
|
|
if len(encoder_out["encoder_embedding"]) == 0 |
|
|
else [ |
|
|
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] |
|
|
] |
|
|
) |
|
|
|
|
|
encoder_states = encoder_out["encoder_states"] |
|
|
if len(encoder_states) > 0: |
|
|
for idx, state in enumerate(encoder_states): |
|
|
encoder_states[idx] = state.index_select(1, new_order) |
|
|
|
|
|
return { |
|
|
"encoder_out": new_encoder_out, |
|
|
"encoder_padding_mask": new_encoder_padding_mask, |
|
|
"encoder_embedding": new_encoder_embedding, |
|
|
"encoder_states": encoder_states, |
|
|
"src_tokens": [], |
|
|
"src_lengths": [], |
|
|
} |
|
|
|
|
|
|
|
|
class StackedSpeechWavTransformerEncoder(FairseqEncoder): |
|
|
def __init__(self, speech_enc, text_enc_layers, text_layer_norm): |
|
|
super().__init__(None) |
|
|
self.speech_encoder = speech_enc |
|
|
self.text_encoder_layers = text_enc_layers |
|
|
self.final_layer_norm = text_layer_norm |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src_tokens, |
|
|
src_lengths=None, |
|
|
return_all_hiddens=False, |
|
|
padding_mask=None, |
|
|
features_only=True, |
|
|
): |
|
|
|
|
|
out = self.speech_encoder.forward( |
|
|
src_tokens, |
|
|
src_lengths, |
|
|
return_all_hiddens, |
|
|
padding_mask=padding_mask, |
|
|
features_only=features_only, |
|
|
) |
|
|
x = out["encoder_out"][0] |
|
|
encoder_padding_mask = None |
|
|
if len(out["encoder_padding_mask"]) > 0: |
|
|
encoder_padding_mask = out["encoder_padding_mask"][0] |
|
|
|
|
|
def cal_text_layers(x, padding_mask, return_all_hiddens=False): |
|
|
encoder_states = [] |
|
|
for layer in self.text_encoder_layers: |
|
|
x = layer(x, padding_mask) |
|
|
if return_all_hiddens: |
|
|
encoder_states.append(x) |
|
|
if self.final_layer_norm is not None: |
|
|
x = self.final_layer_norm(x) |
|
|
return x, encoder_states |
|
|
|
|
|
x, encoder_states = cal_text_layers(x, encoder_padding_mask, return_all_hiddens) |
|
|
if features_only: |
|
|
return { |
|
|
"encoder_out": [x], |
|
|
"encoder_padding_mask": [encoder_padding_mask] |
|
|
if encoder_padding_mask is not None |
|
|
else [], |
|
|
"encoder_embedding": [], |
|
|
"encoder_states": encoder_states, |
|
|
"src_tokens": [], |
|
|
"src_lengths": [], |
|
|
} |
|
|
|
|
|
x_u = out["encoder_unmasked_out"][0] |
|
|
x_u, _ = cal_text_layers(x_u, encoder_padding_mask) |
|
|
|
|
|
return { |
|
|
"encoder_out": [x], |
|
|
"encoder_unmasked_out": [x_u], |
|
|
"encoder_padding_mask": [encoder_padding_mask] |
|
|
if encoder_padding_mask is not None |
|
|
else [], |
|
|
"encoder_embedding": [], |
|
|
"encoder_states": encoder_states, |
|
|
"src_tokens": [], |
|
|
"src_lengths": [], |
|
|
"mask_indices": out["mask_indices"], |
|
|
} |
|
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
|
return self.speech_encoder.reorder_encoder_out(encoder_out, new_order) |
|
|
|