Upload folder using huggingface_hub
Browse files- README.md +144 -0
- config.json +15 -0
- m_diffae/__init__.py +28 -0
- m_diffae/adaln.py +51 -0
- m_diffae/compact_channel_attention.py +21 -0
- m_diffae/config.py +56 -0
- m_diffae/conv_mlp.py +26 -0
- m_diffae/decoder.py +198 -0
- m_diffae/dico_block.py +111 -0
- m_diffae/encoder.py +55 -0
- m_diffae/model.py +294 -0
- m_diffae/norms.py +39 -0
- m_diffae/samplers.py +226 -0
- m_diffae/straight_through_encoder.py +27 -0
- m_diffae/time_embed.py +83 -0
- m_diffae/vp_diffusion.py +151 -0
- model.safetensors +3 -0
- technical_report_mdiffae.md +211 -0
README.md
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- diffusion
|
| 5 |
+
- autoencoder
|
| 6 |
+
- image-reconstruction
|
| 7 |
+
- pytorch
|
| 8 |
+
- masked-autoencoder
|
| 9 |
+
library_name: mdiffae
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# mdiffae_v1
|
| 13 |
+
|
| 14 |
+
**mDiffAE** β **M**asked **Diff**usion **A**uto**E**ncoder.
|
| 15 |
+
A fast, single-GPU-trainable diffusion autoencoder with a **64-channel**
|
| 16 |
+
spatial bottleneck. Uses decoder token masking as an implicit regularizer instead of
|
| 17 |
+
REPA alignment, achieving Flux.2-level conditioning quality with a simpler
|
| 18 |
+
flat decoder architecture (4 blocks, no skip connections).
|
| 19 |
+
|
| 20 |
+
This variant (mdiffae_v1): 81.4M parameters, 310.6 MB.
|
| 21 |
+
Bottleneck: **64 channels** at patch size 16
|
| 22 |
+
(compression ratio 12x).
|
| 23 |
+
|
| 24 |
+
## Documentation
|
| 25 |
+
|
| 26 |
+
- [Technical Report](technical_report_mdiffae.md) β architecture, masking strategy, and results
|
| 27 |
+
- [iRDiffAE Technical Report](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) β full background on VP diffusion, DiCo blocks, patchify encoder, AdaLN
|
| 28 |
+
- [Results β interactive viewer](https://huggingface.co/spaces/data-archetype/mdiffae-results) β full-resolution side-by-side comparison
|
| 29 |
+
|
| 30 |
+
## Quick Start
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
import torch
|
| 34 |
+
from m_diffae import MDiffAE
|
| 35 |
+
|
| 36 |
+
# Load from HuggingFace Hub (or a local path)
|
| 37 |
+
model = MDiffAE.from_pretrained("data-archetype/mdiffae_v1", device="cuda")
|
| 38 |
+
|
| 39 |
+
# Encode
|
| 40 |
+
images = ... # [B, 3, H, W] in [-1, 1], H and W divisible by 16
|
| 41 |
+
latents = model.encode(images)
|
| 42 |
+
|
| 43 |
+
# Decode (1 step by default β PSNR-optimal)
|
| 44 |
+
recon = model.decode(latents, height=H, width=W)
|
| 45 |
+
|
| 46 |
+
# Reconstruct (encode + 1-step decode)
|
| 47 |
+
recon = model.reconstruct(images)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
> **Note:** Requires `pip install huggingface_hub safetensors` for Hub downloads.
|
| 51 |
+
> You can also pass a local directory path to `from_pretrained()`.
|
| 52 |
+
|
| 53 |
+
## Architecture
|
| 54 |
+
|
| 55 |
+
| Property | Value |
|
| 56 |
+
|---|---|
|
| 57 |
+
| Parameters | 81,410,624 |
|
| 58 |
+
| File size | 310.6 MB |
|
| 59 |
+
| Patch size | 16 |
|
| 60 |
+
| Model dim | 896 |
|
| 61 |
+
| Encoder depth | 4 |
|
| 62 |
+
| Decoder depth | 4 |
|
| 63 |
+
| Decoder topology | Flat sequential (no skip connections) |
|
| 64 |
+
| Bottleneck dim | 64 |
|
| 65 |
+
| MLP ratio | 4.0 |
|
| 66 |
+
| Depthwise kernel | 7 |
|
| 67 |
+
| AdaLN rank | 128 |
|
| 68 |
+
| PDG mechanism | Token-level masking (ratio 0.75) |
|
| 69 |
+
| Training regularizer | Decoder token masking (75% ratio, 50% apply prob) |
|
| 70 |
+
|
| 71 |
+
**Encoder**: Deterministic. Patchify (PixelUnshuffle + 1x1 conv) followed by
|
| 72 |
+
DiCo blocks (depthwise conv + compact channel attention + GELU MLP) with
|
| 73 |
+
learned residual gates.
|
| 74 |
+
|
| 75 |
+
**Decoder**: VP diffusion conditioned on encoder latents and timestep via
|
| 76 |
+
shared-base + per-layer low-rank AdaLN-Zero. 4 flat
|
| 77 |
+
sequential blocks (no skip connections). Supports token-level Path-Drop
|
| 78 |
+
Guidance (PDG) at inference β very sensitive, use small strengths only.
|
| 79 |
+
|
| 80 |
+
**Compared to iRDiffAE's decoder**: iRDiffAE uses an 8-block decoder split into
|
| 81 |
+
start (2), middle (4), and end (2) groups with a skip connection that concatenates
|
| 82 |
+
start-block output with middle-block output and fuses them through a Conv1x1 before
|
| 83 |
+
the end blocks. PDG works by dropping the entire middle block computation and
|
| 84 |
+
replacing it with a learned mask feature. In contrast, mDiffAE uses a simple flat
|
| 85 |
+
stack of 4 blocks with no skip connections or block groups.
|
| 86 |
+
PDG instead works at the token level: 75% of spatial tokens in the fused decoder
|
| 87 |
+
input are replaced with a learned mask feature, providing a much finer-grained
|
| 88 |
+
guidance signal. The bottleneck is also halved from 128 to 64
|
| 89 |
+
channels, giving a 12x
|
| 90 |
+
compression ratio vs iRDiffAE's 6x.
|
| 91 |
+
|
| 92 |
+
### Key Differences from iRDiffAE
|
| 93 |
+
|
| 94 |
+
| Aspect | iRDiffAE v1 | mDiffAE v1 |
|
| 95 |
+
|---|---|---|
|
| 96 |
+
| Bottleneck dim | 128 | **64** |
|
| 97 |
+
| Decoder depth | 8 (2+4+2 skip-concat) | **4 (flat sequential)** |
|
| 98 |
+
| PDG mechanism | Block dropping | **Token masking** |
|
| 99 |
+
| Training regularizer | REPA + covariance reg | **Decoder token masking** |
|
| 100 |
+
| PDG sensitivity | Moderate (1.5β3.0) | **Very sensitive (1.05β1.2)** |
|
| 101 |
+
|
| 102 |
+
## Recommended Settings
|
| 103 |
+
|
| 104 |
+
Best quality is achieved with just **1 DDIM step** and PDG disabled,
|
| 105 |
+
making inference extremely fast.
|
| 106 |
+
|
| 107 |
+
PDG in mDiffAE is **very sensitive** β use tiny strengths (1.05β1.2)
|
| 108 |
+
if enabled. Higher values will cause artifacts.
|
| 109 |
+
|
| 110 |
+
| Setting | Default |
|
| 111 |
+
|---|---|
|
| 112 |
+
| Sampler | DDIM |
|
| 113 |
+
| Steps | 1 |
|
| 114 |
+
| PDG | Disabled |
|
| 115 |
+
| PDG strength (if enabled) | 1.1 |
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
from m_diffae import MDiffAEInferenceConfig
|
| 119 |
+
|
| 120 |
+
# PSNR-optimal (fast, 1 step)
|
| 121 |
+
cfg = MDiffAEInferenceConfig(num_steps=1, sampler="ddim")
|
| 122 |
+
recon = model.decode(latents, height=H, width=W, inference_config=cfg)
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Citation
|
| 126 |
+
|
| 127 |
+
```bibtex
|
| 128 |
+
@misc{m_diffae,
|
| 129 |
+
title = {mDiffAE: A Masked Diffusion Autoencoder with Flat Decoder and Token-Level Guidance},
|
| 130 |
+
author = {data-archetype},
|
| 131 |
+
year = {2026},
|
| 132 |
+
month = mar,
|
| 133 |
+
url = {https://huggingface.co/data-archetype/mdiffae_v1},
|
| 134 |
+
}
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## Dependencies
|
| 138 |
+
|
| 139 |
+
- PyTorch >= 2.0
|
| 140 |
+
- safetensors (for loading weights)
|
| 141 |
+
|
| 142 |
+
## License
|
| 143 |
+
|
| 144 |
+
Apache 2.0
|
config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"in_channels": 3,
|
| 3 |
+
"patch_size": 16,
|
| 4 |
+
"model_dim": 896,
|
| 5 |
+
"encoder_depth": 4,
|
| 6 |
+
"decoder_depth": 4,
|
| 7 |
+
"bottleneck_dim": 64,
|
| 8 |
+
"mlp_ratio": 4.0,
|
| 9 |
+
"depthwise_kernel_size": 7,
|
| 10 |
+
"adaln_low_rank_rank": 128,
|
| 11 |
+
"logsnr_min": -10.0,
|
| 12 |
+
"logsnr_max": 10.0,
|
| 13 |
+
"pixel_noise_std": 0.558,
|
| 14 |
+
"pdg_mask_ratio": 0.75
|
| 15 |
+
}
|
m_diffae/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""mDiffAE: Standalone diffusion autoencoder for HuggingFace distribution.
|
| 2 |
+
|
| 3 |
+
Masked Diffusion AutoEncoder β a compact diffusion autoencoder
|
| 4 |
+
that encodes images to spatial latents and decodes via iterative VP diffusion.
|
| 5 |
+
Uses decoder token masking as an implicit regularizer instead of REPA alignment,
|
| 6 |
+
with a flat 4-block decoder (no skip connections).
|
| 7 |
+
|
| 8 |
+
Usage::
|
| 9 |
+
|
| 10 |
+
from m_diffae import MDiffAE, MDiffAEInferenceConfig
|
| 11 |
+
|
| 12 |
+
model = MDiffAE.from_pretrained("path/to/weights", device="cuda")
|
| 13 |
+
|
| 14 |
+
# Encode
|
| 15 |
+
latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
|
| 16 |
+
|
| 17 |
+
# Decode β PSNR-optimal (1 step, default)
|
| 18 |
+
recon = model.decode(latents, height=H, width=W)
|
| 19 |
+
|
| 20 |
+
# Decode β perceptual sharpness (10 steps + PDG)
|
| 21 |
+
cfg = MDiffAEInferenceConfig(num_steps=10, sampler="ddim", pdg_enabled=True)
|
| 22 |
+
recon = model.decode(latents, height=H, width=W, inference_config=cfg)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from .config import MDiffAEConfig, MDiffAEInferenceConfig
|
| 26 |
+
from .model import MDiffAE
|
| 27 |
+
|
| 28 |
+
__all__ = ["MDiffAE", "MDiffAEConfig", "MDiffAEInferenceConfig"]
|
m_diffae/adaln.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AdaLN-Zero modules for shared-base + low-rank-delta conditioning."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AdaLNZeroProjector(nn.Module):
|
| 9 |
+
"""Shared base AdaLN projection: SiLU -> Linear(d_cond -> 4*d_model).
|
| 10 |
+
|
| 11 |
+
Returns packed modulation tensor [B, 4*d_model]. Zero-initialized.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, d_model: int, d_cond: int) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.d_model = int(d_model)
|
| 17 |
+
self.d_cond = int(d_cond)
|
| 18 |
+
self.act = nn.SiLU()
|
| 19 |
+
self.proj = nn.Linear(self.d_cond, 4 * self.d_model)
|
| 20 |
+
nn.init.zeros_(self.proj.weight)
|
| 21 |
+
nn.init.zeros_(self.proj.bias)
|
| 22 |
+
|
| 23 |
+
def forward(self, cond: Tensor) -> Tensor:
|
| 24 |
+
"""Return packed modulation [B, 4*d_model] from conditioning [B, d_cond]."""
|
| 25 |
+
act = self.act(cond)
|
| 26 |
+
return self.proj(act)
|
| 27 |
+
|
| 28 |
+
def forward_activated(self, act_cond: Tensor) -> Tensor:
|
| 29 |
+
"""Return packed modulation from pre-activated conditioning."""
|
| 30 |
+
return self.proj(act_cond)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AdaLNZeroLowRankDelta(nn.Module):
|
| 34 |
+
"""Per-layer low-rank delta: down(d_cond -> rank) -> up(rank -> 4*d_model).
|
| 35 |
+
|
| 36 |
+
Zero-initialized up-projection preserves AdaLN "zero output" at init.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.d_model = int(d_model)
|
| 42 |
+
self.d_cond = int(d_cond)
|
| 43 |
+
self.rank = int(rank)
|
| 44 |
+
self.down = nn.Linear(self.d_cond, self.rank, bias=False)
|
| 45 |
+
self.up = nn.Linear(self.rank, 4 * self.d_model, bias=False)
|
| 46 |
+
nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
|
| 47 |
+
nn.init.zeros_(self.up.weight)
|
| 48 |
+
|
| 49 |
+
def forward(self, act_cond: Tensor) -> Tensor:
|
| 50 |
+
"""Return packed delta modulation [B, 4*d_model] from activated cond."""
|
| 51 |
+
return self.up(self.down(act_cond))
|
m_diffae/compact_channel_attention.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compact Channel Attention (CCA) module."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CompactChannelAttention(nn.Module):
|
| 9 |
+
"""Global average pool -> 1x1 Conv2d -> Sigmoid channel gate."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, channels: int) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
c = int(channels)
|
| 14 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 15 |
+
self.proj = nn.Conv2d(c, c, kernel_size=1, padding=0, stride=1, bias=True)
|
| 16 |
+
self.act = nn.Sigmoid()
|
| 17 |
+
|
| 18 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 19 |
+
w = self.pool(x)
|
| 20 |
+
w = self.proj(w)
|
| 21 |
+
return self.act(w)
|
m_diffae/config.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frozen model architecture and user-tunable inference configuration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import asdict, dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class MDiffAEConfig:
|
| 12 |
+
"""Frozen model architecture config. Stored alongside weights as config.json."""
|
| 13 |
+
|
| 14 |
+
in_channels: int = 3
|
| 15 |
+
patch_size: int = 16
|
| 16 |
+
model_dim: int = 896
|
| 17 |
+
encoder_depth: int = 4
|
| 18 |
+
decoder_depth: int = 4
|
| 19 |
+
bottleneck_dim: int = 64
|
| 20 |
+
mlp_ratio: float = 4.0
|
| 21 |
+
depthwise_kernel_size: int = 7
|
| 22 |
+
adaln_low_rank_rank: int = 128
|
| 23 |
+
# VP diffusion schedule endpoints
|
| 24 |
+
logsnr_min: float = -10.0
|
| 25 |
+
logsnr_max: float = 10.0
|
| 26 |
+
# Pixel-space noise std for VP diffusion initialization
|
| 27 |
+
pixel_noise_std: float = 0.558
|
| 28 |
+
# Token mask ratio for PDG (fraction of spatial tokens replaced with mask_feature)
|
| 29 |
+
pdg_mask_ratio: float = 0.75
|
| 30 |
+
|
| 31 |
+
def save(self, path: str | Path) -> None:
|
| 32 |
+
"""Save config as JSON."""
|
| 33 |
+
p = Path(path)
|
| 34 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
p.write_text(json.dumps(asdict(self), indent=2) + "\n")
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def load(cls, path: str | Path) -> MDiffAEConfig:
|
| 39 |
+
"""Load config from JSON."""
|
| 40 |
+
data = json.loads(Path(path).read_text())
|
| 41 |
+
return cls(**data)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class MDiffAEInferenceConfig:
|
| 46 |
+
"""User-tunable inference parameters with sensible defaults.
|
| 47 |
+
|
| 48 |
+
PDG is very sensitive in mDiffAE β use small strengths (1.05β1.2).
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
num_steps: int = 1 # decoder forward passes (NFE)
|
| 52 |
+
sampler: str = "ddim" # "ddim" or "dpmpp_2m"
|
| 53 |
+
schedule: str = "linear" # "linear" or "cosine"
|
| 54 |
+
pdg_enabled: bool = False
|
| 55 |
+
pdg_strength: float = 1.1
|
| 56 |
+
seed: int | None = None
|
m_diffae/conv_mlp.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conv-based MLP with GELU activation for DiCo blocks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
from .norms import ChannelWiseRMSNorm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConvMLP(nn.Module):
|
| 12 |
+
"""1x1 Conv-based MLP: RMSNorm -> Conv1x1 -> GELU -> Conv1x1."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self, channels: int, hidden_channels: int, norm_eps: float = 1e-6
|
| 16 |
+
) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.norm = ChannelWiseRMSNorm(channels, eps=norm_eps, affine=False)
|
| 19 |
+
self.conv_in = nn.Conv2d(channels, hidden_channels, kernel_size=1, bias=True)
|
| 20 |
+
self.conv_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 23 |
+
y = self.norm(x)
|
| 24 |
+
y = self.conv_in(y)
|
| 25 |
+
y = F.gelu(y)
|
| 26 |
+
return self.conv_out(y)
|
m_diffae/decoder.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""mDiffAE decoder: flat sequential DiCoBlocks with token-level PDG masking."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector
|
| 11 |
+
from .dico_block import DiCoBlock
|
| 12 |
+
from .norms import ChannelWiseRMSNorm
|
| 13 |
+
from .straight_through_encoder import Patchify
|
| 14 |
+
from .time_embed import SinusoidalTimeEmbeddingMLP
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Decoder(nn.Module):
|
| 18 |
+
"""VP diffusion decoder conditioned on encoder latents and timestep.
|
| 19 |
+
|
| 20 |
+
Architecture:
|
| 21 |
+
Patchify x_t -> Norm -> Fuse with upsampled z
|
| 22 |
+
-> Blocks (flat sequential, depth blocks) -> Norm -> Conv1x1 -> PixelShuffle
|
| 23 |
+
|
| 24 |
+
Token-level PDG: at inference, a fraction of spatial tokens in the fused input
|
| 25 |
+
are replaced with a learned mask_feature before the decoder blocks. Comparing
|
| 26 |
+
the masked vs unmasked outputs provides guidance signal.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
in_channels: int,
|
| 32 |
+
patch_size: int,
|
| 33 |
+
model_dim: int,
|
| 34 |
+
depth: int,
|
| 35 |
+
bottleneck_dim: int,
|
| 36 |
+
mlp_ratio: float,
|
| 37 |
+
depthwise_kernel_size: int,
|
| 38 |
+
adaln_low_rank_rank: int,
|
| 39 |
+
pdg_mask_ratio: float = 0.75,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.patch_size = int(patch_size)
|
| 43 |
+
self.model_dim = int(model_dim)
|
| 44 |
+
self.pdg_mask_ratio = float(pdg_mask_ratio)
|
| 45 |
+
|
| 46 |
+
# Input processing
|
| 47 |
+
self.patchify = Patchify(in_channels, patch_size, model_dim)
|
| 48 |
+
self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
|
| 49 |
+
|
| 50 |
+
# Latent conditioning path
|
| 51 |
+
self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
|
| 52 |
+
self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
|
| 53 |
+
self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
|
| 54 |
+
|
| 55 |
+
# Time embedding
|
| 56 |
+
self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
|
| 57 |
+
|
| 58 |
+
# AdaLN: shared base projector + per-block low-rank deltas
|
| 59 |
+
self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim)
|
| 60 |
+
self.adaln_deltas = nn.ModuleList(
|
| 61 |
+
[
|
| 62 |
+
AdaLNZeroLowRankDelta(
|
| 63 |
+
d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
|
| 64 |
+
)
|
| 65 |
+
for _ in range(depth)
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Flat sequential blocks (no start/middle/end split, no skip connections)
|
| 70 |
+
self.blocks = nn.ModuleList(
|
| 71 |
+
[
|
| 72 |
+
DiCoBlock(
|
| 73 |
+
model_dim,
|
| 74 |
+
mlp_ratio,
|
| 75 |
+
depthwise_kernel_size=depthwise_kernel_size,
|
| 76 |
+
use_external_adaln=True,
|
| 77 |
+
)
|
| 78 |
+
for _ in range(depth)
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Learned mask feature for token-level PDG
|
| 83 |
+
self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
|
| 84 |
+
|
| 85 |
+
# Output head
|
| 86 |
+
self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
|
| 87 |
+
self.out_proj = nn.Conv2d(
|
| 88 |
+
model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
|
| 89 |
+
)
|
| 90 |
+
self.unpatchify = nn.PixelShuffle(patch_size)
|
| 91 |
+
|
| 92 |
+
def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
|
| 93 |
+
"""Compute packed AdaLN modulation = shared_base + per-layer delta."""
|
| 94 |
+
act = self.adaln_base.act(cond)
|
| 95 |
+
base_m = self.adaln_base.forward_activated(act)
|
| 96 |
+
delta_m = self.adaln_deltas[layer_idx](act)
|
| 97 |
+
return base_m + delta_m
|
| 98 |
+
|
| 99 |
+
def _apply_token_mask(self, fused: Tensor) -> Tensor:
|
| 100 |
+
"""Replace a fraction of spatial tokens with mask_feature (2x2 groupwise).
|
| 101 |
+
|
| 102 |
+
Divides the spatial grid into 2x2 groups. Within each group, masks
|
| 103 |
+
floor(ratio * 4) tokens deterministically (lowest random scores).
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
fused: [B, C, H, W] fused decoder input.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Masked tensor with same shape, where masked positions contain mask_feature.
|
| 110 |
+
"""
|
| 111 |
+
b, c, h, w = fused.shape
|
| 112 |
+
# Pad to even dims if needed
|
| 113 |
+
h_pad = (2 - h % 2) % 2
|
| 114 |
+
w_pad = (2 - w % 2) % 2
|
| 115 |
+
if h_pad > 0 or w_pad > 0:
|
| 116 |
+
fused = torch.nn.functional.pad(fused, (0, w_pad, 0, h_pad))
|
| 117 |
+
_, _, h, w = fused.shape
|
| 118 |
+
|
| 119 |
+
# Reshape into 2x2 groups: [B, C, H/2, 2, W/2, 2] -> [B, C, H/2, W/2, 4]
|
| 120 |
+
x = fused.reshape(b, c, h // 2, 2, w // 2, 2)
|
| 121 |
+
x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h // 2, w // 2, 4)
|
| 122 |
+
|
| 123 |
+
# Random scores for each token in each group
|
| 124 |
+
scores = torch.rand(b, 1, h // 2, w // 2, 4, device=fused.device)
|
| 125 |
+
|
| 126 |
+
# Mask the floor(ratio * 4) lowest-scoring tokens per group
|
| 127 |
+
num_mask = math.floor(self.pdg_mask_ratio * 4)
|
| 128 |
+
if num_mask > 0:
|
| 129 |
+
# argsort ascending, mask the first num_mask
|
| 130 |
+
_, indices = scores.sort(dim=-1)
|
| 131 |
+
mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 132 |
+
mask.scatter_(-1, indices[..., :num_mask], True)
|
| 133 |
+
else:
|
| 134 |
+
mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 135 |
+
|
| 136 |
+
# Apply mask: replace masked tokens with mask_feature
|
| 137 |
+
mask_feat = self.mask_feature.to(device=fused.device, dtype=fused.dtype)
|
| 138 |
+
mask_feat = mask_feat.squeeze(-1).squeeze(-1) # [1, C]
|
| 139 |
+
mask_feat = mask_feat.view(1, c, 1, 1, 1).expand_as(x)
|
| 140 |
+
mask_expanded = mask.expand_as(x)
|
| 141 |
+
x = torch.where(mask_expanded, mask_feat, x)
|
| 142 |
+
|
| 143 |
+
# Reshape back to [B, C, H, W]
|
| 144 |
+
x = x.reshape(b, c, h // 2, w // 2, 2, 2)
|
| 145 |
+
x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h, w)
|
| 146 |
+
|
| 147 |
+
# Remove padding if applied
|
| 148 |
+
if h_pad > 0 or w_pad > 0:
|
| 149 |
+
x = x[:, :, : h - h_pad, : w - w_pad]
|
| 150 |
+
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
def forward(
|
| 154 |
+
self,
|
| 155 |
+
x_t: Tensor,
|
| 156 |
+
t: Tensor,
|
| 157 |
+
latents: Tensor,
|
| 158 |
+
*,
|
| 159 |
+
mask_tokens: bool = False,
|
| 160 |
+
) -> Tensor:
|
| 161 |
+
"""Single decoder forward pass.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
x_t: Noised image [B, C, H, W].
|
| 165 |
+
t: Timestep [B] in [0, 1].
|
| 166 |
+
latents: Encoder latents [B, bottleneck_dim, h, w].
|
| 167 |
+
mask_tokens: If True, apply token-level masking to decoder input (for PDG).
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
x0 prediction [B, C, H, W].
|
| 171 |
+
"""
|
| 172 |
+
# Patchify and normalize x_t
|
| 173 |
+
x_feat = self.patchify(x_t)
|
| 174 |
+
x_feat = self.norm_in(x_feat)
|
| 175 |
+
|
| 176 |
+
# Upsample and normalize latents, fuse with x_feat
|
| 177 |
+
z_up = self.latent_up(latents)
|
| 178 |
+
z_up = self.latent_norm(z_up)
|
| 179 |
+
fused = torch.cat([x_feat, z_up], dim=1)
|
| 180 |
+
fused = self.fuse_in(fused)
|
| 181 |
+
|
| 182 |
+
# Token masking for PDG (replaces tokens with mask_feature)
|
| 183 |
+
if mask_tokens:
|
| 184 |
+
fused = self._apply_token_mask(fused)
|
| 185 |
+
|
| 186 |
+
# Time conditioning
|
| 187 |
+
cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
|
| 188 |
+
|
| 189 |
+
# Run all blocks sequentially
|
| 190 |
+
x = fused
|
| 191 |
+
for layer_idx, block in enumerate(self.blocks):
|
| 192 |
+
adaln_m = self._adaln_m_for_layer(cond, layer_idx=layer_idx)
|
| 193 |
+
x = block(x, adaln_m=adaln_m)
|
| 194 |
+
|
| 195 |
+
# Output head
|
| 196 |
+
x = self.norm_out(x)
|
| 197 |
+
patches = self.out_proj(x)
|
| 198 |
+
return self.unpatchify(patches)
|
m_diffae/dico_block.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DiCo block: conv path (1x1 -> depthwise -> SiLU -> CCA -> 1x1) + GELU MLP."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
from .compact_channel_attention import CompactChannelAttention
|
| 10 |
+
from .conv_mlp import ConvMLP
|
| 11 |
+
from .norms import ChannelWiseRMSNorm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DiCoBlock(nn.Module):
|
| 15 |
+
"""DiCo-style conv block with optional external AdaLN conditioning.
|
| 16 |
+
|
| 17 |
+
Two modes:
|
| 18 |
+
- Unconditioned (encoder): uses learned per-channel residual gates.
|
| 19 |
+
- External AdaLN (decoder): receives packed modulation [B, 4*C] via adaln_m.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
channels: int,
|
| 25 |
+
mlp_ratio: float,
|
| 26 |
+
*,
|
| 27 |
+
depthwise_kernel_size: int = 7,
|
| 28 |
+
use_external_adaln: bool = False,
|
| 29 |
+
norm_eps: float = 1e-6,
|
| 30 |
+
) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.channels = int(channels)
|
| 33 |
+
self.use_external_adaln = bool(use_external_adaln)
|
| 34 |
+
|
| 35 |
+
# Pre-norm for conv and MLP paths (no affine)
|
| 36 |
+
self.norm1 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)
|
| 37 |
+
self.norm2 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)
|
| 38 |
+
|
| 39 |
+
# Conv path: 1x1 -> depthwise kxk -> SiLU -> CCA -> 1x1
|
| 40 |
+
self.conv1 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
|
| 41 |
+
self.conv2 = nn.Conv2d(
|
| 42 |
+
self.channels,
|
| 43 |
+
self.channels,
|
| 44 |
+
kernel_size=depthwise_kernel_size,
|
| 45 |
+
padding=depthwise_kernel_size // 2,
|
| 46 |
+
groups=self.channels,
|
| 47 |
+
bias=True,
|
| 48 |
+
)
|
| 49 |
+
self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
|
| 50 |
+
self.cca = CompactChannelAttention(self.channels)
|
| 51 |
+
|
| 52 |
+
# MLP path: GELU activation
|
| 53 |
+
hidden_channels = max(int(round(float(self.channels) * mlp_ratio)), 1)
|
| 54 |
+
self.mlp = ConvMLP(self.channels, hidden_channels, norm_eps=norm_eps)
|
| 55 |
+
|
| 56 |
+
# Conditioning: learned gates (encoder) or external adaln_m (decoder)
|
| 57 |
+
if not self.use_external_adaln:
|
| 58 |
+
self.gate_attn = nn.Parameter(torch.zeros(self.channels))
|
| 59 |
+
self.gate_mlp = nn.Parameter(torch.zeros(self.channels))
|
| 60 |
+
|
| 61 |
+
def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
|
| 62 |
+
b, c = x.shape[:2]
|
| 63 |
+
|
| 64 |
+
if self.use_external_adaln:
|
| 65 |
+
if adaln_m is None:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
"adaln_m required for externally-conditioned DiCoBlock"
|
| 68 |
+
)
|
| 69 |
+
adaln_m_cast = adaln_m.to(device=x.device, dtype=x.dtype)
|
| 70 |
+
scale_a, gate_a, scale_m, gate_m = adaln_m_cast.chunk(4, dim=-1)
|
| 71 |
+
elif adaln_m is not None:
|
| 72 |
+
raise ValueError("adaln_m must be None for unconditioned DiCoBlock")
|
| 73 |
+
|
| 74 |
+
residual = x
|
| 75 |
+
|
| 76 |
+
# Conv path
|
| 77 |
+
x_att = self.norm1(x)
|
| 78 |
+
if self.use_external_adaln:
|
| 79 |
+
x_att = x_att * (1.0 + scale_a.view(b, c, 1, 1)) # type: ignore[possibly-undefined]
|
| 80 |
+
y = self.conv1(x_att)
|
| 81 |
+
y = self.conv2(y)
|
| 82 |
+
y = F.silu(y)
|
| 83 |
+
y = y * self.cca(y)
|
| 84 |
+
y = self.conv3(y)
|
| 85 |
+
|
| 86 |
+
if self.use_external_adaln:
|
| 87 |
+
gate_a_view = torch.tanh(gate_a).view(b, c, 1, 1) # type: ignore[possibly-undefined]
|
| 88 |
+
x = residual + gate_a_view * y
|
| 89 |
+
else:
|
| 90 |
+
gate = self.gate_attn.view(1, self.channels, 1, 1).to(
|
| 91 |
+
dtype=y.dtype, device=y.device
|
| 92 |
+
)
|
| 93 |
+
x = residual + gate * y
|
| 94 |
+
|
| 95 |
+
# MLP path
|
| 96 |
+
residual_mlp = x
|
| 97 |
+
x_mlp = self.norm2(x)
|
| 98 |
+
if self.use_external_adaln:
|
| 99 |
+
x_mlp = x_mlp * (1.0 + scale_m.view(b, c, 1, 1)) # type: ignore[possibly-undefined]
|
| 100 |
+
y_mlp = self.mlp(x_mlp)
|
| 101 |
+
|
| 102 |
+
if self.use_external_adaln:
|
| 103 |
+
gate_m_view = torch.tanh(gate_m).view(b, c, 1, 1) # type: ignore[possibly-undefined]
|
| 104 |
+
x = residual_mlp + gate_m_view * y_mlp
|
| 105 |
+
else:
|
| 106 |
+
gate = self.gate_mlp.view(1, self.channels, 1, 1).to(
|
| 107 |
+
dtype=y_mlp.dtype, device=y_mlp.device
|
| 108 |
+
)
|
| 109 |
+
x = residual_mlp + gate * y_mlp
|
| 110 |
+
|
| 111 |
+
return x
|
m_diffae/encoder.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""mDiffAE encoder: patchify -> DiCoBlocks -> bottleneck projection."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
from .dico_block import DiCoBlock
|
| 8 |
+
from .norms import ChannelWiseRMSNorm
|
| 9 |
+
from .straight_through_encoder import Patchify
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Encoder(nn.Module):
|
| 13 |
+
"""Deterministic encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
|
| 14 |
+
|
| 15 |
+
Pipeline: Patchify -> RMSNorm -> DiCoBlocks (unconditioned) -> Conv1x1 -> RMSNorm(no affine)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
patch_size: int,
|
| 22 |
+
model_dim: int,
|
| 23 |
+
depth: int,
|
| 24 |
+
bottleneck_dim: int,
|
| 25 |
+
mlp_ratio: float,
|
| 26 |
+
depthwise_kernel_size: int,
|
| 27 |
+
) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.patchify = Patchify(in_channels, patch_size, model_dim)
|
| 30 |
+
self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
|
| 31 |
+
self.blocks = nn.ModuleList(
|
| 32 |
+
[
|
| 33 |
+
DiCoBlock(
|
| 34 |
+
model_dim,
|
| 35 |
+
mlp_ratio,
|
| 36 |
+
depthwise_kernel_size=depthwise_kernel_size,
|
| 37 |
+
use_external_adaln=False,
|
| 38 |
+
)
|
| 39 |
+
for _ in range(depth)
|
| 40 |
+
]
|
| 41 |
+
)
|
| 42 |
+
self.to_bottleneck = nn.Conv2d(
|
| 43 |
+
model_dim, bottleneck_dim, kernel_size=1, bias=True
|
| 44 |
+
)
|
| 45 |
+
self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
|
| 46 |
+
|
| 47 |
+
def forward(self, images: Tensor) -> Tensor:
|
| 48 |
+
"""Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w]."""
|
| 49 |
+
z = self.patchify(images)
|
| 50 |
+
z = self.norm_in(z)
|
| 51 |
+
for block in self.blocks:
|
| 52 |
+
z = block(z)
|
| 53 |
+
z = self.to_bottleneck(z)
|
| 54 |
+
z = self.norm_out(z)
|
| 55 |
+
return z
|
m_diffae/model.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MDiffAE: standalone HuggingFace-compatible mDiffAE model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
from .config import MDiffAEConfig, MDiffAEInferenceConfig
|
| 11 |
+
from .decoder import Decoder
|
| 12 |
+
from .encoder import Encoder
|
| 13 |
+
from .samplers import run_ddim, run_dpmpp_2m
|
| 14 |
+
from .vp_diffusion import get_schedule, make_initial_state, sample_noise
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _resolve_model_dir(
|
| 18 |
+
path_or_repo_id: str | Path,
|
| 19 |
+
*,
|
| 20 |
+
revision: str | None,
|
| 21 |
+
cache_dir: str | Path | None,
|
| 22 |
+
) -> Path:
|
| 23 |
+
"""Resolve a local path or HuggingFace Hub repo ID to a local directory."""
|
| 24 |
+
|
| 25 |
+
local = Path(path_or_repo_id)
|
| 26 |
+
if local.is_dir():
|
| 27 |
+
return local
|
| 28 |
+
# Not a local directory β try HuggingFace Hub
|
| 29 |
+
repo_id = str(path_or_repo_id)
|
| 30 |
+
try:
|
| 31 |
+
from huggingface_hub import snapshot_download
|
| 32 |
+
except ImportError:
|
| 33 |
+
raise ImportError(
|
| 34 |
+
f"'{repo_id}' is not an existing local directory. "
|
| 35 |
+
"To download from HuggingFace Hub, install huggingface_hub: "
|
| 36 |
+
"pip install huggingface_hub"
|
| 37 |
+
)
|
| 38 |
+
cache_dir_str = str(cache_dir) if cache_dir is not None else None
|
| 39 |
+
local_dir = snapshot_download(
|
| 40 |
+
repo_id,
|
| 41 |
+
revision=revision,
|
| 42 |
+
cache_dir=cache_dir_str,
|
| 43 |
+
)
|
| 44 |
+
return Path(local_dir)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class MDiffAE(nn.Module):
|
| 48 |
+
"""Standalone mDiffAE model for HuggingFace distribution.
|
| 49 |
+
|
| 50 |
+
A masked diffusion autoencoder that encodes images to compact latents and
|
| 51 |
+
decodes them back via iterative VP diffusion. Uses a flat 4-block decoder
|
| 52 |
+
with token-level masking for PDG instead of the skip-concat + block-drop
|
| 53 |
+
approach of iRDiffAE.
|
| 54 |
+
|
| 55 |
+
Usage::
|
| 56 |
+
|
| 57 |
+
model = MDiffAE.from_pretrained("data-archetype/mdiffae-v1")
|
| 58 |
+
model = model.to("cuda", dtype=torch.bfloat16)
|
| 59 |
+
|
| 60 |
+
# Encode
|
| 61 |
+
latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
|
| 62 |
+
|
| 63 |
+
# Decode (1 step by default β PSNR-optimal)
|
| 64 |
+
recon = model.decode(latents, height=H, width=W)
|
| 65 |
+
|
| 66 |
+
# Reconstruct (encode + 1-step decode)
|
| 67 |
+
recon = model.reconstruct(images)
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, config: MDiffAEConfig) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.config = config
|
| 73 |
+
|
| 74 |
+
self.encoder = Encoder(
|
| 75 |
+
in_channels=config.in_channels,
|
| 76 |
+
patch_size=config.patch_size,
|
| 77 |
+
model_dim=config.model_dim,
|
| 78 |
+
depth=config.encoder_depth,
|
| 79 |
+
bottleneck_dim=config.bottleneck_dim,
|
| 80 |
+
mlp_ratio=config.mlp_ratio,
|
| 81 |
+
depthwise_kernel_size=config.depthwise_kernel_size,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.decoder = Decoder(
|
| 85 |
+
in_channels=config.in_channels,
|
| 86 |
+
patch_size=config.patch_size,
|
| 87 |
+
model_dim=config.model_dim,
|
| 88 |
+
depth=config.decoder_depth,
|
| 89 |
+
bottleneck_dim=config.bottleneck_dim,
|
| 90 |
+
mlp_ratio=config.mlp_ratio,
|
| 91 |
+
depthwise_kernel_size=config.depthwise_kernel_size,
|
| 92 |
+
adaln_low_rank_rank=config.adaln_low_rank_rank,
|
| 93 |
+
pdg_mask_ratio=config.pdg_mask_ratio,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def from_pretrained(
|
| 98 |
+
cls,
|
| 99 |
+
path_or_repo_id: str | Path,
|
| 100 |
+
*,
|
| 101 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 102 |
+
device: str | torch.device = "cpu",
|
| 103 |
+
revision: str | None = None,
|
| 104 |
+
cache_dir: str | Path | None = None,
|
| 105 |
+
) -> MDiffAE:
|
| 106 |
+
"""Load a pretrained model from a local directory or HuggingFace Hub.
|
| 107 |
+
|
| 108 |
+
The directory (or repo) should contain:
|
| 109 |
+
- config.json: Model architecture config.
|
| 110 |
+
- model.safetensors (preferred) or model.pt: Model weights.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
path_or_repo_id: Local directory path or HuggingFace Hub repo ID
|
| 114 |
+
(e.g. ``"data-archetype/mdiffae-v1"``).
|
| 115 |
+
dtype: Load weights in this dtype (float32 or bfloat16).
|
| 116 |
+
device: Target device.
|
| 117 |
+
revision: Git revision (branch, tag, or commit) for Hub downloads.
|
| 118 |
+
cache_dir: Where to cache Hub downloads. Uses HF default if None.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Loaded model in eval mode.
|
| 122 |
+
"""
|
| 123 |
+
model_dir = _resolve_model_dir(
|
| 124 |
+
path_or_repo_id, revision=revision, cache_dir=cache_dir
|
| 125 |
+
)
|
| 126 |
+
config = MDiffAEConfig.load(model_dir / "config.json")
|
| 127 |
+
model = cls(config)
|
| 128 |
+
|
| 129 |
+
# Try safetensors first, fall back to .pt
|
| 130 |
+
safetensors_path = model_dir / "model.safetensors"
|
| 131 |
+
pt_path = model_dir / "model.pt"
|
| 132 |
+
|
| 133 |
+
if safetensors_path.exists():
|
| 134 |
+
try:
|
| 135 |
+
from safetensors.torch import load_file
|
| 136 |
+
|
| 137 |
+
state_dict = load_file(str(safetensors_path), device=str(device))
|
| 138 |
+
except ImportError:
|
| 139 |
+
raise ImportError(
|
| 140 |
+
"safetensors package required to load .safetensors files. "
|
| 141 |
+
"Install with: pip install safetensors"
|
| 142 |
+
)
|
| 143 |
+
elif pt_path.exists():
|
| 144 |
+
state_dict = torch.load(
|
| 145 |
+
str(pt_path), map_location=device, weights_only=True
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
raise FileNotFoundError(
|
| 149 |
+
f"No model weights found in {model_dir}. "
|
| 150 |
+
"Expected model.safetensors or model.pt."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
model.load_state_dict(state_dict)
|
| 154 |
+
model = model.to(dtype=dtype, device=torch.device(device))
|
| 155 |
+
model.eval()
|
| 156 |
+
return model
|
| 157 |
+
|
| 158 |
+
def encode(self, images: Tensor) -> Tensor:
|
| 159 |
+
"""Encode images to latents.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
images: [B, 3, H, W] in [-1, 1], H and W must be divisible by patch_size.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Latents [B, bottleneck_dim, H/patch, W/patch].
|
| 166 |
+
"""
|
| 167 |
+
try:
|
| 168 |
+
model_dtype = next(self.parameters()).dtype
|
| 169 |
+
except StopIteration:
|
| 170 |
+
model_dtype = torch.float32
|
| 171 |
+
return self.encoder(images.to(dtype=model_dtype))
|
| 172 |
+
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
def decode(
|
| 175 |
+
self,
|
| 176 |
+
latents: Tensor,
|
| 177 |
+
height: int,
|
| 178 |
+
width: int,
|
| 179 |
+
*,
|
| 180 |
+
inference_config: MDiffAEInferenceConfig | None = None,
|
| 181 |
+
) -> Tensor:
|
| 182 |
+
"""Decode latents to images via VP diffusion.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
latents: [B, bottleneck_dim, h, w] encoder latents.
|
| 186 |
+
height: Output image height (must be divisible by patch_size).
|
| 187 |
+
width: Output image width (must be divisible by patch_size).
|
| 188 |
+
inference_config: Optional inference parameters. Uses defaults if None.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Reconstructed images [B, 3, H, W] in float32.
|
| 192 |
+
"""
|
| 193 |
+
cfg = inference_config or MDiffAEInferenceConfig()
|
| 194 |
+
config = self.config
|
| 195 |
+
batch = int(latents.shape[0])
|
| 196 |
+
device = latents.device
|
| 197 |
+
|
| 198 |
+
# Determine model dtype from parameters
|
| 199 |
+
try:
|
| 200 |
+
model_dtype = next(self.parameters()).dtype
|
| 201 |
+
except StopIteration:
|
| 202 |
+
model_dtype = torch.float32
|
| 203 |
+
|
| 204 |
+
# Validate dimensions
|
| 205 |
+
if height % config.patch_size != 0 or width % config.patch_size != 0:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"height={height} and width={width} must be divisible by patch_size={config.patch_size}"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Generate initial noise
|
| 211 |
+
shape = (batch, config.in_channels, height, width)
|
| 212 |
+
noise = sample_noise(
|
| 213 |
+
shape,
|
| 214 |
+
noise_std=config.pixel_noise_std,
|
| 215 |
+
seed=cfg.seed,
|
| 216 |
+
device=torch.device("cpu"),
|
| 217 |
+
dtype=torch.float32,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Build schedule
|
| 221 |
+
schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
|
| 222 |
+
|
| 223 |
+
# Construct initial state: sigma_start * noise
|
| 224 |
+
initial_state = make_initial_state(
|
| 225 |
+
noise=noise.to(device=device),
|
| 226 |
+
t_start=schedule[0:1],
|
| 227 |
+
logsnr_min=config.logsnr_min,
|
| 228 |
+
logsnr_max=config.logsnr_max,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Disable autocast for numerical precision
|
| 232 |
+
device_type = "cuda" if device.type == "cuda" else "cpu"
|
| 233 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 234 |
+
latents_in = latents.to(device=device)
|
| 235 |
+
|
| 236 |
+
def _forward_fn(
|
| 237 |
+
x_t: Tensor,
|
| 238 |
+
t: Tensor,
|
| 239 |
+
latents: Tensor,
|
| 240 |
+
*,
|
| 241 |
+
mask_tokens: bool = False,
|
| 242 |
+
) -> Tensor:
|
| 243 |
+
return self.decoder(
|
| 244 |
+
x_t.to(dtype=model_dtype),
|
| 245 |
+
t,
|
| 246 |
+
latents.to(dtype=model_dtype),
|
| 247 |
+
mask_tokens=mask_tokens,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Select sampler
|
| 251 |
+
if cfg.sampler == "ddim":
|
| 252 |
+
sampler_fn = run_ddim
|
| 253 |
+
elif cfg.sampler == "dpmpp_2m":
|
| 254 |
+
sampler_fn = run_dpmpp_2m
|
| 255 |
+
else:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
result = sampler_fn(
|
| 261 |
+
forward_fn=_forward_fn,
|
| 262 |
+
initial_state=initial_state,
|
| 263 |
+
schedule=schedule,
|
| 264 |
+
latents=latents_in,
|
| 265 |
+
logsnr_min=config.logsnr_min,
|
| 266 |
+
logsnr_max=config.logsnr_max,
|
| 267 |
+
pdg_enabled=cfg.pdg_enabled,
|
| 268 |
+
pdg_strength=cfg.pdg_strength,
|
| 269 |
+
device=device,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return result
|
| 273 |
+
|
| 274 |
+
@torch.no_grad()
|
| 275 |
+
def reconstruct(
|
| 276 |
+
self,
|
| 277 |
+
images: Tensor,
|
| 278 |
+
*,
|
| 279 |
+
inference_config: MDiffAEInferenceConfig | None = None,
|
| 280 |
+
) -> Tensor:
|
| 281 |
+
"""Encode then decode. Convenience wrapper.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
images: [B, 3, H, W] in [-1, 1].
|
| 285 |
+
inference_config: Optional inference parameters.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Reconstructed images [B, 3, H, W] in float32.
|
| 289 |
+
"""
|
| 290 |
+
latents = self.encode(images)
|
| 291 |
+
_, _, h, w = images.shape
|
| 292 |
+
return self.decode(
|
| 293 |
+
latents, height=h, width=w, inference_config=inference_config
|
| 294 |
+
)
|
m_diffae/norms.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Channel-wise RMSNorm for NCHW tensors."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChannelWiseRMSNorm(nn.Module):
|
| 10 |
+
"""Channel-wise RMSNorm with float32 reduction for numerical stability.
|
| 11 |
+
|
| 12 |
+
Normalizes across channels per spatial position. Supports optional
|
| 13 |
+
per-channel affine weight and bias.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.channels: int = int(channels)
|
| 19 |
+
self._eps: float = float(eps)
|
| 20 |
+
if affine:
|
| 21 |
+
self.weight = nn.Parameter(torch.ones(self.channels))
|
| 22 |
+
self.bias = nn.Parameter(torch.zeros(self.channels))
|
| 23 |
+
else:
|
| 24 |
+
self.register_parameter("weight", None)
|
| 25 |
+
self.register_parameter("bias", None)
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
if x.dim() < 2:
|
| 29 |
+
return x
|
| 30 |
+
# Float32 accumulation for stability
|
| 31 |
+
ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
|
| 32 |
+
inv_rms = torch.rsqrt(ms + self._eps)
|
| 33 |
+
y = x * inv_rms
|
| 34 |
+
if self.weight is not None:
|
| 35 |
+
shape = (1, -1) + (1,) * (x.dim() - 2)
|
| 36 |
+
y = y * self.weight.view(shape).to(dtype=y.dtype)
|
| 37 |
+
if self.bias is not None:
|
| 38 |
+
y = y + self.bias.view(shape).to(dtype=y.dtype)
|
| 39 |
+
return y.to(dtype=x.dtype)
|
m_diffae/samplers.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DDIM and DPM++2M samplers for VP diffusion with x-prediction objective."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Protocol
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from .vp_diffusion import (
|
| 11 |
+
alpha_sigma_from_logsnr,
|
| 12 |
+
broadcast_time_like,
|
| 13 |
+
shifted_cosine_interpolated_logsnr_from_t,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DecoderForwardFn(Protocol):
|
| 18 |
+
"""Callable that predicts x0 from (x_t, t, latents)."""
|
| 19 |
+
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
x_t: Tensor,
|
| 23 |
+
t: Tensor,
|
| 24 |
+
latents: Tensor,
|
| 25 |
+
*,
|
| 26 |
+
mask_tokens: bool = False,
|
| 27 |
+
) -> Tensor: ...
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _reconstruct_eps_from_x0(
|
| 31 |
+
*, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
|
| 32 |
+
) -> Tensor:
|
| 33 |
+
"""Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
|
| 34 |
+
|
| 35 |
+
eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
|
| 36 |
+
"""
|
| 37 |
+
alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
|
| 38 |
+
sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
|
| 39 |
+
x_t_f32 = x_t.to(torch.float32)
|
| 40 |
+
x0_f32 = x0_hat.to(torch.float32)
|
| 41 |
+
return (x_t_f32 - alpha_view * x0_f32) / sigma_view
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _ddim_step(
|
| 45 |
+
*,
|
| 46 |
+
x0_hat: Tensor,
|
| 47 |
+
eps_hat: Tensor,
|
| 48 |
+
alpha_next: Tensor,
|
| 49 |
+
sigma_next: Tensor,
|
| 50 |
+
ref: Tensor,
|
| 51 |
+
) -> Tensor:
|
| 52 |
+
"""DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
|
| 53 |
+
a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
|
| 54 |
+
s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
|
| 55 |
+
return a * x0_hat + s * eps_hat
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_ddim(
|
| 59 |
+
*,
|
| 60 |
+
forward_fn: DecoderForwardFn,
|
| 61 |
+
initial_state: Tensor,
|
| 62 |
+
schedule: Tensor,
|
| 63 |
+
latents: Tensor,
|
| 64 |
+
logsnr_min: float,
|
| 65 |
+
logsnr_max: float,
|
| 66 |
+
log_change_high: float = 0.0,
|
| 67 |
+
log_change_low: float = 0.0,
|
| 68 |
+
pdg_enabled: bool = False,
|
| 69 |
+
pdg_strength: float = 1.1,
|
| 70 |
+
device: torch.device | None = None,
|
| 71 |
+
) -> Tensor:
|
| 72 |
+
"""Run DDIM sampling loop.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
|
| 76 |
+
initial_state: Starting noised state [B, C, H, W] in float32.
|
| 77 |
+
schedule: Descending t-schedule [num_steps] in [0, 1].
|
| 78 |
+
latents: Encoder latents [B, bottleneck_dim, h, w].
|
| 79 |
+
logsnr_min, logsnr_max: VP schedule endpoints.
|
| 80 |
+
log_change_high, log_change_low: Shifted-cosine schedule parameters.
|
| 81 |
+
pdg_enabled: Whether to use token-level Path-Drop Guidance.
|
| 82 |
+
pdg_strength: CFG-like strength for PDG (use small values: 1.05β1.2).
|
| 83 |
+
device: Target device.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Denoised samples [B, C, H, W] in float32.
|
| 87 |
+
"""
|
| 88 |
+
run_device = device or initial_state.device
|
| 89 |
+
batch_size = int(initial_state.shape[0])
|
| 90 |
+
state = initial_state.to(device=run_device, dtype=torch.float32)
|
| 91 |
+
|
| 92 |
+
# Precompute logSNR, alpha, sigma for all schedule points
|
| 93 |
+
lmb = shifted_cosine_interpolated_logsnr_from_t(
|
| 94 |
+
schedule.to(device=run_device),
|
| 95 |
+
logsnr_min=logsnr_min,
|
| 96 |
+
logsnr_max=logsnr_max,
|
| 97 |
+
log_change_high=log_change_high,
|
| 98 |
+
log_change_low=log_change_low,
|
| 99 |
+
)
|
| 100 |
+
alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
|
| 101 |
+
|
| 102 |
+
for i in range(int(schedule.numel()) - 1):
|
| 103 |
+
t_i = schedule[i]
|
| 104 |
+
a_t = alpha_sched[i].expand(batch_size)
|
| 105 |
+
s_t = sigma_sched[i].expand(batch_size)
|
| 106 |
+
a_next = alpha_sched[i + 1].expand(batch_size)
|
| 107 |
+
s_next = sigma_sched[i + 1].expand(batch_size)
|
| 108 |
+
|
| 109 |
+
# Model prediction
|
| 110 |
+
t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
|
| 111 |
+
if pdg_enabled:
|
| 112 |
+
x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to(
|
| 113 |
+
torch.float32
|
| 114 |
+
)
|
| 115 |
+
x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to(
|
| 116 |
+
torch.float32
|
| 117 |
+
)
|
| 118 |
+
x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond)
|
| 119 |
+
else:
|
| 120 |
+
x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to(
|
| 121 |
+
torch.float32
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
eps_hat = _reconstruct_eps_from_x0(
|
| 125 |
+
x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
|
| 126 |
+
)
|
| 127 |
+
state = _ddim_step(
|
| 128 |
+
x0_hat=x0_hat,
|
| 129 |
+
eps_hat=eps_hat,
|
| 130 |
+
alpha_next=a_next,
|
| 131 |
+
sigma_next=s_next,
|
| 132 |
+
ref=state,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return state
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def run_dpmpp_2m(
|
| 139 |
+
*,
|
| 140 |
+
forward_fn: DecoderForwardFn,
|
| 141 |
+
initial_state: Tensor,
|
| 142 |
+
schedule: Tensor,
|
| 143 |
+
latents: Tensor,
|
| 144 |
+
logsnr_min: float,
|
| 145 |
+
logsnr_max: float,
|
| 146 |
+
log_change_high: float = 0.0,
|
| 147 |
+
log_change_low: float = 0.0,
|
| 148 |
+
pdg_enabled: bool = False,
|
| 149 |
+
pdg_strength: float = 1.1,
|
| 150 |
+
device: torch.device | None = None,
|
| 151 |
+
) -> Tensor:
|
| 152 |
+
"""Run DPM++2M sampling loop.
|
| 153 |
+
|
| 154 |
+
Multi-step solver using exponential integrator formulation in half-lambda space.
|
| 155 |
+
"""
|
| 156 |
+
run_device = device or initial_state.device
|
| 157 |
+
batch_size = int(initial_state.shape[0])
|
| 158 |
+
state = initial_state.to(device=run_device, dtype=torch.float32)
|
| 159 |
+
|
| 160 |
+
# Precompute logSNR, alpha, sigma, half-lambda for all schedule points
|
| 161 |
+
lmb = shifted_cosine_interpolated_logsnr_from_t(
|
| 162 |
+
schedule.to(device=run_device),
|
| 163 |
+
logsnr_min=logsnr_min,
|
| 164 |
+
logsnr_max=logsnr_max,
|
| 165 |
+
log_change_high=log_change_high,
|
| 166 |
+
log_change_low=log_change_low,
|
| 167 |
+
)
|
| 168 |
+
alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
|
| 169 |
+
half_lambda = 0.5 * lmb.to(torch.float32)
|
| 170 |
+
|
| 171 |
+
x0_prev: Tensor | None = None
|
| 172 |
+
|
| 173 |
+
for i in range(int(schedule.numel()) - 1):
|
| 174 |
+
t_i = schedule[i]
|
| 175 |
+
s_t = sigma_sched[i].expand(batch_size)
|
| 176 |
+
a_next = alpha_sched[i + 1].expand(batch_size)
|
| 177 |
+
s_next = sigma_sched[i + 1].expand(batch_size)
|
| 178 |
+
|
| 179 |
+
# Model prediction
|
| 180 |
+
t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
|
| 181 |
+
if pdg_enabled:
|
| 182 |
+
x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to(
|
| 183 |
+
torch.float32
|
| 184 |
+
)
|
| 185 |
+
x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to(
|
| 186 |
+
torch.float32
|
| 187 |
+
)
|
| 188 |
+
x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond)
|
| 189 |
+
else:
|
| 190 |
+
x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to(
|
| 191 |
+
torch.float32
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
lam_t = half_lambda[i].expand(batch_size)
|
| 195 |
+
lam_next = half_lambda[i + 1].expand(batch_size)
|
| 196 |
+
h = (lam_next - lam_t).to(torch.float32)
|
| 197 |
+
phi_1 = torch.expm1(-h)
|
| 198 |
+
|
| 199 |
+
sigma_ratio = (s_next / s_t).to(torch.float32)
|
| 200 |
+
|
| 201 |
+
if i == 0 or x0_prev is None:
|
| 202 |
+
# First-order step
|
| 203 |
+
state = (
|
| 204 |
+
sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
|
| 205 |
+
- broadcast_time_like(a_next, state).to(torch.float32)
|
| 206 |
+
* broadcast_time_like(phi_1, state).to(torch.float32)
|
| 207 |
+
* x0_hat
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
# Second-order step
|
| 211 |
+
lam_prev = half_lambda[i - 1].expand(batch_size)
|
| 212 |
+
h_0 = (lam_t - lam_prev).to(torch.float32)
|
| 213 |
+
r0 = h_0 / h
|
| 214 |
+
d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
|
| 215 |
+
common = broadcast_time_like(a_next, state).to(
|
| 216 |
+
torch.float32
|
| 217 |
+
) * broadcast_time_like(phi_1, state).to(torch.float32)
|
| 218 |
+
state = (
|
| 219 |
+
sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
|
| 220 |
+
- common * x0_hat
|
| 221 |
+
- 0.5 * common * d1_0
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
x0_prev = x0_hat
|
| 225 |
+
|
| 226 |
+
return state
|
m_diffae/straight_through_encoder.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PixelUnshuffle-based patchifier (no residual conv path)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Patchify(nn.Module):
|
| 9 |
+
"""PixelUnshuffle(patch) -> Conv2d 1x1 projection.
|
| 10 |
+
|
| 11 |
+
Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.patch = int(patch)
|
| 17 |
+
self.unshuffle = nn.PixelUnshuffle(self.patch)
|
| 18 |
+
in_after = in_channels * (self.patch * self.patch)
|
| 19 |
+
self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 22 |
+
if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
|
| 23 |
+
raise ValueError(
|
| 24 |
+
f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
|
| 25 |
+
)
|
| 26 |
+
y = self.unshuffle(x)
|
| 27 |
+
return self.proj(y)
|
m_diffae/time_embed.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sinusoidal timestep embedding with MLP projection."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _log_spaced_frequencies(
|
| 12 |
+
half: int, max_period: float, *, device: torch.device | None = None
|
| 13 |
+
) -> Tensor:
|
| 14 |
+
"""Log-spaced frequencies for sinusoidal embedding."""
|
| 15 |
+
return torch.exp(
|
| 16 |
+
-math.log(max_period)
|
| 17 |
+
* torch.arange(half, device=device, dtype=torch.float32)
|
| 18 |
+
/ max(float(half - 1), 1.0)
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def sinusoidal_time_embedding(
|
| 23 |
+
t: Tensor,
|
| 24 |
+
dim: int,
|
| 25 |
+
*,
|
| 26 |
+
max_period: float = 10000.0,
|
| 27 |
+
scale: float | None = None,
|
| 28 |
+
freqs: Tensor | None = None,
|
| 29 |
+
) -> Tensor:
|
| 30 |
+
"""Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
|
| 31 |
+
t32 = t.to(torch.float32)
|
| 32 |
+
if scale is not None:
|
| 33 |
+
t32 = t32 * float(scale)
|
| 34 |
+
half = dim // 2
|
| 35 |
+
if freqs is not None:
|
| 36 |
+
freqs = freqs.to(device=t32.device, dtype=torch.float32)
|
| 37 |
+
else:
|
| 38 |
+
freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
|
| 39 |
+
angles = t32[:, None] * freqs[None, :]
|
| 40 |
+
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SinusoidalTimeEmbeddingMLP(nn.Module):
|
| 44 |
+
"""Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
dim: int,
|
| 49 |
+
*,
|
| 50 |
+
freq_dim: int = 256,
|
| 51 |
+
hidden_mult: float = 1.0,
|
| 52 |
+
time_scale: float = 1000.0,
|
| 53 |
+
max_period: float = 10000.0,
|
| 54 |
+
) -> None:
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.dim = int(dim)
|
| 57 |
+
self.freq_dim = int(freq_dim)
|
| 58 |
+
self.time_scale = float(time_scale)
|
| 59 |
+
self.max_period = float(max_period)
|
| 60 |
+
hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
|
| 61 |
+
|
| 62 |
+
freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
|
| 63 |
+
self.register_buffer("freqs", freqs, persistent=True)
|
| 64 |
+
|
| 65 |
+
self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
|
| 66 |
+
self.act = nn.SiLU()
|
| 67 |
+
self.proj_out = nn.Linear(hidden_dim, self.dim)
|
| 68 |
+
|
| 69 |
+
def forward(self, t: Tensor) -> Tensor:
|
| 70 |
+
freqs: Tensor = self.freqs # type: ignore[assignment]
|
| 71 |
+
emb_freq = sinusoidal_time_embedding(
|
| 72 |
+
t.to(torch.float32),
|
| 73 |
+
self.freq_dim,
|
| 74 |
+
max_period=self.max_period,
|
| 75 |
+
scale=self.time_scale,
|
| 76 |
+
freqs=freqs,
|
| 77 |
+
)
|
| 78 |
+
dtype_in = self.proj_in.weight.dtype
|
| 79 |
+
hidden = self.proj_in(emb_freq.to(dtype_in))
|
| 80 |
+
hidden = self.act(hidden)
|
| 81 |
+
if hidden.dtype != self.proj_out.weight.dtype:
|
| 82 |
+
hidden = hidden.to(self.proj_out.weight.dtype)
|
| 83 |
+
return self.proj_out(hidden)
|
m_diffae/vp_diffusion.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
|
| 12 |
+
"""Compute (alpha, sigma) from logSNR in float32.
|
| 13 |
+
|
| 14 |
+
VP constraint: alpha^2 + sigma^2 = 1.
|
| 15 |
+
"""
|
| 16 |
+
lmb32 = lmb.to(dtype=torch.float32)
|
| 17 |
+
alpha = torch.sqrt(torch.sigmoid(lmb32))
|
| 18 |
+
sigma = torch.sqrt(torch.sigmoid(-lmb32))
|
| 19 |
+
return alpha, sigma
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
|
| 23 |
+
"""Broadcast [B] coefficient to match x for per-sample scaling."""
|
| 24 |
+
view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
|
| 25 |
+
return coeff.view(view_shape)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _cosine_interpolated_params(
|
| 29 |
+
logsnr_min: float, logsnr_max: float
|
| 30 |
+
) -> tuple[float, float]:
|
| 31 |
+
"""Compute (a, b) for cosine-interpolated logSNR schedule.
|
| 32 |
+
|
| 33 |
+
logsnr(t) = -2 * log(tan(a*t + b))
|
| 34 |
+
logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
|
| 35 |
+
"""
|
| 36 |
+
b = math.atan(math.exp(-0.5 * logsnr_max))
|
| 37 |
+
a = math.atan(math.exp(-0.5 * logsnr_min)) - b
|
| 38 |
+
return a, b
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def cosine_interpolated_logsnr_from_t(
|
| 42 |
+
t: Tensor, *, logsnr_min: float, logsnr_max: float
|
| 43 |
+
) -> Tensor:
|
| 44 |
+
"""Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
|
| 45 |
+
a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
|
| 46 |
+
t32 = t.to(dtype=torch.float32)
|
| 47 |
+
a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
|
| 48 |
+
b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
|
| 49 |
+
u = a_t * t32 + b_t
|
| 50 |
+
return -2.0 * torch.log(torch.tan(u))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def shifted_cosine_interpolated_logsnr_from_t(
|
| 54 |
+
t: Tensor,
|
| 55 |
+
*,
|
| 56 |
+
logsnr_min: float,
|
| 57 |
+
logsnr_max: float,
|
| 58 |
+
log_change_high: float = 0.0,
|
| 59 |
+
log_change_low: float = 0.0,
|
| 60 |
+
) -> Tensor:
|
| 61 |
+
"""SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
|
| 62 |
+
|
| 63 |
+
lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
|
| 64 |
+
"""
|
| 65 |
+
base = cosine_interpolated_logsnr_from_t(
|
| 66 |
+
t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
|
| 67 |
+
)
|
| 68 |
+
t32 = t.to(dtype=torch.float32)
|
| 69 |
+
high = base + float(log_change_high)
|
| 70 |
+
low = base + float(log_change_low)
|
| 71 |
+
return (1.0 - t32) * high + t32 * low
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
|
| 75 |
+
"""Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
|
| 76 |
+
|
| 77 |
+
``num_steps`` is the number of function evaluations (NFE = decoder forward
|
| 78 |
+
passes). Internally the schedule has ``num_steps + 1`` time points
|
| 79 |
+
(including both endpoints).
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
schedule_type: "linear" or "cosine".
|
| 83 |
+
num_steps: Number of decoder forward passes (NFE), >= 1.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
|
| 87 |
+
"""
|
| 88 |
+
# NOTE: the upstream training code (src/ode/time_schedules.py) uses a
|
| 89 |
+
# different convention where num_steps counts schedule *points* (so NFE =
|
| 90 |
+
# num_steps - 1). This export package corrects the off-by-one so that
|
| 91 |
+
# num_steps means NFE directly. TODO: align the upstream convention.
|
| 92 |
+
n = max(int(num_steps) + 1, 2)
|
| 93 |
+
if schedule_type == "linear":
|
| 94 |
+
base = torch.linspace(0.0, 1.0, n)
|
| 95 |
+
elif schedule_type == "cosine":
|
| 96 |
+
i = torch.arange(n, dtype=torch.float32)
|
| 97 |
+
base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
|
| 101 |
+
)
|
| 102 |
+
# Descending: high t (noisy) -> low t (clean)
|
| 103 |
+
return torch.flip(base, dims=[0])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def make_initial_state(
|
| 107 |
+
*,
|
| 108 |
+
noise: Tensor,
|
| 109 |
+
t_start: Tensor,
|
| 110 |
+
logsnr_min: float,
|
| 111 |
+
logsnr_max: float,
|
| 112 |
+
log_change_high: float = 0.0,
|
| 113 |
+
log_change_low: float = 0.0,
|
| 114 |
+
) -> Tensor:
|
| 115 |
+
"""Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
|
| 116 |
+
|
| 117 |
+
All math in float32.
|
| 118 |
+
"""
|
| 119 |
+
batch = int(noise.shape[0])
|
| 120 |
+
lmb_start = shifted_cosine_interpolated_logsnr_from_t(
|
| 121 |
+
t_start.expand(batch).to(dtype=torch.float32),
|
| 122 |
+
logsnr_min=logsnr_min,
|
| 123 |
+
logsnr_max=logsnr_max,
|
| 124 |
+
log_change_high=log_change_high,
|
| 125 |
+
log_change_low=log_change_low,
|
| 126 |
+
)
|
| 127 |
+
_alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
|
| 128 |
+
sigma_view = broadcast_time_like(sigma_start, noise)
|
| 129 |
+
return sigma_view * noise.to(dtype=torch.float32)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def sample_noise(
|
| 133 |
+
shape: tuple[int, ...],
|
| 134 |
+
*,
|
| 135 |
+
noise_std: float = 1.0,
|
| 136 |
+
seed: int | None = None,
|
| 137 |
+
device: torch.device | None = None,
|
| 138 |
+
dtype: torch.dtype = torch.float32,
|
| 139 |
+
) -> Tensor:
|
| 140 |
+
"""Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
|
| 141 |
+
if seed is None:
|
| 142 |
+
noise = torch.randn(
|
| 143 |
+
shape, device=device or torch.device("cpu"), dtype=torch.float32
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
gen = torch.Generator(device="cpu")
|
| 147 |
+
gen.manual_seed(int(seed))
|
| 148 |
+
noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
|
| 149 |
+
noise = noise.mul(float(noise_std))
|
| 150 |
+
target_device = device if device is not None else torch.device("cpu")
|
| 151 |
+
return noise.to(device=target_device, dtype=dtype)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:196bd7a0607c495c0b72bb8c306c66d871792ea3d3f120afa14115ba3ca8e7ae
|
| 3 |
+
size 325656824
|
technical_report_mdiffae.md
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mDiffAE: Masked Diffusion AutoEncoder β Technical Report
|
| 2 |
+
|
| 3 |
+
**Version 1** β March 2026
|
| 4 |
+
|
| 5 |
+
## 1. Introduction
|
| 6 |
+
|
| 7 |
+
mDiffAE (**M**asked **Diff**usion **A**uto**E**ncoder) builds on the [iRDiffAE](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) model family, which provides a fast, single-GPU-trainable diffusion autoencoder platform capable of high-quality image reconstruction. See that report for full background on the shared components: VP diffusion math (logSNR schedules, alpha/sigma, x-prediction), DiCo block architecture (depthwise conv + compact channel attention + GELU MLP), patchify encoder (PixelUnshuffle + 1Γ1 conv), shared-base + low-rank AdaLN-Zero conditioning, and the Path-Drop Guidance (PDG) concept.
|
| 8 |
+
|
| 9 |
+
The iRDiffAE platform is designed to make it easy to experiment with different ways of regularizing the latent space. iRDiffAE v1 used REPA β aligning encoder features with a frozen DINOv2 teacher β which produces well-structured latents but tends toward overly smooth representations that are hard to reconcile with fine detail. Here we take a different approach entirely: **decoder token masking**.
|
| 10 |
+
|
| 11 |
+
### 1.1 Token Masking as Regularizer
|
| 12 |
+
|
| 13 |
+
A fraction of the time (50% of samples per batch), the decoder only sees **25% of tokens** in the fused conditioning input. The spatial token grid is divided into non-overlapping 2Γ2 groups, and within each group a single token is randomly kept while the other three are replaced with a learned mask feature. Hiding such a large fraction (75%) pushes the encoder to learn a form of representation consistency β each spatial token must carry enough information to support reconstruction even when most of its neighbors are absent. A smaller masking fraction helps downstream models learn sharp details quickly, but they fail to learn spatial coherence nearly as well β the task becomes too close to local inpainting, and the encoder is not pressured into globally consistent representations. The importance of a high masking ratio echoes findings in the masked autoencoder literature (He et al., 2022); we tested lower ratios and confirmed this tradeoff empirically.
|
| 14 |
+
|
| 15 |
+
The 50% per-sample application probability is the knob that controls the compromise between reconstruction quality and latent space quality: samples that receive masking push the encoder toward consistent representations, while unmasked samples maintain reconstruction fidelity.
|
| 16 |
+
|
| 17 |
+
### 1.2 Latent Noise Regularization
|
| 18 |
+
|
| 19 |
+
To further regularize the latent space, we retain the random latent noising mechanism 10% of the time. However, unlike the pixel-space diffusion noise, the latent noise level is sampled independently using a **Beta(2,2)** distribution (stratified), with a **logSNR shift of +1.0** that biases it toward low noise levels (low *t* in our convention). This decouples the latent regularization schedule from the decoder's diffusion schedule, providing a gentle push toward noise-robust representations without disrupting reconstruction training.
|
| 20 |
+
|
| 21 |
+
### 1.3 Simplified Decoder
|
| 22 |
+
|
| 23 |
+
To keep the representational pressure on the encoder, we restrict the decoder to only **4 blocks** (down from 8 in iRDiffAE v1) and simplify it to a flat sequential architecture β no start/middle/end block groups, no skip connections. This halves the decoder's parameter count and makes it roughly 2Γ faster, while forcing the encoder to compensate by producing more informative latents.
|
| 24 |
+
|
| 25 |
+
### 1.4 Empirical Results
|
| 26 |
+
|
| 27 |
+
Compared to the REPA-regularized iRDiffAE v1, mDiffAE achieves slightly higher PSNR (to be confirmed with final benchmarks, but initial results were quite decisive) and produces a less oversmoothed but very locally consistent latent space PCA. In downstream diffusion model training, mDiffAE's latent space does not exhibit the very steep initial loss descent seen with iRDiffAE, but it quickly catches up after 50kβ100k training steps, producing more spatially coherent images earlier with better high-frequency detail.
|
| 28 |
+
|
| 29 |
+
### 1.5 References
|
| 30 |
+
|
| 31 |
+
- He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022). *Masked Autoencoders Are Scalable Vision Learners*. CVPR 2022.
|
| 32 |
+
- Li, T., Chang, H., Mishra, S.K., Zhang, H., Katabi, D., & Krishnan, D. (2023). *MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis*. CVPR 2023.
|
| 33 |
+
|
| 34 |
+
## 2. Architecture Differences from iRDiffAE
|
| 35 |
+
|
| 36 |
+
| Aspect | iRDiffAE v1 (halfrepa c128) | mDiffAE v1 (masked c64) |
|
| 37 |
+
|--------|------------------------------|--------------------------|
|
| 38 |
+
| Bottleneck dim | 128 | **64** |
|
| 39 |
+
| Decoder depth | 8 (2 start + 4 middle + 2 end) | **4 (flat sequential)** |
|
| 40 |
+
| Decoder topology | START_MIDDLE_END_SKIP_CONCAT | **FLAT (no skip concat)** |
|
| 41 |
+
| Skip fusion | Yes (`fuse_skip` Conv1Γ1) | **No** |
|
| 42 |
+
| PDG mechanism | Drop middle blocks β mask_feature | **Token-level masking** (75% spatial tokens β mask_feature) |
|
| 43 |
+
| PDG sensitivity | Moderate (strength 1.5β3.0) | **Very sensitive** (strength 1.05β1.2 only) |
|
| 44 |
+
| Training regularizer | REPA (half-channel DINOv2 alignment) + covreg | **Decoder token masking** (75% ratio, 50% apply prob) |
|
| 45 |
+
| Latent noise reg | Same mechanism | **Independent Beta(2,2), logSNR shift +1.0, 10% prob** |
|
| 46 |
+
| Depthwise kernel | 7Γ7 | 7Γ7 (same) |
|
| 47 |
+
| Model dim | 896 | 896 (same) |
|
| 48 |
+
| Encoder depth | 4 | 4 (same) |
|
| 49 |
+
| Best decode | 1 step DDIM | 1 step DDIM (same) |
|
| 50 |
+
|
| 51 |
+
## 3. Training-Time Masking Details
|
| 52 |
+
|
| 53 |
+
### 3.1 Token Masking Procedure
|
| 54 |
+
|
| 55 |
+
During training, with 50% probability per sample:
|
| 56 |
+
1. The fused decoder input (patchified x_t + upsampled encoder latents) is divided into non-overlapping 2Γ2 spatial groups
|
| 57 |
+
2. Within each group, 3 of 4 tokens (75%) are selected for masking using random scoring
|
| 58 |
+
3. Masked tokens are replaced with a learned `mask_feature` parameter (same dimensionality as model_dim)
|
| 59 |
+
4. The decoder processes the partially-masked input normally through all blocks
|
| 60 |
+
|
| 61 |
+
### 3.2 Inference-Time PDG via Token Masking
|
| 62 |
+
|
| 63 |
+
At inference, the trained mask_feature enables Path-Drop Guidance (PDG) through token-level masking rather than block-level dropping:
|
| 64 |
+
|
| 65 |
+
- **Conditional pass**: Full decoder input (no masking)
|
| 66 |
+
- **Unconditional pass**: Apply 2Γ2 groupwise token masking at the trained ratio (75%)
|
| 67 |
+
- **Guided output**: `x0 = x0_uncond + strength Γ (x0_cond β x0_uncond)`
|
| 68 |
+
|
| 69 |
+
Because the decoder has only 4 blocks and no skip connections, the guidance signal from token masking is very concentrated. This makes PDG extremely sensitive β even a strength of 1.2 produces noticeable sharpening, and values above 1.5 cause severe artifacts.
|
| 70 |
+
|
| 71 |
+
## 4. Flat Decoder Architecture
|
| 72 |
+
|
| 73 |
+
### 4.1 iRDiffAE v1 Decoder (for comparison)
|
| 74 |
+
|
| 75 |
+
The iRDiffAE v1 decoder uses an 8-block layout split into three groups with a skip connection:
|
| 76 |
+
|
| 77 |
+
```
|
| 78 |
+
Fused input β Start blocks (2) β [save for skip] β
|
| 79 |
+
Middle blocks (4) β [cat with saved skip] β FuseSkip Conv1Γ1 β
|
| 80 |
+
End blocks (2) β Output head
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
The skip connection concatenates the start-block output with the middle-block output and fuses them through a learned Conv1Γ1 before feeding into the end blocks. For PDG, the entire middle block computation is dropped and replaced with a broadcasted learned `mask_feature`, effectively removing all 4 middle blocks from the forward pass. This produces a coarse "unconditional" signal for classifier-free guidance.
|
| 84 |
+
|
| 85 |
+
### 4.2 mDiffAE v1 Decoder
|
| 86 |
+
|
| 87 |
+
The mDiffAE decoder replaces this with a flat sequential architecture β no block groups, no skip connection:
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
Input: x_t [B, 3, H, W], t [B], z [B, 64, h, w]
|
| 91 |
+
|
| 92 |
+
Patchify(x_t) β RMSNorm β x_feat [B, 896, h, w]
|
| 93 |
+
LatentUp(z) β RMSNorm β z_up [B, 896, h, w]
|
| 94 |
+
FuseIn(cat(x_feat, z_up)) β fused [B, 896, h, w]
|
| 95 |
+
[Optional: token masking for PDG]
|
| 96 |
+
TimeEmbed(t) β cond [B, 896]
|
| 97 |
+
Block_0(fused, AdaLN(cond)) β ...
|
| 98 |
+
Block_1(..., AdaLN(cond)) β ...
|
| 99 |
+
Block_2(..., AdaLN(cond)) β ...
|
| 100 |
+
Block_3(..., AdaLN(cond)) β out [B, 896, h, w]
|
| 101 |
+
RMSNorm β Conv1x1 β PixelShuffle β x0_hat [B, 3, H, W]
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
With only 4 blocks and no skip fusion layer, the decoder has roughly half the parameters of iRDiffAE's decoder. The `fuse_skip` Conv1Γ1 layer is eliminated entirely. For PDG, instead of dropping blocks, 75% of spatial tokens in the fused input are replaced with a learned `mask_feature` before the blocks run. This token-level masking provides a finer-grained guidance signal but is much more sensitive to strength β the decoder sees the full block computation in both the conditional and unconditional paths, so the difference between them is subtle.
|
| 105 |
+
|
| 106 |
+
### 4.3 Bottleneck
|
| 107 |
+
|
| 108 |
+
The bottleneck dimension is halved from 128 channels (iRDiffAE) to 64 channels, giving a 12x compression ratio at patch size 16 (vs 6x for iRDiffAE). Despite the higher compression, the masking regularizer forces the encoder to produce informative per-token representations, maintaining reconstruction quality.
|
| 109 |
+
|
| 110 |
+
## 5. Model Configuration
|
| 111 |
+
|
| 112 |
+
| Parameter | Value |
|
| 113 |
+
|-----------|-------|
|
| 114 |
+
| `in_channels` | 3 |
|
| 115 |
+
| `patch_size` | 16 |
|
| 116 |
+
| `model_dim` | 896 |
|
| 117 |
+
| `encoder_depth` | 4 |
|
| 118 |
+
| `decoder_depth` | 4 |
|
| 119 |
+
| `bottleneck_dim` | 64 |
|
| 120 |
+
| `mlp_ratio` | 4.0 |
|
| 121 |
+
| `depthwise_kernel_size` | 7 |
|
| 122 |
+
| `adaln_low_rank_rank` | 128 |
|
| 123 |
+
| `logsnr_min` | β10.0 |
|
| 124 |
+
| `logsnr_max` | 10.0 |
|
| 125 |
+
| `pixel_noise_std` | 0.558 |
|
| 126 |
+
| `pdg_mask_ratio` | 0.75 |
|
| 127 |
+
|
| 128 |
+
Training checkpoint: step 708,000 (EMA weights).
|
| 129 |
+
|
| 130 |
+
## 6. Inference Recommendations
|
| 131 |
+
|
| 132 |
+
| Setting | Value | Notes |
|
| 133 |
+
|---------|-------|-------|
|
| 134 |
+
| Sampler | DDIM | Best for 1-step |
|
| 135 |
+
| Steps | 1 | PSNR-optimal |
|
| 136 |
+
| PDG | Disabled | Default, safest |
|
| 137 |
+
| PDG strength | 1.05β1.2 | If enabled, very sensitive |
|
| 138 |
+
|
| 139 |
+
## 7. Results
|
| 140 |
+
|
| 141 |
+
Reconstruction quality evaluated on a curated set of test images covering photographs, book covers, and documents. Flux.1 VAE (patch 8, 16 channels) is included as a reference at the same 12x compression ratio as the c64 variant.
|
| 142 |
+
|
| 143 |
+
### 7.1 Interactive Viewer
|
| 144 |
+
|
| 145 |
+
**[Open full-resolution comparison viewer](https://huggingface.co/spaces/data-archetype/mdiffae-results)** β side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.
|
| 146 |
+
|
| 147 |
+
### 7.2 Inference Settings
|
| 148 |
+
|
| 149 |
+
| Setting | Value |
|
| 150 |
+
|---------|-------|
|
| 151 |
+
| Sampler | ddim |
|
| 152 |
+
| Steps | 1 |
|
| 153 |
+
| Schedule | linear |
|
| 154 |
+
| Seed | 42 |
|
| 155 |
+
| PDG | no_path_dropg |
|
| 156 |
+
| Batch size (timing) | 8 |
|
| 157 |
+
|
| 158 |
+
> All models run in bfloat16. Timings measured on an NVIDIA RTX Pro 6000 (Blackwell).
|
| 159 |
+
|
| 160 |
+
### 7.3 Global Metrics
|
| 161 |
+
|
| 162 |
+
| Metric | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
|
| 163 |
+
|--------|--------|--------|--------|
|
| 164 |
+
| Avg PSNR (dB) | 31.89 | 32.76 | 34.16 |
|
| 165 |
+
| Avg encode (ms/image) | 2.4 | 63.9 | 45.7 |
|
| 166 |
+
| Avg decode (ms/image) | 3.0 | 138.2 | 92.8 |
|
| 167 |
+
|
| 168 |
+
### 7.4 Per-Image PSNR (dB)
|
| 169 |
+
|
| 170 |
+
| Image | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
|
| 171 |
+
|-------|--------|--------|--------|
|
| 172 |
+
| p640x1536:94623 | 31.20 | 31.28 | 33.50 |
|
| 173 |
+
| p640x1536:94624 | 27.32 | 27.62 | 30.03 |
|
| 174 |
+
| p640x1536:94625 | 30.68 | 31.65 | 33.98 |
|
| 175 |
+
| p640x1536:94626 | 29.14 | 29.44 | 31.53 |
|
| 176 |
+
| p640x1536:94627 | 29.63 | 28.70 | 30.53 |
|
| 177 |
+
| p640x1536:94628 | 25.60 | 26.38 | 28.88 |
|
| 178 |
+
| p960x1024:216264 | 44.50 | 40.87 | 45.39 |
|
| 179 |
+
| p960x1024:216265 | 26.42 | 25.82 | 27.80 |
|
| 180 |
+
| p960x1024:216266 | 44.90 | 47.77 | 46.20 |
|
| 181 |
+
| p960x1024:216267 | 37.78 | 37.65 | 39.23 |
|
| 182 |
+
| p960x1024:216268 | 36.15 | 35.27 | 36.13 |
|
| 183 |
+
| p960x1024:216269 | 29.37 | 28.45 | 30.24 |
|
| 184 |
+
| p960x1024:216270 | 32.43 | 31.92 | 34.18 |
|
| 185 |
+
| p960x1024:216271 | 41.23 | 38.92 | 42.18 |
|
| 186 |
+
| p704x1472:94699 | 41.88 | 40.43 | 41.79 |
|
| 187 |
+
| p704x1472:94700 | 29.66 | 29.52 | 32.08 |
|
| 188 |
+
| p704x1472:94701 | 35.14 | 35.43 | 37.90 |
|
| 189 |
+
| p704x1472:94702 | 30.90 | 30.73 | 32.50 |
|
| 190 |
+
| p704x1472:94703 | 28.65 | 29.08 | 31.35 |
|
| 191 |
+
| p704x1472:94704 | 28.98 | 29.22 | 31.84 |
|
| 192 |
+
| p704x1472:94705 | 36.09 | 36.38 | 37.44 |
|
| 193 |
+
| p704x1472:94706 | 31.53 | 31.50 | 33.66 |
|
| 194 |
+
| r256_p1344x704:15577 | 27.89 | 28.32 | 29.98 |
|
| 195 |
+
| r256_p1344x704:15578 | 28.07 | 29.35 | 30.79 |
|
| 196 |
+
| r256_p1344x704:15579 | 29.56 | 30.44 | 31.83 |
|
| 197 |
+
| r256_p1344x704:15580 | 32.89 | 36.12 | 36.03 |
|
| 198 |
+
| r256_p1344x704:15581 | 32.26 | 37.42 | 36.94 |
|
| 199 |
+
| r256_p1344x704:15582 | 28.74 | 30.64 | 32.10 |
|
| 200 |
+
| r256_p1344x704:15583 | 31.99 | 34.67 | 34.54 |
|
| 201 |
+
| r256_p1344x704:15584 | 28.42 | 30.34 | 31.76 |
|
| 202 |
+
| r256_p896x1152:144131 | 30.02 | 33.10 | 33.60 |
|
| 203 |
+
| r256_p896x1152:144132 | 33.19 | 34.23 | 35.32 |
|
| 204 |
+
| r256_p896x1152:144133 | 35.42 | 37.85 | 37.33 |
|
| 205 |
+
| r256_p896x1152:144134 | 31.41 | 34.25 | 34.47 |
|
| 206 |
+
| r256_p896x1152:144135 | 27.13 | 28.17 | 29.87 |
|
| 207 |
+
| r256_p896x1152:144136 | 32.75 | 35.24 | 35.68 |
|
| 208 |
+
| r256_p896x1152:144137 | 28.60 | 32.70 | 32.86 |
|
| 209 |
+
| r256_p896x1152:144138 | 24.76 | 24.15 | 25.63 |
|
| 210 |
+
| VAE_accuracy_test_image | 31.52 | 36.69 | 35.25 |
|
| 211 |
+
|