Spaces:
Running
Running
| """The full Vanta model: encoder + separator + decoder + speaker encoder. | |
| Forward pass: | |
| 1. Speaker encoder (frozen ECAPA-TDNN) turns the enrollment into a 192-d | |
| fingerprint. Runs under no_grad so it doesn't train. | |
| 2. Audio encoder turns the mixture into a (B, N, T') feature map. | |
| 3. Separator predicts a mask (B, N, T'), conditioned on the fingerprint. | |
| 4. Mask * features -> masked features. | |
| 5. Audio decoder turns masked features back into a waveform. | |
| Phase 3 deliverable: this runs end-to-end with random init. The weights are | |
| trash (untrained), so the output is garbage audio, but shapes, gradients, and | |
| conditioning pathways must all work. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| from vanta.models.audio_encoder import AudioEncoder, AudioDecoder, expected_output_samples | |
| from vanta.models.separator import Separator | |
| from vanta.models.speaker_encoder import SpeakerEncoder | |
| class SpecAugmentTime(nn.Module): | |
| """Zero out random time spans of an encoded feature map during training. | |
| Inspired by SpecAugment (Park et al., 2019) but operates on Conv-TasNet | |
| encoder outputs rather than spectrograms. Forces the separator to rely on | |
| the full temporal context instead of memorizing narrow patterns — a direct | |
| attack on the overfitting we saw in earlier runs. No-op in eval mode. | |
| """ | |
| def __init__(self, num_masks: int = 2, max_width: int = 40): | |
| super().__init__() | |
| self.num_masks = num_masks | |
| self.max_width = max_width | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, C, T'). Only mask in training; leave inference untouched. | |
| if not self.training or self.num_masks <= 0 or self.max_width <= 0: | |
| return x | |
| B, _, T = x.shape | |
| out = x.clone() | |
| for b in range(B): | |
| for _ in range(self.num_masks): | |
| width = int(torch.randint(1, self.max_width + 1, (1,)).item()) | |
| if T - width <= 0: | |
| continue | |
| start = int(torch.randint(0, T - width, (1,)).item()) | |
| out[b, :, start : start + width] = 0 | |
| return out | |
| class VantaConfig: | |
| enc_channels: int = 512 | |
| enc_kernel: int = 16 | |
| enc_stride: int = 8 | |
| bottleneck: int = 128 | |
| hidden: int = 512 | |
| tcn_kernel: int = 3 | |
| blocks_per_repeat: int = 8 | |
| repeats: int = 3 | |
| speaker_dim: int = 192 | |
| freeze_speaker: bool = True | |
| dropout: float = 0.0 | |
| specaug_num_masks: int = 0 # 0 disables SpecAugment | |
| specaug_max_width: int = 40 | |
| class Vanta(nn.Module): | |
| def __init__(self, cfg: VantaConfig | None = None): | |
| super().__init__() | |
| cfg = cfg or VantaConfig() | |
| self.cfg = cfg | |
| self.audio_encoder = AudioEncoder( | |
| num_filters=cfg.enc_channels, | |
| kernel_size=cfg.enc_kernel, | |
| stride=cfg.enc_stride, | |
| ) | |
| self.audio_decoder = AudioDecoder( | |
| num_filters=cfg.enc_channels, | |
| kernel_size=cfg.enc_kernel, | |
| stride=cfg.enc_stride, | |
| ) | |
| self.separator = Separator( | |
| enc_channels=cfg.enc_channels, | |
| bottleneck=cfg.bottleneck, | |
| hidden=cfg.hidden, | |
| kernel=cfg.tcn_kernel, | |
| blocks_per_repeat=cfg.blocks_per_repeat, | |
| repeats=cfg.repeats, | |
| speaker_dim=cfg.speaker_dim, | |
| dropout=cfg.dropout, | |
| ) | |
| self.specaug = SpecAugmentTime( | |
| num_masks=cfg.specaug_num_masks, | |
| max_width=cfg.specaug_max_width, | |
| ) | |
| self.speaker_encoder = SpeakerEncoder(freeze=cfg.freeze_speaker) | |
| def embed_speaker(self, enrollment: torch.Tensor) -> torch.Tensor: | |
| """enrollment: (B, T_enroll). Returns (B, speaker_dim).""" | |
| return self.speaker_encoder(enrollment) | |
| def forward( | |
| self, mixture: torch.Tensor, enrollment: torch.Tensor | None = None, | |
| speaker_embedding: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| """Extract the target speaker from the mixture. | |
| Pass either `enrollment` (we'll encode it) or a precomputed | |
| `speaker_embedding`. Precomputing is faster when the same enrollment is | |
| reused across many mixtures (inference-time trick). | |
| """ | |
| if speaker_embedding is None: | |
| if enrollment is None: | |
| raise ValueError("must pass enrollment or speaker_embedding") | |
| speaker_embedding = self.embed_speaker(enrollment) | |
| enc = self.audio_encoder(mixture) # (B, N, T') | |
| enc = self.specaug(enc) # no-op at inference | |
| mask = self.separator(enc, speaker_embedding) # (B, N, T') | |
| masked = enc * mask | |
| wav = self.audio_decoder(masked) # (B, T_out) | |
| # ConvTranspose1d output length may differ from the input waveform by | |
| # up to (kernel - stride) samples. Align so downstream losses can use | |
| # the original mixture length directly. | |
| target_len = mixture.shape[-1] | |
| if wav.shape[-1] > target_len: | |
| wav = wav[..., :target_len] | |
| elif wav.shape[-1] < target_len: | |
| wav = torch.nn.functional.pad(wav, (0, target_len - wav.shape[-1])) | |
| return wav | |