data-archetype commited on
Commit
1b703d5
·
0 Parent(s):

Upload DINAC-AE export package

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - diffusion
5
+ - autoencoder
6
+ - image-reconstruction
7
+ - latent-space
8
+ - dino
9
+ - pytorch
10
+ ---
11
+
12
+ # data-archetype/dinac_ae
13
+
14
+ **DINAC-AE** is a **DIN**O-**A**ligned **C**lass-token **A**uto**E**ncoder.
15
+ It follows the [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae)
16
+ family: patch-16 spatial latents, a VP diffusion decoder, and DINO-aligned
17
+ representations.
18
+
19
+ Relative to SemDisDiffAE, DINAC-AE changes the encoder from FCDM blocks to a
20
+ 6-block ViT/DiT-style transformer encoder and uses DINOv3 ViT-B/16 alignment.
21
+ The latent-to-DINO alignment head is extended to predict the DINO class token
22
+ as well as patch tokens. `predict_class(latents)` exposes that class-token
23
+ feature directly from latents.
24
+
25
+ ## 2k PSNR Benchmark
26
+
27
+ | Model | Mean PSNR (dB) | Std (dB) | Median (dB) | P5 (dB) | P95 (dB) |
28
+ |---|---:|---:|---:|---:|---:|
29
+ | dinac_ae | `35.19` | `4.53` | `35.06` | `28.02` | `42.43` |
30
+ | FLUX.2 VAE | `36.28` | `4.53` | `36.07` | `28.89` | `43.63` |
31
+
32
+ Evaluated on `2000` validation images.
33
+
34
+ DINAC-AE targets a compromise between high reconstruction quality, a learnable
35
+ latent space with KL-like variance expansion, DINOv3 alignment, and robustness
36
+ to local token errors.
37
+
38
+ [Results viewer](https://huggingface.co/spaces/data-archetype/dinac_ae-results)
39
+ shows the 39-image reconstruction set with DINAC-AE and FLUX.2 VAE
40
+ reconstructions, RGB differences, and latent PCA.
41
+ The released export recheck on that 39-image set gives `35.15 dB` mean PSNR
42
+ (`25.73` min, `45.99` max).
43
+
44
+ [Full technical report](https://huggingface.co/data-archetype/dinac_ae/blob/main/technical_report_dinac_ae.md)
45
+
46
+ ## Encode Throughput
47
+
48
+ Measured on an `NVIDIA GeForce RTX 5090` in `bfloat16`, averaging repeated
49
+ batches per resolution.
50
+
51
+ | Resolution | Batch Size | dinac_ae encode (ms/batch) | FLUX.2 encode (ms/batch) | dinac_ae peak VRAM (MiB) | FLUX.2 peak VRAM (MiB) | Speedup vs FLUX.2 | Peak VRAM Reduction vs FLUX.2 |
52
+ |---:|---:|---:|---:|---:|---:|---:|---:|
53
+ | `256x256` | `128` | `50` | `383` | `1,637` | `12,511` | `7.62x` | `86.9%` |
54
+ | `512x512` | `32` | `53` | `354` | `1,639` | `12,511` | `6.72x` | `86.9%` |
55
+
56
+ The transformer encoder is slightly slower and larger than the full_capacitor
57
+ FCDM encoder, but remains much faster and much smaller than the FLUX.2 VAE
58
+ encoder.
59
+
60
+ ## Latent Interface
61
+
62
+ - `encode()` returns DINAC-AE's own whitened latent space.
63
+ - `decode()` expects that same whitened latent space and dewhitens internally.
64
+ - `predict_class()` expects the same whitened latent space, dewhitens
65
+ internally, and predicts a DINOv3 ViT-B/16 class-token feature.
66
+ - `whiten()` and `dewhiten()` are exposed for explicit control.
67
+ - `encode_posterior()` returns the raw exported posterior before whitening.
68
+ - `DinacAEInferenceConfig.num_steps` counts decoder evaluations directly:
69
+ `num_steps=1` means one NFE.
70
+
71
+ The export ships weights in `float32`. The recommended and default runtime path
72
+ is `bfloat16` AMP for the main encoder, decoder, and class-token path, with
73
+ `float32` retained for sensitive operations such as whitening/dewhitening,
74
+ normalization math, RoPE frequency construction, and VP diffusion schedule
75
+ helpers.
76
+
77
+ ## Usage
78
+
79
+ ```python
80
+ import torch
81
+
82
+ from dinac_ae import DinacAE, DinacAEInferenceConfig
83
+
84
+
85
+ device = "cuda"
86
+ model = DinacAE.from_pretrained(
87
+ "data-archetype/dinac_ae",
88
+ device=device,
89
+ dtype=torch.bfloat16,
90
+ )
91
+
92
+ image = ... # [1, 3, H, W] in [-1, 1], H and W divisible by 16
93
+
94
+ with torch.inference_mode():
95
+ latents = model.encode(image.to(device=device, dtype=torch.bfloat16))
96
+ class_token = model.predict_class(latents)
97
+ recon = model.decode(
98
+ latents,
99
+ height=int(image.shape[-2]),
100
+ width=int(image.shape[-1]),
101
+ inference_config=DinacAEInferenceConfig(num_steps=1),
102
+ )
103
+ ```
104
+
105
+ ## Details
106
+
107
+ - DINAC-AE uses a `6`-block ViT/DiT-style transformer encoder and an `8`-block
108
+ FCDM decoder.
109
+ - Patch size is `16`, model width is `896`, and latent width is `128`.
110
+ - The DINO alignment head predicts spatial patch tokens and is extended with a
111
+ class-token output in DINOv3 ViT-B/16 feature space.
112
+ - The class-token output is used to improve semantic organization of the latent
113
+ space and to support FD-loss / Representation Frechet Distance objectives
114
+ directly in latent space.
115
+ - `predict_class(latents)` reaches mean cosine similarity `0.757458` against
116
+ the frozen DINOv3 ViT-B/16 teacher class token on the same `2000` images.
117
+ - DINO alignment is applied directly to clean latent tokens. Robustness to
118
+ local token errors is handled by random-token logSNR offset regularization.
119
+ - Results viewer: https://huggingface.co/spaces/data-archetype/dinac_ae-results
120
+ - Related: [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae),
121
+ [full_capacitor](https://huggingface.co/data-archetype/full_capacitor),
122
+ [capacitor_decoder](https://huggingface.co/data-archetype/capacitor_decoder)
123
+
124
+ ## Citation
125
+
126
+ ```bibtex
127
+ @misc{dinac_ae,
128
+ title = {DINAC-AE: a DINO-aligned class-token diffusion autoencoder},
129
+ author = {data-archetype},
130
+ email = {data-archetype@proton.me},
131
+ year = {2026},
132
+ month = may,
133
+ url = {https://huggingface.co/data-archetype/dinac_ae},
134
+ }
135
+ ```
common/norms.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+ from torch.nn import functional as F
8
+
9
+ __all__ = [
10
+ "ChannelWiseRMSNorm",
11
+ "GlobalRMSNorm",
12
+ "GroupNormF32",
13
+ "LayerNorm",
14
+ "LayerNorm2d",
15
+ "RMSNorm",
16
+ "global_rms_norm",
17
+ "row_norm",
18
+ ]
19
+
20
+
21
+ _HALF_PRECISION_DTYPES: tuple[torch.dtype, ...] = (torch.float16, torch.bfloat16)
22
+
23
+
24
+ def _cast_to_float32(x: Tensor) -> tuple[Tensor, torch.dtype]:
25
+ """Return tensor cast to fp32 for compute along with the original dtype."""
26
+ dtype = x.dtype
27
+ if dtype in _HALF_PRECISION_DTYPES:
28
+ return x.float(), dtype
29
+ return x, dtype
30
+
31
+
32
+ def _restore_dtype(x: Tensor, dtype: torch.dtype) -> Tensor:
33
+ return x if x.dtype == dtype else x.to(dtype)
34
+
35
+
36
+ class RMSNorm(nn.Module):
37
+ """Thin wrapper around ``torch.nn.RMSNorm`` that preserves our API.
38
+
39
+ - Keeps an ``_eps`` attribute used by tests.
40
+ - Maps ``affine`` -> ``elementwise_affine``.
41
+ - Delegates all compute to the native implementation.
42
+
43
+ Notes on precision
44
+ - PyTorch ≥ 2.8 computes RMSNorm reductions in ``opmath`` dtype
45
+ (float32 for float16/bfloat16) internally, then restores the input dtype.
46
+ """
47
+
48
+ def __init__(self, dim: int, eps: float = 1e-6, affine: bool = True) -> None:
49
+ super().__init__()
50
+ self._eps: float = float(eps)
51
+ self._impl: nn.RMSNorm = nn.RMSNorm(
52
+ dim, eps=self._eps, elementwise_affine=affine
53
+ )
54
+ self._dim: int = int(dim)
55
+
56
+ @property
57
+ def weight(self) -> Tensor | None: # expose for tests/compat
58
+ return self._impl.weight
59
+
60
+ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
61
+ """Apply RMSNorm while avoiding dtype-mismatch warnings under AMP.
62
+
63
+ When inputs are bfloat16/float16 under autocast and the stored affine
64
+ weight is float32 (common when model weights remain FP32), PyTorch emits
65
+ a warning about mismatched dtypes and disables the fused path.
66
+
67
+ We pass a view of the weight cast to the input dtype into the functional
68
+ RMSNorm to enable the fused implementation without changing the
69
+ parameter storage dtype (which remains FP32 for stability).
70
+ """
71
+ # Prefer functional to control the weight dtype for the kernel
72
+ w: Tensor | None = self._impl.weight
73
+ w_cast = w.to(dtype=x.dtype) if w is not None else None
74
+ # Bias is not present in RMSNorm; functional takes (input, shape, weight, eps)
75
+ return F.rms_norm(x, (self._dim,), w_cast, self._eps)
76
+
77
+
78
+ class LayerNorm(nn.LayerNorm):
79
+ """Thin wrapper over ``torch.nn.LayerNorm`` with an ``_eps`` attribute.
80
+
81
+ Notes on precision
82
+ - Native LayerNorm kernels accumulate statistics in ``opmath`` dtype
83
+ (float32 for float16/bfloat16) before casting results back.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ normalized_shape: int | Sequence[int],
89
+ eps: float = 1e-6,
90
+ elementwise_affine: bool = True,
91
+ ) -> None:
92
+ shape: int | list[int]
93
+ match normalized_shape:
94
+ case int() as dim:
95
+ shape = dim
96
+ case _:
97
+ shape = [int(v) for v in normalized_shape]
98
+ super().__init__(shape, eps=eps, elementwise_affine=elementwise_affine)
99
+ self._eps: float = float(eps)
100
+
101
+ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
102
+ # Delegate to native LayerNorm
103
+ return super().forward(x)
104
+
105
+
106
+ # Prefer numerically stable GroupNormF32 below.
107
+
108
+
109
+ class GroupNormF32(nn.GroupNorm):
110
+ """Thin wrapper over ``torch.nn.GroupNorm`` with an ``_eps`` attribute.
111
+
112
+ Notes on precision
113
+ - Native GroupNorm uses ``opmath`` accumulation (float32 for
114
+ float16/bfloat16) for statistics and fused scale/bias math; results
115
+ are cast back to the input dtype.
116
+ - Despite the class name, this wrapper does not force a cast; it
117
+ delegates to the native implementation.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ num_groups: int,
123
+ num_channels: int,
124
+ eps: float = 1e-6,
125
+ affine: bool = True,
126
+ ) -> None:
127
+ super().__init__(num_groups, num_channels, eps=eps, affine=affine)
128
+ self._eps: float = float(eps)
129
+
130
+
131
+ class ChannelWiseRMSNorm(nn.Module):
132
+ """Channel-wise RMSNorm for NCHW tensors (fast NCHW path).
133
+
134
+ - Normalizes across channels per spatial position without reshaping, using
135
+ a float32 reduction for numerical stability and keeping elementwise ops
136
+ in input dtype for throughput.
137
+ - Supports optional per-channel affine ``weight`` and ``bias``.
138
+ """
139
+
140
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
141
+ super().__init__()
142
+ self.channels: int = int(channels)
143
+ self._eps: float = float(eps)
144
+ self.affine: bool = bool(affine)
145
+ if self.affine:
146
+ self.weight = nn.Parameter(torch.ones(self.channels))
147
+ self.bias = nn.Parameter(torch.zeros(self.channels))
148
+ else:
149
+ self.register_parameter("weight", None)
150
+ self.register_parameter("bias", None)
151
+
152
+ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
153
+ if x.dim() < 2:
154
+ return x
155
+ C = x.size(1)
156
+ if self.channels != C:
157
+ raise ValueError(f"ChannelWiseRMSNorm expected C={self.channels}, got {C}")
158
+ # Keep only the reductions in fp32; scale/apply in the input dtype.
159
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
160
+ inv_rms = torch.rsqrt(ms + self._eps) # float32
161
+ y = x * inv_rms.to(dtype=x.dtype)
162
+ if self.affine and self.weight is not None:
163
+ shape = (1, -1) + (1,) * (x.dim() - 2)
164
+ y = y * self.weight.view(shape).to(dtype=x.dtype)
165
+ if self.bias is not None:
166
+ y = y + self.bias.view(shape).to(dtype=x.dtype)
167
+ return y
168
+
169
+
170
+ def global_rms_norm(x: Tensor, eps: float = 1e-6) -> Tensor:
171
+ """Project each sample to unit RMS across all non-batch dimensions.
172
+
173
+ This is equivalent to RMSNorm with ``normalized_shape=x.shape[1:]`` and no
174
+ affine parameters. Delegating to the native functional keeps the fast fused
175
+ CUDA path and the same opmath accumulation behavior as ``torch.nn.RMSNorm``.
176
+ """
177
+
178
+ if x.dim() < 2:
179
+ return x
180
+ normalized_shape = tuple(int(dim) for dim in x.shape[1:])
181
+ return F.rms_norm(x, normalized_shape, None, eps)
182
+
183
+
184
+ class GlobalRMSNorm(nn.Module):
185
+ """RMSNorm across all dims except batch — sphere projection for NCHW tensors.
186
+
187
+ Unlike :class:`ChannelWiseRMSNorm` (which normalizes per spatial position
188
+ over channels), this normalizes the *entire* feature volume jointly,
189
+ projecting each sample onto a hypersphere. No learnable parameters.
190
+ """
191
+
192
+ def __init__(self, eps: float = 1e-6) -> None:
193
+ super().__init__()
194
+ self._eps: float = float(eps)
195
+
196
+ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
197
+ return global_rms_norm(x, eps=self._eps)
198
+
199
+
200
+ class LayerNorm2d(nn.LayerNorm):
201
+ """Channel-wise LayerNorm using native ``F.layer_norm`` on a reshaped view.
202
+
203
+ - Normalizes over channels only for each spatial location (B, h, w).
204
+ - Weight and bias follow the base class semantics (shape [C]).
205
+
206
+ Notes on precision
207
+ - ``F.layer_norm`` calls the native LayerNorm kernel which accumulates in
208
+ ``opmath`` dtype (float32 for float16/bfloat16), then casts back.
209
+ """
210
+
211
+ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
212
+ if x.dim() < 3:
213
+ return super().forward(x)
214
+ B, C = x.shape[:2]
215
+ spatial = x.shape[2:]
216
+ x_view = x.permute(0, *range(2, x.dim()), 1).contiguous().view(-1, C)
217
+ y = F.layer_norm(x_view, (C,), self.weight, self.bias, self.eps)
218
+ y = y.view(B, *spatial, C).permute(0, x.dim() - 1, *range(1, x.dim() - 1))
219
+ return y.contiguous()
220
+
221
+
222
+ def row_norm(W: Tensor, eps: float = 1e-6) -> Tensor:
223
+ """Row-normalise weight matrices along the last dimension.
224
+
225
+ Precision and performance
226
+ - Accumulates the squared sum in float32 without materializing a full fp32
227
+ copy of ``W`` via ``sum(..., dtype=torch.float32)``.
228
+ - Uses ``rsqrt`` and clamps the inverse norm via ``clamp_max(1/eps)`` to
229
+ match ``clamp_min(eps)`` on the denominator.
230
+ - Scales in the input dtype for throughput; callers relying on exact
231
+ float32 scaling should cast explicitly.
232
+ """
233
+ # Sum of squares in fp32 for stability
234
+ ss = torch.sum(torch.square(W), dim=-1, keepdim=True, dtype=torch.float32)
235
+ inv = torch.rsqrt(ss).clamp_max(1.0 / float(eps)) # float32
236
+ return W * inv.to(dtype=W.dtype)
common/rope.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections.abc import Callable
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class Rope1D(nn.Module):
11
+ """
12
+ Rotary Position Embedding (RoPE) 1D.
13
+
14
+ Based on the reference LLaMA implementation (Hugging Face
15
+ `modeling_llama.py`), adapted to this codebase without behavior changes.
16
+
17
+ - dim: per-head dimension
18
+ - max_position_embeddings: length used to precompute cached cos/sin (not required
19
+ by forward)
20
+ - base: RoPE base theta
21
+
22
+ Forward expects:
23
+ - x: (B, H, T, D)
24
+ - position_ids: (B, T) integer positions
25
+ Returns:
26
+ - cos, sin: (B, T, D)
27
+ """
28
+
29
+ inv_freq: torch.Tensor
30
+ _cos_cached: torch.Tensor
31
+ _sin_cached: torch.Tensor
32
+
33
+ def __init__(
34
+ self,
35
+ dim: int,
36
+ max_position_embeddings: int = 2048,
37
+ base: float = 10000.0,
38
+ device: torch.device | None = None,
39
+ scaling_factor: float = 1.0,
40
+ ) -> None:
41
+ super().__init__()
42
+ if dim % 2 != 0:
43
+ raise AssertionError("head_dim must be even for RoPE")
44
+ self.scaling_factor: float = float(scaling_factor)
45
+ self.dim: int = int(dim)
46
+ self.max_position_embeddings: int = int(max_position_embeddings)
47
+ self.base: float = float(base)
48
+ inv_freq = self._build_inv_freq(device=device)
49
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
50
+
51
+ # Cached cos/sin (not used in application, but kept for parity with reference)
52
+ self.max_seq_len_cached: int = self.max_position_embeddings
53
+ cos_cached, sin_cached = self._build_cached_trig(device=device)
54
+ self.register_buffer("_cos_cached", cos_cached, persistent=False)
55
+ self.register_buffer("_sin_cached", sin_cached, persistent=False)
56
+
57
+ def _build_inv_freq(self, *, device: torch.device | None) -> torch.Tensor:
58
+ """Return the RoPE inverse-frequency vector in float32."""
59
+
60
+ return 1.0 / (
61
+ self.base
62
+ ** (
63
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
64
+ / float(self.dim)
65
+ )
66
+ )
67
+
68
+ def _build_cached_trig(
69
+ self, *, device: torch.device | None
70
+ ) -> tuple[torch.Tensor, torch.Tensor]:
71
+ """Return cached RoPE trig tensors in float32."""
72
+
73
+ inv_freq = self._build_inv_freq(device=device)
74
+ t = torch.arange(
75
+ self.max_seq_len_cached,
76
+ device=device,
77
+ dtype=torch.float32,
78
+ )
79
+ t = t / self.scaling_factor
80
+ freqs = torch.outer(t, inv_freq)
81
+ emb = torch.cat((freqs, freqs), dim=-1)
82
+ return emb.cos(), emb.sin()
83
+
84
+ def _apply(
85
+ self,
86
+ fn: Callable[[torch.Tensor], torch.Tensor],
87
+ recurse: bool = True,
88
+ ) -> Rope1D:
89
+ """Apply module moves/casts while preserving fp32 RoPE buffers."""
90
+
91
+ out = super()._apply(fn, recurse=recurse)
92
+ with torch.no_grad():
93
+ device = self.inv_freq.device
94
+ self.inv_freq.data = self._build_inv_freq(device=device)
95
+ cos_cached, sin_cached = self._build_cached_trig(device=device)
96
+ self._cos_cached.data = cos_cached
97
+ self._sin_cached.data = sin_cached
98
+ return out
99
+
100
+ @torch.no_grad()
101
+ def forward(
102
+ self, x: torch.Tensor, position_ids: torch.Tensor
103
+ ) -> tuple[torch.Tensor, torch.Tensor]:
104
+ inv_freq_tensor = self._build_inv_freq(device=x.device)
105
+ inv_freq_expanded = (
106
+ inv_freq_tensor[None, :, None].float().expand(position_ids.shape[0], -1, 1)
107
+ )
108
+ position_ids_expanded = position_ids[:, None, :].float() / self.scaling_factor
109
+ device_type = x.device.type
110
+ device_type = (
111
+ device_type
112
+ if isinstance(device_type, str) and device_type != "mps"
113
+ else "cpu"
114
+ )
115
+ with torch.autocast(device_type=device_type, enabled=False):
116
+ freqs = (
117
+ inv_freq_expanded.float() @ position_ids_expanded.float()
118
+ ).transpose(1, 2)
119
+ emb = torch.cat((freqs, freqs), dim=-1)
120
+ cos = emb.cos()
121
+ sin = emb.sin()
122
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
123
+
124
+
125
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
126
+ x1 = x[..., : x.shape[-1] // 2]
127
+ x2 = x[..., x.shape[-1] // 2 :]
128
+ return torch.cat((-x2, x1), dim=-1)
129
+
130
+
131
+ def rotate_half_adjacent(x: torch.Tensor) -> torch.Tensor:
132
+ """Rotate consecutive pairs in the last dimension.
133
+
134
+ This matches the common EVA-02 / SpeedrunDiT RoPE convention where the last
135
+ dimension is interpreted as pairs ``(x0, x1), (x2, x3), ...``.
136
+ """
137
+ if x.shape[-1] % 2 != 0:
138
+ raise ValueError("rotate_half_adjacent requires an even last dimension")
139
+ x_pairs = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
140
+ x1 = x_pairs[..., 0]
141
+ x2 = x_pairs[..., 1]
142
+ return torch.stack((-x2, x1), dim=-1).reshape_as(x)
143
+
144
+
145
+ def apply_rotary_pos_emb(
146
+ q: torch.Tensor,
147
+ k: torch.Tensor,
148
+ cos: torch.Tensor,
149
+ sin: torch.Tensor,
150
+ *,
151
+ unsqueeze_dim: int = 1,
152
+ ) -> tuple[torch.Tensor, torch.Tensor]:
153
+ cos = cos.unsqueeze(unsqueeze_dim)
154
+ sin = sin.unsqueeze(unsqueeze_dim)
155
+ q_embed = (q * cos) + (rotate_half(q) * sin)
156
+ k_embed = (k * cos) + (rotate_half(k) * sin)
157
+ return q_embed, k_embed
158
+
159
+
160
+ class LearnableRoPE2D(nn.Module):
161
+ r"""
162
+ Learnable mixed 2D RoPE with axial RoPE2D-compatible initialization.
163
+
164
+ - Learnable frequency banks for X and Y.
165
+ - Frequencies can be shared across groups of attention heads (see
166
+ ``rope_param_dim``).
167
+ - Angle per pair: theta = x * fx[g, i] + y * fy[g, i]
168
+ - Initialization matches the axial RoPE2D parameterization used by DiTTrunk
169
+ for ``ROPE_2D_AXIAL_FREQ_AWARE`` (AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)):
170
+ - Angle multiplier ``2π``.
171
+ - Period base ``100`` (DINOv3-style), applied per-axis.
172
+ Each head group starts identically (deterministic init) so the learnable
173
+ variant is functionally identical to axial RoPE2D at step 0.
174
+ - Rotation is implemented with real-valued sin/cos to avoid complex tensors
175
+ (torch.compile/inductor cannot codegen complex dtypes).
176
+
177
+ Shapes:
178
+ - Expects q,k of shape (B, H, T, D) with D % 4 == 0.
179
+ - Positions xy: (T, 2) or (B, T, 2), any real dtype (cast to float32).
180
+ - Parameter `freqs`: (2, G, D//2) in float32; index 0 = x, 1 = y.
181
+
182
+ Head grouping / parameter budget
183
+ -------------------------------
184
+ ``rope_param_dim`` controls the total number of learned RoPE frequency
185
+ parameters (scalars) for this module.
186
+
187
+ Let:
188
+ - ``head_dim = D`` (per-head width)
189
+ - ``num_heads = H``
190
+ - ``rope_param_dim = P``
191
+
192
+ Then the module uses:
193
+ - ``num_groups = G = P // D``
194
+ - ``heads_per_group = H // G``
195
+
196
+ This is fail-fast: ``P`` must be divisible by ``D`` and ``H`` must be
197
+ divisible by ``G``. When ``rope_param_dim`` is None (default), the module
198
+ uses the classic per-head parameterization with ``P = H * D``.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ head_dim: int,
204
+ *,
205
+ num_heads: int,
206
+ rope_param_dim: int | None = None,
207
+ rope_base: float = 100.0,
208
+ angle_multiplier: float = 2.0 * float(math.pi),
209
+ learnable: bool = True,
210
+ persist_buffers: bool = True,
211
+ ) -> None:
212
+ super().__init__()
213
+ if head_dim % 4 != 0:
214
+ raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")
215
+ self.head_dim: int = int(head_dim)
216
+ # Avoid naming collisions with nn.Module.half() (dtype casting helper).
217
+ self.half_dim: int = self.head_dim // 2
218
+ self.num_heads: int = int(num_heads)
219
+ effective_param_dim = (
220
+ int(rope_param_dim)
221
+ if rope_param_dim is not None
222
+ else self.num_heads * self.head_dim
223
+ )
224
+ if effective_param_dim <= 0:
225
+ raise ValueError("rope_param_dim must be positive for LearnableRoPE2D")
226
+ self.rope_param_dim: int = int(effective_param_dim)
227
+ self._learnable: bool = bool(learnable)
228
+ theta = float(rope_base)
229
+ mult = float(angle_multiplier)
230
+ if not math.isfinite(theta) or theta <= 0.0:
231
+ raise ValueError("rope_base must be finite and > 0 for LearnableRoPE2D")
232
+ if not math.isfinite(mult) or mult <= 0.0:
233
+ raise ValueError(
234
+ "angle_multiplier must be finite and > 0 for LearnableRoPE2D"
235
+ )
236
+
237
+ if self.rope_param_dim % self.head_dim != 0:
238
+ raise ValueError(
239
+ "rope_param_dim must be divisible by head_dim for LearnableRoPE2D "
240
+ f"(got rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
241
+ )
242
+ self.num_groups: int = self.rope_param_dim // self.head_dim
243
+ if self.num_groups <= 0:
244
+ raise RuntimeError("num_groups must be positive for LearnableRoPE2D")
245
+ if self.num_heads % self.num_groups != 0:
246
+ raise ValueError(
247
+ "num_heads must be divisible by (rope_param_dim / head_dim) for LearnableRoPE2D "
248
+ f"(got num_heads={self.num_heads}, num_groups={self.num_groups}, "
249
+ f"rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
250
+ )
251
+ self.heads_per_group: int = self.num_heads // self.num_groups
252
+ if self.heads_per_group <= 0:
253
+ raise RuntimeError("heads_per_group must be positive for LearnableRoPE2D")
254
+
255
+ # Axial-compatible deterministic init:
256
+ # - periods match AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)
257
+ # - angle = 2π * coord / period
258
+ qtr = self.head_dim // 4
259
+ exponents = (
260
+ 2.0
261
+ * torch.arange(int(qtr), dtype=torch.float32)
262
+ / float(self.head_dim // 2)
263
+ )
264
+ periods = torch.tensor(theta, dtype=torch.float32) ** exponents # [qtr]
265
+ axis_freqs = (mult / periods).to(dtype=torch.float32) # [qtr]
266
+
267
+ zeros = torch.zeros_like(axis_freqs)
268
+ # Match AxialRoPE2D(HALF_SPLIT) flatten order: [y-axis, x-axis].
269
+ # Our xy columns are (x, y), so:
270
+ # - x contributes to the second quarter (x-axis part)
271
+ # - y contributes to the first quarter (y-axis part)
272
+ fx_half = torch.cat((zeros, axis_freqs), dim=0) # [half_dim]
273
+ fy_half = torch.cat((axis_freqs, zeros), dim=0) # [half_dim]
274
+
275
+ freqs_x = fx_half.expand(int(self.num_groups), -1).clone()
276
+ freqs_y = fy_half.expand(int(self.num_groups), -1).clone()
277
+ freqs = torch.stack([freqs_x, freqs_y], dim=0) # (2, G, half)
278
+ if self._learnable:
279
+ self.freqs = nn.Parameter(freqs, requires_grad=True)
280
+ else:
281
+ self.register_buffer("freqs", freqs, persistent=persist_buffers)
282
+
283
+ def _apply(
284
+ self,
285
+ fn: Callable[[torch.Tensor], torch.Tensor],
286
+ recurse: bool = True,
287
+ ) -> LearnableRoPE2D:
288
+ """Apply module moves/casts while preserving fp32 frequency tensors."""
289
+
290
+ out = super()._apply(fn, recurse=recurse)
291
+ with torch.no_grad():
292
+ self.freqs.data = self.freqs.data.to(dtype=torch.float32)
293
+ return out
294
+
295
+ def _apply_rotary_from_trig(
296
+ self,
297
+ x: torch.Tensor,
298
+ *,
299
+ sin: torch.Tensor,
300
+ cos: torch.Tensor,
301
+ ) -> torch.Tensor:
302
+ """Rotate Q/K using precomputed grouped sin/cos buffers (HALF_SPLIT layout).
303
+
304
+ This matches AxialRoPE2DConfig(dim_layout=HALF_SPLIT) rotation and keeps
305
+ the learnable variant identical at initialization when combined with
306
+ axial-compatible frequency init.
307
+
308
+ Args:
309
+ x: Tensor shaped ``(B, H, T, D)``.
310
+ sin: Sin tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.
311
+ cos: Cos tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.
312
+
313
+ Returns:
314
+ Tensor with the same shape/dtype/device as ``x``.
315
+ """
316
+ if x.dim() != 4:
317
+ raise ValueError("x must be shaped (B, H, T, D)")
318
+ B, H, T, D = x.shape
319
+ if self.num_heads != int(H):
320
+ raise ValueError("num_heads mismatch for LearnableRoPE2D")
321
+ if self.head_dim != int(D):
322
+ raise ValueError("head_dim mismatch for LearnableRoPE2D")
323
+
324
+ if sin.dim() == 3 and cos.dim() == 3:
325
+ sin = sin.unsqueeze(0)
326
+ cos = cos.unsqueeze(0)
327
+ if sin.dim() != 4 or cos.dim() != 4:
328
+ raise RuntimeError("Unexpected sin/cos rank for LearnableRoPE2D")
329
+ if int(D) % 2 != 0:
330
+ raise RuntimeError("LearnableRoPE2D requires even head_dim for HALF_SPLIT")
331
+ half = int(D) // 2
332
+ if int(sin.shape[-1]) != half or int(cos.shape[-1]) != half:
333
+ raise RuntimeError(
334
+ "LearnableRoPE2D expected sin/cos last dim == head_dim//2 "
335
+ f"(got sin={tuple(sin.shape)}, cos={tuple(cos.shape)}, head_dim={int(D)})"
336
+ )
337
+
338
+ sin = sin[:, :, None, :, :] # [B, G, 1, T, half]
339
+ cos = cos[:, :, None, :, :] # [B, G, 1, T, half]
340
+
341
+ grouped = x.reshape(
342
+ int(B),
343
+ int(self.num_groups),
344
+ int(self.heads_per_group),
345
+ int(T),
346
+ int(D),
347
+ )
348
+ x1 = grouped[..., :half]
349
+ x2 = grouped[..., half:]
350
+ out1 = x1 * cos - x2 * sin
351
+ out2 = x2 * cos + x1 * sin
352
+ out = torch.cat((out1, out2), dim=-1).reshape(int(B), int(H), int(T), int(D))
353
+ return out.to(dtype=x.dtype)
354
+
355
+ def _compute_mixed_cis(self, xy: torch.Tensor) -> torch.Tensor:
356
+ # Returns complex cis angles with shape (G, T, half) or (B, G, T, half)
357
+ if xy.dim() == 2:
358
+ # (T, 2) -> (G, T, half)
359
+ t_x = xy[:, 0].to(dtype=torch.float32)
360
+ t_y = xy[:, 1].to(dtype=torch.float32)
361
+ with torch.autocast(device_type=t_x.device.type, enabled=False):
362
+ # Memory notes:
363
+ # - Avoid materializing both fx and fy; accumulate in-place into angles.
364
+ # - Avoid torch.ones_like(angles) (full-size allocation); a scalar
365
+ # magnitude broadcasts in torch.polar.
366
+ angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
367
+ 0
368
+ ) # (T, G, half)
369
+ angles.add_(
370
+ t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
371
+ )
372
+ angles = angles.permute(1, 0, 2) # (G, T, half)
373
+ cis = torch.polar(
374
+ torch.ones((), device=angles.device, dtype=angles.dtype), angles
375
+ )
376
+ return cis
377
+ elif xy.dim() == 3:
378
+ # (B, T, 2) -> (B, G, T, half)
379
+ t_x = xy[..., 0].to(dtype=torch.float32)
380
+ t_y = xy[..., 1].to(dtype=torch.float32)
381
+ with torch.autocast(device_type=t_x.device.type, enabled=False):
382
+ angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
383
+ 0
384
+ ).unsqueeze(0)
385
+ angles.add_(
386
+ t_y.unsqueeze(-1).unsqueeze(-1)
387
+ * self.freqs[1].unsqueeze(0).unsqueeze(0)
388
+ )
389
+ angles = angles.permute(0, 2, 1, 3) # (B, G, T, half)
390
+ cis = torch.polar(
391
+ torch.ones((), device=angles.device, dtype=angles.dtype), angles
392
+ )
393
+ return cis
394
+ else:
395
+ raise ValueError("xy must have shape (T,2) or (B,T,2)")
396
+
397
+ def _compute_mixed_angles(self, xy: torch.Tensor) -> torch.Tensor:
398
+ """Return mixed RoPE2D angles without applying cis/polar.
399
+
400
+ Args:
401
+ xy: XY positions shaped ``(T, 2)`` or ``(B, T, 2)``.
402
+
403
+ Returns:
404
+ Float tensor of angles shaped ``(G, T, half)`` or ``(B, G, T, half)``.
405
+ """
406
+ if xy.dim() == 2:
407
+ t_x = xy[:, 0].to(dtype=torch.float32)
408
+ t_y = xy[:, 1].to(dtype=torch.float32)
409
+ with torch.autocast(device_type=t_x.device.type, enabled=False):
410
+ angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(0)
411
+ angles.add_(
412
+ t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
413
+ )
414
+ return angles.permute(1, 0, 2)
415
+ if xy.dim() == 3:
416
+ t_x = xy[..., 0].to(dtype=torch.float32)
417
+ t_y = xy[..., 1].to(dtype=torch.float32)
418
+ with torch.autocast(device_type=t_x.device.type, enabled=False):
419
+ angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
420
+ 0
421
+ ).unsqueeze(0)
422
+ angles.add_(
423
+ t_y.unsqueeze(-1).unsqueeze(-1)
424
+ * self.freqs[1].unsqueeze(0).unsqueeze(0)
425
+ )
426
+ return angles.permute(0, 2, 1, 3)
427
+ raise ValueError("xy must have shape (T,2) or (B,T,2)")
428
+
429
+ def _cos_sin_half_from_xy(
430
+ self,
431
+ xy: torch.Tensor,
432
+ *,
433
+ device: torch.device | None = None,
434
+ out_dtype: torch.dtype | None = None,
435
+ ) -> tuple[torch.Tensor, torch.Tensor]:
436
+ # Helper used in tests to build real-valued cos/sin tensors.
437
+ cis = self._compute_mixed_cis(xy.to(device=device) if device else xy)
438
+ # Convert complex cis to cos/sin (real/imag) with matching shapes
439
+ if cis.is_complex():
440
+ cos_h = cis.real
441
+ sin_h = cis.imag
442
+ else:
443
+ # Should not happen; torch.polar returns complex64/128
444
+ raise RuntimeError("Expected complex cis tensor from polar")
445
+ if out_dtype is not None:
446
+ cos_h = cos_h.to(dtype=out_dtype)
447
+ sin_h = sin_h.to(dtype=out_dtype)
448
+ return cos_h, sin_h
449
+
450
+ def _cos_sin_from_xy(
451
+ self,
452
+ xy: torch.Tensor,
453
+ *,
454
+ device: torch.device | None = None,
455
+ out_dtype: torch.dtype | None = None,
456
+ ) -> tuple[torch.Tensor, torch.Tensor]:
457
+ cos_h, sin_h = self._cos_sin_half_from_xy(
458
+ xy, device=device, out_dtype=out_dtype
459
+ )
460
+ emb_cos = torch.cat((cos_h, cos_h), dim=-1)
461
+ emb_sin = torch.cat((sin_h, sin_h), dim=-1)
462
+ return emb_cos, emb_sin
463
+
464
+ def rotate_qk(
465
+ self,
466
+ q: torch.Tensor,
467
+ k: torch.Tensor,
468
+ xy: torch.Tensor,
469
+ ) -> tuple[torch.Tensor, torch.Tensor]:
470
+ if q.dim() != 4 or k.dim() != 4:
471
+ raise ValueError("q,k must be shaped (B,H,T,D)")
472
+ _, H, _, D = q.shape
473
+ if self.num_heads != H:
474
+ raise ValueError("num_heads mismatch for LearnableRoPE2D")
475
+ if self.head_dim != D:
476
+ raise ValueError("head_dim mismatch for LearnableRoPE2D")
477
+ if D % 4 != 0:
478
+ raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")
479
+
480
+ # Use real-valued sin/cos rotation to keep torch.compile/inductor on the
481
+ # fast path (inductor cannot codegen complex tensors).
482
+ angles = self._compute_mixed_angles(xy.to(device=q.device))
483
+ sin = torch.sin(angles)
484
+ cos = torch.cos(angles)
485
+ q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
486
+ k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
487
+ return q_out, k_out
488
+
489
+ def rotate_qk_with_dilation(
490
+ self,
491
+ q: torch.Tensor,
492
+ k: torch.Tensor,
493
+ *,
494
+ xy: torch.Tensor,
495
+ scales: torch.Tensor,
496
+ ) -> tuple[torch.Tensor, torch.Tensor]:
497
+ """Rotate Q/K using mixed 2D RoPE with per-sample isotropic dilation.
498
+
499
+ This implements dilation by scaling the RoPE angle, i.e.
500
+ ``theta_dilated = scale * theta_base`` where ``theta_base`` comes from the
501
+ undilated XY coordinates.
502
+
503
+ Args:
504
+ q: Query tensor shaped ``(B, H, T, D)``.
505
+ k: Key tensor shaped ``(B, H, T, D)``.
506
+ xy: Base XY coordinates shaped ``(T, 2)`` or ``(B, T, 2)``.
507
+ scales: Per-sample dilation scales shaped ``(B,)``.
508
+
509
+ Raises:
510
+ ValueError: If shapes are inconsistent or scales are not 1D.
511
+ """
512
+ if q.dim() != 4 or k.dim() != 4:
513
+ raise ValueError("q,k must be shaped (B,H,T,D)")
514
+ B, H, T, D = q.shape
515
+ if self.num_heads != H:
516
+ raise ValueError("num_heads mismatch for LearnableRoPE2D")
517
+ if self.head_dim != D:
518
+ raise ValueError("head_dim mismatch for LearnableRoPE2D")
519
+ if scales.dim() != 1 or scales.shape[0] != B:
520
+ raise ValueError("scales must have shape (B,) matching q batch size")
521
+ if xy.dim() == 2 and xy.shape[0] != T:
522
+ raise ValueError("xy length must match q sequence length")
523
+ if xy.dim() == 3 and (xy.shape[0] != B or xy.shape[1] != T):
524
+ raise ValueError("xy must have shape (B,T,2) matching q batch/sequence")
525
+ if xy.shape[-1] != 2:
526
+ raise ValueError("xy must have last dimension 2")
527
+
528
+ angles = self._compute_mixed_angles(xy.to(device=q.device))
529
+ angles = angles * scales.to(device=q.device, dtype=torch.float32).view(
530
+ B, 1, 1, 1
531
+ )
532
+ sin = torch.sin(angles)
533
+ cos = torch.cos(angles)
534
+ q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
535
+ k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
536
+ return q_out, k_out
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "patch_size": 16,
4
+ "model_dim": 896,
5
+ "encoder_depth": 6,
6
+ "decoder_depth": 8,
7
+ "decoder_start_blocks": 2,
8
+ "decoder_end_blocks": 2,
9
+ "bottleneck_dim": 128,
10
+ "mlp_ratio": 4.0,
11
+ "encoder_mlp_type": "gelu",
12
+ "depthwise_kernel_size": 7,
13
+ "adaln_low_rank_rank": 128,
14
+ "bottleneck_posterior_kind": "diagonal_gaussian",
15
+ "bottleneck_norm_mode": "disabled",
16
+ "logsnr_min": -10.0,
17
+ "logsnr_max": 10.0,
18
+ "pixel_noise_std": 0.558,
19
+ "latent_running_stats_eps": 0.0001,
20
+ "class_head_feature_dim": 768,
21
+ "class_head_model_dim": 768,
22
+ "class_head_head_dim": 64,
23
+ "class_head_mlp_ratio": 4.0,
24
+ "class_head_mlp_type": "gelu",
25
+ "class_head_register_token_count": 4
26
+ }
dinac_ae/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DINAC-AE: DINO-aligned class-token autoencoder export."""
2
+
3
+ from .config import DinacAEConfig, DinacAEInferenceConfig
4
+ from .encoder import EncoderPosterior
5
+ from .model import DinacAE
6
+
7
+ __all__ = [
8
+ "DinacAE",
9
+ "DinacAEConfig",
10
+ "DinacAEInferenceConfig",
11
+ "EncoderPosterior",
12
+ ]
dinac_ae/adaln.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class AdaLNScaleGateZeroProjector(nn.Module):
9
+ """Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
10
+
11
+ Outputs [B, 2*d_model] packed as (scale, gate).
12
+ """
13
+
14
+ def __init__(self, d_model: int, d_cond: int) -> None:
15
+ super().__init__()
16
+ self.d_model: int = int(d_model)
17
+ self.d_cond: int = int(d_cond)
18
+ self.act: nn.SiLU = nn.SiLU()
19
+ self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
20
+ nn.init.zeros_(self.proj.weight)
21
+ nn.init.zeros_(self.proj.bias)
22
+
23
+ def project_activated(self, act_cond: Tensor) -> Tensor:
24
+ """Return packed modulation for a pre-activated conditioning vector."""
25
+
26
+ if act_cond.dim() != 2:
27
+ raise ValueError(
28
+ "AdaLNScaleGateZeroProjector expects act_cond with shape [B, d_cond]"
29
+ )
30
+ if act_cond.shape[1] != self.d_cond:
31
+ raise ValueError(
32
+ f"act_cond width {int(act_cond.shape[1])} does not match d_cond={self.d_cond}"
33
+ )
34
+ return self.proj(act_cond)
35
+
36
+ def forward(self, cond: Tensor) -> Tensor:
37
+ """Return packed modulation [B, 2*d_model]."""
38
+ if cond.dim() != 2:
39
+ raise ValueError(
40
+ "AdaLNScaleGateZeroProjector expects cond with shape [B, d_cond]"
41
+ )
42
+ if cond.shape[1] != self.d_cond:
43
+ raise ValueError(
44
+ f"cond width {int(cond.shape[1])} does not match d_cond={self.d_cond}"
45
+ )
46
+ return self.project_activated(self.act(cond))
47
+
48
+
49
+ class AdaLNScaleGateZeroLowRankDelta(nn.Module):
50
+ """Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
51
+
52
+ Zero-initialized up projection preserves zero-output semantics at init.
53
+ """
54
+
55
+ def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
56
+ super().__init__()
57
+ self.d_model: int = int(d_model)
58
+ self.d_cond: int = int(d_cond)
59
+ self.rank: int = int(rank)
60
+ self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
61
+ self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
62
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
63
+ nn.init.zeros_(self.up.weight)
64
+
65
+ def forward(self, act_cond: Tensor) -> Tensor:
66
+ """Return packed delta modulation [B, 2*d_model]."""
67
+ if act_cond.dim() != 2:
68
+ raise ValueError(
69
+ "AdaLNScaleGateZeroLowRankDelta expects act_cond with shape [B, d_cond]"
70
+ )
71
+ if act_cond.shape[1] != self.d_cond:
72
+ raise ValueError(
73
+ f"act_cond width {int(act_cond.shape[1])} does not match d_cond={self.d_cond}"
74
+ )
75
+ return self.up(self.down(act_cond))
dinac_ae/config.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen model architecture and user-tunable inference configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class DinacAEConfig:
12
+ """Frozen architecture config stored alongside exported weights."""
13
+
14
+ in_channels: int = 3
15
+ patch_size: int = 16
16
+ model_dim: int = 896
17
+ encoder_depth: int = 6
18
+ decoder_depth: int = 8
19
+ decoder_start_blocks: int = 2
20
+ decoder_end_blocks: int = 2
21
+ bottleneck_dim: int = 128
22
+ mlp_ratio: float = 4.0
23
+ encoder_mlp_type: str = "gelu"
24
+ depthwise_kernel_size: int = 7
25
+ adaln_low_rank_rank: int = 128
26
+ bottleneck_posterior_kind: str = "diagonal_gaussian"
27
+ bottleneck_norm_mode: str = "disabled"
28
+ logsnr_min: float = -10.0
29
+ logsnr_max: float = 10.0
30
+ pixel_noise_std: float = 0.558
31
+ latent_running_stats_eps: float = 1e-4
32
+ class_head_feature_dim: int = 768
33
+ class_head_model_dim: int = 768
34
+ class_head_head_dim: int = 64
35
+ class_head_mlp_ratio: float = 4.0
36
+ class_head_mlp_type: str = "gelu"
37
+ class_head_register_token_count: int = 4
38
+
39
+ @property
40
+ def latent_channels(self) -> int:
41
+ """Return the exported latent channel width."""
42
+
43
+ return int(self.bottleneck_dim)
44
+
45
+ @property
46
+ def effective_patch_size(self) -> int:
47
+ """Return the image-to-latent stride."""
48
+
49
+ return int(self.patch_size)
50
+
51
+ def save(self, path: str | Path) -> None:
52
+ """Save config as JSON."""
53
+
54
+ output_path = Path(path)
55
+ output_path.parent.mkdir(parents=True, exist_ok=True)
56
+ output_path.write_text(json.dumps(asdict(self), indent=2) + "\n")
57
+
58
+ @classmethod
59
+ def load(cls, path: str | Path) -> DinacAEConfig:
60
+ """Load config from JSON."""
61
+
62
+ data = json.loads(Path(path).read_text())
63
+ return cls(**data)
64
+
65
+
66
+ @dataclass
67
+ class DinacAEInferenceConfig:
68
+ """User-tunable VP diffusion decode settings."""
69
+
70
+ num_steps: int = 1
71
+ sampler: str = "ddim"
72
+ schedule: str = "linear"
73
+ pdg: bool = False
74
+ pdg_strength: float = 2.0
75
+ seed: int | None = None
dinac_ae/decoder.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decoder matching the exported FCDM decoder stack."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
9
+ from .fcdm_block import FCDMBlock
10
+ from .straight_through_encoder import Patchify
11
+ from .time_embed import SinusoidalTimeEmbeddingMLP
12
+
13
+
14
+ class Decoder(nn.Module):
15
+ """VP diffusion decoder conditioned on encoder latents and timestep.
16
+
17
+ Architecture (skip-concat, 2+4+2 default):
18
+ Patchify x_t -> Fuse with upsampled z
19
+ -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
20
+ -> Conv1x1 -> PixelShuffle
21
+
22
+ Path-Drop Guidance (PDG) at inference:
23
+ - Replace middle block output with ``path_drop_mask_feature`` to create
24
+ an unconditional prediction, then extrapolate.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ in_channels: int,
30
+ patch_size: int,
31
+ model_dim: int,
32
+ depth: int,
33
+ start_block_count: int,
34
+ end_block_count: int,
35
+ bottleneck_dim: int,
36
+ mlp_ratio: float,
37
+ depthwise_kernel_size: int,
38
+ adaln_low_rank_rank: int,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.patch_size = int(patch_size)
42
+ self.model_dim = int(model_dim)
43
+
44
+ self.patchify = Patchify(
45
+ in_channels,
46
+ patch_size,
47
+ model_dim,
48
+ )
49
+
50
+ self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
51
+ self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
52
+
53
+ # Time embedding
54
+ self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
55
+
56
+ # 2-way AdaLN: shared base projector + per-block low-rank deltas
57
+ self.adaln_base = AdaLNScaleGateZeroProjector(
58
+ d_model=model_dim, d_cond=model_dim
59
+ )
60
+ self.adaln_deltas = nn.ModuleList(
61
+ [
62
+ AdaLNScaleGateZeroLowRankDelta(
63
+ d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
64
+ )
65
+ for _ in range(depth)
66
+ ]
67
+ )
68
+
69
+ # Block layout: start + middle + end
70
+ middle_count = depth - start_block_count - end_block_count
71
+ self._middle_start_idx = start_block_count
72
+ self._end_start_idx = start_block_count + middle_count
73
+
74
+ def _make_blocks(count: int) -> nn.ModuleList:
75
+ return nn.ModuleList(
76
+ [
77
+ FCDMBlock(
78
+ model_dim,
79
+ mlp_ratio,
80
+ depthwise_kernel_size=depthwise_kernel_size,
81
+ use_external_adaln=True,
82
+ )
83
+ for _ in range(count)
84
+ ]
85
+ )
86
+
87
+ self.start_blocks = _make_blocks(start_block_count)
88
+ self.middle_blocks = _make_blocks(middle_count)
89
+ self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
90
+ self.end_blocks = _make_blocks(end_block_count)
91
+
92
+ self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
93
+
94
+ self.out_proj = nn.Conv2d(
95
+ model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
96
+ )
97
+ self.unpatchify = nn.PixelShuffle(patch_size)
98
+
99
+ def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
100
+ """Compute packed AdaLN modulation = shared_base + per-layer delta."""
101
+ act = self.adaln_base.act(cond)
102
+ base_m = self.adaln_base.project_activated(act)
103
+ delta_m = self.adaln_deltas[layer_idx](act)
104
+ return base_m + delta_m
105
+
106
+ def _run_blocks(
107
+ self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
108
+ ) -> Tensor:
109
+ """Run a group of decoder blocks with per-block AdaLN modulation."""
110
+ for local_idx, block in enumerate(blocks):
111
+ adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
112
+ x = block(x, adaln_m=adaln_m)
113
+ return x
114
+
115
+ def forward(
116
+ self,
117
+ x_t: Tensor,
118
+ t: Tensor,
119
+ latents: Tensor,
120
+ *,
121
+ drop_middle_blocks: bool = False,
122
+ ) -> Tensor:
123
+ """Single decoder forward pass.
124
+
125
+ Args:
126
+ x_t: Noised image [B, C, H, W].
127
+ t: Timestep [B] in [0, 1].
128
+ latents: Encoder latents [B, bottleneck_dim, h, w].
129
+ drop_middle_blocks: Replace middle block output with mask feature (PDG).
130
+
131
+ Returns:
132
+ x0 prediction [B, C, H, W].
133
+ """
134
+ x_feat = self.patchify(x_t)
135
+ z_up = self.latent_up(latents)
136
+
137
+ fused = torch.cat([x_feat, z_up], dim=1)
138
+ fused = self.fuse_in(fused)
139
+
140
+ cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
141
+
142
+ start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
143
+
144
+ if drop_middle_blocks:
145
+ middle_out = self.path_drop_mask_feature.to(
146
+ device=x_t.device, dtype=x_t.dtype
147
+ ).expand_as(start_out)
148
+ else:
149
+ middle_out = self._run_blocks(
150
+ self.middle_blocks,
151
+ start_out,
152
+ cond,
153
+ start_index=self._middle_start_idx,
154
+ )
155
+
156
+ skip_fused = torch.cat([start_out, middle_out], dim=1)
157
+ skip_fused = self.fuse_skip(skip_fused)
158
+
159
+ end_out = self._run_blocks(
160
+ self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
161
+ )
162
+ patches = self.out_proj(end_out)
163
+ return self.unpatchify(patches)
dinac_ae/encoder.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Encoder matching the exported mixed DitBlock/FCDM architecture."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+
11
+ from dit.axial_rope2d import (
12
+ AxialRoPE2D,
13
+ AxialRoPE2DConfig,
14
+ AxialRoPE2DCoordMode,
15
+ AxialRoPE2DDimLayout,
16
+ AxialRoPE2DNormalizeCoords,
17
+ )
18
+ from dit.blocks import DitBlock
19
+ from dit.body_config import DiTConditioning
20
+ from dit.mlp_types import MLPType
21
+ from dit.position_encoding import DiTPositionEncoding
22
+
23
+ from .straight_through_encoder import Patchify
24
+
25
+ _ENCODER_HEAD_DIM = 64
26
+
27
+
28
+ def _resolve_encoder_mlp_type(name: str) -> MLPType:
29
+ """Return the encoder DiT MLP enum for the serialized config value."""
30
+
31
+ match str(name):
32
+ case "gelu":
33
+ return MLPType.GELU
34
+ case "silu":
35
+ return MLPType.SILU
36
+ case "relu":
37
+ return MLPType.RELU
38
+ case _ as unreachable:
39
+ raise ValueError(
40
+ "Unsupported encoder_mlp_type for DinacAE export: " f"{unreachable!r}"
41
+ )
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class EncoderPosterior:
46
+ """VP-parameterized diagonal Gaussian posterior."""
47
+
48
+ mean: Tensor
49
+ logsnr: Tensor
50
+
51
+ @property
52
+ def alpha(self) -> Tensor:
53
+ """Return the VP signal coefficient."""
54
+
55
+ logsnr_fp32 = self.logsnr.to(torch.float32)
56
+ return torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
57
+
58
+ @property
59
+ def sigma(self) -> Tensor:
60
+ """Return the VP noise coefficient."""
61
+
62
+ logsnr_fp32 = self.logsnr.to(torch.float32)
63
+ return torch.exp(0.5 * F.logsigmoid(-logsnr_fp32))
64
+
65
+ def mode(self) -> Tensor:
66
+ """Return the posterior mode in token space."""
67
+
68
+ return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)
69
+
70
+ def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
71
+ """Sample from the posterior."""
72
+
73
+ mean_fp32 = self.mean.to(torch.float32)
74
+ eps = torch.randn(
75
+ mean_fp32.shape,
76
+ device=mean_fp32.device,
77
+ dtype=torch.float32,
78
+ generator=generator,
79
+ )
80
+ return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)
81
+
82
+
83
+ class Encoder(nn.Module):
84
+ """Residual-patchify plus DitBlock encoder."""
85
+
86
+ def __init__(
87
+ self,
88
+ *,
89
+ in_channels: int,
90
+ patch_size: int,
91
+ model_dim: int,
92
+ depth: int,
93
+ bottleneck_dim: int,
94
+ mlp_ratio: float,
95
+ mlp_type: str,
96
+ bottleneck_posterior_kind: str,
97
+ bottleneck_norm_mode: str,
98
+ ) -> None:
99
+ super().__init__()
100
+ if int(model_dim) % int(_ENCODER_HEAD_DIM) != 0:
101
+ raise ValueError("model_dim must be divisible by encoder head dim")
102
+ self.bottleneck_dim: int = int(bottleneck_dim)
103
+ self.bottleneck_posterior_kind: str = str(bottleneck_posterior_kind)
104
+ self.bottleneck_norm_mode: str = str(bottleneck_norm_mode)
105
+ if self.bottleneck_norm_mode != "disabled":
106
+ raise ValueError("DINAC-AE export requires disabled bottleneck norm")
107
+ self.patchify = Patchify(
108
+ in_channels,
109
+ patch_size,
110
+ model_dim,
111
+ )
112
+ self.blocks = nn.ModuleList(
113
+ [
114
+ DitBlock(
115
+ d_model=int(model_dim),
116
+ n_heads=int(model_dim) // int(_ENCODER_HEAD_DIM),
117
+ mlp_ratio=float(mlp_ratio),
118
+ mlp_type=_resolve_encoder_mlp_type(mlp_type),
119
+ block_index=int(index),
120
+ use_norms=True,
121
+ position_encoding=DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED,
122
+ conditioning=DiTConditioning.UNCOND,
123
+ )
124
+ for index in range(int(depth))
125
+ ]
126
+ )
127
+ self.rope = AxialRoPE2D(
128
+ head_dim=int(_ENCODER_HEAD_DIM),
129
+ cfg=AxialRoPE2DConfig(
130
+ base=10_000.0,
131
+ min_period=None,
132
+ max_period=None,
133
+ coord_mode=AxialRoPE2DCoordMode.PATCH_INDICES,
134
+ normalize_coords=AxialRoPE2DNormalizeCoords.MAX,
135
+ dim_layout=AxialRoPE2DDimLayout.PAIR_INTERLEAVED,
136
+ angle_multiplier=1.0,
137
+ coord_offset=0.0,
138
+ frequency_aware=None,
139
+ beta_warp=None,
140
+ alpha_warp=None,
141
+ ),
142
+ )
143
+ match self.bottleneck_posterior_kind:
144
+ case "deterministic":
145
+ output_channels = int(bottleneck_dim)
146
+ case "diagonal_gaussian":
147
+ output_channels = 2 * int(bottleneck_dim)
148
+ case _ as unreachable:
149
+ raise RuntimeError(
150
+ f"Unsupported bottleneck_posterior_kind: {unreachable}"
151
+ )
152
+ self.to_bottleneck = nn.Conv2d(
153
+ int(model_dim),
154
+ output_channels,
155
+ kernel_size=1,
156
+ bias=True,
157
+ )
158
+
159
+ def _encode_projection(self, images: Tensor) -> Tensor:
160
+ """Encode images to the raw bottleneck projection."""
161
+
162
+ z = self.patchify(images)
163
+ batch, channels, height, width = z.shape
164
+ cond = torch.zeros(
165
+ (int(batch), int(channels)),
166
+ device=z.device,
167
+ dtype=z.dtype,
168
+ )
169
+ rope_sincos = self.rope(H=int(height), W=int(width), scales=None)
170
+ y = z
171
+ for block in self.blocks:
172
+ y = block(
173
+ y,
174
+ hw=(int(height), int(width)),
175
+ cond_vec=cond,
176
+ adaln_m=None,
177
+ rope_sincos=rope_sincos,
178
+ generator=None,
179
+ )
180
+ return self.to_bottleneck(y)
181
+
182
+ def _apply_bottleneck_norm(self, z: Tensor) -> Tensor:
183
+ """Return the unnormalized bottleneck mean."""
184
+
185
+ return z
186
+
187
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
188
+ """Encode images and return the posterior."""
189
+
190
+ if self.bottleneck_posterior_kind != "diagonal_gaussian":
191
+ raise RuntimeError(
192
+ "encode_posterior requires bottleneck_posterior_kind=diagonal_gaussian"
193
+ )
194
+ projection = self._encode_projection(images)
195
+ mean, logsnr = projection.chunk(2, dim=1)
196
+ mean = self._apply_bottleneck_norm(mean)
197
+ return EncoderPosterior(mean=mean, logsnr=logsnr)
198
+
199
+ def forward(self, images: Tensor) -> Tensor:
200
+ """Encode images to latent tokens."""
201
+
202
+ projection = self._encode_projection(images)
203
+ match self.bottleneck_posterior_kind:
204
+ case "diagonal_gaussian":
205
+ mean, logsnr = projection.chunk(2, dim=1)
206
+ mean = self._apply_bottleneck_norm(mean)
207
+ logsnr_fp32 = logsnr.to(torch.float32)
208
+ alpha = torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
209
+ return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
210
+ case "deterministic":
211
+ return self._apply_bottleneck_norm(projection)
212
+ case _ as unreachable:
213
+ raise RuntimeError(
214
+ f"Unsupported bottleneck_posterior_kind: {unreachable}"
215
+ )
dinac_ae/fcdm_block.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from .norms import ChannelWiseRMSNorm
10
+
11
+
12
+ class GRN(nn.Module):
13
+ """Global Response Normalization for NCHW tensors."""
14
+
15
+ def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.eps: float = float(eps)
18
+ c = int(channels)
19
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
20
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
24
+ g_fp32 = g.to(dtype=torch.float32)
25
+ n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
26
+ gamma = self.gamma.to(device=x.device, dtype=x.dtype)
27
+ beta = self.beta.to(device=x.device, dtype=x.dtype)
28
+ return gamma * (x * n) + beta + x
29
+
30
+
31
+ class FCDMBlock(nn.Module):
32
+ """ConvNeXt-style block with scale+gate AdaLN and GRN.
33
+
34
+ Two modes:
35
+ - Unconditioned (encoder): uses learned layer-scale for near-identity init.
36
+ - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
37
+ The gate is applied raw (no tanh).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ channels: int,
43
+ mlp_ratio: float,
44
+ *,
45
+ depthwise_kernel_size: int = 7,
46
+ use_external_adaln: bool = False,
47
+ norm_eps: float = 1e-6,
48
+ layer_scale_init: float = 1e-3,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.channels: int = int(channels)
52
+ self.mlp_ratio: float = float(mlp_ratio)
53
+
54
+ self.dwconv = nn.Conv2d(
55
+ channels,
56
+ channels,
57
+ kernel_size=depthwise_kernel_size,
58
+ padding=depthwise_kernel_size // 2,
59
+ stride=1,
60
+ groups=channels,
61
+ bias=True,
62
+ )
63
+ self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
64
+ hidden = max(int(float(channels) * float(mlp_ratio)), 1)
65
+ self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
66
+ self.grn = GRN(hidden, eps=1e-6)
67
+ self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
68
+
69
+ if not use_external_adaln:
70
+ self.layer_scale = nn.Parameter(
71
+ torch.full((channels,), float(layer_scale_init))
72
+ )
73
+ else:
74
+ self.register_parameter("layer_scale", None)
75
+
76
+ def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
77
+ b, c, _, _ = x.shape
78
+
79
+ if adaln_m is not None:
80
+ m = adaln_m.to(device=x.device, dtype=x.dtype)
81
+ scale, gate = m.chunk(2, dim=-1)
82
+ else:
83
+ scale = gate = None
84
+
85
+ h = self.dwconv(x)
86
+ h = self.norm(h)
87
+
88
+ if scale is not None:
89
+ h = h * (1.0 + scale.view(b, c, 1, 1))
90
+
91
+ h = self.pwconv1(h)
92
+ h = F.gelu(h)
93
+ h = self.grn(h)
94
+ h = self.pwconv2(h)
95
+
96
+ if gate is not None:
97
+ gate_view = gate.view(b, c, 1, 1)
98
+ else:
99
+ gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
100
+ device=h.device, dtype=h.dtype
101
+ )
102
+
103
+ return x + gate_view * h
dinac_ae/model.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone mixed DitBlock/FCDM diffusion autoencoder export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from dit.mlp_types import MLPType
11
+ from dit.repa_projection import DinoTokenAlignmentHead
12
+
13
+ from .config import DinacAEConfig, DinacAEInferenceConfig
14
+ from .decoder import Decoder
15
+ from .encoder import Encoder, EncoderPosterior
16
+ from .samplers import run_ddim, run_dpmpp_2m
17
+ from .vp_diffusion import get_schedule, make_initial_state, sample_noise
18
+
19
+
20
+ def _resolve_model_dir(
21
+ path_or_repo_id: str | Path,
22
+ *,
23
+ revision: str | None,
24
+ cache_dir: str | Path | None,
25
+ ) -> Path:
26
+ """Resolve a local path or Hugging Face repo ID to a directory."""
27
+
28
+ local = Path(path_or_repo_id)
29
+ if local.is_dir():
30
+ return local
31
+ repo_id = str(path_or_repo_id)
32
+ try:
33
+ from huggingface_hub import snapshot_download
34
+ except ImportError as exc:
35
+ raise ImportError(
36
+ f"'{repo_id}' is not an existing local directory. Install "
37
+ "huggingface_hub to load from the Hub."
38
+ ) from exc
39
+ cache_dir_str = str(cache_dir) if cache_dir is not None else None
40
+ return Path(
41
+ snapshot_download(
42
+ repo_id,
43
+ revision=revision,
44
+ cache_dir=cache_dir_str,
45
+ )
46
+ )
47
+
48
+
49
+ def _resolve_class_head_mlp_type(name: str) -> MLPType:
50
+ """Return the token-head MLP enum for the serialized config value."""
51
+
52
+ match str(name):
53
+ case "gelu":
54
+ return MLPType.GELU
55
+ case "silu":
56
+ return MLPType.SILU
57
+ case "relu":
58
+ return MLPType.RELU
59
+ case _ as unreachable:
60
+ raise ValueError(
61
+ "Unsupported class_head_mlp_type for DinacAE export: "
62
+ f"{unreachable!r}"
63
+ )
64
+
65
+
66
+ class DinacAE(nn.Module):
67
+ """Exported DINAC-AE wrapper with encode/decode/predict_class APIs."""
68
+
69
+ def __init__(self, config: DinacAEConfig) -> None:
70
+ super().__init__()
71
+ self.config = config
72
+ self.register_buffer(
73
+ "latent_norm_running_mean",
74
+ torch.zeros((config.latent_channels,), dtype=torch.float32),
75
+ )
76
+ self.register_buffer(
77
+ "latent_norm_running_var",
78
+ torch.ones((config.latent_channels,), dtype=torch.float32),
79
+ )
80
+ self.encoder = Encoder(
81
+ in_channels=int(config.in_channels),
82
+ patch_size=int(config.patch_size),
83
+ model_dim=int(config.model_dim),
84
+ depth=int(config.encoder_depth),
85
+ bottleneck_dim=int(config.bottleneck_dim),
86
+ mlp_ratio=float(config.mlp_ratio),
87
+ mlp_type=str(config.encoder_mlp_type),
88
+ bottleneck_posterior_kind=str(config.bottleneck_posterior_kind),
89
+ bottleneck_norm_mode=str(config.bottleneck_norm_mode),
90
+ )
91
+ self.decoder = Decoder(
92
+ in_channels=int(config.in_channels),
93
+ patch_size=int(config.patch_size),
94
+ model_dim=int(config.model_dim),
95
+ depth=int(config.decoder_depth),
96
+ start_block_count=int(config.decoder_start_blocks),
97
+ end_block_count=int(config.decoder_end_blocks),
98
+ bottleneck_dim=int(config.bottleneck_dim),
99
+ mlp_ratio=float(config.mlp_ratio),
100
+ depthwise_kernel_size=int(config.depthwise_kernel_size),
101
+ adaln_low_rank_rank=int(config.adaln_low_rank_rank),
102
+ )
103
+ self.dino_token_alignment_head = DinoTokenAlignmentHead(
104
+ in_channels=int(config.bottleneck_dim),
105
+ feature_dim=int(config.class_head_feature_dim),
106
+ model_dim=int(config.class_head_model_dim),
107
+ head_dim=int(config.class_head_head_dim),
108
+ mlp_ratio=float(config.class_head_mlp_ratio),
109
+ mlp_activation=_resolve_class_head_mlp_type(config.class_head_mlp_type),
110
+ block_index=10_001,
111
+ register_token_count=int(config.class_head_register_token_count),
112
+ )
113
+
114
+ def _restore_float32_norm_buffers(self) -> None:
115
+ """Keep latent running stats in float32 after device/dtype moves."""
116
+
117
+ self.latent_norm_running_mean = self.latent_norm_running_mean.to(
118
+ dtype=torch.float32
119
+ )
120
+ self.latent_norm_running_var = self.latent_norm_running_var.to(
121
+ dtype=torch.float32
122
+ )
123
+
124
+ def to(self, *args: object, **kwargs: object) -> DinacAE:
125
+ """Move the model while preserving float32 latent stats buffers."""
126
+
127
+ moved = super().to(*args, **kwargs)
128
+ if not isinstance(moved, DinacAE):
129
+ raise RuntimeError(
130
+ f"Expected DinacAE after nn.Module.to(), got {type(moved).__name__}"
131
+ )
132
+ moved._restore_float32_norm_buffers()
133
+ return moved
134
+
135
+ @classmethod
136
+ def from_pretrained(
137
+ cls,
138
+ path_or_repo_id: str | Path,
139
+ *,
140
+ dtype: torch.dtype = torch.bfloat16,
141
+ device: str | torch.device = "cpu",
142
+ revision: str | None = None,
143
+ cache_dir: str | Path | None = None,
144
+ ) -> DinacAE:
145
+ """Load a pretrained export from a local directory or the Hub."""
146
+
147
+ model_dir = _resolve_model_dir(
148
+ path_or_repo_id,
149
+ revision=revision,
150
+ cache_dir=cache_dir,
151
+ )
152
+ config = DinacAEConfig.load(model_dir / "config.json")
153
+ model = cls(config)
154
+ safetensors_path = model_dir / "model.safetensors"
155
+ if safetensors_path.exists():
156
+ try:
157
+ from safetensors.torch import load_file
158
+ except ImportError as exc:
159
+ raise ImportError(
160
+ "safetensors is required to load model.safetensors"
161
+ ) from exc
162
+ state_dict = load_file(str(safetensors_path), device="cpu")
163
+ else:
164
+ raise FileNotFoundError(
165
+ f"No model weights found in {model_dir}. Expected model.safetensors."
166
+ )
167
+ model.load_state_dict(state_dict, strict=True)
168
+ model = model.to(dtype=dtype, device=torch.device(device))
169
+ model.eval()
170
+ return model
171
+
172
+ def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
173
+ """Return ``(mean, std)`` tensors for latent whitening."""
174
+
175
+ mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
176
+ var = self.latent_norm_running_var.view(1, -1, 1, 1)
177
+ std = torch.sqrt(
178
+ var.to(torch.float32) + float(self.config.latent_running_stats_eps)
179
+ )
180
+ return mean.to(torch.float32), std
181
+
182
+ def _require_image_size_divisible(self, height: int, width: int) -> None:
183
+ """Require image dimensions compatible with the exported patch size."""
184
+
185
+ patch = int(self.config.effective_patch_size)
186
+ if int(height) % patch != 0 or int(width) % patch != 0:
187
+ raise ValueError(
188
+ f"Image height={height} and width={width} must be divisible by "
189
+ f"effective_patch_size={patch}"
190
+ )
191
+
192
+ def whiten(self, latents: Tensor) -> Tensor:
193
+ """Whiten raw latents using exported running stats."""
194
+
195
+ z = latents.to(torch.float32)
196
+ mean, std = self._latent_norm_stats()
197
+ return (z - mean.to(device=z.device)) / std.to(device=z.device)
198
+
199
+ def dewhiten(self, latents: Tensor) -> Tensor:
200
+ """Undo latent whitening back to the raw decoder scale."""
201
+
202
+ z = latents.to(torch.float32)
203
+ mean, std = self._latent_norm_stats()
204
+ return z * std.to(device=z.device) + mean.to(device=z.device)
205
+
206
+ def encode(self, images: Tensor) -> Tensor:
207
+ """Encode images to the exported whitened latent space."""
208
+
209
+ self._require_image_size_divisible(
210
+ height=int(images.shape[2]),
211
+ width=int(images.shape[3]),
212
+ )
213
+ model_dtype = next(self.parameters()).dtype
214
+ latents = self.encoder(images.to(dtype=model_dtype))
215
+ return self.whiten(latents).to(dtype=model_dtype)
216
+
217
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
218
+ """Encode images and return the raw posterior."""
219
+
220
+ self._require_image_size_divisible(
221
+ height=int(images.shape[2]),
222
+ width=int(images.shape[3]),
223
+ )
224
+ model_dtype = next(self.parameters()).dtype
225
+ return self.encoder.encode_posterior(images.to(dtype=model_dtype))
226
+
227
+ def predict_class(self, latents: Tensor) -> Tensor:
228
+ """Predict the exported DINO class token from whitened latents."""
229
+
230
+ dewhitened = self.dewhiten(latents)
231
+ t_zero = torch.zeros(
232
+ (int(latents.shape[0]),),
233
+ device=latents.device,
234
+ dtype=torch.float32,
235
+ )
236
+ head_dtype = self.dino_token_alignment_head.in_proj.weight.dtype
237
+ device_type = "cuda" if latents.device.type == "cuda" else "cpu"
238
+ with torch.autocast(device_type=device_type, enabled=False):
239
+ out = self.dino_token_alignment_head(
240
+ dewhitened.to(device=latents.device, dtype=head_dtype),
241
+ t=t_zero,
242
+ )
243
+ return out.class_token.to(torch.float32)
244
+
245
+ def decode(
246
+ self,
247
+ latents: Tensor,
248
+ height: int,
249
+ width: int,
250
+ *,
251
+ inference_config: DinacAEInferenceConfig | None = None,
252
+ ) -> Tensor:
253
+ """Decode exported whitened latents to images via VP diffusion."""
254
+
255
+ cfg = (
256
+ inference_config
257
+ if inference_config is not None
258
+ else DinacAEInferenceConfig()
259
+ )
260
+ self._require_image_size_divisible(height=int(height), width=int(width))
261
+ batch = int(latents.shape[0])
262
+ device = latents.device
263
+ model_dtype = next(self.parameters()).dtype
264
+ decoder_latents = self.dewhiten(latents).to(device=device, dtype=model_dtype)
265
+ noise = sample_noise(
266
+ (batch, int(self.config.in_channels), int(height), int(width)),
267
+ noise_std=float(self.config.pixel_noise_std),
268
+ seed=cfg.seed,
269
+ device=torch.device("cpu"),
270
+ dtype=torch.float32,
271
+ )
272
+ schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
273
+ initial_state = make_initial_state(
274
+ noise=noise.to(device=device),
275
+ t_start=schedule[0:1],
276
+ logsnr_min=float(self.config.logsnr_min),
277
+ logsnr_max=float(self.config.logsnr_max),
278
+ )
279
+ device_type = "cuda" if device.type == "cuda" else "cpu"
280
+ with torch.autocast(device_type=device_type, enabled=False):
281
+
282
+ def _forward_fn(
283
+ x_t: Tensor,
284
+ t: Tensor,
285
+ latents_in: Tensor,
286
+ *,
287
+ drop_middle_blocks: bool = False,
288
+ mask_latent_tokens: bool = False,
289
+ ) -> Tensor:
290
+ _ = mask_latent_tokens
291
+ return self.decoder(
292
+ x_t.to(dtype=model_dtype),
293
+ t,
294
+ latents_in.to(dtype=model_dtype),
295
+ drop_middle_blocks=bool(drop_middle_blocks),
296
+ )
297
+
298
+ match cfg.sampler:
299
+ case "ddim":
300
+ sampler_fn = run_ddim
301
+ case "dpmpp_2m":
302
+ sampler_fn = run_dpmpp_2m
303
+ case _ as unreachable:
304
+ raise ValueError(f"Unsupported sampler: {unreachable!r}")
305
+ pdg_mode = "path_drop" if bool(cfg.pdg) else "disabled"
306
+ return sampler_fn(
307
+ forward_fn=_forward_fn,
308
+ initial_state=initial_state,
309
+ schedule=schedule,
310
+ latents=decoder_latents,
311
+ logsnr_min=float(self.config.logsnr_min),
312
+ logsnr_max=float(self.config.logsnr_max),
313
+ pdg_mode=pdg_mode,
314
+ pdg_strength=float(cfg.pdg_strength),
315
+ device=device,
316
+ )
317
+
318
+ def reconstruct(
319
+ self,
320
+ images: Tensor,
321
+ *,
322
+ inference_config: DinacAEInferenceConfig | None = None,
323
+ ) -> Tensor:
324
+ """Encode then decode one image batch."""
325
+
326
+ latents = self.encode(images)
327
+ _batch, _channels, height, width = images.shape
328
+ return self.decode(
329
+ latents,
330
+ height=int(height),
331
+ width=int(width),
332
+ inference_config=inference_config,
333
+ )
dinac_ae/norms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Channel-wise RMSNorm for NCHW tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class ChannelWiseRMSNorm(nn.Module):
10
+ """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
+
12
+ Normalizes across channels per spatial position. Supports optional
13
+ per-channel affine weight and bias.
14
+ """
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
+ super().__init__()
18
+ self.channels: int = int(channels)
19
+ self._eps: float = float(eps)
20
+ if affine:
21
+ self.weight = nn.Parameter(torch.ones(self.channels))
22
+ self.bias = nn.Parameter(torch.zeros(self.channels))
23
+ else:
24
+ self.register_parameter("weight", None)
25
+ self.register_parameter("bias", None)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ if x.dim() < 2:
29
+ return x
30
+ # Float32 accumulation for stability
31
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
+ inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms.to(dtype=x.dtype)
34
+ if self.weight is not None:
35
+ shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=x.dtype)
37
+ if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=x.dtype)
39
+ return y
dinac_ae/samplers.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDIM and DPM++2M samplers for VP diffusion with path-drop PDG support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .vp_diffusion import (
11
+ alpha_sigma_from_logsnr,
12
+ broadcast_time_like,
13
+ shifted_cosine_interpolated_logsnr_from_t,
14
+ )
15
+
16
+
17
+ class DecoderForwardFn(Protocol):
18
+ """Callable that predicts x0 from (x_t, t, latents) with path-drop PDG flag."""
19
+
20
+ def __call__(
21
+ self,
22
+ x_t: Tensor,
23
+ t: Tensor,
24
+ latents: Tensor,
25
+ *,
26
+ drop_middle_blocks: bool = False,
27
+ mask_latent_tokens: bool = False,
28
+ ) -> Tensor: ...
29
+
30
+
31
+ def _reconstruct_eps_from_x0(
32
+ *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
33
+ ) -> Tensor:
34
+ """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
35
+
36
+ eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
37
+ """
38
+ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
39
+ sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
40
+ x_t_f32 = x_t.to(torch.float32)
41
+ x0_f32 = x0_hat.to(torch.float32)
42
+ return (x_t_f32 - alpha_view * x0_f32) / sigma_view
43
+
44
+
45
+ def _ddim_step(
46
+ *,
47
+ x0_hat: Tensor,
48
+ eps_hat: Tensor,
49
+ alpha_next: Tensor,
50
+ sigma_next: Tensor,
51
+ ref: Tensor,
52
+ ) -> Tensor:
53
+ """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
54
+ a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
55
+ s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
56
+ return a * x0_hat + s * eps_hat
57
+
58
+
59
+ def _predict_with_pdg(
60
+ forward_fn: DecoderForwardFn,
61
+ state: Tensor,
62
+ t_vec: Tensor,
63
+ latents: Tensor,
64
+ *,
65
+ pdg_mode: str,
66
+ pdg_strength: float,
67
+ ) -> Tensor:
68
+ """Run decoder forward with optional PDG guidance.
69
+
70
+ Args:
71
+ forward_fn: Decoder forward function.
72
+ state: Current noised state [B, C, H, W].
73
+ t_vec: Timestep vector [B].
74
+ latents: Encoder latents.
75
+ pdg_mode: "disabled" or "path_drop".
76
+ pdg_strength: CFG-like strength for PDG.
77
+
78
+ Returns:
79
+ x0_hat prediction in float32.
80
+ """
81
+ match pdg_mode:
82
+ case "path_drop":
83
+ x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
84
+ torch.float32
85
+ )
86
+ x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
87
+ torch.float32
88
+ )
89
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
90
+ case "disabled":
91
+ return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
92
+ torch.float32
93
+ )
94
+ case _ as unreachable:
95
+ raise ValueError(f"Unsupported PDG mode: {unreachable!r}")
96
+
97
+
98
+ def run_ddim(
99
+ *,
100
+ forward_fn: DecoderForwardFn,
101
+ initial_state: Tensor,
102
+ schedule: Tensor,
103
+ latents: Tensor,
104
+ logsnr_min: float,
105
+ logsnr_max: float,
106
+ log_change_high: float = 0.0,
107
+ log_change_low: float = 0.0,
108
+ pdg_mode: str = "disabled",
109
+ pdg_strength: float = 1.5,
110
+ device: torch.device | None = None,
111
+ ) -> Tensor:
112
+ """Run DDIM sampling loop with path-drop PDG support.
113
+
114
+ Args:
115
+ forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
116
+ initial_state: Starting noised state [B, C, H, W] in float32.
117
+ schedule: Descending t-schedule [num_steps] in [0, 1].
118
+ latents: Encoder latents [B, bottleneck_dim, h, w].
119
+ logsnr_min, logsnr_max: VP schedule endpoints.
120
+ log_change_high, log_change_low: Shifted-cosine schedule parameters.
121
+ pdg_mode: "disabled" or "path_drop".
122
+ pdg_strength: CFG-like strength for PDG.
123
+ device: Target device.
124
+
125
+ Returns:
126
+ Denoised samples [B, C, H, W] in float32.
127
+ """
128
+ run_device = device or initial_state.device
129
+ batch_size = int(initial_state.shape[0])
130
+ state = initial_state.to(device=run_device, dtype=torch.float32)
131
+
132
+ # Precompute logSNR, alpha, sigma for all schedule points
133
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
134
+ schedule.to(device=run_device),
135
+ logsnr_min=logsnr_min,
136
+ logsnr_max=logsnr_max,
137
+ log_change_high=log_change_high,
138
+ log_change_low=log_change_low,
139
+ )
140
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
141
+
142
+ for i in range(int(schedule.numel()) - 1):
143
+ t_i = schedule[i]
144
+ a_t = alpha_sched[i].expand(batch_size)
145
+ s_t = sigma_sched[i].expand(batch_size)
146
+ a_next = alpha_sched[i + 1].expand(batch_size)
147
+ s_next = sigma_sched[i + 1].expand(batch_size)
148
+
149
+ # Model prediction with optional PDG
150
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
151
+ x0_hat = _predict_with_pdg(
152
+ forward_fn,
153
+ state,
154
+ t_vec,
155
+ latents,
156
+ pdg_mode=pdg_mode,
157
+ pdg_strength=pdg_strength,
158
+ )
159
+
160
+ eps_hat = _reconstruct_eps_from_x0(
161
+ x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
162
+ )
163
+ state = _ddim_step(
164
+ x0_hat=x0_hat,
165
+ eps_hat=eps_hat,
166
+ alpha_next=a_next,
167
+ sigma_next=s_next,
168
+ ref=state,
169
+ )
170
+
171
+ return state
172
+
173
+
174
+ def run_dpmpp_2m(
175
+ *,
176
+ forward_fn: DecoderForwardFn,
177
+ initial_state: Tensor,
178
+ schedule: Tensor,
179
+ latents: Tensor,
180
+ logsnr_min: float,
181
+ logsnr_max: float,
182
+ log_change_high: float = 0.0,
183
+ log_change_low: float = 0.0,
184
+ pdg_mode: str = "disabled",
185
+ pdg_strength: float = 1.5,
186
+ device: torch.device | None = None,
187
+ ) -> Tensor:
188
+ """Run DPM++2M sampling loop with path-drop PDG support.
189
+
190
+ Multi-step solver using exponential integrator formulation in half-lambda space.
191
+ """
192
+ run_device = device or initial_state.device
193
+ batch_size = int(initial_state.shape[0])
194
+ state = initial_state.to(device=run_device, dtype=torch.float32)
195
+
196
+ # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
197
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
198
+ schedule.to(device=run_device),
199
+ logsnr_min=logsnr_min,
200
+ logsnr_max=logsnr_max,
201
+ log_change_high=log_change_high,
202
+ log_change_low=log_change_low,
203
+ )
204
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
205
+ half_lambda = 0.5 * lmb.to(torch.float32)
206
+
207
+ x0_prev: Tensor | None = None
208
+
209
+ for i in range(int(schedule.numel()) - 1):
210
+ t_i = schedule[i]
211
+ s_t = sigma_sched[i].expand(batch_size)
212
+ a_next = alpha_sched[i + 1].expand(batch_size)
213
+ s_next = sigma_sched[i + 1].expand(batch_size)
214
+
215
+ # Model prediction with optional PDG
216
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
217
+ x0_hat = _predict_with_pdg(
218
+ forward_fn,
219
+ state,
220
+ t_vec,
221
+ latents,
222
+ pdg_mode=pdg_mode,
223
+ pdg_strength=pdg_strength,
224
+ )
225
+
226
+ lam_t = half_lambda[i].expand(batch_size)
227
+ lam_next = half_lambda[i + 1].expand(batch_size)
228
+ h = (lam_next - lam_t).to(torch.float32)
229
+ phi_1 = torch.expm1(-h)
230
+
231
+ sigma_ratio = (s_next / s_t).to(torch.float32)
232
+
233
+ if i == 0 or x0_prev is None:
234
+ # First-order step
235
+ state = (
236
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
237
+ - broadcast_time_like(a_next, state).to(torch.float32)
238
+ * broadcast_time_like(phi_1, state).to(torch.float32)
239
+ * x0_hat
240
+ )
241
+ else:
242
+ # Second-order step
243
+ lam_prev = half_lambda[i - 1].expand(batch_size)
244
+ h_0 = (lam_t - lam_prev).to(torch.float32)
245
+ r0 = h_0 / h
246
+ d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
247
+ common = broadcast_time_like(a_next, state).to(
248
+ torch.float32
249
+ ) * broadcast_time_like(phi_1, state).to(torch.float32)
250
+ state = (
251
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
252
+ - common * x0_hat
253
+ - 0.5 * common * d1_0
254
+ )
255
+
256
+ x0_prev = x0_hat
257
+
258
+ return state
dinac_ae/straight_through_encoder.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Patch embedding used by the exported DINAC-AE model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Final
6
+
7
+ from torch import Tensor, nn
8
+
9
+ __all__ = ["Patchify", "StraightThroughEncoder"]
10
+
11
+
12
+ class StraightThroughEncoder(nn.Module):
13
+ """Project non-overlapping image patches with a stride-patch convolution."""
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ patch: int,
19
+ out_channels: int,
20
+ ) -> None:
21
+ super().__init__()
22
+ if in_channels <= 0:
23
+ raise ValueError("in_channels must be positive")
24
+ if patch <= 0:
25
+ raise ValueError("patch must be positive")
26
+ if out_channels <= 0:
27
+ raise ValueError("out_channels must be positive")
28
+ self.in_channels: Final[int] = int(in_channels)
29
+ self.patch: Final[int] = int(patch)
30
+ self._output_channels: Final[int] = int(out_channels)
31
+ self.proj = nn.Conv2d(
32
+ self.in_channels,
33
+ self._output_channels,
34
+ kernel_size=self.patch,
35
+ stride=self.patch,
36
+ bias=True,
37
+ )
38
+
39
+ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
40
+ """Return the patchified token grid."""
41
+
42
+ return self.proj(x)
43
+
44
+ @property
45
+ def output_channels(self) -> int:
46
+ """Return the output channel width produced by the encoder."""
47
+
48
+ return int(self._output_channels)
49
+
50
+ @property
51
+ def latent_channels(self) -> int:
52
+ """Alias for ``output_channels`` to match encoder interface shape."""
53
+
54
+ return int(self._output_channels)
55
+
56
+
57
+ Patchify = StraightThroughEncoder
dinac_ae/time_embed.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinusoidal timestep embedding with MLP projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def _log_spaced_frequencies(
12
+ half: int, max_period: float, *, device: torch.device | None = None
13
+ ) -> Tensor:
14
+ """Log-spaced frequencies for sinusoidal embedding."""
15
+ return torch.exp(
16
+ -math.log(max_period)
17
+ * torch.arange(half, device=device, dtype=torch.float32)
18
+ / max(float(half - 1), 1.0)
19
+ )
20
+
21
+
22
+ def sinusoidal_time_embedding(
23
+ t: Tensor,
24
+ dim: int,
25
+ *,
26
+ max_period: float = 10000.0,
27
+ scale: float | None = None,
28
+ freqs: Tensor | None = None,
29
+ ) -> Tensor:
30
+ """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
+ t32 = t.to(torch.float32)
32
+ if scale is not None:
33
+ t32 = t32 * float(scale)
34
+ half = dim // 2
35
+ if freqs is not None:
36
+ freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
+ else:
38
+ freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
+ angles = t32[:, None] * freqs[None, :]
40
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
+
42
+
43
+ class SinusoidalTimeEmbeddingMLP(nn.Module):
44
+ """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ *,
50
+ freq_dim: int = 256,
51
+ hidden_mult: float = 1.0,
52
+ time_scale: float = 1000.0,
53
+ max_period: float = 10000.0,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.dim = int(dim)
57
+ self.freq_dim = int(freq_dim)
58
+ self.time_scale = float(time_scale)
59
+ self.max_period = float(max_period)
60
+ hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
+
62
+ freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
+ self.register_buffer("freqs", freqs, persistent=True)
64
+
65
+ self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
+ self.act = nn.SiLU()
67
+ self.proj_out = nn.Linear(hidden_dim, self.dim)
68
+
69
+ def forward(self, t: Tensor) -> Tensor:
70
+ freqs: Tensor = self.freqs # type: ignore[assignment]
71
+ emb_freq = sinusoidal_time_embedding(
72
+ t.to(torch.float32),
73
+ self.freq_dim,
74
+ max_period=self.max_period,
75
+ scale=self.time_scale,
76
+ freqs=freqs,
77
+ )
78
+ dtype_in = self.proj_in.weight.dtype
79
+ hidden = self.proj_in(emb_freq.to(dtype_in))
80
+ hidden = self.act(hidden)
81
+ if hidden.dtype != self.proj_out.weight.dtype:
82
+ hidden = hidden.to(self.proj_out.weight.dtype)
83
+ return self.proj_out(hidden)
dinac_ae/vp_diffusion.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+
12
+ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
13
+ """Compute (alpha, sigma) from logSNR in float32.
14
+
15
+ VP constraint: alpha^2 + sigma^2 = 1.
16
+ """
17
+ lmb32 = lmb.to(dtype=torch.float32)
18
+ alpha = torch.exp(0.5 * F.logsigmoid(lmb32))
19
+ sigma = torch.exp(0.5 * F.logsigmoid(-lmb32))
20
+ return alpha, sigma
21
+
22
+
23
+ def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
24
+ """Broadcast [B] coefficient to match x for per-sample scaling."""
25
+ view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
26
+ return coeff.view(view_shape)
27
+
28
+
29
+ def _cosine_interpolated_params(
30
+ logsnr_min: float, logsnr_max: float
31
+ ) -> tuple[float, float]:
32
+ """Compute (a, b) for cosine-interpolated logSNR schedule.
33
+
34
+ logsnr(t) = -2 * log(tan(a*t + b))
35
+ logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
36
+ """
37
+ b = math.atan(math.exp(-0.5 * logsnr_max))
38
+ a = math.atan(math.exp(-0.5 * logsnr_min)) - b
39
+ return a, b
40
+
41
+
42
+ def cosine_interpolated_logsnr_from_t(
43
+ t: Tensor, *, logsnr_min: float, logsnr_max: float
44
+ ) -> Tensor:
45
+ """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
46
+ a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
47
+ t32 = t.to(dtype=torch.float32)
48
+ a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
49
+ b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
50
+ u = a_t * t32 + b_t
51
+ return -2.0 * torch.log(torch.tan(u))
52
+
53
+
54
+ def shifted_cosine_interpolated_logsnr_from_t(
55
+ t: Tensor,
56
+ *,
57
+ logsnr_min: float,
58
+ logsnr_max: float,
59
+ log_change_high: float = 0.0,
60
+ log_change_low: float = 0.0,
61
+ ) -> Tensor:
62
+ """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
63
+
64
+ lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
65
+ """
66
+ base = cosine_interpolated_logsnr_from_t(
67
+ t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
68
+ )
69
+ t32 = t.to(dtype=torch.float32)
70
+ high = base + float(log_change_high)
71
+ low = base + float(log_change_low)
72
+ return (1.0 - t32) * high + t32 * low
73
+
74
+
75
+ def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
76
+ """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
77
+
78
+ ``num_steps`` is the number of function evaluations (NFE = decoder forward
79
+ passes). Internally the schedule has ``num_steps + 1`` time points
80
+ (including both endpoints).
81
+
82
+ Args:
83
+ schedule_type: "linear" or "cosine".
84
+ num_steps: Number of decoder forward passes (NFE), >= 1.
85
+
86
+ Returns:
87
+ Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
88
+ """
89
+ if int(num_steps) < 1:
90
+ raise ValueError("num_steps must be at least 1")
91
+ n = int(num_steps) + 1
92
+ match schedule_type:
93
+ case "linear":
94
+ base = torch.linspace(0.0, 1.0, n)
95
+ case "cosine":
96
+ i = torch.arange(n, dtype=torch.float32)
97
+ base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
98
+ case _ as unreachable:
99
+ raise ValueError(
100
+ f"Unsupported schedule type: {unreachable!r}. "
101
+ "Use 'linear' or 'cosine'."
102
+ )
103
+ # Descending: high t (noisy) -> low t (clean)
104
+ return torch.flip(base, dims=[0])
105
+
106
+
107
+ def make_initial_state(
108
+ *,
109
+ noise: Tensor,
110
+ t_start: Tensor,
111
+ logsnr_min: float,
112
+ logsnr_max: float,
113
+ log_change_high: float = 0.0,
114
+ log_change_low: float = 0.0,
115
+ ) -> Tensor:
116
+ """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
117
+
118
+ All math in float32.
119
+ """
120
+ batch = int(noise.shape[0])
121
+ lmb_start = shifted_cosine_interpolated_logsnr_from_t(
122
+ t_start.expand(batch).to(dtype=torch.float32),
123
+ logsnr_min=logsnr_min,
124
+ logsnr_max=logsnr_max,
125
+ log_change_high=log_change_high,
126
+ log_change_low=log_change_low,
127
+ )
128
+ _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
129
+ sigma_view = broadcast_time_like(sigma_start, noise)
130
+ return sigma_view * noise.to(dtype=torch.float32)
131
+
132
+
133
+ def sample_noise(
134
+ shape: tuple[int, ...],
135
+ *,
136
+ noise_std: float = 1.0,
137
+ seed: int | None = None,
138
+ device: torch.device | None = None,
139
+ dtype: torch.dtype = torch.float32,
140
+ ) -> Tensor:
141
+ """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
142
+ if seed is None:
143
+ noise = torch.randn(
144
+ shape, device=device or torch.device("cpu"), dtype=torch.float32
145
+ )
146
+ else:
147
+ gen = torch.Generator(device="cpu")
148
+ gen.manual_seed(int(seed))
149
+ noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
150
+ noise = noise.mul(float(noise_std))
151
+ target_device = device if device is not None else torch.device("cpu")
152
+ return noise.to(device=target_device, dtype=dtype)
dit/attention_blocks.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dense SDPA attention blocks used by the DINAC-AE export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+
11
+ from common.norms import RMSNorm
12
+ from common.rope import rotate_half, rotate_half_adjacent
13
+ from dit.position_encoding import DiTPositionEncoding
14
+
15
+
16
+ def _axial_rope_rotate_fn(
17
+ position_encoding: DiTPositionEncoding,
18
+ ) -> Callable[[Tensor], Tensor]:
19
+ """Return the head-dimension rotation matching the configured RoPE layout."""
20
+
21
+ match position_encoding:
22
+ case (
23
+ DiTPositionEncoding.ROPE_2D_AXIAL_DILATED
24
+ | DiTPositionEncoding.ROPE_2D_AXIAL_NORMALIZED
25
+ | DiTPositionEncoding.ROPE_2D_AXIAL_FREQ_AWARE
26
+ | DiTPositionEncoding.ROPE_1D
27
+ ):
28
+ return rotate_half
29
+ case (
30
+ DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED
31
+ | DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED_DILATED
32
+ | DiTPositionEncoding.ROPE_2D_AXIAL_BETA_WARP
33
+ | DiTPositionEncoding.ROPE_2D_AXIAL_ALPHA_WARP
34
+ | DiTPositionEncoding.ROPE_3D_ZIMAGE
35
+ ):
36
+ return rotate_half_adjacent
37
+ case _ as unreachable:
38
+ raise ValueError(f"Unsupported RoPE position encoding: {unreachable}")
39
+
40
+
41
+ class DitSelfAttentionCore(nn.Module):
42
+ """Dense self-attention core with optional axial RoPE on Q/K."""
43
+
44
+ d_model: int
45
+ n_heads: int
46
+ head_dim: int
47
+ position_encoding: DiTPositionEncoding
48
+ qkv: nn.Linear
49
+ proj_out: nn.Linear
50
+ q_norm: RMSNorm
51
+ k_norm: RMSNorm
52
+
53
+ def __init__(
54
+ self,
55
+ d_model: int,
56
+ n_heads: int,
57
+ *,
58
+ position_encoding: DiTPositionEncoding,
59
+ ) -> None:
60
+ super().__init__()
61
+ if d_model % n_heads != 0:
62
+ raise ValueError("d_model must be divisible by n_heads")
63
+ self.d_model = int(d_model)
64
+ self.n_heads = int(n_heads)
65
+ self.head_dim = int(self.d_model // self.n_heads)
66
+ self.position_encoding = position_encoding
67
+ self.qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=False)
68
+ self.proj_out = nn.Linear(self.d_model, self.d_model, bias=False)
69
+ self.q_norm = RMSNorm(self.head_dim)
70
+ self.k_norm = RMSNorm(self.head_dim)
71
+
72
+ def reset_parameters(self) -> None:
73
+ """Reset projections to their initialization."""
74
+
75
+ nn.init.xavier_uniform_(self.qkv.weight)
76
+ nn.init.xavier_uniform_(self.proj_out.weight)
77
+
78
+ def forward(
79
+ self, tokens: Tensor, *, rope_sincos: tuple[Tensor, Tensor] | None
80
+ ) -> Tensor:
81
+ """Apply dense self-attention to ``[B, N, D]`` tokens."""
82
+
83
+ batch, sequence_length, _width = tokens.shape
84
+ qkv = self.qkv(tokens)
85
+ q, k, v = qkv.chunk(3, dim=-1)
86
+ q = q.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2)
87
+ k = k.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2)
88
+ v = v.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2)
89
+ q = self.q_norm(q.contiguous())
90
+ k = self.k_norm(k.contiguous())
91
+ q, k = self._apply_axial_rope_dense(q, k, rope_sincos=rope_sincos)
92
+ attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
93
+ attn = (
94
+ attn.transpose(1, 2).contiguous().view(batch, sequence_length, self.d_model)
95
+ )
96
+ return self.proj_out(attn)
97
+
98
+ def _apply_axial_rope_dense(
99
+ self,
100
+ q: Tensor,
101
+ k: Tensor,
102
+ *,
103
+ rope_sincos: tuple[Tensor, Tensor] | None,
104
+ ) -> tuple[Tensor, Tensor]:
105
+ """Apply axial RoPE to dense Q/K tensors."""
106
+
107
+ if rope_sincos is None:
108
+ return q, k
109
+ sin, cos = rope_sincos
110
+ rope_len = int(sin.shape[-2])
111
+ rope_dtype = sin.dtype
112
+ q_dtype = q.dtype
113
+ k_dtype = k.dtype
114
+ q_rope = q.to(dtype=rope_dtype)
115
+ k_rope = k.to(dtype=rope_dtype)
116
+ match sin.dim():
117
+ case 2:
118
+ sin_b = sin.view(1, 1, rope_len, self.head_dim)
119
+ cos_b = cos.view(1, 1, rope_len, self.head_dim)
120
+ case 3:
121
+ sin_b = sin.view(int(q.shape[0]), 1, rope_len, self.head_dim)
122
+ cos_b = cos.view(int(q.shape[0]), 1, rope_len, self.head_dim)
123
+ case _ as unreachable:
124
+ raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}")
125
+ rotate = _axial_rope_rotate_fn(self.position_encoding)
126
+ q_span = q_rope[:, :, :rope_len, :]
127
+ k_span = k_rope[:, :, :rope_len, :]
128
+ q_head = (q_span * cos_b) + (rotate(q_span) * sin_b)
129
+ k_head = (k_span * cos_b) + (rotate(k_span) * sin_b)
130
+ q_rope = torch.cat([q_head, q_rope[:, :, rope_len:, :]], dim=2)
131
+ k_rope = torch.cat([k_head, k_rope[:, :, rope_len:, :]], dim=2)
132
+ return q_rope.to(dtype=q_dtype), k_rope.to(dtype=k_dtype)
133
+
134
+
135
+ class CrossAttentionCore(nn.Module):
136
+ """Dense cross-attention core used by the class-token readout."""
137
+
138
+ query_dim: int
139
+ context_dim: int
140
+ context_extra_dim: int
141
+ key_extra_dim: int
142
+ n_heads: int
143
+ head_dim: int
144
+ attn_dim: int
145
+ context_in_dim: int
146
+ attn_dropout: float
147
+ kv_proj: nn.Linear
148
+ k_extra_proj: nn.Linear | None
149
+ out_proj: nn.Linear
150
+ q_norm_heads: RMSNorm
151
+ k_norm_heads: RMSNorm
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ query_dim: int,
157
+ context_dim: int,
158
+ n_heads: int,
159
+ head_dim: int,
160
+ context_extra_dim: int = 0,
161
+ key_extra_dim: int = 0,
162
+ attn_dropout: float = 0.0,
163
+ ) -> None:
164
+ super().__init__()
165
+ self.query_dim = int(query_dim)
166
+ self.context_dim = int(context_dim)
167
+ self.context_extra_dim = int(context_extra_dim)
168
+ self.key_extra_dim = int(key_extra_dim)
169
+ self.n_heads = int(n_heads)
170
+ self.head_dim = int(head_dim)
171
+ self.attn_dim = int(self.n_heads * self.head_dim)
172
+ self.context_in_dim = int(self.context_dim + self.context_extra_dim)
173
+ self.attn_dropout = float(attn_dropout)
174
+ self.kv_proj = nn.Linear(self.context_in_dim, 2 * self.attn_dim, bias=False)
175
+ if self.key_extra_dim == 0:
176
+ self.k_extra_proj = None
177
+ else:
178
+ self.k_extra_proj = nn.Linear(self.key_extra_dim, self.attn_dim, bias=False)
179
+ self.out_proj = nn.Linear(self.attn_dim, self.query_dim, bias=False)
180
+ self.q_norm_heads = RMSNorm(self.head_dim)
181
+ self.k_norm_heads = RMSNorm(self.head_dim)
182
+
183
+ def reset_parameters(self) -> None:
184
+ """Reset projections to their initialization."""
185
+
186
+ nn.init.xavier_uniform_(self.kv_proj.weight)
187
+ if self.k_extra_proj is not None:
188
+ nn.init.xavier_uniform_(self.k_extra_proj.weight)
189
+ nn.init.xavier_uniform_(self.out_proj.weight)
190
+
191
+ def _split_heads(self, x: Tensor) -> Tensor:
192
+ batch, sequence_length, _width = x.shape
193
+ return x.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(
194
+ 1, 2
195
+ )
196
+
197
+ def _merge_heads(self, x: Tensor) -> Tensor:
198
+ batch, _heads, sequence_length, _head_dim = x.shape
199
+ return (
200
+ x.transpose(1, 2).contiguous().view(batch, sequence_length, self.attn_dim)
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ q_tokens: Tensor,
206
+ kv_tokens: Tensor,
207
+ *,
208
+ training: bool,
209
+ key_extra: Tensor | None = None,
210
+ key_padding_mask: Tensor | None = None,
211
+ ) -> Tensor:
212
+ """Apply dense cross-attention to query and context tokens."""
213
+
214
+ kv = self.kv_proj(kv_tokens)
215
+ k, v = kv.chunk(2, dim=-1)
216
+ if self.k_extra_proj is not None and key_extra is not None:
217
+ k = k + self.k_extra_proj(key_extra)
218
+ q = self.q_norm_heads(self._split_heads(q_tokens).contiguous())
219
+ k = self.k_norm_heads(self._split_heads(k).contiguous())
220
+ v = self._split_heads(v).contiguous()
221
+ if key_padding_mask is None:
222
+ attn_mask = None
223
+ else:
224
+ attn_mask = (~key_padding_mask).to(dtype=q.dtype)
225
+ attn_mask = attn_mask.view(
226
+ key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1]
227
+ )
228
+ attn_mask = attn_mask.masked_fill(attn_mask > 0, float("-inf"))
229
+ attn = F.scaled_dot_product_attention(
230
+ q,
231
+ k,
232
+ v,
233
+ attn_mask=attn_mask,
234
+ dropout_p=self.attn_dropout if training else 0.0,
235
+ is_causal=False,
236
+ )
237
+ return self.out_proj(self._merge_heads(attn))
238
+
239
+
240
+ __all__ = ["CrossAttentionCore", "DitSelfAttentionCore"]
dit/axial_rope2d.py ADDED
@@ -0,0 +1,1728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass, replace
5
+ from enum import Enum
6
+ from typing import Final, cast
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+ __all__ = [
12
+ "AxialRoPE2D",
13
+ "AxialRoPE2DAlphaWarpConfig",
14
+ "AxialRoPE2DBetaWarpConfig",
15
+ "AxialRoPE2DConfig",
16
+ "AxialRoPE2DCoordMode",
17
+ "AxialRoPE2DDimLayout",
18
+ "AxialRoPE2DDyPE",
19
+ "AxialRoPE2DDyPEConfig",
20
+ "AxialRoPE2DFrequencyAwareConfig",
21
+ "AxialRoPE2DNormalizeCoords",
22
+ "DyPERoPEMethod",
23
+ "build_axial_rope2d_dype",
24
+ "build_axial_rope2d_inference_warp_with_strength",
25
+ "build_axial_rope2d_with_lumina_frequency_warp",
26
+ "lumina_frequency_aware_periods_for_axis",
27
+ "set_axial_rope2d_dype_noise_time",
28
+ ]
29
+
30
+
31
+ class AxialRoPE2DNormalizeCoords(Enum):
32
+ """Coordinate normalization strategy for axial 2D RoPE (DINOv3-style)."""
33
+
34
+ MIN = "min"
35
+ MAX = "max"
36
+ SEPARATE = "separate"
37
+
38
+
39
+ class AxialRoPE2DCoordMode(Enum):
40
+ """Coordinate grid mode for axial 2D RoPE.
41
+
42
+ - ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in
43
+ ``[-1, 1]`` (after normalization).
44
+ - ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch
45
+ units (e.g., ``x in [0, W-1]``, ``y in [0, H-1]``).
46
+ """
47
+
48
+ DINOV3_NORMALIZED = "dinov3_normalized"
49
+ PATCH_INDICES = "patch_indices"
50
+
51
+
52
+ class AxialRoPE2DDimLayout(Enum):
53
+ """Layout of angles along the head-dimension.
54
+
55
+ The layout must match the rotation convention used when applying RoPE to Q/K.
56
+
57
+ - ``HALF_SPLIT``: LLaMA-style layout compatible with ``common.rope.rotate_half``
58
+ (splits last dim into two halves).
59
+ - ``PAIR_INTERLEAVED``: EVA-02 / SpeedrunDiT-style layout compatible with an
60
+ adjacent-pair rotate_half (pairs consecutive dims).
61
+
62
+ TODO(refactor): Standardize on ``PAIR_INTERLEAVED`` throughout DiT to reduce
63
+ complexity and avoid layout mismatches, then delete ``HALF_SPLIT`` and any
64
+ related branching once the migration is complete.
65
+ """
66
+
67
+ HALF_SPLIT = "half_split"
68
+ PAIR_INTERLEAVED = "pair_interleaved"
69
+
70
+
71
+ class DyPERoPEMethod(Enum):
72
+ """Dynamic position extrapolation method applied to inference RoPE."""
73
+
74
+ VISION_YARN = "vision_yarn"
75
+ DY_YARN = "dy_yarn"
76
+ DY_NTK = "dy_ntk"
77
+
78
+
79
+ @dataclass(frozen=True)
80
+ class AxialRoPE2DDyPEConfig:
81
+ """Inference-only DyPE controls for axial RoPE.
82
+
83
+ Args:
84
+ method: Dynamic extrapolation rule to apply.
85
+ ref_h_tokens: Training/reference token height.
86
+ ref_w_tokens: Training/reference token width.
87
+ lambda_s: Dynamic extrapolation magnitude.
88
+ lambda_t: Dynamic extrapolation noise-time exponent.
89
+ yarn_beta_0: YaRN first-ramp high rotation threshold.
90
+ yarn_beta_1: YaRN first-ramp low rotation threshold.
91
+ yarn_gamma_0: YaRN base-blend high rotation threshold.
92
+ yarn_gamma_1: YaRN base-blend low rotation threshold.
93
+ yarn_attention_scale: Apply YaRN's static attention magnitude correction.
94
+ """
95
+
96
+ method: DyPERoPEMethod
97
+ ref_h_tokens: int
98
+ ref_w_tokens: int
99
+ lambda_s: float = 2.0
100
+ lambda_t: float = 2.0
101
+ yarn_beta_0: float = 1.25
102
+ yarn_beta_1: float = 0.75
103
+ yarn_gamma_0: float = 16.0
104
+ yarn_gamma_1: float = 2.0
105
+ yarn_attention_scale: bool = True
106
+
107
+ def __post_init__(self) -> None:
108
+ if not isinstance(self.method, DyPERoPEMethod):
109
+ raise TypeError("method must be a DyPERoPEMethod")
110
+ if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
111
+ raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
112
+ for name, value in (
113
+ ("lambda_s", self.lambda_s),
114
+ ("lambda_t", self.lambda_t),
115
+ ("yarn_beta_0", self.yarn_beta_0),
116
+ ("yarn_beta_1", self.yarn_beta_1),
117
+ ("yarn_gamma_0", self.yarn_gamma_0),
118
+ ("yarn_gamma_1", self.yarn_gamma_1),
119
+ ):
120
+ v = float(value)
121
+ if not math.isfinite(v) or v <= 0.0:
122
+ raise ValueError(f"{name} must be finite and > 0")
123
+ if not isinstance(self.yarn_attention_scale, bool):
124
+ raise TypeError("yarn_attention_scale must be a bool")
125
+
126
+
127
+ @dataclass(frozen=True)
128
+ class AxialRoPE2DFrequencyAwareConfig:
129
+ """Lumina/Next-DiT-style frequency-aware RoPE warping for one token grid.
130
+
131
+ This config implements a per-axis, per-band frequency warp that depends on
132
+ the input axis length ``L`` relative to a reference length ``L_ref``:
133
+
134
+ - Define the axis scale ``s = L / L_ref``.
135
+ - RoPE is parameterized by *periods* (wavelengths in tokens) ``period[d]``.
136
+ In this module's axial parameterization (with patch-index coordinates),
137
+ the angle for coordinate ``p`` and band ``d`` is:
138
+
139
+ angle(p, d) = 2π * p / period[d]
140
+
141
+ so the wavelength of band ``d`` is exactly ``period[d]`` tokens.
142
+
143
+ - Pick a *boundary wavelength* ``L_boundary`` (in tokens), expressed as a
144
+ trainable multiplier around the reference length:
145
+
146
+ L_boundary = L_ref * exp(boundary_log_multiplier)
147
+
148
+ The scalar ``boundary_log_multiplier`` is shared across H/W axes (and
149
+ initialized by this config).
150
+
151
+ - Define a (possibly fractional) boundary band index ``d*`` as the band
152
+ whose wavelength equals ``L_boundary``:
153
+
154
+ period(d*) = L_boundary
155
+
156
+ In practice we compute ``d*`` by linear interpolation in log-period space
157
+ (periods are geometric for both supported period parametrizations).
158
+
159
+ - The Lumina/Next-DiT implicit exponent ramp is then:
160
+
161
+ alpha[d] = clamp(d / d*, 0, 1)
162
+
163
+ where:
164
+ - high-frequency bands (small d) have alpha≈0 (extrapolation-like),
165
+ - low-frequency bands (large d) have alpha→1 (interpolation-like),
166
+ - alpha is capped at 1 to ensure we never compress a band more than
167
+ plain position interpolation would.
168
+
169
+ - Finally, warp the periods per axis:
170
+
171
+ period'[d] = period[d] * s ** alpha[d]
172
+
173
+ Equivalently, angular frequencies warp as:
174
+
175
+ omega'[d] = omega[d] / s ** alpha[d]
176
+
177
+ Notes
178
+ -----
179
+ - This warp is only meaningful for patch-index coordinates
180
+ (``AxialRoPE2DCoordMode.PATCH_INDICES``). Mixing it with normalized
181
+ coordinates would create an implicit "gauge switch"; we fail fast.
182
+ - The boundary multiplier is trainable by construction (it is stored as an
183
+ nn.Parameter inside AxialRoPE2D when this config is present).
184
+ """
185
+
186
+ ref_h_tokens: int
187
+ ref_w_tokens: int
188
+ boundary_log_multiplier_init: float
189
+
190
+ def __post_init__(self) -> None:
191
+ if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
192
+ raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
193
+ init = float(self.boundary_log_multiplier_init)
194
+ if not math.isfinite(init):
195
+ raise ValueError("boundary_log_multiplier_init must be finite")
196
+
197
+
198
+ @dataclass(frozen=True)
199
+ class AxialRoPE2DBetaWarpConfig:
200
+ """Trainable beta-curve warping for axial 2D RoPE periods (per token grid).
201
+
202
+ This config defines a per-axis period warp that depends on the runtime axis
203
+ length ``L`` relative to a reference length ``L_ref``:
204
+
205
+ s = L / L_ref
206
+ period'[d] = period[d] * s ** beta[d]
207
+
208
+ where the per-band exponent curve beta(d) is parameterized by three
209
+ trainable u-space scalars (shared across H/W axes):
210
+
211
+ beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0)
212
+ beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1)
213
+ beta_bend = beta_max * tanh(beta_bend_u) (mid-band bump amplitude)
214
+
215
+ and the per-band curve is:
216
+
217
+ t = d / (qtr - 1) in [0, 1]
218
+ beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t)
219
+
220
+ Interpretation
221
+ --------------
222
+ - ``beta(d) == 0``: identity / "extrapolation-like" (no warping; periods do not
223
+ change with axis length).
224
+ - ``beta(d) == 1``: position-interpolation-like for that band
225
+ (``period'[d] = period[d] * s`` so ``omega'[d] = omega[d] / s``).
226
+
227
+ This parameterization provides strong and smooth control over the effective
228
+ scaling of each frequency band, including allowing beta<0 (increasing
229
+ frequencies when s>1), which can be important for unnormalized RoPE bases
230
+ (e.g. base=10_000) where some very low-frequency bands barely rotate on
231
+ practical token grids.
232
+
233
+ Notes:
234
+ - This warp requires patch-index coordinates (coord_mode=PATCH_INDICES).
235
+ - The u parameters are stored as nn.Parameter inside AxialRoPE2D when this
236
+ config is present.
237
+ """
238
+
239
+ ref_h_tokens: int
240
+ ref_w_tokens: int
241
+ beta_max: float
242
+ beta_hi_u_init: float
243
+ beta_lo_u_init: float
244
+ beta_bend_u_init: float
245
+
246
+ def __post_init__(self) -> None:
247
+ if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
248
+ raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
249
+ bmax = float(self.beta_max)
250
+ if not math.isfinite(bmax) or bmax <= 0.0:
251
+ raise ValueError("beta_max must be finite and > 0")
252
+ for name, value in (
253
+ ("beta_hi_u_init", self.beta_hi_u_init),
254
+ ("beta_lo_u_init", self.beta_lo_u_init),
255
+ ("beta_bend_u_init", self.beta_bend_u_init),
256
+ ):
257
+ v = float(value)
258
+ if not math.isfinite(v):
259
+ raise ValueError(f"{name} must be finite")
260
+
261
+
262
+ @dataclass(frozen=True)
263
+ class AxialRoPE2DAlphaWarpConfig:
264
+ """Per-band power-law warping of axial 2D RoPE frequencies (shared across axes).
265
+
266
+ This config warps RoPE frequencies per band using a learned exponent vector
267
+ ``alpha[d]`` shared across H/W axes:
268
+
269
+ f'[d] = f[d] * s ** alpha[d] where s = L / L_ref
270
+
271
+ Since this module parameterizes angles via periods ``period[d]`` with
272
+ ``f[d] ∝ 1 / period[d]``, the equivalent period warp implemented in AxialRoPE2D is:
273
+
274
+ period'[d] = period[d] / s ** alpha[d]
275
+
276
+ Notes:
277
+ - This warp requires patch-index coordinates (coord_mode=PATCH_INDICES).
278
+ - ``alpha`` is stored as an unconstrained nn.Parameter vector of length Q
279
+ (bands per axis), initialized to ``alpha_init`` for all bands.
280
+ """
281
+
282
+ ref_h_tokens: int
283
+ ref_w_tokens: int
284
+ alpha_init: float
285
+
286
+ def __post_init__(self) -> None:
287
+ if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
288
+ raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
289
+ init = float(self.alpha_init)
290
+ if not math.isfinite(init):
291
+ raise ValueError("alpha_init must be finite")
292
+
293
+
294
+ @dataclass(frozen=True)
295
+ class AxialRoPE2DConfig:
296
+ """Configuration for axial 2D RoPE sin/cos generation.
297
+
298
+ This module supports two coordinate conventions via ``coord_mode``:
299
+ - ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in
300
+ ``[-1, 1]`` (after normalization).
301
+ - ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch
302
+ units (e.g., ``x in [0, W-1]``).
303
+
304
+ Period parametrization
305
+ ----------------------
306
+ The periods parametrization matches DINOv3:
307
+ - Provide either `base` (and leave `min_period/max_period` unset), or
308
+ - Provide both `min_period` and `max_period` (and set `base=None`).
309
+ """
310
+
311
+ base: float | None = 100.0
312
+ min_period: float | None = None
313
+ max_period: float | None = None
314
+ coord_mode: AxialRoPE2DCoordMode = AxialRoPE2DCoordMode.DINOV3_NORMALIZED
315
+ normalize_coords: AxialRoPE2DNormalizeCoords = AxialRoPE2DNormalizeCoords.MAX
316
+ dim_layout: AxialRoPE2DDimLayout = AxialRoPE2DDimLayout.HALF_SPLIT
317
+ angle_multiplier: float = 2.0 * float(math.pi)
318
+ coord_offset: float = 0.5
319
+ frequency_aware: AxialRoPE2DFrequencyAwareConfig | None = None
320
+ beta_warp: AxialRoPE2DBetaWarpConfig | None = None
321
+ alpha_warp: AxialRoPE2DAlphaWarpConfig | None = None
322
+
323
+ def __post_init__(self) -> None:
324
+ both_periods = self.min_period is not None and self.max_period is not None
325
+ if (self.base is None and not both_periods) or (
326
+ self.base is not None and both_periods
327
+ ):
328
+ raise ValueError(
329
+ "AxialRoPE2DConfig requires either base!=None, or both min_period and max_period."
330
+ )
331
+ if self.base is not None and float(self.base) <= 0.0:
332
+ raise ValueError("AxialRoPE2DConfig.base must be positive when provided")
333
+ if self.min_period is not None and float(self.min_period) <= 0.0:
334
+ raise ValueError(
335
+ "AxialRoPE2DConfig.min_period must be positive when provided"
336
+ )
337
+ if self.max_period is not None and float(self.max_period) <= 0.0:
338
+ raise ValueError(
339
+ "AxialRoPE2DConfig.max_period must be positive when provided"
340
+ )
341
+ if self.min_period is not None and self.max_period is not None:
342
+ if float(self.max_period) <= float(self.min_period):
343
+ raise ValueError("AxialRoPE2DConfig.max_period must be > min_period")
344
+ if not isinstance(self.coord_mode, AxialRoPE2DCoordMode):
345
+ raise TypeError(
346
+ "AxialRoPE2DConfig.coord_mode must be an AxialRoPE2DCoordMode"
347
+ )
348
+ if not isinstance(self.normalize_coords, AxialRoPE2DNormalizeCoords):
349
+ raise TypeError(
350
+ "AxialRoPE2DConfig.normalize_coords must be an AxialRoPE2DNormalizeCoords"
351
+ )
352
+ if not isinstance(self.dim_layout, AxialRoPE2DDimLayout):
353
+ raise TypeError(
354
+ "AxialRoPE2DConfig.dim_layout must be an AxialRoPE2DDimLayout"
355
+ )
356
+ mult = float(self.angle_multiplier)
357
+ if not math.isfinite(mult) or mult <= 0.0:
358
+ raise ValueError(
359
+ "AxialRoPE2DConfig.angle_multiplier must be finite and > 0"
360
+ )
361
+ off = float(self.coord_offset)
362
+ if not math.isfinite(off):
363
+ raise ValueError("AxialRoPE2DConfig.coord_offset must be finite")
364
+ if self.frequency_aware is not None and not isinstance(
365
+ self.frequency_aware, AxialRoPE2DFrequencyAwareConfig
366
+ ):
367
+ raise TypeError(
368
+ "AxialRoPE2DConfig.frequency_aware must be an AxialRoPE2DFrequencyAwareConfig"
369
+ )
370
+ if self.beta_warp is not None and not isinstance(
371
+ self.beta_warp, AxialRoPE2DBetaWarpConfig
372
+ ):
373
+ raise TypeError(
374
+ "AxialRoPE2DConfig.beta_warp must be an AxialRoPE2DBetaWarpConfig"
375
+ )
376
+ if self.alpha_warp is not None and not isinstance(
377
+ self.alpha_warp, AxialRoPE2DAlphaWarpConfig
378
+ ):
379
+ raise TypeError(
380
+ "AxialRoPE2DConfig.alpha_warp must be an AxialRoPE2DAlphaWarpConfig"
381
+ )
382
+ warp_count = (
383
+ int(self.frequency_aware is not None)
384
+ + int(self.beta_warp is not None)
385
+ + int(self.alpha_warp is not None)
386
+ )
387
+ if warp_count > 1:
388
+ raise ValueError(
389
+ "AxialRoPE2DConfig requires at most one of frequency_aware, beta_warp, or alpha_warp"
390
+ )
391
+ if self.frequency_aware is not None and (
392
+ self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES
393
+ ):
394
+ raise ValueError(
395
+ "AxialRoPE2D frequency-aware warping requires coord_mode=PATCH_INDICES"
396
+ )
397
+ if self.beta_warp is not None and (
398
+ self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES
399
+ ):
400
+ raise ValueError("AxialRoPE2D beta warp requires coord_mode=PATCH_INDICES")
401
+ if self.alpha_warp is not None and (
402
+ self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES
403
+ ):
404
+ raise ValueError("AxialRoPE2D alpha warp requires coord_mode=PATCH_INDICES")
405
+
406
+
407
+ _AXIAL_COORDS_CACHE: dict[
408
+ tuple[
409
+ int, int, torch.device, AxialRoPE2DCoordMode, AxialRoPE2DNormalizeCoords, float
410
+ ],
411
+ Tensor,
412
+ ] = {}
413
+
414
+
415
+ def _get_dinov3_normalized_coords(
416
+ H: int,
417
+ W: int,
418
+ *,
419
+ device: torch.device,
420
+ normalize: AxialRoPE2DNormalizeCoords,
421
+ offset: float,
422
+ ) -> Tensor:
423
+ """Return DINOv3-style flattened coords in [-1, 1] with shape [HW, 2]."""
424
+ if H <= 0 or W <= 0:
425
+ raise ValueError("H and W must be positive for axial RoPE coords")
426
+ key = (
427
+ int(H),
428
+ int(W),
429
+ device,
430
+ AxialRoPE2DCoordMode.DINOV3_NORMALIZED,
431
+ normalize,
432
+ float(offset),
433
+ )
434
+ cached = _AXIAL_COORDS_CACHE.get(key)
435
+ if cached is not None:
436
+ return cached
437
+ start = float(offset)
438
+ end_h = start + float(int(H))
439
+ end_w = start + float(int(W))
440
+ match normalize:
441
+ case AxialRoPE2DNormalizeCoords.MAX:
442
+ denom = float(max(int(H), int(W)))
443
+ coords_h = (
444
+ torch.arange(start, end_h, device=device, dtype=torch.float32) / denom
445
+ )
446
+ coords_w = (
447
+ torch.arange(start, end_w, device=device, dtype=torch.float32) / denom
448
+ )
449
+ case AxialRoPE2DNormalizeCoords.MIN:
450
+ denom = float(min(int(H), int(W)))
451
+ coords_h = (
452
+ torch.arange(start, end_h, device=device, dtype=torch.float32) / denom
453
+ )
454
+ coords_w = (
455
+ torch.arange(start, end_w, device=device, dtype=torch.float32) / denom
456
+ )
457
+ case AxialRoPE2DNormalizeCoords.SEPARATE:
458
+ coords_h = torch.arange(
459
+ start, end_h, device=device, dtype=torch.float32
460
+ ) / float(int(H))
461
+ coords_w = torch.arange(
462
+ start, end_w, device=device, dtype=torch.float32
463
+ ) / float(int(W))
464
+ case _ as unreachable: # pragma: no cover - defensive
465
+ raise RuntimeError(f"Unsupported normalize_coords: {unreachable}")
466
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
467
+ coords = coords.flatten(0, 1)
468
+ coords = 2.0 * coords - 1.0
469
+ # torch.compile cannot trace `torch.is_inference_mode_enabled()` and should
470
+ # not record Python-side cache mutations in the graph.
471
+ if torch.compiler.is_compiling():
472
+ return coords
473
+ if torch.is_inference_mode_enabled():
474
+ return coords
475
+ _AXIAL_COORDS_CACHE[key] = coords
476
+ return coords
477
+
478
+
479
+ def _get_patch_index_coords(
480
+ H: int,
481
+ W: int,
482
+ *,
483
+ device: torch.device,
484
+ offset: float,
485
+ ) -> Tensor:
486
+ """Return unnormalized patch-grid coords with shape [HW, 2] and (y, x) columns."""
487
+ if H <= 0 or W <= 0:
488
+ raise ValueError("H and W must be positive for axial RoPE coords")
489
+ key = (
490
+ int(H),
491
+ int(W),
492
+ device,
493
+ AxialRoPE2DCoordMode.PATCH_INDICES,
494
+ AxialRoPE2DNormalizeCoords.MAX,
495
+ float(offset),
496
+ )
497
+ cached = _AXIAL_COORDS_CACHE.get(key)
498
+ if cached is not None:
499
+ return cached
500
+ start = float(offset)
501
+ end_h = start + float(int(H))
502
+ end_w = start + float(int(W))
503
+ coords_h = torch.arange(start, end_h, device=device, dtype=torch.float32)
504
+ coords_w = torch.arange(start, end_w, device=device, dtype=torch.float32)
505
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
506
+ coords = coords.flatten(0, 1)
507
+ if torch.compiler.is_compiling():
508
+ return coords
509
+ if torch.is_inference_mode_enabled():
510
+ return coords
511
+ _AXIAL_COORDS_CACHE[key] = coords
512
+ return coords
513
+
514
+
515
+ def _lumina_boundary_band_index(
516
+ *,
517
+ periods: Tensor,
518
+ boundary_wavelength: Tensor,
519
+ ) -> Tensor:
520
+ """Return the fractional boundary band index d* for a given boundary wavelength.
521
+
522
+ This implements the Lumina/Next-DiT definition:
523
+
524
+ period(d*) = boundary_wavelength
525
+
526
+ We compute d* by linear interpolation in log-period space. For the supported
527
+ period parameterizations, periods are geometric and log(period) is linear in
528
+ band index.
529
+
530
+ Args:
531
+ periods: 1D float tensor of length Q containing monotonically increasing
532
+ periods in tokens.
533
+ boundary_wavelength: Scalar positive float tensor giving the desired
534
+ boundary wavelength in tokens.
535
+
536
+ Returns:
537
+ Scalar float32 tensor giving the (possibly fractional) boundary index d*.
538
+
539
+ Raises:
540
+ ValueError: If periods are invalid or the boundary is outside valid range
541
+ for a well-defined positive d*.
542
+ """
543
+ if periods.dim() != 1:
544
+ raise ValueError("periods must be 1D for boundary band index")
545
+ if int(periods.numel()) < 2:
546
+ raise ValueError("periods must have length >= 2 for boundary band index")
547
+ if boundary_wavelength.dim() != 0:
548
+ raise ValueError("boundary_wavelength must be a scalar tensor")
549
+ if not torch.isfinite(boundary_wavelength).item():
550
+ raise ValueError("boundary_wavelength must be finite")
551
+ if float(boundary_wavelength.item()) <= 0.0:
552
+ raise ValueError("boundary_wavelength must be > 0")
553
+
554
+ periods_f = periods.to(dtype=torch.float32)
555
+ if not torch.isfinite(periods_f).all().item():
556
+ raise ValueError("periods must be finite for boundary band index")
557
+ if float(periods_f[0].item()) <= 0.0:
558
+ raise ValueError("periods must be positive for boundary band index")
559
+ if not (periods_f[1:] > periods_f[:-1]).all().item():
560
+ raise ValueError("periods must be strictly increasing for boundary band index")
561
+
562
+ log_p0 = torch.log(periods_f[0])
563
+ log_p1 = torch.log(periods_f[-1])
564
+ denom = log_p1 - log_p0
565
+ if float(denom.item()) <= 0.0:
566
+ raise ValueError("Invalid periods range for boundary band index")
567
+ log_boundary = torch.log(boundary_wavelength.to(dtype=torch.float32))
568
+ q = int(periods_f.numel())
569
+ d_star = (float(q - 1) * (log_boundary - log_p0)) / denom
570
+ if not torch.isfinite(d_star).item():
571
+ raise ValueError("Computed non-finite boundary band index d*")
572
+ if float(d_star.item()) <= 0.0:
573
+ raise ValueError(
574
+ "Boundary wavelength implies d* <= 0; increase the boundary wavelength "
575
+ "(or its multiplier) to be >= the wavelength of the first non-zero band."
576
+ )
577
+ return d_star
578
+
579
+
580
+ def _lumina_alpha_ramp(
581
+ *,
582
+ qtr: int,
583
+ d_star: Tensor,
584
+ device: torch.device,
585
+ ) -> Tensor:
586
+ """Return alpha[d] = clamp(d / d*, 0, 1) for d in [0, qtr).
587
+
588
+ Args:
589
+ qtr: Number of RoPE bands per axis (Q).
590
+ d_star: Scalar positive float tensor boundary index d*.
591
+ device: Device for the returned alpha tensor.
592
+
593
+ Returns:
594
+ Float32 tensor of shape [Q] with values in [0, 1].
595
+ """
596
+ if int(qtr) <= 0:
597
+ raise ValueError("qtr must be positive for alpha ramp")
598
+ if d_star.dim() != 0:
599
+ raise ValueError("d_star must be a scalar tensor for alpha ramp")
600
+ if float(d_star.item()) <= 0.0:
601
+ raise ValueError("d_star must be > 0 for alpha ramp")
602
+ d = torch.arange(int(qtr), device=device, dtype=torch.float32)
603
+ alpha = d / d_star.to(device=device, dtype=torch.float32)
604
+ return torch.clamp(alpha, min=0.0, max=1.0)
605
+
606
+
607
+ def lumina_frequency_aware_periods_for_axis(
608
+ *,
609
+ periods: Tensor,
610
+ axis_len: int,
611
+ ref_axis_len: int,
612
+ boundary_log_multiplier: Tensor,
613
+ angle_multiplier: float,
614
+ ) -> Tensor:
615
+ """Return Lumina/Next-DiT frequency-aware warped periods for one axis.
616
+
617
+ Implements:
618
+ s = axis_len / ref_axis_len
619
+ L_boundary = ref_axis_len * exp(boundary_log_multiplier)
620
+ d* = boundary band index where period(d*) = L_boundary
621
+ alpha[d] = clamp(d / d*, 0, 1)
622
+ period'[d] = period[d] * s**alpha[d]
623
+
624
+ Notes on ``angle_multiplier``
625
+ -----------------------------
626
+ This module parameterizes angles as:
627
+
628
+ angle(p, d) = angle_multiplier * p / period[d]
629
+
630
+ The *wavelength* (period in tokens) is the delta in ``p`` that increases
631
+ the angle by ``2π``:
632
+
633
+ wavelength[d] = 2π * period[d] / angle_multiplier
634
+
635
+ Lumina/Next-DiT define the boundary by matching *wavelength* to the
636
+ reference axis length. We therefore convert the boundary wavelength
637
+ ``L_boundary`` into a boundary period via:
638
+
639
+ period_boundary = (angle_multiplier / 2π) * L_boundary
640
+
641
+ When ``angle_multiplier == 2π`` (the DINOv3-style parameterization), this
642
+ reduces to ``period_boundary == L_boundary``.
643
+
644
+ Args:
645
+ periods: Base periods ``[Q]`` in tokens (wavelengths).
646
+ axis_len: Input axis length ``L`` in tokens.
647
+ ref_axis_len: Reference axis length ``L_ref`` in tokens.
648
+ boundary_log_multiplier: Scalar tensor; shared trainable log-multiplier.
649
+ angle_multiplier: RoPE angle multiplier used when converting periods to
650
+ physical wavelengths in tokens.
651
+
652
+ Returns:
653
+ Warped periods ``[Q]`` as float32.
654
+
655
+ Raises:
656
+ ValueError: If inputs are malformed or imply an invalid boundary index.
657
+ """
658
+ if int(axis_len) <= 0:
659
+ raise ValueError("axis_len must be positive for frequency-aware periods")
660
+ if int(ref_axis_len) <= 0:
661
+ raise ValueError("ref_axis_len must be positive for frequency-aware periods")
662
+ if boundary_log_multiplier.dim() != 0:
663
+ raise ValueError("boundary_log_multiplier must be a scalar tensor")
664
+ if not torch.isfinite(boundary_log_multiplier).item():
665
+ raise ValueError("boundary_log_multiplier must be finite")
666
+ mult = float(angle_multiplier)
667
+ if not math.isfinite(mult) or mult <= 0.0:
668
+ raise ValueError("angle_multiplier must be finite and > 0")
669
+
670
+ device = periods.device
671
+ qtr = int(periods.numel())
672
+ s = float(int(axis_len)) / float(int(ref_axis_len))
673
+ if not math.isfinite(s) or s <= 0.0:
674
+ raise ValueError("axis_len/ref_axis_len must be finite and > 0")
675
+ boundary_wavelength = float(int(ref_axis_len)) * torch.exp(
676
+ boundary_log_multiplier.to(device=device, dtype=torch.float32)
677
+ )
678
+ boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength
679
+ d_star = _lumina_boundary_band_index(
680
+ periods=periods, boundary_wavelength=boundary_period
681
+ )
682
+ alpha = _lumina_alpha_ramp(qtr=qtr, d_star=d_star, device=device)
683
+ scale = torch.pow(torch.tensor(s, device=device, dtype=torch.float32), alpha)
684
+ return periods.to(device=device, dtype=torch.float32) * scale
685
+
686
+
687
+ def build_axial_rope2d_with_lumina_frequency_warp(
688
+ base: AxialRoPE2D,
689
+ *,
690
+ ref_h_tokens: int,
691
+ ref_w_tokens: int,
692
+ boundary_log_multiplier: float | None,
693
+ boundary_band_multiplier: float | None,
694
+ ) -> AxialRoPE2D:
695
+ """Return an AxialRoPE2D module that applies Lumina-style frequency warping.
696
+
697
+ This helper is intended for inference-time experimentation on checkpoints
698
+ that were trained without frequency-aware warping (e.g.
699
+ ``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``). It constructs a new
700
+ AxialRoPE2D instance that:
701
+ - Keeps the base RoPE periods and layout identical to ``base``.
702
+ - Applies Lumina/Next-DiT per-axis warping based on the runtime token
703
+ lengths ``H`` and ``W`` relative to reference lengths.
704
+ - Uses a fixed (non-trainable) scalar boundary multiplier for inference.
705
+
706
+ Args:
707
+ base: Existing AxialRoPE2D instance from a loaded model.
708
+ ref_h_tokens: Reference H token length (L_ref,h).
709
+ ref_w_tokens: Reference W token length (L_ref,w).
710
+ boundary_log_multiplier: Optional log multiplier applied to reference
711
+ lengths to define the boundary wavelength. Use 0.0 for "boundary at
712
+ L_ref". Mutually exclusive with boundary_band_multiplier.
713
+ boundary_band_multiplier: Optional multiplier that directly selects the
714
+ boundary band index d* relative to the lowest-frequency band index
715
+ (qtr-1). Concretely, with qtr bands per axis:
716
+
717
+ d* = boundary_band_multiplier * (qtr - 1)
718
+
719
+ This lets you move the transition point in frequency space:
720
+ - smaller values => more bands become PI-like (more interpolation)
721
+ - larger values => fewer bands become PI-like (more extrapolation)
722
+
723
+ When provided, we compute the implied boundary wavelength and store
724
+ it as boundary_log_multiplier for the module.
725
+
726
+ Returns:
727
+ New AxialRoPE2D instance on the same device as ``base``.
728
+
729
+ Raises:
730
+ TypeError: If base is not an AxialRoPE2D.
731
+ ValueError: If base uses incompatible coordinates for Lumina warping.
732
+ """
733
+ if not isinstance(base, AxialRoPE2D):
734
+ raise TypeError("base must be an AxialRoPE2D")
735
+ if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES:
736
+ raise ValueError(
737
+ "Lumina frequency-aware warping requires coord_mode=PATCH_INDICES"
738
+ )
739
+ if (boundary_log_multiplier is None) == (boundary_band_multiplier is None):
740
+ raise ValueError(
741
+ "Provide exactly one of boundary_log_multiplier or boundary_band_multiplier"
742
+ )
743
+
744
+ resolved_log_multiplier: float
745
+ if boundary_band_multiplier is not None:
746
+ if int(ref_h_tokens) != int(ref_w_tokens):
747
+ raise ValueError(
748
+ "boundary_band_multiplier requires ref_h_tokens == ref_w_tokens when using a shared scalar boundary"
749
+ )
750
+ mult = float(boundary_band_multiplier)
751
+ if not math.isfinite(mult) or mult <= 0.0:
752
+ raise ValueError("boundary_band_multiplier must be finite and > 0")
753
+ qtr = int(base.periods.numel())
754
+ if qtr < 2:
755
+ raise ValueError(
756
+ "AxialRoPE2D periods length must be >= 2 for boundary band selection"
757
+ )
758
+ # Solve for the boundary wavelength implied by choosing d* directly.
759
+ #
760
+ # We use geometric interpolation in log-period space:
761
+ # log(period(d*)) = log(period0) + (d*/(qtr-1)) * (log(period_max) - log(period0))
762
+ # with:
763
+ # d* = boundary_band_multiplier * (qtr-1)
764
+ #
765
+ # This allows d* outside the trained band range (multiplier > 1), which
766
+ # corresponds to pushing the transition beyond the lowest-frequency band.
767
+ with torch.no_grad():
768
+ periods_f = base.periods.to(dtype=torch.float32, device=torch.device("cpu"))
769
+ if not (periods_f[1:] > periods_f[:-1]).all().item():
770
+ raise ValueError(
771
+ "base.periods must be strictly increasing for boundary band selection"
772
+ )
773
+ log_p0 = float(torch.log(periods_f[0]).item())
774
+ log_p1 = float(torch.log(periods_f[-1]).item())
775
+ d_star = mult * float(qtr - 1)
776
+ log_boundary_period = log_p0 + (d_star / float(qtr - 1)) * (log_p1 - log_p0)
777
+ boundary_period = math.exp(log_boundary_period)
778
+ angle_mult = float(base.cfg.angle_multiplier)
779
+ if not math.isfinite(angle_mult) or angle_mult <= 0.0:
780
+ raise ValueError("base.cfg.angle_multiplier must be finite and > 0")
781
+ boundary_wavelength = (2.0 * float(math.pi) / angle_mult) * boundary_period
782
+ resolved_log_multiplier = math.log(
783
+ boundary_wavelength / float(int(ref_h_tokens))
784
+ )
785
+ else:
786
+ if boundary_log_multiplier is None: # pragma: no cover - validated above
787
+ raise RuntimeError("boundary_log_multiplier missing despite validation")
788
+ resolved_log_multiplier = float(boundary_log_multiplier)
789
+
790
+ freq_cfg = AxialRoPE2DFrequencyAwareConfig(
791
+ ref_h_tokens=int(ref_h_tokens),
792
+ ref_w_tokens=int(ref_w_tokens),
793
+ boundary_log_multiplier_init=resolved_log_multiplier,
794
+ )
795
+ cfg = replace(base.cfg, frequency_aware=freq_cfg, beta_warp=None, alpha_warp=None)
796
+ device = base.periods.device
797
+ warped = AxialRoPE2D(head_dim=int(base.head_dim), cfg=cfg).to(device=device)
798
+ with torch.no_grad():
799
+ warped.periods.copy_(base.periods.to(device=device, dtype=torch.float32))
800
+ if warped.boundary_log_multiplier is None: # pragma: no cover - defensive
801
+ raise RuntimeError("Expected boundary_log_multiplier to be initialized")
802
+ warped.boundary_log_multiplier.copy_(
803
+ torch.tensor(resolved_log_multiplier, device=device, dtype=torch.float32)
804
+ )
805
+ warped.boundary_log_multiplier.requires_grad_(False)
806
+ return warped
807
+
808
+
809
+ def build_axial_rope2d_inference_warp_with_strength(
810
+ base: AxialRoPE2D,
811
+ *,
812
+ ref_h_tokens: int,
813
+ ref_w_tokens: int,
814
+ beta_hi_u: float,
815
+ beta_lo_u: float,
816
+ beta_bend_u: float,
817
+ beta_max: float,
818
+ ) -> AxialRoPE2D:
819
+ """Build an inference-only RoPE warp parameterized by a 3-knob beta(t) curve.
820
+
821
+ This helper is meant for notebook experimentation on checkpoints trained
822
+ with patch-index axial RoPE (e.g. ``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``).
823
+
824
+ We warp per-axis RoPE periods (wavelengths, in tokens) as:
825
+
826
+ period'[d] = period[d] * s ** beta[d] where s = L / L_ref
827
+
828
+ with a smooth exponent curve beta(d) over bands. Unlike a strict
829
+ interpolation-only exponent (0..1), beta is allowed to be negative or > 1,
830
+ which is important for unnormalized RoPE (e.g. base=10_000) where some very
831
+ low-frequency bands are effectively "dead" on practical token grids unless
832
+ their frequencies can be increased (beta < 0).
833
+
834
+ Knobs (bounded via u-space)
835
+ ---------------------------
836
+ We use three unconstrained parameters (u-space) which map to bounded beta
837
+ values via tanh:
838
+
839
+ beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0)
840
+ beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1)
841
+ beta_bend = beta_max * tanh(beta_bend_u) ("bump" amplitude in the middle)
842
+
843
+ Then define the per-band curve over t in [0,1] (high -> low frequency):
844
+
845
+ t = d / (qtr - 1)
846
+ beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t)
847
+
848
+ The bump term is 0 at the endpoints and peaks at 1 at t=0.5.
849
+
850
+ Notes
851
+ -----
852
+ - This wrapper is inference-only: it is not saved in checkpoints.
853
+ - It requires patch-index coordinates (no normalized "gauge").
854
+ - It preserves the base module's periods and layout exactly.
855
+
856
+ Args:
857
+ base: Existing AxialRoPE2D instance from a loaded model.
858
+ ref_h_tokens: Reference H token length (L_ref,h).
859
+ ref_w_tokens: Reference W token length (L_ref,w).
860
+ beta_hi_u: Unconstrained u for beta_hi.
861
+ beta_lo_u: Unconstrained u for beta_lo.
862
+ beta_bend_u: Unconstrained u for beta_bend (mid-band bump).
863
+ beta_max: Maximum absolute beta value (> 0). Higher increases control.
864
+
865
+ Returns:
866
+ An AxialRoPE2D instance whose forward applies the inference-only warp.
867
+ """
868
+ if not isinstance(base, AxialRoPE2D):
869
+ raise TypeError("base must be an AxialRoPE2D")
870
+ if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES:
871
+ raise ValueError(
872
+ "Inference freq-warp requires base.cfg.coord_mode=PATCH_INDICES"
873
+ )
874
+ if int(ref_h_tokens) <= 0 or int(ref_w_tokens) <= 0:
875
+ raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
876
+ hi_u = float(beta_hi_u)
877
+ lo_u = float(beta_lo_u)
878
+ bend_u = float(beta_bend_u)
879
+ if not math.isfinite(hi_u):
880
+ raise ValueError("beta_hi_u must be finite")
881
+ if not math.isfinite(lo_u):
882
+ raise ValueError("beta_lo_u must be finite")
883
+ if not math.isfinite(bend_u):
884
+ raise ValueError("beta_bend_u must be finite")
885
+ bmax = float(beta_max)
886
+ if not math.isfinite(bmax) or bmax <= 0.0:
887
+ raise ValueError("beta_max must be finite and > 0")
888
+
889
+ class _AxialRoPE2DInferenceWarp(AxialRoPE2D):
890
+ """Inference-only axial RoPE variant with beta-curve knobs."""
891
+
892
+ def __init__(self, *, device: torch.device) -> None:
893
+ super().__init__(head_dim=int(base.head_dim), cfg=base.cfg)
894
+ self.ref_h_tokens: Final[int] = int(ref_h_tokens)
895
+ self.ref_w_tokens: Final[int] = int(ref_w_tokens)
896
+ # Store as buffers so the notebook can mutate by replacing the module.
897
+ self.register_buffer(
898
+ "beta_hi_u",
899
+ torch.tensor(float(hi_u), dtype=torch.float32),
900
+ persistent=False,
901
+ )
902
+ self.register_buffer(
903
+ "beta_lo_u",
904
+ torch.tensor(float(lo_u), dtype=torch.float32),
905
+ persistent=False,
906
+ )
907
+ self.register_buffer(
908
+ "beta_bend_u",
909
+ torch.tensor(float(bend_u), dtype=torch.float32),
910
+ persistent=False,
911
+ )
912
+ self.register_buffer(
913
+ "beta_max",
914
+ torch.tensor(float(bmax), dtype=torch.float32),
915
+ persistent=False,
916
+ )
917
+ self.to(device=device)
918
+ with torch.no_grad():
919
+ self.periods.copy_(
920
+ base.periods.detach().to(device=device, dtype=torch.float32)
921
+ )
922
+
923
+ def forward(
924
+ self,
925
+ *,
926
+ H: int,
927
+ W: int,
928
+ scales: Tensor | None,
929
+ ) -> tuple[Tensor, Tensor]:
930
+ if scales is not None:
931
+ raise ValueError("Inference freq-warp does not support dilation scales")
932
+ if int(H) <= 0 or int(W) <= 0:
933
+ raise ValueError("H and W must be positive for axial RoPE")
934
+ device = self.periods.device
935
+ offset = float(self.cfg.coord_offset)
936
+ coords = _get_patch_index_coords(
937
+ int(H), int(W), device=device, offset=offset
938
+ )
939
+ if coords.dim() != 2 or coords.shape[1] != 2:
940
+ raise RuntimeError("Axial RoPE coords must have shape [HW, 2]")
941
+
942
+ qtr = int(self.periods.numel())
943
+ if qtr <= 0:
944
+ raise RuntimeError("Axial RoPE periods length must be positive")
945
+
946
+ beta_max_t = cast("Tensor", self.beta_max).to(
947
+ device=device, dtype=torch.float32
948
+ )
949
+ beta_hi = beta_max_t * torch.tanh(
950
+ cast("Tensor", self.beta_hi_u).to(device=device, dtype=torch.float32)
951
+ )
952
+ beta_lo = beta_max_t * torch.tanh(
953
+ cast("Tensor", self.beta_lo_u).to(device=device, dtype=torch.float32)
954
+ )
955
+ beta_bend = beta_max_t * torch.tanh(
956
+ cast("Tensor", self.beta_bend_u).to(device=device, dtype=torch.float32)
957
+ )
958
+
959
+ if qtr == 1:
960
+ beta = beta_hi[None]
961
+ else:
962
+ t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float(
963
+ qtr - 1
964
+ )
965
+ bump = 4.0 * t * (1.0 - t)
966
+ beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump
967
+
968
+ s_h = float(int(H)) / float(int(self.ref_h_tokens))
969
+ s_w = float(int(W)) / float(int(self.ref_w_tokens))
970
+ if (
971
+ not math.isfinite(s_h)
972
+ or s_h <= 0.0
973
+ or not math.isfinite(s_w)
974
+ or s_w <= 0.0
975
+ ):
976
+ raise ValueError(
977
+ "H/ref_h_tokens and W/ref_w_tokens must be finite and > 0"
978
+ )
979
+
980
+ periods_h = self.periods * torch.pow(
981
+ torch.tensor(s_h, device=device, dtype=torch.float32), beta
982
+ )
983
+ periods_w = self.periods * torch.pow(
984
+ torch.tensor(s_w, device=device, dtype=torch.float32), beta
985
+ )
986
+ axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
987
+
988
+ angles = (
989
+ float(self.cfg.angle_multiplier)
990
+ * coords[:, :, None].to(dtype=torch.float32)
991
+ / axis_periods[None, :, :].to(dtype=torch.float32)
992
+ )
993
+ match self.cfg.dim_layout:
994
+ case AxialRoPE2DDimLayout.HALF_SPLIT:
995
+ angles = angles.flatten(1, 2).repeat(1, 2)
996
+ case AxialRoPE2DDimLayout.PAIR_INTERLEAVED:
997
+ angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2)
998
+ case _ as unreachable: # pragma: no cover - defensive
999
+ raise RuntimeError(f"Unsupported dim_layout: {unreachable}")
1000
+ if angles.shape != (int(H) * int(W), int(self.head_dim)):
1001
+ raise RuntimeError(
1002
+ "Unexpected angles shape in inference freq-warp: "
1003
+ f"{tuple(angles.shape)} for H={int(H)} W={int(W)}"
1004
+ )
1005
+ return torch.sin(angles), torch.cos(angles)
1006
+
1007
+ return _AxialRoPE2DInferenceWarp(device=base.periods.device)
1008
+
1009
+
1010
+ class AxialRoPE2D(nn.Module):
1011
+ """DINOv3-style axial 2D RoPE sin/cos generator.
1012
+
1013
+ The base periods are fixed by ``AxialRoPE2DConfig``. Optionally, this module
1014
+ can include learnable scalar parameters when using:
1015
+ - ``frequency_aware`` (boundary_log_multiplier), or
1016
+ - ``beta_warp`` (beta_hi_u/beta_lo_u/beta_bend_u), or
1017
+ - ``alpha_warp`` (alpha per-band exponents).
1018
+ """
1019
+
1020
+ periods: Tensor
1021
+
1022
+ def __init__(self, *, head_dim: int, cfg: AxialRoPE2DConfig) -> None:
1023
+ super().__init__()
1024
+ if int(head_dim) <= 0:
1025
+ raise ValueError("head_dim must be positive for AxialRoPE2D")
1026
+ if int(head_dim) % 4 != 0:
1027
+ raise ValueError(
1028
+ "AxialRoPE2D requires head_dim % 4 == 0 (DINOv3 constraint); "
1029
+ f"got head_dim={int(head_dim)}"
1030
+ )
1031
+ if not isinstance(cfg, AxialRoPE2DConfig):
1032
+ raise TypeError("cfg must be an AxialRoPE2DConfig for AxialRoPE2D")
1033
+ self.head_dim: Final[int] = int(head_dim)
1034
+ self.cfg: Final[AxialRoPE2DConfig] = cfg
1035
+ self._d_head: Final[int] = self.head_dim
1036
+ self.register_buffer(
1037
+ "periods",
1038
+ torch.empty(self._d_head // 4, dtype=torch.float32),
1039
+ persistent=True,
1040
+ )
1041
+ if cfg.frequency_aware is None:
1042
+ self.register_parameter("boundary_log_multiplier", None)
1043
+ else:
1044
+ init = float(cfg.frequency_aware.boundary_log_multiplier_init)
1045
+ self.boundary_log_multiplier = nn.Parameter(
1046
+ torch.tensor(init, dtype=torch.float32),
1047
+ requires_grad=True,
1048
+ )
1049
+ if cfg.beta_warp is None:
1050
+ self.register_parameter("beta_hi_u", None)
1051
+ self.register_parameter("beta_lo_u", None)
1052
+ self.register_parameter("beta_bend_u", None)
1053
+ else:
1054
+ beta = cfg.beta_warp
1055
+ self.beta_hi_u = nn.Parameter(
1056
+ torch.tensor(float(beta.beta_hi_u_init), dtype=torch.float32),
1057
+ requires_grad=True,
1058
+ )
1059
+ self.beta_lo_u = nn.Parameter(
1060
+ torch.tensor(float(beta.beta_lo_u_init), dtype=torch.float32),
1061
+ requires_grad=True,
1062
+ )
1063
+ self.beta_bend_u = nn.Parameter(
1064
+ torch.tensor(float(beta.beta_bend_u_init), dtype=torch.float32),
1065
+ requires_grad=True,
1066
+ )
1067
+ if cfg.alpha_warp is None:
1068
+ self.register_parameter("alpha", None)
1069
+ else:
1070
+ qtr = int(self._d_head) // 4
1071
+ if qtr <= 0: # pragma: no cover - defensive
1072
+ raise RuntimeError("AxialRoPE2D periods length must be positive")
1073
+ init = float(cfg.alpha_warp.alpha_init)
1074
+ if not math.isfinite(init):
1075
+ raise RuntimeError("alpha_init must be finite for alpha-warp RoPE")
1076
+ self.alpha = nn.Parameter(
1077
+ torch.full((int(qtr),), init, dtype=torch.float32),
1078
+ requires_grad=True,
1079
+ )
1080
+ self._init_periods()
1081
+
1082
+ def _apply(self, fn): # type: ignore[override]
1083
+ out = super()._apply(fn)
1084
+ with torch.no_grad():
1085
+ self.periods.data = self.periods.data.to(dtype=torch.float32)
1086
+ if self.boundary_log_multiplier is not None:
1087
+ self.boundary_log_multiplier.data = (
1088
+ self.boundary_log_multiplier.data.to(dtype=torch.float32)
1089
+ )
1090
+ if self.beta_hi_u is not None:
1091
+ self.beta_hi_u.data = self.beta_hi_u.data.to(dtype=torch.float32)
1092
+ if self.beta_lo_u is not None:
1093
+ self.beta_lo_u.data = self.beta_lo_u.data.to(dtype=torch.float32)
1094
+ if self.beta_bend_u is not None:
1095
+ self.beta_bend_u.data = self.beta_bend_u.data.to(dtype=torch.float32)
1096
+ if self.alpha is not None:
1097
+ self.alpha.data = self.alpha.data.to(dtype=torch.float32)
1098
+ return out
1099
+
1100
+ def _init_periods(self) -> None:
1101
+ """Initialize per-dimension periods using DINOv3 formulas."""
1102
+ device: torch.device = self.periods.device
1103
+ dtype: torch.dtype = self.periods.dtype
1104
+ d_head = int(self._d_head)
1105
+ qtr = d_head // 4
1106
+ if qtr <= 0:
1107
+ raise RuntimeError("AxialRoPE2D periods length must be positive")
1108
+ if self.cfg.base is not None:
1109
+ base = float(self.cfg.base)
1110
+ exponents = (
1111
+ 2.0
1112
+ * torch.arange(int(qtr), device=device, dtype=dtype)
1113
+ / float(d_head // 2)
1114
+ )
1115
+ periods = torch.tensor(base, device=device, dtype=dtype) ** exponents
1116
+ else:
1117
+ if self.cfg.min_period is None or self.cfg.max_period is None:
1118
+ raise RuntimeError(
1119
+ "AxialRoPE2DConfig must provide min_period and max_period when base is None"
1120
+ )
1121
+ min_p = float(self.cfg.min_period)
1122
+ max_p = float(self.cfg.max_period)
1123
+ base = max_p / min_p
1124
+ exponents = torch.linspace(0.0, 1.0, int(qtr), device=device, dtype=dtype)
1125
+ periods = torch.tensor(base, device=device, dtype=dtype) ** exponents
1126
+ periods = periods / torch.tensor(base, device=device, dtype=dtype)
1127
+ periods = periods * torch.tensor(max_p, device=device, dtype=dtype)
1128
+ self.periods.data = periods
1129
+
1130
+ def forward(
1131
+ self,
1132
+ *,
1133
+ H: int,
1134
+ W: int,
1135
+ scales: Tensor | None,
1136
+ ) -> tuple[Tensor, Tensor]:
1137
+ """Return (sin, cos) buffers for axial 2D RoPE.
1138
+
1139
+ Args:
1140
+ H: Patch-grid height.
1141
+ W: Patch-grid width.
1142
+ scales: Optional per-batch dilation scale (scalar tensor). When
1143
+ None, returns shared sin/cos shaped ``[HW, head_dim]``. When
1144
+ provided, applies the scalar dilation and still returns shared
1145
+ sin/cos shaped ``[HW, head_dim]``.
1146
+ """
1147
+ if int(H) <= 0 or int(W) <= 0:
1148
+ raise ValueError("H and W must be positive for AxialRoPE2D forward")
1149
+ device = self.periods.device
1150
+ offset = float(self.cfg.coord_offset)
1151
+ coords: Tensor
1152
+ match self.cfg.coord_mode:
1153
+ case AxialRoPE2DCoordMode.DINOV3_NORMALIZED:
1154
+ coords = _get_dinov3_normalized_coords(
1155
+ int(H),
1156
+ int(W),
1157
+ device=device,
1158
+ normalize=self.cfg.normalize_coords,
1159
+ offset=offset,
1160
+ )
1161
+ case AxialRoPE2DCoordMode.PATCH_INDICES:
1162
+ coords = _get_patch_index_coords(
1163
+ int(H), int(W), device=device, offset=offset
1164
+ )
1165
+ case _ as unreachable: # pragma: no cover - defensive
1166
+ raise RuntimeError(f"Unsupported coord_mode: {unreachable}")
1167
+ if coords.dim() != 2 or coords.shape[1] != 2:
1168
+ raise RuntimeError("AxialRoPE2D coords must have shape [HW, 2]")
1169
+ if self.cfg.frequency_aware is not None:
1170
+ if scales is not None:
1171
+ raise ValueError(
1172
+ "frequency-aware axial RoPE does not support dilation scales"
1173
+ )
1174
+ if self.boundary_log_multiplier is None:
1175
+ raise RuntimeError(
1176
+ "boundary_log_multiplier parameter missing for frequency-aware RoPE"
1177
+ )
1178
+ ref_h = int(self.cfg.frequency_aware.ref_h_tokens)
1179
+ ref_w = int(self.cfg.frequency_aware.ref_w_tokens)
1180
+ periods_h = lumina_frequency_aware_periods_for_axis(
1181
+ periods=self.periods,
1182
+ axis_len=int(H),
1183
+ ref_axis_len=ref_h,
1184
+ boundary_log_multiplier=self.boundary_log_multiplier,
1185
+ angle_multiplier=float(self.cfg.angle_multiplier),
1186
+ )
1187
+ periods_w = lumina_frequency_aware_periods_for_axis(
1188
+ periods=self.periods,
1189
+ axis_len=int(W),
1190
+ ref_axis_len=ref_w,
1191
+ boundary_log_multiplier=self.boundary_log_multiplier,
1192
+ angle_multiplier=float(self.cfg.angle_multiplier),
1193
+ )
1194
+ axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
1195
+ elif self.cfg.beta_warp is not None:
1196
+ if scales is not None:
1197
+ raise ValueError(
1198
+ "beta-warp axial RoPE does not support dilation scales"
1199
+ )
1200
+ if (
1201
+ self.beta_hi_u is None
1202
+ or self.beta_lo_u is None
1203
+ or self.beta_bend_u is None
1204
+ ):
1205
+ raise RuntimeError("beta warp parameters missing for beta-warp RoPE")
1206
+ beta_cfg = self.cfg.beta_warp
1207
+ ref_h = int(beta_cfg.ref_h_tokens)
1208
+ ref_w = int(beta_cfg.ref_w_tokens)
1209
+ qtr = int(self.periods.numel())
1210
+ if qtr <= 0: # pragma: no cover - defensive (checked elsewhere)
1211
+ raise RuntimeError("AxialRoPE2D periods length must be positive")
1212
+ beta_max = float(beta_cfg.beta_max)
1213
+ if not math.isfinite(beta_max) or beta_max <= 0.0:
1214
+ raise RuntimeError("beta_max must be finite and > 0")
1215
+ beta_max_t = torch.tensor(beta_max, device=device, dtype=torch.float32)
1216
+ beta_hi = beta_max_t * torch.tanh(self.beta_hi_u.to(dtype=torch.float32))
1217
+ beta_lo = beta_max_t * torch.tanh(self.beta_lo_u.to(dtype=torch.float32))
1218
+ beta_bend = beta_max_t * torch.tanh(
1219
+ self.beta_bend_u.to(dtype=torch.float32)
1220
+ )
1221
+ if qtr == 1:
1222
+ beta = beta_hi[None]
1223
+ else:
1224
+ t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float(
1225
+ qtr - 1
1226
+ )
1227
+ bump = 4.0 * t * (1.0 - t)
1228
+ beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump
1229
+
1230
+ s_h = float(int(H)) / float(ref_h)
1231
+ s_w = float(int(W)) / float(ref_w)
1232
+ if (
1233
+ not math.isfinite(s_h)
1234
+ or s_h <= 0.0
1235
+ or not math.isfinite(s_w)
1236
+ or s_w <= 0.0
1237
+ ):
1238
+ raise RuntimeError(
1239
+ "Computed invalid axis scale factors for beta-warp RoPE"
1240
+ )
1241
+ periods_h = self.periods.to(dtype=torch.float32) * torch.pow(
1242
+ torch.tensor(s_h, device=device, dtype=torch.float32), beta
1243
+ )
1244
+ periods_w = self.periods.to(dtype=torch.float32) * torch.pow(
1245
+ torch.tensor(s_w, device=device, dtype=torch.float32), beta
1246
+ )
1247
+ axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
1248
+ elif self.cfg.alpha_warp is not None:
1249
+ if scales is not None:
1250
+ raise ValueError(
1251
+ "alpha-warp axial RoPE does not support dilation scales"
1252
+ )
1253
+ if self.alpha is None:
1254
+ raise RuntimeError("alpha parameter missing for alpha-warp RoPE")
1255
+ alpha_cfg = self.cfg.alpha_warp
1256
+ ref_h = int(alpha_cfg.ref_h_tokens)
1257
+ ref_w = int(alpha_cfg.ref_w_tokens)
1258
+ qtr = int(self.periods.numel())
1259
+ if int(self.alpha.numel()) != qtr:
1260
+ raise RuntimeError(
1261
+ "alpha length must match RoPE periods length for alpha-warp RoPE"
1262
+ )
1263
+ s_h = float(int(H)) / float(ref_h)
1264
+ s_w = float(int(W)) / float(ref_w)
1265
+ if (
1266
+ not math.isfinite(s_h)
1267
+ or s_h <= 0.0
1268
+ or not math.isfinite(s_w)
1269
+ or s_w <= 0.0
1270
+ ):
1271
+ raise RuntimeError(
1272
+ "Computed invalid axis scale factors for alpha-warp RoPE"
1273
+ )
1274
+ alpha = self.alpha.to(device=device, dtype=torch.float32)
1275
+ scale_h = torch.pow(
1276
+ torch.tensor(s_h, device=device, dtype=torch.float32), alpha
1277
+ )
1278
+ scale_w = torch.pow(
1279
+ torch.tensor(s_w, device=device, dtype=torch.float32), alpha
1280
+ )
1281
+ periods_h = self.periods.to(dtype=torch.float32) / scale_h
1282
+ periods_w = self.periods.to(dtype=torch.float32) / scale_w
1283
+ axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
1284
+ else:
1285
+ axis_periods = self.periods[None, :].expand(2, -1).to(dtype=torch.float32)
1286
+
1287
+ # Angles: angle_multiplier * coords / periods, flattened and tiled.
1288
+ angles = (
1289
+ float(self.cfg.angle_multiplier)
1290
+ * coords[:, :, None].to(dtype=torch.float32)
1291
+ / axis_periods[None, :, :].to(dtype=torch.float32)
1292
+ )
1293
+ match self.cfg.dim_layout:
1294
+ case AxialRoPE2DDimLayout.HALF_SPLIT:
1295
+ angles = angles.flatten(1, 2).repeat(1, 2)
1296
+ case AxialRoPE2DDimLayout.PAIR_INTERLEAVED:
1297
+ angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2)
1298
+ case _ as unreachable: # pragma: no cover - defensive
1299
+ raise RuntimeError(f"Unsupported dim_layout: {unreachable}")
1300
+ if angles.shape != (int(H) * int(W), int(self._d_head)):
1301
+ raise RuntimeError(
1302
+ "Unexpected angles shape in AxialRoPE2D: "
1303
+ f"{tuple(angles.shape)} for H={int(H)} W={int(W)}"
1304
+ )
1305
+ if scales is not None:
1306
+ if scales.dim() != 0:
1307
+ raise ValueError(
1308
+ "AxialRoPE2D scales must be a scalar tensor for per-batch dilation; "
1309
+ "per-sample dilation is not supported"
1310
+ )
1311
+ angles = angles * scales.to(device=device, dtype=torch.float32)
1312
+ cos = torch.cos(angles)
1313
+ sin = torch.sin(angles)
1314
+ return sin, cos
1315
+
1316
+
1317
+ def _dy_ntk_periods_for_axis(
1318
+ *,
1319
+ periods: Tensor,
1320
+ axis_len: int,
1321
+ ref_axis_len: int,
1322
+ noise_time: Tensor,
1323
+ lambda_s: float,
1324
+ lambda_t: float,
1325
+ ) -> Tensor:
1326
+ """Return Dy-NTK periods for one spatial axis.
1327
+
1328
+ Raises:
1329
+ ValueError: If token lengths or scheduler values are invalid.
1330
+ """
1331
+
1332
+ if int(axis_len) <= 0 or int(ref_axis_len) <= 0:
1333
+ raise ValueError("axis_len and ref_axis_len must be positive for Dy-NTK")
1334
+ qtr = int(periods.numel())
1335
+ if qtr <= 0:
1336
+ raise ValueError("periods must be non-empty for Dy-NTK")
1337
+ scale = float(int(axis_len)) / float(int(ref_axis_len))
1338
+ if not math.isfinite(scale) or scale <= 0.0:
1339
+ raise ValueError("Dy-NTK axis scale must be finite and > 0")
1340
+ return _dy_ntk_periods_for_scale(
1341
+ periods=periods,
1342
+ scale=scale,
1343
+ noise_time=noise_time,
1344
+ lambda_s=float(lambda_s),
1345
+ lambda_t=float(lambda_t),
1346
+ )
1347
+
1348
+
1349
+ def _dy_ntk_periods_for_scale(
1350
+ *,
1351
+ periods: Tensor,
1352
+ scale: float,
1353
+ noise_time: Tensor,
1354
+ lambda_s: float,
1355
+ lambda_t: float,
1356
+ ) -> Tensor:
1357
+ """Return Dy-NTK periods for a precomputed axis scale."""
1358
+
1359
+ axis_scale = float(scale)
1360
+ if not math.isfinite(axis_scale) or axis_scale <= 0.0:
1361
+ raise ValueError("Dy-NTK scale must be finite and > 0")
1362
+ qtr = int(periods.numel())
1363
+ if qtr <= 0:
1364
+ raise ValueError("periods must be non-empty for Dy-NTK")
1365
+ if scale <= 1.0:
1366
+ return periods.to(dtype=torch.float32)
1367
+ if qtr == 1:
1368
+ exponent = torch.zeros((1,), device=periods.device, dtype=torch.float32)
1369
+ else:
1370
+ exponent = torch.arange(qtr, device=periods.device, dtype=torch.float32) / (
1371
+ float(qtr - 1)
1372
+ )
1373
+ kappa = float(lambda_s) * torch.pow(
1374
+ noise_time.to(device=periods.device, dtype=torch.float32),
1375
+ float(lambda_t),
1376
+ )
1377
+ return periods.to(dtype=torch.float32) * torch.pow(
1378
+ torch.tensor(axis_scale, device=periods.device, dtype=torch.float32),
1379
+ kappa * exponent,
1380
+ )
1381
+
1382
+
1383
+ def _dype_dynamic_exponent(
1384
+ *, noise_time: float, lambda_s: float, lambda_t: float
1385
+ ) -> float:
1386
+ """Return Comfy/DyPE-style dynamic magnitude for normalized noise time."""
1387
+
1388
+ noise = float(noise_time)
1389
+ if not math.isfinite(noise):
1390
+ raise ValueError("DyPE noise_time must be finite")
1391
+ noise = max(0.0, min(1.0, noise))
1392
+ scale = float(lambda_s)
1393
+ exponent = float(lambda_t)
1394
+ if not math.isfinite(scale) or scale <= 0.0:
1395
+ raise ValueError("DyPE lambda_s must be finite and > 0")
1396
+ if not math.isfinite(exponent) or exponent <= 0.0:
1397
+ raise ValueError("DyPE lambda_t must be finite and > 0")
1398
+ return scale * (noise**exponent)
1399
+
1400
+
1401
+ def _dype_correction_factor(
1402
+ *,
1403
+ periods: Tensor,
1404
+ rotations: float,
1405
+ ref_axis_len: int,
1406
+ angle_multiplier: float,
1407
+ ) -> float:
1408
+ """Return fractional band index whose wavelength makes ``rotations`` turns."""
1409
+
1410
+ if int(ref_axis_len) <= 0:
1411
+ raise ValueError("ref_axis_len must be positive for DyPE correction")
1412
+ rot = float(rotations)
1413
+ if not math.isfinite(rot) or rot <= 0.0:
1414
+ raise ValueError("rotations must be finite and > 0")
1415
+ mult = float(angle_multiplier)
1416
+ if not math.isfinite(mult) or mult <= 0.0:
1417
+ raise ValueError("angle_multiplier must be finite and > 0")
1418
+ if int(periods.numel()) < 2:
1419
+ return 0.0
1420
+ periods_cpu = periods.detach().to(device=torch.device("cpu"), dtype=torch.float32)
1421
+ p0 = float(periods_cpu[0].item())
1422
+ p1 = float(periods_cpu[-1].item())
1423
+ if p0 <= 0.0 or p1 <= p0:
1424
+ raise ValueError("periods must be positive and strictly increasing for DyPE")
1425
+ boundary_wavelength = float(int(ref_axis_len)) / rot
1426
+ boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength
1427
+ log_p0 = math.log(p0)
1428
+ log_p1 = math.log(p1)
1429
+ return float(periods.numel() - 1) * (
1430
+ (math.log(boundary_period) - log_p0) / (log_p1 - log_p0)
1431
+ )
1432
+
1433
+
1434
+ def _dype_ramp_mask(
1435
+ *,
1436
+ periods: Tensor,
1437
+ threshold_high_rotations: float,
1438
+ threshold_low_rotations: float,
1439
+ ref_axis_len: int,
1440
+ angle_multiplier: float,
1441
+ ) -> Tensor:
1442
+ """Return YaRN's high-to-low band mask for one dynamic threshold pair."""
1443
+
1444
+ qtr = int(periods.numel())
1445
+ if qtr <= 0:
1446
+ raise ValueError("periods must be non-empty for DyPE ramp mask")
1447
+ device = periods.device
1448
+ if qtr == 1:
1449
+ return torch.ones((1,), device=device, dtype=torch.float32)
1450
+ low = math.floor(
1451
+ _dype_correction_factor(
1452
+ periods=periods,
1453
+ rotations=float(threshold_high_rotations),
1454
+ ref_axis_len=int(ref_axis_len),
1455
+ angle_multiplier=float(angle_multiplier),
1456
+ )
1457
+ )
1458
+ high = math.ceil(
1459
+ _dype_correction_factor(
1460
+ periods=periods,
1461
+ rotations=float(threshold_low_rotations),
1462
+ ref_axis_len=int(ref_axis_len),
1463
+ angle_multiplier=float(angle_multiplier),
1464
+ )
1465
+ )
1466
+ low = max(0, min(qtr - 1, int(low)))
1467
+ high = max(0, min(qtr, int(high)))
1468
+ if low == high:
1469
+ high = min(qtr, low + 1)
1470
+ band = torch.arange(qtr, device=device, dtype=torch.float32)
1471
+ ramp = (band - float(low)) / float(high - low)
1472
+ return 1.0 - torch.clamp(ramp, min=0.0, max=1.0)
1473
+
1474
+
1475
+ def _dy_yarn_periods_for_axis(
1476
+ *,
1477
+ periods: Tensor,
1478
+ linear_scale: float,
1479
+ ntk_scale: float,
1480
+ ref_axis_len: int,
1481
+ noise_time: float,
1482
+ lambda_s: float,
1483
+ cfg: AxialRoPE2DDyPEConfig,
1484
+ angle_multiplier: float,
1485
+ ) -> Tensor:
1486
+ """Return Dy-YaRN periods for one spatial axis."""
1487
+
1488
+ if int(ref_axis_len) <= 0:
1489
+ raise ValueError("ref_axis_len must be positive for Dy-YaRN")
1490
+ linear_s = float(linear_scale)
1491
+ ntk_s = float(ntk_scale)
1492
+ if (
1493
+ not math.isfinite(linear_s)
1494
+ or linear_s <= 0.0
1495
+ or not math.isfinite(ntk_s)
1496
+ or ntk_s <= 0.0
1497
+ ):
1498
+ raise ValueError("Dy-YaRN axis scales must be finite and > 0")
1499
+ periods_f = periods.to(dtype=torch.float32)
1500
+ if max(linear_s, ntk_s) <= 1.0:
1501
+ return periods_f
1502
+ kappa = _dype_dynamic_exponent(
1503
+ noise_time=float(noise_time),
1504
+ lambda_s=float(lambda_s),
1505
+ lambda_t=float(cfg.lambda_t),
1506
+ )
1507
+ if kappa <= 1e-6:
1508
+ return periods_f
1509
+ freq_base = float(angle_multiplier) / periods_f
1510
+ freq_linear = float(angle_multiplier) / (periods_f * max(1.0, linear_s))
1511
+ periods_ntk = _dy_ntk_periods_for_scale(
1512
+ periods=periods_f,
1513
+ scale=max(1.0, ntk_s),
1514
+ noise_time=torch.ones((), device=periods.device, dtype=torch.float32),
1515
+ lambda_s=1.0,
1516
+ lambda_t=1.0,
1517
+ )
1518
+ freq_ntk = float(angle_multiplier) / periods_ntk
1519
+
1520
+ beta_mask = _dype_ramp_mask(
1521
+ periods=periods_f,
1522
+ threshold_high_rotations=float(cfg.yarn_beta_0) ** kappa,
1523
+ threshold_low_rotations=float(cfg.yarn_beta_1) ** kappa,
1524
+ ref_axis_len=int(ref_axis_len),
1525
+ angle_multiplier=float(angle_multiplier),
1526
+ )
1527
+ freq = freq_linear * (1.0 - beta_mask) + freq_ntk * beta_mask
1528
+
1529
+ gamma_mask = _dype_ramp_mask(
1530
+ periods=periods_f,
1531
+ threshold_high_rotations=float(cfg.yarn_gamma_0) ** kappa,
1532
+ threshold_low_rotations=float(cfg.yarn_gamma_1) ** kappa,
1533
+ ref_axis_len=int(ref_axis_len),
1534
+ angle_multiplier=float(angle_multiplier),
1535
+ )
1536
+ freq = freq * (1.0 - gamma_mask) + freq_base * gamma_mask
1537
+ return float(angle_multiplier) / freq
1538
+
1539
+
1540
+ class AxialRoPE2DDyPE(AxialRoPE2D):
1541
+ """Inference-only axial RoPE wrapper using dynamic position extrapolation."""
1542
+
1543
+ dype_cfg: AxialRoPE2DDyPEConfig
1544
+ dype_noise_time: Tensor
1545
+ dype_noise_time_values: list[float]
1546
+
1547
+ def __init__(self, *, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig) -> None:
1548
+ if not isinstance(base, AxialRoPE2D):
1549
+ raise TypeError("base must be an AxialRoPE2D")
1550
+ if not isinstance(cfg, AxialRoPE2DDyPEConfig):
1551
+ raise TypeError("cfg must be an AxialRoPE2DDyPEConfig")
1552
+ if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES:
1553
+ raise ValueError("DyPE requires patch-index axial RoPE coordinates")
1554
+ super().__init__(head_dim=int(base.head_dim), cfg=base.cfg)
1555
+ self.dype_cfg = cfg # ty: ignore[unresolved-attribute]
1556
+ self.register_buffer(
1557
+ "dype_noise_time",
1558
+ torch.tensor(1.0, dtype=torch.float32),
1559
+ persistent=False,
1560
+ )
1561
+ self.dype_noise_time_values: list[float] = [1.0]
1562
+ with torch.no_grad():
1563
+ self.periods.copy_(base.periods.detach().to(dtype=torch.float32))
1564
+
1565
+ def set_dype_noise_time(self, noise_time: float) -> None:
1566
+ """Set the current normalized diffusion noise time in ``[0, 1]``."""
1567
+
1568
+ t = float(noise_time)
1569
+ if not math.isfinite(t) or t < 0.0 or t > 1.0:
1570
+ raise ValueError("DyPE noise_time must be finite and within [0, 1]")
1571
+ self.dype_noise_time.fill_(t)
1572
+ self.dype_noise_time_values[0] = t
1573
+
1574
+ def _dype_axis_periods(
1575
+ self,
1576
+ *,
1577
+ axis_len: int,
1578
+ ref_axis_len: int,
1579
+ global_scale: float,
1580
+ lambda_s: float,
1581
+ ) -> Tensor:
1582
+ """Return method-specific periods for one spatial axis."""
1583
+
1584
+ cfg = self.dype_cfg
1585
+ axis_scale = float(int(axis_len)) / float(int(ref_axis_len))
1586
+ shared_scale = float(global_scale)
1587
+ if (
1588
+ not math.isfinite(axis_scale)
1589
+ or axis_scale <= 0.0
1590
+ or not math.isfinite(shared_scale)
1591
+ or shared_scale <= 0.0
1592
+ ):
1593
+ raise ValueError("DyPE axis and global scales must be finite and > 0")
1594
+ match cfg.method:
1595
+ case DyPERoPEMethod.DY_NTK:
1596
+ return _dy_ntk_periods_for_scale(
1597
+ periods=self.periods,
1598
+ scale=shared_scale,
1599
+ noise_time=self.dype_noise_time,
1600
+ lambda_s=float(lambda_s),
1601
+ lambda_t=float(cfg.lambda_t),
1602
+ )
1603
+ case DyPERoPEMethod.VISION_YARN:
1604
+ return _dy_yarn_periods_for_axis(
1605
+ periods=self.periods,
1606
+ linear_scale=axis_scale,
1607
+ ntk_scale=shared_scale,
1608
+ ref_axis_len=int(ref_axis_len),
1609
+ noise_time=float(self.dype_noise_time_values[0]),
1610
+ lambda_s=float(lambda_s),
1611
+ cfg=cfg,
1612
+ angle_multiplier=float(self.cfg.angle_multiplier),
1613
+ )
1614
+ case DyPERoPEMethod.DY_YARN:
1615
+ return _dy_yarn_periods_for_axis(
1616
+ periods=self.periods,
1617
+ linear_scale=shared_scale,
1618
+ ntk_scale=shared_scale,
1619
+ ref_axis_len=int(ref_axis_len),
1620
+ noise_time=float(self.dype_noise_time_values[0]),
1621
+ lambda_s=float(lambda_s),
1622
+ cfg=cfg,
1623
+ angle_multiplier=float(self.cfg.angle_multiplier),
1624
+ )
1625
+ case _ as unreachable:
1626
+ raise RuntimeError(f"Unsupported DyPE method: {unreachable}")
1627
+
1628
+ def forward(
1629
+ self,
1630
+ *,
1631
+ H: int,
1632
+ W: int,
1633
+ scales: Tensor | None,
1634
+ ) -> tuple[Tensor, Tensor]:
1635
+ """Return timestep-aware DyPE sin/cos buffers."""
1636
+
1637
+ if scales is not None:
1638
+ raise ValueError("DyPE axial RoPE does not support dilation scales")
1639
+ if int(H) <= 0 or int(W) <= 0:
1640
+ raise ValueError("H and W must be positive for DyPE axial RoPE")
1641
+ device = self.periods.device
1642
+ coords = _get_patch_index_coords(
1643
+ int(H), int(W), device=device, offset=float(self.cfg.coord_offset)
1644
+ )
1645
+ scale_h = float(int(H)) / float(int(self.dype_cfg.ref_h_tokens))
1646
+ scale_w = float(int(W)) / float(int(self.dype_cfg.ref_w_tokens))
1647
+ global_scale = max(scale_h, scale_w)
1648
+ periods_h = self._dype_axis_periods(
1649
+ axis_len=int(H),
1650
+ ref_axis_len=int(self.dype_cfg.ref_h_tokens),
1651
+ global_scale=global_scale,
1652
+ lambda_s=float(self.dype_cfg.lambda_s),
1653
+ )
1654
+ periods_w = self._dype_axis_periods(
1655
+ axis_len=int(W),
1656
+ ref_axis_len=int(self.dype_cfg.ref_w_tokens),
1657
+ global_scale=global_scale,
1658
+ lambda_s=float(self.dype_cfg.lambda_s),
1659
+ )
1660
+ axis_periods = torch.stack([periods_h, periods_w], dim=0)
1661
+ angles = (
1662
+ float(self.cfg.angle_multiplier)
1663
+ * coords[:, :, None].to(dtype=torch.float32)
1664
+ / axis_periods[None, :, :].to(dtype=torch.float32)
1665
+ )
1666
+ match self.cfg.dim_layout:
1667
+ case AxialRoPE2DDimLayout.HALF_SPLIT:
1668
+ angles = angles.flatten(1, 2).repeat(1, 2)
1669
+ case AxialRoPE2DDimLayout.PAIR_INTERLEAVED:
1670
+ angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2)
1671
+ case _ as unreachable:
1672
+ raise RuntimeError(f"Unsupported dim_layout: {unreachable}")
1673
+ expected_shape = (int(H) * int(W), int(self.head_dim))
1674
+ if angles.shape != expected_shape:
1675
+ raise RuntimeError(
1676
+ "Unexpected angles shape in DyPE axial RoPE: "
1677
+ f"{tuple(angles.shape)} for expected {expected_shape}"
1678
+ )
1679
+ sin = torch.sin(angles)
1680
+ cos = torch.cos(angles)
1681
+ if (
1682
+ self.dype_cfg.method in (DyPERoPEMethod.VISION_YARN, DyPERoPEMethod.DY_YARN)
1683
+ and bool(self.dype_cfg.yarn_attention_scale)
1684
+ and global_scale > 1.0
1685
+ ):
1686
+ match self.dype_cfg.method:
1687
+ case DyPERoPEMethod.VISION_YARN:
1688
+ mscale_start = 0.1 * math.log(global_scale) + 1.0
1689
+ kappa = _dype_dynamic_exponent(
1690
+ noise_time=float(self.dype_noise_time_values[0]),
1691
+ lambda_s=1.0,
1692
+ lambda_t=float(self.dype_cfg.lambda_t),
1693
+ )
1694
+ mscale = 1.0 + (mscale_start - 1.0) * kappa
1695
+ case DyPERoPEMethod.DY_YARN:
1696
+ mscale = 1.0 + 0.1 * math.log(global_scale) / math.sqrt(
1697
+ global_scale
1698
+ )
1699
+ case _ as unreachable: # pragma: no cover - guarded above
1700
+ raise RuntimeError(
1701
+ f"Unsupported YaRN attention scale: {unreachable}"
1702
+ )
1703
+ if mscale > 1.0:
1704
+ sin = sin * float(mscale)
1705
+ cos = cos * float(mscale)
1706
+ return sin, cos
1707
+
1708
+
1709
+ def build_axial_rope2d_dype(
1710
+ *, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig
1711
+ ) -> AxialRoPE2DDyPE:
1712
+ """Build an inference-only DyPE wrapper for an existing axial RoPE."""
1713
+
1714
+ return AxialRoPE2DDyPE(base=base, cfg=cfg).to(device=base.periods.device)
1715
+
1716
+
1717
+ def set_axial_rope2d_dype_noise_time(module: nn.Module, *, noise_time: float) -> bool:
1718
+ """Set DyPE noise time on all axial DyPE modules inside ``module``."""
1719
+
1720
+ updated = False
1721
+ for child in module.modules():
1722
+ match child:
1723
+ case AxialRoPE2DDyPE() as dype:
1724
+ dype.set_dype_noise_time(float(noise_time))
1725
+ updated = True
1726
+ case _:
1727
+ pass
1728
+ return updated
dit/blocks.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dense unconditional DiT blocks used by the DINAC-AE export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from common.norms import RMSNorm
9
+ from common.rope import Rope1D
10
+ from dit.attention_blocks import DitSelfAttentionCore
11
+ from dit.body_config import DiTConditioning
12
+ from dit.mlp import build_dit_mlp, reset_module_parameters
13
+ from dit.mlp_types import MLPType
14
+ from dit.position_encoding import DiTPositionEncoding
15
+
16
+
17
+ def _flatten_tokens(
18
+ x: Tensor, hw: tuple[int, int] | None
19
+ ) -> tuple[Tensor, tuple[int, int], bool]:
20
+ """Return dense tokens plus spatial metadata."""
21
+
22
+ if x.dim() == 4:
23
+ batch, channels, height, width = x.shape
24
+ tokens = x.permute(0, 2, 3, 1).reshape(batch, height * width, channels)
25
+ return tokens, (int(height), int(width)), True
26
+ return x, hw if hw is not None else (int(x.shape[1]), 1), False
27
+
28
+
29
+ def _restore_spatial(tokens: Tensor, hw: tuple[int, int]) -> Tensor:
30
+ """Restore dense tokens to NCHW features."""
31
+
32
+ batch, _sequence_length, width = tokens.shape
33
+ height, spatial_width = hw
34
+ return tokens.transpose(1, 2).reshape(batch, width, height, spatial_width)
35
+
36
+
37
+ class TransformerBlock(nn.Module):
38
+ """Dense pre-norm transformer block kept for import compatibility."""
39
+
40
+ d_model: int
41
+ n_heads: int
42
+ attn_norm: RMSNorm | None
43
+ mlp_norm: RMSNorm | None
44
+ self_attn: DitSelfAttentionCore
45
+ rope_1d: Rope1D | None
46
+ mlp: nn.Module
47
+
48
+ def __init__(
49
+ self,
50
+ *,
51
+ d_model: int,
52
+ n_heads: int,
53
+ mlp_ratio: float,
54
+ mlp_type: MLPType,
55
+ activation_config: object | None = None,
56
+ block_index: int = 0,
57
+ use_norms: bool = True,
58
+ position_encoding: DiTPositionEncoding = DiTPositionEncoding.NONE,
59
+ rope_theta: float | None = None,
60
+ rope_max_position_embeddings: int | None = None,
61
+ ) -> None:
62
+ super().__init__()
63
+ self.d_model = int(d_model)
64
+ self.n_heads = int(n_heads)
65
+ self.attn_norm = RMSNorm(self.d_model) if bool(use_norms) else None
66
+ self.mlp_norm = RMSNorm(self.d_model) if bool(use_norms) else None
67
+ self.self_attn = DitSelfAttentionCore(
68
+ d_model=self.d_model,
69
+ n_heads=self.n_heads,
70
+ position_encoding=position_encoding,
71
+ )
72
+ self.rope_1d = self._build_rope_1d(
73
+ position_encoding=position_encoding,
74
+ rope_theta=rope_theta,
75
+ rope_max_position_embeddings=rope_max_position_embeddings,
76
+ )
77
+ self.mlp = build_dit_mlp(
78
+ mlp_type=mlp_type,
79
+ in_features=self.d_model,
80
+ hidden_budget=int(round(float(mlp_ratio) * self.d_model)),
81
+ activation_config=activation_config,
82
+ block_index=int(block_index),
83
+ bias_up=False,
84
+ bias_down=False,
85
+ )
86
+
87
+ def reset_parameters(self) -> None:
88
+ """Reset attention and MLP parameters."""
89
+
90
+ self.self_attn.reset_parameters()
91
+ reset_module_parameters(self.mlp)
92
+
93
+ def _build_rope_1d(
94
+ self,
95
+ *,
96
+ position_encoding: DiTPositionEncoding,
97
+ rope_theta: float | None,
98
+ rope_max_position_embeddings: int | None,
99
+ ) -> Rope1D | None:
100
+ """Build 1D RoPE for sequence-only transformer blocks."""
101
+
102
+ match position_encoding:
103
+ case DiTPositionEncoding.NONE:
104
+ return None
105
+ case DiTPositionEncoding.ROPE_1D:
106
+ if rope_theta is None or rope_max_position_embeddings is None:
107
+ raise ValueError("ROPE_1D requires theta and max positions")
108
+ return Rope1D(
109
+ dim=int(self.d_model // self.n_heads),
110
+ max_position_embeddings=int(rope_max_position_embeddings),
111
+ base=float(rope_theta),
112
+ )
113
+ case _ as unreachable:
114
+ raise ValueError(f"Unsupported TransformerBlock RoPE: {unreachable}")
115
+
116
+ def forward(self, tokens: Tensor, *, generator: torch.Generator | None) -> Tensor: # type: ignore[override]
117
+ """Apply dense self-attention and MLP to token sequences."""
118
+
119
+ _ = generator
120
+ attn_in = self.attn_norm(tokens) if self.attn_norm is not None else tokens
121
+ rope_sincos = self._build_rope_sincos(attn_in)
122
+ x = tokens + self.self_attn(attn_in, rope_sincos=rope_sincos)
123
+ mlp_in = self.mlp_norm(x) if self.mlp_norm is not None else x
124
+ return x + self.mlp(mlp_in)
125
+
126
+ def _build_rope_sincos(self, tokens: Tensor) -> tuple[Tensor, Tensor] | None:
127
+ """Return dense 1D RoPE sin/cos buffers."""
128
+
129
+ rope = self.rope_1d
130
+ if rope is None:
131
+ return None
132
+ batch = int(tokens.shape[0])
133
+ seqlen = int(tokens.shape[1])
134
+ position_ids = torch.arange(
135
+ seqlen,
136
+ device=tokens.device,
137
+ dtype=torch.int64,
138
+ ).unsqueeze(0)
139
+ position_ids = position_ids.expand(batch, seqlen)
140
+ dummy = tokens.new_empty(batch, self.n_heads, seqlen, rope.dim)
141
+ cos, sin = rope(dummy, position_ids)
142
+ return sin, cos
143
+
144
+
145
+ class DitBlock(nn.Module):
146
+ """Dense unconditional DiT self-attention block."""
147
+
148
+ d: int
149
+ h: int
150
+ dh: int
151
+ hidden_budget: int
152
+ position_encoding: DiTPositionEncoding
153
+ conditioning: DiTConditioning
154
+ adaln: object | None
155
+ gate_attn: nn.Parameter | None
156
+ gate_mlp: nn.Parameter | None
157
+ use_norms: bool
158
+ attn_norm1: RMSNorm
159
+ attn_norm2: RMSNorm
160
+ mlp_norm1: RMSNorm
161
+ mlp_norm2: RMSNorm
162
+ attn_core: DitSelfAttentionCore
163
+ qkv: nn.Linear
164
+ proj_out: nn.Linear
165
+ mlp: nn.Module
166
+
167
+ def __init__(
168
+ self,
169
+ d_model: int,
170
+ n_heads: int,
171
+ mlp_ratio: float,
172
+ *,
173
+ adaln: object | None = None,
174
+ mlp_type: MLPType = MLPType.GELU,
175
+ activation_config: object | None = None,
176
+ block_index: int = 0,
177
+ use_norms: bool = True,
178
+ position_encoding: DiTPositionEncoding = DiTPositionEncoding.NONE,
179
+ conditioning: DiTConditioning = DiTConditioning.UNCOND,
180
+ ) -> None:
181
+ super().__init__()
182
+ if conditioning is not DiTConditioning.UNCOND or adaln is not None:
183
+ raise ValueError("DINAC-AE export only supports unconditional DitBlock")
184
+ self.d = int(d_model)
185
+ self.h = int(n_heads)
186
+ self.dh = int(self.d // self.h)
187
+ self.hidden_budget = int(float(mlp_ratio) * self.d)
188
+ self.position_encoding = position_encoding
189
+ self.conditioning = conditioning
190
+ self.adaln = None
191
+ self.gate_attn = None
192
+ self.gate_mlp = None
193
+ self.use_norms = bool(use_norms)
194
+ self.attn_norm1 = RMSNorm(self.d)
195
+ self.attn_norm2 = RMSNorm(self.d)
196
+ self.mlp_norm1 = RMSNorm(self.d)
197
+ self.mlp_norm2 = RMSNorm(self.d)
198
+ self.attn_core = DitSelfAttentionCore(
199
+ d_model=self.d,
200
+ n_heads=self.h,
201
+ position_encoding=position_encoding,
202
+ )
203
+ self.qkv = self.attn_core.qkv
204
+ self.proj_out = self.attn_core.proj_out
205
+ self.mlp = build_dit_mlp(
206
+ mlp_type=mlp_type,
207
+ in_features=self.d,
208
+ hidden_budget=self.hidden_budget,
209
+ activation_config=activation_config,
210
+ block_index=int(block_index),
211
+ bias_up=False,
212
+ bias_down=False,
213
+ )
214
+ self.reset_parameters()
215
+
216
+ def reset_parameters(self) -> None:
217
+ """Reset attention and MLP parameters."""
218
+
219
+ self.attn_core.reset_parameters()
220
+ reset_module_parameters(self.mlp)
221
+
222
+ def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
223
+ """No-op hook kept for API compatibility."""
224
+
225
+ _ = fullgraph, dynamic
226
+
227
+ def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
228
+ """No-op hook kept for API compatibility."""
229
+
230
+ _ = fullgraph, dynamic
231
+
232
+ def forward(
233
+ self,
234
+ x: Tensor,
235
+ hw: tuple[int, int],
236
+ cond_vec: Tensor,
237
+ adaln_m: Tensor | None = None,
238
+ *,
239
+ rope_sincos: tuple[Tensor, Tensor] | None = None,
240
+ generator: torch.Generator | None = None,
241
+ ) -> Tensor:
242
+ """Apply the dense unconditional block to spatial features or tokens."""
243
+
244
+ _ = cond_vec, adaln_m, generator
245
+ tokens, hw_tokens, was_spatial = _flatten_tokens(x, hw)
246
+ attn_in = self.attn_norm1(tokens) if self.use_norms else tokens
247
+ y = self.attn_core(attn_in, rope_sincos=rope_sincos)
248
+ attn_out = self.attn_norm2(y) if self.use_norms else y
249
+ tokens = tokens + attn_out
250
+ mlp_in = self.mlp_norm1(tokens) if self.use_norms else tokens
251
+ mlp_out = self.mlp(mlp_in)
252
+ mlp_out = self.mlp_norm2(mlp_out) if self.use_norms else mlp_out
253
+ tokens = tokens + mlp_out
254
+ if was_spatial:
255
+ return _restore_spatial(tokens, hw_tokens)
256
+ return tokens
257
+
258
+
259
+ __all__ = ["DitBlock", "TransformerBlock"]
dit/body_config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small DiT configuration enums required by the DINAC-AE export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from enum import Enum, auto
7
+
8
+
9
+ class DiTConditioning(Enum):
10
+ """Conditioning strategy for exported DiT blocks."""
11
+
12
+ ADALN = auto()
13
+ GATED_UNCOND = auto()
14
+ UNCOND = auto()
15
+
16
+
17
+ class AdaLNSharingMode(Enum):
18
+ """AdaLN sharing modes retained for auxiliary config imports."""
19
+
20
+ PER_BLOCK = auto()
21
+ SHARED_BASE_LOW_RANK_DELTA = auto()
22
+
23
+
24
+ @dataclass
25
+ class DiTBodyConfig:
26
+ """Minimal body config placeholder for unused auxiliary heads."""
27
+
28
+ depth: int = 1
29
+ d_model: int = 768
30
+ n_heads: int = 12
31
+
32
+
33
+ __all__ = ["AdaLNSharingMode", "DiTBodyConfig", "DiTConditioning"]
dit/mlp.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small MLP factory for DINAC-AE DiT blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import Protocol, cast
7
+
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+
11
+ from dit.mlp_types import MLPType
12
+
13
+
14
+ class Resettable(Protocol):
15
+ """Typing protocol for modules with ``reset_parameters``."""
16
+
17
+ def reset_parameters(self) -> None:
18
+ """Reset module parameters."""
19
+
20
+
21
+ def reset_module_parameters(module: nn.Module) -> None:
22
+ """Reset a module that exposes ``reset_parameters``."""
23
+
24
+ cast(Resettable, module).reset_parameters()
25
+
26
+
27
+ class SimpleActivationMLP(nn.Module):
28
+ """Feedforward MLP: ``down(activation(up(x)))``."""
29
+
30
+ in_features: int
31
+ hidden_features: int
32
+ activation: Callable[[Tensor], Tensor]
33
+ activation_name: str
34
+ up: nn.Linear
35
+ down: nn.Linear
36
+
37
+ def __init__(
38
+ self,
39
+ in_features: int,
40
+ hidden_features: int,
41
+ *,
42
+ activation: Callable[[Tensor], Tensor],
43
+ activation_name: str,
44
+ bias_up: bool,
45
+ bias_down: bool,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.in_features = int(in_features)
49
+ self.hidden_features = int(hidden_features)
50
+ self.activation = activation
51
+ self.activation_name = str(activation_name)
52
+ self.up = nn.Linear(self.in_features, self.hidden_features, bias=bias_up)
53
+ self.down = nn.Linear(self.hidden_features, self.in_features, bias=bias_down)
54
+ self.reset_parameters()
55
+
56
+ def reset_parameters(self) -> None:
57
+ """Reset linear projections."""
58
+
59
+ nn.init.xavier_uniform_(self.up.weight)
60
+ if self.up.bias is not None:
61
+ nn.init.zeros_(self.up.bias)
62
+ nn.init.xavier_uniform_(self.down.weight)
63
+ if self.down.bias is not None:
64
+ nn.init.zeros_(self.down.bias)
65
+
66
+ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
67
+ """Apply the MLP."""
68
+
69
+ return self.down(self.activation(self.up(x)))
70
+
71
+
72
+ def build_dit_mlp(
73
+ *,
74
+ mlp_type: MLPType,
75
+ in_features: int,
76
+ hidden_budget: int,
77
+ activation_config: object | None = None,
78
+ block_index: int = 0,
79
+ bias_up: bool = False,
80
+ bias_down: bool = False,
81
+ ) -> nn.Module:
82
+ """Build the exported MLP variant."""
83
+
84
+ _ = activation_config, block_index
85
+ match mlp_type:
86
+ case MLPType.GELU:
87
+ return SimpleActivationMLP(
88
+ in_features=int(in_features),
89
+ hidden_features=int(hidden_budget),
90
+ activation=F.gelu,
91
+ activation_name="gelu",
92
+ bias_up=bool(bias_up),
93
+ bias_down=bool(bias_down),
94
+ )
95
+ case MLPType.SILU:
96
+ return SimpleActivationMLP(
97
+ in_features=int(in_features),
98
+ hidden_features=int(hidden_budget),
99
+ activation=F.silu,
100
+ activation_name="silu",
101
+ bias_up=bool(bias_up),
102
+ bias_down=bool(bias_down),
103
+ )
104
+ case MLPType.RELU:
105
+ return SimpleActivationMLP(
106
+ in_features=int(in_features),
107
+ hidden_features=int(hidden_budget),
108
+ activation=F.relu,
109
+ activation_name="relu",
110
+ bias_up=bool(bias_up),
111
+ bias_down=bool(bias_down),
112
+ )
113
+ case _ as unreachable:
114
+ raise ValueError(f"Unsupported exported MLP type: {unreachable}")
115
+
116
+
117
+ __all__ = ["SimpleActivationMLP", "build_dit_mlp", "reset_module_parameters"]
dit/mlp_types.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+
5
+
6
+ class MLPType(Enum):
7
+ """MLP implementation variants for DiT blocks.
8
+
9
+ - SWI: baseline SwiGLU MLP
10
+ - SWINE: Sigmoid-gated sine GLU (σ·sin) with trig promoted to float32
11
+ - SWINER: Sigmoid-gated FINER-style chirp (σ·sin(ω₀·((1+|x|)·x)))
12
+ - SPWIDER: sqrt-gated sine GLU (√|a|·sin(ω₀·b))
13
+ - RELU: Plain ReLU-activated feedforward
14
+ - RELU2: ReLU-squared activation (ReLU(x)^2) feedforward
15
+ - SILU: Plain SiLU-activated feedforward
16
+ - GELU: Plain GELU-activated feedforward
17
+ - SIREN: Pure sine-activated MLP
18
+ - SPIDER: Sine with sqrt magnitude (sin(ω₀·x)·√|x|)
19
+ - SINC: Sinc-activated MLP with log-spaced per-channel scales
20
+ - FINER: FINER activation MLP with a fixed global scale (non-learnable)
21
+ - RBF: Low-rank per-patch RBF with Gaussian kernel
22
+ - RBF_ODD: RBF with odd-Gaussian kernel (z·exp(-z^2))
23
+ - RBF_SHARP: RBF with sharpness exponent alpha (exp(-(s·|x-b|)^alpha))
24
+ - RBF_SIREN: RBF using sine basis sin(ω0·(s·(x-b)))
25
+ - RBF_FINER: RBF using FINER (chirp) basis sin(ω0·((1+|z|)·z)), z=s·(x-b)
26
+ - RBF_DAMPED_SINE: RBF using damped sine sin(ω0·z)·exp(-|z|), z=s·(x-b)
27
+ - RBF_SINC: RBF using sinc basis sinc(z)=sin(z)/z with z=s·(x-b)
28
+ """
29
+
30
+ SWI = "swi"
31
+ SWINE = "swine"
32
+ SWINER = "swiner"
33
+ SPWIDER = "spwider"
34
+ RELU = "relu"
35
+ RELU2 = "relu2"
36
+ SILU = "silu"
37
+ GELU = "gelu"
38
+ SIREN = "siren"
39
+ SPIDER = "spider"
40
+ SINC = "sinc"
41
+ FINER = "finer"
42
+ RBF = "rbf"
43
+ RBF_ODD = "rbf_odd"
44
+ RBF_SHARP = "rbf_sharp"
45
+ RBF_SIREN = "rbf_siren"
46
+ RBF_FINER = "rbf_finer"
47
+ RBF_DAMPED_SINE = "rbf_damped_sine"
48
+ RBF_SINC = "rbf_sinc"
49
+
50
+
51
+ __all__ = ["MLPType"]
dit/position_encoding.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Position encoding options used by exported dense DiT blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class DiTPositionEncoding(Enum):
9
+ """Position encoding strategy used inside exported DiT blocks."""
10
+
11
+ ROPE_2D_AXIAL_DILATED = "rope_2d_axial_dilated"
12
+ ROPE_2D_AXIAL_UNNORMALIZED_DILATED = "rope_2d_axial_unnormalized_dilated"
13
+ ROPE_2D_AXIAL_NORMALIZED = "rope_2d_axial_normalized"
14
+ ROPE_2D_AXIAL_UNNORMALIZED = "rope_2d_axial_unnormalized"
15
+ ROPE_2D_AXIAL_FREQ_AWARE = "rope_2d_axial_freq_aware"
16
+ ROPE_2D_AXIAL_BETA_WARP = "rope_2d_axial_beta_warp"
17
+ ROPE_2D_AXIAL_ALPHA_WARP = "rope_2d_axial_alpha_warp"
18
+ ROPE_3D_ZIMAGE = "rope_3d_zimage"
19
+ ROPE_1D = "rope_1d"
20
+ NONE = "none"
21
+
22
+
23
+ __all__ = ["DiTPositionEncoding"]
dit/repa_projection.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DINO token/class alignment head used by the DINAC-AE export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+ from common.norms import RMSNorm
12
+ from dit.axial_rope2d import (
13
+ AxialRoPE2D,
14
+ AxialRoPE2DConfig,
15
+ AxialRoPE2DCoordMode,
16
+ AxialRoPE2DDimLayout,
17
+ AxialRoPE2DNormalizeCoords,
18
+ )
19
+ from dit.blocks import DitBlock
20
+ from dit.body_config import DiTConditioning
21
+ from dit.position_encoding import DiTPositionEncoding
22
+ from dit.xattn_blocks import CrossAttentionBlock, CrossAttentionConfig
23
+
24
+ if TYPE_CHECKING:
25
+ from dit.mlp_types import MLPType
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class DinoTokenAlignmentOutput:
30
+ """Predicted DINO class token and spatial patch tokens."""
31
+
32
+ class_token: Tensor
33
+ spatial_tokens: Tensor
34
+
35
+
36
+ def _prepend_identity_rope_prefix(
37
+ *,
38
+ rope_sincos: tuple[Tensor, Tensor],
39
+ prefix_token_count: int,
40
+ device: torch.device,
41
+ ) -> tuple[Tensor, Tensor]:
42
+ """Prepend no-op RoPE entries for class/register prefix tokens."""
43
+
44
+ sin, cos = rope_sincos
45
+ prefix_shape = (int(prefix_token_count), int(sin.shape[-1]))
46
+ prefix_sin = torch.zeros(prefix_shape, device=device, dtype=sin.dtype)
47
+ prefix_cos = torch.ones(prefix_shape, device=device, dtype=cos.dtype)
48
+ match sin.dim():
49
+ case 2:
50
+ return (
51
+ torch.cat([prefix_sin, sin.to(device=device)], dim=0),
52
+ torch.cat([prefix_cos, cos.to(device=device)], dim=0),
53
+ )
54
+ case 3:
55
+ batch = int(sin.shape[0])
56
+ return (
57
+ torch.cat(
58
+ [
59
+ prefix_sin.unsqueeze(0).expand(batch, -1, -1),
60
+ sin.to(device=device),
61
+ ],
62
+ dim=1,
63
+ ),
64
+ torch.cat(
65
+ [
66
+ prefix_cos.unsqueeze(0).expand(batch, -1, -1),
67
+ cos.to(device=device),
68
+ ],
69
+ dim=1,
70
+ ),
71
+ )
72
+ case _ as unreachable:
73
+ raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}")
74
+
75
+
76
+ class DinoTokenAlignmentHead(nn.Module):
77
+ """Predict DINOv3 spatial tokens and a class token from latent grids."""
78
+
79
+ in_channels: int
80
+ feature_dim: int
81
+ model_dim: int
82
+ register_token_count: int
83
+ in_proj: nn.Conv2d
84
+ initial_class_token: nn.Parameter
85
+ register_tokens: nn.Parameter
86
+ block: DitBlock
87
+ spatial_output_norm: RMSNorm
88
+ class_readout: CrossAttentionBlock
89
+ class_output_norm: RMSNorm
90
+ _axial_rope2d: AxialRoPE2D
91
+
92
+ def __init__(
93
+ self,
94
+ *,
95
+ in_channels: int,
96
+ feature_dim: int,
97
+ model_dim: int,
98
+ head_dim: int,
99
+ mlp_ratio: float,
100
+ mlp_activation: MLPType,
101
+ block_index: int,
102
+ register_token_count: int,
103
+ ) -> None:
104
+ super().__init__()
105
+ if int(feature_dim) != int(model_dim):
106
+ raise ValueError("DINAC-AE class head requires feature_dim == model_dim")
107
+ if int(register_token_count) != 4:
108
+ raise ValueError("DINAC-AE class head requires four register tokens")
109
+ self.register_token_count = int(register_token_count)
110
+ self.in_channels = int(in_channels)
111
+ self.feature_dim = int(feature_dim)
112
+ self.model_dim = int(model_dim)
113
+ self.in_proj = nn.Conv2d(
114
+ self.in_channels,
115
+ self.model_dim,
116
+ kernel_size=1,
117
+ padding=0,
118
+ stride=1,
119
+ bias=True,
120
+ )
121
+ self.initial_class_token = nn.Parameter(torch.empty((1, self.model_dim)))
122
+ self.register_tokens = nn.Parameter(
123
+ torch.empty((self.register_token_count, self.model_dim))
124
+ )
125
+ nn.init.normal_(self.initial_class_token, mean=0.0, std=0.02)
126
+ nn.init.normal_(self.register_tokens, mean=0.0, std=0.02)
127
+ conditioning = DiTConditioning.UNCOND
128
+ self.block = DitBlock(
129
+ d_model=self.model_dim,
130
+ n_heads=int(self.model_dim // int(head_dim)),
131
+ mlp_ratio=float(mlp_ratio),
132
+ mlp_type=mlp_activation,
133
+ block_index=int(block_index),
134
+ use_norms=True,
135
+ position_encoding=DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED,
136
+ conditioning=conditioning,
137
+ )
138
+ self.spatial_output_norm = RMSNorm(self.model_dim, affine=False)
139
+ self.class_readout = CrossAttentionBlock(
140
+ query_dim=self.model_dim,
141
+ context_dim=self.model_dim,
142
+ cfg=CrossAttentionConfig(
143
+ n_heads=int(self.model_dim // int(head_dim)),
144
+ head_dim=int(head_dim),
145
+ query_extra_dim=0,
146
+ context_extra_dim=0,
147
+ mlp_ratio=float(mlp_ratio),
148
+ attn_dropout=0.0,
149
+ mlp_type=mlp_activation,
150
+ activation_config=None,
151
+ use_norms=True,
152
+ block_index=int(block_index) + 1,
153
+ use_attn_residual=True,
154
+ ),
155
+ )
156
+ self.class_output_norm = RMSNorm(self.model_dim, affine=False)
157
+ self._axial_rope2d = AxialRoPE2D(
158
+ head_dim=int(head_dim),
159
+ cfg=AxialRoPE2DConfig(
160
+ base=10_000.0,
161
+ min_period=None,
162
+ max_period=None,
163
+ coord_mode=AxialRoPE2DCoordMode.PATCH_INDICES,
164
+ normalize_coords=AxialRoPE2DNormalizeCoords.MAX,
165
+ dim_layout=AxialRoPE2DDimLayout.PAIR_INTERLEAVED,
166
+ angle_multiplier=1.0,
167
+ coord_offset=0.0,
168
+ frequency_aware=None,
169
+ beta_warp=None,
170
+ alpha_warp=None,
171
+ ),
172
+ )
173
+
174
+ def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
175
+ """No-op hook kept for source API compatibility."""
176
+
177
+ _ = fullgraph, dynamic
178
+
179
+ def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
180
+ """No-op hook kept for source API compatibility."""
181
+
182
+ _ = fullgraph, dynamic
183
+
184
+ def forward(self, latents: Tensor, *, t: Tensor) -> DinoTokenAlignmentOutput:
185
+ """Return predicted class and spatial DINO tokens."""
186
+
187
+ y = self.in_proj(latents)
188
+ batch, _channels, height, width = y.shape
189
+ spatial_tokens = y.flatten(2).transpose(1, 2)
190
+ class_token = self.initial_class_token.to(device=y.device, dtype=y.dtype)
191
+ class_token = class_token.unsqueeze(0).expand(int(batch), -1, -1)
192
+ register_tokens = self.register_tokens.to(device=y.device, dtype=y.dtype)
193
+ register_tokens = register_tokens.unsqueeze(0).expand(int(batch), -1, -1)
194
+ tokens = torch.cat([class_token, register_tokens, spatial_tokens], dim=1)
195
+ rope_sincos = _prepend_identity_rope_prefix(
196
+ rope_sincos=self._axial_rope2d(H=int(height), W=int(width), scales=None),
197
+ prefix_token_count=int(1 + self.register_token_count),
198
+ device=y.device,
199
+ )
200
+ _ = t
201
+ cond = torch.zeros(
202
+ (int(batch), self.model_dim),
203
+ device=y.device,
204
+ dtype=y.dtype,
205
+ )
206
+ tokens = self.block(
207
+ tokens,
208
+ hw=(int(height), int(width)),
209
+ cond_vec=cond,
210
+ adaln_m=None,
211
+ rope_sincos=rope_sincos,
212
+ generator=None,
213
+ )
214
+ class_query = tokens[:, :1, :]
215
+ context = tokens[:, 1:, :]
216
+ class_output = self.class_readout(class_query, context)[:, 0, :]
217
+ class_output = self.class_output_norm(class_output)
218
+ prefix_token_count = int(1 + self.register_token_count)
219
+ predicted_spatial = self.spatial_output_norm(tokens[:, prefix_token_count:, :])
220
+ return DinoTokenAlignmentOutput(
221
+ class_token=class_output,
222
+ spatial_tokens=predicted_spatial,
223
+ )
224
+
225
+
226
+ __all__ = ["DinoTokenAlignmentHead", "DinoTokenAlignmentOutput"]
dit/xattn_blocks.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dense cross-attention block used by the DINAC-AE class-token head."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from torch import Tensor, nn
8
+
9
+ from common.norms import RMSNorm
10
+ from dit.attention_blocks import CrossAttentionCore
11
+ from dit.mlp import build_dit_mlp, reset_module_parameters
12
+ from dit.mlp_types import MLPType
13
+
14
+
15
+ @dataclass
16
+ class CrossAttentionConfig:
17
+ """Configuration for the exported dense cross-attention block."""
18
+
19
+ n_heads: int = 16
20
+ head_dim: int | None = None
21
+ query_extra_dim: int = 0
22
+ context_extra_dim: int = 0
23
+ key_extra_dim: int = 0
24
+ mlp_ratio: float = 2.0
25
+ attn_dropout: float = 0.0
26
+ mlp_type: MLPType = MLPType.GELU
27
+ activation_config: object | None = None
28
+ use_norms: bool = True
29
+ block_index: int = 0
30
+ use_attn_residual: bool = True
31
+
32
+
33
+ class CrossAttentionBlock(nn.Module):
34
+ """Dense pre-norm cross-attention plus residual MLP."""
35
+
36
+ query_dim: int
37
+ context_dim: int
38
+ query_extra_dim: int
39
+ context_extra_dim: int
40
+ key_extra_dim: int
41
+ n_heads: int
42
+ head_dim: int
43
+ attn_dim: int
44
+ use_norms: bool
45
+ attn_dropout: float
46
+ use_attn_residual: bool
47
+ query_norm: RMSNorm | None
48
+ context_norm: RMSNorm | None
49
+ mlp_norm: RMSNorm | None
50
+ q_proj: nn.Linear
51
+ attn_core: CrossAttentionCore
52
+ kv_proj: nn.Linear
53
+ out_proj: nn.Linear
54
+ mlp: nn.Module
55
+
56
+ def __init__(
57
+ self,
58
+ *,
59
+ query_dim: int,
60
+ context_dim: int,
61
+ cfg: CrossAttentionConfig,
62
+ ) -> None:
63
+ super().__init__()
64
+ n_heads = int(cfg.n_heads)
65
+ if cfg.head_dim is None:
66
+ if query_dim % n_heads != 0:
67
+ raise ValueError("query_dim must be divisible by n_heads")
68
+ head_dim = query_dim // n_heads
69
+ else:
70
+ head_dim = int(cfg.head_dim)
71
+ self.query_dim = int(query_dim)
72
+ self.context_dim = int(context_dim)
73
+ self.query_extra_dim = int(cfg.query_extra_dim)
74
+ self.context_extra_dim = int(cfg.context_extra_dim)
75
+ self.key_extra_dim = int(cfg.key_extra_dim)
76
+ self.n_heads = n_heads
77
+ self.head_dim = int(head_dim)
78
+ self.attn_dim = int(self.n_heads * self.head_dim)
79
+ self.use_norms = bool(cfg.use_norms)
80
+ self.attn_dropout = float(cfg.attn_dropout)
81
+ if not cfg.use_attn_residual:
82
+ raise ValueError("DINAC-AE export requires attention residuals")
83
+ self.use_attn_residual = True
84
+ self.query_norm = RMSNorm(self.query_dim) if self.use_norms else None
85
+ self.context_norm = RMSNorm(self.context_dim) if self.use_norms else None
86
+ self.mlp_norm = RMSNorm(query_dim) if self.use_norms else None
87
+ self.q_proj = nn.Linear(
88
+ self.query_dim + self.query_extra_dim, self.attn_dim, bias=False
89
+ )
90
+ self.attn_core = CrossAttentionCore(
91
+ query_dim=query_dim,
92
+ context_dim=context_dim,
93
+ context_extra_dim=self.context_extra_dim,
94
+ key_extra_dim=self.key_extra_dim,
95
+ n_heads=self.n_heads,
96
+ head_dim=self.head_dim,
97
+ attn_dropout=self.attn_dropout,
98
+ )
99
+ self.kv_proj = self.attn_core.kv_proj
100
+ self.out_proj = self.attn_core.out_proj
101
+ hidden = int(round(cfg.mlp_ratio * query_dim))
102
+ self.mlp = build_dit_mlp(
103
+ mlp_type=cfg.mlp_type,
104
+ in_features=query_dim,
105
+ hidden_budget=hidden,
106
+ activation_config=cfg.activation_config,
107
+ block_index=int(cfg.block_index),
108
+ bias_up=False,
109
+ bias_down=False,
110
+ )
111
+ self.reset_parameters()
112
+
113
+ def reset_parameters(self) -> None:
114
+ """Reset projections and MLP parameters."""
115
+
116
+ nn.init.xavier_uniform_(self.q_proj.weight)
117
+ self.attn_core.reset_parameters()
118
+ reset_module_parameters(self.mlp)
119
+
120
+ def forward(
121
+ self,
122
+ query: Tensor,
123
+ context: Tensor,
124
+ *,
125
+ query_extra: Tensor | None = None,
126
+ context_extra: Tensor | None = None,
127
+ key_extra: Tensor | None = None,
128
+ key_padding_mask: Tensor | None = None,
129
+ ) -> Tensor: # type: ignore[override]
130
+ """Run dense cross-attention followed by the residual MLP."""
131
+
132
+ query_tokens = self.query_norm(query) if self.query_norm is not None else query
133
+ if query_extra is not None:
134
+ q_in = query_tokens.new_empty(
135
+ *query_tokens.shape[:-1],
136
+ int(query_tokens.shape[-1]) + int(query_extra.shape[-1]),
137
+ )
138
+ q_in[..., : int(query_tokens.shape[-1])] = query_tokens
139
+ q_in[..., int(query_tokens.shape[-1]) :] = query_extra
140
+ else:
141
+ q_in = query_tokens
142
+ context_tokens = (
143
+ self.context_norm(context) if self.context_norm is not None else context
144
+ )
145
+ if context_extra is not None:
146
+ kv_tokens = context_tokens.new_empty(
147
+ *context_tokens.shape[:-1],
148
+ int(context_tokens.shape[-1]) + int(context_extra.shape[-1]),
149
+ )
150
+ kv_tokens[..., : int(context_tokens.shape[-1])] = context_tokens
151
+ kv_tokens[..., int(context_tokens.shape[-1]) :] = context_extra
152
+ else:
153
+ kv_tokens = context_tokens
154
+ q_attn_tokens = self.q_proj(q_in)
155
+ attn_out = self.attn_core(
156
+ q_attn_tokens,
157
+ kv_tokens,
158
+ training=self.training,
159
+ key_extra=key_extra,
160
+ key_padding_mask=key_padding_mask,
161
+ )
162
+ tokens = query + attn_out
163
+ mlp_in = self.mlp_norm(tokens) if self.mlp_norm is not None else tokens
164
+ return tokens + self.mlp(mlp_in)
165
+
166
+ def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
167
+ """No-op hook kept for the token-alignment head API."""
168
+
169
+ _ = fullgraph, dynamic
170
+
171
+ def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
172
+ """No-op hook kept for the token-alignment head API."""
173
+
174
+ _ = fullgraph, dynamic
175
+
176
+
177
+ __all__ = ["CrossAttentionBlock", "CrossAttentionConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b610db6eb9ba995f14ddfbe3ca683044b3b7b4ebe2409fec9465c545a5ec88f7
3
+ size 633374472
technical_report_dinac_ae.md ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DINAC-AE Technical Report
2
+
3
+ `dinac_ae` is a DINO-aligned class-token autoencoder in the SemDisDiffAE
4
+ family: patch-16 spatial latents, a VP diffusion decoder, and semantic
5
+ alignment to frozen vision features.
6
+
7
+ Relative to [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae),
8
+ DINAC-AE replaces the FCDM encoder with a 6-block ViT/DiT-style transformer
9
+ encoder, uses DINOv3 ViT-B/16 features, and extends the latent-to-DINO
10
+ alignment head with a class-token output. The decoder remains in the same FCDM
11
+ VP-diffusion family described in the SemDisDiffAE and capacitor reports.
12
+
13
+ Related reports:
14
+
15
+ - SemDisDiffAE report: https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md
16
+ - full_capacitor report: https://huggingface.co/data-archetype/full_capacitor/blob/main/technical_report_full_capacitor.md
17
+ - capacitor_decoder report: https://huggingface.co/data-archetype/capacitor_decoder/blob/main/technical_report_capacitor_decoder.md
18
+
19
+ ## 1. Motivation
20
+
21
+ We trained both FCDM-encoder and transformer-encoder variants. The transformer
22
+ encoder latents were easier for downstream DiT models to learn from, so this
23
+ release keeps the FCDM diffusion decoder and changes the encoder.
24
+
25
+ The second change is the class-token output. The DINO alignment head predicts
26
+ patch tokens as before, and is extended with a class-token prediction path.
27
+ `predict_class(latents)` exposes that feature from latents, enabling
28
+ Representation Frechet Distance / FD-loss style objectives without decoding to
29
+ RGB, and empirically helping make the latents more semantically aligned.
30
+
31
+ ## 2. Architecture Summary
32
+
33
+ | Component | SemDisDiffAE | dinac_ae |
34
+ | --- | ---: | ---: |
35
+ | Patch size | 16 | 16 |
36
+ | Latent channels | 128 | 128 |
37
+ | Encoder block family | FCDM | DiT transformer |
38
+ | Encoder width | 896 | 896 |
39
+ | Encoder depth | 4 | 6 |
40
+ | Decoder block family | FCDM | FCDM |
41
+ | Decoder width | 896 | 896 |
42
+ | Decoder depth | 8 | 8 |
43
+ | Decoder skip layout | start/middle/end skip concat | start/middle/end skip concat |
44
+ | DINO alignment | DINOv3 ViT-S/16 patch tokens | DINOv3 ViT-B/16 patch tokens + class token |
45
+
46
+ Parameter counts in the released checkpoint:
47
+
48
+ | Module | Parameters |
49
+ | --- | ---: |
50
+ | Encoder | 64,939,521 |
51
+ | Decoder | 68,133,505 |
52
+ | DINO alignment head | 14,264,320 |
53
+ | Total | 147,337,346 |
54
+
55
+ ## 3. Encoder
56
+
57
+ The encoder is a single-scale transformer patch encoder. As in SemDisDiffAE,
58
+ all blocks operate at the final latent grid resolution.
59
+
60
+ ### 3.1 Patch Embedding
61
+
62
+ The image is first converted to a patch grid with a stride-16 patch projection:
63
+
64
+ ```text
65
+ image [B, 3, H, W]
66
+ -> stride-16 patch projection (3 x 16 x 16 -> 896) [B, 896, h, w]
67
+ ```
68
+
69
+ ### 3.2 DiT Encoder Block
70
+
71
+ The patch grid is processed by `6` unconditional `DitBlock` layers at width
72
+ `896`. Each block has `14` attention heads with head dimension `64`, an MLP
73
+ ratio of `4.0`, and GELU activations. The encoder uses our shared
74
+ AdaLN-capable `DitBlock` implementation with conditioning disabled. As a result
75
+ it keeps that block's RMSNorm sandwich structure: RMSNorm before and after the
76
+ attention branch, RMSNorm before and after the MLP branch, plus per-head RMSNorm
77
+ on Q and K:
78
+
79
+ ```text
80
+ x
81
+ -> RMSNorm
82
+ -> biasless QKV projection
83
+ -> per-head RMSNorm on Q and K
84
+ -> axial 2D RoPE on Q and K
85
+ -> scaled dot-product attention
86
+ -> biasless output projection
87
+ -> RMSNorm
88
+ -> residual add
89
+ -> RMSNorm
90
+ -> biasless Linear(896 -> 3584)
91
+ -> GELU
92
+ -> biasless Linear(3584 -> 896)
93
+ -> RMSNorm
94
+ -> residual add
95
+ ```
96
+
97
+ ### 3.3 Posterior Projection
98
+
99
+ After the six transformer blocks, a pointwise projection maps from `896` to
100
+ `256` channels and splits into mean and logSNR:
101
+
102
+ ```text
103
+ features [B, 896, h, w]
104
+ -> pointwise projection (896 -> 256)
105
+ -> split: mean [B, 128, h, w], logSNR [B, 128, h, w]
106
+ -> posterior mode: sqrt(sigmoid(logSNR)) * mean
107
+ ```
108
+
109
+ The posterior is a VP-parameterized diagonal Gaussian. The mean branch is used
110
+ directly in the posterior mode computation.
111
+
112
+ The transformer encoder uses 2D axial RoPE with unnormalized patch-index
113
+ coordinates:
114
+
115
+ - Position encoding: axial 2D RoPE.
116
+ - Coordinate mode: patch indices.
117
+ - Coordinate normalization: max-coordinate normalization.
118
+ - Pair layout: interleaved pairs.
119
+ - RoPE base: 10000.
120
+
121
+ This RoPE choice is shared by the encoder blocks and DINO alignment head.
122
+
123
+ ## 4. Decoder
124
+
125
+ The decoder is the same FCDM decoder family used by full_capacitor and
126
+ capacitor_decoder:
127
+
128
+ - 8 FCDM decoder blocks.
129
+ - Width 896.
130
+ - 16x16 latent patch grid.
131
+ - 128 latent channels.
132
+ - Start/middle/end skip-concat architecture with 2 start blocks and 2 end
133
+ blocks.
134
+ - Depthwise convolution kernel size 7.
135
+ - GELU MLP activations and SiLU convolution activations.
136
+
137
+ Each decoder FCDM block uses the same single residual path as SemDisDiffAE:
138
+
139
+ ```text
140
+ x
141
+ -> depthwise Conv 7x7
142
+ -> RMSNorm
143
+ -> scale modulation from timestep AdaLN
144
+ -> pointwise projection
145
+ -> GELU
146
+ -> GRN
147
+ -> pointwise projection
148
+ -> gate modulation from timestep AdaLN
149
+ -> residual add
150
+ ```
151
+
152
+ Timestep conditioning uses the low-rank AdaLN scheme from the capacitor
153
+ decoder: a shared base projection plus per-layer low-rank deltas, split into
154
+ `scale` and `gate`. The decoder topology is:
155
+
156
+ ```text
157
+ noisy image x_t
158
+ -> stride-16 image patch projection
159
+ -> concatenate projected latents
160
+ -> 2 start FCDM blocks
161
+ -> 4 middle FCDM blocks
162
+ -> concatenate start and middle activations
163
+ -> 2 end FCDM blocks
164
+ -> patch-output projection + PixelShuffle(16)
165
+ -> x0 prediction
166
+ ```
167
+
168
+ DINAC-AE keeps this decoder path; the architectural changes are in the encoder
169
+ and alignment head.
170
+
171
+ ## 5. DINOv3 Alignment
172
+
173
+ The alignment target is DINOv3 ViT-B/16 trained on LVD1689M:
174
+
175
+ - Teacher: `dino_v3_vit_base_patch16_lvd1689m`.
176
+ - Feature type: `dino_v3_vit_base_patch16_tokens`.
177
+ - Target feature dimension: 768.
178
+ - The loss supervises both the DINO class token and DINO spatial patch tokens.
179
+
180
+ The alignment head first maps unwhitened DINAC-AE latents into DINO token
181
+ space:
182
+
183
+ ```text
184
+ latents [B, 128, h, w]
185
+ -> pointwise projection (128 -> 768)
186
+ -> flatten spatial tokens [B, h*w, 768]
187
+ -> prepend learned class token [B, 1, 768]
188
+ -> prepend 4 learned register tokens [B, 4, 768]
189
+ -> 1 unconditional RoPE `DitBlock` over all tokens
190
+ -> RMSNorm(spatial tokens)
191
+ -> spatial negative-cosine target
192
+ ```
193
+
194
+ The token-alignment block uses the same unconditional RoPE `DitBlock` form as
195
+ the encoder, but at DINO width: `768` channels, `12` heads, head dimension `64`,
196
+ MLP ratio `4.0`, and GELU MLP. The learned prefix tokens receive identity RoPE
197
+ rotation; spatial tokens receive the same axial 2D RoPE configuration as the
198
+ encoder.
199
+
200
+ The class-token path uses an additional residual cross-attention block:
201
+
202
+ ```text
203
+ updated class token [B, 1, 768]
204
+ updated register + spatial tokens [B, 4 + h*w, 768]
205
+ -> class query cross-attends to register + spatial tokens
206
+ -> residual add
207
+ -> RMSNorm
208
+ -> biasless Linear(768 -> 3072)
209
+ -> GELU
210
+ -> biasless Linear(3072 -> 768)
211
+ -> residual add
212
+ -> RMSNorm(class token)
213
+ -> class negative-cosine target
214
+ ```
215
+
216
+ The class readout uses the same DINO width and head geometry as the
217
+ token-alignment block, but replaces self-attention with dense cross-attention.
218
+
219
+ The DINO alignment loss is negative cosine on RMS-normalized features:
220
+
221
+ ```text
222
+ class_negative_cosine_loss + spatial_negative_cosine_loss
223
+ ```
224
+
225
+ The DINO alignment weight is `0.01`. Alignment is applied directly to clean
226
+ latent tokens. Robustness to local token errors is handled separately by
227
+ [random-token logSNR offset regularization](#8-logsnr-offset-regularization).
228
+
229
+ ## 6. Class-Token Output
230
+
231
+ `predict_class(latents)` is part of the public API. It expects the same
232
+ whitened latent convention returned by `encode(...)`, applies dewhitening, and
233
+ runs the DINO alignment head. The returned tensor has shape `[B, 768]` and
234
+ lives in the DINOv3 ViT-B/16 class-token feature space.
235
+
236
+ This output has two intended uses:
237
+
238
+ - It adds a global semantic pressure during autoencoder training, complementing
239
+ the spatial patch-token alignment loss.
240
+ - It provides a latent-space feature endpoint for FD-loss / Representation
241
+ Frechet Distance objectives, avoiding the cost and gradient path of decoding
242
+ latents back to RGB before computing representation statistics.
243
+
244
+ The class-token output is trained by negative cosine alignment to the frozen
245
+ DINOv3 class token.
246
+
247
+ ## 7. Training Losses
248
+
249
+ The checkpoint was trained with these active loss terms:
250
+
251
+ | Loss term | Weight / value |
252
+ | --- | ---: |
253
+ | Main VP diffusion reconstruction loss | 1.0 |
254
+ | DINO class + spatial token alignment | 0.01 |
255
+ | Latent posterior VE variance loss | 0.00003 |
256
+ | Latent log-variance scale penalty | 0.0003 |
257
+
258
+ The main decoder objective is VP diffusion `x_pred` with the SID2 x-prediction
259
+ variant. The timestep sampler is uniform, with logSNR range `[-10, 10]` and
260
+ training logSNR shift `-1.0`.
261
+
262
+ The latent posterior regularization follows the same KL-like variance expansion
263
+ idea described in the SemDisDiffAE report:
264
+ https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md#32-variance-expansion-loss
265
+
266
+ Latent running statistics use momentum `0.0005` and epsilon `0.0001`.
267
+ `encode(...)` returns whitened latents; `decode(...)` and
268
+ `predict_class(...)` dewhiten those latents before applying the decoder or
269
+ class output path.
270
+
271
+ ## 8. LogSNR Offset Regularization
272
+
273
+ During training, 10% of spatial latent tokens across the batch/grid are
274
+ selected at random and their posterior logSNR is shifted by `-2.0` before
275
+ sampling.
276
+
277
+ This injects non-smooth token-level latent errors during training. The decoder
278
+ and DINO alignment heads both see these perturbed latents, encouraging
279
+ robustness to downstream DiT prediction mistakes.
280
+
281
+ ## 9. Training Recipe
282
+
283
+ The model was trained on 12M images using a single NVIDIA RTX PRO 6000
284
+ Blackwell 96GB GPU. Training used two stages:
285
+
286
+ | Stage | Resolution schedule | Batch size | Approx steps |
287
+ | --- | --- | ---: | ---: |
288
+ | Stage 1 | 90% 256-scale AR buckets, 10% 384-scale AR buckets | 128 | 150k |
289
+ | Stage 2 | equal mix of 256/384/512/768/1024 buckets | 64 | 200k |
290
+
291
+ The mixed-resolution second stage is especially important for the transformer
292
+ encoder. In practice, and even with 2D RoPE, transformer blocks tolerate only
293
+ limited resolution extrapolation unless they see the higher-resolution patch
294
+ grids during training.
295
+
296
+ Optimizer/training parameters:
297
+
298
+ | Setting | Stage 1 | Stage 2 |
299
+ | --- | ---: | ---: |
300
+ | Optimizer | AdamW | AdamW |
301
+ | Learning rate | 1e-4 | 5e-5 |
302
+ | Betas | (0.9, 0.99) | (0.9, sqrt(0.98) ~= 0.98995) |
303
+ | Weight decay | 0.0 | 0.0 |
304
+ | EMA decay | 0.9995 | 0.9995 |
305
+ | Warmup | 2,000 steps | 10,000 steps |
306
+ | Precision | AMP BF16, TF32 matmul | AMP BF16, TF32 matmul |
307
+ | Gradient clipping | 1.0 | 1.0 |
308
+ | Optimizer state dtype | BF16 | BF16 |
309
+
310
+ Resolution and accumulation settings:
311
+
312
+ | Resolution | Stage 1 mix | Stage 1 grad accumulation | Stage 2 mix | Stage 2 grad accumulation |
313
+ | ---: | ---: | ---: | ---: | ---: |
314
+ | 256 | 90% | 1 | 20% | 1 |
315
+ | 384 | 10% | 1 | 20% | 1 |
316
+ | 512 | 0% | 0 | 20% | 1 |
317
+ | 768 | 0% | 0 | 20% | 2 |
318
+ | 1024 | 0% | 0 | 20% | 2 |
319
+
320
+ ## 10. Reconstruction Quality
321
+
322
+ Reconstruction quality on `2000` validation images:
323
+
324
+ | Model | Mean PSNR | Std | Median | Min | p5 | p95 | Max |
325
+ | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
326
+ | dinac_ae | 35.19 | 4.53 | 35.06 | 22.44 | 28.02 | 42.43 | 47.31 |
327
+ | FLUX.2 VAE | 36.28 | 4.53 | 36.07 | 22.73 | 28.89 | 43.63 | 47.38 |
328
+
329
+ The 39-image reconstruction viewer includes originals, DINAC-AE
330
+ reconstructions, FLUX.2 VAE references, RGB deltas, and latent PCA:
331
+ https://huggingface.co/spaces/data-archetype/dinac_ae-results
332
+
333
+ The released export recheck on that viewer set gives `35.15 dB` mean PSNR
334
+ (`25.73` min, `45.99` max).
335
+
336
+ ## 11. Class-Token Alignment Results
337
+
338
+ The `predict_class(...)` path is evaluated against the frozen DINOv3 ViT-B/16
339
+ teacher class token on the same `2000` images used for reconstruction PSNR.
340
+
341
+ | Metric | Cosine similarity |
342
+ | --- | ---: |
343
+ | Mean | 0.757458 |
344
+ | Std | 0.076265 |
345
+ | Median | 0.765958 |
346
+ | Min | 0.394647 |
347
+ | p5 | 0.623156 |
348
+ | p10 | 0.656243 |
349
+ | p25 | 0.711337 |
350
+ | p75 | 0.813098 |
351
+ | p90 | 0.849525 |
352
+ | p95 | 0.865219 |
353
+ | Max | 0.932722 |
354
+
355
+ ## 12. Encoder Throughput
356
+
357
+ Encoder timing was measured with the released package on an NVIDIA GeForce RTX
358
+ 5090. Decoder timing is unchanged from the capacitor decoder/full_capacitor
359
+ release because DINAC-AE uses the same decoder architecture.
360
+
361
+ | Resolution | Batch | FLUX.2 encode ms/batch | full_capacitor ms/batch | dinac_ae ms/batch | Speedup vs FLUX.2 | dinac_ae vs full_capacitor |
362
+ | --- | ---: | ---: | ---: | ---: | ---: | ---: |
363
+ | 256x256 | 128 | 383.41 | 42.56 | 50.32 | 7.62x | 1.18x slower |
364
+ | 512x512 | 32 | 353.58 | 44.97 | 52.65 | 6.72x | 1.17x slower |
365
+
366
+ Peak allocated encoder memory:
367
+
368
+ | Resolution | Batch | FLUX.2 MiB | full_capacitor MiB | dinac_ae MiB | Reduction vs FLUX.2 |
369
+ | --- | ---: | ---: | ---: | ---: | ---: |
370
+ | 256x256 | 128 | 12,511.0 | 1,008.2 | 1,637.3 | 86.9% |
371
+ | 512x512 | 32 | 12,511.0 | 1,005.6 | 1,638.5 | 86.9% |
372
+
373
+ The transformer encoder is slightly slower and uses more memory than the
374
+ full_capacitor FCDM encoder, but it remains much faster and much smaller than
375
+ the FLUX.2 VAE encoder.
376
+
377
+ ## References
378
+
379
+ - Oriane Siméoni, Huy V. Vo, Maximilian Seitzer, Federico Baldassarre,
380
+ Maxime Oquab, Cijo Jose, Vasil Khalidov, Marc Szafraniec, Seungeun Yi,
381
+ Michaël Ramamonjisoa, Francisco Massa, Daniel Haziza, Luca Wehrstedt,
382
+ Jianyuan Wang, Timothée Darcet, Théo Moutakanni, Leonel Sentana,
383
+ Claire Roberts, Andrea Vedaldi, Jamie Tolan, John Brandt, Camille Couprie,
384
+ Julien Mairal, Hervé Jégou, Patrick Labatut, and Piotr Bojanowski.
385
+ DINOv3. arXiv:2508.10104, 2025.
386
+ https://arxiv.org/abs/2508.10104
387
+ - Jiawei Yang, Zhengyang Geng, Xuan Ju, Yonglong Tian, and Yue Wang.
388
+ Representation Frechet Loss for Visual Generation. arXiv:2604.28190, 2026.
389
+ https://arxiv.org/abs/2604.28190
390
+ - FD-Loss implementation: https://github.com/Jiawei-Yang/FD-Loss