irdiffae-v1 / technical_report.md
data-archetype's picture
Upload folder using huggingface_hub
eb377bc verified

iRDiffAE v1.0 β€” Technical Report

iRepa Diffusion AutoEncoder = iRDiffAE

A fast, single-GPU-trainable diffusion autoencoder with spatially structured latents for rapid downstream model convergence. Encoding runs ~5Γ— faster than Flux VAE; single-step decoding runs ~3Γ— faster.

Contents

  1. VP Diffusion Parameterization
  2. Architecture
  3. Design Choices
  4. Model Configuration
  5. Training
  6. Inference
  7. Results

References:

  • SiD2 β€” Hoogeboom et al., Simpler Diffusion (SiD2): 1.5 FID on ImageNet512 with pixel-space diffusion, arXiv:2410.19324, ICLR 2025.
  • DiTo β€” Yin et al., Diffusion Autoencoders are Scalable Image Tokenizers, arXiv:2501.18593, 2025.
  • DiCo β€” Ai et al., DiCo: Revitalizing ConvNets for Scalable and Efficient Diffusion Modeling, arXiv:2505.11196, 2025.
  • SPRINT β€” Park et al., Sprint: Sparse-Dense Residual Fusion for Efficient Diffusion Transformers, arXiv:2510.21986, 2025.
  • Z-image β€” Cai et al., Z-Image: An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer, arXiv:2511.22699, 2025.
  • iREPA β€” Singh et al., What matters for Representation Alignment: Global Information or Spatial Structure?, arXiv:2512.10794, 2025.

1. VP Diffusion Parameterization

iRDiffAE uses the variance-preserving (VP) diffusion framework from SiD2 with an x-prediction objective.

1.1 Forward Process

Given a clean image x0x_0, the forward process constructs a noisy sample at continuous time t∈[0,1]t \in [0, 1]:

xt=Ξ±t x0+Οƒt Ρ,Ρ∼N(0,s2I)x_t = \alpha_t \, x_0 + \sigma_t \, \varepsilon, \qquad \varepsilon \sim \mathcal{N}(0, s^2 I)

where s=0.558s = 0.558 is the pixel-space noise standard deviation (estimated from the dataset image distribution) and the VP constraint holds:

Ξ±t2+Οƒt2=1\alpha_t^2 + \sigma_t^2 = 1

1.2 Log Signal-to-Noise Ratio

The schedule is parameterized through the log signal-to-noise ratio:

Ξ»t=log⁑αt2Οƒt2\lambda_t = \log \frac{\alpha_t^2}{\sigma_t^2}

which monotonically decreases as t→1t \to 1 (pure noise). From λt\lambda_t we recover αt\alpha_t and σt\sigma_t via the sigmoid function:

Ξ±t=Οƒ(Ξ»t),Οƒt=Οƒ(βˆ’Ξ»t)\alpha_t = \sqrt{\sigma(\lambda_t)}, \qquad \sigma_t = \sqrt{\sigma(-\lambda_t)}

where Οƒ(β‹…)\sigma(\cdot) is the logistic sigmoid.

1.3 Cosine-Interpolated Schedule

Following SiD2, the logSNR schedule uses cosine interpolation:

Ξ»(t)=βˆ’2log⁑tan⁑(aβ‹…t+b)\lambda(t) = -2 \log \tan(a \cdot t + b)

where aa and bb are computed to satisfy the boundary conditions Ξ»(0)=Ξ»max\lambda(0) = \lambda_\text{max} and Ξ»(1)=Ξ»min\lambda(1) = \lambda_\text{min}:

b=arctan⁑ ⁣(eβˆ’Ξ»max/2),a=arctan⁑ ⁣(eβˆ’Ξ»min/2)βˆ’bb = \arctan\!\bigl(e^{-\lambda_\text{max}/2}\bigr), \qquad a = \arctan\!\bigl(e^{-\lambda_\text{min}/2}\bigr) - b

SiD2 also defines a "shifted cosine" variant with resolution-dependent additive shifts Ξ”high\Delta_\text{high} and Ξ”low\Delta_\text{low}:

Ξ»shifted(t)=(1βˆ’t)β‹…[Ξ»(t)+Ξ”high]+tβ‹…[Ξ»(t)+Ξ”low]\lambda_\text{shifted}(t) = (1 - t) \cdot [\lambda(t) + \Delta_\text{high}] + t \cdot [\lambda(t) + \Delta_\text{low}]

