Upload 8 files
Browse files- .gitattributes +4 -0
- config.json +36 -0
- config_small.json +36 -0
- config_universal.json +36 -0
- g_dns_best +3 -0
- g_phase_retrieval_voicebank_best +3 -0
- g_universal_best +3 -0
- g_vbd_best +3 -0
- readme.md +299 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
g_dns_best filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
g_phase_retrieval_voicebank_best filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
g_universal_best filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
g_vbd_best filter=lfs diff=lfs merge=lfs -text
|
config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_gpus": 0,
|
| 3 |
+
"batch_size": 4,
|
| 4 |
+
"learning_rate": 0.0005,
|
| 5 |
+
"adam_b1": 0.8,
|
| 6 |
+
"adam_b2": 0.99,
|
| 7 |
+
"lr_decay": 0.99,
|
| 8 |
+
"seed": 1234,
|
| 9 |
+
|
| 10 |
+
"grad_clip_val": 5,
|
| 11 |
+
|
| 12 |
+
"dense_channel": 64,
|
| 13 |
+
"compress_factor": 0.3,
|
| 14 |
+
"num_tsconformers": 4,
|
| 15 |
+
"beta": 2.0,
|
| 16 |
+
|
| 17 |
+
"sampling_rate": 16000,
|
| 18 |
+
"segment_size": 32000,
|
| 19 |
+
"n_fft": 400,
|
| 20 |
+
"hop_size": 100,
|
| 21 |
+
"win_size": 400,
|
| 22 |
+
|
| 23 |
+
"amp_chn":48,
|
| 24 |
+
"ang_chn":16,
|
| 25 |
+
"n_heads":4,
|
| 26 |
+
"amp_attnhead_dim":12,
|
| 27 |
+
"ang_attnhead_dim":6,
|
| 28 |
+
|
| 29 |
+
"num_workers": 16,
|
| 30 |
+
|
| 31 |
+
"dist_config": {
|
| 32 |
+
"dist_backend": "nccl",
|
| 33 |
+
"dist_url": "tcp://localhost:12345",
|
| 34 |
+
"world_size": 1
|
| 35 |
+
}
|
| 36 |
+
}
|
config_small.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_gpus": 0,
|
| 3 |
+
"batch_size": 4,
|
| 4 |
+
"learning_rate": 0.0005,
|
| 5 |
+
"adam_b1": 0.8,
|
| 6 |
+
"adam_b2": 0.99,
|
| 7 |
+
"lr_decay": 0.99,
|
| 8 |
+
"seed": 1234,
|
| 9 |
+
|
| 10 |
+
"grad_clip_val": 5,
|
| 11 |
+
|
| 12 |
+
"dense_channel": 64,
|
| 13 |
+
"compress_factor": 0.3,
|
| 14 |
+
"num_tsconformers": 4,
|
| 15 |
+
"beta": 2.0,
|
| 16 |
+
|
| 17 |
+
"sampling_rate": 16000,
|
| 18 |
+
"segment_size": 32000,
|
| 19 |
+
"n_fft": 400,
|
| 20 |
+
"hop_size": 100,
|
| 21 |
+
"win_size": 400,
|
| 22 |
+
|
| 23 |
+
"amp_chn":32,
|
| 24 |
+
"ang_chn":16,
|
| 25 |
+
"n_heads":4,
|
| 26 |
+
"amp_attnhead_dim":8,
|
| 27 |
+
"ang_attnhead_dim":6,
|
| 28 |
+
|
| 29 |
+
"num_workers": 16,
|
| 30 |
+
|
| 31 |
+
"dist_config": {
|
| 32 |
+
"dist_backend": "nccl",
|
| 33 |
+
"dist_url": "tcp://localhost:12345",
|
| 34 |
+
"world_size": 1
|
| 35 |
+
}
|
| 36 |
+
}
|
config_universal.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_gpus": 0,
|
| 3 |
+
"batch_size": 4,
|
| 4 |
+
"learning_rate": 0.0005,
|
| 5 |
+
"adam_b1": 0.8,
|
| 6 |
+
"adam_b2": 0.99,
|
| 7 |
+
"lr_decay": 0.999,
|
| 8 |
+
"seed": 1234,
|
| 9 |
+
|
| 10 |
+
"grad_clip_val": 5,
|
| 11 |
+
|
| 12 |
+
"dense_channel": 64,
|
| 13 |
+
"compress_factor": 0.3,
|
| 14 |
+
"num_tsconformers": 4,
|
| 15 |
+
"beta": 2.0,
|
| 16 |
+
|
| 17 |
+
"sampling_rate": 16000,
|
| 18 |
+
"segment_size": 32000,
|
| 19 |
+
"n_fft": 400,
|
| 20 |
+
"hop_size": 100,
|
| 21 |
+
"win_size": 400,
|
| 22 |
+
|
| 23 |
+
"amp_chn":48,
|
| 24 |
+
"ang_chn":16,
|
| 25 |
+
"n_heads":4,
|
| 26 |
+
"amp_attnhead_dim":12,
|
| 27 |
+
"ang_attnhead_dim":6,
|
| 28 |
+
|
| 29 |
+
"num_workers": 16,
|
| 30 |
+
|
| 31 |
+
"dist_config": {
|
| 32 |
+
"dist_backend": "nccl",
|
| 33 |
+
"dist_url": "tcp://localhost:12345",
|
| 34 |
+
"world_size": 1
|
| 35 |
+
}
|
| 36 |
+
}
|
g_dns_best
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2e41e0fb3c23a1224211e610f8344076c7d5a6243ebccb09ee716cacc93c11b
|
| 3 |
+
size 6363371
|
g_phase_retrieval_voicebank_best
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:618f48587a27a3ba29d21ca0b6834bb9459fc25126f98015c5b621d0432f7572
|
| 3 |
+
size 3765758
|
g_universal_best
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:935d6b9f1212fa31bbb097b9221be8bb7fa9e3cc91698b8122731c5d77184cf3
|
| 3 |
+
size 6363774
|
g_vbd_best
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3bee17454c2c3751b624905a5d5686750c01a85a436dd35f6b0d30809d139b51
|
| 3 |
+
size 6363774
|
readme.md
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Global Rotation Equivariance Phase Modeling for Speech Enhancement with Deep Magnitude-Phase Interaction
|
| 2 |
+
This repository hosts the official implementation for the paper:
|
| 3 |
+
|
| 4 |
+
**Global Rotation Equivariant Phase Modeling for Speech Enhancement with Deep Magnitude-Phase Interaction** (submitted to IEEE TASLP).
|
| 5 |
+
|
| 6 |
+
Authors: Chengzhong Wang, Andong Li, Dingding Yao and Junfeng Li*
|
| 7 |
+
|
| 8 |
+
A manifold-aware magnitude–phase dual-stream framework is proposed, that enforces **Global Rotation Equivariance (GRE)** in the phase stream, enabling robust phase modeling with strong generalization across denoising, dereverberation, bandwidth extension, and mixed distortions.
|
| 9 |
+
|
| 10 |
+
Training logs, audio samples and supplementary analysis: https://wangchengzhong.github.io/RENet-Supplementary-Materials/
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## Implementation Summary
|
| 15 |
+
- **GRE as inductive bias:** explicit global rotation equivariance for phase modeling.
|
| 16 |
+
- **Deep Magnitude–Phase Interaction:** MPICM for cross-stream gating without breaking equivariance.
|
| 17 |
+
- **Hybrid Attention Dual-FFN (HADF):** attention fusion in score domain + stream-specific FFNs.
|
| 18 |
+
- **Strong results with compact model:** 1.55M parameters, competitive or better quality than advanced baselines across SE tasks.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Method Overview
|
| 23 |
+
|
| 24 |
+
### Architecture
|
| 25 |
+

