mmdiff / README.md
yagmurakarken's picture
Update model card: paper/arXiv links, how-it-works, results, usage, license
c70ad45 verified
|
Raw
History Blame Contribute Delete
6.54 kB
---
license: mit
library_name: pytorch
pipeline_tag: image-segmentation
base_model: black-forest-labs/FLUX.1-dev
tags:
- diffusion
- diffusion-transformer
- flux
- dinov3
- dense-prediction
- semantic-segmentation
- salient-object-detection
- depth-estimation
- synthetic-data
---
# MMDiff: Extending Diffusion Transformers for Multi-Modal Generation
*Yagmur Akarken, Orest Kupyn, Christian Rupprecht — Visual Geometry Group, University of Oxford*
[![arXiv](https://img.shields.io/badge/arXiv-2606.16673-b31b1b.svg)](https://arxiv.org/abs/2606.16673)
 [**Paper**](https://arxiv.org/abs/2606.16673) · [**Code**](https://github.com/yagmurakarken/mmdiff) · [**Project page**](https://yagmurakarken.github.io/mmdiff/)
**MMDiff** turns a **frozen** diffusion transformer into a multi-modal generator. To create an
image, a diffusion model must build up the semantic and geometric structure of the scene — and
normally discards it once the image is rendered. MMDiff keeps that structure and decodes it into
aligned dense outputs (semantic segmentation, salient object detection, monocular depth) in the
same generation pass. Because the maps come straight from the generator, one frozen model can
produce an image **and** its annotations together, enabling synthetic data generation at scale.
This repository hosts the trained **decoder-head checkpoints**. The FLUX.1-dev backbone and the
optional DINOv3 encoder are never fine-tuned — only lightweight heads (~36M parameters) are trained.
- **Backbone:** [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) (frozen) + optional [DINOv3 ViT-B/16](https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m) (frozen)
- **Trained parameters:** ~36M (decoder heads + feature-fusion module only)
## How it works
1. **Multi-timestep feature fusion.** A small learned module reads FLUX features from several
denoising steps and fuses them with spatially varying aggregation weights. Perceptual
information is temporally distributed along the trajectory, so this is the largest contributor —
up to **+28.7% mIoU** over single-timestep extraction.
2. **Concept-driven attention.** The frozen model provides interpretable spatial guidance
(e.g. object vs. background, near vs. far) as extra cues for each task.
3. **Per-task decoder.** A standard lightweight decoder is trained per task (DeepLabV3+ for
segmentation/saliency, DPT for depth). The generator is never trained.
4. **Complementary to encoders.** Frozen FLUX features are competitive with — and complementary
to — state-of-the-art encoders such as DINOv3; combining them improves every task.
## Checkpoints
| File | Task | Dataset | Config (in code repo) |
|------|------|---------|------------------------|
| `pascal_segmentation.ckpt` | Semantic segmentation | Pascal VOC 2012 | `configs/pascal_voc_config.yaml` |
| `duts_saliency.ckpt` | Salient object detection | DUTS | `configs/duts_config.yaml` |
| `nyu_depth.ckpt` | Monocular depth | NYU Depth V2 | `configs/nyu_depth_config.yaml` |
Each file is a PyTorch Lightning checkpoint; inference loads only the model weights (`state_dict`).
## Results
**Feature quality** (frozen backbone, lightweight heads). Higher is better unless marked ↓.
| Variant | VOC mIoU ↑ | DUTS Sₘ ↑ | DUTS MAE ↓ | NYU AbsRel ↓ | NYU RMSE ↓ |
|---------|-----------|-----------|------------|--------------|------------|
| MMDiff | 78.9 | 0.918 | 0.020 | 0.1175 | 0.370 |
| MMDiff + DINOv3 | 84.95 | 0.934 | 0.018 | 0.1164 | 0.365 |
**Synthetic-data training** (decoders trained on MMDiff-generated images + labels).
| Setting | VOC mIoU ↑ | DUTS Sₘ ↑ | NYU AbsRel ↓ |
|---------|-----------|-----------|--------------|
| Synthetic only | 78.9 | 0.784 | 0.1880 |
| Synthetic + real fine-tune | 87.8 | 0.863 | 0.1185 |
Trained purely on synthetic data, MMDiff outperforms prior synthetic-data methods (DatasetDM,
DiffuMask, Dataset Diffusion). See the [paper](https://arxiv.org/abs/2606.16673) for full tables.
## Usage
Clone the [code repository](https://github.com/yagmurakarken/mmdiff), then download a checkpoint
and run inference:
```python
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download("yagmurakarken/mmdiff", "pascal_segmentation.ckpt")
```
```bash
python scripts/inference.py --task pascal \
--config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \
--image my_image.jpg --output_dir outputs/
```
Swap `pascal``duts` / `nyu` (with the matching config and checkpoint) for the other tasks.
> Architecture flags (`--hidden_dim`, `--num_transformer_layers`, `--num_timesteps`,
> `--dino_model`) must match the checkpoint you load.
You can also generate an image and its aligned annotation together with the same frozen backbone:
```bash
python scripts/generate.py --task pascal \
--config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \
--prompts_file prompts.txt --output_dir synth/
```
## Intended uses & limitations
- **Intended use:** dense prediction (segmentation, saliency, depth) from generated or real images,
and large-scale synthetic dataset generation with aligned labels for research.
- **Requires FLUX.1-dev.** These are decoder heads only — you need access to the (gated)
[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) backbone to run them.
- **Not a plug-and-play `transformers`/`diffusers` pipeline.** Use the code repository to load and
run the checkpoints.
- Performance reflects the training datasets above and may not transfer to very different domains.
## Citation
```bibtex
@article{akarken2026mmdiff,
title = {{MMDiff}: Extending Diffusion Transformers for Multi-Modal Generation},
author = {Akarken, Yagmur and Kupyn, Orest and Rupprecht, Christian},
journal = {arXiv preprint arXiv:2606.16673},
year = {2026}
}
```
## License
The decoder checkpoints and code in this project are released under the **MIT License**. Note that
the frozen **FLUX.1-dev** backbone they depend on is governed by the
[FLUX.1-dev Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md),
and **DINOv3** by its own license — review those before any non-research use.
## Acknowledgements
Built on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev),
[DINOv3](https://github.com/facebookresearch/dinov3),
[diffusers](https://github.com/huggingface/diffusers), and
[PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning).