Spaces:
Running
Running
| #!/usr/bin/env python | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .norm import ChannelwiseLayerNorm, GlobalLayerNorm | |
| from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock | |
| import warnings | |
| # inference aux_len | |
| class SpEx_Plus_Double(nn.Module): | |
| def __init__(self, | |
| L1=20, | |
| L2=80, | |
| L3=160, | |
| N=256, | |
| B=8, | |
| O=256, | |
| P=512, | |
| Q=3, | |
| num_spks=101, | |
| spk_embed_dim=256, | |
| causal=False, | |
| norm_type='gLN', | |
| fusion_type='cat', | |
| is_innorm=False, | |
| ): | |
| super(SpEx_Plus_Double, self).__init__() | |
| # n x S => n x N x T, S = 4s*8000 = 32000 | |
| self.L1 = L1 | |
| self.L2 = L2 | |
| self.L3 = L3 | |
| self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0) | |
| self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0) | |
| self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0) | |
| # before repeat blocks, always cLN | |
| self.instancenorm = nn.InstanceNorm1d(N) | |
| self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True) | |
| self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True) | |
| self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True) | |
| self.num_spks = num_spks | |
| self.pred_linear = nn.Linear(spk_embed_dim, num_spks) | |
| self.is_innorm = is_innorm | |
| if causal and norm_type not in ["cgLN", "cLN"]: | |
| norm_type = "cLN" | |
| warnings.warn( | |
| "In causal configuration cumulative layer normalization (cgLN)" | |
| "or channel-wise layer normalization (chanLN) " | |
| f"must be used. Changing {norm_type} to cLN" | |
| ) | |
| self.speaker_encoder = Speaker_Model( | |
| L1=L1, | |
| L2=L2, | |
| L3=L3, | |
| N=N, | |
| O=O, | |
| P=P, | |
| spk_embed_dim=spk_embed_dim, | |
| ) | |
| self.extractor = Extractor( | |
| L1=L1, | |
| L2=L2, | |
| L3=L3, | |
| N=N, | |
| B=B, | |
| O=O, | |
| P=P, | |
| Q=Q, | |
| num_spks=num_spks, | |
| spk_embed_dim=spk_embed_dim, | |
| causal=causal, | |
| fusion_type=fusion_type, | |
| norm_type=norm_type, | |
| ) | |
| self.frameconv1 = Conv1D(2*N, N, 1) | |
| self.frameconv2 = Conv1D(2*N, N, 1) | |
| self.frameconv3 = Conv1D(2*N, N, 1) | |
| self.fusion1 = nn.Parameter(th.tensor(0.8)) | |
| self.fusion2 = nn.Parameter(th.tensor(0.1)) | |
| self.fusion3 = nn.Parameter(th.tensor(0.1)) | |
| def align_to_w(self,frame, w): | |
| diff = frame.shape[-1] - w.shape[-1] | |
| if diff > 0: | |
| frame = frame[..., :w.shape[-1]] # 裁剪 | |
| elif diff < 0: | |
| frame = th.nn.functional.pad(frame, (0, -diff)) # 补零 | |
| return frame, w # w 保持不动 | |
| def ira(self, est1, aux, aux_len, xlen1, xlen2, xlen3, w1 ,w2, w3): | |
| ### 2 | |
| concat_aux = th.cat((est1, aux), dim=1) | |
| concat_aux_len = aux_len + xlen1 | |
| concat_aux_w1 = F.relu(self.encoder_1d_short(concat_aux)) | |
| concat_aux_T_shape = concat_aux_w1.shape[-1] | |
| concat_aux_len1 = concat_aux.shape[-1] | |
| concat_aux_len2 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L2 | |
| concat_aux_len3 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L3 | |
| concat_aux_w2 = F.relu(self.encoder_1d_middle(F.pad(concat_aux, (0, concat_aux_len2 - concat_aux_len1), "constant", 0))) | |
| concat_aux_w3 = F.relu(self.encoder_1d_long(F.pad(concat_aux, (0, concat_aux_len3 - concat_aux_len1), "constant", 0))) | |
| concat_aux = self.speaker_encoder(th.cat([concat_aux_w1, concat_aux_w2, concat_aux_w3], 1), concat_aux_len) | |
| frame1 = F.relu(self.encoder_1d_short(est1)) | |
| frame2 = F.relu(self.encoder_1d_middle(F.pad(est1, (0, xlen2 - xlen1), "constant", 0))) | |
| frame3 = F.relu(self.encoder_1d_long(F.pad(est1, (0, xlen3 - xlen1), "constant", 0))) | |
| if self.is_innorm: | |
| frame1 = self.instancenorm(frame1) | |
| frame2 = self.instancenorm(frame2) | |
| frame3 = self.instancenorm(frame3) | |
| frame1, w1 = self.align_to_w(frame1, w1) | |
| frame2, w2 = self.align_to_w(frame2, w2) | |
| frame3, w3 = self.align_to_w(frame3, w3) | |
| # frame2, w2 长度不匹配 4098 != 4099 | |
| # print("frame2 shape: ", frame2.shape) | |
| # print("w2 shape: ", w2.shape) | |
| concat1 = self.frameconv1(th.cat([frame1, w1], 1)) | |
| concat2 = self.frameconv2(th.cat([frame2, w2], 1)) | |
| concat3 = self.frameconv3(th.cat([frame3, w3], 1)) | |
| mask1, mask2, mask3 = self.extractor(concat1, concat2, concat3, concat_aux) | |
| F1 = concat1 * mask1 | |
| F2 = concat2 * mask2 | |
| F3 = concat3 * mask3 | |
| f1 = self.decoder_1d_short(F1) | |
| xlen1 = f1.shape[-1] | |
| f2 = self.decoder_1d_middle(F2)[:, :xlen1] | |
| f3 = self.decoder_1d_long(F3)[:, :xlen1] | |
| est2 = self.fusion1 * f1 + self.fusion2 * f2 + self.fusion3 * f3 | |
| return est2 | |
| def forward(self, x, aux, aux_len): | |
| if x.dim() >= 3: | |
| raise RuntimeError( | |
| "{} accept 1/2D tensor as input, but got {:d}".format( | |
| self.__name__, x.dim())) | |
| # when inference, only one utt | |
| if x.dim() == 1: | |
| x = th.unsqueeze(x, 0) | |
| # n x 1 x S => n x N x T | |
| w1 = F.relu(self.encoder_1d_short(x)) | |
| T = w1.shape[-1] | |
| xlen1 = x.shape[-1] | |
| xlen2 = (T - 1) * (self.L1 // 2) + self.L2 | |
| xlen3 = (T - 1) * (self.L1 // 2) + self.L3 | |
| w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0))) | |
| w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0))) | |
| # n x 3N x T | |
| # speaker encoder (share params from speech encoder) | |
| if self.is_innorm: | |
| w1 = self.instancenorm(w1) | |
| w2 = self.instancenorm(w2) | |
| w3 = self.instancenorm(w3) | |
| aux_w1 = F.relu(self.encoder_1d_short(aux)) | |
| aux_T_shape = aux_w1.shape[-1] | |
| aux_len1 = aux.shape[-1] | |
| aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2 | |
| aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3 | |
| aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0))) | |
| aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0))) | |
| aux = self.speaker_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1), aux_len) | |
| m1, m2, m3 = self.extractor(w1, w2, w3, aux) | |
| S1 = w1 * m1 | |
| S2 = w2 * m2 | |
| S3 = w3 * m3 | |
| s1 = F.pad(self.decoder_1d_short(S1), (0, max(0, xlen1 - self.decoder_1d_short(S1).shape[1])))[:, :xlen1] | |
| s2 = self.decoder_1d_middle(S2)[:, :xlen1] | |
| s3 = self.decoder_1d_long(S3)[:, :xlen1] | |
| est1 = self.fusion1 * s1 + self.fusion2 * s2 + self.fusion3 * s3 | |
| est2 = self.ira(est1, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3) | |
| est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3) | |
| return est3 | |
| class Extractor(nn.Module): | |
| def __init__(self, | |
| L1=20, | |
| L2=80, | |
| L3=160, | |
| N=256, | |
| B=8, | |
| O=256, | |
| P=512, | |
| Q=3, | |
| num_spks=101, | |
| spk_embed_dim=256, | |
| causal=False, | |
| fusion_type='cat', | |
| norm_type='gLN', | |
| ): | |
| super(Extractor, self).__init__() | |
| # n x N x T => n x O x T | |
| self.ln = ChannelwiseLayerNorm(3*N) | |
| self.proj = Conv1D(3*N, O, 1) | |
| self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) | |
| self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) | |
| self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) | |
| self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) | |
| self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) | |
| self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) | |
| self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) | |
| self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) | |
| # n x O x T => n x N x T | |
| self.mask1 = Conv1D(O, N, 1) | |
| self.mask2 = Conv1D(O, N, 1) | |
| self.mask3 = Conv1D(O, N, 1) | |
| def _build_stacks(self, num_blocks, **block_kwargs): | |
| """ | |
| Stack B numbers of TCN block, the first TCN block takes the speaker embedding | |
| """ | |
| blocks = [ | |
| TCNBlock(**block_kwargs, dilation=(2**b)) | |
| for b in range(1,num_blocks) | |
| ] | |
| return nn.Sequential(*blocks) | |
| def forward(self, w1, w2, w3, aux): | |
| y = self.ln(th.cat([w1, w2, w3], 1)) | |
| # n x O x T | |
| y = self.proj(y) | |
| y = self.conv_block_1(y, aux) | |
| y = self.conv_block_1_other(y) | |
| y = self.conv_block_2(y, aux) | |
| y = self.conv_block_2_other(y) | |
| y = self.conv_block_3(y, aux) | |
| y = self.conv_block_3_other(y) | |
| y = self.conv_block_4(y, aux) | |
| y = self.conv_block_4_other(y) | |
| # n x N x T | |
| m1 = F.relu(self.mask1(y)) | |
| m2 = F.relu(self.mask2(y)) | |
| m3 = F.relu(self.mask3(y)) | |
| return m1, m2, m3 | |
| class Speaker_Model(nn.Module): | |
| def __init__(self, | |
| L1=20, | |
| L2=80, | |
| L3=160, | |
| N=256, | |
| O=256, | |
| P=512, | |
| spk_embed_dim=256, | |
| ): | |
| super(Speaker_Model, self).__init__() | |
| self.L1 = L1 | |
| self.L2 = L2 | |
| self.L3 = L3 | |
| self.spk_encoder = nn.Sequential( | |
| ChannelwiseLayerNorm(3*N), | |
| Conv1D(3*N, O, 1), | |
| ResBlock(O, O), | |
| ResBlock(O, P), | |
| ResBlock(P, P), | |
| Conv1D(P, spk_embed_dim, 1), | |
| ) | |
| def forward(self, aux, aux_len): | |
| aux = self.spk_encoder(aux) | |
| aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1 | |
| aux_T = ((aux_T // 3) // 3) // 3 | |
| aux = th.sum(aux, -1)/aux_T.view(-1,1).float() | |
| return aux | |