iRDiffAE uses Ξ»min=βˆ’10\lambda_\text{min} = -10, Ξ»max=10\lambda_\text{max} = 10, Ξ”high=0\Delta_\text{high} = 0, and Ξ”low=0\Delta_\text{low} = 0 (no resolution-dependent shift), so the schedule reduces to the unshifted cosine interpolation.

1.4 X-Prediction Objective

The model predicts the clean image x^0=fΞΈ(xt,t,z)\hat{x}_0 = f_\theta(x_t, t, z) conditioned on the encoder latents zz.

Schedule-invariant loss. Following SiD2, the training loss is defined as an integral over logSNR Ξ»\lambda, making it invariant to the choice of noise schedule:

L(x)=∫w(Ξ») βˆ₯x0βˆ’x^0βˆ₯2 dΞ»\mathcal{L}(x) = \int w(\lambda) \, \| x_0 - \hat{x}_0 \|^2 \, d\lambda

Since timesteps are sampled uniformly t∼U(0,1)t \sim \mathcal{U}(0,1) rather than integrated over Ξ»\lambda directly, the change of variable dΞ»=dΞ»dt dtd\lambda = \frac{d\lambda}{dt} \, dt introduces a Jacobian factor:

L=Et∼U(0,1)[(βˆ’dΞ»dt)β‹…w(Ξ»(t))β‹…βˆ₯x0βˆ’x^0βˆ₯2]\mathcal{L} = \mathbb{E}_{t \sim \mathcal{U}(0,1)} \left[ \left(-\frac{d\lambda}{dt}\right) \cdot w(\lambda(t)) \cdot \| x_0 - \hat{x}_0 \|^2 \right]

Sigmoid weighting. SiD2 defines the weighting function in Ξ΅\varepsilon-prediction form as Οƒ(bβˆ’Ξ»)\sigma(b - \lambda) β€” a sigmoid centered at bias bb. Converting from Ξ΅\varepsilon-prediction to xx-prediction MSE via βˆ₯Ξ΅βˆ’Ξ΅^βˆ₯2=eΞ»βˆ₯x0βˆ’x^0βˆ₯2\|\varepsilon - \hat{\varepsilon}\|^2 = e^{\lambda} \|x_0 - \hat{x}_0\|^2 gives:

Οƒ(bβˆ’Ξ»)β‹…eΞ»=ebβ‹…Οƒ(Ξ»βˆ’b)\sigma(b - \lambda) \cdot e^{\lambda} = e^b \cdot \sigma(\lambda - b)

Combining the Jacobian with the weighting, the per-sample weight used in training is:

weight(t)=βˆ’12dΞ»dtβ‹…ebβ‹…Οƒ(Ξ»(t)βˆ’b)\text{weight}(t) = -\frac{1}{2} \frac{d\lambda}{dt} \cdot e^b \cdot \sigma(\lambda(t) - b)

The bias b=βˆ’2.0b = -2.0 controls the relative emphasis on high-SNR (low-noise) vs low-SNR (high-noise) timesteps. A more negative bb shifts emphasis toward noisier timesteps.

1.5 Sampling

At inference, each timestep tt in the schedule is first mapped to logSNR via the cosine-interpolated schedule (Section 1.3), then to diffusion coefficients:

tβ€…β€Šβ†’scheduleβ€…β€ŠΞ»(t)β€…β€Šβ†’sigmoidβ€…β€ŠΞ±t=Οƒ(Ξ»),Οƒt=Οƒ(βˆ’Ξ»)t \;\xrightarrow{\text{schedule}}\; \lambda(t) \;\xrightarrow{\text{sigmoid}}\; \alpha_t = \sqrt{\sigma(\lambda)}, \quad \sigma_t = \sqrt{\sigma(-\lambda)}

DDIM. The default sampler uses a descending time schedule t0>t1>β‹―>tNt_0 > t_1 > \cdots > t_N with NN denoising steps. At each step:

  1. Predict x^0=fΞΈ(xti,ti,z)\hat{x}_0 = f_\theta(x_{t_i}, t_i, z)
  2. Reconstruct Ξ΅^=xtiβˆ’Ξ±tix^0Οƒti\hat{\varepsilon} = \frac{x_{t_i} - \alpha_{t_i} \hat{x}_0}{\sigma_{t_i}}
  3. Step: xti+1=Ξ±ti+1x^0+Οƒti+1Ξ΅^x_{t_{i+1}} = \alpha_{t_{i+1}} \hat{x}_0 + \sigma_{t_{i+1}} \hat{\varepsilon}

