#!/usr/bin/env python import torch as th import torch.nn as nn import torch.nn.functional as F from .norm import ChannelwiseLayerNorm, GlobalLayerNorm from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock import warnings # inference aux_len class SpEx_Plus_Double(nn.Module): def __init__(self, L1=20, L2=80, L3=160, N=256, B=8, O=256, P=512, Q=3, num_spks=101, spk_embed_dim=256, causal=False, norm_type='gLN', fusion_type='cat', is_innorm=False, ): super(SpEx_Plus_Double, self).__init__() # n x S => n x N x T, S = 4s*8000 = 32000 self.L1 = L1 self.L2 = L2 self.L3 = L3 self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0) self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0) self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0) # before repeat blocks, always cLN self.instancenorm = nn.InstanceNorm1d(N) self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True) self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True) self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True) self.num_spks = num_spks self.pred_linear = nn.Linear(spk_embed_dim, num_spks) self.is_innorm = is_innorm if causal and norm_type not in ["cgLN", "cLN"]: norm_type = "cLN" warnings.warn( "In causal configuration cumulative layer normalization (cgLN)" "or channel-wise layer normalization (chanLN) " f"must be used. Changing {norm_type} to cLN" ) self.speaker_encoder = Speaker_Model( L1=L1, L2=L2, L3=L3, N=N, O=O, P=P, spk_embed_dim=spk_embed_dim, ) self.extractor = Extractor( L1=L1, L2=L2, L3=L3, N=N, B=B, O=O, P=P, Q=Q, num_spks=num_spks, spk_embed_dim=spk_embed_dim, causal=causal, fusion_type=fusion_type, norm_type=norm_type, ) self.frameconv1 = Conv1D(2*N, N, 1) self.frameconv2 = Conv1D(2*N, N, 1) self.frameconv3 = Conv1D(2*N, N, 1) self.fusion1 = nn.Parameter(th.tensor(0.8)) self.fusion2 = nn.Parameter(th.tensor(0.1)) self.fusion3 = nn.Parameter(th.tensor(0.1)) def align_to_w(self,frame, w): diff = frame.shape[-1] - w.shape[-1] if diff > 0: frame = frame[..., :w.shape[-1]] # 裁剪 elif diff < 0: frame = th.nn.functional.pad(frame, (0, -diff)) # 补零 return frame, w # w 保持不动 def ira(self, est1, aux, aux_len, xlen1, xlen2, xlen3, w1 ,w2, w3): ### 2 concat_aux = th.cat((est1, aux), dim=1) concat_aux_len = aux_len + xlen1 concat_aux_w1 = F.relu(self.encoder_1d_short(concat_aux)) concat_aux_T_shape = concat_aux_w1.shape[-1] concat_aux_len1 = concat_aux.shape[-1] concat_aux_len2 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L2 concat_aux_len3 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L3 concat_aux_w2 = F.relu(self.encoder_1d_middle(F.pad(concat_aux, (0, concat_aux_len2 - concat_aux_len1), "constant", 0))) concat_aux_w3 = F.relu(self.encoder_1d_long(F.pad(concat_aux, (0, concat_aux_len3 - concat_aux_len1), "constant", 0))) concat_aux = self.speaker_encoder(th.cat([concat_aux_w1, concat_aux_w2, concat_aux_w3], 1), concat_aux_len) frame1 = F.relu(self.encoder_1d_short(est1)) frame2 = F.relu(self.encoder_1d_middle(F.pad(est1, (0, xlen2 - xlen1), "constant", 0))) frame3 = F.relu(self.encoder_1d_long(F.pad(est1, (0, xlen3 - xlen1), "constant", 0))) if self.is_innorm: frame1 = self.instancenorm(frame1) frame2 = self.instancenorm(frame2) frame3 = self.instancenorm(frame3) frame1, w1 = self.align_to_w(frame1, w1) frame2, w2 = self.align_to_w(frame2, w2) frame3, w3 = self.align_to_w(frame3, w3) # frame2, w2 长度不匹配 4098 != 4099 # print("frame2 shape: ", frame2.shape) # print("w2 shape: ", w2.shape) concat1 = self.frameconv1(th.cat([frame1, w1], 1)) concat2 = self.frameconv2(th.cat([frame2, w2], 1)) concat3 = self.frameconv3(th.cat([frame3, w3], 1)) mask1, mask2, mask3 = self.extractor(concat1, concat2, concat3, concat_aux) F1 = concat1 * mask1 F2 = concat2 * mask2 F3 = concat3 * mask3 f1 = self.decoder_1d_short(F1) xlen1 = f1.shape[-1] f2 = self.decoder_1d_middle(F2)[:, :xlen1] f3 = self.decoder_1d_long(F3)[:, :xlen1] est2 = self.fusion1 * f1 + self.fusion2 * f2 + self.fusion3 * f3 return est2 def forward(self, x, aux, aux_len): if x.dim() >= 3: raise RuntimeError( "{} accept 1/2D tensor as input, but got {:d}".format( self.__name__, x.dim())) # when inference, only one utt if x.dim() == 1: x = th.unsqueeze(x, 0) # n x 1 x S => n x N x T w1 = F.relu(self.encoder_1d_short(x)) T = w1.shape[-1] xlen1 = x.shape[-1] xlen2 = (T - 1) * (self.L1 // 2) + self.L2 xlen3 = (T - 1) * (self.L1 // 2) + self.L3 w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0))) w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0))) # n x 3N x T # speaker encoder (share params from speech encoder) if self.is_innorm: w1 = self.instancenorm(w1) w2 = self.instancenorm(w2) w3 = self.instancenorm(w3) aux_w1 = F.relu(self.encoder_1d_short(aux)) aux_T_shape = aux_w1.shape[-1] aux_len1 = aux.shape[-1] aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2 aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3 aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0))) aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0))) aux = self.speaker_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1), aux_len) m1, m2, m3 = self.extractor(w1, w2, w3, aux) S1 = w1 * m1 S2 = w2 * m2 S3 = w3 * m3 s1 = F.pad(self.decoder_1d_short(S1), (0, max(0, xlen1 - self.decoder_1d_short(S1).shape[1])))[:, :xlen1] s2 = self.decoder_1d_middle(S2)[:, :xlen1] s3 = self.decoder_1d_long(S3)[:, :xlen1] est1 = self.fusion1 * s1 + self.fusion2 * s2 + self.fusion3 * s3 est2 = self.ira(est1, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3) est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3) return est3 class Extractor(nn.Module): def __init__(self, L1=20, L2=80, L3=160, N=256, B=8, O=256, P=512, Q=3, num_spks=101, spk_embed_dim=256, causal=False, fusion_type='cat', norm_type='gLN', ): super(Extractor, self).__init__() # n x N x T => n x O x T self.ln = ChannelwiseLayerNorm(3*N) self.proj = Conv1D(3*N, O, 1) self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) # n x O x T => n x N x T self.mask1 = Conv1D(O, N, 1) self.mask2 = Conv1D(O, N, 1) self.mask3 = Conv1D(O, N, 1) def _build_stacks(self, num_blocks, **block_kwargs): """ Stack B numbers of TCN block, the first TCN block takes the speaker embedding """ blocks = [ TCNBlock(**block_kwargs, dilation=(2**b)) for b in range(1,num_blocks) ] return nn.Sequential(*blocks) def forward(self, w1, w2, w3, aux): y = self.ln(th.cat([w1, w2, w3], 1)) # n x O x T y = self.proj(y) y = self.conv_block_1(y, aux) y = self.conv_block_1_other(y) y = self.conv_block_2(y, aux) y = self.conv_block_2_other(y) y = self.conv_block_3(y, aux) y = self.conv_block_3_other(y) y = self.conv_block_4(y, aux) y = self.conv_block_4_other(y) # n x N x T m1 = F.relu(self.mask1(y)) m2 = F.relu(self.mask2(y)) m3 = F.relu(self.mask3(y)) return m1, m2, m3 class Speaker_Model(nn.Module): def __init__(self, L1=20, L2=80, L3=160, N=256, O=256, P=512, spk_embed_dim=256, ): super(Speaker_Model, self).__init__() self.L1 = L1 self.L2 = L2 self.L3 = L3 self.spk_encoder = nn.Sequential( ChannelwiseLayerNorm(3*N), Conv1D(3*N, O, 1), ResBlock(O, O), ResBlock(O, P), ResBlock(P, P), Conv1D(P, spk_embed_dim, 1), ) def forward(self, aux, aux_len): aux = self.spk_encoder(aux) aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1 aux_T = ((aux_T // 3) // 3) // 3 aux = th.sum(aux, -1)/aux_T.view(-1,1).float() return aux