mdiffae-v2 / README.md
data-archetype's picture
Fix model card: citation title, line break, add results viewer link
e816c03 verified
---
license: apache-2.0
tags:
- diffusion
- autoencoder
- image-reconstruction
- pytorch
- masked-autoencoder
library_name: mdiffae
---
# mdiffae-v2
**mDiffAE v2** β€” **M**asked **Diff**usion **A**uto**E**ncoder v2.
A fast, single-GPU-trainable diffusion autoencoder with a **96-channel**
spatial bottleneck and optional PDG sharpening.
**This is the recommended version** β€” it offers substantially better
reconstruction than [v1](https://huggingface.co/data-archetype/mdiffae-v1)
(+1.7 dB mean PSNR) while maintaining the same or better convergence for
downstream latent diffusion models.
This variant (mdiffae-v2): 120.9M parameters, 461.2 MB.
Bottleneck: **96 channels** at patch size 16
(compression ratio 8x).
## Documentation
- [Technical Report](technical_report_mdiffae_v2.md) β€” architecture, training changes from v1, and results
- [Results β€” interactive viewer](https://huggingface.co/spaces/data-archetype/mdiffae-v2-results) β€” full-resolution side-by-side comparison
- [mDiffAE v1](https://huggingface.co/data-archetype/mdiffae-v1) β€” previous version
- [iRDiffAE Technical Report](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) β€” full background on VP diffusion, DiCo blocks, patchify encoder, AdaLN
## Quick Start
```python
import torch
from m_diffae_v2 import MDiffAEV2
# Load from HuggingFace Hub (or a local path)
model = MDiffAEV2.from_pretrained("data-archetype/mdiffae-v2", device="cuda")
# Encode
images = ... # [B, 3, H, W] in [-1, 1], H and W divisible by 16
latents = model.encode(images)
# Decode (2 steps by default β€” PSNR-optimal)
recon = model.decode(latents, height=H, width=W)
# Reconstruct (encode + 2-step decode)
recon = model.reconstruct(images)
```
> **Note:** Requires `pip install huggingface_hub safetensors` for Hub downloads.
> You can also pass a local directory path to `from_pretrained()`.
## Architecture
| Property | Value |
|---|---|
| Parameters | 120,893,792 |
| File size | 461.2 MB |
| Patch size | 16 |
| Model dim | 896 |
| Encoder depth | 4 |
| Decoder depth | 8 (2+4+2 skip-concat) |
| Bottleneck dim | 96 |
| Compression ratio | 8x |
| MLP ratio | 4.0 |
| Depthwise kernel | 7 |
| AdaLN rank | 128 |
| PDG | Conditioning degradation for CFG-style sharpening at inference |
| Training regularizer | Token masking (25-75% ratio, 90% apply prob) + Path drop (10% drop prob) |
**Encoder**: Deterministic. Patchify (PixelUnshuffle + 1x1 conv) followed by
DiCo blocks with learned residual gates. No input RMSNorm. Post-bottleneck
RMSNorm (affine=False) normalizes the latent tokens.
**Decoder**: VP diffusion conditioned on encoder latents and timestep via
shared-base + per-layer low-rank AdaLN-Zero. Skip-concat topology
(2 start + 4 middle + 2 end blocks)
with skip connections from start to end blocks. No outer RMSNorms
(input, latent conditioning, and output norms all removed).
### Changes from v1
| Aspect | mDiffAE v1 | mDiffAE v2 |
|---|---|---|
| Bottleneck dim | 64 (12x compression) | **96** (8x compression) |
| Decoder topology | 4 flat sequential blocks | **8 blocks (2+4+2 skip-concat)** |
| Token mask apply prob | 50% | **90%** |
| Token mask ratio | Fixed 75% | **Uniform(25%, 75%)** |
| PDG training regularizer | Token masking (50%) | **Token masking (90%) + path drop (10%)** |
| Latent noise prob | 10% | **50%** |
| Encoder input norm | RMSNorm (affine) | **Removed** |
| Decoder input norm | RMSNorm (affine) | **Removed** |
| Decoder latent norm | RMSNorm (affine) | **Removed** |
| Decoder output norm | RMSNorm (affine) | **Removed** |
## Recommended Settings
| Mode | Steps | PDG | Strength |
|---|---|---|---|
| **Default** (best PSNR) | 2 | off | β€” |
| **Sharp** (perceptual) | 10 | on | 2.0 |
```python
from m_diffae_v2 import MDiffAEV2InferenceConfig
# Default β€” best PSNR, fast (2 steps, no PDG)
recon = model.decode(latents, height=H, width=W)
# Sharp mode β€” perceptual sharpening (10 steps + PDG)
cfg = MDiffAEV2InferenceConfig(num_steps=10, pdg=True, pdg_strength=2.0)
recon = model.decode(latents, height=H, width=W, inference_config=cfg)
```
## Citation
```bibtex
@misc{mdiffae_v2,
title = {mDiffAE v2: A Fast Masked Diffusion Autoencoder},
author = {data-archetype},
year = {2026},
month = mar,
url = {https://huggingface.co/data-archetype/mdiffae-v2},
}
```
## Dependencies
- PyTorch >= 2.0
- safetensors (for loading weights)
## License
Apache 2.0