File size: 8,135 Bytes
1e99e5f
 
 
 
 
 
 
 
 
 
 
 
 
 
d382778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e99e5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
---
license: mit
datasets:
- uoft-cs/cifar10
- marcosv/ffhq-dataset
- jiachenlei/imagenet
- huggan/AFHQv2
- laion/laion-coco
language:
- en
base_model:
- CompVis/stable-diffusion-v1-4
pipeline_tag: text-to-image
---
# Learning to Discretize Denoising Diffusion ODEs

πŸ† ![ICLR2025 Oral](https://img.shields.io/badge/ICLR2025-Oral-blue)

### [Paper on OpenReview](https://openreview.net/forum?id=xDrFWUmCne)

Implementation of LD3, a lightweight framework designed to learn the optimal time discretization for sampling from pre-trained Diffusion Probabilistic Models (DPMs). LD3 can be combined with various samplers and consistently improves generation quality without having to retrain resource-intensive neural networks. LD3 offers an efficient approach to sampling from pre-trained diffusion models.

![Alt Text](visualizations/illustration-lddd.png)

## πŸ”₯ Latest News

- **March 2025**: We have successfully applied **LD3** to the **Flux-dev** model and observed promising results.  
- We are releasing the trained time steps for the Flux model soon! Stay tuned for updates.  


## Setup Environment

We will set up the environment using [Anaconda](https://docs.anaconda.com/anaconda/install/index.html). 

```bash
conda env create -f requirements.yml
conda activate ld3
pip install -e ./src/clip/
pip install -e ./src/taming-transformers/
pip install omegaconf
pip install PyYAML
pip install requests
pip install scipy
pip install torchmetrics
```

## Download Pretrained Models and FID Reference Sets

All necessary data will be automatically downloaded by the script. Note that this process may take some time. If you wish to skip certain downloads, you can comment out the corresponding lines in the script.

```bash
bash scripts/download_model.sh
wget https://raw.githubusercontent.com/tylin/coco-caption/master/annotations/captions_val2014.json
```


## πŸš€ Generating Training Data for LD3  

Before training **LD3**, we first need to generate training data using the teacher solver. The script `gen_data.py` handles this process. Below is an example of generating training data with **20 sampling steps** for **CIFAR-10**, using the `uni_pc` solver and `time-edm` discretization.  

### πŸ“Œ Example: Generating CIFAR-10 Training Data

```bash
CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
                    --all_config configs/cifar10.yml \
                    --total_samples 100 \
                    --sampling_batch_size 10 \
                    --steps 20 \
                    --solver_name uni_pc \
                    --skip_type edm \
                    --save_pt --save_png --data_dir train_data/train_data_cifar10 \
                    --low_gpu
```
#### πŸ“Œ Key Arguments:

- `all_config`: Path to the default configuration file (mandatory). If other arguments are not specified, their values will be taken from this file.
- `solver_name`: Solver to use. Options include `uni_pc`, `dpm_solver++`, `euler`, and `ipndm`.
- `skip_type`: Discretization method. Options include `edm`, `time_uniform`, and `time_quadratic`.
- `low_gpu`: Enables the use of PyTorch's `checkpoint` feature to reduce GPU memory usage.
- `data_dir`: Root directory for saving the generated data. The script will create a subdirectory within this path using the naming format `${solver_name}_NFE${steps}_${skip_type}`.

### πŸ“Œ Example: Generating Stable Diffusion Training Data

For Stable Diffusion, you must additionally specify the prompt file and the number of prompts. Below is an example:

```bash
CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
                    --all_config configs/stable_diff_v1-4.yml \
                    --total_samples 100 \
                    --sampling_batch_size 2 \
                    --steps 6 \
                    --solver_name uni_pc \
                    --skip_type time_uniform \
                    --save_pt --save_png --data_dir train_data/train_data_stable_diff_v1-4 \
                    --low_gpu \
                    --num_prompts 5 --prompt_path captions_val2014.json
```

## Training LD3
After generating the training data, you can train **LD3** using the `main.py` script. Below is an example of training **LD3** on **CIFAR-10** with the following configurations:
- **Teacher**: 20 sampling steps, `uni_pc` solver, and `time-edm` discretization.
- **Student**: 10 sampling steps, `dpm_solver++` solver.

```bash
CUDA_VISIBLE_DEVICES=0 python3 main.py \
                    --all_config configs/cifar10.yml \
                    --data_dir train_data/train_data_cifar10/uni_pc_NFE20_edm \
                    --num_train 50 --num_valid 50 \
                    --main_train_batch_size 1 \
                    --main_valid_batch_size 10 \
                    --solver_name dpm_solver++ \
                    --training_rounds_v1 2 \
                    --training_rounds_v2 5 \
                    --steps 10 \
                    --log_path logs/logs_cifar10
```

**Trained timesteps are available [here](https://docs.google.com/spreadsheets/d/1nUrTDvvtpPHZuRuJcn3zzxGmVKrNX4fFHu8wYIIoGSM/edit?usp=sharing) and are still being updated.**



#### πŸ“Œ Key Arguments:

- `data_dir`: The full path to the training data directory (unlike the root directory used during data generation).
- `log_path`: The root directory for saving logs and models. The script will create a subdirectory within this path using the naming format: `${solver_name}-N${steps}-b${bound}-${loss_type}-lr2${lr2}rv1${rv1}-rv2${rv2}`, for example, `uni_pc-N10-b0.03072-LPIPS-lr20.01rv12-rv25`

## FID Evaluation


### ⚠️ Different FID Scores

It is important to note that FID (FrΓ©chet Inception Distance) scores can vary significantly depending on the processing pipeline used. To ensure transparency and reproducibility, our framework provides a script `compute_fid.py` that supports FID evaluation for both EDM and Latent-Diffusion.

### πŸ“Œ How FID Evaluation Works

The `compute_fid.py` script is a streamlined version of `gen_data.py` with a few differences:

The `--save_dir`, `--save_pt`, and `--save_png` arguments are ignored because the generated data is directly processed for FID calculation without being saved.

The data is automatically forwarded to the FID computation module to extract features.

Optionally, you can pass your own timesteps via `--custom_ts_1` and `--custom_ts_2`. If `custom_ts_2` is not specified, it will be set the same as `custom_ts_1`
### πŸ“Œ Example: Computing FID for Stable Diffusion

```bash 
CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
                    --all_config configs/stable_diff_v1-5.yml \
                    --total_samples 100 \
                    --sampling_batch_size 2 \
                    --steps 6 \
                    --solver_name uni_pc \
                    --skip_type time_uniform \
                    --low_gpu \
                    --num_prompts 5 --prompt_path captions_val2014.json

CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
                    --all_config configs/stable_diff_v1-5.yml \
                    --total_samples 100 \
                    --sampling_batch_size 2 \
                    --steps 4 \
                    --solver_name ipndm \
                    --skip_type custom \
                    --custom_ts_1  [1.0000e+00,7.6668e-01,4.8113e-01,1.8417e-01,1.0000e-03] \
                    --custom_ts_2  [1.0000e+00,7.6706e-01,4.8103e-01,1.8396e-01,1.0000e-03] \
                    --low_gpu \
                    --num_prompts 5 --prompt_path captions_val2014.json

```
## Citation 

```
@inproceedings{tong2024learning,
  title     = {Learning to Discretize Denoising Diffusion ODEs},
  author    = {Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Van den Broeck, Guy and Niepert, Mathias},
  booktitle = {Proceedings of the 13th International Conference on Learning Representations},
  year      = {2025}
}
```

```
@article{tong2024learning,
  title={Learning to Discretize Denoising Diffusion ODEs},
  author={Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Broeck, Guy Van den and Niepert, Mathias},
  journal={arXiv preprint arXiv:2405.15506},
  year={2024}
}
```

## License
MIT