DPM++2M. Also supported as an alternative sampler, using a half-lambda (\(\lambda/2\)) exponential integrator for faster convergence with fewer steps.


2. Architecture

2.1 Overview

iRDiffAE consists of a deterministic encoder and an iterative VP diffusion decoder. The encoder maps an image to a compact spatial latent, and the decoder reconstructs the image by iteratively denoising from Gaussian noise, conditioned on both the latents and the diffusion timestep.

Encoder:  x ∈ ℝ^{BΓ—3Γ—HΓ—W}  β†’  z ∈ ℝ^{BΓ—CΓ—hΓ—w}     (deterministic, single pass)
Decoder:  (z, t, x_t)       β†’  xΜ‚β‚€ ∈ ℝ^{BΓ—3Γ—HΓ—W}    (iterative, N diffusion steps)

where h=H/ph = H / p, w=W/pw = W / p, pp is the patch size, and CC is the bottleneck dimension.

2.2 DiCo Block

Both encoder and decoder use DiCo blocks (from the DiCo paper), a convolution-based alternative to transformer blocks. Each block consists of two residual paths:

Conv path:

y=Conv1×1→DWConvk×k→SiLU→CCA→Conv1×1y = \text{Conv}_{1 \times 1} \to \text{DWConv}_{k \times k} \to \text{SiLU} \to \text{CCA} \to \text{Conv}_{1 \times 1}

MLP path:

y=Conv1×1→GELU→Conv1×1y = \text{Conv}_{1 \times 1} \to \text{GELU} \to \text{Conv}_{1 \times 1}

where DWConvkΓ—k\text{DWConv}_{k \times k} is a depthwise convolution (default k=7k = 7) and CCA\text{CCA} is Compact Channel Attention:

CCA(y)=yβŠ™Οƒ(Conv1Γ—1(AvgPool(y)))\text{CCA}(y) = y \odot \sigma\bigl(\text{Conv}_{1 \times 1}(\text{AvgPool}(y))\bigr)

Both paths use channel-wise RMSNorm (without affine parameters) as pre-norm. Residual connections use gating:

  • Encoder (unconditioned): learned per-channel gate parameters x←x+gβ‹…yx \leftarrow x + g \cdot y, where gg is a learnable vector initialized to zero.
  • Decoder (conditioned): AdaLN-Zero gating via x←x+tanh⁑(gadaln)β‹…yx \leftarrow x + \tanh(g_\text{adaln}) \cdot y, where gadalng_\text{adaln} comes from the timestep conditioning.

2.3 Encoder

The encoder is deterministic β€” no variational posterior, no KL loss. Latent normalization uses channel-wise RMSNorm without affine parameters, following DiTo's finding that this outperforms KL regularization.

Input:       x ∈ ℝ^{BΓ—3Γ—HΓ—W}
Patchify:    PixelUnshuffle(p) β†’ Conv 1Γ—1     β†’  ℝ^{BΓ—DΓ—hΓ—w}
Norm:        ChannelWise RMSNorm (affine)
Blocks:      DiCoBlock Γ— depth_enc              (unconditioned, learned gates)
Bottleneck:  Conv 1Γ—1 (D β†’ C)
Norm out:    ChannelWise RMSNorm (no affine)
Output:      z ∈ ℝ^{BΓ—CΓ—hΓ—w}

2.4 Decoder

The decoder predicts x^0\hat{x}_0 from noisy input xtx_t, conditioned on encoder latents zz and timestep tt.

Patchify x_t:  PixelUnshuffle(p) β†’ Conv 1Γ—1   β†’  ℝ^{BΓ—DΓ—hΓ—w}
Norm:          ChannelWise RMSNorm (affine)
Upsample z:    Conv 1Γ—1 (C β†’ D) β†’ RMSNorm     β†’  ℝ^{BΓ—DΓ—hΓ—w}
Fuse:          Concat[x_feat, z_up] β†’ Conv 1Γ—1 β†’  ℝ^{BΓ—DΓ—hΓ—w}

Time embed:    t β†’ sinusoidal β†’ MLP            β†’  cond ∈ ℝ^{BΓ—D}

Start blocks:  DiCoBlock Γ— 2                    (AdaLN conditioned)
Middle blocks: DiCoBlock Γ— (depth - 4)          (AdaLN conditioned)
Skip fusion:   Concat[start_out, middle_out] β†’ Conv 1Γ—1
End blocks:    DiCoBlock Γ— 2                    (AdaLN conditioned)

