File size: 5,824 Bytes
32de4f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""TCN separator with speaker conditioning — the "neural spotlight" of Vanta.

Architecture (Conv-TasNet-style):
    encoded mixture (B, N, T')
        -> gLN + bottleneck 1x1 Conv (N -> B_chan)
        -> [R repeats of X stacked TCN blocks with exponentially growing dilation]
           at every block input, add a projected speaker embedding.
        -> PReLU + 1x1 Conv (B_chan -> N) -> ReLU
        -> mask (B, N, T')

The mask is multiplied elementwise with the encoded mixture to produce
speaker-masked features, which the audio decoder turns back into a waveform.
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


class GlobalLayerNorm(nn.Module):
    """Normalize over both channel and time dimensions (cumulative over time).

    Standard LayerNorm normalizes per time-step, which is brittle when audio
    volume drifts (e.g., someone whispers then shouts). gLN pools stats across
    the entire utterance, giving a single (mean, var) per example — matches the
    "we care about texture, not volume" invariant the plan describes.
    """

    def __init__(self, channels: int, eps: float = 1e-8):
        super().__init__()
        self.eps = eps
        # learnable affine (gamma, beta) per channel
        self.gamma = nn.Parameter(torch.ones(1, channels, 1))
        self.beta = nn.Parameter(torch.zeros(1, channels, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, T)
        mean = x.mean(dim=(1, 2), keepdim=True)
        var = x.var(dim=(1, 2), keepdim=True, unbiased=False)
        x = (x - mean) / torch.sqrt(var + self.eps)
        return x * self.gamma + self.beta


class TCNBlock(nn.Module):
    """One dilated convolutional block with speaker conditioning.

    Layout (input -> output):
        + speaker embedding (broadcast over time)
        1x1 Conv  (B_chan -> H)
        PReLU + gLN
        Depthwise 1D Conv with dilation d, kernel P
        PReLU + gLN
        1x1 Conv  (H -> B_chan)                -> residual
    """

    def __init__(
        self,
        b_chan: int,
        h_chan: int,
        kernel: int,
        dilation: int,
        dropout: float = 0.0,
    ):
        super().__init__()
        padding = (kernel - 1) * dilation // 2  # "same" padding for odd kernel

        self.pointwise_in = nn.Conv1d(b_chan, h_chan, kernel_size=1)
        self.prelu1 = nn.PReLU(h_chan)
        self.norm1 = GlobalLayerNorm(h_chan)

        self.depthwise = nn.Conv1d(
            h_chan,
            h_chan,
            kernel_size=kernel,
            padding=padding,
            dilation=dilation,
            groups=h_chan,  # depthwise
        )
        self.prelu2 = nn.PReLU(h_chan)
        self.norm2 = GlobalLayerNorm(h_chan)

        # Channel-wise dropout on the block's output path. Zeros an entire
        # feature channel (not random elements), which preserves the temporal
        # structure the next block expects — the standard choice for 1-D conv
        # nets since Dropout1d. Disabled (p=0) when loading legacy checkpoints
        # so no behavior change at inference.
        self.dropout = nn.Dropout1d(dropout) if dropout > 0 else nn.Identity()

        self.pointwise_out = nn.Conv1d(h_chan, b_chan, kernel_size=1)

    def forward(self, x: torch.Tensor, spk_bias: torch.Tensor) -> torch.Tensor:
        """x: (B, B_chan, T'). spk_bias: (B, B_chan, 1) broadcasts over time."""
        residual = x
        h = x + spk_bias  # the "neural spotlight" reminder
        h = self.pointwise_in(h)
        h = self.norm1(self.prelu1(h))
        h = self.depthwise(h)
        h = self.norm2(self.prelu2(h))
        h = self.dropout(h)
        h = self.pointwise_out(h)
        return residual + h


class Separator(nn.Module):
    """Mask predictor: encoded mixture + speaker embedding -> mask."""

    def __init__(
        self,
        enc_channels: int = 512,        # N — must match AudioEncoder.num_filters
        bottleneck: int = 128,          # B
        hidden: int = 512,              # H
        kernel: int = 3,                # P
        blocks_per_repeat: int = 8,     # X
        repeats: int = 3,               # R
        speaker_dim: int = 192,         # ECAPA-TDNN embedding dim
        dropout: float = 0.0,           # per-block Dropout1d probability
    ):
        super().__init__()
        self.enc_channels = enc_channels

        self.in_norm = GlobalLayerNorm(enc_channels)
        self.in_proj = nn.Conv1d(enc_channels, bottleneck, kernel_size=1)

        # One speaker projection, reused at every block. Fewer params than
        # per-block projections and works just as well in practice.
        self.speaker_proj = nn.Linear(speaker_dim, bottleneck)

        self.blocks = nn.ModuleList()
        for _ in range(repeats):
            for x in range(blocks_per_repeat):
                self.blocks.append(
                    TCNBlock(
                        b_chan=bottleneck,
                        h_chan=hidden,
                        kernel=kernel,
                        dilation=2**x,
                        dropout=dropout,
                    )
                )

        self.out_prelu = nn.PReLU(bottleneck)
        self.out_proj = nn.Conv1d(bottleneck, enc_channels, kernel_size=1)

    def forward(
        self, enc_mix: torch.Tensor, spk_emb: torch.Tensor
    ) -> torch.Tensor:
        """enc_mix: (B, N, T'). spk_emb: (B, speaker_dim). Returns mask (B, N, T')."""
        h = self.in_proj(self.in_norm(enc_mix))

        # Speaker bias computed once; shape (B, B_chan, 1) broadcasts to (B, B_chan, T').
        spk_bias = self.speaker_proj(spk_emb).unsqueeze(-1)

        for block in self.blocks:
            h = block(h, spk_bias)

        mask = self.out_proj(self.out_prelu(h))
        return F.relu(mask)