Commit ·
1b703d5
0
Parent(s):
Upload DINAC-AE export package
Browse files- .gitattributes +35 -0
- README.md +135 -0
- common/norms.py +236 -0
- common/rope.py +536 -0
- config.json +26 -0
- dinac_ae/__init__.py +12 -0
- dinac_ae/adaln.py +75 -0
- dinac_ae/config.py +75 -0
- dinac_ae/decoder.py +163 -0
- dinac_ae/encoder.py +215 -0
- dinac_ae/fcdm_block.py +103 -0
- dinac_ae/model.py +333 -0
- dinac_ae/norms.py +39 -0
- dinac_ae/samplers.py +258 -0
- dinac_ae/straight_through_encoder.py +57 -0
- dinac_ae/time_embed.py +83 -0
- dinac_ae/vp_diffusion.py +152 -0
- dit/attention_blocks.py +240 -0
- dit/axial_rope2d.py +1728 -0
- dit/blocks.py +259 -0
- dit/body_config.py +33 -0
- dit/mlp.py +117 -0
- dit/mlp_types.py +51 -0
- dit/position_encoding.py +23 -0
- dit/repa_projection.py +226 -0
- dit/xattn_blocks.py +177 -0
- model.safetensors +3 -0
- technical_report_dinac_ae.md +390 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- diffusion
|
| 5 |
+
- autoencoder
|
| 6 |
+
- image-reconstruction
|
| 7 |
+
- latent-space
|
| 8 |
+
- dino
|
| 9 |
+
- pytorch
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# data-archetype/dinac_ae
|
| 13 |
+
|
| 14 |
+
**DINAC-AE** is a **DIN**O-**A**ligned **C**lass-token **A**uto**E**ncoder.
|
| 15 |
+
It follows the [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae)
|
| 16 |
+
family: patch-16 spatial latents, a VP diffusion decoder, and DINO-aligned
|
| 17 |
+
representations.
|
| 18 |
+
|
| 19 |
+
Relative to SemDisDiffAE, DINAC-AE changes the encoder from FCDM blocks to a
|
| 20 |
+
6-block ViT/DiT-style transformer encoder and uses DINOv3 ViT-B/16 alignment.
|
| 21 |
+
The latent-to-DINO alignment head is extended to predict the DINO class token
|
| 22 |
+
as well as patch tokens. `predict_class(latents)` exposes that class-token
|
| 23 |
+
feature directly from latents.
|
| 24 |
+
|
| 25 |
+
## 2k PSNR Benchmark
|
| 26 |
+
|
| 27 |
+
| Model | Mean PSNR (dB) | Std (dB) | Median (dB) | P5 (dB) | P95 (dB) |
|
| 28 |
+
|---|---:|---:|---:|---:|---:|
|
| 29 |
+
| dinac_ae | `35.19` | `4.53` | `35.06` | `28.02` | `42.43` |
|
| 30 |
+
| FLUX.2 VAE | `36.28` | `4.53` | `36.07` | `28.89` | `43.63` |
|
| 31 |
+
|
| 32 |
+
Evaluated on `2000` validation images.
|
| 33 |
+
|
| 34 |
+
DINAC-AE targets a compromise between high reconstruction quality, a learnable
|
| 35 |
+
latent space with KL-like variance expansion, DINOv3 alignment, and robustness
|
| 36 |
+
to local token errors.
|
| 37 |
+
|
| 38 |
+
[Results viewer](https://huggingface.co/spaces/data-archetype/dinac_ae-results)
|
| 39 |
+
shows the 39-image reconstruction set with DINAC-AE and FLUX.2 VAE
|
| 40 |
+
reconstructions, RGB differences, and latent PCA.
|
| 41 |
+
The released export recheck on that 39-image set gives `35.15 dB` mean PSNR
|
| 42 |
+
(`25.73` min, `45.99` max).
|
| 43 |
+
|
| 44 |
+
[Full technical report](https://huggingface.co/data-archetype/dinac_ae/blob/main/technical_report_dinac_ae.md)
|
| 45 |
+
|
| 46 |
+
## Encode Throughput
|
| 47 |
+
|
| 48 |
+
Measured on an `NVIDIA GeForce RTX 5090` in `bfloat16`, averaging repeated
|
| 49 |
+
batches per resolution.
|
| 50 |
+
|
| 51 |
+
| Resolution | Batch Size | dinac_ae encode (ms/batch) | FLUX.2 encode (ms/batch) | dinac_ae peak VRAM (MiB) | FLUX.2 peak VRAM (MiB) | Speedup vs FLUX.2 | Peak VRAM Reduction vs FLUX.2 |
|
| 52 |
+
|---:|---:|---:|---:|---:|---:|---:|---:|
|
| 53 |
+
| `256x256` | `128` | `50` | `383` | `1,637` | `12,511` | `7.62x` | `86.9%` |
|
| 54 |
+
| `512x512` | `32` | `53` | `354` | `1,639` | `12,511` | `6.72x` | `86.9%` |
|
| 55 |
+
|
| 56 |
+
The transformer encoder is slightly slower and larger than the full_capacitor
|
| 57 |
+
FCDM encoder, but remains much faster and much smaller than the FLUX.2 VAE
|
| 58 |
+
encoder.
|
| 59 |
+
|
| 60 |
+
## Latent Interface
|
| 61 |
+
|
| 62 |
+
- `encode()` returns DINAC-AE's own whitened latent space.
|
| 63 |
+
- `decode()` expects that same whitened latent space and dewhitens internally.
|
| 64 |
+
- `predict_class()` expects the same whitened latent space, dewhitens
|
| 65 |
+
internally, and predicts a DINOv3 ViT-B/16 class-token feature.
|
| 66 |
+
- `whiten()` and `dewhiten()` are exposed for explicit control.
|
| 67 |
+
- `encode_posterior()` returns the raw exported posterior before whitening.
|
| 68 |
+
- `DinacAEInferenceConfig.num_steps` counts decoder evaluations directly:
|
| 69 |
+
`num_steps=1` means one NFE.
|
| 70 |
+
|
| 71 |
+
The export ships weights in `float32`. The recommended and default runtime path
|
| 72 |
+
is `bfloat16` AMP for the main encoder, decoder, and class-token path, with
|
| 73 |
+
`float32` retained for sensitive operations such as whitening/dewhitening,
|
| 74 |
+
normalization math, RoPE frequency construction, and VP diffusion schedule
|
| 75 |
+
helpers.
|
| 76 |
+
|
| 77 |
+
## Usage
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
import torch
|
| 81 |
+
|
| 82 |
+
from dinac_ae import DinacAE, DinacAEInferenceConfig
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
device = "cuda"
|
| 86 |
+
model = DinacAE.from_pretrained(
|
| 87 |
+
"data-archetype/dinac_ae",
|
| 88 |
+
device=device,
|
| 89 |
+
dtype=torch.bfloat16,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
image = ... # [1, 3, H, W] in [-1, 1], H and W divisible by 16
|
| 93 |
+
|
| 94 |
+
with torch.inference_mode():
|
| 95 |
+
latents = model.encode(image.to(device=device, dtype=torch.bfloat16))
|
| 96 |
+
class_token = model.predict_class(latents)
|
| 97 |
+
recon = model.decode(
|
| 98 |
+
latents,
|
| 99 |
+
height=int(image.shape[-2]),
|
| 100 |
+
width=int(image.shape[-1]),
|
| 101 |
+
inference_config=DinacAEInferenceConfig(num_steps=1),
|
| 102 |
+
)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Details
|
| 106 |
+
|
| 107 |
+
- DINAC-AE uses a `6`-block ViT/DiT-style transformer encoder and an `8`-block
|
| 108 |
+
FCDM decoder.
|
| 109 |
+
- Patch size is `16`, model width is `896`, and latent width is `128`.
|
| 110 |
+
- The DINO alignment head predicts spatial patch tokens and is extended with a
|
| 111 |
+
class-token output in DINOv3 ViT-B/16 feature space.
|
| 112 |
+
- The class-token output is used to improve semantic organization of the latent
|
| 113 |
+
space and to support FD-loss / Representation Frechet Distance objectives
|
| 114 |
+
directly in latent space.
|
| 115 |
+
- `predict_class(latents)` reaches mean cosine similarity `0.757458` against
|
| 116 |
+
the frozen DINOv3 ViT-B/16 teacher class token on the same `2000` images.
|
| 117 |
+
- DINO alignment is applied directly to clean latent tokens. Robustness to
|
| 118 |
+
local token errors is handled by random-token logSNR offset regularization.
|
| 119 |
+
- Results viewer: https://huggingface.co/spaces/data-archetype/dinac_ae-results
|
| 120 |
+
- Related: [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae),
|
| 121 |
+
[full_capacitor](https://huggingface.co/data-archetype/full_capacitor),
|
| 122 |
+
[capacitor_decoder](https://huggingface.co/data-archetype/capacitor_decoder)
|
| 123 |
+
|
| 124 |
+
## Citation
|
| 125 |
+
|
| 126 |
+
```bibtex
|
| 127 |
+
@misc{dinac_ae,
|
| 128 |
+
title = {DINAC-AE: a DINO-aligned class-token diffusion autoencoder},
|
| 129 |
+
author = {data-archetype},
|
| 130 |
+
email = {data-archetype@proton.me},
|
| 131 |
+
year = {2026},
|
| 132 |
+
month = may,
|
| 133 |
+
url = {https://huggingface.co/data-archetype/dinac_ae},
|
| 134 |
+
}
|
| 135 |
+
```
|
common/norms.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections.abc import Sequence
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"ChannelWiseRMSNorm",
|
| 11 |
+
"GlobalRMSNorm",
|
| 12 |
+
"GroupNormF32",
|
| 13 |
+
"LayerNorm",
|
| 14 |
+
"LayerNorm2d",
|
| 15 |
+
"RMSNorm",
|
| 16 |
+
"global_rms_norm",
|
| 17 |
+
"row_norm",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_HALF_PRECISION_DTYPES: tuple[torch.dtype, ...] = (torch.float16, torch.bfloat16)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _cast_to_float32(x: Tensor) -> tuple[Tensor, torch.dtype]:
|
| 25 |
+
"""Return tensor cast to fp32 for compute along with the original dtype."""
|
| 26 |
+
dtype = x.dtype
|
| 27 |
+
if dtype in _HALF_PRECISION_DTYPES:
|
| 28 |
+
return x.float(), dtype
|
| 29 |
+
return x, dtype
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _restore_dtype(x: Tensor, dtype: torch.dtype) -> Tensor:
|
| 33 |
+
return x if x.dtype == dtype else x.to(dtype)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RMSNorm(nn.Module):
|
| 37 |
+
"""Thin wrapper around ``torch.nn.RMSNorm`` that preserves our API.
|
| 38 |
+
|
| 39 |
+
- Keeps an ``_eps`` attribute used by tests.
|
| 40 |
+
- Maps ``affine`` -> ``elementwise_affine``.
|
| 41 |
+
- Delegates all compute to the native implementation.
|
| 42 |
+
|
| 43 |
+
Notes on precision
|
| 44 |
+
- PyTorch ≥ 2.8 computes RMSNorm reductions in ``opmath`` dtype
|
| 45 |
+
(float32 for float16/bfloat16) internally, then restores the input dtype.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, dim: int, eps: float = 1e-6, affine: bool = True) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self._eps: float = float(eps)
|
| 51 |
+
self._impl: nn.RMSNorm = nn.RMSNorm(
|
| 52 |
+
dim, eps=self._eps, elementwise_affine=affine
|
| 53 |
+
)
|
| 54 |
+
self._dim: int = int(dim)
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def weight(self) -> Tensor | None: # expose for tests/compat
|
| 58 |
+
return self._impl.weight
|
| 59 |
+
|
| 60 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
|
| 61 |
+
"""Apply RMSNorm while avoiding dtype-mismatch warnings under AMP.
|
| 62 |
+
|
| 63 |
+
When inputs are bfloat16/float16 under autocast and the stored affine
|
| 64 |
+
weight is float32 (common when model weights remain FP32), PyTorch emits
|
| 65 |
+
a warning about mismatched dtypes and disables the fused path.
|
| 66 |
+
|
| 67 |
+
We pass a view of the weight cast to the input dtype into the functional
|
| 68 |
+
RMSNorm to enable the fused implementation without changing the
|
| 69 |
+
parameter storage dtype (which remains FP32 for stability).
|
| 70 |
+
"""
|
| 71 |
+
# Prefer functional to control the weight dtype for the kernel
|
| 72 |
+
w: Tensor | None = self._impl.weight
|
| 73 |
+
w_cast = w.to(dtype=x.dtype) if w is not None else None
|
| 74 |
+
# Bias is not present in RMSNorm; functional takes (input, shape, weight, eps)
|
| 75 |
+
return F.rms_norm(x, (self._dim,), w_cast, self._eps)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LayerNorm(nn.LayerNorm):
|
| 79 |
+
"""Thin wrapper over ``torch.nn.LayerNorm`` with an ``_eps`` attribute.
|
| 80 |
+
|
| 81 |
+
Notes on precision
|
| 82 |
+
- Native LayerNorm kernels accumulate statistics in ``opmath`` dtype
|
| 83 |
+
(float32 for float16/bfloat16) before casting results back.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
normalized_shape: int | Sequence[int],
|
| 89 |
+
eps: float = 1e-6,
|
| 90 |
+
elementwise_affine: bool = True,
|
| 91 |
+
) -> None:
|
| 92 |
+
shape: int | list[int]
|
| 93 |
+
match normalized_shape:
|
| 94 |
+
case int() as dim:
|
| 95 |
+
shape = dim
|
| 96 |
+
case _:
|
| 97 |
+
shape = [int(v) for v in normalized_shape]
|
| 98 |
+
super().__init__(shape, eps=eps, elementwise_affine=elementwise_affine)
|
| 99 |
+
self._eps: float = float(eps)
|
| 100 |
+
|
| 101 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
|
| 102 |
+
# Delegate to native LayerNorm
|
| 103 |
+
return super().forward(x)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Prefer numerically stable GroupNormF32 below.
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class GroupNormF32(nn.GroupNorm):
|
| 110 |
+
"""Thin wrapper over ``torch.nn.GroupNorm`` with an ``_eps`` attribute.
|
| 111 |
+
|
| 112 |
+
Notes on precision
|
| 113 |
+
- Native GroupNorm uses ``opmath`` accumulation (float32 for
|
| 114 |
+
float16/bfloat16) for statistics and fused scale/bias math; results
|
| 115 |
+
are cast back to the input dtype.
|
| 116 |
+
- Despite the class name, this wrapper does not force a cast; it
|
| 117 |
+
delegates to the native implementation.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
num_groups: int,
|
| 123 |
+
num_channels: int,
|
| 124 |
+
eps: float = 1e-6,
|
| 125 |
+
affine: bool = True,
|
| 126 |
+
) -> None:
|
| 127 |
+
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
|
| 128 |
+
self._eps: float = float(eps)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ChannelWiseRMSNorm(nn.Module):
|
| 132 |
+
"""Channel-wise RMSNorm for NCHW tensors (fast NCHW path).
|
| 133 |
+
|
| 134 |
+
- Normalizes across channels per spatial position without reshaping, using
|
| 135 |
+
a float32 reduction for numerical stability and keeping elementwise ops
|
| 136 |
+
in input dtype for throughput.
|
| 137 |
+
- Supports optional per-channel affine ``weight`` and ``bias``.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.channels: int = int(channels)
|
| 143 |
+
self._eps: float = float(eps)
|
| 144 |
+
self.affine: bool = bool(affine)
|
| 145 |
+
if self.affine:
|
| 146 |
+
self.weight = nn.Parameter(torch.ones(self.channels))
|
| 147 |
+
self.bias = nn.Parameter(torch.zeros(self.channels))
|
| 148 |
+
else:
|
| 149 |
+
self.register_parameter("weight", None)
|
| 150 |
+
self.register_parameter("bias", None)
|
| 151 |
+
|
| 152 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
|
| 153 |
+
if x.dim() < 2:
|
| 154 |
+
return x
|
| 155 |
+
C = x.size(1)
|
| 156 |
+
if self.channels != C:
|
| 157 |
+
raise ValueError(f"ChannelWiseRMSNorm expected C={self.channels}, got {C}")
|
| 158 |
+
# Keep only the reductions in fp32; scale/apply in the input dtype.
|
| 159 |
+
ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
|
| 160 |
+
inv_rms = torch.rsqrt(ms + self._eps) # float32
|
| 161 |
+
y = x * inv_rms.to(dtype=x.dtype)
|
| 162 |
+
if self.affine and self.weight is not None:
|
| 163 |
+
shape = (1, -1) + (1,) * (x.dim() - 2)
|
| 164 |
+
y = y * self.weight.view(shape).to(dtype=x.dtype)
|
| 165 |
+
if self.bias is not None:
|
| 166 |
+
y = y + self.bias.view(shape).to(dtype=x.dtype)
|
| 167 |
+
return y
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def global_rms_norm(x: Tensor, eps: float = 1e-6) -> Tensor:
|
| 171 |
+
"""Project each sample to unit RMS across all non-batch dimensions.
|
| 172 |
+
|
| 173 |
+
This is equivalent to RMSNorm with ``normalized_shape=x.shape[1:]`` and no
|
| 174 |
+
affine parameters. Delegating to the native functional keeps the fast fused
|
| 175 |
+
CUDA path and the same opmath accumulation behavior as ``torch.nn.RMSNorm``.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
if x.dim() < 2:
|
| 179 |
+
return x
|
| 180 |
+
normalized_shape = tuple(int(dim) for dim in x.shape[1:])
|
| 181 |
+
return F.rms_norm(x, normalized_shape, None, eps)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class GlobalRMSNorm(nn.Module):
|
| 185 |
+
"""RMSNorm across all dims except batch — sphere projection for NCHW tensors.
|
| 186 |
+
|
| 187 |
+
Unlike :class:`ChannelWiseRMSNorm` (which normalizes per spatial position
|
| 188 |
+
over channels), this normalizes the *entire* feature volume jointly,
|
| 189 |
+
projecting each sample onto a hypersphere. No learnable parameters.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, eps: float = 1e-6) -> None:
|
| 193 |
+
super().__init__()
|
| 194 |
+
self._eps: float = float(eps)
|
| 195 |
+
|
| 196 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
|
| 197 |
+
return global_rms_norm(x, eps=self._eps)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 201 |
+
"""Channel-wise LayerNorm using native ``F.layer_norm`` on a reshaped view.
|
| 202 |
+
|
| 203 |
+
- Normalizes over channels only for each spatial location (B, h, w).
|
| 204 |
+
- Weight and bias follow the base class semantics (shape [C]).
|
| 205 |
+
|
| 206 |
+
Notes on precision
|
| 207 |
+
- ``F.layer_norm`` calls the native LayerNorm kernel which accumulates in
|
| 208 |
+
``opmath`` dtype (float32 for float16/bfloat16), then casts back.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
|
| 212 |
+
if x.dim() < 3:
|
| 213 |
+
return super().forward(x)
|
| 214 |
+
B, C = x.shape[:2]
|
| 215 |
+
spatial = x.shape[2:]
|
| 216 |
+
x_view = x.permute(0, *range(2, x.dim()), 1).contiguous().view(-1, C)
|
| 217 |
+
y = F.layer_norm(x_view, (C,), self.weight, self.bias, self.eps)
|
| 218 |
+
y = y.view(B, *spatial, C).permute(0, x.dim() - 1, *range(1, x.dim() - 1))
|
| 219 |
+
return y.contiguous()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def row_norm(W: Tensor, eps: float = 1e-6) -> Tensor:
|
| 223 |
+
"""Row-normalise weight matrices along the last dimension.
|
| 224 |
+
|
| 225 |
+
Precision and performance
|
| 226 |
+
- Accumulates the squared sum in float32 without materializing a full fp32
|
| 227 |
+
copy of ``W`` via ``sum(..., dtype=torch.float32)``.
|
| 228 |
+
- Uses ``rsqrt`` and clamps the inverse norm via ``clamp_max(1/eps)`` to
|
| 229 |
+
match ``clamp_min(eps)`` on the denominator.
|
| 230 |
+
- Scales in the input dtype for throughput; callers relying on exact
|
| 231 |
+
float32 scaling should cast explicitly.
|
| 232 |
+
"""
|
| 233 |
+
# Sum of squares in fp32 for stability
|
| 234 |
+
ss = torch.sum(torch.square(W), dim=-1, keepdim=True, dtype=torch.float32)
|
| 235 |
+
inv = torch.rsqrt(ss).clamp_max(1.0 / float(eps)) # float32
|
| 236 |
+
return W * inv.to(dtype=W.dtype)
|
common/rope.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from collections.abc import Callable
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Rope1D(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Rotary Position Embedding (RoPE) 1D.
|
| 13 |
+
|
| 14 |
+
Based on the reference LLaMA implementation (Hugging Face
|
| 15 |
+
`modeling_llama.py`), adapted to this codebase without behavior changes.
|
| 16 |
+
|
| 17 |
+
- dim: per-head dimension
|
| 18 |
+
- max_position_embeddings: length used to precompute cached cos/sin (not required
|
| 19 |
+
by forward)
|
| 20 |
+
- base: RoPE base theta
|
| 21 |
+
|
| 22 |
+
Forward expects:
|
| 23 |
+
- x: (B, H, T, D)
|
| 24 |
+
- position_ids: (B, T) integer positions
|
| 25 |
+
Returns:
|
| 26 |
+
- cos, sin: (B, T, D)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
inv_freq: torch.Tensor
|
| 30 |
+
_cos_cached: torch.Tensor
|
| 31 |
+
_sin_cached: torch.Tensor
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
dim: int,
|
| 36 |
+
max_position_embeddings: int = 2048,
|
| 37 |
+
base: float = 10000.0,
|
| 38 |
+
device: torch.device | None = None,
|
| 39 |
+
scaling_factor: float = 1.0,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
if dim % 2 != 0:
|
| 43 |
+
raise AssertionError("head_dim must be even for RoPE")
|
| 44 |
+
self.scaling_factor: float = float(scaling_factor)
|
| 45 |
+
self.dim: int = int(dim)
|
| 46 |
+
self.max_position_embeddings: int = int(max_position_embeddings)
|
| 47 |
+
self.base: float = float(base)
|
| 48 |
+
inv_freq = self._build_inv_freq(device=device)
|
| 49 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 50 |
+
|
| 51 |
+
# Cached cos/sin (not used in application, but kept for parity with reference)
|
| 52 |
+
self.max_seq_len_cached: int = self.max_position_embeddings
|
| 53 |
+
cos_cached, sin_cached = self._build_cached_trig(device=device)
|
| 54 |
+
self.register_buffer("_cos_cached", cos_cached, persistent=False)
|
| 55 |
+
self.register_buffer("_sin_cached", sin_cached, persistent=False)
|
| 56 |
+
|
| 57 |
+
def _build_inv_freq(self, *, device: torch.device | None) -> torch.Tensor:
|
| 58 |
+
"""Return the RoPE inverse-frequency vector in float32."""
|
| 59 |
+
|
| 60 |
+
return 1.0 / (
|
| 61 |
+
self.base
|
| 62 |
+
** (
|
| 63 |
+
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
|
| 64 |
+
/ float(self.dim)
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def _build_cached_trig(
|
| 69 |
+
self, *, device: torch.device | None
|
| 70 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 71 |
+
"""Return cached RoPE trig tensors in float32."""
|
| 72 |
+
|
| 73 |
+
inv_freq = self._build_inv_freq(device=device)
|
| 74 |
+
t = torch.arange(
|
| 75 |
+
self.max_seq_len_cached,
|
| 76 |
+
device=device,
|
| 77 |
+
dtype=torch.float32,
|
| 78 |
+
)
|
| 79 |
+
t = t / self.scaling_factor
|
| 80 |
+
freqs = torch.outer(t, inv_freq)
|
| 81 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 82 |
+
return emb.cos(), emb.sin()
|
| 83 |
+
|
| 84 |
+
def _apply(
|
| 85 |
+
self,
|
| 86 |
+
fn: Callable[[torch.Tensor], torch.Tensor],
|
| 87 |
+
recurse: bool = True,
|
| 88 |
+
) -> Rope1D:
|
| 89 |
+
"""Apply module moves/casts while preserving fp32 RoPE buffers."""
|
| 90 |
+
|
| 91 |
+
out = super()._apply(fn, recurse=recurse)
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
device = self.inv_freq.device
|
| 94 |
+
self.inv_freq.data = self._build_inv_freq(device=device)
|
| 95 |
+
cos_cached, sin_cached = self._build_cached_trig(device=device)
|
| 96 |
+
self._cos_cached.data = cos_cached
|
| 97 |
+
self._sin_cached.data = sin_cached
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
@torch.no_grad()
|
| 101 |
+
def forward(
|
| 102 |
+
self, x: torch.Tensor, position_ids: torch.Tensor
|
| 103 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 104 |
+
inv_freq_tensor = self._build_inv_freq(device=x.device)
|
| 105 |
+
inv_freq_expanded = (
|
| 106 |
+
inv_freq_tensor[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 107 |
+
)
|
| 108 |
+
position_ids_expanded = position_ids[:, None, :].float() / self.scaling_factor
|
| 109 |
+
device_type = x.device.type
|
| 110 |
+
device_type = (
|
| 111 |
+
device_type
|
| 112 |
+
if isinstance(device_type, str) and device_type != "mps"
|
| 113 |
+
else "cpu"
|
| 114 |
+
)
|
| 115 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 116 |
+
freqs = (
|
| 117 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
| 118 |
+
).transpose(1, 2)
|
| 119 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 120 |
+
cos = emb.cos()
|
| 121 |
+
sin = emb.sin()
|
| 122 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 126 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 127 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 128 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def rotate_half_adjacent(x: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
"""Rotate consecutive pairs in the last dimension.
|
| 133 |
+
|
| 134 |
+
This matches the common EVA-02 / SpeedrunDiT RoPE convention where the last
|
| 135 |
+
dimension is interpreted as pairs ``(x0, x1), (x2, x3), ...``.
|
| 136 |
+
"""
|
| 137 |
+
if x.shape[-1] % 2 != 0:
|
| 138 |
+
raise ValueError("rotate_half_adjacent requires an even last dimension")
|
| 139 |
+
x_pairs = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 140 |
+
x1 = x_pairs[..., 0]
|
| 141 |
+
x2 = x_pairs[..., 1]
|
| 142 |
+
return torch.stack((-x2, x1), dim=-1).reshape_as(x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def apply_rotary_pos_emb(
|
| 146 |
+
q: torch.Tensor,
|
| 147 |
+
k: torch.Tensor,
|
| 148 |
+
cos: torch.Tensor,
|
| 149 |
+
sin: torch.Tensor,
|
| 150 |
+
*,
|
| 151 |
+
unsqueeze_dim: int = 1,
|
| 152 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 153 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 154 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 155 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 156 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 157 |
+
return q_embed, k_embed
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class LearnableRoPE2D(nn.Module):
|
| 161 |
+
r"""
|
| 162 |
+
Learnable mixed 2D RoPE with axial RoPE2D-compatible initialization.
|
| 163 |
+
|
| 164 |
+
- Learnable frequency banks for X and Y.
|
| 165 |
+
- Frequencies can be shared across groups of attention heads (see
|
| 166 |
+
``rope_param_dim``).
|
| 167 |
+
- Angle per pair: theta = x * fx[g, i] + y * fy[g, i]
|
| 168 |
+
- Initialization matches the axial RoPE2D parameterization used by DiTTrunk
|
| 169 |
+
for ``ROPE_2D_AXIAL_FREQ_AWARE`` (AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)):
|
| 170 |
+
- Angle multiplier ``2π``.
|
| 171 |
+
- Period base ``100`` (DINOv3-style), applied per-axis.
|
| 172 |
+
Each head group starts identically (deterministic init) so the learnable
|
| 173 |
+
variant is functionally identical to axial RoPE2D at step 0.
|
| 174 |
+
- Rotation is implemented with real-valued sin/cos to avoid complex tensors
|
| 175 |
+
(torch.compile/inductor cannot codegen complex dtypes).
|
| 176 |
+
|
| 177 |
+
Shapes:
|
| 178 |
+
- Expects q,k of shape (B, H, T, D) with D % 4 == 0.
|
| 179 |
+
- Positions xy: (T, 2) or (B, T, 2), any real dtype (cast to float32).
|
| 180 |
+
- Parameter `freqs`: (2, G, D//2) in float32; index 0 = x, 1 = y.
|
| 181 |
+
|
| 182 |
+
Head grouping / parameter budget
|
| 183 |
+
-------------------------------
|
| 184 |
+
``rope_param_dim`` controls the total number of learned RoPE frequency
|
| 185 |
+
parameters (scalars) for this module.
|
| 186 |
+
|
| 187 |
+
Let:
|
| 188 |
+
- ``head_dim = D`` (per-head width)
|
| 189 |
+
- ``num_heads = H``
|
| 190 |
+
- ``rope_param_dim = P``
|
| 191 |
+
|
| 192 |
+
Then the module uses:
|
| 193 |
+
- ``num_groups = G = P // D``
|
| 194 |
+
- ``heads_per_group = H // G``
|
| 195 |
+
|
| 196 |
+
This is fail-fast: ``P`` must be divisible by ``D`` and ``H`` must be
|
| 197 |
+
divisible by ``G``. When ``rope_param_dim`` is None (default), the module
|
| 198 |
+
uses the classic per-head parameterization with ``P = H * D``.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
head_dim: int,
|
| 204 |
+
*,
|
| 205 |
+
num_heads: int,
|
| 206 |
+
rope_param_dim: int | None = None,
|
| 207 |
+
rope_base: float = 100.0,
|
| 208 |
+
angle_multiplier: float = 2.0 * float(math.pi),
|
| 209 |
+
learnable: bool = True,
|
| 210 |
+
persist_buffers: bool = True,
|
| 211 |
+
) -> None:
|
| 212 |
+
super().__init__()
|
| 213 |
+
if head_dim % 4 != 0:
|
| 214 |
+
raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")
|
| 215 |
+
self.head_dim: int = int(head_dim)
|
| 216 |
+
# Avoid naming collisions with nn.Module.half() (dtype casting helper).
|
| 217 |
+
self.half_dim: int = self.head_dim // 2
|
| 218 |
+
self.num_heads: int = int(num_heads)
|
| 219 |
+
effective_param_dim = (
|
| 220 |
+
int(rope_param_dim)
|
| 221 |
+
if rope_param_dim is not None
|
| 222 |
+
else self.num_heads * self.head_dim
|
| 223 |
+
)
|
| 224 |
+
if effective_param_dim <= 0:
|
| 225 |
+
raise ValueError("rope_param_dim must be positive for LearnableRoPE2D")
|
| 226 |
+
self.rope_param_dim: int = int(effective_param_dim)
|
| 227 |
+
self._learnable: bool = bool(learnable)
|
| 228 |
+
theta = float(rope_base)
|
| 229 |
+
mult = float(angle_multiplier)
|
| 230 |
+
if not math.isfinite(theta) or theta <= 0.0:
|
| 231 |
+
raise ValueError("rope_base must be finite and > 0 for LearnableRoPE2D")
|
| 232 |
+
if not math.isfinite(mult) or mult <= 0.0:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
"angle_multiplier must be finite and > 0 for LearnableRoPE2D"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if self.rope_param_dim % self.head_dim != 0:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
"rope_param_dim must be divisible by head_dim for LearnableRoPE2D "
|
| 240 |
+
f"(got rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
|
| 241 |
+
)
|
| 242 |
+
self.num_groups: int = self.rope_param_dim // self.head_dim
|
| 243 |
+
if self.num_groups <= 0:
|
| 244 |
+
raise RuntimeError("num_groups must be positive for LearnableRoPE2D")
|
| 245 |
+
if self.num_heads % self.num_groups != 0:
|
| 246 |
+
raise ValueError(
|
| 247 |
+
"num_heads must be divisible by (rope_param_dim / head_dim) for LearnableRoPE2D "
|
| 248 |
+
f"(got num_heads={self.num_heads}, num_groups={self.num_groups}, "
|
| 249 |
+
f"rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
|
| 250 |
+
)
|
| 251 |
+
self.heads_per_group: int = self.num_heads // self.num_groups
|
| 252 |
+
if self.heads_per_group <= 0:
|
| 253 |
+
raise RuntimeError("heads_per_group must be positive for LearnableRoPE2D")
|
| 254 |
+
|
| 255 |
+
# Axial-compatible deterministic init:
|
| 256 |
+
# - periods match AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)
|
| 257 |
+
# - angle = 2π * coord / period
|
| 258 |
+
qtr = self.head_dim // 4
|
| 259 |
+
exponents = (
|
| 260 |
+
2.0
|
| 261 |
+
* torch.arange(int(qtr), dtype=torch.float32)
|
| 262 |
+
/ float(self.head_dim // 2)
|
| 263 |
+
)
|
| 264 |
+
periods = torch.tensor(theta, dtype=torch.float32) ** exponents # [qtr]
|
| 265 |
+
axis_freqs = (mult / periods).to(dtype=torch.float32) # [qtr]
|
| 266 |
+
|
| 267 |
+
zeros = torch.zeros_like(axis_freqs)
|
| 268 |
+
# Match AxialRoPE2D(HALF_SPLIT) flatten order: [y-axis, x-axis].
|
| 269 |
+
# Our xy columns are (x, y), so:
|
| 270 |
+
# - x contributes to the second quarter (x-axis part)
|
| 271 |
+
# - y contributes to the first quarter (y-axis part)
|
| 272 |
+
fx_half = torch.cat((zeros, axis_freqs), dim=0) # [half_dim]
|
| 273 |
+
fy_half = torch.cat((axis_freqs, zeros), dim=0) # [half_dim]
|
| 274 |
+
|
| 275 |
+
freqs_x = fx_half.expand(int(self.num_groups), -1).clone()
|
| 276 |
+
freqs_y = fy_half.expand(int(self.num_groups), -1).clone()
|
| 277 |
+
freqs = torch.stack([freqs_x, freqs_y], dim=0) # (2, G, half)
|
| 278 |
+
if self._learnable:
|
| 279 |
+
self.freqs = nn.Parameter(freqs, requires_grad=True)
|
| 280 |
+
else:
|
| 281 |
+
self.register_buffer("freqs", freqs, persistent=persist_buffers)
|
| 282 |
+
|
| 283 |
+
def _apply(
|
| 284 |
+
self,
|
| 285 |
+
fn: Callable[[torch.Tensor], torch.Tensor],
|
| 286 |
+
recurse: bool = True,
|
| 287 |
+
) -> LearnableRoPE2D:
|
| 288 |
+
"""Apply module moves/casts while preserving fp32 frequency tensors."""
|
| 289 |
+
|
| 290 |
+
out = super()._apply(fn, recurse=recurse)
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
self.freqs.data = self.freqs.data.to(dtype=torch.float32)
|
| 293 |
+
return out
|
| 294 |
+
|
| 295 |
+
def _apply_rotary_from_trig(
|
| 296 |
+
self,
|
| 297 |
+
x: torch.Tensor,
|
| 298 |
+
*,
|
| 299 |
+
sin: torch.Tensor,
|
| 300 |
+
cos: torch.Tensor,
|
| 301 |
+
) -> torch.Tensor:
|
| 302 |
+
"""Rotate Q/K using precomputed grouped sin/cos buffers (HALF_SPLIT layout).
|
| 303 |
+
|
| 304 |
+
This matches AxialRoPE2DConfig(dim_layout=HALF_SPLIT) rotation and keeps
|
| 305 |
+
the learnable variant identical at initialization when combined with
|
| 306 |
+
axial-compatible frequency init.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
x: Tensor shaped ``(B, H, T, D)``.
|
| 310 |
+
sin: Sin tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.
|
| 311 |
+
cos: Cos tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Tensor with the same shape/dtype/device as ``x``.
|
| 315 |
+
"""
|
| 316 |
+
if x.dim() != 4:
|
| 317 |
+
raise ValueError("x must be shaped (B, H, T, D)")
|
| 318 |
+
B, H, T, D = x.shape
|
| 319 |
+
if self.num_heads != int(H):
|
| 320 |
+
raise ValueError("num_heads mismatch for LearnableRoPE2D")
|
| 321 |
+
if self.head_dim != int(D):
|
| 322 |
+
raise ValueError("head_dim mismatch for LearnableRoPE2D")
|
| 323 |
+
|
| 324 |
+
if sin.dim() == 3 and cos.dim() == 3:
|
| 325 |
+
sin = sin.unsqueeze(0)
|
| 326 |
+
cos = cos.unsqueeze(0)
|
| 327 |
+
if sin.dim() != 4 or cos.dim() != 4:
|
| 328 |
+
raise RuntimeError("Unexpected sin/cos rank for LearnableRoPE2D")
|
| 329 |
+
if int(D) % 2 != 0:
|
| 330 |
+
raise RuntimeError("LearnableRoPE2D requires even head_dim for HALF_SPLIT")
|
| 331 |
+
half = int(D) // 2
|
| 332 |
+
if int(sin.shape[-1]) != half or int(cos.shape[-1]) != half:
|
| 333 |
+
raise RuntimeError(
|
| 334 |
+
"LearnableRoPE2D expected sin/cos last dim == head_dim//2 "
|
| 335 |
+
f"(got sin={tuple(sin.shape)}, cos={tuple(cos.shape)}, head_dim={int(D)})"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
sin = sin[:, :, None, :, :] # [B, G, 1, T, half]
|
| 339 |
+
cos = cos[:, :, None, :, :] # [B, G, 1, T, half]
|
| 340 |
+
|
| 341 |
+
grouped = x.reshape(
|
| 342 |
+
int(B),
|
| 343 |
+
int(self.num_groups),
|
| 344 |
+
int(self.heads_per_group),
|
| 345 |
+
int(T),
|
| 346 |
+
int(D),
|
| 347 |
+
)
|
| 348 |
+
x1 = grouped[..., :half]
|
| 349 |
+
x2 = grouped[..., half:]
|
| 350 |
+
out1 = x1 * cos - x2 * sin
|
| 351 |
+
out2 = x2 * cos + x1 * sin
|
| 352 |
+
out = torch.cat((out1, out2), dim=-1).reshape(int(B), int(H), int(T), int(D))
|
| 353 |
+
return out.to(dtype=x.dtype)
|
| 354 |
+
|
| 355 |
+
def _compute_mixed_cis(self, xy: torch.Tensor) -> torch.Tensor:
|
| 356 |
+
# Returns complex cis angles with shape (G, T, half) or (B, G, T, half)
|
| 357 |
+
if xy.dim() == 2:
|
| 358 |
+
# (T, 2) -> (G, T, half)
|
| 359 |
+
t_x = xy[:, 0].to(dtype=torch.float32)
|
| 360 |
+
t_y = xy[:, 1].to(dtype=torch.float32)
|
| 361 |
+
with torch.autocast(device_type=t_x.device.type, enabled=False):
|
| 362 |
+
# Memory notes:
|
| 363 |
+
# - Avoid materializing both fx and fy; accumulate in-place into angles.
|
| 364 |
+
# - Avoid torch.ones_like(angles) (full-size allocation); a scalar
|
| 365 |
+
# magnitude broadcasts in torch.polar.
|
| 366 |
+
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
|
| 367 |
+
0
|
| 368 |
+
) # (T, G, half)
|
| 369 |
+
angles.add_(
|
| 370 |
+
t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
|
| 371 |
+
)
|
| 372 |
+
angles = angles.permute(1, 0, 2) # (G, T, half)
|
| 373 |
+
cis = torch.polar(
|
| 374 |
+
torch.ones((), device=angles.device, dtype=angles.dtype), angles
|
| 375 |
+
)
|
| 376 |
+
return cis
|
| 377 |
+
elif xy.dim() == 3:
|
| 378 |
+
# (B, T, 2) -> (B, G, T, half)
|
| 379 |
+
t_x = xy[..., 0].to(dtype=torch.float32)
|
| 380 |
+
t_y = xy[..., 1].to(dtype=torch.float32)
|
| 381 |
+
with torch.autocast(device_type=t_x.device.type, enabled=False):
|
| 382 |
+
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
|
| 383 |
+
0
|
| 384 |
+
).unsqueeze(0)
|
| 385 |
+
angles.add_(
|
| 386 |
+
t_y.unsqueeze(-1).unsqueeze(-1)
|
| 387 |
+
* self.freqs[1].unsqueeze(0).unsqueeze(0)
|
| 388 |
+
)
|
| 389 |
+
angles = angles.permute(0, 2, 1, 3) # (B, G, T, half)
|
| 390 |
+
cis = torch.polar(
|
| 391 |
+
torch.ones((), device=angles.device, dtype=angles.dtype), angles
|
| 392 |
+
)
|
| 393 |
+
return cis
|
| 394 |
+
else:
|
| 395 |
+
raise ValueError("xy must have shape (T,2) or (B,T,2)")
|
| 396 |
+
|
| 397 |
+
def _compute_mixed_angles(self, xy: torch.Tensor) -> torch.Tensor:
|
| 398 |
+
"""Return mixed RoPE2D angles without applying cis/polar.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
xy: XY positions shaped ``(T, 2)`` or ``(B, T, 2)``.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Float tensor of angles shaped ``(G, T, half)`` or ``(B, G, T, half)``.
|
| 405 |
+
"""
|
| 406 |
+
if xy.dim() == 2:
|
| 407 |
+
t_x = xy[:, 0].to(dtype=torch.float32)
|
| 408 |
+
t_y = xy[:, 1].to(dtype=torch.float32)
|
| 409 |
+
with torch.autocast(device_type=t_x.device.type, enabled=False):
|
| 410 |
+
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(0)
|
| 411 |
+
angles.add_(
|
| 412 |
+
t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
|
| 413 |
+
)
|
| 414 |
+
return angles.permute(1, 0, 2)
|
| 415 |
+
if xy.dim() == 3:
|
| 416 |
+
t_x = xy[..., 0].to(dtype=torch.float32)
|
| 417 |
+
t_y = xy[..., 1].to(dtype=torch.float32)
|
| 418 |
+
with torch.autocast(device_type=t_x.device.type, enabled=False):
|
| 419 |
+
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
|
| 420 |
+
0
|
| 421 |
+
).unsqueeze(0)
|
| 422 |
+
angles.add_(
|
| 423 |
+
t_y.unsqueeze(-1).unsqueeze(-1)
|
| 424 |
+
* self.freqs[1].unsqueeze(0).unsqueeze(0)
|
| 425 |
+
)
|
| 426 |
+
return angles.permute(0, 2, 1, 3)
|
| 427 |
+
raise ValueError("xy must have shape (T,2) or (B,T,2)")
|
| 428 |
+
|
| 429 |
+
def _cos_sin_half_from_xy(
|
| 430 |
+
self,
|
| 431 |
+
xy: torch.Tensor,
|
| 432 |
+
*,
|
| 433 |
+
device: torch.device | None = None,
|
| 434 |
+
out_dtype: torch.dtype | None = None,
|
| 435 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 436 |
+
# Helper used in tests to build real-valued cos/sin tensors.
|
| 437 |
+
cis = self._compute_mixed_cis(xy.to(device=device) if device else xy)
|
| 438 |
+
# Convert complex cis to cos/sin (real/imag) with matching shapes
|
| 439 |
+
if cis.is_complex():
|
| 440 |
+
cos_h = cis.real
|
| 441 |
+
sin_h = cis.imag
|
| 442 |
+
else:
|
| 443 |
+
# Should not happen; torch.polar returns complex64/128
|
| 444 |
+
raise RuntimeError("Expected complex cis tensor from polar")
|
| 445 |
+
if out_dtype is not None:
|
| 446 |
+
cos_h = cos_h.to(dtype=out_dtype)
|
| 447 |
+
sin_h = sin_h.to(dtype=out_dtype)
|
| 448 |
+
return cos_h, sin_h
|
| 449 |
+
|
| 450 |
+
def _cos_sin_from_xy(
|
| 451 |
+
self,
|
| 452 |
+
xy: torch.Tensor,
|
| 453 |
+
*,
|
| 454 |
+
device: torch.device | None = None,
|
| 455 |
+
out_dtype: torch.dtype | None = None,
|
| 456 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 457 |
+
cos_h, sin_h = self._cos_sin_half_from_xy(
|
| 458 |
+
xy, device=device, out_dtype=out_dtype
|
| 459 |
+
)
|
| 460 |
+
emb_cos = torch.cat((cos_h, cos_h), dim=-1)
|
| 461 |
+
emb_sin = torch.cat((sin_h, sin_h), dim=-1)
|
| 462 |
+
return emb_cos, emb_sin
|
| 463 |
+
|
| 464 |
+
def rotate_qk(
|
| 465 |
+
self,
|
| 466 |
+
q: torch.Tensor,
|
| 467 |
+
k: torch.Tensor,
|
| 468 |
+
xy: torch.Tensor,
|
| 469 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 470 |
+
if q.dim() != 4 or k.dim() != 4:
|
| 471 |
+
raise ValueError("q,k must be shaped (B,H,T,D)")
|
| 472 |
+
_, H, _, D = q.shape
|
| 473 |
+
if self.num_heads != H:
|
| 474 |
+
raise ValueError("num_heads mismatch for LearnableRoPE2D")
|
| 475 |
+
if self.head_dim != D:
|
| 476 |
+
raise ValueError("head_dim mismatch for LearnableRoPE2D")
|
| 477 |
+
if D % 4 != 0:
|
| 478 |
+
raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")
|
| 479 |
+
|
| 480 |
+
# Use real-valued sin/cos rotation to keep torch.compile/inductor on the
|
| 481 |
+
# fast path (inductor cannot codegen complex tensors).
|
| 482 |
+
angles = self._compute_mixed_angles(xy.to(device=q.device))
|
| 483 |
+
sin = torch.sin(angles)
|
| 484 |
+
cos = torch.cos(angles)
|
| 485 |
+
q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
|
| 486 |
+
k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
|
| 487 |
+
return q_out, k_out
|
| 488 |
+
|
| 489 |
+
def rotate_qk_with_dilation(
|
| 490 |
+
self,
|
| 491 |
+
q: torch.Tensor,
|
| 492 |
+
k: torch.Tensor,
|
| 493 |
+
*,
|
| 494 |
+
xy: torch.Tensor,
|
| 495 |
+
scales: torch.Tensor,
|
| 496 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 497 |
+
"""Rotate Q/K using mixed 2D RoPE with per-sample isotropic dilation.
|
| 498 |
+
|
| 499 |
+
This implements dilation by scaling the RoPE angle, i.e.
|
| 500 |
+
``theta_dilated = scale * theta_base`` where ``theta_base`` comes from the
|
| 501 |
+
undilated XY coordinates.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
q: Query tensor shaped ``(B, H, T, D)``.
|
| 505 |
+
k: Key tensor shaped ``(B, H, T, D)``.
|
| 506 |
+
xy: Base XY coordinates shaped ``(T, 2)`` or ``(B, T, 2)``.
|
| 507 |
+
scales: Per-sample dilation scales shaped ``(B,)``.
|
| 508 |
+
|
| 509 |
+
Raises:
|
| 510 |
+
ValueError: If shapes are inconsistent or scales are not 1D.
|
| 511 |
+
"""
|
| 512 |
+
if q.dim() != 4 or k.dim() != 4:
|
| 513 |
+
raise ValueError("q,k must be shaped (B,H,T,D)")
|
| 514 |
+
B, H, T, D = q.shape
|
| 515 |
+
if self.num_heads != H:
|
| 516 |
+
raise ValueError("num_heads mismatch for LearnableRoPE2D")
|
| 517 |
+
if self.head_dim != D:
|
| 518 |
+
raise ValueError("head_dim mismatch for LearnableRoPE2D")
|
| 519 |
+
if scales.dim() != 1 or scales.shape[0] != B:
|
| 520 |
+
raise ValueError("scales must have shape (B,) matching q batch size")
|
| 521 |
+
if xy.dim() == 2 and xy.shape[0] != T:
|
| 522 |
+
raise ValueError("xy length must match q sequence length")
|
| 523 |
+
if xy.dim() == 3 and (xy.shape[0] != B or xy.shape[1] != T):
|
| 524 |
+
raise ValueError("xy must have shape (B,T,2) matching q batch/sequence")
|
| 525 |
+
if xy.shape[-1] != 2:
|
| 526 |
+
raise ValueError("xy must have last dimension 2")
|
| 527 |
+
|
| 528 |
+
angles = self._compute_mixed_angles(xy.to(device=q.device))
|
| 529 |
+
angles = angles * scales.to(device=q.device, dtype=torch.float32).view(
|
| 530 |
+
B, 1, 1, 1
|
| 531 |
+
)
|
| 532 |
+
sin = torch.sin(angles)
|
| 533 |
+
cos = torch.cos(angles)
|
| 534 |
+
q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
|
| 535 |
+
k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
|
| 536 |
+
return q_out, k_out
|
config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"in_channels": 3,
|
| 3 |
+
"patch_size": 16,
|
| 4 |
+
"model_dim": 896,
|
| 5 |
+
"encoder_depth": 6,
|
| 6 |
+
"decoder_depth": 8,
|
| 7 |
+
"decoder_start_blocks": 2,
|
| 8 |
+
"decoder_end_blocks": 2,
|
| 9 |
+
"bottleneck_dim": 128,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"encoder_mlp_type": "gelu",
|
| 12 |
+
"depthwise_kernel_size": 7,
|
| 13 |
+
"adaln_low_rank_rank": 128,
|
| 14 |
+
"bottleneck_posterior_kind": "diagonal_gaussian",
|
| 15 |
+
"bottleneck_norm_mode": "disabled",
|
| 16 |
+
"logsnr_min": -10.0,
|
| 17 |
+
"logsnr_max": 10.0,
|
| 18 |
+
"pixel_noise_std": 0.558,
|
| 19 |
+
"latent_running_stats_eps": 0.0001,
|
| 20 |
+
"class_head_feature_dim": 768,
|
| 21 |
+
"class_head_model_dim": 768,
|
| 22 |
+
"class_head_head_dim": 64,
|
| 23 |
+
"class_head_mlp_ratio": 4.0,
|
| 24 |
+
"class_head_mlp_type": "gelu",
|
| 25 |
+
"class_head_register_token_count": 4
|
| 26 |
+
}
|
dinac_ae/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DINAC-AE: DINO-aligned class-token autoencoder export."""
|
| 2 |
+
|
| 3 |
+
from .config import DinacAEConfig, DinacAEInferenceConfig
|
| 4 |
+
from .encoder import EncoderPosterior
|
| 5 |
+
from .model import DinacAE
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"DinacAE",
|
| 9 |
+
"DinacAEConfig",
|
| 10 |
+
"DinacAEInferenceConfig",
|
| 11 |
+
"EncoderPosterior",
|
| 12 |
+
]
|
dinac_ae/adaln.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AdaLNScaleGateZeroProjector(nn.Module):
|
| 9 |
+
"""Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
|
| 10 |
+
|
| 11 |
+
Outputs [B, 2*d_model] packed as (scale, gate).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, d_model: int, d_cond: int) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.d_model: int = int(d_model)
|
| 17 |
+
self.d_cond: int = int(d_cond)
|
| 18 |
+
self.act: nn.SiLU = nn.SiLU()
|
| 19 |
+
self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
|
| 20 |
+
nn.init.zeros_(self.proj.weight)
|
| 21 |
+
nn.init.zeros_(self.proj.bias)
|
| 22 |
+
|
| 23 |
+
def project_activated(self, act_cond: Tensor) -> Tensor:
|
| 24 |
+
"""Return packed modulation for a pre-activated conditioning vector."""
|
| 25 |
+
|
| 26 |
+
if act_cond.dim() != 2:
|
| 27 |
+
raise ValueError(
|
| 28 |
+
"AdaLNScaleGateZeroProjector expects act_cond with shape [B, d_cond]"
|
| 29 |
+
)
|
| 30 |
+
if act_cond.shape[1] != self.d_cond:
|
| 31 |
+
raise ValueError(
|
| 32 |
+
f"act_cond width {int(act_cond.shape[1])} does not match d_cond={self.d_cond}"
|
| 33 |
+
)
|
| 34 |
+
return self.proj(act_cond)
|
| 35 |
+
|
| 36 |
+
def forward(self, cond: Tensor) -> Tensor:
|
| 37 |
+
"""Return packed modulation [B, 2*d_model]."""
|
| 38 |
+
if cond.dim() != 2:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"AdaLNScaleGateZeroProjector expects cond with shape [B, d_cond]"
|
| 41 |
+
)
|
| 42 |
+
if cond.shape[1] != self.d_cond:
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"cond width {int(cond.shape[1])} does not match d_cond={self.d_cond}"
|
| 45 |
+
)
|
| 46 |
+
return self.project_activated(self.act(cond))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AdaLNScaleGateZeroLowRankDelta(nn.Module):
|
| 50 |
+
"""Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
|
| 51 |
+
|
| 52 |
+
Zero-initialized up projection preserves zero-output semantics at init.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.d_model: int = int(d_model)
|
| 58 |
+
self.d_cond: int = int(d_cond)
|
| 59 |
+
self.rank: int = int(rank)
|
| 60 |
+
self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
|
| 61 |
+
self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
|
| 62 |
+
nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
|
| 63 |
+
nn.init.zeros_(self.up.weight)
|
| 64 |
+
|
| 65 |
+
def forward(self, act_cond: Tensor) -> Tensor:
|
| 66 |
+
"""Return packed delta modulation [B, 2*d_model]."""
|
| 67 |
+
if act_cond.dim() != 2:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"AdaLNScaleGateZeroLowRankDelta expects act_cond with shape [B, d_cond]"
|
| 70 |
+
)
|
| 71 |
+
if act_cond.shape[1] != self.d_cond:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
f"act_cond width {int(act_cond.shape[1])} does not match d_cond={self.d_cond}"
|
| 74 |
+
)
|
| 75 |
+
return self.up(self.down(act_cond))
|
dinac_ae/config.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frozen model architecture and user-tunable inference configuration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import asdict, dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class DinacAEConfig:
|
| 12 |
+
"""Frozen architecture config stored alongside exported weights."""
|
| 13 |
+
|
| 14 |
+
in_channels: int = 3
|
| 15 |
+
patch_size: int = 16
|
| 16 |
+
model_dim: int = 896
|
| 17 |
+
encoder_depth: int = 6
|
| 18 |
+
decoder_depth: int = 8
|
| 19 |
+
decoder_start_blocks: int = 2
|
| 20 |
+
decoder_end_blocks: int = 2
|
| 21 |
+
bottleneck_dim: int = 128
|
| 22 |
+
mlp_ratio: float = 4.0
|
| 23 |
+
encoder_mlp_type: str = "gelu"
|
| 24 |
+
depthwise_kernel_size: int = 7
|
| 25 |
+
adaln_low_rank_rank: int = 128
|
| 26 |
+
bottleneck_posterior_kind: str = "diagonal_gaussian"
|
| 27 |
+
bottleneck_norm_mode: str = "disabled"
|
| 28 |
+
logsnr_min: float = -10.0
|
| 29 |
+
logsnr_max: float = 10.0
|
| 30 |
+
pixel_noise_std: float = 0.558
|
| 31 |
+
latent_running_stats_eps: float = 1e-4
|
| 32 |
+
class_head_feature_dim: int = 768
|
| 33 |
+
class_head_model_dim: int = 768
|
| 34 |
+
class_head_head_dim: int = 64
|
| 35 |
+
class_head_mlp_ratio: float = 4.0
|
| 36 |
+
class_head_mlp_type: str = "gelu"
|
| 37 |
+
class_head_register_token_count: int = 4
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def latent_channels(self) -> int:
|
| 41 |
+
"""Return the exported latent channel width."""
|
| 42 |
+
|
| 43 |
+
return int(self.bottleneck_dim)
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def effective_patch_size(self) -> int:
|
| 47 |
+
"""Return the image-to-latent stride."""
|
| 48 |
+
|
| 49 |
+
return int(self.patch_size)
|
| 50 |
+
|
| 51 |
+
def save(self, path: str | Path) -> None:
|
| 52 |
+
"""Save config as JSON."""
|
| 53 |
+
|
| 54 |
+
output_path = Path(path)
|
| 55 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 56 |
+
output_path.write_text(json.dumps(asdict(self), indent=2) + "\n")
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def load(cls, path: str | Path) -> DinacAEConfig:
|
| 60 |
+
"""Load config from JSON."""
|
| 61 |
+
|
| 62 |
+
data = json.loads(Path(path).read_text())
|
| 63 |
+
return cls(**data)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class DinacAEInferenceConfig:
|
| 68 |
+
"""User-tunable VP diffusion decode settings."""
|
| 69 |
+
|
| 70 |
+
num_steps: int = 1
|
| 71 |
+
sampler: str = "ddim"
|
| 72 |
+
schedule: str = "linear"
|
| 73 |
+
pdg: bool = False
|
| 74 |
+
pdg_strength: float = 2.0
|
| 75 |
+
seed: int | None = None
|
dinac_ae/decoder.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decoder matching the exported FCDM decoder stack."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
|
| 9 |
+
from .fcdm_block import FCDMBlock
|
| 10 |
+
from .straight_through_encoder import Patchify
|
| 11 |
+
from .time_embed import SinusoidalTimeEmbeddingMLP
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Decoder(nn.Module):
|
| 15 |
+
"""VP diffusion decoder conditioned on encoder latents and timestep.
|
| 16 |
+
|
| 17 |
+
Architecture (skip-concat, 2+4+2 default):
|
| 18 |
+
Patchify x_t -> Fuse with upsampled z
|
| 19 |
+
-> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
|
| 20 |
+
-> Conv1x1 -> PixelShuffle
|
| 21 |
+
|
| 22 |
+
Path-Drop Guidance (PDG) at inference:
|
| 23 |
+
- Replace middle block output with ``path_drop_mask_feature`` to create
|
| 24 |
+
an unconditional prediction, then extrapolate.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
in_channels: int,
|
| 30 |
+
patch_size: int,
|
| 31 |
+
model_dim: int,
|
| 32 |
+
depth: int,
|
| 33 |
+
start_block_count: int,
|
| 34 |
+
end_block_count: int,
|
| 35 |
+
bottleneck_dim: int,
|
| 36 |
+
mlp_ratio: float,
|
| 37 |
+
depthwise_kernel_size: int,
|
| 38 |
+
adaln_low_rank_rank: int,
|
| 39 |
+
) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.patch_size = int(patch_size)
|
| 42 |
+
self.model_dim = int(model_dim)
|
| 43 |
+
|
| 44 |
+
self.patchify = Patchify(
|
| 45 |
+
in_channels,
|
| 46 |
+
patch_size,
|
| 47 |
+
model_dim,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
|
| 51 |
+
self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
|
| 52 |
+
|
| 53 |
+
# Time embedding
|
| 54 |
+
self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
|
| 55 |
+
|
| 56 |
+
# 2-way AdaLN: shared base projector + per-block low-rank deltas
|
| 57 |
+
self.adaln_base = AdaLNScaleGateZeroProjector(
|
| 58 |
+
d_model=model_dim, d_cond=model_dim
|
| 59 |
+
)
|
| 60 |
+
self.adaln_deltas = nn.ModuleList(
|
| 61 |
+
[
|
| 62 |
+
AdaLNScaleGateZeroLowRankDelta(
|
| 63 |
+
d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
|
| 64 |
+
)
|
| 65 |
+
for _ in range(depth)
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Block layout: start + middle + end
|
| 70 |
+
middle_count = depth - start_block_count - end_block_count
|
| 71 |
+
self._middle_start_idx = start_block_count
|
| 72 |
+
self._end_start_idx = start_block_count + middle_count
|
| 73 |
+
|
| 74 |
+
def _make_blocks(count: int) -> nn.ModuleList:
|
| 75 |
+
return nn.ModuleList(
|
| 76 |
+
[
|
| 77 |
+
FCDMBlock(
|
| 78 |
+
model_dim,
|
| 79 |
+
mlp_ratio,
|
| 80 |
+
depthwise_kernel_size=depthwise_kernel_size,
|
| 81 |
+
use_external_adaln=True,
|
| 82 |
+
)
|
| 83 |
+
for _ in range(count)
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.start_blocks = _make_blocks(start_block_count)
|
| 88 |
+
self.middle_blocks = _make_blocks(middle_count)
|
| 89 |
+
self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
|
| 90 |
+
self.end_blocks = _make_blocks(end_block_count)
|
| 91 |
+
|
| 92 |
+
self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
|
| 93 |
+
|
| 94 |
+
self.out_proj = nn.Conv2d(
|
| 95 |
+
model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
|
| 96 |
+
)
|
| 97 |
+
self.unpatchify = nn.PixelShuffle(patch_size)
|
| 98 |
+
|
| 99 |
+
def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
|
| 100 |
+
"""Compute packed AdaLN modulation = shared_base + per-layer delta."""
|
| 101 |
+
act = self.adaln_base.act(cond)
|
| 102 |
+
base_m = self.adaln_base.project_activated(act)
|
| 103 |
+
delta_m = self.adaln_deltas[layer_idx](act)
|
| 104 |
+
return base_m + delta_m
|
| 105 |
+
|
| 106 |
+
def _run_blocks(
|
| 107 |
+
self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
|
| 108 |
+
) -> Tensor:
|
| 109 |
+
"""Run a group of decoder blocks with per-block AdaLN modulation."""
|
| 110 |
+
for local_idx, block in enumerate(blocks):
|
| 111 |
+
adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
|
| 112 |
+
x = block(x, adaln_m=adaln_m)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
def forward(
|
| 116 |
+
self,
|
| 117 |
+
x_t: Tensor,
|
| 118 |
+
t: Tensor,
|
| 119 |
+
latents: Tensor,
|
| 120 |
+
*,
|
| 121 |
+
drop_middle_blocks: bool = False,
|
| 122 |
+
) -> Tensor:
|
| 123 |
+
"""Single decoder forward pass.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
x_t: Noised image [B, C, H, W].
|
| 127 |
+
t: Timestep [B] in [0, 1].
|
| 128 |
+
latents: Encoder latents [B, bottleneck_dim, h, w].
|
| 129 |
+
drop_middle_blocks: Replace middle block output with mask feature (PDG).
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
x0 prediction [B, C, H, W].
|
| 133 |
+
"""
|
| 134 |
+
x_feat = self.patchify(x_t)
|
| 135 |
+
z_up = self.latent_up(latents)
|
| 136 |
+
|
| 137 |
+
fused = torch.cat([x_feat, z_up], dim=1)
|
| 138 |
+
fused = self.fuse_in(fused)
|
| 139 |
+
|
| 140 |
+
cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
|
| 141 |
+
|
| 142 |
+
start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
|
| 143 |
+
|
| 144 |
+
if drop_middle_blocks:
|
| 145 |
+
middle_out = self.path_drop_mask_feature.to(
|
| 146 |
+
device=x_t.device, dtype=x_t.dtype
|
| 147 |
+
).expand_as(start_out)
|
| 148 |
+
else:
|
| 149 |
+
middle_out = self._run_blocks(
|
| 150 |
+
self.middle_blocks,
|
| 151 |
+
start_out,
|
| 152 |
+
cond,
|
| 153 |
+
start_index=self._middle_start_idx,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
skip_fused = torch.cat([start_out, middle_out], dim=1)
|
| 157 |
+
skip_fused = self.fuse_skip(skip_fused)
|
| 158 |
+
|
| 159 |
+
end_out = self._run_blocks(
|
| 160 |
+
self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
|
| 161 |
+
)
|
| 162 |
+
patches = self.out_proj(end_out)
|
| 163 |
+
return self.unpatchify(patches)
|
dinac_ae/encoder.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Encoder matching the exported mixed DitBlock/FCDM architecture."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
from dit.axial_rope2d import (
|
| 12 |
+
AxialRoPE2D,
|
| 13 |
+
AxialRoPE2DConfig,
|
| 14 |
+
AxialRoPE2DCoordMode,
|
| 15 |
+
AxialRoPE2DDimLayout,
|
| 16 |
+
AxialRoPE2DNormalizeCoords,
|
| 17 |
+
)
|
| 18 |
+
from dit.blocks import DitBlock
|
| 19 |
+
from dit.body_config import DiTConditioning
|
| 20 |
+
from dit.mlp_types import MLPType
|
| 21 |
+
from dit.position_encoding import DiTPositionEncoding
|
| 22 |
+
|
| 23 |
+
from .straight_through_encoder import Patchify
|
| 24 |
+
|
| 25 |
+
_ENCODER_HEAD_DIM = 64
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _resolve_encoder_mlp_type(name: str) -> MLPType:
|
| 29 |
+
"""Return the encoder DiT MLP enum for the serialized config value."""
|
| 30 |
+
|
| 31 |
+
match str(name):
|
| 32 |
+
case "gelu":
|
| 33 |
+
return MLPType.GELU
|
| 34 |
+
case "silu":
|
| 35 |
+
return MLPType.SILU
|
| 36 |
+
case "relu":
|
| 37 |
+
return MLPType.RELU
|
| 38 |
+
case _ as unreachable:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"Unsupported encoder_mlp_type for DinacAE export: " f"{unreachable!r}"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass(frozen=True)
|
| 45 |
+
class EncoderPosterior:
|
| 46 |
+
"""VP-parameterized diagonal Gaussian posterior."""
|
| 47 |
+
|
| 48 |
+
mean: Tensor
|
| 49 |
+
logsnr: Tensor
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def alpha(self) -> Tensor:
|
| 53 |
+
"""Return the VP signal coefficient."""
|
| 54 |
+
|
| 55 |
+
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 56 |
+
return torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def sigma(self) -> Tensor:
|
| 60 |
+
"""Return the VP noise coefficient."""
|
| 61 |
+
|
| 62 |
+
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 63 |
+
return torch.exp(0.5 * F.logsigmoid(-logsnr_fp32))
|
| 64 |
+
|
| 65 |
+
def mode(self) -> Tensor:
|
| 66 |
+
"""Return the posterior mode in token space."""
|
| 67 |
+
|
| 68 |
+
return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)
|
| 69 |
+
|
| 70 |
+
def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
|
| 71 |
+
"""Sample from the posterior."""
|
| 72 |
+
|
| 73 |
+
mean_fp32 = self.mean.to(torch.float32)
|
| 74 |
+
eps = torch.randn(
|
| 75 |
+
mean_fp32.shape,
|
| 76 |
+
device=mean_fp32.device,
|
| 77 |
+
dtype=torch.float32,
|
| 78 |
+
generator=generator,
|
| 79 |
+
)
|
| 80 |
+
return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Encoder(nn.Module):
|
| 84 |
+
"""Residual-patchify plus DitBlock encoder."""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
*,
|
| 89 |
+
in_channels: int,
|
| 90 |
+
patch_size: int,
|
| 91 |
+
model_dim: int,
|
| 92 |
+
depth: int,
|
| 93 |
+
bottleneck_dim: int,
|
| 94 |
+
mlp_ratio: float,
|
| 95 |
+
mlp_type: str,
|
| 96 |
+
bottleneck_posterior_kind: str,
|
| 97 |
+
bottleneck_norm_mode: str,
|
| 98 |
+
) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
if int(model_dim) % int(_ENCODER_HEAD_DIM) != 0:
|
| 101 |
+
raise ValueError("model_dim must be divisible by encoder head dim")
|
| 102 |
+
self.bottleneck_dim: int = int(bottleneck_dim)
|
| 103 |
+
self.bottleneck_posterior_kind: str = str(bottleneck_posterior_kind)
|
| 104 |
+
self.bottleneck_norm_mode: str = str(bottleneck_norm_mode)
|
| 105 |
+
if self.bottleneck_norm_mode != "disabled":
|
| 106 |
+
raise ValueError("DINAC-AE export requires disabled bottleneck norm")
|
| 107 |
+
self.patchify = Patchify(
|
| 108 |
+
in_channels,
|
| 109 |
+
patch_size,
|
| 110 |
+
model_dim,
|
| 111 |
+
)
|
| 112 |
+
self.blocks = nn.ModuleList(
|
| 113 |
+
[
|
| 114 |
+
DitBlock(
|
| 115 |
+
d_model=int(model_dim),
|
| 116 |
+
n_heads=int(model_dim) // int(_ENCODER_HEAD_DIM),
|
| 117 |
+
mlp_ratio=float(mlp_ratio),
|
| 118 |
+
mlp_type=_resolve_encoder_mlp_type(mlp_type),
|
| 119 |
+
block_index=int(index),
|
| 120 |
+
use_norms=True,
|
| 121 |
+
position_encoding=DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED,
|
| 122 |
+
conditioning=DiTConditioning.UNCOND,
|
| 123 |
+
)
|
| 124 |
+
for index in range(int(depth))
|
| 125 |
+
]
|
| 126 |
+
)
|
| 127 |
+
self.rope = AxialRoPE2D(
|
| 128 |
+
head_dim=int(_ENCODER_HEAD_DIM),
|
| 129 |
+
cfg=AxialRoPE2DConfig(
|
| 130 |
+
base=10_000.0,
|
| 131 |
+
min_period=None,
|
| 132 |
+
max_period=None,
|
| 133 |
+
coord_mode=AxialRoPE2DCoordMode.PATCH_INDICES,
|
| 134 |
+
normalize_coords=AxialRoPE2DNormalizeCoords.MAX,
|
| 135 |
+
dim_layout=AxialRoPE2DDimLayout.PAIR_INTERLEAVED,
|
| 136 |
+
angle_multiplier=1.0,
|
| 137 |
+
coord_offset=0.0,
|
| 138 |
+
frequency_aware=None,
|
| 139 |
+
beta_warp=None,
|
| 140 |
+
alpha_warp=None,
|
| 141 |
+
),
|
| 142 |
+
)
|
| 143 |
+
match self.bottleneck_posterior_kind:
|
| 144 |
+
case "deterministic":
|
| 145 |
+
output_channels = int(bottleneck_dim)
|
| 146 |
+
case "diagonal_gaussian":
|
| 147 |
+
output_channels = 2 * int(bottleneck_dim)
|
| 148 |
+
case _ as unreachable:
|
| 149 |
+
raise RuntimeError(
|
| 150 |
+
f"Unsupported bottleneck_posterior_kind: {unreachable}"
|
| 151 |
+
)
|
| 152 |
+
self.to_bottleneck = nn.Conv2d(
|
| 153 |
+
int(model_dim),
|
| 154 |
+
output_channels,
|
| 155 |
+
kernel_size=1,
|
| 156 |
+
bias=True,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def _encode_projection(self, images: Tensor) -> Tensor:
|
| 160 |
+
"""Encode images to the raw bottleneck projection."""
|
| 161 |
+
|
| 162 |
+
z = self.patchify(images)
|
| 163 |
+
batch, channels, height, width = z.shape
|
| 164 |
+
cond = torch.zeros(
|
| 165 |
+
(int(batch), int(channels)),
|
| 166 |
+
device=z.device,
|
| 167 |
+
dtype=z.dtype,
|
| 168 |
+
)
|
| 169 |
+
rope_sincos = self.rope(H=int(height), W=int(width), scales=None)
|
| 170 |
+
y = z
|
| 171 |
+
for block in self.blocks:
|
| 172 |
+
y = block(
|
| 173 |
+
y,
|
| 174 |
+
hw=(int(height), int(width)),
|
| 175 |
+
cond_vec=cond,
|
| 176 |
+
adaln_m=None,
|
| 177 |
+
rope_sincos=rope_sincos,
|
| 178 |
+
generator=None,
|
| 179 |
+
)
|
| 180 |
+
return self.to_bottleneck(y)
|
| 181 |
+
|
| 182 |
+
def _apply_bottleneck_norm(self, z: Tensor) -> Tensor:
|
| 183 |
+
"""Return the unnormalized bottleneck mean."""
|
| 184 |
+
|
| 185 |
+
return z
|
| 186 |
+
|
| 187 |
+
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
| 188 |
+
"""Encode images and return the posterior."""
|
| 189 |
+
|
| 190 |
+
if self.bottleneck_posterior_kind != "diagonal_gaussian":
|
| 191 |
+
raise RuntimeError(
|
| 192 |
+
"encode_posterior requires bottleneck_posterior_kind=diagonal_gaussian"
|
| 193 |
+
)
|
| 194 |
+
projection = self._encode_projection(images)
|
| 195 |
+
mean, logsnr = projection.chunk(2, dim=1)
|
| 196 |
+
mean = self._apply_bottleneck_norm(mean)
|
| 197 |
+
return EncoderPosterior(mean=mean, logsnr=logsnr)
|
| 198 |
+
|
| 199 |
+
def forward(self, images: Tensor) -> Tensor:
|
| 200 |
+
"""Encode images to latent tokens."""
|
| 201 |
+
|
| 202 |
+
projection = self._encode_projection(images)
|
| 203 |
+
match self.bottleneck_posterior_kind:
|
| 204 |
+
case "diagonal_gaussian":
|
| 205 |
+
mean, logsnr = projection.chunk(2, dim=1)
|
| 206 |
+
mean = self._apply_bottleneck_norm(mean)
|
| 207 |
+
logsnr_fp32 = logsnr.to(torch.float32)
|
| 208 |
+
alpha = torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
|
| 209 |
+
return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
|
| 210 |
+
case "deterministic":
|
| 211 |
+
return self._apply_bottleneck_norm(projection)
|
| 212 |
+
case _ as unreachable:
|
| 213 |
+
raise RuntimeError(
|
| 214 |
+
f"Unsupported bottleneck_posterior_kind: {unreachable}"
|
| 215 |
+
)
|
dinac_ae/fcdm_block.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
from .norms import ChannelWiseRMSNorm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GRN(nn.Module):
|
| 13 |
+
"""Global Response Normalization for NCHW tensors."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.eps: float = float(eps)
|
| 18 |
+
c = int(channels)
|
| 19 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
|
| 20 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
|
| 21 |
+
|
| 22 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 23 |
+
g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
|
| 24 |
+
g_fp32 = g.to(dtype=torch.float32)
|
| 25 |
+
n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
|
| 26 |
+
gamma = self.gamma.to(device=x.device, dtype=x.dtype)
|
| 27 |
+
beta = self.beta.to(device=x.device, dtype=x.dtype)
|
| 28 |
+
return gamma * (x * n) + beta + x
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class FCDMBlock(nn.Module):
|
| 32 |
+
"""ConvNeXt-style block with scale+gate AdaLN and GRN.
|
| 33 |
+
|
| 34 |
+
Two modes:
|
| 35 |
+
- Unconditioned (encoder): uses learned layer-scale for near-identity init.
|
| 36 |
+
- External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
|
| 37 |
+
The gate is applied raw (no tanh).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
channels: int,
|
| 43 |
+
mlp_ratio: float,
|
| 44 |
+
*,
|
| 45 |
+
depthwise_kernel_size: int = 7,
|
| 46 |
+
use_external_adaln: bool = False,
|
| 47 |
+
norm_eps: float = 1e-6,
|
| 48 |
+
layer_scale_init: float = 1e-3,
|
| 49 |
+
) -> None:
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.channels: int = int(channels)
|
| 52 |
+
self.mlp_ratio: float = float(mlp_ratio)
|
| 53 |
+
|
| 54 |
+
self.dwconv = nn.Conv2d(
|
| 55 |
+
channels,
|
| 56 |
+
channels,
|
| 57 |
+
kernel_size=depthwise_kernel_size,
|
| 58 |
+
padding=depthwise_kernel_size // 2,
|
| 59 |
+
stride=1,
|
| 60 |
+
groups=channels,
|
| 61 |
+
bias=True,
|
| 62 |
+
)
|
| 63 |
+
self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
|
| 64 |
+
hidden = max(int(float(channels) * float(mlp_ratio)), 1)
|
| 65 |
+
self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
|
| 66 |
+
self.grn = GRN(hidden, eps=1e-6)
|
| 67 |
+
self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
|
| 68 |
+
|
| 69 |
+
if not use_external_adaln:
|
| 70 |
+
self.layer_scale = nn.Parameter(
|
| 71 |
+
torch.full((channels,), float(layer_scale_init))
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
self.register_parameter("layer_scale", None)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
|
| 77 |
+
b, c, _, _ = x.shape
|
| 78 |
+
|
| 79 |
+
if adaln_m is not None:
|
| 80 |
+
m = adaln_m.to(device=x.device, dtype=x.dtype)
|
| 81 |
+
scale, gate = m.chunk(2, dim=-1)
|
| 82 |
+
else:
|
| 83 |
+
scale = gate = None
|
| 84 |
+
|
| 85 |
+
h = self.dwconv(x)
|
| 86 |
+
h = self.norm(h)
|
| 87 |
+
|
| 88 |
+
if scale is not None:
|
| 89 |
+
h = h * (1.0 + scale.view(b, c, 1, 1))
|
| 90 |
+
|
| 91 |
+
h = self.pwconv1(h)
|
| 92 |
+
h = F.gelu(h)
|
| 93 |
+
h = self.grn(h)
|
| 94 |
+
h = self.pwconv2(h)
|
| 95 |
+
|
| 96 |
+
if gate is not None:
|
| 97 |
+
gate_view = gate.view(b, c, 1, 1)
|
| 98 |
+
else:
|
| 99 |
+
gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
|
| 100 |
+
device=h.device, dtype=h.dtype
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return x + gate_view * h
|
dinac_ae/model.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Standalone mixed DitBlock/FCDM diffusion autoencoder export."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
from dit.mlp_types import MLPType
|
| 11 |
+
from dit.repa_projection import DinoTokenAlignmentHead
|
| 12 |
+
|
| 13 |
+
from .config import DinacAEConfig, DinacAEInferenceConfig
|
| 14 |
+
from .decoder import Decoder
|
| 15 |
+
from .encoder import Encoder, EncoderPosterior
|
| 16 |
+
from .samplers import run_ddim, run_dpmpp_2m
|
| 17 |
+
from .vp_diffusion import get_schedule, make_initial_state, sample_noise
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _resolve_model_dir(
|
| 21 |
+
path_or_repo_id: str | Path,
|
| 22 |
+
*,
|
| 23 |
+
revision: str | None,
|
| 24 |
+
cache_dir: str | Path | None,
|
| 25 |
+
) -> Path:
|
| 26 |
+
"""Resolve a local path or Hugging Face repo ID to a directory."""
|
| 27 |
+
|
| 28 |
+
local = Path(path_or_repo_id)
|
| 29 |
+
if local.is_dir():
|
| 30 |
+
return local
|
| 31 |
+
repo_id = str(path_or_repo_id)
|
| 32 |
+
try:
|
| 33 |
+
from huggingface_hub import snapshot_download
|
| 34 |
+
except ImportError as exc:
|
| 35 |
+
raise ImportError(
|
| 36 |
+
f"'{repo_id}' is not an existing local directory. Install "
|
| 37 |
+
"huggingface_hub to load from the Hub."
|
| 38 |
+
) from exc
|
| 39 |
+
cache_dir_str = str(cache_dir) if cache_dir is not None else None
|
| 40 |
+
return Path(
|
| 41 |
+
snapshot_download(
|
| 42 |
+
repo_id,
|
| 43 |
+
revision=revision,
|
| 44 |
+
cache_dir=cache_dir_str,
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _resolve_class_head_mlp_type(name: str) -> MLPType:
|
| 50 |
+
"""Return the token-head MLP enum for the serialized config value."""
|
| 51 |
+
|
| 52 |
+
match str(name):
|
| 53 |
+
case "gelu":
|
| 54 |
+
return MLPType.GELU
|
| 55 |
+
case "silu":
|
| 56 |
+
return MLPType.SILU
|
| 57 |
+
case "relu":
|
| 58 |
+
return MLPType.RELU
|
| 59 |
+
case _ as unreachable:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
"Unsupported class_head_mlp_type for DinacAE export: "
|
| 62 |
+
f"{unreachable!r}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class DinacAE(nn.Module):
|
| 67 |
+
"""Exported DINAC-AE wrapper with encode/decode/predict_class APIs."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: DinacAEConfig) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.config = config
|
| 72 |
+
self.register_buffer(
|
| 73 |
+
"latent_norm_running_mean",
|
| 74 |
+
torch.zeros((config.latent_channels,), dtype=torch.float32),
|
| 75 |
+
)
|
| 76 |
+
self.register_buffer(
|
| 77 |
+
"latent_norm_running_var",
|
| 78 |
+
torch.ones((config.latent_channels,), dtype=torch.float32),
|
| 79 |
+
)
|
| 80 |
+
self.encoder = Encoder(
|
| 81 |
+
in_channels=int(config.in_channels),
|
| 82 |
+
patch_size=int(config.patch_size),
|
| 83 |
+
model_dim=int(config.model_dim),
|
| 84 |
+
depth=int(config.encoder_depth),
|
| 85 |
+
bottleneck_dim=int(config.bottleneck_dim),
|
| 86 |
+
mlp_ratio=float(config.mlp_ratio),
|
| 87 |
+
mlp_type=str(config.encoder_mlp_type),
|
| 88 |
+
bottleneck_posterior_kind=str(config.bottleneck_posterior_kind),
|
| 89 |
+
bottleneck_norm_mode=str(config.bottleneck_norm_mode),
|
| 90 |
+
)
|
| 91 |
+
self.decoder = Decoder(
|
| 92 |
+
in_channels=int(config.in_channels),
|
| 93 |
+
patch_size=int(config.patch_size),
|
| 94 |
+
model_dim=int(config.model_dim),
|
| 95 |
+
depth=int(config.decoder_depth),
|
| 96 |
+
start_block_count=int(config.decoder_start_blocks),
|
| 97 |
+
end_block_count=int(config.decoder_end_blocks),
|
| 98 |
+
bottleneck_dim=int(config.bottleneck_dim),
|
| 99 |
+
mlp_ratio=float(config.mlp_ratio),
|
| 100 |
+
depthwise_kernel_size=int(config.depthwise_kernel_size),
|
| 101 |
+
adaln_low_rank_rank=int(config.adaln_low_rank_rank),
|
| 102 |
+
)
|
| 103 |
+
self.dino_token_alignment_head = DinoTokenAlignmentHead(
|
| 104 |
+
in_channels=int(config.bottleneck_dim),
|
| 105 |
+
feature_dim=int(config.class_head_feature_dim),
|
| 106 |
+
model_dim=int(config.class_head_model_dim),
|
| 107 |
+
head_dim=int(config.class_head_head_dim),
|
| 108 |
+
mlp_ratio=float(config.class_head_mlp_ratio),
|
| 109 |
+
mlp_activation=_resolve_class_head_mlp_type(config.class_head_mlp_type),
|
| 110 |
+
block_index=10_001,
|
| 111 |
+
register_token_count=int(config.class_head_register_token_count),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def _restore_float32_norm_buffers(self) -> None:
|
| 115 |
+
"""Keep latent running stats in float32 after device/dtype moves."""
|
| 116 |
+
|
| 117 |
+
self.latent_norm_running_mean = self.latent_norm_running_mean.to(
|
| 118 |
+
dtype=torch.float32
|
| 119 |
+
)
|
| 120 |
+
self.latent_norm_running_var = self.latent_norm_running_var.to(
|
| 121 |
+
dtype=torch.float32
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def to(self, *args: object, **kwargs: object) -> DinacAE:
|
| 125 |
+
"""Move the model while preserving float32 latent stats buffers."""
|
| 126 |
+
|
| 127 |
+
moved = super().to(*args, **kwargs)
|
| 128 |
+
if not isinstance(moved, DinacAE):
|
| 129 |
+
raise RuntimeError(
|
| 130 |
+
f"Expected DinacAE after nn.Module.to(), got {type(moved).__name__}"
|
| 131 |
+
)
|
| 132 |
+
moved._restore_float32_norm_buffers()
|
| 133 |
+
return moved
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def from_pretrained(
|
| 137 |
+
cls,
|
| 138 |
+
path_or_repo_id: str | Path,
|
| 139 |
+
*,
|
| 140 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 141 |
+
device: str | torch.device = "cpu",
|
| 142 |
+
revision: str | None = None,
|
| 143 |
+
cache_dir: str | Path | None = None,
|
| 144 |
+
) -> DinacAE:
|
| 145 |
+
"""Load a pretrained export from a local directory or the Hub."""
|
| 146 |
+
|
| 147 |
+
model_dir = _resolve_model_dir(
|
| 148 |
+
path_or_repo_id,
|
| 149 |
+
revision=revision,
|
| 150 |
+
cache_dir=cache_dir,
|
| 151 |
+
)
|
| 152 |
+
config = DinacAEConfig.load(model_dir / "config.json")
|
| 153 |
+
model = cls(config)
|
| 154 |
+
safetensors_path = model_dir / "model.safetensors"
|
| 155 |
+
if safetensors_path.exists():
|
| 156 |
+
try:
|
| 157 |
+
from safetensors.torch import load_file
|
| 158 |
+
except ImportError as exc:
|
| 159 |
+
raise ImportError(
|
| 160 |
+
"safetensors is required to load model.safetensors"
|
| 161 |
+
) from exc
|
| 162 |
+
state_dict = load_file(str(safetensors_path), device="cpu")
|
| 163 |
+
else:
|
| 164 |
+
raise FileNotFoundError(
|
| 165 |
+
f"No model weights found in {model_dir}. Expected model.safetensors."
|
| 166 |
+
)
|
| 167 |
+
model.load_state_dict(state_dict, strict=True)
|
| 168 |
+
model = model.to(dtype=dtype, device=torch.device(device))
|
| 169 |
+
model.eval()
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
|
| 173 |
+
"""Return ``(mean, std)`` tensors for latent whitening."""
|
| 174 |
+
|
| 175 |
+
mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
|
| 176 |
+
var = self.latent_norm_running_var.view(1, -1, 1, 1)
|
| 177 |
+
std = torch.sqrt(
|
| 178 |
+
var.to(torch.float32) + float(self.config.latent_running_stats_eps)
|
| 179 |
+
)
|
| 180 |
+
return mean.to(torch.float32), std
|
| 181 |
+
|
| 182 |
+
def _require_image_size_divisible(self, height: int, width: int) -> None:
|
| 183 |
+
"""Require image dimensions compatible with the exported patch size."""
|
| 184 |
+
|
| 185 |
+
patch = int(self.config.effective_patch_size)
|
| 186 |
+
if int(height) % patch != 0 or int(width) % patch != 0:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"Image height={height} and width={width} must be divisible by "
|
| 189 |
+
f"effective_patch_size={patch}"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def whiten(self, latents: Tensor) -> Tensor:
|
| 193 |
+
"""Whiten raw latents using exported running stats."""
|
| 194 |
+
|
| 195 |
+
z = latents.to(torch.float32)
|
| 196 |
+
mean, std = self._latent_norm_stats()
|
| 197 |
+
return (z - mean.to(device=z.device)) / std.to(device=z.device)
|
| 198 |
+
|
| 199 |
+
def dewhiten(self, latents: Tensor) -> Tensor:
|
| 200 |
+
"""Undo latent whitening back to the raw decoder scale."""
|
| 201 |
+
|
| 202 |
+
z = latents.to(torch.float32)
|
| 203 |
+
mean, std = self._latent_norm_stats()
|
| 204 |
+
return z * std.to(device=z.device) + mean.to(device=z.device)
|
| 205 |
+
|
| 206 |
+
def encode(self, images: Tensor) -> Tensor:
|
| 207 |
+
"""Encode images to the exported whitened latent space."""
|
| 208 |
+
|
| 209 |
+
self._require_image_size_divisible(
|
| 210 |
+
height=int(images.shape[2]),
|
| 211 |
+
width=int(images.shape[3]),
|
| 212 |
+
)
|
| 213 |
+
model_dtype = next(self.parameters()).dtype
|
| 214 |
+
latents = self.encoder(images.to(dtype=model_dtype))
|
| 215 |
+
return self.whiten(latents).to(dtype=model_dtype)
|
| 216 |
+
|
| 217 |
+
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
| 218 |
+
"""Encode images and return the raw posterior."""
|
| 219 |
+
|
| 220 |
+
self._require_image_size_divisible(
|
| 221 |
+
height=int(images.shape[2]),
|
| 222 |
+
width=int(images.shape[3]),
|
| 223 |
+
)
|
| 224 |
+
model_dtype = next(self.parameters()).dtype
|
| 225 |
+
return self.encoder.encode_posterior(images.to(dtype=model_dtype))
|
| 226 |
+
|
| 227 |
+
def predict_class(self, latents: Tensor) -> Tensor:
|
| 228 |
+
"""Predict the exported DINO class token from whitened latents."""
|
| 229 |
+
|
| 230 |
+
dewhitened = self.dewhiten(latents)
|
| 231 |
+
t_zero = torch.zeros(
|
| 232 |
+
(int(latents.shape[0]),),
|
| 233 |
+
device=latents.device,
|
| 234 |
+
dtype=torch.float32,
|
| 235 |
+
)
|
| 236 |
+
head_dtype = self.dino_token_alignment_head.in_proj.weight.dtype
|
| 237 |
+
device_type = "cuda" if latents.device.type == "cuda" else "cpu"
|
| 238 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 239 |
+
out = self.dino_token_alignment_head(
|
| 240 |
+
dewhitened.to(device=latents.device, dtype=head_dtype),
|
| 241 |
+
t=t_zero,
|
| 242 |
+
)
|
| 243 |
+
return out.class_token.to(torch.float32)
|
| 244 |
+
|
| 245 |
+
def decode(
|
| 246 |
+
self,
|
| 247 |
+
latents: Tensor,
|
| 248 |
+
height: int,
|
| 249 |
+
width: int,
|
| 250 |
+
*,
|
| 251 |
+
inference_config: DinacAEInferenceConfig | None = None,
|
| 252 |
+
) -> Tensor:
|
| 253 |
+
"""Decode exported whitened latents to images via VP diffusion."""
|
| 254 |
+
|
| 255 |
+
cfg = (
|
| 256 |
+
inference_config
|
| 257 |
+
if inference_config is not None
|
| 258 |
+
else DinacAEInferenceConfig()
|
| 259 |
+
)
|
| 260 |
+
self._require_image_size_divisible(height=int(height), width=int(width))
|
| 261 |
+
batch = int(latents.shape[0])
|
| 262 |
+
device = latents.device
|
| 263 |
+
model_dtype = next(self.parameters()).dtype
|
| 264 |
+
decoder_latents = self.dewhiten(latents).to(device=device, dtype=model_dtype)
|
| 265 |
+
noise = sample_noise(
|
| 266 |
+
(batch, int(self.config.in_channels), int(height), int(width)),
|
| 267 |
+
noise_std=float(self.config.pixel_noise_std),
|
| 268 |
+
seed=cfg.seed,
|
| 269 |
+
device=torch.device("cpu"),
|
| 270 |
+
dtype=torch.float32,
|
| 271 |
+
)
|
| 272 |
+
schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
|
| 273 |
+
initial_state = make_initial_state(
|
| 274 |
+
noise=noise.to(device=device),
|
| 275 |
+
t_start=schedule[0:1],
|
| 276 |
+
logsnr_min=float(self.config.logsnr_min),
|
| 277 |
+
logsnr_max=float(self.config.logsnr_max),
|
| 278 |
+
)
|
| 279 |
+
device_type = "cuda" if device.type == "cuda" else "cpu"
|
| 280 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 281 |
+
|
| 282 |
+
def _forward_fn(
|
| 283 |
+
x_t: Tensor,
|
| 284 |
+
t: Tensor,
|
| 285 |
+
latents_in: Tensor,
|
| 286 |
+
*,
|
| 287 |
+
drop_middle_blocks: bool = False,
|
| 288 |
+
mask_latent_tokens: bool = False,
|
| 289 |
+
) -> Tensor:
|
| 290 |
+
_ = mask_latent_tokens
|
| 291 |
+
return self.decoder(
|
| 292 |
+
x_t.to(dtype=model_dtype),
|
| 293 |
+
t,
|
| 294 |
+
latents_in.to(dtype=model_dtype),
|
| 295 |
+
drop_middle_blocks=bool(drop_middle_blocks),
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
match cfg.sampler:
|
| 299 |
+
case "ddim":
|
| 300 |
+
sampler_fn = run_ddim
|
| 301 |
+
case "dpmpp_2m":
|
| 302 |
+
sampler_fn = run_dpmpp_2m
|
| 303 |
+
case _ as unreachable:
|
| 304 |
+
raise ValueError(f"Unsupported sampler: {unreachable!r}")
|
| 305 |
+
pdg_mode = "path_drop" if bool(cfg.pdg) else "disabled"
|
| 306 |
+
return sampler_fn(
|
| 307 |
+
forward_fn=_forward_fn,
|
| 308 |
+
initial_state=initial_state,
|
| 309 |
+
schedule=schedule,
|
| 310 |
+
latents=decoder_latents,
|
| 311 |
+
logsnr_min=float(self.config.logsnr_min),
|
| 312 |
+
logsnr_max=float(self.config.logsnr_max),
|
| 313 |
+
pdg_mode=pdg_mode,
|
| 314 |
+
pdg_strength=float(cfg.pdg_strength),
|
| 315 |
+
device=device,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def reconstruct(
|
| 319 |
+
self,
|
| 320 |
+
images: Tensor,
|
| 321 |
+
*,
|
| 322 |
+
inference_config: DinacAEInferenceConfig | None = None,
|
| 323 |
+
) -> Tensor:
|
| 324 |
+
"""Encode then decode one image batch."""
|
| 325 |
+
|
| 326 |
+
latents = self.encode(images)
|
| 327 |
+
_batch, _channels, height, width = images.shape
|
| 328 |
+
return self.decode(
|
| 329 |
+
latents,
|
| 330 |
+
height=int(height),
|
| 331 |
+
width=int(width),
|
| 332 |
+
inference_config=inference_config,
|
| 333 |
+
)
|
dinac_ae/norms.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Channel-wise RMSNorm for NCHW tensors."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChannelWiseRMSNorm(nn.Module):
|
| 10 |
+
"""Channel-wise RMSNorm with float32 reduction for numerical stability.
|
| 11 |
+
|
| 12 |
+
Normalizes across channels per spatial position. Supports optional
|
| 13 |
+
per-channel affine weight and bias.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.channels: int = int(channels)
|
| 19 |
+
self._eps: float = float(eps)
|
| 20 |
+
if affine:
|
| 21 |
+
self.weight = nn.Parameter(torch.ones(self.channels))
|
| 22 |
+
self.bias = nn.Parameter(torch.zeros(self.channels))
|
| 23 |
+
else:
|
| 24 |
+
self.register_parameter("weight", None)
|
| 25 |
+
self.register_parameter("bias", None)
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
if x.dim() < 2:
|
| 29 |
+
return x
|
| 30 |
+
# Float32 accumulation for stability
|
| 31 |
+
ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
|
| 32 |
+
inv_rms = torch.rsqrt(ms + self._eps)
|
| 33 |
+
y = x * inv_rms.to(dtype=x.dtype)
|
| 34 |
+
if self.weight is not None:
|
| 35 |
+
shape = (1, -1) + (1,) * (x.dim() - 2)
|
| 36 |
+
y = y * self.weight.view(shape).to(dtype=x.dtype)
|
| 37 |
+
if self.bias is not None:
|
| 38 |
+
y = y + self.bias.view(shape).to(dtype=x.dtype)
|
| 39 |
+
return y
|
dinac_ae/samplers.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DDIM and DPM++2M samplers for VP diffusion with path-drop PDG support."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Protocol
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from .vp_diffusion import (
|
| 11 |
+
alpha_sigma_from_logsnr,
|
| 12 |
+
broadcast_time_like,
|
| 13 |
+
shifted_cosine_interpolated_logsnr_from_t,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DecoderForwardFn(Protocol):
|
| 18 |
+
"""Callable that predicts x0 from (x_t, t, latents) with path-drop PDG flag."""
|
| 19 |
+
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
x_t: Tensor,
|
| 23 |
+
t: Tensor,
|
| 24 |
+
latents: Tensor,
|
| 25 |
+
*,
|
| 26 |
+
drop_middle_blocks: bool = False,
|
| 27 |
+
mask_latent_tokens: bool = False,
|
| 28 |
+
) -> Tensor: ...
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _reconstruct_eps_from_x0(
|
| 32 |
+
*, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
|
| 33 |
+
) -> Tensor:
|
| 34 |
+
"""Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
|
| 35 |
+
|
| 36 |
+
eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
|
| 37 |
+
"""
|
| 38 |
+
alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
|
| 39 |
+
sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
|
| 40 |
+
x_t_f32 = x_t.to(torch.float32)
|
| 41 |
+
x0_f32 = x0_hat.to(torch.float32)
|
| 42 |
+
return (x_t_f32 - alpha_view * x0_f32) / sigma_view
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _ddim_step(
|
| 46 |
+
*,
|
| 47 |
+
x0_hat: Tensor,
|
| 48 |
+
eps_hat: Tensor,
|
| 49 |
+
alpha_next: Tensor,
|
| 50 |
+
sigma_next: Tensor,
|
| 51 |
+
ref: Tensor,
|
| 52 |
+
) -> Tensor:
|
| 53 |
+
"""DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
|
| 54 |
+
a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
|
| 55 |
+
s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
|
| 56 |
+
return a * x0_hat + s * eps_hat
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _predict_with_pdg(
|
| 60 |
+
forward_fn: DecoderForwardFn,
|
| 61 |
+
state: Tensor,
|
| 62 |
+
t_vec: Tensor,
|
| 63 |
+
latents: Tensor,
|
| 64 |
+
*,
|
| 65 |
+
pdg_mode: str,
|
| 66 |
+
pdg_strength: float,
|
| 67 |
+
) -> Tensor:
|
| 68 |
+
"""Run decoder forward with optional PDG guidance.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
forward_fn: Decoder forward function.
|
| 72 |
+
state: Current noised state [B, C, H, W].
|
| 73 |
+
t_vec: Timestep vector [B].
|
| 74 |
+
latents: Encoder latents.
|
| 75 |
+
pdg_mode: "disabled" or "path_drop".
|
| 76 |
+
pdg_strength: CFG-like strength for PDG.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
x0_hat prediction in float32.
|
| 80 |
+
"""
|
| 81 |
+
match pdg_mode:
|
| 82 |
+
case "path_drop":
|
| 83 |
+
x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
|
| 84 |
+
torch.float32
|
| 85 |
+
)
|
| 86 |
+
x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
|
| 87 |
+
torch.float32
|
| 88 |
+
)
|
| 89 |
+
return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
|
| 90 |
+
case "disabled":
|
| 91 |
+
return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
|
| 92 |
+
torch.float32
|
| 93 |
+
)
|
| 94 |
+
case _ as unreachable:
|
| 95 |
+
raise ValueError(f"Unsupported PDG mode: {unreachable!r}")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def run_ddim(
|
| 99 |
+
*,
|
| 100 |
+
forward_fn: DecoderForwardFn,
|
| 101 |
+
initial_state: Tensor,
|
| 102 |
+
schedule: Tensor,
|
| 103 |
+
latents: Tensor,
|
| 104 |
+
logsnr_min: float,
|
| 105 |
+
logsnr_max: float,
|
| 106 |
+
log_change_high: float = 0.0,
|
| 107 |
+
log_change_low: float = 0.0,
|
| 108 |
+
pdg_mode: str = "disabled",
|
| 109 |
+
pdg_strength: float = 1.5,
|
| 110 |
+
device: torch.device | None = None,
|
| 111 |
+
) -> Tensor:
|
| 112 |
+
"""Run DDIM sampling loop with path-drop PDG support.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
|
| 116 |
+
initial_state: Starting noised state [B, C, H, W] in float32.
|
| 117 |
+
schedule: Descending t-schedule [num_steps] in [0, 1].
|
| 118 |
+
latents: Encoder latents [B, bottleneck_dim, h, w].
|
| 119 |
+
logsnr_min, logsnr_max: VP schedule endpoints.
|
| 120 |
+
log_change_high, log_change_low: Shifted-cosine schedule parameters.
|
| 121 |
+
pdg_mode: "disabled" or "path_drop".
|
| 122 |
+
pdg_strength: CFG-like strength for PDG.
|
| 123 |
+
device: Target device.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Denoised samples [B, C, H, W] in float32.
|
| 127 |
+
"""
|
| 128 |
+
run_device = device or initial_state.device
|
| 129 |
+
batch_size = int(initial_state.shape[0])
|
| 130 |
+
state = initial_state.to(device=run_device, dtype=torch.float32)
|
| 131 |
+
|
| 132 |
+
# Precompute logSNR, alpha, sigma for all schedule points
|
| 133 |
+
lmb = shifted_cosine_interpolated_logsnr_from_t(
|
| 134 |
+
schedule.to(device=run_device),
|
| 135 |
+
logsnr_min=logsnr_min,
|
| 136 |
+
logsnr_max=logsnr_max,
|
| 137 |
+
log_change_high=log_change_high,
|
| 138 |
+
log_change_low=log_change_low,
|
| 139 |
+
)
|
| 140 |
+
alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
|
| 141 |
+
|
| 142 |
+
for i in range(int(schedule.numel()) - 1):
|
| 143 |
+
t_i = schedule[i]
|
| 144 |
+
a_t = alpha_sched[i].expand(batch_size)
|
| 145 |
+
s_t = sigma_sched[i].expand(batch_size)
|
| 146 |
+
a_next = alpha_sched[i + 1].expand(batch_size)
|
| 147 |
+
s_next = sigma_sched[i + 1].expand(batch_size)
|
| 148 |
+
|
| 149 |
+
# Model prediction with optional PDG
|
| 150 |
+
t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
|
| 151 |
+
x0_hat = _predict_with_pdg(
|
| 152 |
+
forward_fn,
|
| 153 |
+
state,
|
| 154 |
+
t_vec,
|
| 155 |
+
latents,
|
| 156 |
+
pdg_mode=pdg_mode,
|
| 157 |
+
pdg_strength=pdg_strength,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
eps_hat = _reconstruct_eps_from_x0(
|
| 161 |
+
x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
|
| 162 |
+
)
|
| 163 |
+
state = _ddim_step(
|
| 164 |
+
x0_hat=x0_hat,
|
| 165 |
+
eps_hat=eps_hat,
|
| 166 |
+
alpha_next=a_next,
|
| 167 |
+
sigma_next=s_next,
|
| 168 |
+
ref=state,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return state
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def run_dpmpp_2m(
|
| 175 |
+
*,
|
| 176 |
+
forward_fn: DecoderForwardFn,
|
| 177 |
+
initial_state: Tensor,
|
| 178 |
+
schedule: Tensor,
|
| 179 |
+
latents: Tensor,
|
| 180 |
+
logsnr_min: float,
|
| 181 |
+
logsnr_max: float,
|
| 182 |
+
log_change_high: float = 0.0,
|
| 183 |
+
log_change_low: float = 0.0,
|
| 184 |
+
pdg_mode: str = "disabled",
|
| 185 |
+
pdg_strength: float = 1.5,
|
| 186 |
+
device: torch.device | None = None,
|
| 187 |
+
) -> Tensor:
|
| 188 |
+
"""Run DPM++2M sampling loop with path-drop PDG support.
|
| 189 |
+
|
| 190 |
+
Multi-step solver using exponential integrator formulation in half-lambda space.
|
| 191 |
+
"""
|
| 192 |
+
run_device = device or initial_state.device
|
| 193 |
+
batch_size = int(initial_state.shape[0])
|
| 194 |
+
state = initial_state.to(device=run_device, dtype=torch.float32)
|
| 195 |
+
|
| 196 |
+
# Precompute logSNR, alpha, sigma, half-lambda for all schedule points
|
| 197 |
+
lmb = shifted_cosine_interpolated_logsnr_from_t(
|
| 198 |
+
schedule.to(device=run_device),
|
| 199 |
+
logsnr_min=logsnr_min,
|
| 200 |
+
logsnr_max=logsnr_max,
|
| 201 |
+
log_change_high=log_change_high,
|
| 202 |
+
log_change_low=log_change_low,
|
| 203 |
+
)
|
| 204 |
+
alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
|
| 205 |
+
half_lambda = 0.5 * lmb.to(torch.float32)
|
| 206 |
+
|
| 207 |
+
x0_prev: Tensor | None = None
|
| 208 |
+
|
| 209 |
+
for i in range(int(schedule.numel()) - 1):
|
| 210 |
+
t_i = schedule[i]
|
| 211 |
+
s_t = sigma_sched[i].expand(batch_size)
|
| 212 |
+
a_next = alpha_sched[i + 1].expand(batch_size)
|
| 213 |
+
s_next = sigma_sched[i + 1].expand(batch_size)
|
| 214 |
+
|
| 215 |
+
# Model prediction with optional PDG
|
| 216 |
+
t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
|
| 217 |
+
x0_hat = _predict_with_pdg(
|
| 218 |
+
forward_fn,
|
| 219 |
+
state,
|
| 220 |
+
t_vec,
|
| 221 |
+
latents,
|
| 222 |
+
pdg_mode=pdg_mode,
|
| 223 |
+
pdg_strength=pdg_strength,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
lam_t = half_lambda[i].expand(batch_size)
|
| 227 |
+
lam_next = half_lambda[i + 1].expand(batch_size)
|
| 228 |
+
h = (lam_next - lam_t).to(torch.float32)
|
| 229 |
+
phi_1 = torch.expm1(-h)
|
| 230 |
+
|
| 231 |
+
sigma_ratio = (s_next / s_t).to(torch.float32)
|
| 232 |
+
|
| 233 |
+
if i == 0 or x0_prev is None:
|
| 234 |
+
# First-order step
|
| 235 |
+
state = (
|
| 236 |
+
sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
|
| 237 |
+
- broadcast_time_like(a_next, state).to(torch.float32)
|
| 238 |
+
* broadcast_time_like(phi_1, state).to(torch.float32)
|
| 239 |
+
* x0_hat
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
# Second-order step
|
| 243 |
+
lam_prev = half_lambda[i - 1].expand(batch_size)
|
| 244 |
+
h_0 = (lam_t - lam_prev).to(torch.float32)
|
| 245 |
+
r0 = h_0 / h
|
| 246 |
+
d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
|
| 247 |
+
common = broadcast_time_like(a_next, state).to(
|
| 248 |
+
torch.float32
|
| 249 |
+
) * broadcast_time_like(phi_1, state).to(torch.float32)
|
| 250 |
+
state = (
|
| 251 |
+
sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
|
| 252 |
+
- common * x0_hat
|
| 253 |
+
- 0.5 * common * d1_0
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
x0_prev = x0_hat
|
| 257 |
+
|
| 258 |
+
return state
|
dinac_ae/straight_through_encoder.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Patch embedding used by the exported DINAC-AE model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Final
|
| 6 |
+
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
__all__ = ["Patchify", "StraightThroughEncoder"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StraightThroughEncoder(nn.Module):
|
| 13 |
+
"""Project non-overlapping image patches with a stride-patch convolution."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels: int,
|
| 18 |
+
patch: int,
|
| 19 |
+
out_channels: int,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
if in_channels <= 0:
|
| 23 |
+
raise ValueError("in_channels must be positive")
|
| 24 |
+
if patch <= 0:
|
| 25 |
+
raise ValueError("patch must be positive")
|
| 26 |
+
if out_channels <= 0:
|
| 27 |
+
raise ValueError("out_channels must be positive")
|
| 28 |
+
self.in_channels: Final[int] = int(in_channels)
|
| 29 |
+
self.patch: Final[int] = int(patch)
|
| 30 |
+
self._output_channels: Final[int] = int(out_channels)
|
| 31 |
+
self.proj = nn.Conv2d(
|
| 32 |
+
self.in_channels,
|
| 33 |
+
self._output_channels,
|
| 34 |
+
kernel_size=self.patch,
|
| 35 |
+
stride=self.patch,
|
| 36 |
+
bias=True,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
|
| 40 |
+
"""Return the patchified token grid."""
|
| 41 |
+
|
| 42 |
+
return self.proj(x)
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def output_channels(self) -> int:
|
| 46 |
+
"""Return the output channel width produced by the encoder."""
|
| 47 |
+
|
| 48 |
+
return int(self._output_channels)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def latent_channels(self) -> int:
|
| 52 |
+
"""Alias for ``output_channels`` to match encoder interface shape."""
|
| 53 |
+
|
| 54 |
+
return int(self._output_channels)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
Patchify = StraightThroughEncoder
|
dinac_ae/time_embed.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sinusoidal timestep embedding with MLP projection."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _log_spaced_frequencies(
|
| 12 |
+
half: int, max_period: float, *, device: torch.device | None = None
|
| 13 |
+
) -> Tensor:
|
| 14 |
+
"""Log-spaced frequencies for sinusoidal embedding."""
|
| 15 |
+
return torch.exp(
|
| 16 |
+
-math.log(max_period)
|
| 17 |
+
* torch.arange(half, device=device, dtype=torch.float32)
|
| 18 |
+
/ max(float(half - 1), 1.0)
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def sinusoidal_time_embedding(
|
| 23 |
+
t: Tensor,
|
| 24 |
+
dim: int,
|
| 25 |
+
*,
|
| 26 |
+
max_period: float = 10000.0,
|
| 27 |
+
scale: float | None = None,
|
| 28 |
+
freqs: Tensor | None = None,
|
| 29 |
+
) -> Tensor:
|
| 30 |
+
"""Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
|
| 31 |
+
t32 = t.to(torch.float32)
|
| 32 |
+
if scale is not None:
|
| 33 |
+
t32 = t32 * float(scale)
|
| 34 |
+
half = dim // 2
|
| 35 |
+
if freqs is not None:
|
| 36 |
+
freqs = freqs.to(device=t32.device, dtype=torch.float32)
|
| 37 |
+
else:
|
| 38 |
+
freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
|
| 39 |
+
angles = t32[:, None] * freqs[None, :]
|
| 40 |
+
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SinusoidalTimeEmbeddingMLP(nn.Module):
|
| 44 |
+
"""Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
dim: int,
|
| 49 |
+
*,
|
| 50 |
+
freq_dim: int = 256,
|
| 51 |
+
hidden_mult: float = 1.0,
|
| 52 |
+
time_scale: float = 1000.0,
|
| 53 |
+
max_period: float = 10000.0,
|
| 54 |
+
) -> None:
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.dim = int(dim)
|
| 57 |
+
self.freq_dim = int(freq_dim)
|
| 58 |
+
self.time_scale = float(time_scale)
|
| 59 |
+
self.max_period = float(max_period)
|
| 60 |
+
hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
|
| 61 |
+
|
| 62 |
+
freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
|
| 63 |
+
self.register_buffer("freqs", freqs, persistent=True)
|
| 64 |
+
|
| 65 |
+
self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
|
| 66 |
+
self.act = nn.SiLU()
|
| 67 |
+
self.proj_out = nn.Linear(hidden_dim, self.dim)
|
| 68 |
+
|
| 69 |
+
def forward(self, t: Tensor) -> Tensor:
|
| 70 |
+
freqs: Tensor = self.freqs # type: ignore[assignment]
|
| 71 |
+
emb_freq = sinusoidal_time_embedding(
|
| 72 |
+
t.to(torch.float32),
|
| 73 |
+
self.freq_dim,
|
| 74 |
+
max_period=self.max_period,
|
| 75 |
+
scale=self.time_scale,
|
| 76 |
+
freqs=freqs,
|
| 77 |
+
)
|
| 78 |
+
dtype_in = self.proj_in.weight.dtype
|
| 79 |
+
hidden = self.proj_in(emb_freq.to(dtype_in))
|
| 80 |
+
hidden = self.act(hidden)
|
| 81 |
+
if hidden.dtype != self.proj_out.weight.dtype:
|
| 82 |
+
hidden = hidden.to(self.proj_out.weight.dtype)
|
| 83 |
+
return self.proj_out(hidden)
|
dinac_ae/vp_diffusion.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
|
| 13 |
+
"""Compute (alpha, sigma) from logSNR in float32.
|
| 14 |
+
|
| 15 |
+
VP constraint: alpha^2 + sigma^2 = 1.
|
| 16 |
+
"""
|
| 17 |
+
lmb32 = lmb.to(dtype=torch.float32)
|
| 18 |
+
alpha = torch.exp(0.5 * F.logsigmoid(lmb32))
|
| 19 |
+
sigma = torch.exp(0.5 * F.logsigmoid(-lmb32))
|
| 20 |
+
return alpha, sigma
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
|
| 24 |
+
"""Broadcast [B] coefficient to match x for per-sample scaling."""
|
| 25 |
+
view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
|
| 26 |
+
return coeff.view(view_shape)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _cosine_interpolated_params(
|
| 30 |
+
logsnr_min: float, logsnr_max: float
|
| 31 |
+
) -> tuple[float, float]:
|
| 32 |
+
"""Compute (a, b) for cosine-interpolated logSNR schedule.
|
| 33 |
+
|
| 34 |
+
logsnr(t) = -2 * log(tan(a*t + b))
|
| 35 |
+
logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
|
| 36 |
+
"""
|
| 37 |
+
b = math.atan(math.exp(-0.5 * logsnr_max))
|
| 38 |
+
a = math.atan(math.exp(-0.5 * logsnr_min)) - b
|
| 39 |
+
return a, b
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def cosine_interpolated_logsnr_from_t(
|
| 43 |
+
t: Tensor, *, logsnr_min: float, logsnr_max: float
|
| 44 |
+
) -> Tensor:
|
| 45 |
+
"""Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
|
| 46 |
+
a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
|
| 47 |
+
t32 = t.to(dtype=torch.float32)
|
| 48 |
+
a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
|
| 49 |
+
b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
|
| 50 |
+
u = a_t * t32 + b_t
|
| 51 |
+
return -2.0 * torch.log(torch.tan(u))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def shifted_cosine_interpolated_logsnr_from_t(
|
| 55 |
+
t: Tensor,
|
| 56 |
+
*,
|
| 57 |
+
logsnr_min: float,
|
| 58 |
+
logsnr_max: float,
|
| 59 |
+
log_change_high: float = 0.0,
|
| 60 |
+
log_change_low: float = 0.0,
|
| 61 |
+
) -> Tensor:
|
| 62 |
+
"""SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
|
| 63 |
+
|
| 64 |
+
lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
|
| 65 |
+
"""
|
| 66 |
+
base = cosine_interpolated_logsnr_from_t(
|
| 67 |
+
t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
|
| 68 |
+
)
|
| 69 |
+
t32 = t.to(dtype=torch.float32)
|
| 70 |
+
high = base + float(log_change_high)
|
| 71 |
+
low = base + float(log_change_low)
|
| 72 |
+
return (1.0 - t32) * high + t32 * low
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
|
| 76 |
+
"""Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
|
| 77 |
+
|
| 78 |
+
``num_steps`` is the number of function evaluations (NFE = decoder forward
|
| 79 |
+
passes). Internally the schedule has ``num_steps + 1`` time points
|
| 80 |
+
(including both endpoints).
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
schedule_type: "linear" or "cosine".
|
| 84 |
+
num_steps: Number of decoder forward passes (NFE), >= 1.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
|
| 88 |
+
"""
|
| 89 |
+
if int(num_steps) < 1:
|
| 90 |
+
raise ValueError("num_steps must be at least 1")
|
| 91 |
+
n = int(num_steps) + 1
|
| 92 |
+
match schedule_type:
|
| 93 |
+
case "linear":
|
| 94 |
+
base = torch.linspace(0.0, 1.0, n)
|
| 95 |
+
case "cosine":
|
| 96 |
+
i = torch.arange(n, dtype=torch.float32)
|
| 97 |
+
base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
|
| 98 |
+
case _ as unreachable:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"Unsupported schedule type: {unreachable!r}. "
|
| 101 |
+
"Use 'linear' or 'cosine'."
|
| 102 |
+
)
|
| 103 |
+
# Descending: high t (noisy) -> low t (clean)
|
| 104 |
+
return torch.flip(base, dims=[0])
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def make_initial_state(
|
| 108 |
+
*,
|
| 109 |
+
noise: Tensor,
|
| 110 |
+
t_start: Tensor,
|
| 111 |
+
logsnr_min: float,
|
| 112 |
+
logsnr_max: float,
|
| 113 |
+
log_change_high: float = 0.0,
|
| 114 |
+
log_change_low: float = 0.0,
|
| 115 |
+
) -> Tensor:
|
| 116 |
+
"""Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
|
| 117 |
+
|
| 118 |
+
All math in float32.
|
| 119 |
+
"""
|
| 120 |
+
batch = int(noise.shape[0])
|
| 121 |
+
lmb_start = shifted_cosine_interpolated_logsnr_from_t(
|
| 122 |
+
t_start.expand(batch).to(dtype=torch.float32),
|
| 123 |
+
logsnr_min=logsnr_min,
|
| 124 |
+
logsnr_max=logsnr_max,
|
| 125 |
+
log_change_high=log_change_high,
|
| 126 |
+
log_change_low=log_change_low,
|
| 127 |
+
)
|
| 128 |
+
_alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
|
| 129 |
+
sigma_view = broadcast_time_like(sigma_start, noise)
|
| 130 |
+
return sigma_view * noise.to(dtype=torch.float32)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def sample_noise(
|
| 134 |
+
shape: tuple[int, ...],
|
| 135 |
+
*,
|
| 136 |
+
noise_std: float = 1.0,
|
| 137 |
+
seed: int | None = None,
|
| 138 |
+
device: torch.device | None = None,
|
| 139 |
+
dtype: torch.dtype = torch.float32,
|
| 140 |
+
) -> Tensor:
|
| 141 |
+
"""Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
|
| 142 |
+
if seed is None:
|
| 143 |
+
noise = torch.randn(
|
| 144 |
+
shape, device=device or torch.device("cpu"), dtype=torch.float32
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
gen = torch.Generator(device="cpu")
|
| 148 |
+
gen.manual_seed(int(seed))
|
| 149 |
+
noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
|
| 150 |
+
noise = noise.mul(float(noise_std))
|
| 151 |
+
target_device = device if device is not None else torch.device("cpu")
|
| 152 |
+
return noise.to(device=target_device, dtype=dtype)
|
dit/attention_blocks.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dense SDPA attention blocks used by the DINAC-AE export."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections.abc import Callable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
from common.norms import RMSNorm
|
| 12 |
+
from common.rope import rotate_half, rotate_half_adjacent
|
| 13 |
+
from dit.position_encoding import DiTPositionEncoding
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _axial_rope_rotate_fn(
|
| 17 |
+
position_encoding: DiTPositionEncoding,
|
| 18 |
+
) -> Callable[[Tensor], Tensor]:
|
| 19 |
+
"""Return the head-dimension rotation matching the configured RoPE layout."""
|
| 20 |
+
|
| 21 |
+
match position_encoding:
|
| 22 |
+
case (
|
| 23 |
+
DiTPositionEncoding.ROPE_2D_AXIAL_DILATED
|
| 24 |
+
| DiTPositionEncoding.ROPE_2D_AXIAL_NORMALIZED
|
| 25 |
+
| DiTPositionEncoding.ROPE_2D_AXIAL_FREQ_AWARE
|
| 26 |
+
| DiTPositionEncoding.ROPE_1D
|
| 27 |
+
):
|
| 28 |
+
return rotate_half
|
| 29 |
+
case (
|
| 30 |
+
DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED
|
| 31 |
+
| DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED_DILATED
|
| 32 |
+
| DiTPositionEncoding.ROPE_2D_AXIAL_BETA_WARP
|
| 33 |
+
| DiTPositionEncoding.ROPE_2D_AXIAL_ALPHA_WARP
|
| 34 |
+
| DiTPositionEncoding.ROPE_3D_ZIMAGE
|
| 35 |
+
):
|
| 36 |
+
return rotate_half_adjacent
|
| 37 |
+
case _ as unreachable:
|
| 38 |
+
raise ValueError(f"Unsupported RoPE position encoding: {unreachable}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DitSelfAttentionCore(nn.Module):
|
| 42 |
+
"""Dense self-attention core with optional axial RoPE on Q/K."""
|
| 43 |
+
|
| 44 |
+
d_model: int
|
| 45 |
+
n_heads: int
|
| 46 |
+
head_dim: int
|
| 47 |
+
position_encoding: DiTPositionEncoding
|
| 48 |
+
qkv: nn.Linear
|
| 49 |
+
proj_out: nn.Linear
|
| 50 |
+
q_norm: RMSNorm
|
| 51 |
+
k_norm: RMSNorm
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
d_model: int,
|
| 56 |
+
n_heads: int,
|
| 57 |
+
*,
|
| 58 |
+
position_encoding: DiTPositionEncoding,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
if d_model % n_heads != 0:
|
| 62 |
+
raise ValueError("d_model must be divisible by n_heads")
|
| 63 |
+
self.d_model = int(d_model)
|
| 64 |
+
self.n_heads = int(n_heads)
|
| 65 |
+
self.head_dim = int(self.d_model // self.n_heads)
|
| 66 |
+
self.position_encoding = position_encoding
|
| 67 |
+
self.qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=False)
|
| 68 |
+
self.proj_out = nn.Linear(self.d_model, self.d_model, bias=False)
|
| 69 |
+
self.q_norm = RMSNorm(self.head_dim)
|
| 70 |
+
self.k_norm = RMSNorm(self.head_dim)
|
| 71 |
+
|
| 72 |
+
def reset_parameters(self) -> None:
|
| 73 |
+
"""Reset projections to their initialization."""
|
| 74 |
+
|
| 75 |
+
nn.init.xavier_uniform_(self.qkv.weight)
|
| 76 |
+
nn.init.xavier_uniform_(self.proj_out.weight)
|
| 77 |
+
|
| 78 |
+
def forward(
|
| 79 |
+
self, tokens: Tensor, *, rope_sincos: tuple[Tensor, Tensor] | None
|
| 80 |
+
) -> Tensor:
|
| 81 |
+
"""Apply dense self-attention to ``[B, N, D]`` tokens."""
|
| 82 |
+
|
| 83 |
+
batch, sequence_length, _width = tokens.shape
|
| 84 |
+
qkv = self.qkv(tokens)
|
| 85 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 86 |
+
q = q.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2)
|
| 87 |
+
k = k.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2)
|
| 88 |
+
v = v.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2)
|
| 89 |
+
q = self.q_norm(q.contiguous())
|
| 90 |
+
k = self.k_norm(k.contiguous())
|
| 91 |
+
q, k = self._apply_axial_rope_dense(q, k, rope_sincos=rope_sincos)
|
| 92 |
+
attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
| 93 |
+
attn = (
|
| 94 |
+
attn.transpose(1, 2).contiguous().view(batch, sequence_length, self.d_model)
|
| 95 |
+
)
|
| 96 |
+
return self.proj_out(attn)
|
| 97 |
+
|
| 98 |
+
def _apply_axial_rope_dense(
|
| 99 |
+
self,
|
| 100 |
+
q: Tensor,
|
| 101 |
+
k: Tensor,
|
| 102 |
+
*,
|
| 103 |
+
rope_sincos: tuple[Tensor, Tensor] | None,
|
| 104 |
+
) -> tuple[Tensor, Tensor]:
|
| 105 |
+
"""Apply axial RoPE to dense Q/K tensors."""
|
| 106 |
+
|
| 107 |
+
if rope_sincos is None:
|
| 108 |
+
return q, k
|
| 109 |
+
sin, cos = rope_sincos
|
| 110 |
+
rope_len = int(sin.shape[-2])
|
| 111 |
+
rope_dtype = sin.dtype
|
| 112 |
+
q_dtype = q.dtype
|
| 113 |
+
k_dtype = k.dtype
|
| 114 |
+
q_rope = q.to(dtype=rope_dtype)
|
| 115 |
+
k_rope = k.to(dtype=rope_dtype)
|
| 116 |
+
match sin.dim():
|
| 117 |
+
case 2:
|
| 118 |
+
sin_b = sin.view(1, 1, rope_len, self.head_dim)
|
| 119 |
+
cos_b = cos.view(1, 1, rope_len, self.head_dim)
|
| 120 |
+
case 3:
|
| 121 |
+
sin_b = sin.view(int(q.shape[0]), 1, rope_len, self.head_dim)
|
| 122 |
+
cos_b = cos.view(int(q.shape[0]), 1, rope_len, self.head_dim)
|
| 123 |
+
case _ as unreachable:
|
| 124 |
+
raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}")
|
| 125 |
+
rotate = _axial_rope_rotate_fn(self.position_encoding)
|
| 126 |
+
q_span = q_rope[:, :, :rope_len, :]
|
| 127 |
+
k_span = k_rope[:, :, :rope_len, :]
|
| 128 |
+
q_head = (q_span * cos_b) + (rotate(q_span) * sin_b)
|
| 129 |
+
k_head = (k_span * cos_b) + (rotate(k_span) * sin_b)
|
| 130 |
+
q_rope = torch.cat([q_head, q_rope[:, :, rope_len:, :]], dim=2)
|
| 131 |
+
k_rope = torch.cat([k_head, k_rope[:, :, rope_len:, :]], dim=2)
|
| 132 |
+
return q_rope.to(dtype=q_dtype), k_rope.to(dtype=k_dtype)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class CrossAttentionCore(nn.Module):
|
| 136 |
+
"""Dense cross-attention core used by the class-token readout."""
|
| 137 |
+
|
| 138 |
+
query_dim: int
|
| 139 |
+
context_dim: int
|
| 140 |
+
context_extra_dim: int
|
| 141 |
+
key_extra_dim: int
|
| 142 |
+
n_heads: int
|
| 143 |
+
head_dim: int
|
| 144 |
+
attn_dim: int
|
| 145 |
+
context_in_dim: int
|
| 146 |
+
attn_dropout: float
|
| 147 |
+
kv_proj: nn.Linear
|
| 148 |
+
k_extra_proj: nn.Linear | None
|
| 149 |
+
out_proj: nn.Linear
|
| 150 |
+
q_norm_heads: RMSNorm
|
| 151 |
+
k_norm_heads: RMSNorm
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
*,
|
| 156 |
+
query_dim: int,
|
| 157 |
+
context_dim: int,
|
| 158 |
+
n_heads: int,
|
| 159 |
+
head_dim: int,
|
| 160 |
+
context_extra_dim: int = 0,
|
| 161 |
+
key_extra_dim: int = 0,
|
| 162 |
+
attn_dropout: float = 0.0,
|
| 163 |
+
) -> None:
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.query_dim = int(query_dim)
|
| 166 |
+
self.context_dim = int(context_dim)
|
| 167 |
+
self.context_extra_dim = int(context_extra_dim)
|
| 168 |
+
self.key_extra_dim = int(key_extra_dim)
|
| 169 |
+
self.n_heads = int(n_heads)
|
| 170 |
+
self.head_dim = int(head_dim)
|
| 171 |
+
self.attn_dim = int(self.n_heads * self.head_dim)
|
| 172 |
+
self.context_in_dim = int(self.context_dim + self.context_extra_dim)
|
| 173 |
+
self.attn_dropout = float(attn_dropout)
|
| 174 |
+
self.kv_proj = nn.Linear(self.context_in_dim, 2 * self.attn_dim, bias=False)
|
| 175 |
+
if self.key_extra_dim == 0:
|
| 176 |
+
self.k_extra_proj = None
|
| 177 |
+
else:
|
| 178 |
+
self.k_extra_proj = nn.Linear(self.key_extra_dim, self.attn_dim, bias=False)
|
| 179 |
+
self.out_proj = nn.Linear(self.attn_dim, self.query_dim, bias=False)
|
| 180 |
+
self.q_norm_heads = RMSNorm(self.head_dim)
|
| 181 |
+
self.k_norm_heads = RMSNorm(self.head_dim)
|
| 182 |
+
|
| 183 |
+
def reset_parameters(self) -> None:
|
| 184 |
+
"""Reset projections to their initialization."""
|
| 185 |
+
|
| 186 |
+
nn.init.xavier_uniform_(self.kv_proj.weight)
|
| 187 |
+
if self.k_extra_proj is not None:
|
| 188 |
+
nn.init.xavier_uniform_(self.k_extra_proj.weight)
|
| 189 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 190 |
+
|
| 191 |
+
def _split_heads(self, x: Tensor) -> Tensor:
|
| 192 |
+
batch, sequence_length, _width = x.shape
|
| 193 |
+
return x.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(
|
| 194 |
+
1, 2
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def _merge_heads(self, x: Tensor) -> Tensor:
|
| 198 |
+
batch, _heads, sequence_length, _head_dim = x.shape
|
| 199 |
+
return (
|
| 200 |
+
x.transpose(1, 2).contiguous().view(batch, sequence_length, self.attn_dim)
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def forward(
|
| 204 |
+
self,
|
| 205 |
+
q_tokens: Tensor,
|
| 206 |
+
kv_tokens: Tensor,
|
| 207 |
+
*,
|
| 208 |
+
training: bool,
|
| 209 |
+
key_extra: Tensor | None = None,
|
| 210 |
+
key_padding_mask: Tensor | None = None,
|
| 211 |
+
) -> Tensor:
|
| 212 |
+
"""Apply dense cross-attention to query and context tokens."""
|
| 213 |
+
|
| 214 |
+
kv = self.kv_proj(kv_tokens)
|
| 215 |
+
k, v = kv.chunk(2, dim=-1)
|
| 216 |
+
if self.k_extra_proj is not None and key_extra is not None:
|
| 217 |
+
k = k + self.k_extra_proj(key_extra)
|
| 218 |
+
q = self.q_norm_heads(self._split_heads(q_tokens).contiguous())
|
| 219 |
+
k = self.k_norm_heads(self._split_heads(k).contiguous())
|
| 220 |
+
v = self._split_heads(v).contiguous()
|
| 221 |
+
if key_padding_mask is None:
|
| 222 |
+
attn_mask = None
|
| 223 |
+
else:
|
| 224 |
+
attn_mask = (~key_padding_mask).to(dtype=q.dtype)
|
| 225 |
+
attn_mask = attn_mask.view(
|
| 226 |
+
key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1]
|
| 227 |
+
)
|
| 228 |
+
attn_mask = attn_mask.masked_fill(attn_mask > 0, float("-inf"))
|
| 229 |
+
attn = F.scaled_dot_product_attention(
|
| 230 |
+
q,
|
| 231 |
+
k,
|
| 232 |
+
v,
|
| 233 |
+
attn_mask=attn_mask,
|
| 234 |
+
dropout_p=self.attn_dropout if training else 0.0,
|
| 235 |
+
is_causal=False,
|
| 236 |
+
)
|
| 237 |
+
return self.out_proj(self._merge_heads(attn))
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
__all__ = ["CrossAttentionCore", "DitSelfAttentionCore"]
|
dit/axial_rope2d.py
ADDED
|
@@ -0,0 +1,1728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass, replace
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Final, cast
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"AxialRoPE2D",
|
| 13 |
+
"AxialRoPE2DAlphaWarpConfig",
|
| 14 |
+
"AxialRoPE2DBetaWarpConfig",
|
| 15 |
+
"AxialRoPE2DConfig",
|
| 16 |
+
"AxialRoPE2DCoordMode",
|
| 17 |
+
"AxialRoPE2DDimLayout",
|
| 18 |
+
"AxialRoPE2DDyPE",
|
| 19 |
+
"AxialRoPE2DDyPEConfig",
|
| 20 |
+
"AxialRoPE2DFrequencyAwareConfig",
|
| 21 |
+
"AxialRoPE2DNormalizeCoords",
|
| 22 |
+
"DyPERoPEMethod",
|
| 23 |
+
"build_axial_rope2d_dype",
|
| 24 |
+
"build_axial_rope2d_inference_warp_with_strength",
|
| 25 |
+
"build_axial_rope2d_with_lumina_frequency_warp",
|
| 26 |
+
"lumina_frequency_aware_periods_for_axis",
|
| 27 |
+
"set_axial_rope2d_dype_noise_time",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class AxialRoPE2DNormalizeCoords(Enum):
|
| 32 |
+
"""Coordinate normalization strategy for axial 2D RoPE (DINOv3-style)."""
|
| 33 |
+
|
| 34 |
+
MIN = "min"
|
| 35 |
+
MAX = "max"
|
| 36 |
+
SEPARATE = "separate"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class AxialRoPE2DCoordMode(Enum):
|
| 40 |
+
"""Coordinate grid mode for axial 2D RoPE.
|
| 41 |
+
|
| 42 |
+
- ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in
|
| 43 |
+
``[-1, 1]`` (after normalization).
|
| 44 |
+
- ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch
|
| 45 |
+
units (e.g., ``x in [0, W-1]``, ``y in [0, H-1]``).
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
DINOV3_NORMALIZED = "dinov3_normalized"
|
| 49 |
+
PATCH_INDICES = "patch_indices"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class AxialRoPE2DDimLayout(Enum):
|
| 53 |
+
"""Layout of angles along the head-dimension.
|
| 54 |
+
|
| 55 |
+
The layout must match the rotation convention used when applying RoPE to Q/K.
|
| 56 |
+
|
| 57 |
+
- ``HALF_SPLIT``: LLaMA-style layout compatible with ``common.rope.rotate_half``
|
| 58 |
+
(splits last dim into two halves).
|
| 59 |
+
- ``PAIR_INTERLEAVED``: EVA-02 / SpeedrunDiT-style layout compatible with an
|
| 60 |
+
adjacent-pair rotate_half (pairs consecutive dims).
|
| 61 |
+
|
| 62 |
+
TODO(refactor): Standardize on ``PAIR_INTERLEAVED`` throughout DiT to reduce
|
| 63 |
+
complexity and avoid layout mismatches, then delete ``HALF_SPLIT`` and any
|
| 64 |
+
related branching once the migration is complete.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
HALF_SPLIT = "half_split"
|
| 68 |
+
PAIR_INTERLEAVED = "pair_interleaved"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class DyPERoPEMethod(Enum):
|
| 72 |
+
"""Dynamic position extrapolation method applied to inference RoPE."""
|
| 73 |
+
|
| 74 |
+
VISION_YARN = "vision_yarn"
|
| 75 |
+
DY_YARN = "dy_yarn"
|
| 76 |
+
DY_NTK = "dy_ntk"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass(frozen=True)
|
| 80 |
+
class AxialRoPE2DDyPEConfig:
|
| 81 |
+
"""Inference-only DyPE controls for axial RoPE.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
method: Dynamic extrapolation rule to apply.
|
| 85 |
+
ref_h_tokens: Training/reference token height.
|
| 86 |
+
ref_w_tokens: Training/reference token width.
|
| 87 |
+
lambda_s: Dynamic extrapolation magnitude.
|
| 88 |
+
lambda_t: Dynamic extrapolation noise-time exponent.
|
| 89 |
+
yarn_beta_0: YaRN first-ramp high rotation threshold.
|
| 90 |
+
yarn_beta_1: YaRN first-ramp low rotation threshold.
|
| 91 |
+
yarn_gamma_0: YaRN base-blend high rotation threshold.
|
| 92 |
+
yarn_gamma_1: YaRN base-blend low rotation threshold.
|
| 93 |
+
yarn_attention_scale: Apply YaRN's static attention magnitude correction.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
method: DyPERoPEMethod
|
| 97 |
+
ref_h_tokens: int
|
| 98 |
+
ref_w_tokens: int
|
| 99 |
+
lambda_s: float = 2.0
|
| 100 |
+
lambda_t: float = 2.0
|
| 101 |
+
yarn_beta_0: float = 1.25
|
| 102 |
+
yarn_beta_1: float = 0.75
|
| 103 |
+
yarn_gamma_0: float = 16.0
|
| 104 |
+
yarn_gamma_1: float = 2.0
|
| 105 |
+
yarn_attention_scale: bool = True
|
| 106 |
+
|
| 107 |
+
def __post_init__(self) -> None:
|
| 108 |
+
if not isinstance(self.method, DyPERoPEMethod):
|
| 109 |
+
raise TypeError("method must be a DyPERoPEMethod")
|
| 110 |
+
if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
|
| 111 |
+
raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
|
| 112 |
+
for name, value in (
|
| 113 |
+
("lambda_s", self.lambda_s),
|
| 114 |
+
("lambda_t", self.lambda_t),
|
| 115 |
+
("yarn_beta_0", self.yarn_beta_0),
|
| 116 |
+
("yarn_beta_1", self.yarn_beta_1),
|
| 117 |
+
("yarn_gamma_0", self.yarn_gamma_0),
|
| 118 |
+
("yarn_gamma_1", self.yarn_gamma_1),
|
| 119 |
+
):
|
| 120 |
+
v = float(value)
|
| 121 |
+
if not math.isfinite(v) or v <= 0.0:
|
| 122 |
+
raise ValueError(f"{name} must be finite and > 0")
|
| 123 |
+
if not isinstance(self.yarn_attention_scale, bool):
|
| 124 |
+
raise TypeError("yarn_attention_scale must be a bool")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@dataclass(frozen=True)
|
| 128 |
+
class AxialRoPE2DFrequencyAwareConfig:
|
| 129 |
+
"""Lumina/Next-DiT-style frequency-aware RoPE warping for one token grid.
|
| 130 |
+
|
| 131 |
+
This config implements a per-axis, per-band frequency warp that depends on
|
| 132 |
+
the input axis length ``L`` relative to a reference length ``L_ref``:
|
| 133 |
+
|
| 134 |
+
- Define the axis scale ``s = L / L_ref``.
|
| 135 |
+
- RoPE is parameterized by *periods* (wavelengths in tokens) ``period[d]``.
|
| 136 |
+
In this module's axial parameterization (with patch-index coordinates),
|
| 137 |
+
the angle for coordinate ``p`` and band ``d`` is:
|
| 138 |
+
|
| 139 |
+
angle(p, d) = 2π * p / period[d]
|
| 140 |
+
|
| 141 |
+
so the wavelength of band ``d`` is exactly ``period[d]`` tokens.
|
| 142 |
+
|
| 143 |
+
- Pick a *boundary wavelength* ``L_boundary`` (in tokens), expressed as a
|
| 144 |
+
trainable multiplier around the reference length:
|
| 145 |
+
|
| 146 |
+
L_boundary = L_ref * exp(boundary_log_multiplier)
|
| 147 |
+
|
| 148 |
+
The scalar ``boundary_log_multiplier`` is shared across H/W axes (and
|
| 149 |
+
initialized by this config).
|
| 150 |
+
|
| 151 |
+
- Define a (possibly fractional) boundary band index ``d*`` as the band
|
| 152 |
+
whose wavelength equals ``L_boundary``:
|
| 153 |
+
|
| 154 |
+
period(d*) = L_boundary
|
| 155 |
+
|
| 156 |
+
In practice we compute ``d*`` by linear interpolation in log-period space
|
| 157 |
+
(periods are geometric for both supported period parametrizations).
|
| 158 |
+
|
| 159 |
+
- The Lumina/Next-DiT implicit exponent ramp is then:
|
| 160 |
+
|
| 161 |
+
alpha[d] = clamp(d / d*, 0, 1)
|
| 162 |
+
|
| 163 |
+
where:
|
| 164 |
+
- high-frequency bands (small d) have alpha≈0 (extrapolation-like),
|
| 165 |
+
- low-frequency bands (large d) have alpha→1 (interpolation-like),
|
| 166 |
+
- alpha is capped at 1 to ensure we never compress a band more than
|
| 167 |
+
plain position interpolation would.
|
| 168 |
+
|
| 169 |
+
- Finally, warp the periods per axis:
|
| 170 |
+
|
| 171 |
+
period'[d] = period[d] * s ** alpha[d]
|
| 172 |
+
|
| 173 |
+
Equivalently, angular frequencies warp as:
|
| 174 |
+
|
| 175 |
+
omega'[d] = omega[d] / s ** alpha[d]
|
| 176 |
+
|
| 177 |
+
Notes
|
| 178 |
+
-----
|
| 179 |
+
- This warp is only meaningful for patch-index coordinates
|
| 180 |
+
(``AxialRoPE2DCoordMode.PATCH_INDICES``). Mixing it with normalized
|
| 181 |
+
coordinates would create an implicit "gauge switch"; we fail fast.
|
| 182 |
+
- The boundary multiplier is trainable by construction (it is stored as an
|
| 183 |
+
nn.Parameter inside AxialRoPE2D when this config is present).
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
ref_h_tokens: int
|
| 187 |
+
ref_w_tokens: int
|
| 188 |
+
boundary_log_multiplier_init: float
|
| 189 |
+
|
| 190 |
+
def __post_init__(self) -> None:
|
| 191 |
+
if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
|
| 192 |
+
raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
|
| 193 |
+
init = float(self.boundary_log_multiplier_init)
|
| 194 |
+
if not math.isfinite(init):
|
| 195 |
+
raise ValueError("boundary_log_multiplier_init must be finite")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@dataclass(frozen=True)
|
| 199 |
+
class AxialRoPE2DBetaWarpConfig:
|
| 200 |
+
"""Trainable beta-curve warping for axial 2D RoPE periods (per token grid).
|
| 201 |
+
|
| 202 |
+
This config defines a per-axis period warp that depends on the runtime axis
|
| 203 |
+
length ``L`` relative to a reference length ``L_ref``:
|
| 204 |
+
|
| 205 |
+
s = L / L_ref
|
| 206 |
+
period'[d] = period[d] * s ** beta[d]
|
| 207 |
+
|
| 208 |
+
where the per-band exponent curve beta(d) is parameterized by three
|
| 209 |
+
trainable u-space scalars (shared across H/W axes):
|
| 210 |
+
|
| 211 |
+
beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0)
|
| 212 |
+
beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1)
|
| 213 |
+
beta_bend = beta_max * tanh(beta_bend_u) (mid-band bump amplitude)
|
| 214 |
+
|
| 215 |
+
and the per-band curve is:
|
| 216 |
+
|
| 217 |
+
t = d / (qtr - 1) in [0, 1]
|
| 218 |
+
beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t)
|
| 219 |
+
|
| 220 |
+
Interpretation
|
| 221 |
+
--------------
|
| 222 |
+
- ``beta(d) == 0``: identity / "extrapolation-like" (no warping; periods do not
|
| 223 |
+
change with axis length).
|
| 224 |
+
- ``beta(d) == 1``: position-interpolation-like for that band
|
| 225 |
+
(``period'[d] = period[d] * s`` so ``omega'[d] = omega[d] / s``).
|
| 226 |
+
|
| 227 |
+
This parameterization provides strong and smooth control over the effective
|
| 228 |
+
scaling of each frequency band, including allowing beta<0 (increasing
|
| 229 |
+
frequencies when s>1), which can be important for unnormalized RoPE bases
|
| 230 |
+
(e.g. base=10_000) where some very low-frequency bands barely rotate on
|
| 231 |
+
practical token grids.
|
| 232 |
+
|
| 233 |
+
Notes:
|
| 234 |
+
- This warp requires patch-index coordinates (coord_mode=PATCH_INDICES).
|
| 235 |
+
- The u parameters are stored as nn.Parameter inside AxialRoPE2D when this
|
| 236 |
+
config is present.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
ref_h_tokens: int
|
| 240 |
+
ref_w_tokens: int
|
| 241 |
+
beta_max: float
|
| 242 |
+
beta_hi_u_init: float
|
| 243 |
+
beta_lo_u_init: float
|
| 244 |
+
beta_bend_u_init: float
|
| 245 |
+
|
| 246 |
+
def __post_init__(self) -> None:
|
| 247 |
+
if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
|
| 248 |
+
raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
|
| 249 |
+
bmax = float(self.beta_max)
|
| 250 |
+
if not math.isfinite(bmax) or bmax <= 0.0:
|
| 251 |
+
raise ValueError("beta_max must be finite and > 0")
|
| 252 |
+
for name, value in (
|
| 253 |
+
("beta_hi_u_init", self.beta_hi_u_init),
|
| 254 |
+
("beta_lo_u_init", self.beta_lo_u_init),
|
| 255 |
+
("beta_bend_u_init", self.beta_bend_u_init),
|
| 256 |
+
):
|
| 257 |
+
v = float(value)
|
| 258 |
+
if not math.isfinite(v):
|
| 259 |
+
raise ValueError(f"{name} must be finite")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@dataclass(frozen=True)
|
| 263 |
+
class AxialRoPE2DAlphaWarpConfig:
|
| 264 |
+
"""Per-band power-law warping of axial 2D RoPE frequencies (shared across axes).
|
| 265 |
+
|
| 266 |
+
This config warps RoPE frequencies per band using a learned exponent vector
|
| 267 |
+
``alpha[d]`` shared across H/W axes:
|
| 268 |
+
|
| 269 |
+
f'[d] = f[d] * s ** alpha[d] where s = L / L_ref
|
| 270 |
+
|
| 271 |
+
Since this module parameterizes angles via periods ``period[d]`` with
|
| 272 |
+
``f[d] ∝ 1 / period[d]``, the equivalent period warp implemented in AxialRoPE2D is:
|
| 273 |
+
|
| 274 |
+
period'[d] = period[d] / s ** alpha[d]
|
| 275 |
+
|
| 276 |
+
Notes:
|
| 277 |
+
- This warp requires patch-index coordinates (coord_mode=PATCH_INDICES).
|
| 278 |
+
- ``alpha`` is stored as an unconstrained nn.Parameter vector of length Q
|
| 279 |
+
(bands per axis), initialized to ``alpha_init`` for all bands.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
ref_h_tokens: int
|
| 283 |
+
ref_w_tokens: int
|
| 284 |
+
alpha_init: float
|
| 285 |
+
|
| 286 |
+
def __post_init__(self) -> None:
|
| 287 |
+
if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0:
|
| 288 |
+
raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
|
| 289 |
+
init = float(self.alpha_init)
|
| 290 |
+
if not math.isfinite(init):
|
| 291 |
+
raise ValueError("alpha_init must be finite")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@dataclass(frozen=True)
|
| 295 |
+
class AxialRoPE2DConfig:
|
| 296 |
+
"""Configuration for axial 2D RoPE sin/cos generation.
|
| 297 |
+
|
| 298 |
+
This module supports two coordinate conventions via ``coord_mode``:
|
| 299 |
+
- ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in
|
| 300 |
+
``[-1, 1]`` (after normalization).
|
| 301 |
+
- ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch
|
| 302 |
+
units (e.g., ``x in [0, W-1]``).
|
| 303 |
+
|
| 304 |
+
Period parametrization
|
| 305 |
+
----------------------
|
| 306 |
+
The periods parametrization matches DINOv3:
|
| 307 |
+
- Provide either `base` (and leave `min_period/max_period` unset), or
|
| 308 |
+
- Provide both `min_period` and `max_period` (and set `base=None`).
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
base: float | None = 100.0
|
| 312 |
+
min_period: float | None = None
|
| 313 |
+
max_period: float | None = None
|
| 314 |
+
coord_mode: AxialRoPE2DCoordMode = AxialRoPE2DCoordMode.DINOV3_NORMALIZED
|
| 315 |
+
normalize_coords: AxialRoPE2DNormalizeCoords = AxialRoPE2DNormalizeCoords.MAX
|
| 316 |
+
dim_layout: AxialRoPE2DDimLayout = AxialRoPE2DDimLayout.HALF_SPLIT
|
| 317 |
+
angle_multiplier: float = 2.0 * float(math.pi)
|
| 318 |
+
coord_offset: float = 0.5
|
| 319 |
+
frequency_aware: AxialRoPE2DFrequencyAwareConfig | None = None
|
| 320 |
+
beta_warp: AxialRoPE2DBetaWarpConfig | None = None
|
| 321 |
+
alpha_warp: AxialRoPE2DAlphaWarpConfig | None = None
|
| 322 |
+
|
| 323 |
+
def __post_init__(self) -> None:
|
| 324 |
+
both_periods = self.min_period is not None and self.max_period is not None
|
| 325 |
+
if (self.base is None and not both_periods) or (
|
| 326 |
+
self.base is not None and both_periods
|
| 327 |
+
):
|
| 328 |
+
raise ValueError(
|
| 329 |
+
"AxialRoPE2DConfig requires either base!=None, or both min_period and max_period."
|
| 330 |
+
)
|
| 331 |
+
if self.base is not None and float(self.base) <= 0.0:
|
| 332 |
+
raise ValueError("AxialRoPE2DConfig.base must be positive when provided")
|
| 333 |
+
if self.min_period is not None and float(self.min_period) <= 0.0:
|
| 334 |
+
raise ValueError(
|
| 335 |
+
"AxialRoPE2DConfig.min_period must be positive when provided"
|
| 336 |
+
)
|
| 337 |
+
if self.max_period is not None and float(self.max_period) <= 0.0:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
"AxialRoPE2DConfig.max_period must be positive when provided"
|
| 340 |
+
)
|
| 341 |
+
if self.min_period is not None and self.max_period is not None:
|
| 342 |
+
if float(self.max_period) <= float(self.min_period):
|
| 343 |
+
raise ValueError("AxialRoPE2DConfig.max_period must be > min_period")
|
| 344 |
+
if not isinstance(self.coord_mode, AxialRoPE2DCoordMode):
|
| 345 |
+
raise TypeError(
|
| 346 |
+
"AxialRoPE2DConfig.coord_mode must be an AxialRoPE2DCoordMode"
|
| 347 |
+
)
|
| 348 |
+
if not isinstance(self.normalize_coords, AxialRoPE2DNormalizeCoords):
|
| 349 |
+
raise TypeError(
|
| 350 |
+
"AxialRoPE2DConfig.normalize_coords must be an AxialRoPE2DNormalizeCoords"
|
| 351 |
+
)
|
| 352 |
+
if not isinstance(self.dim_layout, AxialRoPE2DDimLayout):
|
| 353 |
+
raise TypeError(
|
| 354 |
+
"AxialRoPE2DConfig.dim_layout must be an AxialRoPE2DDimLayout"
|
| 355 |
+
)
|
| 356 |
+
mult = float(self.angle_multiplier)
|
| 357 |
+
if not math.isfinite(mult) or mult <= 0.0:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"AxialRoPE2DConfig.angle_multiplier must be finite and > 0"
|
| 360 |
+
)
|
| 361 |
+
off = float(self.coord_offset)
|
| 362 |
+
if not math.isfinite(off):
|
| 363 |
+
raise ValueError("AxialRoPE2DConfig.coord_offset must be finite")
|
| 364 |
+
if self.frequency_aware is not None and not isinstance(
|
| 365 |
+
self.frequency_aware, AxialRoPE2DFrequencyAwareConfig
|
| 366 |
+
):
|
| 367 |
+
raise TypeError(
|
| 368 |
+
"AxialRoPE2DConfig.frequency_aware must be an AxialRoPE2DFrequencyAwareConfig"
|
| 369 |
+
)
|
| 370 |
+
if self.beta_warp is not None and not isinstance(
|
| 371 |
+
self.beta_warp, AxialRoPE2DBetaWarpConfig
|
| 372 |
+
):
|
| 373 |
+
raise TypeError(
|
| 374 |
+
"AxialRoPE2DConfig.beta_warp must be an AxialRoPE2DBetaWarpConfig"
|
| 375 |
+
)
|
| 376 |
+
if self.alpha_warp is not None and not isinstance(
|
| 377 |
+
self.alpha_warp, AxialRoPE2DAlphaWarpConfig
|
| 378 |
+
):
|
| 379 |
+
raise TypeError(
|
| 380 |
+
"AxialRoPE2DConfig.alpha_warp must be an AxialRoPE2DAlphaWarpConfig"
|
| 381 |
+
)
|
| 382 |
+
warp_count = (
|
| 383 |
+
int(self.frequency_aware is not None)
|
| 384 |
+
+ int(self.beta_warp is not None)
|
| 385 |
+
+ int(self.alpha_warp is not None)
|
| 386 |
+
)
|
| 387 |
+
if warp_count > 1:
|
| 388 |
+
raise ValueError(
|
| 389 |
+
"AxialRoPE2DConfig requires at most one of frequency_aware, beta_warp, or alpha_warp"
|
| 390 |
+
)
|
| 391 |
+
if self.frequency_aware is not None and (
|
| 392 |
+
self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES
|
| 393 |
+
):
|
| 394 |
+
raise ValueError(
|
| 395 |
+
"AxialRoPE2D frequency-aware warping requires coord_mode=PATCH_INDICES"
|
| 396 |
+
)
|
| 397 |
+
if self.beta_warp is not None and (
|
| 398 |
+
self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES
|
| 399 |
+
):
|
| 400 |
+
raise ValueError("AxialRoPE2D beta warp requires coord_mode=PATCH_INDICES")
|
| 401 |
+
if self.alpha_warp is not None and (
|
| 402 |
+
self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES
|
| 403 |
+
):
|
| 404 |
+
raise ValueError("AxialRoPE2D alpha warp requires coord_mode=PATCH_INDICES")
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
_AXIAL_COORDS_CACHE: dict[
|
| 408 |
+
tuple[
|
| 409 |
+
int, int, torch.device, AxialRoPE2DCoordMode, AxialRoPE2DNormalizeCoords, float
|
| 410 |
+
],
|
| 411 |
+
Tensor,
|
| 412 |
+
] = {}
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _get_dinov3_normalized_coords(
|
| 416 |
+
H: int,
|
| 417 |
+
W: int,
|
| 418 |
+
*,
|
| 419 |
+
device: torch.device,
|
| 420 |
+
normalize: AxialRoPE2DNormalizeCoords,
|
| 421 |
+
offset: float,
|
| 422 |
+
) -> Tensor:
|
| 423 |
+
"""Return DINOv3-style flattened coords in [-1, 1] with shape [HW, 2]."""
|
| 424 |
+
if H <= 0 or W <= 0:
|
| 425 |
+
raise ValueError("H and W must be positive for axial RoPE coords")
|
| 426 |
+
key = (
|
| 427 |
+
int(H),
|
| 428 |
+
int(W),
|
| 429 |
+
device,
|
| 430 |
+
AxialRoPE2DCoordMode.DINOV3_NORMALIZED,
|
| 431 |
+
normalize,
|
| 432 |
+
float(offset),
|
| 433 |
+
)
|
| 434 |
+
cached = _AXIAL_COORDS_CACHE.get(key)
|
| 435 |
+
if cached is not None:
|
| 436 |
+
return cached
|
| 437 |
+
start = float(offset)
|
| 438 |
+
end_h = start + float(int(H))
|
| 439 |
+
end_w = start + float(int(W))
|
| 440 |
+
match normalize:
|
| 441 |
+
case AxialRoPE2DNormalizeCoords.MAX:
|
| 442 |
+
denom = float(max(int(H), int(W)))
|
| 443 |
+
coords_h = (
|
| 444 |
+
torch.arange(start, end_h, device=device, dtype=torch.float32) / denom
|
| 445 |
+
)
|
| 446 |
+
coords_w = (
|
| 447 |
+
torch.arange(start, end_w, device=device, dtype=torch.float32) / denom
|
| 448 |
+
)
|
| 449 |
+
case AxialRoPE2DNormalizeCoords.MIN:
|
| 450 |
+
denom = float(min(int(H), int(W)))
|
| 451 |
+
coords_h = (
|
| 452 |
+
torch.arange(start, end_h, device=device, dtype=torch.float32) / denom
|
| 453 |
+
)
|
| 454 |
+
coords_w = (
|
| 455 |
+
torch.arange(start, end_w, device=device, dtype=torch.float32) / denom
|
| 456 |
+
)
|
| 457 |
+
case AxialRoPE2DNormalizeCoords.SEPARATE:
|
| 458 |
+
coords_h = torch.arange(
|
| 459 |
+
start, end_h, device=device, dtype=torch.float32
|
| 460 |
+
) / float(int(H))
|
| 461 |
+
coords_w = torch.arange(
|
| 462 |
+
start, end_w, device=device, dtype=torch.float32
|
| 463 |
+
) / float(int(W))
|
| 464 |
+
case _ as unreachable: # pragma: no cover - defensive
|
| 465 |
+
raise RuntimeError(f"Unsupported normalize_coords: {unreachable}")
|
| 466 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
| 467 |
+
coords = coords.flatten(0, 1)
|
| 468 |
+
coords = 2.0 * coords - 1.0
|
| 469 |
+
# torch.compile cannot trace `torch.is_inference_mode_enabled()` and should
|
| 470 |
+
# not record Python-side cache mutations in the graph.
|
| 471 |
+
if torch.compiler.is_compiling():
|
| 472 |
+
return coords
|
| 473 |
+
if torch.is_inference_mode_enabled():
|
| 474 |
+
return coords
|
| 475 |
+
_AXIAL_COORDS_CACHE[key] = coords
|
| 476 |
+
return coords
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def _get_patch_index_coords(
|
| 480 |
+
H: int,
|
| 481 |
+
W: int,
|
| 482 |
+
*,
|
| 483 |
+
device: torch.device,
|
| 484 |
+
offset: float,
|
| 485 |
+
) -> Tensor:
|
| 486 |
+
"""Return unnormalized patch-grid coords with shape [HW, 2] and (y, x) columns."""
|
| 487 |
+
if H <= 0 or W <= 0:
|
| 488 |
+
raise ValueError("H and W must be positive for axial RoPE coords")
|
| 489 |
+
key = (
|
| 490 |
+
int(H),
|
| 491 |
+
int(W),
|
| 492 |
+
device,
|
| 493 |
+
AxialRoPE2DCoordMode.PATCH_INDICES,
|
| 494 |
+
AxialRoPE2DNormalizeCoords.MAX,
|
| 495 |
+
float(offset),
|
| 496 |
+
)
|
| 497 |
+
cached = _AXIAL_COORDS_CACHE.get(key)
|
| 498 |
+
if cached is not None:
|
| 499 |
+
return cached
|
| 500 |
+
start = float(offset)
|
| 501 |
+
end_h = start + float(int(H))
|
| 502 |
+
end_w = start + float(int(W))
|
| 503 |
+
coords_h = torch.arange(start, end_h, device=device, dtype=torch.float32)
|
| 504 |
+
coords_w = torch.arange(start, end_w, device=device, dtype=torch.float32)
|
| 505 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
| 506 |
+
coords = coords.flatten(0, 1)
|
| 507 |
+
if torch.compiler.is_compiling():
|
| 508 |
+
return coords
|
| 509 |
+
if torch.is_inference_mode_enabled():
|
| 510 |
+
return coords
|
| 511 |
+
_AXIAL_COORDS_CACHE[key] = coords
|
| 512 |
+
return coords
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _lumina_boundary_band_index(
|
| 516 |
+
*,
|
| 517 |
+
periods: Tensor,
|
| 518 |
+
boundary_wavelength: Tensor,
|
| 519 |
+
) -> Tensor:
|
| 520 |
+
"""Return the fractional boundary band index d* for a given boundary wavelength.
|
| 521 |
+
|
| 522 |
+
This implements the Lumina/Next-DiT definition:
|
| 523 |
+
|
| 524 |
+
period(d*) = boundary_wavelength
|
| 525 |
+
|
| 526 |
+
We compute d* by linear interpolation in log-period space. For the supported
|
| 527 |
+
period parameterizations, periods are geometric and log(period) is linear in
|
| 528 |
+
band index.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
periods: 1D float tensor of length Q containing monotonically increasing
|
| 532 |
+
periods in tokens.
|
| 533 |
+
boundary_wavelength: Scalar positive float tensor giving the desired
|
| 534 |
+
boundary wavelength in tokens.
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
Scalar float32 tensor giving the (possibly fractional) boundary index d*.
|
| 538 |
+
|
| 539 |
+
Raises:
|
| 540 |
+
ValueError: If periods are invalid or the boundary is outside valid range
|
| 541 |
+
for a well-defined positive d*.
|
| 542 |
+
"""
|
| 543 |
+
if periods.dim() != 1:
|
| 544 |
+
raise ValueError("periods must be 1D for boundary band index")
|
| 545 |
+
if int(periods.numel()) < 2:
|
| 546 |
+
raise ValueError("periods must have length >= 2 for boundary band index")
|
| 547 |
+
if boundary_wavelength.dim() != 0:
|
| 548 |
+
raise ValueError("boundary_wavelength must be a scalar tensor")
|
| 549 |
+
if not torch.isfinite(boundary_wavelength).item():
|
| 550 |
+
raise ValueError("boundary_wavelength must be finite")
|
| 551 |
+
if float(boundary_wavelength.item()) <= 0.0:
|
| 552 |
+
raise ValueError("boundary_wavelength must be > 0")
|
| 553 |
+
|
| 554 |
+
periods_f = periods.to(dtype=torch.float32)
|
| 555 |
+
if not torch.isfinite(periods_f).all().item():
|
| 556 |
+
raise ValueError("periods must be finite for boundary band index")
|
| 557 |
+
if float(periods_f[0].item()) <= 0.0:
|
| 558 |
+
raise ValueError("periods must be positive for boundary band index")
|
| 559 |
+
if not (periods_f[1:] > periods_f[:-1]).all().item():
|
| 560 |
+
raise ValueError("periods must be strictly increasing for boundary band index")
|
| 561 |
+
|
| 562 |
+
log_p0 = torch.log(periods_f[0])
|
| 563 |
+
log_p1 = torch.log(periods_f[-1])
|
| 564 |
+
denom = log_p1 - log_p0
|
| 565 |
+
if float(denom.item()) <= 0.0:
|
| 566 |
+
raise ValueError("Invalid periods range for boundary band index")
|
| 567 |
+
log_boundary = torch.log(boundary_wavelength.to(dtype=torch.float32))
|
| 568 |
+
q = int(periods_f.numel())
|
| 569 |
+
d_star = (float(q - 1) * (log_boundary - log_p0)) / denom
|
| 570 |
+
if not torch.isfinite(d_star).item():
|
| 571 |
+
raise ValueError("Computed non-finite boundary band index d*")
|
| 572 |
+
if float(d_star.item()) <= 0.0:
|
| 573 |
+
raise ValueError(
|
| 574 |
+
"Boundary wavelength implies d* <= 0; increase the boundary wavelength "
|
| 575 |
+
"(or its multiplier) to be >= the wavelength of the first non-zero band."
|
| 576 |
+
)
|
| 577 |
+
return d_star
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def _lumina_alpha_ramp(
|
| 581 |
+
*,
|
| 582 |
+
qtr: int,
|
| 583 |
+
d_star: Tensor,
|
| 584 |
+
device: torch.device,
|
| 585 |
+
) -> Tensor:
|
| 586 |
+
"""Return alpha[d] = clamp(d / d*, 0, 1) for d in [0, qtr).
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
qtr: Number of RoPE bands per axis (Q).
|
| 590 |
+
d_star: Scalar positive float tensor boundary index d*.
|
| 591 |
+
device: Device for the returned alpha tensor.
|
| 592 |
+
|
| 593 |
+
Returns:
|
| 594 |
+
Float32 tensor of shape [Q] with values in [0, 1].
|
| 595 |
+
"""
|
| 596 |
+
if int(qtr) <= 0:
|
| 597 |
+
raise ValueError("qtr must be positive for alpha ramp")
|
| 598 |
+
if d_star.dim() != 0:
|
| 599 |
+
raise ValueError("d_star must be a scalar tensor for alpha ramp")
|
| 600 |
+
if float(d_star.item()) <= 0.0:
|
| 601 |
+
raise ValueError("d_star must be > 0 for alpha ramp")
|
| 602 |
+
d = torch.arange(int(qtr), device=device, dtype=torch.float32)
|
| 603 |
+
alpha = d / d_star.to(device=device, dtype=torch.float32)
|
| 604 |
+
return torch.clamp(alpha, min=0.0, max=1.0)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def lumina_frequency_aware_periods_for_axis(
|
| 608 |
+
*,
|
| 609 |
+
periods: Tensor,
|
| 610 |
+
axis_len: int,
|
| 611 |
+
ref_axis_len: int,
|
| 612 |
+
boundary_log_multiplier: Tensor,
|
| 613 |
+
angle_multiplier: float,
|
| 614 |
+
) -> Tensor:
|
| 615 |
+
"""Return Lumina/Next-DiT frequency-aware warped periods for one axis.
|
| 616 |
+
|
| 617 |
+
Implements:
|
| 618 |
+
s = axis_len / ref_axis_len
|
| 619 |
+
L_boundary = ref_axis_len * exp(boundary_log_multiplier)
|
| 620 |
+
d* = boundary band index where period(d*) = L_boundary
|
| 621 |
+
alpha[d] = clamp(d / d*, 0, 1)
|
| 622 |
+
period'[d] = period[d] * s**alpha[d]
|
| 623 |
+
|
| 624 |
+
Notes on ``angle_multiplier``
|
| 625 |
+
-----------------------------
|
| 626 |
+
This module parameterizes angles as:
|
| 627 |
+
|
| 628 |
+
angle(p, d) = angle_multiplier * p / period[d]
|
| 629 |
+
|
| 630 |
+
The *wavelength* (period in tokens) is the delta in ``p`` that increases
|
| 631 |
+
the angle by ``2π``:
|
| 632 |
+
|
| 633 |
+
wavelength[d] = 2π * period[d] / angle_multiplier
|
| 634 |
+
|
| 635 |
+
Lumina/Next-DiT define the boundary by matching *wavelength* to the
|
| 636 |
+
reference axis length. We therefore convert the boundary wavelength
|
| 637 |
+
``L_boundary`` into a boundary period via:
|
| 638 |
+
|
| 639 |
+
period_boundary = (angle_multiplier / 2π) * L_boundary
|
| 640 |
+
|
| 641 |
+
When ``angle_multiplier == 2π`` (the DINOv3-style parameterization), this
|
| 642 |
+
reduces to ``period_boundary == L_boundary``.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
periods: Base periods ``[Q]`` in tokens (wavelengths).
|
| 646 |
+
axis_len: Input axis length ``L`` in tokens.
|
| 647 |
+
ref_axis_len: Reference axis length ``L_ref`` in tokens.
|
| 648 |
+
boundary_log_multiplier: Scalar tensor; shared trainable log-multiplier.
|
| 649 |
+
angle_multiplier: RoPE angle multiplier used when converting periods to
|
| 650 |
+
physical wavelengths in tokens.
|
| 651 |
+
|
| 652 |
+
Returns:
|
| 653 |
+
Warped periods ``[Q]`` as float32.
|
| 654 |
+
|
| 655 |
+
Raises:
|
| 656 |
+
ValueError: If inputs are malformed or imply an invalid boundary index.
|
| 657 |
+
"""
|
| 658 |
+
if int(axis_len) <= 0:
|
| 659 |
+
raise ValueError("axis_len must be positive for frequency-aware periods")
|
| 660 |
+
if int(ref_axis_len) <= 0:
|
| 661 |
+
raise ValueError("ref_axis_len must be positive for frequency-aware periods")
|
| 662 |
+
if boundary_log_multiplier.dim() != 0:
|
| 663 |
+
raise ValueError("boundary_log_multiplier must be a scalar tensor")
|
| 664 |
+
if not torch.isfinite(boundary_log_multiplier).item():
|
| 665 |
+
raise ValueError("boundary_log_multiplier must be finite")
|
| 666 |
+
mult = float(angle_multiplier)
|
| 667 |
+
if not math.isfinite(mult) or mult <= 0.0:
|
| 668 |
+
raise ValueError("angle_multiplier must be finite and > 0")
|
| 669 |
+
|
| 670 |
+
device = periods.device
|
| 671 |
+
qtr = int(periods.numel())
|
| 672 |
+
s = float(int(axis_len)) / float(int(ref_axis_len))
|
| 673 |
+
if not math.isfinite(s) or s <= 0.0:
|
| 674 |
+
raise ValueError("axis_len/ref_axis_len must be finite and > 0")
|
| 675 |
+
boundary_wavelength = float(int(ref_axis_len)) * torch.exp(
|
| 676 |
+
boundary_log_multiplier.to(device=device, dtype=torch.float32)
|
| 677 |
+
)
|
| 678 |
+
boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength
|
| 679 |
+
d_star = _lumina_boundary_band_index(
|
| 680 |
+
periods=periods, boundary_wavelength=boundary_period
|
| 681 |
+
)
|
| 682 |
+
alpha = _lumina_alpha_ramp(qtr=qtr, d_star=d_star, device=device)
|
| 683 |
+
scale = torch.pow(torch.tensor(s, device=device, dtype=torch.float32), alpha)
|
| 684 |
+
return periods.to(device=device, dtype=torch.float32) * scale
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def build_axial_rope2d_with_lumina_frequency_warp(
|
| 688 |
+
base: AxialRoPE2D,
|
| 689 |
+
*,
|
| 690 |
+
ref_h_tokens: int,
|
| 691 |
+
ref_w_tokens: int,
|
| 692 |
+
boundary_log_multiplier: float | None,
|
| 693 |
+
boundary_band_multiplier: float | None,
|
| 694 |
+
) -> AxialRoPE2D:
|
| 695 |
+
"""Return an AxialRoPE2D module that applies Lumina-style frequency warping.
|
| 696 |
+
|
| 697 |
+
This helper is intended for inference-time experimentation on checkpoints
|
| 698 |
+
that were trained without frequency-aware warping (e.g.
|
| 699 |
+
``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``). It constructs a new
|
| 700 |
+
AxialRoPE2D instance that:
|
| 701 |
+
- Keeps the base RoPE periods and layout identical to ``base``.
|
| 702 |
+
- Applies Lumina/Next-DiT per-axis warping based on the runtime token
|
| 703 |
+
lengths ``H`` and ``W`` relative to reference lengths.
|
| 704 |
+
- Uses a fixed (non-trainable) scalar boundary multiplier for inference.
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
base: Existing AxialRoPE2D instance from a loaded model.
|
| 708 |
+
ref_h_tokens: Reference H token length (L_ref,h).
|
| 709 |
+
ref_w_tokens: Reference W token length (L_ref,w).
|
| 710 |
+
boundary_log_multiplier: Optional log multiplier applied to reference
|
| 711 |
+
lengths to define the boundary wavelength. Use 0.0 for "boundary at
|
| 712 |
+
L_ref". Mutually exclusive with boundary_band_multiplier.
|
| 713 |
+
boundary_band_multiplier: Optional multiplier that directly selects the
|
| 714 |
+
boundary band index d* relative to the lowest-frequency band index
|
| 715 |
+
(qtr-1). Concretely, with qtr bands per axis:
|
| 716 |
+
|
| 717 |
+
d* = boundary_band_multiplier * (qtr - 1)
|
| 718 |
+
|
| 719 |
+
This lets you move the transition point in frequency space:
|
| 720 |
+
- smaller values => more bands become PI-like (more interpolation)
|
| 721 |
+
- larger values => fewer bands become PI-like (more extrapolation)
|
| 722 |
+
|
| 723 |
+
When provided, we compute the implied boundary wavelength and store
|
| 724 |
+
it as boundary_log_multiplier for the module.
|
| 725 |
+
|
| 726 |
+
Returns:
|
| 727 |
+
New AxialRoPE2D instance on the same device as ``base``.
|
| 728 |
+
|
| 729 |
+
Raises:
|
| 730 |
+
TypeError: If base is not an AxialRoPE2D.
|
| 731 |
+
ValueError: If base uses incompatible coordinates for Lumina warping.
|
| 732 |
+
"""
|
| 733 |
+
if not isinstance(base, AxialRoPE2D):
|
| 734 |
+
raise TypeError("base must be an AxialRoPE2D")
|
| 735 |
+
if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES:
|
| 736 |
+
raise ValueError(
|
| 737 |
+
"Lumina frequency-aware warping requires coord_mode=PATCH_INDICES"
|
| 738 |
+
)
|
| 739 |
+
if (boundary_log_multiplier is None) == (boundary_band_multiplier is None):
|
| 740 |
+
raise ValueError(
|
| 741 |
+
"Provide exactly one of boundary_log_multiplier or boundary_band_multiplier"
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
resolved_log_multiplier: float
|
| 745 |
+
if boundary_band_multiplier is not None:
|
| 746 |
+
if int(ref_h_tokens) != int(ref_w_tokens):
|
| 747 |
+
raise ValueError(
|
| 748 |
+
"boundary_band_multiplier requires ref_h_tokens == ref_w_tokens when using a shared scalar boundary"
|
| 749 |
+
)
|
| 750 |
+
mult = float(boundary_band_multiplier)
|
| 751 |
+
if not math.isfinite(mult) or mult <= 0.0:
|
| 752 |
+
raise ValueError("boundary_band_multiplier must be finite and > 0")
|
| 753 |
+
qtr = int(base.periods.numel())
|
| 754 |
+
if qtr < 2:
|
| 755 |
+
raise ValueError(
|
| 756 |
+
"AxialRoPE2D periods length must be >= 2 for boundary band selection"
|
| 757 |
+
)
|
| 758 |
+
# Solve for the boundary wavelength implied by choosing d* directly.
|
| 759 |
+
#
|
| 760 |
+
# We use geometric interpolation in log-period space:
|
| 761 |
+
# log(period(d*)) = log(period0) + (d*/(qtr-1)) * (log(period_max) - log(period0))
|
| 762 |
+
# with:
|
| 763 |
+
# d* = boundary_band_multiplier * (qtr-1)
|
| 764 |
+
#
|
| 765 |
+
# This allows d* outside the trained band range (multiplier > 1), which
|
| 766 |
+
# corresponds to pushing the transition beyond the lowest-frequency band.
|
| 767 |
+
with torch.no_grad():
|
| 768 |
+
periods_f = base.periods.to(dtype=torch.float32, device=torch.device("cpu"))
|
| 769 |
+
if not (periods_f[1:] > periods_f[:-1]).all().item():
|
| 770 |
+
raise ValueError(
|
| 771 |
+
"base.periods must be strictly increasing for boundary band selection"
|
| 772 |
+
)
|
| 773 |
+
log_p0 = float(torch.log(periods_f[0]).item())
|
| 774 |
+
log_p1 = float(torch.log(periods_f[-1]).item())
|
| 775 |
+
d_star = mult * float(qtr - 1)
|
| 776 |
+
log_boundary_period = log_p0 + (d_star / float(qtr - 1)) * (log_p1 - log_p0)
|
| 777 |
+
boundary_period = math.exp(log_boundary_period)
|
| 778 |
+
angle_mult = float(base.cfg.angle_multiplier)
|
| 779 |
+
if not math.isfinite(angle_mult) or angle_mult <= 0.0:
|
| 780 |
+
raise ValueError("base.cfg.angle_multiplier must be finite and > 0")
|
| 781 |
+
boundary_wavelength = (2.0 * float(math.pi) / angle_mult) * boundary_period
|
| 782 |
+
resolved_log_multiplier = math.log(
|
| 783 |
+
boundary_wavelength / float(int(ref_h_tokens))
|
| 784 |
+
)
|
| 785 |
+
else:
|
| 786 |
+
if boundary_log_multiplier is None: # pragma: no cover - validated above
|
| 787 |
+
raise RuntimeError("boundary_log_multiplier missing despite validation")
|
| 788 |
+
resolved_log_multiplier = float(boundary_log_multiplier)
|
| 789 |
+
|
| 790 |
+
freq_cfg = AxialRoPE2DFrequencyAwareConfig(
|
| 791 |
+
ref_h_tokens=int(ref_h_tokens),
|
| 792 |
+
ref_w_tokens=int(ref_w_tokens),
|
| 793 |
+
boundary_log_multiplier_init=resolved_log_multiplier,
|
| 794 |
+
)
|
| 795 |
+
cfg = replace(base.cfg, frequency_aware=freq_cfg, beta_warp=None, alpha_warp=None)
|
| 796 |
+
device = base.periods.device
|
| 797 |
+
warped = AxialRoPE2D(head_dim=int(base.head_dim), cfg=cfg).to(device=device)
|
| 798 |
+
with torch.no_grad():
|
| 799 |
+
warped.periods.copy_(base.periods.to(device=device, dtype=torch.float32))
|
| 800 |
+
if warped.boundary_log_multiplier is None: # pragma: no cover - defensive
|
| 801 |
+
raise RuntimeError("Expected boundary_log_multiplier to be initialized")
|
| 802 |
+
warped.boundary_log_multiplier.copy_(
|
| 803 |
+
torch.tensor(resolved_log_multiplier, device=device, dtype=torch.float32)
|
| 804 |
+
)
|
| 805 |
+
warped.boundary_log_multiplier.requires_grad_(False)
|
| 806 |
+
return warped
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
def build_axial_rope2d_inference_warp_with_strength(
|
| 810 |
+
base: AxialRoPE2D,
|
| 811 |
+
*,
|
| 812 |
+
ref_h_tokens: int,
|
| 813 |
+
ref_w_tokens: int,
|
| 814 |
+
beta_hi_u: float,
|
| 815 |
+
beta_lo_u: float,
|
| 816 |
+
beta_bend_u: float,
|
| 817 |
+
beta_max: float,
|
| 818 |
+
) -> AxialRoPE2D:
|
| 819 |
+
"""Build an inference-only RoPE warp parameterized by a 3-knob beta(t) curve.
|
| 820 |
+
|
| 821 |
+
This helper is meant for notebook experimentation on checkpoints trained
|
| 822 |
+
with patch-index axial RoPE (e.g. ``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``).
|
| 823 |
+
|
| 824 |
+
We warp per-axis RoPE periods (wavelengths, in tokens) as:
|
| 825 |
+
|
| 826 |
+
period'[d] = period[d] * s ** beta[d] where s = L / L_ref
|
| 827 |
+
|
| 828 |
+
with a smooth exponent curve beta(d) over bands. Unlike a strict
|
| 829 |
+
interpolation-only exponent (0..1), beta is allowed to be negative or > 1,
|
| 830 |
+
which is important for unnormalized RoPE (e.g. base=10_000) where some very
|
| 831 |
+
low-frequency bands are effectively "dead" on practical token grids unless
|
| 832 |
+
their frequencies can be increased (beta < 0).
|
| 833 |
+
|
| 834 |
+
Knobs (bounded via u-space)
|
| 835 |
+
---------------------------
|
| 836 |
+
We use three unconstrained parameters (u-space) which map to bounded beta
|
| 837 |
+
values via tanh:
|
| 838 |
+
|
| 839 |
+
beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0)
|
| 840 |
+
beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1)
|
| 841 |
+
beta_bend = beta_max * tanh(beta_bend_u) ("bump" amplitude in the middle)
|
| 842 |
+
|
| 843 |
+
Then define the per-band curve over t in [0,1] (high -> low frequency):
|
| 844 |
+
|
| 845 |
+
t = d / (qtr - 1)
|
| 846 |
+
beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t)
|
| 847 |
+
|
| 848 |
+
The bump term is 0 at the endpoints and peaks at 1 at t=0.5.
|
| 849 |
+
|
| 850 |
+
Notes
|
| 851 |
+
-----
|
| 852 |
+
- This wrapper is inference-only: it is not saved in checkpoints.
|
| 853 |
+
- It requires patch-index coordinates (no normalized "gauge").
|
| 854 |
+
- It preserves the base module's periods and layout exactly.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
base: Existing AxialRoPE2D instance from a loaded model.
|
| 858 |
+
ref_h_tokens: Reference H token length (L_ref,h).
|
| 859 |
+
ref_w_tokens: Reference W token length (L_ref,w).
|
| 860 |
+
beta_hi_u: Unconstrained u for beta_hi.
|
| 861 |
+
beta_lo_u: Unconstrained u for beta_lo.
|
| 862 |
+
beta_bend_u: Unconstrained u for beta_bend (mid-band bump).
|
| 863 |
+
beta_max: Maximum absolute beta value (> 0). Higher increases control.
|
| 864 |
+
|
| 865 |
+
Returns:
|
| 866 |
+
An AxialRoPE2D instance whose forward applies the inference-only warp.
|
| 867 |
+
"""
|
| 868 |
+
if not isinstance(base, AxialRoPE2D):
|
| 869 |
+
raise TypeError("base must be an AxialRoPE2D")
|
| 870 |
+
if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES:
|
| 871 |
+
raise ValueError(
|
| 872 |
+
"Inference freq-warp requires base.cfg.coord_mode=PATCH_INDICES"
|
| 873 |
+
)
|
| 874 |
+
if int(ref_h_tokens) <= 0 or int(ref_w_tokens) <= 0:
|
| 875 |
+
raise ValueError("ref_h_tokens and ref_w_tokens must be positive")
|
| 876 |
+
hi_u = float(beta_hi_u)
|
| 877 |
+
lo_u = float(beta_lo_u)
|
| 878 |
+
bend_u = float(beta_bend_u)
|
| 879 |
+
if not math.isfinite(hi_u):
|
| 880 |
+
raise ValueError("beta_hi_u must be finite")
|
| 881 |
+
if not math.isfinite(lo_u):
|
| 882 |
+
raise ValueError("beta_lo_u must be finite")
|
| 883 |
+
if not math.isfinite(bend_u):
|
| 884 |
+
raise ValueError("beta_bend_u must be finite")
|
| 885 |
+
bmax = float(beta_max)
|
| 886 |
+
if not math.isfinite(bmax) or bmax <= 0.0:
|
| 887 |
+
raise ValueError("beta_max must be finite and > 0")
|
| 888 |
+
|
| 889 |
+
class _AxialRoPE2DInferenceWarp(AxialRoPE2D):
|
| 890 |
+
"""Inference-only axial RoPE variant with beta-curve knobs."""
|
| 891 |
+
|
| 892 |
+
def __init__(self, *, device: torch.device) -> None:
|
| 893 |
+
super().__init__(head_dim=int(base.head_dim), cfg=base.cfg)
|
| 894 |
+
self.ref_h_tokens: Final[int] = int(ref_h_tokens)
|
| 895 |
+
self.ref_w_tokens: Final[int] = int(ref_w_tokens)
|
| 896 |
+
# Store as buffers so the notebook can mutate by replacing the module.
|
| 897 |
+
self.register_buffer(
|
| 898 |
+
"beta_hi_u",
|
| 899 |
+
torch.tensor(float(hi_u), dtype=torch.float32),
|
| 900 |
+
persistent=False,
|
| 901 |
+
)
|
| 902 |
+
self.register_buffer(
|
| 903 |
+
"beta_lo_u",
|
| 904 |
+
torch.tensor(float(lo_u), dtype=torch.float32),
|
| 905 |
+
persistent=False,
|
| 906 |
+
)
|
| 907 |
+
self.register_buffer(
|
| 908 |
+
"beta_bend_u",
|
| 909 |
+
torch.tensor(float(bend_u), dtype=torch.float32),
|
| 910 |
+
persistent=False,
|
| 911 |
+
)
|
| 912 |
+
self.register_buffer(
|
| 913 |
+
"beta_max",
|
| 914 |
+
torch.tensor(float(bmax), dtype=torch.float32),
|
| 915 |
+
persistent=False,
|
| 916 |
+
)
|
| 917 |
+
self.to(device=device)
|
| 918 |
+
with torch.no_grad():
|
| 919 |
+
self.periods.copy_(
|
| 920 |
+
base.periods.detach().to(device=device, dtype=torch.float32)
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
def forward(
|
| 924 |
+
self,
|
| 925 |
+
*,
|
| 926 |
+
H: int,
|
| 927 |
+
W: int,
|
| 928 |
+
scales: Tensor | None,
|
| 929 |
+
) -> tuple[Tensor, Tensor]:
|
| 930 |
+
if scales is not None:
|
| 931 |
+
raise ValueError("Inference freq-warp does not support dilation scales")
|
| 932 |
+
if int(H) <= 0 or int(W) <= 0:
|
| 933 |
+
raise ValueError("H and W must be positive for axial RoPE")
|
| 934 |
+
device = self.periods.device
|
| 935 |
+
offset = float(self.cfg.coord_offset)
|
| 936 |
+
coords = _get_patch_index_coords(
|
| 937 |
+
int(H), int(W), device=device, offset=offset
|
| 938 |
+
)
|
| 939 |
+
if coords.dim() != 2 or coords.shape[1] != 2:
|
| 940 |
+
raise RuntimeError("Axial RoPE coords must have shape [HW, 2]")
|
| 941 |
+
|
| 942 |
+
qtr = int(self.periods.numel())
|
| 943 |
+
if qtr <= 0:
|
| 944 |
+
raise RuntimeError("Axial RoPE periods length must be positive")
|
| 945 |
+
|
| 946 |
+
beta_max_t = cast("Tensor", self.beta_max).to(
|
| 947 |
+
device=device, dtype=torch.float32
|
| 948 |
+
)
|
| 949 |
+
beta_hi = beta_max_t * torch.tanh(
|
| 950 |
+
cast("Tensor", self.beta_hi_u).to(device=device, dtype=torch.float32)
|
| 951 |
+
)
|
| 952 |
+
beta_lo = beta_max_t * torch.tanh(
|
| 953 |
+
cast("Tensor", self.beta_lo_u).to(device=device, dtype=torch.float32)
|
| 954 |
+
)
|
| 955 |
+
beta_bend = beta_max_t * torch.tanh(
|
| 956 |
+
cast("Tensor", self.beta_bend_u).to(device=device, dtype=torch.float32)
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
if qtr == 1:
|
| 960 |
+
beta = beta_hi[None]
|
| 961 |
+
else:
|
| 962 |
+
t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float(
|
| 963 |
+
qtr - 1
|
| 964 |
+
)
|
| 965 |
+
bump = 4.0 * t * (1.0 - t)
|
| 966 |
+
beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump
|
| 967 |
+
|
| 968 |
+
s_h = float(int(H)) / float(int(self.ref_h_tokens))
|
| 969 |
+
s_w = float(int(W)) / float(int(self.ref_w_tokens))
|
| 970 |
+
if (
|
| 971 |
+
not math.isfinite(s_h)
|
| 972 |
+
or s_h <= 0.0
|
| 973 |
+
or not math.isfinite(s_w)
|
| 974 |
+
or s_w <= 0.0
|
| 975 |
+
):
|
| 976 |
+
raise ValueError(
|
| 977 |
+
"H/ref_h_tokens and W/ref_w_tokens must be finite and > 0"
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
periods_h = self.periods * torch.pow(
|
| 981 |
+
torch.tensor(s_h, device=device, dtype=torch.float32), beta
|
| 982 |
+
)
|
| 983 |
+
periods_w = self.periods * torch.pow(
|
| 984 |
+
torch.tensor(s_w, device=device, dtype=torch.float32), beta
|
| 985 |
+
)
|
| 986 |
+
axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
|
| 987 |
+
|
| 988 |
+
angles = (
|
| 989 |
+
float(self.cfg.angle_multiplier)
|
| 990 |
+
* coords[:, :, None].to(dtype=torch.float32)
|
| 991 |
+
/ axis_periods[None, :, :].to(dtype=torch.float32)
|
| 992 |
+
)
|
| 993 |
+
match self.cfg.dim_layout:
|
| 994 |
+
case AxialRoPE2DDimLayout.HALF_SPLIT:
|
| 995 |
+
angles = angles.flatten(1, 2).repeat(1, 2)
|
| 996 |
+
case AxialRoPE2DDimLayout.PAIR_INTERLEAVED:
|
| 997 |
+
angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2)
|
| 998 |
+
case _ as unreachable: # pragma: no cover - defensive
|
| 999 |
+
raise RuntimeError(f"Unsupported dim_layout: {unreachable}")
|
| 1000 |
+
if angles.shape != (int(H) * int(W), int(self.head_dim)):
|
| 1001 |
+
raise RuntimeError(
|
| 1002 |
+
"Unexpected angles shape in inference freq-warp: "
|
| 1003 |
+
f"{tuple(angles.shape)} for H={int(H)} W={int(W)}"
|
| 1004 |
+
)
|
| 1005 |
+
return torch.sin(angles), torch.cos(angles)
|
| 1006 |
+
|
| 1007 |
+
return _AxialRoPE2DInferenceWarp(device=base.periods.device)
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
class AxialRoPE2D(nn.Module):
|
| 1011 |
+
"""DINOv3-style axial 2D RoPE sin/cos generator.
|
| 1012 |
+
|
| 1013 |
+
The base periods are fixed by ``AxialRoPE2DConfig``. Optionally, this module
|
| 1014 |
+
can include learnable scalar parameters when using:
|
| 1015 |
+
- ``frequency_aware`` (boundary_log_multiplier), or
|
| 1016 |
+
- ``beta_warp`` (beta_hi_u/beta_lo_u/beta_bend_u), or
|
| 1017 |
+
- ``alpha_warp`` (alpha per-band exponents).
|
| 1018 |
+
"""
|
| 1019 |
+
|
| 1020 |
+
periods: Tensor
|
| 1021 |
+
|
| 1022 |
+
def __init__(self, *, head_dim: int, cfg: AxialRoPE2DConfig) -> None:
|
| 1023 |
+
super().__init__()
|
| 1024 |
+
if int(head_dim) <= 0:
|
| 1025 |
+
raise ValueError("head_dim must be positive for AxialRoPE2D")
|
| 1026 |
+
if int(head_dim) % 4 != 0:
|
| 1027 |
+
raise ValueError(
|
| 1028 |
+
"AxialRoPE2D requires head_dim % 4 == 0 (DINOv3 constraint); "
|
| 1029 |
+
f"got head_dim={int(head_dim)}"
|
| 1030 |
+
)
|
| 1031 |
+
if not isinstance(cfg, AxialRoPE2DConfig):
|
| 1032 |
+
raise TypeError("cfg must be an AxialRoPE2DConfig for AxialRoPE2D")
|
| 1033 |
+
self.head_dim: Final[int] = int(head_dim)
|
| 1034 |
+
self.cfg: Final[AxialRoPE2DConfig] = cfg
|
| 1035 |
+
self._d_head: Final[int] = self.head_dim
|
| 1036 |
+
self.register_buffer(
|
| 1037 |
+
"periods",
|
| 1038 |
+
torch.empty(self._d_head // 4, dtype=torch.float32),
|
| 1039 |
+
persistent=True,
|
| 1040 |
+
)
|
| 1041 |
+
if cfg.frequency_aware is None:
|
| 1042 |
+
self.register_parameter("boundary_log_multiplier", None)
|
| 1043 |
+
else:
|
| 1044 |
+
init = float(cfg.frequency_aware.boundary_log_multiplier_init)
|
| 1045 |
+
self.boundary_log_multiplier = nn.Parameter(
|
| 1046 |
+
torch.tensor(init, dtype=torch.float32),
|
| 1047 |
+
requires_grad=True,
|
| 1048 |
+
)
|
| 1049 |
+
if cfg.beta_warp is None:
|
| 1050 |
+
self.register_parameter("beta_hi_u", None)
|
| 1051 |
+
self.register_parameter("beta_lo_u", None)
|
| 1052 |
+
self.register_parameter("beta_bend_u", None)
|
| 1053 |
+
else:
|
| 1054 |
+
beta = cfg.beta_warp
|
| 1055 |
+
self.beta_hi_u = nn.Parameter(
|
| 1056 |
+
torch.tensor(float(beta.beta_hi_u_init), dtype=torch.float32),
|
| 1057 |
+
requires_grad=True,
|
| 1058 |
+
)
|
| 1059 |
+
self.beta_lo_u = nn.Parameter(
|
| 1060 |
+
torch.tensor(float(beta.beta_lo_u_init), dtype=torch.float32),
|
| 1061 |
+
requires_grad=True,
|
| 1062 |
+
)
|
| 1063 |
+
self.beta_bend_u = nn.Parameter(
|
| 1064 |
+
torch.tensor(float(beta.beta_bend_u_init), dtype=torch.float32),
|
| 1065 |
+
requires_grad=True,
|
| 1066 |
+
)
|
| 1067 |
+
if cfg.alpha_warp is None:
|
| 1068 |
+
self.register_parameter("alpha", None)
|
| 1069 |
+
else:
|
| 1070 |
+
qtr = int(self._d_head) // 4
|
| 1071 |
+
if qtr <= 0: # pragma: no cover - defensive
|
| 1072 |
+
raise RuntimeError("AxialRoPE2D periods length must be positive")
|
| 1073 |
+
init = float(cfg.alpha_warp.alpha_init)
|
| 1074 |
+
if not math.isfinite(init):
|
| 1075 |
+
raise RuntimeError("alpha_init must be finite for alpha-warp RoPE")
|
| 1076 |
+
self.alpha = nn.Parameter(
|
| 1077 |
+
torch.full((int(qtr),), init, dtype=torch.float32),
|
| 1078 |
+
requires_grad=True,
|
| 1079 |
+
)
|
| 1080 |
+
self._init_periods()
|
| 1081 |
+
|
| 1082 |
+
def _apply(self, fn): # type: ignore[override]
|
| 1083 |
+
out = super()._apply(fn)
|
| 1084 |
+
with torch.no_grad():
|
| 1085 |
+
self.periods.data = self.periods.data.to(dtype=torch.float32)
|
| 1086 |
+
if self.boundary_log_multiplier is not None:
|
| 1087 |
+
self.boundary_log_multiplier.data = (
|
| 1088 |
+
self.boundary_log_multiplier.data.to(dtype=torch.float32)
|
| 1089 |
+
)
|
| 1090 |
+
if self.beta_hi_u is not None:
|
| 1091 |
+
self.beta_hi_u.data = self.beta_hi_u.data.to(dtype=torch.float32)
|
| 1092 |
+
if self.beta_lo_u is not None:
|
| 1093 |
+
self.beta_lo_u.data = self.beta_lo_u.data.to(dtype=torch.float32)
|
| 1094 |
+
if self.beta_bend_u is not None:
|
| 1095 |
+
self.beta_bend_u.data = self.beta_bend_u.data.to(dtype=torch.float32)
|
| 1096 |
+
if self.alpha is not None:
|
| 1097 |
+
self.alpha.data = self.alpha.data.to(dtype=torch.float32)
|
| 1098 |
+
return out
|
| 1099 |
+
|
| 1100 |
+
def _init_periods(self) -> None:
|
| 1101 |
+
"""Initialize per-dimension periods using DINOv3 formulas."""
|
| 1102 |
+
device: torch.device = self.periods.device
|
| 1103 |
+
dtype: torch.dtype = self.periods.dtype
|
| 1104 |
+
d_head = int(self._d_head)
|
| 1105 |
+
qtr = d_head // 4
|
| 1106 |
+
if qtr <= 0:
|
| 1107 |
+
raise RuntimeError("AxialRoPE2D periods length must be positive")
|
| 1108 |
+
if self.cfg.base is not None:
|
| 1109 |
+
base = float(self.cfg.base)
|
| 1110 |
+
exponents = (
|
| 1111 |
+
2.0
|
| 1112 |
+
* torch.arange(int(qtr), device=device, dtype=dtype)
|
| 1113 |
+
/ float(d_head // 2)
|
| 1114 |
+
)
|
| 1115 |
+
periods = torch.tensor(base, device=device, dtype=dtype) ** exponents
|
| 1116 |
+
else:
|
| 1117 |
+
if self.cfg.min_period is None or self.cfg.max_period is None:
|
| 1118 |
+
raise RuntimeError(
|
| 1119 |
+
"AxialRoPE2DConfig must provide min_period and max_period when base is None"
|
| 1120 |
+
)
|
| 1121 |
+
min_p = float(self.cfg.min_period)
|
| 1122 |
+
max_p = float(self.cfg.max_period)
|
| 1123 |
+
base = max_p / min_p
|
| 1124 |
+
exponents = torch.linspace(0.0, 1.0, int(qtr), device=device, dtype=dtype)
|
| 1125 |
+
periods = torch.tensor(base, device=device, dtype=dtype) ** exponents
|
| 1126 |
+
periods = periods / torch.tensor(base, device=device, dtype=dtype)
|
| 1127 |
+
periods = periods * torch.tensor(max_p, device=device, dtype=dtype)
|
| 1128 |
+
self.periods.data = periods
|
| 1129 |
+
|
| 1130 |
+
def forward(
|
| 1131 |
+
self,
|
| 1132 |
+
*,
|
| 1133 |
+
H: int,
|
| 1134 |
+
W: int,
|
| 1135 |
+
scales: Tensor | None,
|
| 1136 |
+
) -> tuple[Tensor, Tensor]:
|
| 1137 |
+
"""Return (sin, cos) buffers for axial 2D RoPE.
|
| 1138 |
+
|
| 1139 |
+
Args:
|
| 1140 |
+
H: Patch-grid height.
|
| 1141 |
+
W: Patch-grid width.
|
| 1142 |
+
scales: Optional per-batch dilation scale (scalar tensor). When
|
| 1143 |
+
None, returns shared sin/cos shaped ``[HW, head_dim]``. When
|
| 1144 |
+
provided, applies the scalar dilation and still returns shared
|
| 1145 |
+
sin/cos shaped ``[HW, head_dim]``.
|
| 1146 |
+
"""
|
| 1147 |
+
if int(H) <= 0 or int(W) <= 0:
|
| 1148 |
+
raise ValueError("H and W must be positive for AxialRoPE2D forward")
|
| 1149 |
+
device = self.periods.device
|
| 1150 |
+
offset = float(self.cfg.coord_offset)
|
| 1151 |
+
coords: Tensor
|
| 1152 |
+
match self.cfg.coord_mode:
|
| 1153 |
+
case AxialRoPE2DCoordMode.DINOV3_NORMALIZED:
|
| 1154 |
+
coords = _get_dinov3_normalized_coords(
|
| 1155 |
+
int(H),
|
| 1156 |
+
int(W),
|
| 1157 |
+
device=device,
|
| 1158 |
+
normalize=self.cfg.normalize_coords,
|
| 1159 |
+
offset=offset,
|
| 1160 |
+
)
|
| 1161 |
+
case AxialRoPE2DCoordMode.PATCH_INDICES:
|
| 1162 |
+
coords = _get_patch_index_coords(
|
| 1163 |
+
int(H), int(W), device=device, offset=offset
|
| 1164 |
+
)
|
| 1165 |
+
case _ as unreachable: # pragma: no cover - defensive
|
| 1166 |
+
raise RuntimeError(f"Unsupported coord_mode: {unreachable}")
|
| 1167 |
+
if coords.dim() != 2 or coords.shape[1] != 2:
|
| 1168 |
+
raise RuntimeError("AxialRoPE2D coords must have shape [HW, 2]")
|
| 1169 |
+
if self.cfg.frequency_aware is not None:
|
| 1170 |
+
if scales is not None:
|
| 1171 |
+
raise ValueError(
|
| 1172 |
+
"frequency-aware axial RoPE does not support dilation scales"
|
| 1173 |
+
)
|
| 1174 |
+
if self.boundary_log_multiplier is None:
|
| 1175 |
+
raise RuntimeError(
|
| 1176 |
+
"boundary_log_multiplier parameter missing for frequency-aware RoPE"
|
| 1177 |
+
)
|
| 1178 |
+
ref_h = int(self.cfg.frequency_aware.ref_h_tokens)
|
| 1179 |
+
ref_w = int(self.cfg.frequency_aware.ref_w_tokens)
|
| 1180 |
+
periods_h = lumina_frequency_aware_periods_for_axis(
|
| 1181 |
+
periods=self.periods,
|
| 1182 |
+
axis_len=int(H),
|
| 1183 |
+
ref_axis_len=ref_h,
|
| 1184 |
+
boundary_log_multiplier=self.boundary_log_multiplier,
|
| 1185 |
+
angle_multiplier=float(self.cfg.angle_multiplier),
|
| 1186 |
+
)
|
| 1187 |
+
periods_w = lumina_frequency_aware_periods_for_axis(
|
| 1188 |
+
periods=self.periods,
|
| 1189 |
+
axis_len=int(W),
|
| 1190 |
+
ref_axis_len=ref_w,
|
| 1191 |
+
boundary_log_multiplier=self.boundary_log_multiplier,
|
| 1192 |
+
angle_multiplier=float(self.cfg.angle_multiplier),
|
| 1193 |
+
)
|
| 1194 |
+
axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
|
| 1195 |
+
elif self.cfg.beta_warp is not None:
|
| 1196 |
+
if scales is not None:
|
| 1197 |
+
raise ValueError(
|
| 1198 |
+
"beta-warp axial RoPE does not support dilation scales"
|
| 1199 |
+
)
|
| 1200 |
+
if (
|
| 1201 |
+
self.beta_hi_u is None
|
| 1202 |
+
or self.beta_lo_u is None
|
| 1203 |
+
or self.beta_bend_u is None
|
| 1204 |
+
):
|
| 1205 |
+
raise RuntimeError("beta warp parameters missing for beta-warp RoPE")
|
| 1206 |
+
beta_cfg = self.cfg.beta_warp
|
| 1207 |
+
ref_h = int(beta_cfg.ref_h_tokens)
|
| 1208 |
+
ref_w = int(beta_cfg.ref_w_tokens)
|
| 1209 |
+
qtr = int(self.periods.numel())
|
| 1210 |
+
if qtr <= 0: # pragma: no cover - defensive (checked elsewhere)
|
| 1211 |
+
raise RuntimeError("AxialRoPE2D periods length must be positive")
|
| 1212 |
+
beta_max = float(beta_cfg.beta_max)
|
| 1213 |
+
if not math.isfinite(beta_max) or beta_max <= 0.0:
|
| 1214 |
+
raise RuntimeError("beta_max must be finite and > 0")
|
| 1215 |
+
beta_max_t = torch.tensor(beta_max, device=device, dtype=torch.float32)
|
| 1216 |
+
beta_hi = beta_max_t * torch.tanh(self.beta_hi_u.to(dtype=torch.float32))
|
| 1217 |
+
beta_lo = beta_max_t * torch.tanh(self.beta_lo_u.to(dtype=torch.float32))
|
| 1218 |
+
beta_bend = beta_max_t * torch.tanh(
|
| 1219 |
+
self.beta_bend_u.to(dtype=torch.float32)
|
| 1220 |
+
)
|
| 1221 |
+
if qtr == 1:
|
| 1222 |
+
beta = beta_hi[None]
|
| 1223 |
+
else:
|
| 1224 |
+
t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float(
|
| 1225 |
+
qtr - 1
|
| 1226 |
+
)
|
| 1227 |
+
bump = 4.0 * t * (1.0 - t)
|
| 1228 |
+
beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump
|
| 1229 |
+
|
| 1230 |
+
s_h = float(int(H)) / float(ref_h)
|
| 1231 |
+
s_w = float(int(W)) / float(ref_w)
|
| 1232 |
+
if (
|
| 1233 |
+
not math.isfinite(s_h)
|
| 1234 |
+
or s_h <= 0.0
|
| 1235 |
+
or not math.isfinite(s_w)
|
| 1236 |
+
or s_w <= 0.0
|
| 1237 |
+
):
|
| 1238 |
+
raise RuntimeError(
|
| 1239 |
+
"Computed invalid axis scale factors for beta-warp RoPE"
|
| 1240 |
+
)
|
| 1241 |
+
periods_h = self.periods.to(dtype=torch.float32) * torch.pow(
|
| 1242 |
+
torch.tensor(s_h, device=device, dtype=torch.float32), beta
|
| 1243 |
+
)
|
| 1244 |
+
periods_w = self.periods.to(dtype=torch.float32) * torch.pow(
|
| 1245 |
+
torch.tensor(s_w, device=device, dtype=torch.float32), beta
|
| 1246 |
+
)
|
| 1247 |
+
axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
|
| 1248 |
+
elif self.cfg.alpha_warp is not None:
|
| 1249 |
+
if scales is not None:
|
| 1250 |
+
raise ValueError(
|
| 1251 |
+
"alpha-warp axial RoPE does not support dilation scales"
|
| 1252 |
+
)
|
| 1253 |
+
if self.alpha is None:
|
| 1254 |
+
raise RuntimeError("alpha parameter missing for alpha-warp RoPE")
|
| 1255 |
+
alpha_cfg = self.cfg.alpha_warp
|
| 1256 |
+
ref_h = int(alpha_cfg.ref_h_tokens)
|
| 1257 |
+
ref_w = int(alpha_cfg.ref_w_tokens)
|
| 1258 |
+
qtr = int(self.periods.numel())
|
| 1259 |
+
if int(self.alpha.numel()) != qtr:
|
| 1260 |
+
raise RuntimeError(
|
| 1261 |
+
"alpha length must match RoPE periods length for alpha-warp RoPE"
|
| 1262 |
+
)
|
| 1263 |
+
s_h = float(int(H)) / float(ref_h)
|
| 1264 |
+
s_w = float(int(W)) / float(ref_w)
|
| 1265 |
+
if (
|
| 1266 |
+
not math.isfinite(s_h)
|
| 1267 |
+
or s_h <= 0.0
|
| 1268 |
+
or not math.isfinite(s_w)
|
| 1269 |
+
or s_w <= 0.0
|
| 1270 |
+
):
|
| 1271 |
+
raise RuntimeError(
|
| 1272 |
+
"Computed invalid axis scale factors for alpha-warp RoPE"
|
| 1273 |
+
)
|
| 1274 |
+
alpha = self.alpha.to(device=device, dtype=torch.float32)
|
| 1275 |
+
scale_h = torch.pow(
|
| 1276 |
+
torch.tensor(s_h, device=device, dtype=torch.float32), alpha
|
| 1277 |
+
)
|
| 1278 |
+
scale_w = torch.pow(
|
| 1279 |
+
torch.tensor(s_w, device=device, dtype=torch.float32), alpha
|
| 1280 |
+
)
|
| 1281 |
+
periods_h = self.periods.to(dtype=torch.float32) / scale_h
|
| 1282 |
+
periods_w = self.periods.to(dtype=torch.float32) / scale_w
|
| 1283 |
+
axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q]
|
| 1284 |
+
else:
|
| 1285 |
+
axis_periods = self.periods[None, :].expand(2, -1).to(dtype=torch.float32)
|
| 1286 |
+
|
| 1287 |
+
# Angles: angle_multiplier * coords / periods, flattened and tiled.
|
| 1288 |
+
angles = (
|
| 1289 |
+
float(self.cfg.angle_multiplier)
|
| 1290 |
+
* coords[:, :, None].to(dtype=torch.float32)
|
| 1291 |
+
/ axis_periods[None, :, :].to(dtype=torch.float32)
|
| 1292 |
+
)
|
| 1293 |
+
match self.cfg.dim_layout:
|
| 1294 |
+
case AxialRoPE2DDimLayout.HALF_SPLIT:
|
| 1295 |
+
angles = angles.flatten(1, 2).repeat(1, 2)
|
| 1296 |
+
case AxialRoPE2DDimLayout.PAIR_INTERLEAVED:
|
| 1297 |
+
angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2)
|
| 1298 |
+
case _ as unreachable: # pragma: no cover - defensive
|
| 1299 |
+
raise RuntimeError(f"Unsupported dim_layout: {unreachable}")
|
| 1300 |
+
if angles.shape != (int(H) * int(W), int(self._d_head)):
|
| 1301 |
+
raise RuntimeError(
|
| 1302 |
+
"Unexpected angles shape in AxialRoPE2D: "
|
| 1303 |
+
f"{tuple(angles.shape)} for H={int(H)} W={int(W)}"
|
| 1304 |
+
)
|
| 1305 |
+
if scales is not None:
|
| 1306 |
+
if scales.dim() != 0:
|
| 1307 |
+
raise ValueError(
|
| 1308 |
+
"AxialRoPE2D scales must be a scalar tensor for per-batch dilation; "
|
| 1309 |
+
"per-sample dilation is not supported"
|
| 1310 |
+
)
|
| 1311 |
+
angles = angles * scales.to(device=device, dtype=torch.float32)
|
| 1312 |
+
cos = torch.cos(angles)
|
| 1313 |
+
sin = torch.sin(angles)
|
| 1314 |
+
return sin, cos
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
def _dy_ntk_periods_for_axis(
|
| 1318 |
+
*,
|
| 1319 |
+
periods: Tensor,
|
| 1320 |
+
axis_len: int,
|
| 1321 |
+
ref_axis_len: int,
|
| 1322 |
+
noise_time: Tensor,
|
| 1323 |
+
lambda_s: float,
|
| 1324 |
+
lambda_t: float,
|
| 1325 |
+
) -> Tensor:
|
| 1326 |
+
"""Return Dy-NTK periods for one spatial axis.
|
| 1327 |
+
|
| 1328 |
+
Raises:
|
| 1329 |
+
ValueError: If token lengths or scheduler values are invalid.
|
| 1330 |
+
"""
|
| 1331 |
+
|
| 1332 |
+
if int(axis_len) <= 0 or int(ref_axis_len) <= 0:
|
| 1333 |
+
raise ValueError("axis_len and ref_axis_len must be positive for Dy-NTK")
|
| 1334 |
+
qtr = int(periods.numel())
|
| 1335 |
+
if qtr <= 0:
|
| 1336 |
+
raise ValueError("periods must be non-empty for Dy-NTK")
|
| 1337 |
+
scale = float(int(axis_len)) / float(int(ref_axis_len))
|
| 1338 |
+
if not math.isfinite(scale) or scale <= 0.0:
|
| 1339 |
+
raise ValueError("Dy-NTK axis scale must be finite and > 0")
|
| 1340 |
+
return _dy_ntk_periods_for_scale(
|
| 1341 |
+
periods=periods,
|
| 1342 |
+
scale=scale,
|
| 1343 |
+
noise_time=noise_time,
|
| 1344 |
+
lambda_s=float(lambda_s),
|
| 1345 |
+
lambda_t=float(lambda_t),
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
def _dy_ntk_periods_for_scale(
|
| 1350 |
+
*,
|
| 1351 |
+
periods: Tensor,
|
| 1352 |
+
scale: float,
|
| 1353 |
+
noise_time: Tensor,
|
| 1354 |
+
lambda_s: float,
|
| 1355 |
+
lambda_t: float,
|
| 1356 |
+
) -> Tensor:
|
| 1357 |
+
"""Return Dy-NTK periods for a precomputed axis scale."""
|
| 1358 |
+
|
| 1359 |
+
axis_scale = float(scale)
|
| 1360 |
+
if not math.isfinite(axis_scale) or axis_scale <= 0.0:
|
| 1361 |
+
raise ValueError("Dy-NTK scale must be finite and > 0")
|
| 1362 |
+
qtr = int(periods.numel())
|
| 1363 |
+
if qtr <= 0:
|
| 1364 |
+
raise ValueError("periods must be non-empty for Dy-NTK")
|
| 1365 |
+
if scale <= 1.0:
|
| 1366 |
+
return periods.to(dtype=torch.float32)
|
| 1367 |
+
if qtr == 1:
|
| 1368 |
+
exponent = torch.zeros((1,), device=periods.device, dtype=torch.float32)
|
| 1369 |
+
else:
|
| 1370 |
+
exponent = torch.arange(qtr, device=periods.device, dtype=torch.float32) / (
|
| 1371 |
+
float(qtr - 1)
|
| 1372 |
+
)
|
| 1373 |
+
kappa = float(lambda_s) * torch.pow(
|
| 1374 |
+
noise_time.to(device=periods.device, dtype=torch.float32),
|
| 1375 |
+
float(lambda_t),
|
| 1376 |
+
)
|
| 1377 |
+
return periods.to(dtype=torch.float32) * torch.pow(
|
| 1378 |
+
torch.tensor(axis_scale, device=periods.device, dtype=torch.float32),
|
| 1379 |
+
kappa * exponent,
|
| 1380 |
+
)
|
| 1381 |
+
|
| 1382 |
+
|
| 1383 |
+
def _dype_dynamic_exponent(
|
| 1384 |
+
*, noise_time: float, lambda_s: float, lambda_t: float
|
| 1385 |
+
) -> float:
|
| 1386 |
+
"""Return Comfy/DyPE-style dynamic magnitude for normalized noise time."""
|
| 1387 |
+
|
| 1388 |
+
noise = float(noise_time)
|
| 1389 |
+
if not math.isfinite(noise):
|
| 1390 |
+
raise ValueError("DyPE noise_time must be finite")
|
| 1391 |
+
noise = max(0.0, min(1.0, noise))
|
| 1392 |
+
scale = float(lambda_s)
|
| 1393 |
+
exponent = float(lambda_t)
|
| 1394 |
+
if not math.isfinite(scale) or scale <= 0.0:
|
| 1395 |
+
raise ValueError("DyPE lambda_s must be finite and > 0")
|
| 1396 |
+
if not math.isfinite(exponent) or exponent <= 0.0:
|
| 1397 |
+
raise ValueError("DyPE lambda_t must be finite and > 0")
|
| 1398 |
+
return scale * (noise**exponent)
|
| 1399 |
+
|
| 1400 |
+
|
| 1401 |
+
def _dype_correction_factor(
|
| 1402 |
+
*,
|
| 1403 |
+
periods: Tensor,
|
| 1404 |
+
rotations: float,
|
| 1405 |
+
ref_axis_len: int,
|
| 1406 |
+
angle_multiplier: float,
|
| 1407 |
+
) -> float:
|
| 1408 |
+
"""Return fractional band index whose wavelength makes ``rotations`` turns."""
|
| 1409 |
+
|
| 1410 |
+
if int(ref_axis_len) <= 0:
|
| 1411 |
+
raise ValueError("ref_axis_len must be positive for DyPE correction")
|
| 1412 |
+
rot = float(rotations)
|
| 1413 |
+
if not math.isfinite(rot) or rot <= 0.0:
|
| 1414 |
+
raise ValueError("rotations must be finite and > 0")
|
| 1415 |
+
mult = float(angle_multiplier)
|
| 1416 |
+
if not math.isfinite(mult) or mult <= 0.0:
|
| 1417 |
+
raise ValueError("angle_multiplier must be finite and > 0")
|
| 1418 |
+
if int(periods.numel()) < 2:
|
| 1419 |
+
return 0.0
|
| 1420 |
+
periods_cpu = periods.detach().to(device=torch.device("cpu"), dtype=torch.float32)
|
| 1421 |
+
p0 = float(periods_cpu[0].item())
|
| 1422 |
+
p1 = float(periods_cpu[-1].item())
|
| 1423 |
+
if p0 <= 0.0 or p1 <= p0:
|
| 1424 |
+
raise ValueError("periods must be positive and strictly increasing for DyPE")
|
| 1425 |
+
boundary_wavelength = float(int(ref_axis_len)) / rot
|
| 1426 |
+
boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength
|
| 1427 |
+
log_p0 = math.log(p0)
|
| 1428 |
+
log_p1 = math.log(p1)
|
| 1429 |
+
return float(periods.numel() - 1) * (
|
| 1430 |
+
(math.log(boundary_period) - log_p0) / (log_p1 - log_p0)
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
|
| 1434 |
+
def _dype_ramp_mask(
|
| 1435 |
+
*,
|
| 1436 |
+
periods: Tensor,
|
| 1437 |
+
threshold_high_rotations: float,
|
| 1438 |
+
threshold_low_rotations: float,
|
| 1439 |
+
ref_axis_len: int,
|
| 1440 |
+
angle_multiplier: float,
|
| 1441 |
+
) -> Tensor:
|
| 1442 |
+
"""Return YaRN's high-to-low band mask for one dynamic threshold pair."""
|
| 1443 |
+
|
| 1444 |
+
qtr = int(periods.numel())
|
| 1445 |
+
if qtr <= 0:
|
| 1446 |
+
raise ValueError("periods must be non-empty for DyPE ramp mask")
|
| 1447 |
+
device = periods.device
|
| 1448 |
+
if qtr == 1:
|
| 1449 |
+
return torch.ones((1,), device=device, dtype=torch.float32)
|
| 1450 |
+
low = math.floor(
|
| 1451 |
+
_dype_correction_factor(
|
| 1452 |
+
periods=periods,
|
| 1453 |
+
rotations=float(threshold_high_rotations),
|
| 1454 |
+
ref_axis_len=int(ref_axis_len),
|
| 1455 |
+
angle_multiplier=float(angle_multiplier),
|
| 1456 |
+
)
|
| 1457 |
+
)
|
| 1458 |
+
high = math.ceil(
|
| 1459 |
+
_dype_correction_factor(
|
| 1460 |
+
periods=periods,
|
| 1461 |
+
rotations=float(threshold_low_rotations),
|
| 1462 |
+
ref_axis_len=int(ref_axis_len),
|
| 1463 |
+
angle_multiplier=float(angle_multiplier),
|
| 1464 |
+
)
|
| 1465 |
+
)
|
| 1466 |
+
low = max(0, min(qtr - 1, int(low)))
|
| 1467 |
+
high = max(0, min(qtr, int(high)))
|
| 1468 |
+
if low == high:
|
| 1469 |
+
high = min(qtr, low + 1)
|
| 1470 |
+
band = torch.arange(qtr, device=device, dtype=torch.float32)
|
| 1471 |
+
ramp = (band - float(low)) / float(high - low)
|
| 1472 |
+
return 1.0 - torch.clamp(ramp, min=0.0, max=1.0)
|
| 1473 |
+
|
| 1474 |
+
|
| 1475 |
+
def _dy_yarn_periods_for_axis(
|
| 1476 |
+
*,
|
| 1477 |
+
periods: Tensor,
|
| 1478 |
+
linear_scale: float,
|
| 1479 |
+
ntk_scale: float,
|
| 1480 |
+
ref_axis_len: int,
|
| 1481 |
+
noise_time: float,
|
| 1482 |
+
lambda_s: float,
|
| 1483 |
+
cfg: AxialRoPE2DDyPEConfig,
|
| 1484 |
+
angle_multiplier: float,
|
| 1485 |
+
) -> Tensor:
|
| 1486 |
+
"""Return Dy-YaRN periods for one spatial axis."""
|
| 1487 |
+
|
| 1488 |
+
if int(ref_axis_len) <= 0:
|
| 1489 |
+
raise ValueError("ref_axis_len must be positive for Dy-YaRN")
|
| 1490 |
+
linear_s = float(linear_scale)
|
| 1491 |
+
ntk_s = float(ntk_scale)
|
| 1492 |
+
if (
|
| 1493 |
+
not math.isfinite(linear_s)
|
| 1494 |
+
or linear_s <= 0.0
|
| 1495 |
+
or not math.isfinite(ntk_s)
|
| 1496 |
+
or ntk_s <= 0.0
|
| 1497 |
+
):
|
| 1498 |
+
raise ValueError("Dy-YaRN axis scales must be finite and > 0")
|
| 1499 |
+
periods_f = periods.to(dtype=torch.float32)
|
| 1500 |
+
if max(linear_s, ntk_s) <= 1.0:
|
| 1501 |
+
return periods_f
|
| 1502 |
+
kappa = _dype_dynamic_exponent(
|
| 1503 |
+
noise_time=float(noise_time),
|
| 1504 |
+
lambda_s=float(lambda_s),
|
| 1505 |
+
lambda_t=float(cfg.lambda_t),
|
| 1506 |
+
)
|
| 1507 |
+
if kappa <= 1e-6:
|
| 1508 |
+
return periods_f
|
| 1509 |
+
freq_base = float(angle_multiplier) / periods_f
|
| 1510 |
+
freq_linear = float(angle_multiplier) / (periods_f * max(1.0, linear_s))
|
| 1511 |
+
periods_ntk = _dy_ntk_periods_for_scale(
|
| 1512 |
+
periods=periods_f,
|
| 1513 |
+
scale=max(1.0, ntk_s),
|
| 1514 |
+
noise_time=torch.ones((), device=periods.device, dtype=torch.float32),
|
| 1515 |
+
lambda_s=1.0,
|
| 1516 |
+
lambda_t=1.0,
|
| 1517 |
+
)
|
| 1518 |
+
freq_ntk = float(angle_multiplier) / periods_ntk
|
| 1519 |
+
|
| 1520 |
+
beta_mask = _dype_ramp_mask(
|
| 1521 |
+
periods=periods_f,
|
| 1522 |
+
threshold_high_rotations=float(cfg.yarn_beta_0) ** kappa,
|
| 1523 |
+
threshold_low_rotations=float(cfg.yarn_beta_1) ** kappa,
|
| 1524 |
+
ref_axis_len=int(ref_axis_len),
|
| 1525 |
+
angle_multiplier=float(angle_multiplier),
|
| 1526 |
+
)
|
| 1527 |
+
freq = freq_linear * (1.0 - beta_mask) + freq_ntk * beta_mask
|
| 1528 |
+
|
| 1529 |
+
gamma_mask = _dype_ramp_mask(
|
| 1530 |
+
periods=periods_f,
|
| 1531 |
+
threshold_high_rotations=float(cfg.yarn_gamma_0) ** kappa,
|
| 1532 |
+
threshold_low_rotations=float(cfg.yarn_gamma_1) ** kappa,
|
| 1533 |
+
ref_axis_len=int(ref_axis_len),
|
| 1534 |
+
angle_multiplier=float(angle_multiplier),
|
| 1535 |
+
)
|
| 1536 |
+
freq = freq * (1.0 - gamma_mask) + freq_base * gamma_mask
|
| 1537 |
+
return float(angle_multiplier) / freq
|
| 1538 |
+
|
| 1539 |
+
|
| 1540 |
+
class AxialRoPE2DDyPE(AxialRoPE2D):
|
| 1541 |
+
"""Inference-only axial RoPE wrapper using dynamic position extrapolation."""
|
| 1542 |
+
|
| 1543 |
+
dype_cfg: AxialRoPE2DDyPEConfig
|
| 1544 |
+
dype_noise_time: Tensor
|
| 1545 |
+
dype_noise_time_values: list[float]
|
| 1546 |
+
|
| 1547 |
+
def __init__(self, *, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig) -> None:
|
| 1548 |
+
if not isinstance(base, AxialRoPE2D):
|
| 1549 |
+
raise TypeError("base must be an AxialRoPE2D")
|
| 1550 |
+
if not isinstance(cfg, AxialRoPE2DDyPEConfig):
|
| 1551 |
+
raise TypeError("cfg must be an AxialRoPE2DDyPEConfig")
|
| 1552 |
+
if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES:
|
| 1553 |
+
raise ValueError("DyPE requires patch-index axial RoPE coordinates")
|
| 1554 |
+
super().__init__(head_dim=int(base.head_dim), cfg=base.cfg)
|
| 1555 |
+
self.dype_cfg = cfg # ty: ignore[unresolved-attribute]
|
| 1556 |
+
self.register_buffer(
|
| 1557 |
+
"dype_noise_time",
|
| 1558 |
+
torch.tensor(1.0, dtype=torch.float32),
|
| 1559 |
+
persistent=False,
|
| 1560 |
+
)
|
| 1561 |
+
self.dype_noise_time_values: list[float] = [1.0]
|
| 1562 |
+
with torch.no_grad():
|
| 1563 |
+
self.periods.copy_(base.periods.detach().to(dtype=torch.float32))
|
| 1564 |
+
|
| 1565 |
+
def set_dype_noise_time(self, noise_time: float) -> None:
|
| 1566 |
+
"""Set the current normalized diffusion noise time in ``[0, 1]``."""
|
| 1567 |
+
|
| 1568 |
+
t = float(noise_time)
|
| 1569 |
+
if not math.isfinite(t) or t < 0.0 or t > 1.0:
|
| 1570 |
+
raise ValueError("DyPE noise_time must be finite and within [0, 1]")
|
| 1571 |
+
self.dype_noise_time.fill_(t)
|
| 1572 |
+
self.dype_noise_time_values[0] = t
|
| 1573 |
+
|
| 1574 |
+
def _dype_axis_periods(
|
| 1575 |
+
self,
|
| 1576 |
+
*,
|
| 1577 |
+
axis_len: int,
|
| 1578 |
+
ref_axis_len: int,
|
| 1579 |
+
global_scale: float,
|
| 1580 |
+
lambda_s: float,
|
| 1581 |
+
) -> Tensor:
|
| 1582 |
+
"""Return method-specific periods for one spatial axis."""
|
| 1583 |
+
|
| 1584 |
+
cfg = self.dype_cfg
|
| 1585 |
+
axis_scale = float(int(axis_len)) / float(int(ref_axis_len))
|
| 1586 |
+
shared_scale = float(global_scale)
|
| 1587 |
+
if (
|
| 1588 |
+
not math.isfinite(axis_scale)
|
| 1589 |
+
or axis_scale <= 0.0
|
| 1590 |
+
or not math.isfinite(shared_scale)
|
| 1591 |
+
or shared_scale <= 0.0
|
| 1592 |
+
):
|
| 1593 |
+
raise ValueError("DyPE axis and global scales must be finite and > 0")
|
| 1594 |
+
match cfg.method:
|
| 1595 |
+
case DyPERoPEMethod.DY_NTK:
|
| 1596 |
+
return _dy_ntk_periods_for_scale(
|
| 1597 |
+
periods=self.periods,
|
| 1598 |
+
scale=shared_scale,
|
| 1599 |
+
noise_time=self.dype_noise_time,
|
| 1600 |
+
lambda_s=float(lambda_s),
|
| 1601 |
+
lambda_t=float(cfg.lambda_t),
|
| 1602 |
+
)
|
| 1603 |
+
case DyPERoPEMethod.VISION_YARN:
|
| 1604 |
+
return _dy_yarn_periods_for_axis(
|
| 1605 |
+
periods=self.periods,
|
| 1606 |
+
linear_scale=axis_scale,
|
| 1607 |
+
ntk_scale=shared_scale,
|
| 1608 |
+
ref_axis_len=int(ref_axis_len),
|
| 1609 |
+
noise_time=float(self.dype_noise_time_values[0]),
|
| 1610 |
+
lambda_s=float(lambda_s),
|
| 1611 |
+
cfg=cfg,
|
| 1612 |
+
angle_multiplier=float(self.cfg.angle_multiplier),
|
| 1613 |
+
)
|
| 1614 |
+
case DyPERoPEMethod.DY_YARN:
|
| 1615 |
+
return _dy_yarn_periods_for_axis(
|
| 1616 |
+
periods=self.periods,
|
| 1617 |
+
linear_scale=shared_scale,
|
| 1618 |
+
ntk_scale=shared_scale,
|
| 1619 |
+
ref_axis_len=int(ref_axis_len),
|
| 1620 |
+
noise_time=float(self.dype_noise_time_values[0]),
|
| 1621 |
+
lambda_s=float(lambda_s),
|
| 1622 |
+
cfg=cfg,
|
| 1623 |
+
angle_multiplier=float(self.cfg.angle_multiplier),
|
| 1624 |
+
)
|
| 1625 |
+
case _ as unreachable:
|
| 1626 |
+
raise RuntimeError(f"Unsupported DyPE method: {unreachable}")
|
| 1627 |
+
|
| 1628 |
+
def forward(
|
| 1629 |
+
self,
|
| 1630 |
+
*,
|
| 1631 |
+
H: int,
|
| 1632 |
+
W: int,
|
| 1633 |
+
scales: Tensor | None,
|
| 1634 |
+
) -> tuple[Tensor, Tensor]:
|
| 1635 |
+
"""Return timestep-aware DyPE sin/cos buffers."""
|
| 1636 |
+
|
| 1637 |
+
if scales is not None:
|
| 1638 |
+
raise ValueError("DyPE axial RoPE does not support dilation scales")
|
| 1639 |
+
if int(H) <= 0 or int(W) <= 0:
|
| 1640 |
+
raise ValueError("H and W must be positive for DyPE axial RoPE")
|
| 1641 |
+
device = self.periods.device
|
| 1642 |
+
coords = _get_patch_index_coords(
|
| 1643 |
+
int(H), int(W), device=device, offset=float(self.cfg.coord_offset)
|
| 1644 |
+
)
|
| 1645 |
+
scale_h = float(int(H)) / float(int(self.dype_cfg.ref_h_tokens))
|
| 1646 |
+
scale_w = float(int(W)) / float(int(self.dype_cfg.ref_w_tokens))
|
| 1647 |
+
global_scale = max(scale_h, scale_w)
|
| 1648 |
+
periods_h = self._dype_axis_periods(
|
| 1649 |
+
axis_len=int(H),
|
| 1650 |
+
ref_axis_len=int(self.dype_cfg.ref_h_tokens),
|
| 1651 |
+
global_scale=global_scale,
|
| 1652 |
+
lambda_s=float(self.dype_cfg.lambda_s),
|
| 1653 |
+
)
|
| 1654 |
+
periods_w = self._dype_axis_periods(
|
| 1655 |
+
axis_len=int(W),
|
| 1656 |
+
ref_axis_len=int(self.dype_cfg.ref_w_tokens),
|
| 1657 |
+
global_scale=global_scale,
|
| 1658 |
+
lambda_s=float(self.dype_cfg.lambda_s),
|
| 1659 |
+
)
|
| 1660 |
+
axis_periods = torch.stack([periods_h, periods_w], dim=0)
|
| 1661 |
+
angles = (
|
| 1662 |
+
float(self.cfg.angle_multiplier)
|
| 1663 |
+
* coords[:, :, None].to(dtype=torch.float32)
|
| 1664 |
+
/ axis_periods[None, :, :].to(dtype=torch.float32)
|
| 1665 |
+
)
|
| 1666 |
+
match self.cfg.dim_layout:
|
| 1667 |
+
case AxialRoPE2DDimLayout.HALF_SPLIT:
|
| 1668 |
+
angles = angles.flatten(1, 2).repeat(1, 2)
|
| 1669 |
+
case AxialRoPE2DDimLayout.PAIR_INTERLEAVED:
|
| 1670 |
+
angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2)
|
| 1671 |
+
case _ as unreachable:
|
| 1672 |
+
raise RuntimeError(f"Unsupported dim_layout: {unreachable}")
|
| 1673 |
+
expected_shape = (int(H) * int(W), int(self.head_dim))
|
| 1674 |
+
if angles.shape != expected_shape:
|
| 1675 |
+
raise RuntimeError(
|
| 1676 |
+
"Unexpected angles shape in DyPE axial RoPE: "
|
| 1677 |
+
f"{tuple(angles.shape)} for expected {expected_shape}"
|
| 1678 |
+
)
|
| 1679 |
+
sin = torch.sin(angles)
|
| 1680 |
+
cos = torch.cos(angles)
|
| 1681 |
+
if (
|
| 1682 |
+
self.dype_cfg.method in (DyPERoPEMethod.VISION_YARN, DyPERoPEMethod.DY_YARN)
|
| 1683 |
+
and bool(self.dype_cfg.yarn_attention_scale)
|
| 1684 |
+
and global_scale > 1.0
|
| 1685 |
+
):
|
| 1686 |
+
match self.dype_cfg.method:
|
| 1687 |
+
case DyPERoPEMethod.VISION_YARN:
|
| 1688 |
+
mscale_start = 0.1 * math.log(global_scale) + 1.0
|
| 1689 |
+
kappa = _dype_dynamic_exponent(
|
| 1690 |
+
noise_time=float(self.dype_noise_time_values[0]),
|
| 1691 |
+
lambda_s=1.0,
|
| 1692 |
+
lambda_t=float(self.dype_cfg.lambda_t),
|
| 1693 |
+
)
|
| 1694 |
+
mscale = 1.0 + (mscale_start - 1.0) * kappa
|
| 1695 |
+
case DyPERoPEMethod.DY_YARN:
|
| 1696 |
+
mscale = 1.0 + 0.1 * math.log(global_scale) / math.sqrt(
|
| 1697 |
+
global_scale
|
| 1698 |
+
)
|
| 1699 |
+
case _ as unreachable: # pragma: no cover - guarded above
|
| 1700 |
+
raise RuntimeError(
|
| 1701 |
+
f"Unsupported YaRN attention scale: {unreachable}"
|
| 1702 |
+
)
|
| 1703 |
+
if mscale > 1.0:
|
| 1704 |
+
sin = sin * float(mscale)
|
| 1705 |
+
cos = cos * float(mscale)
|
| 1706 |
+
return sin, cos
|
| 1707 |
+
|
| 1708 |
+
|
| 1709 |
+
def build_axial_rope2d_dype(
|
| 1710 |
+
*, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig
|
| 1711 |
+
) -> AxialRoPE2DDyPE:
|
| 1712 |
+
"""Build an inference-only DyPE wrapper for an existing axial RoPE."""
|
| 1713 |
+
|
| 1714 |
+
return AxialRoPE2DDyPE(base=base, cfg=cfg).to(device=base.periods.device)
|
| 1715 |
+
|
| 1716 |
+
|
| 1717 |
+
def set_axial_rope2d_dype_noise_time(module: nn.Module, *, noise_time: float) -> bool:
|
| 1718 |
+
"""Set DyPE noise time on all axial DyPE modules inside ``module``."""
|
| 1719 |
+
|
| 1720 |
+
updated = False
|
| 1721 |
+
for child in module.modules():
|
| 1722 |
+
match child:
|
| 1723 |
+
case AxialRoPE2DDyPE() as dype:
|
| 1724 |
+
dype.set_dype_noise_time(float(noise_time))
|
| 1725 |
+
updated = True
|
| 1726 |
+
case _:
|
| 1727 |
+
pass
|
| 1728 |
+
return updated
|
dit/blocks.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dense unconditional DiT blocks used by the DINAC-AE export."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
from common.norms import RMSNorm
|
| 9 |
+
from common.rope import Rope1D
|
| 10 |
+
from dit.attention_blocks import DitSelfAttentionCore
|
| 11 |
+
from dit.body_config import DiTConditioning
|
| 12 |
+
from dit.mlp import build_dit_mlp, reset_module_parameters
|
| 13 |
+
from dit.mlp_types import MLPType
|
| 14 |
+
from dit.position_encoding import DiTPositionEncoding
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _flatten_tokens(
|
| 18 |
+
x: Tensor, hw: tuple[int, int] | None
|
| 19 |
+
) -> tuple[Tensor, tuple[int, int], bool]:
|
| 20 |
+
"""Return dense tokens plus spatial metadata."""
|
| 21 |
+
|
| 22 |
+
if x.dim() == 4:
|
| 23 |
+
batch, channels, height, width = x.shape
|
| 24 |
+
tokens = x.permute(0, 2, 3, 1).reshape(batch, height * width, channels)
|
| 25 |
+
return tokens, (int(height), int(width)), True
|
| 26 |
+
return x, hw if hw is not None else (int(x.shape[1]), 1), False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _restore_spatial(tokens: Tensor, hw: tuple[int, int]) -> Tensor:
|
| 30 |
+
"""Restore dense tokens to NCHW features."""
|
| 31 |
+
|
| 32 |
+
batch, _sequence_length, width = tokens.shape
|
| 33 |
+
height, spatial_width = hw
|
| 34 |
+
return tokens.transpose(1, 2).reshape(batch, width, height, spatial_width)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TransformerBlock(nn.Module):
|
| 38 |
+
"""Dense pre-norm transformer block kept for import compatibility."""
|
| 39 |
+
|
| 40 |
+
d_model: int
|
| 41 |
+
n_heads: int
|
| 42 |
+
attn_norm: RMSNorm | None
|
| 43 |
+
mlp_norm: RMSNorm | None
|
| 44 |
+
self_attn: DitSelfAttentionCore
|
| 45 |
+
rope_1d: Rope1D | None
|
| 46 |
+
mlp: nn.Module
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
*,
|
| 51 |
+
d_model: int,
|
| 52 |
+
n_heads: int,
|
| 53 |
+
mlp_ratio: float,
|
| 54 |
+
mlp_type: MLPType,
|
| 55 |
+
activation_config: object | None = None,
|
| 56 |
+
block_index: int = 0,
|
| 57 |
+
use_norms: bool = True,
|
| 58 |
+
position_encoding: DiTPositionEncoding = DiTPositionEncoding.NONE,
|
| 59 |
+
rope_theta: float | None = None,
|
| 60 |
+
rope_max_position_embeddings: int | None = None,
|
| 61 |
+
) -> None:
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.d_model = int(d_model)
|
| 64 |
+
self.n_heads = int(n_heads)
|
| 65 |
+
self.attn_norm = RMSNorm(self.d_model) if bool(use_norms) else None
|
| 66 |
+
self.mlp_norm = RMSNorm(self.d_model) if bool(use_norms) else None
|
| 67 |
+
self.self_attn = DitSelfAttentionCore(
|
| 68 |
+
d_model=self.d_model,
|
| 69 |
+
n_heads=self.n_heads,
|
| 70 |
+
position_encoding=position_encoding,
|
| 71 |
+
)
|
| 72 |
+
self.rope_1d = self._build_rope_1d(
|
| 73 |
+
position_encoding=position_encoding,
|
| 74 |
+
rope_theta=rope_theta,
|
| 75 |
+
rope_max_position_embeddings=rope_max_position_embeddings,
|
| 76 |
+
)
|
| 77 |
+
self.mlp = build_dit_mlp(
|
| 78 |
+
mlp_type=mlp_type,
|
| 79 |
+
in_features=self.d_model,
|
| 80 |
+
hidden_budget=int(round(float(mlp_ratio) * self.d_model)),
|
| 81 |
+
activation_config=activation_config,
|
| 82 |
+
block_index=int(block_index),
|
| 83 |
+
bias_up=False,
|
| 84 |
+
bias_down=False,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def reset_parameters(self) -> None:
|
| 88 |
+
"""Reset attention and MLP parameters."""
|
| 89 |
+
|
| 90 |
+
self.self_attn.reset_parameters()
|
| 91 |
+
reset_module_parameters(self.mlp)
|
| 92 |
+
|
| 93 |
+
def _build_rope_1d(
|
| 94 |
+
self,
|
| 95 |
+
*,
|
| 96 |
+
position_encoding: DiTPositionEncoding,
|
| 97 |
+
rope_theta: float | None,
|
| 98 |
+
rope_max_position_embeddings: int | None,
|
| 99 |
+
) -> Rope1D | None:
|
| 100 |
+
"""Build 1D RoPE for sequence-only transformer blocks."""
|
| 101 |
+
|
| 102 |
+
match position_encoding:
|
| 103 |
+
case DiTPositionEncoding.NONE:
|
| 104 |
+
return None
|
| 105 |
+
case DiTPositionEncoding.ROPE_1D:
|
| 106 |
+
if rope_theta is None or rope_max_position_embeddings is None:
|
| 107 |
+
raise ValueError("ROPE_1D requires theta and max positions")
|
| 108 |
+
return Rope1D(
|
| 109 |
+
dim=int(self.d_model // self.n_heads),
|
| 110 |
+
max_position_embeddings=int(rope_max_position_embeddings),
|
| 111 |
+
base=float(rope_theta),
|
| 112 |
+
)
|
| 113 |
+
case _ as unreachable:
|
| 114 |
+
raise ValueError(f"Unsupported TransformerBlock RoPE: {unreachable}")
|
| 115 |
+
|
| 116 |
+
def forward(self, tokens: Tensor, *, generator: torch.Generator | None) -> Tensor: # type: ignore[override]
|
| 117 |
+
"""Apply dense self-attention and MLP to token sequences."""
|
| 118 |
+
|
| 119 |
+
_ = generator
|
| 120 |
+
attn_in = self.attn_norm(tokens) if self.attn_norm is not None else tokens
|
| 121 |
+
rope_sincos = self._build_rope_sincos(attn_in)
|
| 122 |
+
x = tokens + self.self_attn(attn_in, rope_sincos=rope_sincos)
|
| 123 |
+
mlp_in = self.mlp_norm(x) if self.mlp_norm is not None else x
|
| 124 |
+
return x + self.mlp(mlp_in)
|
| 125 |
+
|
| 126 |
+
def _build_rope_sincos(self, tokens: Tensor) -> tuple[Tensor, Tensor] | None:
|
| 127 |
+
"""Return dense 1D RoPE sin/cos buffers."""
|
| 128 |
+
|
| 129 |
+
rope = self.rope_1d
|
| 130 |
+
if rope is None:
|
| 131 |
+
return None
|
| 132 |
+
batch = int(tokens.shape[0])
|
| 133 |
+
seqlen = int(tokens.shape[1])
|
| 134 |
+
position_ids = torch.arange(
|
| 135 |
+
seqlen,
|
| 136 |
+
device=tokens.device,
|
| 137 |
+
dtype=torch.int64,
|
| 138 |
+
).unsqueeze(0)
|
| 139 |
+
position_ids = position_ids.expand(batch, seqlen)
|
| 140 |
+
dummy = tokens.new_empty(batch, self.n_heads, seqlen, rope.dim)
|
| 141 |
+
cos, sin = rope(dummy, position_ids)
|
| 142 |
+
return sin, cos
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class DitBlock(nn.Module):
|
| 146 |
+
"""Dense unconditional DiT self-attention block."""
|
| 147 |
+
|
| 148 |
+
d: int
|
| 149 |
+
h: int
|
| 150 |
+
dh: int
|
| 151 |
+
hidden_budget: int
|
| 152 |
+
position_encoding: DiTPositionEncoding
|
| 153 |
+
conditioning: DiTConditioning
|
| 154 |
+
adaln: object | None
|
| 155 |
+
gate_attn: nn.Parameter | None
|
| 156 |
+
gate_mlp: nn.Parameter | None
|
| 157 |
+
use_norms: bool
|
| 158 |
+
attn_norm1: RMSNorm
|
| 159 |
+
attn_norm2: RMSNorm
|
| 160 |
+
mlp_norm1: RMSNorm
|
| 161 |
+
mlp_norm2: RMSNorm
|
| 162 |
+
attn_core: DitSelfAttentionCore
|
| 163 |
+
qkv: nn.Linear
|
| 164 |
+
proj_out: nn.Linear
|
| 165 |
+
mlp: nn.Module
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
d_model: int,
|
| 170 |
+
n_heads: int,
|
| 171 |
+
mlp_ratio: float,
|
| 172 |
+
*,
|
| 173 |
+
adaln: object | None = None,
|
| 174 |
+
mlp_type: MLPType = MLPType.GELU,
|
| 175 |
+
activation_config: object | None = None,
|
| 176 |
+
block_index: int = 0,
|
| 177 |
+
use_norms: bool = True,
|
| 178 |
+
position_encoding: DiTPositionEncoding = DiTPositionEncoding.NONE,
|
| 179 |
+
conditioning: DiTConditioning = DiTConditioning.UNCOND,
|
| 180 |
+
) -> None:
|
| 181 |
+
super().__init__()
|
| 182 |
+
if conditioning is not DiTConditioning.UNCOND or adaln is not None:
|
| 183 |
+
raise ValueError("DINAC-AE export only supports unconditional DitBlock")
|
| 184 |
+
self.d = int(d_model)
|
| 185 |
+
self.h = int(n_heads)
|
| 186 |
+
self.dh = int(self.d // self.h)
|
| 187 |
+
self.hidden_budget = int(float(mlp_ratio) * self.d)
|
| 188 |
+
self.position_encoding = position_encoding
|
| 189 |
+
self.conditioning = conditioning
|
| 190 |
+
self.adaln = None
|
| 191 |
+
self.gate_attn = None
|
| 192 |
+
self.gate_mlp = None
|
| 193 |
+
self.use_norms = bool(use_norms)
|
| 194 |
+
self.attn_norm1 = RMSNorm(self.d)
|
| 195 |
+
self.attn_norm2 = RMSNorm(self.d)
|
| 196 |
+
self.mlp_norm1 = RMSNorm(self.d)
|
| 197 |
+
self.mlp_norm2 = RMSNorm(self.d)
|
| 198 |
+
self.attn_core = DitSelfAttentionCore(
|
| 199 |
+
d_model=self.d,
|
| 200 |
+
n_heads=self.h,
|
| 201 |
+
position_encoding=position_encoding,
|
| 202 |
+
)
|
| 203 |
+
self.qkv = self.attn_core.qkv
|
| 204 |
+
self.proj_out = self.attn_core.proj_out
|
| 205 |
+
self.mlp = build_dit_mlp(
|
| 206 |
+
mlp_type=mlp_type,
|
| 207 |
+
in_features=self.d,
|
| 208 |
+
hidden_budget=self.hidden_budget,
|
| 209 |
+
activation_config=activation_config,
|
| 210 |
+
block_index=int(block_index),
|
| 211 |
+
bias_up=False,
|
| 212 |
+
bias_down=False,
|
| 213 |
+
)
|
| 214 |
+
self.reset_parameters()
|
| 215 |
+
|
| 216 |
+
def reset_parameters(self) -> None:
|
| 217 |
+
"""Reset attention and MLP parameters."""
|
| 218 |
+
|
| 219 |
+
self.attn_core.reset_parameters()
|
| 220 |
+
reset_module_parameters(self.mlp)
|
| 221 |
+
|
| 222 |
+
def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
|
| 223 |
+
"""No-op hook kept for API compatibility."""
|
| 224 |
+
|
| 225 |
+
_ = fullgraph, dynamic
|
| 226 |
+
|
| 227 |
+
def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
|
| 228 |
+
"""No-op hook kept for API compatibility."""
|
| 229 |
+
|
| 230 |
+
_ = fullgraph, dynamic
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
x: Tensor,
|
| 235 |
+
hw: tuple[int, int],
|
| 236 |
+
cond_vec: Tensor,
|
| 237 |
+
adaln_m: Tensor | None = None,
|
| 238 |
+
*,
|
| 239 |
+
rope_sincos: tuple[Tensor, Tensor] | None = None,
|
| 240 |
+
generator: torch.Generator | None = None,
|
| 241 |
+
) -> Tensor:
|
| 242 |
+
"""Apply the dense unconditional block to spatial features or tokens."""
|
| 243 |
+
|
| 244 |
+
_ = cond_vec, adaln_m, generator
|
| 245 |
+
tokens, hw_tokens, was_spatial = _flatten_tokens(x, hw)
|
| 246 |
+
attn_in = self.attn_norm1(tokens) if self.use_norms else tokens
|
| 247 |
+
y = self.attn_core(attn_in, rope_sincos=rope_sincos)
|
| 248 |
+
attn_out = self.attn_norm2(y) if self.use_norms else y
|
| 249 |
+
tokens = tokens + attn_out
|
| 250 |
+
mlp_in = self.mlp_norm1(tokens) if self.use_norms else tokens
|
| 251 |
+
mlp_out = self.mlp(mlp_in)
|
| 252 |
+
mlp_out = self.mlp_norm2(mlp_out) if self.use_norms else mlp_out
|
| 253 |
+
tokens = tokens + mlp_out
|
| 254 |
+
if was_spatial:
|
| 255 |
+
return _restore_spatial(tokens, hw_tokens)
|
| 256 |
+
return tokens
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
__all__ = ["DitBlock", "TransformerBlock"]
|
dit/body_config.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Small DiT configuration enums required by the DINAC-AE export."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from enum import Enum, auto
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DiTConditioning(Enum):
|
| 10 |
+
"""Conditioning strategy for exported DiT blocks."""
|
| 11 |
+
|
| 12 |
+
ADALN = auto()
|
| 13 |
+
GATED_UNCOND = auto()
|
| 14 |
+
UNCOND = auto()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AdaLNSharingMode(Enum):
|
| 18 |
+
"""AdaLN sharing modes retained for auxiliary config imports."""
|
| 19 |
+
|
| 20 |
+
PER_BLOCK = auto()
|
| 21 |
+
SHARED_BASE_LOW_RANK_DELTA = auto()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class DiTBodyConfig:
|
| 26 |
+
"""Minimal body config placeholder for unused auxiliary heads."""
|
| 27 |
+
|
| 28 |
+
depth: int = 1
|
| 29 |
+
d_model: int = 768
|
| 30 |
+
n_heads: int = 12
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
__all__ = ["AdaLNSharingMode", "DiTBodyConfig", "DiTConditioning"]
|
dit/mlp.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Small MLP factory for DINAC-AE DiT blocks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections.abc import Callable
|
| 6 |
+
from typing import Protocol, cast
|
| 7 |
+
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
from dit.mlp_types import MLPType
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Resettable(Protocol):
|
| 15 |
+
"""Typing protocol for modules with ``reset_parameters``."""
|
| 16 |
+
|
| 17 |
+
def reset_parameters(self) -> None:
|
| 18 |
+
"""Reset module parameters."""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def reset_module_parameters(module: nn.Module) -> None:
|
| 22 |
+
"""Reset a module that exposes ``reset_parameters``."""
|
| 23 |
+
|
| 24 |
+
cast(Resettable, module).reset_parameters()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SimpleActivationMLP(nn.Module):
|
| 28 |
+
"""Feedforward MLP: ``down(activation(up(x)))``."""
|
| 29 |
+
|
| 30 |
+
in_features: int
|
| 31 |
+
hidden_features: int
|
| 32 |
+
activation: Callable[[Tensor], Tensor]
|
| 33 |
+
activation_name: str
|
| 34 |
+
up: nn.Linear
|
| 35 |
+
down: nn.Linear
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
in_features: int,
|
| 40 |
+
hidden_features: int,
|
| 41 |
+
*,
|
| 42 |
+
activation: Callable[[Tensor], Tensor],
|
| 43 |
+
activation_name: str,
|
| 44 |
+
bias_up: bool,
|
| 45 |
+
bias_down: bool,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.in_features = int(in_features)
|
| 49 |
+
self.hidden_features = int(hidden_features)
|
| 50 |
+
self.activation = activation
|
| 51 |
+
self.activation_name = str(activation_name)
|
| 52 |
+
self.up = nn.Linear(self.in_features, self.hidden_features, bias=bias_up)
|
| 53 |
+
self.down = nn.Linear(self.hidden_features, self.in_features, bias=bias_down)
|
| 54 |
+
self.reset_parameters()
|
| 55 |
+
|
| 56 |
+
def reset_parameters(self) -> None:
|
| 57 |
+
"""Reset linear projections."""
|
| 58 |
+
|
| 59 |
+
nn.init.xavier_uniform_(self.up.weight)
|
| 60 |
+
if self.up.bias is not None:
|
| 61 |
+
nn.init.zeros_(self.up.bias)
|
| 62 |
+
nn.init.xavier_uniform_(self.down.weight)
|
| 63 |
+
if self.down.bias is not None:
|
| 64 |
+
nn.init.zeros_(self.down.bias)
|
| 65 |
+
|
| 66 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
|
| 67 |
+
"""Apply the MLP."""
|
| 68 |
+
|
| 69 |
+
return self.down(self.activation(self.up(x)))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_dit_mlp(
|
| 73 |
+
*,
|
| 74 |
+
mlp_type: MLPType,
|
| 75 |
+
in_features: int,
|
| 76 |
+
hidden_budget: int,
|
| 77 |
+
activation_config: object | None = None,
|
| 78 |
+
block_index: int = 0,
|
| 79 |
+
bias_up: bool = False,
|
| 80 |
+
bias_down: bool = False,
|
| 81 |
+
) -> nn.Module:
|
| 82 |
+
"""Build the exported MLP variant."""
|
| 83 |
+
|
| 84 |
+
_ = activation_config, block_index
|
| 85 |
+
match mlp_type:
|
| 86 |
+
case MLPType.GELU:
|
| 87 |
+
return SimpleActivationMLP(
|
| 88 |
+
in_features=int(in_features),
|
| 89 |
+
hidden_features=int(hidden_budget),
|
| 90 |
+
activation=F.gelu,
|
| 91 |
+
activation_name="gelu",
|
| 92 |
+
bias_up=bool(bias_up),
|
| 93 |
+
bias_down=bool(bias_down),
|
| 94 |
+
)
|
| 95 |
+
case MLPType.SILU:
|
| 96 |
+
return SimpleActivationMLP(
|
| 97 |
+
in_features=int(in_features),
|
| 98 |
+
hidden_features=int(hidden_budget),
|
| 99 |
+
activation=F.silu,
|
| 100 |
+
activation_name="silu",
|
| 101 |
+
bias_up=bool(bias_up),
|
| 102 |
+
bias_down=bool(bias_down),
|
| 103 |
+
)
|
| 104 |
+
case MLPType.RELU:
|
| 105 |
+
return SimpleActivationMLP(
|
| 106 |
+
in_features=int(in_features),
|
| 107 |
+
hidden_features=int(hidden_budget),
|
| 108 |
+
activation=F.relu,
|
| 109 |
+
activation_name="relu",
|
| 110 |
+
bias_up=bool(bias_up),
|
| 111 |
+
bias_down=bool(bias_down),
|
| 112 |
+
)
|
| 113 |
+
case _ as unreachable:
|
| 114 |
+
raise ValueError(f"Unsupported exported MLP type: {unreachable}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
__all__ = ["SimpleActivationMLP", "build_dit_mlp", "reset_module_parameters"]
|
dit/mlp_types.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MLPType(Enum):
|
| 7 |
+
"""MLP implementation variants for DiT blocks.
|
| 8 |
+
|
| 9 |
+
- SWI: baseline SwiGLU MLP
|
| 10 |
+
- SWINE: Sigmoid-gated sine GLU (σ·sin) with trig promoted to float32
|
| 11 |
+
- SWINER: Sigmoid-gated FINER-style chirp (σ·sin(ω₀·((1+|x|)·x)))
|
| 12 |
+
- SPWIDER: sqrt-gated sine GLU (√|a|·sin(ω₀·b))
|
| 13 |
+
- RELU: Plain ReLU-activated feedforward
|
| 14 |
+
- RELU2: ReLU-squared activation (ReLU(x)^2) feedforward
|
| 15 |
+
- SILU: Plain SiLU-activated feedforward
|
| 16 |
+
- GELU: Plain GELU-activated feedforward
|
| 17 |
+
- SIREN: Pure sine-activated MLP
|
| 18 |
+
- SPIDER: Sine with sqrt magnitude (sin(ω₀·x)·√|x|)
|
| 19 |
+
- SINC: Sinc-activated MLP with log-spaced per-channel scales
|
| 20 |
+
- FINER: FINER activation MLP with a fixed global scale (non-learnable)
|
| 21 |
+
- RBF: Low-rank per-patch RBF with Gaussian kernel
|
| 22 |
+
- RBF_ODD: RBF with odd-Gaussian kernel (z·exp(-z^2))
|
| 23 |
+
- RBF_SHARP: RBF with sharpness exponent alpha (exp(-(s·|x-b|)^alpha))
|
| 24 |
+
- RBF_SIREN: RBF using sine basis sin(ω0·(s·(x-b)))
|
| 25 |
+
- RBF_FINER: RBF using FINER (chirp) basis sin(ω0·((1+|z|)·z)), z=s·(x-b)
|
| 26 |
+
- RBF_DAMPED_SINE: RBF using damped sine sin(ω0·z)·exp(-|z|), z=s·(x-b)
|
| 27 |
+
- RBF_SINC: RBF using sinc basis sinc(z)=sin(z)/z with z=s·(x-b)
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
SWI = "swi"
|
| 31 |
+
SWINE = "swine"
|
| 32 |
+
SWINER = "swiner"
|
| 33 |
+
SPWIDER = "spwider"
|
| 34 |
+
RELU = "relu"
|
| 35 |
+
RELU2 = "relu2"
|
| 36 |
+
SILU = "silu"
|
| 37 |
+
GELU = "gelu"
|
| 38 |
+
SIREN = "siren"
|
| 39 |
+
SPIDER = "spider"
|
| 40 |
+
SINC = "sinc"
|
| 41 |
+
FINER = "finer"
|
| 42 |
+
RBF = "rbf"
|
| 43 |
+
RBF_ODD = "rbf_odd"
|
| 44 |
+
RBF_SHARP = "rbf_sharp"
|
| 45 |
+
RBF_SIREN = "rbf_siren"
|
| 46 |
+
RBF_FINER = "rbf_finer"
|
| 47 |
+
RBF_DAMPED_SINE = "rbf_damped_sine"
|
| 48 |
+
RBF_SINC = "rbf_sinc"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
__all__ = ["MLPType"]
|
dit/position_encoding.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Position encoding options used by exported dense DiT blocks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DiTPositionEncoding(Enum):
|
| 9 |
+
"""Position encoding strategy used inside exported DiT blocks."""
|
| 10 |
+
|
| 11 |
+
ROPE_2D_AXIAL_DILATED = "rope_2d_axial_dilated"
|
| 12 |
+
ROPE_2D_AXIAL_UNNORMALIZED_DILATED = "rope_2d_axial_unnormalized_dilated"
|
| 13 |
+
ROPE_2D_AXIAL_NORMALIZED = "rope_2d_axial_normalized"
|
| 14 |
+
ROPE_2D_AXIAL_UNNORMALIZED = "rope_2d_axial_unnormalized"
|
| 15 |
+
ROPE_2D_AXIAL_FREQ_AWARE = "rope_2d_axial_freq_aware"
|
| 16 |
+
ROPE_2D_AXIAL_BETA_WARP = "rope_2d_axial_beta_warp"
|
| 17 |
+
ROPE_2D_AXIAL_ALPHA_WARP = "rope_2d_axial_alpha_warp"
|
| 18 |
+
ROPE_3D_ZIMAGE = "rope_3d_zimage"
|
| 19 |
+
ROPE_1D = "rope_1d"
|
| 20 |
+
NONE = "none"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ["DiTPositionEncoding"]
|
dit/repa_projection.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DINO token/class alignment head used by the DINAC-AE export."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
from common.norms import RMSNorm
|
| 12 |
+
from dit.axial_rope2d import (
|
| 13 |
+
AxialRoPE2D,
|
| 14 |
+
AxialRoPE2DConfig,
|
| 15 |
+
AxialRoPE2DCoordMode,
|
| 16 |
+
AxialRoPE2DDimLayout,
|
| 17 |
+
AxialRoPE2DNormalizeCoords,
|
| 18 |
+
)
|
| 19 |
+
from dit.blocks import DitBlock
|
| 20 |
+
from dit.body_config import DiTConditioning
|
| 21 |
+
from dit.position_encoding import DiTPositionEncoding
|
| 22 |
+
from dit.xattn_blocks import CrossAttentionBlock, CrossAttentionConfig
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from dit.mlp_types import MLPType
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class DinoTokenAlignmentOutput:
|
| 30 |
+
"""Predicted DINO class token and spatial patch tokens."""
|
| 31 |
+
|
| 32 |
+
class_token: Tensor
|
| 33 |
+
spatial_tokens: Tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _prepend_identity_rope_prefix(
|
| 37 |
+
*,
|
| 38 |
+
rope_sincos: tuple[Tensor, Tensor],
|
| 39 |
+
prefix_token_count: int,
|
| 40 |
+
device: torch.device,
|
| 41 |
+
) -> tuple[Tensor, Tensor]:
|
| 42 |
+
"""Prepend no-op RoPE entries for class/register prefix tokens."""
|
| 43 |
+
|
| 44 |
+
sin, cos = rope_sincos
|
| 45 |
+
prefix_shape = (int(prefix_token_count), int(sin.shape[-1]))
|
| 46 |
+
prefix_sin = torch.zeros(prefix_shape, device=device, dtype=sin.dtype)
|
| 47 |
+
prefix_cos = torch.ones(prefix_shape, device=device, dtype=cos.dtype)
|
| 48 |
+
match sin.dim():
|
| 49 |
+
case 2:
|
| 50 |
+
return (
|
| 51 |
+
torch.cat([prefix_sin, sin.to(device=device)], dim=0),
|
| 52 |
+
torch.cat([prefix_cos, cos.to(device=device)], dim=0),
|
| 53 |
+
)
|
| 54 |
+
case 3:
|
| 55 |
+
batch = int(sin.shape[0])
|
| 56 |
+
return (
|
| 57 |
+
torch.cat(
|
| 58 |
+
[
|
| 59 |
+
prefix_sin.unsqueeze(0).expand(batch, -1, -1),
|
| 60 |
+
sin.to(device=device),
|
| 61 |
+
],
|
| 62 |
+
dim=1,
|
| 63 |
+
),
|
| 64 |
+
torch.cat(
|
| 65 |
+
[
|
| 66 |
+
prefix_cos.unsqueeze(0).expand(batch, -1, -1),
|
| 67 |
+
cos.to(device=device),
|
| 68 |
+
],
|
| 69 |
+
dim=1,
|
| 70 |
+
),
|
| 71 |
+
)
|
| 72 |
+
case _ as unreachable:
|
| 73 |
+
raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DinoTokenAlignmentHead(nn.Module):
|
| 77 |
+
"""Predict DINOv3 spatial tokens and a class token from latent grids."""
|
| 78 |
+
|
| 79 |
+
in_channels: int
|
| 80 |
+
feature_dim: int
|
| 81 |
+
model_dim: int
|
| 82 |
+
register_token_count: int
|
| 83 |
+
in_proj: nn.Conv2d
|
| 84 |
+
initial_class_token: nn.Parameter
|
| 85 |
+
register_tokens: nn.Parameter
|
| 86 |
+
block: DitBlock
|
| 87 |
+
spatial_output_norm: RMSNorm
|
| 88 |
+
class_readout: CrossAttentionBlock
|
| 89 |
+
class_output_norm: RMSNorm
|
| 90 |
+
_axial_rope2d: AxialRoPE2D
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
*,
|
| 95 |
+
in_channels: int,
|
| 96 |
+
feature_dim: int,
|
| 97 |
+
model_dim: int,
|
| 98 |
+
head_dim: int,
|
| 99 |
+
mlp_ratio: float,
|
| 100 |
+
mlp_activation: MLPType,
|
| 101 |
+
block_index: int,
|
| 102 |
+
register_token_count: int,
|
| 103 |
+
) -> None:
|
| 104 |
+
super().__init__()
|
| 105 |
+
if int(feature_dim) != int(model_dim):
|
| 106 |
+
raise ValueError("DINAC-AE class head requires feature_dim == model_dim")
|
| 107 |
+
if int(register_token_count) != 4:
|
| 108 |
+
raise ValueError("DINAC-AE class head requires four register tokens")
|
| 109 |
+
self.register_token_count = int(register_token_count)
|
| 110 |
+
self.in_channels = int(in_channels)
|
| 111 |
+
self.feature_dim = int(feature_dim)
|
| 112 |
+
self.model_dim = int(model_dim)
|
| 113 |
+
self.in_proj = nn.Conv2d(
|
| 114 |
+
self.in_channels,
|
| 115 |
+
self.model_dim,
|
| 116 |
+
kernel_size=1,
|
| 117 |
+
padding=0,
|
| 118 |
+
stride=1,
|
| 119 |
+
bias=True,
|
| 120 |
+
)
|
| 121 |
+
self.initial_class_token = nn.Parameter(torch.empty((1, self.model_dim)))
|
| 122 |
+
self.register_tokens = nn.Parameter(
|
| 123 |
+
torch.empty((self.register_token_count, self.model_dim))
|
| 124 |
+
)
|
| 125 |
+
nn.init.normal_(self.initial_class_token, mean=0.0, std=0.02)
|
| 126 |
+
nn.init.normal_(self.register_tokens, mean=0.0, std=0.02)
|
| 127 |
+
conditioning = DiTConditioning.UNCOND
|
| 128 |
+
self.block = DitBlock(
|
| 129 |
+
d_model=self.model_dim,
|
| 130 |
+
n_heads=int(self.model_dim // int(head_dim)),
|
| 131 |
+
mlp_ratio=float(mlp_ratio),
|
| 132 |
+
mlp_type=mlp_activation,
|
| 133 |
+
block_index=int(block_index),
|
| 134 |
+
use_norms=True,
|
| 135 |
+
position_encoding=DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED,
|
| 136 |
+
conditioning=conditioning,
|
| 137 |
+
)
|
| 138 |
+
self.spatial_output_norm = RMSNorm(self.model_dim, affine=False)
|
| 139 |
+
self.class_readout = CrossAttentionBlock(
|
| 140 |
+
query_dim=self.model_dim,
|
| 141 |
+
context_dim=self.model_dim,
|
| 142 |
+
cfg=CrossAttentionConfig(
|
| 143 |
+
n_heads=int(self.model_dim // int(head_dim)),
|
| 144 |
+
head_dim=int(head_dim),
|
| 145 |
+
query_extra_dim=0,
|
| 146 |
+
context_extra_dim=0,
|
| 147 |
+
mlp_ratio=float(mlp_ratio),
|
| 148 |
+
attn_dropout=0.0,
|
| 149 |
+
mlp_type=mlp_activation,
|
| 150 |
+
activation_config=None,
|
| 151 |
+
use_norms=True,
|
| 152 |
+
block_index=int(block_index) + 1,
|
| 153 |
+
use_attn_residual=True,
|
| 154 |
+
),
|
| 155 |
+
)
|
| 156 |
+
self.class_output_norm = RMSNorm(self.model_dim, affine=False)
|
| 157 |
+
self._axial_rope2d = AxialRoPE2D(
|
| 158 |
+
head_dim=int(head_dim),
|
| 159 |
+
cfg=AxialRoPE2DConfig(
|
| 160 |
+
base=10_000.0,
|
| 161 |
+
min_period=None,
|
| 162 |
+
max_period=None,
|
| 163 |
+
coord_mode=AxialRoPE2DCoordMode.PATCH_INDICES,
|
| 164 |
+
normalize_coords=AxialRoPE2DNormalizeCoords.MAX,
|
| 165 |
+
dim_layout=AxialRoPE2DDimLayout.PAIR_INTERLEAVED,
|
| 166 |
+
angle_multiplier=1.0,
|
| 167 |
+
coord_offset=0.0,
|
| 168 |
+
frequency_aware=None,
|
| 169 |
+
beta_warp=None,
|
| 170 |
+
alpha_warp=None,
|
| 171 |
+
),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
|
| 175 |
+
"""No-op hook kept for source API compatibility."""
|
| 176 |
+
|
| 177 |
+
_ = fullgraph, dynamic
|
| 178 |
+
|
| 179 |
+
def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
|
| 180 |
+
"""No-op hook kept for source API compatibility."""
|
| 181 |
+
|
| 182 |
+
_ = fullgraph, dynamic
|
| 183 |
+
|
| 184 |
+
def forward(self, latents: Tensor, *, t: Tensor) -> DinoTokenAlignmentOutput:
|
| 185 |
+
"""Return predicted class and spatial DINO tokens."""
|
| 186 |
+
|
| 187 |
+
y = self.in_proj(latents)
|
| 188 |
+
batch, _channels, height, width = y.shape
|
| 189 |
+
spatial_tokens = y.flatten(2).transpose(1, 2)
|
| 190 |
+
class_token = self.initial_class_token.to(device=y.device, dtype=y.dtype)
|
| 191 |
+
class_token = class_token.unsqueeze(0).expand(int(batch), -1, -1)
|
| 192 |
+
register_tokens = self.register_tokens.to(device=y.device, dtype=y.dtype)
|
| 193 |
+
register_tokens = register_tokens.unsqueeze(0).expand(int(batch), -1, -1)
|
| 194 |
+
tokens = torch.cat([class_token, register_tokens, spatial_tokens], dim=1)
|
| 195 |
+
rope_sincos = _prepend_identity_rope_prefix(
|
| 196 |
+
rope_sincos=self._axial_rope2d(H=int(height), W=int(width), scales=None),
|
| 197 |
+
prefix_token_count=int(1 + self.register_token_count),
|
| 198 |
+
device=y.device,
|
| 199 |
+
)
|
| 200 |
+
_ = t
|
| 201 |
+
cond = torch.zeros(
|
| 202 |
+
(int(batch), self.model_dim),
|
| 203 |
+
device=y.device,
|
| 204 |
+
dtype=y.dtype,
|
| 205 |
+
)
|
| 206 |
+
tokens = self.block(
|
| 207 |
+
tokens,
|
| 208 |
+
hw=(int(height), int(width)),
|
| 209 |
+
cond_vec=cond,
|
| 210 |
+
adaln_m=None,
|
| 211 |
+
rope_sincos=rope_sincos,
|
| 212 |
+
generator=None,
|
| 213 |
+
)
|
| 214 |
+
class_query = tokens[:, :1, :]
|
| 215 |
+
context = tokens[:, 1:, :]
|
| 216 |
+
class_output = self.class_readout(class_query, context)[:, 0, :]
|
| 217 |
+
class_output = self.class_output_norm(class_output)
|
| 218 |
+
prefix_token_count = int(1 + self.register_token_count)
|
| 219 |
+
predicted_spatial = self.spatial_output_norm(tokens[:, prefix_token_count:, :])
|
| 220 |
+
return DinoTokenAlignmentOutput(
|
| 221 |
+
class_token=class_output,
|
| 222 |
+
spatial_tokens=predicted_spatial,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
__all__ = ["DinoTokenAlignmentHead", "DinoTokenAlignmentOutput"]
|
dit/xattn_blocks.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dense cross-attention block used by the DINAC-AE class-token head."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
from common.norms import RMSNorm
|
| 10 |
+
from dit.attention_blocks import CrossAttentionCore
|
| 11 |
+
from dit.mlp import build_dit_mlp, reset_module_parameters
|
| 12 |
+
from dit.mlp_types import MLPType
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class CrossAttentionConfig:
|
| 17 |
+
"""Configuration for the exported dense cross-attention block."""
|
| 18 |
+
|
| 19 |
+
n_heads: int = 16
|
| 20 |
+
head_dim: int | None = None
|
| 21 |
+
query_extra_dim: int = 0
|
| 22 |
+
context_extra_dim: int = 0
|
| 23 |
+
key_extra_dim: int = 0
|
| 24 |
+
mlp_ratio: float = 2.0
|
| 25 |
+
attn_dropout: float = 0.0
|
| 26 |
+
mlp_type: MLPType = MLPType.GELU
|
| 27 |
+
activation_config: object | None = None
|
| 28 |
+
use_norms: bool = True
|
| 29 |
+
block_index: int = 0
|
| 30 |
+
use_attn_residual: bool = True
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CrossAttentionBlock(nn.Module):
|
| 34 |
+
"""Dense pre-norm cross-attention plus residual MLP."""
|
| 35 |
+
|
| 36 |
+
query_dim: int
|
| 37 |
+
context_dim: int
|
| 38 |
+
query_extra_dim: int
|
| 39 |
+
context_extra_dim: int
|
| 40 |
+
key_extra_dim: int
|
| 41 |
+
n_heads: int
|
| 42 |
+
head_dim: int
|
| 43 |
+
attn_dim: int
|
| 44 |
+
use_norms: bool
|
| 45 |
+
attn_dropout: float
|
| 46 |
+
use_attn_residual: bool
|
| 47 |
+
query_norm: RMSNorm | None
|
| 48 |
+
context_norm: RMSNorm | None
|
| 49 |
+
mlp_norm: RMSNorm | None
|
| 50 |
+
q_proj: nn.Linear
|
| 51 |
+
attn_core: CrossAttentionCore
|
| 52 |
+
kv_proj: nn.Linear
|
| 53 |
+
out_proj: nn.Linear
|
| 54 |
+
mlp: nn.Module
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
*,
|
| 59 |
+
query_dim: int,
|
| 60 |
+
context_dim: int,
|
| 61 |
+
cfg: CrossAttentionConfig,
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
n_heads = int(cfg.n_heads)
|
| 65 |
+
if cfg.head_dim is None:
|
| 66 |
+
if query_dim % n_heads != 0:
|
| 67 |
+
raise ValueError("query_dim must be divisible by n_heads")
|
| 68 |
+
head_dim = query_dim // n_heads
|
| 69 |
+
else:
|
| 70 |
+
head_dim = int(cfg.head_dim)
|
| 71 |
+
self.query_dim = int(query_dim)
|
| 72 |
+
self.context_dim = int(context_dim)
|
| 73 |
+
self.query_extra_dim = int(cfg.query_extra_dim)
|
| 74 |
+
self.context_extra_dim = int(cfg.context_extra_dim)
|
| 75 |
+
self.key_extra_dim = int(cfg.key_extra_dim)
|
| 76 |
+
self.n_heads = n_heads
|
| 77 |
+
self.head_dim = int(head_dim)
|
| 78 |
+
self.attn_dim = int(self.n_heads * self.head_dim)
|
| 79 |
+
self.use_norms = bool(cfg.use_norms)
|
| 80 |
+
self.attn_dropout = float(cfg.attn_dropout)
|
| 81 |
+
if not cfg.use_attn_residual:
|
| 82 |
+
raise ValueError("DINAC-AE export requires attention residuals")
|
| 83 |
+
self.use_attn_residual = True
|
| 84 |
+
self.query_norm = RMSNorm(self.query_dim) if self.use_norms else None
|
| 85 |
+
self.context_norm = RMSNorm(self.context_dim) if self.use_norms else None
|
| 86 |
+
self.mlp_norm = RMSNorm(query_dim) if self.use_norms else None
|
| 87 |
+
self.q_proj = nn.Linear(
|
| 88 |
+
self.query_dim + self.query_extra_dim, self.attn_dim, bias=False
|
| 89 |
+
)
|
| 90 |
+
self.attn_core = CrossAttentionCore(
|
| 91 |
+
query_dim=query_dim,
|
| 92 |
+
context_dim=context_dim,
|
| 93 |
+
context_extra_dim=self.context_extra_dim,
|
| 94 |
+
key_extra_dim=self.key_extra_dim,
|
| 95 |
+
n_heads=self.n_heads,
|
| 96 |
+
head_dim=self.head_dim,
|
| 97 |
+
attn_dropout=self.attn_dropout,
|
| 98 |
+
)
|
| 99 |
+
self.kv_proj = self.attn_core.kv_proj
|
| 100 |
+
self.out_proj = self.attn_core.out_proj
|
| 101 |
+
hidden = int(round(cfg.mlp_ratio * query_dim))
|
| 102 |
+
self.mlp = build_dit_mlp(
|
| 103 |
+
mlp_type=cfg.mlp_type,
|
| 104 |
+
in_features=query_dim,
|
| 105 |
+
hidden_budget=hidden,
|
| 106 |
+
activation_config=cfg.activation_config,
|
| 107 |
+
block_index=int(cfg.block_index),
|
| 108 |
+
bias_up=False,
|
| 109 |
+
bias_down=False,
|
| 110 |
+
)
|
| 111 |
+
self.reset_parameters()
|
| 112 |
+
|
| 113 |
+
def reset_parameters(self) -> None:
|
| 114 |
+
"""Reset projections and MLP parameters."""
|
| 115 |
+
|
| 116 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 117 |
+
self.attn_core.reset_parameters()
|
| 118 |
+
reset_module_parameters(self.mlp)
|
| 119 |
+
|
| 120 |
+
def forward(
|
| 121 |
+
self,
|
| 122 |
+
query: Tensor,
|
| 123 |
+
context: Tensor,
|
| 124 |
+
*,
|
| 125 |
+
query_extra: Tensor | None = None,
|
| 126 |
+
context_extra: Tensor | None = None,
|
| 127 |
+
key_extra: Tensor | None = None,
|
| 128 |
+
key_padding_mask: Tensor | None = None,
|
| 129 |
+
) -> Tensor: # type: ignore[override]
|
| 130 |
+
"""Run dense cross-attention followed by the residual MLP."""
|
| 131 |
+
|
| 132 |
+
query_tokens = self.query_norm(query) if self.query_norm is not None else query
|
| 133 |
+
if query_extra is not None:
|
| 134 |
+
q_in = query_tokens.new_empty(
|
| 135 |
+
*query_tokens.shape[:-1],
|
| 136 |
+
int(query_tokens.shape[-1]) + int(query_extra.shape[-1]),
|
| 137 |
+
)
|
| 138 |
+
q_in[..., : int(query_tokens.shape[-1])] = query_tokens
|
| 139 |
+
q_in[..., int(query_tokens.shape[-1]) :] = query_extra
|
| 140 |
+
else:
|
| 141 |
+
q_in = query_tokens
|
| 142 |
+
context_tokens = (
|
| 143 |
+
self.context_norm(context) if self.context_norm is not None else context
|
| 144 |
+
)
|
| 145 |
+
if context_extra is not None:
|
| 146 |
+
kv_tokens = context_tokens.new_empty(
|
| 147 |
+
*context_tokens.shape[:-1],
|
| 148 |
+
int(context_tokens.shape[-1]) + int(context_extra.shape[-1]),
|
| 149 |
+
)
|
| 150 |
+
kv_tokens[..., : int(context_tokens.shape[-1])] = context_tokens
|
| 151 |
+
kv_tokens[..., int(context_tokens.shape[-1]) :] = context_extra
|
| 152 |
+
else:
|
| 153 |
+
kv_tokens = context_tokens
|
| 154 |
+
q_attn_tokens = self.q_proj(q_in)
|
| 155 |
+
attn_out = self.attn_core(
|
| 156 |
+
q_attn_tokens,
|
| 157 |
+
kv_tokens,
|
| 158 |
+
training=self.training,
|
| 159 |
+
key_extra=key_extra,
|
| 160 |
+
key_padding_mask=key_padding_mask,
|
| 161 |
+
)
|
| 162 |
+
tokens = query + attn_out
|
| 163 |
+
mlp_in = self.mlp_norm(tokens) if self.mlp_norm is not None else tokens
|
| 164 |
+
return tokens + self.mlp(mlp_in)
|
| 165 |
+
|
| 166 |
+
def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
|
| 167 |
+
"""No-op hook kept for the token-alignment head API."""
|
| 168 |
+
|
| 169 |
+
_ = fullgraph, dynamic
|
| 170 |
+
|
| 171 |
+
def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
|
| 172 |
+
"""No-op hook kept for the token-alignment head API."""
|
| 173 |
+
|
| 174 |
+
_ = fullgraph, dynamic
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
__all__ = ["CrossAttentionBlock", "CrossAttentionConfig"]
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b610db6eb9ba995f14ddfbe3ca683044b3b7b4ebe2409fec9465c545a5ec88f7
|
| 3 |
+
size 633374472
|
technical_report_dinac_ae.md
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DINAC-AE Technical Report
|
| 2 |
+
|
| 3 |
+
`dinac_ae` is a DINO-aligned class-token autoencoder in the SemDisDiffAE
|
| 4 |
+
family: patch-16 spatial latents, a VP diffusion decoder, and semantic
|
| 5 |
+
alignment to frozen vision features.
|
| 6 |
+
|
| 7 |
+
Relative to [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae),
|
| 8 |
+
DINAC-AE replaces the FCDM encoder with a 6-block ViT/DiT-style transformer
|
| 9 |
+
encoder, uses DINOv3 ViT-B/16 features, and extends the latent-to-DINO
|
| 10 |
+
alignment head with a class-token output. The decoder remains in the same FCDM
|
| 11 |
+
VP-diffusion family described in the SemDisDiffAE and capacitor reports.
|
| 12 |
+
|
| 13 |
+
Related reports:
|
| 14 |
+
|
| 15 |
+
- SemDisDiffAE report: https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md
|
| 16 |
+
- full_capacitor report: https://huggingface.co/data-archetype/full_capacitor/blob/main/technical_report_full_capacitor.md
|
| 17 |
+
- capacitor_decoder report: https://huggingface.co/data-archetype/capacitor_decoder/blob/main/technical_report_capacitor_decoder.md
|
| 18 |
+
|
| 19 |
+
## 1. Motivation
|
| 20 |
+
|
| 21 |
+
We trained both FCDM-encoder and transformer-encoder variants. The transformer
|
| 22 |
+
encoder latents were easier for downstream DiT models to learn from, so this
|
| 23 |
+
release keeps the FCDM diffusion decoder and changes the encoder.
|
| 24 |
+
|
| 25 |
+
The second change is the class-token output. The DINO alignment head predicts
|
| 26 |
+
patch tokens as before, and is extended with a class-token prediction path.
|
| 27 |
+
`predict_class(latents)` exposes that feature from latents, enabling
|
| 28 |
+
Representation Frechet Distance / FD-loss style objectives without decoding to
|
| 29 |
+
RGB, and empirically helping make the latents more semantically aligned.
|
| 30 |
+
|
| 31 |
+
## 2. Architecture Summary
|
| 32 |
+
|
| 33 |
+
| Component | SemDisDiffAE | dinac_ae |
|
| 34 |
+
| --- | ---: | ---: |
|
| 35 |
+
| Patch size | 16 | 16 |
|
| 36 |
+
| Latent channels | 128 | 128 |
|
| 37 |
+
| Encoder block family | FCDM | DiT transformer |
|
| 38 |
+
| Encoder width | 896 | 896 |
|
| 39 |
+
| Encoder depth | 4 | 6 |
|
| 40 |
+
| Decoder block family | FCDM | FCDM |
|
| 41 |
+
| Decoder width | 896 | 896 |
|
| 42 |
+
| Decoder depth | 8 | 8 |
|
| 43 |
+
| Decoder skip layout | start/middle/end skip concat | start/middle/end skip concat |
|
| 44 |
+
| DINO alignment | DINOv3 ViT-S/16 patch tokens | DINOv3 ViT-B/16 patch tokens + class token |
|
| 45 |
+
|
| 46 |
+
Parameter counts in the released checkpoint:
|
| 47 |
+
|
| 48 |
+
| Module | Parameters |
|
| 49 |
+
| --- | ---: |
|
| 50 |
+
| Encoder | 64,939,521 |
|
| 51 |
+
| Decoder | 68,133,505 |
|
| 52 |
+
| DINO alignment head | 14,264,320 |
|
| 53 |
+
| Total | 147,337,346 |
|
| 54 |
+
|
| 55 |
+
## 3. Encoder
|
| 56 |
+
|
| 57 |
+
The encoder is a single-scale transformer patch encoder. As in SemDisDiffAE,
|
| 58 |
+
all blocks operate at the final latent grid resolution.
|
| 59 |
+
|
| 60 |
+
### 3.1 Patch Embedding
|
| 61 |
+
|
| 62 |
+
The image is first converted to a patch grid with a stride-16 patch projection:
|
| 63 |
+
|
| 64 |
+
```text
|
| 65 |
+
image [B, 3, H, W]
|
| 66 |
+
-> stride-16 patch projection (3 x 16 x 16 -> 896) [B, 896, h, w]
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### 3.2 DiT Encoder Block
|
| 70 |
+
|
| 71 |
+
The patch grid is processed by `6` unconditional `DitBlock` layers at width
|
| 72 |
+
`896`. Each block has `14` attention heads with head dimension `64`, an MLP
|
| 73 |
+
ratio of `4.0`, and GELU activations. The encoder uses our shared
|
| 74 |
+
AdaLN-capable `DitBlock` implementation with conditioning disabled. As a result
|
| 75 |
+
it keeps that block's RMSNorm sandwich structure: RMSNorm before and after the
|
| 76 |
+
attention branch, RMSNorm before and after the MLP branch, plus per-head RMSNorm
|
| 77 |
+
on Q and K:
|
| 78 |
+
|
| 79 |
+
```text
|
| 80 |
+
x
|
| 81 |
+
-> RMSNorm
|
| 82 |
+
-> biasless QKV projection
|
| 83 |
+
-> per-head RMSNorm on Q and K
|
| 84 |
+
-> axial 2D RoPE on Q and K
|
| 85 |
+
-> scaled dot-product attention
|
| 86 |
+
-> biasless output projection
|
| 87 |
+
-> RMSNorm
|
| 88 |
+
-> residual add
|
| 89 |
+
-> RMSNorm
|
| 90 |
+
-> biasless Linear(896 -> 3584)
|
| 91 |
+
-> GELU
|
| 92 |
+
-> biasless Linear(3584 -> 896)
|
| 93 |
+
-> RMSNorm
|
| 94 |
+
-> residual add
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### 3.3 Posterior Projection
|
| 98 |
+
|
| 99 |
+
After the six transformer blocks, a pointwise projection maps from `896` to
|
| 100 |
+
`256` channels and splits into mean and logSNR:
|
| 101 |
+
|
| 102 |
+
```text
|
| 103 |
+
features [B, 896, h, w]
|
| 104 |
+
-> pointwise projection (896 -> 256)
|
| 105 |
+
-> split: mean [B, 128, h, w], logSNR [B, 128, h, w]
|
| 106 |
+
-> posterior mode: sqrt(sigmoid(logSNR)) * mean
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
The posterior is a VP-parameterized diagonal Gaussian. The mean branch is used
|
| 110 |
+
directly in the posterior mode computation.
|
| 111 |
+
|
| 112 |
+
The transformer encoder uses 2D axial RoPE with unnormalized patch-index
|
| 113 |
+
coordinates:
|
| 114 |
+
|
| 115 |
+
- Position encoding: axial 2D RoPE.
|
| 116 |
+
- Coordinate mode: patch indices.
|
| 117 |
+
- Coordinate normalization: max-coordinate normalization.
|
| 118 |
+
- Pair layout: interleaved pairs.
|
| 119 |
+
- RoPE base: 10000.
|
| 120 |
+
|
| 121 |
+
This RoPE choice is shared by the encoder blocks and DINO alignment head.
|
| 122 |
+
|
| 123 |
+
## 4. Decoder
|
| 124 |
+
|
| 125 |
+
The decoder is the same FCDM decoder family used by full_capacitor and
|
| 126 |
+
capacitor_decoder:
|
| 127 |
+
|
| 128 |
+
- 8 FCDM decoder blocks.
|
| 129 |
+
- Width 896.
|
| 130 |
+
- 16x16 latent patch grid.
|
| 131 |
+
- 128 latent channels.
|
| 132 |
+
- Start/middle/end skip-concat architecture with 2 start blocks and 2 end
|
| 133 |
+
blocks.
|
| 134 |
+
- Depthwise convolution kernel size 7.
|
| 135 |
+
- GELU MLP activations and SiLU convolution activations.
|
| 136 |
+
|
| 137 |
+
Each decoder FCDM block uses the same single residual path as SemDisDiffAE:
|
| 138 |
+
|
| 139 |
+
```text
|
| 140 |
+
x
|
| 141 |
+
-> depthwise Conv 7x7
|
| 142 |
+
-> RMSNorm
|
| 143 |
+
-> scale modulation from timestep AdaLN
|
| 144 |
+
-> pointwise projection
|
| 145 |
+
-> GELU
|
| 146 |
+
-> GRN
|
| 147 |
+
-> pointwise projection
|
| 148 |
+
-> gate modulation from timestep AdaLN
|
| 149 |
+
-> residual add
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
Timestep conditioning uses the low-rank AdaLN scheme from the capacitor
|
| 153 |
+
decoder: a shared base projection plus per-layer low-rank deltas, split into
|
| 154 |
+
`scale` and `gate`. The decoder topology is:
|
| 155 |
+
|
| 156 |
+
```text
|
| 157 |
+
noisy image x_t
|
| 158 |
+
-> stride-16 image patch projection
|
| 159 |
+
-> concatenate projected latents
|
| 160 |
+
-> 2 start FCDM blocks
|
| 161 |
+
-> 4 middle FCDM blocks
|
| 162 |
+
-> concatenate start and middle activations
|
| 163 |
+
-> 2 end FCDM blocks
|
| 164 |
+
-> patch-output projection + PixelShuffle(16)
|
| 165 |
+
-> x0 prediction
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
DINAC-AE keeps this decoder path; the architectural changes are in the encoder
|
| 169 |
+
and alignment head.
|
| 170 |
+
|
| 171 |
+
## 5. DINOv3 Alignment
|
| 172 |
+
|
| 173 |
+
The alignment target is DINOv3 ViT-B/16 trained on LVD1689M:
|
| 174 |
+
|
| 175 |
+
- Teacher: `dino_v3_vit_base_patch16_lvd1689m`.
|
| 176 |
+
- Feature type: `dino_v3_vit_base_patch16_tokens`.
|
| 177 |
+
- Target feature dimension: 768.
|
| 178 |
+
- The loss supervises both the DINO class token and DINO spatial patch tokens.
|
| 179 |
+
|
| 180 |
+
The alignment head first maps unwhitened DINAC-AE latents into DINO token
|
| 181 |
+
space:
|
| 182 |
+
|
| 183 |
+
```text
|
| 184 |
+
latents [B, 128, h, w]
|
| 185 |
+
-> pointwise projection (128 -> 768)
|
| 186 |
+
-> flatten spatial tokens [B, h*w, 768]
|
| 187 |
+
-> prepend learned class token [B, 1, 768]
|
| 188 |
+
-> prepend 4 learned register tokens [B, 4, 768]
|
| 189 |
+
-> 1 unconditional RoPE `DitBlock` over all tokens
|
| 190 |
+
-> RMSNorm(spatial tokens)
|
| 191 |
+
-> spatial negative-cosine target
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
The token-alignment block uses the same unconditional RoPE `DitBlock` form as
|
| 195 |
+
the encoder, but at DINO width: `768` channels, `12` heads, head dimension `64`,
|
| 196 |
+
MLP ratio `4.0`, and GELU MLP. The learned prefix tokens receive identity RoPE
|
| 197 |
+
rotation; spatial tokens receive the same axial 2D RoPE configuration as the
|
| 198 |
+
encoder.
|
| 199 |
+
|
| 200 |
+
The class-token path uses an additional residual cross-attention block:
|
| 201 |
+
|
| 202 |
+
```text
|
| 203 |
+
updated class token [B, 1, 768]
|
| 204 |
+
updated register + spatial tokens [B, 4 + h*w, 768]
|
| 205 |
+
-> class query cross-attends to register + spatial tokens
|
| 206 |
+
-> residual add
|
| 207 |
+
-> RMSNorm
|
| 208 |
+
-> biasless Linear(768 -> 3072)
|
| 209 |
+
-> GELU
|
| 210 |
+
-> biasless Linear(3072 -> 768)
|
| 211 |
+
-> residual add
|
| 212 |
+
-> RMSNorm(class token)
|
| 213 |
+
-> class negative-cosine target
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
The class readout uses the same DINO width and head geometry as the
|
| 217 |
+
token-alignment block, but replaces self-attention with dense cross-attention.
|
| 218 |
+
|
| 219 |
+
The DINO alignment loss is negative cosine on RMS-normalized features:
|
| 220 |
+
|
| 221 |
+
```text
|
| 222 |
+
class_negative_cosine_loss + spatial_negative_cosine_loss
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
The DINO alignment weight is `0.01`. Alignment is applied directly to clean
|
| 226 |
+
latent tokens. Robustness to local token errors is handled separately by
|
| 227 |
+
[random-token logSNR offset regularization](#8-logsnr-offset-regularization).
|
| 228 |
+
|
| 229 |
+
## 6. Class-Token Output
|
| 230 |
+
|
| 231 |
+
`predict_class(latents)` is part of the public API. It expects the same
|
| 232 |
+
whitened latent convention returned by `encode(...)`, applies dewhitening, and
|
| 233 |
+
runs the DINO alignment head. The returned tensor has shape `[B, 768]` and
|
| 234 |
+
lives in the DINOv3 ViT-B/16 class-token feature space.
|
| 235 |
+
|
| 236 |
+
This output has two intended uses:
|
| 237 |
+
|
| 238 |
+
- It adds a global semantic pressure during autoencoder training, complementing
|
| 239 |
+
the spatial patch-token alignment loss.
|
| 240 |
+
- It provides a latent-space feature endpoint for FD-loss / Representation
|
| 241 |
+
Frechet Distance objectives, avoiding the cost and gradient path of decoding
|
| 242 |
+
latents back to RGB before computing representation statistics.
|
| 243 |
+
|
| 244 |
+
The class-token output is trained by negative cosine alignment to the frozen
|
| 245 |
+
DINOv3 class token.
|
| 246 |
+
|
| 247 |
+
## 7. Training Losses
|
| 248 |
+
|
| 249 |
+
The checkpoint was trained with these active loss terms:
|
| 250 |
+
|
| 251 |
+
| Loss term | Weight / value |
|
| 252 |
+
| --- | ---: |
|
| 253 |
+
| Main VP diffusion reconstruction loss | 1.0 |
|
| 254 |
+
| DINO class + spatial token alignment | 0.01 |
|
| 255 |
+
| Latent posterior VE variance loss | 0.00003 |
|
| 256 |
+
| Latent log-variance scale penalty | 0.0003 |
|
| 257 |
+
|
| 258 |
+
The main decoder objective is VP diffusion `x_pred` with the SID2 x-prediction
|
| 259 |
+
variant. The timestep sampler is uniform, with logSNR range `[-10, 10]` and
|
| 260 |
+
training logSNR shift `-1.0`.
|
| 261 |
+
|
| 262 |
+
The latent posterior regularization follows the same KL-like variance expansion
|
| 263 |
+
idea described in the SemDisDiffAE report:
|
| 264 |
+
https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md#32-variance-expansion-loss
|
| 265 |
+
|
| 266 |
+
Latent running statistics use momentum `0.0005` and epsilon `0.0001`.
|
| 267 |
+
`encode(...)` returns whitened latents; `decode(...)` and
|
| 268 |
+
`predict_class(...)` dewhiten those latents before applying the decoder or
|
| 269 |
+
class output path.
|
| 270 |
+
|
| 271 |
+
## 8. LogSNR Offset Regularization
|
| 272 |
+
|
| 273 |
+
During training, 10% of spatial latent tokens across the batch/grid are
|
| 274 |
+
selected at random and their posterior logSNR is shifted by `-2.0` before
|
| 275 |
+
sampling.
|
| 276 |
+
|
| 277 |
+
This injects non-smooth token-level latent errors during training. The decoder
|
| 278 |
+
and DINO alignment heads both see these perturbed latents, encouraging
|
| 279 |
+
robustness to downstream DiT prediction mistakes.
|
| 280 |
+
|
| 281 |
+
## 9. Training Recipe
|
| 282 |
+
|
| 283 |
+
The model was trained on 12M images using a single NVIDIA RTX PRO 6000
|
| 284 |
+
Blackwell 96GB GPU. Training used two stages:
|
| 285 |
+
|
| 286 |
+
| Stage | Resolution schedule | Batch size | Approx steps |
|
| 287 |
+
| --- | --- | ---: | ---: |
|
| 288 |
+
| Stage 1 | 90% 256-scale AR buckets, 10% 384-scale AR buckets | 128 | 150k |
|
| 289 |
+
| Stage 2 | equal mix of 256/384/512/768/1024 buckets | 64 | 200k |
|
| 290 |
+
|
| 291 |
+
The mixed-resolution second stage is especially important for the transformer
|
| 292 |
+
encoder. In practice, and even with 2D RoPE, transformer blocks tolerate only
|
| 293 |
+
limited resolution extrapolation unless they see the higher-resolution patch
|
| 294 |
+
grids during training.
|
| 295 |
+
|
| 296 |
+
Optimizer/training parameters:
|
| 297 |
+
|
| 298 |
+
| Setting | Stage 1 | Stage 2 |
|
| 299 |
+
| --- | ---: | ---: |
|
| 300 |
+
| Optimizer | AdamW | AdamW |
|
| 301 |
+
| Learning rate | 1e-4 | 5e-5 |
|
| 302 |
+
| Betas | (0.9, 0.99) | (0.9, sqrt(0.98) ~= 0.98995) |
|
| 303 |
+
| Weight decay | 0.0 | 0.0 |
|
| 304 |
+
| EMA decay | 0.9995 | 0.9995 |
|
| 305 |
+
| Warmup | 2,000 steps | 10,000 steps |
|
| 306 |
+
| Precision | AMP BF16, TF32 matmul | AMP BF16, TF32 matmul |
|
| 307 |
+
| Gradient clipping | 1.0 | 1.0 |
|
| 308 |
+
| Optimizer state dtype | BF16 | BF16 |
|
| 309 |
+
|
| 310 |
+
Resolution and accumulation settings:
|
| 311 |
+
|
| 312 |
+
| Resolution | Stage 1 mix | Stage 1 grad accumulation | Stage 2 mix | Stage 2 grad accumulation |
|
| 313 |
+
| ---: | ---: | ---: | ---: | ---: |
|
| 314 |
+
| 256 | 90% | 1 | 20% | 1 |
|
| 315 |
+
| 384 | 10% | 1 | 20% | 1 |
|
| 316 |
+
| 512 | 0% | 0 | 20% | 1 |
|
| 317 |
+
| 768 | 0% | 0 | 20% | 2 |
|
| 318 |
+
| 1024 | 0% | 0 | 20% | 2 |
|
| 319 |
+
|
| 320 |
+
## 10. Reconstruction Quality
|
| 321 |
+
|
| 322 |
+
Reconstruction quality on `2000` validation images:
|
| 323 |
+
|
| 324 |
+
| Model | Mean PSNR | Std | Median | Min | p5 | p95 | Max |
|
| 325 |
+
| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 326 |
+
| dinac_ae | 35.19 | 4.53 | 35.06 | 22.44 | 28.02 | 42.43 | 47.31 |
|
| 327 |
+
| FLUX.2 VAE | 36.28 | 4.53 | 36.07 | 22.73 | 28.89 | 43.63 | 47.38 |
|
| 328 |
+
|
| 329 |
+
The 39-image reconstruction viewer includes originals, DINAC-AE
|
| 330 |
+
reconstructions, FLUX.2 VAE references, RGB deltas, and latent PCA:
|
| 331 |
+
https://huggingface.co/spaces/data-archetype/dinac_ae-results
|
| 332 |
+
|
| 333 |
+
The released export recheck on that viewer set gives `35.15 dB` mean PSNR
|
| 334 |
+
(`25.73` min, `45.99` max).
|
| 335 |
+
|
| 336 |
+
## 11. Class-Token Alignment Results
|
| 337 |
+
|
| 338 |
+
The `predict_class(...)` path is evaluated against the frozen DINOv3 ViT-B/16
|
| 339 |
+
teacher class token on the same `2000` images used for reconstruction PSNR.
|
| 340 |
+
|
| 341 |
+
| Metric | Cosine similarity |
|
| 342 |
+
| --- | ---: |
|
| 343 |
+
| Mean | 0.757458 |
|
| 344 |
+
| Std | 0.076265 |
|
| 345 |
+
| Median | 0.765958 |
|
| 346 |
+
| Min | 0.394647 |
|
| 347 |
+
| p5 | 0.623156 |
|
| 348 |
+
| p10 | 0.656243 |
|
| 349 |
+
| p25 | 0.711337 |
|
| 350 |
+
| p75 | 0.813098 |
|
| 351 |
+
| p90 | 0.849525 |
|
| 352 |
+
| p95 | 0.865219 |
|
| 353 |
+
| Max | 0.932722 |
|
| 354 |
+
|
| 355 |
+
## 12. Encoder Throughput
|
| 356 |
+
|
| 357 |
+
Encoder timing was measured with the released package on an NVIDIA GeForce RTX
|
| 358 |
+
5090. Decoder timing is unchanged from the capacitor decoder/full_capacitor
|
| 359 |
+
release because DINAC-AE uses the same decoder architecture.
|
| 360 |
+
|
| 361 |
+
| Resolution | Batch | FLUX.2 encode ms/batch | full_capacitor ms/batch | dinac_ae ms/batch | Speedup vs FLUX.2 | dinac_ae vs full_capacitor |
|
| 362 |
+
| --- | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 363 |
+
| 256x256 | 128 | 383.41 | 42.56 | 50.32 | 7.62x | 1.18x slower |
|
| 364 |
+
| 512x512 | 32 | 353.58 | 44.97 | 52.65 | 6.72x | 1.17x slower |
|
| 365 |
+
|
| 366 |
+
Peak allocated encoder memory:
|
| 367 |
+
|
| 368 |
+
| Resolution | Batch | FLUX.2 MiB | full_capacitor MiB | dinac_ae MiB | Reduction vs FLUX.2 |
|
| 369 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
| 370 |
+
| 256x256 | 128 | 12,511.0 | 1,008.2 | 1,637.3 | 86.9% |
|
| 371 |
+
| 512x512 | 32 | 12,511.0 | 1,005.6 | 1,638.5 | 86.9% |
|
| 372 |
+
|
| 373 |
+
The transformer encoder is slightly slower and uses more memory than the
|
| 374 |
+
full_capacitor FCDM encoder, but it remains much faster and much smaller than
|
| 375 |
+
the FLUX.2 VAE encoder.
|
| 376 |
+
|
| 377 |
+
## References
|
| 378 |
+
|
| 379 |
+
- Oriane Siméoni, Huy V. Vo, Maximilian Seitzer, Federico Baldassarre,
|
| 380 |
+
Maxime Oquab, Cijo Jose, Vasil Khalidov, Marc Szafraniec, Seungeun Yi,
|
| 381 |
+
Michaël Ramamonjisoa, Francisco Massa, Daniel Haziza, Luca Wehrstedt,
|
| 382 |
+
Jianyuan Wang, Timothée Darcet, Théo Moutakanni, Leonel Sentana,
|
| 383 |
+
Claire Roberts, Andrea Vedaldi, Jamie Tolan, John Brandt, Camille Couprie,
|
| 384 |
+
Julien Mairal, Hervé Jégou, Patrick Labatut, and Piotr Bojanowski.
|
| 385 |
+
DINOv3. arXiv:2508.10104, 2025.
|
| 386 |
+
https://arxiv.org/abs/2508.10104
|
| 387 |
+
- Jiawei Yang, Zhengyang Geng, Xuan Ju, Yonglong Tian, and Yue Wang.
|
| 388 |
+
Representation Frechet Loss for Visual Generation. arXiv:2604.28190, 2026.
|
| 389 |
+
https://arxiv.org/abs/2604.28190
|
| 390 |
+
- FD-Loss implementation: https://github.com/Jiawei-Yang/FD-Loss
|