Norm:          ChannelWise RMSNorm (affine)
Output head:   Conv 1Γ—1 (D β†’ 3Β·pΒ²) β†’ PixelShuffle(p)  β†’  xΜ‚β‚€ ∈ ℝ^{BΓ—3Γ—HΓ—W}

2.5 AdaLN: Shared Base + Low-Rank Deltas

Timestep conditioning follows the Z-image style AdaLN (Cai et al., 2025): a shared base projection plus a low-rank delta per layer, scale-and-gate modulation with no shift, and a tanh⁑\tanh on the gate.

A single base projector is shared across all decoder layers, and each layer adds a low-rank correction:

mi=Base(SiLU(cond))+Ξ”i(SiLU(cond))m_i = \text{Base}(\text{SiLU}(\text{cond})) + \Delta_i(\text{SiLU}(\text{cond}))

where Base:RD→R4D\text{Base}: \mathbb{R}^D \to \mathbb{R}^{4D} is a linear projection (zero-initialized) and Δi:RD→downRr→upR4D\Delta_i: \mathbb{R}^D \xrightarrow{\text{down}} \mathbb{R}^r \xrightarrow{\text{up}} \mathbb{R}^{4D} is a low-rank factorization with rank rr (zero-initialized up-projection).

The packed modulation mi∈RBΓ—4Dm_i \in \mathbb{R}^{B \times 4D} is chunked into four vectors (scaleconv,gateconv,scalemlp,gatemlp)(\text{scale}_\text{conv}, \text{gate}_\text{conv}, \text{scale}_\text{mlp}, \text{gate}_\text{mlp}) which modulate the conv and MLP paths (no shift term):

x^=RMSNorm(x)βŠ™(1+scale)\hat{x} = \text{RMSNorm}(x) \odot (1 + \text{scale}) x←x+tanh⁑(gate)β‹…f(x^)x \leftarrow x + \tanh(\text{gate}) \cdot f(\hat{x})

2.6 Path-Drop Guidance (PDG)

At inference, iRDiffAE supports Path-Drop Guidance β€” a classifier-free guidance analogue that does not require training with conditioning dropout. Instead, it exploits the decoder's skip connection:

  1. Conditional pass: run all blocks normally β†’ x^0cond\hat{x}_0^\text{cond}
  2. Unconditional pass: replace the middle block output with a learned mask feature m∈R1Γ—DΓ—1Γ—1m \in \mathbb{R}^{1 \times D \times 1 \times 1} (initialized to zero), effectively dropping the deep processing path β†’ x^0uncond\hat{x}_0^\text{uncond}
  3. Guided prediction: x^0=x^0uncond+sβ‹…(x^0condβˆ’x^0uncond)\hat{x}_0 = \hat{x}_0^\text{uncond} + s \cdot (\hat{x}_0^\text{cond} - \hat{x}_0^\text{uncond})

where ss is the guidance strength.


3. Design Choices

3.1 Convolutional Architecture

iRDiffAE uses a fully convolutional architecture rather than a vision transformer. For an autoencoder whose goal is faithful pixel-level reconstruction (not global semantic understanding), convolutions offer several advantages:

  • Resolution generalization. Convolutions operate on local patches and generalize naturally to arbitrary image dimensions without interpolating position embeddings or suffering attention distribution shift from sequence length changes with global attention. Convolutions are also more efficient than sliding window attention for local operations.
  • Translation invariance. The built-in inductive bias of weight sharing across spatial positions is well matched to reconstruction, where the same local patterns (edges, textures, gradients) conditioned on the low-frequency latent recur throughout the image.
  • Locality. Reconstruction quality depends on preserving fine spatial detail. Convolutions are inherently local operators, avoiding the quadratic cost of global attention while focusing computation where it matters most for reconstruction.

Transformers are better suited for image generation (where global context and long-range dependencies are essential), but convolutions are better suited for autoencoders. The DiCo block provides a well-tested, strong building block for convolutional diffusion models, combining depthwise convolutions with compact channel attention in a design that has been validated at scale.

3.2 Single-Stride Encoder with Final Bottleneck

The encoder uses a single spatial stride (via PixelUnshuffle at the input) followed by a stack of DiCo blocks operating at constant spatial resolution, then a final 1Γ—1 convolution to project from model dimension DD to bottleneck dimension CC. This differs from classical VAE encoders that use progressive downsampling with channel expansion at each stage.

The single-stride design ensures that all encoder blocks see the full spatial resolution and full channel width simultaneously. The information bottleneck is imposed only at the very end, where a single linear projection selects which CC channels to retain. Progressive compression forces early layers to discard information before the full feature representation has been computed, which is both computationally heavier and representationally suboptimal.

