richiejp's picture
Initial upload: LocalVQE demo Space
4313d1d verified
import torch
import torch.nn as nn
from .align import AlignBlock
from .blocks import (
FE,
DecoderBlock,
EncoderBlock,
S4DBottleneck,
)
from .ccm import CCM
class DCTEncoder(nn.Module):
def __init__(self, n_freqs=256, kernel_size=512, stride=256):
super().__init__()
n_filters = kernel_size
self.n_freqs = n_freqs
self.conv = nn.Conv1d(1, n_filters, kernel_size, stride=stride, bias=False)
self._init_dct(n_filters, kernel_size)
def _init_dct(self, n_filters, kernel_size):
import math as _math
k = torch.arange(n_filters).unsqueeze(1).float()
n = torch.arange(kernel_size).unsqueeze(0).float()
basis = torch.cos(_math.pi * (2 * n + 1) * k / (2 * kernel_size))
basis[0] *= 1.0 / (kernel_size ** 0.5)
basis[1:] *= (2.0 / kernel_size) ** 0.5
with torch.no_grad():
self.conv.weight.copy_(basis.unsqueeze(1))
def forward(self, x):
pad = self.conv.kernel_size[0] // 2
x_padded = torch.nn.functional.pad(x, (pad, pad))
out = self.conv(x_padded.unsqueeze(1))
B, _N, T = out.shape
return out.reshape(B, self.n_freqs, 2, T).permute(0, 1, 3, 2)
class DCTDecoder(nn.Module):
def __init__(self, n_freqs=256, kernel_size=512, stride=256):
super().__init__()
n_filters = kernel_size
self.n_freqs = n_freqs
self.linear = nn.Linear(n_filters, kernel_size, bias=False)
self.kernel_size = kernel_size
self.stride = stride
self.pad = kernel_size // 2
self.register_buffer('_overlap_count', None)
def _init_from_encoder(self, encoder):
with torch.no_grad():
self.linear.weight.copy_(encoder.conv.weight.squeeze(1).T)
def forward(self, x, length=None):
B, F, T, _2 = x.shape
x = x.permute(0, 1, 3, 2).reshape(B, F * 2, T)
frames = self.linear(x.transpose(1, 2))
out_len = (T - 1) * self.stride + self.kernel_size
output = torch.nn.functional.fold(
frames.transpose(1, 2),
output_size=(1, out_len),
kernel_size=(1, self.kernel_size),
stride=(1, self.stride),
).squeeze(2).squeeze(1)
if self._overlap_count is None or self._overlap_count.shape[-1] != out_len:
ones = torch.ones(1, self.kernel_size, T, device=x.device)
self._overlap_count = torch.nn.functional.fold(
ones,
output_size=(1, out_len),
kernel_size=(1, self.kernel_size),
stride=(1, self.stride),
).squeeze().clamp(min=1)
output = output / self._overlap_count
output = output[:, self.pad:]
if length is not None:
if output.shape[-1] < length:
output = torch.nn.functional.pad(output, (0, length - output.shape[-1]))
else:
output = output[:, :length]
return output
def compute_freq_progression(n_freqs, kernel_size, n_stages=5):
_kh, kw = kernel_size
pad_left = (kw - 1) // 2
pad_right = kw - 1 - pad_left
f = n_freqs
freqs = [f]
for _ in range(n_stages):
f = (f + pad_left + pad_right - kw) // 2 + 1
freqs.append(f)
return freqs
class LocalVQE(nn.Module):
def __init__(
self,
mic_channels=None,
far_channels=None,
align_hidden=32,
dmax=32,
power_law_c=0.3,
n_freqs=257,
kernel_size=(4, 3),
bottleneck_hidden=0,
):
super().__init__()
if mic_channels is None:
mic_channels = [2, 64, 128, 128, 128, 128]
if far_channels is None:
far_channels = [2, 32, 128]
self.n_freqs = n_freqs
ks = tuple(kernel_size)
self.encoder = DCTEncoder(n_freqs=n_freqs, kernel_size=512, stride=256)
self.decoder = DCTDecoder(n_freqs=n_freqs, kernel_size=512, stride=256)
self.decoder._init_from_encoder(self.encoder)
self.fe_mic = FE(c=power_law_c)
self.fe_ref = FE(c=power_law_c)
self.mic_enc1 = EncoderBlock(mic_channels[0], mic_channels[1], kernel_size=ks)
self.mic_enc2 = EncoderBlock(mic_channels[1], mic_channels[2], kernel_size=ks)
self.far_enc1 = EncoderBlock(far_channels[0], far_channels[1], kernel_size=ks)
self.far_enc2 = EncoderBlock(far_channels[1], far_channels[2], kernel_size=ks)
self.align = AlignBlock(
in_channels=mic_channels[2],
hidden_channels=align_hidden,
dmax=dmax,
)
self.mic_enc3 = EncoderBlock(mic_channels[2] * 2, mic_channels[3], kernel_size=ks)
self.mic_enc4 = EncoderBlock(mic_channels[3], mic_channels[4], kernel_size=ks)
self.mic_enc5 = EncoderBlock(mic_channels[4], mic_channels[5], kernel_size=ks)
freqs = compute_freq_progression(n_freqs, ks)
bn_input = mic_channels[5] * freqs[5]
bn_hidden = bottleneck_hidden if bottleneck_hidden > 0 else bn_input // 2
self.bottleneck = S4DBottleneck(bn_input, bn_hidden)
self.dec5 = DecoderBlock(mic_channels[5], mic_channels[4], kernel_size=ks)
self.dec4 = DecoderBlock(mic_channels[4], mic_channels[3], kernel_size=ks)
self.dec3 = DecoderBlock(mic_channels[3], mic_channels[2], kernel_size=ks)
self.dec2 = DecoderBlock(mic_channels[2], mic_channels[1], kernel_size=ks)
self.dec1 = DecoderBlock(mic_channels[1], 27, kernel_size=ks, is_last=True)
self.mask = CCM()
self._init_ccm_identity()
def _init_ccm_identity(self):
conv = self.dec1.deconv.conv
with torch.no_grad():
conv.bias.zero_()
conv.bias[7] = 1.0
conv.bias[34] = 1.0
def forward(self, mic_wav, ref_wav):
mic_enc = self.encoder(mic_wav)
ref_enc = self.encoder(ref_wav)
mic_fe = self.fe_mic(mic_enc)
ref_fe = self.fe_ref(ref_enc)
mic_e1 = self.mic_enc1(mic_fe)
mic_e2 = self.mic_enc2(mic_e1)
far_e1 = self.far_enc1(ref_fe)
far_e2 = self.far_enc2(far_e1)
aligned_far = self.align(mic_e2, far_e2)
concat = torch.cat([mic_e2, aligned_far], dim=1)
mic_e3 = self.mic_enc3(concat)
mic_e4 = self.mic_enc4(mic_e3)
mic_e5 = self.mic_enc5(mic_e4)
bn = self.bottleneck(mic_e5)
d5 = self.dec5(bn, mic_e5)[..., : mic_e4.shape[-1]]
d4 = self.dec4(d5, mic_e4)[..., : mic_e3.shape[-1]]
d3 = self.dec3(d4, mic_e3)[..., : mic_e2.shape[-1]]
d2 = self.dec2(d3, mic_e2)[..., : mic_e1.shape[-1]]
d1 = self.dec1(d2, mic_e1)[..., : mic_fe.shape[-1]]
enhanced = self.mask(d1, mic_enc)
return enhanced
@classmethod
def from_config(cls, cfg):
return cls(
mic_channels=cfg.model.mic_channels,
far_channels=cfg.model.far_channels,
align_hidden=cfg.model.align_hidden,
dmax=cfg.model.dmax,
power_law_c=cfg.model.power_law_c,
n_freqs=cfg.audio.n_freqs,
kernel_size=cfg.model.kernel_size,
bottleneck_hidden=cfg.model.bottleneck_hidden,
)