Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from .align import AlignBlock | |
| from .blocks import ( | |
| FE, | |
| DecoderBlock, | |
| EncoderBlock, | |
| S4DBottleneck, | |
| ) | |
| from .ccm import CCM | |
| class DCTEncoder(nn.Module): | |
| def __init__(self, n_freqs=256, kernel_size=512, stride=256): | |
| super().__init__() | |
| n_filters = kernel_size | |
| self.n_freqs = n_freqs | |
| self.conv = nn.Conv1d(1, n_filters, kernel_size, stride=stride, bias=False) | |
| self._init_dct(n_filters, kernel_size) | |
| def _init_dct(self, n_filters, kernel_size): | |
| import math as _math | |
| k = torch.arange(n_filters).unsqueeze(1).float() | |
| n = torch.arange(kernel_size).unsqueeze(0).float() | |
| basis = torch.cos(_math.pi * (2 * n + 1) * k / (2 * kernel_size)) | |
| basis[0] *= 1.0 / (kernel_size ** 0.5) | |
| basis[1:] *= (2.0 / kernel_size) ** 0.5 | |
| with torch.no_grad(): | |
| self.conv.weight.copy_(basis.unsqueeze(1)) | |
| def forward(self, x): | |
| pad = self.conv.kernel_size[0] // 2 | |
| x_padded = torch.nn.functional.pad(x, (pad, pad)) | |
| out = self.conv(x_padded.unsqueeze(1)) | |
| B, _N, T = out.shape | |
| return out.reshape(B, self.n_freqs, 2, T).permute(0, 1, 3, 2) | |
| class DCTDecoder(nn.Module): | |
| def __init__(self, n_freqs=256, kernel_size=512, stride=256): | |
| super().__init__() | |
| n_filters = kernel_size | |
| self.n_freqs = n_freqs | |
| self.linear = nn.Linear(n_filters, kernel_size, bias=False) | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.pad = kernel_size // 2 | |
| self.register_buffer('_overlap_count', None) | |
| def _init_from_encoder(self, encoder): | |
| with torch.no_grad(): | |
| self.linear.weight.copy_(encoder.conv.weight.squeeze(1).T) | |
| def forward(self, x, length=None): | |
| B, F, T, _2 = x.shape | |
| x = x.permute(0, 1, 3, 2).reshape(B, F * 2, T) | |
| frames = self.linear(x.transpose(1, 2)) | |
| out_len = (T - 1) * self.stride + self.kernel_size | |
| output = torch.nn.functional.fold( | |
| frames.transpose(1, 2), | |
| output_size=(1, out_len), | |
| kernel_size=(1, self.kernel_size), | |
| stride=(1, self.stride), | |
| ).squeeze(2).squeeze(1) | |
| if self._overlap_count is None or self._overlap_count.shape[-1] != out_len: | |
| ones = torch.ones(1, self.kernel_size, T, device=x.device) | |
| self._overlap_count = torch.nn.functional.fold( | |
| ones, | |
| output_size=(1, out_len), | |
| kernel_size=(1, self.kernel_size), | |
| stride=(1, self.stride), | |
| ).squeeze().clamp(min=1) | |
| output = output / self._overlap_count | |
| output = output[:, self.pad:] | |
| if length is not None: | |
| if output.shape[-1] < length: | |
| output = torch.nn.functional.pad(output, (0, length - output.shape[-1])) | |
| else: | |
| output = output[:, :length] | |
| return output | |
| def compute_freq_progression(n_freqs, kernel_size, n_stages=5): | |
| _kh, kw = kernel_size | |
| pad_left = (kw - 1) // 2 | |
| pad_right = kw - 1 - pad_left | |
| f = n_freqs | |
| freqs = [f] | |
| for _ in range(n_stages): | |
| f = (f + pad_left + pad_right - kw) // 2 + 1 | |
| freqs.append(f) | |
| return freqs | |
| class LocalVQE(nn.Module): | |
| def __init__( | |
| self, | |
| mic_channels=None, | |
| far_channels=None, | |
| align_hidden=32, | |
| dmax=32, | |
| power_law_c=0.3, | |
| n_freqs=257, | |
| kernel_size=(4, 3), | |
| bottleneck_hidden=0, | |
| ): | |
| super().__init__() | |
| if mic_channels is None: | |
| mic_channels = [2, 64, 128, 128, 128, 128] | |
| if far_channels is None: | |
| far_channels = [2, 32, 128] | |
| self.n_freqs = n_freqs | |
| ks = tuple(kernel_size) | |
| self.encoder = DCTEncoder(n_freqs=n_freqs, kernel_size=512, stride=256) | |
| self.decoder = DCTDecoder(n_freqs=n_freqs, kernel_size=512, stride=256) | |
| self.decoder._init_from_encoder(self.encoder) | |
| self.fe_mic = FE(c=power_law_c) | |
| self.fe_ref = FE(c=power_law_c) | |
| self.mic_enc1 = EncoderBlock(mic_channels[0], mic_channels[1], kernel_size=ks) | |
| self.mic_enc2 = EncoderBlock(mic_channels[1], mic_channels[2], kernel_size=ks) | |
| self.far_enc1 = EncoderBlock(far_channels[0], far_channels[1], kernel_size=ks) | |
| self.far_enc2 = EncoderBlock(far_channels[1], far_channels[2], kernel_size=ks) | |
| self.align = AlignBlock( | |
| in_channels=mic_channels[2], | |
| hidden_channels=align_hidden, | |
| dmax=dmax, | |
| ) | |
| self.mic_enc3 = EncoderBlock(mic_channels[2] * 2, mic_channels[3], kernel_size=ks) | |
| self.mic_enc4 = EncoderBlock(mic_channels[3], mic_channels[4], kernel_size=ks) | |
| self.mic_enc5 = EncoderBlock(mic_channels[4], mic_channels[5], kernel_size=ks) | |
| freqs = compute_freq_progression(n_freqs, ks) | |
| bn_input = mic_channels[5] * freqs[5] | |
| bn_hidden = bottleneck_hidden if bottleneck_hidden > 0 else bn_input // 2 | |
| self.bottleneck = S4DBottleneck(bn_input, bn_hidden) | |
| self.dec5 = DecoderBlock(mic_channels[5], mic_channels[4], kernel_size=ks) | |
| self.dec4 = DecoderBlock(mic_channels[4], mic_channels[3], kernel_size=ks) | |
| self.dec3 = DecoderBlock(mic_channels[3], mic_channels[2], kernel_size=ks) | |
| self.dec2 = DecoderBlock(mic_channels[2], mic_channels[1], kernel_size=ks) | |
| self.dec1 = DecoderBlock(mic_channels[1], 27, kernel_size=ks, is_last=True) | |
| self.mask = CCM() | |
| self._init_ccm_identity() | |
| def _init_ccm_identity(self): | |
| conv = self.dec1.deconv.conv | |
| with torch.no_grad(): | |
| conv.bias.zero_() | |
| conv.bias[7] = 1.0 | |
| conv.bias[34] = 1.0 | |
| def forward(self, mic_wav, ref_wav): | |
| mic_enc = self.encoder(mic_wav) | |
| ref_enc = self.encoder(ref_wav) | |
| mic_fe = self.fe_mic(mic_enc) | |
| ref_fe = self.fe_ref(ref_enc) | |
| mic_e1 = self.mic_enc1(mic_fe) | |
| mic_e2 = self.mic_enc2(mic_e1) | |
| far_e1 = self.far_enc1(ref_fe) | |
| far_e2 = self.far_enc2(far_e1) | |
| aligned_far = self.align(mic_e2, far_e2) | |
| concat = torch.cat([mic_e2, aligned_far], dim=1) | |
| mic_e3 = self.mic_enc3(concat) | |
| mic_e4 = self.mic_enc4(mic_e3) | |
| mic_e5 = self.mic_enc5(mic_e4) | |
| bn = self.bottleneck(mic_e5) | |
| d5 = self.dec5(bn, mic_e5)[..., : mic_e4.shape[-1]] | |
| d4 = self.dec4(d5, mic_e4)[..., : mic_e3.shape[-1]] | |
| d3 = self.dec3(d4, mic_e3)[..., : mic_e2.shape[-1]] | |
| d2 = self.dec2(d3, mic_e2)[..., : mic_e1.shape[-1]] | |
| d1 = self.dec1(d2, mic_e1)[..., : mic_fe.shape[-1]] | |
| enhanced = self.mask(d1, mic_enc) | |
| return enhanced | |
| def from_config(cls, cfg): | |
| return cls( | |
| mic_channels=cfg.model.mic_channels, | |
| far_channels=cfg.model.far_channels, | |
| align_hidden=cfg.model.align_hidden, | |
| dmax=cfg.model.dmax, | |
| power_law_c=cfg.model.power_law_c, | |
| n_freqs=cfg.audio.n_freqs, | |
| kernel_size=cfg.model.kernel_size, | |
| bottleneck_hidden=cfg.model.bottleneck_hidden, | |
| ) | |