3.3 Diffusion Decoding vs. GAN-Based Decoding

Empirically, diffusion autoencoders produce a much cleaner latent space than patch-GAN + LPIPS-driven VAEs. The iterative diffusion process acts as a strong structural prior on the decoder, which in turn relaxes the pressure on the encoder to encode every pixel perfectly β€” the latent space can focus on semantically meaningful structure rather than adversarial reconstruction artifacts. This makes diffusion AE latents easier for a downstream latent-space diffusion model to learn.

Training efficiency. The diffusion AE training objective is a straightforward weighted MSE loss with no adversarial component β€” no discriminator, no LPIPS perceptual loss, no delicate GAN balancing. At batch size 128, the model uses less than 30 GB of VRAM and runs at 7–10 iterations per second, making it trainable on a single RTX 5090 in one to two days. By contrast, GAN + LPIPS-based VAEs require many days of H100 time and are notoriously difficult to stabilize, with no publicly known working recipe for training from scratch at comparable quality.

3.4 Skip Connection and Path-Drop Guidance

The decoder's start β†’ middle β†’ skip-fuse β†’ end architecture is inspired by SPRINT's sparse-dense residual fusion. The start blocks process the fused input (noised image + latents) at full fidelity, the middle blocks perform deeper processing, and the skip connection concatenates the start block output with the middle block output before the end blocks.

This design serves three purposes:

  1. Regularization. The skip path ensures that even if the middle blocks are dropped or poorly conditioned, the end blocks still receive meaningful features from the start blocks.
  2. High-frequency preservation. The start blocks (which see the input most directly) pass fine detail through the skip to the end blocks, preventing the middle blocks from washing out high-frequency information.
  3. Path-Drop Guidance (PDG). At inference, replacing the middle block output with a learned zero-initialized mask feature creates an "unconditional" prediction that preserves the skip path but drops the deep processing. Interpolating between the conditional and unconditional predictions (as in classifier-free guidance) sharpens the output distribution β€” and hence the reconstructed image β€” without requiring any training-time conditioning dropout.

3.5 Half-Channel Representation Alignment (iREPA)

Singh et al. (iREPA, arXiv:2512.10794) show that spatial structure of pretrained encoder representations β€” not global semantic accuracy β€” drives generation quality when using representation alignment to guide diffusion training. Their method aligns internal diffusion features with patch tokens from a frozen vision encoder (e.g. DINOv2) using patch-wise cosine similarity, with a conv-based projection and spatial normalization to preserve local structure.

iRDiffAE adopts iREPA but aligns only the first half of the bottleneck channels (64 of 128) to a frozen DINOv3-S teacher. The rationale: models like DINOv3-S are trained for semantic understanding and do not preserve high-frequency detail. Aligning all channels biases the encoder toward dropping fine detail in favour of semantic structure. By aligning only half, the bottleneck decomposes into:

  • Channels 0–63 (aligned): semantic and spatial structure, guided by the teacher's patch tokens.
  • Channels 64–127 (free): fine detail and high-frequency information, driven purely by the reconstruction loss.

The alignment operates on the encoder output after the final RMSNorm (no affine), so the teacher sees unit-RMS normalized features.

Implementation details:

Encoder latents z ∈ ℝ^{BΓ—128Γ—hΓ—w}  (after RMSNorm)
                    ↓
        z_aligned = z[:, :64, :, :]
                    ↓
        Conv2d 3Γ—3 (64 β†’ 384, padding=1)   ← iREPA conv projection
                    ↓
        student tokens ∈ ℝ^{BΓ—TΓ—384}
                    ↓
        patch-wise cosine similarity with DINOv3-S tokens

The teacher's patch tokens are spatially normalized before comparison (\(\gamma = 0.7\), removing 70% of the global mean) following iREPA's prescription. The alignment loss is weighted at 0.5 for most of training, reduced to 0.25 toward the end to improve reconstruction fidelity.

Tradeoff. The alignment costs 2–3 dB of average reconstruction PSNR compared to training without it. In exchange, downstream diffusion and flow matching models trained on the aligned latent space converge significantly faster β€” empirically validating the iREPA finding that spatial structure of the latent representation matters more than raw reconstruction fidelity for generation quality.


4. Model Configuration

