#!/usr/bin/env python import torch as th import torch.nn as nn import torch.nn.functional as F from .norm import ChannelwiseLayerNorm, GlobalLayerNorm from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock import warnings # inference aux_len class SpEx_Plus(nn.Module): def __init__(self, L1=20, L2=80, L3=160, N=256, B=8, O=256, P=512, Q=3, num_spks=101, spk_embed_dim=256, causal=False, norm_type='gLN', fusion_type='cat', is_innorm=False, ): super(SpEx_Plus, self).__init__() # n x S => n x N x T, S = 4s*8000 = 32000 self.L1 = L1 self.L2 = L2 self.L3 = L3 self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0) self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0) self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0) # before repeat blocks, always cLN self.instancenorm = nn.InstanceNorm1d(N) self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True) self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True) self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True) self.num_spks = num_spks self.pred_linear = nn.Linear(spk_embed_dim, num_spks) self.is_innorm = is_innorm if causal and norm_type not in ["cgLN", "cLN"]: norm_type = "cLN" warnings.warn( "In causal configuration cumulative layer normalization (cgLN)" "or channel-wise layer normalization (chanLN) " f"must be used. Changing {norm_type} to cLN" ) self.speaker_encoder = Speaker_Model( L1=L1, L2=L2, L3=L3, N=N, O=O, P=P, spk_embed_dim=spk_embed_dim, ) self.extractor = Extractor( L1=L1, L2=L2, L3=L3, N=N, B=B, O=O, P=P, Q=Q, num_spks=num_spks, spk_embed_dim=spk_embed_dim, causal=causal, fusion_type=fusion_type, norm_type=norm_type, ) def forward(self, x, aux, aux_len): if x.dim() >= 3: raise RuntimeError( "{} accept 1/2D tensor as input, but got {:d}".format( self.__name__, x.dim())) # when inference, only one utt if x.dim() == 1: x = th.unsqueeze(x, 0) # n x 1 x S => n x N x T w1 = F.relu(self.encoder_1d_short(x)) T = w1.shape[-1] xlen1 = x.shape[-1] xlen2 = (T - 1) * (self.L1 // 2) + self.L2 xlen3 = (T - 1) * (self.L1 // 2) + self.L3 w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0))) w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0))) # n x 3N x T # speaker encoder (share params from speech encoder) aux_w1 = F.relu(self.encoder_1d_short(aux)) aux_T_shape = aux_w1.shape[-1] aux_len1 = aux.shape[-1] aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2 aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3 aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0))) aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0))) aux = self.speaker_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1), aux_len) if self.is_innorm: w1 = self.instancenorm(w1) w2 = self.instancenorm(w2) w3 = self.instancenorm(w3) m1, m2, m3 = self.extractor(w1, w2, w3, aux) S1 = w1 * m1 S2 = w2 * m2 S3 = w3 * m3 out1 = self.decoder_1d_short(S1) # out2 = self.decoder_1d_middle(S2)[:, :xlen1] # out3 = self.decoder_1d_long(S3)[:, :xlen1] return self.decoder_1d_short(S1) class Extractor(nn.Module): def __init__(self, L1=20, L2=80, L3=160, N=256, B=8, O=256, P=512, Q=3, num_spks=101, spk_embed_dim=256, causal=False, fusion_type='cat', norm_type='gLN', ): super(Extractor, self).__init__() # n x N x T => n x O x T self.ln = ChannelwiseLayerNorm(3*N) self.proj = Conv1D(3*N, O, 1) self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type) self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type) # n x O x T => n x N x T self.mask1 = Conv1D(O, N, 1) self.mask2 = Conv1D(O, N, 1) self.mask3 = Conv1D(O, N, 1) def _build_stacks(self, num_blocks, **block_kwargs): """ Stack B numbers of TCN block, the first TCN block takes the speaker embedding """ blocks = [ TCNBlock(**block_kwargs, dilation=(2**b)) for b in range(1,num_blocks) ] return nn.Sequential(*blocks) def forward(self, w1, w2, w3, aux): y = self.ln(th.cat([w1, w2, w3], 1)) # n x O x T y = self.proj(y) y = self.conv_block_1(y, aux) y = self.conv_block_1_other(y) y = self.conv_block_2(y, aux) y = self.conv_block_2_other(y) y = self.conv_block_3(y, aux) y = self.conv_block_3_other(y) y = self.conv_block_4(y, aux) y = self.conv_block_4_other(y) # n x N x T m1 = F.relu(self.mask1(y)) m2 = F.relu(self.mask2(y)) m3 = F.relu(self.mask3(y)) return m1, m2, m3 class Speaker_Model(nn.Module): def __init__(self, L1=20, L2=80, L3=160, N=256, O=256, P=512, spk_embed_dim=256, ): super(Speaker_Model, self).__init__() self.L1 = L1 self.L2 = L2 self.L3 = L3 self.spk_encoder = nn.Sequential( ChannelwiseLayerNorm(3*N), Conv1D(3*N, O, 1), ResBlock(O, O), ResBlock(O, P), ResBlock(P, P), Conv1D(P, spk_embed_dim, 1), ) def forward(self, aux, aux_len): aux = self.spk_encoder(aux) aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1 aux_T = ((aux_T // 3) // 3) // 3 aux = th.sum(aux, -1)/aux_T.view(-1,1).float() return aux