Spaces:
Runtime error
Runtime error
| import math | |
| from collections import OrderedDict | |
| from typing import Optional | |
| from torch import Tensor | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torchmetrics.functional import( | |
| scale_invariant_signal_noise_ratio as si_snr, | |
| signal_noise_ratio as snr, | |
| signal_distortion_ratio as sdr, | |
| scale_invariant_signal_distortion_ratio as si_sdr) | |
| from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding | |
| def mod_pad(x, chunk_size, pad): | |
| # Mod pad the input to perform integer number of | |
| # inferences | |
| mod = 0 | |
| if (x.shape[-1] % chunk_size) != 0: | |
| mod = chunk_size - (x.shape[-1] % chunk_size) | |
| x = F.pad(x, (0, mod)) | |
| x = F.pad(x, pad) | |
| return x, mod | |
| class LayerNormPermuted(nn.LayerNorm): | |
| def __init__(self, *args, **kwargs): | |
| super(LayerNormPermuted, self).__init__(*args, **kwargs) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [B, C, T] | |
| """ | |
| x = x.permute(0, 2, 1) # [B, T, C] | |
| x = super().forward(x) | |
| x = x.permute(0, 2, 1) # [B, C, T] | |
| return x | |
| class DepthwiseSeparableConv(nn.Module): | |
| """ | |
| Depthwise separable convolutions | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, | |
| padding, dilation): | |
| super(DepthwiseSeparableConv, self).__init__() | |
| self.layers = nn.Sequential( | |
| nn.Conv1d(in_channels, in_channels, kernel_size, stride, | |
| padding, groups=in_channels, dilation=dilation), | |
| LayerNormPermuted(in_channels), | |
| nn.ReLU(), | |
| nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, | |
| padding=0), | |
| LayerNormPermuted(out_channels), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class DilatedCausalConvEncoder(nn.Module): | |
| """ | |
| A dilated causal convolution based encoder for encoding | |
| time domain audio input into latent space. | |
| """ | |
| def __init__(self, channels, num_layers, kernel_size=3): | |
| super(DilatedCausalConvEncoder, self).__init__() | |
| self.channels = channels | |
| self.num_layers = num_layers | |
| self.kernel_size = kernel_size | |
| # Compute buffer lengths for each layer | |
| # buf_length[i] = (kernel_size - 1) * dilation[i] | |
| self.buf_lengths = [(kernel_size - 1) * 2**i | |
| for i in range(num_layers)] | |
| # Compute buffer start indices for each layer | |
| self.buf_indices = [0] | |
| for i in range(num_layers - 1): | |
| self.buf_indices.append( | |
| self.buf_indices[-1] + self.buf_lengths[i]) | |
| # Dilated causal conv layers aggregate previous context to obtain | |
| # contexful encoded input. | |
| _dcc_layers = OrderedDict() | |
| for i in range(num_layers): | |
| dcc_layer = DepthwiseSeparableConv( | |
| channels, channels, kernel_size=3, stride=1, | |
| padding=0, dilation=2**i) | |
| _dcc_layers.update({'dcc_%d' % i: dcc_layer}) | |
| self.dcc_layers = nn.Sequential(_dcc_layers) | |
| def init_ctx_buf(self, batch_size, device): | |
| """ | |
| Returns an initialized context buffer for a given batch size. | |
| """ | |
| return torch.zeros( | |
| (batch_size, self.channels, | |
| (self.kernel_size - 1) * (2**self.num_layers - 1)), | |
| device=device) | |
| def forward(self, x, ctx_buf): | |
| """ | |
| Encodes input audio `x` into latent space, and aggregates | |
| contextual information in `ctx_buf`. Also generates new context | |
| buffer with updated context. | |
| Args: | |
| x: [B, in_channels, T] | |
| Input multi-channel audio. | |
| ctx_buf: {[B, channels, self.buf_length[0]], ...} | |
| A list of tensors holding context for each dilation | |
| causal conv layer. (len(ctx_buf) == self.num_layers) | |
| Returns: | |
| ctx_buf: {[B, channels, self.buf_length[0]], ...} | |
| Updated context buffer with output as the | |
| last element. | |
| """ | |
| T = x.shape[-1] # Sequence length | |
| for i in range(self.num_layers): | |
| buf_start_idx = self.buf_indices[i] | |
| buf_end_idx = self.buf_indices[i] + self.buf_lengths[i] | |
| # DCC input: concatenation of current output and context | |
| dcc_in = torch.cat( | |
| (ctx_buf[..., buf_start_idx:buf_end_idx], x), dim=-1) | |
| # Push current output to the context buffer | |
| ctx_buf[..., buf_start_idx:buf_end_idx] = \ | |
| dcc_in[..., -self.buf_lengths[i]:] | |
| # Residual connection | |
| x = x + self.dcc_layers[i](dcc_in) | |
| return x, ctx_buf | |
| class CausalTransformerDecoderLayer(torch.nn.TransformerDecoderLayer): | |
| """ | |
| Adapted from: | |
| "https://github.com/alexmt-scale/causal-transformer-decoder/blob/" | |
| "0caf6ad71c46488f76d89845b0123d2550ef792f/" | |
| "causal_transformer_decoder/model.py#L77" | |
| """ | |
| def forward( | |
| self, | |
| tgt: Tensor, | |
| memory: Optional[Tensor] = None, | |
| chunk_size: int = 1 | |
| ) -> Tensor: | |
| tgt_last_tok = tgt[:, -chunk_size:, :] | |
| # self attention part | |
| tmp_tgt, sa_map = self.self_attn( | |
| tgt_last_tok, | |
| tgt, | |
| tgt, | |
| attn_mask=None, # not needed because we only care about the last token | |
| key_padding_mask=None, | |
| ) | |
| tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt) | |
| tgt_last_tok = self.norm1(tgt_last_tok) | |
| # encoder-decoder attention | |
| if memory is not None: | |
| tmp_tgt, ca_map = self.multihead_attn( | |
| tgt_last_tok, | |
| memory, | |
| memory, | |
| attn_mask=None, # Attend to the entire chunk | |
| key_padding_mask=None, | |
| ) | |
| tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt) | |
| tgt_last_tok = self.norm2(tgt_last_tok) | |
| # final feed-forward network | |
| tmp_tgt = self.linear2( | |
| self.dropout(self.activation(self.linear1(tgt_last_tok))) | |
| ) | |
| tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt) | |
| tgt_last_tok = self.norm3(tgt_last_tok) | |
| return tgt_last_tok, sa_map, ca_map | |
| class CausalTransformerDecoder(nn.Module): | |
| """ | |
| A casual transformer decoder which decodes input vectors using | |
| precisely `ctx_len` past vectors in the sequence, and using no future | |
| vectors at all. | |
| """ | |
| def __init__(self, model_dim, ctx_len, chunk_size, num_layers, | |
| nhead, use_pos_enc, ff_dim): | |
| super(CausalTransformerDecoder, self).__init__() | |
| self.num_layers = num_layers | |
| self.model_dim = model_dim | |
| self.ctx_len = ctx_len | |
| self.chunk_size = chunk_size | |
| self.nhead = nhead | |
| self.use_pos_enc = use_pos_enc | |
| self.unfold = nn.Unfold(kernel_size=(ctx_len + chunk_size, 1), stride=chunk_size) | |
| self.pos_enc = PositionalEncoding(model_dim, max_len=200) | |
| self.tf_dec_layers = nn.ModuleList([CausalTransformerDecoderLayer( | |
| d_model=model_dim, nhead=nhead, dim_feedforward=ff_dim, | |
| batch_first=True) for _ in range(num_layers)]) | |
| def init_ctx_buf(self, batch_size, device): | |
| return torch.zeros( | |
| (batch_size, self.num_layers + 1, self.ctx_len, self.model_dim), | |
| device=device) | |
| def _causal_unfold(self, x): | |
| """ | |
| Unfolds the sequence into a batch of sequences | |
| prepended with `ctx_len` previous values. | |
| Args: | |
| x: [B, ctx_len + L, C] | |
| ctx_len: int | |
| Returns: | |
| [B * L, ctx_len + 1, C] | |
| """ | |
| B, T, C = x.shape | |
| x = x.permute(0, 2, 1) # [B, C, ctx_len + L] | |
| x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1] | |
| x = x.permute(0, 2, 1) | |
| x = x.reshape(B, -1, C, self.ctx_len + self.chunk_size) | |
| x = x.reshape(-1, C, self.ctx_len + self.chunk_size) | |
| x = x.permute(0, 2, 1) | |
| return x | |
| def forward(self, tgt, mem, ctx_buf, probe=False): | |
| """ | |
| Args: | |
| x: [B, model_dim, T] | |
| ctx_buf: [B, num_layers, model_dim, ctx_len] | |
| """ | |
| mem, _ = mod_pad(mem, self.chunk_size, (0, 0)) | |
| tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0)) | |
| # Input sequence length | |
| B, C, T = tgt.shape | |
| tgt = tgt.permute(0, 2, 1) | |
| mem = mem.permute(0, 2, 1) | |
| # Prepend mem with the context | |
| mem = torch.cat((ctx_buf[:, 0, :, :], mem), dim=1) | |
| ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :] | |
| mem_ctx = self._causal_unfold(mem) | |
| if self.use_pos_enc: | |
| mem_ctx = mem_ctx + self.pos_enc(mem_ctx) | |
| # Attention chunk size: required to ensure the model | |
| # wouldn't trigger an out-of-memory error when working | |
| # on long sequences. | |
| K = 1000 | |
| for i, tf_dec_layer in enumerate(self.tf_dec_layers): | |
| # Update the tgt with context | |
| tgt = torch.cat((ctx_buf[:, i + 1, :, :], tgt), dim=1) | |
| ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :] | |
| # Compute encoded output | |
| tgt_ctx = self._causal_unfold(tgt) | |
| if self.use_pos_enc and i == 0: | |
| tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx) | |
| tgt = torch.zeros_like(tgt_ctx)[:, -self.chunk_size:, :] | |
| for i in range(int(math.ceil(tgt.shape[0] / K))): | |
| tgt[i*K:(i+1)*K], _sa_map, _ca_map = tf_dec_layer( | |
| tgt_ctx[i*K:(i+1)*K], mem_ctx[i*K:(i+1)*K], | |
| self.chunk_size) | |
| tgt = tgt.reshape(B, T, C) | |
| tgt = tgt.permute(0, 2, 1) | |
| if mod != 0: | |
| tgt = tgt[..., :-mod] | |
| return tgt, ctx_buf | |
| class MaskNet(nn.Module): | |
| def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len, | |
| dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection, proj): | |
| super(MaskNet, self).__init__() | |
| self.skip_connection = skip_connection | |
| self.proj = proj | |
| # Encoder based on dilated causal convolutions. | |
| self.encoder = DilatedCausalConvEncoder(channels=enc_dim, | |
| num_layers=num_enc_layers) | |
| # Project between encoder and decoder dimensions | |
| self.proj_e2d_e = nn.Sequential( | |
| nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0, | |
| groups=dec_dim), | |
| nn.ReLU()) | |
| self.proj_e2d_l = nn.Sequential( | |
| nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0, | |
| groups=dec_dim), | |
| nn.ReLU()) | |
| self.proj_d2e = nn.Sequential( | |
| nn.Conv1d(dec_dim, enc_dim, kernel_size=1, stride=1, padding=0, | |
| groups=dec_dim), | |
| nn.ReLU()) | |
| # Transformer decoder that operates on chunks of size | |
| # buffer size. | |
| self.decoder = CausalTransformerDecoder( | |
| model_dim=dec_dim, ctx_len=dec_buf_len, chunk_size=dec_chunk_size, | |
| num_layers=num_dec_layers, nhead=8, use_pos_enc=use_pos_enc, | |
| ff_dim=2 * dec_dim) | |
| def forward(self, x, l, enc_buf, dec_buf): | |
| """ | |
| Generates a mask based on encoded input `e` and the one-hot | |
| label `label`. | |
| Args: | |
| x: [B, C, T] | |
| Input audio sequence | |
| l: [B, C] | |
| Label embedding | |
| ctx_buf: {[B, C, <receptive field of the layer>], ...} | |
| List of context buffers maintained by DCC encoder | |
| """ | |
| # Enocder the label integrated input | |
| e, enc_buf = self.encoder(x, enc_buf) | |
| # Label integration | |
| l = l.unsqueeze(2) * e | |
| # Project to `dec_dim` dimensions | |
| if self.proj: | |
| e = self.proj_e2d_e(e) | |
| m = self.proj_e2d_l(l) | |
| # Cross-attention to predict the mask | |
| m, dec_buf = self.decoder(m, e, dec_buf) | |
| else: | |
| # Cross-attention to predict the mask | |
| m, dec_buf = self.decoder(l, e, dec_buf) | |
| # Project mask to encoder dimensions | |
| if self.proj: | |
| m = self.proj_d2e(m) | |
| # Final mask after residual connection | |
| if self.skip_connection: | |
| m = l + m | |
| return m, enc_buf, dec_buf | |
| class Net(nn.Module): | |
| def __init__(self, label_len, L=8, | |
| enc_dim=512, num_enc_layers=10, | |
| dec_dim=256, dec_buf_len=100, num_dec_layers=2, | |
| dec_chunk_size=72, out_buf_len=2, | |
| use_pos_enc=True, skip_connection=True, proj=True, lookahead=True): | |
| super(Net, self).__init__() | |
| self.L = L | |
| self.out_buf_len = out_buf_len | |
| self.enc_dim = enc_dim | |
| self.lookahead = lookahead | |
| # Input conv to convert input audio to a latent representation | |
| kernel_size = 3 * L if lookahead else L | |
| self.in_conv = nn.Sequential( | |
| nn.Conv1d(in_channels=1, | |
| out_channels=enc_dim, kernel_size=kernel_size, stride=L, | |
| padding=0, bias=False), | |
| nn.ReLU()) | |
| # Label embedding layer | |
| self.label_embedding = nn.Sequential( | |
| nn.Linear(label_len, 512), | |
| nn.LayerNorm(512), | |
| nn.ReLU(), | |
| nn.Linear(512, enc_dim), | |
| nn.LayerNorm(enc_dim), | |
| nn.ReLU()) | |
| # Mask generator | |
| self.mask_gen = MaskNet( | |
| enc_dim=enc_dim, num_enc_layers=num_enc_layers, | |
| dec_dim=dec_dim, dec_buf_len=dec_buf_len, | |
| dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers, | |
| use_pos_enc=use_pos_enc, skip_connection=skip_connection, proj=proj) | |
| # Output conv layer | |
| self.out_conv = nn.Sequential( | |
| nn.ConvTranspose1d( | |
| in_channels=enc_dim, out_channels=1, | |
| kernel_size=(out_buf_len + 1) * L, | |
| stride=L, | |
| padding=out_buf_len * L, bias=False), | |
| nn.Tanh()) | |
| def init_buffers(self, batch_size, device): | |
| enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device) | |
| dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device) | |
| out_buf = torch.zeros(batch_size, self.enc_dim, self.out_buf_len, | |
| device=device) | |
| return enc_buf, dec_buf, out_buf | |
| def forward(self, x, label, init_enc_buf=None, init_dec_buf=None, | |
| init_out_buf=None, pad=True): | |
| """ | |
| Extracts the audio corresponding to the `label` in the given | |
| `mixture`. Generates `chunk_size` samples per iteration. | |
| Args: | |
| mixed: [B, n_mics, T] | |
| input audio mixture | |
| label: [B, num_labels] | |
| one hot label | |
| Returns: | |
| out: [B, n_spk, T] | |
| extracted audio with sounds corresponding to the `label` | |
| """ | |
| mod = 0 | |
| if pad: | |
| pad_size = (self.L, self.L) if self.lookahead else (0, 0) | |
| x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size) | |
| if init_enc_buf is None or init_dec_buf is None or init_out_buf is None: | |
| assert init_enc_buf is None and \ | |
| init_dec_buf is None and \ | |
| init_out_buf is None, \ | |
| "Both buffers have to initialized, or " \ | |
| "both of them have to be None." | |
| enc_buf, dec_buf, out_buf = self.init_buffers( | |
| x.shape[0], x.device) | |
| else: | |
| enc_buf, dec_buf, out_buf = \ | |
| init_enc_buf, init_dec_buf, init_out_buf | |
| # Generate latent space representation of the input | |
| x = self.in_conv(x) | |
| # Generate label embedding | |
| l = self.label_embedding(label) # [B, label_len] --> [B, channels] | |
| # Generate mask corresponding to the label | |
| m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf) | |
| # Apply mask and decode | |
| x = x * m | |
| x = torch.cat((out_buf, x), dim=-1) | |
| out_buf = x[..., -self.out_buf_len:] | |
| x = self.out_conv(x) | |
| # Remove mod padding, if present. | |
| if mod != 0: | |
| x = x[:, :, :-mod] | |
| if init_enc_buf is None: | |
| return x | |
| else: | |
| return x, enc_buf, dec_buf, out_buf | |
| # Define optimizer, loss and metrics | |
| def optimizer(model, data_parallel=False, **kwargs): | |
| return optim.Adam(model.parameters(), **kwargs) | |
| def loss(pred, tgt): | |
| return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean() | |
| def metrics(mixed, output, gt): | |
| """ Function to compute metrics """ | |
| metrics = {} | |
| def metric_i(metric, src, pred, tgt): | |
| _vals = [] | |
| for s, t, p in zip(src, tgt, pred): | |
| _vals.append((metric(p, t) - metric(s, t)).cpu().item()) | |
| return _vals | |
| for m_fn in [snr, si_snr]: | |
| metrics[m_fn.__name__] = metric_i(m_fn, | |
| mixed[:, :gt.shape[1], :], | |
| output, | |
| gt) | |
| return metrics | |