proactive_hearing / src /models /network /net_conversation_joint.py
guilinhu's picture
Upload folder using huggingface_hub
df9f13e verified
import torch
import torch.nn as nn
from .model1 import Conv_Emb_Generator
from .model2_joint import TSH
import torch.nn.functional as F
import numpy as np
import copy
def mod_pad(x, chunk_size, pad):
mod = 0
if (x.shape[-1] % chunk_size) != 0:
mod = chunk_size - (x.shape[-1] % chunk_size)
x = F.pad(x, (0, mod))
x = F.pad(x, pad)
return x, mod
# A TF-domain network guided by an embedding vector
class Net_Conversation(nn.Module):
def __init__(self,
model1_block_name,
model1_block_params,
model2_block_name,
model2_block_params,
stft_chunk_size=64,
stft_pad_size=32,
stft_back_pad=32,
num_input_channels=1,
num_output_channels=1,
num_sources=1,
speaker_embed = 256,
num_layers_model1=3,
num_layers_model2=3,
latent_dim_model1=16,
latent_dim_model2=32,
use_sp_feats=False,
use_first_ln=True,
n_imics=1,
window="hann",
lstm_fold_chunk=400,
E=2,
use_speaker_emb_model1=True,
one_emb_model1=True,
use_self_speech_model2=True,
local_context_len=-1
):
super(Net_Conversation, self).__init__()
assert num_sources == 1
# num input/output channels
self.nI = num_input_channels
self.nO = num_output_channels
# num channels to the TF-network
num_separator_inputs = self.nI * 2 + use_sp_feats * (3 * (self.nI - 1))
self.stft_chunk_size = stft_chunk_size
self.stft_pad_size = stft_pad_size
self.stft_back_pad = stft_back_pad
self.n_srcs = num_sources
self.use_sp_feats = use_sp_feats
# Input conv to convert input audio to a latent representation
self.nfft = stft_back_pad + stft_chunk_size + stft_pad_size
self.nfreqs = self.nfft//2 + 1
self.lstm_fold_chunk=lstm_fold_chunk
# Construct synthesis/analysis windows (rect)
if window=="hann":
window_fn = lambda x: np.hanning(x)
elif window=="rect":
window_fn = lambda x: np.ones(x)
else:
raise ValueError("Invalid window type!")
if ((stft_pad_size) % stft_chunk_size) == 0:
print("Using perfect STFT windows")
self.analysis_window = torch.from_numpy(window_fn(self.nfft)).float()
# eg. inverse SFTF
self.synthesis_window = torch.zeros(stft_pad_size + stft_chunk_size).float()
A = self.synthesis_window.shape[0]
B = self.stft_chunk_size
N = self.analysis_window.shape[0]
assert (A % B) == 0
for i in range(A):
num = self.analysis_window[N - A + i]
denom = 0
for k in range(A//B):
denom += (self.analysis_window[N - A + (i % B) + k * B] ** 2)
self.synthesis_window[i] = num / denom
else:
print("Using imperfect STFT windows")
self.analysis_window = torch.from_numpy( window_fn(self.nfft) ).float()
self.synthesis_window = torch.from_numpy( window_fn(stft_chunk_size + stft_pad_size) ).float()
self.istft_lookback = 1 + (self.synthesis_window.shape[0] - 1) // self.stft_chunk_size
if local_context_len!=-1:
local_context_len=local_context_len//stft_chunk_size//lstm_fold_chunk
self.model1 = Conv_Emb_Generator(
model1_block_name,
model1_block_params,
spk_dim = speaker_embed,
latent_dim = latent_dim_model1,
n_srcs = num_output_channels * num_sources,
n_fft = self.nfft,
num_inputs = num_separator_inputs,
n_layers = num_layers_model1,
use_first_ln=use_first_ln,
n_imics=n_imics,
lstm_fold_chunk=lstm_fold_chunk,
E=E,
use_speaker_emb=use_speaker_emb_model1,
one_emb=one_emb_model1,
local_context_len=local_context_len
)
self.quantized = False
self.use_self_speech_model2=use_self_speech_model2
self.model2=TSH(
model2_block_name,
model2_block_params,
spk_dim = speaker_embed,
latent_dim = latent_dim_model2,
latent_dim_model1=latent_dim_model1,
n_srcs = num_output_channels * num_sources,
n_fft = self.nfft,
num_inputs = num_separator_inputs,
n_layers = num_layers_model2,
use_first_ln=use_first_ln,
n_imics=n_imics,
lstm_fold_chunk=lstm_fold_chunk,
stft_chunk_size=stft_chunk_size,
use_speaker_emb=use_speaker_emb_model1,
use_self_speech_model2=use_self_speech_model2
)
self.use_speaker_emb_model1=use_speaker_emb_model1
def init_buffers(self, batch_size, device):
buffers = {}
buffers['model1_bufs'] = self.model1.init_buffers(batch_size, device)
buffers['model2_bufs'] = self.model2.init_buffers(batch_size, device)
buffers['istft_buf'] = torch.zeros(batch_size * self.n_srcs * self.nO,
self.synthesis_window.shape[0],
self.istft_lookback, device=device)
return buffers
# compute STFT
def extract_features(self, x):
"""
x: (B, M, T)
returns: (B, C*M, T, F)
"""
B, M, T = x.shape
x = x.reshape(B*M, T)
x = torch.stft(x, n_fft = self.nfft, hop_length = self.stft_chunk_size,
win_length = self.nfft, window=self.analysis_window.to(x.device),
center=False, normalized=False, return_complex=True)
x = torch.view_as_real(x) # [B*M, F, T, 2]
BM, _F, T, C = x.shape
x = x.reshape(B, M, _F, T, C) # [B, M, F, T, 2]
x = x.permute(0, 4, 1, 3, 2) # [B, 2, M. T, F]
x = x.reshape(B, C*M, T, _F)
return x
def synthesis(self, x, input_state):
"""
x: (B, S, T, C*F)
returns: (B, S, t)
"""
istft_buf = input_state['istft_buf']
x = x.transpose(2, 3) # [B, S, CF, T]
B, S, CF, T = x.shape
X = x.reshape(B*S, CF, T)
X = X.reshape(B*S, 2, -1, T).permute(0, 2, 3, 1) # [BS, F, T, C]
X = X[..., 0] + 1j * X[..., 1]
x = torch.fft.irfft(X, dim=1) # [BS, iW, T]
x = x[:, -self.synthesis_window.shape[0]:] # [BS, oW, T]
# Apply synthesis window
x = x * self.synthesis_window.unsqueeze(0).unsqueeze(-1).to(x.device)
oW = self.synthesis_window.shape[0]
# Concatenate blocks from previous IFFTs
x = torch.cat([istft_buf, x], dim=-1)
istft_buf = x[..., -istft_buf.shape[1]:] # Update buffer
# Get full signal
x = F.fold(x, output_size=(self.stft_chunk_size * x.shape[-1] + (oW - self.stft_chunk_size), 1),
kernel_size=(oW, 1), stride=(self.stft_chunk_size, 1)) # [BS, 1, t]
x = x[:, :, -T * self.stft_chunk_size - self.stft_pad_size: - self.stft_pad_size]
x = x.reshape(B, S, -1) # [B, S, t]
input_state['istft_buf'] = istft_buf
return x, input_state
def predict_model1(self, x, input_state, speaker_embedding, pad=True):
"""
B: batch
M: mic
t: time step (time-domain)
x: (B, M, t)
R: real or imaginary
"""
mod = 0
if pad:
pad_size = (self.stft_back_pad, self.stft_pad_size)
x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size)
# Time-domain to TF-domain
x = self.extract_features(x) # [B, RM, T, F]
if speaker_embedding is not None:
speaker_embedding=speaker_embedding.unsqueeze(2)
conversation_emb, input_state['model1_bufs'] = self.model1(x, speaker_embedding, input_state['model1_bufs'], self.quantized)
return conversation_emb, input_state
def predict_model2(self, x, conversation_emb, input_state, pad=True):
"""
B: batch
M: mic
t: time step (time-domain)
x: (B, M, t)
R: real or imaginary
"""
mod = 0
if pad:
pad_size = (self.stft_back_pad, self.stft_pad_size)
x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size)
x = self.extract_features(x)
x, input_state['model2_bufs']=self.model2(x, conversation_emb, input_state['model2_bufs'], self.quantized)
# TF-domain to time-domain
x, next_state = self.synthesis(x, input_state) # [B, S * M, t]
if mod != 0:
x = x[:, :, :-mod]
return x, next_state
def forward(self, inputs, input_state = None, pad=True):
x = inputs['mixture']
start_idx_input=inputs['start_idx']
end_idx_input=inputs['end_idx']
assert ((end_idx_input - start_idx_input) % self.stft_chunk_size) == 0
# Snap start and end to chunk
start_idx_input = (start_idx_input // self.stft_chunk_size) * self.stft_chunk_size
end_idx_input = (end_idx_input // self.stft_chunk_size) * self.stft_chunk_size
B, M, t=x.shape
audio_range=torch.tensor([start_idx_input, end_idx_input]).to(x.device)
audio_range = audio_range.unsqueeze(0).repeat(B, 1)
spk_embed = inputs['embed']
self_speech=None
if not self.use_speaker_emb_model1:
self_speech=inputs['self_speech']
combined_audio = torch.cat((x, self_speech), dim=1)
x=combined_audio
if input_state is None:
input_state = self.init_buffers(x.shape[0], x.device)
B, M, t = x.shape
# enter slow model
conversation_emb, input_state = self.predict_model1(x, input_state, spk_embed, pad=pad) # [B, F, T, C]
# slice conv embedding and corresponding audio
B, _F, T, C = conversation_emb.shape
conversation_emb = conversation_emb.permute(0, 1, 3, 2) # [B, F, C, T]
conversation_emb = torch.roll(conversation_emb, 1, dims=-1)
conversation_emb[..., 0] = 0
conversation_emb = conversation_emb.flatten(0,3).unsqueeze(1) # [*, 1]
multiplier = torch.tile(conversation_emb, (1, self.lstm_fold_chunk)) # [*, L]
multiplier = multiplier.reshape(B, _F, C, T, self.lstm_fold_chunk).flatten(3,4) # [B, F, C, T*L]
multiplier = multiplier.permute(0, 1, 3, 2) # [B, F, T*L, C]
slicing_length=end_idx_input-start_idx_input+self.stft_back_pad+self.stft_pad_size
padded_start=start_idx_input-self.stft_back_pad
padded_end=end_idx_input+self.stft_pad_size
pad_left=max(-padded_start, 0)
pad_right=max(padded_end-t, 0)
actual_start=max(padded_start, 0)
actual_end=min(padded_end, t)
if self.use_self_speech_model2:
sliced_x=x[:, :, actual_start:actual_end]
else:
x_no_self_speech=inputs["mixture"]
sliced_x=x_no_self_speech[:, :, actual_start:actual_end]
padding = (pad_left, pad_right, 0, 0, 0, 0)
sliced_x=F.pad(sliced_x, padding, "constant", 0)
converted_start_idx=start_idx_input//self.stft_chunk_size
converted_end_idx=end_idx_input//self.stft_chunk_size
sliced_emb=multiplier[:, :, converted_start_idx:converted_end_idx, :]
assert sliced_x.shape[2]==slicing_length
assert sliced_emb.shape[2]==(slicing_length-self.stft_back_pad-self.stft_pad_size)//self.stft_chunk_size
model2_output, input_state = self.predict_model2(sliced_x, sliced_emb, input_state, pad=False)
model2_output = model2_output.reshape(B, self.n_srcs, self.nO, model2_output.shape[-1])
return {'output': model2_output[:, 0], 'next_state': input_state, 'audio_range': audio_range}
if __name__ == "__main__":
pass