File size: 10,612 Bytes
9b877c3
128cb34
 
 
 
 
d9ec2a4
128cb34
9b877c3
128cb34
 
 
9b877c3
128cb34
9b877c3
128cb34
 
 
d9ec2a4
128cb34
 
 
9b877c3
128cb34
9b877c3
128cb34
9b877c3
128cb34
9b877c3
 
d9ec2a4
9b877c3
 
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
9b877c3
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b877c3
128cb34
9b877c3
128cb34
 
 
 
 
 
 
 
 
 
 
9b877c3
128cb34
 
 
 
 
 
 
 
 
9b877c3
128cb34
 
 
9b877c3
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b877c3
 
128cb34
 
 
05d1710
128cb34
05d1710
128cb34
05d1710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128cb34
5595c25
128cb34
05d1710
 
 
128cb34
05d1710
128cb34
05d1710
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# mDiffAE: A Fast Masked Diffusion Autoencoder β€” Technical Report

**Version 1** β€” March 2026

## 1. Introduction

mDiffAE (**M**asked **Diff**usion **A**uto**E**ncoder) builds on the [iRDiffAE](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) model family, which provides a fast, single-GPU-trainable diffusion autoencoder with good reconstruction quality, making it a good platform for experimenting with latent space regularization. See that report for full background on the shared components: VP diffusion, DiCo blocks, patchify encoder, AdaLN-Zero conditioning, and Path-Drop Guidance (PDG).

iRDiffAE v1 used REPA (aligning encoder features with a frozen DINOv2 teacher) to regularize the latent space. REPA produces well-structured latents but tends toward overly smooth representations. Here we replace it with **decoder token masking**.

### 1.1 Token Masking as Regularizer

With 50% probability per sample, the decoder only sees **25% of tokens** in the fused conditioning input. The spatial token grid is divided into non-overlapping 2Γ—2 groups; within each group a single token is randomly kept and the other three are replaced with a learned mask feature. The high masking ratio (75%) forces each spatial token to carry enough information for reconstruction even when most neighbors are absent. Lower masking ratios help downstream models learn sharp details quickly but fail to learn spatial coherence β€” the task becomes too close to local inpainting. We tested lower ratios and confirmed this tradeoff (see also He et al., 2022).

The 50% application probability controls the tradeoff between reconstruction quality and latent regularity.

### 1.2 Latent Noise Regularization

10% of the time, random noise is added to the latent representation. Unlike iRDiffAE (and the DiTo paper), which synchronizes the latent noise level with the pixel-space diffusion timestep, here the noise level is sampled independently from a **Beta(2,2)** distribution with a **logSNR shift of +1.0**, biasing it toward low noise. This improves robustness to incomplete convergence of downstream models and encourages local smoothness of the latent space distribution.

### 1.3 Simplified Decoder

The decoder uses only **4 blocks** (down from 8 in iRDiffAE v1) in a flat sequential layout β€” no start/middle/end groups, no skip connections. This halves the decoder's parameter count and is roughly 2Γ— faster.

### 1.4 Bottleneck

iRDiffAE v1 used 128 bottleneck channels, partly because REPA alignment occupies half the channels. Without REPA, 64 channels suffice and give better channel utilisation. This yields a 12Γ— compression ratio at patch size 16 (vs 6Γ— for iRDiffAE).

### 1.5 Empirical Results

Compared to iRDiffAE v1, mDiffAE achieves slightly higher PSNR with less oversmoothed latent PCA. In downstream diffusion model training, mDiffAE's latent space does not show the steep initial loss descent of iRDiffAE, but catches up after 50k–100k steps, producing more spatially coherent images with better high-frequency detail.

### 1.6 References

- He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022). *Masked Autoencoders Are Scalable Vision Learners*. CVPR 2022.
- Li, T., Chang, H., Mishra, S.K., Zhang, H., Katabi, D., & Krishnan, D. (2023). *MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis*. CVPR 2023.

## 2. Architecture Differences from iRDiffAE

