Spaces:
Running
Running
File size: 5,824 Bytes
32de4f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """TCN separator with speaker conditioning — the "neural spotlight" of Vanta.
Architecture (Conv-TasNet-style):
encoded mixture (B, N, T')
-> gLN + bottleneck 1x1 Conv (N -> B_chan)
-> [R repeats of X stacked TCN blocks with exponentially growing dilation]
at every block input, add a projected speaker embedding.
-> PReLU + 1x1 Conv (B_chan -> N) -> ReLU
-> mask (B, N, T')
The mask is multiplied elementwise with the encoded mixture to produce
speaker-masked features, which the audio decoder turns back into a waveform.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
class GlobalLayerNorm(nn.Module):
"""Normalize over both channel and time dimensions (cumulative over time).
Standard LayerNorm normalizes per time-step, which is brittle when audio
volume drifts (e.g., someone whispers then shouts). gLN pools stats across
the entire utterance, giving a single (mean, var) per example — matches the
"we care about texture, not volume" invariant the plan describes.
"""
def __init__(self, channels: int, eps: float = 1e-8):
super().__init__()
self.eps = eps
# learnable affine (gamma, beta) per channel
self.gamma = nn.Parameter(torch.ones(1, channels, 1))
self.beta = nn.Parameter(torch.zeros(1, channels, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T)
mean = x.mean(dim=(1, 2), keepdim=True)
var = x.var(dim=(1, 2), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
return x * self.gamma + self.beta
class TCNBlock(nn.Module):
"""One dilated convolutional block with speaker conditioning.
Layout (input -> output):
+ speaker embedding (broadcast over time)
1x1 Conv (B_chan -> H)
PReLU + gLN
Depthwise 1D Conv with dilation d, kernel P
PReLU + gLN
1x1 Conv (H -> B_chan) -> residual
"""
def __init__(
self,
b_chan: int,
h_chan: int,
kernel: int,
dilation: int,
dropout: float = 0.0,
):
super().__init__()
padding = (kernel - 1) * dilation // 2 # "same" padding for odd kernel
self.pointwise_in = nn.Conv1d(b_chan, h_chan, kernel_size=1)
self.prelu1 = nn.PReLU(h_chan)
self.norm1 = GlobalLayerNorm(h_chan)
self.depthwise = nn.Conv1d(
h_chan,
h_chan,
kernel_size=kernel,
padding=padding,
dilation=dilation,
groups=h_chan, # depthwise
)
self.prelu2 = nn.PReLU(h_chan)
self.norm2 = GlobalLayerNorm(h_chan)
# Channel-wise dropout on the block's output path. Zeros an entire
# feature channel (not random elements), which preserves the temporal
# structure the next block expects — the standard choice for 1-D conv
# nets since Dropout1d. Disabled (p=0) when loading legacy checkpoints
# so no behavior change at inference.
self.dropout = nn.Dropout1d(dropout) if dropout > 0 else nn.Identity()
self.pointwise_out = nn.Conv1d(h_chan, b_chan, kernel_size=1)
def forward(self, x: torch.Tensor, spk_bias: torch.Tensor) -> torch.Tensor:
"""x: (B, B_chan, T'). spk_bias: (B, B_chan, 1) broadcasts over time."""
residual = x
h = x + spk_bias # the "neural spotlight" reminder
h = self.pointwise_in(h)
h = self.norm1(self.prelu1(h))
h = self.depthwise(h)
h = self.norm2(self.prelu2(h))
h = self.dropout(h)
h = self.pointwise_out(h)
return residual + h
class Separator(nn.Module):
"""Mask predictor: encoded mixture + speaker embedding -> mask."""
def __init__(
self,
enc_channels: int = 512, # N — must match AudioEncoder.num_filters
bottleneck: int = 128, # B
hidden: int = 512, # H
kernel: int = 3, # P
blocks_per_repeat: int = 8, # X
repeats: int = 3, # R
speaker_dim: int = 192, # ECAPA-TDNN embedding dim
dropout: float = 0.0, # per-block Dropout1d probability
):
super().__init__()
self.enc_channels = enc_channels
self.in_norm = GlobalLayerNorm(enc_channels)
self.in_proj = nn.Conv1d(enc_channels, bottleneck, kernel_size=1)
# One speaker projection, reused at every block. Fewer params than
# per-block projections and works just as well in practice.
self.speaker_proj = nn.Linear(speaker_dim, bottleneck)
self.blocks = nn.ModuleList()
for _ in range(repeats):
for x in range(blocks_per_repeat):
self.blocks.append(
TCNBlock(
b_chan=bottleneck,
h_chan=hidden,
kernel=kernel,
dilation=2**x,
dropout=dropout,
)
)
self.out_prelu = nn.PReLU(bottleneck)
self.out_proj = nn.Conv1d(bottleneck, enc_channels, kernel_size=1)
def forward(
self, enc_mix: torch.Tensor, spk_emb: torch.Tensor
) -> torch.Tensor:
"""enc_mix: (B, N, T'). spk_emb: (B, speaker_dim). Returns mask (B, N, T')."""
h = self.in_proj(self.in_norm(enc_mix))
# Speaker bias computed once; shape (B, B_chan, 1) broadcasts to (B, B_chan, T').
spk_bias = self.speaker_proj(spk_emb).unsqueeze(-1)
for block in self.blocks:
h = block(h, spk_bias)
mask = self.out_proj(self.out_prelu(h))
return F.relu(mask)
|