import torch import torch.nn as nn from .align import AlignBlock from .blocks import ( FE, DecoderBlock, EncoderBlock, S4DBottleneck, ) from .ccm import CCM class DCTEncoder(nn.Module): def __init__(self, n_freqs=256, kernel_size=512, stride=256): super().__init__() n_filters = kernel_size self.n_freqs = n_freqs self.conv = nn.Conv1d(1, n_filters, kernel_size, stride=stride, bias=False) self._init_dct(n_filters, kernel_size) def _init_dct(self, n_filters, kernel_size): import math as _math k = torch.arange(n_filters).unsqueeze(1).float() n = torch.arange(kernel_size).unsqueeze(0).float() basis = torch.cos(_math.pi * (2 * n + 1) * k / (2 * kernel_size)) basis[0] *= 1.0 / (kernel_size ** 0.5) basis[1:] *= (2.0 / kernel_size) ** 0.5 with torch.no_grad(): self.conv.weight.copy_(basis.unsqueeze(1)) def forward(self, x): pad = self.conv.kernel_size[0] // 2 x_padded = torch.nn.functional.pad(x, (pad, pad)) out = self.conv(x_padded.unsqueeze(1)) B, _N, T = out.shape return out.reshape(B, self.n_freqs, 2, T).permute(0, 1, 3, 2) class DCTDecoder(nn.Module): def __init__(self, n_freqs=256, kernel_size=512, stride=256): super().__init__() n_filters = kernel_size self.n_freqs = n_freqs self.linear = nn.Linear(n_filters, kernel_size, bias=False) self.kernel_size = kernel_size self.stride = stride self.pad = kernel_size // 2 self.register_buffer('_overlap_count', None) def _init_from_encoder(self, encoder): with torch.no_grad(): self.linear.weight.copy_(encoder.conv.weight.squeeze(1).T) def forward(self, x, length=None): B, F, T, _2 = x.shape x = x.permute(0, 1, 3, 2).reshape(B, F * 2, T) frames = self.linear(x.transpose(1, 2)) out_len = (T - 1) * self.stride + self.kernel_size output = torch.nn.functional.fold( frames.transpose(1, 2), output_size=(1, out_len), kernel_size=(1, self.kernel_size), stride=(1, self.stride), ).squeeze(2).squeeze(1) if self._overlap_count is None or self._overlap_count.shape[-1] != out_len: ones = torch.ones(1, self.kernel_size, T, device=x.device) self._overlap_count = torch.nn.functional.fold( ones, output_size=(1, out_len), kernel_size=(1, self.kernel_size), stride=(1, self.stride), ).squeeze().clamp(min=1) output = output / self._overlap_count output = output[:, self.pad:] if length is not None: if output.shape[-1] < length: output = torch.nn.functional.pad(output, (0, length - output.shape[-1])) else: output = output[:, :length] return output def compute_freq_progression(n_freqs, kernel_size, n_stages=5): _kh, kw = kernel_size pad_left = (kw - 1) // 2 pad_right = kw - 1 - pad_left f = n_freqs freqs = [f] for _ in range(n_stages): f = (f + pad_left + pad_right - kw) // 2 + 1 freqs.append(f) return freqs class LocalVQE(nn.Module): def __init__( self, mic_channels=None, far_channels=None, align_hidden=32, dmax=32, power_law_c=0.3, n_freqs=257, kernel_size=(4, 3), bottleneck_hidden=0, ): super().__init__() if mic_channels is None: mic_channels = [2, 64, 128, 128, 128, 128] if far_channels is None: far_channels = [2, 32, 128] self.n_freqs = n_freqs ks = tuple(kernel_size) self.encoder = DCTEncoder(n_freqs=n_freqs, kernel_size=512, stride=256) self.decoder = DCTDecoder(n_freqs=n_freqs, kernel_size=512, stride=256) self.decoder._init_from_encoder(self.encoder) self.fe_mic = FE(c=power_law_c) self.fe_ref = FE(c=power_law_c) self.mic_enc1 = EncoderBlock(mic_channels[0], mic_channels[1], kernel_size=ks) self.mic_enc2 = EncoderBlock(mic_channels[1], mic_channels[2], kernel_size=ks) self.far_enc1 = EncoderBlock(far_channels[0], far_channels[1], kernel_size=ks) self.far_enc2 = EncoderBlock(far_channels[1], far_channels[2], kernel_size=ks) self.align = AlignBlock( in_channels=mic_channels[2], hidden_channels=align_hidden, dmax=dmax, ) self.mic_enc3 = EncoderBlock(mic_channels[2] * 2, mic_channels[3], kernel_size=ks) self.mic_enc4 = EncoderBlock(mic_channels[3], mic_channels[4], kernel_size=ks) self.mic_enc5 = EncoderBlock(mic_channels[4], mic_channels[5], kernel_size=ks) freqs = compute_freq_progression(n_freqs, ks) bn_input = mic_channels[5] * freqs[5] bn_hidden = bottleneck_hidden if bottleneck_hidden > 0 else bn_input // 2 self.bottleneck = S4DBottleneck(bn_input, bn_hidden) self.dec5 = DecoderBlock(mic_channels[5], mic_channels[4], kernel_size=ks) self.dec4 = DecoderBlock(mic_channels[4], mic_channels[3], kernel_size=ks) self.dec3 = DecoderBlock(mic_channels[3], mic_channels[2], kernel_size=ks) self.dec2 = DecoderBlock(mic_channels[2], mic_channels[1], kernel_size=ks) self.dec1 = DecoderBlock(mic_channels[1], 27, kernel_size=ks, is_last=True) self.mask = CCM() self._init_ccm_identity() def _init_ccm_identity(self): conv = self.dec1.deconv.conv with torch.no_grad(): conv.bias.zero_() conv.bias[7] = 1.0 conv.bias[34] = 1.0 def forward(self, mic_wav, ref_wav): mic_enc = self.encoder(mic_wav) ref_enc = self.encoder(ref_wav) mic_fe = self.fe_mic(mic_enc) ref_fe = self.fe_ref(ref_enc) mic_e1 = self.mic_enc1(mic_fe) mic_e2 = self.mic_enc2(mic_e1) far_e1 = self.far_enc1(ref_fe) far_e2 = self.far_enc2(far_e1) aligned_far = self.align(mic_e2, far_e2) concat = torch.cat([mic_e2, aligned_far], dim=1) mic_e3 = self.mic_enc3(concat) mic_e4 = self.mic_enc4(mic_e3) mic_e5 = self.mic_enc5(mic_e4) bn = self.bottleneck(mic_e5) d5 = self.dec5(bn, mic_e5)[..., : mic_e4.shape[-1]] d4 = self.dec4(d5, mic_e4)[..., : mic_e3.shape[-1]] d3 = self.dec3(d4, mic_e3)[..., : mic_e2.shape[-1]] d2 = self.dec2(d3, mic_e2)[..., : mic_e1.shape[-1]] d1 = self.dec1(d2, mic_e1)[..., : mic_fe.shape[-1]] enhanced = self.mask(d1, mic_enc) return enhanced @classmethod def from_config(cls, cfg): return cls( mic_channels=cfg.model.mic_channels, far_channels=cfg.model.far_channels, align_hidden=cfg.model.align_hidden, dmax=cfg.model.dmax, power_law_c=cfg.model.power_law_c, n_freqs=cfg.audio.n_freqs, kernel_size=cfg.model.kernel_size, bottleneck_hidden=cfg.model.bottleneck_hidden, )