|
| 26 |
+
|
| 27 |
+
The model uses a dual-stream encoder–decoder with a GRE-constrained complex phase branch and a real-valued magnitude branch. Key modules:
|
| 28 |
+
|
| 29 |
+
- **MPICM (Magnitude-Phase Interactive Convolutional Module):**
|
| 30 |
+
- bias-free complex convolution for phase stream
|
| 31 |
+
- RMSNorm + SiLU for magnitude stream
|
| 32 |
+
- cross-stream modulus-based gating preserving GRE
|
| 33 |
+
|
| 34 |
+
- **HADF (Hybrid-Attention Dual-FFN):**
|
| 35 |
+
- hybrid attention with a shared score map
|
| 36 |
+
- independent magnitude/phase value projections
|
| 37 |
+
- GRU-based FFN for magnitude, complex-valued convolutional GLU for phase
|
| 38 |
+
|
| 39 |
+
### Global Rotation Equivariance
|
| 40 |
+
GRE ensures $
|
| 41 |
+
\mathcal{F}(\mathbf{x}e^{j\theta}) = \mathcal{F}(\mathbf{x})e^{j\theta}
|
| 42 |
+
$, preventing the phase stream from learning arbitrary absolute orientations while preserving relative phase structure (GD/IP).
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## 🔵 Experiments
|
| 47 |
+
|
| 48 |
+
We evaluate on three settings:
|
| 49 |
+
1. **Phase Retrieval** (clean magnitude, zero phase)
|
| 50 |
+
2. **Denoising** (VoiceBank+DEMAND, DNS-2020)
|
| 51 |
+
3. **Universal SE** (DNS-2021 training, WSJ0+WHAMR! test; DN/DR/BWE/mixed)
|
| 52 |
+
|
| 53 |
+
### Phase Retrieval (VoiceBank)
|
| 54 |
+
|
| 55 |
+
| Model | Params (M) | MACs (G/s) | PESQ | SI-SDR | WOPD $\downarrow$ | PD $\downarrow$ |
|
| 56 |
+
| --- | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 57 |
+
| Griffin-Lim | - | - | 4.23 | -17.07 | 0.342 | 90.07 |
|
| 58 |
+
| DiffPhase | 65.6 | 3330 | 4.41 | -11.75 | 0.230 | 85.66 |
|
| 59 |
+
| MP-SENet Up.* | 1.99 | 38.80 | 4.60 | 14.64 | 0.058 | 11.38 |
|
| 60 |
+
| SEMamba* | 1.88 | 38.01 | 4.59 | 13.63 | 0.059 | 12.46 |
|
| 61 |
+
| **Proposed (Small)** | **0.90** | **22.89** | **4.61** | **16.03** | **0.044** | **8.47** |
|
| 62 |
+
|
| 63 |
+
\* Single phase decoder for phase retrieval.
|
| 64 |
+
|
| 65 |
+
### Denoising (VBD & DNS-2020)
|
| 66 |
+

