Spaces:
Running
Running
| """TCN separator with speaker conditioning — the "neural spotlight" of Vanta. | |
| Architecture (Conv-TasNet-style): | |
| encoded mixture (B, N, T') | |
| -> gLN + bottleneck 1x1 Conv (N -> B_chan) | |
| -> [R repeats of X stacked TCN blocks with exponentially growing dilation] | |
| at every block input, add a projected speaker embedding. | |
| -> PReLU + 1x1 Conv (B_chan -> N) -> ReLU | |
| -> mask (B, N, T') | |
| The mask is multiplied elementwise with the encoded mixture to produce | |
| speaker-masked features, which the audio decoder turns back into a waveform. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class GlobalLayerNorm(nn.Module): | |
| """Normalize over both channel and time dimensions (cumulative over time). | |
| Standard LayerNorm normalizes per time-step, which is brittle when audio | |
| volume drifts (e.g., someone whispers then shouts). gLN pools stats across | |
| the entire utterance, giving a single (mean, var) per example — matches the | |
| "we care about texture, not volume" invariant the plan describes. | |
| """ | |
| def __init__(self, channels: int, eps: float = 1e-8): | |
| super().__init__() | |
| self.eps = eps | |
| # learnable affine (gamma, beta) per channel | |
| self.gamma = nn.Parameter(torch.ones(1, channels, 1)) | |
| self.beta = nn.Parameter(torch.zeros(1, channels, 1)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, C, T) | |
| mean = x.mean(dim=(1, 2), keepdim=True) | |
| var = x.var(dim=(1, 2), keepdim=True, unbiased=False) | |
| x = (x - mean) / torch.sqrt(var + self.eps) | |
| return x * self.gamma + self.beta | |
| class TCNBlock(nn.Module): | |
| """One dilated convolutional block with speaker conditioning. | |
| Layout (input -> output): | |
| + speaker embedding (broadcast over time) | |
| 1x1 Conv (B_chan -> H) | |
| PReLU + gLN | |
| Depthwise 1D Conv with dilation d, kernel P | |
| PReLU + gLN | |
| 1x1 Conv (H -> B_chan) -> residual | |
| """ | |
| def __init__( | |
| self, | |
| b_chan: int, | |
| h_chan: int, | |
| kernel: int, | |
| dilation: int, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| padding = (kernel - 1) * dilation // 2 # "same" padding for odd kernel | |
| self.pointwise_in = nn.Conv1d(b_chan, h_chan, kernel_size=1) | |
| self.prelu1 = nn.PReLU(h_chan) | |
| self.norm1 = GlobalLayerNorm(h_chan) | |
| self.depthwise = nn.Conv1d( | |
| h_chan, | |
| h_chan, | |
| kernel_size=kernel, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=h_chan, # depthwise | |
| ) | |
| self.prelu2 = nn.PReLU(h_chan) | |
| self.norm2 = GlobalLayerNorm(h_chan) | |
| # Channel-wise dropout on the block's output path. Zeros an entire | |
| # feature channel (not random elements), which preserves the temporal | |
| # structure the next block expects — the standard choice for 1-D conv | |
| # nets since Dropout1d. Disabled (p=0) when loading legacy checkpoints | |
| # so no behavior change at inference. | |
| self.dropout = nn.Dropout1d(dropout) if dropout > 0 else nn.Identity() | |
| self.pointwise_out = nn.Conv1d(h_chan, b_chan, kernel_size=1) | |
| def forward(self, x: torch.Tensor, spk_bias: torch.Tensor) -> torch.Tensor: | |
| """x: (B, B_chan, T'). spk_bias: (B, B_chan, 1) broadcasts over time.""" | |
| residual = x | |
| h = x + spk_bias # the "neural spotlight" reminder | |
| h = self.pointwise_in(h) | |
| h = self.norm1(self.prelu1(h)) | |
| h = self.depthwise(h) | |
| h = self.norm2(self.prelu2(h)) | |
| h = self.dropout(h) | |
| h = self.pointwise_out(h) | |
| return residual + h | |
| class Separator(nn.Module): | |
| """Mask predictor: encoded mixture + speaker embedding -> mask.""" | |
| def __init__( | |
| self, | |
| enc_channels: int = 512, # N — must match AudioEncoder.num_filters | |
| bottleneck: int = 128, # B | |
| hidden: int = 512, # H | |
| kernel: int = 3, # P | |
| blocks_per_repeat: int = 8, # X | |
| repeats: int = 3, # R | |
| speaker_dim: int = 192, # ECAPA-TDNN embedding dim | |
| dropout: float = 0.0, # per-block Dropout1d probability | |
| ): | |
| super().__init__() | |
| self.enc_channels = enc_channels | |
| self.in_norm = GlobalLayerNorm(enc_channels) | |
| self.in_proj = nn.Conv1d(enc_channels, bottleneck, kernel_size=1) | |
| # One speaker projection, reused at every block. Fewer params than | |
| # per-block projections and works just as well in practice. | |
| self.speaker_proj = nn.Linear(speaker_dim, bottleneck) | |
| self.blocks = nn.ModuleList() | |
| for _ in range(repeats): | |
| for x in range(blocks_per_repeat): | |
| self.blocks.append( | |
| TCNBlock( | |
| b_chan=bottleneck, | |
| h_chan=hidden, | |
| kernel=kernel, | |
| dilation=2**x, | |
| dropout=dropout, | |
| ) | |
| ) | |
| self.out_prelu = nn.PReLU(bottleneck) | |
| self.out_proj = nn.Conv1d(bottleneck, enc_channels, kernel_size=1) | |
| def forward( | |
| self, enc_mix: torch.Tensor, spk_emb: torch.Tensor | |
| ) -> torch.Tensor: | |
| """enc_mix: (B, N, T'). spk_emb: (B, speaker_dim). Returns mask (B, N, T').""" | |
| h = self.in_proj(self.in_norm(enc_mix)) | |
| # Speaker bias computed once; shape (B, B_chan, 1) broadcasts to (B, B_chan, T'). | |
| spk_bias = self.speaker_proj(spk_emb).unsqueeze(-1) | |
| for block in self.blocks: | |
| h = block(h, spk_bias) | |
| mask = self.out_proj(self.out_prelu(h)) | |
| return F.relu(mask) | |