swc2 commited on
Commit
7eddfc5
·
1 Parent(s): 6d6a788

update change 2

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ ckpt/
app.py CHANGED
@@ -1,55 +1,85 @@
1
  import gradio as gr
2
- #from inference import InferencePipeline
 
 
3
 
4
- #i = InferencePipeline()
5
- # device = "cuda" if torch.cuda.is_available() else "cpu"
 
6
 
7
- # def convert_audio_to_wav(file_path):
8
- # """Convert any supported format (mp3, etc.) to wav using librosa"""
9
- # output_path = "temp_input.wav"
10
- # audio, sr = librosa.load(file_path, sr=None) # 加载音频文件
11
- # librosa.output.write_wav(output_path, audio, sr) # 转换并保存为 WAV 格式
12
- # return output_path
13
 
14
- def gradio_TSE(audio_file_path):
 
 
 
 
 
 
 
 
 
 
15
  """
16
- Wrapper function to handle Gradio's audio input and pass the file path to the voice conversion function.
17
- Gradio passes audio data as a tuple: (temp file path, sample rate).
18
  """
19
- # Gradio passes audio as (temp file path, sample rate)
20
- #audio_file_path = audio_data[0] # Extract the file path
21
- print(f"Here is the audio_file_path: {audio_file_path}")
22
- #print(f"Here is the audio_file_path[0]: {audio_file_path[0]}")
23
- random_wav = f"/path/to/generated_audio_{int(time.time())}.wav"
24
- #return i.voice_conversion(audio_file_path)
25
- return random_wav
26
-
27
- # Define your Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  with gr.Blocks() as demo:
29
  gr.Markdown("## Target Speaker Extraction Demo")
30
  gr.Markdown(
31
- "This demo isolates the speech signal of a target speaker from a mixture of multiple speakers, "
32
- "with or without noises and reverberations."
33
  )
34
-
35
- # input
36
  with gr.Row():
37
- input_audio = gr.Audio(label="Upload or record your clean audio", type="filepath")
38
- enroll_audio = gr.Audio(label="Upload your enroll (target speaker) audio", type="filepath")
 
 
 
 
 
 
 
 
 
39
 
40
- # output
41
  with gr.Row():
 
42
  noisy_audio_output = gr.Audio(label="Noisy Audio (Processed input audio)", type="filepath")
43
  extracted_audio_output = gr.Audio(label="Extracted target speaker audio", type="filepath")
44
 
45
- # deal
46
  convert_button = gr.Button("Extract")
47
-
48
- # event
49
  convert_button.click(
50
  fn=gradio_TSE,
51
- inputs=[input_audio, enroll_audio],
52
- outputs=[noisy_audio_output, extracted_audio_output]
53
  )
54
 
55
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import os
3
+ import soundfile as sf
4
+ import numpy as np
5
 
6
+ # 这是你现有的推理管线
7
+ from decode import InferencePipeline
8
+ from datahandler import AudioMixer, fix_audio_format
9
 
 
 
 
 
 
 
10
 
11
+
12
+ #####################################
13
+ # 这是你的推理 pipeline
14
+ #####################################
15
+ inter = InferencePipeline()
16
+ datamix = AudioMixer()
17
+
18
+ #####################################
19
+ # 这是供 Gradio 点击时调用的函数
20
+ #####################################
21
+ def gradio_TSE(input_audio_path, enroll_audio_path, audio_type):
22
  """
23
+ 如果 audio_type "clean",就调用 data_handler(此处是 produce_mixture_from_clean)
24
+ 把输入先变成 mix;如果是 "mix",则直接使用原始文件。
25
  """
26
+ print(f"User uploaded audio path: {input_audio_path}")
27
+ print(f"User enroll audio path: {enroll_audio_path}")
28
+ print(f"User chose audio_type: {audio_type}")
29
+
30
+ if audio_type == "clean":
31
+ # 先把 clean 转成 mix
32
+ mix_path = datamix.produce_mixture_from_clean(input_audio_path)
33
+ print(f"Converted clean -> mix: {mix_path}")
34
+ else:
35
+ # 如果是已经是混合音频,直接用它
36
+ mix_path = input_audio_path
37
+
38
+ input_wav = fix_audio_format(mix_path)
39
+ mix_wav = "mix.wav"
40
+ sf.write(mix_wav, input_wav, 16000)
41
+ enroll_wav = fix_audio_format(enroll_audio_path)
42
+ eol_wav = "eol.wav"
43
+ sf.write(eol_wav, enroll_wav, 16000)
44
+
45
+ est_path = inter.computer(mix_wav, eol_wav)
46
+ # 接下来走你的推理流程
47
+ return mix_path,est_path
48
+
49
+
50
+ #####################################
51
+ # 搭建 Gradio 界面
52
+ #####################################
53
  with gr.Blocks() as demo:
54
  gr.Markdown("## Target Speaker Extraction Demo")
55
  gr.Markdown(
56
+ "This demo can handle either clean audio (which we'll turn into a mix) or directly a mix audio."
 
57
  )
58
+
 
59
  with gr.Row():
60
+ # 上传或录制的“待处理”音频
61
+ input_audio = gr.Audio(label="Upload/record your audio", type="filepath")
62
+ # 让用户手动指定音频类型
63
+ audio_type = gr.Radio(
64
+ choices=["clean", "mix"],
65
+ value="clean",
66
+ label="Input audio type?"
67
+ )
68
+ with gr.Row():
69
+ # enroll 音频
70
+ enroll_audio = gr.Audio(label="Upload your enroll audio", type="filepath")
71
 
 
72
  with gr.Row():
73
+ # 输出:处理后的 noisy 和 提取的目标说话人
74
  noisy_audio_output = gr.Audio(label="Noisy Audio (Processed input audio)", type="filepath")
75
  extracted_audio_output = gr.Audio(label="Extracted target speaker audio", type="filepath")
76
 
77
+ # 点击按钮触发
78
  convert_button = gr.Button("Extract")
 
 
79
  convert_button.click(
80
  fn=gradio_TSE,
81
+ inputs=[input_audio, enroll_audio, audio_type],
82
+ outputs=[noisy_audio_output, extracted_audio_output]
83
  )
84
 
85
  if __name__ == "__main__":