|
| 67 |
+
|
| 68 |
+
\* SEMamba reported w/o PCS.
|
| 69 |
+
|
| 70 |
+
Key result: strong zero-shot transfer from VBD to DNS-2020 with consistent gains across PESQ, STOI, UTMOS, and PD; SOTA results on larget-scale DNS-2020.
|
| 71 |
+
|
| 72 |
+
### Universal SE (DNS-2021 → WSJ0+WHAMR!)
|
| 73 |
+

|
| 74 |
+
|
| 75 |
+
Our model achieves top-tier performance across DN/DR/BWE and mixed distortions.
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Repository Structure
|
| 80 |
+
- Training:
|
| 81 |
+
- train_denoising_dns.py
|
| 82 |
+
- train_denoising_vbd.py
|
| 83 |
+
- train_phase_retrieval.py
|
| 84 |
+
- train_universal_dns.py
|
| 85 |
+
- Inference:
|
| 86 |
+
- inference_denoising.py
|
| 87 |
+
- inference_phase.py
|
| 88 |
+
- inference_universal.py
|
| 89 |
+
- Core modules:
|
| 90 |
+
- models/model.py
|
| 91 |
+
- models/transformer.py
|
| 92 |
+
- models/mpd_and_metricd.py
|
| 93 |
+
- Data:
|
| 94 |
+
- dataset.py
|
| 95 |
+
- dns_dataset.py
|
| 96 |
+
- data_gen/
|
| 97 |
+
- Metrics:
|
| 98 |
+
- cal_metrics_singledir.py
|
| 99 |
+
- cal_metrics_hierarchicaldir.py
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## Configurations
|
| 103 |
+
We provide multiple configs for different settings:
|
| 104 |
+
- config.json (Standard)
|
| 105 |
+
- config_small.json (Phase Retrieval)
|
| 106 |
+
- config_universal.json (Universal SE)
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## Setup
|
| 111 |
+
This project depends on PyTorch and common audio/metric libraries. Make sure your environment includes:
|
| 112 |
+
- torch
|
| 113 |
+
- librosa
|
| 114 |
+
- soundfile
|
| 115 |
+
- numpy
|
| 116 |
+
- pesq
|
| 117 |
+
- pystoi
|
| 118 |
+
- tablib[xlsx]
|
| 119 |
+
- tqdm
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## Data Preparation
|
| 124 |
+
|
| 125 |
+
### 1) VoiceBank+DEMAND (Denoising/Phase Retrieval)
|
| 126 |
+
Place 16 kHz wavs here:
|
| 127 |
+
- filelist_VBD/wavs_clean
|
| 128 |
+
- filelist_VBD/wavs_noisy
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
The Filelists are with the same formulation as that of MP-SENet:
|
| 132 |
+
- filelist_VBD/training.txt
|
| 133 |
+
- filelist_VBD/test.txt
|
| 134 |
+
|
| 135 |
+
### 2) DNS-2020 (Denoising)
|
| 136 |
+
Place clean wavs and noisy wavs in two separate folders and create the filelist (3000h).
|
| 137 |
+
|
| 138 |
+
Filelist format (the clean files path is set in the training script):
|
| 139 |
+
```
|
| 140 |
+
clean_fileid_118096.wav|/abs/path/to/noisy_fileid_118096.wav
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
You can generate this list using:
|
| 144 |
+
- data/generate_filelist.py
|
| 145 |
+
|
| 146 |
+
Default path:
|
| 147 |
+
- filelist_DNS20/training.txt
|
| 148 |
+
- filelist_DNS20/test.txt
|
| 149 |
+
|
| 150 |
+
### 3) DNS-2020 + WSJ-WHAMR test(Universal SE)
|
| 151 |
+
Prepare a DNS-2021-style list with the same format as DNS-2020 (300h):
|
| 152 |
+
```
|
| 153 |
+
clean_fileid_000123.wav|/abs/path/to/noisy_fileid_000123.wav
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Default path:
|
| 157 |
+
- filelist_DNS21/training.txt
|
| 158 |
+
|
| 159 |
+
We provide the generated WSJ+WHAMR universal SE test set [here](https://drive.google.com/file/d/123-WvyaKZkKqbh81Q_gMTOdTGgxPB_3z/view?usp=sharing).
|
| 160 |
+
|
| 161 |
+
---
|
| 162 |
+
|
| 163 |
+
## 🚀 Training
|
| 164 |
+
|
| 165 |
+
Pre-trained checkpoints for each task are released in the `checkpoint/` folder.
|
| 166 |
+
|
| 167 |
+
### VBD Denoising
|
| 168 |
+
```
|
| 169 |
+
python train_denoising_vbd.py --config config.json
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
### DNS-2020 Denoising
|
| 173 |
+
```
|
| 174 |
+
python train_denoising_dns.py --config config.json
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### Phase Retrieval (Small)
|
| 178 |
+
```
|
| 179 |
+
python train_phase_retrieval.py --config config_small.json
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
### Universal SE (DNS-2021)
|
| 183 |
+
```
|
| 184 |
+
python train_universal_dns.py \
|
| 185 |
+
--test_noisy_dir /path/to/wsj_whamr/noisy_test \
|
| 186 |
+
--test_clean_dir /path/to/wsj_whamr/clean_test \
|
| 187 |
+
--config config_universal.json \
|
| 188 |
+
```
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
## Inference
|
| 192 |
+
|
| 193 |
+
### Unified Inference
|
| 194 |
+
```
|
| 195 |
+
python inference_{denoising|phase|universal}.py \
|
| 196 |
+
--checkpoint_file /path/to/checkpoint \
|
| 197 |
+
--input_noisy_wavs_dir /path/to/input_wavs \
|
| 198 |
+
--output_dir /path/to/output
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
Notes:
|
| 202 |
+
- Use `inference_denoising.py` with a denoising checkpoint and a noisy input folder.
|
| 203 |
+
- Use `inference_phase.py` with PR-trained checkpoint and a clean input folder (the script drops the phase itself).
|
| 204 |
+
- Use `inference_universal.py` with USE-trained checkpoint and universal degraded test input.
|
| 205 |
+
|
| 206 |
+
The inference scripts load the corresponding config file from the checkpoint folder automatically.
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
|
| 210 |
+
## 📈 Evaluation
|
| 211 |
+
|
| 212 |
+
Note: To evaluate UTMOS and DNSMOS, the required metric checkpoint files are not included in this repository. Please place them under `cal_metrics/dns` and `cal_metrics/UTMOS_demo` before running the evaluator.
|
| 213 |
+
|
| 214 |
+
We provide a single-directory evaluator that computes PESQ/STOI/SI-SNR/CSIG/CBAK/COVL/UTMOS/DNSMOS and phase metrics (PD/WOPD):
|
| 215 |
+
```
|
| 216 |
+
python cal_metrics_singledir.py \
|
| 217 |
+
--clean_dir /path/to/clean \
|
| 218 |
+
--enhanced_dir /path/to/enhanced \
|
| 219 |
+
--excel_name results.xlsx
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
To compute only a subset of metrics, use `--metrics` with a comma-separated list (or `all` for everything):
|
| 223 |
+
```
|
| 224 |
+
python cal_metrics_singledir.py \
|
| 225 |
+
--clean_dir /path/to/clean \
|
| 226 |
+
--enhanced_dir /path/to/enhanced \
|
| 227 |
+
--excel_name results.xlsx \
|
| 228 |
+
--metrics PESQ,STOI,SISNR,PD,WOPD
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
For hierarchical test sets (e.g., universal subfolders), the enhanced wavs should be in the same relative structure as the clean set.
|
| 232 |
+
|
| 233 |
+
Hierarchical-directory example ( `cal_metrics_hierarchicaldir.py`):
|
| 234 |
+
```
|
| 235 |
+
python cal_metrics_hierarchicaldir.py \
|
| 236 |
+
--clean_dir /path/to/clean_root \
|
| 237 |
+
--enhanced_dir /path/to/enhanced_root \
|
| 238 |
+
--excel_name results_hierarchical.csv \
|
| 239 |
+
--metrics PESQ,STOI,SISNR,PD,WOPD
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
Target directory layout (clean and enhanced must mirror each other):
|
| 243 |
+
```
|
| 244 |
+
clean_root/
|
| 245 |
+
noise_limit/|noise_reverb/|noise_reverb_limit/|only_noise/
|
| 246 |
+
-5db/|0db/|5db/|10db/|15db/
|
| 247 |
+
0001.wav ...
|
| 248 |
+
only_reverb/
|
| 249 |
+
0001.wav ...
|
| 250 |
+
only_bandlimit/
|
| 251 |
+
2khz/|4khz/
|
| 252 |
+
0001.wav ...
|
| 253 |
+
|
| 254 |
+
enhanced_root/
|
| 255 |
+
(same structure and filenames as clean_root)
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
---
|
| 261 |
+
|
| 262 |
+
## Calculating MACs
|
| 263 |
+
Since we use multiple custom operations, only counting standard conv/deconv/GRU/MHA underestimates MACs. We implement MAC counting for InteConvBlock(Transpose), CustomAttention, and ComplexFFN.
|
| 264 |
+
|
| 265 |
+
Run:
|
| 266 |
+
```
|
| 267 |
+
python cal_mac.py
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
To modify the model size, edit the configuration near the bottom of cal_mac.py.
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
## Acknowledgements
|
| 274 |
+
We acknowledge the contributions of the following repositories, which served as important references for our code implementation:
|
| 275 |
+
- [MP-SENet](https://github.com/yxlu-0102/MP-SENet)
|
| 276 |
+
- [SEMamba](https://github.com/RoyChao19477/SEMamba)
|
| 277 |
+
|
| 278 |
+
---
|
| 279 |
+
|
| 280 |
+
## Citation
|
| 281 |
+
If you find this work useful, please cite the paper.
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|