Vanta / vanta /models /vanta.py
Komalpreet Kaur
sync: update server and scripts (model unchanged)
e51885a unverified
"""The full Vanta model: encoder + separator + decoder + speaker encoder.
Forward pass:
1. Speaker encoder (frozen ECAPA-TDNN) turns the enrollment into a 192-d
fingerprint. Runs under no_grad so it doesn't train.
2. Audio encoder turns the mixture into a (B, N, T') feature map.
3. Separator predicts a mask (B, N, T'), conditioned on the fingerprint.
4. Mask * features -> masked features.
5. Audio decoder turns masked features back into a waveform.
Phase 3 deliverable: this runs end-to-end with random init. The weights are
trash (untrained), so the output is garbage audio, but shapes, gradients, and
conditioning pathways must all work.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
from vanta.models.audio_encoder import AudioEncoder, AudioDecoder, expected_output_samples
from vanta.models.separator import Separator
from vanta.models.speaker_encoder import SpeakerEncoder
class SpecAugmentTime(nn.Module):
"""Zero out random time spans of an encoded feature map during training.
Inspired by SpecAugment (Park et al., 2019) but operates on Conv-TasNet
encoder outputs rather than spectrograms. Forces the separator to rely on
the full temporal context instead of memorizing narrow patterns — a direct
attack on the overfitting we saw in earlier runs. No-op in eval mode.
"""
def __init__(self, num_masks: int = 2, max_width: int = 40):
super().__init__()
self.num_masks = num_masks
self.max_width = max_width
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T'). Only mask in training; leave inference untouched.
if not self.training or self.num_masks <= 0 or self.max_width <= 0:
return x
B, _, T = x.shape
out = x.clone()
for b in range(B):
for _ in range(self.num_masks):
width = int(torch.randint(1, self.max_width + 1, (1,)).item())
if T - width <= 0:
continue
start = int(torch.randint(0, T - width, (1,)).item())
out[b, :, start : start + width] = 0
return out
@dataclass
class VantaConfig:
enc_channels: int = 512
enc_kernel: int = 16
enc_stride: int = 8
bottleneck: int = 128
hidden: int = 512
tcn_kernel: int = 3
blocks_per_repeat: int = 8
repeats: int = 3
speaker_dim: int = 192
freeze_speaker: bool = True
dropout: float = 0.0
specaug_num_masks: int = 0 # 0 disables SpecAugment
specaug_max_width: int = 40
class Vanta(nn.Module):
def __init__(self, cfg: VantaConfig | None = None):
super().__init__()
cfg = cfg or VantaConfig()
self.cfg = cfg
self.audio_encoder = AudioEncoder(
num_filters=cfg.enc_channels,
kernel_size=cfg.enc_kernel,
stride=cfg.enc_stride,
)
self.audio_decoder = AudioDecoder(
num_filters=cfg.enc_channels,
kernel_size=cfg.enc_kernel,
stride=cfg.enc_stride,
)
self.separator = Separator(
enc_channels=cfg.enc_channels,
bottleneck=cfg.bottleneck,
hidden=cfg.hidden,
kernel=cfg.tcn_kernel,
blocks_per_repeat=cfg.blocks_per_repeat,
repeats=cfg.repeats,
speaker_dim=cfg.speaker_dim,
dropout=cfg.dropout,
)
self.specaug = SpecAugmentTime(
num_masks=cfg.specaug_num_masks,
max_width=cfg.specaug_max_width,
)
self.speaker_encoder = SpeakerEncoder(freeze=cfg.freeze_speaker)
def embed_speaker(self, enrollment: torch.Tensor) -> torch.Tensor:
"""enrollment: (B, T_enroll). Returns (B, speaker_dim)."""
return self.speaker_encoder(enrollment)
def forward(
self, mixture: torch.Tensor, enrollment: torch.Tensor | None = None,
speaker_embedding: torch.Tensor | None = None,
) -> torch.Tensor:
"""Extract the target speaker from the mixture.
Pass either `enrollment` (we'll encode it) or a precomputed
`speaker_embedding`. Precomputing is faster when the same enrollment is
reused across many mixtures (inference-time trick).
"""
if speaker_embedding is None:
if enrollment is None:
raise ValueError("must pass enrollment or speaker_embedding")
speaker_embedding = self.embed_speaker(enrollment)
enc = self.audio_encoder(mixture) # (B, N, T')
enc = self.specaug(enc) # no-op at inference
mask = self.separator(enc, speaker_embedding) # (B, N, T')
masked = enc * mask
wav = self.audio_decoder(masked) # (B, T_out)
# ConvTranspose1d output length may differ from the input waveform by
# up to (kernel - stride) samples. Align so downstream losses can use
# the original mixture length directly.
target_len = mixture.shape[-1]
if wav.shape[-1] > target_len:
wav = wav[..., :target_len]
elif wav.shape[-1] < target_len:
wav = torch.nn.functional.pad(wav, (0, target_len - wav.shape[-1]))
return wav