File size: 10,612 Bytes
9b877c3 128cb34 d9ec2a4 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 d9ec2a4 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 d9ec2a4 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 9b877c3 128cb34 05d1710 128cb34 05d1710 128cb34 05d1710 128cb34 5595c25 128cb34 05d1710 128cb34 05d1710 128cb34 05d1710 128cb34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | # mDiffAE: A Fast Masked Diffusion Autoencoder β Technical Report
**Version 1** β March 2026
## 1. Introduction
mDiffAE (**M**asked **Diff**usion **A**uto**E**ncoder) builds on the [iRDiffAE](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) model family, which provides a fast, single-GPU-trainable diffusion autoencoder with good reconstruction quality, making it a good platform for experimenting with latent space regularization. See that report for full background on the shared components: VP diffusion, DiCo blocks, patchify encoder, AdaLN-Zero conditioning, and Path-Drop Guidance (PDG).
iRDiffAE v1 used REPA (aligning encoder features with a frozen DINOv2 teacher) to regularize the latent space. REPA produces well-structured latents but tends toward overly smooth representations. Here we replace it with **decoder token masking**.
### 1.1 Token Masking as Regularizer
With 50% probability per sample, the decoder only sees **25% of tokens** in the fused conditioning input. The spatial token grid is divided into non-overlapping 2Γ2 groups; within each group a single token is randomly kept and the other three are replaced with a learned mask feature. The high masking ratio (75%) forces each spatial token to carry enough information for reconstruction even when most neighbors are absent. Lower masking ratios help downstream models learn sharp details quickly but fail to learn spatial coherence β the task becomes too close to local inpainting. We tested lower ratios and confirmed this tradeoff (see also He et al., 2022).
The 50% application probability controls the tradeoff between reconstruction quality and latent regularity.
### 1.2 Latent Noise Regularization
10% of the time, random noise is added to the latent representation. Unlike iRDiffAE (and the DiTo paper), which synchronizes the latent noise level with the pixel-space diffusion timestep, here the noise level is sampled independently from a **Beta(2,2)** distribution with a **logSNR shift of +1.0**, biasing it toward low noise. This improves robustness to incomplete convergence of downstream models and encourages local smoothness of the latent space distribution.
### 1.3 Simplified Decoder
The decoder uses only **4 blocks** (down from 8 in iRDiffAE v1) in a flat sequential layout β no start/middle/end groups, no skip connections. This halves the decoder's parameter count and is roughly 2Γ faster.
### 1.4 Bottleneck
iRDiffAE v1 used 128 bottleneck channels, partly because REPA alignment occupies half the channels. Without REPA, 64 channels suffice and give better channel utilisation. This yields a 12Γ compression ratio at patch size 16 (vs 6Γ for iRDiffAE).
### 1.5 Empirical Results
Compared to iRDiffAE v1, mDiffAE achieves slightly higher PSNR with less oversmoothed latent PCA. In downstream diffusion model training, mDiffAE's latent space does not show the steep initial loss descent of iRDiffAE, but catches up after 50kβ100k steps, producing more spatially coherent images with better high-frequency detail.
### 1.6 References
- He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022). *Masked Autoencoders Are Scalable Vision Learners*. CVPR 2022.
- Li, T., Chang, H., Mishra, S.K., Zhang, H., Katabi, D., & Krishnan, D. (2023). *MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis*. CVPR 2023.
## 2. Architecture Differences from iRDiffAE
| Aspect | iRDiffAE v1 (halfrepa c128) | mDiffAE v1 (masked c64) |
|--------|------------------------------|--------------------------|
| Bottleneck dim | 128 | **64** |
| Decoder depth | 8 (2 start + 4 middle + 2 end) | **4 (flat sequential)** |
| Decoder topology | START_MIDDLE_END_SKIP_CONCAT | **FLAT (no skip concat)** |
| Skip fusion | Yes (`fuse_skip` Conv1Γ1) | **No** |
| PDG mechanism | Drop middle blocks β mask_feature | **Token-level masking** (75% spatial tokens β mask_feature) |
| PDG sensitivity | Moderate (strength 1.5β3.0) | **Very sensitive** (strength 1.01β1.05) |
| Training regularizer | REPA (half-channel DINOv2 alignment) + covreg | **Decoder token masking** (75% ratio, 50% apply prob) |
| Latent noise reg | Same mechanism | **Independent Beta(2,2), logSNR shift +1.0, 10% prob** |
| Depthwise kernel | 7Γ7 | 7Γ7 (same) |
| Model dim | 896 | 896 (same) |
| Encoder depth | 4 | 4 (same) |
| Best decode | 1 step DDIM | 1 step DDIM (same) |
## 3. Training-Time Masking Details
### 3.1 Token Masking Procedure
During training, with 50% probability per sample:
1. The fused decoder input (patchified x_t + upsampled encoder latents) is divided into non-overlapping 2Γ2 spatial groups
2. Within each group, 3 of 4 tokens (75%) are selected for masking using random scoring
3. Masked tokens are replaced with a learned `mask_feature` parameter (same dimensionality as model_dim)
4. The decoder processes the partially-masked input normally through all blocks
### 3.2 PDG at Inference
At inference, the trained mask_feature can be used for Path-Drop Guidance (PDG): the conditional pass uses the full input, the unconditional pass applies 2Γ2 groupwise masking at 75%, and the two are interpolated as usual. PDG can sharpen reconstructions but should be kept very low (strength 1.01β1.05); higher values cause artifacts.
## 4. Flat Decoder Architecture
### 4.1 iRDiffAE v1 Decoder (for comparison)
```
Fused input β Start blocks (2) β [save for skip] β
Middle blocks (4) β [cat with saved skip] β FuseSkip Conv1Γ1 β
End blocks (2) β Output head
```
8 blocks split into three groups with a skip connection. For PDG, the middle blocks are dropped and replaced with a learned mask feature.
### 4.2 mDiffAE v1 Decoder
```
Patchify(x_t) β RMSNorm β x_feat [B, 896, h, w]
LatentUp(z) β RMSNorm β z_up [B, 896, h, w]
FuseIn(cat(x_feat, z_up)) β fused [B, 896, h, w]
[Optional: token masking for PDG]
TimeEmbed(t) β cond [B, 896]
Block_0 β Block_1 β Block_2 β Block_3 β out [B, 896, h, w]
RMSNorm β Conv1x1 β PixelShuffle β x0_hat [B, 3, H, W]
```
4 flat sequential blocks, no skip connections. Roughly half the decoder parameters of iRDiffAE.
## 5. Model Configuration
| Parameter | Value |
|-----------|-------|
| `in_channels` | 3 |
| `patch_size` | 16 |
| `model_dim` | 896 |
| `encoder_depth` | 4 |
| `decoder_depth` | 4 |
| `bottleneck_dim` | 64 |
| `mlp_ratio` | 4.0 |
| `depthwise_kernel_size` | 7 |
| `adaln_low_rank_rank` | 128 |
| `logsnr_min` | β10.0 |
| `logsnr_max` | 10.0 |
| `pixel_noise_std` | 0.558 |
| `pdg_mask_ratio` | 0.75 |
Training checkpoint: step 708,000 (EMA weights).
## 6. Inference Recommendations
| Setting | Value | Notes |
|---------|-------|-------|
| Sampler | DDIM | Best for 1-step |
| Steps | 1 | PSNR-optimal |
| PDG | Disabled | Default |
| PDG strength | 1.01β1.05 | If enabled; can sharpen but artifacts above ~1.1 |
## 7. Results
Reconstruction quality evaluated on two image sets: a large benchmark (N=2000, 2/3 photographs + 1/3 book covers) for summary statistics, and a curated 39-image set for per-image comparisons. Flux.1 and Flux.2 VAEs are included as references. All models use 1-step DDIM, seed 42, no PDG, bfloat16.
### 7.1 Summary PSNR (N=2000 images)
| Model | Mean PSNR (dB) | Std (dB) | Median (dB) |
|-------|---------------|----------|-------------|
| mDiffAE v1 (1 step) | 34.15 | 5.14 | 33.82 |
| Flux.1 VAE | 34.62 | 4.31 | 35.17 |
| Flux.2 VAE | 36.30 | 4.58 | 36.14 |
**Percentile distribution:**
| Percentile | mDiffAE v1 | Flux.1 VAE | Flux.2 VAE |
|------------|-----------|------------|------------|
| p5 | 26.22 | 27.06 | 28.99 |
| p10 | 27.54 | 28.45 | 30.38 |
| p25 | 30.22 | 31.58 | 32.87 |
| p50 | 33.82 | 35.17 | 36.14 |
| p75 | 38.20 | 37.99 | 39.85 |
| p90 | 41.21 | 39.75 | 42.51 |
| p95 | 42.49 | 40.57 | 43.64 |
> Timings on the 39-image set (batch 8, bf16, NVIDIA RTX Pro 6000 Blackwell): mDiffAE encode 2.4 ms + decode 3.0 ms = **5.4 ms/image** total, vs Flux.1 at 202 ms and Flux.2 at 138 ms β roughly **37Γ** and **26Γ** faster end-to-end.
### 7.2 Interactive Viewer
**[Open full-resolution comparison viewer](https://huggingface.co/spaces/data-archetype/mdiffae-results)** β side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.
### 7.3 Per-Image Results (39-image curated set)
Inference settings: 1-step DDIM, seed 42, no PDG, batch size 8.
| Metric | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
|--------|--------|--------|--------|
| Avg PSNR (dB) | 31.89 | 32.76 | 34.16 |
| Avg encode (ms/image) | 2.4 | 63.9 | 45.7 |
| Avg decode (ms/image) | 3.0 | 138.2 | 92.8 |
### 7.4 Per-Image PSNR (dB)
| Image | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
|-------|--------|--------|--------|
| p640x1536:94623 | 31.20 | 31.28 | 33.50 |
| p640x1536:94624 | 27.32 | 27.62 | 30.03 |
| p640x1536:94625 | 30.68 | 31.65 | 33.98 |
| p640x1536:94626 | 29.14 | 29.44 | 31.53 |
| p640x1536:94627 | 29.63 | 28.70 | 30.53 |
| p640x1536:94628 | 25.60 | 26.38 | 28.88 |
| p960x1024:216264 | 44.50 | 40.87 | 45.39 |
| p960x1024:216265 | 26.42 | 25.82 | 27.80 |
| p960x1024:216266 | 44.90 | 47.77 | 46.20 |
| p960x1024:216267 | 37.78 | 37.65 | 39.23 |
| p960x1024:216268 | 36.15 | 35.27 | 36.13 |
| p960x1024:216269 | 29.37 | 28.45 | 30.24 |
| p960x1024:216270 | 32.43 | 31.92 | 34.18 |
| p960x1024:216271 | 41.23 | 38.92 | 42.18 |
| p704x1472:94699 | 41.88 | 40.43 | 41.79 |
| p704x1472:94700 | 29.66 | 29.52 | 32.08 |
| p704x1472:94701 | 35.14 | 35.43 | 37.90 |
| p704x1472:94702 | 30.90 | 30.73 | 32.50 |
| p704x1472:94703 | 28.65 | 29.08 | 31.35 |
| p704x1472:94704 | 28.98 | 29.22 | 31.84 |
| p704x1472:94705 | 36.09 | 36.38 | 37.44 |
| p704x1472:94706 | 31.53 | 31.50 | 33.66 |
| r256_p1344x704:15577 | 27.89 | 28.32 | 29.98 |
| r256_p1344x704:15578 | 28.07 | 29.35 | 30.79 |
| r256_p1344x704:15579 | 29.56 | 30.44 | 31.83 |
| r256_p1344x704:15580 | 32.89 | 36.12 | 36.03 |
| r256_p1344x704:15581 | 32.26 | 37.42 | 36.94 |
| r256_p1344x704:15582 | 28.74 | 30.64 | 32.10 |
| r256_p1344x704:15583 | 31.99 | 34.67 | 34.54 |
| r256_p1344x704:15584 | 28.42 | 30.34 | 31.76 |
| r256_p896x1152:144131 | 30.02 | 33.10 | 33.60 |
| r256_p896x1152:144132 | 33.19 | 34.23 | 35.32 |
| r256_p896x1152:144133 | 35.42 | 37.85 | 37.33 |
| r256_p896x1152:144134 | 31.41 | 34.25 | 34.47 |
| r256_p896x1152:144135 | 27.13 | 28.17 | 29.87 |
| r256_p896x1152:144136 | 32.75 | 35.24 | 35.68 |
| r256_p896x1152:144137 | 28.60 | 32.70 | 32.86 |
| r256_p896x1152:144138 | 24.76 | 24.15 | 25.63 |
| VAE_accuracy_test_image | 31.52 | 36.69 | 35.25 |
|