File size: 7,644 Bytes
c958a64 067fb50 c958a64 067fb50 f108b12 067fb50 ccdbe3d f108b12 067fb50 f108b12 067fb50 f108b12 067fb50 f108b12 535a9f5 f108b12 535a9f5 f108b12 535a9f5 f108b12 535a9f5 f108b12 535a9f5 f108b12 a26e7a5 f108b12 535a9f5 f108b12 535a9f5 f108b12 535a9f5 f108b12 c958a64 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | ---
license: mit
datasets:
- JacobLinCool/VoiceBank-DEMAND-16k
base_model:
- liduojia/MeanFlowSE
---
<div align="center">
<p align="center">
<h1>MeanFlowSE β One-Step Generative Speech Enhancement</h1>
[](https://arxiv.org/abs/2509.14858)
[](https://huggingface.co/liduojia/MeanFlowSE)
[](https://github.com/liduojia1/MeanFlowSE)
</p>
</div>
**MeanFlowSE** is a conditional generative approach to speech enhancement that learns average velocities over short time spans and performs enhancement in a single step. Instead of rolling out a long ODE trajectory, it applies one backward-in-time displacement directly in the complex STFT domain, delivering competitive quality at a fraction of the compute and latency. The model is trained end-to-end with a local JVP-based objective and remains consistent with conditional flow matching on the diagonalβno teacher models, schedulers, or distillation required. In practice, 1-NFE inference makes real-time deployment on standard hardware straightforward.
* π§ **Demo**: demo page coming soon.
---
## Table of Contents
* [Highlights](#highlights)
* [Whatβs inside](#whats-inside)
* [Quick start](#quick-start)
* [Installation](#installation)
* [Data preparation](#data-preparation)
* [Training](#training)
* [Inference](#inference)
* [Configuration](#configuration)
* [Repository structure](#repository-structure)
* [Built upon & related work](#built-upon--related-work)
* [Pretrained models](#pretrained-models)
* [Acknowledgments](#acknowledgments)
* [Citation](#citation)
## Highlights
* **One-step enhancement (1-NFE):** A single displacement update replaces long ODE rolloutsβfast enough for real-time use on standard GPUs/CPUs.
* **No teachers, no distillation:** Trains with a local, JVP-based objective; on the diagonal it exactly matches conditional flow matching.
* **Same model, two samplers:** Use the displacement sampler for 1-step (or few-step) inference; fall back to Euler along the instantaneous field if you prefer multi-step.
* **Competitive & fast:** strong ESTOI / SI-SDR / DNSMOS with **very low RTF** on VoiceBank-DEMAND.
## Whatβs inside
* **Training** with Average field supervision (for the 1-step displacement sampler).
* **Inference** with euler_mf β single-step displacement along average field.
* **Audio front-end**: complex STFT pipeline; configurable transforms & normalization.
* **Metrics**: PESQ, ESTOI, SI-SDR; end-to-end **RTF** measurement.
## Quick start
### Installation
```bash
# Python 3.10 recommended
pip install -r requirements.txt
# Use a recent PyTorch + CUDA build for multi-GPU training
```
### Data preparation
Expected layout:
```
<BASE_DIR>/
train/clean/*.wav train/noisy/*.wav
valid/clean/*.wav valid/noisy/*.wav
test/clean/*.wav test/noisy/*.wav
```
Defaults assume 16 kHz audio, centered frames, Hann windows, and a complex STFT representation (see `SpecsDataModule` for knobs).
### Training
**Single machine, multi-GPU (DDP)**:
```bash
# Edit DATA_DIR and GPUs inside the script if needed
bash train_vbd.sh
```
Or run directly:
```bash
torchrun --standalone --nproc_per_node=4 train.py \
--backbone ncsnpp \
--ode flowmatching \
--base_dir <BASE_DIR> \
--batch_size 2 \
--num_workers 8 \
--max_epochs 150 \
--precision 32 \
--gradient_clip_val 1.0 \
--t_eps 0.03 --T_rev 1.0 \
--sigma_min 0.0 --sigma_max 0.487 \
--use_mfse \
--mf_weight_final 0.25 \
--mf_warmup_frac 0.5 \
--mf_delta_gamma_start 8.0 --mf_delta_gamma_end 1.0 \
--mf_delta_warmup_frac 0.7 \
--mf_r_equals_t_prob 0.1 \
--mf_jvp_clip 5.0 --mf_jvp_eps 1e-3 \
--mf_jvp_impl fd --mf_jvp_chunk 1 \
--mf_skip_weight_thresh 0.05 \
--val_metrics_every_n_epochs 1 \
--default_root_dir lightning_logs
```
* **Logging & checkpoints** live under `lightning_logs/<exp_name>/version_x/`.
* Heavy validation (PESQ/ESTOI/SI-SDR) runs **every N epochs** on **rank-0**; placeholders are logged otherwise so checkpoint monitors remain valid.
### Inference
Use the helper script:
```bash
# MODE = multistep | multistep_mf | onestep
MODE=onestep STEPS=1 \
TEST_DATA_DIR=<BASE_DIR> \
CKPT_INPUT=path/to/best.ckpt \
bash run_inference.sh
```
Or call the evaluator:
```bash
python evaluate.py \
--test_dir <BASE_DIR> \
--folder_destination /path/to/output \
--ckpt path/to/best.ckpt \
--odesolver euler_mf \
--reverse_starting_point 1.0 \
--last_eval_point 0.0 \
--one_step
```
> `evaluate.py` writes **enhanced WAVs**.
> If `--odesolver` is not given, it **auto-picks** (`euler_mf` when MF-SE was used; otherwise `euler`).
## Configuration
Common flags you may want to tweak:
* **Time & schedule**
* `--T_rev` (reverse start, default 1.0), `--t_eps` (terminal time), `--sigma_min`, `--sigma_max`
* **MF-SE stability**
* `--mf_jvp_impl {auto,fd,autograd}`, `--mf_jvp_chunk`, `--mf_jvp_clip`, `--mf_jvp_eps`
* Curriculum: `--mf_weight_final`, `--mf_warmup_frac`, `--mf_delta_*`, `--mf_r_equals_t_prob`
* **Validation cost**
* `--val_metrics_every_n_epochs`, `--num_eval_files`
* **Backbone & front-end**
* Defined in `backbones/` and `SpecsDataModule` (STFT, transforms, normalization)
## Repository structure
```
MeanFlowSE/
βββ train.py # Lightning entry
βββ evaluate.py # Enhancement script (WAV out)
βββ run_inference.sh # One-step / few-step convenience runner
βββ flowmse/
β βββ model.py # Losses, JVP, curriculum, logging
β βββ odes.py # Path definition & registry
β βββ sampling/
β β βββ __init__.py
β β βββ odesolvers.py # Euler (instantaneous) & Euler-MF (displacement)
β βββ backbones/
β β βββ ncsnpp.py # U-Net w/ time & delta embeddings
β β βββ ...
β βββ data_module.py # STFT I/O pipeline
β βββ util/ # metrics, registry, tensors, inference helpers
βββ requirements.txt
βββ scripts/
βββ train_vbd.sh
```
## Built upon & related work
This repository builds upon previous great works:
* **SGMSE** β [https://github.com/sp-uhh/sgmse](https://github.com/sp-uhh/sgmse)
* **SGMSE-CRP** β [https://github.com/sp-uhh/sgmse\_crp](https://github.com/sp-uhh/sgmse_crp)
* **SGMSE-BBED** β [https://github.com/sp-uhh/sgmse-bbed](https://github.com/sp-uhh/sgmse-bbed)
* **FLOWMSE (FlowSE)** β [https://github.com/seongq/flowmse](https://github.com/seongq/flowmse)
Many design choices (complex STFT pipeline, training infrastructure) are inspired by these excellent projects.
## Pretrained models
* **VoiceBankβDEMAND (16 kHz)**: We have hosted the weight files on Google Drive and added the link here.β [Google Drive Link](https://drive.google.com/file/d/1QAxgd5BWrxiNi0q2qD3n1Xcv6bW0X86-/view?usp=sharing)
## Acknowledgments
We gratefully acknowledge **Prof. Xie Chenβs group (X-LANCE Lab, SJTU)** for their **valuable guidance and support** on training practices and engineering tips that helped this work a lot.
## Citation
* **Citation:** The paper is currently under review. We will add a BibTeX entry and article link once available.
**Questions or issues?** Please open a GitHub issue or pull request.
We welcome contributions β from bug fixes to new backbones and front-ends. |