Parameter Value
Patch size pp 16
Bottleneck dim CC 128
Compression ratio 6Γ—
Model dim DD 896
Total parameters 133.4M
Encoder depth 4
Decoder depth 8
Decoder layout 2 start + 4 middle + 2 end
MLP ratio 4.0
Depthwise kernel 7Γ—7
AdaLN rank rr 128
Ξ»min\lambda_\text{min} βˆ’10
Ξ»max\lambda_\text{max} +10
Sigmoid bias bb βˆ’2.0
Pixel noise std ss 0.558

Compression ratio = (3Γ—p2)/C(3 \times p^2) / C: the factor by which the latent representation is smaller than the raw pixel data. With patch size 16 and 128 bottleneck channels, the encoder produces a 16Γ—16\times spatial downsampling (\(256\times\) area reduction) at 6Γ— total compression.


5. Training

5.1 Data

Training uses ~5M images at various resolutions: mostly photographs, with a significant proportion of illustrations and text-heavy images (documents, screenshots, book covers, diagrams) to encourage crisp line and edge reconstruction. Images are loaded via two strategies in a 50/50 mix:

  • Full-image downsampling: images are bucketed by aspect ratio and downsampled to ~256Β² resolution (preserving aspect ratio).
  • Random 256Γ—256 crops: deterministic patches extracted from images stored at β‰₯512px resolution.

This mixed strategy exposes the model to both global scene composition (via downsampled full images) and fine local detail (via crops from higher-resolution sources).

5.2 Timestep Sampling

Timesteps are drawn via stratified uniform sampling, a variance reduction technique from Monte Carlo integration. The base distribution is uniform over the endpoint-trimmed domain [Ξ΅,1βˆ’Ξ΅][\varepsilon, 1 - \varepsilon]. Rather than drawing BB i.i.d. samples (which can cluster or leave gaps by chance), stratified sampling divides the domain into BB equal-mass buckets and draws exactly one sample per bucket:

ti=ulo+(uhiβˆ’ulo)β‹…i+UiB,Ui∼U(0,1),i=0,…,Bβˆ’1t_i = u_\text{lo} + (u_\text{hi} - u_\text{lo}) \cdot \frac{i + U_i}{B}, \qquad U_i \sim \mathcal{U}(0, 1), \quad i = 0, \ldots, B-1

where ulo=F(Ξ΅)u_\text{lo} = F(\varepsilon), uhi=F(1βˆ’Ξ΅)u_\text{hi} = F(1 - \varepsilon), and FF is the CDF of the base distribution (identity for uniform). This guarantees that every batch covers the full timestep range evenly, reducing the variance of the per-batch gradient estimate without introducing bias.

Endpoint trimming uses Ξ΅=Οƒ(βˆ’7.5)β‰ˆ5.5Γ—10βˆ’4\varepsilon = \sigma(-7.5) \approx 5.5 \times 10^{-4}, keeping βˆ£Ξ»βˆ£β‰€15|\lambda| \leq 15.

5.3 Latent Noise Synchronization (DiTo Regularization)

Following DiTo, encoder latents are regularized via noise synchronization during training. With probability p=0.1p = 0.1, a subset of clean latents z0z_0 are replaced with noisy versions:

zΟ„=(1βˆ’Ο„fm)β‹…z0+Ο„fmβ‹…Ξ΅z,Ξ΅z∼N(0,I)z_\tau = (1 - \tau_\text{fm}) \cdot z_0 + \tau_\text{fm} \cdot \varepsilon_z, \qquad \varepsilon_z \sim \mathcal{N}(0, I)

where Ο„\tau is sampled uniformly in [0,t][0, t] (ensuring the latent is never noisier than the pixel-space input) and converted to a flow-matching time via the logSNR mapping, since downstream latent-space models are expected to use flow matching:

Ο„fm=Οƒ(βˆ’12 λ(Ο„))\tau_\text{fm} = \sigma(-\tfrac{1}{2} \, \lambda(\tau))

This synchronizes the noising process in latent space with pixel space, ensuring that the latent representation remains useful when a downstream latent diffusion model adds noise during its own forward process.

5.4 Pixel vs. Latent Noise Standards

The model uses different noise standard deviations in pixel space and latent space:

  • Pixel space: s=0.558s = 0.558, matching an estimate of the per-channel standard deviation of natural images over the training dataset. This ensures that at t=1t = 1 the noise distribution roughly matches the data distribution scale.
  • Latent space: s=1.0s = 1.0, because encoder latents are RMSNorm'd to unit scale. Downstream latent diffusion models (which use flow matching) operate with this unit-variance assumption.

