mdiffae-v1 / technical_report_mdiffae.md
data-archetype's picture
Upload folder using huggingface_hub
d9ec2a4 verified

mDiffAE: A Fast Masked Diffusion Autoencoder β€” Technical Report

Version 1 β€” March 2026

1. Introduction

mDiffAE (Masked Diffusion AutoEncoder) builds on the iRDiffAE model family, which provides a fast, single-GPU-trainable diffusion autoencoder with good reconstruction quality, making it a good platform for experimenting with latent space regularization. See that report for full background on the shared components: VP diffusion, DiCo blocks, patchify encoder, AdaLN-Zero conditioning, and Path-Drop Guidance (PDG).

iRDiffAE v1 used REPA (aligning encoder features with a frozen DINOv2 teacher) to regularize the latent space. REPA produces well-structured latents but tends toward overly smooth representations. Here we replace it with decoder token masking.

1.1 Token Masking as Regularizer

With 50% probability per sample, the decoder only sees 25% of tokens in the fused conditioning input. The spatial token grid is divided into non-overlapping 2Γ—2 groups; within each group a single token is randomly kept and the other three are replaced with a learned mask feature. The high masking ratio (75%) forces each spatial token to carry enough information for reconstruction even when most neighbors are absent. Lower masking ratios help downstream models learn sharp details quickly but fail to learn spatial coherence β€” the task becomes too close to local inpainting. We tested lower ratios and confirmed this tradeoff (see also He et al., 2022).

The 50% application probability controls the tradeoff between reconstruction quality and latent regularity.

1.2 Latent Noise Regularization

10% of the time, random noise is added to the latent representation. Unlike iRDiffAE (and the DiTo paper), which synchronizes the latent noise level with the pixel-space diffusion timestep, here the noise level is sampled independently from a Beta(2,2) distribution with a logSNR shift of +1.0, biasing it toward low noise. This improves robustness to incomplete convergence of downstream models and encourages local smoothness of the latent space distribution.

1.3 Simplified Decoder

The decoder uses only 4 blocks (down from 8 in iRDiffAE v1) in a flat sequential layout β€” no start/middle/end groups, no skip connections. This halves the decoder's parameter count and is roughly 2Γ— faster.

1.4 Bottleneck

iRDiffAE v1 used 128 bottleneck channels, partly because REPA alignment occupies half the channels. Without REPA, 64 channels suffice and give better channel utilisation. This yields a 12Γ— compression ratio at patch size 16 (vs 6Γ— for iRDiffAE).

1.5 Empirical Results

Compared to iRDiffAE v1, mDiffAE achieves slightly higher PSNR with less oversmoothed latent PCA. In downstream diffusion model training, mDiffAE's latent space does not show the steep initial loss descent of iRDiffAE, but catches up after 50k–100k steps, producing more spatially coherent images with better high-frequency detail.

1.6 References

  • He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022.
  • Li, T., Chang, H., Mishra, S.K., Zhang, H., Katabi, D., & Krishnan, D. (2023). MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis. CVPR 2023.

2. Architecture Differences from iRDiffAE

Aspect iRDiffAE v1 (halfrepa c128) mDiffAE v1 (masked c64)
Bottleneck dim 128 64
Decoder depth 8 (2 start + 4 middle + 2 end) 4 (flat sequential)
Decoder topology START_MIDDLE_END_SKIP_CONCAT FLAT (no skip concat)
Skip fusion Yes (fuse_skip Conv1Γ—1) No
PDG mechanism Drop middle blocks β†’ mask_feature Token-level masking (75% spatial tokens β†’ mask_feature)
PDG sensitivity Moderate (strength 1.5–3.0) Very sensitive (strength 1.01–1.05)
Training regularizer REPA (half-channel DINOv2 alignment) + covreg Decoder token masking (75% ratio, 50% apply prob)
Latent noise reg Same mechanism Independent Beta(2,2), logSNR shift +1.0, 10% prob
Depthwise kernel 7Γ—7 7Γ—7 (same)
Model dim 896 896 (same)
Encoder depth 4 4 (same)
Best decode 1 step DDIM 1 step DDIM (same)

3. Training-Time Masking Details

3.1 Token Masking Procedure

During training, with 50% probability per sample:

  1. The fused decoder input (patchified x_t + upsampled encoder latents) is divided into non-overlapping 2Γ—2 spatial groups
  2. Within each group, 3 of 4 tokens (75%) are selected for masking using random scoring
  3. Masked tokens are replaced with a learned mask_feature parameter (same dimensionality as model_dim)
  4. The decoder processes the partially-masked input normally through all blocks

3.2 PDG at Inference

At inference, the trained mask_feature can be used for Path-Drop Guidance (PDG): the conditional pass uses the full input, the unconditional pass applies 2Γ—2 groupwise masking at 75%, and the two are interpolated as usual. PDG can sharpen reconstructions but should be kept very low (strength 1.01–1.05); higher values cause artifacts.

4. Flat Decoder Architecture

4.1 iRDiffAE v1 Decoder (for comparison)

Fused input β†’ Start blocks (2) β†’ [save for skip] β†’
  Middle blocks (4) β†’ [cat with saved skip] β†’ FuseSkip Conv1Γ—1 β†’
  End blocks (2) β†’ Output head

8 blocks split into three groups with a skip connection. For PDG, the middle blocks are dropped and replaced with a learned mask feature.

4.2 mDiffAE v1 Decoder

