Spaces:
Running
Running
File size: 5,352 Bytes
32de4f6 e51885a 32de4f6 e51885a 32de4f6 e51885a 32de4f6 e51885a 32de4f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """The full Vanta model: encoder + separator + decoder + speaker encoder.
Forward pass:
1. Speaker encoder (frozen ECAPA-TDNN) turns the enrollment into a 192-d
fingerprint. Runs under no_grad so it doesn't train.
2. Audio encoder turns the mixture into a (B, N, T') feature map.
3. Separator predicts a mask (B, N, T'), conditioned on the fingerprint.
4. Mask * features -> masked features.
5. Audio decoder turns masked features back into a waveform.
Phase 3 deliverable: this runs end-to-end with random init. The weights are
trash (untrained), so the output is garbage audio, but shapes, gradients, and
conditioning pathways must all work.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
from vanta.models.audio_encoder import AudioEncoder, AudioDecoder, expected_output_samples
from vanta.models.separator import Separator
from vanta.models.speaker_encoder import SpeakerEncoder
class SpecAugmentTime(nn.Module):
"""Zero out random time spans of an encoded feature map during training.
Inspired by SpecAugment (Park et al., 2019) but operates on Conv-TasNet
encoder outputs rather than spectrograms. Forces the separator to rely on
the full temporal context instead of memorizing narrow patterns — a direct
attack on the overfitting we saw in earlier runs. No-op in eval mode.
"""
def __init__(self, num_masks: int = 2, max_width: int = 40):
super().__init__()
self.num_masks = num_masks
self.max_width = max_width
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T'). Only mask in training; leave inference untouched.
if not self.training or self.num_masks <= 0 or self.max_width <= 0:
return x
B, _, T = x.shape
out = x.clone()
for b in range(B):
for _ in range(self.num_masks):
width = int(torch.randint(1, self.max_width + 1, (1,)).item())
if T - width <= 0:
continue
start = int(torch.randint(0, T - width, (1,)).item())
out[b, :, start : start + width] = 0
return out
@dataclass
class VantaConfig:
enc_channels: int = 512
enc_kernel: int = 16
enc_stride: int = 8
bottleneck: int = 128
hidden: int = 512
tcn_kernel: int = 3
blocks_per_repeat: int = 8
repeats: int = 3
speaker_dim: int = 192
freeze_speaker: bool = True
dropout: float = 0.0
specaug_num_masks: int = 0 # 0 disables SpecAugment
specaug_max_width: int = 40
class Vanta(nn.Module):
def __init__(self, cfg: VantaConfig | None = None):
super().__init__()
cfg = cfg or VantaConfig()
self.cfg = cfg
self.audio_encoder = AudioEncoder(
num_filters=cfg.enc_channels,
kernel_size=cfg.enc_kernel,
stride=cfg.enc_stride,
)
self.audio_decoder = AudioDecoder(
num_filters=cfg.enc_channels,
kernel_size=cfg.enc_kernel,
stride=cfg.enc_stride,
)
self.separator = Separator(
enc_channels=cfg.enc_channels,
bottleneck=cfg.bottleneck,
hidden=cfg.hidden,
kernel=cfg.tcn_kernel,
blocks_per_repeat=cfg.blocks_per_repeat,
repeats=cfg.repeats,
speaker_dim=cfg.speaker_dim,
dropout=cfg.dropout,
)
self.specaug = SpecAugmentTime(
num_masks=cfg.specaug_num_masks,
max_width=cfg.specaug_max_width,
)
self.speaker_encoder = SpeakerEncoder(freeze=cfg.freeze_speaker)
def embed_speaker(self, enrollment: torch.Tensor) -> torch.Tensor:
"""enrollment: (B, T_enroll). Returns (B, speaker_dim)."""
return self.speaker_encoder(enrollment)
def forward(
self, mixture: torch.Tensor, enrollment: torch.Tensor | None = None,
speaker_embedding: torch.Tensor | None = None,
) -> torch.Tensor:
"""Extract the target speaker from the mixture.
Pass either `enrollment` (we'll encode it) or a precomputed
`speaker_embedding`. Precomputing is faster when the same enrollment is
reused across many mixtures (inference-time trick).
"""
if speaker_embedding is None:
if enrollment is None:
raise ValueError("must pass enrollment or speaker_embedding")
speaker_embedding = self.embed_speaker(enrollment)
enc = self.audio_encoder(mixture) # (B, N, T')
enc = self.specaug(enc) # no-op at inference
mask = self.separator(enc, speaker_embedding) # (B, N, T')
masked = enc * mask
wav = self.audio_decoder(masked) # (B, T_out)
# ConvTranspose1d output length may differ from the input waveform by
# up to (kernel - stride) samples. Align so downstream losses can use
# the original mixture length directly.
target_len = mixture.shape[-1]
if wav.shape[-1] > target_len:
wav = wav[..., :target_len]
elif wav.shape[-1] < target_len:
wav = torch.nn.functional.pad(wav, (0, target_len - wav.shape[-1]))
return wav
|