File size: 5,352 Bytes
32de4f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51885a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32de4f6
 
 
 
 
 
 
 
 
 
 
 
 
e51885a
 
32de4f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51885a
 
 
 
32de4f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51885a
32de4f6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""The full Vanta model: encoder + separator + decoder + speaker encoder.

Forward pass:
    1. Speaker encoder (frozen ECAPA-TDNN) turns the enrollment into a 192-d
       fingerprint. Runs under no_grad so it doesn't train.
    2. Audio encoder turns the mixture into a (B, N, T') feature map.
    3. Separator predicts a mask (B, N, T'), conditioned on the fingerprint.
    4. Mask * features -> masked features.
    5. Audio decoder turns masked features back into a waveform.

Phase 3 deliverable: this runs end-to-end with random init. The weights are
trash (untrained), so the output is garbage audio, but shapes, gradients, and
conditioning pathways must all work.
"""

from __future__ import annotations

from dataclasses import dataclass

import torch
import torch.nn as nn

from vanta.models.audio_encoder import AudioEncoder, AudioDecoder, expected_output_samples
from vanta.models.separator import Separator
from vanta.models.speaker_encoder import SpeakerEncoder


class SpecAugmentTime(nn.Module):
    """Zero out random time spans of an encoded feature map during training.

    Inspired by SpecAugment (Park et al., 2019) but operates on Conv-TasNet
    encoder outputs rather than spectrograms. Forces the separator to rely on
    the full temporal context instead of memorizing narrow patterns — a direct
    attack on the overfitting we saw in earlier runs. No-op in eval mode.
    """

    def __init__(self, num_masks: int = 2, max_width: int = 40):
        super().__init__()
        self.num_masks = num_masks
        self.max_width = max_width

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, T'). Only mask in training; leave inference untouched.
        if not self.training or self.num_masks <= 0 or self.max_width <= 0:
            return x
        B, _, T = x.shape
        out = x.clone()
        for b in range(B):
            for _ in range(self.num_masks):
                width = int(torch.randint(1, self.max_width + 1, (1,)).item())
                if T - width <= 0:
                    continue
                start = int(torch.randint(0, T - width, (1,)).item())
                out[b, :, start : start + width] = 0
        return out


@dataclass
class VantaConfig:
    enc_channels: int = 512
    enc_kernel: int = 16
    enc_stride: int = 8
    bottleneck: int = 128
    hidden: int = 512
    tcn_kernel: int = 3
    blocks_per_repeat: int = 8
    repeats: int = 3
    speaker_dim: int = 192
    freeze_speaker: bool = True
    dropout: float = 0.0
    specaug_num_masks: int = 0        # 0 disables SpecAugment
    specaug_max_width: int = 40


class Vanta(nn.Module):
    def __init__(self, cfg: VantaConfig | None = None):
        super().__init__()
        cfg = cfg or VantaConfig()
        self.cfg = cfg

        self.audio_encoder = AudioEncoder(
            num_filters=cfg.enc_channels,
            kernel_size=cfg.enc_kernel,
            stride=cfg.enc_stride,
        )
        self.audio_decoder = AudioDecoder(
            num_filters=cfg.enc_channels,
            kernel_size=cfg.enc_kernel,
            stride=cfg.enc_stride,
        )
        self.separator = Separator(
            enc_channels=cfg.enc_channels,
            bottleneck=cfg.bottleneck,
            hidden=cfg.hidden,
            kernel=cfg.tcn_kernel,
            blocks_per_repeat=cfg.blocks_per_repeat,
            repeats=cfg.repeats,
            speaker_dim=cfg.speaker_dim,
            dropout=cfg.dropout,
        )
        self.specaug = SpecAugmentTime(
            num_masks=cfg.specaug_num_masks,
            max_width=cfg.specaug_max_width,
        )
        self.speaker_encoder = SpeakerEncoder(freeze=cfg.freeze_speaker)

    def embed_speaker(self, enrollment: torch.Tensor) -> torch.Tensor:
        """enrollment: (B, T_enroll). Returns (B, speaker_dim)."""
        return self.speaker_encoder(enrollment)

    def forward(
        self, mixture: torch.Tensor, enrollment: torch.Tensor | None = None,
        speaker_embedding: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Extract the target speaker from the mixture.

        Pass either `enrollment` (we'll encode it) or a precomputed
        `speaker_embedding`. Precomputing is faster when the same enrollment is
        reused across many mixtures (inference-time trick).
        """
        if speaker_embedding is None:
            if enrollment is None:
                raise ValueError("must pass enrollment or speaker_embedding")
            speaker_embedding = self.embed_speaker(enrollment)

        enc = self.audio_encoder(mixture)                       # (B, N, T')
        enc = self.specaug(enc)                                 # no-op at inference
        mask = self.separator(enc, speaker_embedding)           # (B, N, T')
        masked = enc * mask
        wav = self.audio_decoder(masked)                        # (B, T_out)

        # ConvTranspose1d output length may differ from the input waveform by
        # up to (kernel - stride) samples. Align so downstream losses can use
        # the original mixture length directly.
        target_len = mixture.shape[-1]
        if wav.shape[-1] > target_len:
            wav = wav[..., :target_len]
        elif wav.shape[-1] < target_len:
            wav = torch.nn.functional.pad(wav, (0, target_len - wav.shape[-1]))
        return wav