DTLN (models)
Browse files- models/DTLN (alekya)/.gitattributes +35 -0
- models/DTLN (alekya)/DTLN_model.py +641 -0
- models/DTLN (alekya)/LICENSE +21 -0
- models/DTLN (alekya)/README.md +302 -0
- models/DTLN (alekya)/convert_weights_to_onnx.py +102 -0
- models/DTLN (alekya)/convert_weights_to_saved_model.py +30 -0
- models/DTLN (alekya)/convert_weights_to_tf_lite.py +41 -0
- models/DTLN (alekya)/eval_env.yml +117 -0
- models/DTLN (alekya)/main.py +134 -0
- models/DTLN (alekya)/main_1.py +3 -0
- models/DTLN (alekya)/measure_execution_time.py +44 -0
- models/DTLN (alekya)/model.h5 +3 -0
- models/DTLN (alekya)/real_time_dtln_audio.py +147 -0
- models/DTLN (alekya)/real_time_processing.py +55 -0
- models/DTLN (alekya)/real_time_processing_onnx.py +106 -0
- models/DTLN (alekya)/real_time_processing_tf_lite.py +102 -0
- models/DTLN (alekya)/run_evaluation.py +131 -0
- models/DTLN (alekya)/run_training.py +59 -0
- models/DTLN (alekya)/source.txt +1 -0
- models/DTLN (alekya)/test.py +47 -0
- models/DTLN (alekya)/tflite_env.yml +73 -0
- models/DTLN (alekya)/train_env.yml +120 -0
- models/DTLN (yash-04)/.gitattributes +35 -0
- models/DTLN (yash-04)/DTLN_model.py +641 -0
- models/DTLN (yash-04)/LICENSE +21 -0
- models/DTLN (yash-04)/README.md +302 -0
- models/DTLN (yash-04)/convert_weights_to_onnx.py +102 -0
- models/DTLN (yash-04)/convert_weights_to_saved_model.py +30 -0
- models/DTLN (yash-04)/convert_weights_to_tf_lite.py +41 -0
- models/DTLN (yash-04)/eval_env.yml +117 -0
- models/DTLN (yash-04)/measure_execution_time.py +44 -0
- models/DTLN (yash-04)/real_time_dtln_audio.py +147 -0
- models/DTLN (yash-04)/real_time_processing.py +55 -0
- models/DTLN (yash-04)/real_time_processing_onnx.py +106 -0
- models/DTLN (yash-04)/real_time_processing_tf_lite.py +102 -0
- models/DTLN (yash-04)/run_evaluation.py +131 -0
- models/DTLN (yash-04)/run_training.py +59 -0
- models/DTLN (yash-04)/source.txt +1 -0
- models/DTLN (yash-04)/tflite_env.yml +73 -0
- models/DTLN (yash-04)/train_env.yml +120 -0
models/DTLN (alekya)/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
models/DTLN (alekya)/DTLN_model.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
This File contains everything to train the DTLN model.
|
| 4 |
+
|
| 5 |
+
For running the training see "run_training.py".
|
| 6 |
+
To run evaluation with the provided pretrained model see "run_evaluation.py".
|
| 7 |
+
|
| 8 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 9 |
+
Version: 24.06.2020
|
| 10 |
+
|
| 11 |
+
This code is licensed under the terms of the MIT-license.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os, fnmatch
|
| 16 |
+
import tensorflow.keras as keras
|
| 17 |
+
from tensorflow.keras.models import Model
|
| 18 |
+
from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, \
|
| 19 |
+
Lambda, Input, Multiply, Layer, Conv1D
|
| 20 |
+
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, \
|
| 21 |
+
EarlyStopping, ModelCheckpoint
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
import soundfile as sf
|
| 24 |
+
from wavinfo import WavInfoReader
|
| 25 |
+
from random import shuffle, seed
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class audio_generator():
|
| 31 |
+
'''
|
| 32 |
+
Class to create a Tensorflow dataset based on an iterator from a large scale
|
| 33 |
+
audio dataset. This audio generator only supports single channel audio files.
|
| 34 |
+
'''
|
| 35 |
+
|
| 36 |
+
def __init__(self, path_to_input, path_to_s1, len_of_samples, fs, train_flag=False):
|
| 37 |
+
'''
|
| 38 |
+
Constructor of the audio generator class.
|
| 39 |
+
Inputs:
|
| 40 |
+
path_to_input path to the mixtures
|
| 41 |
+
path_to_s1 path to the target source data
|
| 42 |
+
len_of_samples length of audio snippets in samples
|
| 43 |
+
fs sampling rate
|
| 44 |
+
train_flag flag for activate shuffling of files
|
| 45 |
+
'''
|
| 46 |
+
# set inputs to properties
|
| 47 |
+
self.path_to_input = path_to_input
|
| 48 |
+
self.path_to_s1 = path_to_s1
|
| 49 |
+
self.len_of_samples = len_of_samples
|
| 50 |
+
self.fs = fs
|
| 51 |
+
self.train_flag=train_flag
|
| 52 |
+
# count the number of samples in your data set (depending on your disk,
|
| 53 |
+
# this can take some time)
|
| 54 |
+
self.count_samples()
|
| 55 |
+
# create iterable tf.data.Dataset object
|
| 56 |
+
self.create_tf_data_obj()
|
| 57 |
+
|
| 58 |
+
def count_samples(self):
|
| 59 |
+
'''
|
| 60 |
+
Method to list the data of the dataset and count the number of samples.
|
| 61 |
+
'''
|
| 62 |
+
|
| 63 |
+
# list .wav files in directory
|
| 64 |
+
self.file_names = fnmatch.filter(os.listdir(self.path_to_input), '*.wav')
|
| 65 |
+
# count the number of samples contained in the dataset
|
| 66 |
+
self.total_samples = 0
|
| 67 |
+
for file in self.file_names:
|
| 68 |
+
info = WavInfoReader(os.path.join(self.path_to_input, file))
|
| 69 |
+
self.total_samples = self.total_samples + \
|
| 70 |
+
int(np.fix(info.data.frame_count/self.len_of_samples))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def create_generator(self):
|
| 74 |
+
'''
|
| 75 |
+
Method to create the iterator.
|
| 76 |
+
'''
|
| 77 |
+
|
| 78 |
+
# check if training or validation
|
| 79 |
+
if self.train_flag:
|
| 80 |
+
shuffle(self.file_names)
|
| 81 |
+
# iterate over the files
|
| 82 |
+
for file in self.file_names:
|
| 83 |
+
# read the audio files
|
| 84 |
+
noisy, fs_1 = sf.read(os.path.join(self.path_to_input, file))
|
| 85 |
+
speech, fs_2 = sf.read(os.path.join(self.path_to_s1, file))
|
| 86 |
+
# check if the sampling rates are matching the specifications
|
| 87 |
+
if fs_1 != self.fs or fs_2 != self.fs:
|
| 88 |
+
raise ValueError('Sampling rates do not match.')
|
| 89 |
+
if noisy.ndim != 1 or speech.ndim != 1:
|
| 90 |
+
raise ValueError('Too many audio channels. The DTLN audio_generator \
|
| 91 |
+
only supports single channel audio data.')
|
| 92 |
+
# count the number of samples in one file
|
| 93 |
+
num_samples = int(np.fix(noisy.shape[0]/self.len_of_samples))
|
| 94 |
+
# iterate over the number of samples
|
| 95 |
+
for idx in range(num_samples):
|
| 96 |
+
# cut the audio files in chunks
|
| 97 |
+
in_dat = noisy[int(idx*self.len_of_samples):int((idx+1)*
|
| 98 |
+
self.len_of_samples)]
|
| 99 |
+
tar_dat = speech[int(idx*self.len_of_samples):int((idx+1)*
|
| 100 |
+
self.len_of_samples)]
|
| 101 |
+
# yield the chunks as float32 data
|
| 102 |
+
yield in_dat.astype('float32'), tar_dat.astype('float32')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_tf_data_obj(self):
|
| 106 |
+
'''
|
| 107 |
+
Method to to create the tf.data.Dataset.
|
| 108 |
+
'''
|
| 109 |
+
|
| 110 |
+
# creating the tf.data.Dataset from the iterator
|
| 111 |
+
self.tf_data_set = tf.data.Dataset.from_generator(
|
| 112 |
+
self.create_generator,
|
| 113 |
+
(tf.float32, tf.float32),
|
| 114 |
+
output_shapes=(tf.TensorShape([self.len_of_samples]), \
|
| 115 |
+
tf.TensorShape([self.len_of_samples])),
|
| 116 |
+
args=None
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DTLN_model():
|
| 124 |
+
'''
|
| 125 |
+
Class to create and train the DTLN model
|
| 126 |
+
'''
|
| 127 |
+
|
| 128 |
+
def __init__(self):
|
| 129 |
+
'''
|
| 130 |
+
Constructor
|
| 131 |
+
'''
|
| 132 |
+
|
| 133 |
+
# defining default cost function
|
| 134 |
+
self.cost_function = self.snr_cost
|
| 135 |
+
# empty property for the model
|
| 136 |
+
self.model = []
|
| 137 |
+
# defining default parameters
|
| 138 |
+
self.fs = 16000
|
| 139 |
+
self.batchsize = 32
|
| 140 |
+
self.len_samples = 15
|
| 141 |
+
self.activation = 'sigmoid'
|
| 142 |
+
self.numUnits = 128
|
| 143 |
+
self.numLayer = 2
|
| 144 |
+
self.blockLen = 512
|
| 145 |
+
self.block_shift = 128
|
| 146 |
+
self.dropout = 0.25
|
| 147 |
+
self.lr = 1e-3
|
| 148 |
+
self.max_epochs = 200
|
| 149 |
+
self.encoder_size = 256
|
| 150 |
+
self.eps = 1e-7
|
| 151 |
+
# reset all seeds to 42 to reduce invariance between training runs
|
| 152 |
+
os.environ['PYTHONHASHSEED']=str(42)
|
| 153 |
+
seed(42)
|
| 154 |
+
np.random.seed(42)
|
| 155 |
+
tf.random.set_seed(42)
|
| 156 |
+
# some line to correctly find some libraries in TF 2.x
|
| 157 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
| 158 |
+
if len(physical_devices) > 0:
|
| 159 |
+
for device in physical_devices:
|
| 160 |
+
tf.config.experimental.set_memory_growth(device, enable=True)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@staticmethod
|
| 164 |
+
def snr_cost(s_estimate, s_true):
|
| 165 |
+
'''
|
| 166 |
+
Static Method defining the cost function.
|
| 167 |
+
The negative signal to noise ratio is calculated here. The loss is
|
| 168 |
+
always calculated over the last dimension.
|
| 169 |
+
'''
|
| 170 |
+
|
| 171 |
+
# calculating the SNR
|
| 172 |
+
snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
|
| 173 |
+
(tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True)+1e-7)
|
| 174 |
+
# using some more lines, because TF has no log10
|
| 175 |
+
num = tf.math.log(snr)
|
| 176 |
+
denom = tf.math.log(tf.constant(10, dtype=num.dtype))
|
| 177 |
+
loss = -10*(num / (denom))
|
| 178 |
+
# returning the loss
|
| 179 |
+
return loss
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def lossWrapper(self):
|
| 183 |
+
'''
|
| 184 |
+
A wrapper function which returns the loss function. This is done to
|
| 185 |
+
to enable additional arguments to the loss function if necessary.
|
| 186 |
+
'''
|
| 187 |
+
def lossFunction(y_true,y_pred):
|
| 188 |
+
# calculating loss and squeezing single dimensions away
|
| 189 |
+
loss = tf.squeeze(self.cost_function(y_pred,y_true))
|
| 190 |
+
# calculate mean over batches
|
| 191 |
+
loss = tf.reduce_mean(loss)
|
| 192 |
+
# return the loss
|
| 193 |
+
return loss
|
| 194 |
+
# returning the loss function as handle
|
| 195 |
+
return lossFunction
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
'''
|
| 200 |
+
In the following some helper layers are defined.
|
| 201 |
+
'''
|
| 202 |
+
|
| 203 |
+
def stftLayer(self, x):
|
| 204 |
+
'''
|
| 205 |
+
Method for an STFT helper layer used with a Lambda layer. The layer
|
| 206 |
+
calculates the STFT on the last dimension and returns the magnitude and
|
| 207 |
+
phase of the STFT.
|
| 208 |
+
'''
|
| 209 |
+
|
| 210 |
+
# creating frames from the continuous waveform
|
| 211 |
+
frames = tf.signal.frame(x, self.blockLen, self.block_shift)
|
| 212 |
+
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
|
| 213 |
+
stft_dat = tf.signal.rfft(frames)
|
| 214 |
+
# calculating magnitude and phase from the complex signal
|
| 215 |
+
mag = tf.abs(stft_dat)
|
| 216 |
+
phase = tf.math.angle(stft_dat)
|
| 217 |
+
# returning magnitude and phase as list
|
| 218 |
+
return [mag, phase]
|
| 219 |
+
|
| 220 |
+
def fftLayer(self, x):
|
| 221 |
+
'''
|
| 222 |
+
Method for an fft helper layer used with a Lambda layer. The layer
|
| 223 |
+
calculates the rFFT on the last dimension and returns the magnitude and
|
| 224 |
+
phase of the STFT.
|
| 225 |
+
'''
|
| 226 |
+
|
| 227 |
+
# expanding dimensions
|
| 228 |
+
frame = tf.expand_dims(x, axis=1)
|
| 229 |
+
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
|
| 230 |
+
stft_dat = tf.signal.rfft(frame)
|
| 231 |
+
# calculating magnitude and phase from the complex signal
|
| 232 |
+
mag = tf.abs(stft_dat)
|
| 233 |
+
phase = tf.math.angle(stft_dat)
|
| 234 |
+
# returning magnitude and phase as list
|
| 235 |
+
return [mag, phase]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def ifftLayer(self, x):
|
| 240 |
+
'''
|
| 241 |
+
Method for an inverse FFT layer used with an Lambda layer. This layer
|
| 242 |
+
calculates time domain frames from magnitude and phase information.
|
| 243 |
+
As input x a list with [mag,phase] is required.
|
| 244 |
+
'''
|
| 245 |
+
|
| 246 |
+
# calculating the complex representation
|
| 247 |
+
s1_stft = (tf.cast(x[0], tf.complex64) *
|
| 248 |
+
tf.exp( (1j * tf.cast(x[1], tf.complex64))))
|
| 249 |
+
# returning the time domain frames
|
| 250 |
+
return tf.signal.irfft(s1_stft)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def overlapAddLayer(self, x):
|
| 254 |
+
'''
|
| 255 |
+
Method for an overlap and add helper layer used with a Lambda layer.
|
| 256 |
+
This layer reconstructs the waveform from a framed signal.
|
| 257 |
+
'''
|
| 258 |
+
|
| 259 |
+
# calculating and returning the reconstructed waveform
|
| 260 |
+
return tf.signal.overlap_and_add(x, self.block_shift)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def seperation_kernel(self, num_layer, mask_size, x, stateful=False):
|
| 265 |
+
'''
|
| 266 |
+
Method to create a separation kernel.
|
| 267 |
+
!! Important !!: Do not use this layer with a Lambda layer. If used with
|
| 268 |
+
a Lambda layer the gradients are updated correctly.
|
| 269 |
+
|
| 270 |
+
Inputs:
|
| 271 |
+
num_layer Number of LSTM layers
|
| 272 |
+
mask_size Output size of the mask and size of the Dense layer
|
| 273 |
+
'''
|
| 274 |
+
|
| 275 |
+
# creating num_layer number of LSTM layers
|
| 276 |
+
for idx in range(num_layer):
|
| 277 |
+
x = LSTM(self.numUnits, return_sequences=True, stateful=stateful)(x)
|
| 278 |
+
# using dropout between the LSTM layer for regularization
|
| 279 |
+
if idx<(num_layer-1):
|
| 280 |
+
x = Dropout(self.dropout)(x)
|
| 281 |
+
# creating the mask with a Dense and an Activation layer
|
| 282 |
+
mask = Dense(mask_size)(x)
|
| 283 |
+
mask = Activation(self.activation)(mask)
|
| 284 |
+
# returning the mask
|
| 285 |
+
return mask
|
| 286 |
+
|
| 287 |
+
def seperation_kernel_with_states(self, num_layer, mask_size, x,
|
| 288 |
+
in_states):
|
| 289 |
+
'''
|
| 290 |
+
Method to create a separation kernel, which returns the LSTM states.
|
| 291 |
+
!! Important !!: Do not use this layer with a Lambda layer. If used with
|
| 292 |
+
a Lambda layer the gradients are updated correctly.
|
| 293 |
+
|
| 294 |
+
Inputs:
|
| 295 |
+
num_layer Number of LSTM layers
|
| 296 |
+
mask_size Output size of the mask and size of the Dense layer
|
| 297 |
+
'''
|
| 298 |
+
|
| 299 |
+
states_h = []
|
| 300 |
+
states_c = []
|
| 301 |
+
# creating num_layer number of LSTM layers
|
| 302 |
+
for idx in range(num_layer):
|
| 303 |
+
in_state = [in_states[:,idx,:, 0], in_states[:,idx,:, 1]]
|
| 304 |
+
x, h_state, c_state = LSTM(self.numUnits, return_sequences=True,
|
| 305 |
+
unroll=True, return_state=True)(x, initial_state=in_state)
|
| 306 |
+
# using dropout between the LSTM layer for regularization
|
| 307 |
+
if idx<(num_layer-1):
|
| 308 |
+
x = Dropout(self.dropout)(x)
|
| 309 |
+
states_h.append(h_state)
|
| 310 |
+
states_c.append(c_state)
|
| 311 |
+
# creating the mask with a Dense and an Activation layer
|
| 312 |
+
mask = Dense(mask_size)(x)
|
| 313 |
+
mask = Activation(self.activation)(mask)
|
| 314 |
+
out_states_h = tf.reshape(tf.stack(states_h, axis=0),
|
| 315 |
+
[1,num_layer,self.numUnits])
|
| 316 |
+
out_states_c = tf.reshape(tf.stack(states_c, axis=0),
|
| 317 |
+
[1,num_layer,self.numUnits])
|
| 318 |
+
out_states = tf.stack([out_states_h, out_states_c], axis=-1)
|
| 319 |
+
# returning the mask and states
|
| 320 |
+
return mask, out_states
|
| 321 |
+
|
| 322 |
+
def build_DTLN_model(self, norm_stft=False):
|
| 323 |
+
'''
|
| 324 |
+
Method to build and compile the DTLN model. The model takes time domain
|
| 325 |
+
batches of size (batchsize, len_in_samples) and returns enhanced clips
|
| 326 |
+
in the same dimensions. As optimizer for the Training process the Adam
|
| 327 |
+
optimizer with a gradient norm clipping of 3 is used.
|
| 328 |
+
The model contains two separation cores. The first has an STFT signal
|
| 329 |
+
transformation and the second a learned transformation based on 1D-Conv
|
| 330 |
+
layer.
|
| 331 |
+
'''
|
| 332 |
+
|
| 333 |
+
# input layer for time signal
|
| 334 |
+
time_dat = Input(batch_shape=(None, None))
|
| 335 |
+
# calculate STFT
|
| 336 |
+
mag,angle = Lambda(self.stftLayer)(time_dat)
|
| 337 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 338 |
+
if norm_stft:
|
| 339 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 340 |
+
else:
|
| 341 |
+
# behaviour like in the paper
|
| 342 |
+
mag_norm = mag
|
| 343 |
+
# predicting mask with separation kernel
|
| 344 |
+
mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm)
|
| 345 |
+
# multiply mask with magnitude
|
| 346 |
+
estimated_mag = Multiply()([mag, mask_1])
|
| 347 |
+
# transform frames back to time domain
|
| 348 |
+
estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle])
|
| 349 |
+
# encode time domain frames to feature domain
|
| 350 |
+
encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
|
| 351 |
+
# normalize the input to the separation kernel
|
| 352 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 353 |
+
# predict mask based on the normalized feature frames
|
| 354 |
+
mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm)
|
| 355 |
+
# multiply encoded frames with the mask
|
| 356 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 357 |
+
# decode the frames back to time domain
|
| 358 |
+
decoded_frames = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
|
| 359 |
+
# create waveform with overlap and add procedure
|
| 360 |
+
estimated_sig = Lambda(self.overlapAddLayer)(decoded_frames)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# create the model
|
| 364 |
+
self.model = Model(inputs=time_dat, outputs=estimated_sig)
|
| 365 |
+
# show the model summary
|
| 366 |
+
print(self.model.summary())
|
| 367 |
+
|
| 368 |
+
def build_DTLN_model_stateful(self, norm_stft=False):
|
| 369 |
+
'''
|
| 370 |
+
Method to build stateful DTLN model for real time processing. The model
|
| 371 |
+
takes one time domain frame of size (1, blockLen) and one enhanced frame.
|
| 372 |
+
|
| 373 |
+
'''
|
| 374 |
+
|
| 375 |
+
# input layer for time signal
|
| 376 |
+
time_dat = Input(batch_shape=(1, self.blockLen))
|
| 377 |
+
# calculate STFT
|
| 378 |
+
mag,angle = Lambda(self.fftLayer)(time_dat)
|
| 379 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 380 |
+
if norm_stft:
|
| 381 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 382 |
+
else:
|
| 383 |
+
# behaviour like in the paper
|
| 384 |
+
mag_norm = mag
|
| 385 |
+
# predicting mask with separation kernel
|
| 386 |
+
mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm, stateful=True)
|
| 387 |
+
# multiply mask with magnitude
|
| 388 |
+
estimated_mag = Multiply()([mag, mask_1])
|
| 389 |
+
# transform frames back to time domain
|
| 390 |
+
estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle])
|
| 391 |
+
# encode time domain frames to feature domain
|
| 392 |
+
encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
|
| 393 |
+
# normalize the input to the separation kernel
|
| 394 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 395 |
+
# predict mask based on the normalized feature frames
|
| 396 |
+
mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm, stateful=True)
|
| 397 |
+
# multiply encoded frames with the mask
|
| 398 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 399 |
+
# decode the frames back to time domain
|
| 400 |
+
decoded_frame = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
|
| 401 |
+
# create the model
|
| 402 |
+
self.model = Model(inputs=time_dat, outputs=decoded_frame)
|
| 403 |
+
# show the model summary
|
| 404 |
+
print(self.model.summary())
|
| 405 |
+
|
| 406 |
+
def compile_model(self):
|
| 407 |
+
'''
|
| 408 |
+
Method to compile the model for training
|
| 409 |
+
|
| 410 |
+
'''
|
| 411 |
+
|
| 412 |
+
# use the Adam optimizer with a clipnorm of 3
|
| 413 |
+
optimizerAdam = keras.optimizers.Adam(lr=self.lr, clipnorm=3.0)
|
| 414 |
+
# compile model with loss function
|
| 415 |
+
self.model.compile(loss=self.lossWrapper(), optimizer=optimizerAdam)
|
| 416 |
+
|
| 417 |
+
def create_saved_model(self, weights_file, target_name):
|
| 418 |
+
'''
|
| 419 |
+
Method to create a saved model folder from a weights file
|
| 420 |
+
|
| 421 |
+
'''
|
| 422 |
+
# check for type
|
| 423 |
+
if weights_file.find('_norm_') != -1:
|
| 424 |
+
norm_stft = True
|
| 425 |
+
else:
|
| 426 |
+
norm_stft = False
|
| 427 |
+
# build model
|
| 428 |
+
self.build_DTLN_model_stateful(norm_stft=norm_stft)
|
| 429 |
+
# load weights
|
| 430 |
+
self.model.load_weights(weights_file)
|
| 431 |
+
# save model
|
| 432 |
+
tf.saved_model.save(self.model, target_name)
|
| 433 |
+
|
| 434 |
+
def create_tf_lite_model(self, weights_file, target_name, use_dynamic_range_quant=False):
|
| 435 |
+
'''
|
| 436 |
+
Method to create a tf lite model folder from a weights file.
|
| 437 |
+
The conversion creates two models, one for each separation core.
|
| 438 |
+
Tf lite does not support complex numbers yet. Some processing must be
|
| 439 |
+
done outside the model.
|
| 440 |
+
For further information and how real time processing can be
|
| 441 |
+
implemented see "real_time_processing_tf_lite.py".
|
| 442 |
+
|
| 443 |
+
The conversion only works with TF 2.3.
|
| 444 |
+
|
| 445 |
+
'''
|
| 446 |
+
# check for type
|
| 447 |
+
if weights_file.find('_norm_') != -1:
|
| 448 |
+
norm_stft = True
|
| 449 |
+
num_elements_first_core = 2 + self.numLayer * 3 + 2
|
| 450 |
+
else:
|
| 451 |
+
norm_stft = False
|
| 452 |
+
num_elements_first_core = self.numLayer * 3 + 2
|
| 453 |
+
# build model
|
| 454 |
+
self.build_DTLN_model_stateful(norm_stft=norm_stft)
|
| 455 |
+
# load weights
|
| 456 |
+
self.model.load_weights(weights_file)
|
| 457 |
+
|
| 458 |
+
#### Model 1 ##########################
|
| 459 |
+
mag = Input(batch_shape=(1, 1, (self.blockLen//2+1)))
|
| 460 |
+
states_in_1 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
|
| 461 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 462 |
+
if norm_stft:
|
| 463 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 464 |
+
else:
|
| 465 |
+
# behaviour like in the paper
|
| 466 |
+
mag_norm = mag
|
| 467 |
+
# predicting mask with separation kernel
|
| 468 |
+
mask_1, states_out_1 = self.seperation_kernel_with_states(self.numLayer,
|
| 469 |
+
(self.blockLen//2+1),
|
| 470 |
+
mag_norm, states_in_1)
|
| 471 |
+
|
| 472 |
+
model_1 = Model(inputs=[mag, states_in_1], outputs=[mask_1, states_out_1])
|
| 473 |
+
|
| 474 |
+
#### Model 2 ###########################
|
| 475 |
+
|
| 476 |
+
estimated_frame_1 = Input(batch_shape=(1, 1, (self.blockLen)))
|
| 477 |
+
states_in_2 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
|
| 478 |
+
|
| 479 |
+
# encode time domain frames to feature domain
|
| 480 |
+
encoded_frames = Conv1D(self.encoder_size,1,strides=1,
|
| 481 |
+
use_bias=False)(estimated_frame_1)
|
| 482 |
+
# normalize the input to the separation kernel
|
| 483 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 484 |
+
# predict mask based on the normalized feature frames
|
| 485 |
+
mask_2, states_out_2 = self.seperation_kernel_with_states(self.numLayer,
|
| 486 |
+
self.encoder_size,
|
| 487 |
+
encoded_frames_norm,
|
| 488 |
+
states_in_2)
|
| 489 |
+
# multiply encoded frames with the mask
|
| 490 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 491 |
+
# decode the frames back to time domain
|
| 492 |
+
decoded_frame = Conv1D(self.blockLen, 1, padding='causal',
|
| 493 |
+
use_bias=False)(estimated)
|
| 494 |
+
|
| 495 |
+
model_2 = Model(inputs=[estimated_frame_1, states_in_2],
|
| 496 |
+
outputs=[decoded_frame, states_out_2])
|
| 497 |
+
|
| 498 |
+
# set weights to submodels
|
| 499 |
+
weights = self.model.get_weights()
|
| 500 |
+
model_1.set_weights(weights[:num_elements_first_core])
|
| 501 |
+
model_2.set_weights(weights[num_elements_first_core:])
|
| 502 |
+
# convert first model
|
| 503 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(model_1)
|
| 504 |
+
if use_dynamic_range_quant:
|
| 505 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 506 |
+
tflite_model = converter.convert()
|
| 507 |
+
with tf.io.gfile.GFile(target_name + '_1.tflite', 'wb') as f:
|
| 508 |
+
f.write(tflite_model)
|
| 509 |
+
# convert second model
|
| 510 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(model_2)
|
| 511 |
+
if use_dynamic_range_quant:
|
| 512 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 513 |
+
tflite_model = converter.convert()
|
| 514 |
+
with tf.io.gfile.GFile(target_name + '_2.tflite', 'wb') as f:
|
| 515 |
+
f.write(tflite_model)
|
| 516 |
+
|
| 517 |
+
print('TF lite conversion complete!')
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def train_model(self, runName, path_to_train_mix, path_to_train_speech, \
|
| 521 |
+
path_to_val_mix, path_to_val_speech):
|
| 522 |
+
'''
|
| 523 |
+
Method to train the DTLN model.
|
| 524 |
+
'''
|
| 525 |
+
|
| 526 |
+
# create save path if not existent
|
| 527 |
+
savePath = './models_'+ runName+'/'
|
| 528 |
+
if not os.path.exists(savePath):
|
| 529 |
+
os.makedirs(savePath)
|
| 530 |
+
# create log file writer
|
| 531 |
+
csv_logger = CSVLogger(savePath+ 'training_' +runName+ '.log')
|
| 532 |
+
# create callback for the adaptive learning rate
|
| 533 |
+
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
|
| 534 |
+
patience=3, min_lr=10**(-10), cooldown=1)
|
| 535 |
+
# create callback for early stopping
|
| 536 |
+
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0,
|
| 537 |
+
patience=10, verbose=0, mode='auto', baseline=None)
|
| 538 |
+
# create model check pointer to save the best model
|
| 539 |
+
checkpointer = ModelCheckpoint(savePath+runName+'.h5',
|
| 540 |
+
monitor='val_loss',
|
| 541 |
+
verbose=1,
|
| 542 |
+
save_best_only=True,
|
| 543 |
+
save_weights_only=True,
|
| 544 |
+
mode='auto',
|
| 545 |
+
save_freq='epoch'
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# calculate length of audio chunks in samples
|
| 549 |
+
len_in_samples = int(np.fix(self.fs * self.len_samples /
|
| 550 |
+
self.block_shift)*self.block_shift)
|
| 551 |
+
# create data generator for training data
|
| 552 |
+
generator_input = audio_generator(path_to_train_mix,
|
| 553 |
+
path_to_train_speech,
|
| 554 |
+
len_in_samples,
|
| 555 |
+
self.fs, train_flag=True)
|
| 556 |
+
dataset = generator_input.tf_data_set
|
| 557 |
+
dataset = dataset.batch(self.batchsize, drop_remainder=True).repeat()
|
| 558 |
+
# calculate number of training steps in one epoch
|
| 559 |
+
steps_train = generator_input.total_samples//self.batchsize
|
| 560 |
+
# create data generator for validation data
|
| 561 |
+
generator_val = audio_generator(path_to_val_mix,
|
| 562 |
+
path_to_val_speech,
|
| 563 |
+
len_in_samples, self.fs)
|
| 564 |
+
dataset_val = generator_val.tf_data_set
|
| 565 |
+
dataset_val = dataset_val.batch(self.batchsize, drop_remainder=True).repeat()
|
| 566 |
+
# calculate number of validation steps
|
| 567 |
+
steps_val = generator_val.total_samples//self.batchsize
|
| 568 |
+
# start the training of the model
|
| 569 |
+
self.model.fit(
|
| 570 |
+
x=dataset,
|
| 571 |
+
batch_size=None,
|
| 572 |
+
steps_per_epoch=steps_train,
|
| 573 |
+
epochs=self.max_epochs,
|
| 574 |
+
verbose=1,
|
| 575 |
+
validation_data=dataset_val,
|
| 576 |
+
validation_steps=steps_val,
|
| 577 |
+
callbacks=[checkpointer, reduce_lr, csv_logger, early_stopping],
|
| 578 |
+
max_queue_size=50,
|
| 579 |
+
workers=4,
|
| 580 |
+
use_multiprocessing=True)
|
| 581 |
+
# clear out garbage
|
| 582 |
+
tf.keras.backend.clear_session()
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class InstantLayerNormalization(Layer):
|
| 587 |
+
'''
|
| 588 |
+
Class implementing instant layer normalization. It can also be called
|
| 589 |
+
channel-wise layer normalization and was proposed by
|
| 590 |
+
Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2)
|
| 591 |
+
'''
|
| 592 |
+
|
| 593 |
+
def __init__(self, **kwargs):
|
| 594 |
+
'''
|
| 595 |
+
Constructor
|
| 596 |
+
'''
|
| 597 |
+
super(InstantLayerNormalization, self).__init__(**kwargs)
|
| 598 |
+
self.epsilon = 1e-7
|
| 599 |
+
self.gamma = None
|
| 600 |
+
self.beta = None
|
| 601 |
+
|
| 602 |
+
def build(self, input_shape):
|
| 603 |
+
'''
|
| 604 |
+
Method to build the weights.
|
| 605 |
+
'''
|
| 606 |
+
shape = input_shape[-1:]
|
| 607 |
+
# initialize gamma
|
| 608 |
+
self.gamma = self.add_weight(shape=shape,
|
| 609 |
+
initializer='ones',
|
| 610 |
+
trainable=True,
|
| 611 |
+
name='gamma')
|
| 612 |
+
# initialize beta
|
| 613 |
+
self.beta = self.add_weight(shape=shape,
|
| 614 |
+
initializer='zeros',
|
| 615 |
+
trainable=True,
|
| 616 |
+
name='beta')
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def call(self, inputs):
|
| 620 |
+
'''
|
| 621 |
+
Method to call the Layer. All processing is done here.
|
| 622 |
+
'''
|
| 623 |
+
|
| 624 |
+
# calculate mean of each frame
|
| 625 |
+
mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
|
| 626 |
+
# calculate variance of each frame
|
| 627 |
+
variance = tf.math.reduce_mean(tf.math.square(inputs - mean),
|
| 628 |
+
axis=[-1], keepdims=True)
|
| 629 |
+
# calculate standard deviation
|
| 630 |
+
std = tf.math.sqrt(variance + self.epsilon)
|
| 631 |
+
# normalize each frame independently
|
| 632 |
+
outputs = (inputs - mean) / std
|
| 633 |
+
# scale with gamma
|
| 634 |
+
outputs = outputs * self.gamma
|
| 635 |
+
# add the bias beta
|
| 636 |
+
outputs = outputs + self.beta
|
| 637 |
+
# return output
|
| 638 |
+
return outputs
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
|
models/DTLN (alekya)/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020 Nils L. Westhausen
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
models/DTLN (alekya)/README.md
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dual-signal Transformation LSTM Network
|
| 2 |
+
|
| 3 |
+
+ Tensorflow 2.x implementation of the stacked dual-signal transformation LSTM network (DTLN) for real-time noise suppression.
|
| 4 |
+
+ This repository provides the code for training, infering and serving the DTLN model in python. It also provides pretrained models in SavedModel, TF-lite and ONNX format, which can be used as baseline for your own projects. The model is able to run with real time audio on a RaspberryPi.
|
| 5 |
+
+ If you are doing cool things with this repo, tell me about it. I am always curious about what you are doing with this code or this models.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
The DTLN model was handed in to the deep noise suppression challenge ([DNS-Challenge](https://github.com/microsoft/DNS-Challenge)) and the paper was presented at Interspeech 2020.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
This approach combines a short-time Fourier transform (STFT) and a learned analysis and synthesis basis in a stacked-network approach with less than one million parameters. The model was trained on 500h of noisy speech provided by the challenge organizers. The network is capable of real-time processing (one frame in, one frame out) and reaches competitive results.
|
| 13 |
+
Combining these two types of signal transformations enables the DTLN to robustly extract information from magnitude spectra and incorporate phase information from the learned feature basis. The method shows state-of-the-art performance and outperforms the DNS-Challenge baseline by 0.24 points absolute in terms of the mean opinion score (MOS).
|
| 14 |
+
|
| 15 |
+
For more information see the [paper](https://www.isca-speech.org/archive/interspeech_2020/westhausen20_interspeech.html). The results of the DNS-Challenge are published [here](https://www.microsoft.com/en-us/research/academic-program/deep-noise-suppression-challenge-interspeech-2020/#!results). We reached a competitive 8th place out of 17 teams in the real time track.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
For baseline usage and to reproduce the processing used for the paper run:
|
| 20 |
+
```bash
|
| 21 |
+
$ python run_evaluation.py -i in/folder/with/wav -o target/folder/processed/files -m ./pretrained_model/model.h5
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
The pretrained DTLN-aec (the DTLN applied to acoustic echo cancellation) can be found in the [DTLN-aec repository](https://github.com/breizhn/DTLN-aec).
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
Author: Nils L. Westhausen ([Communication Acoustics](https://uol.de/en/kommunikationsakustik) , Carl von Ossietzky University, Oldenburg, Germany)
|
| 31 |
+
|
| 32 |
+
This code is licensed under the terms of the MIT license.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
### Citing:
|
| 37 |
+
|
| 38 |
+
If you are using the DTLN model, please cite:
|
| 39 |
+
|
| 40 |
+
```BibTex
|
| 41 |
+
@inproceedings{Westhausen2020,
|
| 42 |
+
author={Nils L. Westhausen and Bernd T. Meyer},
|
| 43 |
+
title={{Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression}},
|
| 44 |
+
year=2020,
|
| 45 |
+
booktitle={Proc. Interspeech 2020},
|
| 46 |
+
pages={2477--2481},
|
| 47 |
+
doi={10.21437/Interspeech.2020-2631},
|
| 48 |
+
url={http://dx.doi.org/10.21437/Interspeech.2020-2631}
|
| 49 |
+
}
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
### Contents of the README:
|
| 55 |
+
|
| 56 |
+
* [Results](#results)
|
| 57 |
+
* [Execution Times](#execution-times)
|
| 58 |
+
* [Audio Samples](#audio-samples)
|
| 59 |
+
* [Contents of the repository](#contents-of-the-repository)
|
| 60 |
+
* [Python dependencies](#python-dependencies)
|
| 61 |
+
* [Training data preparation](#training-data-preparation)
|
| 62 |
+
* [Run a training of the DTLN model](#run-a-training-of-the-dtln-model)
|
| 63 |
+
* [Measuring the execution time of the DTLN model with the SavedModel format](#measuring-the-execution-time-of-the-dtln-model-with-the-savedmodel-format)
|
| 64 |
+
* [Real time processing with the SavedModel format](#real-time-processing-with-the-savedmodel-format)
|
| 65 |
+
* [Real time processing with tf-lite](#real-time-processing-with-tf-lite)
|
| 66 |
+
* [Real time audio with sounddevice and tf-lite](#real-time-audio-with-sounddevice-and-tf-lite)
|
| 67 |
+
* [Model conversion and real time processing with ONNX](#model-conversion-and-real-time-processing-with-onnx)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
### Results:
|
| 73 |
+
|
| 74 |
+
Results on the DNS-Challenge non reverberant test set:
|
| 75 |
+
Model | PESQ [mos] | STOI [%] | SI-SDR [dB] | TF version
|
| 76 |
+
--- | --- | --- | --- | ---
|
| 77 |
+
unprocessed | 2.45 | 91.52 | 9.07 |
|
| 78 |
+
NsNet (Baseline) | 2.70 | 90.56 | 12.57 |
|
| 79 |
+
| | | |
|
| 80 |
+
DTLN (500h) | 3.04 | 94.76 | 16.34 | 2.1
|
| 81 |
+
DTLN (500h)| 2.98 | 94.75 | 16.20 | TF-light
|
| 82 |
+
DTLN (500h) | 2.95 | 94.47 | 15.71 | TF-light quantized
|
| 83 |
+
| | | |
|
| 84 |
+
DTLN norm (500h) | 3.04 | 94.47 | 16.10 | 2.2
|
| 85 |
+
| | | |
|
| 86 |
+
DTLN norm (40h) | 3.05 | 94.57 | 16.88 | 2.2
|
| 87 |
+
DTLN norm (40h) | 2.98 | 94.56 | 16.58 | TF-light
|
| 88 |
+
DTLN norm (40h) | 2.98 | 94.51 | 16.22 | TF-light quantized
|
| 89 |
+
|
| 90 |
+
* The conversion to TF-light slightly reduces the performance.
|
| 91 |
+
* The dynamic range quantization of TF-light also reduces the performance a bit and introduces some quantization noise. But the audio-quality is still on a high level and the model is real-time capable on the Raspberry Pi 3 B+.
|
| 92 |
+
* The normalization of the log magnitude of the STFT does not decrease the model performance and makes it more robust against level variations.
|
| 93 |
+
* With data augmentation during training it is possible to train the DTLN model on just 40h of noise and speech data. If you have any question regarding this, just contact me.
|
| 94 |
+
|
| 95 |
+
[To contents](#contents-of-the-readme)
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
### Execution Times:
|
| 100 |
+
|
| 101 |
+
Execution times for SavedModel are measured with TF 2.2 and for TF-lite with the TF-lite runtime:
|
| 102 |
+
System | Processor | #Cores | SavedModel | TF-lite | TF-lite quantized
|
| 103 |
+
--- | --- | --- | --- | --- | ---
|
| 104 |
+
Ubuntu 18.04 | Intel I5 6600k @ 3.5 GHz | 4 | 0.65 ms | 0.36 ms | 0.27 ms
|
| 105 |
+
Macbook Air mid 2012 | Intel I7 3667U @ 2.0 GHz | 2 | 1.4 ms | 0.6 ms | 0.4 ms
|
| 106 |
+
Raspberry Pi 3 B+ | ARM Cortex A53 @ 1.4 GHz | 4 | 15.54 ms | 9.6 ms | 2.2 ms
|
| 107 |
+
|
| 108 |
+
For real-time capability the execution time must be below 8 ms.
|
| 109 |
+
|
| 110 |
+
[To contents](#contents-of-the-readme)
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
### Audio Samples:
|
| 115 |
+
|
| 116 |
+
Here some audio samples created with the tf-lite model. Sadly audio can not be integrated directly into markdown.
|
| 117 |
+
|
| 118 |
+
Noisy | Enhanced | Noise type
|
| 119 |
+
--- | --- | ---
|
| 120 |
+
[Sample 1](https://cloudsync.uol.de/s/GFHzmWWJAwgQPLf) | [Sample 1](https://cloudsync.uol.de/s/p3M48y7cjkJ2ZZg) | Air conditioning
|
| 121 |
+
[Sample 2](https://cloudsync.uol.de/s/4Y2PoSpJf7nXx9T) | [Sample 2](https://cloudsync.uol.de/s/QeK4aH5KCELPnko) | Music
|
| 122 |
+
[Sample 3](https://cloudsync.uol.de/s/Awc6oBtnTpb5pY7) | [Sample 3](https://cloudsync.uol.de/s/yNsmDgxH3MPWMTi) | Bus
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
[To contents](#contents-of-the-readme)
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
### Contents of the repository:
|
| 129 |
+
|
| 130 |
+
* **DTLN_model.py** \
|
| 131 |
+
This file is containing the model, data generator and the training routine.
|
| 132 |
+
* **run_training.py** \
|
| 133 |
+
Script to run the training. Before you can start the training with `$ python run_training.py`you have to set the paths to you training and validation data inside the script. The training script uses a default setup.
|
| 134 |
+
* **run_evaluation.py** \
|
| 135 |
+
Script to process a folder with optional subfolders containing .wav files with a trained DTLN model. With the pretrained model delivered with this repository a folder can be processed as following: \
|
| 136 |
+
`$ python run_evaluation.py -i /path/to/input -o /path/for/processed -m ./pretrained_model/model.h5` \
|
| 137 |
+
The evaluation script will create the new folder with the same structure as the input folder and the files will have the same name as the input files.
|
| 138 |
+
* **measure_execution_time.py** \
|
| 139 |
+
Script for measuring the execution time with the saved DTLN model in `./pretrained_model/dtln_saved_model/`. For further information see this [section](#measuring-the-execution-time-of-the-dtln-model-with-the-savedmodel-format).
|
| 140 |
+
* **real_time_processing.py** \
|
| 141 |
+
Script, which explains how real time processing with the SavedModel works. For more information see this [section](#real-time-processing-with-the-savedmodel-format).
|
| 142 |
+
+ **./pretrained_model/** \
|
| 143 |
+
* `model.h5`: Model weights as used in the DNS-Challenge DTLN model.
|
| 144 |
+
* `DTLN_norm_500h.h5`: Model weights trained on 500h with normalization of stft log magnitudes.
|
| 145 |
+
* `DTLN_norm_40h.h5`: Model weights trained on 40h with normalization of stft log magnitudes.
|
| 146 |
+
* `./dtln_saved_model`: same as `model.h5` but as a stateful model in SavedModel format.
|
| 147 |
+
* `./DTLN_norm_500h_saved_model`: same as `DTLN_norm_500h.h5` but as a stateful model in SavedModel format.
|
| 148 |
+
* `./DTLN_norm_40h_saved_model`: same as `DTLN_norm_40h.h5` but as a stateful model in SavedModel format.
|
| 149 |
+
* `model_1.tflite` together with `model_2.tflite`: same as `model.h5` but as TF-lite model with external state handling.
|
| 150 |
+
* `model_quant_1.tflite` together with `model_quant_2.tflite`: same as `model.h5` but as TF-lite model with external state handling and dynamic range quantization.
|
| 151 |
+
* `model_1.onnx` together with `model_2.onnx`: same as `model.h5` but as ONNX model with external state handling.
|
| 152 |
+
|
| 153 |
+
[To contents](#contents-of-the-readme)
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
### Python dependencies:
|
| 157 |
+
|
| 158 |
+
The following packages will be required for this repository:
|
| 159 |
+
* TensorFlow (2.x)
|
| 160 |
+
* librosa
|
| 161 |
+
* wavinfo
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
All additional packages (numpy, soundfile, etc.) should be installed on the fly when using conda or pip. I recommend using conda environments or [pyenv](https://github.com/pyenv/pyenv) [virtualenv](https://github.com/pyenv/pyenv-virtualenv) for the python environment. For training a GPU with at least 5 GB of memory is required. I recommend at least Tensorflow 2.1 with Nvidia driver 418 and Cuda 10.1. If you use conda Cuda will be installed on the fly and you just need the driver. For evaluation-only the CPU version of Tensorflow is enough. Everything was tested on Ubuntu 18.04.
|
| 165 |
+
|
| 166 |
+
Conda environments for training (with cuda) and for evaluation (CPU only) can be created as following:
|
| 167 |
+
|
| 168 |
+
For the training environment:
|
| 169 |
+
```shell
|
| 170 |
+
$ conda env create -f train_env.yml
|
| 171 |
+
```
|
| 172 |
+
For the evaluation environment:
|
| 173 |
+
```
|
| 174 |
+
$ conda env create -f eval_env.yml
|
| 175 |
+
```
|
| 176 |
+
For the tf-lite environment:
|
| 177 |
+
```
|
| 178 |
+
$ conda env create -f tflite_env.yml
|
| 179 |
+
```
|
| 180 |
+
The tf-lite runtime must be downloaded from [here](https://www.tensorflow.org/lite/guide/python).
|
| 181 |
+
|
| 182 |
+
[To contents](#contents-of-the-readme)
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
### Training data preparation:
|
| 186 |
+
|
| 187 |
+
1. Clone the forked DNS-Challenge [repository](https://github.com/breizhn/DNS-Challenge). Before cloning the repository make sure `git-lfs` is installed. Also make sure your disk has enough space. I recommend downloading the data to an SSD for faster dataset creation.
|
| 188 |
+
|
| 189 |
+
2. Run `noisyspeech_synthesizer_multiprocessing.py` to create the dataset. `noisyspeech_synthesizer.cfg`was changed according to my training setup used for the DNS-Challenge.
|
| 190 |
+
|
| 191 |
+
3. Run `split_dns_corpus.py`to divide the dataset in training and validation data. The classic 80:20 split is applied. This file was added to the forked repository by me.
|
| 192 |
+
|
| 193 |
+
[To contents](#contents-of-the-readme)
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
### Run a training of the DTLN model:
|
| 197 |
+
|
| 198 |
+
1. Make sure all dependencies are installed in your python environment.
|
| 199 |
+
|
| 200 |
+
2. Change the paths to your training and validation dataset in `run_training.py`.
|
| 201 |
+
|
| 202 |
+
3. Run `$ python run_training.py`.
|
| 203 |
+
|
| 204 |
+
One epoch takes around 21 minutes on a Nvidia RTX 2080 Ti when loading the training data from an SSD.
|
| 205 |
+
|
| 206 |
+
[To contents](#contents-of-the-readme)
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
### Measuring the execution time of the DTLN model with the SavedModel format:
|
| 210 |
+
|
| 211 |
+
In total there are three ways to measure the execution time for one block of the model: Running a sequence in Keras and dividing by the number of blocks in the sequence, building a stateful model in Keras and running block by block, and saving the stateful model in Tensorflow's SavedModel format and calling that one block by block. In the following I will explain how running the model in the SavedModel format, because it is the most portable version and can also be called from Tensorflow Serving.
|
| 212 |
+
|
| 213 |
+
A Keras model can be saved to the saved model format:
|
| 214 |
+
```python
|
| 215 |
+
import tensorflow as tf
|
| 216 |
+
'''
|
| 217 |
+
Building some model here
|
| 218 |
+
'''
|
| 219 |
+
tf.saved_model.save(your_keras_model, 'name_save_path')
|
| 220 |
+
```
|
| 221 |
+
Important here for real time block by block processing is, to make the LSTM layer stateful, so they can remember the states from the previous block.
|
| 222 |
+
|
| 223 |
+
The model can be imported with
|
| 224 |
+
```python
|
| 225 |
+
model = tf.saved_model.load('name_save_path')
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
For inference we now first call this for mapping signature names to functions
|
| 229 |
+
```python
|
| 230 |
+
infer = model.signatures['serving_default']
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
and now for inferring the block `x` call
|
| 234 |
+
```python
|
| 235 |
+
y = infer(tf.constant(x))['conv1d_1']
|
| 236 |
+
```
|
| 237 |
+
This command gives you the result on the node `'conv1d_1'`which is our output node for real time processing. For more information on using the SavedModel format and obtaining the output node see this [Guide](https://www.tensorflow.org/guide/saved_model).
|
| 238 |
+
|
| 239 |
+
For making everything easier this repository provides a stateful DTLN SavedModel.
|
| 240 |
+
For measuring the execution time call:
|
| 241 |
+
```
|
| 242 |
+
$ python measure_execution_time.py
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
[To contents](#contents-of-the-readme)
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
### Real time processing with the SavedModel format:
|
| 250 |
+
|
| 251 |
+
For explanation look at `real_time_processing.py`.
|
| 252 |
+
|
| 253 |
+
Here some consideration for integrating this model in your project:
|
| 254 |
+
* The sampling rate of this model is fixed at 16 kHz. It will not work smoothly with other sampling rates.
|
| 255 |
+
* The block length of 32 ms and the block shift of 8 ms are also fixed. For changing these values, the model must be retrained.
|
| 256 |
+
* The delay created by the model is the block length, so the input-output delay is 32 ms.
|
| 257 |
+
* For real time capability on your system, the execution time must be below the length of the block shift, so below 8 ms.
|
| 258 |
+
* If can not give you support on the hardware side, regarding soundcards, drivers and so on. Be aware, a lot of artifacts can come from this side.
|
| 259 |
+
|
| 260 |
+
[To contents](#contents-of-the-readme)
|
| 261 |
+
|
| 262 |
+
---
|
| 263 |
+
### Real time processing with tf-lite:
|
| 264 |
+
|
| 265 |
+
With TF 2.3 it is finally possible to convert LSTMs to tf-lite. It is still not perfect because the states must be handled seperatly for a stateful model and tf-light does not support complex numbers. That means that the model is splitted in two submodels when converting it to tf-lite and the calculation of the FFT and iFFT is performed outside the model. I provided an example script for explaining, how real time processing with the tf light model works (```real_time_processing_tf_lite.py```). In this script the tf-lite runtime is used. The runtime can be downloaded [here](https://www.tensorflow.org/lite/guide/python). Quantization works now.
|
| 266 |
+
|
| 267 |
+
Using the tf-lite DTLN model and the tf-lite runtime the execution time on an old Macbook Air mid 2012 can be decreased to **0.6 ms**.
|
| 268 |
+
|
| 269 |
+
[To contents](#contents-of-the-readme)
|
| 270 |
+
|
| 271 |
+
---
|
| 272 |
+
### Real time audio with sounddevice and tf-lite:
|
| 273 |
+
|
| 274 |
+
The file ```real_time_dtln_audio.py```is an example how real time audio with the tf-lite model and the [sounddevice](https://github.com/spatialaudio/python-sounddevice) toolbox can be implemented. The script is based on the ```wire.py``` example. It works fine on an old Macbook Air mid 2012 and so it will probably run on most newer devices. In the quantized version it was sucessfully tested on an Raspberry Pi 3B +.
|
| 275 |
+
|
| 276 |
+
First check for your audio devices:
|
| 277 |
+
```
|
| 278 |
+
$ python real_time_dtln_audio.py --list-devices
|
| 279 |
+
```
|
| 280 |
+
Choose the index of an input and an output device and call:
|
| 281 |
+
```
|
| 282 |
+
$ python real_time_dtln_audio.py -i in_device_idx -o out_device_idx
|
| 283 |
+
```
|
| 284 |
+
If the script is showing too much ```input underflow``` restart the sript. If that does not help, increase the latency with the ```--latency``` option. The default value is 0.2 .
|
| 285 |
+
|
| 286 |
+
[To contents](#contents-of-the-readme)
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
### Model conversion and real time processing with ONNX:
|
| 290 |
+
|
| 291 |
+
Finally I got the ONNX model working.
|
| 292 |
+
For converting the model TF 2.1 and keras2onnx is required. keras2onnx can be downloaded [here](https://github.com/onnx/keras-onnx) and must be installed from source as described in the README. When all dependencies are installed, call:
|
| 293 |
+
```
|
| 294 |
+
$ python convert_weights_to_onnx.py -m /name/of/the/model.h5 -t onnx_model_name
|
| 295 |
+
```
|
| 296 |
+
to convert the model to the ONNX format. The model is split in two parts as for the TF-lite model. The conversion does not work on MacOS.
|
| 297 |
+
The real time processing works similar to the TF-lite model and can be looked up in following file: ```real_time_processing_onnx.py ```
|
| 298 |
+
The ONNX runtime required for this script can be installed with:
|
| 299 |
+
```
|
| 300 |
+
$ pip install onnxruntime
|
| 301 |
+
```
|
| 302 |
+
The execution time on the Macbook Air mid 2012 is around 1.13 ms for one block.
|
models/DTLN (alekya)/convert_weights_to_onnx.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to covert a .h5 weights file of the DTLN model to ONNX.
|
| 4 |
+
At the moment the conversion only works with TF 2.1 and not on Mac.
|
| 5 |
+
|
| 6 |
+
Example call:
|
| 7 |
+
$python convert_weights_to_ONNX.py -m /name/of/the/model.h5 \
|
| 8 |
+
-t name_target
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 12 |
+
Version: 03.07.2020
|
| 13 |
+
|
| 14 |
+
This code is licensed under the terms of the MIT-license.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from DTLN_model import DTLN_model, InstantLayerNormalization
|
| 18 |
+
import argparse
|
| 19 |
+
from tensorflow.keras.models import Model
|
| 20 |
+
from tensorflow.keras.layers import Input, Multiply, Conv1D
|
| 21 |
+
import tensorflow as tf
|
| 22 |
+
import keras2onnx
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == '__main__':
|
| 26 |
+
# arguement parser for running directly from the command line
|
| 27 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 28 |
+
parser.add_argument('--weights_file', '-m',
|
| 29 |
+
help='path to .h5 weights file')
|
| 30 |
+
parser.add_argument('--target_folder', '-t',
|
| 31 |
+
help='target folder for saved model')
|
| 32 |
+
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
weights_file = args.weights_file
|
| 35 |
+
dtln_class = DTLN_model()
|
| 36 |
+
# check for type
|
| 37 |
+
if weights_file.find('_norm_') != -1:
|
| 38 |
+
norm_stft = True
|
| 39 |
+
num_elements_first_core = 2 + dtln_class.numLayer * 3 + 2
|
| 40 |
+
else:
|
| 41 |
+
norm_stft = False
|
| 42 |
+
num_elements_first_core = dtln_class.numLayer * 3 + 2
|
| 43 |
+
# build model
|
| 44 |
+
dtln_class.build_DTLN_model_stateful(norm_stft=norm_stft)
|
| 45 |
+
# load weights
|
| 46 |
+
dtln_class.model.load_weights(weights_file)
|
| 47 |
+
#### Model 1 ##########################
|
| 48 |
+
mag = Input(batch_shape=(1, 1, (dtln_class.blockLen//2+1)))
|
| 49 |
+
states_in_1 = Input(batch_shape=(1, dtln_class.numLayer, dtln_class.numUnits, 2))
|
| 50 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 51 |
+
if norm_stft:
|
| 52 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 53 |
+
else:
|
| 54 |
+
# behaviour like in the paper
|
| 55 |
+
mag_norm = mag
|
| 56 |
+
# predicting mask with separation kernel
|
| 57 |
+
mask_1, states_out_1 = dtln_class.seperation_kernel_with_states(dtln_class.numLayer,
|
| 58 |
+
(dtln_class.blockLen//2+1),
|
| 59 |
+
mag_norm, states_in_1)
|
| 60 |
+
|
| 61 |
+
model_1 = Model(inputs=[mag, states_in_1], outputs=[mask_1, states_out_1])
|
| 62 |
+
|
| 63 |
+
#### Model 2 ###########################
|
| 64 |
+
|
| 65 |
+
estimated_frame_1 = Input(batch_shape=(1, 1, (dtln_class.blockLen)))
|
| 66 |
+
states_in_2 = Input(batch_shape=(1, dtln_class.numLayer, dtln_class.numUnits, 2))
|
| 67 |
+
|
| 68 |
+
# encode time domain frames to feature domain
|
| 69 |
+
encoded_frames = Conv1D(dtln_class.encoder_size,1,strides=1,
|
| 70 |
+
use_bias=False)(estimated_frame_1)
|
| 71 |
+
# normalize the input to the separation kernel
|
| 72 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 73 |
+
# predict mask based on the normalized feature frames
|
| 74 |
+
mask_2, states_out_2 = dtln_class.seperation_kernel_with_states(dtln_class.numLayer,
|
| 75 |
+
dtln_class.encoder_size,
|
| 76 |
+
encoded_frames_norm,
|
| 77 |
+
states_in_2)
|
| 78 |
+
# multiply encoded frames with the mask
|
| 79 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 80 |
+
# decode the frames back to time domain
|
| 81 |
+
decoded_frame = Conv1D(dtln_class.blockLen, 1, padding='causal',
|
| 82 |
+
use_bias=False)(estimated)
|
| 83 |
+
|
| 84 |
+
model_2 = Model(inputs=[estimated_frame_1, states_in_2],
|
| 85 |
+
outputs=[decoded_frame, states_out_2])
|
| 86 |
+
|
| 87 |
+
# set weights to submodels
|
| 88 |
+
weights = dtln_class.model.get_weights()
|
| 89 |
+
|
| 90 |
+
model_1.set_weights(weights[:num_elements_first_core])
|
| 91 |
+
model_2.set_weights(weights[num_elements_first_core:])
|
| 92 |
+
# convert first model
|
| 93 |
+
onnx_model = keras2onnx.convert_keras(model_1)
|
| 94 |
+
temp_model_file = args.target_folder + '_1.onnx'
|
| 95 |
+
keras2onnx.save_model(onnx_model, temp_model_file)
|
| 96 |
+
# convert second model
|
| 97 |
+
onnx_model = keras2onnx.convert_keras(model_2)
|
| 98 |
+
temp_model_file = args.target_folder + '_2.onnx'
|
| 99 |
+
keras2onnx.save_model(onnx_model, temp_model_file)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
print('ONNX conversion complete!')
|
models/DTLN (alekya)/convert_weights_to_saved_model.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to covert a .h weights file of the DTLN model to the saved model format.
|
| 3 |
+
|
| 4 |
+
Example call:
|
| 5 |
+
$python convert_weights_to_saved_model.py -m /name/of/the/model.h5 \
|
| 6 |
+
-t name_target_folder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 24.06.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from DTLN_model import DTLN_model
|
| 16 |
+
import argparse
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == '__main__':
|
| 20 |
+
# arguement parser for running directly from the command line
|
| 21 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 22 |
+
parser.add_argument('--weights_file', '-m',
|
| 23 |
+
help='path to .h5 weights file')
|
| 24 |
+
parser.add_argument('--target_folder', '-t',
|
| 25 |
+
help='target folder for saved model')
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
converter = DTLN_model()
|
| 30 |
+
converter.create_saved_model(args.weights_file, args.target_folder)
|
models/DTLN (alekya)/convert_weights_to_tf_lite.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to covert a .h5 weights file of the DTLN model to tf lite.
|
| 3 |
+
|
| 4 |
+
Example call:
|
| 5 |
+
$python convert_weights_to_tf_light.py -m /name/of/the/model.h5 \
|
| 6 |
+
-t name_target
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 30.06.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from DTLN_model import DTLN_model
|
| 16 |
+
import argparse
|
| 17 |
+
from pkg_resources import parse_version
|
| 18 |
+
import tensorflow as tf
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if __name__ == '__main__':
|
| 22 |
+
# arguement parser for running directly from the command line
|
| 23 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 24 |
+
parser.add_argument('--weights_file', '-m',
|
| 25 |
+
help='path to .h5 weights file')
|
| 26 |
+
parser.add_argument('--target_folder', '-t',
|
| 27 |
+
help='target folder for saved model')
|
| 28 |
+
parser.add_argument('--quantization', '-q',
|
| 29 |
+
help='use quantization (True/False)',
|
| 30 |
+
default='False')
|
| 31 |
+
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
if parse_version(tf.__version__) < parse_version('2.3.0-rc0'):
|
| 34 |
+
raise ValueError('Tf version < 2.3. Conversion of LSTMs will not work'+
|
| 35 |
+
+' with older tensorflow versions')
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
converter = DTLN_model()
|
| 39 |
+
converter.create_tf_lite_model(args.weights_file,
|
| 40 |
+
args.target_folder,
|
| 41 |
+
use_dynamic_range_quant=bool(args.quantization))
|
models/DTLN (alekya)/eval_env.yml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: eval_env
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _tflow_select=2.3.0=mkl
|
| 8 |
+
- absl-py=0.9.0=py37_0
|
| 9 |
+
- astunparse=1.6.3=py_0
|
| 10 |
+
- audioread=2.1.8=py37hc8dfbb8_2
|
| 11 |
+
- blas=1.0=mkl
|
| 12 |
+
- blinker=1.4=py37_0
|
| 13 |
+
- brotlipy=0.7.0=py37h7b6447c_1000
|
| 14 |
+
- bzip2=1.0.8=h516909a_2
|
| 15 |
+
- c-ares=1.15.0=h7b6447c_1001
|
| 16 |
+
- ca-certificates=2020.4.5.2=hecda079_0
|
| 17 |
+
- cachetools=4.1.0=py_1
|
| 18 |
+
- certifi=2020.4.5.2=py37hc8dfbb8_0
|
| 19 |
+
- cffi=1.14.0=py37he30daa8_1
|
| 20 |
+
- chardet=3.0.4=py37_1003
|
| 21 |
+
- click=7.1.2=py_0
|
| 22 |
+
- cryptography=2.9.2=py37h1ba5d50_0
|
| 23 |
+
- cycler=0.10.0=py_2
|
| 24 |
+
- decorator=4.4.2=py_0
|
| 25 |
+
- ffmpeg=4.2.3=h167e202_0
|
| 26 |
+
- freetype=2.10.2=he06d7ca_0
|
| 27 |
+
- gast=0.3.3=py_0
|
| 28 |
+
- gettext=0.19.8.1=h5e8e0c9_1
|
| 29 |
+
- gmp=6.2.0=he1b5a44_2
|
| 30 |
+
- gnutls=3.6.13=h79a8f9a_0
|
| 31 |
+
- google-auth=1.14.1=py_0
|
| 32 |
+
- google-auth-oauthlib=0.4.1=py_2
|
| 33 |
+
- google-pasta=0.2.0=py_0
|
| 34 |
+
- grpcio=1.27.2=py37hf8bcb03_0
|
| 35 |
+
- h5py=2.10.0=py37h7918eee_0
|
| 36 |
+
- hdf5=1.10.4=hb1b8bf9_0
|
| 37 |
+
- icu=58.2=hf484d3e_1000
|
| 38 |
+
- idna=2.9=py_1
|
| 39 |
+
- intel-openmp=2020.1=217
|
| 40 |
+
- joblib=0.15.1=py_0
|
| 41 |
+
- keras-preprocessing=1.1.0=py_1
|
| 42 |
+
- kiwisolver=1.2.0=py37h99015e2_0
|
| 43 |
+
- lame=3.100=h14c3975_1001
|
| 44 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
| 45 |
+
- libedit=3.1.20181209=hc058e9b_0
|
| 46 |
+
- libffi=3.3=he6710b0_1
|
| 47 |
+
- libflac=1.3.3=he1b5a44_0
|
| 48 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
| 49 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
| 50 |
+
- libiconv=1.15=h516909a_1006
|
| 51 |
+
- libllvm8=8.0.1=hc9558a2_0
|
| 52 |
+
- libogg=1.3.2=h516909a_1002
|
| 53 |
+
- libpng=1.6.37=hed695b0_1
|
| 54 |
+
- libprotobuf=3.12.3=hd408876_0
|
| 55 |
+
- librosa=0.7.2=py_1
|
| 56 |
+
- libsndfile=1.0.28=he1b5a44_1000
|
| 57 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
| 58 |
+
- libvorbis=1.3.6=he1b5a44_2
|
| 59 |
+
- llvmlite=0.31.0=py37h5202443_1
|
| 60 |
+
- markdown=3.1.1=py37_0
|
| 61 |
+
- matplotlib-base=3.2.1=py37hef1b27d_0
|
| 62 |
+
- mkl=2020.1=217
|
| 63 |
+
- mkl-service=2.3.0=py37he904b0f_0
|
| 64 |
+
- mkl_fft=1.1.0=py37h23d657b_0
|
| 65 |
+
- mkl_random=1.1.1=py37h0573a6f_0
|
| 66 |
+
- ncurses=6.2=he6710b0_1
|
| 67 |
+
- nettle=3.4.1=h1bed415_1002
|
| 68 |
+
- numba=0.48.0=py37hb3f55d8_0
|
| 69 |
+
- numpy=1.18.1=py37h4f9e942_0
|
| 70 |
+
- numpy-base=1.18.1=py37hde5b4d6_1
|
| 71 |
+
- oauthlib=3.1.0=py_0
|
| 72 |
+
- openh264=2.1.1=h8b12597_0
|
| 73 |
+
- openssl=1.1.1g=h516909a_0
|
| 74 |
+
- opt_einsum=3.1.0=py_0
|
| 75 |
+
- pip=20.1.1=py37_1
|
| 76 |
+
- protobuf=3.12.3=py37he6710b0_0
|
| 77 |
+
- pyasn1=0.4.8=py_0
|
| 78 |
+
- pyasn1-modules=0.2.7=py_0
|
| 79 |
+
- pycparser=2.20=py_0
|
| 80 |
+
- pyjwt=1.7.1=py37_0
|
| 81 |
+
- pyopenssl=19.1.0=py37_0
|
| 82 |
+
- pyparsing=2.4.7=pyh9f0ad1d_0
|
| 83 |
+
- pysocks=1.7.1=py37_0
|
| 84 |
+
- pysoundfile=0.10.2=py_1001
|
| 85 |
+
- python=3.7.7=hcff3b4d_5
|
| 86 |
+
- python-dateutil=2.8.1=py_0
|
| 87 |
+
- python_abi=3.7=1_cp37m
|
| 88 |
+
- readline=8.0=h7b6447c_0
|
| 89 |
+
- requests=2.23.0=py37_0
|
| 90 |
+
- requests-oauthlib=1.3.0=py_0
|
| 91 |
+
- resampy=0.2.2=py_0
|
| 92 |
+
- rsa=4.0=py_0
|
| 93 |
+
- scikit-learn=0.22.1=py37hd81dba3_0
|
| 94 |
+
- scipy=1.4.1=py37h0b6359f_0
|
| 95 |
+
- setuptools=47.3.0=py37_0
|
| 96 |
+
- six=1.15.0=py_0
|
| 97 |
+
- sqlite=3.31.1=h62c20be_1
|
| 98 |
+
- tensorboard=2.2.1=pyh532a8cf_0
|
| 99 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
| 100 |
+
- tensorflow=2.2.0=mkl_py37h6e9ce2d_0
|
| 101 |
+
- tensorflow-base=2.2.0=mkl_py37hd506778_0
|
| 102 |
+
- tensorflow-estimator=2.2.0=pyh208ff02_0
|
| 103 |
+
- termcolor=1.1.0=py37_1
|
| 104 |
+
- tk=8.6.8=hbc83047_0
|
| 105 |
+
- tornado=6.0.4=py37h8f50634_1
|
| 106 |
+
- urllib3=1.25.9=py_0
|
| 107 |
+
- werkzeug=1.0.1=py_0
|
| 108 |
+
- wheel=0.34.2=py37_0
|
| 109 |
+
- wrapt=1.12.1=py37h7b6447c_1
|
| 110 |
+
- x264=1!152.20180806=h14c3975_0
|
| 111 |
+
- xz=5.2.5=h7b6447c_0
|
| 112 |
+
- zlib=1.2.11=h7b6447c_3
|
| 113 |
+
- pip:
|
| 114 |
+
- lxml==4.5.1
|
| 115 |
+
- wavinfo==1.5
|
| 116 |
+
prefix: /home/nils/anaconda3/envs/eval_env
|
| 117 |
+
|
models/DTLN (alekya)/main.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Script to process a folder of .wav files with a trained DTLN model.
|
| 4 |
+
This script supports subfolders and names the processed files the same as the
|
| 5 |
+
original. The model expects 16kHz audio .wav files. Files with other
|
| 6 |
+
sampling rates will be resampled. Stereo files will be downmixed to mono.
|
| 7 |
+
|
| 8 |
+
The idea of this script is to use it for baseline or comparison purpose.
|
| 9 |
+
|
| 10 |
+
Example call:
|
| 11 |
+
$python run_evaluation.py -i /name/of/input/folder \
|
| 12 |
+
-o /name/of/output/folder \
|
| 13 |
+
-m /name/of/the/model.h5
|
| 14 |
+
|
| 15 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 16 |
+
Version: 13.05.2020
|
| 17 |
+
|
| 18 |
+
This code is licensed under the terms of the MIT-license.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
import librosa
|
| 23 |
+
import numpy as np
|
| 24 |
+
import os
|
| 25 |
+
import argparse
|
| 26 |
+
from DTLN_model import DTLN_model
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def process_file(model, audio_file_name, out_file_name):
|
| 31 |
+
'''
|
| 32 |
+
Funtion to read an audio file, rocess it by the network and write the
|
| 33 |
+
enhanced audio to .wav file.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
model : Keras model
|
| 38 |
+
Keras model, which accepts audio in the size (1,timesteps).
|
| 39 |
+
audio_file_name : STRING
|
| 40 |
+
Name and path of the input audio file.
|
| 41 |
+
out_file_name : STRING
|
| 42 |
+
Name and path of the target file.
|
| 43 |
+
|
| 44 |
+
'''
|
| 45 |
+
|
| 46 |
+
# read audio file with librosa to handle resampling and enforce mono
|
| 47 |
+
in_data,fs = librosa.core.load(audio_file_name, sr=16000, mono=True)
|
| 48 |
+
# get length of file
|
| 49 |
+
len_orig = len(in_data)
|
| 50 |
+
# pad audio
|
| 51 |
+
zero_pad = np.zeros(384)
|
| 52 |
+
in_data = np.concatenate((zero_pad, in_data, zero_pad), axis=0)
|
| 53 |
+
# predict audio with the model
|
| 54 |
+
predicted = model.predict_on_batch(
|
| 55 |
+
np.expand_dims(in_data,axis=0).astype(np.float32))
|
| 56 |
+
# squeeze the batch dimension away
|
| 57 |
+
predicted_speech = np.squeeze(predicted)
|
| 58 |
+
predicted_speech = predicted_speech[384:384+len_orig]
|
| 59 |
+
# write the file to target destination
|
| 60 |
+
sf.write(out_file_name, predicted_speech,fs)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def process_folder(model, folder_name, new_folder_name):
|
| 64 |
+
'''
|
| 65 |
+
Function to find .wav files in the folder and subfolders of "folder_name",
|
| 66 |
+
process each .wav file with an algorithm and write it back to disk in the
|
| 67 |
+
folder "new_folder_name". The structure of the original directory is
|
| 68 |
+
preserved. The processed files will be saved with the same name as the
|
| 69 |
+
original file.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
model : Keras model
|
| 74 |
+
Keras model, which accepts audio in the size (1,timesteps).
|
| 75 |
+
folder_name : STRING
|
| 76 |
+
Input folder with .wav files.
|
| 77 |
+
new_folder_name : STRING
|
| 78 |
+
Traget folder for the processed files.
|
| 79 |
+
|
| 80 |
+
'''
|
| 81 |
+
|
| 82 |
+
# empty list for file and folder names
|
| 83 |
+
file_names = [];
|
| 84 |
+
directories = [];
|
| 85 |
+
new_directories = [];
|
| 86 |
+
# walk through the directory
|
| 87 |
+
for root, dirs, files in os.walk(folder_name):
|
| 88 |
+
for file in files:
|
| 89 |
+
# look for .wav files
|
| 90 |
+
if file.endswith(".wav"):
|
| 91 |
+
# write paths and filenames to lists
|
| 92 |
+
file_names.append(file)
|
| 93 |
+
directories.append(root)
|
| 94 |
+
# create new directory names
|
| 95 |
+
new_directories.append(root.replace(folder_name, new_folder_name))
|
| 96 |
+
# check if the new directory already exists, if not create it
|
| 97 |
+
if not os.path.exists(root.replace(folder_name, new_folder_name)):
|
| 98 |
+
os.makedirs(root.replace(folder_name, new_folder_name))
|
| 99 |
+
# iterate over all .wav files
|
| 100 |
+
for idx in range(len(file_names)):
|
| 101 |
+
# process each file with the model
|
| 102 |
+
process_file(model, os.path.join(directories[idx],file_names[idx]),
|
| 103 |
+
os.path.join(new_directories[idx],file_names[idx]))
|
| 104 |
+
print(file_names[idx] + ' processed successfully!')
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# if __name__ == '__main__':
|
| 110 |
+
def run_model(model, in_folder, out_folder):
|
| 111 |
+
# arguement parser for running directly from the command line
|
| 112 |
+
# parser = argparse.ArgumentParser(description='data evaluation')
|
| 113 |
+
# parser.add_argument('--in_folder', '-i',
|
| 114 |
+
# help='folder with input files')
|
| 115 |
+
# parser.add_argument('--out_folder', '-o',
|
| 116 |
+
# help='target folder for processed files')
|
| 117 |
+
# parser.add_argument('--model', '-m',
|
| 118 |
+
# help='weights of the enhancement model in .h5 format')
|
| 119 |
+
# args = parser.parse_args()
|
| 120 |
+
# # determine type of model
|
| 121 |
+
if model.find('_norm_') != -1:
|
| 122 |
+
norm_stft = True
|
| 123 |
+
else:
|
| 124 |
+
norm_stft = False
|
| 125 |
+
# create class instance
|
| 126 |
+
modelClass = DTLN_model();
|
| 127 |
+
# build the model in default configuration
|
| 128 |
+
modelClass.build_DTLN_model(norm_stft=norm_stft)
|
| 129 |
+
# load weights of the .h5 file
|
| 130 |
+
modelClass.model.load_weights(model)
|
| 131 |
+
# process the folder
|
| 132 |
+
process_folder(modelClass.model, in_folder, out_folder)
|
| 133 |
+
|
| 134 |
+
|
models/DTLN (alekya)/main_1.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from main import run_model
|
| 2 |
+
|
| 3 |
+
run_model("./pretrained_model/model.h5", "test", "./output_test_2")
|
models/DTLN (alekya)/measure_execution_time.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
This script tests the execution time of the DTLN model on a CPU.
|
| 5 |
+
Please use TF 2.2 for comparability.
|
| 6 |
+
|
| 7 |
+
Just run "python measure_execution_time.py"
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 13.05.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
import numpy as np
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# only use the cpu
|
| 21 |
+
os.environ["CUDA_VISIBLE_DEVICES"]=''
|
| 22 |
+
|
| 23 |
+
if __name__ == '__main__':
|
| 24 |
+
# loading model in saved model format
|
| 25 |
+
model = tf.saved_model.load('./pretrained_model/dtln_saved_model')
|
| 26 |
+
# mapping signature names to functions
|
| 27 |
+
infer = model.signatures["serving_default"]
|
| 28 |
+
|
| 29 |
+
exec_time = []
|
| 30 |
+
# create random input for testing
|
| 31 |
+
x = np.random.randn(1,512).astype('float32')
|
| 32 |
+
for idx in range(1010):
|
| 33 |
+
# run timer
|
| 34 |
+
start_time = time.time()
|
| 35 |
+
# infer one block
|
| 36 |
+
y = infer(tf.constant(x))['conv1d_1']
|
| 37 |
+
exec_time.append((time.time() - start_time))
|
| 38 |
+
# ignore the first ten iterations
|
| 39 |
+
print('Execution time per block: ' +
|
| 40 |
+
str( np.round(np.mean(np.stack(exec_time[10:]))*1000, 2)) + ' ms')
|
| 41 |
+
|
| 42 |
+
# Ubuntu 18.04 I5 6600k @ 3.5 GHz: 0.65 ms (4 cores)
|
| 43 |
+
# Macbook Air mid 2012 I7 3667U @ 2.0 GHz: 1.4 ms (2 cores)
|
| 44 |
+
# Raspberry Pi 3 B+ ARM Cortex A53 @ 1.4 GHz: 15.54 (4 cores)
|
models/DTLN (alekya)/model.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6e9ecb8140f51e8ffb373884b62e186dc166e97995fa7eea3fffcc6ef30d55d
|
| 3 |
+
size 3989312
|
models/DTLN (alekya)/real_time_dtln_audio.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
This is a real time example how to implement DTLN tf light model with
|
| 5 |
+
sounddevice. The script is based on the "wire.py" example of the sounddevice
|
| 6 |
+
toolbox. If the command line shows "input underflow", restart the script.
|
| 7 |
+
If there are still a lot of dropouts, increase the latency.
|
| 8 |
+
|
| 9 |
+
First call:
|
| 10 |
+
|
| 11 |
+
$ python real_time_dtln_audio.py --list-devices
|
| 12 |
+
|
| 13 |
+
to get your audio devices. In the next step call
|
| 14 |
+
|
| 15 |
+
$ python real_time_dtln_audio.py -i in_device_idx -o out_device_idx
|
| 16 |
+
|
| 17 |
+
For .whl files of the tf light runtime go to:
|
| 18 |
+
https://www.tensorflow.org/lite/guide/python
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 23 |
+
Version: 01.07.2020
|
| 24 |
+
|
| 25 |
+
This code is licensed under the terms of the MIT-license.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import sounddevice as sd
|
| 31 |
+
import tflite_runtime.interpreter as tflite
|
| 32 |
+
import argparse
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def int_or_str(text):
|
| 36 |
+
"""Helper function for argument parsing."""
|
| 37 |
+
try:
|
| 38 |
+
return int(text)
|
| 39 |
+
except ValueError:
|
| 40 |
+
return text
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
'-l', '--list-devices', action='store_true',
|
| 46 |
+
help='show list of audio devices and exit')
|
| 47 |
+
args, remaining = parser.parse_known_args()
|
| 48 |
+
if args.list_devices:
|
| 49 |
+
print(sd.query_devices())
|
| 50 |
+
parser.exit(0)
|
| 51 |
+
parser = argparse.ArgumentParser(
|
| 52 |
+
description=__doc__,
|
| 53 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 54 |
+
parents=[parser])
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
'-i', '--input-device', type=int_or_str,
|
| 57 |
+
help='input device (numeric ID or substring)')
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
'-o', '--output-device', type=int_or_str,
|
| 60 |
+
help='output device (numeric ID or substring)')
|
| 61 |
+
|
| 62 |
+
parser.add_argument('--latency', type=float, help='latency in seconds', default=0.2)
|
| 63 |
+
args = parser.parse_args(remaining)
|
| 64 |
+
|
| 65 |
+
# set some parameters
|
| 66 |
+
block_len_ms = 32
|
| 67 |
+
block_shift_ms = 8
|
| 68 |
+
fs_target = 16000
|
| 69 |
+
# create the interpreters
|
| 70 |
+
interpreter_1 = tflite.Interpreter(model_path='./pretrained_model/model_1.tflite')
|
| 71 |
+
interpreter_1.allocate_tensors()
|
| 72 |
+
interpreter_2 = tflite.Interpreter(model_path='./pretrained_model/model_2.tflite')
|
| 73 |
+
interpreter_2.allocate_tensors()
|
| 74 |
+
# Get input and output tensors.
|
| 75 |
+
input_details_1 = interpreter_1.get_input_details()
|
| 76 |
+
output_details_1 = interpreter_1.get_output_details()
|
| 77 |
+
input_details_2 = interpreter_2.get_input_details()
|
| 78 |
+
output_details_2 = interpreter_2.get_output_details()
|
| 79 |
+
# create states for the lstms
|
| 80 |
+
states_1 = np.zeros(input_details_1[1]['shape']).astype('float32')
|
| 81 |
+
states_2 = np.zeros(input_details_2[1]['shape']).astype('float32')
|
| 82 |
+
# calculate shift and length
|
| 83 |
+
block_shift = int(np.round(fs_target * (block_shift_ms / 1000)))
|
| 84 |
+
block_len = int(np.round(fs_target * (block_len_ms / 1000)))
|
| 85 |
+
# create buffer
|
| 86 |
+
in_buffer = np.zeros((block_len)).astype('float32')
|
| 87 |
+
out_buffer = np.zeros((block_len)).astype('float32')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def callback(indata, outdata, frames, time, status):
|
| 91 |
+
# buffer and states to global
|
| 92 |
+
global in_buffer, out_buffer, states_1, states_2
|
| 93 |
+
if status:
|
| 94 |
+
print(status)
|
| 95 |
+
# write to buffer
|
| 96 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 97 |
+
in_buffer[-block_shift:] = np.squeeze(indata)
|
| 98 |
+
# calculate fft of input block
|
| 99 |
+
in_block_fft = np.fft.rfft(in_buffer)
|
| 100 |
+
in_mag = np.abs(in_block_fft)
|
| 101 |
+
in_phase = np.angle(in_block_fft)
|
| 102 |
+
# reshape magnitude to input dimensions
|
| 103 |
+
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
|
| 104 |
+
# set tensors to the first model
|
| 105 |
+
interpreter_1.set_tensor(input_details_1[1]['index'], states_1)
|
| 106 |
+
interpreter_1.set_tensor(input_details_1[0]['index'], in_mag)
|
| 107 |
+
# run calculation
|
| 108 |
+
interpreter_1.invoke()
|
| 109 |
+
# get the output of the first block
|
| 110 |
+
out_mask = interpreter_1.get_tensor(output_details_1[0]['index'])
|
| 111 |
+
states_1 = interpreter_1.get_tensor(output_details_1[1]['index'])
|
| 112 |
+
# calculate the ifft
|
| 113 |
+
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
|
| 114 |
+
estimated_block = np.fft.irfft(estimated_complex)
|
| 115 |
+
# reshape the time domain block
|
| 116 |
+
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
|
| 117 |
+
# set tensors to the second block
|
| 118 |
+
interpreter_2.set_tensor(input_details_2[1]['index'], states_2)
|
| 119 |
+
interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block)
|
| 120 |
+
# run calculation
|
| 121 |
+
interpreter_2.invoke()
|
| 122 |
+
# get output tensors
|
| 123 |
+
out_block = interpreter_2.get_tensor(output_details_2[0]['index'])
|
| 124 |
+
states_2 = interpreter_2.get_tensor(output_details_2[1]['index'])
|
| 125 |
+
# write to buffer
|
| 126 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 127 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 128 |
+
out_buffer += np.squeeze(out_block)
|
| 129 |
+
# output to soundcard
|
| 130 |
+
outdata[:] = np.expand_dims(out_buffer[:block_shift], axis=-1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
with sd.Stream(device=(args.input_device, args.output_device),
|
| 136 |
+
samplerate=fs_target, blocksize=block_shift,
|
| 137 |
+
dtype=np.float32, latency=args.latency,
|
| 138 |
+
channels=1, callback=callback):
|
| 139 |
+
print('#' * 80)
|
| 140 |
+
print('press Return to quit')
|
| 141 |
+
print('#' * 80)
|
| 142 |
+
input()
|
| 143 |
+
except KeyboardInterrupt:
|
| 144 |
+
parser.exit('')
|
| 145 |
+
except Exception as e:
|
| 146 |
+
parser.exit(type(e).__name__ + ': ' + str(e))
|
| 147 |
+
|
models/DTLN (alekya)/real_time_processing.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Created on Tue Jun 23 16:23:15 2020
|
| 5 |
+
|
| 6 |
+
@author: nils
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
##########################
|
| 16 |
+
# the values are fixed, if you need other values, you have to retrain.
|
| 17 |
+
# The sampling rate of 16k is also fix.
|
| 18 |
+
block_len = 512
|
| 19 |
+
block_shift = 128
|
| 20 |
+
# load model
|
| 21 |
+
model = tf.saved_model.load('./pretrained_model/dtln_saved_model')
|
| 22 |
+
infer = model.signatures["serving_default"]
|
| 23 |
+
# load audio file at 16k fs (please change)
|
| 24 |
+
audio,fs = sf.read('path_to_your_favorite_audio.wav')
|
| 25 |
+
# check for sampling rate
|
| 26 |
+
if fs != 16000:
|
| 27 |
+
raise ValueError('This model only supports 16k sampling rate.')
|
| 28 |
+
# preallocate output audio
|
| 29 |
+
out_file = np.zeros((len(audio)))
|
| 30 |
+
# create buffer
|
| 31 |
+
in_buffer = np.zeros((block_len))
|
| 32 |
+
out_buffer = np.zeros((block_len))
|
| 33 |
+
# calculate number of blocks
|
| 34 |
+
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
|
| 35 |
+
# iterate over the number of blcoks
|
| 36 |
+
for idx in range(num_blocks):
|
| 37 |
+
# shift values and write to buffer
|
| 38 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 39 |
+
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
|
| 40 |
+
# create a batch dimension of one
|
| 41 |
+
in_block = np.expand_dims(in_buffer, axis=0).astype('float32')
|
| 42 |
+
# process one block
|
| 43 |
+
out_block= infer(tf.constant(in_block))['conv1d_1']
|
| 44 |
+
# shift values and write to buffer
|
| 45 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 46 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 47 |
+
out_buffer += np.squeeze(out_block)
|
| 48 |
+
# write block to output file
|
| 49 |
+
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# write to .wav file
|
| 53 |
+
sf.write('out.wav', out_file, fs)
|
| 54 |
+
|
| 55 |
+
print('Processing finished.')
|
models/DTLN (alekya)/real_time_processing_onnx.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is an example how to implement real time processing of the DTLN ONNX
|
| 3 |
+
model in python.
|
| 4 |
+
|
| 5 |
+
Please change the name of the .wav file at line 49 before running the sript.
|
| 6 |
+
For the ONNX runtime call: $ pip install onnxruntime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 11 |
+
Version: 03.07.2020
|
| 12 |
+
|
| 13 |
+
This code is licensed under the terms of the MIT-license.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import soundfile as sf
|
| 17 |
+
import numpy as np
|
| 18 |
+
import time
|
| 19 |
+
import onnxruntime
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
##########################
|
| 24 |
+
# the values are fixed, if you need other values, you have to retrain.
|
| 25 |
+
# The sampling rate of 16k is also fix.
|
| 26 |
+
block_len = 512
|
| 27 |
+
block_shift = 128
|
| 28 |
+
# load models
|
| 29 |
+
interpreter_1 = onnxruntime.InferenceSession('./model_1.onnx')
|
| 30 |
+
model_input_names_1 = [inp.name for inp in interpreter_1.get_inputs()]
|
| 31 |
+
# preallocate input
|
| 32 |
+
model_inputs_1 = {
|
| 33 |
+
inp.name: np.zeros(
|
| 34 |
+
[dim if isinstance(dim, int) else 1 for dim in inp.shape],
|
| 35 |
+
dtype=np.float32)
|
| 36 |
+
for inp in interpreter_1.get_inputs()}
|
| 37 |
+
# load models
|
| 38 |
+
interpreter_2 = onnxruntime.InferenceSession('./model_2.onnx')
|
| 39 |
+
model_input_names_2 = [inp.name for inp in interpreter_2.get_inputs()]
|
| 40 |
+
# preallocate input
|
| 41 |
+
model_inputs_2 = {
|
| 42 |
+
inp.name: np.zeros(
|
| 43 |
+
[dim if isinstance(dim, int) else 1 for dim in inp.shape],
|
| 44 |
+
dtype=np.float32)
|
| 45 |
+
for inp in interpreter_2.get_inputs()}
|
| 46 |
+
|
| 47 |
+
# load audio file
|
| 48 |
+
audio,fs = sf.read('path/to/your/favorite.wav')
|
| 49 |
+
# check for sampling rate
|
| 50 |
+
if fs != 16000:
|
| 51 |
+
raise ValueError('This model only supports 16k sampling rate.')
|
| 52 |
+
# preallocate output audio
|
| 53 |
+
out_file = np.zeros((len(audio)))
|
| 54 |
+
# create buffer
|
| 55 |
+
in_buffer = np.zeros((block_len)).astype('float32')
|
| 56 |
+
out_buffer = np.zeros((block_len)).astype('float32')
|
| 57 |
+
# calculate number of blocks
|
| 58 |
+
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
|
| 59 |
+
# iterate over the number of blcoks
|
| 60 |
+
time_array = []
|
| 61 |
+
for idx in range(num_blocks):
|
| 62 |
+
start_time = time.time()
|
| 63 |
+
# shift values and write to buffer
|
| 64 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 65 |
+
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
|
| 66 |
+
# calculate fft of input block
|
| 67 |
+
in_block_fft = np.fft.rfft(in_buffer)
|
| 68 |
+
in_mag = np.abs(in_block_fft)
|
| 69 |
+
in_phase = np.angle(in_block_fft)
|
| 70 |
+
# reshape magnitude to input dimensions
|
| 71 |
+
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
|
| 72 |
+
# set block to input
|
| 73 |
+
model_inputs_1[model_input_names_1[0]] = in_mag
|
| 74 |
+
# run calculation
|
| 75 |
+
model_outputs_1 = interpreter_1.run(None, model_inputs_1)
|
| 76 |
+
# get the output of the first block
|
| 77 |
+
out_mask = model_outputs_1[0]
|
| 78 |
+
# set out states back to input
|
| 79 |
+
model_inputs_1[model_input_names_1[1]] = model_outputs_1[1]
|
| 80 |
+
# calculate the ifft
|
| 81 |
+
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
|
| 82 |
+
estimated_block = np.fft.irfft(estimated_complex)
|
| 83 |
+
# reshape the time domain block
|
| 84 |
+
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
|
| 85 |
+
# set tensors to the second block
|
| 86 |
+
# interpreter_2.set_tensor(input_details_1[1]['index'], states_2)
|
| 87 |
+
model_inputs_2[model_input_names_2[0]] = estimated_block
|
| 88 |
+
# run calculation
|
| 89 |
+
model_outputs_2 = interpreter_2.run(None, model_inputs_2)
|
| 90 |
+
# get output
|
| 91 |
+
out_block = model_outputs_2[0]
|
| 92 |
+
# set out states back to input
|
| 93 |
+
model_inputs_2[model_input_names_2[1]] = model_outputs_2[1]
|
| 94 |
+
# shift values and write to buffer
|
| 95 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 96 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 97 |
+
out_buffer += np.squeeze(out_block)
|
| 98 |
+
# write block to output file
|
| 99 |
+
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
|
| 100 |
+
time_array.append(time.time()-start_time)
|
| 101 |
+
|
| 102 |
+
# write to .wav file
|
| 103 |
+
sf.write('out.wav', out_file, fs)
|
| 104 |
+
print('Processing Time [ms]:')
|
| 105 |
+
print(np.mean(np.stack(time_array))*1000)
|
| 106 |
+
print('Processing finished.')
|
models/DTLN (alekya)/real_time_processing_tf_lite.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is an example how to implement real time processing of the DTLN tf light
|
| 3 |
+
model in python.
|
| 4 |
+
|
| 5 |
+
Please change the name of the .wav file at line 43 before running the sript.
|
| 6 |
+
For .whl files of the tf light runtime go to:
|
| 7 |
+
https://www.tensorflow.org/lite/guide/python
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 30.06.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
import numpy as np
|
| 17 |
+
import tflite_runtime.interpreter as tflite
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
##########################
|
| 23 |
+
# the values are fixed, if you need other values, you have to retrain.
|
| 24 |
+
# The sampling rate of 16k is also fix.
|
| 25 |
+
block_len = 512
|
| 26 |
+
block_shift = 128
|
| 27 |
+
# load models
|
| 28 |
+
interpreter_1 = tflite.Interpreter(model_path='./pretrained_model/model_1.tflite')
|
| 29 |
+
interpreter_1.allocate_tensors()
|
| 30 |
+
interpreter_2 = tflite.Interpreter(model_path='./pretrained_model/model_2.tflite')
|
| 31 |
+
interpreter_2.allocate_tensors()
|
| 32 |
+
|
| 33 |
+
# Get input and output tensors.
|
| 34 |
+
input_details_1 = interpreter_1.get_input_details()
|
| 35 |
+
output_details_1 = interpreter_1.get_output_details()
|
| 36 |
+
|
| 37 |
+
input_details_2 = interpreter_2.get_input_details()
|
| 38 |
+
output_details_2 = interpreter_2.get_output_details()
|
| 39 |
+
# create states for the lstms
|
| 40 |
+
states_1 = np.zeros(input_details_1[1]['shape']).astype('float32')
|
| 41 |
+
states_2 = np.zeros(input_details_2[1]['shape']).astype('float32')
|
| 42 |
+
# load audio file at 16k fs (please change)
|
| 43 |
+
audio,fs = sf.read('path/to/your/favorite/.wav')
|
| 44 |
+
# check for sampling rate
|
| 45 |
+
if fs != 16000:
|
| 46 |
+
raise ValueError('This model only supports 16k sampling rate.')
|
| 47 |
+
# preallocate output audio
|
| 48 |
+
out_file = np.zeros((len(audio)))
|
| 49 |
+
# create buffer
|
| 50 |
+
in_buffer = np.zeros((block_len)).astype('float32')
|
| 51 |
+
out_buffer = np.zeros((block_len)).astype('float32')
|
| 52 |
+
# calculate number of blocks
|
| 53 |
+
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
|
| 54 |
+
time_array = []
|
| 55 |
+
# iterate over the number of blcoks
|
| 56 |
+
for idx in range(num_blocks):
|
| 57 |
+
start_time = time.time()
|
| 58 |
+
# shift values and write to buffer
|
| 59 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 60 |
+
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
|
| 61 |
+
# calculate fft of input block
|
| 62 |
+
in_block_fft = np.fft.rfft(in_buffer)
|
| 63 |
+
in_mag = np.abs(in_block_fft)
|
| 64 |
+
in_phase = np.angle(in_block_fft)
|
| 65 |
+
# reshape magnitude to input dimensions
|
| 66 |
+
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
|
| 67 |
+
# set tensors to the first model
|
| 68 |
+
interpreter_1.set_tensor(input_details_1[1]['index'], states_1)
|
| 69 |
+
interpreter_1.set_tensor(input_details_1[0]['index'], in_mag)
|
| 70 |
+
# run calculation
|
| 71 |
+
interpreter_1.invoke()
|
| 72 |
+
# get the output of the first block
|
| 73 |
+
out_mask = interpreter_1.get_tensor(output_details_1[0]['index'])
|
| 74 |
+
states_1 = interpreter_1.get_tensor(output_details_1[1]['index'])
|
| 75 |
+
# calculate the ifft
|
| 76 |
+
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
|
| 77 |
+
estimated_block = np.fft.irfft(estimated_complex)
|
| 78 |
+
# reshape the time domain block
|
| 79 |
+
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
|
| 80 |
+
# set tensors to the second block
|
| 81 |
+
interpreter_2.set_tensor(input_details_2[1]['index'], states_2)
|
| 82 |
+
interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block)
|
| 83 |
+
# run calculation
|
| 84 |
+
interpreter_2.invoke()
|
| 85 |
+
# get output tensors
|
| 86 |
+
out_block = interpreter_2.get_tensor(output_details_2[0]['index'])
|
| 87 |
+
states_2 = interpreter_2.get_tensor(output_details_2[1]['index'])
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# shift values and write to buffer
|
| 91 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 92 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 93 |
+
out_buffer += np.squeeze(out_block)
|
| 94 |
+
# write block to output file
|
| 95 |
+
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
|
| 96 |
+
time_array.append(time.time()-start_time)
|
| 97 |
+
|
| 98 |
+
# write to .wav file
|
| 99 |
+
sf.write('out.wav', out_file, fs)
|
| 100 |
+
print('Processing Time [ms]:')
|
| 101 |
+
print(np.mean(np.stack(time_array))*1000)
|
| 102 |
+
print('Processing finished.')
|
models/DTLN (alekya)/run_evaluation.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Script to process a folder of .wav files with a trained DTLN model.
|
| 4 |
+
This script supports subfolders and names the processed files the same as the
|
| 5 |
+
original. The model expects 16kHz audio .wav files. Files with other
|
| 6 |
+
sampling rates will be resampled. Stereo files will be downmixed to mono.
|
| 7 |
+
|
| 8 |
+
The idea of this script is to use it for baseline or comparison purpose.
|
| 9 |
+
|
| 10 |
+
Example call:
|
| 11 |
+
$python run_evaluation.py -i /name/of/input/folder \
|
| 12 |
+
-o /name/of/output/folder \
|
| 13 |
+
-m /name/of/the/model.h5
|
| 14 |
+
|
| 15 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 16 |
+
Version: 13.05.2020
|
| 17 |
+
|
| 18 |
+
This code is licensed under the terms of the MIT-license.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
import librosa
|
| 23 |
+
import numpy as np
|
| 24 |
+
import os
|
| 25 |
+
import argparse
|
| 26 |
+
from DTLN_model import DTLN_model
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def process_file(model, audio_file_name, out_file_name):
|
| 31 |
+
'''
|
| 32 |
+
Funtion to read an audio file, rocess it by the network and write the
|
| 33 |
+
enhanced audio to .wav file.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
model : Keras model
|
| 38 |
+
Keras model, which accepts audio in the size (1,timesteps).
|
| 39 |
+
audio_file_name : STRING
|
| 40 |
+
Name and path of the input audio file.
|
| 41 |
+
out_file_name : STRING
|
| 42 |
+
Name and path of the target file.
|
| 43 |
+
|
| 44 |
+
'''
|
| 45 |
+
|
| 46 |
+
# read audio file with librosa to handle resampling and enforce mono
|
| 47 |
+
in_data,fs = librosa.core.load(audio_file_name, sr=16000, mono=True)
|
| 48 |
+
# get length of file
|
| 49 |
+
len_orig = len(in_data)
|
| 50 |
+
# pad audio
|
| 51 |
+
zero_pad = np.zeros(384)
|
| 52 |
+
in_data = np.concatenate((zero_pad, in_data, zero_pad), axis=0)
|
| 53 |
+
# predict audio with the model
|
| 54 |
+
predicted = model.predict_on_batch(
|
| 55 |
+
np.expand_dims(in_data,axis=0).astype(np.float32))
|
| 56 |
+
# squeeze the batch dimension away
|
| 57 |
+
predicted_speech = np.squeeze(predicted)
|
| 58 |
+
predicted_speech = predicted_speech[384:384+len_orig]
|
| 59 |
+
# write the file to target destination
|
| 60 |
+
sf.write(out_file_name, predicted_speech,fs)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def process_folder(model, folder_name, new_folder_name):
|
| 64 |
+
'''
|
| 65 |
+
Function to find .wav files in the folder and subfolders of "folder_name",
|
| 66 |
+
process each .wav file with an algorithm and write it back to disk in the
|
| 67 |
+
folder "new_folder_name". The structure of the original directory is
|
| 68 |
+
preserved. The processed files will be saved with the same name as the
|
| 69 |
+
original file.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
model : Keras model
|
| 74 |
+
Keras model, which accepts audio in the size (1,timesteps).
|
| 75 |
+
folder_name : STRING
|
| 76 |
+
Input folder with .wav files.
|
| 77 |
+
new_folder_name : STRING
|
| 78 |
+
Traget folder for the processed files.
|
| 79 |
+
|
| 80 |
+
'''
|
| 81 |
+
|
| 82 |
+
# empty list for file and folder names
|
| 83 |
+
file_names = [];
|
| 84 |
+
directories = [];
|
| 85 |
+
new_directories = [];
|
| 86 |
+
# walk through the directory
|
| 87 |
+
for root, dirs, files in os.walk(folder_name):
|
| 88 |
+
for file in files:
|
| 89 |
+
# look for .wav files
|
| 90 |
+
if file.endswith(".wav"):
|
| 91 |
+
# write paths and filenames to lists
|
| 92 |
+
file_names.append(file)
|
| 93 |
+
directories.append(root)
|
| 94 |
+
# create new directory names
|
| 95 |
+
new_directories.append(root.replace(folder_name, new_folder_name))
|
| 96 |
+
# check if the new directory already exists, if not create it
|
| 97 |
+
if not os.path.exists(root.replace(folder_name, new_folder_name)):
|
| 98 |
+
os.makedirs(root.replace(folder_name, new_folder_name))
|
| 99 |
+
# iterate over all .wav files
|
| 100 |
+
for idx in range(len(file_names)):
|
| 101 |
+
# process each file with the model
|
| 102 |
+
process_file(model, os.path.join(directories[idx],file_names[idx]),
|
| 103 |
+
os.path.join(new_directories[idx],file_names[idx]))
|
| 104 |
+
print(file_names[idx] + ' processed successfully!')
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
# arguement parser for running directly from the command line
|
| 111 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 112 |
+
parser.add_argument('--in_folder', '-i',
|
| 113 |
+
help='folder with input files')
|
| 114 |
+
parser.add_argument('--out_folder', '-o',
|
| 115 |
+
help='target folder for processed files')
|
| 116 |
+
parser.add_argument('--model', '-m',
|
| 117 |
+
help='weights of the enhancement model in .h5 format')
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
# determine type of model
|
| 120 |
+
if args.model.find('_norm_') != -1:
|
| 121 |
+
norm_stft = True
|
| 122 |
+
else:
|
| 123 |
+
norm_stft = False
|
| 124 |
+
# create class instance
|
| 125 |
+
modelClass = DTLN_model();
|
| 126 |
+
# build the model in default configuration
|
| 127 |
+
modelClass.build_DTLN_model(norm_stft=norm_stft)
|
| 128 |
+
# load weights of the .h5 file
|
| 129 |
+
modelClass.model.load_weights(args.model)
|
| 130 |
+
# process the folder
|
| 131 |
+
process_folder(modelClass.model, args.in_folder, args.out_folder)
|
models/DTLN (alekya)/run_training.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Script to train the DTLN model in default settings. The folders for noisy and
|
| 5 |
+
clean files are expected to have the same number of files and the files to
|
| 6 |
+
have the same name. The training procedure always saves the best weights of
|
| 7 |
+
the model into the folder "./models_'runName'/". Also a log file of the
|
| 8 |
+
training progress is written there. To change any parameters go to the
|
| 9 |
+
"DTLN_model.py" file or use "modelTrainer.parameter = XY" in this file.
|
| 10 |
+
It is recommended to run the training on a GPU. The setup is optimized for the
|
| 11 |
+
DNS-Challenge data set. If you use a custom data set, just play around with
|
| 12 |
+
the parameters.
|
| 13 |
+
|
| 14 |
+
Please change the folder names before starting the training.
|
| 15 |
+
|
| 16 |
+
Example call:
|
| 17 |
+
$python run_training.py
|
| 18 |
+
|
| 19 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 20 |
+
Version: 13.05.2020
|
| 21 |
+
|
| 22 |
+
This code is licensed under the terms of the MIT-license.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from DTLN_model import DTLN_model
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
# use the GPU with idx 0
|
| 29 |
+
os.environ["CUDA_VISIBLE_DEVICES"]='0'
|
| 30 |
+
# activate this for some reproducibility
|
| 31 |
+
os.environ['TF_DETERMINISTIC_OPS'] = '1'
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# path to folder containing the noisy or mixed audio training files
|
| 38 |
+
path_to_train_mix = '/path/to/noisy/training/data/'
|
| 39 |
+
# path to folder containing the clean/speech files for training
|
| 40 |
+
path_to_train_speech = '/path/to/clean/training/data/'
|
| 41 |
+
# path to folder containing the noisy or mixed audio validation data
|
| 42 |
+
path_to_val_mix = '/path/to/noisy/validation/data/'
|
| 43 |
+
# path to folder containing the clean audio validation data
|
| 44 |
+
path_to_val_speech = '/path/to/clean/validation/data/'
|
| 45 |
+
|
| 46 |
+
# name your training run
|
| 47 |
+
runName = 'DTLN_model'
|
| 48 |
+
# create instance of the DTLN model class
|
| 49 |
+
modelTrainer = DTLN_model()
|
| 50 |
+
# build the model
|
| 51 |
+
modelTrainer.build_DTLN_model()
|
| 52 |
+
# compile it with optimizer and cost function for training
|
| 53 |
+
modelTrainer.compile_model()
|
| 54 |
+
# train the model
|
| 55 |
+
modelTrainer.train_model(runName, path_to_train_mix, path_to_train_speech, \
|
| 56 |
+
path_to_val_mix, path_to_val_speech)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
models/DTLN (alekya)/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/alekya/DTLN
|
models/DTLN (alekya)/test.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import IPython
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from scipy.io import wavfile
|
| 5 |
+
import wave as we
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
import os
|
| 8 |
+
import librosa
|
| 9 |
+
from pesq import pesq
|
| 10 |
+
|
| 11 |
+
clean_path = "/home/ap/Desktop/Workspace/DTLN/valset_clean/p232_007.wav"
|
| 12 |
+
denoised_path = '/home/ap/Desktop/Workspace/DTLN/output-denoised-audio (1).wav'
|
| 13 |
+
|
| 14 |
+
# denoised_files = os.listdir(denoised_path)
|
| 15 |
+
# clean_files = os.listdir(clean_path)
|
| 16 |
+
|
| 17 |
+
# rmse = []
|
| 18 |
+
# snr_list = []
|
| 19 |
+
# pesq_list = []
|
| 20 |
+
|
| 21 |
+
# for i in range(len(denoised_files)):
|
| 22 |
+
original, sr1 = librosa.load(clean_path, sr=None)
|
| 23 |
+
denoised, sr2 = librosa.load(denoised_path, sr=None)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
min_length = min(len(original), len(denoised))
|
| 27 |
+
original = original[:min_length]
|
| 28 |
+
denoised = denoised[:min_length]
|
| 29 |
+
|
| 30 |
+
mse = np.mean((original - denoised) ** 2)
|
| 31 |
+
|
| 32 |
+
# Calculate Signal-to-Noise Ratio (SNR)
|
| 33 |
+
signal_power = np.mean(original ** 2)
|
| 34 |
+
noise_power = np.mean((original - denoised) ** 2)
|
| 35 |
+
snr = 10 * np.log10(signal_power / noise_power)
|
| 36 |
+
|
| 37 |
+
# rmse.append(mse)
|
| 38 |
+
# snr_list.append(snr)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
print(f"Mean Squared Error: {np.average(mse)}")
|
| 42 |
+
print(f"Signal-to-Noise Ratio (SNR): {np.average(snr)} dB")
|
| 43 |
+
|
| 44 |
+
rate, ref = wavfile.read("/home/ap/Desktop/Workspace/DTLN/test/19-198-0034.wav")
|
| 45 |
+
rate, deg = wavfile.read("/home/ap/Desktop/Workspace/DTLN/output_test/19-198-0034.wav")
|
| 46 |
+
|
| 47 |
+
print(pesq(rate, ref, deg, 'wb'))
|
models/DTLN (alekya)/tflite_env.yml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tflite-env
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- audioread=2.1.8=py37hc8dfbb8_2
|
| 7 |
+
- bzip2=1.0.8=h0b31af3_2
|
| 8 |
+
- ca-certificates=2020.6.20=hecda079_0
|
| 9 |
+
- certifi=2020.6.20=py37hc8dfbb8_0
|
| 10 |
+
- cffi=1.14.0=py37hc512035_1
|
| 11 |
+
- cycler=0.10.0=py_2
|
| 12 |
+
- decorator=4.4.2=py_0
|
| 13 |
+
- ffmpeg=4.2.3=hd0c0d6a_0
|
| 14 |
+
- freetype=2.10.2=h8da9a1a_0
|
| 15 |
+
- gettext=0.19.8.1=h1f1d5ed_1
|
| 16 |
+
- gmp=6.2.0=h4a8c4bd_2
|
| 17 |
+
- gnutls=3.6.13=hc269f14_0
|
| 18 |
+
- joblib=0.15.1=py_0
|
| 19 |
+
- kiwisolver=1.2.0=py37ha1cc60f_0
|
| 20 |
+
- lame=3.100=h1de35cc_1001
|
| 21 |
+
- libblas=3.8.0=17_openblas
|
| 22 |
+
- libcblas=3.8.0=17_openblas
|
| 23 |
+
- libcxx=10.0.0=1
|
| 24 |
+
- libedit=3.1.20191231=haf1e3a3_0
|
| 25 |
+
- libffi=3.3=h0a44026_1
|
| 26 |
+
- libflac=1.3.3=h4a8c4bd_0
|
| 27 |
+
- libgfortran=4.0.0=2
|
| 28 |
+
- libiconv=1.15=h0b31af3_1006
|
| 29 |
+
- liblapack=3.8.0=17_openblas
|
| 30 |
+
- libllvm8=8.0.1=h770b8ee_0
|
| 31 |
+
- libogg=1.3.2=h0b31af3_1002
|
| 32 |
+
- libopenblas=0.3.10=h3d69b6c_0
|
| 33 |
+
- libpng=1.6.37=hbbe82c9_1
|
| 34 |
+
- librosa=0.7.2=py_1
|
| 35 |
+
- libsndfile=1.0.28=h4a8c4bd_1000
|
| 36 |
+
- libvorbis=1.3.6=h4a8c4bd_2
|
| 37 |
+
- llvm-openmp=10.0.0=h28b9765_0
|
| 38 |
+
- llvmlite=0.31.0=py37hb548287_1
|
| 39 |
+
- matplotlib-base=3.2.2=py37hddda452_0
|
| 40 |
+
- ncurses=6.2=h0a44026_1
|
| 41 |
+
- nettle=3.4.1=h3efe00b_1002
|
| 42 |
+
- numba=0.48.0=py37h4f17bb1_0
|
| 43 |
+
- numpy=1.18.5=py37h7687784_0
|
| 44 |
+
- openh264=2.1.1=hd174df1_0
|
| 45 |
+
- openssl=1.1.1g=h0b31af3_0
|
| 46 |
+
- pip=20.1.1=py37_1
|
| 47 |
+
- portaudio=19.6.0=h647c56a_4
|
| 48 |
+
- pycparser=2.20=pyh9f0ad1d_2
|
| 49 |
+
- pyparsing=2.4.7=pyh9f0ad1d_0
|
| 50 |
+
- pysoundfile=0.10.2=py_1001
|
| 51 |
+
- python=3.7.7=hf48f09d_4
|
| 52 |
+
- python-dateutil=2.8.1=py_0
|
| 53 |
+
- python-sounddevice=0.3.15=pyh8c360ce_0
|
| 54 |
+
- python_abi=3.7=1_cp37m
|
| 55 |
+
- readline=8.0=h1de35cc_0
|
| 56 |
+
- resampy=0.2.2=py_0
|
| 57 |
+
- scikit-learn=0.23.1=py37hf5857e7_0
|
| 58 |
+
- scipy=1.5.0=py37hce1b9e5_0
|
| 59 |
+
- setuptools=47.3.1=py37_0
|
| 60 |
+
- six=1.15.0=pyh9f0ad1d_0
|
| 61 |
+
- sqlite=3.32.3=hffcf06c_0
|
| 62 |
+
- threadpoolctl=2.1.0=pyh5ca1d4c_0
|
| 63 |
+
- tk=8.6.10=hb0a8c7a_0
|
| 64 |
+
- tornado=6.0.4=py37h9bfed18_1
|
| 65 |
+
- wheel=0.34.2=py37_0
|
| 66 |
+
- x264=1!152.20180806=h1de35cc_0
|
| 67 |
+
- xz=5.2.5=h1de35cc_0
|
| 68 |
+
- zlib=1.2.11=h1de35cc_3
|
| 69 |
+
- pip:
|
| 70 |
+
- flatbuffers==1.12
|
| 71 |
+
- tflite-runtime==2.1.0.post1
|
| 72 |
+
prefix: /Applications/anaconda3/envs/tflite-env
|
| 73 |
+
|
models/DTLN (alekya)/train_env.yml
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: train_env
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _tflow_select=2.1.0=gpu
|
| 8 |
+
- absl-py=0.9.0=py37_0
|
| 9 |
+
- astunparse=1.6.3=py_0
|
| 10 |
+
- audioread=2.1.8=py37hc8dfbb8_2
|
| 11 |
+
- blas=1.0=mkl
|
| 12 |
+
- blinker=1.4=py37_0
|
| 13 |
+
- bzip2=1.0.8=h516909a_2
|
| 14 |
+
- c-ares=1.15.0=h7b6447c_1001
|
| 15 |
+
- ca-certificates=2020.4.5.2=hecda079_0
|
| 16 |
+
- cachetools=4.1.0=py_1
|
| 17 |
+
- certifi=2020.4.5.2=py37hc8dfbb8_0
|
| 18 |
+
- cffi=1.14.0=py37he30daa8_1
|
| 19 |
+
- chardet=3.0.4=py37_1003
|
| 20 |
+
- click=7.1.2=py_0
|
| 21 |
+
- cryptography=2.9.2=py37h1ba5d50_0
|
| 22 |
+
- cudatoolkit=10.1.243=h6bb024c_0
|
| 23 |
+
- cudnn=7.6.5=cuda10.1_0
|
| 24 |
+
- cupti=10.1.168=0
|
| 25 |
+
- cycler=0.10.0=py_2
|
| 26 |
+
- decorator=4.4.2=py_0
|
| 27 |
+
- ffmpeg=4.2.3=h167e202_0
|
| 28 |
+
- freetype=2.10.2=he06d7ca_0
|
| 29 |
+
- gast=0.3.3=py_0
|
| 30 |
+
- gettext=0.19.8.1=h5e8e0c9_1
|
| 31 |
+
- gmp=6.2.0=he1b5a44_2
|
| 32 |
+
- gnutls=3.6.13=h79a8f9a_0
|
| 33 |
+
- google-auth=1.14.1=py_0
|
| 34 |
+
- google-auth-oauthlib=0.4.1=py_2
|
| 35 |
+
- google-pasta=0.2.0=py_0
|
| 36 |
+
- grpcio=1.27.2=py37hf8bcb03_0
|
| 37 |
+
- h5py=2.10.0=py37h7918eee_0
|
| 38 |
+
- hdf5=1.10.4=hb1b8bf9_0
|
| 39 |
+
- icu=58.2=hf484d3e_1000
|
| 40 |
+
- idna=2.9=py_1
|
| 41 |
+
- intel-openmp=2020.1=217
|
| 42 |
+
- joblib=0.15.1=py_0
|
| 43 |
+
- keras-preprocessing=1.1.0=py_1
|
| 44 |
+
- kiwisolver=1.2.0=py37h99015e2_0
|
| 45 |
+
- lame=3.100=h14c3975_1001
|
| 46 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
| 47 |
+
- libedit=3.1.20181209=hc058e9b_0
|
| 48 |
+
- libffi=3.3=he6710b0_1
|
| 49 |
+
- libflac=1.3.3=he1b5a44_0
|
| 50 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
| 51 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
| 52 |
+
- libiconv=1.15=h516909a_1006
|
| 53 |
+
- libllvm8=8.0.1=hc9558a2_0
|
| 54 |
+
- libogg=1.3.2=h516909a_1002
|
| 55 |
+
- libpng=1.6.37=hed695b0_1
|
| 56 |
+
- libprotobuf=3.12.3=hd408876_0
|
| 57 |
+
- librosa=0.7.2=py_1
|
| 58 |
+
- libsndfile=1.0.28=he1b5a44_1000
|
| 59 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
| 60 |
+
- libvorbis=1.3.6=he1b5a44_2
|
| 61 |
+
- llvmlite=0.31.0=py37h5202443_1
|
| 62 |
+
- markdown=3.1.1=py37_0
|
| 63 |
+
- matplotlib-base=3.1.3=py37hef1b27d_0
|
| 64 |
+
- mkl=2020.1=217
|
| 65 |
+
- mkl-service=2.3.0=py37he904b0f_0
|
| 66 |
+
- mkl_fft=1.0.15=py37ha843d7b_0
|
| 67 |
+
- mkl_random=1.1.1=py37h0573a6f_0
|
| 68 |
+
- ncurses=6.2=he6710b0_1
|
| 69 |
+
- nettle=3.4.1=h1bed415_1002
|
| 70 |
+
- numba=0.48.0=py37hb3f55d8_0
|
| 71 |
+
- numpy=1.18.1=py37h4f9e942_0
|
| 72 |
+
- numpy-base=1.18.1=py37hde5b4d6_1
|
| 73 |
+
- oauthlib=3.1.0=py_0
|
| 74 |
+
- openh264=2.1.1=h8b12597_0
|
| 75 |
+
- openssl=1.1.1g=h516909a_0
|
| 76 |
+
- opt_einsum=3.1.0=py_0
|
| 77 |
+
- pip=20.0.2=py37_3
|
| 78 |
+
- protobuf=3.12.3=py37he6710b0_0
|
| 79 |
+
- pyasn1=0.4.8=py_0
|
| 80 |
+
- pyasn1-modules=0.2.7=py_0
|
| 81 |
+
- pycparser=2.20=py_0
|
| 82 |
+
- pyjwt=1.7.1=py37_0
|
| 83 |
+
- pyopenssl=19.1.0=py37_0
|
| 84 |
+
- pyparsing=2.4.7=pyh9f0ad1d_0
|
| 85 |
+
- pysocks=1.7.1=py37_0
|
| 86 |
+
- pysoundfile=0.10.2=py_1001
|
| 87 |
+
- python=3.7.7=hcff3b4d_5
|
| 88 |
+
- python-dateutil=2.8.1=py_0
|
| 89 |
+
- python_abi=3.7=1_cp37m
|
| 90 |
+
- readline=8.0=h7b6447c_0
|
| 91 |
+
- requests=2.23.0=py37_0
|
| 92 |
+
- requests-oauthlib=1.3.0=py_0
|
| 93 |
+
- resampy=0.2.2=py_0
|
| 94 |
+
- rsa=4.0=py_0
|
| 95 |
+
- scikit-learn=0.22.1=py37hd81dba3_0
|
| 96 |
+
- scipy=1.4.1=py37h0b6359f_0
|
| 97 |
+
- setuptools=47.1.1=py37_0
|
| 98 |
+
- six=1.15.0=py_0
|
| 99 |
+
- sqlite=3.31.1=h62c20be_1
|
| 100 |
+
- tensorboard=2.2.1=pyh532a8cf_0
|
| 101 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
| 102 |
+
- tensorflow=2.2.0=gpu_py37h1a511ff_0
|
| 103 |
+
- tensorflow-base=2.2.0=gpu_py37h8a81be8_0
|
| 104 |
+
- tensorflow-estimator=2.2.0=pyh208ff02_0
|
| 105 |
+
- tensorflow-gpu=2.2.0=h0d30ee6_0
|
| 106 |
+
- termcolor=1.1.0=py37_1
|
| 107 |
+
- tk=8.6.8=hbc83047_0
|
| 108 |
+
- tornado=6.0.4=py37h8f50634_1
|
| 109 |
+
- urllib3=1.25.8=py37_0
|
| 110 |
+
- werkzeug=1.0.1=py_0
|
| 111 |
+
- wheel=0.34.2=py37_0
|
| 112 |
+
- wrapt=1.12.1=py37h7b6447c_1
|
| 113 |
+
- x264=1!152.20180806=h14c3975_0
|
| 114 |
+
- xz=5.2.5=h7b6447c_0
|
| 115 |
+
- zlib=1.2.11=h7b6447c_3
|
| 116 |
+
- pip:
|
| 117 |
+
- lxml==4.5.1
|
| 118 |
+
- wavinfo==1.5
|
| 119 |
+
prefix: /home/nils/anaconda3/envs/tfenv
|
| 120 |
+
|
models/DTLN (yash-04)/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
models/DTLN (yash-04)/DTLN_model.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
This File contains everything to train the DTLN model.
|
| 4 |
+
|
| 5 |
+
For running the training see "run_training.py".
|
| 6 |
+
To run evaluation with the provided pretrained model see "run_evaluation.py".
|
| 7 |
+
|
| 8 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 9 |
+
Version: 24.06.2020
|
| 10 |
+
|
| 11 |
+
This code is licensed under the terms of the MIT-license.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os, fnmatch
|
| 16 |
+
import tensorflow.keras as keras
|
| 17 |
+
from tensorflow.keras.models import Model
|
| 18 |
+
from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, \
|
| 19 |
+
Lambda, Input, Multiply, Layer, Conv1D
|
| 20 |
+
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, \
|
| 21 |
+
EarlyStopping, ModelCheckpoint
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
import soundfile as sf
|
| 24 |
+
from wavinfo import WavInfoReader
|
| 25 |
+
from random import shuffle, seed
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class audio_generator():
|
| 31 |
+
'''
|
| 32 |
+
Class to create a Tensorflow dataset based on an iterator from a large scale
|
| 33 |
+
audio dataset. This audio generator only supports single channel audio files.
|
| 34 |
+
'''
|
| 35 |
+
|
| 36 |
+
def __init__(self, path_to_input, path_to_s1, len_of_samples, fs, train_flag=False):
|
| 37 |
+
'''
|
| 38 |
+
Constructor of the audio generator class.
|
| 39 |
+
Inputs:
|
| 40 |
+
path_to_input path to the mixtures
|
| 41 |
+
path_to_s1 path to the target source data
|
| 42 |
+
len_of_samples length of audio snippets in samples
|
| 43 |
+
fs sampling rate
|
| 44 |
+
train_flag flag for activate shuffling of files
|
| 45 |
+
'''
|
| 46 |
+
# set inputs to properties
|
| 47 |
+
self.path_to_input = path_to_input
|
| 48 |
+
self.path_to_s1 = path_to_s1
|
| 49 |
+
self.len_of_samples = len_of_samples
|
| 50 |
+
self.fs = fs
|
| 51 |
+
self.train_flag=train_flag
|
| 52 |
+
# count the number of samples in your data set (depending on your disk,
|
| 53 |
+
# this can take some time)
|
| 54 |
+
self.count_samples()
|
| 55 |
+
# create iterable tf.data.Dataset object
|
| 56 |
+
self.create_tf_data_obj()
|
| 57 |
+
|
| 58 |
+
def count_samples(self):
|
| 59 |
+
'''
|
| 60 |
+
Method to list the data of the dataset and count the number of samples.
|
| 61 |
+
'''
|
| 62 |
+
|
| 63 |
+
# list .wav files in directory
|
| 64 |
+
self.file_names = fnmatch.filter(os.listdir(self.path_to_input), '*.wav')
|
| 65 |
+
# count the number of samples contained in the dataset
|
| 66 |
+
self.total_samples = 0
|
| 67 |
+
for file in self.file_names:
|
| 68 |
+
info = WavInfoReader(os.path.join(self.path_to_input, file))
|
| 69 |
+
self.total_samples = self.total_samples + \
|
| 70 |
+
int(np.fix(info.data.frame_count/self.len_of_samples))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def create_generator(self):
|
| 74 |
+
'''
|
| 75 |
+
Method to create the iterator.
|
| 76 |
+
'''
|
| 77 |
+
|
| 78 |
+
# check if training or validation
|
| 79 |
+
if self.train_flag:
|
| 80 |
+
shuffle(self.file_names)
|
| 81 |
+
# iterate over the files
|
| 82 |
+
for file in self.file_names:
|
| 83 |
+
# read the audio files
|
| 84 |
+
noisy, fs_1 = sf.read(os.path.join(self.path_to_input, file))
|
| 85 |
+
speech, fs_2 = sf.read(os.path.join(self.path_to_s1, file))
|
| 86 |
+
# check if the sampling rates are matching the specifications
|
| 87 |
+
if fs_1 != self.fs or fs_2 != self.fs:
|
| 88 |
+
raise ValueError('Sampling rates do not match.')
|
| 89 |
+
if noisy.ndim != 1 or speech.ndim != 1:
|
| 90 |
+
raise ValueError('Too many audio channels. The DTLN audio_generator \
|
| 91 |
+
only supports single channel audio data.')
|
| 92 |
+
# count the number of samples in one file
|
| 93 |
+
num_samples = int(np.fix(noisy.shape[0]/self.len_of_samples))
|
| 94 |
+
# iterate over the number of samples
|
| 95 |
+
for idx in range(num_samples):
|
| 96 |
+
# cut the audio files in chunks
|
| 97 |
+
in_dat = noisy[int(idx*self.len_of_samples):int((idx+1)*
|
| 98 |
+
self.len_of_samples)]
|
| 99 |
+
tar_dat = speech[int(idx*self.len_of_samples):int((idx+1)*
|
| 100 |
+
self.len_of_samples)]
|
| 101 |
+
# yield the chunks as float32 data
|
| 102 |
+
yield in_dat.astype('float32'), tar_dat.astype('float32')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_tf_data_obj(self):
|
| 106 |
+
'''
|
| 107 |
+
Method to to create the tf.data.Dataset.
|
| 108 |
+
'''
|
| 109 |
+
|
| 110 |
+
# creating the tf.data.Dataset from the iterator
|
| 111 |
+
self.tf_data_set = tf.data.Dataset.from_generator(
|
| 112 |
+
self.create_generator,
|
| 113 |
+
(tf.float32, tf.float32),
|
| 114 |
+
output_shapes=(tf.TensorShape([self.len_of_samples]), \
|
| 115 |
+
tf.TensorShape([self.len_of_samples])),
|
| 116 |
+
args=None
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DTLN_model():
|
| 124 |
+
'''
|
| 125 |
+
Class to create and train the DTLN model
|
| 126 |
+
'''
|
| 127 |
+
|
| 128 |
+
def __init__(self):
|
| 129 |
+
'''
|
| 130 |
+
Constructor
|
| 131 |
+
'''
|
| 132 |
+
|
| 133 |
+
# defining default cost function
|
| 134 |
+
self.cost_function = self.snr_cost
|
| 135 |
+
# empty property for the model
|
| 136 |
+
self.model = []
|
| 137 |
+
# defining default parameters
|
| 138 |
+
self.fs = 16000
|
| 139 |
+
self.batchsize = 32
|
| 140 |
+
self.len_samples = 15
|
| 141 |
+
self.activation = 'sigmoid'
|
| 142 |
+
self.numUnits = 128
|
| 143 |
+
self.numLayer = 2
|
| 144 |
+
self.blockLen = 512
|
| 145 |
+
self.block_shift = 128
|
| 146 |
+
self.dropout = 0.25
|
| 147 |
+
self.lr = 1e-3
|
| 148 |
+
self.max_epochs = 200
|
| 149 |
+
self.encoder_size = 256
|
| 150 |
+
self.eps = 1e-7
|
| 151 |
+
# reset all seeds to 42 to reduce invariance between training runs
|
| 152 |
+
os.environ['PYTHONHASHSEED']=str(42)
|
| 153 |
+
seed(42)
|
| 154 |
+
np.random.seed(42)
|
| 155 |
+
tf.random.set_seed(42)
|
| 156 |
+
# some line to correctly find some libraries in TF 2.x
|
| 157 |
+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
| 158 |
+
if len(physical_devices) > 0:
|
| 159 |
+
for device in physical_devices:
|
| 160 |
+
tf.config.experimental.set_memory_growth(device, enable=True)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@staticmethod
|
| 164 |
+
def snr_cost(s_estimate, s_true):
|
| 165 |
+
'''
|
| 166 |
+
Static Method defining the cost function.
|
| 167 |
+
The negative signal to noise ratio is calculated here. The loss is
|
| 168 |
+
always calculated over the last dimension.
|
| 169 |
+
'''
|
| 170 |
+
|
| 171 |
+
# calculating the SNR
|
| 172 |
+
snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
|
| 173 |
+
(tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True)+1e-7)
|
| 174 |
+
# using some more lines, because TF has no log10
|
| 175 |
+
num = tf.math.log(snr)
|
| 176 |
+
denom = tf.math.log(tf.constant(10, dtype=num.dtype))
|
| 177 |
+
loss = -10*(num / (denom))
|
| 178 |
+
# returning the loss
|
| 179 |
+
return loss
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def lossWrapper(self):
|
| 183 |
+
'''
|
| 184 |
+
A wrapper function which returns the loss function. This is done to
|
| 185 |
+
to enable additional arguments to the loss function if necessary.
|
| 186 |
+
'''
|
| 187 |
+
def lossFunction(y_true,y_pred):
|
| 188 |
+
# calculating loss and squeezing single dimensions away
|
| 189 |
+
loss = tf.squeeze(self.cost_function(y_pred,y_true))
|
| 190 |
+
# calculate mean over batches
|
| 191 |
+
loss = tf.reduce_mean(loss)
|
| 192 |
+
# return the loss
|
| 193 |
+
return loss
|
| 194 |
+
# returning the loss function as handle
|
| 195 |
+
return lossFunction
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
'''
|
| 200 |
+
In the following some helper layers are defined.
|
| 201 |
+
'''
|
| 202 |
+
|
| 203 |
+
def stftLayer(self, x):
|
| 204 |
+
'''
|
| 205 |
+
Method for an STFT helper layer used with a Lambda layer. The layer
|
| 206 |
+
calculates the STFT on the last dimension and returns the magnitude and
|
| 207 |
+
phase of the STFT.
|
| 208 |
+
'''
|
| 209 |
+
|
| 210 |
+
# creating frames from the continuous waveform
|
| 211 |
+
frames = tf.signal.frame(x, self.blockLen, self.block_shift)
|
| 212 |
+
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
|
| 213 |
+
stft_dat = tf.signal.rfft(frames)
|
| 214 |
+
# calculating magnitude and phase from the complex signal
|
| 215 |
+
mag = tf.abs(stft_dat)
|
| 216 |
+
phase = tf.math.angle(stft_dat)
|
| 217 |
+
# returning magnitude and phase as list
|
| 218 |
+
return [mag, phase]
|
| 219 |
+
|
| 220 |
+
def fftLayer(self, x):
|
| 221 |
+
'''
|
| 222 |
+
Method for an fft helper layer used with a Lambda layer. The layer
|
| 223 |
+
calculates the rFFT on the last dimension and returns the magnitude and
|
| 224 |
+
phase of the STFT.
|
| 225 |
+
'''
|
| 226 |
+
|
| 227 |
+
# expanding dimensions
|
| 228 |
+
frame = tf.expand_dims(x, axis=1)
|
| 229 |
+
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
|
| 230 |
+
stft_dat = tf.signal.rfft(frame)
|
| 231 |
+
# calculating magnitude and phase from the complex signal
|
| 232 |
+
mag = tf.abs(stft_dat)
|
| 233 |
+
phase = tf.math.angle(stft_dat)
|
| 234 |
+
# returning magnitude and phase as list
|
| 235 |
+
return [mag, phase]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def ifftLayer(self, x):
|
| 240 |
+
'''
|
| 241 |
+
Method for an inverse FFT layer used with an Lambda layer. This layer
|
| 242 |
+
calculates time domain frames from magnitude and phase information.
|
| 243 |
+
As input x a list with [mag,phase] is required.
|
| 244 |
+
'''
|
| 245 |
+
|
| 246 |
+
# calculating the complex representation
|
| 247 |
+
s1_stft = (tf.cast(x[0], tf.complex64) *
|
| 248 |
+
tf.exp( (1j * tf.cast(x[1], tf.complex64))))
|
| 249 |
+
# returning the time domain frames
|
| 250 |
+
return tf.signal.irfft(s1_stft)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def overlapAddLayer(self, x):
|
| 254 |
+
'''
|
| 255 |
+
Method for an overlap and add helper layer used with a Lambda layer.
|
| 256 |
+
This layer reconstructs the waveform from a framed signal.
|
| 257 |
+
'''
|
| 258 |
+
|
| 259 |
+
# calculating and returning the reconstructed waveform
|
| 260 |
+
return tf.signal.overlap_and_add(x, self.block_shift)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def seperation_kernel(self, num_layer, mask_size, x, stateful=False):
|
| 265 |
+
'''
|
| 266 |
+
Method to create a separation kernel.
|
| 267 |
+
!! Important !!: Do not use this layer with a Lambda layer. If used with
|
| 268 |
+
a Lambda layer the gradients are updated correctly.
|
| 269 |
+
|
| 270 |
+
Inputs:
|
| 271 |
+
num_layer Number of LSTM layers
|
| 272 |
+
mask_size Output size of the mask and size of the Dense layer
|
| 273 |
+
'''
|
| 274 |
+
|
| 275 |
+
# creating num_layer number of LSTM layers
|
| 276 |
+
for idx in range(num_layer):
|
| 277 |
+
x = LSTM(self.numUnits, return_sequences=True, stateful=stateful)(x)
|
| 278 |
+
# using dropout between the LSTM layer for regularization
|
| 279 |
+
if idx<(num_layer-1):
|
| 280 |
+
x = Dropout(self.dropout)(x)
|
| 281 |
+
# creating the mask with a Dense and an Activation layer
|
| 282 |
+
mask = Dense(mask_size)(x)
|
| 283 |
+
mask = Activation(self.activation)(mask)
|
| 284 |
+
# returning the mask
|
| 285 |
+
return mask
|
| 286 |
+
|
| 287 |
+
def seperation_kernel_with_states(self, num_layer, mask_size, x,
|
| 288 |
+
in_states):
|
| 289 |
+
'''
|
| 290 |
+
Method to create a separation kernel, which returns the LSTM states.
|
| 291 |
+
!! Important !!: Do not use this layer with a Lambda layer. If used with
|
| 292 |
+
a Lambda layer the gradients are updated correctly.
|
| 293 |
+
|
| 294 |
+
Inputs:
|
| 295 |
+
num_layer Number of LSTM layers
|
| 296 |
+
mask_size Output size of the mask and size of the Dense layer
|
| 297 |
+
'''
|
| 298 |
+
|
| 299 |
+
states_h = []
|
| 300 |
+
states_c = []
|
| 301 |
+
# creating num_layer number of LSTM layers
|
| 302 |
+
for idx in range(num_layer):
|
| 303 |
+
in_state = [in_states[:,idx,:, 0], in_states[:,idx,:, 1]]
|
| 304 |
+
x, h_state, c_state = LSTM(self.numUnits, return_sequences=True,
|
| 305 |
+
unroll=True, return_state=True)(x, initial_state=in_state)
|
| 306 |
+
# using dropout between the LSTM layer for regularization
|
| 307 |
+
if idx<(num_layer-1):
|
| 308 |
+
x = Dropout(self.dropout)(x)
|
| 309 |
+
states_h.append(h_state)
|
| 310 |
+
states_c.append(c_state)
|
| 311 |
+
# creating the mask with a Dense and an Activation layer
|
| 312 |
+
mask = Dense(mask_size)(x)
|
| 313 |
+
mask = Activation(self.activation)(mask)
|
| 314 |
+
out_states_h = tf.reshape(tf.stack(states_h, axis=0),
|
| 315 |
+
[1,num_layer,self.numUnits])
|
| 316 |
+
out_states_c = tf.reshape(tf.stack(states_c, axis=0),
|
| 317 |
+
[1,num_layer,self.numUnits])
|
| 318 |
+
out_states = tf.stack([out_states_h, out_states_c], axis=-1)
|
| 319 |
+
# returning the mask and states
|
| 320 |
+
return mask, out_states
|
| 321 |
+
|
| 322 |
+
def build_DTLN_model(self, norm_stft=False):
|
| 323 |
+
'''
|
| 324 |
+
Method to build and compile the DTLN model. The model takes time domain
|
| 325 |
+
batches of size (batchsize, len_in_samples) and returns enhanced clips
|
| 326 |
+
in the same dimensions. As optimizer for the Training process the Adam
|
| 327 |
+
optimizer with a gradient norm clipping of 3 is used.
|
| 328 |
+
The model contains two separation cores. The first has an STFT signal
|
| 329 |
+
transformation and the second a learned transformation based on 1D-Conv
|
| 330 |
+
layer.
|
| 331 |
+
'''
|
| 332 |
+
|
| 333 |
+
# input layer for time signal
|
| 334 |
+
time_dat = Input(batch_shape=(None, None))
|
| 335 |
+
# calculate STFT
|
| 336 |
+
mag,angle = Lambda(self.stftLayer)(time_dat)
|
| 337 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 338 |
+
if norm_stft:
|
| 339 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 340 |
+
else:
|
| 341 |
+
# behaviour like in the paper
|
| 342 |
+
mag_norm = mag
|
| 343 |
+
# predicting mask with separation kernel
|
| 344 |
+
mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm)
|
| 345 |
+
# multiply mask with magnitude
|
| 346 |
+
estimated_mag = Multiply()([mag, mask_1])
|
| 347 |
+
# transform frames back to time domain
|
| 348 |
+
estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle])
|
| 349 |
+
# encode time domain frames to feature domain
|
| 350 |
+
encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
|
| 351 |
+
# normalize the input to the separation kernel
|
| 352 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 353 |
+
# predict mask based on the normalized feature frames
|
| 354 |
+
mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm)
|
| 355 |
+
# multiply encoded frames with the mask
|
| 356 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 357 |
+
# decode the frames back to time domain
|
| 358 |
+
decoded_frames = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
|
| 359 |
+
# create waveform with overlap and add procedure
|
| 360 |
+
estimated_sig = Lambda(self.overlapAddLayer)(decoded_frames)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# create the model
|
| 364 |
+
self.model = Model(inputs=time_dat, outputs=estimated_sig)
|
| 365 |
+
# show the model summary
|
| 366 |
+
print(self.model.summary())
|
| 367 |
+
|
| 368 |
+
def build_DTLN_model_stateful(self, norm_stft=False):
|
| 369 |
+
'''
|
| 370 |
+
Method to build stateful DTLN model for real time processing. The model
|
| 371 |
+
takes one time domain frame of size (1, blockLen) and one enhanced frame.
|
| 372 |
+
|
| 373 |
+
'''
|
| 374 |
+
|
| 375 |
+
# input layer for time signal
|
| 376 |
+
time_dat = Input(batch_shape=(1, self.blockLen))
|
| 377 |
+
# calculate STFT
|
| 378 |
+
mag,angle = Lambda(self.fftLayer)(time_dat)
|
| 379 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 380 |
+
if norm_stft:
|
| 381 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 382 |
+
else:
|
| 383 |
+
# behaviour like in the paper
|
| 384 |
+
mag_norm = mag
|
| 385 |
+
# predicting mask with separation kernel
|
| 386 |
+
mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm, stateful=True)
|
| 387 |
+
# multiply mask with magnitude
|
| 388 |
+
estimated_mag = Multiply()([mag, mask_1])
|
| 389 |
+
# transform frames back to time domain
|
| 390 |
+
estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle])
|
| 391 |
+
# encode time domain frames to feature domain
|
| 392 |
+
encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
|
| 393 |
+
# normalize the input to the separation kernel
|
| 394 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 395 |
+
# predict mask based on the normalized feature frames
|
| 396 |
+
mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm, stateful=True)
|
| 397 |
+
# multiply encoded frames with the mask
|
| 398 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 399 |
+
# decode the frames back to time domain
|
| 400 |
+
decoded_frame = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
|
| 401 |
+
# create the model
|
| 402 |
+
self.model = Model(inputs=time_dat, outputs=decoded_frame)
|
| 403 |
+
# show the model summary
|
| 404 |
+
print(self.model.summary())
|
| 405 |
+
|
| 406 |
+
def compile_model(self):
|
| 407 |
+
'''
|
| 408 |
+
Method to compile the model for training
|
| 409 |
+
|
| 410 |
+
'''
|
| 411 |
+
|
| 412 |
+
# use the Adam optimizer with a clipnorm of 3
|
| 413 |
+
optimizerAdam = keras.optimizers.Adam(lr=self.lr, clipnorm=3.0)
|
| 414 |
+
# compile model with loss function
|
| 415 |
+
self.model.compile(loss=self.lossWrapper(), optimizer=optimizerAdam)
|
| 416 |
+
|
| 417 |
+
def create_saved_model(self, weights_file, target_name):
|
| 418 |
+
'''
|
| 419 |
+
Method to create a saved model folder from a weights file
|
| 420 |
+
|
| 421 |
+
'''
|
| 422 |
+
# check for type
|
| 423 |
+
if weights_file.find('_norm_') != -1:
|
| 424 |
+
norm_stft = True
|
| 425 |
+
else:
|
| 426 |
+
norm_stft = False
|
| 427 |
+
# build model
|
| 428 |
+
self.build_DTLN_model_stateful(norm_stft=norm_stft)
|
| 429 |
+
# load weights
|
| 430 |
+
self.model.load_weights(weights_file)
|
| 431 |
+
# save model
|
| 432 |
+
tf.saved_model.save(self.model, target_name)
|
| 433 |
+
|
| 434 |
+
def create_tf_lite_model(self, weights_file, target_name, use_dynamic_range_quant=False):
|
| 435 |
+
'''
|
| 436 |
+
Method to create a tf lite model folder from a weights file.
|
| 437 |
+
The conversion creates two models, one for each separation core.
|
| 438 |
+
Tf lite does not support complex numbers yet. Some processing must be
|
| 439 |
+
done outside the model.
|
| 440 |
+
For further information and how real time processing can be
|
| 441 |
+
implemented see "real_time_processing_tf_lite.py".
|
| 442 |
+
|
| 443 |
+
The conversion only works with TF 2.3.
|
| 444 |
+
|
| 445 |
+
'''
|
| 446 |
+
# check for type
|
| 447 |
+
if weights_file.find('_norm_') != -1:
|
| 448 |
+
norm_stft = True
|
| 449 |
+
num_elements_first_core = 2 + self.numLayer * 3 + 2
|
| 450 |
+
else:
|
| 451 |
+
norm_stft = False
|
| 452 |
+
num_elements_first_core = self.numLayer * 3 + 2
|
| 453 |
+
# build model
|
| 454 |
+
self.build_DTLN_model_stateful(norm_stft=norm_stft)
|
| 455 |
+
# load weights
|
| 456 |
+
self.model.load_weights(weights_file)
|
| 457 |
+
|
| 458 |
+
#### Model 1 ##########################
|
| 459 |
+
mag = Input(batch_shape=(1, 1, (self.blockLen//2+1)))
|
| 460 |
+
states_in_1 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
|
| 461 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 462 |
+
if norm_stft:
|
| 463 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 464 |
+
else:
|
| 465 |
+
# behaviour like in the paper
|
| 466 |
+
mag_norm = mag
|
| 467 |
+
# predicting mask with separation kernel
|
| 468 |
+
mask_1, states_out_1 = self.seperation_kernel_with_states(self.numLayer,
|
| 469 |
+
(self.blockLen//2+1),
|
| 470 |
+
mag_norm, states_in_1)
|
| 471 |
+
|
| 472 |
+
model_1 = Model(inputs=[mag, states_in_1], outputs=[mask_1, states_out_1])
|
| 473 |
+
|
| 474 |
+
#### Model 2 ###########################
|
| 475 |
+
|
| 476 |
+
estimated_frame_1 = Input(batch_shape=(1, 1, (self.blockLen)))
|
| 477 |
+
states_in_2 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
|
| 478 |
+
|
| 479 |
+
# encode time domain frames to feature domain
|
| 480 |
+
encoded_frames = Conv1D(self.encoder_size,1,strides=1,
|
| 481 |
+
use_bias=False)(estimated_frame_1)
|
| 482 |
+
# normalize the input to the separation kernel
|
| 483 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 484 |
+
# predict mask based on the normalized feature frames
|
| 485 |
+
mask_2, states_out_2 = self.seperation_kernel_with_states(self.numLayer,
|
| 486 |
+
self.encoder_size,
|
| 487 |
+
encoded_frames_norm,
|
| 488 |
+
states_in_2)
|
| 489 |
+
# multiply encoded frames with the mask
|
| 490 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 491 |
+
# decode the frames back to time domain
|
| 492 |
+
decoded_frame = Conv1D(self.blockLen, 1, padding='causal',
|
| 493 |
+
use_bias=False)(estimated)
|
| 494 |
+
|
| 495 |
+
model_2 = Model(inputs=[estimated_frame_1, states_in_2],
|
| 496 |
+
outputs=[decoded_frame, states_out_2])
|
| 497 |
+
|
| 498 |
+
# set weights to submodels
|
| 499 |
+
weights = self.model.get_weights()
|
| 500 |
+
model_1.set_weights(weights[:num_elements_first_core])
|
| 501 |
+
model_2.set_weights(weights[num_elements_first_core:])
|
| 502 |
+
# convert first model
|
| 503 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(model_1)
|
| 504 |
+
if use_dynamic_range_quant:
|
| 505 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 506 |
+
tflite_model = converter.convert()
|
| 507 |
+
with tf.io.gfile.GFile(target_name + '_1.tflite', 'wb') as f:
|
| 508 |
+
f.write(tflite_model)
|
| 509 |
+
# convert second model
|
| 510 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(model_2)
|
| 511 |
+
if use_dynamic_range_quant:
|
| 512 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 513 |
+
tflite_model = converter.convert()
|
| 514 |
+
with tf.io.gfile.GFile(target_name + '_2.tflite', 'wb') as f:
|
| 515 |
+
f.write(tflite_model)
|
| 516 |
+
|
| 517 |
+
print('TF lite conversion complete!')
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def train_model(self, runName, path_to_train_mix, path_to_train_speech, \
|
| 521 |
+
path_to_val_mix, path_to_val_speech):
|
| 522 |
+
'''
|
| 523 |
+
Method to train the DTLN model.
|
| 524 |
+
'''
|
| 525 |
+
|
| 526 |
+
# create save path if not existent
|
| 527 |
+
savePath = './models_'+ runName+'/'
|
| 528 |
+
if not os.path.exists(savePath):
|
| 529 |
+
os.makedirs(savePath)
|
| 530 |
+
# create log file writer
|
| 531 |
+
csv_logger = CSVLogger(savePath+ 'training_' +runName+ '.log')
|
| 532 |
+
# create callback for the adaptive learning rate
|
| 533 |
+
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
|
| 534 |
+
patience=3, min_lr=10**(-10), cooldown=1)
|
| 535 |
+
# create callback for early stopping
|
| 536 |
+
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0,
|
| 537 |
+
patience=10, verbose=0, mode='auto', baseline=None)
|
| 538 |
+
# create model check pointer to save the best model
|
| 539 |
+
checkpointer = ModelCheckpoint(savePath+runName+'.h5',
|
| 540 |
+
monitor='val_loss',
|
| 541 |
+
verbose=1,
|
| 542 |
+
save_best_only=True,
|
| 543 |
+
save_weights_only=True,
|
| 544 |
+
mode='auto',
|
| 545 |
+
save_freq='epoch'
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# calculate length of audio chunks in samples
|
| 549 |
+
len_in_samples = int(np.fix(self.fs * self.len_samples /
|
| 550 |
+
self.block_shift)*self.block_shift)
|
| 551 |
+
# create data generator for training data
|
| 552 |
+
generator_input = audio_generator(path_to_train_mix,
|
| 553 |
+
path_to_train_speech,
|
| 554 |
+
len_in_samples,
|
| 555 |
+
self.fs, train_flag=True)
|
| 556 |
+
dataset = generator_input.tf_data_set
|
| 557 |
+
dataset = dataset.batch(self.batchsize, drop_remainder=True).repeat()
|
| 558 |
+
# calculate number of training steps in one epoch
|
| 559 |
+
steps_train = generator_input.total_samples//self.batchsize
|
| 560 |
+
# create data generator for validation data
|
| 561 |
+
generator_val = audio_generator(path_to_val_mix,
|
| 562 |
+
path_to_val_speech,
|
| 563 |
+
len_in_samples, self.fs)
|
| 564 |
+
dataset_val = generator_val.tf_data_set
|
| 565 |
+
dataset_val = dataset_val.batch(self.batchsize, drop_remainder=True).repeat()
|
| 566 |
+
# calculate number of validation steps
|
| 567 |
+
steps_val = generator_val.total_samples//self.batchsize
|
| 568 |
+
# start the training of the model
|
| 569 |
+
self.model.fit(
|
| 570 |
+
x=dataset,
|
| 571 |
+
batch_size=None,
|
| 572 |
+
steps_per_epoch=steps_train,
|
| 573 |
+
epochs=self.max_epochs,
|
| 574 |
+
verbose=1,
|
| 575 |
+
validation_data=dataset_val,
|
| 576 |
+
validation_steps=steps_val,
|
| 577 |
+
callbacks=[checkpointer, reduce_lr, csv_logger, early_stopping],
|
| 578 |
+
max_queue_size=50,
|
| 579 |
+
workers=4,
|
| 580 |
+
use_multiprocessing=True)
|
| 581 |
+
# clear out garbage
|
| 582 |
+
tf.keras.backend.clear_session()
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class InstantLayerNormalization(Layer):
|
| 587 |
+
'''
|
| 588 |
+
Class implementing instant layer normalization. It can also be called
|
| 589 |
+
channel-wise layer normalization and was proposed by
|
| 590 |
+
Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2)
|
| 591 |
+
'''
|
| 592 |
+
|
| 593 |
+
def __init__(self, **kwargs):
|
| 594 |
+
'''
|
| 595 |
+
Constructor
|
| 596 |
+
'''
|
| 597 |
+
super(InstantLayerNormalization, self).__init__(**kwargs)
|
| 598 |
+
self.epsilon = 1e-7
|
| 599 |
+
self.gamma = None
|
| 600 |
+
self.beta = None
|
| 601 |
+
|
| 602 |
+
def build(self, input_shape):
|
| 603 |
+
'''
|
| 604 |
+
Method to build the weights.
|
| 605 |
+
'''
|
| 606 |
+
shape = input_shape[-1:]
|
| 607 |
+
# initialize gamma
|
| 608 |
+
self.gamma = self.add_weight(shape=shape,
|
| 609 |
+
initializer='ones',
|
| 610 |
+
trainable=True,
|
| 611 |
+
name='gamma')
|
| 612 |
+
# initialize beta
|
| 613 |
+
self.beta = self.add_weight(shape=shape,
|
| 614 |
+
initializer='zeros',
|
| 615 |
+
trainable=True,
|
| 616 |
+
name='beta')
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def call(self, inputs):
|
| 620 |
+
'''
|
| 621 |
+
Method to call the Layer. All processing is done here.
|
| 622 |
+
'''
|
| 623 |
+
|
| 624 |
+
# calculate mean of each frame
|
| 625 |
+
mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
|
| 626 |
+
# calculate variance of each frame
|
| 627 |
+
variance = tf.math.reduce_mean(tf.math.square(inputs - mean),
|
| 628 |
+
axis=[-1], keepdims=True)
|
| 629 |
+
# calculate standard deviation
|
| 630 |
+
std = tf.math.sqrt(variance + self.epsilon)
|
| 631 |
+
# normalize each frame independently
|
| 632 |
+
outputs = (inputs - mean) / std
|
| 633 |
+
# scale with gamma
|
| 634 |
+
outputs = outputs * self.gamma
|
| 635 |
+
# add the bias beta
|
| 636 |
+
outputs = outputs + self.beta
|
| 637 |
+
# return output
|
| 638 |
+
return outputs
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
|
models/DTLN (yash-04)/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020 Nils L. Westhausen
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
models/DTLN (yash-04)/README.md
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dual-signal Transformation LSTM Network
|
| 2 |
+
|
| 3 |
+
+ Tensorflow 2.x implementation of the stacked dual-signal transformation LSTM network (DTLN) for real-time noise suppression.
|
| 4 |
+
+ This repository provides the code for training, infering and serving the DTLN model in python. It also provides pretrained models in SavedModel, TF-lite and ONNX format, which can be used as baseline for your own projects. The model is able to run with real time audio on a RaspberryPi.
|
| 5 |
+
+ If you are doing cool things with this repo, tell me about it. I am always curious about what you are doing with this code or this models.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
The DTLN model was handed in to the deep noise suppression challenge ([DNS-Challenge](https://github.com/microsoft/DNS-Challenge)) and the paper was presented at Interspeech 2020.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
This approach combines a short-time Fourier transform (STFT) and a learned analysis and synthesis basis in a stacked-network approach with less than one million parameters. The model was trained on 500h of noisy speech provided by the challenge organizers. The network is capable of real-time processing (one frame in, one frame out) and reaches competitive results.
|
| 13 |
+
Combining these two types of signal transformations enables the DTLN to robustly extract information from magnitude spectra and incorporate phase information from the learned feature basis. The method shows state-of-the-art performance and outperforms the DNS-Challenge baseline by 0.24 points absolute in terms of the mean opinion score (MOS).
|
| 14 |
+
|
| 15 |
+
For more information see the [paper](https://www.isca-speech.org/archive/interspeech_2020/westhausen20_interspeech.html). The results of the DNS-Challenge are published [here](https://www.microsoft.com/en-us/research/academic-program/deep-noise-suppression-challenge-interspeech-2020/#!results). We reached a competitive 8th place out of 17 teams in the real time track.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
For baseline usage and to reproduce the processing used for the paper run:
|
| 20 |
+
```bash
|
| 21 |
+
$ python run_evaluation.py -i in/folder/with/wav -o target/folder/processed/files -m ./pretrained_model/model.h5
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
The pretrained DTLN-aec (the DTLN applied to acoustic echo cancellation) can be found in the [DTLN-aec repository](https://github.com/breizhn/DTLN-aec).
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
Author: Nils L. Westhausen ([Communication Acoustics](https://uol.de/en/kommunikationsakustik) , Carl von Ossietzky University, Oldenburg, Germany)
|
| 31 |
+
|
| 32 |
+
This code is licensed under the terms of the MIT license.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
### Citing:
|
| 37 |
+
|
| 38 |
+
If you are using the DTLN model, please cite:
|
| 39 |
+
|
| 40 |
+
```BibTex
|
| 41 |
+
@inproceedings{Westhausen2020,
|
| 42 |
+
author={Nils L. Westhausen and Bernd T. Meyer},
|
| 43 |
+
title={{Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression}},
|
| 44 |
+
year=2020,
|
| 45 |
+
booktitle={Proc. Interspeech 2020},
|
| 46 |
+
pages={2477--2481},
|
| 47 |
+
doi={10.21437/Interspeech.2020-2631},
|
| 48 |
+
url={http://dx.doi.org/10.21437/Interspeech.2020-2631}
|
| 49 |
+
}
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
### Contents of the README:
|
| 55 |
+
|
| 56 |
+
* [Results](#results)
|
| 57 |
+
* [Execution Times](#execution-times)
|
| 58 |
+
* [Audio Samples](#audio-samples)
|
| 59 |
+
* [Contents of the repository](#contents-of-the-repository)
|
| 60 |
+
* [Python dependencies](#python-dependencies)
|
| 61 |
+
* [Training data preparation](#training-data-preparation)
|
| 62 |
+
* [Run a training of the DTLN model](#run-a-training-of-the-dtln-model)
|
| 63 |
+
* [Measuring the execution time of the DTLN model with the SavedModel format](#measuring-the-execution-time-of-the-dtln-model-with-the-savedmodel-format)
|
| 64 |
+
* [Real time processing with the SavedModel format](#real-time-processing-with-the-savedmodel-format)
|
| 65 |
+
* [Real time processing with tf-lite](#real-time-processing-with-tf-lite)
|
| 66 |
+
* [Real time audio with sounddevice and tf-lite](#real-time-audio-with-sounddevice-and-tf-lite)
|
| 67 |
+
* [Model conversion and real time processing with ONNX](#model-conversion-and-real-time-processing-with-onnx)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
### Results:
|
| 73 |
+
|
| 74 |
+
Results on the DNS-Challenge non reverberant test set:
|
| 75 |
+
Model | PESQ [mos] | STOI [%] | SI-SDR [dB] | TF version
|
| 76 |
+
--- | --- | --- | --- | ---
|
| 77 |
+
unprocessed | 2.45 | 91.52 | 9.07 |
|
| 78 |
+
NsNet (Baseline) | 2.70 | 90.56 | 12.57 |
|
| 79 |
+
| | | |
|
| 80 |
+
DTLN (500h) | 3.04 | 94.76 | 16.34 | 2.1
|
| 81 |
+
DTLN (500h)| 2.98 | 94.75 | 16.20 | TF-light
|
| 82 |
+
DTLN (500h) | 2.95 | 94.47 | 15.71 | TF-light quantized
|
| 83 |
+
| | | |
|
| 84 |
+
DTLN norm (500h) | 3.04 | 94.47 | 16.10 | 2.2
|
| 85 |
+
| | | |
|
| 86 |
+
DTLN norm (40h) | 3.05 | 94.57 | 16.88 | 2.2
|
| 87 |
+
DTLN norm (40h) | 2.98 | 94.56 | 16.58 | TF-light
|
| 88 |
+
DTLN norm (40h) | 2.98 | 94.51 | 16.22 | TF-light quantized
|
| 89 |
+
|
| 90 |
+
* The conversion to TF-light slightly reduces the performance.
|
| 91 |
+
* The dynamic range quantization of TF-light also reduces the performance a bit and introduces some quantization noise. But the audio-quality is still on a high level and the model is real-time capable on the Raspberry Pi 3 B+.
|
| 92 |
+
* The normalization of the log magnitude of the STFT does not decrease the model performance and makes it more robust against level variations.
|
| 93 |
+
* With data augmentation during training it is possible to train the DTLN model on just 40h of noise and speech data. If you have any question regarding this, just contact me.
|
| 94 |
+
|
| 95 |
+
[To contents](#contents-of-the-readme)
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
### Execution Times:
|
| 100 |
+
|
| 101 |
+
Execution times for SavedModel are measured with TF 2.2 and for TF-lite with the TF-lite runtime:
|
| 102 |
+
System | Processor | #Cores | SavedModel | TF-lite | TF-lite quantized
|
| 103 |
+
--- | --- | --- | --- | --- | ---
|
| 104 |
+
Ubuntu 18.04 | Intel I5 6600k @ 3.5 GHz | 4 | 0.65 ms | 0.36 ms | 0.27 ms
|
| 105 |
+
Macbook Air mid 2012 | Intel I7 3667U @ 2.0 GHz | 2 | 1.4 ms | 0.6 ms | 0.4 ms
|
| 106 |
+
Raspberry Pi 3 B+ | ARM Cortex A53 @ 1.4 GHz | 4 | 15.54 ms | 9.6 ms | 2.2 ms
|
| 107 |
+
|
| 108 |
+
For real-time capability the execution time must be below 8 ms.
|
| 109 |
+
|
| 110 |
+
[To contents](#contents-of-the-readme)
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
### Audio Samples:
|
| 115 |
+
|
| 116 |
+
Here some audio samples created with the tf-lite model. Sadly audio can not be integrated directly into markdown.
|
| 117 |
+
|
| 118 |
+
Noisy | Enhanced | Noise type
|
| 119 |
+
--- | --- | ---
|
| 120 |
+
[Sample 1](https://cloudsync.uol.de/s/GFHzmWWJAwgQPLf) | [Sample 1](https://cloudsync.uol.de/s/p3M48y7cjkJ2ZZg) | Air conditioning
|
| 121 |
+
[Sample 2](https://cloudsync.uol.de/s/4Y2PoSpJf7nXx9T) | [Sample 2](https://cloudsync.uol.de/s/QeK4aH5KCELPnko) | Music
|
| 122 |
+
[Sample 3](https://cloudsync.uol.de/s/Awc6oBtnTpb5pY7) | [Sample 3](https://cloudsync.uol.de/s/yNsmDgxH3MPWMTi) | Bus
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
[To contents](#contents-of-the-readme)
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
### Contents of the repository:
|
| 129 |
+
|
| 130 |
+
* **DTLN_model.py** \
|
| 131 |
+
This file is containing the model, data generator and the training routine.
|
| 132 |
+
* **run_training.py** \
|
| 133 |
+
Script to run the training. Before you can start the training with `$ python run_training.py`you have to set the paths to you training and validation data inside the script. The training script uses a default setup.
|
| 134 |
+
* **run_evaluation.py** \
|
| 135 |
+
Script to process a folder with optional subfolders containing .wav files with a trained DTLN model. With the pretrained model delivered with this repository a folder can be processed as following: \
|
| 136 |
+
`$ python run_evaluation.py -i /path/to/input -o /path/for/processed -m ./pretrained_model/model.h5` \
|
| 137 |
+
The evaluation script will create the new folder with the same structure as the input folder and the files will have the same name as the input files.
|
| 138 |
+
* **measure_execution_time.py** \
|
| 139 |
+
Script for measuring the execution time with the saved DTLN model in `./pretrained_model/dtln_saved_model/`. For further information see this [section](#measuring-the-execution-time-of-the-dtln-model-with-the-savedmodel-format).
|
| 140 |
+
* **real_time_processing.py** \
|
| 141 |
+
Script, which explains how real time processing with the SavedModel works. For more information see this [section](#real-time-processing-with-the-savedmodel-format).
|
| 142 |
+
+ **./pretrained_model/** \
|
| 143 |
+
* `model.h5`: Model weights as used in the DNS-Challenge DTLN model.
|
| 144 |
+
* `DTLN_norm_500h.h5`: Model weights trained on 500h with normalization of stft log magnitudes.
|
| 145 |
+
* `DTLN_norm_40h.h5`: Model weights trained on 40h with normalization of stft log magnitudes.
|
| 146 |
+
* `./dtln_saved_model`: same as `model.h5` but as a stateful model in SavedModel format.
|
| 147 |
+
* `./DTLN_norm_500h_saved_model`: same as `DTLN_norm_500h.h5` but as a stateful model in SavedModel format.
|
| 148 |
+
* `./DTLN_norm_40h_saved_model`: same as `DTLN_norm_40h.h5` but as a stateful model in SavedModel format.
|
| 149 |
+
* `model_1.tflite` together with `model_2.tflite`: same as `model.h5` but as TF-lite model with external state handling.
|
| 150 |
+
* `model_quant_1.tflite` together with `model_quant_2.tflite`: same as `model.h5` but as TF-lite model with external state handling and dynamic range quantization.
|
| 151 |
+
* `model_1.onnx` together with `model_2.onnx`: same as `model.h5` but as ONNX model with external state handling.
|
| 152 |
+
|
| 153 |
+
[To contents](#contents-of-the-readme)
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
### Python dependencies:
|
| 157 |
+
|
| 158 |
+
The following packages will be required for this repository:
|
| 159 |
+
* TensorFlow (2.x)
|
| 160 |
+
* librosa
|
| 161 |
+
* wavinfo
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
All additional packages (numpy, soundfile, etc.) should be installed on the fly when using conda or pip. I recommend using conda environments or [pyenv](https://github.com/pyenv/pyenv) [virtualenv](https://github.com/pyenv/pyenv-virtualenv) for the python environment. For training a GPU with at least 5 GB of memory is required. I recommend at least Tensorflow 2.1 with Nvidia driver 418 and Cuda 10.1. If you use conda Cuda will be installed on the fly and you just need the driver. For evaluation-only the CPU version of Tensorflow is enough. Everything was tested on Ubuntu 18.04.
|
| 165 |
+
|
| 166 |
+
Conda environments for training (with cuda) and for evaluation (CPU only) can be created as following:
|
| 167 |
+
|
| 168 |
+
For the training environment:
|
| 169 |
+
```shell
|
| 170 |
+
$ conda env create -f train_env.yml
|
| 171 |
+
```
|
| 172 |
+
For the evaluation environment:
|
| 173 |
+
```
|
| 174 |
+
$ conda env create -f eval_env.yml
|
| 175 |
+
```
|
| 176 |
+
For the tf-lite environment:
|
| 177 |
+
```
|
| 178 |
+
$ conda env create -f tflite_env.yml
|
| 179 |
+
```
|
| 180 |
+
The tf-lite runtime must be downloaded from [here](https://www.tensorflow.org/lite/guide/python).
|
| 181 |
+
|
| 182 |
+
[To contents](#contents-of-the-readme)
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
### Training data preparation:
|
| 186 |
+
|
| 187 |
+
1. Clone the forked DNS-Challenge [repository](https://github.com/breizhn/DNS-Challenge). Before cloning the repository make sure `git-lfs` is installed. Also make sure your disk has enough space. I recommend downloading the data to an SSD for faster dataset creation.
|
| 188 |
+
|
| 189 |
+
2. Run `noisyspeech_synthesizer_multiprocessing.py` to create the dataset. `noisyspeech_synthesizer.cfg`was changed according to my training setup used for the DNS-Challenge.
|
| 190 |
+
|
| 191 |
+
3. Run `split_dns_corpus.py`to divide the dataset in training and validation data. The classic 80:20 split is applied. This file was added to the forked repository by me.
|
| 192 |
+
|
| 193 |
+
[To contents](#contents-of-the-readme)
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
### Run a training of the DTLN model:
|
| 197 |
+
|
| 198 |
+
1. Make sure all dependencies are installed in your python environment.
|
| 199 |
+
|
| 200 |
+
2. Change the paths to your training and validation dataset in `run_training.py`.
|
| 201 |
+
|
| 202 |
+
3. Run `$ python run_training.py`.
|
| 203 |
+
|
| 204 |
+
One epoch takes around 21 minutes on a Nvidia RTX 2080 Ti when loading the training data from an SSD.
|
| 205 |
+
|
| 206 |
+
[To contents](#contents-of-the-readme)
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
### Measuring the execution time of the DTLN model with the SavedModel format:
|
| 210 |
+
|
| 211 |
+
In total there are three ways to measure the execution time for one block of the model: Running a sequence in Keras and dividing by the number of blocks in the sequence, building a stateful model in Keras and running block by block, and saving the stateful model in Tensorflow's SavedModel format and calling that one block by block. In the following I will explain how running the model in the SavedModel format, because it is the most portable version and can also be called from Tensorflow Serving.
|
| 212 |
+
|
| 213 |
+
A Keras model can be saved to the saved model format:
|
| 214 |
+
```python
|
| 215 |
+
import tensorflow as tf
|
| 216 |
+
'''
|
| 217 |
+
Building some model here
|
| 218 |
+
'''
|
| 219 |
+
tf.saved_model.save(your_keras_model, 'name_save_path')
|
| 220 |
+
```
|
| 221 |
+
Important here for real time block by block processing is, to make the LSTM layer stateful, so they can remember the states from the previous block.
|
| 222 |
+
|
| 223 |
+
The model can be imported with
|
| 224 |
+
```python
|
| 225 |
+
model = tf.saved_model.load('name_save_path')
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
For inference we now first call this for mapping signature names to functions
|
| 229 |
+
```python
|
| 230 |
+
infer = model.signatures['serving_default']
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
and now for inferring the block `x` call
|
| 234 |
+
```python
|
| 235 |
+
y = infer(tf.constant(x))['conv1d_1']
|
| 236 |
+
```
|
| 237 |
+
This command gives you the result on the node `'conv1d_1'`which is our output node for real time processing. For more information on using the SavedModel format and obtaining the output node see this [Guide](https://www.tensorflow.org/guide/saved_model).
|
| 238 |
+
|
| 239 |
+
For making everything easier this repository provides a stateful DTLN SavedModel.
|
| 240 |
+
For measuring the execution time call:
|
| 241 |
+
```
|
| 242 |
+
$ python measure_execution_time.py
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
[To contents](#contents-of-the-readme)
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
### Real time processing with the SavedModel format:
|
| 250 |
+
|
| 251 |
+
For explanation look at `real_time_processing.py`.
|
| 252 |
+
|
| 253 |
+
Here some consideration for integrating this model in your project:
|
| 254 |
+
* The sampling rate of this model is fixed at 16 kHz. It will not work smoothly with other sampling rates.
|
| 255 |
+
* The block length of 32 ms and the block shift of 8 ms are also fixed. For changing these values, the model must be retrained.
|
| 256 |
+
* The delay created by the model is the block length, so the input-output delay is 32 ms.
|
| 257 |
+
* For real time capability on your system, the execution time must be below the length of the block shift, so below 8 ms.
|
| 258 |
+
* If can not give you support on the hardware side, regarding soundcards, drivers and so on. Be aware, a lot of artifacts can come from this side.
|
| 259 |
+
|
| 260 |
+
[To contents](#contents-of-the-readme)
|
| 261 |
+
|
| 262 |
+
---
|
| 263 |
+
### Real time processing with tf-lite:
|
| 264 |
+
|
| 265 |
+
With TF 2.3 it is finally possible to convert LSTMs to tf-lite. It is still not perfect because the states must be handled seperatly for a stateful model and tf-light does not support complex numbers. That means that the model is splitted in two submodels when converting it to tf-lite and the calculation of the FFT and iFFT is performed outside the model. I provided an example script for explaining, how real time processing with the tf light model works (```real_time_processing_tf_lite.py```). In this script the tf-lite runtime is used. The runtime can be downloaded [here](https://www.tensorflow.org/lite/guide/python). Quantization works now.
|
| 266 |
+
|
| 267 |
+
Using the tf-lite DTLN model and the tf-lite runtime the execution time on an old Macbook Air mid 2012 can be decreased to **0.6 ms**.
|
| 268 |
+
|
| 269 |
+
[To contents](#contents-of-the-readme)
|
| 270 |
+
|
| 271 |
+
---
|
| 272 |
+
### Real time audio with sounddevice and tf-lite:
|
| 273 |
+
|
| 274 |
+
The file ```real_time_dtln_audio.py```is an example how real time audio with the tf-lite model and the [sounddevice](https://github.com/spatialaudio/python-sounddevice) toolbox can be implemented. The script is based on the ```wire.py``` example. It works fine on an old Macbook Air mid 2012 and so it will probably run on most newer devices. In the quantized version it was sucessfully tested on an Raspberry Pi 3B +.
|
| 275 |
+
|
| 276 |
+
First check for your audio devices:
|
| 277 |
+
```
|
| 278 |
+
$ python real_time_dtln_audio.py --list-devices
|
| 279 |
+
```
|
| 280 |
+
Choose the index of an input and an output device and call:
|
| 281 |
+
```
|
| 282 |
+
$ python real_time_dtln_audio.py -i in_device_idx -o out_device_idx
|
| 283 |
+
```
|
| 284 |
+
If the script is showing too much ```input underflow``` restart the sript. If that does not help, increase the latency with the ```--latency``` option. The default value is 0.2 .
|
| 285 |
+
|
| 286 |
+
[To contents](#contents-of-the-readme)
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
### Model conversion and real time processing with ONNX:
|
| 290 |
+
|
| 291 |
+
Finally I got the ONNX model working.
|
| 292 |
+
For converting the model TF 2.1 and keras2onnx is required. keras2onnx can be downloaded [here](https://github.com/onnx/keras-onnx) and must be installed from source as described in the README. When all dependencies are installed, call:
|
| 293 |
+
```
|
| 294 |
+
$ python convert_weights_to_onnx.py -m /name/of/the/model.h5 -t onnx_model_name
|
| 295 |
+
```
|
| 296 |
+
to convert the model to the ONNX format. The model is split in two parts as for the TF-lite model. The conversion does not work on MacOS.
|
| 297 |
+
The real time processing works similar to the TF-lite model and can be looked up in following file: ```real_time_processing_onnx.py ```
|
| 298 |
+
The ONNX runtime required for this script can be installed with:
|
| 299 |
+
```
|
| 300 |
+
$ pip install onnxruntime
|
| 301 |
+
```
|
| 302 |
+
The execution time on the Macbook Air mid 2012 is around 1.13 ms for one block.
|
models/DTLN (yash-04)/convert_weights_to_onnx.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to covert a .h5 weights file of the DTLN model to ONNX.
|
| 4 |
+
At the moment the conversion only works with TF 2.1 and not on Mac.
|
| 5 |
+
|
| 6 |
+
Example call:
|
| 7 |
+
$python convert_weights_to_ONNX.py -m /name/of/the/model.h5 \
|
| 8 |
+
-t name_target
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 12 |
+
Version: 03.07.2020
|
| 13 |
+
|
| 14 |
+
This code is licensed under the terms of the MIT-license.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from DTLN_model import DTLN_model, InstantLayerNormalization
|
| 18 |
+
import argparse
|
| 19 |
+
from tensorflow.keras.models import Model
|
| 20 |
+
from tensorflow.keras.layers import Input, Multiply, Conv1D
|
| 21 |
+
import tensorflow as tf
|
| 22 |
+
import keras2onnx
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == '__main__':
|
| 26 |
+
# arguement parser for running directly from the command line
|
| 27 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 28 |
+
parser.add_argument('--weights_file', '-m',
|
| 29 |
+
help='path to .h5 weights file')
|
| 30 |
+
parser.add_argument('--target_folder', '-t',
|
| 31 |
+
help='target folder for saved model')
|
| 32 |
+
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
weights_file = args.weights_file
|
| 35 |
+
dtln_class = DTLN_model()
|
| 36 |
+
# check for type
|
| 37 |
+
if weights_file.find('_norm_') != -1:
|
| 38 |
+
norm_stft = True
|
| 39 |
+
num_elements_first_core = 2 + dtln_class.numLayer * 3 + 2
|
| 40 |
+
else:
|
| 41 |
+
norm_stft = False
|
| 42 |
+
num_elements_first_core = dtln_class.numLayer * 3 + 2
|
| 43 |
+
# build model
|
| 44 |
+
dtln_class.build_DTLN_model_stateful(norm_stft=norm_stft)
|
| 45 |
+
# load weights
|
| 46 |
+
dtln_class.model.load_weights(weights_file)
|
| 47 |
+
#### Model 1 ##########################
|
| 48 |
+
mag = Input(batch_shape=(1, 1, (dtln_class.blockLen//2+1)))
|
| 49 |
+
states_in_1 = Input(batch_shape=(1, dtln_class.numLayer, dtln_class.numUnits, 2))
|
| 50 |
+
# normalizing log magnitude stfts to get more robust against level variations
|
| 51 |
+
if norm_stft:
|
| 52 |
+
mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
|
| 53 |
+
else:
|
| 54 |
+
# behaviour like in the paper
|
| 55 |
+
mag_norm = mag
|
| 56 |
+
# predicting mask with separation kernel
|
| 57 |
+
mask_1, states_out_1 = dtln_class.seperation_kernel_with_states(dtln_class.numLayer,
|
| 58 |
+
(dtln_class.blockLen//2+1),
|
| 59 |
+
mag_norm, states_in_1)
|
| 60 |
+
|
| 61 |
+
model_1 = Model(inputs=[mag, states_in_1], outputs=[mask_1, states_out_1])
|
| 62 |
+
|
| 63 |
+
#### Model 2 ###########################
|
| 64 |
+
|
| 65 |
+
estimated_frame_1 = Input(batch_shape=(1, 1, (dtln_class.blockLen)))
|
| 66 |
+
states_in_2 = Input(batch_shape=(1, dtln_class.numLayer, dtln_class.numUnits, 2))
|
| 67 |
+
|
| 68 |
+
# encode time domain frames to feature domain
|
| 69 |
+
encoded_frames = Conv1D(dtln_class.encoder_size,1,strides=1,
|
| 70 |
+
use_bias=False)(estimated_frame_1)
|
| 71 |
+
# normalize the input to the separation kernel
|
| 72 |
+
encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
|
| 73 |
+
# predict mask based on the normalized feature frames
|
| 74 |
+
mask_2, states_out_2 = dtln_class.seperation_kernel_with_states(dtln_class.numLayer,
|
| 75 |
+
dtln_class.encoder_size,
|
| 76 |
+
encoded_frames_norm,
|
| 77 |
+
states_in_2)
|
| 78 |
+
# multiply encoded frames with the mask
|
| 79 |
+
estimated = Multiply()([encoded_frames, mask_2])
|
| 80 |
+
# decode the frames back to time domain
|
| 81 |
+
decoded_frame = Conv1D(dtln_class.blockLen, 1, padding='causal',
|
| 82 |
+
use_bias=False)(estimated)
|
| 83 |
+
|
| 84 |
+
model_2 = Model(inputs=[estimated_frame_1, states_in_2],
|
| 85 |
+
outputs=[decoded_frame, states_out_2])
|
| 86 |
+
|
| 87 |
+
# set weights to submodels
|
| 88 |
+
weights = dtln_class.model.get_weights()
|
| 89 |
+
|
| 90 |
+
model_1.set_weights(weights[:num_elements_first_core])
|
| 91 |
+
model_2.set_weights(weights[num_elements_first_core:])
|
| 92 |
+
# convert first model
|
| 93 |
+
onnx_model = keras2onnx.convert_keras(model_1)
|
| 94 |
+
temp_model_file = args.target_folder + '_1.onnx'
|
| 95 |
+
keras2onnx.save_model(onnx_model, temp_model_file)
|
| 96 |
+
# convert second model
|
| 97 |
+
onnx_model = keras2onnx.convert_keras(model_2)
|
| 98 |
+
temp_model_file = args.target_folder + '_2.onnx'
|
| 99 |
+
keras2onnx.save_model(onnx_model, temp_model_file)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
print('ONNX conversion complete!')
|
models/DTLN (yash-04)/convert_weights_to_saved_model.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to covert a .h weights file of the DTLN model to the saved model format.
|
| 3 |
+
|
| 4 |
+
Example call:
|
| 5 |
+
$python convert_weights_to_saved_model.py -m /name/of/the/model.h5 \
|
| 6 |
+
-t name_target_folder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 24.06.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from DTLN_model import DTLN_model
|
| 16 |
+
import argparse
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == '__main__':
|
| 20 |
+
# arguement parser for running directly from the command line
|
| 21 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 22 |
+
parser.add_argument('--weights_file', '-m',
|
| 23 |
+
help='path to .h5 weights file')
|
| 24 |
+
parser.add_argument('--target_folder', '-t',
|
| 25 |
+
help='target folder for saved model')
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
converter = DTLN_model()
|
| 30 |
+
converter.create_saved_model(args.weights_file, args.target_folder)
|
models/DTLN (yash-04)/convert_weights_to_tf_lite.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to covert a .h5 weights file of the DTLN model to tf lite.
|
| 3 |
+
|
| 4 |
+
Example call:
|
| 5 |
+
$python convert_weights_to_tf_light.py -m /name/of/the/model.h5 \
|
| 6 |
+
-t name_target
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 30.06.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from DTLN_model import DTLN_model
|
| 16 |
+
import argparse
|
| 17 |
+
from pkg_resources import parse_version
|
| 18 |
+
import tensorflow as tf
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if __name__ == '__main__':
|
| 22 |
+
# arguement parser for running directly from the command line
|
| 23 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 24 |
+
parser.add_argument('--weights_file', '-m',
|
| 25 |
+
help='path to .h5 weights file')
|
| 26 |
+
parser.add_argument('--target_folder', '-t',
|
| 27 |
+
help='target folder for saved model')
|
| 28 |
+
parser.add_argument('--quantization', '-q',
|
| 29 |
+
help='use quantization (True/False)',
|
| 30 |
+
default='False')
|
| 31 |
+
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
if parse_version(tf.__version__) < parse_version('2.3.0-rc0'):
|
| 34 |
+
raise ValueError('Tf version < 2.3. Conversion of LSTMs will not work'+
|
| 35 |
+
+' with older tensorflow versions')
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
converter = DTLN_model()
|
| 39 |
+
converter.create_tf_lite_model(args.weights_file,
|
| 40 |
+
args.target_folder,
|
| 41 |
+
use_dynamic_range_quant=bool(args.quantization))
|
models/DTLN (yash-04)/eval_env.yml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: eval_env
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _tflow_select=2.3.0=mkl
|
| 8 |
+
- absl-py=0.9.0=py37_0
|
| 9 |
+
- astunparse=1.6.3=py_0
|
| 10 |
+
- audioread=2.1.8=py37hc8dfbb8_2
|
| 11 |
+
- blas=1.0=mkl
|
| 12 |
+
- blinker=1.4=py37_0
|
| 13 |
+
- brotlipy=0.7.0=py37h7b6447c_1000
|
| 14 |
+
- bzip2=1.0.8=h516909a_2
|
| 15 |
+
- c-ares=1.15.0=h7b6447c_1001
|
| 16 |
+
- ca-certificates=2020.4.5.2=hecda079_0
|
| 17 |
+
- cachetools=4.1.0=py_1
|
| 18 |
+
- certifi=2020.4.5.2=py37hc8dfbb8_0
|
| 19 |
+
- cffi=1.14.0=py37he30daa8_1
|
| 20 |
+
- chardet=3.0.4=py37_1003
|
| 21 |
+
- click=7.1.2=py_0
|
| 22 |
+
- cryptography=2.9.2=py37h1ba5d50_0
|
| 23 |
+
- cycler=0.10.0=py_2
|
| 24 |
+
- decorator=4.4.2=py_0
|
| 25 |
+
- ffmpeg=4.2.3=h167e202_0
|
| 26 |
+
- freetype=2.10.2=he06d7ca_0
|
| 27 |
+
- gast=0.3.3=py_0
|
| 28 |
+
- gettext=0.19.8.1=h5e8e0c9_1
|
| 29 |
+
- gmp=6.2.0=he1b5a44_2
|
| 30 |
+
- gnutls=3.6.13=h79a8f9a_0
|
| 31 |
+
- google-auth=1.14.1=py_0
|
| 32 |
+
- google-auth-oauthlib=0.4.1=py_2
|
| 33 |
+
- google-pasta=0.2.0=py_0
|
| 34 |
+
- grpcio=1.27.2=py37hf8bcb03_0
|
| 35 |
+
- h5py=2.10.0=py37h7918eee_0
|
| 36 |
+
- hdf5=1.10.4=hb1b8bf9_0
|
| 37 |
+
- icu=58.2=hf484d3e_1000
|
| 38 |
+
- idna=2.9=py_1
|
| 39 |
+
- intel-openmp=2020.1=217
|
| 40 |
+
- joblib=0.15.1=py_0
|
| 41 |
+
- keras-preprocessing=1.1.0=py_1
|
| 42 |
+
- kiwisolver=1.2.0=py37h99015e2_0
|
| 43 |
+
- lame=3.100=h14c3975_1001
|
| 44 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
| 45 |
+
- libedit=3.1.20181209=hc058e9b_0
|
| 46 |
+
- libffi=3.3=he6710b0_1
|
| 47 |
+
- libflac=1.3.3=he1b5a44_0
|
| 48 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
| 49 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
| 50 |
+
- libiconv=1.15=h516909a_1006
|
| 51 |
+
- libllvm8=8.0.1=hc9558a2_0
|
| 52 |
+
- libogg=1.3.2=h516909a_1002
|
| 53 |
+
- libpng=1.6.37=hed695b0_1
|
| 54 |
+
- libprotobuf=3.12.3=hd408876_0
|
| 55 |
+
- librosa=0.7.2=py_1
|
| 56 |
+
- libsndfile=1.0.28=he1b5a44_1000
|
| 57 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
| 58 |
+
- libvorbis=1.3.6=he1b5a44_2
|
| 59 |
+
- llvmlite=0.31.0=py37h5202443_1
|
| 60 |
+
- markdown=3.1.1=py37_0
|
| 61 |
+
- matplotlib-base=3.2.1=py37hef1b27d_0
|
| 62 |
+
- mkl=2020.1=217
|
| 63 |
+
- mkl-service=2.3.0=py37he904b0f_0
|
| 64 |
+
- mkl_fft=1.1.0=py37h23d657b_0
|
| 65 |
+
- mkl_random=1.1.1=py37h0573a6f_0
|
| 66 |
+
- ncurses=6.2=he6710b0_1
|
| 67 |
+
- nettle=3.4.1=h1bed415_1002
|
| 68 |
+
- numba=0.48.0=py37hb3f55d8_0
|
| 69 |
+
- numpy=1.18.1=py37h4f9e942_0
|
| 70 |
+
- numpy-base=1.18.1=py37hde5b4d6_1
|
| 71 |
+
- oauthlib=3.1.0=py_0
|
| 72 |
+
- openh264=2.1.1=h8b12597_0
|
| 73 |
+
- openssl=1.1.1g=h516909a_0
|
| 74 |
+
- opt_einsum=3.1.0=py_0
|
| 75 |
+
- pip=20.1.1=py37_1
|
| 76 |
+
- protobuf=3.12.3=py37he6710b0_0
|
| 77 |
+
- pyasn1=0.4.8=py_0
|
| 78 |
+
- pyasn1-modules=0.2.7=py_0
|
| 79 |
+
- pycparser=2.20=py_0
|
| 80 |
+
- pyjwt=1.7.1=py37_0
|
| 81 |
+
- pyopenssl=19.1.0=py37_0
|
| 82 |
+
- pyparsing=2.4.7=pyh9f0ad1d_0
|
| 83 |
+
- pysocks=1.7.1=py37_0
|
| 84 |
+
- pysoundfile=0.10.2=py_1001
|
| 85 |
+
- python=3.7.7=hcff3b4d_5
|
| 86 |
+
- python-dateutil=2.8.1=py_0
|
| 87 |
+
- python_abi=3.7=1_cp37m
|
| 88 |
+
- readline=8.0=h7b6447c_0
|
| 89 |
+
- requests=2.23.0=py37_0
|
| 90 |
+
- requests-oauthlib=1.3.0=py_0
|
| 91 |
+
- resampy=0.2.2=py_0
|
| 92 |
+
- rsa=4.0=py_0
|
| 93 |
+
- scikit-learn=0.22.1=py37hd81dba3_0
|
| 94 |
+
- scipy=1.4.1=py37h0b6359f_0
|
| 95 |
+
- setuptools=47.3.0=py37_0
|
| 96 |
+
- six=1.15.0=py_0
|
| 97 |
+
- sqlite=3.31.1=h62c20be_1
|
| 98 |
+
- tensorboard=2.2.1=pyh532a8cf_0
|
| 99 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
| 100 |
+
- tensorflow=2.2.0=mkl_py37h6e9ce2d_0
|
| 101 |
+
- tensorflow-base=2.2.0=mkl_py37hd506778_0
|
| 102 |
+
- tensorflow-estimator=2.2.0=pyh208ff02_0
|
| 103 |
+
- termcolor=1.1.0=py37_1
|
| 104 |
+
- tk=8.6.8=hbc83047_0
|
| 105 |
+
- tornado=6.0.4=py37h8f50634_1
|
| 106 |
+
- urllib3=1.25.9=py_0
|
| 107 |
+
- werkzeug=1.0.1=py_0
|
| 108 |
+
- wheel=0.34.2=py37_0
|
| 109 |
+
- wrapt=1.12.1=py37h7b6447c_1
|
| 110 |
+
- x264=1!152.20180806=h14c3975_0
|
| 111 |
+
- xz=5.2.5=h7b6447c_0
|
| 112 |
+
- zlib=1.2.11=h7b6447c_3
|
| 113 |
+
- pip:
|
| 114 |
+
- lxml==4.5.1
|
| 115 |
+
- wavinfo==1.5
|
| 116 |
+
prefix: /home/nils/anaconda3/envs/eval_env
|
| 117 |
+
|
models/DTLN (yash-04)/measure_execution_time.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
This script tests the execution time of the DTLN model on a CPU.
|
| 5 |
+
Please use TF 2.2 for comparability.
|
| 6 |
+
|
| 7 |
+
Just run "python measure_execution_time.py"
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 13.05.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
import numpy as np
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# only use the cpu
|
| 21 |
+
os.environ["CUDA_VISIBLE_DEVICES"]=''
|
| 22 |
+
|
| 23 |
+
if __name__ == '__main__':
|
| 24 |
+
# loading model in saved model format
|
| 25 |
+
model = tf.saved_model.load('./pretrained_model/dtln_saved_model')
|
| 26 |
+
# mapping signature names to functions
|
| 27 |
+
infer = model.signatures["serving_default"]
|
| 28 |
+
|
| 29 |
+
exec_time = []
|
| 30 |
+
# create random input for testing
|
| 31 |
+
x = np.random.randn(1,512).astype('float32')
|
| 32 |
+
for idx in range(1010):
|
| 33 |
+
# run timer
|
| 34 |
+
start_time = time.time()
|
| 35 |
+
# infer one block
|
| 36 |
+
y = infer(tf.constant(x))['conv1d_1']
|
| 37 |
+
exec_time.append((time.time() - start_time))
|
| 38 |
+
# ignore the first ten iterations
|
| 39 |
+
print('Execution time per block: ' +
|
| 40 |
+
str( np.round(np.mean(np.stack(exec_time[10:]))*1000, 2)) + ' ms')
|
| 41 |
+
|
| 42 |
+
# Ubuntu 18.04 I5 6600k @ 3.5 GHz: 0.65 ms (4 cores)
|
| 43 |
+
# Macbook Air mid 2012 I7 3667U @ 2.0 GHz: 1.4 ms (2 cores)
|
| 44 |
+
# Raspberry Pi 3 B+ ARM Cortex A53 @ 1.4 GHz: 15.54 (4 cores)
|
models/DTLN (yash-04)/real_time_dtln_audio.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
This is a real time example how to implement DTLN tf light model with
|
| 5 |
+
sounddevice. The script is based on the "wire.py" example of the sounddevice
|
| 6 |
+
toolbox. If the command line shows "input underflow", restart the script.
|
| 7 |
+
If there are still a lot of dropouts, increase the latency.
|
| 8 |
+
|
| 9 |
+
First call:
|
| 10 |
+
|
| 11 |
+
$ python real_time_dtln_audio.py --list-devices
|
| 12 |
+
|
| 13 |
+
to get your audio devices. In the next step call
|
| 14 |
+
|
| 15 |
+
$ python real_time_dtln_audio.py -i in_device_idx -o out_device_idx
|
| 16 |
+
|
| 17 |
+
For .whl files of the tf light runtime go to:
|
| 18 |
+
https://www.tensorflow.org/lite/guide/python
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 23 |
+
Version: 01.07.2020
|
| 24 |
+
|
| 25 |
+
This code is licensed under the terms of the MIT-license.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import sounddevice as sd
|
| 31 |
+
import tflite_runtime.interpreter as tflite
|
| 32 |
+
import argparse
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def int_or_str(text):
|
| 36 |
+
"""Helper function for argument parsing."""
|
| 37 |
+
try:
|
| 38 |
+
return int(text)
|
| 39 |
+
except ValueError:
|
| 40 |
+
return text
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
'-l', '--list-devices', action='store_true',
|
| 46 |
+
help='show list of audio devices and exit')
|
| 47 |
+
args, remaining = parser.parse_known_args()
|
| 48 |
+
if args.list_devices:
|
| 49 |
+
print(sd.query_devices())
|
| 50 |
+
parser.exit(0)
|
| 51 |
+
parser = argparse.ArgumentParser(
|
| 52 |
+
description=__doc__,
|
| 53 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 54 |
+
parents=[parser])
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
'-i', '--input-device', type=int_or_str,
|
| 57 |
+
help='input device (numeric ID or substring)')
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
'-o', '--output-device', type=int_or_str,
|
| 60 |
+
help='output device (numeric ID or substring)')
|
| 61 |
+
|
| 62 |
+
parser.add_argument('--latency', type=float, help='latency in seconds', default=0.2)
|
| 63 |
+
args = parser.parse_args(remaining)
|
| 64 |
+
|
| 65 |
+
# set some parameters
|
| 66 |
+
block_len_ms = 32
|
| 67 |
+
block_shift_ms = 8
|
| 68 |
+
fs_target = 16000
|
| 69 |
+
# create the interpreters
|
| 70 |
+
interpreter_1 = tflite.Interpreter(model_path='./pretrained_model/model_1.tflite')
|
| 71 |
+
interpreter_1.allocate_tensors()
|
| 72 |
+
interpreter_2 = tflite.Interpreter(model_path='./pretrained_model/model_2.tflite')
|
| 73 |
+
interpreter_2.allocate_tensors()
|
| 74 |
+
# Get input and output tensors.
|
| 75 |
+
input_details_1 = interpreter_1.get_input_details()
|
| 76 |
+
output_details_1 = interpreter_1.get_output_details()
|
| 77 |
+
input_details_2 = interpreter_2.get_input_details()
|
| 78 |
+
output_details_2 = interpreter_2.get_output_details()
|
| 79 |
+
# create states for the lstms
|
| 80 |
+
states_1 = np.zeros(input_details_1[1]['shape']).astype('float32')
|
| 81 |
+
states_2 = np.zeros(input_details_2[1]['shape']).astype('float32')
|
| 82 |
+
# calculate shift and length
|
| 83 |
+
block_shift = int(np.round(fs_target * (block_shift_ms / 1000)))
|
| 84 |
+
block_len = int(np.round(fs_target * (block_len_ms / 1000)))
|
| 85 |
+
# create buffer
|
| 86 |
+
in_buffer = np.zeros((block_len)).astype('float32')
|
| 87 |
+
out_buffer = np.zeros((block_len)).astype('float32')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def callback(indata, outdata, frames, time, status):
|
| 91 |
+
# buffer and states to global
|
| 92 |
+
global in_buffer, out_buffer, states_1, states_2
|
| 93 |
+
if status:
|
| 94 |
+
print(status)
|
| 95 |
+
# write to buffer
|
| 96 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 97 |
+
in_buffer[-block_shift:] = np.squeeze(indata)
|
| 98 |
+
# calculate fft of input block
|
| 99 |
+
in_block_fft = np.fft.rfft(in_buffer)
|
| 100 |
+
in_mag = np.abs(in_block_fft)
|
| 101 |
+
in_phase = np.angle(in_block_fft)
|
| 102 |
+
# reshape magnitude to input dimensions
|
| 103 |
+
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
|
| 104 |
+
# set tensors to the first model
|
| 105 |
+
interpreter_1.set_tensor(input_details_1[1]['index'], states_1)
|
| 106 |
+
interpreter_1.set_tensor(input_details_1[0]['index'], in_mag)
|
| 107 |
+
# run calculation
|
| 108 |
+
interpreter_1.invoke()
|
| 109 |
+
# get the output of the first block
|
| 110 |
+
out_mask = interpreter_1.get_tensor(output_details_1[0]['index'])
|
| 111 |
+
states_1 = interpreter_1.get_tensor(output_details_1[1]['index'])
|
| 112 |
+
# calculate the ifft
|
| 113 |
+
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
|
| 114 |
+
estimated_block = np.fft.irfft(estimated_complex)
|
| 115 |
+
# reshape the time domain block
|
| 116 |
+
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
|
| 117 |
+
# set tensors to the second block
|
| 118 |
+
interpreter_2.set_tensor(input_details_2[1]['index'], states_2)
|
| 119 |
+
interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block)
|
| 120 |
+
# run calculation
|
| 121 |
+
interpreter_2.invoke()
|
| 122 |
+
# get output tensors
|
| 123 |
+
out_block = interpreter_2.get_tensor(output_details_2[0]['index'])
|
| 124 |
+
states_2 = interpreter_2.get_tensor(output_details_2[1]['index'])
|
| 125 |
+
# write to buffer
|
| 126 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 127 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 128 |
+
out_buffer += np.squeeze(out_block)
|
| 129 |
+
# output to soundcard
|
| 130 |
+
outdata[:] = np.expand_dims(out_buffer[:block_shift], axis=-1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
with sd.Stream(device=(args.input_device, args.output_device),
|
| 136 |
+
samplerate=fs_target, blocksize=block_shift,
|
| 137 |
+
dtype=np.float32, latency=args.latency,
|
| 138 |
+
channels=1, callback=callback):
|
| 139 |
+
print('#' * 80)
|
| 140 |
+
print('press Return to quit')
|
| 141 |
+
print('#' * 80)
|
| 142 |
+
input()
|
| 143 |
+
except KeyboardInterrupt:
|
| 144 |
+
parser.exit('')
|
| 145 |
+
except Exception as e:
|
| 146 |
+
parser.exit(type(e).__name__ + ': ' + str(e))
|
| 147 |
+
|
models/DTLN (yash-04)/real_time_processing.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Created on Tue Jun 23 16:23:15 2020
|
| 5 |
+
|
| 6 |
+
@author: nils
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
##########################
|
| 16 |
+
# the values are fixed, if you need other values, you have to retrain.
|
| 17 |
+
# The sampling rate of 16k is also fix.
|
| 18 |
+
block_len = 512
|
| 19 |
+
block_shift = 128
|
| 20 |
+
# load model
|
| 21 |
+
model = tf.saved_model.load('./pretrained_model/dtln_saved_model')
|
| 22 |
+
infer = model.signatures["serving_default"]
|
| 23 |
+
# load audio file at 16k fs (please change)
|
| 24 |
+
audio,fs = sf.read('path_to_your_favorite_audio.wav')
|
| 25 |
+
# check for sampling rate
|
| 26 |
+
if fs != 16000:
|
| 27 |
+
raise ValueError('This model only supports 16k sampling rate.')
|
| 28 |
+
# preallocate output audio
|
| 29 |
+
out_file = np.zeros((len(audio)))
|
| 30 |
+
# create buffer
|
| 31 |
+
in_buffer = np.zeros((block_len))
|
| 32 |
+
out_buffer = np.zeros((block_len))
|
| 33 |
+
# calculate number of blocks
|
| 34 |
+
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
|
| 35 |
+
# iterate over the number of blcoks
|
| 36 |
+
for idx in range(num_blocks):
|
| 37 |
+
# shift values and write to buffer
|
| 38 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 39 |
+
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
|
| 40 |
+
# create a batch dimension of one
|
| 41 |
+
in_block = np.expand_dims(in_buffer, axis=0).astype('float32')
|
| 42 |
+
# process one block
|
| 43 |
+
out_block= infer(tf.constant(in_block))['conv1d_1']
|
| 44 |
+
# shift values and write to buffer
|
| 45 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 46 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 47 |
+
out_buffer += np.squeeze(out_block)
|
| 48 |
+
# write block to output file
|
| 49 |
+
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# write to .wav file
|
| 53 |
+
sf.write('out.wav', out_file, fs)
|
| 54 |
+
|
| 55 |
+
print('Processing finished.')
|
models/DTLN (yash-04)/real_time_processing_onnx.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is an example how to implement real time processing of the DTLN ONNX
|
| 3 |
+
model in python.
|
| 4 |
+
|
| 5 |
+
Please change the name of the .wav file at line 49 before running the sript.
|
| 6 |
+
For the ONNX runtime call: $ pip install onnxruntime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 11 |
+
Version: 03.07.2020
|
| 12 |
+
|
| 13 |
+
This code is licensed under the terms of the MIT-license.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import soundfile as sf
|
| 17 |
+
import numpy as np
|
| 18 |
+
import time
|
| 19 |
+
import onnxruntime
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
##########################
|
| 24 |
+
# the values are fixed, if you need other values, you have to retrain.
|
| 25 |
+
# The sampling rate of 16k is also fix.
|
| 26 |
+
block_len = 512
|
| 27 |
+
block_shift = 128
|
| 28 |
+
# load models
|
| 29 |
+
interpreter_1 = onnxruntime.InferenceSession('./model_1.onnx')
|
| 30 |
+
model_input_names_1 = [inp.name for inp in interpreter_1.get_inputs()]
|
| 31 |
+
# preallocate input
|
| 32 |
+
model_inputs_1 = {
|
| 33 |
+
inp.name: np.zeros(
|
| 34 |
+
[dim if isinstance(dim, int) else 1 for dim in inp.shape],
|
| 35 |
+
dtype=np.float32)
|
| 36 |
+
for inp in interpreter_1.get_inputs()}
|
| 37 |
+
# load models
|
| 38 |
+
interpreter_2 = onnxruntime.InferenceSession('./model_2.onnx')
|
| 39 |
+
model_input_names_2 = [inp.name for inp in interpreter_2.get_inputs()]
|
| 40 |
+
# preallocate input
|
| 41 |
+
model_inputs_2 = {
|
| 42 |
+
inp.name: np.zeros(
|
| 43 |
+
[dim if isinstance(dim, int) else 1 for dim in inp.shape],
|
| 44 |
+
dtype=np.float32)
|
| 45 |
+
for inp in interpreter_2.get_inputs()}
|
| 46 |
+
|
| 47 |
+
# load audio file
|
| 48 |
+
audio,fs = sf.read('path/to/your/favorite.wav')
|
| 49 |
+
# check for sampling rate
|
| 50 |
+
if fs != 16000:
|
| 51 |
+
raise ValueError('This model only supports 16k sampling rate.')
|
| 52 |
+
# preallocate output audio
|
| 53 |
+
out_file = np.zeros((len(audio)))
|
| 54 |
+
# create buffer
|
| 55 |
+
in_buffer = np.zeros((block_len)).astype('float32')
|
| 56 |
+
out_buffer = np.zeros((block_len)).astype('float32')
|
| 57 |
+
# calculate number of blocks
|
| 58 |
+
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
|
| 59 |
+
# iterate over the number of blcoks
|
| 60 |
+
time_array = []
|
| 61 |
+
for idx in range(num_blocks):
|
| 62 |
+
start_time = time.time()
|
| 63 |
+
# shift values and write to buffer
|
| 64 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 65 |
+
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
|
| 66 |
+
# calculate fft of input block
|
| 67 |
+
in_block_fft = np.fft.rfft(in_buffer)
|
| 68 |
+
in_mag = np.abs(in_block_fft)
|
| 69 |
+
in_phase = np.angle(in_block_fft)
|
| 70 |
+
# reshape magnitude to input dimensions
|
| 71 |
+
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
|
| 72 |
+
# set block to input
|
| 73 |
+
model_inputs_1[model_input_names_1[0]] = in_mag
|
| 74 |
+
# run calculation
|
| 75 |
+
model_outputs_1 = interpreter_1.run(None, model_inputs_1)
|
| 76 |
+
# get the output of the first block
|
| 77 |
+
out_mask = model_outputs_1[0]
|
| 78 |
+
# set out states back to input
|
| 79 |
+
model_inputs_1[model_input_names_1[1]] = model_outputs_1[1]
|
| 80 |
+
# calculate the ifft
|
| 81 |
+
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
|
| 82 |
+
estimated_block = np.fft.irfft(estimated_complex)
|
| 83 |
+
# reshape the time domain block
|
| 84 |
+
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
|
| 85 |
+
# set tensors to the second block
|
| 86 |
+
# interpreter_2.set_tensor(input_details_1[1]['index'], states_2)
|
| 87 |
+
model_inputs_2[model_input_names_2[0]] = estimated_block
|
| 88 |
+
# run calculation
|
| 89 |
+
model_outputs_2 = interpreter_2.run(None, model_inputs_2)
|
| 90 |
+
# get output
|
| 91 |
+
out_block = model_outputs_2[0]
|
| 92 |
+
# set out states back to input
|
| 93 |
+
model_inputs_2[model_input_names_2[1]] = model_outputs_2[1]
|
| 94 |
+
# shift values and write to buffer
|
| 95 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 96 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 97 |
+
out_buffer += np.squeeze(out_block)
|
| 98 |
+
# write block to output file
|
| 99 |
+
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
|
| 100 |
+
time_array.append(time.time()-start_time)
|
| 101 |
+
|
| 102 |
+
# write to .wav file
|
| 103 |
+
sf.write('out.wav', out_file, fs)
|
| 104 |
+
print('Processing Time [ms]:')
|
| 105 |
+
print(np.mean(np.stack(time_array))*1000)
|
| 106 |
+
print('Processing finished.')
|
models/DTLN (yash-04)/real_time_processing_tf_lite.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is an example how to implement real time processing of the DTLN tf light
|
| 3 |
+
model in python.
|
| 4 |
+
|
| 5 |
+
Please change the name of the .wav file at line 43 before running the sript.
|
| 6 |
+
For .whl files of the tf light runtime go to:
|
| 7 |
+
https://www.tensorflow.org/lite/guide/python
|
| 8 |
+
|
| 9 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 10 |
+
Version: 30.06.2020
|
| 11 |
+
|
| 12 |
+
This code is licensed under the terms of the MIT-license.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
import numpy as np
|
| 17 |
+
import tflite_runtime.interpreter as tflite
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
##########################
|
| 23 |
+
# the values are fixed, if you need other values, you have to retrain.
|
| 24 |
+
# The sampling rate of 16k is also fix.
|
| 25 |
+
block_len = 512
|
| 26 |
+
block_shift = 128
|
| 27 |
+
# load models
|
| 28 |
+
interpreter_1 = tflite.Interpreter(model_path='./pretrained_model/model_1.tflite')
|
| 29 |
+
interpreter_1.allocate_tensors()
|
| 30 |
+
interpreter_2 = tflite.Interpreter(model_path='./pretrained_model/model_2.tflite')
|
| 31 |
+
interpreter_2.allocate_tensors()
|
| 32 |
+
|
| 33 |
+
# Get input and output tensors.
|
| 34 |
+
input_details_1 = interpreter_1.get_input_details()
|
| 35 |
+
output_details_1 = interpreter_1.get_output_details()
|
| 36 |
+
|
| 37 |
+
input_details_2 = interpreter_2.get_input_details()
|
| 38 |
+
output_details_2 = interpreter_2.get_output_details()
|
| 39 |
+
# create states for the lstms
|
| 40 |
+
states_1 = np.zeros(input_details_1[1]['shape']).astype('float32')
|
| 41 |
+
states_2 = np.zeros(input_details_2[1]['shape']).astype('float32')
|
| 42 |
+
# load audio file at 16k fs (please change)
|
| 43 |
+
audio,fs = sf.read('path/to/your/favorite/.wav')
|
| 44 |
+
# check for sampling rate
|
| 45 |
+
if fs != 16000:
|
| 46 |
+
raise ValueError('This model only supports 16k sampling rate.')
|
| 47 |
+
# preallocate output audio
|
| 48 |
+
out_file = np.zeros((len(audio)))
|
| 49 |
+
# create buffer
|
| 50 |
+
in_buffer = np.zeros((block_len)).astype('float32')
|
| 51 |
+
out_buffer = np.zeros((block_len)).astype('float32')
|
| 52 |
+
# calculate number of blocks
|
| 53 |
+
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
|
| 54 |
+
time_array = []
|
| 55 |
+
# iterate over the number of blcoks
|
| 56 |
+
for idx in range(num_blocks):
|
| 57 |
+
start_time = time.time()
|
| 58 |
+
# shift values and write to buffer
|
| 59 |
+
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
| 60 |
+
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
|
| 61 |
+
# calculate fft of input block
|
| 62 |
+
in_block_fft = np.fft.rfft(in_buffer)
|
| 63 |
+
in_mag = np.abs(in_block_fft)
|
| 64 |
+
in_phase = np.angle(in_block_fft)
|
| 65 |
+
# reshape magnitude to input dimensions
|
| 66 |
+
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
|
| 67 |
+
# set tensors to the first model
|
| 68 |
+
interpreter_1.set_tensor(input_details_1[1]['index'], states_1)
|
| 69 |
+
interpreter_1.set_tensor(input_details_1[0]['index'], in_mag)
|
| 70 |
+
# run calculation
|
| 71 |
+
interpreter_1.invoke()
|
| 72 |
+
# get the output of the first block
|
| 73 |
+
out_mask = interpreter_1.get_tensor(output_details_1[0]['index'])
|
| 74 |
+
states_1 = interpreter_1.get_tensor(output_details_1[1]['index'])
|
| 75 |
+
# calculate the ifft
|
| 76 |
+
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
|
| 77 |
+
estimated_block = np.fft.irfft(estimated_complex)
|
| 78 |
+
# reshape the time domain block
|
| 79 |
+
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
|
| 80 |
+
# set tensors to the second block
|
| 81 |
+
interpreter_2.set_tensor(input_details_2[1]['index'], states_2)
|
| 82 |
+
interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block)
|
| 83 |
+
# run calculation
|
| 84 |
+
interpreter_2.invoke()
|
| 85 |
+
# get output tensors
|
| 86 |
+
out_block = interpreter_2.get_tensor(output_details_2[0]['index'])
|
| 87 |
+
states_2 = interpreter_2.get_tensor(output_details_2[1]['index'])
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# shift values and write to buffer
|
| 91 |
+
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
| 92 |
+
out_buffer[-block_shift:] = np.zeros((block_shift))
|
| 93 |
+
out_buffer += np.squeeze(out_block)
|
| 94 |
+
# write block to output file
|
| 95 |
+
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
|
| 96 |
+
time_array.append(time.time()-start_time)
|
| 97 |
+
|
| 98 |
+
# write to .wav file
|
| 99 |
+
sf.write('out.wav', out_file, fs)
|
| 100 |
+
print('Processing Time [ms]:')
|
| 101 |
+
print(np.mean(np.stack(time_array))*1000)
|
| 102 |
+
print('Processing finished.')
|
models/DTLN (yash-04)/run_evaluation.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Script to process a folder of .wav files with a trained DTLN model.
|
| 4 |
+
This script supports subfolders and names the processed files the same as the
|
| 5 |
+
original. The model expects 16kHz audio .wav files. Files with other
|
| 6 |
+
sampling rates will be resampled. Stereo files will be downmixed to mono.
|
| 7 |
+
|
| 8 |
+
The idea of this script is to use it for baseline or comparison purpose.
|
| 9 |
+
|
| 10 |
+
Example call:
|
| 11 |
+
$python run_evaluation.py -i /name/of/input/folder \
|
| 12 |
+
-o /name/of/output/folder \
|
| 13 |
+
-m /name/of/the/model.h5
|
| 14 |
+
|
| 15 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 16 |
+
Version: 13.05.2020
|
| 17 |
+
|
| 18 |
+
This code is licensed under the terms of the MIT-license.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
import librosa
|
| 23 |
+
import numpy as np
|
| 24 |
+
import os
|
| 25 |
+
import argparse
|
| 26 |
+
from DTLN_model import DTLN_model
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def process_file(model, audio_file_name, out_file_name):
|
| 31 |
+
'''
|
| 32 |
+
Funtion to read an audio file, rocess it by the network and write the
|
| 33 |
+
enhanced audio to .wav file.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
model : Keras model
|
| 38 |
+
Keras model, which accepts audio in the size (1,timesteps).
|
| 39 |
+
audio_file_name : STRING
|
| 40 |
+
Name and path of the input audio file.
|
| 41 |
+
out_file_name : STRING
|
| 42 |
+
Name and path of the target file.
|
| 43 |
+
|
| 44 |
+
'''
|
| 45 |
+
|
| 46 |
+
# read audio file with librosa to handle resampling and enforce mono
|
| 47 |
+
in_data,fs = librosa.core.load(audio_file_name, sr=16000, mono=True)
|
| 48 |
+
# get length of file
|
| 49 |
+
len_orig = len(in_data)
|
| 50 |
+
# pad audio
|
| 51 |
+
zero_pad = np.zeros(384)
|
| 52 |
+
in_data = np.concatenate((zero_pad, in_data, zero_pad), axis=0)
|
| 53 |
+
# predict audio with the model
|
| 54 |
+
predicted = model.predict_on_batch(
|
| 55 |
+
np.expand_dims(in_data,axis=0).astype(np.float32))
|
| 56 |
+
# squeeze the batch dimension away
|
| 57 |
+
predicted_speech = np.squeeze(predicted)
|
| 58 |
+
predicted_speech = predicted_speech[384:384+len_orig]
|
| 59 |
+
# write the file to target destination
|
| 60 |
+
sf.write(out_file_name, predicted_speech,fs)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def process_folder(model, folder_name, new_folder_name):
|
| 64 |
+
'''
|
| 65 |
+
Function to find .wav files in the folder and subfolders of "folder_name",
|
| 66 |
+
process each .wav file with an algorithm and write it back to disk in the
|
| 67 |
+
folder "new_folder_name". The structure of the original directory is
|
| 68 |
+
preserved. The processed files will be saved with the same name as the
|
| 69 |
+
original file.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
model : Keras model
|
| 74 |
+
Keras model, which accepts audio in the size (1,timesteps).
|
| 75 |
+
folder_name : STRING
|
| 76 |
+
Input folder with .wav files.
|
| 77 |
+
new_folder_name : STRING
|
| 78 |
+
Traget folder for the processed files.
|
| 79 |
+
|
| 80 |
+
'''
|
| 81 |
+
|
| 82 |
+
# empty list for file and folder names
|
| 83 |
+
file_names = [];
|
| 84 |
+
directories = [];
|
| 85 |
+
new_directories = [];
|
| 86 |
+
# walk through the directory
|
| 87 |
+
for root, dirs, files in os.walk(folder_name):
|
| 88 |
+
for file in files:
|
| 89 |
+
# look for .wav files
|
| 90 |
+
if file.endswith(".wav"):
|
| 91 |
+
# write paths and filenames to lists
|
| 92 |
+
file_names.append(file)
|
| 93 |
+
directories.append(root)
|
| 94 |
+
# create new directory names
|
| 95 |
+
new_directories.append(root.replace(folder_name, new_folder_name))
|
| 96 |
+
# check if the new directory already exists, if not create it
|
| 97 |
+
if not os.path.exists(root.replace(folder_name, new_folder_name)):
|
| 98 |
+
os.makedirs(root.replace(folder_name, new_folder_name))
|
| 99 |
+
# iterate over all .wav files
|
| 100 |
+
for idx in range(len(file_names)):
|
| 101 |
+
# process each file with the model
|
| 102 |
+
process_file(model, os.path.join(directories[idx],file_names[idx]),
|
| 103 |
+
os.path.join(new_directories[idx],file_names[idx]))
|
| 104 |
+
print(file_names[idx] + ' processed successfully!')
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
# arguement parser for running directly from the command line
|
| 111 |
+
parser = argparse.ArgumentParser(description='data evaluation')
|
| 112 |
+
parser.add_argument('--in_folder', '-i',
|
| 113 |
+
help='folder with input files')
|
| 114 |
+
parser.add_argument('--out_folder', '-o',
|
| 115 |
+
help='target folder for processed files')
|
| 116 |
+
parser.add_argument('--model', '-m',
|
| 117 |
+
help='weights of the enhancement model in .h5 format')
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
# determine type of model
|
| 120 |
+
if args.model.find('_norm_') != -1:
|
| 121 |
+
norm_stft = True
|
| 122 |
+
else:
|
| 123 |
+
norm_stft = False
|
| 124 |
+
# create class instance
|
| 125 |
+
modelClass = DTLN_model();
|
| 126 |
+
# build the model in default configuration
|
| 127 |
+
modelClass.build_DTLN_model(norm_stft=norm_stft)
|
| 128 |
+
# load weights of the .h5 file
|
| 129 |
+
modelClass.model.load_weights(args.model)
|
| 130 |
+
# process the folder
|
| 131 |
+
process_folder(modelClass.model, args.in_folder, args.out_folder)
|
models/DTLN (yash-04)/run_training.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Script to train the DTLN model in default settings. The folders for noisy and
|
| 5 |
+
clean files are expected to have the same number of files and the files to
|
| 6 |
+
have the same name. The training procedure always saves the best weights of
|
| 7 |
+
the model into the folder "./models_'runName'/". Also a log file of the
|
| 8 |
+
training progress is written there. To change any parameters go to the
|
| 9 |
+
"DTLN_model.py" file or use "modelTrainer.parameter = XY" in this file.
|
| 10 |
+
It is recommended to run the training on a GPU. The setup is optimized for the
|
| 11 |
+
DNS-Challenge data set. If you use a custom data set, just play around with
|
| 12 |
+
the parameters.
|
| 13 |
+
|
| 14 |
+
Please change the folder names before starting the training.
|
| 15 |
+
|
| 16 |
+
Example call:
|
| 17 |
+
$python run_training.py
|
| 18 |
+
|
| 19 |
+
Author: Nils L. Westhausen (nils.westhausen@uol.de)
|
| 20 |
+
Version: 13.05.2020
|
| 21 |
+
|
| 22 |
+
This code is licensed under the terms of the MIT-license.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from DTLN_model import DTLN_model
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
# use the GPU with idx 0
|
| 29 |
+
os.environ["CUDA_VISIBLE_DEVICES"]='0'
|
| 30 |
+
# activate this for some reproducibility
|
| 31 |
+
os.environ['TF_DETERMINISTIC_OPS'] = '1'
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# path to folder containing the noisy or mixed audio training files
|
| 38 |
+
path_to_train_mix = '/path/to/noisy/training/data/'
|
| 39 |
+
# path to folder containing the clean/speech files for training
|
| 40 |
+
path_to_train_speech = '/path/to/clean/training/data/'
|
| 41 |
+
# path to folder containing the noisy or mixed audio validation data
|
| 42 |
+
path_to_val_mix = '/path/to/noisy/validation/data/'
|
| 43 |
+
# path to folder containing the clean audio validation data
|
| 44 |
+
path_to_val_speech = '/path/to/clean/validation/data/'
|
| 45 |
+
|
| 46 |
+
# name your training run
|
| 47 |
+
runName = 'DTLN_model'
|
| 48 |
+
# create instance of the DTLN model class
|
| 49 |
+
modelTrainer = DTLN_model()
|
| 50 |
+
# build the model
|
| 51 |
+
modelTrainer.build_DTLN_model()
|
| 52 |
+
# compile it with optimizer and cost function for training
|
| 53 |
+
modelTrainer.compile_model()
|
| 54 |
+
# train the model
|
| 55 |
+
modelTrainer.train_model(runName, path_to_train_mix, path_to_train_speech, \
|
| 56 |
+
path_to_val_mix, path_to_val_speech)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
models/DTLN (yash-04)/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/yash-04/DTLN
|
models/DTLN (yash-04)/tflite_env.yml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tflite-env
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- audioread=2.1.8=py37hc8dfbb8_2
|
| 7 |
+
- bzip2=1.0.8=h0b31af3_2
|
| 8 |
+
- ca-certificates=2020.6.20=hecda079_0
|
| 9 |
+
- certifi=2020.6.20=py37hc8dfbb8_0
|
| 10 |
+
- cffi=1.14.0=py37hc512035_1
|
| 11 |
+
- cycler=0.10.0=py_2
|
| 12 |
+
- decorator=4.4.2=py_0
|
| 13 |
+
- ffmpeg=4.2.3=hd0c0d6a_0
|
| 14 |
+
- freetype=2.10.2=h8da9a1a_0
|
| 15 |
+
- gettext=0.19.8.1=h1f1d5ed_1
|
| 16 |
+
- gmp=6.2.0=h4a8c4bd_2
|
| 17 |
+
- gnutls=3.6.13=hc269f14_0
|
| 18 |
+
- joblib=0.15.1=py_0
|
| 19 |
+
- kiwisolver=1.2.0=py37ha1cc60f_0
|
| 20 |
+
- lame=3.100=h1de35cc_1001
|
| 21 |
+
- libblas=3.8.0=17_openblas
|
| 22 |
+
- libcblas=3.8.0=17_openblas
|
| 23 |
+
- libcxx=10.0.0=1
|
| 24 |
+
- libedit=3.1.20191231=haf1e3a3_0
|
| 25 |
+
- libffi=3.3=h0a44026_1
|
| 26 |
+
- libflac=1.3.3=h4a8c4bd_0
|
| 27 |
+
- libgfortran=4.0.0=2
|
| 28 |
+
- libiconv=1.15=h0b31af3_1006
|
| 29 |
+
- liblapack=3.8.0=17_openblas
|
| 30 |
+
- libllvm8=8.0.1=h770b8ee_0
|
| 31 |
+
- libogg=1.3.2=h0b31af3_1002
|
| 32 |
+
- libopenblas=0.3.10=h3d69b6c_0
|
| 33 |
+
- libpng=1.6.37=hbbe82c9_1
|
| 34 |
+
- librosa=0.7.2=py_1
|
| 35 |
+
- libsndfile=1.0.28=h4a8c4bd_1000
|
| 36 |
+
- libvorbis=1.3.6=h4a8c4bd_2
|
| 37 |
+
- llvm-openmp=10.0.0=h28b9765_0
|
| 38 |
+
- llvmlite=0.31.0=py37hb548287_1
|
| 39 |
+
- matplotlib-base=3.2.2=py37hddda452_0
|
| 40 |
+
- ncurses=6.2=h0a44026_1
|
| 41 |
+
- nettle=3.4.1=h3efe00b_1002
|
| 42 |
+
- numba=0.48.0=py37h4f17bb1_0
|
| 43 |
+
- numpy=1.18.5=py37h7687784_0
|
| 44 |
+
- openh264=2.1.1=hd174df1_0
|
| 45 |
+
- openssl=1.1.1g=h0b31af3_0
|
| 46 |
+
- pip=20.1.1=py37_1
|
| 47 |
+
- portaudio=19.6.0=h647c56a_4
|
| 48 |
+
- pycparser=2.20=pyh9f0ad1d_2
|
| 49 |
+
- pyparsing=2.4.7=pyh9f0ad1d_0
|
| 50 |
+
- pysoundfile=0.10.2=py_1001
|
| 51 |
+
- python=3.7.7=hf48f09d_4
|
| 52 |
+
- python-dateutil=2.8.1=py_0
|
| 53 |
+
- python-sounddevice=0.3.15=pyh8c360ce_0
|
| 54 |
+
- python_abi=3.7=1_cp37m
|
| 55 |
+
- readline=8.0=h1de35cc_0
|
| 56 |
+
- resampy=0.2.2=py_0
|
| 57 |
+
- scikit-learn=0.23.1=py37hf5857e7_0
|
| 58 |
+
- scipy=1.5.0=py37hce1b9e5_0
|
| 59 |
+
- setuptools=47.3.1=py37_0
|
| 60 |
+
- six=1.15.0=pyh9f0ad1d_0
|
| 61 |
+
- sqlite=3.32.3=hffcf06c_0
|
| 62 |
+
- threadpoolctl=2.1.0=pyh5ca1d4c_0
|
| 63 |
+
- tk=8.6.10=hb0a8c7a_0
|
| 64 |
+
- tornado=6.0.4=py37h9bfed18_1
|
| 65 |
+
- wheel=0.34.2=py37_0
|
| 66 |
+
- x264=1!152.20180806=h1de35cc_0
|
| 67 |
+
- xz=5.2.5=h1de35cc_0
|
| 68 |
+
- zlib=1.2.11=h1de35cc_3
|
| 69 |
+
- pip:
|
| 70 |
+
- flatbuffers==1.12
|
| 71 |
+
- tflite-runtime==2.1.0.post1
|
| 72 |
+
prefix: /Applications/anaconda3/envs/tflite-env
|
| 73 |
+
|
models/DTLN (yash-04)/train_env.yml
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: train_env
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _tflow_select=2.1.0=gpu
|
| 8 |
+
- absl-py=0.9.0=py37_0
|
| 9 |
+
- astunparse=1.6.3=py_0
|
| 10 |
+
- audioread=2.1.8=py37hc8dfbb8_2
|
| 11 |
+
- blas=1.0=mkl
|
| 12 |
+
- blinker=1.4=py37_0
|
| 13 |
+
- bzip2=1.0.8=h516909a_2
|
| 14 |
+
- c-ares=1.15.0=h7b6447c_1001
|
| 15 |
+
- ca-certificates=2020.4.5.2=hecda079_0
|
| 16 |
+
- cachetools=4.1.0=py_1
|
| 17 |
+
- certifi=2020.4.5.2=py37hc8dfbb8_0
|
| 18 |
+
- cffi=1.14.0=py37he30daa8_1
|
| 19 |
+
- chardet=3.0.4=py37_1003
|
| 20 |
+
- click=7.1.2=py_0
|
| 21 |
+
- cryptography=2.9.2=py37h1ba5d50_0
|
| 22 |
+
- cudatoolkit=10.1.243=h6bb024c_0
|
| 23 |
+
- cudnn=7.6.5=cuda10.1_0
|
| 24 |
+
- cupti=10.1.168=0
|
| 25 |
+
- cycler=0.10.0=py_2
|
| 26 |
+
- decorator=4.4.2=py_0
|
| 27 |
+
- ffmpeg=4.2.3=h167e202_0
|
| 28 |
+
- freetype=2.10.2=he06d7ca_0
|
| 29 |
+
- gast=0.3.3=py_0
|
| 30 |
+
- gettext=0.19.8.1=h5e8e0c9_1
|
| 31 |
+
- gmp=6.2.0=he1b5a44_2
|
| 32 |
+
- gnutls=3.6.13=h79a8f9a_0
|
| 33 |
+
- google-auth=1.14.1=py_0
|
| 34 |
+
- google-auth-oauthlib=0.4.1=py_2
|
| 35 |
+
- google-pasta=0.2.0=py_0
|
| 36 |
+
- grpcio=1.27.2=py37hf8bcb03_0
|
| 37 |
+
- h5py=2.10.0=py37h7918eee_0
|
| 38 |
+
- hdf5=1.10.4=hb1b8bf9_0
|
| 39 |
+
- icu=58.2=hf484d3e_1000
|
| 40 |
+
- idna=2.9=py_1
|
| 41 |
+
- intel-openmp=2020.1=217
|
| 42 |
+
- joblib=0.15.1=py_0
|
| 43 |
+
- keras-preprocessing=1.1.0=py_1
|
| 44 |
+
- kiwisolver=1.2.0=py37h99015e2_0
|
| 45 |
+
- lame=3.100=h14c3975_1001
|
| 46 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
| 47 |
+
- libedit=3.1.20181209=hc058e9b_0
|
| 48 |
+
- libffi=3.3=he6710b0_1
|
| 49 |
+
- libflac=1.3.3=he1b5a44_0
|
| 50 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
| 51 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
| 52 |
+
- libiconv=1.15=h516909a_1006
|
| 53 |
+
- libllvm8=8.0.1=hc9558a2_0
|
| 54 |
+
- libogg=1.3.2=h516909a_1002
|
| 55 |
+
- libpng=1.6.37=hed695b0_1
|
| 56 |
+
- libprotobuf=3.12.3=hd408876_0
|
| 57 |
+
- librosa=0.7.2=py_1
|
| 58 |
+
- libsndfile=1.0.28=he1b5a44_1000
|
| 59 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
| 60 |
+
- libvorbis=1.3.6=he1b5a44_2
|
| 61 |
+
- llvmlite=0.31.0=py37h5202443_1
|
| 62 |
+
- markdown=3.1.1=py37_0
|
| 63 |
+
- matplotlib-base=3.1.3=py37hef1b27d_0
|
| 64 |
+
- mkl=2020.1=217
|
| 65 |
+
- mkl-service=2.3.0=py37he904b0f_0
|
| 66 |
+
- mkl_fft=1.0.15=py37ha843d7b_0
|
| 67 |
+
- mkl_random=1.1.1=py37h0573a6f_0
|
| 68 |
+
- ncurses=6.2=he6710b0_1
|
| 69 |
+
- nettle=3.4.1=h1bed415_1002
|
| 70 |
+
- numba=0.48.0=py37hb3f55d8_0
|
| 71 |
+
- numpy=1.18.1=py37h4f9e942_0
|
| 72 |
+
- numpy-base=1.18.1=py37hde5b4d6_1
|
| 73 |
+
- oauthlib=3.1.0=py_0
|
| 74 |
+
- openh264=2.1.1=h8b12597_0
|
| 75 |
+
- openssl=1.1.1g=h516909a_0
|
| 76 |
+
- opt_einsum=3.1.0=py_0
|
| 77 |
+
- pip=20.0.2=py37_3
|
| 78 |
+
- protobuf=3.12.3=py37he6710b0_0
|
| 79 |
+
- pyasn1=0.4.8=py_0
|
| 80 |
+
- pyasn1-modules=0.2.7=py_0
|
| 81 |
+
- pycparser=2.20=py_0
|
| 82 |
+
- pyjwt=1.7.1=py37_0
|
| 83 |
+
- pyopenssl=19.1.0=py37_0
|
| 84 |
+
- pyparsing=2.4.7=pyh9f0ad1d_0
|
| 85 |
+
- pysocks=1.7.1=py37_0
|
| 86 |
+
- pysoundfile=0.10.2=py_1001
|
| 87 |
+
- python=3.7.7=hcff3b4d_5
|
| 88 |
+
- python-dateutil=2.8.1=py_0
|
| 89 |
+
- python_abi=3.7=1_cp37m
|
| 90 |
+
- readline=8.0=h7b6447c_0
|
| 91 |
+
- requests=2.23.0=py37_0
|
| 92 |
+
- requests-oauthlib=1.3.0=py_0
|
| 93 |
+
- resampy=0.2.2=py_0
|
| 94 |
+
- rsa=4.0=py_0
|
| 95 |
+
- scikit-learn=0.22.1=py37hd81dba3_0
|
| 96 |
+
- scipy=1.4.1=py37h0b6359f_0
|
| 97 |
+
- setuptools=47.1.1=py37_0
|
| 98 |
+
- six=1.15.0=py_0
|
| 99 |
+
- sqlite=3.31.1=h62c20be_1
|
| 100 |
+
- tensorboard=2.2.1=pyh532a8cf_0
|
| 101 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
| 102 |
+
- tensorflow=2.2.0=gpu_py37h1a511ff_0
|
| 103 |
+
- tensorflow-base=2.2.0=gpu_py37h8a81be8_0
|
| 104 |
+
- tensorflow-estimator=2.2.0=pyh208ff02_0
|
| 105 |
+
- tensorflow-gpu=2.2.0=h0d30ee6_0
|
| 106 |
+
- termcolor=1.1.0=py37_1
|
| 107 |
+
- tk=8.6.8=hbc83047_0
|
| 108 |
+
- tornado=6.0.4=py37h8f50634_1
|
| 109 |
+
- urllib3=1.25.8=py37_0
|
| 110 |
+
- werkzeug=1.0.1=py_0
|
| 111 |
+
- wheel=0.34.2=py37_0
|
| 112 |
+
- wrapt=1.12.1=py37h7b6447c_1
|
| 113 |
+
- x264=1!152.20180806=h14c3975_0
|
| 114 |
+
- xz=5.2.5=h7b6447c_0
|
| 115 |
+
- zlib=1.2.11=h7b6447c_3
|
| 116 |
+
- pip:
|
| 117 |
+
- lxml==4.5.1
|
| 118 |
+
- wavinfo==1.5
|
| 119 |
+
prefix: /home/nils/anaconda3/envs/tfenv
|
| 120 |
+
|