data-archetype commited on
Commit
128cb34
Β·
verified Β·
1 Parent(s): 3700176

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - diffusion
5
+ - autoencoder
6
+ - image-reconstruction
7
+ - pytorch
8
+ - masked-autoencoder
9
+ library_name: mdiffae
10
+ ---
11
+
12
+ # mdiffae_v1
13
+
14
+ **mDiffAE** β€” **M**asked **Diff**usion **A**uto**E**ncoder.
15
+ A fast, single-GPU-trainable diffusion autoencoder with a **64-channel**
16
+ spatial bottleneck. Uses decoder token masking as an implicit regularizer instead of
17
+ REPA alignment, achieving Flux.2-level conditioning quality with a simpler
18
+ flat decoder architecture (4 blocks, no skip connections).
19
+
20
+ This variant (mdiffae_v1): 81.4M parameters, 310.6 MB.
21
+ Bottleneck: **64 channels** at patch size 16
22
+ (compression ratio 12x).
23
+
24
+ ## Documentation
25
+
26
+ - [Technical Report](technical_report_mdiffae.md) β€” architecture, masking strategy, and results
27
+ - [iRDiffAE Technical Report](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) β€” full background on VP diffusion, DiCo blocks, patchify encoder, AdaLN
28
+ - [Results β€” interactive viewer](https://huggingface.co/spaces/data-archetype/mdiffae-results) β€” full-resolution side-by-side comparison
29
+
30
+ ## Quick Start
31
+
32
+ ```python
33
+ import torch
34
+ from m_diffae import MDiffAE
35
+
36
+ # Load from HuggingFace Hub (or a local path)
37
+ model = MDiffAE.from_pretrained("data-archetype/mdiffae_v1", device="cuda")
38
+
39
+ # Encode
40
+ images = ... # [B, 3, H, W] in [-1, 1], H and W divisible by 16
41
+ latents = model.encode(images)
42
+
43
+ # Decode (1 step by default β€” PSNR-optimal)
44
+ recon = model.decode(latents, height=H, width=W)
45
+
46
+ # Reconstruct (encode + 1-step decode)
47
+ recon = model.reconstruct(images)
48
+ ```
49
+
50
+ > **Note:** Requires `pip install huggingface_hub safetensors` for Hub downloads.
51
+ > You can also pass a local directory path to `from_pretrained()`.
52
+
53
+ ## Architecture
54
+
55
+ | Property | Value |
56
+ |---|---|
57
+ | Parameters | 81,410,624 |
58
+ | File size | 310.6 MB |
59
+ | Patch size | 16 |
60
+ | Model dim | 896 |
61
+ | Encoder depth | 4 |
62
+ | Decoder depth | 4 |
63
+ | Decoder topology | Flat sequential (no skip connections) |
64
+ | Bottleneck dim | 64 |
65
+ | MLP ratio | 4.0 |
66
+ | Depthwise kernel | 7 |
67
+ | AdaLN rank | 128 |
68
+ | PDG mechanism | Token-level masking (ratio 0.75) |
69
+ | Training regularizer | Decoder token masking (75% ratio, 50% apply prob) |
70
+
71
+ **Encoder**: Deterministic. Patchify (PixelUnshuffle + 1x1 conv) followed by
72
+ DiCo blocks (depthwise conv + compact channel attention + GELU MLP) with
73
+ learned residual gates.
74
+
75
+ **Decoder**: VP diffusion conditioned on encoder latents and timestep via
76
+ shared-base + per-layer low-rank AdaLN-Zero. 4 flat
77
+ sequential blocks (no skip connections). Supports token-level Path-Drop
78
+ Guidance (PDG) at inference β€” very sensitive, use small strengths only.
79
+
80
+ **Compared to iRDiffAE's decoder**: iRDiffAE uses an 8-block decoder split into
81
+ start (2), middle (4), and end (2) groups with a skip connection that concatenates
82
+ start-block output with middle-block output and fuses them through a Conv1x1 before
83
+ the end blocks. PDG works by dropping the entire middle block computation and
84
+ replacing it with a learned mask feature. In contrast, mDiffAE uses a simple flat
85
+ stack of 4 blocks with no skip connections or block groups.
86
+ PDG instead works at the token level: 75% of spatial tokens in the fused decoder
87
+ input are replaced with a learned mask feature, providing a much finer-grained
88
+ guidance signal. The bottleneck is also halved from 128 to 64
89
+ channels, giving a 12x
90
+ compression ratio vs iRDiffAE's 6x.
91
+
92
+ ### Key Differences from iRDiffAE
93
+
94
+ | Aspect | iRDiffAE v1 | mDiffAE v1 |
95
+ |---|---|---|
96
+ | Bottleneck dim | 128 | **64** |
97
+ | Decoder depth | 8 (2+4+2 skip-concat) | **4 (flat sequential)** |
98
+ | PDG mechanism | Block dropping | **Token masking** |
99
+ | Training regularizer | REPA + covariance reg | **Decoder token masking** |
100
+ | PDG sensitivity | Moderate (1.5–3.0) | **Very sensitive (1.05–1.2)** |
101
+
102
+ ## Recommended Settings
103
+
104
+ Best quality is achieved with just **1 DDIM step** and PDG disabled,
105
+ making inference extremely fast.
106
+
107
+ PDG in mDiffAE is **very sensitive** β€” use tiny strengths (1.05–1.2)
108
+ if enabled. Higher values will cause artifacts.
109
+
110
+ | Setting | Default |
111
+ |---|---|
112
+ | Sampler | DDIM |
113
+ | Steps | 1 |
114
+ | PDG | Disabled |
115
+ | PDG strength (if enabled) | 1.1 |
116
+
117
+ ```python
118
+ from m_diffae import MDiffAEInferenceConfig
119
+
120
+ # PSNR-optimal (fast, 1 step)
121
+ cfg = MDiffAEInferenceConfig(num_steps=1, sampler="ddim")
122
+ recon = model.decode(latents, height=H, width=W, inference_config=cfg)
123
+ ```
124
+
125
+ ## Citation
126
+
127
+ ```bibtex
128
+ @misc{m_diffae,
129
+ title = {mDiffAE: A Masked Diffusion Autoencoder with Flat Decoder and Token-Level Guidance},
130
+ author = {data-archetype},
131
+ year = {2026},
132
+ month = mar,
133
+ url = {https://huggingface.co/data-archetype/mdiffae_v1},
134
+ }
135
+ ```
136
+
137
+ ## Dependencies
138
+
139
+ - PyTorch >= 2.0
140
+ - safetensors (for loading weights)
141
+
142
+ ## License
143
+
144
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "patch_size": 16,
4
+ "model_dim": 896,
5
+ "encoder_depth": 4,
6
+ "decoder_depth": 4,
7
+ "bottleneck_dim": 64,
8
+ "mlp_ratio": 4.0,
9
+ "depthwise_kernel_size": 7,
10
+ "adaln_low_rank_rank": 128,
11
+ "logsnr_min": -10.0,
12
+ "logsnr_max": 10.0,
13
+ "pixel_noise_std": 0.558,
14
+ "pdg_mask_ratio": 0.75
15
+ }
m_diffae/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """mDiffAE: Standalone diffusion autoencoder for HuggingFace distribution.
2
+
3
+ Masked Diffusion AutoEncoder β€” a compact diffusion autoencoder
4
+ that encodes images to spatial latents and decodes via iterative VP diffusion.
5
+ Uses decoder token masking as an implicit regularizer instead of REPA alignment,
6
+ with a flat 4-block decoder (no skip connections).
7
+
8
+ Usage::
9
+
10
+ from m_diffae import MDiffAE, MDiffAEInferenceConfig
11
+
12
+ model = MDiffAE.from_pretrained("path/to/weights", device="cuda")
13
+
14
+ # Encode
15
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
16
+
17
+ # Decode β€” PSNR-optimal (1 step, default)
18
+ recon = model.decode(latents, height=H, width=W)
19
+
20
+ # Decode β€” perceptual sharpness (10 steps + PDG)
21
+ cfg = MDiffAEInferenceConfig(num_steps=10, sampler="ddim", pdg_enabled=True)
22
+ recon = model.decode(latents, height=H, width=W, inference_config=cfg)
23
+ """
24
+
25
+ from .config import MDiffAEConfig, MDiffAEInferenceConfig
26
+ from .model import MDiffAE
27
+
28
+ __all__ = ["MDiffAE", "MDiffAEConfig", "MDiffAEInferenceConfig"]
m_diffae/adaln.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AdaLN-Zero modules for shared-base + low-rank-delta conditioning."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class AdaLNZeroProjector(nn.Module):
9
+ """Shared base AdaLN projection: SiLU -> Linear(d_cond -> 4*d_model).
10
+
11
+ Returns packed modulation tensor [B, 4*d_model]. Zero-initialized.
12
+ """
13
+
14
+ def __init__(self, d_model: int, d_cond: int) -> None:
15
+ super().__init__()
16
+ self.d_model = int(d_model)
17
+ self.d_cond = int(d_cond)
18
+ self.act = nn.SiLU()
19
+ self.proj = nn.Linear(self.d_cond, 4 * self.d_model)
20
+ nn.init.zeros_(self.proj.weight)
21
+ nn.init.zeros_(self.proj.bias)
22
+
23
+ def forward(self, cond: Tensor) -> Tensor:
24
+ """Return packed modulation [B, 4*d_model] from conditioning [B, d_cond]."""
25
+ act = self.act(cond)
26
+ return self.proj(act)
27
+
28
+ def forward_activated(self, act_cond: Tensor) -> Tensor:
29
+ """Return packed modulation from pre-activated conditioning."""
30
+ return self.proj(act_cond)
31
+
32
+
33
+ class AdaLNZeroLowRankDelta(nn.Module):
34
+ """Per-layer low-rank delta: down(d_cond -> rank) -> up(rank -> 4*d_model).
35
+
36
+ Zero-initialized up-projection preserves AdaLN "zero output" at init.
37
+ """
38
+
39
+ def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
40
+ super().__init__()
41
+ self.d_model = int(d_model)
42
+ self.d_cond = int(d_cond)
43
+ self.rank = int(rank)
44
+ self.down = nn.Linear(self.d_cond, self.rank, bias=False)
45
+ self.up = nn.Linear(self.rank, 4 * self.d_model, bias=False)
46
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
47
+ nn.init.zeros_(self.up.weight)
48
+
49
+ def forward(self, act_cond: Tensor) -> Tensor:
50
+ """Return packed delta modulation [B, 4*d_model] from activated cond."""
51
+ return self.up(self.down(act_cond))
m_diffae/compact_channel_attention.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compact Channel Attention (CCA) module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class CompactChannelAttention(nn.Module):
9
+ """Global average pool -> 1x1 Conv2d -> Sigmoid channel gate."""
10
+
11
+ def __init__(self, channels: int) -> None:
12
+ super().__init__()
13
+ c = int(channels)
14
+ self.pool = nn.AdaptiveAvgPool2d(1)
15
+ self.proj = nn.Conv2d(c, c, kernel_size=1, padding=0, stride=1, bias=True)
16
+ self.act = nn.Sigmoid()
17
+
18
+ def forward(self, x: Tensor) -> Tensor:
19
+ w = self.pool(x)
20
+ w = self.proj(w)
21
+ return self.act(w)
m_diffae/config.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen model architecture and user-tunable inference configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class MDiffAEConfig:
12
+ """Frozen model architecture config. Stored alongside weights as config.json."""
13
+
14
+ in_channels: int = 3
15
+ patch_size: int = 16
16
+ model_dim: int = 896
17
+ encoder_depth: int = 4
18
+ decoder_depth: int = 4
19
+ bottleneck_dim: int = 64
20
+ mlp_ratio: float = 4.0
21
+ depthwise_kernel_size: int = 7
22
+ adaln_low_rank_rank: int = 128
23
+ # VP diffusion schedule endpoints
24
+ logsnr_min: float = -10.0
25
+ logsnr_max: float = 10.0
26
+ # Pixel-space noise std for VP diffusion initialization
27
+ pixel_noise_std: float = 0.558
28
+ # Token mask ratio for PDG (fraction of spatial tokens replaced with mask_feature)
29
+ pdg_mask_ratio: float = 0.75
30
+
31
+ def save(self, path: str | Path) -> None:
32
+ """Save config as JSON."""
33
+ p = Path(path)
34
+ p.parent.mkdir(parents=True, exist_ok=True)
35
+ p.write_text(json.dumps(asdict(self), indent=2) + "\n")
36
+
37
+ @classmethod
38
+ def load(cls, path: str | Path) -> MDiffAEConfig:
39
+ """Load config from JSON."""
40
+ data = json.loads(Path(path).read_text())
41
+ return cls(**data)
42
+
43
+
44
+ @dataclass
45
+ class MDiffAEInferenceConfig:
46
+ """User-tunable inference parameters with sensible defaults.
47
+
48
+ PDG is very sensitive in mDiffAE β€” use small strengths (1.05–1.2).
49
+ """
50
+
51
+ num_steps: int = 1 # decoder forward passes (NFE)
52
+ sampler: str = "ddim" # "ddim" or "dpmpp_2m"
53
+ schedule: str = "linear" # "linear" or "cosine"
54
+ pdg_enabled: bool = False
55
+ pdg_strength: float = 1.1
56
+ seed: int | None = None
m_diffae/conv_mlp.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conv-based MLP with GELU activation for DiCo blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch.nn.functional as F
6
+ from torch import Tensor, nn
7
+
8
+ from .norms import ChannelWiseRMSNorm
9
+
10
+
11
+ class ConvMLP(nn.Module):
12
+ """1x1 Conv-based MLP: RMSNorm -> Conv1x1 -> GELU -> Conv1x1."""
13
+
14
+ def __init__(
15
+ self, channels: int, hidden_channels: int, norm_eps: float = 1e-6
16
+ ) -> None:
17
+ super().__init__()
18
+ self.norm = ChannelWiseRMSNorm(channels, eps=norm_eps, affine=False)
19
+ self.conv_in = nn.Conv2d(channels, hidden_channels, kernel_size=1, bias=True)
20
+ self.conv_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True)
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ y = self.norm(x)
24
+ y = self.conv_in(y)
25
+ y = F.gelu(y)
26
+ return self.conv_out(y)
m_diffae/decoder.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """mDiffAE decoder: flat sequential DiCoBlocks with token-level PDG masking."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector
11
+ from .dico_block import DiCoBlock
12
+ from .norms import ChannelWiseRMSNorm
13
+ from .straight_through_encoder import Patchify
14
+ from .time_embed import SinusoidalTimeEmbeddingMLP
15
+
16
+
17
+ class Decoder(nn.Module):
18
+ """VP diffusion decoder conditioned on encoder latents and timestep.
19
+
20
+ Architecture:
21
+ Patchify x_t -> Norm -> Fuse with upsampled z
22
+ -> Blocks (flat sequential, depth blocks) -> Norm -> Conv1x1 -> PixelShuffle
23
+
24
+ Token-level PDG: at inference, a fraction of spatial tokens in the fused input
25
+ are replaced with a learned mask_feature before the decoder blocks. Comparing
26
+ the masked vs unmasked outputs provides guidance signal.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ in_channels: int,
32
+ patch_size: int,
33
+ model_dim: int,
34
+ depth: int,
35
+ bottleneck_dim: int,
36
+ mlp_ratio: float,
37
+ depthwise_kernel_size: int,
38
+ adaln_low_rank_rank: int,
39
+ pdg_mask_ratio: float = 0.75,
40
+ ) -> None:
41
+ super().__init__()
42
+ self.patch_size = int(patch_size)
43
+ self.model_dim = int(model_dim)
44
+ self.pdg_mask_ratio = float(pdg_mask_ratio)
45
+
46
+ # Input processing
47
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
48
+ self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
49
+
50
+ # Latent conditioning path
51
+ self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
52
+ self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
53
+ self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
54
+
55
+ # Time embedding
56
+ self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
57
+
58
+ # AdaLN: shared base projector + per-block low-rank deltas
59
+ self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim)
60
+ self.adaln_deltas = nn.ModuleList(
61
+ [
62
+ AdaLNZeroLowRankDelta(
63
+ d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
64
+ )
65
+ for _ in range(depth)
66
+ ]
67
+ )
68
+
69
+ # Flat sequential blocks (no start/middle/end split, no skip connections)
70
+ self.blocks = nn.ModuleList(
71
+ [
72
+ DiCoBlock(
73
+ model_dim,
74
+ mlp_ratio,
75
+ depthwise_kernel_size=depthwise_kernel_size,
76
+ use_external_adaln=True,
77
+ )
78
+ for _ in range(depth)
79
+ ]
80
+ )
81
+
82
+ # Learned mask feature for token-level PDG
83
+ self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
84
+
85
+ # Output head
86
+ self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
87
+ self.out_proj = nn.Conv2d(
88
+ model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
89
+ )
90
+ self.unpatchify = nn.PixelShuffle(patch_size)
91
+
92
+ def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
93
+ """Compute packed AdaLN modulation = shared_base + per-layer delta."""
94
+ act = self.adaln_base.act(cond)
95
+ base_m = self.adaln_base.forward_activated(act)
96
+ delta_m = self.adaln_deltas[layer_idx](act)
97
+ return base_m + delta_m
98
+
99
+ def _apply_token_mask(self, fused: Tensor) -> Tensor:
100
+ """Replace a fraction of spatial tokens with mask_feature (2x2 groupwise).
101
+
102
+ Divides the spatial grid into 2x2 groups. Within each group, masks
103
+ floor(ratio * 4) tokens deterministically (lowest random scores).
104
+
105
+ Args:
106
+ fused: [B, C, H, W] fused decoder input.
107
+
108
+ Returns:
109
+ Masked tensor with same shape, where masked positions contain mask_feature.
110
+ """
111
+ b, c, h, w = fused.shape
112
+ # Pad to even dims if needed
113
+ h_pad = (2 - h % 2) % 2
114
+ w_pad = (2 - w % 2) % 2
115
+ if h_pad > 0 or w_pad > 0:
116
+ fused = torch.nn.functional.pad(fused, (0, w_pad, 0, h_pad))
117
+ _, _, h, w = fused.shape
118
+
119
+ # Reshape into 2x2 groups: [B, C, H/2, 2, W/2, 2] -> [B, C, H/2, W/2, 4]
120
+ x = fused.reshape(b, c, h // 2, 2, w // 2, 2)
121
+ x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h // 2, w // 2, 4)
122
+
123
+ # Random scores for each token in each group
124
+ scores = torch.rand(b, 1, h // 2, w // 2, 4, device=fused.device)
125
+
126
+ # Mask the floor(ratio * 4) lowest-scoring tokens per group
127
+ num_mask = math.floor(self.pdg_mask_ratio * 4)
128
+ if num_mask > 0:
129
+ # argsort ascending, mask the first num_mask
130
+ _, indices = scores.sort(dim=-1)
131
+ mask = torch.zeros_like(scores, dtype=torch.bool)
132
+ mask.scatter_(-1, indices[..., :num_mask], True)
133
+ else:
134
+ mask = torch.zeros_like(scores, dtype=torch.bool)
135
+
136
+ # Apply mask: replace masked tokens with mask_feature
137
+ mask_feat = self.mask_feature.to(device=fused.device, dtype=fused.dtype)
138
+ mask_feat = mask_feat.squeeze(-1).squeeze(-1) # [1, C]
139
+ mask_feat = mask_feat.view(1, c, 1, 1, 1).expand_as(x)
140
+ mask_expanded = mask.expand_as(x)
141
+ x = torch.where(mask_expanded, mask_feat, x)
142
+
143
+ # Reshape back to [B, C, H, W]
144
+ x = x.reshape(b, c, h // 2, w // 2, 2, 2)
145
+ x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h, w)
146
+
147
+ # Remove padding if applied
148
+ if h_pad > 0 or w_pad > 0:
149
+ x = x[:, :, : h - h_pad, : w - w_pad]
150
+
151
+ return x
152
+
153
+ def forward(
154
+ self,
155
+ x_t: Tensor,
156
+ t: Tensor,
157
+ latents: Tensor,
158
+ *,
159
+ mask_tokens: bool = False,
160
+ ) -> Tensor:
161
+ """Single decoder forward pass.
162
+
163
+ Args:
164
+ x_t: Noised image [B, C, H, W].
165
+ t: Timestep [B] in [0, 1].
166
+ latents: Encoder latents [B, bottleneck_dim, h, w].
167
+ mask_tokens: If True, apply token-level masking to decoder input (for PDG).
168
+
169
+ Returns:
170
+ x0 prediction [B, C, H, W].
171
+ """
172
+ # Patchify and normalize x_t
173
+ x_feat = self.patchify(x_t)
174
+ x_feat = self.norm_in(x_feat)
175
+
176
+ # Upsample and normalize latents, fuse with x_feat
177
+ z_up = self.latent_up(latents)
178
+ z_up = self.latent_norm(z_up)
179
+ fused = torch.cat([x_feat, z_up], dim=1)
180
+ fused = self.fuse_in(fused)
181
+
182
+ # Token masking for PDG (replaces tokens with mask_feature)
183
+ if mask_tokens:
184
+ fused = self._apply_token_mask(fused)
185
+
186
+ # Time conditioning
187
+ cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
188
+
189
+ # Run all blocks sequentially
190
+ x = fused
191
+ for layer_idx, block in enumerate(self.blocks):
192
+ adaln_m = self._adaln_m_for_layer(cond, layer_idx=layer_idx)
193
+ x = block(x, adaln_m=adaln_m)
194
+
195
+ # Output head
196
+ x = self.norm_out(x)
197
+ patches = self.out_proj(x)
198
+ return self.unpatchify(patches)
m_diffae/dico_block.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiCo block: conv path (1x1 -> depthwise -> SiLU -> CCA -> 1x1) + GELU MLP."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from .compact_channel_attention import CompactChannelAttention
10
+ from .conv_mlp import ConvMLP
11
+ from .norms import ChannelWiseRMSNorm
12
+
13
+
14
+ class DiCoBlock(nn.Module):
15
+ """DiCo-style conv block with optional external AdaLN conditioning.
16
+
17
+ Two modes:
18
+ - Unconditioned (encoder): uses learned per-channel residual gates.
19
+ - External AdaLN (decoder): receives packed modulation [B, 4*C] via adaln_m.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ mlp_ratio: float,
26
+ *,
27
+ depthwise_kernel_size: int = 7,
28
+ use_external_adaln: bool = False,
29
+ norm_eps: float = 1e-6,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.channels = int(channels)
33
+ self.use_external_adaln = bool(use_external_adaln)
34
+
35
+ # Pre-norm for conv and MLP paths (no affine)
36
+ self.norm1 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)
37
+ self.norm2 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)
38
+
39
+ # Conv path: 1x1 -> depthwise kxk -> SiLU -> CCA -> 1x1
40
+ self.conv1 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
41
+ self.conv2 = nn.Conv2d(
42
+ self.channels,
43
+ self.channels,
44
+ kernel_size=depthwise_kernel_size,
45
+ padding=depthwise_kernel_size // 2,
46
+ groups=self.channels,
47
+ bias=True,
48
+ )
49
+ self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
50
+ self.cca = CompactChannelAttention(self.channels)
51
+
52
+ # MLP path: GELU activation
53
+ hidden_channels = max(int(round(float(self.channels) * mlp_ratio)), 1)
54
+ self.mlp = ConvMLP(self.channels, hidden_channels, norm_eps=norm_eps)
55
+
56
+ # Conditioning: learned gates (encoder) or external adaln_m (decoder)
57
+ if not self.use_external_adaln:
58
+ self.gate_attn = nn.Parameter(torch.zeros(self.channels))
59
+ self.gate_mlp = nn.Parameter(torch.zeros(self.channels))
60
+
61
+ def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
62
+ b, c = x.shape[:2]
63
+
64
+ if self.use_external_adaln:
65
+ if adaln_m is None:
66
+ raise ValueError(
67
+ "adaln_m required for externally-conditioned DiCoBlock"
68
+ )
69
+ adaln_m_cast = adaln_m.to(device=x.device, dtype=x.dtype)
70
+ scale_a, gate_a, scale_m, gate_m = adaln_m_cast.chunk(4, dim=-1)
71
+ elif adaln_m is not None:
72
+ raise ValueError("adaln_m must be None for unconditioned DiCoBlock")
73
+
74
+ residual = x
75
+
76
+ # Conv path
77
+ x_att = self.norm1(x)
78
+ if self.use_external_adaln:
79
+ x_att = x_att * (1.0 + scale_a.view(b, c, 1, 1)) # type: ignore[possibly-undefined]
80
+ y = self.conv1(x_att)
81
+ y = self.conv2(y)
82
+ y = F.silu(y)
83
+ y = y * self.cca(y)
84
+ y = self.conv3(y)
85
+
86
+ if self.use_external_adaln:
87
+ gate_a_view = torch.tanh(gate_a).view(b, c, 1, 1) # type: ignore[possibly-undefined]
88
+ x = residual + gate_a_view * y
89
+ else:
90
+ gate = self.gate_attn.view(1, self.channels, 1, 1).to(
91
+ dtype=y.dtype, device=y.device
92
+ )
93
+ x = residual + gate * y
94
+
95
+ # MLP path
96
+ residual_mlp = x
97
+ x_mlp = self.norm2(x)
98
+ if self.use_external_adaln:
99
+ x_mlp = x_mlp * (1.0 + scale_m.view(b, c, 1, 1)) # type: ignore[possibly-undefined]
100
+ y_mlp = self.mlp(x_mlp)
101
+
102
+ if self.use_external_adaln:
103
+ gate_m_view = torch.tanh(gate_m).view(b, c, 1, 1) # type: ignore[possibly-undefined]
104
+ x = residual_mlp + gate_m_view * y_mlp
105
+ else:
106
+ gate = self.gate_mlp.view(1, self.channels, 1, 1).to(
107
+ dtype=y_mlp.dtype, device=y_mlp.device
108
+ )
109
+ x = residual_mlp + gate * y_mlp
110
+
111
+ return x
m_diffae/encoder.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """mDiffAE encoder: patchify -> DiCoBlocks -> bottleneck projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+ from .dico_block import DiCoBlock
8
+ from .norms import ChannelWiseRMSNorm
9
+ from .straight_through_encoder import Patchify
10
+
11
+
12
+ class Encoder(nn.Module):
13
+ """Deterministic encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
14
+
15
+ Pipeline: Patchify -> RMSNorm -> DiCoBlocks (unconditioned) -> Conv1x1 -> RMSNorm(no affine)
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ patch_size: int,
22
+ model_dim: int,
23
+ depth: int,
24
+ bottleneck_dim: int,
25
+ mlp_ratio: float,
26
+ depthwise_kernel_size: int,
27
+ ) -> None:
28
+ super().__init__()
29
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
30
+ self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
31
+ self.blocks = nn.ModuleList(
32
+ [
33
+ DiCoBlock(
34
+ model_dim,
35
+ mlp_ratio,
36
+ depthwise_kernel_size=depthwise_kernel_size,
37
+ use_external_adaln=False,
38
+ )
39
+ for _ in range(depth)
40
+ ]
41
+ )
42
+ self.to_bottleneck = nn.Conv2d(
43
+ model_dim, bottleneck_dim, kernel_size=1, bias=True
44
+ )
45
+ self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
46
+
47
+ def forward(self, images: Tensor) -> Tensor:
48
+ """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w]."""
49
+ z = self.patchify(images)
50
+ z = self.norm_in(z)
51
+ for block in self.blocks:
52
+ z = block(z)
53
+ z = self.to_bottleneck(z)
54
+ z = self.norm_out(z)
55
+ return z
m_diffae/model.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MDiffAE: standalone HuggingFace-compatible mDiffAE model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from .config import MDiffAEConfig, MDiffAEInferenceConfig
11
+ from .decoder import Decoder
12
+ from .encoder import Encoder
13
+ from .samplers import run_ddim, run_dpmpp_2m
14
+ from .vp_diffusion import get_schedule, make_initial_state, sample_noise
15
+
16
+
17
+ def _resolve_model_dir(
18
+ path_or_repo_id: str | Path,
19
+ *,
20
+ revision: str | None,
21
+ cache_dir: str | Path | None,
22
+ ) -> Path:
23
+ """Resolve a local path or HuggingFace Hub repo ID to a local directory."""
24
+
25
+ local = Path(path_or_repo_id)
26
+ if local.is_dir():
27
+ return local
28
+ # Not a local directory β€” try HuggingFace Hub
29
+ repo_id = str(path_or_repo_id)
30
+ try:
31
+ from huggingface_hub import snapshot_download
32
+ except ImportError:
33
+ raise ImportError(
34
+ f"'{repo_id}' is not an existing local directory. "
35
+ "To download from HuggingFace Hub, install huggingface_hub: "
36
+ "pip install huggingface_hub"
37
+ )
38
+ cache_dir_str = str(cache_dir) if cache_dir is not None else None
39
+ local_dir = snapshot_download(
40
+ repo_id,
41
+ revision=revision,
42
+ cache_dir=cache_dir_str,
43
+ )
44
+ return Path(local_dir)
45
+
46
+
47
+ class MDiffAE(nn.Module):
48
+ """Standalone mDiffAE model for HuggingFace distribution.
49
+
50
+ A masked diffusion autoencoder that encodes images to compact latents and
51
+ decodes them back via iterative VP diffusion. Uses a flat 4-block decoder
52
+ with token-level masking for PDG instead of the skip-concat + block-drop
53
+ approach of iRDiffAE.
54
+
55
+ Usage::
56
+
57
+ model = MDiffAE.from_pretrained("data-archetype/mdiffae-v1")
58
+ model = model.to("cuda", dtype=torch.bfloat16)
59
+
60
+ # Encode
61
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
62
+
63
+ # Decode (1 step by default β€” PSNR-optimal)
64
+ recon = model.decode(latents, height=H, width=W)
65
+
66
+ # Reconstruct (encode + 1-step decode)
67
+ recon = model.reconstruct(images)
68
+ """
69
+
70
+ def __init__(self, config: MDiffAEConfig) -> None:
71
+ super().__init__()
72
+ self.config = config
73
+
74
+ self.encoder = Encoder(
75
+ in_channels=config.in_channels,
76
+ patch_size=config.patch_size,
77
+ model_dim=config.model_dim,
78
+ depth=config.encoder_depth,
79
+ bottleneck_dim=config.bottleneck_dim,
80
+ mlp_ratio=config.mlp_ratio,
81
+ depthwise_kernel_size=config.depthwise_kernel_size,
82
+ )
83
+
84
+ self.decoder = Decoder(
85
+ in_channels=config.in_channels,
86
+ patch_size=config.patch_size,
87
+ model_dim=config.model_dim,
88
+ depth=config.decoder_depth,
89
+ bottleneck_dim=config.bottleneck_dim,
90
+ mlp_ratio=config.mlp_ratio,
91
+ depthwise_kernel_size=config.depthwise_kernel_size,
92
+ adaln_low_rank_rank=config.adaln_low_rank_rank,
93
+ pdg_mask_ratio=config.pdg_mask_ratio,
94
+ )
95
+
96
+ @classmethod
97
+ def from_pretrained(
98
+ cls,
99
+ path_or_repo_id: str | Path,
100
+ *,
101
+ dtype: torch.dtype = torch.bfloat16,
102
+ device: str | torch.device = "cpu",
103
+ revision: str | None = None,
104
+ cache_dir: str | Path | None = None,
105
+ ) -> MDiffAE:
106
+ """Load a pretrained model from a local directory or HuggingFace Hub.
107
+
108
+ The directory (or repo) should contain:
109
+ - config.json: Model architecture config.
110
+ - model.safetensors (preferred) or model.pt: Model weights.
111
+
112
+ Args:
113
+ path_or_repo_id: Local directory path or HuggingFace Hub repo ID
114
+ (e.g. ``"data-archetype/mdiffae-v1"``).
115
+ dtype: Load weights in this dtype (float32 or bfloat16).
116
+ device: Target device.
117
+ revision: Git revision (branch, tag, or commit) for Hub downloads.
118
+ cache_dir: Where to cache Hub downloads. Uses HF default if None.
119
+
120
+ Returns:
121
+ Loaded model in eval mode.
122
+ """
123
+ model_dir = _resolve_model_dir(
124
+ path_or_repo_id, revision=revision, cache_dir=cache_dir
125
+ )
126
+ config = MDiffAEConfig.load(model_dir / "config.json")
127
+ model = cls(config)
128
+
129
+ # Try safetensors first, fall back to .pt
130
+ safetensors_path = model_dir / "model.safetensors"
131
+ pt_path = model_dir / "model.pt"
132
+
133
+ if safetensors_path.exists():
134
+ try:
135
+ from safetensors.torch import load_file
136
+
137
+ state_dict = load_file(str(safetensors_path), device=str(device))
138
+ except ImportError:
139
+ raise ImportError(
140
+ "safetensors package required to load .safetensors files. "
141
+ "Install with: pip install safetensors"
142
+ )
143
+ elif pt_path.exists():
144
+ state_dict = torch.load(
145
+ str(pt_path), map_location=device, weights_only=True
146
+ )
147
+ else:
148
+ raise FileNotFoundError(
149
+ f"No model weights found in {model_dir}. "
150
+ "Expected model.safetensors or model.pt."
151
+ )
152
+
153
+ model.load_state_dict(state_dict)
154
+ model = model.to(dtype=dtype, device=torch.device(device))
155
+ model.eval()
156
+ return model
157
+
158
+ def encode(self, images: Tensor) -> Tensor:
159
+ """Encode images to latents.
160
+
161
+ Args:
162
+ images: [B, 3, H, W] in [-1, 1], H and W must be divisible by patch_size.
163
+
164
+ Returns:
165
+ Latents [B, bottleneck_dim, H/patch, W/patch].
166
+ """
167
+ try:
168
+ model_dtype = next(self.parameters()).dtype
169
+ except StopIteration:
170
+ model_dtype = torch.float32
171
+ return self.encoder(images.to(dtype=model_dtype))
172
+
173
+ @torch.no_grad()
174
+ def decode(
175
+ self,
176
+ latents: Tensor,
177
+ height: int,
178
+ width: int,
179
+ *,
180
+ inference_config: MDiffAEInferenceConfig | None = None,
181
+ ) -> Tensor:
182
+ """Decode latents to images via VP diffusion.
183
+
184
+ Args:
185
+ latents: [B, bottleneck_dim, h, w] encoder latents.
186
+ height: Output image height (must be divisible by patch_size).
187
+ width: Output image width (must be divisible by patch_size).
188
+ inference_config: Optional inference parameters. Uses defaults if None.
189
+
190
+ Returns:
191
+ Reconstructed images [B, 3, H, W] in float32.
192
+ """
193
+ cfg = inference_config or MDiffAEInferenceConfig()
194
+ config = self.config
195
+ batch = int(latents.shape[0])
196
+ device = latents.device
197
+
198
+ # Determine model dtype from parameters
199
+ try:
200
+ model_dtype = next(self.parameters()).dtype
201
+ except StopIteration:
202
+ model_dtype = torch.float32
203
+
204
+ # Validate dimensions
205
+ if height % config.patch_size != 0 or width % config.patch_size != 0:
206
+ raise ValueError(
207
+ f"height={height} and width={width} must be divisible by patch_size={config.patch_size}"
208
+ )
209
+
210
+ # Generate initial noise
211
+ shape = (batch, config.in_channels, height, width)
212
+ noise = sample_noise(
213
+ shape,
214
+ noise_std=config.pixel_noise_std,
215
+ seed=cfg.seed,
216
+ device=torch.device("cpu"),
217
+ dtype=torch.float32,
218
+ )
219
+
220
+ # Build schedule
221
+ schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
222
+
223
+ # Construct initial state: sigma_start * noise
224
+ initial_state = make_initial_state(
225
+ noise=noise.to(device=device),
226
+ t_start=schedule[0:1],
227
+ logsnr_min=config.logsnr_min,
228
+ logsnr_max=config.logsnr_max,
229
+ )
230
+
231
+ # Disable autocast for numerical precision
232
+ device_type = "cuda" if device.type == "cuda" else "cpu"
233
+ with torch.autocast(device_type=device_type, enabled=False):
234
+ latents_in = latents.to(device=device)
235
+
236
+ def _forward_fn(
237
+ x_t: Tensor,
238
+ t: Tensor,
239
+ latents: Tensor,
240
+ *,
241
+ mask_tokens: bool = False,
242
+ ) -> Tensor:
243
+ return self.decoder(
244
+ x_t.to(dtype=model_dtype),
245
+ t,
246
+ latents.to(dtype=model_dtype),
247
+ mask_tokens=mask_tokens,
248
+ )
249
+
250
+ # Select sampler
251
+ if cfg.sampler == "ddim":
252
+ sampler_fn = run_ddim
253
+ elif cfg.sampler == "dpmpp_2m":
254
+ sampler_fn = run_dpmpp_2m
255
+ else:
256
+ raise ValueError(
257
+ f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
258
+ )
259
+
260
+ result = sampler_fn(
261
+ forward_fn=_forward_fn,
262
+ initial_state=initial_state,
263
+ schedule=schedule,
264
+ latents=latents_in,
265
+ logsnr_min=config.logsnr_min,
266
+ logsnr_max=config.logsnr_max,
267
+ pdg_enabled=cfg.pdg_enabled,
268
+ pdg_strength=cfg.pdg_strength,
269
+ device=device,
270
+ )
271
+
272
+ return result
273
+
274
+ @torch.no_grad()
275
+ def reconstruct(
276
+ self,
277
+ images: Tensor,
278
+ *,
279
+ inference_config: MDiffAEInferenceConfig | None = None,
280
+ ) -> Tensor:
281
+ """Encode then decode. Convenience wrapper.
282
+
283
+ Args:
284
+ images: [B, 3, H, W] in [-1, 1].
285
+ inference_config: Optional inference parameters.
286
+
287
+ Returns:
288
+ Reconstructed images [B, 3, H, W] in float32.
289
+ """
290
+ latents = self.encode(images)
291
+ _, _, h, w = images.shape
292
+ return self.decode(
293
+ latents, height=h, width=w, inference_config=inference_config
294
+ )
m_diffae/norms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Channel-wise RMSNorm for NCHW tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class ChannelWiseRMSNorm(nn.Module):
10
+ """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
+
12
+ Normalizes across channels per spatial position. Supports optional
13
+ per-channel affine weight and bias.
14
+ """
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
+ super().__init__()
18
+ self.channels: int = int(channels)
19
+ self._eps: float = float(eps)
20
+ if affine:
21
+ self.weight = nn.Parameter(torch.ones(self.channels))
22
+ self.bias = nn.Parameter(torch.zeros(self.channels))
23
+ else:
24
+ self.register_parameter("weight", None)
25
+ self.register_parameter("bias", None)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ if x.dim() < 2:
29
+ return x
30
+ # Float32 accumulation for stability
31
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
+ inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms
34
+ if self.weight is not None:
35
+ shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=y.dtype)
37
+ if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=y.dtype)
39
+ return y.to(dtype=x.dtype)
m_diffae/samplers.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDIM and DPM++2M samplers for VP diffusion with x-prediction objective."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .vp_diffusion import (
11
+ alpha_sigma_from_logsnr,
12
+ broadcast_time_like,
13
+ shifted_cosine_interpolated_logsnr_from_t,
14
+ )
15
+
16
+
17
+ class DecoderForwardFn(Protocol):
18
+ """Callable that predicts x0 from (x_t, t, latents)."""
19
+
20
+ def __call__(
21
+ self,
22
+ x_t: Tensor,
23
+ t: Tensor,
24
+ latents: Tensor,
25
+ *,
26
+ mask_tokens: bool = False,
27
+ ) -> Tensor: ...
28
+
29
+
30
+ def _reconstruct_eps_from_x0(
31
+ *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
32
+ ) -> Tensor:
33
+ """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
34
+
35
+ eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
36
+ """
37
+ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
38
+ sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
39
+ x_t_f32 = x_t.to(torch.float32)
40
+ x0_f32 = x0_hat.to(torch.float32)
41
+ return (x_t_f32 - alpha_view * x0_f32) / sigma_view
42
+
43
+
44
+ def _ddim_step(
45
+ *,
46
+ x0_hat: Tensor,
47
+ eps_hat: Tensor,
48
+ alpha_next: Tensor,
49
+ sigma_next: Tensor,
50
+ ref: Tensor,
51
+ ) -> Tensor:
52
+ """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
53
+ a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
54
+ s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
55
+ return a * x0_hat + s * eps_hat
56
+
57
+
58
+ def run_ddim(
59
+ *,
60
+ forward_fn: DecoderForwardFn,
61
+ initial_state: Tensor,
62
+ schedule: Tensor,
63
+ latents: Tensor,
64
+ logsnr_min: float,
65
+ logsnr_max: float,
66
+ log_change_high: float = 0.0,
67
+ log_change_low: float = 0.0,
68
+ pdg_enabled: bool = False,
69
+ pdg_strength: float = 1.1,
70
+ device: torch.device | None = None,
71
+ ) -> Tensor:
72
+ """Run DDIM sampling loop.
73
+
74
+ Args:
75
+ forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
76
+ initial_state: Starting noised state [B, C, H, W] in float32.
77
+ schedule: Descending t-schedule [num_steps] in [0, 1].
78
+ latents: Encoder latents [B, bottleneck_dim, h, w].
79
+ logsnr_min, logsnr_max: VP schedule endpoints.
80
+ log_change_high, log_change_low: Shifted-cosine schedule parameters.
81
+ pdg_enabled: Whether to use token-level Path-Drop Guidance.
82
+ pdg_strength: CFG-like strength for PDG (use small values: 1.05–1.2).
83
+ device: Target device.
84
+
85
+ Returns:
86
+ Denoised samples [B, C, H, W] in float32.
87
+ """
88
+ run_device = device or initial_state.device
89
+ batch_size = int(initial_state.shape[0])
90
+ state = initial_state.to(device=run_device, dtype=torch.float32)
91
+
92
+ # Precompute logSNR, alpha, sigma for all schedule points
93
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
94
+ schedule.to(device=run_device),
95
+ logsnr_min=logsnr_min,
96
+ logsnr_max=logsnr_max,
97
+ log_change_high=log_change_high,
98
+ log_change_low=log_change_low,
99
+ )
100
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
101
+
102
+ for i in range(int(schedule.numel()) - 1):
103
+ t_i = schedule[i]
104
+ a_t = alpha_sched[i].expand(batch_size)
105
+ s_t = sigma_sched[i].expand(batch_size)
106
+ a_next = alpha_sched[i + 1].expand(batch_size)
107
+ s_next = sigma_sched[i + 1].expand(batch_size)
108
+
109
+ # Model prediction
110
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
111
+ if pdg_enabled:
112
+ x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to(
113
+ torch.float32
114
+ )
115
+ x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to(
116
+ torch.float32
117
+ )
118
+ x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond)
119
+ else:
120
+ x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to(
121
+ torch.float32
122
+ )
123
+
124
+ eps_hat = _reconstruct_eps_from_x0(
125
+ x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
126
+ )
127
+ state = _ddim_step(
128
+ x0_hat=x0_hat,
129
+ eps_hat=eps_hat,
130
+ alpha_next=a_next,
131
+ sigma_next=s_next,
132
+ ref=state,
133
+ )
134
+
135
+ return state
136
+
137
+
138
+ def run_dpmpp_2m(
139
+ *,
140
+ forward_fn: DecoderForwardFn,
141
+ initial_state: Tensor,
142
+ schedule: Tensor,
143
+ latents: Tensor,
144
+ logsnr_min: float,
145
+ logsnr_max: float,
146
+ log_change_high: float = 0.0,
147
+ log_change_low: float = 0.0,
148
+ pdg_enabled: bool = False,
149
+ pdg_strength: float = 1.1,
150
+ device: torch.device | None = None,
151
+ ) -> Tensor:
152
+ """Run DPM++2M sampling loop.
153
+
154
+ Multi-step solver using exponential integrator formulation in half-lambda space.
155
+ """
156
+ run_device = device or initial_state.device
157
+ batch_size = int(initial_state.shape[0])
158
+ state = initial_state.to(device=run_device, dtype=torch.float32)
159
+
160
+ # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
161
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
162
+ schedule.to(device=run_device),
163
+ logsnr_min=logsnr_min,
164
+ logsnr_max=logsnr_max,
165
+ log_change_high=log_change_high,
166
+ log_change_low=log_change_low,
167
+ )
168
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
169
+ half_lambda = 0.5 * lmb.to(torch.float32)
170
+
171
+ x0_prev: Tensor | None = None
172
+
173
+ for i in range(int(schedule.numel()) - 1):
174
+ t_i = schedule[i]
175
+ s_t = sigma_sched[i].expand(batch_size)
176
+ a_next = alpha_sched[i + 1].expand(batch_size)
177
+ s_next = sigma_sched[i + 1].expand(batch_size)
178
+
179
+ # Model prediction
180
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
181
+ if pdg_enabled:
182
+ x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to(
183
+ torch.float32
184
+ )
185
+ x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to(
186
+ torch.float32
187
+ )
188
+ x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond)
189
+ else:
190
+ x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to(
191
+ torch.float32
192
+ )
193
+
194
+ lam_t = half_lambda[i].expand(batch_size)
195
+ lam_next = half_lambda[i + 1].expand(batch_size)
196
+ h = (lam_next - lam_t).to(torch.float32)
197
+ phi_1 = torch.expm1(-h)
198
+
199
+ sigma_ratio = (s_next / s_t).to(torch.float32)
200
+
201
+ if i == 0 or x0_prev is None:
202
+ # First-order step
203
+ state = (
204
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
205
+ - broadcast_time_like(a_next, state).to(torch.float32)
206
+ * broadcast_time_like(phi_1, state).to(torch.float32)
207
+ * x0_hat
208
+ )
209
+ else:
210
+ # Second-order step
211
+ lam_prev = half_lambda[i - 1].expand(batch_size)
212
+ h_0 = (lam_t - lam_prev).to(torch.float32)
213
+ r0 = h_0 / h
214
+ d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
215
+ common = broadcast_time_like(a_next, state).to(
216
+ torch.float32
217
+ ) * broadcast_time_like(phi_1, state).to(torch.float32)
218
+ state = (
219
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
220
+ - common * x0_hat
221
+ - 0.5 * common * d1_0
222
+ )
223
+
224
+ x0_prev = x0_hat
225
+
226
+ return state
m_diffae/straight_through_encoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelUnshuffle-based patchifier (no residual conv path)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class Patchify(nn.Module):
9
+ """PixelUnshuffle(patch) -> Conv2d 1x1 projection.
10
+
11
+ Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
12
+ """
13
+
14
+ def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
15
+ super().__init__()
16
+ self.patch = int(patch)
17
+ self.unshuffle = nn.PixelUnshuffle(self.patch)
18
+ in_after = in_channels * (self.patch * self.patch)
19
+ self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
23
+ raise ValueError(
24
+ f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
25
+ )
26
+ y = self.unshuffle(x)
27
+ return self.proj(y)
m_diffae/time_embed.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinusoidal timestep embedding with MLP projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def _log_spaced_frequencies(
12
+ half: int, max_period: float, *, device: torch.device | None = None
13
+ ) -> Tensor:
14
+ """Log-spaced frequencies for sinusoidal embedding."""
15
+ return torch.exp(
16
+ -math.log(max_period)
17
+ * torch.arange(half, device=device, dtype=torch.float32)
18
+ / max(float(half - 1), 1.0)
19
+ )
20
+
21
+
22
+ def sinusoidal_time_embedding(
23
+ t: Tensor,
24
+ dim: int,
25
+ *,
26
+ max_period: float = 10000.0,
27
+ scale: float | None = None,
28
+ freqs: Tensor | None = None,
29
+ ) -> Tensor:
30
+ """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
+ t32 = t.to(torch.float32)
32
+ if scale is not None:
33
+ t32 = t32 * float(scale)
34
+ half = dim // 2
35
+ if freqs is not None:
36
+ freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
+ else:
38
+ freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
+ angles = t32[:, None] * freqs[None, :]
40
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
+
42
+
43
+ class SinusoidalTimeEmbeddingMLP(nn.Module):
44
+ """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ *,
50
+ freq_dim: int = 256,
51
+ hidden_mult: float = 1.0,
52
+ time_scale: float = 1000.0,
53
+ max_period: float = 10000.0,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.dim = int(dim)
57
+ self.freq_dim = int(freq_dim)
58
+ self.time_scale = float(time_scale)
59
+ self.max_period = float(max_period)
60
+ hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
+
62
+ freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
+ self.register_buffer("freqs", freqs, persistent=True)
64
+
65
+ self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
+ self.act = nn.SiLU()
67
+ self.proj_out = nn.Linear(hidden_dim, self.dim)
68
+
69
+ def forward(self, t: Tensor) -> Tensor:
70
+ freqs: Tensor = self.freqs # type: ignore[assignment]
71
+ emb_freq = sinusoidal_time_embedding(
72
+ t.to(torch.float32),
73
+ self.freq_dim,
74
+ max_period=self.max_period,
75
+ scale=self.time_scale,
76
+ freqs=freqs,
77
+ )
78
+ dtype_in = self.proj_in.weight.dtype
79
+ hidden = self.proj_in(emb_freq.to(dtype_in))
80
+ hidden = self.act(hidden)
81
+ if hidden.dtype != self.proj_out.weight.dtype:
82
+ hidden = hidden.to(self.proj_out.weight.dtype)
83
+ return self.proj_out(hidden)
m_diffae/vp_diffusion.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
12
+ """Compute (alpha, sigma) from logSNR in float32.
13
+
14
+ VP constraint: alpha^2 + sigma^2 = 1.
15
+ """
16
+ lmb32 = lmb.to(dtype=torch.float32)
17
+ alpha = torch.sqrt(torch.sigmoid(lmb32))
18
+ sigma = torch.sqrt(torch.sigmoid(-lmb32))
19
+ return alpha, sigma
20
+
21
+
22
+ def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
23
+ """Broadcast [B] coefficient to match x for per-sample scaling."""
24
+ view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
25
+ return coeff.view(view_shape)
26
+
27
+
28
+ def _cosine_interpolated_params(
29
+ logsnr_min: float, logsnr_max: float
30
+ ) -> tuple[float, float]:
31
+ """Compute (a, b) for cosine-interpolated logSNR schedule.
32
+
33
+ logsnr(t) = -2 * log(tan(a*t + b))
34
+ logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
35
+ """
36
+ b = math.atan(math.exp(-0.5 * logsnr_max))
37
+ a = math.atan(math.exp(-0.5 * logsnr_min)) - b
38
+ return a, b
39
+
40
+
41
+ def cosine_interpolated_logsnr_from_t(
42
+ t: Tensor, *, logsnr_min: float, logsnr_max: float
43
+ ) -> Tensor:
44
+ """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
45
+ a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
46
+ t32 = t.to(dtype=torch.float32)
47
+ a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
48
+ b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
49
+ u = a_t * t32 + b_t
50
+ return -2.0 * torch.log(torch.tan(u))
51
+
52
+
53
+ def shifted_cosine_interpolated_logsnr_from_t(
54
+ t: Tensor,
55
+ *,
56
+ logsnr_min: float,
57
+ logsnr_max: float,
58
+ log_change_high: float = 0.0,
59
+ log_change_low: float = 0.0,
60
+ ) -> Tensor:
61
+ """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
62
+
63
+ lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
64
+ """
65
+ base = cosine_interpolated_logsnr_from_t(
66
+ t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
67
+ )
68
+ t32 = t.to(dtype=torch.float32)
69
+ high = base + float(log_change_high)
70
+ low = base + float(log_change_low)
71
+ return (1.0 - t32) * high + t32 * low
72
+
73
+
74
+ def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
75
+ """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
76
+
77
+ ``num_steps`` is the number of function evaluations (NFE = decoder forward
78
+ passes). Internally the schedule has ``num_steps + 1`` time points
79
+ (including both endpoints).
80
+
81
+ Args:
82
+ schedule_type: "linear" or "cosine".
83
+ num_steps: Number of decoder forward passes (NFE), >= 1.
84
+
85
+ Returns:
86
+ Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
87
+ """
88
+ # NOTE: the upstream training code (src/ode/time_schedules.py) uses a
89
+ # different convention where num_steps counts schedule *points* (so NFE =
90
+ # num_steps - 1). This export package corrects the off-by-one so that
91
+ # num_steps means NFE directly. TODO: align the upstream convention.
92
+ n = max(int(num_steps) + 1, 2)
93
+ if schedule_type == "linear":
94
+ base = torch.linspace(0.0, 1.0, n)
95
+ elif schedule_type == "cosine":
96
+ i = torch.arange(n, dtype=torch.float32)
97
+ base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
98
+ else:
99
+ raise ValueError(
100
+ f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
101
+ )
102
+ # Descending: high t (noisy) -> low t (clean)
103
+ return torch.flip(base, dims=[0])
104
+
105
+
106
+ def make_initial_state(
107
+ *,
108
+ noise: Tensor,
109
+ t_start: Tensor,
110
+ logsnr_min: float,
111
+ logsnr_max: float,
112
+ log_change_high: float = 0.0,
113
+ log_change_low: float = 0.0,
114
+ ) -> Tensor:
115
+ """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
116
+
117
+ All math in float32.
118
+ """
119
+ batch = int(noise.shape[0])
120
+ lmb_start = shifted_cosine_interpolated_logsnr_from_t(
121
+ t_start.expand(batch).to(dtype=torch.float32),
122
+ logsnr_min=logsnr_min,
123
+ logsnr_max=logsnr_max,
124
+ log_change_high=log_change_high,
125
+ log_change_low=log_change_low,
126
+ )
127
+ _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
128
+ sigma_view = broadcast_time_like(sigma_start, noise)
129
+ return sigma_view * noise.to(dtype=torch.float32)
130
+
131
+
132
+ def sample_noise(
133
+ shape: tuple[int, ...],
134
+ *,
135
+ noise_std: float = 1.0,
136
+ seed: int | None = None,
137
+ device: torch.device | None = None,
138
+ dtype: torch.dtype = torch.float32,
139
+ ) -> Tensor:
140
+ """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
141
+ if seed is None:
142
+ noise = torch.randn(
143
+ shape, device=device or torch.device("cpu"), dtype=torch.float32
144
+ )
145
+ else:
146
+ gen = torch.Generator(device="cpu")
147
+ gen.manual_seed(int(seed))
148
+ noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
149
+ noise = noise.mul(float(noise_std))
150
+ target_device = device if device is not None else torch.device("cpu")
151
+ return noise.to(device=target_device, dtype=dtype)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:196bd7a0607c495c0b72bb8c306c66d871792ea3d3f120afa14115ba3ca8e7ae
3
+ size 325656824
technical_report_mdiffae.md ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mDiffAE: Masked Diffusion AutoEncoder β€” Technical Report
2
+
3
+ **Version 1** β€” March 2026
4
+
5
+ ## 1. Introduction
6
+
7
+ mDiffAE (**M**asked **Diff**usion **A**uto**E**ncoder) builds on the [iRDiffAE](https://huggingface.co/data-archetype/irdiffae-v1/blob/main/technical_report.md) model family, which provides a fast, single-GPU-trainable diffusion autoencoder platform capable of high-quality image reconstruction. See that report for full background on the shared components: VP diffusion math (logSNR schedules, alpha/sigma, x-prediction), DiCo block architecture (depthwise conv + compact channel attention + GELU MLP), patchify encoder (PixelUnshuffle + 1Γ—1 conv), shared-base + low-rank AdaLN-Zero conditioning, and the Path-Drop Guidance (PDG) concept.
8
+
9
+ The iRDiffAE platform is designed to make it easy to experiment with different ways of regularizing the latent space. iRDiffAE v1 used REPA β€” aligning encoder features with a frozen DINOv2 teacher β€” which produces well-structured latents but tends toward overly smooth representations that are hard to reconcile with fine detail. Here we take a different approach entirely: **decoder token masking**.
10
+
11
+ ### 1.1 Token Masking as Regularizer
12
+
13
+ A fraction of the time (50% of samples per batch), the decoder only sees **25% of tokens** in the fused conditioning input. The spatial token grid is divided into non-overlapping 2Γ—2 groups, and within each group a single token is randomly kept while the other three are replaced with a learned mask feature. Hiding such a large fraction (75%) pushes the encoder to learn a form of representation consistency β€” each spatial token must carry enough information to support reconstruction even when most of its neighbors are absent. A smaller masking fraction helps downstream models learn sharp details quickly, but they fail to learn spatial coherence nearly as well β€” the task becomes too close to local inpainting, and the encoder is not pressured into globally consistent representations. The importance of a high masking ratio echoes findings in the masked autoencoder literature (He et al., 2022); we tested lower ratios and confirmed this tradeoff empirically.
14
+
15
+ The 50% per-sample application probability is the knob that controls the compromise between reconstruction quality and latent space quality: samples that receive masking push the encoder toward consistent representations, while unmasked samples maintain reconstruction fidelity.
16
+
17
+ ### 1.2 Latent Noise Regularization
18
+
19
+ To further regularize the latent space, we retain the random latent noising mechanism 10% of the time. However, unlike the pixel-space diffusion noise, the latent noise level is sampled independently using a **Beta(2,2)** distribution (stratified), with a **logSNR shift of +1.0** that biases it toward low noise levels (low *t* in our convention). This decouples the latent regularization schedule from the decoder's diffusion schedule, providing a gentle push toward noise-robust representations without disrupting reconstruction training.
20
+
21
+ ### 1.3 Simplified Decoder
22
+
23
+ To keep the representational pressure on the encoder, we restrict the decoder to only **4 blocks** (down from 8 in iRDiffAE v1) and simplify it to a flat sequential architecture β€” no start/middle/end block groups, no skip connections. This halves the decoder's parameter count and makes it roughly 2Γ— faster, while forcing the encoder to compensate by producing more informative latents.
24
+
25
+ ### 1.4 Empirical Results
26
+
27
+ Compared to the REPA-regularized iRDiffAE v1, mDiffAE achieves slightly higher PSNR (to be confirmed with final benchmarks, but initial results were quite decisive) and produces a less oversmoothed but very locally consistent latent space PCA. In downstream diffusion model training, mDiffAE's latent space does not exhibit the very steep initial loss descent seen with iRDiffAE, but it quickly catches up after 50k–100k training steps, producing more spatially coherent images earlier with better high-frequency detail.
28
+
29
+ ### 1.5 References
30
+
31
+ - He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022). *Masked Autoencoders Are Scalable Vision Learners*. CVPR 2022.
32
+ - Li, T., Chang, H., Mishra, S.K., Zhang, H., Katabi, D., & Krishnan, D. (2023). *MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis*. CVPR 2023.
33
+
34
+ ## 2. Architecture Differences from iRDiffAE
35
+
36
+ | Aspect | iRDiffAE v1 (halfrepa c128) | mDiffAE v1 (masked c64) |
37
+ |--------|------------------------------|--------------------------|
38
+ | Bottleneck dim | 128 | **64** |
39
+ | Decoder depth | 8 (2 start + 4 middle + 2 end) | **4 (flat sequential)** |
40
+ | Decoder topology | START_MIDDLE_END_SKIP_CONCAT | **FLAT (no skip concat)** |
41
+ | Skip fusion | Yes (`fuse_skip` Conv1Γ—1) | **No** |
42
+ | PDG mechanism | Drop middle blocks β†’ mask_feature | **Token-level masking** (75% spatial tokens β†’ mask_feature) |
43
+ | PDG sensitivity | Moderate (strength 1.5–3.0) | **Very sensitive** (strength 1.05–1.2 only) |
44
+ | Training regularizer | REPA (half-channel DINOv2 alignment) + covreg | **Decoder token masking** (75% ratio, 50% apply prob) |
45
+ | Latent noise reg | Same mechanism | **Independent Beta(2,2), logSNR shift +1.0, 10% prob** |
46
+ | Depthwise kernel | 7Γ—7 | 7Γ—7 (same) |
47
+ | Model dim | 896 | 896 (same) |
48
+ | Encoder depth | 4 | 4 (same) |
49
+ | Best decode | 1 step DDIM | 1 step DDIM (same) |
50
+
51
+ ## 3. Training-Time Masking Details
52
+
53
+ ### 3.1 Token Masking Procedure
54
+
55
+ During training, with 50% probability per sample:
56
+ 1. The fused decoder input (patchified x_t + upsampled encoder latents) is divided into non-overlapping 2Γ—2 spatial groups
57
+ 2. Within each group, 3 of 4 tokens (75%) are selected for masking using random scoring
58
+ 3. Masked tokens are replaced with a learned `mask_feature` parameter (same dimensionality as model_dim)
59
+ 4. The decoder processes the partially-masked input normally through all blocks
60
+
61
+ ### 3.2 Inference-Time PDG via Token Masking
62
+
63
+ At inference, the trained mask_feature enables Path-Drop Guidance (PDG) through token-level masking rather than block-level dropping:
64
+
65
+ - **Conditional pass**: Full decoder input (no masking)
66
+ - **Unconditional pass**: Apply 2Γ—2 groupwise token masking at the trained ratio (75%)
67
+ - **Guided output**: `x0 = x0_uncond + strength Γ— (x0_cond βˆ’ x0_uncond)`
68
+
69
+ Because the decoder has only 4 blocks and no skip connections, the guidance signal from token masking is very concentrated. This makes PDG extremely sensitive β€” even a strength of 1.2 produces noticeable sharpening, and values above 1.5 cause severe artifacts.
70
+
71
+ ## 4. Flat Decoder Architecture
72
+
73
+ ### 4.1 iRDiffAE v1 Decoder (for comparison)
74
+
75
+ The iRDiffAE v1 decoder uses an 8-block layout split into three groups with a skip connection:
76
+
77
+ ```
78
+ Fused input β†’ Start blocks (2) β†’ [save for skip] β†’
79
+ Middle blocks (4) β†’ [cat with saved skip] β†’ FuseSkip Conv1Γ—1 β†’
80
+ End blocks (2) β†’ Output head
81
+ ```
82
+
83
+ The skip connection concatenates the start-block output with the middle-block output and fuses them through a learned Conv1Γ—1 before feeding into the end blocks. For PDG, the entire middle block computation is dropped and replaced with a broadcasted learned `mask_feature`, effectively removing all 4 middle blocks from the forward pass. This produces a coarse "unconditional" signal for classifier-free guidance.
84
+
85
+ ### 4.2 mDiffAE v1 Decoder
86
+
87
+ The mDiffAE decoder replaces this with a flat sequential architecture β€” no block groups, no skip connection:
88
+
89
+ ```
90
+ Input: x_t [B, 3, H, W], t [B], z [B, 64, h, w]
91
+
92
+ Patchify(x_t) β†’ RMSNorm β†’ x_feat [B, 896, h, w]
93
+ LatentUp(z) β†’ RMSNorm β†’ z_up [B, 896, h, w]
94
+ FuseIn(cat(x_feat, z_up)) β†’ fused [B, 896, h, w]
95
+ [Optional: token masking for PDG]
96
+ TimeEmbed(t) β†’ cond [B, 896]
97
+ Block_0(fused, AdaLN(cond)) β†’ ...
98
+ Block_1(..., AdaLN(cond)) β†’ ...
99
+ Block_2(..., AdaLN(cond)) β†’ ...
100
+ Block_3(..., AdaLN(cond)) β†’ out [B, 896, h, w]
101
+ RMSNorm β†’ Conv1x1 β†’ PixelShuffle β†’ x0_hat [B, 3, H, W]
102
+ ```
103
+
104
+ With only 4 blocks and no skip fusion layer, the decoder has roughly half the parameters of iRDiffAE's decoder. The `fuse_skip` Conv1Γ—1 layer is eliminated entirely. For PDG, instead of dropping blocks, 75% of spatial tokens in the fused input are replaced with a learned `mask_feature` before the blocks run. This token-level masking provides a finer-grained guidance signal but is much more sensitive to strength β€” the decoder sees the full block computation in both the conditional and unconditional paths, so the difference between them is subtle.
105
+
106
+ ### 4.3 Bottleneck
107
+
108
+ The bottleneck dimension is halved from 128 channels (iRDiffAE) to 64 channels, giving a 12x compression ratio at patch size 16 (vs 6x for iRDiffAE). Despite the higher compression, the masking regularizer forces the encoder to produce informative per-token representations, maintaining reconstruction quality.
109
+
110
+ ## 5. Model Configuration
111
+
112
+ | Parameter | Value |
113
+ |-----------|-------|
114
+ | `in_channels` | 3 |
115
+ | `patch_size` | 16 |
116
+ | `model_dim` | 896 |
117
+ | `encoder_depth` | 4 |
118
+ | `decoder_depth` | 4 |
119
+ | `bottleneck_dim` | 64 |
120
+ | `mlp_ratio` | 4.0 |
121
+ | `depthwise_kernel_size` | 7 |
122
+ | `adaln_low_rank_rank` | 128 |
123
+ | `logsnr_min` | βˆ’10.0 |
124
+ | `logsnr_max` | 10.0 |
125
+ | `pixel_noise_std` | 0.558 |
126
+ | `pdg_mask_ratio` | 0.75 |
127
+
128
+ Training checkpoint: step 708,000 (EMA weights).
129
+
130
+ ## 6. Inference Recommendations
131
+
132
+ | Setting | Value | Notes |
133
+ |---------|-------|-------|
134
+ | Sampler | DDIM | Best for 1-step |
135
+ | Steps | 1 | PSNR-optimal |
136
+ | PDG | Disabled | Default, safest |
137
+ | PDG strength | 1.05–1.2 | If enabled, very sensitive |
138
+
139
+ ## 7. Results
140
+
141
+ Reconstruction quality evaluated on a curated set of test images covering photographs, book covers, and documents. Flux.1 VAE (patch 8, 16 channels) is included as a reference at the same 12x compression ratio as the c64 variant.
142
+
143
+ ### 7.1 Interactive Viewer
144
+
145
+ **[Open full-resolution comparison viewer](https://huggingface.co/spaces/data-archetype/mdiffae-results)** β€” side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.
146
+
147
+ ### 7.2 Inference Settings
148
+
149
+ | Setting | Value |
150
+ |---------|-------|
151
+ | Sampler | ddim |
152
+ | Steps | 1 |
153
+ | Schedule | linear |
154
+ | Seed | 42 |
155
+ | PDG | no_path_dropg |
156
+ | Batch size (timing) | 8 |
157
+
158
+ > All models run in bfloat16. Timings measured on an NVIDIA RTX Pro 6000 (Blackwell).
159
+
160
+ ### 7.3 Global Metrics
161
+
162
+ | Metric | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
163
+ |--------|--------|--------|--------|
164
+ | Avg PSNR (dB) | 31.89 | 32.76 | 34.16 |
165
+ | Avg encode (ms/image) | 2.4 | 63.9 | 45.7 |
166
+ | Avg decode (ms/image) | 3.0 | 138.2 | 92.8 |
167
+
168
+ ### 7.4 Per-Image PSNR (dB)
169
+
170
+ | Image | mdiffae_v1 (1 step) | Flux.1 VAE | Flux.2 VAE |
171
+ |-------|--------|--------|--------|
172
+ | p640x1536:94623 | 31.20 | 31.28 | 33.50 |
173
+ | p640x1536:94624 | 27.32 | 27.62 | 30.03 |
174
+ | p640x1536:94625 | 30.68 | 31.65 | 33.98 |
175
+ | p640x1536:94626 | 29.14 | 29.44 | 31.53 |
176
+ | p640x1536:94627 | 29.63 | 28.70 | 30.53 |
177
+ | p640x1536:94628 | 25.60 | 26.38 | 28.88 |
178
+ | p960x1024:216264 | 44.50 | 40.87 | 45.39 |
179
+ | p960x1024:216265 | 26.42 | 25.82 | 27.80 |
180
+ | p960x1024:216266 | 44.90 | 47.77 | 46.20 |
181
+ | p960x1024:216267 | 37.78 | 37.65 | 39.23 |
182
+ | p960x1024:216268 | 36.15 | 35.27 | 36.13 |
183
+ | p960x1024:216269 | 29.37 | 28.45 | 30.24 |
184
+ | p960x1024:216270 | 32.43 | 31.92 | 34.18 |
185
+ | p960x1024:216271 | 41.23 | 38.92 | 42.18 |
186
+ | p704x1472:94699 | 41.88 | 40.43 | 41.79 |
187
+ | p704x1472:94700 | 29.66 | 29.52 | 32.08 |
188
+ | p704x1472:94701 | 35.14 | 35.43 | 37.90 |
189
+ | p704x1472:94702 | 30.90 | 30.73 | 32.50 |
190
+ | p704x1472:94703 | 28.65 | 29.08 | 31.35 |
191
+ | p704x1472:94704 | 28.98 | 29.22 | 31.84 |
192
+ | p704x1472:94705 | 36.09 | 36.38 | 37.44 |
193
+ | p704x1472:94706 | 31.53 | 31.50 | 33.66 |
194
+ | r256_p1344x704:15577 | 27.89 | 28.32 | 29.98 |
195
+ | r256_p1344x704:15578 | 28.07 | 29.35 | 30.79 |
196
+ | r256_p1344x704:15579 | 29.56 | 30.44 | 31.83 |
197
+ | r256_p1344x704:15580 | 32.89 | 36.12 | 36.03 |
198
+ | r256_p1344x704:15581 | 32.26 | 37.42 | 36.94 |
199
+ | r256_p1344x704:15582 | 28.74 | 30.64 | 32.10 |
200
+ | r256_p1344x704:15583 | 31.99 | 34.67 | 34.54 |
201
+ | r256_p1344x704:15584 | 28.42 | 30.34 | 31.76 |
202
+ | r256_p896x1152:144131 | 30.02 | 33.10 | 33.60 |
203
+ | r256_p896x1152:144132 | 33.19 | 34.23 | 35.32 |
204
+ | r256_p896x1152:144133 | 35.42 | 37.85 | 37.33 |
205
+ | r256_p896x1152:144134 | 31.41 | 34.25 | 34.47 |
206
+ | r256_p896x1152:144135 | 27.13 | 28.17 | 29.87 |
207
+ | r256_p896x1152:144136 | 32.75 | 35.24 | 35.68 |
208
+ | r256_p896x1152:144137 | 28.60 | 32.70 | 32.86 |
209
+ | r256_p896x1152:144138 | 24.76 | 24.15 | 25.63 |
210
+ | VAE_accuracy_test_image | 31.52 | 36.69 | 35.25 |
211
+