wcz333 commited on
Commit
faa035a
·
verified ·
1 Parent(s): 28b4f12

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ g_dns_best filter=lfs diff=lfs merge=lfs -text
37
+ g_phase_retrieval_voicebank_best filter=lfs diff=lfs merge=lfs -text
38
+ g_universal_best filter=lfs diff=lfs merge=lfs -text
39
+ g_vbd_best filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_gpus": 0,
3
+ "batch_size": 4,
4
+ "learning_rate": 0.0005,
5
+ "adam_b1": 0.8,
6
+ "adam_b2": 0.99,
7
+ "lr_decay": 0.99,
8
+ "seed": 1234,
9
+
10
+ "grad_clip_val": 5,
11
+
12
+ "dense_channel": 64,
13
+ "compress_factor": 0.3,
14
+ "num_tsconformers": 4,
15
+ "beta": 2.0,
16
+
17
+ "sampling_rate": 16000,
18
+ "segment_size": 32000,
19
+ "n_fft": 400,
20
+ "hop_size": 100,
21
+ "win_size": 400,
22
+
23
+ "amp_chn":48,
24
+ "ang_chn":16,
25
+ "n_heads":4,
26
+ "amp_attnhead_dim":12,
27
+ "ang_attnhead_dim":6,
28
+
29
+ "num_workers": 16,
30
+
31
+ "dist_config": {
32
+ "dist_backend": "nccl",
33
+ "dist_url": "tcp://localhost:12345",
34
+ "world_size": 1
35
+ }
36
+ }
config_small.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_gpus": 0,
3
+ "batch_size": 4,
4
+ "learning_rate": 0.0005,
5
+ "adam_b1": 0.8,
6
+ "adam_b2": 0.99,
7
+ "lr_decay": 0.99,
8
+ "seed": 1234,
9
+
10
+ "grad_clip_val": 5,
11
+
12
+ "dense_channel": 64,
13
+ "compress_factor": 0.3,
14
+ "num_tsconformers": 4,
15
+ "beta": 2.0,
16
+
17
+ "sampling_rate": 16000,
18
+ "segment_size": 32000,
19
+ "n_fft": 400,
20
+ "hop_size": 100,
21
+ "win_size": 400,
22
+
23
+ "amp_chn":32,
24
+ "ang_chn":16,
25
+ "n_heads":4,
26
+ "amp_attnhead_dim":8,
27
+ "ang_attnhead_dim":6,
28
+
29
+ "num_workers": 16,
30
+
31
+ "dist_config": {
32
+ "dist_backend": "nccl",
33
+ "dist_url": "tcp://localhost:12345",
34
+ "world_size": 1
35
+ }
36
+ }
config_universal.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_gpus": 0,
3
+ "batch_size": 4,
4
+ "learning_rate": 0.0005,
5
+ "adam_b1": 0.8,
6
+ "adam_b2": 0.99,
7
+ "lr_decay": 0.999,
8
+ "seed": 1234,
9
+
10
+ "grad_clip_val": 5,
11
+
12
+ "dense_channel": 64,
13
+ "compress_factor": 0.3,
14
+ "num_tsconformers": 4,
15
+ "beta": 2.0,
16
+
17
+ "sampling_rate": 16000,
18
+ "segment_size": 32000,
19
+ "n_fft": 400,
20
+ "hop_size": 100,
21
+ "win_size": 400,
22
+
23
+ "amp_chn":48,
24
+ "ang_chn":16,
25
+ "n_heads":4,
26
+ "amp_attnhead_dim":12,
27
+ "ang_attnhead_dim":6,
28
+
29
+ "num_workers": 16,
30
+
31
+ "dist_config": {
32
+ "dist_backend": "nccl",
33
+ "dist_url": "tcp://localhost:12345",
34
+ "world_size": 1
35
+ }
36
+ }
g_dns_best ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2e41e0fb3c23a1224211e610f8344076c7d5a6243ebccb09ee716cacc93c11b
3
+ size 6363371
g_phase_retrieval_voicebank_best ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:618f48587a27a3ba29d21ca0b6834bb9459fc25126f98015c5b621d0432f7572
3
+ size 3765758
g_universal_best ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:935d6b9f1212fa31bbb097b9221be8bb7fa9e3cc91698b8122731c5d77184cf3
3
+ size 6363774
g_vbd_best ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bee17454c2c3751b624905a5d5686750c01a85a436dd35f6b0d30809d139b51
3
+ size 6363774
readme.md ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Global Rotation Equivariance Phase Modeling for Speech Enhancement with Deep Magnitude-Phase Interaction
2
+ This repository hosts the official implementation for the paper:
3
+
4
+ **Global Rotation Equivariant Phase Modeling for Speech Enhancement with Deep Magnitude-Phase Interaction** (submitted to IEEE TASLP).
5
+
6
+ Authors: Chengzhong Wang, Andong Li, Dingding Yao and Junfeng Li*
7
+
8
+ A manifold-aware magnitude–phase dual-stream framework is proposed, that enforces **Global Rotation Equivariance (GRE)** in the phase stream, enabling robust phase modeling with strong generalization across denoising, dereverberation, bandwidth extension, and mixed distortions.
9
+
10
+ Training logs, audio samples and supplementary analysis: https://wangchengzhong.github.io/RENet-Supplementary-Materials/
11
+
12
+ ---
13
+
14
+ ## Implementation Summary
15
+ - **GRE as inductive bias:** explicit global rotation equivariance for phase modeling.
16
+ - **Deep Magnitude–Phase Interaction:** MPICM for cross-stream gating without breaking equivariance.
17
+ - **Hybrid Attention Dual-FFN (HADF):** attention fusion in score domain + stream-specific FFNs.
18
+ - **Strong results with compact model:** 1.55M parameters, competitive or better quality than advanced baselines across SE tasks.
19
+
20
+ ---
21
+
22
+ ## Method Overview
23
+
24
+ ### Architecture
25
+ ![Overview of RENet architecture](assets/fig_architecture.png)
26
+
27
+ The model uses a dual-stream encoder–decoder with a GRE-constrained complex phase branch and a real-valued magnitude branch. Key modules:
28
+
29
+ - **MPICM (Magnitude-Phase Interactive Convolutional Module):**
30
+ - bias-free complex convolution for phase stream
31
+ - RMSNorm + SiLU for magnitude stream
32
+ - cross-stream modulus-based gating preserving GRE
33
+
34
+ - **HADF (Hybrid-Attention Dual-FFN):**
35
+ - hybrid attention with a shared score map
36
+ - independent magnitude/phase value projections
37
+ - GRU-based FFN for magnitude, complex-valued convolutional GLU for phase
38
+
39
+ ### Global Rotation Equivariance
40
+ GRE ensures $
41
+ \mathcal{F}(\mathbf{x}e^{j\theta}) = \mathcal{F}(\mathbf{x})e^{j\theta}
42
+ $, preventing the phase stream from learning arbitrary absolute orientations while preserving relative phase structure (GD/IP).
43
+
44
+ ---
45
+
46
+ ## 🔵 Experiments
47
+
48
+ We evaluate on three settings:
49
+ 1. **Phase Retrieval** (clean magnitude, zero phase)
50
+ 2. **Denoising** (VoiceBank+DEMAND, DNS-2020)
51
+ 3. **Universal SE** (DNS-2021 training, WSJ0+WHAMR! test; DN/DR/BWE/mixed)
52
+
53
+ ### Phase Retrieval (VoiceBank)
54
+
55
+ | Model | Params (M) | MACs (G/s) | PESQ | SI-SDR | WOPD $\downarrow$ | PD $\downarrow$ |
56
+ | --- | ---: | ---: | ---: | ---: | ---: | ---: |
57
+ | Griffin-Lim | - | - | 4.23 | -17.07 | 0.342 | 90.07 |
58
+ | DiffPhase | 65.6 | 3330 | 4.41 | -11.75 | 0.230 | 85.66 |
59
+ | MP-SENet Up.* | 1.99 | 38.80 | 4.60 | 14.64 | 0.058 | 11.38 |
60
+ | SEMamba* | 1.88 | 38.01 | 4.59 | 13.63 | 0.059 | 12.46 |
61
+ | **Proposed (Small)** | **0.90** | **22.89** | **4.61** | **16.03** | **0.044** | **8.47** |
62
+
63
+ \* Single phase decoder for phase retrieval.
64
+
65
+ ### Denoising (VBD & DNS-2020)
66
+ ![Denoising results](assets/fig_denoising.png)
67
+
68
+ \* SEMamba reported w/o PCS.
69
+
70
+ Key result: strong zero-shot transfer from VBD to DNS-2020 with consistent gains across PESQ, STOI, UTMOS, and PD; SOTA results on larget-scale DNS-2020.
71
+
72
+ ### Universal SE (DNS-2021 → WSJ0+WHAMR!)
73
+ ![Universal SE results](assets/fig_universal_results.png)
74
+
75
+ Our model achieves top-tier performance across DN/DR/BWE and mixed distortions.
76
+
77
+ ---
78
+
79
+ ## Repository Structure
80
+ - Training:
81
+ - train_denoising_dns.py
82
+ - train_denoising_vbd.py
83
+ - train_phase_retrieval.py
84
+ - train_universal_dns.py
85
+ - Inference:
86
+ - inference_denoising.py
87
+ - inference_phase.py
88
+ - inference_universal.py
89
+ - Core modules:
90
+ - models/model.py
91
+ - models/transformer.py
92
+ - models/mpd_and_metricd.py
93
+ - Data:
94
+ - dataset.py
95
+ - dns_dataset.py
96
+ - data_gen/
97
+ - Metrics:
98
+ - cal_metrics_singledir.py
99
+ - cal_metrics_hierarchicaldir.py
100
+ ---
101
+
102
+ ## Configurations
103
+ We provide multiple configs for different settings:
104
+ - config.json (Standard)
105
+ - config_small.json (Phase Retrieval)
106
+ - config_universal.json (Universal SE)
107
+
108
+ ---
109
+
110
+ ## Setup
111
+ This project depends on PyTorch and common audio/metric libraries. Make sure your environment includes:
112
+ - torch
113
+ - librosa
114
+ - soundfile
115
+ - numpy
116
+ - pesq
117
+ - pystoi
118
+ - tablib[xlsx]
119
+ - tqdm
120
+
121
+ ---
122
+
123
+ ## Data Preparation
124
+
125
+ ### 1) VoiceBank+DEMAND (Denoising/Phase Retrieval)
126
+ Place 16 kHz wavs here:
127
+ - filelist_VBD/wavs_clean
128
+ - filelist_VBD/wavs_noisy
129
+
130
+
131
+ The Filelists are with the same formulation as that of MP-SENet:
132
+ - filelist_VBD/training.txt
133
+ - filelist_VBD/test.txt
134
+
135
+ ### 2) DNS-2020 (Denoising)
136
+ Place clean wavs and noisy wavs in two separate folders and create the filelist (3000h).
137
+
138
+ Filelist format (the clean files path is set in the training script):
139
+ ```
140
+ clean_fileid_118096.wav|/abs/path/to/noisy_fileid_118096.wav
141
+ ```
142
+
143
+ You can generate this list using:
144
+ - data/generate_filelist.py
145
+
146
+ Default path:
147
+ - filelist_DNS20/training.txt
148
+ - filelist_DNS20/test.txt
149
+
150
+ ### 3) DNS-2020 + WSJ-WHAMR test(Universal SE)
151
+ Prepare a DNS-2021-style list with the same format as DNS-2020 (300h):
152
+ ```
153
+ clean_fileid_000123.wav|/abs/path/to/noisy_fileid_000123.wav
154
+ ```
155
+
156
+ Default path:
157
+ - filelist_DNS21/training.txt
158
+
159
+ We provide the generated WSJ+WHAMR universal SE test set [here](https://drive.google.com/file/d/123-WvyaKZkKqbh81Q_gMTOdTGgxPB_3z/view?usp=sharing).
160
+
161
+ ---
162
+
163
+ ## 🚀 Training
164
+
165
+ Pre-trained checkpoints for each task are released in the `checkpoint/` folder.
166
+
167
+ ### VBD Denoising
168
+ ```
169
+ python train_denoising_vbd.py --config config.json
170
+ ```
171
+
172
+ ### DNS-2020 Denoising
173
+ ```
174
+ python train_denoising_dns.py --config config.json
175
+ ```
176
+
177
+ ### Phase Retrieval (Small)
178
+ ```
179
+ python train_phase_retrieval.py --config config_small.json
180
+ ```
181
+
182
+ ### Universal SE (DNS-2021)
183
+ ```
184
+ python train_universal_dns.py \
185
+ --test_noisy_dir /path/to/wsj_whamr/noisy_test \
186
+ --test_clean_dir /path/to/wsj_whamr/clean_test \
187
+ --config config_universal.json \
188
+ ```
189
+ ---
190
+
191
+ ## Inference
192
+
193
+ ### Unified Inference
194
+ ```
195
+ python inference_{denoising|phase|universal}.py \
196
+ --checkpoint_file /path/to/checkpoint \
197
+ --input_noisy_wavs_dir /path/to/input_wavs \
198
+ --output_dir /path/to/output
199
+ ```
200
+
201
+ Notes:
202
+ - Use `inference_denoising.py` with a denoising checkpoint and a noisy input folder.
203
+ - Use `inference_phase.py` with PR-trained checkpoint and a clean input folder (the script drops the phase itself).
204
+ - Use `inference_universal.py` with USE-trained checkpoint and universal degraded test input.
205
+
206
+ The inference scripts load the corresponding config file from the checkpoint folder automatically.
207
+
208
+ ---
209
+
210
+ ## 📈 Evaluation
211
+
212
+ Note: To evaluate UTMOS and DNSMOS, the required metric checkpoint files are not included in this repository. Please place them under `cal_metrics/dns` and `cal_metrics/UTMOS_demo` before running the evaluator.
213
+
214
+ We provide a single-directory evaluator that computes PESQ/STOI/SI-SNR/CSIG/CBAK/COVL/UTMOS/DNSMOS and phase metrics (PD/WOPD):
215
+ ```
216
+ python cal_metrics_singledir.py \
217
+ --clean_dir /path/to/clean \
218
+ --enhanced_dir /path/to/enhanced \
219
+ --excel_name results.xlsx
220
+ ```
221
+
222
+ To compute only a subset of metrics, use `--metrics` with a comma-separated list (or `all` for everything):
223
+ ```
224
+ python cal_metrics_singledir.py \
225
+ --clean_dir /path/to/clean \
226
+ --enhanced_dir /path/to/enhanced \
227
+ --excel_name results.xlsx \
228
+ --metrics PESQ,STOI,SISNR,PD,WOPD
229
+ ```
230
+
231
+ For hierarchical test sets (e.g., universal subfolders), the enhanced wavs should be in the same relative structure as the clean set.
232
+
233
+ Hierarchical-directory example ( `cal_metrics_hierarchicaldir.py`):
234
+ ```
235
+ python cal_metrics_hierarchicaldir.py \
236
+ --clean_dir /path/to/clean_root \
237
+ --enhanced_dir /path/to/enhanced_root \
238
+ --excel_name results_hierarchical.csv \
239
+ --metrics PESQ,STOI,SISNR,PD,WOPD
240
+ ```
241
+
242
+ Target directory layout (clean and enhanced must mirror each other):
243
+ ```
244
+ clean_root/
245
+ noise_limit/|noise_reverb/|noise_reverb_limit/|only_noise/
246
+ -5db/|0db/|5db/|10db/|15db/
247
+ 0001.wav ...
248
+ only_reverb/
249
+ 0001.wav ...
250
+ only_bandlimit/
251
+ 2khz/|4khz/
252
+ 0001.wav ...
253
+
254
+ enhanced_root/
255
+ (same structure and filenames as clean_root)
256
+ ```
257
+
258
+
259
+
260
+ ---
261
+
262
+ ## Calculating MACs
263
+ Since we use multiple custom operations, only counting standard conv/deconv/GRU/MHA underestimates MACs. We implement MAC counting for InteConvBlock(Transpose), CustomAttention, and ComplexFFN.
264
+
265
+ Run:
266
+ ```
267
+ python cal_mac.py
268
+ ```
269
+
270
+ To modify the model size, edit the configuration near the bottom of cal_mac.py.
271
+
272
+
273
+ ## Acknowledgements
274
+ We acknowledge the contributions of the following repositories, which served as important references for our code implementation:
275
+ - [MP-SENet](https://github.com/yxlu-0102/MP-SENet)
276
+ - [SEMamba](https://github.com/RoyChao19477/SEMamba)
277
+
278
+ ---
279
+
280
+ ## Citation
281
+ If you find this work useful, please cite the paper.
282
+
283
+
284
+
285
+
286
+
287
+
288
+
289
+
290
+
291
+
292
+
293
+
294
+
295
+
296
+
297
+
298
+
299
+