--- license: apache-2.0 language: - en library_name: diffusers tags: - diffusers - image-generation - class-conditional - dit-moe pipeline_tag: unconditional-image-generation --- # DiT-MoE-diffusers Diffusers implementation of **DiT-MoE** (Diffusion Transformer with Mixture of Experts) for class-conditional ImageNet generation. Each variant folder is self-contained: - `pipeline.py` — `DiTMoEPipeline` - `scheduler/scheduler_config.json` — `DDIMScheduler` (S/B) or `DiTMoEFlowMatchScheduler` (XL/G) - `transformer/transformer_dit_moe.py` — `DiTMoETransformer2DModel` - `vae/` — `AutoencoderKL` (`stabilityai/sd-vae-ft-mse`) ## ImageNet class labels Each variant keeps an English `id2label` map directly in its own `model_index.json` (DiT-style). - `pipe.id2label` — inspect id → English label correspondence - `pipe.labels` — reverse map (English synonym → id), sorted for browsing - `pipe.get_label_ids("golden retriever")` - `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically ## Available checkpoints | Checkpoint | Path | Resolution | Sampler | | --- | --- | --- | --- | | DiT-MoE-S/2-8E2A | `./DiT-MoE-S-8E2A` | 256×256 | DDIM | | DiT-MoE-B/2-8E2A | `./DiT-MoE-B-8E2A` | 256×256 | DDIM | | DiT-MoE-XL/2-8E2A | `./DiT-MoE-XL-8E2A` | 256×256 | RF | | DiT-MoE-G/2-16E2A | *(convert with `--rectified-flow --num-experts 16`)* | 256×256 | RF | ## Convert from official weights ```bash conda activate rsgen cd libs/DiT-MoE-diffusers python scripts/convert_dit_moe_to_diffusers.py \ --checkpoint ../../models/feizhengcong/DiT-MoE/dit_moe_s_8E2A.pt \ --output ../../models/BiliSakura/DiT-MoE-diffusers/DiT-MoE-S-8E2A \ --model DiT-S/2 \ --num-experts 8 \ --num-experts-per-tok 2 \ --copy-vae ../../models/feizhengcong/DiT-MoE/sd-vae-ft-mse \ --check-load ``` ## Inference Use `torch.bfloat16` on Ampere+ GPUs (default in examples and `sample_dit_moe.py`). ```python from pathlib import Path import torch from diffusers import DiffusionPipeline model_dir = Path("./DiT-MoE-S-8E2A").resolve() pipe = DiffusionPipeline.from_pretrained( str(model_dir), local_files_only=True, custom_pipeline=str(model_dir / "pipeline.py"), trust_remote_code=True, torch_dtype=torch.bfloat16, ) pipe.to("cuda") print(pipe.id2label[207]) print(pipe.get_label_ids("golden retriever")) generator = torch.Generator(device="cuda").manual_seed(42) image = pipe( class_labels="golden retriever", height=256, width=256, num_inference_steps=50, guidance_scale=4.0, generator=generator, ).images[0] image.save("demo.png") ``` ## Citation ```bibtex @article{FeiDiTMoE2024, title={Scaling Diffusion Transformers to 16 Billion Parameters}, author={Zhengcong Fei and Mingyuan Fan and Changqian Yu and Debang Li and Jusnshi Huang}, year={2024}, journal={arXiv preprint arXiv:2407.11633}, } ```