MMDiff: Extending Diffusion Transformers for Multi-Modal Generation
Yagmur Akarken, Orest Kupyn, Christian Rupprecht — Visual Geometry Group, University of Oxford
Paper · Code · Project page
MMDiff turns a frozen diffusion transformer into a multi-modal generator. To create an image, a diffusion model must build up the semantic and geometric structure of the scene — and normally discards it once the image is rendered. MMDiff keeps that structure and decodes it into aligned dense outputs (semantic segmentation, salient object detection, monocular depth) in the same generation pass. Because the maps come straight from the generator, one frozen model can produce an image and its annotations together, enabling synthetic data generation at scale.
This repository hosts the trained decoder-head checkpoints. The FLUX.1-dev backbone and the optional DINOv3 encoder are never fine-tuned — only lightweight heads (~36M parameters) are trained.
- Backbone: FLUX.1-dev (frozen) + optional DINOv3 ViT-B/16 (frozen)
- Trained parameters: ~36M (decoder heads + feature-fusion module only)
How it works
- Multi-timestep feature fusion. A small learned module reads FLUX features from several denoising steps and fuses them with spatially varying aggregation weights. Perceptual information is temporally distributed along the trajectory, so this is the largest contributor — up to +28.7% mIoU over single-timestep extraction.
- Concept-driven attention. The frozen model provides interpretable spatial guidance (e.g. object vs. background, near vs. far) as extra cues for each task.
- Per-task decoder. A standard lightweight decoder is trained per task (DeepLabV3+ for segmentation/saliency, DPT for depth). The generator is never trained.
- Complementary to encoders. Frozen FLUX features are competitive with — and complementary to — state-of-the-art encoders such as DINOv3; combining them improves every task.
Checkpoints
| File | Task | Dataset | Config (in code repo) |
|---|---|---|---|
pascal_segmentation.ckpt |
Semantic segmentation | Pascal VOC 2012 | configs/pascal_voc_config.yaml |
duts_saliency.ckpt |
Salient object detection | DUTS | configs/duts_config.yaml |
nyu_depth.ckpt |
Monocular depth | NYU Depth V2 | configs/nyu_depth_config.yaml |
Each file is a PyTorch Lightning checkpoint; inference loads only the model weights (state_dict).
Results
Feature quality (frozen backbone, lightweight heads). Higher is better unless marked ↓.
| Variant | VOC mIoU ↑ | DUTS Sₘ ↑ | DUTS MAE ↓ | NYU AbsRel ↓ | NYU RMSE ↓ |
|---|---|---|---|---|---|
| MMDiff | 78.9 | 0.918 | 0.020 | 0.1175 | 0.370 |
| MMDiff + DINOv3 | 84.95 | 0.934 | 0.018 | 0.1164 | 0.365 |
Synthetic-data training (decoders trained on MMDiff-generated images + labels).
| Setting | VOC mIoU ↑ | DUTS Sₘ ↑ | NYU AbsRel ↓ |
|---|---|---|---|
| Synthetic only | 78.9 | 0.784 | 0.1880 |
| Synthetic + real fine-tune | 87.8 | 0.863 | 0.1185 |
Trained purely on synthetic data, MMDiff outperforms prior synthetic-data methods (DatasetDM, DiffuMask, Dataset Diffusion). See the paper for full tables.
Usage
Clone the code repository, then download a checkpoint and run inference:
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download("yagmurakarken/mmdiff", "pascal_segmentation.ckpt")
python scripts/inference.py --task pascal \
--config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \
--image my_image.jpg --output_dir outputs/
Swap pascal → duts / nyu (with the matching config and checkpoint) for the other tasks.
Architecture flags (
--hidden_dim,--num_transformer_layers,--num_timesteps,--dino_model) must match the checkpoint you load.
You can also generate an image and its aligned annotation together with the same frozen backbone:
python scripts/generate.py --task pascal \
--config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \
--prompts_file prompts.txt --output_dir synth/
Intended uses & limitations
- Intended use: dense prediction (segmentation, saliency, depth) from generated or real images, and large-scale synthetic dataset generation with aligned labels for research.
- Requires FLUX.1-dev. These are decoder heads only — you need access to the (gated) FLUX.1-dev backbone to run them.
- Not a plug-and-play
transformers/diffuserspipeline. Use the code repository to load and run the checkpoints. - Performance reflects the training datasets above and may not transfer to very different domains.
Citation
@article{akarken2026mmdiff,
title = {{MMDiff}: Extending Diffusion Transformers for Multi-Modal Generation},
author = {Akarken, Yagmur and Kupyn, Orest and Rupprecht, Christian},
journal = {arXiv preprint arXiv:2606.16673},
year = {2026}
}
License
The decoder checkpoints and code in this project are released under the MIT License. Note that the frozen FLUX.1-dev backbone they depend on is governed by the FLUX.1-dev Non-Commercial License, and DINOv3 by its own license — review those before any non-research use.
Acknowledgements
Built on FLUX.1-dev, DINOv3, diffusers, and PyTorch Lightning.
Model tree for yagmurakarken/mmdiff
Base model
black-forest-labs/FLUX.1-dev