|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .model1 import Conv_Emb_Generator |
|
|
from .model2_joint import TSH |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import copy |
|
|
|
|
|
|
|
|
def mod_pad(x, chunk_size, pad): |
|
|
mod = 0 |
|
|
if (x.shape[-1] % chunk_size) != 0: |
|
|
mod = chunk_size - (x.shape[-1] % chunk_size) |
|
|
|
|
|
x = F.pad(x, (0, mod)) |
|
|
x = F.pad(x, pad) |
|
|
|
|
|
return x, mod |
|
|
|
|
|
|
|
|
|
|
|
class Net_Conversation(nn.Module): |
|
|
def __init__(self, |
|
|
model1_block_name, |
|
|
model1_block_params, |
|
|
model2_block_name, |
|
|
model2_block_params, |
|
|
stft_chunk_size=64, |
|
|
stft_pad_size=32, |
|
|
stft_back_pad=32, |
|
|
num_input_channels=1, |
|
|
num_output_channels=1, |
|
|
num_sources=1, |
|
|
speaker_embed = 256, |
|
|
num_layers_model1=3, |
|
|
num_layers_model2=3, |
|
|
latent_dim_model1=16, |
|
|
latent_dim_model2=32, |
|
|
use_sp_feats=False, |
|
|
use_first_ln=True, |
|
|
n_imics=1, |
|
|
window="hann", |
|
|
lstm_fold_chunk=400, |
|
|
E=2, |
|
|
use_speaker_emb_model1=True, |
|
|
one_emb_model1=True, |
|
|
use_self_speech_model2=True, |
|
|
local_context_len=-1 |
|
|
): |
|
|
super(Net_Conversation, self).__init__() |
|
|
|
|
|
assert num_sources == 1 |
|
|
|
|
|
|
|
|
self.nI = num_input_channels |
|
|
self.nO = num_output_channels |
|
|
|
|
|
|
|
|
num_separator_inputs = self.nI * 2 + use_sp_feats * (3 * (self.nI - 1)) |
|
|
|
|
|
self.stft_chunk_size = stft_chunk_size |
|
|
self.stft_pad_size = stft_pad_size |
|
|
self.stft_back_pad = stft_back_pad |
|
|
self.n_srcs = num_sources |
|
|
self.use_sp_feats = use_sp_feats |
|
|
|
|
|
|
|
|
self.nfft = stft_back_pad + stft_chunk_size + stft_pad_size |
|
|
|
|
|
self.nfreqs = self.nfft//2 + 1 |
|
|
|
|
|
self.lstm_fold_chunk=lstm_fold_chunk |
|
|
|
|
|
|
|
|
if window=="hann": |
|
|
window_fn = lambda x: np.hanning(x) |
|
|
elif window=="rect": |
|
|
window_fn = lambda x: np.ones(x) |
|
|
else: |
|
|
raise ValueError("Invalid window type!") |
|
|
|
|
|
if ((stft_pad_size) % stft_chunk_size) == 0: |
|
|
print("Using perfect STFT windows") |
|
|
self.analysis_window = torch.from_numpy(window_fn(self.nfft)).float() |
|
|
|
|
|
|
|
|
self.synthesis_window = torch.zeros(stft_pad_size + stft_chunk_size).float() |
|
|
|
|
|
A = self.synthesis_window.shape[0] |
|
|
B = self.stft_chunk_size |
|
|
N = self.analysis_window.shape[0] |
|
|
|
|
|
assert (A % B) == 0 |
|
|
for i in range(A): |
|
|
num = self.analysis_window[N - A + i] |
|
|
|
|
|
denom = 0 |
|
|
for k in range(A//B): |
|
|
denom += (self.analysis_window[N - A + (i % B) + k * B] ** 2) |
|
|
|
|
|
self.synthesis_window[i] = num / denom |
|
|
else: |
|
|
print("Using imperfect STFT windows") |
|
|
self.analysis_window = torch.from_numpy( window_fn(self.nfft) ).float() |
|
|
self.synthesis_window = torch.from_numpy( window_fn(stft_chunk_size + stft_pad_size) ).float() |
|
|
|
|
|
self.istft_lookback = 1 + (self.synthesis_window.shape[0] - 1) // self.stft_chunk_size |
|
|
|
|
|
if local_context_len!=-1: |
|
|
local_context_len=local_context_len//stft_chunk_size//lstm_fold_chunk |
|
|
|
|
|
self.model1 = Conv_Emb_Generator( |
|
|
model1_block_name, |
|
|
model1_block_params, |
|
|
spk_dim = speaker_embed, |
|
|
latent_dim = latent_dim_model1, |
|
|
n_srcs = num_output_channels * num_sources, |
|
|
n_fft = self.nfft, |
|
|
num_inputs = num_separator_inputs, |
|
|
n_layers = num_layers_model1, |
|
|
use_first_ln=use_first_ln, |
|
|
n_imics=n_imics, |
|
|
lstm_fold_chunk=lstm_fold_chunk, |
|
|
E=E, |
|
|
use_speaker_emb=use_speaker_emb_model1, |
|
|
one_emb=one_emb_model1, |
|
|
local_context_len=local_context_len |
|
|
) |
|
|
|
|
|
self.quantized = False |
|
|
|
|
|
self.use_self_speech_model2=use_self_speech_model2 |
|
|
|
|
|
self.model2=TSH( |
|
|
model2_block_name, |
|
|
model2_block_params, |
|
|
spk_dim = speaker_embed, |
|
|
latent_dim = latent_dim_model2, |
|
|
latent_dim_model1=latent_dim_model1, |
|
|
n_srcs = num_output_channels * num_sources, |
|
|
n_fft = self.nfft, |
|
|
num_inputs = num_separator_inputs, |
|
|
n_layers = num_layers_model2, |
|
|
use_first_ln=use_first_ln, |
|
|
n_imics=n_imics, |
|
|
lstm_fold_chunk=lstm_fold_chunk, |
|
|
stft_chunk_size=stft_chunk_size, |
|
|
use_speaker_emb=use_speaker_emb_model1, |
|
|
use_self_speech_model2=use_self_speech_model2 |
|
|
) |
|
|
|
|
|
self.use_speaker_emb_model1=use_speaker_emb_model1 |
|
|
|
|
|
def init_buffers(self, batch_size, device): |
|
|
buffers = {} |
|
|
|
|
|
buffers['model1_bufs'] = self.model1.init_buffers(batch_size, device) |
|
|
|
|
|
buffers['model2_bufs'] = self.model2.init_buffers(batch_size, device) |
|
|
|
|
|
buffers['istft_buf'] = torch.zeros(batch_size * self.n_srcs * self.nO, |
|
|
self.synthesis_window.shape[0], |
|
|
self.istft_lookback, device=device) |
|
|
|
|
|
return buffers |
|
|
|
|
|
|
|
|
def extract_features(self, x): |
|
|
""" |
|
|
x: (B, M, T) |
|
|
returns: (B, C*M, T, F) |
|
|
""" |
|
|
B, M, T = x.shape |
|
|
|
|
|
x = x.reshape(B*M, T) |
|
|
x = torch.stft(x, n_fft = self.nfft, hop_length = self.stft_chunk_size, |
|
|
win_length = self.nfft, window=self.analysis_window.to(x.device), |
|
|
center=False, normalized=False, return_complex=True) |
|
|
|
|
|
x = torch.view_as_real(x) |
|
|
BM, _F, T, C = x.shape |
|
|
|
|
|
x = x.reshape(B, M, _F, T, C) |
|
|
|
|
|
x = x.permute(0, 4, 1, 3, 2) |
|
|
|
|
|
x = x.reshape(B, C*M, T, _F) |
|
|
|
|
|
return x |
|
|
|
|
|
def synthesis(self, x, input_state): |
|
|
""" |
|
|
x: (B, S, T, C*F) |
|
|
returns: (B, S, t) |
|
|
""" |
|
|
istft_buf = input_state['istft_buf'] |
|
|
|
|
|
x = x.transpose(2, 3) |
|
|
|
|
|
B, S, CF, T = x.shape |
|
|
X = x.reshape(B*S, CF, T) |
|
|
X = X.reshape(B*S, 2, -1, T).permute(0, 2, 3, 1) |
|
|
X = X[..., 0] + 1j * X[..., 1] |
|
|
|
|
|
x = torch.fft.irfft(X, dim=1) |
|
|
x = x[:, -self.synthesis_window.shape[0]:] |
|
|
|
|
|
|
|
|
x = x * self.synthesis_window.unsqueeze(0).unsqueeze(-1).to(x.device) |
|
|
|
|
|
oW = self.synthesis_window.shape[0] |
|
|
|
|
|
|
|
|
x = torch.cat([istft_buf, x], dim=-1) |
|
|
istft_buf = x[..., -istft_buf.shape[1]:] |
|
|
|
|
|
|
|
|
x = F.fold(x, output_size=(self.stft_chunk_size * x.shape[-1] + (oW - self.stft_chunk_size), 1), |
|
|
kernel_size=(oW, 1), stride=(self.stft_chunk_size, 1)) |
|
|
|
|
|
x = x[:, :, -T * self.stft_chunk_size - self.stft_pad_size: - self.stft_pad_size] |
|
|
x = x.reshape(B, S, -1) |
|
|
|
|
|
input_state['istft_buf'] = istft_buf |
|
|
|
|
|
return x, input_state |
|
|
|
|
|
|
|
|
def predict_model1(self, x, input_state, speaker_embedding, pad=True): |
|
|
""" |
|
|
B: batch |
|
|
M: mic |
|
|
t: time step (time-domain) |
|
|
x: (B, M, t) |
|
|
R: real or imaginary |
|
|
""" |
|
|
|
|
|
mod = 0 |
|
|
if pad: |
|
|
pad_size = (self.stft_back_pad, self.stft_pad_size) |
|
|
x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size) |
|
|
|
|
|
|
|
|
x = self.extract_features(x) |
|
|
|
|
|
if speaker_embedding is not None: |
|
|
speaker_embedding=speaker_embedding.unsqueeze(2) |
|
|
|
|
|
conversation_emb, input_state['model1_bufs'] = self.model1(x, speaker_embedding, input_state['model1_bufs'], self.quantized) |
|
|
|
|
|
return conversation_emb, input_state |
|
|
|
|
|
def predict_model2(self, x, conversation_emb, input_state, pad=True): |
|
|
""" |
|
|
B: batch |
|
|
M: mic |
|
|
t: time step (time-domain) |
|
|
x: (B, M, t) |
|
|
R: real or imaginary |
|
|
""" |
|
|
mod = 0 |
|
|
if pad: |
|
|
pad_size = (self.stft_back_pad, self.stft_pad_size) |
|
|
x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size) |
|
|
|
|
|
x = self.extract_features(x) |
|
|
|
|
|
x, input_state['model2_bufs']=self.model2(x, conversation_emb, input_state['model2_bufs'], self.quantized) |
|
|
|
|
|
|
|
|
x, next_state = self.synthesis(x, input_state) |
|
|
|
|
|
if mod != 0: |
|
|
x = x[:, :, :-mod] |
|
|
|
|
|
return x, next_state |
|
|
|
|
|
|
|
|
def forward(self, inputs, input_state = None, pad=True): |
|
|
x = inputs['mixture'] |
|
|
|
|
|
start_idx_input=inputs['start_idx'] |
|
|
end_idx_input=inputs['end_idx'] |
|
|
|
|
|
assert ((end_idx_input - start_idx_input) % self.stft_chunk_size) == 0 |
|
|
|
|
|
|
|
|
start_idx_input = (start_idx_input // self.stft_chunk_size) * self.stft_chunk_size |
|
|
end_idx_input = (end_idx_input // self.stft_chunk_size) * self.stft_chunk_size |
|
|
|
|
|
B, M, t=x.shape |
|
|
|
|
|
audio_range=torch.tensor([start_idx_input, end_idx_input]).to(x.device) |
|
|
audio_range = audio_range.unsqueeze(0).repeat(B, 1) |
|
|
|
|
|
spk_embed = inputs['embed'] |
|
|
self_speech=None |
|
|
|
|
|
if not self.use_speaker_emb_model1: |
|
|
self_speech=inputs['self_speech'] |
|
|
|
|
|
combined_audio = torch.cat((x, self_speech), dim=1) |
|
|
x=combined_audio |
|
|
|
|
|
if input_state is None: |
|
|
input_state = self.init_buffers(x.shape[0], x.device) |
|
|
|
|
|
B, M, t = x.shape |
|
|
|
|
|
|
|
|
conversation_emb, input_state = self.predict_model1(x, input_state, spk_embed, pad=pad) |
|
|
|
|
|
|
|
|
B, _F, T, C = conversation_emb.shape |
|
|
conversation_emb = conversation_emb.permute(0, 1, 3, 2) |
|
|
conversation_emb = torch.roll(conversation_emb, 1, dims=-1) |
|
|
conversation_emb[..., 0] = 0 |
|
|
conversation_emb = conversation_emb.flatten(0,3).unsqueeze(1) |
|
|
multiplier = torch.tile(conversation_emb, (1, self.lstm_fold_chunk)) |
|
|
multiplier = multiplier.reshape(B, _F, C, T, self.lstm_fold_chunk).flatten(3,4) |
|
|
multiplier = multiplier.permute(0, 1, 3, 2) |
|
|
|
|
|
slicing_length=end_idx_input-start_idx_input+self.stft_back_pad+self.stft_pad_size |
|
|
|
|
|
padded_start=start_idx_input-self.stft_back_pad |
|
|
padded_end=end_idx_input+self.stft_pad_size |
|
|
|
|
|
pad_left=max(-padded_start, 0) |
|
|
pad_right=max(padded_end-t, 0) |
|
|
|
|
|
actual_start=max(padded_start, 0) |
|
|
actual_end=min(padded_end, t) |
|
|
|
|
|
if self.use_self_speech_model2: |
|
|
sliced_x=x[:, :, actual_start:actual_end] |
|
|
else: |
|
|
x_no_self_speech=inputs["mixture"] |
|
|
sliced_x=x_no_self_speech[:, :, actual_start:actual_end] |
|
|
|
|
|
padding = (pad_left, pad_right, 0, 0, 0, 0) |
|
|
|
|
|
sliced_x=F.pad(sliced_x, padding, "constant", 0) |
|
|
|
|
|
converted_start_idx=start_idx_input//self.stft_chunk_size |
|
|
converted_end_idx=end_idx_input//self.stft_chunk_size |
|
|
|
|
|
sliced_emb=multiplier[:, :, converted_start_idx:converted_end_idx, :] |
|
|
|
|
|
assert sliced_x.shape[2]==slicing_length |
|
|
assert sliced_emb.shape[2]==(slicing_length-self.stft_back_pad-self.stft_pad_size)//self.stft_chunk_size |
|
|
|
|
|
model2_output, input_state = self.predict_model2(sliced_x, sliced_emb, input_state, pad=False) |
|
|
model2_output = model2_output.reshape(B, self.n_srcs, self.nO, model2_output.shape[-1]) |
|
|
|
|
|
return {'output': model2_output[:, 0], 'next_state': input_state, 'audio_range': audio_range} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |