Spaces:
Paused
Paused
Commit ·
1af34cd
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +5 -0
- .gitattributes +35 -0
- .gitignore +25 -0
- Dockerfile +24 -0
- README.md +129 -0
- examples/clean_unet/run.sh +181 -0
- examples/clean_unet/step_1_prepare_data.py +201 -0
- examples/clean_unet/step_2_train_model.py +419 -0
- examples/clean_unet/step_3_evaluation.py +6 -0
- examples/clean_unet/yaml/config.yaml +14 -0
- examples/conv_tasnet/run.sh +154 -0
- examples/conv_tasnet/step_1_prepare_data.py +164 -0
- examples/conv_tasnet/step_2_train_model.py +509 -0
- examples/conv_tasnet/yaml/config.yaml +28 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py +90 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py +129 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py +71 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py +93 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py +77 -0
- examples/data_preprocess/dns_challenge_to_8k/process_musan.py +8 -0
- examples/data_preprocess/ms_snsd_to_8k/process_ms_snsd.py +70 -0
- examples/dfnet/run.sh +156 -0
- examples/dfnet/step_1_prepare_data.py +164 -0
- examples/dfnet/step_2_train_model.py +461 -0
- examples/dfnet/yaml/config.yaml +74 -0
- examples/dfnet2/run.sh +164 -0
- examples/dfnet2/step_1_prepare_data.py +164 -0
- examples/dfnet2/step_2_train_model.py +469 -0
- examples/dfnet2/yaml/config.yaml +75 -0
- examples/dtln/run.sh +171 -0
- examples/dtln/step_1_prepare_data.py +164 -0
- examples/dtln/step_2_train_model.py +437 -0
- examples/dtln/yaml/config-1024.yaml +29 -0
- examples/dtln/yaml/config-256.yaml +29 -0
- examples/dtln/yaml/config-512.yaml +29 -0
- examples/dtln_mp3_to_wav/run.sh +168 -0
- examples/dtln_mp3_to_wav/step_1_prepare_data.py +127 -0
- examples/dtln_mp3_to_wav/step_2_train_model.py +445 -0
- examples/dtln_mp3_to_wav/yaml/config-1024.yaml +29 -0
- examples/dtln_mp3_to_wav/yaml/config-256.yaml +29 -0
- examples/dtln_mp3_to_wav/yaml/config-512.yaml +29 -0
- examples/frcrn/run.sh +159 -0
- examples/frcrn/step_1_prepare_data.py +164 -0
- examples/frcrn/step_2_train_model.py +457 -0
- examples/frcrn/yaml/config-10.yaml +31 -0
- examples/frcrn/yaml/config-14.yaml +31 -0
- examples/frcrn/yaml/config-20.yaml +31 -0
- examples/frcrn_mp3_to_wav/run.sh +156 -0
- examples/frcrn_mp3_to_wav/step_1_prepare_data.py +127 -0
- examples/frcrn_mp3_to_wav/step_2_train_model.py +442 -0
.dockerignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
.git/
|
| 3 |
+
.idea/
|
| 4 |
+
|
| 5 |
+
/examples/
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
.gradio/
|
| 3 |
+
.git/
|
| 4 |
+
.idea/
|
| 5 |
+
|
| 6 |
+
**/evaluation_audio/
|
| 7 |
+
**/file_dir/
|
| 8 |
+
**/flagged/
|
| 9 |
+
**/log/
|
| 10 |
+
**/logs/
|
| 11 |
+
**/__pycache__/
|
| 12 |
+
|
| 13 |
+
/data/
|
| 14 |
+
/docs/
|
| 15 |
+
/dotenv/
|
| 16 |
+
/hub_datasets/
|
| 17 |
+
/script/
|
| 18 |
+
/thirdparty/
|
| 19 |
+
/trained_models/
|
| 20 |
+
/temp/
|
| 21 |
+
|
| 22 |
+
**/*.wav
|
| 23 |
+
**/*.xlsx
|
| 24 |
+
|
| 25 |
+
requirements-python-3-9-9.txt
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
COPY . /code
|
| 6 |
+
|
| 7 |
+
RUN apt-get update
|
| 8 |
+
RUN apt-get install -y ffmpeg build-essential
|
| 9 |
+
|
| 10 |
+
RUN pip install --upgrade pip
|
| 11 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 12 |
+
|
| 13 |
+
RUN useradd -m -u 1000 user
|
| 14 |
+
|
| 15 |
+
USER user
|
| 16 |
+
|
| 17 |
+
ENV HOME=/home/user \
|
| 18 |
+
PATH=/home/user/.local/bin:$PATH
|
| 19 |
+
|
| 20 |
+
WORKDIR $HOME/app
|
| 21 |
+
|
| 22 |
+
COPY --chown=user . $HOME/app
|
| 23 |
+
|
| 24 |
+
CMD ["python3", "main.py"]
|
README.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: NX Denoise
|
| 3 |
+
emoji: 🐢
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 12 |
+
## NX Denoise
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
### datasets
|
| 16 |
+
|
| 17 |
+
```text
|
| 18 |
+
|
| 19 |
+
AISHELL (15G)
|
| 20 |
+
https://openslr.trmal.net/resources/33/
|
| 21 |
+
|
| 22 |
+
AISHELL-3 (19G)
|
| 23 |
+
http://www.openslr.org/93/
|
| 24 |
+
|
| 25 |
+
DNS3
|
| 26 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
| 27 |
+
噪音数据来源于 DEMAND, FreeSound, AudioSet.
|
| 28 |
+
|
| 29 |
+
MS-SNSD
|
| 30 |
+
https://github.com/microsoft/MS-SNSD
|
| 31 |
+
噪音数据来源于 DEMAND, FreeSound.
|
| 32 |
+
|
| 33 |
+
MUSAN
|
| 34 |
+
https://www.openslr.org/17/
|
| 35 |
+
其中包含 music, noise, speech.
|
| 36 |
+
music 是一些纯音乐, noise 包含 free-sound, sound-bible, sound-bible部分也许可以做为补充部分.
|
| 37 |
+
总的来说, 有用的不部不多, 可能噪音数据仍然需要自己收集为主, 更加可靠.
|
| 38 |
+
|
| 39 |
+
CHiME-4
|
| 40 |
+
https://www.chimechallenge.org/challenges/chime4/download.html
|
| 41 |
+
|
| 42 |
+
freesound
|
| 43 |
+
https://freesound.org/
|
| 44 |
+
|
| 45 |
+
AudioSet
|
| 46 |
+
https://research.google.com/audioset/index.html
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
### ### 创建训练容器
|
| 51 |
+
|
| 52 |
+
```text
|
| 53 |
+
在容器中训练模型,需要能够从容器中访问到 GPU,参考:
|
| 54 |
+
https://hub.docker.com/r/ollama/ollama
|
| 55 |
+
|
| 56 |
+
docker run -itd \
|
| 57 |
+
--name nx_denoise \
|
| 58 |
+
--network host \
|
| 59 |
+
--gpus all \
|
| 60 |
+
--privileged \
|
| 61 |
+
--ipc=host \
|
| 62 |
+
-v /data/tianxing/HuggingDatasets/nx_noise/data:/data/tianxing/HuggingDatasets/nx_noise/data \
|
| 63 |
+
-v /data/tianxing/PycharmProjects/nx_denoise:/data/tianxing/PycharmProjects/nx_denoise \
|
| 64 |
+
python:3.12
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
查看GPU
|
| 68 |
+
nvidia-smi
|
| 69 |
+
watch -n 1 -d nvidia-smi
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
```text
|
| 75 |
+
在容器中访问 GPU
|
| 76 |
+
|
| 77 |
+
参考:
|
| 78 |
+
https://blog.csdn.net/footless_bird/article/details/136291344
|
| 79 |
+
步骤:
|
| 80 |
+
# 安装
|
| 81 |
+
yum install -y nvidia-container-toolkit
|
| 82 |
+
|
| 83 |
+
# 编辑文件 /etc/docker/daemon.json
|
| 84 |
+
cat /etc/docker/daemon.json
|
| 85 |
+
{
|
| 86 |
+
"data-root": "/data/lib/docker",
|
| 87 |
+
"default-runtime": "nvidia",
|
| 88 |
+
"runtimes": {
|
| 89 |
+
"nvidia": {
|
| 90 |
+
"path": "/usr/bin/nvidia-container-runtime",
|
| 91 |
+
"runtimeArgs": []
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
"registry-mirrors": [
|
| 95 |
+
"https://docker.m.daocloud.io",
|
| 96 |
+
"https://dockerproxy.com",
|
| 97 |
+
"https://docker.mirrors.ustc.edu.cn",
|
| 98 |
+
"https://docker.nju.edu.cn"
|
| 99 |
+
]
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# 重启 docker
|
| 103 |
+
systemctl restart docker
|
| 104 |
+
systemctl daemon-reload
|
| 105 |
+
|
| 106 |
+
# 测试容器内能否访问 GPU.
|
| 107 |
+
docker run --gpus all python:3.12-slim nvidia-smi
|
| 108 |
+
|
| 109 |
+
# 通过这种方式启动容器, 在容器中, 可以查看到 GPU. 但是容器中没有 GPU驱动 nvidia-smi 不工作.
|
| 110 |
+
docker run -it --privileged python:3.12-slim /bin/bash
|
| 111 |
+
apt update
|
| 112 |
+
apt install -y pciutils
|
| 113 |
+
lspci | grep -i nvidia
|
| 114 |
+
#00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
|
| 115 |
+
|
| 116 |
+
# 网上看的是这种启动容器的方式, 但是进去后仍然是 nvidia-smi 不工作.
|
| 117 |
+
docker run \
|
| 118 |
+
--device /dev/nvidia0:/dev/nvidia0 \
|
| 119 |
+
--device /dev/nvidiactl:/dev/nvidiactl \
|
| 120 |
+
--device /dev/nvidia-uvm:/dev/nvidia-uvm \
|
| 121 |
+
-v /usr/local/nvidia:/usr/local/nvidia \
|
| 122 |
+
-it --privileged python:3.12-slim /bin/bash
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# 这种方式进入容器, nvidia-smi 可以工作. 应该关键是 --gpus all 参数.
|
| 126 |
+
docker run -itd --gpus all --name open_unsloth python:3.12-slim /bin/bash
|
| 127 |
+
docker run -itd --gpus all --name Qwen2-7B-Instruct python:3.12-slim /bin/bash
|
| 128 |
+
|
| 129 |
+
```
|
examples/clean_unet/run.sh
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
|
| 7 |
+
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
| 8 |
+
--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
|
| 12 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 13 |
+
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
| 14 |
+
|
| 15 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
|
| 16 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 17 |
+
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
END
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# params
|
| 24 |
+
system_version="windows";
|
| 25 |
+
verbose=true;
|
| 26 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 27 |
+
stop_stage=9
|
| 28 |
+
|
| 29 |
+
work_dir="$(pwd)"
|
| 30 |
+
file_folder_name=file_folder_name
|
| 31 |
+
final_model_name=final_model_name
|
| 32 |
+
config_file="yaml/config.yaml"
|
| 33 |
+
limit=10
|
| 34 |
+
|
| 35 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 36 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 37 |
+
|
| 38 |
+
max_count=10000000
|
| 39 |
+
|
| 40 |
+
nohup_name=nohup.out
|
| 41 |
+
|
| 42 |
+
# model params
|
| 43 |
+
batch_size=64
|
| 44 |
+
max_epochs=200
|
| 45 |
+
save_top_k=10
|
| 46 |
+
patience=5
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# parse options
|
| 50 |
+
while true; do
|
| 51 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 52 |
+
case "$1" in
|
| 53 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 54 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 55 |
+
old_value="(eval echo \\$$name)";
|
| 56 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 57 |
+
was_bool=true;
|
| 58 |
+
else
|
| 59 |
+
was_bool=false;
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 63 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 64 |
+
eval "${name}=\"$2\"";
|
| 65 |
+
|
| 66 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 67 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 68 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 69 |
+
exit 1;
|
| 70 |
+
fi
|
| 71 |
+
shift 2;
|
| 72 |
+
;;
|
| 73 |
+
|
| 74 |
+
*) break;
|
| 75 |
+
esac
|
| 76 |
+
done
|
| 77 |
+
|
| 78 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 79 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 80 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 81 |
+
|
| 82 |
+
dataset="${file_dir}/dataset.xlsx"
|
| 83 |
+
train_dataset="${file_dir}/train.xlsx"
|
| 84 |
+
valid_dataset="${file_dir}/valid.xlsx"
|
| 85 |
+
|
| 86 |
+
$verbose && echo "system_version: ${system_version}"
|
| 87 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 88 |
+
|
| 89 |
+
if [ $system_version == "windows" ]; then
|
| 90 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 91 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 92 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 93 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 94 |
+
fi
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 98 |
+
$verbose && echo "stage 1: prepare data"
|
| 99 |
+
cd "${work_dir}" || exit 1
|
| 100 |
+
python3 step_1_prepare_data.py \
|
| 101 |
+
--file_dir "${file_dir}" \
|
| 102 |
+
--noise_dir "${noise_dir}" \
|
| 103 |
+
--speech_dir "${speech_dir}" \
|
| 104 |
+
--train_dataset "${train_dataset}" \
|
| 105 |
+
--valid_dataset "${valid_dataset}" \
|
| 106 |
+
--max_count "${max_count}" \
|
| 107 |
+
|
| 108 |
+
fi
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 112 |
+
$verbose && echo "stage 2: train model"
|
| 113 |
+
cd "${work_dir}" || exit 1
|
| 114 |
+
python3 step_2_train_model.py \
|
| 115 |
+
--train_dataset "${train_dataset}" \
|
| 116 |
+
--valid_dataset "${valid_dataset}" \
|
| 117 |
+
--serialization_dir "${file_dir}" \
|
| 118 |
+
--config_file "${config_file}" \
|
| 119 |
+
|
| 120 |
+
fi
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 124 |
+
$verbose && echo "stage 3: test model"
|
| 125 |
+
cd "${work_dir}" || exit 1
|
| 126 |
+
python3 step_3_evaluation.py \
|
| 127 |
+
--valid_dataset "${valid_dataset}" \
|
| 128 |
+
--model_dir "${file_dir}/best" \
|
| 129 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 130 |
+
--limit "${limit}" \
|
| 131 |
+
|
| 132 |
+
fi
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 136 |
+
$verbose && echo "stage 4: export model"
|
| 137 |
+
cd "${work_dir}" || exit 1
|
| 138 |
+
python3 step_5_export_models.py \
|
| 139 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
| 140 |
+
--model_dir "${file_dir}/best" \
|
| 141 |
+
--serialization_dir "${file_dir}" \
|
| 142 |
+
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 147 |
+
$verbose && echo "stage 5: collect files"
|
| 148 |
+
cd "${work_dir}" || exit 1
|
| 149 |
+
|
| 150 |
+
mkdir -p ${final_model_dir}
|
| 151 |
+
|
| 152 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 153 |
+
cp -r "${file_dir}/vocabulary" "${final_model_dir}"
|
| 154 |
+
|
| 155 |
+
cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
|
| 156 |
+
|
| 157 |
+
cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
|
| 158 |
+
cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
|
| 159 |
+
cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
|
| 160 |
+
cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
|
| 161 |
+
|
| 162 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 163 |
+
|
| 164 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 165 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 166 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 167 |
+
fi
|
| 168 |
+
|
| 169 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 170 |
+
rm -rf "${final_model_name}"
|
| 171 |
+
|
| 172 |
+
fi
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 176 |
+
$verbose && echo "stage 6: clear file_dir"
|
| 177 |
+
cd "${work_dir}" || exit 1
|
| 178 |
+
|
| 179 |
+
rm -rf "${file_dir}";
|
| 180 |
+
|
| 181 |
+
fi
|
examples/clean_unet/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from scipy.io import wavfile
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import librosa
|
| 17 |
+
|
| 18 |
+
from project_settings import project_path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_args():
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 24 |
+
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--noise_dir",
|
| 27 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 28 |
+
type=str
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--speech_dir",
|
| 32 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 33 |
+
type=str
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 37 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 38 |
+
|
| 39 |
+
parser.add_argument("--duration", default=2.0, type=float)
|
| 40 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 41 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 42 |
+
|
| 43 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 44 |
+
|
| 45 |
+
parser.add_argument("--max_count", default=10000, type=int)
|
| 46 |
+
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
return args
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def filename_generator(data_dir: str):
|
| 52 |
+
data_dir = Path(data_dir)
|
| 53 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 54 |
+
yield filename.as_posix()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
|
| 58 |
+
data_dir = Path(data_dir)
|
| 59 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 60 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 61 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 62 |
+
|
| 63 |
+
if raw_duration < duration:
|
| 64 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 65 |
+
continue
|
| 66 |
+
if signal.ndim != 1:
|
| 67 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 68 |
+
|
| 69 |
+
signal_length = len(signal)
|
| 70 |
+
win_size = int(duration * sample_rate)
|
| 71 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 72 |
+
row = {
|
| 73 |
+
"filename": filename.as_posix(),
|
| 74 |
+
"raw_duration": round(raw_duration, 4),
|
| 75 |
+
"offset": round(begin / sample_rate, 4),
|
| 76 |
+
"duration": round(duration, 4),
|
| 77 |
+
}
|
| 78 |
+
yield row
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_dataset(args):
|
| 82 |
+
file_dir = Path(args.file_dir)
|
| 83 |
+
file_dir.mkdir(exist_ok=True)
|
| 84 |
+
|
| 85 |
+
noise_dir = Path(args.noise_dir)
|
| 86 |
+
speech_dir = Path(args.speech_dir)
|
| 87 |
+
|
| 88 |
+
noise_generator = target_second_signal_generator(
|
| 89 |
+
noise_dir.as_posix(),
|
| 90 |
+
duration=args.duration,
|
| 91 |
+
sample_rate=args.target_sample_rate
|
| 92 |
+
)
|
| 93 |
+
speech_generator = target_second_signal_generator(
|
| 94 |
+
speech_dir.as_posix(),
|
| 95 |
+
duration=args.duration,
|
| 96 |
+
sample_rate=args.target_sample_rate
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
dataset = list()
|
| 100 |
+
|
| 101 |
+
count = 0
|
| 102 |
+
process_bar = tqdm(desc="build dataset excel")
|
| 103 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
| 104 |
+
if count >= args.max_count:
|
| 105 |
+
break
|
| 106 |
+
|
| 107 |
+
noise_filename = noise["filename"]
|
| 108 |
+
noise_raw_duration = noise["raw_duration"]
|
| 109 |
+
noise_offset = noise["offset"]
|
| 110 |
+
noise_duration = noise["duration"]
|
| 111 |
+
|
| 112 |
+
speech_filename = speech["filename"]
|
| 113 |
+
speech_raw_duration = speech["raw_duration"]
|
| 114 |
+
speech_offset = speech["offset"]
|
| 115 |
+
speech_duration = speech["duration"]
|
| 116 |
+
|
| 117 |
+
random1 = random.random()
|
| 118 |
+
random2 = random.random()
|
| 119 |
+
|
| 120 |
+
row = {
|
| 121 |
+
"noise_filename": noise_filename,
|
| 122 |
+
"noise_raw_duration": noise_raw_duration,
|
| 123 |
+
"noise_offset": noise_offset,
|
| 124 |
+
"noise_duration": noise_duration,
|
| 125 |
+
|
| 126 |
+
"speech_filename": speech_filename,
|
| 127 |
+
"speech_raw_duration": speech_raw_duration,
|
| 128 |
+
"speech_offset": speech_offset,
|
| 129 |
+
"speech_duration": speech_duration,
|
| 130 |
+
|
| 131 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 132 |
+
|
| 133 |
+
"random1": random1,
|
| 134 |
+
"random2": random2,
|
| 135 |
+
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
| 136 |
+
}
|
| 137 |
+
dataset.append(row)
|
| 138 |
+
count += 1
|
| 139 |
+
duration_seconds = count * args.duration
|
| 140 |
+
duration_hours = duration_seconds / 3600
|
| 141 |
+
|
| 142 |
+
process_bar.update(n=1)
|
| 143 |
+
process_bar.set_postfix({
|
| 144 |
+
# "duration_seconds": round(duration_seconds, 4),
|
| 145 |
+
"duration_hours": round(duration_hours, 4),
|
| 146 |
+
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
dataset = pd.DataFrame(dataset)
|
| 150 |
+
dataset = dataset.sort_values(by=["random1"], ascending=False)
|
| 151 |
+
dataset.to_excel(
|
| 152 |
+
file_dir / "dataset.xlsx",
|
| 153 |
+
index=False,
|
| 154 |
+
)
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def split_dataset(args):
|
| 160 |
+
"""分割训练集, 测试集"""
|
| 161 |
+
file_dir = Path(args.file_dir)
|
| 162 |
+
file_dir.mkdir(exist_ok=True)
|
| 163 |
+
|
| 164 |
+
df = pd.read_excel(file_dir / "dataset.xlsx")
|
| 165 |
+
|
| 166 |
+
train = list()
|
| 167 |
+
test = list()
|
| 168 |
+
|
| 169 |
+
for i, row in df.iterrows():
|
| 170 |
+
flag = row["flag"]
|
| 171 |
+
if flag == "TRAIN":
|
| 172 |
+
train.append(row)
|
| 173 |
+
else:
|
| 174 |
+
test.append(row)
|
| 175 |
+
|
| 176 |
+
train = pd.DataFrame(train)
|
| 177 |
+
train.to_excel(
|
| 178 |
+
args.train_dataset,
|
| 179 |
+
index=False,
|
| 180 |
+
# encoding="utf_8_sig"
|
| 181 |
+
)
|
| 182 |
+
test = pd.DataFrame(test)
|
| 183 |
+
test.to_excel(
|
| 184 |
+
args.valid_dataset,
|
| 185 |
+
index=False,
|
| 186 |
+
# encoding="utf_8_sig"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def main():
|
| 193 |
+
args = get_args()
|
| 194 |
+
|
| 195 |
+
get_dataset(args)
|
| 196 |
+
split_dataset(args)
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
main()
|
examples/clean_unet/step_2_train_model.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/NVIDIA/CleanUNet/blob/main/train.py
|
| 5 |
+
|
| 6 |
+
https://github.com/NVIDIA/CleanUNet/blob/main/configs/DNS-large-full.json
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 12 |
+
import os
|
| 13 |
+
import platform
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import random
|
| 16 |
+
import sys
|
| 17 |
+
import shutil
|
| 18 |
+
from typing import List
|
| 19 |
+
|
| 20 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 21 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
from torch.utils.data.dataloader import DataLoader
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
| 31 |
+
from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig
|
| 32 |
+
from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
|
| 33 |
+
from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
|
| 34 |
+
from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
|
| 35 |
+
from toolbox.torchaudio.models.clean_unet.metrics import run_pesq_score
|
| 36 |
+
|
| 37 |
+
torch.autograd.set_detect_anomaly(True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_args():
|
| 41 |
+
parser = argparse.ArgumentParser()
|
| 42 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 43 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 44 |
+
|
| 45 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
| 46 |
+
|
| 47 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
| 48 |
+
parser.add_argument("--learning_rate", default=2e-4, type=float)
|
| 49 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
| 50 |
+
parser.add_argument("--patience", default=5, type=int)
|
| 51 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 52 |
+
parser.add_argument("--seed", default=0, type=int)
|
| 53 |
+
|
| 54 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 55 |
+
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
return args
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def logging_config(file_dir: str):
|
| 61 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 62 |
+
|
| 63 |
+
logging.basicConfig(format=fmt,
|
| 64 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 65 |
+
level=logging.INFO)
|
| 66 |
+
file_handler = TimedRotatingFileHandler(
|
| 67 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 68 |
+
encoding="utf-8",
|
| 69 |
+
when="D",
|
| 70 |
+
interval=1,
|
| 71 |
+
backupCount=7
|
| 72 |
+
)
|
| 73 |
+
file_handler.setLevel(logging.INFO)
|
| 74 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 75 |
+
logger = logging.getLogger(__name__)
|
| 76 |
+
logger.addHandler(file_handler)
|
| 77 |
+
|
| 78 |
+
return logger
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CollateFunction(object):
|
| 82 |
+
def __init__(self):
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
def __call__(self, batch: List[dict]):
|
| 86 |
+
clean_audios = list()
|
| 87 |
+
noisy_audios = list()
|
| 88 |
+
|
| 89 |
+
for sample in batch:
|
| 90 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
| 91 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
| 92 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
| 93 |
+
# snr_db: float = sample["snr_db"]
|
| 94 |
+
|
| 95 |
+
clean_audios.append(clean_audio)
|
| 96 |
+
noisy_audios.append(noisy_audio)
|
| 97 |
+
|
| 98 |
+
clean_audios = torch.stack(clean_audios)
|
| 99 |
+
noisy_audios = torch.stack(noisy_audios)
|
| 100 |
+
|
| 101 |
+
# assert
|
| 102 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 103 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 104 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 105 |
+
raise AssertionError("nan or inf in noisy_audios")
|
| 106 |
+
return clean_audios, noisy_audios
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
collate_fn = CollateFunction()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def main():
|
| 113 |
+
args = get_args()
|
| 114 |
+
|
| 115 |
+
config = CleanUNetConfig.from_pretrained(
|
| 116 |
+
pretrained_model_name_or_path=args.config_file,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
serialization_dir = Path(args.serialization_dir)
|
| 120 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
logger = logging_config(serialization_dir)
|
| 123 |
+
|
| 124 |
+
random.seed(args.seed)
|
| 125 |
+
np.random.seed(args.seed)
|
| 126 |
+
torch.manual_seed(args.seed)
|
| 127 |
+
logger.info(f"set seed: {args.seed}")
|
| 128 |
+
|
| 129 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 130 |
+
n_gpu = torch.cuda.device_count()
|
| 131 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 132 |
+
|
| 133 |
+
# datasets
|
| 134 |
+
train_dataset = DenoiseExcelDataset(
|
| 135 |
+
excel_file=args.train_dataset,
|
| 136 |
+
expected_sample_rate=8000,
|
| 137 |
+
max_wave_value=32768.0,
|
| 138 |
+
)
|
| 139 |
+
valid_dataset = DenoiseExcelDataset(
|
| 140 |
+
excel_file=args.valid_dataset,
|
| 141 |
+
expected_sample_rate=8000,
|
| 142 |
+
max_wave_value=32768.0,
|
| 143 |
+
)
|
| 144 |
+
train_data_loader = DataLoader(
|
| 145 |
+
dataset=train_dataset,
|
| 146 |
+
batch_size=args.batch_size,
|
| 147 |
+
shuffle=True,
|
| 148 |
+
sampler=None,
|
| 149 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 150 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 151 |
+
collate_fn=collate_fn,
|
| 152 |
+
pin_memory=False,
|
| 153 |
+
# prefetch_factor=64,
|
| 154 |
+
)
|
| 155 |
+
valid_data_loader = DataLoader(
|
| 156 |
+
dataset=valid_dataset,
|
| 157 |
+
batch_size=args.batch_size,
|
| 158 |
+
shuffle=True,
|
| 159 |
+
sampler=None,
|
| 160 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 161 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 162 |
+
collate_fn=collate_fn,
|
| 163 |
+
pin_memory=False,
|
| 164 |
+
# prefetch_factor=64,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# models
|
| 168 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 169 |
+
model = CleanUNetPretrainedModel(config).to(device)
|
| 170 |
+
|
| 171 |
+
# optimizer
|
| 172 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
| 173 |
+
optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
|
| 174 |
+
|
| 175 |
+
# resume training
|
| 176 |
+
last_epoch = -1
|
| 177 |
+
for epoch_i in serialization_dir.glob("epoch-*"):
|
| 178 |
+
epoch_i = Path(epoch_i)
|
| 179 |
+
epoch_idx = epoch_i.stem.split("-")[1]
|
| 180 |
+
epoch_idx = int(epoch_idx)
|
| 181 |
+
if epoch_idx > last_epoch:
|
| 182 |
+
last_epoch = epoch_idx
|
| 183 |
+
|
| 184 |
+
if last_epoch != -1:
|
| 185 |
+
logger.info(f"resume from epoch-{last_epoch}.")
|
| 186 |
+
model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
|
| 187 |
+
optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
|
| 188 |
+
|
| 189 |
+
logger.info(f"load state dict for model.")
|
| 190 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 191 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 192 |
+
model.load_state_dict(state_dict, strict=True)
|
| 193 |
+
|
| 194 |
+
logger.info(f"load state dict for optimizer.")
|
| 195 |
+
with open(optimizer_pth.as_posix(), "rb") as f:
|
| 196 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 197 |
+
optimizer.load_state_dict(state_dict)
|
| 198 |
+
|
| 199 |
+
lr_scheduler = LinearWarmupCosineDecay(
|
| 200 |
+
optimizer,
|
| 201 |
+
lr_max=args.learning_rate,
|
| 202 |
+
n_iter=250000,
|
| 203 |
+
iteration=250000,
|
| 204 |
+
divider=25,
|
| 205 |
+
warmup_proportion=0.05,
|
| 206 |
+
phase=("linear", "cosine"),
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# ae_loss_fn = nn.MSELoss(reduction="mean")
|
| 210 |
+
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
| 211 |
+
|
| 212 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 213 |
+
fft_sizes=[256, 512, 1024],
|
| 214 |
+
hop_sizes=[25, 50, 120],
|
| 215 |
+
win_lengths=[120, 240, 600],
|
| 216 |
+
sc_lambda=0.5,
|
| 217 |
+
mag_lambda=0.5,
|
| 218 |
+
band="full"
|
| 219 |
+
).to(device)
|
| 220 |
+
|
| 221 |
+
# training loop
|
| 222 |
+
|
| 223 |
+
# state
|
| 224 |
+
average_pesq_score = 10000000000
|
| 225 |
+
average_loss = 10000000000
|
| 226 |
+
average_ae_loss = 10000000000
|
| 227 |
+
average_sc_loss = 10000000000
|
| 228 |
+
average_mag_loss = 10000000000
|
| 229 |
+
|
| 230 |
+
model_list = list()
|
| 231 |
+
best_idx_epoch = None
|
| 232 |
+
best_metric = None
|
| 233 |
+
patience_count = 0
|
| 234 |
+
|
| 235 |
+
logger.info("training")
|
| 236 |
+
for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
|
| 237 |
+
# train
|
| 238 |
+
model.train()
|
| 239 |
+
|
| 240 |
+
total_pesq_score = 0.
|
| 241 |
+
total_loss = 0.
|
| 242 |
+
total_ae_loss = 0.
|
| 243 |
+
total_sc_loss = 0.
|
| 244 |
+
total_mag_loss = 0.
|
| 245 |
+
total_batches = 0.
|
| 246 |
+
|
| 247 |
+
progress_bar = tqdm(
|
| 248 |
+
total=len(train_data_loader),
|
| 249 |
+
desc="Training; epoch: {}".format(idx_epoch),
|
| 250 |
+
)
|
| 251 |
+
for batch in train_data_loader:
|
| 252 |
+
clean_audios, noisy_audios = batch
|
| 253 |
+
clean_audios = clean_audios.to(device)
|
| 254 |
+
noisy_audios = noisy_audios.to(device)
|
| 255 |
+
|
| 256 |
+
enhanced_audios = model.forward(noisy_audios)
|
| 257 |
+
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
| 258 |
+
|
| 259 |
+
ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
|
| 260 |
+
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
|
| 261 |
+
|
| 262 |
+
loss = ae_loss + sc_loss + mag_loss
|
| 263 |
+
|
| 264 |
+
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
| 265 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 266 |
+
pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
|
| 267 |
+
|
| 268 |
+
optimizer.zero_grad()
|
| 269 |
+
loss.backward()
|
| 270 |
+
optimizer.step()
|
| 271 |
+
lr_scheduler.step()
|
| 272 |
+
|
| 273 |
+
total_pesq_score += pesq_score
|
| 274 |
+
total_loss += loss.item()
|
| 275 |
+
total_ae_loss += ae_loss.item()
|
| 276 |
+
total_sc_loss += sc_loss.item()
|
| 277 |
+
total_mag_loss += mag_loss.item()
|
| 278 |
+
total_batches += 1
|
| 279 |
+
|
| 280 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 281 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 282 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
| 283 |
+
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
| 284 |
+
average_mag_loss = round(total_mag_loss / total_batches, 4)
|
| 285 |
+
|
| 286 |
+
progress_bar.update(1)
|
| 287 |
+
progress_bar.set_postfix({
|
| 288 |
+
"pesq_score": average_pesq_score,
|
| 289 |
+
"loss": average_loss,
|
| 290 |
+
"ae_loss": average_ae_loss,
|
| 291 |
+
"sc_loss": average_sc_loss,
|
| 292 |
+
"mag_loss": average_mag_loss,
|
| 293 |
+
})
|
| 294 |
+
|
| 295 |
+
# evaluation
|
| 296 |
+
model.eval()
|
| 297 |
+
|
| 298 |
+
torch.cuda.empty_cache()
|
| 299 |
+
|
| 300 |
+
total_pesq_score = 0.
|
| 301 |
+
total_loss = 0.
|
| 302 |
+
total_ae_loss = 0.
|
| 303 |
+
total_sc_loss = 0.
|
| 304 |
+
total_mag_loss = 0.
|
| 305 |
+
total_batches = 0.
|
| 306 |
+
|
| 307 |
+
progress_bar = tqdm(
|
| 308 |
+
total=len(valid_data_loader),
|
| 309 |
+
desc="Evaluation; epoch: {}".format(idx_epoch),
|
| 310 |
+
)
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
for batch in valid_data_loader:
|
| 313 |
+
clean_audios, noisy_audios = batch
|
| 314 |
+
clean_audios = clean_audios.to(device)
|
| 315 |
+
noisy_audios = noisy_audios.to(device)
|
| 316 |
+
|
| 317 |
+
enhanced_audios = model.forward(noisy_audios)
|
| 318 |
+
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
| 319 |
+
|
| 320 |
+
ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
|
| 321 |
+
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
|
| 322 |
+
|
| 323 |
+
loss = ae_loss + sc_loss + mag_loss
|
| 324 |
+
|
| 325 |
+
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
| 326 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 327 |
+
pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
|
| 328 |
+
|
| 329 |
+
total_pesq_score += pesq_score
|
| 330 |
+
total_loss += loss.item()
|
| 331 |
+
total_ae_loss += ae_loss.item()
|
| 332 |
+
total_sc_loss += sc_loss.item()
|
| 333 |
+
total_mag_loss += mag_loss.item()
|
| 334 |
+
total_batches += 1
|
| 335 |
+
|
| 336 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 337 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 338 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
| 339 |
+
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
| 340 |
+
average_mag_loss = round(total_mag_loss / total_batches, 4)
|
| 341 |
+
|
| 342 |
+
progress_bar.update(1)
|
| 343 |
+
progress_bar.set_postfix({
|
| 344 |
+
"pesq_score": average_pesq_score,
|
| 345 |
+
"loss": average_loss,
|
| 346 |
+
"ae_loss": average_ae_loss,
|
| 347 |
+
"sc_loss": average_sc_loss,
|
| 348 |
+
"mag_loss": average_mag_loss,
|
| 349 |
+
})
|
| 350 |
+
|
| 351 |
+
# scheduler
|
| 352 |
+
lr_scheduler.step()
|
| 353 |
+
|
| 354 |
+
# save path
|
| 355 |
+
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
| 356 |
+
epoch_dir.mkdir(parents=True, exist_ok=False)
|
| 357 |
+
|
| 358 |
+
# save models
|
| 359 |
+
model.save_pretrained(epoch_dir.as_posix())
|
| 360 |
+
|
| 361 |
+
model_list.append(epoch_dir)
|
| 362 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 363 |
+
model_to_delete: Path = model_list.pop(0)
|
| 364 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 365 |
+
|
| 366 |
+
# save optim
|
| 367 |
+
torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
|
| 368 |
+
|
| 369 |
+
# save metric
|
| 370 |
+
if best_metric is None:
|
| 371 |
+
best_idx_epoch = idx_epoch
|
| 372 |
+
best_metric = average_pesq_score
|
| 373 |
+
elif average_pesq_score > best_metric:
|
| 374 |
+
# great is better.
|
| 375 |
+
best_idx_epoch = idx_epoch
|
| 376 |
+
best_metric = average_pesq_score
|
| 377 |
+
else:
|
| 378 |
+
pass
|
| 379 |
+
|
| 380 |
+
metrics = {
|
| 381 |
+
"idx_epoch": idx_epoch,
|
| 382 |
+
"best_idx_epoch": best_idx_epoch,
|
| 383 |
+
|
| 384 |
+
"pesq_score": average_pesq_score,
|
| 385 |
+
"loss": average_loss,
|
| 386 |
+
"ae_loss": average_ae_loss,
|
| 387 |
+
"sc_loss": average_sc_loss,
|
| 388 |
+
"mag_loss": average_mag_loss,
|
| 389 |
+
|
| 390 |
+
}
|
| 391 |
+
metrics_filename = epoch_dir / "metrics_epoch.json"
|
| 392 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 393 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 394 |
+
|
| 395 |
+
# save best
|
| 396 |
+
best_dir = serialization_dir / "best"
|
| 397 |
+
if best_idx_epoch == idx_epoch:
|
| 398 |
+
if best_dir.exists():
|
| 399 |
+
shutil.rmtree(best_dir)
|
| 400 |
+
shutil.copytree(epoch_dir, best_dir)
|
| 401 |
+
|
| 402 |
+
# early stop
|
| 403 |
+
early_stop_flag = False
|
| 404 |
+
if best_idx_epoch == idx_epoch:
|
| 405 |
+
patience_count = 0
|
| 406 |
+
else:
|
| 407 |
+
patience_count += 1
|
| 408 |
+
if patience_count >= args.patience:
|
| 409 |
+
early_stop_flag = True
|
| 410 |
+
|
| 411 |
+
# early stop
|
| 412 |
+
if early_stop_flag:
|
| 413 |
+
break
|
| 414 |
+
|
| 415 |
+
return
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if __name__ == "__main__":
|
| 419 |
+
main()
|
examples/clean_unet/step_3_evaluation.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == '__main__':
|
| 6 |
+
pass
|
examples/clean_unet/yaml/config.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "clean_unet"
|
| 2 |
+
|
| 3 |
+
channels_input: 1
|
| 4 |
+
channels_output: 1
|
| 5 |
+
channels_h: 64
|
| 6 |
+
max_h: 768
|
| 7 |
+
encoder_n_layers: 8
|
| 8 |
+
kernel_size: 4
|
| 9 |
+
stride: 2
|
| 10 |
+
tsfm_n_layers: 5
|
| 11 |
+
tsfm_n_head: 8
|
| 12 |
+
tsfm_d_model: 512
|
| 13 |
+
tsfm_d_inner: 2048
|
| 14 |
+
|
examples/conv_tasnet/run.sh
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
| 7 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
| 8 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
| 9 |
+
--max_epochs 400
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
END
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# params
|
| 16 |
+
system_version="windows";
|
| 17 |
+
verbose=true;
|
| 18 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 19 |
+
stop_stage=9
|
| 20 |
+
|
| 21 |
+
work_dir="$(pwd)"
|
| 22 |
+
file_folder_name=file_folder_name
|
| 23 |
+
final_model_name=final_model_name
|
| 24 |
+
config_file="yaml/config.yaml"
|
| 25 |
+
limit=10
|
| 26 |
+
|
| 27 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 28 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 29 |
+
|
| 30 |
+
max_count=10000000
|
| 31 |
+
|
| 32 |
+
nohup_name=nohup.out
|
| 33 |
+
|
| 34 |
+
# model params
|
| 35 |
+
batch_size=64
|
| 36 |
+
max_epochs=200
|
| 37 |
+
save_top_k=10
|
| 38 |
+
patience=5
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# parse options
|
| 42 |
+
while true; do
|
| 43 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 44 |
+
case "$1" in
|
| 45 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 46 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 47 |
+
old_value="(eval echo \\$$name)";
|
| 48 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 49 |
+
was_bool=true;
|
| 50 |
+
else
|
| 51 |
+
was_bool=false;
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 55 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 56 |
+
eval "${name}=\"$2\"";
|
| 57 |
+
|
| 58 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 59 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 60 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 61 |
+
exit 1;
|
| 62 |
+
fi
|
| 63 |
+
shift 2;
|
| 64 |
+
;;
|
| 65 |
+
|
| 66 |
+
*) break;
|
| 67 |
+
esac
|
| 68 |
+
done
|
| 69 |
+
|
| 70 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 71 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 72 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 73 |
+
|
| 74 |
+
train_dataset="${file_dir}/train.jsonl"
|
| 75 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
| 76 |
+
|
| 77 |
+
$verbose && echo "system_version: ${system_version}"
|
| 78 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 79 |
+
|
| 80 |
+
if [ $system_version == "windows" ]; then
|
| 81 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 82 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 83 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 84 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 89 |
+
$verbose && echo "stage 1: prepare data"
|
| 90 |
+
cd "${work_dir}" || exit 1
|
| 91 |
+
python3 step_1_prepare_data.py \
|
| 92 |
+
--file_dir "${file_dir}" \
|
| 93 |
+
--noise_dir "${noise_dir}" \
|
| 94 |
+
--speech_dir "${speech_dir}" \
|
| 95 |
+
--train_dataset "${train_dataset}" \
|
| 96 |
+
--valid_dataset "${valid_dataset}" \
|
| 97 |
+
--max_count "${max_count}" \
|
| 98 |
+
|
| 99 |
+
fi
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 103 |
+
$verbose && echo "stage 2: train model"
|
| 104 |
+
cd "${work_dir}" || exit 1
|
| 105 |
+
python3 step_2_train_model.py \
|
| 106 |
+
--train_dataset "${train_dataset}" \
|
| 107 |
+
--valid_dataset "${valid_dataset}" \
|
| 108 |
+
--serialization_dir "${file_dir}" \
|
| 109 |
+
--config_file "${config_file}" \
|
| 110 |
+
|
| 111 |
+
fi
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 115 |
+
$verbose && echo "stage 3: test model"
|
| 116 |
+
cd "${work_dir}" || exit 1
|
| 117 |
+
python3 step_3_evaluation.py \
|
| 118 |
+
--valid_dataset "${valid_dataset}" \
|
| 119 |
+
--model_dir "${file_dir}/best" \
|
| 120 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 121 |
+
--limit "${limit}" \
|
| 122 |
+
|
| 123 |
+
fi
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 127 |
+
$verbose && echo "stage 4: collect files"
|
| 128 |
+
cd "${work_dir}" || exit 1
|
| 129 |
+
|
| 130 |
+
mkdir -p ${final_model_dir}
|
| 131 |
+
|
| 132 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 133 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 134 |
+
|
| 135 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 136 |
+
|
| 137 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 138 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 139 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 140 |
+
fi
|
| 141 |
+
|
| 142 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 143 |
+
rm -rf "${final_model_name}"
|
| 144 |
+
|
| 145 |
+
fi
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 149 |
+
$verbose && echo "stage 5: clear file_dir"
|
| 150 |
+
cd "${work_dir}" || exit 1
|
| 151 |
+
|
| 152 |
+
rm -rf "${file_dir}";
|
| 153 |
+
|
| 154 |
+
fi
|
examples/conv_tasnet/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--noise_dir",
|
| 24 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--speech_dir",
|
| 29 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 34 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 35 |
+
|
| 36 |
+
parser.add_argument("--duration", default=4.0, type=float)
|
| 37 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 38 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 39 |
+
|
| 40 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--max_count", default=10000, type=int)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
return args
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def filename_generator(data_dir: str):
|
| 49 |
+
data_dir = Path(data_dir)
|
| 50 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 51 |
+
yield filename.as_posix()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
| 55 |
+
data_dir = Path(data_dir)
|
| 56 |
+
for epoch_idx in range(max_epoch):
|
| 57 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
+
|
| 61 |
+
if raw_duration < duration:
|
| 62 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
+
continue
|
| 64 |
+
if signal.ndim != 1:
|
| 65 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
+
|
| 67 |
+
signal_length = len(signal)
|
| 68 |
+
win_size = int(duration * sample_rate)
|
| 69 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 71 |
+
continue
|
| 72 |
+
row = {
|
| 73 |
+
"epoch_idx": epoch_idx,
|
| 74 |
+
"filename": filename.as_posix(),
|
| 75 |
+
"raw_duration": round(raw_duration, 4),
|
| 76 |
+
"offset": round(begin / sample_rate, 4),
|
| 77 |
+
"duration": round(duration, 4),
|
| 78 |
+
}
|
| 79 |
+
yield row
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
args = get_args()
|
| 84 |
+
|
| 85 |
+
file_dir = Path(args.file_dir)
|
| 86 |
+
file_dir.mkdir(exist_ok=True)
|
| 87 |
+
|
| 88 |
+
noise_dir = Path(args.noise_dir)
|
| 89 |
+
speech_dir = Path(args.speech_dir)
|
| 90 |
+
|
| 91 |
+
noise_generator = target_second_signal_generator(
|
| 92 |
+
noise_dir.as_posix(),
|
| 93 |
+
duration=args.duration,
|
| 94 |
+
sample_rate=args.target_sample_rate,
|
| 95 |
+
max_epoch=100000,
|
| 96 |
+
)
|
| 97 |
+
speech_generator = target_second_signal_generator(
|
| 98 |
+
speech_dir.as_posix(),
|
| 99 |
+
duration=args.duration,
|
| 100 |
+
sample_rate=args.target_sample_rate,
|
| 101 |
+
max_epoch=1,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
dataset = list()
|
| 105 |
+
|
| 106 |
+
count = 0
|
| 107 |
+
process_bar = tqdm(desc="build dataset excel")
|
| 108 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 109 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
| 110 |
+
if count >= args.max_count > 0:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
noise_filename = noise["filename"]
|
| 114 |
+
noise_raw_duration = noise["raw_duration"]
|
| 115 |
+
noise_offset = noise["offset"]
|
| 116 |
+
noise_duration = noise["duration"]
|
| 117 |
+
|
| 118 |
+
speech_filename = speech["filename"]
|
| 119 |
+
speech_raw_duration = speech["raw_duration"]
|
| 120 |
+
speech_offset = speech["offset"]
|
| 121 |
+
speech_duration = speech["duration"]
|
| 122 |
+
|
| 123 |
+
random1 = random.random()
|
| 124 |
+
random2 = random.random()
|
| 125 |
+
|
| 126 |
+
row = {
|
| 127 |
+
"count": count,
|
| 128 |
+
|
| 129 |
+
"noise_filename": noise_filename,
|
| 130 |
+
"noise_raw_duration": noise_raw_duration,
|
| 131 |
+
"noise_offset": noise_offset,
|
| 132 |
+
"noise_duration": noise_duration,
|
| 133 |
+
|
| 134 |
+
"speech_filename": speech_filename,
|
| 135 |
+
"speech_raw_duration": speech_raw_duration,
|
| 136 |
+
"speech_offset": speech_offset,
|
| 137 |
+
"speech_duration": speech_duration,
|
| 138 |
+
|
| 139 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 140 |
+
|
| 141 |
+
"random1": random1,
|
| 142 |
+
}
|
| 143 |
+
row = json.dumps(row, ensure_ascii=False)
|
| 144 |
+
if random2 < (1 / 300 / 1):
|
| 145 |
+
fvalid.write(f"{row}\n")
|
| 146 |
+
else:
|
| 147 |
+
ftrain.write(f"{row}\n")
|
| 148 |
+
|
| 149 |
+
count += 1
|
| 150 |
+
duration_seconds = count * args.duration
|
| 151 |
+
duration_hours = duration_seconds / 3600
|
| 152 |
+
|
| 153 |
+
process_bar.update(n=1)
|
| 154 |
+
process_bar.set_postfix({
|
| 155 |
+
# "duration_seconds": round(duration_seconds, 4),
|
| 156 |
+
"duration_hours": round(duration_hours, 4),
|
| 157 |
+
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
main()
|
examples/conv_tasnet/step_2_train_model.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/kaituoxu/Conv-TasNet/tree/master/src
|
| 5 |
+
|
| 6 |
+
一般场景:
|
| 7 |
+
|
| 8 |
+
目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
|
| 9 |
+
|
| 10 |
+
高要求场景(如医疗助听、语音识别):
|
| 11 |
+
需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
|
| 12 |
+
|
| 13 |
+
DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。
|
| 14 |
+
https://arxiv.org/abs/2205.05474
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 21 |
+
import os
|
| 22 |
+
import platform
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
import shutil
|
| 27 |
+
from typing import List
|
| 28 |
+
|
| 29 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 30 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
from torch.nn import functional as F
|
| 36 |
+
from torch.utils.data.dataloader import DataLoader
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
| 40 |
+
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
| 41 |
+
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
| 42 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 43 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 44 |
+
from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
|
| 45 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_args():
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 51 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 52 |
+
|
| 53 |
+
parser.add_argument("--max_epochs", default=200, type=int)
|
| 54 |
+
|
| 55 |
+
parser.add_argument("--batch_size", default=8, type=int)
|
| 56 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
| 57 |
+
parser.add_argument("--patience", default=5, type=int)
|
| 58 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 59 |
+
parser.add_argument("--seed", default=1234, type=int)
|
| 60 |
+
|
| 61 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 62 |
+
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
return args
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def logging_config(file_dir: str):
|
| 68 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 69 |
+
|
| 70 |
+
logging.basicConfig(format=fmt,
|
| 71 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 72 |
+
level=logging.INFO)
|
| 73 |
+
file_handler = TimedRotatingFileHandler(
|
| 74 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 75 |
+
encoding="utf-8",
|
| 76 |
+
when="D",
|
| 77 |
+
interval=1,
|
| 78 |
+
backupCount=7
|
| 79 |
+
)
|
| 80 |
+
file_handler.setLevel(logging.INFO)
|
| 81 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 82 |
+
logger = logging.getLogger(__name__)
|
| 83 |
+
logger.addHandler(file_handler)
|
| 84 |
+
|
| 85 |
+
return logger
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class CollateFunction(object):
|
| 89 |
+
def __init__(self):
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
def __call__(self, batch: List[dict]):
|
| 93 |
+
clean_audios = list()
|
| 94 |
+
noisy_audios = list()
|
| 95 |
+
|
| 96 |
+
for sample in batch:
|
| 97 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
| 98 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
| 99 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
| 100 |
+
# snr_db: float = sample["snr_db"]
|
| 101 |
+
|
| 102 |
+
clean_audios.append(clean_audio)
|
| 103 |
+
noisy_audios.append(noisy_audio)
|
| 104 |
+
|
| 105 |
+
clean_audios = torch.stack(clean_audios)
|
| 106 |
+
noisy_audios = torch.stack(noisy_audios)
|
| 107 |
+
|
| 108 |
+
# assert
|
| 109 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 110 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 111 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 112 |
+
raise AssertionError("nan or inf in noisy_audios")
|
| 113 |
+
return clean_audios, noisy_audios
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
collate_fn = CollateFunction()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def main():
|
| 120 |
+
args = get_args()
|
| 121 |
+
|
| 122 |
+
config = ConvTasNetConfig.from_pretrained(
|
| 123 |
+
pretrained_model_name_or_path=args.config_file,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
serialization_dir = Path(args.serialization_dir)
|
| 127 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
logger = logging_config(serialization_dir)
|
| 130 |
+
|
| 131 |
+
random.seed(args.seed)
|
| 132 |
+
np.random.seed(args.seed)
|
| 133 |
+
torch.manual_seed(args.seed)
|
| 134 |
+
logger.info(f"set seed: {args.seed}")
|
| 135 |
+
|
| 136 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 137 |
+
n_gpu = torch.cuda.device_count()
|
| 138 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 139 |
+
|
| 140 |
+
# datasets
|
| 141 |
+
train_dataset = DenoiseJsonlDataset(
|
| 142 |
+
jsonl_file=args.train_dataset,
|
| 143 |
+
expected_sample_rate=config.sample_rate,
|
| 144 |
+
max_wave_value=32768.0,
|
| 145 |
+
min_snr_db=config.min_snr_db,
|
| 146 |
+
max_snr_db=config.max_snr_db,
|
| 147 |
+
# skip=225000,
|
| 148 |
+
)
|
| 149 |
+
valid_dataset = DenoiseJsonlDataset(
|
| 150 |
+
jsonl_file=args.valid_dataset,
|
| 151 |
+
expected_sample_rate=config.sample_rate,
|
| 152 |
+
max_wave_value=32768.0,
|
| 153 |
+
min_snr_db=config.min_snr_db,
|
| 154 |
+
max_snr_db=config.max_snr_db,
|
| 155 |
+
)
|
| 156 |
+
train_data_loader = DataLoader(
|
| 157 |
+
dataset=train_dataset,
|
| 158 |
+
batch_size=args.batch_size,
|
| 159 |
+
# shuffle=True,
|
| 160 |
+
sampler=None,
|
| 161 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 162 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 163 |
+
collate_fn=collate_fn,
|
| 164 |
+
pin_memory=False,
|
| 165 |
+
prefetch_factor=2,
|
| 166 |
+
)
|
| 167 |
+
valid_data_loader = DataLoader(
|
| 168 |
+
dataset=valid_dataset,
|
| 169 |
+
batch_size=args.batch_size,
|
| 170 |
+
# shuffle=True,
|
| 171 |
+
sampler=None,
|
| 172 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 173 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 174 |
+
collate_fn=collate_fn,
|
| 175 |
+
pin_memory=False,
|
| 176 |
+
prefetch_factor=2,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# models
|
| 180 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 181 |
+
model = ConvTasNetPretrainedModel(config).to(device)
|
| 182 |
+
model.to(device)
|
| 183 |
+
model.train()
|
| 184 |
+
|
| 185 |
+
# optimizer
|
| 186 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
| 187 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
| 188 |
+
|
| 189 |
+
# resume training
|
| 190 |
+
last_step_idx = -1
|
| 191 |
+
last_epoch = -1
|
| 192 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 193 |
+
step_idx_str = Path(step_idx_str)
|
| 194 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
| 195 |
+
step_idx = int(step_idx)
|
| 196 |
+
if step_idx > last_step_idx:
|
| 197 |
+
last_step_idx = step_idx
|
| 198 |
+
last_epoch = 1
|
| 199 |
+
|
| 200 |
+
if last_step_idx != -1:
|
| 201 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
| 202 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 203 |
+
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
| 204 |
+
|
| 205 |
+
logger.info(f"load state dict for model.")
|
| 206 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 207 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 208 |
+
model.load_state_dict(state_dict, strict=True)
|
| 209 |
+
|
| 210 |
+
logger.info(f"load state dict for optimizer.")
|
| 211 |
+
with open(optimizer_pth.as_posix(), "rb") as f:
|
| 212 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 213 |
+
optimizer.load_state_dict(state_dict)
|
| 214 |
+
|
| 215 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
| 216 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 217 |
+
optimizer,
|
| 218 |
+
last_epoch=last_epoch,
|
| 219 |
+
# T_max=10 * config.eval_steps,
|
| 220 |
+
# eta_min=0.01 * config.lr,
|
| 221 |
+
**config.lr_scheduler_kwargs,
|
| 222 |
+
)
|
| 223 |
+
elif config.lr_scheduler == "MultiStepLR":
|
| 224 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 225 |
+
optimizer,
|
| 226 |
+
last_epoch=last_epoch,
|
| 227 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 231 |
+
|
| 232 |
+
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
| 233 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 234 |
+
neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
|
| 235 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 236 |
+
fft_size_list=[256, 512, 1024],
|
| 237 |
+
win_size_list=[120, 240, 480],
|
| 238 |
+
hop_size_list=[25, 50, 100],
|
| 239 |
+
factor_sc=1.5,
|
| 240 |
+
factor_mag=1.0,
|
| 241 |
+
reduction="mean"
|
| 242 |
+
).to(device)
|
| 243 |
+
pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
|
| 244 |
+
|
| 245 |
+
# training loop
|
| 246 |
+
|
| 247 |
+
# state
|
| 248 |
+
average_pesq_score = 1000000000
|
| 249 |
+
average_loss = 1000000000
|
| 250 |
+
average_ae_loss = 1000000000
|
| 251 |
+
average_neg_si_snr_loss = 1000000000
|
| 252 |
+
average_neg_stoi_loss = 1000000000
|
| 253 |
+
|
| 254 |
+
model_list = list()
|
| 255 |
+
best_epoch_idx = None
|
| 256 |
+
best_step_idx = None
|
| 257 |
+
best_metric = None
|
| 258 |
+
patience_count = 0
|
| 259 |
+
|
| 260 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 261 |
+
|
| 262 |
+
logger.info("training")
|
| 263 |
+
for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
|
| 264 |
+
# train
|
| 265 |
+
model.train()
|
| 266 |
+
|
| 267 |
+
total_pesq_score = 0.
|
| 268 |
+
total_loss = 0.
|
| 269 |
+
total_ae_loss = 0.
|
| 270 |
+
total_neg_si_snr_loss = 0.
|
| 271 |
+
total_neg_stoi_loss = 0.
|
| 272 |
+
total_mr_stft_loss = 0.
|
| 273 |
+
total_pesq_loss = 0.
|
| 274 |
+
total_batches = 0.
|
| 275 |
+
|
| 276 |
+
progress_bar_train = tqdm(
|
| 277 |
+
initial=step_idx,
|
| 278 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
| 279 |
+
)
|
| 280 |
+
for train_batch in train_data_loader:
|
| 281 |
+
clean_audios, noisy_audios = train_batch
|
| 282 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 283 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 284 |
+
|
| 285 |
+
denoise_audios = model.forward(noisy_audios)
|
| 286 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 287 |
+
|
| 288 |
+
if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)):
|
| 289 |
+
raise AssertionError("nan or inf in denoise_audios")
|
| 290 |
+
|
| 291 |
+
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
| 292 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 293 |
+
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
| 294 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 295 |
+
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
| 296 |
+
|
| 297 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
| 298 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
| 299 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
| 300 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
| 301 |
+
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
| 302 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
| 303 |
+
loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss
|
| 304 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 305 |
+
logger.info(f"find nan or inf in loss.")
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 309 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 310 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 311 |
+
|
| 312 |
+
optimizer.zero_grad()
|
| 313 |
+
loss.backward()
|
| 314 |
+
optimizer.step()
|
| 315 |
+
lr_scheduler.step()
|
| 316 |
+
|
| 317 |
+
total_pesq_score += pesq_score
|
| 318 |
+
total_loss += loss.item()
|
| 319 |
+
total_ae_loss += ae_loss.item()
|
| 320 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 321 |
+
total_neg_stoi_loss += neg_stoi_loss.item()
|
| 322 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 323 |
+
total_pesq_loss += pesq_loss.item()
|
| 324 |
+
total_batches += 1
|
| 325 |
+
|
| 326 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 327 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 328 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
| 329 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 330 |
+
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
| 331 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 332 |
+
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
|
| 333 |
+
|
| 334 |
+
progress_bar_train.update(1)
|
| 335 |
+
progress_bar_train.set_postfix({
|
| 336 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 337 |
+
"pesq_score": average_pesq_score,
|
| 338 |
+
"loss": average_loss,
|
| 339 |
+
"ae_loss": average_ae_loss,
|
| 340 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 341 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
| 342 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 343 |
+
"pesq_loss": average_pesq_loss,
|
| 344 |
+
})
|
| 345 |
+
|
| 346 |
+
# evaluation
|
| 347 |
+
step_idx += 1
|
| 348 |
+
if step_idx % config.eval_steps == 0:
|
| 349 |
+
model.eval()
|
| 350 |
+
with torch.no_grad():
|
| 351 |
+
torch.cuda.empty_cache()
|
| 352 |
+
|
| 353 |
+
total_pesq_score = 0.
|
| 354 |
+
total_loss = 0.
|
| 355 |
+
total_ae_loss = 0.
|
| 356 |
+
total_neg_si_snr_loss = 0.
|
| 357 |
+
total_neg_stoi_loss = 0.
|
| 358 |
+
total_mr_stft_loss = 0.
|
| 359 |
+
total_pesq_loss = 0.
|
| 360 |
+
total_batches = 0.
|
| 361 |
+
|
| 362 |
+
progress_bar_train.close()
|
| 363 |
+
progress_bar_eval = tqdm(
|
| 364 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 365 |
+
)
|
| 366 |
+
for eval_batch in valid_data_loader:
|
| 367 |
+
clean_audios, noisy_audios = eval_batch
|
| 368 |
+
clean_audios = clean_audios.to(device)
|
| 369 |
+
noisy_audios = noisy_audios.to(device)
|
| 370 |
+
|
| 371 |
+
denoise_audios = model.forward(noisy_audios)
|
| 372 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 373 |
+
|
| 374 |
+
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
| 375 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 376 |
+
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
| 377 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 378 |
+
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
| 379 |
+
|
| 380 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
| 381 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
| 382 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
| 383 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
| 384 |
+
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
| 385 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
| 386 |
+
loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss
|
| 387 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 388 |
+
logger.info(f"find nan or inf in loss.")
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 392 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 393 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 394 |
+
|
| 395 |
+
total_pesq_score += pesq_score
|
| 396 |
+
total_loss += loss.item()
|
| 397 |
+
total_ae_loss += ae_loss.item()
|
| 398 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 399 |
+
total_neg_stoi_loss += neg_stoi_loss.item()
|
| 400 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 401 |
+
total_pesq_loss += pesq_loss.item()
|
| 402 |
+
total_batches += 1
|
| 403 |
+
|
| 404 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 405 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 406 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
| 407 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 408 |
+
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
| 409 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 410 |
+
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
|
| 411 |
+
|
| 412 |
+
progress_bar_eval.update(1)
|
| 413 |
+
progress_bar_eval.set_postfix({
|
| 414 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 415 |
+
"pesq_score": average_pesq_score,
|
| 416 |
+
"loss": average_loss,
|
| 417 |
+
"ae_loss": average_ae_loss,
|
| 418 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 419 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
| 420 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 421 |
+
"pesq_loss": average_pesq_loss,
|
| 422 |
+
})
|
| 423 |
+
|
| 424 |
+
total_pesq_score = 0.
|
| 425 |
+
total_loss = 0.
|
| 426 |
+
total_ae_loss = 0.
|
| 427 |
+
total_neg_si_snr_loss = 0.
|
| 428 |
+
total_neg_stoi_loss = 0.
|
| 429 |
+
total_mr_stft_loss = 0.
|
| 430 |
+
total_pesq_loss = 0.
|
| 431 |
+
total_batches = 0.
|
| 432 |
+
|
| 433 |
+
progress_bar_eval.close()
|
| 434 |
+
progress_bar_train = tqdm(
|
| 435 |
+
initial=progress_bar_train.n,
|
| 436 |
+
postfix=progress_bar_train.postfix,
|
| 437 |
+
desc=progress_bar_train.desc,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# save path
|
| 441 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 442 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
| 443 |
+
|
| 444 |
+
# save models
|
| 445 |
+
model.save_pretrained(save_dir.as_posix())
|
| 446 |
+
|
| 447 |
+
model_list.append(save_dir)
|
| 448 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 449 |
+
model_to_delete: Path = model_list.pop(0)
|
| 450 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 451 |
+
|
| 452 |
+
# save optim
|
| 453 |
+
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
| 454 |
+
|
| 455 |
+
# save metric
|
| 456 |
+
if best_metric is None:
|
| 457 |
+
best_epoch_idx = epoch_idx
|
| 458 |
+
best_step_idx = step_idx
|
| 459 |
+
best_metric = average_pesq_score
|
| 460 |
+
elif average_pesq_score > best_metric:
|
| 461 |
+
# great is better.
|
| 462 |
+
best_epoch_idx = epoch_idx
|
| 463 |
+
best_step_idx = step_idx
|
| 464 |
+
best_metric = average_pesq_score
|
| 465 |
+
else:
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
metrics = {
|
| 469 |
+
"epoch_idx": epoch_idx,
|
| 470 |
+
"best_epoch_idx": best_epoch_idx,
|
| 471 |
+
"best_step_idx": best_step_idx,
|
| 472 |
+
"pesq_score": average_pesq_score,
|
| 473 |
+
"loss": average_loss,
|
| 474 |
+
"ae_loss": average_ae_loss,
|
| 475 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 476 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
| 477 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 478 |
+
"pesq_loss": average_pesq_loss,
|
| 479 |
+
}
|
| 480 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
| 481 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 482 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 483 |
+
|
| 484 |
+
# save best
|
| 485 |
+
best_dir = serialization_dir / "best"
|
| 486 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 487 |
+
if best_dir.exists():
|
| 488 |
+
shutil.rmtree(best_dir)
|
| 489 |
+
shutil.copytree(save_dir, best_dir)
|
| 490 |
+
|
| 491 |
+
# early stop
|
| 492 |
+
early_stop_flag = False
|
| 493 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 494 |
+
patience_count = 0
|
| 495 |
+
else:
|
| 496 |
+
patience_count += 1
|
| 497 |
+
if patience_count >= args.patience:
|
| 498 |
+
early_stop_flag = True
|
| 499 |
+
|
| 500 |
+
# early stop
|
| 501 |
+
if early_stop_flag:
|
| 502 |
+
break
|
| 503 |
+
model.train()
|
| 504 |
+
|
| 505 |
+
return
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
if __name__ == "__main__":
|
| 509 |
+
main()
|
examples/conv_tasnet/yaml/config.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "conv_tasnet"
|
| 2 |
+
|
| 3 |
+
sample_rate: 8000
|
| 4 |
+
segment_size: 4
|
| 5 |
+
|
| 6 |
+
win_size: 20
|
| 7 |
+
freq_bins: 256
|
| 8 |
+
bottleneck_channels: 256
|
| 9 |
+
num_speakers: 1
|
| 10 |
+
num_blocks: 4
|
| 11 |
+
num_sub_blocks: 8
|
| 12 |
+
sub_blocks_channels: 512
|
| 13 |
+
sub_blocks_kernel_size: 3
|
| 14 |
+
|
| 15 |
+
norm_type: "gLN"
|
| 16 |
+
causal: false
|
| 17 |
+
mask_nonlinear: "relu"
|
| 18 |
+
|
| 19 |
+
min_snr_db: -10
|
| 20 |
+
max_snr_db: 20
|
| 21 |
+
|
| 22 |
+
lr: 0.005
|
| 23 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 24 |
+
lr_scheduler_kwargs:
|
| 25 |
+
T_max: 250000
|
| 26 |
+
eta_min: 0.00005
|
| 27 |
+
|
| 28 |
+
eval_steps: 25000
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
| 5 |
+
|
| 6 |
+
1.2G
|
| 7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
| 8 |
+
|
| 9 |
+
14G
|
| 10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
| 11 |
+
|
| 12 |
+
38G
|
| 13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
| 14 |
+
|
| 15 |
+
247M
|
| 16 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
import argparse
|
| 21 |
+
import os
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
|
| 28 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 29 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 30 |
+
|
| 31 |
+
import librosa
|
| 32 |
+
from scipy.io import wavfile
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_args():
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--data_dir",
|
| 40 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
|
| 41 |
+
type=str
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--output_dir",
|
| 45 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
|
| 46 |
+
type=str
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
return args
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
args = get_args()
|
| 55 |
+
|
| 56 |
+
data_dir = Path(args.data_dir)
|
| 57 |
+
output_dir = Path(args.output_dir)
|
| 58 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
# finished_set
|
| 61 |
+
finished_set = set()
|
| 62 |
+
for filename in tqdm(output_dir.glob("**/*.wav")):
|
| 63 |
+
name = filename.stem
|
| 64 |
+
finished_set.add(name)
|
| 65 |
+
print(f"finished_set count: {len(finished_set)}")
|
| 66 |
+
|
| 67 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
| 68 |
+
label = filename.parts[-2]
|
| 69 |
+
name = filename.stem
|
| 70 |
+
# print(f"filename: {filename.as_posix()}")
|
| 71 |
+
if name in finished_set:
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
| 75 |
+
|
| 76 |
+
signal = signal * (1 << 15)
|
| 77 |
+
signal = np.array(signal, dtype=np.int16)
|
| 78 |
+
|
| 79 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
| 80 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
wavfile.write(
|
| 82 |
+
to_file.as_posix(),
|
| 83 |
+
rate=args.sample_rate,
|
| 84 |
+
data=signal,
|
| 85 |
+
)
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
| 5 |
+
|
| 6 |
+
1.2G
|
| 7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
| 8 |
+
|
| 9 |
+
14G
|
| 10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
| 11 |
+
|
| 12 |
+
38G
|
| 13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
| 14 |
+
|
| 15 |
+
12G
|
| 16 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.french_data.tar.bz2
|
| 17 |
+
|
| 18 |
+
43G
|
| 19 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.german_speech.tar.bz2
|
| 20 |
+
|
| 21 |
+
7.9G
|
| 22 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.italian_speech.tar.bz2
|
| 23 |
+
|
| 24 |
+
12G
|
| 25 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.mandarin_speech.tar.bz2
|
| 26 |
+
|
| 27 |
+
3.1G
|
| 28 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.russian_speech.tar.bz2
|
| 29 |
+
|
| 30 |
+
9.7G
|
| 31 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.spanish_speech.tar.bz2
|
| 32 |
+
|
| 33 |
+
617M
|
| 34 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.singing_voice.tar.bz2
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
import argparse
|
| 38 |
+
import os
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
import sys
|
| 41 |
+
|
| 42 |
+
import numpy as np
|
| 43 |
+
from tqdm import tqdm
|
| 44 |
+
|
| 45 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 46 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 47 |
+
|
| 48 |
+
import librosa
|
| 49 |
+
from scipy.io import wavfile
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_args():
|
| 53 |
+
parser = argparse.ArgumentParser()
|
| 54 |
+
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--data_dir",
|
| 57 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
|
| 58 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
|
| 59 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
|
| 60 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.french_data\datasets\clean\french_data",
|
| 61 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech",
|
| 62 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
|
| 63 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
|
| 64 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.spanish_speech\datasets\clean\spanish_speech",
|
| 65 |
+
type=str
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--output_dir",
|
| 69 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
|
| 70 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
|
| 71 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
|
| 72 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-french-speech-8k",
|
| 73 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k",
|
| 74 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
|
| 75 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
|
| 76 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-spanish-speech-8k",
|
| 77 |
+
type=str
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
| 80 |
+
args = parser.parse_args()
|
| 81 |
+
return args
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
args = get_args()
|
| 86 |
+
|
| 87 |
+
data_dir = Path(args.data_dir)
|
| 88 |
+
output_dir = Path(args.output_dir)
|
| 89 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
# finished_set
|
| 92 |
+
finished_set = set()
|
| 93 |
+
for filename in tqdm(output_dir.glob("**/*.wav")):
|
| 94 |
+
filename = Path(filename)
|
| 95 |
+
relative_name = filename.relative_to(output_dir)
|
| 96 |
+
relative_name_ = relative_name.as_posix()
|
| 97 |
+
finished_set.add(relative_name_)
|
| 98 |
+
print(f"finished_set count: {len(finished_set)}")
|
| 99 |
+
|
| 100 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
| 101 |
+
relative_name = filename.relative_to(data_dir)
|
| 102 |
+
relative_name_ = relative_name.as_posix()
|
| 103 |
+
if relative_name_ in finished_set:
|
| 104 |
+
continue
|
| 105 |
+
finished_set.add(relative_name_)
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
signal, _ = librosa.load(filename.as_posix(), mono=False, sr=args.sample_rate)
|
| 109 |
+
except Exception:
|
| 110 |
+
print(f"skip file: {filename.as_posix()}")
|
| 111 |
+
continue
|
| 112 |
+
if signal.ndim != 1:
|
| 113 |
+
raise AssertionError
|
| 114 |
+
|
| 115 |
+
signal = signal * (1 << 15)
|
| 116 |
+
signal = np.array(signal, dtype=np.int16)
|
| 117 |
+
|
| 118 |
+
to_file = output_dir / relative_name.as_posix()
|
| 119 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
wavfile.write(
|
| 121 |
+
to_file.as_posix(),
|
| 122 |
+
rate=args.sample_rate,
|
| 123 |
+
data=signal,
|
| 124 |
+
)
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
| 5 |
+
|
| 6 |
+
1.2G
|
| 7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import random
|
| 14 |
+
import sys
|
| 15 |
+
import shutil
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 21 |
+
|
| 22 |
+
import librosa
|
| 23 |
+
from scipy.io import wavfile
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_args():
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--data_dir",
|
| 31 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\DEMAND\demand",
|
| 32 |
+
type=str
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--output_dir",
|
| 36 |
+
default=r"E:\programmer\asr_datasets\denoise\demand-8k",
|
| 37 |
+
type=str
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
return args
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
args = get_args()
|
| 46 |
+
|
| 47 |
+
data_dir = Path(args.data_dir)
|
| 48 |
+
output_dir = Path(args.output_dir)
|
| 49 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
| 50 |
+
|
| 51 |
+
for filename in data_dir.glob("**/ch01.wav"):
|
| 52 |
+
label = filename.parts[-2]
|
| 53 |
+
name = filename.stem
|
| 54 |
+
|
| 55 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
| 56 |
+
|
| 57 |
+
signal = signal * (1 << 15)
|
| 58 |
+
signal = np.array(signal, dtype=np.int16)
|
| 59 |
+
|
| 60 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
| 61 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
wavfile.write(
|
| 63 |
+
to_file.as_posix(),
|
| 64 |
+
rate=args.sample_rate,
|
| 65 |
+
data=signal,
|
| 66 |
+
)
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
| 5 |
+
|
| 6 |
+
1.2G
|
| 7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
| 8 |
+
|
| 9 |
+
14G
|
| 10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
| 11 |
+
|
| 12 |
+
38G
|
| 13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
| 14 |
+
|
| 15 |
+
247M
|
| 16 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
|
| 17 |
+
|
| 18 |
+
240M
|
| 19 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.impulse_responses.tar.bz2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
import argparse
|
| 24 |
+
import os
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
|
| 31 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 32 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 33 |
+
|
| 34 |
+
import librosa
|
| 35 |
+
from scipy.io import wavfile
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_args():
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--data_dir",
|
| 43 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
|
| 44 |
+
type=str
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--output_dir",
|
| 48 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
|
| 49 |
+
type=str
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
return args
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main():
|
| 57 |
+
args = get_args()
|
| 58 |
+
|
| 59 |
+
data_dir = Path(args.data_dir)
|
| 60 |
+
output_dir = Path(args.output_dir)
|
| 61 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
# finished_set
|
| 64 |
+
finished_set = set()
|
| 65 |
+
for filename in tqdm(output_dir.glob("**/*.wav")):
|
| 66 |
+
name = filename.stem
|
| 67 |
+
finished_set.add(name)
|
| 68 |
+
print(f"finished_set count: {len(finished_set)}")
|
| 69 |
+
|
| 70 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
| 71 |
+
label = filename.parts[-2]
|
| 72 |
+
name = filename.stem
|
| 73 |
+
# print(f"filename: {filename.as_posix()}")
|
| 74 |
+
if name in finished_set:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
| 78 |
+
|
| 79 |
+
signal = signal * (1 << 15)
|
| 80 |
+
signal = np.array(signal, dtype=np.int16)
|
| 81 |
+
|
| 82 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
| 83 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
| 84 |
+
wavfile.write(
|
| 85 |
+
to_file.as_posix(),
|
| 86 |
+
rate=args.sample_rate,
|
| 87 |
+
data=signal,
|
| 88 |
+
)
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
| 5 |
+
|
| 6 |
+
1.2G
|
| 7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
| 8 |
+
|
| 9 |
+
14G
|
| 10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
| 11 |
+
|
| 12 |
+
38G
|
| 13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 25 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 26 |
+
|
| 27 |
+
import librosa
|
| 28 |
+
from scipy.io import wavfile
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_args():
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--data_dir",
|
| 36 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.noise\datasets",
|
| 37 |
+
type=str
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--output_dir",
|
| 41 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-noise-8k",
|
| 42 |
+
type=str
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
return args
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main():
|
| 50 |
+
args = get_args()
|
| 51 |
+
|
| 52 |
+
data_dir = Path(args.data_dir)
|
| 53 |
+
output_dir = Path(args.output_dir)
|
| 54 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
| 57 |
+
label = filename.parts[-2]
|
| 58 |
+
name = filename.stem
|
| 59 |
+
# print(f"filename: {filename.as_posix()}")
|
| 60 |
+
|
| 61 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
| 62 |
+
|
| 63 |
+
signal = signal * (1 << 15)
|
| 64 |
+
signal = np.array(signal, dtype=np.int16)
|
| 65 |
+
|
| 66 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
| 67 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
wavfile.write(
|
| 69 |
+
to_file.as_posix(),
|
| 70 |
+
rate=args.sample_rate,
|
| 71 |
+
data=signal,
|
| 72 |
+
)
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_musan.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://www.openslr.org/17/
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
if __name__ == '__main__':
|
| 8 |
+
pass
|
examples/data_preprocess/ms_snsd_to_8k/process_ms_snsd.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
MS-SNSD
|
| 5 |
+
https://github.com/microsoft/MS-SNSD
|
| 6 |
+
"""
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 17 |
+
|
| 18 |
+
import librosa
|
| 19 |
+
from scipy.io import wavfile
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_args():
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--data_dir",
|
| 27 |
+
default=r"E:\programmer\asr_datasets\MS-SNSD",
|
| 28 |
+
type=str
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--output_dir",
|
| 32 |
+
default=r"E:\programmer\asr_datasets\denoise\ms-snsd-noise-8k",
|
| 33 |
+
type=str
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
return args
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
args = get_args()
|
| 42 |
+
|
| 43 |
+
data_dir = Path(args.data_dir)
|
| 44 |
+
output_dir = Path(args.output_dir)
|
| 45 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
|
| 47 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
| 48 |
+
label = filename.parts[-2]
|
| 49 |
+
name = filename.stem
|
| 50 |
+
|
| 51 |
+
if label not in ["noise_train", "noise_test", "clean_train", "clean_test"]:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
| 55 |
+
|
| 56 |
+
signal = signal * (1 << 15)
|
| 57 |
+
signal = np.array(signal, dtype=np.int16)
|
| 58 |
+
|
| 59 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
| 60 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
wavfile.write(
|
| 62 |
+
to_file.as_posix(),
|
| 63 |
+
rate=args.sample_rate,
|
| 64 |
+
data=signal,
|
| 65 |
+
)
|
| 66 |
+
return
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
main()
|
examples/dfnet/run.sh
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
|
| 6 |
+
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
| 7 |
+
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
| 8 |
+
|
| 9 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-nx-dns3 \
|
| 10 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 11 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
END
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# params
|
| 18 |
+
system_version="windows";
|
| 19 |
+
verbose=true;
|
| 20 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 21 |
+
stop_stage=9
|
| 22 |
+
|
| 23 |
+
work_dir="$(pwd)"
|
| 24 |
+
file_folder_name=file_folder_name
|
| 25 |
+
final_model_name=final_model_name
|
| 26 |
+
config_file="yaml/config.yaml"
|
| 27 |
+
limit=10
|
| 28 |
+
|
| 29 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 30 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 31 |
+
|
| 32 |
+
max_count=10000000
|
| 33 |
+
|
| 34 |
+
nohup_name=nohup.out
|
| 35 |
+
|
| 36 |
+
# model params
|
| 37 |
+
batch_size=64
|
| 38 |
+
max_epochs=200
|
| 39 |
+
save_top_k=10
|
| 40 |
+
patience=5
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# parse options
|
| 44 |
+
while true; do
|
| 45 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 46 |
+
case "$1" in
|
| 47 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 48 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 49 |
+
old_value="(eval echo \\$$name)";
|
| 50 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 51 |
+
was_bool=true;
|
| 52 |
+
else
|
| 53 |
+
was_bool=false;
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 57 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 58 |
+
eval "${name}=\"$2\"";
|
| 59 |
+
|
| 60 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 61 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 62 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 63 |
+
exit 1;
|
| 64 |
+
fi
|
| 65 |
+
shift 2;
|
| 66 |
+
;;
|
| 67 |
+
|
| 68 |
+
*) break;
|
| 69 |
+
esac
|
| 70 |
+
done
|
| 71 |
+
|
| 72 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 73 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 74 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 75 |
+
|
| 76 |
+
train_dataset="${file_dir}/train.jsonl"
|
| 77 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
| 78 |
+
|
| 79 |
+
$verbose && echo "system_version: ${system_version}"
|
| 80 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 81 |
+
|
| 82 |
+
if [ $system_version == "windows" ]; then
|
| 83 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 84 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 85 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 86 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 91 |
+
$verbose && echo "stage 1: prepare data"
|
| 92 |
+
cd "${work_dir}" || exit 1
|
| 93 |
+
python3 step_1_prepare_data.py \
|
| 94 |
+
--file_dir "${file_dir}" \
|
| 95 |
+
--noise_dir "${noise_dir}" \
|
| 96 |
+
--speech_dir "${speech_dir}" \
|
| 97 |
+
--train_dataset "${train_dataset}" \
|
| 98 |
+
--valid_dataset "${valid_dataset}" \
|
| 99 |
+
--max_count "${max_count}" \
|
| 100 |
+
|
| 101 |
+
fi
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 105 |
+
$verbose && echo "stage 2: train model"
|
| 106 |
+
cd "${work_dir}" || exit 1
|
| 107 |
+
python3 step_2_train_model.py \
|
| 108 |
+
--train_dataset "${train_dataset}" \
|
| 109 |
+
--valid_dataset "${valid_dataset}" \
|
| 110 |
+
--serialization_dir "${file_dir}" \
|
| 111 |
+
--config_file "${config_file}" \
|
| 112 |
+
|
| 113 |
+
fi
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 117 |
+
$verbose && echo "stage 3: test model"
|
| 118 |
+
cd "${work_dir}" || exit 1
|
| 119 |
+
python3 step_3_evaluation.py \
|
| 120 |
+
--valid_dataset "${valid_dataset}" \
|
| 121 |
+
--model_dir "${file_dir}/best" \
|
| 122 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 123 |
+
--limit "${limit}" \
|
| 124 |
+
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 129 |
+
$verbose && echo "stage 4: collect files"
|
| 130 |
+
cd "${work_dir}" || exit 1
|
| 131 |
+
|
| 132 |
+
mkdir -p ${final_model_dir}
|
| 133 |
+
|
| 134 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 135 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 136 |
+
|
| 137 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 138 |
+
|
| 139 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 140 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 141 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 142 |
+
fi
|
| 143 |
+
|
| 144 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 145 |
+
rm -rf "${final_model_name}"
|
| 146 |
+
|
| 147 |
+
fi
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 151 |
+
$verbose && echo "stage 5: clear file_dir"
|
| 152 |
+
cd "${work_dir}" || exit 1
|
| 153 |
+
|
| 154 |
+
rm -rf "${file_dir}";
|
| 155 |
+
|
| 156 |
+
fi
|
examples/dfnet/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--noise_dir",
|
| 24 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--speech_dir",
|
| 29 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 34 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 35 |
+
|
| 36 |
+
parser.add_argument("--duration", default=4.0, type=float)
|
| 37 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 38 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 39 |
+
|
| 40 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--max_count", default=10000, type=int)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
return args
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def filename_generator(data_dir: str):
|
| 49 |
+
data_dir = Path(data_dir)
|
| 50 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 51 |
+
yield filename.as_posix()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
| 55 |
+
data_dir = Path(data_dir)
|
| 56 |
+
for epoch_idx in range(max_epoch):
|
| 57 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
+
|
| 61 |
+
if raw_duration < duration:
|
| 62 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
+
continue
|
| 64 |
+
if signal.ndim != 1:
|
| 65 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
+
|
| 67 |
+
signal_length = len(signal)
|
| 68 |
+
win_size = int(duration * sample_rate)
|
| 69 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 71 |
+
continue
|
| 72 |
+
row = {
|
| 73 |
+
"epoch_idx": epoch_idx,
|
| 74 |
+
"filename": filename.as_posix(),
|
| 75 |
+
"raw_duration": round(raw_duration, 4),
|
| 76 |
+
"offset": round(begin / sample_rate, 4),
|
| 77 |
+
"duration": round(duration, 4),
|
| 78 |
+
}
|
| 79 |
+
yield row
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
args = get_args()
|
| 84 |
+
|
| 85 |
+
file_dir = Path(args.file_dir)
|
| 86 |
+
file_dir.mkdir(exist_ok=True)
|
| 87 |
+
|
| 88 |
+
noise_dir = Path(args.noise_dir)
|
| 89 |
+
speech_dir = Path(args.speech_dir)
|
| 90 |
+
|
| 91 |
+
noise_generator = target_second_signal_generator(
|
| 92 |
+
noise_dir.as_posix(),
|
| 93 |
+
duration=args.duration,
|
| 94 |
+
sample_rate=args.target_sample_rate,
|
| 95 |
+
max_epoch=100000,
|
| 96 |
+
)
|
| 97 |
+
speech_generator = target_second_signal_generator(
|
| 98 |
+
speech_dir.as_posix(),
|
| 99 |
+
duration=args.duration,
|
| 100 |
+
sample_rate=args.target_sample_rate,
|
| 101 |
+
max_epoch=1,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
dataset = list()
|
| 105 |
+
|
| 106 |
+
count = 0
|
| 107 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
| 108 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 109 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
| 110 |
+
if count >= args.max_count > 0:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
noise_filename = noise["filename"]
|
| 114 |
+
noise_raw_duration = noise["raw_duration"]
|
| 115 |
+
noise_offset = noise["offset"]
|
| 116 |
+
noise_duration = noise["duration"]
|
| 117 |
+
|
| 118 |
+
speech_filename = speech["filename"]
|
| 119 |
+
speech_raw_duration = speech["raw_duration"]
|
| 120 |
+
speech_offset = speech["offset"]
|
| 121 |
+
speech_duration = speech["duration"]
|
| 122 |
+
|
| 123 |
+
random1 = random.random()
|
| 124 |
+
random2 = random.random()
|
| 125 |
+
|
| 126 |
+
row = {
|
| 127 |
+
"count": count,
|
| 128 |
+
|
| 129 |
+
"noise_filename": noise_filename,
|
| 130 |
+
"noise_raw_duration": noise_raw_duration,
|
| 131 |
+
"noise_offset": noise_offset,
|
| 132 |
+
"noise_duration": noise_duration,
|
| 133 |
+
|
| 134 |
+
"speech_filename": speech_filename,
|
| 135 |
+
"speech_raw_duration": speech_raw_duration,
|
| 136 |
+
"speech_offset": speech_offset,
|
| 137 |
+
"speech_duration": speech_duration,
|
| 138 |
+
|
| 139 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 140 |
+
|
| 141 |
+
"random1": random1,
|
| 142 |
+
}
|
| 143 |
+
row = json.dumps(row, ensure_ascii=False)
|
| 144 |
+
if random2 < (1 / 300 / 1):
|
| 145 |
+
fvalid.write(f"{row}\n")
|
| 146 |
+
else:
|
| 147 |
+
ftrain.write(f"{row}\n")
|
| 148 |
+
|
| 149 |
+
count += 1
|
| 150 |
+
duration_seconds = count * args.duration
|
| 151 |
+
duration_hours = duration_seconds / 3600
|
| 152 |
+
|
| 153 |
+
process_bar.update(n=1)
|
| 154 |
+
process_bar.set_postfix({
|
| 155 |
+
# "duration_seconds": round(duration_seconds, 4),
|
| 156 |
+
"duration_hours": round(duration_hours, 4),
|
| 157 |
+
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
main()
|
examples/dfnet/step_2_train_model.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/Rikorose/DeepFilterNet
|
| 5 |
+
"""
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 10 |
+
import os
|
| 11 |
+
import platform
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import random
|
| 14 |
+
import sys
|
| 15 |
+
import shutil
|
| 16 |
+
from typing import List
|
| 17 |
+
|
| 18 |
+
from fontTools.varLib.plot import stops
|
| 19 |
+
|
| 20 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 21 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
from torch.utils.data.dataloader import DataLoader
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
| 31 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 32 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 33 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 34 |
+
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
| 35 |
+
from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_args():
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 41 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 42 |
+
|
| 43 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 44 |
+
parser.add_argument("--patience", default=10, type=int)
|
| 45 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 46 |
+
|
| 47 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 48 |
+
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
return args
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def logging_config(file_dir: str):
|
| 54 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 55 |
+
|
| 56 |
+
logging.basicConfig(format=fmt,
|
| 57 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 58 |
+
level=logging.INFO)
|
| 59 |
+
file_handler = TimedRotatingFileHandler(
|
| 60 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 61 |
+
encoding="utf-8",
|
| 62 |
+
when="D",
|
| 63 |
+
interval=1,
|
| 64 |
+
backupCount=7
|
| 65 |
+
)
|
| 66 |
+
file_handler.setLevel(logging.INFO)
|
| 67 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
logger.addHandler(file_handler)
|
| 70 |
+
|
| 71 |
+
return logger
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CollateFunction(object):
|
| 75 |
+
def __init__(self):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def __call__(self, batch: List[dict]):
|
| 79 |
+
clean_audios = list()
|
| 80 |
+
noisy_audios = list()
|
| 81 |
+
snr_db_list = list()
|
| 82 |
+
|
| 83 |
+
for sample in batch:
|
| 84 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
| 85 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
| 86 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
| 87 |
+
# snr_db: float = sample["snr_db"]
|
| 88 |
+
|
| 89 |
+
clean_audios.append(clean_audio)
|
| 90 |
+
noisy_audios.append(noisy_audio)
|
| 91 |
+
|
| 92 |
+
clean_audios = torch.stack(clean_audios)
|
| 93 |
+
noisy_audios = torch.stack(noisy_audios)
|
| 94 |
+
|
| 95 |
+
# assert
|
| 96 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 97 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 98 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 99 |
+
raise AssertionError("nan or inf in noisy_audios")
|
| 100 |
+
return clean_audios, noisy_audios
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
collate_fn = CollateFunction()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
args = get_args()
|
| 108 |
+
|
| 109 |
+
config = DfNetConfig.from_pretrained(
|
| 110 |
+
pretrained_model_name_or_path=args.config_file,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
serialization_dir = Path(args.serialization_dir)
|
| 114 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
|
| 116 |
+
logger = logging_config(serialization_dir)
|
| 117 |
+
|
| 118 |
+
random.seed(config.seed)
|
| 119 |
+
np.random.seed(config.seed)
|
| 120 |
+
torch.manual_seed(config.seed)
|
| 121 |
+
logger.info(f"set seed: {config.seed}")
|
| 122 |
+
|
| 123 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 124 |
+
n_gpu = torch.cuda.device_count()
|
| 125 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 126 |
+
|
| 127 |
+
# datasets
|
| 128 |
+
train_dataset = DenoiseJsonlDataset(
|
| 129 |
+
jsonl_file=args.train_dataset,
|
| 130 |
+
expected_sample_rate=config.sample_rate,
|
| 131 |
+
max_wave_value=32768.0,
|
| 132 |
+
min_snr_db=config.min_snr_db,
|
| 133 |
+
max_snr_db=config.max_snr_db,
|
| 134 |
+
# skip=225000,
|
| 135 |
+
)
|
| 136 |
+
valid_dataset = DenoiseJsonlDataset(
|
| 137 |
+
jsonl_file=args.valid_dataset,
|
| 138 |
+
expected_sample_rate=config.sample_rate,
|
| 139 |
+
max_wave_value=32768.0,
|
| 140 |
+
min_snr_db=config.min_snr_db,
|
| 141 |
+
max_snr_db=config.max_snr_db,
|
| 142 |
+
)
|
| 143 |
+
train_data_loader = DataLoader(
|
| 144 |
+
dataset=train_dataset,
|
| 145 |
+
batch_size=config.batch_size,
|
| 146 |
+
# shuffle=True,
|
| 147 |
+
sampler=None,
|
| 148 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 149 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 150 |
+
collate_fn=collate_fn,
|
| 151 |
+
pin_memory=False,
|
| 152 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 153 |
+
)
|
| 154 |
+
valid_data_loader = DataLoader(
|
| 155 |
+
dataset=valid_dataset,
|
| 156 |
+
batch_size=config.batch_size,
|
| 157 |
+
# shuffle=True,
|
| 158 |
+
sampler=None,
|
| 159 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 160 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 161 |
+
collate_fn=collate_fn,
|
| 162 |
+
pin_memory=False,
|
| 163 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# models
|
| 167 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 168 |
+
model = DfNetPretrainedModel(config).to(device)
|
| 169 |
+
model.to(device)
|
| 170 |
+
model.train()
|
| 171 |
+
|
| 172 |
+
# optimizer
|
| 173 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 174 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
| 175 |
+
|
| 176 |
+
# resume training
|
| 177 |
+
last_step_idx = -1
|
| 178 |
+
last_epoch = -1
|
| 179 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 180 |
+
step_idx_str = Path(step_idx_str)
|
| 181 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
| 182 |
+
step_idx = int(step_idx)
|
| 183 |
+
if step_idx > last_step_idx:
|
| 184 |
+
last_step_idx = step_idx
|
| 185 |
+
# last_epoch = 1
|
| 186 |
+
|
| 187 |
+
if last_step_idx != -1:
|
| 188 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
| 189 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 190 |
+
|
| 191 |
+
logger.info(f"load state dict for model.")
|
| 192 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 193 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 194 |
+
model.load_state_dict(state_dict, strict=True)
|
| 195 |
+
|
| 196 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
| 197 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 198 |
+
optimizer,
|
| 199 |
+
last_epoch=last_epoch,
|
| 200 |
+
# T_max=10 * config.eval_steps,
|
| 201 |
+
# eta_min=0.01 * config.lr,
|
| 202 |
+
**config.lr_scheduler_kwargs,
|
| 203 |
+
)
|
| 204 |
+
elif config.lr_scheduler == "MultiStepLR":
|
| 205 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 206 |
+
optimizer,
|
| 207 |
+
last_epoch=last_epoch,
|
| 208 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 212 |
+
|
| 213 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 214 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 215 |
+
fft_size_list=[256, 512, 1024],
|
| 216 |
+
win_size_list=[256, 512, 1024],
|
| 217 |
+
hop_size_list=[128, 256, 512],
|
| 218 |
+
factor_sc=1.5,
|
| 219 |
+
factor_mag=1.0,
|
| 220 |
+
reduction="mean"
|
| 221 |
+
).to(device)
|
| 222 |
+
|
| 223 |
+
# training loop
|
| 224 |
+
|
| 225 |
+
# state
|
| 226 |
+
average_pesq_score = 1000000000
|
| 227 |
+
average_loss = 1000000000
|
| 228 |
+
average_mr_stft_loss = 1000000000
|
| 229 |
+
average_neg_si_snr_loss = 1000000000
|
| 230 |
+
average_mask_loss = 1000000000
|
| 231 |
+
average_lsnr_loss = 1000000000
|
| 232 |
+
|
| 233 |
+
model_list = list()
|
| 234 |
+
best_epoch_idx = None
|
| 235 |
+
best_step_idx = None
|
| 236 |
+
best_metric = None
|
| 237 |
+
patience_count = 0
|
| 238 |
+
|
| 239 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 240 |
+
|
| 241 |
+
logger.info("training")
|
| 242 |
+
early_stop_flag = False
|
| 243 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 244 |
+
if early_stop_flag:
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
# train
|
| 248 |
+
model.train()
|
| 249 |
+
|
| 250 |
+
total_pesq_score = 0.
|
| 251 |
+
total_loss = 0.
|
| 252 |
+
total_mr_stft_loss = 0.
|
| 253 |
+
total_neg_si_snr_loss = 0.
|
| 254 |
+
total_mask_loss = 0.
|
| 255 |
+
total_lsnr_loss = 0.
|
| 256 |
+
total_batches = 0.
|
| 257 |
+
|
| 258 |
+
progress_bar_train = tqdm(
|
| 259 |
+
initial=step_idx,
|
| 260 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
| 261 |
+
)
|
| 262 |
+
for train_batch in train_data_loader:
|
| 263 |
+
clean_audios, noisy_audios = train_batch
|
| 264 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 265 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 266 |
+
|
| 267 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
| 268 |
+
|
| 269 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
| 270 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
| 271 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 272 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 273 |
+
|
| 274 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
|
| 275 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 276 |
+
logger.info(f"find nan or inf in loss.")
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
| 280 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 281 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 282 |
+
|
| 283 |
+
optimizer.zero_grad()
|
| 284 |
+
loss.backward()
|
| 285 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 286 |
+
optimizer.step()
|
| 287 |
+
lr_scheduler.step()
|
| 288 |
+
|
| 289 |
+
total_pesq_score += pesq_score
|
| 290 |
+
total_loss += loss.item()
|
| 291 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 292 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 293 |
+
total_mask_loss += mask_loss.item()
|
| 294 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 295 |
+
total_batches += 1
|
| 296 |
+
|
| 297 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 298 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 299 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 300 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 301 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 302 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 303 |
+
|
| 304 |
+
progress_bar_train.update(1)
|
| 305 |
+
progress_bar_train.set_postfix({
|
| 306 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 307 |
+
"pesq_score": average_pesq_score,
|
| 308 |
+
"loss": average_loss,
|
| 309 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 310 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 311 |
+
"mask_loss": average_mask_loss,
|
| 312 |
+
"lsnr_loss": average_lsnr_loss,
|
| 313 |
+
})
|
| 314 |
+
|
| 315 |
+
# evaluation
|
| 316 |
+
step_idx += 1
|
| 317 |
+
if step_idx % config.eval_steps == 0:
|
| 318 |
+
model.eval()
|
| 319 |
+
with torch.no_grad():
|
| 320 |
+
torch.cuda.empty_cache()
|
| 321 |
+
|
| 322 |
+
total_pesq_score = 0.
|
| 323 |
+
total_loss = 0.
|
| 324 |
+
total_mr_stft_loss = 0.
|
| 325 |
+
total_neg_si_snr_loss = 0.
|
| 326 |
+
total_mask_loss = 0.
|
| 327 |
+
total_lsnr_loss = 0.
|
| 328 |
+
total_batches = 0.
|
| 329 |
+
|
| 330 |
+
progress_bar_train.close()
|
| 331 |
+
progress_bar_eval = tqdm(
|
| 332 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 333 |
+
)
|
| 334 |
+
for eval_batch in valid_data_loader:
|
| 335 |
+
clean_audios, noisy_audios = eval_batch
|
| 336 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 337 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 338 |
+
|
| 339 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
| 340 |
+
|
| 341 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
| 342 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
| 343 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 344 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 345 |
+
|
| 346 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
|
| 347 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 348 |
+
logger.info(f"find nan or inf in loss.")
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
| 352 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 353 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 354 |
+
|
| 355 |
+
total_pesq_score += pesq_score
|
| 356 |
+
total_loss += loss.item()
|
| 357 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 358 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 359 |
+
total_mask_loss += mask_loss.item()
|
| 360 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 361 |
+
total_batches += 1
|
| 362 |
+
|
| 363 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 364 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 365 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 366 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 367 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 368 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 369 |
+
|
| 370 |
+
progress_bar_eval.update(1)
|
| 371 |
+
progress_bar_eval.set_postfix({
|
| 372 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 373 |
+
"pesq_score": average_pesq_score,
|
| 374 |
+
"loss": average_loss,
|
| 375 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 376 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 377 |
+
"mask_loss": average_mask_loss,
|
| 378 |
+
"lsnr_loss": average_lsnr_loss,
|
| 379 |
+
})
|
| 380 |
+
|
| 381 |
+
total_pesq_score = 0.
|
| 382 |
+
total_loss = 0.
|
| 383 |
+
total_mr_stft_loss = 0.
|
| 384 |
+
total_neg_si_snr_loss = 0.
|
| 385 |
+
total_mask_loss = 0.
|
| 386 |
+
total_lsnr_loss = 0.
|
| 387 |
+
total_batches = 0.
|
| 388 |
+
|
| 389 |
+
progress_bar_eval.close()
|
| 390 |
+
progress_bar_train = tqdm(
|
| 391 |
+
initial=progress_bar_train.n,
|
| 392 |
+
postfix=progress_bar_train.postfix,
|
| 393 |
+
desc=progress_bar_train.desc,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# save path
|
| 397 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 398 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
| 399 |
+
|
| 400 |
+
# save models
|
| 401 |
+
model.save_pretrained(save_dir.as_posix())
|
| 402 |
+
|
| 403 |
+
model_list.append(save_dir)
|
| 404 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 405 |
+
model_to_delete: Path = model_list.pop(0)
|
| 406 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 407 |
+
|
| 408 |
+
# save metric
|
| 409 |
+
if best_metric is None:
|
| 410 |
+
best_epoch_idx = epoch_idx
|
| 411 |
+
best_step_idx = step_idx
|
| 412 |
+
best_metric = average_pesq_score
|
| 413 |
+
elif average_pesq_score >= best_metric:
|
| 414 |
+
# great is better.
|
| 415 |
+
best_epoch_idx = epoch_idx
|
| 416 |
+
best_step_idx = step_idx
|
| 417 |
+
best_metric = average_pesq_score
|
| 418 |
+
else:
|
| 419 |
+
pass
|
| 420 |
+
|
| 421 |
+
metrics = {
|
| 422 |
+
"epoch_idx": epoch_idx,
|
| 423 |
+
"best_epoch_idx": best_epoch_idx,
|
| 424 |
+
"best_step_idx": best_step_idx,
|
| 425 |
+
"pesq_score": average_pesq_score,
|
| 426 |
+
"loss": average_loss,
|
| 427 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 428 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 429 |
+
"mask_loss": average_mask_loss,
|
| 430 |
+
"lsnr_loss": average_lsnr_loss,
|
| 431 |
+
}
|
| 432 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
| 433 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 434 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 435 |
+
|
| 436 |
+
# save best
|
| 437 |
+
best_dir = serialization_dir / "best"
|
| 438 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 439 |
+
if best_dir.exists():
|
| 440 |
+
shutil.rmtree(best_dir)
|
| 441 |
+
shutil.copytree(save_dir, best_dir)
|
| 442 |
+
|
| 443 |
+
# early stop
|
| 444 |
+
early_stop_flag = False
|
| 445 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 446 |
+
patience_count = 0
|
| 447 |
+
else:
|
| 448 |
+
patience_count += 1
|
| 449 |
+
if patience_count >= args.patience:
|
| 450 |
+
early_stop_flag = True
|
| 451 |
+
|
| 452 |
+
# early stop
|
| 453 |
+
if early_stop_flag:
|
| 454 |
+
break
|
| 455 |
+
model.train()
|
| 456 |
+
|
| 457 |
+
return
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
if __name__ == "__main__":
|
| 461 |
+
main()
|
examples/dfnet/yaml/config.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "dfnet"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
nfft: 512
|
| 6 |
+
win_size: 200
|
| 7 |
+
hop_size: 80
|
| 8 |
+
|
| 9 |
+
spec_bins: 256
|
| 10 |
+
|
| 11 |
+
# model
|
| 12 |
+
conv_channels: 64
|
| 13 |
+
conv_kernel_size_input:
|
| 14 |
+
- 3
|
| 15 |
+
- 3
|
| 16 |
+
conv_kernel_size_inner:
|
| 17 |
+
- 1
|
| 18 |
+
- 3
|
| 19 |
+
conv_lookahead: 0
|
| 20 |
+
|
| 21 |
+
convt_kernel_size_inner:
|
| 22 |
+
- 1
|
| 23 |
+
- 3
|
| 24 |
+
|
| 25 |
+
embedding_hidden_size: 256
|
| 26 |
+
encoder_combine_op: "concat"
|
| 27 |
+
|
| 28 |
+
encoder_emb_skip_op: "none"
|
| 29 |
+
encoder_emb_linear_groups: 16
|
| 30 |
+
encoder_emb_hidden_size: 256
|
| 31 |
+
|
| 32 |
+
encoder_linear_groups: 32
|
| 33 |
+
|
| 34 |
+
decoder_emb_num_layers: 3
|
| 35 |
+
decoder_emb_skip_op: "none"
|
| 36 |
+
decoder_emb_linear_groups: 16
|
| 37 |
+
decoder_emb_hidden_size: 256
|
| 38 |
+
|
| 39 |
+
df_decoder_hidden_size: 256
|
| 40 |
+
df_num_layers: 2
|
| 41 |
+
df_order: 5
|
| 42 |
+
df_bins: 96
|
| 43 |
+
df_gru_skip: "grouped_linear"
|
| 44 |
+
df_decoder_linear_groups: 16
|
| 45 |
+
df_pathway_kernel_size_t: 5
|
| 46 |
+
df_lookahead: 2
|
| 47 |
+
|
| 48 |
+
# lsnr
|
| 49 |
+
n_frame: 3
|
| 50 |
+
lsnr_max: 30
|
| 51 |
+
lsnr_min: -15
|
| 52 |
+
norm_tau: 1.
|
| 53 |
+
|
| 54 |
+
# data
|
| 55 |
+
min_snr_db: -10
|
| 56 |
+
max_snr_db: 20
|
| 57 |
+
|
| 58 |
+
# train
|
| 59 |
+
lr: 0.001
|
| 60 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 61 |
+
lr_scheduler_kwargs:
|
| 62 |
+
T_max: 250000
|
| 63 |
+
eta_min: 0.0001
|
| 64 |
+
|
| 65 |
+
max_epochs: 100
|
| 66 |
+
clip_grad_norm: 10.0
|
| 67 |
+
seed: 1234
|
| 68 |
+
|
| 69 |
+
num_workers: 8
|
| 70 |
+
batch_size: 64
|
| 71 |
+
eval_steps: 10000
|
| 72 |
+
|
| 73 |
+
# runtime
|
| 74 |
+
use_post_filter: true
|
examples/dfnet2/run.sh
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
|
| 6 |
+
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
| 7 |
+
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
| 8 |
+
|
| 9 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-dns3 \
|
| 10 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 11 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 12 |
+
|
| 13 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \
|
| 14 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
| 15 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
| 16 |
+
|
| 17 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dfnet2-nx2-dns3 --final_model_name dfnet2-nx2-dns3 \
|
| 18 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
|
| 19 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
END
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# params
|
| 26 |
+
system_version="windows";
|
| 27 |
+
verbose=true;
|
| 28 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 29 |
+
stop_stage=9
|
| 30 |
+
|
| 31 |
+
work_dir="$(pwd)"
|
| 32 |
+
file_folder_name=file_folder_name
|
| 33 |
+
final_model_name=final_model_name
|
| 34 |
+
config_file="yaml/config.yaml"
|
| 35 |
+
limit=10
|
| 36 |
+
|
| 37 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 38 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 39 |
+
|
| 40 |
+
max_count=-1
|
| 41 |
+
|
| 42 |
+
nohup_name=nohup.out
|
| 43 |
+
|
| 44 |
+
# model params
|
| 45 |
+
batch_size=64
|
| 46 |
+
max_epochs=200
|
| 47 |
+
save_top_k=10
|
| 48 |
+
patience=5
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# parse options
|
| 52 |
+
while true; do
|
| 53 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 54 |
+
case "$1" in
|
| 55 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 56 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 57 |
+
old_value="(eval echo \\$$name)";
|
| 58 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 59 |
+
was_bool=true;
|
| 60 |
+
else
|
| 61 |
+
was_bool=false;
|
| 62 |
+
fi
|
| 63 |
+
|
| 64 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 65 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 66 |
+
eval "${name}=\"$2\"";
|
| 67 |
+
|
| 68 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 69 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 70 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 71 |
+
exit 1;
|
| 72 |
+
fi
|
| 73 |
+
shift 2;
|
| 74 |
+
;;
|
| 75 |
+
|
| 76 |
+
*) break;
|
| 77 |
+
esac
|
| 78 |
+
done
|
| 79 |
+
|
| 80 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 81 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 82 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 83 |
+
|
| 84 |
+
train_dataset="${file_dir}/train.jsonl"
|
| 85 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
| 86 |
+
|
| 87 |
+
$verbose && echo "system_version: ${system_version}"
|
| 88 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 89 |
+
|
| 90 |
+
if [ $system_version == "windows" ]; then
|
| 91 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 92 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 93 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 94 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 99 |
+
$verbose && echo "stage 1: prepare data"
|
| 100 |
+
cd "${work_dir}" || exit 1
|
| 101 |
+
python3 step_1_prepare_data.py \
|
| 102 |
+
--file_dir "${file_dir}" \
|
| 103 |
+
--noise_dir "${noise_dir}" \
|
| 104 |
+
--speech_dir "${speech_dir}" \
|
| 105 |
+
--train_dataset "${train_dataset}" \
|
| 106 |
+
--valid_dataset "${valid_dataset}" \
|
| 107 |
+
--max_count "${max_count}" \
|
| 108 |
+
|
| 109 |
+
fi
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 113 |
+
$verbose && echo "stage 2: train model"
|
| 114 |
+
cd "${work_dir}" || exit 1
|
| 115 |
+
python3 step_2_train_model.py \
|
| 116 |
+
--train_dataset "${train_dataset}" \
|
| 117 |
+
--valid_dataset "${valid_dataset}" \
|
| 118 |
+
--serialization_dir "${file_dir}" \
|
| 119 |
+
--config_file "${config_file}" \
|
| 120 |
+
|
| 121 |
+
fi
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 125 |
+
$verbose && echo "stage 3: test model"
|
| 126 |
+
cd "${work_dir}" || exit 1
|
| 127 |
+
python3 step_3_evaluation.py \
|
| 128 |
+
--valid_dataset "${valid_dataset}" \
|
| 129 |
+
--model_dir "${file_dir}/best" \
|
| 130 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 131 |
+
--limit "${limit}" \
|
| 132 |
+
|
| 133 |
+
fi
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 137 |
+
$verbose && echo "stage 4: collect files"
|
| 138 |
+
cd "${work_dir}" || exit 1
|
| 139 |
+
|
| 140 |
+
mkdir -p ${final_model_dir}
|
| 141 |
+
|
| 142 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 143 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 144 |
+
|
| 145 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 146 |
+
|
| 147 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 148 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 149 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 150 |
+
fi
|
| 151 |
+
|
| 152 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 153 |
+
rm -rf "${final_model_name}"
|
| 154 |
+
|
| 155 |
+
fi
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 159 |
+
$verbose && echo "stage 5: clear file_dir"
|
| 160 |
+
cd "${work_dir}" || exit 1
|
| 161 |
+
|
| 162 |
+
rm -rf "${file_dir}";
|
| 163 |
+
|
| 164 |
+
fi
|
examples/dfnet2/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--noise_dir",
|
| 24 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--speech_dir",
|
| 29 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 34 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 35 |
+
|
| 36 |
+
parser.add_argument("--duration", default=2.0, type=float)
|
| 37 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 38 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 39 |
+
|
| 40 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
return args
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def filename_generator(data_dir: str):
|
| 49 |
+
data_dir = Path(data_dir)
|
| 50 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 51 |
+
yield filename.as_posix()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
| 55 |
+
data_dir = Path(data_dir)
|
| 56 |
+
for epoch_idx in range(max_epoch):
|
| 57 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
+
|
| 61 |
+
if raw_duration < duration:
|
| 62 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
+
continue
|
| 64 |
+
if signal.ndim != 1:
|
| 65 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
+
|
| 67 |
+
signal_length = len(signal)
|
| 68 |
+
win_size = int(duration * sample_rate)
|
| 69 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 71 |
+
continue
|
| 72 |
+
row = {
|
| 73 |
+
"epoch_idx": epoch_idx,
|
| 74 |
+
"filename": filename.as_posix(),
|
| 75 |
+
"raw_duration": round(raw_duration, 4),
|
| 76 |
+
"offset": round(begin / sample_rate, 4),
|
| 77 |
+
"duration": round(duration, 4),
|
| 78 |
+
}
|
| 79 |
+
yield row
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
args = get_args()
|
| 84 |
+
|
| 85 |
+
file_dir = Path(args.file_dir)
|
| 86 |
+
file_dir.mkdir(exist_ok=True)
|
| 87 |
+
|
| 88 |
+
noise_dir = Path(args.noise_dir)
|
| 89 |
+
speech_dir = Path(args.speech_dir)
|
| 90 |
+
|
| 91 |
+
noise_generator = target_second_signal_generator(
|
| 92 |
+
noise_dir.as_posix(),
|
| 93 |
+
duration=args.duration,
|
| 94 |
+
sample_rate=args.target_sample_rate,
|
| 95 |
+
max_epoch=100000,
|
| 96 |
+
)
|
| 97 |
+
speech_generator = target_second_signal_generator(
|
| 98 |
+
speech_dir.as_posix(),
|
| 99 |
+
duration=args.duration,
|
| 100 |
+
sample_rate=args.target_sample_rate,
|
| 101 |
+
max_epoch=1,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
dataset = list()
|
| 105 |
+
|
| 106 |
+
count = 0
|
| 107 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
| 108 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 109 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
| 110 |
+
if count >= args.max_count > 0:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
noise_filename = noise["filename"]
|
| 114 |
+
noise_raw_duration = noise["raw_duration"]
|
| 115 |
+
noise_offset = noise["offset"]
|
| 116 |
+
noise_duration = noise["duration"]
|
| 117 |
+
|
| 118 |
+
speech_filename = speech["filename"]
|
| 119 |
+
speech_raw_duration = speech["raw_duration"]
|
| 120 |
+
speech_offset = speech["offset"]
|
| 121 |
+
speech_duration = speech["duration"]
|
| 122 |
+
|
| 123 |
+
random1 = random.random()
|
| 124 |
+
random2 = random.random()
|
| 125 |
+
|
| 126 |
+
row = {
|
| 127 |
+
"count": count,
|
| 128 |
+
|
| 129 |
+
"noise_filename": noise_filename,
|
| 130 |
+
"noise_raw_duration": noise_raw_duration,
|
| 131 |
+
"noise_offset": noise_offset,
|
| 132 |
+
"noise_duration": noise_duration,
|
| 133 |
+
|
| 134 |
+
"speech_filename": speech_filename,
|
| 135 |
+
"speech_raw_duration": speech_raw_duration,
|
| 136 |
+
"speech_offset": speech_offset,
|
| 137 |
+
"speech_duration": speech_duration,
|
| 138 |
+
|
| 139 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 140 |
+
|
| 141 |
+
"random1": random1,
|
| 142 |
+
}
|
| 143 |
+
row = json.dumps(row, ensure_ascii=False)
|
| 144 |
+
if random2 < (1 / 300 / 1):
|
| 145 |
+
fvalid.write(f"{row}\n")
|
| 146 |
+
else:
|
| 147 |
+
ftrain.write(f"{row}\n")
|
| 148 |
+
|
| 149 |
+
count += 1
|
| 150 |
+
duration_seconds = count * args.duration
|
| 151 |
+
duration_hours = duration_seconds / 3600
|
| 152 |
+
|
| 153 |
+
process_bar.update(n=1)
|
| 154 |
+
process_bar.set_postfix({
|
| 155 |
+
# "duration_seconds": round(duration_seconds, 4),
|
| 156 |
+
"duration_hours": round(duration_hours, 4),
|
| 157 |
+
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
main()
|
examples/dfnet2/step_2_train_model.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/Rikorose/DeepFilterNet
|
| 5 |
+
"""
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 10 |
+
import os
|
| 11 |
+
import platform
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import random
|
| 14 |
+
import sys
|
| 15 |
+
import shutil
|
| 16 |
+
from typing import List
|
| 17 |
+
|
| 18 |
+
from fontTools.varLib.plot import stops
|
| 19 |
+
|
| 20 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 21 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
from torch.utils.data.dataloader import DataLoader
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
| 31 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 32 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 33 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 34 |
+
from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
|
| 35 |
+
from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2, DfNet2PretrainedModel
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_args():
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 41 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 42 |
+
|
| 43 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 44 |
+
parser.add_argument("--patience", default=30, type=int)
|
| 45 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 46 |
+
|
| 47 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 48 |
+
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
return args
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def logging_config(file_dir: str):
|
| 54 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 55 |
+
|
| 56 |
+
logging.basicConfig(format=fmt,
|
| 57 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 58 |
+
level=logging.INFO)
|
| 59 |
+
file_handler = TimedRotatingFileHandler(
|
| 60 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 61 |
+
encoding="utf-8",
|
| 62 |
+
when="D",
|
| 63 |
+
interval=1,
|
| 64 |
+
backupCount=7
|
| 65 |
+
)
|
| 66 |
+
file_handler.setLevel(logging.INFO)
|
| 67 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
logger.addHandler(file_handler)
|
| 70 |
+
|
| 71 |
+
return logger
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CollateFunction(object):
|
| 75 |
+
def __init__(self):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def __call__(self, batch: List[dict]):
|
| 79 |
+
clean_audios = list()
|
| 80 |
+
noisy_audios = list()
|
| 81 |
+
snr_db_list = list()
|
| 82 |
+
|
| 83 |
+
for sample in batch:
|
| 84 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
| 85 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
| 86 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
| 87 |
+
# snr_db: float = sample["snr_db"]
|
| 88 |
+
|
| 89 |
+
clean_audios.append(clean_audio)
|
| 90 |
+
noisy_audios.append(noisy_audio)
|
| 91 |
+
|
| 92 |
+
clean_audios = torch.stack(clean_audios)
|
| 93 |
+
noisy_audios = torch.stack(noisy_audios)
|
| 94 |
+
|
| 95 |
+
# assert
|
| 96 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 97 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 98 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 99 |
+
raise AssertionError("nan or inf in noisy_audios")
|
| 100 |
+
return clean_audios, noisy_audios
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
collate_fn = CollateFunction()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
args = get_args()
|
| 108 |
+
|
| 109 |
+
config = DfNet2Config.from_pretrained(
|
| 110 |
+
pretrained_model_name_or_path=args.config_file,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
serialization_dir = Path(args.serialization_dir)
|
| 114 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
|
| 116 |
+
logger = logging_config(serialization_dir)
|
| 117 |
+
|
| 118 |
+
random.seed(config.seed)
|
| 119 |
+
np.random.seed(config.seed)
|
| 120 |
+
torch.manual_seed(config.seed)
|
| 121 |
+
logger.info(f"set seed: {config.seed}")
|
| 122 |
+
|
| 123 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 124 |
+
n_gpu = torch.cuda.device_count()
|
| 125 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 126 |
+
|
| 127 |
+
# datasets
|
| 128 |
+
train_dataset = DenoiseJsonlDataset(
|
| 129 |
+
jsonl_file=args.train_dataset,
|
| 130 |
+
expected_sample_rate=config.sample_rate,
|
| 131 |
+
max_wave_value=32768.0,
|
| 132 |
+
min_snr_db=config.min_snr_db,
|
| 133 |
+
max_snr_db=config.max_snr_db,
|
| 134 |
+
# skip=225000,
|
| 135 |
+
)
|
| 136 |
+
valid_dataset = DenoiseJsonlDataset(
|
| 137 |
+
jsonl_file=args.valid_dataset,
|
| 138 |
+
expected_sample_rate=config.sample_rate,
|
| 139 |
+
max_wave_value=32768.0,
|
| 140 |
+
min_snr_db=config.min_snr_db,
|
| 141 |
+
max_snr_db=config.max_snr_db,
|
| 142 |
+
)
|
| 143 |
+
train_data_loader = DataLoader(
|
| 144 |
+
dataset=train_dataset,
|
| 145 |
+
batch_size=config.batch_size,
|
| 146 |
+
# shuffle=True,
|
| 147 |
+
sampler=None,
|
| 148 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 149 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 150 |
+
collate_fn=collate_fn,
|
| 151 |
+
pin_memory=False,
|
| 152 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 153 |
+
)
|
| 154 |
+
valid_data_loader = DataLoader(
|
| 155 |
+
dataset=valid_dataset,
|
| 156 |
+
batch_size=config.batch_size,
|
| 157 |
+
# shuffle=True,
|
| 158 |
+
sampler=None,
|
| 159 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 160 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 161 |
+
collate_fn=collate_fn,
|
| 162 |
+
pin_memory=False,
|
| 163 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# models
|
| 167 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 168 |
+
model = DfNet2PretrainedModel(config).to(device)
|
| 169 |
+
model.to(device)
|
| 170 |
+
model.train()
|
| 171 |
+
|
| 172 |
+
# optimizer
|
| 173 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 174 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
| 175 |
+
|
| 176 |
+
# resume training
|
| 177 |
+
last_step_idx = -1
|
| 178 |
+
last_epoch = -1
|
| 179 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 180 |
+
step_idx_str = Path(step_idx_str)
|
| 181 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
| 182 |
+
step_idx = int(step_idx)
|
| 183 |
+
if step_idx > last_step_idx:
|
| 184 |
+
last_step_idx = step_idx
|
| 185 |
+
# last_epoch = 1
|
| 186 |
+
|
| 187 |
+
if last_step_idx != -1:
|
| 188 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
| 189 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 190 |
+
|
| 191 |
+
logger.info(f"load state dict for model.")
|
| 192 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 193 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 194 |
+
model.load_state_dict(state_dict, strict=True)
|
| 195 |
+
|
| 196 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
| 197 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 198 |
+
optimizer,
|
| 199 |
+
last_epoch=last_epoch,
|
| 200 |
+
# T_max=10 * config.eval_steps,
|
| 201 |
+
# eta_min=0.01 * config.lr,
|
| 202 |
+
**config.lr_scheduler_kwargs,
|
| 203 |
+
)
|
| 204 |
+
elif config.lr_scheduler == "MultiStepLR":
|
| 205 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 206 |
+
optimizer,
|
| 207 |
+
last_epoch=last_epoch,
|
| 208 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 212 |
+
|
| 213 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 214 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 215 |
+
fft_size_list=[256, 512, 1024],
|
| 216 |
+
win_size_list=[256, 512, 1024],
|
| 217 |
+
hop_size_list=[128, 256, 512],
|
| 218 |
+
factor_sc=1.5,
|
| 219 |
+
factor_mag=1.0,
|
| 220 |
+
reduction="mean"
|
| 221 |
+
).to(device)
|
| 222 |
+
|
| 223 |
+
# training loop
|
| 224 |
+
|
| 225 |
+
# state
|
| 226 |
+
average_pesq_score = 1000000000
|
| 227 |
+
average_loss = 1000000000
|
| 228 |
+
average_mr_stft_loss = 1000000000
|
| 229 |
+
average_neg_si_snr_loss = 1000000000
|
| 230 |
+
average_mask_loss = 1000000000
|
| 231 |
+
average_lsnr_loss = 1000000000
|
| 232 |
+
|
| 233 |
+
model_list = list()
|
| 234 |
+
best_epoch_idx = None
|
| 235 |
+
best_step_idx = None
|
| 236 |
+
best_metric = None
|
| 237 |
+
patience_count = 0
|
| 238 |
+
|
| 239 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 240 |
+
|
| 241 |
+
logger.info("training")
|
| 242 |
+
early_stop_flag = False
|
| 243 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 244 |
+
if early_stop_flag:
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
# train
|
| 248 |
+
model.train()
|
| 249 |
+
|
| 250 |
+
total_pesq_score = 0.
|
| 251 |
+
total_loss = 0.
|
| 252 |
+
total_mr_stft_loss = 0.
|
| 253 |
+
total_neg_si_snr_loss = 0.
|
| 254 |
+
total_mask_loss = 0.
|
| 255 |
+
total_lsnr_loss = 0.
|
| 256 |
+
total_batches = 0.
|
| 257 |
+
|
| 258 |
+
progress_bar_train = tqdm(
|
| 259 |
+
initial=step_idx,
|
| 260 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
| 261 |
+
)
|
| 262 |
+
for train_batch in train_data_loader:
|
| 263 |
+
clean_audios, noisy_audios = train_batch
|
| 264 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 265 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 266 |
+
|
| 267 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
| 268 |
+
# est_wav shape: [b, 1, n_samples]
|
| 269 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
| 270 |
+
# est_wav shape: [b, n_samples]
|
| 271 |
+
|
| 272 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
| 273 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
| 274 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 275 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 276 |
+
|
| 277 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
|
| 278 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 279 |
+
logger.info(f"find nan or inf in loss. continue.")
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
| 283 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 284 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 285 |
+
|
| 286 |
+
optimizer.zero_grad()
|
| 287 |
+
loss.backward()
|
| 288 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 289 |
+
optimizer.step()
|
| 290 |
+
lr_scheduler.step()
|
| 291 |
+
|
| 292 |
+
total_pesq_score += pesq_score
|
| 293 |
+
total_loss += loss.item()
|
| 294 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 295 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 296 |
+
total_mask_loss += mask_loss.item()
|
| 297 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 298 |
+
total_batches += 1
|
| 299 |
+
|
| 300 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 301 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 302 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 303 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 304 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 305 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 306 |
+
|
| 307 |
+
progress_bar_train.update(1)
|
| 308 |
+
progress_bar_train.set_postfix({
|
| 309 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 310 |
+
"pesq_score": average_pesq_score,
|
| 311 |
+
"loss": average_loss,
|
| 312 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 313 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 314 |
+
"mask_loss": average_mask_loss,
|
| 315 |
+
"lsnr_loss": average_lsnr_loss,
|
| 316 |
+
})
|
| 317 |
+
|
| 318 |
+
# evaluation
|
| 319 |
+
step_idx += 1
|
| 320 |
+
if step_idx % config.eval_steps == 0:
|
| 321 |
+
with torch.no_grad():
|
| 322 |
+
torch.cuda.empty_cache()
|
| 323 |
+
|
| 324 |
+
model.eval()
|
| 325 |
+
|
| 326 |
+
total_pesq_score = 0.
|
| 327 |
+
total_loss = 0.
|
| 328 |
+
total_mr_stft_loss = 0.
|
| 329 |
+
total_neg_si_snr_loss = 0.
|
| 330 |
+
total_mask_loss = 0.
|
| 331 |
+
total_lsnr_loss = 0.
|
| 332 |
+
total_batches = 0.
|
| 333 |
+
|
| 334 |
+
progress_bar_train.close()
|
| 335 |
+
progress_bar_eval = tqdm(
|
| 336 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 337 |
+
)
|
| 338 |
+
for eval_batch in valid_data_loader:
|
| 339 |
+
clean_audios, noisy_audios = eval_batch
|
| 340 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 341 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 342 |
+
|
| 343 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
| 344 |
+
# est_wav shape: [b, 1, n_samples]
|
| 345 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
| 346 |
+
# est_wav shape: [b, n_samples]
|
| 347 |
+
|
| 348 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
| 349 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
| 350 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 351 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 352 |
+
|
| 353 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
|
| 354 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 355 |
+
logger.info(f"find nan or inf in loss. continue.")
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
| 359 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 360 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 361 |
+
|
| 362 |
+
total_pesq_score += pesq_score
|
| 363 |
+
total_loss += loss.item()
|
| 364 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 365 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 366 |
+
total_mask_loss += mask_loss.item()
|
| 367 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 368 |
+
total_batches += 1
|
| 369 |
+
|
| 370 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 371 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 372 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 373 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 374 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 375 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 376 |
+
|
| 377 |
+
progress_bar_eval.update(1)
|
| 378 |
+
progress_bar_eval.set_postfix({
|
| 379 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 380 |
+
"pesq_score": average_pesq_score,
|
| 381 |
+
"loss": average_loss,
|
| 382 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 383 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 384 |
+
"mask_loss": average_mask_loss,
|
| 385 |
+
"lsnr_loss": average_lsnr_loss,
|
| 386 |
+
})
|
| 387 |
+
|
| 388 |
+
model.train()
|
| 389 |
+
|
| 390 |
+
total_pesq_score = 0.
|
| 391 |
+
total_loss = 0.
|
| 392 |
+
total_mr_stft_loss = 0.
|
| 393 |
+
total_neg_si_snr_loss = 0.
|
| 394 |
+
total_mask_loss = 0.
|
| 395 |
+
total_lsnr_loss = 0.
|
| 396 |
+
total_batches = 0.
|
| 397 |
+
|
| 398 |
+
progress_bar_eval.close()
|
| 399 |
+
progress_bar_train = tqdm(
|
| 400 |
+
initial=progress_bar_train.n,
|
| 401 |
+
postfix=progress_bar_train.postfix,
|
| 402 |
+
desc=progress_bar_train.desc,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# save path
|
| 406 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 407 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
| 408 |
+
|
| 409 |
+
# save models
|
| 410 |
+
model.save_pretrained(save_dir.as_posix())
|
| 411 |
+
|
| 412 |
+
model_list.append(save_dir)
|
| 413 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 414 |
+
model_to_delete: Path = model_list.pop(0)
|
| 415 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 416 |
+
|
| 417 |
+
# save metric
|
| 418 |
+
if best_metric is None:
|
| 419 |
+
best_epoch_idx = epoch_idx
|
| 420 |
+
best_step_idx = step_idx
|
| 421 |
+
best_metric = average_pesq_score
|
| 422 |
+
elif average_pesq_score >= best_metric:
|
| 423 |
+
# great is better.
|
| 424 |
+
best_epoch_idx = epoch_idx
|
| 425 |
+
best_step_idx = step_idx
|
| 426 |
+
best_metric = average_pesq_score
|
| 427 |
+
else:
|
| 428 |
+
pass
|
| 429 |
+
|
| 430 |
+
metrics = {
|
| 431 |
+
"epoch_idx": epoch_idx,
|
| 432 |
+
"best_epoch_idx": best_epoch_idx,
|
| 433 |
+
"best_step_idx": best_step_idx,
|
| 434 |
+
"pesq_score": average_pesq_score,
|
| 435 |
+
"loss": average_loss,
|
| 436 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 437 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 438 |
+
"mask_loss": average_mask_loss,
|
| 439 |
+
"lsnr_loss": average_lsnr_loss,
|
| 440 |
+
}
|
| 441 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
| 442 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 443 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 444 |
+
|
| 445 |
+
# save best
|
| 446 |
+
best_dir = serialization_dir / "best"
|
| 447 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 448 |
+
if best_dir.exists():
|
| 449 |
+
shutil.rmtree(best_dir)
|
| 450 |
+
shutil.copytree(save_dir, best_dir)
|
| 451 |
+
|
| 452 |
+
# early stop
|
| 453 |
+
early_stop_flag = False
|
| 454 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 455 |
+
patience_count = 0
|
| 456 |
+
else:
|
| 457 |
+
patience_count += 1
|
| 458 |
+
if patience_count >= args.patience:
|
| 459 |
+
early_stop_flag = True
|
| 460 |
+
|
| 461 |
+
# early stop
|
| 462 |
+
if early_stop_flag:
|
| 463 |
+
break
|
| 464 |
+
|
| 465 |
+
return
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
if __name__ == "__main__":
|
| 469 |
+
main()
|
examples/dfnet2/yaml/config.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "dfnet2"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
nfft: 512
|
| 6 |
+
win_size: 200
|
| 7 |
+
hop_size: 80
|
| 8 |
+
|
| 9 |
+
spec_bins: 256
|
| 10 |
+
erb_bins: 32
|
| 11 |
+
min_freq_bins_for_erb: 2
|
| 12 |
+
use_ema_norm: true
|
| 13 |
+
|
| 14 |
+
# model
|
| 15 |
+
conv_channels: 64
|
| 16 |
+
conv_kernel_size_input:
|
| 17 |
+
- 3
|
| 18 |
+
- 3
|
| 19 |
+
conv_kernel_size_inner:
|
| 20 |
+
- 1
|
| 21 |
+
- 3
|
| 22 |
+
convt_kernel_size_inner:
|
| 23 |
+
- 1
|
| 24 |
+
- 3
|
| 25 |
+
|
| 26 |
+
embedding_hidden_size: 256
|
| 27 |
+
encoder_combine_op: "concat"
|
| 28 |
+
|
| 29 |
+
encoder_emb_skip_op: "none"
|
| 30 |
+
encoder_emb_linear_groups: 16
|
| 31 |
+
encoder_emb_hidden_size: 256
|
| 32 |
+
|
| 33 |
+
encoder_linear_groups: 32
|
| 34 |
+
|
| 35 |
+
decoder_emb_num_layers: 3
|
| 36 |
+
decoder_emb_skip_op: "none"
|
| 37 |
+
decoder_emb_linear_groups: 16
|
| 38 |
+
decoder_emb_hidden_size: 256
|
| 39 |
+
|
| 40 |
+
df_decoder_hidden_size: 256
|
| 41 |
+
df_num_layers: 2
|
| 42 |
+
df_order: 5
|
| 43 |
+
df_bins: 96
|
| 44 |
+
df_gru_skip: "grouped_linear"
|
| 45 |
+
df_decoder_linear_groups: 16
|
| 46 |
+
df_pathway_kernel_size_t: 5
|
| 47 |
+
df_lookahead: 2
|
| 48 |
+
|
| 49 |
+
# lsnr
|
| 50 |
+
n_frame: 3
|
| 51 |
+
lsnr_max: 30
|
| 52 |
+
lsnr_min: -15
|
| 53 |
+
norm_tau: 1.
|
| 54 |
+
|
| 55 |
+
# data
|
| 56 |
+
min_snr_db: -5
|
| 57 |
+
max_snr_db: 40
|
| 58 |
+
|
| 59 |
+
# train
|
| 60 |
+
lr: 0.001
|
| 61 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 62 |
+
lr_scheduler_kwargs:
|
| 63 |
+
T_max: 250000
|
| 64 |
+
eta_min: 0.0001
|
| 65 |
+
|
| 66 |
+
max_epochs: 100
|
| 67 |
+
clip_grad_norm: 10.0
|
| 68 |
+
seed: 1234
|
| 69 |
+
|
| 70 |
+
num_workers: 8
|
| 71 |
+
batch_size: 96
|
| 72 |
+
eval_steps: 10000
|
| 73 |
+
|
| 74 |
+
# runtime
|
| 75 |
+
use_post_filter: true
|
examples/dtln/run.sh
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \
|
| 6 |
+
--config_file "yaml/config-256.yaml" \
|
| 7 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 8 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
|
| 12 |
+
--config_file "yaml/config-512.yaml" \
|
| 13 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 14 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
|
| 18 |
+
--config_file "yaml/config-1024.yaml" \
|
| 19 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
| 20 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3 --final_model_name dtln-256-nx2-dns3 \
|
| 24 |
+
--config_file "yaml/config-256.yaml" \
|
| 25 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 26 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
END
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# params
|
| 33 |
+
system_version="windows";
|
| 34 |
+
verbose=true;
|
| 35 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 36 |
+
stop_stage=9
|
| 37 |
+
|
| 38 |
+
work_dir="$(pwd)"
|
| 39 |
+
file_folder_name=file_folder_name
|
| 40 |
+
final_model_name=final_model_name
|
| 41 |
+
config_file="yaml/config.yaml"
|
| 42 |
+
limit=10
|
| 43 |
+
|
| 44 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 45 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 46 |
+
|
| 47 |
+
max_count=-1
|
| 48 |
+
|
| 49 |
+
nohup_name=nohup.out
|
| 50 |
+
|
| 51 |
+
# model params
|
| 52 |
+
batch_size=64
|
| 53 |
+
max_epochs=200
|
| 54 |
+
save_top_k=10
|
| 55 |
+
patience=5
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# parse options
|
| 59 |
+
while true; do
|
| 60 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 61 |
+
case "$1" in
|
| 62 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 63 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 64 |
+
old_value="(eval echo \\$$name)";
|
| 65 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 66 |
+
was_bool=true;
|
| 67 |
+
else
|
| 68 |
+
was_bool=false;
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 72 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 73 |
+
eval "${name}=\"$2\"";
|
| 74 |
+
|
| 75 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 76 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 77 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 78 |
+
exit 1;
|
| 79 |
+
fi
|
| 80 |
+
shift 2;
|
| 81 |
+
;;
|
| 82 |
+
|
| 83 |
+
*) break;
|
| 84 |
+
esac
|
| 85 |
+
done
|
| 86 |
+
|
| 87 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 88 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 89 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 90 |
+
|
| 91 |
+
train_dataset="${file_dir}/train.jsonl"
|
| 92 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
| 93 |
+
|
| 94 |
+
$verbose && echo "system_version: ${system_version}"
|
| 95 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 96 |
+
|
| 97 |
+
if [ $system_version == "windows" ]; then
|
| 98 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 99 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 100 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 101 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 102 |
+
fi
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 106 |
+
$verbose && echo "stage 1: prepare data"
|
| 107 |
+
cd "${work_dir}" || exit 1
|
| 108 |
+
python3 step_1_prepare_data.py \
|
| 109 |
+
--file_dir "${file_dir}" \
|
| 110 |
+
--noise_dir "${noise_dir}" \
|
| 111 |
+
--speech_dir "${speech_dir}" \
|
| 112 |
+
--train_dataset "${train_dataset}" \
|
| 113 |
+
--valid_dataset "${valid_dataset}" \
|
| 114 |
+
--max_count "${max_count}" \
|
| 115 |
+
|
| 116 |
+
fi
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 120 |
+
$verbose && echo "stage 2: train model"
|
| 121 |
+
cd "${work_dir}" || exit 1
|
| 122 |
+
python3 step_2_train_model.py \
|
| 123 |
+
--train_dataset "${train_dataset}" \
|
| 124 |
+
--valid_dataset "${valid_dataset}" \
|
| 125 |
+
--serialization_dir "${file_dir}" \
|
| 126 |
+
--config_file "${config_file}" \
|
| 127 |
+
|
| 128 |
+
fi
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 132 |
+
$verbose && echo "stage 3: test model"
|
| 133 |
+
cd "${work_dir}" || exit 1
|
| 134 |
+
python3 step_3_evaluation.py \
|
| 135 |
+
--valid_dataset "${valid_dataset}" \
|
| 136 |
+
--model_dir "${file_dir}/best" \
|
| 137 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 138 |
+
--limit "${limit}" \
|
| 139 |
+
|
| 140 |
+
fi
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 144 |
+
$verbose && echo "stage 4: collect files"
|
| 145 |
+
cd "${work_dir}" || exit 1
|
| 146 |
+
|
| 147 |
+
mkdir -p ${final_model_dir}
|
| 148 |
+
|
| 149 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 150 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 151 |
+
|
| 152 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 153 |
+
|
| 154 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 155 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 156 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 157 |
+
fi
|
| 158 |
+
|
| 159 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 160 |
+
rm -rf "${final_model_name}"
|
| 161 |
+
|
| 162 |
+
fi
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 166 |
+
$verbose && echo "stage 5: clear file_dir"
|
| 167 |
+
cd "${work_dir}" || exit 1
|
| 168 |
+
|
| 169 |
+
rm -rf "${file_dir}";
|
| 170 |
+
|
| 171 |
+
fi
|
examples/dtln/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--noise_dir",
|
| 24 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--speech_dir",
|
| 29 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 34 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 35 |
+
|
| 36 |
+
parser.add_argument("--duration", default=2.0, type=float)
|
| 37 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 38 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 39 |
+
|
| 40 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
return args
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def filename_generator(data_dir: str):
|
| 49 |
+
data_dir = Path(data_dir)
|
| 50 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 51 |
+
yield filename.as_posix()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
| 55 |
+
data_dir = Path(data_dir)
|
| 56 |
+
for epoch_idx in range(max_epoch):
|
| 57 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
+
|
| 61 |
+
if raw_duration < duration:
|
| 62 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
+
continue
|
| 64 |
+
if signal.ndim != 1:
|
| 65 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
+
|
| 67 |
+
signal_length = len(signal)
|
| 68 |
+
win_size = int(duration * sample_rate)
|
| 69 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 71 |
+
continue
|
| 72 |
+
row = {
|
| 73 |
+
"epoch_idx": epoch_idx,
|
| 74 |
+
"filename": filename.as_posix(),
|
| 75 |
+
"raw_duration": round(raw_duration, 4),
|
| 76 |
+
"offset": round(begin / sample_rate, 4),
|
| 77 |
+
"duration": round(duration, 4),
|
| 78 |
+
}
|
| 79 |
+
yield row
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
args = get_args()
|
| 84 |
+
|
| 85 |
+
file_dir = Path(args.file_dir)
|
| 86 |
+
file_dir.mkdir(exist_ok=True)
|
| 87 |
+
|
| 88 |
+
noise_dir = Path(args.noise_dir)
|
| 89 |
+
speech_dir = Path(args.speech_dir)
|
| 90 |
+
|
| 91 |
+
noise_generator = target_second_signal_generator(
|
| 92 |
+
noise_dir.as_posix(),
|
| 93 |
+
duration=args.duration,
|
| 94 |
+
sample_rate=args.target_sample_rate,
|
| 95 |
+
max_epoch=100000,
|
| 96 |
+
)
|
| 97 |
+
speech_generator = target_second_signal_generator(
|
| 98 |
+
speech_dir.as_posix(),
|
| 99 |
+
duration=args.duration,
|
| 100 |
+
sample_rate=args.target_sample_rate,
|
| 101 |
+
max_epoch=1,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
dataset = list()
|
| 105 |
+
|
| 106 |
+
count = 0
|
| 107 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
| 108 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 109 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
| 110 |
+
if count >= args.max_count > 0:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
noise_filename = noise["filename"]
|
| 114 |
+
noise_raw_duration = noise["raw_duration"]
|
| 115 |
+
noise_offset = noise["offset"]
|
| 116 |
+
noise_duration = noise["duration"]
|
| 117 |
+
|
| 118 |
+
speech_filename = speech["filename"]
|
| 119 |
+
speech_raw_duration = speech["raw_duration"]
|
| 120 |
+
speech_offset = speech["offset"]
|
| 121 |
+
speech_duration = speech["duration"]
|
| 122 |
+
|
| 123 |
+
random1 = random.random()
|
| 124 |
+
random2 = random.random()
|
| 125 |
+
|
| 126 |
+
row = {
|
| 127 |
+
"count": count,
|
| 128 |
+
|
| 129 |
+
"noise_filename": noise_filename,
|
| 130 |
+
"noise_raw_duration": noise_raw_duration,
|
| 131 |
+
"noise_offset": noise_offset,
|
| 132 |
+
"noise_duration": noise_duration,
|
| 133 |
+
|
| 134 |
+
"speech_filename": speech_filename,
|
| 135 |
+
"speech_raw_duration": speech_raw_duration,
|
| 136 |
+
"speech_offset": speech_offset,
|
| 137 |
+
"speech_duration": speech_duration,
|
| 138 |
+
|
| 139 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 140 |
+
|
| 141 |
+
"random1": random1,
|
| 142 |
+
}
|
| 143 |
+
row = json.dumps(row, ensure_ascii=False)
|
| 144 |
+
if random2 < (1 / 300 / 1):
|
| 145 |
+
fvalid.write(f"{row}\n")
|
| 146 |
+
else:
|
| 147 |
+
ftrain.write(f"{row}\n")
|
| 148 |
+
|
| 149 |
+
count += 1
|
| 150 |
+
duration_seconds = count * args.duration
|
| 151 |
+
duration_hours = duration_seconds / 3600
|
| 152 |
+
|
| 153 |
+
process_bar.update(n=1)
|
| 154 |
+
process_bar.set_postfix({
|
| 155 |
+
# "duration_seconds": round(duration_seconds, 4),
|
| 156 |
+
"duration_hours": round(duration_hours, 4),
|
| 157 |
+
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
main()
|
examples/dtln/step_2_train_model.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/breizhn/DTLN
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
import argparse
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 11 |
+
import os
|
| 12 |
+
import platform
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import random
|
| 15 |
+
import sys
|
| 16 |
+
import shutil
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from torch.nn import functional as F
|
| 26 |
+
from torch.utils.data.dataloader import DataLoader
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
| 30 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 31 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 32 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 33 |
+
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
|
| 34 |
+
from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_args():
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 40 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 43 |
+
parser.add_argument("--patience", default=30, type=int)
|
| 44 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 45 |
+
|
| 46 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 47 |
+
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
return args
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def logging_config(file_dir: str):
|
| 53 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 54 |
+
|
| 55 |
+
logging.basicConfig(format=fmt,
|
| 56 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 57 |
+
level=logging.INFO)
|
| 58 |
+
file_handler = TimedRotatingFileHandler(
|
| 59 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 60 |
+
encoding="utf-8",
|
| 61 |
+
when="D",
|
| 62 |
+
interval=1,
|
| 63 |
+
backupCount=7
|
| 64 |
+
)
|
| 65 |
+
file_handler.setLevel(logging.INFO)
|
| 66 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 67 |
+
logger = logging.getLogger(__name__)
|
| 68 |
+
logger.addHandler(file_handler)
|
| 69 |
+
|
| 70 |
+
return logger
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class CollateFunction(object):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
def __call__(self, batch: List[dict]):
|
| 78 |
+
clean_audios = list()
|
| 79 |
+
noisy_audios = list()
|
| 80 |
+
snr_db_list = list()
|
| 81 |
+
|
| 82 |
+
for sample in batch:
|
| 83 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
| 84 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
| 85 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
| 86 |
+
# snr_db: float = sample["snr_db"]
|
| 87 |
+
|
| 88 |
+
clean_audios.append(clean_audio)
|
| 89 |
+
noisy_audios.append(noisy_audio)
|
| 90 |
+
|
| 91 |
+
clean_audios = torch.stack(clean_audios)
|
| 92 |
+
noisy_audios = torch.stack(noisy_audios)
|
| 93 |
+
|
| 94 |
+
# assert
|
| 95 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 96 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 97 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 98 |
+
raise AssertionError("nan or inf in noisy_audios")
|
| 99 |
+
return clean_audios, noisy_audios
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
collate_fn = CollateFunction()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
args = get_args()
|
| 107 |
+
|
| 108 |
+
config = DTLNConfig.from_pretrained(
|
| 109 |
+
pretrained_model_name_or_path=args.config_file,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
serialization_dir = Path(args.serialization_dir)
|
| 113 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
logger = logging_config(serialization_dir)
|
| 116 |
+
|
| 117 |
+
random.seed(config.seed)
|
| 118 |
+
np.random.seed(config.seed)
|
| 119 |
+
torch.manual_seed(config.seed)
|
| 120 |
+
logger.info(f"set seed: {config.seed}")
|
| 121 |
+
|
| 122 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 123 |
+
n_gpu = torch.cuda.device_count()
|
| 124 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 125 |
+
|
| 126 |
+
# datasets
|
| 127 |
+
train_dataset = DenoiseJsonlDataset(
|
| 128 |
+
jsonl_file=args.train_dataset,
|
| 129 |
+
expected_sample_rate=config.sample_rate,
|
| 130 |
+
max_wave_value=32768.0,
|
| 131 |
+
min_snr_db=config.min_snr_db,
|
| 132 |
+
max_snr_db=config.max_snr_db,
|
| 133 |
+
# skip=225000,
|
| 134 |
+
)
|
| 135 |
+
valid_dataset = DenoiseJsonlDataset(
|
| 136 |
+
jsonl_file=args.valid_dataset,
|
| 137 |
+
expected_sample_rate=config.sample_rate,
|
| 138 |
+
max_wave_value=32768.0,
|
| 139 |
+
min_snr_db=config.min_snr_db,
|
| 140 |
+
max_snr_db=config.max_snr_db,
|
| 141 |
+
)
|
| 142 |
+
train_data_loader = DataLoader(
|
| 143 |
+
dataset=train_dataset,
|
| 144 |
+
batch_size=config.batch_size,
|
| 145 |
+
# shuffle=True,
|
| 146 |
+
sampler=None,
|
| 147 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 148 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 149 |
+
collate_fn=collate_fn,
|
| 150 |
+
pin_memory=False,
|
| 151 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 152 |
+
)
|
| 153 |
+
valid_data_loader = DataLoader(
|
| 154 |
+
dataset=valid_dataset,
|
| 155 |
+
batch_size=config.batch_size,
|
| 156 |
+
# shuffle=True,
|
| 157 |
+
sampler=None,
|
| 158 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 159 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 160 |
+
collate_fn=collate_fn,
|
| 161 |
+
pin_memory=False,
|
| 162 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# models
|
| 166 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 167 |
+
model = DTLNPretrainedModel(config).to(device)
|
| 168 |
+
model.to(device)
|
| 169 |
+
model.train()
|
| 170 |
+
|
| 171 |
+
# optimizer
|
| 172 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 173 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
| 174 |
+
|
| 175 |
+
# resume training
|
| 176 |
+
last_step_idx = -1
|
| 177 |
+
last_epoch = -1
|
| 178 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 179 |
+
step_idx_str = Path(step_idx_str)
|
| 180 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
| 181 |
+
step_idx = int(step_idx)
|
| 182 |
+
if step_idx > last_step_idx:
|
| 183 |
+
last_step_idx = step_idx
|
| 184 |
+
# last_epoch = 1
|
| 185 |
+
|
| 186 |
+
if last_step_idx != -1:
|
| 187 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
| 188 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 189 |
+
|
| 190 |
+
logger.info(f"load state dict for model.")
|
| 191 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 192 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 193 |
+
model.load_state_dict(state_dict, strict=True)
|
| 194 |
+
|
| 195 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
| 196 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 197 |
+
optimizer,
|
| 198 |
+
last_epoch=last_epoch,
|
| 199 |
+
# T_max=10 * config.eval_steps,
|
| 200 |
+
# eta_min=0.01 * config.lr,
|
| 201 |
+
**config.lr_scheduler_kwargs,
|
| 202 |
+
)
|
| 203 |
+
elif config.lr_scheduler == "MultiStepLR":
|
| 204 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 205 |
+
optimizer,
|
| 206 |
+
last_epoch=last_epoch,
|
| 207 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 211 |
+
|
| 212 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 213 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 214 |
+
fft_size_list=[256, 512, 1024],
|
| 215 |
+
win_size_list=[256, 512, 1024],
|
| 216 |
+
hop_size_list=[128, 256, 512],
|
| 217 |
+
factor_sc=1.5,
|
| 218 |
+
factor_mag=1.0,
|
| 219 |
+
reduction="mean"
|
| 220 |
+
).to(device)
|
| 221 |
+
|
| 222 |
+
# training loop
|
| 223 |
+
|
| 224 |
+
# state
|
| 225 |
+
average_pesq_score = 1000000000
|
| 226 |
+
average_loss = 1000000000
|
| 227 |
+
average_mr_stft_loss = 1000000000
|
| 228 |
+
average_neg_si_snr_loss = 1000000000
|
| 229 |
+
|
| 230 |
+
model_list = list()
|
| 231 |
+
best_epoch_idx = None
|
| 232 |
+
best_step_idx = None
|
| 233 |
+
best_metric = None
|
| 234 |
+
patience_count = 0
|
| 235 |
+
|
| 236 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 237 |
+
|
| 238 |
+
logger.info("training")
|
| 239 |
+
early_stop_flag = False
|
| 240 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 241 |
+
if early_stop_flag:
|
| 242 |
+
break
|
| 243 |
+
|
| 244 |
+
# train
|
| 245 |
+
model.train()
|
| 246 |
+
|
| 247 |
+
total_pesq_score = 0.
|
| 248 |
+
total_loss = 0.
|
| 249 |
+
total_mr_stft_loss = 0.
|
| 250 |
+
total_neg_si_snr_loss = 0.
|
| 251 |
+
total_batches = 0.
|
| 252 |
+
|
| 253 |
+
progress_bar_train = tqdm(
|
| 254 |
+
initial=step_idx,
|
| 255 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
| 256 |
+
)
|
| 257 |
+
for train_batch in train_data_loader:
|
| 258 |
+
clean_audios, noisy_audios = train_batch
|
| 259 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 260 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 261 |
+
|
| 262 |
+
denoise_audios = model.forward(noisy_audios)
|
| 263 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 264 |
+
|
| 265 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 266 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 267 |
+
|
| 268 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
|
| 269 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 270 |
+
logger.info(f"find nan or inf in loss.")
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 274 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 275 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 276 |
+
|
| 277 |
+
optimizer.zero_grad()
|
| 278 |
+
loss.backward()
|
| 279 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 280 |
+
optimizer.step()
|
| 281 |
+
lr_scheduler.step()
|
| 282 |
+
|
| 283 |
+
total_pesq_score += pesq_score
|
| 284 |
+
total_loss += loss.item()
|
| 285 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 286 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 287 |
+
total_batches += 1
|
| 288 |
+
|
| 289 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 290 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 291 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 292 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 293 |
+
|
| 294 |
+
progress_bar_train.update(1)
|
| 295 |
+
progress_bar_train.set_postfix({
|
| 296 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 297 |
+
"pesq_score": average_pesq_score,
|
| 298 |
+
"loss": average_loss,
|
| 299 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 300 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 301 |
+
})
|
| 302 |
+
|
| 303 |
+
# evaluation
|
| 304 |
+
step_idx += 1
|
| 305 |
+
if step_idx % config.eval_steps == 0:
|
| 306 |
+
model.eval()
|
| 307 |
+
with torch.no_grad():
|
| 308 |
+
torch.cuda.empty_cache()
|
| 309 |
+
|
| 310 |
+
total_pesq_score = 0.
|
| 311 |
+
total_loss = 0.
|
| 312 |
+
total_mr_stft_loss = 0.
|
| 313 |
+
total_neg_si_snr_loss = 0.
|
| 314 |
+
total_batches = 0.
|
| 315 |
+
|
| 316 |
+
progress_bar_train.close()
|
| 317 |
+
progress_bar_eval = tqdm(
|
| 318 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 319 |
+
)
|
| 320 |
+
for eval_batch in valid_data_loader:
|
| 321 |
+
clean_audios, noisy_audios = eval_batch
|
| 322 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 323 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 324 |
+
|
| 325 |
+
denoise_audios = model.forward(noisy_audios)
|
| 326 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 327 |
+
|
| 328 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 329 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 330 |
+
|
| 331 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
|
| 332 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 333 |
+
logger.info(f"find nan or inf in loss.")
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 337 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 338 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 339 |
+
|
| 340 |
+
total_pesq_score += pesq_score
|
| 341 |
+
total_loss += loss.item()
|
| 342 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 343 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 344 |
+
total_batches += 1
|
| 345 |
+
|
| 346 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 347 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 348 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 349 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 350 |
+
|
| 351 |
+
progress_bar_eval.update(1)
|
| 352 |
+
progress_bar_eval.set_postfix({
|
| 353 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 354 |
+
"pesq_score": average_pesq_score,
|
| 355 |
+
"loss": average_loss,
|
| 356 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 357 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 358 |
+
|
| 359 |
+
})
|
| 360 |
+
|
| 361 |
+
total_pesq_score = 0.
|
| 362 |
+
total_loss = 0.
|
| 363 |
+
total_mr_stft_loss = 0.
|
| 364 |
+
total_neg_si_snr_loss = 0.
|
| 365 |
+
total_batches = 0.
|
| 366 |
+
|
| 367 |
+
progress_bar_eval.close()
|
| 368 |
+
progress_bar_train = tqdm(
|
| 369 |
+
initial=progress_bar_train.n,
|
| 370 |
+
postfix=progress_bar_train.postfix,
|
| 371 |
+
desc=progress_bar_train.desc,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# save path
|
| 375 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 376 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
| 377 |
+
|
| 378 |
+
# save models
|
| 379 |
+
model.save_pretrained(save_dir.as_posix())
|
| 380 |
+
|
| 381 |
+
model_list.append(save_dir)
|
| 382 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 383 |
+
model_to_delete: Path = model_list.pop(0)
|
| 384 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 385 |
+
|
| 386 |
+
# save metric
|
| 387 |
+
if best_metric is None:
|
| 388 |
+
best_epoch_idx = epoch_idx
|
| 389 |
+
best_step_idx = step_idx
|
| 390 |
+
best_metric = average_pesq_score
|
| 391 |
+
elif average_pesq_score >= best_metric:
|
| 392 |
+
# great is better.
|
| 393 |
+
best_epoch_idx = epoch_idx
|
| 394 |
+
best_step_idx = step_idx
|
| 395 |
+
best_metric = average_pesq_score
|
| 396 |
+
else:
|
| 397 |
+
pass
|
| 398 |
+
|
| 399 |
+
metrics = {
|
| 400 |
+
"epoch_idx": epoch_idx,
|
| 401 |
+
"best_epoch_idx": best_epoch_idx,
|
| 402 |
+
"best_step_idx": best_step_idx,
|
| 403 |
+
"pesq_score": average_pesq_score,
|
| 404 |
+
"loss": average_loss,
|
| 405 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 406 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 407 |
+
}
|
| 408 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
| 409 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 410 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 411 |
+
|
| 412 |
+
# save best
|
| 413 |
+
best_dir = serialization_dir / "best"
|
| 414 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 415 |
+
if best_dir.exists():
|
| 416 |
+
shutil.rmtree(best_dir)
|
| 417 |
+
shutil.copytree(save_dir, best_dir)
|
| 418 |
+
|
| 419 |
+
# early stop
|
| 420 |
+
early_stop_flag = False
|
| 421 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 422 |
+
patience_count = 0
|
| 423 |
+
else:
|
| 424 |
+
patience_count += 1
|
| 425 |
+
if patience_count >= args.patience:
|
| 426 |
+
early_stop_flag = True
|
| 427 |
+
|
| 428 |
+
# early stop
|
| 429 |
+
if early_stop_flag:
|
| 430 |
+
break
|
| 431 |
+
model.train()
|
| 432 |
+
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == "__main__":
|
| 437 |
+
main()
|
examples/dtln/yaml/config-1024.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "DTLN"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
fft_size: 512
|
| 6 |
+
hop_size: 128
|
| 7 |
+
win_type: hann
|
| 8 |
+
|
| 9 |
+
# data
|
| 10 |
+
min_snr_db: -5
|
| 11 |
+
max_snr_db: 25
|
| 12 |
+
|
| 13 |
+
# model
|
| 14 |
+
encoder_size: 1024
|
| 15 |
+
|
| 16 |
+
# train
|
| 17 |
+
lr: 0.001
|
| 18 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
+
lr_scheduler_kwargs:
|
| 20 |
+
T_max: 250000
|
| 21 |
+
eta_min: 0.0001
|
| 22 |
+
|
| 23 |
+
max_epochs: 100
|
| 24 |
+
clip_grad_norm: 10.0
|
| 25 |
+
seed: 1234
|
| 26 |
+
|
| 27 |
+
num_workers: 4
|
| 28 |
+
batch_size: 64
|
| 29 |
+
eval_steps: 15000
|
examples/dtln/yaml/config-256.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "DTLN"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
fft_size: 256
|
| 6 |
+
hop_size: 128
|
| 7 |
+
win_type: hann
|
| 8 |
+
|
| 9 |
+
# data
|
| 10 |
+
min_snr_db: -5
|
| 11 |
+
max_snr_db: 25
|
| 12 |
+
|
| 13 |
+
# model
|
| 14 |
+
encoder_size: 256
|
| 15 |
+
|
| 16 |
+
# train
|
| 17 |
+
lr: 0.001
|
| 18 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
+
lr_scheduler_kwargs:
|
| 20 |
+
T_max: 250000
|
| 21 |
+
eta_min: 0.0001
|
| 22 |
+
|
| 23 |
+
max_epochs: 100
|
| 24 |
+
clip_grad_norm: 10.0
|
| 25 |
+
seed: 1234
|
| 26 |
+
|
| 27 |
+
num_workers: 4
|
| 28 |
+
batch_size: 64
|
| 29 |
+
eval_steps: 15000
|
examples/dtln/yaml/config-512.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "DTLN"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
fft_size: 512
|
| 6 |
+
hop_size: 128
|
| 7 |
+
win_type: hann
|
| 8 |
+
|
| 9 |
+
# data
|
| 10 |
+
min_snr_db: -5
|
| 11 |
+
max_snr_db: 25
|
| 12 |
+
|
| 13 |
+
# model
|
| 14 |
+
encoder_size: 512
|
| 15 |
+
|
| 16 |
+
# train
|
| 17 |
+
lr: 0.001
|
| 18 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
+
lr_scheduler_kwargs:
|
| 20 |
+
T_max: 250000
|
| 21 |
+
eta_min: 0.0001
|
| 22 |
+
|
| 23 |
+
max_epochs: 100
|
| 24 |
+
clip_grad_norm: 10.0
|
| 25 |
+
seed: 1234
|
| 26 |
+
|
| 27 |
+
num_workers: 4
|
| 28 |
+
batch_size: 64
|
| 29 |
+
eval_steps: 15000
|
examples/dtln_mp3_to_wav/run.sh
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \
|
| 6 |
+
--config_file "yaml/config-256.yaml" \
|
| 7 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 8 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
|
| 12 |
+
--config_file "yaml/config-512.yaml" \
|
| 13 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 14 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
|
| 18 |
+
--config_file "yaml/config-1024.yaml" \
|
| 19 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
| 20 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
bash run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3-mp3 --final_model_name dtln-256-nx2-dns3-mp3 \
|
| 24 |
+
--config_file "yaml/config-256.yaml" \
|
| 25 |
+
--audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
END
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# params
|
| 32 |
+
system_version="windows";
|
| 33 |
+
verbose=true;
|
| 34 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 35 |
+
stop_stage=9
|
| 36 |
+
|
| 37 |
+
work_dir="$(pwd)"
|
| 38 |
+
file_folder_name=file_folder_name
|
| 39 |
+
final_model_name=final_model_name
|
| 40 |
+
config_file="yaml/config.yaml"
|
| 41 |
+
limit=10
|
| 42 |
+
|
| 43 |
+
audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
|
| 44 |
+
|
| 45 |
+
max_count=-1
|
| 46 |
+
|
| 47 |
+
nohup_name=nohup.out
|
| 48 |
+
|
| 49 |
+
# model params
|
| 50 |
+
batch_size=64
|
| 51 |
+
max_epochs=200
|
| 52 |
+
save_top_k=10
|
| 53 |
+
patience=5
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# parse options
|
| 57 |
+
while true; do
|
| 58 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 59 |
+
case "$1" in
|
| 60 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 61 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 62 |
+
old_value="(eval echo \\$$name)";
|
| 63 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 64 |
+
was_bool=true;
|
| 65 |
+
else
|
| 66 |
+
was_bool=false;
|
| 67 |
+
fi
|
| 68 |
+
|
| 69 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 70 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 71 |
+
eval "${name}=\"$2\"";
|
| 72 |
+
|
| 73 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 74 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 75 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 76 |
+
exit 1;
|
| 77 |
+
fi
|
| 78 |
+
shift 2;
|
| 79 |
+
;;
|
| 80 |
+
|
| 81 |
+
*) break;
|
| 82 |
+
esac
|
| 83 |
+
done
|
| 84 |
+
|
| 85 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 86 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 87 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 88 |
+
|
| 89 |
+
train_dataset="${file_dir}/train.jsonl"
|
| 90 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
| 91 |
+
|
| 92 |
+
$verbose && echo "system_version: ${system_version}"
|
| 93 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 94 |
+
|
| 95 |
+
if [ $system_version == "windows" ]; then
|
| 96 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 97 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 98 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 99 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 100 |
+
fi
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 104 |
+
$verbose && echo "stage 1: prepare data"
|
| 105 |
+
cd "${work_dir}" || exit 1
|
| 106 |
+
python3 step_1_prepare_data.py \
|
| 107 |
+
--file_dir "${file_dir}" \
|
| 108 |
+
--audio_dir "${audio_dir}" \
|
| 109 |
+
--train_dataset "${train_dataset}" \
|
| 110 |
+
--valid_dataset "${valid_dataset}" \
|
| 111 |
+
--max_count "${max_count}" \
|
| 112 |
+
|
| 113 |
+
fi
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 117 |
+
$verbose && echo "stage 2: train model"
|
| 118 |
+
cd "${work_dir}" || exit 1
|
| 119 |
+
python3 step_2_train_model.py \
|
| 120 |
+
--train_dataset "${train_dataset}" \
|
| 121 |
+
--valid_dataset "${valid_dataset}" \
|
| 122 |
+
--serialization_dir "${file_dir}" \
|
| 123 |
+
--config_file "${config_file}" \
|
| 124 |
+
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 129 |
+
$verbose && echo "stage 3: test model"
|
| 130 |
+
cd "${work_dir}" || exit 1
|
| 131 |
+
python3 step_3_evaluation.py \
|
| 132 |
+
--valid_dataset "${valid_dataset}" \
|
| 133 |
+
--model_dir "${file_dir}/best" \
|
| 134 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 135 |
+
--limit "${limit}" \
|
| 136 |
+
|
| 137 |
+
fi
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 141 |
+
$verbose && echo "stage 4: collect files"
|
| 142 |
+
cd "${work_dir}" || exit 1
|
| 143 |
+
|
| 144 |
+
mkdir -p ${final_model_dir}
|
| 145 |
+
|
| 146 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 147 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 148 |
+
|
| 149 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 150 |
+
|
| 151 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 152 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 153 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 154 |
+
fi
|
| 155 |
+
|
| 156 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 157 |
+
rm -rf "${final_model_name}"
|
| 158 |
+
|
| 159 |
+
fi
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 163 |
+
$verbose && echo "stage 5: clear file_dir"
|
| 164 |
+
cd "${work_dir}" || exit 1
|
| 165 |
+
|
| 166 |
+
rm -rf "${file_dir}";
|
| 167 |
+
|
| 168 |
+
fi
|
examples/dtln_mp3_to_wav/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--audio_dir",
|
| 24 |
+
default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 29 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 30 |
+
|
| 31 |
+
parser.add_argument("--duration", default=4.0, type=float)
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 34 |
+
|
| 35 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
| 36 |
+
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
|
| 42 |
+
data_dir = Path(data_dir)
|
| 43 |
+
for epoch_idx in range(max_epoch):
|
| 44 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 45 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 46 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 47 |
+
|
| 48 |
+
if raw_duration < duration:
|
| 49 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 50 |
+
continue
|
| 51 |
+
if signal.ndim != 1:
|
| 52 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 53 |
+
|
| 54 |
+
signal_length = len(signal)
|
| 55 |
+
win_size = int(duration * sample_rate)
|
| 56 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 57 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 58 |
+
continue
|
| 59 |
+
row = {
|
| 60 |
+
"epoch_idx": epoch_idx,
|
| 61 |
+
"filename": filename.as_posix(),
|
| 62 |
+
"raw_duration": round(raw_duration, 4),
|
| 63 |
+
"offset": round(begin / sample_rate, 4),
|
| 64 |
+
"duration": round(duration, 4),
|
| 65 |
+
}
|
| 66 |
+
yield row
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
args = get_args()
|
| 71 |
+
|
| 72 |
+
file_dir = Path(args.file_dir)
|
| 73 |
+
file_dir.mkdir(exist_ok=True)
|
| 74 |
+
|
| 75 |
+
audio_dir = Path(args.audio_dir)
|
| 76 |
+
|
| 77 |
+
audio_generator = target_second_signal_generator(
|
| 78 |
+
audio_dir.as_posix(),
|
| 79 |
+
duration=args.duration,
|
| 80 |
+
sample_rate=args.target_sample_rate,
|
| 81 |
+
max_epoch=1,
|
| 82 |
+
)
|
| 83 |
+
count = 0
|
| 84 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
| 85 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 86 |
+
for audio in audio_generator:
|
| 87 |
+
if count >= args.max_count > 0:
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
filename = audio["filename"]
|
| 91 |
+
raw_duration = audio["raw_duration"]
|
| 92 |
+
offset = audio["offset"]
|
| 93 |
+
duration = audio["duration"]
|
| 94 |
+
|
| 95 |
+
random1 = random.random()
|
| 96 |
+
random2 = random.random()
|
| 97 |
+
|
| 98 |
+
row = {
|
| 99 |
+
"count": count,
|
| 100 |
+
|
| 101 |
+
"filename": filename,
|
| 102 |
+
"raw_duration": raw_duration,
|
| 103 |
+
"offset": offset,
|
| 104 |
+
"duration": duration,
|
| 105 |
+
|
| 106 |
+
"random1": random1,
|
| 107 |
+
}
|
| 108 |
+
row = json.dumps(row, ensure_ascii=False)
|
| 109 |
+
if random2 < (1 / 300):
|
| 110 |
+
fvalid.write(f"{row}\n")
|
| 111 |
+
else:
|
| 112 |
+
ftrain.write(f"{row}\n")
|
| 113 |
+
|
| 114 |
+
count += 1
|
| 115 |
+
duration_seconds = count * args.duration
|
| 116 |
+
duration_hours = duration_seconds / 3600
|
| 117 |
+
|
| 118 |
+
process_bar.update(n=1)
|
| 119 |
+
process_bar.set_postfix({
|
| 120 |
+
"duration_hours": round(duration_hours, 4),
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
examples/dtln_mp3_to_wav/step_2_train_model.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/breizhn/DTLN
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
import argparse
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 11 |
+
import os
|
| 12 |
+
import platform
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import random
|
| 15 |
+
import sys
|
| 16 |
+
import shutil
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from torch.nn import functional as F
|
| 26 |
+
from torch.utils.data.dataloader import DataLoader
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
|
| 30 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 31 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 32 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 33 |
+
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
|
| 34 |
+
from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_args():
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 40 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 43 |
+
parser.add_argument("--patience", default=30, type=int)
|
| 44 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 45 |
+
|
| 46 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 47 |
+
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
return args
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def logging_config(file_dir: str):
|
| 53 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 54 |
+
|
| 55 |
+
logging.basicConfig(format=fmt,
|
| 56 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 57 |
+
level=logging.INFO)
|
| 58 |
+
file_handler = TimedRotatingFileHandler(
|
| 59 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 60 |
+
encoding="utf-8",
|
| 61 |
+
when="D",
|
| 62 |
+
interval=1,
|
| 63 |
+
backupCount=7
|
| 64 |
+
)
|
| 65 |
+
file_handler.setLevel(logging.INFO)
|
| 66 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 67 |
+
logger = logging.getLogger(__name__)
|
| 68 |
+
logger.addHandler(file_handler)
|
| 69 |
+
|
| 70 |
+
return logger
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class CollateFunction(object):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
def __call__(self, batch: List[dict]):
|
| 78 |
+
mp3_waveform_list = list()
|
| 79 |
+
wav_waveform_list = list()
|
| 80 |
+
|
| 81 |
+
for sample in batch:
|
| 82 |
+
mp3_waveform: torch.Tensor = sample["mp3_waveform"]
|
| 83 |
+
wav_waveform: torch.Tensor = sample["wav_waveform"]
|
| 84 |
+
|
| 85 |
+
mp3_waveform_list.append(mp3_waveform)
|
| 86 |
+
wav_waveform_list.append(wav_waveform)
|
| 87 |
+
|
| 88 |
+
mp3_waveform_list = torch.stack(mp3_waveform_list)
|
| 89 |
+
wav_waveform_list = torch.stack(wav_waveform_list)
|
| 90 |
+
|
| 91 |
+
# assert
|
| 92 |
+
if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
|
| 93 |
+
raise AssertionError("nan or inf in mp3_waveform_list")
|
| 94 |
+
if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
|
| 95 |
+
raise AssertionError("nan or inf in wav_waveform_list")
|
| 96 |
+
|
| 97 |
+
return mp3_waveform_list, wav_waveform_list
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
collate_fn = CollateFunction()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def main():
|
| 104 |
+
args = get_args()
|
| 105 |
+
|
| 106 |
+
config = DTLNConfig.from_pretrained(
|
| 107 |
+
pretrained_model_name_or_path=args.config_file,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
serialization_dir = Path(args.serialization_dir)
|
| 111 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
logger = logging_config(serialization_dir)
|
| 114 |
+
|
| 115 |
+
random.seed(config.seed)
|
| 116 |
+
np.random.seed(config.seed)
|
| 117 |
+
torch.manual_seed(config.seed)
|
| 118 |
+
logger.info(f"set seed: {config.seed}")
|
| 119 |
+
|
| 120 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 121 |
+
n_gpu = torch.cuda.device_count()
|
| 122 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 123 |
+
|
| 124 |
+
# datasets
|
| 125 |
+
train_dataset = Mp3ToWavJsonlDataset(
|
| 126 |
+
jsonl_file=args.train_dataset,
|
| 127 |
+
expected_sample_rate=config.sample_rate,
|
| 128 |
+
max_wave_value=32768.0,
|
| 129 |
+
# skip=225000,
|
| 130 |
+
)
|
| 131 |
+
valid_dataset = Mp3ToWavJsonlDataset(
|
| 132 |
+
jsonl_file=args.valid_dataset,
|
| 133 |
+
expected_sample_rate=config.sample_rate,
|
| 134 |
+
max_wave_value=32768.0,
|
| 135 |
+
)
|
| 136 |
+
train_data_loader = DataLoader(
|
| 137 |
+
dataset=train_dataset,
|
| 138 |
+
batch_size=config.batch_size,
|
| 139 |
+
# shuffle=True,
|
| 140 |
+
sampler=None,
|
| 141 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 142 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 143 |
+
collate_fn=collate_fn,
|
| 144 |
+
pin_memory=False,
|
| 145 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 146 |
+
)
|
| 147 |
+
valid_data_loader = DataLoader(
|
| 148 |
+
dataset=valid_dataset,
|
| 149 |
+
batch_size=config.batch_size,
|
| 150 |
+
# shuffle=True,
|
| 151 |
+
sampler=None,
|
| 152 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 153 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 154 |
+
collate_fn=collate_fn,
|
| 155 |
+
pin_memory=False,
|
| 156 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# models
|
| 160 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 161 |
+
model = DTLNPretrainedModel(config).to(device)
|
| 162 |
+
model.to(device)
|
| 163 |
+
model.train()
|
| 164 |
+
|
| 165 |
+
# optimizer
|
| 166 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 167 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
| 168 |
+
|
| 169 |
+
# resume training
|
| 170 |
+
last_step_idx = -1
|
| 171 |
+
last_epoch = -1
|
| 172 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 173 |
+
step_idx_str = Path(step_idx_str)
|
| 174 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
| 175 |
+
step_idx = int(step_idx)
|
| 176 |
+
if step_idx > last_step_idx:
|
| 177 |
+
last_step_idx = step_idx
|
| 178 |
+
# last_epoch = 1
|
| 179 |
+
|
| 180 |
+
if last_step_idx != -1:
|
| 181 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
| 182 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 183 |
+
|
| 184 |
+
logger.info(f"load state dict for model.")
|
| 185 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 186 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 187 |
+
model.load_state_dict(state_dict, strict=True)
|
| 188 |
+
|
| 189 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
| 190 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 191 |
+
optimizer,
|
| 192 |
+
last_epoch=last_epoch,
|
| 193 |
+
# T_max=10 * config.eval_steps,
|
| 194 |
+
# eta_min=0.01 * config.lr,
|
| 195 |
+
**config.lr_scheduler_kwargs,
|
| 196 |
+
)
|
| 197 |
+
elif config.lr_scheduler == "MultiStepLR":
|
| 198 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 199 |
+
optimizer,
|
| 200 |
+
last_epoch=last_epoch,
|
| 201 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 205 |
+
|
| 206 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 207 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 208 |
+
fft_size_list=[256, 512, 1024],
|
| 209 |
+
win_size_list=[256, 512, 1024],
|
| 210 |
+
hop_size_list=[128, 256, 512],
|
| 211 |
+
factor_sc=1.5,
|
| 212 |
+
factor_mag=1.0,
|
| 213 |
+
reduction="mean"
|
| 214 |
+
).to(device)
|
| 215 |
+
audio_l1_loss_fn = nn.L1Loss(reduction="mean")
|
| 216 |
+
|
| 217 |
+
# training loop
|
| 218 |
+
|
| 219 |
+
# state
|
| 220 |
+
average_pesq_score = 1000000000
|
| 221 |
+
average_loss = 1000000000
|
| 222 |
+
average_mr_stft_loss = 1000000000
|
| 223 |
+
average_audio_l1_loss = 1000000000
|
| 224 |
+
average_neg_si_snr_loss = 1000000000
|
| 225 |
+
|
| 226 |
+
model_list = list()
|
| 227 |
+
best_epoch_idx = None
|
| 228 |
+
best_step_idx = None
|
| 229 |
+
best_metric = None
|
| 230 |
+
patience_count = 0
|
| 231 |
+
|
| 232 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 233 |
+
|
| 234 |
+
logger.info("training")
|
| 235 |
+
early_stop_flag = False
|
| 236 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 237 |
+
if early_stop_flag:
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
# train
|
| 241 |
+
model.train()
|
| 242 |
+
|
| 243 |
+
total_pesq_score = 0.
|
| 244 |
+
total_loss = 0.
|
| 245 |
+
total_mr_stft_loss = 0.
|
| 246 |
+
total_audio_l1_loss = 0.
|
| 247 |
+
total_neg_si_snr_loss = 0.
|
| 248 |
+
total_batches = 0.
|
| 249 |
+
|
| 250 |
+
progress_bar_train = tqdm(
|
| 251 |
+
initial=step_idx,
|
| 252 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
| 253 |
+
)
|
| 254 |
+
for train_batch in train_data_loader:
|
| 255 |
+
mp3_audios, wav_audios = train_batch
|
| 256 |
+
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 257 |
+
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 258 |
+
|
| 259 |
+
denoise_audios = model.forward(noisy_audios)
|
| 260 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 261 |
+
|
| 262 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 263 |
+
audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
|
| 264 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 265 |
+
|
| 266 |
+
loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
|
| 267 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 268 |
+
logger.info(f"find nan or inf in loss.")
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 272 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 273 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 274 |
+
|
| 275 |
+
optimizer.zero_grad()
|
| 276 |
+
loss.backward()
|
| 277 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 278 |
+
optimizer.step()
|
| 279 |
+
lr_scheduler.step()
|
| 280 |
+
|
| 281 |
+
total_pesq_score += pesq_score
|
| 282 |
+
total_loss += loss.item()
|
| 283 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 284 |
+
total_audio_l1_loss += audio_l1_loss.item()
|
| 285 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 286 |
+
total_batches += 1
|
| 287 |
+
|
| 288 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 289 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 290 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 291 |
+
average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
|
| 292 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 293 |
+
|
| 294 |
+
progress_bar_train.update(1)
|
| 295 |
+
progress_bar_train.set_postfix({
|
| 296 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 297 |
+
"pesq_score": average_pesq_score,
|
| 298 |
+
"loss": average_loss,
|
| 299 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 300 |
+
"audio_l1_loss": average_audio_l1_loss,
|
| 301 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
# evaluation
|
| 305 |
+
step_idx += 1
|
| 306 |
+
if step_idx % config.eval_steps == 0:
|
| 307 |
+
model.eval()
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
torch.cuda.empty_cache()
|
| 310 |
+
|
| 311 |
+
total_pesq_score = 0.
|
| 312 |
+
total_loss = 0.
|
| 313 |
+
total_mr_stft_loss = 0.
|
| 314 |
+
total_audio_l1_loss = 0.
|
| 315 |
+
total_neg_si_snr_loss = 0.
|
| 316 |
+
total_batches = 0.
|
| 317 |
+
|
| 318 |
+
progress_bar_train.close()
|
| 319 |
+
progress_bar_eval = tqdm(
|
| 320 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 321 |
+
)
|
| 322 |
+
for eval_batch in valid_data_loader:
|
| 323 |
+
mp3_audios, wav_audios = eval_batch
|
| 324 |
+
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 325 |
+
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 326 |
+
|
| 327 |
+
denoise_audios = model.forward(noisy_audios)
|
| 328 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 329 |
+
|
| 330 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 331 |
+
audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
|
| 332 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 333 |
+
|
| 334 |
+
loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
|
| 335 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 336 |
+
logger.info(f"find nan or inf in loss.")
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 340 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 341 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 342 |
+
|
| 343 |
+
total_pesq_score += pesq_score
|
| 344 |
+
total_loss += loss.item()
|
| 345 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 346 |
+
total_audio_l1_loss += audio_l1_loss.item()
|
| 347 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 348 |
+
total_batches += 1
|
| 349 |
+
|
| 350 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 351 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 352 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 353 |
+
average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
|
| 354 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 355 |
+
|
| 356 |
+
progress_bar_eval.update(1)
|
| 357 |
+
progress_bar_eval.set_postfix({
|
| 358 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 359 |
+
"pesq_score": average_pesq_score,
|
| 360 |
+
"loss": average_loss,
|
| 361 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 362 |
+
"audio_l1_loss": average_audio_l1_loss,
|
| 363 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 364 |
+
|
| 365 |
+
})
|
| 366 |
+
|
| 367 |
+
total_pesq_score = 0.
|
| 368 |
+
total_loss = 0.
|
| 369 |
+
total_mr_stft_loss = 0.
|
| 370 |
+
total_audio_l1_loss = 0.
|
| 371 |
+
total_neg_si_snr_loss = 0.
|
| 372 |
+
total_batches = 0.
|
| 373 |
+
|
| 374 |
+
progress_bar_eval.close()
|
| 375 |
+
progress_bar_train = tqdm(
|
| 376 |
+
initial=progress_bar_train.n,
|
| 377 |
+
postfix=progress_bar_train.postfix,
|
| 378 |
+
desc=progress_bar_train.desc,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# save path
|
| 382 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 383 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
| 384 |
+
|
| 385 |
+
# save models
|
| 386 |
+
model.save_pretrained(save_dir.as_posix())
|
| 387 |
+
|
| 388 |
+
model_list.append(save_dir)
|
| 389 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 390 |
+
model_to_delete: Path = model_list.pop(0)
|
| 391 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 392 |
+
|
| 393 |
+
# save metric
|
| 394 |
+
if best_metric is None:
|
| 395 |
+
best_epoch_idx = epoch_idx
|
| 396 |
+
best_step_idx = step_idx
|
| 397 |
+
best_metric = average_pesq_score
|
| 398 |
+
elif average_pesq_score >= best_metric:
|
| 399 |
+
# great is better.
|
| 400 |
+
best_epoch_idx = epoch_idx
|
| 401 |
+
best_step_idx = step_idx
|
| 402 |
+
best_metric = average_pesq_score
|
| 403 |
+
else:
|
| 404 |
+
pass
|
| 405 |
+
|
| 406 |
+
metrics = {
|
| 407 |
+
"epoch_idx": epoch_idx,
|
| 408 |
+
"best_epoch_idx": best_epoch_idx,
|
| 409 |
+
"best_step_idx": best_step_idx,
|
| 410 |
+
"pesq_score": average_pesq_score,
|
| 411 |
+
"loss": average_loss,
|
| 412 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 413 |
+
"audio_l1_loss": average_audio_l1_loss,
|
| 414 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 415 |
+
}
|
| 416 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
| 417 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 418 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 419 |
+
|
| 420 |
+
# save best
|
| 421 |
+
best_dir = serialization_dir / "best"
|
| 422 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 423 |
+
if best_dir.exists():
|
| 424 |
+
shutil.rmtree(best_dir)
|
| 425 |
+
shutil.copytree(save_dir, best_dir)
|
| 426 |
+
|
| 427 |
+
# early stop
|
| 428 |
+
early_stop_flag = False
|
| 429 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 430 |
+
patience_count = 0
|
| 431 |
+
else:
|
| 432 |
+
patience_count += 1
|
| 433 |
+
if patience_count >= args.patience:
|
| 434 |
+
early_stop_flag = True
|
| 435 |
+
|
| 436 |
+
# early stop
|
| 437 |
+
if early_stop_flag:
|
| 438 |
+
break
|
| 439 |
+
model.train()
|
| 440 |
+
|
| 441 |
+
return
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
if __name__ == "__main__":
|
| 445 |
+
main()
|
examples/dtln_mp3_to_wav/yaml/config-1024.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "DTLN"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
fft_size: 512
|
| 6 |
+
hop_size: 128
|
| 7 |
+
win_type: hann
|
| 8 |
+
|
| 9 |
+
# data
|
| 10 |
+
min_snr_db: -5
|
| 11 |
+
max_snr_db: 25
|
| 12 |
+
|
| 13 |
+
# model
|
| 14 |
+
encoder_size: 1024
|
| 15 |
+
|
| 16 |
+
# train
|
| 17 |
+
lr: 0.001
|
| 18 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
+
lr_scheduler_kwargs:
|
| 20 |
+
T_max: 250000
|
| 21 |
+
eta_min: 0.0001
|
| 22 |
+
|
| 23 |
+
max_epochs: 100
|
| 24 |
+
clip_grad_norm: 10.0
|
| 25 |
+
seed: 1234
|
| 26 |
+
|
| 27 |
+
num_workers: 4
|
| 28 |
+
batch_size: 64
|
| 29 |
+
eval_steps: 15000
|
examples/dtln_mp3_to_wav/yaml/config-256.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "DTLN"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
fft_size: 256
|
| 6 |
+
hop_size: 128
|
| 7 |
+
win_type: hann
|
| 8 |
+
|
| 9 |
+
# data
|
| 10 |
+
min_snr_db: -5
|
| 11 |
+
max_snr_db: 25
|
| 12 |
+
|
| 13 |
+
# model
|
| 14 |
+
encoder_size: 256
|
| 15 |
+
|
| 16 |
+
# train
|
| 17 |
+
lr: 0.001
|
| 18 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
+
lr_scheduler_kwargs:
|
| 20 |
+
T_max: 250000
|
| 21 |
+
eta_min: 0.0001
|
| 22 |
+
|
| 23 |
+
max_epochs: 100
|
| 24 |
+
clip_grad_norm: 10.0
|
| 25 |
+
seed: 1234
|
| 26 |
+
|
| 27 |
+
num_workers: 4
|
| 28 |
+
batch_size: 64
|
| 29 |
+
eval_steps: 15000
|
examples/dtln_mp3_to_wav/yaml/config-512.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "DTLN"
|
| 2 |
+
|
| 3 |
+
# spec
|
| 4 |
+
sample_rate: 8000
|
| 5 |
+
fft_size: 512
|
| 6 |
+
hop_size: 128
|
| 7 |
+
win_type: hann
|
| 8 |
+
|
| 9 |
+
# data
|
| 10 |
+
min_snr_db: -5
|
| 11 |
+
max_snr_db: 25
|
| 12 |
+
|
| 13 |
+
# model
|
| 14 |
+
encoder_size: 512
|
| 15 |
+
|
| 16 |
+
# train
|
| 17 |
+
lr: 0.001
|
| 18 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
+
lr_scheduler_kwargs:
|
| 20 |
+
T_max: 250000
|
| 21 |
+
eta_min: 0.0001
|
| 22 |
+
|
| 23 |
+
max_epochs: 100
|
| 24 |
+
clip_grad_norm: 10.0
|
| 25 |
+
seed: 1234
|
| 26 |
+
|
| 27 |
+
num_workers: 4
|
| 28 |
+
batch_size: 64
|
| 29 |
+
eval_steps: 15000
|
examples/frcrn/run.sh
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
|
| 7 |
+
--config_file "yaml/config-10.yaml" \
|
| 8 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 9 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
|
| 13 |
+
--config_file "yaml/config-10.yaml" \
|
| 14 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
| 15 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
| 16 |
+
|
| 17 |
+
END
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# params
|
| 21 |
+
system_version="windows";
|
| 22 |
+
verbose=true;
|
| 23 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 24 |
+
stop_stage=9
|
| 25 |
+
|
| 26 |
+
work_dir="$(pwd)"
|
| 27 |
+
file_folder_name=file_folder_name
|
| 28 |
+
final_model_name=final_model_name
|
| 29 |
+
config_file="yaml/config.yaml"
|
| 30 |
+
limit=10
|
| 31 |
+
|
| 32 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 33 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 34 |
+
|
| 35 |
+
max_count=10000000
|
| 36 |
+
|
| 37 |
+
nohup_name=nohup.out
|
| 38 |
+
|
| 39 |
+
# model params
|
| 40 |
+
batch_size=64
|
| 41 |
+
max_epochs=200
|
| 42 |
+
save_top_k=10
|
| 43 |
+
patience=5
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# parse options
|
| 47 |
+
while true; do
|
| 48 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 49 |
+
case "$1" in
|
| 50 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 51 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 52 |
+
old_value="(eval echo \\$$name)";
|
| 53 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 54 |
+
was_bool=true;
|
| 55 |
+
else
|
| 56 |
+
was_bool=false;
|
| 57 |
+
fi
|
| 58 |
+
|
| 59 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 60 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 61 |
+
eval "${name}=\"$2\"";
|
| 62 |
+
|
| 63 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 64 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 65 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 66 |
+
exit 1;
|
| 67 |
+
fi
|
| 68 |
+
shift 2;
|
| 69 |
+
;;
|
| 70 |
+
|
| 71 |
+
*) break;
|
| 72 |
+
esac
|
| 73 |
+
done
|
| 74 |
+
|
| 75 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 76 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 77 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 78 |
+
|
| 79 |
+
train_dataset="${file_dir}/train.jsonl"
|
| 80 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
| 81 |
+
|
| 82 |
+
$verbose && echo "system_version: ${system_version}"
|
| 83 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 84 |
+
|
| 85 |
+
if [ $system_version == "windows" ]; then
|
| 86 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 87 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 88 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 89 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 90 |
+
fi
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 94 |
+
$verbose && echo "stage 1: prepare data"
|
| 95 |
+
cd "${work_dir}" || exit 1
|
| 96 |
+
python3 step_1_prepare_data.py \
|
| 97 |
+
--file_dir "${file_dir}" \
|
| 98 |
+
--noise_dir "${noise_dir}" \
|
| 99 |
+
--speech_dir "${speech_dir}" \
|
| 100 |
+
--train_dataset "${train_dataset}" \
|
| 101 |
+
--valid_dataset "${valid_dataset}" \
|
| 102 |
+
--max_count "${max_count}" \
|
| 103 |
+
|
| 104 |
+
fi
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 108 |
+
$verbose && echo "stage 2: train model"
|
| 109 |
+
cd "${work_dir}" || exit 1
|
| 110 |
+
python3 step_2_train_model.py \
|
| 111 |
+
--train_dataset "${train_dataset}" \
|
| 112 |
+
--valid_dataset "${valid_dataset}" \
|
| 113 |
+
--serialization_dir "${file_dir}" \
|
| 114 |
+
--config_file "${config_file}" \
|
| 115 |
+
|
| 116 |
+
fi
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 120 |
+
$verbose && echo "stage 3: test model"
|
| 121 |
+
cd "${work_dir}" || exit 1
|
| 122 |
+
python3 step_3_evaluation.py \
|
| 123 |
+
--valid_dataset "${valid_dataset}" \
|
| 124 |
+
--model_dir "${file_dir}/best" \
|
| 125 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 126 |
+
--limit "${limit}" \
|
| 127 |
+
|
| 128 |
+
fi
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 132 |
+
$verbose && echo "stage 4: collect files"
|
| 133 |
+
cd "${work_dir}" || exit 1
|
| 134 |
+
|
| 135 |
+
mkdir -p ${final_model_dir}
|
| 136 |
+
|
| 137 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 138 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 139 |
+
|
| 140 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 141 |
+
|
| 142 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 143 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 144 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 145 |
+
fi
|
| 146 |
+
|
| 147 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 148 |
+
rm -rf "${final_model_name}"
|
| 149 |
+
|
| 150 |
+
fi
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 154 |
+
$verbose && echo "stage 5: clear file_dir"
|
| 155 |
+
cd "${work_dir}" || exit 1
|
| 156 |
+
|
| 157 |
+
rm -rf "${file_dir}";
|
| 158 |
+
|
| 159 |
+
fi
|
examples/frcrn/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--noise_dir",
|
| 24 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--speech_dir",
|
| 29 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 34 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 35 |
+
|
| 36 |
+
parser.add_argument("--duration", default=2.0, type=float)
|
| 37 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 38 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 39 |
+
|
| 40 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
return args
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def filename_generator(data_dir: str):
|
| 49 |
+
data_dir = Path(data_dir)
|
| 50 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 51 |
+
yield filename.as_posix()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
| 55 |
+
data_dir = Path(data_dir)
|
| 56 |
+
for epoch_idx in range(max_epoch):
|
| 57 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
+
|
| 61 |
+
if raw_duration < duration:
|
| 62 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
+
continue
|
| 64 |
+
if signal.ndim != 1:
|
| 65 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
+
|
| 67 |
+
signal_length = len(signal)
|
| 68 |
+
win_size = int(duration * sample_rate)
|
| 69 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 71 |
+
continue
|
| 72 |
+
row = {
|
| 73 |
+
"epoch_idx": epoch_idx,
|
| 74 |
+
"filename": filename.as_posix(),
|
| 75 |
+
"raw_duration": round(raw_duration, 4),
|
| 76 |
+
"offset": round(begin / sample_rate, 4),
|
| 77 |
+
"duration": round(duration, 4),
|
| 78 |
+
}
|
| 79 |
+
yield row
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
args = get_args()
|
| 84 |
+
|
| 85 |
+
file_dir = Path(args.file_dir)
|
| 86 |
+
file_dir.mkdir(exist_ok=True)
|
| 87 |
+
|
| 88 |
+
noise_dir = Path(args.noise_dir)
|
| 89 |
+
speech_dir = Path(args.speech_dir)
|
| 90 |
+
|
| 91 |
+
noise_generator = target_second_signal_generator(
|
| 92 |
+
noise_dir.as_posix(),
|
| 93 |
+
duration=args.duration,
|
| 94 |
+
sample_rate=args.target_sample_rate,
|
| 95 |
+
max_epoch=100000,
|
| 96 |
+
)
|
| 97 |
+
speech_generator = target_second_signal_generator(
|
| 98 |
+
speech_dir.as_posix(),
|
| 99 |
+
duration=args.duration,
|
| 100 |
+
sample_rate=args.target_sample_rate,
|
| 101 |
+
max_epoch=1,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
dataset = list()
|
| 105 |
+
|
| 106 |
+
count = 0
|
| 107 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
| 108 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 109 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
| 110 |
+
if count >= args.max_count > 0:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
noise_filename = noise["filename"]
|
| 114 |
+
noise_raw_duration = noise["raw_duration"]
|
| 115 |
+
noise_offset = noise["offset"]
|
| 116 |
+
noise_duration = noise["duration"]
|
| 117 |
+
|
| 118 |
+
speech_filename = speech["filename"]
|
| 119 |
+
speech_raw_duration = speech["raw_duration"]
|
| 120 |
+
speech_offset = speech["offset"]
|
| 121 |
+
speech_duration = speech["duration"]
|
| 122 |
+
|
| 123 |
+
random1 = random.random()
|
| 124 |
+
random2 = random.random()
|
| 125 |
+
|
| 126 |
+
row = {
|
| 127 |
+
"count": count,
|
| 128 |
+
|
| 129 |
+
"noise_filename": noise_filename,
|
| 130 |
+
"noise_raw_duration": noise_raw_duration,
|
| 131 |
+
"noise_offset": noise_offset,
|
| 132 |
+
"noise_duration": noise_duration,
|
| 133 |
+
|
| 134 |
+
"speech_filename": speech_filename,
|
| 135 |
+
"speech_raw_duration": speech_raw_duration,
|
| 136 |
+
"speech_offset": speech_offset,
|
| 137 |
+
"speech_duration": speech_duration,
|
| 138 |
+
|
| 139 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 140 |
+
|
| 141 |
+
"random1": random1,
|
| 142 |
+
}
|
| 143 |
+
row = json.dumps(row, ensure_ascii=False)
|
| 144 |
+
if random2 < (1 / 300 / 1):
|
| 145 |
+
fvalid.write(f"{row}\n")
|
| 146 |
+
else:
|
| 147 |
+
ftrain.write(f"{row}\n")
|
| 148 |
+
|
| 149 |
+
count += 1
|
| 150 |
+
duration_seconds = count * args.duration
|
| 151 |
+
duration_hours = duration_seconds / 3600
|
| 152 |
+
|
| 153 |
+
process_bar.update(n=1)
|
| 154 |
+
process_bar.set_postfix({
|
| 155 |
+
# "duration_seconds": round(duration_seconds, 4),
|
| 156 |
+
"duration_hours": round(duration_hours, 4),
|
| 157 |
+
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
main()
|
examples/frcrn/step_2_train_model.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://arxiv.org/abs/2206.07293
|
| 5 |
+
|
| 6 |
+
FRCRN 论文中:
|
| 7 |
+
在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33
|
| 8 |
+
|
| 9 |
+
WSJ0 包含约 80小时的纯净英语语音录音.
|
| 10 |
+
|
| 11 |
+
我的音频大约是 1300 小时, 则预期大约需要 10个 epoch
|
| 12 |
+
"""
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 17 |
+
import os
|
| 18 |
+
import platform
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
import random
|
| 21 |
+
import sys
|
| 22 |
+
import shutil
|
| 23 |
+
from typing import List
|
| 24 |
+
|
| 25 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 26 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
from torch.nn import functional as F
|
| 32 |
+
from torch.utils.data.dataloader import DataLoader
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
|
| 35 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
| 36 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 37 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 38 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 39 |
+
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
|
| 40 |
+
from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_args():
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 46 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 47 |
+
|
| 48 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 49 |
+
parser.add_argument("--patience", default=30, type=int)
|
| 50 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 51 |
+
|
| 52 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 53 |
+
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
return args
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def logging_config(file_dir: str):
|
| 59 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 60 |
+
|
| 61 |
+
logging.basicConfig(format=fmt,
|
| 62 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 63 |
+
level=logging.INFO)
|
| 64 |
+
file_handler = TimedRotatingFileHandler(
|
| 65 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 66 |
+
encoding="utf-8",
|
| 67 |
+
when="D",
|
| 68 |
+
interval=1,
|
| 69 |
+
backupCount=7
|
| 70 |
+
)
|
| 71 |
+
file_handler.setLevel(logging.INFO)
|
| 72 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 73 |
+
logger = logging.getLogger(__name__)
|
| 74 |
+
logger.addHandler(file_handler)
|
| 75 |
+
|
| 76 |
+
return logger
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CollateFunction(object):
|
| 80 |
+
def __init__(self):
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
def __call__(self, batch: List[dict]):
|
| 84 |
+
clean_audios = list()
|
| 85 |
+
noisy_audios = list()
|
| 86 |
+
|
| 87 |
+
for sample in batch:
|
| 88 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
| 89 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
| 90 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
| 91 |
+
# snr_db: float = sample["snr_db"]
|
| 92 |
+
|
| 93 |
+
clean_audios.append(clean_audio)
|
| 94 |
+
noisy_audios.append(noisy_audio)
|
| 95 |
+
|
| 96 |
+
clean_audios = torch.stack(clean_audios)
|
| 97 |
+
noisy_audios = torch.stack(noisy_audios)
|
| 98 |
+
|
| 99 |
+
# assert
|
| 100 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 101 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 102 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 103 |
+
raise AssertionError("nan or inf in noisy_audios")
|
| 104 |
+
return clean_audios, noisy_audios
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
collate_fn = CollateFunction()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def main():
|
| 111 |
+
args = get_args()
|
| 112 |
+
|
| 113 |
+
config = FRCRNConfig.from_pretrained(
|
| 114 |
+
pretrained_model_name_or_path=args.config_file,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
serialization_dir = Path(args.serialization_dir)
|
| 118 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
|
| 120 |
+
logger = logging_config(serialization_dir)
|
| 121 |
+
|
| 122 |
+
random.seed(config.seed)
|
| 123 |
+
np.random.seed(config.seed)
|
| 124 |
+
torch.manual_seed(config.seed)
|
| 125 |
+
logger.info(f"set seed: {config.seed}")
|
| 126 |
+
|
| 127 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 128 |
+
n_gpu = torch.cuda.device_count()
|
| 129 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 130 |
+
|
| 131 |
+
# datasets
|
| 132 |
+
train_dataset = DenoiseJsonlDataset(
|
| 133 |
+
jsonl_file=args.train_dataset,
|
| 134 |
+
expected_sample_rate=config.sample_rate,
|
| 135 |
+
max_wave_value=32768.0,
|
| 136 |
+
min_snr_db=config.min_snr_db,
|
| 137 |
+
max_snr_db=config.max_snr_db,
|
| 138 |
+
# skip=225000,
|
| 139 |
+
)
|
| 140 |
+
valid_dataset = DenoiseJsonlDataset(
|
| 141 |
+
jsonl_file=args.valid_dataset,
|
| 142 |
+
expected_sample_rate=config.sample_rate,
|
| 143 |
+
max_wave_value=32768.0,
|
| 144 |
+
min_snr_db=config.min_snr_db,
|
| 145 |
+
max_snr_db=config.max_snr_db,
|
| 146 |
+
)
|
| 147 |
+
train_data_loader = DataLoader(
|
| 148 |
+
dataset=train_dataset,
|
| 149 |
+
batch_size=config.batch_size,
|
| 150 |
+
# shuffle=True,
|
| 151 |
+
sampler=None,
|
| 152 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 153 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 154 |
+
collate_fn=collate_fn,
|
| 155 |
+
pin_memory=False,
|
| 156 |
+
prefetch_factor=2,
|
| 157 |
+
)
|
| 158 |
+
valid_data_loader = DataLoader(
|
| 159 |
+
dataset=valid_dataset,
|
| 160 |
+
batch_size=config.batch_size,
|
| 161 |
+
# shuffle=True,
|
| 162 |
+
sampler=None,
|
| 163 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 164 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 165 |
+
collate_fn=collate_fn,
|
| 166 |
+
pin_memory=False,
|
| 167 |
+
prefetch_factor=2,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# models
|
| 171 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 172 |
+
model = FRCRNPretrainedModel(config).to(device)
|
| 173 |
+
model.to(device)
|
| 174 |
+
model.train()
|
| 175 |
+
|
| 176 |
+
# optimizer
|
| 177 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 178 |
+
optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
|
| 179 |
+
|
| 180 |
+
# resume training
|
| 181 |
+
last_step_idx = -1
|
| 182 |
+
last_epoch = -1
|
| 183 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 184 |
+
step_idx_str = Path(step_idx_str)
|
| 185 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
| 186 |
+
step_idx = int(step_idx)
|
| 187 |
+
if step_idx > last_step_idx:
|
| 188 |
+
last_step_idx = step_idx
|
| 189 |
+
# last_epoch = 0
|
| 190 |
+
|
| 191 |
+
if last_step_idx != -1:
|
| 192 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
| 193 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 194 |
+
# optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
| 195 |
+
|
| 196 |
+
logger.info(f"load state dict for model.")
|
| 197 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 198 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 199 |
+
model.load_state_dict(state_dict, strict=True)
|
| 200 |
+
|
| 201 |
+
# logger.info(f"load state dict for optimizer.")
|
| 202 |
+
# with open(optimizer_pth.as_posix(), "rb") as f:
|
| 203 |
+
# state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 204 |
+
# optimizer.load_state_dict(state_dict)
|
| 205 |
+
|
| 206 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
| 207 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 208 |
+
optimizer,
|
| 209 |
+
last_epoch=last_epoch,
|
| 210 |
+
# T_max=10 * config.eval_steps,
|
| 211 |
+
# eta_min=0.01 * config.lr,
|
| 212 |
+
**config.lr_scheduler_kwargs,
|
| 213 |
+
)
|
| 214 |
+
elif config.lr_scheduler == "MultiStepLR":
|
| 215 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 216 |
+
optimizer,
|
| 217 |
+
last_epoch=last_epoch,
|
| 218 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 222 |
+
|
| 223 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 224 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 225 |
+
fft_size_list=[256, 512, 1024],
|
| 226 |
+
win_size_list=[256, 512, 1024],
|
| 227 |
+
hop_size_list=[128, 256, 512],
|
| 228 |
+
factor_sc=1.5,
|
| 229 |
+
factor_mag=1.0,
|
| 230 |
+
reduction="mean"
|
| 231 |
+
).to(device)
|
| 232 |
+
|
| 233 |
+
# training loop
|
| 234 |
+
|
| 235 |
+
# state
|
| 236 |
+
average_pesq_score = 1000000000
|
| 237 |
+
average_loss = 1000000000
|
| 238 |
+
average_neg_si_snr_loss = 1000000000
|
| 239 |
+
average_mask_loss = 1000000000
|
| 240 |
+
|
| 241 |
+
model_list = list()
|
| 242 |
+
best_epoch_idx = None
|
| 243 |
+
best_step_idx = None
|
| 244 |
+
best_metric = None
|
| 245 |
+
patience_count = 0
|
| 246 |
+
|
| 247 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 248 |
+
|
| 249 |
+
logger.info("training")
|
| 250 |
+
early_stop_flag = False
|
| 251 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 252 |
+
if early_stop_flag:
|
| 253 |
+
break
|
| 254 |
+
|
| 255 |
+
# train
|
| 256 |
+
model.train()
|
| 257 |
+
|
| 258 |
+
total_pesq_score = 0.
|
| 259 |
+
total_loss = 0.
|
| 260 |
+
total_mr_stft_loss = 0.
|
| 261 |
+
total_neg_si_snr_loss = 0.
|
| 262 |
+
total_mask_loss = 0.
|
| 263 |
+
total_batches = 0.
|
| 264 |
+
|
| 265 |
+
progress_bar_train = tqdm(
|
| 266 |
+
initial=step_idx,
|
| 267 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
| 268 |
+
)
|
| 269 |
+
for train_batch in train_data_loader:
|
| 270 |
+
clean_audios, noisy_audios = train_batch
|
| 271 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 272 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 273 |
+
|
| 274 |
+
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
| 275 |
+
denoise_audios = est_wav
|
| 276 |
+
|
| 277 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 278 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 279 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 280 |
+
|
| 281 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
| 282 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 283 |
+
logger.info(f"find nan or inf in loss.")
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 287 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 288 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 289 |
+
|
| 290 |
+
optimizer.zero_grad()
|
| 291 |
+
loss.backward()
|
| 292 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 293 |
+
optimizer.step()
|
| 294 |
+
lr_scheduler.step()
|
| 295 |
+
|
| 296 |
+
total_pesq_score += pesq_score
|
| 297 |
+
total_loss += loss.item()
|
| 298 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 299 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 300 |
+
total_mask_loss += mask_loss.item()
|
| 301 |
+
total_batches += 1
|
| 302 |
+
|
| 303 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 304 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 305 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 306 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 307 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 308 |
+
|
| 309 |
+
progress_bar_train.update(1)
|
| 310 |
+
progress_bar_train.set_postfix({
|
| 311 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 312 |
+
"pesq_score": average_pesq_score,
|
| 313 |
+
"loss": average_loss,
|
| 314 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 315 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 316 |
+
"mask_loss": average_mask_loss,
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
# evaluation
|
| 320 |
+
step_idx += 1
|
| 321 |
+
if step_idx % config.eval_steps == 0:
|
| 322 |
+
model.eval()
|
| 323 |
+
with torch.no_grad():
|
| 324 |
+
torch.cuda.empty_cache()
|
| 325 |
+
|
| 326 |
+
total_pesq_score = 0.
|
| 327 |
+
total_loss = 0.
|
| 328 |
+
total_mr_stft_loss = 0.
|
| 329 |
+
total_neg_si_snr_loss = 0.
|
| 330 |
+
total_mask_loss = 0.
|
| 331 |
+
total_batches = 0.
|
| 332 |
+
|
| 333 |
+
progress_bar_train.close()
|
| 334 |
+
progress_bar_eval = tqdm(
|
| 335 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 336 |
+
)
|
| 337 |
+
for eval_batch in valid_data_loader:
|
| 338 |
+
clean_audios, noisy_audios = eval_batch
|
| 339 |
+
clean_audios = clean_audios.to(device)
|
| 340 |
+
noisy_audios = noisy_audios.to(device)
|
| 341 |
+
|
| 342 |
+
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
| 343 |
+
denoise_audios = est_wav
|
| 344 |
+
|
| 345 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 346 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 347 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 348 |
+
|
| 349 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
| 350 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 351 |
+
logger.info(f"find nan or inf in loss.")
|
| 352 |
+
continue
|
| 353 |
+
|
| 354 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 355 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 356 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 357 |
+
|
| 358 |
+
total_pesq_score += pesq_score
|
| 359 |
+
total_loss += loss.item()
|
| 360 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 361 |
+
total_mask_loss += mask_loss.item()
|
| 362 |
+
total_batches += 1
|
| 363 |
+
|
| 364 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 365 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 366 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 367 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 368 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 369 |
+
|
| 370 |
+
progress_bar_eval.update(1)
|
| 371 |
+
progress_bar_eval.set_postfix({
|
| 372 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 373 |
+
"pesq_score": average_pesq_score,
|
| 374 |
+
"loss": average_loss,
|
| 375 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 376 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 377 |
+
"mask_loss": average_mask_loss,
|
| 378 |
+
})
|
| 379 |
+
|
| 380 |
+
total_pesq_score = 0.
|
| 381 |
+
total_loss = 0.
|
| 382 |
+
total_mr_stft_loss = 0.
|
| 383 |
+
total_neg_si_snr_loss = 0.
|
| 384 |
+
total_mask_loss = 0.
|
| 385 |
+
total_batches = 0.
|
| 386 |
+
|
| 387 |
+
progress_bar_eval.close()
|
| 388 |
+
progress_bar_train = tqdm(
|
| 389 |
+
initial=progress_bar_train.n,
|
| 390 |
+
postfix=progress_bar_train.postfix,
|
| 391 |
+
desc=progress_bar_train.desc,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# save path
|
| 395 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 396 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
| 397 |
+
|
| 398 |
+
# save models
|
| 399 |
+
model.save_pretrained(save_dir.as_posix())
|
| 400 |
+
|
| 401 |
+
model_list.append(save_dir)
|
| 402 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 403 |
+
model_to_delete: Path = model_list.pop(0)
|
| 404 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 405 |
+
|
| 406 |
+
# save metric
|
| 407 |
+
if best_metric is None:
|
| 408 |
+
best_epoch_idx = epoch_idx
|
| 409 |
+
best_step_idx = step_idx
|
| 410 |
+
best_metric = average_pesq_score
|
| 411 |
+
elif average_pesq_score >= best_metric:
|
| 412 |
+
# great is better.
|
| 413 |
+
best_epoch_idx = epoch_idx
|
| 414 |
+
best_step_idx = step_idx
|
| 415 |
+
best_metric = average_pesq_score
|
| 416 |
+
else:
|
| 417 |
+
pass
|
| 418 |
+
|
| 419 |
+
metrics = {
|
| 420 |
+
"epoch_idx": epoch_idx,
|
| 421 |
+
"best_epoch_idx": best_epoch_idx,
|
| 422 |
+
"best_step_idx": best_step_idx,
|
| 423 |
+
"pesq_score": average_pesq_score,
|
| 424 |
+
"loss": average_loss,
|
| 425 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 426 |
+
"mask_loss": average_mask_loss,
|
| 427 |
+
}
|
| 428 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
| 429 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 430 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 431 |
+
|
| 432 |
+
# save best
|
| 433 |
+
best_dir = serialization_dir / "best"
|
| 434 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 435 |
+
if best_dir.exists():
|
| 436 |
+
shutil.rmtree(best_dir)
|
| 437 |
+
shutil.copytree(save_dir, best_dir)
|
| 438 |
+
|
| 439 |
+
# early stop
|
| 440 |
+
early_stop_flag = False
|
| 441 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 442 |
+
patience_count = 0
|
| 443 |
+
else:
|
| 444 |
+
patience_count += 1
|
| 445 |
+
if patience_count >= args.patience:
|
| 446 |
+
early_stop_flag = True
|
| 447 |
+
|
| 448 |
+
# early stop
|
| 449 |
+
if early_stop_flag:
|
| 450 |
+
break
|
| 451 |
+
model.train()
|
| 452 |
+
|
| 453 |
+
return
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
if __name__ == "__main__":
|
| 457 |
+
main()
|
examples/frcrn/yaml/config-10.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "frcrn"
|
| 2 |
+
|
| 3 |
+
sample_rate: 8000
|
| 4 |
+
segment_size: 32000
|
| 5 |
+
nfft: 128
|
| 6 |
+
win_size: 128
|
| 7 |
+
hop_size: 64
|
| 8 |
+
win_type: hann
|
| 9 |
+
|
| 10 |
+
use_complex_networks: true
|
| 11 |
+
model_depth: 10
|
| 12 |
+
model_complexity: -1
|
| 13 |
+
|
| 14 |
+
min_snr_db: -10
|
| 15 |
+
max_snr_db: 20
|
| 16 |
+
|
| 17 |
+
num_workers: 8
|
| 18 |
+
batch_size: 32
|
| 19 |
+
eval_steps: 20000
|
| 20 |
+
|
| 21 |
+
lr: 0.001
|
| 22 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 23 |
+
lr_scheduler_kwargs:
|
| 24 |
+
T_max: 250000
|
| 25 |
+
eta_min: 0.0001
|
| 26 |
+
|
| 27 |
+
max_epochs: 100
|
| 28 |
+
weight_decay: 1.0e-05
|
| 29 |
+
clip_grad_norm: 10.0
|
| 30 |
+
seed: 1234
|
| 31 |
+
num_gpus: -1
|
examples/frcrn/yaml/config-14.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "frcrn"
|
| 2 |
+
|
| 3 |
+
sample_rate: 8000
|
| 4 |
+
segment_size: 32000
|
| 5 |
+
nfft: 640
|
| 6 |
+
win_size: 640
|
| 7 |
+
hop_size: 320
|
| 8 |
+
win_type: hann
|
| 9 |
+
|
| 10 |
+
use_complex_networks: true
|
| 11 |
+
model_depth: 14
|
| 12 |
+
model_complexity: -1
|
| 13 |
+
|
| 14 |
+
min_snr_db: -10
|
| 15 |
+
max_snr_db: 20
|
| 16 |
+
|
| 17 |
+
num_workers: 8
|
| 18 |
+
batch_size: 32
|
| 19 |
+
eval_steps: 10000
|
| 20 |
+
|
| 21 |
+
lr: 0.001
|
| 22 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 23 |
+
lr_scheduler_kwargs:
|
| 24 |
+
T_max: 250000
|
| 25 |
+
eta_min: 0.0001
|
| 26 |
+
|
| 27 |
+
max_epochs: 100
|
| 28 |
+
weight_decay: 1.0e-05
|
| 29 |
+
clip_grad_norm: 10.0
|
| 30 |
+
seed: 1234
|
| 31 |
+
num_gpus: -1
|
examples/frcrn/yaml/config-20.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: "frcrn"
|
| 2 |
+
|
| 3 |
+
sample_rate: 8000
|
| 4 |
+
segment_size: 32000
|
| 5 |
+
nfft: 512
|
| 6 |
+
win_size: 512
|
| 7 |
+
hop_size: 256
|
| 8 |
+
win_type: hann
|
| 9 |
+
|
| 10 |
+
use_complex_networks: true
|
| 11 |
+
model_depth: 20
|
| 12 |
+
model_complexity: 45
|
| 13 |
+
|
| 14 |
+
min_snr_db: -10
|
| 15 |
+
max_snr_db: 20
|
| 16 |
+
|
| 17 |
+
num_workers: 8
|
| 18 |
+
batch_size: 32
|
| 19 |
+
eval_steps: 10000
|
| 20 |
+
|
| 21 |
+
lr: 0.001
|
| 22 |
+
lr_scheduler: "CosineAnnealingLR"
|
| 23 |
+
lr_scheduler_kwargs:
|
| 24 |
+
T_max: 250000
|
| 25 |
+
eta_min: 0.0001
|
| 26 |
+
|
| 27 |
+
max_epochs: 100
|
| 28 |
+
weight_decay: 1.0e-05
|
| 29 |
+
clip_grad_norm: 10.0
|
| 30 |
+
seed: 1234
|
| 31 |
+
num_gpus: -1
|
examples/frcrn_mp3_to_wav/run.sh
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
: <<'END'
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
|
| 7 |
+
--config_file "yaml/config-10.yaml" \
|
| 8 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 9 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
|
| 13 |
+
--config_file "yaml/config-10.yaml" \
|
| 14 |
+
--audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
|
| 15 |
+
|
| 16 |
+
END
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# params
|
| 20 |
+
system_version="windows";
|
| 21 |
+
verbose=true;
|
| 22 |
+
stage=0 # start from 0 if you need to start from data preparation
|
| 23 |
+
stop_stage=9
|
| 24 |
+
|
| 25 |
+
work_dir="$(pwd)"
|
| 26 |
+
file_folder_name=file_folder_name
|
| 27 |
+
final_model_name=final_model_name
|
| 28 |
+
config_file="yaml/config.yaml"
|
| 29 |
+
limit=10
|
| 30 |
+
|
| 31 |
+
audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
|
| 32 |
+
|
| 33 |
+
max_count=10000000
|
| 34 |
+
|
| 35 |
+
nohup_name=nohup.out
|
| 36 |
+
|
| 37 |
+
# model params
|
| 38 |
+
batch_size=64
|
| 39 |
+
max_epochs=200
|
| 40 |
+
save_top_k=10
|
| 41 |
+
patience=5
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# parse options
|
| 45 |
+
while true; do
|
| 46 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 47 |
+
case "$1" in
|
| 48 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 49 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 50 |
+
old_value="(eval echo \\$$name)";
|
| 51 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 52 |
+
was_bool=true;
|
| 53 |
+
else
|
| 54 |
+
was_bool=false;
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 58 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 59 |
+
eval "${name}=\"$2\"";
|
| 60 |
+
|
| 61 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 62 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 63 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 64 |
+
exit 1;
|
| 65 |
+
fi
|
| 66 |
+
shift 2;
|
| 67 |
+
;;
|
| 68 |
+
|
| 69 |
+
*) break;
|
| 70 |
+
esac
|
| 71 |
+
done
|
| 72 |
+
|
| 73 |
+
file_dir="${work_dir}/${file_folder_name}"
|
| 74 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 75 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 76 |
+
|
| 77 |
+
train_dataset="${file_dir}/train.jsonl"
|
| 78 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
| 79 |
+
|
| 80 |
+
$verbose && echo "system_version: ${system_version}"
|
| 81 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 82 |
+
|
| 83 |
+
if [ $system_version == "windows" ]; then
|
| 84 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 85 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 86 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
| 87 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 88 |
+
fi
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 92 |
+
$verbose && echo "stage 1: prepare data"
|
| 93 |
+
cd "${work_dir}" || exit 1
|
| 94 |
+
python3 step_1_prepare_data.py \
|
| 95 |
+
--file_dir "${file_dir}" \
|
| 96 |
+
--audio_dir "${audio_dir}" \
|
| 97 |
+
--train_dataset "${train_dataset}" \
|
| 98 |
+
--valid_dataset "${valid_dataset}" \
|
| 99 |
+
--max_count "${max_count}" \
|
| 100 |
+
|
| 101 |
+
fi
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 105 |
+
$verbose && echo "stage 2: train model"
|
| 106 |
+
cd "${work_dir}" || exit 1
|
| 107 |
+
python3 step_2_train_model.py \
|
| 108 |
+
--train_dataset "${train_dataset}" \
|
| 109 |
+
--valid_dataset "${valid_dataset}" \
|
| 110 |
+
--serialization_dir "${file_dir}" \
|
| 111 |
+
--config_file "${config_file}" \
|
| 112 |
+
|
| 113 |
+
fi
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 117 |
+
$verbose && echo "stage 3: test model"
|
| 118 |
+
cd "${work_dir}" || exit 1
|
| 119 |
+
python3 step_3_evaluation.py \
|
| 120 |
+
--valid_dataset "${valid_dataset}" \
|
| 121 |
+
--model_dir "${file_dir}/best" \
|
| 122 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 123 |
+
--limit "${limit}" \
|
| 124 |
+
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 129 |
+
$verbose && echo "stage 4: collect files"
|
| 130 |
+
cd "${work_dir}" || exit 1
|
| 131 |
+
|
| 132 |
+
mkdir -p ${final_model_dir}
|
| 133 |
+
|
| 134 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 135 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 136 |
+
|
| 137 |
+
cd "${final_model_dir}/.." || exit 1;
|
| 138 |
+
|
| 139 |
+
if [ -e "${final_model_name}.zip" ]; then
|
| 140 |
+
rm -rf "${final_model_name}_backup.zip"
|
| 141 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 142 |
+
fi
|
| 143 |
+
|
| 144 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 145 |
+
rm -rf "${final_model_name}"
|
| 146 |
+
|
| 147 |
+
fi
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 151 |
+
$verbose && echo "stage 5: clear file_dir"
|
| 152 |
+
cd "${work_dir}" || exit 1
|
| 153 |
+
|
| 154 |
+
rm -rf "${file_dir}";
|
| 155 |
+
|
| 156 |
+
fi
|
examples/frcrn_mp3_to_wav/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
+
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--audio_dir",
|
| 24 |
+
default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 29 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 30 |
+
|
| 31 |
+
parser.add_argument("--duration", default=4.0, type=float)
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 34 |
+
|
| 35 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
| 36 |
+
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
|
| 42 |
+
data_dir = Path(data_dir)
|
| 43 |
+
for epoch_idx in range(max_epoch):
|
| 44 |
+
for filename in data_dir.glob("**/*.wav"):
|
| 45 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 46 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 47 |
+
|
| 48 |
+
if raw_duration < duration:
|
| 49 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 50 |
+
continue
|
| 51 |
+
if signal.ndim != 1:
|
| 52 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 53 |
+
|
| 54 |
+
signal_length = len(signal)
|
| 55 |
+
win_size = int(duration * sample_rate)
|
| 56 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 57 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 58 |
+
continue
|
| 59 |
+
row = {
|
| 60 |
+
"epoch_idx": epoch_idx,
|
| 61 |
+
"filename": filename.as_posix(),
|
| 62 |
+
"raw_duration": round(raw_duration, 4),
|
| 63 |
+
"offset": round(begin / sample_rate, 4),
|
| 64 |
+
"duration": round(duration, 4),
|
| 65 |
+
}
|
| 66 |
+
yield row
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
args = get_args()
|
| 71 |
+
|
| 72 |
+
file_dir = Path(args.file_dir)
|
| 73 |
+
file_dir.mkdir(exist_ok=True)
|
| 74 |
+
|
| 75 |
+
audio_dir = Path(args.audio_dir)
|
| 76 |
+
|
| 77 |
+
audio_generator = target_second_signal_generator(
|
| 78 |
+
audio_dir.as_posix(),
|
| 79 |
+
duration=args.duration,
|
| 80 |
+
sample_rate=args.target_sample_rate,
|
| 81 |
+
max_epoch=1,
|
| 82 |
+
)
|
| 83 |
+
count = 0
|
| 84 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
| 85 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 86 |
+
for audio in audio_generator:
|
| 87 |
+
if count >= args.max_count > 0:
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
filename = audio["filename"]
|
| 91 |
+
raw_duration = audio["raw_duration"]
|
| 92 |
+
offset = audio["offset"]
|
| 93 |
+
duration = audio["duration"]
|
| 94 |
+
|
| 95 |
+
random1 = random.random()
|
| 96 |
+
random2 = random.random()
|
| 97 |
+
|
| 98 |
+
row = {
|
| 99 |
+
"count": count,
|
| 100 |
+
|
| 101 |
+
"filename": filename,
|
| 102 |
+
"raw_duration": raw_duration,
|
| 103 |
+
"offset": offset,
|
| 104 |
+
"duration": duration,
|
| 105 |
+
|
| 106 |
+
"random1": random1,
|
| 107 |
+
}
|
| 108 |
+
row = json.dumps(row, ensure_ascii=False)
|
| 109 |
+
if random2 < (1 / 10):
|
| 110 |
+
fvalid.write(f"{row}\n")
|
| 111 |
+
else:
|
| 112 |
+
ftrain.write(f"{row}\n")
|
| 113 |
+
|
| 114 |
+
count += 1
|
| 115 |
+
duration_seconds = count * args.duration
|
| 116 |
+
duration_hours = duration_seconds / 3600
|
| 117 |
+
|
| 118 |
+
process_bar.update(n=1)
|
| 119 |
+
process_bar.set_postfix({
|
| 120 |
+
"duration_hours": round(duration_hours, 4),
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
examples/frcrn_mp3_to_wav/step_2_train_model.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 7 |
+
import os
|
| 8 |
+
import platform
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import random
|
| 11 |
+
import sys
|
| 12 |
+
import shutil
|
| 13 |
+
from typing import List
|
| 14 |
+
|
| 15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.nn import functional as F
|
| 22 |
+
from torch.utils.data.dataloader import DataLoader
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
|
| 26 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 27 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 28 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 29 |
+
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
|
| 30 |
+
from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_args():
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 36 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 37 |
+
|
| 38 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 39 |
+
parser.add_argument("--patience", default=30, type=int)
|
| 40 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
return args
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def logging_config(file_dir: str):
|
| 49 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 50 |
+
|
| 51 |
+
logging.basicConfig(format=fmt,
|
| 52 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 53 |
+
level=logging.INFO)
|
| 54 |
+
file_handler = TimedRotatingFileHandler(
|
| 55 |
+
filename=os.path.join(file_dir, "main.log"),
|
| 56 |
+
encoding="utf-8",
|
| 57 |
+
when="D",
|
| 58 |
+
interval=1,
|
| 59 |
+
backupCount=7
|
| 60 |
+
)
|
| 61 |
+
file_handler.setLevel(logging.INFO)
|
| 62 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
| 63 |
+
logger = logging.getLogger(__name__)
|
| 64 |
+
logger.addHandler(file_handler)
|
| 65 |
+
|
| 66 |
+
return logger
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class CollateFunction(object):
|
| 70 |
+
def __init__(self):
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
def __call__(self, batch: List[dict]):
|
| 74 |
+
mp3_waveform_list = list()
|
| 75 |
+
wav_waveform_list = list()
|
| 76 |
+
|
| 77 |
+
for sample in batch:
|
| 78 |
+
mp3_waveform: torch.Tensor = sample["mp3_waveform"]
|
| 79 |
+
wav_waveform: torch.Tensor = sample["wav_waveform"]
|
| 80 |
+
|
| 81 |
+
mp3_waveform_list.append(mp3_waveform)
|
| 82 |
+
wav_waveform_list.append(wav_waveform)
|
| 83 |
+
|
| 84 |
+
mp3_waveform_list = torch.stack(mp3_waveform_list)
|
| 85 |
+
wav_waveform_list = torch.stack(wav_waveform_list)
|
| 86 |
+
|
| 87 |
+
# assert
|
| 88 |
+
if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
|
| 89 |
+
raise AssertionError("nan or inf in mp3_waveform_list")
|
| 90 |
+
if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
|
| 91 |
+
raise AssertionError("nan or inf in wav_waveform_list")
|
| 92 |
+
|
| 93 |
+
return mp3_waveform_list, wav_waveform_list
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
collate_fn = CollateFunction()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
+
args = get_args()
|
| 101 |
+
|
| 102 |
+
config = FRCRNConfig.from_pretrained(
|
| 103 |
+
pretrained_model_name_or_path=args.config_file,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
serialization_dir = Path(args.serialization_dir)
|
| 107 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
logger = logging_config(serialization_dir)
|
| 110 |
+
|
| 111 |
+
random.seed(config.seed)
|
| 112 |
+
np.random.seed(config.seed)
|
| 113 |
+
torch.manual_seed(config.seed)
|
| 114 |
+
logger.info(f"set seed: {config.seed}")
|
| 115 |
+
|
| 116 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 117 |
+
n_gpu = torch.cuda.device_count()
|
| 118 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 119 |
+
|
| 120 |
+
# datasets
|
| 121 |
+
train_dataset = Mp3ToWavJsonlDataset(
|
| 122 |
+
jsonl_file=args.train_dataset,
|
| 123 |
+
expected_sample_rate=config.sample_rate,
|
| 124 |
+
max_wave_value=32768.0,
|
| 125 |
+
# skip=225000,
|
| 126 |
+
)
|
| 127 |
+
valid_dataset = Mp3ToWavJsonlDataset(
|
| 128 |
+
jsonl_file=args.valid_dataset,
|
| 129 |
+
expected_sample_rate=config.sample_rate,
|
| 130 |
+
max_wave_value=32768.0,
|
| 131 |
+
)
|
| 132 |
+
train_data_loader = DataLoader(
|
| 133 |
+
dataset=train_dataset,
|
| 134 |
+
batch_size=config.batch_size,
|
| 135 |
+
# shuffle=True,
|
| 136 |
+
sampler=None,
|
| 137 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 138 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 139 |
+
collate_fn=collate_fn,
|
| 140 |
+
pin_memory=False,
|
| 141 |
+
prefetch_factor=2,
|
| 142 |
+
)
|
| 143 |
+
valid_data_loader = DataLoader(
|
| 144 |
+
dataset=valid_dataset,
|
| 145 |
+
batch_size=config.batch_size,
|
| 146 |
+
# shuffle=True,
|
| 147 |
+
sampler=None,
|
| 148 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 149 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 150 |
+
collate_fn=collate_fn,
|
| 151 |
+
pin_memory=False,
|
| 152 |
+
prefetch_factor=2,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# models
|
| 156 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 157 |
+
model = FRCRNPretrainedModel(config).to(device)
|
| 158 |
+
model.to(device)
|
| 159 |
+
model.train()
|
| 160 |
+
|
| 161 |
+
# optimizer
|
| 162 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 163 |
+
optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
|
| 164 |
+
|
| 165 |
+
# resume training
|
| 166 |
+
last_step_idx = -1
|
| 167 |
+
last_epoch = -1
|
| 168 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 169 |
+
step_idx_str = Path(step_idx_str)
|
| 170 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
| 171 |
+
step_idx = int(step_idx)
|
| 172 |
+
if step_idx > last_step_idx:
|
| 173 |
+
last_step_idx = step_idx
|
| 174 |
+
# last_epoch = 0
|
| 175 |
+
|
| 176 |
+
if last_step_idx != -1:
|
| 177 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
| 178 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 179 |
+
# optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
| 180 |
+
|
| 181 |
+
logger.info(f"load state dict for model.")
|
| 182 |
+
with open(model_pt.as_posix(), "rb") as f:
|
| 183 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 184 |
+
model.load_state_dict(state_dict, strict=True)
|
| 185 |
+
|
| 186 |
+
# logger.info(f"load state dict for optimizer.")
|
| 187 |
+
# with open(optimizer_pth.as_posix(), "rb") as f:
|
| 188 |
+
# state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 189 |
+
# optimizer.load_state_dict(state_dict)
|
| 190 |
+
|
| 191 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
| 192 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 193 |
+
optimizer,
|
| 194 |
+
last_epoch=last_epoch,
|
| 195 |
+
# T_max=10 * config.eval_steps,
|
| 196 |
+
# eta_min=0.01 * config.lr,
|
| 197 |
+
**config.lr_scheduler_kwargs,
|
| 198 |
+
)
|
| 199 |
+
elif config.lr_scheduler == "MultiStepLR":
|
| 200 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 201 |
+
optimizer,
|
| 202 |
+
last_epoch=last_epoch,
|
| 203 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 207 |
+
|
| 208 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 209 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 210 |
+
fft_size_list=[256, 512, 1024],
|
| 211 |
+
win_size_list=[256, 512, 1024],
|
| 212 |
+
hop_size_list=[128, 256, 512],
|
| 213 |
+
factor_sc=1.5,
|
| 214 |
+
factor_mag=1.0,
|
| 215 |
+
reduction="mean"
|
| 216 |
+
).to(device)
|
| 217 |
+
|
| 218 |
+
# training loop
|
| 219 |
+
|
| 220 |
+
# state
|
| 221 |
+
average_pesq_score = 1000000000
|
| 222 |
+
average_loss = 1000000000
|
| 223 |
+
average_neg_si_snr_loss = 1000000000
|
| 224 |
+
average_mask_loss = 1000000000
|
| 225 |
+
|
| 226 |
+
model_list = list()
|
| 227 |
+
best_epoch_idx = None
|
| 228 |
+
best_step_idx = None
|
| 229 |
+
best_metric = None
|
| 230 |
+
patience_count = 0
|
| 231 |
+
|
| 232 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 233 |
+
|
| 234 |
+
logger.info("training")
|
| 235 |
+
early_stop_flag = False
|
| 236 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 237 |
+
if early_stop_flag:
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
# train
|
| 241 |
+
model.train()
|
| 242 |
+
|
| 243 |
+
total_pesq_score = 0.
|
| 244 |
+
total_loss = 0.
|
| 245 |
+
total_mr_stft_loss = 0.
|
| 246 |
+
total_neg_si_snr_loss = 0.
|
| 247 |
+
total_mask_loss = 0.
|
| 248 |
+
total_batches = 0.
|
| 249 |
+
|
| 250 |
+
progress_bar_train = tqdm(
|
| 251 |
+
initial=step_idx,
|
| 252 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
| 253 |
+
)
|
| 254 |
+
for train_batch in train_data_loader:
|
| 255 |
+
mp3_audios, wav_audios = train_batch
|
| 256 |
+
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 257 |
+
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 258 |
+
|
| 259 |
+
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
| 260 |
+
denoise_audios = est_wav
|
| 261 |
+
|
| 262 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 263 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 264 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 265 |
+
|
| 266 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
| 267 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 268 |
+
logger.info(f"find nan or inf in loss.")
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 272 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 273 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 274 |
+
|
| 275 |
+
optimizer.zero_grad()
|
| 276 |
+
loss.backward()
|
| 277 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 278 |
+
optimizer.step()
|
| 279 |
+
lr_scheduler.step()
|
| 280 |
+
|
| 281 |
+
total_pesq_score += pesq_score
|
| 282 |
+
total_loss += loss.item()
|
| 283 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
| 284 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 285 |
+
total_mask_loss += mask_loss.item()
|
| 286 |
+
total_batches += 1
|
| 287 |
+
|
| 288 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 289 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 290 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 291 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 292 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 293 |
+
|
| 294 |
+
progress_bar_train.update(1)
|
| 295 |
+
progress_bar_train.set_postfix({
|
| 296 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 297 |
+
"pesq_score": average_pesq_score,
|
| 298 |
+
"loss": average_loss,
|
| 299 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 300 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 301 |
+
"mask_loss": average_mask_loss,
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
# evaluation
|
| 305 |
+
step_idx += 1
|
| 306 |
+
if step_idx % config.eval_steps == 0:
|
| 307 |
+
model.eval()
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
torch.cuda.empty_cache()
|
| 310 |
+
|
| 311 |
+
total_pesq_score = 0.
|
| 312 |
+
total_loss = 0.
|
| 313 |
+
total_mr_stft_loss = 0.
|
| 314 |
+
total_neg_si_snr_loss = 0.
|
| 315 |
+
total_mask_loss = 0.
|
| 316 |
+
total_batches = 0.
|
| 317 |
+
|
| 318 |
+
progress_bar_train.close()
|
| 319 |
+
progress_bar_eval = tqdm(
|
| 320 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 321 |
+
)
|
| 322 |
+
for eval_batch in valid_data_loader:
|
| 323 |
+
mp3_audios, wav_audios = eval_batch
|
| 324 |
+
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 325 |
+
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 326 |
+
|
| 327 |
+
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
| 328 |
+
denoise_audios = est_wav
|
| 329 |
+
|
| 330 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 331 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 332 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 333 |
+
|
| 334 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
| 335 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 336 |
+
logger.info(f"find nan or inf in loss.")
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 340 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 341 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 342 |
+
|
| 343 |
+
total_pesq_score += pesq_score
|
| 344 |
+
total_loss += loss.item()
|
| 345 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 346 |
+
total_mask_loss += mask_loss.item()
|
| 347 |
+
total_batches += 1
|
| 348 |
+
|
| 349 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 350 |
+
average_loss = round(total_loss / total_batches, 4)
|
| 351 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 352 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 353 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 354 |
+
|
| 355 |
+
progress_bar_eval.update(1)
|
| 356 |
+
progress_bar_eval.set_postfix({
|
| 357 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 358 |
+
"pesq_score": average_pesq_score,
|
| 359 |
+
"loss": average_loss,
|
| 360 |
+
"mr_stft_loss": average_mr_stft_loss,
|
| 361 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 362 |
+
"mask_loss": average_mask_loss,
|
| 363 |
+
})
|
| 364 |
+
|
| 365 |
+
total_pesq_score = 0.
|
| 366 |
+
total_loss = 0.
|
| 367 |
+
total_mr_stft_loss = 0.
|
| 368 |
+
total_neg_si_snr_loss = 0.
|
| 369 |
+
total_mask_loss = 0.
|
| 370 |
+
total_batches = 0.
|
| 371 |
+
|
| 372 |
+
progress_bar_eval.close()
|
| 373 |
+
progress_bar_train = tqdm(
|
| 374 |
+
initial=progress_bar_train.n,
|
| 375 |
+
postfix=progress_bar_train.postfix,
|
| 376 |
+
desc=progress_bar_train.desc,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# save path
|
| 380 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 381 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
| 382 |
+
|
| 383 |
+
# save models
|
| 384 |
+
model.save_pretrained(save_dir.as_posix())
|
| 385 |
+
|
| 386 |
+
model_list.append(save_dir)
|
| 387 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 388 |
+
model_to_delete: Path = model_list.pop(0)
|
| 389 |
+
shutil.rmtree(model_to_delete.as_posix())
|
| 390 |
+
|
| 391 |
+
# save metric
|
| 392 |
+
if best_metric is None:
|
| 393 |
+
best_epoch_idx = epoch_idx
|
| 394 |
+
best_step_idx = step_idx
|
| 395 |
+
best_metric = average_pesq_score
|
| 396 |
+
elif average_pesq_score >= best_metric:
|
| 397 |
+
# great is better.
|
| 398 |
+
best_epoch_idx = epoch_idx
|
| 399 |
+
best_step_idx = step_idx
|
| 400 |
+
best_metric = average_pesq_score
|
| 401 |
+
else:
|
| 402 |
+
pass
|
| 403 |
+
|
| 404 |
+
metrics = {
|
| 405 |
+
"epoch_idx": epoch_idx,
|
| 406 |
+
"best_epoch_idx": best_epoch_idx,
|
| 407 |
+
"best_step_idx": best_step_idx,
|
| 408 |
+
"pesq_score": average_pesq_score,
|
| 409 |
+
"loss": average_loss,
|
| 410 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 411 |
+
"mask_loss": average_mask_loss,
|
| 412 |
+
}
|
| 413 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
| 414 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 415 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 416 |
+
|
| 417 |
+
# save best
|
| 418 |
+
best_dir = serialization_dir / "best"
|
| 419 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 420 |
+
if best_dir.exists():
|
| 421 |
+
shutil.rmtree(best_dir)
|
| 422 |
+
shutil.copytree(save_dir, best_dir)
|
| 423 |
+
|
| 424 |
+
# early stop
|
| 425 |
+
early_stop_flag = False
|
| 426 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 427 |
+
patience_count = 0
|
| 428 |
+
else:
|
| 429 |
+
patience_count += 1
|
| 430 |
+
if patience_count >= args.patience:
|
| 431 |
+
early_stop_flag = True
|
| 432 |
+
|
| 433 |
+
# early stop
|
| 434 |
+
if early_stop_flag:
|
| 435 |
+
break
|
| 436 |
+
model.train()
|
| 437 |
+
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
if __name__ == "__main__":
|
| 442 |
+
main()
|