config/config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ _target_: model.spex_plus.SpEx_Plus # str, model class name
4
+ L1: 40
5
+ L2: 160
6
+ L3: 320
7
+ N: 256
8
+ B: 8
9
+ O: 256
10
+ P: 512
11
+ Q: 3
12
+ num_spks: 1410 # with speed perturbation 470 -> 1410
13
+ spk_embed_dim: 256
14
+ causal: false
15
+ is_innorm: true
16
+ fusion_type: 'cat' #cat mul film att
17
+
18
+
19
+ test:
20
+ checkpoint: "./ckpt/v2.0.pt.tar"
21
+ gpu: -1
22
+ sample_rate: 16000
23
+
datahandler.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import warnings
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import pyloudnorm
7
+ import glob
8
+ import librosa
9
+
10
+ def fix_audio_format(audio_path, out_sr=16000):
11
+ """
12
+ 将音频读进来(自动识别格式)并强制重采样到 out_sr,转换为单声道。
13
+ 最终返回:
14
+ - data (numpy array): 处理后(单声道、out_sr)的音频数据
15
+ - sr (int): 处理后的采样率 (默认 16k)
16
+ 不写入临时文件,只做内存操作。
17
+ """
18
+ # librosa.load 会自动解析不同格式的音频(wav/mp3/flac等)
19
+ # 并将其重采样到 out_sr, 同时 mono=True 意味着转换为单声道
20
+ data, sr = librosa.load(audio_path, sr=out_sr, mono=True)
21
+
22
+ return data
23
+
24
+ class AudioMixer(object):
25
+ def __init__(
26
+ self,
27
+ sample_rate=16000,
28
+ mean_snr=-7,
29
+ var_snr=25,
30
+ mean_loudness=-24,
31
+ var_loudness=20
32
+ ):
33
+ """
34
+ 初始化一些参数、随机种子和响度计算工具等。
35
+ """
36
+ self.sample_rate = sample_rate
37
+ self.mean_snr = mean_snr
38
+ self.var_snr = var_snr
39
+ self.MEAN_LOUNDNESS = mean_loudness
40
+ self.VAR_LOUNDNESS = var_loudness
41
+
42
+ self.EPS = 1e-10
43
+ self.MAX_AMP = 0.9
44
+
45
+ # pyloudnorm 的 Meter,用于计算音频响度
46
+ self.meter = pyloudnorm.Meter(self.sample_rate)
47
+
48
+ # # 也可固定随机种子,保证每次混合一致(如果想要可复现)
49
+ # self.seed = 1453
50
+ # random.seed(self.seed)
51
+ # np.random.seed(self.seed)
52
+
53
+ def read_wav(self, wav_path):
54
+ """
55
+ 读取音频文件并返回 wave 数据和采样率
56
+ """
57
+ data, sr = sf.read(wav_path, dtype='float32')
58
+ # 如果读到的是多通道,可只取其中一个通道
59
+ if data.ndim > 1:
60
+ data = data[:, 0]
61
+ return data, sr
62
+
63
+ def normalize(self, signal, is_noise=False):
64
+ """
65
+ 对输入的 signal 做响度归一化,并确保不会过载失真。
66
+ """
67
+ c_loudness = self.meter.integrated_loudness(signal)
68
+ if is_noise:
69
+ # 噪声的目标响度可以偏高一些或随便设置
70
+ target_loudness = np.random.normal(self.MEAN_LOUNDNESS + 4, self.VAR_LOUNDNESS**0.5)
71
+ else:
72
+ # mix 或者语音的目标响度
73
+ target_loudness = np.random.normal(self.MEAN_LOUNDNESS, self.VAR_LOUNDNESS**0.5)
74
+
75
+ with warnings.catch_warnings():
76
+ warnings.filterwarnings("error", category=RuntimeWarning)
77
+ signal = pyloudnorm.normalize.loudness(signal, c_loudness, target_loudness)
78
+
79
+ # # 再检查是否会 clipping
80
+ # peak = np.max(np.abs(signal))
81
+ # if peak >= 1.0:
82
+ # signal = signal * self.MAX_AMP / peak
83
+
84
+ return signal
85
+
86
+ def snr_norm(self, signal, noise, is_noise=True):
87
+ """
88
+ 根据预设的 mean_snr、var_snr 来随机决定一个目标 SNR,然后
89
+ 以此对 noise 做缩放,得到与 signal 相匹配的噪声幅度。
90
+ """
91
+ if is_noise:
92
+ desired_snr = np.random.normal(self.mean_snr, self.var_snr**0.5)
93
+ else:
94
+ # 如果你还有别的需求,比如想做正 SNR 范围,可以改这里
95
+ desired_snr = np.random.uniform(2, 10)
96
+
97
+ current_snr = 10 * np.log10(
98
+ np.mean(signal ** 2) / (np.mean(noise ** 2) + self.EPS) + self.EPS
99
+ )
100
+ scale_factor = 10 ** ((current_snr - desired_snr) / 20)
101
+
102
+ scaled_noise = noise * scale_factor
103
+
104
+ # # 防止噪声自身 clipping
105
+ # peak = np.max(np.abs(scaled_noise))
106
+ # if peak >= 1.0:
107
+ # scaled_noise = scaled_noise * self.MAX_AMP / peak
108
+
109
+ return scaled_noise
110
+
111
+ def _mix(self, sources_list):
112
+ """
113
+ 将多路音频进行叠加,防止溢出。
114
+ """
115
+ # 假设 sources_list[0] 是 mix 音频,sources_list[1] 是已拼好长度的 noise
116
+ mix_length = len(sources_list[0])
117
+ mixture = np.zeros(mix_length, dtype=np.float32)
118
+ for s in sources_list:
119
+ mixture += s[:mix_length] # 仅叠加到 mix 的长度
120
+
121
+ # 再做一次峰值校正,避免溢出
122
+ peak = np.max(np.abs(mixture))
123
+ if peak >= 1.0:
124
+ mixture = mixture * self.MAX_AMP / peak
125
+
126
+ return mixture
127
+
128
+ def _prepare_noise_for_mix(self, noise_files, mix_length):
129
+ """
130
+ 传入一组 noise 文件路径,先对它们打乱,再依次读取、拼接。
131
+ 如果总长度还不够覆盖 mix_length,可以再次拼接自己(循环)。
132
+
133
+ - noise_files: 存储多个噪声文件路径的列表
134
+ - mix_length: 需要的总长度(采样点数)
135
+
136
+ 返回: 拼接后的 noise 波形
137
+ """
138
+ # 先随机打乱
139
+ random.shuffle(noise_files)
140
+
141
+ # 依次读取并拼接
142
+ noise_all = []
143
+ total_len = 0
144
+
145
+ # ���一次先拼完所有 noise 文件,如果还不够,就重复拼接
146
+ while total_len < mix_length:
147
+ for nf in noise_files:
148
+ noise_data, _ = self.read_wav(nf)
149
+
150
+ # 可选:对每条 noise 做一次 normalize,提升多样性
151
+ # (或者只在外部做一次统一的 normalize)
152
+ #noise_data = self.normalize(noise_data, is_noise=True)
153
+
154
+ noise_all.append(noise_data)
155
+ total_len += len(noise_data)
156
+
157
+ if total_len >= mix_length:
158
+ break
159
+
160
+ # 如果已经拼完一轮,可能还不够,就继续 while 循环再拼一轮
161
+
162
+ # 拼接后截断到 mix_length
163
+ concatenated_noise = np.concatenate(noise_all)[:mix_length]
164
+ return concatenated_noise
165
+
166
+ def mix_with_noise_folder(self, mix_wave,sr_mix noise_folder):
167
+ """
168
+ 读取一条 mix 文件和一个 noise 文件夹,做如下处理:
169
+ 1. 读取 mix wave,并做响度归一化
170
+ 2. 根据 mix 的长度,在 noise 文件夹中随机打乱全部 wav,依次拼接满足同长度
171
+ 3. 对最终拼好的 noise 做 snr_norm
172
+ 4. 叠加输出
173
+ """
174
+ # 1. 读取 mix
175
+ # mix_wave, sr_mix = self.read_wav(mix_path)
176
+
177
+ # 如果文件夹下找不到任何 noise 文件,就直接返回原音频
178
+ noise_files = sorted(glob.glob(os.path.join(noise_folder, "*.wav")))
179
+ if not noise_files:
180
+ raise RuntimeError(f"噪声文件夹 {noise_folder} 内未发现 .wav 文件")
181
+
182
+ mix_wave = self.normalize(mix_wave, is_noise=False)
183
+ mix_length = len(mix_wave)
184
+
185
+ # 2. 先把 noise 文件拼接到 match mix_length
186
+ # (会将 noise_files 打乱后依次读、拼接)
187
+ noise_ready = self._prepare_noise_for_mix(noise_files, mix_length)
188
+
189
+ # 3. SNR 调整
190
+ noise_ready = self.snr_norm(mix_wave, noise_ready, is_noise=True)
191
+
192
+ # 4. 叠加
193
+ mixture = self._mix([mix_wave, noise_ready])
194
+
195
+ out_noisy = "temp_noisy.wav" # 可以理解为把输入的混合音频直接另存为
196
+
197
+ # 返回混合后的音频以及采样率
198
+ sf.write(out_noisy, mixture, sr_mix)
199
+
200
+ return out_noisy
201
+
202
+
203
+ if __name__ == "__main__":
204
+ # 假设你有一个 mix.wav 以及一个 noise 文件夹(含若干个 .wav 噪声文件)
205
+ mix_path_test = "test_mix.wav"
206
+ mix_wave, sr_mix = self.read_wav(mix_path_test)
207
+ noise_folder_test = "noises/" # 比如里面有 10 条 noise*.wav
208
+
209
+ mixer = AudioMixer()
210
+
211
+ # 执行混合
212
+ mixed_wav, sr = mixer.mix_with_noise_folder(mix_wave, sr_mix, noise_folder_test)
213
+
214
+ # 这里你可以选择把结果写回本地文件,或直接返回 numpy 数组做后续处理
215
+ sf.write("test_output_mixture.wav", mixed_wav, sr)
216
+ print("混合完成,已输出到 test_output_mixture.wav")
decode.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import logging
5
+ import numpy as np
6
+ import torch as th
7
+ import soundfile as sf
8
+ import hydra
9
+ from omegaconf import OmegaConf
10
+
11
+
12
+
13
+ # ================ 网络推理类 ================
14
+ class NnetComputer(object):
15
+ def __init__(self, cpt_dir, gpuid, nnet_conf):
16
+ self.device = th.device(f"cuda:{gpuid}") if gpuid >= 0 else th.device("cpu")
17
+ nnet = self._load_nnet(cpt_dir, nnet_conf)
18
+ self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet
19
+ self.nnet.eval()
20
+
21
+ def _load_nnet(self, cpt_dir, model):
22
+ cpt = th.load(cpt_dir, map_location="cpu")
23
+ model.load_state_dict(cpt["model_state_dict"])
24
+
25
+ return model
26
+
27
+ def compute(self, samps, aux_samps, aux_samps_len):
28
+ with th.no_grad():
29
+ raw = th.tensor(samps, dtype=th.float32, device=self.device)
30
+ aux = th.tensor(aux_samps, dtype=th.float32, device=self.device)
31
+ aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device)
32
+ aux = aux.unsqueeze(0)
33
+ print("raw",raw.shape)
34
+ print("aux",aux.shape)
35
+ sps, sps2, sps3, spk_pred = self.nnet(raw, aux, aux_len)
36
+ sp_samps = np.squeeze(sps.detach().cpu().numpy())
37
+ return sp_samps
38
+
39
+ class InferencePipeline:
40
+ """
41
+ 外部只需传入 config,即可完成:
42
+ 1) 模型实例化 (含 hydra.instantiate 逻辑)
43
+ 2) 加载 checkpoint
44
+ 3) 推理
45
+ """
46
+ def __init__(self, config):
47
+ """
48
+ 在构造时就把所有初始化做好,包括:
49
+ - hydra.instantiate(config.model) -> 得到一个 nn.Module
50
+ - 用 NnetComputer(...) 封装
51
+ """
52
+ # 如果 config.model 里含有 _target_ 字段,可以用 hydra.instantiate
53
+ # 注意: hydra.instantiate 需要在这里显式地导入 hydra.utils
54
+
55
+ # 1. 根据 config.model 构建模型
56
+ model_inst = hydra.utils.instantiate(config.model)
57
+
58
+ self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst)
59
+
60
+ def run_inference(self, input_audio_path: str, enroll_audio_path: str) -> str:
61
+ """
62
+ 给定混合音频 + enroll 音频,执行推理并返回输出文件路径。
63
+ """
64
+ # 1. 读取音频
65
+ mix_samps, sr = sf.read(input_audio_path)
66
+ aux_samps, sr2 = sf.read(enroll_audio_path)
67
+
68
+ # 2. 调用底层 compute
69
+ samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps))
70
+ norm = np.linalg.norm(mix_samps, np.inf)
71
+ samps = samps[:mix_samps.size]
72
+ samps = samps * norm / np.max(np.abs(samps))
73
+
74
+ # 3. 写到临时文件
75
+ out_wav = "temp_extracted.wav"
76
+ sf.write(out_wav, samps, sr)
77
+ return out_wav
78
+
79
+ if __name__ == "__main__":
80
+ cfg = OmegaConf.load("config/config.yaml")
81
+ pipeline = InferencePipeline(cfg)
82
+
83
+ mix_path = "test_output_mixture.wav"
84
+ enroll_path = "test_mix.wav"
85
+ out_wav = pipeline.run_inference(mix_path, enroll_path)
86
+ print("Done:", out_wav)
87
+
88
+
model/__pycache__/cnns.cpython-37.pyc ADDED
Binary file (8.14 kB). View file
 
model/__pycache__/cnns.cpython-38.pyc ADDED
Binary file (8.01 kB). View file
 
model/__pycache__/norm.cpython-37.pyc ADDED
Binary file (3.88 kB). View file
 
model/__pycache__/norm.cpython-38.pyc ADDED
Binary file (3.77 kB). View file
 
model/__pycache__/spex_plus.cpython-37.pyc ADDED
Binary file (5.9 kB). View file
 
model/__pycache__/spex_plus.cpython-38.pyc ADDED
Binary file (5.98 kB). View file
 
{nnet → model}/cnns.py RENAMED
@@ -2,8 +2,9 @@
2
 
3
  import torch as th
4
  import torch.nn as nn
 
5
 
6
- from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
7
 
8
  class Conv1D(nn.Conv1d):
9
  """
@@ -58,12 +59,23 @@ class TCNBlock(nn.Module):
58
  conv_channels=512,
59
  kernel_size=3,
60
  dilation=1,
61
- causal=False):
 
62
  super(TCNBlock, self).__init__()
63
  self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
64
  self.prelu1 = nn.PReLU()
65
- self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
66
- ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
 
 
 
 
 
 
 
 
 
 
67
  dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
68
  dilation * (kernel_size - 1))
69
  self.dconv = nn.Conv1d(
@@ -75,8 +87,8 @@ class TCNBlock(nn.Module):
75
  dilation=dilation,
76
  bias=True)
77
  self.prelu2 = nn.PReLU()
78
- self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
79
- ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
80
  self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
81
  self.causal = causal
82
  self.dconv_pad = dconv_pad
@@ -108,12 +120,40 @@ class TCNBlock_Spk(nn.Module):
108
  conv_channels=512,
109
  kernel_size=3,
110
  dilation=1,
111
- causal=False):
 
 
112
  super(TCNBlock_Spk, self).__init__()
113
- self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  self.prelu1 = nn.PReLU()
115
- self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
116
- ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
 
 
 
 
 
 
 
 
 
117
  dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
118
  dilation * (kernel_size - 1))
119
  self.dconv = nn.Conv1d(
@@ -125,19 +165,80 @@ class TCNBlock_Spk(nn.Module):
125
  dilation=dilation,
126
  bias=True)
127
  self.prelu2 = nn.PReLU()
128
- self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
129
- ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
130
  self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
131
  self.causal = causal
132
  self.dconv_pad = dconv_pad
133
  self.dilation = dilation
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def forward(self, x, aux):
136
  # Repeatedly concated speaker embedding aux to each frame of the representation x
137
  T = x.shape[-1]
138
- aux = th.unsqueeze(aux, -1)
139
- aux = aux.repeat(1,1,T)
140
- y = th.cat([x, aux], 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  y = self.conv1x1(y)
142
  y = self.norm1(self.prelu1(y))
143
  y = self.dconv(y)
 
2
 
3
  import torch as th
4
  import torch.nn as nn
5
+ import torch.nn.functional as F
6
 
7
+ from .norm import ChannelwiseLayerNorm, GlobalLayerNorm, CumLN
8
 
9
  class Conv1D(nn.Conv1d):
10
  """
 
59
  conv_channels=512,
60
  kernel_size=3,
61
  dilation=1,
62
+ causal=False,
63
+ norm_type='gLN'):
64
  super(TCNBlock, self).__init__()
65
  self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
66
  self.prelu1 = nn.PReLU()
67
+ # self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
68
+ # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
69
+ if norm_type == 'gLN':
70
+ self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
71
+ self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
72
+ elif norm_type == 'cLN':
73
+ self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
74
+ self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
75
+ elif norm_type == 'cgLN':
76
+ self.norm1 = CumLN(conv_channels, elementwise_affine=True)
77
+ self.norm2 = CumLN(conv_channels, elementwise_affine=True)
78
+
79
  dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
80
  dilation * (kernel_size - 1))
81
  self.dconv = nn.Conv1d(
 
87
  dilation=dilation,
88
  bias=True)
89
  self.prelu2 = nn.PReLU()
90
+ # self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
91
+ # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
92
  self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
93
  self.causal = causal
94
  self.dconv_pad = dconv_pad
 
120
  conv_channels=512,
121
  kernel_size=3,
122
  dilation=1,
123
+ causal=False,
124
+ norm_type='gLN',
125
+ fusion_type='cat'):
126
  super(TCNBlock_Spk, self).__init__()
127
+ self.fusion_type = fusion_type
128
+ if fusion_type == 'cat':
129
+ self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1)
130
+ if fusion_type in ('add', 'mul'):
131
+ self.fusion_linear = nn.Linear(spk_embed_dim, in_channels)
132
+ self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
133
+ if fusion_type == 'film':
134
+ self.fusion_linear_1 = nn.Linear(spk_embed_dim, in_channels)
135
+ self.fusion_linear_2 = nn.Linear(spk_embed_dim, in_channels)
136
+ self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
137
+ if fusion_type == 'att':
138
+ self.fusion_linear = nn.Linear(spk_embed_dim, in_channels)
139
+ self.average = Conv1D(in_channels, in_channels, kernel_size, kernel_size, groups=in_channels)
140
+ self.average.weight = nn.Parameter(th.ones(in_channels, 1, kernel_size) / kernel_size)
141
+ self.average.bias = nn.Parameter(th.zeros(in_channels))
142
+ for p in self.average.parameters():
143
+ p.requires_grad = False
144
+ self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
145
  self.prelu1 = nn.PReLU()
146
+ # self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
147
+ # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
148
+ if norm_type == 'gLN':
149
+ self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
150
+ self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
151
+ elif norm_type == 'cLN':
152
+ self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
153
+ self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
154
+ elif norm_type == 'cgLN':
155
+ self.norm1 = CumLN(conv_channels, elementwise_affine=True)
156
+ self.norm2 = CumLN(conv_channels, elementwise_affine=True)
157
  dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
158
  dilation * (kernel_size - 1))
159
  self.dconv = nn.Conv1d(
 
165
  dilation=dilation,
166
  bias=True)
167
  self.prelu2 = nn.PReLU()
168
+ # self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
169
+ # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
170
  self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
171
  self.causal = causal
172
  self.dconv_pad = dconv_pad
173
  self.dilation = dilation
174
 
175
+ def _concatenation(self, aux, output, L):
176
+ aux_concat = th.unsqueeze(aux, -1)
177
+ aux_concat = aux_concat.repeat(1, 1, L)
178
+ # -> [B, N(embeddings_size), L]
179
+ output = th.cat([output, aux_concat], 1)
180
+ # -> [B, N(input_size + embeddings_size), L]
181
+ return output
182
+
183
+ def _addition(self, aux, output, L, fusion_linear):
184
+ aux_add = fusion_linear(aux)
185
+ # -> [B, N(input_size)]
186
+ aux_add = th.unsqueeze(aux_add, -1)
187
+ aux_add = aux_add.repeat(1, 1, L)
188
+ # -> [B, N(input_size), L]
189
+ output = output + aux_add
190
+ # -> [B, N(input_size, L]
191
+ return output
192
+
193
+ def _multiplication(self, aux, output, L, fusion_linear):
194
+ aux_mul = fusion_linear(aux)
195
+ # -> [B, N(input_size)]
196
+ aux_mul = th.unsqueeze(aux_mul, -1)
197
+ aux_mul = aux_mul.repeat(1, 1, L)
198
+ # -> [B, N(input_size), L]
199
+ output = output * aux_mul
200
+ # -> [B, N(input_size, L]
201
+ return output
202
+
203
+ def _attention(self, aux, output, fusion_linear):
204
+ L = output.shape[-1]
205
+ aux_att = fusion_linear(aux)
206
+ aux_att = th.unsqueeze(aux_att, -1)
207
+ aux_att = aux_att.repeat(1, 1, L)
208
+ att = th.sum(output * aux_att, 1, keepdim=True)
209
+ att = F.softmax(att, -1)
210
+ att = att * aux_att
211
+ return att + aux_att
212
+
213
+ def _film(self, aux, output, L):
214
+ output = self._multiplication(aux, output, L, self.fusion_linear_1)
215
+ # -> [B, N(input_size, L]
216
+ output = self._addition(aux, output, L, self.fusion_linear_2)
217
+ # -> [B, N(input_size, L]
218
+ return output
219
+
220
  def forward(self, x, aux):
221
  # Repeatedly concated speaker embedding aux to each frame of the representation x
222
  T = x.shape[-1]
223
+ if self.fusion_type == 'cat':
224
+ y = self._concatenation(aux, x, T)
225
+ # -> [B, N(input_size + embeddings_size), L]
226
+ if self.fusion_type == 'add':
227
+ y = self._addition(aux, x, T, self.fusion_linear)
228
+ # -> [B, N(input_size), L]
229
+ if self.fusion_type == 'mul':
230
+ y = self._multiplication(aux, x, T, self.fusion_linear)
231
+ # -> [B, N(input_size), L]
232
+ if self.fusion_type == 'film':
233
+ y = self._film(aux, x, T)
234
+ # -> [B, N(input_size), L]
235
+ if self.fusion_type == 'att':
236
+ output_avg = self.average(x)
237
+ att_out = self._attention(aux, output_avg, self.fusion_linear)
238
+ upsampling = nn.Upsample(size=T, mode='nearest')
239
+ att_out = upsampling(att_out)
240
+ y = x * att_out
241
+
242
  y = self.conv1x1(y)
243
  y = self.norm1(self.prelu1(y))
244
  y = self.dconv(y)
{nnet → model}/norm.py RENAMED
@@ -22,6 +22,44 @@ class ChannelwiseLayerNorm(nn.LayerNorm):
22
  x = th.transpose(x, 1, 2)
23
  return x
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class GlobalLayerNorm(nn.Module):
26
  """
27
  Global layer normalization
@@ -57,3 +95,4 @@ class GlobalLayerNorm(nn.Module):
57
  def extra_repr(self):
58
  return "{normalized_dim}, eps={eps}, " \
59
  "elementwise_affine={elementwise_affine}".format(**self.__dict__)
 
 
22
  x = th.transpose(x, 1, 2)
23
  return x
24
 
25
+
26
+ class CumLN(nn.Module):
27
+ """
28
+ Cumulative Global layer normalization
29
+ Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
30
+ Output: 3D tensor with same shape
31
+ """
32
+
33
+ def __init__(self, dim, eps=1e-05, elementwise_affine=True):
34
+ super(CumLN, self).__init__()
35
+ self.eps = eps
36
+ self.elementwise_affine = elementwise_affine
37
+ self.normalized_dim = dim
38
+ if elementwise_affine:
39
+ self.beta = nn.Parameter(th.zeros(dim, 1))
40
+ self.gamma = nn.Parameter(th.ones(dim, 1))
41
+ else:
42
+ self.register_parameter("weight", None)
43
+ self.register_parameter("bias", None)
44
+
45
+ def forward(self, x):
46
+ if x.dim() != 3:
47
+ raise RuntimeError("{} requires a 3D tensor input".format(self.__class__.__name__))
48
+ batch, chan, spec_len = x.size()
49
+ cum_sum = th.cumsum(x.sum(1, keepdim=True), dim=-1)
50
+ cum_pow_sum = th.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) #th.cumsum 后加前 逐元素相加
51
+ cnt = th.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype, device=x.device).view(1, 1, -1)
52
+ cum_mean = cum_sum / cnt
53
+ cum_var = cum_pow_sum / cnt - cum_mean.pow(2)
54
+ normalized_x = (x - cum_mean) / (cum_var + self.eps).sqrt()
55
+ if self.elementwise_affine:
56
+ normalized_x = self.gamma * normalized_x + self.beta
57
+ return normalized_x
58
+
59
+ def extra_repr(self):
60
+ return "{normalized_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
61
+
62
+
63
  class GlobalLayerNorm(nn.Module):
64
  """
65
  Global layer normalization
 
95
  def extra_repr(self):
96
  return "{normalized_dim}, eps={eps}, " \
97
  "elementwise_affine={elementwise_affine}".format(**self.__dict__)
98
+
{nnet → model}/spex_plus.py RENAMED
@@ -6,15 +6,9 @@ import torch.nn.functional as F
6
 
7
  from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
8
  from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
 
9
 
10
- import torchaudio
11
- from .ResNet34 import Speaker_Encoder
12
- # from .sunine.trainer.utils import PreEmphasis
13
-
14
-
15
-
16
- # 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
17
- # 注意下维度 是 B N T 还是 B T N
18
 
19
  class SpEx_Plus(nn.Module):
20
  def __init__(self,
@@ -28,14 +22,15 @@ class SpEx_Plus(nn.Module):
28
  Q=3,
29
  num_spks=101,
30
  spk_embed_dim=256,
31
- sample_rate = 16000,
32
- n_mels = 80,
33
  causal=False,
 
 
 
34
  ):
35
  super(SpEx_Plus, self).__init__()
 
36
  # n x S => n x N x T, S = 4s*8000 = 32000
37
- self.sample_rate = sample_rate
38
- self.n_mels = n_mels
39
  self.L1 = L1
40
  self.L2 = L2
41
  self.L3 = L3
@@ -43,82 +38,49 @@ class SpEx_Plus(nn.Module):
43
  self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
44
  self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
45
  # before repeat blocks, always cLN
46
- self.ln = ChannelwiseLayerNorm(3*N)
47
- # n x N x T => n x O x T
48
- self.proj = Conv1D(3*N, O, 1)
49
- self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
50
- self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
51
- self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
52
- self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
53
- self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
54
- self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
55
- self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
56
- self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
57
- # n x O x T => n x N x T
58
- self.mask1 = Conv1D(O, N, 1)
59
- self.mask2 = Conv1D(O, N, 1)
60
- self.mask3 = Conv1D(O, N, 1)
61
- # using ConvTrans1D: n x N x T => n x 1 x To
62
- # To = (T - 1) * L // 2 + L
63
- #############################################################
64
  self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
65
  self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
66
  self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
67
  self.num_spks = num_spks
