File size: 7,193 Bytes
e944abb
 
 
 
 
 
 
8fccd3f
972a35a
8fccd3f
972a35a
8fccd3f
972a35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fccd3f
972a35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fccd3f
972a35a
 
8fccd3f
972a35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
---
license: mit
datasets:
- hzzheng/MaskDiT-imagenet
tags:
- code
---

# Fast Training of Diffusion Models with Masked Transformers
This repository hosts large model checkpoints and the Pytorch implementation for [MaskDiT](https://github.com/Anima-Lab/MaskDiT.git). 

Paper:<br>
**[Fast Training of Diffusion Models with Masked Transformers](https://openreview.net/pdf?id=vTBjBtGioE)**
<br>
Hongkai Zheng*, Weili Nie*, Arash Vahdat, Anima Anandkumar <br>
(*Equal contribution)<br>

Abstract: *While masked transformers have been extensively explored for representation learning, their application to
generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce
the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of
patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder
architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer
decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of
reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches.
Experiments on ImageNet-256x256 and ImageNet-512x512 show that our approach achieves competitive and even better
generative performance than the state-of-the-art Diffusion Transformer (DiT) model, using only around 30% of its
original training time. Thus, our method shows a promising way of efficiently training large transformer-based diffusion
models without sacrificing the generative performance.*



## Requirements

- Training MaskDiT on ImageNet256x256 takes around 260 hours with 8 A100 GPUs to perform 2M updates with a batch size of
  1024.
- Training MaskDiT on ImageNet512x512 takes around 210 A100 GPU days to perform 1M updates with a batch size of 1024.
- At least one high-end GPU for sampling.
- [Dockerfile](Dockerfile) is provided for exact software environment.

## Efficiency

Our MaskDiT applies Automatic Mixed Precision (AMP) by default. We also add the MaskDiT without AMP (Ours_ft32) for
reference.

### Training speed

<img src="assets/figs/bar_speed_256.png" width=45% style="display: inline-block;"><img src="assets/figs/bar_speed_512.png" width=46%  style="display: inline-block;">

### GPU memory

<img src="assets/figs/bar_mem_256.png" width=45% style="display: inline-block;"><img src="assets/figs/bar_mem_512.png" width=44.3%  style="display: inline-block;">

## Pretrained Models
We provide pretrained models of MaskDiT for ImageNet256 and ImageNet512 in the following table. For FID with guidance, the guidance scale is set to 1.5 by default.
| Guidance | Resolution | FID   | Model                                                                                                                     |
| :------- | :--------- | :---- | :------------------------------------------------------------------------------------------------------------------------ |
| Yes      | 256x256    | 2.28  | [imagenet256-guidance.pt](checkpoints/imagenet256_with_guidance.pt)       |
| No       | 256x256    | 5.69  | [imagenet256-conditional.pt](checkpoints/imagenet256_without_guidance.pt) |
| Yes      | 512x512    | 2.50  | [imagenet512-guidance.pt](checkpoints/imagenet512_with_guidance.pt)                                   |
| No       | 512x512    | 10.79 | [imagenet512-conditional.pt](checkpoints/imagenet512_without_guidance.pt)                                 |                                                                                       |

## Generate from pretrained models

To generate samples from provided checkpoints, run

```bash
python3 generate.py --config configs/test/maskdit-512.yaml --ckpt_path [path to checkpoints] --class_idx [class index from 0-999] --cfg_scale [guidance scale]
```

<img src="assets/figs/12samples_compressed.png" title="Generated samples from MaskDiT 256x256." width="850" style="display: block; margin: 0 auto;"/>
<p align='center'> Generated samples from MaskDiT 256x256. Upper panel: without CFG. Lower panel: with CFG (scale=1.5).
<p\>

<img src="assets/figs/imagenet512.png" title="Generated samples from MaskDiT 512x512." width="850" style="display: block; margin: 0 auto;"/>
<p align='center'> Generated samples from MaskDiT 512x512 with CFG (scale=1.5).
<p\>

## Dataset

We use the pre-trained VAE to first encode the ImageNet dataset into latent space. You can download the ImageNet-256x256
and ImageNet-512x512 that have been encoded into latent space from [hzzheng/MaskDiT-imagenet](https://huggingface.co/datasets/hzzheng/MaskDiT-imagenet). 

### LMDB to Webdataset

When training on ImageNet-256x256, we store our data in LMDB format. When training on ImageNet-512x512, we
use [webdataset](https://github.com/webdataset/webdataset) for faster IO performance. To convert a LMDB dataset into a
webdataset, run

```
python3 lmdb2wds.py --datadir [path to lmdb] --outdir [path to save webdataset] --resolution [latent resolution] --num_channels [number of latent channels]
```

## Train

### ImageNet-256x256

First train MaskDiT with 50% mask ratio with AMP enabled.

```bash
accelerate launch --multi_gpu train.py --config configs/train/imagenet256-latent.yaml
```

Then finetune with unmasking.

```bash
accelerate launch --multi_gpu train.py --config configs/finetune/imagenet256-latent-const.yaml --ckpt_path [path to checkpoint] --use_ckpt_path False --use_strict_load False --no_amp
```

### ImageNet-512x512

Train MaskDiT with 50% mask ratio with AMP enabled. Here is an example of 4-node training script.

```bash
bash scripts/train_latent512.sh
```

Then finetune with unmasking.

```bash
bash scripts/finetune_latent512.sh
```

## Evaluation

### FID evaluation

To compute a FID of a pretrained model, run

```bash
accelerate launch --multi_gpu eval_latent.py --config configs/test/maskdit-256.yaml --ckpt [path to the pretrained model] --cfg_scale [guidance scale]
```

### Full evaluation

First, download the reference from [ADM repo](https://github.com/openai/guided-diffusion/tree/main/evaluations)
directly. You can also use `download_assets.py` by running

```bash
python3 download_assets.py --name imagenet256 --dest [destination directory]
```

Then we use the evaluator `evaluator.py`
from [ADM repo](https://github.com/openai/guided-diffusion/tree/main/evaluations), or `fid.py`
from [EDM repo](https://github.com/NVlabs/edm), to evaluate the generated samples.


### Citation
```
@inproceedings{Zheng2024MaskDiT,
  title={Fast Training of Diffusion Models with Masked Transformers},
  author={Zheng, Hongkai and Nie, Weili and Vahdat, Arash and Anandkumar, Anima},
  booktitle = {Transactions on Machine Learning Research (TMLR)},
  year={2024}
}
```


## Acknowledgements

Thanks to the open source codebases such as [DiT](https://github.com/facebookresearch/DiT)
, [MAE](https://github.com/facebookresearch/mae), [U-ViT](https://github.com/baofff/U-ViT)
, [ADM](https://github.com/openai/guided-diffusion), and [EDM](https://github.com/NVlabs/edm). Our codebase is built on
them.