Patchify(x_t) β†’ RMSNorm β†’ x_feat [B, 896, h, w]
LatentUp(z) β†’ RMSNorm β†’ z_up [B, 896, h, w]
FuseIn(cat(x_feat, z_up)) β†’ fused [B, 896, h, w]
[Optional: token masking for PDG]
TimeEmbed(t) β†’ cond [B, 896]
Block_0 β†’ Block_1 β†’ Block_2 β†’ Block_3 β†’ out [B, 896, h, w]
RMSNorm β†’ Conv1x1 β†’ PixelShuffle β†’ x0_hat [B, 3, H, W]

4 flat sequential blocks, no skip connections. Roughly half the decoder parameters of iRDiffAE.

5. Model Configuration

Parameter Value
in_channels 3
patch_size 16
model_dim 896
encoder_depth 4
decoder_depth 4
bottleneck_dim 64
mlp_ratio 4.0
depthwise_kernel_size 7
adaln_low_rank_rank 128
logsnr_min βˆ’10.0
logsnr_max 10.0
pixel_noise_std 0.558
pdg_mask_ratio 0.75

Training checkpoint: step 708,000 (EMA weights).

6. Inference Recommendations

Setting Value Notes
Sampler DDIM Best for 1-step
Steps 1 PSNR-optimal
PDG Disabled Default
PDG strength 1.01–1.05 If enabled; can sharpen but artifacts above ~1.1

7. Results

Reconstruction quality evaluated on two image sets: a large benchmark (N=2000, 2/3 photographs + 1/3 book covers) for summary statistics, and a curated 39-image set for per-image comparisons. Flux.1 and Flux.2 VAEs are included as references. All models use 1-step DDIM, seed 42, no PDG, bfloat16.

7.1 Summary PSNR (N=2000 images)

Model Mean PSNR (dB) Std (dB) Median (dB)
mDiffAE v1 (1 step) 34.15 5.14 33.82
Flux.1 VAE 34.62 4.31 35.17
Flux.2 VAE 36.30 4.58 36.14

Percentile distribution:

Percentile mDiffAE v1 Flux.1 VAE Flux.2 VAE
p5 26.22 27.06 28.99
p10 27.54 28.45 30.38
p25 30.22 31.58 32.87
p50 33.82 35.17 36.14
p75 38.20 37.99 39.85
p90 41.21 39.75 42.51
p95 42.49 40.57 43.64

Timings on the 39-image set (batch 8, bf16, NVIDIA RTX Pro 6000 Blackwell): mDiffAE encode 2.4 ms + decode 3.0 ms = 5.4 ms/image total, vs Flux.1 at 202 ms and Flux.2 at 138 ms β€” roughly 37Γ— and 26Γ— faster end-to-end.

7.2 Interactive Viewer

Open full-resolution comparison viewer β€” side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.

7.3 Per-Image Results (39-image curated set)

Inference settings: 1-step DDIM, seed 42, no PDG, batch size 8.

Metric mdiffae_v1 (1 step) Flux.1 VAE Flux.2 VAE
Avg PSNR (dB) 31.89 32.76 34.16
Avg encode (ms/image) 2.4 63.9 45.7
Avg decode (ms/image) 3.0 138.2 92.8

7.4 Per-Image PSNR (dB)

Image mdiffae_v1 (1 step) Flux.1 VAE Flux.2 VAE
p640x1536:94623 31.20 31.28 33.50
p640x1536:94624 27.32 27.62 30.03
p640x1536:94625 30.68 31.65 33.98
p640x1536:94626 29.14 29.44 31.53
p640x1536:94627 29.63 28.70 30.53
p640x1536:94628 25.60 26.38 28.88
p960x1024:216264 44.50 40.87 45.39
p960x1024:216265 26.42 25.82 27.80
p960x1024:216266 44.90 47.77 46.20
p960x1024:216267 37.78 37.65 39.23
p960x1024:216268 36.15 35.27 36.13
p960x1024:216269 29.37 28.45 30.24
p960x1024:216270 32.43 31.92 34.18
p960x1024:216271 41.23 38.92 42.18
p704x1472:94699 41.88 40.43 41.79
p704x1472:94700 29.66 29.52 32.08
p704x1472:94701 35.14 35.43 37.90
p704x1472:94702 30.90 30.73 32.50
p704x1472:94703 28.65 29.08 31.35
p704x1472:94704 28.98 29.22 31.84
p704x1472:94705 36.09 36.38 37.44
p704x1472:94706 31.53 31.50 33.66
r256_p1344x704:15577 27.89 28.32 29.98
r256_p1344x704:15578 28.07 29.35 30.79
r256_p1344x704:15579 29.56 30.44 31.83
r256_p1344x704:15580 32.89 36.12 36.03
r256_p1344x704:15581 32.26 37.42 36.94
r256_p1344x704:15582 28.74 30.64 32.10
r256_p1344x704:15583 31.99 34.67 34.54
r256_p1344x704:15584 28.42 30.34 31.76
r256_p896x1152:144131 30.02 33.10 33.60
r256_p896x1152:144132 33.19 34.23 35.32
r256_p896x1152:144133 35.42 37.85 37.33
r256_p896x1152:144134 31.41 34.25 34.47
r256_p896x1152:144135 27.13 28.17 29.87
r256_p896x1152:144136 32.75 35.24 35.68
r256_p896x1152:144137 28.60 32.70 32.86
r256_p896x1152:144138 24.76 24.15 25.63
VAE_accuracy_test_image 31.52 36.69 35.25