MMDiff: Extending Diffusion Transformers for Multi-Modal Generation

Yagmur Akarken, Orest Kupyn, Christian Rupprecht — Visual Geometry Group, University of Oxford

arXiv  Paper · Code · Project page

MMDiff turns a frozen diffusion transformer into a multi-modal generator. To create an image, a diffusion model must build up the semantic and geometric structure of the scene — and normally discards it once the image is rendered. MMDiff keeps that structure and decodes it into aligned dense outputs (semantic segmentation, salient object detection, monocular depth) in the same generation pass. Because the maps come straight from the generator, one frozen model can produce an image and its annotations together, enabling synthetic data generation at scale.

This repository hosts the trained decoder-head checkpoints. The FLUX.1-dev backbone and the optional DINOv3 encoder are never fine-tuned — only lightweight heads (~36M parameters) are trained.

  • Backbone: FLUX.1-dev (frozen) + optional DINOv3 ViT-B/16 (frozen)
  • Trained parameters: ~36M (decoder heads + feature-fusion module only)

How it works

  1. Multi-timestep feature fusion. A small learned module reads FLUX features from several denoising steps and fuses them with spatially varying aggregation weights. Perceptual information is temporally distributed along the trajectory, so this is the largest contributor — up to +28.7% mIoU over single-timestep extraction.
  2. Concept-driven attention. The frozen model provides interpretable spatial guidance (e.g. object vs. background, near vs. far) as extra cues for each task.
  3. Per-task decoder. A standard lightweight decoder is trained per task (DeepLabV3+ for segmentation/saliency, DPT for depth). The generator is never trained.
  4. Complementary to encoders. Frozen FLUX features are competitive with — and complementary to — state-of-the-art encoders such as DINOv3; combining them improves every task.

Checkpoints

File Task Dataset Config (in code repo)
pascal_segmentation.ckpt Semantic segmentation Pascal VOC 2012 configs/pascal_voc_config.yaml
duts_saliency.ckpt Salient object detection DUTS configs/duts_config.yaml
nyu_depth.ckpt Monocular depth NYU Depth V2 configs/nyu_depth_config.yaml

Each file is a PyTorch Lightning checkpoint; inference loads only the model weights (state_dict).

Results

Feature quality (frozen backbone, lightweight heads). Higher is better unless marked ↓.

Variant VOC mIoU ↑ DUTS Sₘ ↑ DUTS MAE ↓ NYU AbsRel ↓ NYU RMSE ↓
MMDiff 78.9 0.918 0.020 0.1175 0.370
MMDiff + DINOv3 84.95 0.934 0.018 0.1164 0.365

Synthetic-data training (decoders trained on MMDiff-generated images + labels).

Setting VOC mIoU ↑ DUTS Sₘ ↑ NYU AbsRel ↓
Synthetic only 78.9 0.784 0.1880
Synthetic + real fine-tune 87.8 0.863 0.1185

Trained purely on synthetic data, MMDiff outperforms prior synthetic-data methods (DatasetDM, DiffuMask, Dataset Diffusion). See the paper for full tables.

Usage

Clone the code repository, then download a checkpoint and run inference:

from huggingface_hub import hf_hub_download

ckpt = hf_hub_download("yagmurakarken/mmdiff", "pascal_segmentation.ckpt")
python scripts/inference.py --task pascal \
    --config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \
    --image my_image.jpg --output_dir outputs/

Swap pascalduts / nyu (with the matching config and checkpoint) for the other tasks.

Architecture flags (--hidden_dim, --num_transformer_layers, --num_timesteps, --dino_model) must match the checkpoint you load.

You can also generate an image and its aligned annotation together with the same frozen backbone:

python scripts/generate.py --task pascal \
    --config configs/pascal_voc_config.yaml --checkpoint "$ckpt" \
    --prompts_file prompts.txt --output_dir synth/

Intended uses & limitations

  • Intended use: dense prediction (segmentation, saliency, depth) from generated or real images, and large-scale synthetic dataset generation with aligned labels for research.
  • Requires FLUX.1-dev. These are decoder heads only — you need access to the (gated) FLUX.1-dev backbone to run them.
  • Not a plug-and-play transformers/diffusers pipeline. Use the code repository to load and run the checkpoints.
  • Performance reflects the training datasets above and may not transfer to very different domains.

Citation

@article{akarken2026mmdiff,
  title   = {{MMDiff}: Extending Diffusion Transformers for Multi-Modal Generation},
  author  = {Akarken, Yagmur and Kupyn, Orest and Rupprecht, Christian},
  journal = {arXiv preprint arXiv:2606.16673},
  year    = {2026}
}

License

The decoder checkpoints and code in this project are released under the MIT License. Note that the frozen FLUX.1-dev backbone they depend on is governed by the FLUX.1-dev Non-Commercial License, and DINOv3 by its own license — review those before any non-research use.

Acknowledgements

Built on FLUX.1-dev, DINOv3, diffusers, and PyTorch Lightning.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for yagmurakarken/mmdiff

Finetuned
(579)
this model

Paper for yagmurakarken/mmdiff