| Aspect | iRDiffAE v1 (halfrepa c128) | mDiffAE v1 (masked c64) |
|--------|------------------------------|--------------------------|
| Bottleneck dim | 128 | **64** |
| Decoder depth | 8 (2 start + 4 middle + 2 end) | **4 (flat sequential)** |
| Decoder topology | START_MIDDLE_END_SKIP_CONCAT | **FLAT (no skip concat)** |
| Skip fusion | Yes (`fuse_skip` Conv1Γ—1) | **No** |
| PDG mechanism | Drop middle blocks β†’ mask_feature | **Token-level masking** (75% spatial tokens β†’ mask_feature) |
| PDG sensitivity | Moderate (strength 1.5–3.0) | **Very sensitive** (strength 1.01–1.05) |
| Training regularizer | REPA (half-channel DINOv2 alignment) + covreg | **Decoder token masking** (75% ratio, 50% apply prob) |
| Latent noise reg | Same mechanism | **Independent Beta(2,2), logSNR shift +1.0, 10% prob** |
| Depthwise kernel | 7Γ—7 | 7Γ—7 (same) |
| Model dim | 896 | 896 (same) |
| Encoder depth | 4 | 4 (same) |
| Best decode | 1 step DDIM | 1 step DDIM (same) |

## 3. Training-Time Masking Details

### 3.1 Token Masking Procedure

During training, with 50% probability per sample:
1. The fused decoder input (patchified x_t + upsampled encoder latents) is divided into non-overlapping 2Γ—2 spatial groups
2. Within each group, 3 of 4 tokens (75%) are selected for masking using random scoring
3. Masked tokens are replaced with a learned `mask_feature` parameter (same dimensionality as model_dim)
4. The decoder processes the partially-masked input normally through all blocks

### 3.2 PDG at Inference

At inference, the trained mask_feature can be used for Path-Drop Guidance (PDG): the conditional pass uses the full input, the unconditional pass applies 2Γ—2 groupwise masking at 75%, and the two are interpolated as usual. PDG can sharpen reconstructions but should be kept very low (strength 1.01–1.05); higher values cause artifacts.

## 4. Flat Decoder Architecture

### 4.1 iRDiffAE v1 Decoder (for comparison)

```
Fused input β†’ Start blocks (2) β†’ [save for skip] β†’
  Middle blocks (4) β†’ [cat with saved skip] β†’ FuseSkip Conv1Γ—1 β†’
  End blocks (2) β†’ Output head
```

8 blocks split into three groups with a skip connection. For PDG, the middle blocks are dropped and replaced with a learned mask feature.

### 4.2 mDiffAE v1 Decoder

```
Patchify(x_t) β†’ RMSNorm β†’ x_feat [B, 896, h, w]
LatentUp(z) β†’ RMSNorm β†’ z_up [B, 896, h, w]
FuseIn(cat(x_feat, z_up)) β†’ fused [B, 896, h, w]
[Optional: token masking for PDG]
TimeEmbed(t) β†’ cond [B, 896]
Block_0 β†’ Block_1 β†’ Block_2 β†’ Block_3 β†’ out [B, 896, h, w]
RMSNorm β†’ Conv1x1 β†’ PixelShuffle β†’ x0_hat [B, 3, H, W]
```

4 flat sequential blocks, no skip connections. Roughly half the decoder parameters of iRDiffAE.

## 5. Model Configuration

| Parameter | Value |
|-----------|-------|
| `in_channels` | 3 |
| `patch_size` | 16 |
| `model_dim` | 896 |
| `encoder_depth` | 4 |
| `decoder_depth` | 4 |
| `bottleneck_dim` | 64 |
| `mlp_ratio` | 4.0 |
| `depthwise_kernel_size` | 7 |
| `adaln_low_rank_rank` | 128 |
| `logsnr_min` | βˆ’10.0 |
| `logsnr_max` | 10.0 |
| `pixel_noise_std` | 0.558 |
| `pdg_mask_ratio` | 0.75 |

Training checkpoint: step 708,000 (EMA weights).

## 6. Inference Recommendations

| Setting | Value | Notes |
|---------|-------|-------|
| Sampler | DDIM | Best for 1-step |
| Steps | 1 | PSNR-optimal |
| PDG | Disabled | Default |
| PDG strength | 1.01–1.05 | If enabled; can sharpen but artifacts above ~1.1 |

## 7. Results

Reconstruction quality evaluated on two image sets: a large benchmark (N=2000, 2/3 photographs + 1/3 book covers) for summary statistics, and a curated 39-image set for per-image comparisons. Flux.1 and Flux.2 VAEs are included as references. All models use 1-step DDIM, seed 42, no PDG, bfloat16.

### 7.1 Summary PSNR (N=2000 images)

| Model | Mean PSNR (dB) | Std (dB) | Median (dB) |
|-------|---------------|----------|-------------|
| mDiffAE v1 (1 step) | 34.15 | 5.14 | 33.82 |
| Flux.1 VAE | 34.62 | 4.31 | 35.17 |
| Flux.2 VAE | 36.30 | 4.58 | 36.14 |

**Percentile distribution:**

