# mDiffAE: A Fast Masked Diffusion Autoencoder — Technical Report **Version 1** — March 2026 ## 1. Introduction mDiffAE (**M**asked **Diff**usion **A**uto**E**ncoder) builds on the [iRDiffAE](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) model family, which provides a fast, single-GPU-trainable diffusion autoencoder with good reconstruction quality, making it a good platform for experimenting with latent space regularization. See that report for full background on the shared components: VP diffusion, DiCo blocks, patchify encoder, AdaLN-Zero conditioning, and Path-Drop Guidance (PDG). iRDiffAE v1 used REPA (aligning encoder features with a frozen DINOv2 teacher) to regularize the latent space. REPA produces well-structured latents but tends toward overly smooth representations. Here we replace it with **decoder token masking**. ### 1.1 Token Masking as Regularizer With 50% probability per sample, the decoder only sees **25% of tokens** in the fused conditioning input. The spatial token grid is divided into non-overlapping 2×2 groups; within each group a single token is randomly kept and the other three are replaced with a learned mask feature. The high masking ratio (75%) forces each spatial token to carry enough information for reconstruction even when most neighbors are absent. Lower masking ratios help downstream models learn sharp details quickly but fail to learn spatial coherence — the task becomes too close to local inpainting. We tested lower ratios and confirmed this tradeoff (see also He et al., 2022). The 50% application probability controls the tradeoff between reconstruction quality and latent regularity. ### 1.2 Latent Noise Regularization 10% of the time, random noise is added to the latent representation. Unlike iRDiffAE (and the DiTo paper), which synchronizes the latent noise level with the pixel-space diffusion timestep, here the noise level is sampled independently from a **Beta(2,2)** distribution with a **logSNR shift of +1.0**, biasing it toward low noise. This improves robustness to incomplete convergence of downstream models and encourages local smoothness of the latent space distribution. ### 1.3 Simplified Decoder The decoder uses only **4 blocks** (down from 8 in iRDiffAE v1) in a flat sequential layout — no start/middle/end groups, no skip connections. This halves the decoder's parameter count and is roughly 2× faster. ### 1.4 Bottleneck iRDiffAE v1 used 128 bottleneck channels, partly because REPA alignment occupies half the channels. Without REPA, 64 channels suffice and give better channel utilisation. This yields a 12× compression ratio at patch size 16 (vs 6× for iRDiffAE). ### 1.5 Empirical Results Compared to iRDiffAE v1, mDiffAE achieves slightly higher PSNR with less oversmoothed latent PCA. In downstream diffusion model training, mDiffAE's latent space does not show the steep initial loss descent of iRDiffAE, but catches up after 50k–100k steps, producing more spatially coherent images with better high-frequency detail. ### 1.6 References - He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). *Masked Autoencoders Are Scalable Vision Learners*. CVPR 2022. - Li, T., Chang, H., Mishra, S.K., Zhang, H., Katabi, D., & Krishnan, D. (2023). *MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis*. CVPR 2023. ## 2. Architecture Differences from iRDiffAE | Aspect | iRDiffAE v1 (halfrepa c128) | mDiffAE v1 (masked c64) | |--------|------------------------------|--------------------------| | Bottleneck dim | 128 | **64** | | Decoder depth | 8 (2 start + 4 middle + 2 end) | **4 (flat sequential)** | | Decoder topology | START_MIDDLE_END_SKIP_CONCAT | **FLAT (no skip concat)** | | Skip fusion | Yes (`fuse_skip` Conv1×1) | **No** | | PDG mechanism | Drop middle blocks → mask_feature | **Token-level masking** (75% spatial tokens → mask_feature) | | PDG sensitivity | Moderate (strength 1.5–3.0) | **Very sensitive** (strength 1.01–1.05) | | Training regularizer | REPA (half-channel DINOv2 alignment) + covreg | **Decoder token masking** (75% ratio, 50% apply prob) | | Latent noise reg | Same mechanism | **Independent Beta(2,2), logSNR shift +1.0, 10% prob** | | Depthwise kernel | 7×7 | 7×7 (same) | | Model dim | 896 | 896 (same) | | Encoder depth | 4 | 4 (same) | | Best decode | 1 step DDIM | 1 step DDIM (same) | ## 3. Training-Time Masking Details ### 3.1 Token Masking Procedure During training, with 50% probability per sample: 1. The fused decoder input (patchified x_t + upsampled encoder latents) is divided into non-overlapping 2×2 spatial groups 2. Within each group, 3 of 4 tokens (75%) are selected for masking using random scoring 3. Masked tokens are replaced with a learned `mask_feature` parameter (same dimensionality as model_dim) 4. The decoder processes the partially-masked input normally through all blocks ### 3.2 PDG at Inference At inference, the trained mask_feature can be used for Path-Drop Guidance (PDG): the conditional pass uses the full input, the unconditional pass applies 2×2 groupwise masking at 75%, and the two are interpolated as usual. PDG can sharpen reconstructions but should be kept very low (strength 1.01–1.05); higher values cause artifacts. ## 4. Flat Decoder Architecture ### 4.1 iRDiffAE v1 Decoder (for comparison) ``` Fused input → Start blocks (2) → [save for skip] → Middle blocks (4) → [cat with saved skip] → FuseSkip Conv1×1 → End blocks (2) → Output head ``` 8 blocks split into three groups with a skip connection. For PDG, the middle blocks are dropped and replaced with a learned mask feature. ### 4.2 mDiffAE v1 Decoder ``` Patchify(x_t) → RMSNorm → x_feat [B, 896, h, w] LatentUp(z) → RMSNorm → z_up [B, 896, h, w] FuseIn(cat(x_feat, z_up)) → fused [B, 896, h, w] [Optional: token masking for PDG] TimeEmbed(t) → cond [B, 896] Block_0 → Block_1 → Block_2 → Block_3 → out [B, 896, h, w] RMSNorm → Conv1x1 → PixelShuffle → x0_hat [B, 3, H, W] ``` 4 flat sequential blocks, no skip connections. Roughly half the decoder parameters of iRDiffAE. ## 5. Model Configuration | Parameter | Value | |-----------|-------| | `in_channels` | 3 | | `patch_size` | 16 | | `model_dim` | 896 | | `encoder_depth` | 4 | | `decoder_depth` | 4 | | `bottleneck_dim` | 64 | | `mlp_ratio` | 4.0 | | `depthwise_kernel_size` | 7 | | `adaln_low_rank_rank` | 128 | | `logsnr_min` | −10.0 | | `logsnr_max` | 10.0 | | `pixel_noise_std` | 0.558 | | `pdg_mask_ratio` | 0.75 | Training checkpoint: step 708,000 (EMA weights). ## 6. Inference Recommendations | Setting | Value | Notes | |---------|-------|-------| | Sampler | DDIM | Best for 1-step | | Steps | 1 | PSNR-optimal | | PDG | Disabled | Default | | PDG strength | 1.01–1.05 | If enabled; can sharpen but artifacts above ~1.1 | ## 7. Results Reconstruction quality evaluated on two image sets: a large benchmark (N=2000, 2/3 photographs + 1/3 book covers) for summary statistics, and a curated 39-image set for per-image comparisons. Flux.1 and Flux.2 VAEs are included as references. All models use 1-step DDIM, seed 42, no PDG, bfloat16. ### 7.1 Summary PSNR (N=2000 images) | Model | Mean PSNR (dB) | Std (dB) | Median (dB) | |-------|---------------|----------|-------------| | mDiffAE v1 (1 step) | 34.15 | 5.14 | 33.82 | | Flux.1 VAE | 34.62 | 4.31 | 35.17 | | Flux.2 VAE | 36.30 | 4.58 | 36.14 | **Percentile distribution:** | Percentile | mDiffAE v1 | Flux.1 VAE | Flux.2 VAE | |------------|-----------|------------|------------| | p5 | 26.22 | 27.06 | 28.99 | | p10 | 27.54 | 28.45 | 30.38 | | p25 | 30.22 | 31.58 | 32.87 | | p50 | 33.82 | 35.17 | 36.14 | | p75 | 38.20 | 37.99 | 39.85 | | p90 | 41.21 | 39.75 | 42.51 | | p95 | 42.49 | 40.57 | 43.64 | > Timings on the 39-image set (batch 8, bf16, NVIDIA RTX Pro 6000 Blackwell): mDiffAE encode 2.4 ms + decode 3.0 ms = **5.4 ms/image** total, vs Flux.1 at 202 ms and Flux.2 at 138 ms — roughly **37×** and **26×** faster end-to-end. ### 7.2 Interactive Viewer **[Open full-resolution comparison viewer](https://huggingface.co/spaces/data-archetype/mdiffae-results)** — side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size. ### 7.3 Per-Image Results (39-image curated set) Inference settings: 1-step DDIM, seed 42, no PDG, batch size 8. | Metric | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE | |--------|--------|--------|--------| | Avg PSNR (dB) | 31.89 | 32.76 | 34.16 | | Avg encode (ms/image) | 2.4 | 63.9 | 45.7 | | Avg decode (ms/image) | 3.0 | 138.2 | 92.8 | ### 7.4 Per-Image PSNR (dB) | Image | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE | |-------|--------|--------|--------| | p640x1536:94623 | 31.20 | 31.28 | 33.50 | | p640x1536:94624 | 27.32 | 27.62 | 30.03 | | p640x1536:94625 | 30.68 | 31.65 | 33.98 | | p640x1536:94626 | 29.14 | 29.44 | 31.53 | | p640x1536:94627 | 29.63 | 28.70 | 30.53 | | p640x1536:94628 | 25.60 | 26.38 | 28.88 | | p960x1024:216264 | 44.50 | 40.87 | 45.39 | | p960x1024:216265 | 26.42 | 25.82 | 27.80 | | p960x1024:216266 | 44.90 | 47.77 | 46.20 | | p960x1024:216267 | 37.78 | 37.65 | 39.23 | | p960x1024:216268 | 36.15 | 35.27 | 36.13 | | p960x1024:216269 | 29.37 | 28.45 | 30.24 | | p960x1024:216270 | 32.43 | 31.92 | 34.18 | | p960x1024:216271 | 41.23 | 38.92 | 42.18 | | p704x1472:94699 | 41.88 | 40.43 | 41.79 | | p704x1472:94700 | 29.66 | 29.52 | 32.08 | | p704x1472:94701 | 35.14 | 35.43 | 37.90 | | p704x1472:94702 | 30.90 | 30.73 | 32.50 | | p704x1472:94703 | 28.65 | 29.08 | 31.35 | | p704x1472:94704 | 28.98 | 29.22 | 31.84 | | p704x1472:94705 | 36.09 | 36.38 | 37.44 | | p704x1472:94706 | 31.53 | 31.50 | 33.66 | | r256_p1344x704:15577 | 27.89 | 28.32 | 29.98 | | r256_p1344x704:15578 | 28.07 | 29.35 | 30.79 | | r256_p1344x704:15579 | 29.56 | 30.44 | 31.83 | | r256_p1344x704:15580 | 32.89 | 36.12 | 36.03 | | r256_p1344x704:15581 | 32.26 | 37.42 | 36.94 | | r256_p1344x704:15582 | 28.74 | 30.64 | 32.10 | | r256_p1344x704:15583 | 31.99 | 34.67 | 34.54 | | r256_p1344x704:15584 | 28.42 | 30.34 | 31.76 | | r256_p896x1152:144131 | 30.02 | 33.10 | 33.60 | | r256_p896x1152:144132 | 33.19 | 34.23 | 35.32 | | r256_p896x1152:144133 | 35.42 | 37.85 | 37.33 | | r256_p896x1152:144134 | 31.41 | 34.25 | 34.47 | | r256_p896x1152:144135 | 27.13 | 28.17 | 29.87 | | r256_p896x1152:144136 | 32.75 | 35.24 | 35.68 | | r256_p896x1152:144137 | 28.60 | 32.70 | 32.86 | | r256_p896x1152:144138 | 24.76 | 24.15 | 25.63 | | VAE_accuracy_test_image | 31.52 | 36.69 | 35.25 |