68
- # self.spk_encoder = nn.Sequential(
69
- # ChannelwiseLayerNorm(3*N),
70
- # Conv1D(3*N, O, 1),
71
- # ResBlock(O, O),
72
- # ResBlock(O, P),
73
- # ResBlock(P, P),
74
- # Conv1D(P, spk_embed_dim, 1),
75
- # )
76
-
77
- # self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
78
-
79
- # 改为pretrain
80
- # 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
81
- # /work105/youzhenghai/model/resnet_asp_aam_adamw_welr
82
- # import ..sunine/trainer/speaker encoder
83
- # **kwargs 无需关心 找到 self.hparams就行 按照 main_infer改就行
84
- #############################################################
85
-
86
- # # 1. Acoustic Feature
87
- # self.mel_trans = th.nn.Sequential(
88
- # PreEmphasis(),
89
- # torchaudio.transforms.MelSpectrogram(sample_rate=self.sample_rate, n_fft=512,
90
- # win_length=400, hop_length=160, window_fn=th.hamming_window, n_mels=self.n_mels)
91
- # )
92
-
93
- # self.instancenorm = nn.InstanceNorm1d(self.n_mels)
94
-
95
- # # 在调用的地方设置超参数 记得后面写为参数传入
96
- # self.hparams = {'embedding_dim': spk_embed_dim, 'pooling_type': 'ASP' , 'n_mels': self.n_mels}
97
- # # 使用 **self.hparams 调用函数
98
- # self.speaker_encoder = Speaker_Encoder(**self.hparams)
99
- self.speaker_embedding_extracter = Speaker_Model(pooling_type='ASP', spk_embed_dim=spk_embed_dim, sample_rate=self.sample_rate, n_mels=self.n_mels)
100
  self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
101
-
102
- #############################################################
103
-
104
- # # 3. Loss / Classifier
105
- # if not self.hparams.evaluate:
106
- # LossFunction = importlib.import_module('trainer.loss.'+self.hparams.loss_type).__getattribute__('LossFunction')
107
- # self.loss = LossFunction(**dict(self.hparams))
108
-
109
-
110
- def _build_stacks(self, num_blocks, **block_kwargs):
111
- """
112
- Stack B numbers of TCN block, the first TCN block takes the speaker embedding
113
- """
114
- blocks = [
115
- TCNBlock(**block_kwargs, dilation=(2**b))
116
- for b in range(1,num_blocks)
117
- ]
118
- return nn.Sequential(*blocks)
119
- # 注意下维度 是 B N T 还是 B T N
120
-
121
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def forward(self, x, aux, aux_len):
124
  if x.dim() >= 3:
@@ -128,8 +90,9 @@ class SpEx_Plus(nn.Module):
128
  # when inference, only one utt
129
  if x.dim() == 1:
130
  x = th.unsqueeze(x, 0)
131
-
132
  # n x 1 x S => n x N x T
 
 
133
  w1 = F.relu(self.encoder_1d_short(x))
134
  T = w1.shape[-1]
135
  xlen1 = x.shape[-1]
@@ -137,42 +100,79 @@ class SpEx_Plus(nn.Module):
137
  xlen3 = (T - 1) * (self.L1 // 2) + self.L3
138
  w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
139
  w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
140
-
141
  # n x 3N x T
142
- y = self.ln(th.cat([w1, w2, w3], 1))
143
- # n x O x T
144
- y = self.proj(y)
145
-
146
  # speaker encoder (share params from speech encoder)
147
- # aux_w1 = F.relu(self.encoder_1d_short(aux))
148
- # aux_T_shape = aux_w1.shape[-1]
149
- # aux_len1 = aux.shape[-1]
150
- # aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2
151
- # aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3
152
- # aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0)))
153
- # aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0)))
154
 
155
- # spk_encoder + mean pooling
156
- # aux = self.spk_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1))
157
- # aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
158
- # aux_T = ((aux_T // 3) // 3) // 3
159
- # aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
160
 
161
- # spk_encoder + TAP pooling
162
- aux = self.speaker_embedding_extracter(aux)
 
 
163
 
 
164
 
 
 
 
 
 
165
 
166
- #aux = torch.mean(aux, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # aux = aux.cpu().detach().numpy()
 
 
 
 
 
 
 
 
169
 
170
- # 不需要 reshape N * D 是正确的维度
171
- #aux = aux.reshape(-1, self.hparams.nPerSpeaker, self.spk_embed_dim)
172
- # loss, acc = self.loss(x, label)
173
- # return loss.mean(), acc
174
- # 考虑 loss 是否也要
175
 
 
 
 
176
  y = self.conv_block_1(y, aux)
177
  y = self.conv_block_1_other(y)
178
  y = self.conv_block_2(y, aux)
@@ -186,62 +186,35 @@ class SpEx_Plus(nn.Module):
186
  m1 = F.relu(self.mask1(y))
187
  m2 = F.relu(self.mask2(y))
188
  m3 = F.relu(self.mask3(y))
189
- S1 = w1 * m1
190
- S2 = w2 * m2
191
- S3 = w3 * m3
192
-
193
- return self.decoder_1d_short(S1), self.decoder_1d_middle(S2)[:, :xlen1], self.decoder_1d_long(S3)[:, :xlen1], self.pred_linear(aux)
194
-
195
- class PreEmphasis(th.nn.Module):
196
- def __init__(self, coef: float = 0.97):
197
- super().__init__()
198
- self.coef = coef
199
- # make kernel
200
- # In pyth, the convolution operation uses cross-correlation. So, filter is flipped.
201
- self.register_buffer(
202
- 'flipped_filter', th.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
203
- )
204
 
205
- def forward(self, inputs: th.tensor) -> th.tensor:
206
- assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
207
- # reflect padding to match lengths of in/out
208
- inputs = inputs.unsqueeze(1)
209
- inputs = F.pad(inputs, (1, 0), 'reflect')
210
- return F.conv1d(inputs, self.flipped_filter).squeeze(1)
211
 
212
 
213
  class Speaker_Model(nn.Module):
214
- #class Speaker_Model(LightningModule):
215
- def __init__(self, pooling_type, spk_embed_dim, sample_rate, n_mels):
216
- super().__init__()
217
- # self.save_hyperparameters()
218
-
219
- self.pooling_type = pooling_type
220
- self.spk_embed_dim = spk_embed_dim
221
- self.sample_rate = sample_rate
222
- self.n_mels = n_mels
223
- sr = self.sample_rate
224
-
225
- self.mel_trans = th.nn.Sequential(
226
- PreEmphasis(),
227
- torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=512,
228
- win_length=sr * 25 // 1000, hop_length=sr * 10 // 1000,
229
- window_fn=th.hamming_window, n_mels=self.n_mels)
230
- )
231
- self.instancenorm = nn.InstanceNorm1d(self.n_mels)
232
-
233
- self.hparams = {'embedding_dim': self.spk_embed_dim, 'pooling_type': self.pooling_type , 'n_mels': self.n_mels}
234
-
235
- self.speaker_encoder = Speaker_Encoder(**dict(self.hparams))
236
-
237
- def extract_speaker_embedding(self, data):
238
- x = data.reshape(-1, data.size()[-1])
239
- x = self.mel_trans(x) + 1e-6
240
- x = x.log()
241
- x = self.instancenorm(x)
242
- x = self.speaker_encoder(x)
243
- return x
244
-
245
- def forward(self, x):
246
- x = self.extract_speaker_embedding(x)
247
- return x
 
6
 
7
  from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
8
  from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
9
+ import warnings
10
 
11
+ # inference aux_len
 
 
 
 
 
 
 
12
 
13
  class SpEx_Plus(nn.Module):
14
  def __init__(self,
 
22
  Q=3,
23
  num_spks=101,
24
  spk_embed_dim=256,
 
 
25
  causal=False,
26
+ norm_type='gLN',
27
+ fusion_type='cat',
28
+ is_innorm=False,
29
  ):
30
  super(SpEx_Plus, self).__init__()
31
+
32
  # n x S => n x N x T, S = 4s*8000 = 32000
33
+
 
34
  self.L1 = L1
35
  self.L2 = L2
36
  self.L3 = L3
 
38
  self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
39
  self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
40
  # before repeat blocks, always cLN
41
+
42
+ self.instancenorm = nn.InstanceNorm1d(N)
43
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
45
  self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
46
  self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
47
  self.num_spks = num_spks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
49
+ self.is_innorm = is_innorm
50
+
51
+ if causal and norm_type not in ["cgLN", "cLN"]:
52
+ norm_type = "cLN"
53
+ warnings.warn(
54
+ "In causal configuration cumulative layer normalization (cgLN)"
55
+ "or channel-wise layer normalization (chanLN) "
56
+ f"must be used. Changing {norm_type} to cLN"
57
+ )
58
+
59
+ self.speaker_encoder = Speaker_Model(
60
+ L1=L1,
61
+ L2=L2,
62
+ L3=L3,
63
+ N=N,
64
+ O=O,
65
+ P=P,
66
+ spk_embed_dim=spk_embed_dim,
67
+ )
68
+
69
+ self.extractor = Extractor(
70
+ L1=L1,
71
+ L2=L2,
72
+ L3=L3,
73
+ N=N,
74
+ B=B,
75
+ O=O,
76
+ P=P,
77
+ Q=Q,
78
+ num_spks=num_spks,
79
+ spk_embed_dim=spk_embed_dim,
80
+ causal=causal,
81
+ fusion_type=fusion_type,
82
+ norm_type=norm_type,
83
+ )
84
 
85
  def forward(self, x, aux, aux_len):
86
  if x.dim() >= 3:
 
90
  # when inference, only one utt
91
  if x.dim() == 1:
92
  x = th.unsqueeze(x, 0)
 
93
  # n x 1 x S => n x N x T
94
+
95
+
96
  w1 = F.relu(self.encoder_1d_short(x))
97
  T = w1.shape[-1]
98
  xlen1 = x.shape[-1]
 
100
  xlen3 = (T - 1) * (self.L1 // 2) + self.L3
101
  w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
102
  w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
 
103
  # n x 3N x T
 
 
 
 
104
  # speaker encoder (share params from speech encoder)
105
+ aux_w1 = F.relu(self.encoder_1d_short(aux))
106
+ aux_T_shape = aux_w1.shape[-1]
107
+ aux_len1 = aux.shape[-1]
108
+ aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2
109
+ aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3
110
+ aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0)))
111
+ aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0)))
112
 
113
+ aux = self.speaker_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1), aux_len)
 
 
 
 
114
 
115
+ if self.is_innorm:
116
+ w1 = self.instancenorm(w1)
117
+ w2 = self.instancenorm(w2)
118
+ w3 = self.instancenorm(w3)
119
 
120
+ m1, m2, m3 = self.extractor(w1, w2, w3, aux)
121
 
122
+ S1 = w1 * m1
123
+ S2 = w2 * m2
124
+ S3 = w3 * m3
125
+
126
+ return self.decoder_1d_short(S1), self.decoder_1d_middle(S2)[:, :xlen1], self.decoder_1d_long(S3)[:, :xlen1], self.pred_linear(aux)
127
 
128
+ class Extractor(nn.Module):
129
+ def __init__(self,
130
+ L1=20,
131
+ L2=80,
132
+ L3=160,
133
+ N=256,
134
+ B=8,
135
+ O=256,
136
+ P=512,
137
+ Q=3,
138
+ num_spks=101,
139
+ spk_embed_dim=256,
140
+ causal=False,
141
+ fusion_type='cat',
142
+ norm_type='gLN',
143
+ ):
144
+ super(Extractor, self).__init__()
145
+ # n x N x T => n x O x T
146
+ self.ln = ChannelwiseLayerNorm(3*N)
147
+ self.proj = Conv1D(3*N, O, 1)
148
+ self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
149
+ self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
150
+ self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
151
+ self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
152
+ self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
153
+ self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
154
+ self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
155
+ self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
156
+ # n x O x T => n x N x T
157
+ self.mask1 = Conv1D(O, N, 1)
158
+ self.mask2 = Conv1D(O, N, 1)
159
+ self.mask3 = Conv1D(O, N, 1)
160
 
