File size: 7,644 Bytes
c958a64
 
 
 
067fb50
 
c958a64
067fb50
 
 
f108b12
067fb50
 
ccdbe3d
f108b12
067fb50
 
f108b12
067fb50
f108b12
067fb50
f108b12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535a9f5
f108b12
 
 
 
 
 
 
 
535a9f5
f108b12
 
 
 
 
 
 
 
535a9f5
f108b12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535a9f5
f108b12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535a9f5
f108b12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a26e7a5
f108b12
 
 
 
 
535a9f5
f108b12
 
 
 
 
535a9f5
f108b12
 
 
 
 
535a9f5
f108b12
 
 
c958a64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
---
license: mit
datasets:
- JacobLinCool/VoiceBank-DEMAND-16k
base_model:
- liduojia/MeanFlowSE
---
<div align="center">
<p align="center">
  <h1>MeanFlowSE β€” One-Step Generative Speech Enhancement</h1>

  [![Paper](https://img.shields.io/badge/Paper-arXiv-b31b1b?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2509.14858)
  [![Hugging Face Model](https://img.shields.io/badge/Model-HuggingFace-yellow?logo=huggingface)](https://huggingface.co/liduojia/MeanFlowSE)
  [![Code](https://img.shields.io/badge/Code-Repo-black?style=flat&logo=github&logoColor=white)](https://github.com/liduojia1/MeanFlowSE)

</p>
</div>

**MeanFlowSE** is a conditional generative approach to speech enhancement that learns average velocities over short time spans and performs enhancement in a single step. Instead of rolling out a long ODE trajectory, it applies one backward-in-time displacement directly in the complex STFT domain, delivering competitive quality at a fraction of the compute and latency. The model is trained end-to-end with a local JVP-based objective and remains consistent with conditional flow matching on the diagonalβ€”no teacher models, schedulers, or distillation required. In practice, 1-NFE inference makes real-time deployment on standard hardware straightforward.

* 🎧 **Demo**: demo page coming soon.
---

## Table of Contents

* [Highlights](#highlights)
* [What’s inside](#whats-inside)
* [Quick start](#quick-start)

  * [Installation](#installation)
  * [Data preparation](#data-preparation)
  * [Training](#training)
  * [Inference](#inference)
* [Configuration](#configuration)
* [Repository structure](#repository-structure)
* [Built upon & related work](#built-upon--related-work)
* [Pretrained models](#pretrained-models)
* [Acknowledgments](#acknowledgments)
* [Citation](#citation)



## Highlights

* **One-step enhancement (1-NFE):** A single displacement update replaces long ODE rolloutsβ€”fast enough for real-time use on standard GPUs/CPUs.
* **No teachers, no distillation:** Trains with a local, JVP-based objective; on the diagonal it exactly matches conditional flow matching.
* **Same model, two samplers:** Use the displacement sampler for 1-step (or few-step) inference; fall back to Euler along the instantaneous field if you prefer multi-step.
* **Competitive & fast:** strong ESTOI / SI-SDR / DNSMOS with **very low RTF** on VoiceBank-DEMAND.



## What’s inside

* **Training** with Average field supervision (for the 1-step displacement sampler).
* **Inference** with  euler_mf β€” single-step displacement along average field.
* **Audio front-end**: complex STFT pipeline; configurable transforms & normalization.
* **Metrics**: PESQ, ESTOI, SI-SDR; end-to-end **RTF** measurement.



## Quick start

### Installation

```bash
# Python 3.10 recommended

pip install -r requirements.txt
# Use a recent PyTorch + CUDA build for multi-GPU training
```

### Data preparation

Expected layout:

```
<BASE_DIR>/
  train/clean/*.wav   train/noisy/*.wav
  valid/clean/*.wav   valid/noisy/*.wav
  test/clean/*.wav    test/noisy/*.wav
```

Defaults assume 16 kHz audio, centered frames, Hann windows, and a complex STFT representation (see `SpecsDataModule` for knobs).

### Training

**Single machine, multi-GPU (DDP)**:

```bash
# Edit DATA_DIR and GPUs inside the script if needed
bash train_vbd.sh
```

Or run directly:

```bash
torchrun --standalone --nproc_per_node=4 train.py \
  --backbone ncsnpp \
  --ode flowmatching \
  --base_dir <BASE_DIR> \
  --batch_size 2 \
  --num_workers 8 \
  --max_epochs 150 \
  --precision 32 \
  --gradient_clip_val 1.0 \
  --t_eps 0.03 --T_rev 1.0 \
  --sigma_min 0.0 --sigma_max 0.487 \
  --use_mfse \
  --mf_weight_final 0.25 \
  --mf_warmup_frac 0.5 \
  --mf_delta_gamma_start 8.0 --mf_delta_gamma_end 1.0 \
  --mf_delta_warmup_frac 0.7 \
  --mf_r_equals_t_prob 0.1 \
  --mf_jvp_clip 5.0 --mf_jvp_eps 1e-3 \
  --mf_jvp_impl fd --mf_jvp_chunk 1 \
  --mf_skip_weight_thresh 0.05 \
  --val_metrics_every_n_epochs 1 \
  --default_root_dir lightning_logs
```

* **Logging & checkpoints** live under `lightning_logs/<exp_name>/version_x/`.
* Heavy validation (PESQ/ESTOI/SI-SDR) runs **every N epochs** on **rank-0**; placeholders are logged otherwise so checkpoint monitors remain valid.

### Inference

Use the helper script:

```bash
# MODE = multistep | multistep_mf | onestep
MODE=onestep STEPS=1 \
TEST_DATA_DIR=<BASE_DIR> \
CKPT_INPUT=path/to/best.ckpt \
bash run_inference.sh
```

Or call the evaluator:

```bash
python evaluate.py \
  --test_dir <BASE_DIR> \
  --folder_destination /path/to/output \
  --ckpt path/to/best.ckpt \
  --odesolver euler_mf \
  --reverse_starting_point 1.0 \
  --last_eval_point 0.0 \
  --one_step
```

> `evaluate.py` writes **enhanced WAVs**.
> If `--odesolver` is not given, it **auto-picks** (`euler_mf` when MF-SE was used; otherwise `euler`).



## Configuration

Common flags you may want to tweak:

* **Time & schedule**

  * `--T_rev` (reverse start, default 1.0), `--t_eps` (terminal time), `--sigma_min`, `--sigma_max`
* **MF-SE stability**

  * `--mf_jvp_impl {auto,fd,autograd}`, `--mf_jvp_chunk`, `--mf_jvp_clip`, `--mf_jvp_eps`
  * Curriculum: `--mf_weight_final`, `--mf_warmup_frac`, `--mf_delta_*`, `--mf_r_equals_t_prob`
* **Validation cost**

  * `--val_metrics_every_n_epochs`, `--num_eval_files`
* **Backbone & front-end**

  * Defined in `backbones/` and `SpecsDataModule` (STFT, transforms, normalization)



## Repository structure

```
MeanFlowSE/
β”œβ”€β”€ train.py                 # Lightning entry
β”œβ”€β”€ evaluate.py              # Enhancement script (WAV out)
β”œβ”€β”€ run_inference.sh         # One-step / few-step convenience runner
β”œβ”€β”€ flowmse/
β”‚   β”œβ”€β”€ model.py             # Losses, JVP, curriculum, logging
β”‚   β”œβ”€β”€ odes.py              # Path definition & registry
β”‚   β”œβ”€β”€ sampling/
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   └── odesolvers.py    # Euler (instantaneous) & Euler-MF (displacement)
β”‚   β”œβ”€β”€ backbones/
β”‚   β”‚   β”œβ”€β”€ ncsnpp.py        # U-Net w/ time & delta embeddings
β”‚   β”‚   └── ...
β”‚   β”œβ”€β”€ data_module.py       # STFT I/O pipeline
β”‚   └── util/                # metrics, registry, tensors, inference helpers
β”œβ”€β”€ requirements.txt
└── scripts/
    └── train_vbd.sh
```

## Built upon & related work

This repository builds upon previous great works:

* **SGMSE** β€” [https://github.com/sp-uhh/sgmse](https://github.com/sp-uhh/sgmse)
* **SGMSE-CRP** β€” [https://github.com/sp-uhh/sgmse\_crp](https://github.com/sp-uhh/sgmse_crp)
* **SGMSE-BBED** β€” [https://github.com/sp-uhh/sgmse-bbed](https://github.com/sp-uhh/sgmse-bbed)
* **FLOWMSE (FlowSE)** β€” [https://github.com/seongq/flowmse](https://github.com/seongq/flowmse)

Many design choices (complex STFT pipeline, training infrastructure) are inspired by these excellent projects.



## Pretrained models

* **VoiceBank–DEMAND (16 kHz)**: We have hosted the weight files on Google Drive and added the link here.β€” [Google Drive Link](https://drive.google.com/file/d/1QAxgd5BWrxiNi0q2qD3n1Xcv6bW0X86-/view?usp=sharing)



## Acknowledgments

We gratefully acknowledge **Prof. Xie Chen’s group (X-LANCE Lab, SJTU)** for their **valuable guidance and support** on training practices and engineering tips that helped this work a lot.



## Citation

* **Citation:** The paper is currently under review. We will add a BibTeX entry and article link once available.




**Questions or issues?** Please open a GitHub issue or pull request.
We welcome contributions β€” from bug fixes to new backbones and front-ends.