Target-speaker-extraction / model /spex_plus_plus.py
swc2's picture
add model select
ef932f5
#!/usr/bin/env python
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
import warnings
# inference aux_len
class SpEx_Plus_Double(nn.Module):
def __init__(self,
L1=20,
L2=80,
L3=160,
N=256,
B=8,
O=256,
P=512,
Q=3,
num_spks=101,
spk_embed_dim=256,
causal=False,
norm_type='gLN',
fusion_type='cat',
is_innorm=False,
):
super(SpEx_Plus_Double, self).__init__()
# n x S => n x N x T, S = 4s*8000 = 32000
self.L1 = L1
self.L2 = L2
self.L3 = L3
self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0)
self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
# before repeat blocks, always cLN
self.instancenorm = nn.InstanceNorm1d(N)
self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
self.num_spks = num_spks
self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
self.is_innorm = is_innorm
if causal and norm_type not in ["cgLN", "cLN"]:
norm_type = "cLN"
warnings.warn(
"In causal configuration cumulative layer normalization (cgLN)"
"or channel-wise layer normalization (chanLN) "
f"must be used. Changing {norm_type} to cLN"
)
self.speaker_encoder = Speaker_Model(
L1=L1,
L2=L2,
L3=L3,
N=N,
O=O,
P=P,
spk_embed_dim=spk_embed_dim,
)
self.extractor = Extractor(
L1=L1,
L2=L2,
L3=L3,
N=N,
B=B,
O=O,
P=P,
Q=Q,
num_spks=num_spks,
spk_embed_dim=spk_embed_dim,
causal=causal,
fusion_type=fusion_type,
norm_type=norm_type,
)
self.frameconv1 = Conv1D(2*N, N, 1)
self.frameconv2 = Conv1D(2*N, N, 1)
self.frameconv3 = Conv1D(2*N, N, 1)
self.fusion1 = nn.Parameter(th.tensor(0.8))
self.fusion2 = nn.Parameter(th.tensor(0.1))
self.fusion3 = nn.Parameter(th.tensor(0.1))
def align_to_w(self,frame, w):
diff = frame.shape[-1] - w.shape[-1]
if diff > 0:
frame = frame[..., :w.shape[-1]] # 裁剪
elif diff < 0:
frame = th.nn.functional.pad(frame, (0, -diff)) # 补零
return frame, w # w 保持不动
def ira(self, est1, aux, aux_len, xlen1, xlen2, xlen3, w1 ,w2, w3):
### 2
concat_aux = th.cat((est1, aux), dim=1)
concat_aux_len = aux_len + xlen1
concat_aux_w1 = F.relu(self.encoder_1d_short(concat_aux))
concat_aux_T_shape = concat_aux_w1.shape[-1]
concat_aux_len1 = concat_aux.shape[-1]
concat_aux_len2 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L2
concat_aux_len3 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L3
concat_aux_w2 = F.relu(self.encoder_1d_middle(F.pad(concat_aux, (0, concat_aux_len2 - concat_aux_len1), "constant", 0)))
concat_aux_w3 = F.relu(self.encoder_1d_long(F.pad(concat_aux, (0, concat_aux_len3 - concat_aux_len1), "constant", 0)))
concat_aux = self.speaker_encoder(th.cat([concat_aux_w1, concat_aux_w2, concat_aux_w3], 1), concat_aux_len)
frame1 = F.relu(self.encoder_1d_short(est1))
frame2 = F.relu(self.encoder_1d_middle(F.pad(est1, (0, xlen2 - xlen1), "constant", 0)))
frame3 = F.relu(self.encoder_1d_long(F.pad(est1, (0, xlen3 - xlen1), "constant", 0)))
if self.is_innorm:
frame1 = self.instancenorm(frame1)
frame2 = self.instancenorm(frame2)
frame3 = self.instancenorm(frame3)
frame1, w1 = self.align_to_w(frame1, w1)
frame2, w2 = self.align_to_w(frame2, w2)
frame3, w3 = self.align_to_w(frame3, w3)
# frame2, w2 长度不匹配 4098 != 4099
# print("frame2 shape: ", frame2.shape)
# print("w2 shape: ", w2.shape)
concat1 = self.frameconv1(th.cat([frame1, w1], 1))
concat2 = self.frameconv2(th.cat([frame2, w2], 1))
concat3 = self.frameconv3(th.cat([frame3, w3], 1))
mask1, mask2, mask3 = self.extractor(concat1, concat2, concat3, concat_aux)
F1 = concat1 * mask1
F2 = concat2 * mask2
F3 = concat3 * mask3
f1 = self.decoder_1d_short(F1)
xlen1 = f1.shape[-1]
f2 = self.decoder_1d_middle(F2)[:, :xlen1]
f3 = self.decoder_1d_long(F3)[:, :xlen1]
est2 = self.fusion1 * f1 + self.fusion2 * f2 + self.fusion3 * f3
return est2
def forward(self, x, aux, aux_len):
if x.dim() >= 3:
raise RuntimeError(
"{} accept 1/2D tensor as input, but got {:d}".format(
self.__name__, x.dim()))
# when inference, only one utt
if x.dim() == 1:
x = th.unsqueeze(x, 0)
# n x 1 x S => n x N x T
w1 = F.relu(self.encoder_1d_short(x))
T = w1.shape[-1]
xlen1 = x.shape[-1]
xlen2 = (T - 1) * (self.L1 // 2) + self.L2
xlen3 = (T - 1) * (self.L1 // 2) + self.L3
w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
# n x 3N x T
# speaker encoder (share params from speech encoder)
if self.is_innorm:
w1 = self.instancenorm(w1)
w2 = self.instancenorm(w2)
w3 = self.instancenorm(w3)
aux_w1 = F.relu(self.encoder_1d_short(aux))
aux_T_shape = aux_w1.shape[-1]
aux_len1 = aux.shape[-1]
aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2
aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3
aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0)))
aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0)))
aux = self.speaker_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1), aux_len)
m1, m2, m3 = self.extractor(w1, w2, w3, aux)
S1 = w1 * m1
S2 = w2 * m2
S3 = w3 * m3
s1 = F.pad(self.decoder_1d_short(S1), (0, max(0, xlen1 - self.decoder_1d_short(S1).shape[1])))[:, :xlen1]
s2 = self.decoder_1d_middle(S2)[:, :xlen1]
s3 = self.decoder_1d_long(S3)[:, :xlen1]
est1 = self.fusion1 * s1 + self.fusion2 * s2 + self.fusion3 * s3
est2 = self.ira(est1, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
return est3
class Extractor(nn.Module):
def __init__(self,
L1=20,
L2=80,
L3=160,
N=256,
B=8,
O=256,
P=512,
Q=3,
num_spks=101,
spk_embed_dim=256,
causal=False,
fusion_type='cat',
norm_type='gLN',
):
super(Extractor, self).__init__()
# n x N x T => n x O x T
self.ln = ChannelwiseLayerNorm(3*N)
self.proj = Conv1D(3*N, O, 1)
self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
# n x O x T => n x N x T
self.mask1 = Conv1D(O, N, 1)
self.mask2 = Conv1D(O, N, 1)
self.mask3 = Conv1D(O, N, 1)
def _build_stacks(self, num_blocks, **block_kwargs):
"""
Stack B numbers of TCN block, the first TCN block takes the speaker embedding
"""
blocks = [
TCNBlock(**block_kwargs, dilation=(2**b))
for b in range(1,num_blocks)
]
return nn.Sequential(*blocks)
def forward(self, w1, w2, w3, aux):
y = self.ln(th.cat([w1, w2, w3], 1))
# n x O x T
y = self.proj(y)
y = self.conv_block_1(y, aux)
y = self.conv_block_1_other(y)
y = self.conv_block_2(y, aux)
y = self.conv_block_2_other(y)
y = self.conv_block_3(y, aux)
y = self.conv_block_3_other(y)
y = self.conv_block_4(y, aux)
y = self.conv_block_4_other(y)
# n x N x T
m1 = F.relu(self.mask1(y))
m2 = F.relu(self.mask2(y))
m3 = F.relu(self.mask3(y))
return m1, m2, m3
class Speaker_Model(nn.Module):
def __init__(self,
L1=20,
L2=80,
L3=160,
N=256,
O=256,
P=512,
spk_embed_dim=256,
):
super(Speaker_Model, self).__init__()
self.L1 = L1
self.L2 = L2
self.L3 = L3
self.spk_encoder = nn.Sequential(
ChannelwiseLayerNorm(3*N),
Conv1D(3*N, O, 1),
ResBlock(O, O),
ResBlock(O, P),
ResBlock(P, P),
Conv1D(P, spk_embed_dim, 1),
)
def forward(self, aux, aux_len):
aux = self.spk_encoder(aux)
aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
aux_T = ((aux_T // 3) // 3) // 3
aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
return aux