Spaces:
Running
Running
update change 2
Browse files- .gitignore +1 -0
- app.py +63 -33
- config/config.yaml +23 -0
- datahandler.py +216 -0
- decode.py +88 -0
- model/__pycache__/cnns.cpython-37.pyc +0 -0
- model/__pycache__/cnns.cpython-38.pyc +0 -0
- model/__pycache__/norm.cpython-37.pyc +0 -0
- model/__pycache__/norm.cpython-38.pyc +0 -0
- model/__pycache__/spex_plus.cpython-37.pyc +0 -0
- model/__pycache__/spex_plus.cpython-38.pyc +0 -0
- {nnet → model}/cnns.py +116 -15
- {nnet → model}/norm.py +39 -0
- {nnet → model}/spex_plus.py +138 -165
- nnet/ResNet34.py +0 -213
- nnet/__init__.py +0 -0
- nnet/pooling.py +0 -100
- nnet/speaker_encoder.py +0 -47
- noises/00840.wav +0 -0
- noises/022928.wav +0 -0
- noises/04338.wav +0 -0
- noises/046324.wav +0 -0
- noises/093004.wav +0 -0
- noises/11129.wav +0 -0
- noises/133254.wav +0 -0
- noises/30100.wav +0 -0
- noises/30135.wav +0 -0
- noises/30437.wav +0 -0
- noises/30603.wav +0 -0
- requirements.txt +6 -1
- temp_extracted.wav +0 -0
- test_mix.wav +0 -0
- test_output_mixture.wav +0 -0
- utils/__init__.py +0 -0
- utils/audio.py +0 -124
- utils/dataset copy.py +0 -284
- utils/dataset.py +0 -402
- utils/load_obj.py +0 -18
- utils/logger.py +0 -22
- utils/sisdr.py +0 -23
- utils/timer.py +0 -17
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ckpt/
|
app.py
CHANGED
|
@@ -1,55 +1,85 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
|
|
|
| 6 |
|
| 7 |
-
# def convert_audio_to_wav(file_path):
|
| 8 |
-
# """Convert any supported format (mp3, etc.) to wav using librosa"""
|
| 9 |
-
# output_path = "temp_input.wav"
|
| 10 |
-
# audio, sr = librosa.load(file_path, sr=None) # 加载音频文件
|
| 11 |
-
# librosa.output.write_wav(output_path, audio, sr) # 转换并保存为 WAV 格式
|
| 12 |
-
# return output_path
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
-
|
| 17 |
-
|
| 18 |
"""
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
print(f"
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
with gr.Blocks() as demo:
|
| 29 |
gr.Markdown("## Target Speaker Extraction Demo")
|
| 30 |
gr.Markdown(
|
| 31 |
-
"This demo
|
| 32 |
-
"with or without noises and reverberations."
|
| 33 |
)
|
| 34 |
-
|
| 35 |
-
# input
|
| 36 |
with gr.Row():
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
# output
|
| 41 |
with gr.Row():
|
|
|
|
| 42 |
noisy_audio_output = gr.Audio(label="Noisy Audio (Processed input audio)", type="filepath")
|
| 43 |
extracted_audio_output = gr.Audio(label="Extracted target speaker audio", type="filepath")
|
| 44 |
|
| 45 |
-
#
|
| 46 |
convert_button = gr.Button("Extract")
|
| 47 |
-
|
| 48 |
-
# event
|
| 49 |
convert_button.click(
|
| 50 |
fn=gradio_TSE,
|
| 51 |
-
inputs=[input_audio, enroll_audio],
|
| 52 |
-
outputs=[noisy_audio_output, extracted_audio_output]
|
| 53 |
)
|
| 54 |
|
| 55 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import numpy as np
|
| 5 |
|
| 6 |
+
# 这是你现有的推理管线
|
| 7 |
+
from decode import InferencePipeline
|
| 8 |
+
from datahandler import AudioMixer, fix_audio_format
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
|
| 12 |
+
#####################################
|
| 13 |
+
# 这是你的推理 pipeline
|
| 14 |
+
#####################################
|
| 15 |
+
inter = InferencePipeline()
|
| 16 |
+
datamix = AudioMixer()
|
| 17 |
+
|
| 18 |
+
#####################################
|
| 19 |
+
# 这是供 Gradio 点击时调用的函数
|
| 20 |
+
#####################################
|
| 21 |
+
def gradio_TSE(input_audio_path, enroll_audio_path, audio_type):
|
| 22 |
"""
|
| 23 |
+
如果 audio_type 是 "clean",就调用 data_handler(此处是 produce_mixture_from_clean)
|
| 24 |
+
把输入先变成 mix;如果是 "mix",则直接使用原始文件。
|
| 25 |
"""
|
| 26 |
+
print(f"User uploaded audio path: {input_audio_path}")
|
| 27 |
+
print(f"User enroll audio path: {enroll_audio_path}")
|
| 28 |
+
print(f"User chose audio_type: {audio_type}")
|
| 29 |
+
|
| 30 |
+
if audio_type == "clean":
|
| 31 |
+
# 先把 clean 转成 mix
|
| 32 |
+
mix_path = datamix.produce_mixture_from_clean(input_audio_path)
|
| 33 |
+
print(f"Converted clean -> mix: {mix_path}")
|
| 34 |
+
else:
|
| 35 |
+
# 如果是已经是混合音频,直接用它
|
| 36 |
+
mix_path = input_audio_path
|
| 37 |
+
|
| 38 |
+
input_wav = fix_audio_format(mix_path)
|
| 39 |
+
mix_wav = "mix.wav"
|
| 40 |
+
sf.write(mix_wav, input_wav, 16000)
|
| 41 |
+
enroll_wav = fix_audio_format(enroll_audio_path)
|
| 42 |
+
eol_wav = "eol.wav"
|
| 43 |
+
sf.write(eol_wav, enroll_wav, 16000)
|
| 44 |
+
|
| 45 |
+
est_path = inter.computer(mix_wav, eol_wav)
|
| 46 |
+
# 接下来走你的推理流程
|
| 47 |
+
return mix_path,est_path
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
#####################################
|
| 51 |
+
# 搭建 Gradio 界面
|
| 52 |
+
#####################################
|
| 53 |
with gr.Blocks() as demo:
|
| 54 |
gr.Markdown("## Target Speaker Extraction Demo")
|
| 55 |
gr.Markdown(
|
| 56 |
+
"This demo can handle either clean audio (which we'll turn into a mix) or directly a mix audio."
|
|
|
|
| 57 |
)
|
| 58 |
+
|
|
|
|
| 59 |
with gr.Row():
|
| 60 |
+
# 上传或录制的“待处理”音频
|
| 61 |
+
input_audio = gr.Audio(label="Upload/record your audio", type="filepath")
|
| 62 |
+
# 让用户手动指定音频类型
|
| 63 |
+
audio_type = gr.Radio(
|
| 64 |
+
choices=["clean", "mix"],
|
| 65 |
+
value="clean",
|
| 66 |
+
label="Input audio type?"
|
| 67 |
+
)
|
| 68 |
+
with gr.Row():
|
| 69 |
+
# enroll 音频
|
| 70 |
+
enroll_audio = gr.Audio(label="Upload your enroll audio", type="filepath")
|
| 71 |
|
|
|
|
| 72 |
with gr.Row():
|
| 73 |
+
# 输出:处理后的 noisy 和 提取的目标说话人
|
| 74 |
noisy_audio_output = gr.Audio(label="Noisy Audio (Processed input audio)", type="filepath")
|
| 75 |
extracted_audio_output = gr.Audio(label="Extracted target speaker audio", type="filepath")
|
| 76 |
|
| 77 |
+
# 点击按钮触发
|
| 78 |
convert_button = gr.Button("Extract")
|
|
|
|
|
|
|
| 79 |
convert_button.click(
|
| 80 |
fn=gradio_TSE,
|
| 81 |
+
inputs=[input_audio, enroll_audio, audio_type],
|
| 82 |
+
outputs=[noisy_audio_output, extracted_audio_output]
|
| 83 |
)
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
config/config.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
model:
|
| 3 |
+
_target_: model.spex_plus.SpEx_Plus # str, model class name
|
| 4 |
+
L1: 40
|
| 5 |
+
L2: 160
|
| 6 |
+
L3: 320
|
| 7 |
+
N: 256
|
| 8 |
+
B: 8
|
| 9 |
+
O: 256
|
| 10 |
+
P: 512
|
| 11 |
+
Q: 3
|
| 12 |
+
num_spks: 1410 # with speed perturbation 470 -> 1410
|
| 13 |
+
spk_embed_dim: 256
|
| 14 |
+
causal: false
|
| 15 |
+
is_innorm: true
|
| 16 |
+
fusion_type: 'cat' #cat mul film att
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
test:
|
| 20 |
+
checkpoint: "./ckpt/v2.0.pt.tar"
|
| 21 |
+
gpu: -1
|
| 22 |
+
sample_rate: 16000
|
| 23 |
+
|
datahandler.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import warnings
|
| 4 |
+
import numpy as np
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import pyloudnorm
|
| 7 |
+
import glob
|
| 8 |
+
import librosa
|
| 9 |
+
|
| 10 |
+
def fix_audio_format(audio_path, out_sr=16000):
|
| 11 |
+
"""
|
| 12 |
+
将音频读进来(自动识别格式)并强制重采样到 out_sr,转换为单声道。
|
| 13 |
+
最终返回:
|
| 14 |
+
- data (numpy array): 处理后(单声道、out_sr)的音频数据
|
| 15 |
+
- sr (int): 处理后的采样率 (默认 16k)
|
| 16 |
+
不写入临时文件,只做内存操作。
|
| 17 |
+
"""
|
| 18 |
+
# librosa.load 会自动解析不同格式的音频(wav/mp3/flac等)
|
| 19 |
+
# 并将其重采样到 out_sr, 同时 mono=True 意味着转换为单声道
|
| 20 |
+
data, sr = librosa.load(audio_path, sr=out_sr, mono=True)
|
| 21 |
+
|
| 22 |
+
return data
|
| 23 |
+
|
| 24 |
+
class AudioMixer(object):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
sample_rate=16000,
|
| 28 |
+
mean_snr=-7,
|
| 29 |
+
var_snr=25,
|
| 30 |
+
mean_loudness=-24,
|
| 31 |
+
var_loudness=20
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
初始化一些参数、随机种子和响度计算工具等。
|
| 35 |
+
"""
|
| 36 |
+
self.sample_rate = sample_rate
|
| 37 |
+
self.mean_snr = mean_snr
|
| 38 |
+
self.var_snr = var_snr
|
| 39 |
+
self.MEAN_LOUNDNESS = mean_loudness
|
| 40 |
+
self.VAR_LOUNDNESS = var_loudness
|
| 41 |
+
|
| 42 |
+
self.EPS = 1e-10
|
| 43 |
+
self.MAX_AMP = 0.9
|
| 44 |
+
|
| 45 |
+
# pyloudnorm 的 Meter,用于计算音频响度
|
| 46 |
+
self.meter = pyloudnorm.Meter(self.sample_rate)
|
| 47 |
+
|
| 48 |
+
# # 也可固定随机种子,保证每次混合一致(如果想要可复现)
|
| 49 |
+
# self.seed = 1453
|
| 50 |
+
# random.seed(self.seed)
|
| 51 |
+
# np.random.seed(self.seed)
|
| 52 |
+
|
| 53 |
+
def read_wav(self, wav_path):
|
| 54 |
+
"""
|
| 55 |
+
读取音频文件并返回 wave 数据和采样率
|
| 56 |
+
"""
|
| 57 |
+
data, sr = sf.read(wav_path, dtype='float32')
|
| 58 |
+
# 如果读到的是多通道,可只取其中一个通道
|
| 59 |
+
if data.ndim > 1:
|
| 60 |
+
data = data[:, 0]
|
| 61 |
+
return data, sr
|
| 62 |
+
|
| 63 |
+
def normalize(self, signal, is_noise=False):
|
| 64 |
+
"""
|
| 65 |
+
对输入的 signal 做响度归一化,并确保不会过载失真。
|
| 66 |
+
"""
|
| 67 |
+
c_loudness = self.meter.integrated_loudness(signal)
|
| 68 |
+
if is_noise:
|
| 69 |
+
# 噪声的目标响度可以偏高一些或随便设置
|
| 70 |
+
target_loudness = np.random.normal(self.MEAN_LOUNDNESS + 4, self.VAR_LOUNDNESS**0.5)
|
| 71 |
+
else:
|
| 72 |
+
# mix 或者语音的目标响度
|
| 73 |
+
target_loudness = np.random.normal(self.MEAN_LOUNDNESS, self.VAR_LOUNDNESS**0.5)
|
| 74 |
+
|
| 75 |
+
with warnings.catch_warnings():
|
| 76 |
+
warnings.filterwarnings("error", category=RuntimeWarning)
|
| 77 |
+
signal = pyloudnorm.normalize.loudness(signal, c_loudness, target_loudness)
|
| 78 |
+
|
| 79 |
+
# # 再检查是否会 clipping
|
| 80 |
+
# peak = np.max(np.abs(signal))
|
| 81 |
+
# if peak >= 1.0:
|
| 82 |
+
# signal = signal * self.MAX_AMP / peak
|
| 83 |
+
|
| 84 |
+
return signal
|
| 85 |
+
|
| 86 |
+
def snr_norm(self, signal, noise, is_noise=True):
|
| 87 |
+
"""
|
| 88 |
+
根据预设的 mean_snr、var_snr 来随机决定一个目标 SNR,然后
|
| 89 |
+
以此对 noise 做缩放,得到与 signal 相匹配的噪声幅度。
|
| 90 |
+
"""
|
| 91 |
+
if is_noise:
|
| 92 |
+
desired_snr = np.random.normal(self.mean_snr, self.var_snr**0.5)
|
| 93 |
+
else:
|
| 94 |
+
# 如果你还有别的需求,比如想做正 SNR 范围,可以改这里
|
| 95 |
+
desired_snr = np.random.uniform(2, 10)
|
| 96 |
+
|
| 97 |
+
current_snr = 10 * np.log10(
|
| 98 |
+
np.mean(signal ** 2) / (np.mean(noise ** 2) + self.EPS) + self.EPS
|
| 99 |
+
)
|
| 100 |
+
scale_factor = 10 ** ((current_snr - desired_snr) / 20)
|
| 101 |
+
|
| 102 |
+
scaled_noise = noise * scale_factor
|
| 103 |
+
|
| 104 |
+
# # 防止噪声自身 clipping
|
| 105 |
+
# peak = np.max(np.abs(scaled_noise))
|
| 106 |
+
# if peak >= 1.0:
|
| 107 |
+
# scaled_noise = scaled_noise * self.MAX_AMP / peak
|
| 108 |
+
|
| 109 |
+
return scaled_noise
|
| 110 |
+
|
| 111 |
+
def _mix(self, sources_list):
|
| 112 |
+
"""
|
| 113 |
+
将多路音频进行叠加,防止溢出。
|
| 114 |
+
"""
|
| 115 |
+
# 假设 sources_list[0] 是 mix 音频,sources_list[1] 是已拼好长度的 noise
|
| 116 |
+
mix_length = len(sources_list[0])
|
| 117 |
+
mixture = np.zeros(mix_length, dtype=np.float32)
|
| 118 |
+
for s in sources_list:
|
| 119 |
+
mixture += s[:mix_length] # 仅叠加到 mix 的长度
|
| 120 |
+
|
| 121 |
+
# 再做一次峰值校正,避免溢出
|
| 122 |
+
peak = np.max(np.abs(mixture))
|
| 123 |
+
if peak >= 1.0:
|
| 124 |
+
mixture = mixture * self.MAX_AMP / peak
|
| 125 |
+
|
| 126 |
+
return mixture
|
| 127 |
+
|
| 128 |
+
def _prepare_noise_for_mix(self, noise_files, mix_length):
|
| 129 |
+
"""
|
| 130 |
+
传入一组 noise 文件路径,先对它们打乱,再依次读取、拼接。
|
| 131 |
+
如果总长度还不够覆盖 mix_length,可以再次拼接自己(循环)。
|
| 132 |
+
|
| 133 |
+
- noise_files: 存储多个噪声文件路径的列表
|
| 134 |
+
- mix_length: 需要的总长度(采样点数)
|
| 135 |
+
|
| 136 |
+
返回: 拼接后的 noise 波形
|
| 137 |
+
"""
|
| 138 |
+
# 先随机打乱
|
| 139 |
+
random.shuffle(noise_files)
|
| 140 |
+
|
| 141 |
+
# 依次读取并拼接
|
| 142 |
+
noise_all = []
|
| 143 |
+
total_len = 0
|
| 144 |
+
|
| 145 |
+
# ���一次先拼完所有 noise 文件,如果还不够,就重复拼接
|
| 146 |
+
while total_len < mix_length:
|
| 147 |
+
for nf in noise_files:
|
| 148 |
+
noise_data, _ = self.read_wav(nf)
|
| 149 |
+
|
| 150 |
+
# 可选:对每条 noise 做一次 normalize,提升多样性
|
| 151 |
+
# (或者只在外部做一次统一的 normalize)
|
| 152 |
+
#noise_data = self.normalize(noise_data, is_noise=True)
|
| 153 |
+
|
| 154 |
+
noise_all.append(noise_data)
|
| 155 |
+
total_len += len(noise_data)
|
| 156 |
+
|
| 157 |
+
if total_len >= mix_length:
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
# 如果已经拼完一轮,可能还不够,就继续 while 循环再拼一轮
|
| 161 |
+
|
| 162 |
+
# 拼接后截断到 mix_length
|
| 163 |
+
concatenated_noise = np.concatenate(noise_all)[:mix_length]
|
| 164 |
+
return concatenated_noise
|
| 165 |
+
|
| 166 |
+
def mix_with_noise_folder(self, mix_wave,sr_mix noise_folder):
|
| 167 |
+
"""
|
| 168 |
+
读取一条 mix 文件和一个 noise 文件夹,做如下处理:
|
| 169 |
+
1. 读取 mix wave,并做响度归一化
|
| 170 |
+
2. 根据 mix 的长度,在 noise 文件夹中随机打乱全部 wav,依次拼接满足同长度
|
| 171 |
+
3. 对最终拼好的 noise 做 snr_norm
|
| 172 |
+
4. 叠加输出
|
| 173 |
+
"""
|
| 174 |
+
# 1. 读取 mix
|
| 175 |
+
# mix_wave, sr_mix = self.read_wav(mix_path)
|
| 176 |
+
|
| 177 |
+
# 如果文件夹下找不到任何 noise 文件,就直接返回原音频
|
| 178 |
+
noise_files = sorted(glob.glob(os.path.join(noise_folder, "*.wav")))
|
| 179 |
+
if not noise_files:
|
| 180 |
+
raise RuntimeError(f"噪声文件夹 {noise_folder} 内未发现 .wav 文件")
|
| 181 |
+
|
| 182 |
+
mix_wave = self.normalize(mix_wave, is_noise=False)
|
| 183 |
+
mix_length = len(mix_wave)
|
| 184 |
+
|
| 185 |
+
# 2. 先把 noise 文件拼接到 match mix_length
|
| 186 |
+
# (会将 noise_files 打乱后依次读、拼接)
|
| 187 |
+
noise_ready = self._prepare_noise_for_mix(noise_files, mix_length)
|
| 188 |
+
|
| 189 |
+
# 3. SNR 调整
|
| 190 |
+
noise_ready = self.snr_norm(mix_wave, noise_ready, is_noise=True)
|
| 191 |
+
|
| 192 |
+
# 4. 叠加
|
| 193 |
+
mixture = self._mix([mix_wave, noise_ready])
|
| 194 |
+
|
| 195 |
+
out_noisy = "temp_noisy.wav" # 可以理解为把输入的混合音频直接另存为
|
| 196 |
+
|
| 197 |
+
# 返回混合后的音频以及采样率
|
| 198 |
+
sf.write(out_noisy, mixture, sr_mix)
|
| 199 |
+
|
| 200 |
+
return out_noisy
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
# 假设你有一个 mix.wav 以及一个 noise 文件夹(含若干个 .wav 噪声文件)
|
| 205 |
+
mix_path_test = "test_mix.wav"
|
| 206 |
+
mix_wave, sr_mix = self.read_wav(mix_path_test)
|
| 207 |
+
noise_folder_test = "noises/" # 比如里面有 10 条 noise*.wav
|
| 208 |
+
|
| 209 |
+
mixer = AudioMixer()
|
| 210 |
+
|
| 211 |
+
# 执行混合
|
| 212 |
+
mixed_wav, sr = mixer.mix_with_noise_folder(mix_wave, sr_mix, noise_folder_test)
|
| 213 |
+
|
| 214 |
+
# 这里你可以选择把结果写回本地文件,或直接返回 numpy 数组做后续处理
|
| 215 |
+
sf.write("test_output_mixture.wav", mixed_wav, sr)
|
| 216 |
+
print("混合完成,已输出到 test_output_mixture.wav")
|
decode.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import hydra
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ================ 网络推理类 ================
|
| 14 |
+
class NnetComputer(object):
|
| 15 |
+
def __init__(self, cpt_dir, gpuid, nnet_conf):
|
| 16 |
+
self.device = th.device(f"cuda:{gpuid}") if gpuid >= 0 else th.device("cpu")
|
| 17 |
+
nnet = self._load_nnet(cpt_dir, nnet_conf)
|
| 18 |
+
self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet
|
| 19 |
+
self.nnet.eval()
|
| 20 |
+
|
| 21 |
+
def _load_nnet(self, cpt_dir, model):
|
| 22 |
+
cpt = th.load(cpt_dir, map_location="cpu")
|
| 23 |
+
model.load_state_dict(cpt["model_state_dict"])
|
| 24 |
+
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
def compute(self, samps, aux_samps, aux_samps_len):
|
| 28 |
+
with th.no_grad():
|
| 29 |
+
raw = th.tensor(samps, dtype=th.float32, device=self.device)
|
| 30 |
+
aux = th.tensor(aux_samps, dtype=th.float32, device=self.device)
|
| 31 |
+
aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device)
|
| 32 |
+
aux = aux.unsqueeze(0)
|
| 33 |
+
print("raw",raw.shape)
|
| 34 |
+
print("aux",aux.shape)
|
| 35 |
+
sps, sps2, sps3, spk_pred = self.nnet(raw, aux, aux_len)
|
| 36 |
+
sp_samps = np.squeeze(sps.detach().cpu().numpy())
|
| 37 |
+
return sp_samps
|
| 38 |
+
|
| 39 |
+
class InferencePipeline:
|
| 40 |
+
"""
|
| 41 |
+
外部只需传入 config,即可完成:
|
| 42 |
+
1) 模型实例化 (含 hydra.instantiate 逻辑)
|
| 43 |
+
2) 加载 checkpoint
|
| 44 |
+
3) 推理
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, config):
|
| 47 |
+
"""
|
| 48 |
+
在构造时就把所有初始化做好,包括:
|
| 49 |
+
- hydra.instantiate(config.model) -> 得到一个 nn.Module
|
| 50 |
+
- 用 NnetComputer(...) 封装
|
| 51 |
+
"""
|
| 52 |
+
# 如果 config.model 里含有 _target_ 字段,可以用 hydra.instantiate
|
| 53 |
+
# 注意: hydra.instantiate 需要在这里显式地导入 hydra.utils
|
| 54 |
+
|
| 55 |
+
# 1. 根据 config.model 构建模型
|
| 56 |
+
model_inst = hydra.utils.instantiate(config.model)
|
| 57 |
+
|
| 58 |
+
self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst)
|
| 59 |
+
|
| 60 |
+
def run_inference(self, input_audio_path: str, enroll_audio_path: str) -> str:
|
| 61 |
+
"""
|
| 62 |
+
给定混合音频 + enroll 音频,执行推理并返回输出文件路径。
|
| 63 |
+
"""
|
| 64 |
+
# 1. 读取音频
|
| 65 |
+
mix_samps, sr = sf.read(input_audio_path)
|
| 66 |
+
aux_samps, sr2 = sf.read(enroll_audio_path)
|
| 67 |
+
|
| 68 |
+
# 2. 调用底层 compute
|
| 69 |
+
samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps))
|
| 70 |
+
norm = np.linalg.norm(mix_samps, np.inf)
|
| 71 |
+
samps = samps[:mix_samps.size]
|
| 72 |
+
samps = samps * norm / np.max(np.abs(samps))
|
| 73 |
+
|
| 74 |
+
# 3. 写到临时文件
|
| 75 |
+
out_wav = "temp_extracted.wav"
|
| 76 |
+
sf.write(out_wav, samps, sr)
|
| 77 |
+
return out_wav
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
cfg = OmegaConf.load("config/config.yaml")
|
| 81 |
+
pipeline = InferencePipeline(cfg)
|
| 82 |
+
|
| 83 |
+
mix_path = "test_output_mixture.wav"
|
| 84 |
+
enroll_path = "test_mix.wav"
|
| 85 |
+
out_wav = pipeline.run_inference(mix_path, enroll_path)
|
| 86 |
+
print("Done:", out_wav)
|
| 87 |
+
|
| 88 |
+
|
model/__pycache__/cnns.cpython-37.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
model/__pycache__/cnns.cpython-38.pyc
ADDED
|
Binary file (8.01 kB). View file
|
|
|
model/__pycache__/norm.cpython-37.pyc
ADDED
|
Binary file (3.88 kB). View file
|
|
|
model/__pycache__/norm.cpython-38.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
model/__pycache__/spex_plus.cpython-37.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
model/__pycache__/spex_plus.cpython-38.pyc
ADDED
|
Binary file (5.98 kB). View file
|
|
|
{nnet → model}/cnns.py
RENAMED
|
@@ -2,8 +2,9 @@
|
|
| 2 |
|
| 3 |
import torch as th
|
| 4 |
import torch.nn as nn
|
|
|
|
| 5 |
|
| 6 |
-
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
|
| 7 |
|
| 8 |
class Conv1D(nn.Conv1d):
|
| 9 |
"""
|
|
@@ -58,12 +59,23 @@ class TCNBlock(nn.Module):
|
|
| 58 |
conv_channels=512,
|
| 59 |
kernel_size=3,
|
| 60 |
dilation=1,
|
| 61 |
-
causal=False
|
|
|
|
| 62 |
super(TCNBlock, self).__init__()
|
| 63 |
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
|
| 64 |
self.prelu1 = nn.PReLU()
|
| 65 |
-
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
|
| 68 |
dilation * (kernel_size - 1))
|
| 69 |
self.dconv = nn.Conv1d(
|
|
@@ -75,8 +87,8 @@ class TCNBlock(nn.Module):
|
|
| 75 |
dilation=dilation,
|
| 76 |
bias=True)
|
| 77 |
self.prelu2 = nn.PReLU()
|
| 78 |
-
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 79 |
-
|
| 80 |
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
|
| 81 |
self.causal = causal
|
| 82 |
self.dconv_pad = dconv_pad
|
|
@@ -108,12 +120,40 @@ class TCNBlock_Spk(nn.Module):
|
|
| 108 |
conv_channels=512,
|
| 109 |
kernel_size=3,
|
| 110 |
dilation=1,
|
| 111 |
-
causal=False
|
|
|
|
|
|
|
| 112 |
super(TCNBlock_Spk, self).__init__()
|
| 113 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
self.prelu1 = nn.PReLU()
|
| 115 |
-
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
|
| 118 |
dilation * (kernel_size - 1))
|
| 119 |
self.dconv = nn.Conv1d(
|
|
@@ -125,19 +165,80 @@ class TCNBlock_Spk(nn.Module):
|
|
| 125 |
dilation=dilation,
|
| 126 |
bias=True)
|
| 127 |
self.prelu2 = nn.PReLU()
|
| 128 |
-
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 129 |
-
|
| 130 |
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
|
| 131 |
self.causal = causal
|
| 132 |
self.dconv_pad = dconv_pad
|
| 133 |
self.dilation = dilation
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def forward(self, x, aux):
|
| 136 |
# Repeatedly concated speaker embedding aux to each frame of the representation x
|
| 137 |
T = x.shape[-1]
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
y = self.conv1x1(y)
|
| 142 |
y = self.norm1(self.prelu1(y))
|
| 143 |
y = self.dconv(y)
|
|
|
|
| 2 |
|
| 3 |
import torch as th
|
| 4 |
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
|
| 7 |
+
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm, CumLN
|
| 8 |
|
| 9 |
class Conv1D(nn.Conv1d):
|
| 10 |
"""
|
|
|
|
| 59 |
conv_channels=512,
|
| 60 |
kernel_size=3,
|
| 61 |
dilation=1,
|
| 62 |
+
causal=False,
|
| 63 |
+
norm_type='gLN'):
|
| 64 |
super(TCNBlock, self).__init__()
|
| 65 |
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
|
| 66 |
self.prelu1 = nn.PReLU()
|
| 67 |
+
# self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 68 |
+
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 69 |
+
if norm_type == 'gLN':
|
| 70 |
+
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
|
| 71 |
+
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
|
| 72 |
+
elif norm_type == 'cLN':
|
| 73 |
+
self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
|
| 74 |
+
self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
|
| 75 |
+
elif norm_type == 'cgLN':
|
| 76 |
+
self.norm1 = CumLN(conv_channels, elementwise_affine=True)
|
| 77 |
+
self.norm2 = CumLN(conv_channels, elementwise_affine=True)
|
| 78 |
+
|
| 79 |
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
|
| 80 |
dilation * (kernel_size - 1))
|
| 81 |
self.dconv = nn.Conv1d(
|
|
|
|
| 87 |
dilation=dilation,
|
| 88 |
bias=True)
|
| 89 |
self.prelu2 = nn.PReLU()
|
| 90 |
+
# self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 91 |
+
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 92 |
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
|
| 93 |
self.causal = causal
|
| 94 |
self.dconv_pad = dconv_pad
|
|
|
|
| 120 |
conv_channels=512,
|
| 121 |
kernel_size=3,
|
| 122 |
dilation=1,
|
| 123 |
+
causal=False,
|
| 124 |
+
norm_type='gLN',
|
| 125 |
+
fusion_type='cat'):
|
| 126 |
super(TCNBlock_Spk, self).__init__()
|
| 127 |
+
self.fusion_type = fusion_type
|
| 128 |
+
if fusion_type == 'cat':
|
| 129 |
+
self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1)
|
| 130 |
+
if fusion_type in ('add', 'mul'):
|
| 131 |
+
self.fusion_linear = nn.Linear(spk_embed_dim, in_channels)
|
| 132 |
+
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
|
| 133 |
+
if fusion_type == 'film':
|
| 134 |
+
self.fusion_linear_1 = nn.Linear(spk_embed_dim, in_channels)
|
| 135 |
+
self.fusion_linear_2 = nn.Linear(spk_embed_dim, in_channels)
|
| 136 |
+
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
|
| 137 |
+
if fusion_type == 'att':
|
| 138 |
+
self.fusion_linear = nn.Linear(spk_embed_dim, in_channels)
|
| 139 |
+
self.average = Conv1D(in_channels, in_channels, kernel_size, kernel_size, groups=in_channels)
|
| 140 |
+
self.average.weight = nn.Parameter(th.ones(in_channels, 1, kernel_size) / kernel_size)
|
| 141 |
+
self.average.bias = nn.Parameter(th.zeros(in_channels))
|
| 142 |
+
for p in self.average.parameters():
|
| 143 |
+
p.requires_grad = False
|
| 144 |
+
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
|
| 145 |
self.prelu1 = nn.PReLU()
|
| 146 |
+
# self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 147 |
+
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 148 |
+
if norm_type == 'gLN':
|
| 149 |
+
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
|
| 150 |
+
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
|
| 151 |
+
elif norm_type == 'cLN':
|
| 152 |
+
self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
|
| 153 |
+
self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
|
| 154 |
+
elif norm_type == 'cgLN':
|
| 155 |
+
self.norm1 = CumLN(conv_channels, elementwise_affine=True)
|
| 156 |
+
self.norm2 = CumLN(conv_channels, elementwise_affine=True)
|
| 157 |
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
|
| 158 |
dilation * (kernel_size - 1))
|
| 159 |
self.dconv = nn.Conv1d(
|
|
|
|
| 165 |
dilation=dilation,
|
| 166 |
bias=True)
|
| 167 |
self.prelu2 = nn.PReLU()
|
| 168 |
+
# self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 169 |
+
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 170 |
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
|
| 171 |
self.causal = causal
|
| 172 |
self.dconv_pad = dconv_pad
|
| 173 |
self.dilation = dilation
|
| 174 |
|
| 175 |
+
def _concatenation(self, aux, output, L):
|
| 176 |
+
aux_concat = th.unsqueeze(aux, -1)
|
| 177 |
+
aux_concat = aux_concat.repeat(1, 1, L)
|
| 178 |
+
# -> [B, N(embeddings_size), L]
|
| 179 |
+
output = th.cat([output, aux_concat], 1)
|
| 180 |
+
# -> [B, N(input_size + embeddings_size), L]
|
| 181 |
+
return output
|
| 182 |
+
|
| 183 |
+
def _addition(self, aux, output, L, fusion_linear):
|
| 184 |
+
aux_add = fusion_linear(aux)
|
| 185 |
+
# -> [B, N(input_size)]
|
| 186 |
+
aux_add = th.unsqueeze(aux_add, -1)
|
| 187 |
+
aux_add = aux_add.repeat(1, 1, L)
|
| 188 |
+
# -> [B, N(input_size), L]
|
| 189 |
+
output = output + aux_add
|
| 190 |
+
# -> [B, N(input_size, L]
|
| 191 |
+
return output
|
| 192 |
+
|
| 193 |
+
def _multiplication(self, aux, output, L, fusion_linear):
|
| 194 |
+
aux_mul = fusion_linear(aux)
|
| 195 |
+
# -> [B, N(input_size)]
|
| 196 |
+
aux_mul = th.unsqueeze(aux_mul, -1)
|
| 197 |
+
aux_mul = aux_mul.repeat(1, 1, L)
|
| 198 |
+
# -> [B, N(input_size), L]
|
| 199 |
+
output = output * aux_mul
|
| 200 |
+
# -> [B, N(input_size, L]
|
| 201 |
+
return output
|
| 202 |
+
|
| 203 |
+
def _attention(self, aux, output, fusion_linear):
|
| 204 |
+
L = output.shape[-1]
|
| 205 |
+
aux_att = fusion_linear(aux)
|
| 206 |
+
aux_att = th.unsqueeze(aux_att, -1)
|
| 207 |
+
aux_att = aux_att.repeat(1, 1, L)
|
| 208 |
+
att = th.sum(output * aux_att, 1, keepdim=True)
|
| 209 |
+
att = F.softmax(att, -1)
|
| 210 |
+
att = att * aux_att
|
| 211 |
+
return att + aux_att
|
| 212 |
+
|
| 213 |
+
def _film(self, aux, output, L):
|
| 214 |
+
output = self._multiplication(aux, output, L, self.fusion_linear_1)
|
| 215 |
+
# -> [B, N(input_size, L]
|
| 216 |
+
output = self._addition(aux, output, L, self.fusion_linear_2)
|
| 217 |
+
# -> [B, N(input_size, L]
|
| 218 |
+
return output
|
| 219 |
+
|
| 220 |
def forward(self, x, aux):
|
| 221 |
# Repeatedly concated speaker embedding aux to each frame of the representation x
|
| 222 |
T = x.shape[-1]
|
| 223 |
+
if self.fusion_type == 'cat':
|
| 224 |
+
y = self._concatenation(aux, x, T)
|
| 225 |
+
# -> [B, N(input_size + embeddings_size), L]
|
| 226 |
+
if self.fusion_type == 'add':
|
| 227 |
+
y = self._addition(aux, x, T, self.fusion_linear)
|
| 228 |
+
# -> [B, N(input_size), L]
|
| 229 |
+
if self.fusion_type == 'mul':
|
| 230 |
+
y = self._multiplication(aux, x, T, self.fusion_linear)
|
| 231 |
+
# -> [B, N(input_size), L]
|
| 232 |
+
if self.fusion_type == 'film':
|
| 233 |
+
y = self._film(aux, x, T)
|
| 234 |
+
# -> [B, N(input_size), L]
|
| 235 |
+
if self.fusion_type == 'att':
|
| 236 |
+
output_avg = self.average(x)
|
| 237 |
+
att_out = self._attention(aux, output_avg, self.fusion_linear)
|
| 238 |
+
upsampling = nn.Upsample(size=T, mode='nearest')
|
| 239 |
+
att_out = upsampling(att_out)
|
| 240 |
+
y = x * att_out
|
| 241 |
+
|
| 242 |
y = self.conv1x1(y)
|
| 243 |
y = self.norm1(self.prelu1(y))
|
| 244 |
y = self.dconv(y)
|
{nnet → model}/norm.py
RENAMED
|
@@ -22,6 +22,44 @@ class ChannelwiseLayerNorm(nn.LayerNorm):
|
|
| 22 |
x = th.transpose(x, 1, 2)
|
| 23 |
return x
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class GlobalLayerNorm(nn.Module):
|
| 26 |
"""
|
| 27 |
Global layer normalization
|
|
@@ -57,3 +95,4 @@ class GlobalLayerNorm(nn.Module):
|
|
| 57 |
def extra_repr(self):
|
| 58 |
return "{normalized_dim}, eps={eps}, " \
|
| 59 |
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
|
|
|
|
|
| 22 |
x = th.transpose(x, 1, 2)
|
| 23 |
return x
|
| 24 |
|
| 25 |
+
|
| 26 |
+
class CumLN(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Cumulative Global layer normalization
|
| 29 |
+
Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
|
| 30 |
+
Output: 3D tensor with same shape
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, dim, eps=1e-05, elementwise_affine=True):
|
| 34 |
+
super(CumLN, self).__init__()
|
| 35 |
+
self.eps = eps
|
| 36 |
+
self.elementwise_affine = elementwise_affine
|
| 37 |
+
self.normalized_dim = dim
|
| 38 |
+
if elementwise_affine:
|
| 39 |
+
self.beta = nn.Parameter(th.zeros(dim, 1))
|
| 40 |
+
self.gamma = nn.Parameter(th.ones(dim, 1))
|
| 41 |
+
else:
|
| 42 |
+
self.register_parameter("weight", None)
|
| 43 |
+
self.register_parameter("bias", None)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
if x.dim() != 3:
|
| 47 |
+
raise RuntimeError("{} requires a 3D tensor input".format(self.__class__.__name__))
|
| 48 |
+
batch, chan, spec_len = x.size()
|
| 49 |
+
cum_sum = th.cumsum(x.sum(1, keepdim=True), dim=-1)
|
| 50 |
+
cum_pow_sum = th.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) #th.cumsum 后加前 逐元素相加
|
| 51 |
+
cnt = th.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype, device=x.device).view(1, 1, -1)
|
| 52 |
+
cum_mean = cum_sum / cnt
|
| 53 |
+
cum_var = cum_pow_sum / cnt - cum_mean.pow(2)
|
| 54 |
+
normalized_x = (x - cum_mean) / (cum_var + self.eps).sqrt()
|
| 55 |
+
if self.elementwise_affine:
|
| 56 |
+
normalized_x = self.gamma * normalized_x + self.beta
|
| 57 |
+
return normalized_x
|
| 58 |
+
|
| 59 |
+
def extra_repr(self):
|
| 60 |
+
return "{normalized_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
class GlobalLayerNorm(nn.Module):
|
| 64 |
"""
|
| 65 |
Global layer normalization
|
|
|
|
| 95 |
def extra_repr(self):
|
| 96 |
return "{normalized_dim}, eps={eps}, " \
|
| 97 |
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
| 98 |
+
|
{nnet → model}/spex_plus.py
RENAMED
|
@@ -6,15 +6,9 @@ import torch.nn.functional as F
|
|
| 6 |
|
| 7 |
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
|
| 8 |
from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
from .ResNet34 import Speaker_Encoder
|
| 12 |
-
# from .sunine.trainer.utils import PreEmphasis
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
|
| 17 |
-
# 注意下维度 是 B N T 还是 B T N
|
| 18 |
|
| 19 |
class SpEx_Plus(nn.Module):
|
| 20 |
def __init__(self,
|
|
@@ -28,14 +22,15 @@ class SpEx_Plus(nn.Module):
|
|
| 28 |
Q=3,
|
| 29 |
num_spks=101,
|
| 30 |
spk_embed_dim=256,
|
| 31 |
-
sample_rate = 16000,
|
| 32 |
-
n_mels = 80,
|
| 33 |
causal=False,
|
|
|
|
|
|
|
|
|
|
| 34 |
):
|
| 35 |
super(SpEx_Plus, self).__init__()
|
|
|
|
| 36 |
# n x S => n x N x T, S = 4s*8000 = 32000
|
| 37 |
-
|
| 38 |
-
self.n_mels = n_mels
|
| 39 |
self.L1 = L1
|
| 40 |
self.L2 = L2
|
| 41 |
self.L3 = L3
|
|
@@ -43,82 +38,49 @@ class SpEx_Plus(nn.Module):
|
|
| 43 |
self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
|
| 44 |
self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
|
| 45 |
# before repeat blocks, always cLN
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 50 |
-
self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 51 |
-
self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 52 |
-
self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 53 |
-
self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 54 |
-
self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 55 |
-
self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 56 |
-
self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 57 |
-
# n x O x T => n x N x T
|
| 58 |
-
self.mask1 = Conv1D(O, N, 1)
|
| 59 |
-
self.mask2 = Conv1D(O, N, 1)
|
| 60 |
-
self.mask3 = Conv1D(O, N, 1)
|
| 61 |
-
# using ConvTrans1D: n x N x T => n x 1 x To
|
| 62 |
-
# To = (T - 1) * L // 2 + L
|
| 63 |
-
#############################################################
|
| 64 |
self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
|
| 65 |
self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
|
| 66 |
self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
|
| 67 |
self.num_spks = num_spks
|
| 68 |
-
# self.spk_encoder = nn.Sequential(
|
| 69 |
-
# ChannelwiseLayerNorm(3*N),
|
| 70 |
-
# Conv1D(3*N, O, 1),
|
| 71 |
-
# ResBlock(O, O),
|
| 72 |
-
# ResBlock(O, P),
|
| 73 |
-
# ResBlock(P, P),
|
| 74 |
-
# Conv1D(P, spk_embed_dim, 1),
|
| 75 |
-
# )
|
| 76 |
-
|
| 77 |
-
# self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
|
| 78 |
-
|
| 79 |
-
# 改为pretrain
|
| 80 |
-
# 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
|
| 81 |
-
# /work105/youzhenghai/model/resnet_asp_aam_adamw_welr
|
| 82 |
-
# import ..sunine/trainer/speaker encoder
|
| 83 |
-
# **kwargs 无需关心 找到 self.hparams就行 按照 main_infer改就行
|
| 84 |
-
#############################################################
|
| 85 |
-
|
| 86 |
-
# # 1. Acoustic Feature
|
| 87 |
-
# self.mel_trans = th.nn.Sequential(
|
| 88 |
-
# PreEmphasis(),
|
| 89 |
-
# torchaudio.transforms.MelSpectrogram(sample_rate=self.sample_rate, n_fft=512,
|
| 90 |
-
# win_length=400, hop_length=160, window_fn=th.hamming_window, n_mels=self.n_mels)
|
| 91 |
-
# )
|
| 92 |
-
|
| 93 |
-
# self.instancenorm = nn.InstanceNorm1d(self.n_mels)
|
| 94 |
-
|
| 95 |
-
# # 在调用的地方设置超参数 记得后面写为参数传入
|
| 96 |
-
# self.hparams = {'embedding_dim': spk_embed_dim, 'pooling_type': 'ASP' , 'n_mels': self.n_mels}
|
| 97 |
-
# # 使用 **self.hparams 调用函数
|
| 98 |
-
# self.speaker_encoder = Speaker_Encoder(**self.hparams)
|
| 99 |
-
self.speaker_embedding_extracter = Speaker_Model(pooling_type='ASP', spk_embed_dim=spk_embed_dim, sample_rate=self.sample_rate, n_mels=self.n_mels)
|
| 100 |
self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
def forward(self, x, aux, aux_len):
|
| 124 |
if x.dim() >= 3:
|
|
@@ -128,8 +90,9 @@ class SpEx_Plus(nn.Module):
|
|
| 128 |
# when inference, only one utt
|
| 129 |
if x.dim() == 1:
|
| 130 |
x = th.unsqueeze(x, 0)
|
| 131 |
-
|
| 132 |
# n x 1 x S => n x N x T
|
|
|
|
|
|
|
| 133 |
w1 = F.relu(self.encoder_1d_short(x))
|
| 134 |
T = w1.shape[-1]
|
| 135 |
xlen1 = x.shape[-1]
|
|
@@ -137,42 +100,79 @@ class SpEx_Plus(nn.Module):
|
|
| 137 |
xlen3 = (T - 1) * (self.L1 // 2) + self.L3
|
| 138 |
w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
|
| 139 |
w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
|
| 140 |
-
|
| 141 |
# n x 3N x T
|
| 142 |
-
y = self.ln(th.cat([w1, w2, w3], 1))
|
| 143 |
-
# n x O x T
|
| 144 |
-
y = self.proj(y)
|
| 145 |
-
|
| 146 |
# speaker encoder (share params from speech encoder)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
|
| 155 |
-
|
| 156 |
-
# aux = self.spk_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1))
|
| 157 |
-
# aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
|
| 158 |
-
# aux_T = ((aux_T // 3) // 3) // 3
|
| 159 |
-
# aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
|
|
|
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
-
|
| 171 |
-
#aux = aux.reshape(-1, self.hparams.nPerSpeaker, self.spk_embed_dim)
|
| 172 |
-
# loss, acc = self.loss(x, label)
|
| 173 |
-
# return loss.mean(), acc
|
| 174 |
-
# 考虑 loss 是否也要
|
| 175 |
|
|
|
|
|
|
|
|
|
|
| 176 |
y = self.conv_block_1(y, aux)
|
| 177 |
y = self.conv_block_1_other(y)
|
| 178 |
y = self.conv_block_2(y, aux)
|
|
@@ -186,62 +186,35 @@ class SpEx_Plus(nn.Module):
|
|
| 186 |
m1 = F.relu(self.mask1(y))
|
| 187 |
m2 = F.relu(self.mask2(y))
|
| 188 |
m3 = F.relu(self.mask3(y))
|
| 189 |
-
S1 = w1 * m1
|
| 190 |
-
S2 = w2 * m2
|
| 191 |
-
S3 = w3 * m3
|
| 192 |
-
|
| 193 |
-
return self.decoder_1d_short(S1), self.decoder_1d_middle(S2)[:, :xlen1], self.decoder_1d_long(S3)[:, :xlen1], self.pred_linear(aux)
|
| 194 |
-
|
| 195 |
-
class PreEmphasis(th.nn.Module):
|
| 196 |
-
def __init__(self, coef: float = 0.97):
|
| 197 |
-
super().__init__()
|
| 198 |
-
self.coef = coef
|
| 199 |
-
# make kernel
|
| 200 |
-
# In pyth, the convolution operation uses cross-correlation. So, filter is flipped.
|
| 201 |
-
self.register_buffer(
|
| 202 |
-
'flipped_filter', th.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
|
| 203 |
-
)
|
| 204 |
|
| 205 |
-
|
| 206 |
-
assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
|
| 207 |
-
# reflect padding to match lengths of in/out
|
| 208 |
-
inputs = inputs.unsqueeze(1)
|
| 209 |
-
inputs = F.pad(inputs, (1, 0), 'reflect')
|
| 210 |
-
return F.conv1d(inputs, self.flipped_filter).squeeze(1)
|
| 211 |
|
| 212 |
|
| 213 |
class Speaker_Model(nn.Module):
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
self.
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
x = self.instancenorm(x)
|
| 242 |
-
x = self.speaker_encoder(x)
|
| 243 |
-
return x
|
| 244 |
-
|
| 245 |
-
def forward(self, x):
|
| 246 |
-
x = self.extract_speaker_embedding(x)
|
| 247 |
-
return x
|
|
|
|
| 6 |
|
| 7 |
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
|
| 8 |
from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
|
| 9 |
+
import warnings
|
| 10 |
|
| 11 |
+
# inference aux_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class SpEx_Plus(nn.Module):
|
| 14 |
def __init__(self,
|
|
|
|
| 22 |
Q=3,
|
| 23 |
num_spks=101,
|
| 24 |
spk_embed_dim=256,
|
|
|
|
|
|
|
| 25 |
causal=False,
|
| 26 |
+
norm_type='gLN',
|
| 27 |
+
fusion_type='cat',
|
| 28 |
+
is_innorm=False,
|
| 29 |
):
|
| 30 |
super(SpEx_Plus, self).__init__()
|
| 31 |
+
|
| 32 |
# n x S => n x N x T, S = 4s*8000 = 32000
|
| 33 |
+
|
|
|
|
| 34 |
self.L1 = L1
|
| 35 |
self.L2 = L2
|
| 36 |
self.L3 = L3
|
|
|
|
| 38 |
self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
|
| 39 |
self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
|
| 40 |
# before repeat blocks, always cLN
|
| 41 |
+
|
| 42 |
+
self.instancenorm = nn.InstanceNorm1d(N)
|
| 43 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
|
| 45 |
self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
|
| 46 |
self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
|
| 47 |
self.num_spks = num_spks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
|
| 49 |
+
self.is_innorm = is_innorm
|
| 50 |
+
|
| 51 |
+
if causal and norm_type not in ["cgLN", "cLN"]:
|
| 52 |
+
norm_type = "cLN"
|
| 53 |
+
warnings.warn(
|
| 54 |
+
"In causal configuration cumulative layer normalization (cgLN)"
|
| 55 |
+
"or channel-wise layer normalization (chanLN) "
|
| 56 |
+
f"must be used. Changing {norm_type} to cLN"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.speaker_encoder = Speaker_Model(
|
| 60 |
+
L1=L1,
|
| 61 |
+
L2=L2,
|
| 62 |
+
L3=L3,
|
| 63 |
+
N=N,
|
| 64 |
+
O=O,
|
| 65 |
+
P=P,
|
| 66 |
+
spk_embed_dim=spk_embed_dim,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.extractor = Extractor(
|
| 70 |
+
L1=L1,
|
| 71 |
+
L2=L2,
|
| 72 |
+
L3=L3,
|
| 73 |
+
N=N,
|
| 74 |
+
B=B,
|
| 75 |
+
O=O,
|
| 76 |
+
P=P,
|
| 77 |
+
Q=Q,
|
| 78 |
+
num_spks=num_spks,
|
| 79 |
+
spk_embed_dim=spk_embed_dim,
|
| 80 |
+
causal=causal,
|
| 81 |
+
fusion_type=fusion_type,
|
| 82 |
+
norm_type=norm_type,
|
| 83 |
+
)
|
| 84 |
|
| 85 |
def forward(self, x, aux, aux_len):
|
| 86 |
if x.dim() >= 3:
|
|
|
|
| 90 |
# when inference, only one utt
|
| 91 |
if x.dim() == 1:
|
| 92 |
x = th.unsqueeze(x, 0)
|
|
|
|
| 93 |
# n x 1 x S => n x N x T
|
| 94 |
+
|
| 95 |
+
|
| 96 |
w1 = F.relu(self.encoder_1d_short(x))
|
| 97 |
T = w1.shape[-1]
|
| 98 |
xlen1 = x.shape[-1]
|
|
|
|
| 100 |
xlen3 = (T - 1) * (self.L1 // 2) + self.L3
|
| 101 |
w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
|
| 102 |
w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
|
|
|
|
| 103 |
# n x 3N x T
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# speaker encoder (share params from speech encoder)
|
| 105 |
+
aux_w1 = F.relu(self.encoder_1d_short(aux))
|
| 106 |
+
aux_T_shape = aux_w1.shape[-1]
|
| 107 |
+
aux_len1 = aux.shape[-1]
|
| 108 |
+
aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2
|
| 109 |
+
aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3
|
| 110 |
+
aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0)))
|
| 111 |
+
aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0)))
|
| 112 |
|
| 113 |
+
aux = self.speaker_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1), aux_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
if self.is_innorm:
|
| 116 |
+
w1 = self.instancenorm(w1)
|
| 117 |
+
w2 = self.instancenorm(w2)
|
| 118 |
+
w3 = self.instancenorm(w3)
|
| 119 |
|
| 120 |
+
m1, m2, m3 = self.extractor(w1, w2, w3, aux)
|
| 121 |
|
| 122 |
+
S1 = w1 * m1
|
| 123 |
+
S2 = w2 * m2
|
| 124 |
+
S3 = w3 * m3
|
| 125 |
+
|
| 126 |
+
return self.decoder_1d_short(S1), self.decoder_1d_middle(S2)[:, :xlen1], self.decoder_1d_long(S3)[:, :xlen1], self.pred_linear(aux)
|
| 127 |
|
| 128 |
+
class Extractor(nn.Module):
|
| 129 |
+
def __init__(self,
|
| 130 |
+
L1=20,
|
| 131 |
+
L2=80,
|
| 132 |
+
L3=160,
|
| 133 |
+
N=256,
|
| 134 |
+
B=8,
|
| 135 |
+
O=256,
|
| 136 |
+
P=512,
|
| 137 |
+
Q=3,
|
| 138 |
+
num_spks=101,
|
| 139 |
+
spk_embed_dim=256,
|
| 140 |
+
causal=False,
|
| 141 |
+
fusion_type='cat',
|
| 142 |
+
norm_type='gLN',
|
| 143 |
+
):
|
| 144 |
+
super(Extractor, self).__init__()
|
| 145 |
+
# n x N x T => n x O x T
|
| 146 |
+
self.ln = ChannelwiseLayerNorm(3*N)
|
| 147 |
+
self.proj = Conv1D(3*N, O, 1)
|
| 148 |
+
self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
|
| 149 |
+
self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
|
| 150 |
+
self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
|
| 151 |
+
self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
|
| 152 |
+
self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
|
| 153 |
+
self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
|
| 154 |
+
self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
|
| 155 |
+
self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
|
| 156 |
+
# n x O x T => n x N x T
|
| 157 |
+
self.mask1 = Conv1D(O, N, 1)
|
| 158 |
+
self.mask2 = Conv1D(O, N, 1)
|
| 159 |
+
self.mask3 = Conv1D(O, N, 1)
|
| 160 |
|
| 161 |
+
def _build_stacks(self, num_blocks, **block_kwargs):
|
| 162 |
+
"""
|
| 163 |
+
Stack B numbers of TCN block, the first TCN block takes the speaker embedding
|
| 164 |
+
"""
|
| 165 |
+
blocks = [
|
| 166 |
+
TCNBlock(**block_kwargs, dilation=(2**b))
|
| 167 |
+
for b in range(1,num_blocks)
|
| 168 |
+
]
|
| 169 |
+
return nn.Sequential(*blocks)
|
| 170 |
|
| 171 |
+
def forward(self, w1, w2, w3, aux):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
y = self.ln(th.cat([w1, w2, w3], 1))
|
| 174 |
+
# n x O x T
|
| 175 |
+
y = self.proj(y)
|
| 176 |
y = self.conv_block_1(y, aux)
|
| 177 |
y = self.conv_block_1_other(y)
|
| 178 |
y = self.conv_block_2(y, aux)
|
|
|
|
| 186 |
m1 = F.relu(self.mask1(y))
|
| 187 |
m2 = F.relu(self.mask2(y))
|
| 188 |
m3 = F.relu(self.mask3(y))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
return m1, m2, m3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
class Speaker_Model(nn.Module):
|
| 194 |
+
def __init__(self,
|
| 195 |
+
L1=20,
|
| 196 |
+
L2=80,
|
| 197 |
+
L3=160,
|
| 198 |
+
N=256,
|
| 199 |
+
O=256,
|
| 200 |
+
P=512,
|
| 201 |
+
spk_embed_dim=256,
|
| 202 |
+
):
|
| 203 |
+
super(Speaker_Model, self).__init__()
|
| 204 |
+
self.L1 = L1
|
| 205 |
+
self.L2 = L2
|
| 206 |
+
self.L3 = L3
|
| 207 |
+
self.spk_encoder = nn.Sequential(
|
| 208 |
+
ChannelwiseLayerNorm(3*N),
|
| 209 |
+
Conv1D(3*N, O, 1),
|
| 210 |
+
ResBlock(O, O),
|
| 211 |
+
ResBlock(O, P),
|
| 212 |
+
ResBlock(P, P),
|
| 213 |
+
Conv1D(P, spk_embed_dim, 1),
|
| 214 |
+
)
|
| 215 |
+
def forward(self, aux, aux_len):
|
| 216 |
+
aux = self.spk_encoder(aux)
|
| 217 |
+
aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
|
| 218 |
+
aux_T = ((aux_T // 3) // 3) // 3
|
| 219 |
+
aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
|
| 220 |
+
return aux
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nnet/ResNet34.py
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
#! /usr/bin/python
|
| 2 |
-
# -*- encoding: utf-8 -*-
|
| 3 |
-
'''
|
| 4 |
-
Fast ResNet
|
| 5 |
-
https://arxiv.org/pdf/2003.11982.pdf
|
| 6 |
-
'''
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
from torch.nn import Parameter
|
| 12 |
-
try:
|
| 13 |
-
from .pooling import *
|
| 14 |
-
except:
|
| 15 |
-
from pooling import *
|
| 16 |
-
|
| 17 |
-
class SEBasicBlock(nn.Module):
|
| 18 |
-
expansion = 1
|
| 19 |
-
|
| 20 |
-
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
| 21 |
-
super(SEBasicBlock, self).__init__()
|
| 22 |
-
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 23 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
| 24 |
-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
| 25 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
| 26 |
-
self.relu = nn.ReLU(inplace=True)
|
| 27 |
-
self.se = SELayer(planes, reduction)
|
| 28 |
-
self.downsample = downsample
|
| 29 |
-
self.stride = stride
|
| 30 |
-
|
| 31 |
-
def forward(self, x):
|
| 32 |
-
residual = x
|
| 33 |
-
|
| 34 |
-
out = self.conv1(x)
|
| 35 |
-
out = self.relu(out)
|
| 36 |
-
out = self.bn1(out)
|
| 37 |
-
|
| 38 |
-
out = self.conv2(out)
|
| 39 |
-
out = self.bn2(out)
|
| 40 |
-
out = self.se(out)
|
| 41 |
-
|
| 42 |
-
if self.downsample is not None:
|
| 43 |
-
residual = self.downsample(x)
|
| 44 |
-
|
| 45 |
-
out += residual
|
| 46 |
-
out = self.relu(out)
|
| 47 |
-
return out
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class SEBottleneck(nn.Module):
|
| 51 |
-
expansion = 4
|
| 52 |
-
|
| 53 |
-
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
| 54 |
-
super(SEBottleneck, self).__init__()
|
| 55 |
-
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 56 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
| 57 |
-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 58 |
-
padding=1, bias=False)
|
| 59 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
| 60 |
-
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 61 |
-
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 62 |
-
self.relu = nn.ReLU(inplace=True)
|
| 63 |
-
self.se = SELayer(planes * 4, reduction)
|
| 64 |
-
self.downsample = downsample
|
| 65 |
-
self.stride = stride
|
| 66 |
-
|
| 67 |
-
def forward(self, x):
|
| 68 |
-
residual = x
|
| 69 |
-
|
| 70 |
-
out = self.conv1(x)
|
| 71 |
-
out = self.bn1(out)
|
| 72 |
-
out = self.relu(out)
|
| 73 |
-
|
| 74 |
-
out = self.conv2(out)
|
| 75 |
-
out = self.bn2(out)
|
| 76 |
-
out = self.relu(out)
|
| 77 |
-
|
| 78 |
-
out = self.conv3(out)
|
| 79 |
-
out = self.bn3(out)
|
| 80 |
-
out = self.se(out)
|
| 81 |
-
|
| 82 |
-
if self.downsample is not None:
|
| 83 |
-
residual = self.downsample(x)
|
| 84 |
-
|
| 85 |
-
out += residual
|
| 86 |
-
out = self.relu(out)
|
| 87 |
-
|
| 88 |
-
return out
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class SELayer(nn.Module):
|
| 92 |
-
def __init__(self, channel, reduction=8):
|
| 93 |
-
super(SELayer, self).__init__()
|
| 94 |
-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 95 |
-
self.fc = nn.Sequential(
|
| 96 |
-
nn.Linear(channel, channel // reduction),
|
| 97 |
-
nn.ReLU(inplace=True),
|
| 98 |
-
nn.Linear(channel // reduction, channel),
|
| 99 |
-
nn.Sigmoid()
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
def forward(self, x):
|
| 103 |
-
b, c, _, _ = x.size()
|
| 104 |
-
y = self.avg_pool(x).view(b, c)
|
| 105 |
-
y = self.fc(y).view(b, c, 1, 1)
|
| 106 |
-
return x * y
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
class ResNetSE(nn.Module):
|
| 110 |
-
def __init__(self, block, layers, num_filters, embedding_dim, n_mels=80, pooling_type="TSP", **kwargs):
|
| 111 |
-
super(ResNetSE, self).__init__()
|
| 112 |
-
|
| 113 |
-
self.inplanes = num_filters[0]
|
| 114 |
-
self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=3, stride=(1, 1), padding=1,
|
| 115 |
-
bias=False)
|
| 116 |
-
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
| 117 |
-
self.relu = nn.ReLU(inplace=True)
|
| 118 |
-
|
| 119 |
-
self.layer1 = self._make_layer(block, num_filters[0], layers[0])
|
| 120 |
-
self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
|
| 121 |
-
self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
|
| 122 |
-
self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2))
|
| 123 |
-
|
| 124 |
-
out_dim = num_filters[3] * block.expansion * (n_mels//8)
|
| 125 |
-
|
| 126 |
-
if pooling_type == "Temporal_Average_Pooling" or pooling_type == "TAP":
|
| 127 |
-
self.pooling = Temporal_Average_Pooling()
|
| 128 |
-
self.bn2 = nn.BatchNorm1d(out_dim)
|
| 129 |
-
self.fc = nn.Linear(out_dim, embedding_dim)
|
| 130 |
-
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 131 |
-
|
| 132 |
-
elif pooling_type == "Temporal_Statistics_Pooling" or pooling_type == "TSP":
|
| 133 |
-
self.pooling = Temporal_Statistics_Pooling()
|
| 134 |
-
self.bn2 = nn.BatchNorm1d(out_dim * 2)
|
| 135 |
-
self.fc = nn.Linear(out_dim * 2, embedding_dim)
|
| 136 |
-
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 137 |
-
|
| 138 |
-
elif pooling_type == "Self_Attentive_Pooling" or pooling_type == "SAP":
|
| 139 |
-
self.pooling = Self_Attentive_Pooling(out_dim)
|
| 140 |
-
self.bn2 = nn.BatchNorm1d(out_dim)
|
| 141 |
-
self.fc = nn.Linear(out_dim, embedding_dim)
|
| 142 |
-
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 143 |
-
|
| 144 |
-
elif pooling_type == "Attentive_Statistics_Pooling" or pooling_type == "ASP":
|
| 145 |
-
self.pooling = Attentive_Statistics_Pooling(out_dim)
|
| 146 |
-
self.bn2 = nn.BatchNorm1d(out_dim * 2)
|
| 147 |
-
self.fc = nn.Linear(out_dim * 2, embedding_dim)
|
| 148 |
-
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 149 |
-
|
| 150 |
-
else:
|
| 151 |
-
raise ValueError('{} pooling type is not defined'.format(pooling_type))
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
for m in self.modules():
|
| 155 |
-
if isinstance(m, nn.Conv2d):
|
| 156 |
-
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 157 |
-
elif isinstance(m, nn.BatchNorm2d):
|
| 158 |
-
nn.init.constant_(m.weight, 1)
|
| 159 |
-
nn.init.constant_(m.bias, 0)
|
| 160 |
-
|
| 161 |
-
def _make_layer(self, block, planes, blocks, stride=1):
|
| 162 |
-
downsample = None
|
| 163 |
-
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 164 |
-
downsample = nn.Sequential(
|
| 165 |
-
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 166 |
-
kernel_size=1, stride=stride, bias=False),
|
| 167 |
-
nn.BatchNorm2d(planes * block.expansion),
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
layers = []
|
| 171 |
-
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 172 |
-
self.inplanes = planes * block.expansion
|
| 173 |
-
for i in range(1, blocks):
|
| 174 |
-
layers.append(block(self.inplanes, planes))
|
| 175 |
-
|
| 176 |
-
return nn.Sequential(*layers)
|
| 177 |
-
|
| 178 |
-
def forward(self, x):
|
| 179 |
-
x = x.unsqueeze(1)
|
| 180 |
-
x = self.conv1(x)
|
| 181 |
-
x = self.bn1(x)
|
| 182 |
-
x = self.relu(x)
|
| 183 |
-
|
| 184 |
-
x = self.layer1(x)
|
| 185 |
-
x = self.layer2(x)
|
| 186 |
-
x = self.layer3(x)
|
| 187 |
-
x = self.layer4(x)
|
| 188 |
-
|
| 189 |
-
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
| 190 |
-
|
| 191 |
-
x = self.pooling(x)
|
| 192 |
-
x = self.bn2(x)
|
| 193 |
-
x = torch.flatten(x, 1)
|
| 194 |
-
x = self.fc(x)
|
| 195 |
-
x = self.bn3(x)
|
| 196 |
-
return x
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
def Speaker_Encoder(embedding_dim=256, **kwargs):
|
| 200 |
-
# Number of filters
|
| 201 |
-
num_filters = [32, 64, 128, 256]
|
| 202 |
-
model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, embedding_dim, **kwargs)
|
| 203 |
-
return model
|
| 204 |
-
|
| 205 |
-
if __name__ == '__main__':
|
| 206 |
-
model = Speaker_Encoder()
|
| 207 |
-
total = sum([param.nelement() for param in model.parameters()])
|
| 208 |
-
print(total/1e6)
|
| 209 |
-
data = torch.randn(10, 80, 100)
|
| 210 |
-
out = model(data)
|
| 211 |
-
print(data.shape)
|
| 212 |
-
print(out.shape)
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nnet/__init__.py
DELETED
|
File without changes
|
nnet/pooling.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
|
| 5 |
-
|
| 6 |
-
class Temporal_Average_Pooling(nn.Module):
|
| 7 |
-
def __init__(self, **kwargs):
|
| 8 |
-
"""TAP
|
| 9 |
-
Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
|
| 10 |
-
Link: https://arxiv.org/pdf/1903.12058.pdf
|
| 11 |
-
"""
|
| 12 |
-
super(Temporal_Average_Pooling, self).__init__()
|
| 13 |
-
|
| 14 |
-
def forward(self, x):
|
| 15 |
-
"""Computes Temporal Average Pooling Module
|
| 16 |
-
Args:
|
| 17 |
-
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
| 18 |
-
Returns:
|
| 19 |
-
torch.Tensor: Output tensor (#batch, channels)
|
| 20 |
-
"""
|
| 21 |
-
x = torch.mean(x, axis=2)
|
| 22 |
-
return x
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class Temporal_Statistics_Pooling(nn.Module):
|
| 26 |
-
def __init__(self, **kwargs):
|
| 27 |
-
"""TSP
|
| 28 |
-
Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
|
| 29 |
-
Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
|
| 30 |
-
"""
|
| 31 |
-
super(Temporal_Statistics_Pooling, self).__init__()
|
| 32 |
-
|
| 33 |
-
def forward(self, x):
|
| 34 |
-
"""Computes Temporal Statistics Pooling Module
|
| 35 |
-
Args:
|
| 36 |
-
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
| 37 |
-
Returns:
|
| 38 |
-
torch.Tensor: Output tensor (#batch, channels*2)
|
| 39 |
-
"""
|
| 40 |
-
mean = torch.mean(x, axis=2)
|
| 41 |
-
var = torch.var(x, axis=2)
|
| 42 |
-
x = torch.cat((mean, var), axis=1)
|
| 43 |
-
return x
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
''' Self attentive weighted mean pooling.
|
| 47 |
-
'''
|
| 48 |
-
class Self_Attentive_Pooling(nn.Module):
|
| 49 |
-
def __init__(self, dim, **kwargs):
|
| 50 |
-
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
| 51 |
-
# attention dim = 128
|
| 52 |
-
super(Self_Attentive_Pooling, self).__init__()
|
| 53 |
-
self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
|
| 54 |
-
self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
|
| 55 |
-
|
| 56 |
-
def forward(self, x):
|
| 57 |
-
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
| 58 |
-
alpha = torch.tanh(self.linear1(x))
|
| 59 |
-
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 60 |
-
mean = torch.sum(alpha * x, dim=2)
|
| 61 |
-
return mean
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
''' Attentive weighted mean and standard deviation pooling.
|
| 65 |
-
'''
|
| 66 |
-
class Attentive_Statistics_Pooling(nn.Module):
|
| 67 |
-
def __init__(self, dim, **kwargs):
|
| 68 |
-
# Use AttentiveStatisticsPooling and BatchNorm1d from speechbrain
|
| 69 |
-
super(Attentive_Statistics_Pooling, self).__init__()
|
| 70 |
-
self.pooling = AttentiveStatisticsPooling(dim)
|
| 71 |
-
|
| 72 |
-
def forward(self, x):
|
| 73 |
-
x = self.pooling(x)
|
| 74 |
-
return x
|
| 75 |
-
|
| 76 |
-
# class Attentive_Statistics_Pooling(nn.Module):
|
| 77 |
-
# def __init__(self, dim, **kwargs):
|
| 78 |
-
# # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
| 79 |
-
# # attention dim = 128
|
| 80 |
-
# super(Attentive_Statistics_Pooling, self).__init__()
|
| 81 |
-
# self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
|
| 82 |
-
# self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
|
| 83 |
-
#
|
| 84 |
-
# def forward(self, x):
|
| 85 |
-
# # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
| 86 |
-
# alpha = torch.tanh(self.linear1(x))
|
| 87 |
-
# alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 88 |
-
# mean = torch.sum(alpha * x, dim=2)
|
| 89 |
-
# residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
|
| 90 |
-
# std = torch.sqrt(residuals.clamp(min=1e-9))
|
| 91 |
-
# return torch.cat([mean, std], dim=1)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if __name__ == "__main__":
|
| 96 |
-
data = torch.randn(10, 128, 100)
|
| 97 |
-
pooling = Self_Attentive_Pooling(128)
|
| 98 |
-
out = pooling(data)
|
| 99 |
-
print(data.shape)
|
| 100 |
-
print(out.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nnet/speaker_encoder.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torchaudio
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from torch.nn import functional as F
|
| 5 |
-
from .ResNet34 import Speaker_Encoder
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class Speaker_Model(torch.nn.Module):
|
| 9 |
-
#class Speaker_Model(LightningModule):
|
| 10 |
-
def __init__(self, pooling_type, spk_embed_dim, sample_rate, n_mels):
|
| 11 |
-
super().__init__()
|
| 12 |
-
# self.save_hyperparameters()
|
| 13 |
-
|
| 14 |
-
self.pooling_type = pooling_type
|
| 15 |
-
self.spk_embed_dim = spk_embed_dim
|
| 16 |
-
self.sample_rate = sample_rate
|
| 17 |
-
self.n_mels = n_mels
|
| 18 |
-
sr = self.sample_rate
|
| 19 |
-
|
| 20 |
-
self.mel_trans = torch.nn.Sequential(
|
| 21 |
-
PreEmphasis(),
|
| 22 |
-
torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=512,
|
| 23 |
-
win_length=sr * 25 // 1000, hop_length=sr * 10 // 1000,
|
| 24 |
-
window_fn=torch.hamming_window, n_mels=self.n_mels)
|
| 25 |
-
)
|
| 26 |
-
self.instancenorm = nn.InstanceNorm1d(self.n_mels)
|
| 27 |
-
|
| 28 |
-
self.hparams = {'embedding_dim': self.spk_embed_dim, 'pooling_type': self.pooling_type , 'n_mels': self.n_mels}
|
| 29 |
-
|
| 30 |
-
self.speaker_encoder = Speaker_Encoder(**dict(self.hparams))
|
| 31 |
-
|
| 32 |
-
class PreEmphasis(torch.nn.Module):
|
| 33 |
-
def __init__(self, coef: float = 0.97):
|
| 34 |
-
super().__init__()
|
| 35 |
-
self.coef = coef
|
| 36 |
-
# make kernel
|
| 37 |
-
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
|
| 38 |
-
self.register_buffer(
|
| 39 |
-
'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
def forward(self, inputs: torch.tensor) -> torch.tensor:
|
| 43 |
-
assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
|
| 44 |
-
# reflect padding to match lengths of in/out
|
| 45 |
-
inputs = inputs.unsqueeze(1)
|
| 46 |
-
inputs = F.pad(inputs, (1, 0), 'reflect')
|
| 47 |
-
return F.conv1d(inputs, self.flipped_filter).squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noises/00840.wav
ADDED
|
Binary file (963 kB). View file
|
|
|
noises/022928.wav
ADDED
|
Binary file (961 kB). View file
|
|
|
noises/04338.wav
ADDED
|
Binary file (959 kB). View file
|
|
|
noises/046324.wav
ADDED
|
Binary file (961 kB). View file
|
|
|
noises/093004.wav
ADDED
|
Binary file (960 kB). View file
|
|
|
noises/11129.wav
ADDED
|
Binary file (963 kB). View file
|
|
|
noises/133254.wav
ADDED
|
Binary file (960 kB). View file
|
|
|
noises/30100.wav
ADDED
|
Binary file (959 kB). View file
|
|
|
noises/30135.wav
ADDED
|
Binary file (959 kB). View file
|
|
|
noises/30437.wav
ADDED
|
Binary file (963 kB). View file
|
|
|
noises/30603.wav
ADDED
|
Binary file (959 kB). View file
|
|
|
requirements.txt
CHANGED
|
@@ -1,2 +1,7 @@
|
|
| 1 |
-
soundfile
|
| 2 |
gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
soundfile==0.12.1
|
| 2 |
gradio
|
| 3 |
+
hydra-core==1.3.2
|
| 4 |
+
torch==1.11.0
|
| 5 |
+
pyloudnorm
|
| 6 |
+
numpy==1.24.4
|
| 7 |
+
librosa
|
temp_extracted.wav
ADDED
|
Binary file (180 kB). View file
|
|
|
test_mix.wav
ADDED
|
Binary file (180 kB). View file
|
|
|
test_output_mixture.wav
ADDED
|
Binary file (180 kB). View file
|
|
|
utils/__init__.py
DELETED
|
File without changes
|
utils/audio.py
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import numpy as np
|
| 5 |
-
import soundfile as sf
|
| 6 |
-
|
| 7 |
-
def write_wav(fname, samps, sample_rate=16000, normalize=True):
|
| 8 |
-
"""
|
| 9 |
-
Write wav files in float32, support single/multi-channel
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
# wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
|
| 13 |
-
# change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
|
| 14 |
-
fdir = os.path.dirname(fname)
|
| 15 |
-
if fdir and not os.path.exists(fdir):
|
| 16 |
-
os.makedirs(fdir)
|
| 17 |
-
sf.write(fname, samps, sample_rate, subtype='FLOAT')
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def read_wav(fname, normalize=True, return_rate=False):
|
| 21 |
-
"""
|
| 22 |
-
Read wave files (support multi-channel)
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
# wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
|
| 26 |
-
# change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
|
| 27 |
-
samps, samp_rate = sf.read(fname)
|
| 28 |
-
if return_rate:
|
| 29 |
-
return samp_rate, samps
|
| 30 |
-
return samps
|
| 31 |
-
|
| 32 |
-
def parse_scripts(scp_path, value_processor=lambda x: x, num_tokens=2):
|
| 33 |
-
"""
|
| 34 |
-
Parse kaldi's script(.scp) file
|
| 35 |
-
If num_tokens >= 2, function will check token number
|
| 36 |
-
"""
|
| 37 |
-
scp_dict = dict()
|
| 38 |
-
line = 0
|
| 39 |
-
with open(scp_path, "r") as f:
|
| 40 |
-
for raw_line in f:
|
| 41 |
-
scp_tokens = raw_line.strip().split()
|
| 42 |
-
line += 1
|
| 43 |
-
if num_tokens >= 2 and len(scp_tokens) != num_tokens or len(
|
| 44 |
-
scp_tokens) < 2:
|
| 45 |
-
raise RuntimeError(
|
| 46 |
-
"For {}, format error in line[{:d}]: {}".format(
|
| 47 |
-
scp_path, line, raw_line))
|
| 48 |
-
if num_tokens == 2:
|
| 49 |
-
key, value = scp_tokens
|
| 50 |
-
else:
|
| 51 |
-
key, value = scp_tokens[0], scp_tokens[1:]
|
| 52 |
-
if key in scp_dict:
|
| 53 |
-
raise ValueError("Duplicated key \'{0}\' exists in {1}".format(
|
| 54 |
-
key, scp_path))
|
| 55 |
-
scp_dict[key] = value_processor(value)
|
| 56 |
-
return scp_dict
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class Reader(object):
|
| 60 |
-
"""
|
| 61 |
-
Basic Reader Class
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
def __init__(self, scp_path, value_processor=lambda x: x):
|
| 65 |
-
self.index_dict = parse_scripts(
|
| 66 |
-
scp_path, value_processor=value_processor, num_tokens=2)
|
| 67 |
-
self.index_keys = list(self.index_dict.keys())
|
| 68 |
-
|
| 69 |
-
def _load(self, key):
|
| 70 |
-
# return path
|
| 71 |
-
return self.index_dict[key]
|
| 72 |
-
|
| 73 |
-
# number of utterance
|
| 74 |
-
def __len__(self):
|
| 75 |
-
return len(self.index_dict)
|
| 76 |
-
|
| 77 |
-
# avoid key error
|
| 78 |
-
def __contains__(self, key):
|
| 79 |
-
return key in self.index_dict
|
| 80 |
-
|
| 81 |
-
# sequential index
|
| 82 |
-
def __iter__(self):
|
| 83 |
-
for key in self.index_keys:
|
| 84 |
-
yield key, self._load(key)
|
| 85 |
-
|
| 86 |
-
# random index, support str/int as index
|
| 87 |
-
def __getitem__(self, index):
|
| 88 |
-
if type(index) not in [int, str]:
|
| 89 |
-
raise IndexError("Unsupported index type: {}".format(type(index)))
|
| 90 |
-
if type(index) == int:
|
| 91 |
-
# from int index to key
|
| 92 |
-
num_utts = len(self.index_keys)
|
| 93 |
-
if index >= num_utts or index < 0:
|
| 94 |
-
raise KeyError(
|
| 95 |
-
"Interger index out of range, {:d} vs {:d}".format(
|
| 96 |
-
index, num_utts))
|
| 97 |
-
index = self.index_keys[index]
|
| 98 |
-
if index not in self.index_dict:
|
| 99 |
-
raise KeyError("Missing utterance {}!".format(index))
|
| 100 |
-
return self._load(index)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
class WaveReader(Reader):
|
| 104 |
-
"""
|
| 105 |
-
Sequential/Random Reader for single channel wave
|
| 106 |
-
Format of wav.scp follows Kaldi's definition:
|
| 107 |
-
key1 /path/to/wav
|
| 108 |
-
...
|
| 109 |
-
"""
|
| 110 |
-
|
| 111 |
-
def __init__(self, wav_scp, sample_rate=None, normalize=True):
|
| 112 |
-
super(WaveReader, self).__init__(wav_scp)
|
| 113 |
-
self.samp_rate = sample_rate
|
| 114 |
-
self.normalize = normalize
|
| 115 |
-
|
| 116 |
-
def _load(self, key):
|
| 117 |
-
# return C x N or N
|
| 118 |
-
samp_rate, samps = read_wav(
|
| 119 |
-
self.index_dict[key], normalize=self.normalize, return_rate=True)
|
| 120 |
-
# if given samp_rate, check it
|
| 121 |
-
if self.samp_rate is not None and samp_rate != self.samp_rate:
|
| 122 |
-
raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format(
|
| 123 |
-
samp_rate, self.samp_rate))
|
| 124 |
-
return samps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/dataset copy.py
DELETED
|
@@ -1,284 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import random
|
| 4 |
-
import torch as th
|
| 5 |
-
import numpy as np
|
| 6 |
-
|
| 7 |
-
from torch.utils.data.dataloader import default_collate
|
| 8 |
-
import torch.utils.data as dat
|
| 9 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
-
|
| 11 |
-
from .audio import WaveReader
|
| 12 |
-
|
| 13 |
-
import soundfile as sf
|
| 14 |
-
|
| 15 |
-
# random_seed = 1453
|
| 16 |
-
# random.seed(random_seed)
|
| 17 |
-
|
| 18 |
-
def make_dataloader(train=True,
|
| 19 |
-
utt_scp_file=None,
|
| 20 |
-
spk_list=None,
|
| 21 |
-
sample_rate=16000,
|
| 22 |
-
num_workers=4,
|
| 23 |
-
chunk_size=32000,
|
| 24 |
-
batch_size=16):
|
| 25 |
-
dataset = Dataset(utt_scp_file=utt_scp_file,
|
| 26 |
-
spk_list=spk_list,
|
| 27 |
-
chunk_size=chunk_size,
|
| 28 |
-
sample_rate=sample_rate)
|
| 29 |
-
return DataLoader(dataset,
|
| 30 |
-
train=train,
|
| 31 |
-
chunk_size=chunk_size,
|
| 32 |
-
batch_size=batch_size,
|
| 33 |
-
num_workers=num_workers)
|
| 34 |
-
|
| 35 |
-
class Dataset(object):
|
| 36 |
-
"""
|
| 37 |
-
Per Utterance Loader
|
| 38 |
-
"""
|
| 39 |
-
def __init__(self, utt_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
|
| 40 |
-
self.sample_rate = sample_rate
|
| 41 |
-
self.spk_list = self._load_spk(spk_list)
|
| 42 |
-
|
| 43 |
-
self.seg_least= int(chunk_size // 2 )
|
| 44 |
-
|
| 45 |
-
# self.mix = WaveReader(mix_scp, sample_rate=sample_rate)
|
| 46 |
-
# self.ref = WaveReader(ref_scp, sample_rate=sample_rate)
|
| 47 |
-
# self.aux = WaveReader(aux_scp, sample_rate=sample_rate)
|
| 48 |
-
|
| 49 |
-
with open(utt_scp_file, 'r') as f:
|
| 50 |
-
lines = f.readlines()
|
| 51 |
-
self.data = []
|
| 52 |
-
self.total_lines = len(self.data)
|
| 53 |
-
for line in lines:
|
| 54 |
-
parts = line.strip().split()
|
| 55 |
-
sentence_id = parts[0]
|
| 56 |
-
sentence_path = parts[1]
|
| 57 |
-
data_len = parts[2]
|
| 58 |
-
spk_id = (sentence_id.split('-')[0])[1:5]
|
| 59 |
-
self.data.append((sentence_id, spk_id, sentence_path, data_len))
|
| 60 |
-
|
| 61 |
-
if not self.data:
|
| 62 |
-
raise ValueError("No valid lines found in the input file.")
|
| 63 |
-
self.total_lines = len(self.data)
|
| 64 |
-
|
| 65 |
-
def _load_spk(self, spk_list_path):
|
| 66 |
-
if spk_list_path is None:
|
| 67 |
-
return []
|
| 68 |
-
lines = open(spk_list_path).readlines()
|
| 69 |
-
new_lines = []
|
| 70 |
-
for line in lines:
|
| 71 |
-
new_lines.append(line.strip())
|
| 72 |
-
|
| 73 |
-
return new_lines
|
| 74 |
-
|
| 75 |
-
def __len__(self):
|
| 76 |
-
return len(self.data)
|
| 77 |
-
|
| 78 |
-
def _get_segment_start_stop(self, seg_len, length):
|
| 79 |
-
if seg_len is not None:
|
| 80 |
-
start = random.randint(0, length - seg_len)
|
| 81 |
-
stop = start + seg_len
|
| 82 |
-
else:
|
| 83 |
-
start = 0
|
| 84 |
-
stop = None
|
| 85 |
-
return start, stop
|
| 86 |
-
|
| 87 |
-
def _mix(self, sources_list):
|
| 88 |
-
|
| 89 |
-
# if self.seg_len:
|
| 90 |
-
# mix_length = self.seg_len
|
| 91 |
-
|
| 92 |
-
# else:
|
| 93 |
-
# mix_length = self.common_length
|
| 94 |
-
mix_length = self.common_length
|
| 95 |
-
mixture = np.zeros(mix_length)
|
| 96 |
-
for i, _ in enumerate(sources_list):
|
| 97 |
-
mixture += sources_list[i]
|
| 98 |
-
|
| 99 |
-
return mixture
|
| 100 |
-
|
| 101 |
-
def __getitem__(self, idx):
|
| 102 |
-
source_id, source_spk, source_path, all_source_length= self.data[idx]
|
| 103 |
-
all_source_length = int(all_source_length)
|
| 104 |
-
spk_idx = self.spk_list.index(source_spk)
|
| 105 |
-
|
| 106 |
-
other_counter = 0
|
| 107 |
-
while True:
|
| 108 |
-
random_idx = np.random.randint(0, self.total_lines)
|
| 109 |
-
if self.data[random_idx][1] != source_spk:
|
| 110 |
-
other_id, other_spk, other_path, other_length = self.data[random_idx]
|
| 111 |
-
other_length = int(other_length)
|
| 112 |
-
|
| 113 |
-
if other_length > self.seg_least:
|
| 114 |
-
break
|
| 115 |
-
|
| 116 |
-
other_counter += 1
|
| 117 |
-
|
| 118 |
-
if other_counter >= self.total_lines:
|
| 119 |
-
raise ValueError("All Data too shorter to mix")
|
| 120 |
-
|
| 121 |
-
enroll_counter = 0
|
| 122 |
-
|
| 123 |
-
while True:
|
| 124 |
-
random_idx = np.random.randint(0, self.total_lines)
|
| 125 |
-
if self.data[random_idx][1] == source_spk:
|
| 126 |
-
enroll_id, enroll_spk, enroll_path, all_enroll_length= self.data[random_idx]
|
| 127 |
-
all_enroll_length = int(all_enroll_length)
|
| 128 |
-
if all_enroll_length > self.seg_least:
|
| 129 |
-
break
|
| 130 |
-
|
| 131 |
-
enroll_counter += 1
|
| 132 |
-
if enroll_counter >= self.total_lines:
|
| 133 |
-
raise ValueError("All Data too shorter to enroll")
|
| 134 |
-
# lengths = [all_source_length, other_length]
|
| 135 |
-
|
| 136 |
-
if all_source_length >= other_length:
|
| 137 |
-
self.common_length = other_length
|
| 138 |
-
start, stop = self._get_segment_start_stop(other_length, all_source_length)
|
| 139 |
-
source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
|
| 140 |
-
other_tmp,_ = sf.read(other_path, dtype="float32")
|
| 141 |
-
elif all_source_length <= other_length:
|
| 142 |
-
self.common_length = all_source_length
|
| 143 |
-
start, stop = self._get_segment_start_stop(all_source_length, other_length)
|
| 144 |
-
source_tmp,_ = sf.read(source_path, dtype="float32")
|
| 145 |
-
other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
|
| 146 |
-
|
| 147 |
-
source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
|
| 148 |
-
|
| 149 |
-
other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
|
| 150 |
-
|
| 151 |
-
mixture = self._mix([source, other])
|
| 152 |
-
mixture = mixture.astype(np.float32)
|
| 153 |
-
|
| 154 |
-
enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
|
| 155 |
-
enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
|
| 156 |
-
|
| 157 |
-
return {
|
| 158 |
-
"mix": mixture,
|
| 159 |
-
"ref": source,
|
| 160 |
-
"aux": enroll,
|
| 161 |
-
"aux_len": len(enroll),
|
| 162 |
-
"spk_idx": spk_idx
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
class ChunkSplitter(object):
|
| 166 |
-
"""
|
| 167 |
-
Split utterance into small chunks
|
| 168 |
-
"""
|
| 169 |
-
def __init__(self, chunk_size, train=True, least=16000):
|
| 170 |
-
self.chunk_size = chunk_size
|
| 171 |
-
self.least = least
|
| 172 |
-
self.train = train
|
| 173 |
-
|
| 174 |
-
def _make_chunk(self, eg, s):
|
| 175 |
-
"""
|
| 176 |
-
Make a chunk instance, which contains:
|
| 177 |
-
"mix": ndarray,
|
| 178 |
-
"ref": [ndarray...]
|
| 179 |
-
"""
|
| 180 |
-
chunk = dict()
|
| 181 |
-
chunk["mix"] = eg["mix"][s:s + self.chunk_size]
|
| 182 |
-
chunk["ref"] = eg["ref"][s:s + self.chunk_size]
|
| 183 |
-
chunk["aux"] = eg["aux"]
|
| 184 |
-
chunk["aux_len"] = eg["aux_len"]
|
| 185 |
-
chunk["valid_len"] = int(self.chunk_size)
|
| 186 |
-
chunk["spk_idx"] = eg["spk_idx"]
|
| 187 |
-
return chunk
|
| 188 |
-
|
| 189 |
-
def split(self, eg):
|
| 190 |
-
N = eg["mix"].size
|
| 191 |
-
# too short, throw away
|
| 192 |
-
if N < self.least:
|
| 193 |
-
return []
|
| 194 |
-
chunks = []
|
| 195 |
-
# padding zeros
|
| 196 |
-
if N < self.chunk_size:
|
| 197 |
-
P = self.chunk_size - N
|
| 198 |
-
chunk = dict()
|
| 199 |
-
chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
|
| 200 |
-
chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
|
| 201 |
-
chunk["aux"] = eg["aux"]
|
| 202 |
-
chunk["aux_len"] = eg["aux_len"]
|
| 203 |
-
chunk["valid_len"] = int(N)
|
| 204 |
-
chunk["spk_idx"] = eg["spk_idx"]
|
| 205 |
-
chunks.append(chunk)
|
| 206 |
-
else:
|
| 207 |
-
# random select start point for training
|
| 208 |
-
s = random.randint(0, N % self.least) if self.train else 0
|
| 209 |
-
while True:
|
| 210 |
-
if s + self.chunk_size > N:
|
| 211 |
-
break
|
| 212 |
-
chunk = self._make_chunk(eg, s)
|
| 213 |
-
chunks.append(chunk)
|
| 214 |
-
s += self.least
|
| 215 |
-
return chunks
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
class DataLoader(object):
|
| 219 |
-
"""
|
| 220 |
-
Online dataloader for chunk-level
|
| 221 |
-
"""
|
| 222 |
-
def __init__(self,
|
| 223 |
-
dataset,
|
| 224 |
-
num_workers=4,
|
| 225 |
-
chunk_size=32000,
|
| 226 |
-
batch_size=16,
|
| 227 |
-
train=True):
|
| 228 |
-
self.batch_size = batch_size
|
| 229 |
-
self.train = train
|
| 230 |
-
self.splitter = ChunkSplitter(chunk_size,
|
| 231 |
-
train=train,
|
| 232 |
-
least=chunk_size // 2)
|
| 233 |
-
# just return batch of egs, support multiple workers
|
| 234 |
-
self.eg_loader = dat.DataLoader(dataset,
|
| 235 |
-
batch_size=batch_size // 2,
|
| 236 |
-
num_workers=num_workers,
|
| 237 |
-
shuffle=train,
|
| 238 |
-
collate_fn=self._collate)
|
| 239 |
-
|
| 240 |
-
def _collate(self, batch):
|
| 241 |
-
"""
|
| 242 |
-
Online split utterances
|
| 243 |
-
"""
|
| 244 |
-
chunk = []
|
| 245 |
-
for eg in batch:
|
| 246 |
-
chunk += self.splitter.split(eg)
|
| 247 |
-
return chunk
|
| 248 |
-
|
| 249 |
-
def _pad_aux(self, chunk_list):
|
| 250 |
-
lens_list = []
|
| 251 |
-
for chunk_item in chunk_list:
|
| 252 |
-
lens_list.append(chunk_item['aux_len'])
|
| 253 |
-
max_len = np.max(lens_list)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
for idx in range(len(chunk_list)):
|
| 257 |
-
P = max_len - len(chunk_list[idx]["aux"])
|
| 258 |
-
chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
|
| 259 |
-
|
| 260 |
-
return chunk_list
|
| 261 |
-
|
| 262 |
-
def _merge(self, chunk_list):
|
| 263 |
-
"""
|
| 264 |
-
Merge chunk list into mini-batch
|
| 265 |
-
"""
|
| 266 |
-
N = len(chunk_list)
|
| 267 |
-
if self.train:
|
| 268 |
-
random.shuffle(chunk_list)
|
| 269 |
-
blist = []
|
| 270 |
-
for s in range(0, N - self.batch_size + 1, self.batch_size):
|
| 271 |
-
# padding aux info
|
| 272 |
-
#self._pad_aux(chunk_list[s:s + self.batch_size])
|
| 273 |
-
batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
|
| 274 |
-
blist.append(batch)
|
| 275 |
-
rn = N % self.batch_size
|
| 276 |
-
return blist, chunk_list[-rn:] if rn else []
|
| 277 |
-
|
| 278 |
-
def __iter__(self):
|
| 279 |
-
chunk_list = []
|
| 280 |
-
for chunks in self.eg_loader:
|
| 281 |
-
chunk_list += chunks
|
| 282 |
-
batch, chunk_list = self._merge(chunk_list)
|
| 283 |
-
for obj in batch:
|
| 284 |
-
yield obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/dataset.py
DELETED
|
@@ -1,402 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import random
|
| 4 |
-
import torch as th
|
| 5 |
-
import numpy as np
|
| 6 |
-
|
| 7 |
-
from torch.utils.data.dataloader import default_collate
|
| 8 |
-
import torch.utils.data as dat
|
| 9 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
-
|
| 11 |
-
from .audio import WaveReader
|
| 12 |
-
|
| 13 |
-
import soundfile as sf
|
| 14 |
-
|
| 15 |
-
# random_seed = 1453
|
| 16 |
-
# random.seed(random_seed)
|
| 17 |
-
|
| 18 |
-
# "aux_len": all_enroll_length,
|
| 19 |
-
|
| 20 |
-
EPS = 1e-10
|
| 21 |
-
def make_dataloader(train=True,
|
| 22 |
-
mix_scp_file=None,
|
| 23 |
-
enroll_scp_file=None,
|
| 24 |
-
noise_scp_file=None,
|
| 25 |
-
spk_list=None,
|
| 26 |
-
sample_rate=16000,
|
| 27 |
-
num_workers=4,
|
| 28 |
-
chunk_size=32000,
|
| 29 |
-
batch_size=16):
|
| 30 |
-
dataset = Dataset(mix_scp_file=mix_scp_file,
|
| 31 |
-
enroll_scp_file=enroll_scp_file,
|
| 32 |
-
noise_scp_file=noise_scp_file,
|
| 33 |
-
spk_list=spk_list,
|
| 34 |
-
chunk_size=chunk_size,
|
| 35 |
-
sample_rate=sample_rate)
|
| 36 |
-
return DataLoader(dataset,
|
| 37 |
-
train=train,
|
| 38 |
-
chunk_size=chunk_size,
|
| 39 |
-
batch_size=batch_size,
|
| 40 |
-
num_workers=num_workers)
|
| 41 |
-
|
| 42 |
-
class Dataset(object):
|
| 43 |
-
"""
|
| 44 |
-
Per Utterance Loader
|
| 45 |
-
"""
|
| 46 |
-
def __init__(self, mix_scp_file="", enroll_scp_file="", noise_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
|
| 47 |
-
self.sample_rate = sample_rate
|
| 48 |
-
self.spk_list = self._load_spk(spk_list)
|
| 49 |
-
|
| 50 |
-
self.seg_least= int(chunk_size // 2 )
|
| 51 |
-
|
| 52 |
-
with open(mix_scp_file, 'r') as f:
|
| 53 |
-
lines = f.readlines()
|
| 54 |
-
self.data = []
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
for line in lines:
|
| 58 |
-
parts = line.strip().split()
|
| 59 |
-
sentence_id = parts[0]
|
| 60 |
-
sentence_path = parts[1]
|
| 61 |
-
data_len = parts[2]
|
| 62 |
-
spk_id = (sentence_id.split('-')[0])[1:5]
|
| 63 |
-
self.data.append((sentence_id, spk_id, sentence_path, data_len))
|
| 64 |
-
|
| 65 |
-
with open(enroll_scp_file, 'r') as f:
|
| 66 |
-
enroll_lines = f.readlines()
|
| 67 |
-
self.enroll_data = []
|
| 68 |
-
|
| 69 |
-
for line in enroll_lines:
|
| 70 |
-
parts = line.strip().split()
|
| 71 |
-
sentence_id = parts[0]
|
| 72 |
-
sentence_path = parts[1]
|
| 73 |
-
data_len = parts[2]
|
| 74 |
-
spk_id = (sentence_id.split('-')[0])[1:5]
|
| 75 |
-
self.enroll_data.append((sentence_id, spk_id, sentence_path, data_len))
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
with open(noise_scp_file, 'r') as f:
|
| 79 |
-
noise_lines = f.readlines()
|
| 80 |
-
self.noise_data = []
|
| 81 |
-
|
| 82 |
-
for line in noise_lines:
|
| 83 |
-
parts = line.strip().split()
|
| 84 |
-
sentence_id = parts[0]
|
| 85 |
-
sentence_path = parts[1]
|
| 86 |
-
data_len = parts[2]
|
| 87 |
-
# spk_id = (sentence_id.split('-')[0])[1:5]
|
| 88 |
-
self.noise_data.append((sentence_id, sentence_path, data_len))
|
| 89 |
-
|
| 90 |
-
self.total_lines = len(self.data)
|
| 91 |
-
self.total_enroll = self._enroll_data_len()
|
| 92 |
-
self.total_noise = self._noise_data_len()
|
| 93 |
-
|
| 94 |
-
if not self.data:
|
| 95 |
-
raise ValueError("No valid lines found in the input file.")
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def _load_spk(self, spk_list_path):
|
| 99 |
-
if spk_list_path is None:
|
| 100 |
-
return []
|
| 101 |
-
lines = open(spk_list_path).readlines()
|
| 102 |
-
new_lines = []
|
| 103 |
-
for line in lines:
|
| 104 |
-
new_lines.append(line.strip())
|
| 105 |
-
|
| 106 |
-
return new_lines
|
| 107 |
-
|
| 108 |
-
def __len__(self):
|
| 109 |
-
return len(self.data)
|
| 110 |
-
|
| 111 |
-
def _enroll_data_len(self):
|
| 112 |
-
return len(self.enroll_data)
|
| 113 |
-
|
| 114 |
-
def _noise_data_len(self):
|
| 115 |
-
return len(self.noise_data)
|
| 116 |
-
|
| 117 |
-
def _get_segment_start_stop(self, seg_len, length):
|
| 118 |
-
if seg_len is not None:
|
| 119 |
-
start = random.randint(0, length - seg_len)
|
| 120 |
-
stop = start + seg_len
|
| 121 |
-
else:
|
| 122 |
-
start = 0
|
| 123 |
-
stop = None
|
| 124 |
-
return start, stop
|
| 125 |
-
|
| 126 |
-
def _mix(self, sources_list):
|
| 127 |
-
|
| 128 |
-
# if self.seg_len:
|
| 129 |
-
# mix_length = self.seg_len
|
| 130 |
-
|
| 131 |
-
# else:
|
| 132 |
-
# mix_length = self.common_length
|
| 133 |
-
mix_length = self.common_length
|
| 134 |
-
mixture = np.zeros(mix_length)
|
| 135 |
-
for i, _ in enumerate(sources_list):
|
| 136 |
-
mixture += sources_list[i]
|
| 137 |
-
|
| 138 |
-
return mixture
|
| 139 |
-
|
| 140 |
-
def __getitem__(self, idx):
|
| 141 |
-
source_id, source_spk, source_path, all_source_length= self.data[idx]
|
| 142 |
-
all_source_length = int(all_source_length)
|
| 143 |
-
spk_idx = self.spk_list.index(source_spk)
|
| 144 |
-
|
| 145 |
-
other_counter = 0
|
| 146 |
-
while True:
|
| 147 |
-
random_idx = np.random.randint(0, self.total_lines)
|
| 148 |
-
if self.data[random_idx][1] != source_spk:
|
| 149 |
-
other_id, other_spk, other_path, other_length = self.data[random_idx]
|
| 150 |
-
other_length = int(other_length)
|
| 151 |
-
|
| 152 |
-
if other_length > self.seg_least:
|
| 153 |
-
break
|
| 154 |
-
|
| 155 |
-
other_counter += 1
|
| 156 |
-
|
| 157 |
-
if other_counter >= self.total_lines:
|
| 158 |
-
raise ValueError("All Data too shorter to mix")
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
if all_source_length >= other_length:
|
| 162 |
-
self.common_length = other_length
|
| 163 |
-
start, stop = self._get_segment_start_stop(self.common_length, all_source_length)
|
| 164 |
-
source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
|
| 165 |
-
other_tmp,_ = sf.read(other_path, dtype="float32")
|
| 166 |
-
elif all_source_length <= other_length:
|
| 167 |
-
self.common_length = all_source_length
|
| 168 |
-
start, stop = self._get_segment_start_stop(self.common_length, other_length)
|
| 169 |
-
source_tmp,_ = sf.read(source_path, dtype="float32")
|
| 170 |
-
other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
|
| 171 |
-
|
| 172 |
-
noise_counter = 0
|
| 173 |
-
while True:
|
| 174 |
-
random_idx = np.random.randint(0, self.total_noise)
|
| 175 |
-
|
| 176 |
-
noise_id, noise_path, all_noise_length= self.noise_data[random_idx]
|
| 177 |
-
all_noise_length = int(all_noise_length)
|
| 178 |
-
|
| 179 |
-
if all_noise_length >= self.common_length:
|
| 180 |
-
break
|
| 181 |
-
noise_counter += 1
|
| 182 |
-
if noise_counter >= self.total_noise:
|
| 183 |
-
raise ValueError("All Data can't as noise")
|
| 184 |
-
|
| 185 |
-
enroll_counter = 0
|
| 186 |
-
while True:
|
| 187 |
-
random_idx = np.random.randint(0, self.total_enroll)
|
| 188 |
-
if self.enroll_data[random_idx][1] == source_spk:
|
| 189 |
-
enroll_id, enroll_spk, enroll_path, all_enroll_length= self.enroll_data[random_idx]
|
| 190 |
-
all_enroll_length = int(all_enroll_length)
|
| 191 |
-
break
|
| 192 |
-
|
| 193 |
-
enroll_counter += 1
|
| 194 |
-
if enroll_counter >= self.total_enroll:
|
| 195 |
-
raise ValueError("All Data can't as enroll")
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
|
| 201 |
-
other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
|
| 202 |
-
|
| 203 |
-
noise_start, noise_stop = self._get_segment_start_stop(self.common_length, all_noise_length)
|
| 204 |
-
noise,_ = sf.read(noise_path, dtype="float32", start=noise_start, stop=noise_stop) # single channel?
|
| 205 |
-
# noise = noise_tmp[:, np.random.randint(0, noise_tmp.shape[1])]
|
| 206 |
-
# other_noise = self._mix([other,noise])
|
| 207 |
-
desired_snr = np.random.uniform(-4, 4) # 设置目标 SNR
|
| 208 |
-
current_snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(noise ** 2) + EPS) + EPS)
|
| 209 |
-
scale_factor = 10 ** ((current_snr - desired_snr ) / 20)
|
| 210 |
-
scaled_noise = noise * scale_factor
|
| 211 |
-
|
| 212 |
-
snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(scaled_noise ** 2) + EPS) + EPS)
|
| 213 |
-
mixture = self._mix([source,other,scaled_noise])
|
| 214 |
-
|
| 215 |
-
mixture = mixture.astype(np.float32)
|
| 216 |
-
|
| 217 |
-
enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
|
| 218 |
-
enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
|
| 219 |
-
|
| 220 |
-
return {
|
| 221 |
-
"mix": mixture,
|
| 222 |
-
"ref": source,
|
| 223 |
-
"aux": enroll,
|
| 224 |
-
"aux_len": all_enroll_length,
|
| 225 |
-
"spk_idx": spk_idx
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
class ChunkSplitter(object):
|
| 229 |
-
"""
|
| 230 |
-
Split utterance into small chunks
|
| 231 |
-
"""
|
| 232 |
-
def __init__(self, chunk_size, train=True, least=16000):
|
| 233 |
-
self.chunk_size = chunk_size
|
| 234 |
-
self.least = least
|
| 235 |
-
self.train = train
|
| 236 |
-
|
| 237 |
-
def _make_chunk(self, eg, s):
|
| 238 |
-
"""
|
| 239 |
-
Make a chunk instance, which contains:
|
| 240 |
-
"mix": ndarray,
|
| 241 |
-
"ref": [ndarray...]
|
| 242 |
-
"""
|
| 243 |
-
chunk = dict()
|
| 244 |
-
chunk["mix"] = eg["mix"][s:s + self.chunk_size]
|
| 245 |
-
chunk["ref"] = eg["ref"][s:s + self.chunk_size]
|
| 246 |
-
chunk["aux"] = eg["aux"]
|
| 247 |
-
chunk["aux_len"] = eg["aux_len"]
|
| 248 |
-
chunk["valid_len"] = int(self.chunk_size)
|
| 249 |
-
chunk["spk_idx"] = eg["spk_idx"]
|
| 250 |
-
return chunk
|
| 251 |
-
|
| 252 |
-
def split(self, eg):
|
| 253 |
-
N = eg["mix"].size
|
| 254 |
-
# too short, throw away
|
| 255 |
-
if N < self.least:
|
| 256 |
-
return []
|
| 257 |
-
chunks = []
|
| 258 |
-
# padding zeros
|
| 259 |
-
if N < self.chunk_size:
|
| 260 |
-
P = self.chunk_size - N
|
| 261 |
-
chunk = dict()
|
| 262 |
-
chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
|
| 263 |
-
chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
|
| 264 |
-
chunk["aux"] = eg["aux"]
|
| 265 |
-
chunk["aux_len"] = eg["aux_len"]
|
| 266 |
-
chunk["valid_len"] = int(N)
|
| 267 |
-
chunk["spk_idx"] = eg["spk_idx"]
|
| 268 |
-
chunks.append(chunk)
|
| 269 |
-
# else:
|
| 270 |
-
# # random select start point for training
|
| 271 |
-
# s = random.randint(0, N % self.least) if self.train else 0
|
| 272 |
-
# while True:
|
| 273 |
-
# if s + self.chunk_size > N:
|
| 274 |
-
# break
|
| 275 |
-
# chunk = self._make_chunk(eg, s)
|
| 276 |
-
# chunks.append(chunk)
|
| 277 |
-
# s += self.least
|
| 278 |
-
# return chunks
|
| 279 |
-
|
| 280 |
-
else:
|
| 281 |
-
if self.train:
|
| 282 |
-
# random select A start point for training
|
| 283 |
-
s = random.randint(0, N - self.chunk_size)
|
| 284 |
-
chunk = self._make_chunk(eg, s)
|
| 285 |
-
chunks.append(chunk)
|
| 286 |
-
else:
|
| 287 |
-
s = 0
|
| 288 |
-
while True:
|
| 289 |
-
if s + self.chunk_size > N:
|
| 290 |
-
break
|
| 291 |
-
chunk = self._make_chunk(eg, s)
|
| 292 |
-
chunks.append(chunk)
|
| 293 |
-
s += self.least
|
| 294 |
-
return chunks
|
| 295 |
-
|
| 296 |
-
class DataLoader(object):
|
| 297 |
-
"""
|
| 298 |
-
Online dataloader for chunk-level
|
| 299 |
-
"""
|
| 300 |
-
def __init__(self,
|
| 301 |
-
dataset,
|
| 302 |
-
num_workers=4,
|
| 303 |
-
chunk_size=32000,
|
| 304 |
-
batch_size=16,
|
| 305 |
-
train=True):
|
| 306 |
-
self.batch_size = batch_size
|
| 307 |
-
self.train = train
|
| 308 |
-
self.splitter = ChunkSplitter(chunk_size,
|
| 309 |
-
train=train,
|
| 310 |
-
least=chunk_size // 2)
|
| 311 |
-
# just return batch of egs, support multiple workers
|
| 312 |
-
self.eg_loader = dat.DataLoader(dataset,
|
| 313 |
-
batch_size=batch_size // 2,
|
| 314 |
-
num_workers=num_workers,
|
| 315 |
-
shuffle=train,
|
| 316 |
-
collate_fn=self._collate)
|
| 317 |
-
|
| 318 |
-
def _collate(self, batch):
|
| 319 |
-
"""
|
| 320 |
-
Online split utterances
|
| 321 |
-
"""
|
| 322 |
-
chunk = []
|
| 323 |
-
for eg in batch:
|
| 324 |
-
chunk += self.splitter.split(eg)
|
| 325 |
-
return chunk
|
| 326 |
-
|
| 327 |
-
def _pad_aux(self, chunk_list):
|
| 328 |
-
lens_list = []
|
| 329 |
-
for chunk_item in chunk_list:
|
| 330 |
-
lens_list.append(chunk_item['aux_len'])
|
| 331 |
-
max_len = np.max(lens_list)
|
| 332 |
-
# pad 0
|
| 333 |
-
for idx in range(len(chunk_list)):
|
| 334 |
-
P = max_len - len(chunk_list[idx]["aux"])
|
| 335 |
-
chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
|
| 336 |
-
# # pad circle
|
| 337 |
-
# for idx in range(len(chunk_list)):
|
| 338 |
-
# P = max_len - len(chunk_list[idx]["aux"])
|
| 339 |
-
# original_aux_len = len(chunk_list[idx]["aux"])
|
| 340 |
-
# # 使用循环来填充原句子的内容
|
| 341 |
-
# for i in range(P):
|
| 342 |
-
# chunk_list[idx]["aux"].append(chunk_list[idx]["aux"][i % original_aux_len])
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
return chunk_list
|
| 346 |
-
|
| 347 |
-
def _merge(self, chunk_list):
|
| 348 |
-
"""
|
| 349 |
-
Merge chunk list into mini-batch
|
| 350 |
-
"""
|
| 351 |
-
N = len(chunk_list)
|
| 352 |
-
if self.train:
|
| 353 |
-
random.shuffle(chunk_list)
|
| 354 |
-
blist = []
|
| 355 |
-
for s in range(0, N - self.batch_size + 1, self.batch_size):
|
| 356 |
-
# padding aux info
|
| 357 |
-
#self._pad_aux(chunk_list[s:s + self.batch_size])
|
| 358 |
-
batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
|
| 359 |
-
blist.append(batch)
|
| 360 |
-
rn = N % self.batch_size
|
| 361 |
-
return blist, chunk_list[-rn:] if rn else []
|
| 362 |
-
|
| 363 |
-
def __iter__(self):
|
| 364 |
-
chunk_list = []
|
| 365 |
-
for chunks in self.eg_loader:
|
| 366 |
-
chunk_list += chunks
|
| 367 |
-
batch, chunk_list = self._merge(chunk_list)
|
| 368 |
-
for obj in batch:
|
| 369 |
-
yield obj
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
# def snr_xy(x, y):
|
| 374 |
-
# return 10 * np.log10(np.mean(x ** 2) / (np.mean(y ** 2) + EPS) + EPS)
|
| 375 |
-
|
| 376 |
-
# def main(args):
|
| 377 |
-
# wham_noise_dir = args.wham_dir
|
| 378 |
-
# # Get train dir
|
| 379 |
-
# subdir = os.path.join(wham_noise_dir, 'tr')
|
| 380 |
-
# # List files in that dir
|
| 381 |
-
# sound_paths = glob.glob(os.path.join(subdir, '**/*.wav'),
|
| 382 |
-
# recursive=True)
|
| 383 |
-
# # Avoid running this script if it already have been run
|
| 384 |
-
# if len(sound_paths) == 60000:
|
| 385 |
-
# print("It appears that augmented files have already been generated.\n"
|
| 386 |
-
# "Skipping data augmentation.")
|
| 387 |
-
# return
|
| 388 |
-
# elif len(sound_paths) != 20000:
|
| 389 |
-
# print("It appears that augmented files have not been generated properly\n"
|
| 390 |
-
# "Resuming augmentation.")
|
| 391 |
-
# originals = [x for x in sound_paths if 'sp' not in x]
|
| 392 |
-
# to_be_removed_08 = [x.replace('sp08','') for x in sound_paths if 'sp08' in x]
|
| 393 |
-
# to_be_removed_12 = [x.replace('sp12','') for x in sound_paths if 'sp12' in x ]
|
| 394 |
-
# sound_paths_08 = list(set(originals) - set(to_be_removed_08))
|
| 395 |
-
# sound_paths_12 = list(set(originals) - set(to_be_removed_12))
|
| 396 |
-
# augment_noise(sound_paths_08, 0.8)
|
| 397 |
-
# augment_noise(sound_paths_12, 1.2)
|
| 398 |
-
# else:
|
| 399 |
-
# print(f'Augmenting {subdir} files')
|
| 400 |
-
# # Transform audio speed
|
| 401 |
-
# augment_noise(sound_paths, 0.8)
|
| 402 |
-
# augment_noise(sound_paths, 1.2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/load_obj.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import torch as th
|
| 4 |
-
|
| 5 |
-
def load_obj(obj, device):
|
| 6 |
-
"""
|
| 7 |
-
Offload tensor object in obj to cuda device
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
def cuda(obj):
|
| 11 |
-
return obj.to(device) if isinstance(obj, th.Tensor) else obj
|
| 12 |
-
|
| 13 |
-
if isinstance(obj, dict):
|
| 14 |
-
return {key: load_obj(obj[key], device) for key in obj}
|
| 15 |
-
elif isinstance(obj, list):
|
| 16 |
-
return [load_obj(val, device) for val in obj]
|
| 17 |
-
else:
|
| 18 |
-
return cuda(obj)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/logger.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
|
| 5 |
-
def get_logger(
|
| 6 |
-
name,
|
| 7 |
-
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
|
| 8 |
-
date_format="%Y-%m-%d %H:%M:%S",
|
| 9 |
-
file=False):
|
| 10 |
-
"""
|
| 11 |
-
Get python logger instance
|
| 12 |
-
"""
|
| 13 |
-
logger = logging.getLogger(name)
|
| 14 |
-
logger.setLevel(logging.INFO)
|
| 15 |
-
# file or console
|
| 16 |
-
handler = logging.StreamHandler() if not file else logging.FileHandler(
|
| 17 |
-
name)
|
| 18 |
-
handler.setLevel(logging.INFO)
|
| 19 |
-
formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
|
| 20 |
-
handler.setFormatter(formatter)
|
| 21 |
-
logger.addHandler(handler)
|
| 22 |
-
return logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/sisdr.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
|
| 5 |
-
def sisdr(x, s, remove_dc=True):
|
| 6 |
-
"""
|
| 7 |
-
Compute SI-SDR
|
| 8 |
-
x: extracted signal
|
| 9 |
-
s: reference signal(ground truth)
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def vec_l2norm(x):
|
| 13 |
-
return np.linalg.norm(x, 2)
|
| 14 |
-
|
| 15 |
-
if remove_dc:
|
| 16 |
-
x_zm = x - np.mean(x)
|
| 17 |
-
s_zm = s - np.mean(s)
|
| 18 |
-
t = np.inner(x_zm, s_zm) * s_zm / vec_l2norm(s_zm)**2
|
| 19 |
-
n = x_zm - t
|
| 20 |
-
else:
|
| 21 |
-
t = np.inner(x, s) * s / vec_l2norm(s)**2
|
| 22 |
-
n = x - t
|
| 23 |
-
return 20 * np.log10(vec_l2norm(t) / vec_l2norm(n))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/timer.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import time
|
| 4 |
-
|
| 5 |
-
class Timer(object):
|
| 6 |
-
"""
|
| 7 |
-
A timer to record the elapsed time
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
def __init__(self):
|
| 11 |
-
self.reset()
|
| 12 |
-
|
| 13 |
-
def reset(self):
|
| 14 |
-
self.start = time.time()
|
| 15 |
-
|
| 16 |
-
def elapsed(self):
|
| 17 |
-
return (time.time() - self.start) / 60
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|