The conversion between pixel-space VP logSNR and latent-space flow-matching time uses the sigmoid mapping tfm=Οƒ(βˆ’12Ξ»)t_\text{fm} = \sigma(-\frac{1}{2}\lambda), which naturally accounts for the different noise scales.

5.5 Optimizer and Hyperparameters

Hyperparameter Value
Optimizer AdamW
Learning rate 1Γ—10βˆ’41 \times 10^{-4}
Weight decay 0
Adam Ξ΅\varepsilon 1Γ—10βˆ’81 \times 10^{-8}
LR schedule Constant (after warmup), halved for last 20% of training
Warmup steps 2,000
Batch size 128
EMA decay 0.9999
Precision AMP bfloat16 (FP32 master weights, TF32 matmul)
Compilation torch.compile enabled
Training steps 700k
Training images ~5M
Hardware Single GPU

5.6 Loss

L=Lrecon+wrepaβ‹…Lrepa\mathcal{L} = \mathcal{L}_\text{recon} + w_\text{repa} \cdot \mathcal{L}_\text{repa} Lrecon\mathcal{L}_\text{recon} is the SiD2 sigmoid-weighted x-prediction MSE (Section 1.4) with bias b=βˆ’2.0b = -2.0, computed in float32 for numerical stability. Lrepa\mathcal{L}_\text{repa} is the iREPA half-channel alignment loss (Section 3.5): mean patch-wise negative cosine similarity between the first 64 encoder channels (projected via 3Γ—3 conv) and spatially-normalized DINOv3-S tokens. wrepa=0.5w_\text{repa} = 0.5 for the majority of training, lowered to 0.25 toward the end to recover reconstruction fidelity.


6. Inference

6.1 Sampling Pipeline

Decoding proceeds by iteratively denoising from Gaussian noise (\(\varepsilon \sim \mathcal{N}(0, s^2 I)\) with s=0.558s = 0.558). A descending time schedule t0>t1>β‹―>tNβˆ’1t_0 > t_1 > \cdots > t_{N-1} is generated (linearly spaced by default), and at each step tit_i is mapped to logSNR and then to diffusion coefficients:

  1. Compute Ξ»i=Ξ»(ti)\lambda_i = \lambda(t_i) via the cosine-interpolated schedule
  2. Derive Ξ±i=Οƒ(Ξ»i)\alpha_i = \sqrt{\sigma(\lambda_i)}, Οƒi=Οƒ(βˆ’Ξ»i)\sigma_i = \sqrt{\sigma(-\lambda_i)}
  3. Run the DDIM or DPM++2M update step (Section 1.5)

The initial state is xt0=Οƒt0β‹…Ξ΅x_{t_0} = \sigma_{t_0} \cdot \varepsilon (pure noise scaled by the first-step sigma).

6.2 Recommended Settings

1 DDIM step with PDG disabled is generally recommended β€” it achieves the best PSNR and is extremely fast (a single forward pass through the decoder). For images with sharp text or fine line art, 10–20 steps can sometimes improve edge crispness.

Setting Recommended Sharp text
Sampler DDIM DDIM or DPM++2M
Steps 1 10–20
Schedule Linear Linear
PDG Disabled Disabled or 2.0

Reconstruction PSNR vs. decode steps (N=2000 images, 2/3 photos + 1/3 book covers, EMA weights):

Decode steps Avg PSNR (dB)
1 33.71
10 32.69
20 32.30

PSNR decreases slightly with more steps because the model is trained for single-step x-prediction; additional sampling steps introduce accumulated discretization error. The 128-channel bottleneck preserves enough information that a single decoder pass suffices for high-fidelity reconstruction.

Multi-step sampling can help recover sharper edges on text and line art. PDG (strength 2–4) further increases perceptual sharpness but tends to hallucinate high-frequency detail β€” a direct manifestation of the perception-distortion tradeoff.

Inference latency (batch of 4 Γ— 256Γ—256, bf16, NVIDIA RTX PRO 6000 Blackwell, 100 iterations after warmup):

Operation iRDiffAE Flux.1 VAE Flux.2 VAE
Encode 2.1 ms 11.6 ms 9.1 ms
Decode (1 step) 8.3 ms 24.9 ms 20.0 ms
Decode (10 steps) 52.7 ms β€” β€”
Decode (20 steps) 100.6 ms β€” β€”
Roundtrip (enc + 1-step dec) 11.1 ms 36.4 ms 29.0 ms

