Vanta / vanta /models /audio_encoder.py
Komalpreet Kaur
Initial deploy: Vanta TSE backend (FastAPI + trained checkpoint)
32de4f6 unverified
"""Learnable time-domain encoder/decoder pair (Conv-TasNet style).
Why a learnable encoder instead of STFT: STFT discards phase, so reconstructions
sound metallic. A 1D-Conv "plays the role" of a filter bank and learns features
that preserve enough information to rebuild clean audio via its transposed twin.
Conv-TasNet defaults (Luo & Mesgarani, 2019):
num_filters = 512
kernel_size = 16 (1 ms at 16 kHz)
stride = 8 (50% overlap)
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
class AudioEncoder(nn.Module):
"""Raw waveform -> non-negative feature map (B, N, T')."""
def __init__(self, num_filters: int = 512, kernel_size: int = 16, stride: int = 8):
super().__init__()
self.num_filters = num_filters
self.kernel_size = kernel_size
self.stride = stride
# bias=False: the output goes through ReLU, a learnable bias adds nothing.
self.conv = nn.Conv1d(
in_channels=1,
out_channels=num_filters,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=False,
)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""wav: (B, T) or (B, 1, T). Returns features (B, N, T')."""
if wav.dim() == 2:
wav = wav.unsqueeze(1)
feat = self.conv(wav)
# ReLU keeps features non-negative. Masking (later in the TCN) expects
# the mask to shrink/keep features, not invert their sign.
return F.relu(feat)
class AudioDecoder(nn.Module):
"""Feature map (B, N, T') -> raw waveform (B, T).
The transpose of the encoder. We match its kernel_size/stride so that a
perfectly trained encoder/decoder pair can reconstruct the waveform losslessly.
"""
def __init__(self, num_filters: int = 512, kernel_size: int = 16, stride: int = 8):
super().__init__()
self.num_filters = num_filters
self.kernel_size = kernel_size
self.stride = stride
self.deconv = nn.ConvTranspose1d(
in_channels=num_filters,
out_channels=1,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=False,
)
def forward(self, feat: torch.Tensor) -> torch.Tensor:
"""feat: (B, N, T'). Returns waveform (B, T)."""
wav = self.deconv(feat)
return wav.squeeze(1)
def n_frames_for(samples: int, kernel_size: int = 16, stride: int = 8) -> int:
"""Frames produced by a Conv1d with no padding."""
return (samples - kernel_size) // stride + 1
def expected_output_samples(n_frames: int, kernel_size: int = 16, stride: int = 8) -> int:
"""Samples produced by a ConvTranspose1d (no padding, no output_padding)."""
return (n_frames - 1) * stride + kernel_size