| <div align="center"> | |
| # MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture | |
| [Hui Li](https://scholar.google.com/citations?user=QeQnG7IAAAAJ&hl=zh-CN)<sup>1</sup> 路 [Jiayue Lyu](https://scholar.google.com.hk/citations?user=Q4LVvegAAAAJ&hl=zh-CN)<sup>1</sup> 路 [Fu-Yun Wang](https://g-u-n.github.io/)<sup>2</sup> 路 [Kaihui Cheng](https://github.com/Kaihui-Cheng)<sup>1</sup> | |
| [Siyu Zhu](https://sites.google.com/site/zhusiyucs/home)<sup>1,4,5</sup> 路 [Jingdong Wang](https://jingdongwang2017.github.io/)<sup>3</sup> | |
| <sup>1</sup>Fudan University   <sup>2</sup>The Chinese University of Hong Kong   <sup>3</sup>Baidu   | |
| <sup>4</sup>Shanghai Innovation Institute   <sup>5</sup>Shanghai Academy of AI for Science | |
| [馃寪 Project page](https://mixflowgen.github.io/)   [馃 Models](https://huggingface.co/fudan-generative-ai/MixFlow)   [馃搫 Paper](https://arxiv.org/abs/2512.19311) | |
| </div> | |
| --- | |
| This is the official PyTorch implementation of **MixFlow**, a novel post-training approach for improving diffusion and flow matching models by alleviating the training-testing discrepancy (exposure bias). | |
| ## Method Overview | |
| We present a novel training approach, named MixFlow, for improving the training performance. Our approach is motivated by the Slow Flow phenomenon: the ground-truth interpolation that is the nearest to the generated noisy data at a given sampling timestep is observed to correspond | |
| to a higher-noise timestep (termed slowed timestep), i.e., the corresponding ground-truth timestep is slower than the | |
| sampling timestep. MixFlow leverages the interpolations at the slowed timesteps, named slowed interpolation mixture, for post-training the prediction network at each training timestep. | |
| ### Implementation | |
| The implementation is simple. For example, for MixFlow-RAE, 4 lines are added, and 1 line is modified in the file `src/stage2/transport/transport.py`: | |
| ```diff | |
| def sample(self, x1): | |
| """Sampling x0 & t based on shape of x1 (if needed) | |
| Args: | |
| x1 - data point; [batch, *dim] | |
| """ | |
| # ... | |
| if dist_options[0] == "uniform": | |
| t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 | |
| + t = 1 - th.sqrt(t) # sample t from Beta(2,1) | |
| # ... | |
| return t, x0, x1 | |
| def training_losses( | |
| self, | |
| model, | |
| x1, | |
| + gamma=0.4, # mixture range coefficient | |
| model_kwargs=None | |
| ): | |
| # ... | |
| t, x0, x1 = self.sample(x1) | |
| + t, _, ut = self.path_sampler.plan(t, x0, x1) # optional modification: remove the output xt, xt will be slowed interpolation | |
| + mt = t + th.rand(*t.size(), device=t.device, dtype=t.dtype) * gamma * (1 - t) # sample slowed timestep mt from U[(1-gamma)t, t] | |
| + _, xt, __ = self.path_sampler.plan(mt, x0, x1) # compute slowed interpolation | |
| model_output = model(xt, t, **model_kwargs) | |
| ``` | |
| ### This repository includes four folders. | |
| Each folder provides the training scripts, inference pipelines, and model weights for the following configurations: | |
| - MixFlow + RAE (Folder: `MixFlow-RAE`) | |
| - MixFlow + REPA (Folder: `MixFlow-REPA`) | |
| - MixFlow + SiT (Folder: `MixFlow-SiT`) | |
| - SD3.5-M + MixFlow (Folder: `MixFlow-SD3.5`) (TBD) | |
| ## Results | |
| ### ImageNet 256x256 | |
| | Model | Params | FID (w/o cfg) | FID (w/ cfg) | Checkpoint | | |
| |:------|:------:|:-------------:|:------------:|:----------:| | |
| | MixFlow + SiT-XL | 675M | 7.56 | 1.97 | [Download](https://huggingface.co/fudan-generative-ai/MixFlow) | | |
| | MixFlow + REPA-XL | 675M | 5.00 | 1.22 | [Download](https://huggingface.co/fudan-generative-ai/MixFlow) | | |
| | MixFlow + RAE-XL | 839M | **1.43** | **1.10** | [Download](https://huggingface.co/fudan-generative-ai/MixFlow/) | | |
| ### ImageNet 512x512 | |
| | Model | Params | FID (w/o cfg) | FID (w/ cfg) | Checkpoint | | |
| |:------|:------:|:-------------:|:------------:|:----------:| | |
| | MixFlow + RAE-XL | 839M | **1.55** | **1.10** | [Download](https://huggingface.co/fudan-generative-ai/MixFlow) | | |
| ## Getting Started | |
| ### 1. Environment Setup | |
| To set up our environment, please run: | |
| ```bash | |
| git clone https://github.com//MixFlow.git | |
| cd MixFlow | |
| # Using conda | |
| conda activate -n mixflow python=3.10 -y | |
| conda activate mixflow | |
| # Or using venv from the Python standard library (optional) | |
| python3.10 -m venv .venv | |
| source .venv/bin/activate | |
| # Install uv | |
| pip install uv | |
| # Install PyTorch 2.8.0 with CUDA 12.4 | |
| uv pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu124 | |
| # Install other dependencies | |
| uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb | |
| uv pip install "numpy<2" transformers einops omegaconf diffusers requests ftfy regex | |
| ``` | |
| ### 2. Data, Model Download | |
| Data preparation (ImageNet-1k), pretrained model download procedures should follow the settings detailed in the RAE documentation [here](https://github.com/bytetriper/RAE?tab=readme-ov-file#download-pre-trained-models). | |
| The pretrained model is comprised of two parts: the encoder and decoder, and the DiTDH-XL. After downloading these checkpoints into the `MixFlow-RAE/models` folder, the next training step will use the pretrained DiTDH-XL as the starting checkpoint. | |
| ### 3. Post-train the RAE model with MixFlow | |
| ```bash | |
| cd MixFlow-RAE | |
| torchrun --standalone --nnodes=1 --nproc_per_node=N \ | |
| src/train.py \ | |
| --config configs/stage2/training/ImageNet256/DiTDH-XL_DINOv2-B.yaml \ | |
| --data-path <imagenet_train_split> \ | |
| --results-dir results/mixflow \ | |
| --precision fp32 \ | |
| --ckpt models/DiTs/Dinov2/wReg_base/ImageNet256/DiTDH-XL/stage2_model.pt # load the pretrained model | |
| ``` | |
| ### 4. Evaluation | |
| For distributed sampling, please run: | |
| ```bash | |
| torchrun --standalone --nnodes=1 --nproc_per_node=N \ | |
| src/sample_ddp.py \ | |
| --config <sample_config> \ | |
| --sample-dir samples \ | |
| --precision fp32 \ | |
| --label-sampling equal | |
| ``` | |
| Note that we utilize an autoguidance scale of 1.5 for both 256 and 512 resolutions. This differs from the original RAE settings, which use 1.42 for 256 resolution and 1.5 for 512 resolution. | |
| After generating 50k samples, evaluate the results using the ADM evaluation suite. For detailed instructions, please refer to the RAE documentation [here](https://github.com/bytetriper/RAE?tab=readme-ov-file#evaluation). | |
| ## Citation | |
| If you find this work useful, please cite: | |
| ```bibtex | |
| @article{mixflow2025, | |
| title={MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture}, | |
| author={Hui Li and Jiayue Lyu and Fu-yun Wang and Kaihui Cheng and Siyu Zhu and Jingdong Wang}, | |
| journal={arXiv preprint arXiv: 2512.19311}, | |
| year={2025} | |
| } | |
| ``` | |
| ## Acknowledgements | |
| This codebase builds upon the following excellent works: | |
| - [SiT](https://github.com/willisma/SiT) - Scalable Interpolant Transformers | |
| - [REPA](https://github.com/sihyun-yu/REPA) - Representation Alignment | |
| - [RAE](https://github.com/bytetriper/RAE) - Representation Autoencoders | |
| - [diffusion-pipe](https://github.com/tdrussell/diffusion-pipe) - A pipeline parallel training script for diffusion models. | |
| --- | |
| license: mit | |
| --- | |