Encoding is ~5Γ— faster than Flux.1 and ~4Γ— faster than Flux.2. Single-step decoding is ~3Γ— faster than both Flux VAEs; multi-step decoding trades speed for perceptual sharpness.

6.3 Usage

from ir_diffae import IRDiffAE, IRDiffAEInferenceConfig

model = IRDiffAE.from_pretrained("data-archetype/irdiffae-v1", device="cuda")  # bfloat16 by default

# Encode
latents = model.encode(images)  # [B, 3, H, W] β†’ [B, 128, H/16, W/16]

# Decode β€” PSNR-optimal (1 step, single forward pass)
cfg = IRDiffAEInferenceConfig(num_steps=1, sampler="ddim")
recon = model.decode(latents, height=H, width=W, inference_config=cfg)

# Decode β€” perceptual sharpness (10 steps + PDG)
cfg_sharp = IRDiffAEInferenceConfig(
    num_steps=10, sampler="ddim", pdg_enabled=True, pdg_strength=2.0
)
recon_sharp = model.decode(latents, height=H, width=W, inference_config=cfg_sharp)

Citation

@misc{ir_diffae,
  title   = {iRDiffAE: A Fast, Representation Aligned Diffusion Autoencoder with DiCo Blocks},
  author  = {data-archetype},
  year    = {2026},
  month   = feb,
  url     = {https://github.com/data-archetype/irdiffae},
}

7. Results

Reconstruction quality evaluated on a curated set of test images covering photographs, book covers, and documents. Flux.1 VAE (patch 8, 16 channels) is included as a reference at the same 12x compression ratio as the c64 variant.

7.1 Interactive Viewer

Open full-resolution comparison viewer β€” side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.

7.2 Inference Settings

Setting Value
Sampler ddim
Steps 1
Schedule linear
Seed 42
PDG no_path_dropg
Batch size (timing) 4

All models run in bfloat16. Timings measured on an NVIDIA RTX Pro 6000 (Blackwell).

7.3 Global Metrics

Metric irdiffae_v1 (1 step) Flux.1 VAE Flux.2 VAE
Avg PSNR (dB) 31.77 32.76 34.16
Avg encode (ms/image) 2.5 64.8 46.3
Avg decode (ms/image) 5.7 138.1 92.5

7.4 Per-Image PSNR (dB)

Image irdiffae_v1 (1 step) Flux.1 VAE Flux.2 VAE
p640x1536:94623 30.99 31.29 33.50
p640x1536:94624 27.21 27.62 30.03
p640x1536:94625 30.48 31.65 33.98
p640x1536:94626 28.96 29.44 31.53
p640x1536:94627 29.17 28.70 30.53
p640x1536:94628 25.55 26.38 28.88
p960x1024:216264 40.92 40.87 45.39
p960x1024:216265 26.18 25.82 27.80
p960x1024:216266 43.61 47.77 46.20
p960x1024:216267 37.12 37.65 39.23
p960x1024:216268 35.75 35.27 36.13
p960x1024:216269 29.14 28.45 30.24
p960x1024:216270 32.06 31.92 34.18
p960x1024:216271 38.73 38.92 42.18
p704x1472:94699 40.81 40.43 41.79
p704x1472:94700 29.52 29.52 32.08
p704x1472:94701 35.01 35.44 37.90
p704x1472:94702 30.74 30.74 32.50
p704x1472:94703 28.50 29.07 31.35
p704x1472:94704 28.68 29.22 31.84
p704x1472:94705 35.91 36.38 37.44
p704x1472:94706 31.12 31.50 33.66
r256_p1344x704:15577 28.10 28.32 29.98
r256_p1344x704:15578 28.29 29.35 30.79
r256_p1344x704:15579 29.86 30.44 31.83
r256_p1344x704:15580 34.01 36.12 36.03
r256_p1344x704:15581 33.41 37.42 36.94
r256_p1344x704:15582 29.12 30.64 32.10
r256_p1344x704:15583 32.61 34.67 34.54
r256_p1344x704:15584 28.72 30.34 31.76
r256_p896x1152:144131 30.73 33.10 33.60
r256_p896x1152:144132 33.13 34.23 35.32
r256_p896x1152:144133 35.70 37.85 37.33
r256_p896x1152:144134 31.72 34.25 34.47
r256_p896x1152:144135 27.34 28.17 29.87
r256_p896x1152:144136 32.89 35.24 35.68
r256_p896x1152:144137 29.78 32.70 32.86
r256_p896x1152:144138 24.86 24.15 25.63
VAE_accuracy_test_image 32.62 36.69 35.25