Update README.md
Browse files
README.md
CHANGED
|
@@ -1,198 +1,166 @@
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
---
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
This
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
-
- **Developed by:** [More Information Needed]
|
| 20 |
-
- **Funded by [optional]:** [More Information Needed]
|
| 21 |
-
- **Shared by [optional]:** [More Information Needed]
|
| 22 |
-
- **Model type:** [More Information Needed]
|
| 23 |
-
- **Language(s) (NLP):** [More Information Needed]
|
| 24 |
-
- **License:** [More Information Needed]
|
| 25 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
### Direct Use
|
| 40 |
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
-
[
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
### Out-of-Scope Use
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
## Bias, Risks, and Limitations
|
| 58 |
|
| 59 |
-
|
| 60 |
|
| 61 |
-
[
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
|
| 81 |
-
|
| 82 |
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
|
| 88 |
|
| 89 |
-
[More Information Needed]
|
| 90 |
|
|
|
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
[More Information Needed]
|
| 101 |
|
| 102 |
-
##
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
|
| 106 |
-
### Testing Data, Factors & Metrics
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
<!-- This should link to a Dataset Card if possible. -->
|
| 111 |
-
|
| 112 |
-
[More Information Needed]
|
| 113 |
-
|
| 114 |
-
#### Factors
|
| 115 |
-
|
| 116 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 117 |
-
|
| 118 |
-
[More Information Needed]
|
| 119 |
-
|
| 120 |
-
#### Metrics
|
| 121 |
-
|
| 122 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 123 |
-
|
| 124 |
-
[More Information Needed]
|
| 125 |
-
|
| 126 |
-
### Results
|
| 127 |
-
|
| 128 |
-
[More Information Needed]
|
| 129 |
-
|
| 130 |
-
#### Summary
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
## Model Examination [optional]
|
| 135 |
-
|
| 136 |
-
<!-- Relevant interpretability work for the model goes here -->
|
| 137 |
-
|
| 138 |
-
[More Information Needed]
|
| 139 |
-
|
| 140 |
-
## Environmental Impact
|
| 141 |
-
|
| 142 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 143 |
-
|
| 144 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 145 |
-
|
| 146 |
-
- **Hardware Type:** [More Information Needed]
|
| 147 |
-
- **Hours used:** [More Information Needed]
|
| 148 |
-
- **Cloud Provider:** [More Information Needed]
|
| 149 |
-
- **Compute Region:** [More Information Needed]
|
| 150 |
-
- **Carbon Emitted:** [More Information Needed]
|
| 151 |
-
|
| 152 |
-
## Technical Specifications [optional]
|
| 153 |
-
|
| 154 |
-
### Model Architecture and Objective
|
| 155 |
-
|
| 156 |
-
[More Information Needed]
|
| 157 |
-
|
| 158 |
-
### Compute Infrastructure
|
| 159 |
-
|
| 160 |
-
[More Information Needed]
|
| 161 |
-
|
| 162 |
-
#### Hardware
|
| 163 |
-
|
| 164 |
-
[More Information Needed]
|
| 165 |
-
|
| 166 |
-
#### Software
|
| 167 |
-
|
| 168 |
-
[More Information Needed]
|
| 169 |
-
|
| 170 |
-
## Citation [optional]
|
| 171 |
-
|
| 172 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 173 |
-
|
| 174 |
-
**BibTeX:**
|
| 175 |
-
|
| 176 |
-
[More Information Needed]
|
| 177 |
-
|
| 178 |
-
**APA:**
|
| 179 |
-
|
| 180 |
-
[More Information Needed]
|
| 181 |
-
|
| 182 |
-
## Glossary [optional]
|
| 183 |
-
|
| 184 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 185 |
-
|
| 186 |
-
[More Information Needed]
|
| 187 |
-
|
| 188 |
-
## More Information [optional]
|
| 189 |
-
|
| 190 |
-
[More Information Needed]
|
| 191 |
-
|
| 192 |
-
## Model Card Authors [optional]
|
| 193 |
-
|
| 194 |
-
[More Information Needed]
|
| 195 |
-
|
| 196 |
-
## Model Card Contact
|
| 197 |
-
|
| 198 |
-
[More Information Needed]
|
|
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
---
|
| 4 |
+
## Scalable Diffusion Models with Transformers (DiT)<br><sub>Official PyTorch Implementation</sub>
|
| 5 |
|
| 6 |
+
### [Paper](http://arxiv.org/abs/2212.09748) | [Project Page](https://www.wpeebles.com/DiT) | Run DiT-XL/2 [](https://huggingface.co/spaces/wpeebles/DiT) [](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) <a href="https://replicate.com/arielreplicate/scalable_diffusion_with_transformers"><img src="https://replicate.com/arielreplicate/scalable_diffusion_with_transformers/badge"></a>
|
| 7 |
|
| 8 |
+

|
| 9 |
|
| 10 |
+
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring
|
| 11 |
+
diffusion models with transformers (DiTs). You can find more visualizations on our [project page](https://www.wpeebles.com/DiT).
|
| 12 |
|
| 13 |
+
> [**Scalable Diffusion Models with Transformers**](https://www.wpeebles.com/DiT)<br>
|
| 14 |
+
> [William Peebles](https://www.wpeebles.com), [Saining Xie](https://www.sainingxie.com)
|
| 15 |
+
> <br>UC Berkeley, New York University<br>
|
| 16 |
|
| 17 |
+
We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on
|
| 18 |
+
latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass
|
| 19 |
+
complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or
|
| 20 |
+
increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our
|
| 21 |
+
DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks,
|
| 22 |
+
achieving a state-of-the-art FID of 2.27 on the latter.
|
| 23 |
|
| 24 |
+
This repository contains:
|
| 25 |
|
| 26 |
+
* 🪐 A simple PyTorch [implementation](models.py) of DiT
|
| 27 |
+
* ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256)
|
| 28 |
+
* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running pre-trained DiT-XL/2 models
|
| 29 |
+
* 🛸 A DiT [training script](train.py) using PyTorch DDP
|
| 30 |
|
| 31 |
+
An implementation of DiT directly in Hugging Face `diffusers` can also be found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx).
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
## Setup
|
| 35 |
|
| 36 |
+
First, download and set up the repo:
|
| 37 |
|
| 38 |
+
```bash
|
| 39 |
+
git clone https://github.com/facebookresearch/DiT.git
|
| 40 |
+
cd DiT
|
| 41 |
+
```
|
| 42 |
|
| 43 |
+
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
|
| 44 |
+
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
|
| 45 |
|
| 46 |
+
```bash
|
| 47 |
+
conda env create -f environment.yml
|
| 48 |
+
conda activate DiT
|
| 49 |
+
```
|
| 50 |
|
|
|
|
| 51 |
|
| 52 |
+
## Sampling [](https://huggingface.co/spaces/wpeebles/DiT) [](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb)
|
| 53 |
+

|
| 54 |
|
| 55 |
+
**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights for our pre-trained DiT model will be
|
| 56 |
+
automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256
|
| 57 |
+
and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
|
| 58 |
+
our 512x512 DiT-XL/2 model, you can use:
|
| 59 |
|
| 60 |
+
```bash
|
| 61 |
+
python sample.py --image-size 512 --seed 1
|
| 62 |
+
```
|
| 63 |
|
| 64 |
+
For convenience, our pre-trained DiT models can be downloaded directly here as well:
|
| 65 |
|
| 66 |
+
| DiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
|
| 67 |
+
|---------------|------------------|---------|-----------------|--------|
|
| 68 |
+
| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) | 256x256 | 2.27 | 278.24 | 119 |
|
| 69 |
+
| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) | 512x512 | 3.04 | 240.82 | 525 |
|
| 70 |
|
|
|
|
| 71 |
|
| 72 |
+
**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), you can add the `--ckpt`
|
| 73 |
+
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
|
| 74 |
+
256x256 DiT-L/4 model, run:
|
| 75 |
|
| 76 |
+
```bash
|
| 77 |
+
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
|
| 78 |
+
```
|
| 79 |
|
|
|
|
| 80 |
|
| 81 |
+
## Training DiT
|
| 82 |
|
| 83 |
+
We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional
|
| 84 |
+
DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with `N` GPUs on
|
| 85 |
+
one node:
|
| 86 |
|
| 87 |
+
```bash
|
| 88 |
+
torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train
|
| 89 |
+
```
|
| 90 |
|
| 91 |
+
### PyTorch Training Results
|
| 92 |
|
| 93 |
+
We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script
|
| 94 |
+
to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give
|
| 95 |
+
similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:
|
| 96 |
|
| 97 |
+
| DiT Model | Train Steps | FID-50K<br> (JAX Training) | FID-50K<br> (PyTorch Training) | PyTorch Global Training Seed |
|
| 98 |
+
|------------|-------------|----------------------------|--------------------------------|------------------------------|
|
| 99 |
+
| XL/2 | 400K | 19.5 | **18.1** | 42 |
|
| 100 |
+
| B/4 | 400K | **68.4** | 68.9 | 42 |
|
| 101 |
+
| B/4 | 400K | 68.4 | **68.3** | 100 |
|
| 102 |
|
| 103 |
+
These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID
|
| 104 |
+
here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`).
|
| 105 |
|
| 106 |
+
**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults.
|
| 107 |
+
We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on
|
| 108 |
+
A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to
|
| 109 |
+
the above results.
|
| 110 |
|
| 111 |
+
### Enhancements
|
| 112 |
+
Training (and sampling) could likely be sped-up significantly by:
|
| 113 |
+
- [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model
|
| 114 |
+
- [ ] using `torch.compile` in PyTorch 2.0
|
| 115 |
|
| 116 |
+
Basic features that would be nice to add:
|
| 117 |
+
- [ ] Monitor FID and other metrics
|
| 118 |
+
- [ ] Generate and save samples from the EMA model periodically
|
| 119 |
+
- [ ] Resume training from a checkpoint
|
| 120 |
+
- [ ] AMP/bfloat16 support
|
| 121 |
|
| 122 |
+
**🔥 Feature Update** Check out this repository at https://github.com/chuanyangjin/fast-DiT to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features. With these advancements, we have achieved a training speed of 0.84 steps/sec for DiT-XL/2 using just a single A100 GPU.
|
| 123 |
|
| 124 |
+
## Evaluation (FID, Inception Score, etc.)
|
| 125 |
|
| 126 |
+
We include a [`sample_ddp.py`](sample_ddp.py) script which samples a large number of images from a DiT model in parallel. This script
|
| 127 |
+
generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow
|
| 128 |
+
evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and
|
| 129 |
+
other metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model over `N` GPUs, run:
|
| 130 |
|
| 131 |
+
```bash
|
| 132 |
+
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000
|
| 133 |
+
```
|
| 134 |
|
| 135 |
+
There are several additional options; see [`sample_ddp.py`](sample_ddp.py) for details.
|
| 136 |
|
|
|
|
| 137 |
|
| 138 |
+
## Differences from JAX
|
| 139 |
|
| 140 |
+
Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models.
|
| 141 |
+
There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated
|
| 142 |
+
our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID
|
| 143 |
+
versus 2.27 in the paper).
|
| 144 |
|
|
|
|
| 145 |
|
| 146 |
+
## BibTeX
|
| 147 |
|
| 148 |
+
```bibtex
|
| 149 |
+
@article{Peebles2022DiT,
|
| 150 |
+
title={Scalable Diffusion Models with Transformers},
|
| 151 |
+
author={William Peebles and Saining Xie},
|
| 152 |
+
year={2022},
|
| 153 |
+
journal={arXiv preprint arXiv:2212.09748},
|
| 154 |
+
}
|
| 155 |
+
```
|
| 156 |
|
|
|
|
| 157 |
|
| 158 |
+
## Acknowledgments
|
| 159 |
+
We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions.
|
| 160 |
+
William Peebles is supported by the NSF Graduate Research Fellowship.
|
| 161 |
|
| 162 |
+
This codebase borrows from OpenAI's diffusion repos, most notably [ADM](https://github.com/openai/guided-diffusion).
|
| 163 |
|
|
|
|
| 164 |
|
| 165 |
+
## License
|
| 166 |
+
The code and model weights are licensed under CC-BY-NC. See [`LICENSE.txt`](LICENSE.txt) for details.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|