import math from typing import List, Union, Iterable import torch import torch.nn as nn import torch.nn.functional as F from torchaudio.transforms import Spectrogram from torch.cuda.amp import autocast from src.data import DataProperties ################################################################################ # DeepSpeech2 model (Amodei et al.) as implemented by Sean Naren ################################################################################ class SequenceWise(nn.Module): def __init__(self, module: nn.Module): """ Collapses input of shape (seq_len, n_batch, n_features) to (seq_len * n_batch, n_features) and applies a nn.Module along the feature dimension. Allows handling of variable sequence lengths and batch sizes. Parameters ---------- module (nn.Module): module to apply to input """ super(SequenceWise, self).__init__() self.module = module def forward(self, x: torch.Tensor): # assume input shape (seq_len, n_batch, n_features) t, n = x.size(0), x.size(1) x = x.view(t * n, -1) x = self.module(x) x = x.view(t, n, -1) return x def __repr__(self): tmpstr = self.__class__.__name__ + ' (\n' tmpstr += self.module.__repr__() tmpstr += ')' return tmpstr class MaskConv(nn.Module): def __init__(self, seq_module: nn.Sequential): """ Adds padding to the output of each layer in a given convolution stack based on a set of given lengths. This ensures that the results of the model do not change when batch sizes change during inference. Expects input with shape (n_batch, n_channels, ???, seq_len) Parameters ---------- seq_module (nn.Sequential): the sequential module containing the convolution stack """ super(MaskConv, self).__init__() self.seq_module = seq_module def forward(self, x: torch.Tensor, lengths: Iterable): """ Parameters ---------- x (Tensor): input with shape (n_batch, n_channels, ???, seq_len) lengths (list): list of target lengths Returns ------- masked (Tensor): padded output of convolution stack lengths (list): list of target lengths """ for module in self.seq_module: x = module(x) mask = torch.BoolTensor(x.size()).fill_(0) if x.is_cuda: mask = mask.cuda() for i, length in enumerate(lengths): length = length.item() if (mask[i].size(2) - length) > 0: mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) x = x.masked_fill(mask, 0) return x, lengths class InferenceBatchSoftmax(nn.Module): """Apply softmax along final tensor dimension in inference mode only""" def forward(self, input_: torch.Tensor): if not self.training: return F.softmax(input_, dim=-1) else: return input_ class BatchRNN(nn.Module): """RNN layer with optional batch normalization""" def __init__(self, input_size: int, hidden_size: int, rnn_type=nn.LSTM, bidirectional: bool = False, batch_norm: bool = True): super(BatchRNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.bidirectional = bidirectional # apply time-distributed batch normalization self.batch_norm = SequenceWise( nn.BatchNorm1d(input_size)) if batch_norm else None self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, bidirectional=bidirectional, bias=True) self.num_directions = 2 if bidirectional else 1 def flatten_parameters(self): self.rnn.flatten_parameters() def forward(self, x: torch.Tensor, output_lengths: torch.Tensor): if self.batch_norm is not None: x = self.batch_norm(x) x = nn.utils.rnn.pack_padded_sequence(x, output_lengths) x, h = self.rnn(x) x, _ = nn.utils.rnn.pad_packed_sequence(x) # sum forward and backward contexts if bidirectional if self.bidirectional: x = x.view( x.size(0), x.size(1), 2, -1 ).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) return x class Lookahead(nn.Module): """ Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks from Wang et al 2016. """ def __init__(self, n_features: int, context: int): """ Parameters ---------- n_features (int): feature dimension context (int): context length in frames, corresponding to a lookahead of (context - 1) frames """ super(Lookahead, self).__init__() assert context > 0, 'Must provide nonzero context length' self.context = context self.n_features = n_features # pad to preserve sequence length in output self.pad = (0, self.context - 1) self.conv = nn.Conv1d( self.n_features, self.n_features, kernel_size=self.context, stride=1, groups=self.n_features, padding=0, bias=False ) def forward(self, x: torch.Tensor): """ Parameters ---------- x (Tensor): shape (seq_len, n_batch, n_features) Returns ------- out (Tensor): shape (seq_len, n_batch, n_features) """ x = x.transpose(0, 1).transpose(1, 2) x = F.pad(x, pad=self.pad, value=0) x = self.conv(x) x = x.transpose(1, 2).transpose(0, 1).contiguous() return x def __repr__(self): return self.__class__.__name__ + '(' \ + 'n_features=' + str(self.n_features) \ + ', context=' + str(self.context) + ')' class DeepSpeech(nn.Module): def __init__(self, window_size: float = 0.02, window_stride: float = 0.01, normalize: bool = True): """ Parameters ---------- """ super().__init__() # hard-code to match pre-trained implementation self.sample_rate = 16000 self.labels = [ '_', "'", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '|'] self.sep_idx = len(self.labels) - 1 self.blank_idx = 0 self.hidden_size = 1024 self.hidden_layers = 5 self.lookahead_context = 0 self.bidirectional: bool = True self.normalize = normalize num_classes = len(self.labels) # check sample rate if DataProperties.get("sample_rate") != self.sample_rate: raise ValueError(f"Incompatible data and model sample rates " f"{DataProperties.get('sample_rate')}, " f"{self.sample_rate}") # spectrogram processing - matches original Librosa implementation # (MSE ~1e-11 for 4s audio) self.spec = Spectrogram( n_fft=int(self.sample_rate * window_size), win_length=int(self.sample_rate * window_size), hop_length=int(self.sample_rate * window_stride), window_fn=torch.hamming_window, center=True, pad_mode='constant', power=1 ) # convolutional spectrogram encoder (acoustic model) self.conv = MaskConv(nn.Sequential( nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), nn.BatchNorm2d(32), nn.Hardtanh(0, 20, inplace=True), nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), nn.BatchNorm2d(32), nn.Hardtanh(0, 20, inplace=True) )) # compute RNN input size using conv formula (W - F + 2P)/ S+1 rnn_input_size = int(math.floor((self.sample_rate * window_size) / 2) + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) rnn_input_size *= 32 # RNN stack self.rnns = nn.Sequential( BatchRNN( input_size=rnn_input_size, hidden_size=self.hidden_size, rnn_type=nn.LSTM, bidirectional=self.bidirectional, batch_norm=False ), *( BatchRNN( input_size=self.hidden_size, hidden_size=self.hidden_size, rnn_type=nn.LSTM, bidirectional=self.bidirectional ) for x in range(self.hidden_layers - 1) ) ) # post-RNN lookahead (for unidirectional models) self.lookahead = nn.Sequential( Lookahead(self.hidden_size, context=self.lookahead_context), nn.Hardtanh(0, 20, inplace=True) ) if not self.bidirectional else None # final time-distributed linear layer for token prediction fully_connected = nn.Sequential( nn.BatchNorm1d(self.hidden_size), nn.Linear(self.hidden_size, num_classes, bias=False) ) self.fc = nn.Sequential( SequenceWise(fully_connected), ) self.inference_softmax = InferenceBatchSoftmax() def forward(self, x, lengths=None): """ Parameters ---------- x (Tensor): lengths (Tensor): """ # ensure RNN blocks are in train mode to allow backpropagation for # attack optimization if not self.rnns.training: self.rnns.train() # require batch, channel dimensions assert x.ndim >= 2 n_batch, *channel_dims, signal_len = x.shape if x.ndim == 2: x = x.unsqueeze(1) # convert to mono audio x = x.mean(dim=1, keepdim=True) # compute spectrogram x = self.spec(x) # (n_batch, 1, n_freq, n_frames) x = torch.log1p(x) if self.normalize: mean = x.mean() std = x.std() x = x - mean x = x / std lengths = lengths or torch.full((n_batch,), x.shape[-1], dtype=torch.long) lengths = lengths.cpu().int() output_lengths = self.get_seq_lens(lengths) x, _ = self.conv(x, output_lengths) sizes = x.size() x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH for rnn in self.rnns: x = rnn(x, output_lengths) if not self.bidirectional: # no need for lookahead layer in bidirectional x = self.lookahead(x) x = self.fc(x) x = x.transpose(0, 1) return x def get_seq_lens(self, input_length): """ Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable containing the size sequences that will be output by the network. :param input_length: 1D Tensor :return: 1D Tensor scaled by model """ seq_len = input_length for m in self.conv.modules(): if type(m) == nn.modules.conv.Conv2d: seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1) return seq_len.int()