161
+ def _build_stacks(self, num_blocks, **block_kwargs):
162
+ """
163
+ Stack B numbers of TCN block, the first TCN block takes the speaker embedding
164
+ """
165
+ blocks = [
166
+ TCNBlock(**block_kwargs, dilation=(2**b))
167
+ for b in range(1,num_blocks)
168
+ ]
169
+ return nn.Sequential(*blocks)
170
 
171
+ def forward(self, w1, w2, w3, aux):
 
 
 
 
172
 
173
+ y = self.ln(th.cat([w1, w2, w3], 1))
174
+ # n x O x T
175
+ y = self.proj(y)
176
  y = self.conv_block_1(y, aux)
177
  y = self.conv_block_1_other(y)
178
  y = self.conv_block_2(y, aux)
 
186
  m1 = F.relu(self.mask1(y))
187
  m2 = F.relu(self.mask2(y))
188
  m3 = F.relu(self.mask3(y))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ return m1, m2, m3
 
 
 
 
 
191
 
192
 
193
  class Speaker_Model(nn.Module):
194
+ def __init__(self,
195
+ L1=20,
196
+ L2=80,
197
+ L3=160,
198
+ N=256,
199
+ O=256,
200
+ P=512,
201
+ spk_embed_dim=256,
202
+ ):
203
+ super(Speaker_Model, self).__init__()
204
+ self.L1 = L1
205
+ self.L2 = L2
206
+ self.L3 = L3
207
+ self.spk_encoder = nn.Sequential(
208
+ ChannelwiseLayerNorm(3*N),
209
+ Conv1D(3*N, O, 1),
210
+ ResBlock(O, O),
211
+ ResBlock(O, P),
212
+ ResBlock(P, P),
213
+ Conv1D(P, spk_embed_dim, 1),
214
+ )
215
+ def forward(self, aux, aux_len):
216
+ aux = self.spk_encoder(aux)
217
+ aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
218
+ aux_T = ((aux_T // 3) // 3) // 3
219
+ aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
220
+ return aux
 
 
 
 
 
 
 
nnet/ResNet34.py DELETED
@@ -1,213 +0,0 @@
1
- #! /usr/bin/python
2
- # -*- encoding: utf-8 -*-
3
- '''
4
- Fast ResNet
5
- https://arxiv.org/pdf/2003.11982.pdf
6
- '''
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from torch.nn import Parameter
12
- try:
13
- from .pooling import *
14
- except:
15
- from pooling import *
16
-
17
- class SEBasicBlock(nn.Module):
18
- expansion = 1
19
-
20
- def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
21
- super(SEBasicBlock, self).__init__()
22
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23
- self.bn1 = nn.BatchNorm2d(planes)
24
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
25
- self.bn2 = nn.BatchNorm2d(planes)
26
- self.relu = nn.ReLU(inplace=True)
27
- self.se = SELayer(planes, reduction)
28
- self.downsample = downsample
29
- self.stride = stride
30
-
31
- def forward(self, x):
32
- residual = x
33
-
34
- out = self.conv1(x)
35
- out = self.relu(out)
36
- out = self.bn1(out)
37
-
38
- out = self.conv2(out)
39
- out = self.bn2(out)
40
- out = self.se(out)
41
-
42
- if self.downsample is not None:
43
- residual = self.downsample(x)
44
-
45
- out += residual
46
- out = self.relu(out)
47
- return out
48
-
49
-
50
- class SEBottleneck(nn.Module):
51
- expansion = 4
52
-
53
- def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
54
- super(SEBottleneck, self).__init__()
55
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56
- self.bn1 = nn.BatchNorm2d(planes)
57
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
58
- padding=1, bias=False)
59
- self.bn2 = nn.BatchNorm2d(planes)
60
- self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
61
- self.bn3 = nn.BatchNorm2d(planes * 4)
62
- self.relu = nn.ReLU(inplace=True)
63
- self.se = SELayer(planes * 4, reduction)
64
- self.downsample = downsample
65
- self.stride = stride
66
-
67
- def forward(self, x):
68
- residual = x
69
-
70
- out = self.conv1(x)
71
- out = self.bn1(out)
72
- out = self.relu(out)
73
-
74
- out = self.conv2(out)
75
- out = self.bn2(out)
76
- out = self.relu(out)
77
-
78
- out = self.conv3(out)
79
- out = self.bn3(out)
80
- out = self.se(out)
81
-
82
- if self.downsample is not None:
83
- residual = self.downsample(x)
84
-
85
- out += residual
86
- out = self.relu(out)
87
-
88
- return out
89
-
90
-
91
- class SELayer(nn.Module):
92
- def __init__(self, channel, reduction=8):
93
- super(SELayer, self).__init__()
94
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
95
- self.fc = nn.Sequential(
96
- nn.Linear(channel, channel // reduction),
97
- nn.ReLU(inplace=True),
98
- nn.Linear(channel // reduction, channel),
99
- nn.Sigmoid()
100
- )
101
-
102
- def forward(self, x):
103
- b, c, _, _ = x.size()
104
- y = self.avg_pool(x).view(b, c)
105
- y = self.fc(y).view(b, c, 1, 1)
106
- return x * y
107
-
108
-
109
- class ResNetSE(nn.Module):
110
- def __init__(self, block, layers, num_filters, embedding_dim, n_mels=80, pooling_type="TSP", **kwargs):
111
- super(ResNetSE, self).__init__()
112
-
113
- self.inplanes = num_filters[0]
114
- self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=3, stride=(1, 1), padding=1,
115
- bias=False)
116
- self.bn1 = nn.BatchNorm2d(num_filters[0])
117
- self.relu = nn.ReLU(inplace=True)
118
-
119
- self.layer1 = self._make_layer(block, num_filters[0], layers[0])
120
- self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
121
- self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
122
- self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2))
123
-
124
- out_dim = num_filters[3] * block.expansion * (n_mels//8)
125
-
126
- if pooling_type == "Temporal_Average_Pooling" or pooling_type == "TAP":
127
- self.pooling = Temporal_Average_Pooling()
128
- self.bn2 = nn.BatchNorm1d(out_dim)
129
- self.fc = nn.Linear(out_dim, embedding_dim)
130
- self.bn3 = nn.BatchNorm1d(embedding_dim)
131
-
132
- elif pooling_type == "Temporal_Statistics_Pooling" or pooling_type == "TSP":
133
- self.pooling = Temporal_Statistics_Pooling()
134
- self.bn2 = nn.BatchNorm1d(out_dim * 2)
135
- self.fc = nn.Linear(out_dim * 2, embedding_dim)
136
- self.bn3 = nn.BatchNorm1d(embedding_dim)
137
-
138
- elif pooling_type == "Self_Attentive_Pooling" or pooling_type == "SAP":
139
- self.pooling = Self_Attentive_Pooling(out_dim)
140
- self.bn2 = nn.BatchNorm1d(out_dim)
141
- self.fc = nn.Linear(out_dim, embedding_dim)
142
- self.bn3 = nn.BatchNorm1d(embedding_dim)
143
-
144
- elif pooling_type == "Attentive_Statistics_Pooling" or pooling_type == "ASP":
145
- self.pooling = Attentive_Statistics_Pooling(out_dim)
146
- self.bn2 = nn.BatchNorm1d(out_dim * 2)
147
- self.fc = nn.Linear(out_dim * 2, embedding_dim)
148
- self.bn3 = nn.BatchNorm1d(embedding_dim)
149
-
150
- else:
151
- raise ValueError('{} pooling type is not defined'.format(pooling_type))
152
-
153
-
154
- for m in self.modules():
155
- if isinstance(m, nn.Conv2d):
156
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
157
- elif isinstance(m, nn.BatchNorm2d):
158
- nn.init.constant_(m.weight, 1)
159
- nn.init.constant_(m.bias, 0)
160
-
161
- def _make_layer(self, block, planes, blocks, stride=1):
162
- downsample = None
163
- if stride != 1 or self.inplanes != planes * block.expansion:
164
- downsample = nn.Sequential(
165
- nn.Conv2d(self.inplanes, planes * block.expansion,
166
- kernel_size=1, stride=stride, bias=False),
167
- nn.BatchNorm2d(planes * block.expansion),
168
- )
169
-
170
- layers = []
171
- layers.append(block(self.inplanes, planes, stride, downsample))
172
- self.inplanes = planes * block.expansion
173
- for i in range(1, blocks):
174
- layers.append(block(self.inplanes, planes))
175
-
176
- return nn.Sequential(*layers)
177
-
178
- def forward(self, x):
179
- x = x.unsqueeze(1)
180
- x = self.conv1(x)
181
- x = self.bn1(x)
182
- x = self.relu(x)
183
-
184
- x = self.layer1(x)
185
- x = self.layer2(x)
186
- x = self.layer3(x)
187
- x = self.layer4(x)
188
-
189
- x = x.reshape(x.shape[0], -1, x.shape[-1])
190
-
191
- x = self.pooling(x)
192
- x = self.bn2(x)
193
- x = torch.flatten(x, 1)
194
- x = self.fc(x)
195
- x = self.bn3(x)
196
- return x
197
-
198
-
199
- def Speaker_Encoder(embedding_dim=256, **kwargs):
200
- # Number of filters
201
- num_filters = [32, 64, 128, 256]
202
- model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, embedding_dim, **kwargs)
203
- return model
204
-
205
- if __name__ == '__main__':
206
- model = Speaker_Encoder()
207
- total = sum([param.nelement() for param in model.parameters()])
208
- print(total/1e6)
209
- data = torch.randn(10, 80, 100)
210
- out = model(data)
211
- print(data.shape)
212
- print(out.shape)
213
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nnet/__init__.py DELETED
File without changes
nnet/pooling.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
5
-
6
- class Temporal_Average_Pooling(nn.Module):
7
- def __init__(self, **kwargs):
8
- """TAP
9
- Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
10
- Link: https://arxiv.org/pdf/1903.12058.pdf
11
- """
12
- super(Temporal_Average_Pooling, self).__init__()
13
-
14
- def forward(self, x):
15
- """Computes Temporal Average Pooling Module
16
- Args:
17
- x (torch.Tensor): Input tensor (#batch, channels, frames).
18
- Returns:
19
- torch.Tensor: Output tensor (#batch, channels)
20
- """
21
- x = torch.mean(x, axis=2)
22
- return x
23
-
24
-
25
- class Temporal_Statistics_Pooling(nn.Module):
26
- def __init__(self, **kwargs):
27
- """TSP
28
- Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
29
- Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
30
- """
31
- super(Temporal_Statistics_Pooling, self).__init__()
32
-
33
- def forward(self, x):
34
- """Computes Temporal Statistics Pooling Module
35
- Args:
36
- x (torch.Tensor): Input tensor (#batch, channels, frames).
37
- Returns:
38
- torch.Tensor: Output tensor (#batch, channels*2)
39
- """
40
- mean = torch.mean(x, axis=2)
41
- var = torch.var(x, axis=2)
42
- x = torch.cat((mean, var), axis=1)
43
- return x
44
-
45
-
46
- ''' Self attentive weighted mean pooling.
47
- '''
48
- class Self_Attentive_Pooling(nn.Module):
49
- def __init__(self, dim, **kwargs):
50
- # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
51
- # attention dim = 128
52
- super(Self_Attentive_Pooling, self).__init__()
53
- self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
54
- self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
55
-
56
- def forward(self, x):
57
- # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
58
- alpha = torch.tanh(self.linear1(x))
59
- alpha = torch.softmax(self.linear2(alpha), dim=2)
60
- mean = torch.sum(alpha * x, dim=2)
61
- return mean
62
-
63
-
64
- ''' Attentive weighted mean and standard deviation pooling.
65
- '''
66
- class Attentive_Statistics_Pooling(nn.Module):
67
- def __init__(self, dim, **kwargs):
68
- # Use AttentiveStatisticsPooling and BatchNorm1d from speechbrain
69
- super(Attentive_Statistics_Pooling, self).__init__()
70
- self.pooling = AttentiveStatisticsPooling(dim)
71
-
72
- def forward(self, x):
73
- x = self.pooling(x)
74
- return x
75
-
76
- # class Attentive_Statistics_Pooling(nn.Module):
77
- # def __init__(self, dim, **kwargs):
78
- # # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
79
- # # attention dim = 128
80
- # super(Attentive_Statistics_Pooling, self).__init__()
81
- # self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
82
- # self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
83
- #
84
- # def forward(self, x):
85
- # # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
86
- # alpha = torch.tanh(self.linear1(x))
87
- # alpha = torch.softmax(self.linear2(alpha), dim=2)
88
- # mean = torch.sum(alpha * x, dim=2)
89
- # residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
90
- # std = torch.sqrt(residuals.clamp(min=1e-9))
91
- # return torch.cat([mean, std], dim=1)
92
-
93
-
94
-
95
- if __name__ == "__main__":
96
- data = torch.randn(10, 128, 100)
97
- pooling = Self_Attentive_Pooling(128)
98
- out = pooling(data)
99
- print(data.shape)
100
- print(out.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nnet/speaker_encoder.py DELETED
@@ -1,47 +0,0 @@
1
- import torch
2
- import torchaudio
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- from .ResNet34 import Speaker_Encoder
6
-
7
-
8
- class Speaker_Model(torch.nn.Module):
9
- #class Speaker_Model(LightningModule):
10
- def __init__(self, pooling_type, spk_embed_dim, sample_rate, n_mels):
11
- super().__init__()
12
- # self.save_hyperparameters()
13
-
14
- self.pooling_type = pooling_type
15
- self.spk_embed_dim = spk_embed_dim
16
- self.sample_rate = sample_rate
17
- self.n_mels = n_mels
18
- sr = self.sample_rate
19
-
20
- self.mel_trans = torch.nn.Sequential(
21
- PreEmphasis(),
22
- torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=512,
23
- win_length=sr * 25 // 1000, hop_length=sr * 10 // 1000,
24
- window_fn=torch.hamming_window, n_mels=self.n_mels)
25
- )
26
- self.instancenorm = nn.InstanceNorm1d(self.n_mels)
27
-
28
- self.hparams = {'embedding_dim': self.spk_embed_dim, 'pooling_type': self.pooling_type , 'n_mels': self.n_mels}
29
-
30
- self.speaker_encoder = Speaker_Encoder(**dict(self.hparams))
31
-
32
- class PreEmphasis(torch.nn.Module):
33
- def __init__(self, coef: float = 0.97):
34
- super().__init__()
35
- self.coef = coef
36
- # make kernel
37
- # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
38
- self.register_buffer(
39
- 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
40
- )
41
-
42
- def forward(self, inputs: torch.tensor) -> torch.tensor:
43
- assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
44
- # reflect padding to match lengths of in/out
45
- inputs = inputs.unsqueeze(1)
46
- inputs = F.pad(inputs, (1, 0), 'reflect')
47
- return F.conv1d(inputs, self.flipped_filter).squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
noises/00840.wav ADDED
Binary file (963 kB). View file
 