| Percentile | mDiffAE v1 | Flux.1 VAE | Flux.2 VAE |
|------------|-----------|------------|------------|
| p5 | 26.22 | 27.06 | 28.99 |
| p10 | 27.54 | 28.45 | 30.38 |
| p25 | 30.22 | 31.58 | 32.87 |
| p50 | 33.82 | 35.17 | 36.14 |
| p75 | 38.20 | 37.99 | 39.85 |
| p90 | 41.21 | 39.75 | 42.51 |
| p95 | 42.49 | 40.57 | 43.64 |

> Timings on the 39-image set (batch 8, bf16, NVIDIA RTX Pro 6000 Blackwell): mDiffAE encode 2.4 ms + decode 3.0 ms = **5.4 ms/image** total, vs Flux.1 at 202 ms and Flux.2 at 138 ms β€” roughly **37Γ—** and **26Γ—** faster end-to-end.

### 7.2 Interactive Viewer

**[Open full-resolution comparison viewer](https://huggingface.co/spaces/data-archetype/mdiffae-results)** β€” side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.

### 7.3 Per-Image Results (39-image curated set)

Inference settings: 1-step DDIM, seed 42, no PDG, batch size 8.

| Metric | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
|--------|--------|--------|--------|
| Avg PSNR (dB) | 31.89 | 32.76 | 34.16 |
| Avg encode (ms/image) | 2.4 | 63.9 | 45.7 |
| Avg decode (ms/image) | 3.0 | 138.2 | 92.8 |

### 7.4 Per-Image PSNR (dB)

| Image | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
|-------|--------|--------|--------|
| p640x1536:94623 | 31.20 | 31.28 | 33.50 |
| p640x1536:94624 | 27.32 | 27.62 | 30.03 |
| p640x1536:94625 | 30.68 | 31.65 | 33.98 |
| p640x1536:94626 | 29.14 | 29.44 | 31.53 |
| p640x1536:94627 | 29.63 | 28.70 | 30.53 |
| p640x1536:94628 | 25.60 | 26.38 | 28.88 |
| p960x1024:216264 | 44.50 | 40.87 | 45.39 |
| p960x1024:216265 | 26.42 | 25.82 | 27.80 |
| p960x1024:216266 | 44.90 | 47.77 | 46.20 |
| p960x1024:216267 | 37.78 | 37.65 | 39.23 |
| p960x1024:216268 | 36.15 | 35.27 | 36.13 |
| p960x1024:216269 | 29.37 | 28.45 | 30.24 |
| p960x1024:216270 | 32.43 | 31.92 | 34.18 |
| p960x1024:216271 | 41.23 | 38.92 | 42.18 |
| p704x1472:94699 | 41.88 | 40.43 | 41.79 |
| p704x1472:94700 | 29.66 | 29.52 | 32.08 |
| p704x1472:94701 | 35.14 | 35.43 | 37.90 |
| p704x1472:94702 | 30.90 | 30.73 | 32.50 |
| p704x1472:94703 | 28.65 | 29.08 | 31.35 |
| p704x1472:94704 | 28.98 | 29.22 | 31.84 |
| p704x1472:94705 | 36.09 | 36.38 | 37.44 |
| p704x1472:94706 | 31.53 | 31.50 | 33.66 |
| r256_p1344x704:15577 | 27.89 | 28.32 | 29.98 |
| r256_p1344x704:15578 | 28.07 | 29.35 | 30.79 |
| r256_p1344x704:15579 | 29.56 | 30.44 | 31.83 |
| r256_p1344x704:15580 | 32.89 | 36.12 | 36.03 |
| r256_p1344x704:15581 | 32.26 | 37.42 | 36.94 |
| r256_p1344x704:15582 | 28.74 | 30.64 | 32.10 |
| r256_p1344x704:15583 | 31.99 | 34.67 | 34.54 |
| r256_p1344x704:15584 | 28.42 | 30.34 | 31.76 |
| r256_p896x1152:144131 | 30.02 | 33.10 | 33.60 |
| r256_p896x1152:144132 | 33.19 | 34.23 | 35.32 |
| r256_p896x1152:144133 | 35.42 | 37.85 | 37.33 |
| r256_p896x1152:144134 | 31.41 | 34.25 | 34.47 |
| r256_p896x1152:144135 | 27.13 | 28.17 | 29.87 |
| r256_p896x1152:144136 | 32.75 | 35.24 | 35.68 |
| r256_p896x1152:144137 | 28.60 | 32.70 | 32.86 |
| r256_p896x1152:144138 | 24.76 | 24.15 | 25.63 |
| VAE_accuracy_test_image | 31.52 | 36.69 | 35.25 |