| |
| |
| """ |
| @File : htdemucs.py |
| @Time : 2023/8/8 下午4:27 |
| @Author : waytan |
| @Contact : waytan@tencent.com |
| @License : (C)Copyright 2023, Tencent |
| @Desc : The spectrogram and Hybrid version of Demucs |
| """ |
|
|
| import math |
| import typing as tp |
| from copy import deepcopy |
| from fractions import Fraction |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from einops import rearrange |
| from openunmix.filtering import wiener |
|
|
| from .transformer import CrossTransformerEncoder |
| from .demucs import DConv, rescale_module |
| from .states import capture_init |
| from .spec import spectro, ispectro |
|
|
|
|
| def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): |
| """Tiny wrapper around F.pad, just to allow for reflect padding on small input. |
| If this is the case, we insert extra 0 padding to the right before the reflection happen.""" |
| x0 = x |
| length = x.shape[-1] |
| padding_left, padding_right = paddings |
| if mode == 'reflect': |
| max_pad = max(padding_left, padding_right) |
| if length <= max_pad: |
| extra_pad = max_pad - length + 1 |
| extra_pad_right = min(padding_right, extra_pad) |
| extra_pad_left = extra_pad - extra_pad_right |
| paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right) |
| x = F.pad(x, (extra_pad_left, extra_pad_right)) |
| out = F.pad(x, paddings, mode, value) |
| assert out.shape[-1] == length + padding_left + padding_right |
| assert (out[..., padding_left: padding_left + length] == x0).all() |
| return out |
|
|
|
|
| class ScaledEmbedding(nn.Module): |
| """ |
| Boost learning rate for embeddings (with `scale`). |
| Also, can make embeddings continuous with `smooth`. |
| """ |
| def __init__(self, num_embeddings: int, embedding_dim: int, |
| scale: float = 10., smooth=False): |
| super().__init__() |
| self.embedding = nn.Embedding(num_embeddings, embedding_dim) |
| if smooth: |
| weight = torch.cumsum(self.embedding.weight.data, dim=0) |
| |
| weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] |
| self.embedding.weight.data[:] = weight |
| self.embedding.weight.data /= scale |
| self.scale = scale |
|
|
| @property |
| def weight(self): |
| return self.embedding.weight * self.scale |
|
|
| def forward(self, x): |
| out = self.embedding(x) * self.scale |
| return out |
|
|
|
|
| class HEncLayer(nn.Module): |
| def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, |
| freq=True, dconv=True, norm=True, context=0, dconv_kw=None, pad=True, |
| rewrite=True): |
| """Encoder layer. This used both by the time and the frequency branch. |
| """ |
| super().__init__() |
| norm_fn = lambda d: nn.Identity() |
| if norm: |
| norm_fn = lambda d: nn.GroupNorm(norm_groups, d) |
| if pad: |
| pad = kernel_size // 4 |
| else: |
| pad = 0 |
| klass = nn.Conv1d |
| self.freq = freq |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.empty = empty |
| self.norm = norm |
| self.pad = pad |
| if freq: |
| kernel_size = [kernel_size, 1] |
| stride = [stride, 1] |
| pad = [pad, 0] |
| klass = nn.Conv2d |
| self.conv = klass(chin, chout, kernel_size, stride, pad) |
| if self.empty: |
| return |
| self.norm1 = norm_fn(chout) |
| self.rewrite = None |
| if rewrite: |
| self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) |
| self.norm2 = norm_fn(2 * chout) |
|
|
| self.dconv = None |
| if dconv: |
| self.dconv = DConv(chout, **dconv_kw) |
|
|
| def forward(self, x, inject=None): |
| """ |
| `inject` is used to inject the result from the time branch into the frequency branch, |
| when both have the same stride. |
| """ |
| if not self.freq and x.dim() == 4: |
| b, c, fr, t = x.shape |
| x = x.view(b, -1, t) |
|
|
| if not self.freq: |
| le = x.shape[-1] |
| if not le % self.stride == 0: |
| x = F.pad(x, (0, self.stride - (le % self.stride))) |
| y = self.conv(x) |
| if self.empty: |
| return y |
| if inject is not None: |
| assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) |
| if inject.dim() == 3 and y.dim() == 4: |
| inject = inject[:, :, None] |
| y = y + inject |
| y = F.gelu(self.norm1(y)) |
| if self.dconv: |
| if self.freq: |
| b, c, fr, t = y.shape |
| y = y.permute(0, 2, 1, 3).reshape(-1, c, t) |
| y = self.dconv(y) |
| if self.freq: |
| y = y.view(b, fr, c, t).permute(0, 2, 1, 3) |
| if self.rewrite: |
| z = self.norm2(self.rewrite(y)) |
| z = F.glu(z, dim=1) |
| else: |
| z = y |
| return z |
|
|
|
|
| class MultiWrap(nn.Module): |
| """ |
| Takes one layer and replicate it N times. each replica will act |
| on a frequency band. All is done so that if the N replica have the same weights, |
| then this is exactly equivalent to applying the original module on all frequencies. |
| """ |
| def __init__(self, layer, split_ratios): |
| super().__init__() |
| self.split_ratios = split_ratios |
| self.layers = nn.ModuleList() |
| self.conv = isinstance(layer, HEncLayer) |
| assert not layer.norm |
| assert layer.freq |
| assert layer.pad |
| if not self.conv: |
| assert not layer.context_freq |
| for _ in range(len(split_ratios) + 1): |
| lay = deepcopy(layer) |
| if self.conv: |
| lay.conv.padding = (0, 0) |
| else: |
| lay.pad = False |
| for m in lay.modules(): |
| if hasattr(m, 'reset_parameters'): |
| m.reset_parameters() |
| self.layers.append(lay) |
|
|
| def forward(self, x, skip=None, length=None): |
| _, _, fr, _ = x.shape |
|
|
| ratios = list(self.split_ratios) + [1] |
| start = 0 |
| outs = [] |
| for ratio, layer in zip(ratios, self.layers): |
| if self.conv: |
| pad = layer.kernel_size // 4 |
| if ratio == 1: |
| limit = fr |
| frames = -1 |
| else: |
| limit = int(round(fr * ratio)) |
| le = limit - start |
| if start == 0: |
| le += pad |
| frames = round((le - layer.kernel_size) / layer.stride + 1) |
| limit = start + (frames - 1) * layer.stride + layer.kernel_size |
| if start == 0: |
| limit -= pad |
| assert limit - start > 0, (limit, start) |
| assert limit <= fr, (limit, fr) |
| y = x[:, :, start:limit, :] |
| if start == 0: |
| y = F.pad(y, (0, 0, pad, 0)) |
| if ratio == 1: |
| y = F.pad(y, (0, 0, 0, pad)) |
| outs.append(layer(y)) |
| start = limit - layer.kernel_size + layer.stride |
| else: |
| if ratio == 1: |
| limit = fr |
| else: |
| limit = int(round(fr * ratio)) |
| last = layer.last |
| layer.last = True |
|
|
| y = x[:, :, start:limit] |
| s = skip[:, :, start:limit] |
| out, _ = layer(y, s, None) |
| if outs: |
| outs[-1][:, :, -layer.stride:] += ( |
| out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) |
| out = out[:, :, layer.stride:] |
| if ratio == 1: |
| out = out[:, :, :-layer.stride // 2, :] |
| if start == 0: |
| out = out[:, :, layer.stride // 2:, :] |
| outs.append(out) |
| layer.last = last |
| start = limit |
| out = torch.cat(outs, dim=2) |
| if not self.conv and not last: |
| out = F.gelu(out) |
| if self.conv: |
| return out |
| else: |
| return out, None |
|
|
|
|
| class HDecLayer(nn.Module): |
| def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, |
| freq=True, dconv=True, norm=True, context=1, dconv_kw=None, pad=True, |
| context_freq=True, rewrite=True): |
| """ |
| Same as HEncLayer but for decoder. See `HEncLayer` for documentation. |
| """ |
| super().__init__() |
| norm_fn = lambda d: nn.Identity() |
| if norm: |
| norm_fn = lambda d: nn.GroupNorm(norm_groups, d) |
| if pad: |
| pad = kernel_size // 4 |
| else: |
| pad = 0 |
| self.pad = pad |
| self.last = last |
| self.freq = freq |
| self.chin = chin |
| self.empty = empty |
| self.stride = stride |
| self.kernel_size = kernel_size |
| self.norm = norm |
| self.context_freq = context_freq |
| klass = nn.Conv1d |
| klass_tr = nn.ConvTranspose1d |
| if freq: |
| kernel_size = [kernel_size, 1] |
| stride = [stride, 1] |
| klass = nn.Conv2d |
| klass_tr = nn.ConvTranspose2d |
| self.conv_tr = klass_tr(chin, chout, kernel_size, stride) |
| self.norm2 = norm_fn(chout) |
| if self.empty: |
| return |
| self.rewrite = None |
| if rewrite: |
| if context_freq: |
| self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) |
| else: |
| self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, |
| [0, context]) |
| self.norm1 = norm_fn(2 * chin) |
|
|
| self.dconv = None |
| if dconv: |
| self.dconv = DConv(chin, **dconv_kw) |
|
|
| def forward(self, x, skip, length): |
| if self.freq and x.dim() == 3: |
| b, c, t = x.shape |
| x = x.view(b, self.chin, -1, t) |
|
|
| if not self.empty: |
| x = x + skip |
|
|
| if self.rewrite: |
| y = F.glu(self.norm1(self.rewrite(x)), dim=1) |
| else: |
| y = x |
| if self.dconv: |
| if self.freq: |
| b, c, fr, t = y.shape |
| y = y.permute(0, 2, 1, 3).reshape(-1, c, t) |
| y = self.dconv(y) |
| if self.freq: |
| y = y.view(b, fr, c, t).permute(0, 2, 1, 3) |
| else: |
| y = x |
| assert skip is None |
| z = self.norm2(self.conv_tr(y)) |
| if self.freq: |
| if self.pad: |
| z = z[..., self.pad:-self.pad, :] |
| else: |
| z = z[..., self.pad:self.pad + length] |
| assert z.shape[-1] == length, (z.shape[-1], length) |
| if not self.last: |
| z = F.gelu(z) |
| return z, y |
|
|
|
|
| class HTDemucs(nn.Module): |
| """ |
| Spectrogram and hybrid Demucs model. |
| The spectrogram model has the same structure as Demucs, except the first few layers are over the |
| frequency axis, until there is only 1 frequency, and then it moves to time convolutions. |
| Frequency layers can still access information across time steps thanks to the DConv residual. |
| |
| Hybrid model have a parallel time branch. At some layer, the time branch has the same stride |
| as the frequency branch and then the two are combined. The opposite happens in the decoder. |
| |
| Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), |
| or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on |
| Open Unmix implementation [Stoter et al. 2019]. |
| |
| The loss is always on the temporal domain, by backpropagating through the above |
| output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks |
| a bit Wiener filtering, as doing more iteration at test time will change the spectrogram |
| contribution, without changing the one from the waveform, which will lead to worse performance. |
| I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. |
| CaC on the other hand provides similar performance for hybrid, and works naturally with |
| hybrid models. |
| |
| This model also uses frequency embeddings are used to improve efficiency on convolutions |
| over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). |
| |
| Unlike classic Demucs, there is no resampling here, and normalization is always applied. |
| """ |
|
|
| @capture_init |
| def __init__( |
| self, |
| sources, |
| |
| audio_channels=2, |
| channels=48, |
| channels_time=None, |
| growth=2, |
| |
| nfft=4096, |
| wiener_iters=0, |
| end_iters=0, |
| wiener_residual=False, |
| cac=True, |
| |
| depth=4, |
| rewrite=True, |
| |
| multi_freqs=None, |
| multi_freqs_depth=3, |
| freq_emb=0.2, |
| emb_scale=10, |
| emb_smooth=True, |
| |
| kernel_size=8, |
| time_stride=2, |
| stride=4, |
| context=1, |
| context_enc=0, |
| |
| norm_starts=4, |
| norm_groups=4, |
| |
| dconv_mode=1, |
| dconv_depth=2, |
| dconv_comp=8, |
| dconv_init=1e-3, |
| |
| bottom_channels=0, |
| |
| t_layers=5, |
| t_emb="sin", |
| t_hidden_scale=4.0, |
| t_heads=8, |
| t_dropout=0.0, |
| t_max_positions=10000, |
| t_norm_in=True, |
| t_norm_in_group=False, |
| t_group_norm=False, |
| t_norm_first=True, |
| t_norm_out=True, |
| t_max_period=10000.0, |
| t_weight_decay=0.0, |
| t_lr=None, |
| t_layer_scale=True, |
| t_gelu=True, |
| t_weight_pos_embed=1.0, |
| t_sin_random_shift=0, |
| t_cape_mean_normalize=True, |
| t_cape_augment=True, |
| t_cape_glob_loc_scale=None, |
| t_sparse_self_attn=False, |
| t_sparse_cross_attn=False, |
| t_mask_type="diag", |
| t_mask_random_seed=42, |
| t_sparse_attn_window=500, |
| t_global_window=100, |
| t_sparsity=0.95, |
| t_auto_sparsity=False, |
| |
| t_cross_first=False, |
| |
| rescale=0.1, |
| |
| samplerate=44100, |
| segment=10, |
| use_train_segment=True, |
| ): |
| """ |
| Args: |
| sources (list[str]): list of source names. |
| audio_channels (int): input/output audio channels. |
| channels (int): initial number of hidden channels. |
| channels_time: if not None, use a different `channels` value for the time branch. |
| growth: increase the number of hidden channels by this factor at each layer. |
| nfft: number of fft bins. Note that changing this require careful computation of |
| various shape parameters and will not work out of the box for hybrid models. |
| wiener_iters: when using Wiener filtering, number of iterations at test time. |
| end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. |
| wiener_residual: add residual source before wiener filtering. |
| cac: uses complex as channels, i.e. complex numbers are 2 channels each |
| in input and output. no further processing is done before ISTFT. |
| depth (int): number of layers in the encoder and in the decoder. |
| rewrite (bool): add 1x1 convolution to each layer. |
| multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. |
| multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost |
| layers will be wrapped. |
| freq_emb: add frequency embedding after the first frequency layer if > 0, |
| the actual value controls the weight of the embedding. |
| emb_scale: equivalent to scaling the embedding learning rate |
| emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). |
| kernel_size: kernel_size for encoder and decoder layers. |
| stride: stride for encoder and decoder layers. |
| time_stride: stride for the final time layer, after the merge. |
| context: context for 1x1 conv in the decoder. |
| context_enc: context for 1x1 conv in the encoder. |
| norm_starts: layer at which group norm starts being used. |
| decoder layers are numbered in reverse order. |
| norm_groups: number of groups for group norm. |
| dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. |
| dconv_depth: depth of residual DConv branch. |
| dconv_comp: compression of DConv branch. |
| dconv_attn: adds attention layers in DConv branch starting at this layer. |
| dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. |
| dconv_init: initial scale for the DConv branch LayerScale. |
| bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the |
| transformer in order to change the number of channels |
| t_layers: number of layers in each branch (waveform and spec) of the transformer |
| t_emb: "sin", "cape" or "scaled" |
| t_hidden_scale: the hidden scale of the Feedforward parts of the transformer |
| for instance if C = 384 (the number of channels in the transformer) and |
| t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension |
| 384 * 4 = 1536 |
| t_heads: number of heads for the transformer |
| t_dropout: dropout in the transformer |
| t_max_positions: max_positions for the "scaled" positional embedding, only |
| useful if t_emb="scaled" |
| t_norm_in: (bool) norm before addinf positional embedding and getting into the |
| transformer layers |
| t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the |
| timesteps (GroupNorm with group=1) |
| t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the |
| timesteps (GroupNorm with group=1) |
| t_norm_first: (bool) if True the norm is before the attention and before the FFN |
| t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer |
| t_max_period: (float) denominator in the sinusoidal embedding expression |
| t_weight_decay: (float) weight decay for the transformer |
| t_lr: (float) specific learning rate for the transformer |
| t_layer_scale: (bool) Layer Scale for the transformer |
| t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else |
| t_weight_pos_embed: (float) weighting of the positional embedding |
| t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings |
| see: https://arxiv.org/abs/2106.03143 |
| t_cape_augment: (bool) if t_emb="cape", must be True during training and False |
| during the inference, see: https://arxiv.org/abs/2106.03143 |
| t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters |
| see: https://arxiv.org/abs/2106.03143 |
| t_sparse_self_attn: (bool) if True, the self attentions are sparse |
| t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it |
| unless you designed really specific masks) |
| t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination |
| with '_' between: i.e. "diag_jmask_random" (note that this is permutation |
| invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") |
| t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed |
| that generated the random part of the mask |
| t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and |
| a key (j), the mask is True id |i-j|<=t_sparse_attn_window |
| t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] |
| and mask[:, :t_global_window] will be True |
| t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity |
| level of the random part of the mask. |
| t_cross_first: (bool) if True cross attention is the first layer of the |
| transformer (False seems to be better) |
| rescale: weight rescaling trick |
| use_train_segment: (bool) if True, the actual size that is used during the |
| training is used during inference. |
| """ |
| super().__init__() |
| self.cac = cac |
| self.wiener_residual = wiener_residual |
| self.audio_channels = audio_channels |
| self.sources = sources |
| self.kernel_size = kernel_size |
| self.context = context |
| self.stride = stride |
| self.depth = depth |
| self.bottom_channels = bottom_channels |
| self.channels = channels |
| self.samplerate = samplerate |
| self.segment = segment |
| self.use_train_segment = use_train_segment |
| self.nfft = nfft |
| self.hop_length = nfft // 4 |
| self.wiener_iters = wiener_iters |
| self.end_iters = end_iters |
| self.freq_emb = None |
| assert wiener_iters == end_iters |
|
|
| self.encoder = nn.ModuleList() |
| self.decoder = nn.ModuleList() |
|
|
| self.tencoder = nn.ModuleList() |
| self.tdecoder = nn.ModuleList() |
|
|
| chin = audio_channels |
| chin_z = chin |
| if self.cac: |
| chin_z *= 2 |
| chout = channels_time or channels |
| chout_z = channels |
| freqs = nfft // 2 |
|
|
| for index in range(depth): |
| norm = index >= norm_starts |
| freq = freqs > 1 |
| stri = stride |
| ker = kernel_size |
| if not freq: |
| assert freqs == 1 |
| ker = time_stride * 2 |
| stri = time_stride |
|
|
| pad = True |
| last_freq = False |
| if freq and freqs <= kernel_size: |
| ker = freqs |
| pad = False |
| last_freq = True |
|
|
| kw = { |
| "kernel_size": ker, |
| "stride": stri, |
| "freq": freq, |
| "pad": pad, |
| "norm": norm, |
| "rewrite": rewrite, |
| "norm_groups": norm_groups, |
| "dconv_kw": { |
| "depth": dconv_depth, |
| "compress": dconv_comp, |
| "init": dconv_init, |
| "gelu": True, |
| }, |
| } |
| kwt = dict(kw) |
| kwt["freq"] = 0 |
| kwt["kernel_size"] = kernel_size |
| kwt["stride"] = stride |
| kwt["pad"] = True |
| kw_dec = dict(kw) |
| multi = False |
| if multi_freqs and index < multi_freqs_depth: |
| multi = True |
| kw_dec["context_freq"] = False |
|
|
| if last_freq: |
| chout_z = max(chout, chout_z) |
| chout = chout_z |
|
|
| enc = HEncLayer( |
| chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw |
| ) |
| if freq: |
| tenc = HEncLayer( |
| chin, |
| chout, |
| dconv=dconv_mode & 1, |
| context=context_enc, |
| empty=last_freq, |
| **kwt |
| ) |
| self.tencoder.append(tenc) |
|
|
| if multi: |
| enc = MultiWrap(enc, multi_freqs) |
| self.encoder.append(enc) |
| if index == 0: |
| chin = self.audio_channels * len(self.sources) |
| chin_z = chin |
| if self.cac: |
| chin_z *= 2 |
| dec = HDecLayer( |
| chout_z, |
| chin_z, |
| dconv=dconv_mode & 2, |
| last=index == 0, |
| context=context, |
| **kw_dec |
| ) |
| if multi: |
| dec = MultiWrap(dec, multi_freqs) |
| if freq: |
| tdec = HDecLayer( |
| chout, |
| chin, |
| dconv=dconv_mode & 2, |
| empty=last_freq, |
| last=index == 0, |
| context=context, |
| **kwt |
| ) |
| self.tdecoder.insert(0, tdec) |
| self.decoder.insert(0, dec) |
|
|
| chin = chout |
| chin_z = chout_z |
| chout = int(growth * chout) |
| chout_z = int(growth * chout_z) |
| if freq: |
| if freqs <= kernel_size: |
| freqs = 1 |
| else: |
| freqs //= stride |
| if index == 0 and freq_emb: |
| self.freq_emb = ScaledEmbedding( |
| freqs, chin_z, smooth=emb_smooth, scale=emb_scale |
| ) |
| self.freq_emb_scale = freq_emb |
|
|
| if rescale: |
| rescale_module(self, reference=rescale) |
|
|
| transformer_channels = channels * growth ** (depth - 1) |
| if bottom_channels: |
| self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) |
| self.channel_downsampler = nn.Conv1d( |
| bottom_channels, transformer_channels, 1 |
| ) |
| self.channel_upsampler_t = nn.Conv1d( |
| transformer_channels, bottom_channels, 1 |
| ) |
| self.channel_downsampler_t = nn.Conv1d( |
| bottom_channels, transformer_channels, 1 |
| ) |
|
|
| transformer_channels = bottom_channels |
|
|
| if t_layers > 0: |
| if t_cape_glob_loc_scale is None: |
| t_cape_glob_loc_scale = [5000.0, 1.0, 1.4] |
| self.crosstransformer = CrossTransformerEncoder( |
| dim=transformer_channels, |
| emb=t_emb, |
| hidden_scale=t_hidden_scale, |
| num_heads=t_heads, |
| num_layers=t_layers, |
| cross_first=t_cross_first, |
| dropout=t_dropout, |
| max_positions=t_max_positions, |
| norm_in=t_norm_in, |
| norm_in_group=t_norm_in_group, |
| group_norm=t_group_norm, |
| norm_first=t_norm_first, |
| norm_out=t_norm_out, |
| max_period=t_max_period, |
| weight_decay=t_weight_decay, |
| lr=t_lr, |
| layer_scale=t_layer_scale, |
| gelu=t_gelu, |
| sin_random_shift=t_sin_random_shift, |
| weight_pos_embed=t_weight_pos_embed, |
| cape_mean_normalize=t_cape_mean_normalize, |
| cape_augment=t_cape_augment, |
| cape_glob_loc_scale=t_cape_glob_loc_scale, |
| sparse_self_attn=t_sparse_self_attn, |
| sparse_cross_attn=t_sparse_cross_attn, |
| mask_type=t_mask_type, |
| mask_random_seed=t_mask_random_seed, |
| sparse_attn_window=t_sparse_attn_window, |
| global_window=t_global_window, |
| sparsity=t_sparsity, |
| auto_sparsity=t_auto_sparsity, |
| ) |
| else: |
| self.crosstransformer = None |
|
|
| def _spec(self, x): |
| hl = self.hop_length |
| nfft = self.nfft |
|
|
| |
| |
| |
| |
| |
| |
| |
| assert hl == nfft // 4 |
| le = int(math.ceil(x.shape[-1] / hl)) |
| pad = hl // 2 * 3 |
| x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") |
|
|
| z = spectro(x, nfft, hl)[..., :-1, :] |
| assert z.shape[-1] == le + 4, (z.shape, x.shape, le) |
| z = z[..., 2: 2 + le] |
| return z |
|
|
| def _ispec(self, z, length=None, scale=0): |
| hl = self.hop_length // (4**scale) |
| z = F.pad(z, (0, 0, 0, 1)) |
| z = F.pad(z, (2, 2)) |
| pad = hl // 2 * 3 |
| le = hl * int(math.ceil(length / hl)) + 2 * pad |
| x = ispectro(z, hl, length=le) |
| x = x[..., pad: pad + length] |
| return x |
|
|
| def _magnitude(self, z): |
| |
| |
| if self.cac: |
| b, c, fr, t = z.shape |
| m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) |
| m = m.reshape(b, c * 2, fr, t) |
| else: |
| m = z.abs() |
| return m |
|
|
| def _mask(self, z, m): |
| |
| |
| niters = self.wiener_iters |
| if self.cac: |
| b, s, _, fr, t = m.shape |
| out = m.view(b, s, -1, 2, fr, t).permute(0, 1, 2, 4, 5, 3) |
| out = torch.view_as_complex(out.contiguous()) |
| return out |
| if self.training: |
| niters = self.end_iters |
| if niters < 0: |
| z = z[:, None] |
| return z / (1e-8 + z.abs()) * m |
| else: |
| return self._wiener(m, z, niters) |
|
|
| def _wiener(self, mag_out, mix_stft, niters): |
| |
| init = mix_stft.dtype |
| wiener_win_len = 300 |
| residual = self.wiener_residual |
|
|
| b, s, c, fq, t = mag_out.shape |
| mag_out = mag_out.permute(0, 4, 3, 2, 1) |
| mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) |
|
|
| outs = [] |
| for sample in range(b): |
| pos = 0 |
| out = [] |
| for pos in range(0, t, wiener_win_len): |
| frame = slice(pos, pos + wiener_win_len) |
| z_out = wiener( |
| mag_out[sample, frame], |
| mix_stft[sample, frame], |
| niters, |
| residual=residual, |
| ) |
| out.append(z_out.transpose(-1, -2)) |
| outs.append(torch.cat(out, dim=0)) |
| out = torch.view_as_complex(torch.stack(outs, 0)) |
| out = out.permute(0, 4, 3, 2, 1).contiguous() |
| if residual: |
| out = out[:, :-1] |
| assert list(out.shape) == [b, s, c, fq, t] |
| return out.to(init) |
|
|
| def valid_length(self, length: int): |
| """ |
| Return a length that is appropriate for evaluation. |
| In our case, always return the training length, unless |
| it is smaller than the given length, in which case this |
| raises an error. |
| """ |
| if not self.use_train_segment: |
| return length |
| training_length = int(self.segment * self.samplerate) |
| if training_length < length: |
| raise ValueError( |
| f"Given length {length} is longer than " |
| f"training length {training_length}") |
| return training_length |
|
|
| def forward(self, mix): |
| length = mix.shape[-1] |
| length_pre_pad = None |
| if self.use_train_segment: |
| if self.training: |
| self.segment = Fraction(mix.shape[-1], self.samplerate) |
| else: |
| training_length = int(self.segment * self.samplerate) |
| if mix.shape[-1] < training_length: |
| length_pre_pad = mix.shape[-1] |
| mix = F.pad(mix, (0, training_length - length_pre_pad)) |
| z = self._spec(mix) |
| mag = self._magnitude(z).to(mix.device) |
| x = mag |
|
|
| b, _, fq, t = x.shape |
|
|
| |
| mean = x.mean(dim=(1, 2, 3), keepdim=True) |
| std = x.std(dim=(1, 2, 3), keepdim=True) |
| x = (x - mean) / (1e-5 + std) |
| |
|
|
| |
| xt = mix |
| meant = xt.mean(dim=(1, 2), keepdim=True) |
| stdt = xt.std(dim=(1, 2), keepdim=True) |
| xt = (xt - meant) / (1e-5 + stdt) |
|
|
| |
| saved = [] |
| saved_t = [] |
| lengths = [] |
| lengths_t = [] |
| for idx, encode in enumerate(self.encoder): |
| lengths.append(x.shape[-1]) |
| inject = None |
| if idx < len(self.tencoder): |
| |
| lengths_t.append(xt.shape[-1]) |
| tenc = self.tencoder[idx] |
| xt = tenc(xt) |
| if not tenc.empty: |
| |
| saved_t.append(xt) |
| else: |
| |
| |
| inject = xt |
| x = encode(x, inject) |
| if idx == 0 and self.freq_emb is not None: |
| |
| |
| frs = torch.arange(x.shape[-2], device=x.device) |
| emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) |
| x = x + self.freq_emb_scale * emb |
|
|
| saved.append(x) |
| if self.crosstransformer: |
| if self.bottom_channels: |
| _, _, f, _ = x.shape |
| x = rearrange(x, "b c f t-> b c (f t)") |
| x = self.channel_upsampler(x) |
| x = rearrange(x, "b c (f t)-> b c f t", f=f) |
| xt = self.channel_upsampler_t(xt) |
|
|
| x, xt = self.crosstransformer(x, xt) |
|
|
| if self.bottom_channels: |
| x = rearrange(x, "b c f t-> b c (f t)") |
| x = self.channel_downsampler(x) |
| x = rearrange(x, "b c (f t)-> b c f t", f=f) |
| xt = self.channel_downsampler_t(xt) |
|
|
| for idx, decode in enumerate(self.decoder): |
| skip = saved.pop(-1) |
| x, pre = decode(x, skip, lengths.pop(-1)) |
| |
| |
|
|
| offset = self.depth - len(self.tdecoder) |
| if idx >= offset: |
| tdec = self.tdecoder[idx - offset] |
| length_t = lengths_t.pop(-1) |
| if tdec.empty: |
| assert pre.shape[2] == 1, pre.shape |
| pre = pre[:, :, 0] |
| xt, _ = tdec(pre, None, length_t) |
| else: |
| skip = saved_t.pop(-1) |
| xt, _ = tdec(xt, skip, length_t) |
|
|
| |
| assert len(saved) == 0 |
| assert len(lengths_t) == 0 |
| assert len(saved_t) == 0 |
|
|
| s = len(self.sources) |
| x = x.view(b, s, -1, fq, t) |
| x = x * std[:, None] + mean[:, None] |
|
|
| |
| |
| |
| |
| x_is_mps = x.device.type == "mps" |
| if x_is_mps: |
| x = x.cpu() |
|
|
| zout = self._mask(z, x) |
| if self.use_train_segment: |
| if self.training: |
| x = self._ispec(zout, length) |
| else: |
| x = self._ispec(zout, training_length) |
| else: |
| x = self._ispec(zout, length) |
|
|
| |
| if x_is_mps: |
| x = x.to("mps") |
|
|
| if self.use_train_segment: |
| if self.training: |
| xt = xt.view(b, s, -1, length) |
| else: |
| xt = xt.view(b, s, -1, training_length) |
| else: |
| xt = xt.view(b, s, -1, length) |
| xt = xt * stdt[:, None] + meant[:, None] |
| x = xt + x |
| if length_pre_pad: |
| x = x[..., :length_pre_pad] |
| return x |
|
|