noises/022928.wav ADDED
Binary file (961 kB). View file
 
noises/04338.wav ADDED
Binary file (959 kB). View file
 
noises/046324.wav ADDED
Binary file (961 kB). View file
 
noises/093004.wav ADDED
Binary file (960 kB). View file
 
noises/11129.wav ADDED
Binary file (963 kB). View file
 
noises/133254.wav ADDED
Binary file (960 kB). View file
 
noises/30100.wav ADDED
Binary file (959 kB). View file
 
noises/30135.wav ADDED
Binary file (959 kB). View file
 
noises/30437.wav ADDED
Binary file (963 kB). View file
 
noises/30603.wav ADDED
Binary file (959 kB). View file
 
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
- soundfile
2
  gradio
 
 
 
 
 
 
1
+ soundfile==0.12.1
2
  gradio
3
+ hydra-core==1.3.2
4
+ torch==1.11.0
5
+ pyloudnorm
6
+ numpy==1.24.4
7
+ librosa
temp_extracted.wav ADDED
Binary file (180 kB). View file
 
test_mix.wav ADDED
Binary file (180 kB). View file
 
test_output_mixture.wav ADDED
Binary file (180 kB). View file
 
utils/__init__.py DELETED
File without changes
utils/audio.py DELETED
@@ -1,124 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import os
4
- import numpy as np
5
- import soundfile as sf
6
-
7
- def write_wav(fname, samps, sample_rate=16000, normalize=True):
8
- """
9
- Write wav files in float32, support single/multi-channel
10
- """
11
-
12
- # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
13
- # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
14
- fdir = os.path.dirname(fname)
15
- if fdir and not os.path.exists(fdir):
16
- os.makedirs(fdir)
17
- sf.write(fname, samps, sample_rate, subtype='FLOAT')
18
-
19
-
20
- def read_wav(fname, normalize=True, return_rate=False):
21
- """
22
- Read wave files (support multi-channel)
23
- """
24
-
25
- # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
26
- # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
27
- samps, samp_rate = sf.read(fname)
28
- if return_rate:
29
- return samp_rate, samps
30
- return samps
31
-
32
- def parse_scripts(scp_path, value_processor=lambda x: x, num_tokens=2):
33
- """
34
- Parse kaldi's script(.scp) file
35
- If num_tokens >= 2, function will check token number
36
- """
37
- scp_dict = dict()
38
- line = 0
39
- with open(scp_path, "r") as f:
40
- for raw_line in f:
41
- scp_tokens = raw_line.strip().split()
42
- line += 1
43
- if num_tokens >= 2 and len(scp_tokens) != num_tokens or len(
44
- scp_tokens) < 2:
45
- raise RuntimeError(
46
- "For {}, format error in line[{:d}]: {}".format(
47
- scp_path, line, raw_line))
48
- if num_tokens == 2:
49
- key, value = scp_tokens
50
- else:
51
- key, value = scp_tokens[0], scp_tokens[1:]
52
- if key in scp_dict:
53
- raise ValueError("Duplicated key \'{0}\' exists in {1}".format(
54
- key, scp_path))
55
- scp_dict[key] = value_processor(value)
56
- return scp_dict
57
-
58
-
59
- class Reader(object):
60
- """
61
- Basic Reader Class
62
- """
63
-
64
- def __init__(self, scp_path, value_processor=lambda x: x):
65
- self.index_dict = parse_scripts(
66
- scp_path, value_processor=value_processor, num_tokens=2)
67
- self.index_keys = list(self.index_dict.keys())
68
-
69
- def _load(self, key):
70
- # return path
71
- return self.index_dict[key]
72
-
73
- # number of utterance
74
- def __len__(self):
75
- return len(self.index_dict)
76
-
77
- # avoid key error
78
- def __contains__(self, key):
79
- return key in self.index_dict
80
-
81
- # sequential index
82
- def __iter__(self):
83
- for key in self.index_keys:
84
- yield key, self._load(key)
85
-
86
- # random index, support str/int as index
87
- def __getitem__(self, index):
88
- if type(index) not in [int, str]:
89
- raise IndexError("Unsupported index type: {}".format(type(index)))
90
- if type(index) == int:
91
- # from int index to key
92
- num_utts = len(self.index_keys)
93
- if index >= num_utts or index < 0:
94
- raise KeyError(
95
- "Interger index out of range, {:d} vs {:d}".format(
96
- index, num_utts))
97
- index = self.index_keys[index]
98
- if index not in self.index_dict:
99
- raise KeyError("Missing utterance {}!".format(index))
100
- return self._load(index)
101
-
102
-
103
- class WaveReader(Reader):
104
- """
105
- Sequential/Random Reader for single channel wave
106
- Format of wav.scp follows Kaldi's definition:
107
- key1 /path/to/wav
108
- ...
109
- """
110
-
111
- def __init__(self, wav_scp, sample_rate=None, normalize=True):
112
- super(WaveReader, self).__init__(wav_scp)
113
- self.samp_rate = sample_rate
114
- self.normalize = normalize
115
-
116
- def _load(self, key):
117
- # return C x N or N
118
- samp_rate, samps = read_wav(
119
- self.index_dict[key], normalize=self.normalize, return_rate=True)
120
- # if given samp_rate, check it
121
- if self.samp_rate is not None and samp_rate != self.samp_rate:
122
- raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format(
123
- samp_rate, self.samp_rate))
124
- return samps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/dataset copy.py DELETED
@@ -1,284 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import random
4
- import torch as th
5
- import numpy as np
6
-
7
- from torch.utils.data.dataloader import default_collate
8
- import torch.utils.data as dat
9
- from torch.nn.utils.rnn import pad_sequence
10
-
11
- from .audio import WaveReader
12
-
13
- import soundfile as sf
14
-
15
- # random_seed = 1453
16
- # random.seed(random_seed)
17
-
18
- def make_dataloader(train=True,
19
- utt_scp_file=None,
20
- spk_list=None,
21
- sample_rate=16000,
22
- num_workers=4,
23
- chunk_size=32000,
24
- batch_size=16):
25
- dataset = Dataset(utt_scp_file=utt_scp_file,
26
- spk_list=spk_list,
27
- chunk_size=chunk_size,
28
- sample_rate=sample_rate)
29
- return DataLoader(dataset,
30
- train=train,
31
- chunk_size=chunk_size,
32
- batch_size=batch_size,
33
- num_workers=num_workers)
34
-
35
- class Dataset(object):
36
- """
37
- Per Utterance Loader
38
- """
39
- def __init__(self, utt_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
40
- self.sample_rate = sample_rate
41
- self.spk_list = self._load_spk(spk_list)
42
-
43
- self.seg_least= int(chunk_size // 2 )
44
-
45
- # self.mix = WaveReader(mix_scp, sample_rate=sample_rate)
46
- # self.ref = WaveReader(ref_scp, sample_rate=sample_rate)
47
- # self.aux = WaveReader(aux_scp, sample_rate=sample_rate)
48
-
49
- with open(utt_scp_file, 'r') as f:
50
- lines = f.readlines()
51
- self.data = []
52
- self.total_lines = len(self.data)
53
- for line in lines:
54
- parts = line.strip().split()
55
- sentence_id = parts[0]
56
- sentence_path = parts[1]
57
- data_len = parts[2]
58
- spk_id = (sentence_id.split('-')[0])[1:5]
59
- self.data.append((sentence_id, spk_id, sentence_path, data_len))
60
-
61
- if not self.data:
62
- raise ValueError("No valid lines found in the input file.")
63
- self.total_lines = len(self.data)
64
-
65
- def _load_spk(self, spk_list_path):
66
- if spk_list_path is None:
67
- return []
68
- lines = open(spk_list_path).readlines()
69
- new_lines = []
70
- for line in lines:
71
- new_lines.append(line.strip())
72
-
73
- return new_lines
74
-
75
- def __len__(self):
76
- return len(self.data)
77
-
78
- def _get_segment_start_stop(self, seg_len, length):
79
- if seg_len is not None:
80
- start = random.randint(0, length - seg_len)
81
- stop = start + seg_len
82
- else:
83
- start = 0
84
- stop = None
85
- return start, stop
86
-
87
- def _mix(self, sources_list):
88
-
89
- # if self.seg_len:
90
- # mix_length = self.seg_len
91
-
92
- # else:
93
- # mix_length = self.common_length
94
- mix_length = self.common_length
95
- mixture = np.zeros(mix_length)
96
- for i, _ in enumerate(sources_list):
97
- mixture += sources_list[i]
98
-
99
- return mixture
100
-
101
- def __getitem__(self, idx):
102
- source_id, source_spk, source_path, all_source_length= self.data[idx]
103
- all_source_length = int(all_source_length)
104
- spk_idx = self.spk_list.index(source_spk)
105
-
106
- other_counter = 0
107
- while True:
108
- random_idx = np.random.randint(0, self.total_lines)
109
- if self.data[random_idx][1] != source_spk:
110
- other_id, other_spk, other_path, other_length = self.data[random_idx]
111
- other_length = int(other_length)
112
-
113
- if other_length > self.seg_least:
114
- break
115
-
116
- other_counter += 1
117
-
118
- if other_counter >= self.total_lines:
119
- raise ValueError("All Data too shorter to mix")
120
-
121
- enroll_counter = 0
122
-
123
- while True:
124
- random_idx = np.random.randint(0, self.total_lines)
125
- if self.data[random_idx][1] == source_spk:
126
- enroll_id, enroll_spk, enroll_path, all_enroll_length= self.data[random_idx]
127
- all_enroll_length = int(all_enroll_length)
128
- if all_enroll_length > self.seg_least:
129
- break
130
-
131
- enroll_counter += 1
132
- if enroll_counter >= self.total_lines:
133
- raise ValueError("All Data too shorter to enroll")
134
- # lengths = [all_source_length, other_length]
135
-
136
- if all_source_length >= other_length:
137
- self.common_length = other_length
138
- start, stop = self._get_segment_start_stop(other_length, all_source_length)
139
- source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
140
- other_tmp,_ = sf.read(other_path, dtype="float32")
141
- elif all_source_length <= other_length:
142
- self.common_length = all_source_length
143
- start, stop = self._get_segment_start_stop(all_source_length, other_length)
144
- source_tmp,_ = sf.read(source_path, dtype="float32")
145
- other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
146
-
147
- source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
148
-
149
- other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
150
-
151
- mixture = self._mix([source, other])
152
- mixture = mixture.astype(np.float32)
153
-
154
- enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
155
- enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
156
-
157
- return {
158
- "mix": mixture,
159
- "ref": source,
160
- "aux": enroll,
161
- "aux_len": len(enroll),
162
- "spk_idx": spk_idx
163
- }
164
-
165
- class ChunkSplitter(object):
166
- """
167
- Split utterance into small chunks
168
- """
169
- def __init__(self, chunk_size, train=True, least=16000):
170
- self.chunk_size = chunk_size
171
- self.least = least
172
- self.train = train
173
-
174
- def _make_chunk(self, eg, s):
175
- """
176
- Make a chunk instance, which contains:
177
- "mix": ndarray,
178
- "ref": [ndarray...]
179
- """
180
- chunk = dict()
181
- chunk["mix"] = eg["mix"][s:s + self.chunk_size]
182
- chunk["ref"] = eg["ref"][s:s + self.chunk_size]
183
- chunk["aux"] = eg["aux"]
184
- chunk["aux_len"] = eg["aux_len"]
185
- chunk["valid_len"] = int(self.chunk_size)
186
- chunk["spk_idx"] = eg["spk_idx"]
187
- return chunk
188
-
189
- def split(self, eg):
190
- N = eg["mix"].size
191
- # too short, throw away
192
- if N < self.least:
193
- return []
194
- chunks = []
195
- # padding zeros
196
- if N < self.chunk_size:
197
- P = self.chunk_size - N
198
- chunk = dict()
199
- chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
200
- chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
201
- chunk["aux"] = eg["aux"]
202
- chunk["aux_len"] = eg["aux_len"]
203
- chunk["valid_len"] = int(N)
204
- chunk["spk_idx"] = eg["spk_idx"]
205
- chunks.append(chunk)
206
- else:
207
- # random select start point for training
208
- s = random.randint(0, N % self.least) if self.train else 0
209
- while True:
210
- if s + self.chunk_size > N:
211
- break
212
- chunk = self._make_chunk(eg, s)
213
- chunks.append(chunk)
214
- s += self.least
215
- return chunks
216
-
217
-
218
- class DataLoader(object):
219
- """
220
- Online dataloader for chunk-level
221
- """
222
- def __init__(self,
223
- dataset,
224
- num_workers=4,
225
- chunk_size=32000,
226
- batch_size=16,
227
- train=True):
228
- self.batch_size = batch_size
229
- self.train = train
230
- self.splitter = ChunkSplitter(chunk_size,
231
- train=train,
232
- least=chunk_size // 2)
233
- # just return batch of egs, support multiple workers
234
- self.eg_loader = dat.DataLoader(dataset,
235
- batch_size=batch_size // 2,
236
- num_workers=num_workers,
237
- shuffle=train,
238
- collate_fn=self._collate)
239
-
240
- def _collate(self, batch):
241
- """
242
- Online split utterances
243
- """
244
- chunk = []
245
- for eg in batch:
246
- chunk += self.splitter.split(eg)
247
- return chunk
248
-
249
- def _pad_aux(self, chunk_list):
250
- lens_list = []
251
- for chunk_item in chunk_list:
252
- lens_list.append(chunk_item['aux_len'])
253
- max_len = np.max(lens_list)
254
-
255
-
256
- for idx in range(len(chunk_list)):
257
- P = max_len - len(chunk_list[idx]["aux"])
258
- chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
259
-
260
- return chunk_list
261
-
262
- def _merge(self, chunk_list):
263
- """
264
- Merge chunk list into mini-batch
265
- """
266
- N = len(chunk_list)
267
- if self.train:
268
- random.shuffle(chunk_list)
269
- blist = []
270
- for s in range(0, N - self.batch_size + 1, self.batch_size):
271
- # padding aux info
272
- #self._pad_aux(chunk_list[s:s + self.batch_size])
273
- batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
274
- blist.append(batch)
275
- rn = N % self.batch_size
276
- return blist, chunk_list[-rn:] if rn else []
277
-
278
- def __iter__(self):
279
- chunk_list = []
280
- for chunks in self.eg_loader:
281
- chunk_list += chunks
282
- batch, chunk_list = self._merge(chunk_list)
283
- for obj in batch:
284
- yield obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/dataset.py DELETED
@@ -1,402 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import random
4
- import torch as th
5
- import numpy as np
6
-
7
- from torch.utils.data.dataloader import default_collate
8
- import torch.utils.data as dat
9
- from torch.nn.utils.rnn import pad_sequence
10
-
11
- from .audio import WaveReader
12
-
13
- import soundfile as sf
14
-
15
- # random_seed = 1453
16
- # random.seed(random_seed)
17
-
18
- # "aux_len": all_enroll_length,
19
-
20
- EPS = 1e-10
21
- def make_dataloader(train=True,
22
- mix_scp_file=None,
23
- enroll_scp_file=None,
24
- noise_scp_file=None,
25
- spk_list=None,
26
- sample_rate=16000,
27
- num_workers=4,
28
- chunk_size=32000,
29
- batch_size=16):
30
- dataset = Dataset(mix_scp_file=mix_scp_file,
31
- enroll_scp_file=enroll_scp_file,
32
- noise_scp_file=noise_scp_file,
33
- spk_list=spk_list,
34
- chunk_size=chunk_size,
35
- sample_rate=sample_rate)
36
- return DataLoader(dataset,
37
- train=train,
38
- chunk_size=chunk_size,
39
- batch_size=batch_size,
40
- num_workers=num_workers)
41
-
42
- class Dataset(object):
43
- """
44
- Per Utterance Loader
45
- """
46
- def __init__(self, mix_scp_file="", enroll_scp_file="", noise_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
47
- self.sample_rate = sample_rate
48
- self.spk_list = self._load_spk(spk_list)
49
-
50
- self.seg_least= int(chunk_size // 2 )
51
-
52
- with open(mix_scp_file, 'r') as f:
53
- lines = f.readlines()
54
- self.data = []
55
-
56
-
57
- for line in lines:
58
- parts = line.strip().split()
59
- sentence_id = parts[0]
60
- sentence_path = parts[1]
61
- data_len = parts[2]
62
- spk_id = (sentence_id.split('-')[0])[1:5]
63
- self.data.append((sentence_id, spk_id, sentence_path, data_len))
64
-
65
- with open(enroll_scp_file, 'r') as f:
66
- enroll_lines = f.readlines()
67
- self.enroll_data = []
68
-
69
- for line in enroll_lines:
70
- parts = line.strip().split()
71
- sentence_id = parts[0]
72
- sentence_path = parts[1]
73
- data_len = parts[2]
74
- spk_id = (sentence_id.split('-')[0])[1:5]
75
- self.enroll_data.append((sentence_id, spk_id, sentence_path, data_len))
76
-
77
-
78
- with open(noise_scp_file, 'r') as f:
79
- noise_lines = f.readlines()
80
- self.noise_data = []
81
-
82
- for line in noise_lines:
83
- parts = line.strip().split()
84
- sentence_id = parts[0]
85
- sentence_path = parts[1]
86
- data_len = parts[2]
87
- # spk_id = (sentence_id.split('-')[0])[1:5]
88
- self.noise_data.append((sentence_id, sentence_path, data_len))
89
-
90
- self.total_lines = len(self.data)
91
- self.total_enroll = self._enroll_data_len()
92
- self.total_noise = self._noise_data_len()
93
-
94
- if not self.data:
95
- raise ValueError("No valid lines found in the input file.")
96
-
97
-
98
- def _load_spk(self, spk_list_path):
99
- if spk_list_path is None:
100
- return []
101
- lines = open(spk_list_path).readlines()
102
- new_lines = []
103
- for line in lines:
104
- new_lines.append(line.strip())
105
-
106
- return new_lines
107
-
108
- def __len__(self):
109
- return len(self.data)
110
-
111
- def _enroll_data_len(self):
112
- return len(self.enroll_data)
113
-
114
- def _noise_data_len(self):
115
- return len(self.noise_data)
116
-
117
- def _get_segment_start_stop(self, seg_len, length):
118
- if seg_len is not None:
119
- start = random.randint(0, length - seg_len)
120
- stop = start + seg_len
121
- else:
122
- start = 0
123
- stop = None
124
- return start, stop
125
-
126
- def _mix(self, sources_list):
127
-
128
- # if self.seg_len:
129
- # mix_length = self.seg_len
130
-
131
- # else:
132
- # mix_length = self.common_length
133
- mix_length = self.common_length
134
- mixture = np.zeros(mix_length)
135
- for i, _ in enumerate(sources_list):
136
- mixture += sources_list[i]
137
-
138
- return mixture
139
-
140
- def __getitem__(self, idx):
141
- source_id, source_spk, source_path, all_source_length= self.data[idx]
142
- all_source_length = int(all_source_length)
143
- spk_idx = self.spk_list.index(source_spk)
144
-
145
- other_counter = 0
146
- while True:
147
- random_idx = np.random.randint(0, self.total_lines)
148
- if self.data[random_idx][1] != source_spk:
149
- other_id, other_spk, other_path, other_length = self.data[random_idx]
150
- other_length = int(other_length)
151
-
152
- if other_length > self.seg_least:
153
- break
154
-
155
- other_counter += 1
156
-
157
- if other_counter >= self.total_lines:
158
- raise ValueError("All Data too shorter to mix")
159
-
160
-
161
- if all_source_length >= other_length:
162
- self.common_length = other_length
163
- start, stop = self._get_segment_start_stop(self.common_length, all_source_length)
164
- source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
165
- other_tmp,_ = sf.read(other_path, dtype="float32")
166
- elif all_source_length <= other_length:
167
- self.common_length = all_source_length
168
- start, stop = self._get_segment_start_stop(self.common_length, other_length)
169
- source_tmp,_ = sf.read(source_path, dtype="float32")
170
- other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
171
-
172
- noise_counter = 0
173
- while True:
174
- random_idx = np.random.randint(0, self.total_noise)
175
-
176
- noise_id, noise_path, all_noise_length= self.noise_data[random_idx]
177
- all_noise_length = int(all_noise_length)
178
-
179
- if all_noise_length >= self.common_length:
180
- break
181
- noise_counter += 1
182
- if noise_counter >= self.total_noise:
183
- raise ValueError("All Data can't as noise")
184
-
185
- enroll_counter = 0
186
- while True:
187
- random_idx = np.random.randint(0, self.total_enroll)
188
- if self.enroll_data[random_idx][1] == source_spk:
189
- enroll_id, enroll_spk, enroll_path, all_enroll_length= self.enroll_data[random_idx]
190
- all_enroll_length = int(all_enroll_length)
191
- break
192
-
193
- enroll_counter += 1
194
- if enroll_counter >= self.total_enroll:
195
- raise ValueError("All Data can't as enroll")
196
-
197
-
198
-
199
-
200
- source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
201
- other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
202
-
203
- noise_start, noise_stop = self._get_segment_start_stop(self.common_length, all_noise_length)
204
- noise,_ = sf.read(noise_path, dtype="float32", start=noise_start, stop=noise_stop) # single channel?
205
- # noise = noise_tmp[:, np.random.randint(0, noise_tmp.shape[1])]
206
- # other_noise = self._mix([other,noise])
207
- desired_snr = np.random.uniform(-4, 4) # 设置目标 SNR
208
- current_snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(noise ** 2) + EPS) + EPS)
209
- scale_factor = 10 ** ((current_snr - desired_snr ) / 20)
210
- scaled_noise = noise * scale_factor
211
-
212
- snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(scaled_noise ** 2) + EPS) + EPS)
213
- mixture = self._mix([source,other,scaled_noise])
214
-
215
- mixture = mixture.astype(np.float32)
216
-
217
- enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
218
- enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
219
-
220
- return {
221
- "mix": mixture,
222
- "ref": source,
223
- "aux": enroll,
224
- "aux_len": all_enroll_length,
225
- "spk_idx": spk_idx
226
- }
227
-
228
- class ChunkSplitter(object):
229
- """
230
- Split utterance into small chunks
231
- """
232
- def __init__(self, chunk_size, train=True, least=16000):
233
- self.chunk_size = chunk_size
234
- self.least = least
235
- self.train = train
236
-
237
- def _make_chunk(self, eg, s):
238
- """
239
- Make a chunk instance, which contains:
240
- "mix": ndarray,
241
- "ref": [ndarray...]
242
- """
243
- chunk = dict()
244
- chunk["mix"] = eg["mix"][s:s + self.chunk_size]
245
- chunk["ref"] = eg["ref"][s:s + self.chunk_size]
246
- chunk["aux"] = eg["aux"]
247
- chunk["aux_len"] = eg["aux_len"]
248
- chunk["valid_len"] = int(self.chunk_size)
249
- chunk["spk_idx"] = eg["spk_idx"]
250
- return chunk
251
-
252
- def split(self, eg):
253
- N = eg["mix"].size
254
- # too short, throw away
255
- if N < self.least:
256
- return []
257
- chunks = []
258
- # padding zeros
259
- if N < self.chunk_size:
260
- P = self.chunk_size - N
261
- chunk = dict()
262
- chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
263
- chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
264
- chunk["aux"] = eg["aux"]
265
- chunk["aux_len"] = eg["aux_len"]
266
- chunk["valid_len"] = int(N)
267
- chunk["spk_idx"] = eg["spk_idx"]
268
- chunks.append(chunk)
269
- # else:
270
- # # random select start point for training
271
- # s = random.randint(0, N % self.least) if self.train else 0
272
- # while True:
273
- # if s + self.chunk_size > N:
274
- # break
275
- # chunk = self._make_chunk(eg, s)
276
- # chunks.append(chunk)
277
- # s += self.least
278
- # return chunks
279
-
280
- else:
281
- if self.train:
282
- # random select A start point for training
283
- s = random.randint(0, N - self.chunk_size)
284
- chunk = self._make_chunk(eg, s)
285
- chunks.append(chunk)
286
- else:
287
- s = 0
288
- while True:
289
- if s + self.chunk_size > N:
290
- break
291
- chunk = self._make_chunk(eg, s)
292
- chunks.append(chunk)
293
- s += self.least
294
- return chunks
295
-
296
- class DataLoader(object):
297
- """
298
- Online dataloader for chunk-level
299
- """
300
- def __init__(self,
301
- dataset,
302
- num_workers=4,
303
- chunk_size=32000,
304
- batch_size=16,
305
- train=True):
306
- self.batch_size = batch_size
307
- self.train = train
308
- self.splitter = ChunkSplitter(chunk_size,
309
- train=train,
310
- least=chunk_size // 2)
311
- # just return batch of egs, support multiple workers
312
- self.eg_loader = dat.DataLoader(dataset,
313
- batch_size=batch_size // 2,
314
- num_workers=num_workers,
315
- shuffle=train,
316
- collate_fn=self._collate)
317
-
318
- def _collate(self, batch):
319
- """
320
- Online split utterances
321
- """
322
- chunk = []
323
- for eg in batch:
324
- chunk += self.splitter.split(eg)
325
- return chunk
326
-
327
- def _pad_aux(self, chunk_list):
328
- lens_list = []
329
- for chunk_item in chunk_list:
330
- lens_list.append(chunk_item['aux_len'])
331
- max_len = np.max(lens_list)
332
- # pad 0
333
- for idx in range(len(chunk_list)):
334
- P = max_len - len(chunk_list[idx]["aux"])
335
- chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
336
- # # pad circle
337
- # for idx in range(len(chunk_list)):
338
- # P = max_len - len(chunk_list[idx]["aux"])
339
- # original_aux_len = len(chunk_list[idx]["aux"])
340
- # # 使用循环来填充原句子的内容
341
- # for i in range(P):
342
- # chunk_list[idx]["aux"].append(chunk_list[idx]["aux"][i % original_aux_len])
343
-
344
-
345
- return chunk_list
346
-
347
- def _merge(self, chunk_list):
348
- """
349
- Merge chunk list into mini-batch
350
- """
351
- N = len(chunk_list)
352
- if self.train:
353
- random.shuffle(chunk_list)
354
- blist = []
355
- for s in range(0, N - self.batch_size + 1, self.batch_size):
356
- # padding aux info
357
- #self._pad_aux(chunk_list[s:s + self.batch_size])
358
- batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
359
- blist.append(batch)
360
- rn = N % self.batch_size
361
- return blist, chunk_list[-rn:] if rn else []
362
-
363
- def __iter__(self):
364
- chunk_list = []
365
- for chunks in self.eg_loader:
366
- chunk_list += chunks
367
- batch, chunk_list = self._merge(chunk_list)
368
- for obj in batch:
369
- yield obj
370
-
371
-
372
-
373
- # def snr_xy(x, y):
374
- # return 10 * np.log10(np.mean(x ** 2) / (np.mean(y ** 2) + EPS) + EPS)
375
-
376
- # def main(args):
377
- # wham_noise_dir = args.wham_dir
378
- # # Get train dir
379
- # subdir = os.path.join(wham_noise_dir, 'tr')
380
- # # List files in that dir
381
- # sound_paths = glob.glob(os.path.join(subdir, '**/*.wav'),
382
- # recursive=True)
383
- # # Avoid running this script if it already have been run
384
- # if len(sound_paths) == 60000:
385
- # print("It appears that augmented files have already been generated.\n"
386
- # "Skipping data augmentation.")
387
- # return
388
- # elif len(sound_paths) != 20000:
389
- # print("It appears that augmented files have not been generated properly\n"
390
- # "Resuming augmentation.")
391
- # originals = [x for x in sound_paths if 'sp' not in x]
392
- # to_be_removed_08 = [x.replace('sp08','') for x in sound_paths if 'sp08' in x]
393
- # to_be_removed_12 = [x.replace('sp12','') for x in sound_paths if 'sp12' in x ]
394
- # sound_paths_08 = list(set(originals) - set(to_be_removed_08))
395
- # sound_paths_12 = list(set(originals) - set(to_be_removed_12))
396
- # augment_noise(sound_paths_08, 0.8)
397
- # augment_noise(sound_paths_12, 1.2)
398
- # else:
399
- # print(f'Augmenting {subdir} files')
400
- # # Transform audio speed
401
- # augment_noise(sound_paths, 0.8)
402
- # augment_noise(sound_paths, 1.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/load_obj.py DELETED
@@ -1,18 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import torch as th
4
-
5
- def load_obj(obj, device):
6
- """
7
- Offload tensor object in obj to cuda device
8
- """
9
-
10
- def cuda(obj):
11
- return obj.to(device) if isinstance(obj, th.Tensor) else obj
12
-
13
- if isinstance(obj, dict):
14
- return {key: load_obj(obj[key], device) for key in obj}
15
- elif isinstance(obj, list):
16
- return [load_obj(val, device) for val in obj]
17
- else:
18
- return cuda(obj)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/logger.py DELETED
@@ -1,22 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import logging
4
-
5
- def get_logger(
6
- name,
7
- format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
8
- date_format="%Y-%m-%d %H:%M:%S",
9
- file=False):
10
- """
11
- Get python logger instance
12
- """
13
- logger = logging.getLogger(name)
14
- logger.setLevel(logging.INFO)
15
- # file or console
16
- handler = logging.StreamHandler() if not file else logging.FileHandler(
17
- name)
18
- handler.setLevel(logging.INFO)
19
- formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
20
- handler.setFormatter(formatter)
21
- logger.addHandler(handler)
22
- return logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/sisdr.py DELETED
@@ -1,23 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import numpy as np
4
-
5
- def sisdr(x, s, remove_dc=True):
6
- """
7
- Compute SI-SDR
8
- x: extracted signal
9
- s: reference signal(ground truth)
10
- """
11
-
12
- def vec_l2norm(x):
13
- return np.linalg.norm(x, 2)
14
-
15
- if remove_dc:
16
- x_zm = x - np.mean(x)
17
- s_zm = s - np.mean(s)
18
- t = np.inner(x_zm, s_zm) * s_zm / vec_l2norm(s_zm)**2
19
- n = x_zm - t
20
- else:
21
- t = np.inner(x, s) * s / vec_l2norm(s)**2
22
- n = x - t
23
- return 20 * np.log10(vec_l2norm(t) / vec_l2norm(n))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/timer.py DELETED
@@ -1,17 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import time
4
-
5
- class Timer(object):
6
- """
7
- A timer to record the elapsed time
8
- """
9
-
10
- def __init__(self):
11
- self.reset()
12
-
13
- def reset(self):
14
- self.start = time.time()
15
-
16
- def elapsed(self):
17
- return (time.time() - self.start) / 60