File size: 2,906 Bytes
9aa9fef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c205191
b91d23a
9aa9fef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c205191
 
9aa9fef
 
 
 
 
 
 
 
 
 
 
c205191
9aa9fef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
---
license: apache-2.0
language:
  - en
library_name: diffusers
tags:
  - diffusers
  - image-generation
  - class-conditional
  - dit-moe
pipeline_tag: unconditional-image-generation
---

# DiT-MoE-diffusers

Diffusers implementation of **DiT-MoE** (Diffusion Transformer with Mixture of Experts) for class-conditional ImageNet generation. Each variant folder is self-contained:

- `pipeline.py` β€” `DiTMoEPipeline`
- `scheduler/scheduler_config.json` β€” `DDIMScheduler` (S/B) or `DiTMoEFlowMatchScheduler` (XL/G)
- `transformer/transformer_dit_moe.py` β€” `DiTMoETransformer2DModel`
- `vae/` β€” `AutoencoderKL` (`stabilityai/sd-vae-ft-mse`)

## ImageNet class labels

Each variant keeps an English `id2label` map directly in its own `model_index.json` (DiT-style).

- `pipe.id2label` β€” inspect id β†’ English label correspondence
- `pipe.labels` β€” reverse map (English synonym β†’ id), sorted for browsing
- `pipe.get_label_ids("golden retriever")`
- `pipe(class_labels="golden retriever", ...)` β€” string labels resolved automatically

## Available checkpoints

| Checkpoint | Path | Resolution | Sampler |
| --- | --- | --- | --- |
| DiT-MoE-S/2-8E2A | `./DiT-MoE-S-8E2A` | 256Γ—256 | DDIM |
| DiT-MoE-B/2-8E2A | `./DiT-MoE-B-8E2A` | 256Γ—256 | DDIM |
| DiT-MoE-XL/2-8E2A | `./DiT-MoE-XL-8E2A` | 256Γ—256 | RF |
| DiT-MoE-G/2-16E2A | *(convert with `--rectified-flow --num-experts 16`)* | 256Γ—256 | RF |

## Convert from official weights

```bash
conda activate rsgen
cd libs/DiT-MoE-diffusers

python scripts/convert_dit_moe_to_diffusers.py \
  --checkpoint ../../models/feizhengcong/DiT-MoE/dit_moe_s_8E2A.pt \
  --output ../../models/BiliSakura/DiT-MoE-diffusers/DiT-MoE-S-8E2A \
  --model DiT-S/2 \
  --num-experts 8 \
  --num-experts-per-tok 2 \
  --copy-vae ../../models/feizhengcong/DiT-MoE/sd-vae-ft-mse \
  --check-load
```

## Inference

Use `torch.bfloat16` on Ampere+ GPUs (default in examples and `sample_dit_moe.py`).

```python
from pathlib import Path
import torch
from diffusers import DiffusionPipeline

model_dir = Path("./DiT-MoE-S-8E2A").resolve()
pipe = DiffusionPipeline.from_pretrained(
    str(model_dir),
    local_files_only=True,
    custom_pipeline=str(model_dir / "pipeline.py"),
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

print(pipe.id2label[207])
print(pipe.get_label_ids("golden retriever"))

generator = torch.Generator(device="cuda").manual_seed(42)
image = pipe(
    class_labels="golden retriever",
    height=256,
    width=256,
    num_inference_steps=50,
    guidance_scale=4.0,
    generator=generator,
).images[0]
image.save("demo.png")
```

## Citation

```bibtex
@article{FeiDiTMoE2024,
  title={Scaling Diffusion Transformers to 16 Billion Parameters},
  author={Zhengcong Fei and Mingyuan Fan and Changqian Yu and Debang Li and Jusnshi Huang},
  year={2024},
  journal={arXiv preprint arXiv:2407.11633},
}
```