MixFlow / README.md
AricGamma's picture
Update README.md (#6)
bb9cfa9 verified
<div align="center">
# MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture
[Hui Li](https://scholar.google.com/citations?user=QeQnG7IAAAAJ&hl=zh-CN)<sup>1</sup> 路 [Jiayue Lyu](https://scholar.google.com.hk/citations?user=Q4LVvegAAAAJ&hl=zh-CN)<sup>1</sup> 路 [Fu-Yun Wang](https://g-u-n.github.io/)<sup>2</sup> 路 [Kaihui Cheng](https://github.com/Kaihui-Cheng)<sup>1</sup>
[Siyu Zhu](https://sites.google.com/site/zhusiyucs/home)<sup>1,4,5</sup> 路 [Jingdong Wang](https://jingdongwang2017.github.io/)<sup>3</sup>
<sup>1</sup>Fudan University &emsp; <sup>2</sup>The Chinese University of Hong Kong &emsp; <sup>3</sup>Baidu &emsp;
<sup>4</sup>Shanghai Innovation Institute &emsp; <sup>5</sup>Shanghai Academy of AI for Science
[馃寪 Project page](https://mixflowgen.github.io/) &ensp; [馃 Models](https://huggingface.co/fudan-generative-ai/MixFlow) &ensp; [馃搫 Paper](https://arxiv.org/abs/2512.19311)
</div>
---
This is the official PyTorch implementation of **MixFlow**, a novel post-training approach for improving diffusion and flow matching models by alleviating the training-testing discrepancy (exposure bias).
## Method Overview
We present a novel training approach, named MixFlow, for improving the training performance. Our approach is motivated by the Slow Flow phenomenon: the ground-truth interpolation that is the nearest to the generated noisy data at a given sampling timestep is observed to correspond
to a higher-noise timestep (termed slowed timestep), i.e., the corresponding ground-truth timestep is slower than the
sampling timestep. MixFlow leverages the interpolations at the slowed timesteps, named slowed interpolation mixture, for post-training the prediction network at each training timestep.
### Implementation
The implementation is simple. For example, for MixFlow-RAE, 4 lines are added, and 1 line is modified in the file `src/stage2/transport/transport.py`:
```diff
def sample(self, x1):
"""Sampling x0 & t based on shape of x1 (if needed)
Args:
x1 - data point; [batch, *dim]
"""
# ...
if dist_options[0] == "uniform":
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
+ t = 1 - th.sqrt(t) # sample t from Beta(2,1)
# ...
return t, x0, x1
def training_losses(
self,
model,
x1,
+ gamma=0.4, # mixture range coefficient
model_kwargs=None
):
# ...
t, x0, x1 = self.sample(x1)
+ t, _, ut = self.path_sampler.plan(t, x0, x1) # optional modification: remove the output xt, xt will be slowed interpolation
+ mt = t + th.rand(*t.size(), device=t.device, dtype=t.dtype) * gamma * (1 - t) # sample slowed timestep mt from U[(1-gamma)t, t]
+ _, xt, __ = self.path_sampler.plan(mt, x0, x1) # compute slowed interpolation
model_output = model(xt, t, **model_kwargs)
```
### This repository includes four folders.
Each folder provides the training scripts, inference pipelines, and model weights for the following configurations:
- MixFlow + RAE (Folder: `MixFlow-RAE`)
- MixFlow + REPA (Folder: `MixFlow-REPA`)
- MixFlow + SiT (Folder: `MixFlow-SiT`)
- SD3.5-M + MixFlow (Folder: `MixFlow-SD3.5`) (TBD)
## Results
### ImageNet 256x256
| Model | Params | FID (w/o cfg) | FID (w/ cfg) | Checkpoint |
|:------|:------:|:-------------:|:------------:|:----------:|
| MixFlow + SiT-XL | 675M | 7.56 | 1.97 | [Download](https://huggingface.co/fudan-generative-ai/MixFlow) |
| MixFlow + REPA-XL | 675M | 5.00 | 1.22 | [Download](https://huggingface.co/fudan-generative-ai/MixFlow) |
| MixFlow + RAE-XL | 839M | **1.43** | **1.10** | [Download](https://huggingface.co/fudan-generative-ai/MixFlow/) |
### ImageNet 512x512
| Model | Params | FID (w/o cfg) | FID (w/ cfg) | Checkpoint |
|:------|:------:|:-------------:|:------------:|:----------:|
| MixFlow + RAE-XL | 839M | **1.55** | **1.10** | [Download](https://huggingface.co/fudan-generative-ai/MixFlow) |
## Getting Started
### 1. Environment Setup
To set up our environment, please run:
```bash
git clone https://github.com//MixFlow.git
cd MixFlow
# Using conda
conda activate -n mixflow python=3.10 -y
conda activate mixflow
# Or using venv from the Python standard library (optional)
python3.10 -m venv .venv
source .venv/bin/activate
# Install uv
pip install uv
# Install PyTorch 2.8.0 with CUDA 12.4
uv pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu124
# Install other dependencies
uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb
uv pip install "numpy<2" transformers einops omegaconf diffusers requests ftfy regex
```
### 2. Data, Model Download
Data preparation (ImageNet-1k), pretrained model download procedures should follow the settings detailed in the RAE documentation [here](https://github.com/bytetriper/RAE?tab=readme-ov-file#download-pre-trained-models).
The pretrained model is comprised of two parts: the encoder and decoder, and the DiTDH-XL. After downloading these checkpoints into the `MixFlow-RAE/models` folder, the next training step will use the pretrained DiTDH-XL as the starting checkpoint.
### 3. Post-train the RAE model with MixFlow
```bash
cd MixFlow-RAE
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/train.py \
--config configs/stage2/training/ImageNet256/DiTDH-XL_DINOv2-B.yaml \
--data-path <imagenet_train_split> \
--results-dir results/mixflow \
--precision fp32 \
--ckpt models/DiTs/Dinov2/wReg_base/ImageNet256/DiTDH-XL/stage2_model.pt # load the pretrained model
```
### 4. Evaluation
For distributed sampling, please run:
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/sample_ddp.py \
--config <sample_config> \
--sample-dir samples \
--precision fp32 \
--label-sampling equal
```
Note that we utilize an autoguidance scale of 1.5 for both 256 and 512 resolutions. This differs from the original RAE settings, which use 1.42 for 256 resolution and 1.5 for 512 resolution.
After generating 50k samples, evaluate the results using the ADM evaluation suite. For detailed instructions, please refer to the RAE documentation [here](https://github.com/bytetriper/RAE?tab=readme-ov-file#evaluation).
## Citation
If you find this work useful, please cite:
```bibtex
@article{mixflow2025,
title={MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture},
author={Hui Li and Jiayue Lyu and Fu-yun Wang and Kaihui Cheng and Siyu Zhu and Jingdong Wang},
journal={arXiv preprint arXiv: 2512.19311},
year={2025}
}
```
## Acknowledgements
This codebase builds upon the following excellent works:
- [SiT](https://github.com/willisma/SiT) - Scalable Interpolant Transformers
- [REPA](https://github.com/sihyun-yu/REPA) - Representation Alignment
- [RAE](https://github.com/bytetriper/RAE) - Representation Autoencoders
- [diffusion-pipe](https://github.com/tdrussell/diffusion-pipe) - A pipeline parallel training script for diffusion models.
---
license: mit
---