MeanFlowSE / README.md
liduojia's picture
Update README.md
ccdbe3d verified
---
license: mit
datasets:
- JacobLinCool/VoiceBank-DEMAND-16k
base_model:
- liduojia/MeanFlowSE
---
<div align="center">
<p align="center">
<h1>MeanFlowSE β€” One-Step Generative Speech Enhancement</h1>
[![Paper](https://img.shields.io/badge/Paper-arXiv-b31b1b?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2509.14858)
[![Hugging Face Model](https://img.shields.io/badge/Model-HuggingFace-yellow?logo=huggingface)](https://huggingface.co/liduojia/MeanFlowSE)
[![Code](https://img.shields.io/badge/Code-Repo-black?style=flat&logo=github&logoColor=white)](https://github.com/liduojia1/MeanFlowSE)
</p>
</div>
**MeanFlowSE** is a conditional generative approach to speech enhancement that learns average velocities over short time spans and performs enhancement in a single step. Instead of rolling out a long ODE trajectory, it applies one backward-in-time displacement directly in the complex STFT domain, delivering competitive quality at a fraction of the compute and latency. The model is trained end-to-end with a local JVP-based objective and remains consistent with conditional flow matching on the diagonalβ€”no teacher models, schedulers, or distillation required. In practice, 1-NFE inference makes real-time deployment on standard hardware straightforward.
* 🎧 **Demo**: demo page coming soon.
---
## Table of Contents
* [Highlights](#highlights)
* [What’s inside](#whats-inside)
* [Quick start](#quick-start)
* [Installation](#installation)
* [Data preparation](#data-preparation)
* [Training](#training)
* [Inference](#inference)
* [Configuration](#configuration)
* [Repository structure](#repository-structure)
* [Built upon & related work](#built-upon--related-work)
* [Pretrained models](#pretrained-models)
* [Acknowledgments](#acknowledgments)
* [Citation](#citation)
## Highlights
* **One-step enhancement (1-NFE):** A single displacement update replaces long ODE rolloutsβ€”fast enough for real-time use on standard GPUs/CPUs.
* **No teachers, no distillation:** Trains with a local, JVP-based objective; on the diagonal it exactly matches conditional flow matching.
* **Same model, two samplers:** Use the displacement sampler for 1-step (or few-step) inference; fall back to Euler along the instantaneous field if you prefer multi-step.
* **Competitive & fast:** strong ESTOI / SI-SDR / DNSMOS with **very low RTF** on VoiceBank-DEMAND.
## What’s inside
* **Training** with Average field supervision (for the 1-step displacement sampler).
* **Inference** with euler_mf β€” single-step displacement along average field.
* **Audio front-end**: complex STFT pipeline; configurable transforms & normalization.
* **Metrics**: PESQ, ESTOI, SI-SDR; end-to-end **RTF** measurement.
## Quick start
### Installation
```bash
# Python 3.10 recommended
pip install -r requirements.txt
# Use a recent PyTorch + CUDA build for multi-GPU training
```
### Data preparation
Expected layout:
```
<BASE_DIR>/
train/clean/*.wav train/noisy/*.wav
valid/clean/*.wav valid/noisy/*.wav
test/clean/*.wav test/noisy/*.wav
```
Defaults assume 16 kHz audio, centered frames, Hann windows, and a complex STFT representation (see `SpecsDataModule` for knobs).
### Training
**Single machine, multi-GPU (DDP)**:
```bash
# Edit DATA_DIR and GPUs inside the script if needed
bash train_vbd.sh
```
Or run directly:
```bash
torchrun --standalone --nproc_per_node=4 train.py \
--backbone ncsnpp \
--ode flowmatching \
--base_dir <BASE_DIR> \
--batch_size 2 \
--num_workers 8 \
--max_epochs 150 \
--precision 32 \
--gradient_clip_val 1.0 \
--t_eps 0.03 --T_rev 1.0 \
--sigma_min 0.0 --sigma_max 0.487 \
--use_mfse \
--mf_weight_final 0.25 \
--mf_warmup_frac 0.5 \
--mf_delta_gamma_start 8.0 --mf_delta_gamma_end 1.0 \
--mf_delta_warmup_frac 0.7 \
--mf_r_equals_t_prob 0.1 \
--mf_jvp_clip 5.0 --mf_jvp_eps 1e-3 \
--mf_jvp_impl fd --mf_jvp_chunk 1 \
--mf_skip_weight_thresh 0.05 \
--val_metrics_every_n_epochs 1 \
--default_root_dir lightning_logs
```
* **Logging & checkpoints** live under `lightning_logs/<exp_name>/version_x/`.
* Heavy validation (PESQ/ESTOI/SI-SDR) runs **every N epochs** on **rank-0**; placeholders are logged otherwise so checkpoint monitors remain valid.
### Inference
Use the helper script:
```bash
# MODE = multistep | multistep_mf | onestep
MODE=onestep STEPS=1 \
TEST_DATA_DIR=<BASE_DIR> \
CKPT_INPUT=path/to/best.ckpt \
bash run_inference.sh
```
Or call the evaluator:
```bash
python evaluate.py \
--test_dir <BASE_DIR> \
--folder_destination /path/to/output \
--ckpt path/to/best.ckpt \
--odesolver euler_mf \
--reverse_starting_point 1.0 \
--last_eval_point 0.0 \
--one_step
```
> `evaluate.py` writes **enhanced WAVs**.
> If `--odesolver` is not given, it **auto-picks** (`euler_mf` when MF-SE was used; otherwise `euler`).
## Configuration
Common flags you may want to tweak:
* **Time & schedule**
* `--T_rev` (reverse start, default 1.0), `--t_eps` (terminal time), `--sigma_min`, `--sigma_max`
* **MF-SE stability**
* `--mf_jvp_impl {auto,fd,autograd}`, `--mf_jvp_chunk`, `--mf_jvp_clip`, `--mf_jvp_eps`
* Curriculum: `--mf_weight_final`, `--mf_warmup_frac`, `--mf_delta_*`, `--mf_r_equals_t_prob`
* **Validation cost**
* `--val_metrics_every_n_epochs`, `--num_eval_files`
* **Backbone & front-end**
* Defined in `backbones/` and `SpecsDataModule` (STFT, transforms, normalization)
## Repository structure
```
MeanFlowSE/
β”œβ”€β”€ train.py # Lightning entry
β”œβ”€β”€ evaluate.py # Enhancement script (WAV out)
β”œβ”€β”€ run_inference.sh # One-step / few-step convenience runner
β”œβ”€β”€ flowmse/
β”‚ β”œβ”€β”€ model.py # Losses, JVP, curriculum, logging
β”‚ β”œβ”€β”€ odes.py # Path definition & registry
β”‚ β”œβ”€β”€ sampling/
β”‚ β”‚ β”œβ”€β”€ __init__.py
β”‚ β”‚ └── odesolvers.py # Euler (instantaneous) & Euler-MF (displacement)
β”‚ β”œβ”€β”€ backbones/
β”‚ β”‚ β”œβ”€β”€ ncsnpp.py # U-Net w/ time & delta embeddings
β”‚ β”‚ └── ...
β”‚ β”œβ”€β”€ data_module.py # STFT I/O pipeline
β”‚ └── util/ # metrics, registry, tensors, inference helpers
β”œβ”€β”€ requirements.txt
└── scripts/
└── train_vbd.sh
```
## Built upon & related work
This repository builds upon previous great works:
* **SGMSE** β€” [https://github.com/sp-uhh/sgmse](https://github.com/sp-uhh/sgmse)
* **SGMSE-CRP** β€” [https://github.com/sp-uhh/sgmse\_crp](https://github.com/sp-uhh/sgmse_crp)
* **SGMSE-BBED** β€” [https://github.com/sp-uhh/sgmse-bbed](https://github.com/sp-uhh/sgmse-bbed)
* **FLOWMSE (FlowSE)** β€” [https://github.com/seongq/flowmse](https://github.com/seongq/flowmse)
Many design choices (complex STFT pipeline, training infrastructure) are inspired by these excellent projects.
## Pretrained models
* **VoiceBank–DEMAND (16 kHz)**: We have hosted the weight files on Google Drive and added the link here.β€” [Google Drive Link](https://drive.google.com/file/d/1QAxgd5BWrxiNi0q2qD3n1Xcv6bW0X86-/view?usp=sharing)
## Acknowledgments
We gratefully acknowledge **Prof. Xie Chen’s group (X-LANCE Lab, SJTU)** for their **valuable guidance and support** on training practices and engineering tips that helped this work a lot.
## Citation
* **Citation:** The paper is currently under review. We will add a BibTeX entry and article link once available.
**Questions or issues?** Please open a GitHub issue or pull request.
We welcome contributions β€” from bug fixes to new backbones and front-ends.