DTLN / DTLN_model.py
alekya's picture
model files
34d6c18 verified
# -*- coding: utf-8 -*-
"""
This File contains everything to train the DTLN model.
For running the training see "run_training.py".
To run evaluation with the provided pretrained model see "run_evaluation.py".
Author: Nils L. Westhausen (nils.westhausen@uol.de)
Version: 24.06.2020
This code is licensed under the terms of the MIT-license.
"""
import os, fnmatch
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, \
Lambda, Input, Multiply, Layer, Conv1D
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, \
EarlyStopping, ModelCheckpoint
import tensorflow as tf
import soundfile as sf
from wavinfo import WavInfoReader
from random import shuffle, seed
import numpy as np
class audio_generator():
'''
Class to create a Tensorflow dataset based on an iterator from a large scale
audio dataset. This audio generator only supports single channel audio files.
'''
def __init__(self, path_to_input, path_to_s1, len_of_samples, fs, train_flag=False):
'''
Constructor of the audio generator class.
Inputs:
path_to_input path to the mixtures
path_to_s1 path to the target source data
len_of_samples length of audio snippets in samples
fs sampling rate
train_flag flag for activate shuffling of files
'''
# set inputs to properties
self.path_to_input = path_to_input
self.path_to_s1 = path_to_s1
self.len_of_samples = len_of_samples
self.fs = fs
self.train_flag=train_flag
# count the number of samples in your data set (depending on your disk,
# this can take some time)
self.count_samples()
# create iterable tf.data.Dataset object
self.create_tf_data_obj()
def count_samples(self):
'''
Method to list the data of the dataset and count the number of samples.
'''
# list .wav files in directory
self.file_names = fnmatch.filter(os.listdir(self.path_to_input), '*.wav')
# count the number of samples contained in the dataset
self.total_samples = 0
for file in self.file_names:
info = WavInfoReader(os.path.join(self.path_to_input, file))
self.total_samples = self.total_samples + \
int(np.fix(info.data.frame_count/self.len_of_samples))
def create_generator(self):
'''
Method to create the iterator.
'''
# check if training or validation
if self.train_flag:
shuffle(self.file_names)
# iterate over the files
for file in self.file_names:
# read the audio files
noisy, fs_1 = sf.read(os.path.join(self.path_to_input, file))
speech, fs_2 = sf.read(os.path.join(self.path_to_s1, file))
# check if the sampling rates are matching the specifications
if fs_1 != self.fs or fs_2 != self.fs:
raise ValueError('Sampling rates do not match.')
if noisy.ndim != 1 or speech.ndim != 1:
raise ValueError('Too many audio channels. The DTLN audio_generator \
only supports single channel audio data.')
# count the number of samples in one file
num_samples = int(np.fix(noisy.shape[0]/self.len_of_samples))
# iterate over the number of samples
for idx in range(num_samples):
# cut the audio files in chunks
in_dat = noisy[int(idx*self.len_of_samples):int((idx+1)*
self.len_of_samples)]
tar_dat = speech[int(idx*self.len_of_samples):int((idx+1)*
self.len_of_samples)]
# yield the chunks as float32 data
yield in_dat.astype('float32'), tar_dat.astype('float32')
def create_tf_data_obj(self):
'''
Method to to create the tf.data.Dataset.
'''
# creating the tf.data.Dataset from the iterator
self.tf_data_set = tf.data.Dataset.from_generator(
self.create_generator,
(tf.float32, tf.float32),
output_shapes=(tf.TensorShape([self.len_of_samples]), \
tf.TensorShape([self.len_of_samples])),
args=None
)
class DTLN_model():
'''
Class to create and train the DTLN model
'''
def __init__(self):
'''
Constructor
'''
# defining default cost function
self.cost_function = self.snr_cost
# empty property for the model
self.model = []
# defining default parameters
self.fs = 16000
self.batchsize = 32
self.len_samples = 15
self.activation = 'sigmoid'
self.numUnits = 128
self.numLayer = 2
self.blockLen = 512
self.block_shift = 128
self.dropout = 0.25
self.lr = 1e-3
self.max_epochs = 200
self.encoder_size = 256
self.eps = 1e-7
# reset all seeds to 42 to reduce invariance between training runs
os.environ['PYTHONHASHSEED']=str(42)
seed(42)
np.random.seed(42)
tf.random.set_seed(42)
# some line to correctly find some libraries in TF 2.x
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, enable=True)
@staticmethod
def snr_cost(s_estimate, s_true):
'''
Static Method defining the cost function.
The negative signal to noise ratio is calculated here. The loss is
always calculated over the last dimension.
'''
# calculating the SNR
snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
(tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True)+1e-7)
# using some more lines, because TF has no log10
num = tf.math.log(snr)
denom = tf.math.log(tf.constant(10, dtype=num.dtype))
loss = -10*(num / (denom))
# returning the loss
return loss
def lossWrapper(self):
'''
A wrapper function which returns the loss function. This is done to
to enable additional arguments to the loss function if necessary.
'''
def lossFunction(y_true,y_pred):
# calculating loss and squeezing single dimensions away
loss = tf.squeeze(self.cost_function(y_pred,y_true))
# calculate mean over batches
loss = tf.reduce_mean(loss)
# return the loss
return loss
# returning the loss function as handle
return lossFunction
'''
In the following some helper layers are defined.
'''
def stftLayer(self, x):
'''
Method for an STFT helper layer used with a Lambda layer. The layer
calculates the STFT on the last dimension and returns the magnitude and
phase of the STFT.
'''
# creating frames from the continuous waveform
frames = tf.signal.frame(x, self.blockLen, self.block_shift)
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
stft_dat = tf.signal.rfft(frames)
# calculating magnitude and phase from the complex signal
mag = tf.abs(stft_dat)
phase = tf.math.angle(stft_dat)
# returning magnitude and phase as list
return [mag, phase]
def fftLayer(self, x):
'''
Method for an fft helper layer used with a Lambda layer. The layer
calculates the rFFT on the last dimension and returns the magnitude and
phase of the STFT.
'''
# expanding dimensions
frame = tf.expand_dims(x, axis=1)
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
stft_dat = tf.signal.rfft(frame)
# calculating magnitude and phase from the complex signal
mag = tf.abs(stft_dat)
phase = tf.math.angle(stft_dat)
# returning magnitude and phase as list
return [mag, phase]
def ifftLayer(self, x):
'''
Method for an inverse FFT layer used with an Lambda layer. This layer
calculates time domain frames from magnitude and phase information.
As input x a list with [mag,phase] is required.
'''
# calculating the complex representation
s1_stft = (tf.cast(x[0], tf.complex64) *
tf.exp( (1j * tf.cast(x[1], tf.complex64))))
# returning the time domain frames
return tf.signal.irfft(s1_stft)
def overlapAddLayer(self, x):
'''
Method for an overlap and add helper layer used with a Lambda layer.
This layer reconstructs the waveform from a framed signal.
'''
# calculating and returning the reconstructed waveform
return tf.signal.overlap_and_add(x, self.block_shift)
def seperation_kernel(self, num_layer, mask_size, x, stateful=False):
'''
Method to create a separation kernel.
!! Important !!: Do not use this layer with a Lambda layer. If used with
a Lambda layer the gradients are updated correctly.
Inputs:
num_layer Number of LSTM layers
mask_size Output size of the mask and size of the Dense layer
'''
# creating num_layer number of LSTM layers
for idx in range(num_layer):
x = LSTM(self.numUnits, return_sequences=True, stateful=stateful)(x)
# using dropout between the LSTM layer for regularization
if idx<(num_layer-1):
x = Dropout(self.dropout)(x)
# creating the mask with a Dense and an Activation layer
mask = Dense(mask_size)(x)
mask = Activation(self.activation)(mask)
# returning the mask
return mask
def seperation_kernel_with_states(self, num_layer, mask_size, x,
in_states):
'''
Method to create a separation kernel, which returns the LSTM states.
!! Important !!: Do not use this layer with a Lambda layer. If used with
a Lambda layer the gradients are updated correctly.
Inputs:
num_layer Number of LSTM layers
mask_size Output size of the mask and size of the Dense layer
'''
states_h = []
states_c = []
# creating num_layer number of LSTM layers
for idx in range(num_layer):
in_state = [in_states[:,idx,:, 0], in_states[:,idx,:, 1]]
x, h_state, c_state = LSTM(self.numUnits, return_sequences=True,
unroll=True, return_state=True)(x, initial_state=in_state)
# using dropout between the LSTM layer for regularization
if idx<(num_layer-1):
x = Dropout(self.dropout)(x)
states_h.append(h_state)
states_c.append(c_state)
# creating the mask with a Dense and an Activation layer
mask = Dense(mask_size)(x)
mask = Activation(self.activation)(mask)
out_states_h = tf.reshape(tf.stack(states_h, axis=0),
[1,num_layer,self.numUnits])
out_states_c = tf.reshape(tf.stack(states_c, axis=0),
[1,num_layer,self.numUnits])
out_states = tf.stack([out_states_h, out_states_c], axis=-1)
# returning the mask and states
return mask, out_states
def build_DTLN_model(self, norm_stft=False):
'''
Method to build and compile the DTLN model. The model takes time domain
batches of size (batchsize, len_in_samples) and returns enhanced clips
in the same dimensions. As optimizer for the Training process the Adam
optimizer with a gradient norm clipping of 3 is used.
The model contains two separation cores. The first has an STFT signal
transformation and the second a learned transformation based on 1D-Conv
layer.
'''
# input layer for time signal
time_dat = Input(batch_shape=(None, None))
# calculate STFT
mag,angle = Lambda(self.stftLayer)(time_dat)
# normalizing log magnitude stfts to get more robust against level variations
if norm_stft:
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
else:
# behaviour like in the paper
mag_norm = mag
# predicting mask with separation kernel
mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm)
# multiply mask with magnitude
estimated_mag = Multiply()([mag, mask_1])
# transform frames back to time domain
estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle])
# encode time domain frames to feature domain
encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
# normalize the input to the separation kernel
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
# predict mask based on the normalized feature frames
mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm)
# multiply encoded frames with the mask
estimated = Multiply()([encoded_frames, mask_2])
# decode the frames back to time domain
decoded_frames = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
# create waveform with overlap and add procedure
estimated_sig = Lambda(self.overlapAddLayer)(decoded_frames)
# create the model
self.model = Model(inputs=time_dat, outputs=estimated_sig)
# show the model summary
print(self.model.summary())
def build_DTLN_model_stateful(self, norm_stft=False):
'''
Method to build stateful DTLN model for real time processing. The model
takes one time domain frame of size (1, blockLen) and one enhanced frame.
'''
# input layer for time signal
time_dat = Input(batch_shape=(1, self.blockLen))
# calculate STFT
mag,angle = Lambda(self.fftLayer)(time_dat)
# normalizing log magnitude stfts to get more robust against level variations
if norm_stft:
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
else:
# behaviour like in the paper
mag_norm = mag
# predicting mask with separation kernel
mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm, stateful=True)
# multiply mask with magnitude
estimated_mag = Multiply()([mag, mask_1])
# transform frames back to time domain
estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle])
# encode time domain frames to feature domain
encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
# normalize the input to the separation kernel
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
# predict mask based on the normalized feature frames
mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm, stateful=True)
# multiply encoded frames with the mask
estimated = Multiply()([encoded_frames, mask_2])
# decode the frames back to time domain
decoded_frame = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
# create the model
self.model = Model(inputs=time_dat, outputs=decoded_frame)
# show the model summary
print(self.model.summary())
def compile_model(self):
'''
Method to compile the model for training
'''
# use the Adam optimizer with a clipnorm of 3
optimizerAdam = keras.optimizers.Adam(lr=self.lr, clipnorm=3.0)
# compile model with loss function
self.model.compile(loss=self.lossWrapper(), optimizer=optimizerAdam)
def create_saved_model(self, weights_file, target_name):
'''
Method to create a saved model folder from a weights file
'''
# check for type
if weights_file.find('_norm_') != -1:
norm_stft = True
else:
norm_stft = False
# build model
self.build_DTLN_model_stateful(norm_stft=norm_stft)
# load weights
self.model.load_weights(weights_file)
# save model
tf.saved_model.save(self.model, target_name)
def create_tf_lite_model(self, weights_file, target_name, use_dynamic_range_quant=False):
'''
Method to create a tf lite model folder from a weights file.
The conversion creates two models, one for each separation core.
Tf lite does not support complex numbers yet. Some processing must be
done outside the model.
For further information and how real time processing can be
implemented see "real_time_processing_tf_lite.py".
The conversion only works with TF 2.3.
'''
# check for type
if weights_file.find('_norm_') != -1:
norm_stft = True
num_elements_first_core = 2 + self.numLayer * 3 + 2
else:
norm_stft = False
num_elements_first_core = self.numLayer * 3 + 2
# build model
self.build_DTLN_model_stateful(norm_stft=norm_stft)
# load weights
self.model.load_weights(weights_file)
#### Model 1 ##########################
mag = Input(batch_shape=(1, 1, (self.blockLen//2+1)))
states_in_1 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
# normalizing log magnitude stfts to get more robust against level variations
if norm_stft:
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
else:
# behaviour like in the paper
mag_norm = mag
# predicting mask with separation kernel
mask_1, states_out_1 = self.seperation_kernel_with_states(self.numLayer,
(self.blockLen//2+1),
mag_norm, states_in_1)
model_1 = Model(inputs=[mag, states_in_1], outputs=[mask_1, states_out_1])
#### Model 2 ###########################
estimated_frame_1 = Input(batch_shape=(1, 1, (self.blockLen)))
states_in_2 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
# encode time domain frames to feature domain
encoded_frames = Conv1D(self.encoder_size,1,strides=1,
use_bias=False)(estimated_frame_1)
# normalize the input to the separation kernel
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
# predict mask based on the normalized feature frames
mask_2, states_out_2 = self.seperation_kernel_with_states(self.numLayer,
self.encoder_size,
encoded_frames_norm,
states_in_2)
# multiply encoded frames with the mask
estimated = Multiply()([encoded_frames, mask_2])
# decode the frames back to time domain
decoded_frame = Conv1D(self.blockLen, 1, padding='causal',
use_bias=False)(estimated)
model_2 = Model(inputs=[estimated_frame_1, states_in_2],
outputs=[decoded_frame, states_out_2])
# set weights to submodels
weights = self.model.get_weights()
model_1.set_weights(weights[:num_elements_first_core])
model_2.set_weights(weights[num_elements_first_core:])
# convert first model
converter = tf.lite.TFLiteConverter.from_keras_model(model_1)
if use_dynamic_range_quant:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with tf.io.gfile.GFile(target_name + '_1.tflite', 'wb') as f:
f.write(tflite_model)
# convert second model
converter = tf.lite.TFLiteConverter.from_keras_model(model_2)
if use_dynamic_range_quant:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with tf.io.gfile.GFile(target_name + '_2.tflite', 'wb') as f:
f.write(tflite_model)
print('TF lite conversion complete!')
def train_model(self, runName, path_to_train_mix, path_to_train_speech, \
path_to_val_mix, path_to_val_speech):
'''
Method to train the DTLN model.
'''
# create save path if not existent
savePath = './models_'+ runName+'/'
if not os.path.exists(savePath):
os.makedirs(savePath)
# create log file writer
csv_logger = CSVLogger(savePath+ 'training_' +runName+ '.log')
# create callback for the adaptive learning rate
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
patience=3, min_lr=10**(-10), cooldown=1)
# create callback for early stopping
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0,
patience=10, verbose=0, mode='auto', baseline=None)
# create model check pointer to save the best model
checkpointer = ModelCheckpoint(savePath+runName+'.h5',
monitor='val_loss',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='auto',
save_freq='epoch'
)
# calculate length of audio chunks in samples
len_in_samples = int(np.fix(self.fs * self.len_samples /
self.block_shift)*self.block_shift)
# create data generator for training data
generator_input = audio_generator(path_to_train_mix,
path_to_train_speech,
len_in_samples,
self.fs, train_flag=True)
dataset = generator_input.tf_data_set
dataset = dataset.batch(self.batchsize, drop_remainder=True).repeat()
# calculate number of training steps in one epoch
steps_train = generator_input.total_samples//self.batchsize
# create data generator for validation data
generator_val = audio_generator(path_to_val_mix,
path_to_val_speech,
len_in_samples, self.fs)
dataset_val = generator_val.tf_data_set
dataset_val = dataset_val.batch(self.batchsize, drop_remainder=True).repeat()
# calculate number of validation steps
steps_val = generator_val.total_samples//self.batchsize
# start the training of the model
self.model.fit(
x=dataset,
batch_size=None,
steps_per_epoch=steps_train,
epochs=self.max_epochs,
verbose=1,
validation_data=dataset_val,
validation_steps=steps_val,
callbacks=[checkpointer, reduce_lr, csv_logger, early_stopping],
max_queue_size=50,
workers=4,
use_multiprocessing=True)
# clear out garbage
tf.keras.backend.clear_session()
class InstantLayerNormalization(Layer):
'''
Class implementing instant layer normalization. It can also be called
channel-wise layer normalization and was proposed by
Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2)
'''
def __init__(self, **kwargs):
'''
Constructor
'''
super(InstantLayerNormalization, self).__init__(**kwargs)
self.epsilon = 1e-7
self.gamma = None
self.beta = None
def build(self, input_shape):
'''
Method to build the weights.
'''
shape = input_shape[-1:]
# initialize gamma
self.gamma = self.add_weight(shape=shape,
initializer='ones',
trainable=True,
name='gamma')
# initialize beta
self.beta = self.add_weight(shape=shape,
initializer='zeros',
trainable=True,
name='beta')
def call(self, inputs):
'''
Method to call the Layer. All processing is done here.
'''
# calculate mean of each frame
mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
# calculate variance of each frame
variance = tf.math.reduce_mean(tf.math.square(inputs - mean),
axis=[-1], keepdims=True)
# calculate standard deviation
std = tf.math.sqrt(variance + self.epsilon)
# normalize each frame independently
outputs = (inputs - mean) / std
# scale with gamma
outputs = outputs * self.gamma
# add the bias beta
outputs = outputs + self.beta
# return output
return outputs