| --- |
| license: mit |
| library_name: pytorch |
| pipeline_tag: image-segmentation |
| base_model: black-forest-labs/FLUX.1-dev |
| tags: |
| - diffusion |
| - diffusion-transformer |
| - flux |
| - dinov3 |
| - dense-prediction |
| - semantic-segmentation |
| - salient-object-detection |
| - depth-estimation |
| - synthetic-data |
| --- |
| |
| # MMDiff: Extending Diffusion Transformers for Multi-Modal Generation |
|
|
| *Yagmur Akarken, Orest Kupyn, Christian Rupprecht — Visual Geometry Group, University of Oxford* |
|
|
| [](https://arxiv.org/abs/2606.16673) |
| [**Paper**](https://arxiv.org/abs/2606.16673) · [**Code**](https://github.com/yagmurakarken/mmdiff) · [**Project page**](https://yagmurakarken.github.io/mmdiff/) |
|
|
| **MMDiff** turns a **frozen** diffusion transformer into a multi-modal generator. To create an |
| image, a diffusion model must build up the semantic and geometric structure of the scene — and |
| normally discards it once the image is rendered. MMDiff keeps that structure and decodes it into |
| aligned dense outputs (semantic segmentation, salient object detection, monocular depth) in the |
| same generation pass. Because the maps come straight from the generator, one frozen model can |
| produce an image **and** its annotations together, enabling synthetic data generation at scale. |
|
|
| This repository hosts the trained **decoder-head checkpoints**. The FLUX.1-dev backbone and the |
| optional DINOv3 encoder are never fine-tuned — only lightweight heads (~36M parameters) are trained. |
|
|
| - **Backbone:** [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) (frozen) + optional [DINOv3 ViT-B/16](https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m) (frozen) |
| - **Trained parameters:** ~36M (decoder heads + feature-fusion module only) |
|
|
| ## How it works |
|
|
| 1. **Multi-timestep feature fusion.** A small learned module reads FLUX features from several |
| denoising steps and fuses them with spatially varying aggregation weights. Perceptual |
| information is temporally distributed along the trajectory, so this is the largest contributor — |
| up to **+28.7% mIoU** over single-timestep extraction. |
| 2. **Concept-driven attention.** The frozen model provides interpretable spatial guidance |
| (e.g. object vs. background, near vs. far) as extra cues for each task. |
| 3. **Per-task decoder.** A standard lightweight decoder is trained per task (DeepLabV3+ for |
| segmentation/saliency, DPT for depth). The generator is never trained. |
| 4. **Complementary to encoders.** Frozen FLUX features are competitive with — and complementary |
| to — state-of-the-art encoders such as DINOv3; combining them improves every task. |
|
|
| ## Checkpoints |
|
|
| | File | Task | Dataset | Config (in code repo) | |
| |------|------|---------|------------------------| |
| | `pascal_segmentation.ckpt` | Semantic segmentation | Pascal VOC 2012 | `configs/pascal_voc_config.yaml` | |
| | `duts_saliency.ckpt` | Salient object detection | DUTS | `configs/duts_config.yaml` | |
| | `nyu_depth.ckpt` | Monocular depth | NYU Depth V2 | `configs/nyu_depth_config.yaml` | |
|
|
| Each file is a PyTorch Lightning checkpoint; inference loads only the model weights (`state_dict`). |
|
|
| ## Results |
|
|
| **Feature quality** (frozen backbone, lightweight heads). Higher is better unless marked ↓. |
|
|
| | Variant | VOC mIoU ↑ | DUTS Sₘ ↑ | DUTS MAE ↓ | NYU AbsRel ↓ | NYU RMSE ↓ | |
| |---------|-----------|-----------|------------|--------------|------------| |
| | MMDiff | 78.9 | 0.918 | 0.020 | 0.1175 | 0.370 | |
| | MMDiff + DINOv3 | 84.95 | 0.934 | 0.018 | 0.1164 | 0.365 | |
|
|
| **Synthetic-data training** (decoders trained on MMDiff-generated images + labels). |
|
|
| | Setting | VOC mIoU ↑ | DUTS Sₘ ↑ | NYU AbsRel ↓ | |
| |---------|-----------|-----------|--------------| |
| | Synthetic only | 78.9 | 0.784 | 0.1880 | |
| | Synthetic + real fine-tune | 87.8 | 0.863 | 0.1185 | |
|
|
| Trained purely on synthetic data, MMDiff outperforms prior synthetic-data methods (DatasetDM, |
| DiffuMask, Dataset Diffusion). See the [paper](https://arxiv.org/abs/2606.16673) for full tables. |
|
|
| ## Usage |
|
|
| Clone the [code repository](https://github.com/yagmurakarken/mmdiff), then download a checkpoint |
| and run inference: |
|
|
| ```python |
| from huggingface_hub import hf_hub_download |
| |
| ckpt = hf_hub_download("yagmurakarken/mmdiff", "pascal_segmentation.ckpt") |
| ``` |
|
|
| ```bash |
| python scripts/inference.py --task pascal \ |
| --config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \ |
| --image my_image.jpg --output_dir outputs/ |
| ``` |
|
|
| Swap `pascal` → `duts` / `nyu` (with the matching config and checkpoint) for the other tasks. |
|
|
| > Architecture flags (`--hidden_dim`, `--num_transformer_layers`, `--num_timesteps`, |
| > `--dino_model`) must match the checkpoint you load. |
| |
| You can also generate an image and its aligned annotation together with the same frozen backbone: |
| |
| ```bash |
| python scripts/generate.py --task pascal \ |
| --config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \ |
| --prompts_file prompts.txt --output_dir synth/ |
| ``` |
| |
| ## Intended uses & limitations |
| |
| - **Intended use:** dense prediction (segmentation, saliency, depth) from generated or real images, |
| and large-scale synthetic dataset generation with aligned labels for research. |
| - **Requires FLUX.1-dev.** These are decoder heads only — you need access to the (gated) |
| [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) backbone to run them. |
| - **Not a plug-and-play `transformers`/`diffusers` pipeline.** Use the code repository to load and |
| run the checkpoints. |
| - Performance reflects the training datasets above and may not transfer to very different domains. |
| |
| ## Citation |
| |
| ```bibtex |
| @article{akarken2026mmdiff, |
| title = {{MMDiff}: Extending Diffusion Transformers for Multi-Modal Generation}, |
| author = {Akarken, Yagmur and Kupyn, Orest and Rupprecht, Christian}, |
| journal = {arXiv preprint arXiv:2606.16673}, |
| year = {2026} |
| } |
| ``` |
| |
| ## License |
| |
| The decoder checkpoints and code in this project are released under the **MIT License**. Note that |
| the frozen **FLUX.1-dev** backbone they depend on is governed by the |
| [FLUX.1-dev Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md), |
| and **DINOv3** by its own license — review those before any non-research use. |
| |
| ## Acknowledgements |
| |
| Built on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev), |
| [DINOv3](https://github.com/facebookresearch/dinov3), |
| [diffusers](https://github.com/huggingface/diffusers), and |
| [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning). |
| |